查找numpy数组中的所有NaN切片

2024-05-20 14:16:56 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个四维Numpy ndarray(时间、压力水平、纬度、经度),我想检查每个时间和压力水平(维度0和1),如果在纬度或经度维度(2和3)上有一个全NaN切片。 我想用矢量化的方式来处理它,这样就不用在数组上循环了,但是我不知道怎么做

import numpy as np
a=np.ones([2,3,5,5])
a[0,2,:,2]=np.nan*np.ones_like(a[0,2,:,2])
a[0,1,1,:]=np.nan*np.ones_like(a[0,1,1,:])
a[0,0,1,2]=np.nan
a[1,1,:,2]=np.nan*np.ones_like(a[0,2,:,2])
a[1,1,1,:]=np.nan*np.ones_like(a[0,1,1,:])
print(a)

该数组现在保存1(即数字),并且在某些位置仅保存NaN的切片。我想知道这些地点。因此在本例中,我需要发现NaN片位于[0,2,:,2],[0,1,1,:],[1,1,:,2]和[1,1,1,:]


Tags: numpynp时间ones水平切片数组nan
1条回答
网友
1楼 · 发布于 2024-05-20 14:16:56

您应该使用np.isnan函数来创建与原始矩阵大小相同的布尔矩阵。然后只需使用np.all之类的布尔约简操作。因此,下面的代码将所有元素都等于np.nan的行(轴=1)的索引存储在idx

arr = np.array([[0, 0, 0], [np.nan, np.nan, np.nan], [1, np.nan, 1]])
arr_isnan = np.isnan(arr)
idx = np.argwhere(arr_isnan.all(axis=1))

输出:

>>>print(idx)
[[1]]

按照您的示例,此方法提供以下输出:

arr_isnan = np.isnan(a)
idx = np.argwhere(arr_isnan.all(axis=2))

>>>print(idx) #[0,2,:,2] and [1,1,:,2] because axis=2
array([[0, 2, 2],
       [1, 1, 2]], dtype=int64)

>>>print(a[idx[:,0], idx[:,1], :, idx[:,2]])
[[nan nan nan nan nan]
 [nan nan nan nan nan]]

因此,您只需根据轴调整“:”的位置

相关问题 更多 >