加速numpy数组中的索引查找

2024-09-28 03:13:46 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个1d numpy字符串数组(dtype='U'),名为ops,长度为15MM,其中我需要找到所有索引,其中我找到了一个名为op83000次的字符串

到目前为止,numpy赢得了比赛,但仍然需要3个小时:indices = np.where(ops==op) 我也尝试了np.unravel_index(np.where(ops.ravel()==op), ops.shape)[0][0],没有太大区别

我正在尝试一种cython方法,其随机数据与原始数据类似,但其速度比numpys解决方案慢40倍左右。这是我的第一个cython代码也许我可以改进它。 Cython代码:

import numpy as np
cimport numpy as np

def get_ixs(np.ndarray data, str x, np.ndarray[int,mode="c",ndim=1] xind):
    cdef int count, n, i
    count = 0
    n = data.shape[0]
    i = 0
    while i < n:
        if (data[i] == x):
            xind[count] = i
            count += 1
        i += 1

    return xind[0:count]

Tags: 字符串代码numpydataascountnpwhere
1条回答
网友
1楼 · 发布于 2024-09-28 03:13:46

如果您使用相同的data多次调用get_ixs,最快的解决方案是将data预处理为dict,然后在查询字符串时获得O(1)查找(恒定时间)。
dict的键是字符串x,该键的值是包含满足data[i] == x的索引的列表。
代码如下:

import numpy as np

data = np.array(["toto", "titi", "toto", "titi", "tutu"])

indices = np.arange(len(data))
# sort data so that we can construct the dict by replacing list with ndarray as soon as possible (when string changes) to reduce memory usage
indices_data_sorted = np.argsort(data)  
data = data[indices_data_sorted]
indices = indices[indices_data_sorted]

# construct the dict str -> ndarray of indices (use ndarray for lower memory consumption)
dict_str_to_indices = dict()
prev_str = None
list_idx = []  # list to hold the indices for a given string
for i, s in zip(indices, data):
    if s != prev_str:  
        # the current string has changed so we can construct the ndarray and store it in the dict
        if prev_str is not None:
            dict_str_to_indices[prev_str] = np.array(list_idx, dtype="int32")
        list_idx.clear()
        prev_str = s
    list_idx.append(i)
    
dict_str_to_indices[s] = np.array(list_idx, dtype="int32")  # add the ndarray for last string

def get_ixs(dict_str_to_indices: dict, x: str):
    return dict_str_to_indices[x]

print(get_ixs(dict_str_to_indices, "toto"))
print(get_ixs(dict_str_to_indices, "titi"))
print(get_ixs(dict_str_to_indices, "tutu"))

输出:

[0 2]
[1 3]
[4]

如果使用相同的dict_str_to_indices多次调用get_ixs,则这是最佳渐近解(O(1)查找)

相关问题 更多 >

    热门问题