Numpy获取二维数组中重复元素的确切参数

2024-09-30 01:28:14 发布

您现在位置:Python中文网/ 问答频道 /正文

我有两个二维数组ab。我想找到ba精确的索引。我遵循了here提出的解决方案

问题是我的数组包含重复项,您可以在这里看到:

# The shape of b is (50, 2)
b = np.array([[ 0,  1],[ 2,  3],[ 4,  5],[ 6,  7], [ 0,  1],
             [10, 11], [12, 13], [14, 15], [16, 17], [10, 11],
             [20, 21], [22, 23], [24, 25], [26, 27], [20, 21],
             [30, 31], [32, 33], [34, 35], [36, 37], [30, 31],
             [40, 41], [42, 43], [44, 45], [46, 47], [40, 41],
             [50, 51], [52, 53], [54, 55], [56, 57], [50, 51],
             [60, 61], [62, 63], [64, 65], [66, 67], [60, 61],
             [70, 71], [72, 73], [74, 75], [76, 77], [70, 71],
             [80, 81], [82, 83], [84, 85], [86, 87], [80, 81],
             [90, 91], [92, 93], [94, 95], [96, 97], [90, 91]])

# The shape of a is (20,2)
a = np.array([[ 0,  1],[ 2,  3], [ 4,  5],[ 6,  7],[ 0,  1],
       [50, 51],[52, 53], [54, 55], [56, 57], [50, 51],
       [20, 21], [22, 23], [24, 25], [26, 27], [20, 21],
       [70, 71], [72, 73], [74, 75], [76, 77], [70, 71]])

现在当我尝试这样的事情时:

# See the link above approach 2
def view1D(a, b): # a, b are arrays
    a = np.ascontiguousarray(a)
    b = np.ascontiguousarray(b)
    void_dt = np.dtype((np.void, a.dtype.itemsize * a.shape[1]))
    return a.view(void_dt).ravel(),  b.view(void_dt).ravel()

def argwhere_nd_searchsorted(a,b):
    A,B = view1D(a,b)
    sidxB = B.argsort()
    mask = np.isin(A,B)
    cm = A[mask]
    idx0 = np.flatnonzero(mask)
    idx1 = sidxB[np.searchsorted(B,cm, sorter=sidxB)]
    return idx0, idx1 # idx0 : indices in A, idx1 : indices in B

args0, args1 = argwhere_nd_searchsorted(a,b)

结果:

#args0
array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,17, 18, 19])

#args1
 array([ 0,
  1,
  2,
  3,
  0, # this sould be 4
 25,
 26,
 27,
 28,
 25, # this sould be 29
 10,
 11,
 12,
 13,
 10,# this should be 14
 39,# this should be 35
 36,
 37,
 38,
 39])
# if we check
np.equal(b[args1],a).all() # This returns True

如您所见,突出显示的索引中的问题是重复的。我的预期结果显示在注释行中

感谢您的帮助


Tags: thenpdtmask数组bethisarray
1条回答
网友
1楼 · 发布于 2024-09-30 01:28:14

我们可以再添加一列id来表示行中的重复项,然后使用相同的步骤。我们会用熊猫来获取这些身份证,这样更简单。因此,只要做-

import pandas as pd

def assign_duplbl(a):
    df = pd.DataFrame(a)
    df['num'] = 1
    return df.groupby(list(range(a.shape[1]))).cumsum().values

a1 = np.hstack((a,assign_duplbl(a)))
b1 = np.hstack((b,assign_duplbl(b)))
args0, args1 = argwhere_nd_searchsorted(a1,b1)

相关问题 更多 >

    热门问题