Skip to content

Commit

Permalink
Add transformation information to --info
Browse files Browse the repository at this point in the history
  • Loading branch information
WardBrian committed Oct 30, 2024
1 parent 5692dc4 commit 8fb1ed4
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 33 deletions.
45 changes: 39 additions & 6 deletions src/frontend/Info.ml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ open Middle
open Yojson.Basic

let rec unsized_basetype_json t =
let to_json (type_, dim) : t =
`Assoc [("type", `String type_); ("dimensions", `Int dim)] in
let to_json (type_, dim) =
[("type", `String type_); ("dimensions", `Int dim)] in
let internal, dims = UnsizedType.unwind_array_type t in
match internal with
| UnsizedType.UInt -> to_json ("int", dims)
Expand All @@ -16,23 +16,56 @@ let rec unsized_basetype_json t =
| UMatrix -> to_json ("real", dims + 2)
| UComplexMatrix -> to_json ("complex", dims + 2)
| UTuple internals ->
`Assoc
[ ("type", `List (List.map ~f:unsized_basetype_json internals))
; ("dimensions", `Int dims) ]
[ ( "type"
, `List
(List.map ~f:(fun f -> `Assoc (unsized_basetype_json f)) internals)
); ("dimensions", `Int dims) ]
| UMathLibraryFunction | UFun _ | UArray _ -> assert false

let basetype_dims t = SizedType.to_unsized t |> unsized_basetype_json

let rec transformation t =
let expr_string = Fmt.to_to_string Pretty_printing.pp_expression in
let expr_string e =
`String (expr_string (Ast.untyped_expression_of_typed_expression e)) in
let transform details = [("transform", details)] in
match t with
| Transformation.Identity -> transform (`String "none")
| Lower e -> transform @@ `Assoc [("lower", expr_string e)]
| Upper e -> transform @@ `Assoc [("upper", expr_string e)]
| LowerUpper (e1, e2) ->
transform @@ `Assoc [("lower", expr_string e1); ("upper", expr_string e2)]
| Offset e -> transform @@ `Assoc [("offset", expr_string e)]
| Multiplier e -> transform @@ `Assoc [("multiplier", expr_string e)]
| OffsetMultiplier (e1, e2) ->
transform
@@ `Assoc [("offset", expr_string e1); ("multiplier", expr_string e2)]
| Ordered -> transform (`String "ordered")
| PositiveOrdered -> transform (`String "positive_ordered")
| Simplex -> transform (`String "simplex")
| UnitVector -> transform (`String "unit_vector")
| SumToZero -> transform (`String "sum_to_zero")
| CholeskyCorr -> transform (`String "cholesky_corr")
| CholeskyCov -> transform (`String "cholesky_cov")
| Correlation -> transform (`String "correlation")
| Covariance -> transform (`String "covariance")
| StochasticRow -> transform (`String "stochastic_row")
| StochasticColumn -> transform (`String "stochastic_column")
| TupleTransformation ts ->
transform (`List (List.map ~f:(fun t -> `Assoc (transformation t)) ts))

let get_var_decl {stmts; _} : t =
`Assoc
(List.fold_right ~init:[]
~f:(fun stmt acc ->
match stmt.Ast.stmt with
| Ast.VarDecl decl ->
let type_info = basetype_dims decl.decl_type in
let transform_info = transformation decl.transformation in
let decl_info =
List.map
~f:(fun {identifier; _} -> (identifier.name, type_info))
~f:(fun {identifier; _} ->
(identifier.name, `Assoc (type_info @ transform_info)))
decl.variables in
decl_info @ acc
| _ -> acc)
Expand Down
85 changes: 58 additions & 27 deletions test/integration/cli-args/info/info.expected
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
$ ../../../../../install/default/bin/stanc --include-paths=.,includes --info info.stan
{
"inputs": {
"a": { "type": "int", "dimensions": 0 },
"b": { "type": "real", "dimensions": 0 },
"c": { "type": "real", "dimensions": 1 },
"d1": { "type": "real", "dimensions": 1 },
"d2": { "type": "real", "dimensions": 1 },
"e": { "type": "real", "dimensions": 2 },
"f": { "type": "int", "dimensions": 1 },
"g": { "type": "real", "dimensions": 1 },
"h": { "type": "real", "dimensions": 2 },
"i": { "type": "real", "dimensions": 3 },
"j": { "type": "int", "dimensions": 3 },
"cplx": { "type": "complex", "dimensions": 1 },
"cplx_vec": { "type": "complex", "dimensions": 1 },
"cplx_row": { "type": "complex", "dimensions": 1 },
"cplx_mat": { "type": "complex", "dimensions": 2 },
"a": { "type": "int", "dimensions": 0, "transform": "none" },
"b": { "type": "real", "dimensions": 0, "transform": "none" },
"c": { "type": "real", "dimensions": 1, "transform": "none" },
"d1": { "type": "real", "dimensions": 1, "transform": "none" },
"d2": { "type": "real", "dimensions": 1, "transform": "none" },
"e": { "type": "real", "dimensions": 2, "transform": "none" },
"f": { "type": "int", "dimensions": 1, "transform": "none" },
"g": { "type": "real", "dimensions": 1, "transform": "none" },
"h": { "type": "real", "dimensions": 2, "transform": "none" },
"i": { "type": "real", "dimensions": 3, "transform": "none" },
"j": { "type": "int", "dimensions": 3, "transform": "none" },
"cplx": { "type": "complex", "dimensions": 1, "transform": "none" },
"cplx_vec": { "type": "complex", "dimensions": 1, "transform": "none" },
"cplx_row": { "type": "complex", "dimensions": 1, "transform": "none" },
"cplx_mat": { "type": "complex", "dimensions": 2, "transform": "none" },
"tuples": {
"type": [
{ "type": "int", "dimensions": 0 },
Expand All @@ -28,22 +28,53 @@
"dimensions": 0
}
],
"dimensions": 1
"dimensions": 1,
"transform": [
{ "transform": "none" },
{ "transform": "none" },
{ "transform": [ { "transform": "none" }, { "transform": "none" } ] }
]
}
},
"parameters": {
"l": { "type": "real", "dimensions": 1 },
"m": { "type": "real", "dimensions": 1 },
"n": { "type": "real", "dimensions": 1 },
"o": { "type": "real", "dimensions": 1 },
"p": { "type": "real", "dimensions": 2 },
"q": { "type": "real", "dimensions": 2 },
"r": { "type": "real", "dimensions": 2 },
"s": { "type": "real", "dimensions": 2 },
"y": { "type": "real", "dimensions": 0 }
"low": { "type": "real", "dimensions": 0, "transform": { "lower": "0" } },
"l": { "type": "real", "dimensions": 1, "transform": "simplex" },
"m": { "type": "real", "dimensions": 1, "transform": "unit_vector" },
"n": { "type": "real", "dimensions": 1, "transform": "ordered" },
"o": { "type": "real", "dimensions": 1, "transform": "positive_ordered" },
"p": { "type": "real", "dimensions": 2, "transform": "covariance" },
"q": { "type": "real", "dimensions": 2, "transform": "correlation" },
"r": { "type": "real", "dimensions": 2, "transform": "cholesky_cov" },
"s": { "type": "real", "dimensions": 2, "transform": "cholesky_corr" },
"y": { "type": "real", "dimensions": 0, "transform": "none" },
"parameter_transforms": {
"type": [
{ "type": "real", "dimensions": 0 },
{
"type": [
{ "type": "real", "dimensions": 1 },
{ "type": "real", "dimensions": 2 }
],
"dimensions": 0
}
],
"dimensions": 0,
"transform": [
{ "transform": { "lower": "y" } },
{
"transform": [
{ "transform": "simplex" }, { "transform": "covariance" }
]
}
]
}
},
"transformed parameters": {
"t": { "type": "real", "dimensions": 2, "transform": "none" }
},
"generated quantities": {
"u": { "type": "real", "dimensions": 0, "transform": "none" }
},
"transformed parameters": { "t": { "type": "real", "dimensions": 2 } },
"generated quantities": { "u": { "type": "real", "dimensions": 0 } },
"functions": [
"fatal_error", "log", "print", "reduce_sum", "reject", "sin", "square"
],
Expand Down
3 changes: 3 additions & 0 deletions test/integration/cli-args/info/info.stan
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ transformed data {
}

parameters {
real <lower=0> low;
simplex[10] l;
unit_vector[11] m;
ordered[12] n;
Expand All @@ -46,6 +47,8 @@ parameters {
cholesky_factor_cov[16] r;
cholesky_factor_corr[17] s;
real y;

tuple(real<lower=y>, tuple(simplex[18], cov_matrix[19])) parameter_transforms;
}

transformed parameters {
Expand Down

0 comments on commit 8fb1ed4

Please sign in to comment.