在具有常量但不可散列对象的函数上使用functools.lru_缓存

2024-06-02 19:01:52 发布

您现在位置:Python中文网/ 问答频道 /正文

是否可以使用functools.lru_cache缓存由functools.partial创建的部分函数

我的问题是一个函数,它接受可散列参数和连续的、不可散列的对象,比如NumPy数组

考虑这个玩具例子:

import numpy as np
from functools import lru_cache, partial

def foo(key, array):
    print('%s:' % key, array)
a = np.array([1,2,3])

由于NumPy数组不可散列,因此这将不起作用:

@lru_cache(maxsize=None)
def foo(key, array):
    print('%s:' % key, array)
foo(1, a)

正如预期的那样,您会出现以下错误:

/Users/ch/miniconda/envs/sci34/lib/python3.4/functools.py in __init__(self, tup, hash)
    349     def __init__(self, tup, hash=hash):
    350         self[:] = tup
--> 351         self.hashvalue = hash(tup)
    352 
    353     def __hash__(self):

TypeError: unhashable type: 'numpy.ndarray'

因此,我的下一个想法是使用functools.partial来摆脱NumPy数组(它无论如何都是常量)

pfoo = partial(foo, array=a)
pfoo(2)

现在我有了一个只接受散列参数的函数,它应该非常适合lru_cache。但是在这种情况下是否可以使用lru_cache?我不能将它用作包装函数而不是@lru_cache装饰器,可以吗

有没有聪明的方法来解决这个问题


Tags: key函数selfnumpycache参数foodef
2条回答

由于数组是常量,您可以在实际的lru缓存函数周围使用包装器,只需将键值传递给它:

from functools import lru_cache, partial
import numpy as np


def lru_wrapper(array=None):
    @lru_cache(maxsize=None)
    def foo(key):
        return '%s:' % key, array
    return foo


arr = np.array([1, 2, 3])
func = lru_wrapper(array=arr)

for x in [0, 0, 1, 2, 2, 1, 2, 0]:
    print (func(x))

print (func.cache_info())

产出:

('0:', array([1, 2, 3]))
('0:', array([1, 2, 3]))
('1:', array([1, 2, 3]))
('2:', array([1, 2, 3]))
('2:', array([1, 2, 3]))
('1:', array([1, 2, 3]))
('2:', array([1, 2, 3]))
('0:', array([1, 2, 3]))
CacheInfo(hits=5, misses=3, maxsize=None, currsize=3)

下面是一个如何将lru_cachefunctools.partial一起使用的示例:

from functools import lru_cache, partial
import numpy as np


def foo(key, array):
    return '%s:' % key, array


arr = np.array([1, 2, 3])
pfoo = partial(foo, array=arr)
func = lru_cache(maxsize=None)(pfoo)

for x in [0, 0, 1, 2, 2, 1, 2, 0]:
    print(func(x))

print(func.cache_info())

输出:

('0:', array([1, 2, 3]))
('0:', array([1, 2, 3]))
('1:', array([1, 2, 3]))
('2:', array([1, 2, 3]))
('2:', array([1, 2, 3]))
('1:', array([1, 2, 3]))
('2:', array([1, 2, 3]))
('0:', array([1, 2, 3]))
CacheInfo(hits=5, misses=3, maxsize=None, currsize=3)

这比solution of @AshwiniChaudhary更简洁,并且在OP的要求之后使用functools.partial


p.S.:此解决方案改编自Applying ^{} to lambda

相关问题 更多 >