使用Dagster进行交叉验证

2024-09-26 22:55:25 发布

您现在位置:Python中文网/ 问答频道 /正文

我已经开始在我们的ML管道中使用Dagster,并遇到一些基本问题,我想知道我是否遗漏了一些琐碎的东西,或者这就是它的本质

假设我有一个简单的ML管道:

Load raw data --> Process data into table --> Split train / test --> train model --> evaluate model.

线性模型在Dagster中是直截了当的。但如果我想添加一个小循环,比如为了交叉验证的目的:

Load raw data --> Process data into table --> Split into k folds, and for each fold:
  - fold 1: train model --> evaluate
  - fold 2: train model --> evaluate
  - fold 3: train model --> evaluate
  --> summarize cross validation results.

是否有一个好的&;用Dagster干净利落的方法?我做事的方式是:

Load raw data --> Process data into table --> Split into K folds --> choose fold k --> train model --> evaluate model

将折叠“k”作为管道的输入参数。然后运行管道K次

我错过了什么


Tags: datarawmodel管道tableloadtrainfold
1条回答
网友
1楼 · 发布于 2024-09-26 22:55:25

是的,Dagster确实支持将实体扇出到多个实体中,而不是在单个管道中扇入到水槽实体(即汇总结果)。下面是一些示例代码和dagit(full dagzoomed in)中相应的dag可视化

@solid
def load_raw_data(_):
    yield Output('loaded_data')


@solid
def process_data_into_table(_, raw_data):
    yield Output(raw_data)


@solid(
    output_defs=[
        OutputDefinition(name='fold_one', dagster_type=int, is_required=True),
        OutputDefinition(name='fold_two', dagster_type=int, is_required=True),
    ],
)
def split_into_two_folds(_, table):
    yield Output(1, 'fold_one')
    yield Output(2, 'fold_two')


@solid
def train_fold(_, fold):
    yield Output('model')


@solid
def evaluate_fold(_, model):
    yield Output('compute_result')


@composite_solid
def process_fold(fold):
    return evaluate_fold(train_fold(fold))


@solid
def summarize_results(context, fold_1_result, fold_2_result):
    yield Output('summary_stats')


@pipeline
def ml_pipeline():
    fold_one, fold_two = split_into_two_folds(process_data_into_table(load_raw_data()))

    process_fold_one = process_fold.alias('process_fold_one')
    process_fold_two = process_fold.alias('process_fold_two')

    summarize_results(process_fold_one(fold_one), process_fold_two(fold_two))

在示例代码中,我们使用别名es,以便对每个折叠重复使用相同的逻辑。我们还整合了在复合实体中处理每个折叠的逻辑

另一种选择是直接以编程方式创建PipelineDefinition,但我建议使用上述方法

相关问题 更多 >

    热门问题