Technical infrastructure that powers CSBDeep (and StarDist) under the hood.
try:
import google.colab
COLAB = True
except ModuleNotFoundError:
COLAB = False
if COLAB:
import sys
!{sys.executable} -m pip install csbdeep stardist
import numpy as np
import matplotlib
matplotlib.rcParams["image.interpolation"] = 'none'
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
from csbdeep.utils.tf import keras_import, BACKEND as K
keras = keras_import()
# equivalent to either:
# - "import keras" # if using TensorFlow and separate Keras package
# - "from tensorflow import keras" # if using TensorFlow with integrated Keras
# can also do specific imports, e.g.:
Input, Dense = keras_import('layers', 'Input','Dense')
assert Input == keras.layers.Input and Dense == keras.layers.Dense
from csbdeep.internals.nets import common_unet, custom_unet
from csbdeep.internals.blocks import unet_block, resnet_block
model = common_unet(residual=False,n_channel_out=2)((128,128,3))
model.summary()
Model: "model" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input (InputLayer) [(None, 128, 128, 3 0 [] )] down_level_0_no_0 (Conv2D) (None, 128, 128, 16 448 ['input[0][0]'] ) down_level_0_no_1 (Conv2D) (None, 128, 128, 16 2320 ['down_level_0_no_0[0][0]'] ) max_0 (MaxPooling2D) (None, 64, 64, 16) 0 ['down_level_0_no_1[0][0]'] middle_0 (Conv2D) (None, 64, 64, 32) 4640 ['max_0[0][0]'] middle_2 (Conv2D) (None, 64, 64, 16) 4624 ['middle_0[0][0]'] up_sampling2d (UpSampling2D) (None, 128, 128, 16 0 ['middle_2[0][0]'] ) concatenate (Concatenate) (None, 128, 128, 32 0 ['up_sampling2d[0][0]', ) 'down_level_0_no_1[0][0]'] up_level_0_no_0 (Conv2D) (None, 128, 128, 16 4624 ['concatenate[0][0]'] ) up_level_0_no_2 (Conv2D) (None, 128, 128, 16 2320 ['up_level_0_no_0[0][0]'] ) conv2d (Conv2D) (None, 128, 128, 2) 34 ['up_level_0_no_2[0][0]'] activation (Activation) (None, 128, 128, 2) 0 ['conv2d[0][0]'] ================================================================================================== Total params: 19,010 Trainable params: 19,010 Non-trainable params: 0 __________________________________________________________________________________________________
x = inputs = Input((128,128,3))
x = resnet_block(64)(x)
x = resnet_block(128, pool=(2,2))(x)
x = keras.layers.GlobalAveragePooling2D()(x)
x = Dense(32, activation='relu')(x)
x = outputs = Dense(1, activation='sigmoid')(x)
model = keras.Model(inputs, outputs)
model.summary()
Model: "model_1" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_1 (InputLayer) [(None, 128, 128, 3 0 [] )] conv2d_1 (Conv2D) (None, 128, 128, 64 1792 ['input_1[0][0]'] ) activation_1 (Activation) (None, 128, 128, 64 0 ['conv2d_1[0][0]'] ) conv2d_3 (Conv2D) (None, 128, 128, 64 256 ['input_1[0][0]'] ) conv2d_2 (Conv2D) (None, 128, 128, 64 36928 ['activation_1[0][0]'] ) add (Add) (None, 128, 128, 64 0 ['conv2d_3[0][0]', ) 'conv2d_2[0][0]'] activation_2 (Activation) (None, 128, 128, 64 0 ['add[0][0]'] ) conv2d_4 (Conv2D) (None, 64, 64, 128) 73856 ['activation_2[0][0]'] activation_3 (Activation) (None, 64, 64, 128) 0 ['conv2d_4[0][0]'] conv2d_6 (Conv2D) (None, 64, 64, 128) 8320 ['activation_2[0][0]'] conv2d_5 (Conv2D) (None, 64, 64, 128) 147584 ['activation_3[0][0]'] add_1 (Add) (None, 64, 64, 128) 0 ['conv2d_6[0][0]', 'conv2d_5[0][0]'] activation_4 (Activation) (None, 64, 64, 128) 0 ['add_1[0][0]'] global_average_pooling2d (Glob (None, 128) 0 ['activation_4[0][0]'] alAveragePooling2D) dense (Dense) (None, 32) 4128 ['global_average_pooling2d[0][0]' ] dense_1 (Dense) (None, 1) 33 ['dense[0][0]'] ================================================================================================== Total params: 272,897 Trainable params: 272,897 Non-trainable params: 0 __________________________________________________________________________________________________
from csbdeep.models import BaseModel, BaseConfig
BaseConfig()
BaseConfig(axes='YXC', n_channel_in=1, n_channel_out=1, n_dim=2, train_checkpoint='weights_best.h5', train_checkpoint_epoch='weights_now.h5', train_checkpoint_last='weights_last.h5')
class MyConfig(BaseConfig):
def __init__(self, my_parameter, **kwargs):
super().__init__(**kwargs)
self.my_parameter = my_parameter
config = MyConfig(my_parameter=42)
config
MyConfig(axes='YXC', my_parameter=42, n_channel_in=1, n_channel_out=1, n_dim=2, train_checkpoint='weights_best.h5', train_checkpoint_epoch='weights_now.h5', train_checkpoint_last='weights_last.h5')
class MyModel(BaseModel):
@property
def _config_class(self):
return MyConfig
def _build(self):
pass
# demo: delete model folder if it already exists
%rm -rf models/my_model
# create model folder and persist config
model = MyModel(config, 'my_model', basedir='models')
model
MyModel(my_model): YXC → YXC ├─ Directory: /home/uwe/research/csbdeep/examples/examples/other/models/my_model └─ MyConfig(axes='YXC', my_parameter=42, n_channel_in=1, n_channel_out=1, n_dim=2, train_checkpoint='weights_best.h5', train_checkpoint_epoch='weights_now.h5', train_checkpoint_last='weights_last.h5')
%ls models/my_model
config.json
%cat models/my_model/config.json
{"n_dim": 2, "axes": "YXC", "n_channel_in": 1, "n_channel_out": 1, "train_checkpoint": "weights_best.h5", "train_checkpoint_last": "weights_last.h5", "train_checkpoint_epoch": "weights_now.h5", "my_parameter": 42}
# load model from folder (config and possibly trained weights)
model = MyModel(None, 'my_model', basedir='models')
model
Couldn't find any network weights (*.h5, *.hdf5) to load.
MyModel(my_model): YXC → YXC ├─ Directory: /home/uwe/research/csbdeep/examples/examples/other/models/my_model └─ MyConfig(axes='YXC', my_parameter=42, n_channel_in=1, n_channel_out=1, n_dim=2, train_checkpoint='weights_best.h5', train_checkpoint_epoch='weights_now.h5', train_checkpoint_last='weights_last.h5')
BaseModel has more to offer, some is shown below...
[a for a in dir(model) if not a.startswith('__')]
['_abc_impl', '_axes_out', '_build', '_check_normalizer_resizer', '_checkpoint_callbacks', '_config_class', '_find_and_load_weights', '_make_permute_axes', '_model_prepared', '_repr_extra', '_set_logdir', '_training_finished', '_update_and_check_config', 'basedir', 'config', 'export_TF', 'from_pretrained', 'keras_model', 'load_weights', 'logdir', 'name']
try:
from stardist.models import StarDist2D
StarDist2D.from_pretrained()
except ModuleNotFoundError:
pass
There are 4 registered models for 'StarDist2D': Name Alias(es) ──── ───────── '2D_versatile_fluo' 'Versatile (fluorescent nuclei)' '2D_versatile_he' 'Versatile (H&E nuclei)' '2D_paper_dsb2018' 'DSB 2018 (from StarDist 2D paper)' '2D_demo' None
try:
StarDist2D.from_pretrained('Versatile (fluorescent nuclei)')
except ModuleNotFoundError:
pass
Found model '2D_versatile_fluo' with alias 'Versatile (fluorescent nuclei)' for 'StarDist2D'. Loading network weights from 'weights_best.h5'. Loading thresholds from 'thresholds.json'. Using default values: prob_thresh=0.479071, nms_thresh=0.3.
StarDist2D(2D_versatile_fluo): YXC → YXC ├─ Directory: None └─ Config2D(axes='YXC', backbone='unet', grid=(2, 2), n_channel_in=1, n_channel_out=33, n_classes=None, n_dim=2, n_rays=32, net_conv_after_unet=128, net_input_shape=[None, None, 1], net_mask_shape=[None, None, 1], train_background_reg=0.0001, train_batch_size=8, train_checkpoint='weights_best.h5', train_checkpoint_epoch='weights_now.h5', train_checkpoint_last='weights_last.h5', train_class_weights=(1, 1), train_completion_crop=32, train_dist_loss='mae', train_epochs=800, train_foreground_only=0.9, train_learning_rate=0.0003, train_loss_weights=[1, 0.2], train_n_val_patches=None, train_patch_size=[256, 256], train_reduce_lr={'factor': 0.5, 'patience': 80, 'min_delta': 0}, train_sample_cache=True, train_shape_completion=False, train_steps_per_epoch=400, train_tensorboard=True, unet_activation='relu', unet_batch_norm=False, unet_dropout=0.0, unet_kernel_size=[3, 3], unet_last_activation='relu', unet_n_conv_per_depth=2, unet_n_depth=3, unet_n_filter_base=32, unet_pool=[2, 2], unet_prefix='', use_gpu=False)
MyModel.from_pretrained()
There are 0 registered models for 'MyModel'
from csbdeep.models import register_model, register_aliases
register_model(MyModel, 'my_model', 'http://example.com/my_model.zip', '<hash>')
register_aliases(MyModel, 'my_model', 'My minimal model', 'Another name for my model')
MyModel.from_pretrained()
There is 1 registered model for 'MyModel': Name Alias(es) ──── ───────── 'my_model' 'My minimal model', 'Another name for my model'
Note that the focus is on demonstrating certain concepts rather than being a good/complete segmentation approach.
try:
import skimage
except ModuleNotFoundError:
raise RuntimeError("This demo needs scikit-image to run.")
from glob import glob
from tqdm import tqdm
from tifffile import imread
from pathlib import Path
from skimage.segmentation import find_boundaries
def crop(u,shape=(256,256)):
"""Crop central region of given shape"""
return u[tuple(slice((s-m)//2,(s-m)//2+m) for s,m in zip(u.shape,shape))]
def to_3class_label(lbl, onehot=True):
"""Convert instance labeling to background/inner/outer mask"""
b = find_boundaries(lbl,mode='outer')
res = (lbl>0).astype(np.uint8)
res[b] = 2
if onehot:
res = keras.utils.to_categorical(res,num_classes=3).reshape(lbl.shape+(3,))
return res
def dice_bce_loss(n_labels):
"""Combined crossentropy and dice loss"""
def _sum(a):
return K.sum(a, axis=(1,2), keepdims=True)
def dice_coef(y_true, y_pred):
return (2 * _sum(y_true * y_pred) + K.epsilon()) / (_sum(y_true) + _sum(y_pred) + K.epsilon())
def _loss(y_true, y_pred):
dice_loss = 0
for i in range(n_labels):
dice_loss += 1-dice_coef(y_true[...,i], y_pred[...,i])
return dice_loss/n_labels + K.categorical_crossentropy(y_true, y_pred)
return _loss
def datagen(X,Y,batch_size,seed=0):
"""Simple data augmentation"""
try:
ImageDataGenerator = keras.preprocessing.image.ImageDataGenerator
except AttributeError:
ImageDataGenerator = keras.src.legacy.preprocessing.image.ImageDataGenerator
g = ImageDataGenerator(horizontal_flip=True, vertical_flip=True,
rotation_range=10, shear_range=10, fill_mode='reflect')
assert seed is not None
gX = g.flow(X, batch_size=batch_size, seed=seed)
gY = g.flow(Y, batch_size=batch_size, seed=seed)
while True:
yield next(gX), next(gY)
from csbdeep.utils import download_and_extract_zip_file, normalize
download_and_extract_zip_file(
url = 'https://github.com/mpicbg-csbd/stardist/releases/download/0.1.0/dsb2018.zip',
targetdir = 'data',
verbose = 1,
)
Files missing, downloading... extracting... done.
# load and crop out central patch (for simplicity)
X = [crop(imread(x)) for x in sorted(glob('data/dsb2018/train/images/*.tif'))]
Y_label = [crop(imread(y)) for y in sorted(glob('data/dsb2018/train/masks/*.tif'))]
# normalize input image and convert label image to 3-class segmentation mask
X = [normalize(x,1,99.8) for x in tqdm(X)]
Y = [to_3class_label(y) for y in tqdm(Y_label)]
# convert to numpy arrays
X, Y, Y_label = np.expand_dims(np.stack(X),-1), np.stack(Y), np.stack(Y_label)
100%|██████████| 447/447 [00:00<00:00, 1159.26it/s] 100%|██████████| 447/447 [00:01<00:00, 437.09it/s]
i = 15
fig, (a0,a1,a2) = plt.subplots(1,3,figsize=(15,5))
a0.imshow(X[i,...,0],cmap='gray'); a0.set_title('input image')
a1.imshow(Y_label[i],cmap='tab20'); a1.set_title('label image')
a2.imshow(Y[i]); a2.set_title('segmentation mask')
fig.suptitle("Example")
None;
from csbdeep.data import PadAndCropResizer
from csbdeep.utils import axes_check_and_normalize
from csbdeep.utils.tf import IS_TF_1, CARETensorBoardImage
if IS_TF_1:
raise NotImplementedError("For sake of simplicity, this example only works with TensorFlow 2.x")
class SegConfig(BaseConfig):
def __init__(self, unet_depth, **kwargs):
super().__init__(**kwargs)
self.unet_depth = unet_depth
class SegModel(BaseModel):
@property
def _config_class(self):
return SegConfig
def _build(self):
return common_unet(n_depth=self.config.unet_depth,
n_first=32, residual=False,
n_channel_out=self.config.n_channel_out,
last_activation='softmax')((None,None,self.config.n_channel_in))
def _prepare_for_training(self, validation_data, lr):
assert self.config.n_channel_out > 1
self.keras_model.compile(optimizer=keras.optimizers.Adam(lr),
loss=dice_bce_loss(self.config.n_channel_out),
metrics=['categorical_crossentropy','accuracy'])
self.callbacks = self._checkpoint_callbacks()
self.callbacks.append(keras.callbacks.TensorBoard(log_dir=str(self.logdir/'logs'),
write_graph=False, profile_batch=0))
self.callbacks.append(CARETensorBoardImage(model=self.keras_model, data=validation_data,
log_dir=str(self.logdir/'logs'/'images'),
n_images=3, prob_out=False))
self._model_prepared = True
def train(self, X,Y, validation_data, lr, batch_size, epochs, steps_per_epoch):
if not self._model_prepared:
self._prepare_for_training(validation_data, lr)
training_data = datagen(X,Y,batch_size)
history = self.keras_model.fit(training_data, validation_data=validation_data,
epochs=epochs, steps_per_epoch=steps_per_epoch,
callbacks=self.callbacks, verbose=1)
self._training_finished()
return history
def predict(self, img, axes=None, normalizer=None, resizer=PadAndCropResizer()):
normalizer, resizer = self._check_normalizer_resizer(normalizer, resizer)
axes_net = self.config.axes
if axes is None:
axes = axes_net
axes = axes_check_and_normalize(axes, img.ndim)
axes_net_div_by = tuple((2**self.config.unet_depth if a in 'XYZ' else 1) for a in axes_net)
x = self._make_permute_axes(axes, axes_net)(img)
x = normalizer(x, axes_net)
x = resizer.before(x, axes_net, axes_net_div_by)
pred = self.keras_model.predict(x[np.newaxis], verbose=0)[0]
pred = resizer.after(pred, axes_net)
return pred
# demo: delete model folder if it already exists
%rm -rf models/seg_model
config = SegConfig(n_channel_in=1, n_channel_out=3, unet_depth=2)
model = SegModel(config, 'seg_model', basedir='models')
model
SegModel(seg_model): YXC → YXC ├─ Directory: /home/uwe/research/csbdeep/examples/examples/other/models/seg_model └─ SegConfig(axes='YXC', n_channel_in=1, n_channel_out=3, n_dim=2, train_checkpoint='weights_best.h5', train_checkpoint_epoch='weights_now.h5', train_checkpoint_last='weights_last.h5', unet_depth=2)
model.keras_model.summary(line_length=110)
Model: "model_3" ______________________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ============================================================================================================== input (InputLayer) [(None, None, None, 1) 0 [] ] down_level_0_no_0 (Conv2D) (None, None, None, 32) 320 ['input[0][0]'] down_level_0_no_1 (Conv2D) (None, None, None, 32) 9248 ['down_level_0_no_0[0][0]'] max_0 (MaxPooling2D) (None, None, None, 32) 0 ['down_level_0_no_1[0][0]'] down_level_1_no_0 (Conv2D) (None, None, None, 64) 18496 ['max_0[0][0]'] down_level_1_no_1 (Conv2D) (None, None, None, 64) 36928 ['down_level_1_no_0[0][0]'] max_1 (MaxPooling2D) (None, None, None, 64) 0 ['down_level_1_no_1[0][0]'] middle_0 (Conv2D) (None, None, None, 128 73856 ['max_1[0][0]'] ) middle_2 (Conv2D) (None, None, None, 64) 73792 ['middle_0[0][0]'] up_sampling2d_4 (UpSampling2D) (None, None, None, 64) 0 ['middle_2[0][0]'] concatenate_4 (Concatenate) (None, None, None, 128 0 ['up_sampling2d_4[0][0]', ) 'down_level_1_no_1[0][0]'] up_level_1_no_0 (Conv2D) (None, None, None, 64) 73792 ['concatenate_4[0][0]'] up_level_1_no_2 (Conv2D) (None, None, None, 32) 18464 ['up_level_1_no_0[0][0]'] up_sampling2d_5 (UpSampling2D) (None, None, None, 32) 0 ['up_level_1_no_2[0][0]'] concatenate_5 (Concatenate) (None, None, None, 64) 0 ['up_sampling2d_5[0][0]', 'down_level_0_no_1[0][0]'] up_level_0_no_0 (Conv2D) (None, None, None, 32) 18464 ['concatenate_5[0][0]'] up_level_0_no_2 (Conv2D) (None, None, None, 32) 9248 ['up_level_0_no_0[0][0]'] conv2d_9 (Conv2D) (None, None, None, 3) 99 ['up_level_0_no_2[0][0]'] activation_5 (Activation) (None, None, None, 3) 0 ['conv2d_9[0][0]'] ============================================================================================================== Total params: 332,707 Trainable params: 332,707 Non-trainable params: 0 ______________________________________________________________________________________________________________
from csbdeep.data import shuffle_inplace
# shuffle data
shuffle_inplace(X, Y, Y_label, seed=0)
# split into 80% training and 20% validation images
n_val = len(X) // 5
def split_train_val(a):
return a[:-n_val], a[-n_val:]
X_train, X_val = split_train_val(X)
Y_train, Y_val = split_train_val(Y)
Y_label_train, Y_label_val = split_train_val(Y_label)
if COLAB:
%reload_ext tensorboard
%tensorboard --logdir=models
# for demonstration purposes: training only for a very short time here
history = model.train(X_train,Y_train, validation_data=(X_val,Y_val),
lr=3e-4, batch_size=4, epochs=10, steps_per_epoch=10)
Epoch 1/10 10/10 [==============================] - 6s 404ms/step - loss: 1.8100 - categorical_crossentropy: 1.0583 - accuracy: 0.8139 - val_loss: 1.7270 - val_categorical_crossentropy: 0.9802 - val_accuracy: 0.8809 Epoch 2/10 10/10 [==============================] - 1s 82ms/step - loss: 1.5209 - categorical_crossentropy: 0.8085 - accuracy: 0.9125 - val_loss: 1.2013 - val_categorical_crossentropy: 0.5642 - val_accuracy: 0.9568 Epoch 3/10 10/10 [==============================] - 1s 77ms/step - loss: 0.9535 - categorical_crossentropy: 0.4083 - accuracy: 0.9325 - val_loss: 0.6224 - val_categorical_crossentropy: 0.1979 - val_accuracy: 0.9589 Epoch 4/10 10/10 [==============================] - 1s 76ms/step - loss: 0.6209 - categorical_crossentropy: 0.1969 - accuracy: 0.9434 - val_loss: 0.5553 - val_categorical_crossentropy: 0.1453 - val_accuracy: 0.9481 Epoch 5/10 10/10 [==============================] - 1s 78ms/step - loss: 0.5810 - categorical_crossentropy: 0.1747 - accuracy: 0.9376 - val_loss: 0.5400 - val_categorical_crossentropy: 0.1475 - val_accuracy: 0.9567 Epoch 6/10 10/10 [==============================] - 1s 76ms/step - loss: 0.5070 - categorical_crossentropy: 0.1302 - accuracy: 0.9570 - val_loss: 0.4786 - val_categorical_crossentropy: 0.1184 - val_accuracy: 0.9610 Epoch 7/10 10/10 [==============================] - 1s 75ms/step - loss: 0.4764 - categorical_crossentropy: 0.1280 - accuracy: 0.9601 - val_loss: 0.4453 - val_categorical_crossentropy: 0.1123 - val_accuracy: 0.9625 Epoch 8/10 10/10 [==============================] - 1s 79ms/step - loss: 0.4590 - categorical_crossentropy: 0.1317 - accuracy: 0.9543 - val_loss: 0.4260 - val_categorical_crossentropy: 0.1176 - val_accuracy: 0.9598 Epoch 9/10 10/10 [==============================] - 1s 162ms/step - loss: 0.4479 - categorical_crossentropy: 0.1338 - accuracy: 0.9583 - val_loss: 0.4279 - val_categorical_crossentropy: 0.1176 - val_accuracy: 0.9574 Epoch 10/10 10/10 [==============================] - 1s 74ms/step - loss: 0.4706 - categorical_crossentropy: 0.1500 - accuracy: 0.9474 - val_loss: 0.4345 - val_categorical_crossentropy: 0.1257 - val_accuracy: 0.9561 Loading network weights from 'weights_best.h5'.
Model folder after training:
%ls models/seg_model
config.json logs/ weights_best.h5 weights_last.h5
# only works if "tree" is installed
!tree models/seg_model
models/seg_model ├── config.json ├── logs │ ├── images │ │ └── events.out.tfevents.1713794031.workstation.150763.0.v2 │ ├── train │ │ └── events.out.tfevents.1713794031.workstation.150763.1.v2 │ └── validation │ └── events.out.tfevents.1713794034.workstation.150763.2.v2 ├── weights_best.h5 └── weights_last.h5 4 directories, 6 files
Model weights at best validation loss are automatically loaded after training. Or when reloading the model from disk:
model = SegModel(None, 'seg_model', basedir='models')
Loading network weights from 'weights_best.h5'.
# can predict via keras model, but only works for properly-shaped and normalized images
Yhat_val = model.keras_model.predict(X_val, batch_size=8)
Yhat_val.shape
(89, 256, 256, 3)
i = 1
img, lbl, mask = X_val[i,:223,:223,0], Y_label_val[i,:223,:223], Y_val[i,:223,:223]
img.shape, lbl.shape, mask.shape
((223, 223), (223, 223), (223, 223, 3))
# U-Net models expects input to be divisible by certain sizes, hence fails here.
try:
model.keras_model.predict(img[np.newaxis])
except Exception as e:
print(e)
in user code: File "/home/uwe/sw/miniconda3/envs/ws/lib/python3.8/site-packages/keras/engine/training.py", line 1801, in predict_function * return step_function(self, iterator) File "/home/uwe/sw/miniconda3/envs/ws/lib/python3.8/site-packages/keras/engine/training.py", line 1790, in step_function ** outputs = model.distribute_strategy.run(run_step, args=(data,)) File "/home/uwe/sw/miniconda3/envs/ws/lib/python3.8/site-packages/keras/engine/training.py", line 1783, in run_step ** outputs = model.predict_step(data) File "/home/uwe/sw/miniconda3/envs/ws/lib/python3.8/site-packages/keras/engine/training.py", line 1751, in predict_step return self(x, training=False) File "/home/uwe/sw/miniconda3/envs/ws/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler raise e.with_traceback(filtered_tb) from None File "/home/uwe/sw/miniconda3/envs/ws/lib/python3.8/site-packages/keras/backend.py", line 3313, in concatenate return tf.concat([to_dense(x) for x in tensors], axis) ValueError: Exception encountered when calling layer "concatenate_6" (type Concatenate). Dimension 1 in both shapes must be equal, but are 110 and 111. Shapes are [?,110,110] and [?,111,111]. for '{{node model_4/concatenate_6/concat}} = ConcatV2[N=2, T=DT_FLOAT, Tidx=DT_INT32](model_4/up_sampling2d_6/resize/ResizeNearestNeighbor, model_4/down_level_1_no_1/Relu, model_4/concatenate_6/concat/axis)' with input shapes: [?,110,110,64], [?,111,111,64], [] and with computed input tensors: input[2] = <3>. Call arguments received: • inputs=['tf.Tensor(shape=(None, 110, 110, 64), dtype=float32)', 'tf.Tensor(shape=(None, 111, 111, 64), dtype=float32)']
mask_pred = model.predict(img, axes='YX')
mask_pred.shape
(223, 223, 3)
from skimage.measure import label
# threshold inner (green) and find connected components
lbl_pred = label(mask_pred[...,1] > 0.7)
fig, ((a0,a1,a2),(b0,b1,b2)) = plt.subplots(2,3,figsize=(15,10))
a0.imshow(img,cmap='gray'); a0.set_title('input image')
a1.imshow(lbl,cmap='tab20'); a1.set_title('label image')
a2.imshow(mask); a2.set_title('segmentation mask')
b0.axis('off')
b1.imshow(lbl_pred,cmap='tab20'); b1.set_title('label image (prediction)')
b2.imshow(mask_pred); b2.set_title('segmentation mask (prediction)')
fig.suptitle("Example")
None;
from csbdeep.internals.predict import tile_iterator
help(tile_iterator)
Help on function tile_iterator in module csbdeep.internals.predict: tile_iterator(x, n_tiles, block_sizes, n_block_overlaps, guarantee='size') Tile iterator for n-d arrays. Yields block-aligned tiles (`block_sizes`) that have at least a certain amount of overlapping blocks (`n_block_overlaps`) with their neighbors. Also yields slices that allow to map each tile back to the original array x. Notes ----- - Tiles will not go beyond the array boundary (i.e. no padding). This means the shape of x must be evenly divisible by the respective block_size. - It is not guaranteed that all tiles have the same size if guarantee is not 'size'. Parameters ---------- x : numpy.ndarray Input array. n_tiles : int or sequence of ints Number of tiles for each dimension of x. block_sizes : int or sequence of ints Block sizes for each dimension of x. The shape of x is assumed to be evenly divisible by block_sizes. All tiles are aligned with block_sizes. n_block_overlaps : int or sequence of ints Tiles will at least overlap this many blocks in each dimension. guarantee : str Can be either 'size' or 'n_tiles': 'size': The size of all tiles is guaranteed to be the same, but the number of tiles can be different and the amount of overlap can be larger than requested. 'n_tiles': The size of tiles can be different at the beginning and end, but the number of tiles is guarantee to be the one requested. The mount of overlap is also exactly as requested. Example ------- Duplicate an array tile-by-tile: >>> x = np.array(...) >>> y = np.empty_like(x) >>> >>> for tile,s_src,s_dst in tile_iterator(x, n_tiles, block_sizes, n_block_overlaps): >>> y[s_dst] = tile[s_src] >>> >>> np.allclose(x,y) True
img = imread('data/dsb2018/test/images/5f9d29d6388c700f35a3c29fa1b1ce0c1cba6667d05fdb70bd1e89004dcf71ed.tif')
img = normalize(img, 1,99.8)
plt.figure(figsize=(8,8))
plt.imshow(img, clim=(0,1), cmap='gray')
plt.title(f"example image with shape = {img.shape}");
import matplotlib.patches as patches
def process(x):
return model.predict(x, axes='YX')
img_processed = process(img)
img_processed_tiled = np.empty_like(img_processed)
###
block_sizes = (8,8)
n_block_overlaps = (3,5)
n_tiles = (3,5)
print(f"block_sizes = {block_sizes}")
print(f"n_block_overlaps = {n_block_overlaps}")
print(f"n_tiles = {n_tiles}")
fig, ax = plt.subplots(*n_tiles, figsize=(15,8))
ax = ax.ravel()
[a.axis('off') for a in ax]
i = 0
for tile,s_src,s_dst in tile_iterator(img, n_tiles, block_sizes, n_block_overlaps, guarantee='size'):
# tile is padded; will always start and end at a multiple of block size
# tile[s_src] removes the padding (shown in magenta)
# the slice s_dst denotes the region where tile[s_src] comes from
# process tile, crop the padded region from the result and put it at its original location
img_processed_tiled[s_dst] = process(tile)[s_src]
ax[i].imshow(tile, clim=(0,1), cmap='gray')
rect = patches.Rectangle( [s.start for s in reversed(s_src)],
*[s.stop-s.start for s in reversed(s_src)],
edgecolor='none',facecolor='m',alpha=0.6)
ax[i].add_patch(rect)
i+=1
plt.tight_layout()
assert np.allclose(img_processed, img_processed_tiled)
None;
block_sizes = (8, 8) n_block_overlaps = (3, 5) n_tiles = (3, 5)