diff --git a/examples/UnetOnRelis3D.ipynb b/examples/UnetOnRelis3D.ipynb new file mode 100644 index 00000000..e0054582 --- /dev/null +++ b/examples/UnetOnRelis3D.ipynb @@ -0,0 +1,308 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "IVxOn3wyEQ76" + }, + "source": [ + "# UNet on RELLIS-3D with DetectionMetrics\n", + "\n", + "This tutorial shows how to train a simple **UNet** model on the **RELLIS-3D dataset** and then **evaluate it** using the [DetectionMetrics](https://jderobot.github.io/DetectionMetrics/v2/) library. \n", + "\n", + "While training is included here for demonstration, the main focus of DetectionMetrics is **evaluation**. \n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Installation\n", + "\n", + "First, install the required dependencies: **PyTorch**, **torchvision**, and **DetectionMetrics**.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pip install torch torchvision\n", + "pip install detection-metrics" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Imports\n", + "\n", + "We import PyTorch for model training and DetectionMetrics for dataset handling and evaluation.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4rrg2GbREPx9" + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "from torch.utils.data import DataLoader\n", + "import torchvision.transforms as T\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from detection_metrics.datasets import Rellis3DImageSegmentationDataset\n", + "from detection_metrics.evaluators import SegmentationEvaluator\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Load RELLIS-3D Dataset\n", + "\n", + "DetectionMetrics provides a ready-to-use class `Rellis3DImageSegmentationDataset`. \n", + "Here we create **train** and **validation** splits, and apply basic transformations.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "fcyZcTMDEVVV" + }, + "outputs": [], + "source": [ + "data_root = \"/path/to/rellis3d\" # TODO: replace with your dataset path\n", + "\n", + "transform = T.Compose([\n", + " T.ToTensor(),\n", + " T.Resize((256, 256)),\n", + "])\n", + "\n", + "train_dataset = Rellis3DImageSegmentationDataset(\n", + " root=data_root,\n", + " split=\"train\",\n", + " transforms=transform\n", + ")\n", + "\n", + "val_dataset = Rellis3DImageSegmentationDataset(\n", + " root=data_root,\n", + " split=\"val\",\n", + " transforms=transform\n", + ")\n", + "\n", + "train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)\n", + "val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)\n", + "\n", + "print(\"Train samples:\", len(train_dataset))\n", + "print(\"Val samples:\", len(val_dataset))\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "P0dkft82EkeN" + }, + "source": [ + "## 4. Define UNet Model\n", + "\n", + "We define a simple UNet architecture for semantic segmentation. \n", + "The final layer outputs `n_classes` channels (one for each class in the dataset).\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pydGeenMErfZ" + }, + "outputs": [], + "source": [ + "class DoubleConv(nn.Module):\n", + " def __init__(self, in_channels, out_channels):\n", + " super(DoubleConv, self).__init__()\n", + " self.net = nn.Sequential(\n", + " nn.Conv2d(in_channels, out_channels, 3, padding=1),\n", + " nn.ReLU(inplace=True),\n", + " nn.Conv2d(out_channels, out_channels, 3, padding=1),\n", + " nn.ReLU(inplace=True),\n", + " )\n", + " def forward(self, x):\n", + " return self.net(x)\n", + "\n", + "class UNet(nn.Module):\n", + " def __init__(self, n_classes):\n", + " super(UNet, self).__init__()\n", + " self.enc1 = DoubleConv(3, 64)\n", + " self.pool = nn.MaxPool2d(2)\n", + " self.enc2 = DoubleConv(64, 128)\n", + " self.enc3 = DoubleConv(128, 256)\n", + "\n", + " self.bottleneck = DoubleConv(256, 512)\n", + "\n", + " self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)\n", + " self.dec3 = DoubleConv(512, 256)\n", + " self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)\n", + " self.dec2 = DoubleConv(256, 128)\n", + " self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)\n", + " self.dec1 = DoubleConv(128, 64)\n", + "\n", + " self.final = nn.Conv2d(64, n_classes, 1)\n", + "\n", + " def forward(self, x):\n", + " e1 = self.enc1(x)\n", + " e2 = self.enc2(self.pool(e1))\n", + " e3 = self.enc3(self.pool(e2))\n", + " b = self.bottleneck(self.pool(e3))\n", + " d3 = self.up3(b)\n", + " d3 = self.dec3(torch.cat([d3, e3], dim=1))\n", + " d2 = self.up2(d3)\n", + " d2 = self.dec2(torch.cat([d2, e2], dim=1))\n", + " d1 = self.up1(d2)\n", + " d1 = self.dec1(torch.cat([d1, e1], dim=1))\n", + " return self.final(d1)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Y9bk1mBjExWa" + }, + "source": [ + "## 5. Training the Model\n", + "\n", + "We train UNet for a few epochs using **CrossEntropyLoss** and **Adam optimizer**.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "SMlf-38xE0dT" + }, + "outputs": [], + "source": [ + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "num_classes = len(train_dataset.classes)\n", + "model = UNet(num_classes).to(device)\n", + "\n", + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = optim.Adam(model.parameters(), lr=1e-3)\n", + "\n", + "EPOCHS = 5\n", + "for epoch in range(EPOCHS):\n", + " model.train()\n", + " total_loss = 0\n", + " for imgs, masks in train_loader:\n", + " imgs, masks = imgs.to(device), masks.to(device)\n", + " optimizer.zero_grad()\n", + " outputs = model(imgs)\n", + " loss = criterion(outputs, masks)\n", + " loss.backward()\n", + " optimizer.step()\n", + " total_loss += loss.item()\n", + " print(f\"Epoch [{epoch+1}/{EPOCHS}] Loss: {total_loss/len(train_loader):.4f}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "a-MUDeDfE1qv" + }, + "source": [ + "## 6. Evaluation with DetectionMetrics\n", + "\n", + "Now we use `SegmentationEvaluator` from DetectionMetrics to compute metrics such as: \n", + "- **Mean Intersection over Union (mIoU)** \n", + "- **Pixel Accuracy** \n", + "- **Per-class metrics**\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "KeogQzsqE5US" + }, + "outputs": [], + "source": [ + "evaluator = SegmentationEvaluator(num_classes=num_classes, class_names=train_dataset.classes)\n", + "\n", + "model.eval()\n", + "with torch.no_grad():\n", + " for imgs, masks in val_loader:\n", + " imgs, masks = imgs.to(device), masks.to(device)\n", + " outputs = model(imgs)\n", + " preds = torch.argmax(outputs, dim=1)\n", + " evaluator.add_batch(preds.cpu().numpy(), masks.cpu().numpy())\n", + "\n", + "results = evaluator.evaluate()\n", + "print(\"Evaluation Results:\")\n", + "for metric, value in results.items():\n", + " print(f\"{metric}: {value:.4f}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Visualizing Predictions\n", + "\n", + "Finally, let’s visualize some input images, their ground-truth masks, and the predicted segmentation maps.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "imgs, masks = next(iter(val_loader))\n", + "imgs = imgs.to(device)\n", + "outputs = model(imgs)\n", + "preds = torch.argmax(outputs, dim=1).cpu()\n", + "\n", + "plt.figure(figsize=(12,6))\n", + "for i in range(2):\n", + " plt.subplot(3, 2, i*2+1)\n", + " plt.imshow(imgs[i].permute(1,2,0).cpu())\n", + " plt.title(\"Input Image\")\n", + " plt.subplot(3, 2, i*2+2)\n", + " plt.imshow(preds[i])\n", + " plt.title(\"Predicted Mask\")\n", + "plt.show()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ✅ Summary\n", + "\n", + "- We trained a UNet model on **RELLIS-3D**. \n", + "- More importantly, we used **DetectionMetrics** to evaluate it. \n", + "- The evaluation step is the main focus of DetectionMetrics and should always be included. \n" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}