from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
SparkSession - hive
import ibmos2spark
# @hidden_cell
credentials = {
'endpoint': '',
'service_id': 'iam-ServiceId-56ce77ac-9d82-41cc-96e9-1a4689ce5806',
'iam_service_endpoint': '',
'api_key': 'fwcL_7HeZADPUr-jsL5P7scQbjmaoY0H3dqqC3Bsp9_5'
configuration_name = 'os_a2d5a67716ea4f0db7f2711634a21cc0_configs'
cos = ibmos2spark.CloudObjectStorage(sc, credentials, configuration_name, 'bluemix_cos')
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
df_data_1 =\
.option('header', 'true')\
.option('inferSchema', 'true')\
.load(cos.url('patientdataV6-1.csv', 'heartfailure-donotdelete-pr-iafaq9zj55qih4'))
root |-- AVGHEARTBEATSPERMIN: integer (nullable = true) |-- PALPITATIONSPERDAY: integer (nullable = true) |-- CHOLESTEROL: integer (nullable = true) |-- BMI: integer (nullable = true) |-- HEARTFAILURE: string (nullable = true) |-- AGE: integer (nullable = true) |-- SEX: string (nullable = true) |-- FAMILYHISTORY: string (nullable = true) |-- SMOKERLAST5YRS: string (nullable = true) |-- EXERCISEMINPERWEEK: integer (nullable = true)
+-------------------+------------------+-----------+---+------------+---+---+-------------+--------------+------------------+ |AVGHEARTBEATSPERMIN|PALPITATIONSPERDAY|CHOLESTEROL|BMI|HEARTFAILURE|AGE|SEX|FAMILYHISTORY|SMOKERLAST5YRS|EXERCISEMINPERWEEK| +-------------------+------------------+-----------+---+------------+---+---+-------------+--------------+------------------+ | 93| 22| 163| 25| N| 49| F| N| N| 110| | 108| 22| 181| 24| N| 32| F| N| N| 192| | 86| 0| 239| 20| N| 60| F| N| N| 121| | 80| 36| 164| 31| Y| 45| F| Y| N| 141| | 66| 36| 185| 23| N| 39| F| N| N| 63| | 125| 27| 201| 31| N| 47| M| N| N| 13| | 83| 27| 169| 20| N| 71| F| Y| N| 124| | 107| 31| 199| 32| N| 55| F| N| N| 22| | 92| 28| 174| 22| N| 44| F| N| N| 107| | 84| 12| 206| 25| N| 50| M| N| N| 199| | 60| 1| 194| 28| N| 71| M| N| N| 27| | 134| 7| 228| 34| Y| 63| F| Y| N| 92| | 103| 0| 237| 24| N| 64| F| Y| N| 34| | 101| 39| 157| 20| N| 49| M| N| N| 33| | 92| 2| 169| 26| N| 36| M| N| N| 217| | 80| 27| 234| 27| N| 50| M| N| N| 28| | 82| 14| 155| 30| N| 70| F| N| N| 207| | 63| 9| 204| 26| N| 42| M| N| N| 88| | 83| 12| 209| 29| N| 38| M| Y| N| 220| | 80| 37| 157| 20| N| 48| M| N| N| 54| +-------------------+------------------+-----------+---+------------+---+---+-------------+--------------+------------------+ only showing top 20 rows"AVGHEARTBEATSPERMIN", "PALPITATIONSPERDAY","CHOLESTEROL","BMI","HEARTFAILURE","SEX").toPandas().head()
0 | 93 | 22 | 163 | 25 | N | F |
1 | 108 | 22 | 181 | 24 | N | F |
2 | 86 | 0 | 239 | 20 | N | F |
3 | 80 | 36 | 164 | 31 | Y | F |
4 | 66 | 36 | 185 | 23 | N | F |
+-------+-------------------+------------------+------------------+------------------+------------+------------------+-----+-------------+--------------+------------------+ |summary|AVGHEARTBEATSPERMIN|PALPITATIONSPERDAY| CHOLESTEROL| BMI|HEARTFAILURE| AGE| SEX|FAMILYHISTORY|SMOKERLAST5YRS|EXERCISEMINPERWEEK| +-------+-------------------+------------------+------------------+------------------+------------+------------------+-----+-------------+--------------+------------------+ | count| 10800| 10800| 10800| 10800| 10800| 10800|10800| 10800| 10800| 10800| | mean| 87.11509259259259|20.423148148148147|195.08027777777778| 26.35972222222222| null|49.965185185185184| null| null| null|119.72953703703703| | stddev| 19.744375148984474|12.165320351622993|26.136731865042325|3.8201472810942136| null|13.079280962015586| null| null| null| 71.14706006382843| | min| 48| 0| 150| 20| N| 28| F| N| N| 0| | max| 161| 45| 245| 34| Y| 72| M| Y| Y| 250| +-------+-------------------+------------------+------------------+------------------+------------+------------------+-----+-------------+--------------+------------------+
import pixiedust
!pip install --upgrade pixiedust
import pixiedust
Collecting pixiedust Collecting mpld3 (from pixiedust) Collecting requests (from pixiedust) Using cached Collecting astunparse (from pixiedust) Using cached Collecting geojson (from pixiedust) Using cached Collecting colour (from pixiedust) Using cached Collecting markdown (from pixiedust) Using cached Collecting lxml (from pixiedust) Using cached Collecting idna<3,>=2.5 (from requests->pixiedust) Using cached Collecting chardet<4,>=3.0.2 (from requests->pixiedust) Using cached Collecting urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 (from requests->pixiedust) Using cached Collecting certifi>=2017.4.17 (from requests->pixiedust) Using cached Collecting wheel<1.0,>=0.23.0 (from astunparse->pixiedust) Using cached Collecting six<2.0,>=1.6.1 (from astunparse->pixiedust) Using cached Collecting setuptools>=36 (from markdown->pixiedust) Using cached tensorflow 1.13.1 requires tensorboard<1.14.0,>=1.13.0, which is not installed. pytest-astropy 0.8.0 requires pytest-cov>=2.0, which is not installed. pytest-astropy 0.8.0 requires pytest-filter-subpackage>=0.1, which is not installed. pytest-astropy 0.8.0 has requirement pytest>=4.6, but you'll have pytest 3.10.1 which is incompatible. ibm-cos-sdk-core 2.4.3 has requirement urllib3<1.25,>=1.20, but you'll have urllib3 1.25.9 which is incompatible. botocore 1.12.82 has requirement urllib3<1.25,>=1.20, but you'll have urllib3 1.25.9 which is incompatible. Installing collected packages: mpld3, idna, chardet, urllib3, certifi, requests, wheel, six, astunparse, geojson, colour, setuptools, markdown, lxml, pixiedust Successfully installed astunparse-1.6.3 certifi-2020.4.5.1 chardet-3.0.4 colour-0.1.5 geojson-2.5.0 idna-2.9 lxml-4.5.0 markdown-3.2.1 mpld3-0.3 pixiedust-1.1.18 requests-2.23.0 setuptools-46.1.3 six-1.14.0 urllib3-1.25.9 wheel-0.34.2
92 | 28 | 174 | 22 | N | 44 | F | N | N | 107 |
84 | 12 | 206 | 25 | N | 50 | M | N | N | 199 |
78 | 5 | 233 | 25 | N | 53 | M | N | N | 204 |
110 | 8 | 227 | 28 | N | 53 | F | Y | N | 39 |
97 | 11 | 170 | 32 | N | 38 | F | N | N | 222 |
82 | 15 | 154 | 26 | N | 46 | F | N | N | 227 |
88 | 38 | 154 | 22 | N | 47 | F | N | N | 204 |
100 | 30 | 212 | 23 | N | 72 | M | N | N | 22 |
100 | 17 | 153 | 32 | N | 49 | M | N | N | 244 |
82 | 3 | 218 | 27 | Y | 49 | F | N | N | 182 |
105 | 36 | 157 | 32 | N | 66 | F | Y | N | 5 |
136 | 42 | 233 | 26 | N | 44 | M | N | N | 195 |
76 | 3 | 205 | 27 | N | 40 | M | N | N | 189 |
140 | 24 | 216 | 24 | N | 66 | M | Y | N | 104 |
129 | 38 | 154 | 25 | N | 30 | F | N | N | 67 |
124 | 19 | 237 | 27 | N | 29 | F | N | N | 96 |
111 | 20 | 203 | 26 | N | 57 | F | N | N | 60 |
105 | 39 | 150 | 30 | N | 43 | M | N | N | 135 |
79 | 20 | 158 | 23 | N | 67 | M | N | N | 166 |
106 | 37 | 207 | 30 | N | 66 | F | Y | N | 140 |
72 | 30 | 154 | 26 | N | 64 | M | Y | Y | 81 |
116 | 27 | 215 | 29 | N | 36 | M | N | N | 150 |
119 | 16 | 220 | 21 | N | 32 | M | N | N | 226 |
66 | 22 | 237 | 21 | N | 40 | M | N | N | 148 |
64 | 28 | 180 | 29 | N | 45 | F | N | N | 113 |
94 | 38 | 221 | 28 | N | 72 | M | N | N | 123 |
89 | 39 | 194 | 25 | Y | 32 | M | N | N | 83 |
126 | 10 | 225 | 21 | N | 51 | M | Y | N | 212 |
98 | 24 | 205 | 28 | N | 60 | F | N | N | 135 |
94 | 22 | 228 | 22 | N | 59 | F | N | N | 24 |
100 | 18 | 212 | 20 | N | 37 | F | N | N | 221 |
89 | 4 | 227 | 31 | N | 72 | M | Y | N | 78 |
69 | 26 | 184 | 20 | N | 68 | F | N | N | 37 |
87 | 40 | 219 | 28 | N | 37 | M | Y | N | 103 |
120 | 44 | 230 | 34 | Y | 38 | F | Y | Y | 0 |
100 | 13 | 165 | 30 | N | 34 | M | N | N | 35 |
79 | 27 | 233 | 29 | N | 29 | M | N | N | 177 |
93 | 25 | 177 | 27 | N | 29 | F | N | N | 242 |
81 | 4 | 182 | 21 | N | 65 | M | N | N | 149 |
79 | 34 | 191 | 28 | N | 72 | F | Y | N | 194 |
97 | 36 | 206 | 27 | Y | 42 | M | Y | N | 116 |
86 | 20 | 171 | 26 | N | 41 | M | N | N | 187 |
107 | 20 | 159 | 31 | N | 61 | M | N | N | 52 |
99 | 11 | 163 | 21 | N | 57 | M | N | N | 98 |
122 | 6 | 226 | 23 | N | 45 | M | N | N | 229 |
54 | 12 | 215 | 29 | N | 49 | F | N | N | 178 |
65 | 21 | 220 | 26 | N | 40 | F | N | N | 82 |
90 | 28 | 181 | 22 | N | 30 | F | N | N | 92 |
91 | 5 | 182 | 31 | N | 59 | M | Y | N | 36 |
126 | 30 | 189 | 30 | N | 58 | F | N | N | 58 |
95 | 23 | 197 | 21 | N | 32 | M | N | N | 165 |
88 | 2 | 180 | 22 | N | 63 | M | N | N | 141 |
66 | 14 | 150 | 24 | N | 43 | F | N | N | 124 |
84 | 40 | 201 | 30 | N | 56 | M | N | N | 118 |
105 | 19 | 191 | 31 | Y | 52 | M | Y | N | 3 |
89 | 1 | 154 | 28 | N | 53 | M | N | N | 206 |
68 | 7 | 209 | 32 | N | 44 | F | N | N | 99 |
69 | 34 | 188 | 29 | N | 46 | F | N | N | 208 |
75 | 20 | 203 | 23 | N | 70 | M | Y | N | 191 |
104 | 0 | 220 | 25 | N | 48 | M | N | N | 5 |
94 | 8 | 192 | 22 | N | 61 | F | N | N | 218 |
129 | 43 | 197 | 33 | Y | 30 | F | N | N | 139 |
90 | 36 | 160 | 21 | N | 56 | F | N | N | 135 |
115 | 17 | 164 | 23 | N | 44 | M | N | N | 204 |
57 | 35 | 168 | 21 | N | 53 | F | N | N | 164 |
111 | 3 | 152 | 27 | N | 34 | F | N | N | 226 |
110 | 36 | 232 | 27 | N | 57 | F | N | N | 153 |
105 | 36 | 218 | 24 | N | 58 | F | N | N | 124 |
107 | 24 | 204 | 27 | N | 39 | F | N | N | 103 |
80 | 36 | 172 | 25 | N | 54 | M | N | N | 214 |
75 | 4 | 234 | 27 | N | 72 | M | N | N | 7 |
81 | 0 | 186 | 24 | N | 69 | F | N | N | 10 |
60 | 38 | 185 | 23 | N | 30 | M | N | N | 238 |
93 | 24 | 199 | 22 | Y | 65 | M | Y | N | 39 |
107 | 29 | 211 | 24 | N | 66 | M | N | N | 66 |
101 | 39 | 159 | 29 | N | 50 | F | N | N | 229 |
86 | 38 | 233 | 31 | N | 49 | M | N | Y | 69 |
108 | 37 | 180 | 30 | N | 61 | M | N | N | 194 |
106 | 17 | 216 | 31 | N | 69 | M | N | Y | 58 |
94 | 32 | 184 | 23 | N | 53 | M | Y | N | 210 |
80 | 40 | 238 | 28 | N | 40 | M | N | N | 128 |
102 | 45 | 232 | 28 | Y | 43 | M | Y | N | 20 |
68 | 12 | 233 | 22 | N | 33 | F | N | N | 135 |
60 | 43 | 224 | 23 | Y | 29 | M | N | N | 16 |
92 | 42 | 167 | 23 | Y | 55 | M | Y | N | 59 |
87 | 6 | 209 | 23 | N | 42 | F | N | N | 27 |
116 | 30 | 192 | 28 | N | 70 | M | N | N | 121 |
77 | 5 | 217 | 23 | N | 41 | M | N | N | 137 |
74 | 11 | 152 | 29 | N | 39 | M | N | N | 142 |
79 | 26 | 236 | 26 | N | 33 | F | N | N | 220 |
94 | 16 | 152 | 32 | N | 48 | F | N | N | 112 |
75 | 31 | 182 | 23 | N | 62 | F | N | N | 189 |
101 | 36 | 205 | 29 | Y | 40 | M | N | N | 60 |
67 | 6 | 177 | 29 | Y | 42 | M | N | N | 61 |
68 | 39 | 179 | 28 | N | 31 | M | N | N | 26 |
93 | 21 | 175 | 23 | N | 70 | F | N | N | 73 |
84 | 1 | 190 | 29 | N | 72 | F | N | N | 97 |
111 | 35 | 200 | 25 | Y | 34 | F | Y | N | 52 |
82 | 34 | 165 | 25 | N | 63 | F | N | N | 53 |
82 | 2 | 182 | 32 | N | 53 | F | N | N | 44 |
split_data = df_data_1.randomSplit([0.8, 0.20], 24)
train_data = split_data[0]
test_data = split_data[1]
print("Number of training records: " + str(train_data.count()))
print("Number of testing records : " + str(test_data.count()))
Number of training records: 8637 Number of testing records : 2163
from import StringIndexer, IndexToString, VectorAssembler
from import RandomForestClassifier, LogisticRegression, NaiveBayes
from import MulticlassClassificationEvaluator
from import Pipeline, Model
stringIndexer_label = StringIndexer(inputCol="HEARTFAILURE", outputCol="label").fit(df_data_1)
stringIndexer_sex = StringIndexer(inputCol="SEX", outputCol="SEX_IX")
stringIndexer_famhist = StringIndexer(inputCol="FAMILYHISTORY", outputCol="FAMILYHISTORY_IX")
stringIndexer_smoker = StringIndexer(inputCol="SMOKERLAST5YRS", outputCol="SMOKERLAST5YRS_IX")
rf = RandomForestClassifier(labelCol="label", featuresCol="features")
lr = LogisticRegression()
nb = NaiveBayes(smoothing=1.0)
labelConverter = IndexToString(inputCol="prediction", outputCol="predictedLabel", labels=stringIndexer_label.labels)
vectorAssembler_features = VectorAssembler(inputCols= ["AVGHEARTBEATSPERMIN","PALPITATIONSPERDAY",
transform_df_pipeline = Pipeline(stages=[stringIndexer_sex, stringIndexer_famhist, stringIndexer_smoker, vectorAssembler_features])
transformed_df =
+-------------------+------------------+-----------+---+------------+---+---+-------------+--------------+------------------+------+----------------+-----------------+--------------------+ |AVGHEARTBEATSPERMIN|PALPITATIONSPERDAY|CHOLESTEROL|BMI|HEARTFAILURE|AGE|SEX|FAMILYHISTORY|SMOKERLAST5YRS|EXERCISEMINPERWEEK|SEX_IX|FAMILYHISTORY_IX|SMOKERLAST5YRS_IX| features| +-------------------+------------------+-----------+---+------------+---+---+-------------+--------------+------------------+------+----------------+-----------------+--------------------+ | 93| 22| 163| 25| N| 49| F| N| N| 110| 1.0| 0.0| 0.0|[93.0,22.0,163.0,...| | 108| 22| 181| 24| N| 32| F| N| N| 192| 1.0| 0.0| 0.0|[108.0,22.0,181.0...| | 86| 0| 239| 20| N| 60| F| N| N| 121| 1.0| 0.0| 0.0|[86.0,0.0,239.0,2...| | 80| 36| 164| 31| Y| 45| F| Y| N| 141| 1.0| 1.0| 0.0|[80.0,36.0,164.0,...| | 66| 36| 185| 23| N| 39| F| N| N| 63| 1.0| 0.0| 0.0|[66.0,36.0,185.0,...| | 125| 27| 201| 31| N| 47| M| N| N| 13| 0.0| 0.0| 0.0|[125.0,27.0,201.0...| | 83| 27| 169| 20| N| 71| F| Y| N| 124| 1.0| 1.0| 0.0|[83.0,27.0,169.0,...| | 107| 31| 199| 32| N| 55| F| N| N| 22| 1.0| 0.0| 0.0|[107.0,31.0,199.0...| | 92| 28| 174| 22| N| 44| F| N| N| 107| 1.0| 0.0| 0.0|[92.0,28.0,174.0,...| | 84| 12| 206| 25| N| 50| M| N| N| 199| 0.0| 0.0| 0.0|[84.0,12.0,206.0,...| | 60| 1| 194| 28| N| 71| M| N| N| 27| 0.0| 0.0| 0.0|[60.0,1.0,194.0,2...| | 134| 7| 228| 34| Y| 63| F| Y| N| 92| 1.0| 1.0| 0.0|[134.0,7.0,228.0,...| | 103| 0| 237| 24| N| 64| F| Y| N| 34| 1.0| 1.0| 0.0|[103.0,0.0,237.0,...| | 101| 39| 157| 20| N| 49| M| N| N| 33| 0.0| 0.0| 0.0|[101.0,39.0,157.0...| | 92| 2| 169| 26| N| 36| M| N| N| 217| 0.0| 0.0| 0.0|[92.0,2.0,169.0,2...| | 80| 27| 234| 27| N| 50| M| N| N| 28| 0.0| 0.0| 0.0|[80.0,27.0,234.0,...| | 82| 14| 155| 30| N| 70| F| N| N| 207| 1.0| 0.0| 0.0|[82.0,14.0,155.0,...| | 63| 9| 204| 26| N| 42| M| N| N| 88| 0.0| 0.0| 0.0|[63.0,9.0,204.0,2...| | 83| 12| 209| 29| N| 38| M| Y| N| 220| 0.0| 1.0| 0.0|[83.0,12.0,209.0,...| | 80| 37| 157| 20| N| 48| M| N| N| 54| 0.0| 0.0| 0.0|[80.0,37.0,157.0,...| +-------------------+------------------+-----------+---+------------+---+---+-------------+--------------+------------------+------+----------------+-----------------+--------------------+ only showing top 20 rows
pipeline1 = Pipeline(stages=[stringIndexer_label, stringIndexer_sex,
stringIndexer_famhist, stringIndexer_smoker,
vectorAssembler_features, rf, labelConverter])
m1Name = "Random Forest"
pipeline2 = Pipeline(stages=[stringIndexer_label, stringIndexer_sex,
stringIndexer_famhist, stringIndexer_smoker,
vectorAssembler_features, lr, labelConverter])
m2Name = "Logistic Regression"
pipeline3 = Pipeline(stages=[stringIndexer_label, stringIndexer_sex,
stringIndexer_famhist, stringIndexer_smoker,
vectorAssembler_features, nb, labelConverter])
m3Name = "Naive Bayes"
model1 =
model2 =
model3 =
from pyspark.sql.types import Row
import numpy as np
def getCMEntries(threshold):
newThresholdDF = spark.sql("select label, p1, prediction as oldPrediction,"
" case when p1 > " + str(threshold) + " then 1.0 else 0.0 end as newPrediction"
" from inputToThreshold")
# Here is an SQL query to find true positives
tpA = spark.sql("SELECT * FROM newThreshold WHERE label = 1 AND newprediction = 1")
# Write an SQL query to find the number of false positives
### spark.sql("???")
fpA = spark.sql("SELECT * FROM newThreshold WHERE label = 0 AND newprediction = 1")
# Write an SQL query to find the number of false negatives
### spark.sql("???")
fnA = spark.sql("SELECT * FROM newThreshold WHERE label = 1 AND newprediction = 0")
# Write an SQL query to find the number of true negatives
### spark.sql("???")
tnA = spark.sql("SELECT * FROM newThreshold WHERE label = 0 AND newprediction = 0")
return (tpA.count(), fpA.count(), fnA.count(), tnA.count())
import numpy as np
numBins = 10
thresholds = np.array(range(0, numBins + 1))*1.0/numBins
def getModelThresholdStats(model_df, data):
tp = np.array([i for i in range(0, numBins + 1 )])
fp = np.array([i for i in range(0, numBins + 1 )])
fn = np.array([i for i in range(0, numBins + 1 )])
tn = np.array([i for i in range(0, numBins + 1 )])
#generate dataframe to be used in thresholding:
predictionsForROC = model_df.transform(data)
columnsForCM = spark.sql("select probability, prediction, label from predictions")
extractedProbability = x: Row(p1 = np
.asscalar(x[0][1]), prediction=x[1] , label=x[2])).toDF()
# get the total number of positives and negatives in the predictions dataset:
### p = spark.sql("???").count()
### n = spark.sql("???").count()
p = spark.sql("SELECT * from predictions WHERE label = 1").count()
n = spark.sql("SELECT * from predictions WHERE label = 0").count()
# We know the number of true positives, etc. at the threshold edges:
(tp[0],fp[0],fn[0],tn[0]) = (p, n, 0, 0)
(tp[-1],fp[-1],fn[-1],tn[-1]) = (0, 0, p, n)
for (i, threshold) in zip(range(0, numBins + 1),thresholds):
print(i, threshold)
if (i>0 and i<numBins):
(tp[i],fp[i],fn[i],tn[i]) = getCMEntries(threshold)
# coercing to double precision from integers
tp = tp*1.0; fp=fp*1.0; p=p*1.0; n=n*1.0
# calculate the true positive and false positive rate
### tpr = ???
### fpr = ???
tpr = tp/(tp + fn)
fpr = fp/(fp + tn)
# calculate Matthews Correlation Coefficient
### mcc = ???
mcc = (tp * tn - fp * fn)/np.sqrt((tp + fp)*(tp + fn) * (tn + fp) * (tn + fn))
# calculate accuracy as a function of threshold
accThreshold = (tp + tn) / (p + n)
# calculate the area under the curve
auc = - np.array(
[(fpr[i + 1] - fpr[i]) * 0.5 * (tpr[i + 1] + tpr[i]) for i in range(0,numBins)]
return (tpr, fpr, mcc, accThreshold, auc)
print("getting stats for " + m1Name + ": train")
(tpr1,fpr1,mcc1,acc1,auc1) = getModelThresholdStats(model1, train_data)
print("getting stats for " + m1Name + ": test")
(tpr1Test,fpr1Test,mcc1Test,acc1Test,auc1Test) = getModelThresholdStats(model1, test_data)
print("getting stats for " + m2Name)
(tpr2,fpr2,mcc2,acc2,auc2) = getModelThresholdStats(model2, train_data)
print("getting stats for " + m1Name + ": test")
(tpr2Test,fpr2Test,mcc2Test,acc2Test,auc2Test) = getModelThresholdStats(model2, test_data)
print("getting stats for " + m3Name + ": train")
(tpr3,fpr3,mcc3,acc3,auc3) = getModelThresholdStats(model3, train_data)
print("getting stats for " + m3Name + ": test")
(tpr3Test,fpr3Test,mcc3Test,acc3Test,auc3Test) = getModelThresholdStats(model3, test_data)
getting stats for Random Forest: train 0 0.0 1427 7210 0 0 1 0.1 1064 2105 363 5105 2 0.2 901 910 526 6300 3 0.3 847 684 580 6526 4 0.4 579 238 848 6972 5 0.5 416 117 1011 7093 6 0.6 313 57 1114 7153 7 0.7 134 12 1293 7198 8 0.8 16 0 1411 7210 9 0.9 0 0 1427 7210 10 1.0 0 0 1427 7210 getting stats for Random Forest: test
/opt/ibm/conda/miniconda3.6/lib/python3.6/site-packages/ipykernel/ RuntimeWarning: invalid value encountered in true_divide
0 0.0 361 1802 0 0 1 0.1 275 528 86 1274 2 0.2 233 209 128 1593 3 0.3 223 153 138 1649 4 0.4 151 62 210 1740 5 0.5 102 33 259 1769 6 0.6 72 16 289 1786 7 0.7 22 4 339 1798 8 0.8 6 0 355 1802 9 0.9 0 0 361 1802 10 1.0 0 0 361 1802 getting stats for Logistic Regression 0 0.0 1427 7210 0 0 1 0.1 1111 3023 316 4187 2 0.2 855 1060 572 6150 3 0.3 740 687 687 6523 4 0.4 595 429 832 6781 5 0.5 425 233 1002 6977 6 0.6 271 95 1156 7115 7 0.7 134 32 1293 7178 8 0.8 34 9 1393 7201 9 0.9 1 1 1426 7209 10 1.0 0 0 1427 7210 getting stats for Random Forest: test 0 0.0 361 1802 0 0 1 0.1 280 730 81 1072 2 0.2 226 247 135 1555 3 0.3 203 163 158 1639 4 0.4 161 109 200 1693 5 0.5 108 53 253 1749 6 0.6 67 26 294 1776 7 0.7 25 7 336 1795 8 0.8 8 1 353 1801 9 0.9 1 0 360 1802 10 1.0 0 0 361 1802 getting stats for Naive Bayes: train 0 0.0 1427 7210 0 0 1 0.1 876 3314 551 3896 2 0.2 841 3133 586 4077 3 0.3 809 2998 618 4212 4 0.4 785 2905 642 4305 5 0.5 770 2814 657 4396 6 0.6 742 2729 685 4481 7 0.7 715 2639 712 4571 8 0.8 686 2504 741 4706 9 0.9 643 2335 784 4875 10 1.0 0 0 1427 7210 getting stats for Naive Bayes: test 0 0.0 361 1802 0 0 1 0.1 230 789 131 1013 2 0.2 221 748 140 1054 3 0.3 216 721 145 1081 4 0.4 210 695 151 1107 5 0.5 203 669 158 1133 6 0.6 198 635 163 1167 7 0.7 193 612 168 1190 8 0.8 186 579 175 1223 9 0.9 172 534 189 1268 10 1.0 0 0 361 1802
%matplotlib inline
import matplotlib.pyplot as plt
dpi = 300 #dots per square inch when plotting...higher resolution for publications.
plt.rcParams['figure.dpi']= dpi
plt.rcParams["figure.figsize"] = [4,1]
plt.rcParams['figure.figsize'] = (15,5)
plt.plot(thresholds, mcc1, '-o')
plt.plot(thresholds, mcc2, '-o')
plt.plot(thresholds, mcc3, '-o')
plt.title("Matthews Correlation Coefficient")
plt.legend([m1Name + ": AUC (train) = "+ str(round(auc1,3)),
m2Name + ": AUC (train) = "+ str(round(auc2,3)),
m3Name + ": AUC (train) = "+ str(round(auc3,3))])
plt.plot( thresholds, mcc1Test, '-o')
plt.plot( thresholds, mcc2Test, '-o')
plt.plot( thresholds, mcc3Test, '-o')
plt.title("Matthews Correlation Coefficient")
plt.legend([m1Name + ": AUC (test) = "+ str(round(auc1Test,3)),
m2Name + ": AUC (test) = "+ str(round(auc2Test,3)),
m3Name + ": AUC (test) = "+ str(round(auc3Test,3))])
plt.plot(thresholds, acc1,'-o')
plt.plot(thresholds, acc2,'-o')
plt.plot(thresholds, acc3,'-o')
plt.plot(thresholds, acc1Test,'-o')
plt.plot(thresholds, acc2Test,'-o')
plt.plot(thresholds, acc3Test,'-o')
plt.title("Receiver Operating Characteristic")
plt.plot(fpr1, tpr1,'-o')
plt.plot(fpr2, tpr2,'-o')
plt.plot(fpr3, tpr3,'-o')
plt.ylabel("True Positive Rate")
plt.xlabel("False Positive Rate")
plt.title("Receiver Operating Characteristic")
plt.plot(fpr1, tpr1Test,'-o')
plt.plot(fpr2, tpr2Test,'-o')
plt.plot(fpr3, tpr3Test,'-o')
plt.ylabel("True Positive Rate")
plt.xlabel("False Positive Rate")
plt.title("Receiver Operating Characteristic")
plt.plot(fpr1, tpr1,'-o')
plt.ylabel("True Positive Rate")
plt.xlabel("False Positive Rate")
/opt/ibm/conda/miniconda3.6/lib/python3.6/site-packages/matplotlib/cbook/ MatplotlibDeprecationWarning: Passing one of 'on', 'true', 'off', 'false' as a boolean is deprecated; use an actual boolean (True/False) instead. warn_deprecated("2.2", "Passing one of 'on', 'true', 'off', 'false' as a "
(0, 1)