向withColumn下的udf传递数据帧列和外部列表

2024-05-19 12:05:03 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个Spark数据框,结构如下。bodyText_令牌具有令牌(已处理/一组字)。我有一个嵌套的已定义关键字列表

root
 |-- id: string (nullable = true)
 |-- body: string (nullable = true)
 |-- bodyText_token: array (nullable = true)

keyword_list=['union','workers','strike','pay','rally','free','immigration',],
['farmer','plants','fruits','workers'],['outside','field','party','clothes','fashions']]

我需要检查每个关键字列表下有多少个标记,并将结果作为现有数据帧的新列添加。 例如:如果tokens =["become", "farmer","rally","workers","student"] 结果将是->;[1,2,0]

以下函数按预期工作。

def label_maker_topic(tokens,topic_words):
    twt_list = []
    for i in range(0, len(topic_words)):
        count = 0
        #print(topic_words[i])
        for tkn in tokens:
            if tkn in topic_words[i]:
                count += 1
        twt_list.append(count)

    return twt_list

我使用withColumn下的udf访问该函数,得到一个错误。我想是把一个外部列表传递给一个自定义项。有没有方法可以将外部列表和datafram列传递到udf并将新列添加到dataframe?

topicWord = udf(label_maker_topic,StringType())
myDF=myDF.withColumn("topic_word_count",topicWord(myDF.bodyText_token,keyword_list))

Tags: 数据intrue列表topiccountlistwords
3条回答

最干净的解决方案是使用闭包传递其他参数:

def make_topic_word(topic_words):
     return udf(lambda c: label_maker_topic(c, topic_words))

df = sc.parallelize([(["union"], )]).toDF(["tokens"])

(df.withColumn("topics", make_topic_word(keyword_list)(col("tokens")))
    .show())

这不需要对keyword_list或用UDF包装的函数进行任何更改。也可以使用此方法传递任意对象。例如,这可以用来传递一个sets列表,以便进行有效的查找。

如果要使用当前的UDF并直接传递topic_words,则必须首先将其转换为列文字:

from pyspark.sql.functions import array, lit

ks_lit = array(*[array(*[lit(k) for k in ks]) for ks in keyword_list])
df.withColumn("ad", topicWord(col("tokens"), ks_lit)).show()

根据您的数据和需求,可以有其他更高效的解决方案,不需要udf(explode+aggregate+collapse)或查找(hashing+vector操作)。

在任何外部参数都可以传递给UDF的情况下,下面的代码都可以正常工作(这是一个经过调整的代码,可以帮助任何人)

topicWord=udf(lambda tkn: label_maker_topic(tkn,topic_words),StringType())
myDF=myDF.withColumn("topic_word_count",topicWord(myDF.bodyText_token))

另一种方法是使用functools模块的partial

from functools import partial

func_to_call = partial(label_maker_topic, topic_words=keyword_list)

pyspark_udf = udf(func_to_call, <specify_the_type_returned_by_function_here>)

df = sc.parallelize([(["union"], )]).toDF(["tokens"])

df.withColumn("topics", pyspark_udf(col("tokens"))).show()

相关问题 更多 >

    热门问题