Non-Rigid Registration: Free Form Deformation

This notebook illustrates the use of the Free Form Deformation (FFD) based non-rigid registration algorithm in SimpleITK.

The data we work with is a 4D (3D+time) thoracic-abdominal CT, the Point-validated Pixel-based Breathing Thorax Model (POPI) model. This data consists of a set of temporal CT volumes, a set of masks segmenting each of the CTs to air/body/lung, and a set of corresponding points across the CT volumes.

The POPI model is provided by the Léon Bérard Cancer Center & CREATIS Laboratory, Lyon, France. The relevant publication is:

J. Vandemeulebroucke, D. Sarrut, P. Clarysse, "The POPI-model, a point-validated pixel-based breathing thorax model", Proc. XVth International Conference on the Use of Computers in Radiation Therapy (ICCR), Toronto, Canada, 2007.

The POPI data, and additional 4D CT data sets with reference points are available from the CREATIS Laboratory here.

In [1]:
from __future__ import print_function

import SimpleITK as sitk
import registration_utilities as ru
import registration_callbacks as rc

%matplotlib inline
import matplotlib.pyplot as plt

from ipywidgets import interact, fixed

#utility method that either downloads data from the Girder repository or
#if already downloaded returns the file name for reading from disk (cached data)
%run update_path_to_download_script
from downloaddata import fetch_data as fdata

Utilities

Load utilities that are specific to the POPI data, functions for loading ground truth data, display and the labels for masks.

In [2]:
%run popi_utilities_setup.py

Loading Data

Load all of the images, masks and point data into corresponding lists. If the data is not available locally it will be downloaded from the original remote repository.

Take a look at the images. According to the documentation on the POPI site, volume number one corresponds to end inspiration (maximal air volume).

In [3]:
images = []
masks = []
points = []
for i in range(0,10):
    image_file_name = 'POPI/meta/{0}0-P.mhd'.format(i)
    mask_file_name = 'POPI/masks/{0}0-air-body-lungs.mhd'.format(i)
    points_file_name = 'POPI/landmarks/{0}0-Landmarks.pts'.format(i)
    images.append(sitk.ReadImage(fdata(image_file_name), sitk.sitkFloat32)) #read and cast to format required for registration
    masks.append(sitk.ReadImage(fdata(mask_file_name)))
    points.append(read_POPI_points(fdata(points_file_name)))
        
interact(display_coronal_with_overlay, temporal_slice=(0,len(images)-1), 
         coronal_slice = (0, images[0].GetSize()[1]-1), 
         images = fixed(images), masks = fixed(masks), 
         label=fixed(lung_label), window_min = fixed(-1024), window_max=fixed(976));
Fetching POPI/meta/00-P.mhd
Fetching POPI/masks/00-air-body-lungs.mhd
Fetching POPI/landmarks/00-Landmarks.pts
Downloaded 715 of 715 bytes (100.00%)
Fetching POPI/meta/10-P.mhd
Downloaded 48936960 of 48936960 bytes (100.00%)
Fetching POPI/masks/10-air-body-lungs.mhd
Downloaded 24473600 of 24473600 bytes (100.00%)
Fetching POPI/landmarks/10-Landmarks.pts
Downloaded 727 of 727 bytes (100.00%)
Fetching POPI/meta/20-P.mhd
Downloaded 48936960 of 48936960 bytes (100.00%)
Fetching POPI/masks/20-air-body-lungs.mhd
Downloaded 24473600 of 24473600 bytes (100.00%)
Fetching POPI/landmarks/20-Landmarks.pts
Downloaded 712 of 712 bytes (100.00%)
Fetching POPI/meta/30-P.mhd
Downloaded 48936960 of 48936960 bytes (100.00%)
Fetching POPI/masks/30-air-body-lungs.mhd
Downloaded 24473600 of 24473600 bytes (100.00%)
Fetching POPI/landmarks/30-Landmarks.pts
Downloaded 713 of 713 bytes (100.00%)
Fetching POPI/meta/40-P.mhd
Downloaded 48936960 of 48936960 bytes (100.00%)
Fetching POPI/masks/40-air-body-lungs.mhd
Downloaded 24473600 of 24473600 bytes (100.00%)
Fetching POPI/landmarks/40-Landmarks.pts
Downloaded 714 of 714 bytes (100.00%)
Fetching POPI/meta/50-P.mhd
Downloaded 48936960 of 48936960 bytes (100.00%)
Fetching POPI/masks/50-air-body-lungs.mhd
Downloaded 24473600 of 24473600 bytes (100.00%)
Fetching POPI/landmarks/50-Landmarks.pts
Downloaded 719 of 719 bytes (100.00%)
Fetching POPI/meta/60-P.mhd
Downloaded 48936960 of 48936960 bytes (100.00%)
Fetching POPI/masks/60-air-body-lungs.mhd
Downloaded 24473600 of 24473600 bytes (100.00%)
Fetching POPI/landmarks/60-Landmarks.pts
Downloaded 729 of 729 bytes (100.00%)
Fetching POPI/meta/70-P.mhd
Downloaded 48936960 of 48936960 bytes (100.00%)
Fetching POPI/masks/70-air-body-lungs.mhd
Downloaded 24473600 of 24473600 bytes (100.00%)
Fetching POPI/landmarks/70-Landmarks.pts
Downloaded 716 of 716 bytes (100.00%)
Fetching POPI/meta/80-P.mhd
Downloaded 48936960 of 48936960 bytes (100.00%)
Fetching POPI/masks/80-air-body-lungs.mhd
Downloaded 24473600 of 24473600 bytes (100.00%)
Fetching POPI/landmarks/80-Landmarks.pts
Downloaded 718 of 718 bytes (100.00%)
Fetching POPI/meta/90-P.mhd
Downloaded 48936960 of 48936960 bytes (100.00%)
Fetching POPI/masks/90-air-body-lungs.mhd
Downloaded 24473600 of 24473600 bytes (100.00%)
Fetching POPI/landmarks/90-Landmarks.pts
Downloaded 711 of 711 bytes (100.00%)

Getting to know your data

While the POPI site states that image number 1 is end inspiration, and visual inspection seems to suggest this is correct, we should probably take a look at the lung volumes to ensure that what we expect is indeed what is happening.

Which image is end inspiration and which end expiration?

In [4]:
label_shape_statistics_filter = sitk.LabelShapeStatisticsImageFilter()

for i, mask in enumerate(masks):
    label_shape_statistics_filter.Execute(mask)
    print('Lung volume in image {0} is {1} liters.'.format(i,0.000001*label_shape_statistics_filter.GetPhysicalSize(lung_label)))
Lung volume in image 0 is 5.455734929689075 liters.
Lung volume in image 1 is 5.527017905172941 liters.
Lung volume in image 2 is 5.554014379499289 liters.
Lung volume in image 3 is 5.451784830895104 liters.
Lung volume in image 4 is 5.301415956621658 liters.
Lung volume in image 5 is 5.191841246039885 liters.
Lung volume in image 6 is 5.09130732168864 liters.
Lung volume in image 7 is 5.11128669632256 liters.
Lung volume in image 8 is 5.195802788867059 liters.
Lung volume in image 9 is 5.335378032491021 liters.

Free Form Deformation

This function will align the fixed and moving images using a FFD. If given a mask, the similarity metric will be evaluated using points sampled inside the mask. If given fixed and moving points the similarity metric value and the target registration errors will be displayed during registration.

As this notebook performs intra-modal registration, we use the MeanSquares similarity metric (simple to compute and appropriate for the task).

In [5]:
def bspline_intra_modal_registration(fixed_image, moving_image, fixed_image_mask=None, fixed_points=None, moving_points=None):

    registration_method = sitk.ImageRegistrationMethod()
    
    # Determine the number of BSpline control points using the physical spacing we want for the control grid. 
    grid_physical_spacing = [50.0, 50.0, 50.0] # A control point every 50mm
    image_physical_size = [size*spacing for size,spacing in zip(fixed_image.GetSize(), fixed_image.GetSpacing())]
    mesh_size = [int(image_size/grid_spacing + 0.5) \
                 for image_size,grid_spacing in zip(image_physical_size,grid_physical_spacing)]

    initial_transform = sitk.BSplineTransformInitializer(image1 = fixed_image, 
                                                         transformDomainMeshSize = mesh_size, order=3)    
    registration_method.SetInitialTransform(initial_transform)
        
    registration_method.SetMetricAsMeanSquares()
    # Settings for metric sampling, usage of a mask is optional. When given a mask the sample points will be 
    # generated inside that region. Also, this implicitly speeds things up as the mask is smaller than the
    # whole image.
    registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
    registration_method.SetMetricSamplingPercentage(0.01)
    if fixed_image_mask:
        registration_method.SetMetricFixedMask(fixed_image_mask)
    
    # Multi-resolution framework.            
    registration_method.SetShrinkFactorsPerLevel(shrinkFactors = [4,2,1])
    registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2,1,0])
    registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

    registration_method.SetInterpolator(sitk.sitkLinear)
    registration_method.SetOptimizerAsLBFGSB(gradientConvergenceTolerance=1e-5, numberOfIterations=100)
    

    # If corresponding points in the fixed and moving image are given then we display the similarity metric
    # and the TRE during the registration.
    if fixed_points and moving_points:
        registration_method.AddCommand(sitk.sitkStartEvent, rc.metric_and_reference_start_plot)
        registration_method.AddCommand(sitk.sitkEndEvent, rc.metric_and_reference_end_plot)
        registration_method.AddCommand(sitk.sitkIterationEvent, lambda: rc.metric_and_reference_plot_values(registration_method, fixed_points, moving_points))
    
    return registration_method.Execute(fixed_image, moving_image)

Perform Registration

The following cell allows you to select the images used for registration, runs the registration, and afterwards computes statistics comparing the target registration errors before and after registration and displays a histogram of the TREs.

To time the registration, uncomment the timeit magic. Note: this creates a separate scope for the cell. Variables set inside the cell, specifically tx, will become local variables and thus their value is not available in other cells.

In [6]:
#%%timeit -r1 -n1

# Select the fixed and moving images, valid entries are in [0,9].
fixed_image_index = 0
moving_image_index = 7


tx = bspline_intra_modal_registration(fixed_image = images[fixed_image_index], 
                                      moving_image = images[moving_image_index],
                                      fixed_image_mask = (masks[fixed_image_index] == lung_label),
                                      fixed_points = points[fixed_image_index], 
                                      moving_points = points[moving_image_index]
                                     )
initial_errors_mean, initial_errors_std, _, initial_errors_max, initial_errors = ru.registration_errors(sitk.Euler3DTransform(), points[fixed_image_index], points[moving_image_index])
final_errors_mean, final_errors_std, _, final_errors_max, final_errors = ru.registration_errors(tx, points[fixed_image_index], points[moving_image_index])

plt.hist(initial_errors, bins=20, alpha=0.5, label='before registration', color='blue')
plt.hist(final_errors, bins=20, alpha=0.5, label='after registration', color='green')
plt.legend()
plt.title('TRE histogram');
print('Initial alignment errors in millimeters, mean(std): {:.2f}({:.2f}), max: {:.2f}'.format(initial_errors_mean, initial_errors_std, initial_errors_max))
print('Final alignment errors in millimeters, mean(std): {:.2f}({:.2f}), max: {:.2f}'.format(final_errors_mean, final_errors_std, final_errors_max))
Initial alignment errors in millimeters, mean(std): 5.07(2.67), max: 14.02
Final alignment errors in millimeters, mean(std): 1.56(0.93), max: 4.53

Another option for evaluating the registration is to use segmentation. In this case, we transfer the segmentation from one image to the other and compare the overlaps, both visually, and quantitatively.

Note: A more detailed version of the approach described here can be found in the Segmentation Evaluation notebook.

In [7]:
# Transfer the segmentation via the estimated transformation. Use Nearest Neighbor interpolation to retain the labels.
transformed_labels = sitk.Resample(masks[moving_image_index],
                                   images[fixed_image_index],
                                   tx, 
                                   sitk.sitkNearestNeighbor,
                                   0.0, 
                                   masks[moving_image_index].GetPixelID())

segmentations_before_and_after = [masks[moving_image_index], transformed_labels]
interact(display_coronal_with_label_maps_overlay, coronal_slice = (0, images[0].GetSize()[1]-1),
         mask_index=(0,len(segmentations_before_and_after)-1),
         image = fixed(images[fixed_image_index]), masks = fixed(segmentations_before_and_after), 
         label=fixed(lung_label), window_min = fixed(-1024), window_max=fixed(976));

# Compute the Dice coefficient and Hausdorff distance between the segmentations before, and after registration.
ground_truth = masks[fixed_image_index] == lung_label
before_registration = masks[moving_image_index] == lung_label
after_registration = transformed_labels == lung_label

label_overlap_measures_filter = sitk.LabelOverlapMeasuresImageFilter()
label_overlap_measures_filter.Execute(ground_truth, before_registration)
print("Dice coefficient before registration: {:.2f}".format(label_overlap_measures_filter.GetDiceCoefficient()))
label_overlap_measures_filter.Execute(ground_truth, after_registration)
print("Dice coefficient after registration: {:.2f}".format(label_overlap_measures_filter.GetDiceCoefficient()))

hausdorff_distance_image_filter = sitk.HausdorffDistanceImageFilter()
hausdorff_distance_image_filter.Execute(ground_truth, before_registration)
print("Hausdorff distance before registration: {:.2f}".format(hausdorff_distance_image_filter.GetHausdorffDistance()))
hausdorff_distance_image_filter.Execute(ground_truth, after_registration)
print("Hausdorff distance after registration: {:.2f}".format(hausdorff_distance_image_filter.GetHausdorffDistance()))
Dice coefficient before registration: 0.94
Dice coefficient after registration: 0.97
Hausdorff distance before registration: 18.04
Hausdorff distance after registration: 14.35

Multi-resolution control point grid

In the example above we used the standard image registration framework. This implies the same transformation model at all image resolutions. For global transformations (e.g. rigid, affine...) the number of transformation parameters has no relationship to the changing resolution. For the BSpline transformation we can potentially use fewer control points for images with lower frequencies, higher levels of the image pyramid, increasing the number of control points as we go down the pyramid. With the standard framework we use the same number of control points for all pyramid levels.

To use a multi-resolution control point grid we have a specific initializer for the BSpline transformation, SetInitialTransformAsBSpline.

The following code solves the same registration task as above, just with a multi-resolution control point grid.

In [8]:
def bspline_intra_modal_registration2(fixed_image, moving_image, fixed_image_mask=None, fixed_points=None, moving_points=None):

    registration_method = sitk.ImageRegistrationMethod()
    
    # Determine the number of BSpline control points using the physical spacing we 
    # want for the finest resolution control grid. 
    grid_physical_spacing = [50.0, 50.0, 50.0] # A control point every 50mm
    image_physical_size = [size*spacing for size,spacing in zip(fixed_image.GetSize(), fixed_image.GetSpacing())]
    mesh_size = [int(image_size/grid_spacing + 0.5) \
                 for image_size,grid_spacing in zip(image_physical_size,grid_physical_spacing)]

    # The starting mesh size will be 1/4 of the original, it will be refined by 
    # the multi-resolution framework.
    mesh_size = [int(sz/4 + 0.5) for sz in mesh_size]
    
    initial_transform = sitk.BSplineTransformInitializer(image1 = fixed_image, 
                                                         transformDomainMeshSize = mesh_size, order=3)    
    # Instead of the standard SetInitialTransform we use the BSpline specific method which also
    # accepts the scaleFactors parameter to refine the BSpline mesh. In this case we start with 
    # the given mesh_size at the highest pyramid level then we double it in the next lower level and
    # in the full resolution image we use a mesh that is four times the original size.
    registration_method.SetInitialTransformAsBSpline(initial_transform,
                                                     inPlace=True,
                                                     scaleFactors=[1,2,4])
    registration_method.SetMetricAsMeanSquares()
    # Settings for metric sampling, usage of a mask is optional. When given a mask the sample points will be 
    # generated inside that region. Also, this implicitly speeds things up as the mask is smaller than the
    # whole image.
    registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
    registration_method.SetMetricSamplingPercentage(0.01)
    if fixed_image_mask:
        registration_method.SetMetricFixedMask(fixed_image_mask)
    
    # Multi-resolution framework.            
    registration_method.SetShrinkFactorsPerLevel(shrinkFactors = [4,2,1])
    registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2,1,0])
    registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

    registration_method.SetInterpolator(sitk.sitkLinear)
    # Use the LBFGS2 instead of LBFGS. The latter cannot adapt to the changing control grid resolution.
    registration_method.SetOptimizerAsLBFGS2(solutionAccuracy=1e-2, numberOfIterations=100, deltaConvergenceTolerance=0.01)

    # If corresponding points in the fixed and moving image are given then we display the similarity metric
    # and the TRE during the registration.
    if fixed_points and moving_points:
        registration_method.AddCommand(sitk.sitkStartEvent, rc.metric_and_reference_start_plot)
        registration_method.AddCommand(sitk.sitkEndEvent, rc.metric_and_reference_end_plot)
        registration_method.AddCommand(sitk.sitkIterationEvent, lambda: rc.metric_and_reference_plot_values(registration_method, fixed_points, moving_points))
    
    return registration_method.Execute(fixed_image, moving_image)
In [9]:
#%%timeit -r1 -n1

# Select the fixed and moving images, valid entries are in [0,9].
fixed_image_index = 0
moving_image_index = 7


tx = bspline_intra_modal_registration2(fixed_image = images[fixed_image_index], 
                                      moving_image = images[moving_image_index],
                                      fixed_image_mask = (masks[fixed_image_index] == lung_label),
                                      fixed_points = points[fixed_image_index], 
                                      moving_points = points[moving_image_index]
                                     )
initial_errors_mean, initial_errors_std, _, initial_errors_max, initial_errors = ru.registration_errors(sitk.Euler3DTransform(), points[fixed_image_index], points[moving_image_index])
final_errors_mean, final_errors_std, _, final_errors_max, final_errors = ru.registration_errors(tx, points[fixed_image_index], points[moving_image_index])

plt.hist(initial_errors, bins=20, alpha=0.5, label='before registration', color='blue')
plt.hist(final_errors, bins=20, alpha=0.5, label='after registration', color='green')
plt.legend()
plt.title('TRE histogram');
print('Initial alignment errors in millimeters, mean(std): {:.2f}({:.2f}), max: {:.2f}'.format(initial_errors_mean, initial_errors_std, initial_errors_max))
print('Final alignment errors in millimeters, mean(std): {:.2f}({:.2f}), max: {:.2f}'.format(final_errors_mean, final_errors_std, final_errors_max))
Initial alignment errors in millimeters, mean(std): 5.07(2.67), max: 14.02
Final alignment errors in millimeters, mean(std): 1.77(1.23), max: 7.22