11import os
22import socket
3- from typing import TypedDict
3+ from typing import TypedDict , Unpack , Any , cast
44
5- from cadence .api .v1 .service_worker_pb2_grpc import WorkerAPIStub
6- from grpc .aio import Channel
5+ from grpc import ChannelCredentials , Compression
76
8- from cadence .data_converter import DataConverter
7+ from cadence ._internal .rpc .yarpc import YarpcMetadataInterceptor
8+ from cadence .api .v1 .service_worker_pb2_grpc import WorkerAPIStub
9+ from grpc .aio import Channel , ClientInterceptor , secure_channel , insecure_channel
10+ from cadence .data_converter import DataConverter , DefaultDataConverter
911
1012
1113class ClientOptions (TypedDict , total = False ):
1214 domain : str
13- identity : str
15+ target : str
1416 data_converter : DataConverter
17+ identity : str
18+ service_name : str
19+ caller_name : str
20+ channel_arguments : dict [str , Any ]
21+ credentials : ChannelCredentials | None
22+ compression : Compression
23+ interceptors : list [ClientInterceptor ]
24+
25+ _DEFAULT_OPTIONS : ClientOptions = {
26+ "data_converter" : DefaultDataConverter (),
27+ "identity" : f"{ os .getpid ()} @{ socket .gethostname ()} " ,
28+ "service_name" : "cadence-frontend" ,
29+ "caller_name" : "cadence-client" ,
30+ "channel_arguments" : {},
31+ "credentials" : None ,
32+ "compression" : Compression .NoCompression ,
33+ "interceptors" : [],
34+ }
1535
1636class Client :
17- def __init__ (self , channel : Channel , options : ClientOptions ) -> None :
18- self ._channel = channel
19- self ._worker_stub = WorkerAPIStub (channel )
20- self ._options = options
21- self ._identity = options ["identity" ] if "identity" in options else f"{ os .getpid ()} @{ socket .gethostname ()} "
37+ def __init__ (self , ** kwargs : Unpack [ClientOptions ]) -> None :
38+ self ._options = _validate_and_copy_defaults (ClientOptions (** kwargs ))
39+ self ._channel = _create_channel (self ._options )
40+ self ._worker_stub = WorkerAPIStub (self ._channel )
2241
2342 @property
2443 def data_converter (self ) -> DataConverter :
@@ -30,14 +49,35 @@ def domain(self) -> str:
3049
3150 @property
3251 def identity (self ) -> str :
33- return self ._identity
52+ return self ._options [ "identity" ]
3453
3554 @property
3655 def worker_stub (self ) -> WorkerAPIStub :
3756 return self ._worker_stub
3857
39-
4058 async def close (self ) -> None :
4159 await self ._channel .close ()
4260
61+ def _validate_and_copy_defaults (options : ClientOptions ) -> ClientOptions :
62+ if "target" not in options :
63+ raise ValueError ("target must be specified" )
64+
65+ if "domain" not in options :
66+ raise ValueError ("domain must be specified" )
67+
68+ # Set default values for missing options
69+ for key , value in _DEFAULT_OPTIONS .items ():
70+ if key not in options :
71+ cast (dict , options )[key ] = value
72+
73+ return options
74+
75+
76+ def _create_channel (options : ClientOptions ) -> Channel :
77+ interceptors = list (options ["interceptors" ])
78+ interceptors .append (YarpcMetadataInterceptor (options ["service_name" ], options ["caller_name" ]))
4379
80+ if options ["credentials" ]:
81+ return secure_channel (options ["target" ], options ["credentials" ], options ["channel_arguments" ], options ["compression" ], interceptors )
82+ else :
83+ return insecure_channel (options ["target" ], options ["channel_arguments" ], options ["compression" ], interceptors )
0 commit comments