From 8fb1ed45dbadeb402c02698fe195f388d798e167 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Wed, 30 Oct 2024 15:08:35 -0400 Subject: [PATCH] Add transformation information to --info --- src/frontend/Info.ml | 45 +++++++++-- test/integration/cli-args/info/info.expected | 85 +++++++++++++------- test/integration/cli-args/info/info.stan | 3 + 3 files changed, 100 insertions(+), 33 deletions(-) diff --git a/src/frontend/Info.ml b/src/frontend/Info.ml index 09a8125e4..a62d68136 100644 --- a/src/frontend/Info.ml +++ b/src/frontend/Info.ml @@ -4,8 +4,8 @@ open Middle open Yojson.Basic let rec unsized_basetype_json t = - let to_json (type_, dim) : t = - `Assoc [("type", `String type_); ("dimensions", `Int dim)] in + let to_json (type_, dim) = + [("type", `String type_); ("dimensions", `Int dim)] in let internal, dims = UnsizedType.unwind_array_type t in match internal with | UnsizedType.UInt -> to_json ("int", dims) @@ -16,13 +16,44 @@ let rec unsized_basetype_json t = | UMatrix -> to_json ("real", dims + 2) | UComplexMatrix -> to_json ("complex", dims + 2) | UTuple internals -> - `Assoc - [ ("type", `List (List.map ~f:unsized_basetype_json internals)) - ; ("dimensions", `Int dims) ] + [ ( "type" + , `List + (List.map ~f:(fun f -> `Assoc (unsized_basetype_json f)) internals) + ); ("dimensions", `Int dims) ] | UMathLibraryFunction | UFun _ | UArray _ -> assert false let basetype_dims t = SizedType.to_unsized t |> unsized_basetype_json +let rec transformation t = + let expr_string = Fmt.to_to_string Pretty_printing.pp_expression in + let expr_string e = + `String (expr_string (Ast.untyped_expression_of_typed_expression e)) in + let transform details = [("transform", details)] in + match t with + | Transformation.Identity -> transform (`String "none") + | Lower e -> transform @@ `Assoc [("lower", expr_string e)] + | Upper e -> transform @@ `Assoc [("upper", expr_string e)] + | LowerUpper (e1, e2) -> + transform @@ `Assoc [("lower", expr_string e1); ("upper", expr_string e2)] + | Offset e -> transform @@ `Assoc [("offset", expr_string e)] + | Multiplier e -> transform @@ `Assoc [("multiplier", expr_string e)] + | OffsetMultiplier (e1, e2) -> + transform + @@ `Assoc [("offset", expr_string e1); ("multiplier", expr_string e2)] + | Ordered -> transform (`String "ordered") + | PositiveOrdered -> transform (`String "positive_ordered") + | Simplex -> transform (`String "simplex") + | UnitVector -> transform (`String "unit_vector") + | SumToZero -> transform (`String "sum_to_zero") + | CholeskyCorr -> transform (`String "cholesky_corr") + | CholeskyCov -> transform (`String "cholesky_cov") + | Correlation -> transform (`String "correlation") + | Covariance -> transform (`String "covariance") + | StochasticRow -> transform (`String "stochastic_row") + | StochasticColumn -> transform (`String "stochastic_column") + | TupleTransformation ts -> + transform (`List (List.map ~f:(fun t -> `Assoc (transformation t)) ts)) + let get_var_decl {stmts; _} : t = `Assoc (List.fold_right ~init:[] @@ -30,9 +61,11 @@ let get_var_decl {stmts; _} : t = match stmt.Ast.stmt with | Ast.VarDecl decl -> let type_info = basetype_dims decl.decl_type in + let transform_info = transformation decl.transformation in let decl_info = List.map - ~f:(fun {identifier; _} -> (identifier.name, type_info)) + ~f:(fun {identifier; _} -> + (identifier.name, `Assoc (type_info @ transform_info))) decl.variables in decl_info @ acc | _ -> acc) diff --git a/test/integration/cli-args/info/info.expected b/test/integration/cli-args/info/info.expected index 46a805ef6..eca886d7b 100644 --- a/test/integration/cli-args/info/info.expected +++ b/test/integration/cli-args/info/info.expected @@ -1,21 +1,21 @@ $ ../../../../../install/default/bin/stanc --include-paths=.,includes --info info.stan { "inputs": { - "a": { "type": "int", "dimensions": 0 }, - "b": { "type": "real", "dimensions": 0 }, - "c": { "type": "real", "dimensions": 1 }, - "d1": { "type": "real", "dimensions": 1 }, - "d2": { "type": "real", "dimensions": 1 }, - "e": { "type": "real", "dimensions": 2 }, - "f": { "type": "int", "dimensions": 1 }, - "g": { "type": "real", "dimensions": 1 }, - "h": { "type": "real", "dimensions": 2 }, - "i": { "type": "real", "dimensions": 3 }, - "j": { "type": "int", "dimensions": 3 }, - "cplx": { "type": "complex", "dimensions": 1 }, - "cplx_vec": { "type": "complex", "dimensions": 1 }, - "cplx_row": { "type": "complex", "dimensions": 1 }, - "cplx_mat": { "type": "complex", "dimensions": 2 }, + "a": { "type": "int", "dimensions": 0, "transform": "none" }, + "b": { "type": "real", "dimensions": 0, "transform": "none" }, + "c": { "type": "real", "dimensions": 1, "transform": "none" }, + "d1": { "type": "real", "dimensions": 1, "transform": "none" }, + "d2": { "type": "real", "dimensions": 1, "transform": "none" }, + "e": { "type": "real", "dimensions": 2, "transform": "none" }, + "f": { "type": "int", "dimensions": 1, "transform": "none" }, + "g": { "type": "real", "dimensions": 1, "transform": "none" }, + "h": { "type": "real", "dimensions": 2, "transform": "none" }, + "i": { "type": "real", "dimensions": 3, "transform": "none" }, + "j": { "type": "int", "dimensions": 3, "transform": "none" }, + "cplx": { "type": "complex", "dimensions": 1, "transform": "none" }, + "cplx_vec": { "type": "complex", "dimensions": 1, "transform": "none" }, + "cplx_row": { "type": "complex", "dimensions": 1, "transform": "none" }, + "cplx_mat": { "type": "complex", "dimensions": 2, "transform": "none" }, "tuples": { "type": [ { "type": "int", "dimensions": 0 }, @@ -28,22 +28,53 @@ "dimensions": 0 } ], - "dimensions": 1 + "dimensions": 1, + "transform": [ + { "transform": "none" }, + { "transform": "none" }, + { "transform": [ { "transform": "none" }, { "transform": "none" } ] } + ] } }, "parameters": { - "l": { "type": "real", "dimensions": 1 }, - "m": { "type": "real", "dimensions": 1 }, - "n": { "type": "real", "dimensions": 1 }, - "o": { "type": "real", "dimensions": 1 }, - "p": { "type": "real", "dimensions": 2 }, - "q": { "type": "real", "dimensions": 2 }, - "r": { "type": "real", "dimensions": 2 }, - "s": { "type": "real", "dimensions": 2 }, - "y": { "type": "real", "dimensions": 0 } + "low": { "type": "real", "dimensions": 0, "transform": { "lower": "0" } }, + "l": { "type": "real", "dimensions": 1, "transform": "simplex" }, + "m": { "type": "real", "dimensions": 1, "transform": "unit_vector" }, + "n": { "type": "real", "dimensions": 1, "transform": "ordered" }, + "o": { "type": "real", "dimensions": 1, "transform": "positive_ordered" }, + "p": { "type": "real", "dimensions": 2, "transform": "covariance" }, + "q": { "type": "real", "dimensions": 2, "transform": "correlation" }, + "r": { "type": "real", "dimensions": 2, "transform": "cholesky_cov" }, + "s": { "type": "real", "dimensions": 2, "transform": "cholesky_corr" }, + "y": { "type": "real", "dimensions": 0, "transform": "none" }, + "parameter_transforms": { + "type": [ + { "type": "real", "dimensions": 0 }, + { + "type": [ + { "type": "real", "dimensions": 1 }, + { "type": "real", "dimensions": 2 } + ], + "dimensions": 0 + } + ], + "dimensions": 0, + "transform": [ + { "transform": { "lower": "y" } }, + { + "transform": [ + { "transform": "simplex" }, { "transform": "covariance" } + ] + } + ] + } + }, + "transformed parameters": { + "t": { "type": "real", "dimensions": 2, "transform": "none" } + }, + "generated quantities": { + "u": { "type": "real", "dimensions": 0, "transform": "none" } }, - "transformed parameters": { "t": { "type": "real", "dimensions": 2 } }, - "generated quantities": { "u": { "type": "real", "dimensions": 0 } }, "functions": [ "fatal_error", "log", "print", "reduce_sum", "reject", "sin", "square" ], diff --git a/test/integration/cli-args/info/info.stan b/test/integration/cli-args/info/info.stan index f79c568a5..cf2271c87 100644 --- a/test/integration/cli-args/info/info.stan +++ b/test/integration/cli-args/info/info.stan @@ -37,6 +37,7 @@ transformed data { } parameters { + real low; simplex[10] l; unit_vector[11] m; ordered[12] n; @@ -46,6 +47,8 @@ parameters { cholesky_factor_cov[16] r; cholesky_factor_corr[17] s; real y; + + tuple(real, tuple(simplex[18], cov_matrix[19])) parameter_transforms; } transformed parameters {