numpy阵列上具有特定条件的滤波张量流阵列

2024-09-30 22:18:49 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个tensorflow数组namestf-array和一个numpy数组namesnp_array。我想在tf_array中找到关于np-array的特定行。你知道吗

    tf-array = tf.constant(
                [[9.968594,  8.655439,  0.,        0.       ],
                 [0.,        8.3356,    0.,        8.8974   ],
                 [0.,        0.,        6.103182,  7.330564 ],
                 [6.609862,  0.,        3.0614321, 0.       ],
                 [9.497023,  0.,        3.8914037, 0.       ],
                 [0.,        8.457685,  8.602337,  0.       ],
                 [0.,        0.,        5.826657,  8.283971 ]])

我还有一个np数组:

np_array = np.matrix(
 [[2, 5, 1],
  [1, 6, 4],
  [0, 0, 0],
  [2, 3, 6],
  [4, 2, 4]]

现在我想保留tf-array中的元素,其中n(here n is 2)的组合(它们的索引)的值为np-array。这是什么意思?你知道吗

例如,在tf-array的第一列中,有值的索引是:(0,3,4)。在np-array中是否有任何行包含这两个索引的任意组合:(0,3), (0,4) or (3,4)。其实,没有这样的争吵。因此,该列中的所有元素都变成了zero。你知道吗

tf-array中第二列的索引是(0,1) (0,5) (1,5)。如您所见,记录(1,5)在第一行的np-array中可用。这就是为什么我们把它们保存在tf-array。你知道吗

所以最终结果应该是这样的:

[[0.        0.        0.        0.       ]
 [0.        8.3356    0.        8.8974   ]
 [0.        0.        6.103182  7.330564 ]
 [0.        0.        3.0614321 0.       ]
 [0.        0.        3.8914037 0.       ]
 [0.        8.457685  8.602337  0.       ]
 [0.        0.        5.826657  8.283971 ]]

我正在寻找一个非常有效的方法,因为我有大量的数据。你知道吗

更新1

我可以用下面的代码得到它,它给出了True,其中有值和false的零掩码:

[[ True  True False False]
 [False  True False  True]
 [False False  True  True]
 [ True False  True False]
 [ True False  True False]
 [False  True  True False]
 [False False  True  True]]

with tf.Session() as sess:  
 where = tf.not_equal(tf-array, 0.0)
 print(sess.run(where))

但是我怎样才能将这些矩阵与np_array进行比较呢?你知道吗

提前谢谢!你知道吗


Tags: numpyfalsetrue元素tftensorflownp数组
2条回答

下面是https://stackoverflow.com/a/56510832/7207392的解决方案,并做了必要的修改。为了简单起见,我对所有数据都使用np.array。我不是tensortflow专家,所以如果翻译不完全是直截了当的,你就得问问别人怎么做。你知道吗

import numpy as np

def f(a1, a2, n):
    N,M = a1.shape
    a1p = np.concatenate([a1,np.zeros((1,a1.shape[1]),a1.dtype)], axis=0)
    a2 = np.sort(a2, axis=1)
    a2[:,1:][a2[:,1:]==a2[:,:-1]] = N
    y,x = np.where(np.count_nonzero(a1p[a2], axis=1) >= n)
    out = np.zeros_like(a1p)
    out[a2[y],x[:,None]] = a1p[a2[y],x[:,None]]
    return out[:-1]

a1 = np.array(
    [[9.968594,  8.655439,  0.,        0.       ],
     [0.,        8.3356,    0.,        8.8974   ],
     [0.,        0.,        6.103182,  7.330564 ],
     [6.609862,  0.,        3.0614321, 0.       ],
     [9.497023,  0.,        3.8914037, 0.       ],
     [0.,        8.457685,  8.602337,  0.       ],
     [0.,        0.,        5.826657,  8.283971 ]])

a2 = np.array(
 [[2, 5, 1],
  [1, 6, 4],
  [0, 0, 0],
  [2, 3, 6],
  [4, 2, 4]])

print(f(a1,a2,2))

输出:

[[0.        0.        0.        0.       ]
 [0.        8.3356    0.        8.8974   ]
 [0.        0.        6.103182  7.330564 ]
 [0.        0.        3.0614321 0.       ]
 [0.        0.        3.8914037 0.       ]
 [0.        8.457685  8.602337  0.       ]
 [0.        0.        5.826657  8.283971 ]]

您可以尝试的一种有效方法是为每行设置位标志(0,3,4)的值是1<;<;0 | 1<;<;3 | 1<;<;4。您将有一个带有旗帜。试试看如果<;<;和|运算符在numpy中工作。 对另一个数组也一样,我猜tf数组只是包裹的numpy。 在拥有2个标志数组后,在这些标志上按位“and”。如果行的条件为true,则结果将至少有两个非零位。也有点可以做也有效率,谷歌为之。你知道吗

这个函数永远不会与float一起工作-您需要将它们转换为非常小的整数。你知道吗

import numpy as np



arr_one =  np.array(
 [[2, 5, 1],
  [1, 6, 4],
  [0, 0, 0],
  [2, 3, 6],
  [4, 2, 4]])

arr_two =  np.array(
 [[2, 0, 7],
  [1, 3, 4],
  [5, 5, 6],
  [1, 3, 6],
  [4, 2, 4]])




print('1 << arr_one.T[0] ' , 1 << arr_one.T[0] )


arr_one_flags = 1 << arr_one.T[0] | 1 << arr_one.T[1] | 1 << arr_one.T[2]

print('arr_one_flags ', arr_one_flags)

arr_two_flags = 1 << arr_two.T[0] | 1 << arr_two.T[1] | 1 << arr_two.T[2]

arr_and = arr_one_flags & arr_two_flags

print('arr_and ', arr_and)



def get_bit_count(value):
    n = 0
    while value:
        n += 1
        value &= value-1
    return n

arr_matches = np.array([get_bit_count(x) for x in arr_and])


print('arr_matches ', arr_matches )


arr_two_filtered = arr_two[arr_matches > 1]

print('arr_two_filtered ', arr_two_filtered )

相关问题 更多 >