Skip to content

Commit b2351cb

Browse files
Use broadcast_like for 2d plot coordinates (#5099)
* Use broadcast_like for 2d plot coordinates Use broadcast_like if either `x` or `y` inputs are 2d to ensure that both have dimensions in the same order as the DataArray being plotted. Convert to numpy arrays after possibly using broadcast_like. Simplifies code, and fixes #5097 (bug when dimensions have the same size). * Update whats-new * Test for issue 5097 * Fix typo in doc/whats-new.rst Co-authored-by: Mathias Hauser <[email protected]> * Update doc/whats-new.rst Co-authored-by: Mathias Hauser <[email protected]>
1 parent d58a511 commit b2351cb

File tree

3 files changed

+44
-21
lines changed

3 files changed

+44
-21
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ Deprecations
9191

9292
Bug fixes
9393
~~~~~~~~~
94+
- Fix 2d plot failure for certain combinations of dimensions when `x` is 1d and `y` is
95+
2d (:issue:`5097`, :pull:`5099`). By `John Omotani <https://github.com/johnomotani>`_.
9496
- Ensure standard calendar times encoded with large values (i.e. greater than approximately 292 years), can be decoded correctly without silently overflowing (:pull:`5050`). This was a regression in xarray 0.17.0. By `Zeb Nicholls <https://github.com/znicholls>`_.
9597
- Added support for `numpy.bool_` attributes in roundtrips using `h5netcdf` engine with `invalid_netcdf=True` [which casts `bool`s to `numpy.bool_`] (:issue:`4981`, :pull:`4986`).
9698
By `Victor Negîrneac <https://github.com/caenrigen>`_.

xarray/plot/plot.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -671,28 +671,21 @@ def newplotfunc(
671671
darray=darray, x=x, y=y, imshow=imshow_rgb, rgb=rgb
672672
)
673673

674-
# better to pass the ndarrays directly to plotting functions
675-
xval = darray[xlab].values
676-
yval = darray[ylab].values
677-
678-
# check if we need to broadcast one dimension
679-
if xval.ndim < yval.ndim:
680-
dims = darray[ylab].dims
681-
if xval.shape[0] == yval.shape[0]:
682-
xval = np.broadcast_to(xval[:, np.newaxis], yval.shape)
683-
else:
684-
xval = np.broadcast_to(xval[np.newaxis, :], yval.shape)
685-
686-
elif yval.ndim < xval.ndim:
687-
dims = darray[xlab].dims
688-
if yval.shape[0] == xval.shape[0]:
689-
yval = np.broadcast_to(yval[:, np.newaxis], xval.shape)
690-
else:
691-
yval = np.broadcast_to(yval[np.newaxis, :], xval.shape)
692-
elif xval.ndim == 2:
693-
dims = darray[xlab].dims
674+
xval = darray[xlab]
675+
yval = darray[ylab]
676+
677+
if xval.ndim > 1 or yval.ndim > 1:
678+
# Passing 2d coordinate values, need to ensure they are transposed the same
679+
# way as darray
680+
xval = xval.broadcast_like(darray)
681+
yval = yval.broadcast_like(darray)
682+
dims = darray.dims
694683
else:
695-
dims = (darray[ylab].dims[0], darray[xlab].dims[0])
684+
dims = (yval.dims[0], xval.dims[0])
685+
686+
# better to pass the ndarrays directly to plotting functions
687+
xval = xval.values
688+
yval = yval.values
696689

697690
# May need to transpose for correct x, y labels
698691
# xlab may be the name of a coord, we have to check for dim names

xarray/tests/test_plot.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,34 @@ def test2d_1d_2d_coordinates_contourf(self):
368368
a.plot.contourf(x="time", y="depth")
369369
a.plot.contourf(x="depth", y="time")
370370

371+
def test2d_1d_2d_coordinates_pcolormesh(self):
372+
# Test with equal coordinates to catch bug from #5097
373+
sz = 10
374+
y2d, x2d = np.meshgrid(np.arange(sz), np.arange(sz))
375+
a = DataArray(
376+
easy_array((sz, sz)),
377+
dims=["x", "y"],
378+
coords={"x2d": (["x", "y"], x2d), "y2d": (["x", "y"], y2d)},
379+
)
380+
381+
for x, y in [
382+
("x", "y"),
383+
("y", "x"),
384+
("x2d", "y"),
385+
("y", "x2d"),
386+
("x", "y2d"),
387+
("y2d", "x"),
388+
("x2d", "y2d"),
389+
("y2d", "x2d"),
390+
]:
391+
p = a.plot.pcolormesh(x=x, y=y)
392+
v = p.get_paths()[0].vertices
393+
394+
# Check all vertices are different, except last vertex which should be the
395+
# same as the first
396+
_, unique_counts = np.unique(v[:-1], axis=0, return_counts=True)
397+
assert np.all(unique_counts == 1)
398+
371399
def test_contourf_cmap_set(self):
372400
a = DataArray(easy_array((4, 4)), dims=["z", "time"])
373401

0 commit comments

Comments
 (0)