@@ -23,6 +23,8 @@ def test_create_performance_definition():
2323 RestObj ({"name" : "Test Model 2" , "id" : "67890" , "projectId" : PROJECT ["id" ]}),
2424 ]
2525 USER = "username"
26+ VERSION_MOCK = {"modelVersionName" : "1.0" }
27+ VERSION_MOCK_NONAME = {}
2628
2729 with mock .patch ("sasctl.core.Session._get_authorization_token" ):
2830 current_session ("example.com" , USER , "password" )
@@ -111,6 +113,32 @@ def test_create_performance_definition():
111113 table_prefix = "TestData" ,
112114 )
113115
116+ with pytest .raises (ValueError ):
117+ # Model verions exceeds models
118+ get_model .side_effect = copy .deepcopy (MODELS )
119+ _ = mm .create_performance_definition (
120+ models = ["model1" , "model2" ],
121+ modelVersions = ["1.0" , "2.0" , "3.0" ],
122+ library_name = "TestLibrary" ,
123+ table_prefix = "TestData" ,
124+ max_bins = 3 ,
125+ monitor_challenger = True ,
126+ monitor_champion = True ,
127+ )
128+
129+ with pytest .raises (ValueError ):
130+ # Model version dictionary missing modelVersionName
131+ get_model .side_effect = copy .deepcopy (MODELS )
132+ _ = mm .create_performance_definition (
133+ models = ["model1" , "model2" ],
134+ modelVersions = VERSION_MOCK_NONAME ,
135+ library_name = "TestLibrary" ,
136+ table_prefix = "TestData" ,
137+ max_bins = 3 ,
138+ monitor_challenger = True ,
139+ monitor_champion = True ,
140+ )
141+
114142 get_project .return_value = copy .deepcopy (PROJECT )
115143 get_project .return_value ["targetVariable" ] = "target"
116144 get_project .return_value ["targetLevel" ] = "interval"
@@ -125,21 +153,68 @@ def test_create_performance_definition():
125153 monitor_challenger = True ,
126154 monitor_champion = True ,
127155 )
156+ url , data = post_models .call_args
157+ assert post_models .call_count == 1
158+ assert PROJECT ["id" ] == data ["json" ]["projectId" ]
159+ assert MODELS [0 ]["id" ] in data ["json" ]["modelIds" ]
160+ assert MODELS [1 ]["id" ] in data ["json" ]["modelIds" ]
161+ assert "TestLibrary" == data ["json" ]["dataLibrary" ]
162+ assert "TestData" == data ["json" ]["dataPrefix" ]
163+ assert "cas-shared-default" == data ["json" ]["casServerId" ]
164+ assert data ["json" ]["name" ]
165+ assert data ["json" ]["description" ]
166+ assert data ["json" ]["maxBins" ] == 3
167+ assert data ["json" ]["championMonitored" ] is True
168+ assert data ["json" ]["challengerMonitored" ] is True
128169
129- assert post_models .call_count == 1
130- url , data = post_models .call_args
131-
132- assert PROJECT ["id" ] == data ["json" ]["projectId" ]
133- assert MODELS [0 ]["id" ] in data ["json" ]["modelIds" ]
134- assert MODELS [1 ]["id" ] in data ["json" ]["modelIds" ]
135- assert "TestLibrary" == data ["json" ]["dataLibrary" ]
136- assert "TestData" == data ["json" ]["dataPrefix" ]
137- assert "cas-shared-default" == data ["json" ]["casServerId" ]
138- assert data ["json" ]["name" ]
139- assert data ["json" ]["description" ]
140- assert data ["json" ]["maxBins" ] == 3
141- assert data ["json" ]["championMonitored" ] is True
142- assert data ["json" ]["challengerMonitored" ] is True
170+ get_model .side_effect = copy .deepcopy (MODELS )
171+ _ = mm .create_performance_definition (
172+ # One model version as a string name
173+ models = ["model1" , "model2" ],
174+ modelVersions = "1.0" ,
175+ library_name = "TestLibrary" ,
176+ table_prefix = "TestData" ,
177+ max_bins = 3 ,
178+ monitor_challenger = True ,
179+ monitor_champion = True ,
180+ )
181+
182+ assert post_models .call_count == 2
183+ url , data = post_models .call_args
184+ assert f"{ MODELS [0 ]['id' ]} :1.0" in data ["json" ]["modelIds" ]
185+ assert MODELS [1 ]["id" ] in data ["json" ]["modelIds" ]
186+
187+ get_model .side_effect = copy .deepcopy (MODELS )
188+ # List of string type model versions
189+ _ = mm .create_performance_definition (
190+ models = ["model1" , "model2" ],
191+ modelVersions = ["1.0" , "2.0" ],
192+ library_name = "TestLibrary" ,
193+ table_prefix = "TestData" ,
194+ max_bins = 3 ,
195+ monitor_challenger = True ,
196+ monitor_champion = True ,
197+ )
198+ assert post_models .call_count == 3
199+ url , data = post_models .call_args
200+ assert f"{ MODELS [0 ]['id' ]} :1.0" in data ["json" ]["modelIds" ]
201+ assert f"{ MODELS [1 ]['id' ]} :2.0" in data ["json" ]["modelIds" ]
202+
203+ get_model .side_effect = copy .deepcopy (MODELS )
204+ # List of dictionary type and string type model versions
205+ _ = mm .create_performance_definition (
206+ models = ["model1" , "model2" ],
207+ modelVersions = [VERSION_MOCK , "2.0" ],
208+ library_name = "TestLibrary" ,
209+ table_prefix = "TestData" ,
210+ max_bins = 3 ,
211+ monitor_challenger = True ,
212+ monitor_champion = True ,
213+ )
214+ assert post_models .call_count == 4
215+ url , data = post_models .call_args
216+ assert f"{ MODELS [0 ]['id' ]} :1.0" in data ["json" ]["modelIds" ]
217+ assert f"{ MODELS [1 ]['id' ]} :2.0" in data ["json" ]["modelIds" ]
143218
144219 with mock .patch (
145220 "sasctl._services.model_management.ModelManagement" ".post"
@@ -160,20 +235,39 @@ def test_create_performance_definition():
160235 monitor_champion = True ,
161236 )
162237
163- assert post_project .call_count == 1
164- url , data = post_project .call_args
165-
166- assert PROJECT ["id" ] == data ["json" ]["projectId" ]
167- assert MODELS [0 ]["id" ] in data ["json" ]["modelIds" ]
168- assert MODELS [1 ]["id" ] in data ["json" ]["modelIds" ]
169- assert "TestLibrary" == data ["json" ]["dataLibrary" ]
170- assert "TestData" == data ["json" ]["dataPrefix" ]
171- assert "cas-shared-default" == data ["json" ]["casServerId" ]
172- assert data ["json" ]["name" ]
173- assert data ["json" ]["description" ]
174- assert data ["json" ]["maxBins" ] == 3
175- assert data ["json" ]["championMonitored" ] is True
176- assert data ["json" ]["challengerMonitored" ] is True
238+ # one extra test for project with version id
239+
240+ assert post_project .call_count == 1
241+ url , data = post_project .call_args
242+
243+ assert PROJECT ["id" ] == data ["json" ]["projectId" ]
244+ assert MODELS [0 ]["id" ] in data ["json" ]["modelIds" ]
245+ assert MODELS [1 ]["id" ] in data ["json" ]["modelIds" ]
246+ assert "TestLibrary" == data ["json" ]["dataLibrary" ]
247+ assert "TestData" == data ["json" ]["dataPrefix" ]
248+ assert "cas-shared-default" == data ["json" ]["casServerId" ]
249+ assert data ["json" ]["name" ]
250+ assert data ["json" ]["description" ]
251+ assert data ["json" ]["maxBins" ] == 3
252+ assert data ["json" ]["championMonitored" ] is True
253+ assert data ["json" ]["challengerMonitored" ] is True
254+
255+ get_model .side_effect = copy .deepcopy (MODELS )
256+ # Project with model version
257+ _ = mm .create_performance_definition (
258+ project = "project" ,
259+ modelVersions = "2.0" ,
260+ library_name = "TestLibrary" ,
261+ table_prefix = "TestData" ,
262+ max_bins = 3 ,
263+ monitor_challenger = True ,
264+ monitor_champion = True ,
265+ )
266+
267+ assert post_project .call_count == 2
268+ url , data = post_project .call_args
269+ assert f"{ MODELS [0 ]['id' ]} :2.0" in data ["json" ]["modelIds" ]
270+ assert MODELS [1 ]["id" ] in data ["json" ]["modelIds" ]
177271
178272 def test_table_prefix_format ():
179273 with pytest .raises (ValueError ):
0 commit comments