保留矩阵中的元素值

2024-10-01 04:46:38 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个带有标签和图像的文本数据集。标签是表示手写数字的一维元素Dimension:(1010,)。图像是28*28像素大小的图像。Dimension:(1010, 784)。在从文本数据集读取之后,我有以下数据集reformatData['data']reformatData['target'],它们分别是[n_samples, n_features][n_samples]

同样,这些的尺寸:(1010, 784) (1010,)打印时reformatData

现在我正在尝试进行二进制分类,并在矩阵中引入数字,我尝试使用下面的函数来实现这一点

digits1=[8]
digits2=[1]


def read(digits):
    rows=28
    cols=28
    #lbl = array("b", reformatData['target'])
    lbl = reformatData['target']
    img=reformatData['data']
    #img = array("B", reformatData['data'])

    ind = [ k for k in xrange(len(lbl)) if lbl[k] in digits]
    images =  matrix(0, (len(ind), rows*cols))
    labels = matrix(0, (len(ind), 1))
    for i in xrange(len(ind)):
        images[i, :] = img[ ind[i]*rows*cols : (ind[i]+1)*rows*cols ]
        labels[i] = lbl[ind[i]]
    return images, labels

print read(digits=digits1)

输出

(<0x784 matrix, tc='i'>, <0x1 matrix, tc='i'>)

我期望:

(<1010x784 matrix, tc='i'>, <1010x1 matrix, tc='i'>)

我该怎么解决这个问题


Tags: 数据in图像targetimgdatalenmatrix
2条回答

使用numpywhere进行矢量化和更快的计算:

上面有rahfl23的数组:

np.where(s==6, 0, 1)

对于矩阵:

np.where(images==6, 0, 1)

要映射二进制分类的两个数字,请通过列表理解修改目标向量:

import numpy as np

s = np.array([6, 8, 6, 6, 6, 8, 6, 8, 8, 8, 6, 6, 6, 8, 8, 6, 8, 6, 8] ) 

new = np.array([0 if i==6 else 1 for i in s])

输出:

[0 1 0 0 0 1 0 1 1 1 0 0 0 1 1 0 1 0 1]

相关问题 更多 >