Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix null byte \x00 issue by switching to numba.types.unicode_type #904

Closed
wants to merge 1 commit into from

Conversation

M0gician
Copy link
Contributor

Fixes #833

Changes

Test

I've used this patch for about a month. So far everything is working well without any problem.

@lapp0
Copy link
Contributor

lapp0 commented May 22, 2024

@M0gician please let me know if you would like any help on the tests.

Edit:

Problem

Looking further into it, the core issue is that hex representations invocabulary are now an iterable of length 2.
'9F'

before it was an iterable of length 1.
(array(['9F'], dtype='<U2')

So when we compare to alphabet_symbol_mapping, we check for a 9 key and F key separately, but only the key 9F exists, resulting in inappropriate exclusion of this symbol.

Looking into a clean way to fix this.

Solution Investigation

Perhaps we can create a format where the byte (9F) is represented as a null-prefixed custom encoding, e.g.

>>> bytearray.fromhex("009F")
bytearray(b'\x00\x9f')

Then we can treat null-prefixed characters as bytes, and everything else as characters.

No major tokenizers use tokens containing characters containing null bytes (they're all following the UTF-8 standard), so there wouldn't be an collisions.

Script:

from collections import Counter
from transformers import AutoTokenizer


def count_first_byte_occurrences(strings):
    byte_counts = Counter()
    for s in strings:
        for char in s:
            utf8_bytes = char.encode('utf-8')
            assert char.encode('utf-8').decode() == char
            for i in range(0, len(utf8_bytes)):
                first_byte = utf8_bytes[i]
                byte_counts[first_byte] += 1
    return byte_counts


tokenizer_uris = [
    "Qwen/Qwen1.5-0.5B-Chat",
    "microsoft/Phi-3-vision-128k-instruct",
    "mistralai/Mistral-7B-Instruct-v0.2",
    "NousResearch/Hermes-2-Pro-Llama-3-8B",
    "microsoft/phi-2",
]


for tokenizer_uri in tokenizer_uris:
    print(tokenizer_uri)
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_uri)
    token_strs = list(tokenizer.get_vocab())
    first_byte_occurrences = count_first_byte_occurrences(token_strs)
    assert first_byte_occurrences[0] == 0
    print("no null byte present in any characters:)")

    print("most common bytes:", first_byte_occurrences.most_common(4))
    print("number of bytes present:", len(first_byte_occurrences), "/ 256")
    print()

Output:

Qwen/Qwen1.5-0.5B-Chat
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
no null byte present in any characters:)
most common bytes: [(196, 156537), (195, 125517), (194, 100880), (160, 70087)]
number of bytes present: 162 / 256

microsoft/Phi-3-vision-128k-instruct
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
no null byte present in any characters:)
most common bytes: [(129, 17192), (226, 16768), (150, 16761), (101, 14787)]
number of bytes present: 203 / 256

mistralai/Mistral-7B-Instruct-v0.2
no null byte present in any characters:)
most common bytes: [(129, 16308), (226, 16091), (150, 16028), (101, 14940)]
number of bytes present: 231 / 256

NousResearch/Hermes-2-Pro-Llama-3-8B
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
no null byte present in any characters:)
most common bytes: [(196, 106088), (195, 72004), (160, 69912), (101, 60476)]
number of bytes present: 162 / 256

microsoft/phi-2
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
no null byte present in any characters:)
most common bytes: [(196, 34126), (160, 33158), (101, 30206), (105, 21804)]
number of bytes present: 164 / 256

@M0gician
Copy link
Contributor Author

@M0gician please let me know if you would like any help on the tests.

I am currently busy this week so my bandwidth is limited.

I tried to step into the exceptions but did not find anything decisive yet. If you got some time to check the test that will definitely help.

@lapp0
Copy link
Contributor

lapp0 commented May 26, 2024

Could you please review lapp0@8e168c6

This includes a new token / symbol representation format introduced in my comment above.

Before we would have an array with utf-8 symbols and hex codes. We would distinguish these based on how many characters the symbol has. One character implies the character is a utf-8 character, two implies it's a hex representation of a byte.

  • main: ["😇", "9", "F", "9F"]

Because we use unicode_type instead a List[UnicodeCharSeq(2)], the example sequence would be represented as

  • this PR: ["😇", "9", "F", "9F"] -> '😇9F9F'

This demonstrates a problem, we have no idea whether consecutive hex characters represent a byte or two separate utf-8 characters. To resolve this, we prefix a hex-byte with a null byte.

  • my branch part 1: ["😇", "9", "F", "9F"] -> '😇9F\x009F'.

This allows us to avoid the issues with numba U2-arrays by representing as a unicode_type, while also distinguishing hex bytes from character pairs.

Processing tokens symbol-by-symbol is inefficient, especially when you're applying conditional handling within _walk_fsm for \x00-prefixed characters. A further adjustment is made to precompute the Sequence[int] of transition keys for a given token, then calling _walk_fsm for each tokens transition key sequence.

This final change improved runtime from ~210% of main to ~80%.

TODO:

  • fix llamacpp test failures (ValueError: ctypes objects containing pointers cannot be pickled)
  • smoke testing, additional unit tests
  • determine cause of 2x performance degredation Fix null byte lapp0/outlines#21 (comment)
    • I incorporated an index which converts tokens into a sequence of transition keys and now it's slightly faster than main!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

KeyError in BetterFSM::FSMInfo when input FSM alphabet contains UTF-8 characters that ends with \xb8\x80
2 participants