IXI benchmarking quick start
Use the IXI benchmarking scripts in tutorials/IXI_benchmarking to run quick comparisons across models. The notebook in tutorials/IXI_benchmarking/IXI_benchmarking.ipynb contains the full workflow, and the commands below run the core benchmarks.
cd tutorials/IXI_benchmarking
python3.8 -u train_TransMorph.py
python3.8 -u train_TransMorphTVF.py
python3.8 -u train_SITReg.py
python3.8 -u train_SITReg_SPR.py
Notebook snippets
Imports and setup
from torch.utils.tensorboard import SummaryWriter
import os
import sys
import glob
from torch.utils.data import DataLoader
import numpy as np
import torch
from torch import optim
import matplotlib.pyplot as plt
from natsort import natsorted
from torchvision import transforms
from MIR.models import SpatialTransformer, EncoderFeatureExtractor, SITReg, VFA, TransMorphTVF, TransMorph, convex_adam_MIND
from MIR.models.SITReg import ReLUFactory, GroupNormalizerFactory
from MIR.models.SITReg.composable_mapping import DataFormat
from MIR.models.SITReg.deformation_inversion_layer.fixed_point_iteration import (
AndersonSolver,
AndersonSolverArguments,
MaxElementWiseAbsStopCriterion,
RelativeL2ErrorStopCriterion,
)
import MIR.models.convexAdam.configs_ConvexAdam_MIND as CONFIGS_CVXAdam
import MIR.models.configs_TransMorph as configs_TransMorph
import MIR.models.configs_VFA as CONFIGS_VFA
from data import datasets, trans
import torch.nn.functional as F
Dataset paths
H, W, D = 160, 192, 224
data_dir = './IXI_data/'
Model factory functions
def transmorph_model():
scale_factor = 1
config = configs_TransMorph.get_3DTransMorph3Lvl_config()
config.img_size = (H//scale_factor, W//scale_factor, D//scale_factor)
config.window_size = (H // 64, W // 64, D // 64)
config.out_chan = 3
TM_model = TransMorph(config).cuda('cuda:0')
return TM_model
def transmorphTVF_model():
scale_factor = 2
config = configs_TransMorph.get_3DTransMorph3Lvl_config()
config.img_size = (H//scale_factor, W//scale_factor, D//scale_factor)
config.window_size = (H // 64, W // 64, D // 64)
config.out_chan = 3
TMTVF_model = TransMorphTVF(config, time_steps=7).cuda('cuda:0')
return TMTVF_model
def vfa_model():
scale_factor = 1
config = CONFIGS_VFA.get_VFA_default_config()
config.img_size = (H//scale_factor, W//scale_factor, D//scale_factor)
VFA_model = VFA(config, device='cuda:0')
return VFA_model
def create_model(INPUT_SHAPE) -> SITReg:
feature_extractor = EncoderFeatureExtractor(
n_input_channels=1,
activation_factory=ReLUFactory(),
n_features_per_resolution=[12, 16, 32, 64, 128, 128],
n_convolutions_per_resolution=[2, 2, 2, 2, 2, 2],
input_shape=INPUT_SHAPE,
normalizer_factory=GroupNormalizerFactory(2),
).cuda()
AndersonSolver_forward = AndersonSolver(
MaxElementWiseAbsStopCriterion(min_iterations=2, max_iterations=50, threshold=1e-2),
AndersonSolverArguments(memory_length=4),
)
AndersonSolver_backward = AndersonSolver(
RelativeL2ErrorStopCriterion(min_iterations=2, max_iterations=50, threshold=1e-2),
AndersonSolverArguments(memory_length=4),
)
network = SITReg(
feature_extractor=feature_extractor,
n_transformation_convolutions_per_resolution=[2, 2, 2, 2, 2, 2],
n_transformation_features_per_resolution=[12, 64, 128, 256, 256, 256],
max_control_point_multiplier=0.99,
affine_transformation_type=None,
input_voxel_size=(1.0, 1.0, 1.0),
input_shape=INPUT_SHAPE,
transformation_downsampling_factor=(1.0, 1.0, 1.0),
forward_fixed_point_solver=AndersonSolver_forward,
backward_fixed_point_solver=AndersonSolver_backward,
activation_factory=ReLUFactory(),
normalizer_factory=GroupNormalizerFactory(4),
).cuda()
return network
def sitreg_model():
return create_model((H, W, D)).cuda('cuda:0')
def convexadam_model():
config = CONFIGS_CVXAdam.get_ConvexAdam_MIND_brain_default_config()
model = convex_adam_MIND
return {'model': model, 'config': config}
models_dict = {
'TransMorph': transmorph_model,
'TransMorphTVF': transmorphTVF_model,
'VFA': vfa_model,
'SITReg': sitreg_model,
'ConvexAdam-MIND': convexadam_model,
}
Data download and extraction
if not os.path.exists(data_dir):
import gdown
file_id = '1-VQewCVNj5eTtc3eQGhTM2yXBQmgm8Ol'
url = f"https://drive.google.com/uc?id={file_id}"
gdown.download(url, 'IXI_data.zip', quiet=False)
import zipfile
with zipfile.ZipFile('IXI_data.zip', 'r') as zip_ref:
zip_ref.extractall('./')
os.remove('IXI_data.zip')
if not os.path.exists(data_dir):
raise ValueError('Data directory not found and download failed.')
Inference loop
spatial_trans = SpatialTransformer((H, W, D)).cuda()
def inference(model_name, model, moving, fixed):
if model_name == "TransMorph" or model_name == 'VFA':
with torch.no_grad():
model.eval()
moving = moving.cuda()
fixed = fixed.cuda()
flow = model((moving, fixed))
elif model_name == "TransMorphTVF":
with torch.no_grad():
model.eval()
moving = F.avg_pool3d(moving, 2).cuda()
fixed = F.avg_pool3d(fixed, 2).cuda()
flow = model((moving, fixed))
flow = F.interpolate(flow, size=(H, W, D), mode='trilinear', align_corners=True) * 2.0
elif model_name == "SITReg":
with torch.no_grad():
model.eval()
moving = moving.cuda()
fixed = fixed.cuda()
mapping_pair = model(moving, fixed, mappings_for_levels=((0, False),))[0]
flow = mapping_pair.forward_mapping.sample(data_format=DataFormat.voxel_displacements()).generate_values()
elif model_name == "ConvexAdam":
moving = moving.cuda()
fixed = fixed.cuda()
convexadam = model()['model']
config = model()['config']
flow = convexadam(moving, fixed, config)
return flow