如何获取和设置XorOWRandomNumberGenerator状态以实现持久性?

2024-09-30 08:16:04 发布

您现在位置:Python中文网/ 问答频道 /正文

我的目标是保持一个复杂算法每次迭代的完整状态,这个算法还涉及通过pycuda生成的伪随机数。为了在仲裁迭代中恢复算法并确定地再现相同的结果,我需要类似于get_state() and set_state() from numpy.random.RandomState的东西

考虑到这一点:

from pycuda.curandom import XORWOWRandomNumberGenerator
gen = XORWOWRandomNumberGenerator()

如何获得gen的完整状态以便将其加载到numpy数组中

如何基于这些先前获得的numpy数组重现gen的完全相同的状态


Tags: andfrompycudanumpy算法目标get状态
1条回答
网友
1楼 · 发布于 2024-09-30 08:16:04

我没有找到现成的解决方案。因此,我从XORWOWRandomNumberGenerator派生了以下类:

from pycuda.curandom import XORWOWRandomNumberGenerator
import pycuda.driver as drv
import numpy


class PersistableXORWOWRandomNumberGenerator(XORWOWRandomNumberGenerator):
    def get_state_size(self):
        from pycuda.characterize import sizeof
        data_type_size = sizeof(self.state_type, "#include <curand_kernel.h>")
        return self.block_count * self.generators_per_block * data_type_size

    def get_state(self, ary=None):
        if ary is None:
            ary = drv.from_device(self.state, (self.get_state_size(),), numpy.uint8)
        else:
            drv.memcpy_dtoh(ary, self.state)
        return ary

    def set_state(self, state):
        drv.memcpy_htod(self.state, state)

可以这样使用:

gen = PersistableXORWOWRandomNumberGenerator()

# obtain the random state as a uint8 numpy array
state_as_numpy = gen.get_state()

# set the state from a uint8 numpy array
gen.set_state(state_as_numpy)

相关问题 更多 >

    热门问题