在流畅的界面中构建pytorch模型

torchluent的Python项目详细描述


pytorch fluent模型

为创建pytorch模型提供流畅界面的小软件包。

摘要

一个流畅的接口大致上就是一个链接方法调用的接口。阅读更多关于 流畅的界面here

这个库允许密集层,卷积层,最大池, 以及非线性或其他运算符(即标准化)。这个计算 每层后的新形状,意味着您不必重复 指定功能。

考虑以下纯pytorch代码:

importtorch.nnasnnnet=nn.Sequential(nn.Linear(28*28,128),nn.Linear(128,10))

第二层(128)的输入必须始终与第一层的输出匹配 层。这种冗余非常小,但可以改进。问题变成 当你考虑卷积层时更明显。

此外,官方的pytorch库不包括一些常见的胶水 扩展顺序块的代码。一个可能的原因是 fluent的api不太可能像传统的api那样详尽 不管怎样,我们通常不得不依赖于更详细的模块定义。

最后,它有非常多功能的thenthen_with,其中 工作在转置卷积层和未冷却的同时仍然避免 冗余层大小或通道号。

API参考

https://tjstretchalot.github.io/torchluent/

用法

使用输入的形状创建torchluent.FluentModule的实例。 fluentmodule上有一些元函数,比如 将打印形状如何通过渐进调用更改。对于那些 更改可以在一般意义上调用.transform的功能数 或者使用提供的函数之一,例如.dense,它将计算 新功能的数量。对于不改变数据形状的图层, 而不是为每个函数都包含一个函数,您可以使用.operator 接受torch.nn中属性的名称以及参数或 关键字参数。

安装

pip install torchluent

示例

fromtorchluentimportFluentModuleprint('Network:')net=(FluentModule((1,28,28)).verbose().conv2d(32,kernel_size=5).maxpool2d(kernel_size=3).operator('LeakyReLU',negative_slope=0.05).flatten().dense(128).operator('ReLU').dense(10).operator('ReLU').build())print(net)

产生:

Network:
  (1, 28, 28)
  Conv2d -> (32, 24, 24)
  MaxPool2d -> (32, 8, 8)
  LeakyReLU
  Reshape -> (2048,)
  Linear -> (128,)
  ReLU
  Linear -> (10,)
  ReLU

Sequential(
  (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
  (1): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
  (2): LeakyReLU(negative_slope=0.05)
  (3): Reshape(2048)
  (4): Linear(in_features=2048, out_features=128, bias=True)
  (5): ReLU()
  (6): Linear(in_features=128, out_features=10, bias=True)
  (7): ReLU()
)

包装和拆封

默认情况下,pytorch中没有一个概念是考虑隐藏的 以抽象的方式描述任意网络的状态。这个想法是 基本上,如果模块除了返回 转换后的输出,其中返回数组中的每个元素都是快照 当输入通过网络传播时。

下面是一个虚构的示例,演示了这样一个模块可能 看起来像:

importtorch.nnasnnclassHiddenStateModule(nn.Module):defforward(self,x):result=[]result.append(x)# initial state always therex=x**2result.append(x)# where relevantx=x*3+2x=torch.relu(x)result.append(x)returnx,result

此模块意味着不必修改 底层转换(即nn.Linear)也不能强制回退 为这种非常常见的情况创建自定义模块。

但是,这类模块出现的另一个问题是 如果只需要一个输出,结果将破坏大部分代码库。这个 当与一些抽象的训练范例相结合时,比如 火把点燃。幸运的是,很容易从 这样的一个模块,好像由下面的

importtorch.nnasnnclassStrippedStateModule(nn.Module):def__init__(self,mod):super().__init__()self.mod=moddefforward(self,x):returnself.mod(x)[0]

通过在主实现中包含数组,然后使用 “拆封”模块你可以得到最好的两个世界。用于培训和 不需要隐藏状态的通用用法,请使用剥离版本。 对于需要隐藏状态的分析,使用预剥离版本。

考虑到这个上下文,下面的代码片段将生成 网络的已包装和未包装版本:

fromtorchluentimportFluentModuleprint('Network:')net,stripped_net=(FluentModule((28*28,)).verbose().wrap(with_input=True)# create array and initialize with input.dense(128).operator('ReLU').save_state()# pushes to the array.dense(128).operator('ReLU').save_state().dense(10).operator('ReLU').save_state().build(with_stripped=True))print()print(net)

产生

Network:
  (784,)
  Linear -> (128,)
  ReLU
  Linear -> (128,)
  ReLU
  Linear -> (10,)
  ReLU

Sequential(
  (0): InitListModule(include_first=True)
  (1): WrapModule(
    (child): Linear(in_features=784, out_features=128, bias=True)
  )
  (2): WrapModule(
    (child): ReLU()
  )
  (3): SaveStateModule()
  (4): WrapModule(
    (child): Linear(in_features=128, out_features=128, bias=True)
  )
  (5): WrapModule(
    (child): ReLU()
  )
  (6): SaveStateModule()
  (7): WrapModule(
    (child): Linear(in_features=128, out_features=10, bias=True)
  )
  (8): WrapModule(
    (child): ReLU()
  )
  (9): SaveStateModule()
)

限制

对于非平凡的网络,可能会有大量使用then 以及then_with函数,它们不如所示示例好 但我相信这仍然是一个重大的进步。

欢迎加入QQ群-->: 979659372 Python中文网_新手群

推荐PyPI第三方库


热门话题
java我是否正确地实现了广告,为什么没有显示?   java Maven编译器插件与Maven默认插件?   java如何通过在Hibernate中引入二级缓存来解决N+1问题?   java如何在Android中绘制位图   java再次关闭这个“FileOutputStream”声纳   Java Android Studio应用程序开发NullPointerException。我的应用程序强制在加载第二个活动时关闭   java无法使用与postman应用程序中相同的restTemplate发送请求   java为计算器拆分输入字符串   java底部导航栏未显示在活动中   java XML读取具有不同段的相同标记   java从文本文件中添加值   java将外部JAR与插件库目录分离   spock框架中的java高级助手方法   azure ADAL for Java Proxy   java如何使用Apache httpclient 4为每个请求设置超时。*使用PoollightTPClientConnectionManager?