带生成元的笛卡尔积

2024-05-20 14:09:49 发布

您现在位置:Python中文网/ 问答频道 /正文

这是链接到Cartesian product of nested dictionaries of lists

假设我有一个嵌套的dict,其中的列表表示多个配置,如:

{'algorithm': ['PPO', 'A2C', 'DQN'],
'env_config': {'env': 'GymEnvWrapper-Atari',
'env_config': {'AtariEnv': {'game': ['breakout', 'pong']}}}

目标是计算嵌套dict中列表的笛卡尔积,以得到所有可能的配置

到目前为止,我得到的是:

def product(*args, repeat=1, root=False):
    pools = [tuple(pool) for pool in args] * repeat
    result = [[]]
    for pool in pools:
        result = [x+[y] for x in result for y in pool]
    print("************************")
    print(root)
    for r in result:
        print(tuple(r))
    print("************************")
    for prod in result:
        yield tuple(prod)


def recursive_cartesian_product(dic, root=True):
    # based on https://stackoverflow.com/a/50606871/11051330
    # added differentiation between list and entry to protect strings in dicts
    # with uneven depth
    keys, values = dic.keys(), dic.values()

    vals = (recursive_cartesian_product(v, False) if isinstance(v, dict)
            else v if isinstance(v, list) else (v,) for v in
            values)

    print("!", root)
    for conf in product(*vals, root=root):
        print(conf)
        yield dict(zip(keys, conf))

以下是相关输出:

************************
True
('PPO', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'breakout'}}})
('PPO', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
('A2C', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'breakout'}}})
('A2C', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
('DQN', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'breakout'}}})
('DQN', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
************************
('PPO', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'breakout'}}})
('PPO', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
('A2C', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
('A2C', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
('DQN', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})
('DQN', {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}})

请注意product内的print语句是如何正确工作的,而yield内的print语句失败,并且不会为以后的配置更改env


Tags: inenvgameconfigfordqnrootproduct
3条回答

结果表明,问题不在上述函数内部,而在其外部。生成的conf被作为**kwargs传递给一个函数,这会使生成器出错

下面是一个快速解决方案:

def recursive_cartesian_product(dic):
    # based on https://stackoverflow.com/a/50606871/11051330
    # added differentiation between list and entry to protect strings
    # yield contains deepcopy. important as use otherwise messes up generator
    keys, values = dic.keys(), dic.values()

    vals = (recursive_cartesian_product(v) if isinstance(v, dict)
            else v if isinstance(v, list) else (v,) for v in
            values)

    for conf in itertools.product(*vals):
        yield deepcopy(dict(zip(keys, conf)))

使用^{}确实比自己滚动要简单

如果您不希望您的env_config发生更改(游戏名称除外),则无需实现通用递归dict访问者。
因此,您只需要algorithmsgame名称的乘积,始终使用AtariEnv,然后:

from itertools import product

possible_configurations = {'algorithm': ['PPO', 'A2C', 'DQN'],
'env_config': {'env': 'GymEnvWrapper-Atari',
'env_config': {'AtariEnv': {'game': ['breakout', 'pong']}}}}

algorithms = tuple(possible_configurations["algorithm"])
games = tuple(
    {"env": "GymEnvWrapper-Atari", "env_config": {"AtariEnv": {"game": game_name}}}
    for game_name in possible_configurations["env_config"]["env_config"]["AtariEnv"]["game"]
)

factors = (algorithms, games)
for config in product(*factors):
    print(config)

如果你喜欢一个通用的解决方案,这里是我的:

from itertools import product

possible_configurations = {'algorithm': ['PPO', 'A2C', 'DQN'],
'env_config': {'env': 'GymEnvWrapper-Atari',
'env_config': {'AtariEnv': {'game': ['breakout', 'pong']}}}}


def product_visitor(obj):
    if isinstance(obj, dict):
        yield from (
            dict(possible_product)
            for possible_product in product(
                *(
                    [(key, possible_value) for possible_value in product_visitor(value)]
                    for key, value in obj.items())))
    elif isinstance(obj, list):
        for value in obj:
            yield from product_visitor(value)
    else:  # either a string, a number, a boolean or null (all scalars)
        yield obj


configs = tuple(product_visitor(possible_configurations))
print("\n".join(map(str, configs)))
assert configs == (
    {'algorithm': 'PPO', 'env_config': {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'breakout'}}}},
    {'algorithm': 'PPO', 'env_config': {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}}},
    {'algorithm': 'A2C', 'env_config': {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'breakout'}}}},
    {'algorithm': 'A2C', 'env_config': {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}}},
    {'algorithm': 'DQN', 'env_config': {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'breakout'}}}},
    {'algorithm': 'DQN', 'env_config': {'env': 'GymEnvWrapper-Atari', 'env_config': {'AtariEnv': {'game': 'pong'}}}},
)

itertools已具有product类型:

from itertools import product


d = {'algorithm': ['PPO', 'A2C', 'DQN'],
     'env_config': {'env': 'GymEnvWrapper-Atari',
                    'env_config': {'AtariEnv': {'game': ['breakout', 'pong']}}}

for algo, game in product(d['algorithm'],
                          d['env_config']['env_config']['AtariEnv']['game']):
    print((algo, {'env': 'GymEnvWrapper-Atari', 
                  'env_config': {'AtariEnv': {'game': game}}})) 

相关问题 更多 >