<p>下面是<a href="https://stackoverflow.com/a/56510832/7207392">https://stackoverflow.com/a/56510832/7207392</a>的解决方案,并做了必要的修改。为了简单起见,我对所有数据都使用<code>np.array</code>。我不是tensortflow专家,所以如果翻译不完全是直截了当的,你就得问问别人怎么做。你知道吗</p>
<pre><code>import numpy as np
def f(a1, a2, n):
N,M = a1.shape
a1p = np.concatenate([a1,np.zeros((1,a1.shape[1]),a1.dtype)], axis=0)
a2 = np.sort(a2, axis=1)
a2[:,1:][a2[:,1:]==a2[:,:-1]] = N
y,x = np.where(np.count_nonzero(a1p[a2], axis=1) >= n)
out = np.zeros_like(a1p)
out[a2[y],x[:,None]] = a1p[a2[y],x[:,None]]
return out[:-1]
a1 = np.array(
[[9.968594, 8.655439, 0., 0. ],
[0., 8.3356, 0., 8.8974 ],
[0., 0., 6.103182, 7.330564 ],
[6.609862, 0., 3.0614321, 0. ],
[9.497023, 0., 3.8914037, 0. ],
[0., 8.457685, 8.602337, 0. ],
[0., 0., 5.826657, 8.283971 ]])
a2 = np.array(
[[2, 5, 1],
[1, 6, 4],
[0, 0, 0],
[2, 3, 6],
[4, 2, 4]])
print(f(a1,a2,2))
</code></pre>
<p>输出:</p>
<pre><code>[[0. 0. 0. 0. ]
[0. 8.3356 0. 8.8974 ]
[0. 0. 6.103182 7.330564 ]
[0. 0. 3.0614321 0. ]
[0. 0. 3.8914037 0. ]
[0. 8.457685 8.602337 0. ]
[0. 0. 5.826657 8.283971 ]]
</code></pre>