三元表示法中的快速数字和(Python)

2024-10-16 20:51:50 发布

您现在位置:Python中文网/ 问答频道 /正文

我定义了一个函数

def enumerateSpin(n):
    s = []
    for a in range(0,3**n):
        ternary_rep = np.base_repr(a,3)
        k = len(ternary_rep)
        r = (n-k)*'0'+ternary_rep
        if sum(map(int,r)) == n:
            s.append(r)
    return s

我看的是数字0<;=a<;3^N并询问三元表示法中的数字总和是否达到某个值。我首先将数字转换成三元表示的字符串。我之所以填充零,是因为我想存储一个固定长度表示的列表,以便以后用于进一步计算(即两个元素之间的逐位比较)

现在np.base_reprsum(map(int,#))分别在我的计算机上占用大约5个us,这意味着迭代大约需要10个us,我正在寻找一种方法,在这种方法中,您可以完成我所做的事情,但速度要快10倍

(编辑:关于在左侧填充零的注释)

(Edit2:事后看来,最终表示形式最好是整数的元组,而不是字符串)

(编辑3:对于那些想知道的人来说,代码的目的是枚举具有相同总S_z值的自旋-1链的状态。)


Tags: 方法字符串lt编辑mapbase定义np
3条回答

您可以使用^{}生成数字,然后转换为字符串表示形式:

import itertools as it

def new(n):
    s = []
    for digits in it.product((0, 1, 2), repeat=n):
        if sum(digits) == n:
            s.append(''.join(str(x) for x in digits))
    return s

这给了我大约7倍的加速:

In [8]: %timeit enumerateSpin(12)
2.39 s ± 7.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [9]: %timeit new(12)
347 ms ± 4.26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

在Python 3.9.0(IPython 7.20.0)(Linux)上测试

上面的过程使用it.product也生成了我们通过推理知道它们不符合条件的数字(这是所有数字的一半的情况,因为数字的总和必须等于数字的数目)。对于n位,我们可以计算最终总计为n210位的各种计数。然后我们可以生成这些数字的所有distinct permutations,从而只生成相关数字:

import itertools as it
from more_itertools import distinct_permutations

def new2(n):
    all_digits = (('2',)*i + ('1',)*(n-2*i) + ('0',)*i for i in range(n//2+1))
    all_digits = it.chain.from_iterable(distinct_permutations(d) for d in all_digits)
    return (''.join(digits) for digits in all_digits)

特别是对于大量的n,这提供了额外的、显著的加速:

In [44]: %timeit -r 1 -n 1 new(16)
31.4 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

In [45]: %timeit -r 1 -n 1 list(new2(16))
7.82 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

请注意,上述解决方案newnew2具有O(1)内存缩放(将new更改为yield而不是append

通常,要获取特定基数中的数字,我们可以执行以下操作:

while num > 0:
    digit = num % base
    num //= base
    print(digit)

使用num = 14, base = 3运行此命令时,我们会得到:

2
1
1

这意味着三元中的14是112。
我们可以将其提取到方法digits(num, base)中,并且仅在实际需要将数字转换为字符串时使用np.base_repr(a,3)

def enumerateSpin(n):
    s = []
    for a in range(0,3**n):
        if sum(digits(a, 3)) == n:
            ternary_rep = np.base_repr(a,3)
            k = len(ternary_rep)
            r = (n-k)*'0'+ternary_rep
            s.append(r)
    return s

enumerateSpin(4)输出:

['0022', '0112', '0121', '0202', '0211', '0220', '1012', '1021', '1102', '1111', '1120', '1201', '1210', '2002', '2011', '2020', '2101', '2110', '2200']

通过将所有计算委托给numpy以利用矢量化处理,可以实现10倍的改进:

def eSpin(n):
    nums    = np.arange(3**n,dtype=np.int)
    base3   = nums // (3**np.arange(n))[:,None] % 3
    matches = np.sum(base3,axis=0) == n
    digits  = np.sum(base3[:,matches] * 10**np.arange(n)[:,None],axis=0)
    return [f"{a:0{n}}" for a in digits]   

其工作原理(例如eSpin(3)):

nums是一个包含最多3**n的所有数字的数组

   [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26]  

base3将其转换为附加维度中的基3位数:

[[0 1 2 0 1 2 0 1 2 0 1 2 0 1 2 0 1 2 0 1 2 0 1 2 0 1 2]
 [0 0 0 1 1 1 2 2 2 0 0 0 1 1 1 2 2 2 0 0 0 1 1 1 2 2 2]
 [0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2]]

matches标识base3数字之和与n匹配的列

 [0 0 0 0 0 1 0 1 0 0 0 1 0 1 0 1 0 0 0 1 0 1 0 0 0 0 0]

digits将匹配的列转换为由base3数字组成的以10为基数的数字

 [ 12  21 102 111 120 201 210]

最后,匹配的(base10)数字被格式化为前导零

性能:

from timeit import timeit
count = 1

print(enumerateSpin(10)==eSpin(10)) # True

t1 = timeit(lambda:eSpin(13),number=count)
print("eSpin",t1) # 0.634 sec

t0 = timeit(lambda:enumerateSpin(13),number=count)
print("enumerateSpin",t0) # 7.362 sec

元组版本:

def eSpin2(n):
    nums    = np.arange(3**n,dtype=np.int)
    base3   = nums// (3**np.arange(n))[:,None]  % 3
    matches = np.sum(base3,axis=0) == n
    return [*map(tuple,base3[:,matches].T)]

eSpin2(3)
[(2, 1, 0), (1, 2, 0), (2, 0, 1), (1, 1, 1), (0, 2, 1), (1, 0, 2), (0, 1, 2)]

[EDIT]一种更快的方法(比enumerateSpin快40到80倍)

使用动态编程和记忆可以提供更好的性能:

@lru_cache()
def eSpin(n,base=3,target=None):
    if target is None: target = n
    if target == 0: return [(0,)*n]
    if target>base**n-1: return []
    if n==1: return [(target,)]
    result = []
    for d in range(min(base,target+1)):
        result.extend((d,)+suffix for suffix in eSpin(n-1,base,target-d) )
    return result

t4 = timeit(lambda:eSpin(13),number=count)
print("eSpin",t4) # 0.108 sec

eSpin.cache_clear()
t5 = timeit(lambda:eSpin(16),number=count)
print("eSpin",t5) # 2.25 sec

相关问题 更多 >