只下载某些MNIST数字

2024-10-03 21:32:00 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图下载一个项目的MNIST手写数字数据库的一部分。具体来说,我只希望数字0,1,2和3被发送到神经网络。你知道吗

我当前正在加载这样的数据(基于"Neural Networks and Deep Learning" by Michal Daniel Dobrzanski):

import cPickle
import gzip
import numpy as np

def load_data():
    f = gzip.open('src/mnist.pkl.gz', 'rb')
    training_data, validation_data, test_data = cPickle.load(f)
    f.close()
    return (training_data, validation_data, test_data)

def load_data_wrapper():
    tr_d, va_d, te_d = load_data()
    training_inputs = [np.reshape(x, (784, 1)) for x in tr_d[0]]
    training_results = [vectorized_result(y) for y in tr_d[1]]
    training_data = zip(training_inputs, training_results)
    validation_inputs = [np.reshape(x, (784, 1)) for x in va_d[0]]
    validation_data = zip(validation_inputs, va_d[1])
    test_inputs = [np.reshape(x, (784, 1)) for x in te_d[0]]
    test_data = zip(test_inputs, te_d[1])
    return (training_data, validation_data, test_data)

在发送到load_data_wrapper()之前,我尝试从load_data()构建一个函数来创建新的数据集(通过在load_data_wrapper()中将tr_d, va_d, te_d = load_data()更改为tr_d, va_d, te_d = digitTest()),但没有成功,请参见下面的内容:

def digitTest():
    tr_d, va_d, te_d = load_data()
    tr_d = list(tr_d)
    va_d = list(va_d)
    te_d = list(te_d)

    newTrD = []
    newTrD.append([])
    newTrD.append([])

    newVaD = []
    newVaD.append([])
    newVaD.append([])

    newTeD = []
    newTeD.append([])
    newTeD.append([])

    for index,label in enumerate(tr_d[1]):
        if tr_d[1][index] < 4:
            newTrD[0].append(tr_d[0][index])
            newTrD[1].append(tr_d[1][index])

    for index,label in enumerate(va_d[1]):
        if va_d[1][index] < 4:
            newVaD[0].append(va_d[0][index])
            newVaD[1].append(va_d[1][index])

    for index,label in enumerate(te_d[1]):
        if te_d[1][index] < 4:
            newTeD[0].append(te_d[0][index])
            newTeD[1].append(te_d[1][index])

    return (newTrD, newVaD, newTeD)

有可能达到我想要的吗?我该怎么做?请注意,从load\ U data函数解析时,数据存储在元组中。你知道吗


Tags: intestfordataindextrainingloadtr
1条回答
网友
1楼 · 发布于 2024-10-03 21:32:00

我从未使用cPickle加载mnist数据集,也不知道它返回什么。 阅读你的代码看起来你做的事情是对的,但是如果你说它不起作用,我想有些事情与cPickle返回数据的内容或方式有关。你知道吗

我没有python 2,因此无法调试您的代码,但是:

我倾向于自己做这些事情:

def loadSet(values_path, labels_path):
    labels = []
    # labels:
    # 0000     32 bit integer  0x00000803(2051) magic number
    # 0008     32 bit integer  28               number of labels
    # 0009     unsigned byte   ??               label
    # 0010     unsigned byte   ??               label
    # ....     unsigned byte   ??               label

    with open(labels_path, 'rb') as f:
        m_number = int.from_bytes(f.read(4,), 'big')
        num_labels = int.from_bytes(f.read(4), 'big')
        for i in range(num_labels):
            labels.append(int.from_bytes(f.read(1), 'big'))

    images = []
    # images:
    # 0000     32 bit integer  0x00000803(2051) magic number
    # 0004     32 bit integer  60000            number of images
    # 0008     32 bit integer  28               number of rows
    # 0012     32 bit integer  28               number of columns
    # 0016     unsigned byte   ??               pixel
    # 0020     unsigned byte   ??               pixel
    # ....     unsigned byte   ??               pixel

    with open(values_path, 'rb') as f:
        m_number = int.from_bytes(f.read(4), 'big')
        num_images = int.from_bytes(f.read(4), 'big')
        num_rows = int.from_bytes(f.read(4), 'big')
        num_cols = int.from_bytes(f.read(4), 'big')
        for i in range(num_images):
            image = []
            for x in range(num_rows * num_cols):
                image.append(int.from_bytes(f.read(1), 'big'))
            images.append(image)

此函数将从文件中加载一组mnist标签和值。然后你就可以把数据解压了。 标签应该是“火车”-idx1.1标签-“乌比特”。只需将路径传递给train label和images或test label和images到函数中,它就会加载这些值。你知道吗

返回值是两个列表的元组:

([number], [pixels])

其中像素是一个列表本身。你知道吗

此外,如果文件不存在或(可能)文件格式不正确,除了抛出异常之外,这不会进行错误检查,因此您可能需要考虑以某种方式进行检查。你知道吗

我也不习惯numpy,我通常在c++和java中工作,但您肯定可以很容易地将这些值转换为numpy数组—只需阅读本主题即可。你知道吗

你现在应该可以很容易地使用数字过滤了。你知道吗

您可能会看到,如果使用原始mnist数据集,则只会得到train和test图像。这里发生的事情是,你从其中一个集合中拿出一部分,用它作为——我不完全确定你在这里的措辞——测试数据来评估训练进度。培训结束后,您可以使用“t10k”文件来验证您的网络培训效果。这里很重要的一点是,如果你从这些t10k图像中分割测试数据,你就不会再使用这些图像了,只剩下剩下一部分,因为这样做的目的是验证网络上还没有看到的数据的训练。你知道吗

相关问题 更多 >