我试图下载一个项目的MNIST手写数字数据库的一部分。具体来说,我只希望数字0,1,2和3被发送到神经网络。你知道吗
我当前正在加载这样的数据(基于"Neural Networks and Deep Learning" by Michal Daniel Dobrzanski):
import cPickle
import gzip
import numpy as np
def load_data():
f = gzip.open('src/mnist.pkl.gz', 'rb')
training_data, validation_data, test_data = cPickle.load(f)
f.close()
return (training_data, validation_data, test_data)
def load_data_wrapper():
tr_d, va_d, te_d = load_data()
training_inputs = [np.reshape(x, (784, 1)) for x in tr_d[0]]
training_results = [vectorized_result(y) for y in tr_d[1]]
training_data = zip(training_inputs, training_results)
validation_inputs = [np.reshape(x, (784, 1)) for x in va_d[0]]
validation_data = zip(validation_inputs, va_d[1])
test_inputs = [np.reshape(x, (784, 1)) for x in te_d[0]]
test_data = zip(test_inputs, te_d[1])
return (training_data, validation_data, test_data)
在发送到load_data_wrapper()
之前,我尝试从load_data()
构建一个函数来创建新的数据集(通过在load_data_wrapper()
中将tr_d, va_d, te_d = load_data()
更改为tr_d, va_d, te_d = digitTest()
),但没有成功,请参见下面的内容:
def digitTest():
tr_d, va_d, te_d = load_data()
tr_d = list(tr_d)
va_d = list(va_d)
te_d = list(te_d)
newTrD = []
newTrD.append([])
newTrD.append([])
newVaD = []
newVaD.append([])
newVaD.append([])
newTeD = []
newTeD.append([])
newTeD.append([])
for index,label in enumerate(tr_d[1]):
if tr_d[1][index] < 4:
newTrD[0].append(tr_d[0][index])
newTrD[1].append(tr_d[1][index])
for index,label in enumerate(va_d[1]):
if va_d[1][index] < 4:
newVaD[0].append(va_d[0][index])
newVaD[1].append(va_d[1][index])
for index,label in enumerate(te_d[1]):
if te_d[1][index] < 4:
newTeD[0].append(te_d[0][index])
newTeD[1].append(te_d[1][index])
return (newTrD, newVaD, newTeD)
有可能达到我想要的吗?我该怎么做?请注意,从load\ U data函数解析时,数据存储在元组中。你知道吗
我从未使用cPickle加载mnist数据集,也不知道它返回什么。 阅读你的代码看起来你做的事情是对的,但是如果你说它不起作用,我想有些事情与cPickle返回数据的内容或方式有关。你知道吗
我没有python 2,因此无法调试您的代码,但是:
我倾向于自己做这些事情:
此函数将从文件中加载一组mnist标签和值。然后你就可以把数据解压了。 标签应该是“火车”-idx1.1标签-“乌比特”。只需将路径传递给train label和images或test label和images到函数中,它就会加载这些值。你知道吗
返回值是两个列表的元组:
其中像素是一个列表本身。你知道吗
此外,如果文件不存在或(可能)文件格式不正确,除了抛出异常之外,这不会进行错误检查,因此您可能需要考虑以某种方式进行检查。你知道吗
我也不习惯numpy,我通常在c++和java中工作,但您肯定可以很容易地将这些值转换为numpy数组—只需阅读本主题即可。你知道吗
你现在应该可以很容易地使用数字过滤了。你知道吗
您可能会看到,如果使用原始mnist数据集,则只会得到train和test图像。这里发生的事情是,你从其中一个集合中拿出一部分,用它作为——我不完全确定你在这里的措辞——测试数据来评估训练进度。培训结束后,您可以使用“t10k”文件来验证您的网络培训效果。这里很重要的一点是,如果你从这些t10k图像中分割测试数据,你就不会再使用这些图像了,只剩下剩下一部分,因为这样做的目的是验证网络上还没有看到的数据的训练。你知道吗
相关问题 更多 >
编程相关推荐