我有一个PySpark数据帧trips
,我正在其上执行聚合。对于每个PULocationID
,我首先计算total_amount
的平均值,然后计算行程数,最后计算DOLocationID
位于mtrips
的DOLocationID
列中的行程数,这是另一个PySpark数据帧
我在下面包含了trips
和mtrips
的模式
我当前的代码如下,但不完整:
import pyspark.sql.functions as F
cnt_cond = lambda cond: F.sum(F.when(cond, 1).otherwise(0))
(
trips
.groupBy('PULocationID', 'DOLocationID')
.agg(
F.mean('total_amount').alias('avg_total_amt'),
F.count('*').alias('trip_count'),
cnt_cond(mtrips.DOLocationID.contains(trips.DOLocationID)).alias('trips_to_pop')
)
.show(200)
)
trips.printSchema()
# root
# |-- VendorID: integer (nullable = true)
# |-- tpep_pickup_datetime: timestamp (nullable = true)
# |-- tpep_dropoff_datetime: timestamp (nullable = true)
# |-- passenger_count: integer (nullable = true)
# |-- trip_distance: double (nullable = true)
# |-- RatecodeID: integer (nullable = true)
# |-- store_and_fwd_flag: string (nullable = true)
# |-- PULocationID: integer (nullable = true)
# |-- DOLocationID: integer (nullable = true)
# |-- payment_type: integer (nullable = true)
# |-- fare_amount: double (nullable = true)
# |-- extra: double (nullable = true)
# |-- mta_tax: double (nullable = true)
# |-- tip_amount: double (nullable = true)
# |-- tolls_amount: double (nullable = true)
# |-- improvement_surcharge: double (nullable = true)
# |-- total_amount: double (nullable = true)
# |-- congestion_surcharge: double (nullable = true)
mtrips.printSchema()
# root
# |-- DOLocationID: integer (nullable = true)
# |-- pcount: long (nullable = true)
以下是解决问题的代码行:
相关问题 更多 >
编程相关推荐