pysp中的Bucketing和一个热编码

2024-09-29 17:44:15 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个PySpark数据帧,由以下列组成:

id        Age
1         30
2         25
3         21

我有以下年龄段:[20, 24, 27, 30]。你知道吗

我的预期结果:

id    Age    age_bucket     age_27_30     age_24_27   age_20_24
1     30      (27-30]           1            0           0
2     25      (24-27]           0            1           0
3     21      (20-24]           0            0           1

我的当前代码:

from pyspark.ml.feature import Bucketizer
bucketizer = Bucketizer(splits=[ 20,24,27,30 ],inputCol="Age", outputCol="age_bucket")
df1 = bucketizer.setHandleInvalid("keep").transform(df)

Tags: 数据代码fromimportidagebucketml
2条回答

使用OneHotEncoderEstimator()

spark.version
'2.4.3'

df = spark.createDataFrame([(1, 30), (2, 25), (3, 21),],["id", "age"])

# buckets
from pyspark.ml.feature import Bucketizer

bucketizer = Bucketizer(splits=[20,24,27,30],inputCol="age", outputCol="age_bucket", handleInvalid="keep")
buckets = bucketizer.transform(df)

buckets.show()
+ -+ -+     +
| id|age|age_bucket|
+ -+ -+     +
|  1| 30|       2.0|
|  2| 25|       1.0|
|  3| 21|       0.0|
+ -+ -+     +

# ohe
from pyspark.ml.feature import OneHotEncoderEstimator

encoder = OneHotEncoderEstimator(inputCols=["age_bucket"], outputCols=["age_ohe"])

model = encoder.fit(buckets)
transform_model = model.transform(buckets)

transform_model.show()
+ -+ -+     +      -+
| id|age|age_bucket|      age_ohe|
+ -+ -+     +      -+
|  1| 30|       2.0|    (2,[],[])|
|  2| 25|       1.0|(2,[1],[1.0])|
|  3| 21|       0.0|(2,[0],[1.0])|
+ -+ -+     +      -+

# wrap it up in a pipeline if you want
from pyspark.ml import Pipeline

bucketizer = Bucketizer(splits=[20,24,27,30], inputCol="age", outputCol="age_bucket")
encoder = OneHotEncoderEstimator(inputCols=["age_bucket"], outputCols=["age_ohe"])

pipeline = Pipeline(stages=[bucketizer, encoder])

model = pipeline.fit(df)
fe = model.transform(df)

fe.show()
+ -+ -+     +      -+
| id|age|age_bucket|      age_ohe|
+ -+ -+     +      -+
|  1| 30|       2.0|    (2,[],[])|
|  2| 25|       1.0|(2,[1],[1.0])|
|  3| 21|       0.0|(2,[0],[1.0])|
+ -+ -+     +      -+

如果您希望得到与您的问题完全相同的结果,OneHotEstimatorEncoder如果没有其他一些奇特的映射技巧,将无法工作。你知道吗

我会在这里使用连接:

age_buckets = [20, 24, 27, 30]
bins = list(zip(age_buckets, age_buckets[1:]))

data = [[i] + ['({0}-{1}]'.format(*bin_endpoints)] + [0] * i + [1] + [0] * (len(bins) - i - 1) 
        for i, bin_endpoints in enumerate(bins)]
schema = ', '.join('age_bucket_{}_{}: int'.format(start, end) 
                   for start, end in zip(age_buckets, age_buckets[1:]))

join_df = spark.createDataFrame(data, 'age_bucket: int, age_bucket_string: string, ' + schema)

result = (df1.join(join_df, on='age_bucket', how='left')
             .drop('age_bucket')
             .withColumnRenamed('age_bucket_string', 'age_bucket')
             .orderBy('id'))
result.show()

输出:

+ -+ -+     +        +        +        +
| id|Age|age_bucket|age_bucket_20_24|age_bucket_24_27|age_bucket_27_30|
+ -+ -+     +        +        +        +
|  1| 30|   (27-30]|               0|               0|               1|
|  2| 25|   (24-27]|               0|               1|               0|
|  3| 21|   (20-24]|               1|               0|               0|
+ -+ -+     +        +        +        +

相关问题 更多 >

    热门问题