3D Brain Tumor Segmentation in PyTorch using U-Net & Eigenvector projection for color invariance
https://github.com/Saswatm123/3D-Brain-Tumor-Segmentation-PyTorch
Preface: I'm looking for a CV or ML/MLE job since getting my last offer rescinded in December.
I created a 3D Brain Tumor segmentation model using PyTorch that takes in human brain MRI scan data as input, and outputs a segmentation map of the part of the brain that has a tumor, if it is present.
Predicted segmentation map for previous input
There are many more pictures/GIFs on the README, but I would like to briefly go over the project here.
I settled on a U-Net architecture after experimenting with a couple other architecture styles. The skip connections really make a difference in the pixel-by-pixel accuracy of the segmentation map produced.
One interesting issue that took a bit of extra math to solve was regarding image color. The images would often be random hues since these MRI images come from different machines. For example, one brain image would be blue-on-black, while another one might be orange-on-green. Tumorous tissue is detected by a deviation in color, as can be seen in the first GIF, where even an untrained eye can pick out the tumor location, and the specific hue does not matter. Examples of this multiple hue effect can be seen in the README. This not only increases the dimensionality of the problem space, but also overactivates the residual connections often. For example, if they are used to lower-intensity color schemes (blue on black), a brighter color scheme (orange on green) would create an almost fully activated segmentation map since the first skip-connection would simply forward the image to the last couple layers. I needed a way to create color invariance in images. This is usually solved through grayscaling the image, which takes the L2 norm per pixel and uses that as a "brightness" value. However, this does not work for this use case. An L2 norm takes the shell of a 3D sphere & compresses it to a single point. This means that points on the same sphere shell, but separate from each other,([0,1,1],[0,1,0],[1,1,0]) would all be considered the same, and a tumor would go undetected. We need to maintain 3D distance between points while ignoring the actual color.
Solution: We view each image as a 5D point cloud of (x, y, R, G, B), where (x, y) are the coordinates per pixel, and (R, G, B) are the values for the pixel. We may ignore (x, y) for now and focus on the (R, G, B) values. Color invariance while maintaining shape is now simply a problem of scale, translation, and rotation invariance of a 3D point cloud.
Translation invariance is trivial - we simply center the means per axis. This means that any configuration of this point cloud that has the same shape, but is translated differently, maps to the same position.
Rotation invariance then can be achieved by taking the Eigenvectors of our centered point cloud, ordering them by length, and mapping them to axes (largest EV = axis 0, second largest = axis 1, etc.) We can then simply rotate our point cloud according to our eigenvector projections. This ends up being a 1-sample PCA, where the sample is the point cloud image. The README shows a table of various images with this technique applied to it, along with their point cloud representations.
This technique helps my model beat the human accuracy benchmark. The problem of residual channels being overwhelmed/thrown off by various color schemes is not an issue anymore.
I prefer solutions involving invariance & explicit bias over augmentation because augmentation is exponential in time & space. If there are 5 factors with 3 levels each that we wish to make our model robust to, the extra multiplier is 3^5, and we can get rid of this with some ML craftiness. The augmented solution is also much more vulnerable to adversarial attacks in a way that an explicitly invariant model is not.
The Loss Functions I used were DICE & Tversky. DICE is simply Intersection over Union between our predicted segmentation map and the ground truth segmentation map (code in repo & below).
def DICE_loss(input, target, eps= 1e-5):
'''
Args:
input:
Predicted Tensor to gauge accuracy of. Same size as target.
target:
Target Tensor to use as ground truth. Same size as input.
eps:
Smoothing value to ensure no division by zero.
Desc:
DICE Loss function computes 1 - DICE coefficient. DICE coefficient
is representation of Intersection over Union. Formula is:
2 * |Input && Target| / ( |Input| + |Target| )
For |...| sybolizing cardinality of a set.
Since input can include soft probabilities as well as hard 1/0,
the cardinality of an input is the sum.
Returns:
Tensor containing 1 - DICE coefficient, optimal when minimized @ 0
'''
intersection = (input * target).view(input.shape[0], -1).sum(axis= -1)
union = input.view(input.shape[0], -1).sum(axis= -1) + target.view(target.shape[0], -1).sum(axis= -1)
return (1 - 2*intersection/(union + eps) ).sum()
Tversky is similar, but more fine tuned. Tversky breaks down the Union term into False Positive + False Negative + True Positive. We can then add alpha & beta parameters to the False Positive & False Negative terms & guide our model's learning dynamically based on the mistakes it is making. Here is the code (also in repo).
def tversky_loss(input, target, eps= 1, alpha= .5, beta= .5):
'''
Args:
input:
Predicted Tensor to gauge accuracy of. Same size as target.
target:
Target Tensor to use as ground truth. Same size as input.
eps:
Smoothing value to ensure no division by zero.
alpha:
Weight to put on False Positives. Higher value penalizes more.
Value of .5 for alpha & beta results in standard DICE loss.
beta:
Weight to put on False Negatives. Higher value penalizes more.
Value of .5 for alpha & beta results in standard DICE loss.
Desc:
Tversky Loss is DICE Loss (IoU) with separate weights put on
False Positives and False Negatives. The Union calculation for
the denominator is framed as:
Union = True Positive + False Positive + False Negative
This allows us to put separate weights on False Positives and
False Negatives, leading to the calculation:
Union = True Positive + alpha * False Positive + beta * False Negative
Values of .5 for both parameters create the standard DICE loss.
Values lie in domain (0, inf).
Returns:
Tensor containing 1 - Tversky coefficient, optimal when minimized @ 0.
'''
# Flattens mask to single binary image since all 3 channels are the same
# for all masks in the batch
target = target[:,0,:,:].reshape(-1)
input = input.reshape(-1)
true_pos = (input * target).sum()
false_pos = ( (1-target) * input).sum()
false_neg = (target * (1-input) ).sum()
tversky_coef = (true_pos + eps) / (true_pos + alpha*false_pos + beta*false_neg + eps)
return 1 - tversky_coef
The model, like I mentioned before, is a simple U-Net architecture, that looks like this:
Image created in NN-SVG & MS Paint
The PyTorch code for the model can be found in the repo, and the README has more in-depth images of everything explained here. Thanks for reading, and I am open to hearing about job opportunities at the moment :)