1 from pyspark
.sql
import Row
, SparkSession
2 from pyspark
.sql
import functions
as F
3 from pyspark
.sql
.functions
import udf
4 from pyspark
.sql
.types
import *
5 from pyspark
.sql
.functions
import explode
7 def explode_col(weight
):
8 return int(weight
//10) * [10.0] + ([] if weight
%10==0 else [weight
%10])
10 spark
= SparkSession
.builder
.getOrCreate()
13 StructField("feature_1", FloatType()),
14 StructField("feature_2", FloatType()),
15 StructField("bias_weight", FloatType())
20 Row(0.32, 1.43, 12.8),
24 df
= spark
.createDataFrame(spark
.sparkContext
.parallelize(data
), StructType(dataSchema
))
26 normalizing_constant
= 100
27 sum_bias_weight
= df
.select(F
.sum('bias_weight')).collect()[0][0]
28 normalizing_factor
= normalizing_constant
/ sum_bias_weight
29 df
= df
.withColumn('normalized_bias_weight', df
.bias_weight
* normalizing_factor
)
30 df
= df
.drop('bias_weight')
31 df
= df
.withColumnRenamed('normalized_bias_weight', 'bias_weight')
33 my_udf
= udf(lambda x
: explode_col(x
), ArrayType(FloatType()))
34 df1
= df
.withColumn('explode_val', my_udf(df
.bias_weight
))
35 df1
= df1
.withColumn("explode_val_1", explode(df1
.explode_val
)).drop("explode_val")
36 df1
= df1
.drop('bias_weight').withColumnRenamed('explode_val_1', 'bias_weight')
40 assert(df1
.count() == 12)