<p>数百个阈值条目可能是<a href="https://spark.apache.org/docs/3.0.0/rdd-programming-guide.html#broadcast-variables" rel="nofollow noreferrer">broadcasted</a>。然后可以在<a href="https://spark.apache.org/docs/3.0.0/api/python/pyspark.sql.html#pyspark.sql.functions.udf" rel="nofollow noreferrer">UDF</a>中检查值是否高于或低于阈值:</p>
<pre class="lang-python prettyprint-override"><code>#broadcast the threshold data
thresholdDf = ...
thresholdMap = thresholdDf.rdd.collectAsMap()
thresholds = spark.sparkContext.broadcast(thresholdMap)
userDf = ...
#add a new column to the user dataframe that contains a struct with the column
#names and their respective values. This column will be used to call the udf
user2Df = userDf.withColumn("all_cols", F.struct([F.struct(F.lit(x),userDf[x]) \
for x in userDf.columns]))
#create the udf
def calc_segments(row):
return [col.col1 for col in row \
if thresholds.value.get(col.col1) != None \
if int(thresholds.value[col.col1]) < int(col[col.col1])]
segment_udf = F.udf(calc_segments, T.ArrayType(T.StringType()))
#call the udf and drop the intermediate column
user2Df.withColumn("segment_array", segment_udf(user2Df.all_cols)) \
.drop("all_cols").show(truncate=False)
</code></pre>
<p>我的结果是</p>
<pre class="lang-none prettyprint-override"><code>+ -+ + + + + + +
|user_id|seg1|seg2|seg3|seg4|seg5|segment_array |
+ -+ + + + + + +
|100 |90 |20 |76 |100 |30 |[seg1, seg3, seg4]|
|200 |56 |15 |67 |99 |25 |[seg3, seg4] |
|300 |87 |38 |45 |97 |40 |[seg1, seg2, seg5]|
+ -+ + + + + + +
</code></pre>
<p>此结果与预期结果略有不同。也许测试数据有问题</p>