Metal train得到了一个意外的关键字参数“n_epochs”

2024-09-29 23:18:55 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在编写一个Python代码,在这部分代码中,我想使用Metal来训练我的模型,如下所示:

from metal.label_model import LabelModel
gen_model = LabelModel()
%time gen_model.train(L_train[0], n_epochs=500, print_every=100)

但这给了我们:

^{pr2}$

Tags: 代码from模型importmodeltimetrainlabel
1条回答
网友
1楼 · 发布于 2024-09-29 23:18:55

在0.3.0中有一个变化:

'Renames Classifier.train to Classifier.train_model to avoid overwriting the nn.Module.train function'

因此,请尝试使用train_model()而不是train()

from metal.label_model import LabelModel
gen_model = LabelModel()
%time gen_model.train_model(L_train[0], n_epochs=500, print_every=100)

资料来源:

https://github.com/HazyResearch/metal/commit/4210c7c66f3f4a6fc7287192aec133c293ed8198

相关问题 更多 >

    热门问题