%matplotlib inline
import os
import pandas as pd
import numpy as np
from rdkit import Chem
from rdkit.Chem import RDConfig
from rdkit.Chem import DataStructs
from rdkit.Chem import AllChem
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem import Draw
from sklearn.ensemble import RandomForestClassifier
from nonconformist.nc import ClassifierNc
from nonconformist.nc import ClassifierAdapter
from nonconformist.icp import IcpClassifier
from nonconformist.evaluation import ClassIcpCvHelper
from nonconformist.evaluation import cross_val_score
train = os.path.join(RDConfig.RDDocsDir, 'Book/data/solubility.train.sdf')
test = os.path.join(RDConfig.RDDocsDir, 'Book/data/solubility.test.sdf')
trainmol = [m for m in Chem.SDMolSupplier(train)]
testmol = [m for m in Chem.SDMolSupplier(test)]
labels = set([m.GetProp('SOL_classification') for m in trainmol])
print(labels)
{'(A) low', '(C) high', '(B) medium'}
label2cls = {'(A) low':0, '(B) medium':1, '(C) high':2}
def fp2arr(fp):
arr = np.zeros((1,))
DataStructs.ConvertToNumpyArray(fp, arr)
return arr
trainfps = [AllChem.GetMorganFingerprintAsBitVect(m, 2, 1024) for m in trainmol]
trainfps = np.array([fp2arr(fp) for fp in trainfps])
testfps = [AllChem.GetMorganFingerprintAsBitVect(m, 2, 1024) for m in testmol]
testfps = np.array([fp2arr(fp) for fp in testfps])
train_cls = [label2cls[m.GetProp('SOL_classification')] for m in trainmol]
train_cls = np.array(train_cls)
test_cls = [label2cls[m.GetProp('SOL_classification')] for m in testmol]
test_cls = np.array(test_cls)
print(trainfps.shape, train_cls.shape, testfps.shape, test_cls.shape)
(1025, 1024) (1025,) (257, 1024) (257,)
#train data is devided to train and calibration data
ids = np.random.permutation(train_cls.size)
# Use first 700 data for train and second set is used for calibration
trainX, calibX = trainfps[ids[:700],:],trainfps[ids[700:],:]
trainY, calibY = train_cls[ids[:700]],train_cls[ids[700:]]
testX = testfps
testY = test_cls
rf = RandomForestClassifier(n_estimators=500, random_state=794)
nc = ClassifierNc(ClassifierAdapter(rf))
icp = IcpClassifier(nc)
icp.fit(trainX, trainY)
icp.calibrate(calibX, calibY)
pred = icp.predict(testX)
pred95 = icp.predict(testX, significance=0.05).astype(np.int32)
pred80 = icp.predict(testX, significance=0.2).astype(np.int32)
from nonconformist.evaluation import class_avg_c, class_n_correct
class_n_correct(pred, testY, significance=0.05)
244
class_n_correct(pred, testY, significance=0.1)
232
tp = 0
for idx, j in enumerate(testY):
print(j, np.argmax(pred[idx]), j == np.argmax(pred[idx]) , pred80[idx], ":", pred95[idx])
if j == np.argmax(pred[idx]):
tp += 1
0 2 False [0 1 1] : [1 1 1] 0 1 False [0 1 1] : [1 1 1] 1 1 True [0 1 0] : [1 1 0] 1 1 True [0 1 0] : [1 1 0] 1 1 True [0 1 0] : [0 1 0] 1 1 True [0 1 0] : [1 1 1] 0 0 True [1 0 0] : [1 0 0] 1 0 False [1 1 0] : [1 1 1] 0 1 False [0 1 0] : [1 1 0] 0 1 False [1 1 0] : [1 1 1] 0 0 True [1 0 0] : [1 1 0] 0 0 True [1 0 0] : [1 0 0] 0 0 True [1 1 0] : [1 1 1] 0 0 True [1 0 0] : [1 1 0] 0 0 True [1 0 0] : [1 0 0] 0 0 True [1 0 0] : [1 0 0] 0 0 True [1 0 0] : [1 1 0] 0 0 True [1 0 0] : [1 0 0] 0 0 True [1 0 0] : [1 0 0] 1 0 False [1 1 0] : [1 1 1] 0 0 True [1 0 0] : [1 0 0] 0 0 True [1 0 0] : [1 0 0] 1 1 True [0 1 0] : [0 1 0] 1 1 True [0 1 0] : [0 1 0] 1 1 True [0 1 0] : [0 1 0] 1 1 True [0 1 0] : [0 1 0] 1 1 True [0 1 0] : [0 1 0] 1 1 True [0 1 0] : [1 1 1] 1 1 True [0 1 0] : [0 1 1] 1 0 False [1 1 0] : [1 1 0] 1 1 True [0 1 0] : [0 1 0] 1 1 True [0 1 0] : [1 1 1] 0 0 True [1 0 0] : [1 0 0] 0 0 True [1 0 0] : [1 0 0] 1 1 True [0 1 0] : [0 1 0] 0 0 True [1 0 0] : [1 0 0] 0 1 False [0 1 0] : [0 1 0] 1 0 False [1 0 0] : [1 1 0] 0 0 True [1 0 0] : [1 0 0] 0 0 True [1 0 0] : [1 0 0] 0 0 True [1 0 0] : [1 0 0] 0 0 True [1 0 0] : [1 0 0] 1 1 True [0 1 0] : [1 1 0] 0 0 True [1 0 0] : [1 0 0] 0 0 True [1 0 0] : [1 0 0] 0 0 True [1 0 0] : [1 0 0] 0 0 True [1 0 0] : [1 0 0] 0 0 True [1 0 0] : [1 0 0] 0 0 True [1 0 0] : [1 0 0] 0 0 True [1 0 0] : [1 0 0] 0 0 True [1 0 0] : [1 0 0] 0 0 True [1 0 0] : [1 0 0] 0 0 True [1 0 0] : [1 0 0] 0 0 True [1 0 0] : [1 0 0] 0 0 True [1 0 0] : [1 0 0] 0 0 True [1 0 0] : [1 0 0] 2 2 True [0 0 1] : [0 1 1] 2 2 True [0 0 1] : [0 1 1] 2 2 True [0 0 1] : [0 0 1] 1 2 False [0 0 1] : [0 1 1] 2 2 True [0 0 1] : [0 0 1] 1 1 True [1 1 1] : [1 1 1] 1 2 False [0 0 1] : [0 1 1] 1 2 False [0 0 1] : [0 0 1] 2 2 True [0 0 1] : [0 0 1] 1 1 True [0 1 0] : [1 1 1] 0 0 True [1 0 0] : [1 1 0] 1 2 False [0 1 1] : [1 1 1] 2 1 False [0 1 0] : [0 1 1] 2 1 False [0 1 0] : [1 1 1] 1 1 True [1 1 0] : [1 1 1] 1 0 False [1 0 0] : [1 1 0] 2 2 True [0 1 1] : [1 1 1] 0 1 False [0 1 0] : [0 1 0] 2 1 False [0 1 1] : [1 1 1] 2 0 False [1 1 1] : [1 1 1] 1 0 False [1 1 0] : [1 1 1] 2 1 False [0 1 0] : [1 1 1] 0 1 False [0 1 0] : [0 1 0] 2 2 True [0 0 1] : [0 0 1] 0 0 True [1 0 0] : [1 0 0] 1 1 True [0 1 0] : [0 1 0] 1 1 True [0 1 0] : [1 1 1] 1 1 True [1 1 0] : [1 1 1] 0 1 False [0 1 0] : [0 1 0] 0 0 True [1 0 0] : [1 0 0] 2 2 True [0 0 1] : [0 1 1] 2 2 True [0 0 1] : [0 0 1] 1 2 False [0 0 1] : [0 0 1] 1 2 False [0 0 1] : [0 1 1] 1 0 False [1 0 0] : [1 1 1] 0 0 True [1 0 0] : [1 0 0] 1 1 True [0 1 0] : [0 1 0] 0 0 True [1 0 0] : [1 0 0] 2 2 True [0 0 1] : [0 0 1] 2 2 True [0 0 1] : [1 1 1] 1 2 False [0 1 1] : [1 1 1] 1 0 False [1 1 0] : [1 1 1] 1 0 False [1 1 0] : [1 1 1] 2 2 True [0 0 1] : [1 1 1] 1 0 False [1 0 0] : [1 0 0] 0 0 True [1 0 0] : [1 0 0] 0 0 True [1 0 0] : [1 0 0] 0 0 True [1 0 0] : [1 0 0] 1 1 True [0 1 0] : [0 1 0] 1 1 True [0 1 0] : [0 1 1] 1 1 True [1 1 0] : [1 1 1] 1 0 False [1 0 0] : [1 1 0] 0 1 False [1 1 0] : [1 1 1] 0 0 True [1 0 0] : [1 0 0] 1 1 True [0 1 0] : [1 1 0] 1 1 True [0 1 0] : [1 1 0] 1 1 True [0 1 0] : [1 1 1] 1 0 False [1 1 0] : [1 1 1] 1 0 False [1 0 0] : [1 1 1] 2 1 False [0 1 0] : [1 1 1] 1 1 True [1 1 1] : [1 1 1] 2 1 False [0 1 1] : [1 1 1] 1 1 True [1 1 0] : [1 1 1] 2 1 False [0 1 0] : [1 1 1] 0 1 False [0 1 0] : [1 1 0] 1 1 True [0 1 0] : [1 1 1] 1 1 True [0 1 0] : [0 1 1] 1 1 True [0 1 0] : [0 1 0] 0 0 True [1 0 0] : [1 1 0] 1 1 True [0 1 0] : [0 1 0] 0 0 True [1 0 0] : [1 0 0] 1 1 True [0 1 0] : [1 1 1] 2 1 False [0 1 0] : [1 1 1] 1 1 True [0 1 0] : [0 1 0] 0 1 False [0 1 0] : [0 1 0] 0 2 False [1 1 1] : [1 1 1] 0 0 True [1 0 0] : [1 0 0] 1 1 True [0 1 0] : [0 1 0] 1 0 False [1 1 0] : [1 1 1] 1 1 True [0 1 0] : [1 1 0] 1 1 True [0 1 0] : [0 1 0] 1 1 True [0 1 0] : [0 1 0] 1 1 True [0 1 0] : [1 1 0] 0 1 False [1 1 0] : [1 1 1] 1 1 True [1 1 1] : [1 1 1] 2 1 False [0 1 0] : [1 1 1] 1 2 False [1 1 1] : [1 1 1] 2 2 True [0 0 1] : [1 1 1] 2 2 True [0 0 1] : [0 0 1] 2 1 False [0 1 1] : [1 1 1] 2 2 True [0 1 1] : [1 1 1] 2 2 True [0 0 1] : [0 0 1] 0 1 False [1 1 0] : [1 1 0] 2 2 True [0 0 1] : [0 0 1] 2 2 True [0 0 1] : [0 1 1] 1 1 True [0 1 0] : [1 1 1] 2 1 False [0 1 0] : [0 1 1] 0 1 False [1 1 0] : [1 1 0] 1 1 True [0 1 0] : [0 1 0] 1 1 True [0 1 0] : [0 1 1] 1 0 False [1 1 0] : [1 1 1] 0 0 True [1 0 0] : [1 0 0] 0 0 True [1 0 0] : [1 1 0] 0 0 True [1 0 0] : [1 0 0] 0 1 False [1 1 0] : [1 1 1] 2 1 False [0 1 1] : [1 1 1] 2 2 True [0 0 1] : [0 1 1] 1 1 True [0 1 0] : [0 1 0] 2 2 True [0 1 1] : [1 1 1] 0 0 True [1 0 0] : [1 0 0] 0 0 True [1 0 0] : [1 1 0] 0 0 True [1 0 0] : [1 0 0] 0 0 True [1 0 0] : [1 0 0] 2 2 True [0 0 1] : [1 1 1] 1 0 False [1 1 0] : [1 1 1] 1 1 True [0 1 0] : [0 1 0] 0 0 True [1 0 0] : [1 0 0] 0 0 True [1 0 0] : [1 1 0] 0 0 True [1 1 0] : [1 1 1] 1 1 True [1 1 0] : [1 1 1] 1 1 True [0 1 1] : [1 1 1] 0 0 True [1 0 0] : [1 1 1] 1 0 False [1 0 0] : [1 0 0] 1 1 True [0 1 0] : [1 1 0] 0 0 True [1 0 0] : [1 0 0] 0 0 True [1 0 0] : [1 0 0] 1 0 False [1 1 0] : [1 1 1] 1 0 False [1 0 1] : [1 1 1] 1 1 True [1 1 0] : [1 1 1] 0 0 True [1 0 0] : [1 1 0] 1 2 False [0 0 1] : [1 1 1] 0 1 False [0 1 0] : [1 1 1] 1 1 True [0 1 0] : [1 1 1] 2 2 True [0 0 1] : [0 0 1] 2 2 True [0 1 1] : [1 1 1] 2 2 True [0 0 1] : [0 1 1] 1 1 True [0 1 0] : [1 1 1] 2 2 True [0 0 1] : [1 1 1] 1 2 False [0 1 1] : [1 1 1] 1 1 True [0 1 0] : [1 1 1] 0 2 False [0 0 1] : [1 1 1] 2 2 True [1 0 1] : [1 1 1] 0 1 False [0 1 0] : [1 1 0] 1 1 True [0 1 1] : [1 1 1] 1 2 False [0 0 1] : [0 1 1] 1 1 True [1 1 0] : [1 1 1] 0 0 True [1 0 0] : [1 1 0] 1 1 True [0 1 0] : [1 1 0] 1 1 True [0 1 0] : [1 1 0] 1 1 True [0 1 0] : [0 1 0] 1 0 False [1 1 0] : [1 1 1] 1 2 False [0 0 1] : [0 1 1] 1 2 False [0 1 1] : [1 1 1] 0 0 True [1 0 0] : [1 0 0] 1 0 False [1 1 0] : [1 1 1] 1 0 False [1 1 0] : [1 1 1] 1 1 True [0 1 0] : [0 1 0] 1 1 True [0 1 0] : [1 1 1] 1 0 False [1 0 0] : [1 0 0] 1 0 False [1 0 0] : [1 1 0] 0 1 False [0 1 0] : [0 1 0] 1 1 True [0 1 0] : [1 1 1] 1 0 False [1 1 1] : [1 1 1] 1 1 True [1 1 0] : [1 1 1] 1 0 False [1 0 0] : [1 1 1] 1 0 False [1 1 1] : [1 1 1] 1 0 False [1 0 0] : [1 1 1] 0 1 False [1 1 0] : [1 1 1] 1 1 True [0 1 1] : [1 1 1] 0 0 True [1 0 0] : [1 1 0] 0 1 False [1 1 0] : [1 1 1] 0 0 True [1 1 0] : [1 1 0] 1 1 True [1 1 0] : [1 1 1] 1 1 True [0 1 0] : [1 1 1] 1 0 False [1 1 0] : [1 1 1] 0 0 True [1 0 0] : [1 1 0] 0 0 True [1 1 0] : [1 1 0] 0 0 True [1 1 0] : [1 1 1] 0 0 True [1 0 0] : [1 0 0] 1 1 True [0 1 0] : [0 1 0] 1 1 True [0 1 0] : [1 1 0] 0 1 False [0 1 0] : [1 1 0] 1 1 True [0 1 0] : [1 1 1] 0 0 True [1 1 0] : [1 1 1] 0 0 True [1 1 0] : [1 1 1] 0 0 True [1 0 0] : [1 1 0] 1 1 True [1 1 0] : [1 1 1] 1 1 True [0 1 0] : [1 1 1] 0 0 True [1 0 0] : [1 1 1] 1 1 True [1 1 0] : [1 1 1] 0 2 False [0 0 1] : [0 1 1] 0 2 False [0 1 1] : [1 1 1] 1 0 False [1 0 0] : [1 1 0] 0 0 True [1 0 0] : [1 0 0] 0 0 True [1 0 0] : [1 0 0] 0 0 True [1 0 0] : [1 1 1] 0 0 True [1 0 0] : [1 0 0] 0 0 True [1 0 0] : [1 1 0] 2 2 True [0 0 1] : [1 1 1] 0 0 True [1 0 0] : [1 1 0] 0 0 True [1 0 0] : [1 0 0]
print(tp/testY.size)
0.6848249027237354
print(class_avg_c(pred, testY, significance=0.05))
print(class_avg_c(pred, testY, significance=0.2))
2.0155642023346303 1.2996108949416343
icpmodel = ClassIcpCvHelper(icp)
res = cross_val_score(icpmodel,
trainfps,
train_cls,
iterations=10,
scoring_funcs=[class_avg_c],
significance_levels=[0.05, 0.1, 0.2],
)
res.head(10)
iter | fold | significance | class_avg_c | |
---|---|---|---|---|
0 | 0 | 0 | 0.05 | 2.135922 |
1 | 0 | 0 | 0.10 | 1.708738 |
2 | 0 | 0 | 0.20 | 1.330097 |
3 | 0 | 1 | 0.05 | 2.495146 |
4 | 0 | 1 | 0.10 | 1.922330 |
5 | 0 | 1 | 0.20 | 1.281553 |
6 | 0 | 2 | 0.05 | 2.194175 |
7 | 0 | 2 | 0.10 | 1.582524 |
8 | 0 | 2 | 0.20 | 1.271845 |
9 | 0 | 3 | 0.05 | 2.000000 |