加载Keras的自定义数据集

2024-10-02 14:20:19 发布

您现在位置:Python中文网/ 问答频道 /正文

n示例:

来自[neon example](http://neon.nervanasys.com/index.html/mnist.html):(类似于keras)

from neon.data import MNIST

 mnist = MNIST()

(X_train, y_train), (X_test, y_test), nclass = mnist.load_data()

我想为UCF_CC_50数据集获取相同的元组集。在

这是一个由50张不同图像组成的数据集,是拥挤地区的鸟瞰图。 我正在修改the segment behind this。在

所有图像都会下载并包含在“images”文件夹中。在

这是初始化

^{pr2}$

这就是我目前所拥有的。我不知道如何修改init。在

class UCF(dataset):
**def __init__(self, path='.', subset_pct=100, normalize=True):
    super(UCF, self).__init__('Images',
                                '//url',
                                15296311,
                                path=path,
                                subset_pct=subset_pct)**
    self.normalize = normalize

def load_data(self):
    filepath = self._valid_path_append(self.path, self.filename)

    with open(filepath, 'rb') as ucf:
        (X_train, y_train), (X_test, y_test) = pickle_load(ucf)
        X_train = X_train.reshape(-1, 784)
        X_test = X_test.reshape(-1, 784)

        if self.normalize:
            X_train = X_train / 255.
            X_test = X_test / 255.

    return (X_train, y_train), (X_test, y_test), 10

def gen_iterators(self):
    (X_train, y_train), (X_test, y_test), nclass = self.load_data()
    train = ArrayIterator(X_train,
                          y_train,
                          nclass=nclass,
                          lshape=(1, 28, 28),
                          name='train')
    val = ArrayIterator(X_test,
                        y_test,
                        nclass=nclass,
                        lshape=(1, 28, 28),
                        name='valid')
    self._data_dict = {'train': train,
                       'valid': val}
    return self._data_dict

有人能帮我吗?在


Tags: pathtestselfdatainitdefloadtrain