pytorch框架下卷积网络的flops计数器

ptflops的Python项目详细描述


pytorch框架中卷积网络的flops计数器

Pypi version

此脚本用于计算乘法加法运算的理论量 在卷积神经网络中。它还可以计算参数的数量和 给定网络的每层打印计算成本。

支持的层:

  • conv1d/2d/3d(包括分组)
  • convTranspasse2d(包括分组)
  • 批次标准1d/2d/3d
  • 激活(relu、prelu、elu、relu6、leakyrelu)
  • 线性
  • 向上采样
  • 池(avgpool1d/2d/3d、maxpool1d/2d/3d和自适应池)

要求:Pythorch>;=0.4.1,TorchVision>;=0.2.1

感谢@warmspringwinds提供了脚本的初始版本。

使用技巧

  • 此脚本不考虑torch.nn.functional.*操作。例如,如果一个人有一个语义分割模型,并且使用torch.nn.functional.interpolate来提升特性,那么这些操作将不会贡献总的失败次数。为了避免这种情况,可以使用torch.nn.Upsample,而不是torch.nn.functional.interpolate
  • ptflops在一个随机张量上启动一个给定的模型,并在推理过程中估计计算量。复杂的模型可以有几个输入,其中一些可以是可选的。要构造非平凡的输入,可以使用get_model_complexity_infoinput_constructor参数。input_constructor是一个函数,它将输入空间分辨率作为元组,并返回带有模型的命名输入参数的dict。接下来,这个dict将作为关键字参数传递给模型。

安装最新版本

pip install --upgrade git+https://github.com/sovrasov/flops-counter.pytorch.git

示例

importtorchvision.modelsasmodelsimporttorchfromptflopsimportget_model_complexity_infowithtorch.cuda.device(0):net=models.densenet161()flops,params=get_model_complexity_info(net,(3,224,224),as_strings=True,print_per_layer_stat=True)print('Flops:  '+flops)print('Params: '+params)

基准

torchvision

ModelInput ResolutionParams(M)MACs(G)Top-1 errorTop-5 error
alexnet224x22461.10.7243.4520.91
vgg11224x224132.867.6330.9811.37
vgg13224x224133.0511.3430.0710.75
vgg16224x224138.3615.528.419.62
vgg19224x224143.6719.6727.629.12
vgg11_bn224x224132.877.6429.6210.19
vgg13_bn224x224133.0511.3628.459.63
vgg16_bn224x224138.3715.5326.638.50
vgg19_bn224x224143.6819.725.768.15
resnet18224x22411.691.8230.2410.92
resnet34224x22421.83.6826.708.58
resnet50224x22425.564.1223.857.13
resnet101224x22444.557.8522.636.44
resnet152224x22460.1911.5821.695.94
squeezenet1_0224x2241.250.8341.9019.58
squeezenet1_1224x2241.240.3641.8119.38
densenet121224x2247.982.8825.357.83
densenet169224x22414.153.4224.007.00
densenet201224x22420.014.3722.806.43
densenet161224x22428.687.8222.356.20
inception_v3224x22427.162.8522.556.44
  • Top-1错误-ImageNet单裁剪Top-1错误(224x224)
  • Top-5错误-ImageNet单次裁剪Top-5错误(224x224)

Cadene/pretrained-models.pytorch

ModelInput ResolutionParams(M)MACs(G)Acc@1Acc@5
alexnet224x22461.10.7256.43279.194
bninception224x22411.32.0573.52491.562
cafferesnet101224x22444.557.6276.292.766
densenet121224x2247.982.8874.64692.136
densenet161224x22428.687.8277.5693.798
densenet169224x22414.153.4276.02692.992
densenet201224x22420.014.3777.15293.548
dpn107224x22486.9218.4279.74694.684
dpn131224x22479.2516.1379.43294.574
dpn68224x22412.612.3675.86892.774
dpn68b224x22412.612.3677.03493.59
dpn92224x22437.676.5679.494.62
dpn98224x22461.5711.7679.22494.488
fbresnet152224x22460.2711.677.38693.594
inceptionresnetv2299x29955.8413.2280.1795.234
inceptionv3299x29927.165.7377.29493.454
inceptionv4299x29942.6812.3180.06294.926
nasnetalarge331x33188.7524.0482.56696.086
nasnetamobile224x2245.290.5974.0891.74
pnasnet5large331x33186.0625.2182.73695.992
polynet331x33195.3734.981.00295.624
resnet101224x22444.557.8577.43893.672
resnet152224x22460.1911.5878.42894.11
resnet18224x22411.691.8270.14289.274
resnet34224x22421.83.6873.55491.456
resnet50224x22425.564.1276.00292.98
resnext101_32x4d224x22444.188.0378.18893.886
resnext101_64x4d224x22483.4615.5578.95694.252
se_resnet101224x22449.337.6378.39694.258
se_resnet152224x22466.8211.3778.65894.374
se_resnet50224x22428.093.977.63693.752
se_resnext101_32x4d224x22448.968.0580.23695.028
se_resnext50_32x4d224x22427.564.2879.07694.434
senet154224x224115.0920.8281.30495.498
squeezenet1_0224x2241.250.8358.10880.428
squeezenet1_1224x2241.240.3658.2580.8
vgg11224x224132.867.6368.9788.746
vgg11_bn224x224132.877.6470.45289.818
vgg13224x224133.0511.3469.66289.264
vgg13_bn224x224133.0511.3671.50890.494
vgg16224x224138.3615.571.63690.354
vgg16_bn224x224138.3715.5373.51891.608
vgg19224x224143.6719.6772.0890.822
vgg19_bn224x224143.6819.774.26692.066
xception299x29922.868.4278.88894.292
  • acc@1-imagenet在训练过程中使用的相同大小的验证图像上的单次裁剪最高精度。
  • acc@5-imagenet在训练过程中使用的相同大小的验证图像的单次裁剪精度达到前5名。

欢迎加入QQ群-->: 979659372 Python中文网_新手群

推荐PyPI第三方库


热门话题
Diamond不编译Java 7   web服务如何将文件(txt)从java上传到restful web服务?   java将用户组分配给角色Liferay 6.1.1   java打印对象并检查null   java JScrollPane重写setHvalue   java如何使用JAXRS标准客户端API处理来自Web服务的错误JSON内容类型?   java在尝试解压7Zip归档文件(以二进制模式)时遇到异常   在Java中从派生类调用基类构造函数   java如何在微文件rest客户端中动态添加HTTP头?   PrintWriter[]数组Java中的NullPointerException   java addFlashAttribute和保存数据   java Jboss eap 6.4到Wildfly 14/18的迁移   java eclipse:由于“无法读取文件…~$somefile.xlsx”而未生成项目