我的目标是对蒙版元素子集执行昂贵的操作,并用零表示剩余元素。我将从一个例子开始,在这个例子中,我使用sum来代替昂贵的操作:
# shape: (3,3,2)
in = [ [ [ 1, 2 ], [ 2, 3 ], [ 3, 4 ] ],
[ [ 4, 5 ], [ 5, 6 ], [ 6, 7 ] ],
[ [ 7, 8 ], [ 8, 9 ], [ 9, 0 ] ] ]
# shape: (3,3)
mask = [ [ 0, 1, 0 ],
[ 1, 0, 0 ],
[ 0, 0, 0 ] ]
# expected sum output:
# shape: (3,3,1)
out = [ [ [ 0 ], [ 5 ], [ 0 ] ],
[ [ 9 ], [ 0 ], [ 0 ] ],
[ [ 0 ], [ 0 ], [ 0 ] ] ]
我能在一层之外工作A
是我的掩码,E
是我的输入数据,N
是我行轴和列轴中的元素数
分区过程的副作用是扁平化,因此我需要将其重新整形为原始的行/列维度,但使用新的通道维度
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
import tensorflow as tf
import numpy as np
def test1():
N = 3
testA = [[[0., 1., 0.],
[1., 0., 0.],
[0., 0., 0.]]]
testE = [[[[1., 2.],
[3.1, 4.1],
[5., 6.], ],
[[7.1, 8.1],
[9., 1.],
[9., 2.], ],
[[8., 3.],
[7., 4.],
[6., 5.], ], ]]
testA = np.asarray(testA).astype('float32')
testE = np.asarray(testE).astype('float32')
part1 = tf.dynamic_partition( testE, testA, 2 )
print( len( part1 ) )
print( part1[0] )
print( part1[1] )
"""
2
tf.Tensor(
[[1. 2.]
[5. 6.]
[9. 1.]
[9. 2.]
[8. 3.]
[7. 4.]
[6. 5.]], shape=(7, 2), dtype=float32)
tf.Tensor(
[[3.1 4.1]
[7.1 8.1]], shape=(2, 2), dtype=float32)
"""
sum1 = tf.math.reduce_sum( part1[1], axis=-1, keepdims=1 )
print( sum1 )
"""
tf.Tensor(
[[ 7.2 ]
[15.200001]], shape=(2, 1), dtype=float32)
"""
indices1 = [
[ 0 ],
[ 2 ],
[ 4 ],
[ 5 ],
[ 6 ],
[ 7 ],
[ 8 ],
]
indices2 = [
[ 1 ],
[ 3 ],
]
indices = [ indices1, indices2 ]
partitioned_data = [
np.zeros( shape=(7,1) ),
sum1
]
stitch1_flat = tf.dynamic_stitch( indices, partitioned_data )
print( stitch1_flat )
"""
tf.Tensor(
[ 0. 7.2 0. 15.200001 0. 0. 0.
0. 0. ], shape=(9,), dtype=float32)
"""
stitch1 = tf.reshape( stitch1_flat, (N,N,1) )
print( stitch1 )
"""
tf.Tensor(
[[[ 0. ]
[ 7.2 ]
[ 0. ]]
[[15.200001]
[ 0. ]
[ 0. ]]
[[ 0. ]
[ 0. ]
[ 0. ]]], shape=(3, 3, 1), dtype=float32)
"""
stitch1_np = stitch1.numpy()
target = np.array([[[ 0. ],
[ 7.2 ],
[ 0. ]],
[[15.200001],
[ 0. ],
[ 0. ]],
[[ 0. ],
[ 0. ],
[ 0. ]]])
np.testing.assert_almost_equal( stitch1_np, target, decimal=3 )
我很难将其推广到keras/tf层。 我可以让分区工作,但我有一个困难的时间来计算正确的索引缝合回来。 我还希望在计算另一个粘滞元素的零张量时遇到麻烦。 如果您能为这些问题提供任何帮助,我们将不胜感激
我对python也很在行,所以不要以为我是故意做任何非常规的事情。 我可能只是不知道什么更好
提前谢谢
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
import tensorflow as tf
import numpy as np
def test2():
class TestFlat(tf.keras.layers.Layer):
def __init__(self):
super(TestFlat, self).__init__()
self.N = -1 #size of row, column
self.S = -1 #size of input channel
def build(self, input_shape):
print( "input_shape: ", input_shape )
# TensorShape([None, 3, 3]), TensorShape([None, 3, 3, 2])]
assert( len( input_shape ) == 2 )
assert( len( input_shape[0] ) == 3 )
assert( len( input_shape[1] ) == 4 )
assert( input_shape[0][1] == input_shape[1][1] )
assert( input_shape[0][2] == input_shape[1][2] )
self.N = input_shape[0][1]
self.S = input_shape[1][3]
def call(self, inputs):
print( "inputs: ", inputs )
#[<tf.Tensor 'A_in:0' shape=(None, 3, 3) dtype=float32>,
# <tf.Tensor 'E_in:0' shape=(None, 3, 3, 2) dtype=float32>]
A = inputs[0] # mask
E = inputs[1] # data
A_int = tf.cast( A, "int32" )
part = tf.dynamic_partition( E, A_int, 2 )
print( len( part ) )
print( part[0] )
print( part[1] )
"""
2
tf.Tensor(
[[1. 2.]
[5. 6.]
[9. 1.]
[9. 2.]
[8. 3.]
[7. 4.]
[6. 5.]], shape=(7, 2), dtype=float32)
tf.Tensor(
[[3.1 4.1]
[7.1 8.1]], shape=(2, 2), dtype=float32)
"""
sum1 = tf.math.reduce_sum( part[1], axis=-1, keepdims=True )
# Okay so now we're done with the "expensive" calculation
# and we just need to merge with zeros back into our target shape of (None,N,N,1)
# Step 1: Calculate indices for stitching
# none of the rest of this works:
r = tf.range(self.N*self.N*self.S) #???
#tf.shape((None,self.N,self.N,1)) #???
s = tf.shape(E) #???
aa=tf.Variable(s) #???
aa[-1].assign( 1 ) #???
r = tf.reshape( r, s ) #???
indices = tf.dynamic_partition( r, A_int, 2 )
print( indices )
"""
partitioned_data = [
np.zeros( shape=(7,1) ),
sum1
]
"""
# Step 2: Create zero tensor
# Step 3: Stitch sum1 with zero tensor
return inputs[0] #dummy for now
N = 3
S = 2
A_in = Input(shape=(N, N), name='A_in')
E_in = Input(shape=(N, N, S), name='E_in')
out = TestFlat()( [A_in,E_in] )
model = Model(inputs=[A_in,E_in], outputs=out)
model.compile(optimizer='adam', loss='mean_squared_error')
model.summary()
testA = [[[0., 1., 0.],
[1., 0., 0.],
[0., 0., 0.]]]
testE = [[[[1., 2.],
[3.1, 4.1],
[5., 6.], ],
[[7.1, 8.1],
[9., 1.],
[9., 2.], ],
[[8., 3.],
[7., 4.],
[6., 5.], ], ]]
testA = np.asarray(testA).astype('float32')
testE = np.asarray(testE).astype('float32')
print( model([testA,testE]) )
当我知道你实际上不需要零值的指数时,我最终找到了一个答案。混合的零是隐含的,你只需要在最后一批货的末尾加上一些
这是混乱的,但它的工作请随时就如何改进提供建议。
相关问题 更多 >
编程相关推荐