统计自适应随机优化方法
statopt的Python项目详细描述
统计自适应随机梯度法
Pythorch优化器包,可根据在线统计测试自动安排学习速率。在
- 主要算法:SALSA和SASA
- 辅助代码:QHM和SSLS
论文合著:张、朗、刘、肖著,2020年。在
安装
pip install statopt
或来自Github:
^{pr2}$莎莎和莎莎的用法
下面我们将概述CIFAR10的关键步骤。 完整的Python代码在examples/cifar_example.py中给出。在
常用设置
首先,选择一个批大小并准备数据集和数据加载器,如this PyTorch tutorial:
importtorch,torchvisionbatch_size=128trainset=torchvision.datasets.CIFAR10(root='../data',train=True,...)trainloader=torch.utils.data.DataLoader(trainset,batch_size=batch_size,...)
选择设备、网络型号和损耗函数:
device='cuda'iftorch.cuda.is_available()else'cpu'net=torchvision.models.resnet18().to(device)loss_func=torch.nn.CrossEntropyLoss()
萨尔萨
导入statopt
,并使用较小的学习率和两个额外参数初始化SALSA:
importstatoptgamma=math.sqrt(batch_size/len(trainset))# smoothing parameter for line searchtestfreq=min(1000,len(trainloader))# frequency to perform statistical test optimizer=statopt.SALSA(net.parameters(),lr=1e-3,# any small initial learning rate momentum=0.9,weight_decay=5e-4,# common choices for CIFAR10/100gamma=gamma,testfreq=testfreq)# two extra parameters for SALSA
使用萨尔萨语的培训代码
forepochinrange(100):for(images,labels)intrainloader:net.train()# always switch to train() mode# Compute model outputs and loss function images,labels=images.to(device),labels.to(device)loss=loss_func(net(images),labels)# Compute gradient with back-propagation optimizer.zero_grad()loss.backward()# SALSA requires a closure function for line searchdefeval_loss(eval_mode=True):ifeval_mode:net.eval()withtorch.no_grad():loss=loss_func(net(images),labels)returnlossoptimizer.step(closure=eval_loss)
SASA
与大多数其他优化器一样,SASA需要良好的(手动调整的)初始学习率,但不要使用行搜索:
optimizer=statopt.SASA(net.parameters(),lr=1.0,# need a good initial learning rate momentum=0.9,weight_decay=5e-4,# common choices for CIFAR10/100testfreq=testfreq)# frequency for statistical tests
在训练循环中:optimizer.step()
不需要任何闭包函数。在
- 项目
标签: