由于内存问题,我正在处理语言建模问题并使用predict_生成器函数。我面临的问题是predict_生成器给出的预测比输入的大小要多。在
我在predict_generator函数中提供的参数:
predictions = model.predict_generator(testDataGenerator(statements),
use_multiprocessing=True,workers=4,
steps=25,
verbose=1)
发电机功能:
^{pr2}$我总共有1568个输入,我将一批发送64个,但我得到了1600个预测。错误输出为:
25/25 [==============================] - 47s 2s/step
IndexError: Length of values does not match length of index
我认为我在这里问题的生成器函数中发送语句的方式。在
如果使用自定义生成器,则必须对预测器上的最后一步保持谨慎。在
由于您正在执行25个步骤,批处理大小为64,生成器希望您的数据正好是1600,我认为在生成器中使用一个简单的if来更改端点应该可以解决您的问题。在
相关问题 更多 >
编程相关推荐