从头开始训练huggingface的GPT2:断言n_状态%config.n_头==0错误

2024-10-03 23:20:26 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试将GPT2架构用于音乐应用程序,因此需要从头开始训练它。通过谷歌搜索,我发现huggingface的github发布的1714已经“解决”了这个问题。当我尝试运行建议解决方案时:

from transformers import GPT2Config, GPT2Model

NUMLAYER = 4
NUMHEAD = 4
SIZEREDUCTION = 10 #the factor by which we reduce the size of the velocity argument.
VELSIZE = int(np.floor(127/SIZEREDUCTION)) + 1 
SEQLEN=40 #size of data sequences.
EMBEDSIZE = 5 

config = GPT2Config(vocab_size = VELSIZE, n_positions = SEQLEN, n_embd = EMBEDSIZE, n_layer = NUMLAYER, n_ctx = SEQLEN, n_head = NUMHEAD)  
model = GPT2Model(config)

我得到以下错误:

Traceback (most recent call last):

  File "<ipython-input-7-b043a7a2425f>", line 1, in <module>
    runfile('C:/Users/cnelias/Desktop/PHD/Swing project/code/script/GPT2.py', wdir='C:/Users/cnelias/Desktop/PHD/Swing project/code/script')

  File "C:\Users\cnelias\Anaconda3\lib\site-packages\spyder_kernels\customize\spydercustomize.py", line 786, in runfile
    execfile(filename, namespace)

  File "C:\Users\cnelias\Anaconda3\lib\site-packages\spyder_kernels\customize\spydercustomize.py", line 110, in execfile
    exec(compile(f.read(), filename, 'exec'), namespace)

  File "C:/Users/cnelias/Desktop/PHD/Swing project/code/script/GPT2.py", line 191, in <module>
    model = GPT2Model(config)

  File "C:\Users\cnelias\Anaconda3\lib\site-packages\transformers\modeling_gpt2.py", line 355, in __init__
    self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])

  File "C:\Users\cnelias\Anaconda3\lib\site-packages\transformers\modeling_gpt2.py", line 355, in <listcomp>
    self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])

  File "C:\Users\cnelias\Anaconda3\lib\site-packages\transformers\modeling_gpt2.py", line 223, in __init__
    self.attn = Attention(nx, n_ctx, config, scale)

  File "C:\Users\cnelias\Anaconda3\lib\site-packages\transformers\modeling_gpt2.py", line 109, in __init__
    assert n_state % config.n_head == 0

这意味着什么?我如何解决它

更一般地说,是否有关于如何使用GPT2进行转发呼叫的文档?我可以定义自己的train()函数吗?还是必须使用模型的内置函数?我是被迫使用Dataset来进行训练,还是可以给它单独的张量? 我找了一下,但在医生上找不到答案,但也许我遗漏了什么

附:我已经读了huggingface.co的博客,但是它遗漏了太多的信息和细节,对我的申请没有用处


Tags: inpyconfiglibpackageslinesiteusers
1条回答
网友
1楼 · 发布于 2024-10-03 23:20:26

我认为错误信息非常清楚:

assert n_state % config.n_head == 0

追溯到the code,我们可以看到

n_state = nx # in Attention: n_state=768

这表明n_state表示嵌入维度(在类似BERT的模型中,默认情况下通常为768)。当我们再看一下GPT-2 documentation时,似乎指定它的参数是n_embd,您将其设置为5。如错误所示,嵌入维度必须通过指定为4的注意头数均匀划分。因此,选择不同的嵌入维度作为4的倍数应该可以解决这个问题。当然,您也可以首先更改头的数量,但似乎不支持奇数嵌入维度

相关问题 更多 >