视觉转换器中用图像线替代自我注意的挑战

2024-10-01 00:19:57 发布

您现在位置:Python中文网/ 问答频道 /正文

当我在Vision Transformer中用SelfAttention替换ImageLinearAttention时,代码如下,我得到一个运行时错误。ImageLinearAttention的代码来自https://github.com/lucidrains/linear-attention-transformer/blob/master/linear_attention_transformer/images.py,除了我删除了您在注释代码中看到的通道数之外

class ImageLinearAttention(nn.Module):
    def __init__(self, chan, chan_out = None, kernel_size = 1, padding = 0, stride = 1, key_dim = 64, value_dim = 64, heads = 8, norm_queries = True):
        super().__init__()
        self.chan = chan
        chan_out = chan if chan_out is None else chan_out

        self.key_dim = key_dim
        self.value_dim = value_dim
        self.heads = heads

        self.norm_queries = norm_queries

        conv_kwargs = {'padding': padding, 'stride': stride}
        self.to_q = nn.Conv2d(chan, key_dim * heads, kernel_size, **conv_kwargs)
        self.to_k = nn.Conv2d(chan, key_dim * heads, kernel_size, **conv_kwargs)
        self.to_v = nn.Conv2d(chan, value_dim * heads, kernel_size, **conv_kwargs)
        print('value dim: ', value_dim)
        print('chan out: ', chan_out)
        print('kernel_size: ', kernel_size)
        out_conv_kwargs = {'padding': padding}
        print('out_conv_kwargs: ', out_conv_kwargs)
        print('in_chan: ', value_dim * heads)
        self.to_out = nn.Conv2d(value_dim * heads, chan_out, kernel_size, **out_conv_kwargs)

    def forward(self, x, context = None):
        print('x.shape: ', x.shape)
        print('*x.shape is: ', *x.shape)
        print('heads: ', self.heads)
        #b, c, h, w, k_dim, heads = *x.shape, self.key_dim, self.heads
        b, h, w, k_dim, heads = *x.shape, self.key_dim, self.heads
        q, k, v = (self.to_q(x), self.to_k(x), self.to_v(x))
        q, k, v = map(lambda t: t.reshape(b, heads, -1, h * w), (q, k, v))
        q, k = map(lambda x: x * (self.key_dim ** -0.25), (q, k))
        
        if context is not None:
            #context = context.reshape(b, c, 1, -1)
            context = context.reshape(b, 1, -1)
            ck, cv = self.to_k(context), self.to_v(context)
            ck, cv = map(lambda t: t.reshape(b, heads, k_dim, -1), (ck, cv))
            k = torch.cat((k, ck), dim=3)
            v = torch.cat((v, cv), dim=3)

        k = k.softmax(dim=-1)

        if self.norm_queries:
            q = q.softmax(dim=-2)

        context = torch.einsum('bhdn,bhen->bhde', k, v)
        out = torch.einsum('bhdn,bhde->bhen', q, context)
        out = out.reshape(b, -1, h, w)
        out = self.to_out(out)
        return out

错误是: RuntimeError: Expected 4-dimensional input for 4-dimensional weight [384, 512, 1, 1], but got 3-dimensional input of size [1, 1984, 512] instead

此外,我输入到transformer的数据的大小为torch.size([1983,512]),我的批量大小为1。 完整日志为:

$ bash scripts/train.sh 
train: True test: False cam: False
preparing datasets and dataloaders......
total_train_num:  176
creating models......
n_class:  2
in_dim:  512
value dim:  64
chan out:  512
kernel_size:  1
out_conv_kwargs:  {'padding': 0}
in_chan:  768
in_dim:  512
value dim:  64
chan out:  512
kernel_size:  1
out_conv_kwargs:  {'padding': 0}
in_chan:  768

=>Epoches 1, learning rate = 0.0010000, previous best = 0.0000
torch.Size([1983, 512])
features size:  torch.Size([1983, 512])
/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/optim/lr_scheduler.py:129: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
  warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/optim/lr_scheduler.py:154: UserWarning: The epoch parameter in `scheduler.step()` was not necessary and is being deprecated where possible. Please use `scheduler.step()` to step the scheduler. During the deprecation, if epoch is different from None, the closed form is used instead of the new chainable form, where available. Please open an issue if you are unable to replicate your use case: https://github.com/pytorch/pytorch/issues/new/choose.
  warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
max_feature_num:  1983
batch feature size:  torch.Size([1, 1983, 512])
x.shape:  torch.Size([1, 1984, 512])
*x.shape is:  1 1984 512
heads:  12
Traceback (most recent call last):
  File "main.py", line 148, in <module>
    preds,labels,loss = trainer.train(sample_batched, model)
  File "/SeaExp/mona/research/code/cc/helper.py", line 71, in train
    pred,labels,loss = model.forward(feats, labels, masks)
  File "/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 166, in forward
    return self.module(*inputs[0], **kwargs[0])
  File "/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/SeaExp/mona/research/code/cc/models/Transformer.py", line 31, in forward
    out = self.transformer(X)
  File "/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/SeaExp/mona/research/code/cc/models/linear_att_ViT.py", line 262, in forward
    feat = self.transformer(emb)
  File "/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/SeaExp/mona/research/code/cc/models/linear_att_ViT.py", line 206, in forward
    out = layer(out)
  File "/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/SeaExp/mona/research/code/cc/models/linear_att_ViT.py", line 174, in forward
    out = self.attn(out)
  File "/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/SeaExp/mona/research/code/cc/models/linear_att_ViT.py", line 92, in forward
    q, k, v = (self.to_q(x), self.to_k(x), self.to_v(x))
  File "/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 443, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 439, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Expected 4-dimensional input for 4-dimensional weight [384, 512, 1, 1], but got 3-dimensional input of size [1, 1984, 512] instead

原始SelfAttention代码是:

class SelfAttention(nn.Module):
    def __init__(self, in_dim, heads=8, dropout_rate=0.1):
        super(SelfAttention, self).__init__()
        self.heads = heads
        self.head_dim = in_dim // heads
        self.scale = self.head_dim ** 0.5
        
        self.query = LinearGeneral((in_dim,), (self.heads, self.head_dim))
        self.key = LinearGeneral((in_dim,), (self.heads, self.head_dim))
        self.value = LinearGeneral((in_dim,), (self.heads, self.head_dim))
        self.out = LinearGeneral((self.heads, self.head_dim), (in_dim,))

        if dropout_rate > 0:
            self.dropout = nn.Dropout(dropout_rate)
        else:
            self.dropout = None

    def forward(self, x):
        b, n, _ = x.shape

        q = self.query(x, dims=([2], [0]))
        k = self.key(x, dims=([2], [0]))
        v = self.value(x, dims=([2], [0]))

        q = q.permute(0, 2, 1, 3)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)

        attn_weights = torch.matmul(q, k.transpose(-2, -1)) / self.scale
        attn_weights = F.softmax(attn_weights, dim=-1)
        out = torch.matmul(attn_weights, v)
        out = out.permute(0, 2, 1, 3)

        out = self.out(out, dims=([2, 3], [0, 1]))

        return out

如何修复此错误?我在Vision Transformer的编码器块中调用ImageSelfAttention,如下所示:

class EncoderBlock(nn.Module):
    def __init__(self, in_dim, mlp_dim, num_heads, dropout_rate=0.1, attn_dropout_rate=0.1):
        super(EncoderBlock, self).__init__()

        self.norm1 = nn.LayerNorm(in_dim)
        #self.attn = SelfAttention(in_dim, heads=num_heads, dropout_rate=attn_dropout_rate)
        ## note Mona: not sure if I am correctly passing the params
        # what about attn_dropout_rate=0.1
        ## I don't know 
        print('in_dim: ', in_dim) 
        self.attn = ImageLinearAttention(chan=in_dim, heads=num_heads, key_dim=32)
        if dropout_rate > 0:
            self.dropout = nn.Dropout(dropout_rate)
        else:
            self.dropout = None
        self.norm2 = nn.LayerNorm(in_dim)
        self.mlp = MlpBlock(in_dim, mlp_dim, in_dim, dropout_rate)

    def forward(self, x):
        residual = x
        out = self.norm1(x)
        out = self.attn(out)
        if self.dropout:
            out = self.dropout(out)
        out += residual
        residual = out

        out = self.norm2(out)
        out = self.mlp(out)
        out += residual
        return out

SelfAttention的代码以及如何在编码器中使用它主要来自https://github.com/asyml/vision-transformer-pytorch/blob/main/src/model.py


Tags: toinpyselfnntorchoutkwargs
1条回答
网友
1楼 · 发布于 2024-10-01 00:19:57

看起来图像自我关注作用于适合图像的4维形状输入(批次、尺寸、高度、宽度),而自我关注作用于适合NLP任务的3维形状输入(批次、序列长度、尺寸)。也许在自我关注之前,输入必须被重塑

相关问题 更多 >