使用yaml选择模块类和函数,不使用任何ifstatement。
easy-module-attribute-getter的Python项目详细描述
简单模块属性获取器
安装
pip install easy-module-attribute-getter
问题是:无法维护的if语句和字典
通常在yaml配置文件中指定脚本参数。例如:
^{pr2}$通常,加载配置文件,然后使用各种if语句或开关实例化对象等。它可能看起来像这样(取决于配置文件的组织方式):
models={}forkin["modelA","modelB"]:model_name=list(args.models[k].keys())[0]ifmodel_name=="densenet121":models[k]=torchvision.models.densenet121(**args.models[k][model_name])elifmodel_name=="googlenet":models[k]=torchvision.models.googlenet(**args.models[k][model_name])elifmodel_name=="resnet50":models[k]=torchvision.models.resnet50(**args.models[k][model_name])elifmodel_name=="inception_v3":models[k]=torchvision.models.inception_v3(**args.models[k][model_name])...
这有点烦人,每次PyTorch添加您想要访问的新类或函数时,您都需要向巨型if语句添加新的case。另一种方法是制作词典:
model_dict = {"densenet121": torchvision.models.densenet121,
"googlenet": torchvision.models.googlenet,
"resnet50": torchvision.models.resnet50,
"inception_v3": torchvision.models.inception_v3
...}
models = {}
for k in ["modelA", "modelB"]:
model_name = list(args.models[k].keys())[0]
models[k] = model_dict[model_name](**args.models[k][model_name])
这比if语句短,但仍然需要您手动拼写出所有键和类。当软件包更新时,你仍然需要自己更新它。在
解决方案
在一行中获取并初始化多个模型
在这个包中,上面的for循环和if语句可以简化为:
fromeasy_module_attribute_getterimportPytorchGetterpytorch_getter=PytorchGetter()models=pytorch_getter.get_multiple("model",args.models)
“models”是一个字典,它将字符串(“modelA”和“modelB”)映射到所需的对象,这些对象已经用配置文件中指定的参数初始化。在
在一条线上访问多个模块
假设您想要访问默认包(torchvision.models公司),以及pretrainedmodels包,以及其他两个自定义模型模块X和Y。您可以注册这些模块:
pytorch_getter.register('model',pretrainedmodels)pytorch_getter.register('model',X)pytorch_getter.register('model',Y)
现在您仍然可以执行1行程序:
models=pytorch_getter.get_multiple("model",args.models)
Pythorch_getter将尝试所有4个注册的模块,直到匹配为止。在
自动让yaml访问新类
如果您升级到Pythorch的新版本,其中有20个新类,您不必更改任何内容。您可以自动访问所有新类,并且可以在yaml文件中指定它们。在
通过命令行合并或重写复杂的配置选项:
示例yaml文件包含“models”,它映射到包含modelA和modelb2的嵌套字典。使用标准的python表示法可以很容易地在命令行向模型添加另一个键。在
python example.py --models {modelC: {googlenet: {pretrained: True}}}
然后在你的剧本里:
importargparseyaml_reader=YamlReader(argparse.ArgumentParser())args,_,_=yaml_reader.load_yamls({"models":['models.yaml'],"losses":['losses.yaml']},max_merge_depth=float('inf'))
现在args.型号包含3个模型。在
如果您通常希望合并配置选项,那么在load_yamls函数中,将max_merge_depth参数设置为要应用合并的子字典数。在
如果将max_merge_depth设置为1,但希望对某个特定标志执行完全覆盖,该怎么办?在这种情况下,只需将~OVERRIDE~附加到标志上:
python example.py --models~OVERRIDE~ {modelC: {googlenet: {pretrained: True}}}
现在args.型号将只包含modelC,即使max_merge_depth设置为1。在
将一个或多个yaml文件加载到一个args对象
fromeasy_module_attribute_getterimportYamlReaderyaml_reader=YamlReader()args,_,_=yaml_reader.load_yamls(['models.yaml'])
提供文件路径列表:
args,_,_=yaml_reader.load_yamls(['models.yaml','optimizers.yaml','transforms.yaml'])
或者提供根路径和将子文件夹名称映射到裸文件名的字典
root_path="/where/your/yaml/subfolders/are/"subfolder_to_name_dict={"models":"default","optimizers":"special_trial","transforms":"blah"}args,_,_=yaml_reader.load_yamls(root_path=root_path,subfolder_to_name_dict=subfolder_to_name_dict)
Pythorch特定功能
变换
在配置文件中指定转换:
transforms:train:Resize:size:256RandomResizedCrop:scale:0.16 1ratio:0.75 1.33size:227RandomHorizontalFlip:p:0.5eval:Resize:size:256CenterCrop:size:227
然后在脚本中加载合成的变换:
transforms={}fork,vinargs.transforms.items():transforms[k]=pytorch_getter.get_composed_img_transform(v,mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
转换dict现在包含:
{'train':Compose(Resize(size=256,interpolation=PIL.Image.BILINEAR)RandomResizedCrop(size=(227,227),scale=(0.16,1),ratio=(0.75,1.33),interpolation=PIL.Image.BILINEAR)RandomHorizontalFlip(p=0.5)ToTensor()Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])),'eval':Compose(Resize(size=256,interpolation=PIL.Image.BILINEAR)CenterCrop(size=(227,227))ToTensor()Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225]))}
优化器、调度程序和渐变裁剪器
(可选)在优化器参数中指定调度程序和渐变剪裁规范。在
调度程序密钥应该是scheduler_by_epoch
、scheduler_by_iteration
和{
optimizers:modelA:Adam:lr:0.00001weight_decay:0.00005scheduler_by_epoch:StepLR:step_size:2gamma:0.95scheduler_by_iteration:ExponentialLR:gamma:0.99clip_grad_norm:1modelB:Adam:lr:0.00001weight_decay:0.00005
创建优化器:
optimizers={}schedulers={}grad_clippers={}fork,vinmodels.items():optimizers[k],schedulers[k],grad_clippers[k]=pytorch_getter.get_optimizer(v,yaml_dict=args.optimizers[k])
不仅仅是为了Pythorch
注意,YamlReader和EasyModuleAttributeGetter类完全独立于PyTorch。我编写了子类PyTorchGetter,因为这正是我使用这个包的目的,但是其他两个类可以在一般情况下使用,并且可以为您自己的目的进行扩展。在
- 项目
标签: