在流畅的界面中构建pytorch模型
torchluent的Python项目详细描述
pytorch fluent模型
为创建pytorch模型提供流畅界面的小软件包。
摘要
一个流畅的接口大致上就是一个链接方法调用的接口。阅读更多关于 流畅的界面here。
这个库允许密集层,卷积层,最大池, 以及非线性或其他运算符(即标准化)。这个计算 每层后的新形状,意味着您不必重复 指定功能。
考虑以下纯pytorch代码:
importtorch.nnasnnnet=nn.Sequential(nn.Linear(28*28,128),nn.Linear(128,10))
第二层(128)的输入必须始终与第一层的输出匹配 层。这种冗余非常小,但可以改进。问题变成 当你考虑卷积层时更明显。
此外,官方的pytorch库不包括一些常见的胶水 扩展顺序块的代码。一个可能的原因是 fluent的api不太可能像传统的api那样详尽 不管怎样,我们通常不得不依赖于更详细的模块定义。
最后,它有非常多功能的then
和then_with
,其中
工作在转置卷积层和未冷却的同时仍然避免
冗余层大小或通道号。
API参考
https://tjstretchalot.github.io/torchluent/
用法
使用输入的形状创建torchluent.FluentModule
的实例。
fluentmodule上有一些元函数,比如
将打印形状如何通过渐进调用更改。对于那些
更改可以在一般意义上调用.transform
的功能数
或者使用提供的函数之一,例如.dense
,它将计算
新功能的数量。对于不改变数据形状的图层,
而不是为每个函数都包含一个函数,您可以使用.operator
接受torch.nn
中属性的名称以及参数或
关键字参数。
安装
pip install torchluent
示例
fromtorchluentimportFluentModuleprint('Network:')net=(FluentModule((1,28,28)).verbose().conv2d(32,kernel_size=5).maxpool2d(kernel_size=3).operator('LeakyReLU',negative_slope=0.05).flatten().dense(128).operator('ReLU').dense(10).operator('ReLU').build())print(net)
产生:
Network:
(1, 28, 28)
Conv2d -> (32, 24, 24)
MaxPool2d -> (32, 8, 8)
LeakyReLU
Reshape -> (2048,)
Linear -> (128,)
ReLU
Linear -> (10,)
ReLU
Sequential(
(0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
(1): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
(2): LeakyReLU(negative_slope=0.05)
(3): Reshape(2048)
(4): Linear(in_features=2048, out_features=128, bias=True)
(5): ReLU()
(6): Linear(in_features=128, out_features=10, bias=True)
(7): ReLU()
)
包装和拆封
默认情况下,pytorch中没有一个概念是考虑隐藏的 以抽象的方式描述任意网络的状态。这个想法是 基本上,如果模块除了返回 转换后的输出,其中返回数组中的每个元素都是快照 当输入通过网络传播时。
下面是一个虚构的示例,演示了这样一个模块可能 看起来像:
importtorch.nnasnnclassHiddenStateModule(nn.Module):defforward(self,x):result=[]result.append(x)# initial state always therex=x**2result.append(x)# where relevantx=x*3+2x=torch.relu(x)result.append(x)returnx,result
此模块意味着不必修改
底层转换(即nn.Linear
)也不能强制回退
为这种非常常见的情况创建自定义模块。
但是,这类模块出现的另一个问题是 如果只需要一个输出,结果将破坏大部分代码库。这个 当与一些抽象的训练范例相结合时,比如 火把点燃。幸运的是,很容易从 这样的一个模块,好像由下面的
importtorch.nnasnnclassStrippedStateModule(nn.Module):def__init__(self,mod):super().__init__()self.mod=moddefforward(self,x):returnself.mod(x)[0]
通过在主实现中包含数组,然后使用 “拆封”模块你可以得到最好的两个世界。用于培训和 不需要隐藏状态的通用用法,请使用剥离版本。 对于需要隐藏状态的分析,使用预剥离版本。
考虑到这个上下文,下面的代码片段将生成 网络的已包装和未包装版本:
fromtorchluentimportFluentModuleprint('Network:')net,stripped_net=(FluentModule((28*28,)).verbose().wrap(with_input=True)# create array and initialize with input.dense(128).operator('ReLU').save_state()# pushes to the array.dense(128).operator('ReLU').save_state().dense(10).operator('ReLU').save_state().build(with_stripped=True))print()print(net)
产生
Network:
(784,)
Linear -> (128,)
ReLU
Linear -> (128,)
ReLU
Linear -> (10,)
ReLU
Sequential(
(0): InitListModule(include_first=True)
(1): WrapModule(
(child): Linear(in_features=784, out_features=128, bias=True)
)
(2): WrapModule(
(child): ReLU()
)
(3): SaveStateModule()
(4): WrapModule(
(child): Linear(in_features=128, out_features=128, bias=True)
)
(5): WrapModule(
(child): ReLU()
)
(6): SaveStateModule()
(7): WrapModule(
(child): Linear(in_features=128, out_features=10, bias=True)
)
(8): WrapModule(
(child): ReLU()
)
(9): SaveStateModule()
)
限制
对于非平凡的网络,可能会有大量使用then
以及then_with
函数,它们不如所示示例好
但我相信这仍然是一个重大的进步。