vuls: init at 0.27.0
[NixPkgs.git] / nixos / tests / spark / spark_sample.py
blobc4939451eae04a2b15c90cb159e5220d3a4f4f6a
1 from pyspark.sql import Row, SparkSession
2 from pyspark.sql import functions as F
3 from pyspark.sql.functions import udf
4 from pyspark.sql.types import *
5 from pyspark.sql.functions import explode
7 def explode_col(weight):
8 return int(weight//10) * [10.0] + ([] if weight%10==0 else [weight%10])
10 spark = SparkSession.builder.getOrCreate()
12 dataSchema = [
13 StructField("feature_1", FloatType()),
14 StructField("feature_2", FloatType()),
15 StructField("bias_weight", FloatType())
18 data = [
19 Row(0.1, 0.2, 10.32),
20 Row(0.32, 1.43, 12.8),
21 Row(1.28, 1.12, 0.23)
24 df = spark.createDataFrame(spark.sparkContext.parallelize(data), StructType(dataSchema))
26 normalizing_constant = 100
27 sum_bias_weight = df.select(F.sum('bias_weight')).collect()[0][0]
28 normalizing_factor = normalizing_constant / sum_bias_weight
29 df = df.withColumn('normalized_bias_weight', df.bias_weight * normalizing_factor)
30 df = df.drop('bias_weight')
31 df = df.withColumnRenamed('normalized_bias_weight', 'bias_weight')
33 my_udf = udf(lambda x: explode_col(x), ArrayType(FloatType()))
34 df1 = df.withColumn('explode_val', my_udf(df.bias_weight))
35 df1 = df1.withColumn("explode_val_1", explode(df1.explode_val)).drop("explode_val")
36 df1 = df1.drop('bias_weight').withColumnRenamed('explode_val_1', 'bias_weight')
38 df1.show()
40 assert(df1.count() == 12)