@@ -57,6 +57,54 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
5757 }
5858}
5959
60+ template <typename reorder_vec_dot_q_sycl>
61+ static void mul_mat_vec_q_reorder_local_mem (const void * __restrict__ vx, const void * __restrict__ vy, sycl::local_accessor<block_q8_1, 1 > vy_local, float * __restrict__ dst,
62+ const int ncols, const int nrows, const sycl::nd_item<3 > & nd_item) {
63+ using block_type = ggml_sycl_reordered::block_q_t <reorder_vec_dot_q_sycl::gtype>;
64+ using block_traits = typename block_type::traits;
65+
66+ const auto sg = nd_item.get_sub_group ();
67+ const int sg_range = sg.get_group_linear_range ();
68+ const int workgroup_id = nd_item.get_group_linear_id ();
69+ const int sg_id = sg.get_group_linear_id ();
70+ const int row = workgroup_id * sg_range + sg_id;
71+
72+ if (row >= nrows) return ;
73+
74+ const int blocks_per_row = ncols / block_traits::qk;
75+ constexpr int blocks_per_subgroup = ceil_div (block_traits::vdr_mmvq * WARP_SIZE, block_traits::qi);
76+ constexpr int block_elements_per_sg = block_traits::qi / block_traits::vdr_mmvq;
77+
78+ const int total_y_blocks = blocks_per_row * block_type::block_to_q8_1_ratio ();
79+ const int nblocks = nrows * blocks_per_row;
80+
81+ const block_q8_1 * y_global = static_cast <const block_q8_1 *>(vy);
82+ for (int i = nd_item.get_local_linear_id (); i < total_y_blocks; i += nd_item.get_local_range ().size ()) {
83+ vy_local[i] = y_global[i];
84+ }
85+ nd_item.barrier (sycl::access::fence_space::local_space);
86+
87+ float partial_sum = 0 .0f ;
88+ for (int i = sg.get_local_linear_id () / block_elements_per_sg; i < blocks_per_row; i += blocks_per_subgroup) {
89+ const int ibx = row * blocks_per_row + i;
90+ const int bx_offset = block_type::get_block_offset (ibx);
91+ const int d_offset = block_type::get_d_offset (nrows, ncols, ibx);
92+
93+ const int iby = i * block_type::block_to_q8_1_ratio ();
94+
95+ for (int elem = 0 ; elem < block_elements_per_sg; elem += WARP_SIZE) {
96+ const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id () % block_elements_per_sg);
97+ partial_sum += reorder_vec_dot_q_sycl ()(vx, bx_offset, d_offset, &vy_local[iby], iqs, nblocks);
98+ }
99+ }
100+
101+ float sum = sycl::reduce_over_group (sg, partial_sum, std::plus<>());
102+
103+ if (sg.leader ()) {
104+ dst[row] = sum;
105+ }
106+ }
107+
60108template <int qk, int qi, typename block_q_t , int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
61109static void mul_mat_vec_q (const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
62110 const int ncols, const int nrows, const sycl::nd_item<3 > & item_ct1) {
@@ -101,6 +149,58 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_
101149 }
102150}
103151
152+ template <int qk, int qi, typename block_q_t , int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
153+ static void mul_mat_vec_q_local_mem (const void * __restrict__ vx, const void * __restrict__ vy, sycl::local_accessor<block_q8_1, 1 > y_local, float * __restrict__ dst,
154+ const int ncols, const int nrows, const sycl::nd_item<3 > & item_ct1) {
155+ const int row = item_ct1.get_group (2 ) * item_ct1.get_local_range (1 ) + item_ct1.get_local_id (1 );
156+
157+ if (row >= nrows) {
158+ return ;
159+ }
160+
161+ const int blocks_per_row = ncols / qk;
162+ constexpr int blocks_per_warp = (vdr * WARP_SIZE + qi - 1 ) / qi; // Ensuring blocks_per_warp > 0
163+
164+ assert (blocks_per_warp > 0 );
165+
166+ // partial sum for each thread
167+ float tmp = 0 .0f ;
168+
169+ const block_q_t * x = (const block_q_t *) vx;
170+ const block_q8_1 * y = (const block_q8_1 *) vy;
171+
172+ const int blocks_per_row_y = ncols / /* qk_vec */ QK8_1; // TODO:: hardcoded
173+ const int total_y_blocks = blocks_per_row_y;
174+ for (int iby = item_ct1.get_local_id (2 ); iby < total_y_blocks; iby += item_ct1.get_local_range (2 )) {
175+ y_local[iby] = y[iby];
176+ }
177+
178+ item_ct1.barrier (sycl::access::fence_space::local_space);
179+
180+ for (int i = item_ct1.get_local_id (2 ) / (qi / vdr); i < blocks_per_row; i += blocks_per_warp) {
181+ const int ibx = row * blocks_per_row + i; // x block index
182+
183+ const int iby = i * (qk / QK8_1); // y block index that aligns with ibx
184+
185+ for (size_t elem = 0 ; elem < qi / vdr; elem += WARP_SIZE) {
186+ const int iqs = elem + vdr * (item_ct1.get_local_id (2 ) %
187+ (qi / vdr)); // x block quant index when casting the quants to int
188+
189+ tmp += vec_dot_q_sycl (&x[ibx], &y_local[iby], iqs);
190+ }
191+ }
192+
193+ // sum up partial sums and write back result
194+ #pragma unroll
195+ for (int mask = WARP_SIZE / 2 ; mask > 0 ; mask >>= 1 ) {
196+ tmp += dpct::permute_sub_group_by_xor (item_ct1.get_sub_group (), tmp, mask);
197+ }
198+
199+ if (item_ct1.get_local_id (2 ) == 0 ) {
200+ dst[row] = tmp;
201+ }
202+ }
203+
104204template <int qk, int qi, typename block_q_t , int vdr>
105205static void mul_mat_vec_q_iq2_xxs_q8_1 (const void *__restrict__ vx,
106206 const void *__restrict__ vy,
@@ -720,41 +820,65 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
720820 float *dst, const int ncols,
721821 const int nrows,
722822 dpct::queue_ptr stream) {
823+ // std::cout << ">>>>>>>>> THIS IS CALLED\n";
723824 GGML_ASSERT (ncols % QK_K == 0 );
724825 const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1 ) / GGML_SYCL_MMV_Y;
725826 const sycl::range<3 > block_nums (1 , 1 , block_num_y);
726827 const sycl::range<3 > block_dims (1 , GGML_SYCL_MMV_Y, WARP_SIZE);
828+
829+ using block_type = ggml_sycl_reordered::block_q_t <GGML_TYPE_Q4_K>;
830+ const int blocks_per_row = ncols / block_type::traits::qk;
831+ const int total_y_blocks = blocks_per_row * block_type::block_to_q8_1_ratio ();
832+ if (total_y_blocks * sizeof (block_q8_1) > stream->get_device ().get_info <sycl::info::device::local_mem_size>()) {
833+ // TODO: add fallback
834+ GGML_ABORT (" not enough local mem" );
835+ }
836+
727837 {
728838
729839 stream->submit ([&](sycl::handler &cgh) {
840+ sycl::local_accessor<block_q8_1, 1 > vy_local (sycl::range<1 >(total_y_blocks), cgh);
730841
731842 cgh.parallel_for (
732843 sycl::nd_range<3 >(block_nums * block_dims, block_dims),
733844 [=](sycl::nd_item<3 > item_ct1)
734845 [[sycl::reqd_sub_group_size (WARP_SIZE)]] {
735- mul_mat_vec_q <QK_K, QI4_K, block_q4_K,
846+ mul_mat_vec_q_local_mem <QK_K, QI4_K, block_q4_K,
736847 VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
737- vx, vy, dst, ncols, nrows, item_ct1);
848+ vx, vy, vy_local, dst, ncols, nrows, item_ct1);
738849 });
739850 });
740851 }
741852}
742853
743854static void reorder_mul_mat_vec_q4_k_q8_1_sycl (const void * vx, const void * vy, float * dst, const int ncols,
744855 const int nrows, dpct::queue_ptr stream) {
856+ // std::cout << ">>>>>>>>> REORDER PATH\n";
857+
745858 GGML_ASSERT (ncols % QK_K == 0 );
746859
747860 const int block_num_y = ceil_div (nrows, GGML_SYCL_MMV_Y);
748861 constexpr size_t num_subgroups = 16 ;
749862 GGML_ASSERT (block_num_y % num_subgroups == 0 );
863+ // std::cout << "block_num_y: " << block_num_y << ", num_subgroups: " << num_subgroups << ", nrows: " << nrows << ", ncols:" << ncols << "\n";
750864
751865 const sycl::range<3 > global_size (1 , GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
752866 const sycl::range<3 > workgroup_size (1 , GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
753867
868+ using block_type = ggml_sycl_reordered::block_q_t <GGML_TYPE_Q4_K>;
869+ const int blocks_per_row = ncols / block_type::traits::qk;
870+ const int total_y_blocks = blocks_per_row * block_type::block_to_q8_1_ratio ();
871+ if (total_y_blocks * sizeof (block_q8_1) > stream->get_device ().get_info <sycl::info::device::local_mem_size>()) {
872+ // TODO: add fallback
873+ GGML_ABORT (" not enough local mem" );
874+ }
875+
754876 stream->submit ([&](sycl::handler & cgh) {
877+ sycl::local_accessor<block_q8_1, 1 > vy_local (sycl::range<1 >(total_y_blocks), cgh);
878+
755879 cgh.parallel_for (sycl::nd_range<3 >(global_size, workgroup_size),
756880 [=](sycl::nd_item<3 > nd_item) [[sycl::reqd_sub_group_size (WARP_SIZE)]] {
757- mul_mat_vec_q_reorder <reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>>(vx, vy, dst, ncols,
881+ mul_mat_vec_q_reorder_local_mem <reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>>(vx, vy, vy_local , dst, ncols,
758882 nrows, nd_item);
759883 });
760884 });
0 commit comments