Brain registration (affine + deformable)

Use the brain registration tutorial notebook in tutorials/brain_registration/deformable_registration.ipynb to run an affine pre-alignment followed by deformable registration with VFA, TransMorphTVF, or ConvexAdam-MIND.

cd tutorials/brain_registration
from MIR.models import SpatialTransformer, VFA, AffineReg3D, TransMorphTVF, convex_adam_MIND
import nibabel as nib
import torch
from MIR import ModelWeights
import MIR.models.configs_VFA as CONFIGS_VFA
import MIR.models.configs_TransMorph as configs_TransMorph
import MIR.models.convexAdam.configs_ConvexAdam_MIND as CONFIGS_CVXAdam
import torch.nn.functional as F

# load images
img_nib = nib.load('sub-01_T1w.nii.gz')
template_nib = nib.load('LUMIR_template.nii.gz')

# affine registration
img_torch = torch.from_numpy(img_nib.get_fdata()[None, None, ...]).float().cuda(0)
template_torch = torch.from_numpy(template_nib.get_fdata()[None, None, ...]).float().cuda(0)
spatial_trans = SpatialTransformer(size=template_torch.shape[2:], mode='bilinear').cuda(0)
affine_model = AffineReg3D(vol_shape=template_torch.shape[2:], dof="affine").cuda(0)
output = affine_model.optimize(img_torch, template_torch, steps_per_scale=(50, 50), verbose=True)
affine_flow = output['flow']
deformed = output['warped']

# deformable registration (VFA)
config = CONFIGS_VFA.get_VFA_default_config()
config.img_size = template_torch.shape[2:]
VFA_model = VFA(config, device='cuda:0').cuda()
weights = torch.load('pretrained_wts/VFA_LUMIR24.pth')[ModelWeights['VFA-LUMIR24-MonoModal']['wts_key']]
VFA_model.load_state_dict(weights)

with torch.no_grad():
    deformable_flow = VFA_model((deformed, template_torch))
    flow = deformable_flow + spatial_trans(affine_flow, deformable_flow)
    final_output = spatial_trans(img_torch, flow)

# deformable registration (TransMorphTVF)
tm_config = configs_TransMorph.get_3DTransMorph3Lvl_config()
tm_config.img_size = tuple(s // 2 for s in template_torch.shape[2:])
tm_config.window_size = tuple(s // 64 for s in template_torch.shape[2:])
tm_config.out_chan = 3
TM_model = TransMorphTVF(tm_config, time_steps=7).cuda(0)
tm_weights = torch.load('pretrained_wts/TransMorphTVF_LUMIR24.pth.tar')[ModelWeights['TransMorphTVF-LUMIR24-MonoModal']['wts_key']]
TM_model.load_state_dict(tm_weights)

with torch.no_grad():
    mov_small = F.avg_pool3d(deformed, 2)
    fix_small = F.avg_pool3d(template_torch, 2)
    tm_flow = TM_model((mov_small, fix_small))
    tm_flow = F.interpolate(tm_flow, size=template_torch.shape[2:], mode='trilinear', align_corners=True)
    tm_flow = tm_flow + spatial_trans(affine_flow, tm_flow)
    tm_output = spatial_trans(img_torch, tm_flow)

# deformable registration (ConvexAdam-MIND)
cvx_config = CONFIGS_CVXAdam.get_ConvexAdam_MIND_brain_default_config()
cvx_flow = convex_adam_MIND(deformed, template_torch, cvx_config)
cvx_flow = cvx_flow + spatial_trans(affine_flow, cvx_flow)
cvx_output = spatial_trans(img_torch, cvx_flow)