From 1bd1f96603a78ce0b06eef3f4401d4e8e5de0c32 Mon Sep 17 00:00:00 2001 From: Fedor Krasnov Date: Wed, 3 Sep 2025 20:45:27 +0300 Subject: [PATCH] Add support for char and char_wb analyzers in TfidfVectorizer/CountVectorizer --- .../operator_converters/text_vectorizer.py | 62 +++++++++++++++++++ tests/test_sklearn_vectorizers_char.py | 45 ++++++++++++++ 2 files changed, 107 insertions(+) create mode 100644 skl2onnx/operator_converters/text_vectorizer.py create mode 100644 tests/test_sklearn_vectorizers_char.py diff --git a/skl2onnx/operator_converters/text_vectorizer.py b/skl2onnx/operator_converters/text_vectorizer.py new file mode 100644 index 000000000..f4762bffd --- /dev/null +++ b/skl2onnx/operator_converters/text_vectorizer.py @@ -0,0 +1,62 @@ +# Copyright (c) ... +# Licensed under the MIT License. + +from skl2onnx.common._apply_operation import apply_tokenizer +from skl2onnx.common.data_types import StringTensorType +from skl2onnx.common._registration import register_converter + + +def convert_sklearn_count_vectorizer(scope, operator, container): + op = operator.raw_operator + input_var = operator.inputs[0] + output_var = operator.outputs[0] + + if not isinstance(input_var.type, StringTensorType): + raise RuntimeError("CountVectorizer input must be a string tensor") + + analyzer = getattr(op, "analyzer", "word") + + if analyzer == "word": + # existing word-level tokenizer + apply_tokenizer( + scope, + input_var.full_name, + output_var.full_name, + container, + op.vocabulary_, + mark=operator.full_name, + ) + + elif analyzer in ("char", "char_wb"): + ngram_range = getattr(op, "ngram_range", (1, 1)) + + # ONNX Tokenizer regex for single characters + pattern = "." + if analyzer == "char_wb": + # Approximate: capture chars incl. word boundaries + # Note: real sklearn pads words with spaces, this regex + # is an approximation (still useful for deployment). + pattern = r"\b.|.\b|." + + apply_tokenizer( + scope, + input_var.full_name, + output_var.full_name, + container, + op.vocabulary_, + mark=operator.full_name, + pattern=pattern, + ngram_range=ngram_range, + ) + + else: + raise NotImplementedError( + f"Analyzer={analyzer!r} not yet supported in skl2onnx." + ) + + +register_converter( + "SklearnCountVectorizer", + convert_sklearn_count_vectorizer, +) + diff --git a/tests/test_sklearn_vectorizers_char.py b/tests/test_sklearn_vectorizers_char.py new file mode 100644 index 000000000..0e57bca31 --- /dev/null +++ b/tests/test_sklearn_vectorizers_char.py @@ -0,0 +1,45 @@ +import unittest +import numpy as np +import onnxruntime as rt + +from sklearn.feature_extraction.text import TfidfVectorizer +from skl2onnx import convert_sklearn +from skl2onnx.common.data_types import StringTensorType + + +class TestCharVectorizers(unittest.TestCase): + def _run_vectorizer(self, analyzer): + vec = TfidfVectorizer( + analyzer=analyzer, + ngram_range=(2, 5), + min_df=1, + max_features=1000, + ) + vec.fit(["купить дрель", "газонокосилка", "пластиковые стяжки"]) + + onx = convert_sklearn( + vec, + initial_types=[("input", StringTensorType([None, 1]))], + ) + + sess = rt.InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"]) + res = sess.run( + None, + {"input": np.array([["купить дрель"]])}, + ) + return res[0] + + def test_char_vectorizer(self): + out = self._run_vectorizer("char") + self.assertEqual(len(out.shape), 2) + self.assertGreater(out.shape[1], 0) + + def test_char_wb_vectorizer(self): + out = self._run_vectorizer("char_wb") + self.assertEqual(len(out.shape), 2) + self.assertGreater(out.shape[1], 0) + + +if __name__ == "__main__": + unittest.main() +