Tensorflow 2.0 fit()无法识别批大小

2024-07-02 04:44:39 发布

您现在位置:Python中文网/ 问答频道 /正文

所以我将模型初始化为:

model = tf.keras.utils.multi_gpu_model(model, gpus=NUM_GPUS)当我这样做model.compile()它运行得非常好

但是当我做history = model.fit(tf.cast(X_train, tf.float32), tf.cast(Y_train, tf.float32), validation_split=0.25, batch_size = 16, verbose=1, epochs=100)时,它会给我错误:

OOM when allocating tensor with shape[4760,256,256,3] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc [Op:Cast] name: Cast/

这段代码以前运行得非常好,但在TensorFlow2.0中不再如此。我的训练集中有4760个样本。我不知道为什么要整套而不是批量


Tags: 模型modelgputftrainutilsmultinum
1条回答
网友
1楼 · 发布于 2024-07-02 04:44:39

model.compile()只为训练配置模型,没有任何内存分配

你的bug是自我解释的,你直接把一个大的numpy数组输入到模型中。我建议编写一个新的数据生成器或keras.utils.Sequence来输入数据。如果是这样,您不需要再次在fit方法中指定batch_size,因为您自己的生成器或Sequence将生成批处理

相关问题 更多 >