这是链接到Cartesian product of nested dictionaries of lists
假设我有一个嵌套的dict,其中的列表表示多个配置,如:
{'algorithm': ['PPO', 'A2C', 'DQN'],
'env_config': {'env': 'GymEnvWrapper-Atari',
'env_config': {'AtariEnv': {'game': ['breakout', 'pong']}}}
目标是计算嵌套dict中列表的笛卡尔积,以得到所有可能的配置
到目前为止,我得到的是:
def product(*args, repeat=1, root=False):
pools = [tuple(pool) for pool in args] * repeat
result = [[]]
for pool in pools:
result = [x+[y] for x in result for y in pool]
print("************************")
print(root)
for r in result:
print(tuple(r))
print("************************")
for prod in result:
yield tuple(prod)
def recursive_cartesian_product(dic, root=True):
# based on https://stackoverflow.com/a/50606871/11051330
# added differentiation between list and entry to protect strings in dicts
# with uneven depth
keys, values = dic.keys(), dic.values()
vals = (recursive_cartesian_product(v, False) if isinstance(v, dict)
else v if isinstance(v, list) else (v,) for v in
values)
print("!", root)
for conf in product(*vals, root=root):
print(conf)
yield dict(zip(keys, conf))
以下是相关输出:
************************
True
('PPO', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'breakout'}}})
('PPO', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
('A2C', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'breakout'}}})
('A2C', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
('DQN', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'breakout'}}})
('DQN', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
************************
('PPO', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'breakout'}}})
('PPO', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
('A2C', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
('A2C', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
('DQN', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
('DQN', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
请注意product
内的print语句是如何正确工作的,而yield
内的print语句失败,并且不会为以后的配置更改env
值
结果表明,问题不在上述函数内部,而在其外部。生成的conf被作为
**kwargs
传递给一个函数,这会使生成器出错下面是一个快速解决方案:
使用^{} 确实比自己滚动要简单
如果您不希望您的
env_config
发生更改(游戏名称除外),则无需实现通用递归dict访问者。因此,您只需要
algorithms
与game
名称的乘积,始终使用AtariEnv
,然后:如果你喜欢一个通用的解决方案,这里是我的:
itertools
已具有product
类型:相关问题 更多 >
编程相关推荐