from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
from glob import glob
from tifffile import imread
from csbdeep.utils import Path, normalize
from skimage.segmentation import find_boundaries
from stardist import dist_to_coord, non_maximum_suppression, polygons_to_label
from stardist import random_label_cmap, draw_polygons, sample_points
from stardist import Config, StarDist
np.random.seed(6)
lbl_cmap = random_label_cmap()
Using TensorFlow backend.
We assume that data has already been downloaded in via notebook 1_data.ipynb.
We now load images from the sub-folder test
that have not been used during training.
X = sorted(glob('data/dsb2018/test/images/*.tif'))
X = list(map(imread,X))
n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]
axis_norm = (0,1) # normalize channels independently
# axis_norm = (0,1,2) # normalize channels jointly
if n_channel > 1:
print("Normalizing image channels %s." % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))
# show all test images
if False:
fig, ax = plt.subplots(7,8, figsize=(16,16))
for i,(a,x) in enumerate(zip(ax.flat, X)):
a.imshow(x,cmap='gray')
a.set_title(i)
[a.axis('off') for a in ax.flat]
plt.tight_layout()
None;
We assume that two StarDist models have already been trained via notebook 2_training.ipynb.
# Without shape completion
model_no_sc = StarDist(None, name='stardist_no_shape_completion', basedir='models')
Loading network weights from 'weights_best.h5'.
# With shape completion
model_sc = StarDist(None, name='stardist_shape_completion', basedir='models')
Loading network weights from 'weights_best.h5'.
prob
abilities and star-convex polygon dist
ancesdist
ances to polygon vertex coord
inates%%time
# 1.
%time img = normalize(X[16],1,99.8,axis=axis_norm)
# 2.
%time prob, dist = model_sc.predict(img)
# 3.
%time coord = dist_to_coord(dist)
# 4.
%time points = non_maximum_suppression(coord,prob,prob_thresh=0.4)
# 5.
%time labels = polygons_to_label(coord,prob,points)
print('------------------')
CPU times: user 5.31 ms, sys: 531 µs, total: 5.84 ms Wall time: 3.34 ms CPU times: user 86.1 ms, sys: 24.7 ms, total: 111 ms Wall time: 81.9 ms CPU times: user 75.9 ms, sys: 11 ms, total: 86.9 ms Wall time: 84.7 ms CPU times: user 1.07 s, sys: 126 ms, total: 1.19 s Wall time: 56 ms CPU times: user 156 ms, sys: 406 ms, total: 562 ms Wall time: 23.8 ms ------------------ CPU times: user 1.41 s, sys: 612 ms, total: 2.02 s Wall time: 257 ms
img_show = img if img.ndim==2 else img[...,0]
fig, ax = plt.subplots(2,2, figsize=(12,12))
for a,d,cm,s in zip(ax.flat, [img_show,prob,dist[...,0],labels], ['gray','magma','viridis',lbl_cmap],
['Input image','Predicted object probability','Predicted distance (0°)','Predicted label instances']):
a.imshow(d,cmap=cm)
a.set_title(s)
a.axis('off')
plt.tight_layout()
None;
plt.figure(figsize=(13,12))
points_rnd = sample_points(200,prob>0.2)
plt.subplot(121); plt.imshow(img_show,cmap='gray'); draw_polygons(coord,prob,points_rnd,cmap=lbl_cmap)
plt.axis('off'); plt.title('Polygons randomly sampled')
plt.subplot(122); plt.imshow(img_show,cmap='gray'); draw_polygons(coord,prob,points,cmap=lbl_cmap)
plt.axis('off'); plt.title('Polygons after non-maximum suppression')
plt.tight_layout()
None;
def example(model,i):
img = normalize(X[i],1,99.8,axis=axis_norm)
prob, dist = model.predict(img)
coord = dist_to_coord(dist)
points = non_maximum_suppression(coord,prob,prob_thresh=0.4)
labels = polygons_to_label(coord,prob,points)
img_show = img if img.ndim==2 else img[...,0]
plt.figure(figsize=(13,10))
plt.subplot(121); plt.imshow(img_show,cmap='gray'); plt.axis('off')
draw_polygons(coord,prob,points,show_dist=True)
if model in (model_no_sc,model_sc):
plt.title(('With' if model==model_sc else 'Without') + ' shape completion')
plt.subplot(122); plt.imshow(img_show,cmap='gray'); plt.axis('off')
plt.imshow(labels,cmap=lbl_cmap,alpha=0.5)
plt.tight_layout()
plt.show()
example(model_no_sc,42)
example(model_sc,42)
example(model_no_sc,1)
example(model_sc,1)
example(model_no_sc,15)
example(model_sc,15)
model_paper = StarDist(None, name='dsb2018', basedir='../models')
model_paper.load_weights('weights_last.h5')
Loading network weights from 'weights_best.h5'.
example(model_paper,29)