Skip to content

Commit a7f978b

Browse files
committed
Enable model_list filtering for AUTO_SELECT_SERIES via MetaSelector (from spec.model_kwargs), add tests
1 parent 42f297b commit a7f978b

File tree

4 files changed

+112
-5
lines changed

4 files changed

+112
-5
lines changed

ads/opctl/operator/lowcode/forecast/meta_selector.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class MetaSelector:
1313
The rules are based on the meta-features calculated by the FFORMS approach.
1414
"""
1515

16-
def __init__(self):
16+
def __init__(self, allowed_models=None):
1717
"""Initialize the MetaSelector with pre-learned meta rules"""
1818
# Pre-learned rules based on meta-features
1919
self._meta_rules = {
@@ -216,6 +216,22 @@ def __init__(self):
216216
},
217217
}
218218

219+
# Normalize and apply allowed_models filter if provided
220+
self._allowed_set = None
221+
if allowed_models:
222+
known = {"prophet", "arima", "neuralprophet", "automlx", "autots"}
223+
if isinstance(allowed_models, (list, tuple, set)):
224+
self._allowed_set = {str(m).lower() for m in allowed_models}
225+
else:
226+
self._allowed_set = {str(allowed_models).lower()}
227+
self._allowed_set = {m for m in self._allowed_set if m in known}
228+
if self._allowed_set:
229+
self._meta_rules = {
230+
name: rule
231+
for name, rule in self._meta_rules.items()
232+
if rule.get("model") in self._allowed_set
233+
}
234+
219235
def _evaluate_condition(self, value, operator, threshold):
220236
"""Evaluate a single condition based on pre-defined operators"""
221237
if pd.isna(value):
@@ -288,7 +304,13 @@ def select_best_model(self, meta_features_df):
288304
series_info["matched_features"] = matched_features[best_rule]
289305
else:
290306
best_rule = "default"
291-
best_model = "prophet" # Default to prophet if no rules match
307+
if getattr(self, "_allowed_set", None):
308+
if "prophet" in self._allowed_set:
309+
best_model = "prophet"
310+
else:
311+
best_model = sorted(self._allowed_set)[0]
312+
else:
313+
best_model = "prophet" # Default to prophet if no rules match
292314
series_info["matched_features"] = []
293315

294316
series_info["selected_model"] = best_model

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@
4747
AUTO_SELECT,
4848
BACKTEST_REPORT_NAME,
4949
SUMMARY_METRICS_HORIZON_LIMIT,
50+
TROUBLESHOOTING_GUIDE,
5051
ForecastOutputColumns,
5152
SpeedAccuracyMode,
5253
SupportedMetrics,
5354
SupportedModels,
54-
TROUBLESHOOTING_GUIDE,
5555
)
5656
from ..operator_config import ForecastOperatorConfig, ForecastOperatorSpec
5757
from .forecast_datasets import ForecastDatasets, ForecastResults

ads/opctl/operator/lowcode/forecast/model/factory.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ def get_model(
7676

7777
if model_type == AUTO_SELECT_SERIES:
7878
# Initialize MetaSelector for series-specific model selection
79-
selector = MetaSelector()
79+
allowed = operator_config.spec.model_kwargs.get("model_list", None) if hasattr(operator_config.spec, "model_kwargs") and operator_config.spec.model_kwargs else None
80+
selector = MetaSelector(allowed_models=allowed)
8081
# Create a Transformations instance
8182
transformer = Transformations(dataset_info=datasets.historical_data.spec)
8283

@@ -89,7 +90,15 @@ def get_model(
8990
)
9091
)
9192
# Get the most common model as default
92-
model_type = meta_features['selected_model'].mode().iloc[0]
93+
selected_str = str(meta_features['selected_model'].mode().iloc[0]).lower()
94+
str_to_enum = {
95+
"prophet": SupportedModels.Prophet,
96+
"arima": SupportedModels.Arima,
97+
"neuralprophet": SupportedModels.NeuralProphet,
98+
"automlx": SupportedModels.AutoMLX,
99+
"autots": SupportedModels.AutoTS,
100+
}
101+
model_type = str_to_enum.get(selected_str, SupportedModels.Prophet)
93102
# Store the series-specific model selections in the config for later use
94103
operator_config.spec.meta_features = meta_features
95104
operator_config.spec.model_kwargs = {}

tests/operators/forecast/test_datasets.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,5 +413,81 @@ def run_operator(
413413
# generate_train_metrics = True
414414

415415

416+
@pytest.mark.parametrize("allowed", [["prophet", "arima"], ["prophet"], ["arima"], ["automlx"], ["neuralprophet"]])
417+
def test_auto_select_series_model_list_filter(allowed):
418+
# Skip neuralprophet when running with NumPy 2.x due to upstream np.NaN usage
419+
if "neuralprophet" in allowed:
420+
try:
421+
import numpy as np # local import to avoid unused import in other tests
422+
major = int(str(np.__version__).split(".")[0])
423+
except Exception:
424+
major = 0
425+
if major >= 2:
426+
pytest.skip("Skipping neuralprophet with NumPy >= 2.0 due to upstream incompatibility (uses np.NaN).")
427+
428+
# Skip pure-arima case if pmdarima cannot be imported (e.g., binary incompatibility with current NumPy)
429+
if [str(m).lower() for m in allowed] == ["arima"]:
430+
try:
431+
import pmdarima as pm # noqa: F401
432+
except Exception as e:
433+
pytest.skip(f"Skipping arima due to pmdarima import error: {e}")
434+
435+
dataset_name = f"{DATASET_PREFIX}dataset1.csv"
436+
dataset_i = pd.read_csv(dataset_name)
437+
target = "Y"
438+
439+
with tempfile.TemporaryDirectory() as tmpdirname:
440+
historical_data_path = f"{tmpdirname}/primary_data.csv"
441+
test_data_path = f"{tmpdirname}/test_data.csv"
442+
output_data_path = f"{tmpdirname}/results"
443+
yaml_i = deepcopy(TEMPLATE_YAML)
444+
445+
# Train/Test split
446+
dataset_i[[DATETIME_COL, target]][:-PERIODS].to_csv(
447+
historical_data_path, index=False
448+
)
449+
dataset_i[[DATETIME_COL, target]][-PERIODS:].to_csv(test_data_path, index=False)
450+
451+
# Prepare YAML
452+
yaml_i["spec"]["historical_data"]["url"] = historical_data_path
453+
yaml_i["spec"]["test_data"] = {"url": test_data_path}
454+
yaml_i["spec"]["output_directory"]["url"] = output_data_path
455+
yaml_i["spec"]["model"] = "auto-select-series"
456+
yaml_i["spec"]["target_column"] = target
457+
yaml_i["spec"]["datetime_column"]["name"] = DATETIME_COL
458+
yaml_i["spec"]["horizon"] = PERIODS
459+
yaml_i["spec"]["generate_metrics"] = True
460+
yaml_i["spec"]["model_kwargs"] = {"model_list": allowed}
461+
462+
# Run operator
463+
run(yaml_i, backend="operator.local", debug=False)
464+
465+
# Collect per-model metrics produced by auto-select-series
466+
result_files = os.listdir(output_data_path)
467+
train_metrics_files = [
468+
f for f in result_files if f.startswith("metrics_") and f.endswith(".csv")
469+
]
470+
test_metrics_files = [
471+
f
472+
for f in result_files
473+
if f.startswith("test_metrics_") and f.endswith(".csv")
474+
]
475+
476+
# Extract model names from filenames
477+
found_models = set()
478+
for f in train_metrics_files:
479+
found_models.add(f[len("metrics_") : -len(".csv")])
480+
for f in test_metrics_files:
481+
found_models.add(f[len("test_metrics_") : -len(".csv")])
482+
483+
assert found_models, "No per-model metrics files were generated."
484+
# Ensure only allowed models are present
485+
assert found_models.issubset(set(allowed)), f"Found disallowed models in outputs: {found_models - set(allowed)}"
486+
487+
# Ensure disallowed models are absent
488+
known_models = {"prophet", "arima", "neuralprophet", "automlx", "autots"}
489+
disallowed = known_models - set(allowed)
490+
assert found_models.isdisjoint(disallowed), f"Disallowed models present: {found_models & disallowed}"
491+
416492
if __name__ == "__main__":
417493
pass

0 commit comments

Comments
 (0)