如何在PySpark的UDF中返回“元组类型”?

2024-06-23 20:11:19 发布

您现在位置:Python中文网/ 问答频道 /正文

所有data types in ^{} are

__all__ = [
    "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType",
    "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType",
    "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType"]

我必须编写一个UDF(在pyspark中),它返回一个元组数组。第二个参数是udf方法的返回类型,我该给它什么呢?它应该是ArrayType(TupleType())行上的东西。。。


Tags: indataallaretypesdatatypestringtypedoubletype
3条回答

火花中没有TupleType这样的东西。产品类型用特定类型的字段表示为structs。例如,如果要返回一个成对数组(整数、字符串),可以使用如下架构:

from pyspark.sql.types import *

schema = ArrayType(StructType([
    StructField("char", StringType(), False),
    StructField("count", IntegerType(), False)
]))

示例用法:

from pyspark.sql.functions import udf
from collections import Counter

char_count_udf = udf(
    lambda s: Counter(s).most_common(),
    schema
)

df = sc.parallelize([(1, "foo"), (2, "bar")]).toDF(["id", "value"])

df.select("*", char_count_udf(df["value"])).show(2, False)

## +---+-----+-------------------------+
## |id |value|PythonUDF#<lambda>(value)|
## +---+-----+-------------------------+
## |1  |foo  |[[o,2], [f,1]]           |
## |2  |bar  |[[r,1], [a,1], [b,1]]    |
## +---+-----+-------------------------+

Stackoverflow一直在引导我回答这个问题,所以我想我会在这里添加一些信息。

从UDF返回简单类型:

from pyspark.sql.types import *
from pyspark.sql import functions as F

def get_df():
  d = [(0.0, 0.0), (0.0, 3.0), (1.0, 6.0), (1.0, 9.0)]
  df = sqlContext.createDataFrame(d, ['x', 'y'])
  return df

df = get_df()
df.show()

# +---+---+
# |  x|  y|
# +---+---+
# |0.0|0.0|
# |0.0|3.0|
# |1.0|6.0|
# |1.0|9.0|
# +---+---+

func = udf(lambda x: str(x), StringType())
df = df.withColumn('y_str', func('y'))

func = udf(lambda x: int(x), IntegerType())
df = df.withColumn('y_int', func('y'))

df.show()

# +---+---+-----+-----+
# |  x|  y|y_str|y_int|
# +---+---+-----+-----+
# |0.0|0.0|  0.0|    0|
# |0.0|3.0|  3.0|    3|
# |1.0|6.0|  6.0|    6|
# |1.0|9.0|  9.0|    9|
# +---+---+-----+-----+

df.printSchema()

# root
#  |-- x: double (nullable = true)
#  |-- y: double (nullable = true)
#  |-- y_str: string (nullable = true)
#  |-- y_int: integer (nullable = true)

当整数不够时:

df = get_df()

func = udf(lambda x: [0]*int(x), ArrayType(IntegerType()))
df = df.withColumn('list', func('y'))

func = udf(lambda x: {float(y): str(y) for y in range(int(x))}, 
           MapType(FloatType(), StringType()))
df = df.withColumn('map', func('y'))

df.show()
# +---+---+--------------------+--------------------+
# |  x|  y|                list|                 map|
# +---+---+--------------------+--------------------+
# |0.0|0.0|                  []|               Map()|
# |0.0|3.0|           [0, 0, 0]|Map(2.0 -> 2, 0.0...|
# |1.0|6.0|  [0, 0, 0, 0, 0, 0]|Map(0.0 -> 0, 5.0...|
# |1.0|9.0|[0, 0, 0, 0, 0, 0...|Map(0.0 -> 0, 5.0...|
# +---+---+--------------------+--------------------+

df.printSchema()
# root
#  |-- x: double (nullable = true)
#  |-- y: double (nullable = true)
#  |-- list: array (nullable = true)
#  |    |-- element: integer (containsNull = true)
#  |-- map: map (nullable = true)
#  |    |-- key: float
#  |    |-- value: string (valueContainsNull = true)

从UDF返回复杂数据类型:

df = get_df()
df = df.groupBy('x').agg(F.collect_list('y').alias('y[]'))
df.show()

# +---+----------+
# |  x|       y[]|
# +---+----------+
# |0.0|[0.0, 3.0]|
# |1.0|[9.0, 6.0]|
# +---+----------+

schema = StructType([
    StructField("min", FloatType(), True),
    StructField("size", IntegerType(), True),
    StructField("edges",  ArrayType(FloatType()), True),
    StructField("val_to_index",  MapType(FloatType(), IntegerType()), True)
    # StructField('insanity', StructType([StructField("min_", FloatType(), True), StructField("size_", IntegerType(), True)]))

])

def func(values):
  mn = min(values)
  size = len(values)
  lst = sorted(values)[::-1]
  val_to_index = {x: i for i, x in enumerate(values)}
  return (mn, size, lst, val_to_index)

func = udf(func, schema)
dff = df.select('*', func('y[]').alias('complex_type'))
dff.show(10, False)

# +---+----------+------------------------------------------------------+
# |x  |y[]       |complex_type                                          |
# +---+----------+------------------------------------------------------+
# |0.0|[0.0, 3.0]|[0.0,2,WrappedArray(3.0, 0.0),Map(0.0 -> 0, 3.0 -> 1)]|
# |1.0|[6.0, 9.0]|[6.0,2,WrappedArray(9.0, 6.0),Map(9.0 -> 1, 6.0 -> 0)]|
# +---+----------+------------------------------------------------------+

dff.printSchema()

# +---+----------+------------------------------------------------------+
# |x  |y[]       |complex_type                                          |
# +---+----------+------------------------------------------------------+
# |0.0|[0.0, 3.0]|[0.0,2,WrappedArray(3.0, 0.0),Map(0.0 -> 0, 3.0 -> 1)]|
# |1.0|[6.0, 9.0]|[6.0,2,WrappedArray(9.0, 6.0),Map(9.0 -> 1, 6.0 -> 0)]|
# +---+----------+------------------------------------------------------+

向UDF传递多个参数:

df = get_df()
func = udf(lambda arr: arr[0]*arr[1],FloatType())
df = df.withColumn('x*y', func(F.array('x', 'y')))

    # +---+---+---+
    # |  x|  y|x*y|
    # +---+---+---+
    # |0.0|0.0|0.0|
    # |0.0|3.0|0.0|
    # |1.0|6.0|6.0|
    # |1.0|9.0|9.0|
    # +---+---+---+

这段代码纯粹是为了演示目的,上面所有的转换都可以在Spark代码中使用,并且会产生更好的性能。 正如上面注释中的@zero323,在pyspark中通常应该避免udf;返回复杂类型应该使您考虑简化逻辑。

对于scala版本而不是python。 版本2.4

import org.apache.spark.sql.types._

val testschema : StructType = StructType(
    StructField("number", IntegerType) ::
    StructField("Array",  ArrayType(StructType(StructField("cnt_rnk", IntegerType) :: StructField("comp", StringType) :: Nil))) :: 
    StructField("comp", StringType):: Nil)

树的结构是这样的。

testschema.printTreeString
root
 |-- number: integer (nullable = true)
 |-- Array: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- cnt_rnk: integer (nullable = true)
 |    |    |-- corp_id: string (nullable = true)
 |-- comp: string (nullable = true)

相关问题 更多 >

    热门问题