From 620f2cede6ad138a77c72a37b9e40e6686b65370 Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Fri, 2 May 2025 18:55:19 +0200 Subject: [PATCH 01/29] updating evaluation to use new splitters --- moabb/evaluations/evaluations.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/moabb/evaluations/evaluations.py b/moabb/evaluations/evaluations.py index c72f2e69d..e692acee7 100644 --- a/moabb/evaluations/evaluations.py +++ b/moabb/evaluations/evaluations.py @@ -19,6 +19,7 @@ from tqdm import tqdm from moabb.evaluations.base import BaseEvaluation +from moabb.evaluations.splitters import CrossSessionSplitter, CrossSubjectSplitter from moabb.evaluations.utils import create_save_path, save_model_cv, save_model_list @@ -516,7 +517,8 @@ def evaluate( tracker.start() # we want to store a results per session - cv = LeaveOneGroupOut() + cv = CrossSessionSplitter(random_state=self.random_state) + inner_cv = StratifiedKFold( 3, shuffle=True, random_state=self.random_state ) @@ -538,8 +540,7 @@ def evaluate( grid=False, eval_type="CrossSession", ) - - for cv_ind, (train, test) in enumerate(cv.split(X, y, groups)): + for cv_ind, (train, test) in enumerate(cv.split(y, metadata)): model_list = [] if _carbonfootprint: tracker.start() @@ -695,12 +696,19 @@ def evaluate( scorer = get_scorer(self.paradigm.scoring) # perform leave one subject out CV + if self.n_splits is None: - cv = LeaveOneGroupOut() + cv_class = LeaveOneGroupOut + cv_kwargs = {} else: - cv = GroupKFold(n_splits=self.n_splits) + cv_class = GroupKFold + cv_kwargs = {"n_splits": self.n_splits} n_subjects = self.n_splits + cv = CrossSubjectSplitter( + cv_class=cv_class, random_state=self.random_state, **cv_kwargs + ) + inner_cv = StratifiedKFold(3, shuffle=True, random_state=self.random_state) # Implement Grid Search @@ -712,7 +720,7 @@ def evaluate( # Progressbar at subject level for cv_ind, (train, test) in enumerate( tqdm( - cv.split(X, y, groups), + cv.split(y, metadata), total=n_subjects, desc=f"{dataset.code}-CrossSubject", ) From 7d6b83ff44a90b890b15ddd514dfa763c95971e6 Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Fri, 2 May 2025 19:52:08 +0200 Subject: [PATCH 02/29] cross-subject --- moabb/evaluations/evaluations.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/moabb/evaluations/evaluations.py b/moabb/evaluations/evaluations.py index e692acee7..e3a637235 100644 --- a/moabb/evaluations/evaluations.py +++ b/moabb/evaluations/evaluations.py @@ -19,7 +19,10 @@ from tqdm import tqdm from moabb.evaluations.base import BaseEvaluation -from moabb.evaluations.splitters import CrossSessionSplitter, CrossSubjectSplitter +from moabb.evaluations.splitters import ( + CrossSessionSplitter, + CrossSubjectSplitter, +) from moabb.evaluations.utils import create_save_path, save_model_cv, save_model_list @@ -696,7 +699,6 @@ def evaluate( scorer = get_scorer(self.paradigm.scoring) # perform leave one subject out CV - if self.n_splits is None: cv_class = LeaveOneGroupOut cv_kwargs = {} From e24d8e7f3df2361a2979d5d69a2c878d15bb3f2a Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Fri, 2 May 2025 20:04:30 +0200 Subject: [PATCH 03/29] including the whats new --- docs/source/whats_new.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/whats_new.rst b/docs/source/whats_new.rst index e54e3be99..e6a09b0fe 100644 --- a/docs/source/whats_new.rst +++ b/docs/source/whats_new.rst @@ -27,7 +27,7 @@ Enhancements - Adding :func:`moabb.analysis.plotting.dataset_bubble_plot` plus the corresponding tutorial (:gh:`753` by `Pierre Guetschel`_) - Adding :func:`moabb.datasets.utils.plot_all_datasets` and update the tutorial (:gh:`758` by `Pierre Guetschel`_) - Improve the dataset model cards in each API page (:gh:`765` by `Pierre Guetschel`_) - +- Using the splitters in the evaluation (:gh:`766` by `Bruno Aristimunha`_) Bugs ~~~~ From 8597ad1cd7f1f7ac2424ac40d9f1095b8d714983 Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Fri, 2 May 2025 20:05:22 +0200 Subject: [PATCH 04/29] updating the whats new --- docs/source/whats_new.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/whats_new.rst b/docs/source/whats_new.rst index e6a09b0fe..e0b0bd1df 100644 --- a/docs/source/whats_new.rst +++ b/docs/source/whats_new.rst @@ -27,7 +27,7 @@ Enhancements - Adding :func:`moabb.analysis.plotting.dataset_bubble_plot` plus the corresponding tutorial (:gh:`753` by `Pierre Guetschel`_) - Adding :func:`moabb.datasets.utils.plot_all_datasets` and update the tutorial (:gh:`758` by `Pierre Guetschel`_) - Improve the dataset model cards in each API page (:gh:`765` by `Pierre Guetschel`_) -- Using the splitters in the evaluation (:gh:`766` by `Bruno Aristimunha`_) +- Using the splitters in the evaluation (:gh:`769` by `Bruno Aristimunha`_) Bugs ~~~~ From e64d2e1dc2c564f51128cd801137ab3c0e03fd9e Mon Sep 17 00:00:00 2001 From: Bru Date: Sat, 3 May 2025 17:35:25 +0200 Subject: [PATCH 05/29] Update docs/source/whats_new.rst Co-authored-by: Pierre Guetschel <25532709+PierreGtch@users.noreply.github.com> Signed-off-by: Bru --- docs/source/whats_new.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/whats_new.rst b/docs/source/whats_new.rst index e0b0bd1df..215698146 100644 --- a/docs/source/whats_new.rst +++ b/docs/source/whats_new.rst @@ -27,7 +27,7 @@ Enhancements - Adding :func:`moabb.analysis.plotting.dataset_bubble_plot` plus the corresponding tutorial (:gh:`753` by `Pierre Guetschel`_) - Adding :func:`moabb.datasets.utils.plot_all_datasets` and update the tutorial (:gh:`758` by `Pierre Guetschel`_) - Improve the dataset model cards in each API page (:gh:`765` by `Pierre Guetschel`_) -- Using the splitters in the evaluation (:gh:`769` by `Bruno Aristimunha`_) +- Refactor :class:`moabb.evaluation.CrossSessionEvaluation` and :class:`moabb.evaluation.CrossSubjectEvaluation` to use the new splitter classes (:gh:`769` by `Bruno Aristimunha`_) Bugs ~~~~ From 062d1c33f07f8b90b85e3e25340b0d7ba804ea9d Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Sat, 3 May 2025 19:29:14 +0200 Subject: [PATCH 06/29] simple fit for everybody --- moabb/evaluations/evaluations.py | 40 ++++++++++---------------------- 1 file changed, 12 insertions(+), 28 deletions(-) diff --git a/moabb/evaluations/evaluations.py b/moabb/evaluations/evaluations.py index e3a637235..67f57701b 100644 --- a/moabb/evaluations/evaluations.py +++ b/moabb/evaluations/evaluations.py @@ -14,7 +14,7 @@ StratifiedShuffleSplit, cross_validate, ) -from sklearn.model_selection._validation import _fit_and_score, _score +from sklearn.model_selection._validation import _score from sklearn.preprocessing import LabelEncoder from tqdm import tqdm @@ -548,35 +548,19 @@ def evaluate( if _carbonfootprint: tracker.start() t_start = time() - if isinstance(X, BaseEpochs): - cvclf = clone(grid_clf) - cvclf.fit(X[train], y[train]) - model_list.append(cvclf) - score = scorer(cvclf, X[test], y[test]) - if self.hdf5_path is not None and self.save_model: - save_model_cv( - model=cvclf, - save_path=model_save_path, - cv_index=str(cv_ind), - ) - else: - result = _fit_and_score( - estimator=clone(grid_clf), - X=X, - y=y, - scorer=scorer, - train=train, - test=test, - verbose=False, - parameters=None, - fit_params=None, - error_score=self.error_score, - return_estimator=True, - score_params={}, + cvclf = clone(grid_clf) + cvclf.fit(X[train], y[train]) + model_list.append(cvclf) + score = scorer(cvclf, X[test], y[test]) + + if self.hdf5_path is not None and self.save_model: + save_model_cv( + model=cvclf, + save_path=model_save_path, + cv_index=str(cv_ind), ) - score = result["test_scores"] - model_list = result["estimator"] + if _carbonfootprint: emissions = tracker.stop() if emissions is None: From 68da5d12f2bf95753ed52fd4d4ba7e8e3b0c8a92 Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Mon, 5 May 2025 09:34:13 +0200 Subject: [PATCH 07/29] updating the splitter --- moabb/evaluations/evaluations.py | 55 ++++++++++++-------------------- 1 file changed, 20 insertions(+), 35 deletions(-) diff --git a/moabb/evaluations/evaluations.py b/moabb/evaluations/evaluations.py index 67f57701b..d7cc7d76c 100644 --- a/moabb/evaluations/evaluations.py +++ b/moabb/evaluations/evaluations.py @@ -12,7 +12,6 @@ LeaveOneGroupOut, StratifiedKFold, StratifiedShuffleSplit, - cross_validate, ) from sklearn.model_selection._validation import _score from sklearn.preprocessing import LabelEncoder @@ -22,6 +21,7 @@ from moabb.evaluations.splitters import ( CrossSessionSplitter, CrossSubjectSplitter, + WithinSessionSplitter, ) from moabb.evaluations.utils import create_save_path, save_model_cv, save_model_list @@ -176,7 +176,11 @@ def _evaluate( tracker = EmissionsTracker(save_to_file=False, log_level="error") tracker.start() t_start = time() - cv = StratifiedKFold(5, shuffle=True, random_state=self.random_state) + cv = WithinSessionSplitter( + n_folds=self.n_splits, + shuffle=True, + random_state=self.random_state, + ) inner_cv = StratifiedKFold( 3, shuffle=True, random_state=self.random_state ) @@ -217,44 +221,25 @@ def _evaluate( eval_type="WithinSession", ) - if isinstance(X, BaseEpochs): - scorer = get_scorer(self.paradigm.scoring) - acc = list() - X_ = X[ix] - y_ = y[ix] if self.mne_labels else y_cv - for cv_ind, (train, test) in enumerate(cv.split(X_, y_)): - cvclf = clone(grid_clf) - cvclf.fit(X_[train], y_[train]) - acc.append(scorer(cvclf, X_[test], y_[test])) - - if self.hdf5_path is not None and self.save_model: - save_model_cv( - model=cvclf, - save_path=model_save_path, - cv_index=cv_ind, - ) + scorer = get_scorer(self.paradigm.scoring) + acc = list() + X_ = X[ix] + y_ = y[ix] if self.mne_labels else y_cv + for cv_ind, (train, test) in enumerate(cv.split(X_, y_)): + cvclf = clone(grid_clf) + cvclf.fit(X_[train], y_[train]) + acc.append(scorer(cvclf, X_[test], y_[test])) - acc = np.array(acc) - score = acc.mean() - else: - results = cross_validate( - grid_clf, - X[ix], - y_cv, - cv=cv, - scoring=self.paradigm.scoring, - n_jobs=self.n_jobs, - error_score=self.error_score, - return_estimator=True, - ) - score = results["test_score"].mean() if self.hdf5_path is not None and self.save_model: - save_model_list( - results["estimator"], - score_list=results["test_score"], + save_model_cv( + model=cvclf, save_path=model_save_path, + cv_index=cv_ind, ) + acc = np.array(acc) + score = acc.mean() + if _carbonfootprint: emissions = tracker.stop() if emissions is None: From 3fadc5cce1ca1e108b2c8e1f7211fdb2e4587128 Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Mon, 5 May 2025 10:01:41 +0200 Subject: [PATCH 08/29] parallel evaluation now --- moabb/evaluations/base.py | 59 ++++++++++++++++++++++---------- moabb/evaluations/evaluations.py | 5 ++- 2 files changed, 45 insertions(+), 19 deletions(-) diff --git a/moabb/evaluations/base.py b/moabb/evaluations/base.py index c15c2699d..9a9ae1d09 100644 --- a/moabb/evaluations/base.py +++ b/moabb/evaluations/base.py @@ -3,6 +3,7 @@ from warnings import warn import pandas as pd +from joblib import Parallel, delayed from sklearn.base import BaseEstimator from sklearn.model_selection import GridSearchCV @@ -97,7 +98,7 @@ def __init__( return_epochs=False, return_raws=False, mne_labels=False, - n_splits=None, + n_splits: int = 5, save_model=False, cache_config=None, optuna=False, @@ -201,7 +202,6 @@ def process(self, pipelines, param_grid=None, postprocess_pipeline=None): This pipeline must be "fixed" because it will not be trained, i.e. no call to ``fit`` will be made. - Returns ------- results: pd.DataFrame @@ -216,26 +216,49 @@ def process(self, pipelines, param_grid=None, postprocess_pipeline=None): if not (isinstance(pipeline, BaseEstimator)): raise (ValueError("pipelines must only contains Pipelines " "instance")) - res_per_db = [] - for dataset in self.datasets: - log.info("Processing dataset: {}".format(dataset.code)) - process_pipeline = self.paradigm.make_process_pipelines( + # Prepare dataset processing parameters + processing_params = [ + ( dataset, - return_epochs=self.return_epochs, - return_raws=self.return_raws, - postprocess_pipeline=postprocess_pipeline, - )[0] - # (we only keep the pipeline for the first frequency band, better ideas?) - - results = self.evaluate( - dataset, - pipelines, - param_grid=param_grid, - process_pipeline=process_pipeline, - postprocess_pipeline=postprocess_pipeline, + self.paradigm.make_process_pipelines( + dataset, + return_epochs=self.return_epochs, + return_raws=self.return_raws, + postprocess_pipeline=postprocess_pipeline, + )[0], ) + for dataset in self.datasets + ] + + # Parallel processing... + parallel_results = Parallel( + n_jobs=self.n_jobs, return_as="generator", verbose=10 + )( + delayed( + lambda dataset_processor: list( + self.evaluate( + dataset_processor[0], # dataset + pipelines, + param_grid=param_grid, + process_pipeline=dataset_processor[1], + # process_pipeline + postprocess_pipeline=postprocess_pipeline, + ) + ) + )( + params + ) # Pass parameters here + for params in processing_params + ) + + res_per_db = [] + # Process results in order + for (dataset, process_pipeline), results in zip( + processing_params, parallel_results + ): for res in results: self.push_result(res, pipelines, process_pipeline) + res_per_db.append( self.results.to_dataframe( pipelines=pipelines, process_pipeline=process_pipeline diff --git a/moabb/evaluations/evaluations.py b/moabb/evaluations/evaluations.py index d7cc7d76c..3d1f18fe6 100644 --- a/moabb/evaluations/evaluations.py +++ b/moabb/evaluations/evaluations.py @@ -225,8 +225,11 @@ def _evaluate( acc = list() X_ = X[ix] y_ = y[ix] if self.mne_labels else y_cv - for cv_ind, (train, test) in enumerate(cv.split(X_, y_)): + meta_ = metadata[ix].reset_index(drop=True) + + for cv_ind, (train, test) in enumerate(cv.split(y_, meta_)): cvclf = clone(grid_clf) + cvclf.fit(X_[train], y_[train]) acc.append(scorer(cvclf, X_[test], y_[test])) From 0297f3cd4f715eb76f20c3b51e001cada077dfc3 Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Mon, 5 May 2025 10:11:40 +0200 Subject: [PATCH 09/29] solving the small issue --- moabb/evaluations/base.py | 8 +++----- moabb/evaluations/evaluations.py | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/moabb/evaluations/base.py b/moabb/evaluations/base.py index 9a9ae1d09..b06692958 100644 --- a/moabb/evaluations/base.py +++ b/moabb/evaluations/base.py @@ -98,7 +98,7 @@ def __init__( return_epochs=False, return_raws=False, mne_labels=False, - n_splits: int = 5, + n_splits=None, save_model=False, cache_config=None, optuna=False, @@ -229,11 +229,9 @@ def process(self, pipelines, param_grid=None, postprocess_pipeline=None): ) for dataset in self.datasets ] - + n_jobs = 1 # Parallel processing... - parallel_results = Parallel( - n_jobs=self.n_jobs, return_as="generator", verbose=10 - )( + parallel_results = Parallel(n_jobs=n_jobs, return_as="generator", verbose=10)( delayed( lambda dataset_processor: list( self.evaluate( diff --git a/moabb/evaluations/evaluations.py b/moabb/evaluations/evaluations.py index 3d1f18fe6..a41e8910c 100644 --- a/moabb/evaluations/evaluations.py +++ b/moabb/evaluations/evaluations.py @@ -177,7 +177,7 @@ def _evaluate( tracker.start() t_start = time() cv = WithinSessionSplitter( - n_folds=self.n_splits, + n_folds=5, shuffle=True, random_state=self.random_state, ) From 2b3ae21500d1359985fc8d4984ff084662e20c8e Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Mon, 5 May 2025 10:27:52 +0200 Subject: [PATCH 10/29] updating the evaluation --- moabb/evaluations/base.py | 24 ++++++++++-------------- moabb/evaluations/evaluations.py | 19 +++++++++++++++++-- moabb/evaluations/utils.py | 18 ++++++++++++++++++ 3 files changed, 45 insertions(+), 16 deletions(-) diff --git a/moabb/evaluations/base.py b/moabb/evaluations/base.py index b06692958..9bb85d4a7 100644 --- a/moabb/evaluations/base.py +++ b/moabb/evaluations/base.py @@ -5,28 +5,21 @@ import pandas as pd from joblib import Parallel, delayed from sklearn.base import BaseEstimator -from sklearn.model_selection import GridSearchCV from moabb.analysis import Results from moabb.datasets.base import BaseDataset -from moabb.evaluations.utils import _convert_sklearn_params_to_optuna +from moabb.evaluations.utils import ( + _convert_sklearn_params_to_optuna, + check_search_avaliable, +) from moabb.paradigms.base import BaseParadigm +search_methods, optuna_available = check_search_avaliable() + log = logging.getLogger(__name__) # Making the optuna soft dependency -try: - from optuna.integration import OptunaSearchCV - - optuna_available = True -except ImportError: - optuna_available = False - -if optuna_available: - search_methods = {"grid": GridSearchCV, "optuna": OptunaSearchCV} -else: - search_methods = {"grid": GridSearchCV} class BaseEvaluation(ABC): @@ -116,6 +109,7 @@ def __init__( self.cache_config = cache_config self.optuna = optuna self.time_out = time_out + self.n_jobs_inner = 5 if self.optuna and not optuna_available: raise ImportError("Optuna is not available. Please install it first.") @@ -231,7 +225,9 @@ def process(self, pipelines, param_grid=None, postprocess_pipeline=None): ] n_jobs = 1 # Parallel processing... - parallel_results = Parallel(n_jobs=n_jobs, return_as="generator", verbose=10)( + parallel_results = Parallel( + n_jobs=n_jobs, return_as="generator", verbose=10, backend="loky" + )( delayed( lambda dataset_processor: list( self.evaluate( diff --git a/moabb/evaluations/evaluations.py b/moabb/evaluations/evaluations.py index a41e8910c..b57a73676 100644 --- a/moabb/evaluations/evaluations.py +++ b/moabb/evaluations/evaluations.py @@ -4,6 +4,7 @@ from typing import Optional, Union import numpy as np +from joblib import parallel_backend from mne.epochs import BaseEpochs from sklearn.base import clone from sklearn.metrics import get_scorer @@ -23,7 +24,12 @@ CrossSubjectSplitter, WithinSessionSplitter, ) -from moabb.evaluations.utils import create_save_path, save_model_cv, save_model_list +from moabb.evaluations.utils import ( + check_search_avaliable, + create_save_path, + save_model_cv, + save_model_list, +) try: @@ -33,6 +39,8 @@ except ImportError: _carbonfootprint = False +search_methods, _ = check_search_avaliable() + log = logging.getLogger(__name__) # Numpy ArrayLike is only available starting from Numpy 1.20 and Python 3.8 @@ -229,8 +237,15 @@ def _evaluate( for cv_ind, (train, test) in enumerate(cv.split(y_, meta_)): cvclf = clone(grid_clf) + if any( + isinstance(cvclf, search) + for search in search_methods.values() + ): + with parallel_backend("threading", n_jobs=self.n_jobs_inner): + cvclf.fit(X_[train], y_[train]) + else: + cvclf.fit(X_[train], y_[train]) - cvclf.fit(X_[train], y_[train]) acc.append(scorer(cvclf, X_[test], y_[test])) if self.hdf5_path is not None and self.save_model: diff --git a/moabb/evaluations/utils.py b/moabb/evaluations/utils.py index 5b5b6d24c..0d66336cf 100644 --- a/moabb/evaluations/utils.py +++ b/moabb/evaluations/utils.py @@ -5,6 +5,7 @@ from typing import Sequence from numpy import argmax +from sklearn.model_selection import GridSearchCV from sklearn.pipeline import Pipeline @@ -213,3 +214,20 @@ def _convert_sklearn_params_to_optuna(param_grid: dict) -> dict: except Exception as e: raise ValueError(f"Conversion failed for parameter {key}: {e}") return optuna_params + + +def check_search_avaliable(): + """Check if optuna is available""" + try: + from optuna.integration import OptunaSearchCV + + optuna_available = True + except ImportError: + optuna_available = False + + if optuna_available: + search_methods = {"grid": GridSearchCV, "optuna": OptunaSearchCV} + else: + search_methods = {"grid": GridSearchCV} + + return search_methods, optuna_available From 86028c9dbcd7913debaa6562547f112b8e7bd580 Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Mon, 5 May 2025 10:44:32 +0200 Subject: [PATCH 11/29] adjusting in the other evaluation too --- moabb/evaluations/evaluations.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/moabb/evaluations/evaluations.py b/moabb/evaluations/evaluations.py index b57a73676..f31644d2b 100644 --- a/moabb/evaluations/evaluations.py +++ b/moabb/evaluations/evaluations.py @@ -241,7 +241,7 @@ def _evaluate( isinstance(cvclf, search) for search in search_methods.values() ): - with parallel_backend("threading", n_jobs=self.n_jobs_inner): + with parallel_backend("threading", n_jobs=3): cvclf.fit(X_[train], y_[train]) else: cvclf.fit(X_[train], y_[train]) @@ -553,7 +553,14 @@ def evaluate( t_start = time() cvclf = clone(grid_clf) - cvclf.fit(X[train], y[train]) + if any( + isinstance(cvclf, search) for search in search_methods.values() + ): + with parallel_backend("threading", n_jobs=3): + cvclf.fit(X[train], y[train]) + else: + cvclf.fit(X[train], y[train]) + model_list.append(cvclf) score = scorer(cvclf, X[test], y[test]) @@ -727,7 +734,15 @@ def evaluate( clf = self._grid_search( param_grid=param_grid, name=name, grid_clf=clf, inner_cv=inner_cv ) - model = deepcopy(clf).fit(X[train], y[train]) + + model = clone(clf) + + if any(isinstance(model, search) for search in search_methods.values()): + with parallel_backend("threading", n_jobs=3): + model.fit(X[train], y[train]) + else: + model.fit(X[train], y[train]) + if _carbonfootprint: emissions = tracker.stop() if emissions is None: From 70daa3a58d7d05c41756c8c121f7b0aa4d3beee0 Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Tue, 6 May 2025 13:50:36 +0200 Subject: [PATCH 12/29] updating --- moabb/evaluations/evaluations.py | 28 +++++----------------------- 1 file changed, 5 insertions(+), 23 deletions(-) diff --git a/moabb/evaluations/evaluations.py b/moabb/evaluations/evaluations.py index f31644d2b..aedf6d974 100644 --- a/moabb/evaluations/evaluations.py +++ b/moabb/evaluations/evaluations.py @@ -4,7 +4,6 @@ from typing import Optional, Union import numpy as np -from joblib import parallel_backend from mne.epochs import BaseEpochs from sklearn.base import clone from sklearn.metrics import get_scorer @@ -25,7 +24,6 @@ WithinSessionSplitter, ) from moabb.evaluations.utils import ( - check_search_avaliable, create_save_path, save_model_cv, save_model_list, @@ -39,7 +37,6 @@ except ImportError: _carbonfootprint = False -search_methods, _ = check_search_avaliable() log = logging.getLogger(__name__) @@ -237,14 +234,8 @@ def _evaluate( for cv_ind, (train, test) in enumerate(cv.split(y_, meta_)): cvclf = clone(grid_clf) - if any( - isinstance(cvclf, search) - for search in search_methods.values() - ): - with parallel_backend("threading", n_jobs=3): - cvclf.fit(X_[train], y_[train]) - else: - cvclf.fit(X_[train], y_[train]) + + cvclf.fit(X_[train], y_[train]) acc.append(scorer(cvclf, X_[test], y_[test])) @@ -553,13 +544,8 @@ def evaluate( t_start = time() cvclf = clone(grid_clf) - if any( - isinstance(cvclf, search) for search in search_methods.values() - ): - with parallel_backend("threading", n_jobs=3): - cvclf.fit(X[train], y[train]) - else: - cvclf.fit(X[train], y[train]) + + cvclf.fit(X[train], y[train]) model_list.append(cvclf) score = scorer(cvclf, X[test], y[test]) @@ -737,11 +723,7 @@ def evaluate( model = clone(clf) - if any(isinstance(model, search) for search in search_methods.values()): - with parallel_backend("threading", n_jobs=3): - model.fit(X[train], y[train]) - else: - model.fit(X[train], y[train]) + model.fit(X[train], y[train]) if _carbonfootprint: emissions = tracker.stop() From 8eb2c9d1c845d2878e1c2316020ad21d6feb9b92 Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Tue, 6 May 2025 13:54:00 +0200 Subject: [PATCH 13/29] updating the evaluations --- moabb/evaluations/base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/moabb/evaluations/base.py b/moabb/evaluations/base.py index 9bb85d4a7..7aae8d604 100644 --- a/moabb/evaluations/base.py +++ b/moabb/evaluations/base.py @@ -109,7 +109,6 @@ def __init__( self.cache_config = cache_config self.optuna = optuna self.time_out = time_out - self.n_jobs_inner = 5 if self.optuna and not optuna_available: raise ImportError("Optuna is not available. Please install it first.") @@ -223,10 +222,9 @@ def process(self, pipelines, param_grid=None, postprocess_pipeline=None): ) for dataset in self.datasets ] - n_jobs = 1 # Parallel processing... parallel_results = Parallel( - n_jobs=n_jobs, return_as="generator", verbose=10, backend="loky" + n_jobs=self.n_jobs, return_as="generator", verbose=10, backend="loky" )( delayed( lambda dataset_processor: list( From 6451efd866e8d8d12a203f55a5f52db2343cf3d7 Mon Sep 17 00:00:00 2001 From: Bru Date: Tue, 6 May 2025 23:24:41 +0200 Subject: [PATCH 14/29] Apply suggestions from code review Co-authored-by: Thomas Moreau Signed-off-by: Bru --- moabb/evaluations/base.py | 29 +++++++++++------------------ moabb/evaluations/utils.py | 2 +- 2 files changed, 12 insertions(+), 19 deletions(-) diff --git a/moabb/evaluations/base.py b/moabb/evaluations/base.py index 7aae8d604..b9dc8eeb7 100644 --- a/moabb/evaluations/base.py +++ b/moabb/evaluations/base.py @@ -10,12 +10,12 @@ from moabb.datasets.base import BaseDataset from moabb.evaluations.utils import ( _convert_sklearn_params_to_optuna, - check_search_avaliable, + check_search_available, ) from moabb.paradigms.base import BaseParadigm -search_methods, optuna_available = check_search_avaliable() +search_methods, optuna_available = check_search_available() log = logging.getLogger(__name__) @@ -224,23 +224,16 @@ def process(self, pipelines, param_grid=None, postprocess_pipeline=None): ] # Parallel processing... parallel_results = Parallel( - n_jobs=self.n_jobs, return_as="generator", verbose=10, backend="loky" + n_jobs=self.n_jobs, return_as="generator", )( - delayed( - lambda dataset_processor: list( - self.evaluate( - dataset_processor[0], # dataset - pipelines, - param_grid=param_grid, - process_pipeline=dataset_processor[1], - # process_pipeline - postprocess_pipeline=postprocess_pipeline, - ) - ) - )( - params - ) # Pass parameters here - for params in processing_params +```suggestion + delayed(self.evaluate)( + dataset, + pipelines, + param_grid=param_grid, + process_pipeline=process_pipeline, + postprocess_pipeline=postprocess_pipeline, + ) for dataset, process_pipeline in processing_params ) res_per_db = [] diff --git a/moabb/evaluations/utils.py b/moabb/evaluations/utils.py index 0d66336cf..0d67aea24 100644 --- a/moabb/evaluations/utils.py +++ b/moabb/evaluations/utils.py @@ -216,7 +216,7 @@ def _convert_sklearn_params_to_optuna(param_grid: dict) -> dict: return optuna_params -def check_search_avaliable(): +def check_search_available(): """Check if optuna is available""" try: from optuna.integration import OptunaSearchCV From 8c932d4eed61365970543b713c72c0d7ff781023 Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Tue, 6 May 2025 23:26:13 +0200 Subject: [PATCH 15/29] updating base --- moabb/evaluations/base.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/moabb/evaluations/base.py b/moabb/evaluations/base.py index b9dc8eeb7..696c57e27 100644 --- a/moabb/evaluations/base.py +++ b/moabb/evaluations/base.py @@ -222,18 +222,20 @@ def process(self, pipelines, param_grid=None, postprocess_pipeline=None): ) for dataset in self.datasets ] + # Parallel processing... parallel_results = Parallel( - n_jobs=self.n_jobs, return_as="generator", + n_jobs=self.n_jobs, + return_as="generator", )( -```suggestion delayed(self.evaluate)( dataset, pipelines, param_grid=param_grid, process_pipeline=process_pipeline, postprocess_pipeline=postprocess_pipeline, - ) for dataset, process_pipeline in processing_params + ) + for dataset, process_pipeline in processing_params ) res_per_db = [] From 3bf2ab7e5515fcd28cca305373eed295c3fd4090 Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Fri, 25 Jul 2025 15:52:40 +0200 Subject: [PATCH 16/29] updating the pyproject --- pyproject.toml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4ae4822c8..40a561fa5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,10 +11,10 @@ keywords = ["eeg", "datasets", "reproducibility", "bci", "benchmark"] license = "BSD-3-Clause" [tool.poetry.dependencies] -python = ">=3.9" +python = ">=3.10" numpy = "^1.22" scipy = "^1.9.3" -mne = "^1.7.0" +mne = ">=1.10.0" pandas = ">=1.5.2" h5py = "^3.10.0" matplotlib = "^3.6.2" @@ -30,8 +30,9 @@ memory-profiler = "^0.61.0" edflib-python = "^1.0.6" edfio = "^0.4.2" pytest = "^8.3.5" -mne-bids = ">=0.14" +mne-bids = ">=0.16" scikit-learn = "<1.6" +lmdb = ">=1.7.2" # Optional dependencies for carbon emission codecarbon = { version = "^2.1.4", optional = true } From c76acce8af45597505fb993da676863043810d1a Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Fri, 25 Jul 2025 16:28:23 +0200 Subject: [PATCH 17/29] trying to solve this shit... --- moabb/evaluations/evaluations.py | 108 +++++++++++++++++++++++++++++-- 1 file changed, 101 insertions(+), 7 deletions(-) diff --git a/moabb/evaluations/evaluations.py b/moabb/evaluations/evaluations.py index bfb10ff78..3bb17224f 100644 --- a/moabb/evaluations/evaluations.py +++ b/moabb/evaluations/evaluations.py @@ -1,10 +1,16 @@ +import hashlib +import inspect import logging from copy import deepcopy +from pathlib import Path from time import time from typing import Optional, Union import numpy as np +import pandas as pd +import torch from mne.epochs import BaseEpochs +from mne.utils import get_config from sklearn.base import clone from sklearn.metrics import get_scorer from sklearn.model_selection import ( @@ -14,7 +20,8 @@ StratifiedShuffleSplit, ) from sklearn.model_selection._validation import _score -from sklearn.preprocessing import LabelEncoder +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import FunctionTransformer, LabelEncoder from tqdm import tqdm from moabb.evaluations.base import BaseEvaluation @@ -661,15 +668,32 @@ def evaluate( if len(run_pipes) == 0: return - # get the data - X, y, metadata = self.paradigm.get_data( + memmap_path = get_memmap_path( dataset=dataset, + process_pipeline=process_pipeline, + postprocess_pipeline=postprocess_pipeline, return_epochs=self.return_epochs, return_raws=self.return_raws, - cache_config=self.cache_config, - postprocess_pipeline=postprocess_pipeline, - process_pipelines=[process_pipeline], ) + if not Path(memmap_path).exists(): + # get the data + X, y, metadata = self.paradigm.get_data( + dataset=dataset, + return_epochs=self.return_epochs, + return_raws=self.return_raws, + cache_config=self.cache_config, + postprocess_pipeline=postprocess_pipeline, + process_pipelines=[process_pipeline], + ) + Path(memmap_path).mkdir(parents=True, exist_ok=True) + np.save(f"{memmap_path}/X_memmap.npy", X) + np.save(f"{memmap_path}/y_memmap.npy", y) + metadata.to_pickle(f"{memmap_path}/metadata.pkl") + else: + # load the data from memmap + X = np.load(f"{memmap_path}/X_memmap.npy", mmap_mode="r") + y = np.load(f"{memmap_path}/y_memmap.npy", mmap_mode="r") + metadata = pd.read_pickle(f"{memmap_path}/metadata.pkl") # encode labels le = LabelEncoder() @@ -726,7 +750,6 @@ def evaluate( ) model = clone(clf) - model.fit(X[train], y[train]) if _carbonfootprint: @@ -778,3 +801,74 @@ def evaluate( def is_valid(self, dataset): return len(dataset.subject_list) > 1 + + +class MemmapEEGDataset(torch.utils.data.Dataset): + def __init__(self, x_path, y_path, transform=None): + self.X = np.load(x_path, mmap_mode="r") + self.y = np.load(y_path, mmap_mode="r") + self.transform = transform + + def __getitem__(self, index): + x = self.X[index] + y = self.y[index] + if self.transform: + x = self.transform(x) + return torch.from_numpy(x).float(), int(y) + + def __len__(self): + return len(self.y) + + +def get_pipeline_name(pipeline): + """Create a unique name for a pipeline or list of (name, transform) steps.""" + steps = [] + + if isinstance(pipeline, Pipeline): + steps = pipeline.steps + elif isinstance(pipeline, list): # raw list of (name, transform) + steps = pipeline + else: + return "noop" + + names = [] + for name, step in steps: + if isinstance(step, FunctionTransformer): + try: + source = inspect.getsource(step.func) + hashed = hashlib.md5(source.encode()).hexdigest()[:8] + step_name = f"FunctionTransformer_{hashed}" + except Exception: + step_name = f"FunctionTransformer_{id(step.func)}" + else: + step_name = getattr(step, "__class__", str(step)).__name__ + names.append(step_name) + + return "_".join(names) + + +def get_memmap_path( + dataset, + process_pipeline, + postprocess_pipeline=None, + return_epochs=False, + return_raws=False, +): + """Generate a clean path to store/load memmap data.""" + # Handle None safely + postprocess_pipeline = postprocess_pipeline or [] + + # Convert pipeline names to unique strings + process_name = get_pipeline_name(process_pipeline) + postprocess_name = get_pipeline_name(postprocess_pipeline) + + # Compose full path + memmap_path = Path(get_config("MNE_DATA")) / Path( + dataset.code, + process_name, + postprocess_name, + "memmap", + "epochs" if return_epochs else "raws" if return_raws else "data", + ) + + return str(memmap_path) From 08b97edc381e367a2bf2046e242700019e88610d Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Fri, 25 Jul 2025 16:41:31 +0200 Subject: [PATCH 18/29] crazy things here.. --- moabb/evaluations/evaluations.py | 33 +++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/moabb/evaluations/evaluations.py b/moabb/evaluations/evaluations.py index 3bb17224f..3cf3245b3 100644 --- a/moabb/evaluations/evaluations.py +++ b/moabb/evaluations/evaluations.py @@ -685,19 +685,23 @@ def evaluate( postprocess_pipeline=postprocess_pipeline, process_pipelines=[process_pipeline], ) + le = LabelEncoder() + y = y if self.mne_labels else le.fit_transform(y) Path(memmap_path).mkdir(parents=True, exist_ok=True) np.save(f"{memmap_path}/X_memmap.npy", X) np.save(f"{memmap_path}/y_memmap.npy", y) metadata.to_pickle(f"{memmap_path}/metadata.pkl") else: # load the data from memmap - X = np.load(f"{memmap_path}/X_memmap.npy", mmap_mode="r") - y = np.load(f"{memmap_path}/y_memmap.npy", mmap_mode="r") - metadata = pd.read_pickle(f"{memmap_path}/metadata.pkl") + dataset_memmap = MemmapEEGDataset( + x_path=f"{memmap_path}/X_memmap.npy", + y_path=f"{memmap_path}/y_memmap.npy", + metadata_path=f"{memmap_path}/metadata.pkl", + ) - # encode labels - le = LabelEncoder() - y = y if self.mne_labels else le.fit_transform(y) + X = dataset_memmap.X + y = dataset_memmap.y + metadata = dataset_memmap.metadata # extract metadata groups = metadata.subject.values @@ -722,7 +726,6 @@ def evaluate( inner_cv = StratifiedKFold(3, shuffle=True, random_state=self.random_state) # Implement Grid Search - if _carbonfootprint: # Initialise CodeCarbon tracker = EmissionsTracker(save_to_file=False, log_level="error") @@ -750,7 +753,10 @@ def evaluate( ) model = clone(clf) - model.fit(X[train], y[train]) + + train_dataset = torch.utils.data.Subset(dataset_memmap, train) + + model.fit(train_dataset, y=None) if _carbonfootprint: emissions = tracker.stop() @@ -775,6 +781,8 @@ def evaluate( # we eval on each session for session in np.unique(sessions[test]): ix = sessions[test] == session + + # Extract number of channels from dataset score = _score( estimator=model, X_test=X[test[ix]], @@ -783,6 +791,8 @@ def evaluate( score_params={}, ) + nchan = X.shape[1] # since X is memmapped already + nchan = X.info["nchan"] if isinstance(X, BaseEpochs) else X.shape[1] res = { "time": duration, @@ -804,17 +814,18 @@ def is_valid(self, dataset): class MemmapEEGDataset(torch.utils.data.Dataset): - def __init__(self, x_path, y_path, transform=None): + def __init__(self, x_path, y_path, metadata_path, transform=None): self.X = np.load(x_path, mmap_mode="r") self.y = np.load(y_path, mmap_mode="r") + self.metadata = pd.read_pickle(metadata_path) self.transform = transform def __getitem__(self, index): x = self.X[index] - y = self.y[index] + y = int(self.y[index]) if self.transform: x = self.transform(x) - return torch.from_numpy(x).float(), int(y) + return torch.from_numpy(x).float(), y def __len__(self): return len(self.y) From d3a4aa2904baa4e31a99bbeb9675545655aefbf4 Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Fri, 25 Jul 2025 16:48:56 +0200 Subject: [PATCH 19/29] too much things at the same time --- moabb/evaluations/evaluations.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/moabb/evaluations/evaluations.py b/moabb/evaluations/evaluations.py index 3cf3245b3..b8e17f0f0 100644 --- a/moabb/evaluations/evaluations.py +++ b/moabb/evaluations/evaluations.py @@ -693,15 +693,9 @@ def evaluate( metadata.to_pickle(f"{memmap_path}/metadata.pkl") else: # load the data from memmap - dataset_memmap = MemmapEEGDataset( - x_path=f"{memmap_path}/X_memmap.npy", - y_path=f"{memmap_path}/y_memmap.npy", - metadata_path=f"{memmap_path}/metadata.pkl", - ) - - X = dataset_memmap.X - y = dataset_memmap.y - metadata = dataset_memmap.metadata + X = np.load(f"{memmap_path}/X_memmap.npy", mmap_mode="r") + y = np.load(f"{memmap_path}/y_memmap.npy", mmap_mode="r") + metadata = pd.read_pickle(f"{memmap_path}/metadata.pkl") # extract metadata groups = metadata.subject.values @@ -754,9 +748,7 @@ def evaluate( model = clone(clf) - train_dataset = torch.utils.data.Subset(dataset_memmap, train) - - model.fit(train_dataset, y=None) + model.fit(X[train], y[train]) if _carbonfootprint: emissions = tracker.stop() From 22f5950272973651d69c0b7a07d300981de9a91c Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Mon, 28 Jul 2025 19:54:21 +0200 Subject: [PATCH 20/29] reverting --- moabb/evaluations/evaluations.py | 115 +++---------------------------- 1 file changed, 8 insertions(+), 107 deletions(-) diff --git a/moabb/evaluations/evaluations.py b/moabb/evaluations/evaluations.py index b8e17f0f0..d96a47b6b 100644 --- a/moabb/evaluations/evaluations.py +++ b/moabb/evaluations/evaluations.py @@ -1,16 +1,10 @@ -import hashlib -import inspect import logging from copy import deepcopy -from pathlib import Path from time import time from typing import Optional, Union import numpy as np -import pandas as pd -import torch from mne.epochs import BaseEpochs -from mne.utils import get_config from sklearn.base import clone from sklearn.metrics import get_scorer from sklearn.model_selection import ( @@ -20,8 +14,7 @@ StratifiedShuffleSplit, ) from sklearn.model_selection._validation import _score -from sklearn.pipeline import Pipeline -from sklearn.preprocessing import FunctionTransformer, LabelEncoder +from sklearn.preprocessing import LabelEncoder from tqdm import tqdm from moabb.evaluations.base import BaseEvaluation @@ -668,34 +661,16 @@ def evaluate( if len(run_pipes) == 0: return - memmap_path = get_memmap_path( + X, y, metadata = self.paradigm.get_data( dataset=dataset, - process_pipeline=process_pipeline, - postprocess_pipeline=postprocess_pipeline, return_epochs=self.return_epochs, return_raws=self.return_raws, + cache_config=self.cache_config, + postprocess_pipeline=postprocess_pipeline, + process_pipelines=[process_pipeline], ) - if not Path(memmap_path).exists(): - # get the data - X, y, metadata = self.paradigm.get_data( - dataset=dataset, - return_epochs=self.return_epochs, - return_raws=self.return_raws, - cache_config=self.cache_config, - postprocess_pipeline=postprocess_pipeline, - process_pipelines=[process_pipeline], - ) - le = LabelEncoder() - y = y if self.mne_labels else le.fit_transform(y) - Path(memmap_path).mkdir(parents=True, exist_ok=True) - np.save(f"{memmap_path}/X_memmap.npy", X) - np.save(f"{memmap_path}/y_memmap.npy", y) - metadata.to_pickle(f"{memmap_path}/metadata.pkl") - else: - # load the data from memmap - X = np.load(f"{memmap_path}/X_memmap.npy", mmap_mode="r") - y = np.load(f"{memmap_path}/y_memmap.npy", mmap_mode="r") - metadata = pd.read_pickle(f"{memmap_path}/metadata.pkl") + le = LabelEncoder() + y = y if self.mne_labels else le.fit_transform(y) # extract metadata groups = metadata.subject.values @@ -746,9 +721,7 @@ def evaluate( param_grid=param_grid, name=name, grid_clf=clf, inner_cv=inner_cv ) - model = clone(clf) - - model.fit(X[train], y[train]) + model = deepcopy(clf).fit(X[train], y[train]) if _carbonfootprint: emissions = tracker.stop() @@ -803,75 +776,3 @@ def evaluate( def is_valid(self, dataset): return len(dataset.subject_list) > 1 - - -class MemmapEEGDataset(torch.utils.data.Dataset): - def __init__(self, x_path, y_path, metadata_path, transform=None): - self.X = np.load(x_path, mmap_mode="r") - self.y = np.load(y_path, mmap_mode="r") - self.metadata = pd.read_pickle(metadata_path) - self.transform = transform - - def __getitem__(self, index): - x = self.X[index] - y = int(self.y[index]) - if self.transform: - x = self.transform(x) - return torch.from_numpy(x).float(), y - - def __len__(self): - return len(self.y) - - -def get_pipeline_name(pipeline): - """Create a unique name for a pipeline or list of (name, transform) steps.""" - steps = [] - - if isinstance(pipeline, Pipeline): - steps = pipeline.steps - elif isinstance(pipeline, list): # raw list of (name, transform) - steps = pipeline - else: - return "noop" - - names = [] - for name, step in steps: - if isinstance(step, FunctionTransformer): - try: - source = inspect.getsource(step.func) - hashed = hashlib.md5(source.encode()).hexdigest()[:8] - step_name = f"FunctionTransformer_{hashed}" - except Exception: - step_name = f"FunctionTransformer_{id(step.func)}" - else: - step_name = getattr(step, "__class__", str(step)).__name__ - names.append(step_name) - - return "_".join(names) - - -def get_memmap_path( - dataset, - process_pipeline, - postprocess_pipeline=None, - return_epochs=False, - return_raws=False, -): - """Generate a clean path to store/load memmap data.""" - # Handle None safely - postprocess_pipeline = postprocess_pipeline or [] - - # Convert pipeline names to unique strings - process_name = get_pipeline_name(process_pipeline) - postprocess_name = get_pipeline_name(postprocess_pipeline) - - # Compose full path - memmap_path = Path(get_config("MNE_DATA")) / Path( - dataset.code, - process_name, - postprocess_name, - "memmap", - "epochs" if return_epochs else "raws" if return_raws else "data", - ) - - return str(memmap_path) From 8996991ab5dda1725fac558050ea3424ad309cb9 Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Mon, 28 Jul 2025 20:15:47 +0200 Subject: [PATCH 21/29] evaluation --- moabb/evaluations/base.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/moabb/evaluations/base.py b/moabb/evaluations/base.py index 696c57e27..9bdbe3201 100644 --- a/moabb/evaluations/base.py +++ b/moabb/evaluations/base.py @@ -224,17 +224,18 @@ def process(self, pipelines, param_grid=None, postprocess_pipeline=None): ] # Parallel processing... - parallel_results = Parallel( - n_jobs=self.n_jobs, - return_as="generator", - )( - delayed(self.evaluate)( - dataset, - pipelines, - param_grid=param_grid, - process_pipeline=process_pipeline, - postprocess_pipeline=postprocess_pipeline, - ) + parallel_results = Parallel(n_jobs=self.n_jobs)( + delayed( + lambda d, p: list( + self.evaluate( + d, + pipelines, + param_grid=param_grid, + process_pipeline=p, + postprocess_pipeline=postprocess_pipeline, + ) + ) + )(dataset, process_pipeline) for dataset, process_pipeline in processing_params ) From e1b6de16d3f5d6838821bf016a4b4da356b5a126 Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Mon, 28 Jul 2025 20:47:01 +0200 Subject: [PATCH 22/29] including acceptance test --- moabb/tests/acceptance_tests/test_accurary.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 moabb/tests/acceptance_tests/test_accurary.py diff --git a/moabb/tests/acceptance_tests/test_accurary.py b/moabb/tests/acceptance_tests/test_accurary.py new file mode 100644 index 000000000..5b99ee760 --- /dev/null +++ b/moabb/tests/acceptance_tests/test_accurary.py @@ -0,0 +1,46 @@ +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest +from pyriemann.classification import MDM +from pyriemann.estimation import XdawnCovariances +from sklearn.pipeline import make_pipeline +from sklearn.utils import check_random_state + +from moabb.datasets import BNCI2014_001, BNCI2015_001 +from moabb.evaluations import CrossSessionEvaluation +from moabb.paradigms import MotorImagery + + +@pytest.mark.parametrize("dataset_class", [BNCI2014_001, BNCI2015_001]) +def test_decoding_performance_stable(dataset_class): + dataset_name = dataset_class.__name__ + random_state = check_random_state(42) + + dataset_cls = dataset_class + dataset = dataset_cls() + paradigm = MotorImagery() + + # Simple pipeline + pipeline = make_pipeline(XdawnCovariances(nfilter=4), MDM(n_jobs=4)) + + # Evaluate + evaluation = CrossSessionEvaluation( + paradigm=paradigm, datasets=[dataset], overwrite=True, random_state=random_state + ) + results = evaluation.process({"mdm": pipeline}) + results.drop(columns=["time"], inplace=True) + results["score"] = results["score"].astype(np.float32) + results["samples"] = results["samples"].astype(int) + results["subject"] = results["subject"].astype(int) + + folder_path = Path(__file__).parent / "reference_results_dataset_{}.csv".format( + dataset_name + ) + reference_performance = pd.read_csv(folder_path) + reference_performance.drop(columns=["time", "Unnamed: 0"], inplace=True) + reference_performance["score"] = reference_performance["score"].astype(np.float32) + reference_performance["samples"] = reference_performance["samples"].astype(int) + + pd.testing.assert_frame_equal(results, reference_performance) From 041e0b7361a76809e69d467202b7d00c5a89c6da Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Mon, 28 Jul 2025 20:47:49 +0200 Subject: [PATCH 23/29] forcing two reference results --- ...reference_results_dataset_BNCI2014_001.csv | 19 ++++++++++++ ...reference_results_dataset_BNCI2015_001.csv | 29 +++++++++++++++++++ 2 files changed, 48 insertions(+) create mode 100644 moabb/tests/acceptance_tests/reference_results_dataset_BNCI2014_001.csv create mode 100644 moabb/tests/acceptance_tests/reference_results_dataset_BNCI2015_001.csv diff --git a/moabb/tests/acceptance_tests/reference_results_dataset_BNCI2014_001.csv b/moabb/tests/acceptance_tests/reference_results_dataset_BNCI2014_001.csv new file mode 100644 index 000000000..b029c526f --- /dev/null +++ b/moabb/tests/acceptance_tests/reference_results_dataset_BNCI2014_001.csv @@ -0,0 +1,19 @@ +,score,time,samples,subject,session,channels,n_sessions,dataset,pipeline +0,0.7430556,0.28345227,288.0,1,0train,22,2,BNCI2014-001,mdm +1,0.6944444,0.2819698,288.0,1,1test,22,2,BNCI2014-001,mdm +2,0.5486111,0.28295708,288.0,2,0train,22,2,BNCI2014-001,mdm +3,0.5555556,0.28221202,288.0,2,1test,22,2,BNCI2014-001,mdm +4,0.6527778,0.27323103,288.0,3,0train,22,2,BNCI2014-001,mdm +5,0.6319444,0.28558397,288.0,3,1test,22,2,BNCI2014-001,mdm +6,0.4652778,0.28424382,288.0,4,0train,22,2,BNCI2014-001,mdm +7,0.6076389,0.28512216,288.0,4,1test,22,2,BNCI2014-001,mdm +8,0.4340278,0.26603198,288.0,5,0train,22,2,BNCI2014-001,mdm +9,0.47569445,0.2672441,288.0,5,1test,22,2,BNCI2014-001,mdm +10,0.38194445,0.28032613,288.0,6,0train,22,2,BNCI2014-001,mdm +11,0.4652778,0.29096103,288.0,6,1test,22,2,BNCI2014-001,mdm +12,0.5625,0.26360798,288.0,7,0train,22,2,BNCI2014-001,mdm +13,0.46875,0.26497293,288.0,7,1test,22,2,BNCI2014-001,mdm +14,0.6041667,0.27954388,288.0,8,0train,22,2,BNCI2014-001,mdm +15,0.6111111,0.29071403,288.0,8,1test,22,2,BNCI2014-001,mdm +16,0.5451389,0.27546215,288.0,9,0train,22,2,BNCI2014-001,mdm +17,0.7326389,0.2862649,288.0,9,1test,22,2,BNCI2014-001,mdm diff --git a/moabb/tests/acceptance_tests/reference_results_dataset_BNCI2015_001.csv b/moabb/tests/acceptance_tests/reference_results_dataset_BNCI2015_001.csv new file mode 100644 index 000000000..97d2c3265 --- /dev/null +++ b/moabb/tests/acceptance_tests/reference_results_dataset_BNCI2015_001.csv @@ -0,0 +1,29 @@ +,score,time,samples,subject,session,channels,n_sessions,dataset,pipeline +0,0.9898,0.104274035,200.0,1,0A,13,2,BNCI2015-001,mdm +1,0.996,0.109023094,200.0,1,1B,13,2,BNCI2015-001,mdm +2,0.9822,0.11902189,200.0,2,0A,13,2,BNCI2015-001,mdm +3,0.9817,0.10449815,200.0,2,1B,13,2,BNCI2015-001,mdm +4,0.9411,0.10515785,200.0,3,0A,13,2,BNCI2015-001,mdm +5,0.9713,0.10190797,200.0,3,1B,13,2,BNCI2015-001,mdm +6,0.8777,0.107106924,200.0,4,0A,13,2,BNCI2015-001,mdm +7,0.9653,0.10397911,200.0,4,1B,13,2,BNCI2015-001,mdm +8,0.8416,0.105483055,200.0,5,0A,13,2,BNCI2015-001,mdm +9,0.8118,0.10831189,200.0,5,1B,13,2,BNCI2015-001,mdm +10,0.6624,0.12765789,200.0,6,0A,13,2,BNCI2015-001,mdm +11,0.6314,0.10389686,200.0,6,1B,13,2,BNCI2015-001,mdm +12,0.8948,0.10865617,200.0,7,0A,13,2,BNCI2015-001,mdm +13,0.8931,0.09851694,200.0,7,1B,13,2,BNCI2015-001,mdm +14,0.6032,0.18366313,400.0,8,0A,13,2,BNCI2015-001,mdm +15,0.7523,0.19959378,400.0,8,1B,13,2,BNCI2015-001,mdm +16,0.8488,0.18477702,400.0,8,2C,13,2,BNCI2015-001,mdm +17,0.7601,0.1761918,400.0,9,0A,13,2,BNCI2015-001,mdm +18,0.8687,0.17262912,400.0,9,1B,13,2,BNCI2015-001,mdm +19,0.9154,0.17855692,400.0,9,2C,13,2,BNCI2015-001,mdm +20,0.6787,0.21773195,400.0,10,0A,13,2,BNCI2015-001,mdm +21,0.6402,0.20742917,400.0,10,1B,13,2,BNCI2015-001,mdm +22,0.6116,0.19268918,400.0,10,2C,13,2,BNCI2015-001,mdm +23,0.7974,0.20285797,400.0,11,0A,13,2,BNCI2015-001,mdm +24,0.7403,0.20020509,400.0,11,1B,13,2,BNCI2015-001,mdm +25,0.7949,0.18860793,400.0,11,2C,13,2,BNCI2015-001,mdm +26,0.6574,0.10171008,200.0,12,0A,13,2,BNCI2015-001,mdm +27,0.6693,0.10934806,200.0,12,1B,13,2,BNCI2015-001,mdm From f36c2f0d01119af991485a7e4b226cd51a1683e0 Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Mon, 28 Jul 2025 20:52:45 +0200 Subject: [PATCH 24/29] reverting small detail --- moabb/evaluations/evaluations.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/moabb/evaluations/evaluations.py b/moabb/evaluations/evaluations.py index 8fb285883..c89dade7c 100644 --- a/moabb/evaluations/evaluations.py +++ b/moabb/evaluations/evaluations.py @@ -749,8 +749,6 @@ def evaluate( score = scorer(model, X[test[ix]], y[test[ix]]) - nchan = X.shape[1] # since X is memmapped already - nchan = X.info["nchan"] if isinstance(X, BaseEpochs) else X.shape[1] res = { "time": duration, From 480497fdfe49f3f98fcdb7a45fccf98a5faf0b48 Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Mon, 28 Jul 2025 20:53:31 +0200 Subject: [PATCH 25/29] updating the pyproject --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d77dcf119..3fd3128c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,6 @@ edfio = "^0.4.2" pytest = "^8.3.5" mne-bids = ">=0.16" scikit-learn = "<1.6" -lmdb = ">=1.7.2" # Optional dependencies for carbon emission codecarbon = { version = "^2.1.4", optional = true } From 9d175578c3c8f62df14e8b2fbb94875de2183a7f Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Mon, 4 Aug 2025 15:52:48 +0200 Subject: [PATCH 26/29] upgrading the mne version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3fd3128c4..a65bf9a8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ license = "BSD-3-Clause" [tool.poetry.dependencies] python = ">=3.10" -numpy = "^2.0" +numpy = ">=2.0" scipy = "^1.9.3" mne = "^1.10.0" pandas = ">=1.5.2" From 5bcdfb513a31251ce17e81184b82c4ab75e7230c Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Mon, 4 Aug 2025 16:49:57 +0200 Subject: [PATCH 27/29] solving issue with saving --- .../plot_grid_search_withinsession.py | 42 ---------- moabb/evaluations/__init__.py | 2 +- moabb/evaluations/base.py | 5 ++ moabb/evaluations/evaluations.py | 79 ++++++++++--------- moabb/evaluations/utils.py | 12 +-- 5 files changed, 53 insertions(+), 87 deletions(-) diff --git a/examples/advanced_examples/plot_grid_search_withinsession.py b/examples/advanced_examples/plot_grid_search_withinsession.py index a2f1aefb7..f63bb82be 100644 --- a/examples/advanced_examples/plot_grid_search_withinsession.py +++ b/examples/advanced_examples/plot_grid_search_withinsession.py @@ -9,7 +9,6 @@ """ import os -from pickle import load import matplotlib.pyplot as plt import seaborn as sns @@ -132,44 +131,3 @@ ) sns.pointplot(data=result, y="score", x="pipeline", ax=axes, palette="Set1") axes.set_ylabel("ROC AUC") - -########################################################## -# Load Best Model Parameter -# ------------------------- -# The best model are automatically saved in a pickle file, in the -# results directory. It is possible to load those model for each -# dataset, subject and session. Here, we could see that the grid -# search found a l1_ratio that is different from the baseline -# value. - -with open( - "./Results/Models_WithinSession/BNCI2014-001/1/1test/GridSearchEN/fitted_model_best.pkl", - "rb", -) as pickle_file: - GridSearchEN_Session_E = load(pickle_file) - -print( - "Best Parameter l1_ratio Session_E GridSearchEN ", - GridSearchEN_Session_E.best_params_["LogistReg__l1_ratio"], -) - -print( - "Best Parameter l1_ratio Session_E VanillaEN: ", - pipelines["VanillaEN"].steps[2][1].l1_ratio, -) - -with open( - "./Results/Models_WithinSession/BNCI2014-001/1/0train/GridSearchEN/fitted_model_best.pkl", - "rb", -) as pickle_file: - GridSearchEN_Session_T = load(pickle_file) - -print( - "Best Parameter l1_ratio Session_T GridSearchEN ", - GridSearchEN_Session_T.best_params_["LogistReg__l1_ratio"], -) - -print( - "Best Parameter l1_ratio Session_T VanillaEN: ", - pipelines["VanillaEN"].steps[2][1].l1_ratio, -) diff --git a/moabb/evaluations/__init__.py b/moabb/evaluations/__init__.py index 9f8eceff5..4a5695f48 100644 --- a/moabb/evaluations/__init__.py +++ b/moabb/evaluations/__init__.py @@ -10,4 +10,4 @@ WithinSessionEvaluation, ) from .splitters import CrossSessionSplitter, CrossSubjectSplitter, WithinSessionSplitter -from .utils import create_save_path, save_model_cv, save_model_list +from .utils import _create_save_path, _save_model_cv diff --git a/moabb/evaluations/base.py b/moabb/evaluations/base.py index 9bdbe3201..7f8d70158 100644 --- a/moabb/evaluations/base.py +++ b/moabb/evaluations/base.py @@ -77,6 +77,8 @@ class BaseEvaluation(ABC): optuna, time_out parameters. """ + search = False + def __init__( self, paradigm, @@ -327,9 +329,12 @@ def _grid_search(self, param_grid, name, grid_clf, inner_cv): return_train_score=True, **extra_params, ) + self.search = True return search else: + self.search = True return grid_clf else: + self.search = False return grid_clf diff --git a/moabb/evaluations/evaluations.py b/moabb/evaluations/evaluations.py index c89dade7c..36c2badcc 100644 --- a/moabb/evaluations/evaluations.py +++ b/moabb/evaluations/evaluations.py @@ -24,9 +24,8 @@ WithinSessionSplitter, ) from moabb.evaluations.utils import ( - create_save_path, - save_model_cv, - save_model_list, + _create_save_path, + _save_model_cv, ) @@ -143,7 +142,6 @@ def __init__( super().__init__(**kwargs) # flake8: noqa: C901 - def _evaluate( self, dataset, @@ -181,8 +179,9 @@ def _evaluate( # Initialize CodeCarbon tracker = EmissionsTracker(save_to_file=False, log_level="error") tracker.start() + t_start = time() - cv = WithinSessionSplitter( + self.cv = WithinSessionSplitter( n_folds=5, shuffle=True, random_state=self.random_state, @@ -198,17 +197,6 @@ def _evaluate( grid_clf = clone(clf) - # Create folder for grid search results - create_save_path( - self.hdf5_path, - dataset.code, - subject, - session, - name, - grid=True, - eval_type="WithinSession", - ) - # Implement Grid Search grid_clf = self._grid_search( param_grid=param_grid, @@ -216,16 +204,19 @@ def _evaluate( grid_clf=grid_clf, inner_cv=inner_cv, ) + if self.hdf5_path is not None and self.save_model: - model_save_path = create_save_path( + model_save_path = _create_save_path( self.hdf5_path, dataset.code, subject, session, name, - grid=False, + grid=self.search, eval_type="WithinSession", ) + else: + model_save_path = None scorer = get_scorer(self.paradigm.scoring) acc = list() @@ -233,7 +224,7 @@ def _evaluate( y_ = y[ix] if self.mne_labels else y_cv meta_ = metadata[ix].reset_index(drop=True) - for cv_ind, (train, test) in enumerate(cv.split(y_, meta_)): + for cv_ind, (train, test) in enumerate(self.cv.split(y_, meta_)): cvclf = clone(grid_clf) cvclf.fit(X_[train], y_[train]) @@ -241,7 +232,7 @@ def _evaluate( acc.append(scorer(cvclf, X_[test], y_[test])) if self.hdf5_path is not None and self.save_model: - save_model_cv( + _save_model_cv( model=cvclf, save_path=model_save_path, cv_index=cv_ind, @@ -255,10 +246,9 @@ def _evaluate( if emissions is None: emissions = np.nan duration = time() - t_start - nchan = X.info["nchan"] if isinstance(X, BaseEpochs) else X.shape[1] res = { - "time": duration / 5.0, # 5 fold CV + "time": duration / self.cv.n_folds, # 5 fold CV "dataset": dataset, "subject": subject, "session": session, @@ -517,7 +507,7 @@ def evaluate( tracker.start() # we want to store a results per session - cv = CrossSessionSplitter(random_state=self.random_state) + self.cv = CrossSessionSplitter(random_state=self.random_state) inner_cv = StratifiedKFold( 3, shuffle=True, random_state=self.random_state @@ -531,16 +521,17 @@ def evaluate( ) if self.hdf5_path is not None and self.save_model: - model_save_path = create_save_path( + model_save_path = _create_save_path( hdf5_path=self.hdf5_path, code=dataset.code, subject=subject, session="", name=name, - grid=False, + grid=self.search, eval_type="CrossSession", ) - for cv_ind, (train, test) in enumerate(cv.split(y, metadata)): + + for cv_ind, (train, test) in enumerate(self.cv.split(y, metadata)): model_list = [] if _carbonfootprint: tracker.start() @@ -554,7 +545,7 @@ def evaluate( score = scorer(cvclf, X[test], y[test]) if self.hdf5_path is not None and self.save_model: - save_model_cv( + _save_model_cv( model=cvclf, save_path=model_save_path, cv_index=str(cv_ind), @@ -566,12 +557,6 @@ def evaluate( emissions = 0 duration = time() - t_start - if self.hdf5_path is not None and self.save_model: - save_model_list( - model_list=model_list, - score_list=score, - save_path=model_save_path, - ) nchan = X.info["nchan"] if isinstance(X, BaseEpochs) else X.shape[1] res = { @@ -688,7 +673,7 @@ def evaluate( cv_kwargs = {"n_splits": self.n_splits} n_subjects = self.n_splits - cv = CrossSubjectSplitter( + self.cv = CrossSubjectSplitter( cv_class=cv_class, random_state=self.random_state, **cv_kwargs ) @@ -702,7 +687,7 @@ def evaluate( # Progressbar at subject level for cv_ind, (train, test) in enumerate( tqdm( - cv.split(y, metadata), + self.cv.split(y, metadata), total=n_subjects, desc=f"{dataset.code}-CrossSubject", ) @@ -721,6 +706,23 @@ def evaluate( param_grid=param_grid, name=name, grid_clf=clf, inner_cv=inner_cv ) + if self.hdf5_path is not None and self.save_model: + # Save the best model from grid search + model_save_path = _create_save_path( + hdf5_path=self.hdf5_path, + code=dataset.code, + subject=subject, + session="", + name=name, + grid=self.search, + eval_type="CrossSubject", + ) + _save_model_cv( + model=clf, + save_path=model_save_path, + cv_index=str(cv_ind), + ) + model = deepcopy(clf).fit(X[train], y[train]) if _carbonfootprint: @@ -730,17 +732,18 @@ def evaluate( duration = time() - t_start if self.hdf5_path is not None and self.save_model: - model_save_path = create_save_path( + + model_save_path = _create_save_path( hdf5_path=self.hdf5_path, code=dataset.code, subject=subject, session="", name=name, - grid=False, + grid=self.search, eval_type="CrossSubject", ) - save_model_cv( + _save_model_cv( model=model, save_path=model_save_path, cv_index=str(cv_ind) ) # we eval on each session diff --git a/moabb/evaluations/utils.py b/moabb/evaluations/utils.py index c15a044bc..f642e1191 100644 --- a/moabb/evaluations/utils.py +++ b/moabb/evaluations/utils.py @@ -54,7 +54,7 @@ def _check_if_is_pytorch_steps(model): return skorch_valid -def save_model_cv(model: object, save_path: str | Path, cv_index: str | int): +def _save_model_cv(model: object, save_path: str | Path, cv_index: str | int): """Save a model fitted to a given fold from cross-validation. Parameters @@ -96,7 +96,7 @@ def save_model_cv(model: object, save_path: str | Path, cv_index: str | int): dump(model, file, protocol=HIGHEST_PROTOCOL) -def save_model_list(model_list: list | Pipeline, score_list: Sequence, save_path: str): +def _save_model_list(model_list: list | Pipeline, score_list: Sequence, save_path: str): """Save a list of models fitted to a folder. Parameters @@ -120,14 +120,14 @@ def save_model_list(model_list: list | Pipeline, score_list: Sequence, save_path model_list = [model_list] for cv_index, model in enumerate(model_list): - save_model_cv(model, save_path, str(cv_index)) + _save_model_cv(model, save_path, str(cv_index)) best_model = model_list[argmax(score_list)] - save_model_cv(best_model, save_path, "best") + _save_model_cv(best_model, save_path, "best") -def create_save_path( +def _create_save_path( hdf5_path, code: str, subject: int | str, @@ -167,7 +167,7 @@ def create_save_path( if grid: path_save = ( Path(hdf5_path) - / f"GridSearch_{eval_type}" + / f"Search_{eval_type}" / code / f"{str(subject)}" / str(session) From 6fe882d1433a0b51673784e9e0c0e2cabbec3b0f Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Mon, 4 Aug 2025 16:53:55 +0200 Subject: [PATCH 28/29] scoring --- moabb/evaluations/evaluations.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/moabb/evaluations/evaluations.py b/moabb/evaluations/evaluations.py index 36c2badcc..93d4d55f6 100644 --- a/moabb/evaluations/evaluations.py +++ b/moabb/evaluations/evaluations.py @@ -215,8 +215,6 @@ def _evaluate( grid=self.search, eval_type="WithinSession", ) - else: - model_save_path = None scorer = get_scorer(self.paradigm.scoring) acc = list() @@ -229,7 +227,9 @@ def _evaluate( cvclf.fit(X_[train], y_[train]) - acc.append(scorer(cvclf, X_[test], y_[test])) + score = scorer(cvclf, X_[test], y_[test]) + + acc.append(score) if self.hdf5_path is not None and self.save_model: _save_model_cv( From 492ad04e91f290ed15afa8f0d6cc7ac5228b3260 Mon Sep 17 00:00:00 2001 From: bruAristimunha Date: Mon, 4 Aug 2025 17:35:32 +0200 Subject: [PATCH 29/29] fixing import --- moabb/tests/test_evaluations.py | 42 ++------------------------------- 1 file changed, 2 insertions(+), 40 deletions(-) diff --git a/moabb/tests/test_evaluations.py b/moabb/tests/test_evaluations.py index 0dc4f98c0..6beb73edb 100644 --- a/moabb/tests/test_evaluations.py +++ b/moabb/tests/test_evaluations.py @@ -18,7 +18,8 @@ from moabb.datasets.fake import FakeDataset from moabb.evaluations import evaluations as ev from moabb.evaluations.base import optuna_available -from moabb.evaluations.utils import create_save_path, save_model_cv, save_model_list +from moabb.evaluations.utils import _create_save_path as create_save_path +from moabb.evaluations.utils import _save_model_cv as save_model_cv from moabb.paradigms.motor_imagery import FakeImageryParadigm @@ -393,17 +394,6 @@ def test_save_model_cv(self): # Assert that the saved model file exists assert os.path.isfile(os.path.join(save_path, "fitted_model_0.pkl")) - def test_save_model_list(self): - step = Dummy() - model = Pipeline([("step", step)]) - model_list = [model] - score_list = [0.8] - save_path = "test_save_path" - save_model_list(model_list, score_list, save_path) - - # Assert that the saved model file for best model exists - assert os.path.isfile(os.path.join(save_path, "fitted_model_best.pkl")) - def test_create_save_path(self): hdf5_path = "base_path" code = "evaluation_code" @@ -454,21 +444,6 @@ def test_save_model_cv_with_pytorch_model(self): assert os.path.isfile(os.path.join(save_path, "step_fitted_0_history.json")) assert os.path.isfile(os.path.join(save_path, "step_fitted_0_criterion.pkl")) - def test_save_model_list_with_multiple_models(self): - model1 = Dummy() - model2 = Dummy() - model_list = [model1, model2] - score_list = [0.8, 0.9] - save_path = "test_save_path" - save_model_list(model_list, score_list, save_path) - - # Assert that the saved model files for each model exist - assert os.path.isfile(os.path.join(save_path, "fitted_model_0.pkl")) - assert os.path.isfile(os.path.join(save_path, "fitted_model_1.pkl")) - - # Assert that the saved model file for the best model exists - assert os.path.isfile(os.path.join(save_path, "fitted_model_best.pkl")) - def test_create_save_path_with_cross_session_evaluation(self): hdf5_path = "base_path" code = "evaluation_code" @@ -516,19 +491,6 @@ def test_save_model_cv_without_hdf5_path(self): with pytest.raises(IOError): save_model_cv(model, save_path, cv_index) - def test_save_model_list_with_single_model(self): - model = Dummy() - model_list = model - score_list = [0.8] - save_path = "test_save_path" - save_model_list(model_list, score_list, save_path) - - # Assert that the saved model file for the single model exists - assert os.path.isfile(os.path.join(save_path, "fitted_model_0.pkl")) - - # Assert that the saved model file for the best model exists - assert os.path.isfile(os.path.join(save_path, "fitted_model_best.pkl")) - def test_create_save_path_with_cross_subject_evaluation(self): hdf5_path = "base_path" code = "evaluation_code"