我想使用vmap将此代码矢量化以提高性能
def matrix(dataA, dataB):
return jnp.array([[func(a, b) for b in dataB] for a in dataA])
matrix(data, data)
我试过这个:
def f(x, y):
return func(x, y)
mapped = jax.vmap(f)
mapped(data, data)
但这只给出对角线条目
基本上我有一个向量data = [1,2,3,4,5]
(示例),我想得到一个矩阵,使得矩阵的每个条目(i, j)
都是f(data[i], data[j])
。因此,得到的矩阵形状将是(len(data), len(data))
jax.vmap
一次映射一组轴。如果要跨两组独立的轴进行映射,可以通过嵌套两个vmap
变换来实现:相关问题 更多 >
编程相关推荐