在数组中查找重复项的有效方法

2024-09-27 00:22:32 发布

您现在位置:Python中文网/ 问答频道 /正文

我现在正在用Python编写一个程序,它逐行遍历2D numpy数组,并在不同数组中查找相同的行。如果它找到一个重复的数组,它将使用第一个数组的索引运行一小段代码

当阵列很小(约2x500和2x500)时,这种方法可以很好地工作并足够有效,但对于较长的阵列,这种方法很快就会变得效率低下。我想知道是否有人知道使用numpy的方法(我目前正在其他地方使用其他numpy功能,因此最好不要更改数据类型),或者其他更有效的方法。我确信在数组中有比两个更快的for循环。 提前谢谢

import random
import numpy as np
N = 1000
speed = 50
longueur = 20000          
largeur =  30000          
quadrillage = 50 
p= 0.8               
def stick():
    u = random.random()
    if u <p:
        a = 1   #The particle is stuck
    else:
        a =0    #The particle did not stick, it will instead bounce
    return a 
obstacle_number =2000   
maxstuck = 4 
numbstuck = np.zeros((obstacle_number)) 

spacinglarg = largeur/quadrillage
spacinglong = longueur/quadrillage
obs0 = np.random.randint(0, spacinglarg,(obstacle_number,1)) *quadrillage
obs1 =  np.random.randint(0, spacinglong,(obstacle_number,1)) *quadrillage
obs = np.concatenate([obs0,obs1], axis =1)

s=(N,2)
global A
A = np.zeros(s)
for i in range (0,N):
    a = i*longueur/N
    b = 50
    A[i,0]= b
    A[i,1]= a


T = 50*np.round(A/(50))

B=np.zeros(s)
tp = 2*np.pi
for i in range(0,nombre_atomes):
    aa = random.randint(0,360)/tp
    B[i,0]=np.cos(aa)*speed
    B[i,1]=np.sin(aa)*speed


for i in range(0, N):
    for j in range(0,len(obs)):
        if T[i,0] == obs[j,0] and T[i,1] == obs[j,1]:
            if numbstuck[j] <= maxstuck and abs(B[i,0]) != 0:    
                sss= stick()
                if sss == 1: #if it sticks
                    B[i,0]=0
                    B[i,1]=0
                    numbstuck[j] += 1
                else:
                    B[i,0]=-B[i,0] 
                    B[i,1]=-B[i,1] 

Tags: 方法innumpynumberforifnprange
1条回答
网友
1楼 · 发布于 2024-09-27 00:22:32

总体思路是以简单的方式写下代码:

import numpy as np

size = 42
a = np.arange(size**2).reshape(size, size)
b = a.copy() + size * 5

def detect_duplicates(a, b):
    duplicates = []
    for i, row_a in enumerate(a):
        for j, row_b in enumerate(b):
            if np.all(row_a == row_b):
                duplicates.append((i, j))
    return duplicates

但这相当缓慢:

In [1]: %timeit detect_duplicates(a, b)
7.1 ms ± 40.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

但是使用^{}可以大大加快速度,而无需更改循环中的一行代码:

import numpy as np
import numba

size = 42
a = np.arange(size**2).reshape(size, size)
b = a.copy() + size * 5

@numba.njit
def detect_duplicates(a, b):
    duplicates = []
    for i, row_a in enumerate(a):
        for j, row_b in enumerate(b):
            if np.all(row_a == row_b):
                duplicates.append((i, j))
    return duplicates

现在速度要快得多:

In [1]: %timeit detect_duplicates(a, b)
235 µs ± 13.8 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

相关问题 更多 >

    热门问题