如何在TF2.x中的参差不齐张量上运行BatchNorm?

2024-09-29 21:51:41 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图在TF2.x中对一批参差不齐的张量运行BatchNormalization,但在执行此操作时似乎遇到了错误。(我可以在BatchNorm forward调用前后对不规则的张量进行转换,但我无法在NonEager模式下运行to_tensor(),这是我有效训练网络的必要条件)

Pytorch有一个BatchNorm1D,但TF似乎没有任何这样的API,任何建议/指针都会有帮助


Tags: to网络apitf错误模式pytorch建议
1条回答
网友
1楼 · 发布于 2024-09-29 21:51:41

在keras上创建一个自定义层,将粗糙的张量转换为张量,然后将其注入到BatchNormalization层中

我真的试过了。这是我所做的,有一个问题,但我没有时间来解决它。这可能对你有帮助,但也许没有

在下面的代码中,我创建了一个超级简单的“to_tensor”层,可以在规范化之前使用它

这算是可行的,但由于我在to.tensor()行中创建了一个新的张量,tf再也找不到任何可训练的变量了

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import *

def to_tensor(x):
    return x.to_tensor(shape=(None, 4))

X = tf.ragged.constant(
    [[3, 1, 4, 1], [], [5, 9, 2], [],[6]])

y = tf.random.normal(shape=(5,1))

inp = Input(shape=(None,), ragged=True)
x = Lambda(to_tensor)(inp)
out = Dense(1)(x)

m = Model(inp,out)

m.compile(optimizer='adam',metrics=['accuracy'])
history = m.fit(X, y, epochs=10)

相关问题 更多 >

    热门问题