我想训练DQN玩井字游戏。我训练它玩X(而O移动是随机的)。经过12小时的训练,它可以发挥出良好的,但不是完美无缺的。现在我想同时训练两个网-一个用于X移动,一个用于O移动。 但当我试着模型.预测(状态)在第二个网络上,我会遇到如下错误:
ValueError: Cannot feed value of shape (9,) for Tensor 'InputData/X:0', which has shape '(?, 9)'
但我知道舒尔的网络定义和数据维度是相同的。定义两个dnn是有问题的。在
下面是一个通用示例:
^{pr2}$错误如下:
Traceback (most recent call last):
File "2_dnn_test.py", line 25, in <module>
m2.fit(X, Y, n_epoch = 20)
File "/home/cpro/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tflearn/models/dnn.py", line 157, in fit
self.targets)
File "/home/cpro/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tflearn/utils.py", line 267, in feed_dict_builder
feed_dict[net_inputs[i]] = x
IndexError: list index out of range
错误是不同的,因为在我的tic-tac-toe中,我调用predict on second DNN比执行first fit()要快。如果我在示例中注释掉m2.fit(X, Y, n_epoch = 20)
,我会得到相同的错误:
Traceback (most recent call last):
File "2_dnn_test.py", line 27, in <module>
print(m2.predict([[0.9,0.1]]))
File "/home/cpro/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tflearn/models/dnn.py", line 204, in predict
return self.predictor.predict(feed_dict)
File "/home/cpro/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tflearn/helpers/evaluator.py", line 69, in predict
o_pred = self.session.run(output, feed_dict=feed_dict).tolist()
File "/home/cpro/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 372, in run
run_metadata_ptr)
File "/home/cpro/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 625, in _run
% (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
ValueError: Cannot feed value of shape (2,) for Tensor 'InputData/X:0', which has shape '(?, 2)'
所以两个相同的网络不能同时工作。我该如何让这两种方法都起作用呢?在
BTW示例未获得预期的预测结果:)
看来我应该补充一下
以防止TFLearn将两个模型附加到默认图形中。有了这个附加功能,一切都能正常工作。在
相关问题 更多 >
编程相关推荐