MXNET无效类型“<type”努比·恩达雷“>”表示数据,应为NDArray,努比·恩达雷,

2024-09-22 16:25:46 发布

您现在位置:Python中文网/ 问答频道 /正文

我在使用mxnet的基本IO时遇到问题。我试图使用mxnet.io.NDArrayIter来读取内存中的数据集,以便在mxnet中进行培训。我有下面的代码(为了简洁起见,压缩了代码),它对代码进行预处理,并尝试遍历它(主要基于tutorial):

import csv
import mxnet as mx
import numpy as np

from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
from sklearn.pipeline import Pipeline


with open('data.csv', 'r') as data_file:
    data = list(csv.reader(data_file))

labels = np.array(map(lambda x: x[1], data)) # one-hot encoded classes
data = map(lambda x: x[0], data) # raw text in need of pre-processing

transformer = Pipeline(steps=(('count_vectorizer', CountVectorizer()),
                              ('tfidf_transformer', TfidfTransformer())))

preprocessed_data = np.array([np.array(row) for row in transformer.fit_transform(data)])

training_data = mx.io.NDArrayIter(data=preprocessed_data, label=labels, batch_size=50)

for i, batch in enumerate(training_data):
    print(batch)

执行此代码时,我收到以下错误:

^{pr2}$

我不明白,因为在创建NDArrayIter实例之前,我的数据被转换为numpy.ndarray。有人愿意提供一些关于如何读取mxnet中数据的见解吗?在

以上代码当前使用以下版本:

  • mxnet-1.1.0版
  • 数字-1.14.2

Tags: csv数据代码inioimportdataas
1条回答
网友
1楼 · 发布于 2024-09-22 16:25:46

user2357112的帮助下,通过在python3中使用异常链接来查找异常(更新有问题):

transformer管道返回的是scipy.sparse.csr_matrix矩阵的numpy.array,而不是二维numpy.array。通过添加更改以下行以使用toarray方法进行转换,脚本将运行。在

preprocessed_data = np.array([row.toarray() for row in transformer.fit_transform(data)])

最佳解决方案:当在scipy.sparse.csr_matrix上使用时,toarray在内存消耗方面效率低下。在mxnet1.10版本中,可以使用mxnet.nd.sparse.array来更有效地存储数据:

^{pr2}$

唯一的警告是必须在NDArrayIter(功能last_batch_handlehere)中使用last_batch_handle='discard'关键字参数

相关问题 更多 >