我有一个带有标签和图像的文本数据集。标签是表示手写数字的一维元素Dimension:(1010,)
。图像是28*28像素大小的图像。Dimension:(1010, 784)
。在从文本数据集读取之后,我有以下数据集reformatData['data']
和reformatData['target']
,它们分别是[n_samples, n_features]
和[n_samples]
同样,这些的尺寸:(1010, 784) (1010,)
打印时reformatData
现在我正在尝试进行二进制分类,并在矩阵中引入数字,我尝试使用下面的函数来实现这一点
digits1=[8]
digits2=[1]
def read(digits):
rows=28
cols=28
#lbl = array("b", reformatData['target'])
lbl = reformatData['target']
img=reformatData['data']
#img = array("B", reformatData['data'])
ind = [ k for k in xrange(len(lbl)) if lbl[k] in digits]
images = matrix(0, (len(ind), rows*cols))
labels = matrix(0, (len(ind), 1))
for i in xrange(len(ind)):
images[i, :] = img[ ind[i]*rows*cols : (ind[i]+1)*rows*cols ]
labels[i] = lbl[ind[i]]
return images, labels
print read(digits=digits1)
输出
(<0x784 matrix, tc='i'>, <0x1 matrix, tc='i'>)
我期望:
(<1010x784 matrix, tc='i'>, <1010x1 matrix, tc='i'>)
我该怎么解决这个问题
使用
numpy
where
进行矢量化和更快的计算:上面有rahfl23的数组:
对于矩阵:
要映射二进制分类的两个数字,请通过列表理解修改目标向量:
输出:
相关问题 更多 >
编程相关推荐