有没有办法将quint8 pytorch格式转换为np.uint8格式?

2024-10-04 03:17:33 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在使用下面的代码来获取pytorch中的量化非IGEED int 8格式。但是,我无法将quant变量转换为to np.uint8。有可能做到这一点吗

import torch

quant = torch.quantize_per_tensor(torch.tensor([-1.0, 0.352, 1.321, 2.0]), 0.1, 10, torch.quint8)

Tags: to代码import格式nptorchpytorchint
1条回答
网友
1楼 · 发布于 2024-10-04 03:17:33

这可以使用torch.int_repr()实现

import torch
import numpy as np

# generate a test float32 tensor
float32_tensor = torch.tensor([-1.0, 0.352, 1.321, 2.0])
print(f'{float32_tensor.dtype}\n{float32_tensor}\n')

# convert to a quantized uint8 tensor. This format keeps the values in the range of
# the float32 format, with the resolution of a uint8 format (256 possible values)
quint8_tensor = torch.quantize_per_tensor(float32_tensor, 0.1, 10, torch.quint8)
print(f'{quint8_tensor.dtype}\n{quint8_tensor}\n')

# map the quantized data to the actual uint8 values (and then to an np array)
uint8_np_ndarray = torch.int_repr(quint8_tensor).numpy()
print(f'{uint8_np_ndarray.dtype}\n{uint8_np_ndarray}')

输出

torch.float32
tensor([-1.0000,  0.3520,  1.3210,  2.0000])

torch.quint8
tensor([-1.0000,  0.4000,  1.3000,  2.0000], size=(4,), dtype=torch.quint8,
       quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=10)

uint8
[ 0 14 23 30]

相关问题 更多 >