用pytorch建立(部分)可逆神经网络的框架
revtorch的Python项目详细描述
旋转手电筒
用pytorch建立(部分)可逆神经网络的框架
本文介绍并解释了revtorch, 在MICCAI 2019接受陈述。
如果您发现此代码对您的研究有帮助,请引用以下文章:
@article{PartiallyRevUnet2019Bruegger,
author={Br{\"u}gger, Robin and Baumgartner, Christian F.
and Konukoglu, Ender},
title={A Partially Reversible U-Net for Memory-Efficient Volumetric Image Segmentation},
journal={arXiv:1906.06148},
year={2019},
安装
使用pip安装revtorch:
$ pip install revtorch
revtorch需要pytorch。但是,pytorch不包含在依赖项中,因为所需的pytorch版本依赖于您的系统。请按照PyTorch website上的说明安装pytorch。
用法
这个例子展示了如何使用revtorch框架。
importtorchimporttorchvisionimporttorchvision.transformsastransformsimporttorch.nnasnnimporttorch.nn.functionalasFimporttorch.optimasoptimimportrevtorchasrvdeftrain():trainset=torchvision.datasets.CIFAR10(root="./data",train=True,download=True,transform=transforms.ToTensor())trainloader=torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True)net=PartiallyReversibleNet()criterion=nn.CrossEntropyLoss()optimizer=optim.Adam(net.parameters())forepochinrange(2):running_loss=0.0fori,datainenumerate(trainloader,0):inputs,labels=dataoptimizer.zero_grad()outputs=net(inputs)loss=criterion(outputs,labels)loss.backward()optimizer.step()#logging stuffrunning_loss+=loss.item()LOG_INTERVAL=200ifi%LOG_INTERVAL==(LOG_INTERVAL-1):# print every 2000 mini-batchesprint('[%d, %5d] loss: %.3f'%(epoch+1,i+1,running_loss/LOG_INTERVAL))running_loss=0.0classPartiallyReversibleNet(nn.Module):def__init__(self):super(PartiallyReversibleNet,self).__init__()#initial non-reversible convolution to get to 32 channelsself.conv1=nn.Conv2d(3,32,3)#construct reversible sequencce with 4 reversible blocksblocks=[]foriinrange(4):#f and g must both be a nn.Module whos output has the same shape as its inputf_func=nn.Sequential(nn.ReLU(),nn.Conv2d(16,16,3,padding=1))g_func=nn.Sequential(nn.ReLU(),nn.Conv2d(16,16,3,padding=1))#we construct a reversible block with our F and G functionsblocks.append(rv.ReversibleBlock(f_func,g_func))#pack all reversible blocks into a reversible sequenceself.sequence=rv.ReversibleSequence(nn.ModuleList(blocks))#non-reversible convolution to get to 10 channels (one for each label)self.conv2=nn.Conv2d(32,10,3)defforward(self,x):x=self.conv1(x)#the reversible sequence can be used like any other nn.Module. Memory-saving backpropagation is used automaticallyx=self.sequence(x)x=self.conv2(F.relu(x))x=F.avg_pool2d(x,(x.shape[2],x.shape[3]))x=x.view(x.shape[0],x.shape[1])returnxif__name__=="__main__":train()
python版本
使用python 3.6和pytorch 1.1.0进行测试。应该适用于任何版本的python 3。