@@ -81,6 +81,7 @@ class AuthSamlProvider(models.Model):
81
81
"auth.saml.attribute.mapping" ,
82
82
"provider_id" ,
83
83
string = "Attribute Mapping" ,
84
+ copy = True ,
84
85
)
85
86
active = fields .Boolean (default = True )
86
87
sequence = fields .Integer (index = True )
@@ -136,6 +137,20 @@ class AuthSamlProvider(models.Model):
136
137
default = True ,
137
138
help = "Whether metadata should be signed or not" ,
138
139
)
140
+ # User creation fields
141
+ create_user = fields .Boolean (
142
+ default = False ,
143
+ help = "Create user if not found. The login and name will defaults to the SAML "
144
+ "user matching attribute. Use the mapping attributes to change the value "
145
+ "used." ,
146
+ )
147
+ create_user_template_id = fields .Many2one (
148
+ comodel_name = "res.users" ,
149
+ # Template users, like base.default_user, are disabled by default so allow them
150
+ domain = "[('active', 'in', (True, False))]" ,
151
+ default = lambda self : self .env .ref ("base.default_user" ),
152
+ help = "When creating user, this user is used as a template" ,
153
+ )
139
154
140
155
@api .model
141
156
def _sig_alg_selection (self ):
@@ -256,9 +271,7 @@ def _get_auth_request(self, extra_state=None, url_root=None):
256
271
}
257
272
state .update (extra_state )
258
273
259
- sig_alg = ds .SIG_RSA_SHA1
260
- if self .sig_alg :
261
- sig_alg = getattr (ds , self .sig_alg )
274
+ sig_alg = getattr (ds , self .sig_alg )
262
275
263
276
saml_client = self ._get_client_for_provider (url_root )
264
277
reqid , info = saml_client .prepare_for_authenticate (
@@ -287,27 +300,15 @@ def _validate_auth_response(self, token: str, base_url: str = None):
287
300
saml2 .entity .BINDING_HTTP_POST ,
288
301
self ._get_outstanding_requests_dict (),
289
302
)
290
- matching_value = None
291
-
292
- if self .matching_attribute == "subject.nameId" :
293
- matching_value = response .name_id .text
294
- else :
295
- attrs = response .get_identity ()
296
-
297
- for k , v in attrs .items ():
298
- if k == self .matching_attribute :
299
- matching_value = v
300
- break
301
-
302
- if not matching_value :
303
- raise Exception (
304
- f"Matching attribute { self .matching_attribute } not found "
305
- f"in user attrs: { attrs } "
306
- )
307
-
308
- if matching_value and isinstance (matching_value , list ):
309
- matching_value = next (iter (matching_value ), None )
310
-
303
+ try :
304
+ matching_value = self ._get_attribute_value (
305
+ response , self .matching_attribute
306
+ )
307
+ except KeyError :
308
+ raise Exception (
309
+ f"Matching attribute { self .matching_attribute } not found "
310
+ f"in user attrs: { response .get_identity ()} "
311
+ ) from None
311
312
if isinstance (matching_value , str ) and self .matching_attribute_to_lower :
312
313
matching_value = matching_value .lower ()
313
314
@@ -349,24 +350,59 @@ def _metadata_string(self, valid=None, base_url: str = None):
349
350
sign = self .sign_metadata ,
350
351
)
351
352
353
+ @staticmethod
354
+ def _get_attribute_value (response , attribute_name : str ):
355
+ """
356
+
357
+ :raise: KeyError if attribute is not in the response
358
+ :param response:
359
+ :param attribute_name:
360
+ :return: value of the attribut. if the value is an empty list, return None
361
+ otherwise return the first element of the list
362
+ """
363
+ if attribute_name == "subject.nameId" :
364
+ return response .name_id .text
365
+ attrs = response .get_identity ()
366
+ attribute_value = attrs [attribute_name ]
367
+ if isinstance (attribute_value , list ):
368
+ attribute_value = next (iter (attribute_value ), None )
369
+ return attribute_value
370
+
352
371
def _hook_validate_auth_response (self , response , matching_value ):
353
372
self .ensure_one ()
354
373
vals = {}
355
- attrs = response .get_identity ()
356
374
357
375
for attribute in self .attribute_mapping_ids :
358
- if attribute .attribute_name not in attrs :
359
- _logger .debug (
376
+ try :
377
+ vals [attribute .field_name ] = self ._get_attribute_value (
378
+ response , attribute .attribute_name
379
+ )
380
+ except KeyError :
381
+ _logger .warning (
360
382
"SAML attribute '%s' found in response %s" ,
361
383
attribute .attribute_name ,
362
- attrs ,
384
+ response . get_identity () ,
363
385
)
364
- continue
365
386
366
- attribute_value = attrs [attribute .attribute_name ]
367
- if isinstance (attribute_value , list ):
368
- attribute_value = attribute_value [0 ]
387
+ return {"mapped_attrs" : vals }
369
388
370
- vals [attribute .field_name ] = attribute_value
389
+ def _user_copy_defaults (self , validation ):
390
+ """
391
+ Returns defaults when copying the template user.
371
392
372
- return {"mapped_attrs" : vals }
393
+ Can be overridden with extra information.
394
+ :param validation: validation result
395
+ :return: a dictionary for copying template user, empty to avoid copying
396
+ """
397
+ self .ensure_one ()
398
+ if not self .create_user :
399
+ return {}
400
+ saml_uid = validation ["user_id" ]
401
+ return {
402
+ "name" : saml_uid ,
403
+ "login" : saml_uid ,
404
+ "active" : True ,
405
+ # if signature is not provided by mapped_attrs, it will be computed
406
+ # due to call to compute method in calling method.
407
+ "signature" : None ,
408
+ } | validation .get ("mapped_attrs" , {})
0 commit comments