无法广播ragged_秩=2的tf.RaggedSensor

2024-09-29 22:20:47 发布

您现在位置:Python中文网/ 问答频道 /正文

我试着将一些参差张量的值分别与另一个张量的每个值相加,得到另一个具有一个维度的参差张量。但我发现了一个广播错误:

Unable to broadcast: dimension size mismatch in dimension=[2]. lengths=[3] dim_size=[1, 1, 1, 1, 1]

因此,这个错误发生了,而不是像普通tf广播规则所说的那样,沿着大小为1的参差不齐维度进行广播

最简单的例子如下:

这是有效的(与ragged_rank=1

import tensorflow as tf
x = tf.ragged.constant(
        [
            [[1,2,3,4], [2,5,7,8]],
            [[3,12,8,9],[0,0,1,1],[4,4,9,7]],
        ],
    ragged_rank=1)
y = tf.constant([10, 20, 30])
x = x[:,:,tf.newaxis,:]
y = y[:,tf.newaxis]
print(x+y)

但这不是(用ragged_rank=2

import tensorflow as tf
x = tf.ragged.constant(
        [
            [[1,2,3,4], [2,5,7,8]],
            [[3,12,8,9],[0,0,1,1],[4,4,9,7]],
        ],
    ragged_rank=2)
y = tf.constant([10, 20, 30])
x = x[:,:,tf.newaxis,:]
y = y[:,tf.newaxis]
print(x+y)

我必须处理ragged_rank=2r.t,因为我将它作为我的tf.data.Dataset批处理管道的map函数的输入。 我还希望有一个将ragged_rank减少为1的解决方案(如果可能的话),因为内部维度应该是统一的(长度为4)

UPD:好的,我通过重新创建一个张量,成功地降低了参差不齐的_等级,如下所示:

def downgrade_bbox_batch_ragged_rank(x, inner_len=5):
    v = x.flat_values
    rs = x.row_splits
    return tf.RaggedTensor.from_row_splits(values=tf.reshape(v,(-1,inner_len)),
                                           row_splits=rs)

在此之后,就广播而言,在内部维度(平面_值)之前或中添加新轴效果良好。但在参差不齐的维度之前添加newaxis仍然不起作用。这是一种预期的行为吗


Tags: importsizetftensorflowas错误rowprint

热门问题