Skip to content

Commit cc8a278

Browse files
Seppo Enarvipre-commit-ci[bot]
authored andcommitted
Two fixes for handling edge cases in MLflow logging (#16451)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit 9346151)
1 parent 2ed069a commit cc8a278

File tree

3 files changed

+40
-23
lines changed

3 files changed

+40
-23
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2323

2424
- Fixed an issue with `MLFlowLogger` logging the wrong keys with `.log_hyperparams()` ([#16418](https://github.com/Lightning-AI/lightning/pull/16418))
2525

26+
- Fixed logging more than 100 parameters with `MLFlowLogger` and long values are truncated ([#16451](https://github.com/Lightning-AI/lightning/pull/16451))
27+
2628

2729

2830
## [1.9.0] - 2023-01-17

src/pytorch_lightning/loggers/mlflow.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -238,18 +238,14 @@ def experiment_id(self) -> Optional[str]:
238238
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
239239
params = _convert_params(params)
240240
params = _flatten_dict(params)
241-
params_list: List[Param] = []
242241

243-
for k, v in params.items():
244-
# TODO: mlflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0
245-
if len(str(v)) > 250:
246-
rank_zero_warn(
247-
f"Mlflow only allows parameters with up to 250 characters. Discard {k}={v}", category=RuntimeWarning
248-
)
249-
continue
250-
params_list.append(Param(key=k, value=v))
242+
# Truncate parameter values to 250 characters.
243+
# TODO: MLflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0
244+
params_list = [Param(key=k, value=str(v)[:250]) for k, v in params.items()]
251245

252-
self.experiment.log_batch(run_id=self.run_id, params=params_list)
246+
# Log in chunks of 100 parameters (the maximum allowed by MLflow).
247+
for idx in range(0, len(params_list), 100):
248+
self.experiment.log_batch(run_id=self.run_id, params=params_list[idx : idx + 100])
253249

254250
@rank_zero_only
255251
def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None:

tests/tests_pytorch/loggers/test_mlflow.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -224,19 +224,6 @@ def test_mlflow_logger_with_unexpected_characters(client, _, __, tmpdir):
224224
logger.log_metrics(metrics)
225225

226226

227-
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
228-
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
229-
def test_mlflow_logger_with_long_param_value(client, _, tmpdir):
230-
"""Test that the logger raises warning with special characters not accepted by MLFlow."""
231-
logger = MLFlowLogger("test", save_dir=tmpdir)
232-
value = "test" * 100
233-
key = "test_param"
234-
params = {key: value}
235-
236-
with pytest.warns(RuntimeWarning, match=f"Discard {key}={value}"):
237-
logger.log_hyperparams(params)
238-
239-
240227
@mock.patch("pytorch_lightning.loggers.mlflow.Metric")
241228
@mock.patch("pytorch_lightning.loggers.mlflow.Param")
242229
@mock.patch("pytorch_lightning.loggers.mlflow.time")
@@ -270,6 +257,38 @@ def test_mlflow_logger_experiment_calls(client, _, time, param, metric, tmpdir):
270257
)
271258

272259

260+
def _check_value_length(value, *args, **kwargs):
261+
assert len(value) <= 250
262+
263+
264+
@mock.patch("pytorch_lightning.loggers.mlflow.Param", side_effect=_check_value_length)
265+
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
266+
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
267+
def test_mlflow_logger_with_long_param_value(client, _, param, tmpdir):
268+
"""Test that long parameter values are truncated to 250 characters."""
269+
logger = MLFlowLogger("test", save_dir=tmpdir)
270+
271+
params = {"test": "test_param" * 50}
272+
logger.log_hyperparams(params)
273+
274+
# assert_called_once_with() won't properly check the parameter value.
275+
logger.experiment.log_batch.assert_called_once()
276+
277+
278+
@mock.patch("pytorch_lightning.loggers.mlflow.Param")
279+
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
280+
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
281+
def test_mlflow_logger_with_many_params(client, _, param, tmpdir):
282+
"""Test that the when logging more than 100 parameters, it will be split into batches of at most 100
283+
parameters."""
284+
logger = MLFlowLogger("test", save_dir=tmpdir)
285+
286+
params = {f"test_{idx}": f"test_param_{idx}" for idx in range(150)}
287+
logger.log_hyperparams(params)
288+
289+
assert logger.experiment.log_batch.call_count == 2
290+
291+
273292
@pytest.mark.parametrize(
274293
"status,expected",
275294
[

0 commit comments

Comments
 (0)