Source code for MIR.deformation_regularizer.GlobalRegularizers

"""Global regularizers for deformation regularization."""

from __future__ import annotations

import torch
import torch.nn.functional as nnf
import MIR.utils.registration_utils as reg_utils


def _apply_penalty(tensor: torch.Tensor, penalty: str) -> torch.Tensor:
    """Apply the configured penalty while preserving legacy behavior."""
    if penalty == 'l2':
        return tensor.pow(2)
    return tensor


def _gradient_loss_2d(y_pred: torch.Tensor, penalty: str) -> torch.Tensor:
    """Compute the mean 2D finite-difference gradient penalty."""
    dy = torch.abs(y_pred[:, :, 1:, :] - y_pred[:, :, :-1, :])
    dx = torch.abs(y_pred[:, :, :, 1:] - y_pred[:, :, :, :-1])

    dx = _apply_penalty(dx, penalty)
    dy = _apply_penalty(dy, penalty)
    return (torch.mean(dx) + torch.mean(dy)) / 2.0


def _gradient_loss_3d(y_pred: torch.Tensor, penalty: str) -> torch.Tensor:
    """Compute the mean 3D finite-difference gradient penalty."""
    dy = torch.abs(y_pred[:, :, 1:, :, :] - y_pred[:, :, :-1, :, :])
    dx = torch.abs(y_pred[:, :, :, 1:, :] - y_pred[:, :, :, :-1, :])
    dz = torch.abs(y_pred[:, :, :, :, 1:] - y_pred[:, :, :, :, :-1])

    dx = _apply_penalty(dx, penalty)
    dy = _apply_penalty(dy, penalty)
    dz = _apply_penalty(dz, penalty)
    return (torch.mean(dx) + torch.mean(dy) + torch.mean(dz)) / 3.0

[docs] class Grad2D(torch.nn.Module): """2D gradient loss.""" def __init__(self, penalty='l1', loss_mult=None): super().__init__() self.penalty = penalty self.loss_mult = loss_mult
[docs] def forward(self, y_pred, y_true): """Compute 2D gradient regularization loss. Args: y_pred: Predicted displacement/field tensor (B, C, H, W). y_true: Unused placeholder for API compatibility. Returns: Scalar gradient penalty. """ #del y_true grad = _gradient_loss_2d(y_pred, self.penalty) if self.loss_mult is not None: grad *= self.loss_mult return grad
[docs] class Grad3d(torch.nn.Module): """3D gradient loss.""" def __init__(self, penalty='l1', loss_mult=None): super().__init__() self.penalty = penalty self.loss_mult = loss_mult
[docs] def forward(self, y_pred, y_true=None): """Compute 3D gradient regularization loss. Args: y_pred: Predicted displacement/field tensor (B, C, H, W, D). y_true: Unused placeholder for API compatibility. Returns: Scalar gradient penalty. """ #del y_true grad = _gradient_loss_3d(y_pred, self.penalty) if self.loss_mult is not None: grad *= self.loss_mult return grad
[docs] class Grad3DiTV(torch.nn.Module): """3D isotropic TV loss.""" def __init__(self): super().__init__()
[docs] def forward(self, y_pred, y_true): """Compute isotropic total-variation loss in 3D. Args: y_pred: Predicted displacement/field tensor (B, C, H, W, D). y_true: Unused placeholder for API compatibility. Returns: Scalar TV penalty. """ #del y_true dy = torch.abs(y_pred[:, :, 1:, 1:, 1:] - y_pred[:, :, :-1, 1:, 1:]) dx = torch.abs(y_pred[:, :, 1:, 1:, 1:] - y_pred[:, :, 1:, :-1, 1:]) dz = torch.abs(y_pred[:, :, 1:, 1:, 1:] - y_pred[:, :, 1:, 1:, :-1]) dy = dy.pow(2) dx = dx.pow(2) dz = dz.pow(2) return torch.mean(torch.sqrt(dx + dy + dz + 1e-6)) / 3.0
[docs] class DisplacementRegularizer(torch.nn.Module): """Compute displacement-field regularization energies.""" def __init__(self, energy_type): super().__init__() self.energy_type = energy_type
[docs] @staticmethod def gradient_dx(fv): return (fv[:, 2:, 1:-1, 1:-1] - fv[:, :-2, 1:-1, 1:-1]) / 2
[docs] @staticmethod def gradient_dy(fv): return (fv[:, 1:-1, 2:, 1:-1] - fv[:, 1:-1, :-2, 1:-1]) / 2
[docs] @staticmethod def gradient_dz(fv): return (fv[:, 1:-1, 1:-1, 2:] - fv[:, 1:-1, 1:-1, :-2]) / 2
[docs] @staticmethod def gradient_txyz(txyz, gradient_fn): return torch.stack([gradient_fn(txyz[:, axis, ...]) for axis in (0, 1, 2)], dim=1)
[docs] def compute_gradient_norm(self, displacement, flag_l1=False): """Compute L1/L2 gradient norm of a displacement field. Args: displacement: Tensor (B, 3, H, W, D). flag_l1: If True, uses L1 norm; otherwise L2. Returns: Scalar gradient norm. """ dTdx = self.gradient_txyz(displacement, self.gradient_dx) dTdy = self.gradient_txyz(displacement, self.gradient_dy) dTdz = self.gradient_txyz(displacement, self.gradient_dz) if flag_l1: norms = torch.abs(dTdx) + torch.abs(dTdy) + torch.abs(dTdz) else: norms = dTdx.pow(2) + dTdy.pow(2) + dTdz.pow(2) return torch.mean(norms) / 3.0
[docs] def compute_bending_energy(self, displacement): """Compute bending energy of a displacement field. Args: displacement: Tensor (B, 3, H, W, D). Returns: Scalar bending energy. """ dTdx = self.gradient_txyz(displacement, self.gradient_dx) dTdy = self.gradient_txyz(displacement, self.gradient_dy) dTdz = self.gradient_txyz(displacement, self.gradient_dz) dTdxx = self.gradient_txyz(dTdx, self.gradient_dx) dTdyy = self.gradient_txyz(dTdy, self.gradient_dy) dTdzz = self.gradient_txyz(dTdz, self.gradient_dz) dTdxy = self.gradient_txyz(dTdx, self.gradient_dy) dTdyz = self.gradient_txyz(dTdy, self.gradient_dz) dTdxz = self.gradient_txyz(dTdx, self.gradient_dz) return torch.mean( dTdxx.pow(2) + dTdyy.pow(2) + dTdzz.pow(2) + 2 * dTdxy.pow(2) + 2 * dTdxz.pow(2) + 2 * dTdyz.pow(2) )
[docs] def forward(self, disp, _): if self.energy_type == 'bending': energy = self.compute_bending_energy(disp) elif self.energy_type == 'gradient-l2': energy = self.compute_gradient_norm(disp) elif self.energy_type == 'gradient-l1': energy = self.compute_gradient_norm(disp, flag_l1=True) else: raise Exception('Not recognised local regulariser!') return energy
[docs] class GradICON3d(torch.nn.Module): """Gradient-ICON loss for 3D displacement fields.""" def __init__(self, flow_shape, penalty='l2', loss_mult=None, both_dirs=False, device='cpu'): """Initialize the GradICON penalty module.""" super().__init__() if penalty not in ('l1', 'l2'): raise ValueError("penalty must be 'l1' or 'l2'") self.stn = reg_utils.SpatialTransformer(flow_shape).to(device) self.penalty = penalty self.loss_mult = loss_mult self.both_dirs = both_dirs @staticmethod def _grad3d(disp, penalty): """Compute a finite-difference gradient loss for one displacement.""" return _gradient_loss_3d(disp, penalty)
[docs] def forward(self, flow_fwd, flow_inv): """Compute the GradICON penalty for forward and inverse flows.""" comp_f = flow_inv + self.stn(flow_fwd, flow_inv) loss = self._grad3d(comp_f, self.penalty) if self.both_dirs: comp_b = flow_fwd + self.stn(flow_inv, flow_fwd) loss = 0.5 * (loss + self._grad3d(comp_b, self.penalty)) if self.loss_mult is not None: loss *= self.loss_mult return loss
[docs] class GradICONExact3d(torch.nn.Module): """Paper-faithful Gradient-ICON for 3D flows.""" def __init__(self, vol_shape, penalty='l2', both_dirs=False, device='cpu'): super().__init__() if penalty not in ('l1', 'l2'): raise ValueError("penalty must be 'l1' or 'l2'") self.D, self.H, self.W = vol_shape self.penalty = penalty self.both_dirs = both_dirs self.device = device self.stn = reg_utils.SpatialTransformer(vol_shape).to(device) self.dx_vox = torch.tensor( [size - 1 for size in vol_shape], dtype=torch.float32, device=device, ) * 1e-3 z = torch.linspace(0, self.D - 1, self.D, device=device) y = torch.linspace(0, self.H - 1, self.H, device=device) x = torch.linspace(0, self.W - 1, self.W, device=device) zz, yy, xx = torch.meshgrid(z, y, x, indexing='ij') self.grid_full = torch.stack([xx, yy, zz], dim=-1).view(-1, 3) self.Nsub = self.grid_full.size(0) // 8 self.register_buffer('eye3', torch.eye(3)) @staticmethod def _fro(diff, penalty): if penalty == 'l1': return diff.abs().sum((-2, -1)) return diff.pow(2).sum((-2, -1)) def _compose_disp(self, flow_fwd, flow_inv): """Return the displacement of the composed transform on the voxel grid.""" return flow_inv + self.stn(flow_fwd, flow_inv) def _jacobian_samples(self, disp, pts_vox): """Sample a finite-difference Jacobian at random voxel locations.""" B, N, _ = pts_vox.shape def sample(f, p): p_norm = p.clone() p_norm[..., 0] = 2 * (p[..., 0] / (self.W - 1) - 0.5) p_norm[..., 1] = 2 * (p[..., 1] / (self.H - 1) - 0.5) p_norm[..., 2] = 2 * (p[..., 2] / (self.D - 1) - 0.5) g = p_norm.view(B, N, 1, 1, 3) v = nnf.grid_sample( f, g, align_corners=False, mode='bilinear', padding_mode='border', ) return v.view(B, 3, N).permute(0, 2, 1) x = pts_vox phi_x = x + sample(disp, x) grads = [] for axis, dx in enumerate(self.dx_vox): x_shift = x.clone() x_shift[..., axis] += dx phi_shift = x_shift + sample(disp, x_shift) grads.append((phi_shift - phi_x) / dx) return torch.stack(grads, dim=-1)
[docs] def forward(self, flow_fwd, flow_inv): """Compute the exact GradICON penalty for sampled voxel locations.""" B = flow_fwd.size(0) idx = torch.randperm(self.grid_full.size(0), device=self.device)[:self.Nsub] pts0 = self.grid_full[idx].unsqueeze(0).repeat(B, 1, 1) comp_disp = self._compose_disp(flow_fwd, flow_inv) J = self._jacobian_samples(comp_disp, pts0) diff = J - self.eye3.to(J.device) loss = self._fro(diff, self.penalty).mean() if self.both_dirs: comp_disp_b = self._compose_disp(flow_inv, flow_fwd) J_b = self._jacobian_samples(comp_disp_b, pts0) diff_b = J_b - self.eye3 loss = 0.5 * (loss + self._fro(diff_b, self.penalty).mean()) return loss