Skip to content

Commit db56c16

Browse files
martinholtersfredrikekre
authored andcommitted
Add functions with dims keyword argument (#518)
* Add functions with `dims` keyword argument Add (unexported) `Compat.accumulate`, `Compat.accumulate!`, `Compat.all`, `Compat.any`, `Compat.cor`, `Compat.cov`, `Compat.cumprod`, `Compat.cumprod!`, `Compat.cumsum`, `Compat.cumsum!`, `Compat.findmax`, `Compat.findmin`, `Compat.mapreduce`, `Compat.maximum`, `Compat.mean`, `Compat.median`, `Compat.minimum`, `Compat.prod`, `Compat.reduce`, `Compat.sort`, `Compat.std`, `Compat.sum`, `Compat.var`, and `Compat.varm` with `dims` keyword argument. * add a version check for some tests * Use correction version bound for conditional tests. * Remove some unnecessary at-evals ...at the cost of splitting definitions between two if-blocks.
1 parent 76a456b commit db56c16

File tree

3 files changed

+218
-8
lines changed

3 files changed

+218
-8
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,13 @@ Currently, the `@compat` macro supports the following syntaxes:
285285

286286
* `Compat.mv` and `Compat.cp` with `force` keyword argument ([#26069]).
287287

288+
* `Compat.accumulate`, `Compat.accumulate!`, `Compat.all`, `Compat.any`, `Compat.cor`,
289+
`Compat.cov`, `Compat.cumprod`, `Compat.cumprod!`, `Compat.cumsum`, `Compat.cumsum!`,
290+
`Compat.findmax`, `Compat.findmin`, `Compat.mapreduce`, `Compat.maximum`, `Compat.mean`,
291+
`Compat.median`, `Compat.minimum`, `Compat.prod`, `Compat.reduce`, `Compat.sort`,
292+
`Compat.std`, `Compat.sum`, `Compat.var`, and `Compat.varm` with `dims` keyword argument ([#25989],[#26369]).
293+
294+
288295
## Renaming
289296

290297
* `Display` is now `AbstractDisplay` ([#24831]).
@@ -597,6 +604,7 @@ includes this fix. Find the minimum version from there.
597604
[#25873]: https://github.com/JuliaLang/julia/issues/25873
598605
[#25896]: https://github.com/JuliaLang/julia/issues/25896
599606
[#25959]: https://github.com/JuliaLang/julia/issues/25959
607+
[#25989]: https://github.com/JuliaLang/julia/issues/25989
600608
[#25990]: https://github.com/JuliaLang/julia/issues/25990
601609
[#25998]: https://github.com/JuliaLang/julia/issues/25998
602610
[#26069]: https://github.com/JuliaLang/julia/issues/26069
@@ -605,5 +613,6 @@ includes this fix. Find the minimum version from there.
605613
[#26156]: https://github.com/JuliaLang/julia/issues/26156
606614
[#26283]: https://github.com/JuliaLang/julia/issues/26283
607615
[#26316]: https://github.com/JuliaLang/julia/issues/26316
616+
[#26369]: https://github.com/JuliaLang/julia/issues/26369
608617
[#26436]: https://github.com/JuliaLang/julia/issues/26436
609618
[#26442]: https://github.com/JuliaLang/julia/issues/26442

src/Compat.jl

Lines changed: 75 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,7 @@ if VERSION < v"0.7.0-DEV.755"
513513
# This is a hack to only add keyword signature that won't work on all julia versions.
514514
# However, since we really only need to support a few (0.5, 0.6 and early 0.7) versions
515515
# this should be good enough.
516-
let Tf = typeof(cov), Tkw = Core.Core.kwftype(Tf)
516+
let Tf = typeof(Base.cov), Tkw = Core.Core.kwftype(Tf)
517517
@eval begin
518518
@inline function _get_corrected(kws)
519519
corrected = true
@@ -527,14 +527,14 @@ if VERSION < v"0.7.0-DEV.755"
527527
return corrected::Bool
528528
end
529529
if VERSION >= v"0.6"
530-
(::$Tkw)(kws::Vector{Any}, ::$Tf, x::AbstractVector) = cov(x, _get_corrected(kws))
530+
(::$Tkw)(kws::Vector{Any}, ::$Tf, x::AbstractVector) = Base.cov(x, _get_corrected(kws))
531531
(::$Tkw)(kws::Vector{Any}, ::$Tf, X::AbstractVector, Y::AbstractVector) =
532-
cov(X, Y, _get_corrected(kws))
532+
Base.cov(X, Y, _get_corrected(kws))
533533
end
534534
(::$Tkw)(kws::Vector{Any}, ::$Tf, x::AbstractMatrix, vardim::Int) =
535-
cov(x, vardim, _get_corrected(kws))
535+
Base.cov(x, vardim, _get_corrected(kws))
536536
(::$Tkw)(kws::Vector{Any}, ::$Tf, X::AbstractVecOrMat, Y::AbstractVecOrMat,
537-
vardim::Int) = cov(X, Y, vardim, _get_corrected(kws))
537+
vardim::Int) = Base.cov(X, Y, vardim, _get_corrected(kws))
538538
end
539539
end
540540
end
@@ -948,7 +948,7 @@ end
948948
import Base: diagm
949949
function diagm(kv::Pair...)
950950
T = promote_type(map(x -> eltype(x.second), kv)...)
951-
n = mapreduce(x -> length(x.second) + abs(x.first), max, kv)
951+
n = Base.mapreduce(x -> length(x.second) + abs(x.first), max, kv)
952952
A = zeros(T, n, n)
953953
for p in kv
954954
inds = diagind(A, p.first)
@@ -1255,7 +1255,7 @@ end
12551255

12561256
@inline function start(iter::CartesianIndices)
12571257
iterfirst, iterlast = first(iter), last(iter)
1258-
if any(map(>, iterfirst.I, iterlast.I))
1258+
if Base.any(map(>, iterfirst.I, iterlast.I))
12591259
return iterlast+1
12601260
end
12611261
iterfirst
@@ -1316,7 +1316,7 @@ end
13161316
@inline function Base.getindex(iter::LinearIndices{N,R}, I::Vararg{Int, N}) where {N,R}
13171317
dims = length.(iter.indices)
13181318
#without the inbounds, this is slower than Base._sub2ind(iter.indices, I...)
1319-
@inbounds result = reshape(1:prod(dims), dims)[(I .- first.(iter.indices) .+ 1)...]
1319+
@inbounds result = reshape(1:Base.prod(dims), dims)[(I .- first.(iter.indices) .+ 1)...]
13201320
return result
13211321
end
13221322
elseif VERSION < v"0.7.0-DEV.3395"
@@ -1709,6 +1709,73 @@ if VERSION < v"0.7.0-DEV.4585"
17091709
const lowercasefirst = lcfirst
17101710
end
17111711

1712+
if VERSION < v"0.7.0-DEV.4064"
1713+
for f in (:mean, :cumsum, :cumprod, :sum, :prod, :maximum, :minimum, :all, :any, :median)
1714+
@eval begin
1715+
$f(a::AbstractArray; dims=nothing) =
1716+
dims===nothing ? Base.$f(a) : Base.$f(a, dims)
1717+
end
1718+
end
1719+
for f in (:sum, :prod, :maximum, :minimum, :all, :any, :accumulate)
1720+
@eval begin
1721+
$f(f, a::AbstractArray; dims=nothing) =
1722+
dims===nothing ? Base.$f(f, a) : Base.$f(f, a, dims)
1723+
end
1724+
end
1725+
for f in (:findmax, :findmin)
1726+
@eval begin
1727+
$f(a::AbstractVector; dims=nothing) =
1728+
dims===nothing ? Base.$f(a) : Base.$f(a, dims)
1729+
function $f(a::AbstractArray; dims=nothing)
1730+
vs, inds = dims===nothing ? Base.$f(a) : Base.$f(a, dims)
1731+
cis = CartesianIndices(a)
1732+
return (vs, map(i -> cis[i], inds))
1733+
end
1734+
end
1735+
end
1736+
for f in (:var, :std, :sort)
1737+
@eval begin
1738+
$f(a::AbstractArray; dims=nothing, kwargs...) =
1739+
dims===nothing ? Base.$f(a; kwargs...) : Base.$f(a, dims; kwargs...)
1740+
end
1741+
end
1742+
for f in (:cumsum!, :cumprod!)
1743+
@eval $f(out, a; dims=nothing) =
1744+
dims===nothing ? Base.$f(out, a) : Base.$f(out, a, dims)
1745+
end
1746+
end
1747+
if VERSION < v"0.7.0-DEV.4064"
1748+
varm(A::AbstractArray, m; dims=nothing, kwargs...) =
1749+
dims===nothing ? Base.varm(A, m; kwargs...) : Base.varm(A, m, dims; kwargs...)
1750+
if VERSION < v"0.7.0-DEV.755"
1751+
cov(a::AbstractMatrix; dims=1, corrected=true) = Base.cov(a, dims, corrected)
1752+
cov(a::AbstractVecOrMat, b::AbstractVecOrMat; dims=1, corrected=true) =
1753+
Base.cov(a, b, dims, corrected)
1754+
else
1755+
cov(a::AbstractMatrix; dims=nothing, kwargs...) =
1756+
dims===nothing ? Base.cov(a; kwargs...) : Base.cov(a, dims; kwargs...)
1757+
cov(a::AbstractVecOrMat, b::AbstractVecOrMat; dims=nothing, kwargs...) =
1758+
dims===nothing ? Base.cov(a, b; kwargs...) : Base.cov(a, b, dims; kwargs...)
1759+
end
1760+
cor(a::AbstractMatrix; dims=nothing) = dims===nothing ? Base.cor(a) : Base.cor(a, dims)
1761+
cor(a::AbstractVecOrMat, b::AbstractVecOrMat; dims=nothing) =
1762+
dims===nothing ? Base.cor(a, b) : Base.cor(a, b, dims)
1763+
mapreduce(f, op, a::AbstractArray; dims=nothing) =
1764+
dims===nothing ? Base.mapreduce(f, op, a) : Base.mapreducedim(f, op, a, dims)
1765+
mapreduce(f, op, v0, a::AbstractArray; dims=nothing) =
1766+
dims===nothing ? Base.mapreduce(f, op, v0, a) : Base.mapreducedim(f, op, a, dims, v0)
1767+
reduce(op, a::AbstractArray; dims=nothing) =
1768+
dims===nothing ? Base.reduce(op, a) : Base.reducedim(op, a, dims)
1769+
reduce(op, v0, a::AbstractArray; dims=nothing) =
1770+
dims===nothing ? Base.reduce(op, v0, a) : Base.reducedim(op, a, dims, v0)
1771+
accumulate!(op, out, a; dims=nothing) =
1772+
dims===nothing ? Base.accumulate!(op, out, a) : Base.accumulate!(op, out, a, dims)
1773+
end
1774+
if VERSION < v"0.7.0-DEV.4534"
1775+
reverse(a::AbstractArray; dims=nothing) =
1776+
dims===nothing ? Base.reverse(a) : Base.flipdim(a, dims)
1777+
end
1778+
17121779
include("deprecated.jl")
17131780

17141781
end # module Compat

test/runtests.jl

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1535,4 +1535,138 @@ end
15351535
@test uppercasefirst("qwerty") == "Qwerty"
15361536
@test lowercasefirst("Qwerty") == "qwerty"
15371537

1538+
# 0.7.0-DEV.4064
1539+
# some tests are behind a version check below because Julia gave
1540+
# the wrong result between 0.7.0-DEV.3262 and 0.7.0-DEV.4646
1541+
# see https://github.com/JuliaLang/julia/issues/26488
1542+
Issue26488 = VERSION < v"0.7.0-DEV.3262" || VERSION >= v"0.7.0-DEV.4646"
1543+
@test Compat.mean([1 2; 3 4]) == 2.5
1544+
@test Compat.mean([1 2; 3 4], dims=1) == [2 3]
1545+
@test Compat.mean([1 2; 3 4], dims=2) == hcat([1.5; 3.5])
1546+
@test Compat.cumsum([1 2; 3 4], dims=1) == [1 2; 4 6]
1547+
@test Compat.cumsum([1 2; 3 4], dims=2) == [1 3; 3 7]
1548+
@test Compat.cumprod([1 2; 3 4], dims=1) == [1 2; 3 8]
1549+
@test Compat.cumprod([1 2; 3 4], dims=2) == [1 2; 3 12]
1550+
@test Compat.sum([1 2; 3 4]) == 10
1551+
@test Compat.sum([1 2; 3 4], dims=1) == [4 6]
1552+
@test Compat.sum([1 2; 3 4], dims=2) == hcat([3; 7])
1553+
@test Compat.sum(x -> x+1, [1 2; 3 4]) == 14
1554+
Issue26488 && @test Compat.sum(x -> x+1, [1 2; 3 4], dims=1) == [6 8]
1555+
Issue26488 && @test Compat.sum(x -> x+1, [1 2; 3 4], dims=2) == hcat([5; 9])
1556+
@test Compat.prod([1 2; 3 4]) == 24
1557+
@test Compat.prod([1 2; 3 4], dims=1) == [3 8]
1558+
@test Compat.prod([1 2; 3 4], dims=2) == hcat([2; 12])
1559+
@test Compat.prod(x -> x+1, [1 2; 3 4]) == 120
1560+
Issue26488 && @test Compat.prod(x -> x+1, [1 2; 3 4], dims=1) == [8 15]
1561+
Issue26488 && @test Compat.prod(x -> x+1, [1 2; 3 4], dims=2) == hcat([6; 20])
1562+
@test Compat.maximum([1 2; 3 4]) == 4
1563+
@test Compat.maximum([1 2; 3 4], dims=1) == [3 4]
1564+
@test Compat.maximum([1 2; 3 4], dims=2) == hcat([2; 4])
1565+
@test Compat.maximum(x -> x+1, [1 2; 3 4]) == 5
1566+
@test Compat.maximum(x -> x+1, [1 2; 3 4], dims=1) == [4 5]
1567+
@test Compat.maximum(x -> x+1, [1 2; 3 4], dims=2) == hcat([3; 5])
1568+
@test Compat.minimum([1 2; 3 4]) == 1
1569+
@test Compat.minimum([1 2; 3 4], dims=1) == [1 2]
1570+
@test Compat.minimum([1 2; 3 4], dims=2) == hcat([1; 3])
1571+
@test Compat.minimum(x -> x+1, [1 2; 3 4]) == 2
1572+
@test Compat.minimum(x -> x+1, [1 2; 3 4], dims=1) == [2 3]
1573+
@test Compat.minimum(x -> x+1, [1 2; 3 4], dims=2) == hcat([2; 4])
1574+
@test Compat.all([true false; true false]) == false
1575+
@test Compat.all([true false; true false], dims=1) == [true false]
1576+
@test Compat.all([true false; true false], dims=2) == hcat([false; false])
1577+
@test Compat.all(isodd, [1 2; 3 4]) == false
1578+
@test Compat.all(isodd, [1 2; 3 4], dims=1) == [true false]
1579+
@test Compat.all(isodd, [1 2; 3 4], dims=2) == hcat([false; false])
1580+
@test Compat.any([true false; true false]) == true
1581+
@test Compat.any([true false; true false], dims=1) == [true false]
1582+
@test Compat.any([true false; true false], dims=2) == hcat([true; true])
1583+
@test Compat.any(isodd, [1 2; 3 4]) == true
1584+
@test Compat.any(isodd, [1 2; 3 4], dims=1) == [true false]
1585+
@test Compat.any(isodd, [1 2; 3 4], dims=2) == hcat([true; true])
1586+
@test Compat.findmax([3, 2, 7, 4]) == (7, 3)
1587+
@test Compat.findmax([3, 2, 7, 4], dims=1) == ([7], [3])
1588+
@test Compat.findmax([1 2; 3 4], dims=1) == ([3 4], [CartesianIndex(2, 1) CartesianIndex(2, 2)])
1589+
@test Compat.findmax([1 2; 3 4]) == (4, CartesianIndex(2, 2))
1590+
@test Compat.findmax([1 2; 3 4], dims=1) == ([3 4], [CartesianIndex(2, 1) CartesianIndex(2, 2)])
1591+
@test Compat.findmax([1 2; 3 4], dims=2) == (hcat([2; 4]), hcat([CartesianIndex(1, 2); CartesianIndex(2, 2)]))
1592+
@test Compat.findmin([3, 2, 7, 4]) == (2, 2)
1593+
@test Compat.findmin([3, 2, 7, 4], dims=1) == ([2], [2])
1594+
@test Compat.findmin([1 2; 3 4]) == (1, CartesianIndex(1, 1))
1595+
@test Compat.findmin([1 2; 3 4], dims=1) == ([1 2], [CartesianIndex(1, 1) CartesianIndex(1, 2)])
1596+
@test Compat.findmin([1 2; 3 4], dims=2) == (hcat([1; 3]), hcat([CartesianIndex(1, 1); CartesianIndex(2, 1)]))
1597+
@test Compat.varm([1 2; 3 4], -1) == 18
1598+
@test Compat.varm([1 2; 3 4], [-1 -2], dims=1) == [20 52]
1599+
@test Compat.varm([1 2; 3 4], [-1, -2], dims=2) == hcat([13, 61])
1600+
@test Compat.var([1 2; 3 4]) == 5/3
1601+
@test Compat.var([1 2; 3 4], dims=1) == [2 2]
1602+
@test Compat.var([1 2; 3 4], dims=2) == hcat([0.5, 0.5])
1603+
@test Compat.var([1 2; 3 4], corrected=false) == 1.25
1604+
@test Compat.var([1 2; 3 4], corrected=false, dims=1) == [1 1]
1605+
@test Compat.var([1 2; 3 4], corrected=false, dims=2) == hcat([0.25, 0.25])
1606+
@test Compat.std([1 2; 3 4]) == sqrt(5/3)
1607+
@test Compat.std([1 2; 3 4], dims=1) == [sqrt(2) sqrt(2)]
1608+
@test Compat.std([1 2; 3 4], dims=2) == hcat([sqrt(0.5), sqrt(0.5)])
1609+
@test Compat.std([1 2; 3 4], corrected=false) == sqrt(1.25)
1610+
@test Compat.std([1 2; 3 4], corrected=false, dims=1) == [sqrt(1) sqrt(1)]
1611+
@test Compat.std([1 2; 3 4], corrected=false, dims=2) == hcat([sqrt(0.25), sqrt(0.25)])
1612+
@test Compat.cov([1 2; 3 4]) == [2 2; 2 2]
1613+
@test Compat.cov([1 2; 3 4], dims=1) == [2 2; 2 2]
1614+
@test Compat.cov([1 2; 3 4], dims=2) == [0.5 0.5; 0.5 0.5]
1615+
@test Compat.cov([1 2; 3 4], [4; 5]) == hcat([1; 1])
1616+
@test Compat.cov([1 2; 3 4], [4; 5], dims=1) == hcat([1; 1])
1617+
@test Compat.cov([1 2; 3 4], [4; 5], dims=2) == hcat([0.5; 0.5])
1618+
@test Compat.cov([1 2; 3 4], [4; 5], corrected=false) == hcat([0.5; 0.5])
1619+
@test Compat.cov([1 2; 3 4], [4; 5], corrected=false, dims=1) == hcat([0.5; 0.5])
1620+
@test Compat.cov([1 2; 3 4], [4; 5], corrected=false, dims=2) == hcat([0.25; 0.25])
1621+
@test Compat.cor([1 2; 3 4]) [1 1; 1 1]
1622+
@test Compat.cor([1 2; 3 4], dims=1) [1 1; 1 1]
1623+
@test Compat.cor([1 2; 3 4], dims=2) [1 1; 1 1]
1624+
@test Compat.cor([1 2; 3 4], [4; 5]) [1; 1]
1625+
@test Compat.cor([1 2; 3 4], [4; 5], dims=1) [1; 1]
1626+
@test Compat.cor([1 2; 3 4], [4; 5], dims=2) [1; 1]
1627+
@test Compat.median([1 2; 3 4]) == 2.5
1628+
@test Compat.median([1 2; 3 4], dims=1) == [2 3]
1629+
@test Compat.median([1 2; 3 4], dims=2) == hcat([1.5; 3.5])
1630+
@test Compat.mapreduce(string, *, [1 2; 3 4]) == "1324"
1631+
Issue26488 && @test Compat.mapreduce(string, *, [1 2; 3 4], dims=1) == ["13" "24"]
1632+
Issue26488 && @test Compat.mapreduce(string, *, [1 2; 3 4], dims=2) == hcat(["12", "34"])
1633+
@test Compat.mapreduce(string, *, "z", [1 2; 3 4]) == "z1324"
1634+
@test Compat.mapreduce(string, *, "z", [1 2; 3 4], dims=1) == ["z13" "z24"]
1635+
@test Compat.mapreduce(string, *, "z", [1 2; 3 4], dims=2) == hcat(["z12", "z34"])
1636+
@test Compat.reduce(*, [1 2; 3 4]) == 24
1637+
@test Compat.reduce(*, [1 2; 3 4], dims=1) == [3 8]
1638+
@test Compat.reduce(*, [1 2; 3 4], dims=2) == hcat([2, 12])
1639+
@test Compat.reduce(*, 10, [1 2; 3 4]) == 240
1640+
@test Compat.reduce(*, 10, [1 2; 3 4], dims=1) == [30 80]
1641+
@test Compat.reduce(*, 10, [1 2; 3 4], dims=2) == hcat([20, 120])
1642+
@test Compat.sort([1, 2, 3, 4]) == [1, 2, 3, 4]
1643+
@test Compat.sort([1 2; 3 4], dims=1) == [1 2; 3 4]
1644+
@test Compat.sort([1 2; 3 4], dims=2) == [1 2; 3 4]
1645+
@test Compat.sort([1, 2, 3, 4], rev=true) == [4, 3, 2, 1]
1646+
@test Compat.sort([1 2; 3 4], rev=true, dims=1) == [3 4; 1 2]
1647+
@test Compat.sort([1 2; 3 4], rev=true, dims=2) == [2 1; 4 3]
1648+
@test Compat.accumulate(*, [1 2; 3 4], dims=1) == [1 2; 3 8]
1649+
@test Compat.accumulate(*, [1 2; 3 4], dims=2) == [1 2; 3 12]
1650+
@test Compat.cumsum([1 2; 3 4], dims=1) == [1 2; 4 6]
1651+
@test Compat.cumsum([1 2; 3 4], dims=2) == [1 3; 3 7]
1652+
@test Compat.cumprod([1 2; 3 4], dims=1) == [1 2; 3 8]
1653+
@test Compat.cumprod([1 2; 3 4], dims=2) == [1 2; 3 12]
1654+
let b = zeros(2,2)
1655+
Compat.accumulate!(*, b, [1 2; 3 4], dims=1)
1656+
@test b == [1 2; 3 8]
1657+
Compat.accumulate!(*, b, [1 2; 3 4], dims=2)
1658+
@test b == [1 2; 3 12]
1659+
Compat.cumsum!(b, [1 2; 3 4], dims=1)
1660+
@test b == [1 2; 4 6]
1661+
Compat.cumsum!(b, [1 2; 3 4], dims=2)
1662+
@test b == [1 3; 3 7]
1663+
Compat.cumprod!(b, [1 2; 3 4], dims=1)
1664+
@test b == [1 2; 3 8]
1665+
Compat.cumprod!(b, [1 2; 3 4], dims=2)
1666+
@test b == [1 2; 3 12]
1667+
end
1668+
@test Compat.reverse([1, 2, 3, 4]) == [4, 3, 2, 1]
1669+
@test Compat.reverse([1 2; 3 4], dims=1) == [3 4; 1 2]
1670+
@test Compat.reverse([1 2; 3 4], dims=2) == [2 1; 4 3]
1671+
15381672
nothing

0 commit comments

Comments
 (0)