没有隐藏层和线性激活函数的神经网络应该近似线性回归?

2024-09-24 20:29:19 发布

您现在位置:Python中文网/ 问答频道 /正文

据我所知,假设不使用隐藏层和线性激活函数,神经网络将产生与线性回归相同的方程形式。i、 e.y=总和(w_i*x_i+b_i),其中i为0,表示您拥有的特征数量

我试图通过使用线性回归的权重和偏差来证明这一点,并将其输入神经网络,看看结果是否相同。事实并非如此

我想知道我的理解是错误的,还是我的代码是错误的,或者两者都是错误的


from sklearn.linear_model import LinearRegression
import tensorflow as tf
from tensorflow import keras
import numpy as np

linearModel = LinearRegression()
linearModel.fit(np.array(normTrainFeaturesDf), np.array(trainLabelsDf))

# Gets the weights of the linear model and the intercept in a form that can be passed into the neural network
linearWeights = np.array(linearModel.coef_)
intercept = np.array([linearModel.intercept_])

trialWeights = np.reshape(linearWeights, (len(linearWeights), 1))
trialWeights = trialWeights.astype('float32')
intercept = intercept.astype('float32')
newTrialWeights = [trialWeights, intercept]

# Create a neural network and set the weights of the model to the linear model
nnModel = keras.Sequential([keras.layers.Dense(1, activation='linear', input_shape=[len(normTrainFeaturesDf.keys())]),])
nnModel.set_weights(newTrialWeights)

# Print predictions of both models (the results are vastly different)
print(linearModel.predict(np.array(normTestFeaturesDf))
print(nnModel.predict(normTestFeaturesDf).flatten())


Tags: oftheimportmodel错误np线性array
1条回答
网友
1楼 · 发布于 2024-09-24 20:29:19

是的,一个单层无激活函数的神经网络相当于线性回归

定义一些未包含的变量:

normTrainFeaturesDf = np.random.rand(100, 10)
normTestFeaturesDf = np.random.rand(10, 10)
trainLabelsDf = np.random.rand(100)

然后输出如预期的那样:

>>> linear_model_preds = linearModel.predict(np.array(normTestFeaturesDf))
>>> nn_model_preds = nnModel.predict(normTestFeaturesDf).flatten()

>>> print(linear_model_preds)
>>> print(nn_model_preds)
[0.46030349 0.69676376 0.43064266 0.4583325  0.50750268 0.51753189
 0.47254946 0.50654825 0.52998559 0.35908762]
[0.46030346 0.69676375 0.43064266 0.45833248 0.5075026  0.5175319
 0.47254944 0.50654817 0.52998555 0.3590876 ]

除了浮动精度引起的微小变化外,这些数字是相同的

>>> np.allclose(linear_model_preds, nn_model_preds)
True

相关问题 更多 >