From 2388bd0eefa38d8f885ca9a2d8e1bfaef0203bd3 Mon Sep 17 00:00:00 2001 From: Anfimov Dima Date: Fri, 11 Oct 2024 04:18:55 +0200 Subject: [PATCH 1/2] feat: add handling prefer header --- postgrest/_async/request_builder.py | 20 +++++++- postgrest/_sync/request_builder.py | 22 +++++++-- postgrest/base_request_builder.py | 29 ++++++++--- postgrest/types.py | 5 ++ tests/_async/test_request_builder.py | 72 +++++++++++++++++++++++++--- tests/_sync/test_request_builder.py | 72 +++++++++++++++++++++++++--- 6 files changed, 195 insertions(+), 25 deletions(-) diff --git a/postgrest/_async/request_builder.py b/postgrest/_async/request_builder.py index c13cf200..60557edf 100644 --- a/postgrest/_async/request_builder.py +++ b/postgrest/_async/request_builder.py @@ -20,7 +20,7 @@ pre_upsert, ) from ..exceptions import APIError, generate_default_error_message -from ..types import ReturnMethod +from ..types import Handling, ReturnMethod from ..utils import AsyncClient, get_origin_and_cast _ReturnT = TypeVar("_ReturnT") @@ -283,16 +283,20 @@ def select( *columns: str, count: Optional[CountMethod] = None, head: Optional[bool] = None, + handling: Handling = Handling.lenient, ) -> AsyncSelectRequestBuilder[_ReturnT]: """Run a SELECT query. Args: *columns: The names of the columns to fetch. count: The method to use to get the count of rows returned. + handling: Either 'lenient' or 'strict' Returns: :class:`AsyncSelectRequestBuilder` """ - method, params, headers, json = pre_select(*columns, count=count, head=head) + method, params, headers, json = pre_select( + *columns, count=count, head=head, handling=handling + ) return AsyncSelectRequestBuilder[_ReturnT]( self.session, self.path, method, headers, params, json ) @@ -305,6 +309,7 @@ def insert( returning: ReturnMethod = ReturnMethod.representation, upsert: bool = False, default_to_null: bool = True, + handling: Handling = Handling.lenient, ) -> AsyncQueryRequestBuilder[_ReturnT]: """Run an INSERT query. @@ -316,6 +321,7 @@ def insert( default_to_null: Make missing fields default to `null`. Otherwise, use the default value for the column. Only applies for bulk inserts. + handling: Either 'lenient' or 'strict' Returns: :class:`AsyncQueryRequestBuilder` """ @@ -325,6 +331,7 @@ def insert( returning=returning, upsert=upsert, default_to_null=default_to_null, + handling=handling, ) return AsyncQueryRequestBuilder[_ReturnT]( self.session, self.path, method, headers, params, json @@ -339,6 +346,7 @@ def upsert( ignore_duplicates: bool = False, on_conflict: str = "", default_to_null: bool = True, + handling: Handling = Handling.lenient, ) -> AsyncQueryRequestBuilder[_ReturnT]: """Run an upsert (INSERT ... ON CONFLICT DO UPDATE) query. @@ -352,6 +360,7 @@ def upsert( default value for the column. This only applies when inserting new rows, not when merging with existing rows under `ignoreDuplicates: false`. This also only applies when doing bulk upserts. + handling: Either 'lenient' or 'strict' Returns: :class:`AsyncQueryRequestBuilder` """ @@ -362,6 +371,7 @@ def upsert( ignore_duplicates=ignore_duplicates, on_conflict=on_conflict, default_to_null=default_to_null, + handling=handling, ) return AsyncQueryRequestBuilder[_ReturnT]( self.session, self.path, method, headers, params, json @@ -373,6 +383,7 @@ def update( *, count: Optional[CountMethod] = None, returning: ReturnMethod = ReturnMethod.representation, + handling: Handling = Handling.lenient, ) -> AsyncFilterRequestBuilder[_ReturnT]: """Run an UPDATE query. @@ -380,6 +391,7 @@ def update( json: The updated fields. count: The method to use to get the count of rows returned. returning: Either 'minimal' or 'representation' + handling: Either 'lenient' or 'strict' Returns: :class:`AsyncFilterRequestBuilder` """ @@ -387,6 +399,7 @@ def update( json, count=count, returning=returning, + handling=handling, ) return AsyncFilterRequestBuilder[_ReturnT]( self.session, self.path, method, headers, params, json @@ -397,18 +410,21 @@ def delete( *, count: Optional[CountMethod] = None, returning: ReturnMethod = ReturnMethod.representation, + handling: Handling = Handling.lenient, ) -> AsyncFilterRequestBuilder[_ReturnT]: """Run a DELETE query. Args: count: The method to use to get the count of rows returned. returning: Either 'minimal' or 'representation' + handling: Either 'lenient' or 'strict' Returns: :class:`AsyncFilterRequestBuilder` """ method, params, headers, json = pre_delete( count=count, returning=returning, + handling=handling, ) return AsyncFilterRequestBuilder[_ReturnT]( self.session, self.path, method, headers, params, json diff --git a/postgrest/_sync/request_builder.py b/postgrest/_sync/request_builder.py index 742db9a2..9a7b1d38 100644 --- a/postgrest/_sync/request_builder.py +++ b/postgrest/_sync/request_builder.py @@ -20,7 +20,7 @@ pre_upsert, ) from ..exceptions import APIError, generate_default_error_message -from ..types import ReturnMethod +from ..types import Handling, ReturnMethod from ..utils import SyncClient, get_origin_and_cast _ReturnT = TypeVar("_ReturnT") @@ -34,7 +34,7 @@ def __init__( http_method: str, headers: Headers, params: QueryParams, - json: Union[dict, list], + json: dict, ) -> None: self.session = session self.path = path @@ -283,16 +283,20 @@ def select( *columns: str, count: Optional[CountMethod] = None, head: Optional[bool] = None, + handling: Handling = Handling.lenient, ) -> SyncSelectRequestBuilder[_ReturnT]: """Run a SELECT query. Args: *columns: The names of the columns to fetch. count: The method to use to get the count of rows returned. + handling: Either 'lenient' or 'strict' Returns: :class:`SyncSelectRequestBuilder` """ - method, params, headers, json = pre_select(*columns, count=count, head=head) + method, params, headers, json = pre_select( + *columns, count=count, head=head, handling=handling + ) return SyncSelectRequestBuilder[_ReturnT]( self.session, self.path, method, headers, params, json ) @@ -305,6 +309,7 @@ def insert( returning: ReturnMethod = ReturnMethod.representation, upsert: bool = False, default_to_null: bool = True, + handling: Handling = Handling.lenient, ) -> SyncQueryRequestBuilder[_ReturnT]: """Run an INSERT query. @@ -316,6 +321,7 @@ def insert( default_to_null: Make missing fields default to `null`. Otherwise, use the default value for the column. Only applies for bulk inserts. + handling: Either 'lenient' or 'strict' Returns: :class:`SyncQueryRequestBuilder` """ @@ -325,6 +331,7 @@ def insert( returning=returning, upsert=upsert, default_to_null=default_to_null, + handling=handling, ) return SyncQueryRequestBuilder[_ReturnT]( self.session, self.path, method, headers, params, json @@ -339,6 +346,7 @@ def upsert( ignore_duplicates: bool = False, on_conflict: str = "", default_to_null: bool = True, + handling: Handling = Handling.lenient, ) -> SyncQueryRequestBuilder[_ReturnT]: """Run an upsert (INSERT ... ON CONFLICT DO UPDATE) query. @@ -352,6 +360,7 @@ def upsert( default value for the column. This only applies when inserting new rows, not when merging with existing rows under `ignoreDuplicates: false`. This also only applies when doing bulk upserts. + handling: Either 'lenient' or 'strict' Returns: :class:`SyncQueryRequestBuilder` """ @@ -362,6 +371,7 @@ def upsert( ignore_duplicates=ignore_duplicates, on_conflict=on_conflict, default_to_null=default_to_null, + handling=handling, ) return SyncQueryRequestBuilder[_ReturnT]( self.session, self.path, method, headers, params, json @@ -373,6 +383,7 @@ def update( *, count: Optional[CountMethod] = None, returning: ReturnMethod = ReturnMethod.representation, + handling: Handling = Handling.lenient, ) -> SyncFilterRequestBuilder[_ReturnT]: """Run an UPDATE query. @@ -380,6 +391,7 @@ def update( json: The updated fields. count: The method to use to get the count of rows returned. returning: Either 'minimal' or 'representation' + handling: Either 'lenient' or 'strict' Returns: :class:`SyncFilterRequestBuilder` """ @@ -387,6 +399,7 @@ def update( json, count=count, returning=returning, + handling=handling, ) return SyncFilterRequestBuilder[_ReturnT]( self.session, self.path, method, headers, params, json @@ -397,18 +410,21 @@ def delete( *, count: Optional[CountMethod] = None, returning: ReturnMethod = ReturnMethod.representation, + handling: Handling = Handling.lenient, ) -> SyncFilterRequestBuilder[_ReturnT]: """Run a DELETE query. Args: count: The method to use to get the count of rows returned. returning: Either 'minimal' or 'representation' + handling: Either 'lenient' or 'strict' Returns: :class:`SyncFilterRequestBuilder` """ method, params, headers, json = pre_delete( count=count, returning=returning, + handling=handling, ) return SyncFilterRequestBuilder[_ReturnT]( self.session, self.path, method, headers, params, json diff --git a/postgrest/base_request_builder.py b/postgrest/base_request_builder.py index b86d8f29..a9340d27 100644 --- a/postgrest/base_request_builder.py +++ b/postgrest/base_request_builder.py @@ -34,7 +34,7 @@ # < 2.0.0 from pydantic import validator as field_validator -from .types import CountMethod, Filters, RequestMethod, ReturnMethod +from .types import CountMethod, Filters, Handling, RequestMethod, ReturnMethod from .utils import AsyncClient, SyncClient, get_origin_and_cast, sanitize_param @@ -70,12 +70,16 @@ def pre_select( *columns: str, count: Optional[CountMethod] = None, head: Optional[bool] = None, + handling: Handling = Handling.lenient, ) -> QueryArgs: method = RequestMethod.HEAD if head else RequestMethod.GET cleaned_columns = _cleaned_columns(columns or "*") params = QueryParams({"select": cleaned_columns}) - headers = Headers({"Prefer": f"count={count}"}) if count else Headers() + prefer_headers = [f"handling={handling}"] + if count: + prefer_headers.append(f"count={count}") + headers = Headers({"Prefer": ",".join(prefer_headers)}) return QueryArgs(method, params, headers, {}) @@ -86,8 +90,9 @@ def pre_insert( returning: ReturnMethod, upsert: bool, default_to_null: bool = True, + handling: Handling = Handling.lenient, ) -> QueryArgs: - prefer_headers = [f"return={returning}"] + prefer_headers = [f"return={returning}", f"handling={handling}"] if count: prefer_headers.append(f"count={count}") if upsert: @@ -110,9 +115,13 @@ def pre_upsert( ignore_duplicates: bool, on_conflict: str = "", default_to_null: bool = True, + handling: Handling = Handling.lenient, ) -> QueryArgs: query_params = {} - prefer_headers = [f"return={returning}"] + prefer_headers = [ + f"return={returning}", + f"handling={handling}", + ] if count: prefer_headers.append(f"count={count}") resolution = "ignore" if ignore_duplicates else "merge" @@ -133,8 +142,12 @@ def pre_update( *, count: Optional[CountMethod], returning: ReturnMethod, + handling: Handling = Handling.lenient, ) -> QueryArgs: - prefer_headers = [f"return={returning}"] + prefer_headers = [ + f"return={returning}", + f"handling={handling}", + ] if count: prefer_headers.append(f"count={count}") headers = Headers({"Prefer": ",".join(prefer_headers)}) @@ -145,8 +158,12 @@ def pre_delete( *, count: Optional[CountMethod], returning: ReturnMethod, + handling: Handling = Handling.lenient, ) -> QueryArgs: - prefer_headers = [f"return={returning}"] + prefer_headers = [ + f"return={returning}", + f"handling={handling}", + ] if count: prefer_headers.append(f"count={count}") headers = Headers({"Prefer": ",".join(prefer_headers)}) diff --git a/postgrest/types.py b/postgrest/types.py index fa6f94ce..ccc0d33d 100644 --- a/postgrest/types.py +++ b/postgrest/types.py @@ -56,3 +56,8 @@ class RequestMethod(StrEnum): class ReturnMethod(StrEnum): minimal = "minimal" representation = "representation" + + +class Handling(StrEnum): + lenient = "lenient" + strict = "strict" diff --git a/tests/_async/test_request_builder.py b/tests/_async/test_request_builder.py index feb98032..2d4a349c 100644 --- a/tests/_async/test_request_builder.py +++ b/tests/_async/test_request_builder.py @@ -5,7 +5,7 @@ from postgrest import AsyncRequestBuilder, AsyncSingleRequestBuilder from postgrest.base_request_builder import APIResponse, SingleAPIResponse -from postgrest.types import CountMethod +from postgrest.types import CountMethod, Handling from postgrest.utils import AsyncClient @@ -24,7 +24,7 @@ def test_select(self, request_builder: AsyncRequestBuilder): builder = request_builder.select("col1", "col2") assert builder.params["select"] == "col1,col2" - assert builder.headers.get("prefer") is None + assert builder.headers.get_list("prefer") == ["handling=lenient"] assert builder.http_method == "GET" assert builder.json == {} @@ -32,7 +32,10 @@ def test_select_with_count(self, request_builder: AsyncRequestBuilder): builder = request_builder.select(count=CountMethod.exact) assert builder.params["select"] == "*" - assert builder.headers["prefer"] == "count=exact" + assert builder.headers.get_list("prefer", True) == [ + "handling=lenient", + "count=exact", + ] assert builder.http_method == "GET" assert builder.json == {} @@ -40,7 +43,7 @@ def test_select_with_head(self, request_builder: AsyncRequestBuilder): builder = request_builder.select("col1", "col2", head=True) assert builder.params.get("select") == "col1,col2" - assert builder.headers.get("prefer") is None + assert builder.headers.get_list("prefer") == ["handling=lenient"] assert builder.http_method == "HEAD" assert builder.json == {} @@ -50,12 +53,23 @@ def test_select_as_csv(self, request_builder: AsyncRequestBuilder): assert builder.headers["Accept"] == "text/csv" assert isinstance(builder, AsyncSingleRequestBuilder) + def test_select_with_handling_strict(self, request_builder: AsyncRequestBuilder): + builder = request_builder.select("col1", "col2", handling=Handling.strict) + + assert builder.params["select"] == "col1,col2" + assert builder.headers.get_list("prefer") == ["handling=strict"] + assert builder.http_method == "GET" + assert builder.json == {} + class TestInsert: def test_insert(self, request_builder: AsyncRequestBuilder): builder = request_builder.insert({"key1": "val1"}) - assert builder.headers.get_list("prefer", True) == ["return=representation"] + assert builder.headers.get_list("prefer", True) == [ + "return=representation", + "handling=lenient", + ] assert builder.http_method == "POST" assert builder.json == {"key1": "val1"} @@ -64,6 +78,7 @@ def test_insert_with_count(self, request_builder: AsyncRequestBuilder): assert builder.headers.get_list("prefer", True) == [ "return=representation", + "handling=lenient", "count=exact", ] assert builder.http_method == "POST" @@ -74,6 +89,7 @@ def test_insert_with_upsert(self, request_builder: AsyncRequestBuilder): assert builder.headers.get_list("prefer", True) == [ "return=representation", + "handling=lenient", "resolution=merge-duplicates", ] assert builder.http_method == "POST" @@ -83,6 +99,7 @@ def test_upsert_with_default_single(self, request_builder: AsyncRequestBuilder): builder = request_builder.upsert([{"key1": "val1"}], default_to_null=False) assert builder.headers.get_list("prefer", True) == [ "return=representation", + "handling=lenient", "resolution=merge-duplicates", "missing=default", ] @@ -96,6 +113,7 @@ def test_bulk_insert_using_default(self, request_builder: AsyncRequestBuilder): ) assert builder.headers.get_list("prefer", True) == [ "return=representation", + "handling=lenient", "missing=default", ] assert builder.http_method == "POST" @@ -109,6 +127,7 @@ def test_upsert(self, request_builder: AsyncRequestBuilder): assert builder.headers.get_list("prefer", True) == [ "return=representation", + "handling=lenient", "resolution=merge-duplicates", ] assert builder.http_method == "POST" @@ -120,6 +139,7 @@ def test_bulk_upsert_with_default(self, request_builder: AsyncRequestBuilder): ) assert builder.headers.get_list("prefer", True) == [ "return=representation", + "handling=lenient", "resolution=merge-duplicates", "missing=default", ] @@ -129,12 +149,25 @@ def test_bulk_upsert_with_default(self, request_builder: AsyncRequestBuilder): '"key1","key2","key3"'.split(",") ) + def test_insert_with_handling_strict(self, request_builder: AsyncRequestBuilder): + builder = request_builder.insert({"key1": "val1"}, handling=Handling.strict) + + assert builder.headers.get_list("prefer", True) == [ + "return=representation", + "handling=strict", + ] + assert builder.http_method == "POST" + assert builder.json == {"key1": "val1"} + class TestUpdate: def test_update(self, request_builder: AsyncRequestBuilder): builder = request_builder.update({"key1": "val1"}) - assert builder.headers.get_list("prefer", True) == ["return=representation"] + assert builder.headers.get_list("prefer", True) == [ + "return=representation", + "handling=lenient", + ] assert builder.http_method == "PATCH" assert builder.json == {"key1": "val1"} @@ -143,17 +176,31 @@ def test_update_with_count(self, request_builder: AsyncRequestBuilder): assert builder.headers.get_list("prefer", True) == [ "return=representation", + "handling=lenient", "count=exact", ] assert builder.http_method == "PATCH" assert builder.json == {"key1": "val1"} + def test_update_with_handling_strict(self, request_builder: AsyncRequestBuilder): + builder = request_builder.update({"key1": "val1"}, handling=Handling.strict) + + assert builder.headers.get_list("prefer", True) == [ + "return=representation", + "handling=strict", + ] + assert builder.http_method == "PATCH" + assert builder.json == {"key1": "val1"} + class TestDelete: def test_delete(self, request_builder: AsyncRequestBuilder): builder = request_builder.delete() - assert builder.headers.get_list("prefer", True) == ["return=representation"] + assert builder.headers.get_list("prefer", True) == [ + "return=representation", + "handling=lenient", + ] assert builder.http_method == "DELETE" assert builder.json == {} @@ -162,11 +209,22 @@ def test_delete_with_count(self, request_builder: AsyncRequestBuilder): assert builder.headers.get_list("prefer", True) == [ "return=representation", + "handling=lenient", "count=exact", ] assert builder.http_method == "DELETE" assert builder.json == {} + def test_delete_with_handling_strict(self, request_builder: AsyncRequestBuilder): + builder = request_builder.delete(handling=Handling.strict) + + assert builder.headers.get_list("prefer", True) == [ + "return=representation", + "handling=strict", + ] + assert builder.http_method == "DELETE" + assert builder.json == {} + class TestTextSearch: def test_text_search(self, request_builder: AsyncRequestBuilder): diff --git a/tests/_sync/test_request_builder.py b/tests/_sync/test_request_builder.py index 8d8a1939..c88f1954 100644 --- a/tests/_sync/test_request_builder.py +++ b/tests/_sync/test_request_builder.py @@ -5,7 +5,7 @@ from postgrest import SyncRequestBuilder, SyncSingleRequestBuilder from postgrest.base_request_builder import APIResponse, SingleAPIResponse -from postgrest.types import CountMethod +from postgrest.types import CountMethod, Handling from postgrest.utils import SyncClient @@ -24,7 +24,7 @@ def test_select(self, request_builder: SyncRequestBuilder): builder = request_builder.select("col1", "col2") assert builder.params["select"] == "col1,col2" - assert builder.headers.get("prefer") is None + assert builder.headers.get_list("prefer") == ["handling=lenient"] assert builder.http_method == "GET" assert builder.json == {} @@ -32,7 +32,10 @@ def test_select_with_count(self, request_builder: SyncRequestBuilder): builder = request_builder.select(count=CountMethod.exact) assert builder.params["select"] == "*" - assert builder.headers["prefer"] == "count=exact" + assert builder.headers.get_list("prefer", True) == [ + "handling=lenient", + "count=exact", + ] assert builder.http_method == "GET" assert builder.json == {} @@ -40,7 +43,7 @@ def test_select_with_head(self, request_builder: SyncRequestBuilder): builder = request_builder.select("col1", "col2", head=True) assert builder.params.get("select") == "col1,col2" - assert builder.headers.get("prefer") is None + assert builder.headers.get_list("prefer") == ["handling=lenient"] assert builder.http_method == "HEAD" assert builder.json == {} @@ -50,12 +53,23 @@ def test_select_as_csv(self, request_builder: SyncRequestBuilder): assert builder.headers["Accept"] == "text/csv" assert isinstance(builder, SyncSingleRequestBuilder) + def test_select_with_handling_strict(self, request_builder: SyncRequestBuilder): + builder = request_builder.select("col1", "col2", handling=Handling.strict) + + assert builder.params["select"] == "col1,col2" + assert builder.headers.get_list("prefer") == ["handling=strict"] + assert builder.http_method == "GET" + assert builder.json == {} + class TestInsert: def test_insert(self, request_builder: SyncRequestBuilder): builder = request_builder.insert({"key1": "val1"}) - assert builder.headers.get_list("prefer", True) == ["return=representation"] + assert builder.headers.get_list("prefer", True) == [ + "return=representation", + "handling=lenient", + ] assert builder.http_method == "POST" assert builder.json == {"key1": "val1"} @@ -64,6 +78,7 @@ def test_insert_with_count(self, request_builder: SyncRequestBuilder): assert builder.headers.get_list("prefer", True) == [ "return=representation", + "handling=lenient", "count=exact", ] assert builder.http_method == "POST" @@ -74,6 +89,7 @@ def test_insert_with_upsert(self, request_builder: SyncRequestBuilder): assert builder.headers.get_list("prefer", True) == [ "return=representation", + "handling=lenient", "resolution=merge-duplicates", ] assert builder.http_method == "POST" @@ -83,6 +99,7 @@ def test_upsert_with_default_single(self, request_builder: SyncRequestBuilder): builder = request_builder.upsert([{"key1": "val1"}], default_to_null=False) assert builder.headers.get_list("prefer", True) == [ "return=representation", + "handling=lenient", "resolution=merge-duplicates", "missing=default", ] @@ -96,6 +113,7 @@ def test_bulk_insert_using_default(self, request_builder: SyncRequestBuilder): ) assert builder.headers.get_list("prefer", True) == [ "return=representation", + "handling=lenient", "missing=default", ] assert builder.http_method == "POST" @@ -109,6 +127,7 @@ def test_upsert(self, request_builder: SyncRequestBuilder): assert builder.headers.get_list("prefer", True) == [ "return=representation", + "handling=lenient", "resolution=merge-duplicates", ] assert builder.http_method == "POST" @@ -120,6 +139,7 @@ def test_bulk_upsert_with_default(self, request_builder: SyncRequestBuilder): ) assert builder.headers.get_list("prefer", True) == [ "return=representation", + "handling=lenient", "resolution=merge-duplicates", "missing=default", ] @@ -129,12 +149,25 @@ def test_bulk_upsert_with_default(self, request_builder: SyncRequestBuilder): '"key1","key2","key3"'.split(",") ) + def test_insert_with_handling_strict(self, request_builder: SyncRequestBuilder): + builder = request_builder.insert({"key1": "val1"}, handling=Handling.strict) + + assert builder.headers.get_list("prefer", True) == [ + "return=representation", + "handling=strict", + ] + assert builder.http_method == "POST" + assert builder.json == {"key1": "val1"} + class TestUpdate: def test_update(self, request_builder: SyncRequestBuilder): builder = request_builder.update({"key1": "val1"}) - assert builder.headers.get_list("prefer", True) == ["return=representation"] + assert builder.headers.get_list("prefer", True) == [ + "return=representation", + "handling=lenient", + ] assert builder.http_method == "PATCH" assert builder.json == {"key1": "val1"} @@ -143,17 +176,31 @@ def test_update_with_count(self, request_builder: SyncRequestBuilder): assert builder.headers.get_list("prefer", True) == [ "return=representation", + "handling=lenient", "count=exact", ] assert builder.http_method == "PATCH" assert builder.json == {"key1": "val1"} + def test_update_with_handling_strict(self, request_builder: SyncRequestBuilder): + builder = request_builder.update({"key1": "val1"}, handling=Handling.strict) + + assert builder.headers.get_list("prefer", True) == [ + "return=representation", + "handling=strict", + ] + assert builder.http_method == "PATCH" + assert builder.json == {"key1": "val1"} + class TestDelete: def test_delete(self, request_builder: SyncRequestBuilder): builder = request_builder.delete() - assert builder.headers.get_list("prefer", True) == ["return=representation"] + assert builder.headers.get_list("prefer", True) == [ + "return=representation", + "handling=lenient", + ] assert builder.http_method == "DELETE" assert builder.json == {} @@ -162,11 +209,22 @@ def test_delete_with_count(self, request_builder: SyncRequestBuilder): assert builder.headers.get_list("prefer", True) == [ "return=representation", + "handling=lenient", "count=exact", ] assert builder.http_method == "DELETE" assert builder.json == {} + def test_delete_with_handling_strict(self, request_builder: SyncRequestBuilder): + builder = request_builder.delete(handling=Handling.strict) + + assert builder.headers.get_list("prefer", True) == [ + "return=representation", + "handling=strict", + ] + assert builder.http_method == "DELETE" + assert builder.json == {} + class TestTextSearch: def test_text_search(self, request_builder: SyncRequestBuilder): From fa4edae24215795b0ccd8a45ba42e4d2263a92aa Mon Sep 17 00:00:00 2001 From: Anfimov Dima Date: Fri, 11 Oct 2024 05:01:09 +0200 Subject: [PATCH 2/2] feat: add max-affected-header --- infra/docker-compose.yaml | 2 +- postgrest/_async/request_builder.py | 6 ++ postgrest/_sync/request_builder.py | 6 ++ postgrest/base_request_builder.py | 6 ++ tests/_async/test_max_affected_integration.py | 36 ++++++++ tests/_async/test_request_builder.py | 88 ++++++++++++++++++- tests/_sync/test_max_affected_integration.py | 36 ++++++++ tests/_sync/test_request_builder.py | 88 ++++++++++++++++++- 8 files changed, 265 insertions(+), 3 deletions(-) create mode 100644 tests/_async/test_max_affected_integration.py create mode 100644 tests/_sync/test_max_affected_integration.py diff --git a/infra/docker-compose.yaml b/infra/docker-compose.yaml index 783ed1dc..6754c392 100644 --- a/infra/docker-compose.yaml +++ b/infra/docker-compose.yaml @@ -2,7 +2,7 @@ version: '3' services: rest: - image: postgrest/postgrest:v11.2.2 + image: postgrest/postgrest:v12.2.3 ports: - '3000:3000' environment: diff --git a/postgrest/_async/request_builder.py b/postgrest/_async/request_builder.py index 60557edf..c6624806 100644 --- a/postgrest/_async/request_builder.py +++ b/postgrest/_async/request_builder.py @@ -384,6 +384,7 @@ def update( count: Optional[CountMethod] = None, returning: ReturnMethod = ReturnMethod.representation, handling: Handling = Handling.lenient, + max_affected: Optional[int] = None, ) -> AsyncFilterRequestBuilder[_ReturnT]: """Run an UPDATE query. @@ -392,6 +393,7 @@ def update( count: The method to use to get the count of rows returned. returning: Either 'minimal' or 'representation' handling: Either 'lenient' or 'strict' + max_affected: Limit of rows that can be affected during request. Working only with handling=strict. Returns: :class:`AsyncFilterRequestBuilder` """ @@ -400,6 +402,7 @@ def update( count=count, returning=returning, handling=handling, + max_affected=max_affected, ) return AsyncFilterRequestBuilder[_ReturnT]( self.session, self.path, method, headers, params, json @@ -411,6 +414,7 @@ def delete( count: Optional[CountMethod] = None, returning: ReturnMethod = ReturnMethod.representation, handling: Handling = Handling.lenient, + max_affected: Optional[int] = None, ) -> AsyncFilterRequestBuilder[_ReturnT]: """Run a DELETE query. @@ -418,6 +422,7 @@ def delete( count: The method to use to get the count of rows returned. returning: Either 'minimal' or 'representation' handling: Either 'lenient' or 'strict' + max_affected: Limit of rows that can be affected during request. Working only with handling=strict. Returns: :class:`AsyncFilterRequestBuilder` """ @@ -425,6 +430,7 @@ def delete( count=count, returning=returning, handling=handling, + max_affected=max_affected, ) return AsyncFilterRequestBuilder[_ReturnT]( self.session, self.path, method, headers, params, json diff --git a/postgrest/_sync/request_builder.py b/postgrest/_sync/request_builder.py index 9a7b1d38..a9e21abb 100644 --- a/postgrest/_sync/request_builder.py +++ b/postgrest/_sync/request_builder.py @@ -384,6 +384,7 @@ def update( count: Optional[CountMethod] = None, returning: ReturnMethod = ReturnMethod.representation, handling: Handling = Handling.lenient, + max_affected: Optional[int] = None, ) -> SyncFilterRequestBuilder[_ReturnT]: """Run an UPDATE query. @@ -392,6 +393,7 @@ def update( count: The method to use to get the count of rows returned. returning: Either 'minimal' or 'representation' handling: Either 'lenient' or 'strict' + max_affected: Limit of rows that can be affected during request. Working only with handling=strict. Returns: :class:`SyncFilterRequestBuilder` """ @@ -400,6 +402,7 @@ def update( count=count, returning=returning, handling=handling, + max_affected=max_affected, ) return SyncFilterRequestBuilder[_ReturnT]( self.session, self.path, method, headers, params, json @@ -411,6 +414,7 @@ def delete( count: Optional[CountMethod] = None, returning: ReturnMethod = ReturnMethod.representation, handling: Handling = Handling.lenient, + max_affected: Optional[int] = None, ) -> SyncFilterRequestBuilder[_ReturnT]: """Run a DELETE query. @@ -418,6 +422,7 @@ def delete( count: The method to use to get the count of rows returned. returning: Either 'minimal' or 'representation' handling: Either 'lenient' or 'strict' + max_affected: Limit of rows that can be affected during request. Working only with handling=strict. Returns: :class:`SyncFilterRequestBuilder` """ @@ -425,6 +430,7 @@ def delete( count=count, returning=returning, handling=handling, + max_affected=max_affected, ) return SyncFilterRequestBuilder[_ReturnT]( self.session, self.path, method, headers, params, json diff --git a/postgrest/base_request_builder.py b/postgrest/base_request_builder.py index a9340d27..410ad869 100644 --- a/postgrest/base_request_builder.py +++ b/postgrest/base_request_builder.py @@ -143,6 +143,7 @@ def pre_update( count: Optional[CountMethod], returning: ReturnMethod, handling: Handling = Handling.lenient, + max_affected: Optional[int] = None, ) -> QueryArgs: prefer_headers = [ f"return={returning}", @@ -150,6 +151,8 @@ def pre_update( ] if count: prefer_headers.append(f"count={count}") + if max_affected and handling == handling.strict: + prefer_headers.append(f"max-affected={max_affected}") headers = Headers({"Prefer": ",".join(prefer_headers)}) return QueryArgs(RequestMethod.PATCH, QueryParams(), headers, json) @@ -159,6 +162,7 @@ def pre_delete( count: Optional[CountMethod], returning: ReturnMethod, handling: Handling = Handling.lenient, + max_affected: Optional[int] = None, ) -> QueryArgs: prefer_headers = [ f"return={returning}", @@ -166,6 +170,8 @@ def pre_delete( ] if count: prefer_headers.append(f"count={count}") + if max_affected and handling == handling.strict: + prefer_headers.append(f"max-affected={max_affected}") headers = Headers({"Prefer": ",".join(prefer_headers)}) return QueryArgs(RequestMethod.DELETE, QueryParams(), headers, {}) diff --git a/tests/_async/test_max_affected_integration.py b/tests/_async/test_max_affected_integration.py new file mode 100644 index 00000000..abbdd328 --- /dev/null +++ b/tests/_async/test_max_affected_integration.py @@ -0,0 +1,36 @@ +import pytest + +from postgrest.exceptions import APIError +from postgrest.types import Handling + +from .client import rest_client + + +async def test_update_more_rows_that_should_be_affected(): + with pytest.raises( + APIError, match="Query result exceeds max-affected preference constraint" + ): + ( + await rest_client() + .from_("countries") + .update( + {"country_name": "COUNTRY_NAME_CHANGED"}, + handling=Handling.strict, + max_affected=1, + ) + .in_("nicename", ["Albania", "Algeria"]) + .execute() + ) + + +async def test_delete_more_rows_that_should_be_affected(): + with pytest.raises( + APIError, match="Query result exceeds max-affected preference constraint" + ): + ( + await rest_client() + .from_("countries") + .delete(handling=Handling.strict, max_affected=1) + .in_("nicename", ["Albania", "Algeria"]) + .execute() + ) diff --git a/tests/_async/test_request_builder.py b/tests/_async/test_request_builder.py index 2d4a349c..44c46bcc 100644 --- a/tests/_async/test_request_builder.py +++ b/tests/_async/test_request_builder.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import pytest from httpx import Request, Response @@ -192,6 +192,50 @@ def test_update_with_handling_strict(self, request_builder: AsyncRequestBuilder) assert builder.http_method == "PATCH" assert builder.json == {"key1": "val1"} + @pytest.mark.parametrize( + "handling, max_affected, max_affected_should_be_in_headers", + [ + pytest.param( + Handling.strict, + 10, + True, + id="when-handling-is-strict-can-set-max-affected-parameter", + ), + pytest.param( + Handling.lenient, + 10, + False, + id="when-handling-is-lenient-max-affected-parameter-cant-be-set", + ), + pytest.param( + Handling.lenient, + None, + False, + id="when-max-affected-is-not-set-then-not-add-it-to-headers", + ), + pytest.param( + Handling.strict, + None, + False, + id="when-max-affected-is-not-set-then-not-add-it-to-headers-even-handling-is-strict", + ), + ], + ) + def test_update_with_max_affected( + self, + request_builder: AsyncRequestBuilder, + handling: Handling, + max_affected: Optional[int], + max_affected_should_be_in_headers: bool, + ): + builder = request_builder.update( + {"key1": "val1"}, handling=handling, max_affected=max_affected + ) + + assert ( + f"max-affected={max_affected}" in builder.headers.get_list("prefer", True) + ) is max_affected_should_be_in_headers + class TestDelete: def test_delete(self, request_builder: AsyncRequestBuilder): @@ -225,6 +269,48 @@ def test_delete_with_handling_strict(self, request_builder: AsyncRequestBuilder) assert builder.http_method == "DELETE" assert builder.json == {} + @pytest.mark.parametrize( + "handling, max_affected, max_affected_should_be_in_headers", + [ + pytest.param( + Handling.strict, + 10, + True, + id="when-handling-is-strict-can-set-max-affected-parameter", + ), + pytest.param( + Handling.lenient, + 10, + False, + id="when-handling-is-lenient-max-affected-parameter-cant-be-set", + ), + pytest.param( + Handling.lenient, + None, + False, + id="when-max-affected-is-not-set-then-not-add-it-to-headers", + ), + pytest.param( + Handling.strict, + None, + False, + id="when-max-affected-is-not-set-then-not-add-it-to-headers-even-handling-is-strict", + ), + ], + ) + def test_delete_with_max_affected( + self, + request_builder: AsyncRequestBuilder, + handling: Handling, + max_affected: Optional[int], + max_affected_should_be_in_headers: bool, + ): + builder = request_builder.delete(handling=handling, max_affected=max_affected) + + assert ( + f"max-affected={max_affected}" in builder.headers.get_list("prefer", True) + ) is max_affected_should_be_in_headers + class TestTextSearch: def test_text_search(self, request_builder: AsyncRequestBuilder): diff --git a/tests/_sync/test_max_affected_integration.py b/tests/_sync/test_max_affected_integration.py new file mode 100644 index 00000000..4f4528d4 --- /dev/null +++ b/tests/_sync/test_max_affected_integration.py @@ -0,0 +1,36 @@ +import pytest + +from postgrest.exceptions import APIError +from postgrest.types import Handling + +from .client import rest_client + + +def test_update_more_rows_that_should_be_affected(): + with pytest.raises( + APIError, match="Query result exceeds max-affected preference constraint" + ): + ( + rest_client() + .from_("countries") + .update( + {"country_name": "COUNTRY_NAME_CHANGED"}, + handling=Handling.strict, + max_affected=1, + ) + .in_("nicename", ["Albania", "Algeria"]) + .execute() + ) + + +def test_delete_more_rows_that_should_be_affected(): + with pytest.raises( + APIError, match="Query result exceeds max-affected preference constraint" + ): + ( + rest_client() + .from_("countries") + .delete(handling=Handling.strict, max_affected=1) + .in_("nicename", ["Albania", "Algeria"]) + .execute() + ) diff --git a/tests/_sync/test_request_builder.py b/tests/_sync/test_request_builder.py index c88f1954..6e616a14 100644 --- a/tests/_sync/test_request_builder.py +++ b/tests/_sync/test_request_builder.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import pytest from httpx import Request, Response @@ -192,6 +192,50 @@ def test_update_with_handling_strict(self, request_builder: SyncRequestBuilder): assert builder.http_method == "PATCH" assert builder.json == {"key1": "val1"} + @pytest.mark.parametrize( + "handling, max_affected, max_affected_should_be_in_headers", + [ + pytest.param( + Handling.strict, + 10, + True, + id="when-handling-is-strict-can-set-max-affected-parameter", + ), + pytest.param( + Handling.lenient, + 10, + False, + id="when-handling-is-lenient-max-affected-parameter-cant-be-set", + ), + pytest.param( + Handling.lenient, + None, + False, + id="when-max-affected-is-not-set-then-not-add-it-to-headers", + ), + pytest.param( + Handling.strict, + None, + False, + id="when-max-affected-is-not-set-then-not-add-it-to-headers-even-handling-is-strict", + ), + ], + ) + def test_update_with_max_affected( + self, + request_builder: SyncRequestBuilder, + handling: Handling, + max_affected: Optional[int], + max_affected_should_be_in_headers: bool, + ): + builder = request_builder.update( + {"key1": "val1"}, handling=handling, max_affected=max_affected + ) + + assert ( + f"max-affected={max_affected}" in builder.headers.get_list("prefer", True) + ) is max_affected_should_be_in_headers + class TestDelete: def test_delete(self, request_builder: SyncRequestBuilder): @@ -225,6 +269,48 @@ def test_delete_with_handling_strict(self, request_builder: SyncRequestBuilder): assert builder.http_method == "DELETE" assert builder.json == {} + @pytest.mark.parametrize( + "handling, max_affected, max_affected_should_be_in_headers", + [ + pytest.param( + Handling.strict, + 10, + True, + id="when-handling-is-strict-can-set-max-affected-parameter", + ), + pytest.param( + Handling.lenient, + 10, + False, + id="when-handling-is-lenient-max-affected-parameter-cant-be-set", + ), + pytest.param( + Handling.lenient, + None, + False, + id="when-max-affected-is-not-set-then-not-add-it-to-headers", + ), + pytest.param( + Handling.strict, + None, + False, + id="when-max-affected-is-not-set-then-not-add-it-to-headers-even-handling-is-strict", + ), + ], + ) + def test_delete_with_max_affected( + self, + request_builder: SyncRequestBuilder, + handling: Handling, + max_affected: Optional[int], + max_affected_should_be_in_headers: bool, + ): + builder = request_builder.delete(handling=handling, max_affected=max_affected) + + assert ( + f"max-affected={max_affected}" in builder.headers.get_list("prefer", True) + ) is max_affected_should_be_in_headers + class TestTextSearch: def test_text_search(self, request_builder: SyncRequestBuilder):