Skip to content

Commit b6f47b1

Browse files
authored
Maf 18836 gpt mistal pp rope (#41)
# What does this PR do? moreh pipeline ๊ณผ rope cache ๊ธฐ๋Šฅ์„ ์ถ”๊ฐ€ํ•จ. 1. Pipeline์„ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” ์•„๋ž˜์˜ ์ฝ”๋“œ๊ฐ€ python ๋ ˆ๋ฒจ์—์„œ ํ˜ธ์ถœ์ด ๋˜์–ด์•ผํ•จ. ``` hidden_states = torch.moreh.pipeline_assign(hidden_states) ``` ์ด๋ฅผ ์œ„ํ•ด ํ•ด๋‹น ์ฝ”๋“œ๋ฅผ ์ถ”๊ฐ€ํ•จ. 2. Mistral์˜ Rope๋ฅผ cache ๋ฒ„์ „์„ ์‚ฌ์šฉํ•˜๋„๋ก ํ•จ. ๊ธฐ์กด `MistralRotaryEmbedding` ์—์„œ๋Š” forward ๋งˆ๋‹ค ๋งค๋ฒˆ cos, sin tensor๋ฅผ ํƒ€์ผ“ ํ•˜๋“œ์›จ์–ด(GPU, NPU) ์œ„์—์„œ ์—ฐ์‚ฐํ•˜์˜€์Œ. ์ด ๊ฐ’์„ ๊ณ„์‚ฐํ•˜๊ธฐ์œ„ํ•ด์„œ๋Š” hidden์˜ dtype๊ณผ shape ์ •๋ณด๋งŒ ์žˆ์œผ๋ฉด ๋ฏธ๋ฆฌ ๊ณ„์‚ฐ์„ ํ•ด๋‘๋Š” ์ตœ์ ํ™”๋ฅผ ์ ์šฉ ํ•  ์ˆ˜ ์žˆ์Œ. (์ง„์งœ ๋ชฉ์ ์€ cache ๋ฒ„์ „์„ ์‚ฌ์šฉํ•˜์ง€ ์•Š์œผ๋ฉด MAF์˜ Pipeline ์‹คํ–‰์ด ์•ˆ๋จ) ์ฐธ๊ณ  ์ฝ”๋“œ. ๊ตฌ๋ฒ„์ „ HuggingFace์˜ llama Code. https://github.com/huggingface/transformers/blob/9c804f7ec42c94289ce52eaa84eed32f770311d7/src/transformers/models/deprecated/open_llama/modeling_open_llama.py#L109 MAF์˜ rope ์ฝ”๋“œ (CPU์—์„œ ๋ฏธ๋ฆฌ ์—ฐ์‚ฐ์„ ์ˆ˜ํ–‰ํ•ด๋‘ ) https://github.com/moreh-dev/framework/blob/df54f28ce96ff43dce4c0b40a0aeb7bff7fd6b0c/IR/driver/pytorch/torch/moreh_ops/rotary_embedding.py#L35 ์•„๋ž˜๋Š” MAF์—์„œ ํ˜ธ์ถœํ•˜๋Š” ์˜ˆ์ œ ์ž…๋‹ˆ๋‹ค. moreh-dev/framework#8819 sh ํŒŒ์ผ์€ ์•ˆ์˜ฌ๋ ค๋†จ๋Š”๋ฐ ํฌ๊ฒŒ ๋‹ค๋ฅผ๊ฑด ์—†๊ณ  ๋ชจ๋ธ์ด๋ฆ„์„ gpt2-small-moreh ์ด๋ ‡๊ฒŒ ์‹คํ–‰ํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค. json ์˜ˆ์ œ ``` { "attention_dropout": 0.0, "bos_token_id": 1, "eos_token_id": 2, "head_dim": null, "hidden_act": "silu", "hidden_size": 2048, "initializer_range": 0.02, "intermediate_size": 7168, "max_position_embeddings": 131072, "model_type": "mistral-moreh", "num_attention_heads": 16, "num_hidden_layers": 7, "num_key_value_heads": 8, "pad_token_id": 2, "rms_norm_eps": 1e-06, "rope_theta": 10000.0, "sliding_window": null, "tie_word_embeddings": false, "transformers_version": "4.42.4", "use_cache": true, "vocab_size": 32000, "moreh_config": { "pipeline_layers": [3], "rope_cache": true } } ```
1 parent 51d68ae commit b6f47b1

4 files changed

Lines changed: 58 additions & 4 deletions

File tree

โ€Žsrc/transformers/models/gpt2/configuration_gpt2_moreh.pyโ€Ž

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def __init__(
159159
eos_token_id=50256,
160160
scale_attn_by_inverse_layer_idx=False,
161161
reorder_and_upcast_attn=False,
162+
moreh_config=None,
162163
**kwargs,
163164
):
164165
self.vocab_size = vocab_size
@@ -186,6 +187,8 @@ def __init__(
186187
self.bos_token_id = bos_token_id
187188
self.eos_token_id = eos_token_id
188189

190+
self.moreh_config = moreh_config
191+
189192
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
190193

191194

โ€Žsrc/transformers/models/gpt2/modeling_gpt2_moreh.pyโ€Ž

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1016,6 +1016,13 @@ def __init__(self, config):
10161016
# Initialize weights and apply final processing
10171017
self.post_init()
10181018

1019+
# Moreh Config
1020+
self.moreh_pipeline_layers = []
1021+
moreh_config = getattr(config, "moreh_config", None)
1022+
if moreh_config is not None and "pipeline_layers" in moreh_config:
1023+
self.moreh_pipeline_layers = moreh_config["pipeline_layers"]
1024+
1025+
10191026
@add_start_docstrings(PARALLELIZE_DOCSTRING)
10201027
def parallelize(self, device_map=None):
10211028
# Check validity of device_map
@@ -1257,6 +1264,8 @@ def forward(
12571264
for k, v in self.device_map.items():
12581265
if i == v[-1] and "cuda:" + str(k) != self.last_device:
12591266
hidden_states = hidden_states.to("cuda:" + str(k + 1))
1267+
if i in self.moreh_pipeline_layers:
1268+
hidden_states = torch.moreh.pipeline_assign(hidden_states)
12601269

12611270
hidden_states = self.ln_f(hidden_states)
12621271

@@ -1293,7 +1302,6 @@ class GPT2LMHeadModelMoreh(GPT2PreTrainedModel):
12931302

12941303
def __init__(self, config):
12951304
super().__init__(config)
1296-
print("GPT2LMHeadModelMoreh ##################################")
12971305
self.transformer = GPT2Model(config)
12981306
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
12991307

โ€Žsrc/transformers/models/mistral/configuration_mistral_moreh.pyโ€Ž

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def __init__(
116116
rope_theta=10000.0,
117117
sliding_window=4096,
118118
attention_dropout=0.0,
119+
moreh_config=None,
119120
**kwargs,
120121
):
121122
self.vocab_size = vocab_size
@@ -138,6 +139,8 @@ def __init__(
138139
self.rope_theta = rope_theta
139140
self.attention_dropout = attention_dropout
140141

142+
self.moreh_config = moreh_config
143+
141144
super().__init__(
142145
pad_token_id=pad_token_id,
143146
bos_token_id=bos_token_id,

โ€Žsrc/transformers/models/mistral/modeling_mistral_moreh.pyโ€Ž

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def forward(self, hidden_states):
9494

9595

9696
class MistralRotaryEmbedding(nn.Module):
97-
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
97+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, use_rope_cache=False):
9898
super().__init__()
9999

100100
self.dim = dim
@@ -103,9 +103,36 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
103103
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
104104
self.register_buffer("inv_freq", inv_freq, persistent=False)
105105

106+
self.use_rope_cache = use_rope_cache
107+
if self.use_rope_cache:
108+
self._set_cos_sin_cache(max_position_embeddings, dtype=torch.float32)
109+
110+
def _set_cos_sin_cache(self, seq_len, dtype):
111+
self.max_seq_len_cached = seq_len
112+
113+
t = torch.arange(seq_len, dtype=torch.float32, device="cpu")
114+
freqs = torch.outer(t, self.inv_freq.cpu()) # [seq_len, dim/2]
115+
emb = torch.cat((freqs, freqs), dim=-1) # [seq_len, dim]
116+
117+
cos = emb.cos()
118+
sin = emb.sin()
119+
120+
cos = cos.to(device='cuda', dtype=dtype)
121+
sin = sin.to(device='cuda', dtype=dtype)
122+
123+
self.register_buffer("cos_cached", cos, persistent=False)
124+
self.register_buffer("sin_cached", sin, persistent=False)
125+
106126
@torch.no_grad()
107127
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward
108128
def forward(self, x, position_ids):
129+
if self.use_rope_cache:
130+
seq_len = position_ids.shape[-1]
131+
assert seq_len <= self.max_position_embeddings, "Sequence length exceeds maximum position embeddings"
132+
cos = self.cos_cached[:seq_len].to(dtype=x.dtype, device=x.device).unsqueeze(0)
133+
sin = self.sin_cached[:seq_len].to(dtype=x.dtype, device=x.device).unsqueeze(0)
134+
return cos, sin
135+
109136
# x: [bs, num_attention_heads, seq_len, head_size]
110137
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
111138
position_ids_expanded = position_ids[:, None, :].float()
@@ -221,10 +248,16 @@ def __init__(self, config: MistralMorehConfig, layer_idx: Optional[int] = None):
221248
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
222249
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
223250

251+
use_rope_cache = False
252+
moreh_config = getattr(config, "moreh_config", None)
253+
if moreh_config is not None and "rope_cache" in moreh_config:
254+
use_rope_cache = moreh_config["rope_cache"]
255+
224256
self.rotary_emb = MistralRotaryEmbedding(
225257
self.head_dim,
226258
max_position_embeddings=self.max_position_embeddings,
227259
base=self.rope_theta,
260+
use_rope_cache=use_rope_cache,
228261
)
229262

230263
def forward(
@@ -885,6 +918,12 @@ def __init__(self, config: MistralMorehConfig):
885918
# Initialize weights and apply final processing
886919
self.post_init()
887920

921+
# Moreh Config
922+
self.moreh_pipeline_layers = []
923+
moreh_config = getattr(config, "moreh_config", None)
924+
if moreh_config is not None and "pipeline_layers" in moreh_config:
925+
self.moreh_pipeline_layers = moreh_config["pipeline_layers"]
926+
888927
def get_input_embeddings(self):
889928
return self.embed_tokens
890929

@@ -957,7 +996,7 @@ def forward(
957996
all_self_attns = () if output_attentions else None
958997
next_decoder_cache = None
959998

960-
for decoder_layer in self.layers:
999+
for layer_idx, decoder_layer in enumerate(self.layers):
9611000
if output_hidden_states:
9621001
all_hidden_states += (hidden_states,)
9631002

@@ -984,6 +1023,8 @@ def forward(
9841023
)
9851024

9861025
hidden_states = layer_outputs[0]
1026+
if layer_idx in self.moreh_pipeline_layers:
1027+
hidden_states = torch.moreh.pipeline_assign(hidden_states)
9871028

9881029
if use_cache:
9891030
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
@@ -1123,7 +1164,6 @@ class MistralForCausalLMMoreh(MistralPreTrainedModel):
11231164

11241165
def __init__(self, config):
11251166
super().__init__(config)
1126-
print("MistralForCausalLMMoreh #########################################")
11271167
self.model = MistralModel(config)
11281168
self.vocab_size = config.vocab_size
11291169
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

0 commit comments

Comments
ย (0)