Brain registration (affine + deformable)
Use the brain registration tutorial notebook in tutorials/brain_registration/deformable_registration.ipynb to run an affine pre-alignment followed by deformable registration with VFA, TransMorphTVF, or ConvexAdam-MIND.
cd tutorials/brain_registration
from MIR.models import SpatialTransformer, VFA, AffineReg3D, TransMorphTVF, convex_adam_MIND
import nibabel as nib
import torch
from MIR import ModelWeights
import MIR.models.configs_VFA as CONFIGS_VFA
import MIR.models.configs_TransMorph as configs_TransMorph
import MIR.models.convexAdam.configs_ConvexAdam_MIND as CONFIGS_CVXAdam
import torch.nn.functional as F
# load images
img_nib = nib.load('sub-01_T1w.nii.gz')
template_nib = nib.load('LUMIR_template.nii.gz')
# affine registration
img_torch = torch.from_numpy(img_nib.get_fdata()[None, None, ...]).float().cuda(0)
template_torch = torch.from_numpy(template_nib.get_fdata()[None, None, ...]).float().cuda(0)
spatial_trans = SpatialTransformer(size=template_torch.shape[2:], mode='bilinear').cuda(0)
affine_model = AffineReg3D(vol_shape=template_torch.shape[2:], dof="affine").cuda(0)
output = affine_model.optimize(img_torch, template_torch, steps_per_scale=(50, 50), verbose=True)
affine_flow = output['flow']
deformed = output['warped']
# deformable registration (VFA)
config = CONFIGS_VFA.get_VFA_default_config()
config.img_size = template_torch.shape[2:]
VFA_model = VFA(config, device='cuda:0').cuda()
weights = torch.load('pretrained_wts/VFA_LUMIR24.pth')[ModelWeights['VFA-LUMIR24-MonoModal']['wts_key']]
VFA_model.load_state_dict(weights)
with torch.no_grad():
deformable_flow = VFA_model((deformed, template_torch))
flow = deformable_flow + spatial_trans(affine_flow, deformable_flow)
final_output = spatial_trans(img_torch, flow)
# deformable registration (TransMorphTVF)
tm_config = configs_TransMorph.get_3DTransMorph3Lvl_config()
tm_config.img_size = tuple(s // 2 for s in template_torch.shape[2:])
tm_config.window_size = tuple(s // 64 for s in template_torch.shape[2:])
tm_config.out_chan = 3
TM_model = TransMorphTVF(tm_config, time_steps=7).cuda(0)
tm_weights = torch.load('pretrained_wts/TransMorphTVF_LUMIR24.pth.tar')[ModelWeights['TransMorphTVF-LUMIR24-MonoModal']['wts_key']]
TM_model.load_state_dict(tm_weights)
with torch.no_grad():
mov_small = F.avg_pool3d(deformed, 2)
fix_small = F.avg_pool3d(template_torch, 2)
tm_flow = TM_model((mov_small, fix_small))
tm_flow = F.interpolate(tm_flow, size=template_torch.shape[2:], mode='trilinear', align_corners=True)
tm_flow = tm_flow + spatial_trans(affine_flow, tm_flow)
tm_output = spatial_trans(img_torch, tm_flow)
# deformable registration (ConvexAdam-MIND)
cvx_config = CONFIGS_CVXAdam.get_ConvexAdam_MIND_brain_default_config()
cvx_flow = convex_adam_MIND(deformed, template_torch, cvx_config)
cvx_flow = cvx_flow + spatial_trans(affine_flow, cvx_flow)
cvx_output = spatial_trans(img_torch, cvx_flow)