根据规则对pytorch参数进行分组
torch-parameter-groups的Python项目详细描述
火炬参数组
根据规则对pytorch参数进行分组。
安装
需要Python3.6+。
pip install torch-parameter-groups
用法
importtorchimporttorch.nnasnnimporttorch_basic_modelsimporttorch_parameter_groupsmodel=torch_basic_models.MobileNetV2.factory()optimizer=torch_parameter_groups.optimizer_factory(model=model,config={'type':'SGD','kwargs':{'momentum':0.9,'nesterov':True,'weight_decay':0.0001,},'rules':[{'param_name_list':['weight'],'kwargs':{'weight_decay':0}},{}]},)criterion=nn.CrossEntropyLoss()output=model(torch.randn(1,3,224,224))loss=criterion(output,torch.Tensor([0]).long())loss.backward()optimizer.step(closure=None)