Skip to content

Issue on inference for converted to tflight ESWT-12-12_LSR_x4 model #30

@koranten2

Description

@koranten2

We tried to convert ESWT-12-12_LSR_x4.pth model from torch to tflight. We should use Flex tf ops as not all layers were converted initially but finally model was converted successfully without errors with tf ops.
On inference we have an issue
RuntimeError: tensorflow/lite/kernels/reshape.cc:92 num_input_elements != num_output_elements (0 != 8)Node number 0 (RESHAPE) failed to prepare.Node number 360 (IF) failed to prepare.

Please, tell if you tried conversion to tflight. Can you check this issue on your side?
Please, note that intermediate tf model working well.

Model was convertion scheme pth -> onnx -> tf -> tflight
Conversion script

    _import numpy as np
  
    import torch
    from basicsr.models import build_model
    from .utils import get_config
    import onnx
    import torchvision
    import onnx_tf
    import tensorflow as tf
    from onnx import helper
    def __init__(self, model_config_path, task_config_path, checkpoint_path):
        self.opt = get_config(model_config_path, task_config_path, checkpoint_path)
        self.device = torch.device('cpu')
        self.model = build_model(self.opt).net_g.to(self.device).to(torch.float32).eval()
        self.saveModel(self.model)_

    
    _def saveModel(self, model):
        modelName = "sr"
        input_shape = (1, 3, 256, 256)
        torch.onnx.export(model, torch.randn(input_shape), modelName + '-new.onnx', opset_version=12, input_names=['input'], output_names=['output'])
        onnx_model = onnx.load(modelName + '-new.onnx')
        # Convert ONNX model to TensorFlow format
        tf_model = onnx_tf.backend.prepare(onnx_model)
        # Export  TensorFlow  model
        tf_model.export_graph(modelName + '.tf')
        converter = tf.lite.TFLiteConverter.from_saved_model(modelName + '.tf')
        converter.target_spec.supported_ops = [
          tf.lite.OpsSet.TFLITE_BUILTINS,
          tf.lite.OpsSet.SELECT_TF_OPS
        ]
        tflite_model = converter.convert()
        open(modelName + '.tflite', 'wb').write(tflite_model)_

Some artifacts:
sr.tf.zip

sr_12-12.tflight.zip

Our inference scripts

tf:

   _import tensorflow as tf
      import numpy as np
      from PIL import Image
      import PIL
      
      import torch
      import torchvision
      import torchvision.transforms as T
      
      def swapChannelsInput(input_tensor):
          input_tensor = input_tensor[tf.newaxis, ...]
      
          out = input_tensor.numpy()
          torchTensor = torch.from_numpy(out)
          torchTensor = torchTensor.permute(0, 3, 1, 2)
          np_arr = torchTensor.detach().cpu().numpy()
          tensorflow_tensor = tf.constant(np_arr)
          return tensorflow_tensor
      
      def showOutput(res):
      
          res = tf.squeeze(res)
      
          res = res.numpy()
          torchTensorRes = torch.from_numpy(res)
          torchTensorRes = torchTensorRes.permute(1, 2, 0)
          resFinal = torchTensorRes.detach().cpu().numpy()
          return PIL.Image.fromarray(resFinal.astype(np.uint8))
      
      extraction_path = "sr.tf/"
      test_image_path = "frame0.jpg"
      model = tf.saved_model.load(extraction_path)
      
      infer = model.signatures["serving_default"]
      
      image_np = np.array(Image.open(test_image_path))
      input_tensor = tf.convert_to_tensor(image_np, tf.float32)
      input_tensor = swapChannelsInput(input_tensor)
      
      res = infer(tf.constant(input_tensor))['output']
      showOutput(res).show()_

tflight:

        import tensorflow as tf
        import numpy as np
        import cv2
        
        from PIL import Image
        import PIL
        
        import torch
        import torchvision
        
        class TFLiteModel:
            def __init__(self, model_path: str):
                self.interpreter = tf.lite.Interpreter(model_path)
                self.interpreter.allocate_tensors()
        
                self.input_details = self.interpreter.get_input_details()
                self.output_details = self.interpreter.get_output_details()
        
            def predict(self, *data_args):
                assert len(data_args) == len(self.input_details)
                for data, details in zip(data_args, self.input_details):
                    self.interpreter.set_tensor(details["index"], data)
                self.interpreter.invoke()
                return self.interpreter.get_tensor(self.output_details[0]["index"])
        
        
        model = TFLiteModel("sr_12-12.tflite")
        
        test_image_path = "frame0.jpg"
        image_np = np.array(Image.open(test_image_path))
        input_tensor = tf.convert_to_tensor(image_np, tf.float32)
        
        input_tensor = input_tensor[tf.newaxis, ...]
        
        out = input_tensor.numpy()
        torchTensor = torch.from_numpy(out)
        torchTensor = torchTensor.permute(0, 3, 1, 2)
        np_arr = torchTensor.detach().cpu().numpy()
        tensorflow_tensor = tf.constant(np_arr)
        res = model.predict(tensorflow_tensor)[0]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions