-
Notifications
You must be signed in to change notification settings - Fork 28
Open
Description
We tried to convert ESWT-12-12_LSR_x4.pth model from torch to tflight. We should use Flex tf ops as not all layers were converted initially but finally model was converted successfully without errors with tf ops.
On inference we have an issue
RuntimeError: tensorflow/lite/kernels/reshape.cc:92 num_input_elements != num_output_elements (0 != 8)Node number 0 (RESHAPE) failed to prepare.Node number 360 (IF) failed to prepare.
Please, tell if you tried conversion to tflight. Can you check this issue on your side?
Please, note that intermediate tf model working well.
Model was convertion scheme pth -> onnx -> tf -> tflight
Conversion script
_import numpy as np
import torch
from basicsr.models import build_model
from .utils import get_config
import onnx
import torchvision
import onnx_tf
import tensorflow as tf
from onnx import helper
def __init__(self, model_config_path, task_config_path, checkpoint_path):
self.opt = get_config(model_config_path, task_config_path, checkpoint_path)
self.device = torch.device('cpu')
self.model = build_model(self.opt).net_g.to(self.device).to(torch.float32).eval()
self.saveModel(self.model)_
_def saveModel(self, model):
modelName = "sr"
input_shape = (1, 3, 256, 256)
torch.onnx.export(model, torch.randn(input_shape), modelName + '-new.onnx', opset_version=12, input_names=['input'], output_names=['output'])
onnx_model = onnx.load(modelName + '-new.onnx')
# Convert ONNX model to TensorFlow format
tf_model = onnx_tf.backend.prepare(onnx_model)
# Export TensorFlow model
tf_model.export_graph(modelName + '.tf')
converter = tf.lite.TFLiteConverter.from_saved_model(modelName + '.tf')
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS
]
tflite_model = converter.convert()
open(modelName + '.tflite', 'wb').write(tflite_model)_
Some artifacts:
sr.tf.zip
Our inference scripts
tf:
_import tensorflow as tf
import numpy as np
from PIL import Image
import PIL
import torch
import torchvision
import torchvision.transforms as T
def swapChannelsInput(input_tensor):
input_tensor = input_tensor[tf.newaxis, ...]
out = input_tensor.numpy()
torchTensor = torch.from_numpy(out)
torchTensor = torchTensor.permute(0, 3, 1, 2)
np_arr = torchTensor.detach().cpu().numpy()
tensorflow_tensor = tf.constant(np_arr)
return tensorflow_tensor
def showOutput(res):
res = tf.squeeze(res)
res = res.numpy()
torchTensorRes = torch.from_numpy(res)
torchTensorRes = torchTensorRes.permute(1, 2, 0)
resFinal = torchTensorRes.detach().cpu().numpy()
return PIL.Image.fromarray(resFinal.astype(np.uint8))
extraction_path = "sr.tf/"
test_image_path = "frame0.jpg"
model = tf.saved_model.load(extraction_path)
infer = model.signatures["serving_default"]
image_np = np.array(Image.open(test_image_path))
input_tensor = tf.convert_to_tensor(image_np, tf.float32)
input_tensor = swapChannelsInput(input_tensor)
res = infer(tf.constant(input_tensor))['output']
showOutput(res).show()_
tflight:
import tensorflow as tf
import numpy as np
import cv2
from PIL import Image
import PIL
import torch
import torchvision
class TFLiteModel:
def __init__(self, model_path: str):
self.interpreter = tf.lite.Interpreter(model_path)
self.interpreter.allocate_tensors()
self.input_details = self.interpreter.get_input_details()
self.output_details = self.interpreter.get_output_details()
def predict(self, *data_args):
assert len(data_args) == len(self.input_details)
for data, details in zip(data_args, self.input_details):
self.interpreter.set_tensor(details["index"], data)
self.interpreter.invoke()
return self.interpreter.get_tensor(self.output_details[0]["index"])
model = TFLiteModel("sr_12-12.tflite")
test_image_path = "frame0.jpg"
image_np = np.array(Image.open(test_image_path))
input_tensor = tf.convert_to_tensor(image_np, tf.float32)
input_tensor = input_tensor[tf.newaxis, ...]
out = input_tensor.numpy()
torchTensor = torch.from_numpy(out)
torchTensor = torchTensor.permute(0, 3, 1, 2)
np_arr = torchTensor.detach().cpu().numpy()
tensorflow_tensor = tf.constant(np_arr)
res = model.predict(tensorflow_tensor)[0]
Metadata
Metadata
Assignees
Labels
No labels