|
10 | 10 | #define MLIR_ANALYSIS_MATMULCONFIGANALYSIS_H |
11 | 11 |
|
12 | 12 | #include "gc/Dialect/Linalgx/LinalgxOps.h" |
| 13 | +#include "mlir/Dialect/DLTI/DLTI.h" |
13 | 14 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
14 | | -#include <cstring> |
| 15 | +#include "mlir/Interfaces/DataLayoutInterfaces.h" |
15 | 16 |
|
16 | 17 | namespace mlir { |
17 | 18 | namespace gc { |
18 | 19 |
|
19 | 20 | using namespace mlir; |
20 | 21 |
|
21 | | -// A mock for the taget information |
22 | | -// TODO: replace it with upstream hardware description model |
23 | 22 | struct SystemDesc { |
24 | | - |
25 | | - static int getPositiveIntFromStr(char *str, int defaultValue = 1) { |
26 | | - if (!str || strlen(str) == 0 || str[0] > '9' || str[0] < '0') { |
27 | | - return defaultValue; |
28 | | - } |
29 | | - auto val = std::stoi(str); |
30 | | - return val > 0 ? val : defaultValue; |
31 | | - } |
32 | | - |
33 | 23 | // get runtime OMP_NUM_THREADS |
34 | 24 | uint32_t getNumThreads() { |
35 | | - char *numThreads = getenv("OMP_NUM_THREADS"); |
36 | | - return getPositiveIntFromStr(numThreads, 1); |
| 25 | + std::optional<Attribute> numThreads = layout.getDevicePropertyValue( |
| 26 | + Builder(ctx).getStringAttr("CPU" /* device ID*/), |
| 27 | + Builder(ctx).getStringAttr("num_threads")); |
| 28 | + if (numThreads && isa<IntegerAttr>(*numThreads)) |
| 29 | + { |
| 30 | + return dyn_cast<IntegerAttr>(*numThreads).getInt(); |
| 31 | + } |
| 32 | + return 1; |
37 | 33 | } |
38 | 34 | // get cache size by cacheLevel |
39 | 35 | size_t getCacheSize(uint8_t cacheLevel) { |
40 | 36 | if (cacheLevel == 1) { |
41 | | - char *cacheSize = getenv("L1_CACHE_SIZE"); |
42 | | - return getPositiveIntFromStr(cacheSize, 0); |
| 37 | + std::optional<Attribute> cacheSize = layout.getDevicePropertyValue( |
| 38 | + Builder(ctx).getStringAttr("CPU" /* device ID*/), |
| 39 | + Builder(ctx).getStringAttr("L1_cache_size_in_bytes")); |
| 40 | + if (cacheSize && isa<IntegerAttr>(*cacheSize)) |
| 41 | + { |
| 42 | + return dyn_cast<IntegerAttr>(*cacheSize).getInt(); |
| 43 | + } |
43 | 44 | } else if (cacheLevel == 2) { |
44 | | - char *cacheSize = getenv("L2_CACHE_SIZE"); |
45 | | - return getPositiveIntFromStr(cacheSize, 0); |
| 45 | + std::optional<Attribute> cacheSize = layout.getDevicePropertyValue( |
| 46 | + Builder(ctx).getStringAttr("CPU" /* device ID*/), |
| 47 | + Builder(ctx).getStringAttr("L2_cache_size_in_bytes")); |
| 48 | + if (cacheSize && isa<IntegerAttr>(*cacheSize)) |
| 49 | + { |
| 50 | + return dyn_cast<IntegerAttr>(*cacheSize).getInt(); |
| 51 | + } |
46 | 52 | } else if (cacheLevel == 3) { |
47 | | - char *cacheSize = getenv("L3_CACHE_SIZE"); |
48 | | - return getPositiveIntFromStr(cacheSize, 0); |
| 53 | + std::optional<Attribute> cacheSize = layout.getDevicePropertyValue( |
| 54 | + Builder(ctx).getStringAttr("CPU" /* device ID*/), |
| 55 | + Builder(ctx).getStringAttr("L3_cache_size_in_bytes")); |
| 56 | + if (cacheSize && isa<IntegerAttr>(*cacheSize)) { |
| 57 | + return dyn_cast<IntegerAttr>(*cacheSize).getInt(); |
| 58 | + } |
49 | 59 | } |
50 | 60 | return 0; |
51 | 61 | } |
52 | 62 |
|
53 | 63 | // get the maximum vector length in bits |
54 | 64 | size_t getMaxVectorLength() { |
55 | | - char *maxVectorLanes = getenv("MAX_VECTOR_LENGTH"); |
56 | | - return getPositiveIntFromStr(maxVectorLanes, 512); |
| 65 | + std::optional<Attribute> maxVectorLength = layout.getDevicePropertyValue( |
| 66 | + Builder(ctx).getStringAttr("CPU" /* device ID*/), |
| 67 | + Builder(ctx).getStringAttr("max_vector_width")); |
| 68 | + if (maxVectorLength && isa<IntegerAttr>(*maxVectorLength)) |
| 69 | + { |
| 70 | + return dyn_cast<IntegerAttr>(*maxVectorLength).getInt(); |
| 71 | + } |
| 72 | + return 512; |
57 | 73 | } |
| 74 | + |
| 75 | + SystemDesc(ModuleOp m) : layout(m), ctx(m->getContext()) {} |
| 76 | + |
| 77 | + private: |
| 78 | + DataLayout layout; |
| 79 | + MLIRContext *ctx; |
58 | 80 | }; |
59 | 81 |
|
60 | 82 | // The configuration for matmul tiling |
61 | 83 | // TODO: support batch matmul |
62 | 84 | struct MatmulConfig { |
63 | 85 | // The number of threads distributed to M, N, K |
64 | 86 | uint32_t MThreads, NThreads, KThreads; |
65 | | - // The innermost block size for M, N, K which will be directly converted to |
66 | | - // brgemm. |
67 | | - uint32_t innerMostMBlock, innerMostNBlock, innerMostKBlock; |
68 | 87 | // The outer block size for M, N, K which will be used to decide the loop tile |
69 | 88 | // size in single thread |
70 | 89 | uint32_t MBlock, NBlock, KBlock; |
| 90 | + // The innermost block size for M, N, K which will be directly converted to |
| 91 | + // brgemm. |
| 92 | + uint32_t innerMostMBlock, innerMostNBlock, innerMostKBlock; |
71 | 93 | }; |
72 | 94 |
|
73 | 95 | enum DimType { Batch, M, N, K }; |
|
0 commit comments