|
| 1 | +#include <string.h> |
| 2 | +#include <stdlib.h> |
| 3 | +#include <math.h> |
| 4 | +#include <HAP_farf.h> |
| 5 | +#include <HAP_perf.h> |
| 6 | + |
| 7 | +#define GGML_COMMON_DECL_C |
| 8 | +#include "ggml-common.h" |
| 9 | +#include "ggml.h" |
| 10 | + |
| 11 | +#include "hvx-utils.h" |
| 12 | +#include "hex-dma.h" |
| 13 | + |
| 14 | +#include "htp-ctx.h" |
| 15 | +#include "htp-msg.h" |
| 16 | +#include "htp-ops.h" |
| 17 | + |
| 18 | +#ifndef MIN |
| 19 | +#define MIN(a, b) ((a) < (b) ? (a) : (b)) |
| 20 | +#endif |
| 21 | + |
| 22 | +struct htp_argsort_context { |
| 23 | + struct htp_ops_context * octx; |
| 24 | + uint32_t nrows_per_thread; |
| 25 | + struct fastdiv_values div_ne01; |
| 26 | + struct fastdiv_values div_ne02_ne01; |
| 27 | +}; |
| 28 | + |
| 29 | +// Scalar sort implementation since std::sort is not available. |
| 30 | +// Sorts indices based on values. |
| 31 | +static void quicksort_indices_asc(int32_t * indices, const float * data, int left, int right) { |
| 32 | + if (left >= right) return; |
| 33 | + |
| 34 | + int pivot_idx = indices[(left + right) / 2]; |
| 35 | + float pivot = data[pivot_idx]; |
| 36 | + int i = left; |
| 37 | + int j = right; |
| 38 | + |
| 39 | + while (i <= j) { |
| 40 | + while (data[indices[i]] < pivot) i++; |
| 41 | + while (data[indices[j]] > pivot) j--; |
| 42 | + if (i <= j) { |
| 43 | + int32_t tmp = indices[i]; |
| 44 | + indices[i] = indices[j]; |
| 45 | + indices[j] = tmp; |
| 46 | + i++; |
| 47 | + j--; |
| 48 | + } |
| 49 | + } |
| 50 | + |
| 51 | + if (left < j) quicksort_indices_asc(indices, data, left, j); |
| 52 | + if (i < right) quicksort_indices_asc(indices, data, i, right); |
| 53 | +} |
| 54 | + |
| 55 | +static void quicksort_indices_desc(int32_t * indices, const float * data, int left, int right) { |
| 56 | + if (left >= right) return; |
| 57 | + |
| 58 | + int pivot_idx = indices[(left + right) / 2]; |
| 59 | + float pivot = data[pivot_idx]; |
| 60 | + int i = left; |
| 61 | + int j = right; |
| 62 | + |
| 63 | + while (i <= j) { |
| 64 | + while (data[indices[i]] > pivot) i++; |
| 65 | + while (data[indices[j]] < pivot) j--; |
| 66 | + if (i <= j) { |
| 67 | + int32_t tmp = indices[i]; |
| 68 | + indices[i] = indices[j]; |
| 69 | + indices[j] = tmp; |
| 70 | + i++; |
| 71 | + j--; |
| 72 | + } |
| 73 | + } |
| 74 | + |
| 75 | + if (left < j) quicksort_indices_desc(indices, data, left, j); |
| 76 | + if (i < right) quicksort_indices_desc(indices, data, i, right); |
| 77 | +} |
| 78 | + |
| 79 | +static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) { |
| 80 | + struct htp_argsort_context * actx = (struct htp_argsort_context *)data; |
| 81 | + struct htp_ops_context * octx = actx->octx; |
| 82 | + |
| 83 | + // Unpack context |
| 84 | + const struct htp_tensor * src0 = &octx->src0; |
| 85 | + const struct htp_tensor * dst = &octx->dst; |
| 86 | + |
| 87 | + // Scratchpad memory |
| 88 | + uint8_t * spad = octx->src0_spad.data + octx->src0_spad.size_per_thread * i; |
| 89 | + |
| 90 | + // Dimensions |
| 91 | + uint32_t ne00 = src0->ne[0]; |
| 92 | + uint32_t ne01 = src0->ne[1]; |
| 93 | + uint32_t ne02 = src0->ne[2]; |
| 94 | + uint32_t ne03 = src0->ne[3]; |
| 95 | + |
| 96 | + uint32_t nb01 = src0->nb[1]; |
| 97 | + uint32_t nb02 = src0->nb[2]; |
| 98 | + uint32_t nb03 = src0->nb[3]; |
| 99 | + |
| 100 | + uint32_t nb1 = dst->nb[1]; |
| 101 | + uint32_t nb2 = dst->nb[2]; |
| 102 | + uint32_t nb3 = dst->nb[3]; |
| 103 | + |
| 104 | + // Sort order |
| 105 | + enum ggml_sort_order order = (enum ggml_sort_order) octx->op_params[0]; |
| 106 | + |
| 107 | + // Rows to process |
| 108 | + uint32_t total_rows = ne01 * ne02 * ne03; |
| 109 | + uint32_t rows_per_thread = actx->nrows_per_thread; |
| 110 | + uint32_t start_row = rows_per_thread * i; |
| 111 | + uint32_t end_row = MIN(start_row + rows_per_thread, total_rows); |
| 112 | + |
| 113 | + // Scratchpad layout: |
| 114 | + // We need space for one row of float data (values) and one row of int32 indices. |
| 115 | + // values: ne00 * sizeof(float) |
| 116 | + // indices: ne00 * sizeof(int32_t) |
| 117 | + // Padded to 128 bytes. |
| 118 | + |
| 119 | + size_t values_size = hex_round_up(ne00 * sizeof(float), 128); |
| 120 | + float * values_buf = (float *) spad; |
| 121 | + int32_t * indices_buf = (int32_t *) (spad + values_size); |
| 122 | + |
| 123 | + for (uint32_t r = start_row; r < end_row; r++) { |
| 124 | + // Calculate indices for 3D iteration flattened using fastdiv |
| 125 | + // uint32_t i03 = r / (ne02 * ne01); |
| 126 | + // uint32_t rem = r % (ne02 * ne01); |
| 127 | + // uint32_t i02 = rem / ne01; |
| 128 | + // uint32_t i01 = rem % ne01; |
| 129 | + |
| 130 | + uint32_t i03 = fastdiv(r, &actx->div_ne02_ne01); |
| 131 | + uint32_t rem = fastmodulo(r, ne02 * ne01, &actx->div_ne02_ne01); |
| 132 | + uint32_t i02 = fastdiv(rem, &actx->div_ne01); |
| 133 | + uint32_t i01 = rem - i02 * ne01; |
| 134 | + |
| 135 | + uint32_t src_offset = i03 * nb03 + i02 * nb02 + i01 * nb01; |
| 136 | + uint32_t dst_offset = i03 * nb3 + i02 * nb2 + i01 * nb1; |
| 137 | + |
| 138 | + uint8_t * src_ptr = (uint8_t *) src0->data + src_offset; |
| 139 | + uint8_t * dst_ptr = (uint8_t *) dst->data + dst_offset; |
| 140 | + |
| 141 | + // Prefetch and Copy row data to VTCM |
| 142 | + hex_l2fetch(src_ptr, ne00 * sizeof(float), ne00 * sizeof(float), 1); |
| 143 | + |
| 144 | + // Use vector copy if available/efficient, handles unaligned |
| 145 | + hvx_copy_f32_uu((uint8_t*)values_buf, src_ptr, ne00); |
| 146 | + |
| 147 | + // Initialize indices |
| 148 | + for (uint32_t j = 0; j < ne00; j++) { |
| 149 | + indices_buf[j] = j; |
| 150 | + } |
| 151 | + |
| 152 | + // Sort indices based on values |
| 153 | + if (order == GGML_SORT_ORDER_ASC) { |
| 154 | + quicksort_indices_asc(indices_buf, values_buf, 0, ne00 - 1); |
| 155 | + } else { |
| 156 | + quicksort_indices_desc(indices_buf, values_buf, 0, ne00 - 1); |
| 157 | + } |
| 158 | + |
| 159 | + // Copy indices back to DDR |
| 160 | + // Indices are 32-bit integers, effectively same as float for copy purposes size-wise |
| 161 | + hvx_copy_f32_uu(dst_ptr, (const uint8_t *) indices_buf, ne00); |
| 162 | + } |
| 163 | +} |
| 164 | + |
| 165 | +int op_argsort(struct htp_ops_context * octx) { |
| 166 | + // Check supported types |
| 167 | + if (octx->src0.type != HTP_TYPE_F32) { |
| 168 | + return HTP_STATUS_NO_SUPPORT; |
| 169 | + } |
| 170 | + |
| 171 | + // Allocate scratchpad |
| 172 | + // We need 1 row of float + 1 row of int32 per thread. |
| 173 | + uint32_t ne00 = octx->src0.ne[0]; |
| 174 | + size_t values_size = hex_round_up(ne00 * sizeof(float), 128); |
| 175 | + size_t indices_size = hex_round_up(ne00 * sizeof(int32_t), 128); |
| 176 | + size_t spad_per_thread = values_size + indices_size; |
| 177 | + |
| 178 | + // Make sure we round up to 256 for alignment requirements |
| 179 | + spad_per_thread = hex_round_up(spad_per_thread, 256); |
| 180 | + |
| 181 | + size_t total_spad_size = spad_per_thread * octx->n_threads; |
| 182 | + |
| 183 | + if (octx->ctx->vtcm_size < total_spad_size) { |
| 184 | + FARF(ERROR, "argsort: VTCM size too small. Needed %zu, have %zu", total_spad_size, octx->ctx->vtcm_size); |
| 185 | + return HTP_STATUS_VTCM_TOO_SMALL; |
| 186 | + } |
| 187 | + |
| 188 | + octx->src0_spad.data = octx->ctx->vtcm_base; |
| 189 | + octx->src0_spad.size = total_spad_size; |
| 190 | + octx->src0_spad.size_per_thread = spad_per_thread; |
| 191 | + |
| 192 | + FARF(HIGH, "argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)", |
| 193 | + octx->src0.ne[0], octx->src0.ne[1], octx->src0.ne[2], octx->src0.ne[3], |
| 194 | + octx->dst.ne[0], octx->dst.ne[1], octx->dst.ne[2], octx->dst.ne[3], |
| 195 | + octx->src0.data, octx->dst.data); |
| 196 | + |
| 197 | + uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3]; |
| 198 | + uint32_t n_jobs = MIN(total_rows, octx->n_threads); |
| 199 | + |
| 200 | + struct htp_argsort_context actx; |
| 201 | + actx.octx = octx; |
| 202 | + actx.nrows_per_thread = (total_rows + n_jobs - 1) / n_jobs; |
| 203 | + // Initialize fastdiv values |
| 204 | + actx.div_ne01 = init_fastdiv_values(octx->src0.ne[1]); |
| 205 | + actx.div_ne02_ne01 = init_fastdiv_values(octx->src0.ne[2] * octx->src0.ne[1]); |
| 206 | + |
| 207 | + // Run jobs |
| 208 | + worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_jobs); |
| 209 | + |
| 210 | + return HTP_STATUS_OK; |
| 211 | +} |
0 commit comments