Skip to content

Commit 5aaf5de

Browse files
hexagon: add ARGSORT op
1 parent e6e934c commit 5aaf5de

7 files changed

Lines changed: 296 additions & 18 deletions

File tree

ggml/src/ggml-hexagon/ggml-hexagon.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2111,6 +2111,21 @@ static bool ggml_hexagon_supported_get_rows(const struct ggml_hexagon_session *
21112111
return true;
21122112
}
21132113

2114+
static bool ggml_hexagon_supported_argsort(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
2115+
const struct ggml_tensor * src0 = op->src[0]; // values
2116+
const struct ggml_tensor * dst = op; // indices
2117+
2118+
if (src0->type != GGML_TYPE_F32) {
2119+
return false;
2120+
}
2121+
2122+
if (dst->type != GGML_TYPE_I32) {
2123+
return false;
2124+
}
2125+
2126+
return true;
2127+
}
2128+
21142129
static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
21152130
const int32_t * op_params = &op->op_params[0];
21162131

@@ -2316,6 +2331,17 @@ static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer *
23162331
return n_bufs;
23172332
}
23182333

2334+
static inline size_t init_argsort_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
2335+
req->op = HTP_OP_ARGSORT;
2336+
memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
2337+
2338+
size_t n_bufs = 0;
2339+
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2340+
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
2341+
2342+
return n_bufs;
2343+
}
2344+
23192345
template <bool _is_src0_constant>
23202346
static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
23212347
switch (t->op) {
@@ -2564,6 +2590,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
25642590
ggml_hexagon_dispatch_op<init_cpy_req>(sess, node, flags);
25652591
break;
25662592

2593+
case GGML_OP_ARGSORT:
2594+
ggml_hexagon_dispatch_op<init_argsort_req>(sess, node, flags);
2595+
break;
2596+
25672597
default:
25682598
GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node));
25692599
}
@@ -2968,6 +2998,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
29682998
supp = ggml_hexagon_supported_cpy(sess, op);
29692999
break;
29703000

3001+
case GGML_OP_ARGSORT:
3002+
supp = ggml_hexagon_supported_argsort(sess, op);
3003+
break;
3004+
29713005
default:
29723006
break;
29733007
}

ggml/src/ggml-hexagon/htp/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
66
include_directories(
77
${HEXAGON_SDK_ROOT}/incs
88
${HEXAGON_SDK_ROOT}/incs/stddef
9+
${CMAKE_CURRENT_SOURCE_DIR}/../../../include
910
${CMAKE_CURRENT_SOURCE_DIR}/../..
1011
${CMAKE_CURRENT_SOURCE_DIR}/..
1112
${CMAKE_CURRENT_SOURCE_DIR}
@@ -28,6 +29,7 @@ add_library(${HTP_LIB} SHARED
2829
set-rows-ops.c
2930
get-rows-ops.c
3031
cpy-ops.c
32+
argsort-ops.c
3133
)
3234

3335
target_compile_definitions(${HTP_LIB} PRIVATE
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
#include <string.h>
2+
#include <stdlib.h>
3+
#include <math.h>
4+
#include <HAP_farf.h>
5+
#include <HAP_perf.h>
6+
7+
#define GGML_COMMON_DECL_C
8+
#include "ggml-common.h"
9+
#include "ggml.h"
10+
11+
#include "hvx-utils.h"
12+
#include "hex-dma.h"
13+
14+
#include "htp-ctx.h"
15+
#include "htp-msg.h"
16+
#include "htp-ops.h"
17+
18+
#ifndef MIN
19+
#define MIN(a, b) ((a) < (b) ? (a) : (b))
20+
#endif
21+
22+
struct htp_argsort_context {
23+
struct htp_ops_context * octx;
24+
uint32_t nrows_per_thread;
25+
struct fastdiv_values div_ne01;
26+
struct fastdiv_values div_ne02_ne01;
27+
};
28+
29+
// Scalar sort implementation since std::sort is not available.
30+
// Sorts indices based on values.
31+
static void quicksort_indices_asc(int32_t * indices, const float * data, int left, int right) {
32+
if (left >= right) return;
33+
34+
int pivot_idx = indices[(left + right) / 2];
35+
float pivot = data[pivot_idx];
36+
int i = left;
37+
int j = right;
38+
39+
while (i <= j) {
40+
while (data[indices[i]] < pivot) i++;
41+
while (data[indices[j]] > pivot) j--;
42+
if (i <= j) {
43+
int32_t tmp = indices[i];
44+
indices[i] = indices[j];
45+
indices[j] = tmp;
46+
i++;
47+
j--;
48+
}
49+
}
50+
51+
if (left < j) quicksort_indices_asc(indices, data, left, j);
52+
if (i < right) quicksort_indices_asc(indices, data, i, right);
53+
}
54+
55+
static void quicksort_indices_desc(int32_t * indices, const float * data, int left, int right) {
56+
if (left >= right) return;
57+
58+
int pivot_idx = indices[(left + right) / 2];
59+
float pivot = data[pivot_idx];
60+
int i = left;
61+
int j = right;
62+
63+
while (i <= j) {
64+
while (data[indices[i]] > pivot) i++;
65+
while (data[indices[j]] < pivot) j--;
66+
if (i <= j) {
67+
int32_t tmp = indices[i];
68+
indices[i] = indices[j];
69+
indices[j] = tmp;
70+
i++;
71+
j--;
72+
}
73+
}
74+
75+
if (left < j) quicksort_indices_desc(indices, data, left, j);
76+
if (i < right) quicksort_indices_desc(indices, data, i, right);
77+
}
78+
79+
static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) {
80+
struct htp_argsort_context * actx = (struct htp_argsort_context *)data;
81+
struct htp_ops_context * octx = actx->octx;
82+
83+
// Unpack context
84+
const struct htp_tensor * src0 = &octx->src0;
85+
const struct htp_tensor * dst = &octx->dst;
86+
87+
// Scratchpad memory
88+
uint8_t * spad = octx->src0_spad.data + octx->src0_spad.size_per_thread * i;
89+
90+
// Dimensions
91+
uint32_t ne00 = src0->ne[0];
92+
uint32_t ne01 = src0->ne[1];
93+
uint32_t ne02 = src0->ne[2];
94+
uint32_t ne03 = src0->ne[3];
95+
96+
uint32_t nb01 = src0->nb[1];
97+
uint32_t nb02 = src0->nb[2];
98+
uint32_t nb03 = src0->nb[3];
99+
100+
uint32_t nb1 = dst->nb[1];
101+
uint32_t nb2 = dst->nb[2];
102+
uint32_t nb3 = dst->nb[3];
103+
104+
// Sort order
105+
enum ggml_sort_order order = (enum ggml_sort_order) octx->op_params[0];
106+
107+
// Rows to process
108+
uint32_t total_rows = ne01 * ne02 * ne03;
109+
uint32_t rows_per_thread = actx->nrows_per_thread;
110+
uint32_t start_row = rows_per_thread * i;
111+
uint32_t end_row = MIN(start_row + rows_per_thread, total_rows);
112+
113+
// Scratchpad layout:
114+
// We need space for one row of float data (values) and one row of int32 indices.
115+
// values: ne00 * sizeof(float)
116+
// indices: ne00 * sizeof(int32_t)
117+
// Padded to 128 bytes.
118+
119+
size_t values_size = hex_round_up(ne00 * sizeof(float), 128);
120+
float * values_buf = (float *) spad;
121+
int32_t * indices_buf = (int32_t *) (spad + values_size);
122+
123+
for (uint32_t r = start_row; r < end_row; r++) {
124+
// Calculate indices for 3D iteration flattened using fastdiv
125+
// uint32_t i03 = r / (ne02 * ne01);
126+
// uint32_t rem = r % (ne02 * ne01);
127+
// uint32_t i02 = rem / ne01;
128+
// uint32_t i01 = rem % ne01;
129+
130+
uint32_t i03 = fastdiv(r, &actx->div_ne02_ne01);
131+
uint32_t rem = fastmodulo(r, ne02 * ne01, &actx->div_ne02_ne01);
132+
uint32_t i02 = fastdiv(rem, &actx->div_ne01);
133+
uint32_t i01 = rem - i02 * ne01;
134+
135+
uint32_t src_offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
136+
uint32_t dst_offset = i03 * nb3 + i02 * nb2 + i01 * nb1;
137+
138+
uint8_t * src_ptr = (uint8_t *) src0->data + src_offset;
139+
uint8_t * dst_ptr = (uint8_t *) dst->data + dst_offset;
140+
141+
// Prefetch and Copy row data to VTCM
142+
hex_l2fetch(src_ptr, ne00 * sizeof(float), ne00 * sizeof(float), 1);
143+
144+
// Use vector copy if available/efficient, handles unaligned
145+
hvx_copy_f32_uu((uint8_t*)values_buf, src_ptr, ne00);
146+
147+
// Initialize indices
148+
for (uint32_t j = 0; j < ne00; j++) {
149+
indices_buf[j] = j;
150+
}
151+
152+
// Sort indices based on values
153+
if (order == GGML_SORT_ORDER_ASC) {
154+
quicksort_indices_asc(indices_buf, values_buf, 0, ne00 - 1);
155+
} else {
156+
quicksort_indices_desc(indices_buf, values_buf, 0, ne00 - 1);
157+
}
158+
159+
// Copy indices back to DDR
160+
// Indices are 32-bit integers, effectively same as float for copy purposes size-wise
161+
hvx_copy_f32_uu(dst_ptr, (const uint8_t *) indices_buf, ne00);
162+
}
163+
}
164+
165+
int op_argsort(struct htp_ops_context * octx) {
166+
// Check supported types
167+
if (octx->src0.type != HTP_TYPE_F32) {
168+
return HTP_STATUS_NO_SUPPORT;
169+
}
170+
171+
// Allocate scratchpad
172+
// We need 1 row of float + 1 row of int32 per thread.
173+
uint32_t ne00 = octx->src0.ne[0];
174+
size_t values_size = hex_round_up(ne00 * sizeof(float), 128);
175+
size_t indices_size = hex_round_up(ne00 * sizeof(int32_t), 128);
176+
size_t spad_per_thread = values_size + indices_size;
177+
178+
// Make sure we round up to 256 for alignment requirements
179+
spad_per_thread = hex_round_up(spad_per_thread, 256);
180+
181+
size_t total_spad_size = spad_per_thread * octx->n_threads;
182+
183+
if (octx->ctx->vtcm_size < total_spad_size) {
184+
FARF(ERROR, "argsort: VTCM size too small. Needed %zu, have %zu", total_spad_size, octx->ctx->vtcm_size);
185+
return HTP_STATUS_VTCM_TOO_SMALL;
186+
}
187+
188+
octx->src0_spad.data = octx->ctx->vtcm_base;
189+
octx->src0_spad.size = total_spad_size;
190+
octx->src0_spad.size_per_thread = spad_per_thread;
191+
192+
FARF(HIGH, "argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)",
193+
octx->src0.ne[0], octx->src0.ne[1], octx->src0.ne[2], octx->src0.ne[3],
194+
octx->dst.ne[0], octx->dst.ne[1], octx->dst.ne[2], octx->dst.ne[3],
195+
octx->src0.data, octx->dst.data);
196+
197+
uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3];
198+
uint32_t n_jobs = MIN(total_rows, octx->n_threads);
199+
200+
struct htp_argsort_context actx;
201+
actx.octx = octx;
202+
actx.nrows_per_thread = (total_rows + n_jobs - 1) / n_jobs;
203+
// Initialize fastdiv values
204+
actx.div_ne01 = init_fastdiv_values(octx->src0.ne[1]);
205+
actx.div_ne02_ne01 = init_fastdiv_values(octx->src0.ne[2] * octx->src0.ne[1]);
206+
207+
// Run jobs
208+
worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_jobs);
209+
210+
return HTP_STATUS_OK;
211+
}

ggml/src/ggml-hexagon/htp/htp-msg.h

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ enum htp_op {
6464
HTP_OP_SCALE = 16,
6565
HTP_OP_GET_ROWS = 17,
6666
HTP_OP_CPY = 18,
67+
HTP_OP_ARGSORT = 19,
6768
INVALID
6869
};
6970

@@ -103,22 +104,6 @@ static inline size_t htp_type_nbytes(uint32_t t) {
103104
return 0;
104105
}
105106

106-
static const char * htp_type_name(uint32_t t) {
107-
switch (t) {
108-
case HTP_TYPE_F32:
109-
return "fp32";
110-
case HTP_TYPE_F16:
111-
return "fp16";
112-
case HTP_TYPE_Q4_0:
113-
return "q4_0";
114-
case HTP_TYPE_Q8_0:
115-
return "q8_0";
116-
case HTP_TYPE_MXFP4:
117-
return "mxfp4";
118-
}
119-
return 0;
120-
}
121-
122107
// Internal types
123108
#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128)
124109
#define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks

ggml/src/ggml-hexagon/htp/htp-ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,5 +98,6 @@ int op_flash_attn_ext(struct htp_ops_context * octx);
9898
int op_set_rows(struct htp_ops_context * octx);
9999
int op_get_rows(struct htp_ops_context * octx);
100100
int op_cpy(struct htp_ops_context * octx);
101+
int op_argsort(struct htp_ops_context * octx);
101102

102103
#endif /* HTP_OPS_H */

ggml/src/ggml-hexagon/htp/hvx-copy.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,6 @@ static inline void hvx_copy_f32_uu(uint8_t * restrict dst, const uint8_t * restr
136136
dst_type * restrict vdst = (dst_type *) dst; \
137137
src_type * restrict vsrc = (src_type *) src; \
138138
\
139-
const HVX_Vector zero = Q6_V_vsplat_R(0); \
140-
\
141139
const uint32_t elem_size = sizeof(__fp16); \
142140
const uint32_t epv = 128 / elem_size; \
143141
const uint32_t nvec = n / epv; \

ggml/src/ggml-hexagon/htp/main.c

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,45 @@ static void proc_matmul_req(struct htp_context * ctx,
440440
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
441441
}
442442

443+
static void proc_argsort_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
444+
struct dspqueue_buffer rsp_bufs[1];
445+
446+
// We had written to the output buffer, we'd also need to flush it
447+
rsp_bufs[0].fd = bufs[1].fd;
448+
rsp_bufs[0].ptr = bufs[1].ptr;
449+
rsp_bufs[0].offset = bufs[1].offset;
450+
rsp_bufs[0].size = bufs[1].size;
451+
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
452+
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
453+
454+
// Setup Op context
455+
struct htp_ops_context octx = { 0 };
456+
octx.ctx = ctx;
457+
octx.src0 = req->src0;
458+
octx.dst = req->dst;
459+
octx.flags = req->flags;
460+
octx.op = req->op;
461+
462+
memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
463+
464+
// Update data pointers
465+
octx.src0.data = (uint32_t) bufs[0].ptr;
466+
octx.dst.data = (uint32_t) bufs[1].ptr;
467+
octx.n_threads = ctx->n_threads;
468+
469+
struct profile_data prof;
470+
profile_start(&prof);
471+
472+
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
473+
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
474+
rsp_status = op_argsort(&octx);
475+
vtcm_release(ctx);
476+
}
477+
478+
profile_stop(&prof);
479+
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
480+
}
481+
443482
static void proc_cpy_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
444483
struct dspqueue_buffer rsp_bufs[1];
445484

@@ -1035,6 +1074,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
10351074
proc_cpy_req(ctx, &req, bufs);
10361075
break;
10371076

1077+
case HTP_OP_ARGSORT:
1078+
if (n_bufs != 2) {
1079+
FARF(ERROR, "Bad argsort-req buffer list");
1080+
continue;
1081+
}
1082+
proc_argsort_req(ctx, &req, bufs);
1083+
break;
1084+
10381085
default:
10391086
FARF(ERROR, "Unknown Op %u", req.op);
10401087
break;

0 commit comments

Comments
 (0)