diff --git a/CHANGES.md b/CHANGES.md index e796fa45a..6d020452f 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,6 +2,10 @@ ## Unreleased +### titiler.extensions + +* add: TileMatrixSet extension (`/tms`), to create TMS document from a dataset + ## 1.1.1 (2026-01-22) ### titiler.extensions diff --git a/src/titiler/application/titiler/application/main.py b/src/titiler/application/titiler/application/main.py index 01a221b38..98a77f1eb 100644 --- a/src/titiler/application/titiler/application/main.py +++ b/src/titiler/application/titiler/application/main.py @@ -43,6 +43,7 @@ stacExtension, stacRenderExtension, stacViewerExtension, + tmsExtension, wmtsExtension, ) from titiler.mosaic.errors import MOSAIC_STATUS_CODES @@ -136,6 +137,7 @@ def validate_access_token(access_token: str = Security(api_key_query)): cogViewerExtension(), stacExtension(), wmtsExtension(), + tmsExtension(), ], enable_telemetry=api_settings.telemetry_enabled, templates=titiler_templates, diff --git a/src/titiler/extensions/tests/test_tms.py b/src/titiler/extensions/tests/test_tms.py new file mode 100644 index 000000000..7e4126217 --- /dev/null +++ b/src/titiler/extensions/tests/test_tms.py @@ -0,0 +1,44 @@ +"""Test TiTiler stac extension.""" + +import os + +from fastapi import FastAPI +from morecantile import TileMatrixSet +from starlette.testclient import TestClient + +from titiler.core.factory import TilerFactory +from titiler.extensions import tmsExtension + +cog = os.path.join(os.path.dirname(__file__), "fixtures", "cog.tif") + + +def test_tmsExtension(): + """Test stacExtension class.""" + tiler = TilerFactory() + tiler_plus_tms = TilerFactory(extensions=[tmsExtension()]) + # Check that we added one route (/tms) + assert len(tiler_plus_tms.router.routes) == len(tiler.router.routes) + 1 + + app = FastAPI() + app.include_router(tiler_plus_tms.router) + with TestClient(app) as client: + response = client.get("/tms", params={"url": cog, "f": "html"}) + assert response.status_code == 200 + assert "text/html" in response.headers["content-type"] + + response = client.get( + "/tms", params={"url": cog}, headers={"Accept": "text/html"} + ) + assert response.status_code == 200 + assert "text/html" in response.headers["content-type"] + + response = client.get("/tms", params={"url": cog}) + assert response.status_code == 200 + assert response.headers["content-type"] == "application/json" + + body = response.json() + tms = TileMatrixSet.model_validate(body) + assert tms.description + assert tms.boundingBox + assert tms.crs + assert len(tms.tileMatrices) == 5 diff --git a/src/titiler/extensions/titiler/extensions/__init__.py b/src/titiler/extensions/titiler/extensions/__init__.py index abf21fdc6..c867f7dca 100644 --- a/src/titiler/extensions/titiler/extensions/__init__.py +++ b/src/titiler/extensions/titiler/extensions/__init__.py @@ -5,6 +5,7 @@ from .cogeo import cogValidateExtension # noqa from .render import stacRenderExtension # noqa from .stac import stacExtension # noqa +from .tms import tmsExtension # noqa from .viewer import cogViewerExtension, stacViewerExtension # noqa from .wms import wmsExtension # noqa from .wmts import wmtsExtension # noqa diff --git a/src/titiler/extensions/titiler/extensions/tms.py b/src/titiler/extensions/titiler/extensions/tms.py new file mode 100644 index 000000000..9e42661ac --- /dev/null +++ b/src/titiler/extensions/titiler/extensions/tms.py @@ -0,0 +1,166 @@ +"""TileMatrixSet Extension.""" + +import math +from typing import Annotated, Any, Literal + +import pyproj +import rasterio +from attrs import define +from fastapi import Depends, Query +from morecantile import TileMatrixSet +from morecantile.models import CRS_to_uri, TileMatrix, TMSBoundingBox, crs_axis_inverted +from morecantile.utils import meters_per_unit +from pyproj.exceptions import CRSError +from starlette.requests import Request + +from titiler.core.factory import FactoryExtension, TilerFactory +from titiler.core.resources.enums import MediaType +from titiler.core.utils import ( + accept_media_type, + create_html_response, + rio_crs_to_pyproj, +) + + +@define +class tmsExtension(FactoryExtension): + """Add /tms endpoint to a TilerFactory.""" + + def register(self, factory: TilerFactory): # type: ignore [override] # noqa: C901 + """Register endpoint to the tiler factory.""" + + @factory.router.get( + "/tms", + response_model=TileMatrixSet, + response_model_exclude_none=True, + name="Create TileMatrixSet from Dataset", + operation_id=f"{factory.operation_prefix}createTMS", + ) + def create_tilematrixset( + request: Request, + src_path=Depends(factory.path_dependency), + f: Annotated[ + Literal["html", "json"] | None, + Query( + description="Response MediaType. Defaults to endpoint's default or value defined in `accept` header." + ), + ] = None, + ): + """Create TileMatrixSet document.""" + tile_matrices: list[TileMatrix] = [] + + with rasterio.open(src_path) as src_dst: + bbox = src_dst.bounds + blockxsize, blockysize = src_dst.block_shapes[0] + width = src_dst.width + height = src_dst.height + + try: + overviews = src_dst.overviews(1) + except Exception: + overviews = [] + + crs = rio_crs_to_pyproj(src_dst.crs) + mpu = meters_per_unit(crs) + screen_pixel_size = 0.28e-3 + + is_inverted = crs_axis_inverted(crs) + # TODO: check this, some image might have different origin + corner_of_origin = "topLeft" + if corner_of_origin == "topLeft": + x_origin = bbox.left if not is_inverted else bbox.top + y_origin = bbox.top if not is_inverted else bbox.left + point_of_origin = [x_origin, y_origin] + elif corner_of_origin == "bottomLeft": + x_origin = bbox.left if not is_inverted else bbox.bottom + y_origin = bbox.bottom if not is_inverted else bbox.left + point_of_origin = [x_origin, y_origin] + + res = max(src_dst.res) + base_level = TileMatrix( + id=str(len(overviews)), # Last TileMatrix + scaleDenominator=res * mpu / screen_pixel_size, + cellSize=res, + cornerOfOrigin=corner_of_origin, + pointOfOrigin=point_of_origin, + tileWidth=blockxsize, + tileHeight=blockysize, + matrixWidth=math.ceil(width / blockxsize), + matrixHeight=math.ceil(height / blockysize), + ) + + for ix, ovr in enumerate(reversed(range(len(overviews)))): + with rasterio.open(src_path, OVERVIEW_LEVEL=ovr) as src_dst: + res = max(src_dst.res) + try: + blocksize = src_dst.block_shapes[0] + except Exception: + blocksize = (src_dst.width, 1) + + width = src_dst.width + height = src_dst.height + + # add tile matrix for highest resolution (base level) + tile_matrices.append( + TileMatrix( + id=str(ix), + scaleDenominator=res * mpu / 0.28e-3, + cellSize=res, + cornerOfOrigin=corner_of_origin, + pointOfOrigin=point_of_origin, + tileWidth=blocksize[1], + tileHeight=blocksize[0], + matrixWidth=math.ceil(width / blocksize[1]), + matrixHeight=math.ceil(height / blocksize[0]), + ) + ) + + tile_matrices.append(base_level) + + if crs.to_authority(min_confidence=20): + crs_data: Any = CRS_to_uri(crs) + + # Some old Proj version might not support URI + # so we fall back to wkt + try: + pyproj.CRS.from_user_input(crs_data) + except CRSError: + crs_data = {"wkt": crs.to_json_dict()} + + else: + crs_data = {"wkt": crs.to_json_dict()} + + tms = TileMatrixSet( + description=f"TileMatrixSet document for {src_path}", + crs=crs_data, + tileMatrices=tile_matrices, + boundingBox=TMSBoundingBox( + lowerLeft=[bbox.left, bbox.bottom], + upperRight=[bbox.right, bbox.top], + crs=crs_data, + ), + ) + + if f: + output_type = MediaType[f] + else: + accepted_media = [MediaType.html, MediaType.json] + output_type = ( + accept_media_type(request.headers.get("accept", ""), accepted_media) + or MediaType.json + ) + + if output_type == MediaType.html: + return create_html_response( + request, + { + **tms.model_dump(exclude_none=True, mode="json"), + # For visualization purpose we add the tms bbox + "bbox": list(tms.bbox), + }, + title="TileMatrixSet", + template_name="tilematrixset", + templates=factory.templates, + ) + + return tms