如何在Numpy中将索引数组转换为掩码数组?

2024-05-17 07:34:56 发布

您现在位置:Python中文网/ 问答频道 /正文

在给定的范围内,是否可以将索引数组转换为1和0数组? i、 e.[2,3]->;[0,0,1,1,0],范围为5

我正试着把这样的事情自动化:

>>> index_array = np.arange(200,300)
array([200, 201, ... , 299])

>>> mask_array = ???           # some function of index_array and 500
array([0, 0, 0, ..., 1, 1, 1, ... , 0, 0, 0])

>>> train(data[mask_array])    # trains with 200~299
>>> predict(data[~mask_array]) # predicts with 0~199, 300~499

Tags: andofgtdataindexwithnpfunction
3条回答

有一个很好的技巧可以将其作为一个单独的行来完成,也可以使用这样的numpy.in1dnumpy.arange函数(最后一行是关键部分):

>>> x = np.linspace(-2, 2, 10)
>>> y = x**2 - 1
>>> idxs = np.where(y<0)

>>> np.in1d(np.arange(len(x)), idxs)
array([False, False, False,  True,  True,  True,  True, False, False, False], dtype=bool)

这种方法的缺点是它比沃伦·韦克瑟提出的方法慢10-100倍。。。但这是一条单行线,可能是也可能不是你要找的。

有一种方法:

In [1]: index_array = np.array([3, 4, 7, 9])

In [2]: n = 15

In [3]: mask_array = np.zeros(n, dtype=int)

In [4]: mask_array[index_array] = 1

In [5]: mask_array
Out[5]: array([0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0])

如果掩码始终是一个范围,则可以消除index_array,并将1分配给切片:

In [6]: mask_array = np.zeros(n, dtype=int)

In [7]: mask_array[5:10] = 1

In [8]: mask_array
Out[8]: array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0])

如果需要布尔值数组而不是整数,请在创建时更改mask_arraydtype

In [11]: mask_array = np.zeros(n, dtype=bool)

In [12]: mask_array
Out[12]: 
array([False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False], dtype=bool)

In [13]: mask_array[5:10] = True

In [14]: mask_array
Out[14]: 
array([False, False, False, False, False,  True,  True,  True,  True,
        True, False, False, False, False, False], dtype=bool)

对于单个维度,请尝试:

n = (15,)
index_array = [2, 5, 7]
mask_array = numpy.zeros(n)
mask_array[index_array] = 1

对于多个维度,请将n维索引转换为一维索引,然后使用ravel:

n = (15, 15)
index_array = [[1, 4, 6], [10, 11, 2]] # you may need to transpose your indices!
mask_array = numpy.zeros(n)
flat_index_array = np.ravel_multi_index(
    index_array,
    mask_array.shape)
numpy.ravel(mask_array)[flat_index_array] = 1

相关问题 更多 >