RandomForest索引器错误:仅限整数、切片(`:`)、省略号(`…`),新轴(`None`)和整数或布尔数组是有效的索引

2024-10-01 09:24:53 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在sklearnRandomForestClassifier上工作:

class RandomForest(RandomForestClassifier):

    def fit(self, x, y):
        self.unique_train_y,  y_classes = transform_y_vectors_in_classes(y)
        return RandomForestClassifier.fit(self, x, y_classes)

    def predict(self, x):
        y_classes = RandomForestClassifier.predict(self, x)
        predictions = transform_classes_in_y_vectors(y_classes, self.unique_train_y)
        return predictions

    def transform_classes_in_y_vectors(y_classes, unique_train_y):
        cyr = [unique_train_y[predicted_index] for predicted_index in y_classes]
        predictions = np.array(float(cyr))
        return predictions

我收到了这个错误消息:

^{pr2}$

Tags: inselfreturndeftransformtrainpredictfit
1条回答
网友
1楼 · 发布于 2024-10-01 09:24:53

似乎y_classes包含的值不是有效的索引。在

当您试图用predicted_index访问unique_train_y时,您会得到一个异常,正如所预测的那样,索引并不是您想象的那样。在

尝试执行以下代码:

cyr = [unique_train_y[predicted_index] for predicted_index in range(len(y_classes))] 
# assuming unique_train_y is a list and predicted_index should be integer.

相关问题 更多 >