Skip to content

Commit 125a1d2

Browse files
tamarPaltamarPal
authored andcommitted
megrez-moe : fix conversion
1 parent cd46a28 commit 125a1d2

2 files changed

Lines changed: 105 additions & 64 deletions

File tree

convert_hf_to_gguf.py

Lines changed: 83 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -9033,88 +9033,91 @@ def set_gguf_parameters(self):
90339033
self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"]))
90349034

90359035

9036-
@ModelBase.register("MegrezMoEForCausalLM")
9036+
@ModelBase.register("MegrezMoeForCausalLM", "MegrezMoEForCausalLM")
90379037
class MegrezMoEModel(TextModel):
90389038
model_arch = gguf.MODEL_ARCH.MEGREZ_MOE
90399039

90409040
def set_vocab(self):
9041-
from transformers import AutoTokenizer
9042-
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
9043-
9044-
tokpre = self.get_vocab_base_pre(tokenizer)
9045-
merges = []
9046-
vocab = {}
9047-
mergeable_ranks = getattr(tokenizer, "mergeable_ranks", {})
9048-
for token, rank in mergeable_ranks.items():
9049-
vocab[QwenModel.token_bytes_to_string(token)] = rank
9050-
if len(token) == 1:
9051-
continue
9052-
merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank)
9053-
if len(merged) == 2:
9054-
merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged)))
9055-
9056-
vocab_size = self.hparams["vocab_size"]
9057-
assert tokenizer.vocab_size == vocab_size
9058-
special_tokens = getattr(tokenizer, "special_tokens", {})
9059-
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()}
9060-
tokens: list[str] = []
9061-
toktypes: list[int] = []
9062-
for i in range(vocab_size):
9063-
if i not in reverse_vocab:
9064-
tokens.append(f"[PAD{i}]")
9065-
toktypes.append(gguf.TokenType.UNUSED)
9066-
else:
9067-
token = reverse_vocab[i]
9068-
tokens.append(token)
9069-
if i in special_tokens.values():
9070-
toktypes.append(gguf.TokenType.CONTROL)
9071-
else:
9072-
toktypes.append(gguf.TokenType.NORMAL)
9073-
9074-
self.gguf_writer.add_tokenizer_model("gpt2")
9075-
self.gguf_writer.add_tokenizer_pre(tokpre)
9076-
self.gguf_writer.add_token_list(tokens)
9077-
self.gguf_writer.add_token_types(toktypes)
9078-
self.gguf_writer.add_token_merges(merges)
9079-
9080-
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
9081-
special_vocab.add_to_gguf(self.gguf_writer)
9082-
# BOS token fix if needed
9083-
# self.gguf_writer.add_bos_token_id(<id>)
9041+
# Megrez-MoE uses Qwen-style BPE tokenizer
9042+
# Use standard GPT2 vocab loading which handles BPE correctly
9043+
try:
9044+
self._set_vocab_gpt2()
9045+
except Exception:
9046+
# Fallback to Qwen-specific handling if needed
9047+
self._set_vocab_qwen()
9048+
# Note: special_vocab.add_to_gguf() is already called within
9049+
# _set_vocab_gpt2() and _set_vocab_qwen(), so no need to call it again
90849050

90859051
def set_gguf_parameters(self):
90869052
super().set_gguf_parameters()
90879053
hparams = self.hparams
90889054

9089-
self.gguf_writer.add_expert_count(hparams["num_experts"])
9090-
self.gguf_writer.add_expert_shared_feed_forward_length(hparams["intermediate_size"])
9091-
9092-
moe_intermediate_size = hparams["moe_intermediate_size"]
9093-
assert all(n == moe_intermediate_size[0] for n in moe_intermediate_size)
9055+
# MoE expert configuration
9056+
# Try multiple possible parameter names for compatibility
9057+
num_experts = hparams.get("num_experts") or hparams.get("n_routed_experts")
9058+
if num_experts is None:
9059+
raise ValueError("Missing 'num_experts' or 'n_routed_experts' in model config")
9060+
self.gguf_writer.add_expert_count(num_experts)
9061+
9062+
# Shared expert FFN size - Note: In Megrez-MoE, this is NOT the same as intermediate_size!
9063+
# The shared experts have their own FFN size: hidden_size * 2.75
9064+
# For Megrez2-3x7B-A3B: hidden_size=2048 → shared_expert_ffn=5632
9065+
hidden_size = hparams.get("hidden_size", 2048)
9066+
shared_expert_ffn_size = int(hidden_size * 2.75)
9067+
9068+
self.gguf_writer.add_expert_shared_feed_forward_length(shared_expert_ffn_size)
9069+
9070+
# Per-expert FFN size (should be consistent across all experts)
9071+
moe_intermediate_size = hparams.get("moe_intermediate_size")
9072+
if moe_intermediate_size is None:
9073+
raise ValueError("Missing 'moe_intermediate_size' in model config")
9074+
if not isinstance(moe_intermediate_size, list):
9075+
moe_intermediate_size = [moe_intermediate_size]
9076+
9077+
# Validate all experts have same size
9078+
if not all(n == moe_intermediate_size[0] for n in moe_intermediate_size):
9079+
raise ValueError(f"All experts must have same FFN size, got: {moe_intermediate_size}")
90949080
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size[0])
90959081

9096-
moe_topk = hparams["moe_topk"]
9097-
assert all(topk == moe_topk[0] for topk in moe_topk)
9098-
self.gguf_writer.add_expert_used_count(moe_topk[0])
9099-
9100-
moe_shared_expert = hparams["num_shared_expert"]
9101-
assert all(n == moe_shared_expert[0] for n in moe_shared_expert)
9102-
self.gguf_writer.add_expert_shared_count(moe_shared_expert[0])
9103-
9104-
rope_scaling = hparams.get("rope_scaling", {})
9105-
if rope_scaling.get("type") == "dynamic":
9082+
# Top-K expert selection is already handled by parent class (TextModel)
9083+
# via num_experts_per_tok parameter, so we don't need to set it again here
9084+
9085+
# Shared expert count (should be consistent across layers)
9086+
# Try multiple possible parameter names
9087+
num_shared_expert = hparams.get("num_shared_expert") or hparams.get("n_shared_experts")
9088+
if num_shared_expert is None:
9089+
raise ValueError("Missing 'num_shared_expert' or 'n_shared_experts' in model config")
9090+
if not isinstance(num_shared_expert, list):
9091+
num_shared_expert = [num_shared_expert]
9092+
9093+
if not all(n == num_shared_expert[0] for n in num_shared_expert):
9094+
raise ValueError(f"All layers must have same shared expert count, got: {num_shared_expert}")
9095+
self.gguf_writer.add_expert_shared_count(num_shared_expert[0])
9096+
9097+
# RoPE scaling (Megrez may use dynamic scaling)
9098+
rope_scaling = hparams.get("rope_scaling")
9099+
if rope_scaling and rope_scaling.get("type") == "dynamic":
91069100
alpha = rope_scaling.get("alpha", 1000)
91079101
base = hparams.get("rope_theta", 10000.0)
9108-
dim = (hparams["hidden_size"] // hparams["num_attention_heads"])
9102+
hidden_size = hparams.get("hidden_size")
9103+
num_attention_heads = hparams.get("num_attention_heads")
9104+
9105+
if None in (hidden_size, num_attention_heads):
9106+
raise ValueError("Missing 'hidden_size' or 'num_attention_heads' for RoPE scaling")
9107+
9108+
dim = hidden_size // num_attention_heads
91099109
scaled_base = base * (alpha ** (dim / (dim - 2)))
9110+
91109111
self.gguf_writer.add_rope_freq_base(scaled_base)
91119112
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
91129113
self.gguf_writer.add_rope_scaling_factor(1)
91139114
self.gguf_writer.add_rope_scaling_orig_ctx_len(256 * 1024)
91149115
self.gguf_writer.add_context_length(256 * 1024)
9115-
assert alpha == 1000 and base == 10000.0 and dim == 128 and self.hparams["max_position_embeddings"] in [32 * 1024, 256 * 1024], \
9116-
"Megrez dynamic RoPE scaling assumptions changed, please update the logic or context length manually"
9117-
9116+
9117+
logger.info(
9118+
f"Megrez dynamic RoPE: alpha={alpha}, base={base}, dim={dim}, "
9119+
f"scaled_base={scaled_base:.2f}, max_ctx={256*1024}"
9120+
)
91189121
_experts: list[dict[str, Tensor]] | None = None
91199122

91209123
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
@@ -9123,8 +9126,24 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
91239126
logger.info("Skipping tied output layer 'lm_head.weight'")
91249127
return []
91259128

9129+
# Handle MoE gate bias (e_score_correction_bias) - map to exp_probs_b
9130+
if "e_score_correction_bias" in name:
9131+
# This is the expert selection bias - map to blk.N.exp_probs_b
9132+
# Format: model.layers.N.mlp.gate.e_score_correction_bias -> blk.N.exp_probs_b
9133+
layer_num = int(name.split(".")[2]) # Extract layer number
9134+
new_name = f"blk.{layer_num}.exp_probs_b"
9135+
return [(new_name, data_torch)]
9136+
9137+
# Handle shared FFN (non-expert layers) - pass through directly
9138+
if name.find("mlp.down_proj") != -1 or name.find("mlp.gate_proj") != -1 or name.find("mlp.up_proj") != -1:
9139+
if name.find("mlp.experts") == -1:
9140+
# This is a shared FFN layer, not an expert - pass through
9141+
return [(self.map_tensor_name(name), data_torch)]
9142+
91269143
if name.find("mlp.experts") != -1:
9127-
n_experts = self.hparams["num_experts"]
9144+
n_experts = self.hparams.get("num_experts") or self.hparams.get("n_routed_experts")
9145+
if n_experts is None:
9146+
raise ValueError("Missing 'num_experts' or 'n_routed_experts' in config")
91289147
assert bid is not None
91299148

91309149
if self._experts is None:

gguf-py/gguf/constants.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1551,6 +1551,28 @@ class MODEL_TENSOR(IntEnum):
15511551
MODEL_TENSOR.FFN_DOWN_EXP,
15521552
MODEL_TENSOR.FFN_UP_EXP,
15531553
],
1554+
MODEL_ARCH.MEGREZ_MOE: [
1555+
MODEL_TENSOR.TOKEN_EMBD,
1556+
MODEL_TENSOR.OUTPUT_NORM,
1557+
MODEL_TENSOR.OUTPUT,
1558+
MODEL_TENSOR.ATTN_NORM,
1559+
MODEL_TENSOR.ATTN_Q,
1560+
MODEL_TENSOR.ATTN_K,
1561+
MODEL_TENSOR.ATTN_V,
1562+
MODEL_TENSOR.ATTN_OUT,
1563+
MODEL_TENSOR.FFN_NORM,
1564+
MODEL_TENSOR.FFN_GATE_INP,
1565+
MODEL_TENSOR.FFN_GATE,
1566+
MODEL_TENSOR.FFN_DOWN,
1567+
MODEL_TENSOR.FFN_UP,
1568+
MODEL_TENSOR.FFN_GATE_EXP,
1569+
MODEL_TENSOR.FFN_DOWN_EXP,
1570+
MODEL_TENSOR.FFN_UP_EXP,
1571+
MODEL_TENSOR.FFN_GATE_INP_SHEXP,
1572+
MODEL_TENSOR.FFN_GATE_SHEXP,
1573+
MODEL_TENSOR.FFN_DOWN_SHEXP,
1574+
MODEL_TENSOR.FFN_UP_SHEXP,
1575+
],
15541576
MODEL_ARCH.QWEN3VL: [
15551577
MODEL_TENSOR.TOKEN_EMBD,
15561578
MODEL_TENSOR.OUTPUT_NORM,

0 commit comments

Comments
 (0)