33import copy
44import datetime
55import logging
6+ from collections .abc import MutableMapping
67from dataclasses import dataclass , field
78from enum import Enum
8- from typing import TYPE_CHECKING , Any , Protocol , TypeAlias
9+ from typing import TYPE_CHECKING , Any , Protocol , TypeAlias , cast
910
10- import boto3 # type: ignore
11- from botocore .config import Config # type: ignore
11+ import boto3
12+ from botocore .config import Config
1213
1314from aws_durable_execution_sdk_python .exceptions import (
1415 CallableRuntimeError ,
1718)
1819
1920if TYPE_CHECKING :
20- from collections .abc import MutableMapping
21+ from mypy_boto3_lambda import LambdaClient as Boto3LambdaClient
22+ from mypy_boto3_lambda .type_defs import (
23+ CheckpointDurableExecutionResponseTypeDef ,
24+ GetDurableExecutionStateResponseTypeDef ,
25+ )
2126
2227 from aws_durable_execution_sdk_python .identifier import OperationIdentifier
2328
@@ -1031,9 +1036,9 @@ def get_execution_state(
10311036class LambdaClient (DurableServiceClient ):
10321037 """Persist durable operations to the Lambda Durable Function APIs."""
10331038
1034- _cached_boto_client : Any = None
1039+ _cached_boto_client : Boto3LambdaClient | None = None
10351040
1036- def __init__ (self , client : Any ) -> None :
1041+ def __init__ (self , client : Boto3LambdaClient ) -> None :
10371042 self .client = client
10381043
10391044 @classmethod
@@ -1066,19 +1071,20 @@ def checkpoint(
10661071 client_token : str | None ,
10671072 ) -> CheckpointOutput :
10681073 try :
1069- params = {
1070- "DurableExecutionArn" : durable_execution_arn ,
1071- "CheckpointToken" : checkpoint_token ,
1072- "Updates" : [o .to_dict () for o in updates ],
1073- }
1074+ optional_params : dict [str , str ] = {}
10741075 if client_token is not None :
1075- params ["ClientToken" ] = client_token
1076-
1077- result : MutableMapping [str , Any ] = self .client .checkpoint_durable_execution (
1078- ** params
1076+ optional_params ["ClientToken" ] = client_token
1077+
1078+ result : CheckpointDurableExecutionResponseTypeDef = (
1079+ self .client .checkpoint_durable_execution (
1080+ DurableExecutionArn = durable_execution_arn ,
1081+ CheckpointToken = checkpoint_token ,
1082+ Updates = cast (Any , [o .to_dict () for o in updates ]),
1083+ ** optional_params , # type: ignore[arg-type]
1084+ )
10791085 )
10801086
1081- return CheckpointOutput .from_dict (result )
1087+ return CheckpointOutput .from_dict (cast ( MutableMapping [ str , Any ], result ) )
10821088 except Exception as e :
10831089 checkpoint_error = CheckpointError .from_exception (e )
10841090 logger .exception (
@@ -1094,13 +1100,15 @@ def get_execution_state(
10941100 max_items : int = 1000 ,
10951101 ) -> StateOutput :
10961102 try :
1097- result : MutableMapping [str , Any ] = self .client .get_durable_execution_state (
1098- DurableExecutionArn = durable_execution_arn ,
1099- CheckpointToken = checkpoint_token ,
1100- Marker = next_marker ,
1101- MaxItems = max_items ,
1103+ result : GetDurableExecutionStateResponseTypeDef = (
1104+ self .client .get_durable_execution_state (
1105+ DurableExecutionArn = durable_execution_arn ,
1106+ CheckpointToken = checkpoint_token ,
1107+ Marker = next_marker ,
1108+ MaxItems = max_items ,
1109+ )
11021110 )
1103- return StateOutput .from_dict (result )
1111+ return StateOutput .from_dict (cast ( MutableMapping [ str , Any ], result ) )
11041112 except Exception as e :
11051113 error = GetExecutionStateError .from_exception (e )
11061114 logger .exception (
0 commit comments