用pytorch编写的一个简单的crf模块。实施是基于https://github.com/allenai/allennlp/blob/master/allennlp/modules/conditional\u random\u field.py
pytorch-text-crf的Python项目详细描述
Pythorch文本CRF
这个包包含一个用于使用条件随机字段(CRF)的简单包装器。这个代码是基于优秀的Allen NLP实现的CRF。在
安装
pip install pytorch-text-crf
使用
^{pr2}$LSTM CRF实施
有关完整的工作实现,请参阅https://github.com/iamsimha/pytorch-text-crf/blob/master/examples/pos_tagging/train.ipynb。在
fromcrf.crfimportConditionalRandomFieldclassLSTMCRF:""" An Example implementation for using a CRF model on top of LSTM. """def__init__(self):......# Initilize the conditional CRF modelself.crf=ConditionalRandomField(n_class,# Number of tagslabel_encoding="BIO",# Label encoding formatidx2tag=idx2tag# Dict mapping index to a tag)defforward(self,inputs,tags):logits=self.lstm(inputs)# logits dim:(batch_size, seq_length, num_tags)mask=inputs!="<pad token>"# mask for ignoring pad tokens. mask dim: (batch_size, seq_length)log_likelihood=self.crf(logits,tags,mask)loss=-log_likelihood# Log likelihood is not normalized (It is not divided by the batch size).# To obtain the best sequence using viterbi decodingbest_tag_sequence=self.crf.best_viterbi_tag(logits,mask)# To obtain output similar to the lstm prediction we can use the below codeclass_probabilities=out*0.0fori,instance_tagsinenumerate(best_tag_sequence):forj,tag_idinenumerate(instance_tags[0][0]):class_probabilities[i,j,int(tag_id)]=1return{"loss":loss,"class_probabilities":class_probabilities}# Traininglstm_crf=LSTMCRF()output=lstm_crf(sentences,tags)loss=output["loss"]loss.backward()optimizer.step()
- 项目
标签: