This example demonstrates the training of the unrolled optimization with deep priors (ODP) gradient descent architecture described in applied to a CT reconstruction problem.
The source images are foam phantoms generated with xdesign.
A class scico.flax.ODPNet implements the ODP architecture, which solves the optimization problem
$$\mathrm{argmin}_{\mathbf{x}} \; \| A \mathbf{x} - \mathbf{y} \|_2^2 + r(\mathbf{x}) \;,$$where $A$ is a tomographic projector, $\mathbf{y}$ is a set of sinograms, $r$ is a regularizer and $\mathbf{x}$ is the set of reconstructed images. The ODP, gradient descent architecture, abstracts the iterative solution by an unrolled network where each iteration corresponds to a different stage in the ODP network and updates the prediction by solving
$$\mathbf{x}^{k+1} = \mathrm{argmin}_{\mathbf{x}} \; \alpha_k \| A \mathbf{x} - \mathbf{y} \|_2^2 + \frac{1}{2} \| \mathbf{x} - \mathbf{x}^k - \mathbf{x}^{k+1/2} \|_2^2 \;,$$which for the CT problem, using gradient descent, corresponds to
$$\mathbf{x}^{k+1} = \mathbf{x}^k + \mathbf{x}^{k+1/2} - \alpha_k \, A^T \, (A \mathbf{x}^k - \mathbf{y}) \;,$$where $k$ is the index of the stage (iteration), $\mathbf{x}^k + \mathbf{x}^{k+1/2} = \mathrm{ResNet}(\mathbf{x}^{k})$ is the regularization (implemented as a residual convolutional neural network), $\mathbf{x}^k$ is the output of the previous stage and $\alpha_k > 0$ is a learned stage-wise parameter weighting the contribution of the fidelity term. The output of the final stage is the set of reconstructed images.
# This scico project Jupyter notebook has been automatically modified
# to install the dependencies required for running it on Google Colab.
# If you encounter any problems in running it, please open an issue at
# https://github.com/lanl/scico-data/issues
!pip install 'scico[examples] @ git+https://github.com/lanl/scico'
# This scico project Jupyter notebook has been automatically modified
# to install the dependencies required for running it on Google Colab.
# If you encounter any problems in running it, please open an issue at
# https://github.com/lanl/scico-data/issues
!pip install 'scico[examples] @ git+https://github.com/lanl/scico
'
# This scico project Jupyter notebook has been automatically modified
# to install the dependencies required for running it on Google Colab.
# If you encounter any problems in running it, please open an issue at
# https://github.com/lanl/scico-data/issues
!pip install git+https://github.com/lanl/scico
!pip install astra-toolbox xdesign
# This scico project Jupyter notebook has been automatically modified
# to install the dependencies required for running it on Google Colab.
# If you encounter any problems in running it, please open an issue at
# https://github.com/lanl/scico-data/issues
!pip install git+https://github.com/lanl/scico
!pip install astra-toolbox
# This scico project Jupyter notebook has been automatically modified
# to install the dependencies required for running it on Google Colab.
# If you encounter any problems in running it, please open an issue at
# https://github.com/lanl/scico-data/issues
!pip install git+https://github.com/lanl/scico
!pip install astra-toolbox
# This scico project Jupyter notebook has been automatically modified
# to install the dependencies required for running it on Google Colab.
# If you encounter any problems in running it, please open an issue at
# https://github.com/lanl/scico-data/issues
!pip install git+https://github.com/lanl/scico
!pip install astra-toolbox
import os
from functools import partial
from time import time
import numpy as np
import jax
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scico import flax as sflax
from scico import metric, plot
from scico.flax.examples import load_ct_data
from scico.flax.train.traversals import clip_positive, construct_traversal
from scico.linop.xray.astra import XRayTransform
plot.config_notebook_plotting()
Prepare parallel processing. Set an arbitrary processor count (only applies if GPU is not available).
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
platform = jax.lib.xla_bridge.get_backend().platform
print("Platform: ", platform)
Platform: gpu
Read data from cache or generate if not available.
N = 256 # phantom size
train_nimg = 536 # number of training images
test_nimg = 64 # number of testing images
nimg = train_nimg + test_nimg
n_projection = 45 # CT views
trdt, ttdt = load_ct_data(train_nimg, test_nimg, N, n_projection, verbose=True)
Data read from path : ~/.cache/scico/examples/data Set --training-- : Size: 536 Set --testing -- : Size: 64 Data range --images -- : Min: 0.00, Max: 1.00 Data range --sinogram-- : Min: 0.00, Max: 0.67 Data range --FBP -- : Min: 0.00, Max: 1.00
Build CT projection operator.
angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
A = XRayTransform(
input_shape=(N, N),
detector_spacing=1,
det_count=N,
angles=angles,
) # CT projection operator
A = (1.0 / N) * A # normalized
Build training and testing structures. Inputs are the sinograms and outpus are the original generated foams. Keep training and testing partitions.
numtr = 320
numtt = 32
train_ds = {"image": trdt["sino"][:numtr], "label": trdt["img"][:numtr]}
test_ds = {"image": ttdt["sino"][:numtt], "label": ttdt["img"][:numtt]}
Define configuration dictionary for model and training loop.
Parameters have been selected for demonstration purposes and relatively short training. The model depth is akin to the number of unrolled iterations in the MoDL model. The block depth controls the number of layers at each unrolled iteration. The number of filters is uniform throughout the iterations. The iterations used for the conjugate gradient (CG) solver can also be specified. Better performance may be obtained by increasing depth, block depth, number of filters, CG iterations, or training epochs, but may require longer training times.
# model configuration
model_conf = {
"depth": 8,
"num_filters": 64,
"block_depth": 6,
}
# training configuration
train_conf: sflax.ConfigDict = {
"seed": 1234,
"opt_type": "ADAM",
"batch_size": 16,
"num_epochs": 200,
"base_learning_rate": 1e-3,
"warmup_epochs": 0,
"log_every_steps": 160,
"log": True,
"checkpointing": True,
}
Construct functionality for making sure that the learned fidelity weight parameter is always positive.
alphatrav = construct_traversal("alpha") # select alpha parameters in model
alphapost = partial(
clip_positive, # apply this function
traversal=alphatrav, # to alpha parameters in model
minval=1e-3,
)
Print configuration of distributed run.
print(f"{'JAX process: '}{jax.process_index()}{' / '}{jax.process_count()}")
print(f"{'JAX local devices: '}{jax.local_devices()}")
JAX process: 0 / 1 JAX local devices: [cuda(id=0), cuda(id=1), cuda(id=2), cuda(id=3), cuda(id=4), cuda(id=5), cuda(id=6), cuda(id=7)]
Construct ODPNet model.
channels = train_ds["image"].shape[-1]
model = sflax.ODPNet(
operator=A,
depth=model_conf["depth"],
channels=channels,
num_filters=model_conf["num_filters"],
block_depth=model_conf["block_depth"],
odp_block=sflax.inverse.ODPGrDescBlock,
alpha_ini=1e-2,
)
Run training loop.
workdir = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "odp_ct_out")
train_conf["workdir"] = workdir
train_conf["post_lst"] = [alphapost]
# Construct training object
trainer = sflax.BasicFlaxTrainer(
train_conf,
model,
train_ds,
test_ds,
)
start_time = time()
modvar, stats_object = trainer.train()
time_train = time() - start_time
Channels: 1, training signals: 320, testing signals: 32, signal size: 256 +---------------------------------------------------------+----------------+--------+-----------+--------+ | Name | Shape | Size | Mean | Std | +---------------------------------------------------------+----------------+--------+-----------+--------+ | ODPGrDescBlock_0/alpha | (1,) | 1 | 0.01 | 0.0 | | ODPGrDescBlock_0/resnet/BatchNorm_0/bias | (1,) | 1 | 0.0 | 0.0 | | ODPGrDescBlock_0/resnet/BatchNorm_0/scale | (1,) | 1 | 1.0 | 0.0 | | ODPGrDescBlock_0/resnet/ConvBNBlock_0/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_0/resnet/ConvBNBlock_0/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_0/resnet/ConvBNBlock_0/Conv_0/kernel | (3, 3, 1, 64) | 576 | -0.000308 | 0.0568 | | ODPGrDescBlock_0/resnet/ConvBNBlock_1/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_0/resnet/ConvBNBlock_1/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_0/resnet/ConvBNBlock_1/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | 2.38e-05 | 0.0416 | | ODPGrDescBlock_0/resnet/ConvBNBlock_2/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_0/resnet/ConvBNBlock_2/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_0/resnet/ConvBNBlock_2/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | 0.000402 | 0.0418 | | ODPGrDescBlock_0/resnet/ConvBNBlock_3/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_0/resnet/ConvBNBlock_3/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_0/resnet/ConvBNBlock_3/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | 0.000185 | 0.0416 | | ODPGrDescBlock_0/resnet/ConvBNBlock_4/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_0/resnet/ConvBNBlock_4/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_0/resnet/ConvBNBlock_4/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | -0.000101 | 0.0418 | | ODPGrDescBlock_0/resnet/Conv_0/kernel | (3, 3, 64, 1) | 576 | 0.00276 | 0.058 | | ODPGrDescBlock_1/alpha | (1,) | 1 | 0.005 | 0.0 | | ODPGrDescBlock_1/resnet/BatchNorm_0/bias | (1,) | 1 | 0.0 | 0.0 | | ODPGrDescBlock_1/resnet/BatchNorm_0/scale | (1,) | 1 | 1.0 | 0.0 | | ODPGrDescBlock_1/resnet/ConvBNBlock_0/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_1/resnet/ConvBNBlock_0/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_1/resnet/ConvBNBlock_0/Conv_0/kernel | (3, 3, 1, 64) | 576 | 0.00128 | 0.0583 | | ODPGrDescBlock_1/resnet/ConvBNBlock_1/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_1/resnet/ConvBNBlock_1/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_1/resnet/ConvBNBlock_1/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | -0.000226 | 0.0419 | | ODPGrDescBlock_1/resnet/ConvBNBlock_2/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_1/resnet/ConvBNBlock_2/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_1/resnet/ConvBNBlock_2/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | 0.00029 | 0.0415 | | ODPGrDescBlock_1/resnet/ConvBNBlock_3/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_1/resnet/ConvBNBlock_3/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_1/resnet/ConvBNBlock_3/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | -5.14e-05 | 0.0417 | | ODPGrDescBlock_1/resnet/ConvBNBlock_4/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_1/resnet/ConvBNBlock_4/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_1/resnet/ConvBNBlock_4/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | 0.000319 | 0.0416 | | ODPGrDescBlock_1/resnet/Conv_0/kernel | (3, 3, 64, 1) | 576 | 0.000334 | 0.0583 | | ODPGrDescBlock_2/alpha | (1,) | 1 | 0.0025 | 0.0 | | ODPGrDescBlock_2/resnet/BatchNorm_0/bias | (1,) | 1 | 0.0 | 0.0 | | ODPGrDescBlock_2/resnet/BatchNorm_0/scale | (1,) | 1 | 1.0 | 0.0 | | ODPGrDescBlock_2/resnet/ConvBNBlock_0/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_2/resnet/ConvBNBlock_0/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_2/resnet/ConvBNBlock_0/Conv_0/kernel | (3, 3, 1, 64) | 576 | -0.00119 | 0.0602 | | ODPGrDescBlock_2/resnet/ConvBNBlock_1/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_2/resnet/ConvBNBlock_1/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_2/resnet/ConvBNBlock_1/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | 0.000443 | 0.0415 | | ODPGrDescBlock_2/resnet/ConvBNBlock_2/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_2/resnet/ConvBNBlock_2/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_2/resnet/ConvBNBlock_2/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | 0.000163 | 0.0416 | | ODPGrDescBlock_2/resnet/ConvBNBlock_3/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_2/resnet/ConvBNBlock_3/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_2/resnet/ConvBNBlock_3/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | -0.0004 | 0.0417 | | ODPGrDescBlock_2/resnet/ConvBNBlock_4/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_2/resnet/ConvBNBlock_4/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_2/resnet/ConvBNBlock_4/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | 6.25e-05 | 0.0417 | | ODPGrDescBlock_2/resnet/Conv_0/kernel | (3, 3, 64, 1) | 576 | 0.000515 | 0.0586 | | ODPGrDescBlock_3/alpha | (1,) | 1 | 0.00125 | 0.0 | | ODPGrDescBlock_3/resnet/BatchNorm_0/bias | (1,) | 1 | 0.0 | 0.0 | | ODPGrDescBlock_3/resnet/BatchNorm_0/scale | (1,) | 1 | 1.0 | 0.0 | | ODPGrDescBlock_3/resnet/ConvBNBlock_0/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_3/resnet/ConvBNBlock_0/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_3/resnet/ConvBNBlock_0/Conv_0/kernel | (3, 3, 1, 64) | 576 | -0.00179 | 0.057 | | ODPGrDescBlock_3/resnet/ConvBNBlock_1/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_3/resnet/ConvBNBlock_1/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_3/resnet/ConvBNBlock_1/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | -3.06e-05 | 0.0417 | | ODPGrDescBlock_3/resnet/ConvBNBlock_2/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_3/resnet/ConvBNBlock_2/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_3/resnet/ConvBNBlock_2/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | 0.000153 | 0.0416 | | ODPGrDescBlock_3/resnet/ConvBNBlock_3/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_3/resnet/ConvBNBlock_3/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_3/resnet/ConvBNBlock_3/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | 0.000617 | 0.0418 | | ODPGrDescBlock_3/resnet/ConvBNBlock_4/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_3/resnet/ConvBNBlock_4/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_3/resnet/ConvBNBlock_4/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | -0.000218 | 0.0419 | | ODPGrDescBlock_3/resnet/Conv_0/kernel | (3, 3, 64, 1) | 576 | -0.000309 | 0.0575 | | ODPGrDescBlock_4/alpha | (1,) | 1 | 0.000625 | 0.0 | | ODPGrDescBlock_4/resnet/BatchNorm_0/bias | (1,) | 1 | 0.0 | 0.0 | | ODPGrDescBlock_4/resnet/BatchNorm_0/scale | (1,) | 1 | 1.0 | 0.0 | | ODPGrDescBlock_4/resnet/ConvBNBlock_0/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_4/resnet/ConvBNBlock_0/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_4/resnet/ConvBNBlock_0/Conv_0/kernel | (3, 3, 1, 64) | 576 | 0.00173 | 0.0585 | | ODPGrDescBlock_4/resnet/ConvBNBlock_1/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_4/resnet/ConvBNBlock_1/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_4/resnet/ConvBNBlock_1/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | -0.000285 | 0.0417 | | ODPGrDescBlock_4/resnet/ConvBNBlock_2/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_4/resnet/ConvBNBlock_2/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_4/resnet/ConvBNBlock_2/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | -0.000136 | 0.0416 | | ODPGrDescBlock_4/resnet/ConvBNBlock_3/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_4/resnet/ConvBNBlock_3/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_4/resnet/ConvBNBlock_3/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | 0.000243 | 0.0417 | | ODPGrDescBlock_4/resnet/ConvBNBlock_4/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_4/resnet/ConvBNBlock_4/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_4/resnet/ConvBNBlock_4/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | 0.000183 | 0.0417 | | ODPGrDescBlock_4/resnet/Conv_0/kernel | (3, 3, 64, 1) | 576 | 0.0033 | 0.0608 | | ODPGrDescBlock_5/alpha | (1,) | 1 | 0.000312 | 0.0 | | ODPGrDescBlock_5/resnet/BatchNorm_0/bias | (1,) | 1 | 0.0 | 0.0 | | ODPGrDescBlock_5/resnet/BatchNorm_0/scale | (1,) | 1 | 1.0 | 0.0 | | ODPGrDescBlock_5/resnet/ConvBNBlock_0/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_5/resnet/ConvBNBlock_0/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_5/resnet/ConvBNBlock_0/Conv_0/kernel | (3, 3, 1, 64) | 576 | 0.00178 | 0.0589 | | ODPGrDescBlock_5/resnet/ConvBNBlock_1/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_5/resnet/ConvBNBlock_1/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_5/resnet/ConvBNBlock_1/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | -0.000261 | 0.0417 | | ODPGrDescBlock_5/resnet/ConvBNBlock_2/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_5/resnet/ConvBNBlock_2/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_5/resnet/ConvBNBlock_2/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | 0.000299 | 0.0417 | | ODPGrDescBlock_5/resnet/ConvBNBlock_3/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_5/resnet/ConvBNBlock_3/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_5/resnet/ConvBNBlock_3/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | -0.000267 | 0.0418 | | ODPGrDescBlock_5/resnet/ConvBNBlock_4/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_5/resnet/ConvBNBlock_4/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_5/resnet/ConvBNBlock_4/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | 0.000261 | 0.0416 | | ODPGrDescBlock_5/resnet/Conv_0/kernel | (3, 3, 64, 1) | 576 | 0.00478 | 0.0608 | | ODPGrDescBlock_6/alpha | (1,) | 1 | 0.000156 | 0.0 | | ODPGrDescBlock_6/resnet/BatchNorm_0/bias | (1,) | 1 | 0.0 | 0.0 | | ODPGrDescBlock_6/resnet/BatchNorm_0/scale | (1,) | 1 | 1.0 | 0.0 | | ODPGrDescBlock_6/resnet/ConvBNBlock_0/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_6/resnet/ConvBNBlock_0/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_6/resnet/ConvBNBlock_0/Conv_0/kernel | (3, 3, 1, 64) | 576 | 0.000588 | 0.0576 | | ODPGrDescBlock_6/resnet/ConvBNBlock_1/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_6/resnet/ConvBNBlock_1/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_6/resnet/ConvBNBlock_1/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | 0.000272 | 0.0416 | | ODPGrDescBlock_6/resnet/ConvBNBlock_2/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_6/resnet/ConvBNBlock_2/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_6/resnet/ConvBNBlock_2/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | 6.59e-05 | 0.0417 | | ODPGrDescBlock_6/resnet/ConvBNBlock_3/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_6/resnet/ConvBNBlock_3/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_6/resnet/ConvBNBlock_3/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | 7.39e-05 | 0.0418 | | ODPGrDescBlock_6/resnet/ConvBNBlock_4/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_6/resnet/ConvBNBlock_4/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_6/resnet/ConvBNBlock_4/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | -5e-05 | 0.0418 | | ODPGrDescBlock_6/resnet/Conv_0/kernel | (3, 3, 64, 1) | 576 | -0.000458 | 0.0583 | | ODPGrDescBlock_7/alpha | (1,) | 1 | 7.81e-05 | 0.0 | | ODPGrDescBlock_7/resnet/BatchNorm_0/bias | (1,) | 1 | 0.0 | 0.0 | | ODPGrDescBlock_7/resnet/BatchNorm_0/scale | (1,) | 1 | 1.0 | 0.0 | | ODPGrDescBlock_7/resnet/ConvBNBlock_0/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_7/resnet/ConvBNBlock_0/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_7/resnet/ConvBNBlock_0/Conv_0/kernel | (3, 3, 1, 64) | 576 | 0.00317 | 0.0574 | | ODPGrDescBlock_7/resnet/ConvBNBlock_1/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_7/resnet/ConvBNBlock_1/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_7/resnet/ConvBNBlock_1/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | -0.000238 | 0.0415 | | ODPGrDescBlock_7/resnet/ConvBNBlock_2/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_7/resnet/ConvBNBlock_2/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_7/resnet/ConvBNBlock_2/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | 0.000153 | 0.0418 | | ODPGrDescBlock_7/resnet/ConvBNBlock_3/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_7/resnet/ConvBNBlock_3/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_7/resnet/ConvBNBlock_3/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | -3.74e-05 | 0.0417 | | ODPGrDescBlock_7/resnet/ConvBNBlock_4/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_7/resnet/ConvBNBlock_4/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_7/resnet/ConvBNBlock_4/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | 0.000368 | 0.0418 | | ODPGrDescBlock_7/resnet/Conv_0/kernel | (3, 3, 64, 1) | 576 | -0.00299 | 0.0604 | +---------------------------------------------------------+----------------+--------+-----------+--------+ Total: 1,194,008 +--------------------------------------------------------+-------+------+------+-----+ | Name | Shape | Size | Mean | Std | +--------------------------------------------------------+-------+------+------+-----+ | ODPGrDescBlock_0/resnet/BatchNorm_0/mean | (1,) | 1 | 0.0 | 0.0 | | ODPGrDescBlock_0/resnet/BatchNorm_0/var | (1,) | 1 | 1.0 | 0.0 | | ODPGrDescBlock_0/resnet/ConvBNBlock_0/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_0/resnet/ConvBNBlock_0/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_0/resnet/ConvBNBlock_1/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_0/resnet/ConvBNBlock_1/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_0/resnet/ConvBNBlock_2/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_0/resnet/ConvBNBlock_2/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_0/resnet/ConvBNBlock_3/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_0/resnet/ConvBNBlock_3/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_0/resnet/ConvBNBlock_4/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_0/resnet/ConvBNBlock_4/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_1/resnet/BatchNorm_0/mean | (1,) | 1 | 0.0 | 0.0 | | ODPGrDescBlock_1/resnet/BatchNorm_0/var | (1,) | 1 | 1.0 | 0.0 | | ODPGrDescBlock_1/resnet/ConvBNBlock_0/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_1/resnet/ConvBNBlock_0/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_1/resnet/ConvBNBlock_1/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_1/resnet/ConvBNBlock_1/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_1/resnet/ConvBNBlock_2/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_1/resnet/ConvBNBlock_2/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_1/resnet/ConvBNBlock_3/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_1/resnet/ConvBNBlock_3/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_1/resnet/ConvBNBlock_4/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_1/resnet/ConvBNBlock_4/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_2/resnet/BatchNorm_0/mean | (1,) | 1 | 0.0 | 0.0 | | ODPGrDescBlock_2/resnet/BatchNorm_0/var | (1,) | 1 | 1.0 | 0.0 | | ODPGrDescBlock_2/resnet/ConvBNBlock_0/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_2/resnet/ConvBNBlock_0/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_2/resnet/ConvBNBlock_1/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_2/resnet/ConvBNBlock_1/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_2/resnet/ConvBNBlock_2/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_2/resnet/ConvBNBlock_2/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_2/resnet/ConvBNBlock_3/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_2/resnet/ConvBNBlock_3/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_2/resnet/ConvBNBlock_4/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_2/resnet/ConvBNBlock_4/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_3/resnet/BatchNorm_0/mean | (1,) | 1 | 0.0 | 0.0 | | ODPGrDescBlock_3/resnet/BatchNorm_0/var | (1,) | 1 | 1.0 | 0.0 | | ODPGrDescBlock_3/resnet/ConvBNBlock_0/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_3/resnet/ConvBNBlock_0/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_3/resnet/ConvBNBlock_1/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_3/resnet/ConvBNBlock_1/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_3/resnet/ConvBNBlock_2/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_3/resnet/ConvBNBlock_2/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_3/resnet/ConvBNBlock_3/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_3/resnet/ConvBNBlock_3/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_3/resnet/ConvBNBlock_4/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_3/resnet/ConvBNBlock_4/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_4/resnet/BatchNorm_0/mean | (1,) | 1 | 0.0 | 0.0 | | ODPGrDescBlock_4/resnet/BatchNorm_0/var | (1,) | 1 | 1.0 | 0.0 | | ODPGrDescBlock_4/resnet/ConvBNBlock_0/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_4/resnet/ConvBNBlock_0/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_4/resnet/ConvBNBlock_1/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_4/resnet/ConvBNBlock_1/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_4/resnet/ConvBNBlock_2/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_4/resnet/ConvBNBlock_2/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_4/resnet/ConvBNBlock_3/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_4/resnet/ConvBNBlock_3/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_4/resnet/ConvBNBlock_4/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_4/resnet/ConvBNBlock_4/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_5/resnet/BatchNorm_0/mean | (1,) | 1 | 0.0 | 0.0 | | ODPGrDescBlock_5/resnet/BatchNorm_0/var | (1,) | 1 | 1.0 | 0.0 | | ODPGrDescBlock_5/resnet/ConvBNBlock_0/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_5/resnet/ConvBNBlock_0/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_5/resnet/ConvBNBlock_1/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_5/resnet/ConvBNBlock_1/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_5/resnet/ConvBNBlock_2/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_5/resnet/ConvBNBlock_2/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_5/resnet/ConvBNBlock_3/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_5/resnet/ConvBNBlock_3/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_5/resnet/ConvBNBlock_4/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_5/resnet/ConvBNBlock_4/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_6/resnet/BatchNorm_0/mean | (1,) | 1 | 0.0 | 0.0 | | ODPGrDescBlock_6/resnet/BatchNorm_0/var | (1,) | 1 | 1.0 | 0.0 | | ODPGrDescBlock_6/resnet/ConvBNBlock_0/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_6/resnet/ConvBNBlock_0/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_6/resnet/ConvBNBlock_1/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_6/resnet/ConvBNBlock_1/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_6/resnet/ConvBNBlock_2/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_6/resnet/ConvBNBlock_2/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_6/resnet/ConvBNBlock_3/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_6/resnet/ConvBNBlock_3/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_6/resnet/ConvBNBlock_4/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_6/resnet/ConvBNBlock_4/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_7/resnet/BatchNorm_0/mean | (1,) | 1 | 0.0 | 0.0 | | ODPGrDescBlock_7/resnet/BatchNorm_0/var | (1,) | 1 | 1.0 | 0.0 | | ODPGrDescBlock_7/resnet/ConvBNBlock_0/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_7/resnet/ConvBNBlock_0/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_7/resnet/ConvBNBlock_1/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_7/resnet/ConvBNBlock_1/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_7/resnet/ConvBNBlock_2/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_7/resnet/ConvBNBlock_2/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_7/resnet/ConvBNBlock_3/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_7/resnet/ConvBNBlock_3/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | | ODPGrDescBlock_7/resnet/ConvBNBlock_4/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 | | ODPGrDescBlock_7/resnet/ConvBNBlock_4/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 | +--------------------------------------------------------+-------+------+------+-----+ Total: 5,136 Initial compilation, this might take some minutes... Initial compilation completed. Epoch Time Train_LR Train_Loss Train_SNR Eval_Loss Eval_SNR --------------------------------------------------------------------- 7 1.61e+02 0.001000 0.052377 3.85 0.132320 -4.61 15 3.06e+02 0.001000 0.007551 7.91 0.126599 -4.42 23 4.47e+02 0.001000 0.005597 9.15 0.086118 -2.74 31 5.89e+02 0.001000 0.004938 9.69 0.081113 -2.48 39 7.30e+02 0.001000 0.004495 10.10 0.028649 2.04 47 8.73e+02 0.001000 0.004167 10.42 0.050053 -0.39 55 1.02e+03 0.001000 0.003868 10.75 0.011462 6.03 63 1.16e+03 0.001000 0.003687 10.96 0.004091 10.50 71 1.30e+03 0.001000 0.006230 10.08 21614.328125 -56.68 79 1.44e+03 0.001000 0.012242 6.12 0.019623 3.68 87 1.59e+03 0.001000 0.005031 9.62 0.041972 0.38 95 1.73e+03 0.001000 0.004204 10.39 0.012248 5.73 103 1.87e+03 0.001000 0.003774 10.86 0.005177 9.47 111 2.01e+03 0.001000 0.003523 11.15 0.004176 10.40 119 2.16e+03 0.001000 0.003324 11.41 0.003917 10.68 127 2.30e+03 0.001000 0.003194 11.58 0.004408 10.17 135 2.44e+03 0.001000 0.003065 11.76 0.003885 10.72 143 2.59e+03 0.001000 0.002968 11.90 0.004074 10.51 151 2.73e+03 0.001000 0.002864 12.05 0.003808 10.80 159 2.87e+03 0.001000 0.002801 12.15 0.003001 11.84 167 3.02e+03 0.001000 0.002711 12.29 0.002757 12.21 175 3.16e+03 0.001000 0.002627 12.43 0.002760 12.21 183 3.30e+03 0.001000 0.002596 12.48 0.004782 9.84 191 3.45e+03 0.001000 0.002539 12.58 0.003661 10.98 199 3.59e+03 0.001000 0.002479 12.68 0.002647 12.39
Evaluate on testing data.
del train_ds["image"]
del train_ds["label"]
fmap = sflax.FlaxMap(model, modvar)
del model, modvar
maxn = numtt
start_time = time()
output = fmap(test_ds["image"][:maxn])
time_eval = time() - start_time
output = np.clip(output, a_min=0, a_max=1.0)
epochs = train_conf["num_epochs"]
Evaluate trained model in terms of reconstruction time and data fidelity.
snr_eval = metric.snr(test_ds["label"][:maxn], output)
psnr_eval = metric.psnr(test_ds["label"][:maxn], output)
print(
f"{'ODPNet training':18s}{'epochs:':2s}{epochs:>5d}{'':21s}"
f"{'time[s]:':10s}{time_train:>7.2f}"
)
print(
f"{'ODPNet testing':18s}{'SNR:':5s}{snr_eval:>5.2f}{' dB'}{'':3s}"
f"{'PSNR:':6s}{psnr_eval:>5.2f}{' dB'}{'':3s}{'time[s]:':10s}{time_eval:>7.2f}"
)
ODPNet training epochs: 200 time[s]: 3590.61 ODPNet testing SNR: 12.40 dB PSNR: 22.78 dB time[s]: 11.57
Plot comparison.
np.random.seed(123)
indx = np.random.randint(0, high=maxn)
fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5))
plot.imview(test_ds["label"][indx, ..., 0], title="Ground truth", cbar=None, fig=fig, ax=ax[0])
plot.imview(
test_ds["image"][indx, ..., 0],
title="Sinogram",
cbar=None,
fig=fig,
ax=ax[1],
)
plot.imview(
output[indx, ..., 0],
title="ODPNet Reconstruction\nSNR: %.2f (dB), PSNR: %.2f"
% (
metric.snr(test_ds["label"][indx, ..., 0], output[indx, ..., 0]),
metric.psnr(test_ds["label"][indx, ..., 0], output[indx, ..., 0]),
),
fig=fig,
ax=ax[2],
)
divider = make_axes_locatable(ax[2])
cax = divider.append_axes("right", size="5%", pad=0.2)
fig.colorbar(ax[2].get_images()[0], cax=cax, label="arbitrary units")
fig.show()
Plot convergence statistics. Statistics are generated only if a training cycle was done (i.e. if not reading final epoch results from checkpoint).
if stats_object is not None and len(stats_object.iterations) > 0:
hist = stats_object.history(transpose=True)
fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))
plot.plot(
np.vstack((hist.Train_Loss, hist.Eval_Loss)).T,
x=hist.Epoch,
ptyp="semilogy",
title="Loss function",
xlbl="Epoch",
ylbl="Loss value",
lgnd=("Train", "Test"),
fig=fig,
ax=ax[0],
)
plot.plot(
np.vstack((hist.Train_SNR, hist.Eval_SNR)).T,
x=hist.Epoch,
title="Metric",
xlbl="Epoch",
ylbl="SNR (dB)",
lgnd=("Train", "Test"),
fig=fig,
ax=ax[1],
)
fig.show()