Skip to content

Commit c531ba4

Browse files
committed
uncomment 2bit
1 parent 27a5f93 commit c531ba4

3 files changed

Lines changed: 22 additions & 8 deletions

File tree

benchmark/bench_single_decode.ipynb

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,21 @@
44
"cell_type": "code",
55
"execution_count": 1,
66
"metadata": {},
7-
"outputs": [],
7+
"outputs": [
8+
{
9+
"ename": "ImportError",
10+
"evalue": "/home/ddy/miniconda3/envs/issue/lib/python3.10/site-packages/bit_decode-1.0.0.post1-py3.10-linux-x86_64.egg/bit_decode_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _Z28run_mha_fwd_splitkv_dispatchIN7cutlass6half_tELi128ELb0ELi1ELi2ELi128EEvR16Flash_fwd_paramsP11CUstream_st",
11+
"output_type": "error",
12+
"traceback": [
13+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
14+
"\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
15+
"Cell \u001b[0;32mIn[1], line 9\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mnumpy\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mnp\u001b[39;00m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mflash_attn\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m flash_attn_with_kvcache\n\u001b[0;32m----> 9\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mbit_decode\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m kvcache_pack_int, fwd_kvcache_int\n",
16+
"File \u001b[0;32m~/miniconda3/envs/issue/lib/python3.10/site-packages/bit_decode-1.0.0.post1-py3.10-linux-x86_64.egg/bit_decode/__init__.py:3\u001b[0m\n\u001b[1;32m 1\u001b[0m __version__ \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m1.0.0.post1\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mbit_decode\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbit_decode_interface\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[1;32m 4\u001b[0m kvcache_pack_int,\n\u001b[1;32m 5\u001b[0m fwd_kvcache_int\n\u001b[1;32m 6\u001b[0m )\n",
17+
"File \u001b[0;32m~/miniconda3/envs/issue/lib/python3.10/site-packages/bit_decode-1.0.0.post1-py3.10-linux-x86_64.egg/bit_decode/bit_decode_interface.py:8\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtorch\u001b[39;00m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mnn\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mnn\u001b[39;00m\n\u001b[0;32m----> 8\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mbit_decode_cuda\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mbit_decode_cuda\u001b[39;00m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mkvcache_pack_int\u001b[39m(k_cache: torch\u001b[38;5;241m.\u001b[39mTensor, k_pack: torch\u001b[38;5;241m.\u001b[39mTensor, k_params: torch\u001b[38;5;241m.\u001b[39mTensor,\n\u001b[1;32m 11\u001b[0m v_cache: torch\u001b[38;5;241m.\u001b[39mTensor, v_pack: torch\u001b[38;5;241m.\u001b[39mTensor, v_params: torch\u001b[38;5;241m.\u001b[39mTensor,\n\u001b[1;32m 12\u001b[0m opt_block_table: Optional[torch\u001b[38;5;241m.\u001b[39mTensor] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 16\u001b[0m group_size: \u001b[38;5;28mint\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m128\u001b[39m,\n\u001b[1;32m 17\u001b[0m num_bits: \u001b[38;5;28mint\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m4\u001b[39m):\n\u001b[1;32m 19\u001b[0m batch_size, seqlen_k, nheads_k, d \u001b[38;5;241m=\u001b[39m k_cache\u001b[38;5;241m.\u001b[39mshape\n",
18+
"\u001b[0;31mImportError\u001b[0m: /home/ddy/miniconda3/envs/issue/lib/python3.10/site-packages/bit_decode-1.0.0.post1-py3.10-linux-x86_64.egg/bit_decode_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _Z28run_mha_fwd_splitkv_dispatchIN7cutlass6half_tELi128ELb0ELi1ELi2ELi128EEvR16Flash_fwd_paramsP11CUstream_st"
19+
]
20+
}
21+
],
822
"source": [
923
"import torch\n",
1024
"import torch.nn as nn\n",
@@ -392,7 +406,7 @@
392406
],
393407
"metadata": {
394408
"kernelspec": {
395-
"display_name": "bitdecode",
409+
"display_name": "issue",
396410
"language": "python",
397411
"name": "python3"
398412
},
@@ -406,7 +420,7 @@
406420
"name": "python",
407421
"nbconvert_exporter": "python",
408422
"pygments_lexer": "ipython3",
409-
"version": "3.10.16"
423+
"version": "3.10.18"
410424
}
411425
},
412426
"nbformat": 4,

csrc/bit_decode/src/genfile/flash_fwd_split_hdim128_fp16_sm80_2bit.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@
44

55
#include "../flash_fwd_launch_template.h"
66

7-
// template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, 2, 128>(Flash_fwd_params &params, cudaStream_t stream);
7+
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, 2, 128>(Flash_fwd_params &params, cudaStream_t stream);
88
// template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, 2, 64>(Flash_fwd_params &params, cudaStream_t stream);
99
// template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, 2, 32>(Flash_fwd_params &params, cudaStream_t stream);

csrc/bit_decode/src/genfile/flash_qpack_hdim128_fp16_sm80_2bit.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55
#include "../flash_fwd_launch_template.h"
66

7-
// template<>
8-
// void run_kvcache_qpack_<cutlass::half_t, 128, 1, 2, 128>(Flash_fwd_params &params, cudaStream_t stream) {
9-
// run_kvcache_qpack_hdim128<cutlass::half_t, 1, 2, 128>(params, stream);
10-
// }
7+
template<>
8+
void run_kvcache_qpack_<cutlass::half_t, 128, 1, 2, 128>(Flash_fwd_params &params, cudaStream_t stream) {
9+
run_kvcache_qpack_hdim128<cutlass::half_t, 1, 2, 128>(params, stream);
10+
}
1111
// template<>
1212
// void run_kvcache_qpack_<cutlass::half_t, 128, 1, 2, 64>(Flash_fwd_params &params, cudaStream_t stream) {
1313
// run_kvcache_qpack_hdim128<cutlass::half_t, 1, 2, 64>(params, stream);

0 commit comments

Comments
 (0)