Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions tensorflow/lite/micro/kernels/conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ TfLiteStatus ConvEval(TfLiteContext* context, TfLiteNode* node) {
break;
}
case kTfLiteInt16: {
if (bias == nullptr || bias->type == kTfLiteInt32) {
if (bias != nullptr && bias->type == kTfLiteInt32) {
reference_integer_ops::ConvPerChannel(
ConvParamsQuantized(params, data),
data.per_channel_output_multiplier, data.per_channel_output_shift,
Expand All @@ -92,16 +92,16 @@ TfLiteStatus ConvEval(TfLiteContext* context, TfLiteNode* node) {
weights_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int32_t>(
tflite::micro::GetTensorData<int32_t>(
micro_context, bias, bias_comp_td, data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<std::int32_t>(bias),
tflite::micro::GetTensorData<std::int32_t>(bias),
#endif // USE_TFLM_COMPRESSION
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
} else if (bias->type == kTfLiteInt64) {
} else if (bias == nullptr || bias->type == kTfLiteInt64) {
reference_integer_ops::ConvPerChannel(
ConvParamsQuantized(params, data),
data.per_channel_output_multiplier, data.per_channel_output_shift,
Expand All @@ -113,12 +113,12 @@ TfLiteStatus ConvEval(TfLiteContext* context, TfLiteNode* node) {
weights_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int64_t>(
tflite::micro::GetOptionalTensorData<int64_t>(
micro_context, bias, bias_comp_td, data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<std::int64_t>(bias),
tflite::micro::GetOptionalTensorData<std::int64_t>(bias),
#endif // USE_TFLM_COMPRESSION
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
Expand Down Expand Up @@ -194,4 +194,4 @@ TFLMRegistration Register_CONV_2D() {
return tflite::micro::RegisterOp(ConvInit, ConvPrepare, ConvEval);
}

} // namespace tflite
} // namespace tflite
12 changes: 6 additions & 6 deletions tensorflow/lite/micro/kernels/fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt16: {
switch (filter->type) {
case kTfLiteInt8: {
if (bias == nullptr || bias->type == kTfLiteInt32) {
if (bias != nullptr && bias->type == kTfLiteInt32) {
data.is_per_channel
? tflite::reference_integer_ops::FullyConnectedPerChannel(
FullyConnectedParamsQuantized(data),
Expand All @@ -253,13 +253,13 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) {
micro_context, filter, weights_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int32_t>(
tflite::micro::GetTensorData<int32_t>(
micro_context, bias, bias_comp_td,
data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int32_t>(bias),
tflite::micro::GetTensorData<int32_t>(bias),
#endif // USE_TFLM_COMPRESSION
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output))
Expand All @@ -273,17 +273,17 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) {
micro_context, filter, weights_comp_td,
data.weights_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int32_t>(
tflite::micro::GetTensorData<int32_t>(
micro_context, bias, bias_comp_td,
data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int32_t>(bias),
tflite::micro::GetTensorData<int32_t>(bias),
#endif // USE_TFLM_COMPRESSION
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
} else if (bias->type == kTfLiteInt64) {
} else if (bias == nullptr || bias->type == kTfLiteInt64) {
data.is_per_channel
? tflite::reference_integer_ops::FullyConnectedPerChannel(
FullyConnectedParamsQuantized(data),
Expand Down
Loading