Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions skl2onnx/operator_converters/text_vectorizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) ...
# Licensed under the MIT License.

from skl2onnx.common._apply_operation import apply_tokenizer
from skl2onnx.common.data_types import StringTensorType
from skl2onnx.common._registration import register_converter


def convert_sklearn_count_vectorizer(scope, operator, container):
op = operator.raw_operator
input_var = operator.inputs[0]
output_var = operator.outputs[0]

if not isinstance(input_var.type, StringTensorType):
raise RuntimeError("CountVectorizer input must be a string tensor")

analyzer = getattr(op, "analyzer", "word")

if analyzer == "word":
# existing word-level tokenizer
apply_tokenizer(
scope,
input_var.full_name,
output_var.full_name,
container,
op.vocabulary_,
mark=operator.full_name,
)

elif analyzer in ("char", "char_wb"):
ngram_range = getattr(op, "ngram_range", (1, 1))

# ONNX Tokenizer regex for single characters
pattern = "."
if analyzer == "char_wb":
# Approximate: capture chars incl. word boundaries
# Note: real sklearn pads words with spaces, this regex
# is an approximation (still useful for deployment).
pattern = r"\b.|.\b|."

apply_tokenizer(
scope,
input_var.full_name,
output_var.full_name,
container,
op.vocabulary_,
mark=operator.full_name,
pattern=pattern,
ngram_range=ngram_range,
)

else:
raise NotImplementedError(
f"Analyzer={analyzer!r} not yet supported in skl2onnx."
)


register_converter(
"SklearnCountVectorizer",
convert_sklearn_count_vectorizer,
)

45 changes: 45 additions & 0 deletions tests/test_sklearn_vectorizers_char.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import unittest
import numpy as np
import onnxruntime as rt

from sklearn.feature_extraction.text import TfidfVectorizer
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import StringTensorType


class TestCharVectorizers(unittest.TestCase):
def _run_vectorizer(self, analyzer):
vec = TfidfVectorizer(
analyzer=analyzer,
ngram_range=(2, 5),
min_df=1,
max_features=1000,
)
vec.fit(["купить дрель", "газонокосилка", "пластиковые стяжки"])

onx = convert_sklearn(
vec,
initial_types=[("input", StringTensorType([None, 1]))],
)

sess = rt.InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
res = sess.run(
None,
{"input": np.array([["купить дрель"]])},
)
return res[0]

def test_char_vectorizer(self):
out = self._run_vectorizer("char")
self.assertEqual(len(out.shape), 2)
self.assertGreater(out.shape[1], 0)

def test_char_wb_vectorizer(self):
out = self._run_vectorizer("char_wb")
self.assertEqual(len(out.shape), 2)
self.assertGreater(out.shape[1], 0)


if __name__ == "__main__":
unittest.main()

Loading