滤波numpy阵列

2024-10-05 10:50:35 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图过滤掉numpy数组中的对象。我希望删除/过滤掉所有高薪球员。例如,所有球员的工资都低于20000

我正在使用的数据集: https://www.kaggle.com/stefanoleone992/fifa-20-complete-player-dataset?select=players_20.csv

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import Pipeline
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from math import sqrt

df20 = pd.read_csv('players_20.csv')

x = df20['age'].values
y = df20['wage_eur'].values

# Training Model
lm = LinearRegression()
lm.fit(x.reshape(-1,1),y.reshape(-1,1))

y_pred=lm.predict(x.reshape(-1,1))

# creating pipeline and fitting it on data
Input=[('polynomial',PolynomialFeatures(degree=2)),('modal',LinearRegression())]
pipe=Pipeline(Input)
pipe.fit(x.reshape(-1,1),y.reshape(-1,1))

poly_pred=pipe.predict(x.reshape(-1,1))

# sorting predicted values with respect to predictor
sorted_zip = sorted(zip(x,poly_pred))
x_poly, poly_pred = zip(*sorted_zip)

# plotting predictions
plt.figure(figsize=(10,6))
plt.scatter(x,y,s=15)
plt.plot(x,y_pred,color='r',label='Linear Regression')
plt.plot(x_poly,poly_pred,color='g',label='Polynomial Regression')
plt.xlabel('Age',fontsize=16)
plt.ylabel('Wage',fontsize=16)
plt.legend()
plt.show()

Tags: csvfromimportaspltsklearnziplm

热门问题