Skip to content

Commit 55e23d7

Browse files
committed
Adjust mapreduce to use first element in offset range as well.
Adjust tests
1 parent 151594c commit 55e23d7

File tree

3 files changed

+19
-13
lines changed

3 files changed

+19
-13
lines changed

base/abstractarray.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1917,7 +1917,9 @@ function mapslices(f, A::AbstractArray; dims)
19171917
Rsize = copy(dimsA)
19181918
# TODO: maybe support removing dimensions
19191919
if !isa(r1, AbstractArray) || ndims(r1) == 0
1920-
r1 = [r1]
1920+
tmp = similar(Aslice, typeof(r1), reduced_indices(Aslice, 1:ndims(Aslice)))
1921+
tmp[first(CartesianIndices(tmp))] = r1
1922+
r1 = tmp
19211923
end
19221924
nextra = max(0, length(dims)-ndims(r1))
19231925
if eltype(Rsize) == Int

base/reducedim.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,27 @@ reduced_indices(a::AbstractArray, region) = reduced_indices(axes(a), region)
1717
# for reductions that keep 0 dims as 0
1818
reduced_indices0(a::AbstractArray, region) = reduced_indices0(axes(a), region)
1919

20-
function reduced_indices(inds::Indices{N}, d::Int, rd::AbstractUnitRange) where N
20+
function reduced_indices(inds::Indices{N}, d::Int) where N
2121
d < 1 && throw(ArgumentError("dimension must be ≥ 1, got $d"))
2222
if d == 1
23-
return (oftype(inds[1], rd), tail(inds)...)
23+
return (reduced_index(inds[1]), tail(inds)...)
2424
elseif 1 < d <= N
25-
return tuple(inds[1:d-1]..., oftype(inds[d], rd), inds[d+1:N]...)::typeof(inds)
25+
return tuple(inds[1:d-1]..., oftype(inds[d], reduced_index(inds[d])), inds[d+1:N]...)::typeof(inds)
2626
else
2727
return inds
2828
end
2929
end
30-
reduced_indices(inds::Indices, d::Int) = reduced_indices(inds, d, reduced_index(inds[d]))
3130

3231
function reduced_indices0(inds::Indices{N}, d::Int) where N
3332
d < 1 && throw(ArgumentError("dimension must be ≥ 1, got $d"))
3433
if d <= N
3534
ind = inds[d]
36-
return reduced_indices(inds, d, (isempty(ind) ? ind : reduced_index(inds[d])))
35+
rd = isempty(ind) ? ind : reduced_index(inds[d])
36+
if d == 1
37+
return (rd, tail(inds)...)
38+
else
39+
return tuple(inds[1:d-1]..., oftype(inds[d], rd), inds[d+1:N]...)::typeof(inds)
40+
end
3741
else
3842
return inds
3943
end

test/offsetarray.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -339,9 +339,9 @@ A = OffsetArray(rand(4,4), (-3,5))
339339
@test maximum(A) == maximum(parent(A))
340340
@test minimum(A) == minimum(parent(A))
341341
@test extrema(A) == extrema(parent(A))
342-
@test maximum(A, dims=1) == OffsetArray(maximum(parent(A), dims=1), (0,A.offsets[2]))
343-
@test maximum(A, dims=2) == OffsetArray(maximum(parent(A), dims=2), (A.offsets[1],0))
344-
@test maximum(A, dims=1:2) == maximum(parent(A), dims=1:2)
342+
@test maximum(A, dims=1) == OffsetArray(maximum(parent(A), dims=1), A.offsets)
343+
@test maximum(A, dims=2) == OffsetArray(maximum(parent(A), dims=2), A.offsets)
344+
@test maximum(A, dims=1:2) == OffsetArray(maximum(parent(A), dims=1:2), A.offsets)
345345
C = similar(A)
346346
cumsum!(C, A, dims=1)
347347
@test parent(C) == cumsum(parent(A), dims=1)
@@ -373,11 +373,11 @@ I = findall(!iszero, z)
373373
@test findall(x->x==0, h) == [2]
374374
@test mean(A_3_3) == median(A_3_3) == 5
375375
@test mean(x->2x, A_3_3) == 10
376-
@test mean(A_3_3, dims=1) == median(A_3_3, dims=1) == OffsetArray([2 5 8], (0,A_3_3.offsets[2]))
377-
@test mean(A_3_3, dims=2) == median(A_3_3, dims=2) == OffsetArray(reshape([4,5,6],(3,1)), (A_3_3.offsets[1],0))
376+
@test mean(A_3_3, dims=1) == median(A_3_3, dims=1) == OffsetArray([2 5 8], A_3_3.offsets)
377+
@test mean(A_3_3, dims=2) == median(A_3_3, dims=2) == OffsetArray(reshape([4,5,6],(3,1)), A_3_3.offsets)
378378
@test var(A_3_3) == 7.5
379-
@test std(A_3_3, dims=1) == OffsetArray([1 1 1], (0,A_3_3.offsets[2]))
380-
@test std(A_3_3, dims=2) == OffsetArray(reshape([3,3,3], (3,1)), (A_3_3.offsets[1],0))
379+
@test std(A_3_3, dims=1) == OffsetArray([1 1 1], A_3_3.offsets)
380+
@test std(A_3_3, dims=2) == OffsetArray(reshape([3,3,3], (3,1)), A_3_3.offsets)
381381
@test sum(OffsetArray(fill(1,3000), -1000)) == 3000
382382

383383
@test norm(v) norm(parent(v))

0 commit comments

Comments
 (0)