From 8c081e63f01f179f38bef16f6371d623ea401bbb Mon Sep 17 00:00:00 2001 From: Masahiro Sakai Date: Tue, 20 Nov 2018 17:32:20 +0900 Subject: [PATCH] add experimental Numo::NArray support --- lib/menoh.rb | 35 +++++++++++++++++++++++++++++++++++ menoh.gemspec | 1 + test/menoh_test.rb | 31 +++++++++++++++++++++++++++++++ 3 files changed, 67 insertions(+) diff --git a/lib/menoh.rb b/lib/menoh.rb index 2674afb..f384667 100644 --- a/lib/menoh.rb +++ b/lib/menoh.rb @@ -1,6 +1,7 @@ require 'menoh/version' require 'menoh/menoh_native' require 'json' +require 'numo/narray' module Menoh class Menoh @@ -20,6 +21,16 @@ def make_model(option) end class MenohModel + DTYPE_TO_NUMO_NARRAY_CLASS = { + float: Numo::SFloat, + float32: Numo::SFloat, + float64: Numo::DFloat, + int8: Numo::Int8, + int16: Numo::Int16, + int32: Numo::Int32, + int64: Numo::Int64, + } + def initialize(menoh, option) if option[:input_layers].nil? || option[:input_layers].empty? raise "Required ':input_layers'" @@ -72,6 +83,30 @@ def run(dataset) yield results if block_given? results end + + def run_numo(dataset) + raise 'Invalid dataset' if !dataset.instance_of?(Array) || dataset.empty? + if dataset.length != @option[:input_layers].length + raise "Invalid input num: expected==#{@option[:input_layers].length} actual==#{dataset.length}" + end + dataset.each do |input| + set_data_str(input[:name], input[:data].to_binary) + end + + # run + native_run + + results = {} + @option[:output_layers].each do |name| + dtype = get_dtype(name) + c = DTYPE_TO_NUMO_NARRAY_CLASS[dtype] + raise InvalidDTypeError.new("unsupported dtype: #{dtype}") if c.nil? + results[name] = c.from_binary(get_data_str(name), get_shape(name)) + end + + yield results if block_given? + results + end end module Util diff --git a/menoh.gemspec b/menoh.gemspec index 6ab2d10..61f3540 100644 --- a/menoh.gemspec +++ b/menoh.gemspec @@ -26,4 +26,5 @@ Gem::Specification.new do |spec| spec.add_development_dependency 'pry' spec.add_development_dependency 'rake', '~> 10.0' spec.add_development_dependency 'rake-compiler' + spec.add_development_dependency 'numo-narray' end diff --git a/test/menoh_test.rb b/test/menoh_test.rb index 5f9a93d..162180f 100644 --- a/test/menoh_test.rb +++ b/test/menoh_test.rb @@ -1,4 +1,5 @@ require 'test_helper' +require 'numo/narray' MNIST_ONNX_FILE = 'example/data/mnist.onnx'.freeze MNIST_IN_NAME = '139900320569040'.freeze @@ -39,6 +40,36 @@ def test_menoh_basic_function end end + def test_menoh_basic_function_numo + onnx = Menoh::Menoh.new(MNIST_ONNX_FILE) + assert_instance_of(Menoh::Menoh, onnx) + batch_size = 3 + model_opt = { + backend: 'mkldnn', + input_layers: [ + { + name: MNIST_IN_NAME, + dims: [batch_size, 1, 28, 28] + } + ], + output_layers: [MNIST_OUT_NAME] + } + model = onnx.make_model(model_opt) + assert_instance_of(Menoh::MenohModel, model) + 10.times do + imageset = [ + { + name: MNIST_IN_NAME, + data: Numo::SFloat.zeros(batch_size, 1, 28, 28) + } + ] + inferenced_results = model.run_numo imageset + assert_instance_of(Hash, inferenced_results) + assert_instance_of(Numo::SFloat, inferenced_results[MNIST_OUT_NAME]) + assert_equal([batch_size, 10], inferenced_results[MNIST_OUT_NAME].shape) + end + end + def test_menoh_basic_function_with_block batch_size = 3 model_opt = {