我用Keras来训练神经网络,我已经到了数据集比计算机上安装的RAM大的程度,所以是时候修改我的训练脚本来调用了model.fit_发电机()而不是模型.拟合(),所以我不必一次将所有的培训和验证数据加载到RAM中。在
我已经做了修改,但有一件事让我有点困扰——我在网上看到的所有fit_generator()的示例用法都使用Python的yield
特性来存储生成器的状态。在本质上,我是一个老C++程序员,怀疑我不完全理解的特征,如^ {< CD1>},因此我想显式而不是隐式地维护我的生成器状态,因此我实现了我的生成器:
class DataGenerator:
def __init__(self, inputFileName, maxExamplesPerBatch):
self._inputFileName = inputFileName
self._maxExamplesPerBatch = maxExamplesPerBatch
self._inputsFile = open(inputFileName, "rb")
if (self._inputsFile == None):
self._print("Couldn't open file %s to read input data" % inputFileName)
sys.exit(10)
self._outputsFile = open(inputFileName, "rb") # yes, we're deliberately opening the same file twice (to avoid having to call seek() a lot)
if (self._outputsFile == None):
self._print("Couldn't open file %s to read output data" % inputFileName)
sys.exit(10)
headerInfo = struct.unpack("<4L", self._inputsFile.read(16))
if (headerInfo[0] != 1414676815):
print("Bad magic number in input file [%s], aborting!" % inputFileName)
sys.exit(10)
self._numExamples = headerInfo[1] # Number of input->output rows in our data-file (typically quite large, i.e. millions)
self._numInputs = headerInfo[2] # Number of input values in each row
self._numOutputs = headerInfo[3] # Number of output values in row
self.seekToTopOfData()
def __len__(self):
return (math.ceil(self._numExamples/self._maxExamplesPerBatch))
def __next__(self):
numExamplesToLoad = self._maxExamplesPerBatch
numExamplesLeft = self._numExamples - self._curExampleIdx
if (numExamplesLeft < numExamplesToLoad):
numExamplesToLoad = numExamplesLeft
inputData = np.reshape(np.fromfile(self._inputsFile, dtype='<f4', count=(numExamplesToLoad*self._numInputs)), (numExamplesToLoad, self._numInputs))
outputData = np.reshape(np.fromfile(self._outputsFile, dtype='<f4', count=(numExamplesToLoad*self._numOutputs)), (numExamplesToLoad, self._numOutputs))
self._curExampleIdx += numExamplesToLoad
if (self._curExampleIdx == self._numExamples):
self.seekToTopOfData()
return (inputData, outputData) # <----- NOTE return, not yield!!
def seekToTopOfData(self):
self._curExampleIdx = 0
self._inputsFile.seek(16)
self._outputsFile.seek(16+(self._numExamples*self._numInputs*4))
[...]
trainingDataGenerator = DataGenerator(trainingInputFileName, maxExamplesPerBatch)
validationDataGenerator = DataGenerator(validationInputFileName, maxExamplesPerBatch)
model.fit_generator(generator=trainingDataGenerator, steps_per_epoch=len(trainingDataGenerator), epochs=maxEpochs, callbacks=callbacks_list, validation_data=validationDataGenerator, validation_steps=len(validationDataGenerator))
。。。注意,我的函数以return
而不是yield
结尾,并且我显式地(通过DataGenerator对象中的私有成员变量)而不是隐式地(通过yield
魔术)存储生成器的状态。这似乎很管用。在
我的问题是,这种不寻常的方法是否会引入任何我应该意识到的不明显的行为问题?在
对代码进行一次表面检查就可以了。当您编写一个生成器函数并调用它时,调用将返回一个生成器,该生成器的
__next__
方法通常会被迭代反复调用,直到它引发StopIteration
异常。在生成器是迭代器的特例。类似Iterables的列表有一个生成迭代器的
__iter__
方法。在除非您想将值发送到生成器并将其取出,否则您的
DataGenerator
是实现迭代器的合理方法,但要编写iterable,则需要另一个类,该类的__iter__
方法返回DataGenerator
的实例。在在How to implement __iter__(self) for a container object (Python)处的答案也可能有帮助。在
相关问题 更多 >
编程相关推荐