pyspark中的位操作,不使用udf

2024-09-30 04:39:37 发布

您现在位置:Python中文网/ 问答频道 /正文

我的spark数据框如下所示:

+---------+---------------------------+
|country  |sports                     |
+---------+---------------------------+
|India    |[Cricket, Hockey, Football]|
|Sri Lanka|[Cricket, Football]        |
+---------+---------------------------+

“运动”列中的每项运动都用代码表示:

sport_to_code_map = {
'Cricket' : 0x0001,
'Hockey' : 0x0002,
'Football' : 0x0004
}

现在我想添加一个名为sportsInt的新列,它是与上述映射中的运动字符串相关联的每个代码的按位的结果,从而导致:

+---------+---------------------------+---------+
|country  |sports                     |sportsInt|
+---------+---------------------------+---------+
|India    |[Cricket, Hockey, Football]|7        |
|Sri Lanka|[Cricket, Football]        |5        |
+---------+---------------------------+---------+

我知道一种方法是使用UDF,它是这样的:

def get_sport_to_code(sport_name):

    sport_to_code_map = {
        'Cricket': 0x0001,
        'Hockey': 0x0002,
        'Football': 0x0004
    }

    if feature not in sport_to_code_map:
        raise Exception(f'Unknown Sport: {sport_name}')
    return sport_to_code_map.get(sport_name)

def sport_to_code(sports):
    if not sports:
        return None

    code = 0x0000
    for sport in sports:
        code = code | get_sport_to_code(sport)
    return code
import pyspark.sql.functions as F

sport_to_code_udf = F.udf(sport_to_code, F.StringType())
df.withColumn('sportsInt',sport_to_code_udf('sports'))

但是有没有办法用spark函数来实现呢?而不是udf


Tags: tonamemapgetreturncodecountryspark
1条回答
网友
1楼 · 发布于 2024-09-30 04:39:37

Spark-2.4+我们可以在这种情况下使用聚合高阶函数和bitwise or操作符

Example:

from pyspark.sql.types import *
from pyspark.sql.functions import *

sport_to_code_map = {
'Cricket' : 0x0001,
'Hockey' : 0x0002,
'Football' : 0x0004
}

#creating dataframe from dictionary
lookup=spark.createDataFrame(*[zip(sport_to_code_map.keys(),sport_to_code_map.values())],["key","value"])

#sample dataframe
df.show(10,False)
#+    -+             -+
#|country  |sports                     |
#+    -+             -+
#|India    |[Cricket, Hockey, Football]|
#|Sri Lanka|[Cricket, Football]        |
#+    -+             -+

df1=df.selectExpr("explode(sports) as key","country")

df2=df1.join(lookup,['key'],'left').\
groupBy("country").\
agg(collect_list(col("key")).alias("sports"),collect_list(col("value")).alias("sportsInt"))

df2.withColumn("sportsInt",expr('aggregate(sportsInt,0,(s,x) -> int(s) | int(x))')).\
show(10,False)
#+    -+             -+    -+
#|country  |sports                     |sportsInt|
#+    -+             -+    -+
#|Sri Lanka|[Cricket, Football]        |5        |
#|India    |[Cricket, Hockey, Football]|7        |
#+    -+             -+    -+

如果要避免在sport_to_code_mapdict中进行查找,请使用.replace

#converting dict values to string
sport_to_code_map={k:str(v) for k,v in sport_to_code_map.items()}

df1.replace(sport_to_code_map).show()
#+ -+    -+
#|key|  country|
#+ -+    -+
#|  1|    India|
#|  2|    India|
#|  4|    India|
#|  1|Sri Lanka|
#|  4|Sri Lanka|
#+ -+    -+

相关问题 更多 >

    热门问题