Skip to content

Commit d1a18b2

Browse files
committed
Rename KernelTensorSum -> KernelIndependentSum
1 parent 668fcfd commit d1a18b2

File tree

7 files changed

+47
-43
lines changed

7 files changed

+47
-43
lines changed

docs/src/kernels.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ TransformedKernel
124124
ScaledKernel
125125
KernelSum
126126
KernelProduct
127-
KernelTensorSum
127+
KernelIndependentSum
128128
KernelTensorProduct
129129
NormalizedKernel
130130
```

src/KernelFunctions.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ export LinearKernel, PolynomialKernel
1515
export RationalKernel, RationalQuadraticKernel, GammaRationalKernel
1616
export PiecewisePolynomialKernel
1717
export PeriodicKernel, NeuralNetworkKernel
18-
export KernelSum, KernelProduct, KernelTensorSum, KernelTensorProduct
18+
export KernelSum, KernelProduct, KernelIndependentSum, KernelTensorProduct
1919
export TransformedKernel, ScaledKernel, NormalizedKernel
2020
export GibbsKernel
2121
export
@@ -109,7 +109,7 @@ include("kernels/normalizedkernel.jl")
109109
include("matrix/kernelmatrix.jl")
110110
include("kernels/kernelsum.jl")
111111
include("kernels/kernelproduct.jl")
112-
include("kernels/kerneltensorsum.jl")
112+
include("kernels/kernelindependentsum.jl")
113113
include("kernels/kerneltensorproduct.jl")
114114
include("kernels/overloads.jl")
115115
include("kernels/neuralkernelnetwork.jl")

src/kernels/kerneltensorsum.jl renamed to src/kernels/kernelindependentsum.jl

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
2-
KernelTensorSum
2+
KernelIndependentSum
33
4-
Tensor sum of kernels.
4+
Independent sum of kernels.
55
66
# Definition
77
@@ -13,42 +13,42 @@ k(x, x'; k_1, \\ldots, k_n) = \\sum_{i=1}^n k_i(x_i, x'_i).
1313
1414
# Construction
1515
16-
The simplest way to specify a `KernelTensorSum` is to use the `⊕` operator (can be typed by `\\oplus<tab>`).
17-
```jldoctest tensorproduct
16+
The simplest way to specify a `KernelIndependentSum` is to use the `⊕` operator (can be typed by `\\oplus<tab>`).
17+
```jldoctest independentsum
1818
julia> k1 = SqExponentialKernel(); k2 = LinearKernel(); X = rand(5, 2);
1919
2020
julia> kernelmatrix(k1 ⊕ k2, RowVecs(X)) == kernelmatrix(k1, X[:, 1]) + kernelmatrix(k2, X[:, 2])
2121
true
2222
```
2323
24-
You can also specify a `KernelTensorSum` by providing kernels as individual arguments
24+
You can also specify a `KernelIndependentSum` by providing kernels as individual arguments
2525
or as an iterable data structure such as a `Tuple` or a `Vector`. Using a tuple or
26-
individual arguments guarantees that `KernelTensorSum` is concretely typed but might
26+
individual arguments guarantees that `KernelIndependentSum` is concretely typed but might
2727
lead to large compilation times if the number of kernels is large.
28-
```jldoctest tensorproduct
29-
julia> KernelTensorSum(k1, k2) == k1 ⊕ k2
28+
```jldoctest independentsum
29+
julia> KernelIndependentSum(k1, k2) == k1 ⊕ k2
3030
true
3131
32-
julia> KernelTensorSum((k1, k2)) == k1 ⊕ k2
32+
julia> KernelIndependentSum((k1, k2)) == k1 ⊕ k2
3333
true
3434
35-
julia> KernelTensorSum([k1, k2]) == k1 ⊕ k2
35+
julia> KernelIndependentSum([k1, k2]) == k1 ⊕ k2
3636
true
3737
```
3838
"""
39-
struct KernelTensorSum{K} <: Kernel
39+
struct KernelIndependentSum{K} <: Kernel
4040
kernels::K
4141
end
4242

43-
function KernelTensorSum(kernel::Kernel, kernels::Kernel...)
44-
return KernelTensorSum((kernel, kernels...))
43+
function KernelIndependentSum(kernel::Kernel, kernels::Kernel...)
44+
return KernelIndependentSum((kernel, kernels...))
4545
end
4646

47-
@functor KernelTensorSum
47+
@functor KernelIndependentSum
4848

49-
Base.length(kernel::KernelTensorSum) = length(kernel.kernels)
49+
Base.length(kernel::KernelIndependentSum) = length(kernel.kernels)
5050

51-
function (kernel::KernelTensorSum)(x, y)
51+
function (kernel::KernelIndependentSum)(x, y)
5252
if !((nx = length(x)) == (ny = length(y)) == (nkernels = length(kernel)))
5353
throw(
5454
DimensionMismatch(
@@ -59,46 +59,46 @@ function (kernel::KernelTensorSum)(x, y)
5959
return sum(k(xi, yi) for (k, xi, yi) in zip(kernel.kernels, x, y))
6060
end
6161

62-
function validate_domain(k::KernelTensorSum, x::AbstractVector, y::AbstractVector)
62+
function validate_domain(k::KernelIndependentSum, x::AbstractVector, y::AbstractVector)
6363
return (dx = dim(x)) == (dy = dim(y)) == (nkernels = length(k)) || error(
6464
"number of kernels ($nkernels) and group of features (x=$dx), y=$dy) are not consistent",
6565
)
6666
end
6767

68-
function validate_domain(k::KernelTensorSum, x::AbstractVector)
68+
function validate_domain(k::KernelIndependentSum, x::AbstractVector)
6969
return validate_domain(k, x, x)
7070
end
7171

72-
function kernelmatrix(k::KernelTensorSum, x::AbstractVector)
72+
function kernelmatrix(k::KernelIndependentSum, x::AbstractVector)
7373
validate_domain(k, x)
7474
return mapreduce(kernelmatrix, +, k.kernels, slices(x))
7575
end
7676

77-
function kernelmatrix(k::KernelTensorSum, x::AbstractVector, y::AbstractVector)
77+
function kernelmatrix(k::KernelIndependentSum, x::AbstractVector, y::AbstractVector)
7878
validate_domain(k, x, y)
7979
return mapreduce(kernelmatrix, +, k.kernels, slices(x), slices(y))
8080
end
8181

82-
function kernelmatrix_diag(k::KernelTensorSum, x::AbstractVector)
82+
function kernelmatrix_diag(k::KernelIndependentSum, x::AbstractVector)
8383
validate_domain(k, x)
8484
return mapreduce(kernelmatrix_diag, +, k.kernels, slices(x))
8585
end
8686

87-
function kernelmatrix_diag(k::KernelTensorSum, x::AbstractVector, y::AbstractVector)
87+
function kernelmatrix_diag(k::KernelIndependentSum, x::AbstractVector, y::AbstractVector)
8888
validate_domain(k, x, y)
8989
return mapreduce(kernelmatrix_diag, +, k.kernels, slices(x), slices(y))
9090
end
9191

92-
function Base.:(==)(x::KernelTensorSum, y::KernelTensorSum)
92+
function Base.:(==)(x::KernelIndependentSum, y::KernelIndependentSum)
9393
return (
9494
length(x.kernels) == length(y.kernels) &&
9595
all(kx == ky for (kx, ky) in zip(x.kernels, y.kernels))
9696
)
9797
end
9898

99-
Base.show(io::IO, kernel::KernelTensorSum) = printshifted(io, kernel, 0)
99+
Base.show(io::IO, kernel::KernelIndependentSum) = printshifted(io, kernel, 0)
100100

101-
function printshifted(io::IO, kernel::KernelTensorSum, shift::Int)
101+
function printshifted(io::IO, kernel::KernelIndependentSum, shift::Int)
102102
print(io, "Tensor sum of ", length(kernel), " kernels:")
103103
for k in kernel.kernels
104104
print(io, "\n")

src/kernels/overloads.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ for (M, op, T) in (
44
(:Base, :+, :KernelSum),
55
(:Base, :*, :KernelProduct),
66
(:TensorCore, :tensor, :KernelTensorProduct),
7-
(:KernelFunctions, :, :KernelTensorSum),
7+
(:KernelFunctions, :, :KernelIndependentSum),
88
)
99
@eval begin
1010
$M.$op(k1::Kernel, k2::Kernel) = $T(k1, k2)

test/kernels/kerneltensorsum.jl renamed to test/kernels/kernelindependentsum.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@
88
# kernels
99
k1 = SqExponentialKernel()
1010
k2 = ExponentialKernel()
11-
kernel1 = KernelTensorSum(k1, k2)
12-
kernel2 = KernelTensorSum([k1, k2])
11+
kernel1 = KernelIndependentSum(k1, k2)
12+
kernel2 = KernelIndependentSum([k1, k2])
1313

1414
@test kernel1 == kernel2
15-
@test kernel1.kernels == (k1, k2) === KernelTensorSum((k1, k2)).kernels
15+
@test kernel1.kernels == (k1, k2) === KernelIndependentSum((k1, k2)).kernels
1616
for (_k1, _k2) in Iterators.product(
17-
(k1, KernelTensorSum((k1,)), KernelTensorSum([k1])),
18-
(k2, KernelTensorSum((k2,)), KernelTensorSum([k2])),
17+
(k1, KernelIndependentSum((k1,)), KernelIndependentSum([k1])),
18+
(k2, KernelIndependentSum((k2,)), KernelIndependentSum([k2])),
1919
)
2020
@test kernel1 == _k1 _k2
2121
end
@@ -39,21 +39,21 @@
3939
TestUtils.test_interface(kernel1, ColVecs{Float64})
4040
TestUtils.test_interface(kernel1, RowVecs{Float64})
4141
TestUtils.test_interface(
42-
KernelTensorSum(WhiteKernel(), ConstantKernel(; c=1.1)), ColVecs{String}
42+
KernelIndependentSum(WhiteKernel(), ConstantKernel(; c=1.1)), ColVecs{String}
4343
)
4444
test_ADs(
45-
x -> KernelTensorSum(SqExponentialKernel(), LinearKernel(; c=exp(x[1]))),
45+
x -> KernelIndependentSum(SqExponentialKernel(), LinearKernel(; c=exp(x[1]))),
4646
rand(1);
4747
dims=[2, 2],
4848
)
4949
types = [ColVecs{Float64,Matrix{Float64}}, RowVecs{Float64,Matrix{Float64}}]
5050
test_interface_ad_perf(2.1, StableRNG(123456), types) do c
51-
KernelTensorSum(SqExponentialKernel(), LinearKernel(; c=c))
51+
KernelIndependentSum(SqExponentialKernel(), LinearKernel(; c=c))
5252
end
53-
test_params(KernelTensorSum(k1, k2), (k1, k2))
53+
test_params(KernelIndependentSum(k1, k2), (k1, k2))
5454

5555
@testset "single kernel" begin
56-
kernel = KernelTensorSum(k1)
56+
kernel = KernelIndependentSum(k1)
5757
@test length(kernel) == 1
5858

5959
@testset "eval" begin

test/kernels/overloads.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,13 @@
55
k2 = SqExponentialKernel()
66
k3 = RationalQuadraticKernel()
77

8-
for (op, T) in
9-
((+, KernelSum), (*, KernelProduct), (, KernelTensorProduct), (, KernelTensorSum))
10-
if T === KernelTensorProduct || T === KernelTensorSum
8+
for (op, T) in (
9+
(+, KernelSum),
10+
(*, KernelProduct),
11+
(, KernelTensorProduct),
12+
(, KernelIndependentSum),
13+
)
14+
if T === KernelTensorProduct || T === KernelIndependentSum
1115
v2_1 = rand(rng, 2)
1216
v2_2 = rand(rng, 2)
1317
v3_1 = rand(rng, 3)

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ include("test_utils.jl")
125125
include("kernels/kernelproduct.jl")
126126
include("kernels/kernelsum.jl")
127127
include("kernels/kerneltensorproduct.jl")
128-
include("kernels/kerneltensorsum.jl")
128+
include("kernels/kernelindependentsum.jl")
129129
include("kernels/overloads.jl")
130130
include("kernels/scaledkernel.jl")
131131
include("kernels/transformedkernel.jl")

0 commit comments

Comments
 (0)