输入长度不匹配

2024-10-06 07:53:23 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图使用DecisionTreeClassifier进行一些分析,但它给了我以下错误:

ValueError: Number of features of the model must match the input. Model n_features is 1 and input n_features is 4

我对SVCGaussianNB分类器使用了相同的训练集和测试集,这两个分类器都工作得很好。下面是我的代码,我知道测试集和训练集有相同的设计,也就是说,在被矢量化之前,它们采用包含字符串的列表的形式。我不知道错配是从哪里来的

#classify with just scikit

from sklearn.feature_extraction.text import TfidfVectorizer
from tools.striper import stripe, cleanupfiles
from tools.tweetprocessor import clean, wordclean

from sklearn import svm
from sklearn.naive_bayes import GaussianNB, MultinomialNB
from sklearn.metrics import classification_report
from sklearn import tree

stripe(0.1)

training = []
traininglabel = []
test = []
testlabel = []

with open('tempdata/goodtraining.txt','r') as f:
    for line in f:
        tweet = [wordclean(x) for x in clean(line.rstrip('\n')).split()]
        tweet = [x for x in tweet if len(x) >= 3]
        training.append(' '.join(tweet))
        traininglabel.append('good')
with open('tempdata/badtraining.txt','r') as f:
    for line in f:
        tweet = [wordclean(x) for x in clean(line.rstrip('\n')).split()]
        tweet = [x for x in tweet if len(x) >= 3]
        training.append(' '.join(tweet))
        traininglabel.append('bad')
with open('tempdata/goodtest.txt','r') as f:
    for line in f:
        tweet = [wordclean(x) for x in clean(line.rstrip('\n')).split()]
        test.append(' '.join(tweet))
        testlabel.append('good')
with open('tempdata/badtest.txt','r') as f:
    for line in f:
        tweet = [wordclean(x) for x in clean(line.rstrip('\n')).split()]
        test.append(' '.join(tweet))
        testlabel.append('bad')

vectorizer = TfidfVectorizer(min_df=0.1,max_df=0.9)
train_vect = vectorizer.fit_transform(training)
test_vect = vectorizer.fit_transform(test)

print (train_vect)
print (test_vect)

classifier = tree.DecisionTreeClassifier()
classifier.fit(train_vect.toarray(), traininglabel)
predictions = classifier.predict(test_vect.toarray())

print (classification_report(testlabel, predictions))

cleanupfiles()

Tags: infromtestimportcleanforwithline