扫描时没有序列?(模拟范围())

2024-06-28 19:10:42 发布

您现在位置:Python中文网/ 问答频道 /正文

我是新手。我正在做一些实验来生成可变长度的序列。我从脑海中想到的最简单的事情开始:模拟range()。下面是我写的简单代码:

from theano import scan
from theano import function
from theano import tensor as  T

X = T.iscalar('X')
STEP = T.iscalar('STEP')
MAX_LENGTH = 1024  # or any othe very large value

def fstep(i, x, t):
    n = i * t
    return n, until(n >= x)

t_fwd_range, _ = scan(
    fn=fstep,
    sequences=T.arange(MAX_LENGTH),
    non_sequences=[X, STEP]
)

getRange = function(
    inputs=[X, Param(STEP, 1, 'step')],
    outputs=t_fwd_range
)

getRange(x, step)
print list(f)
assert list(f[:-1]) == list(range(0, x, step))

所以我必须使用MAX_LENGTH作为输入fstep的范围长度scan。所以,我的主要问题是:有没有办法在没有输入序列的情况下使用scan?而且,我认为答案是,下一个问题是:这是我要做的事情的正确(最有效的,ecc)方法吗?在


Tags: fromimportscansteprangefunction序列theano
1条回答
网友
1楼 · 发布于 2024-06-28 19:10:42

不需要提供扫描的输入序列。您可以通过scan的n_steps参数指定迭代次数。或者,您也可以通过theano.scan_module.until指定一个扫描应提前停止的条件。在

因此,Python的range函数可以使用Theano的scan来模拟,而不需要通过计算构造请求序列所需的迭代次数来指定输入序列。在

下面是基于ano的scan的range函数的实现。唯一复杂的部分是计算出需要多少步骤。在

import numpy
import theano
import theano.tensor as tt
import theano.ifelse


def scan_range_step(x_tm1, step):
    return x_tm1 + step


def compile_theano_range():
    tt.arange
    symbolic_start = tt.lscalar()
    symbolic_stop = tt.lscalar()
    symbolic_step = tt.lscalar()
    n_steps = tt.cast(
        tt.ceil(tt.abs_(symbolic_stop - symbolic_start) / tt.cast(tt.abs_(symbolic_step), theano.config.floatX)),
        'int64') - 1
    outputs, _ = theano.scan(scan_range_step, outputs_info=[symbolic_start], n_steps=n_steps,
                             non_sequences=[symbolic_step], strict=True)
    outputs = theano.ifelse.ifelse(tt.eq(n_steps, 0), tt.stack(symbolic_start), outputs)
    f = theano.function([symbolic_start, symbolic_stop, symbolic_step],
                        outputs=tt.concatenate([[symbolic_start], outputs]))

    def theano_range(start, stop=None, step=1):
        assert isinstance(start, int)
        assert isinstance(step, int)
        if step == 0:
            raise ValueError()
        if stop is None:
            stop = start
            start = 0
        else:
            assert isinstance(stop, int)
        if start == stop:
            return []
        if stop < start and step > 0:
            return []
        if stop > start and step < 0:
            return []
        return f(start, stop, step)

    return theano_range


def main():
    theano_range = compile_theano_range()
    python_range = range

    for start in [-10, -5, -1, 0, 1, 5, 10]:
        for stop in [-10, -5, -1, 0, 1, 5, 10]:
            for step in [-3, -2, -1, 1, 2, 3]:
                a = theano_range(start, stop, step)
                b = python_range(start, stop, step)
                assert numpy.all(numpy.equal(a, b)), (start, stop, step, a, b)


main()

显然,这是一件愚蠢的事情,因为Theano已经提供了Python的range函数的符号版本,即theano.tensor.arange。内置实现也比我们的scan版本高效得多,因为它不使用scan,而是使用自定义操作。在

经验法则是:您必须通过rangen_steps参数来设置迭代步骤的最大数目。您可以将其设置为一个非常大的数字,然后使用theano.scan_module.until在满足停止条件的早期阶段停止迭代。在

相关问题 更多 >