#!/usr/bin/env python # coding: utf-8 # Open In Colab # # **If on Google Colab, first do this:** # # Click on `Runtime` > `Change runtime type` and select `GPU` as `Hardware accelerator`. # # CSBDeep # # Technical infrastructure that powers CSBDeep (and StarDist) under the hood. # #### Basic setup for this notebook # In[1]: try: import google.colab COLAB = True except ModuleNotFoundError: COLAB = False if COLAB: import sys get_ipython().system('{sys.executable} -m pip install csbdeep stardist') # In[2]: import numpy as np import matplotlib matplotlib.rcParams["image.interpolation"] = 'none' import matplotlib.pyplot as plt get_ipython().run_line_magic('matplotlib', 'inline') get_ipython().run_line_magic('config', "InlineBackend.figure_format = 'retina'") # ## Keras with TensorFlow 1 & 2 # In[3]: from csbdeep.utils.tf import keras_import, BACKEND as K keras = keras_import() # equivalent to either: # - "import keras" # if using TensorFlow and separate Keras package # - "from tensorflow import keras" # if using TensorFlow with integrated Keras # can also do specific imports, e.g.: Input, Dense = keras_import('layers', 'Input','Dense') assert Input == keras.layers.Input and Dense == keras.layers.Dense # ## Blocks and Nets # In[4]: from csbdeep.internals.nets import common_unet, custom_unet from csbdeep.internals.blocks import unet_block, resnet_block model = common_unet(residual=False,n_channel_out=2)((128,128,3)) model.summary() # In[5]: x = inputs = Input((128,128,3)) x = resnet_block(64)(x) x = resnet_block(128, pool=(2,2))(x) x = keras.layers.GlobalAveragePooling2D()(x) x = Dense(32, activation='relu')(x) x = outputs = Dense(1, activation='sigmoid')(x) model = keras.Model(inputs, outputs) model.summary() # ## BaseConfig & BaseModel # In[6]: from csbdeep.models import BaseModel, BaseConfig # In[7]: BaseConfig() # In[8]: class MyConfig(BaseConfig): def __init__(self, my_parameter, **kwargs): super().__init__(**kwargs) self.my_parameter = my_parameter config = MyConfig(my_parameter=42) config # In[9]: class MyModel(BaseModel): @property def _config_class(self): return MyConfig def _build(self): pass # In[10]: # demo: delete model folder if it already exists get_ipython().run_line_magic('rm', '-rf models/my_model') # In[11]: # create model folder and persist config model = MyModel(config, 'my_model', basedir='models') model # In[12]: get_ipython().run_line_magic('ls', 'models/my_model') # In[13]: get_ipython().run_line_magic('cat', 'models/my_model/config.json') # In[14]: # load model from folder (config and possibly trained weights) model = MyModel(None, 'my_model', basedir='models') model # BaseModel has more to offer, some is shown below... # In[15]: [a for a in dir(model) if not a.startswith('__')] # ## Registry for pretrained models # In[16]: try: from stardist.models import StarDist2D StarDist2D.from_pretrained() except ModuleNotFoundError: pass # In[17]: try: StarDist2D.from_pretrained('Versatile (fluorescent nuclei)') except ModuleNotFoundError: pass # In[18]: MyModel.from_pretrained() # In[19]: from csbdeep.models import register_model, register_aliases register_model(MyModel, 'my_model', 'http://example.com/my_model.zip', '') register_aliases(MyModel, 'my_model', 'My minimal model', 'Another name for my model') # In[20]: MyModel.from_pretrained() # ## Example: U-Net model for multi-class semantic segmentation # Note that the focus is on demonstrating certain concepts rather than being a good/complete segmentation approach. # ### Helper # In[21]: try: import skimage except ModuleNotFoundError: raise RuntimeError("This demo needs scikit-image to run.") from glob import glob from tqdm import tqdm from tifffile import imread from pathlib import Path from skimage.segmentation import find_boundaries def crop(u,shape=(256,256)): """Crop central region of given shape""" return u[tuple(slice((s-m)//2,(s-m)//2+m) for s,m in zip(u.shape,shape))] def to_3class_label(lbl, onehot=True): """Convert instance labeling to background/inner/outer mask""" b = find_boundaries(lbl,mode='outer') res = (lbl>0).astype(np.uint8) res[b] = 2 if onehot: res = keras.utils.to_categorical(res,num_classes=3).reshape(lbl.shape+(3,)) return res def dice_bce_loss(n_labels): """Combined crossentropy and dice loss""" def _sum(a): return K.sum(a, axis=(1,2), keepdims=True) def dice_coef(y_true, y_pred): return (2 * _sum(y_true * y_pred) + K.epsilon()) / (_sum(y_true) + _sum(y_pred) + K.epsilon()) def _loss(y_true, y_pred): dice_loss = 0 for i in range(n_labels): dice_loss += 1-dice_coef(y_true[...,i], y_pred[...,i]) return dice_loss/n_labels + K.categorical_crossentropy(y_true, y_pred) return _loss def datagen(X,Y,batch_size,seed=0): """Simple data augmentation""" try: ImageDataGenerator = keras.preprocessing.image.ImageDataGenerator except AttributeError: ImageDataGenerator = keras.src.legacy.preprocessing.image.ImageDataGenerator g = ImageDataGenerator(horizontal_flip=True, vertical_flip=True, rotation_range=10, shear_range=10, fill_mode='reflect') assert seed is not None gX = g.flow(X, batch_size=batch_size, seed=seed) gY = g.flow(Y, batch_size=batch_size, seed=seed) while True: yield next(gX), next(gY) # ### Data # In[22]: from csbdeep.utils import download_and_extract_zip_file, normalize download_and_extract_zip_file( url = 'https://github.com/mpicbg-csbd/stardist/releases/download/0.1.0/dsb2018.zip', targetdir = 'data', verbose = 1, ) # In[23]: # load and crop out central patch (for simplicity) X = [crop(imread(x)) for x in sorted(glob('data/dsb2018/train/images/*.tif'))] Y_label = [crop(imread(y)) for y in sorted(glob('data/dsb2018/train/masks/*.tif'))] # normalize input image and convert label image to 3-class segmentation mask X = [normalize(x,1,99.8) for x in tqdm(X)] Y = [to_3class_label(y) for y in tqdm(Y_label)] # convert to numpy arrays X, Y, Y_label = np.expand_dims(np.stack(X),-1), np.stack(Y), np.stack(Y_label) # In[24]: i = 15 fig, (a0,a1,a2) = plt.subplots(1,3,figsize=(15,5)) a0.imshow(X[i,...,0],cmap='gray'); a0.set_title('input image') a1.imshow(Y_label[i],cmap='tab20'); a1.set_title('label image') a2.imshow(Y[i]); a2.set_title('segmentation mask') fig.suptitle("Example") None; # ### Model # In[25]: from csbdeep.data import PadAndCropResizer from csbdeep.utils import axes_check_and_normalize from csbdeep.utils.tf import IS_TF_1, CARETensorBoardImage if IS_TF_1: raise NotImplementedError("For sake of simplicity, this example only works with TensorFlow 2.x") class SegConfig(BaseConfig): def __init__(self, unet_depth, **kwargs): super().__init__(**kwargs) self.unet_depth = unet_depth class SegModel(BaseModel): @property def _config_class(self): return SegConfig def _build(self): return common_unet(n_depth=self.config.unet_depth, n_first=32, residual=False, n_channel_out=self.config.n_channel_out, last_activation='softmax')((None,None,self.config.n_channel_in)) def _prepare_for_training(self, validation_data, lr): assert self.config.n_channel_out > 1 self.keras_model.compile(optimizer=keras.optimizers.Adam(lr), loss=dice_bce_loss(self.config.n_channel_out), metrics=['categorical_crossentropy','accuracy']) self.callbacks = self._checkpoint_callbacks() self.callbacks.append(keras.callbacks.TensorBoard(log_dir=str(self.logdir/'logs'), write_graph=False, profile_batch=0)) self.callbacks.append(CARETensorBoardImage(model=self.keras_model, data=validation_data, log_dir=str(self.logdir/'logs'/'images'), n_images=3, prob_out=False)) self._model_prepared = True def train(self, X,Y, validation_data, lr, batch_size, epochs, steps_per_epoch): if not self._model_prepared: self._prepare_for_training(validation_data, lr) training_data = datagen(X,Y,batch_size) history = self.keras_model.fit(training_data, validation_data=validation_data, epochs=epochs, steps_per_epoch=steps_per_epoch, callbacks=self.callbacks, verbose=1) self._training_finished() return history def predict(self, img, axes=None, normalizer=None, resizer=PadAndCropResizer()): normalizer, resizer = self._check_normalizer_resizer(normalizer, resizer) axes_net = self.config.axes if axes is None: axes = axes_net axes = axes_check_and_normalize(axes, img.ndim) axes_net_div_by = tuple((2**self.config.unet_depth if a in 'XYZ' else 1) for a in axes_net) x = self._make_permute_axes(axes, axes_net)(img) x = normalizer(x, axes_net) x = resizer.before(x, axes_net, axes_net_div_by) pred = self.keras_model.predict(x[np.newaxis], verbose=0)[0] pred = resizer.after(pred, axes_net) return pred # In[26]: # demo: delete model folder if it already exists get_ipython().run_line_magic('rm', '-rf models/seg_model') # In[27]: config = SegConfig(n_channel_in=1, n_channel_out=3, unet_depth=2) model = SegModel(config, 'seg_model', basedir='models') model # In[28]: model.keras_model.summary(line_length=110) # ### Train # In[29]: from csbdeep.data import shuffle_inplace # shuffle data shuffle_inplace(X, Y, Y_label, seed=0) # split into 80% training and 20% validation images n_val = len(X) // 5 def split_train_val(a): return a[:-n_val], a[-n_val:] X_train, X_val = split_train_val(X) Y_train, Y_val = split_train_val(Y) Y_label_train, Y_label_val = split_train_val(Y_label) # In[30]: if COLAB: get_ipython().run_line_magic('reload_ext', 'tensorboard') get_ipython().run_line_magic('tensorboard', '--logdir=models') # In[31]: # for demonstration purposes: training only for a very short time here history = model.train(X_train,Y_train, validation_data=(X_val,Y_val), lr=3e-4, batch_size=4, epochs=10, steps_per_epoch=10) # Model folder after training: # In[32]: get_ipython().run_line_magic('ls', 'models/seg_model') # In[33]: # only works if "tree" is installed get_ipython().system('tree models/seg_model') # Model weights at best validation loss are automatically loaded after training. Or when reloading the model from disk: # In[34]: model = SegModel(None, 'seg_model', basedir='models') # ### Predict # In[35]: # can predict via keras model, but only works for properly-shaped and normalized images Yhat_val = model.keras_model.predict(X_val, batch_size=8) Yhat_val.shape # In[36]: i = 1 img, lbl, mask = X_val[i,:223,:223,0], Y_label_val[i,:223,:223], Y_val[i,:223,:223] img.shape, lbl.shape, mask.shape # In[37]: # U-Net models expects input to be divisible by certain sizes, hence fails here. try: model.keras_model.predict(img[np.newaxis]) except Exception as e: print(e) # In[38]: mask_pred = model.predict(img, axes='YX') mask_pred.shape # In[39]: from skimage.measure import label # threshold inner (green) and find connected components lbl_pred = label(mask_pred[...,1] > 0.7) fig, ((a0,a1,a2),(b0,b1,b2)) = plt.subplots(2,3,figsize=(15,10)) a0.imshow(img,cmap='gray'); a0.set_title('input image') a1.imshow(lbl,cmap='tab20'); a1.set_title('label image') a2.imshow(mask); a2.set_title('segmentation mask') b0.axis('off') b1.imshow(lbl_pred,cmap='tab20'); b1.set_title('label image (prediction)') b2.imshow(mask_pred); b2.set_title('segmentation mask (prediction)') fig.suptitle("Example") None; # ## Tile iterator to process large images # In[40]: from csbdeep.internals.predict import tile_iterator help(tile_iterator) # In[41]: img = imread('data/dsb2018/test/images/5f9d29d6388c700f35a3c29fa1b1ce0c1cba6667d05fdb70bd1e89004dcf71ed.tif') img = normalize(img, 1,99.8) plt.figure(figsize=(8,8)) plt.imshow(img, clim=(0,1), cmap='gray') plt.title(f"example image with shape = {img.shape}"); # In[42]: import matplotlib.patches as patches def process(x): return model.predict(x, axes='YX') img_processed = process(img) img_processed_tiled = np.empty_like(img_processed) ### block_sizes = (8,8) n_block_overlaps = (3,5) n_tiles = (3,5) print(f"block_sizes = {block_sizes}") print(f"n_block_overlaps = {n_block_overlaps}") print(f"n_tiles = {n_tiles}") fig, ax = plt.subplots(*n_tiles, figsize=(15,8)) ax = ax.ravel() [a.axis('off') for a in ax] i = 0 for tile,s_src,s_dst in tile_iterator(img, n_tiles, block_sizes, n_block_overlaps, guarantee='size'): # tile is padded; will always start and end at a multiple of block size # tile[s_src] removes the padding (shown in magenta) # the slice s_dst denotes the region where tile[s_src] comes from # process tile, crop the padded region from the result and put it at its original location img_processed_tiled[s_dst] = process(tile)[s_src] ax[i].imshow(tile, clim=(0,1), cmap='gray') rect = patches.Rectangle( [s.start for s in reversed(s_src)], *[s.stop-s.start for s in reversed(s_src)], edgecolor='none',facecolor='m',alpha=0.6) ax[i].add_patch(rect) i+=1 plt.tight_layout() assert np.allclose(img_processed, img_processed_tiled) None;