Skip to content

Commit

Permalink
Populate various free_names in Closure_conversion (rebased) (ocaml#181)
Browse files Browse the repository at this point in the history
  • Loading branch information
Keryan-dev committed Sep 14, 2021
1 parent d1e3d45 commit 408da9c
Show file tree
Hide file tree
Showing 16 changed files with 383 additions and 272 deletions.
220 changes: 95 additions & 125 deletions middle_end/flambda2/from_lambda/closure_conversion.ml

Large diffs are not rendered by default.

183 changes: 128 additions & 55 deletions middle_end/flambda2/from_lambda/closure_conversion_aux.ml
Expand Up @@ -216,8 +216,7 @@ module Acc = struct
{ declared_symbols : (Symbol.t * Flambda.Static_const.t) list;
shareable_constants : Symbol.t Flambda.Static_const.Map.t;
code : Flambda.Code.t Code_id.Map.t;
free_names_of_current_function : Name_occurrences.t;
free_continuations : Name_occurrences.t;
free_names : Name_occurrences.t;
cost_metrics : Flambda.Cost_metrics.t;
seen_a_function : bool
}
Expand All @@ -237,8 +236,7 @@ module Acc = struct
{ declared_symbols = [];
shareable_constants = Flambda.Static_const.Map.empty;
code = Code_id.Map.empty;
free_names_of_current_function = Name_occurrences.empty;
free_continuations = Name_occurrences.empty;
free_names = Name_occurrences.empty;
cost_metrics = Flambda.Cost_metrics.zero;
seen_a_function = false
}
Expand All @@ -249,9 +247,7 @@ module Acc = struct

let code t = t.code

let free_names_of_current_function t = t.free_names_of_current_function

let free_continuations t = t.free_continuations
let free_names t = t.free_names

let add_declared_symbol ~symbol ~constant t =
let declared_symbols = (symbol, constant) :: t.declared_symbols in
Expand All @@ -266,35 +262,51 @@ module Acc = struct
let add_code ~code_id ~code t =
{ t with code = Code_id.Map.add code_id code t.code }

let add_symbol_to_free_names ~symbol t =
let add_free_names free_names t =
{ t with free_names = Name_occurrences.union free_names t.free_names }

let add_name_to_free_names ~name t =
{ t with
free_names_of_current_function =
Name_occurrences.add_symbol t.free_names_of_current_function symbol
Name_mode.normal
free_names = Name_occurrences.add_name t.free_names name Name_mode.normal
}

let add_closure_var_to_free_names ~closure_var t =
let add_simple_to_free_names acc simple =
Simple.pattern_match simple
~const:(fun _ -> acc)
~name:(fun name ~coercion:_ -> add_name_to_free_names ~name acc)

let remove_code_id_or_symbol_from_free_names cis t =
{ t with
free_names_of_current_function =
Name_occurrences.add_closure_var t.free_names_of_current_function
closure_var Name_mode.normal
free_names = Name_occurrences.remove_code_id_or_symbol t.free_names cis
}

let add_continuation_occurrence ~cont ~has_traps t =
let remove_symbol_from_free_names symbol t =
remove_code_id_or_symbol_from_free_names (Symbol symbol) t

let remove_var_from_free_names var t =
{ t with free_names = Name_occurrences.remove_var t.free_names var }

let remove_continuation_from_free_names cont t =
{ t with
free_continuations =
Name_occurrences.add_continuation t.free_continuations cont ~has_traps
free_names = Name_occurrences.remove_continuation t.free_names cont
}

let with_free_names free_names t =
{ t with free_names_of_current_function = free_names }
let remove_code_id_from_free_names code_id t =
remove_code_id_or_symbol_from_free_names (Code_id code_id) t

let with_free_names free_names t = { t with free_names }

let eval_branch_free_names t ~f =
let base_free_names = t.free_names in
let t, res = f { t with free_names = Name_occurrences.empty } in
t.free_names, { t with free_names = base_free_names }, res

let measure_cost_metrics acc ~f =
let saved_cost_metrics = cost_metrics acc in
let acc = with_cost_metrics Flambda.Cost_metrics.zero acc in
let acc, return = f acc in
let free_names, acc, return = eval_branch_free_names acc ~f in
let cost_metrics = cost_metrics acc in
cost_metrics, with_cost_metrics saved_cost_metrics acc, return
cost_metrics, free_names, with_cost_metrics saved_cost_metrics acc, return
end

module Function_decls = struct
Expand Down Expand Up @@ -438,17 +450,7 @@ module Expr_with_acc = struct
(Code_size.apply apply |> Cost_metrics.from_size)
acc
in
let acc =
match Apply.continuation apply with
| Never_returns -> acc
| Return cont ->
Acc.add_continuation_occurrence ~cont ~has_traps:false acc
in
let acc =
Acc.add_continuation_occurrence
~cont:(Exn_continuation.exn_handler (Apply.exn_continuation apply))
~has_traps:false acc
in
let acc = Acc.add_free_names (Apply_expr.free_names apply) acc in
acc, Expr.create_apply apply

let create_let (acc, let_expr) =
Expand All @@ -464,6 +466,7 @@ module Expr_with_acc = struct
(Code_size.switch switch |> Cost_metrics.from_size)
acc
in
let acc = Acc.add_simple_to_free_names acc (Switch_expr.scrutinee switch) in
acc, Expr.create_switch switch

let create_invalid acc ?semantics () =
Expand All @@ -475,18 +478,15 @@ end

module Apply_cont_with_acc = struct
let create acc ?trap_action cont ~args ~dbg =
let acc =
Acc.add_continuation_occurrence ~cont
~has_traps:(match trap_action with None -> false | _ -> true)
acc
in
acc, Apply_cont.create ?trap_action cont ~args ~dbg
let apply_cont = Apply_cont.create ?trap_action cont ~args ~dbg in
let acc = Acc.add_free_names (Apply_cont.free_names apply_cont) acc in
acc, apply_cont

let goto acc cont = create acc cont ~args:[] ~dbg:Debuginfo.none
end

module Let_with_acc = struct
let create acc let_bound named ~body ~free_names_of_body =
let create acc let_bound named ~body =
let cost_metrics_of_defining_expr =
match named with
| Named.Prim (prim, _) -> Code_size.prim prim |> Cost_metrics.from_size
Expand All @@ -509,29 +509,52 @@ module Let_with_acc = struct
~cost_metrics_of_defining_expr)
acc
in
acc, Let.create let_bound named ~body ~free_names_of_body
let free_names_of_body = Or_unknown.Known (Acc.free_names acc) in
let acc =
Bindable_let_bound.fold_all_bound_names let_bound ~init:acc
~var:(fun acc var ->
Acc.remove_var_from_free_names (Var_in_binding_pos.var var) acc)
~symbol:(fun acc s -> Acc.remove_symbol_from_free_names s acc)
~code_id:(fun acc cid -> Acc.remove_code_id_from_free_names cid acc)
in
let expr = Let.create let_bound named ~body ~free_names_of_body in
let acc = Acc.add_free_names (Named.free_names named) acc in
acc, expr
end

module Continuation_handler_with_acc = struct
let create acc parameters ~handler ~free_names_of_handler ~is_exn_handler =
let create acc parameters ~handler ~is_exn_handler =
let free_names_of_handler = Or_unknown.Known (Acc.free_names acc) in
let acc =
List.fold_left
(fun acc param ->
Acc.remove_var_from_free_names (Kinded_parameter.var param) acc)
acc parameters
in
( acc,
Continuation_handler.create parameters ~handler ~free_names_of_handler
~is_exn_handler )
end

module Let_cont_with_acc = struct
let create_non_recursive acc cont handler ~body ~cost_metrics_of_handler =
let free_conts = Acc.free_continuations acc in
let acc =
Acc.increment_metrics
(Cost_metrics.increase_due_to_let_cont_non_recursive
~cost_metrics_of_handler)
acc
in
( acc,
(* This function only uses continuations of [free_names_of_body] *)
Let_cont.create_non_recursive cont handler ~body
~free_names_of_body:(Known free_conts) )
let create_non_recursive acc cont handler ~body ~free_names_of_body
~cost_metrics_of_handler =
match Name_occurrences.count_continuation free_names_of_body cont with
| Zero when not (Continuation_handler.is_exn_handler handler) -> acc, body
| _ ->
let acc =
Acc.increment_metrics
(Cost_metrics.increase_due_to_let_cont_non_recursive
~cost_metrics_of_handler)
acc
in
let expr =
(* This function only uses continuations of [free_names_of_body] *)
Let_cont.create_non_recursive cont handler ~body
~free_names_of_body:(Known free_names_of_body)
in
let acc = Acc.remove_continuation_from_free_names cont acc in
acc, expr

let create_recursive acc handlers ~body ~cost_metrics_of_handlers =
let acc =
Expand All @@ -540,5 +563,55 @@ module Let_cont_with_acc = struct
~cost_metrics_of_handlers)
acc
in
acc, Let_cont.create_recursive handlers ~body
let expr = Let_cont.create_recursive handlers ~body in
let acc =
Continuation.Map.fold
(fun cont _ acc -> Acc.remove_continuation_from_free_names cont acc)
handlers acc
in
acc, expr

let build_recursive acc ~handlers ~body =
let handlers_free_names, cost_metrics_of_handlers, acc, handlers =
Continuation.Map.fold
(fun cont (handler, params, is_exn_handler)
(free_names, costs, acc, handlers) ->
let cost_metrics_of_handler, handler_free_names, acc, handler =
Acc.measure_cost_metrics acc ~f:(fun acc ->
let acc, handler = handler acc in
Continuation_handler_with_acc.create acc params ~handler
~is_exn_handler)
in
( Name_occurrences.union free_names handler_free_names,
Cost_metrics.( + ) costs cost_metrics_of_handler,
acc,
Continuation.Map.add cont handler handlers ))
handlers
(Name_occurrences.empty, Cost_metrics.zero, acc, Continuation.Map.empty)
in
let body_free_names, acc, body = Acc.eval_branch_free_names acc ~f:body in
let acc =
Acc.with_free_names
(Name_occurrences.union body_free_names handlers_free_names)
acc
in
create_recursive acc handlers ~body ~cost_metrics_of_handlers

let build_non_recursive acc cont ~handler_params ~handler ~body
~is_exn_handler =
let cost_metrics_of_handler, handler_free_names, acc, handler =
Acc.measure_cost_metrics acc ~f:(fun acc ->
let acc, handler = handler acc in
Continuation_handler_with_acc.create acc handler_params ~handler
~is_exn_handler)
in
let free_names_of_body, acc, body =
Acc.eval_branch_free_names acc ~f:body
in
let acc, expr =
create_non_recursive
(Acc.with_free_names free_names_of_body acc)
cont handler ~body ~free_names_of_body ~cost_metrics_of_handler
in
Acc.add_free_names handler_free_names acc, expr
end
53 changes: 24 additions & 29 deletions middle_end/flambda2/from_lambda/closure_conversion_aux.mli
Expand Up @@ -135,9 +135,7 @@ module Acc : sig

val code : t -> Flambda.Code.t Code_id.Map.t

val free_names_of_current_function : t -> Name_occurrences.t

val free_continuations : t -> Name_occurrences.t
val free_names : t -> Name_occurrences.t

val seen_a_function : t -> bool

Expand All @@ -151,25 +149,31 @@ module Acc : sig

val add_code : code_id:Code_id.t -> code:Flambda.Code.t -> t -> t

val add_symbol_to_free_names : symbol:Symbol.t -> t -> t
val add_free_names : Name_occurrences.t -> t -> t

val add_closure_var_to_free_names : closure_var:Var_within_closure.t -> t -> t
val remove_var_from_free_names : Variable.t -> t -> t

val add_continuation_occurrence :
cont:Continuation.t -> has_traps:bool -> t -> t
val remove_continuation_from_free_names : Continuation.t -> t -> t

val with_free_names : Name_occurrences.t -> t -> t

(* This is intended to evaluate a distinct free_names from the one in acc, one
must be careful to update acc afterward when necessary *)
val eval_branch_free_names :
t -> f:(t -> t * 'a) -> Name_occurrences.t * t * 'a

val cost_metrics : t -> Flambda.Cost_metrics.t

val increment_metrics : Flambda.Cost_metrics.t -> t -> t

val with_cost_metrics : Flambda.Cost_metrics.t -> t -> t

(* Executes [f] in an acc with an empty cost metrics and returns the cost
metrics for the term generated by f separately from the one in the acc *)
metrics for the term generated by f separately from the one in the acc. As
for [eval_branch_free_names], the returned free_names differ from the one
in acc *)
val measure_cost_metrics :
t -> f:(t -> t * 'a) -> Flambda.Cost_metrics.t * t * 'a
t -> f:(t -> t * 'a) -> Flambda.Cost_metrics.t * Name_occurrences.t * t * 'a
end

(** Used to represent information about a set of function declarations during
Expand Down Expand Up @@ -273,33 +277,24 @@ module Let_with_acc : sig
Bindable_let_bound.t ->
Named.t ->
body:Expr_with_acc.t ->
free_names_of_body:Name_occurrences.t Or_unknown.t ->
Acc.t * Let.t
end

module Continuation_handler_with_acc : sig
val create :
Acc.t ->
Kinded_parameter.t list ->
handler:Expr_with_acc.t ->
free_names_of_handler:Name_occurrences.t Or_unknown.t ->
is_exn_handler:bool ->
Acc.t * Continuation_handler.t
end

module Let_cont_with_acc : sig
val create_non_recursive :
val build_recursive :
Acc.t ->
Continuation.t ->
Continuation_handler.t ->
body:Expr_with_acc.t ->
cost_metrics_of_handler:Cost_metrics.t ->
handlers:
((Acc.t -> Acc.t * Expr_with_acc.t) * Kinded_parameter.t list * bool)
Continuation.Map.t ->
body:(Acc.t -> Acc.t * Expr_with_acc.t) ->
Acc.t * Expr_with_acc.t

val create_recursive :
val build_non_recursive :
Acc.t ->
Continuation_handler.t Continuation.Map.t ->
body:Expr_with_acc.t ->
cost_metrics_of_handlers:Cost_metrics.t ->
Continuation.t ->
handler_params:Kinded_parameter.t list ->
handler:(Acc.t -> Acc.t * Expr_with_acc.t) ->
body:(Acc.t -> Acc.t * Expr_with_acc.t) ->
is_exn_handler:bool ->
Acc.t * Expr_with_acc.t
end

0 comments on commit 408da9c

Please sign in to comment.