1+ import os , json
2+
3+ def write_config (models , model_dir , multi_mode = False , case_label = None ):
4+ """
5+ Write one or multiple model configs to JSON.
6+
7+ models: list of dicts, each dict must contain:
8+ - input_features
9+ - output_features
10+ - model_path
11+ - active_cells
12+ - apply_time
13+ - feature_engineering_input (optional)
14+ - scaling_input (optional)
15+ - feature_engineering_output (optional)
16+ - scaling_output (optional)
17+ - input_scaling_params (optional)
18+ - output_scaling_params (optional)
19+
20+ multi_mode: if True, write all configs into a single JSON.
21+ case_label: used as filename for multi-mode JSON (required if multi_mode=True).
22+ """
23+
24+ if multi_mode and not case_label :
25+ raise ValueError ("case_label must be provided for multi_mode=True" )
26+
27+ os .makedirs (model_dir , exist_ok = True )
28+ all_configs = []
29+
30+ for m in models :
31+ # Normalize input/output features
32+ input_features = m ["input_features" ]
33+ output_features = m ["output_features" ]
34+ if isinstance (input_features , str ):
35+ input_features = [input_features ]
36+ if isinstance (output_features , str ):
37+ output_features = [output_features ]
38+
39+ n_input = len (input_features )
40+ n_output = len (output_features )
41+
42+ # Optional fields (safe defaults)
43+ fe_input = m .get ("feature_engineering_input" ) or [None ]* n_input
44+ fe_output = m .get ("feature_engineering_output" ) or [None ]* n_output
45+ scaling_input = m .get ("scaling_input" ) or [None ]* n_input
46+ scaling_output = m .get ("scaling_output" ) or [None ]* n_output
47+ input_scaling_params = m .get ("input_scaling_params" ) or [None ]* n_input
48+ output_scaling_params = m .get ("output_scaling_params" ) or [None ]* n_output
49+
50+ # Build input block
51+ input_block = {}
52+ for i , fname in enumerate (input_features ):
53+ feat_eng = fe_input [i ] if i < len (fe_input ) else None
54+ scale = scaling_input [i ] if i < len (scaling_input ) else None
55+ feature_dict = {
56+ "feature_engineering" : feat_eng .lower () if feat_eng and feat_eng .lower () != "none" else None ,
57+ "scaling" : scale .lower () if scale and scale .lower () != "none" else None ,
58+ }
59+ feature_dict = {k : v for k , v in feature_dict .items () if v is not None }
60+ if i < len (input_scaling_params ) and input_scaling_params [i ] is not None :
61+ feature_dict ["scaling_params" ] = input_scaling_params [i ]
62+ input_block [fname ] = feature_dict
63+
64+ # Build output block
65+ output_block = {}
66+ for i , fname in enumerate (output_features ):
67+ feat_eng = fe_output [i ] if i < len (fe_output ) else None
68+ scale = scaling_output [i ] if i < len (scaling_output ) else None
69+ feature_dict = {
70+ "feature_engineering" : feat_eng .lower () if feat_eng and feat_eng .lower () != "none" else None ,
71+ "scaling" : scale .lower () if scale and scale .lower () != "none" else None ,
72+ }
73+ feature_dict = {k : v for k , v in feature_dict .items () if v is not None }
74+ if i < len (output_scaling_params ) and output_scaling_params [i ] is not None :
75+ feature_dict ["scaling_params" ] = output_scaling_params [i ]
76+ output_block [fname ] = feature_dict
77+
78+ # Save active cells
79+ model_base = os .path .splitext (os .path .basename (m ["model_path" ]))[0 ]
80+ cells_file = os .path .join (model_dir , model_base + "_active_cells.txt" )
81+ with open (cells_file , "w" ) as f :
82+ for cell in m ["active_cells" ]:
83+ f .write (f"{ cell } \n " )
84+
85+ # Build config dict
86+ cfg = {
87+ "model_path" : m ["model_path" ],
88+ "cell_indices_file" : cells_file ,
89+ "apply_times" : [m ["apply_time" ]],
90+ "features" : {
91+ "inputs" : input_block ,
92+ "outputs" : output_block
93+ }
94+ }
95+ all_configs .append (cfg )
96+
97+ # Determine JSON filename
98+ if multi_mode :
99+ if not case_label :
100+ raise ValueError ("case_label is required for multi_mode" )
101+ json_name = case_label + ".json"
102+ else :
103+ json_name = os .path .splitext (os .path .basename (models [0 ]["model_path" ]))[0 ] + ".json"
104+
105+ json_path = os .path .join (model_dir , json_name )
106+
107+ # Always write a list of configs
108+ with open (json_path , "w" ) as f :
109+ json .dump (all_configs , f , indent = 2 )
110+
111+ return json_path
0 commit comments