In case of problems or questions, please first check the list of Frequently Asked Questions (FAQ).
Please shutdown all other training/prediction notebooks before running this notebook (as those might occupy the GPU memory otherwise).
from __future__ import print_function, unicode_literals, absolute_import, division
import sys
import numpy as np
import matplotlib
matplotlib.rcParams["image.interpolation"] = None
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
from glob import glob
from tqdm import tqdm
from tifffile import imread
from csbdeep.utils import Path, normalize
from stardist import fill_label_holes, random_label_cmap, calculate_extents, gputools_available
from stardist.matching import matching, matching_dataset
from stardist.models import Config2D, StarDist2D, StarDistData2D
np.random.seed(42)
lbl_cmap = random_label_cmap()
Using TensorFlow backend.
We assume that data has already been downloaded via notebook 1_data.ipynb.
X = sorted(glob('data/dsb2018/train/images/*.tif'))
Y = sorted(glob('data/dsb2018/train/masks/*.tif'))
assert all(Path(x).name==Path(y).name for x,y in zip(X,Y))
X = list(map(imread,X))
Y = list(map(imread,Y))
n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]
Normalize images and fill small label holes.
axis_norm = (0,1) # normalize channels independently
# axis_norm = (0,1,2) # normalize channels jointly
if n_channel > 1:
print("Normalizing image channels %s." % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))
sys.stdout.flush()
X = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X)]
Y = [fill_label_holes(y) for y in tqdm(Y)]
100%|██████████| 447/447 [00:01<00:00, 390.04it/s] 100%|██████████| 447/447 [00:04<00:00, 89.76it/s]
Split into train and validation datasets.
assert len(X) > 1, "not enough training data"
rng = np.random.RandomState(42)
ind = rng.permutation(len(X))
n_val = max(1, int(round(0.15 * len(ind))))
ind_train, ind_val = ind[:-n_val], ind[-n_val:]
X_val, Y_val = [X[i] for i in ind_val] , [Y[i] for i in ind_val]
X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train]
print('number of images: %3d' % len(X))
print('- training: %3d' % len(X_trn))
print('- validation: %3d' % len(X_val))
number of images: 447 - training: 380 - validation: 67
Training data consists of pairs of input image and label instances.
def plot_img_label(img, lbl, img_title="image", lbl_title="label", **kwargs):
fig, (ai,al) = plt.subplots(1,2, figsize=(12,5), gridspec_kw=dict(width_ratios=(1.25,1)))
im = ai.imshow(img, cmap='gray', clim=(0,1))
ai.set_title(img_title)
fig.colorbar(im, ax=ai)
al.imshow(lbl, cmap=lbl_cmap)
al.set_title(lbl_title)
plt.tight_layout()
i = min(9, len(X)-1)
img, lbl = X[i], Y[i]
assert img.ndim in (2,3)
img = img if (img.ndim==2 or img.shape[-1]==3) else img[...,0]
plot_img_label(img,lbl)
None;
A StarDist2D
model is specified via a Config2D
object.
print(Config2D.__doc__)
Configuration for a :class:`StarDist2D` model. Parameters ---------- axes : str or None Axes of the input images. n_rays : int Number of radial directions for the star-convex polygon. Recommended to use a power of 2 (default: 32). n_channel_in : int Number of channels of given input image (default: 1). grid : (int,int) Subsampling factors (must be powers of 2) for each of the axes. Model will predict on a subsampled grid for increased efficiency and larger field of view. backbone : str Name of the neural network architecture to be used as backbone. kwargs : dict Overwrite (or add) configuration attributes (see below). Attributes ---------- unet_n_depth : int Number of U-Net resolution levels (down/up-sampling layers). unet_kernel_size : (int,int) Convolution kernel size for all (U-Net) convolution layers. unet_n_filter_base : int Number of convolution kernels (feature channels) for first U-Net layer. Doubled after each down-sampling layer. unet_pool : (int,int) Maxpooling size for all (U-Net) convolution layers. net_conv_after_unet : int Number of filters of the extra convolution layer after U-Net (0 to disable). unet_* : * Additional parameters for U-net backbone. train_shape_completion : bool Train model to predict complete shapes for partially visible objects at image boundary. train_completion_crop : int If 'train_shape_completion' is set to True, specify number of pixels to crop at boundary of training patches. Should be chosen based on (largest) object sizes. train_patch_size : (int,int) Size of patches to be cropped from provided training images. train_background_reg : float Regularizer to encourage distance predictions on background regions to be 0. train_foreground_only : float Fraction (0..1) of patches that will only be sampled from regions that contain foreground pixels. train_dist_loss : str Training loss for star-convex polygon distances ('mse' or 'mae'). train_loss_weights : tuple of float Weights for losses relating to (probability, distance) train_epochs : int Number of training epochs. train_steps_per_epoch : int Number of parameter update steps per epoch. train_learning_rate : float Learning rate for training. train_batch_size : int Batch size for training. train_n_val_patches : int Number of patches to be extracted from validation images (``None`` = one patch per image). train_tensorboard : bool Enable TensorBoard for monitoring training progress. train_reduce_lr : dict Parameter :class:`dict` of ReduceLROnPlateau_ callback; set to ``None`` to disable. use_gpu : bool Indicate that the data generator should use OpenCL to do computations on the GPU. .. _ReduceLROnPlateau: https://keras.io/callbacks/#reducelronplateau
# 32 is a good default choice (see 1_data.ipynb)
n_rays = 32
# Use OpenCL-based computations for data generator during training (requires 'gputools')
use_gpu = False and gputools_available()
# Predict on subsampled grid for increased efficiency and larger field of view
grid = (2,2)
conf = Config2D (
n_rays = n_rays,
grid = grid,
use_gpu = use_gpu,
n_channel_in = n_channel,
)
print(conf)
vars(conf)
Config2D(axes='YXC', backbone='unet', grid=(2, 2), n_channel_in=1, n_channel_out=33, n_dim=2, n_rays=32, net_conv_after_unet=128, net_input_shape=(None, None, 1), net_mask_shape=(None, None, 1), train_background_reg=0.0001, train_batch_size=4, train_checkpoint='weights_best.h5', train_checkpoint_epoch='weights_now.h5', train_checkpoint_last='weights_last.h5', train_completion_crop=32, train_dist_loss='mae', train_epochs=400, train_foreground_only=0.9, train_learning_rate=0.0003, train_loss_weights=(1, 0.2), train_n_val_patches=None, train_patch_size=(256, 256), train_reduce_lr={'factor': 0.5, 'patience': 40, 'min_delta': 0}, train_shape_completion=False, train_steps_per_epoch=100, train_tensorboard=True, unet_activation='relu', unet_batch_norm=False, unet_dropout=0.0, unet_kernel_size=(3, 3), unet_last_activation='relu', unet_n_conv_per_depth=2, unet_n_depth=3, unet_n_filter_base=32, unet_pool=(2, 2), unet_prefix='', use_gpu=False)
{'axes': 'YXC', 'backbone': 'unet', 'grid': (2, 2), 'n_channel_in': 1, 'n_channel_out': 33, 'n_dim': 2, 'n_rays': 32, 'net_conv_after_unet': 128, 'net_input_shape': (None, None, 1), 'net_mask_shape': (None, None, 1), 'train_background_reg': 0.0001, 'train_batch_size': 4, 'train_checkpoint': 'weights_best.h5', 'train_checkpoint_epoch': 'weights_now.h5', 'train_checkpoint_last': 'weights_last.h5', 'train_completion_crop': 32, 'train_dist_loss': 'mae', 'train_epochs': 400, 'train_foreground_only': 0.9, 'train_learning_rate': 0.0003, 'train_loss_weights': (1, 0.2), 'train_n_val_patches': None, 'train_patch_size': (256, 256), 'train_reduce_lr': {'factor': 0.5, 'min_delta': 0, 'patience': 40}, 'train_shape_completion': False, 'train_steps_per_epoch': 100, 'train_tensorboard': True, 'unet_activation': 'relu', 'unet_batch_norm': False, 'unet_dropout': 0.0, 'unet_kernel_size': (3, 3), 'unet_last_activation': 'relu', 'unet_n_conv_per_depth': 2, 'unet_n_depth': 3, 'unet_n_filter_base': 32, 'unet_pool': (2, 2), 'unet_prefix': '', 'use_gpu': False}
if use_gpu:
from csbdeep.utils.tf import limit_gpu_memory
# adjust as necessary: limit GPU memory to be used by TensorFlow to leave some to OpenCL-based computations
limit_gpu_memory(0.8)
# alternatively, try this:
# limit_gpu_memory(None, allow_growth=True)
Note: The trained StarDist2D
model will not predict completed shapes for partially visible objects at the image boundary if train_shape_completion=False
(which is the default option).
model = StarDist2D(conf, name='stardist', basedir='models')
Using default values: prob_thresh=0.5, nms_thresh=0.4.
Check if the neural network has a large enough field of view to see up to the boundary of most objects.
median_size = calculate_extents(list(Y), np.median)
fov = np.array(model._axes_tile_overlap('YX'))
print(f"median object size: {median_size}")
print(f"network field of view : {fov}")
if any(median_size > fov):
print("WARNING: median object size larger than field of view of the neural network.")
You can define a function/callable that applies augmentation to each batch of the data generator.
We here use an augmenter
that applies random rotations, flips, and intensity changes, which are typically sensible for (2D) microscopy images (but you can disable augmentation by setting augmenter = None
).
def random_fliprot(img, mask):
assert img.ndim >= mask.ndim
axes = tuple(range(mask.ndim))
perm = tuple(np.random.permutation(axes))
img = img.transpose(perm + tuple(range(mask.ndim, img.ndim)))
mask = mask.transpose(perm)
for ax in axes:
if np.random.rand() > 0.5:
img = np.flip(img, axis=ax)
mask = np.flip(mask, axis=ax)
return img, mask
def random_intensity_change(img):
img = img*np.random.uniform(0.6,2) + np.random.uniform(-0.2,0.2)
return img
def augmenter(x, y):
"""Augmentation of a single input/label image pair.
x is an input image
y is the corresponding ground-truth label image
"""
x, y = random_fliprot(x, y)
x = random_intensity_change(x)
# add some gaussian noise
sig = 0.02*np.random.uniform(0,1)
x = x + sig*np.random.normal(0,1,x.shape)
return x, y
# plot some augmented examples
img, lbl = X[0],Y[0]
plot_img_label(img, lbl)
for _ in range(3):
img_aug, lbl_aug = augmenter(img,lbl)
plot_img_label(img_aug, lbl_aug, img_title="image augmented", lbl_title="label augmented")
We recommend to monitor the progress during training with TensorBoard. You can start it in the shell from the current working directory like this:
$ tensorboard --logdir=.
Then connect to http://localhost:6006/ with your browser.
quick_demo = True
if quick_demo:
print (
"NOTE: This is only for a quick demonstration!\n"
" Please set the variable 'quick_demo = False' for proper (long) training.",
file=sys.stderr, flush=True
)
model.train(X_trn, Y_trn, validation_data=(X_val,Y_val), augmenter=augmenter,
epochs=2, steps_per_epoch=10)
print("====> Stopping training and loading previously trained demo model from disk.", file=sys.stderr, flush=True)
model = StarDist2D.from_pretrained('2D_demo')
else:
model.train(X_trn, Y_trn, validation_data=(X_val,Y_val), augmenter=augmenter)
None;
NOTE: This is only for a quick demonstration! Please set the variable 'quick_demo = False' for proper (long) training.
Epoch 1/2 10/10 [==============================] - 4s 365ms/step - loss: 3.7202 - prob_loss: 0.5916 - dist_loss: 15.6428 - prob_kld: 0.5061 - dist_relevant_mae: 15.6428 - dist_relevant_mse: 376.8492 - val_loss: 3.0740 - val_prob_loss: 0.4098 - val_dist_loss: 13.3209 - val_prob_kld: 0.3434 - val_dist_relevant_mae: 13.3209 - val_dist_relevant_mse: 291.9536 Epoch 2/2 10/10 [==============================] - 1s 97ms/step - loss: 3.0843 - prob_loss: 0.3404 - dist_loss: 13.7193 - prob_kld: 0.2711 - dist_relevant_mae: 13.7192 - dist_relevant_mse: 323.0617 - val_loss: 2.7932 - val_prob_loss: 0.3304 - val_dist_loss: 12.3143 - val_prob_kld: 0.2640 - val_dist_relevant_mae: 12.3142 - val_dist_relevant_mse: 263.8004 Loading network weights from 'weights_best.h5'.
====> Stopping training and loading previously trained demo model from disk.
Found model '2D_demo' for 'StarDist2D'. Loading network weights from 'weights_best.h5'. Loading thresholds from 'thresholds.json'. Using default values: prob_thresh=0.486166, nms_thresh=0.5.
While the default values for the probability and non-maximum suppression thresholds already yield good results in many cases, we still recommend to adapt the thresholds to your data. The optimized threshold values are saved to disk and will be automatically loaded with the model.
if quick_demo:
model.optimize_thresholds(X_val[:2], Y_val[:2])
else:
model.optimize_thresholds(X_val, Y_val)
NMS threshold = 0.3: 75%|███████▌ | 15/20 [00:02<00:01, 4.54it/s, 0.511 -> 0.700] NMS threshold = 0.4: 75%|███████▌ | 15/20 [00:01<00:00, 6.60it/s, 0.511 -> 0.688] NMS threshold = 0.5: 75%|███████▌ | 15/20 [00:01<00:00, 6.63it/s, 0.511 -> 0.688]
Using optimized values: prob_thresh=0.508553, nms_thresh=0.3.
{'nms': 0.3, 'prob': 0.5085528305763067}
Besides the losses and metrics during training, we can also quantitatively evaluate the actual detection/segmentation performance on the validation data by considering objects in the ground truth to be correctly matched if there are predicted objects with overlap (here intersection over union (IoU)) beyond a chosen IoU threshold $\tau$.
The corresponding matching statistics (average overlap, accuracy, recall, precision, etc.) are typically of greater practical relevance than the losses/metrics computed during training (but harder to formulate as a loss function). The value of $\tau$ can be between 0 (even slightly overlapping objects count as correctly predicted) and 1 (only pixel-perfectly overlapping objects count) and which $\tau$ to use depends on the needed segmentation precision/application.
Please see help(matching)
for definitions of the abbreviations used in the evaluation below and see the Wikipedia page on Sensitivity and specificity for further details.
# help(matching)
First predict the labels for all validation images:
Y_val_pred = [model.predict_instances(x, n_tiles=model._guess_n_tiles(x), show_tile_progress=False)[0]
for x in tqdm(X_val)]
100%|██████████| 67/67 [00:06<00:00, 16.93it/s]
Plot a GT/prediction example
plot_img_label(X_val[0],Y_val[0], lbl_title="label GT")
plot_img_label(X_val[0],Y_val_pred[0], lbl_title="label Pred")
Choose several IoU thresholds $\tau$ that might be of interest and for each compute matching statistics for the validation data.
taus = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
stats = [matching_dataset(Y_val, Y_val_pred, thresh=t, show_progress=False) for t in tqdm(taus)]
100%|██████████| 9/9 [00:04<00:00, 2.00it/s]
Example: Print all available matching statistics for $\tau=0.5$
stats[taus.index(0.5)]
DatasetMatching(criterion='iou', thresh=0.5, fp=103, tp=2210, fn=333, precision=0.9554690877648077, recall=0.86905230043256, accuracy=0.8352229780801209, f1=0.9102141680395387, n_true=2543, n_pred=2313, mean_true_score=0.7361843108137301, mean_matched_score=0.8471116300449393, panoptic_quality=0.7710530075779719, by_image=False)
Plot the matching statistics and the number of true/false positives/negatives as a function of the IoU threshold $\tau$.
fig, (ax1,ax2) = plt.subplots(1,2, figsize=(15,5))
for m in ('precision', 'recall', 'accuracy', 'f1', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'):
ax1.plot(taus, [s._asdict()[m] for s in stats], '.-', lw=2, label=m)
ax1.set_xlabel(r'IoU threshold $\tau$')
ax1.set_ylabel('Metric value')
ax1.grid()
ax1.legend()
for m in ('fp', 'tp', 'fn'):
ax2.plot(taus, [s._asdict()[m] for s in stats], '.-', lw=2, label=m)
ax2.set_xlabel(r'IoU threshold $\tau$')
ax2.set_ylabel('Number #')
ax2.grid()
ax2.legend();