@@ -49,6 +49,7 @@ def __init__(
4949 description : str ,
5050 params : Sequence [ParameterSchema ],
5151 required_authn_params : Mapping [str , list [str ]],
52+ required_authz_tokens : Sequence [str ],
5253 auth_service_token_getters : Mapping [str , Callable [[], str ]],
5354 bound_params : Mapping [str , Union [Callable [[], Any ], Any ]],
5455 client_headers : Mapping [str , Union [Callable , Coroutine , str ]],
@@ -63,12 +64,14 @@ def __init__(
6364 name: The name of the remote tool.
6465 description: The description of the remote tool.
6566 params: The args of the tool.
66- required_authn_params: A map of required authenticated parameters to a list
67- of alternative services that can provide values for them.
68- auth_service_token_getters: A dict of authService -> token (or callables that
69- produce a token)
70- bound_params: A mapping of parameter names to bind to specific values or
71- callables that are called to produce values as needed.
67+ required_authn_params: A map of required authenticated parameters to
68+ a list of alternative services that can provide values for them.
69+ required_authz_tokens: A sequence of alternative services for
70+ providing authorization token for the tool invocation.
71+ auth_service_token_getters: A dict of authService -> token (or
72+ callables that produce a token)
73+ bound_params: A mapping of parameter names to bind to specific
74+ values or callables that are called to produce values as needed.
7275 client_headers: Client specific headers bound to the tool.
7376 """
7477 # used to invoke the toolbox API
@@ -106,6 +109,8 @@ def __init__(
106109
107110 # map of parameter name to auth service required by it
108111 self .__required_authn_params = required_authn_params
112+ # sequence of authorization tokens required by it
113+ self .__required_authz_tokens = required_authz_tokens
109114 # map of authService -> token_getter
110115 self .__auth_service_token_getters = auth_service_token_getters
111116 # map of parameter name to value (or callable that produces that value)
@@ -121,6 +126,7 @@ def __copy(
121126 description : Optional [str ] = None ,
122127 params : Optional [Sequence [ParameterSchema ]] = None ,
123128 required_authn_params : Optional [Mapping [str , list [str ]]] = None ,
129+ required_authz_tokens : Optional [Sequence [str ]] = None ,
124130 auth_service_token_getters : Optional [Mapping [str , Callable [[], str ]]] = None ,
125131 bound_params : Optional [Mapping [str , Union [Callable [[], Any ], Any ]]] = None ,
126132 client_headers : Optional [Mapping [str , Union [Callable , Coroutine , str ]]] = None ,
@@ -134,12 +140,14 @@ def __copy(
134140 name: The name of the remote tool.
135141 description: The description of the remote tool.
136142 params: The args of the tool.
137- required_authn_params: A map of required authenticated parameters to a list
138- of alternative services that can provide values for them.
139- auth_service_token_getters: A dict of authService -> token (or callables
140- that produce a token)
141- bound_params: A mapping of parameter names to bind to specific values or
142- callables that are called to produce values as needed.
143+ required_authn_params: A map of required authenticated parameters to
144+ a list of alternative services that can provide values for them.
145+ required_authz_tokens: A sequence of alternative services for
146+ providing authorization token for the tool invocation.
147+ auth_service_token_getters: A dict of authService -> token (or
148+ callables that produce a token)
149+ bound_params: A mapping of parameter names to bind to specific
150+ values or callables that are called to produce values as needed.
143151 client_headers: Client specific headers bound to the tool.
144152 """
145153 check = lambda val , default : val if val is not None else default
@@ -152,6 +160,9 @@ def __copy(
152160 required_authn_params = check (
153161 required_authn_params , self .__required_authn_params
154162 ),
163+ required_authz_tokens = check (
164+ required_authz_tokens , self .__required_authz_tokens
165+ ),
155166 auth_service_token_getters = check (
156167 auth_service_token_getters , self .__auth_service_token_getters
157168 ),
@@ -179,11 +190,15 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
179190 """
180191
181192 # check if any auth services need to be specified yet
182- if len (self .__required_authn_params ) > 0 :
193+ if (
194+ len (self .__required_authn_params ) > 0
195+ or len (self .__required_authz_tokens ) > 0
196+ ):
183197 # Gather all the required auth services into a set
184198 req_auth_services = set ()
185199 for s in self .__required_authn_params .values ():
186200 req_auth_services .update (s )
201+ req_auth_services .update (self .__required_authz_tokens )
187202 raise ValueError (
188203 f"One or more of the following authn services are required to invoke this tool"
189204 f": { ',' .join (req_auth_services )} "
@@ -269,18 +284,20 @@ def add_auth_token_getters(
269284 dict (self .__auth_service_token_getters , ** auth_token_getters )
270285 )
271286 # create a read-only updated for params that are still required
272- new_req_authn_params = types . MappingProxyType (
287+ new_req_authn_params , new_req_authz_tokens , used_auth_token_getters = (
273288 identify_required_authn_params (
274- # TODO: Add authRequired
275289 self .__required_authn_params ,
276- [] ,
290+ self . __required_authz_tokens ,
277291 auth_token_getters .keys (),
278- )[ 0 ]
292+ )
279293 )
280294
295+ # TODO: Add validation for used_auth_token_getters
296+
281297 return self .__copy (
282298 auth_service_token_getters = new_getters ,
283- required_authn_params = new_req_authn_params ,
299+ required_authn_params = types .MappingProxyType (new_req_authn_params ),
300+ required_authz_tokens = new_req_authz_tokens ,
284301 )
285302
286303 def bind_params (
0 commit comments