@@ -17,6 +17,9 @@ abstract type RArray{ElType,Shape,N} <: AbstractArray{ElType,N} end
17
17
@inline mlir_type (:: RArray{ElType,Shape,N} ) where {ElType,Shape,N} =
18
18
MLIR. IR. TensorType (Shape, MLIR. IR. Type (ElType))
19
19
20
+ @inline mlir_type (:: Type{<:RArray{ElType,Shape,N}} ) where {ElType,Shape,N} =
21
+ MLIR. IR. TensorType (Shape, MLIR. IR. Type (ElType))
22
+
20
23
struct XLAArray{ElType,Shape,N} <: RArray{ElType,Shape,N} end
21
24
22
25
mutable struct ConcreteRArray{ElType,Shape,N} <: RArray{ElType,Shape,N}
@@ -208,6 +211,10 @@ using Enzyme
208
211
TracedSetPath = 5
209
212
end
210
213
214
+ @inline getmap (:: Val{T} ) where T = nothing
215
+ @inline getmap (:: Val{T} , a, b, args... ) where {T} = getmap (Val (T), args... )
216
+ @inline getmap (:: Val{T} , :: Val{T} , :: Val{T2} , args... ) where {T, T2} = T2
217
+
211
218
@inline is_concrete_tuple (x:: T2 ) where {T2} =
212
219
(x <: Tuple ) && ! (x === Tuple) && ! (x isa UnionAll)
213
220
@inline function traced_type (val:: Type{T} , seen:: ST , :: Val{mode} ) where {ST,T,mode}
@@ -391,17 +398,18 @@ end
391
398
return IdDict{iddict_name (T),traced_type (iddict_val (T), seen, Val (mode))}
392
399
end
393
400
394
- if Val (T) ∈ seen
395
- return T
401
+ nextTy = getmap (Val (T), seen... )
402
+ if nextTy != nothing
403
+ return nextTy
396
404
end
397
405
398
- seen = (Val (T), seen... )
406
+ seen2 = (Val (T), Val (T), seen... )
399
407
400
408
changed = false
401
409
subTys = Type[]
402
410
for f in 1 : fieldcount (T)
403
411
subT = fieldtype (T, f)
404
- subTT = traced_type (subT, seen , Val (mode))
412
+ subTT = traced_type (subT, seen2 , Val (mode))
405
413
changed |= subT != subTT
406
414
push! (subTys, subTT)
407
415
end
@@ -421,29 +429,32 @@ end
421
429
end
422
430
423
431
TT2 = Core. apply_type (T. name. wrapper, subParms... )
432
+ seen3 = (Val (T), Val (TT2), seen... )
424
433
if fieldcount (T) == fieldcount (TT2)
425
434
legal = true
426
435
for f in 1 : fieldcount (T)
427
436
subT = fieldtype (T, f)
428
437
subT2 = fieldtype (TT2, f)
429
- subTT = traced_type (subT, seen, Val (mode))
430
- legal &= subT2 == subTT
438
+ subTT = traced_type (subT, seen3, Val (mode))
439
+ if subT2 != subTT
440
+ legal = false
441
+ break
442
+ end
431
443
end
432
444
if legal
433
445
return TT2
434
446
end
435
447
end
436
448
437
449
name = Symbol[]
438
-
439
- return NamedTuple{fieldnames (T),Tuple{subTys... }}
450
+ throw (error (" Cannot convert type $T , best attempt $TT2 failed" ))
440
451
end
441
452
442
453
function append_path (path, i)
443
454
return (path... , i)
444
455
end
445
456
446
- @inline function make_tracer (seen:: IdDict , prev:: RT , path, mode) where {RT}
457
+ @inline function make_tracer (seen:: IdDict , prev:: RT , path:: Tuple , mode:: TraceMode ; toscalar = false , tobatch = nothing ) where {RT}
447
458
if haskey (seen, prev)
448
459
return seen[prev]
449
460
end
457
468
subs = []
458
469
for i in 1 : nf
459
470
xi = Base. getfield (prev, i)
460
- xi2 = make_tracer (seen, xi, append_path (path, i), mode)
471
+ xi2 = make_tracer (seen, xi, append_path (path, i), mode; toscalar, tobatch )
461
472
if xi != = xi2
462
473
changed = true
463
474
end
468
479
return prev
469
480
end
470
481
tup = (subs... ,)
471
- @show TT, subs, tup
472
482
return NamedTuple {TT.parameters[1],typeof(tup)} (tup)
473
483
end
474
484
479
489
for i in 1 : nf
480
490
if isdefined (prev, i)
481
491
xi = Base. getfield (prev, i)
482
- xi2 = make_tracer (seen, xi, append_path (path, i), mode)
492
+ xi2 = make_tracer (seen, xi, append_path (path, i), mode; toscalar, tobatch )
483
493
if xi != = xi2
484
494
changed = true
485
495
end
502
512
for i in 1 : nf
503
513
if isdefined (prev, i)
504
514
xi = Base. getfield (prev, i)
505
- xi2 = make_tracer (seen, xi, append_path (path, i), mode)
515
+ xi2 = make_tracer (seen, xi, append_path (path, i), mode; toscalar, tobatch )
506
516
if xi != = xi2
507
517
changed = true
508
518
end
522
532
end
523
533
524
534
@inline function make_tracer (
525
- seen:: IdDict , prev:: ConcreteRArray{ElType,Shape,N} , path, mode
535
+ seen:: IdDict , prev:: ConcreteRArray{ElType,Shape,N} , path:: Tuple , mode:: TraceMode ; toscalar = false , tobatch = nothing
526
536
) where {ElType,Shape,N}
527
537
if mode == ArrayToConcrete
528
538
return prev
540
550
end
541
551
542
552
@inline function make_tracer (
543
- seen:: IdDict , prev:: TracedRArray{ElType,Shape,N} , path, mode
553
+ seen:: IdDict , prev:: TracedRArray{ElType,Shape,N} , path:: Tuple , mode:: TraceMode ; toscalar = false , tobatch = nothing
544
554
) where {ElType,Shape,N}
545
555
if mode == ConcreteToTraced
546
556
throw (" Cannot trace existing trace type" )
556
566
if haskey (seen, prev)
557
567
return seen[prev]
558
568
end
559
- res = TracedRArray {ElType,Shape,N} ((path,), prev. mlir_data)
569
+ res = if toscalar
570
+ TracedRArray {ElType,(),0} ((path,), nothing )
571
+ elseif tobatch != = nothing
572
+ TracedRArray {ElType,tobatch,length(tobatch)} ((path,), prev. mlir_data)
573
+ else
574
+ TracedRArray {ElType,Shape,N} ((path,), prev. mlir_data)
575
+ end
560
576
seen[prev] = res
561
577
return res
562
578
end
@@ -573,18 +589,18 @@ end
573
589
throw (" Cannot Unknown trace mode $mode " )
574
590
end
575
591
576
- @inline function make_tracer (seen:: IdDict , prev:: RT , path, mode) where {RT<: AbstractFloat }
592
+ @inline function make_tracer (seen:: IdDict , prev:: RT , path:: Tuple , mode:: TraceMode ; toscalar = false , tobatch = nothing ) where {RT<: AbstractFloat }
577
593
return prev
578
594
end
579
595
580
- @inline function make_tracer (seen:: IdDict , prev:: Complex{RT} , path, mode) where {RT}
596
+ @inline function make_tracer (seen:: IdDict , prev:: Complex{RT} , path:: Tuple , mode:: TraceMode ; toscalar = false , tobatch = nothing ) where {RT}
581
597
return Complex (
582
- make_tracer (seen, prev. re, append_path (path, :re ), mode),
583
- make_tracer (seen, prev. im, append_path (path, :im ), mode),
598
+ make_tracer (seen, prev. re, append_path (path, :re ), mode; toscalar, tobatch ),
599
+ make_tracer (seen, prev. im, append_path (path, :im ), mode; toscalar, tobatch ),
584
600
)
585
601
end
586
602
587
- @inline function make_tracer (seen:: IdDict , prev:: RT , path, mode) where {RT<: Array }
603
+ @inline function make_tracer (seen:: IdDict , prev:: RT , path:: Tuple , mode:: TraceMode ; toscalar = false , tobatch = nothing ) where {RT<: Array }
588
604
if haskey (seen, prev)
589
605
return seen[prev]
590
606
end
598
614
for I in eachindex (prev)
599
615
if isassigned (prev, I)
600
616
pv = prev[I]
601
- nv = make_tracer (seen, pv, append_path (path, I), mode)
617
+ nv = make_tracer (seen, pv, append_path (path, I), mode; toscalar, tobatch )
602
618
if pv != = nv
603
619
same = false
604
620
end
@@ -612,27 +628,27 @@ end
612
628
return newa
613
629
end
614
630
615
- @inline function make_tracer (seen:: IdDict , prev:: RT , path, mode) where {RT<: Tuple }
631
+ @inline function make_tracer (seen:: IdDict , prev:: RT , path:: Tuple , mode:: TraceMode ; toscalar = false , tobatch = nothing ) where {RT<: Tuple }
616
632
return (
617
- (make_tracer (seen, v, append_path (path, i), mode) for (i, v) in enumerate (prev)). .. ,
633
+ (make_tracer (seen, v, append_path (path, i), mode; toscalar, tobatch ) for (i, v) in enumerate (prev)). .. ,
618
634
)
619
635
end
620
636
621
- @inline function make_tracer (seen:: IdDict , prev:: NamedTuple{A,RT} , path, mode) where {A,RT}
637
+ @inline function make_tracer (seen:: IdDict , prev:: NamedTuple{A,RT} , path:: Tuple , mode:: TraceMode ; toscalar = false , tobatch = nothing ) where {A,RT}
622
638
return NamedTuple {A,traced_type(RT, (), Val(mode))} ((
623
639
(
624
- make_tracer (seen, Base. getfield (prev, i), append_path (path, i), mode) for
640
+ make_tracer (seen, Base. getfield (prev, i), append_path (path, i), mode; toscalar, tobatch ) for
625
641
i in 1 : length (A)
626
642
). .. ,
627
643
))
628
644
end
629
645
630
- @inline function make_tracer (seen:: IdDict , prev:: Core.Box , path, mode)
646
+ @inline function make_tracer (seen:: IdDict , prev:: Core.Box , path:: Tuple , mode:: TraceMode ; toscalar = false , tobatch = nothing )
631
647
if haskey (seen, prev)
632
648
return seen[prev]
633
649
end
634
650
prev2 = prev. contents
635
- tr = make_tracer (seen, prev2, append_path (path, :contents ), mode)
651
+ tr = make_tracer (seen, prev2, append_path (path, :contents ), mode; toscalar, tobatch )
636
652
if tr == prev2
637
653
seen[prev] = prev
638
654
return prev
@@ -1100,9 +1116,12 @@ pad_dot_general<1>(1);
1100
1116
"""
1101
1117
1102
1118
function compile_to_module (mod, f, args; optimize= true )
1103
- fnwrapped, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn (
1104
- mod, f, args, (), " main" , true
1105
- )
1119
+ fnwrapped, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results =
1120
+ MLIR. IR. block! (MLIR. IR. body (mod)) do
1121
+ return make_mlir_fn (
1122
+ f, args, (), " main" , true
1123
+ )
1124
+ end
1106
1125
1107
1126
concrete_seen = IdDict ()
1108
1127
@@ -1112,6 +1131,7 @@ function compile_to_module(mod, f, args; optimize=true)
1112
1131
1113
1132
if optimize
1114
1133
XLA. RunPassPipeline (
1134
+ opt_passes * " ,enzyme-batch," *
1115
1135
opt_passes *
1116
1136
" ,enzyme,arith-raise{stablehlo=true},canonicalize, remove-unnecessary-enzyme-ops, enzyme-simplify-math," *
1117
1137
opt_passes,
0 commit comments