使用for循环时如何减少JAX编译时间?

2024-09-26 18:07:42 发布

您现在位置:Python中文网/ 问答频道 /正文

这是一个基本的例子

@jax.jit
def block(arg1, arg2):
   for x1 in range(cons1):
       for x2 in range(cons2):
          for x3 in range(cons3):
             --do something--
   return result

cons较小时,编译时间大约为一分钟。对于较大的缺点,编译时间要高得多—10分钟。我需要更高的犯人。可以做些什么? 从我所读到的,循环是原因。它们在编译时展开。 有什么解决办法吗?还有jax.fori_循环。但我不知道如何使用它。有一个jax.experimental.loops模块,但我还是不能理解它

我对这一切都很陌生。因此,我们非常感谢您的帮助。 如果您能提供一些关于如何使用jax循环的示例,我们将不胜感激

另外,什么是ok编译时间?几分钟后可以吗? 在其中一个示例中,编译时间为262秒,剩余的运行时间约为0.1-0.2秒

运行时的任何增益都会被编译时所掩盖


Tags: in示例fordef时间rangeblock例子
2条回答

我不确定这是否与numba相同,但情况可能类似

当我使用numba.jit编译器并有大数据输入时,首先在一些小示例数据上编译函数,然后使用它

伪代码:

func_being_compiled(small_amount_of_data)  # compile-only purpose
func_being_compiled(large_amount_of_data)

JAX的JIT编译器使所有Python循环平坦化。要了解我的意思,请看一下这个通过jax.make_jaxpr运行的简单函数,这是一种检查JAX的跟踪程序如何解释python代码的方法(有关更多信息,请参见Understanding Jaxprs):

import jax

def f(x):
  for i in range(5):
    x += i
  return x

print(jax.make_jaxpr(f)(0))
# { lambda  ; a.
#   let b = add a 0
#       c = add b 1
#       d = add c 2
#       e = add d 3
#       f = add e 4
#   in (f,) }

注意循环被扁平化:每一步都变成了一个显式的操作,发送给XLA编译器。XLA编译时间随着函数中操作次数的增加而增加,因此三层嵌套循环将导致长编译时间是有意义的。p>

那么,如何解决这个问题呢?嗯,不幸的是,答案取决于你的 do something 在做什么,所以我猜不到

一般来说,最好的选择是使用向量化数组操作,而不是在这些向量中的值上循环;例如,下面是一种添加两个向量的非常缓慢的方法:

import jax.numpy as jnp

def f_slow(x, y):
  z = []
  for xi, yi in zip(xi, yi):
    z.append(xi + yi)
  return jnp.array(z)

这里有一个更快的方法来做同样的事情:

def f_fast(x, y):
  return x + y

如果你的操作不支持向量化,另一个选择是使用lax control flow运算符代替^ {< CD3>}循环:这将把循环推到XLA中。这在CPU上可以有相当好的性能,但在加速器上比等效矢量化阵列操作要慢

有关JAX和Python控制流语句(如forifwhile等)的更多讨论,请参见🔪 JAX - The Sharp Bits 🔪: Control Flow

相关问题 更多 >

    热门问题