在不调用yield()的情况下为Keras的fit_generator方法()实现“generator”是否合理?

2024-10-03 23:21:41 发布

您现在位置:Python中文网/ 问答频道 /正文

我用Keras来训练神经网络,我已经到了数据集比计算机上安装的RAM大的程度,所以是时候修改我的训练脚本来调用了model.fit_发电机()而不是模型.拟合(),所以我不必一次将所有的培训和验证数据加载到RAM中。在

我已经做了修改,但有一件事让我有点困扰——我在网上看到的所有fit_generator()的示例用法都使用Python的yield特性来存储生成器的状态。在本质上,我是一个老C++程序员,怀疑我不完全理解的特征,如^ {< CD1>},因此我想显式而不是隐式地维护我的生成器状态,因此我实现了我的生成器:

class DataGenerator:
   def __init__(self, inputFileName, maxExamplesPerBatch):
      self._inputFileName       = inputFileName
      self._maxExamplesPerBatch = maxExamplesPerBatch

      self._inputsFile = open(inputFileName, "rb")
      if (self._inputsFile == None):
         self._print("Couldn't open file %s to read input data" % inputFileName)
         sys.exit(10)

      self._outputsFile = open(inputFileName, "rb")   # yes, we're deliberately opening the same file twice (to avoid having to call seek() a lot)
      if (self._outputsFile == None):
         self._print("Couldn't open file %s to read output data" % inputFileName)
         sys.exit(10)

      headerInfo = struct.unpack("<4L", self._inputsFile.read(16))          
      if (headerInfo[0] != 1414676815):
         print("Bad magic number in input file [%s], aborting!" % inputFileName)
         sys.exit(10)

      self._numExamples   = headerInfo[1]  # Number of input->output rows in our data-file (typically quite large, i.e. millions)
      self._numInputs     = headerInfo[2]  # Number of input values in each row
      self._numOutputs    = headerInfo[3]  # Number of output values in row
      self.seekToTopOfData()

   def __len__(self):
      return (math.ceil(self._numExamples/self._maxExamplesPerBatch))

   def __next__(self):
      numExamplesToLoad = self._maxExamplesPerBatch
      numExamplesLeft   = self._numExamples - self._curExampleIdx
      if (numExamplesLeft < numExamplesToLoad):
         numExamplesToLoad = numExamplesLeft
      inputData  = np.reshape(np.fromfile(self._inputsFile,  dtype='<f4', count=(numExamplesToLoad*self._numInputs)),  (numExamplesToLoad, self._numInputs))
      outputData = np.reshape(np.fromfile(self._outputsFile, dtype='<f4', count=(numExamplesToLoad*self._numOutputs)), (numExamplesToLoad, self._numOutputs))
      self._curExampleIdx += numExamplesToLoad
      if (self._curExampleIdx == self._numExamples):
         self.seekToTopOfData()
      return (inputData, outputData)   # <----- NOTE return, not yield!!

   def seekToTopOfData(self):
      self._curExampleIdx = 0
      self._inputsFile.seek(16)
      self._outputsFile.seek(16+(self._numExamples*self._numInputs*4))

[...]

trainingDataGenerator   = DataGenerator(trainingInputFileName, maxExamplesPerBatch)
validationDataGenerator = DataGenerator(validationInputFileName, maxExamplesPerBatch)

model.fit_generator(generator=trainingDataGenerator, steps_per_epoch=len(trainingDataGenerator), epochs=maxEpochs, callbacks=callbacks_list, validation_data=validationDataGenerator, validation_steps=len(validationDataGenerator))

。。。注意,我的函数以return而不是yield结尾,并且我显式地(通过DataGenerator对象中的私有成员变量)而不是隐式地(通过yield魔术)存储生成器的状态。这似乎很管用。在

我的问题是,这种不寻常的方法是否会引入任何我应该意识到的不明显的行为问题?在


Tags: toselfinputifdefopenfileyield
1条回答
网友
1楼 · 发布于 2024-10-03 23:21:41

对代码进行一次表面检查就可以了。当您编写一个生成器函数并调用它时,调用将返回一个生成器,该生成器的__next__方法通常会被迭代反复调用,直到它引发StopIteration异常。在

生成器是迭代器的特例。类似Iterables的列表有一个生成迭代器的__iter__方法。在

除非您想将值发送到生成器并将其取出,否则您的DataGenerator是实现迭代器的合理方法,但要编写iterable,则需要另一个类,该类的__iter__方法返回DataGenerator的实例。在

How to implement __iter__(self) for a container object (Python)处的答案也可能有帮助。在

相关问题 更多 >