Source code for MIR.deformation_regularizer.GlobalRegularizers

'''
Global regularizers for deformation regularization.
Modified and tested by:
Junyu Chen
jchen245@jhmi.edu
Johns Hopkins University
'''

import torch
import MIR.utils.registration_utils as reg_utils
import torch.nn.functional as nnf

[docs] class Grad2D(torch.nn.Module): """ 2D gradient loss. """ def __init__(self, penalty='l1', loss_mult=None): super(Grad2D, self).__init__() self.penalty = penalty self.loss_mult = loss_mult
[docs] def forward(self, y_pred, y_true): """Compute 2D gradient regularization loss. Args: y_pred: Predicted displacement/field tensor (B, C, H, W). y_true: Unused placeholder for API compatibility. Returns: Scalar gradient penalty. """ dy = torch.abs(y_pred[:, :, 1:, :] - y_pred[:, :, :-1, :]) dx = torch.abs(y_pred[:, :, :, 1:] - y_pred[:, :, :, :-1]) if self.penalty == 'l2': dy = dy * dy dx = dx * dx d = torch.mean(dx) + torch.mean(dy) grad = d / 2.0 if self.loss_mult is not None: grad *= self.loss_mult return grad
[docs] class Grad3d(torch.nn.Module): """ 3D gradient loss. """ def __init__(self, penalty='l1', loss_mult=None): super(Grad3d, self).__init__() self.penalty = penalty self.loss_mult = loss_mult
[docs] def forward(self, y_pred, y_true=None): """Compute 3D gradient regularization loss. Args: y_pred: Predicted displacement/field tensor (B, C, H, W, D). y_true: Unused placeholder for API compatibility. Returns: Scalar gradient penalty. """ dy = torch.abs(y_pred[:, :, 1:, :, :] - y_pred[:, :, :-1, :, :]) dx = torch.abs(y_pred[:, :, :, 1:, :] - y_pred[:, :, :, :-1, :]) dz = torch.abs(y_pred[:, :, :, :, 1:] - y_pred[:, :, :, :, :-1]) if self.penalty == 'l2': dy = dy * dy dx = dx * dx dz = dz * dz d = torch.mean(dx) + torch.mean(dy) + torch.mean(dz) grad = d / 3.0 if self.loss_mult is not None: grad *= self.loss_mult return grad
[docs] class Grad3DiTV(torch.nn.Module): """ 3D gradient Isotropic TV loss. """ def __init__(self): super(Grad3DiTV, self).__init__() a = 1
[docs] def forward(self, y_pred, y_true): """Compute isotropic total-variation loss in 3D. Args: y_pred: Predicted displacement/field tensor (B, C, H, W, D). y_true: Unused placeholder for API compatibility. Returns: Scalar TV penalty. """ dy = torch.abs(y_pred[:, :, 1:, 1:, 1:] - y_pred[:, :, :-1, 1:, 1:]) dx = torch.abs(y_pred[:, :, 1:, 1:, 1:] - y_pred[:, :, 1:, :-1, 1:]) dz = torch.abs(y_pred[:, :, 1:, 1:, 1:] - y_pred[:, :, 1:, 1:, :-1]) dy = dy * dy dx = dx * dx dz = dz * dz d = torch.mean(torch.sqrt(dx+dy+dz+1e-6)) grad = d / 3.0 return grad
[docs] class DisplacementRegularizer(torch.nn.Module): """Compute displacement-field regularization energies.""" def __init__(self, energy_type): super().__init__() self.energy_type = energy_type
[docs] def gradient_dx(self, fv): return (fv[:, 2:, 1:-1, 1:-1] - fv[:, :-2, 1:-1, 1:-1]) / 2
[docs] def gradient_dy(self, fv): return (fv[:, 1:-1, 2:, 1:-1] - fv[:, 1:-1, :-2, 1:-1]) / 2
[docs] def gradient_dz(self, fv): return (fv[:, 1:-1, 1:-1, 2:] - fv[:, 1:-1, 1:-1, :-2]) / 2
[docs] def gradient_txyz(self, Txyz, fn): return torch.stack([fn(Txyz[:,i,...]) for i in [0, 1, 2]], dim=1)
[docs] def compute_gradient_norm(self, displacement, flag_l1=False): """Compute L1/L2 gradient norm of a displacement field. Args: displacement: Tensor (B, 3, H, W, D). flag_l1: If True, uses L1 norm; otherwise L2. Returns: Scalar gradient norm. """ dTdx = self.gradient_txyz(displacement, self.gradient_dx) dTdy = self.gradient_txyz(displacement, self.gradient_dy) dTdz = self.gradient_txyz(displacement, self.gradient_dz) if flag_l1: norms = torch.abs(dTdx) + torch.abs(dTdy) + torch.abs(dTdz) else: norms = dTdx**2 + dTdy**2 + dTdz**2 return torch.mean(norms)/3.0
[docs] def compute_bending_energy(self, displacement): """Compute bending energy of a displacement field. Args: displacement: Tensor (B, 3, H, W, D). Returns: Scalar bending energy. """ dTdx = self.gradient_txyz(displacement, self.gradient_dx) dTdy = self.gradient_txyz(displacement, self.gradient_dy) dTdz = self.gradient_txyz(displacement, self.gradient_dz) dTdxx = self.gradient_txyz(dTdx, self.gradient_dx) dTdyy = self.gradient_txyz(dTdy, self.gradient_dy) dTdzz = self.gradient_txyz(dTdz, self.gradient_dz) dTdxy = self.gradient_txyz(dTdx, self.gradient_dy) dTdyz = self.gradient_txyz(dTdy, self.gradient_dz) dTdxz = self.gradient_txyz(dTdx, self.gradient_dz) return torch.mean(dTdxx**2 + dTdyy**2 + dTdzz**2 + 2*dTdxy**2 + 2*dTdxz**2 + 2*dTdyz**2)
[docs] def forward(self, disp, _): if self.energy_type == 'bending': energy = self.compute_bending_energy(disp) elif self.energy_type == 'gradient-l2': energy = self.compute_gradient_norm(disp) elif self.energy_type == 'gradient-l1': energy = self.compute_gradient_norm(disp, flag_l1=True) else: raise Exception('Not recognised local regulariser!') return energy
[docs] class GradICON3d(torch.nn.Module): """ Gradient‑ICON loss for 3‑D displacement fields. Penalises the Frobenius‑norm of the Jacobian of the composition Φ^{AB}∘Φ^{BA} (forward ◦ inverse). """ def __init__(self, flow_shape, penalty='l2', loss_mult=None, both_dirs=False, device='cpu'): """ Args ---- stn : instance of SpatialTransformer (warps tensors by displacements) penalty : 'l1' or 'l2' loss_mult : optional scalar multiplier both_dirs : if True also penalise the reverse composition Φ^{BA}∘Φ^{AB} and average the two losses """ super().__init__() if penalty not in ('l1', 'l2'): raise ValueError("penalty must be 'l1' or 'l2'") self.stn = reg_utils.SpatialTransformer(flow_shape).to(device) self.penalty = penalty self.loss_mult = loss_mult self.both_dirs = both_dirs @staticmethod def _grad3d(disp, p): """finite‑difference gradient loss for one displacement""" dy = torch.abs(disp[:, :, 1:, :, :] - disp[:, :, :-1, :, :]) dx = torch.abs(disp[:, :, :, 1:, :] - disp[:, :, :, :-1, :]) dz = torch.abs(disp[:, :, :, :, 1:] - disp[:, :, :, :, :-1]) if p == 'l2': dy = dy * dy dx = dx * dx dz = dz * dz return (dx.mean() + dy.mean() + dz.mean()) / 3.0
[docs] def forward(self, flow_fwd, flow_inv): """ Returns ------- loss : scalar GradICON penalty """ # Φ^{AB}∘Φ^{BA} − Id (displacement form) comp_f = flow_inv + self.stn(flow_fwd, flow_inv) loss = self._grad3d(comp_f, self.penalty) if self.both_dirs: # Φ^{BA}∘Φ^{AB} − Id comp_b = flow_fwd + self.stn(flow_inv, flow_fwd) loss = 0.5 * (loss + self._grad3d(comp_b, self.penalty)) if self.loss_mult is not None: loss *= self.loss_mult return loss
[docs] class GradICONExact3d(torch.nn.Module): """ Paper‑faithful Gradient‑ICON for 3‑D flows """ def __init__(self, vol_shape, penalty='l2', both_dirs=False, device='cpu'): super().__init__() if penalty not in ('l1', 'l2'): raise ValueError("penalty must be 'l1' or 'l2'") self.D, self.H, self.W = vol_shape self.penalty = penalty self.both_dirs = both_dirs self.device = device # unchanged voxelmorph ST‑N self.stn = reg_utils.SpatialTransformer(vol_shape).to(device) # Δx = 1e‑3 in unit‑cube coords → voxel step self.dx_vox = torch.tensor([s - 1 for s in vol_shape], dtype=torch.float32, device=device) * 1e-3 # full identity grid in voxel coords (D,H,W,3) z = torch.linspace(0, self.D - 1, self.D, device=device) y = torch.linspace(0, self.H - 1, self.H, device=device) x = torch.linspace(0, self.W - 1, self.W, device=device) zz, yy, xx = torch.meshgrid(z, y, x, indexing='ij') self.grid_full = torch.stack([xx, yy, zz], dim=-1).view(-1, 3) # (N,3) self.Nsub = self.grid_full.size(0) // 8 # vox/2³ self.register_buffer('eye3', torch.eye(3)) # ---------------------------------------------------------------- utilities @staticmethod def _fro(diff, p): if p == 'l1': return diff.abs().sum((-2, -1)) # ||·||_F (L¹) else: return diff.pow(2).sum((-2, -1)) # ||·||_F² # ---------------------------------------------------------------- helpers def _compose_disp(self, flow_fwd, flow_inv): """ Returns displacement of Φ_AB∘Φ_BA on the voxel grid (B,3,D,H,W) """ return flow_inv + self.stn(flow_fwd, flow_inv) # exactly as in your code def _jacobian_samples(self, disp, pts_vox): """ disp : (B,3,D,H,W) displacement field in *voxel* units pts_vox : (B,N,3) random sample points in *voxel* coords returns : (B,N,3,3) finite‑difference Jacobian at those points """ B, N, _ = pts_vox.shape # helper to sample disp at arbitrary pts via grid_sample def sample(f, p): p_norm = p.clone() # [0,1]³ → [-1,1]³ p_norm[..., 0] = 2 * (p[..., 0] / (self.W - 1) - 0.5) p_norm[..., 1] = 2 * (p[..., 1] / (self.H - 1) - 0.5) p_norm[..., 2] = 2 * (p[..., 2] / (self.D - 1) - 0.5) g = p_norm.view(B, N, 1, 1, 3) v = nnf.grid_sample(f, g, align_corners=False, mode='bilinear', padding_mode='border') return v.view(B, 3, N).permute(0, 2, 1) # B×N×3 # Φ(x) = x + disp(x) x = pts_vox # B×N×3 phi_x = x + sample(disp, x) grads = [] # finite differences: (Φ(x+Δx e_i) - Φ(x)) / Δx for axis, dx in enumerate(self.dx_vox): x_shift = x.clone() x_shift[..., axis] += dx phi_shift = x_shift + sample(disp, x_shift) grads.append((phi_shift - phi_x) / dx) return torch.stack(grads, dim=-1) # B×N×3×3 # ---------------------------------------------------------------- forward
[docs] def forward(self, flow_fwd, flow_inv): """ flow_fwd, flow_inv : (B,3,D,H,W) voxel‑unit displacements """ B = flow_fwd.size(0) # uniform random sub‑sample of voxel centres idx = torch.randperm(self.grid_full.size(0), device=self.device)[:self.Nsub] pts0 = self.grid_full[idx].unsqueeze(0).repeat(B, 1, 1) # B×Ns×3 # Φ_AB∘Φ_BA displacement field comp_disp = self._compose_disp(flow_fwd, flow_inv) # Jacobian of composition at sampled points J = self._jacobian_samples(comp_disp, pts0) diff = J - self.eye3.to(J.device) # ∇Φ − I loss = self._fro(diff, self.penalty).mean() if self.both_dirs: # reverse term comp_disp_b = self._compose_disp(flow_inv, flow_fwd) J_b = self._jacobian_samples(comp_disp_b, pts0) diff_b = J_b - self.eye3 loss = 0.5 * (loss + self._fro(diff_b, self.penalty).mean()) return loss