<p>更新:这个解决方案使用numpy的<code>np.interp</code>,它将把这些点作为一种“最佳匹配”连接起来。然后我们使用误差函数来找出插值线和每个多项式次数的预测y值之间的差异。在</p>
<pre><code>import numpy as np
import matplotlib.pyplot as plt
import itertools
dataTrain = [
[2.362761180904257019e-01, -4.108125266714775847e+00],
[4.324296163702689988e-01, -9.869308732049049127e+00],
[6.023323504115264404e-01, -6.684279243433971729e+00],
[3.305079685397107614e-01, -7.897042003779912278e+00],
[9.952423271981121200e-01, 3.710086310489402628e+00],
[8.308127402955634011e-02, 1.828266768673480147e+00],
[1.855495407116576345e-01, 1.039713135916495501e+00],
[7.088332047815845138e-01, -9.783208407540947560e-01],
[9.475723071629885697e-01, 1.137746192425550085e+01],
[2.343475721257285427e-01, 3.098019704040922750e+00],
[9.338350584099475160e-02, 2.316408265530458976e+00],
[2.107903139601833287e-01, -1.550451474833406396e+00],
[9.509966727520677843e-01, 9.295029459100994984e+00],
[7.164931165416982273e-01, 1.041025972594300075e+00],
[2.965557300301902011e-03, -1.060607693351102121e+01]
]
data = np.array(dataTrain)
data = data[data[:, 0].argsort()]
X,y = data[:, 0], data[:, 1]
fig,ax = plt.subplots(4, 4)
indices = list(itertools.product([0,1,2,3], repeat=2))
for i,loc in enumerate(indices, start=1):
xx = np.linspace(X.min(), X.max(), 1000)
yy = np.interp(xx, X, y)
w = np.polyfit(X, y, i)
y_pred = np.polyval(w, xx)
ax[loc].scatter(X, y)
ax[loc].plot(xx, y_pred)
ax[loc].plot(xx, yy, 'r ')
error = np.square(yy - y_pred).sum() / X.shape[0]
print(error)
plt.show()
</code></pre>
<p>打印出来:</p>
^{pr2}$
<p>从视觉上看,它描绘出:</p>
<p><a href="https://i.stack.imgur.com/UG9Mz.png" rel="nofollow noreferrer"><img src="https://i.stack.imgur.com/UG9Mz.png" alt="enter image description here"/></a></p>
<p>从这里开始,只需将这些错误保存到一个列表中并找到最小值。在</p>