Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ynnpack/subgraph/concatenate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ ynn_status ynn_define_concatenate(ynn_subgraph_t subgraph, int32_t axis,
validate_output_tensor("concatenate", subgraph, "output_id", output_id));

const ynn_value& input0 = subgraph->value(input_ids[0]);
YNN_RETURN_IF_ERROR(
validate_axis("concatenate", "input", input0.rank(), axis));
axis = axis_to_slinky_dim(input0.rank(), axis);

// Make the output and node.
Expand Down
20 changes: 15 additions & 5 deletions ynnpack/subgraph/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ slinky::func make_reshape(ynn_runtime& runtime,

ynn_status validate_new_shape(const char* node, size_t rank,
const size_t* new_dims) {
YNN_RETURN_IF_ERROR(validate_rank(node, "new_dims", rank));
if (new_dims == nullptr && rank > 0) {
YNN_LOG_ERROR() << "For node `" << node
<< "`, new_dims must be non-null for rank > 0";
Expand Down Expand Up @@ -346,19 +347,17 @@ ynn_status ynn_define_static_expand_dims(ynn_subgraph_t subgraph,
"input_id", input_id));
YNN_RETURN_IF_ERROR(validate_output_tensor("static_expand_dims", subgraph,
"output_id", output_id));
if (new_axes == nullptr && num_new_axes > 0) {
YNN_LOG_ERROR() << "For node `static_expand_dims`, new_axes must be "
"non-null for num_new_axes > 0";
return ynn_status_invalid_parameter;
}

const ynn_value& input = subgraph->value(input_id);

ynn_value& output = subgraph->get_output_value(output_id, input);

const int new_rank = input.rank() + num_new_axes;
YNN_RETURN_IF_ERROR(validate_rank("static_expand_dims", "output", new_rank));
ynn_node::static_expand_dims op;
for (size_t i = 0; i < num_new_axes; ++i) {
YNN_RETURN_IF_ERROR(
validate_axis("static_expand_dims", "output", new_rank, new_axes[i]));
op.new_axes[axis_to_slinky_dim(new_rank, new_axes[i])] = true;
}

Expand Down Expand Up @@ -419,6 +418,10 @@ ynn_status ynn_define_fuse_dim(ynn_subgraph_t subgraph, int32_t axis,
return ynn_status_invalid_parameter;
}
const ynn_value& input = subgraph->value(input_id);
YNN_RETURN_IF_ERROR(validate_axis("fuse_dim", "input", input.rank(), axis));
YNN_RETURN_IF_ERROR(
validate_axis("fuse_dim", "input", input.rank(), axis + axes_count - 1));

ynn_node::fuse_dim op;
// Since the first axis was specified with the dims in reverse order, we
// actually want the last dim here.
Expand Down Expand Up @@ -483,6 +486,9 @@ ynn_status ynn_define_split_dim(ynn_subgraph_t subgraph, int32_t axis,
return ynn_status_invalid_parameter;
}
const ynn_value& input = subgraph->value(input_id);
YNN_RETURN_IF_ERROR(
validate_rank("split_dim", "output", input.rank() + num_splits - 1));
YNN_RETURN_IF_ERROR(validate_axis("split_dim", "input", input.rank(), axis));

ynn_value& output = subgraph->get_output_value(output_id, input);

Expand Down Expand Up @@ -546,6 +552,8 @@ ynn_status ynn_define_fuse_dims(ynn_subgraph_t subgraph, size_t num_axes,

ynn_node::fuse_dims op;
for (size_t i = 0; i < num_axes; ++i) {
YNN_RETURN_IF_ERROR(
validate_axis("fuse_dims", "input", input.rank(), axes[i]));
// Since we are reversing the axes, the first dimension to fuse is actually
// the next dimension.
op.axes[axis_to_slinky_dim(input.rank(), axes[i] + 1)] = true;
Expand Down Expand Up @@ -596,6 +604,8 @@ ynn_status ynn_define_split_dims(ynn_subgraph_t subgraph, size_t num_axes,
using split = ynn_node::split_dims::split;
ynn_node::split_dims op;
for (size_t i = 0; i < num_axes; ++i) {
YNN_RETURN_IF_ERROR(
validate_axis("split_dims", "input", input.rank(), axes[i]));
op.splits.push_back({axis_to_slinky_dim(input.rank(), axes[i]), splits[i]});
}

Expand Down
3 changes: 2 additions & 1 deletion ynnpack/subgraph/even_split.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#include "ynnpack/subgraph/slinky.h"
#include "ynnpack/subgraph/subgraph.h"
#include "slinky/builder/pipeline.h"
#include "slinky/runtime/buffer.h"
#include "slinky/runtime/expr.h"

namespace ynn {
Expand All @@ -34,6 +33,8 @@ ynn_status ynn_define_even_split(ynn_subgraph_t subgraph, int32_t axis,
"even_split", subgraph, "output_ids", num_outputs, output_ids));

const ynn_value& input = subgraph->value(input_id);
YNN_RETURN_IF_ERROR(validate_axis("even_split", "input", input.rank(), axis));

for (size_t i = 0; i < num_outputs; ++i) {
subgraph->get_output_value(&output_ids[i], input);
}
Expand Down
6 changes: 2 additions & 4 deletions ynnpack/subgraph/get_tensor_shape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,13 @@ ynn_status ynn_define_get_tensor_shape(ynn_subgraph_t subgraph, size_t num_axes,
if (op.reshape_1d) {
extents = {slinky::index_t(1)};
for (int32_t i : op.axes) {
if (input.extents[i].defined()) {
extents[0] *= input.extents[i];
}
extents[0] *= input.extent(i);
}
extents[0] = slinky::simplify(extents[0]);
} else {
extents.reserve(op.axes.size());
for (int32_t i : op.axes) {
extents.push_back(input.extents[i].defined() ? input.extents[i] : 1);
extents.push_back(input.extent(i));
}
}

Expand Down
8 changes: 6 additions & 2 deletions ynnpack/subgraph/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
#include "slinky/builder/simplify.h"
#include "slinky/runtime/buffer.h"
#include "slinky/runtime/expr.h"
#include "slinky/runtime/print.h"
#include "slinky/runtime/stmt.h"

namespace ynn {
Expand Down Expand Up @@ -387,7 +386,12 @@ ynn_status ynn_define_reduce(ynn_subgraph_t subgraph,

ynn::axes_set k_dims;
for (size_t i = 0; i < num_axes; ++i) {
k_dims[axis_to_slinky_dim(a.rank(), axes[i])] = true;
const int axis = axis_to_slinky_dim(a.rank(), axes[i]);
if (axis < a.rank()) {
k_dims[axis] = true;
} else {
// This is a reduction of an implicit broadcast, which is a no-op.
}
}
bool keep_dims = flags & YNN_NODE_FLAG_KEEP_DIMS;

Expand Down
3 changes: 3 additions & 0 deletions ynnpack/subgraph/stack.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ ynn_status ynn_define_stack(ynn_subgraph_t subgraph, int32_t axis,
validate_output_tensor("stack", subgraph, "output_id", output_id));

const ynn_value& input0 = subgraph->value(input_ids[0]);
YNN_RETURN_IF_ERROR(validate_rank("stack", "output", input0.rank() + 1));
YNN_RETURN_IF_ERROR(
validate_axis("stack", "output", input0.rank() + 1, axis));
axis = axis_to_slinky_dim(input0.rank() + 1, axis);

// Make the output and node.
Expand Down
4 changes: 3 additions & 1 deletion ynnpack/subgraph/static_pad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <algorithm>
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <memory>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -42,6 +42,8 @@ ynn_status ynn_define_static_pad(ynn_subgraph_t subgraph, size_t num_axes,
ynn_node::static_pad op;
op.paddings.reserve(num_axes);
for (size_t i = 0; i < num_axes; ++i) {
YNN_RETURN_IF_ERROR(
validate_axis("static_pad", "input", input.rank(), axes[i]));
if (pre_paddings[i] != 0 || post_paddings[i] != 0) {
op.paddings.push_back({ynn::axis_to_slinky_dim(input.rank(), axes[i]),
pre_paddings[i], post_paddings[i]});
Expand Down
6 changes: 4 additions & 2 deletions ynnpack/subgraph/static_transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ ynn_status define_static_transpose(ynn_subgraph_t subgraph,
const int elem_count = type_element_count(output.type);
output.extents.resize(permutation.size());
for (int d = 0; d < output.rank(); ++d) {
slinky::expr input_extent = input.extents[permutation[d]];
slinky::expr input_extent = input.extent(permutation[d]);
if (permutation[d] == 0 && elem_count != 1) {
// The extents are physical shapes, we need to convert to logical shapes
// when we transpose the dimensions.
Expand Down Expand Up @@ -219,9 +219,11 @@ ynn_status ynn_define_static_transpose(ynn_subgraph_t subgraph, size_t rank,
YNN_RETURN_IF_ERROR(validate_output_tensor("static_transpose", subgraph,
"output_id", output_id));
if (permutation == nullptr && rank > 0) {
YNN_LOG_ERROR() << "permutation must be non-null for rank > 0";
YNN_LOG_ERROR() << "For node `static_transpose`, permutation must be "
"non-null for rank > 0";
return ynn_status_invalid_parameter;
}
YNN_RETURN_IF_ERROR(validate_rank("static_transpose", "output", rank));

// Rewrite the permutation to be slinky dimensions.
const ynn_value& input = subgraph->value(input_id);
Expand Down
8 changes: 7 additions & 1 deletion ynnpack/subgraph/stencil_copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,18 @@ ynn_status ynn_define_stencil_copy(ynn_subgraph_t subgraph, size_t num_stencils,
"stencil_copy", subgraph, "padding_id", padding_id, /*optional=*/true));
YNN_RETURN_IF_ERROR(
validate_output_tensor("stencil_copy", subgraph, "output_id", output_id));
const ynn_value& input = subgraph->value(input_id);
YNN_RETURN_IF_ERROR(
validate_rank("stencil_copy", "output", input.rank() + num_stencils));

ynn_node node;
const ynn_value& input = subgraph->value(input_id);
ynn_node::stencil_copy op_data;
op_data.stencils.reserve(num_stencils);
for (size_t i = 0; i < num_stencils; ++i) {
YNN_RETURN_IF_ERROR(
validate_axis("stencil_copy", "input", input.rank(), stencil_axes[i]));
YNN_RETURN_IF_ERROR(validate_axis(
"stencil_copy", "output", input.rank() + num_stencils, new_axes[i]));
op_data.stencils.push_back({
// Swap the axes to get the slinky dimensions.
.axis = axis_to_slinky_dim(input.rank(), stencil_axes[i]),
Expand Down
35 changes: 33 additions & 2 deletions ynnpack/subgraph/subgraph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,27 @@ ynn_status validate_subgraph(const char* node, ynn_subgraph_t subgraph) {
return ynn_status_success;
}

ynn_status validate_rank(const char* node, const char* rank_of, size_t rank) {
if (rank > YNN_MAX_TENSOR_RANK) {
YNN_LOG_ERROR() << "For node `" << node << "`, rank " << rank
<< " of tensor `" << rank_of
<< "` exceeds YNN_MAX_TENSOR_RANK " << YNN_MAX_TENSOR_RANK;
return ynn_status_unsupported_parameter;
}
return ynn_status_success;
}

ynn_status validate_axis(const char* node, const char* axis_of, int rank,
int32_t axis) {
if (axis < -rank || axis >= rank) {
YNN_LOG_ERROR() << "For node `" << node << "`, axis " << axis
<< " exceeds rank " << rank << " of tensor `" << axis_of
<< "`";
return ynn_status_invalid_parameter;
}
return ynn_status_success;
}

ynn_status validate_input_tensor(const char* node, ynn_subgraph_t subgraph,
const char* name, uint32_t id, bool optional) {
if (optional && id == YNN_INVALID_VALUE_ID) {
Expand All @@ -68,7 +89,12 @@ ynn_status validate_input_tensor(const char* node, ynn_subgraph_t subgraph,
ynn_status validate_input_tensor_array(const char* node,
ynn_subgraph_t subgraph,
const char* name, size_t count,
const uint32_t* ids) {
const uint32_t* ids, bool allow_empty) {
if (!allow_empty && count == 0) {
YNN_LOG_ERROR() << "For node `" << node << "`, input array `" << name
<< "` must be non-empty";
return ynn_status_invalid_parameter;
}
if (count > 0 && ids == nullptr) {
YNN_LOG_ERROR() << "For node `" << node << "`, input array `" << name
<< "` must be non-null if count is " << count;
Expand Down Expand Up @@ -103,7 +129,12 @@ ynn_status validate_output_tensor(const char* node, ynn_subgraph_t subgraph,
ynn_status validate_output_tensor_array(const char* node,
ynn_subgraph_t subgraph,
const char* name, size_t count,
uint32_t* ids_out) {
uint32_t* ids_out, bool allow_empty) {
if (!allow_empty && count == 0) {
YNN_LOG_ERROR() << "For node `" << node << "`, output array `" << name
<< "` must be non-empty";
return ynn_status_invalid_parameter;
}
if (count > 0 && ids_out == nullptr) {
YNN_LOG_ERROR() << "For node `" << node << "`, output array `" << name
<< "` must be non-null if count is " << count;
Expand Down
9 changes: 7 additions & 2 deletions ynnpack/subgraph/subgraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,24 @@ ynn_status define_static_transpose(ynn_subgraph_t subgraph,

// Validation helpers for public APIs.
ynn_status validate_subgraph(const char* node, ynn_subgraph_t subgraph);
ynn_status validate_rank(const char* node, const char* rank_of, size_t rank);
ynn_status validate_axis(const char* node, const char* axis_of, int rank,
int32_t axis);
ynn_status validate_input_tensor(const char* node, ynn_subgraph_t subgraph,
const char* name, uint32_t id,
bool optional = false);
ynn_status validate_input_tensor_array(const char* node,
ynn_subgraph_t subgraph,
const char* name, size_t count,
const uint32_t* ids);
const uint32_t* ids,
bool allow_empty = false);
ynn_status validate_output_tensor(const char* node, ynn_subgraph_t subgraph,
const char* name, uint32_t* id_out);
ynn_status validate_output_tensor_array(const char* node,
ynn_subgraph_t subgraph,
const char* name, size_t count,
uint32_t* id_outs);
uint32_t* id_outs,
bool allow_empty = false);
ynn_status validate_runtime(ynn_runtime_t runtime);

} // namespace ynn
Expand Down
5 changes: 5 additions & 0 deletions ynnpack/subgraph/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ ynn_status ynn_define_tensor_value(ynn_subgraph_t subgraph, enum ynn_type type,
uint32_t scale_id, uint32_t flags,
uint32_t* id_out) {
YNN_RETURN_IF_ERROR(validate_subgraph("define_tensor", subgraph));
if (rank > YNN_MAX_TENSOR_RANK) {
YNN_LOG_ERROR() << "rank " << rank << " exceeds YNN_MAX_TENSOR_RANK "
<< YNN_MAX_TENSOR_RANK;
return ynn_status_unsupported_parameter;
}
if (!id_out) {
YNN_LOG_ERROR() << "id_out must be non-null";
return ynn_status_invalid_parameter;
Expand Down
Loading