Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions src/geocodio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@

class GeocodioClient:
BASE_PATH = "/v1.8" # keep in sync with Geocodio's current version
DEFAULT_SINGLE_TIMEOUT = 5.0
DEFAULT_BATCH_TIMEOUT = 1800.0 # 30 minutes
LIST_API_TIMEOUT = 60.0

@staticmethod
def get_status_exception_mappings() -> Dict[
Expand All @@ -43,13 +46,23 @@ def get_status_exception_mappings() -> Dict[
500: GeocodioServerError,
}

def __init__(self, api_key: Optional[str] = None, hostname: str = "api.geocod.io"):
def __init__(
self,
api_key: Optional[str] = None,
hostname: str = "api.geocod.io",
single_timeout: Optional[float] = None,
batch_timeout: Optional[float] = None,
list_timeout: Optional[float] = None,
):
self.api_key: str = api_key or os.getenv("GEOCODIO_API_KEY", "")
if not self.api_key:
raise AuthenticationError(
detail="No API key supplied and GEOCODIO_API_KEY is not set."
)
self.hostname = hostname.rstrip("/")
self.single_timeout = single_timeout or self.DEFAULT_SINGLE_TIMEOUT
self.batch_timeout = batch_timeout or self.DEFAULT_BATCH_TIMEOUT
self.list_timeout = list_timeout or self.LIST_API_TIMEOUT
self._http = httpx.Client(base_url=f"https://{self.hostname}")

# ──────────────────────────────────────────────────────────────────────────
Expand Down Expand Up @@ -108,7 +121,8 @@ def geocode(
params["q"] = address
data = None

response = self._request("POST" if data else "GET", endpoint, params, json=data)
timeout = self.batch_timeout if data else self.single_timeout
response = self._request("POST" if data else "GET", endpoint, params, json=data, timeout=timeout)
return self._parse_geocoding_response(response.json())

def reverse(
Expand Down Expand Up @@ -144,7 +158,8 @@ def reverse(
params["q"] = coordinate # "lat,lng"
data = None

response = self._request("POST" if data else "GET", endpoint, params, json=data)
timeout = self.batch_timeout if data else self.single_timeout
response = self._request("POST" if data else "GET", endpoint, params, json=data, timeout=timeout)
return self._parse_geocoding_response(response.json())

# ──────────────────────────────────────────────────────────────────────────
Expand All @@ -158,13 +173,18 @@ def _request(
params: dict,
json: Optional[dict] = None,
files: Optional[dict] = None,
timeout: Optional[float] = None,
) -> httpx.Response:
logger.debug(f"Making Request: {method} {endpoint}")
logger.debug(f"Params: {params}")
logger.debug(f"JSON body: {json}")
logger.debug(f"Files: {files}")

resp = self._http.request(method, endpoint, params=params, json=json, files=files, timeout=60)
if timeout is None:
timeout = self.single_timeout

logger.debug(f"Using timeout: {timeout}s")
resp = self._http.request(method, endpoint, params=params, json=json, files=files, timeout=timeout)

logger.debug(f"Response status code: {resp.status_code}")
logger.debug(f"Response headers: {resp.headers}")
Expand Down Expand Up @@ -290,7 +310,7 @@ def create_list(
# Join fields with commas as required by the API
params["fields"] = ",".join(fields)

response = self._request("POST", endpoint, params, files=files)
response = self._request("POST", endpoint, params, files=files, timeout=self.list_timeout)
logger.debug(f"Response content: {response.text}")
return self._parse_list_response(response.json(), response=response)

Expand All @@ -304,7 +324,7 @@ def get_lists(self) -> PaginatedResponse:
params: Dict[str, Union[str, int]] = {"api_key": self.api_key}
endpoint = f"{self.BASE_PATH}/lists"

response = self._request("GET", endpoint, params)
response = self._request("GET", endpoint, params, timeout=self.list_timeout)
pagination_info = response.json()

logger.debug(f"Pagination info: {pagination_info}")
Expand Down Expand Up @@ -339,7 +359,7 @@ def get_list(self, list_id: str) -> ListResponse:
params: Dict[str, Union[str, int]] = {"api_key": self.api_key}
endpoint = f"{self.BASE_PATH}/lists/{list_id}"

response = self._request("GET", endpoint, params)
response = self._request("GET", endpoint, params, timeout=self.list_timeout)
return self._parse_list_response(response.json(), response=response)

def delete_list(self, list_id: str) -> None:
Expand All @@ -352,7 +372,7 @@ def delete_list(self, list_id: str) -> None:
params: Dict[str, Union[str, int]] = {"api_key": self.api_key}
endpoint = f"{self.BASE_PATH}/lists/{list_id}"

self._request("DELETE", endpoint, params)
self._request("DELETE", endpoint, params, timeout=self.list_timeout)

@staticmethod
def _parse_list_response(response_json: dict, response: httpx.Response = None) -> ListResponse:
Expand Down Expand Up @@ -521,7 +541,7 @@ def download(self, list_id: str, filename: Optional[str] = None) -> str | bytes:
params = {"api_key": self.api_key}
endpoint = f"{self.BASE_PATH}/lists/{list_id}/download"

response: httpx.Response = self._request("GET", endpoint, params)
response: httpx.Response = self._request("GET", endpoint, params, timeout=self.list_timeout)
if response.headers.get("content-type", "").startswith("application/json"):
try:
error = response.json()
Expand Down