-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtensor.h
More file actions
68 lines (53 loc) · 2.41 KB
/
tensor.h
File metadata and controls
68 lines (53 loc) · 2.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#ifndef TENSOR_H
#define TENSOR_H
#include <stdlib.h>
#include <stdint.h>
typedef struct Tensor Tensor;
Tensor *tensorCreate(size_t ndim, const size_t *shape);
Tensor *tensorCreateFromData(size_t ndim, const size_t *shape, const DynamicArray *src);
Tensor *tensorScalar(double value);
Tensor *tensorVector(size_t n);
Tensor *tensorVectorFromData(size_t n, const double *data);
Tensor *tensorMatrix(size_t rows, size_t cols);
Tensor *tensorMatrixFromData(size_t rows, size_t cols, const double *data);
void tensorDestroy(Tensor *t);
double tensorGet(const Tensor *t, const size_t *indices);
void tensorSet(Tensor *t, const size_t *indices, double value);
void tensorPrint(const Tensor *t);
Tensor *tensorAdd(const Tensor *a, const Tensor *b);
Tensor *tensorSub(const Tensor *a, const Tensor *b);
Tensor *tensorHadamardMul(const Tensor *a, const Tensor *b);
Tensor *tensorHadamardDiv(const Tensor *a, const Tensor *b);
Tensor *tensorAddScalar(const Tensor *t, double scalar);
Tensor *tensorMulScalar(const Tensor *t, double scalar);
Tensor *tensorSubScalar(const Tensor *t, double scalar);
Tensor *tensorDivScalar(const Tensor *t, double scalar);
typedef double (*BinaryOp)(double, double);
typedef double (*ScalarOp)(double, double);
typedef double (*UnaryOp)(double);
typedef double (*ReduceOp)(double acc, double x);
Tensor *tensorApply(const Tensor *t, UnaryOp op);
double tensorReduce(const Tensor *t, ReduceOp op, double init);
double tensorSum(const Tensor *t);
double tensorProd(const Tensor *t);
double tensorL2Norm(const Tensor *t);
double tensorMin(const Tensor *t);
double tensorMax(const Tensor *t);
double tensorMean(const Tensor *t);
double tensorAll(const Tensor *t);
double tensorAny(const Tensor *t);
Tensor *tensorNeg(const Tensor *t);
Tensor *tensorAbs(const Tensor *t);
Tensor *tensorExp(const Tensor *t);
Tensor *tensorLog(const Tensor *t);
size_t tensorArgMin(const Tensor *t);
size_t tensorArgMax(const Tensor *t);
Tensor *tensorMul(const Tensor *a, const Tensor *b);
Tensor *tensorDot(const Tensor *a, const Tensor *b);
Tensor *tensorMatMul(const Tensor *a, const Tensor *b);
Tensor *tensorBatchedMatMul(const Tensor *a, const Tensor *b);
Tensor *tensorTranspose(const Tensor *t, size_t axis1, size_t axis2);
Tensor *tensorTranspose2D(const Tensor *t);
Tensor *tensorReshape(const Tensor *t, size_t new_ndim, const size_t *new_shape);
Tensor *tensorReshapeInfer(const Tensor *t, size_t new_ndim, const size_t *shape);
#endif