Adding a model requires few steps:
- Convert the model to GGUF
- Define the model architecture in
llama.cpp - Build the GGML graph implementation
- Optional: Add multimodal encoder implementation
After following these steps, you can open PR.
Also, it is important to check that the examples and main ggml backends (CUDA, METAL, CPU) are working with the new architecture, especially:
This step is done in python with a convert script using the gguf library.
Depending on the model architecture, you can use either convert_hf_to_gguf.py or examples/convert_legacy_llama.py (for llama/llama2 models in .pth format).
The convert script reads the model configuration, tokenizer, tensor names+data and converts them to GGUF metadata and tensors.
The required steps to implement for an HF model are:
- Define the model
ModelBase.registerannotation in a newTextModelorMmprojModelsubclass, example:
@ModelBase.register("MyModelForCausalLM")
class MyModel(TextModel):
model_arch = gguf.MODEL_ARCH.MYMODELor
@ModelBase.register("MyModelForConditionalGeneration")
class MyModel(MmprojModel):
model_arch = gguf.MODEL_ARCH.MYMODEL- Define the layout of the GGUF tensors in constants.py
Add an enum entry in MODEL_ARCH, the model human friendly name in MODEL_ARCH_NAMES and the GGUF tensor names in MODEL_TENSORS.
Example for falcon model:
MODEL_ARCH.FALCON: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_NORM_2,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
]- Map the original tensor names to the standardize equivalent in GGUF
As a general rule, before adding a new tensor name to GGUF, be sure the equivalent naming does not already exist.
Once you have found the GGUF tensor name equivalent, add it to the tensor_mapping.py file.
If the tensor name is part of a repetitive layer/block, the key word bid substitutes it.
Example for the normalization tensor in attention layers:
block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
# Attention norm
MODEL_TENSOR.ATTN_NORM: (
"gpt_neox.layers.{bid}.input_layernorm", # gptneox
"transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen
"transformer.blocks.{bid}.norm_1", # mpt
...
)
}transformer.blocks.{bid}.norm_1 will be mapped to blk.{bid}.attn_norm in GGUF.
Depending on the model configuration, tokenizer, code and tensors layout, you will have to override:
TextModel#set_gguf_parametersMmprojModel#set_gguf_parametersModelBase#set_vocabModelBase#modify_tensors
NOTE: Tensor names must end with .weight or .bias suffixes, that is the convention and several tools like quantize expect this to proceed the weights.
The model params and tensors layout must be defined in llama.cpp source files:
- Define a new
llm_archenum value insrc/llama-arch.h. - In
src/llama-arch.cpp:- Add the architecture name to the
LLM_ARCH_NAMESmap. - Add the list of model tensors to
llm_get_tensor_names(you may also need to updateLLM_TENSOR_NAMES)
- Add the architecture name to the
- Add any non-standard metadata loading in the
llama_model_loaderconstructor insrc/llama-model-loader.cpp. - If the model has a RoPE operation, add a case for the architecture in
llama_model_rope_typefunction insrc/llama-model.cpp.
NOTE: The dimensions in ggml are typically in the reverse order of the pytorch dimensions.
This is the funniest part, you have to provide the inference graph implementation of the new model architecture in src/llama-model.cpp.
Create a new struct that inherits from llm_graph_context and implement the graph-building logic in its constructor.
Have a look at existing implementations like llm_build_llama, llm_build_dbrx or llm_build_bert.
Then, in the llama_model::build_graph method, add a case for your architecture to instantiate your new graph-building struct.
Some ggml backends do not support all operations. Backend implementations can be added in a separate PR.
Note: to debug the inference graph: you can use llama-eval-callback.
If the new model supports multimodal inputs, you will need to add a new encoder definition in libmtmd. You can find more information about llama.cpp's multimodal support in the docs and in the tools/mtmd source directory.
- In the conversion script, make sure you add a subclass that extends
MmprojModelor another class that inherits from the same base class. - Add the encoder definition in
clip.cpp. - Implement the preprocessor in
mtmd.cpp. In most cases, you can reuse an existing preprocessor. - Implement the encoder GGML graph, either in a dedicated file if the model is truly different from existing ones, or by reusing an existing implementation (for example: siglip, pixtral, or qwen) and adding a model-specific projector.
Note:
- Many multimodal encoders are based on models that are already supported. Make sure to read the existing encoder definitions in
tools/mtmd/modelsbefore adding a new one. Inlibmtmd, it is generally better to extend an existing model than to duplicate code. - To debug the multimodal preprocessor and encoder, you can use llama-mtmd-debug.
- Adding a model-specific API or CLI is an anti-pattern in
libmtmd. The goal oflibmtmdis to provide an easy-to-use, model-agnostic library for multimodal pipeline. - In most cases,
llama-mtmd-clishould not be modified. If a model requires a specific prompt, either let the user provide it or bake it into the Jinja chat template.
PyTorch implementations usually prefer explicitly calculating freq_cis/sin/cos components. However, in llama.cpp, most RoPE operations can be handled via ggml_rope_ext, which does not require a sin/cos matrix. This saves memory while allowing the GGML RoPE kernel to be fused with other ops.
However, since ggml_rope_ext only provides a subset of the RoPE implementations that models use, converting models from PyTorch to llama.cpp may require some creative adaptations.
For more information about ggml_rope_ext, please refer to the in-code documentation in ggml.h.
Examples:
libmtmdimplements 2D RoPE withGGML_ROPE_TYPE_NORMALordering by splitting the input tensor in half, applyingggml_rope_extseparately to each half, then joining them back together usingggml_concat.- The Kimi-K2.5 vision encoder uses vision RoPE with interleaved frequencies. The weights must be permuted during conversion in order to reuse the
build_rope_2d()function. - Gemma 4 uses "proportional" RoPE. We employ a trick where
rope_freqsis set to a very large value in the last dimensions to prevent those dimensions from being rotated. See theGemma4Modelclass inconvert_hf_to_gguf.py. - Some models require scaling the input position. For example,
[0, 1, 2, ...]becomes[0, 0.5, 1, ...]. In this case, you can provide the scaling viafreq_scale = 0.5f. - Some models use learned RoPE frequencies instead of relying on
powf(freq_base, -2.0 * i / n_dims). In this case, you can provide the learned frequencies via therope_freqstensor (corresponding to thecargument inggml_rope_ext), then setfreq_base = 1.0f. An important note is thatrope_freqsin GGML is the inverse (theta = pos[i] / rope_freqs), so you may need to invertrope_freqsduring conversion.
https://github.com/ggml-org/ggml/blob/master/docs/gguf.md
- YaRN RoPE scaling ggml-org#2268
- support Baichuan serial models ggml-org#3009
- support attention bias ggml-org#4283
- Mixtral support ggml-org#4406
- BERT embeddings ggml-org#5423
- Grok-1 support ggml-org#6204
- Command R Plus support ggml-org#6491
- support arch DBRX ggml-org#6515
- How to convert HuggingFace model to GGUF format ggml-org#2948