对numpy数组子类的操作不需要的包装结果

2024-06-30 15:47:36 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图了解Subclassing ndarray的细节,但它非常密集。(第二种可能:我很密集。)

下面是一个最小的例子:

import numpy as np

class MyArray(np.ndarray):
    def __new__(cls, input_array, foo='foo'):
        self = input_array.view(cls)
        self.foo = foo
        return self

    def __array_finalize__(self, from_array):
        if from_array is not None:
            self.foo = getattr(from_array, 'foo', 'foo')

下面是一个不必要行为的演示:

^{pr2}$

如何防止本应返回单个数字的函数或方法将结果包装成MyArray?在


Tags: fromimportselfinputfoodefnparray
3条回答

我不是类方面的专家,但是使用https://docs.scipy.org/doc/numpy/user/basics.subclassing.html中的代码:

import numpy as np

class FooArray(np.ndarray):

    def __new__(subtype, shape, dtype=float, buffer=None, offset=0,
                strides=None, order=None, foo='Foo'):
        # Create the ndarray instance of our type, given the usual
        # ndarray input arguments.  This will call the standard
        # ndarray constructor, but return an object of our type.
        # It also triggers a call to InfoArray.__array_finalize__
        obj = super(FooArray, subtype).__new__(subtype, shape, dtype,
                                                buffer, offset, strides,
                                                order)
        # set the new 'info' attribute to the value passed
        obj.foo = foo
        # Finally, we must return the newly created object:
        return obj

    def __array_finalize__(self, obj):
        if obj is None: return
        self.foo = getattr(obj, 'foo', None)

arr = np.arange(9).reshape(3,3)
fa = arr.view(FooArray)
print(fa.sum())

编辑:在进一步查看上面提到的页面并播放更多内容之后,我想到了:

^{pr2}$

编辑2:修正数组finalize的最后一行

这很难。我能想到的只有2个解决方案,而且我不喜欢其中任何一个——一个来自numpy文档中关于ndarray的子类化——使用array_wrap。其次是更多的保存->;覆盖np.总和以及你想要的其他功能。在

注意:

下面的代码需要抛光,主要是使用array_wrap的解决方案。不过,看起来array_wrap是正确的方向,因为它也适用于其他函数(*.mean()等),但正如您所说,文档非常密集,很难说黑客攻击这些神奇函数会产生什么后果。 重写在另一方面是简单明了的,但是谁想重写所有现有的函数呢?在

import numpy as np

class MyArray(np.ndarray):
    def __new__(cls, input_array, foo='foo'):
        self = input_array.view(cls)
        self.foo = foo
        return self

    # First solution - define this func that is magically called on results (check documentation)
    def __array_wrap__(self, out_arr, context=None):
       if not out_arr.shape:
           out_arr = out_arr.reshape(-1)[0] 
           # Check more on documentation. This is just example
           # It definitely needs more polishing

           return out_arr
       else:
           # This part was there before. That means it needs to be used, but thats up to you. It seems it simply passes result to __array_finalize__
           return super(MyArray, self).__array_wrap__(out_arr,self, context)

    def __array_finalize__(self, from_array):

        if from_array is not None:
            self.foo = getattr(from_array, 'foo', 'foo')

    # Use only if you know what you are doing
    ## This is second solution. Override. Check result, act accordingly
    #def sum(self,*args,**kwargs):
    #    result = super().sum()
    #    if not result.shape:
    #        return result.reshape(-1)[0] 
    #    else:
    #        return result

a = MyArray(np.arange(9).reshape(3,3))

print(a.foo)
print(a.sum())
print(a.mean().__class__)
print(a.sum().__class__)
print(np.mean(a).__class__)

>>> foo
>>> 36
>>> <class 'numpy.float64'>
>>> <class 'numpy.int64'>
>>> <class 'numpy.float64'>

很好的问题,我会尽力回答,不过,我不能百分之百确定这个解决方案是否有副作用。在

必须指定__array_ufunc__方法。通过更改https://docs.scipy.org/doc/numpy/user/basics.subclassing.html#array-ufunc-for-ufuncs中的代码,我获得了预期的结果(我相信这些结果是:没有将标量结果包装为MyArray,对吗?)在

class MyArray(np.ndarray):
    def __new__(cls, input_array, foo='foo'):
        self = input_array.view(cls)
        self.foo = foo
        return self

    def __array_finalize__(self, from_array):
        if from_array is not None:
            self.foo = getattr(from_array, 'foo', 'foo')

    def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
        args = []
        in_no = []
        for i, input_ in enumerate(inputs):
            if isinstance(input_, MyArray):
                in_no.append(i)
                args.append(input_.view(np.ndarray))
            else:
                args.append(input_)

        outputs = kwargs.pop('out', None)
        out_no = []
        if outputs:
            out_args = []
            for j, output in enumerate(outputs):
                if isinstance(output, MyArray):
                    out_no.append(j)
                    out_args.append(output.view(np.ndarray))
                else:
                    out_args.append(output)
            kwargs['out'] = tuple(out_args)
        else:
            outputs = (None,) * ufunc.nout

        results = super(MyArray, self).__array_ufunc__(ufunc, method,
                                                 *args, **kwargs)
        if results is NotImplemented:
            return NotImplemented

        if method == 'at':
            return

        if ufunc.nout == 1:
            results = (results,) 

        results = tuple((np.asarray(result).view(MyArray)
                         if isinstance(result, np.ndarray) else result)
                         #if output is None else output)
                        for result, output in zip(results, outputs))

        return results[0] if len(results) == 1 else results

导致:

^{pr2}$

相关问题 更多 >