Skip to content

Commit 966bda6

Browse files
veblushTFLM-bot
andauthored
Added UINT4 support (Sync from github.com/tensorflow/tensorflow) (#3283)
* Added UINT4 support * Sync from upstream TF. --------- Co-authored-by: TFLM-bot <tflm-github-bot@google.com>
1 parent 0efdd51 commit 966bda6

10 files changed

Lines changed: 31 additions & 10 deletions

File tree

python/tflite_micro/numpy_utils.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
5555
return NPY_INT16;
5656
case kTfLiteUInt8:
5757
return NPY_UINT8;
58+
case kTfLiteUInt4:
59+
// TODO(b/246806634): NPY_UINT4 currently doesn't exist
60+
return NPY_BYTE;
5861
case kTfLiteInt4:
5962
// TODO(b/246806634): NPY_INT4 currently doesn't exist
6063
return NPY_BYTE;

tensorflow/compiler/mlir/lite/core/c/tflite_types.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ typedef enum {
6565
kTfLiteInt4 = 18,
6666
kTfLiteBFloat16 = 19,
6767
kTfLiteInt2 = 20,
68+
kTfLiteUInt4 = 21,
6869
} TfLiteType;
6970
// LINT.ThenChange(//tensorflow/lite/profiling/proto/model_runtime_info.proto:EdgeDataType)
7071

tensorflow/compiler/mlir/lite/schema/schema.fbs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ enum TensorType : byte {
6262
INT4 = 17,
6363
BFLOAT16 = 18,
6464
INT2 = 19,
65+
UINT4 = 20,
6566
}
6667

6768
// Custom quantization parameters for experimenting with new quantization

tensorflow/lite/core/api/flatbuffer_conversions.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,9 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
10981098
case TensorType_INT2:
10991099
*type = kTfLiteInt2;
11001100
return kTfLiteOk;
1101+
case TensorType_UINT4:
1102+
*type = kTfLiteUInt4;
1103+
return kTfLiteOk;
11011104
default:
11021105
*type = kTfLiteNoType;
11031106
TF_LITE_REPORT_ERROR(error_reporter,

tensorflow/lite/core/c/common.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,8 @@ const char* TfLiteTypeGetName(TfLiteType type) {
511511
return "INT4";
512512
case kTfLiteInt2:
513513
return "INT2";
514+
case kTfLiteUInt4:
515+
return "UINT4";
514516
}
515517
return "Unknown type";
516518
}

tensorflow/lite/micro/tools/layer_by_layer.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ TfLiteStatus ConvertTensorType(TfLiteType type, TensorTypes& tensor_type) {
117117
case kTfLiteVariant:
118118
tensor_type = TensorTypes_VARIANT;
119119
return kTfLiteOk;
120+
case kTfLiteUInt4:
121+
tensor_type = TensorTypes_UINT4;
122+
return kTfLiteOk;
120123
case kTfLiteInt4:
121124
tensor_type = TensorTypes_INT4;
122125
return kTfLiteOk;

tensorflow/lite/micro/tools/layer_by_layer_schema.fbs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ enum TensorTypes : byte {
3636
INT4 = 17,
3737
BFLOAT16 = 18,
3838
INT2 = 19,
39+
UINT4 = 20,
3940
}
4041

4142
table TensorData {

tensorflow/lite/micro/tools/layer_by_layer_schema_generated.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,12 @@ enum TensorTypes : int8_t {
6060
TensorTypes_INT4 = 17,
6161
TensorTypes_BFLOAT16 = 18,
6262
TensorTypes_INT2 = 19,
63+
TensorTypes_UINT4 = 20,
6364
TensorTypes_MIN = TensorTypes_FLOAT32,
64-
TensorTypes_MAX = TensorTypes_INT2
65+
TensorTypes_MAX = TensorTypes_UINT4
6566
};
6667

67-
inline const TensorTypes (&EnumValuesTensorTypes())[20] {
68+
inline const TensorTypes (&EnumValuesTensorTypes())[21] {
6869
static const TensorTypes values[] = {
6970
TensorTypes_FLOAT32,
7071
TensorTypes_FLOAT16,
@@ -85,13 +86,14 @@ inline const TensorTypes (&EnumValuesTensorTypes())[20] {
8586
TensorTypes_UINT16,
8687
TensorTypes_INT4,
8788
TensorTypes_BFLOAT16,
88-
TensorTypes_INT2
89+
TensorTypes_INT2,
90+
TensorTypes_UINT4
8991
};
9092
return values;
9193
}
9294

9395
inline const char * const *EnumNamesTensorTypes() {
94-
static const char * const names[21] = {
96+
static const char * const names[22] = {
9597
"FLOAT32",
9698
"FLOAT16",
9799
"INT32",
@@ -112,13 +114,14 @@ inline const char * const *EnumNamesTensorTypes() {
112114
"INT4",
113115
"BFLOAT16",
114116
"INT2",
117+
"UINT4",
115118
nullptr
116119
};
117120
return names;
118121
}
119122

120123
inline const char *EnumNameTensorTypes(TensorTypes e) {
121-
if (::flatbuffers::IsOutRange(e, TensorTypes_FLOAT32, TensorTypes_INT2)) return "";
124+
if (::flatbuffers::IsOutRange(e, TensorTypes_FLOAT32, TensorTypes_UINT4)) return "";
122125
const size_t index = static_cast<size_t>(e);
123126
return EnumNamesTensorTypes()[index];
124127
}

tensorflow/lite/python/schema_py_generated.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class TensorType(object):
2828
INT4 = 17
2929
BFLOAT16 = 18
3030
INT2 = 19
31+
UINT4 = 20
3132

3233

3334
class QuantizationDetails(object):

tensorflow/lite/schema/schema_generated.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -712,11 +712,12 @@ enum TensorType : int8_t {
712712
TensorType_INT4 = 17,
713713
TensorType_BFLOAT16 = 18,
714714
TensorType_INT2 = 19,
715+
TensorType_UINT4 = 20,
715716
TensorType_MIN = TensorType_FLOAT32,
716-
TensorType_MAX = TensorType_INT2
717+
TensorType_MAX = TensorType_UINT4
717718
};
718719

719-
inline const TensorType (&EnumValuesTensorType())[20] {
720+
inline const TensorType (&EnumValuesTensorType())[21] {
720721
static const TensorType values[] = {
721722
TensorType_FLOAT32,
722723
TensorType_FLOAT16,
@@ -737,13 +738,14 @@ inline const TensorType (&EnumValuesTensorType())[20] {
737738
TensorType_UINT16,
738739
TensorType_INT4,
739740
TensorType_BFLOAT16,
740-
TensorType_INT2
741+
TensorType_INT2,
742+
TensorType_UINT4
741743
};
742744
return values;
743745
}
744746

745747
inline const char * const *EnumNamesTensorType() {
746-
static const char * const names[21] = {
748+
static const char * const names[22] = {
747749
"FLOAT32",
748750
"FLOAT16",
749751
"INT32",
@@ -764,13 +766,14 @@ inline const char * const *EnumNamesTensorType() {
764766
"INT4",
765767
"BFLOAT16",
766768
"INT2",
769+
"UINT4",
767770
nullptr
768771
};
769772
return names;
770773
}
771774

772775
inline const char *EnumNameTensorType(TensorType e) {
773-
if (::flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_INT2)) return "";
776+
if (::flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_UINT4)) return "";
774777
const size_t index = static_cast<size_t>(e);
775778
return EnumNamesTensorType()[index];
776779
}

0 commit comments

Comments
 (0)