Skip to content

Commit 6a923af

Browse files
eliascarvjuliohm
andauthored
Define == and isapprox for transforms (#14)
* Define '==' and 'isapprox' for transforms * Update CI.yml --------- Co-authored-by: Júlio Hoffimann <[email protected]>
1 parent c7ddf28 commit 6a923af

File tree

3 files changed

+28
-1
lines changed

3 files changed

+28
-1
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
fail-fast: false
1818
matrix:
1919
version:
20-
- '1.6'
20+
- '1.9'
2121
- '1'
2222
os:
2323
- ubuntu-latest

src/interface.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,19 @@ reapply(transform::Transform, object, cache) = apply(transform, object) |> first
123123

124124
(transform::Transform)(object) = apply(transform, object) |> first
125125

126+
Base.:(==)(t₁::Transform, t₂::Transform) = nameof(typeof(t₁)) == nameof(typeof(t₂)) && parameters(t₁) == parameters(t₂)
127+
128+
Base.isapprox(t₁::Transform, t₂::Transform; kwargs...) =
129+
nameof(typeof(t₁)) == nameof(typeof(t₂)) && _isapprox(parameters(t₁), parameters(t₂); kwargs...)
130+
131+
_isapprox(tup₁::NamedTuple, tup₂::NamedTuple; kwargs...) =
132+
propertynames(tup₁) == propertynames(tup₂) && _isapprox(Tuple(tup₁), Tuple(tup₂); kwargs...)
133+
134+
_isapprox(tup₁::Tuple, tup₂::Tuple; kwargs...) =
135+
length(tup₁) == length(tup₂) && all(_isapprox(x₁, x₂; kwargs...) for (x₁, x₂) in zip(tup₁, tup₂))
136+
137+
_isapprox(x₁, x₂; kwargs...) = isapprox(x₁, x₂; kwargs...)
138+
126139
# -----------
127140
# IO METHODS
128141
# -----------

test/runtests.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,20 @@ using Test
4848
@test T[begin] == TestTransform()
4949
@test T[end] == Identity()
5050

51+
# equality and approximation
52+
struct TestParamTransform <: TransformsBase.Transform
53+
param::Float64
54+
end
55+
TransformsBase.apply(t::TestParamTransform, x) = x * t.param, nothing
56+
TransformsBase.parameters(t::TestParamTransform) = (; param=t.param)
57+
T1 = TestParamTransform(1.0)
58+
T2 = TestParamTransform(1.0f0)
59+
T3 = TestTransform()
60+
@test T1 == T2
61+
@test T1 T3
62+
@test T1 T2
63+
@test T1 T3
64+
5165
T1 = Identity()
5266
T2 = TestTransform()
5367
T3 = TestTransform() TestTransform()

0 commit comments

Comments
 (0)