33
44Kernel derived from `k` for which inputs are transformed via a [`Transform`](@ref) `t`.
55
6- It is preferred to create kernels with input transformations with [`transform`](@ref)
7- instead of `TransformedKernel` directly since [`transform`](@ref) allows optimized
8- implementations for specific kernels and transformations.
6+ The preferred way to create kernels with input transformations is to use the composition
7+ operator [`∘`](@ref) or its alias `compose` instead of `TransformedKernel` directly since
8+ this allows optimized implementations for specific kernels and transformations.
99
10- # Definition
11-
12- For inputs ``x, x'``, the transformed kernel ``\\ widetilde{k}`` derived from kernel ``k`` by
13- input transformation ``t`` is defined as
14- ```math
15- \\ widetilde{k}(x, x'; k, t) = k\\ big(t(x), t(x')\\ big).
16- ```
10+ See also: [`∘`](@ref)
1711"""
1812struct TransformedKernel{Tk<: Kernel ,Tr<: Transform } <: Kernel
1913 kernel:: Tk
4236_scale (t:: ScaleTransform , metric, x, y) = evaluate (metric, t (x), t (y))
4337
4438"""
45- transform(k::Kernel, t::Transform)
39+ kernel ∘ transform
40+ ∘(kernel, transform)
41+ compose(kernel, transform)
4642
47- Create a [`TransformedKernel`](@ref) for kernel `k` and transform `t`.
48- """
49- transform (k:: Kernel , t:: Transform ) = TransformedKernel (k, t)
50- function transform (k:: TransformedKernel , t:: Transform )
51- return TransformedKernel (k. kernel, t ∘ k. transform)
52- end
43+ Compose a `kernel` with a transformation `transform` of its inputs.
5344
54- """
55- transform(k::Kernel, ρ::Real)
45+ The prefix forms support chains of multiple transformations:
46+ `∘(kernel, transform1, transform2) = kernel ∘ transform1 ∘ transform2`.
5647
57- Create a [`TransformedKernel`](@ref) for kernel `k` and inverse lengthscale `ρ`.
58- """
59- transform (k:: Kernel , ρ:: Real ) = transform (k, ScaleTransform (ρ))
48+ # Definition
6049
61- """
62- transform(k::Kernel, ρ::AbstractVector)
50+ For inputs ``x, x'``, the transformed kernel ``\\ widetilde{k}`` derived from kernel ``k`` by
51+ input transformation ``t`` is defined as
52+ ```math
53+ \\ widetilde{k}(x, x'; k, t) = k\\ big(t(x), t(x')\\ big).
54+ ```
6355
64- Create a [`TransformedKernel`](@ref) for kernel `k` and inverse lengthscales `ρ`.
65- """
66- transform (k:: Kernel , ρ:: AbstractVector ) = transform (k, ARDTransform (ρ))
56+ # Examples
57+
58+ ```jldoctest
59+ julia> (SqExponentialKernel() ∘ ScaleTransform(0.5))(0, 2) == exp(-0.5)
60+ true
6761
68- kernel (κ) = κ. kernel
62+ julia> ∘(ExponentialKernel(), ScaleTransform(2), ScaleTransform(0.5))(1, 2) == exp(-1)
63+ true
64+ ```
65+
66+ See also: [`TransformedKernel`](@ref)
67+ """
68+ Base.:∘ (k:: Kernel , t:: Transform ) = TransformedKernel (k, t)
69+ Base.:∘ (k:: TransformedKernel , t:: Transform ) = TransformedKernel (k. kernel, k. transform ∘ t)
6970
7071Base. show (io:: IO , κ:: TransformedKernel ) = printshifted (io, κ, 0 )
7172
@@ -87,13 +88,13 @@ function kernelmatrix_diag!(
8788end
8889
8990function kernelmatrix! (K:: AbstractMatrix , κ:: TransformedKernel , x:: AbstractVector )
90- return kernelmatrix! (K, kernel (κ) , _map (κ. transform, x))
91+ return kernelmatrix! (K, κ . kernel , _map (κ. transform, x))
9192end
9293
9394function kernelmatrix! (
9495 K:: AbstractMatrix , κ:: TransformedKernel , x:: AbstractVector , y:: AbstractVector
9596)
96- return kernelmatrix! (K, kernel (κ) , _map (κ. transform, x), _map (κ. transform, y))
97+ return kernelmatrix! (K, κ . kernel , _map (κ. transform, x), _map (κ. transform, y))
9798end
9899
99100function kernelmatrix_diag (κ:: TransformedKernel , x:: AbstractVector )
@@ -105,9 +106,9 @@ function kernelmatrix_diag(κ::TransformedKernel, x::AbstractVector, y::Abstract
105106end
106107
107108function kernelmatrix (κ:: TransformedKernel , x:: AbstractVector )
108- return kernelmatrix (kernel (κ) , _map (κ. transform, x))
109+ return kernelmatrix (κ . kernel , _map (κ. transform, x))
109110end
110111
111112function kernelmatrix (κ:: TransformedKernel , x:: AbstractVector , y:: AbstractVector )
112- return kernelmatrix (kernel (κ) , _map (κ. transform, x), _map (κ. transform, y))
113+ return kernelmatrix (κ . kernel , _map (κ. transform, x), _map (κ. transform, y))
113114end
0 commit comments