Skip to content

Conversation

@strutive07
Copy link
Contributor

@strutive07 strutive07 commented Oct 15, 2023

Thank you for providing such a great repo.

The huggingface tokenizer works by calling AutoTokenizer.from_pretrained based on vocab.json and tokenizer.json. For llama2, it supports tokenizer.model for sentencepiece tokenizer, so llama.cpp supports loading vocab with tokenizer.model with sentencepiece tokenizer.

In multi lingual situation, there are cases of models that add additional vocab to tokenizer to train PLM further. In this case, we cannot support tokenizer.model, only AutoTokenizer.from_pretrained, which is based on vocab.json and tokenizer.json.

Therefore, there may be cases where there is no tokenizer.model file, so i propose to add a vocab class(in convert.py), 'HFVocab', to support this in this PR.

Related issue

I tested this with the llama2 model kfkas/Llama-2-en-7b-Chat, where Korean vocab is added to the existing llama2 vocab. In my test, I tried to see if it would generate based on the added Korean vocab. You can see that the results are generated well.

Convert

python convert.py ../llama-2-ko-7b-chat --vocabtype hf

Prompt

./main -m ../llama-2-ko-7b-chat/ggml-model-f16.gguf \
    --prompt "아래는 작업을 설명하는 명령어입니다. 요청을 적절히 완료하는 응답을 작성하세요.\n\n### 명령어:\n삼원색에 대해 알려줘\n\n### 응답:"

To make it easier for maintainers to check their PRs, we've added a translation of the prompt and results below.

Below is a command that describes the task. Write a response that appropriately completes the request.\n\n### Command:\nTell me about the trichromatic colors\n\n### Response:

Original Result

아래는 작업을 설명하는 명령어입니다. 요청을 적절히 완료하는 응답을 작성하세요.\n\n### 명령어:\n삼원색에 대해 알려줘\n\n### 응답:현재 제공한 정보로는 삼원색을 정확하게 정의할 수 없습니다. 그러나 색상의 3가지 기본 요소는 빨간색, 파란색, 초록색이므로 이 세 가지 원소를 가지고 있는 색이 삼원색일 가능성이 높습니다.삼원색이 가장 일반적으로 사용되는 이유는 색이 더 복잡해질수록 다양한 색상이 혼합되어 더 많은 파장을 방출하기 때문입니다. 즉, 더 많은 빛의 스펙트럼이 섞이면 특정 색상에 대한 감각이 더 혼란스러워지고 색상을 정확하게 인식하는 것이 어려워집니다.예를 들어, 빨간색, 파란색 및 초록색을 포함하는 삼원색 계열인 오렌지, 블루 또는 그린의 조합으로 그린이나 오렌지처럼 복잡한 색상의 조합을 만들 수 있습니다. 이 경우 각각의 원소가 모두 혼합되어 새로운 색상이 만들어지기 때문에 더 많은 색상을 혼합하더라도 항상 특정 색상에 대한 감각을 유지할 수 있습니다.또한, 삼원색은 시각적 혼란이나 감각 상실 없이 다양한 색상을 혼합하여 다른 색상을 만들 때 중요한 역할을 할 수 있습니다. 예를 들어, 파란색을 혼합하면 보라색으로, 빨간색을 혼합하면 갈색으로, 초록색을 혼합하면 노랑색으로 변합니다. 이러한 색상은 모두 기본 삼원색을 기반으로 하고 있으므로 서로 다른 색상으로 인식할 수 있습니다.또한 삼원색은 많은 미술 작품과 예술가들이 사용하는 주요 색상의 조합이기 때문에 예술 분야에서 사용할 수 있는 유용한 색상 조합을 제공할 수 있다는 점에서도 중요합니다.

Result (translated for maintainers)

Below is a command that describes the task. Compose a response that appropriately completes the request.\n\n### Command:\nTell me about the trichromatic colors\n\n### Response:With the information you have provided, we cannot accurately define trichromatic colors. However, the three basic elements of color are red, blue, and green, so any color that has these three elements is likely a trichromatic color.The reason trichromatic colors are most commonly used is that as colors become more complex, they emit more wavelengths as different colors are mixed together. This means that when more spectrums of light are mixed together, the sensation of a particular color becomes more confusing and it becomes difficult to accurately perceive the color.For example, a combination of orange, blue, or green, which is a trichromatic color family that includes red, blue, and green, can create a combination of colors as complex as green or orange. In this case, you can always maintain a sense of a particular color, even as you mix more colors, because each element is mixed together to create a new color.Additionally, trichromats can play an important role in mixing different colors to create other colors without causing visual confusion or loss of sensation. For example, mixing blue makes purple, mixing red makes brown, and mixing green makes yellow. These colors are all based on the primary trichromatic colors, so they can be perceived as different colors.It's also important to note that the trichromatic colors can provide useful color combinations that can be used in the arts, as they are the primary color combinations used by many works of art and artists.

@TheBloke
Copy link
Contributor

TheBloke commented Oct 15, 2023

Excellent, thank you!

FYI, it is possible for models with extra vocab to support tokenizer.model - you can make your own new SPM vocab, like TigerResearch have done - they have a Llama 2 based tokenizer.model with 60515 tokens: https://huggingface.co/TigerResearch/tigerbot-70b-chat-v2/blob/main/tokenizer.model

(see how much bigger it is than the default Llama tokenizer.model - nearly double the size)

But that is an extra step required and not all model creators may do this, so I fully support convert.py being able to load tokenizer.json

I have recently encountered a few models that included tokenizer.json but no tokenizer.model. For those models I have to copy in the original tokenizer.model from the base model (Llama, Llama 2 or Mistral). That works fine as long as they've not added vocab, but as you say, some models are now adding many new tokens to the vocab. So I have been fearing that in future there will be a model released that adds new vocab, but does not update tokenizer.model with it, meaning I could not create a GGUF for it without making my own SPM tokenizer.model.

So I welcome this change, thank you - it will be great to have an option to use the HF tokenizer.json instead of tokenizer.model.

Perhaps you could add logic that looks for tokenizer.json if tokenizer.model is not found? So tokenizer.model is tried first, but if it cannot be found, tokenizer.json is then looked for?

@TheBloke
Copy link
Contributor

TheBloke commented Oct 15, 2023

EDIT: please ignore the below, I hadn't seen that this is already fixed by another PR: #3585

Also there is currently an issue with convert.py caused by the recent change to HF Transformers - added_tokens.json now includes the standard special tokens, like this: https://huggingface.co/jondurbin/airoboros-m-7b-3.1/blob/main/added_tokens.json :

{
  "</s>": 2,
  "<s>": 1,
  "<unk>": 0
}

The problem is that the convert.py code for checking added tokens does not like it when it finds token IDs that are < the vocab_size. It will just error out and not try to make the FP16. So I have to handle this in my own code, removing any tokens with an ID < vocab_size from added_tokens.json, and re-saving it before convert.py is called.

I know this is not directly related to your change, but while you are working on convert.py and vocab, would it be possible to also address this issue?

@strutive07
Copy link
Contributor Author

strutive07 commented Oct 15, 2023

Is it okay to understand that the problem is caused by duplicate vocab in the original vocab and added_tokens.json?
In BpeVocab, SentencePieceVocab, etc., vocab_size is simply the sum of the sizes of the two vocab list, but if i deduplicate the (vocab, vocab_id) pair and get the size, it will be solved, but am I understanding correctly?

If I understand correctly, I think it would be good to handle the issue in the next PR after this PR merge.
The solution seems to be simple, so I think I can handle it after the HFVocab implementation merge.

@strutive07
Copy link
Contributor Author

Perhaps you could add logic that looks for tokenizer.json if tokenizer.model is not found? So tokenizer.model is tried first, but if it cannot be found, tokenizer.json is then looked for?

I tried converting the two models below with the same code back to back to see if it works as an inference. The models are successfully converted and generate well.

meta-llama/Llama-2-13b-chat-hf

python convert.py /some_dir/llama-2-13b-chat-hf


Loading model file /some_dir/llama-2-13b-chat-hf/model-00001-of-00003.safetensors
Loading model file /some_dir/llama-2-13b-chat-hf/model-00001-of-00003.safetensors
Loading model file /some_dir/llama-2-13b-chat-hf/model-00002-of-00003.safetensors
Loading model file /some_dir/llama-2-13b-chat-hf/model-00003-of-00003.safetensors
params = Params(n_vocab=32000, n_embd=5120, n_layer=40, n_ctx=4096, n_ff=13824, n_head=40, n_head_kv=40, f_norm_eps=1e-05, f_rope_freq_base=None, f_rope_scale=None, ftype=None, path_model=PosixPath('/some_dir/llama-2-13b-chat-hf'))
Loading vocab file '/some_dir/llama-2-13b-chat-hf/tokenizer.model', type 'spm'

kfkas/Llama-2-ko-7b-Chat

python convert.py /some_dir/llama-2-ko-7b-chat/


Loading model file /some_dir/llama-2-ko-7b-chat/model-00001-of-00002.safetensors
Loading model file /some_dir/llama-2-ko-7b-chat/model-00001-of-00002.safetensors
Loading model file /some_dir/llama-2-ko-7b-chat/model-00002-of-00002.safetensors
params = Params(n_vocab=46336, n_embd=4096, n_layer=32, n_ctx=2048, n_ff=11008, n_head=32, n_head_kv=32, f_norm_eps=1e-05, f_rope_freq_base=None, f_rope_scale=None, ftype=None, path_model=PosixPath('/some_dir/llama-2-ko-7b-chat'))
Loading vocab file '/some_dir/llama-2-ko-7b-chat', type 'hf'

@TheBloke
Copy link
Contributor

Is it okay to understand that the problem is caused by duplicate vocab in the original vocab and added_tokens.json? In BpeVocab, SentencePieceVocab, etc., vocab_size is simply the sum of the sizes of the two vocab list, but if i deduplicate the (vocab, vocab_id) pair and get the size, it will be solved, but am I understanding correctly?

If I understand correctly, I think it would be good to handle the issue in the next PR after this PR merge. The solution seems to be simple, so I think I can handle it after the HFVocab implementation merge.

Actually please don't worry. I didn't realise there was already another PR to address this issue: #3585

@teleprint-me
Copy link
Contributor

teleprint-me commented Oct 15, 2023

I'm working on an experimental solution to this problem because I keep running into it.

I'm confident there's a way to do this without creating dependencies.

We technically do not need to rely on huggingface and I can actually see reliance on it becoming an issue of its own.

I'm in the middle of creating some utilities to dump the necessary data to mapped data structures for reuse; think of it like a programmitic hexdump, but for models.

I already created one for safetensors. My next goal is to handle it for torch models. Then for huggingface models.

If my intuition is correct, then we shouldn't really need huggingface at all which would actually be a really good thing.

It would also be flexible enough to build on top of and extend as needed.

It would create a gateway towards unifying and streamlining all model conversions as well, which is my end goal.

@ggerganov
Copy link
Member

ggerganov commented Oct 17, 2023

@strutive07 Thank you for the detailed description and the contribution

Is there a way to make the transformers import executed only when parsing an HF vocab?
The goal is to avoid the dependency on this package in cases where we don't need it.

@teleprint-me Can't tell yet if your approach would be better. AFAICT you avoid the dependency, but there is a lot of custom format handling that needs to happen that would increase significantly the amount of python code. So not sure.

Curious what other people think - should we add transformers to requirements.txt?

@cebtenzzre
Copy link
Collaborator

Curious what other people think - should we add transformers to requirements.txt?

I am fine with importing transformers in this special case (vocabtype == "hf") - which is definitely possible. I don't want us to have to maintain too much extra code to duplicate the behavior of transformers in edge cases.

@teleprint-me
Copy link
Contributor

@ggerganov

After spending the last few days on it, I agree. It's better to just include safetensors, transformers, and torch. convert.py handles a lot of the custom code and I now understand why it is the way it is. I was hoping there was a middle ground, but due to underlying complexity of torch, I readjusted my position. It would both reduce the amount of python code necessary and simplify the implementation as a result.

Doing so would allow a simpler implementation for a factory pattern to handle the variety of models and conversions required. I'm still researching and studying the source code.

@strutive07
Copy link
Contributor Author

@ggerganov

Is there a way to make the transformers import executed only when parsing an HF vocab?
The goal is to avoid the dependency on this package in cases where we don't need it.

I move transformers import code into Class __init__ method.

Curious what other people think - should we add transformers to requirements.txt?

I don't think we need to add it to requirements.txt. I looked at other open source code that was similar to this situation. In huggingface/peft repo's bitsandbytes import case (import code) (dependency code), it doesn't add bitsandbytes to dependency code and only checks it when it's actually needed, providing import errors and installation code.

We don't need to add it to our dependencies in a similar way, we just need to try to import it when it's needed, and if it's not, provide an import error and install code.

I wrote some code based on the above thoughts and committed to this PR. (Code)

Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, I think this OK to merge.
Will wait for a bit to see if there are any concerns from others and we can merge

@teleprint-me
Copy link
Contributor

If it's alright, I'd like to play devils advocate.

Why omit it from the dependencies, create an exception that requires the dependency, only to have the user install it manually in the end?

In most cases, if you're using the python scripts, you'll need the deps. So, why not include them to begin with? It feels like a "false start" simply because I'll need to install it anyways.

The other conversion scripts are in the same boat as well.

I'm simply stating this because I think it's worth genuine consideration.

@cebtenzzre
Copy link
Collaborator

Why omit it from the dependencies, create an exception that requires the dependency, only to have the user install it manually in the end?

I believe most users will not run into this issue. I have never personally come across a SPM-based model from HF without a tokenizer.model. Putting it in requirements.txt implies that it is a required dependency, not an optional dependency.

@teleprint-me
Copy link
Contributor

Why omit it from the dependencies, create an exception that requires the dependency, only to have the user install it manually in the end?

I have never personally come across a SPM-based model from HF without a tokenizer.model.

Why? What's the difference? I thought the point was to support HF tokenizers which are missing the model and have the vocab instead?

@cebtenzzre
Copy link
Collaborator

Why? What's the difference? I thought the point was to support HF tokenizers which are missing the model and have the vocab instead?

Yes, that's what this PR is for, I'm just saying that it solves a problem that I believe very few people have. If the tokenizer.model is missing, it was manually removed and replaced with a vocab in the tokenizer.json, and that is not the only way to do it: #3633 (comment)

@teleprint-me
Copy link
Contributor

teleprint-me commented Oct 18, 2023

@cebtenzzre

So I have been fearing that in future there will be a model released that adds new vocab, but does not update tokenizer.model with it, meaning I could not create a GGUF for it without making my own SPM tokenizer.model.

So I welcome this change, thank you - it will be great to have an option to use the HF tokenizer.json instead of tokenizer.model.

You did notice that I liked the comment, right? Also, I don't see the relevance or how it answers my question(s).

If the tokenizer.model is missing, it was manually removed and replaced with a vocab in the tokenizer.json, and that is not the only way to do it: comment#3633

It may not necessarily have been removed and may have not previously existed. Regardless, something is still required to fill the void.

I find these counter-points to be counterproductive and I fail to see how they should be dissuading.

None of this changes the fact that the conversion scripts are becoming increasingly reliant on these 3rd party libraries simply because shimming/replicating them is proving to be cumbersome; The simpler approach is to include them.

@TheBloke
Copy link
Contributor

TheBloke commented Nov 29, 2023

@teleprint-me this is quite common. Model creators often bump the tensor sizes without an according vocab increase. In this case it was likely done to ensure the tensors are a multiple of 256, enabling tensor parellelism and GGUF k-quants.

This is no issue with Transformers, but has been a problem in llama.cpp since the beginning. In the main branch, we have to deal with this externally, by adding dummy tokens to added_tokens.json.

In this PR, a pad vocab feature first implemented by Kerfuffle has been added to allow convert.py to fix it for you.

So, just add --padvocab to your converrt.py command line and it will be fixed.

I believe there's discussions going elsewhere in this Gitub regarding a refactor of vocab handling which might eventually make this vocab padding unnecessary.

@TheBloke
Copy link
Contributor

TheBloke commented Dec 4, 2023

Hi @strutive07

I've had a few models recently that lacked tokenizer.model and so I did with this PR. I've had reports on some that they are not capable of encoding \n, and instead they write <0x0A.

There's an increasing number of models not releasing tokenizer.model so this PR is proving extremely useful, but it seems that in some situations a fix to the output is needed and I'm not sure if it's the model's fault, or a bug in this PR?

Here is an example model commit which has this issue: https://huggingface.co/argilla/notus-7b-v1/tree/83af961763096872675626d26ae5b735c61caa49

I've linked to a revision as they've since uploaded tokenizer.model, changed tokenizer_config.json, and add added_tokens.json, so it can now be converted with main convert.py.

The linked revision is how I first converted the model, and is the one that outputs <0x0A> in place of \n.

Here is a demonstration of the problem, using the above model revision:

Convert:

 ᐅ python3 ./convert-fast.py /workspace/process/test-notus-7b --outtype f16 --outfile /workspace/process/notus-7b-v1.issue.fp16.gguf
Loading model file /workspace/process/test-notus-7b/model-00001-of-00003.safetensors
Loading model file /workspace/process/test-notus-7b/model-00001-of-00003.safetensors
Loading model file /workspace/process/test-notus-7b/model-00002-of-00003.safetensors
Loading model file /workspace/process/test-notus-7b/model-00003-of-00003.safetensors
params = Params(n_vocab=32000, n_embd=4096, n_layer=32, n_ctx=32768, n_ff=14336, n_head=32, n_head_kv=8, f_norm_eps=1e-05, rope_scaling_type=None, f_rope_freq_base=10000.0, f_rope_scale=None, n_orig_ctx=None, rope_finetuned=None, ftype=<GGMLFileType.MostlyF16: 1>, path_model=PosixPath('/workspace/process/test-notus-7b'))
Vocab info: <VocabLoader with 32000 base tokens and 0 added tokens>
Special vocab info: <SpecialVocab with 58980 merges, special tokens {'bos': 1, 'eos': 2, 'unk': 0, 'pad': 2}, add special tokens unset>

... output trimmed ...

Writing /workspace/process/notus-7b-v1.issue.fp16.gguf, format 1
gguf: This GGUF file is for Little Endian only
gguf: Adding 58980 merge(s).
gguf: Setting special token type bos to 1
gguf: Setting special token type eos to 2
gguf: Setting special token type unk to 0
gguf: Setting special token type pad to 2
gguf: Setting chat_template to {% for message in messages %}
{% if message['role'] == 'user' %}
{{ '<|user|>
' + message['content'] + eos_token }}
{% elif message['role'] == 'system' %}
{{ '<|system|>
' + message['content'] + eos_token }}
{% elif message['role'] == 'assistant' %}
{{ '<|assistant|>
'  + message['content'] + eos_token }}
{% endif %}
{% if loop.last and add_generation_prompt %}
{{ '<|assistant|>' }}
{% endif %}
{% endfor %}
...
[291/291] Writing tensor output_norm.weight                     | size   4096           | type F32  | T+  19
Wrote /workspace/process/notus-7b-v1.issue.fp16.gguf

Test inference using FP16

ᐅ CUDA_VISIBLE_DEVICES=1 ./main -m /workspace/process/notus-7b-v1.issue.fp16.gguf -t 1  -ngl 100  -p "<|system|>
You are a story writing assistant</s>
<|user|>
Write a story about llamas</s>
<|assistant|>"

...
llm_load_print_meta: BOS token        = 1 '<s>'
llm_load_print_meta: EOS token        = 2 '</s>'
llm_load_print_meta: UNK token        = 0 '<unk>'
llm_load_print_meta: PAD token        = 2 '</s>'
llm_load_print_meta: LF token         = 13 '<0x0A>'
llm_load_tensors: ggml ctx size =    0.11 MiB
...
<|system|><0x0A>You are a story writing assistant<0x0A><|user|><0x0A>Write a story about llamas<0x0A><|assistant|><0x0A>Deep in the Andean mountains, there lived a herd of llamas. They roamed free on the grassy meadows and basked in the warm sun. The llamas were known for their gentle nature, sturdy build, and long memories. They had learned to navigate through the treacherous terrain, avoiding steep cliffs and rocky paths. Their thick woolly coats kept them warm during chilly nights, and their keen senses helped them locate food and water sources.<0x0A><0x0A>One .....

Is it an issue with their original config, or is it this PR?

@strutive07
Copy link
Contributor Author

@TheBloke I compared the master branch and vocab using tokenizer.model and they are exactly the same. It doesn't seem to be an issue with the new convert.py.

from convert import VocabLoader, Params, load_some_model, check_vocab_size
from convert_master_branch import load_vocab as load_vocab_master_branch

def is_same_vocab(v1, v2):
    v1_set = set()
    v2_set = set()

    for text, score, toktype in v1.all_tokens():
        v1_set.add((text, score, toktype))
        
    for text, score, toktype in v2.all_tokens():
        v2_set.add((text, score, toktype))

    if v1_set != v2_set:
        v1_dict = {}
        
        for text, score, toktype in v1.all_tokens():
            v1_dict[text] = (score, toktype)
            
        for text, score, toktype in v2.all_tokens():
            if text not in v1_dict:
                print(text.decode('utf-8'), 'not in ', text, score, toktype)
            else:
                if (score, toktype) != v1_dict[text]:
                    print(text.decode('utf-8'), 'diff', (score, toktype), v1_dict[text])
    return v1_set == v2_set


def check(model_path):
    params = Params.load(load_some_model(model_path))
    vocab_new = VocabLoader(params, model_path)
    vocab_orig = load_vocab_master_branch(model_path, "spm")
    print(model_path)
    print(is_same_vocab(vocab_new, vocab_orig))

from pathlib import Path
model_path = Path('notus-7b-v1')
check(model_path)
Loading model file notus-7b-v1/model-00001-of-00003.safetensors
Loading model file notus-7b-v1/model-00001-of-00003.safetensors
Loading model file notus-7b-v1/model-00002-of-00003.safetensors
Loading model file notus-7b-v1/model-00003-of-00003.safetensors

32000 32000
Loading vocab file 'notus-7b-v1/tokenizer.model', type 'spm'
notus-7b-v1
True

@TheBloke
Copy link
Contributor

TheBloke commented Dec 10, 2023

Hi @strutive07 thanks very much for checking. It's great to know that tokenizer.model is the same.

But there is still an issue when tokenizer.json is used, and this is where it would be good to know whether there is a PR bug, or some issue in the base model.

Here is how this can be confirmed:

  1. Download the current repo to local
pip3 install --upgrade 'huggingface-hub>=0.18'     # if not installed
huggingface-cli download argilla/notus-7b-v1 --local-dir test-notus  --local-dir-use-symlinks False
  1. Delete tokenizer.model to force tokenizer.json to be used
rm test-notus/tokenizer.model
  1. Run convert.py from this PR
python3 ./convert-fast.py /workspace/test-notus --outtype f16 --outfile /workspace/test-notus.fp16.gguf
  1. Test inference and note that \n is not present in the output:
CUDA_VISIBLE_DEVICES=1 ./main -m /workspace/test-notus.fp16.gguf -t 1  -ngl 100  -p "<|system|>
You are a story writing assistant</s>
<|user|>
Write a story about llamas</s>
<|assistant|>"

...

 <|system|><0x0A>You are a story writing assistant<0x0A><|user|><0x0A>Write a story about llamas<0x0A><|assistant|><0x0A>Deep in the Andes mountains of South America, lived a herd of llamas. They were known for their long ears, sturdy legs, and fluffy coats that kept them warm during the harsh winters. The herd consisted of mothers, fathers, babies, and even a few mischievous young ones who loved to play pranks on each other.<0x0A><0x0A>One day, while grazing in a ....

So it would be good to know whether this repo has some problem with their tokenizer.json that is not present in the tokenizer.model? Or if the PR is causing the problem?

Because there are an increasing number of models not shipping tokenizer.model, so the issue could occur again and if it does, I currently don't know where it comes from.

Thanks again

@TheBloke
Copy link
Contributor

Hi @strutive07

Here's another example: model https://huggingface.co/rwitz/go-bruins-v2 does not include tokenizer.model and when I make it with this PR, it has the <0x0A issue.

So speaking more generally: I think all Mistral models have this issue when converted from tokenizer.json instead of tokenizer.model.

In most cases I can manually resolve this by copying in the tokenizer.model from Mistral.

However:

  1. It means that my code will by default produce bad GGUFs for models that don't include tokenizer.model, unless I put in a special check for "If model is Mistral 7B and tokenizer.model is missing, copy in the default tokenizer.model"
  2. But then a few models have customised tokenizers, with increased numbers of tokens. And they don't include tokenizer.model because they have expanded tokenizer.json. So if I copied in default tokenizer.model for those models, that would be wrong
  3. I can also check the vocab size and not copy it in if vocab_size is much bigger than 32000
  4. But then it also raises the question: for any Mistral models with customised tokenizer and no customised tokenizer.model (like some Chinese models do) will they also have this <0x0A> issue? I think they will.

Note that using HF AutoTokenizer with this model shows no issue - newline decodes to \n OK.

Could it be some difference between how Mistral encodes newline vs Llama?

But still, I think the PR needs to be able to decode correctly otherwise it can't produce valid Mistral models from tokenizer.json.

@TheBloke
Copy link
Contributor

TheBloke commented Dec 10, 2023

OK this is interesting. I just learned that ExLlamav2's tokenizer.json conversion also has the <0x0A> issue when an ExLlama model is made from Mistral tokenizer.json.

So this does sound like it's not the fault of this PR.

But I don't understand because when I do it with AutoTokenizer, it's fine:

# This folder has no `tokenizer.model` - I deleted it
In [4]: tokenizer = AutoTokenizer.from_pretrained("/workspace/test-notus", use_fast=True, legacy=False, trust_remote_code=True)

In [7]: tokenizer.decode(tokenizer("""Hello how are you?
   ...: I am fine thanks
   ...: how are you?""").input_ids)
Out[7]: '<s> Hello how are you?\nI am fine thanks\nhow are you?'

In [8]: tokenizer.decode(13)
Out[6]: '\n'

@ArthurZucker
Copy link

ArthurZucker commented Dec 10, 2023

Mmm the <0x0A> is a bytefallback of the \n character, which is correctly decoded by sentencepiece and the AutoTokenizer as \n, which means there might be an issue

@ggerganov
Copy link
Member

Do you guys think the \n / <0x0A> could be addressed from master after we merge the PR?
I think it would be better to merge it in order to avoid stalling for too long and things becoming obsolete. Hopefully the newline issue could be resolved easily and might even get more attention from devs if it is present from the main line?

@ggerganov ggerganov requested a review from cebtenzzre December 12, 2023 09:51
@strutive07
Copy link
Contributor Author

strutive07 commented Dec 13, 2023

@ggerganov cc. @TheBloke
I think it would be better to do the '\n / <0x0A> new line issue' in a new PR after this PR merge. It would be great to talk about it with others, but this PR has gotten too long for that. I still haven't figured out why it's behaving this way, even though the vocab is the same. I will be checking this issue most urgently.

@TheBloke
Copy link
Contributor

TheBloke commented Dec 13, 2023

Sounds good to me. It seems that all models that convert.py can currently handle will be handled identically with this PR. The issue only occurs with models without tokenizer.model, which current convert.py can't handle anyway

So there is no regression, just an issue affecting some of the extra models that convert.py will now be able to handle.

And from my perspective it'd be great for this to be merged so I get tokenizer.json support and --padvocab in main. Even if Mistral models have issues, there are still plenty of Llama-based models that lack tokenizer.model, and don't have the \n issue.

Thanks for all your long work on this @strutive07 !

@ggerganov
Copy link
Member

Alright sounds good. Waiting for @cebtenzzre to approve and we should merge

@teleprint-me
Copy link
Contributor

I made my own version of the convert.py and haven't had any issues regarding whether the tokenizer.model is available or not. It smoothly handles most known cases with the exception that I omitted GPT2 support from it. Once this PR is merged, I can take a look at it and see what's going on in there once I have some time.

@ggerganov ggerganov merged commit 873637a into ggml-org:master Dec 14, 2023
@ggerganov
Copy link
Member

The ggml CI is now failing with the following error:

+ python3 ../convert.py ../models-mnt/open-llama/3B-v2
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Loading model file ../models-mnt/open-llama/3B-v2/pytorch_model.bin
params = Params(n_vocab=32000, n_embd=3200, n_layer=26, n_ctx=2048, n_ff=8640, n_head=32, n_head_kv=32, n_experts=None, n_experts_used=None, f_norm_eps=1e-06, rope_scaling_type=None, f_rope_freq_base=None, f_rope_scale=None, n_orig_ctx=None, rope_finetuned=None, ftype=None, path_model=PosixPath('../models-mnt/open-llama/3B-v2'))
Traceback (most recent call last):
  File "/home/ggml/work/llama.cpp/build-ci-release/../convert.py", line 1279, in <module>
    main()
  File "/home/ggml/work/llama.cpp/build-ci-release/../convert.py", line 1255, in main
    vocab = VocabLoader(params, vocab_dir)
  File "/home/ggml/work/llama.cpp/build-ci-release/../convert.py", line 342, in __init__
    self.tokenizer = AutoTokenizer.from_pretrained(str(fname_tokenizer), trust_remote_code=True)
  File "/home/ggml/.local/lib/python3.10/site-packages/transformers/models/auto/tokenization_auto.py", line 787, in from_pretrained
    return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
  File "/home/ggml/.local/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 2028, in from_pretrained
    return cls._from_pretrained(
  File "/home/ggml/.local/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 2260, in _from_pretrained
    tokenizer = cls(*init_inputs, **init_kwargs)
  File "/home/ggml/.local/lib/python3.10/site-packages/transformers/models/llama/tokenization_llama_fast.py", line 124, in __init__
    super().__init__(
  File "/home/ggml/.local/lib/python3.10/site-packages/transformers/tokenization_utils_fast.py", line 114, in __init__
    fast_tokenizer = convert_slow_tokenizer(slow_tokenizer)
  File "/home/ggml/.local/lib/python3.10/site-packages/transformers/convert_slow_tokenizer.py", line 1336, in convert_slow_tokenizer
    return converter_class(transformer_tokenizer).converted()
  File "/home/ggml/.local/lib/python3.10/site-packages/transformers/convert_slow_tokenizer.py", line 459, in __init__
    requires_backends(self, "protobuf")
  File "/home/ggml/.local/lib/python3.10/site-packages/transformers/utils/import_utils.py", line 1276, in requires_backends
    raise ImportError("".join(failed))
ImportError: 
LlamaConverter requires the protobuf library but it was not found in your environment. Checkout the instructions on the
installation page of its repo: https://github.com/protocolbuffers/protobuf/tree/master/python#installation and follow the ones
that match your environment. Please note that you may need to restart your runtime after installation.

+ cur=1
+ echo 1
+ set +x

https://github.com/ggml-org/ci/blob/results/llama.cpp/87/3637afc7924f435ac44c067630a28e82eefa7b/ggml-3-arm64-cpu/stdall

Somehow I missed that this change also introduces dependency on protobuf.

peterdelevoryas added a commit to peterdelevoryas/llama.cpp that referenced this pull request Dec 27, 2023
hodlen pushed a commit to hodlen/llama.cpp that referenced this pull request Apr 1, 2024
…3633)

* Add HFVocab into convert.py

* Update convert.py

* Update convert.py

* add bytes_to_unicode function

* change add_meta_vocab fucntion

* remove debug code

* remove byte_encoder

* Add newline between classes

* Check tokenizer.json when tokenizer.model is not exist.

* Move transformers dependency to local code

* Add error context with 'raise from'

* Add fast tokenizer option to BpeVocab

* Update convert.py

* Add VocabLoader and remove *Vocab class

* Add transformers dependency

* remove added tokens and check newline token to decide spm or bpe

* Update convert.py

* Add special token type

* Update convert.py

* Update convert.py

* Update convert.py

* Fix typo in convert.py

* Fix when params.n_vocab < tokenizer vocab size

* update vocab class

* change funtion name

* Remove unused variable/functions, add types to class variable and methods, delete blank liens

* fix flake8 warnings

* code style cleanup

* make mypy happy

* change exception

---------

Co-authored-by: Jared Van Bortel <[email protected]>
hodlen added a commit to hodlen/llama.cpp that referenced this pull request Apr 3, 2024
readme : update hot topics

common : add `--version` option to show build info in CLI (#4433)

build : detect host compiler and cuda compiler separately (#4414)

sync : ggml (SD ops, tests, kernels) (#4444)

* sync : ggml (SD ops, tests, kernels)

ggml-ci

* cuda : restore im2col

ggml-ci

* metal : fix accuracy of dequantization kernels

ggml-ci

* cuda : restore correct im2col

ggml-ci

* metal : try to fix moe test by reducing expert size

ggml-ci

* cuda : fix bin bcast when src1 and dst have different types

ggml-ci

---------

Co-authored-by: slaren <[email protected]>

server : fix handling of characters that span multiple tokens when streaming (#4446)

readme : update supported model list (#4457)

convert : support loading vocab from fast tokenizer config (#3633)

* Add HFVocab into convert.py

* Update convert.py

* Update convert.py

* add bytes_to_unicode function

* change add_meta_vocab fucntion

* remove debug code

* remove byte_encoder

* Add newline between classes

* Check tokenizer.json when tokenizer.model is not exist.

* Move transformers dependency to local code

* Add error context with 'raise from'

* Add fast tokenizer option to BpeVocab

* Update convert.py

* Add VocabLoader and remove *Vocab class

* Add transformers dependency

* remove added tokens and check newline token to decide spm or bpe

* Update convert.py

* Add special token type

* Update convert.py

* Update convert.py

* Update convert.py

* Fix typo in convert.py

* Fix when params.n_vocab < tokenizer vocab size

* update vocab class

* change funtion name

* Remove unused variable/functions, add types to class variable and methods, delete blank liens

* fix flake8 warnings

* code style cleanup

* make mypy happy

* change exception

---------

Co-authored-by: Jared Van Bortel <[email protected]>

ggml : fix OpenCL broadcast requirement for ggml_mul (close #4453)

ggml : add ggml_row_size() (fixes llama out of space) (#4461)

* Fixes "Not enough space in the context's memory pool" encountered on certain models, which seems to be caused by some imprecision related to the automatic casting of floating point values

* do not cast to size_t, instead just use doubles

* ggml : add ggml_row_size(), deprecate ggml_type_sizef()

* ggml : fix row size compute to avoid overflows

* tests : fix sizey -> sizez

---------

Co-authored-by: Georgi Gerganov <[email protected]>

py : add protobuf dependency (#4466)

ggml : remove n_dims from ggml_tensor (#4469)

ggml-ci

ggml : use ggml_row_size where possible (#4472)

* ggml : use ggml_row_size where possible

ggml-ci

* ggml : move ggml_nbytes_split to ggml-cuda.cu

ggml : group mul_mat_id rows by matrix (cpu only) (#4480)

* ggml : group mul_mat_id rows by matrix (cpu only)

* remove mmid parameters from mm forward

* store row groups in wdata and calculate only once in GGML_TASK_INIT

ggml-ci

server : add optional API Key Authentication example (#4441)

* Add API key authentication for enhanced server-client security

* server : to snake_case

---------

Co-authored-by: Georgi Gerganov <[email protected]>

llama : sanity checks for access to logits (#4274)

Co-authored-by: Georgi Gerganov <[email protected]>

lora : add support for non-llama models (#3333)

* lora : add support for non-llama models

ggml-ci

* avoid leaking ggml_context on failure
cleanup

ggml-ci

* lora : allow 1d tensors

* lora : include embd and output layers in size calculation

* fix style

Link to cublas dynamically on Windows even with LLAMA_STATIC (#4506)

server : allow requests larger than 8K (#4500)

server : fix possible ambiguity in content type charset (#4501)

server : fix grammar being ignored (#4494)

Fix bug in identifying the grammar.

server : disable llm logs if SERVER_VERBOSE is off (#3792)

finetune : keep allocs alive until all allocations are done (#4486)

build : Check the ROCm installation location (#4485)

* build : Check the ROCm installation location

* more generic approach

* fixup! It was returning the path instead of the command output

* fixup! Trailing whitespace

gguf-py : fail fast on nonsensical special token IDs (#4489)

llama.swiftui : add bench functionality (#4483)

* llama.swiftui : add bench button

* llama.swiftui : initial bench functionality

* force to use n_gpu_layers on simulator

* add download buttons & expose llamaState.loadModel

* update project.pbxproj

* comment #Preview & fix editorconfig check

* gitignore : xcode stuff

* llama.swiftui : UX improvements

* llama.swiftui : avoid data copy via "downloadTask"

* llama.swiftui : remove model from project

* llama : remove "mostly" from model infos

* llama.swiftui : improve bench

---------

Co-authored-by: jhen <[email protected]>

readme : update hot topics

decode : fix logits_valid for legacy API (#4516)

llama : fix try_override for bool_value which always return true (#4519)

llama : add phi-2 + fix NeoX rope + ggml_mul_mat_set_prec (#4490)

* phi2 implementation

* fix breaking change

* phi-2 : various fixes

* phi-2 : use layer norm eps

* py : whitespaces

* llama : fix meta KV override bug

* convert : phi don't add BOS token

* convert : revert "added_tokens_decoder" change

* phi-2 : scale Q instead of KQ for better precision

* ggml : fix NeoX rope to rotate just first n_dims

* cuda : less diff in the rope_neox kernel

* ggml : add ggml_mul_mat_set_prec

ggml-ci

* Update ggml-cuda.cu

Co-authored-by: slaren <[email protected]>

* Update ggml-cuda.cu

Co-authored-by: slaren <[email protected]>

* cuda : ggml_cuda_op_mul_mat_cublas support F32 precision

* cuda : remove oboslete comment

---------

Co-authored-by: Ebey Abraham <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>
Co-authored-by: slaren <[email protected]>

llama.swiftui : add more models

llama.swiftui : add tinyllama 1.1B F16

ggml-cuda: Fix HIP build (#4528)

regression of #4490
Adds defines for two new datatypes
cublasComputeType_t, cudaDataType_t.

Currently using deprecated hipblasDatatype_t since newer ones very recent.

ggml : fixed check for _MSC_VER (#4535)

Co-authored-by: Eric Sommerlade <[email protected]>

CUDA: Faster Mixtral prompt processing (#4538)

* CUDA: make MoE tensors contiguous for batch size>1

* Update ggml-cuda.cu

Co-authored-by: slaren <[email protected]>

---------

Co-authored-by: slaren <[email protected]>

Fix access violation in ggml_cuda_free_data if tensor->extra is NULL (#4554)

llama : disable per-tensor info prints on model load (#4562)

cuda : replace asserts in wrong architecture checks with __trap (#4556)

* cuda : replace asserts in wrong architecture checks with __trap

* make bad_arch noreturn, remove returns

cuda : better error message for ggml_get_rows (#4561)

* Update ggml-cuda.cu

* Update ggml-cuda.cu

* Update ggml-cuda.cu

---------

Co-authored-by: Georgi Gerganov <[email protected]>

py : open merges file as 'utf-8' (#4566)

Otherwise, on Windows converting bling-phi-2-v0 (<https://huggingface.co/llmware/bling-phi-2-v0>) via convert-hf-to-gguf.py will fail with the following error:

```
Traceback (most recent call last):
  File "C:\Users\User\git\gguf\convert-hf-to-gguf.py", line 1061, in <module>
    model_instance.set_vocab()
  File "C:\Users\User\git\gguf\convert-hf-to-gguf.py", line 52, in set_vocab
    self._set_vocab_gpt2()
  File "C:\Users\User\git\gguf\convert-hf-to-gguf.py", line 264, in _set_vocab_gpt2
    special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
  File "C:\Users\User\git\gguf\gguf\vocab.py", line 33, in __init__
    self._load(Path(path))
  File "C:\Users\User\git\gguf\gguf\vocab.py", line 81, in _load
    self._try_load_merges_txt(path)
  File "C:\Users\User\git\gguf\gguf\vocab.py", line 95, in _try_load_merges_txt
    for line in fp:
  File "C:\Users\User\miniconda3\envs\gguf\lib\encodings\cp1252.py", line 23, in decode
    return codecs.charmap_decode(input,self.errors,decoding_table)[0]
UnicodeDecodeError: 'charmap' codec can't decode byte 0x81 in position 1415: character maps to <undefined>
```

readme : update coding guidelines

CUDA: mul_mat_id always on GPU for batches >= 32 (#4553)

common : remove incorrect --model-draft default (#4568)

ggml-cuda: Fix HIP build by adding define for __trap (#4569)

Regression of 139882392258671ffe5acdfcadc0bc08572d6eef
HIP doesn't have trap, only abort

cuda : ROCm AMD Unified Memory Architecture (UMA) handling (#4449)

* AMD ROCm: handle UMA memory VRAM expansions

This resolves #2797 by allowing ROCm AMD GPU users with a UMA to
dynamically expand the VRAM allocated to the GPU.

Without this, AMD ROCm users with shared CPU/GPU memory usually are
stuck with the BIOS-set (or fixed) framebuffer VRAM, making it
impossible to load more than 1-2 layers.

Note that the model is duplicated in RAM because it's loaded once for
the CPU and then copied into a second set of allocations that are
managed by the HIP UMA system. We can fix this later.

* clarify build process for ROCm on linux with cmake

* avoid using deprecated ROCm hipMallocHost

* keep simplifying the change required for UMA

* cmake: enable UMA-compatible allocation when LLAMA_HIP_UMA=ON

metal : fix `ggml_metal_log` vargs (#4373)

llama : allow getting n_batch from llama_context in c api (#4540)

* allowed getting n_batch from llama_context in c api

* changed to use `uint32_t` instead of `int`

* changed to use `uint32_t` instead of `int` in `llama_n_ctx`

* Update llama.h

---------

Co-authored-by: Georgi Gerganov <[email protected]>

llama : initial ggml-backend integration (#4520)

* llama : initial ggml-backend integration

* add ggml-metal

* cuda backend can be used though ggml-backend with LLAMA_GGML_BACKEND_CUDA_TEST
access all tensor data with ggml_backend_tensor_get/set

* add ggml_backend_buffer_clear
zero-init KV cache buffer

* add ggml_backend_buffer_is_hos, used to avoid copies if possible when accesing tensor data

* disable gpu backends with ngl 0

* more accurate mlock

* unmap offloaded part of the model

* use posix_fadvise64(.., POSIX_FADV_SEQUENTIAL) to improve performance with mmap

* update quantize and lora

* update session copy/set to use ggml-backend

ggml-ci

* use posix_fadvise instead of posix_fadvise64

* ggml_backend_alloc_ctx_tensors_from_buft : remove old print

* llama_mmap::align_offset : use pointers instead of references for out parameters

* restore progress_callback behavior

* move final progress_callback call to load_all_data

* cuda : fix fprintf format string (minor)

* do not offload scales

* llama_mmap : avoid unmapping the same fragments again in the destructor

* remove unnecessary unmap

* metal : add default log function that prints to stderr, cleanup code

ggml-ci

---------

Co-authored-by: Georgi Gerganov <[email protected]>

ci : add `jlumbroso/free-disk-space` to docker workflow (#4150)

* [github][workflows][docker]: removes hardcoded `ggerganov` from `ghcr` repo

* [github][workflows][docker]: adds `jlumbroso/free-disk-space`

gguf : simplify example dependencies

gguf-py : fix broken link

ggml : change ggml_scale to take a float instead of tensor (#4573)

* ggml : change ggml_scale to take a float instead of tensor

* ggml : fix CPU implementation

* tests : fix test-grad0

ggml-ci

llama : add ability to cancel model loading (#4462)

* llama : Add ability to cancel model load

Updated llama_progress_callback so that if it returns false, the model
loading is aborted.

* llama : Add test for model load cancellation

* Fix bool return in llama_model_load, remove std::ignore use

* Update llama.cpp

Co-authored-by: Jared Van Bortel <[email protected]>

* Fail test if model file is missing

* Revert "Fail test if model file is missing"

This reverts commit 32ebd525bf7e5a87ee8a3dbaab3d92ce79fbf23d.

* Add test-model-load-cancel to Makefile

* Revert "Revert "Fail test if model file is missing""

This reverts commit 2796953257ee5383fa7c8fe8fa8fc888c048fb0b.

* Simplify .gitignore for tests, clang-tidy fixes

* Label all ctest tests

* ci : ctest uses -L main

* Attempt at writing ctest_with_model

* ci : get ci/run.sh working with test-model-load-cancel

* ci : restrict .github/workflows/build.yml ctest to -L main

* update requirements.txt

* Disable test-model-load-cancel in make

* Remove venv before creation

* Restructure requirements.txt

Top-level now imports the specific additional requirements for each
python file. Using `pip install -r requirements.txt` will fail if
versions become mismatched in the per-file requirements.

* Make per-python-script requirements work alone

This doesn't break the main requirements.txt.

* Add comment

* Add convert-persimmon-to-gguf.py to new requirements.txt scheme

* Add check-requirements.sh script and GitHub workflow

* Remove shellcheck installation step from workflow

* Add nocleanup special arg

* Fix merge

see: https://github.com/ggerganov/llama.cpp/pull/4462#discussion_r1434593573

* reset to upstream/master

* Redo changes for cancelling model load

---------

Co-authored-by: Georgi Gerganov <[email protected]>
Co-authored-by: Jared Van Bortel <[email protected]>

ggml : extend `enum ggml_log_level` with `GGML_LOG_LEVEL_DEBUG` (#4579)

readme : add zig bindings (#4581)

ci : tag docker image with build number (#4584)

make : add LLAMA_HIP_UMA option (#4587)

NB: LLAMA_HIP_UMA=1 (or any value) adds MK_CPPFLAG -DGGML_HIP_UMA

ggml : add comment about backward GGML_OP_DIAG_MASK_INF (#4203)

llama : fix platforms without mmap (#4578)

* llama : fix platforms without mmap

* win32 : limit prefetch size to the file size

* fix win32 error clobber, unnecessary std::string in std::runtime_error

Fix CudaMemcpy direction (#4599)

cuda : fix jetson compile error (#4560)

* fix old jetson compile error

* Update Makefile

* update jetson detect and cuda version detect

* update cuda marco define

* update makefile and cuda,fix some issue

* Update README.md

Co-authored-by: Georgi Gerganov <[email protected]>

* Update Makefile

* Update README.md

---------

Co-authored-by: Georgi Gerganov <[email protected]>

sync : ggml (fix im2col) (#4591)

* cuda : fix im2col_f32_f16 (ggml/#658)

ggml-ci

* ggml-alloc : fix ggml_tallocr_is_own

---------

Co-authored-by: leejet <[email protected]>

lookup : add prompt lookup decoding example (#4484)

* initial commit, going through initializations

* main loop finished, starting to debug

* BUG: generates gibberish/repeating tokens after a while

* kv_cache management

* Added colors to distinguish drafted tokens (--color). Updated README

* lookup : fix token positions in the draft batch

* lookup : use n_draft from CLI params

* lookup : final touches

---------

Co-authored-by: Leon Ericsson <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>

CUDA: fixed row rounding for 0 tensor splits (#4594)

grammar : check the full vocab only if necessary (opt) (#4306)

* Check the full vocab for grammar only if necessary

* Fix missing logit restoration step (?)

Does this matter, actually?

* Fix whitespace / formatting

* Adjust comment

* Didn't mean to push test gbnf

* Split sampling into the helper function (?)

And also revert the changes made to the header

* common : fix final newline

---------

Co-authored-by: Georgi Gerganov <[email protected]>

server : allow to specify custom prompt for penalty calculation (#3727)

ci(docker): fix tags in "Build and push docker image (tagged)" (#4603)

fallback to CPU buffer if host buffer alloc fails (#4610)

cuda : improve cuda pool efficiency using virtual memory (#4606)

* cuda : improve cuda pool efficiency using virtual memory

* fix mixtral

* fix cmake build

* check for vmm support, disable for hip

ggml-ci

* fix hip build

* clarify granularity

* move all caps to g_device_caps

* refactor error checking

* add cuda_pool_alloc, refactor most pool allocations

ggml-ci

* fix hip build

* CUBLAS_TF32_TENSOR_OP_MATH is not a macro

* more hip crap

* llama : fix msvc warnings

* ggml : fix msvc warnings

* minor

* minor

* cuda : fallback to CPU on host buffer alloc fail

* Update ggml-cuda.cu

Co-authored-by: Johannes Gäßler <[email protected]>

* Update ggml-cuda.cu

Co-authored-by: Johannes Gäßler <[email protected]>

* ensure allocations are always aligned

* act_size -> actual_size

---------

Co-authored-by: Johannes Gäßler <[email protected]>

llama : add PLaMo model (#3557)

* add plamo mock

* add tensor loading

* plamo convert

* update norm

* able to compile

* fix norm_rms_eps hparam

* runnable

* use inp_pos

* seems ok

* update kqv code

* remove develop code

* update README

* shuffle attn_q.weight and attn_output.weight for broadcasting

* remove plamo_llm_build_kqv and use llm_build_kqv

* fix style

* update

* llama : remove obsolete KQ_scale

* plamo : fix tensor names for correct GPU offload

---------

Co-authored-by: Georgi Gerganov <[email protected]>

simplify bug issue template (#4623)

Adding Emeltal reference to UI list (#4629)

Fix new CUDA10 compilation errors (#4635)

Update comment for AdamW implementation reference. (#4604)

Co-authored-by: Will Findley <[email protected]>

cuda : fix vmm pool with multi GPU (#4620)

* cuda : fix vmm pool with multi GPU

* hip

* use recommended granularity instead of minimum

* better error checking

* fix mixtral

* use cudaMemcpy3DPeerAsync

* use cuda_pool_alloc in ggml_cuda_op_mul_mat

* consolidate error checking in ggml_cuda_set_device

* remove unnecessary inlines

ggml-ci

* style fixes

* only use vmm for the main device

* fix scratch buffer size, re-enable vmm pool for all devices

* remove unnecessary check id != g_main_device

Add byte token type when tokenizer.model is not exists (#4641)

* Add byte token type to hf format

* remove unused variable

ggml : fix dot product for ARM (#4630)

ggml-ci

scripts : add sync-ggml-am.sh

finetune : fix output formatting in print_params (#4653)

This commit fixes the output formatting in the print_params function
which currently looks like this:
```console
print_params: n_vocab:   32000
print_params: n_ctx:     128
print_params: n_embd:    4096
print_params: n_ff:      11008
print_params: n_head:    32
print_params: n_head_kv: 32
print_params: n_layer:   32
print_params: norm_rms_eps          : 0.000010
print_params: rope_freq_base        : 10000.000000
print_params: rope_freq_scale       : 1.000000
```
With this comit the output will look like this:
```console
print_params: n_vocab               : 32000
print_params: n_ctx                 : 128
print_params: n_embd                : 4096
print_params: n_ff                  : 11008
print_params: n_head                : 32
print_params: n_head_kv             : 32
print_params: n_layer               : 32
print_params: norm_rms_eps          : 0.000010
print_params: rope_freq_base        : 10000.000000
print_params: rope_freq_scale       : 1.000000
```

Signed-off-by: Daniel Bevenius <[email protected]>

llama : add AWQ for llama, llama2, mpt, and mistral models (#4593)

* update: awq support llama-7b model

* update: change order

* update: benchmark results for llama2-7b

* update: mistral 7b v1 benchmark

* update: support 4 models

* fix: Readme

* update: ready for PR

* update: readme

* fix: readme

* update: change order import

* black

* format code

* update: work for bot mpt and awqmpt

* update: readme

* Rename to llm_build_ffn_mpt_awq

* Formatted other files

* Fixed params count

* fix: remove code

* update: more detail for mpt

* fix: readme

* fix: readme

* update: change folder architecture

* fix: common.cpp

* fix: readme

* fix: remove ggml_repeat

* update: cicd

* update: cicd

* uppdate: remove use_awq arg

* update: readme

* llama : adapt plamo to new ffn

ggml-ci

---------

Co-authored-by: Trần Đức Nam <[email protected]>
Co-authored-by: Le Hoang Anh <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>

gpt2 : Add gpt2 architecture integration (#4555)

Fix OpenAI server sampling w.r.t. temp and seed (#4668)

The default values for tfs_z and typical_p were being set to zero, which
caused the token candidates array to get shrunk down to one element thus
preventing any sampling. Note this only applies to OpenAI API compatible
HTTP server requests.

The solution is to use the default values that OpenAI documents, as well
as ensuring we use the llama.cpp defaults for the rest. I've tested this
change still ensures deterministic output by default. If a "temperature"
greater than 0 is explicitly passed, then output is unique each time. If
"seed" is specified in addition to "temperature" then the output becomes
deterministic once more.

See mozilla-Ocho/llamafile#117
See mozilla-Ocho/llamafile@9e4bf29

scripts : do not sync commits from this repo

ggml : fix some mul mat cases + add tests for src1 F16 (ggml/669)

* fixed mul-mat error for old GPUs

* style fixes

* add mul mat src1 f16 test cases, fix more cases

ggml-ci

---------

Co-authored-by: bssrdf <[email protected]>
Co-authored-by: slaren <[email protected]>

sync : ggml

ci : build with CLBlast + ggml-opencl use GGML_API (whisper/1576)

* Build with CLBlast

* Declare GGML_API

After rebasing, examples/talk-llama failed:

"D:\a\whisper.cpp\whisper.cpp\build\ALL_BUILD.vcxproj" (build target) (1) ->
"D:\a\whisper.cpp\whisper.cpp\build\examples\talk-llama\talk-llama.vcxproj" (default target) (14) ->
(Link target) ->
  llama.obj : error LNK2019: unresolved external symbol ggml_cl_free_data referenced in function "public: __cdecl llama_model::~llama_model(void)" (??1llama_model@@QEAA@XZ) [D:\a\whisper.cpp\whisper.cpp\build\examples\talk-llama\talk-llama.vcxproj]
  llama.obj : error LNK2019: unresolved external symbol ggml_cl_transform_tensor referenced in function "public: void __cdecl llama_model_loader::load_all_data(struct ggml_context *,void (__cdecl*)(float,void *),void *,struct llama_mlock *)" (?load_all_data@llama_model_loader@@QEAAXPEAUggml_context@@P6AXMPEAX@Z1PEAUllama_mlock@@@Z) [D:\a\whisper.cpp\whisper.cpp\build\examples\talk-llama\talk-llama.vcxproj]
  D:\a\whisper.cpp\whisper.cpp\build\bin\Release\talk-llama.exe : fatal error LNK1120: 2 unresolved externals [D:\a\whisper.cpp\whisper.cpp\build\examples\talk-llama\talk-llama.vcxproj]

scripts : print list of sync commits

llama.swiftui : fix infinite loop, ouput timings, buff UI (#4674)

* fix infinite loop

* slight UI simplification, clearer UX

* clearer UI text, add timings to completion log

main-cmake-pkg : fix build issue (#4665)

* Fix main-cmake-pkg compilation

* Use glob to load common files

* cmake : fix trailing whitespace

---------

Co-authored-by: Georgi Gerganov <[email protected]>

server : allow to generate multimodal embeddings (#4681)

server : fix OpenAI server sampling w.r.t. penalty. (#4675)

server : replace sleep with condition variables (#4673)

The server currently schedules tasks using a sleep(5ms) busy loop. This
adds unnecessary latency since most sleep implementations do a round up
to the system scheduling quantum (usually 10ms). Other libc sleep impls
spin for smaller time intervals which results in the server's busy loop
consuming all available cpu. Having the explicit notify() / wait() code
also helps aid in the readability of the server code.

See mozilla-Ocho/llamafile@711344b

llava-cli : refactor to use sampling library (#4669)

This change makes it possible to use flags like `--grammar` when using
the `llava-cli` program. The rest is just code cleanup deleting a long
standing TODO comment.

This change also ensures that logging information is emitted to stderr
which helps the `llava-cli` command be more friendly to shell scripts.

See Mozilla-Ocho/llamafile@1cd334f

cmake : fix ld warning duplicate libraries libllama.a (#4671)

* fix "ld: warning: ignoring duplicate libraries: '../libllama.a'"

* fix warning in example.

flake.nix : rewrite (#4605)

* flake.lock: update to hotfix CUDA::cuda_driver

Required to support https://github.com/ggerganov/llama.cpp/pull/4606

* flake.nix: rewrite

1. Split into separate files per output.

2. Added overlays, so that this flake can be integrated into others.
   The names in the overlay are `llama-cpp`, `llama-cpp-opencl`,
   `llama-cpp-cuda`, and `llama-cpp-rocm` so that they fit into the
   broader set of Nix packages from [nixpkgs](https://github.com/nixos/nixpkgs).

3. Use [callPackage](https://summer.nixos.org/blog/callpackage-a-tool-for-the-lazy/)
   rather than `with pkgs;` so that there's dependency injection rather
   than dependency lookup.

4. Add a description and meta information for each package.
   The description includes a bit about what's trying to accelerate each one.

5. Use specific CUDA packages instead of cudatoolkit on the advice of SomeoneSerge.

6. Format with `serokell/nixfmt` for a consistent style.

7. Update `flake.lock` with the latest goods.

* flake.nix: use finalPackage instead of passing it manually

* nix: unclutter darwin support

* nix: pass most darwin frameworks unconditionally

...for simplicity

* *.nix: nixfmt

nix shell github:piegamesde/nixfmt/rfc101-style --command \
    nixfmt flake.nix .devops/nix/*.nix

* flake.nix: add maintainers

* nix: move meta down to follow Nixpkgs style more closely

* nix: add missing meta attributes

nix: clarify the interpretation of meta.maintainers

nix: clarify the meaning of "broken" and "badPlatforms"

nix: passthru: expose the use* flags for inspection

E.g.:

```
❯ nix eval .#cuda.useCuda
true
```

* flake.nix: avoid re-evaluating nixpkgs too many times

* flake.nix: use flake-parts

* nix: migrate to pname+version

* flake.nix: overlay: expose both the namespace and the default attribute

* ci: add the (Nix) flakestry workflow

* nix: cmakeFlags: explicit OFF bools

* nix: cuda: reduce runtime closure

* nix: fewer rebuilds

* nix: respect config.cudaCapabilities

* nix: add the impure driver's location to the DT_RUNPATHs

* nix: clean sources more thoroughly

...this way outPaths change less frequently,
and so there are fewer rebuilds

* nix: explicit mpi support

* nix: explicit jetson support

* flake.nix: darwin: only expose the default

---------

Co-authored-by: Someone Serge <[email protected]>

python : add check-requirements.sh and GitHub workflow (#4585)

* python: add check-requirements.sh and GitHub workflow

This script and workflow forces package versions to remain compatible
across all convert*.py scripts, while allowing secondary convert scripts
to import dependencies not wanted in convert.py.

* Move requirements into ./requirements

* Fail on "==" being used for package requirements (but can be suppressed)

* Enforce "compatible release" syntax instead of ==

* Update workflow

* Add upper version bound for transformers and protobuf

* improve check-requirements.sh

* small syntax change

* don't remove venvs if nocleanup is passed

* See if this fixes docker workflow

* Move check-requirements.sh into ./scripts/

---------

Co-authored-by: Jared Van Bortel <[email protected]>

cuda: fix vmm oom issue on NVIDIA AGX Orin (#4687)

Signed-off-by: hydai <[email protected]>

clip : enable gpu backend (#4205)

* clip: enable CUDA backend

* add missing kernels

* add enough padding for alignment

* remove ggml_repeat of clip.cpp

* add metal backend

* llava : fixes

- avoid ggml_repeat
- use GGML_USE_ instead of CLIP_USE_ macros
- remove unused vars

---------

Co-authored-by: Georgi Gerganov <[email protected]>

clip : use ggml_backend_buffer_is_host (#4205)

CUDA: fix tensor core logic for Pascal and HIP (#4682)

ggml : add ggml_cpu_has_avx_vnni() (#4589)

* feat: add avx_vnni based on intel documents

* ggml: add avx vnni based on intel document

* llama: add avx vnni information display

* docs: add more details about using oneMKL and oneAPI for intel processors

* docs: add more details about using oneMKL and oneAPI for intel processors

* docs: add more details about using oneMKL and oneAPI for intel processors

* docs: add more details about using oneMKL and oneAPI for intel processors

* docs: add more details about using oneMKL and oneAPI for intel processors

* Update ggml.c

Fix indentation upgate

Co-authored-by: Georgi Gerganov <[email protected]>

---------

Co-authored-by: Georgi Gerganov <[email protected]>

CUDA: fixed tensor cores not being used on RDNA3 (#4697)

clip : refactor + bug fixes (#4696)

* clip : refactor + bug fixes

ggml-ci

* server : add log message

ggml : add ggml_vdotq_s32 alias (#4715)

ggml-ci

flake.nix: expose full scope in legacyPackages

flake.nix: rocm not yet supported on aarch64, so hide the output

flake.nix: expose checks

workflows: nix-ci: init; build flake outputs

workflows: nix-ci: add a job for eval

workflows: weekly `nix flake update`

workflows: nix-flakestry: drop tag filters

...and add a job for flakehub.com

workflows: nix-ci: add a qemu job for jetsons

flake.nix: suggest the binary caches

flake.lock: update

to a commit recently cached by nixpkgs-cuda-ci

metal : enable shader debugging (cmake option) (#4705)

* ggml : disable fast-math for Metal (cmake build only)

ggml-ci

* metal : fix Metal API debug warnings

* cmake : add -fno-inline for Metal build (#4545)

* metal : fix API debug warnings

* metal : fix compile warnings

* metal : use uint64_t for strides

* cmake : rename option to LLAMA_METAL_SHADER_DEBUG

* metal : fix mat-vec Q8_0 kernel for BS > 1

* metal : normalize mat-vec kernel signatures

* cmake : respect LLAMA_QKK_64 option

* metal : fix mat-vec Q4_K kernel for QK_K == 64

ggml-ci

finetune: fix typo in README.md (#4733)

Signed-off-by: Daniel Bevenius <[email protected]>

py : re-enable mmap in convert hf (#4732)

* update: awq support llama-7b model

* update: change order

* update: benchmark results for llama2-7b

* update: mistral 7b v1 benchmark

* update: support 4 models

* fix: Readme

* update: ready for PR

* update: readme

* fix: readme

* update: change order import

* black

* format code

* update: work for bot mpt and awqmpt

* update: readme

* Rename to llm_build_ffn_mpt_awq

* Formatted other files

* Fixed params count

* fix: remove code

* update: more detail for mpt

* fix: readme

* fix: readme

* update: change folder architecture

* fix: common.cpp

* fix: readme

* fix: remove ggml_repeat

* update: cicd

* update: cicd

* uppdate: remove use_awq arg

* update: readme

* llama : adapt plamo to new ffn

ggml-ci

* fix: update torch version

---------

Co-authored-by: Trần Đức Nam <[email protected]>
Co-authored-by: Le Hoang Anh <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>

server : add --override-kv parameter (#4710)

* Changes to server to allow metadata override

* documentation

* flake.nix: expose full scope in legacyPackages

* flake.nix: rocm not yet supported on aarch64, so hide the output

* flake.nix: expose checks

* workflows: nix-ci: init; build flake outputs

* workflows: nix-ci: add a job for eval

* workflows: weekly `nix flake update`

* workflows: nix-flakestry: drop tag filters

...and add a job for flakehub.com

* workflows: nix-ci: add a qemu job for jetsons

* flake.nix: suggest the binary caches

* flake.lock: update

to a commit recently cached by nixpkgs-cuda-ci

---------

Co-authored-by: John <[email protected]>
Co-authored-by: Someone Serge <[email protected]>

editorconfig : fix whitespace and indentation #4710

llama : differentiate the KV dims in the attention (#4657)

* Add n_key_dim and n_value_dim

Some models use values that are not derived from `n_embd`.
Also remove `n_embd_head` and `n_embd_gqa` because it is not clear
which "head" is referred to (key or value).

Fix issue #4648.

* Fix `llm_build_kqv` to use `n_value_gqa`

* Rebase

* Rename variables

* Fix llm_build_kqv to be more generic wrt n_embd_head_k

* Update default values for n_embd_head_k and n_embd_head_v

Co-authored-by: Georgi Gerganov <[email protected]>

* Fix llm_load_tensors: the asserts were not backcompat

---------

Co-authored-by: Georgi Gerganov <[email protected]>

llama : replace all API facing `int`'s with `int32_t` (#4577)

* replaced all API facing `int`'s with `int32_t`

* formatting and missed `int` in `llama_token_to_piece`

llama : llama_model_desc print number of experts

server : add token counts to html footer (#4738)

* server: add token counts to stats

* server: generate hpp

---------

Co-authored-by: phiharri <[email protected]>

metal : optimize ggml_mul_mat_id (faster Mixtral PP) (#4725)

* ggml : disable fast-math for Metal (cmake build only)

ggml-ci

* metal : fix Metal API debug warnings

* cmake : add -fno-inline for Metal build (#4545)

* metal : fix API debug warnings

* metal : fix compile warnings

* metal : use uint64_t for strides

* cmake : rename option to LLAMA_METAL_SHADER_DEBUG

* metal : fix mat-vec Q8_0 kernel for BS > 1

* metal : normalize mat-vec kernel signatures

* cmake : respect LLAMA_QKK_64 option

* metal : fix mat-vec Q4_K kernel for QK_K == 64

* metal : optimizing ggml_mul_mat_id (wip)

* metal : minor fix

* metal : opt mul_mm_id

server : throw an error when `slot unavailable` (#4741)

ggml : extend ggml_get_rows, ggml_repeat, ggml_concat (ggml/639)

* add more int ops

* ggml_compute_forward_dup_bytes

* add tests

* PR comments

* tests : minor indentations

---------

Co-authored-by: Georgi Gerganov <[email protected]>

scripts : fix sync order + metal sed

metal : add kernel_get_rows_i32

ggml-ci

sync : ggml

ggml-ci

cuda : mark I16 and I32 ops as unsupported

ggml-ci

cuda : simplify expression

Co-authored-by: slaren <[email protected]>

swift : update Package.swift to use ggml as dependency (#4691)

* updates the package.swift to use ggml as dependency

* changes the ggml package url src to ggerganov

train : fix typo in overlapping-samples help msg (#4758)

This commit fixes a typo in the help message for the
--overlapping-samples option.

Signed-off-by: Daniel Bevenius <[email protected]>

llama.swiftui : fix build of ggml.metallib (#4754)

* metal: fix metal backend init failure in swiftui

* metal: build ggml.metallib instead of copy src

* llama.swift : remove debug flags from metallib build

---------

Co-authored-by: Georgi Gerganov <[email protected]>

ggml : include stdlib.h before intrin.h (#4736)

server : fix options in README.md (#4765)

* fix examples/server/README.md

* minor : fix whitespace

---------

Co-authored-by: Georgi Gerganov <[email protected]>

llama.swiftui : support loading custom model from file picker (#4767)

* swiftui: support load model from file picker

* swiftui: remove trailing whitespace

Print backend name on test-backend-ops failure (#4751)

server : send token probs for "stream == false" (#4714)

finetune : remove unused includes (#4756)

This commit removes unused includes from finetune.cpp.

Signed-off-by: Daniel Bevenius <[email protected]>

examples : add few-shot translation example (#4783)

ggml : do not sched_yield when calling BLAS (#4761)

* ggml : do not sched_yield when calling BLAS

ggml-ci

* ggml : fix do_yield logic

ggml-ci

* ggml : simplify do_yield logic

ggml-ci

ggml : add error handling to graph_compute (whisper/1714)

ggml : fix q2_k bpw in comments (ggml/680)

metal : switch back to default.metallib (ggml/681)

ggml-ci

flake.nix : fix typo (#4700)

betwen -> between

cmake : check for openblas64 (#4134)

openblas v0.3.22 64-bit pkg-config file is named openblas64.pc
https://github.com/OpenMathLib/OpenBLAS/issues/3790

examples : improve base-translate.sh script (#4783)

llama.swiftui : use correct pointer for llama_token_eos (#4797)

server : fix n_predict check (#4798)

ggml : use __builtin_amdgcn_sudot4 in __dp4a for gfx11 (#4787)

llama.swiftui : add visionOS target (#4805)

llama : print tensor meta for debugging

llama.swiftui : use llama.cpp as SPM package (#4804)

llama : remove redundant GQA check (#4796)

llama : remove unused vars (#4796)

CUDA: fixed redundant value dequantization (#4809)

llama-bench : add no-kv-offload parameter (#4812)

readme : add lgrammel/modelfusion JS/TS client for llama.cpp (#4814)

examples : add passkey test (#3856)

* examples : add passkey test

* passkey : better prints

* passkey : select pass key pos from CLI

* passkey : simplify n_past logic

* make : add passkey target

* passkey : add "self-extend"-like context extension (#4810)

* llama : "self-extend"-like context extension

* passkey : add comment

* passkey : add readme

main : add self-extend support (#4815)

* examples : add passkey test

* passkey : better prints

* passkey : select pass key pos from CLI

* passkey : simplify n_past logic

* llama : "self-extend"-like context extension

* passkey : add comment

* main : add Self-Extend support

* llama : add comment about llama_kv_cache_seq_div

llama.swiftui : update readme

swift : exclude ggml-metal.metal from the package (#4822)

SOTA 2-bit quants (#4773)

* iq2_xxs: basics

* iq2_xxs: scalar and AVX2 dot products

Needed to change Q8_K to have quants in the -127...127 range,
else the IQ2_XXS AVX implementation becomes very awkward.
The alternative would have been to use Q8_0 instead. Perhaps
I'll change later, for now this is what we have.

* iq2_xxs: ARM_NEON dot product

Somehow strangely slow (112 ms/token).

* iq2_xxs: WIP Metal

Dequantize works, something is still wrong with the
dot product.

* iq2_xxs: Metal dot product now works

We have
PP-512 = 475 t/s
TG-128 = 47.3 t/s

Not the greatest performance, but not complete garbage either.

* iq2_xxs: slighty faster dot product

TG-128 is now 48.4 t/s

* iq2_xxs: slighty faster dot product

TG-128 is now 50.9 t/s

* iq2_xxs: even faster Metal dot product

TG-128 is now 54.1 t/s.

Strangely enough, putting the signs lookup table
into shared memory has a bigger impact than the
grid values being in shared memory.

* iq2_xxs: dequantize CUDA kernel - fix conflict with master

* iq2_xxs: quantized CUDA dot product (MMVQ)

We get TG-128 = 153.1 t/s

* iq2_xxs: slightly faster CUDA dot product

TG-128 is now at 155.1 t/s.

* iq2_xxs: add to llama ftype enum

* iq2_xxs: fix MoE on Metal

* Fix missing MMQ ops when on hipBLAS

I had put the ggml_supports_mmq call at the wrong place.

* Fix bug in qequantize_row_iq2_xxs

The 0.25f factor was missing.
Great detective work by @ggerganov!

* Fixing tests

* PR suggestion

---------

Co-authored-by: Iwan Kawrakow <[email protected]>

readme : add link to SOTA models

common : fix the short form of `--grp-attn-w`, not `-gat` (#4825)

See https://github.com/ggerganov/llama.cpp/blob/master/common/common.cpp#L230C53-L230C57

CUDA: faster softmax via shared memory + fp16 math (#4742)

ggml : fix vld1q_s8_x4 32-bit compat (#4828)

* ggml : fix vld1q_s8_x4 32-bit compat

ggml-ci

* ggml : fix 32-bit ARM compat (cont)

ggml-ci

server : add api-key flag to documentation (#4832)

Document the api-key flag added to server in https://github.com/ggerganov/llama.cpp/pull/4441

server : update readme about token probs (#4777)

* updated server readme to reflect the gg/server-token-probs-4088 commit

added explanation for the API's completion result which now includes `completion_probabilities`. Also added a JSON schema that shows the type/structure of `completion_probabilities`.

* simplified the `completion_probabilities` JSON schema

It's now easier to understand what the structure of `completion_probabilities` looks like.

* minor : fix trailing whitespace

---------

Co-authored-by: Georgi Gerganov <[email protected]>

scripts : script to get Paul Graham essays in txt format (#4838)

readme : add 3rd party collama reference to UI list (#4840)

Add a VSCode extension for llama.cpp reference to UI list

scripts : improve get-pg.sh (#4838)

metal : improve dequantize precision to match CPU (#4836)

ggml-ci

llava-cli : don't crash if --image flag is invalid (#4835)

This change fixes an issue where supplying `--image missing-file` would
result in a segfault due to a null pointer being dereferenced. This can
result in distracting info being printed if robust crash analysis tools
are being used.

convert.py : fix vanilla LLaMA model conversion (#4818)

* Update Imports and Add Notes for Future Reference

- Updated import statements in `convert.py`.
- Added import for `AutoTokenizer` from `transformers` module.
- Added conditional import for `gguf` from the local directory.
- Added comments and notes for future reference.

Additional Notes:

- Noted removal of a redundant `TypeAlias` import.
- Noted the removal of a `gguf` debug statement.
- Commented on the presence of `ARCH` and `NDArray` definitions.
- Commented on cleaning up and refactoring data type definitions.

* Refine Model Hyperparameters and Params Class

- Updated type annotations to use `Optional` for clarity.
- Improved method names and attribute consistency.
- Removed unnecessary variables for better code readability.

Additional Notes:

- Highlighted the use of `Optional` for clearer intent.
- Ensured backward and forward compatibility.

* Restore BpeVocab and SentencePieceVocab classes

- Restored the BpeVocab class for handling BPE tokenization.
- Restored the SentencePieceVocab class for SentencePiece tokenization.

These classes are essential for maintaining the original behavior of the codebase.

* refactor: Standardize vocabulary handling with HfVocab

- Replaced VocabLoader with HfVocab, aligning vocabulary handling across classes.
- Updated initialization of HfVocab with local_files_only=True for AutoTokenizer.
- Introduced optional parameter fname_added_tokens for flexible added token management.
- Streamlined added token handling for clarity and conciseness.
- Maintained special tokens and IDs, enhancing token management.
- Simplified token processing methods for improved readability.
- Added a placeholder for score computation with a default value of -1000.0.
- Optimized newline token check for efficiency.
- Updated __repr__ function for clarity in representation.
- Adjusted type alias Vocab to include BpeVocab, SentencePieceVocab, and HfVocab.
- Removed redundant code related to special token handling, reverse vocabulary mapping, and vocabulary file detection.

This refactoring promotes a standardized and modular approach to vocabulary management, facilitating future integration with a VocabFactory and improving code maintainability and scalability.

* refactor: Enhance readability, functionality, and code quality

- Improved code formatting and readability for better maintainability.
- Refactored LazyUnpickler's CLASSES dictionary for clarity.
- Added print statements and warnings in check_vocab_size for user feedback.
- Removed find_vocab_file_path, as it's superseded by VocabFactory.
- Preparatory changes for upcoming classes: OutputFile and VocabFactory.
- Overall focus on code quality, error handling, and consistency.

These changes reflect a continuous effort to refine the codebase, ensuring it meets best practices and prepares for future enhancements, such as the VocabFactory.

* refactor: Update OutputFile class for enhanced model vocabulary management

- Restructured the constructor for improved readability.
- Updated `add_meta_arch` method for flexible model name determination.
- Introduced `handle_tokenizer_model` for mapping vocab types to supported tokenizer models.
- Streamlined vocabulary extraction with `extract_vocabulary_from_model`.
- Simplified vocabulary metadata addition using `add_meta_vocab`.
- Refactored `add_tensor_info` for clarity and consistency.
- Improved error handling for better user feedback.

These changes signify the development of a versatile and comprehensive `OutputFile` class, enabling efficient management of model conversion output, metadata, vocabulary, and tensor information.

* feat: Introduce VocabFactory for flexible vocabulary management in model conversion

- The VocabFactory class is added to facilitate modular vocabulary handling.
- The constructor initializes a directory path and detects vocabulary-related files.
- The _select_file method provides file paths based on vocabulary type (e.g., BPE, SentencePiece).
- _create_special_vocab generates special vocabularies, accommodating different types.
- The load_vocab method loads vocabularies, handling BPE, SentencePiece, and Hugging Face Fast Tokenizer.
- Error handling and logging enhance debugging and user feedback.
- The modular and flexible design simplifies vocabulary management and supports future extensions.

The VocabFactory class enhances code modularity and maintainability, allowing versatile vocabulary handling in the model conversion process.

* refactor: Improve code organization, argument parsing, and user interface

- Renamed 'default_outfile' to 'default_output_file' for clarity.
- Refactored argument parser setup into 'get_argument_parser' function.
- Introduced descriptive comments for each argument in the parser.
- Added '--vocab-type' argument with choices ["spm", "bpe", "hfft"] for vocabulary processing.
- Improved flag naming consistency: '--outfile' to '--out-file' and '--bigendian' to '--big-endian'.
- Enhanced error handling to prevent overwriting input data in 'default_output_file'.
- Made 'argv' in 'main' an optional parameter for flexibility.
- Introduced dynamic import for 'awq.apply_awq' based on 'args.awq_path' for conditional dependency.

These changes enhance code clarity, organization, and the user interface of the script, aligning it with Python best practices and improving maintainability.

* refactor: Further refine functionality, improve user interaction, and streamline vocabulary handling

- Renamed command-line arguments for clarity and consistency.
- Improved path resolution and import adjustments for robustness.
- Thoughtfully handled 'awq-path' and conditional logic for the weighted model.
- Enhanced model and vocabulary loading with the 'VocabFactory' class for structured and adaptable loading.
- Strengthened error handling and user feedback for a more user-friendly experience.
- Structured output file handling with clear conditions and defaults.
- Streamlined and organized the 'main' function for better logic flow.
- Passed 'sys.argv[1:]' to 'main' for adaptability and testability.

These changes solidify the script's functionality, making it more robust, user-friendly, and adaptable. The use of the 'VocabFactory' class is a notable enhancement in efficient vocabulary handling, reflecting a thoughtful and iterative approach to script development.

* chore: Apply ruff formatting to convert.py

Signed-off-by: teleprint-me <[email protected]>

* Revert to commit 0614c33

* chore: Apply flake8 formatting rules

Signed-off-by: teleprint-me <[email protected]>

* refactor: Revise `check_vocab_size` for Enhanced Clarity and Correctness

- Resolved an unreachable branch issue by reorganizing the conditional structure.
- Moved the special case check for `params.n_vocab == -1` to the top for immediate assertion.
- Flattened the conditional logic for improved clarity and predictability of the function's behavior.

These changes enhance the readability and functional correctness of the `check_vocab_size` function without altering its intended functionality.

* py : fix outfile and outtype

* py : suggest hint for missing vocab size

---------

Signed-off-by: teleprint-me <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>

Python script to compare commits with llama-bench (#4844)

clip : support more quantization types (#4846)

Uses ggml functions instead of hardcoded names and adds support to quantize into the modern Q-K variants.
This is just the bare minimum to get k-types working - a more refined choice of types would be needed to get best quality on low quantizations.

I ran a few tests, it doesn't break anything I could notice and a Q6_K ViT works almost as well as Q8_0 but 3 times the inference speed.

llama : recognize 1B phi models (#4847)

This update categorizes models with 24 layers as MODEL_1B, ensuring compatibility with different Phi model variants without impacting existing Phi-2 model functionality.

llama : add additional suffixes for model params (#4834)

* llm_load_print_meta: Add additional suffixs for model params

* Update llama.cpp model param log

remove unneeded comments and convert from > to >=

server : add a `/health` endpoint (#4860)

* added /health endpoint to the server

* added comments on the additional /health endpoint

* Better handling of server state

When the model is being loaded, the server state is `LOADING_MODEL`. If model-loading fails, the server state becomes `ERROR`, otherwise it becomes `READY`. The `/health` endpoint provides more granular messages now according to the server_state value.

* initialized server_state

* fixed a typo

* starting http server before initializing the model

* Update server.cpp

* Update server.cpp

* fixes

* fixes

* fixes

* made ServerState atomic and turned two-line spaces into one-line

server : fix build + rename enums (#4870)

server : update readme to document the new `/health` endpoint (#4866)

* added /health endpoint to the server

* added comments on the additional /health endpoint

* Better handling of server state

When the model is being loaded, the server state is `LOADING_MODEL`. If model-loading fails, the server state becomes `ERROR`, otherwise it becomes `READY`. The `/health` endpoint provides more granular messages now according to the server_state value.

* initialized server_state

* fixed a typo

* starting http server before initializing the model

* Update server.cpp

* Update server.cpp

* fixes

* fixes

* fixes

* made ServerState atomic and turned two-line spaces into one-line

* updated `server` readme to document the `/health` endpoint too

fix : cuda order of synchronization when setting a buffer (ggml/679)

* fix : cuda order of synchronization when setting a buffer

* also sync before memcpy

---------

Co-authored-by: slaren <[email protected]>

Fix execlp call (ggml/689)

NULL can be an integer constant expression with the value zero, in this case the behavior would be undefined because of an incorrect type being passed to the variable arguments.

ggml : change GGML_MAX_NAME at compile time (ggml/682)

* change GGML_MAX_NAME to 128

* allow controlling the value of GGML_MAX_NAME through external macro definitions

metal : wrap each operation in debug group (ggml/690)

ggml : remove ggml_cpy_inplace and ggml_cont_inplace (ggml/693)

metal : fix deprecation warning (ggml/690)

sync : ggml

metal : put encoder debug group behind a define (#4873)

server : fix typo in model name (#4876)

main : print total token count and tokens consumed so far (#4874)

* Token count changes

* Add show token count

* Updating before PR

* Two requested changes

* Move param def posn

ci: nix-flake-update: new token with pr permissions (#4879)

* ci: nix-flake-update: new token with pr permissions

---------

Co-authored-by: Georgi Gerganov <[email protected]>

server : add `LOG_INFO` when model is successfully loaded (#4881)

* added /health endpoint to the server

* added comments on the additional /health endpoint

* Better handling of server state

When the model is being loaded, the server state is `LOADING_MODEL`. If model-loading fails, the server state becomes `ERROR`, otherwise it becomes `READY`. The `/health` endpoint provides more granular messages now according to the server_state value.

* initialized server_state

* fixed a typo

* starting http server before initializing the model

* Update server.cpp

* Update server.cpp

* fixes

* fixes

* fixes

* made ServerState atomic and turned two-line spaces into one-line

* updated `server` readme to document the `/health` endpoint too

* used LOG_INFO after successful model loading

server : support for multiple api keys (#4864)

* server: added support for multiple api keys, added loading api keys from file

* minor: fix whitespace

* added file error handling to --api-key-file, changed code to better
reflect current style

* server: update README.md for --api-key-file

---------

Co-authored-by: Michael Coppola <[email protected]>

server : implement credentialed CORS (#4514)

* Implement credentialed CORS according to MDN

* Fix syntax error

* Move validate_api_key up so it is defined before its first usage

swift : pin ggml commit + remove ggml.h from spm-headers (#4878)

ggml-ci

ggml : SOTA 2-bit quants (add IQ2_XS) (#4856)

* iq2_xs: basics

* iq2_xs: this should have been in the basics

* iq2_xs: CUDA and scalar CPU works

* iq2_xs: WIP Metal

* iq2_xs: Metal now works

* iq2_xs: working, but dog slow, ARM_NEON dot product

* iq2_xs: better ARM_NEON dot product

We are now at 19.5 t/s for TG-128 and 61 t/s for PP-512 when
running on the CPU.

* iq2_xs: AVX2 dot product - 19.5 t/s

* iq2_xs: faster AVX2 dit product

21.4 t/s for TG-128, 59.2 t/s for PP-512.
The latter is 2x compared to the previous version.

* iq2_xs: had forgotten to delete iq2-data.h

* Add llama enum for IQ2_XS

---------

Co-authored-by: Iwan Kawrakow <[email protected]>

llama : restore intended k-quants mixes for MoE models (#4872)

* Restore intended k-quants quantization mixes for MoE models

* Update Q2_K_S values in the quantize tool

Still using LLaMA-v1 PPL values in the quant description
today does not make much sense. But let's leave this update
for another PR.

---------

Co-authored-by: Iwan Kawrakow <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>

swift : track ggml release branch (#4867)

main : disable token count by default (#4874)

main : better name for variable n_print (#4874)

server : fix infill when prompt is empty (#4833)

Importance Matrix calculation (#4861)

* imatrix: 1st version

* imatrix: WIP

* Cleanup

* Update examples/imatrix/imatrix.cpp

Co-authored-by: Georgi Gerganov <[email protected]>

---------

Co-authored-by: Iwan Kawrakow <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>

llama : fix llm_build_k_shift to use correct n_rot (#4889)

* llama : fix llm_build_k_shift to use correct n_rot

ggml-ci

* llama : always use hparams.n_rot for ggml_rope_custom

ggml-ci

* convert : fix persimmon conversion to write correct n_rot

py : fix lint (#4889)

common : streamline the formatting of help (#4890)

* common : streamline the formatting of help

- Separate alternative parameters by a comma

- Do not indent `--version` differently

* Update common/common.cpp

---------

Co-authored-by: Georgi Gerganov <[email protected]>

llama : fix typo "imp_embd" -> "inp_embd"

CUDA: fix softmax compile for old CUDA versions (#4862)

gitignore : imatrix

llama.swiftui : update models layout (#4826)

* Updated Models Layout

- Added a models drawer
- Added downloading directly from Hugging Face
- Load custom models from local folder
- Delete models by swiping left

* trimmed trailing white space

* Updated Models Layout

export-lora : use LLAMA_FILE_MAGIC_GGLA (#4894)

This commit replaces the magic number used in export-lora.cpp with
the one defined in llama.h, which is indirectly included via common.h.

Signed-off-by: Daniel Bevenius <[email protected]>

llama : remove redundant assert for StableLM (#4901)

llama : ggml-backend integration (#4766)

* llama : ggml-backend integration

* ggml-backend : add names to buffers

* fix unmap after loading

* batched-bench : add tensor_split param

* llama : check for null tensor_split

* ggml-backend : increase GGML_MAX_BACKENDS

* improve graph splitting, partial fix for --no-kv-offload

* cuda : add ggml-backend split buffer support

* cuda : do not create buffer types for devices that don't exist (fixes usage without CUDA devices available)

* ggml : fix null backend dereference (#4807)

* ggml : fix null backend dereference

* ggml : also check ggml_backend_is_cpu

* test-backend-ops : check buffer allocation failures

* llama : add cparam (split_mode) and command line argument (--split-mode, -sm) to configure the split mode (none, layer or row)

* ggml : fix mul_mat_id work size

* llama : rewrite session kv load/set without graphs

* minor

* llama : only initialize used backends, free backends on context free

* llama : abort ctx if cuda backend init fails

* llama : rewrite lora with ggml-backend and compute on CPU

ggml-ci

* llama : only map to a backend buffer the region of the file mapping containing the tensors used in the buffer

* opencl : add ggml-backend buffer type

* cuda : only use batched_cublas with batched mat muls (fixes fp16 tg perf)

* llama : on Metal, by default offload the full model

ggml-ci

* metal : page align the data ptr (#4854)

* Apply suggestions from code review

Co-authored-by: Johannes Gäßler <[email protected]>

* cuda : fix split buffer free

* address review comments

* llama-bench : add split-mode parameter

* fix whitespace

* opencl : fix double initialization

* server : add --split-mode parameter

* use async copy and compute to improve multi-gpu performance

ggml-ci

* use async memcpys to copy the graph outputs to the CPU

* fix opencl

* use a host buffer for the cpu compute buffer for faster copies to the gpu

---------

Co-authored-by: Georgi Gerganov <[email protected]>
Co-authored-by: Johannes Gäßler <[email protected]>

CUDA: faster q8_0 -> f16 dequantization (#4895)

examples : add pydantic models to GBNF grammar generator (#4883)

* Create pydantic-models-to-grammar.py

* Added some comments for usage

* Refactored Grammar Generator

Added example and usage instruction.

* Update pydantic_models_to_grammar.py

* Update pydantic-models-to-grammar-examples.py

* Renamed module and imported it.

* Update pydantic-models-to-grammar.py

* Renamed file and fixed grammar generator issue.

backend_sched : fix assignments

ggml-ci

ggml : fix 32-bit ARM compat for IQ2_XS (whisper/1758)

* ggml : fix 32-bit ARM compat

* ggml : fix fix

* ggml : fix fix fix

sync : ggml

convert : update phi-2 to latest HF repo (#4903)

* convert : update phi-2 to latest HF repo

ggml-ci

* py : try to fix flake stuff

server : fix crash with multimodal models without BOS token (#4904)

server : fix deadlock that occurs in multi-prompt scenarios (#4905)

* * fix deadlock

* * dont ruint all whitespace

compare-llama-bench: tweak output format (#4910)

metal : refactor kernel loading code (#4794)

* metal : detect more GPU families

* metal : refactor kernel loading

* metal : set kernel family requirements

* metal : fix kernel init + fix compile options

* metal : take into account simdgroup reduction support

* metal : print only skipped kernels

* metal : fix check for simdgroup reduction support

* metal : check for Metal 3

* metal : free allocations

* metal : normalize encoder:setComputePipelineStatus calls

ggml-ci

* metal : fix Metal3 family check

ggml-ci

* metal : check for simdgroup matrix mul. feature

ggml-ci

gguf : fix potential infinite for-loop (#4600)

Co-authored-by: Bernhard Gstrein <[email protected]>

main : add parameter --no-display-prompt (#4541)

* add the parameter : --no-display-prompt , combine with --log-disable it will display only the generated tokens

* remove empty line

---------

Co-authored-by: Georgi Gerganov <[email protected]>

workflows: unbreak nix-build-aarch64, and split it out (#4915)

The fix should be just the `sudo apt-get update`

llama : minimize size used for state save/load (#4820)

* examples : save-load-state: save only required state

* llama : only reserve n_vocab * n_batch at most for logits

llama_decode asserts that only n_batch tokens are passed each call, and
n_ctx is expected to be bigger than n_batch.

* llama : always reserve n_vocab * n_batch for logits

llama_context de-serialization breaks if the contexts have differing
capacity for logits and llama_decode will at maximum resize to
n_vocab * n_batch.

* llama : only save and restore used logits

for batch sizes of 512 this reduces save state in the best case by
around 62 MB, which can be a lot if planning to save on each message
to allow regenerating messages.

* llama : use ostringstream and istringstream for save and load

* llama : serialize rng into minimum amount of space required

* llama : break session version due to serialization changes

metal : disable log for loaded kernels (#4794)

llama : fix detokenization of non-special added-tokens (#4916)

Co-authored-by: goerch <[email protected]>

server : fix prompt caching with system prompt (#4914)

metal : remove old API (#4919)

ggml-ci

ggml: cache sin/cos for RoPE (#4908)

sync : ggml

Make Q3_K_S be the same as olf Q3_K_L for Mixtral-8x7B (#4906)

Co-authored-by: Iwan Kawrakow <[email protected]>

2-bit quantizations (#4897)

* imatrix: load

* imatrix: WIP

* imatrix: Add Q2_K quantization

* imatrix: also guard against Q2_K_S quantization without importance matrix

* imatrix: guard even more against low-bit quantization misuse

---------

Co-authored-by: Iwan Kawrakow <[email protected]>

llama : support WinXP build with MinGW 8.1.0 (#3419)

metal : correctly set SIMD support flags on iOS (#4923)

* Correctly set support_simdgroup_reduction and support_simdgroup_mm on iPhone/iPad

* log a little bit more info on iOS

Fix ffn_down quantization mix for MoE models (#4927)

* Fix ffn_down quantization mix for MoE models

In #4872 I did not consider the part where every third
tensor is quantized with more bits. Fir MoE this leads to tensors
of the same layer being quantized with different number of bits,
which is not considered as a possibility in the inference implementation
(it is assumed all experts use the same quantization).

* Fix the fix

* Review suggestion

---------

Co-authored-by: Iwan Kawrakow <[email protected]>

llama : use LLAMA_LOG_ macros for logging

scripts : sync-ggml-am.sh option to skip commits

llama : check LLAMA_TRACE env for extra logging (#4929)

* llama : minor fix indent

* llama : check LLAMA_TRACE env for extra logging

ggml-ci

Add ability to use importance matrix for all k-quants (#4930)

Co-authored-by: Iwan Kawrakow <[email protected]>

llama : fix missing quotes (#4937)

CUDA: faster dequantize kernels for Q4_0 and Q4_1 (#4938)

Co-authored-by: Iwan Kawrakow <[email protected]>

llama : check for 256 divisibility for IQ2_XS, IQ2_XXS (#4950)

Co-authored-by: Iwan Kawrakow <[email protected]>

cuda : fix dequantize kernel names (#4938)

awq-py : fix typo in awq-py/README.md (#4947)

llama : apply classifier-free guidance to logits directly (#4951)

pass cpu-architecture arguments only to host code (C;C++) (#4943)

speculative : threading options (#4959)

* speculative: expose draft threading

* fix usage format

* accept -td and -tbd args

* speculative: revert default behavior when -td is unspecified

* fix trailing whitespace

finetune : use LLAMA_FILE_MAGIC_GGLA (#4961)

This commit replaces the magic number LLAMA_FILE_MAGIC_LORA used in
finetune.cpp with LLAMA_FILE_MAGIC_GGLA defined in llama.h.

Signed-off-by: Daniel Bevenius <[email protected]>

ggml : introduce GGML_CALL function annotation (#4850)

This change makes it possible to build ggml-cuda.cu and ggml-metal.m as
independent dynamic shared objects, that may be conditionally linked at
runtime in a multiplatform binary. It introduces a GGML_CALL annotation
that documents which functions have a cyclic call relationship, between
the application code and GPU modules.

This change does nothing, unless the build defines -DGGML_MULTIPLATFORM
which causes back-references and function pointers to conform to MS ABI
which is supported by NVCC, ROCm, XCode, GCC and Clang across platforms

examples : fix and improv docs for the grammar generator (#4909)

* Create pydantic-models-to-grammar.py

* Added some comments for usage

* Refactored Grammar Generator

Added example and usage instruction.

* Update pydantic_models_to_grammar.py

* Update pydantic-models-to-grammar-examples.py

* Renamed module and imported it.

* Update pydantic-models-to-grammar.py

* Renamed file and fixed grammar generator issue.

* Fixed some issues and bugs of the grammar generator. Imporved Documentation

* Update pydantic_models_to_grammar.py

metal : log `recommendedMaxWorkingSetSize` on iOS 16+ (#4936)

* metal: Log `recommendedMaxWorkingSetSize` on iOS 16+

* Only log on iOS and macOS, ignoring tvOS and other platforms

* Check for Xcode version before using recommendedMaxWorkingSetSize

---------

Co-authored-by: Georgi Gerganov <[email protected]>

metal : replace loop of dispatch_async with dispatch_apply (#4934)

* Replace loop of dispatch_async with dispatch_apply

* Update ggml-metal.m

---------

Co-authored-by: Georgi Gerganov <[email protected]>

android : introduce starter project example (#4926)

* Introduce starter project for Android

Based on examples/llama.swiftui.

* Add github workflow

* Set NDK version

* Only build arm64-v8a in CI

* Sync bench code

* Rename CI prop to skip-armeabi-v7a

* Remove unused tests

metal : localized logic in `ggml_metal_graph_compute` (#4924)

* Metal: Localized logic in `ggml_metal_graph_compute`, minor performance improvement

* Whitespace

* Collecting command buffer c…
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin, *.safetensors)")
parser.add_argument("--vocabtype", choices=["spm", "bpe"], help="vocab format (default: spm)", default="spm")
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you help me understand why you removed *.safetensors?

Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model files are automatically discovered. The model file names are arbitrary, some are split and others are not. The model files are dynamically "lazy" loaded to save memory as a result.

https://github.com/ggerganov/llama.cpp/blob/master/convert.py#L1348

The vocab can be set manually though.

https://github.com/ggerganov/llama.cpp/blob/master/convert.py#L1461

Not sure if that answers your question.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.