Mourad Askar
Decision Tree Models to predict if a crash has_injuries
NOTE: This model was experimental for exploration of the tree, but not included in the analysis
Create Models to predict if a crash has_injuries:
Steps per Model:
Report for each model run:
import pandas as pd
import numpy as np
import klib
import pandas_profiling as pp
import sweetviz
import sklearn
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import OrdinalEncoder
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import MinMaxScaler
from sklearn.compose import make_column_transformer
from imblearn.over_sampling import *
from imblearn.under_sampling import *
from sklearn.utils import *
from sklearn.model_selection import train_test_split
from platform import python_version
from sklearn.tree import export_graphviz
import graphviz
%precision 2
pd.options.display.max_rows = 100
pd.options.display.max_columns = 100
pd.options.display.width = 120
pd.options.display.float_format='{:,.2f}'.format
pd.options.display.precision = 2
np.set_printoptions(precision=2, linewidth=120, suppress=True, edgeitems=5)
sns.set_style("white")
StartBold = "\033[1m"
EndBold = "\033[0m"
print('python',python_version())
print(np.__name__, np.__version__)
print(pd.__name__, pd.__version__)
print(klib.__name__, klib.__version__)
#print(pp.__name__, pp.__version__)
print(sklearn.__name__, sklearn.__version__)
print(sweetviz.__name__, sweetviz.__version__)
python 3.9.2 numpy 1.20.1 pandas 1.2.3 klib 0.1.5 sklearn 0.24.1 sweetviz 2.0.9
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import cross_validate
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.ensemble import RandomForestClassifier, BaggingClassifier
from sklearn.linear_model import LogisticRegression, SGDClassifier
from sklearn.ensemble import AdaBoostClassifier
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.svm import SVC
from sklearn import metrics
%matplotlib inline
%config InlineBackend.figure_formats = ['retina']
#%config InlineBackend.figure_formats = ['png', 'retina', 'jpeg', 'svg', 'pdf']
# Global Parameters
random_state = 2021
n_jobs = 4
week_days = {0:'Sunday',1:'Monday',2:'Tuesday',3:'Wednesday',4:'Thursday',5:'Friday',6:'Saturday'}
is_weekday = {0:'Weekend',1:'Weekday'}
#file_parquet_c = '20210203_chi_crashes_c.parquet'
#file_parquet_c = '20210228_chi_crashes_c.parquet'
file_crash_df_parquet = 'crash_df.parquet'
crash_df = pd.read_parquet(file_crash_df_parquet)
crash_df.info()
<class 'pandas.core.frame.DataFrame'> Int64Index: 326488 entries, 0 to 328789 Data columns (total 32 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 crash_date 326488 non-null datetime64[ns] 1 crash_year 326488 non-null int16 2 crash_month 326488 non-null int8 3 crash_day_of_week 326488 non-null int8 4 crash_hour 326488 non-null int8 5 crash_time_of_day 326488 non-null category 6 latitude 326488 non-null float32 7 longitude 326488 non-null float32 8 beat_of_occurrence 326488 non-null int64 9 address 326488 non-null string 10 street_no 326488 non-null string 11 street_direction 326488 non-null category 12 street_name 326488 non-null category 13 posted_speed_limit 326488 non-null int64 14 traffic_control_device 326488 non-null category 15 device_condition 326488 non-null category 16 weather_condition 326488 non-null category 17 lighting_condition 326488 non-null category 18 trafficway_type 326488 non-null category 19 alignment 326488 non-null category 20 roadway_surface_cond 326488 non-null category 21 road_defect 326488 non-null category 22 first_crash_type 326488 non-null category 23 prim_contributory_cause 326488 non-null category 24 sec_contributory_cause 326488 non-null category 25 num_units 326488 non-null int8 26 has_injuries 326488 non-null int64 27 has_fatal 326488 non-null int64 28 crash_type 326488 non-null category 29 damage 326488 non-null category 30 injuries_total 326488 non-null int64 31 injuries_fatal 326488 non-null int64 dtypes: category(16), datetime64[ns](1), float32(2), int16(1), int64(6), int8(4), string(2) memory usage: 34.6 MB
crash_df['is_weekday'] = crash_df.crash_day_of_week.map(lambda x: 0 if x in [1,7] else 1)
features_names = [
# 'crash_date',
'crash_year',
'crash_month',
# 'crash_day_of_week',
# 'crash_hour',
'crash_time_of_day', # New
'is_weekday', #New
'latitude',
'longitude',
# 'beat_of_occurrence', # Should be considered as categorical, don't scale.
# 'address',
# 'street_no',
# 'street_direction',
# 'street_name',
'posted_speed_limit',
'traffic_control_device',
'device_condition',
'weather_condition',
'lighting_condition',
'trafficway_type',
'alignment',
'roadway_surface_cond',
'road_defect',
'first_crash_type',
'prim_contributory_cause',
'sec_contributory_cause',
'num_units',
# 'intersection_related_i',
# 'not_right_of_way_i',
# 'hit_and_run_i',
]
target_names = [
'has_injuries', # New
# 'has_fatal', # New
# 'crash_type',
# 'damage',
# 'injuries_total',
# 'injuries_fatal',
# 'injuries_incapacitating',
# 'inj_non_incap',
# 'inj_report_not_evdnt',
# 'injuries_no_indication',
# 'most_severe_injury',
]
all_columns = features_names + target_names
all_columns
['crash_year', 'crash_month', 'crash_time_of_day', 'is_weekday', 'latitude', 'longitude', 'posted_speed_limit', 'traffic_control_device', 'device_condition', 'weather_condition', 'lighting_condition', 'trafficway_type', 'alignment', 'roadway_surface_cond', 'road_defect', 'first_crash_type', 'prim_contributory_cause', 'sec_contributory_cause', 'num_units', 'has_injuries']
#crash_df = crash_df[(crash_df.crash_year>2017) & (crash_df.crash_year<2021)]
crash_2019_df = crash_df.query('crash_year == 2019')[all_columns[1:]] # drop year column
crash_2019_df.info()
<class 'pandas.core.frame.DataFrame'> Int64Index: 116737 entries, 118947 to 236699 Data columns (total 19 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 crash_month 116737 non-null int8 1 crash_time_of_day 116737 non-null category 2 is_weekday 116737 non-null int64 3 latitude 116737 non-null float32 4 longitude 116737 non-null float32 5 posted_speed_limit 116737 non-null int64 6 traffic_control_device 116737 non-null category 7 device_condition 116737 non-null category 8 weather_condition 116737 non-null category 9 lighting_condition 116737 non-null category 10 trafficway_type 116737 non-null category 11 alignment 116737 non-null category 12 roadway_surface_cond 116737 non-null category 13 road_defect 116737 non-null category 14 first_crash_type 116737 non-null category 15 prim_contributory_cause 116737 non-null category 16 sec_contributory_cause 116737 non-null category 17 num_units 116737 non-null int8 18 has_injuries 116737 non-null int64 dtypes: category(12), float32(2), int64(3), int8(2) memory usage: 6.0 MB
features = crash_2019_df[features_names[1:]].copy() # skip year column
target = crash_2019_df[target_names].has_injuries.copy()
features.shape, target.shape
((116737, 18), (116737,))
features.info()
<class 'pandas.core.frame.DataFrame'> Int64Index: 116737 entries, 118947 to 236699 Data columns (total 18 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 crash_month 116737 non-null int8 1 crash_time_of_day 116737 non-null category 2 is_weekday 116737 non-null int64 3 latitude 116737 non-null float32 4 longitude 116737 non-null float32 5 posted_speed_limit 116737 non-null int64 6 traffic_control_device 116737 non-null category 7 device_condition 116737 non-null category 8 weather_condition 116737 non-null category 9 lighting_condition 116737 non-null category 10 trafficway_type 116737 non-null category 11 alignment 116737 non-null category 12 roadway_surface_cond 116737 non-null category 13 road_defect 116737 non-null category 14 first_crash_type 116737 non-null category 15 prim_contributory_cause 116737 non-null category 16 sec_contributory_cause 116737 non-null category 17 num_units 116737 non-null int8 dtypes: category(12), float32(2), int64(2), int8(2) memory usage: 5.1 MB
features.select_dtypes('category').columns.values
array(['crash_time_of_day', 'traffic_control_device', 'device_condition', 'weather_condition', 'lighting_condition', 'trafficway_type', 'alignment', 'roadway_surface_cond', 'road_defect', 'first_crash_type', 'prim_contributory_cause', 'sec_contributory_cause'], dtype=object)
features['crash_month'] = features['crash_month'].astype('category')
#features['beat_of_occurrence'] = features['beat_of_occurrence'].astype('category')
#features['address'] = features['address'].astype('category')
features_dm = pd.get_dummies(features)
# Split data in to Features X and Target y
X = features_dm
y = target
print('X and y shapes:')
print(X.shape,y.shape,'\n')
print('Target Ratio:')
print(y.value_counts(normalize=True, dropna=False),'\n')
X and y shapes: (116737, 204) (116737,) Target Ratio: 0 0.86 1 0.14 Name: has_injuries, dtype: float64
X
is_weekday | latitude | longitude | posted_speed_limit | num_units | crash_month_1 | crash_month_2 | crash_month_3 | crash_month_4 | crash_month_5 | crash_month_6 | crash_month_7 | crash_month_8 | crash_month_9 | crash_month_10 | crash_month_11 | crash_month_12 | crash_time_of_day_overnight | crash_time_of_day_morning | crash_time_of_day_mid_day | crash_time_of_day_evening | traffic_control_device_BICYCLE CROSSING SIGN | traffic_control_device_DELINEATORS | traffic_control_device_FLASHING CONTROL SIGNAL | traffic_control_device_LANE USE MARKING | traffic_control_device_NO CONTROLS | traffic_control_device_NO PASSING | traffic_control_device_OTHER | traffic_control_device_OTHER RAILROAD CROSSING | traffic_control_device_OTHER REG. SIGN | traffic_control_device_OTHER WARNING SIGN | traffic_control_device_PEDESTRIAN CROSSING SIGN | traffic_control_device_POLICE/FLAGMAN | traffic_control_device_RAILROAD CROSSING GATE | traffic_control_device_RR CROSSING SIGN | traffic_control_device_SCHOOL ZONE | traffic_control_device_STOP SIGN/FLASHER | traffic_control_device_TRAFFIC SIGNAL | traffic_control_device_UNKNOWN | traffic_control_device_YIELD | device_condition_FUNCTIONING IMPROPERLY | device_condition_FUNCTIONING PROPERLY | device_condition_MISSING | device_condition_NO CONTROLS | device_condition_NOT FUNCTIONING | device_condition_OTHER | device_condition_UNKNOWN | device_condition_WORN REFLECTIVE MATERIAL | weather_condition_BLOWING SAND, SOIL, DIRT | weather_condition_BLOWING SNOW | ... | prim_contributory_cause_PHYSICAL CONDITION OF DRIVER | prim_contributory_cause_RELATED TO BUS STOP | prim_contributory_cause_ROAD CONSTRUCTION/MAINTENANCE | prim_contributory_cause_ROAD ENGINEERING/SURFACE/MARKING DEFECTS | prim_contributory_cause_TEXTING | prim_contributory_cause_TURNING RIGHT ON RED | prim_contributory_cause_UNABLE TO DETERMINE | prim_contributory_cause_UNDER THE INFLUENCE OF ALCOHOL/DRUGS (USE WHEN ARREST IS EFFECTED) | prim_contributory_cause_VISION OBSCURED (SIGNS, TREE LIMBS, BUILDINGS, ETC.) | prim_contributory_cause_WEATHER | sec_contributory_cause_ANIMAL | sec_contributory_cause_BICYCLE ADVANCING LEGALLY ON RED LIGHT | sec_contributory_cause_CELL PHONE USE OTHER THAN TEXTING | sec_contributory_cause_DISREGARDING OTHER TRAFFIC SIGNS | sec_contributory_cause_DISREGARDING ROAD MARKINGS | sec_contributory_cause_DISREGARDING STOP SIGN | sec_contributory_cause_DISREGARDING TRAFFIC SIGNALS | sec_contributory_cause_DISREGARDING YIELD SIGN | sec_contributory_cause_DISTRACTION - FROM INSIDE VEHICLE | sec_contributory_cause_DISTRACTION - FROM OUTSIDE VEHICLE | sec_contributory_cause_DISTRACTION - OTHER ELECTRONIC DEVICE (NAVIGATION DEVICE, DVD PLAYER, ETC.) | sec_contributory_cause_DRIVING ON WRONG SIDE/WRONG WAY | sec_contributory_cause_DRIVING SKILLS/KNOWLEDGE/EXPERIENCE | sec_contributory_cause_EQUIPMENT - VEHICLE CONDITION | sec_contributory_cause_EVASIVE ACTION DUE TO ANIMAL, OBJECT, NONMOTORIST | sec_contributory_cause_EXCEEDING AUTHORIZED SPEED LIMIT | sec_contributory_cause_EXCEEDING SAFE SPEED FOR CONDITIONS | sec_contributory_cause_FAILING TO REDUCE SPEED TO AVOID CRASH | sec_contributory_cause_FAILING TO YIELD RIGHT-OF-WAY | sec_contributory_cause_FOLLOWING TOO CLOSELY | sec_contributory_cause_HAD BEEN DRINKING (USE WHEN ARREST IS NOT MADE) | sec_contributory_cause_IMPROPER BACKING | sec_contributory_cause_IMPROPER LANE USAGE | sec_contributory_cause_IMPROPER OVERTAKING/PASSING | sec_contributory_cause_IMPROPER TURNING/NO SIGNAL | sec_contributory_cause_MOTORCYCLE ADVANCING LEGALLY ON RED LIGHT | sec_contributory_cause_NOT APPLICABLE | sec_contributory_cause_OBSTRUCTED CROSSWALKS | sec_contributory_cause_OPERATING VEHICLE IN ERRATIC, RECKLESS, CARELESS, NEGLIGENT OR AGGRESSIVE MANNER | sec_contributory_cause_PASSING STOPPED SCHOOL BUS | sec_contributory_cause_PHYSICAL CONDITION OF DRIVER | sec_contributory_cause_RELATED TO BUS STOP | sec_contributory_cause_ROAD CONSTRUCTION/MAINTENANCE | sec_contributory_cause_ROAD ENGINEERING/SURFACE/MARKING DEFECTS | sec_contributory_cause_TEXTING | sec_contributory_cause_TURNING RIGHT ON RED | sec_contributory_cause_UNABLE TO DETERMINE | sec_contributory_cause_UNDER THE INFLUENCE OF ALCOHOL/DRUGS (USE WHEN ARREST IS EFFECTED) | sec_contributory_cause_VISION OBSCURED (SIGNS, TREE LIMBS, BUILDINGS, ETC.) | sec_contributory_cause_WEATHER | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
118947 | 1 | 41.88 | -87.71 | 30 | 2 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 |
118948 | 1 | 41.79 | -87.59 | 35 | 2 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 |
118949 | 1 | 41.95 | -87.67 | 30 | 2 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 |
118950 | 1 | 41.88 | -87.63 | 20 | 2 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
118951 | 1 | 41.95 | -87.78 | 30 | 3 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
236694 | 1 | 41.84 | -87.64 | 30 | 2 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
236695 | 1 | 41.88 | -87.67 | 25 | 2 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 |
236696 | 1 | 41.91 | -87.63 | 45 | 4 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
236698 | 1 | 41.81 | -87.74 | 30 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 |
236699 | 1 | 41.87 | -87.70 | 35 | 3 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 |
116737 rows × 204 columns
# Split data into Train and Test
# We stratify to make sure target representation is kept in the new datasets
X_train, X_test, y_train, y_test = train_test_split(X, y,
test_size=.2,
# stratify=y,
random_state=random_state)
X_train.shape, X_test.shape, y_train.shape, y_test.shape
((93389, 204), (23348, 204), (93389,), (23348,))
print('Training Data:')
print('--------------')
print('X and y shapes:')
print(X_train.shape,y_train.shape,'\n')
print('Target Ratio:')
print(y_train.value_counts(normalize=True, dropna=False),'\n')
Training Data: -------------- X and y shapes: (93389, 204) (93389,) Target Ratio: 0 0.86 1 0.14 Name: has_injuries, dtype: float64
print('Testing Data:')
print('-------------')
print('X and y shapes:')
print(X_test.shape,y_test.shape,'\n')
print('Target Ratio:')
print(y_test.value_counts(normalize=True, dropna=False),'\n')
Testing Data: ------------- X and y shapes: (23348, 204) (23348,) Target Ratio: 0 0.86 1 0.14 Name: has_injuries, dtype: float64
It's visible that we have an imbalanced target that we need to consider during our evaluations.
Resampling to compensate the target class that has low representation.
ros = RandomOverSampler(random_state=random_state)
X_train_resampled, y_train_resampled = ros.fit_resample(X_train, y_train)
print('Testing Data:')
print('-------------')
print('X and y shapes:')
print(X_train_resampled.shape,y_train_resampled.shape,'\n')
print('Target Ratio:')
print(y_train_resampled.value_counts(normalize=True, dropna=False),'\n')
Testing Data: ------------- X and y shapes: (160788, 204) (160788,) Target Ratio: 0 0.50 1 0.50 Name: has_injuries, dtype: float64
X_train
is_weekday | latitude | longitude | posted_speed_limit | num_units | crash_month_1 | crash_month_2 | crash_month_3 | crash_month_4 | crash_month_5 | crash_month_6 | crash_month_7 | crash_month_8 | crash_month_9 | crash_month_10 | crash_month_11 | crash_month_12 | crash_time_of_day_overnight | crash_time_of_day_morning | crash_time_of_day_mid_day | crash_time_of_day_evening | traffic_control_device_BICYCLE CROSSING SIGN | traffic_control_device_DELINEATORS | traffic_control_device_FLASHING CONTROL SIGNAL | traffic_control_device_LANE USE MARKING | traffic_control_device_NO CONTROLS | traffic_control_device_NO PASSING | traffic_control_device_OTHER | traffic_control_device_OTHER RAILROAD CROSSING | traffic_control_device_OTHER REG. SIGN | traffic_control_device_OTHER WARNING SIGN | traffic_control_device_PEDESTRIAN CROSSING SIGN | traffic_control_device_POLICE/FLAGMAN | traffic_control_device_RAILROAD CROSSING GATE | traffic_control_device_RR CROSSING SIGN | traffic_control_device_SCHOOL ZONE | traffic_control_device_STOP SIGN/FLASHER | traffic_control_device_TRAFFIC SIGNAL | traffic_control_device_UNKNOWN | traffic_control_device_YIELD | device_condition_FUNCTIONING IMPROPERLY | device_condition_FUNCTIONING PROPERLY | device_condition_MISSING | device_condition_NO CONTROLS | device_condition_NOT FUNCTIONING | device_condition_OTHER | device_condition_UNKNOWN | device_condition_WORN REFLECTIVE MATERIAL | weather_condition_BLOWING SAND, SOIL, DIRT | weather_condition_BLOWING SNOW | ... | prim_contributory_cause_PHYSICAL CONDITION OF DRIVER | prim_contributory_cause_RELATED TO BUS STOP | prim_contributory_cause_ROAD CONSTRUCTION/MAINTENANCE | prim_contributory_cause_ROAD ENGINEERING/SURFACE/MARKING DEFECTS | prim_contributory_cause_TEXTING | prim_contributory_cause_TURNING RIGHT ON RED | prim_contributory_cause_UNABLE TO DETERMINE | prim_contributory_cause_UNDER THE INFLUENCE OF ALCOHOL/DRUGS (USE WHEN ARREST IS EFFECTED) | prim_contributory_cause_VISION OBSCURED (SIGNS, TREE LIMBS, BUILDINGS, ETC.) | prim_contributory_cause_WEATHER | sec_contributory_cause_ANIMAL | sec_contributory_cause_BICYCLE ADVANCING LEGALLY ON RED LIGHT | sec_contributory_cause_CELL PHONE USE OTHER THAN TEXTING | sec_contributory_cause_DISREGARDING OTHER TRAFFIC SIGNS | sec_contributory_cause_DISREGARDING ROAD MARKINGS | sec_contributory_cause_DISREGARDING STOP SIGN | sec_contributory_cause_DISREGARDING TRAFFIC SIGNALS | sec_contributory_cause_DISREGARDING YIELD SIGN | sec_contributory_cause_DISTRACTION - FROM INSIDE VEHICLE | sec_contributory_cause_DISTRACTION - FROM OUTSIDE VEHICLE | sec_contributory_cause_DISTRACTION - OTHER ELECTRONIC DEVICE (NAVIGATION DEVICE, DVD PLAYER, ETC.) | sec_contributory_cause_DRIVING ON WRONG SIDE/WRONG WAY | sec_contributory_cause_DRIVING SKILLS/KNOWLEDGE/EXPERIENCE | sec_contributory_cause_EQUIPMENT - VEHICLE CONDITION | sec_contributory_cause_EVASIVE ACTION DUE TO ANIMAL, OBJECT, NONMOTORIST | sec_contributory_cause_EXCEEDING AUTHORIZED SPEED LIMIT | sec_contributory_cause_EXCEEDING SAFE SPEED FOR CONDITIONS | sec_contributory_cause_FAILING TO REDUCE SPEED TO AVOID CRASH | sec_contributory_cause_FAILING TO YIELD RIGHT-OF-WAY | sec_contributory_cause_FOLLOWING TOO CLOSELY | sec_contributory_cause_HAD BEEN DRINKING (USE WHEN ARREST IS NOT MADE) | sec_contributory_cause_IMPROPER BACKING | sec_contributory_cause_IMPROPER LANE USAGE | sec_contributory_cause_IMPROPER OVERTAKING/PASSING | sec_contributory_cause_IMPROPER TURNING/NO SIGNAL | sec_contributory_cause_MOTORCYCLE ADVANCING LEGALLY ON RED LIGHT | sec_contributory_cause_NOT APPLICABLE | sec_contributory_cause_OBSTRUCTED CROSSWALKS | sec_contributory_cause_OPERATING VEHICLE IN ERRATIC, RECKLESS, CARELESS, NEGLIGENT OR AGGRESSIVE MANNER | sec_contributory_cause_PASSING STOPPED SCHOOL BUS | sec_contributory_cause_PHYSICAL CONDITION OF DRIVER | sec_contributory_cause_RELATED TO BUS STOP | sec_contributory_cause_ROAD CONSTRUCTION/MAINTENANCE | sec_contributory_cause_ROAD ENGINEERING/SURFACE/MARKING DEFECTS | sec_contributory_cause_TEXTING | sec_contributory_cause_TURNING RIGHT ON RED | sec_contributory_cause_UNABLE TO DETERMINE | sec_contributory_cause_UNDER THE INFLUENCE OF ALCOHOL/DRUGS (USE WHEN ARREST IS EFFECTED) | sec_contributory_cause_VISION OBSCURED (SIGNS, TREE LIMBS, BUILDINGS, ETC.) | sec_contributory_cause_WEATHER | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
209921 | 1 | 41.85 | -87.69 | 25 | 2 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
120062 | 0 | 41.98 | -87.67 | 25 | 2 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
205935 | 1 | 41.90 | -87.62 | 35 | 2 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
154446 | 1 | 41.92 | -87.76 | 35 | 2 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
156371 | 1 | 41.74 | -87.68 | 35 | 5 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
121631 | 1 | 41.79 | -87.80 | 30 | 2 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 |
202702 | 0 | 41.94 | -87.69 | 30 | 2 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
125175 | 1 | 41.68 | -87.73 | 20 | 2 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 |
213271 | 1 | 41.89 | -87.62 | 30 | 2 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 |
210973 | 1 | 41.95 | -87.72 | 30 | 2 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 |
93389 rows × 204 columns
I have considered whether I should scale/normalize the dataset, then I realized that most of my features are categorical that are now expanded into 1 hot encoded dummy variables. The other Integer variables I also assume that they can be used as is, as they can represent ordinal values. Anyway, I will be testing the normalization during my steps.
NOTE: Decided that scaling will make sense to be applied on posted_speed_limit and num_units
X_train_Scaler = MinMaxScaler().fit(X_train)
X_train_mms = X_train_Scaler.transform(X_train) # mms = MinMaxScaled
X_train_resampled_Scaler = MinMaxScaler().fit(X_train_resampled)
X_train_resampled_mms = X_train_resampled_Scaler.transform(X_train_resampled)
X_test_mms = X_train_Scaler.transform(X_test)
X_test_rs_mms = X_train_resampled_Scaler.transform(X_test)
X_train_mms
array([[1. , 0.54, 0.59, 0.42, 0.09, ..., 0. , 0. , 0. , 0. , 0. ], [0. , 0.89, 0.65, 0.42, 0.09, ..., 0. , 0. , 0. , 0. , 0. ], [1. , 0.66, 0.77, 0.58, 0.09, ..., 0. , 0. , 0. , 0. , 0. ], [1. , 0.73, 0.43, 0.58, 0.09, ..., 0. , 0. , 0. , 0. , 0. ], [1. , 0.26, 0.61, 0.58, 0.36, ..., 0. , 0. , 0. , 0. , 0. ], ..., [1. , 0.4 , 0.33, 0.5 , 0.09, ..., 0. , 1. , 0. , 0. , 0. ], [0. , 0.78, 0.6 , 0.5 , 0.09, ..., 0. , 0. , 0. , 0. , 0. ], [1. , 0.1 , 0.51, 0.33, 0.09, ..., 0. , 1. , 0. , 0. , 0. ], [1. , 0.65, 0.76, 0.5 , 0.09, ..., 0. , 1. , 0. , 0. , 0. ], [1. , 0.82, 0.52, 0.5 , 0.09, ..., 0. , 1. , 0. , 0. , 0. ]])
def get_GS_cv_results(gs):
_cv_results_ = gs.cv_results_
_scorer_keys = gs.scorer_.keys()
df_gs_scores = pd.DataFrame()
for k in _cv_results_.keys():
if k.startswith('param_'):
param_key = ('param',k.split('_',1)[1])
param_values = _cv_results_[k]
df_gs_scores = pd.concat([df_gs_scores,pd.DataFrame({param_key:param_values})], axis=1)
elif k.startswith('mean_train') or k.startswith('mean_test'):
score_key = (k.split('_',2)[2],k.split('_',2)[1])
score_results = _cv_results_[k]
df_gs_scores = pd.concat([df_gs_scores,pd.DataFrame({score_key:score_results})], axis=1)
return df_gs_scores
def plot_grid_search_scores_grid(gs, display_plots=False):
_cv_results_ = gs.cv_results_
_scorer_keys = gs.scorer_.keys()
_param_keys = list(_cv_results_['params'][0].keys())
df_gs_scores = get_GS_cv_results(gs)
display(df_gs_scores)
if(display_plots):
if len(_param_keys) > 1:
for _score in _scorer_keys:
dd = df_gs_scores[['param',_score]]
dd.columns = dd.columns.droplevel(0)
ddm = dd.melt(id_vars=_param_keys[:2],value_vars=['train','test'],var_name='dataset',value_name=_score)
g = sns.FacetGrid(ddm,col=_param_keys[0],height=3)
g.map_dataframe(sns.lineplot,_param_keys[1],_score,'dataset')
g.add_legend()
else:
for _score in _scorer_keys:
dd = df_gs_scores[['param',_score]]
dd.columns = dd.columns.droplevel(0)
ddm = dd.melt(id_vars=_param_keys,value_vars=['train','test'],var_name='dataset',value_name=_score)
g = sns.FacetGrid(ddm, height=3)
g.map_dataframe(sns.lineplot,_param_keys[0],_score,'dataset')
g.add_legend()
def plot_grid_search_results(gs, plot_scoring='r2'):
_cv_results_ = gs.cv_results_
_scorer_keys = gs.scorer_.keys()
print(f'{StartBold}Estimator:{EndBold}')
print(gs.best_estimator_)
print()
print(f'{StartBold}Best Result (Suggested):{EndBold}')
print(f'{StartBold}\t{"":20} {"train":>6} {"valdn":>6}{EndBold}')
for _score in _scorer_keys:
print('\t{2:>20} {0:>6.2f} {1:>6.2f}'.format(_cv_results_['mean_train_' + _score][gs.best_index_],
_cv_results_['mean_test_' + _score][gs.best_index_],
_score))
print()
print(f'{StartBold}Params:{EndBold}')
print('\t{}'.format(gs.best_params_))
print()
scoring_label = plot_scoring
not_negative = 1
if plot_scoring.startswith('neg_'):
scoring_label = plot_scoring[4:]
not_negative = -1
df = pd.DataFrame({
'params': _cv_results_['params'],
'mean_train_' + plot_scoring: _cv_results_['mean_train_' + plot_scoring] * not_negative,
'mean_test_' + plot_scoring: _cv_results_['mean_test_' + plot_scoring] * not_negative
})
fig, ax = plt.subplots()
df.plot('params',['mean_train_' + plot_scoring,'mean_test_' + plot_scoring], 'line', ax=ax)
ax.axvline(gs.best_index_,0,1, color='r', linestyle='--')
plt.grid(axis='both',c='lightgrey',ls=':')
plt.xticks(ticks=range(df.params.count()), labels=df.params)
plt.tick_params(axis='x', rotation=90)
plt.ylabel(scoring_label)
plt.legend(['train','valdn'])
plt.show()
if("beep" in globals()): beep(True)
def score_classification_model(fitted_model, X, y_true):
p = fitted_model.predict(X)
print('accuracy score: {:>5.2f}'.format(metrics.accuracy_score(y_true,p)))
print('f1 score: {:>5.2f}'.format(metrics.f1_score(y_true,p)))
print('recall score: {:>5.2f}'.format(metrics.recall_score(y_true,p)))
print('precision score: {:>5.2f}'.format(metrics.precision_score(y_true,p)))
print('balanced_accuracy score: {:>5.2f}'.format(metrics.balanced_accuracy_score(y_true,p)))
print('roc_auc score: {:>5.2f}'.format(metrics.roc_auc_score(y_true,p)))
#print(metrics.classification_report(y_true,p))
metrics.plot_confusion_matrix(fitted_model,X, y_true)
plt.show()
#print(metrics.classification_report(y_true,p, sample_weight=compute_sample_weight('balanced',y_true)))
#metrics.plot_confusion_matrix(fitted_model,X, y_true, sample_weight=compute_sample_weight('balanced',y_true))
def plot_coefficients(coef, feature_names, top_n=0):
_ = pd.DataFrame({'features':feature_names,'coef':coef}).sort_values('coef',key=lambda x: abs(x),ascending=False)
if top_n > 0:
_ = _.iloc[:top_n,:]
#display(_)
n_features = _.shape[0]
#plt.subplots(figsize=(10,20))
plt.barh(range(n_features), _['coef'], align='center')
plt.yticks(np.arange(n_features), _['features'])
plt.xlabel('Coefficient Value')
plt.ylabel('Feature')
plt.ylim(n_features,-1)
plt.show()
#compute_class_weight('balanced',classes=[0,1],y=y_train)
#compute_sample_weight('balanced',y_train)
# Common Grid Search Parameters
grid_search_defaults = {
'cv': 3,
'scoring': ['f1', 'recall', 'precision', 'balanced_accuracy', 'accuracy', 'roc_auc'],
'refit': 'f1',
'return_train_score': True,
'error_score': 0,
'verbose': 3,
'n_jobs': n_jobs,
}
# Model 1
# Using scaled training dataset
params = {
'max_depth': np.logspace(1,7,3,base=2),
'min_samples_split': np.linspace(.01,.5,5).round(2),
}
gs_dt1 = GridSearchCV(
DecisionTreeClassifier(random_state=random_state),
params, **grid_search_defaults
)
gs_dt1.fit(
X_train,
y_train
)
plot_grid_search_results(gs_dt1, 'f1')
Fitting 3 folds for each of 15 candidates, totalling 45 fits Estimator: DecisionTreeClassifier(max_depth=16.0, min_samples_split=0.01, random_state=2021) Best Result (Suggested): train valdn f1 0.41 0.40 recall 0.28 0.27 precision 0.78 0.77 balanced_accuracy 0.63 0.63 accuracy 0.89 0.89 roc_auc 0.82 0.80 Params: {'max_depth': 16.0, 'min_samples_split': 0.01}
score_classification_model(gs_dt1, X_test_mms, y_test)
accuracy score: 0.85 f1 score: 0.33 recall score: 0.27 precision score: 0.42 balanced_accuracy score: 0.61 roc_auc score: 0.61
# Model 2
# Using scaled training dataset, balanced class weight
params = {
'max_depth': np.logspace(1,7,3,base=2),
'min_samples_split': np.linspace(.01,.5,5).round(2),
}
gs_dt2 = GridSearchCV(
DecisionTreeClassifier(random_state=random_state, class_weight='balanced'),
params, **grid_search_defaults
)
gs_dt2.fit(
X_train_mms,
y_train
)
plot_grid_search_results(gs_dt2, 'f1')
Fitting 3 folds for each of 15 candidates, totalling 45 fits Estimator: DecisionTreeClassifier(class_weight='balanced', max_depth=16.0, min_samples_split=0.38, random_state=2021) Best Result (Suggested): train valdn f1 0.43 0.43 recall 0.57 0.56 precision 0.35 0.34 balanced_accuracy 0.70 0.69 accuracy 0.79 0.79 roc_auc 0.78 0.78 Params: {'max_depth': 16.0, 'min_samples_split': 0.38}
score_classification_model(gs_dt2, X_test_mms, y_test)
accuracy score: 0.64 f1 score: 0.37 recall score: 0.77 precision score: 0.24 balanced_accuracy score: 0.69 roc_auc score: 0.69
# Model 3
# Using scaled Re-sampled training dataset
params = {
'max_depth': np.logspace(1,7,3,base=2),
'min_samples_split': np.linspace(.01,.5,5).round(2),
}
gs_dt3 = GridSearchCV(
DecisionTreeClassifier(random_state=random_state),
params, **grid_search_defaults
)
gs_dt3.fit(
X_train_resampled_mms,
y_train_resampled
)
plot_grid_search_results(gs_dt3, 'f1')
Fitting 3 folds for each of 15 candidates, totalling 45 fits
/usr/local/Caskroom/miniconda/base/envs/t1/lib/python3.9/site-packages/joblib/externals/loky/process_executor.py:688: UserWarning: A worker stopped while some jobs were given to the executor. This can be caused by a too short worker timeout or by a memory leak. warnings.warn(
Estimator: DecisionTreeClassifier(max_depth=128.0, min_samples_split=0.01, random_state=2021) Best Result (Suggested): train valdn f1 0.75 0.74 recall 0.76 0.75 precision 0.74 0.73 balanced_accuracy 0.75 0.74 accuracy 0.75 0.74 roc_auc 0.84 0.83 Params: {'max_depth': 128.0, 'min_samples_split': 0.01}
score_classification_model(gs_dt3, X_test_mms, y_test)
accuracy score: 0.71 f1 score: 0.40 recall score: 0.70 precision score: 0.28 balanced_accuracy score: 0.71 roc_auc score: 0.71
If the trees doesn't show up in the notebook, check the html export
# Original Training Dataset (Not-Oversampled)
params = {
'max_depth': 5,
'min_samples_split': .1,
# 'class_weight': 'balanced',
}
clf = DecisionTreeClassifier(**params)
clf.fit(X_train_mms, y_train)
fn = X_train.columns.values
cn = ['no_injuries','has_injuries']
_ = export_graphviz(clf,rounded=True,leaves_parallel=True,
out_file=None,
feature_names = X_train.columns.values,
class_names= cn,
filled = True,
proportion= True
)
graphviz.Source(_)
# Resampled Training Dataset (Oversampled)
params = {
'max_depth': 5,
'min_samples_split': .1
}
clf = DecisionTreeClassifier(**params)
clf.fit(X_train_resampled_mms, y_train_resampled)
fn = X_train.columns.values
cn = ['no_injuries','has_injuries']
_ = export_graphviz(clf,rounded=True,leaves_parallel=True,
out_file=None,
feature_names = X_train.columns.values,
class_names= cn,
filled = True,
proportion= True
)
graphviz.Source(_)
# Send notification to my mobile when done (will not work if script not found)
if("prowl_notify" in globals()): prowl_notify('ALL DONE')