如何将perreplica转换为张量?

2024-10-06 10:21:17 发布

您现在位置:Python中文网/ 问答频道 /正文

在tensorflow2.0中使用多gpu进行培训时,perreplica将减少以下代码:

strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)

但是,如果我只想收集(没有'sum reduce'或'mean reduce')所有gpu的预测到一个张量:

^{pr2}$

Tags: 代码nonereducegputfdistributesumstrategy
1条回答
网友
1楼 · 发布于 2024-10-06 10:21:17

简而言之,您可以将PerReplicaresult转换为如下张量的元组:

tensors_tuple = per_replica_predicitions.values

返回的tensors_tuple将是来自每个副本/设备的predictions的元组:

^{pr2}$

这个元组中元素的数量由分布式策略可用的设备决定。特别是,如果策略在单个副本/设备上运行,则策略。实验性的?运行?2与直接调用train_step函数相同(张量或张量列表由您的train_step决定)。所以您可能需要这样编写代码:

per_replica_losses, per_replica_predicitions = strategy.experimental_run_v2(train_step, args=(dataset_inputs,))

if strategy.num_replicas_in_sync > 1:
    predicition_tensors = per_replica_predicitions.values
else:
    predicition_tensors = per_replica_predicitions

PerReplica是一个封装分布式运行结果的类对象。您可以找到它的定义here,还有更多的属性/方法供我们操作PerReplica对象。在

相关问题 更多 >