如何在sklearn中使用BERT和Elmo嵌入

2024-09-30 22:18:28 发布

您现在位置:Python中文网/ 问答频道 /正文

我使用sklearn创建了一个使用Tf-Idf的文本分类器,我想使用BERT和Elmo嵌入来代替Tf-Idf

你会怎么做

我正在使用下面的代码来获取Bert嵌入:

from flair.data import Sentence
from flair.embeddings import TransformerWordEmbeddings

# init embedding
embedding = TransformerWordEmbeddings('bert-base-uncased')

# create a sentence
sentence = Sentence('The grass is green .')

# embed words in sentence
embedding.embed(sentence)
import pandas as pd
import numpy as np

from sklearn.compose import ColumnTransformer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import MinMaxScaler
from sklearn.linear_model import LogisticRegression

column_trans = ColumnTransformer([
    ('tfidf', TfidfVectorizer(), 'text'),
    ('number_scaler', MinMaxScaler(), ['number'])
])

# Initialize data
data = [
    ['This process, however, afforded me no means of.', 20, 1],
    ['another long description', 21, 1],
    ['It never once occurred to me that the fumbling', 19, 0],
    ['How lovely is spring As we looked from Windsor', 18, 0]
]

# Create DataFrame
df = pd.DataFrame(data, columns=['text', 'number', 'target'])

X = column_trans.fit_transform(df)
X = X.toarray()
y = df.loc[:, "target"].values

# Perform classification

classifier = LogisticRegression(random_state=0)
classifier.fit(X, y)

Tags: textfromimportnumberdfdataistf
1条回答
网友
1楼 · 发布于 2024-09-30 22:18:28

Sklearn提供了定制data transformer的可能性(与机器学习模型“transformers”无关)

我实现了一个自定义的sklearn数据转换器,它使用您使用的flair库。请注意,我使用了TransformerDocumentEmbeddings而不是TransformerWordEmbeddings。还有一个是与transformers库一起工作的

我添加了一个SO问题,讨论使用here感兴趣的转换器层

我不熟悉Elmo,尽管我发现this使用tensorflow。您可以修改我共享的代码,使Elmo正常工作

import torch
import numpy as np
from flair.data import Sentence
from flair.embeddings import TransformerDocumentEmbeddings
from sklearn.base import BaseEstimator, TransformerMixin


class FlairTransformerEmbedding(TransformerMixin, BaseEstimator):

    def __init__(self, model_name='bert-base-uncased', batch_size=None, layers=None):
        # From https://lvngd.com/blog/spacy-word-vectors-as-features-in-scikit-learn/
        # For pickling reason you should not load models in __init__
        self.model_name = model_name
        self.model_kw_args = {'batch_size': batch_size, 'layers': layers}
        self.model_kw_args = {k: v for k, v in self.model_kw_args.items()
                              if v is not None}
    
    def fit(self, X, y=None):
        return self
    
    def transform(self, X):
        model = TransformerDocumentEmbeddings(
                self.model_name, fine_tune=False,
                **self.model_kw_args)

        sentences = [Sentence(text) for text in X]
        embedded = model.embed(sentences)
        embedded = [e.get_embedding().reshape(1, -1) for e in embedded]
        return np.array(torch.cat(embedded).cpu())

import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from transformers import AutoTokenizer, AutoModel
from more_itertools import chunked

class TransformerEmbedding(TransformerMixin, BaseEstimator):

    def __init__(self, model_name='bert-base-uncased', batch_size=1, layer=-1):
        # From https://lvngd.com/blog/spacy-word-vectors-as-features-in-scikit-learn/
        # For pickling reason you should not load models in __init__
        self.model_name = model_name
        self.layer = layer
        self.batch_size = batch_size
    
    def fit(self, X, y=None):
        return self
    
    def transform(self, X):
        tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        model = AutoModel.from_pretrained(self.model_name)

        res = []
        for batch in chunked(X, self.batch_size):
            encoded_input = tokenizer.batch_encode_plus(
                batch, return_tensors='pt', padding=True, truncation=True)
            output = model(**encoded_input)
            embed = output.last_hidden_state[:,self.layer].detach().numpy()
            res.append(embed)

        return np.concatenate(res)

在您的情况下,用以下方法更换柱变压器:

column_trans = ColumnTransformer([
    ('embedding', FlairTransformerEmbedding(), 'text'),
    ('number_scaler', MinMaxScaler(), ['number'])
])

相关问题 更多 >