沿较深维度更新张量值

2024-10-02 08:15:24 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个形状为MxNxC的张量A,其中M表示例子的数量,N表示特征的数量,C表示3个欧拉旋转角。同样我有一个类似形状的张量B,但是没有角度,而是坐标。 所需要的是将这两个张量都转换为一个包含仿射变换矩阵的张量,这样它的形状就像mxnx4x4x4。我不知道如何一起迭代这些张量,我已经寻找了tf.map_fntf.scan,但是它们只在第一维度上迭代。我所寻找的是一些方法来应用像下面的一个沿最后一个轴的所有元素的功能

def f(angles, vector): #dimensions 3 or 3x1
    ...
    return matrix # dimension 4x4

任何帮助都会有用的,谢谢


Tags: 方法功能元素map数量scantf矩阵
1条回答
网友
1楼 · 发布于 2024-10-02 08:15:24

您可以尝试以下方法:

A_flattened = tf.reshape(A, [-1, 3])# flatten it out
B_flattened = tf.reshape(B, [-1, 3])
AB_flattened = tf.map_fn(convert_to_mat, (A_flattened, B_flattened))# convert_to_mat should return a 4x4 matrix
AB = tf.reshape(AB_flattened, [M, N, 4, 4])

这应该会成功的

相关问题 更多 >

    热门问题