如何在JAX中从数据集读取数据

2024-09-28 01:32:26 发布

您现在位置:Python中文网/ 问答频道 /正文

我是JAX的新手。我有下面的代码,其中“特征矩阵”作为数组,“目标向量”作为数组。但我不希望程序读取这些数据数组。这些数组已经存在于代码中。我想修改代码,以便读取我导入的波士顿房价数据集。有人能告诉我,我需要对这段代码做些什么修改,才能使线性回归工作吗

import jax.numpy as np
from jax import grad, jit

from sklearn.datasets import load_boston
import sklearn.linear_model as sk

boston = load_boston()
X = np.array(boston.data)
y = np.array(boston.target)

def J(X, w, b, y):
    """Cost function for a linear regression. A forward pass of our model.

    Args:
        X: a features matrix.
        w: weights (a column vector).
        b: a bias.
        y: a target vector.

    Returns:
        scalar: a cost of this solution.    
    """
    y_hat = X.dot(w) + b # Predict values.
    return ((y_hat - y)**2).mean() # Return cost.

# A features matrix.
X = np.array([
                 [4., 7.],
                 [1., 8.],
                 [-5., -6.],
                 [3., -1.],
                 [0., 9.]
             ])

# A target column vector.
y = np.array([
                 [37.],
                 [24.],
                 [-34.], 
                 [16.],
                 [21.]
             ])

learning_rate = 0.01

w = np.zeros((2, 1))
b = 0.

%timeit grad(J, argnums=1)(X, w, b, y)

%timeit grad(J, argnums=2)(X, w, b, y)

for i in range(100):
    w -= learning_rate * grad(J, argnums=1)(X, w, b, y)
    b -= learning_rate * grad(J, argnums=2)(X, w, b, y)
    
    if i % 10 == 0:
        print(J(X, w, b, y))

Tags: 数据代码importtargetrateasnp数组

热门问题