Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions ggml/src/ggml-hexagon/ggml-hexagon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2764,6 +2764,21 @@ static bool ggml_hexagon_supported_solve_tri(const struct ggml_hexagon_session *
return true;
}

static bool ggml_hexagon_supported_tri(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {

const struct ggml_tensor * src0 = op->src[0];
const struct ggml_tensor * dst = op;

if (src0->type != GGML_TYPE_F32) { return false; }
if (dst->type != GGML_TYPE_F32) { return false; }
if (!ggml_are_same_shape(src0, dst)) { return false; }
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) { return false; }

return true;

GGML_UNUSED(sess);
}

static const char * ggml_backend_hexagon_name(ggml_backend_t backend) {
auto sess = static_cast<ggml_hexagon_session *>(backend->context);
return sess->c_name();
Expand Down Expand Up @@ -2803,6 +2818,8 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) {
case GGML_OP_FILL: return HTP_OP_FILL;
case GGML_OP_DIAG: return HTP_OP_DIAG;
case GGML_OP_SOLVE_TRI: return HTP_OP_SOLVE_TRI;
case GGML_OP_TRI: return HTP_OP_TRI;

case GGML_OP_UNARY:
switch (ggml_get_unary_op(t)) {
case GGML_UNARY_OP_SILU: return HTP_OP_UNARY_SILU;
Expand Down Expand Up @@ -3351,6 +3368,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
case GGML_OP_SOLVE_TRI:
supp = ggml_hexagon_supported_solve_tri(sess, op);
break;

case GGML_OP_TRI:
supp = ggml_hexagon_supported_tri(sess, op);
break;

default:
break;
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-hexagon/htp/htp-ctx.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,5 +106,6 @@ int op_cumsum(struct htp_ops_context * octx);
int op_fill(struct htp_ops_context * octx);
int op_diag(struct htp_ops_context * octx);
int op_solve_tri(struct htp_ops_context * octx);
int op_tri(struct htp_ops_context * octx);

#endif /* HTP_CTX_H */
2 changes: 2 additions & 0 deletions ggml/src/ggml-hexagon/htp/htp-ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ enum htp_op_code {
HTP_OP_FILL,
HTP_OP_DIAG,
HTP_OP_SOLVE_TRI,
HTP_OP_TRI,

HTP_OP_INVALID
};

Expand Down
3 changes: 3 additions & 0 deletions ggml/src/ggml-hexagon/htp/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,9 @@ static int execute_op(struct htp_ops_context * octx) {
case HTP_OP_SOLVE_TRI:
return op_solve_tri(octx);

case HTP_OP_TRI:
return op_tri(octx);

case HTP_OP_INVALID:
break;

Expand Down
114 changes: 112 additions & 2 deletions ggml/src/ggml-hexagon/htp/unary-ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#include "ggml-common.h"
#include "htp-ctx.h"
#include "htp-ops.h"
#include "htp-ops.h"

struct htp_unary_context {
struct htp_ops_context * octx;
Expand Down Expand Up @@ -277,6 +276,95 @@ static void sigmoid_f32(const float * restrict src,
}
}

static void tri_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
const uint32_t num_rows,
const uint32_t row_elems,
const size_t row_size,
int32_t * op_params,
const uint32_t ir,
const struct htp_unary_context * uctx) {

const int32_t ttype = op_params[0];
const HVX_Vector zero = hvx_vec_splat_f32(0.0f);
const uint32_t nvec = row_elems / VLEN_FP32;
const uint32_t nloe = row_elems % VLEN_FP32;

const uint32_t ne01 = uctx->octx->src[0]->ne[1];

for (uint32_t b = 0; b < num_rows; b++) {
const uint32_t abs_row = ir + b;
const uint32_t i01 = abs_row % ne01;

const HVX_Vector * restrict v_src = (const HVX_Vector *) ((const uint8_t *) src + b * row_size);
HVX_Vector * restrict v_dst = (HVX_Vector *) ((uint8_t *) dst + b * row_size);

uint32_t boundary;
int keep_left;
switch (ttype) {
case 0: boundary = i01; keep_left = 0; break; // keep col >= row
case 1: boundary = i01 + 1; keep_left = 0; break; // keep col > row
case 2: boundary = i01 + 1; keep_left = 1; break; // keep col <= row
case 3: boundary = i01; keep_left = 1; break; // keep col < row
default: boundary = 0; keep_left = 0; break;
}
if (boundary > row_elems) boundary = row_elems;

// Full HVX vectors — each starts at a 128-byte aligned offset
for (uint32_t i = 0; i < nvec; i++) {
const uint32_t vec_start = i * VLEN_FP32;
const uint32_t vec_end = vec_start + VLEN_FP32;
if (keep_left) {
if (vec_end <= boundary) {
v_dst[i] = v_src[i];
} else if (vec_start >= boundary) {
v_dst[i] = zero;
} else {
HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float));
v_dst[i] = Q6_V_vmux_QVV(mask, v_src[i], zero);
}
} else {
if (vec_end <= boundary) {
v_dst[i] = zero;
} else if (vec_start >= boundary) {
v_dst[i] = v_src[i];
} else {
HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float));
v_dst[i] = Q6_V_vmux_QVV(mask, zero, v_src[i]);
}
}
}

// Tail elements (row_elems not a multiple of VLEN_FP32)
if (nloe > 0) {
const uint32_t vec_start = nvec * VLEN_FP32;
const uint32_t vec_end = vec_start + nloe;
HVX_Vector tail_val;
if (keep_left) {
if (vec_end <= boundary) {
tail_val = v_src[nvec];
} else if (vec_start >= boundary) {
tail_val = zero;
} else {
HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float));
tail_val = Q6_V_vmux_QVV(mask, v_src[nvec], zero);
}
} else {
if (vec_end <= boundary) {
tail_val = zero;
} else if (vec_start >= boundary) {
tail_val = v_src[nvec];
} else {
HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float));
tail_val = Q6_V_vmux_QVV(mask, zero, v_src[nvec]);
}
}
hvx_vec_store_a(&v_dst[nvec], nloe * sizeof(float), tail_val);
}
}
}

static void softplus_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
Expand Down Expand Up @@ -402,6 +490,9 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
case HTP_OP_UNARY_SOFTPLUS:
softplus_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
break;
case HTP_OP_TRI:
tri_f32(src0_spad, dst_spad, NULL, block_size, ne00, src0_row_size_aligned, op_params, ir, uctx);
break;
default:
break;
}
Expand Down Expand Up @@ -469,6 +560,9 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
case HTP_OP_UNARY_SOFTPLUS:
op_type = "softplus-f32";
break;
case HTP_OP_TRI:
op_type = "tri-f32";
break;

default:
FARF(ERROR, "Unsupported unary Op %u\n", octx->op);
Expand Down Expand Up @@ -532,13 +626,29 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
.block = (octx->src0_spad.size_per_thread / 2) / src0_row_size_aligned,
.nc = src0->ne[0],
};

worker_pool_run_func(octx->ctx->worker_pool, unary_job_f32_per_thread, &uctx, n_threads);
}

return err;
}

int op_tri(struct htp_ops_context * octx) {
int err = HTP_STATUS_OK;

switch (octx->src[0]->type) {
case HTP_TYPE_F32:
err = execute_op_unary_f32(octx);
break;

default:
err = HTP_STATUS_NO_SUPPORT;
break;
}

return err;
}

int op_unary(struct htp_ops_context * octx) {
int err = HTP_STATUS_OK;

Expand Down