Skip to content

Commit a104f99

Browse files
refactor: adding proper type annotations
1 parent beb8d2f commit a104f99

3 files changed

Lines changed: 33 additions & 25 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ dependencies = ["coverage[toml]", "pytest", "pytest-cov"]
4141
cov = "pytest --cov-report=term-missing --cov-config=pyproject.toml --cov=src/aws_durable_execution_sdk_python --cov-fail-under=98"
4242

4343
[tool.hatch.envs.types]
44-
extra-dependencies = ["mypy>=1.0.0", "pytest"]
44+
extra-dependencies = ["mypy>=1.0.0", "pytest", "boto3-stubs[lambda]"]
4545
[tool.hatch.envs.types.scripts]
4646
check = "mypy --install-types --non-interactive {args:src/aws_durable_execution_sdk_python tests}"
4747

src/aws_durable_execution_sdk_python/execution.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
if TYPE_CHECKING:
3333
from collections.abc import Callable, MutableMapping
3434

35-
import boto3 # type: ignore
35+
from mypy_boto3_lambda import LambdaClient as Boto3LambdaClient
3636

3737
from aws_durable_execution_sdk_python.types import LambdaContext
3838

@@ -237,7 +237,7 @@ def create_succeeded(cls, result: str) -> DurableExecutionInvocationOutput:
237237
def durable_execution(
238238
func: Callable[[Any, DurableContext], Any] | None = None,
239239
*,
240-
boto3_client: boto3.client | None = None,
240+
boto3_client: Boto3LambdaClient | None = None,
241241
) -> Callable[[Any, LambdaContext], Any]:
242242
# Decorator called with parameters
243243
if func is None:

src/aws_durable_execution_sdk_python/lambda_service.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
import copy
44
import datetime
55
import logging
6+
from collections.abc import MutableMapping
67
from dataclasses import dataclass, field
78
from enum import Enum
8-
from typing import TYPE_CHECKING, Any, Protocol, TypeAlias
9+
from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, cast
910

10-
import boto3 # type: ignore
11-
from botocore.config import Config # type: ignore
11+
import boto3
12+
from botocore.config import Config
1213

1314
from aws_durable_execution_sdk_python.exceptions import (
1415
CallableRuntimeError,
@@ -17,7 +18,11 @@
1718
)
1819

1920
if TYPE_CHECKING:
20-
from collections.abc import MutableMapping
21+
from mypy_boto3_lambda import LambdaClient as Boto3LambdaClient
22+
from mypy_boto3_lambda.type_defs import (
23+
CheckpointDurableExecutionResponseTypeDef,
24+
GetDurableExecutionStateResponseTypeDef,
25+
)
2126

2227
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
2328

@@ -1031,9 +1036,9 @@ def get_execution_state(
10311036
class LambdaClient(DurableServiceClient):
10321037
"""Persist durable operations to the Lambda Durable Function APIs."""
10331038

1034-
_cached_boto_client: Any = None
1039+
_cached_boto_client: Boto3LambdaClient | None = None
10351040

1036-
def __init__(self, client: Any) -> None:
1041+
def __init__(self, client: Boto3LambdaClient) -> None:
10371042
self.client = client
10381043

10391044
@classmethod
@@ -1066,19 +1071,20 @@ def checkpoint(
10661071
client_token: str | None,
10671072
) -> CheckpointOutput:
10681073
try:
1069-
params = {
1070-
"DurableExecutionArn": durable_execution_arn,
1071-
"CheckpointToken": checkpoint_token,
1072-
"Updates": [o.to_dict() for o in updates],
1073-
}
1074+
optional_params: dict[str, str] = {}
10741075
if client_token is not None:
1075-
params["ClientToken"] = client_token
1076-
1077-
result: MutableMapping[str, Any] = self.client.checkpoint_durable_execution(
1078-
**params
1076+
optional_params["ClientToken"] = client_token
1077+
1078+
result: CheckpointDurableExecutionResponseTypeDef = (
1079+
self.client.checkpoint_durable_execution(
1080+
DurableExecutionArn=durable_execution_arn,
1081+
CheckpointToken=checkpoint_token,
1082+
Updates=cast(Any, [o.to_dict() for o in updates]),
1083+
**optional_params, # type: ignore[arg-type]
1084+
)
10791085
)
10801086

1081-
return CheckpointOutput.from_dict(result)
1087+
return CheckpointOutput.from_dict(cast(MutableMapping[str, Any], result))
10821088
except Exception as e:
10831089
checkpoint_error = CheckpointError.from_exception(e)
10841090
logger.exception(
@@ -1094,13 +1100,15 @@ def get_execution_state(
10941100
max_items: int = 1000,
10951101
) -> StateOutput:
10961102
try:
1097-
result: MutableMapping[str, Any] = self.client.get_durable_execution_state(
1098-
DurableExecutionArn=durable_execution_arn,
1099-
CheckpointToken=checkpoint_token,
1100-
Marker=next_marker,
1101-
MaxItems=max_items,
1103+
result: GetDurableExecutionStateResponseTypeDef = (
1104+
self.client.get_durable_execution_state(
1105+
DurableExecutionArn=durable_execution_arn,
1106+
CheckpointToken=checkpoint_token,
1107+
Marker=next_marker,
1108+
MaxItems=max_items,
1109+
)
11021110
)
1103-
return StateOutput.from_dict(result)
1111+
return StateOutput.from_dict(cast(MutableMapping[str, Any], result))
11041112
except Exception as e:
11051113
error = GetExecutionStateError.from_exception(e)
11061114
logger.exception(

0 commit comments

Comments
 (0)