11# SPDX-License-Identifier: Apache-2.0
2- from typing import List , Optional
2+ from typing import Optional
33
44import numpy as np
5+ from numba import jit
56
67
78class NgramProposer :
89
9- def __init__ (self ):
10- pass
11-
1210 def propose (
1311 self ,
1412 context_token_ids : np .ndarray ,
@@ -21,7 +19,7 @@ def propose(
2119 that match.
2220
2321 Args:
24- context_token_ids: List of token IDs representing the
22+ context_token_ids: Numpy array of token IDs representing the
2523 context sequence.
2624 n: Length of the n-gram to match.
2725 k: Number of tokens follow the match. If there are less
@@ -41,66 +39,65 @@ def propose(
4139 followed that pattern. Here we will return [4,2,3] because
4240 we only have three tokens after the match.
4341 """
44- # TODO: Use c++ to implement the _find_subarray_kmp to
45- # improve the efficiency
46- return self ._find_subarray_kmp (context_token_ids , n , k )
42+ return _find_subarray_kmp (context_token_ids , n , k )
4743
48- @staticmethod
49- def _kmp_lps_array (pattern : List [int ]) -> List [int ]:
50- """
51- Build the lps (longest proper prefix which is also suffix)
52- array for the pattern.
53- """
54- lps = [0 ] * len (pattern )
55- prev_lps = 0 # length of the previous longest prefix suffix
56- i = 1
5744
58- while i < len (pattern ):
59- if pattern [i ] == pattern [prev_lps ]:
60- prev_lps += 1
61- lps [i ] = prev_lps
62- i += 1
45+ @jit (nopython = True )
46+ def _kmp_lps_array (pattern : np .ndarray ) -> np .ndarray :
47+ """
48+ Build the lps (longest proper prefix which is also suffix)
49+ array for the pattern.
50+ """
51+ lps = np .zeros (len (pattern ), dtype = np .int32 )
52+ prev_lps = 0 # length of the previous longest prefix suffix
53+ i = 1
54+
55+ while i < len (pattern ):
56+ if pattern [i ] == pattern [prev_lps ]:
57+ prev_lps += 1
58+ lps [i ] = prev_lps
59+ i += 1
60+ else :
61+ if prev_lps != 0 :
62+ prev_lps = lps [prev_lps - 1 ]
6363 else :
64- if prev_lps != 0 :
65- prev_lps = lps [prev_lps - 1 ]
66- else :
67- lps [i ] = 0
68- i += 1
64+ lps [i ] = 0
65+ i += 1
66+ return lps
6967
70- return lps
7168
72- @ staticmethod
73- def _find_subarray_kmp (
74- context_token_ids : np .ndarray ,
75- n : int ,
76- k : int ,
77- ) -> Optional [np .ndarray ]:
78- context_len = context_token_ids .shape [0 ]
79- assert n > 0
69+ @ jit ( nopython = True )
70+ def _find_subarray_kmp (
71+ context_token_ids : np .ndarray ,
72+ n : int ,
73+ k : int ,
74+ ) -> Optional [np .ndarray ]:
75+ context_len = context_token_ids .shape [0 ]
76+ assert n > 0
8077
81- pattern = context_token_ids [- n :]
82- # Precompute lps array for Y
83- lps = NgramProposer . _kmp_lps_array (pattern )
78+ pattern = context_token_ids [- n :]
79+ # Precompute lps array for Y
80+ lps = _kmp_lps_array (pattern )
8481
85- i = 0
86- j = 0
87- # -n because the last n tokens are used as pattern
88- while i < context_len - n :
89- if context_token_ids [i ] == pattern [j ]:
90- i += 1
91- j += 1
82+ i = 0
83+ j = 0
84+ # -n because the last n tokens are used as pattern
85+ while i < context_len - n :
86+ if context_token_ids [i ] == pattern [j ]:
87+ i += 1
88+ j += 1
9289
93- # If we have matched the entire Y
94- if j == n :
95- # Found pattern in context, gather the next K elements
96- return context_token_ids [i :i + k ]
90+ # If we have matched the entire Y
91+ if j == n :
92+ # Found pattern in context, gather the next K elements
93+ return context_token_ids [i :i + k ]
94+ else :
95+ # Mismatch
96+ if j != 0 :
97+ # Use the lps array to avoid re-checking elements
98+ j = lps [j - 1 ]
9799 else :
98- # Mismatch
99- if j != 0 :
100- # Use the lps array to avoid re-checking elements
101- j = lps [j - 1 ]
102- else :
103- i += 1
100+ i += 1
104101
105- # Y not found
106- return None
102+ # Y not found
103+ return None
0 commit comments