IXI benchmarking quick start

Use the IXI benchmarking scripts in tutorials/IXI_benchmarking to run quick comparisons across models. The notebook in tutorials/IXI_benchmarking/IXI_benchmarking.ipynb contains the full workflow, and the commands below run the core benchmarks.

cd tutorials/IXI_benchmarking
python3.8 -u train_TransMorph.py
python3.8 -u train_TransMorphTVF.py
python3.8 -u train_SITReg.py
python3.8 -u train_SITReg_SPR.py

Notebook snippets

Imports and setup

from torch.utils.tensorboard import SummaryWriter
import os
import sys
import glob

from torch.utils.data import DataLoader
import numpy as np
import torch
from torch import optim
import matplotlib.pyplot as plt
from natsort import natsorted
from torchvision import transforms

from MIR.models import SpatialTransformer, EncoderFeatureExtractor, SITReg, VFA, TransMorphTVF, TransMorph, convex_adam_MIND
from MIR.models.SITReg import ReLUFactory, GroupNormalizerFactory
from MIR.models.SITReg.composable_mapping import DataFormat
from MIR.models.SITReg.deformation_inversion_layer.fixed_point_iteration import (
    AndersonSolver,
    AndersonSolverArguments,
    MaxElementWiseAbsStopCriterion,
    RelativeL2ErrorStopCriterion,
)
import MIR.models.convexAdam.configs_ConvexAdam_MIND as CONFIGS_CVXAdam
import MIR.models.configs_TransMorph as configs_TransMorph
import MIR.models.configs_VFA as CONFIGS_VFA
from data import datasets, trans
import torch.nn.functional as F

Dataset paths

H, W, D = 160, 192, 224
data_dir = './IXI_data/'

Model factory functions

def transmorph_model():
    scale_factor = 1
    config = configs_TransMorph.get_3DTransMorph3Lvl_config()
    config.img_size = (H//scale_factor, W//scale_factor, D//scale_factor)
    config.window_size = (H // 64, W // 64, D // 64)
    config.out_chan = 3
    TM_model = TransMorph(config).cuda('cuda:0')
    return TM_model

def transmorphTVF_model():
    scale_factor = 2
    config = configs_TransMorph.get_3DTransMorph3Lvl_config()
    config.img_size = (H//scale_factor, W//scale_factor, D//scale_factor)
    config.window_size = (H // 64, W // 64, D // 64)
    config.out_chan = 3
    TMTVF_model = TransMorphTVF(config, time_steps=7).cuda('cuda:0')
    return TMTVF_model

def vfa_model():
    scale_factor = 1
    config = CONFIGS_VFA.get_VFA_default_config()
    config.img_size = (H//scale_factor, W//scale_factor, D//scale_factor)
    VFA_model = VFA(config, device='cuda:0')
    return VFA_model

def create_model(INPUT_SHAPE) -> SITReg:
    feature_extractor = EncoderFeatureExtractor(
        n_input_channels=1,
        activation_factory=ReLUFactory(),
        n_features_per_resolution=[12, 16, 32, 64, 128, 128],
        n_convolutions_per_resolution=[2, 2, 2, 2, 2, 2],
        input_shape=INPUT_SHAPE,
        normalizer_factory=GroupNormalizerFactory(2),
    ).cuda()
    AndersonSolver_forward = AndersonSolver(
        MaxElementWiseAbsStopCriterion(min_iterations=2, max_iterations=50, threshold=1e-2),
        AndersonSolverArguments(memory_length=4),
    )
    AndersonSolver_backward = AndersonSolver(
        RelativeL2ErrorStopCriterion(min_iterations=2, max_iterations=50, threshold=1e-2),
        AndersonSolverArguments(memory_length=4),
    )
    network = SITReg(
        feature_extractor=feature_extractor,
        n_transformation_convolutions_per_resolution=[2, 2, 2, 2, 2, 2],
        n_transformation_features_per_resolution=[12, 64, 128, 256, 256, 256],
        max_control_point_multiplier=0.99,
        affine_transformation_type=None,
        input_voxel_size=(1.0, 1.0, 1.0),
        input_shape=INPUT_SHAPE,
        transformation_downsampling_factor=(1.0, 1.0, 1.0),
        forward_fixed_point_solver=AndersonSolver_forward,
        backward_fixed_point_solver=AndersonSolver_backward,
        activation_factory=ReLUFactory(),
        normalizer_factory=GroupNormalizerFactory(4),
    ).cuda()
    return network

def sitreg_model():
    return create_model((H, W, D)).cuda('cuda:0')

def convexadam_model():
    config = CONFIGS_CVXAdam.get_ConvexAdam_MIND_brain_default_config()
    model = convex_adam_MIND
    return {'model': model, 'config': config}

models_dict = {
    'TransMorph': transmorph_model,
    'TransMorphTVF': transmorphTVF_model,
    'VFA': vfa_model,
    'SITReg': sitreg_model,
    'ConvexAdam-MIND': convexadam_model,
}

Data download and extraction

if not os.path.exists(data_dir):
    import gdown
    file_id = '1-VQewCVNj5eTtc3eQGhTM2yXBQmgm8Ol'
    url = f"https://drive.google.com/uc?id={file_id}"
    gdown.download(url, 'IXI_data.zip', quiet=False)
    import zipfile
    with zipfile.ZipFile('IXI_data.zip', 'r') as zip_ref:
        zip_ref.extractall('./')
    os.remove('IXI_data.zip')

if not os.path.exists(data_dir):
    raise ValueError('Data directory not found and download failed.')

Inference loop

spatial_trans = SpatialTransformer((H, W, D)).cuda()

def inference(model_name, model, moving, fixed):
    if model_name == "TransMorph" or model_name == 'VFA':
        with torch.no_grad():
            model.eval()
            moving = moving.cuda()
            fixed = fixed.cuda()
            flow = model((moving, fixed))
    elif model_name == "TransMorphTVF":
        with torch.no_grad():
            model.eval()
            moving = F.avg_pool3d(moving, 2).cuda()
            fixed = F.avg_pool3d(fixed, 2).cuda()
            flow = model((moving, fixed))
            flow = F.interpolate(flow, size=(H, W, D), mode='trilinear', align_corners=True) * 2.0
    elif model_name == "SITReg":
        with torch.no_grad():
            model.eval()
            moving = moving.cuda()
            fixed = fixed.cuda()
            mapping_pair = model(moving, fixed, mappings_for_levels=((0, False),))[0]
            flow = mapping_pair.forward_mapping.sample(data_format=DataFormat.voxel_displacements()).generate_values()
    elif model_name == "ConvexAdam":
        moving = moving.cuda()
        fixed = fixed.cuda()
        convexadam = model()['model']
        config = model()['config']
        flow = convexadam(moving, fixed, config)
    return flow