Skip to content

Commit 8d74ceb

Browse files
authored
Merge pull request #1471 from stan-dev/fix/1470-underscore-jacobian-bwd-compat
Allow existing uses of _jacobian in function names, with warning
2 parents 96e409a + 02c8663 commit 8d74ceb

File tree

11 files changed

+266
-30
lines changed

11 files changed

+266
-30
lines changed

src/frontend/Deprecation_analysis.ml

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,76 @@ let lkj_cov_message =
4141
independent lognormal distribution on the scales, see: \
4242
https://mc-stan.org/docs/reference-manual/deprecations.html#lkj_cov-distribution"
4343

44+
let functions_block_contains_jac_pe (stmts : untyped_statement list) =
45+
(* tracking if 'jacobian' is a variable in scope *)
46+
let jacobian_scope_id = ref 0 in
47+
let is_jacobian_in_scope () = !jacobian_scope_id > 0 in
48+
let current_scope_id = ref 1 in
49+
let found_jacobian () =
50+
if not (is_jacobian_in_scope ()) then jacobian_scope_id := !current_scope_id
51+
in
52+
let push_scope () = current_scope_id := !current_scope_id + 1 in
53+
let pop_scope () =
54+
current_scope_id := !current_scope_id - 1;
55+
(* if the scope we just left was the one defining jacobian, reset it *)
56+
if !jacobian_scope_id > !current_scope_id then jacobian_scope_id := 0 in
57+
(* walk over the tree, looking for usages of jacobian+= where
58+
there is no variable called jacobian already in scope *)
59+
let rec f (s : untyped_statement) =
60+
match s.stmt with
61+
| FunDef {body; funname; _}
62+
when String.is_suffix funname.name ~suffix:"_jacobian" ->
63+
push_scope ();
64+
let res = f body in
65+
pop_scope ();
66+
res
67+
| Block stmts | Profile (_, stmts) ->
68+
push_scope ();
69+
let res = List.exists ~f stmts in
70+
pop_scope ();
71+
res
72+
| For {loop_body; _} | While (_, loop_body) | ForEach (_, _, loop_body) ->
73+
push_scope ();
74+
let res = f loop_body in
75+
pop_scope ();
76+
res
77+
| IfThenElse (_, s1, s2_opt) ->
78+
push_scope ();
79+
let res1 = f s1 in
80+
pop_scope ();
81+
push_scope ();
82+
let res2 = match s2_opt with Some s2 -> f s2 | None -> false in
83+
pop_scope ();
84+
res1 || res2
85+
| JacobianPE _ -> true
86+
| Assignment
87+
{ assign_lhs= LValue {lval= LVariable {name; _}; _}
88+
; assign_op= OperatorAssign Plus
89+
; _ }
90+
when String.equal name "jacobian" ->
91+
not (is_jacobian_in_scope ())
92+
| VarDecl {variables; _} ->
93+
if
94+
List.exists
95+
~f:(fun {identifier; _} -> String.equal identifier.name "jacobian")
96+
variables
97+
then found_jacobian ();
98+
false
99+
| _ -> false in
100+
let res = List.exists ~f stmts in
101+
(* sanity check that pushes and pops are balanced *)
102+
if !current_scope_id <> 1 then
103+
Common.ICE.internal_compiler_error
104+
[%message
105+
"functions_block_contains_jac_pe: scope tracking failed"
106+
(!current_scope_id : int)
107+
(!jacobian_scope_id : int)
108+
(stmts : untyped_statement list)];
109+
res
110+
111+
let set_jacobian_compatibility_mode stmts =
112+
Fun_kind.jacobian_compat_mode := not (functions_block_contains_jac_pe stmts)
113+
44114
let rec collect_deprecated_expr (acc : (Location_span.t * string) list)
45115
({expr; emeta} : (typed_expr_meta, fun_kind) expr_with) :
46116
(Location_span.t * string) list =
@@ -89,6 +159,22 @@ let rec collect_deprecated_stmt fundefs (acc : (Location_span.t * string) list)
89159
, "Functions do not need to be declared before definition; all user \
90160
defined function names are always in scope regardless of \
91161
definition order." ) ]
162+
| FunDef {funname; body; _}
163+
when !Fun_kind.jacobian_compat_mode
164+
&& String.is_suffix funname.name ~suffix:"_jacobian" ->
165+
let acc =
166+
( funname.id_loc
167+
, "Functions that end in _jacobian will change meaning in Stan 2.39. \
168+
They will be used for the encapsulating usages of 'jacobian +=', \
169+
and therefore not available to be called in all the same places as \
170+
this function is now. To avoid any issues, please rename this \
171+
function to not end in _jacobian." )
172+
:: acc in
173+
fold_statement collect_deprecated_expr
174+
(collect_deprecated_stmt fundefs)
175+
collect_deprecated_lval
176+
(fun l _ -> l)
177+
acc body.stmt
92178
| Tilde {distribution; _} when String.equal distribution.name "lkj_cov" ->
93179
let acc = (distribution.id_loc, lkj_cov_message) :: acc in
94180
fold_statement collect_deprecated_expr

src/frontend/Deprecation_analysis.mli

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,8 @@ val rename_deprecated : (string * (int * int)) String.Map.t -> string -> string
1111
val stan_lib_deprecations : (string * (int * int)) String.Map.t
1212
val collect_warnings : typed_program -> Warnings.t list
1313
val remove_unneeded_forward_decls : typed_program -> typed_program
14+
15+
val set_jacobian_compatibility_mode : untyped_statement list -> unit
16+
(** Pre-Stan 2.39, we need to know if _jacobian functions are
17+
FnPlain or not. We use the presence of any jacobian+= statements
18+
as our condition. If none are present, we assume this is old code. *)

src/frontend/Typechecker.ml

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,7 @@ let verify_fn_target_plus_equals cf loc id =
454454
let verify_fn_jacobian_plus_equals cf loc id =
455455
if
456456
String.is_suffix id.name ~suffix:"_jacobian"
457+
&& (not !Fun_kind.jacobian_compat_mode)
457458
&& not (in_jacobian_function cf || cf.current_block = TParam)
458459
then Semantic_error.jacobian_plusequals_not_allowed loc |> error
459460

@@ -913,13 +914,6 @@ let check_expression_of_scalar_or_type cf tenv t e name =
913914

914915
(* -- Statements ------------------------------------------------- *)
915916
(* non returning functions *)
916-
let verify_nrfn_target loc cf id =
917-
if
918-
String.is_suffix id.name ~suffix:"_lp"
919-
&& not
920-
(in_lp_function cf || cf.current_block = Model
921-
|| cf.current_block = TParam)
922-
then Semantic_error.target_plusequals_outside_model_or_logprob loc |> error
923917

924918
let check_nrfn loc tenv id es =
925919
match Env.find tenv id.name with
@@ -960,7 +954,9 @@ let check_nrfn loc tenv id es =
960954
let check_nr_fn_app loc cf tenv id es =
961955
let tes = List.map ~f:(check_expression cf tenv) es in
962956
verify_identifier id;
963-
verify_nrfn_target loc cf id;
957+
verify_fn_target_plus_equals cf loc id;
958+
verify_fn_jacobian_plus_equals cf loc id;
959+
verify_fn_rng cf loc id;
964960
check_nrfn loc tenv id tes
965961

966962
(* target plus-equals / jacobian plus-equals *)
@@ -1894,6 +1890,8 @@ let add_userdefined_functions tenv stmts_opt =
18941890
match stmts_opt with
18951891
| None -> tenv
18961892
| Some {stmts; _} ->
1893+
(* TODO(2.39): Remove this workaround *)
1894+
Deprecation_analysis.set_jacobian_compatibility_mode stmts;
18971895
let f tenv (s : Ast.untyped_statement) =
18981896
match s with
18991897
| {stmt= FunDef {returntype; funname; arguments; body}; smeta= {loc}} ->

src/middle/Fun_kind.ml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,16 @@ type 'e t =
2121
| UserDefined of string * bool suffix
2222
[@@deriving compare, sexp, hash, map, fold]
2323

24+
(** If true, we assume _jacobian functions are
25+
"plain" functions for the purposes of typechecking and warnings
26+
*)
27+
let jacobian_compat_mode = ref false
28+
2429
let suffix_from_name fname =
2530
let is_suffix suffix = Core.String.is_suffix ~suffix fname in
2631
if is_suffix "_rng" then FnRng
2732
else if is_suffix "_lp" then FnTarget
28-
else if is_suffix "_jacobian" then FnJacobian
33+
else if is_suffix "_jacobian" && not !jacobian_compat_mode then FnJacobian
2934
else if is_suffix "_lupdf" then FnLpdf true
3035
else if is_suffix "_lupmf" then FnLpmf true
3136
else if is_suffix "_lpdf" then FnLpdf false
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
functions {
2+
// void return type to check function statement, rather than expression
3+
void foo_jacobian() {
4+
jacobian += 1;
5+
}
6+
}
7+
transformed data {
8+
foo_jacobian();
9+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
functions {
2+
void foo_rng(real x){
3+
print(normal_rng(0,x));
4+
}
5+
}
6+
7+
model {
8+
foo_rng(1.0);
9+
}

test/integration/bad/stanc.expected

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,6 +1083,18 @@ Semantic error in 'err-jacobian-plusequals-scope3.stan', line 14, column 11 to c
10831083
15: }
10841084
-------------------------------------------------
10851085

1086+
The jacobian adjustment can only be applied in the transformed parameters block or in functions ending with _jacobian
1087+
[exit 1]
1088+
$ ../../../../install/default/bin/stanc err-jacobian-plusequals-scope4.stan
1089+
Semantic error in 'err-jacobian-plusequals-scope4.stan', line 8, column 2 to column 17:
1090+
-------------------------------------------------
1091+
6: }
1092+
7: transformed data {
1093+
8: foo_jacobian();
1094+
^
1095+
9: }
1096+
-------------------------------------------------
1097+
10861098
The jacobian adjustment can only be applied in the transformed parameters block or in functions ending with _jacobian
10871099
[exit 1]
10881100
$ ../../../../install/default/bin/stanc err-minus-types.stan
@@ -1245,6 +1257,18 @@ Syntax error in 'err-transformed-params.stan', line 4, column 0 to column 11, pa
12451257
-------------------------------------------------
12461258

12471259
"transformed parameters {", "model {" or "generated quantities {" expected after end of parameters block.
1260+
[exit 1]
1261+
$ ../../../../install/default/bin/stanc err_void_rng_check.stan
1262+
Semantic error in 'err_void_rng_check.stan', line 8, column 4 to column 17:
1263+
-------------------------------------------------
1264+
6:
1265+
7: model {
1266+
8: foo_rng(1.0);
1267+
^
1268+
9: }
1269+
-------------------------------------------------
1270+
1271+
Random number generators are only allowed in transformed data block, generated quantities block or user-defined functions with names ending in _rng.
12481272
[exit 1]
12491273
$ ../../../../install/default/bin/stanc expect_statement_seq_close_brace.stan
12501274
Syntax error in 'expect_statement_seq_close_brace.stan', line 6, column 0 to column 0, parsing error:

test/integration/cli-args/warn-pedantic/stanc.expected

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,10 @@ Warning in 'jacobian_warning_user.stan', line 5, column 2: Left-hand side of
546546
using jacobian += in the transformed parameters block.
547547
[exit 0]
548548
$ ../../../../../install/default/bin/stanc --warn-pedantic lp_fun.stan
549+
Warning in 'lp_fun.stan', line 10, column 2: Using _lp functions in
550+
transformed parameters is deprecated and will be disallowed in Stan 2.39.
551+
Use an _jacobian function instead, as this allows change of variable
552+
adjustments which are conditionally enabled by the algorithms.
549553
Warning: The parameter y has 2 priors.
550554
[exit 0]
551555
$ ../../../../../install/default/bin/stanc --warn-pedantic missing-prior-false-alarm.stan

0 commit comments

Comments
 (0)