#!/usr/bin/env python # coding: utf-8 # In[1]: get_ipython().run_line_magic('load_ext', 'autoreload') get_ipython().run_line_magic('autoreload', '2') get_ipython().run_line_magic('matplotlib', 'inline') import numpy as np import pandas as pd import matplotlib.pyplot as plt # In[2]: plt.rcParams['figure.figsize'] = [14, 9] # In[3]: from sklearn.datasets import make_classification from sklearn.linear_model import LogisticRegression from sklearn.ensemble import GradientBoostingClassifier from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import train_test_split X, y = make_classification(n_samples=10000, n_informative=5, random_state=0) X_train, X_test, y_train, y_test = train_test_split(X, y) lr = LogisticRegression(solver='lbfgs', max_iter=1000) gbc = GradientBoostingClassifier() rfc = RandomForestClassifier(n_estimators=100) lr.fit(X_train, y_train) gbc.fit(X_train, y_train) rfc.fit(X_train, y_train) # ## First plot # In[9]: from sklearn_plot_api import plot_roc_curve viz_lr = plot_roc_curve(lr, X, y) # ### Change line color # In[10]: viz_lr.line_.set_color('red') viz_lr.figure_ # ## Plot multiple - Function Call # In[11]: fig, ax = plt.subplots() viz_lr = plot_roc_curve(lr, X, y, ax=ax) viz_gbc = plot_roc_curve(gbc, X, y, ax=ax) viz_rfc = plot_roc_curve(rfc, X, y, ax=ax) ax.legend() # ## Replot using ax # In[12]: fig, ax = plt.subplots() viz_lr.plot(ax=ax) viz_gbc.plot(ax=ax) viz_rfc.plot(ax=ax) ax.legend() # ## View Figure # In[13]: viz_lr.figure_