Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 28 additions & 22 deletions pytest_mpl/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import shutil
import hashlib
import inspect
import logging
import tempfile
import warnings
import contextlib
Expand All @@ -54,27 +55,6 @@
{actual_path}"""


def _download_file(baseline, filename):
# Note that baseline can be a comma-separated list of URLs that we can
# then treat as mirrors
for base_url in baseline.split(','):
try:
u = urlopen(base_url + filename)
content = u.read()
except Exception as e:
warnings.warn('Downloading {0} failed: {1}'.format(base_url + filename, e))
else:
break
else:
raise Exception("Could not download baseline image from any of the "
"available URLs")
result_dir = Path(tempfile.mkdtemp())
filename = result_dir / 'downloaded'
with open(str(filename), 'wb') as tmpfile:
tmpfile.write(content)
return Path(filename)


def _hash_file(in_stream):
"""
Hashes an already opened file.
Expand Down Expand Up @@ -292,6 +272,12 @@ def __init__(self,
self._test_results = {}
self._test_stats = None

# https://stackoverflow.com/questions/51737378/how-should-i-log-in-my-pytest-plugin
# turn debug prints on only if "-vv" or more passed
level = logging.DEBUG if config.option.verbose > 1 else logging.INFO
logging.basicConfig(level=level)
self.logger = logging.getLogger('pytest-mpl')

def get_compare(self, item):
"""
Return the mpl_image_compare marker for the given item.
Expand Down Expand Up @@ -364,6 +350,26 @@ def get_baseline_directory(self, item):

return baseline_dir

def _download_file(self, baseline, filename):
# Note that baseline can be a comma-separated list of URLs that we can
# then treat as mirrors
for base_url in baseline.split(','):
try:
u = urlopen(base_url + filename)
content = u.read()
except Exception as e:
self.logger.info(f'Downloading {base_url + filename} failed: {repr(e)}')
else:
break
else:
raise Exception("Could not download baseline image from any of the "
"available URLs")
result_dir = Path(tempfile.mkdtemp())
filename = result_dir / 'downloaded'
with open(str(filename), 'wb') as tmpfile:
tmpfile.write(content)
return Path(filename)

def obtain_baseline_image(self, item, target_dir):
"""
Copy the baseline image to our working directory.
Expand All @@ -378,7 +384,7 @@ def obtain_baseline_image(self, item, target_dir):
if baseline_remote:
# baseline_dir can be a list of URLs when remote, so we have to
# pass base and filename to download
baseline_image = _download_file(baseline_dir, filename)
baseline_image = self._download_file(baseline_dir, filename)
else:
baseline_image = (baseline_dir / filename).absolute()

Expand Down