Skip to content

Commit fac0e48

Browse files
committed
Simplify data passed across functions
1 parent 4cec625 commit fac0e48

File tree

1 file changed

+27
-26
lines changed

1 file changed

+27
-26
lines changed

src/code_gen/gen_c_code.jl

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,11 @@ function declare_var(lang::LangC, val, id, type)
5555
end
5656
function declare_coeff(lang::LangC, val::T, id) where T
5757
# Use real type regardless of T.
58-
is_real_val = isreal(val)
59-
actual_val = is_real_val ? real(val) : val
58+
actual_val = isreal(val) ? real(val) : val
6059
(blas_t, blas_prefix) = get_blas_type(lang, typeof(actual_val))
6160
variable_string = "coeff_$id"
6261
dec_init_string = declare_var(lang, actual_val, variable_string, blas_t)
63-
return (is_real_val, variable_string, dec_init_string)
62+
return (val, variable_string, dec_init_string)
6463
end
6564

6665
function assignment_string(::LangC, var, real_part)
@@ -305,7 +304,7 @@ function function_end(lang::LangC, graph, mem)
305304
end
306305

307306
function add_lincomb_body(code, lang::LangC, T, nodemem,
308-
coeff_names, coeff_is_real, parent_mems)
307+
coeff_names, coeff_values, parent_mems)
309308
rhs = join((coeff_names .* " * ") .* ("*(" .* parent_mems .* " + i)"),
310309
" + \n ")
311310
for_body = "*(" * nodemem * " + i) = " * (isempty(rhs) ? "0" : rhs) * ";"
@@ -321,13 +320,14 @@ fma_function_name(T::Type{Complex{Float32}}, coeff_real) =
321320
fma_function_name(T::Type{Complex{Float64}}, coeff_real) =
322321
coeff_real ? "fma_MKL_Complex16_double" : "fma_MKL_Complex16"
323322
function add_lincomb_body(code, lang::LangC_MKL, T::Type{Complex{S}}, nodemem,
324-
coeff_names, coeff_is_real,
323+
coeff_names, coeff_values,
325324
parent_mems) where S <: Real
326-
for (it, (coeff, coeff_real, parent)) in
327-
enumerate(zip(coeff_names, coeff_is_real, parent_mems))
328-
statement = (it == 1 ?
329-
prod_function_name(T, coeff_real) :
330-
fma_function_name(T, coeff_real)) *
325+
for (it, (coeff, coeff_value, parent)) in
326+
enumerate(zip(coeff_names, coeff_values, parent_mems))
327+
operation = it == 1 ?
328+
prod_function_name(T, isreal(coeff_value)) :
329+
fma_function_name(T, isreal(coeff_value))
330+
statement = operation *
331331
"(" * nodemem * " + i, " *
332332
reference_value(lang, T, coeff) * ", " *
333333
parent * " + i);"
@@ -336,17 +336,21 @@ function add_lincomb_body(code, lang::LangC_MKL, T::Type{Complex{S}}, nodemem,
336336
end
337337

338338
function add_lincomb_identity_body(code, lang::LangC, T,
339-
nodemem, coeff_id, coeff_real)
340-
push_code!(code, "*(" * nodemem * " + i) += " * coeff_id * ";", ind_lvl = 2)
339+
nodemem, coeff_id, coeff_value)
340+
push_code!(
341+
code,
342+
"*(" * nodemem * " + i) += " * coeff_id * ";",
343+
ind_lvl = 2
344+
)
341345
end
342346

343347
acc_function_name(T::Type{Complex{Float32}}, coeff_real) =
344348
coeff_real ? "acc_MKL_Complex8_float" : "acc_MKL_Complex8"
345349
acc_function_name(T::Type{Complex{Float64}}, coeff_real) =
346350
coeff_real ? "acc_MKL_Complex16_double" : "acc_MKL_Complex16"
347351
function add_lincomb_identity_body(code, lang::LangC_MKL, T::Type{Complex{S}},
348-
nodemem, coeff_id, coeff_real) where S <: Real
349-
statement = acc_function_name(T, coeff_real) * "(" * nodemem * " + i, " *
352+
nodemem, coeff_id, coeff_value) where S <: Real
353+
statement = acc_function_name(T, isreal(coeff_value)) * "(" * nodemem * " + i, " *
350354
reference_value(lang, T, coeff_id) * ");"
351355
push_code!(code, statement, ind_lvl = 2)
352356
end
@@ -483,7 +487,7 @@ function execute_operation!(lang::LangC, T, graph, node, dealloc_list, mem)
483487

484488
# Set coefficients.
485489
coeff_names = Vector()
486-
coeff_is_real = Vector()
490+
coeff_values = Vector()
487491
parent_mems = Vector()
488492
id_coefficient = 0
489493
counter = 1
@@ -497,12 +501,12 @@ function execute_operation!(lang::LangC, T, graph, node, dealloc_list, mem)
497501
id_coefficient += v
498502
else
499503
# Coefficient of other nodes.
500-
(coeff_real, coeff_i, coeff_i_code) = declare_coeff(lang, v,
504+
(coeff_value, coeff_i, coeff_i_code) = declare_coeff(lang, v,
501505
"$node" * "_" * "$counter")
502506
counter += 1
503507
push_code!(code, coeff_i_code)
504508
push!(coeff_names, coeff_i)
505-
push!(coeff_is_real, coeff_real)
509+
push!(coeff_values, coeff_value)
506510
push!(parent_mems, get_slot_name(mem, n))
507511
end
508512
end
@@ -532,9 +536,9 @@ function execute_operation!(lang::LangC, T, graph, node, dealloc_list, mem)
532536
splice!(coeff_names, index)
533537
pushfirst!(coeff_names, coeff_name)
534538

535-
coeff_real = coeff_is_real[index]
536-
splice!(coeff_is_real, index)
537-
pushfirst!(coeff_is_real, coeff_real)
539+
coeff_value = coeff_values[index]
540+
splice!(coeff_values, index)
541+
pushfirst!(coeff_values, coeff_value)
538542
end
539543

540544
# Write the linear combination.
@@ -543,19 +547,16 @@ function execute_operation!(lang::LangC, T, graph, node, dealloc_list, mem)
543547
else
544548
push_code!(code, "for (size_t i = 0; i < n * n; i++) {")
545549
add_lincomb_body(code, lang, T, nodemem,
546-
coeff_names, coeff_is_real, parent_mems)
550+
coeff_names, coeff_value, parent_mems)
547551
push_code!(code, "}")
548552
end
549553

550554
if id_coefficient != 0
551-
(coeff_real, coeff_id, coeff_id_code) = declare_coeff(lang,
555+
(coeff_value, coeff_id, coeff_id_code) = declare_coeff(lang,
552556
id_coefficient, "$node" * "_0")
553557
push_code!(code, coeff_id_code)
554-
# for statement in coeff_id_code
555-
# push_code!(code, statement)
556-
# end
557558
push_code!(code, "for (size_t i = 0; i < n * n; i += n + 1) {")
558-
add_lincomb_identity_body(code, lang, T, nodemem, coeff_id, coeff_real)
559+
add_lincomb_identity_body(code, lang, T, nodemem, coeff_id, coeff_value)
559560
push_code!(code, "}")
560561
end
561562

0 commit comments

Comments
 (0)