Skip to content

Commit a8ec53a

Browse files
authored
Merge pull request #75 from matrixfunctions/multilincomb_refactor
More than two matrices in lincomb
2 parents 6d103c0 + 6a7c682 commit a8ec53a

22 files changed

+290
-540
lines changed

src/GraphMatFun.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ include("generators/horner.jl")
2525
include("generators/newton_schulz.jl")
2626
include("generators/exp.jl")
2727
include("generators/sastre.jl")
28+
include("generators/bigraph.jl")
2829

2930
# Error bounds
3031
include("error_bounds.jl")

src/code_gen/gen_c_code.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,12 @@ function get_blas_type(::LangC_OpenBLAS, T::Type{Complex{Float64}})
107107
return ("openblas_complex_double", "z")
108108
end
109109

110+
111+
function preprocess_codegen(graph, lang::LangC)
112+
(g,_)=graph_bigraph(graph)
113+
return g # Lang C only supports bigraphs
114+
end
115+
110116
function function_definition(lang::LangC, graph, T, funname, precomputed_nodes)
111117
(blas_type, blas_prefix) = get_blas_type(lang, T)
112118
code = init_code(lang)

src/code_gen/gen_code.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
include("gen_code_mem.jl")
22
include("gen_code_snippets.jl")
33

4-
include("multilincomb.jl");
5-
64
include("gen_c_code.jl")
75
include("gen_julia_code.jl")
86
include("gen_matlab_code.jl")

src/code_gen/gen_julia_code.jl

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ export LangJulia
44
struct LangJulia
55
overwrite_input::Any # Overwrite input
66
inline::Any
7-
dot_fusing::Any # Allow dot fusion
87
axpby_header::Any
98
alloc_function
109
only_overwrite
@@ -14,18 +13,14 @@ end
1413
default_alloc_function(k)="similar(A,T)"
1514

1615
"""
17-
LangJulia(overwrite_input=true,inline=true,dot_fusing=true,axpby_header=:auto,alloc_function,only_overwrite=false)
16+
LangJulia(overwrite_input=true,inline=true,axpby_header=:auto,alloc_function,only_overwrite=false)
1817
19-
Code generation in julia language, with optional overwriting of input, inlining
20-
the function and optional usage of dot fusion. The `axpby_header` specifies if axpby function calls should be included in the beginning of the file. The parameter `alloc_function` is a function of three parameters `alloc_function(k)` where `k` is the memory slot (default is `alloc_function(k)=similar(A,T)`). The `only_overwrite` specifies if `f` should be created if the overwrite funtion `f!` contains the actual code.
18+
Code generation in julia language, with optional overwriting of input. The `axpby_header` specifies if axpby function calls should be included in the beginning of the file. The parameter `alloc_function` is a function of three parameters `alloc_function(k)` where `k` is the memory slot (default is `alloc_function(k)=similar(A,T)`). The `only_overwrite` specifies if `f` should be created if the overwrite funtion `f!` contains the actual code.
2119
"""
22-
LangJulia() = LangJulia(true, true, true, :auto, default_alloc_function,false,"ValueOne","matfun_axpby!")
20+
LangJulia() = LangJulia(true, true, :auto, default_alloc_function,false,"ValueOne","matfun_axpby!")
2321
LangJulia(overwrite_input) = LangJulia(overwrite_input, true, true, :auto, default_alloc_function,false,"ValueOne","matfun_axpby!")
24-
function LangJulia(overwrite_input, inline)
25-
return LangJulia(overwrite_input, inline, true, :auto, default_alloc_function,false,"ValueOne","matfun_axpby!")
26-
end
27-
function LangJulia(overwrite_input, inline, dot_fusing; value_one_name="ValueOne",axpby_name="matfun_axpby!")
28-
return LangJulia(overwrite_input, inline, dot_fusing, :auto, default_alloc_function,false,value_one_name,axpby_name)
22+
function LangJulia(overwrite_input, inline ; value_one_name="ValueOne",axpby_name="matfun_axpby!")
23+
return LangJulia(overwrite_input, inline, :auto, default_alloc_function,false,value_one_name,axpby_name)
2924
end
3025

3126
# Language specific operations.
@@ -60,11 +55,7 @@ function assign_coeff_basic(lang::LangJulia, v, i)
6055
end
6156

6257
function preprocess_codegen(graph, lang::LangJulia)
63-
if (lang.dot_fusing)
64-
return MultiLincombCompgraph(graph) # Merge many lincombs for dot fusion
65-
else
66-
return graph
67-
end
58+
return graph # Merge many lincombs for dot fusion
6859
end
6960

7061
# Code generation.
@@ -124,7 +115,7 @@ function function_definition(
124115
code = init_code(lang)
125116
axpby_header = lang.axpby_header
126117
if (lang.axpby_header == :auto)
127-
axpby_header = !lang.dot_fusing
118+
axpby_header = false
128119
end
129120

130121
push_code!(code, "using LinearAlgebra", ind_lvl = 0)
@@ -297,7 +288,7 @@ end
297288
function execute_operation!(
298289
lang::LangJulia,
299290
T,
300-
graph::MultiLincombCompgraph,
291+
graph,
301292
node,
302293
dealloc_list,
303294
mem,
@@ -399,10 +390,6 @@ function execute_operation!(
399390
end
400391
end
401392

402-
function execute_operation!(lang::LangJulia, T, graph, node, dealloc_list, mem)
403-
return execute_operation_basic!(lang, T, graph, node, dealloc_list, mem)
404-
end
405-
406393
# The general base case. Separated for dispatch.
407394
function execute_operation_basic!(
408395
lang::LangJulia,

src/code_gen/gen_matlab_code.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,21 +61,26 @@ end
6161
function execute_operation!(lang::LangMatlab, T, graph, node, dealloc_list, mem)
6262
op = graph.operations[node]
6363
parent1 = graph.parents[node][1]
64-
parent2 = graph.parents[node][2]
6564

6665
code = init_code(lang)
6766
push_comment!(code, "Computing $node with operation: $op")
6867
# Needs to be updated
6968
if op == :mult
69+
parent2 = graph.parents[node][2]
7070
push_code!(code, "$node = $parent1 * $parent2;")
7171
elseif op == :ldiv
72+
parent2 = graph.parents[node][2]
7273
push_code!(code, "$node = $parent1 \\ $parent2;")
7374
elseif op == :lincomb
74-
(coeff1, coeff1_code) = assign_coeff(lang, graph.coeffs[node][1], 1)
75-
push_code!(code, "$coeff1_code;")
76-
(coeff2, coeff2_code) = assign_coeff(lang, graph.coeffs[node][2], 2)
77-
push_code!(code, "$coeff2_code;")
78-
push_code!(code, "$node = $coeff1*$parent1 + $coeff2*$parent2;")
75+
76+
coeff_names=Vector();
77+
for (i,coeff) = enumerate(graph.coeffs[node])
78+
(coeff, coeff_code) = assign_coeff(lang, graph.coeffs[node][i], i)
79+
push_code!(code, "$coeff_code;")
80+
push!(coeff_names,coeff)
81+
end
82+
sum_code=join((coeff_names.*"*") .* string.(graph.parents[node])," + ")
83+
push_code!(code, "$node = $sum_code;")
7984
end
8085
return (code, "$node")
8186
end

src/code_gen/multilincomb.jl

Lines changed: 0 additions & 178 deletions
This file was deleted.

0 commit comments

Comments
 (0)