import findspark
findspark.init('/Users/ryanshin/Downloads/spark-2.3.1-bin-hadoop2.7')
import pyspark
sc = pyspark.SparkContext()
from pyspark.sql.session import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
def age_group_func(age):
return int(age / 10)
age_group = udf(age_group_func, returnType=IntegerType())
spark = SparkSession(sc)
titanic = spark.read.option("header", "true") \
.csv("/Users/ryanshin/Downloads/train.csv") \
.withColumn("Survived", col("Survived").cast("double")) \
.withColumn("label", col("Survived")) \
.withColumn("Pclass", col("Pclass").cast("double"))\
.withColumn("SibSp", col("SibSp").cast("double"))\
.withColumn("Parch", col("Parch").cast("double"))\
.withColumn("Fare", col("Fare").cast("double"))\
.withColumn("Age", col("Age").cast("int"))\
.na.fill("S", "Embarked") \
.na.fill(-1, "Age") \
.withColumn("age_group", age_group(col("Age")))
titanic.printSchema()
root |-- PassengerId: string (nullable = true) |-- Survived: double (nullable = true) |-- Pclass: double (nullable = true) |-- Name: string (nullable = true) |-- Sex: string (nullable = true) |-- Age: integer (nullable = true) |-- SibSp: double (nullable = true) |-- Parch: double (nullable = true) |-- Ticket: string (nullable = true) |-- Fare: double (nullable = true) |-- Cabin: string (nullable = true) |-- Embarked: string (nullable = false) |-- label: double (nullable = true) |-- age_group: integer (nullable = true)
titanic.show()
+-----------+--------+------+--------------------+------+---+-----+-----+----------------+-------+-----+--------+-----+---------+ |PassengerId|Survived|Pclass| Name| Sex|Age|SibSp|Parch| Ticket| Fare|Cabin|Embarked|label|age_group| +-----------+--------+------+--------------------+------+---+-----+-----+----------------+-------+-----+--------+-----+---------+ | 1| 0.0| 3.0|Braund, Mr. Owen ...| male| 22| 1.0| 0.0| A/5 21171| 7.25| null| S| 0.0| 2| | 2| 1.0| 1.0|Cumings, Mrs. Joh...|female| 38| 1.0| 0.0| PC 17599|71.2833| C85| C| 1.0| 3| | 3| 1.0| 3.0|Heikkinen, Miss. ...|female| 26| 0.0| 0.0|STON/O2. 3101282| 7.925| null| S| 1.0| 2| | 4| 1.0| 1.0|Futrelle, Mrs. Ja...|female| 35| 1.0| 0.0| 113803| 53.1| C123| S| 1.0| 3| | 5| 0.0| 3.0|Allen, Mr. Willia...| male| 35| 0.0| 0.0| 373450| 8.05| null| S| 0.0| 3| | 6| 0.0| 3.0| Moran, Mr. James| male| -1| 0.0| 0.0| 330877| 8.4583| null| Q| 0.0| 0| | 7| 0.0| 1.0|McCarthy, Mr. Tim...| male| 54| 0.0| 0.0| 17463|51.8625| E46| S| 0.0| 5| | 8| 0.0| 3.0|Palsson, Master. ...| male| 2| 3.0| 1.0| 349909| 21.075| null| S| 0.0| 0| | 9| 1.0| 3.0|Johnson, Mrs. Osc...|female| 27| 0.0| 2.0| 347742|11.1333| null| S| 1.0| 2| | 10| 1.0| 2.0|Nasser, Mrs. Nich...|female| 14| 1.0| 0.0| 237736|30.0708| null| C| 1.0| 1| | 11| 1.0| 3.0|Sandstrom, Miss. ...|female| 4| 1.0| 1.0| PP 9549| 16.7| G6| S| 1.0| 0| | 12| 1.0| 1.0|Bonnell, Miss. El...|female| 58| 0.0| 0.0| 113783| 26.55| C103| S| 1.0| 5| | 13| 0.0| 3.0|Saundercock, Mr. ...| male| 20| 0.0| 0.0| A/5. 2151| 8.05| null| S| 0.0| 2| | 14| 0.0| 3.0|Andersson, Mr. An...| male| 39| 1.0| 5.0| 347082| 31.275| null| S| 0.0| 3| | 15| 0.0| 3.0|Vestrom, Miss. Hu...|female| 14| 0.0| 0.0| 350406| 7.8542| null| S| 0.0| 1| | 16| 1.0| 2.0|Hewlett, Mrs. (Ma...|female| 55| 0.0| 0.0| 248706| 16.0| null| S| 1.0| 5| | 17| 0.0| 3.0|Rice, Master. Eugene| male| 2| 4.0| 1.0| 382652| 29.125| null| Q| 0.0| 0| | 18| 1.0| 2.0|Williams, Mr. Cha...| male| -1| 0.0| 0.0| 244373| 13.0| null| S| 1.0| 0| | 19| 0.0| 3.0|Vander Planke, Mr...|female| 31| 1.0| 0.0| 345763| 18.0| null| S| 0.0| 3| | 20| 1.0| 3.0|Masselmani, Mrs. ...|female| -1| 0.0| 0.0| 2649| 7.225| null| C| 1.0| 0| +-----------+--------+------+--------------------+------+---+-----+-----+----------------+-------+-----+--------+-----+---------+ only showing top 20 rows
titanic.count()
891
# 다 죽었다고 예측
def predict1_func(dummy):
return 0.0
predict1 = udf(predict1_func, returnType=DoubleType())
# 여자는 다 살았다고 남자는 다 죽었다고 예측
def predict2_func(gender):
if gender == "female":
return 1.0
else:
return 0.0
predict2 = udf(predict2_func, returnType=DoubleType())
# UDF 생성
prediction1result = titanic.select(predict1("Sex").alias("prediction"), col("Survived").cast("double").alias("label"))
prediction2result = titanic.select(predict2("Sex").alias("prediction"), col("Survived").cast("double").alias("label"))
from pyspark.ml.evaluation import BinaryClassificationEvaluator
evaluator = BinaryClassificationEvaluator()
evaluator.setRawPredictionCol("prediction").setLabelCol("label")
evaluator.setMetricName("areaUnderROC")
print("prediction1result areaUnderROC=%f" % evaluator.evaluate(prediction1result))
print("prediction2result areaUnderROC=%f" % evaluator.evaluate(prediction2result))
evaluator.setMetricName("areaUnderPR")
print("prediction1result areaUnderPR=%f" % evaluator.evaluate(prediction1result))
print("prediction2result areaUnderPR=%f" % evaluator.evaluate(prediction2result))
prediction1result areaUnderROC=0.500000 prediction2result areaUnderROC=0.766873 prediction1result areaUnderPR=0.383838 prediction2result areaUnderPR=0.684957
prediction1result = prediction1result.withColumn('accuracy',col('prediction')== col('label'))
prediction1result = prediction1result.where(prediction1result.prediction==prediction1result.label)
prediction1result.count() / float(titanic.count())
0.6161616161616161
prediction2result = prediction2result.withColumn('accuracy',col('prediction')== col('label'))
prediction2result = prediction2result.where(prediction2result.prediction==prediction2result.label)
prediction2result.count() / float(titanic.count())
0.7867564534231201
titanic.withColumn("age_group", age_group(col("Age"))).groupBy("age_group", "Sex") \
.agg(sum("Survived"), count("Survived"), sum("Survived")/count("Survived")) \
.orderBy("Sex", "age_group").show()
+---------+------+-------------+---------------+---------------------------------+ |age_group| Sex|sum(Survived)|count(Survived)|(sum(Survived) / count(Survived))| +---------+------+-------------+---------------+---------------------------------+ | 0|female| 55.0| 83| 0.6626506024096386| | 1|female| 34.0| 45| 0.7555555555555555| | 2|female| 52.0| 72| 0.7222222222222222| | 3|female| 50.0| 60| 0.8333333333333334| | 4|female| 22.0| 32| 0.6875| | 5|female| 16.0| 18| 0.8888888888888888| | 6|female| 4.0| 4| 1.0| | 0| male| 35.0| 156| 0.22435897435897437| | 1| male| 7.0| 57| 0.12280701754385964| | 2| male| 25.0| 148| 0.16891891891891891| | 3| male| 23.0| 107| 0.21495327102803738| | 4| male| 12.0| 57| 0.21052631578947367| | 5| male| 4.0| 30| 0.13333333333333333| | 6| male| 2.0| 15| 0.13333333333333333| | 7| male| 0.0| 6| 0.0| | 8| male| 1.0| 1| 1.0| +---------+------+-------------+---------------+---------------------------------+
training_data, test_data = titanic.randomSplit([0.7, 0.3])
from pyspark.ml.classification import *
from pyspark.ml.feature import *
# Train a RandomForest model.
rf = RandomForestClassifier() \
.setNumTrees(20)
sex_indexer = StringIndexer().setInputCol("Sex").setOutputCol("sex_idx").fit(titanic)
embark_indexer = StringIndexer().setInputCol("Embarked").setOutputCol("embark_idx").fit(titanic)
assembler = VectorAssembler() \
.setInputCols(["Pclass", "SibSp", "Parch", "age_group", "sex_idx"]) \
.setOutputCol("features")
from pyspark.ml import Pipeline
# Chain indexers and forest in a Pipeline.
pipeline = Pipeline().setStages([sex_indexer, embark_indexer, assembler, rf])
pipeline_model = pipeline.fit(training_data)
prediction4result = pipeline_model.transform(test_data)
evaluator.setMetricName("areaUnderROC")
print(evaluator.evaluate(prediction4result))
print(" ")
evaluator.setMetricName("areaUnderPR")
print(evaluator.evaluate(prediction4result))
0.7497701149425287 0.7079536774334352
titanic_test = spark.read.option("header", "true") \
.csv("/Users/ryanshin/Downloads/test.csv") \
.withColumn("Pclass", col("Pclass").cast("double"))\
.withColumn("SibSp", col("SibSp").cast("double"))\
.withColumn("Parch", col("Parch").cast("double"))\
.withColumn("Fare", col("Fare").cast("double"))\
.withColumn("Age", col("Age").cast("int"))\
.na.fill("S", "Embarked") \
.na.fill(-1, "Age") \
.withColumn("age_group", age_group(col("Age")))
titanic_test.show()
+-----------+------+--------------------+------+---+-----+-----+----------------+-------+-----+--------+---------+ |PassengerId|Pclass| Name| Sex|Age|SibSp|Parch| Ticket| Fare|Cabin|Embarked|age_group| +-----------+------+--------------------+------+---+-----+-----+----------------+-------+-----+--------+---------+ | 892| 3.0| Kelly, Mr. James| male| 34| 0.0| 0.0| 330911| 7.8292| null| Q| 3| | 893| 3.0|Wilkes, Mrs. Jame...|female| 47| 1.0| 0.0| 363272| 7.0| null| S| 4| | 894| 2.0|Myles, Mr. Thomas...| male| 62| 0.0| 0.0| 240276| 9.6875| null| Q| 6| | 895| 3.0| Wirz, Mr. Albert| male| 27| 0.0| 0.0| 315154| 8.6625| null| S| 2| | 896| 3.0|Hirvonen, Mrs. Al...|female| 22| 1.0| 1.0| 3101298|12.2875| null| S| 2| | 897| 3.0|Svensson, Mr. Joh...| male| 14| 0.0| 0.0| 7538| 9.225| null| S| 1| | 898| 3.0|Connolly, Miss. Kate|female| 30| 0.0| 0.0| 330972| 7.6292| null| Q| 3| | 899| 2.0|Caldwell, Mr. Alb...| male| 26| 1.0| 1.0| 248738| 29.0| null| S| 2| | 900| 3.0|Abrahim, Mrs. Jos...|female| 18| 0.0| 0.0| 2657| 7.2292| null| C| 1| | 901| 3.0|Davies, Mr. John ...| male| 21| 2.0| 0.0| A/4 48871| 24.15| null| S| 2| | 902| 3.0| Ilieff, Mr. Ylio| male| -1| 0.0| 0.0| 349220| 7.8958| null| S| 0| | 903| 1.0|Jones, Mr. Charle...| male| 46| 0.0| 0.0| 694| 26.0| null| S| 4| | 904| 1.0|Snyder, Mrs. John...|female| 23| 1.0| 0.0| 21228|82.2667| B45| S| 2| | 905| 2.0|Howard, Mr. Benjamin| male| 63| 1.0| 0.0| 24065| 26.0| null| S| 6| | 906| 1.0|Chaffee, Mrs. Her...|female| 47| 1.0| 0.0| W.E.P. 5734| 61.175| E31| S| 4| | 907| 2.0|del Carlo, Mrs. S...|female| 24| 1.0| 0.0| SC/PARIS 2167|27.7208| null| C| 2| | 908| 2.0| Keane, Mr. Daniel| male| 35| 0.0| 0.0| 233734| 12.35| null| Q| 3| | 909| 3.0| Assaf, Mr. Gerios| male| 21| 0.0| 0.0| 2692| 7.225| null| C| 2| | 910| 3.0|Ilmakangas, Miss....|female| 27| 1.0| 0.0|STON/O2. 3101270| 7.925| null| S| 2| | 911| 3.0|"Assaf Khalil, Mr...|female| 45| 0.0| 0.0| 2696| 7.225| null| C| 4| +-----------+------+--------------------+------+---+-----+-----+----------------+-------+-----+--------+---------+ only showing top 20 rows
result = pipeline_model.transform(titanic_test).select(col('PassengerId'), col('prediction').alias("Survived").cast('int'))
result.show()
+-----------+--------+ |PassengerId|Survived| +-----------+--------+ | 892| 0| | 893| 0| | 894| 0| | 895| 0| | 896| 0| | 897| 0| | 898| 0| | 899| 0| | 900| 1| | 901| 0| | 902| 0| | 903| 0| | 904| 1| | 905| 0| | 906| 1| | 907| 1| | 908| 0| | 909| 0| | 910| 1| | 911| 0| +-----------+--------+ only showing top 20 rows
%%bash
rm -rf answer
result.repartition(1).write.csv('answer', header='true')
%%bash
cat answer/part*
PassengerId,Survived 892,0 893,0 894,0 895,0 896,0 897,0 898,0 899,0 900,1 901,0 902,0 903,0 904,1 905,0 906,1 907,1 908,0 909,0 910,1 911,0 912,0 913,1 914,1 915,0 916,1 917,0 918,1 919,0 920,0 921,0 922,0 923,0 924,0 925,0 926,1 927,0 928,1 929,0 930,0 931,0 932,0 933,0 934,0 935,1 936,1 937,0 938,0 939,0 940,1 941,0 942,1 943,0 944,1 945,1 946,0 947,0 948,0 949,0 950,0 951,1 952,0 953,0 954,0 955,0 956,0 957,1 958,1 959,0 960,0 961,1 962,0 963,0 964,0 965,0 966,1 967,0 968,0 969,1 970,0 971,0 972,1 973,0 974,0 975,0 976,0 977,0 978,0 979,1 980,1 981,1 982,1 983,0 984,1 985,0 986,0 987,0 988,1 989,0 990,0 991,0 992,1 993,0 994,0 995,0 996,0 997,0 998,0 999,0 1000,0 1001,0 1002,0 1003,1 1004,1 1005,1 1006,1 1007,0 1008,0 1009,1 1010,0 1011,1 1012,1 1013,0 1014,1 1015,0 1016,0 1017,1 1018,0 1019,1 1020,0 1021,0 1022,0 1023,0 1024,0 1025,0 1026,0 1027,0 1028,0 1029,0 1030,0 1031,0 1032,0 1033,1 1034,0 1035,0 1036,0 1037,0 1038,0 1039,0 1040,0 1041,0 1042,1 1043,0 1044,0 1045,0 1046,0 1047,0 1048,1 1049,0 1050,0 1051,0 1052,1 1053,1 1054,1 1055,0 1056,0 1057,0 1058,0 1059,0 1060,1 1061,0 1062,0 1063,0 1064,0 1065,0 1066,0 1067,1 1068,1 1069,0 1070,1 1071,1 1072,0 1073,1 1074,1 1075,0 1076,1 1077,0 1078,1 1079,0 1080,0 1081,0 1082,0 1083,0 1084,0 1085,0 1086,1 1087,0 1088,1 1089,1 1090,0 1091,1 1092,1 1093,1 1094,1 1095,1 1096,0 1097,0 1098,0 1099,0 1100,1 1101,0 1102,0 1103,0 1104,0 1105,1 1106,0 1107,0 1108,1 1109,0 1110,1 1111,0 1112,1 1113,0 1114,1 1115,0 1116,1 1117,1 1118,0 1119,1 1120,0 1121,0 1122,0 1123,1 1124,0 1125,0 1126,1 1127,0 1128,0 1129,0 1130,1 1131,1 1132,1 1133,1 1134,1 1135,0 1136,1 1137,1 1138,1 1139,0 1140,1 1141,1 1142,1 1143,0 1144,1 1145,0 1146,0 1147,0 1148,0 1149,0 1150,1 1151,0 1152,0 1153,0 1154,1 1155,1 1156,0 1157,0 1158,0 1159,0 1160,1 1161,0 1162,0 1163,0 1164,1 1165,1 1166,0 1167,1 1168,0 1169,0 1170,0 1171,0 1172,0 1173,1 1174,1 1175,1 1176,1 1177,0 1178,0 1179,1 1180,0 1181,0 1182,0 1183,0 1184,0 1185,0 1186,0 1187,0 1188,1 1189,0 1190,0 1191,0 1192,0 1193,0 1194,0 1195,0 1196,1 1197,1 1198,1 1199,1 1200,0 1201,0 1202,0 1203,0 1204,0 1205,0 1206,1 1207,1 1208,0 1209,0 1210,0 1211,0 1212,0 1213,0 1214,0 1215,0 1216,1 1217,0 1218,1 1219,0 1220,0 1221,0 1222,1 1223,0 1224,0 1225,0 1226,0 1227,0 1228,0 1229,0 1230,0 1231,0 1232,0 1233,0 1234,1 1235,1 1236,1 1237,1 1238,0 1239,0 1240,0 1241,1 1242,1 1243,0 1244,0 1245,0 1246,0 1247,0 1248,1 1249,0 1250,0 1251,1 1252,0 1253,1 1254,1 1255,0 1256,1 1257,0 1258,0 1259,0 1260,1 1261,0 1262,0 1263,1 1264,0 1265,0 1266,1 1267,1 1268,1 1269,0 1270,0 1271,0 1272,0 1273,0 1274,1 1275,1 1276,0 1277,1 1278,0 1279,0 1280,0 1281,0 1282,0 1283,1 1284,0 1285,0 1286,0 1287,1 1288,0 1289,1 1290,0 1291,0 1292,1 1293,0 1294,1 1295,0 1296,1 1297,0 1298,0 1299,0 1300,1 1301,1 1302,1 1303,1 1304,0 1305,0 1306,1 1307,0 1308,0 1309,1