diff --git a/compiler/ml/typecore.ml b/compiler/ml/typecore.ml index 6d9c4fd441..03be96c650 100644 --- a/compiler/ml/typecore.ml +++ b/compiler/ml/typecore.ml @@ -95,10 +95,14 @@ type error = | Type_params_not_supported of Longident.t | Field_access_on_dict_type | Jsx_not_enabled + | AliasPolyFromLiterals_NonLiteralOrPattern + | AliasPolyFromLiterals_UnsupportedBaseType of type_expr exception Error of Location.t * Env.t * error exception Error_forward of Location.error +module StringSet = Set.Make (String) + (* Forward declaration, to be filled in by Typemod.type_module *) let type_module = @@ -608,6 +612,48 @@ let build_or_pat env loc lid = in (path, rp {r with pat_loc = loc}, ty) +let maybe_closed_polyvariant_from_literal_or ~env ~loc ~attrs ~typed_pat + ~base_ty = + let has_attr = + List.exists (function + | {Location.txt = "res.asPolyVariantFromLiterals"}, _ -> true + | _ -> false) + in + if not (has_attr attrs) then None + else + let invalid = ref false in + let rec gather (p : pattern) (acc : string list) = + match p.pat_desc with + | Tpat_or (p1, p2, _) -> gather p2 (gather p1 acc) + | Tpat_constant (Const_string (s, _)) -> s :: acc + | Tpat_constant (Const_int i) -> string_of_int i :: acc + | _ -> + invalid := true; + acc + in + let literals = + gather typed_pat [] |> List.rev |> StringSet.of_list |> StringSet.elements + in + if !invalid || literals = [] then + raise (Error (loc, env, AliasPolyFromLiterals_NonLiteralOrPattern)) + else + match (repr (expand_head env base_ty)).desc with + | Tconstr (p, [], _) + when Path.same p Predef.path_string || Path.same p Predef.path_int -> + let row = + { + row_fields = List.map (fun l -> (l, Rpresent None)) literals; + row_more = newvar (); + row_closed = true; + row_fixed = false; + row_name = None; + } + in + Some (newty (Tvariant row)) + | _ -> + raise + (Error (loc, env, AliasPolyFromLiterals_UnsupportedBaseType base_ty)) + let extract_type_from_pat_variant_spread env lid expected_ty = let path, decl = Typetexp.find_type env lid.loc lid.txt in match decl with @@ -1290,7 +1336,14 @@ and type_pat_aux ~constrs ~labels ~no_existentials ~mode ~explode ~env sp let ty_var = match override_type_from_variant_spread with | Some ty -> ty - | None -> build_as_type !env q + | None -> ( + let base_ty = build_as_type !env q in + match + maybe_closed_polyvariant_from_literal_or ~env:!env ~loc + ~attrs:sp.ppat_attributes ~typed_pat:q ~base_ty + with + | Some ty -> ty + | None -> base_ty) in end_def (); generalize ty_var; @@ -4719,6 +4772,16 @@ let report_error env loc ppf error = fprintf ppf "Cannot compile JSX expression because JSX support is not enabled. Add \ \"jsx\" settings to rescript.json to enable JSX support." + | AliasPolyFromLiterals_NonLiteralOrPattern -> + fprintf ppf + "@[The alias form @{as #...id@} can only be used when matching an \ + or-pattern made solely of string or int literals, e.g. \ + @{\"a\"|\"b\"@} or @{1|2@}.@]" + | AliasPolyFromLiterals_UnsupportedBaseType ty -> + fprintf ppf + "@[The alias form @{as #...id@} requires the alias to have type \ + @{string@} or @{int@}, but here it has type@ %a@]" + type_expr ty let report_error env loc ppf err = Printtyp.wrap_printing_env env (fun () -> report_error env loc ppf err) diff --git a/compiler/ml/typecore.mli b/compiler/ml/typecore.mli index 2a35693ec2..a37ee8ce93 100644 --- a/compiler/ml/typecore.mli +++ b/compiler/ml/typecore.mli @@ -127,6 +127,8 @@ type error = | Type_params_not_supported of Longident.t | Field_access_on_dict_type | Jsx_not_enabled + | AliasPolyFromLiterals_NonLiteralOrPattern + | AliasPolyFromLiterals_UnsupportedBaseType of type_expr exception Error of Location.t * Env.t * error exception Error_forward of Location.error diff --git a/compiler/syntax/src/res_core.ml b/compiler/syntax/src/res_core.ml index ec0e7b38ad..ef10c443d9 100644 --- a/compiler/syntax/src/res_core.ml +++ b/compiler/syntax/src/res_core.ml @@ -199,6 +199,9 @@ let template_literal_attr = (Location.mknoloc "res.template", Parsetree.PStr []) let make_pat_variant_spread_attr = (Location.mknoloc "res.patVariantSpread", Parsetree.PStr []) +let make_alias_polyvariant_from_literals_attr = + (Location.mknoloc "res.asPolyVariantFromLiterals", Parsetree.PStr []) + let tagged_template_literal_attr = (Location.mknoloc "res.taggedTemplate", Parsetree.PStr []) @@ -1318,11 +1321,19 @@ and parse_alias_pattern ~attrs pattern p = match p.Parser.token with | As -> Parser.next p; + let alias_attrs = + match p.Parser.token with + | Hash -> + Parser.next p; + Parser.expect DotDotDot p; + make_alias_polyvariant_from_literals_attr :: attrs + | _ -> attrs + in let name, loc = parse_lident p in let name = Location.mkloc name loc in Ast_helper.Pat.alias ~loc:{pattern.ppat_loc with loc_end = p.prev_end_pos} - ~attrs pattern name + ~attrs:alias_attrs pattern name | _ -> pattern (* or ::= pattern | pattern diff --git a/compiler/syntax/src/res_parsetree_viewer.ml b/compiler/syntax/src/res_parsetree_viewer.ml index fe5430ebe5..8e6a8f8465 100644 --- a/compiler/syntax/src/res_parsetree_viewer.ml +++ b/compiler/syntax/src/res_parsetree_viewer.ml @@ -80,6 +80,13 @@ let has_res_pat_variant_spread_attribute attrs = | _ -> false) attrs +let has_attribute_with_name ~name attrs = + List.exists + (function + | {Location.txt}, _ when txt = name -> true + | _ -> false) + attrs + let has_dict_pattern_attribute attrs = attrs |> List.find_opt (fun (({txt}, _) : Parsetree.attribute) -> @@ -201,7 +208,8 @@ let filter_parsing_attrs attrs = ( "meth" | "res.braces" | "ns.braces" | "res.iflet" | "res.ternary" | "res.await" | "res.template" | "res.taggedTemplate" | "res.patVariantSpread" - | "res.dictPattern" | "res.inlineRecordDefinition" ); + | "res.asPolyVariantFromLiterals" | "res.dictPattern" + | "res.inlineRecordDefinition" ); }, _ ) -> false @@ -543,7 +551,8 @@ let is_printable_attribute attr = | ( { Location.txt = ( "res.iflet" | "res.braces" | "ns.braces" | "JSX" | "res.await" - | "res.template" | "res.ternary" | "res.inlineRecordDefinition" ); + | "res.template" | "res.ternary" | "res.inlineRecordDefinition" + | "res.asPolyVariantFromLiterals" ); }, _ ) -> false diff --git a/compiler/syntax/src/res_parsetree_viewer.mli b/compiler/syntax/src/res_parsetree_viewer.mli index c50e3f2ce8..ec1300e925 100644 --- a/compiler/syntax/src/res_parsetree_viewer.mli +++ b/compiler/syntax/src/res_parsetree_viewer.mli @@ -16,6 +16,7 @@ val expr_is_await : Parsetree.expression -> bool val has_await_attribute : Parsetree.attributes -> bool val has_inline_record_definition_attribute : Parsetree.attributes -> bool val has_res_pat_variant_spread_attribute : Parsetree.attributes -> bool +val has_attribute_with_name : name:string -> Parsetree.attributes -> bool val has_dict_pattern_attribute : Parsetree.attributes -> bool type if_condition_kind = diff --git a/compiler/syntax/src/res_printer.ml b/compiler/syntax/src/res_printer.ml index a7c95cf5ee..bea374dafa 100644 --- a/compiler/syntax/src/res_printer.ml +++ b/compiler/syntax/src/res_printer.ml @@ -2611,18 +2611,24 @@ and print_pattern ~state (p : Parsetree.pattern) cmt_tbl = (Doc.concat docs) | Ppat_extension ext -> print_extension ~state ~at_module_lvl:false ext cmt_tbl - | Ppat_alias (p, alias_loc) -> + | Ppat_alias (inner_p, alias_loc) -> let needs_parens = - match p.ppat_desc with + match inner_p.ppat_desc with | Ppat_or (_, _) | Ppat_alias (_, _) -> true | _ -> false in let rendered_pattern = - let p = print_pattern ~state p cmt_tbl in + let p = print_pattern ~state inner_p cmt_tbl in if needs_parens then Doc.concat [Doc.text "("; p; Doc.text ")"] else p in - Doc.concat - [rendered_pattern; Doc.text " as "; print_string_loc alias_loc cmt_tbl] + let alias_doc = + if + ParsetreeViewer.has_attribute_with_name + ~name:"res.asPolyVariantFromLiterals" p.ppat_attributes + then Doc.concat [Doc.text "#..."; print_string_loc alias_loc cmt_tbl] + else print_string_loc alias_loc cmt_tbl + in + Doc.concat [rendered_pattern; Doc.text " as "; alias_doc] (* Note: module(P : S) is represented as *) (* Ppat_constraint(Ppat_unpack, Ptyp_package) *) | Ppat_constraint diff --git a/tests/build_tests/super_errors/expected/alias_poly_nonliteral.res.expected b/tests/build_tests/super_errors/expected/alias_poly_nonliteral.res.expected new file mode 100644 index 0000000000..f4ec5208e0 --- /dev/null +++ b/tests/build_tests/super_errors/expected/alias_poly_nonliteral.res.expected @@ -0,0 +1,11 @@ + + We've found a bug for you! + /.../fixtures/alias_poly_nonliteral.res:3:5-22 + + 1 │ let f = x => + 2 │ switch x { + 3 │ | ("a" | _) as #...f => 1 + 4 │ | _ => 0 + 5 │ } + + The alias form as #...id can only be used when matching an or-pattern made solely of string or int literals, e.g. "a"|"b" or 1|2. diff --git a/tests/build_tests/super_errors/fixtures/alias_poly_nonliteral.res b/tests/build_tests/super_errors/fixtures/alias_poly_nonliteral.res new file mode 100644 index 0000000000..ac2a330835 --- /dev/null +++ b/tests/build_tests/super_errors/fixtures/alias_poly_nonliteral.res @@ -0,0 +1,5 @@ +let f = x => + switch x { + | ("a" | _) as #...f => 1 + | _ => 0 + } diff --git a/tests/syntax_tests/data/parsing/grammar/pattern/expected/or.res.txt b/tests/syntax_tests/data/parsing/grammar/pattern/expected/or.res.txt index d337882221..640f1f5825 100644 --- a/tests/syntax_tests/data/parsing/grammar/pattern/expected/or.res.txt +++ b/tests/syntax_tests/data/parsing/grammar/pattern/expected/or.res.txt @@ -4,4 +4,7 @@ | Blue as c1|Red as c2 -> () | Blue as c1|Red as c2 -> () | exception Exit|exception Continue -> () - | exception (Exit|exception Continue) -> () \ No newline at end of file + | exception (Exit|exception Continue) -> () + | (({js|a|js}|{js|b|js}|{js|c|js} as f)[@res.asPolyVariantFromLiterals ]) + -> () + | ((1|2 as n)[@res.asPolyVariantFromLiterals ]) -> () \ No newline at end of file diff --git a/tests/syntax_tests/data/parsing/grammar/pattern/or.res b/tests/syntax_tests/data/parsing/grammar/pattern/or.res index d83a3bb2f7..6b7232c1c5 100644 --- a/tests/syntax_tests/data/parsing/grammar/pattern/or.res +++ b/tests/syntax_tests/data/parsing/grammar/pattern/or.res @@ -5,4 +5,6 @@ switch x { | (Blue as c1) | (Red as c2) => () | exception Exit | exception Continue => () | exception (Exit | exception Continue) => () +| ("a" | "b" | "c") as #...f => () +| (1 | 2) as #...n => () } diff --git a/tests/tests/src/alias_poly_from_literals.mjs b/tests/tests/src/alias_poly_from_literals.mjs new file mode 100644 index 0000000000..a5e118742d --- /dev/null +++ b/tests/tests/src/alias_poly_from_literals.mjs @@ -0,0 +1,61 @@ +// Generated by ReScript, PLEASE EDIT WITH CARE + + +function useABC(x) { + if (x === "b") { + return 2; + } else if (x === "c") { + return 3; + } else { + return 1; + } +} + +function fromString(s) { + switch (s) { + case "a" : + case "b" : + case "c" : + return useABC(s); + default: + return 0; + } +} + +function useNums(x) { + if (x === 2) { + return 20; + } else { + return 10; + } +} + +function fromInt(i) { + if (i !== 2 && i !== 1) { + return 0; + } else { + return useNums(i); + } +} + +function useSingle(x) { + return x + " world!"; +} + +function fromString2(s) { + if (s === "hello") { + return useSingle(s); + } else { + return "unknown"; + } +} + +export { + useABC, + fromString, + useNums, + fromInt, + useSingle, + fromString2, +} +/* No side effect */ diff --git a/tests/tests/src/alias_poly_from_literals.res b/tests/tests/src/alias_poly_from_literals.res new file mode 100644 index 0000000000..33a6cfc473 --- /dev/null +++ b/tests/tests/src/alias_poly_from_literals.res @@ -0,0 +1,38 @@ +type abc = [#a | #b | #c] + +let useABC = (x: abc) => + switch x { + | #a => 1 + | #b => 2 + | #c => 3 + } + +let fromString = s => + switch s { + | ("a" | "b" | "c") as #...f => useABC(f) + | _ => 0 + } + +type nums = [#1 | #2] + +let useNums = (x: nums) => + switch x { + | #1 => 10 + | #2 => 20 + } + +let fromInt = i => + switch i { + | (1 | 2) as #...n => useNums(n) + | _ => 0 + } + +let useSingle = (x: [#hello]) => { + `${(x :> string)} world!` +} + +let fromString2 = s => + switch s { + | "hello" as #...g => useSingle(g) + | _ => "unknown" + }