Jax中的vmap ops.index_更新

2024-09-26 22:51:18 发布

您现在位置:Python中文网/ 问答频道 /正文

我有下面的代码,它使用了一个简单的for循环。我只是想知道是否有一种方法可以将其vmap?以下是原始代码:

import numpy as np 
import jax.numpy as jnp
import jax.scipy.signal as jscp
from scipy import signal
import jax

data = np.random.rand(192,334)

a = [1,-1.086740193996892,0.649914553946275,-0.124948974636730]
b = [0.054778173164082,0.164334519492245,0.164334519492245,0.054778173164082]
impulse = signal.lfilter(b, a, [1] + [0]*99) 
impulse_20 = impulse[:20]
impulse_20 = jnp.asarray(impulse_20)

@jax.jit
def filter_jax(y):
    for ind in range(0, len(y)):
      y = jax.ops.index_update(y, jax.ops.index[:, ind], jscp.convolve(impulse_20, y[:,ind])[:-19])
    return y

jnpData = jnp.asarray(data)

%timeit filter_jax(jnpData).block_until_ready()

下面是我使用vmap的尝试:

def paraUpdate(y, ind):
    return jax.ops.index_update(y, jax.ops.index[:, ind], jscp.convolve(impulse_20, y[:,ind])[:-19])

@jax.jit
def filter_jax2(y):
  ranger = range(0, len(y))
  return jax.vmap(paraUpdate, y)(ranger)

但我收到以下错误:

TypeError: vmap in_axes must be an int, None, or (nested) container with those types as leaves, but got Traced<ShapedArray(float32[192,334])>with<DynamicJaxprTrace(level=0/1)>.

我有点困惑,因为范围是int类型的,所以我不太确定发生了什么

最后,我将尽可能地优化这个小部件,以获得最短的时间


Tags: 代码importindexsignalreturndefasfilter
1条回答
网友
1楼 · 发布于 2024-09-26 22:51:18

^{}可以表示单个操作在输入的多个轴上独立应用的功能。您的函数有点不同:您对单个输入迭代应用了单个操作

幸运的是,JAX提供了^{},可以处理这种情况。实现将如下所示:

from jax import lax

def paraUpdate(y, ind):
    return jax.ops.index_update(y, jax.ops.index[:, ind], jscp.convolve(impulse_20, y[:,ind])[:-19]), ind

@jax.jit
def filter_jax2(y):
  ranger = jnp.arange(len(y))
  return lax.scan(paraUpdate, y, ranger)[0]

print(np.allclose(filter_jax(jnpData), filter_jax2(jnpData)))
# True

%timeit filter_jax(jnpData).block_until_ready()
# 10 loops, best of 3: 28.6 ms per loop

%timeit filter_jax2(jnpData).block_until_ready()
# 1000 loops, best of 3: 519 µs per loop

如果更改算法,以便将操作应用于数组中的每个列,而不是第一个N列,则可以使用以下vmap表示:

@jax.jit
def filter_jax3(y):
  f = lambda col: jscp.convolve(impulse_20, col)[:-19]
  return jax.vmap(f, in_axes=1, out_axes=1)(y)

相关问题 更多 >

    热门问题