<p>我用一种非常简单的方式(只使用了<code>fill_()</code>)<br/>
代码如下:</p>
<pre><code>import torch
import torch.nn as nn
class LeNet300(nn.Module):
def __init__(self):
super(LeNet300, self).__init__()
# Define layers-
self.fc1 = nn.Linear(in_features = 28, out_features = 300)
self.fc2 = nn.Linear(in_features = 300, out_features = 100)
self.output = nn.Linear(in_features = 100, out_features = 10)
self.weights_initialization()
def forward(self, x):
out = F.relu(self.fc1(x))
out = F.relu(self.fc2(out))
return self.output(out)
def weights_initialization(self):
'''
When we define all the modules such as the layers in '__init__()'
method above, these are all stored in 'self.modules()'.
We go through each module one by one. This is the entire network,
basically.
'''
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0)
mask_model = LeNet300()
with torch.no_grad():
for layer in mask_model.state_dict():
print(layer)
#print(torch.ones_like(mask_model.state_dict()[layer].data))
mask_model.state_dict()[layer].data.fill_(1)
mask_model.state_dict()['fc1.weight']
# tensor([[1., 1., 1., ..., 1., 1., 1.],
# [1., 1., 1., ..., 1., 1., 1.],
# [1., 1., 1., ..., 1., 1., 1.],
# ...,
# [1., 1., 1., ..., 1., 1., 1.],
# [1., 1., 1., ..., 1., 1., 1.],
# [1., 1., 1., ..., 1., 1., 1.]])
</code></pre>