# Install necessary modules !pip install timm # For our convenience, take a peek at what we're working with !nvidia-smi # Import the core modules, check which GPU we end up with and scale batch size accordingly import torch # Flipping this on/off will change the memory dynamics, since I usually # validate and train with it on, will leave it on by default torch.backends.cudnn.benchmark = True import timm from timm.data import * from timm.utils import * import pynvml from collections import OrderedDict import logging import time def log_gpu_memory(): handle = pynvml.nvmlDeviceGetHandleByIndex(0) info = pynvml.nvmlDeviceGetMemoryInfo(handle) info.free = round(info.free / 1024**2) info.used = round(info.used / 1024**2) logging.info('GPU memory free: {}, memory used: {}'.format(info.free, info.used)) return info.used def get_gpu_memory_total(): handle = pynvml.nvmlDeviceGetHandleByIndex(0) info = pynvml.nvmlDeviceGetMemoryInfo(handle) info.total = round(info.total / 1024**2) return info.total pynvml.nvmlInit() setup_default_logging() log_gpu_memory() total_gpu_mem = get_gpu_memory_total() if total_gpu_mem > 12300: logging.info('Running on a T4 GPU or other with > 12GB memory, setting batch size to 128') batch_size = 128 else: logging.info('Running on a K80 GPU or other with < 12GB memory, batch size set to 80') batch_size = 80 # Download and extract the dataset (note it's not actually a gz like the file says) if not os.path.exists('./imagenetv2-matched-frequency'): !curl -s https://s3-us-west-2.amazonaws.com/imagenetv2public/imagenetv2-matched-frequency.tar.gz | tar x dataset = Dataset('./imagenetv2-matched-frequency/') for i in range(len(dataset)): # warmup dummy = dataset[i] # A basic validation routine with timing and accuracy metrics def validate(model, loader): batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() model.eval() #torch.cuda.reset_max_memory_allocated() #torch.cuda.reset_max_memory_cached() gpu_used_baseline = log_gpu_memory() gpu_used = 0 start = end = time.time() num_batches = len(loader) log_iter = round(0.25 * num_batches) with torch.no_grad(): for i, (input, target) in enumerate(loader): target = target.cuda() input = input.cuda() output = model(input) prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) top1.update(prec1.item(), input.size(0)) top5.update(prec5.item(), input.size(0)) batch_time.update(time.time() - end) end = time.time() if i and i % log_iter == 0: if gpu_used == 0: gpu_used = log_gpu_memory() logging.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 'Rate: {rate_avg:.3f} img/sec ' 'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' 'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) gpu_used = gpu_used - gpu_used_baseline # These measures are less consistent than method being used wrt # where the batch sizes can be pushed for each model #gpu_used = torch.cuda.max_memory_allocated() #gpu_cached = torch.cuda.max_memory_cached() elapsed = time.time() - start results = OrderedDict( top1=round(top1.avg, 3), top1_err=round(100 - top1.avg, 3), top5=round(top5.avg, 3), top5_err=round(100 - top5.avg, 3), rate=len(loader.dataset) / elapsed, gpu_used=gpu_used, ) logging.info(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f}) Rate {:.3f}'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'], results['rate'])) return results # Define the models and arguments that will be used for comparisons # include original ImageNet-1k validation results for comparison against ImageNet-V2 here orig_top1 = dict( efficientnet_b0=76.912, efficientnet_b1=78.692, efficientnet_b2=79.760, tf_efficientnet_b1=78.554, tf_efficientnet_b2=79.606, tf_efficientnet_b3=80.874, tf_efficientnet_b4=82.604, dpn68b=77.514, seresnext26_32x4d=77.104, resnet50=78.486, gluon_seresnext50_32x4d=79.912, gluon_seresnext101_32x4d=80.902, ig_resnext101_32x8d=82.688, ) models_effnet = [ dict(model_name='efficientnet_b0'), dict(model_name='efficientnet_b1'), dict(model_name='efficientnet_b2'), ] models_effnet_tf = [ dict(model_name='tf_efficientnet_b2'), # overlapping between TF non-TF for comparison dict(model_name='tf_efficientnet_b3'), dict(model_name='tf_efficientnet_b4'), ] models_resnet = [ dict(model_name='dpn68b'), # b0, yes, not a ResNet, need to find a better b0 comparison #dict(model_name='seresnext26_32x4d'), # b0, not the best b0 comparison either, a little slow dict(model_name='resnet50'), # b1 dict(model_name='gluon_seresnext50_32x4d'), # b2-b3 dict(model_name='gluon_seresnext101_32x4d'), # b3 dict(model_name='ig_resnext101_32x8d'), # b4 ] models_resnet_ttp = [ dict(model_name='resnet50', input_size=(3, 240, 240), ttp=True), dict(model_name='resnet50', input_size=(3, 260, 260), ttp=True), dict(model_name='gluon_seresnext50_32x4d', input_size=(3, 260, 260), ttp=True), dict(model_name='gluon_seresnext50_32x4d', input_size=(3, 300, 300), ttp=True), dict(model_name='gluon_seresnext101_32x4d', input_size=(3, 260, 260), ttp=True), dict(model_name='gluon_seresnext101_32x4d', input_size=(3, 300, 300), ttp=True), dict(model_name='ig_resnext101_32x8d', input_size=(3, 300, 300), ttp=True), ] from timm.models import TestTimePoolHead def model_runner(model_args): model_name = model_args['model_name'] pretrained = True checkpoint_path = '' if 'model_url' in model_args and model_args['model_url']: !wget -q {model_args['model_url']} checkpoint_path = './' + os.path.basename(model_args['model_url']) logging.info('Downloaded checkpoint {} from specified URL'.format(checkpoint_path)) pretrained = False model = timm.create_model( model_name, num_classes=1000, in_chans=3, pretrained=pretrained, checkpoint_path=checkpoint_path) data_config = timm.data.resolve_data_config(model_args, model=model, verbose=True) ttp = False if 'ttp' in model_args and model_args['ttp']: ttp = True logging.info('Applying test time pooling to model') model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size']) model_key = [model_name, str(data_config['input_size'][-1])] if ttp: model_key += ['ttp'] model_key = '-'.join(model_key) param_count = sum([m.numel() for m in model.parameters()]) logging.info('Model {} created, param count: {}. Running...'.format(model_key, param_count)) model = model.cuda() loader = create_loader( dataset, input_size=data_config['input_size'], batch_size=batch_size, use_prefetcher=True, interpolation='bicubic', mean=data_config['mean'], std=data_config['std'], crop_pct=1.0 if ttp else data_config['crop_pct'], num_workers=2) result = validate(model, loader) logging.info('Model {} done.\n'.format(model_key)) result['param_count'] = param_count / 1e6 # add extra non-metric keys for comparisons result['model_name'] = model_name result['input_size'] = data_config['input_size'] result['ttp'] = ttp del model del loader torch.cuda.empty_cache() return model_key, result # Run validation on all the models, get a coffee (or two) results_effnet = {} results_effnet_tf = {} results_resnet = {} results_resnet_ttp = {} logging.info('Running validation on native PyTorch EfficientNet models') for ma in models_effnet: mk, mr = model_runner(ma) results_effnet[mk] = mr logging.info('Running validation on ported Tensorflow EfficientNet models') for ma in models_effnet_tf: mk, mr = model_runner(ma) results_effnet_tf[mk] = mr logging.info('Running validation on ResNe(X)t models') for ma in models_resnet: mk, mr = model_runner(ma) results_resnet[mk] = mr logging.info('Running validation on ResNe(X)t models w/ Test Time Pooling enabled') for ma in models_resnet_ttp: mk, mr = model_runner(ma) results_resnet_ttp[mk] = mr results = {**results_effnet, **results_effnet_tf, **results_resnet, **results_resnet_ttp} # Setup common charting variables import numpy as np import matplotlib import matplotlib.pyplot as plt matplotlib.rcParams['figure.figsize'] = [16, 10] def annotate(ax, xv, yv, names, xo=0., yo=0., align='left'): for i, (x, y) in enumerate(zip(xv, yv)): ax1.text(x + xo, y + yo, names[i], fontsize=9, ha=align) names_all = list(results.keys()) names_effnet = list(results_effnet.keys()) names_effnet_tf = list(results_effnet_tf.keys()) names_resnet = list(results_resnet.keys()) names_resnet_ttp = list(results_resnet_ttp.keys()) acc_all = np.array([results[m]['top1'] for m in names_all]) acc_effnet = np.array([results[m]['top1'] for m in names_effnet]) acc_effnet_tf = np.array([results[m]['top1'] for m in names_effnet_tf]) acc_resnet = np.array([results[m]['top1'] for m in names_resnet]) acc_resnet_ttp = np.array([results[m]['top1'] for m in names_resnet_ttp]) print('Results by top-1 accuracy:') results_by_top1 = list(sorted(results.keys(), key=lambda x: results[x]['top1'], reverse=True)) for m in results_by_top1: print(' Model: {:34}, Top-1 {:4.2f}, Top-5 {:4.2f}, Rate: {:4.2f}'.format(m, results[m]['top1'], results[m]['top5'], results[m]['rate'])) sort_ix = np.argsort(acc_all) acc_sorted = acc_all[sort_ix] acc_min, acc_max = acc_sorted[0], acc_sorted[-1] names_sorted = np.array(names_all)[sort_ix] fig = plt.figure() ax1 = fig.add_subplot(111) ix = np.arange(len(acc_sorted)) ix_effnet = ix[np.in1d(names_sorted[ix], names_effnet)] ix_effnet_tf = ix[np.in1d(names_sorted[ix], names_effnet_tf)] ix_resnet = ix[np.in1d(names_sorted[ix], names_resnet)] ix_resnet_ttp = ix[np.in1d(names_sorted[ix], names_resnet_ttp)] ax1.bar(ix_effnet, acc_sorted[ix_effnet], color='r', label='EfficientNet') ax1.bar(ix_effnet_tf, acc_sorted[ix_effnet_tf], color='#8C001A', label='TF-EfficientNet') ax1.bar(ix_resnet, acc_sorted[ix_resnet], color='b', label='ResNet') ax1.bar(ix_resnet_ttp, acc_sorted[ix_resnet_ttp], color='#43C6DB', label='ResNet + TTP') plt.ylim([math.ceil(acc_min - .3*(acc_max - acc_min)), math.ceil(acc_max + .3*(acc_max - acc_min))]) ax1.set_title('Top-1 Comparison') ax1.set_ylabel('Top-1 Accuracy (%)') ax1.set_xlabel('Network Architecture') ax1.set_xticks(ix) ax1.set_xticklabels(names_sorted, rotation='45', ha='right') ax1.legend() plt.show() params_effnet = np.array([results[m]['param_count'] for m in names_effnet]) params_effnet_tf = np.array([results[m]['param_count'] for m in names_effnet_tf]) params_resnet = np.array([results[m]['param_count'] for m in names_resnet]) params_resnet_ttp = np.array([results[m]['param_count'] for m in names_resnet_ttp]) fig = plt.figure() ax1 = fig.add_subplot(111) ax1.scatter(params_effnet, acc_effnet, s=10, c='r', marker="s", label='EfficientNet') ax1.plot(params_effnet, acc_effnet, c='r') annotate(ax1, params_effnet, acc_effnet, names_effnet, xo=-.5, align='right') ax1.scatter(params_effnet_tf, acc_effnet_tf, s=10, c='#8C001A', marker="v", label='TF-EfficientNet') ax1.plot(params_effnet_tf, acc_effnet_tf, c='#8C001A') annotate(ax1, params_effnet_tf, acc_effnet_tf, names_effnet_tf, xo=.5, align='left') ax1.scatter(params_resnet, acc_resnet, s=10, c='b', marker="o", label='ResNet') ax1.plot(params_resnet, acc_resnet, c='b') annotate(ax1, params_resnet, acc_resnet, names_resnet, xo=0.5, align='left') ax1.scatter(params_resnet_ttp, acc_resnet_ttp, s=10, c='#43C6DB', marker="x", label='ResNet TTP') ax1.plot(params_resnet_ttp, acc_resnet_ttp, c='#43C6DB') annotate(ax1, params_resnet_ttp, acc_resnet_ttp, names_resnet_ttp, xo=0.3, align='left') ax1.set_title('Top-1 vs Parameter Count') ax1.set_ylabel('Top-1 Accuracy (%)') ax1.set_xlabel('Parameters (Millions)') ax1.legend() plt.show() print('Results by image rate:') results_by_rate = list(sorted(results.keys(), key=lambda x: results[x]['rate'], reverse=True)) for m in results_by_rate: print(' {:32} Rate: {:>6.2f}, Top-1 {:.2f}, Top-5: {:.2f}'.format( m, results[m]['rate'], results[m]['top1'], results[m]['top5'])) print() rate_effnet = np.array([results[m]['rate'] for m in names_effnet]) rate_effnet_tf = np.array([results[m]['rate'] for m in names_effnet_tf]) rate_resnet = np.array([results[m]['rate'] for m in names_resnet]) rate_resnet_ttp = np.array([results[m]['rate'] for m in names_resnet_ttp]) fig = plt.figure() ax1 = fig.add_subplot(111) ax1.scatter(rate_effnet, acc_effnet, s=10, c='r', marker="s", label='EfficientNet') ax1.plot(rate_effnet, acc_effnet, c='r') annotate(ax1, rate_effnet, acc_effnet, names_effnet, xo=.5, align='left') ax1.scatter(rate_effnet_tf, acc_effnet_tf, s=10, c='#8C001A', marker="v", label='TF-EfficientNet') ax1.plot(rate_effnet_tf, acc_effnet_tf, c='#8C001A') annotate(ax1, rate_effnet_tf, acc_effnet_tf, names_effnet_tf, xo=-.5, yo=-.2, align='right') ax1.scatter(rate_resnet, acc_resnet, s=10, c='b', marker="o", label='ResNet') ax1.plot(rate_resnet, acc_resnet, c='b') annotate(ax1, rate_resnet, acc_resnet, names_resnet, xo=.3, align='left') ax1.scatter(rate_resnet_ttp, acc_resnet_ttp, s=10, c='#43C6DB', marker="x", label='ResNet TPP') ax1.plot(rate_resnet_ttp, acc_resnet_ttp, c='#43C6DB') annotate(ax1, rate_resnet_ttp, acc_resnet_ttp, names_resnet_ttp, xo=0., yo=0., align='center') ax1.set_title('Top-1 vs Rate') ax1.set_ylabel('Top-1 Accuracy (%)') ax1.set_xlabel('Rate (Images / sec)') ax1.legend() plt.show() print('Results by GPU memory usage:') results_by_mem = list(sorted(results.keys(), key=lambda x: results[x]['gpu_used'], reverse=False)) for m in results_by_mem: print(' {:32} Mem: {}, Rate: {:>6.2f}, Top-1 {:.2f}, Top-5: {:.2f}'.format( m, results[m]['gpu_used'], results[m]['rate'], results[m]['top1'], results[m]['top5'])) mem_effnet = np.array([results[m]['gpu_used'] for m in names_effnet]) mem_effnet_tf = np.array([results[m]['gpu_used'] for m in names_effnet_tf]) mem_resnet = np.array([results[m]['gpu_used'] for m in names_resnet]) mem_resnet_ttp = np.array([results[m]['gpu_used'] for m in names_resnet_ttp]) fig = plt.figure() ax1 = fig.add_subplot(111) ax1.scatter(mem_effnet, acc_effnet, s=10, c='r', marker="s", label='EfficientNet') ax1.plot(mem_effnet, acc_effnet, c='r') annotate(ax1, mem_effnet, acc_effnet, names_effnet, xo=-.3, align='right') ax1.scatter(mem_effnet_tf, acc_effnet_tf, s=10, c='#8C001A', marker="v", label='TF-EfficientNet') ax1.plot(mem_effnet_tf, acc_effnet_tf, c='#8C001A') annotate(ax1, mem_effnet_tf, acc_effnet_tf, names_effnet_tf, xo=-.3, align='right') ax1.scatter(mem_resnet, acc_resnet, s=10, c='b', marker="o", label='ResNet') ax1.plot(mem_resnet, acc_resnet, c='b') annotate(ax1, mem_resnet, acc_resnet, names_resnet, xo=.5, align='left') # Too busy #ax1.scatter(mem_resnet_ttp, acc_resnet_ttp, s=10, c='#43C6DB', marker="o", label='ResNet TTP') #ax1.plot(mem_resnet_ttp, acc_resnet_ttp, c='#43C6DB') #annotate(ax1, mem_resnet_ttp, acc_resnet_ttp, names_resnet_ttp, xo=.5, align='left') ax1.set_title('Top-1 vs GPU Memory') ax1.set_ylabel('Top-1 Accuracy (%)') ax1.set_xlabel('GPU Memory (MB)') ax1.legend() plt.show() def compare_results(results, namea, nameb): resa, resb = results[namea], results[nameb] top1r = 100. * (resa['top1'] - resb['top1']) / resb['top1'] top5r = 100. * (resa['top5'] - resb['top5']) / resb['top5'] rater = 100. * (resa['rate'] - resb['rate']) / resb['rate'] memr = 100. * (resa['gpu_used'] - resb['gpu_used']) / resb['gpu_used'] print('{:22} vs {:28} top1: {:+4.2f}%, top5: {:+4.2f}%, rate: {:+4.2f}%, mem: {:+.2f}%'.format( namea, nameb, top1r, top5r, rater, memr)) #compare_results(results, 'efficientnet_b0-224', 'seresnext26_32x4d-224') compare_results(results, 'efficientnet_b0-224', 'dpn68b-224') compare_results(results, 'efficientnet_b1-240', 'resnet50-224') compare_results(results, 'efficientnet_b1-240', 'resnet50-240-ttp') compare_results(results, 'efficientnet_b2-260', 'gluon_seresnext50_32x4d-224') compare_results(results, 'tf_efficientnet_b3-300', 'gluon_seresnext50_32x4d-224') compare_results(results, 'tf_efficientnet_b3-300', 'gluon_seresnext101_32x4d-224') compare_results(results, 'tf_efficientnet_b4-380', 'ig_resnext101_32x8d-224') print('\nNote the cost of running with the SAME padding hack:') compare_results(results, 'tf_efficientnet_b2-260', 'efficientnet_b2-260') print('Results by absolute accuracy gap between ImageNet-V2 Matched-Frequency and original ImageNet top-1:') no_ttp_keys = [k for k in results.keys() if 'ttp' not in k] gaps = {x: (results[x]['top1'] - orig_top1[results[x]['model_name']]) for x in no_ttp_keys} sorted_keys = list(sorted(no_ttp_keys, key=lambda x: gaps[x], reverse=True)) for m in sorted_keys: print(' Model: {:34} {:4.2f}%'.format(m, gaps[m])) print() print('Results by relative accuracy gap between ImageNet-V2 Matched-Frequency and original ImageNet top-1:') gaps = {x: 100 * (results[x]['top1'] - orig_top1[results[x]['model_name']]) / orig_top1[results[x]['model_name']] for x in no_ttp_keys} sorted_keys = list(sorted(no_ttp_keys, key=lambda x: gaps[x], reverse=True)) for m in sorted_keys: print(' Model: {:34} {:4.2f}%'.format(m, gaps[m]))