Source code for MIR.models.convexAdam.convex_adam_MIND_SPR

'''
Wrapper for Convex Adam with MIND features
This code is based on the original Convex Adam code from:
Siebert, Hanna, et al. "ConvexAdam: Self-Configuring Dual-Optimisation-Based 3D Multitask Medical Image Registration." IEEE Transactions on Medical Imaging (2024).
https://github.com/multimodallearning/convexAdam

Modified and tested by:
Junyu Chen
jchen245@jhmi.edu
Johns Hopkins University
'''

import argparse
import os
import time
import warnings
from pathlib import Path
from typing import Optional, Union

import nibabel as nib
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.ndimage import distance_transform_edt as edt
from typing import Tuple
from MIR.models.convexAdam.convex_adam_utils import (MINDSSC, correlate, coupled_convex,
                                          inverse_consistency, validate_image)

warnings.filterwarnings("ignore")


def extract_features(
    img_fixed: torch.Tensor,
    img_moving: torch.Tensor,
    mind_r: int,
    mind_d: int,
    use_mask: bool,
    mask_fixed: torch.Tensor,
    mask_moving: torch.Tensor,
    device: torch.device = torch.device("cuda"),
    dtype: torch.dtype = torch.float16,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Extract MIND and/or semantic nnUNet features"""

    # MIND features
    if use_mask:
        H,W,D = img_fixed.shape[-3:]

        #replicate masking
        avg3 = nn.Sequential(nn.ReplicationPad3d(1),nn.AvgPool3d(3,stride=1))
        avg3.to(device)

        mask = (avg3(mask_fixed.view(1,1,H,W,D).to(device))>0.9).float()
        _,idx = edt((mask[0,0,::2,::2,::2]==0).squeeze().cpu().numpy(),return_indices=True)
        fixed_r = F.interpolate((img_fixed[::2,::2,::2].to(device).reshape(-1)[idx[0]*D//2*W//2+idx[1]*D//2+idx[2]]).unsqueeze(0).unsqueeze(0),scale_factor=2,mode='trilinear')
        fixed_r.view(-1)[mask.view(-1)!=0] = img_fixed.to(device).reshape(-1)[mask.view(-1)!=0]

        mask = (avg3(mask_moving.view(1,1,H,W,D).to(device))>0.9).float()
        _,idx = edt((mask[0,0,::2,::2,::2]==0).squeeze().cpu().numpy(),return_indices=True)
        moving_r = F.interpolate((img_moving[::2,::2,::2].to(device).reshape(-1)[idx[0]*D//2*W//2+idx[1]*D//2+idx[2]]).unsqueeze(0).unsqueeze(0),scale_factor=2,mode='trilinear')
        moving_r.view(-1)[mask.view(-1)!=0] = img_moving.to(device).reshape(-1)[mask.view(-1)!=0]

        features_fix = MINDSSC(fixed_r.to(device),mind_r,mind_d,device=device).to(dtype)
        features_mov = MINDSSC(moving_r.to(device),mind_r,mind_d,device=device).to(dtype)
    else:
        img_fixed = img_fixed.unsqueeze(0).unsqueeze(0)
        img_moving = img_moving.unsqueeze(0).unsqueeze(0)
        features_fix = MINDSSC(img_fixed.to(device),mind_r,mind_d,device=device).to(dtype)
        features_mov = MINDSSC(img_moving.to(device),mind_r,mind_d,device=device).to(dtype)

    return features_fix, features_mov


def convex_adam_pt(
    img_fixed: Union[torch.Tensor, np.ndarray, nib.Nifti1Image],
    img_moving: Union[torch.Tensor, np.ndarray, nib.Nifti1Image],
    mind_r: int = 1,
    mind_d: int = 2,
    lambda_weight: float = 1.25,
    logBeta_weight: float = 0.1,
    grid_sp: int = 6,
    disp_hw: int = 4,
    selected_niter: int = 80,
    selected_smooth: int = 0,
    grid_sp_adam: int = 2,
    ic: bool = True,
    use_mask: bool = False,
    path_fixed_mask: Optional[Union[Path, str]] = None,
    path_moving_mask: Optional[Union[Path, str]] = None,
    dtype: torch.dtype = torch.float16,
    verbose: bool = False,
    device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    save_disp: bool = True,
) -> np.ndarray:
    """Coupled convex optimisation with adam instance optimisation"""
    img_fixed = validate_image(img_fixed)
    img_moving = validate_image(img_moving)
    img_fixed = img_fixed.float()
    img_moving = img_moving.float()
    
    if dtype == torch.float16 and device == torch.device("cpu"):
        print("Warning: float16 is not supported on CPU, using float32 instead")
        dtype = torch.float32

    if use_mask:
        mask_fixed = torch.from_numpy(nib.load(path_fixed_mask).get_fdata()).float()
        mask_moving = torch.from_numpy(nib.load(path_moving_mask).get_fdata()).float()
    else:
        mask_fixed = None
        mask_moving = None

    H, W, D = img_fixed.shape

    t0 = time.time()

    # compute features and downsample (using average pooling)
    with torch.no_grad():      
        features_fix, features_mov = extract_features(
            img_fixed=img_fixed,
            img_moving=img_moving,
            mind_r=mind_r,
            mind_d=mind_d,
            use_mask=use_mask,
            mask_fixed=mask_fixed,
            mask_moving=mask_moving,
            device=device,
            dtype=dtype,
        )

        features_fix_smooth = F.avg_pool3d(features_fix,grid_sp,stride=grid_sp)
        features_mov_smooth = F.avg_pool3d(features_mov,grid_sp,stride=grid_sp)

        n_ch = features_fix_smooth.shape[1]

    # compute correlation volume with SSD
    ssd,ssd_argmin = correlate(features_fix_smooth,features_mov_smooth,disp_hw,grid_sp,(H,W,D), n_ch)

    # provide auxiliary mesh grid
    disp_mesh_t = F.affine_grid(disp_hw*torch.eye(3,4).to(device).to(dtype).unsqueeze(0),(1,1,disp_hw*2+1,disp_hw*2+1,disp_hw*2+1),align_corners=True).permute(0,4,1,2,3).reshape(3,-1,1)

    # perform coupled convex optimisation
    disp_soft = coupled_convex(ssd,ssd_argmin,disp_mesh_t,grid_sp,(H,W,D))

    # if "ic" flag is set: make inverse consistent
    if ic:
        scale = torch.tensor([H//grid_sp-1,W//grid_sp-1,D//grid_sp-1]).view(1,3,1,1,1).to(device).to(dtype)/2

        ssd_,ssd_argmin_ = correlate(features_mov_smooth,features_fix_smooth,disp_hw,grid_sp,(H,W,D), n_ch)

        disp_soft_ = coupled_convex(ssd_,ssd_argmin_,disp_mesh_t,grid_sp,(H,W,D))
        disp_ice,_ = inverse_consistency((disp_soft/scale).flip(1),(disp_soft_/scale).flip(1),iter=15)

        disp_hr = F.interpolate(disp_ice.flip(1)*scale*grid_sp,size=(H,W,D),mode='trilinear',align_corners=False)
    
    else:
        disp_hr=disp_soft

    # run Adam instance optimisation
    if lambda_weight > 0:
        with torch.no_grad():
            patch_features_fix = F.avg_pool3d(features_fix,grid_sp_adam,stride=grid_sp_adam)
            patch_features_mov = F.avg_pool3d(features_mov,grid_sp_adam,stride=grid_sp_adam)

        #create optimisable displacement grid
        disp_lr = F.interpolate(disp_hr,size=(H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam),mode='trilinear',align_corners=False)

        net = nn.Sequential(nn.Conv3d(3,1,(H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam),bias=False))
        net[0].weight.data[:] = disp_lr.float().cpu().data/grid_sp_adam
        net.to(device)
        spr_wt = nn.Sequential(nn.Conv3d(1,1,(H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam),bias=False))
        spr_wt[0].weight.data[:] = spr_wt[0].weight.data[:]*0+1.0
        spr_wt.to(device)
        params = [{'params': net.parameters(), 'lr': 1}] + [{'params': spr_wt.parameters(), 'lr': 0.01}] 
        optimizer = torch.optim.Adam(params)

        grid0 = F.affine_grid(torch.eye(3,4).unsqueeze(0).to(device),(1,1,H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam),align_corners=False)
        logBeta_wt = 1 + logBeta_weight
        #run Adam optimisation with diffusion regularisation and B-spline smoothing
        for iter in range(selected_niter):
            optimizer.zero_grad()

            disp_sample = F.avg_pool3d(F.avg_pool3d(F.avg_pool3d(net[0].weight,3,stride=1,padding=1),3,stride=1,padding=1),3,stride=1,padding=1).permute(0,2,3,4,1)
            wt_sample = F.avg_pool3d(F.avg_pool3d(F.avg_pool3d(spr_wt[0].weight,3,stride=1,padding=1),3,stride=1,padding=1),3,stride=1,padding=1).permute(0,2,3,4,1)
            reg_loss = lambda_weight*local_grad_3d(disp_sample, wt_sample)
            reg_wt_loss = log_beta(wt_sample, logBeta_wt)
            reg_loss += reg_wt_loss
            
            scale = torch.tensor([(H//grid_sp_adam-1)/2,(W//grid_sp_adam-1)/2,(D//grid_sp_adam-1)/2]).to(device).unsqueeze(0)
            grid_disp = grid0.view(-1,3).to(device).float()+((disp_sample.view(-1,3))/scale).flip(1).float()

            patch_mov_sampled = F.grid_sample(patch_features_mov.float(),grid_disp.view(1,H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam,3).to(device),align_corners=False,mode='bilinear')

            sampled_cost = (patch_mov_sampled-patch_features_fix).pow(2).mean(1)*12
            loss = sampled_cost.mean()
            (loss+reg_loss).backward()
            optimizer.step()

        fitted_grid = disp_sample.detach().permute(0,4,1,2,3)
        spr_wts_fitted = wt_sample.detach().permute(0,4,1,2,3)
        spr_wts_hr = F.interpolate(spr_wts_fitted,size=(H,W,D),mode='trilinear',align_corners=False)
        disp_hr = F.interpolate(fitted_grid*grid_sp_adam,size=(H,W,D),mode='trilinear',align_corners=False)

        if selected_smooth > 0:
            if selected_smooth % 2 == 0:
                kernel_smooth = selected_smooth+1
                print('selected_smooth should be an odd number, adding 1')

            kernel_smooth = selected_smooth
            padding_smooth = kernel_smooth//2
            disp_hr = F.avg_pool3d(F.avg_pool3d(F.avg_pool3d(disp_hr,kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1)

    t1 = time.time()
    case_time = t1-t0
    if verbose:
        print(f'case time: {case_time}')
    if save_disp:
        x = disp_hr[0,0,:,:,:].cpu().to(dtype).data.numpy()
        y = disp_hr[0,1,:,:,:].cpu().to(dtype).data.numpy()
        z = disp_hr[0,2,:,:,:].cpu().to(dtype).data.numpy()
        displacements = np.stack((x,y,z),3).astype(float)
    else:
        displacements = disp_hr
    return displacements, spr_wts_hr

def log_beta(weight_volume, alpha, eps=1e-6):
    lambdas = torch.clamp(weight_volume, eps, 1.0)
    beta = torch.log(lambdas)
    return (1.-alpha)*beta.mean()

def local_grad_3d(flow, weight_volume):
    dy = torch.abs(flow[0, 1:, :, :] - flow[0, :-1, :, :])
    dx = torch.abs(flow[0, :, 1:, :] - flow[0, :, :-1, :])
    dz = torch.abs(flow[0, :, :, 1:] - flow[0, :, :, :-1])
    dy = dy.pow(2)
    dx = dx.pow(2)
    dz = dz.pow(2)
    d = torch.mean(dx*weight_volume[0, :, 1:, :])+torch.mean(dy*weight_volume[0, 1:, :, :])+torch.mean(dz*weight_volume[0, :, :, 1:])
    grad = d / 3.0
    return grad

def convex_adam(
    path_img_fixed: Union[Path, str],
    path_img_moving: Union[Path, str],
    mind_r: int = 1,
    mind_d: int = 2,
    lambda_weight: float = 1.25,
    logBeta_wteight: float = 0.1,
    grid_sp: int = 6,
    disp_hw: int = 4,
    selected_niter: int = 80,
    selected_smooth: int = 0,
    grid_sp_adam: int = 2,
    ic: bool = True,
    use_mask: bool = False,
    path_fixed_mask: Optional[Union[Path, str]] = None,
    path_moving_mask: Optional[Union[Path, str]] = None,
    result_path: Union[Path, str] = './',
    verbose: bool = False,
) -> None:
    """Coupled convex optimisation with adam instance optimisation"""

    img_fixed = torch.from_numpy(nib.load(path_img_fixed).get_fdata()).float()
    img_moving = torch.from_numpy(nib.load(path_img_moving).get_fdata()).float()

    displacements = convex_adam_pt(
        img_fixed=img_fixed,
        img_moving=img_moving,
        mind_r=mind_r,
        mind_d=mind_d,
        lambda_weight=lambda_weight,
        logBeta_weight=logBeta_wteight,
        grid_sp=grid_sp,
        disp_hw=disp_hw,
        selected_niter=selected_niter,
        selected_smooth=selected_smooth,
        grid_sp_adam=grid_sp_adam,
        ic=ic,
        use_mask=use_mask,
        path_fixed_mask=path_fixed_mask,
        path_moving_mask=path_moving_mask,
        verbose=verbose,
    )

    affine = nib.load(path_img_fixed).affine
    disp_nii = nib.Nifti1Image(displacements, affine)
    nib.save(disp_nii, os.path.join(result_path,'disp.nii.gz'))

[docs] def convex_adam_MIND_SPR( img_moving, img_fixed, configs ) -> None: """Coupled convex optimisation with adam instance optimisation""" displacements, spr_wts = convex_adam_pt( img_fixed=img_fixed[0,0], img_moving=img_moving[0,0], mind_r=configs.mind_r, mind_d=configs.mind_d, lambda_weight=configs.lambda_weight, logBeta_weight=configs.logBeta_weight, grid_sp=configs.grid_sp, disp_hw=configs.disp_hw, selected_niter=configs.selected_niter, selected_smooth=configs.selected_smooth, grid_sp_adam=configs.grid_sp_adam, ic=configs.ic, use_mask=configs.use_mask, path_fixed_mask=configs.path_fixed_mask, path_moving_mask=configs.path_moving_mask, verbose=configs.verbose, save_disp=False, ) return displacements, spr_wts
if __name__=="__main__": parser = argparse.ArgumentParser() parser.add_argument("-f","--path_img_fixed", type=str, required=True) parser.add_argument("-m",'--path_img_moving', type=str, required=True) parser.add_argument('--mind_r', type=int, default=1) parser.add_argument('--mind_d', type=int, default=2) parser.add_argument('--lambda_weight', type=float, default=1.25) parser.add_argument('--grid_sp', type=int, default=6) parser.add_argument('--disp_hw', type=int, default=4) parser.add_argument('--selected_niter', type=int, default=80) parser.add_argument('--selected_smooth', type=int, default=0) parser.add_argument('--grid_sp_adam', type=int, default=2) parser.add_argument('--ic', choices=('True','False'), default='True') parser.add_argument('--use_mask', choices=('True','False'), default='False') parser.add_argument('--path_mask_fixed', type=str, default=None) parser.add_argument('--path_mask_moving', type=str, default=None) parser.add_argument('--result_path', type=str, default='./') args = parser.parse_args() convex_adam( path_img_fixed=args.path_img_fixed, path_img_moving=args.path_img_moving, mind_r=args.mind_r, mind_d=args.mind_d, lambda_weight=args.lambda_weight, logBeta_weight=args.logBeta_weight, grid_sp=args.grid_sp, disp_hw=args.disp_hw, selected_niter=args.selected_niter, selected_smooth=args.selected_smooth, grid_sp_adam=args.grid_sp_adam, ic=(args.ic == 'True'), use_mask=(args.use_mask == 'True'), path_fixed_mask=args.path_mask_fixed, path_moving_mask=args.path_mask_moving, result_path=args.result_path, )