Pyspark使用2个数据帧中的值和阈值生成段数组

2024-10-03 02:43:32 发布

您现在位置:Python中文网/ 问答频道 /正文

我需要使用两个不同数据集中的段值及其阈值生成一个段数组。在pyspark或hivesql中有没有一种简单的方法可以做到这一点

段值数据集:

--------------------------------------------------
| user_id   | seg1  | seg2  | seg3 | seg4 | seg5 |
------------------------------------------------
| 100       |   90  |  20   |   76 |  100 |  30  |
| 200       |   56  |  15   |   67 |  99  |  25  |
| 300       |   87  |  38   |   45 |  97  |  40  |
--------------------------------------------------

段阈值数据集:

---------------------------
|seg_name | seg_threshold |
---------------------------
|  seg1   |  83           |
|  seg2   |  25           |
|  seg3   |  60           |
|  seg4   |  98           |
|  seg5   |  35           |
---------------------------

如果某段的值高于阈值,则应将用户视为该段的一部分。该用户的段数组应包括段名称(列标题)

预期产出:

-------------------------------------
| user_id| segment_array            |
-------------------------------------
| 100    | [seg1, seg3, seg4] |
| 200    | [seg3, seg4]             |
| 300    | [seg1, seg2, seg5]       |
-------------------------------------

请注意,这只是一个指示性数据集。我有几百个这样的片段

谢谢你的帮助


Tags: 数据用户id阈值数组pyspark段值seg
2条回答

数百个阈值条目可能是broadcasted。然后可以在UDF中检查值是否高于或低于阈值:

#broadcast the threshold data
thresholdDf = ...
thresholdMap = thresholdDf.rdd.collectAsMap()
thresholds = spark.sparkContext.broadcast(thresholdMap)

userDf = ...

#add a new column to the user dataframe that contains a struct with the column 
#names and their respective values. This column will be used to call the udf
user2Df = userDf.withColumn("all_cols", F.struct([F.struct(F.lit(x),userDf[x]) \
    for x in userDf.columns]))

#create the udf
def calc_segments(row):
    return [col.col1 for col in row \
        if thresholds.value.get(col.col1) != None \
        if int(thresholds.value[col.col1]) < int(col[col.col1])]
segment_udf = F.udf(calc_segments, T.ArrayType(T.StringType()))

#call the udf and drop the intermediate column
user2Df.withColumn("segment_array", segment_udf(user2Df.all_cols)) \
    .drop("all_cols").show(truncate=False)

我的结果是

+   -+  +  +  +  +  +         +
|user_id|seg1|seg2|seg3|seg4|seg5|segment_array     |
+   -+  +  +  +  +  +         +
|100    |90  |20  |76  |100 |30  |[seg1, seg3, seg4]|
|200    |56  |15  |67  |99  |25  |[seg3, seg4]      |
|300    |87  |38  |45  |97  |40  |[seg1, seg2, seg5]|
+   -+  +  +  +  +  +         +

此结果与预期结果略有不同。也许测试数据有问题

@werner的解决方案是完全有效的

在纯sparksql中,有一种方法可以在没有udf的情况下实现这一点

准备数据帧:

from pyspark.sql import Row

spark.createDataFrame([
  Row(user_id=100, seg1=90, seg2=20, seg3=76, seg4=100, seg5=30), 
  Row(user_id=200, seg1=56, seg2=15, seg3=67, seg4=99, seg5=25), 
  Row(user_id=300, seg1=87, seg2=38, seg3=45, seg4=97, seg5=40)]).createOrReplaceTempView("data")

spark.createDataFrame([
  Row(seg_name = 'seg1', seg_threshold = 83),
  Row(seg_name = 'seg2', seg_threshold = 25),
  Row(seg_name = 'seg3', seg_threshold = 60),
  Row(seg_name = 'seg4', seg_threshold = 98),
  Row(seg_name = 'seg5', seg_threshold = 35)
]).createOrReplaceTempView("thr")

现在,您可以使用一个名为stack的边缘但非常有用的函数执行“unpivot”操作:

spark.sql("""
WITH data_eva 
     AS (SELECT user_id, 
                Stack(5, 'seg1', seg1, 'seg2', seg2, 'seg3', seg3, 'seg4', seg4, 'seg5', seg5) 
         FROM   data) 
SELECT user_id, 
       Collect_list(col0) 
FROM   data_eva 
       JOIN thr 
         ON data_eva.col0 = thr.seg_name 
WHERE  col1 > seg_threshold 
GROUP  BY user_id 
 """).show()

这是输出:

+   -+         +
|user_id|collect_list(col0)|
+   -+         +
|    100|[seg4, seg1, seg3]|
|    200|      [seg4, seg3]|
|    300|[seg2, seg1, seg5]|
+   -+         +

你提到你有数百段。可以使用循环在堆栈函数内轻松生成表达式

这项技术在spark工具箱中非常有用

相关问题 更多 >