Skip to content

Commit 96acd4d

Browse files
authored
Subtype: avoid false alarm caused by eager forall_exists_subtype. (#48441)
* Avoid earsing `Runion` within nested `forall_exists_subtype` If `Runion.more != 0` we‘d better not erase the local `Runion` as we need it if the subtyping fails after. This commit replaces `forall_exists_subtype` with a local version. It first tries `forall_exists_subtype` and estimates the "problem scale". If subtyping fails and the scale looks small then it switches to the slow path. TODO: At present, the "problem scale" only counts the number of checked `Lunion`s. But perhaps we need a more accurate result (e.g. sum of `Runion.depth`) * Change the reversed subtyping into a local check. Make sure we don't forget the bound in `env`. (And we can fuse `local_forall_exists_subtype`) * Optimization for non-union invariant parameter.
1 parent d918576 commit 96acd4d

File tree

2 files changed

+91
-41
lines changed

2 files changed

+91
-41
lines changed

src/subtype.c

Lines changed: 82 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,7 @@ static jl_value_t *pick_union_element(jl_value_t *u JL_PROPAGATES_ROOT, jl_stenv
620620
return u;
621621
}
622622

623-
static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param);
623+
static int local_forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param, int limit_slow);
624624

625625
// subtype for variable bounds consistency check. needs its own forall/exists environment.
626626
static int subtype_ccheck(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
@@ -636,17 +636,7 @@ static int subtype_ccheck(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
636636
if (x == (jl_value_t*)jl_any_type && jl_is_datatype(y))
637637
return 0;
638638
jl_saved_unionstate_t oldLunions; push_unionstate(&oldLunions, &e->Lunions);
639-
jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions);
640-
int sub;
641-
e->Lunions.used = e->Runions.used = 0;
642-
e->Runions.depth = 0;
643-
e->Runions.more = 0;
644-
e->Lunions.depth = 0;
645-
e->Lunions.more = 0;
646-
647-
sub = forall_exists_subtype(x, y, e, 0);
648-
649-
pop_unionstate(&e->Runions, &oldRunions);
639+
int sub = local_forall_exists_subtype(x, y, e, 0, 1);
650640
pop_unionstate(&e->Lunions, &oldLunions);
651641
return sub;
652642
}
@@ -1431,6 +1421,72 @@ static int is_definite_length_tuple_type(jl_value_t *x)
14311421
return k == JL_VARARG_NONE || k == JL_VARARG_INT;
14321422
}
14331423

1424+
static int _forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param, int *count, int *noRmore);
1425+
1426+
static int may_contain_union_decision(jl_value_t *x, jl_stenv_t *e, jl_typeenv_t *log) JL_NOTSAFEPOINT
1427+
{
1428+
if (x == NULL || x == (jl_value_t*)jl_any_type || x == jl_bottom_type)
1429+
return 0;
1430+
if (jl_is_unionall(x))
1431+
return may_contain_union_decision(((jl_unionall_t *)x)->body, e, log);
1432+
if (jl_is_datatype(x)) {
1433+
jl_datatype_t *xd = (jl_datatype_t *)x;
1434+
for (int i = 0; i < jl_nparams(xd); i++) {
1435+
jl_value_t *param = jl_tparam(xd, i);
1436+
if (jl_is_vararg(param))
1437+
param = jl_unwrap_vararg(param);
1438+
if (may_contain_union_decision(param, e, log))
1439+
return 1;
1440+
}
1441+
return 0;
1442+
}
1443+
if (!jl_is_typevar(x))
1444+
return 1;
1445+
jl_typeenv_t *t = log;
1446+
while (t != NULL) {
1447+
if (x == (jl_value_t *)t->var)
1448+
return 1;
1449+
t = t->prev;
1450+
}
1451+
jl_typeenv_t newlog = { (jl_tvar_t*)x, NULL, log };
1452+
jl_varbinding_t *xb = lookup(e, (jl_tvar_t *)x);
1453+
return may_contain_union_decision(xb ? xb->lb : ((jl_tvar_t *)x)->lb, e, &newlog) ||
1454+
may_contain_union_decision(xb ? xb->ub : ((jl_tvar_t *)x)->ub, e, &newlog);
1455+
}
1456+
1457+
static int local_forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param, int limit_slow)
1458+
{
1459+
int16_t oldRmore = e->Runions.more;
1460+
int sub;
1461+
if (may_contain_union_decision(y, e, NULL) && pick_union_decision(e, 1) == 0) {
1462+
jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions);
1463+
e->Lunions.used = e->Runions.used = 0;
1464+
e->Lunions.depth = e->Runions.depth = 0;
1465+
e->Lunions.more = e->Runions.more = 0;
1466+
int count = 0, noRmore = 0;
1467+
sub = _forall_exists_subtype(x, y, e, param, &count, &noRmore);
1468+
pop_unionstate(&e->Runions, &oldRunions);
1469+
// we should not try the slow path if `forall_exists_subtype` has tested all cases;
1470+
// Once limit_slow == 1, also skip it if
1471+
// 1) `forall_exists_subtype` return false
1472+
// 2) the left `Union` looks big
1473+
if (noRmore || (limit_slow && (count > 3 || !sub)))
1474+
e->Runions.more = oldRmore;
1475+
}
1476+
else {
1477+
// slow path
1478+
e->Lunions.used = 0;
1479+
while (1) {
1480+
e->Lunions.more = 0;
1481+
e->Lunions.depth = 0;
1482+
sub = subtype(x, y, e, param);
1483+
if (!sub || !next_union_state(e, 0))
1484+
break;
1485+
}
1486+
}
1487+
return sub;
1488+
}
1489+
14341490
static int forall_exists_equal(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
14351491
{
14361492
if (obviously_egal(x, y)) return 1;
@@ -1449,33 +1505,13 @@ static int forall_exists_equal(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
14491505
}
14501506

14511507
jl_saved_unionstate_t oldLunions; push_unionstate(&oldLunions, &e->Lunions);
1452-
e->Lunions.used = 0;
1453-
int sub;
1454-
1455-
if (!jl_has_free_typevars(x) || !jl_has_free_typevars(y)) {
1456-
jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions);
1457-
e->Runions.used = 0;
1458-
e->Runions.depth = 0;
1459-
e->Runions.more = 0;
1460-
e->Lunions.depth = 0;
1461-
e->Lunions.more = 0;
14621508

1463-
sub = forall_exists_subtype(x, y, e, 2);
1464-
1465-
pop_unionstate(&e->Runions, &oldRunions);
1466-
}
1467-
else {
1468-
while (1) {
1469-
e->Lunions.more = 0;
1470-
e->Lunions.depth = 0;
1471-
sub = subtype(x, y, e, 2);
1472-
if (!sub || !next_union_state(e, 0))
1473-
break;
1474-
}
1475-
}
1509+
int limit_slow = !jl_has_free_typevars(x) || !jl_has_free_typevars(y);
1510+
int sub = local_forall_exists_subtype(x, y, e, 2, limit_slow) &&
1511+
local_forall_exists_subtype(y, x, e, 0, 0);
14761512

14771513
pop_unionstate(&e->Lunions, &oldLunions);
1478-
return sub && subtype(y, x, e, 0);
1514+
return sub;
14791515
}
14801516

14811517
static int exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, jl_value_t *saved, jl_savedenv_t *se, int param)
@@ -1502,7 +1538,7 @@ static int exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, jl_value_
15021538
}
15031539
}
15041540

1505-
static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param)
1541+
static int _forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param, int *count, int *noRmore)
15061542
{
15071543
// The depth recursion has the following shape, after simplification:
15081544
// ∀₁
@@ -1515,8 +1551,12 @@ static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, in
15151551

15161552
e->Lunions.used = 0;
15171553
int sub;
1554+
if (count) *count = 0;
1555+
if (noRmore) *noRmore = 1;
15181556
while (1) {
15191557
sub = exists_subtype(x, y, e, saved, &se, param);
1558+
if (count) *count = (*count < 4) ? *count + 1 : 4;
1559+
if (noRmore) *noRmore = *noRmore && e->Runions.more == 0;
15201560
if (!sub || !next_union_state(e, 0))
15211561
break;
15221562
free_env(&se);
@@ -1528,6 +1568,11 @@ static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, in
15281568
return sub;
15291569
}
15301570

1571+
static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param)
1572+
{
1573+
return _forall_exists_subtype(x, y, e, param, NULL, NULL);
1574+
}
1575+
15311576
static void init_stenv(jl_stenv_t *e, jl_value_t **env, int envsz)
15321577
{
15331578
e->vars = NULL;

test/subtype.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1484,6 +1484,8 @@ f24521(::Type{T}, ::Type{T}) where {T} = T
14841484
@test !(Ref{Union{Int64, Val{Number}}} <: Ref{Union{Val{T}, T}} where T)
14851485
@test !(Ref{Union{Ref{Number}, Int64}} <: Ref{Union{Ref{T}, T}} where T)
14861486
@test !(Ref{Union{Val{Number}, Int64}} <: Ref{Union{Val{T}, T}} where T)
1487+
@test !(Val{Ref{Union{Int64, Ref{Number}}}} <: Val{S} where {S<:Ref{Union{Ref{T}, T}} where T})
1488+
@test !(Tuple{Ref{Union{Int64, Ref{Number}}}} <: Tuple{S} where {S<:Ref{Union{Ref{T}, T}} where T})
14871489

14881490
# issue #26180
14891491
@test !(Ref{Union{Ref{Int64}, Ref{Number}}} <: Ref{Ref{T}} where T)
@@ -2385,8 +2387,8 @@ abstract type P47654{A} end
23852387
@test_broken typeintersect(Tuple{Vector{VT}, Vector{VT}} where {N1, VT<:AbstractVector{N1}},
23862388
Tuple{Vector{VN} where {N, VN<:AbstractVector{N}}, Vector{Vector{Float64}}}) !== Union{}
23872389
#issue 40865
2388-
@test_broken Tuple{Set{Ref{Int}}, Set{Ref{Int}}} <: Tuple{Set{KV}, Set{K}} where {K,KV<:Union{K,Ref{K}}}
2389-
@test_broken Tuple{Set{Val{Int}}, Set{Val{Int}}} <: Tuple{Set{KV}, Set{K}} where {K,KV<:Union{K,Val{K}}}
2390+
@test Tuple{Set{Ref{Int}}, Set{Ref{Int}}} <: Tuple{Set{KV}, Set{K}} where {K,KV<:Union{K,Ref{K}}}
2391+
@test Tuple{Set{Val{Int}}, Set{Val{Int}}} <: Tuple{Set{KV}, Set{K}} where {K,KV<:Union{K,Val{K}}}
23902392

23912393
#issue 39099
23922394
A = Tuple{Tuple{Int, Int, Vararg{Int, N}}, Tuple{Int, Vararg{Int, N}}, Tuple{Vararg{Int, N}}} where N
@@ -2420,8 +2422,7 @@ end
24202422

24212423
# try to fool a greedy algorithm that picks X=Int, Y=String here
24222424
@test Tuple{Ref{Union{Int,String}}, Ref{Union{Int,String}}} <: Tuple{Ref{Union{X,Y}}, Ref{X}} where {X,Y}
2423-
# this slightly more complex case has been broken since 1.0 (worked in 0.6)
2424-
@test_broken Tuple{Ref{Union{Int,String,Missing}}, Ref{Union{Int,String}}} <: Tuple{Ref{Union{X,Y}}, Ref{X}} where {X,Y}
2425+
@test Tuple{Ref{Union{Int,String,Missing}}, Ref{Union{Int,String}}} <: Tuple{Ref{Union{X,Y}}, Ref{X}} where {X,Y}
24252426

24262427
@test !(Tuple{Any, Any, Any} <: Tuple{Any, Vararg{T}} where T)
24272428

@@ -2435,3 +2436,7 @@ let A = Tuple{Type{T}, T} where T,
24352436
C = Tuple{Type{MyType47877{W, V} where V<:Union{MyAbstract47877{W}, Base.BitInteger}}, MyType47877{W, V} where V<:Union{MyAbstract47877{W}, Base.BitInteger}} where W<:Base.BitInteger
24362437
@test typeintersect(B, A) == C
24372438
end
2439+
2440+
let a = (isodd(i) ? Pair{Char, String} : Pair{String, String} for i in 1:2000)
2441+
@test Tuple{Type{Pair{Union{Char, String}, String}}, a...} <: Tuple{Type{Pair{K, V}}, Vararg{Pair{A, B} where B where A}} where V where K
2442+
end

0 commit comments

Comments
 (0)