pytorch中cuda张量的Bit运算

2024-07-08 10:57:37 发布

您现在位置:Python中文网/ 问答频道 /正文

我想在cuda中使用torch.tensor进行一些操作,例如<<&燃气轮机>;,或者提取表示浮点数的每个位,例如浮点数16中0.333的0 01101 01010101012(3555H)

我现在做的是:

def _decompose(self, value, exp_bias=None):
    '''
    decompose a single into sign, exp and mant
    '''
    if exp_bias is None:
        exp_bias = self.exp_bias
    # smallest non-zero float point
    descriminator = torch.tensor((2 ** (-exp_bias)) / 2).type_as(value)
    sign = (value > descriminator).type_as(value)
    sign -= (value < -descriminator).type_as(value)
    value = value.abs()
    exp = torch.log2(value).floor()
    mant = value / (2 ** exp)
    return sign, exp, mant

有没有办法实现这样的功能?或者我的代码有问题吗?谢谢


Tags: ltselfnonevalueastypetorchcuda

热门问题