diff --git a/src/geocodio/client.py b/src/geocodio/client.py index 0102e2d..0d82255 100644 --- a/src/geocodio/client.py +++ b/src/geocodio/client.py @@ -27,6 +27,9 @@ class GeocodioClient: BASE_PATH = "/v1.8" # keep in sync with Geocodio's current version + DEFAULT_SINGLE_TIMEOUT = 5.0 + DEFAULT_BATCH_TIMEOUT = 1800.0 # 30 minutes + LIST_API_TIMEOUT = 60.0 @staticmethod def get_status_exception_mappings() -> Dict[ @@ -43,13 +46,23 @@ def get_status_exception_mappings() -> Dict[ 500: GeocodioServerError, } - def __init__(self, api_key: Optional[str] = None, hostname: str = "api.geocod.io"): + def __init__( + self, + api_key: Optional[str] = None, + hostname: str = "api.geocod.io", + single_timeout: Optional[float] = None, + batch_timeout: Optional[float] = None, + list_timeout: Optional[float] = None, + ): self.api_key: str = api_key or os.getenv("GEOCODIO_API_KEY", "") if not self.api_key: raise AuthenticationError( detail="No API key supplied and GEOCODIO_API_KEY is not set." ) self.hostname = hostname.rstrip("/") + self.single_timeout = single_timeout or self.DEFAULT_SINGLE_TIMEOUT + self.batch_timeout = batch_timeout or self.DEFAULT_BATCH_TIMEOUT + self.list_timeout = list_timeout or self.LIST_API_TIMEOUT self._http = httpx.Client(base_url=f"https://{self.hostname}") # ────────────────────────────────────────────────────────────────────────── @@ -108,7 +121,8 @@ def geocode( params["q"] = address data = None - response = self._request("POST" if data else "GET", endpoint, params, json=data) + timeout = self.batch_timeout if data else self.single_timeout + response = self._request("POST" if data else "GET", endpoint, params, json=data, timeout=timeout) return self._parse_geocoding_response(response.json()) def reverse( @@ -144,7 +158,8 @@ def reverse( params["q"] = coordinate # "lat,lng" data = None - response = self._request("POST" if data else "GET", endpoint, params, json=data) + timeout = self.batch_timeout if data else self.single_timeout + response = self._request("POST" if data else "GET", endpoint, params, json=data, timeout=timeout) return self._parse_geocoding_response(response.json()) # ────────────────────────────────────────────────────────────────────────── @@ -158,13 +173,18 @@ def _request( params: dict, json: Optional[dict] = None, files: Optional[dict] = None, + timeout: Optional[float] = None, ) -> httpx.Response: logger.debug(f"Making Request: {method} {endpoint}") logger.debug(f"Params: {params}") logger.debug(f"JSON body: {json}") logger.debug(f"Files: {files}") - resp = self._http.request(method, endpoint, params=params, json=json, files=files, timeout=60) + if timeout is None: + timeout = self.single_timeout + + logger.debug(f"Using timeout: {timeout}s") + resp = self._http.request(method, endpoint, params=params, json=json, files=files, timeout=timeout) logger.debug(f"Response status code: {resp.status_code}") logger.debug(f"Response headers: {resp.headers}") @@ -290,7 +310,7 @@ def create_list( # Join fields with commas as required by the API params["fields"] = ",".join(fields) - response = self._request("POST", endpoint, params, files=files) + response = self._request("POST", endpoint, params, files=files, timeout=self.list_timeout) logger.debug(f"Response content: {response.text}") return self._parse_list_response(response.json(), response=response) @@ -304,7 +324,7 @@ def get_lists(self) -> PaginatedResponse: params: Dict[str, Union[str, int]] = {"api_key": self.api_key} endpoint = f"{self.BASE_PATH}/lists" - response = self._request("GET", endpoint, params) + response = self._request("GET", endpoint, params, timeout=self.list_timeout) pagination_info = response.json() logger.debug(f"Pagination info: {pagination_info}") @@ -339,7 +359,7 @@ def get_list(self, list_id: str) -> ListResponse: params: Dict[str, Union[str, int]] = {"api_key": self.api_key} endpoint = f"{self.BASE_PATH}/lists/{list_id}" - response = self._request("GET", endpoint, params) + response = self._request("GET", endpoint, params, timeout=self.list_timeout) return self._parse_list_response(response.json(), response=response) def delete_list(self, list_id: str) -> None: @@ -352,7 +372,7 @@ def delete_list(self, list_id: str) -> None: params: Dict[str, Union[str, int]] = {"api_key": self.api_key} endpoint = f"{self.BASE_PATH}/lists/{list_id}" - self._request("DELETE", endpoint, params) + self._request("DELETE", endpoint, params, timeout=self.list_timeout) @staticmethod def _parse_list_response(response_json: dict, response: httpx.Response = None) -> ListResponse: @@ -521,7 +541,7 @@ def download(self, list_id: str, filename: Optional[str] = None) -> str | bytes: params = {"api_key": self.api_key} endpoint = f"{self.BASE_PATH}/lists/{list_id}/download" - response: httpx.Response = self._request("GET", endpoint, params) + response: httpx.Response = self._request("GET", endpoint, params, timeout=self.list_timeout) if response.headers.get("content-type", "").startswith("application/json"): try: error = response.json()