我试图理解argnums
在JAX的梯度函数中的行为。
假设我有以下函数:
def make_mse(x, t):
def mse(w,b):
return np.sum(jnp.power(x.dot(w) + b - t, 2))/2
return mse
我用下面的方法计算梯度:
w_gradient, b_gradient = grad(make_mse(train_data, y), (0,1))(w,b)
argnums= (0,1)
在这种情况下,它意味着什么?关于哪些变量计算梯度?如果我改用argnums=0
,会有什么不同?
另外,我可以使用相同的函数来获得Hessian矩阵吗
我看了关于它的JAX help部分,但想不出来
将多个argnum传递给grad时,结果是一个返回渐变元组的函数,相当于单独计算每个渐变:
如果要计算混合二阶导数,可以重复应用梯度:
相关问题 更多 >
编程相关推荐