pytorch动态配料
torchfold的Python项目详细描述
扭矩
博客文章:http://near.ai/articles/2017-09-06-PyTorch-Dynamic-Batching/
类似于TensorFlow Fold,使用超级简单的接口实现动态批处理。
用f.add('function name', arguments)
替换计算中对nn模块的每个直接调用。
它将构造计算的优化版本,并在f.apply
上动态批处理和执行给定nn模块上的计算。
安装
我们建议使用pip包管理器:
pip install torchfold
示例
f = torchfold.Fold()
def dfs(node):
if is_leaf(node):
return f.add('leaf', node)
else:
prev = f.add('init')
for child in children(node):
prev = f.add('child', prev, child)
return prev
class Model(nn.Module):
def __init__(self, ...):
...
def leaf(self, leaf):
...
def child(self, prev, child):
...
res = dfs(my_tree)
model = Model(...)
f.apply(model, [[res]])