Skip to content

jamesbraza/cs231n-3d-segmentation

Repository files navigation

cs231n-3d-segmentation

Stanford CS231N Deep Learning for Computer Vision: Class Project

Summary

3-D U-Net

A 3-D U-Net gets passed batches of four MRI scans (B x 4 x D x H x W), and processes them using 3-D convolution and max pooling layers. The idea is the model can utilize 3-D spatial information, not just 2-D (per-slice) info.

Trained weights are available in unet_zoo/checkpoints. The binary threshold used on the 3-D U-Net model's raw predictions was 0.25, we found this maximized the IoU within the validation dataset's binarized predictions and labels.

Here are several cross-sections from the test dataset's example 2:

Angle Image
0 cross-section angle 1
1 cross-section angle 1
2 cross-section angle 2

Here is the same example rendered in 3-D. The color code is grey is healthy tissue, green is non-enhancing tumor core, blue is peritumoral edema, and purple is Gd-enhancing tumor.

Labels Predictions
Labels 3-D predictions

2-D U-Net

On the branch experiment/unet2d, we trained a 2-D U-Net. A 2-D U-Net gets passed batches of four MRI slices (B x 4 x H x W), where the batch dimension B is actually the original MRI's depth dimension D.

This slicing detail effectively turns a batch of MRI volumes into a batch of slices from MRI volumes. This means a 2-D U-Net doesn't learn upon data containing 3-D spatial information. The 2-D U-Net model was trained using the slices dataset found in data/loaders.py.

The 2-D U-Net uses 2-D convolution and max pooling layers (as opposed to 3-D ones), so it has 1/3rd the weights of a 3-D U-Net. This lightweight model runs substantially faster, but, given the same training data, generally under-performs a 3-D U-Net when looking at IoU with labels.

Trained weights are available in unet_zoo/checkpoints on the experiment/unet2d branch. The binary threshold used on the 2-D U-Net model's raw predictions was 0.06, using the IoU procedure detailed in the 3-D U-Net summary.

Here are several cross-sections from the test dataset's example 2:

Angle Image
0 cross-section angle 1
1 cross-section angle 1
2 cross-section angle 2

Here is the same example rendered in 3-D, its color code corresponds with the 3-D U-Net summary's colors.

Labels Predictions
Labels 2-D predictions

Dataset

We used the BraTS2020 Dataset (Training + Validation) dataset from Kaggle. All iterations of the BraTS challenge can be found here.

Here's how to easily download the dataset with the Kaggle API:

kaggle datasets download -p data/brats2020-training-validation-data --unzip awsaf49/brats20-dataset-training-validation

Development

Development began in spring 2023 using Python 3.11.

AWS

AWS granted access to G and VT instances, and at the time, AWS's Deep Learning AMI supported G3, P3, P3dn, P4d, P4de, G5, G4dn instances. Thus, the AMI used was Deep Learning AMI GPU PyTorch 2.0.0 (Ubuntu 20.04) 20230530 (release notes) with instance type g4dn.2xlarge and 120 GiB of gp3 (general purpose SSD) storage.

For your reference, the below commands take less than 20 minutes to run.

Step 1: check GPU is present.

nvcc --version  # cuda_11.8
nvidia-smi  # NVIDIA Tesla T4
sudo apt update && sudo apt upgrade -y
sudo apt install -y ubuntu-drivers-common alsa-utils
ubuntu-drivers devices  # Drivers: nvidia-driver-525, nvidia-driver-525-server

Step 2: install and configure Python 3.11.

python3 --version  # 3.8.10
sudo apt update && sudo apt upgrade -y
sudo add-apt-repository -y ppa:deadsnakes/ppa
sudo apt install -y python3.11 python3.11-dev python3.11-venv
sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.8 1
sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 2
python3 --version  # 3.11.3

Step 3: git clone and install requirements into a venv.

git clone https://github.com/jamesbraza/cs231n-3d-segmentation.git
cd cs231n-3d-segmentation
python3 -m venv venv
source venv/bin/activate
python -m pip install --no-cache-dir --progress-bar off -r requirements.txt

Step 4: download BraTS 2020 dataset using the Kaggle API.

# Run these from local (non-VM) machine
export SEG01=<vm IP address>
scp -pr ~/.kaggle/ ubuntu@$SEG01:~/.kaggle/
ssh ubuntu@$SEG01
cd cs231n-3d-segmentation
source venv/bin/activate
kaggle datasets download -p data/brats2020-training-validation-data \
    --unzip awsaf49/brats20-dataset-training-validation

TensorBoard

Local Training

Here is how you kick off TensorBoard:

tensorboard --logdir <path> --port 6006

Afterwards, go to the URL: http://localhost:6006/.

Remote Training

If training on a remote machine, make sure you expose port 6006 in the AWS security group:

  • IP version: IPv4
  • Type: Custom TCP
  • Protocol: TCP
  • Port range: 6006
  • Source: 0.0.0.0/0
  • Description: TensorBoard

First, start TensorBoard on the remote machine:

tensorboard --logdir <path> --host 0.0.0.0 --port 6006

Then on the local machine:

export SEG01=<vm IP address>
python -m webbrowser http://$SEG01:6006/
# or
open http://$SEG01:6006/