#!/usr/bin/env python
# coding: utf-8
#
#
# **If on Google Colab, first do this:**
#
# Click on `Runtime` > `Change runtime type` and select `GPU` as `Hardware accelerator`.
# # CSBDeep
#
# Technical infrastructure that powers CSBDeep (and StarDist) under the hood.
# #### Basic setup for this notebook
# In[1]:
try:
import google.colab
COLAB = True
except ModuleNotFoundError:
COLAB = False
if COLAB:
import sys
get_ipython().system('{sys.executable} -m pip install csbdeep stardist')
# In[2]:
import numpy as np
import matplotlib
matplotlib.rcParams["image.interpolation"] = 'none'
import matplotlib.pyplot as plt
get_ipython().run_line_magic('matplotlib', 'inline')
get_ipython().run_line_magic('config', "InlineBackend.figure_format = 'retina'")
# ## Keras with TensorFlow 1 & 2
# In[3]:
from csbdeep.utils.tf import keras_import, BACKEND as K
keras = keras_import()
# equivalent to either:
# - "import keras" # if using TensorFlow and separate Keras package
# - "from tensorflow import keras" # if using TensorFlow with integrated Keras
# can also do specific imports, e.g.:
Input, Dense = keras_import('layers', 'Input','Dense')
assert Input == keras.layers.Input and Dense == keras.layers.Dense
# ## Blocks and Nets
# In[4]:
from csbdeep.internals.nets import common_unet, custom_unet
from csbdeep.internals.blocks import unet_block, resnet_block
model = common_unet(residual=False,n_channel_out=2)((128,128,3))
model.summary()
# In[5]:
x = inputs = Input((128,128,3))
x = resnet_block(64)(x)
x = resnet_block(128, pool=(2,2))(x)
x = keras.layers.GlobalAveragePooling2D()(x)
x = Dense(32, activation='relu')(x)
x = outputs = Dense(1, activation='sigmoid')(x)
model = keras.Model(inputs, outputs)
model.summary()
# ## BaseConfig & BaseModel
# In[6]:
from csbdeep.models import BaseModel, BaseConfig
# In[7]:
BaseConfig()
# In[8]:
class MyConfig(BaseConfig):
def __init__(self, my_parameter, **kwargs):
super().__init__(**kwargs)
self.my_parameter = my_parameter
config = MyConfig(my_parameter=42)
config
# In[9]:
class MyModel(BaseModel):
@property
def _config_class(self):
return MyConfig
def _build(self):
pass
# In[10]:
# demo: delete model folder if it already exists
get_ipython().run_line_magic('rm', '-rf models/my_model')
# In[11]:
# create model folder and persist config
model = MyModel(config, 'my_model', basedir='models')
model
# In[12]:
get_ipython().run_line_magic('ls', 'models/my_model')
# In[13]:
get_ipython().run_line_magic('cat', 'models/my_model/config.json')
# In[14]:
# load model from folder (config and possibly trained weights)
model = MyModel(None, 'my_model', basedir='models')
model
# BaseModel has more to offer, some is shown below...
# In[15]:
[a for a in dir(model) if not a.startswith('__')]
# ## Registry for pretrained models
# In[16]:
try:
from stardist.models import StarDist2D
StarDist2D.from_pretrained()
except ModuleNotFoundError:
pass
# In[17]:
try:
StarDist2D.from_pretrained('Versatile (fluorescent nuclei)')
except ModuleNotFoundError:
pass
# In[18]:
MyModel.from_pretrained()
# In[19]:
from csbdeep.models import register_model, register_aliases
register_model(MyModel, 'my_model', 'http://example.com/my_model.zip', '')
register_aliases(MyModel, 'my_model', 'My minimal model', 'Another name for my model')
# In[20]:
MyModel.from_pretrained()
# ## Example: U-Net model for multi-class semantic segmentation
# Note that the focus is on demonstrating certain concepts rather than being a good/complete segmentation approach.
# ### Helper
# In[21]:
try:
import skimage
except ModuleNotFoundError:
raise RuntimeError("This demo needs scikit-image to run.")
from glob import glob
from tqdm import tqdm
from tifffile import imread
from pathlib import Path
from skimage.segmentation import find_boundaries
def crop(u,shape=(256,256)):
"""Crop central region of given shape"""
return u[tuple(slice((s-m)//2,(s-m)//2+m) for s,m in zip(u.shape,shape))]
def to_3class_label(lbl, onehot=True):
"""Convert instance labeling to background/inner/outer mask"""
b = find_boundaries(lbl,mode='outer')
res = (lbl>0).astype(np.uint8)
res[b] = 2
if onehot:
res = keras.utils.to_categorical(res,num_classes=3).reshape(lbl.shape+(3,))
return res
def dice_bce_loss(n_labels):
"""Combined crossentropy and dice loss"""
def _sum(a):
return K.sum(a, axis=(1,2), keepdims=True)
def dice_coef(y_true, y_pred):
return (2 * _sum(y_true * y_pred) + K.epsilon()) / (_sum(y_true) + _sum(y_pred) + K.epsilon())
def _loss(y_true, y_pred):
dice_loss = 0
for i in range(n_labels):
dice_loss += 1-dice_coef(y_true[...,i], y_pred[...,i])
return dice_loss/n_labels + K.categorical_crossentropy(y_true, y_pred)
return _loss
def datagen(X,Y,batch_size,seed=0):
"""Simple data augmentation"""
try:
ImageDataGenerator = keras.preprocessing.image.ImageDataGenerator
except AttributeError:
ImageDataGenerator = keras.src.legacy.preprocessing.image.ImageDataGenerator
g = ImageDataGenerator(horizontal_flip=True, vertical_flip=True,
rotation_range=10, shear_range=10, fill_mode='reflect')
assert seed is not None
gX = g.flow(X, batch_size=batch_size, seed=seed)
gY = g.flow(Y, batch_size=batch_size, seed=seed)
while True:
yield next(gX), next(gY)
# ### Data
# In[22]:
from csbdeep.utils import download_and_extract_zip_file, normalize
download_and_extract_zip_file(
url = 'https://github.com/mpicbg-csbd/stardist/releases/download/0.1.0/dsb2018.zip',
targetdir = 'data',
verbose = 1,
)
# In[23]:
# load and crop out central patch (for simplicity)
X = [crop(imread(x)) for x in sorted(glob('data/dsb2018/train/images/*.tif'))]
Y_label = [crop(imread(y)) for y in sorted(glob('data/dsb2018/train/masks/*.tif'))]
# normalize input image and convert label image to 3-class segmentation mask
X = [normalize(x,1,99.8) for x in tqdm(X)]
Y = [to_3class_label(y) for y in tqdm(Y_label)]
# convert to numpy arrays
X, Y, Y_label = np.expand_dims(np.stack(X),-1), np.stack(Y), np.stack(Y_label)
# In[24]:
i = 15
fig, (a0,a1,a2) = plt.subplots(1,3,figsize=(15,5))
a0.imshow(X[i,...,0],cmap='gray'); a0.set_title('input image')
a1.imshow(Y_label[i],cmap='tab20'); a1.set_title('label image')
a2.imshow(Y[i]); a2.set_title('segmentation mask')
fig.suptitle("Example")
None;
# ### Model
# In[25]:
from csbdeep.data import PadAndCropResizer
from csbdeep.utils import axes_check_and_normalize
from csbdeep.utils.tf import IS_TF_1, CARETensorBoardImage
if IS_TF_1:
raise NotImplementedError("For sake of simplicity, this example only works with TensorFlow 2.x")
class SegConfig(BaseConfig):
def __init__(self, unet_depth, **kwargs):
super().__init__(**kwargs)
self.unet_depth = unet_depth
class SegModel(BaseModel):
@property
def _config_class(self):
return SegConfig
def _build(self):
return common_unet(n_depth=self.config.unet_depth,
n_first=32, residual=False,
n_channel_out=self.config.n_channel_out,
last_activation='softmax')((None,None,self.config.n_channel_in))
def _prepare_for_training(self, validation_data, lr):
assert self.config.n_channel_out > 1
self.keras_model.compile(optimizer=keras.optimizers.Adam(lr),
loss=dice_bce_loss(self.config.n_channel_out),
metrics=['categorical_crossentropy','accuracy'])
self.callbacks = self._checkpoint_callbacks()
self.callbacks.append(keras.callbacks.TensorBoard(log_dir=str(self.logdir/'logs'),
write_graph=False, profile_batch=0))
self.callbacks.append(CARETensorBoardImage(model=self.keras_model, data=validation_data,
log_dir=str(self.logdir/'logs'/'images'),
n_images=3, prob_out=False))
self._model_prepared = True
def train(self, X,Y, validation_data, lr, batch_size, epochs, steps_per_epoch):
if not self._model_prepared:
self._prepare_for_training(validation_data, lr)
training_data = datagen(X,Y,batch_size)
history = self.keras_model.fit(training_data, validation_data=validation_data,
epochs=epochs, steps_per_epoch=steps_per_epoch,
callbacks=self.callbacks, verbose=1)
self._training_finished()
return history
def predict(self, img, axes=None, normalizer=None, resizer=PadAndCropResizer()):
normalizer, resizer = self._check_normalizer_resizer(normalizer, resizer)
axes_net = self.config.axes
if axes is None:
axes = axes_net
axes = axes_check_and_normalize(axes, img.ndim)
axes_net_div_by = tuple((2**self.config.unet_depth if a in 'XYZ' else 1) for a in axes_net)
x = self._make_permute_axes(axes, axes_net)(img)
x = normalizer(x, axes_net)
x = resizer.before(x, axes_net, axes_net_div_by)
pred = self.keras_model.predict(x[np.newaxis], verbose=0)[0]
pred = resizer.after(pred, axes_net)
return pred
# In[26]:
# demo: delete model folder if it already exists
get_ipython().run_line_magic('rm', '-rf models/seg_model')
# In[27]:
config = SegConfig(n_channel_in=1, n_channel_out=3, unet_depth=2)
model = SegModel(config, 'seg_model', basedir='models')
model
# In[28]:
model.keras_model.summary(line_length=110)
# ### Train
# In[29]:
from csbdeep.data import shuffle_inplace
# shuffle data
shuffle_inplace(X, Y, Y_label, seed=0)
# split into 80% training and 20% validation images
n_val = len(X) // 5
def split_train_val(a):
return a[:-n_val], a[-n_val:]
X_train, X_val = split_train_val(X)
Y_train, Y_val = split_train_val(Y)
Y_label_train, Y_label_val = split_train_val(Y_label)
# In[30]:
if COLAB:
get_ipython().run_line_magic('reload_ext', 'tensorboard')
get_ipython().run_line_magic('tensorboard', '--logdir=models')
# In[31]:
# for demonstration purposes: training only for a very short time here
history = model.train(X_train,Y_train, validation_data=(X_val,Y_val),
lr=3e-4, batch_size=4, epochs=10, steps_per_epoch=10)
# Model folder after training:
# In[32]:
get_ipython().run_line_magic('ls', 'models/seg_model')
# In[33]:
# only works if "tree" is installed
get_ipython().system('tree models/seg_model')
# Model weights at best validation loss are automatically loaded after training. Or when reloading the model from disk:
# In[34]:
model = SegModel(None, 'seg_model', basedir='models')
# ### Predict
# In[35]:
# can predict via keras model, but only works for properly-shaped and normalized images
Yhat_val = model.keras_model.predict(X_val, batch_size=8)
Yhat_val.shape
# In[36]:
i = 1
img, lbl, mask = X_val[i,:223,:223,0], Y_label_val[i,:223,:223], Y_val[i,:223,:223]
img.shape, lbl.shape, mask.shape
# In[37]:
# U-Net models expects input to be divisible by certain sizes, hence fails here.
try:
model.keras_model.predict(img[np.newaxis])
except Exception as e:
print(e)
# In[38]:
mask_pred = model.predict(img, axes='YX')
mask_pred.shape
# In[39]:
from skimage.measure import label
# threshold inner (green) and find connected components
lbl_pred = label(mask_pred[...,1] > 0.7)
fig, ((a0,a1,a2),(b0,b1,b2)) = plt.subplots(2,3,figsize=(15,10))
a0.imshow(img,cmap='gray'); a0.set_title('input image')
a1.imshow(lbl,cmap='tab20'); a1.set_title('label image')
a2.imshow(mask); a2.set_title('segmentation mask')
b0.axis('off')
b1.imshow(lbl_pred,cmap='tab20'); b1.set_title('label image (prediction)')
b2.imshow(mask_pred); b2.set_title('segmentation mask (prediction)')
fig.suptitle("Example")
None;
# ## Tile iterator to process large images
# In[40]:
from csbdeep.internals.predict import tile_iterator
help(tile_iterator)
# In[41]:
img = imread('data/dsb2018/test/images/5f9d29d6388c700f35a3c29fa1b1ce0c1cba6667d05fdb70bd1e89004dcf71ed.tif')
img = normalize(img, 1,99.8)
plt.figure(figsize=(8,8))
plt.imshow(img, clim=(0,1), cmap='gray')
plt.title(f"example image with shape = {img.shape}");
# In[42]:
import matplotlib.patches as patches
def process(x):
return model.predict(x, axes='YX')
img_processed = process(img)
img_processed_tiled = np.empty_like(img_processed)
###
block_sizes = (8,8)
n_block_overlaps = (3,5)
n_tiles = (3,5)
print(f"block_sizes = {block_sizes}")
print(f"n_block_overlaps = {n_block_overlaps}")
print(f"n_tiles = {n_tiles}")
fig, ax = plt.subplots(*n_tiles, figsize=(15,8))
ax = ax.ravel()
[a.axis('off') for a in ax]
i = 0
for tile,s_src,s_dst in tile_iterator(img, n_tiles, block_sizes, n_block_overlaps, guarantee='size'):
# tile is padded; will always start and end at a multiple of block size
# tile[s_src] removes the padding (shown in magenta)
# the slice s_dst denotes the region where tile[s_src] comes from
# process tile, crop the padded region from the result and put it at its original location
img_processed_tiled[s_dst] = process(tile)[s_src]
ax[i].imshow(tile, clim=(0,1), cmap='gray')
rect = patches.Rectangle( [s.start for s in reversed(s_src)],
*[s.stop-s.start for s in reversed(s_src)],
edgecolor='none',facecolor='m',alpha=0.6)
ax[i].add_patch(rect)
i+=1
plt.tight_layout()
assert np.allclose(img_processed, img_processed_tiled)
None;