Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from geojson_pydantic import Feature, FeatureCollection, Polygon
from vcr.request import Request

import numpy as np
from rio_tiler.models import ImageData

from titiler.cmr.backend import CMRBackend
from titiler.cmr.settings import AuthSettings

Expand Down Expand Up @@ -192,3 +195,13 @@ def rasterio_query_params() -> Dict[str, str]:
"bands_regex": "Fmask",
"bands": "Fmask",
}


@pytest.fixture
def image_data():
"""
Create a proper ImageData object to return from tile
"""
return ImageData(
np.zeros((3, 256, 256), dtype=np.uint8) # RGB image
)
99 changes: 99 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Test backend functions"""

from unittest.mock import MagicMock, patch

import pytest

from titiler.cmr.backend import Access, CMRBackend
Expand Down Expand Up @@ -28,3 +30,100 @@ def test_get_assets(access: Access, expectation: str) -> None:
assert asset_url
assert isinstance(asset_url, dict)
assert asset_url[band].startswith(expectation)


cmr_query = {
"concept_id": "C2021957657-LPCLOUD",
"temporal": ("2024-02-11", "2024-02-13"),
}


@pytest.mark.parametrize(
"method_name,method_call",
[
(
"tile",
lambda backend: backend.tile(
tile_x=0, tile_y=0, tile_z=0, cmr_query=cmr_query, bands_regex=""
),
),
(
"part",
lambda backend: backend.part(
bbox=(0, 0, 1, 1), cmr_query=cmr_query, bands_regex=""
),
),
(
"feature",
lambda backend: backend.feature(
shape={"type": "Point", "coordinates": [0, 0]},
cmr_query=cmr_query,
bands_regex="",
),
),
],
)
def test_s3_credentials_used_for_session_creation(
method_name, method_call, image_data
) -> None:
"""Test that s3_credentials from _get_s3_credentials are used to create AWS session."""
from rio_tiler.io import Reader

# Mock s3 credentials that would be returned by _get_s3_credentials
mock_s3_credentials = {
"accessKeyId": "test_access_key",
"secretAccessKey": "test_secret_key",
"sessionToken": "test_session_token",
}

# Mock asset that would be returned by assets_for_tile
mock_asset = {
"url": "s3://test-bucket/test-file.tif",
"provider": "TEST_PROVIDER",
}

# Create a mock class that will pass isinstance checks
class MockReader:
def __init__(self, *args, **kwargs):
pass

def __enter__(self):
mock_instance = MagicMock()
# Set the method to return the image_data
getattr(mock_instance, method_name).return_value = image_data
return mock_instance

def __exit__(self, *args):
pass

def class_eq(self, other):
if other is MockReader:
return True
return type.__eq__(self, other)

with CMRBackend(reader=MockReader) as backend:
# Mock asset methods to return our test asset
with (
patch.object(backend, "assets_for_tile", return_value=[mock_asset]),
patch.object(backend, "assets_for_bbox", return_value=[mock_asset]),
patch.object(backend, "get_assets", return_value=[mock_asset]),
patch.object(type(Reader), "__eq__", class_eq),
patch.object(
backend, "_get_s3_credentials", return_value=mock_s3_credentials
) as mock_get_creds,
patch.object(backend, "_create_aws_session") as mock_create_session,
patch("rasterio.Env"),
):
# Mock the session to return a valid context manager
mock_session = MagicMock()
mock_create_session.return_value = mock_session

# Call tile, which should trigger the credential flow
method_call(backend)

# Verify that _get_s3_credentials was called with the asset
mock_get_creds.assert_called_once_with(mock_asset)

# Verify that _create_aws_session was called with the credentials
# returned by _get_s3_credentials
mock_create_session.assert_called_once_with(mock_s3_credentials)
15 changes: 9 additions & 6 deletions titiler/cmr/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,10 @@ def tile(
f"No assets found for tile {tile_z}-{tile_x}-{tile_y}"
)

def _reader(asset: Asset, x: int, y: int, z: int, **kwargs: Any) -> ImageData:
s3_credentials = self._get_s3_credentials(asset)
asset = mosaic_assets[0]
s3_credentials = self._get_s3_credentials(asset)

def _reader(asset: Asset, x: int, y: int, z: int, **kwargs: Any) -> ImageData:
if isinstance(self.reader, type) and self.reader == Reader:
aws_session = self._create_aws_session(s3_credentials)

Expand Down Expand Up @@ -354,9 +355,10 @@ def part(
if not mosaic_assets:
raise NoAssetFoundError("No assets found for bbox input")

def _reader(asset: Asset, bbox: BBox, **kwargs: Any) -> ImageData:
s3_credentials = self._get_s3_credentials(asset)
asset = mosaic_assets[0]
s3_credentials = self._get_s3_credentials(asset)

def _reader(asset: Asset, bbox: BBox, **kwargs: Any) -> ImageData:
if isinstance(self.reader, type) and self.reader == Reader:
aws_session = self._create_aws_session(s3_credentials)

Expand Down Expand Up @@ -414,9 +416,10 @@ def feature(
if not mosaic_assets:
raise NoAssetFoundError("No assets found for Geometry")

def _reader(asset: Asset, shape: Dict, **kwargs: Any) -> ImageData:
s3_credentials = self._get_s3_credentials(asset)
asset = mosaic_assets[0]
s3_credentials = self._get_s3_credentials(asset)

def _reader(asset: Asset, shape: Dict, **kwargs: Any) -> ImageData:
if isinstance(self.reader, type) and self.reader == Reader:
aws_session = self._create_aws_session(s3_credentials)

Expand Down
Loading