Numpy:用0填充ndarray的非最大元素

2024-10-20 03:24:39 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个ndarray,我想将最后一个维度中的所有非最大元素设置为零。在

a = np.array([[[1,8,3,4],[6,7,10,6],[11,12,15,4]],
              [[4,2,3,4],[4,7,9,8],[41,14,15,3]],
              [[4,22,3,4],[16,7,9,8],[41,12,15,43]]
             ])
print(a.shape)
(3,3,4)

我可以通过np.argmax公司():

^{pr2}$

显然,b的一维比a小,现在,我想得到一个新的三维数组,除了最大值所在的位置外,所有的零都是零。在

我要得到这个数组:

np.array([[[0,1,0,0],[0,0,1,0],[0,0,1,0]],
          [[1,0,0,1],[0,0,1,0],[1,0,0,0]],
          [[0,1,0,0],[1,0,0,0],[0,0,0,1]]
         ])

实现这一点的一种方法是,我尝试创建这些临时数组

b = np.repeat(b[:,:,np.newaxis], 4, axis=2)
t = np.repeat(np.arange(4).reshape(4,1), 9, axis=1).T.reshape(b.shape)

z = np.zeros(shape=a.shape, dtype=int)
z[t == b] = 1
z
array([[[0, 1, 0, 0],
    [0, 0, 1, 0],
    [0, 0, 1, 0]],

   [[1, 0, 0, 0],
    [0, 0, 1, 0],
    [1, 0, 0, 0]],

   [[0, 1, 0, 0],
    [1, 0, 0, 0],
    [0, 0, 0, 1]]])

你知道如何更有效地得到这个吗?在


Tags: 方法元素np公司数组arrayrepeatprint
1条回答
网友
1楼 · 发布于 2024-10-20 03:24:39

这里有一种使用广播的方法:

In [108]: (a == a.max(axis=2, keepdims=True)).astype(int)
Out[108]: 
array([[[0, 1, 0, 0],
        [0, 0, 1, 0],
        [0, 0, 1, 0]],

       [[1, 0, 0, 1],
        [0, 0, 1, 0],
        [1, 0, 0, 0]],

       [[0, 1, 0, 0],
        [1, 0, 0, 0],
        [0, 0, 0, 1]]])

相关问题 更多 >