这是一个基本的例子
@jax.jit
def block(arg1, arg2):
for x1 in range(cons1):
for x2 in range(cons2):
for x3 in range(cons3):
--do something--
return result
cons较小时,编译时间大约为一分钟。对于较大的缺点,编译时间要高得多—10分钟。我需要更高的犯人。可以做些什么? 从我所读到的,循环是原因。它们在编译时展开。 有什么解决办法吗?还有jax.fori_循环。但我不知道如何使用它。有一个jax.experimental.loops模块,但我还是不能理解它
我对这一切都很陌生。因此,我们非常感谢您的帮助。 如果您能提供一些关于如何使用jax循环的示例,我们将不胜感激
另外,什么是ok编译时间?几分钟后可以吗? 在其中一个示例中,编译时间为262秒,剩余的运行时间约为0.1-0.2秒
运行时的任何增益都会被编译时所掩盖
我不确定这是否与
numba
相同,但情况可能类似当我使用
numba.jit
编译器并有大数据输入时,首先在一些小示例数据上编译函数,然后使用它伪代码:
JAX的JIT编译器使所有Python循环平坦化。要了解我的意思,请看一下这个通过
jax.make_jaxpr
运行的简单函数,这是一种检查JAX的跟踪程序如何解释python代码的方法(有关更多信息,请参见Understanding Jaxprs):注意循环被扁平化:每一步都变成了一个显式的操作,发送给XLA编译器。XLA编译时间随着函数中操作次数的增加而增加,因此三层嵌套循环将导致长编译时间是有意义的。p>
那么,如何解决这个问题呢?嗯,不幸的是,答案取决于你的
do something
在做什么,所以我猜不到一般来说,最好的选择是使用向量化数组操作,而不是在这些向量中的值上循环;例如,下面是一种添加两个向量的非常缓慢的方法:
这里有一个更快的方法来做同样的事情:
如果你的操作不支持向量化,另一个选择是使用lax control flow运算符代替^ {< CD3>}循环:这将把循环推到XLA中。这在CPU上可以有相当好的性能,但在加速器上比等效矢量化阵列操作要慢
有关JAX和Python控制流语句(如
for
、if
、while
等)的更多讨论,请参见🔪 JAX - The Sharp Bits 🔪: Control Flow相关问题 更多 >
编程相关推荐