From a92d34e86341cab9fc8938ca34a7edd58a78ef61 Mon Sep 17 00:00:00 2001 From: vaibhavi089 Date: Sat, 13 Sep 2025 18:41:06 +0530 Subject: [PATCH 1/2] Fix Issue #245: Add UNet tutorial for DetectionMetrics --- tutorial/UnetOnRelis3D.ipynb | 429 +++++++++++++++++++++++++++++++++++ 1 file changed, 429 insertions(+) create mode 100644 tutorial/UnetOnRelis3D.ipynb diff --git a/tutorial/UnetOnRelis3D.ipynb b/tutorial/UnetOnRelis3D.ipynb new file mode 100644 index 00000000..57aaf5f2 --- /dev/null +++ b/tutorial/UnetOnRelis3D.ipynb @@ -0,0 +1,429 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "The Rellis-3D dataset is split into:\n", + "1) Full Image (11GB) -> the RGB inputs.\n", + "2) Full Image Annotations(94MB, ID Format) -> segmentation labels where each pixel corresponds to a class Id." + ], + "metadata": { + "id": "IVxOn3wyEQ76" + } + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "4rrg2GbREPx9" + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "from torch.utils.data import DataLoader, Dataset\n", + "from torchvision import transforms\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from pathlib import Path\n", + "from PIL import Image" + ] + }, + { + "cell_type": "markdown", + "source": [ + "We create a custom PyTorch dataset class called `Rellis3DDataset` that loads images and segmentation masks from the Rellis-3D dataset.\n", + "\n", + " In the `__init__` method, we define the root directory of the dataset, collect sorted file paths for images and labels, set the target resize dimensions, and optionally accept a transform function for preprocessing.\n", + "\n", + " The `__len__` method simply returns the number of available images, while the `__getitem__` method handles loading a single sample. For each index, it loads the corresponding RGB image and segmentation mask, resizes the image using bilinear interpolation (to preserve smoothness) and the mask using nearest neighbor interpolation (to keep class IDs intact), applies any defined transforms to the image (such as converting to a tensor and normalizing with ImageNet mean and standard deviation), and converts the label to a tensor of class IDs.\n", + " \n", + " Finally, it returns the image and its label as a pair. The transform pipeline consists of converting the image to a tensor and normalizing it, which is standard practice when using pretrained models like UNet encoders.\n", + " \n", + " We then create a dataset instance pointing to the validation folder of Rellis-3D and wrap it with a DataLoader, which batches the data (batch size of 4 in this example) and handles iteration. With this setup, we now have a clean way to load, preprocess, and batch Rellis-3D samples for training or evaluation." + ], + "metadata": { + "id": "-sMH4bWtEXiz" + } + }, + { + "cell_type": "code", + "source": [ + "class Rellis3DDataset(Dataset):\n", + " def __init__(self, root_dir, image_size=(512, 512), transform=None):\n", + " self.root_dir = Path(root_dir)\n", + " self.image_paths = sorted((self.root_dir / \"images\").glob(\"*.png\"))\n", + " self.label_paths = sorted((self.root_dir / \"labels\").glob(\"*.png\"))\n", + " self.transform = transform\n", + " self.image_size = image_size\n", + "\n", + " def __len__(self):\n", + " return len(self.image_paths)\n", + "\n", + " def __getitem__(self, idx):\n", + " img = Image.open(self.image_paths[idx]).convert(\"RGB\")\n", + " label = Image.open(self.label_paths[idx])\n", + "\n", + "\n", + " img = img.resize(self.image_size, Image.BILINEAR)\n", + " label = label.resize(self.image_size, Image.NEAREST)\n", + "\n", + " if self.transform:\n", + " img = self.transform(img)\n", + "\n", + " label = torch.from_numpy(np.array(label)).long()\n", + " return img, label\n", + "\n", + "\n", + "\n", + "transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize(mean=[0.485, 0.456, 0.406],\n", + " std=[0.229, 0.224, 0.225]),\n", + "])\n", + "\n", + "\n", + "dataset = Rellis3DDataset(\"path/to/Rellis-3D/val\", transform=transform)\n", + "dataloader = DataLoader(dataset, batch_size=4, shuffle=False)\n" + ], + "metadata": { + "id": "fcyZcTMDEVVV" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# UNet Model Explanation (PyTorch)\n", + "\n", + "This document explains the structure and working of a **UNet model** implemented in PyTorch for image segmentation.\n", + "\n", + "---\n", + "\n", + "## 1. DoubleConv Block\n", + "\n", + "- **Purpose:** Performs feature extraction through **two consecutive convolution layers**.\n", + "- **Operations:**\n", + " - **Conv2d:** 2D convolution extracts features from input images.\n", + " - **BatchNorm2d:** Normalizes outputs of the convolution layer for faster and stable training.\n", + " - **ReLU:** Introduces non-linearity.\n", + "- **Padding:** `padding=1` ensures the output spatial dimensions are the same as the input.\n", + "- **Benefit:** Two convolutions per block allow the network to learn richer feature representations at each level.\n", + "\n", + "---\n", + "\n", + "## 2. UNet Architecture\n", + "\n", + "UNet is an **encoder-decoder network** with skip connections, specifically designed for image segmentation.\n", + "\n", + "### Encoder (Downsampling Path)\n", + "\n", + "- Comprised of **DoubleConv blocks** followed by **Max Pooling**.\n", + "- Reduces spatial dimensions while increasing the number of feature channels.\n", + "- Example progression: `3 → 64 → 128 → 256 → 512` channels.\n", + "- **MaxPool2d(2):** Reduces width and height by half at each step.\n", + "\n", + "### Bottleneck\n", + "\n", + "- The **deepest layer** in the network.\n", + "- Uses a **DoubleConv block** to capture high-level features with maximum channels (`512 → 1024`).\n", + "- Acts as a bridge between encoder and decoder.\n", + "\n", + "### Decoder (Upsampling Path)\n", + "\n", + "- Upsamples the feature maps using **ConvTranspose2d**, increasing spatial resolution.\n", + "- **Skip Connections:** Each upsampled feature map is concatenated with the corresponding encoder output.\n", + " - Helps the network preserve fine-grained spatial information.\n", + "- Reduces channel dimension progressively: `1024 → 512 → 256 → 128 → 64`.\n", + "\n", + "### Output Layer\n", + "\n", + "- **Final Conv2d:** Reduces channels to the number of classes (`n_classes`), producing the segmentation map.\n", + "\n", + "---\n", + "\n", + "## 3. Forward Pass Workflow\n", + "\n", + "1. **Encoding:** Input passes through encoder blocks, producing feature maps (`e1, e2, e3, e4`).\n", + "2. **Pooling:** Max pooling downsamples feature maps between encoder layers.\n", + "3. **Bottleneck:** Deepest layer extracts high-level features.\n", + "4. **Decoding:**\n", + " - Upsample bottleneck features.\n", + " - Concatenate with encoder features (skip connections).\n", + " - Refine features using DoubleConv blocks.\n", + "5. **Final Output:** Last convolution layer outputs the segmentation map with `n_classes` channels.\n", + "\n", + "---\n", + "\n", + "## 4. Model Initialization and Pretrained Weights\n", + "\n", + "- `n_classes = 19`: Number of segmentation classes.\n", + "- **Loading Weights:** Pretrained weights are loaded using `load_state_dict`.\n", + "- **Evaluation Mode:** `eval()` disables training-specific layers like dropout and fixes batch normalization.\n", + "- **GPU Usage:** `cuda()` moves the model to GPU for faster inference.\n", + "\n", + "---\n", + "\n", + "## 5. Key Points\n", + "\n", + "- **Skip Connections:** Preserve spatial details lost during downsampling.\n", + "- **DoubleConv Blocks:** Improve feature extraction at each level.\n", + "- **Encoder-Decoder Symmetry:** Ensures feature information is combined effectively during reconstruction.\n", + "- **UNet Use Case:** Commonly used in semantic segmentation tasks such as satellite imagery, medical imaging, and autonomous driving.\n" + ], + "metadata": { + "id": "P0dkft82EkeN" + } + }, + { + "cell_type": "code", + "source": [ + "class DoubleConv(nn.Module):\n", + " def __init__(self, in_channels, out_channels):\n", + " super().__init__()\n", + " self.double_conv = nn.Sequential(\n", + " nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),\n", + " nn.BatchNorm2d(out_channels),\n", + " nn.ReLU(inplace=True),\n", + " nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),\n", + " nn.BatchNorm2d(out_channels),\n", + " nn.ReLU(inplace=True),\n", + " )\n", + "\n", + " def forward(self, x):\n", + " return self.double_conv(x)\n", + "\n", + "\n", + "class UNet(nn.Module):\n", + " def __init__(self, n_classes):\n", + " super().__init__()\n", + " self.enc1 = DoubleConv(3, 64)\n", + " self.enc2 = DoubleConv(64, 128)\n", + " self.enc3 = DoubleConv(128, 256)\n", + " self.enc4 = DoubleConv(256, 512)\n", + "\n", + " self.pool = nn.MaxPool2d(2)\n", + "\n", + " self.bottleneck = DoubleConv(512, 1024)\n", + "\n", + " self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)\n", + " self.dec4 = DoubleConv(1024, 512)\n", + " self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)\n", + " self.dec3 = DoubleConv(512, 256)\n", + " self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)\n", + " self.dec2 = DoubleConv(256, 128)\n", + " self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)\n", + " self.dec1 = DoubleConv(128, 64)\n", + "\n", + " self.final = nn.Conv2d(64, n_classes, kernel_size=1)\n", + "\n", + " def forward(self, x):\n", + " e1 = self.enc1(x)\n", + " e2 = self.enc2(self.pool(e1))\n", + " e3 = self.enc3(self.pool(e2))\n", + " e4 = self.enc4(self.pool(e3))\n", + "\n", + " b = self.bottleneck(self.pool(e4))\n", + "\n", + " d4 = self.up4(b)\n", + " d4 = self.dec4(torch.cat([d4, e4], dim=1))\n", + " d3 = self.up3(d4)\n", + " d3 = self.dec3(torch.cat([d3, e3], dim=1))\n", + " d2 = self.up2(d3)\n", + " d2 = self.dec2(torch.cat([d2, e2], dim=1))\n", + " d1 = self.up1(d2)\n", + " d1 = self.dec1(torch.cat([d1, e1], dim=1))\n", + "\n", + " return self.final(d1)\n", + "\n", + "\n", + "n_classes = 19\n", + "model = UNet(n_classes=n_classes)\n", + "model.load_state_dict(torch.load(\"unet_rellis3d.pth\", map_location=\"cuda\"))\n", + "model.eval().cuda()\n" + ], + "metadata": { + "id": "pydGeenMErfZ" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Evaluation of UNet Model for Segmentation\n", + "\n", + "This document explains the evaluation functions used to measure the performance of a **UNet segmentation model** using metrics like **mIoU** and **Pixel Accuracy**.\n", + "\n", + "---\n", + "\n", + "## 1. Confusion Matrix\n", + "\n", + "**Function:** `compute_confusion_matrix(pred, label, num_classes)`\n", + "\n", + "- **Purpose:** Builds a confusion matrix for a single prediction-label pair.\n", + "- **Process:**\n", + " - **Masking:** Ensures only valid labels (0 ≤ label < num_classes) are considered.\n", + " - **Indexing:** Each pair of `(label, pred)` is mapped to a unique index: `num_classes * label + pred`.\n", + " - **Bincount:** Counts occurrences of each `(label, pred)` pair.\n", + " - **Reshape:** Converts the flat count vector into a `(num_classes x num_classes)` confusion matrix.\n", + "\n", + "**Key Concept:**\n", + "\n", + "- **Rows:** Ground truth classes.\n", + "- **Columns:** Predicted classes.\n", + "- **Diagonal:** Correct predictions.\n", + "\n", + "---\n", + "\n", + "## 2. Mean Intersection over Union (mIoU)\n", + "\n", + "**Function:** `compute_mIoU(confusion_matrix)`\n", + "\n", + "- **Purpose:** Measures the overlap between predicted and true segmentation regions.\n", + "- **Computation:**\n", + " - **Intersection:** Diagonal of the confusion matrix → correctly predicted pixels for each class.\n", + " - **Union:** Sum of row + sum of column - intersection → total pixels covered by either prediction or ground truth for each class.\n", + " - **IoU per class:** `intersection / union`.\n", + " - **Mean IoU:** Average across all classes.\n", + "\n", + "**Significance:** mIoU is a standard metric for semantic segmentation, reflecting per-class accuracy while accounting for class imbalance.\n", + "\n", + "---\n", + "\n", + "## 3. Evaluation Function\n", + "\n", + "**Function:** `evaluate(model, dataloader, num_classes=19)`\n", + "\n", + "- **Purpose:** Computes **mIoU** and **Pixel Accuracy** over the entire dataset.\n", + "- **Steps:**\n", + " 1. Initialize a zeroed confusion matrix.\n", + " 2. Disable gradient computation with `torch.no_grad()` (faster evaluation).\n", + " 3. Iterate over batches from `dataloader`:\n", + " - Move images and labels to GPU.\n", + " - Run the model to get predictions.\n", + " - Convert raw outputs to class predictions using `argmax`.\n", + " - Update the confusion matrix for each image in the batch.\n", + " 4. Compute **mIoU** using `compute_mIoU`.\n", + " 5. Compute **Pixel Accuracy**: total correctly predicted pixels divided by total pixels.\n", + "\n", + "**Outputs:**\n", + "\n", + "- **mIoU:** Mean Intersection over Union (range: 0–1)\n", + "- **Pixel Accuracy:** Overall fraction of correctly classified pixels (range: 0–1)\n", + "\n", + "---\n", + "\n", + "## 4. Example Usage\n", + "\n", + "- Call the evaluation function with the model and dataloader:\n", + "\n", + "```python\n", + "mIoU, pixel_acc = evaluate(model, dataloader, num_classes=19)\n", + "print(f\"mIoU: {mIoU:.4f}, Pixel Accuracy: {pixel_acc:.4f}\")\n" + ], + "metadata": { + "id": "Y9bk1mBjExWa" + } + }, + { + "cell_type": "code", + "source": [ + "def compute_confusion_matrix(pred, label, num_classes):\n", + " mask = (label >= 0) & (label < num_classes)\n", + " hist = torch.bincount(\n", + " num_classes * label[mask] + pred[mask],\n", + " minlength=num_classes ** 2\n", + " ).reshape(num_classes, num_classes)\n", + " return hist\n", + "\n", + "def compute_mIoU(confusion_matrix):\n", + " intersection = torch.diag(confusion_matrix)\n", + " union = confusion_matrix.sum(1) + confusion_matrix.sum(0) - intersection\n", + " IoU = intersection / union\n", + " return IoU.mean().item()\n", + "\n", + "def evaluate(model, dataloader, num_classes=19):\n", + " confusion_matrix = torch.zeros((num_classes, num_classes), dtype=torch.int64)\n", + "\n", + " with torch.no_grad():\n", + " for imgs, labels in dataloader:\n", + " imgs, labels = imgs.cuda(), labels.cuda()\n", + " outputs = model(imgs)\n", + " preds = torch.argmax(outputs, dim=1)\n", + "\n", + " for p, l in zip(preds, labels):\n", + " confusion_matrix += compute_confusion_matrix(p.view(-1), l.view(-1), num_classes)\n", + "\n", + " mIoU = compute_mIoU(confusion_matrix)\n", + " pixel_acc = torch.diag(confusion_matrix).sum().item() / confusion_matrix.sum().item()\n", + " return mIoU, pixel_acc\n", + "\n", + "\n", + "mIoU, pixel_acc = evaluate(model, dataloader, num_classes=19)\n", + "print(f\"mIoU: {mIoU:.4f}, Pixel Accuracy: {pixel_acc:.4f}\")\n" + ], + "metadata": { + "id": "SMlf-38xE0dT" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Visualization" + ], + "metadata": { + "id": "a-MUDeDfE1qv" + } + }, + { + "cell_type": "code", + "source": [ + "def visualize(model, dataset, idx=0):\n", + " model.eval()\n", + " img, label = dataset[idx]\n", + " img_input = img.unsqueeze(0).cuda()\n", + "\n", + " with torch.no_grad():\n", + " output = model(img_input)\n", + " pred = torch.argmax(output, dim=1).squeeze().cpu().numpy()\n", + "\n", + " fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n", + " axes[0].imshow(img.permute(1,2,0).cpu().numpy() * 0.229 + 0.485) # approx unnormalize\n", + " axes[0].set_title(\"Input Image\")\n", + " axes[1].imshow(label.numpy())\n", + " axes[1].set_title(\"Ground Truth\")\n", + " axes[2].imshow(pred)\n", + " axes[2].set_title(\"Prediction\")\n", + " plt.show()\n", + "\n", + "\n", + "\n", + "visualize(model, dataset, idx=5)\n" + ], + "metadata": { + "id": "KeogQzsqE5US" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file From 5545331ef955b5d14e8adf7b4e0d22c9858a0c76 Mon Sep 17 00:00:00 2001 From: vaibhavi089 Date: Wed, 17 Sep 2025 10:31:01 +0530 Subject: [PATCH 2/2] Move Unet on RELLIS - 3D tutorial to example/ after making changes. --- examples/UnetOnRelis3D.ipynb | 308 +++++++++++++++++++++++++ tutorial/UnetOnRelis3D.ipynb | 429 ----------------------------------- 2 files changed, 308 insertions(+), 429 deletions(-) create mode 100644 examples/UnetOnRelis3D.ipynb delete mode 100644 tutorial/UnetOnRelis3D.ipynb diff --git a/examples/UnetOnRelis3D.ipynb b/examples/UnetOnRelis3D.ipynb new file mode 100644 index 00000000..e0054582 --- /dev/null +++ b/examples/UnetOnRelis3D.ipynb @@ -0,0 +1,308 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "IVxOn3wyEQ76" + }, + "source": [ + "# UNet on RELLIS-3D with DetectionMetrics\n", + "\n", + "This tutorial shows how to train a simple **UNet** model on the **RELLIS-3D dataset** and then **evaluate it** using the [DetectionMetrics](https://jderobot.github.io/DetectionMetrics/v2/) library. \n", + "\n", + "While training is included here for demonstration, the main focus of DetectionMetrics is **evaluation**. \n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Installation\n", + "\n", + "First, install the required dependencies: **PyTorch**, **torchvision**, and **DetectionMetrics**.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pip install torch torchvision\n", + "pip install detection-metrics" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Imports\n", + "\n", + "We import PyTorch for model training and DetectionMetrics for dataset handling and evaluation.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4rrg2GbREPx9" + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "from torch.utils.data import DataLoader\n", + "import torchvision.transforms as T\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from detection_metrics.datasets import Rellis3DImageSegmentationDataset\n", + "from detection_metrics.evaluators import SegmentationEvaluator\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Load RELLIS-3D Dataset\n", + "\n", + "DetectionMetrics provides a ready-to-use class `Rellis3DImageSegmentationDataset`. \n", + "Here we create **train** and **validation** splits, and apply basic transformations.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "fcyZcTMDEVVV" + }, + "outputs": [], + "source": [ + "data_root = \"/path/to/rellis3d\" # TODO: replace with your dataset path\n", + "\n", + "transform = T.Compose([\n", + " T.ToTensor(),\n", + " T.Resize((256, 256)),\n", + "])\n", + "\n", + "train_dataset = Rellis3DImageSegmentationDataset(\n", + " root=data_root,\n", + " split=\"train\",\n", + " transforms=transform\n", + ")\n", + "\n", + "val_dataset = Rellis3DImageSegmentationDataset(\n", + " root=data_root,\n", + " split=\"val\",\n", + " transforms=transform\n", + ")\n", + "\n", + "train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)\n", + "val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)\n", + "\n", + "print(\"Train samples:\", len(train_dataset))\n", + "print(\"Val samples:\", len(val_dataset))\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "P0dkft82EkeN" + }, + "source": [ + "## 4. Define UNet Model\n", + "\n", + "We define a simple UNet architecture for semantic segmentation. \n", + "The final layer outputs `n_classes` channels (one for each class in the dataset).\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pydGeenMErfZ" + }, + "outputs": [], + "source": [ + "class DoubleConv(nn.Module):\n", + " def __init__(self, in_channels, out_channels):\n", + " super(DoubleConv, self).__init__()\n", + " self.net = nn.Sequential(\n", + " nn.Conv2d(in_channels, out_channels, 3, padding=1),\n", + " nn.ReLU(inplace=True),\n", + " nn.Conv2d(out_channels, out_channels, 3, padding=1),\n", + " nn.ReLU(inplace=True),\n", + " )\n", + " def forward(self, x):\n", + " return self.net(x)\n", + "\n", + "class UNet(nn.Module):\n", + " def __init__(self, n_classes):\n", + " super(UNet, self).__init__()\n", + " self.enc1 = DoubleConv(3, 64)\n", + " self.pool = nn.MaxPool2d(2)\n", + " self.enc2 = DoubleConv(64, 128)\n", + " self.enc3 = DoubleConv(128, 256)\n", + "\n", + " self.bottleneck = DoubleConv(256, 512)\n", + "\n", + " self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)\n", + " self.dec3 = DoubleConv(512, 256)\n", + " self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)\n", + " self.dec2 = DoubleConv(256, 128)\n", + " self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)\n", + " self.dec1 = DoubleConv(128, 64)\n", + "\n", + " self.final = nn.Conv2d(64, n_classes, 1)\n", + "\n", + " def forward(self, x):\n", + " e1 = self.enc1(x)\n", + " e2 = self.enc2(self.pool(e1))\n", + " e3 = self.enc3(self.pool(e2))\n", + " b = self.bottleneck(self.pool(e3))\n", + " d3 = self.up3(b)\n", + " d3 = self.dec3(torch.cat([d3, e3], dim=1))\n", + " d2 = self.up2(d3)\n", + " d2 = self.dec2(torch.cat([d2, e2], dim=1))\n", + " d1 = self.up1(d2)\n", + " d1 = self.dec1(torch.cat([d1, e1], dim=1))\n", + " return self.final(d1)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Y9bk1mBjExWa" + }, + "source": [ + "## 5. Training the Model\n", + "\n", + "We train UNet for a few epochs using **CrossEntropyLoss** and **Adam optimizer**.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "SMlf-38xE0dT" + }, + "outputs": [], + "source": [ + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "num_classes = len(train_dataset.classes)\n", + "model = UNet(num_classes).to(device)\n", + "\n", + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = optim.Adam(model.parameters(), lr=1e-3)\n", + "\n", + "EPOCHS = 5\n", + "for epoch in range(EPOCHS):\n", + " model.train()\n", + " total_loss = 0\n", + " for imgs, masks in train_loader:\n", + " imgs, masks = imgs.to(device), masks.to(device)\n", + " optimizer.zero_grad()\n", + " outputs = model(imgs)\n", + " loss = criterion(outputs, masks)\n", + " loss.backward()\n", + " optimizer.step()\n", + " total_loss += loss.item()\n", + " print(f\"Epoch [{epoch+1}/{EPOCHS}] Loss: {total_loss/len(train_loader):.4f}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "a-MUDeDfE1qv" + }, + "source": [ + "## 6. Evaluation with DetectionMetrics\n", + "\n", + "Now we use `SegmentationEvaluator` from DetectionMetrics to compute metrics such as: \n", + "- **Mean Intersection over Union (mIoU)** \n", + "- **Pixel Accuracy** \n", + "- **Per-class metrics**\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "KeogQzsqE5US" + }, + "outputs": [], + "source": [ + "evaluator = SegmentationEvaluator(num_classes=num_classes, class_names=train_dataset.classes)\n", + "\n", + "model.eval()\n", + "with torch.no_grad():\n", + " for imgs, masks in val_loader:\n", + " imgs, masks = imgs.to(device), masks.to(device)\n", + " outputs = model(imgs)\n", + " preds = torch.argmax(outputs, dim=1)\n", + " evaluator.add_batch(preds.cpu().numpy(), masks.cpu().numpy())\n", + "\n", + "results = evaluator.evaluate()\n", + "print(\"Evaluation Results:\")\n", + "for metric, value in results.items():\n", + " print(f\"{metric}: {value:.4f}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Visualizing Predictions\n", + "\n", + "Finally, let’s visualize some input images, their ground-truth masks, and the predicted segmentation maps.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "imgs, masks = next(iter(val_loader))\n", + "imgs = imgs.to(device)\n", + "outputs = model(imgs)\n", + "preds = torch.argmax(outputs, dim=1).cpu()\n", + "\n", + "plt.figure(figsize=(12,6))\n", + "for i in range(2):\n", + " plt.subplot(3, 2, i*2+1)\n", + " plt.imshow(imgs[i].permute(1,2,0).cpu())\n", + " plt.title(\"Input Image\")\n", + " plt.subplot(3, 2, i*2+2)\n", + " plt.imshow(preds[i])\n", + " plt.title(\"Predicted Mask\")\n", + "plt.show()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ✅ Summary\n", + "\n", + "- We trained a UNet model on **RELLIS-3D**. \n", + "- More importantly, we used **DetectionMetrics** to evaluate it. \n", + "- The evaluation step is the main focus of DetectionMetrics and should always be included. \n" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tutorial/UnetOnRelis3D.ipynb b/tutorial/UnetOnRelis3D.ipynb deleted file mode 100644 index 57aaf5f2..00000000 --- a/tutorial/UnetOnRelis3D.ipynb +++ /dev/null @@ -1,429 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } - }, - "cells": [ - { - "cell_type": "markdown", - "source": [ - "The Rellis-3D dataset is split into:\n", - "1) Full Image (11GB) -> the RGB inputs.\n", - "2) Full Image Annotations(94MB, ID Format) -> segmentation labels where each pixel corresponds to a class Id." - ], - "metadata": { - "id": "IVxOn3wyEQ76" - } - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "4rrg2GbREPx9" - }, - "outputs": [], - "source": [ - "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "from torch.utils.data import DataLoader, Dataset\n", - "from torchvision import transforms\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "from pathlib import Path\n", - "from PIL import Image" - ] - }, - { - "cell_type": "markdown", - "source": [ - "We create a custom PyTorch dataset class called `Rellis3DDataset` that loads images and segmentation masks from the Rellis-3D dataset.\n", - "\n", - " In the `__init__` method, we define the root directory of the dataset, collect sorted file paths for images and labels, set the target resize dimensions, and optionally accept a transform function for preprocessing.\n", - "\n", - " The `__len__` method simply returns the number of available images, while the `__getitem__` method handles loading a single sample. For each index, it loads the corresponding RGB image and segmentation mask, resizes the image using bilinear interpolation (to preserve smoothness) and the mask using nearest neighbor interpolation (to keep class IDs intact), applies any defined transforms to the image (such as converting to a tensor and normalizing with ImageNet mean and standard deviation), and converts the label to a tensor of class IDs.\n", - " \n", - " Finally, it returns the image and its label as a pair. The transform pipeline consists of converting the image to a tensor and normalizing it, which is standard practice when using pretrained models like UNet encoders.\n", - " \n", - " We then create a dataset instance pointing to the validation folder of Rellis-3D and wrap it with a DataLoader, which batches the data (batch size of 4 in this example) and handles iteration. With this setup, we now have a clean way to load, preprocess, and batch Rellis-3D samples for training or evaluation." - ], - "metadata": { - "id": "-sMH4bWtEXiz" - } - }, - { - "cell_type": "code", - "source": [ - "class Rellis3DDataset(Dataset):\n", - " def __init__(self, root_dir, image_size=(512, 512), transform=None):\n", - " self.root_dir = Path(root_dir)\n", - " self.image_paths = sorted((self.root_dir / \"images\").glob(\"*.png\"))\n", - " self.label_paths = sorted((self.root_dir / \"labels\").glob(\"*.png\"))\n", - " self.transform = transform\n", - " self.image_size = image_size\n", - "\n", - " def __len__(self):\n", - " return len(self.image_paths)\n", - "\n", - " def __getitem__(self, idx):\n", - " img = Image.open(self.image_paths[idx]).convert(\"RGB\")\n", - " label = Image.open(self.label_paths[idx])\n", - "\n", - "\n", - " img = img.resize(self.image_size, Image.BILINEAR)\n", - " label = label.resize(self.image_size, Image.NEAREST)\n", - "\n", - " if self.transform:\n", - " img = self.transform(img)\n", - "\n", - " label = torch.from_numpy(np.array(label)).long()\n", - " return img, label\n", - "\n", - "\n", - "\n", - "transform = transforms.Compose([\n", - " transforms.ToTensor(),\n", - " transforms.Normalize(mean=[0.485, 0.456, 0.406],\n", - " std=[0.229, 0.224, 0.225]),\n", - "])\n", - "\n", - "\n", - "dataset = Rellis3DDataset(\"path/to/Rellis-3D/val\", transform=transform)\n", - "dataloader = DataLoader(dataset, batch_size=4, shuffle=False)\n" - ], - "metadata": { - "id": "fcyZcTMDEVVV" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "# UNet Model Explanation (PyTorch)\n", - "\n", - "This document explains the structure and working of a **UNet model** implemented in PyTorch for image segmentation.\n", - "\n", - "---\n", - "\n", - "## 1. DoubleConv Block\n", - "\n", - "- **Purpose:** Performs feature extraction through **two consecutive convolution layers**.\n", - "- **Operations:**\n", - " - **Conv2d:** 2D convolution extracts features from input images.\n", - " - **BatchNorm2d:** Normalizes outputs of the convolution layer for faster and stable training.\n", - " - **ReLU:** Introduces non-linearity.\n", - "- **Padding:** `padding=1` ensures the output spatial dimensions are the same as the input.\n", - "- **Benefit:** Two convolutions per block allow the network to learn richer feature representations at each level.\n", - "\n", - "---\n", - "\n", - "## 2. UNet Architecture\n", - "\n", - "UNet is an **encoder-decoder network** with skip connections, specifically designed for image segmentation.\n", - "\n", - "### Encoder (Downsampling Path)\n", - "\n", - "- Comprised of **DoubleConv blocks** followed by **Max Pooling**.\n", - "- Reduces spatial dimensions while increasing the number of feature channels.\n", - "- Example progression: `3 → 64 → 128 → 256 → 512` channels.\n", - "- **MaxPool2d(2):** Reduces width and height by half at each step.\n", - "\n", - "### Bottleneck\n", - "\n", - "- The **deepest layer** in the network.\n", - "- Uses a **DoubleConv block** to capture high-level features with maximum channels (`512 → 1024`).\n", - "- Acts as a bridge between encoder and decoder.\n", - "\n", - "### Decoder (Upsampling Path)\n", - "\n", - "- Upsamples the feature maps using **ConvTranspose2d**, increasing spatial resolution.\n", - "- **Skip Connections:** Each upsampled feature map is concatenated with the corresponding encoder output.\n", - " - Helps the network preserve fine-grained spatial information.\n", - "- Reduces channel dimension progressively: `1024 → 512 → 256 → 128 → 64`.\n", - "\n", - "### Output Layer\n", - "\n", - "- **Final Conv2d:** Reduces channels to the number of classes (`n_classes`), producing the segmentation map.\n", - "\n", - "---\n", - "\n", - "## 3. Forward Pass Workflow\n", - "\n", - "1. **Encoding:** Input passes through encoder blocks, producing feature maps (`e1, e2, e3, e4`).\n", - "2. **Pooling:** Max pooling downsamples feature maps between encoder layers.\n", - "3. **Bottleneck:** Deepest layer extracts high-level features.\n", - "4. **Decoding:**\n", - " - Upsample bottleneck features.\n", - " - Concatenate with encoder features (skip connections).\n", - " - Refine features using DoubleConv blocks.\n", - "5. **Final Output:** Last convolution layer outputs the segmentation map with `n_classes` channels.\n", - "\n", - "---\n", - "\n", - "## 4. Model Initialization and Pretrained Weights\n", - "\n", - "- `n_classes = 19`: Number of segmentation classes.\n", - "- **Loading Weights:** Pretrained weights are loaded using `load_state_dict`.\n", - "- **Evaluation Mode:** `eval()` disables training-specific layers like dropout and fixes batch normalization.\n", - "- **GPU Usage:** `cuda()` moves the model to GPU for faster inference.\n", - "\n", - "---\n", - "\n", - "## 5. Key Points\n", - "\n", - "- **Skip Connections:** Preserve spatial details lost during downsampling.\n", - "- **DoubleConv Blocks:** Improve feature extraction at each level.\n", - "- **Encoder-Decoder Symmetry:** Ensures feature information is combined effectively during reconstruction.\n", - "- **UNet Use Case:** Commonly used in semantic segmentation tasks such as satellite imagery, medical imaging, and autonomous driving.\n" - ], - "metadata": { - "id": "P0dkft82EkeN" - } - }, - { - "cell_type": "code", - "source": [ - "class DoubleConv(nn.Module):\n", - " def __init__(self, in_channels, out_channels):\n", - " super().__init__()\n", - " self.double_conv = nn.Sequential(\n", - " nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),\n", - " nn.BatchNorm2d(out_channels),\n", - " nn.ReLU(inplace=True),\n", - " nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),\n", - " nn.BatchNorm2d(out_channels),\n", - " nn.ReLU(inplace=True),\n", - " )\n", - "\n", - " def forward(self, x):\n", - " return self.double_conv(x)\n", - "\n", - "\n", - "class UNet(nn.Module):\n", - " def __init__(self, n_classes):\n", - " super().__init__()\n", - " self.enc1 = DoubleConv(3, 64)\n", - " self.enc2 = DoubleConv(64, 128)\n", - " self.enc3 = DoubleConv(128, 256)\n", - " self.enc4 = DoubleConv(256, 512)\n", - "\n", - " self.pool = nn.MaxPool2d(2)\n", - "\n", - " self.bottleneck = DoubleConv(512, 1024)\n", - "\n", - " self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)\n", - " self.dec4 = DoubleConv(1024, 512)\n", - " self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)\n", - " self.dec3 = DoubleConv(512, 256)\n", - " self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)\n", - " self.dec2 = DoubleConv(256, 128)\n", - " self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)\n", - " self.dec1 = DoubleConv(128, 64)\n", - "\n", - " self.final = nn.Conv2d(64, n_classes, kernel_size=1)\n", - "\n", - " def forward(self, x):\n", - " e1 = self.enc1(x)\n", - " e2 = self.enc2(self.pool(e1))\n", - " e3 = self.enc3(self.pool(e2))\n", - " e4 = self.enc4(self.pool(e3))\n", - "\n", - " b = self.bottleneck(self.pool(e4))\n", - "\n", - " d4 = self.up4(b)\n", - " d4 = self.dec4(torch.cat([d4, e4], dim=1))\n", - " d3 = self.up3(d4)\n", - " d3 = self.dec3(torch.cat([d3, e3], dim=1))\n", - " d2 = self.up2(d3)\n", - " d2 = self.dec2(torch.cat([d2, e2], dim=1))\n", - " d1 = self.up1(d2)\n", - " d1 = self.dec1(torch.cat([d1, e1], dim=1))\n", - "\n", - " return self.final(d1)\n", - "\n", - "\n", - "n_classes = 19\n", - "model = UNet(n_classes=n_classes)\n", - "model.load_state_dict(torch.load(\"unet_rellis3d.pth\", map_location=\"cuda\"))\n", - "model.eval().cuda()\n" - ], - "metadata": { - "id": "pydGeenMErfZ" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "# Evaluation of UNet Model for Segmentation\n", - "\n", - "This document explains the evaluation functions used to measure the performance of a **UNet segmentation model** using metrics like **mIoU** and **Pixel Accuracy**.\n", - "\n", - "---\n", - "\n", - "## 1. Confusion Matrix\n", - "\n", - "**Function:** `compute_confusion_matrix(pred, label, num_classes)`\n", - "\n", - "- **Purpose:** Builds a confusion matrix for a single prediction-label pair.\n", - "- **Process:**\n", - " - **Masking:** Ensures only valid labels (0 ≤ label < num_classes) are considered.\n", - " - **Indexing:** Each pair of `(label, pred)` is mapped to a unique index: `num_classes * label + pred`.\n", - " - **Bincount:** Counts occurrences of each `(label, pred)` pair.\n", - " - **Reshape:** Converts the flat count vector into a `(num_classes x num_classes)` confusion matrix.\n", - "\n", - "**Key Concept:**\n", - "\n", - "- **Rows:** Ground truth classes.\n", - "- **Columns:** Predicted classes.\n", - "- **Diagonal:** Correct predictions.\n", - "\n", - "---\n", - "\n", - "## 2. Mean Intersection over Union (mIoU)\n", - "\n", - "**Function:** `compute_mIoU(confusion_matrix)`\n", - "\n", - "- **Purpose:** Measures the overlap between predicted and true segmentation regions.\n", - "- **Computation:**\n", - " - **Intersection:** Diagonal of the confusion matrix → correctly predicted pixels for each class.\n", - " - **Union:** Sum of row + sum of column - intersection → total pixels covered by either prediction or ground truth for each class.\n", - " - **IoU per class:** `intersection / union`.\n", - " - **Mean IoU:** Average across all classes.\n", - "\n", - "**Significance:** mIoU is a standard metric for semantic segmentation, reflecting per-class accuracy while accounting for class imbalance.\n", - "\n", - "---\n", - "\n", - "## 3. Evaluation Function\n", - "\n", - "**Function:** `evaluate(model, dataloader, num_classes=19)`\n", - "\n", - "- **Purpose:** Computes **mIoU** and **Pixel Accuracy** over the entire dataset.\n", - "- **Steps:**\n", - " 1. Initialize a zeroed confusion matrix.\n", - " 2. Disable gradient computation with `torch.no_grad()` (faster evaluation).\n", - " 3. Iterate over batches from `dataloader`:\n", - " - Move images and labels to GPU.\n", - " - Run the model to get predictions.\n", - " - Convert raw outputs to class predictions using `argmax`.\n", - " - Update the confusion matrix for each image in the batch.\n", - " 4. Compute **mIoU** using `compute_mIoU`.\n", - " 5. Compute **Pixel Accuracy**: total correctly predicted pixels divided by total pixels.\n", - "\n", - "**Outputs:**\n", - "\n", - "- **mIoU:** Mean Intersection over Union (range: 0–1)\n", - "- **Pixel Accuracy:** Overall fraction of correctly classified pixels (range: 0–1)\n", - "\n", - "---\n", - "\n", - "## 4. Example Usage\n", - "\n", - "- Call the evaluation function with the model and dataloader:\n", - "\n", - "```python\n", - "mIoU, pixel_acc = evaluate(model, dataloader, num_classes=19)\n", - "print(f\"mIoU: {mIoU:.4f}, Pixel Accuracy: {pixel_acc:.4f}\")\n" - ], - "metadata": { - "id": "Y9bk1mBjExWa" - } - }, - { - "cell_type": "code", - "source": [ - "def compute_confusion_matrix(pred, label, num_classes):\n", - " mask = (label >= 0) & (label < num_classes)\n", - " hist = torch.bincount(\n", - " num_classes * label[mask] + pred[mask],\n", - " minlength=num_classes ** 2\n", - " ).reshape(num_classes, num_classes)\n", - " return hist\n", - "\n", - "def compute_mIoU(confusion_matrix):\n", - " intersection = torch.diag(confusion_matrix)\n", - " union = confusion_matrix.sum(1) + confusion_matrix.sum(0) - intersection\n", - " IoU = intersection / union\n", - " return IoU.mean().item()\n", - "\n", - "def evaluate(model, dataloader, num_classes=19):\n", - " confusion_matrix = torch.zeros((num_classes, num_classes), dtype=torch.int64)\n", - "\n", - " with torch.no_grad():\n", - " for imgs, labels in dataloader:\n", - " imgs, labels = imgs.cuda(), labels.cuda()\n", - " outputs = model(imgs)\n", - " preds = torch.argmax(outputs, dim=1)\n", - "\n", - " for p, l in zip(preds, labels):\n", - " confusion_matrix += compute_confusion_matrix(p.view(-1), l.view(-1), num_classes)\n", - "\n", - " mIoU = compute_mIoU(confusion_matrix)\n", - " pixel_acc = torch.diag(confusion_matrix).sum().item() / confusion_matrix.sum().item()\n", - " return mIoU, pixel_acc\n", - "\n", - "\n", - "mIoU, pixel_acc = evaluate(model, dataloader, num_classes=19)\n", - "print(f\"mIoU: {mIoU:.4f}, Pixel Accuracy: {pixel_acc:.4f}\")\n" - ], - "metadata": { - "id": "SMlf-38xE0dT" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "### Visualization" - ], - "metadata": { - "id": "a-MUDeDfE1qv" - } - }, - { - "cell_type": "code", - "source": [ - "def visualize(model, dataset, idx=0):\n", - " model.eval()\n", - " img, label = dataset[idx]\n", - " img_input = img.unsqueeze(0).cuda()\n", - "\n", - " with torch.no_grad():\n", - " output = model(img_input)\n", - " pred = torch.argmax(output, dim=1).squeeze().cpu().numpy()\n", - "\n", - " fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n", - " axes[0].imshow(img.permute(1,2,0).cpu().numpy() * 0.229 + 0.485) # approx unnormalize\n", - " axes[0].set_title(\"Input Image\")\n", - " axes[1].imshow(label.numpy())\n", - " axes[1].set_title(\"Ground Truth\")\n", - " axes[2].imshow(pred)\n", - " axes[2].set_title(\"Prediction\")\n", - " plt.show()\n", - "\n", - "\n", - "\n", - "visualize(model, dataset, idx=5)\n" - ], - "metadata": { - "id": "KeogQzsqE5US" - }, - "execution_count": null, - "outputs": [] - } - ] -} \ No newline at end of file