We are getting ~80 correct predictions. This might be because the ten digit signatures are insufficient to capture the different variation of each digit. We might try to define the different flavors of each digit using a round of k-means clustering and labeling in the training dataset then a "narrow" signature definition, then prediction on the test dataset using this narrow set of signatures. Also, visualization of these digits as images will be useful. We can also carry over the similarity score as a value based category or use it to generate a ROC curve.
from clustergrammer2 import net
df = {}
import clustergrammer_groupby as cby
from copy import deepcopy
import random
random.seed(99)
import pandas as pd
net.load_file('../data/big_data/MNIST_row_labels.txt')
df['mnist'] = net.export_df()
df['mnist'].shape
(784, 70000)
cols = df['mnist'].columns.tolist()
new_cols = [(x, 'Digit: ' + x.split('-')[0]) for x in cols]
df['mnist-cat'] = deepcopy(df['mnist'])
df['mnist-cat'].columns = new_cols
print(new_cols[0])
('Zero-0', 'Digit: Zero')
cols = df['mnist-cat'].columns.tolist()
random.shuffle(cols)
df['mnist-train'] = df['mnist-cat'][cols[:35000]]
df['mnist-pred'] = df['mnist-cat'][cols[35000:]]
print(df['mnist-train'].shape, df['mnist-pred'].shape)
net.load_df(df['mnist-train'])
net.normalize(axis='row', norm_type='zscore')
df['mnist-train-z'] = net.export_df()
net.load_df(df['mnist-pred'])
net.normalize(axis='row', norm_type='zscore')
df['mnist-pred-z'] = net.export_df()
(784, 35000) (784, 35000)
def set_cat_colors(cat_color, axis, cat_index, cat_title=False):
for inst_ct in cat_color:
if cat_title != False:
cat_name = cat_title + ': ' + inst_ct
else:
cat_name = inst_ct
inst_color = cat_color[inst_ct]
net.set_cat_color(axis=axis, cat_index=cat_index, cat_name=cat_name, inst_color=inst_color)
pval_cutoff = 0.00001
num_top_dims = 50
for inst_norm in ['', '-z']:
df['sig' + inst_norm], keep_genes_dict, df_gene_pval, fold_info = cby.generate_signatures(
df['mnist-train' + inst_norm],
'Digit', num_top_dims=num_top_dims)
print(inst_norm, df['sig' + inst_norm].shape)
(276, 10)
/Users/nickfernandez/anaconda3/envs/py36lab/lib/python3.6/site-packages/scipy/stats/_distn_infrastructure.py:879: RuntimeWarning: invalid value encountered in greater return (self.a < x) & (x < self.b) /Users/nickfernandez/anaconda3/envs/py36lab/lib/python3.6/site-packages/scipy/stats/_distn_infrastructure.py:879: RuntimeWarning: invalid value encountered in less return (self.a < x) & (x < self.b) /Users/nickfernandez/anaconda3/envs/py36lab/lib/python3.6/site-packages/scipy/stats/_distn_infrastructure.py:1821: RuntimeWarning: invalid value encountered in less_equal cond2 = cond0 & (x <= self.a)
-z (276, 10)
At the coarse grained level we appear to be able to distinguish
coarse_digits = {}
coarse_digits['One'] = ['One']
coarse_digits['Three-Five-Eight'] = ['Three', 'Five', 'Eight']
coarse_digits['Four-Seven-Nine'] = ['Four', 'Seven', 'Nine']
coarse_digits['Zero-Two-Six'] = ['Zero', 'Two', 'Six']
cols = df['mnist-cat'].columns.tolist()
random.shuffle(cols)
df['mnist-train'] = df['mnist-cat'][cols[:35000]]
df['mnist-pred'] = df['mnist-cat'][cols[35000:]]
print(df['mnist-train'].shape, df['mnist-pred'].shape)
for inst_group in coarse_digits:
cols = df['mnist-train']
keep_cols = [x for x in cols if x[1].split(': ')[1] in coarse_digits[inst_group]]
df[inst_group] = df['mnist-train'][keep_cols]
print(inst_group, df[inst_group].shape)
net.load_df(df[inst_group])
net.normalize(axis='row', norm_type='zscore')
df[inst_group + '-z'] = net.export_df()
(784, 35000) (784, 35000) One (784, 3926) Three-Five-Eight (784, 10101) Four-Seven-Nine (784, 10575) Zero-Two-Six (784, 10398)
# Generate Signatures
pval_cutoff = 1e-10
num_top_dims=50
for inst_group in coarse_digits:
for inst_norm in ['', '-z']:
if '-' in inst_group:
print(inst_group)
df['sig-' + inst_group + inst_norm], keep_genes_dict, df_gene_pval, fold_info = cby.generate_signatures(
df[inst_group + inst_norm], 'Digit', pval_cutoff=pval_cutoff, num_top_dims=num_top_dims)
print(inst_group + inst_norm, df['sig-' + inst_group + inst_norm].shape)
Three-Five-Eight Three-Five-Eight (103, 3) Three-Five-Eight
/Users/nickfernandez/anaconda3/envs/py36lab/lib/python3.6/site-packages/scipy/stats/_distn_infrastructure.py:879: RuntimeWarning: invalid value encountered in greater return (self.a < x) & (x < self.b) /Users/nickfernandez/anaconda3/envs/py36lab/lib/python3.6/site-packages/scipy/stats/_distn_infrastructure.py:879: RuntimeWarning: invalid value encountered in less return (self.a < x) & (x < self.b) /Users/nickfernandez/anaconda3/envs/py36lab/lib/python3.6/site-packages/scipy/stats/_distn_infrastructure.py:1821: RuntimeWarning: invalid value encountered in less_equal cond2 = cond0 & (x <= self.a)
Three-Five-Eight-z (103, 3) Four-Seven-Nine Four-Seven-Nine (99, 3) Four-Seven-Nine Four-Seven-Nine-z (99, 3) Zero-Two-Six Zero-Two-Six (130, 3) Zero-Two-Six Zero-Two-Six-z (130, 3)
df['sig-Four-Seven-Nine'].shape
(99, 3)
net.load_df(df['sig'])
net.cluster()
tmp_cat_color = deepcopy(net.viz['cat_colors']['col']['cat-0'])
cat_color = {}
for inst_key in tmp_cat_color:
cat_color[inst_key.split(': ')[1]] = tmp_cat_color[inst_key]
cat_color['Zero'] = 'yellow'
cat_color['Four'] = 'blue'
cat_color['Seven'] = 'red'
cat_color['Nine'] = 'grey'
cat_color['One'] = 'black'
set_cat_colors(cat_color, axis='col', cat_index=1, cat_title='Digit')
cat_color
{'Eight': '#393b79', 'Five': '#ff7f0e', 'Four': 'blue', 'Nine': 'grey', 'One': 'black', 'Seven': 'red', 'Six': '#FFDB58', 'Three': '#e377c2', 'Two': '#2ca02c', 'Zero': 'yellow'}
net.load_df(df['sig'])
net.widget()
ExampleWidget(network='{"row_nodes": [{"name": "pos_10-10", "ini": 276, "clust": 234, "rank": 218, "rankvar": …
# Predict
##################
df_pred_cat, df_sig_sim, y_info = cby.predict_cats_from_sigs(df['mnist-pred'], df['sig'], truth_level=1,
predict_level='Pred Digit', unknown_thresh=0.0)
df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)
print('Predict: ', fraction_correct)
ser_correct.sort_values(ascending=False)
Predict: 0.816485714286
One 0.933435 Zero 0.901810 Six 0.884370 Seven 0.833700 Three 0.812517 Four 0.803198 Two 0.789207 Nine 0.780338 Eight 0.754257 Five 0.643282 dtype: float64
Will test running narrow prediction on broad digits.
merge_358 = ['Three', 'Five', 'Eight']
merge_479 = ['Four', 'Seven', 'Nine']
merge_026 = ['Zero', 'Two', 'Six']
y_broad = {}
inst_true = []
for inst_cat in y_info['true']:
if inst_cat in merge_358:
inst_cat = 'Three-Five-Eight'
if inst_cat in merge_479:
inst_cat = 'Four-Seven-Nine'
if inst_cat in merge_026:
inst_cat = 'Zero-Two-Six'
inst_true.append(inst_cat)
inst_pred = []
for inst_cat in y_info['pred']:
if inst_cat in merge_358:
inst_cat = 'Three-Five-Eight'
if inst_cat in merge_479:
inst_cat = 'Four-Seven-Nine'
if inst_cat in merge_026:
inst_cat = 'Zero-Two-Six'
inst_pred.append(inst_cat)
y_broad['true'] = inst_true
y_broad['pred'] = inst_pred
df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_broad)
print('Predict: ', fraction_correct)
df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_broad)
print('\nbroad cell type: ', fraction_correct, '\n')
print(ser_correct.sort_values(ascending=False))
Predict: 0.900114285714 broad cell type: 0.900114285714 One 0.933435 Four-Seven-Nine 0.910286 Zero-Two-Six 0.906856 Three-Five-Eight 0.869817 dtype: float64
# Predict
##################
df_pred_cat, df_sig_sim, y_info = cby.predict_cats_from_sigs(df['mnist-pred'], df['sig'], truth_level=1,
predict_level='Broad Digit', unknown_thresh=0.0)
ini_broad_cols = df_pred_cat.columns.tolist()
broad_cols = []
for inst_col in ini_broad_cols:
inst_cat = inst_col[2].split(': ')[1]
broad_predict = inst_col[2]
for inst_group in coarse_digits:
if inst_cat in coarse_digits[inst_group]:
broad_predict = 'Broad Digit: ' + inst_group
broad_col = (inst_col[0], inst_col[1], broad_predict)
broad_cols.append(broad_col)
df['pred-broad'] = deepcopy(df['mnist-pred'])
df['pred-broad'].columns = broad_cols
# Re-run prediction on individual broad digits if necessary
df_list = []
for inst_broad in coarse_digits:
cols = df['pred-broad'].columns.tolist()
keep_cols = [x for x in cols if x[2].split(': ')[1] == inst_broad]
inst_df = df['pred-broad'][keep_cols]
print(inst_broad, inst_df.shape)
# run prediction if necessary
if '-' in inst_broad:
tmp_cols = inst_df.columns.tolist()
# drop previous prediction
new_cols = [(x[0], x[1]) for x in tmp_cols]
inst_df.columns = new_cols
# predict using narrow signature
df_pred_cat, df_sig_sim, y_info = cby.predict_cats_from_sigs(inst_df,
df['sig-' + inst_broad], truth_level=1, predict_level='Digit', unknown_thresh=0.0)
inst_df.columns = df_pred_cat.columns.tolist()
df_list.append(inst_df)
else:
tmp_cols = inst_df.columns.tolist()
new_cols = [(x[0], x[1], x[2].replace('Broad Digit:', 'Digit:')) for x in tmp_cols]
inst_df.columns = new_cols
df_list.append(inst_df)
df['pred-narrow'] = pd.concat(df_list, axis=1)
df['pred-narrow'].shape
One (784, 4456) Three-Five-Eight (784, 9918) Four-Seven-Nine (784, 10312) Zero-Two-Six (784, 10314)
(784, 35000)
cols = df['pred-narrow'].columns.tolist()
cols[0]
y_info = {}
y_info['true'] = [x[1].split(': ')[1] for x in cols]
y_info['pred'] = [x[2].split(': ')[1] for x in cols]
df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)
print('Predict: ', fraction_correct)
ser_correct.sort_values(ascending=False)
Predict: 0.8172
One 0.933435 Zero 0.922872 Six 0.877687 Seven 0.825165 Four 0.815407 Three 0.812238 Nine 0.805134 Two 0.777122 Eight 0.751615 Five 0.622612 dtype: float64
df_pred_cat.shape
net.load_df(df_pred_cat)
set_cat_colors(cat_color, axis='col', cat_index=2, cat_title='Pred Digit')
net.load_df(df_pred_cat)
net.random_sample(axis='col', num_samples=2500, random_state=100)
net.widget()
# Predict
##################
df_pred_cat, df_sig_sim, y_info = cby.predict_cats_from_sigs(df['mnist-pred-z'], df['sig-z'], truth_level=1,
predict_level='Pred Digit', unknown_thresh=0.0)
df_sig_sim.columns = df_pred_cat.columns.tolist()
df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)
print('Predict: ', fraction_correct)
df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)
print('\nbroad cell type: ', fraction_correct, '\n')
print(ser_correct.sort_values(ascending=False))
df_conf.shape
# net.load_df(df_conf)
# net.widget()
df_sig_sim.shape
net.load_df(df_sig_sim)
net.random_sample(axis='col', num_samples=2500, random_state=99)
net.load_df(net.export_df().round(2))
net.widget()
net.load_df(df_pred_cat)
net.random_sample(axis='col', num_samples=2500, random_state=99)
net.load_df(net.export_df().round(2))
net.widget()
cols = df['mnist-cat'].columns.tolist()
random.shuffle(cols)
df['mnist-train'] = df['mnist-cat'][cols[:35000]]
df['mnist-pred'] = df['mnist-cat'][cols[35000:]]
print(df['mnist-train'].shape, df['mnist-pred'].shape)
for inst_data in ['mnist-train', 'mnist-pred']:
cols = df[inst_data]
new_cols = []
for inst_col in cols:
inst_cat = inst_col[1].split(': ')[1]
if inst_cat in merge_358:
inst_cat = 'Three-Five-Eight'
if inst_cat in merge_479:
inst_cat = 'Four-Seven-Nine'
if inst_cat in merge_026:
inst_cat = 'Zero-Two-Six'
new_col = (inst_col[0], 'Coarse: ' + inst_cat, inst_col[1])
new_cols.append(new_col)
df[inst_data + '-coarse'] = deepcopy(df[inst_data])
df[inst_data + '-coarse'].columns = new_cols
print(df[inst_data + '-coarse'].shape)
net.load_df(df[inst_data + '-coarse'])
net.normalize(axis='row', norm_type='zscore')
df[inst_data + '-coarse-z'] = net.export_df()
pval_cutoff = 0.00001
num_top_dims = 50
for inst_norm in ['', '-z']:
df['sig-broad' + inst_norm], keep_genes_dict, df_gene_pval, fold_info = cby.generate_signatures(
df['mnist-train-coarse' + inst_norm],
'Coarse', num_top_dims=num_top_dims)
print(inst_norm, df['sig-broad' + inst_norm].shape)
Need to predict broad digits, then separate each of the broad categories and predict using the narrow signature.
# Predict
##################
df_pred_cat, df_sig_sim, y_info = cby.predict_cats_from_sigs(df['mnist-pred-coarse'], df['sig-broad'], truth_level=1,
predict_level='Pred Digit', unknown_thresh=0.0)
df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)
print('Predict: ', fraction_correct)
df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)
print('\nbroad cell type: ', fraction_correct, '\n')
print(ser_correct.sort_values(ascending=False))
# Predict
##################
df_pred_cat, df_sig_sim, y_info = cby.predict_cats_from_sigs(df['mnist-pred-358-z'], df['sig-z'], truth_level=1,
predict_level='Pred Digit', unknown_thresh=0.0)
df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)
print('Predict: ', fraction_correct)
df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)
print(ser_correct.sort_values(ascending=False))
df_pred_cat.shape
net.load_df(df_pred_cat)
net.random_sample(axis='col', num_samples=2500, random_state=99)
net.load_df(net.export_df().round(2))
net.widget()
cat_color['Three-Five-Eight'] = 'red'
cat_color['Four-Seven-Nine'] = 'blue'
cat_color['Zero-Two-Six'] = 'yellow'
net.load_df(df_pred_cat)
set_cat_colors(cat_color, axis='col', cat_index=1, cat_title='Coarse')
set_cat_colors(cat_color, axis='col', cat_index=3, cat_title='Pred Digit')
net.load_df(df_pred_cat)
net.random_sample(axis='col', num_samples=2500, random_state=99)
net.load_df(net.export_df().round(2))
net.widget()