API Reference

MIR.models

Model zoo and configuration helpers for MIR registration pipelines.

MIR.utils

Utility helpers for registration, training, visualization, and IO.

MIR.image_similarity

Image similarity losses and metrics used in MIR registration.

MIR.deformation_regularizer

Regularization terms for deformation fields and velocity models.

MIR.accuracy_measures

Accuracy metrics and Jacobian-based measures for evaluation.

Model zoo and configuration helpers for MIR registration pipelines.

class MIR.models.AdvancedDecoder3D(*args: Any, **kwargs: Any)[source]

Bases: Module

Advanced 3D decoder for multi-resolution encoder features.

Parameters:
  • encoder_channels – List of channel counts for each feature level.

  • aspp_out – Number of channels after ASPP.

  • num_classes – Number of output classes.

Inputs:

x_feats: List of feature maps ordered from low to high resolution.

Returns:

Logits tensor of shape [B, num_classes, D0, H0, W0].

forward(x_feats: list) torch.Tensor[source]

Run decoder forward pass.

class MIR.models.AffineReg3D(*args: Any, **kwargs: Any)[source]

Bases: Module

3D affine registration with configurable degrees of freedom. This is an optimization-based approach, not deep learning-based.

Parameters:
  • vol_shape – Spatial shape tuple (D, H, W).

  • dof – Degrees of freedom, one of: “affine”, “rigid”, “translation”, “scaling”.

  • scales – Multi-scale factors (e.g., (0.25, 0.5, 1)).

  • loss_funcs – Loss names per scale (“mse”, “l1”, “ncc”, “fastncc”, “pcc”, “localcorrratio”, “corrratio”, “ssim3d”, “mutualinformation”, “localmutualinformation”, “mind”).

  • loss_weights – Optional loss weights per scale.

  • mode – Sampling mode for SpatialTransformer.

  • batch_size – Parameter batch size (defaults to 1).

  • match_fixed – If True, pad/crop inputs to fixed image shape.

  • pad_mode – Padding mode for size matching.

  • pad_value – Constant value for padding (if pad_mode == “constant”).

apply_affine(moving: torch.Tensor, affine: torch.Tensor, invert: bool = False, target_shape: Sequence[int] | None = None) torch.Tensor[source]

Apply a given affine matrix to the moving image.

Parameters:
  • moving – Moving image tensor (B, C, D, H, W).

  • affine – Affine matrix with shape (3,4), (B,3,4), (4,4), or (B,4,4).

  • invert – If True, apply the inverse affine.

  • target_shape – Optional target shape for padding/cropping.

Returns:

Warped image tensor (B, C, D, H, W).

forward(moving: torch.Tensor, fixed: torch.Tensor, target_shape: Sequence[int] | None = None)[source]

Compute affine registration at full resolution.

Parameters:
  • moving – Moving image tensor (B, C, D, H, W).

  • fixed – Fixed image tensor (B, C, D, H, W).

Returns:

warped, flow, affine, loss, losses (single entry).

Return type:

Dict with keys

optimize(moving: torch.Tensor, fixed: torch.Tensor, target_shape: Sequence[int] | None = None, optimizer: torch.optim.Optimizer | None = None, optimizer_name: str = 'lbfgs', lr: float = 0.01, steps: int = 200, steps_per_scale: Sequence[int] | None = None, return_history: bool = False, verbose: bool = False, normalize: bool = True, lbfgs_history_size: int = 10)[source]

Optimize affine parameters to align moving to fixed.

Parameters:
  • moving – Moving image tensor (B, C, D, H, W).

  • fixed – Fixed image tensor (B, C, D, H, W).

  • target_shape – Optional target spatial shape for padding/cropping.

  • optimizer – Optional optimizer. If None, optimizer_name is used.

  • optimizer_name – “lbfgs” or “adam” when optimizer is None.

  • lr – Learning rate for optimizer if optimizer is None.

  • steps – Number of optimization steps (used if steps_per_scale is None).

  • steps_per_scale – Optional per-scale step counts matching self.scales.

  • return_history – Whether to return loss history.

  • normalize – Whether to normalize images before optimization.

  • lbfgs_history_size – History size for LBFGS.

Returns:

Tuple of (output_dict, loss_history) if return_history else output_dict.

reset_parameters() None[source]

Reset learnable parameters to identity transform.

class MIR.models.AffineTransformer(*args: Any, **kwargs: Any)[source]

Bases: Module

3-D Affine Transformer :param mode: interpolation mode, ‘bilinear’ or ‘nearest’

apply_affine(src, mat)[source]

Apply an affine matrix to a source volume.

Parameters:
  • src – Source tensor (B, C, H, W, D).

  • mat – Affine matrix (B, 3, 4).

Returns:

Warped tensor.

forward(src, affine, scale, translate, shear)[source]

Apply composed affine parameters to a volume.

Parameters:
  • src – Source tensor (B, C, H, W, D).

  • affine – Rotation parameters (B, 3).

  • scale – Scale parameters (B, 3).

  • translate – Translation parameters (B, 3).

  • shear – Shear parameters (B, 6).

Returns:

Tuple of (warped, affine_matrix, inverse_affine_matrix).

class MIR.models.Conv3dReLU(*args: Any, **kwargs: Any)[source]

Bases: Sequential

3D convolution + normalization + LeakyReLU block.

class MIR.models.ConvNeXtSynthHead3D(*args: Any, **kwargs: Any)[source]

Bases: Module

forward(x: torch.Tensor) torch.Tensor[source]

x: (B, C, H, W, D) returns: (B, 1, H, W, D)

class MIR.models.Decoder(*args: Any, **kwargs: Any)[source]

Bases: Module

VFA decoder network.

Inputs:

F: List of fixed-image feature maps. M: List of moving-image feature maps.

Returns:

List of composed grids at each scale.

forward(F, M)[source]
get_candidate_from_tensor(x, dim, kernel=3, stride=1)[source]

Extract local patches into a token tensor.

Parameters:
  • x – Tensor of shape [B, C, *spatial].

  • dim – Spatial dimensionality (2 or 3).

  • kernel – Patch size.

  • stride – Patch stride.

Returns:

Tensor of shape [B, *spatial, P, C] with patch tokens.

to_token(x)[source]

Flatten spatial dimensions and move channels to the last axis.

Parameters:

x – Tensor of shape [B, C, *spatial].

Returns:

Tensor of shape [B, N, C] where N is the flattened spatial size.

class MIR.models.DecoderBlock(*args: Any, **kwargs: Any)[source]

Bases: Module

Decoder block with upsampling and optional skip connection.

forward(x, skip=None)[source]

Forward pass.

Parameters:
  • x – Tensor of shape [B, C, D, H, W].

  • skip – Optional skip tensor concatenated on channel axis.

Returns:

Tensor after upsampling and convolutions.

class MIR.models.DefSwinTransformer(*args: Any, **kwargs: Any)[source]

Bases: Module

Swin Transformer
A PyTorch impl ofSwin Transformer: Hierarchical Vision Transformer using Shifted Windows -

https://arxiv.org/pdf/2103.14030

Parameters:
  • img_size (int | tuple(int)) – Input image size. Default 224

  • patch_size (int | tuple(int)) – Patch size. Default: 4

  • in_chans (int) – Number of input image channels. Default: 3

  • num_classes (int) – Number of classes for classification head. Default: 1000

  • embed_dim (int) – Patch embedding dimension. Default: 96

  • depths (tuple(int)) – Depth of each Swin Transformer layer.

  • num_heads (tuple(int)) – Number of attention heads in different layers.

  • window_size (tuple) – Window size. Default: 7

  • mlp_ratio (float) – Ratio of mlp hidden dim to embedding dim. Default: 4

  • qkv_bias (bool) – If True, add a learnable bias to query, key, value. Default: True

  • qk_scale (float) – Override default qk scale of head_dim ** -0.5 if set. Default: None

  • drop_rate (float) – Dropout rate. Default: 0

  • attn_drop_rate (float) – Attention dropout rate. Default: 0

  • drop_path_rate (float) – Stochastic depth rate. Default: 0.1

  • norm_layer (nn.Module) – Normalization layer. Default: nn.LayerNorm.

  • ape (bool) – If True, add absolute position embedding to the patch embedding. Default: False

  • patch_norm (bool) – If True, add normalization after patch embedding. Default: True

  • use_checkpoint (bool) – Whether to use checkpointing to save memory. Default: False

forward(x)[source]
init_weights(pretrained=None)[source]

Initialize the weights in backbone. :param pretrained: Path to pre-trained weights.

Defaults to None.

train(mode=True)[source]

Convert the model into training mode while keep layers freezed.

class MIR.models.DefSwinTransformerV2(*args: Any, **kwargs: Any)[source]

Bases: Module

Deformable Swin Transformer
A PyTorch impl ofSwin Transformer: Hierarchical Vision Transformer using Shifted Windows -

https://arxiv.org/pdf/2103.14030

Parameters:
  • img_size (int | tuple(int)) – Input image size. Default 224

  • patch_size (int | tuple(int)) – Patch size. Default: 4

  • in_chans (int) – Number of input image channels. Default: 3

  • num_classes (int) – Number of classes for classification head. Default: 1000

  • embed_dim (int) – Patch embedding dimension. Default: 96

  • depths (tuple(int)) – Depth of each Swin Transformer layer.

  • num_heads (tuple(int)) – Number of attention heads in different layers.

  • window_size (tuple) – Window size. Default: 7

  • mlp_ratio (float) – Ratio of mlp hidden dim to embedding dim. Default: 4

  • qkv_bias (bool) – If True, add a learnable bias to query, key, value. Default: True

  • qk_scale (float) – Override default qk scale of head_dim ** -0.5 if set. Default: None

  • drop_rate (float) – Dropout rate. Default: 0

  • attn_drop_rate (float) – Attention dropout rate. Default: 0

  • drop_path_rate (float) – Stochastic depth rate. Default: 0.1

  • norm_layer (nn.Module) – Normalization layer. Default: nn.LayerNorm.

  • ape (bool) – If True, add absolute position embedding to the patch embedding. Default: False

  • patch_norm (bool) – If True, add normalization after patch embedding. Default: True

  • use_checkpoint (bool) – Whether to use checkpointing to save memory. Default: False

forward(x)[source]
init_weights(pretrained=None)[source]

Initialize the weights in backbone. :param pretrained: Path to pre-trained weights.

Defaults to None.

train(mode=True)[source]

Convert the model into training mode while keep layers freezed.

class MIR.models.DoubleConv3d(*args: Any, **kwargs: Any)[source]

Bases: Module

Two-layer 3D convolutional block with instance norm and LeakyReLU.

forward(x)[source]
class MIR.models.EfficientAdvancedSynthHead3D(*args: Any, **kwargs: Any)[source]

Bases: Module

Memory-efficient advanced synthesis head using inverted residuals + ECA attention.

forward(x: torch.Tensor) torch.Tensor[source]

x: (B, C, H, W, D) returns: (B, 1, H, W, D)

class MIR.models.GroupNet3D(*args: Any, **kwargs: Any)[source]

Bases: Module

forward(x)[source]
class MIR.models.HyperTransMorphTVF(*args: Any, **kwargs: Any)[source]

Bases: Module

forward(inputs, hyp_val)[source]
class MIR.models.HyperTransMorphTVFSPR(*args: Any, **kwargs: Any)[source]

Bases: Module

TransMorph TVF with Spatially-varying regularization :param config: Configuration object containing model parameters :param time_steps: Number of time steps for progressive registration :param SVF: Boolean indicating whether to use SVF (Time Stationary Velocity Field) integration :param SVF_steps: Number of steps for SVF integration :param composition: Type of composition for flow integration (‘composition’ or ‘addition’) :param swin_type: Type of Swin Transformer to use (‘swin’ or ‘dswin’)

forward(inputs, hyper_val)[source]

Forward pass for the TransMorphTVFSPR model. :param inputs: Tuple of moving and fixed images (mov, fix).

Returns:

The computed flow field for image registration. x_weight: The spatial weights for regularization.

Return type:

flow

class MIR.models.HyperVFA(*args: Any, **kwargs: Any)[source]

Bases: Module

Hyperparameter-conditioned VFA model.

Parameters:
  • configs – VFA configuration object.

  • device – Device to run the model on.

  • return_orginal – If True, return composed grids and stats.

  • return_all_flows – If True, return flows for all decoder levels.

Forward inputs:

sample: Tuple (mov, fix) tensors. hyper_val: Hyperparameter tensor.

Forward outputs:

Flow(s) depending on flags.

forward(sample, hyper_val)[source]

Run forward registration.

Parameters:
  • sample – Tuple (mov, fix) tensors.

  • hyper_val – Hyperparameter tensor.

Returns:

Output varies by flags (return_orginal, return_all_flows).

class MIR.models.HyperVFASPR(*args: Any, **kwargs: Any)[source]

Bases: Module

HyperVFA model with spatially varying regularization.

Parameters:
  • configs – VFA configuration object.

  • device – Device to run the model on.

  • return_orginal – If True, return composed grids and stats.

  • return_all_flows – If True, return flows for all decoder levels.

Forward inputs:

sample: Tuple (mov, fix) tensors. hyper_val: Hyperparameter tensor.

Forward outputs:

Flow(s) and spatial weights depending on flags.

forward(sample, hyper_val)[source]

Run forward registration.

Parameters:
  • sample – Tuple (mov, fix) tensors.

  • hyper_val – Hyperparameter tensor.

Returns:

Output varies by flags (return_orginal, return_all_flows).

class MIR.models.HyperVxmDense(*args: Any, **kwargs: Any)[source]

Bases: Module

HyperMorph variant of VoxelMorph with hyperparameter conditioning.

Inputs:

input_imgs: Tuple (mov, fix) of tensors [B, 1, *spatial]. hyp_val: Tensor of hyperparameter values.

Returns:

Flow tensor, and optionally warped image if gen_output is True.

forward(input_imgs, hyp_val)[source]

Run HyperMorph forward pass.

Parameters:
  • input_imgs – Tuple (mov, fix) tensors.

  • hyp_val – Hyperparameter tensor.

Returns:

Flow tensor, and optionally warped moving image.

class MIR.models.ListBatchSampler(*args: Any, **kwargs: Any)[source]

Bases: Sampler

Yield precomputed lists of indices as batches.

This allows constructing a DataLoader with arbitrary batch sizes per iteration while still supporting num_workers and automatic collation.

class MIR.models.MeanStream(*args: Any, **kwargs: Any)[source]

Bases: Module

Mean stream for the Deformable Template Network :param cap: Cap for the mean :param in_shape: Input shape

forward(x)[source]
mean_update(pre_mean, pre_count, x, pre_cap=0.0)[source]
class MIR.models.PreAffineToTemplate(*args: Any, **kwargs: Any)[source]

Bases: Module

Affine pre-alignment to a fixed template before deformable registration.

Parameters:
  • mode – Interpolation mode for SpatialTransformer.

  • template_type – Template name to load (“lumir” or “mni”).

  • batch_size – Batch size for affine parameters.

  • dof – Degrees of freedom for affine model.

  • scales – Multi-scale factors for loss computation.

  • loss_funcs – Loss names per scale.

  • device – Device for the affine model and template.

forward(x: torch.Tensor, optimize: bool = True, verbose: bool = False)[source]

Forward pass of the affine registration module :param x: Input image tensor (B, C, H, W, D)

Returns:

Affine registered image tensor (B, C, H, W, D)

class MIR.models.RegistrationHead(*args: Any, **kwargs: Any)[source]

Bases: Sequential

Predict a dense displacement field from decoder features.

class MIR.models.SSLHead1Lvl(*args: Any, **kwargs: Any)[source]

Bases: Module

Self-supervised learning head with one level.

Parameters:
  • encoder – Encoder model.

  • img_size – Image size.

  • num_lvls – Number of levels.

  • channels – Number of channels per level.

  • if_upsamp – Whether to upsample predicted flows.

  • encoder_output_type – Encoder output type (‘single’ or ‘multi’).

  • swap_encoder_order – Whether to swap encoder output order.

  • gen_output – Whether to generate deformed output.

Forward inputs:

inputs: Tuple (mov, fix) tensors.

Forward outputs:

If gen_output, returns (warped, flow); otherwise flow.

forward(inputs)[source]

Run SSL head forward pass.

class MIR.models.SSLHeadNLvl(*args: Any, **kwargs: Any)[source]

Bases: Module

Self-supervised learning head with multiple levels.

Parameters:
  • encoder – Encoder model.

  • img_size – Image size.

  • num_lvls – Number of levels.

  • channels – Number of channels per level.

  • if_upsamp – Whether to upsample predicted flows.

  • encoder_output_type – Encoder output type (‘single’ or ‘multi’).

  • encoder_input_type – Encoder input type (‘single’, ‘multi’, or ‘separate’).

  • swap_encoder_order – Whether to swap encoder output order.

  • gen_output – Whether to generate deformed output.

Forward inputs:

inputs: Tuple (mov, fix) tensors.

Forward outputs:

If gen_output, returns (warped, flow, stats); otherwise (flow, stats).

forward(inputs)[source]

Run SSL head forward pass.

class MIR.models.SpatialTransformer(*args: Any, **kwargs: Any)[source]

Bases: Module

N-D Spatial Transformer Obtained from https://github.com/voxelmorph/voxelmorph :param size: spatial size of the input tensor :param mode: interpolation mode, ‘bilinear’ or ‘nearest’

forward(src, flow)[source]

Warp a source tensor with a displacement field.

Parameters:
  • src – Source tensor (B, C, …).

  • flow – Displacement field (B, ndim, …).

Returns:

Warped tensor.

class MIR.models.SwinTransformer(*args: Any, **kwargs: Any)[source]

Bases: Module

Swin Transformer
A PyTorch impl ofSwin Transformer: Hierarchical Vision Transformer using Shifted Windows -

https://arxiv.org/pdf/2103.14030

Parameters:
  • img_size (int | tuple(int)) – Input image size. Default 224

  • patch_size (int | tuple(int)) – Patch size. Default: 4

  • in_chans (int) – Number of input image channels. Default: 3

  • num_classes (int) – Number of classes for classification head. Default: 1000

  • embed_dim (int) – Patch embedding dimension. Default: 96

  • depths (tuple(int)) – Depth of each Swin Transformer layer.

  • num_heads (tuple(int)) – Number of attention heads in different layers.

  • window_size (tuple) – Window size. Default: 7

  • mlp_ratio (float) – Ratio of mlp hidden dim to embedding dim. Default: 4

  • qkv_bias (bool) – If True, add a learnable bias to query, key, value. Default: True

  • qk_scale (float) – Override default qk scale of head_dim ** -0.5 if set. Default: None

  • drop_rate (float) – Dropout rate. Default: 0

  • attn_drop_rate (float) – Attention dropout rate. Default: 0

  • drop_path_rate (float) – Stochastic depth rate. Default: 0.1

  • norm_layer (nn.Module) – Normalization layer. Default: nn.LayerNorm.

  • ape (bool) – If True, add absolute position embedding to the patch embedding. Default: False

  • patch_norm (bool) – If True, add normalization after patch embedding. Default: True

  • use_checkpoint (bool) – Whether to use checkpointing to save memory. Default: False

forward(x)[source]

Forward function.

init_weights(pretrained=None)[source]

Initialize the weights in backbone. :param pretrained: Path to pre-trained weights.

Defaults to None.

train(mode=True)[source]

Convert the model into training mode while keep layers freezed.

class MIR.models.SynthesisHead3D(*args: Any, **kwargs: Any)[source]

Bases: Module

Simple synthesis head for 3D feature maps.

Parameters:
  • in_channels – Number of input channels.

  • mid_channels – Number of intermediate channels.

  • out_channels – Number of output channels.

  • norm – Normalization type (‘instance’ or ‘batch’).

  • activation – Activation function (‘leaky_relu’ or ‘relu’).

Inputs:

x: Tensor of shape [B, C, D, H, W].

Returns:

Tensor of shape [B, out_channels, D, H, W].

forward(x: torch.Tensor) torch.Tensor[source]

Apply synthesis head.

class MIR.models.SynthesisHead3DAdvanced(*args: Any, **kwargs: Any)[source]

Bases: Module

Synthesis head with residual and SE blocks for 3D features.

Parameters:
  • in_channels – Number of input channels.

  • mid_channels – Number of intermediate channels.

  • num_res_blocks – Number of residual blocks.

  • norm – Normalization type (‘instance’ or ‘group’).

Inputs:

x: Tensor of shape [B, C, D, H, W].

Returns:

Tensor of shape [B, 1, D, H, W].

forward(x: torch.Tensor) torch.Tensor[source]

Apply advanced synthesis head.

class MIR.models.TemplateCreation(*args: Any, **kwargs: Any)[source]

Bases: Module

Deformable Template Network :param reg_model: Registration model :param img_size: Image size :param mean_cap: Mean cap for the MeanStream :param use_sitreg: Use SITReg-style mapping outputs when True; VFA, TransMorph, VoxelMorph-style otherwise :param mode: SpatialTransformer interpolation mode (SITReg path)

forward(inputs)[source]
class MIR.models.TransMorph(*args: Any, **kwargs: Any)[source]

Bases: Module

TransMorph model.

Parameters:
  • config – Configuration object containing model parameters.

  • SVF – Whether to integrate a stationary velocity field.

  • SVF_steps – Number of scaling-and-squaring steps.

  • swin_type – Transformer type (‘swin’, ‘dswin’, ‘dswinv2’).

Forward inputs:

inputs: Tuple (mov, fix) tensors of shape [B, 1, D, H, W].

Forward outputs:

Dense flow tensor of shape [B, 3, D, H, W].

forward(inputs)[source]

Forward pass for the TransMorph model.

Parameters:

inputs – Tuple of moving and fixed images (mov, fix).

Returns:

Predicted dense flow field.

class MIR.models.TransMorphAffine(*args: Any, **kwargs: Any)[source]

Bases: Module

Affine TransMorph head predicting global parameters.

Parameters:
  • config – Configuration object containing model parameters.

  • swin_type – Transformer type (‘swin’, ‘dswin’, ‘dswinv2’).

Forward inputs:

inputs: Tuple (mov, fix) tensors of shape [B, 1, D, H, W].

Forward outputs:

Tuple (aff, scl, trans, shr) with rotation, scale, translation, and shear parameters.

forward(inputs)[source]

Forward pass predicting affine parameters.

Parameters:

inputs – Tuple of moving and fixed images (mov, fix).

Returns:

Tuple (aff, scl, trans, shr) of affine parameters.

softplus(x)[source]

Apply softplus activation.

class MIR.models.TransMorphTVF(*args: Any, **kwargs: Any)[source]

Bases: Module

TransMorph TVF model with progressive flow integration.

Parameters:
  • config – Configuration object containing model parameters.

  • time_steps – Number of time steps for progressive registration.

  • SVF – Whether to integrate a stationary velocity field.

  • SVF_steps – Number of scaling-and-squaring steps.

  • composition – Flow composition strategy (‘composition’ or ‘addition’).

  • swin_type – Transformer type (‘swin’, ‘dswin’, ‘dswinv2’).

Forward inputs:

inputs: Tuple (mov, fix) tensors of shape [B, 1, D, H, W].

Forward outputs:

Dense flow tensor of shape [B, 3, D, H, W].

forward(inputs)[source]

Forward pass for the TransMorphTVF model.

Parameters:

inputs – Tuple of moving and fixed images (mov, fix).

Returns:

Predicted dense flow field.

class MIR.models.TransMorphTVFSPR(*args: Any, **kwargs: Any)[source]

Bases: Module

TransMorph TVF with spatially varying regularization.

Parameters:
  • config – Configuration object containing model parameters.

  • time_steps – Number of time steps for progressive registration.

  • SVF – Whether to integrate a stationary velocity field.

  • SVF_steps – Number of scaling-and-squaring steps.

  • composition – Flow composition strategy (‘composition’ or ‘addition’).

  • swin_type – Transformer type (‘swin’, ‘dswin’, ‘dswinv2’).

Forward inputs:

inputs: Tuple (mov, fix) tensors of shape [B, 1, D, H, W].

Forward outputs:

Tuple (flow, spatial_wts) where flow is [B, 3, D, H, W].

forward(inputs)[source]

Forward pass for the TransMorphTVFSPR model.

Parameters:

inputs – Tuple of moving and fixed images (mov, fix).

Returns:

Tuple (flow, x_weight) or (pos_flow, neg_flow, x_weight) when SVF is enabled.

class MIR.models.TransVFA(*args: Any, **kwargs: Any)[source]

Bases: Module

TransVFA model for image registration.

Parameters:
  • configs_sw – Swin Transformer config.

  • configs – VFA config.

  • device – Device to run the model on.

  • swin_type – Transformer type (‘swin’, ‘dswin’, ‘dswinv2’).

  • return_orginal – If True, return composed grids and stats.

  • return_all_flows – If True, return flows for all decoder levels.

Forward inputs:

sample: Tuple (mov, fix) tensors of shape [B, 1, *spatial].

Forward outputs:

Flow(s) and optional auxiliary outputs depending on flags.

forward(sample)[source]

Run forward registration.

Parameters:

sample – Tuple (mov, fix) tensors.

Returns:

Output varies by flags (return_orginal, return_all_flows).

class MIR.models.VFA(*args: Any, **kwargs: Any)[source]

Bases: Module

VFA model for image registration.

Parameters:
  • configs – VFA configuration object.

  • device – Device to run the model on.

  • return_orginal – If True, return VFA-style composed grids and stats.

  • return_all_flows – If True, return flows for all decoder levels.

  • SVF – If True, integrate flow as stationary velocity field.

  • SVF_steps – Number of scaling-and-squaring steps for integration.

  • return_full – If True, also return warped images and inverse flow.

Forward inputs:

sample: Tuple (mov, fix) tensors of shape [B, 1, *spatial].

Forward outputs:

Depending on flags, returns flow(s) or a results dict.

forward(sample)[source]

Run forward registration.

Parameters:

sample – Tuple (mov, fix) tensors.

Returns:

Output varies by flags (return_orginal, return_all_flows, SVF, return_full). See class docstring for details.

class MIR.models.VFASPR(*args: Any, **kwargs: Any)[source]

Bases: Module

VFA model with spatially varying regularization (VFA-SPR).

Parameters:
  • configs – VFA configuration object.

  • device – Device to run the model on.

  • return_orginal – If True, return VFA-style composed grids and stats.

  • return_all_flows – If True, return flows for all decoder levels.

  • SVF – If True, integrate flow as stationary velocity field.

  • SVF_steps – Number of scaling-and-squaring steps for integration.

  • return_full – If True, also return warped images and inverse flow.

Forward inputs:

sample: Tuple (mov, fix) tensors of shape [B, 1, *spatial].

Forward outputs:

Depending on flags, returns flow(s), spatial weights, or a results dict.

forward(sample)[source]

Run forward registration.

Parameters:

sample – Tuple (mov, fix) tensors.

Returns:

Output varies by flags (return_orginal, return_all_flows, SVF, return_full). See class docstring for details.

class MIR.models.VecInt(*args: Any, **kwargs: Any)[source]

Bases: Module

Integrates a vector field via scaling and squaring. :param inshape: shape of the input tensor :param nsteps: number of integration steps

forward(vec)[source]

Integrate a vector field via scaling and squaring.

Parameters:

vec – Velocity field tensor (B, ndim, …).

Returns:

Integrated displacement field.

class MIR.models.VxmDense(*args: Any, **kwargs: Any)[source]

Bases: Module

VoxelMorph network for nonlinear image registration.

Inputs:

input_imgs: Tuple (mov, fix) of tensors [B, 1, *spatial].

Returns:

Flow tensor, and optionally warped image if gen_output is True.

forward(input_imgs)[source]

Run VoxelMorph forward pass.

Parameters:

input_imgs – Tuple (mov, fix) tensors.

Returns:

Flow tensor, and optionally warped moving image.

class MIR.models.Warp3d(*args: Any, **kwargs: Any)[source]

Bases: Module

forward(x, w)[source]
MIR.models.convex_adam_MIND(img_moving, img_fixed, configs) None[source]

Coupled convex optimisation with adam instance optimisation

MIR.models.convex_adam_MIND_SPR(img_moving, img_fixed, configs) None[source]

Coupled convex optimisation with adam instance optimisation

MIR.models.convex_adam_features(img_moving, img_fixed, configs, initial_disp=None) None[source]

Coupled convex optimisation with adam instance optimisation

MIR.models.convex_adam_seg_features(feat_moving, feat_fixed, configs)[source]
MIR.models.convex_adam_vfa(img_fixed: numpy.ndarray | torch.Tensor, img_moving: numpy.ndarray | torch.Tensor, convex_config=None, vfa_config=None, feature_scales: Iterable[int] | None = None, max_feat_channels: int = 32, use_no_grad: bool = True, vfa_weights_path: str | None = None, vfa_weights_key: str | None = None, vfa_encoder: torch.nn.Module | None = None, initial_disp: torch.Tensor | None = None, device: torch.device = torch.device) torch.Tensor[source]

Register images using ConvexAdam on VFA encoder multiscale features.

Parameters:
  • img_fixed – Fixed image tensor/array.

  • img_moving – Moving image tensor/array.

  • convex_config – ConvexAdam config (defaults to MIND brain config).

  • vfa_config – VFA config (defaults to VFA default config).

  • feature_scales – Iterable of encoder scale indices to use (coarse->fine).

  • max_feat_channels – Channel cap per scale for efficiency.

  • use_no_grad – If True, compute VFA features under no_grad.

  • vfa_weights_path – Optional path to a VFA checkpoint that contains both encoder and decoder.

  • vfa_weights_key – Optional key to read from a checkpoint dict.

  • vfa_encoder – Optional preloaded VFA encoder to reuse across calls.

  • initial_disp – Optional initial displacement (B, 3, H, W, D).

  • device – Torch device.

Returns:

Displacement field tensor (B, 3, H, W, D).

MIR.models.ensemble_average(fields, field_types, weights=None, nb_steps=5, iters=100, device='cpu')[source]

Ensemble average of displacement fields or SVFs. :param fields: list of displacement fields, each (*vol, ndims) :param field_types: parallel list, ‘svf’ or ‘disp’ :param weights: list of floats or None (defaults to equal) :param nb_steps: number of integration steps for VecInt :param iters: number of iterations for fit_warp_to_svf

Returns:

numpy displacement field (*vol, ndims) = exp(Σ w_i v_i)

Return type:

disp_bar

# assume you already have three predictions: svf_A = np.load(“model_A_velocity.npy”) # (*vol, 3) disp_B = np.load(“model_B_disp.npy”) # (*vol, 3) disp_C = np.load(“model_C_disp.npy”)

fields = [svf_A, disp_B, disp_C] field_types = [‘svf’, ‘disp’, ‘disp’] # e.g. let Dice on a small val‑set decide the weights weights = [0.45, 0.35, 0.20]

disp_mean = ensemble_average(

fields, field_types, weights, nb_steps=7, # same nsteps you used during training/inference iters=150, # iterations for fit_warp_to_svf on the disp fields

)

# now warp any tensor (image, mask, logits) with the consensus deformation vol_shape = disp_mean.shape[:-1] transform = SpatialTransformer(vol_shape).cuda() # if GPU used above moving_img_t = torch.from_numpy(moving_img)[None, None].float().cuda() disp_t = torch.from_numpy(disp_mean).permute(3,0,1,2)[None].float().cuda() warped = transform(moving_img_t, disp_t) # (1,1,*vol)

MIR.models.fit_warp_to_svf(warp_t, nb_steps: int = 7, iters: int = 500, min_delta: float = 1e-05, lr: float = 0.1, objective: str = 'mse', init: str = 'warp', output_type: str = 'disp', verbose: bool = True, device: str = 'cpu')[source]

Fit a stationary‑velocity field v so that exp(v) ≈ given displacement field. Parameters mirror the original TF implementation. warp shape: (*vol_shape, ndims) (numpy array or torch tensor) Returns: v as a numpy array of same shape.

MIR.models.fit_warp_to_svf_fast(warp_t, nb_steps: int = 5, iters: int = 50, lr: float = 0.1, downsample_factor: int = 2, refine_iters: int = 10, min_delta: float = 1e-05, warm_start: torch.Tensor | None = None, use_amp: bool = False, output_type: str = 'disp', verbose: bool = False, **kwargs)[source]

Fast approximation to fit_warp_to_svf using coarse-to-fine fitting.

Parameters:
  • warp_t – Displacement field tensor (B, C, *vol).

  • nb_steps – VecInt steps for exponentiation.

  • iters – Optimization iterations for coarse fit.

  • lr – Learning rate for coarse fit.

  • downsample_factor – Integer downsample factor for speed (>1 recommended).

  • output_type – “disp” returns velocity (SVF); “svf” returns displacement.

  • verbose – Print progress during fitting.

  • **kwargs – Passed to fit_warp_to_svf (e.g., objective, init, min_delta).

Returns:

Tensor of same shape as warp_t, either SVF (velocity) or displacement.

MIR.models.get_3DTransMorph3Lvl_config()[source]

Return TransMorph 3-level config.

Returns:

ml_collections.ConfigDict with model hyperparameters.

MIR.models.get_3DTransMorphDWin3Lvl_config()[source]

Return TransMorph 3-level config with dual-window attention.

Returns:

ml_collections.ConfigDict with model hyperparameters.

MIR.models.get_3DTransMorphLarge_config()[source]

A Large TransMorph Network

MIR.models.get_3DTransMorphLrn_config()[source]

TransMorph with Learnable Positional Embedding

MIR.models.get_3DTransMorphNoConvSkip_config()[source]

No skip connections from convolution layers

Computational complexity: 577.34 GMac Number of parameters: 63.56 M

MIR.models.get_3DTransMorphNoRelativePosEmbd_config()[source]

Return TransMorph config without relative positional embeddings.

Returns:

ml_collections.ConfigDict with model hyperparameters.

MIR.models.get_3DTransMorphNoSkip_config()[source]

No skip connections

Computational complexity: 639.93 GMac Number of parameters: 58.4 M

MIR.models.get_3DTransMorphNoTransSkip_config()[source]

No skip connections from Transformer blocks

Computational complexity: 639.93 GMac Number of parameters: 58.4 M

MIR.models.get_3DTransMorphRelativePosEmbdSimple_config()[source]
MIR.models.get_3DTransMorphSin_config()[source]

Return TransMorph config with sinusoidal positional embeddings.

Returns:

ml_collections.ConfigDict with model hyperparameters.

MIR.models.get_3DTransMorphSmall_config()[source]

A Small TransMorph Network

MIR.models.get_3DTransMorphTiny_config()[source]

A Tiny TransMorph Network

MIR.models.get_3DTransMorph_config()[source]

Return base TransMorph 4-level config.

Returns:

ml_collections.ConfigDict with model hyperparameters.

MIR.models.get_ConvexAdam_MIND_brain_default_config()[source]

ConvexAdam MIND default config.

MIR.models.get_VFA_default_config()[source]

VFA default config. config.name: str, name of the model. config.skip: int, skip certain displacement fields in the decoder. config.initialize: float, initialize beta in the decoder. config.downsamples: int, number of downsampling layers in the encoder. config.start_channels: int, number of channels in the first layer of the encoder. config.matching_channels: int, number of channels in the matching layer of the encoder. config.int_steps: int, number of integration steps in the decoder. config.affine: int, whether to use affine transformation in the decoder. config.img_size: tuple, size of the input image. config.in_channels: int, number of input channels. config.max_channels: int, maximum number of channels in the encoder.

MIR.models.get_VXM_1_config()[source]

Return VoxelMorph-1 config.

Returns:

ml_collections.ConfigDict with model hyperparameters.

MIR.models.get_VXM_BJ_config()[source]

Return VoxelMorph-BJ config.

Returns:

ml_collections.ConfigDict with model hyperparameters.

MIR.models.get_VXM_default_config()[source]

Return default VoxelMorph-2 config.

Returns:

ml_collections.ConfigDict with model hyperparameters.

MIR.models.grid_to_flow(grid: torch.Tensor)[source]

Convert an absolute sampling grid (as produced by VFA) to a voxel‑displacement field that VoxelMorph’s SpatialTransformer expects.

Parameters:

grid – Tensor of shape [B, ndim, *spatial_dims] with absolute coords.

Returns:

Tensor of the same shape containing displacements from identity.

MIR.models.invert_warp_via_velocity(warp, nb_steps: int = 5, iters: int = 100, **kwargs)[source]

Approximate inverse by: fit v, then integrate ‑v. Returns a displacement field (numpy) of same shape.

MIR.models.make_epoch_batches(n_samples, min_bs, max_bs)[source]

Create randomized batch index lists for one epoch.

Parameters:
  • n_samples – Total number of samples.

  • min_bs – Minimum batch size.

  • max_bs – Maximum batch size.

Returns:

List of numpy arrays of indices for each batch.

Utility helpers for registration, training, visualization, and IO.

class MIR.utils.AverageMeter[source]

Bases: object

Computes and stores the average and current value

reset()[source]

Reset stored statistics.

update(val, n=1)[source]

Update running statistics.

Parameters:
  • val – New value.

  • n – Weight/count for the value.

class MIR.utils.Logger(save_dir)[source]

Bases: object

flush()[source]

No-op flush for compatibility with file-like interfaces.

write(message)[source]

Write a message to both stdout and the log file.

class MIR.utils.MultiResPatchSampler3D(patch_size)[source]

Bases: object

class MIR.utils.RandomPatchSampler3D(patch_size)[source]

Bases: object

MIR.utils.SLANT_label_reassign(label_map)[source]

Reassign SLANT label IDs to contiguous indices.

Parameters:

label_map – Label map array.

Returns:

Reassigned label map.

class MIR.utils.SpatialTransformer(*args: Any, **kwargs: Any)[source]

Bases: Module

N-D Spatial Transformer Obtained from https://github.com/voxelmorph/voxelmorph

forward(src, flow)[source]

Warp a source image with a displacement field.

Parameters:
  • src – Source tensor (B, C, …).

  • flow – Displacement field (B, ndim, …).

Returns:

Warped source tensor.

class MIR.utils.VecInt(*args: Any, **kwargs: Any)[source]

Bases: Module

Integrates a vector field via scaling and squaring.

forward(vec)[source]

Integrate a vector field via scaling and squaring.

Parameters:

vec – Velocity field tensor (B, ndim, …).

Returns:

Integrated displacement field.

MIR.utils.get_cmap(n, name='nipy_spectral')[source]

Return a matplotlib colormap with n distinct colors.

Parameters:
  • n – Number of discrete colors.

  • name – Matplotlib colormap name.

Returns:

Colormap instance.

MIR.utils.load_partial_weights(model, checkpoint_path, weights_key='state_dict', strict=False)[source]

Load weights from a checkpoint into a model, skipping unmatched layers.

Parameters:
  • model (torch.nn.Module) – Your model with updated architecture.

  • checkpoint_path (str) – Path to the .pth or .pt file.

  • strict (bool) – If True, behaves like standard strict loading. If False, loads what it can.

MIR.utils.make_affine_from_pixdim(pixdim)[source]

Create a 4x4 affine matrix from pixel spacing.

Parameters:

pixdim – Sequence of spacing values (dx, dy, dz).

Returns:

4x4 affine matrix.

MIR.utils.mk_grid_img(grid_step=8, line_thickness=1, grid_sz=(160, 192, 224), dim=0)[source]

Create a grid image tensor for visualization.

Parameters:
  • grid_step – Spacing between grid lines.

  • line_thickness – Grid line thickness.

  • grid_sz – Grid size (H, W, D).

  • dim – Axis along which to draw the grid.

Returns:

Grid tensor with shape (1, 1, H, W, D).

MIR.utils.pad_image(img, target_size)[source]

Pad a 3D tensor to a target size.

Parameters:
  • img – Input tensor (B, C, H, W, D).

  • target_size – Target spatial size (H, W, D).

Returns:

Padded tensor.

MIR.utils.pkload(fname)[source]

Load a pickled object from disk.

Parameters:

fname – Path to pickle file.

Returns:

Loaded object.

MIR.utils.resample_to_orginal_space_and_save(deformed_img, ants_affine_mat_path, img_orig_path, out_back_dir, img_pixdim, if_flip=True, flip_axis=1, interpolater='nearestNeighbor')[source]

Resample a deformed image back to the original image space.

Parameters:
  • deformed_img – Deformed image tensor (B, 1, H, W, D).

  • ants_affine_mat_path – Path to the ANTs affine matrix.

  • img_orig_path – Path to the original image (.nii.gz).

  • out_back_dir – Output directory to save the resampled image.

  • img_pixdim – Pixel spacing used during preprocessing.

  • if_flip – Whether to flip the image along the specified axis.

  • flip_axis – Axis along which to flip the image if if_flip is True.

  • interpolater – ANTs interpolator (‘linear’, ‘nearestNeighbor’, ‘bSpline’).

Returns:

The resampled image in the original space.

Return type:

img_final

MIR.utils.savepkl(data, path)[source]

Save an object to a pickle file.

Parameters:
  • data – Object to serialize.

  • path – Output pickle file path.

MIR.utils.sliding_window_inference(feat: torch.Tensor, head: torch.nn.Module, patch_size: tuple, overlap: float = 0.5, num_classes: int = 133, mode: str = 'argmax') torch.Tensor[source]

Run sliding-window inference on a full-resolution feature volume.

Parameters:
  • feat – Feature tensor (B, C, H, W, D).

  • head – Segmentation head that outputs logits.

  • patch_size – Patch size (ph, pw, pd).

  • overlap – Overlap ratio between patches.

  • num_classes – Number of output classes.

  • mode – Output mode (currently ‘argmax’).

Returns:

Segmentation prediction (B, 1, H, W, D).

MIR.utils.write2csv(line, name)[source]
MIR.utils.zoom_img(img, pixel_dims, order=3)[source]

Resize a 3D image to match target pixel dimensions.

Parameters:
  • img – Input 3D image array.

  • pixel_dims – Target spacing.

  • order – Interpolation order.

Returns:

Resampled image array.

Image similarity losses and metrics used in MIR registration.

class MIR.image_similarity.CompRecon(*args: Any, **kwargs: Any)[source]

Bases: Module

Composite loss for image synthesis without GANs:
  • Charbonnier (smooth L1) to preserve edges

  • Gradient-difference to align edges

  • Focal-frequency to boost high-frequency detail

Parameters:
  • charb_weight (float) – Weight for Charbonnier loss.

  • grad_weight (float) – Weight for gradient difference loss.

  • ff_weight (float) – Weight for focal frequency loss.

  • eps (float) – Small constant to avoid division by zero.

  • alpha (float) – Exponent for focal frequency loss.

  • beta (float) – Exponent for focal frequency loss.

forward(pred: torch.Tensor, target: torch.Tensor) torch.Tensor[source]

Compute composite reconstruction loss.

Parameters:
  • pred – Predicted image tensor.

  • target – Target image tensor.

Returns:

Scalar loss.

class MIR.image_similarity.CorrRatio(*args: Any, **kwargs: Any)[source]

Bases: Module

Correlation Ratio based on Parzen window Implemented by Junyu Chen, jchen245@jhmi.edu TODO: Under testing

The Correlation Ratio as a New Similarity Measure for Multimodal Image Registration by Roche et al. 1998 https://link.springer.com/chapter/10.1007/BFb0056301

correlation_ratio(X, Y)[source]

Compute correlation ratio between two images.

Parameters:
  • X – Image tensor (B, C, H, W, D).

  • Y – Image tensor (B, C, H, W, D).

Returns:

Scalar correlation ratio.

forward(y_true, y_pred)[source]

Return negative symmetric correlation ratio as a loss.

Parameters:
  • y_true – Fixed image tensor (B, 1, …).

  • y_pred – Moving image tensor (B, 1, …).

Returns:

Scalar loss.

gaussian_kernel(diff, preterm)[source]
vol_bin_centers

Sigma for Gaussian approx.

class MIR.image_similarity.DiceLoss(*args: Any, **kwargs: Any)[source]

Bases: Module

Dice loss

forward(y_pred, y_true)[source]

Compute Dice loss.

Parameters:
  • y_pred – Predicted probabilities (B, C, H, W, D).

  • y_true – Integer labels (B, 1, H, W, D) or (B, H, W, D).

Returns:

Scalar Dice loss.

class MIR.image_similarity.DiceSegLoss(*args: Any, **kwargs: Any)[source]

Bases: Module

forward(logits: torch.Tensor, target: torch.Tensor) torch.Tensor[source]

logits: (B, C, ph, pw, pd) raw scores target: (B, ph, pw, pd) integer labels in [0..C-1]

class MIR.image_similarity.FastNCC(*args: Any, **kwargs: Any)[source]

Bases: Module

Local (over window) normalized cross correlation loss by XiJia https://github.com/xi-jia/FastLNCC

# For PyTorch versions > 2.0, if there are numerical differences, please add the following code. torch.backends.cudnn.allow_tf32 = False

forward(y_true, y_pred)[source]
class MIR.image_similarity.LocalCorrRatio(*args: Any, **kwargs: Any)[source]

Bases: Module

Localized Correlation Ratio based on Parzen window Implemented by Junyu Chen, jchen245@jhmi.edu TODO: Under testing

The Correlation Ratio as a New Similarity Measure for Multimodal Image Registration by Roche et al. 1998 https://link.springer.com/chapter/10.1007/BFb0056301

correlation_ratio(X, Y)[source]

Compute local correlation ratio between two images.

Parameters:
  • X – Image tensor (B, C, H, W, D).

  • Y – Image tensor (B, C, H, W, D).

Returns:

Scalar correlation ratio.

forward(y_true, y_pred)[source]

Return negative symmetric local correlation ratio as a loss.

Parameters:
  • y_true – Fixed image tensor (B, 1, …).

  • y_pred – Moving image tensor (B, 1, …).

Returns:

Scalar loss.

gaussian_kernel(diff, preterm)[source]
vol_bin_centers

Sigma for Gaussian approx.

class MIR.image_similarity.MIND_loss(*args: Any, **kwargs: Any)[source]

Bases: Module

Local (over window) normalized cross correlation loss.

MINDSSC(img, radius=2, dilation=2)[source]

Compute the MIND-SSC descriptor.

Parameters:
  • img – Image tensor (B, 1, H, W, D).

  • radius – Patch radius.

  • dilation – Neighborhood dilation.

Returns:

MIND-SSC descriptor tensor.

forward(y_pred, y_true)[source]

Compute MIND-SSC L2 loss between two images.

Parameters:
  • y_pred – Moving image tensor (B, 1, H, W, D).

  • y_true – Fixed image tensor (B, 1, H, W, D).

Returns:

Scalar loss.

pdist_squared(x)[source]

Compute pairwise squared distances.

Parameters:

x – Tensor (B, N, D).

Returns:

Pairwise squared distances (B, N, N).

class MIR.image_similarity.MutualInformation(*args: Any, **kwargs: Any)[source]

Bases: Module

Mutual Information

forward(y_true, y_pred)[source]

Return negative mutual information as a loss.

Parameters:
  • y_true – Fixed image tensor (B, 1, …).

  • y_pred – Moving image tensor (B, 1, …).

Returns:

Scalar loss.

mi(y_true, y_pred)[source]

Compute mutual information between two images.

Parameters:
  • y_true – Fixed image tensor (B, 1, …).

  • y_pred – Moving image tensor (B, 1, …).

Returns:

Scalar mutual information.

class MIR.image_similarity.NCC(*args: Any, **kwargs: Any)[source]

Bases: Module

Local (over window) normalized cross correlation loss.

forward(y_true, y_pred)[source]

Compute NCC loss over a local window.

Parameters:
  • y_true – Fixed image tensor (B, 1, …).

  • y_pred – Moving image tensor (B, 1, …).

Returns:

Scalar NCC loss (negative mean NCC).

class MIR.image_similarity.NCC_fp16(*args: Any, **kwargs: Any)[source]

Bases: Module

Local normalized cross‑correlation loss for 1‑, 2‑ or 3‑D inputs.

Parameters:
  • win (int) – Side length of the cubic averaging window. Default: 9.

  • squared (bool) –

    • False → classic NCC ( σ_xy / √(σ_x σ_y) )

    • True → squared NCC ( σ_xy² / (σ_x σ_y) )

    Default: False.

  • eps (float) – Small constant to avoid divide‑by‑zero. Default: 1e‑5.

forward(y_true: torch.Tensor, y_pred: torch.Tensor) torch.Tensor[source]
class MIR.image_similarity.NCC_gauss(*args: Any, **kwargs: Any)[source]

Bases: Module

Local (over window) normalized cross correlation loss via Gaussian

create_window_3D(window_size, channel)[source]
forward(y_true, y_pred)[source]

Compute Gaussian-windowed NCC loss.

Parameters:
  • y_true – Fixed image tensor (B, 1, …).

  • y_pred – Moving image tensor (B, 1, …).

Returns:

Scalar NCC loss (negative mean NCC).

gaussian(window_size, sigma)[source]
class MIR.image_similarity.NCC_mok(*args: Any, **kwargs: Any)[source]

Bases: Module

local (over window) normalized cross correlation

forward(I, J)[source]
class MIR.image_similarity.NCC_mok2(*args: Any, **kwargs: Any)[source]

Bases: Module

local (over window) normalized cross correlation

forward(I, J)[source]
class MIR.image_similarity.NCC_vfa(*args: Any, **kwargs: Any)[source]

Bases: Module

Multi-scale NCC from C2FViT: https://github.com/cwmok/C2FViT suitable for FP16

forward(I, J)[source]
class MIR.image_similarity.NCC_vxm(*args: Any, **kwargs: Any)[source]

Bases: Module

Local (over window) normalized cross correlation loss.

forward(y_true, y_pred)[source]
class MIR.image_similarity.PCC(*args: Any, **kwargs: Any)[source]

Bases: Module

forward(I, J)[source]

Return 1 - PCC as a loss.

Parameters:
  • I – Fixed image tensor (B, 1, …).

  • J – Moving image tensor (B, 1, …).

Returns:

Scalar loss.

pcc(y_true, y_pred)[source]

Compute Pearson correlation coefficient.

Parameters:
  • y_true – Fixed image tensor (B, 1, …).

  • y_pred – Moving image tensor (B, 1, …).

Returns:

Scalar PCC.

class MIR.image_similarity.SSIM2D(*args: Any, **kwargs: Any)[source]

Bases: Module

forward(img1, img2)[source]

Compute 2D SSIM between two images.

Parameters:
  • img1 – Tensor (B, C, H, W).

  • img2 – Tensor (B, C, H, W).

Returns:

SSIM score.

class MIR.image_similarity.SSIM3D(*args: Any, **kwargs: Any)[source]

Bases: Module

forward(img1, img2)[source]

Compute 3D SSIM between two volumes.

Parameters:
  • img1 – Tensor (B, C, H, W, D).

  • img2 – Tensor (B, C, H, W, D).

Returns:

SSIM score.

class MIR.image_similarity.SegmentationLoss(*args: Any, **kwargs: Any)[source]

Bases: Module

forward(logits: torch.Tensor, target: torch.Tensor) torch.Tensor[source]

logits: (B, C, ph, pw, pd) target: (B, ph, pw, pd) ints in [0..C-1]

MIR.image_similarity.create_window(window_size, channel)[source]

Create a 2D SSIM window tensor.

MIR.image_similarity.create_window_3D(window_size, channel)[source]

Create a 3D SSIM window tensor.

MIR.image_similarity.gaussian(window_size, sigma)[source]

Create a 1D Gaussian kernel for SSIM windows.

class MIR.image_similarity.localMutualInformation(*args: Any, **kwargs: Any)[source]

Bases: Module

Local Mutual Information for non-overlapping patches

forward(y_true, y_pred)[source]

Return negative local mutual information as a loss.

Parameters:
  • y_true – Fixed image tensor (B, 1, …).

  • y_pred – Moving image tensor (B, 1, …).

Returns:

Scalar loss.

local_mi(y_true, y_pred)[source]

Compute local mutual information over patches.

Parameters:
  • y_true – Fixed image tensor (B, 1, …).

  • y_pred – Moving image tensor (B, 1, …).

Returns:

Scalar local mutual information.

MIR.image_similarity.ssim(img1, img2, window_size=11, size_average=True)[source]
MIR.image_similarity.ssim3D(img1, img2, window_size=11, size_average=True)[source]

Regularization terms for deformation fields and velocity models.

class MIR.deformation_regularizer.DisplacementRegularizer(*args: Any, **kwargs: Any)[source]

Bases: Module

Compute displacement-field regularization energies.

compute_bending_energy(displacement)[source]

Compute bending energy of a displacement field.

Parameters:

displacement – Tensor (B, 3, H, W, D).

Returns:

Scalar bending energy.

compute_gradient_norm(displacement, flag_l1=False)[source]

Compute L1/L2 gradient norm of a displacement field.

Parameters:
  • displacement – Tensor (B, 3, H, W, D).

  • flag_l1 – If True, uses L1 norm; otherwise L2.

Returns:

Scalar gradient norm.

forward(disp, _)[source]
gradient_dx(fv)[source]
gradient_dy(fv)[source]
gradient_dz(fv)[source]
gradient_txyz(Txyz, fn)[source]
class MIR.deformation_regularizer.Grad2D(*args: Any, **kwargs: Any)[source]

Bases: Module

2D gradient loss.

forward(y_pred, y_true)[source]

Compute 2D gradient regularization loss.

Parameters:
  • y_pred – Predicted displacement/field tensor (B, C, H, W).

  • y_true – Unused placeholder for API compatibility.

Returns:

Scalar gradient penalty.

class MIR.deformation_regularizer.Grad3DiTV(*args: Any, **kwargs: Any)[source]

Bases: Module

3D gradient Isotropic TV loss.

forward(y_pred, y_true)[source]

Compute isotropic total-variation loss in 3D.

Parameters:
  • y_pred – Predicted displacement/field tensor (B, C, H, W, D).

  • y_true – Unused placeholder for API compatibility.

Returns:

Scalar TV penalty.

class MIR.deformation_regularizer.Grad3d(*args: Any, **kwargs: Any)[source]

Bases: Module

3D gradient loss.

forward(y_pred, y_true=None)[source]

Compute 3D gradient regularization loss.

Parameters:
  • y_pred – Predicted displacement/field tensor (B, C, H, W, D).

  • y_true – Unused placeholder for API compatibility.

Returns:

Scalar gradient penalty.

class MIR.deformation_regularizer.GradICON3d(*args: Any, **kwargs: Any)[source]

Bases: Module

Gradient‑ICON loss for 3‑D displacement fields. Penalises the Frobenius‑norm of the Jacobian of the composition Φ^{AB}∘Φ^{BA} (forward ◦ inverse).

forward(flow_fwd, flow_inv)[source]
Returns:

loss

Return type:

scalar GradICON penalty

class MIR.deformation_regularizer.GradICONExact3d(*args: Any, **kwargs: Any)[source]

Bases: Module

Paper‑faithful Gradient‑ICON for 3‑D flows

forward(flow_fwd, flow_inv)[source]

flow_fwd, flow_inv : (B,3,D,H,W) voxel‑unit displacements

class MIR.deformation_regularizer.KL_divergence(*args: Any, **kwargs: Any)[source]

Bases: Module

KL divergence between factorized Gaussian fields.

forward(P, Q)[source]

Compute KL divergence between two Gaussian fields.

Parameters:
  • P – Tuple (mean_p, log_sigma_p) tensors.

  • Q – Tuple (mean_q, log_sigma_q) tensors.

Returns:

Scalar KL divergence.

class MIR.deformation_regularizer.LocalGrad3d(*args: Any, **kwargs: Any)[source]

Bases: Module

Local 3D gradient loss.

forward(y_pred, weight)[source]

Compute weighted 3D gradient regularization.

Parameters:
  • y_pred – Predicted displacement/field tensor (B, C, H, W, D).

  • weight – Spatial weight tensor (B, 1, H, W, D).

Returns:

Scalar weighted gradient penalty.

class MIR.deformation_regularizer.MultiVariateKL_divergence(*args: Any, **kwargs: Any)[source]

Bases: Module

KL divergence between multivariate Gaussian fields.

forward(P, Q)[source]

Compute multivariate KL divergence between Gaussian fields.

Parameters:
  • P – Tuple (mean_p, log_sigma_p) tensors.

  • Q – Tuple (mean_q, log_sigma_q) tensors.

Returns:

Scalar KL divergence.

class MIR.deformation_regularizer.logBeta(*args: Any, **kwargs: Any)[source]

Bases: Module

Negative log-likelihood term for Beta prior on weights.

forward(weights, alpha)[source]

Compute Beta prior regularization.

Parameters:
  • weights – Tensor of weights.

  • alpha – Beta distribution alpha parameter.

Returns:

Scalar regularization loss.

class MIR.deformation_regularizer.logGaussian(*args: Any, **kwargs: Any)[source]

Bases: Module

Gaussian prior regularizer for weights.

forward(weights, inv_sigma2)[source]

Compute Gaussian prior regularization.

Parameters:
  • weights – Tensor of weights.

  • inv_sigma2 – Inverse variance scalar/tensor.

Returns:

Scalar regularization loss.

Accuracy metrics and Jacobian-based measures for evaluation.

MIR.accuracy_measures.calc_J_i(trans, grad_args)[source]

Compute Jacobian determinant using a specific finite-difference stencil.

Parameters:
  • trans – Displacement field (3, H, W, D).

  • grad_args – Gradient stencil spec (e.g., ‘0x0y0z’, ‘+x-y+z’).

Returns:

Jacobian determinant volume (H-2, W-2, D-2).

MIR.accuracy_measures.calc_Jstar_1(trans)[source]

Compute Jacobian determinant with forward-difference stencil.

Parameters:

trans – Displacement field (3, H, W, D).

Returns:

Jacobian determinant volume (H-2, W-2, D-2).

MIR.accuracy_measures.calc_Jstar_2(trans)[source]

Compute Jacobian determinant with backward-difference stencil.

Parameters:

trans – Displacement field (3, H, W, D).

Returns:

Jacobian determinant volume (H-2, W-2, D-2).

MIR.accuracy_measures.calc_jac_dets(trans)[source]

Compute multiple Jacobian determinant variants and consistency mask.

Parameters:

trans – Displacement field (3, H, W, D).

Returns:

Dict of Jacobian determinant volumes keyed by stencil string.

MIR.accuracy_measures.calc_measurements(jac_dets, mask)[source]
MIR.accuracy_measures.dice_val_VOI(y_pred, y_true, num_clus=4, eval_labels=None)[source]

Compute mean Dice over a set of labels of interest.

Parameters:
  • y_pred – Predicted label tensor (B, 1, H, W, D).

  • y_true – Ground-truth label tensor (B, 1, H, W, D).

  • num_clus – Number of classes (used when eval_labels is None).

  • eval_labels – Optional list/array of label IDs to evaluate.

Returns:

Mean Dice score across the selected labels.

MIR.accuracy_measures.dice_val_all(y_pred, y_true, num_clus)[source]

Compute mean Dice across all classes.

Parameters:
  • y_pred – Predicted label tensor (B, 1, H, W, D).

  • y_true – Ground-truth label tensor (B, 1, H, W, D).

  • num_clus – Number of classes.

Returns:

Scalar mean Dice across classes.

MIR.accuracy_measures.dice_val_substruct(y_pred, y_true, std_idx, num_classes=46)[source]

Compute per-class Dice scores and return as a CSV line string.

Parameters:
  • y_pred – Predicted label tensor (B, 1, H, W, D).

  • y_true – Ground-truth label tensor (B, 1, H, W, D).

  • std_idx – Case identifier used in the output line.

  • num_classes – Total number of classes.

Returns:

CSV line string with per-class Dice values.

MIR.accuracy_measures.get_identity_grid(array)[source]

Return the identity transformation of the same size as the input. Expect input dimension: 3xHxWxS.