from __future__ import print_function, unicode_literals, absolute_import, division
import sys
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
from glob import glob
from tqdm import tqdm
from tifffile import imread
from csbdeep.utils import Path, normalize
from stardist import fill_label_holes, random_label_cmap
from stardist import Config, StarDist, StarDistData
np.random.seed(42)
lbl_cmap = random_label_cmap()
Using TensorFlow backend.
We assume that data has already been downloaded in via notebook 1_data.ipynb.
In general, training data (for input X
with associated labels Y
) can be provided via lists of numpy arrays, where each image can have a different size. Alternatively, a single numpy array can also be used if all images have the same size.
Input images can either be two-dimensional (single-channel) or three-dimensional (multi-channel) arrays, where the channel axis comes last. Label images need to be integer-valued.
X = sorted(glob('data/dsb2018/train/images/*.tif'))
Y = sorted(glob('data/dsb2018/train/masks/*.tif'))
assert all(Path(x).name==Path(y).name for x,y in zip(X,Y))
X = list(map(imread,X))
Y = list(map(imread,Y))
n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]
Normalize images and fill small label holes.
axis_norm = (0,1) # normalize channels independently
# axis_norm = (0,1,2) # normalize channels jointly
if n_channel > 1:
print("Normalizing image channels %s." % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))
sys.stdout.flush()
X = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X)]
Y = [fill_label_holes(y) for y in tqdm(Y)]
100%|██████████| 447/447 [00:01<00:00, 409.57it/s] 100%|██████████| 447/447 [00:03<00:00, 111.88it/s]
Split into train and validation datasets.
rng = np.random.RandomState(42)
ind = rng.permutation(len(X))
n_val = int(round(0.15 * len(X)))
ind_train, ind_val = ind[:-n_val], ind[-n_val:]
X_val, Y_val = [X[i] for i in ind_val] , [Y[i] for i in ind_val]
X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train]
print('number of images: %3d' % len(X))
print('- training: %3d' % len(X_trn))
print('- validation: %3d' % len(X_val))
number of images: 447 - training: 380 - validation: 67
Training data consists of pairs of input image and label instances.
i = 9
img, lbl = X[i], Y[i]
img = img if img.ndim==2 else img[...,0]
plt.figure(figsize=(16,10))
plt.subplot(121); plt.imshow(img,cmap='gray'); plt.axis('off'); plt.title('Raw image')
plt.subplot(122); plt.imshow(lbl,cmap=lbl_cmap); plt.axis('off'); plt.title('GT labels')
None;
From the label instance image, all necessary data for training StarDist
can be computed via StarDistData
.
Note that this here is only for illustration, since it happens automatically when calling StarDist.train
(see below).
With shape_completion = False
(see Config
below), the trained StarDist
model will not predict completed shapes for partially visible cells at the image boundary. This is the default behavior.
np.random.seed(42)
data = StarDistData(X,Y,batch_size=1,n_rays=32,patch_size=(256,256),shape_completion=False)
(img,dist_mask), (prob,dist) = data[0]
fig, ax = plt.subplots(2,2, figsize=(12,12))
for a,d,cm,s in zip(ax.flat, [img,prob,dist_mask,dist], ['gray','magma','bone','viridis'],
['Input image','Object probability','Distance mask','Distance (0°)']):
a.imshow(d[0,...,0],cmap=cm)
a.set_title(s)
plt.tight_layout()
None;
With shape_completion = True
(see Config
below), the trained StarDist
model will predict completed shapes for partially visible cells at the image boundary. For this to work, the image needs to be cropped, which is controlled by the Config
parameter train_completion_crop
(default 32), which should be chosen based on the size of the objects. Furthermore, it may be a good idea to increase train_batch_size
to offset the reduced amount of pixels per training patch due to cropping.
np.random.seed(42)
data = StarDistData(X,Y,batch_size=1,n_rays=32,patch_size=(256,256),shape_completion=True)
(img,dist_mask), (prob,dist) = data[0]
fig, ax = plt.subplots(2,2, figsize=(12,12))
for a,d,cm,s in zip(ax.flat, [img,prob,dist_mask,dist], ['gray','magma','bone','viridis'],
['Input image','Object probability','Distance mask','Distance (0°)']):
a.imshow(d[0,...,0],cmap=cm)
a.set_title(s)
plt.tight_layout()
None;
A StarDist
model is specified via a Config
object.
print(Config.__doc__)
Configuration for a :class:`StarDist` model. Parameters ---------- n_rays : int Number of radial directions for the star-convex polygon. Recommended to use a power of 2 (default: 32). n_channel_in : int Number of channels of given input image (default: 1). kwargs : dict Overwrite (or add) configuration attributes (see below). Attributes ---------- unet_n_depth : int Number of U-Net resolution levels (down/up-sampling layers). unet_kernel_size : (int,int) Convolution kernel size for all (U-Net) convolution layers. unet_n_filter_base : int Number of convolution kernels (feature channels) for first U-Net layer. Doubled after each down-sampling layer. net_conv_after_unet : int Number of filters of the extra convolution layer after U-Net (0 to disable). train_shape_completion : bool Train model to predict complete shapes for partially visible objects at image boundary. train_completion_crop : int If 'train_shape_completion' is set to True, specify number of pixels to crop at boundary of training patches. Should be chosen based on (largest) object sizes. train_patch_size : (int,int) Size of patches to be cropped from provided training images. train_dist_loss : str Training loss for star-convex polygon distances ('mse' or 'mae'). train_epochs : int Number of training epochs. train_steps_per_epoch : int Number of parameter update steps per epoch. train_learning_rate : float Learning rate for training. train_batch_size : int Batch size for training. train_tensorboard : bool Enable TensorBoard for monitoring training progress. train_checkpoint : str Name of checkpoint file for model weights (only best are saved); set to ``None`` to disable. train_reduce_lr : dict Parameter :class:`dict` of ReduceLROnPlateau_ callback; set to ``None`` to disable. .. _ReduceLROnPlateau: https://keras.io/callbacks/#reducelronplateau
You can monitor the progress during training with TensorBoard by starting it from the current working directory:
$ tensorboard --logdir=.
Then connect to http://localhost:6006/ with your browser.
conf = Config(n_channel_in=n_channel, train_batch_size=4, train_shape_completion=False)
print(conf)
vars(conf)
Config(n_channel_in=1, n_rays=32, net_conv_after_unet=128, net_input_shape=(None, None, 1), net_mask_shape=(None, None, 1), train_batch_size=4, train_checkpoint='weights_best.h5', train_completion_crop=32, train_dist_loss='mae', train_epochs=100, train_learning_rate=0.0003, train_patch_size=(256, 256), train_reduce_lr={'factor': 0.5, 'patience': 10}, train_shape_completion=False, train_steps_per_epoch=400, train_tensorboard=True, unet_kernel_size=(3, 3), unet_n_depth=3, unet_n_filter_base=32)
{'n_channel_in': 1, 'n_rays': 32, 'net_conv_after_unet': 128, 'net_input_shape': (None, None, 1), 'net_mask_shape': (None, None, 1), 'train_batch_size': 4, 'train_checkpoint': 'weights_best.h5', 'train_completion_crop': 32, 'train_dist_loss': 'mae', 'train_epochs': 100, 'train_learning_rate': 0.0003, 'train_patch_size': (256, 256), 'train_reduce_lr': {'factor': 0.5, 'patience': 10}, 'train_shape_completion': False, 'train_steps_per_epoch': 400, 'train_tensorboard': True, 'unet_kernel_size': (3, 3), 'unet_n_depth': 3, 'unet_n_filter_base': 32}
model = StarDist(conf, name='stardist_no_shape_completion', basedir='models')
%%capture train_log
model.train(X_trn,Y_trn,validation_data=(X_val,Y_val))
# show train log
# train_log()
conf = Config(n_channel_in=n_channel, train_batch_size=7, train_shape_completion=True)
print(conf)
vars(conf)
Config(n_channel_in=1, n_rays=32, net_conv_after_unet=128, net_input_shape=(None, None, 1), train_batch_size=7, train_checkpoint='weights_best.h5', train_completion_crop=32, train_dist_loss='mae', train_epochs=100, train_learning_rate=0.0003, train_patch_size=(256, 256), train_reduce_lr={'factor': 0.5, 'patience': 10}, train_shape_completion=True, train_steps_per_epoch=400, train_tensorboard=True, unet_kernel_size=(3, 3), unet_n_depth=3, unet_n_filter_base=32)
{'n_channel_in': 1, 'n_rays': 32, 'net_conv_after_unet': 128, 'net_input_shape': (None, None, 1), 'train_batch_size': 7, 'train_checkpoint': 'weights_best.h5', 'train_completion_crop': 32, 'train_dist_loss': 'mae', 'train_epochs': 100, 'train_learning_rate': 0.0003, 'train_patch_size': (256, 256), 'train_reduce_lr': {'factor': 0.5, 'patience': 10}, 'train_shape_completion': True, 'train_steps_per_epoch': 400, 'train_tensorboard': True, 'unet_kernel_size': (3, 3), 'unet_n_depth': 3, 'unet_n_filter_base': 32}
model = StarDist(conf, name='stardist_shape_completion', basedir='models')
%%capture train_log
model.train(X_trn,Y_trn,validation_data=(X_val,Y_val))
# show train log
# train_log()