Skip to content

Commit d03afff

Browse files
committed
TST: Calculate RMS and diff image in C++
The current implementation is not slow, but uses a lot of memory per image. In `compare_images`, we have: - one actual and one expected image as uint8 (2×image) - both converted to int16 (though original is thrown away) (4×) which adds up to 4× the image allocated in this function. Then it calls `calculate_rms`, which has: - a difference between them as int16 (2×) - the difference cast to 64-bit float (8×) - the square of the difference as 64-bit float (though possibly the original difference was thrown away) (8×) which at its peak has 16× the image allocated in parallel. If the RMS is over the desired tolerance, then `save_diff_image` is called, which: - loads the actual and expected images _again_ as uint8 (2× image) - converts both to 64-bit float (throwing away the original) (16×) - calculates the difference (8×) - calculates the absolute value (8×) - multiples that by 10 (in-place, so no allocation) - clips to 0-255 (8×) - casts to uint8 (1×) which at peak uses 32× the image. So at their peak, `compare_images`→`calculate_rms` will have 20× the image allocated, and then `compare_images`→`save_diff_image` will have 36× the image allocated. This is generally not a problem, but on resource-constrained places like WASM, it can sometimes run out of memory just in `calculate_rms`. This implementation in C++ always allocates the diff image, even when not needed, but doesn't have all the temporaries, so it's a maximum of 3× the image size (plus a few scalar temporaries).
1 parent 38a8e15 commit d03afff

File tree

2 files changed

+80
-9
lines changed

2 files changed

+80
-9
lines changed

lib/matplotlib/testing/compare.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from PIL import Image
2020

2121
import matplotlib as mpl
22-
from matplotlib import cbook
22+
from matplotlib import cbook, _image
2323
from matplotlib.testing.exceptions import ImageComparisonFailure
2424

2525
_log = logging.getLogger(__name__)
@@ -398,7 +398,7 @@ def compare_images(expected, actual, tol, in_decorator=False):
398398
399399
The two given filenames may point to files which are convertible to
400400
PNG via the `.converter` dictionary. The underlying RMS is calculated
401-
with the `.calculate_rms` function.
401+
in a similar way to the `.calculate_rms` function.
402402
403403
Parameters
404404
----------
@@ -469,17 +469,12 @@ def compare_images(expected, actual, tol, in_decorator=False):
469469
if np.array_equal(expected_image, actual_image):
470470
return None
471471

472-
# convert to signed integers, so that the images can be subtracted without
473-
# overflow
474-
expected_image = expected_image.astype(np.int16)
475-
actual_image = actual_image.astype(np.int16)
476-
477-
rms = calculate_rms(expected_image, actual_image)
472+
rms, abs_diff = _image.calculate_rms_and_diff(expected_image, actual_image)
478473

479474
if rms <= tol:
480475
return None
481476

482-
save_diff_image(expected, actual, diff_image)
477+
Image.fromarray(abs_diff).save(diff_image, format="png")
483478

484479
results = dict(rms=rms, expected=str(expected),
485480
actual=str(actual), diff=str(diff_image), tol=tol)

src/_image_wrapper.cpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include <pybind11/pybind11.h>
22
#include <pybind11/numpy.h>
33

4+
#include <algorithm>
5+
46
#include "_image_resample.h"
57
#include "py_converters.h"
68

@@ -200,6 +202,77 @@ image_resample(py::array input_array,
200202
}
201203

202204

205+
// This is used by matplotlib.testing.compare to calculate RMS and a difference image.
206+
static py::tuple
207+
calculate_rms_and_diff(py::array_t<unsigned char> expected_image,
208+
py::array_t<unsigned char> actual_image)
209+
{
210+
if (expected_image.ndim() != 3) {
211+
auto exceptions = py::module_::import("matplotlib.testing.exceptions");
212+
auto ImageComparisonFailure = exceptions.attr("ImageComparisonFailure");
213+
py::set_error(
214+
ImageComparisonFailure,
215+
"Expected image must be 3-dimensional, but is {ndim}-dimensional"_s.format(
216+
"ndim"_a=expected_image.ndim()));
217+
throw py::error_already_set();
218+
}
219+
220+
if (actual_image.ndim() != 3) {
221+
auto exceptions = py::module_::import("matplotlib.testing.exceptions");
222+
auto ImageComparisonFailure = exceptions.attr("ImageComparisonFailure");
223+
py::set_error(
224+
ImageComparisonFailure,
225+
"Actual image must be 3-dimensional, but is {ndim}-dimensional"_s.format(
226+
"ndim"_a=actual_image.ndim()));
227+
throw py::error_already_set();
228+
}
229+
230+
auto height = expected_image.shape(0);
231+
auto width = expected_image.shape(1);
232+
auto depth = expected_image.shape(2);
233+
234+
if (height != actual_image.shape(0) || width != actual_image.shape(1) ||
235+
depth != actual_image.shape(2)) {
236+
auto exceptions = py::module_::import("matplotlib.testing.exceptions");
237+
auto ImageComparisonFailure = exceptions.attr("ImageComparisonFailure");
238+
py::set_error(
239+
ImageComparisonFailure,
240+
"Image sizes do not match expected size: {expected_image.shape} "_s
241+
"actual size {actual_image.shape}"_s.format(
242+
"expected_image"_a=expected_image, "actual_image"_a=actual_image));
243+
throw py::error_already_set();
244+
}
245+
auto expected = expected_image.unchecked<3>();
246+
auto actual = actual_image.unchecked<3>();
247+
248+
py::ssize_t diff_dims[3] = {height, width, 3};
249+
py::array_t<unsigned char> diff_image(diff_dims);
250+
auto diff = diff_image.mutable_unchecked<3>();
251+
252+
double total = 0.0;
253+
for (auto i = 0; i < height; i++) {
254+
for (auto j = 0; j < width; j++) {
255+
for (auto k = 0; k < depth; k++) {
256+
auto pixel_diff = static_cast<double>(expected(i, j, k)) -
257+
static_cast<double>(actual(i, j, k));
258+
259+
total += pixel_diff*pixel_diff;
260+
261+
if (k != 3) { // Hard-code a fully solid alpha channel by omitting it.
262+
diff(i, j, k) = static_cast<unsigned char>(std::clamp(
263+
// expand differences in luminance domain
264+
abs(pixel_diff) * 10,
265+
0.0, 255.0));
266+
}
267+
}
268+
}
269+
}
270+
total = total / (width * height * depth);
271+
272+
return py::make_tuple(sqrt(total), diff_image);
273+
}
274+
275+
203276
PYBIND11_MODULE(_image, m, py::mod_gil_not_used())
204277
{
205278
py::enum_<interpolation_e>(m, "_InterpolationType")
@@ -232,4 +305,7 @@ PYBIND11_MODULE(_image, m, py::mod_gil_not_used())
232305
"norm"_a = false,
233306
"radius"_a = 1,
234307
image_resample__doc__);
308+
309+
m.def("calculate_rms_and_diff", &calculate_rms_and_diff,
310+
"expected_image"_a, "actual_image"_a);
235311
}

0 commit comments

Comments
 (0)