无法从Ray演员进程初始化超类

2024-10-01 00:23:01 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试并行采样过程,所以我创建了一个采样器对象。采样器依赖于两个大数据集(存储为numpy数组),它们是构造函数的参数。为了避免对象存储中存在重复项,我的想法是首先使用ray.put将对象添加到对象存储中,然后使用相应的ID初始化采样器对象

此外,我不想将装饰器添加到Sampler类中。相反,我创建了一个Sampler的子类RemoteSampler,它装饰了superclass的方法,并通过添加.remote()调用来修改它们。但是,我似乎无法从ActorClass初始化superclass。我得到一个类型错误:

TypeError: super() argument 1 must be type, not ActorClass(RemoteSampler).

框架代码如下:

import ray
import numpy as np

class Sampler(object):

    def __init__(self, train_data, d_train_data, *others):

        # these can be big, so we want to have only one copy that
        # mutliple actors share
        if isinstance(train_data, np.ndarray):
            self.train_data = train_data
        else:
            self.train_data = ray.get(train_data)

        if isinstance(d_train_data, np.ndarray):
            self.d_train_data = d_train_data
        else:
            self.d_train_data = ray.get(d_train_data)

        # Initialise the rest of the sampler state
        self.d1 = {}
        self.d2 = {}

    def __call__(self, features, n_samples):

        a, b, c = self._sampling_loop(features, n_samples)

        # process a, b, c and return something

        return a, b, c, features

    def build_lookups(self, X):
        self.d1 = {0: X[0]}
        self.d2 = {1: X[1]}
        return self.d1, self.d2

    def _sampling_loop(self, features, n_samples):
        # Use train_data, d_train data and other attributes return some data to call
        return 0, 0, 0

@ray.remote
class RemoteSampler(Sampler):

    def __init__(self, *args):
        super(RemoteSampler, self).__init__(*args)
        self.__call__ = ray.method(self.__call__, num_return_vals=4)
        self.build_lookups = ray.method(self.build_lookups, num_return_vals=2)

    def __call__(self, anchor, num_samples):
        return self.__call__(anchor, num_samples).remote()

    def build_lookups(self, X):
        a, b, c = self.build_lookups.remote(X)
        return a, b, c


def _fit_parallel(*args):
    # method of a class where the RemoteSampler objects are initialised
    # copy large objects to object store
    train_data, d_train_data, *others = args
    train_data_id = ray.put(train_data)
    d_train_data_id = ray.put(d_train_data)
    n_args = (train_data_id, d_train_data_id, *others)
    return [RemoteSampler.remote(*n_args) for _ in range(4)]

Tags: 对象buildselfdatareturnremotedefargs