我必须重写子类中的所有数学运算符吗?

2024-09-20 22:52:25 发布

您现在位置:Python中文网/ 问答频道 /正文

我想在python3.7程序中创建一个简单的Point2d类,它只实现了一些特性。我在一个如此的答案中看到(我现在找不到)创建Point类的一种方法是重写complex,所以我写了这样一篇文章:

import math

class Point2d(complex):

    def distanceTo(self, otherPoint):
        return math.sqrt((self.real - otherPoint.real)**2 + (self.imag - otherPoint.imag)**2)

    def x(self):
        return self.real

    def y(self):
        return self.imag

这样做有效:

In [48]: p1 = Point2d(3, 3)

In [49]: p2 = Point2d(6, 7)

In [50]: p1.distanceTo(p2)
Out[50]: 5.0

但当我这样做时,p3complex的实例,而不是Point2d

In [51]: p3 = p1 + p2

In [52]: p3.distanceTo(p1)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-52-37fbadb3015e> in <module>
----> 1 p3.distanceTo(p1)

AttributeError: 'complex' object has no attribute 'distanceTo'

我的大部分背景都是Objective-C和C,所以我仍在尝试找出做这类事情的Python式方法。我需要重写我想在Point2d类上使用的所有数学运算符吗?还是我完全错了?你知道吗


Tags: 方法inselfreturndefmathrealp2
3条回答

在这种情况下,我建议从头开始实现类Point2d。你知道吗

如果你很懒的话,可以看看类似lib的sympy,它包含一个Point类和其他几何类https://docs.sympy.org/latest/modules/geometry/index.html

问题是,当您的类使用属于complex的任何数据模型函数时,它返回一个complex,因此您需要将其交给Point2d

加上这个方法就可以了

def __add__(self, b):
    return Point2d(super().__add__(b))

但还是应该有更好的办法。但这是动态包装一些数据模型(dunder)方法的方法。你知道吗

顺便说一句,距离函数可以使它变短,像这样

def distanceTo(self, otherPoint):
    return abs(self - otherPoint)

我要提到一种重写所有方法的方法,而不必手动编写每个方法,但这只是因为我们都在这里consenting adults。我不建议这样做,如果你只是覆盖每一个操作就更清楚了。也就是说,您可以编写一个类包装器来检查基类的所有方法,并将输出转换为复杂类型的点。你知道吗

import math
import inspect


def convert_to_base(cls):
    def decorate_func(name, method, base_cls):
        def method_wrapper(*args, **kwargs):
            obj = method(*args, **kwargs)
            return cls.convert(obj, base_cls) if isinstance(obj, base_cls) else obj
        return method_wrapper if name not in ('__init__', '__new__') else method
    for base_cls in cls.__bases__:
        for name, method in inspect.getmembers(base_cls, inspect.isroutine):  # Might want to change this filter
            setattr(cls, name, decorate_func(name, method, base_cls))
    return cls


@convert_to_base
class Point2d(complex):

    @classmethod
    def convert(cls, obj, base_cls):
        # Use base_cls if you need to know which base class to convert.
        return cls(obj.real, obj.imag)

    def distanceTo(self, otherPoint):
        return math.sqrt((self.real - otherPoint.real)**2 + (self.imag - otherPoint.imag)**2)

    def x(self):
        return self.real

    def y(self):
        return self.imag

p1 = Point2d(3, 3)
p2 = Point2d(6, 7)
p3 = p1 + p2
p4 = p3.distanceTo(p1)
print(p4)

# 9.219544457292887

这里发生的事情是,它只检查基类的所有方法,如果它返回的是基类的类型,则将其转换为子类,如子类中的特殊classmethod所定义的那样。你知道吗

相关问题 更多 >

    热门问题