在附加维度上重复keras(tensorflow)模型

2024-05-19 07:06:58 发布

您现在位置:Python中文网/ 问答频道 /正文

假设我有一个模型,它把一个形状为[n,10]的张量映射到一个形状为[n,2]的张量,其中n是批量大小。如何重复该模型,以便生成的模型接受形状为[n,k,10]的输入张量并输出形状为[n,k,2]的张量?模型的k个版本应该共享所有权重


Tags: 模型版本批量形状所有权
1条回答
网友
1楼 · 发布于 2024-05-19 07:06:58

你可以这样做:

input_ = Input((k, model.input.shape[1]))
input_as_list = Lambda(lambda x: tf.unstack(x, axis=1))(input_)
model_outputs = [model(x) for x in input_as_list] 
model_outputs = [Lambda(lambda x: K.expand_dims(x, axis=1))(y) for y in model_outputs]
concat_output = Concatenate(axis=1)(model_outputs)
new_model = Model(input_, concat_output)

相关问题 更多 >

    热门问题