Commit b6f47b1
authored
Maf 18836 gpt mistal pp rope (#41)
# What does this PR do?
moreh pipeline ๊ณผ rope cache ๊ธฐ๋ฅ์ ์ถ๊ฐํจ.
1. Pipeline์ ์ฌ์ฉํ๊ธฐ ์ํด์๋ ์๋์ ์ฝ๋๊ฐ python ๋ ๋ฒจ์์ ํธ์ถ์ด ๋์ด์ผํจ.
```
hidden_states = torch.moreh.pipeline_assign(hidden_states)
```
์ด๋ฅผ ์ํด ํด๋น ์ฝ๋๋ฅผ ์ถ๊ฐํจ.
2. Mistral์ Rope๋ฅผ cache ๋ฒ์ ์ ์ฌ์ฉํ๋๋ก ํจ.
๊ธฐ์กด `MistralRotaryEmbedding` ์์๋ forward ๋ง๋ค ๋งค๋ฒ cos, sin tensor๋ฅผ ํ์ผ
ํ๋์จ์ด(GPU, NPU) ์์์ ์ฐ์ฐํ์์.
์ด ๊ฐ์ ๊ณ์ฐํ๊ธฐ์ํด์๋ hidden์ dtype๊ณผ shape ์ ๋ณด๋ง ์์ผ๋ฉด ๋ฏธ๋ฆฌ ๊ณ์ฐ์ ํด๋๋ ์ต์ ํ๋ฅผ ์ ์ฉ ํ ์ ์์.
(์ง์ง ๋ชฉ์ ์ cache ๋ฒ์ ์ ์ฌ์ฉํ์ง ์์ผ๋ฉด MAF์ Pipeline ์คํ์ด ์๋จ)
์ฐธ๊ณ ์ฝ๋.
๊ตฌ๋ฒ์ HuggingFace์ llama Code.
https://github.com/huggingface/transformers/blob/9c804f7ec42c94289ce52eaa84eed32f770311d7/src/transformers/models/deprecated/open_llama/modeling_open_llama.py#L109
MAF์ rope ์ฝ๋ (CPU์์ ๋ฏธ๋ฆฌ ์ฐ์ฐ์ ์ํํด๋ )
https://github.com/moreh-dev/framework/blob/df54f28ce96ff43dce4c0b40a0aeb7bff7fd6b0c/IR/driver/pytorch/torch/moreh_ops/rotary_embedding.py#L35
์๋๋ MAF์์ ํธ์ถํ๋ ์์ ์
๋๋ค.
moreh-dev/framework#8819
sh ํ์ผ์ ์์ฌ๋ ค๋จ๋๋ฐ ํฌ๊ฒ ๋ค๋ฅผ๊ฑด ์๊ณ ๋ชจ๋ธ์ด๋ฆ์ gpt2-small-moreh ์ด๋ ๊ฒ ์คํํ๋ฉด ๋ฉ๋๋ค.
json ์์
```
{
"attention_dropout": 0.0,
"bos_token_id": 1,
"eos_token_id": 2,
"head_dim": null,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 7168,
"max_position_embeddings": 131072,
"model_type": "mistral-moreh",
"num_attention_heads": 16,
"num_hidden_layers": 7,
"num_key_value_heads": 8,
"pad_token_id": 2,
"rms_norm_eps": 1e-06,
"rope_theta": 10000.0,
"sliding_window": null,
"tie_word_embeddings": false,
"transformers_version": "4.42.4",
"use_cache": true,
"vocab_size": 32000,
"moreh_config": {
"pipeline_layers": [3],
"rope_cache": true
}
}
```1 parent 51d68ae commit b6f47b1
4 files changed
Lines changed: 58 additions & 4 deletions
File tree
- src/transformers/models
- gpt2
- mistral
Lines changed: 3 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
159 | 159 | | |
160 | 160 | | |
161 | 161 | | |
| 162 | + | |
162 | 163 | | |
163 | 164 | | |
164 | 165 | | |
| |||
186 | 187 | | |
187 | 188 | | |
188 | 189 | | |
| 190 | + | |
| 191 | + | |
189 | 192 | | |
190 | 193 | | |
191 | 194 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1016 | 1016 | | |
1017 | 1017 | | |
1018 | 1018 | | |
| 1019 | + | |
| 1020 | + | |
| 1021 | + | |
| 1022 | + | |
| 1023 | + | |
| 1024 | + | |
| 1025 | + | |
1019 | 1026 | | |
1020 | 1027 | | |
1021 | 1028 | | |
| |||
1257 | 1264 | | |
1258 | 1265 | | |
1259 | 1266 | | |
| 1267 | + | |
| 1268 | + | |
1260 | 1269 | | |
1261 | 1270 | | |
1262 | 1271 | | |
| |||
1293 | 1302 | | |
1294 | 1303 | | |
1295 | 1304 | | |
1296 | | - | |
1297 | 1305 | | |
1298 | 1306 | | |
1299 | 1307 | | |
| |||
Lines changed: 3 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
116 | 116 | | |
117 | 117 | | |
118 | 118 | | |
| 119 | + | |
119 | 120 | | |
120 | 121 | | |
121 | 122 | | |
| |||
138 | 139 | | |
139 | 140 | | |
140 | 141 | | |
| 142 | + | |
| 143 | + | |
141 | 144 | | |
142 | 145 | | |
143 | 146 | | |
| |||
Lines changed: 43 additions & 3 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
94 | 94 | | |
95 | 95 | | |
96 | 96 | | |
97 | | - | |
| 97 | + | |
98 | 98 | | |
99 | 99 | | |
100 | 100 | | |
| |||
103 | 103 | | |
104 | 104 | | |
105 | 105 | | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
106 | 126 | | |
107 | 127 | | |
108 | 128 | | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
109 | 136 | | |
110 | 137 | | |
111 | 138 | | |
| |||
221 | 248 | | |
222 | 249 | | |
223 | 250 | | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
224 | 256 | | |
225 | 257 | | |
226 | 258 | | |
227 | 259 | | |
| 260 | + | |
228 | 261 | | |
229 | 262 | | |
230 | 263 | | |
| |||
885 | 918 | | |
886 | 919 | | |
887 | 920 | | |
| 921 | + | |
| 922 | + | |
| 923 | + | |
| 924 | + | |
| 925 | + | |
| 926 | + | |
888 | 927 | | |
889 | 928 | | |
890 | 929 | | |
| |||
957 | 996 | | |
958 | 997 | | |
959 | 998 | | |
960 | | - | |
| 999 | + | |
961 | 1000 | | |
962 | 1001 | | |
963 | 1002 | | |
| |||
984 | 1023 | | |
985 | 1024 | | |
986 | 1025 | | |
| 1026 | + | |
| 1027 | + | |
987 | 1028 | | |
988 | 1029 | | |
989 | 1030 | | |
| |||
1123 | 1164 | | |
1124 | 1165 | | |
1125 | 1166 | | |
1126 | | - | |
1127 | 1167 | | |
1128 | 1168 | | |
1129 | 1169 | | |
| |||
0 commit comments