In this project, we help Hyundai Heavy Industry to build a regression model to predict the number of crew numbers the ships need. The data description is as follows.
Variables
Ship Name
Cruise Line
Age (as of 2013)
Tonnage (1000s of tons)
passengers (100s)
Length (100s of feet)
Cabins (100s)
Passenger Density
Crew (100s)
from pyspark.sql import SparkSession
from pyspark.ml.feature import StringIndexer
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import LinearRegression
from pyspark.sql.functions import corr
spark = SparkSession.builder.appName('cruise').getOrCreate()
df = spark.read.csv("dbfs:/FileStore/shared_uploads/dizhen@hsph.harvard.edu/cruise_ship_info.csv",inferSchema=True,header=True)
df.printSchema()
root |-- Ship_name: string (nullable = true) |-- Cruise_line: string (nullable = true) |-- Age: integer (nullable = true) |-- Tonnage: double (nullable = true) |-- passengers: double (nullable = true) |-- length: double (nullable = true) |-- cabins: double (nullable = true) |-- passenger_density: double (nullable = true) |-- crew: double (nullable = true)
df.show()
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+ | Ship_name|Cruise_line|Age| Tonnage|passengers|length|cabins|passenger_density|crew| +-----------+-----------+---+------------------+----------+------+------+-----------------+----+ | Journey| Azamara| 6|30.276999999999997| 6.94| 5.94| 3.55| 42.64|3.55| | Quest| Azamara| 6|30.276999999999997| 6.94| 5.94| 3.55| 42.64|3.55| |Celebration| Carnival| 26| 47.262| 14.86| 7.22| 7.43| 31.8| 6.7| | Conquest| Carnival| 11| 110.0| 29.74| 9.53| 14.88| 36.99|19.1| | Destiny| Carnival| 17| 101.353| 26.42| 8.92| 13.21| 38.36|10.0| | Ecstasy| Carnival| 22| 70.367| 20.52| 8.55| 10.2| 34.29| 9.2| | Elation| Carnival| 15| 70.367| 20.52| 8.55| 10.2| 34.29| 9.2| | Fantasy| Carnival| 23| 70.367| 20.56| 8.55| 10.22| 34.23| 9.2| |Fascination| Carnival| 19| 70.367| 20.52| 8.55| 10.2| 34.29| 9.2| | Freedom| Carnival| 6|110.23899999999999| 37.0| 9.51| 14.87| 29.79|11.5| | Glory| Carnival| 10| 110.0| 29.74| 9.51| 14.87| 36.99|11.6| | Holiday| Carnival| 28| 46.052| 14.52| 7.27| 7.26| 31.72| 6.6| |Imagination| Carnival| 18| 70.367| 20.52| 8.55| 10.2| 34.29| 9.2| |Inspiration| Carnival| 17| 70.367| 20.52| 8.55| 10.2| 34.29| 9.2| | Legend| Carnival| 11| 86.0| 21.24| 9.63| 10.62| 40.49| 9.3| | Liberty*| Carnival| 8| 110.0| 29.74| 9.51| 14.87| 36.99|11.6| | Miracle| Carnival| 9| 88.5| 21.24| 9.63| 10.62| 41.67|10.3| | Paradise| Carnival| 15| 70.367| 20.52| 8.55| 10.2| 34.29| 9.2| | Pride| Carnival| 12| 88.5| 21.24| 9.63| 11.62| 41.67| 9.3| | Sensation| Carnival| 20| 70.367| 20.52| 8.55| 10.2| 34.29| 9.2| +-----------+-----------+---+------------------+----------+------+------+-----------------+----+ only showing top 20 rows
df.describe().show()
+-------+---------+-----------+------------------+------------------+-----------------+-----------------+------------------+-----------------+-----------------+ |summary|Ship_name|Cruise_line| Age| Tonnage| passengers| length| cabins|passenger_density| crew| +-------+---------+-----------+------------------+------------------+-----------------+-----------------+------------------+-----------------+-----------------+ | count| 158| 158| 158| 158| 158| 158| 158| 158| 158| | mean| Infinity| null|15.689873417721518| 71.28467088607599|18.45740506329114|8.130632911392404| 8.830000000000005|39.90094936708861|7.794177215189873| | stddev| null| null| 7.615691058751413|37.229540025907866|9.677094775143416|1.793473548054825|4.4714172221480615| 8.63921711391542|3.503486564627034| | min|Adventure| Azamara| 4| 2.329| 0.66| 2.79| 0.33| 17.7| 0.59| | max|Zuiderdam| Windstar| 48| 220.0| 54.0| 11.82| 27.0| 71.43| 21.0| +-------+---------+-----------+------------------+------------------+-----------------+-----------------+------------------+-----------------+-----------------+
df.groupBy('Cruise_line').count().show()
+-----------------+-----+ | Cruise_line|count| +-----------------+-----+ | Costa| 11| | P&O| 6| | Cunard| 3| |Regent_Seven_Seas| 5| | MSC| 8| | Carnival| 22| | Crystal| 2| | Orient| 1| | Princess| 17| | Silversea| 4| | Seabourn| 3| | Holland_American| 14| | Windstar| 3| | Disney| 2| | Norwegian| 13| | Oceania| 3| | Azamara| 2| | Celebrity| 10| | Star| 6| | Royal_Caribbean| 23| +-----------------+-----+
indexer = StringIndexer(inputCol="Cruise_line", outputCol="cruise_cat")
indexed = indexer.fit(df).transform(df)
indexed.head(5)
Out[11]: [Row(Ship_name='Journey', Cruise_line='Azamara', Age=6, Tonnage=30.276999999999997, passengers=6.94, length=5.94, cabins=3.55, passenger_density=42.64, crew=3.55, cruise_cat=16.0), Row(Ship_name='Quest', Cruise_line='Azamara', Age=6, Tonnage=30.276999999999997, passengers=6.94, length=5.94, cabins=3.55, passenger_density=42.64, crew=3.55, cruise_cat=16.0), Row(Ship_name='Celebration', Cruise_line='Carnival', Age=26, Tonnage=47.262, passengers=14.86, length=7.22, cabins=7.43, passenger_density=31.8, crew=6.7, cruise_cat=1.0), Row(Ship_name='Conquest', Cruise_line='Carnival', Age=11, Tonnage=110.0, passengers=29.74, length=9.53, cabins=14.88, passenger_density=36.99, crew=19.1, cruise_cat=1.0), Row(Ship_name='Destiny', Cruise_line='Carnival', Age=17, Tonnage=101.353, passengers=26.42, length=8.92, cabins=13.21, passenger_density=38.36, crew=10.0, cruise_cat=1.0)]
indexed.columns
Out[13]: ['Ship_name', 'Cruise_line', 'Age', 'Tonnage', 'passengers', 'length', 'cabins', 'passenger_density', 'crew', 'cruise_cat']
assembler = VectorAssembler(
inputCols=['Age',
'Tonnage',
'passengers',
'length',
'cabins',
'passenger_density',
'cruise_cat'],
outputCol="features")
output = assembler.transform(indexed)
output.select("features", "crew").show()
+--------------------+----+ | features|crew| +--------------------+----+ |[6.0,30.276999999...|3.55| |[6.0,30.276999999...|3.55| |[26.0,47.262,14.8...| 6.7| |[11.0,110.0,29.74...|19.1| |[17.0,101.353,26....|10.0| |[22.0,70.367,20.5...| 9.2| |[15.0,70.367,20.5...| 9.2| |[23.0,70.367,20.5...| 9.2| |[19.0,70.367,20.5...| 9.2| |[6.0,110.23899999...|11.5| |[10.0,110.0,29.74...|11.6| |[28.0,46.052,14.5...| 6.6| |[18.0,70.367,20.5...| 9.2| |[17.0,70.367,20.5...| 9.2| |[11.0,86.0,21.24,...| 9.3| |[8.0,110.0,29.74,...|11.6| |[9.0,88.5,21.24,9...|10.3| |[15.0,70.367,20.5...| 9.2| |[12.0,88.5,21.24,...| 9.3| |[20.0,70.367,20.5...| 9.2| +--------------------+----+ only showing top 20 rows
final_data = output.select("features", "crew")
train_data,test_data = final_data.randomSplit([0.7,0.3])
lr = LinearRegression(labelCol='crew')
lrModel = lr.fit(train_data)
print("Coefficients: {} Intercept: {}".format(lrModel.coefficients,lrModel.intercept))
Coefficients: [-0.00450918681822143,0.012963558609912268,-0.15963811898272345,0.3462005314106989,0.8897395131080795,-0.008624377855820762,0.06069326390922283] Intercept: -0.784750205393044
test_results = lrModel.evaluate(test_data)
print("RMSE: {}".format(test_results.rootMeanSquaredError))
print("MSE: {}".format(test_results.meanSquaredError))
print("R2: {}".format(test_results.r2))
RMSE: 1.0837245330819674 MSE: 1.174458863603728 R2: 0.8798711702553128
df.select(corr('crew','passengers')).show()
+----------------------+ |corr(crew, passengers)| +----------------------+ | 0.9152341306065384| +----------------------+
df.select(corr('crew','cabins')).show()
+------------------+ |corr(crew, cabins)| +------------------+ |0.9508226063578497| +------------------+