First off, I want to be clear that I am not panning EfficientNets (https://arxiv.org/abs/1905.11946) here. They are unprecident in their parameter and FLOP efficiency. Thanks Mingxing Tan, Quoc V. Le, and the Google Brain team for releasing the code and weights.
I dug into the EfficientNet paper the day it was released. I had recently implemented MobileNet-v3 and MNasNet architectures in PyTorch and EfficientNets have a lot in common with those models. After defining new model definitions strings, adding the depth scaling, and hacking together some weight porting code they were alive.
First impressions were positive, "Wow, that's some impressive accuracy for so few parameters (and such small checkpoints)''. After spending more time with the models, training them, running numerous validations, etc. some realities sank in. These models are less efficient in actual use than I'd expected. I started doing more detailed comparisons with familiar ResNet models and that's how this notebook came to be...
A few points I'm hoping to illustrate in this notebook:
The efficiencies of EfficientNets may not translate to better real-world performance on all frameworks and hardware platforms. Your trusty old ResNets may be just as good for your NN framework of choice running on an NVIDIA GPU. What consumes less resources in Tensorflow with an XLA optimized graph on a TPU, may end up being more resource hungry in PyTorch running with a CUDA backend.
The story of a ResNet-50 does not end with a top-1 of 76.3% on ImageNet-1k. Neither do the other ResNe(X)t networks end with the results of the original papers or the pretrained weights of canonical Caffe, Tensorflow, or PyTorch implementations. Many papers compare shiny new architectures trained with recent techniques (or algorithmically searched hyper-parameters) to ResNet baselines that aren't given the same training effort. A ResNet-50 can be trained to well over 78% on ImageNet -- better than an 'original' ResNet-152 -- a 35M parameter difference! I've selected better pretrained models to compare against the EfficientNets.
Most PyTorch implementations of EfficientNet that I'm aware of are using the Tensorflow ported weights, like my 'tf_efficientnet_b*' models. These ported weights requires explicit padding ops to match the behaviour of Tensorflow 'SAME' padding. This padding adds a runtime penalty (about 2% for forward) and a memory penalty (reducing max batch sizes by roughly 15-20%). I've natively trained the B0 through B2 models in PyTorch now, but haven't made progress on B3 and up (very slow to train).
There are some nifty inference tricks, like test time pooling, that can breathe life into old models and allow them to be used outside of their standard resolutions without retraining. A few ResNets were run with TTP here at resolutions similar to the EffNet models as a comparison.
A few additional considerations:
I'm only running the numbers on validation here to keep the Colab notebook sane. I have trained with all of the architectures, the relative differences in throughtput and memory usage/batch size limits fit my experience training as well.
This comparison is for PyTorch 1.0/1.1 with a CUDA backend. Future versions of PyTorch, CUDA, or the PyTorch XLA TPU backend may change things significantly. I'm hoping to compare these models with the PyTorch XLA impl at some point. Not sure if it's ready yet?
The analysis is for the ImageNet classification task. The extra resolution in all EfficientNet > b0 is arguably less beneficial for this task than say fine-grained classification, segmentation, object detection and other more interesting tasks. Since the input resolution is responsible for a large amount of the GPU memory use, and ResNets for those other tasks are also run at higher res, the comparisons made do highly depend on the task.
The timm
module use here is a PyPi packaging of my PyTorch Image Models
Stand alone version of the EfficientNet, MobileNet-V3, MNasNet, etc can also be found at
# Install necessary modules
!pip install timm
Collecting timm Downloading https://files.pythonhosted.org/packages/1e/87/7de9e1175bda1151de177198bb2e99ac78cf0bdf97309b19f6d22b215b79/timm-0.1.6-py3-none-any.whl (83kB) |████████████████████████████████| 92kB 28.0MB/s Requirement already satisfied: torchvision in /usr/local/lib/python3.6/dist-packages (from timm) (0.3.0) Requirement already satisfied: torch>=1.0 in /usr/local/lib/python3.6/dist-packages (from timm) (1.1.0) Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision->timm) (4.3.0) Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torchvision->timm) (1.16.4) Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from torchvision->timm) (1.12.0) Requirement already satisfied: olefile in /usr/local/lib/python3.6/dist-packages (from pillow>=4.1.1->torchvision->timm) (0.46) Installing collected packages: timm Successfully installed timm-0.1.6
# For our convenience, take a peek at what we're working with
!nvidia-smi
Mon Jul 1 20:17:45 2019 +-----------------------------------------------------------------------------+ | NVIDIA-SMI 418.67 Driver Version: 410.79 CUDA Version: 10.0 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | |===============================+======================+======================| | 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 | | N/A 44C P8 15W / 70W | 0MiB / 15079MiB | 0% Default | +-------------------------------+----------------------+----------------------+ +-----------------------------------------------------------------------------+ | Processes: GPU Memory | | GPU PID Type Process name Usage | |=============================================================================| | No running processes found | +-----------------------------------------------------------------------------+
# Import the core modules, check which GPU we end up with and scale batch size accordingly
import torch
# Flipping this on/off will change the memory dynamics, since I usually
# validate and train with it on, will leave it on by default
torch.backends.cudnn.benchmark = True
import timm
from timm.data import *
from timm.utils import *
import pynvml
from collections import OrderedDict
import logging
import time
def log_gpu_memory():
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
info.free = round(info.free / 1024**2)
info.used = round(info.used / 1024**2)
logging.info('GPU memory free: {}, memory used: {}'.format(info.free, info.used))
return info.used
def get_gpu_memory_total():
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
info.total = round(info.total / 1024**2)
return info.total
pynvml.nvmlInit()
setup_default_logging()
log_gpu_memory()
total_gpu_mem = get_gpu_memory_total()
if total_gpu_mem > 12300:
logging.info('Running on a T4 GPU or other with > 12GB memory, setting batch size to 128')
batch_size = 128
else:
logging.info('Running on a K80 GPU or other with < 12GB memory, batch size set to 80')
batch_size = 80
GPU memory free: 15080, memory used: 0 Running on a T4 GPU or other with > 12GB memory, setting batch size to 128
If you're not aware, ImageNet-V2 (https://github.com/modestyachts/ImageNetV2) is a useful collection of 3 ImageNet-like validation sets that have been collected more recently, 10 years after the original ImageNet.
Aside from being conveniently smaller and easier to deploy in a notebook, it's a useful test set to compare how models might generalize beyond the original ImageNet-1k data. We're going to use the 'Matched Frequency' version of the dataset. There is a markedly lower accuracy rate across the board for this test set. It's very interesting to see how different models fall relative to each other. I've included an analysis of those differences at the bottom.
# Download and extract the dataset (note it's not actually a gz like the file says)
if not os.path.exists('./imagenetv2-matched-frequency'):
!curl -s https://s3-us-west-2.amazonaws.com/imagenetv2public/imagenetv2-matched-frequency.tar.gz | tar x
dataset = Dataset('./imagenetv2-matched-frequency/')
for i in range(len(dataset)): # warmup
dummy = dataset[i]
# A basic validation routine with timing and accuracy metrics
def validate(model, loader):
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
model.eval()
#torch.cuda.reset_max_memory_allocated()
#torch.cuda.reset_max_memory_cached()
gpu_used_baseline = log_gpu_memory()
gpu_used = 0
start = end = time.time()
num_batches = len(loader)
log_iter = round(0.25 * num_batches)
with torch.no_grad():
for i, (input, target) in enumerate(loader):
target = target.cuda()
input = input.cuda()
output = model(input)
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
top1.update(prec1.item(), input.size(0))
top5.update(prec5.item(), input.size(0))
batch_time.update(time.time() - end)
end = time.time()
if i and i % log_iter == 0:
if gpu_used == 0:
gpu_used = log_gpu_memory()
logging.info(
'Test: [{0:>4d}/{1}] '
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Rate: {rate_avg:.3f} img/sec '
'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
i, len(loader), batch_time=batch_time,
rate_avg=input.size(0) / batch_time.avg,
loss=losses, top1=top1, top5=top5))
gpu_used = gpu_used - gpu_used_baseline
# These measures are less consistent than method being used wrt
# where the batch sizes can be pushed for each model
#gpu_used = torch.cuda.max_memory_allocated()
#gpu_cached = torch.cuda.max_memory_cached()
elapsed = time.time() - start
results = OrderedDict(
top1=round(top1.avg, 3), top1_err=round(100 - top1.avg, 3),
top5=round(top5.avg, 3), top5_err=round(100 - top5.avg, 3),
rate=len(loader.dataset) / elapsed, gpu_used=gpu_used,
)
logging.info(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f}) Rate {:.3f}'.format(
results['top1'], results['top1_err'], results['top5'],
results['top5_err'], results['rate']))
return results
As per the intro, one of the goals here is to compare EfficientNets with a more capable set of baseline models. I've gone through the various models included in my collection and selected several that I feel are more appropriate matches based on their Top-1 scores from much better training setups than originals.
Here we will split them into 4 lists for analysis and charting:
Note: I realize it's not entirely fair to include the IG ResNext model since it's not technically trained purely on ImageNet like the others. But, it's a truly impressive model, and actually quite a bit easier to work with in PyTorch than even the B4 EfficientNet.
# Define the models and arguments that will be used for comparisons
# include original ImageNet-1k validation results for comparison against ImageNet-V2 here
orig_top1 = dict(
efficientnet_b0=76.912,
efficientnet_b1=78.692,
efficientnet_b2=79.760,
tf_efficientnet_b1=78.554,
tf_efficientnet_b2=79.606,
tf_efficientnet_b3=80.874,
tf_efficientnet_b4=82.604,
dpn68b=77.514,
seresnext26_32x4d=77.104,
resnet50=78.486,
gluon_seresnext50_32x4d=79.912,
gluon_seresnext101_32x4d=80.902,
ig_resnext101_32x8d=82.688,
)
models_effnet = [
dict(model_name='efficientnet_b0'),
dict(model_name='efficientnet_b1'),
dict(model_name='efficientnet_b2'),
]
models_effnet_tf = [
dict(model_name='tf_efficientnet_b2'), # overlapping between TF non-TF for comparison
dict(model_name='tf_efficientnet_b3'),
dict(model_name='tf_efficientnet_b4'),
]
models_resnet = [
dict(model_name='dpn68b'), # b0, yes, not a ResNet, need to find a better b0 comparison
#dict(model_name='seresnext26_32x4d'), # b0, not the best b0 comparison either, a little slow
dict(model_name='resnet50'), # b1
dict(model_name='gluon_seresnext50_32x4d'), # b2-b3
dict(model_name='gluon_seresnext101_32x4d'), # b3
dict(model_name='ig_resnext101_32x8d'), # b4
]
models_resnet_ttp = [
dict(model_name='resnet50', input_size=(3, 240, 240), ttp=True),
dict(model_name='resnet50', input_size=(3, 260, 260), ttp=True),
dict(model_name='gluon_seresnext50_32x4d', input_size=(3, 260, 260), ttp=True),
dict(model_name='gluon_seresnext50_32x4d', input_size=(3, 300, 300), ttp=True),
dict(model_name='gluon_seresnext101_32x4d', input_size=(3, 260, 260), ttp=True),
dict(model_name='gluon_seresnext101_32x4d', input_size=(3, 300, 300), ttp=True),
dict(model_name='ig_resnext101_32x8d', input_size=(3, 300, 300), ttp=True),
]
The runner creates each model, a matching data loader, and runs the validation. It uses several features of my image collection module for this.
Test time pooling is enabled here if requested in the model_args. The pooling is implemented as a module the wraps the base network. It's important to set the crop factor for the images to 1.0 when enabling pooling.
from timm.models import TestTimePoolHead
def model_runner(model_args):
model_name = model_args['model_name']
pretrained = True
checkpoint_path = ''
if 'model_url' in model_args and model_args['model_url']:
!wget -q {model_args['model_url']}
checkpoint_path = './' + os.path.basename(model_args['model_url'])
logging.info('Downloaded checkpoint {} from specified URL'.format(checkpoint_path))
pretrained = False
model = timm.create_model(
model_name,
num_classes=1000,
in_chans=3,
pretrained=pretrained,
checkpoint_path=checkpoint_path)
data_config = timm.data.resolve_data_config(model_args, model=model, verbose=True)
ttp = False
if 'ttp' in model_args and model_args['ttp']:
ttp = True
logging.info('Applying test time pooling to model')
model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size'])
model_key = [model_name, str(data_config['input_size'][-1])]
if ttp:
model_key += ['ttp']
model_key = '-'.join(model_key)
param_count = sum([m.numel() for m in model.parameters()])
logging.info('Model {} created, param count: {}. Running...'.format(model_key, param_count))
model = model.cuda()
loader = create_loader(
dataset,
input_size=data_config['input_size'],
batch_size=batch_size,
use_prefetcher=True,
interpolation='bicubic',
mean=data_config['mean'],
std=data_config['std'],
crop_pct=1.0 if ttp else data_config['crop_pct'],
num_workers=2)
result = validate(model, loader)
logging.info('Model {} done.\n'.format(model_key))
result['param_count'] = param_count / 1e6
# add extra non-metric keys for comparisons
result['model_name'] = model_name
result['input_size'] = data_config['input_size']
result['ttp'] = ttp
del model
del loader
torch.cuda.empty_cache()
return model_key, result
# Run validation on all the models, get a coffee (or two)
results_effnet = {}
results_effnet_tf = {}
results_resnet = {}
results_resnet_ttp = {}
logging.info('Running validation on native PyTorch EfficientNet models')
for ma in models_effnet:
mk, mr = model_runner(ma)
results_effnet[mk] = mr
logging.info('Running validation on ported Tensorflow EfficientNet models')
for ma in models_effnet_tf:
mk, mr = model_runner(ma)
results_effnet_tf[mk] = mr
logging.info('Running validation on ResNe(X)t models')
for ma in models_resnet:
mk, mr = model_runner(ma)
results_resnet[mk] = mr
logging.info('Running validation on ResNe(X)t models w/ Test Time Pooling enabled')
for ma in models_resnet_ttp:
mk, mr = model_runner(ma)
results_resnet_ttp[mk] = mr
results = {**results_effnet, **results_effnet_tf, **results_resnet, **results_resnet_ttp}
Running validation on native PyTorch EfficientNet models Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0-d6904d92.pth" to /root/.cache/torch/checkpoints/efficientnet_b0-d6904d92.pth 100%|██████████| 21376958/21376958 [00:02<00:00, 8676444.76it/s] Data processing configuration for current model + dataset: input_size: (3, 224, 224) interpolation: bicubic mean: (0.485, 0.456, 0.406) std: (0.229, 0.224, 0.225) crop_pct: 0.875 Model efficientnet_b0-224 created, param count: 5288548. Running... GPU memory free: 14276, memory used: 804 GPU memory free: 11346, memory used: 3734 Test: [ 20/79] Time: 0.190 (0.805) Rate: 159.098 img/sec Prec@1: 64.8438 (69.6801) Prec@5: 87.5000 (88.9509) Test: [ 40/79] Time: 0.194 (0.800) Rate: 159.972 img/sec Prec@1: 51.5625 (68.8072) Prec@5: 79.6875 (88.5671) Test: [ 60/79] Time: 0.186 (0.790) Rate: 162.028 img/sec Prec@1: 60.9375 (66.1501) Prec@5: 83.5938 (86.6035) * Prec@1 64.580 (35.420) Prec@5 85.890 (14.110) Rate 165.732 Model efficientnet_b0-224 done. Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth" to /root/.cache/torch/checkpoints/efficientnet_b1-533bc792.pth 100%|██████████| 31502706/31502706 [00:03<00:00, 9936470.52it/s] Data processing configuration for current model + dataset: input_size: (3, 240, 240) interpolation: bicubic mean: (0.485, 0.456, 0.406) std: (0.229, 0.224, 0.225) crop_pct: 0.882 Model efficientnet_b1-240 created, param count: 7794184. Running... GPU memory free: 14260, memory used: 820 GPU memory free: 10890, memory used: 4190 Test: [ 20/79] Time: 0.311 (0.919) Rate: 139.286 img/sec Prec@1: 69.5312 (73.9583) Prec@5: 86.7188 (90.7366) Test: [ 40/79] Time: 0.310 (0.878) Rate: 145.851 img/sec Prec@1: 58.5938 (72.1799) Prec@5: 81.2500 (89.9200) Test: [ 60/79] Time: 0.312 (0.867) Rate: 147.679 img/sec Prec@1: 67.1875 (69.0958) Prec@5: 81.2500 (87.9867) * Prec@1 67.550 (32.450) Prec@5 87.290 (12.710) Rate 151.628 Model efficientnet_b1-240 done. Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2-cf78dc4d.pth" to /root/.cache/torch/checkpoints/efficientnet_b2-cf78dc4d.pth 100%|██████████| 36788101/36788101 [00:03<00:00, 11752398.17it/s] Data processing configuration for current model + dataset: input_size: (3, 260, 260) interpolation: bicubic mean: (0.485, 0.456, 0.406) std: (0.229, 0.224, 0.225) crop_pct: 0.89 Model efficientnet_b2-260 created, param count: 9109994. Running... GPU memory free: 14258, memory used: 822 GPU memory free: 10266, memory used: 4814 Test: [ 20/79] Time: 0.416 (0.941) Rate: 136.036 img/sec Prec@1: 68.7500 (72.9539) Prec@5: 88.2812 (91.0714) Test: [ 40/79] Time: 0.429 (0.914) Rate: 140.068 img/sec Prec@1: 58.5938 (71.9893) Prec@5: 82.0312 (90.4535) Test: [ 60/79] Time: 0.527 (0.894) Rate: 143.120 img/sec Prec@1: 64.0625 (69.3904) Prec@5: 85.9375 (88.8960) * Prec@1 67.800 (32.200) Prec@5 88.200 (11.800) Rate 144.201 Model efficientnet_b2-260 done. Running validation on ported Tensorflow EfficientNet models Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2-e393ef04.pth" to /root/.cache/torch/checkpoints/tf_efficientnet_b2-e393ef04.pth 100%|██████████| 36797929/36797929 [00:03<00:00, 11014399.83it/s] Data processing configuration for current model + dataset: input_size: (3, 260, 260) interpolation: bicubic mean: (0.485, 0.456, 0.406) std: (0.229, 0.224, 0.225) crop_pct: 0.89 Model tf_efficientnet_b2-260 created, param count: 9109994. Running... GPU memory free: 14258, memory used: 822 GPU memory free: 9568, memory used: 5512 Test: [ 20/79] Time: 1.217 (0.960) Rate: 133.306 img/sec Prec@1: 66.4062 (72.7679) Prec@5: 87.5000 (90.4018) Test: [ 40/79] Time: 0.522 (0.917) Rate: 139.645 img/sec Prec@1: 58.5938 (71.3986) Prec@5: 79.6875 (89.7675) Test: [ 60/79] Time: 0.939 (0.908) Rate: 141.046 img/sec Prec@1: 64.8438 (68.9037) Prec@5: 85.1562 (88.2172) * Prec@1 67.400 (32.600) Prec@5 87.580 (12.420) Rate 142.727 Model tf_efficientnet_b2-260 done. Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3-e3bd6955.pth" to /root/.cache/torch/checkpoints/tf_efficientnet_b3-e3bd6955.pth 100%|██████████| 49381362/49381362 [00:03<00:00, 12584590.15it/s] Data processing configuration for current model + dataset: input_size: (3, 300, 300) interpolation: bicubic mean: (0.485, 0.456, 0.406) std: (0.229, 0.224, 0.225) crop_pct: 0.904 Model tf_efficientnet_b3-300 created, param count: 12233232. Running... GPU memory free: 14242, memory used: 838 GPU memory free: 5604, memory used: 9476 Test: [ 20/79] Time: 1.267 (1.161) Rate: 110.269 img/sec Prec@1: 66.4062 (73.8467) Prec@5: 90.6250 (91.6667) Test: [ 40/79] Time: 0.833 (1.097) Rate: 116.649 img/sec Prec@1: 60.9375 (72.8087) Prec@5: 83.5938 (90.7393) Test: [ 60/79] Time: 1.242 (1.082) Rate: 118.310 img/sec Prec@1: 67.1875 (70.1588) Prec@5: 84.3750 (89.1522) * Prec@1 68.520 (31.480) Prec@5 88.700 (11.300) Rate 119.134 Model tf_efficientnet_b3-300 done. Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4-74ee3bed.pth" to /root/.cache/torch/checkpoints/tf_efficientnet_b4-74ee3bed.pth 100%|██████████| 77989689/77989689 [00:06<00:00, 12751872.12it/s] Data processing configuration for current model + dataset: input_size: (3, 380, 380) interpolation: bicubic mean: (0.485, 0.456, 0.406) std: (0.229, 0.224, 0.225) crop_pct: 0.922 Model tf_efficientnet_b4-380 created, param count: 19341616. Running... GPU memory free: 14214, memory used: 866 GPU memory free: 2460, memory used: 12620 Test: [ 20/79] Time: 1.761 (2.057) Rate: 62.222 img/sec Prec@1: 69.5312 (76.4509) Prec@5: 91.4062 (92.6339) Test: [ 40/79] Time: 1.740 (1.914) Rate: 66.889 img/sec Prec@1: 64.8438 (75.4954) Prec@5: 83.5938 (92.2637) Test: [ 60/79] Time: 1.782 (1.866) Rate: 68.600 img/sec Prec@1: 71.0938 (72.8740) Prec@5: 85.1562 (90.6634) * Prec@1 71.340 (28.660) Prec@5 90.110 (9.890) Rate 69.103 Model tf_efficientnet_b4-380 done. Running validation on ResNe(X)t models Downloading: "https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68b_extra-84854c156.pth" to /root/.cache/torch/checkpoints/dpn68b_extra-84854c156.pth 100%|██████████| 50765517/50765517 [00:04<00:00, 12271223.44it/s] Data processing configuration for current model + dataset: input_size: (3, 224, 224) interpolation: bicubic mean: (0.48627450980392156, 0.4588235294117647, 0.40784313725490196) std: (0.23482446870963955, 0.23482446870963955, 0.23482446870963955) crop_pct: 0.875 Model dpn68b-224 created, param count: 12611602. Running... GPU memory free: 14240, memory used: 840 GPU memory free: 11342, memory used: 3738 Test: [ 20/79] Time: 0.442 (0.876) Rate: 146.176 img/sec Prec@1: 54.6875 (70.2381) Prec@5: 85.9375 (88.9509) Test: [ 40/79] Time: 1.007 (0.847) Rate: 151.177 img/sec Prec@1: 57.8125 (69.5122) Prec@5: 78.9062 (88.4337) Test: [ 60/79] Time: 1.015 (0.834) Rate: 153.556 img/sec Prec@1: 60.1562 (66.8033) Prec@5: 78.9062 (86.5907) * Prec@1 65.600 (34.400) Prec@5 85.940 (14.060) Rate 155.150 Model dpn68b-224 done. Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/rw_resnet50-86acaeed.pth" to /root/.cache/torch/checkpoints/rw_resnet50-86acaeed.pth 100%|██████████| 102488165/102488165 [00:07<00:00, 13755311.81it/s] Data processing configuration for current model + dataset: input_size: (3, 224, 224) interpolation: bicubic mean: (0.485, 0.456, 0.406) std: (0.229, 0.224, 0.225) crop_pct: 0.875 Model resnet50-224 created, param count: 25557032. Running... GPU memory free: 14182, memory used: 898 GPU memory free: 12652, memory used: 2428 Test: [ 20/79] Time: 0.406 (0.859) Rate: 149.042 img/sec Prec@1: 66.4062 (72.6562) Prec@5: 90.6250 (90.4762) Test: [ 40/79] Time: 0.662 (0.820) Rate: 156.156 img/sec Prec@1: 58.5938 (71.1128) Prec@5: 85.9375 (89.5960) Test: [ 60/79] Time: 0.601 (0.807) Rate: 158.594 img/sec Prec@1: 61.7188 (68.3017) Prec@5: 82.0312 (87.7946) * Prec@1 66.810 (33.190) Prec@5 87.000 (13.000) Rate 159.510 Model resnet50-224 done. Downloading: "https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext50_32x4d-90cf2d6e.pth" to /root/.cache/torch/checkpoints/gluon_seresnext50_32x4d-90cf2d6e.pth 100%|██████████| 110578827/110578827 [00:08<00:00, 12788555.61it/s] Data processing configuration for current model + dataset: input_size: (3, 224, 224) interpolation: bicubic mean: (0.485, 0.456, 0.406) std: (0.229, 0.224, 0.225) crop_pct: 0.875 Model gluon_seresnext50_32x4d-224 created, param count: 27559896. Running... GPU memory free: 14180, memory used: 900 GPU memory free: 12510, memory used: 2570 Test: [ 20/79] Time: 1.013 (0.875) Rate: 146.238 img/sec Prec@1: 70.3125 (74.2188) Prec@5: 88.2812 (91.0714) Test: [ 40/79] Time: 1.197 (0.859) Rate: 149.059 img/sec Prec@1: 60.9375 (72.8849) Prec@5: 82.8125 (90.4345) Test: [ 60/79] Time: 1.185 (0.859) Rate: 148.930 img/sec Prec@1: 64.8438 (70.0307) Prec@5: 84.3750 (88.8064) * Prec@1 68.670 (31.330) Prec@5 88.320 (11.680) Rate 150.435 Model gluon_seresnext50_32x4d-224 done. Downloading: "https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_32x4d-cf52900d.pth" to /root/.cache/torch/checkpoints/gluon_seresnext101_32x4d-cf52900d.pth 100%|██████████| 196505510/196505510 [00:12<00:00, 16164511.02it/s] Data processing configuration for current model + dataset: input_size: (3, 224, 224) interpolation: bicubic mean: (0.485, 0.456, 0.406) std: (0.229, 0.224, 0.225) crop_pct: 0.875 Model gluon_seresnext101_32x4d-224 created, param count: 48955416. Running... GPU memory free: 14086, memory used: 994 GPU memory free: 12272, memory used: 2808 Test: [ 20/79] Time: 0.897 (1.016) Rate: 125.932 img/sec Prec@1: 72.6562 (75.5580) Prec@5: 88.2812 (91.6667) Test: [ 40/79] Time: 0.899 (0.997) Rate: 128.324 img/sec Prec@1: 64.8438 (74.4284) Prec@5: 83.5938 (91.2538) Test: [ 60/79] Time: 0.867 (0.986) Rate: 129.853 img/sec Prec@1: 67.1875 (71.7597) Prec@5: 89.0625 (89.6644) * Prec@1 70.010 (29.990) Prec@5 88.910 (11.090) Rate 131.572 Model gluon_seresnext101_32x4d-224 done. Downloading: "https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth" to /root/.cache/torch/checkpoints/ig_resnext101_32x8-c38310e5.pth 100%|██████████| 356056638/356056638 [00:11<00:00, 31320647.42it/s] Data processing configuration for current model + dataset: input_size: (3, 224, 224) interpolation: bilinear mean: (0.485, 0.456, 0.406) std: (0.229, 0.224, 0.225) crop_pct: 0.875 Model ig_resnext101_32x8d-224 created, param count: 88791336. Running... GPU memory free: 13946, memory used: 1134 GPU memory free: 10564, memory used: 4516 Test: [ 20/79] Time: 1.560 (1.664) Rate: 76.934 img/sec Prec@1: 76.5625 (78.9807) Prec@5: 93.7500 (94.2708) Test: [ 40/79] Time: 1.450 (1.582) Rate: 80.907 img/sec Prec@1: 66.4062 (77.9535) Prec@5: 88.2812 (93.7881) Test: [ 60/79] Time: 1.470 (1.540) Rate: 83.129 img/sec Prec@1: 74.2188 (75.0256) Prec@5: 91.4062 (92.6358) * Prec@1 73.830 (26.170) Prec@5 92.280 (7.720) Rate 83.352 Model ig_resnext101_32x8d-224 done. Running validation on ResNe(X)t models w/ Test Time Pooling enabled Data processing configuration for current model + dataset: input_size: (3, 240, 240) interpolation: bicubic mean: (0.485, 0.456, 0.406) std: (0.229, 0.224, 0.225) crop_pct: 0.875 Applying test time pooling to model Model resnet50-240-ttp created, param count: 25557032. Running... GPU memory free: 14182, memory used: 898 GPU memory free: 12098, memory used: 2982 Test: [ 20/79] Time: 0.429 (0.892) Rate: 143.505 img/sec Prec@1: 67.1875 (72.7679) Prec@5: 89.0625 (90.3274) Test: [ 40/79] Time: 0.757 (0.845) Rate: 151.416 img/sec Prec@1: 55.4688 (71.1128) Prec@5: 84.3750 (89.5198) Test: [ 60/79] Time: 1.154 (0.831) Rate: 154.108 img/sec Prec@1: 61.7188 (68.4170) Prec@5: 82.8125 (87.6537) * Prec@1 67.020 (32.980) Prec@5 87.040 (12.960) Rate 154.346 Model resnet50-240-ttp done. Data processing configuration for current model + dataset: input_size: (3, 260, 260) interpolation: bicubic mean: (0.485, 0.456, 0.406) std: (0.229, 0.224, 0.225) crop_pct: 0.875 Applying test time pooling to model Model resnet50-260-ttp created, param count: 25557032. Running... GPU memory free: 14182, memory used: 898 GPU memory free: 11650, memory used: 3430 Test: [ 20/79] Time: 1.172 (1.097) Rate: 116.650 img/sec Prec@1: 68.7500 (72.9911) Prec@5: 87.5000 (90.5134) Test: [ 40/79] Time: 0.902 (0.976) Rate: 131.211 img/sec Prec@1: 57.8125 (72.0084) Prec@5: 82.8125 (89.9581) Test: [ 60/79] Time: 0.832 (0.940) Rate: 136.223 img/sec Prec@1: 60.1562 (69.2751) Prec@5: 85.9375 (88.2684) * Prec@1 67.630 (32.370) Prec@5 87.630 (12.370) Rate 135.915 Model resnet50-260-ttp done. Data processing configuration for current model + dataset: input_size: (3, 260, 260) interpolation: bicubic mean: (0.485, 0.456, 0.406) std: (0.229, 0.224, 0.225) crop_pct: 0.875 Applying test time pooling to model Model gluon_seresnext50_32x4d-260-ttp created, param count: 27559896. Running... GPU memory free: 14180, memory used: 900 GPU memory free: 11594, memory used: 3486 Test: [ 20/79] Time: 1.229 (1.147) Rate: 111.577 img/sec Prec@1: 71.8750 (74.4420) Prec@5: 86.7188 (91.2946) Test: [ 40/79] Time: 1.056 (1.053) Rate: 121.593 img/sec Prec@1: 62.5000 (73.8567) Prec@5: 85.1562 (90.6822) Test: [ 60/79] Time: 1.133 (1.015) Rate: 126.067 img/sec Prec@1: 68.7500 (71.1194) Prec@5: 86.7188 (89.0625) * Prec@1 69.670 (30.330) Prec@5 88.620 (11.380) Rate 126.519 Model gluon_seresnext50_32x4d-260-ttp done. Data processing configuration for current model + dataset: input_size: (3, 300, 300) interpolation: bicubic mean: (0.485, 0.456, 0.406) std: (0.229, 0.224, 0.225) crop_pct: 0.875 Applying test time pooling to model Model gluon_seresnext50_32x4d-300-ttp created, param count: 27559896. Running... GPU memory free: 14180, memory used: 900 GPU memory free: 10880, memory used: 4200 Test: [ 20/79] Time: 1.041 (1.484) Rate: 86.250 img/sec Prec@1: 71.8750 (76.3021) Prec@5: 89.0625 (91.9271) Test: [ 40/79] Time: 1.037 (1.287) Rate: 99.457 img/sec Prec@1: 64.0625 (75.0572) Prec@5: 86.7188 (91.3300) Test: [ 60/79] Time: 1.064 (1.216) Rate: 105.295 img/sec Prec@1: 71.0938 (72.1952) Prec@5: 88.2812 (89.7285) * Prec@1 70.470 (29.530) Prec@5 89.180 (10.820) Rate 104.694 Model gluon_seresnext50_32x4d-300-ttp done. Data processing configuration for current model + dataset: input_size: (3, 260, 260) interpolation: bicubic mean: (0.485, 0.456, 0.406) std: (0.229, 0.224, 0.225) crop_pct: 0.875 Applying test time pooling to model Model gluon_seresnext101_32x4d-260-ttp created, param count: 48955416. Running... GPU memory free: 14086, memory used: 994 GPU memory free: 11634, memory used: 3446 Test: [ 20/79] Time: 1.307 (1.413) Rate: 90.559 img/sec Prec@1: 71.8750 (76.3393) Prec@5: 89.0625 (92.0387) Test: [ 40/79] Time: 1.307 (1.362) Rate: 93.981 img/sec Prec@1: 61.7188 (75.6479) Prec@5: 82.0312 (91.8826) Test: [ 60/79] Time: 1.303 (1.343) Rate: 95.329 img/sec Prec@1: 74.2188 (72.8868) Prec@5: 87.5000 (90.1895) * Prec@1 71.140 (28.860) Prec@5 89.470 (10.530) Rate 95.842 Model gluon_seresnext101_32x4d-260-ttp done. Data processing configuration for current model + dataset: input_size: (3, 300, 300) interpolation: bicubic mean: (0.485, 0.456, 0.406) std: (0.229, 0.224, 0.225) crop_pct: 0.875 Applying test time pooling to model Model gluon_seresnext101_32x4d-300-ttp created, param count: 48955416. Running... GPU memory free: 14086, memory used: 994 GPU memory free: 10834, memory used: 4246 Test: [ 20/79] Time: 1.691 (1.786) Rate: 71.683 img/sec Prec@1: 71.8750 (77.5298) Prec@5: 91.4062 (93.1176) Test: [ 40/79] Time: 1.669 (1.732) Rate: 73.888 img/sec Prec@1: 63.2812 (76.2767) Prec@5: 85.1562 (92.5877) Test: [ 60/79] Time: 1.693 (1.715) Rate: 74.635 img/sec Prec@1: 75.0000 (73.7193) Prec@5: 92.1875 (90.9964) * Prec@1 71.990 (28.010) Prec@5 90.100 (9.900) Rate 74.874 Model gluon_seresnext101_32x4d-300-ttp done. Data processing configuration for current model + dataset: input_size: (3, 300, 300) interpolation: bilinear mean: (0.485, 0.456, 0.406) std: (0.229, 0.224, 0.225) crop_pct: 0.875 Applying test time pooling to model Model ig_resnext101_32x8d-300-ttp created, param count: 88791336. Running... GPU memory free: 13946, memory used: 1134 GPU memory free: 9288, memory used: 5792 Test: [ 20/79] Time: 2.850 (3.122) Rate: 41.006 img/sec Prec@1: 75.0000 (79.3155) Prec@5: 93.7500 (94.8661) Test: [ 40/79] Time: 2.855 (2.989) Rate: 42.826 img/sec Prec@1: 64.8438 (78.6966) Prec@5: 87.5000 (94.3979) Test: [ 60/79] Time: 2.856 (2.945) Rate: 43.463 img/sec Prec@1: 74.2188 (76.2295) Prec@5: 89.0625 (93.0456) * Prec@1 75.170 (24.830) Prec@5 92.660 (7.340) Rate 43.622 Model ig_resnext101_32x8d-300-ttp done.
We're going walk through the results and look at several things:
# Setup common charting variables
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['figure.figsize'] = [16, 10]
def annotate(ax, xv, yv, names, xo=0., yo=0., align='left'):
for i, (x, y) in enumerate(zip(xv, yv)):
ax1.text(x + xo, y + yo, names[i], fontsize=9, ha=align)
names_all = list(results.keys())
names_effnet = list(results_effnet.keys())
names_effnet_tf = list(results_effnet_tf.keys())
names_resnet = list(results_resnet.keys())
names_resnet_ttp = list(results_resnet_ttp.keys())
acc_all = np.array([results[m]['top1'] for m in names_all])
acc_effnet = np.array([results[m]['top1'] for m in names_effnet])
acc_effnet_tf = np.array([results[m]['top1'] for m in names_effnet_tf])
acc_resnet = np.array([results[m]['top1'] for m in names_resnet])
acc_resnet_ttp = np.array([results[m]['top1'] for m in names_resnet_ttp])
We'll start by ranking the models by Top-1 accuracy on the ImageNet-V2 validation set.
You'll notice that a well trained
The ResNeXt101-32x8d pretrained on Facebook's Instagram is in a class of it's own. Somewhat unfairly since pretrained on a larger dataset. However, since it generalizes better than any model I've seen to this dataset (see bottom) and runs faster with less memory overehead than the EfficientNet-B4 (despite it's 88M parameters), I've included it.
print('Results by top-1 accuracy:')
results_by_top1 = list(sorted(results.keys(), key=lambda x: results[x]['top1'], reverse=True))
for m in results_by_top1:
print(' Model: {:34}, Top-1 {:4.2f}, Top-5 {:4.2f}, Rate: {:4.2f}'.format(m, results[m]['top1'], results[m]['top5'], results[m]['rate']))
Results by top-1 accuracy: Model: ig_resnext101_32x8d-300-ttp , Top-1 75.17, Top-5 92.66, Rate: 43.62 Model: ig_resnext101_32x8d-224 , Top-1 73.83, Top-5 92.28, Rate: 83.35 Model: gluon_seresnext101_32x4d-300-ttp , Top-1 71.99, Top-5 90.10, Rate: 74.87 Model: tf_efficientnet_b4-380 , Top-1 71.34, Top-5 90.11, Rate: 69.10 Model: gluon_seresnext101_32x4d-260-ttp , Top-1 71.14, Top-5 89.47, Rate: 95.84 Model: gluon_seresnext50_32x4d-300-ttp , Top-1 70.47, Top-5 89.18, Rate: 104.69 Model: gluon_seresnext101_32x4d-224 , Top-1 70.01, Top-5 88.91, Rate: 131.57 Model: gluon_seresnext50_32x4d-260-ttp , Top-1 69.67, Top-5 88.62, Rate: 126.52 Model: gluon_seresnext50_32x4d-224 , Top-1 68.67, Top-5 88.32, Rate: 150.43 Model: tf_efficientnet_b3-300 , Top-1 68.52, Top-5 88.70, Rate: 119.13 Model: efficientnet_b2-260 , Top-1 67.80, Top-5 88.20, Rate: 144.20 Model: resnet50-260-ttp , Top-1 67.63, Top-5 87.63, Rate: 135.92 Model: efficientnet_b1-240 , Top-1 67.55, Top-5 87.29, Rate: 151.63 Model: tf_efficientnet_b2-260 , Top-1 67.40, Top-5 87.58, Rate: 142.73 Model: resnet50-240-ttp , Top-1 67.02, Top-5 87.04, Rate: 154.35 Model: resnet50-224 , Top-1 66.81, Top-5 87.00, Rate: 159.51 Model: dpn68b-224 , Top-1 65.60, Top-5 85.94, Rate: 155.15 Model: efficientnet_b0-224 , Top-1 64.58, Top-5 85.89, Rate: 165.73
sort_ix = np.argsort(acc_all)
acc_sorted = acc_all[sort_ix]
acc_min, acc_max = acc_sorted[0], acc_sorted[-1]
names_sorted = np.array(names_all)[sort_ix]
fig = plt.figure()
ax1 = fig.add_subplot(111)
ix = np.arange(len(acc_sorted))
ix_effnet = ix[np.in1d(names_sorted[ix], names_effnet)]
ix_effnet_tf = ix[np.in1d(names_sorted[ix], names_effnet_tf)]
ix_resnet = ix[np.in1d(names_sorted[ix], names_resnet)]
ix_resnet_ttp = ix[np.in1d(names_sorted[ix], names_resnet_ttp)]
ax1.bar(ix_effnet, acc_sorted[ix_effnet], color='r', label='EfficientNet')
ax1.bar(ix_effnet_tf, acc_sorted[ix_effnet_tf], color='#8C001A', label='TF-EfficientNet')
ax1.bar(ix_resnet, acc_sorted[ix_resnet], color='b', label='ResNet')
ax1.bar(ix_resnet_ttp, acc_sorted[ix_resnet_ttp], color='#43C6DB', label='ResNet + TTP')
plt.ylim([math.ceil(acc_min - .3*(acc_max - acc_min)),
math.ceil(acc_max + .3*(acc_max - acc_min))])
ax1.set_title('Top-1 Comparison')
ax1.set_ylabel('Top-1 Accuracy (%)')
ax1.set_xlabel('Network Architecture')
ax1.set_xticks(ix)
ax1.set_xticklabels(names_sorted, rotation='45', ha='right')
ax1.legend()
plt.show()
No surprises here, exactly as per the EfficientNet paper, they are in a class of their own in terms of parameter efficiency.
The test time pooling effectively increases the parameter efficiency of the ResNet models, but at the cost of both throughput and memory efficency (see later graphs).
I'm not going to repeat the FLOP differences as there are again no surprises, same as paper barring differences in the models being compare to. If you are looking at FLOP counts for the EfficientNet models, do keep in mind, their counts appear to be for inference optiized models with the BatcNorm layers fused. The counts will be higher if you're working with trainable models that still have their BN layers. You can see some counts I did on ONNX optimized models here (https://github.com/rwightman/gen-efficientnet-pytorch/blob/master/BENCHMARK.md)
params_effnet = np.array([results[m]['param_count'] for m in names_effnet])
params_effnet_tf = np.array([results[m]['param_count'] for m in names_effnet_tf])
params_resnet = np.array([results[m]['param_count'] for m in names_resnet])
params_resnet_ttp = np.array([results[m]['param_count'] for m in names_resnet_ttp])
fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.scatter(params_effnet, acc_effnet, s=10, c='r', marker="s", label='EfficientNet')
ax1.plot(params_effnet, acc_effnet, c='r')
annotate(ax1, params_effnet, acc_effnet, names_effnet, xo=-.5, align='right')
ax1.scatter(params_effnet_tf, acc_effnet_tf, s=10, c='#8C001A', marker="v", label='TF-EfficientNet')
ax1.plot(params_effnet_tf, acc_effnet_tf, c='#8C001A')
annotate(ax1, params_effnet_tf, acc_effnet_tf, names_effnet_tf, xo=.5, align='left')
ax1.scatter(params_resnet, acc_resnet, s=10, c='b', marker="o", label='ResNet')
ax1.plot(params_resnet, acc_resnet, c='b')
annotate(ax1, params_resnet, acc_resnet, names_resnet, xo=0.5, align='left')
ax1.scatter(params_resnet_ttp, acc_resnet_ttp, s=10, c='#43C6DB', marker="x", label='ResNet TTP')
ax1.plot(params_resnet_ttp, acc_resnet_ttp, c='#43C6DB')
annotate(ax1, params_resnet_ttp, acc_resnet_ttp, names_resnet_ttp, xo=0.3, align='left')
ax1.set_title('Top-1 vs Parameter Count')
ax1.set_ylabel('Top-1 Accuracy (%)')
ax1.set_xlabel('Parameters (Millions)')
ax1.legend()
plt.show()
One of the first thing I noticed running batches through my first ported EfficientNet weights -- the image throughput does not scale with FLOP or parameter counts. Much larger ResNet, DPN, etc. models can match the throughput of EfficientNet models with far fewer parameters and FLOPS. I've trained on many of these models and training throughputs do -- in relative terms -- mirror the validation numbers here.
This was surprising to me given the FLOP ratios. I'd like to see an in depth comparison with Tensorflow, XLA enabled, targeted for both GPU and TPU.
print('Results by image rate:')
results_by_rate = list(sorted(results.keys(), key=lambda x: results[x]['rate'], reverse=True))
for m in results_by_rate:
print(' {:32} Rate: {:>6.2f}, Top-1 {:.2f}, Top-5: {:.2f}'.format(
m, results[m]['rate'], results[m]['top1'], results[m]['top5']))
print()
Results by image rate: efficientnet_b0-224 Rate: 165.73, Top-1 64.58, Top-5: 85.89 resnet50-224 Rate: 159.51, Top-1 66.81, Top-5: 87.00 dpn68b-224 Rate: 155.15, Top-1 65.60, Top-5: 85.94 resnet50-240-ttp Rate: 154.35, Top-1 67.02, Top-5: 87.04 efficientnet_b1-240 Rate: 151.63, Top-1 67.55, Top-5: 87.29 gluon_seresnext50_32x4d-224 Rate: 150.43, Top-1 68.67, Top-5: 88.32 efficientnet_b2-260 Rate: 144.20, Top-1 67.80, Top-5: 88.20 tf_efficientnet_b2-260 Rate: 142.73, Top-1 67.40, Top-5: 87.58 resnet50-260-ttp Rate: 135.92, Top-1 67.63, Top-5: 87.63 gluon_seresnext101_32x4d-224 Rate: 131.57, Top-1 70.01, Top-5: 88.91 gluon_seresnext50_32x4d-260-ttp Rate: 126.52, Top-1 69.67, Top-5: 88.62 tf_efficientnet_b3-300 Rate: 119.13, Top-1 68.52, Top-5: 88.70 gluon_seresnext50_32x4d-300-ttp Rate: 104.69, Top-1 70.47, Top-5: 89.18 gluon_seresnext101_32x4d-260-ttp Rate: 95.84, Top-1 71.14, Top-5: 89.47 ig_resnext101_32x8d-224 Rate: 83.35, Top-1 73.83, Top-5: 92.28 gluon_seresnext101_32x4d-300-ttp Rate: 74.87, Top-1 71.99, Top-5: 90.10 tf_efficientnet_b4-380 Rate: 69.10, Top-1 71.34, Top-5: 90.11 ig_resnext101_32x8d-300-ttp Rate: 43.62, Top-1 75.17, Top-5: 92.66
rate_effnet = np.array([results[m]['rate'] for m in names_effnet])
rate_effnet_tf = np.array([results[m]['rate'] for m in names_effnet_tf])
rate_resnet = np.array([results[m]['rate'] for m in names_resnet])
rate_resnet_ttp = np.array([results[m]['rate'] for m in names_resnet_ttp])
fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.scatter(rate_effnet, acc_effnet, s=10, c='r', marker="s", label='EfficientNet')
ax1.plot(rate_effnet, acc_effnet, c='r')
annotate(ax1, rate_effnet, acc_effnet, names_effnet, xo=.5, align='left')
ax1.scatter(rate_effnet_tf, acc_effnet_tf, s=10, c='#8C001A', marker="v", label='TF-EfficientNet')
ax1.plot(rate_effnet_tf, acc_effnet_tf, c='#8C001A')
annotate(ax1, rate_effnet_tf, acc_effnet_tf, names_effnet_tf, xo=-.5, yo=-.2, align='right')
ax1.scatter(rate_resnet, acc_resnet, s=10, c='b', marker="o", label='ResNet')
ax1.plot(rate_resnet, acc_resnet, c='b')
annotate(ax1, rate_resnet, acc_resnet, names_resnet, xo=.3, align='left')
ax1.scatter(rate_resnet_ttp, acc_resnet_ttp, s=10, c='#43C6DB', marker="x", label='ResNet TPP')
ax1.plot(rate_resnet_ttp, acc_resnet_ttp, c='#43C6DB')
annotate(ax1, rate_resnet_ttp, acc_resnet_ttp, names_resnet_ttp, xo=0., yo=0., align='center')
ax1.set_title('Top-1 vs Rate')
ax1.set_ylabel('Top-1 Accuracy (%)')
ax1.set_xlabel('Rate (Images / sec)')
ax1.legend()
plt.show()
Measuring the 'practical' GPU memory consumption is a bit of a challenge. By 'practical', what I want to capture is relative GPU memory usage that indicates what the likely maximum batch sizes will be. With cudnn.benchmark = True
set, the torch memory allocator metrics didn't prove reliable. In the end, using pynvml (same output as nvidia-smi) and taking a sample part way through the validation set is the most consistent.
I've verified the sampling by pushing batch sizes for several of the models to the point where they fail with OOM exception. The relative measures of the memory usage match the relative batch sizes -- I can roughly predict where the largest batch size will be from the measure.
On a T4 colab instance I pushed:
Overall, the EfficientNets are not particularly memory efficient. The monster ResNext101-32x8d with 88M params is more memory efficient at 224x224 than the EfficientNet-B2 at 260x260 with 9.1M. This is especially true for the 'tf' variants with the 'SAME' padding hack enabled, there is up to a 20% penalty for this in memory churn that does impact the max useable batch size.
print('Results by GPU memory usage:')
results_by_mem = list(sorted(results.keys(), key=lambda x: results[x]['gpu_used'], reverse=False))
for m in results_by_mem:
print(' {:32} Mem: {}, Rate: {:>6.2f}, Top-1 {:.2f}, Top-5: {:.2f}'.format(
m, results[m]['gpu_used'], results[m]['rate'], results[m]['top1'], results[m]['top5']))
Results by GPU memory usage: resnet50-224 Mem: 1530, Rate: 159.51, Top-1 66.81, Top-5: 87.00 gluon_seresnext50_32x4d-224 Mem: 1670, Rate: 150.43, Top-1 68.67, Top-5: 88.32 gluon_seresnext101_32x4d-224 Mem: 1814, Rate: 131.57, Top-1 70.01, Top-5: 88.91 resnet50-240-ttp Mem: 2084, Rate: 154.35, Top-1 67.02, Top-5: 87.04 gluon_seresnext101_32x4d-260-ttp Mem: 2452, Rate: 95.84, Top-1 71.14, Top-5: 89.47 resnet50-260-ttp Mem: 2532, Rate: 135.92, Top-1 67.63, Top-5: 87.63 gluon_seresnext50_32x4d-260-ttp Mem: 2586, Rate: 126.52, Top-1 69.67, Top-5: 88.62 dpn68b-224 Mem: 2898, Rate: 155.15, Top-1 65.60, Top-5: 85.94 efficientnet_b0-224 Mem: 2930, Rate: 165.73, Top-1 64.58, Top-5: 85.89 gluon_seresnext101_32x4d-300-ttp Mem: 3252, Rate: 74.87, Top-1 71.99, Top-5: 90.10 gluon_seresnext50_32x4d-300-ttp Mem: 3300, Rate: 104.69, Top-1 70.47, Top-5: 89.18 efficientnet_b1-240 Mem: 3370, Rate: 151.63, Top-1 67.55, Top-5: 87.29 ig_resnext101_32x8d-224 Mem: 3382, Rate: 83.35, Top-1 73.83, Top-5: 92.28 efficientnet_b2-260 Mem: 3992, Rate: 144.20, Top-1 67.80, Top-5: 88.20 ig_resnext101_32x8d-300-ttp Mem: 4658, Rate: 43.62, Top-1 75.17, Top-5: 92.66 tf_efficientnet_b2-260 Mem: 4690, Rate: 142.73, Top-1 67.40, Top-5: 87.58 tf_efficientnet_b3-300 Mem: 8638, Rate: 119.13, Top-1 68.52, Top-5: 88.70 tf_efficientnet_b4-380 Mem: 11754, Rate: 69.10, Top-1 71.34, Top-5: 90.11
mem_effnet = np.array([results[m]['gpu_used'] for m in names_effnet])
mem_effnet_tf = np.array([results[m]['gpu_used'] for m in names_effnet_tf])
mem_resnet = np.array([results[m]['gpu_used'] for m in names_resnet])
mem_resnet_ttp = np.array([results[m]['gpu_used'] for m in names_resnet_ttp])
fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.scatter(mem_effnet, acc_effnet, s=10, c='r', marker="s", label='EfficientNet')
ax1.plot(mem_effnet, acc_effnet, c='r')
annotate(ax1, mem_effnet, acc_effnet, names_effnet, xo=-.3, align='right')
ax1.scatter(mem_effnet_tf, acc_effnet_tf, s=10, c='#8C001A', marker="v", label='TF-EfficientNet')
ax1.plot(mem_effnet_tf, acc_effnet_tf, c='#8C001A')
annotate(ax1, mem_effnet_tf, acc_effnet_tf, names_effnet_tf, xo=-.3, align='right')
ax1.scatter(mem_resnet, acc_resnet, s=10, c='b', marker="o", label='ResNet')
ax1.plot(mem_resnet, acc_resnet, c='b')
annotate(ax1, mem_resnet, acc_resnet, names_resnet, xo=.5, align='left')
# Too busy
#ax1.scatter(mem_resnet_ttp, acc_resnet_ttp, s=10, c='#43C6DB', marker="o", label='ResNet TTP')
#ax1.plot(mem_resnet_ttp, acc_resnet_ttp, c='#43C6DB')
#annotate(ax1, mem_resnet_ttp, acc_resnet_ttp, names_resnet_ttp, xo=.5, align='left')
ax1.set_title('Top-1 vs GPU Memory')
ax1.set_ylabel('Top-1 Accuracy (%)')
ax1.set_xlabel('GPU Memory (MB)')
ax1.legend()
plt.show()
A few model to model comparisons, pairing models that are a little more fair than the original paper when you consider all of accuracy, rate, and memory efficiency.
def compare_results(results, namea, nameb):
resa, resb = results[namea], results[nameb]
top1r = 100. * (resa['top1'] - resb['top1']) / resb['top1']
top5r = 100. * (resa['top5'] - resb['top5']) / resb['top5']
rater = 100. * (resa['rate'] - resb['rate']) / resb['rate']
memr = 100. * (resa['gpu_used'] - resb['gpu_used']) / resb['gpu_used']
print('{:22} vs {:28} top1: {:+4.2f}%, top5: {:+4.2f}%, rate: {:+4.2f}%, mem: {:+.2f}%'.format(
namea, nameb, top1r, top5r, rater, memr))
#compare_results(results, 'efficientnet_b0-224', 'seresnext26_32x4d-224')
compare_results(results, 'efficientnet_b0-224', 'dpn68b-224')
compare_results(results, 'efficientnet_b1-240', 'resnet50-224')
compare_results(results, 'efficientnet_b1-240', 'resnet50-240-ttp')
compare_results(results, 'efficientnet_b2-260', 'gluon_seresnext50_32x4d-224')
compare_results(results, 'tf_efficientnet_b3-300', 'gluon_seresnext50_32x4d-224')
compare_results(results, 'tf_efficientnet_b3-300', 'gluon_seresnext101_32x4d-224')
compare_results(results, 'tf_efficientnet_b4-380', 'ig_resnext101_32x8d-224')
print('\nNote the cost of running with the SAME padding hack:')
compare_results(results, 'tf_efficientnet_b2-260', 'efficientnet_b2-260')
efficientnet_b0-224 vs dpn68b-224 top1: -1.55%, top5: -0.06%, rate: +6.82%, mem: +1.10% efficientnet_b1-240 vs resnet50-224 top1: +1.11%, top5: +0.33%, rate: -4.94%, mem: +120.26% efficientnet_b1-240 vs resnet50-240-ttp top1: +0.79%, top5: +0.29%, rate: -1.76%, mem: +61.71% efficientnet_b2-260 vs gluon_seresnext50_32x4d-224 top1: -1.27%, top5: -0.14%, rate: -4.14%, mem: +139.04% tf_efficientnet_b3-300 vs gluon_seresnext50_32x4d-224 top1: -0.22%, top5: +0.43%, rate: -20.81%, mem: +417.25% tf_efficientnet_b3-300 vs gluon_seresnext101_32x4d-224 top1: -2.13%, top5: -0.24%, rate: -9.45%, mem: +376.19% tf_efficientnet_b4-380 vs ig_resnext101_32x8d-224 top1: -3.37%, top5: -2.35%, rate: -17.10%, mem: +247.55% Note the cost of running with the SAME padding hack: tf_efficientnet_b2-260 vs efficientnet_b2-260 top1: -0.59%, top5: -0.70%, rate: -1.02%, mem: +17.48%
This is often an interesting comparison. The results for the IG ResNeXt are impressive, it's the lowest gap between ImageNet-1k and ImageNet-V2 validation scores that I've seen (http://people.csail.mit.edu/ludwigs/papers/imagenet.pdf).
print('Results by absolute accuracy gap between ImageNet-V2 Matched-Frequency and original ImageNet top-1:')
no_ttp_keys = [k for k in results.keys() if 'ttp' not in k]
gaps = {x: (results[x]['top1'] - orig_top1[results[x]['model_name']]) for x in no_ttp_keys}
sorted_keys = list(sorted(no_ttp_keys, key=lambda x: gaps[x], reverse=True))
for m in sorted_keys:
print(' Model: {:34} {:4.2f}%'.format(m, gaps[m]))
print()
print('Results by relative accuracy gap between ImageNet-V2 Matched-Frequency and original ImageNet top-1:')
gaps = {x: 100 * (results[x]['top1'] - orig_top1[results[x]['model_name']]) / orig_top1[results[x]['model_name']] for x in no_ttp_keys}
sorted_keys = list(sorted(no_ttp_keys, key=lambda x: gaps[x], reverse=True))
for m in sorted_keys:
print(' Model: {:34} {:4.2f}%'.format(m, gaps[m]))
Results by absolute accuracy gap between ImageNet-V2 Matched-Frequency and original ImageNet top-1: Model: ig_resnext101_32x8d-224 -8.86% Model: gluon_seresnext101_32x4d-224 -10.89% Model: efficientnet_b1-240 -11.14% Model: gluon_seresnext50_32x4d-224 -11.24% Model: tf_efficientnet_b4-380 -11.26% Model: resnet50-224 -11.68% Model: dpn68b-224 -11.91% Model: efficientnet_b2-260 -11.96% Model: tf_efficientnet_b2-260 -12.21% Model: efficientnet_b0-224 -12.33% Model: tf_efficientnet_b3-300 -12.35% Results by relative accuracy gap between ImageNet-V2 Matched-Frequency and original ImageNet top-1: Model: ig_resnext101_32x8d-224 -10.71% Model: gluon_seresnext101_32x4d-224 -13.46% Model: tf_efficientnet_b4-380 -13.64% Model: gluon_seresnext50_32x4d-224 -14.07% Model: efficientnet_b1-240 -14.16% Model: resnet50-224 -14.88% Model: efficientnet_b2-260 -14.99% Model: tf_efficientnet_b3-300 -15.28% Model: tf_efficientnet_b2-260 -15.33% Model: dpn68b-224 -15.37% Model: efficientnet_b0-224 -16.03%