如何有效地对lambdafision中的术语进行分组?

2024-07-02 11:26:00 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个有无数项的辛多项式。我想把这个公式简化一下。然而,由于它有大量的项,并且多项式被展开,所以向下的运算比最优的要多。具体地说,通过某些项组合在一起,我们可以消除一些操作。例如,考虑以下等式:

x^2y^2 + x^2y + x^2 + 1

如果我对此进行lambdizing,那么,如果xy是长度为N的1Dnp.array,那么将有4个元素的平方运算、2个元素的乘法运算和3个元素的加法运算,从而产生大约9*N的运算。你知道吗

奥托,通过做一点代数,我们得出:

x^2(y^2 + y + 1) + 1

通过对等推理,这个公式只涉及6*N运算。如果我有一个更大更复杂的公式,差别可能会很大。你知道吗

在任何情况下,我都不需要找到使性能最大化的表示法,但是很明显,一个小的术语分组至少可以提高性能。你知道吗

在lambdafing时,如何进行这种“术语分组”以实现更有效的sympy公式表示?你知道吗


Tags: 元素情况性能array公式术语表示法sympy
2条回答

您可以按常用符号对术语进行分组,并对它们使用horner

>>> d=defaultdict(list)
>>> for t in Add.make_args(eq):
...  d[tuple(ordered(t.free_symbols))].append(t)
...
>>> Add(*[horner(Add(*i)) for i in d.values()])
x**2*y*(y + 1) + x**2 + 1

最后我用了sympy.collect。如果方程没有太多的变量,就可以简单地强制所有的组合,并递归到“收集”项中。你知道吗

这是我想出的密码。可能还有很多改进的空间:

def collect_best(expr, measure=sympy.count_ops):
    # This method performs sympy.collect over all permutations of the free variables, and returns the best collection
    best = expr
    best_score = measure(expr)
    perms = itertools.permutations(expr.free_symbols)
    permlen = np.math.factorial(len(expr.free_symbols))
    print(permlen)
    for i, perm in enumerate(perms):
        if (permlen > 1000) and not (i%int(permlen/100)):
            print(i)
        collected = sympy.collect(expr, perm)
        if measure(collected) < best_score:
            best_score = measure(collected)
            best = collected
    return best

def product(args):
    arg = next(args)
    try:
        return arg*product(args)
    except:
        return arg

def rcollect_best(expr, measure=sympy.count_ops):
    # This method performs collect_best recursively on the collected terms
    best = collect_best(expr, measure)
    best_score = measure(best)
    if expr == best:
        return best
    if isinstance(best, sympy.Mul):
        return product(map(rcollect_best, best.args))
    if isinstance(best, sympy.Add):
        return sum(map(rcollect_best, best.args))

rcollect_best将其转换为(count\u ops=136):

4*a**3*d*e - 6*a**2*b*d*e - 6*a**2*c*d*e + 16*a**2*e**3 + 6*a**2*e*f**2 + 6*a**2*e*g**2 + 2*a*b**2*d*e + 8*a*b*c*d*e - 14*a*b*e**3 - 2*a*b*e*f**2 - 8*a*b*e*g**2 + 2*a*c**2*d*e - 14*a*c*e**3 - 8*a*c*e*f**2 - 2*a*c*e*g**2 - 2*b**2*c*d*e + 2*b**2*e**3 + 2*b**2*e*g**2 - 2*b*c**2*d*e + 8*b*c*e**3 + 2*b*c*e*f**2 + 2*b*c*e*g**2 + 2*c**2*e**3 + 2*c**2*e*f**2

(计数=68):

2*e*(d*(2*a**3 - 3*a**2*b + a*b**2 + c**2*(a - b) + c*(-3*a**2 + 4*a*b - b**2)) + e**2*(8*a**2 - 7*a*b + b**2 + c**2 + c*(-7*a + 4*b)) + f**2*(3*a**2 - a*b + c**2 + c*(-4*a + b)) + g**2*(3*a**2 - 4*a*b + b**2 + c*(-a + b)))

是7个变量的5次多项式。运行时间大约是10到15分钟,并且会以指数级的速度增长,所以我不建议对任何比这个要求更高的东西使用这个。我相信有一些基本的改进可以修复超指数增长,但这解决了我的问题,所以我现在兑现。:)

相关问题 更多 >