从Spark中管道内的StringIndexer阶段获取标签(pyspark)

2024-09-27 07:30:45 发布

您现在位置:Python中文网/ 问答频道 /正文

我使用的是Sparkpyspark,我有一个pipeline设置了一组StringIndexer对象,用于将字符串列编码为索引列:

indexers = [StringIndexer(inputCol=column, outputCol=column + '_index').setHandleInvalid('skip')
            for column in list(set(data_frame.columns) - ignore_columns)]
pipeline = Pipeline(stages=indexers)
new_data_frame = pipeline.fit(data_frame).transform(data_frame)

问题是,我需要在每个StringIndexer对象装配好后得到它的标签列表。对于没有管道的单个列和单个StringIndexer来说,这是一项简单的任务。在DataFrame上安装索引器后,我可以访问labels属性:

^{pr2}$

然而,当我使用管道时,这似乎是不可能的,或者至少我不知道如何做到这一点。在

所以我想我的问题归结为: 有没有一种方法可以访问索引过程中为每个单独列使用的标签?在

或者,在这个用例中,我将不得不放弃管道,例如遍历StringIndexer对象列表并手动执行?)我相信这是可能的。不过,使用管道会更好)


Tags: columns对象字符串编码列表data管道pipeline
1条回答
网友
1楼 · 发布于 2024-09-27 07:30:45

示例数据和Pipeline

from pyspark.ml.feature import StringIndexer, StringIndexerModel

df = spark.createDataFrame([("a", "foo"), ("b", "bar")], ("x1", "x2"))

pipeline = Pipeline(stages=[
    StringIndexer(inputCol=c, outputCol='{}_index'.format(c))
    for c in df.columns
])

model = pipeline.fit(df)

摘录自stages

^{pr2}$
{'x1_index': ['a', 'b'], 'x2_index': ['foo', 'bar']}

从已转换的DataFrame的元数据:

indexed = model.transform(df)

{c.name: c.metadata["ml_attr"]["vals"]
for c in indexed.schema.fields if c.name.endswith("_index")}
{'x1_index': ['a', 'b'], 'x2_index': ['foo', 'bar']}

相关问题 更多 >

    热门问题