怎么了nn.模块保存子模块

2024-10-03 23:26:51 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个问题,关于Pythorchnn.模块作品

import torch
import torch.nn as nn



class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.sub_module = nn.Linear(10, 5)
        self.value = 3

net = Net()
print(net.__dict__)

输出

^{pr2}$

我知道类的每个属性都应该存储在中,为什么值(int值)在其中,但是子模块(ann.模块)相反,子模块不是存储在模块中的

我读了密码nn.模块实施,但我没想到。有人有什么想法吗?在

谢谢你!!在


Tags: 模块importselfnetinitdefasnn
1条回答
网友
1楼 · 发布于 2024-10-03 23:26:51

我会尽量保持简单。在

每次在类Net中创建一个新项,例如:self.sub_module = nn.Linear(10, 5),它会调用其父类的方法__setattr__,在本例中是nn.Module。然后,在__setattr__方法中,将参数存储到它们所属的dict中。在本例中,由于nn.Linear是一个模块,因此它被存储到_modulesdict中

下面是在Modulehttps://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L389中执行此操作的代码段

相关问题 更多 >