使用functools.wrapps在装饰器链中保留精确的函数签名

2024-06-28 15:42:46 发布

您现在位置:Python中文网/ 问答频道 /正文

在Python3.4+中,functools.wraps保留它所包装的函数的签名。不幸的是,如果您创建的装饰器要堆叠在彼此的顶部,那么序列中的第二个(或更高)装饰器将看到包装器的泛型*args**kwargs签名,而不是在装饰器序列的底部一直保留原始函数的签名。这里有一个例子

from functools import wraps    

def validate_x(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        assert kwargs['x'] <= 2
        return func(*args, **kwargs)
    return wrapper

def validate_y(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        assert kwargs['y'] >= 2
        return func(*args, **kwargs)
    return wrapper

@validate_x
@validate_y
def foo(x=1, y=3):
    print(x + y)


# call the double wrapped function.
foo()

这给

-------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-5-69c17467332d> in <module>
     22
     23
---> 24 foo()

<ipython-input-5-69c17467332d> in wrapper(*args, **kwargs)
      4     @wraps(func)
      5     def wrapper(*args, **kwargs):
----> 6         assert kwargs['x'] <= 2
      7         return func(*args, **kwargs)
      8     return wrapper

KeyError: 'x'

如果您切换装饰程序的顺序,那么'y'会得到相同的键错误

我尝试在第二个decorator中用wraps(func.__wrapped__)替换wraps(func),但这仍然不起作用(更不用说它要求程序员明确知道他们在decorator堆栈中为给定包装器功能工作的位置)

我还查看了inspect.signature(foo),这似乎给出了正确的结果,但我发现这是因为inspect.signature有一个follow_wrapped参数默认为True,所以它不知怎么地知道如何遵循包装函数的顺序,但显然是调用foo()的常规方法调用框架对于外部包装的解析args和kwargs,将不遵循相同的协议

我如何才能让wraps忠实地传递签名,以便wraps(wraps(wraps(wraps(f))))(可以这么说)始终忠实地复制f的签名


Tags: 函数returnfoodefargs序列装饰assert
3条回答

你的诊断是错误的;实际上,functools.wraps保留了双修饰函数的签名:

>>> import inspect
>>> inspect.signature(foo)
<Signature (x=1, y=3)>

我们还可以观察到,调用具有错误签名的函数不是问题,因为这会引发TypeError,而不是KeyError

您似乎有这样一种印象,即当只使用一个decorator时,kwargs将填充参数默认值。这根本不会发生:

def test_decorator(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        print('args:', args)
        print('kwargs:', kwargs)
        return func(*args, **kwargs)
    return wrapper

@test_decorator
def foo(x=1):
    print('x:', x)

输出为:

>>> foo()
args: ()
kwargs: {}
x: 1

如您所见,argskwargs都不会收到参数的默认值,即使只使用了一个装饰符。它们都是空的,因为foo()调用包装函数时没有位置参数和关键字参数


实际的问题是您的代码有一个逻辑错误。修饰符validate_xvalidate_y期望参数作为关键字参数传递,但事实上,它们可能作为位置参数传递,或者根本不传递(因此默认值将适用),在这种情况下'x'和/或'y'不会出现在kwargs

没有简单的方法可以让你的装饰者使用一个可以作为关键字或位置传递的参数;如果只使用arguments关键字,则可以在验证前测试'x''y'是否在kwargs

def validate_x(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        if 'x' in kwargs and kwargs['x'] > 2:
            raise ValueError('Invalid x, should be <= 2, was ' + str(x))
        return func(*args, **kwargs)
    return wrapper

@validate_x
def bar(*, x=1): # keyword-only arg, prevent passing as positional arg
    ...

通常最好显式地raise一个错误,而不是使用assert,因为您的程序can be run with ^{} disabled

还要注意的是,可以声明像@validate_x def baz(*, x=5): ...这样的函数,其中默认的x无效。这不会引发任何错误,因为装饰程序未检查默认参数值

实际上,您没有向函数foo传递任何参数,因此*args**kwargs对于这两个修饰符都是空的。如果您传递参数,则装饰器将正常工作

foo(x=2, y = 3) # prints 5

您可以尝试使用inspect获取默认函数参数

如果不使用inspect,就无法真正获得默认值,还需要考虑位置参数(*args)与关键字参数(**kwargs)。因此,如果数据存在,则对其进行规范化如果数据丢失,则检查函数

import inspect
from functools import wraps


def get_default_args(func):
    signature = inspect.signature(func)
    return {
        k: v.default
        for k, v in signature.parameters.items()
        if v.default is not inspect.Parameter.empty
    }


def validate_x(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        if args and not kwargs and len(args) == 2:
            kwargs['x'] = args[0]
            kwargs['y'] = args[1]
            args = []
        if not args and not kwargs:
            kwargs = get_default_args(func)
        assert kwargs['x'] <= 2
        return func(*args, **kwargs)

    return wrapper


def validate_y(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        if args and not kwargs and len(args) == 2:
            kwargs['x'] = args[0]
            kwargs['y'] = args[1]
            args = []
        if not args and not kwargs:
            kwargs = get_default_args(func)
        assert kwargs['y'] >= 2
        return func(*args, **kwargs)

    return wrapper


@validate_x
@validate_y
def foo(x=1, y=3):
    print(x + y)


# call the double wrapped function.
foo()
# call with positional args
foo(1, 4)
# call with keyword args
foo(x=2, y=10)

这张照片

4
5
12

相关问题 更多 >