Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/decl assign params #1441

Merged
merged 17 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions src/analysis_and_optimization/Optimize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,14 @@ let gen_inline_var (name : string) (id_var : string) =

let replace_fresh_local_vars (fname : string) stmt =
let f (m : (string, string) Core.Map.Poly.t) = function
| Stmt.Fixed.Pattern.Decl {decl_adtype; decl_type; decl_id; initialize} ->
| Stmt.Fixed.Pattern.Decl
{decl_adtype; decl_type; decl_id; initialize; assignment} ->
let new_name =
match Map.Poly.find m decl_id with
| Some existing -> existing
| None -> gen_inline_var fname decl_id in
( Stmt.Fixed.Pattern.Decl
{decl_adtype; decl_id= new_name; decl_type; initialize}
{decl_adtype; decl_id= new_name; decl_type; initialize; assignment}
, Map.Poly.set m ~key:decl_id ~data:new_name )
| Stmt.Fixed.Pattern.For {loopvar; lower; upper; body} ->
let new_name =
Expand Down Expand Up @@ -201,7 +202,8 @@ let handle_early_returns (fname : string) opt_var stmt =
{ decl_adtype= DataOnly
; decl_id= returned
; decl_type= Sized SInt
; initialize= true }
; initialize= true
; assignment= None }
; meta= Location_span.empty }
; Stmt.Fixed.
{ pattern=
Expand Down Expand Up @@ -294,7 +296,8 @@ let rec inline_function_expression propto adt fim (Expr.Fixed.{pattern; _} as e)
(Type.to_unsized decl_type)
; decl_id= inline_return_name
; decl_type
; initialize= false } ]
; initialize= false
; assignment= None } ]
(* We should minimize the code that's having its variables
replaced to avoid conflict with the (two) new dummy
variables introduced by inlining *)
Expand Down Expand Up @@ -972,7 +975,8 @@ let lazy_code_motion ?(preserve_stability = false) (mir : Program.Typed.t) =
{ decl_adtype= Expr.Typed.adlevel_of key
; decl_id= data
; decl_type= Type.Unsized (Expr.Typed.type_of key)
; initialize= true }
; initialize= true
; assignment= None }
; meta= Location_span.empty }
:: accum) in
let lazy_code_motion_base i stmt =
Expand Down
13 changes: 9 additions & 4 deletions src/frontend/Ast_to_Mir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,9 @@ let create_decl_with_assign decl_id declc decl_type initial_value transform
() } in
let decl =
Stmt.
{ Fixed.pattern= Decl {decl_adtype; decl_id; decl_type; initialize= true}
{ Fixed.pattern=
Decl
{decl_adtype; decl_id; decl_type; initialize= true; assignment= None}
; meta= smeta } in
let rhs_assignment =
Option.map
Expand Down Expand Up @@ -583,7 +585,8 @@ let rec trans_stmt ud_dists (declc : decl_context) (ts : Ast.typed_statement) =
{ decl_adtype= Expr.Typed.adlevel_of iteratee'
; decl_id= loopvar.name
; decl_type= Unsized decl_type
; initialize= true } } in
; initialize= true
; assignment= None } } in
let assignment var =
Stmt.Fixed.
{ pattern=
Expand Down Expand Up @@ -629,7 +632,8 @@ and trans_packed_assign loc trans_stmt lvals rhs assign_op =
{ decl_adtype= rhs.emeta.ad_level
; decl_id= sym
; decl_type= Unsized rhs_type
; initialize= false }
; initialize= false
; assignment= None }
; meta= rhs.emeta.loc } in
let assign =
{ temp with
Expand Down Expand Up @@ -743,7 +747,8 @@ let rec trans_sizedtype_decl declc tr name st =
{ decl_type= Sized SInt
; decl_id
; decl_adtype= DataOnly
; initialize= true }
; initialize= true
; assignment= None }
; meta= e.meta.loc } in
let assign =
{ Stmt.Fixed.pattern=
Expand Down
18 changes: 13 additions & 5 deletions src/middle/Stmt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ module Fixed = struct
{ decl_adtype: UnsizedType.autodifftype
; decl_id: string
; decl_type: 'a Type.t
; initialize: bool }
; initialize: bool
; assignment: 'a option }
WardBrian marked this conversation as resolved.
Show resolved Hide resolved
[@@deriving sexp, hash, map, fold, compare]

and 'e lvalue = 'e lbase * 'e Index.t list
Expand Down Expand Up @@ -70,9 +71,15 @@ module Fixed = struct
| Block stmts ->
Fmt.pf ppf "{@;<1 2>@[<v>%a@]@;}" Fmt.(list pp_s ~sep:cut) stmts
| SList stmts -> Fmt.(list pp_s ~sep:cut |> vbox) ppf stmts
| Decl {decl_adtype; decl_id; decl_type; _} ->
Fmt.pf ppf "@[<hov 2>%a%a@ %s;@]" UnsizedType.pp_autodifftype
decl_adtype (Type.pp pp_e) decl_type decl_id
(*TODO(Steve): Need a new one for decl with assign*)
| Decl {decl_adtype; decl_id; decl_type; assignment; _} -> (
match assignment with
| Some e ->
Fmt.pf ppf "@[<hov 2>%a%a@ %s = %a;@]" UnsizedType.pp_autodifftype
decl_adtype (Type.pp pp_e) decl_type decl_id pp_e e
| None ->
Fmt.pf ppf "@[<hov 2>%a%a@ %s;@]" UnsizedType.pp_autodifftype
decl_adtype (Type.pp pp_e) decl_type decl_id)

include Foldable.Make2 (struct
type nonrec ('a, 'b) t = ('a, 'b) t
Expand Down Expand Up @@ -143,7 +150,8 @@ module Helpers = struct
{ decl_adtype= Expr.Typed.adlevel_of e
; decl_id= sym
; decl_type= Unsized (Expr.Typed.type_of e)
; initialize= true }
; initialize= true
; assignment= None }
; meta= e.meta.loc } in
let assign =
{ decl with
Expand Down
3 changes: 2 additions & 1 deletion src/middle/Stmt.mli
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ module Fixed : sig
{ decl_adtype: UnsizedType.autodifftype
; decl_id: string
; decl_type: 'a Type.t
; initialize: bool }
; initialize: bool
; assignment: 'a option }
[@@deriving sexp, hash, compare]

and 'e lvalue = 'e lbase * 'e Index.t list
Expand Down
37 changes: 23 additions & 14 deletions src/stan_math_backend/Lower_stmt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,10 @@ let rec initialize_value st adtype =
(adtype : UnsizedType.autodifftype)]

(*Initialize an object of a given size.*)
let lower_assign_sized st adtype initialize =
if initialize then Some (initialize_value st adtype) else None
let lower_assign_sized st adtype initialize assignment =
match assignment with
| Some e -> Some (lower_expr e)
| None -> if initialize then Some (initialize_value st adtype) else None

let lower_unsized_decl name ut adtype =
let type_ =
Expand All @@ -103,25 +105,32 @@ let lower_unsized_decl name ut adtype =
| true, _ -> TypeLiteral "matrix_cl<double>" in
make_variable_defn ~type_ ~name ()

let lower_possibly_opencl_decl name st adtype =
let lower_possibly_opencl_decl name st adtype assignment =
let ut = SizedType.to_unsized st in
let mem_pattern = SizedType.get_mem_pattern st in
match (Transform_Mir.is_opencl_var name, ut) with
| _, UnsizedType.(UInt | UReal) | false, _ ->
lower_possibly_var_decl adtype ut mem_pattern
| _, UnsizedType.(UInt | UReal) | false, _ -> (
match assignment with
| Some
Expr.Fixed.
{ pattern= FunApp (CompilerInternal (Internal_fun.FnReadParam _), _)
; _ } ->
Auto
| _ -> lower_possibly_var_decl adtype ut mem_pattern)
| true, UArray UInt -> TypeLiteral "matrix_cl<int>"
| true, _ -> TypeLiteral "matrix_cl<double>"

let lower_sized_decl name st adtype initialize =
let type_ = lower_possibly_opencl_decl name st adtype in
let lower_sized_decl name st adtype initialize assignment =
let type_ = lower_possibly_opencl_decl name st adtype assignment in
let init =
lower_assign_sized st adtype initialize
lower_assign_sized st adtype initialize assignment
|> Option.value_map ~default:Uninitialized ~f:(fun i -> Assignment i) in
make_variable_defn ~type_ ~name ~init ()

let lower_decl vident pst adtype initialize =
let lower_decl vident pst adtype initialize assignment =
match pst with
| Type.Sized st -> VariableDefn (lower_sized_decl vident st adtype initialize)
| Type.Sized st ->
VariableDefn (lower_sized_decl vident st adtype initialize assignment)
| Unsized ut -> VariableDefn (lower_unsized_decl vident ut adtype)

let lower_profile name body =
Expand Down Expand Up @@ -320,8 +329,8 @@ let rec lower_statement Stmt.Fixed.{pattern; meta} : stmt list =
| Return e -> [Return (Option.map ~f:lower_expr e)]
| Block ls -> [Stmts.block (lower_statements ls)]
| SList ls -> lower_statements ls
| Decl {decl_adtype; decl_id; decl_type; initialize; _} ->
[lower_decl decl_id decl_type decl_adtype initialize]
| Decl {decl_adtype; decl_id; decl_type; initialize; assignment} ->
[lower_decl decl_id decl_type decl_adtype initialize assignment]
| Profile (name, ls) -> [lower_profile name (lower_statements ls)]

and lower_statements = List.concat_map ~f:lower_statement
Expand All @@ -333,7 +342,7 @@ module Testing = struct
(Fmt.option Cpp.Printing.pp_expr)
(lower_assign_sized
(SArray (SArray (SMatrix (AoS, int 2, int 3), int 4), int 5))
DataOnly false)
DataOnly false None)
|> print_endline;
[%expect {| |}]

Expand All @@ -343,7 +352,7 @@ module Testing = struct
(Fmt.option Cpp.Printing.pp_expr)
(lower_assign_sized
(SArray (SArray (SMatrix (AoS, int 2, int 3), int 4), int 5))
DataOnly true)
DataOnly true None)
|> print_endline;
[%expect
{|
Expand Down
56 changes: 44 additions & 12 deletions src/stan_math_backend/Transform_Mir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,8 @@ let rec var_context_read_inside_tuple enclosing_tuple_name origin_type
(SizedType.to_unsized t)
; decl_id= make_tuple_temp name
; decl_type= Sized t
; initialize= true }
; initialize= true
; assignment= None }
|> swrap)
tuple_component_names tuple_types in
let loop =
Expand Down Expand Up @@ -369,7 +370,8 @@ let rec var_context_read_inside_tuple enclosing_tuple_name origin_type
{ decl_adtype= AutoDiffable
; decl_id= decl_id_flat
; decl_type= Unsized flat_type
; initialize= true }
; initialize= true
; assignment= None }
|> swrap
, Assignment (Stmt.Helpers.lvariable decl_id_flat, flat_type, origin)
|> swrap
Expand Down Expand Up @@ -473,7 +475,8 @@ let rec var_context_read
{ decl_adtype= AutoDiffable
; decl_id= variable_name
; decl_type= Unsized array_type
; initialize= true }
; initialize= true
; assignment= None }
|> swrap_noloc
; Assignment
( Stmt.Helpers.lvariable variable_name
Expand All @@ -484,7 +487,8 @@ let rec var_context_read
{ decl_adtype= DataOnly
; decl_id= variable_name ^ "pos__"
; decl_type= Unsized UInt
; initialize= true }
; initialize= true
; assignment= None }
|> swrap_noloc
; Stmt.Fixed.Pattern.Assignment
( Stmt.Helpers.lvariable (variable_name ^ "pos__")
Expand Down Expand Up @@ -512,7 +516,8 @@ let rec var_context_read
(SizedType.to_unsized t)
; decl_id= make_tuple_temp name
; decl_type= Sized t
; initialize= true }
; initialize= true
; assignment= None }
|> swrap_noloc)
tuple_component_names tuple_types in
let loop =
Expand Down Expand Up @@ -559,7 +564,8 @@ let rec var_context_read
{ decl_adtype= AutoDiffable
; decl_id= decl_id_flat
; decl_type= Unsized flat_type
; initialize= false }
; initialize= false
; assignment= None }
|> swrap
, Assignment
( Stmt.Helpers.lvariable decl_id_flat
Expand Down Expand Up @@ -764,9 +770,31 @@ let add_reads vars mkread stmts =
let var_names = String.Map.of_alist_exn vars in
let add_read_to_decl (Stmt.Fixed.{pattern; _} as stmt) =
match pattern with
| Decl {decl_id; _} when Map.mem var_names decl_id ->
| Decl {decl_id; decl_adtype; decl_type; initialize; _}
when Map.mem var_names decl_id -> (
let loc, out = Map.find_exn var_names decl_id in
stmt :: mkread (Stmt.Helpers.lvariable decl_id, loc, out)
let param_reader = mkread (Stmt.Helpers.lvariable decl_id, loc, out) in
match param_reader with
| [ Stmt.Fixed.
WardBrian marked this conversation as resolved.
Show resolved Hide resolved
{ pattern=
Stmt.Fixed.Pattern.Assignment
( _
, _
, (Expr.Fixed.
{ pattern=
Expr.Fixed.Pattern.FunApp
(CompilerInternal (Internal_fun.FnReadParam _), _)
; _ } as e) )
; _ } ] ->
[ { stmt with
pattern=
Stmt.Fixed.Pattern.Decl
{ decl_id
; decl_adtype
; decl_type
; initialize
; assignment= Some e } } ]
| _ -> stmt :: param_reader)
| _ -> [stmt] in
List.concat_map ~f:add_read_to_decl stmts

Expand Down Expand Up @@ -872,7 +900,8 @@ let var_context_unconstrain_transform (decl_id, smeta, outvar) =
(SizedType.to_unsized st)
; decl_id
; decl_type= Type.Sized st
; initialize= true }
; initialize= true
; assignment= None }
; meta= smeta }
:: var_context_read (Stmt.Helpers.lvariable decl_id, smeta, st)
@ param_serializer_write ~unconstrain:true (decl_id, outvar)
Expand All @@ -888,7 +917,8 @@ let array_unconstrain_transform (decl_id, smeta, outvar) =
(SizedType.to_unsized outvar.Program.out_constrained_st)
; decl_id
; decl_type= Type.Sized outvar.Program.out_constrained_st
; initialize= true }
; initialize= true
; assignment= None }
; meta= smeta } in
let rec read (lval, st) =
match st with
Expand Down Expand Up @@ -1028,7 +1058,8 @@ let trans_prog (p : Program.Typed.t) =
{ decl_adtype= DataOnly
; decl_id= pos
; decl_type= Sized SInt
; initialize= true }
; initialize= true
; assignment= None }
; Assignment (Stmt.Helpers.lvariable pos, UInt, Expr.Helpers.loop_bottom) ]
|> List.map ~f:(fun pattern ->
Stmt.Fixed.{pattern; meta= Location_span.empty}) in
Expand Down Expand Up @@ -1145,7 +1176,8 @@ let trans_prog (p : Program.Typed.t) =
{ decl_adtype= DataOnly
; decl_id= vident
; decl_type= Type.Unsized type_of_input_var
; initialize= true }
; initialize= true
; assignment= None }
; meta= Location_span.empty }
; { pattern=
Assignment
Expand Down
Loading