Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions compiler/frontend/ast_option_optimizations.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
open Parsetree
open Longident

(*
Optimise calls to Option.forEach/map/flatMap so they produce the same switch
structure as handwritten code. We only rewrite calls whose callback is a
simple literal lambda or identifier; more complex callbacks are left intact
to preserve ReScript's call-by-value semantics.
*)

let value_name = "__res_option_value"

type option_call = ForEach | Map | FlatMap

(* Inlineable callbacks are bare identifiers (possibly wrapped in coercions or
type annotations). Those can be applied directly inside the emitted switch
without introducing a let-binding that might change evaluation behaviour. *)
let rec callback_is_inlineable expr =
match expr.pexp_desc with
| Pexp_ident _ -> true
| Pexp_constraint (inner, _) | Pexp_coerce (inner, _, _) ->
callback_is_inlineable inner
| _ -> false

(* Detect literal lambdas (ignoring type annotations) so we can reuse their
argument binder in the rewritten switch. *)
let rec inline_lambda expr =
match expr.pexp_desc with
| Pexp_constraint (inner, _) | Pexp_coerce (inner, _, _) ->
inline_lambda inner
| Pexp_fun {arg_label = Asttypes.Nolabel; lhs; rhs; async = false} ->
Some (lhs, rhs)
| _ -> None

let transform (expr : Parsetree.expression) : Parsetree.expression =
match expr.pexp_desc with
| Pexp_apply
{
funct =
{
pexp_desc =
Pexp_ident
{txt = Ldot (Lident ("Option" | "Stdlib_Option"), fname)};
};
args = [(_, opt_expr); (_, func_expr)];
} -> (
Comment on lines +35 to +46

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P1] Guard rewrite against user-defined Option modules

The optimization currently rewrites any application whose head is syntactically Option.* or Stdlib_Option.* (see the Ldot (Lident ("Option" | "Stdlib_Option"), fname) pattern). Because this runs before name resolution, it will also trigger for locally defined modules named Option or Stdlib_Option. If a project defines its own module shadowing these names (for example a custom Option.map with different semantics), the pass will silently replace the call with a match on Some/None, miscompiling the user’s code. The transformation needs to confirm it is targeting the stdlib Option module—e.g. by running after type checking or by explicitly checking the module path after resolution—otherwise it is unsound.

Useful? React with 👍 / 👎.

let call_kind =
match fname with
| "forEach" -> Some ForEach
| "map" -> Some Map
| "flatMap" -> Some FlatMap
| _ -> None
in
match call_kind with
| None -> expr
| Some call_kind -> (
let loc_ghost = {expr.pexp_loc with loc_ghost = true} in
let emit_option_match value_pat result_expr =
let some_rhs =
match call_kind with
| ForEach | FlatMap -> result_expr
| Map ->
Ast_helper.Exp.construct ~loc:loc_ghost
{txt = Lident "Some"; loc = loc_ghost}
(Some result_expr)
in
let none_rhs =
match call_kind with
| ForEach ->
Ast_helper.Exp.construct ~loc:loc_ghost
{txt = Lident "()"; loc = loc_ghost}
None
| Map | FlatMap ->
Ast_helper.Exp.construct ~loc:loc_ghost
{txt = Lident "None"; loc = loc_ghost}
None
in
let mk_case ctor payload rhs =
{
Parsetree.pc_bar = None;
pc_lhs =
Ast_helper.Pat.construct ~loc:loc_ghost
{txt = Lident ctor; loc = loc_ghost}
payload;
pc_guard = None;
pc_rhs = rhs;
}
in
let some_case = mk_case "Some" (Some value_pat) some_rhs in
let none_case = mk_case "None" None none_rhs in
let transformed =
Ast_helper.Exp.match_ ~loc:loc_ghost opt_expr [some_case; none_case]
in
{
transformed with
pexp_loc = expr.pexp_loc;
pexp_attributes = expr.pexp_attributes;
}
in
match inline_lambda func_expr with
(* Literal lambda with a simple binder: reuse the binder directly inside
the generated switch, so the body runs exactly once with the option's
payload. *)
| Some ({ppat_desc = Parsetree.Ppat_var {txt}}, body) ->
let value_pat =
Ast_helper.Pat.var ~loc:loc_ghost {txt; loc = loc_ghost}
in
emit_option_match value_pat body
(* Callback is a simple identifier (possibly annotated). Apply it inside
the switch so evaluation order matches handwritten code. *)
| _ when callback_is_inlineable func_expr ->
let value_pat =
Ast_helper.Pat.var ~loc:loc_ghost {txt = value_name; loc = loc_ghost}
in
let value_ident =
Ast_helper.Exp.ident ~loc:loc_ghost
{txt = Lident value_name; loc = loc_ghost}
in
let apply_callback =
Ast_helper.Exp.apply ~loc:loc_ghost func_expr
[(Asttypes.Nolabel, value_ident)]
in
emit_option_match value_pat apply_callback
(* Complex callbacks are left as-is so we don't change when they run. *)
| _ -> expr))
| _ -> expr
1 change: 1 addition & 0 deletions compiler/frontend/ast_option_optimizations.mli
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
val transform : Parsetree.expression -> Parsetree.expression
3 changes: 2 additions & 1 deletion compiler/frontend/bs_builtin_ppx.ml
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ let expr_mapper ~async_context ~in_function_def (self : mapper)
body;
pexp_attributes;
})
| Pexp_apply _ -> Ast_exp_apply.app_exp_mapper e self
| Pexp_apply _ ->
Ast_exp_apply.app_exp_mapper e self |> Ast_option_optimizations.transform
| Pexp_match
( b,
[
Expand Down
12 changes: 8 additions & 4 deletions tests/tests/src/core/Core_ObjectTests.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import * as Test from "./Test.mjs";
import * as Stdlib_BigInt from "@rescript/runtime/lib/es6/Stdlib_BigInt.js";
import * as Stdlib_Option from "@rescript/runtime/lib/es6/Stdlib_Option.js";
import * as Primitive_object from "@rescript/runtime/lib/es6/Primitive_object.js";
import * as Primitive_option from "@rescript/runtime/lib/es6/Primitive_option.js";

let eq = Primitive_object.equal;

Expand Down Expand Up @@ -530,10 +531,13 @@ runGetTest({
3
]
}),
get: i => Stdlib_Option.getOr(Stdlib_Option.map(i["a"], i => i.concat([
4,
5
])), []),
get: i => {
let i$1 = i["a"];
return Stdlib_Option.getOr(i$1 !== undefined ? Primitive_option.valFromOption(i$1).concat([
4,
5
]) : undefined, []);
},
expected: [
1,
2,
Expand Down
10 changes: 5 additions & 5 deletions tests/tests/src/core/Core_TempTests.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,11 @@ console.log(regex.test(string));

let result = regex.exec(string);

let result$1 = (result == null) ? undefined : Primitive_option.some(result);

console.log(Stdlib_Option.map(result$1, prim => prim.input));
console.log(!(result == null) ? result.input : undefined);

console.log(Stdlib_Option.map(result$1, prim => prim.index));
console.log(!(result == null) ? result.index : undefined);

console.log(Stdlib_Option.map(result$1, prim => prim.slice(1)));
console.log(!(result == null) ? result.slice(1) : undefined);

console.info("");

Expand Down Expand Up @@ -322,6 +320,8 @@ let formatter = Core_IntlTests.formatter;

let segments = Core_IntlTests.segments;

let result$1 = (result == null) ? undefined : Primitive_option.some(result);

export {
_collator,
collator,
Expand Down
4 changes: 2 additions & 2 deletions tests/tests/src/core/intl/Core_IntlTests.mjs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Generated by ReScript, PLEASE EDIT WITH CARE

import * as Stdlib_JsExn from "@rescript/runtime/lib/es6/Stdlib_JsExn.js";
import * as Stdlib_Option from "@rescript/runtime/lib/es6/Stdlib_Option.js";
import * as Primitive_option from "@rescript/runtime/lib/es6/Primitive_option.js";
import * as Core_Intl_LocaleTest from "./Core_Intl_LocaleTest.mjs";
import * as Primitive_exceptions from "@rescript/runtime/lib/es6/Primitive_exceptions.js";
Expand Down Expand Up @@ -58,7 +57,8 @@ try {
let e$2 = Primitive_exceptions.internalToException(raw_e$2);
if (e$2.RE_EXN_ID === "JsExn") {
let e$3 = e$2._1;
let message = Stdlib_Option.map(Stdlib_JsExn.message(e$3), prim => prim.toLowerCase());
let __res_option_value = Stdlib_JsExn.message(e$3);
let message = __res_option_value !== undefined ? __res_option_value.toLowerCase() : undefined;
let exit = 0;
if (message === "invalid key : someinvalidkey") {
console.log("Caught expected error");
Expand Down
Loading
Loading