如何在NumPy中沿轴切片多维数组?

2024-09-28 17:18:19 发布

您现在位置:Python中文网/ 问答频道 /正文

假设我有一个x形状(n,) + higher_dims的数组,其中n是一个正整数,higher_dims是一个任意长度的正整数元组。 也就是说,n是第一个轴的大小,可以有任意多个轴。你知道吗

假设我还有一个形状为(k,) + higher_dimsindices数组,其中k是一个正整数。 也就是说,indicesx具有相同的形状,除了可能的第一个轴。 假设indices的每个条目都是0n - 1之间的整数。你知道吗

我想创建一个数组y,它的形状与indices相同,并且满足

y[i, ...] = x[indices[i, ...], ...]

对于0n - 1之间的每个i。这里...表示剩余轴的索引的任意组合,而不是Ellipsis object。你知道吗

例如,如果x是三维的,我可以使用For循环来创建y

import numpy as np

x = np.arange(24).reshape((4, 2, 3))
print('x =', x, sep='\n')

indices = np.asarray([[[1, 0, 1], [2, 1, 2]], [[3, 1, 2], [0, 0, 1]]])
print('indices =', indices, sep='\n')

y = np.empty(indices.shape, dtype=x.dtype)
for i in range(indices.shape[0]):
    for j in range(indices.shape[1]):
        for k in range(indices.shape[2]):
            y[i, j, k] = x[indices[i, j, k], j, k]  # Defining property of y
print('y =', y, sep='\n')

输出:

x =
[[[ 0  1  2]
  [ 3  4  5]]
 [[ 6  7  8]
  [ 9 10 11]]
 [[12 13 14]
  [15 16 17]]
 [[18 19 20]
  [21 22 23]]]
indices =
[[[1 0 1]
  [2 1 2]]
 [[3 1 2]
  [0 0 1]]]
y =
[[[ 6  1  8]
  [15 10 17]]
 [[18  7 14]
  [ 3  4 11]]]

I am looking for a function or indexing trick to achieve this behavior in general (for ndarrays of arbitrary dimension), without Python loops if possible.


Tags: ofinfornprange数组sep形状