将多个列与另一个单列进行比较时,选择立即较小/较大的值

2024-06-01 22:27:32 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个可变的列数,比如说在这个例子中,我们有4列要与一个具有不同值的列(textX)进行比较(id):

d =     [
  {'id':  500, 'text1': 1000 ,'text2': 2000 ,'text3': 3000, 'text4': 5000}, 
  {'id': 1500, 'text1': 1000 ,'text2': 2000 ,'text3': 3000, 'text4': 5000},
  {'id': 2500, 'text1': 1000 ,'text2': 2000 ,'text3': 3000, 'text4': 5000},
  {'id': 3500, 'text1': 1000 ,'text2': 2000 ,'text3': 3000, 'text4': 5000},
  {'id': 4500, 'text1': 1000 ,'text2': 2000 ,'text3': 3000, 'text4': 5000},
  {'id': 5500, 'text1': 1000 ,'text2': 2000 ,'text3': 3000, 'text4': 5000}
] 
data = spark.createDataFrame(d)

我希望根据'id'的值对textX列中的最小值和较大值进行操作。 例如,对于idvalue=2500,我希望对值2000和3000进行操作。如果'id'的值为500,则它将是null和1000。 我尝试将这些列作为附加列,例如获取较低的列值

df_cols = data.columns
thresh_list = [x for x in df_cols if x.startswith('text')]

data.withColumn('inic_th', (col(x) for x in thresh_list if col('id') > col(x)))

但是得到一个错误:

col should be Column

我猜这是因为有多个列与条件匹配,但无法在此处插入

有没有人有办法根据第三列将一个操作转换成2个值,或者如何正确地获得这些边界?实际上,textX列的数量会有所不同。由于性能问题,我将尽可能远离熊猫和UDF


Tags: iniddffordataifcollist
3条回答

下面是使用^{}^{}函数以及^{}表达式的另一种方法:

  • lowerBound=maxthresh_cols该状态为thresh_col < id
  • upperBound=minthresh_cols该状态满足条件{}
from pyspark.sql import functions as F

result = data.withColumn(
    'lowerBound',
    F.array_max(F.array(*[F.when(F.col(c) < F.col('id'), F.col(c)) for c in thresh_cols]))
).withColumn(
    'upperBound',
    F.array_min(F.array(*[F.when(F.col(c) > F.col('id'), F.col(c)) for c in thresh_cols]))
)

result.show()

#+  +  -+  -+  -+  -+     +     +
#|  id|text1|text2|text3|text4|lowerBound|upperBound|
#+  +  -+  -+  -+  -+     +     +
#| 500| 1000| 2000| 3000| 5000|      null|      1000|
#|1500| 1000| 2000| 3000| 5000|      1000|      2000|
#|2500| 1000| 2000| 3000| 5000|      2000|      3000|
#|3500| 1000| 2000| 3000| 5000|      3000|      5000|
#|4500| 1000| 2000| 3000| 5000|      3000|      5000|
#|5500| 1000| 2000| 3000| 5000|      5000|      null|
#+  +  -+  -+  -+  -+     +     +

下面是一种使用spark高阶函数的方法>=2.4:


df_cols = data.columns
thresh_list = [x for x in df_cols if x.startswith('text')]

out = (data.select("*",F.sort_array(F.array(*thresh_list)).alias("Arr"))
.withColumn("FirstVal",F.expr('element_at(filter (Arr, x-> x<id),-1)'))
.withColumn("LastVal",F.expr('filter (Arr, x->x>id)[0]')).drop("Arr")
)

out.show(truncate=False)

+  +  -+  -+  -+  -+    +   -+
|id  |text1|text2|text3|text4|FirstVal|LastVal|
+  +  -+  -+  -+  -+    +   -+
|500 |1000 |2000 |3000 |5000 |null    |1000   |
|1500|1000 |2000 |3000 |5000 |1000    |2000   |
|2500|1000 |2000 |3000 |5000 |2000    |3000   |
|3500|1000 |2000 |3000 |5000 |3000    |5000   |
|4500|1000 |2000 |3000 |5000 |3000    |5000   |
|5500|1000 |2000 |3000 |5000 |5000    |null   |
+  +  -+  -+  -+  -+    +   -+

您可以使用leastgreatest获取相关列:

import pyspark.sql.functions as F

df = data.withColumn(
    'col1',
    F.greatest(*[
        F.when(F.col(c) < F.col('id'), F.col(c))
        for c in data.columns
    ])
).withColumn(
    'col2',
    F.least(*[
        F.when(F.col(c) > F.col('id'), F.col(c))
        for c in data.columns
    ])
)

df.show()
+  +  -+  -+  -+  -+  +  +
|  id|text1|text2|text3|text4|col1|col2|
+  +  -+  -+  -+  -+  +  +
| 500| 1000| 2000| 3000| 5000|null|1000|
|1500| 1000| 2000| 3000| 5000|1000|2000|
|2500| 1000| 2000| 3000| 5000|2000|3000|
|3500| 1000| 2000| 3000| 5000|3000|5000|
|4500| 1000| 2000| 3000| 5000|3000|5000|
|5500| 1000| 2000| 3000| 5000|5000|null|
+  +  -+  -+  -+  -+  +  +

然后你可以对col1col2进行操作

相关问题 更多 >