diff --git a/test/assets/fakedata/draw_rotated_boxes.png b/test/assets/fakedata/draw_rotated_boxes.png new file mode 100644 index 00000000000..4e5a5eb5414 Binary files /dev/null and b/test/assets/fakedata/draw_rotated_boxes.png differ diff --git a/test/common_utils.py b/test/common_utils.py index b3a26dfd441..a74f204f429 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -444,13 +444,13 @@ def sample_position(values, max_value): r_rad = r * torch.pi / 180.0 cos, sin = torch.cos(r_rad), torch.sin(r_rad) x1, y1 = x, y - x3 = x1 + w * cos - y3 = y1 - w * sin - x2 = x3 + h * sin - y2 = y3 + h * cos + x2 = x1 + w * cos + y2 = y1 - w * sin + x3 = x2 + h * sin + y3 = y2 + h * cos x4 = x1 + h * sin y4 = y1 + h * cos - parts = (x1, y1, x3, y3, x2, y2, x4, y4) + parts = (x1, y1, x2, y2, x3, y3, x4, y4) else: raise ValueError(f"Format {format} is not supported") diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 94d90b9e2f6..21e81ec37f8 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -560,6 +560,78 @@ def affine_bounding_boxes(bounding_boxes): ) +def reference_affine_rotated_bounding_boxes_helper(bounding_boxes, *, affine_matrix, new_canvas_size=None, clamp=True): + format = bounding_boxes.format + canvas_size = new_canvas_size or bounding_boxes.canvas_size + + def affine_rotated_bounding_boxes(bounding_boxes): + dtype = bounding_boxes.dtype + device = bounding_boxes.device + + # Go to float before converting to prevent precision loss in case of CXCYWHR -> XYXYXYXY and W or H is 1 + input_xyxyxyxy = F.convert_bounding_box_format( + bounding_boxes.to(dtype=torch.float64, device="cpu", copy=True), + old_format=format, + new_format=tv_tensors.BoundingBoxFormat.XYXYXYXY, + inplace=True, + ) + x1, y1, x2, y2, x3, y3, x4, y4 = input_xyxyxyxy.squeeze(0).tolist() + + points = np.array( + [ + [x1, y1, 1.0], + [x2, y2, 1.0], + [x3, y3, 1.0], + [x4, y4, 1.0], + ] + ) + transformed_points = np.matmul(points, affine_matrix.astype(points.dtype).T) + output = torch.tensor( + [ + float(transformed_points[1, 0]), + float(transformed_points[1, 1]), + float(transformed_points[0, 0]), + float(transformed_points[0, 1]), + float(transformed_points[3, 0]), + float(transformed_points[3, 1]), + float(transformed_points[2, 0]), + float(transformed_points[2, 1]), + ] + ) + + output = F.convert_bounding_box_format( + output, old_format=tv_tensors.BoundingBoxFormat.XYXYXYXY, new_format=format + ) + + if clamp: + # It is important to clamp before casting, especially for CXCYWHR format, dtype=int64 + output = F.clamp_bounding_boxes( + output, + format=format, + canvas_size=canvas_size, + ) + else: + # We leave the bounding box as float32 so the caller gets the full precision to perform any additional + # operation + dtype = output.dtype + + return output.to(dtype=dtype, device=device) + + return tv_tensors.BoundingBoxes( + torch.cat( + [ + affine_rotated_bounding_boxes(b) + for b in bounding_boxes.reshape( + -1, 5 if format != tv_tensors.BoundingBoxFormat.XYXYXYXY else 8 + ).unbind() + ], + dim=0, + ).reshape(bounding_boxes.shape), + format=format, + canvas_size=canvas_size, + ) + + class TestResize: INPUT_SIZE = (17, 11) OUTPUT_SIZES = [17, [17], (17,), None, [12, 13], (12, 13)] @@ -1012,7 +1084,7 @@ class TestHorizontalFlip: def test_kernel_image(self, dtype, device): check_kernel(F.horizontal_flip_image, make_image(dtype=dtype, device=device)) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_kernel_bounding_boxes(self, format, dtype, device): @@ -1071,7 +1143,7 @@ def test_image_correctness(self, fn): torch.testing.assert_close(actual, expected) - def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes): + def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes: tv_tensors.BoundingBoxes): affine_matrix = np.array( [ [-1, 0, bounding_boxes.canvas_size[1]], @@ -1079,9 +1151,14 @@ def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes): ], ) - return reference_affine_bounding_boxes_helper(bounding_boxes, affine_matrix=affine_matrix) + helper = ( + reference_affine_rotated_bounding_boxes_helper + if tv_tensors.is_rotated_bounding_format(bounding_boxes.format) + else reference_affine_bounding_boxes_helper + ) + return helper(bounding_boxes, affine_matrix=affine_matrix) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize( "fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)] ) @@ -1464,7 +1541,7 @@ class TestVerticalFlip: def test_kernel_image(self, dtype, device): check_kernel(F.vertical_flip_image, make_image(dtype=dtype, device=device)) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_kernel_bounding_boxes(self, format, dtype, device): @@ -1521,7 +1598,7 @@ def test_image_correctness(self, fn): torch.testing.assert_close(actual, expected) - def _reference_vertical_flip_bounding_boxes(self, bounding_boxes): + def _reference_vertical_flip_bounding_boxes(self, bounding_boxes: tv_tensors.BoundingBoxes): affine_matrix = np.array( [ [1, 0, 0], @@ -1529,9 +1606,14 @@ def _reference_vertical_flip_bounding_boxes(self, bounding_boxes): ], ) - return reference_affine_bounding_boxes_helper(bounding_boxes, affine_matrix=affine_matrix) + helper = ( + reference_affine_rotated_bounding_boxes_helper + if tv_tensors.is_rotated_bounding_format(bounding_boxes.format) + else reference_affine_bounding_boxes_helper + ) + return helper(bounding_boxes, affine_matrix=affine_matrix) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)]) def test_bounding_boxes_correctness(self, format, fn): bounding_boxes = make_bounding_boxes(format=format) diff --git a/test/test_tv_tensors.py b/test/test_tv_tensors.py index a8e59ab7531..073fb13b107 100644 --- a/test/test_tv_tensors.py +++ b/test/test_tv_tensors.py @@ -43,6 +43,34 @@ def test_bbox_instance(data, format): assert bboxes.format == format +@pytest.mark.parametrize( + "format, is_rotated_expected", + [ + ("XYXY", False), + ("XYWH", False), + ("CXCYWH", False), + ("XYXYXYXY", True), + ("XYWHR", True), + ("CXCYWHR", True), + (tv_tensors.BoundingBoxFormat.XYXY, False), + (tv_tensors.BoundingBoxFormat.XYWH, False), + (tv_tensors.BoundingBoxFormat.CXCYWH, False), + (tv_tensors.BoundingBoxFormat.XYXYXYXY, True), + (tv_tensors.BoundingBoxFormat.XYWHR, True), + (tv_tensors.BoundingBoxFormat.CXCYWHR, True), + ], +) +@pytest.mark.parametrize("scripted", (False, True)) +def test_bbox_format(format, is_rotated_expected, scripted): + if isinstance(format, str): + format = tv_tensors.BoundingBoxFormat[(format.upper())] + + fn = tv_tensors.is_rotated_bounding_format + if scripted: + fn = torch.jit.script(fn) + assert fn(format) == is_rotated_expected + + def test_bbox_dim_error(): data_3d = [[[1, 2, 3, 4]]] with pytest.raises(ValueError, match="Expected a 1D or 2D tensor, got 3D"): diff --git a/test/test_utils.py b/test/test_utils.py index 3cad178d00a..000798a0609 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -17,7 +17,25 @@ PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split(".")) boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) - +rotated_boxes = torch.tensor( + [ + [100, 150, 150, 150, 150, 250, 100, 250], + [200, 350, 250, 350, 250, 250, 200, 250], + [300, 200, 200, 200, 200, 250, 300, 250], + # Not really a rectangle, but it doesn't matter + [ + 100, + 100, + 200, + 50, + 290, + 350, + 200, + 400, + ], + ], + dtype=torch.float, +) keypoints = torch.tensor([[[10, 10], [5, 5], [2, 2]], [[20, 20], [30, 30], [3, 3]]], dtype=torch.float) @@ -148,6 +166,17 @@ def test_draw_boxes_with_coloured_label_backgrounds(): assert_equal(result, expected) +@pytest.mark.skipif(PILLOW_VERSION < (10, 1), reason="The reference image is only valid for PIL >= 10.1") +def test_draw_rotated_boxes(): + img = torch.full((3, 500, 500), 255, dtype=torch.uint8) + colors = ["blue", "yellow", (0, 255, 0), "black"] + + result = utils.draw_bounding_boxes(img, rotated_boxes, colors=colors) + path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_rotated_boxes.png") + expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) + assert_equal(result, expected) + + @pytest.mark.parametrize("fill", [True, False]) def test_draw_boxes_dtypes(fill): img_uint8 = torch.full((3, 100, 100), 255, dtype=torch.uint8) diff --git a/torchvision/ops/_box_convert.py b/torchvision/ops/_box_convert.py index 4484007bc83..1910e46088a 100644 --- a/torchvision/ops/_box_convert.py +++ b/torchvision/ops/_box_convert.py @@ -130,56 +130,56 @@ def _box_xywhr_to_cxcywhr(boxes: Tensor) -> Tensor: def _box_xywhr_to_xyxyxyxy(boxes: Tensor) -> Tensor: """ - Converts rotated bounding boxes from (x1, y1, w, h, r) format to (x1, y1, x3, y3, x2, y2, x4, y4) format. + Converts rotated bounding boxes from (x1, y1, w, h, r) format to (x1, y1, x2, y2, x3, y3, x4, y4) format. (x1, y1) refer to top left of bounding box (w, h) are width and height of the rotated bounding box r is rotation angle w.r.t to the box center by :math:`|r|` degrees counter clock wise in the image plan (x1, y1) refer to top left of rotated bounding box - (x3, y3) refer to top right of rotated bounding box - (x2, y2) refer to bottom right of rotated bounding box + (x2, y2) refer to top right of rotated bounding box + (x3, y3) refer to bottom right of rotated bounding box (x4, y4) refer to bottom left ofrotated bounding box Args: boxes (Tensor[N, 5]): rotated boxes in (cx, cy, w, h, r) format which will be converted. Returns: - boxes (Tensor(N, 8)): rotated boxes in (x1, y1, x3, y3, x2, y2, x4, y4) format. + boxes (Tensor(N, 8)): rotated boxes in (x1, y1, x2, y2, x3, y3, x4, y4) format. """ x1, y1, w, h, r = boxes.unbind(-1) r_rad = r * torch.pi / 180.0 cos, sin = torch.cos(r_rad), torch.sin(r_rad) - x3 = x1 + w * cos - y3 = y1 - w * sin - x2 = x3 + h * sin - y2 = y3 + h * cos + x2 = x1 + w * cos + y2 = y1 - w * sin + x3 = x2 + h * sin + y3 = y2 + h * cos x4 = x1 + h * sin y4 = y1 + h * cos - return torch.stack((x1, y1, x3, y3, x2, y2, x4, y4), dim=-1) + return torch.stack((x1, y1, x2, y2, x3, y3, x4, y4), dim=-1) def _box_xyxyxyxy_to_xywhr(boxes: Tensor) -> Tensor: """ - Converts rotated bounding boxes from (x1, y1, x3, y3, x2, y2, x4, y4) format to (x1, y1, w, h, r) format. + Converts rotated bounding boxes from (x1, y1, x2, y2, x3, y3, x4, y4) format to (x1, y1, w, h, r) format. (x1, y1) refer to top left of the rotated bounding box - (x3, y3) refer to bottom left of the rotated bounding box - (x2, y2) refer to bottom right of the rotated bounding box + (x2, y2) refer to bottom left of the rotated bounding box + (x3, y3) refer to bottom right of the rotated bounding box (x4, y4) refer to top right of the rotated bounding box (w, h) refers to width and height of rotated bounding box r is rotation angle w.r.t to the box center by :math:`|r|` degrees counter clock wise in the image plan Args: - boxes (Tensor(N, 8)): rotated boxes in (x1, y1, x3, y3, x2, y2, x4, y4) format. + boxes (Tensor(N, 8)): rotated boxes in (x1, y1, x2, y2, x3, y3, x4, y4) format. Returns: boxes (Tensor[N, 5]): rotated boxes in (x1, y1, w, h, r) format. """ - x1, y1, x3, y3, x2, y2, x4, y4 = boxes.unbind(-1) - r_rad = torch.atan2(y1 - y3, x3 - x1) + x1, y1, x2, y2, x3, y3, x4, y4 = boxes.unbind(-1) + r_rad = torch.atan2(y1 - y2, x2 - x1) r = r_rad * 180 / torch.pi - w = ((x3 - x1) ** 2 + (y1 - y3) ** 2).sqrt() + w = ((x2 - x1) ** 2 + (y1 - y2) ** 2).sqrt() h = ((x3 - x2) ** 2 + (y3 - y2) ** 2).sqrt() boxes = torch.stack((x1, y1, w, h, r), dim=-1) diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index be8d59716bf..7fb8192e1cd 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -209,8 +209,8 @@ def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor: being width and height. r is rotation angle w.r.t to the box center by :math:`|r|` degrees counter clock wise in the image plan - ``'xyxyxyxy'``: boxes are represented via corners, x1, y1 being top left, x2, y2 bottom right, - x3, y3 bottom left, and x4, y4 top right. + ``'xyxyxyxy'``: boxes are represented via corners, x1, y1 being top left, x2, y2 top right, + x3, y3 bottom right, and x4, y4 bottom left. Args: boxes (Tensor[N, K]): boxes which will be converted. K is the number of coordinates (4 for unrotated bounding boxes, 5 or 8 for rotated bounding boxes) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 1381f5d39cb..34a6b3692b9 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -71,14 +71,38 @@ def horizontal_flip_bounding_boxes( ) -> torch.Tensor: shape = bounding_boxes.shape - bounding_boxes = bounding_boxes.clone().reshape(-1, 4) + if tv_tensors.is_rotated_bounding_format(format): + bounding_boxes = ( + bounding_boxes.clone().reshape(-1, 5) + if format != tv_tensors.BoundingBoxFormat.XYXYXYXY + else bounding_boxes.clone().reshape(-1, 8) + ) + else: + bounding_boxes = bounding_boxes.clone().reshape(-1, 4) if format == tv_tensors.BoundingBoxFormat.XYXY: bounding_boxes[:, [2, 0]] = bounding_boxes[:, [0, 2]].sub_(canvas_size[1]).neg_() elif format == tv_tensors.BoundingBoxFormat.XYWH: bounding_boxes[:, 0].add_(bounding_boxes[:, 2]).sub_(canvas_size[1]).neg_() - else: # format == tv_tensors.BoundingBoxFormat.CXCYWH: + elif format == tv_tensors.BoundingBoxFormat.CXCYWH: + bounding_boxes[:, 0].sub_(canvas_size[1]).neg_() + elif format == tv_tensors.BoundingBoxFormat.XYXYXYXY: + bounding_boxes[:, 0::2].sub_(canvas_size[1]).neg_() + bounding_boxes = bounding_boxes[:, [2, 3, 0, 1, 6, 7, 4, 5]] + elif format == tv_tensors.BoundingBoxFormat.XYWHR: + + dtype = bounding_boxes.dtype + if not torch.is_floating_point(bounding_boxes): + # Casting to float to support cos and sin computations. + bounding_boxes = bounding_boxes.to(torch.float32) + angle_rad = bounding_boxes[:, 4].mul(torch.pi).div(180) + bounding_boxes[:, 0].add_(bounding_boxes[:, 2].mul(angle_rad.cos())).sub_(canvas_size[1]).neg_() + bounding_boxes[:, 1].sub_(bounding_boxes[:, 2].mul(angle_rad.sin())) + bounding_boxes[:, 4].neg_() + bounding_boxes = bounding_boxes.to(dtype) + else: # format == tv_tensors.BoundingBoxFormat.CXCYWHR: bounding_boxes[:, 0].sub_(canvas_size[1]).neg_() + bounding_boxes[:, 4].neg_() return bounding_boxes.reshape(shape) @@ -128,14 +152,37 @@ def vertical_flip_bounding_boxes( ) -> torch.Tensor: shape = bounding_boxes.shape - bounding_boxes = bounding_boxes.clone().reshape(-1, 4) + if tv_tensors.is_rotated_bounding_format(format): + bounding_boxes = ( + bounding_boxes.clone().reshape(-1, 5) + if format != tv_tensors.BoundingBoxFormat.XYXYXYXY + else bounding_boxes.clone().reshape(-1, 8) + ) + else: + bounding_boxes = bounding_boxes.clone().reshape(-1, 4) if format == tv_tensors.BoundingBoxFormat.XYXY: bounding_boxes[:, [1, 3]] = bounding_boxes[:, [3, 1]].sub_(canvas_size[0]).neg_() elif format == tv_tensors.BoundingBoxFormat.XYWH: bounding_boxes[:, 1].add_(bounding_boxes[:, 3]).sub_(canvas_size[0]).neg_() - else: # format == tv_tensors.BoundingBoxFormat.CXCYWH: + elif format == tv_tensors.BoundingBoxFormat.CXCYWH: + bounding_boxes[:, 1].sub_(canvas_size[0]).neg_() + elif format == tv_tensors.BoundingBoxFormat.XYXYXYXY: + bounding_boxes[:, 1::2].sub_(canvas_size[0]).neg_() + bounding_boxes = bounding_boxes[:, [2, 3, 0, 1, 6, 7, 4, 5]] + elif format == tv_tensors.BoundingBoxFormat.XYWHR: + dtype = bounding_boxes.dtype + if not torch.is_floating_point(bounding_boxes): + # Casting to float to support cos and sin computations. + bounding_boxes = bounding_boxes.to(torch.float64) + angle_rad = bounding_boxes[:, 4].mul(torch.pi).div(180) + bounding_boxes[:, 1].sub_(bounding_boxes[:, 2].mul(angle_rad.sin())).sub_(canvas_size[0]).neg_() + bounding_boxes[:, 0].add_(bounding_boxes[:, 2].mul(angle_rad.cos())) + bounding_boxes[:, 4].neg_().add_(180) + bounding_boxes = bounding_boxes.to(dtype) + else: # format == tv_tensors.BoundingBoxFormat.CXCYWHR: bounding_boxes[:, 1].sub_(canvas_size[0]).neg_() + bounding_boxes[:, 4].neg_().add_(180) return bounding_boxes.reshape(shape) diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 31dae9a1a81..8dce16957d9 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -227,13 +227,13 @@ def _xywhr_to_xyxyxyxy(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor: r_rad = xywhr[..., 4].mul(torch.pi).div(180.0) cos, sin = r_rad.cos(), r_rad.sin() xywhr = xywhr[..., :2].tile((1, 4)) - # x1 + w * cos = x3 + # x1 + w * cos = x2 xywhr[..., 2].add_(wh[..., 0].mul(cos)) - # y1 - w * sin = y3 + # y1 - w * sin = y2 xywhr[..., 3].sub_(wh[..., 0].mul(sin)) - # x1 + w * cos + h * sin = x2 + # x1 + w * cos + h * sin = x3 xywhr[..., 4].add_(wh[..., 0].mul(cos).add(wh[..., 1].mul(sin))) - # y1 - w * sin + h * cos = y2 + # y1 - w * sin + h * cos = y3 xywhr[..., 5].sub_(wh[..., 0].mul(sin).sub(wh[..., 1].mul(cos))) # x1 + h * sin = x4 xywhr[..., 6].add_(wh[..., 1].mul(sin)) @@ -252,12 +252,12 @@ def _xyxyxyxy_to_xywhr(xyxyxyxy: torch.Tensor, inplace: bool) -> torch.Tensor: xyxyxyxy = xyxyxyxy.float() r_rad = torch.atan2(xyxyxyxy[..., 1].sub(xyxyxyxy[..., 3]), xyxyxyxy[..., 2].sub(xyxyxyxy[..., 0])) - # x1, y1, (x3 - x1), (y3 - y1), (x2 - x3), (y2 - y3) x4, y4 + # x1, y1, (x2 - x1), (y2 - y1), (x3 - x2), (y3 - y2) x4, y4 xyxyxyxy[..., 4:6].sub_(xyxyxyxy[..., 2:4]) xyxyxyxy[..., 2:4].sub_(xyxyxyxy[..., :2]) - # sqrt((x3 - x1) ** 2 + (y1 - y3) ** 2) = w + # sqrt((x2 - x1) ** 2 + (y1 - y2) ** 2) = w xyxyxyxy[..., 2] = xyxyxyxy[..., 2].pow(2).add(xyxyxyxy[..., 3].pow(2)).sqrt() - # sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2) = h + # sqrt((x2 - x3) ** 2 + (y2 - y3) ** 2) = h xyxyxyxy[..., 3] = xyxyxyxy[..., 4].pow(2).add(xyxyxyxy[..., 5].pow(2)).sqrt() xyxyxyxy[..., 4] = r_rad.div_(torch.pi).mul_(180.0) return xyxyxyxy[..., :5].to(dtype) @@ -360,6 +360,16 @@ def _clamp_bounding_boxes( return out_boxes.to(in_dtype) +def _clamp_rotated_bounding_boxes( + bounding_boxes: torch.Tensor, format: BoundingBoxFormat, canvas_size: tuple[int, int] +) -> torch.Tensor: + # TODO: For now we are not clamping rotated bounding boxes. + in_dtype = bounding_boxes.dtype + out_boxes = bounding_boxes.clone() if bounding_boxes.is_floating_point() else bounding_boxes.float() + + return out_boxes.to(in_dtype) + + def clamp_bounding_boxes( inpt: torch.Tensor, format: Optional[BoundingBoxFormat] = None, @@ -373,11 +383,21 @@ def clamp_bounding_boxes( if format is None or canvas_size is None: raise ValueError("For pure tensor inputs, `format` and `canvas_size` have to be passed.") - return _clamp_bounding_boxes(inpt, format=format, canvas_size=canvas_size) + if tv_tensors.is_rotated_bounding_format(format): + return _clamp_rotated_bounding_boxes(inpt, format=format, canvas_size=canvas_size) + else: + return _clamp_bounding_boxes(inpt, format=format, canvas_size=canvas_size) elif isinstance(inpt, tv_tensors.BoundingBoxes): if format is not None or canvas_size is not None: raise ValueError("For bounding box tv_tensor inputs, `format` and `canvas_size` must not be passed.") - output = _clamp_bounding_boxes(inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size) + if tv_tensors.is_rotated_bounding_format(inpt.format): + output = _clamp_rotated_bounding_boxes( + inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size + ) + else: + output = _clamp_bounding_boxes( + inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size + ) return tv_tensors.wrap(output, like=inpt) else: raise TypeError( diff --git a/torchvision/tv_tensors/__init__.py b/torchvision/tv_tensors/__init__.py index 1ba47f60a36..93689254955 100644 --- a/torchvision/tv_tensors/__init__.py +++ b/torchvision/tv_tensors/__init__.py @@ -1,6 +1,6 @@ import torch -from ._bounding_boxes import BoundingBoxes, BoundingBoxFormat +from ._bounding_boxes import BoundingBoxes, BoundingBoxFormat, is_rotated_bounding_format from ._image import Image from ._mask import Mask from ._torch_function_helpers import set_return_type diff --git a/torchvision/tv_tensors/_bounding_boxes.py b/torchvision/tv_tensors/_bounding_boxes.py index bf8ed8cfcb4..aad76d34448 100644 --- a/torchvision/tv_tensors/_bounding_boxes.py +++ b/torchvision/tv_tensors/_bounding_boxes.py @@ -26,8 +26,8 @@ class BoundingBoxFormat(Enum): cy being center of box, w, h being width and height. r is rotation angle in degrees. * ``XYXYXYXY``: rotated boxes represented via corners, x1, y1 being top - left, x2, y2 being bottom right, x3, y3 being bottom left, x4, y4 being - top right. + left, x2, y2 being top right, x3, y3 being bottom right, x4, y4 being + bottom left. """ XYXY = "XYXY" @@ -38,6 +38,14 @@ class BoundingBoxFormat(Enum): XYXYXYXY = "XYXYXYXY" +# TODO: Once torchscript supports Enums with staticmethod +# this can be put into BoundingBoxFormat as staticmethod +def is_rotated_bounding_format(format: BoundingBoxFormat) -> bool: + return ( + format == BoundingBoxFormat.XYWHR or format == BoundingBoxFormat.CXCYWHR or format == BoundingBoxFormat.XYXYXYXY + ) + + class BoundingBoxes(TVTensor): """:class:`torch.Tensor` subclass for bounding boxes with shape ``[N, K]``. diff --git a/torchvision/utils.py b/torchvision/utils.py index e648ee80596..eec7d21293f 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -123,6 +123,57 @@ def norm_range(t, value_range): return grid +class _ImageDrawTV(ImageDraw.ImageDraw): + """ + A wrapper around PIL.ImageDraw to add functionalities for drawing rotated bounding boxes. + """ + + def oriented_rectangle(self, xy, fill=None, outline=None, width=1): + self.dashed_line(((xy[0], xy[1]), (xy[2], xy[3])), width=width, fill=outline) + for i in range(2, len(xy), 2): + self.line( + ((xy[i], xy[i + 1]), (xy[(i + 2) % len(xy)], xy[(i + 3) % len(xy)])), + width=width, + fill=outline, + ) + self.rectangle(xy, fill=fill, outline=None, width=0) + + def dashed_line(self, xy, fill=None, width=0, joint=None, dash_length=5, space_length=5): + # Calculate the total length of the line + total_length = 0 + for i in range(1, len(xy)): + x1, y1 = xy[i - 1] + x2, y2 = xy[i] + total_length += ((x2 - x1) ** 2 + (y2 - y1) ** 2) ** 0.5 + # Initialize the current position and the current dash + current_position = 0 + current_dash = True + # Iterate over the coordinates of the line + for i in range(1, len(xy)): + x1, y1 = xy[i - 1] + x2, y2 = xy[i] + # Calculate the length of this segment + segment_length = ((x2 - x1) ** 2 + (y2 - y1) ** 2) ** 0.5 + # While there are still dashes to draw on this segment + while segment_length > 0: + # Calculate the length of this dash + dash_length_to_draw = min(segment_length, dash_length if current_dash else space_length) + # Calculate the end point of this dash + dx = x2 - x1 + dy = y2 - y1 + angle = math.atan2(dy, dx) + end_x = x1 + math.cos(angle) * dash_length_to_draw + end_y = y1 + math.sin(angle) * dash_length_to_draw + # If this is a dash, draw it + if current_dash: + self.line([(x1, y1), (end_x, end_y)], fill, width, joint) + # Update the current position and the current dash + current_position += dash_length_to_draw + segment_length -= dash_length_to_draw + x1, y1 = end_x, end_y + current_dash = not current_dash + + @torch.no_grad() def save_image( tensor: Union[torch.Tensor, list[torch.Tensor]], @@ -171,9 +222,11 @@ def draw_bounding_boxes( Args: image (Tensor): Tensor of shape (C, H, W) and dtype uint8 or float. - boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that - the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and - `0 <= ymin < ymax < H`. + boxes (Tensor): Tensor of size (N, 4) or (N, 8) containing bounding boxes. + For (N, 4), the format is (xmin, ymin, xmax, ymax) and the boxes are absolute coordinates with respect to the image. + In other words: `0 <= xmin < xmax < W` and `0 <= ymin < ymax < H`. + For (N, 8), the format is (x1, y1, x2, y2, x3, y3, x4, y4) and the boxes are absolute coordinates with respect to the underlying + object, so no need to verify the latter inequalities. labels (List[str]): List containing the labels of bounding boxes. colors (color or list of colors, optional): List containing the colors of the boxes or single color for all boxes. The color can be represented as @@ -205,7 +258,7 @@ def draw_bounding_boxes( raise ValueError("Pass individual images, not batches") elif image.size(0) not in {1, 3}: raise ValueError("Only grayscale and RGB images are supported") - elif (boxes[:, 0] > boxes[:, 2]).any() or (boxes[:, 1] > boxes[:, 3]).any(): + elif boxes.shape[-1] == 4 and ((boxes[:, 0] > boxes[:, 2]).any() or (boxes[:, 1] > boxes[:, 3]).any()): raise ValueError( "Boxes need to be in (xmin, ymin, xmax, ymax) format. Use torchvision.ops.box_convert to convert them" ) @@ -248,16 +301,14 @@ def draw_bounding_boxes( img_boxes = boxes.to(torch.int64).tolist() if fill: - draw = ImageDraw.Draw(img_to_draw, "RGBA") + draw = _ImageDrawTV(img_to_draw, "RGBA") else: - draw = ImageDraw.Draw(img_to_draw) + draw = _ImageDrawTV(img_to_draw) for bbox, color, label, label_color in zip(img_boxes, colors, labels, label_colors): # type: ignore[arg-type] - if fill: - fill_color = color + (100,) - draw.rectangle(bbox, width=width, outline=color, fill=fill_color) - else: - draw.rectangle(bbox, width=width, outline=color) + draw_method = draw.oriented_rectangle if len(bbox) > 4 else draw.rectangle + fill_color = color + (100,) if fill else None + draw_method(bbox, width=width, outline=color, fill=fill_color) if label is not None: box_margin = 1