Numpy矢量化元素为元组的位置

2024-09-27 09:33:46 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图理解np.vectorize是如何工作的。我想遍历一个数组,其中每个元素都是一个元组,然后修改元组中的每个元素,但这似乎不起作用

arr = [(1,3), (9,12)]

def func(i):
    return (i[0]+1, i[1]+2)

vfunc = np.vectorize(func)
vfunc(arr)

这将是预期的产出:

[(2,5), (10,14)]

请帮助我理解为什么它不遍历每个元组,以及我将如何实现这一点-而不使用for循环


Tags: 元素forreturndefnp数组元组func
1条回答
网友
1楼 · 发布于 2024-09-27 09:33:46

使用

[func(i) for i in arr]

np.vectorize,即使在它工作的地方也不会更快

np.array(arr)+np.array([[1,2]])

对于真正的numpy“矢量化”计算,将(1,2)数组添加到(2,2)arr

附言

arr是元组的列表。如果您可以构造一个元组数组,我将重新讨论这个问题

编辑

我们应该反对,因为您没有提供错误消息!尤其是当你试图理解vectorize

In [307]: np.array(arr)
Out[307]: 
array([[ 1,  3],
       [ 9, 12]])
In [308]: vfunc = np.vectorize(func)
     ...: vfunc(arr)
Traceback (most recent call last):
  File "<ipython-input-308-61dee032d8f1>", line 2, in <module>
    vfunc(arr)
  File "/usr/local/lib/python3.8/dist-packages/numpy/lib/function_base.py", line 2163, in __call__
    return self._vectorize_call(func=func, args=vargs)
  File "/usr/local/lib/python3.8/dist-packages/numpy/lib/function_base.py", line 2241, in _vectorize_call
    ufunc, otypes = self._get_ufunc_and_otypes(func=func, args=args)
  File "/usr/local/lib/python3.8/dist-packages/numpy/lib/function_base.py", line 2201, in _get_ufunc_and_otypes
    outputs = func(*inputs)
  File "<ipython-input-306-ab62cb78d2b7>", line 4, in func
    return (i[0]+1, i[1]+2)
IndexError: invalid index to scalar variable.

函数中出现错误是因为您试图为标量变量编制索引i是一个数字,不是元组

测试我最初的建议:

In [309]: [func(i) for i in arr]
Out[309]: [(2, 5), (10, 14)]
In [310]: np.array(arr)+np.array([1,2])
Out[310]: 
array([[ 2,  5],
       [10, 14]])

或制作一个2元素数组:

In [311]: A = np.empty(2, object)
In [312]: A[:]=arr
In [313]: A
Out[313]: array([(1, 3), (9, 12)], dtype=object)
In [314]: A[0]
Out[314]: (1, 3)
In [315]: vfunc(A)
Out[315]: (array([ 2, 10]), array([ 5, 14]))

将此AOut[307]进行比较。非常不同的阵列

另一种诊断方法是向func添加打印

In [317]: def func(i):
     ...:     print(i)
     ...:     return (i[0]+1, i[1]+2)
     ...: 
In [318]: vfunc = np.vectorize(func)
In [319]: vfunc(arr)
1
Traceback (most recent call last):
  File "<ipython-input-319-62576075c0b8>", line 1, in <module>
 ...
IndexError: invalid index to scalar variable.

vectorized传递np.array(arr)的第一个元素,而不是arr的第一个元组,标量1

应用于对象数组:

In [320]: vfunc(A)
(1, 3)               # the documented trial call
(1, 3)
(9, 12)
Out[320]: (array([ 2, 10]), array([ 5, 14]))

现在我们看到它正在向函数传递元组

相关问题 更多 >

    热门问题