@@ -211,6 +211,8 @@ def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma, delta):
211211 sequences = [at_C , at_D ],
212212 outputs_info = [st0 , et0 , it0 , logp_c , logp_d ],
213213 non_sequences = [beta , gamma , delta ],
214+ # multi-output Elemwise not supported in NUMBA
215+ mode = get_mode ("NUMBA" ).excluding ("fusion" ),
214216 )
215217 st .name = "S_t"
216218 et .name = "E_t"
@@ -321,6 +323,8 @@ def power_of_2(previous_power, max_value):
321323 outputs_info = at .constant (1.0 ),
322324 non_sequences = max_value ,
323325 n_steps = 1024 ,
326+ # multi-output Elemwise not supported in NUMBA
327+ mode = get_mode ("NUMBA" ).excluding ("fusion" ),
324328 )
325329
326330 out_fg = FunctionGraph ([max_value ], [values ])
@@ -370,6 +374,8 @@ def f_pow2(x_tm2, x_tm1):
370374 state_val = np .array ([1.0 , 2.0 ])
371375
372376 numba_mode = get_mode ("NUMBA" ).including ("scan_save_mem" )
377+ # multi-output Elemwise not supported in NUMBA
378+ numba_mode = numba_mode .excluding ("fusion" )
373379 py_mode = Mode ("py" ).including ("scan_save_mem" )
374380
375381 out_fg = FunctionGraph ([init_x , n_steps ], [output ])
@@ -409,6 +415,8 @@ def inner_fct(seq, state_old, state_current):
409415 g_outs = grad (out .sum (), [seq , init_x ])
410416
411417 numba_mode = get_mode ("NUMBA" ).including ("scan_save_mem" )
418+ # multi-output Elemwise not supported in NUMBA
419+ numba_mode = numba_mode .excluding ("fusion" )
412420 py_mode = Mode ("py" ).including ("scan_save_mem" )
413421
414422 out_fg = FunctionGraph ([seq , init_x ], g_outs )
0 commit comments