Commit 2d0d276
[PyT] Plumbing correct bias dims from TE to cudnn, while adding support for additional bias shapes (NVIDIA#2537)
* Plumbing correct bias dims from TE to cudnn
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Make changes for cp bias code
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
* Add dbias and dbias_ to run_dpa_with_cp test
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Fix: Use output_dBias instead of input_dBias to extract the shape
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
* Add guards for bias/bias_/dbias/dbias_ being None
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
* Add support for bias shape 111s in addition to the original 1hss, 11ss, b1ss and bhss
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
* Add support for dbias calculation and variant packing for the dbias shapes b1ss, bhss, 11ss in addition to the already supported 1hss
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
* Add support for 111s bias shape in DPA
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
* Allow fused attn for dbias calculation for 11ss, b1ss, bhss. Disable fused attn if dbias calculation for 111s is required, else enable
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
* Disable requires_grad for bias for shape 111s in tests
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
* Disable bias grad / training flag for 111s bias in the non-CP attn tests. Add bias shape 111s to test_dpa_bias_shapes
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
* Fix to correctly create the bias shape tensor instead of the hard coded shape. Fix the comparison logic shapes for bias/dbias
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
* Add fused attn cp test cases for all supported bias shapes
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* nit: switch to elif for bias grad conditional
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
* Add CP support for bias/dbias shape 111s
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
* Add support for is_training in CP attn tests
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* nit: Fix incorrect comment
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
* nit: Fix incorrect comment and assert string
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
* Create the dbias graph tensor only if it is a cuDNN supported bias shape
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
* Fix the dim that is being compared for the two cp chunks in the test
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
* nit: Reinstate the original test for right side swa
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
---------
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>1 parent f122b07 commit 2d0d276
File tree
10 files changed
+501
-233
lines changed- tests/pytorch
- attention
- transformer_engine
- common/fused_attn
- pytorch/attention/dot_product_attention
10 files changed
+501
-233
lines changedLarge diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
162 | 162 | | |
163 | 163 | | |
164 | 164 | | |
| 165 | + | |
| 166 | + | |
165 | 167 | | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
166 | 175 | | |
167 | 176 | | |
168 | 177 | | |
| |||
636 | 645 | | |
637 | 646 | | |
638 | 647 | | |
639 | | - | |
| 648 | + | |
| 649 | + | |
640 | 650 | | |
641 | 651 | | |
642 | 652 | | |
| |||
646 | 656 | | |
647 | 657 | | |
648 | 658 | | |
649 | | - | |
| 659 | + | |
650 | 660 | | |
651 | 661 | | |
652 | 662 | | |
| |||
1143 | 1153 | | |
1144 | 1154 | | |
1145 | 1155 | | |
| 1156 | + | |
| 1157 | + | |
1146 | 1158 | | |
| 1159 | + | |
| 1160 | + | |
| 1161 | + | |
1147 | 1162 | | |
1148 | 1163 | | |
1149 | | - | |
| 1164 | + | |
| 1165 | + | |
1150 | 1166 | | |
1151 | 1167 | | |
1152 | 1168 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
147 | 147 | | |
148 | 148 | | |
149 | 149 | | |
150 | | - | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
151 | 154 | | |
152 | 155 | | |
153 | 156 | | |
| |||
160 | 163 | | |
161 | 164 | | |
162 | 165 | | |
163 | | - | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
164 | 174 | | |
165 | 175 | | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
166 | 190 | | |
167 | 191 | | |
168 | 192 | | |
| |||
171 | 195 | | |
172 | 196 | | |
173 | 197 | | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
174 | 201 | | |
175 | 202 | | |
176 | 203 | | |
| |||
191 | 218 | | |
192 | 219 | | |
193 | 220 | | |
| 221 | + | |
194 | 222 | | |
195 | 223 | | |
| 224 | + | |
196 | 225 | | |
197 | 226 | | |
| 227 | + | |
198 | 228 | | |
199 | 229 | | |
200 | 230 | | |
| |||
324 | 354 | | |
325 | 355 | | |
326 | 356 | | |
| 357 | + | |
| 358 | + | |
327 | 359 | | |
328 | 360 | | |
329 | 361 | | |
330 | 362 | | |
331 | 363 | | |
332 | 364 | | |
| 365 | + | |
333 | 366 | | |
334 | 367 | | |
335 | 368 | | |
| |||
348 | 381 | | |
349 | 382 | | |
350 | 383 | | |
| 384 | + | |
351 | 385 | | |
352 | 386 | | |
353 | 387 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
271 | 271 | | |
272 | 272 | | |
273 | 273 | | |
274 | | - | |
275 | 274 | | |
276 | 275 | | |
277 | 276 | | |
| |||
289 | 288 | | |
290 | 289 | | |
291 | 290 | | |
292 | | - | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
293 | 294 | | |
294 | 295 | | |
295 | 296 | | |
| |||
Lines changed: 59 additions & 46 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
52 | 52 | | |
53 | 53 | | |
54 | 54 | | |
55 | | - | |
56 | | - | |
57 | | - | |
58 | | - | |
59 | | - | |
60 | | - | |
61 | | - | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
62 | 63 | | |
63 | 64 | | |
64 | 65 | | |
| |||
121 | 122 | | |
122 | 123 | | |
123 | 124 | | |
| 125 | + | |
| 126 | + | |
124 | 127 | | |
125 | 128 | | |
126 | 129 | | |
| |||
269 | 272 | | |
270 | 273 | | |
271 | 274 | | |
272 | | - | |
273 | | - | |
274 | | - | |
275 | | - | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
276 | 280 | | |
277 | 281 | | |
278 | 282 | | |
| |||
548 | 552 | | |
549 | 553 | | |
550 | 554 | | |
551 | | - | |
552 | | - | |
553 | | - | |
554 | | - | |
555 | | - | |
556 | | - | |
557 | | - | |
558 | | - | |
559 | | - | |
560 | | - | |
| 555 | + | |
| 556 | + | |
| 557 | + | |
| 558 | + | |
| 559 | + | |
| 560 | + | |
| 561 | + | |
| 562 | + | |
| 563 | + | |
| 564 | + | |
561 | 565 | | |
562 | 566 | | |
563 | 567 | | |
| |||
622 | 626 | | |
623 | 627 | | |
624 | 628 | | |
| 629 | + | |
| 630 | + | |
625 | 631 | | |
626 | 632 | | |
627 | 633 | | |
| |||
811 | 817 | | |
812 | 818 | | |
813 | 819 | | |
814 | | - | |
815 | | - | |
816 | | - | |
817 | | - | |
818 | | - | |
819 | | - | |
820 | | - | |
821 | | - | |
| 820 | + | |
| 821 | + | |
| 822 | + | |
| 823 | + | |
| 824 | + | |
822 | 825 | | |
823 | | - | |
824 | | - | |
825 | | - | |
826 | | - | |
| 826 | + | |
| 827 | + | |
| 828 | + | |
| 829 | + | |
| 830 | + | |
| 831 | + | |
| 832 | + | |
| 833 | + | |
827 | 834 | | |
828 | 835 | | |
829 | 836 | | |
| |||
974 | 981 | | |
975 | 982 | | |
976 | 983 | | |
977 | | - | |
| 984 | + | |
978 | 985 | | |
979 | | - | |
980 | | - | |
981 | 986 | | |
982 | 987 | | |
983 | 988 | | |
| |||
1083 | 1088 | | |
1084 | 1089 | | |
1085 | 1090 | | |
| 1091 | + | |
| 1092 | + | |
1086 | 1093 | | |
1087 | 1094 | | |
1088 | 1095 | | |
1089 | 1096 | | |
| 1097 | + | |
| 1098 | + | |
1090 | 1099 | | |
1091 | 1100 | | |
1092 | 1101 | | |
| |||
1152 | 1161 | | |
1153 | 1162 | | |
1154 | 1163 | | |
1155 | | - | |
| 1164 | + | |
1156 | 1165 | | |
1157 | 1166 | | |
1158 | 1167 | | |
| |||
1197 | 1206 | | |
1198 | 1207 | | |
1199 | 1208 | | |
1200 | | - | |
1201 | | - | |
1202 | | - | |
1203 | | - | |
| 1209 | + | |
| 1210 | + | |
| 1211 | + | |
| 1212 | + | |
1204 | 1213 | | |
1205 | 1214 | | |
1206 | 1215 | | |
| |||
1244 | 1253 | | |
1245 | 1254 | | |
1246 | 1255 | | |
| 1256 | + | |
| 1257 | + | |
1247 | 1258 | | |
1248 | 1259 | | |
1249 | 1260 | | |
1250 | 1261 | | |
1251 | 1262 | | |
| 1263 | + | |
| 1264 | + | |
1252 | 1265 | | |
1253 | 1266 | | |
1254 | 1267 | | |
| |||
1291 | 1304 | | |
1292 | 1305 | | |
1293 | 1306 | | |
1294 | | - | |
1295 | | - | |
1296 | | - | |
1297 | | - | |
1298 | | - | |
| 1307 | + | |
| 1308 | + | |
| 1309 | + | |
| 1310 | + | |
| 1311 | + | |
1299 | 1312 | | |
1300 | 1313 | | |
1301 | 1314 | | |
| |||
0 commit comments