如何将多列作为参数传递给pyspark write repartition()

2024-06-15 03:10:15 发布

您现在位置:Python中文网/ 问答频道 /正文

我创建了一个函数,用一些参数将数据帧写到s3位置。除重新分区列表参数外的所有工作。它失败了,错误是:raise TypeError("numPartitions should be an int or Column")

年份是int列,日期是sate类型的列。当我硬编码这两列时 .repartition("parti_create_date", "parti_hour")有效。 我试着把它们作为列表、字符串和列来提供。 似乎什么都没用。你知道吗

parti_list = ["parti_year", "parti_create_date", "parti_hour"]
re_parti_list = ["parti_create_date", "parti_hour"]

def spark_write(in_df, write_tgt_loc, parti_list, re_parti_list, tgt_file_format, write_mode, tgt_file_compression):
(in_df
            .repartition(re_parti_list)  #(re_parti_str)
            .write
            .partitionBy(parti_str)
            .mode(write_mode).format(tgt_file_format)
            .option('compression', tgt_file_compression).option("nullValue", "null").option("treatEmptyValuesAsNulls,", "true")
            .save(write_tgt_loc))

spark_write(tgt_df, "s3://bucket/out/", parti_list, re_parti_list, "parquet", "overwrite","snappy")

你能帮我弄清楚如何在PySpark中将重分区列作为参数传递吗?你知道吗


Tags: reformatdf参数datemodecreatelist
1条回答
网友
1楼 · 发布于 2024-06-15 03:10:15

重新分区需要either int or column,因此我们需要将col("<col_name>")传递给数据帧。你知道吗

Example:

df=spark.createDataFrame([(1,'a',),(2,'b',),(3,'c',)],['id','name'])
df.rdd.getNumPartitions() #number of partitions in df
1

Repartition on int:

df.repartition(10).rdd.getNumPartitions() #repartition to 10 

10

Repartition on columns:

df.repartition(col("id"),col("name")).rdd.getNumPartitions() #repartition on columns

200

Dynamic repartition on columns:

df.repartition(*[col(c) for c in df.columns]).rdd.getNumPartitions()

200

columns list映射到column类型而不是string然后在repartition.中传递列名

For your case try this way:

df.repartition(*[col(c) for c in re_parti_list])
            .write
            .partitionBy(parti_str)
            .mode(write_mode).format(tgt_file_format)
            .option('compression', tgt_file_compression).option("nullValue", "null").option("treatEmptyValuesAsNulls,", "true")
            .save(write_tgt_loc))

In scala:

df.repartition(df.columns.map(c => col(c)):_*).rdd.getNumPartitions
200

相关问题 更多 >