diff --git a/benchmarks/bench_regex_fsm.py b/benchmarks/bench_regex_fsm.py new file mode 100644 index 000000000..48585c7fe --- /dev/null +++ b/benchmarks/bench_regex_fsm.py @@ -0,0 +1,56 @@ +import random + +from outlines.caching import cache_disabled +from outlines.fsm.regex import reduced_vocabulary +from outlines.models.tokenizer import Tokenizer + +from .common import ensure_numba_compiled + + +class MockTokenizer(Tokenizer): + def __init__(self, token_strs): + self.eos_token = "" + self.eos_token_id = 0 + self.pad_token_id = 1 + self.special_tokens = {0, 1} + + self.vocabulary = {"": 0, "": 1} + + for i, tok in enumerate(token_strs): + self.vocabulary[tok] = i + 2 + + @classmethod + def from_random_tokens(cls, n_tokens, max_token_length=8, seed=42): + random.seed(seed) + tokens = [ + "".join( + chr(random.randint(0, 4096)) + for __ in range(random.randint(0, max_token_length)) + ) + for _ in range(n_tokens) + ] + return cls(tokens) + + def convert_token_to_string(self, token): + return token + + def __hash__(self): + return hash(tuple(sorted(self.vocabulary.items()))) + + +def reduced_vocabulary_uncached(*args, **kwargs): + return reduced_vocabulary.__wrapped__(*args, **kwargs) + + +class RegexReducedVocabularyBenchmark: + params = [10000, 100000, 1000000] + param_names = ["vocab_size"] + + def setup(self, vocab_size): + ensure_numba_compiled(MockTokenizer([chr(i) for i in range(128)])) + + self.tokenizer = MockTokenizer.from_random_tokens(vocab_size) + + @cache_disabled() + def time_reduced_vocabulary(self, _): + reduced_vocabulary_uncached(self.tokenizer) diff --git a/outlines/fsm/regex.py b/outlines/fsm/regex.py index 8cfd81ead..a1d337aa8 100644 --- a/outlines/fsm/regex.py +++ b/outlines/fsm/regex.py @@ -28,6 +28,8 @@ from numba.typed.typedobjectutils import _nonoptional from tqdm import tqdm +from outlines.fsm.vocab_trie import VocabTrie + if TYPE_CHECKING: from outlines.models.tokenizer import Tokenizer @@ -664,30 +666,39 @@ def state_scan_tokens( alphabet_anything_value: int, fsm_initial: int, fsm_finals: Set[int], - vocabulary: List[Tuple[str, Sequence[int]]], - vocabulary_transition_keys: List[Sequence[int]], + vocabulary: List[Tuple[Sequence[str], Sequence[int]]], + vocab_trie: VocabTrie, start_state: int, ) -> Set[Tuple[int, int]]: res = set() - for (token, token_ids), token_transition_keys in zip( - vocabulary, vocabulary_transition_keys - ): + # Initialize the stack with tokens having no prefixes + stack = numba.typed.List() + for token_transitions_seq in vocab_trie.get_children(): + stack.append(token_transitions_seq) + + # Process the tokens using the stack + while stack: + token_transitions_seq = stack.pop() state_seq = _walk_fsm( fsm_transitions, fsm_initial, fsm_finals, - token_transition_keys, + token_transitions_seq, start_state, False, ) - if state_seq is not None and len(state_seq) < len(token_transition_keys): + if len(state_seq) < len(token_transitions_seq): continue - for token_id in token_ids: + for token_id in vocab_trie.get_token_ids(token_transitions_seq): res.add((token_id, state_seq[-1])) + # Add successors to the stack + for new_token in vocab_trie.get_children(token_transitions_seq): + stack.append(new_token) + return res @@ -805,7 +816,7 @@ def create_fsm_index_end_to_end( desc="Compiling FSM index for all state transitions", ) - vocabulary_transition_keys = get_vocabulary_transition_keys( + vocabulary_transitions = get_vocabulary_transition_keys( fsm_info.alphabet_symbol_mapping, fsm_info.alphabet_anything_value, vocabulary, @@ -815,10 +826,16 @@ def create_fsm_index_end_to_end( else numba.typed.List.empty_list(numba.types.unicode_type) ), ) + vocab_trie = VocabTrie(vocabulary_transitions, vocabulary) while next_states: start_state = next_states.pop() + pbar.update(1) + + if start_state not in seen: + seen.add(start_state) + token_ids_end_states = state_scan_tokens( fsm_info.transitions, fsm_info.alphabet_symbol_mapping, @@ -826,7 +843,7 @@ def create_fsm_index_end_to_end( fsm_info.initial, fsm_info.finals, vocabulary, - vocabulary_transition_keys, + vocab_trie, start_state, ) @@ -838,10 +855,6 @@ def create_fsm_index_end_to_end( if end_state not in seen: next_states.add(end_state) - if start_state not in seen: - pbar.update(1) - seen.add(start_state) - pbar.close() return states_to_token_subsets @@ -887,23 +900,11 @@ def gpt2_unicode_to_bytes(): return {v: k for k, v in gpt2_bytes_to_unicode().items()} -# TODO: Cannot cache typed collections to disk, yet. See -# https://github.com/numba/numba/issues/4698 -@lru_cache -def reduced_vocabulary( - tokenizer: "Tokenizer", -) -> Tuple[List[Tuple[str, Sequence[int]]], Set[int]]: - """Create a map from decoded vocabulary tokens to lists of equivalent token ids.""" +def get_normalized_vocab(tokenizer: "Tokenizer") -> Tuple[Dict[int, str], Set[int]]: + norm_vocab = {} empty_token_ids = set() - vocabulary: Dict[Union[str, Tuple[str, ...]], List[int]] = {} for token, token_idx in tokenizer.vocabulary.items(): - if token in tokenizer.special_tokens: - continue - - token_str: Union[str, Tuple[str, ...]] = tokenizer.convert_token_to_string( - token - ) - + token_str = tokenizer.convert_token_to_string(token) if token_str: # invalid utf-8 sequences are replaced with � (\ufffd), but there # might also be tokens specifically for �, ��, ���, etc. @@ -927,22 +928,88 @@ def reduced_vocabulary( ) token_str = "".join(byte_symbol(b) for b in token_bytes) - vocabulary.setdefault(token_str, []).append(token_idx) + norm_vocab[token_idx] = token_str else: empty_token_ids.add(numba.int64(token_idx)) - vocabulary_nb = numba.typed.List.empty_list( - numba.types.Tuple( - ( - nb_unicode_type, - numba.int64[:], - ) - ) + return norm_vocab, empty_token_ids + + +@numba.njit(cache=True, nogil=True) +def to_numba_dict(keys: List[int], values: List[str]): + """ + Pure-python numba dict construction is extremely slow. + This helper accepts equal length key and value arrays, and constructs a numba dict + """ + # Define the key and value types for the Numba dictionary + numba_dict = numba.typed.Dict.empty( + key_type=numba.types.int64, + value_type=numba.types.unicode_type, + ) + + # Fill the Numba dictionary with values from the input lists + for i in range(len(keys)): + numba_dict[keys[i]] = values[i] + + return numba_dict + + +token_id_str_pair = numba.types.Tuple((nb_unicode_type, numba.int64[:])) + + +@numba.njit( + numba.types.ListType(token_id_str_pair)( + numba.types.DictType(numba.int64, nb_unicode_type) + ), + cache=True, + nogil=True, +) +def vocab_dict_to_inverted_vocab_list( + vocab_dict_nb: Dict[int, str] +) -> List[Tuple[str, Sequence[int]]]: + """ + Helper for `reduced_vocabulary` + + Convert + - from `vocab_dict_nb`: Dict[token_id, token_str] + - to `vocab_nb`: List[token_str, token_id[:]] + """ + inverse_vocab_dict = numba.typed.Dict.empty( + key_type=numba.types.unicode_type, value_type=numba.types.int64[:] ) - for token_str, token_ids in vocabulary.items(): - token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64")) - vocabulary_nb.append((token_str, token_ids_np)) + # Fill the temporary dictionary + for key in vocab_dict_nb: + value = vocab_dict_nb[key] + if value not in inverse_vocab_dict: + inverse_vocab_dict[value] = np.zeros(0, dtype=np.int64) + inverse_vocab_dict[value] = np.append(inverse_vocab_dict[value], key) + + # Transfer data from the temporary dictionary to the final dictionary + vocab_nb = numba.typed.List.empty_list(token_id_str_pair) + + for value in inverse_vocab_dict: + vocab_nb.append((value, inverse_vocab_dict[value])) + + return vocab_nb + + +# TODO: Cannot cache typed collections to disk, yet. See +# https://github.com/numba/numba/issues/4698 +@lru_cache +def reduced_vocabulary( + tokenizer: "Tokenizer", +) -> Tuple[List[Tuple[str, Sequence[int]]], Set[int]]: + """ + Provided the tokenizer, calculate the + - vocabulary_nb: mapping of (normalized token str -> token_ids[:]) + - empty token ids + """ + norm_vocab, empty_token_ids = get_normalized_vocab(tokenizer) + norm_vocab_dict_nb = to_numba_dict( + np.fromiter(norm_vocab.keys(), dtype=np.int64), list(norm_vocab.values()) + ) + vocabulary_nb = vocab_dict_to_inverted_vocab_list(norm_vocab_dict_nb) return vocabulary_nb, empty_token_ids diff --git a/outlines/fsm/vocab_trie.py b/outlines/fsm/vocab_trie.py new file mode 100644 index 000000000..2b74714ac --- /dev/null +++ b/outlines/fsm/vocab_trie.py @@ -0,0 +1,244 @@ +import operator +from typing import List, Optional, Sequence, Tuple + +import numpy as np +from numba import njit, typed, types +from numba.cpython.hashing import ( + _Py_uhash_t, + _PyHASH_XXPRIME_1, + _PyHASH_XXPRIME_2, + _PyHASH_XXPRIME_5, + _PyHASH_XXROTATE, + process_return, +) +from numba.experimental import jitclass, structref +from numba.extending import overload +from numba.typed import Dict + +########################### +# Dict With Int[:] Key Impl +########################### + + +# Register type +@structref.register +class IntArrayDictType(types.StructRef): + """ + Represents a dictionary using int64[:] as keys, + intended for byte-level FSM representation with int64[:] transition. + """ + + def preprocess_fields(self, fields): + return tuple( + (name, typ.dtype if isinstance(typ, types.TypeRef) else typ) + for name, typ in fields + ) + + +class IntArrayDict(structref.StructRefProxy): + """Python proxy""" + + @property + def wrapped_dict(self): + return IntArrayDict_get_wrapped_dict(self) # noqa: F821 + + +structref.define_proxy(IntArrayDict, IntArrayDictType, ["wrapped_dict"]) + + +@njit +def hash_key(key): + """ + XXH64 Hash for int64[:] keys + adapted from https://github.com/numba/numba/blob/556545/numba/cpython/hashing.py + """ + acc = _PyHASH_XXPRIME_5 + for i in range(key.shape[0]): + x = key[i] + lane = hash(x) + if lane == _Py_uhash_t(-1): + return -1 + acc += lane * _PyHASH_XXPRIME_2 + acc = _PyHASH_XXROTATE(acc) + acc *= _PyHASH_XXPRIME_1 + + acc += key.shape[0] ^ (_PyHASH_XXPRIME_5 ^ _Py_uhash_t(3527539)) + + if acc == _Py_uhash_t(-1): + return process_return(1546275796) + + return process_return(acc) + + +@overload(IntArrayDict) +def custom_int_array_dict_constructor(value_type): + if isinstance(value_type, types.Type): + + def impl(value_type): + wrapped_dictionary = Dict.empty(types.intp, value_type) + return IntArrayDict(wrapped_dictionary) + + return impl + + +@overload(operator.getitem) +def ol_int_array_dict_getitem(inst, key): + if isinstance(inst, IntArrayDictType): + + def impl(inst, key): + return inst.wrapped_dict[hash_key(key)] + + return impl + + +@overload(operator.setitem) +def ol_int_array_dict_setitem(inst, key, value): + if isinstance(inst, IntArrayDictType): + + def impl(inst, key, value): + inst.wrapped_dict[hash_key(key)] = value + + return impl + + +@overload(operator.contains) +def ol_int_array_dict_contains(inst, key): + if isinstance(inst, IntArrayDictType): + + def impl(inst, key): + return hash_key(key) in inst.wrapped_dict + + return impl + + +################# +# Vocab Trie Impl +################# + +nb_int64_array_type = types.int64[:] + +# use intp keys as that is the hash type, +# but the true key type is nb_int64_array_type +IntArrayToIntType = IntArrayDictType( + (("wrapped_dict", types.DictType(types.intp, types.int64)),) +) +IntArrayToIntArrayType = IntArrayDictType( + (("wrapped_dict", types.DictType(types.intp, nb_int64_array_type)),) +) + + +@jitclass( + [ + ("token_to_token_key", IntArrayToIntType), + ("token_key_to_token", types.DictType(types.int64, nb_int64_array_type)), + ( + "token_key_to_child_token_keys", + types.DictType(types.int64, nb_int64_array_type), + ), + ("token_to_token_ids", IntArrayToIntArrayType), + ] +) +class VocabTrie: + """ + VocabTrie: Class for efficient traversal of the vocabulary + + Bidirectional mapping between trie node ID and nb_int64_array_type token + - token_to_token_key: Dict[nb_int64_array_type, int] + - token_key_to_token: Dict[int, nb_int64_array_type] + + Allow retrieval of children in trie + - token_key_to_child_token_keys: Dict[int, int64[:]] + + Allow retrieval of token_ids for a given token + - token_to_token_ids: Dict[nb_int64_array_type, int64[:]] + + Trie structure: + Only members of the vocabulary are included as nodes, no intermediates. + Structured to guarantee that recursive calls to get_children() + will return every token once, only once. + + Given a vocabulary of ["a", "ab", "abc", "ac", "ace", "apple"], + the children of "a" are "ab", "ac", "apple". + "abc" and "ace" are excluded because they have intermediate parents in the vocabulary. + """ + + def __init__( + self, + all_token_transitions: List[Sequence[int]], + vocabulary: List[Tuple[str, Sequence[int]]], + ): + self.token_to_token_key = IntArrayDict( + typed.Dict.empty(types.intp, types.int64) + ) + self.token_key_to_token = typed.Dict.empty( + key_type=types.int64, value_type=nb_int64_array_type + ) + self.token_key_to_child_token_keys = typed.Dict.empty( + key_type=types.int64, value_type=nb_int64_array_type + ) + self.token_to_token_ids = IntArrayDict( + typed.Dict.empty(types.intp, nb_int64_array_type) + ) + + self._insert(all_token_transitions, vocabulary) + + def _insert( + self, + all_token_transitions: List[Sequence[int]], + vocabulary: List[Tuple[str, Sequence[int]]], + ) -> None: + # Initialize an empty array for the root token key to store child token keys + self.token_key_to_child_token_keys[-1] = np.empty((0,), types.int64) + + # It's necessary to insert shorter transition sequences (prefixes) first + sorted_idx_transition_seq = sorted( + enumerate(all_token_transitions), key=lambda x: len(x[1]) + ) + + for idx, token_transitions in sorted_idx_transition_seq: + token_ids = vocabulary[idx][1] + if token_transitions not in self.token_to_token_key: + # create bimapping between token and token_key (tokens trie node key) + self.token_to_token_key[token_transitions] = idx + self.token_key_to_token[idx] = token_transitions + + # find parent token key + parent_token_key = -1 # root token + for i in range(len(token_transitions) - 1, -1, -1): + prefix_token = token_transitions[:i] + if prefix_token in self.token_to_token_key: + parent_token_key = self.token_to_token_key[prefix_token] + break + # map parent token to current token + self.token_key_to_child_token_keys[parent_token_key] = np.append( + self.token_key_to_child_token_keys[parent_token_key], + np.array([idx]), + ) + # map current token to empty list of children + self.token_key_to_child_token_keys[idx] = np.empty((0,), types.int64) + + # set current tokens token ids + self.token_to_token_ids[token_transitions] = token_ids + + else: + # if exists, append to current tokens token ids + self.token_to_token_ids[token_transitions] = np.append( + self.token_to_token_ids[token_transitions], token_ids + ) + + def get_children(self, token_transitions: Optional[Sequence[int]] = None): + """ + Get the token_ids of all children for the given token_id. + If token_id is None, get the root children. + """ + if token_transitions is None: + token_key = -1 + else: + token_key = self.token_to_token_key[token_transitions] + + child_token_keys = self.token_key_to_child_token_keys[token_key] + + return [self.token_key_to_token[token_key] for token_key in child_token_keys] + + def get_token_ids(self, token): + return self.token_to_token_ids[token] diff --git a/tests/fsm/test_regex.py b/tests/fsm/test_regex.py index fa72ad0dc..930afba77 100644 --- a/tests/fsm/test_regex.py +++ b/tests/fsm/test_regex.py @@ -18,6 +18,7 @@ reduced_vocabulary, walk_fsm, ) +from outlines.fsm.vocab_trie import VocabTrie from outlines.integrations.utils import adapt_tokenizer from outlines.models.transformers import TransformerTokenizer @@ -717,3 +718,51 @@ def test_reduced_vocabulary_with_rare_tokens(rare_token): tokenizer = adapt_tokenizer(tokenizer=tokenizer) tokenizer.vocabulary[rare_token] = max(tokenizer.vocabulary.values()) + 1 reduced_vocabulary(tokenizer) + + +def test_vocab_trie_ordering(): + class MockTokenizer: + vocabulary = {"": 0, "a": 1, "abc": 2, "def": 3, "abcd": 4, "abce": 5} + special_tokens = {""} + eos_token = "" + eos_token_id = 4 + + def convert_token_to_string(self, token): + return token + + tokenizer = MockTokenizer() + + pattern = r"abc[de]fghi" + regex_pattern = interegular.parse_pattern(pattern) + interegular_fsm = regex_pattern.to_fsm().reduce() + regex_fsm, _ = make_deterministic_fsm(interegular_fsm) + vocabulary, _ = reduced_vocabulary(tokenizer) + token_trans_keys = get_vocabulary_transition_keys( + regex_fsm.fsm_info.alphabet_symbol_mapping, + regex_fsm.fsm_info.alphabet_anything_value, + vocabulary, + numba.typed.List.empty_list(numba.types.unicode_type), + ) + + vocab_trie = VocabTrie(token_trans_keys, vocabulary) + + def get_children(parent_id=None): + trans_key = token_trans_keys[parent_id] if parent_id is not None else None + res = [ + set(vocab_trie.get_token_ids(child)) + for child in vocab_trie.get_children(trans_key) + ] + if not res: + return set() + return set.union(*res) + + # initial children - tokens with no predecessor tokens: ["eos", "a", "def"] + assert get_children() == {0, 1, 3} + # children of "a" (1) are ["abc"] + assert get_children(1) == {2} + # children of "abc" (2) are ["abcd", "abce"] + assert get_children(2) == {4, 5} + # no children for these + assert not get_children(3) + assert not get_children(4) + assert not get_children(5)