在可变长度numpy数组的列中,基于数组的dict查找创建新数组

2024-09-30 01:22:59 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在寻找一种性能化或矢量化的方法来完成这个场景,其中我有一个由可变长度数组组成的初始列(称为“数据”),并创建一个新列(也可以是numpy数组),其中原始数组中的值被dict中的某个查找值替换

在我的示例中,输出列“lookedup_values”是我试图快速创建的

import pandas as pd
import numpy as np 

def map_my_values(row):
  return [lookup.get(val) for val in row]

lookup = {10: 234234, 20: 253458, 30: 99934, 40: 90083, 50: 55847, 70: 99938, 100: 325230}

df = pd.DataFrame([
                   {'id':1234, 'data': np.array([10, 20, 30])},
                   {'id':1235, 'data': np.array([50, 70])},
                   {'id': 1236,'data': np.array([20, 10])},
                   {'id': 1237,'data': np.array([100, 30, 50, 10])}
])
df['lookedup_values'] = df['data'].map(map_my_values)
df.head()

     id               data                 lookedup_values
0  1234       [10, 20, 30]         [234234, 253458, 99934]
1  1235           [50, 70]                  [55847, 99938]
2  1236           [20, 10]                [253458, 234234]
3  1237  [100, 30, 50, 10]  [325230, 99934, 55847, 234234]

在我的“data”列中使用.map或apply并迭代数组很简单,但是在数千万行的数据集上它的速度非常慢。我希望有人知道矢量化的解决方案或方法

我已经在上面提供了一个完整/有效的演示。谢谢


Tags: 数据方法importnumpyidmapdfdata
1条回答
网友
1楼 · 发布于 2024-09-30 01:22:59

我发现减少运行时间的唯一方法是使用^{}

from numba import njit

@njit
def numba_map(row):
    lookup = {10: 234234, 20: 253458, 30: 99934, 40: 90083, 50: 55847, 70: 99938, 100: 325230}

    return np.array([lookup[val] for val in row])

df['lookedup_values'] = df['data'].map(numba_map)

可以在映射之前^{}data和按索引^{}

x = df.data.explode().map(lookup)
df['lookedup_values'] = x.groupby(x.index).apply(np.array)
df

输出:

     id               data                 lookedup_values
0  1234       [10, 20, 30]         [234234, 253458, 99934]
1  1235           [50, 70]                  [55847, 99938]
2  1236           [20, 10]                [253458, 234234]
3  1237  [100, 30, 50, 10]  [325230, 99934, 55847, 234234]

微观基准

在具有2个内核和12 GB RAM的colab实例上,pandas解决方案比具有理解能力的原始解决方案慢约10倍。我没料到

结果显示len(df)/4在x轴上

results

基准代码

import pandas as pd

def map_my_values(row):
  return [lookup.get(val) for val in row]

lookup = {10: 234234, 20: 253458, 30: 99934, 40: 90083, 50: 55847, 70: 99938, 100: 325230}

from numba import njit
@njit
def numba_map(row):
  lookup = {10: 234234, 20: 253458, 30: 99934, 40: 90083, 50: 55847, 70: 99938, 100: 325230}
  return np.array([lookup[val] for val in row])

def makedata(n=1):
  return pd.concat([pd.DataFrame([
                   {'id':1234, 'data': np.array([10, 20, 30])},
                   {'id':1235, 'data': np.array([50, 70])},
                   {'id': 1236,'data': np.array([20, 10])},
                   {'id': 1237,'data': np.array([100, 30, 50, 10])}
])]*n).reset_index(drop=True)

def comprehension(df):
  df['lookedup_values'] = df['data'].map(map_my_values)
  return df

def explode(df):
  x = df.data.explode().map(lookup)
  df['lookedup_values'] = x.groupby(x.index).apply(np.array)
  return df

def numbamap(df):
  df['lookedup_values'] = df['data'].map(numba_map)
  return df

import perfplot
perfplot.show(
    setup=makedata,
    kernels=[comprehension, explode, numbamap],
    n_range=[2**k for k in range(5,18)],
    equality_check=False
)

相关问题 更多 >

    热门问题