如何序列化/反序列化pybrain网络?

2024-09-29 17:09:46 发布

您现在位置:Python中文网/ 问答频道 /正文

PyBrain是一个python库,它提供(除其他外)易于使用的人工神经网络。在

我无法使用pickle或cPickle正确序列化/反序列化PyBrain网络。在

请参见以下示例:

from pybrain.datasets            import SupervisedDataSet
from pybrain.tools.shortcuts     import buildNetwork
from pybrain.supervised.trainers import BackpropTrainer
import cPickle as pickle
import numpy as np 

#generate some data
np.random.seed(93939393)
data = SupervisedDataSet(2, 1)
for x in xrange(10):
    y = x * 3
    z = x + y + 0.2 * np.random.randn()  
    data.addSample((x, y), (z,))

#build a network and train it    

net1 = buildNetwork( data.indim, 2, data.outdim )
trainer1 = BackpropTrainer(net1, dataset=data, verbose=True)
for i in xrange(4):
    trainer1.trainEpochs(1)
    print '\tvalue after %d epochs: %.2f'%(i, net1.activate((1, 4))[0])

这是上述代码的输出:

^{pr2}$

如您所见,网络总误差随着培训的进行而减少。您还可以看到预测值接近预期值12。在

现在我们将进行类似的练习,但将包括序列化/反序列化:

print 'creating net2'
net2 = buildNetwork(data.indim, 2, data.outdim)
trainer2 = BackpropTrainer(net2, dataset=data, verbose=True)
trainer2.trainEpochs(1)
print '\tvalue after %d epochs: %.2f'%(1, net2.activate((1, 4))[0])

#So far, so good. Let's test pickle
pickle.dump(net2, open('testNetwork.dump', 'w'))
net2 = pickle.load(open('testNetwork.dump'))
trainer2 = BackpropTrainer(net2, dataset=data, verbose=True)
print 'loaded net2 using pickle, continue training'
for i in xrange(1, 4):
        trainer2.trainEpochs(1)
        print '\tvalue after %d epochs: %.2f'%(i, net2.activate((1, 4))[0])

这是此块的输出:

creating net2
Total error: 176.339378639
    value after 1 epochs: 5.45
loaded net2 using pickle, continue training
Total error: 123.392181859
    value after 1 epochs: 5.45
Total error: 94.2867637623
    value after 2 epochs: 5.45
Total error: 78.076711114
    value after 3 epochs: 5.45

如您所见,似乎训练对网络有一些影响(报告的总误差值继续减小),但是网络的输出值冻结在与第一次训练迭代相关的值上。在

是否有任何缓存机制需要我注意,它会导致这种错误行为?有更好的方法来序列化/反序列化pybrain网络吗?在

相关版本号:

  • Python 2.6.5(r265:790962010年3月19日,21:48:26)[MSCV.1500 32位(Intel)]
  • 数字1.5.1
  • C盘1.71
  • 大脑0.3

另外,我已经在项目的站点上创建了a bug report,并将保持SO和bug跟踪器的更新


Tags: import网络data序列化valueerrorpickletotal
1条回答
网友
1楼 · 发布于 2024-09-29 17:09:46

原因

导致这种行为的机制是对PyBrain模块中的参数(.params)和导数(.derivs)的处理:事实上,所有的网络参数都存储在一个数组中,但是单个的Module或{}对象可以访问“他们自己的”.params,但这只是整个数组的一个部分的视图。这允许在同一数据结构上同时进行本地和网络范围的写入和读取。在

显然,这个切片视图链接会因为酸洗而丢失。在

解决方案

插入

net2.sorted = False
net2.sortModules()

从文件(它重新创建此共享)加载后,它应该可以工作。在

相关问题 更多 >

    热门问题