快速.ai:如何在验证期间获得每批损失

2024-06-17 14:18:18 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在使用由fast.ai实现的AWD-LSTM模型。现在我可以得到所有批次的验证损失平均值:

from fastai.text import *  
data_lm = (TextList.from_csv("data/penn", "concatenated.csv", cols='text')
    .split_from_df("is_valid")
    .label_for_lm()
    .databunch())  
learner = language_model_learner(data_lm, AWD_LSTM, pretrained=False)  
learner.fit_one_cycle(10, 1e-2)  
learner.export("exported.pkl")

itemlist = TextList.from_csv("data/penn", "concatenated.csv", cols='text')  
newlearner = load_learner(path="data/penn", test=itemlist, file="exported.pkl")  
loss, acc = newlearner.validate(newlearner.data.test_dl)

但是如何获得每批的验证损失?你知道吗

我尝试过的事情包括:
1尝试附加Recorder。但似乎Recorder不监视验证,learner.losses只存储每批的列损失。
2使用fastai.basic_train.loss_batch(learner.model, xb, yb, learner.loss_func),其中xbyb只是torch.Tensors。但是这种方法给出了以下AttributeError

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-14-5fa44a2d640f> in <module>
      5 xb = torch.ones((64, 20)).cuda().long()
      6 yb = torch.ones((64, 20)).cuda().long()
----> 7 loss_batch(newlearner.model, xb, yb, newlearner.loss_func)

~/anaconda3/envs/pytorch12/lib/python3.7/site-packages/fastai/basic_train.py in loss_batch(model, xb, yb, loss_func, opt, cb_handler)
     27     out = cb_handler.on_loss_begin(out)
     28     if not loss_func: return to_detach(out), to_detach(yb[0])
---> 29     loss = loss_func(out, *yb)
     30 
     31     if opt is not None:

~/anaconda3/envs/pytorch12/lib/python3.7/site-packages/fastai/layers.py in __call__(self, input, target, **kwargs)
    237 
    238     def __call__(self, input:Tensor, target:Tensor, **kwargs)->Rank0Tensor:
--> 239         input = input.transpose(self.axis,-1).contiguous()
    240         target = target.transpose(self.axis,-1).contiguous()
    241         if self.floatify: target = target.float()

AttributeError: 'tuple' object has no attribute 'transpose'

Tags: csvfromselftargetinputdatamodelout
1条回答
网友
1楼 · 发布于 2024-06-17 14:18:18

我现在有办法了。你知道吗

cb_handler = CallbackHandler(newlearner.callbacks + [], None)
losses, acc = fastai.basic_train.validate(
    newlearner.model, 
    newlearner.data.test_dl, 
    newlearner.loss_func, 
    cb_handler,  # This is necessary
    average=False)

相关问题 更多 >