对Spark DataFram中的所有单元格应用函数

2024-10-01 22:43:13 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试转换一些熊猫代码为Spark进行缩放。myfunc是复杂API的包装器,它接受一个字符串并返回一个新字符串(这意味着我不能使用向量化函数)。在

def myfunc(ds):
    for attribute, value in ds.items():
        value = api_function(attribute, value)
        ds[attribute] = value
    return ds

df = df.apply(myfunc, axis='columns')

myfunc获取一个数据系列,将其分解为各个单元格,为每个单元格调用API,并使用相同的列名构建一个新的数据系列。这将有效地修改数据帧中的所有单元格。在

我是Spark新手,我想用pyspark来翻译这个逻辑。我已将pandas数据帧转换为Spark:

^{pr2}$

我迷路了。我需要一个UDF,一个pandas_udf?如何遍历所有单元格并使用myfunc为每个单元格返回一个新字符串?spark_df.foreach()不返回任何内容,也没有map()函数。在

如果需要,我可以将myfuncDataSeries->;DataSeries修改为string->;string。在


Tags: 数据函数字符串代码gtapipandasdf
2条回答

选项1:一次对一列使用UDF

最简单的方法是重写函数,将字符串作为参数(使其成为string->;string)并使用UDF。有一个很好的例子here。这种方法一次只能处理一列。因此,如果您的DataFrame具有合理数量的列,则可以一次将UDF应用于每个列:

from pyspark.sql.functions import col
new_df = df.select(udf(col("col1")), udf(col("col2")), ...)

示例

^{pr2}$

选项2:一次映射整个数据帧

map可用于ScalaDataFrames,但目前在PySpark中不可用。 低级的RDDAPI在PySpark中确实有一个map函数。因此,如果有太多的列无法一次转换一个,可以对DataFrame中的每个单元格进行如下操作:

def map_fn(row):
    return [api_function(x) for (column, x) in row.asDict().items()

column_names = df.columns
new_df = df.rdd.map(map_fn).toDF(df.columns)

示例

df = sc.parallelize([[1, 4], [2,5], [3,6]]).toDF(["col1", "col2"])
def map_fn(row):
   return [value + 1 for (_, value) in row.asDict().items()]

columns = df.columns
new_df = df.rdd.map(map_fn).toDF(columns)
new_df.show()
+  +  +
|col1|col2|
+  +  +
|   2|   5|
|   3|   6|
|   4|   7|
+  +  +

上下文

foreachdocumentation只给出了打印的示例,但是我们可以通过查看code来验证它确实没有返回任何内容。在

你可以在this post中读到pandas_udf,但它似乎最适合于向量化函数,正如你所指出的,由于api_function,你不能使用它。在

解决方案是:

udf_func = udf(func, StringType())
for col_name in spark_df.columns:
    spark_df = spark_df.withColumn(col_name, udf_func(lit(col_name), col_name))
return spark_df.toPandas()

有3个关键见解帮助我解决了这个问题:

  1. 如果将withColumn与现有列的名称(col_name)一起使用,则Spark“overwrites”/会隐藏原始列。这本质上给人一种直接编辑列的感觉,就好像它是可变的一样。在
  2. 通过在原始列之间创建一个循环并重用相同的DataFrame变量spark_df,我使用相同的原理来模拟一个可变的数据帧,创建一个列级转换链,每次“重写”一个列(每个#1-见下文)
  3. SparkUDFs希望所有参数都是Column类型,这意味着它尝试解析每个参数的列值。因为api_function的第一个参数是一个对于向量中所有行都相同的文本值,所以必须使用lit()函数。只要将col_name传递给函数,就会尝试提取该列的列值。据我所知,传递col_name等同于传递{}。在

假设有3列“a”、“b”和“c”,展开此概念如下所示:

^{pr2}$

相关问题 更多 >

    热门问题