检查numpy数组中的不一致/重复值并编制索引

2024-05-18 07:31:20 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个包含对象id的数组traced_descIDs,我想确定哪些项在这个数组中不是唯一的。然后,对于每个唯一的重复(小心)ID,我需要确定traced_descIDs的哪些索引与之关联。在

例如,如果我们在这里使用跟踪的描述,我希望发生以下过程:

traced_descIDs = [1, 345, 23, 345, 90, 1]
dupIds = [1, 345]
dupInds = [[0,5],[1,3]]

我正在通过以下方法找出哪些对象具有多个条目:

^{pr2}$

但是,这需要太长时间,因为len( traced_descIDs )大约是150000。有没有更快的方法达到同样的效果?在

非常感谢任何帮助。干杯。在


Tags: 对象方法idlen过程条目数组小心
3条回答

虽然字典是O(n),但是Python对象的开销有时使使用numpy的函数更加方便,这些函数使用sorting和are O(n*logn)。在您的案例中,起点是:

a = [1, 345, 23, 345, 90, 1]
unq, unq_idx, unq_cnt = np.unique(a, return_inverse=True, return_counts=True)

如果您使用的是早于1.9的numpy版本,那么最后一行必须是:

^{pr2}$

我们创建的三个数组的内容是:

>>> unq
array([  1,  23,  90, 345])
>>> unq_idx
array([0, 3, 1, 3, 2, 0])
>>> unq_cnt
array([2, 1, 1, 2])

要获取重复项:

cnt_mask = unq_cnt > 1
dup_ids = unq[cnt_mask]

>>> dup_ids
array([  1, 345])

获取指数需要更多的工作,但非常简单:

cnt_idx, = np.nonzero(cnt_mask)
idx_mask = np.in1d(unq_idx, cnt_idx)
idx_idx, = np.nonzero(idx_mask)
srt_idx = np.argsort(unq_idx[idx_mask])
dup_idx = np.split(idx_idx[srt_idx], np.cumsum(unq_cnt[cnt_mask])[:-1])

>>> dup_idx
[array([0, 5]), array([1, 3])]

您当前的方法是O(N**2),使用字典在O(N)时间内完成:

>>> from collections import defaultdict
>>> traced_descIDs = [1, 345, 23, 345, 90, 1]
>>> d = defaultdict(list)
>>> for i, x in enumerate(traced_descIDs):
...     d[x].append(i)
...     
>>> for k, v in d.items():
...     if len(v) == 1:
...         del d[k]
...         
>>> d
defaultdict(<type 'list'>, {1: [0, 5], 345: [1, 3]})

为了得到项目和索引:

^{pr2}$

请注意,如果您想保持dupIds中项目的顺序,那么就使用^{}和{a2}方法。在

有一个^{}表示每个项目的频率:

>>> xs = np.array([1, 345, 23, 345, 90, 1])
>>> ifreq = sp.stats.itemfreq(xs)
>>> ifreq
array([[  1,   2],
       [ 23,   1],
       [ 90,   1],
       [345,   2]])
>>> [(xs == w).nonzero()[0] for w in ifreq[ifreq[:,1] > 1, 0]]
[array([0, 5]), array([1, 3])]

相关问题 更多 >