擅长:python、mysql、java
<p>您可以使用<strong>nn.ModuleList</strong>执行类似操作:</p>
<pre><code>import torch
import torch.nn as nn
import torch.nn.functional as F
class fclist(nn.Module):
def __init__(self, k):
super().__init__()
'''
k: no. of clusters
'''
self.k = k
'''
.
.
.
Other previous layers
.
.
'''
c = 1
self.out_layers = nn.ModuleList()
for i in range(k):
self.out_layers.append(nn.Linear(c*32*32, 2))
def forward(self, x):
'''
.
.
.
pass throgh previous layers
.
.
'''
x = [layer(x) for layer in self.out_layers]
return x
</code></pre>
<p>样本输出:</p>
<pre><code>>>> net = fclist(k=3)
>>> inp = torch.randn(1, 1*32*32)
>>> net(inp)
[tensor([[-0.7319, -0.2686]], grad_fn=<AddmmBackward>), tensor([[-0.6248, 0.9180]], grad_fn=<AddmmBackward>), tensor([[0.2532, 0.1387]], grad_fn=<AddmmBackward>)]
</code></pre>