diff --git a/cmake/config.cmake b/cmake/config.cmake index dfbe0d217893..068b17fbf09b 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -203,7 +203,8 @@ set(USE_CUBLAS OFF) set(USE_SORT ON) # Whether to build with TensorRT codegen or runtime -# Examples are available here: docs/deploy/tensorrt.rst. +# An end-to-end example is available here: +# docs/how_to/tutorials/bring_your_own_codegen.py. # # USE_TENSORRT_CODEGEN - Support for compiling a graph where supported operators are # offloaded to TensorRT. OFF/ON diff --git a/docs/how_to/tutorials/bring_your_own_codegen.py b/docs/how_to/tutorials/bring_your_own_codegen.py index b6039e493039..a0d4534cc497 100644 --- a/docs/how_to/tutorials/bring_your_own_codegen.py +++ b/docs/how_to/tutorials/bring_your_own_codegen.py @@ -18,55 +18,72 @@ """ .. _tutorial-bring-your-own-codegen: -Bring Your Own Codegen: NPU Backend Example -=========================================== - -This tutorial shows how to integrate a custom hardware backend with TVM's -BYOC framework, using the bundled example NPU backend (CPU emulation, no -real hardware required) as the worked example. You will see the key -concepts needed to offload operations to a custom accelerator: pattern -registration, graph partitioning, codegen, and runtime dispatch. - -NPUs are purpose-built accelerators designed around a fixed set of operations -common in neural network inference, such as matrix multiplication, convolution, -and activation functions. -The example backend's runtime is a *stub*: it logs the dispatch decisions an +Bring Your Own Codegen +====================== + +TVM's Bring Your Own Codegen (BYOC) framework lets you offload parts of a model +to a custom backend -- a hardware accelerator, an inference library, or your own +kernels -- while TVM compiles the rest. This tutorial has two parts: + +- **How BYOC works** -- we teach the flow with a bundled, hardware-free *example + NPU* backend and then drive the **same flow** on a real production backend, + NVIDIA TensorRT. Both run a small, hand-written model so every step is + visible; the only thing that changes between them is the backend, and that + contrast is the lesson. +- **Deploying a real model** -- we then put it to work, taking an actual PyTorch + ``nn.Module`` from export through TensorRT and running it on the GPU. + +The example NPU is a teaching stub: its runtime logs the dispatch decisions an NPU would make (memory tier, execution engine, fusion) but performs no real -computation, so output buffers are uninitialized. Assertions in this tutorial -therefore check shapes, not values. When you replace the runtime with your -hardware SDK calls, the same flow produces real results. - -**Prerequisites**: Build TVM with ``USE_EXAMPLE_NPU_CODEGEN=ON`` and -``USE_EXAMPLE_NPU_RUNTIME=ON``. +computation, so its output buffers are left uninitialized. We therefore check +*shapes*, not values, in the NPU sections -- its job is to make every BYOC step +visible with nothing hidden. TensorRT then runs the identical flow for real, so +we cross-check its result against a reference. + +**Prerequisites**: the example NPU sections need TVM built with +``USE_EXAMPLE_NPU_CODEGEN=ON`` and ``USE_EXAMPLE_NPU_RUNTIME=ON``; the TensorRT +sections need ``USE_TENSORRT_CODEGEN=ON``, ``USE_TENSORRT_RUNTIME=ON`` and +``USE_CUDA=ON`` plus a CUDA GPU and a matching TensorRT install (from NVIDIA's +``pip install tensorrt`` packages or the TensorRT archive); the final deployment +section also needs PyTorch. Each section degrades gracefully when its backend is +unavailable. """ ###################################################################### -# Overview of the BYOC Flow +# Overview of the BYOC flow # ------------------------- # -# The BYOC framework lets you plug a custom backend into TVM's compilation -# pipeline in four steps: +# BYOC plugs a custom backend into TVM's compilation pipeline in four steps: # -# 1. **Register patterns** - describe which sequences of Relax ops the -# backend can handle. +# 1. **Register patterns** - describe which sequences of Relax ops the backend +# can handle. # 2. **Partition the graph** - group matched ops into composite functions. -# 3. **Run codegen** - lower composite functions to backend-specific -# representation (JSON graph for the example NPU). -# 4. **Execute** - the runtime dispatches composite functions to the -# registered backend runtime. +# 3. **Run codegen** - lower each composite to the backend's representation +# (a JSON graph for both backends in this tutorial). +# 4. **Execute** - the runtime dispatches each composite to the backend. +# +# Steps 1 and 2 are pure Python and run anywhere; steps 3 and 4 need the +# backend's codegen and runtime compiled into TVM, which is why the +# build-and-run cells below are guarded. ###################################################################### -# Step 1: Import the backend to register its patterns -# --------------------------------------------------- +# Step 1: Import the backends to register their patterns +# ------------------------------------------------------ # -# Importing the module is enough to register all supported patterns with -# TVM's pattern registry. +# Importing a backend module registers its patterns with TVM's global registry. +# Pattern registration is independent of the C++ build -- only codegen and the +# runtime require the backend to be compiled in -- so we probe each backend and +# guard the build-and-run cells accordingly. + +import os +import tempfile import numpy as np import tvm -import tvm.relax.backend.contrib.example_npu # registers patterns +import tvm.relax.backend.contrib.example_npu from tvm import relax +from tvm.relax.backend.contrib.tensorrt import partition_for_tensorrt from tvm.relax.backend.pattern_registry import get_patterns_with_prefix from tvm.relax.transform import FuseOpsByPattern, MergeCompositeFunctions, RunCodegen from tvm.script import relax as R @@ -75,148 +92,289 @@ has_example_npu_runtime = tvm.get_global_func("runtime.ExampleNPUJSONRuntimeCreate", True) has_example_npu = has_example_npu_codegen and has_example_npu_runtime -target = tvm.target.Target("llvm") - -patterns = get_patterns_with_prefix("example_npu") -print("Registered patterns:", [p.name for p in patterns]) +has_tensorrt_codegen = tvm.get_global_func("relax.ext.tensorrt", True) is not None +_is_trt_runtime_enabled = tvm.get_global_func("relax.is_tensorrt_runtime_enabled", True) +has_tensorrt = ( + has_tensorrt_codegen and _is_trt_runtime_enabled is not None and _is_trt_runtime_enabled() +) +has_cuda = tvm.cuda(0).exist ###################################################################### -# Step 2: Define a model -# ---------------------- +# Step 2: Define the model +# ------------------------ # -# We use a simple MatMul + ReLU module to illustrate the flow. +# A single convolution followed by a ReLU. This one model is used for both +# backends. @tvm.script.ir_module -class MatmulReLU: +class ConvReLU: @R.function def main( - x: R.Tensor((2, 4), "float32"), - w: R.Tensor((4, 8), "float32"), - ) -> R.Tensor((2, 8), "float32"): + data: R.Tensor((1, 3, 32, 32), "float32"), + weight: R.Tensor((16, 3, 3, 3), "float32"), + ) -> R.Tensor((1, 16, 30, 30), "float32"): with R.dataflow(): - y = relax.op.matmul(x, w) - z = relax.op.nn.relu(y) - R.output(z) - return z + conv = relax.op.nn.conv2d(data, weight) + out = relax.op.nn.relu(conv) + R.output(out) + return out ###################################################################### -# Step 3: Partition the graph -# --------------------------- -# -# ``FuseOpsByPattern`` groups ops that match a registered pattern into -# composite functions, controlled by two flags: -# -# - ``bind_constants=False`` keeps weights as function arguments instead -# of baking them in, so the host stays in charge of parameter -# ownership. -# - ``annotate_codegen=True`` tags each composite with its backend name -# (``example_npu``); without this tag, ``RunCodegen`` has no way to -# route the composite to a backend. -# -# ``MergeCompositeFunctions`` then consolidates adjacent composites -# that target the same backend so each group becomes a single external -# call. Note that consolidation depends on the patterns themselves: an -# ``op_a + op_b`` chain only collapses into one composite if a fused -# pattern (e.g. ``matmul_relu_fused``) was registered for it; otherwise -# each op stays as its own composite even when both target the same -# backend. - -mod = MatmulReLU -mod = FuseOpsByPattern(patterns, bind_constants=False, annotate_codegen=True)(mod) -mod = MergeCompositeFunctions()(mod) -print("After partitioning:") -print(mod) - -###################################################################### -# Step 4: Run codegen -# ------------------- +# Step 3: Partition for the example NPU +# ------------------------------------- +# +# ``FuseOpsByPattern`` groups ops matching a registered pattern into composite +# functions; ``MergeCompositeFunctions`` then consolidates adjacent composites +# bound for the same backend into a single external call. Two flags steer +# partitioning: # -# ``RunCodegen`` lowers each annotated composite function to the backend's -# serialization format. For the example NPU this produces a JSON graph -# that the C++ runtime can execute. +# - ``bind_constants=False`` keeps weights as function arguments, so the host +# stays in charge of the parameters. (TensorRT below makes the opposite +# choice: it binds weights as constants because it bakes them into its engine.) +# - ``annotate_codegen=True`` wraps each matched composite in a function tagged +# with the backend name -- the tag ``RunCodegen`` routes on. (The follow-up +# ``MergeCompositeFunctions`` also attaches this tag when it groups composites, +# which is why ``partition_for_tensorrt`` below can leave the flag off.) # -# Steps 4 and 5 require TVM to be built with ``USE_EXAMPLE_NPU_CODEGEN=ON`` -# and ``USE_EXAMPLE_NPU_RUNTIME=ON``. +# The example NPU registers a fused ``conv2d + relu`` pattern with higher +# priority than the standalone ``conv2d`` pattern, so the two ops collapse into a +# single ``example_npu.conv2d_relu_fused`` composite -- look for it in the +# printed module. -if has_example_npu: - mod = RunCodegen()(mod) - print("After codegen:") - print(mod) +npu_patterns = get_patterns_with_prefix("example_npu") +npu_mod = FuseOpsByPattern(npu_patterns, bind_constants=False, annotate_codegen=True)(ConvReLU) +npu_mod = MergeCompositeFunctions()(npu_mod) +print("After partitioning for the example NPU:") +print(npu_mod) - ###################################################################### - # Step 5: Build and run - # --------------------- - # - # Build the module for the host target, create a virtual machine, and - # execute the compiled function. +###################################################################### +# Step 4: Codegen, build and run on the example NPU +# ------------------------------------------------- +# +# ``RunCodegen`` invokes each annotated composite's backend codegen, replacing it +# with the backend runtime module (here, the NPU's JSON graph); ``relax.build`` +# then compiles the remaining host-side program and links everything. Because +# the runtime is a stub that computes nothing, we assert on the output *shape* +# only -- the values are uninitialized. - np.random.seed(0) - x_np = np.random.randn(2, 4).astype("float32") - w_np = np.random.randn(4, 8).astype("float32") +np.random.seed(0) +data_np = np.random.randn(1, 3, 32, 32).astype("float32") +weight_np = np.random.randn(16, 3, 3, 3).astype("float32") - with tvm.transform.PassContext(opt_level=3): - built = relax.build(mod, target) +if has_example_npu: + npu_mod = RunCodegen()(npu_mod) - vm = relax.VirtualMachine(built, tvm.cpu()) - result = vm["main"](tvm.runtime.tensor(x_np, tvm.cpu()), tvm.runtime.tensor(w_np, tvm.cpu())) + with tvm.transform.PassContext(opt_level=3): + npu_exec = relax.build(npu_mod, tvm.target.Target("llvm")) - assert result.numpy().shape == (2, 8) - print("Execution completed. Output shape:", result.numpy().shape) + npu_vm = relax.VirtualMachine(npu_exec, tvm.cpu()) + npu_out = npu_vm["main"]( + tvm.runtime.tensor(data_np, tvm.cpu()), tvm.runtime.tensor(weight_np, tvm.cpu()) + ) + assert npu_out.numpy().shape == (1, 16, 30, 30) + print("Example NPU run completed. Output shape:", npu_out.numpy().shape) +else: + print("Example NPU backend unavailable; skipping its build and run.") ###################################################################### -# Step 6: Conv2D + ReLU -# --------------------- +# The same flow on a real backend: TensorRT +# ----------------------------------------- +# +# Steps 1-4 above are the whole mechanism. Aiming them at a real backend +# changes very little, so rather than repeat the walkthrough, here is only what +# differs for NVIDIA TensorRT: # -# The same flow applies to convolution workloads. Because the fused -# ``conv2d + relu`` pattern is registered after the standalone -# ``conv2d`` pattern in ``patterns.py`` (later entries have higher -# priority), both ops are offloaded as a single composite function. +# - **Partition in one call.** ``partition_for_tensorrt`` bundles the +# ``FuseOpsByPattern`` + ``MergeCompositeFunctions`` you ran by hand, using +# TensorRT's own pattern table. +# - **Weights become constants** (``bind_constants=True``): TensorRT bakes them +# into the engine it builds, so bind the parameters before partitioning. +# - **Real values.** TensorRT actually computes, so we build for CUDA, run on +# the GPU, and cross-check against a plain CPU build -- not just the shape. + +trt_mod = relax.transform.BindParams("main", {"weight": weight_np})(ConvReLU) +trt_mod = partition_for_tensorrt(trt_mod) +print("After partition_for_tensorrt:") +print(trt_mod) +###################################################################### +# Build for CUDA, run on the GPU, and compare against the CPU reference. -@tvm.script.ir_module -class Conv2dReLU: - @R.function - def main( - x: R.Tensor((1, 3, 32, 32), "float32"), - w: R.Tensor((16, 3, 3, 3), "float32"), - ) -> R.Tensor((1, 16, 30, 30), "float32"): - with R.dataflow(): - y = relax.op.nn.conv2d(x, w) - z = relax.op.nn.relu(y) - R.output(z) - return z +if has_tensorrt and has_cuda: + dev = tvm.cuda(0) + with tvm.transform.PassContext(opt_level=3): + trt_exec = relax.build(RunCodegen()(trt_mod), "cuda") + trt_out = relax.VirtualMachine(trt_exec, dev)["main"](tvm.runtime.tensor(data_np, dev)).numpy() + cpu_mod = relax.transform.LegalizeOps()( + relax.transform.BindParams("main", {"weight": weight_np})(ConvReLU) + ) + cpu_exec = relax.build(cpu_mod, "llvm") + cpu_out = relax.VirtualMachine(cpu_exec, tvm.cpu())["main"]( + tvm.runtime.tensor(data_np, tvm.cpu()) + ).numpy() -if has_example_npu: - mod2 = Conv2dReLU - mod2 = FuseOpsByPattern(patterns, bind_constants=False, annotate_codegen=True)(mod2) - mod2 = MergeCompositeFunctions()(mod2) - mod2 = RunCodegen()(mod2) + np.testing.assert_allclose(trt_out, cpu_out, rtol=1e-2, atol=1e-2) + print("TensorRT output shape:", trt_out.shape, "- matches the CPU reference.") +else: + print("TensorRT/CUDA unavailable; skipping the GPU build and run.") + +###################################################################### +# A real backend also exposes knobs the stub does not. Setting ``use_fp16`` +# through the ``relax.ext.tensorrt.options`` config lets TensorRT pick FP16 +# kernels, trading a little accuracy for speed; nothing else about the flow +# changes. (Other options are environment-driven: ``TVM_TENSORRT_USE_INT8`` +# enables INT8 with calibration, ``TVM_TENSORRT_MAX_WORKSPACE_SIZE`` caps the +# build workspace, and ``TVM_TENSORRT_CACHE_DIR`` caches built engines to disk +# for reuse across runs.) + +if has_tensorrt and has_cuda: + fp16_mod = partition_for_tensorrt( + relax.transform.BindParams("main", {"weight": weight_np})(ConvReLU) + ) + with tvm.transform.PassContext( + opt_level=3, config={"relax.ext.tensorrt.options": {"use_fp16": True}} + ): + fp16_exec = relax.build(RunCodegen()(fp16_mod), "cuda") + fp16_out = relax.VirtualMachine(fp16_exec, tvm.cuda(0))["main"]( + tvm.runtime.tensor(data_np, tvm.cuda(0)) + ).numpy() + + np.testing.assert_allclose(fp16_out, cpu_out, rtol=5e-2, atol=5e-2) + print("TensorRT FP16 output shape:", fp16_out.shape, "- matches within FP16 tolerance.") +else: + print("TensorRT/CUDA unavailable; skipping the FP16 build.") + +###################################################################### +# Example NPU vs TensorRT at a glance +# ----------------------------------- +# +# The same four-step flow, two backends: +# +# ========= ============================== ================================== +# Aspect Example NPU (teaching stub) TensorRT (real backend) +# ========= ============================== ================================== +# Runtime logs decisions, no compute builds and runs an nvinfer engine +# Output uninitialized (check shape) real values (cross-checked vs CPU) +# Weights ``bind_constants=False`` ``bind_constants=True`` (baked in) +# Partition two passes, by hand ``partition_for_tensorrt`` one call +# ========= ============================== ================================== +###################################################################### +# Deploying a PyTorch model with TensorRT +# --------------------------------------- +# +# Everything above used a hand-written ``IRModule`` so each op was visible. In +# practice you start from a trained model. This final section runs the *same* +# ``partition_for_tensorrt`` flow on a real PyTorch ``nn.Module``, end to end: +# export it, import it into Relax with the PyTorch frontend (the weights come in +# as constants -- exactly what TensorRT bakes into its engine), partition, build +# for CUDA, and check the GPU result against PyTorch's own output. Beyond the +# frontend import, the only difference is that the imported program returns its +# outputs as a tuple, so we index ``[0]`` for the single result tensor; the +# partition-build-run flow is otherwise unchanged. +# +# This section additionally requires PyTorch. + +try: + import torch + from torch import nn + + has_torch = True +except ImportError: + has_torch = False + +if has_torch and has_tensorrt and has_cuda: + from tvm.relax.frontend.torch import from_exported_program + + class SmallConvNet(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 8, 3) + self.conv2 = nn.Conv2d(8, 16, 3) + self.pool = nn.MaxPool2d(2) + + def forward(self, x): + x = torch.relu(self.conv1(x)) + x = self.pool(x) + x = torch.relu(self.conv2(x)) + return x + + torch_model = SmallConvNet().eval() + example_input = torch.randn(1, 3, 32, 32) + with torch.no_grad(): + torch_ref = torch_model(example_input).numpy() + exported = torch.export.export(torch_model, (example_input,)) + + torch_mod = from_exported_program(exported) + torch_mod = partition_for_tensorrt(torch_mod) + print("After importing and partitioning the PyTorch model:") + print(torch_mod) + + torch_dev = tvm.cuda(0) with tvm.transform.PassContext(opt_level=3): - built2 = relax.build(mod2, target) + torch_exec = relax.build(RunCodegen()(torch_mod), "cuda") + deployed = relax.VirtualMachine(torch_exec, torch_dev)["main"]( + tvm.runtime.tensor(example_input.numpy(), torch_dev) + )[0].numpy() - x2_np = np.random.randn(1, 3, 32, 32).astype("float32") - w2_np = np.random.randn(16, 3, 3, 3).astype("float32") + np.testing.assert_allclose(deployed, torch_ref, rtol=1e-2, atol=1e-2) + print("Deployed PyTorch model on TensorRT; output", deployed.shape, "matches PyTorch.") +else: + print("PyTorch / TensorRT / CUDA unavailable; skipping the deployment example.") - vm2 = relax.VirtualMachine(built2, tvm.cpu()) - result2 = vm2["main"]( - tvm.runtime.tensor(x2_np, tvm.cpu()), tvm.runtime.tensor(w2_np, tvm.cpu()) - ) - assert result2.numpy().shape == (1, 16, 30, 30) - print("Conv2dReLU output shape:", result2.numpy().shape) +###################################################################### +# Real deployment builds once and reuses the artifact. Export the compiled +# module to a shared library, then load and run it later -- in a fresh process, +# with no PyTorch and no rebuild needed. + +if has_torch and has_tensorrt and has_cuda: + with tempfile.TemporaryDirectory() as tmpdir: + lib_path = os.path.join(tmpdir, "deployed_trt.so") + torch_exec.export_library(lib_path) + loaded = tvm.runtime.load_module(lib_path) + reran = relax.VirtualMachine(loaded, torch_dev)["main"]( + tvm.runtime.tensor(example_input.numpy(), torch_dev) + )[0].numpy() + np.testing.assert_allclose(reran, torch_ref, rtol=1e-2, atol=1e-2) + print("Reloaded the exported library and reran; output", reran.shape, "still matches.") +else: + print("PyTorch / TensorRT / CUDA unavailable; skipping the export/reload step.") + +###################################################################### +# Notes for real deployments +# -------------------------- +# +# - **Operator coverage and fallback.** TensorRT offloads only the ops in its +# pattern table (see ``python/tvm/relax/backend/contrib/tensorrt.py``); +# anything unsupported simply stays on the host. Print the partitioned module +# and look for the ``Codegen: "tensorrt"`` functions to see what was offloaded. +# - **Dynamic shapes.** The builder sets up an optimization profile for a dynamic +# leading (batch) dimension, so the integration can serve a model exported with +# a symbolic batch size. +# - **Engine build cost.** Building a TensorRT engine is slow the first time (it +# is not a hang). Set ``TVM_TENSORRT_CACHE_DIR`` to cache built engines to +# disk and skip the rebuild on later runs. ###################################################################### # Next steps # ---------- # -# To build a real NPU backend using this example as a starting point: +# To build your own backend using the example NPU as a starting point: # -# - Replace ``example_npu_runtime.cc`` with your hardware SDK calls. +# - Replace the stub runtime in +# ``src/runtime/extra/contrib/example_npu/example_npu_runtime.cc`` with your +# hardware SDK calls. # - Extend ``patterns.py`` with the ops your hardware supports. -# - Add a C++ codegen under ``src/relax/backend/contrib/`` if your -# hardware requires a non-JSON serialization format. -# - Add your cmake module under ``cmake/modules/contrib/`` following -# the pattern in ``cmake/modules/contrib/ExampleNPU.cmake``. +# - Add a C++ codegen under ``src/relax/backend/contrib/`` if your backend needs +# a non-JSON serialization format. +# - Add a CMake module under ``cmake/modules/contrib/`` following +# ``ExampleNPU.cmake``. +# +# For a complete real-backend implementation to study, see the TensorRT +# integration: the pattern table and ``partition_for_tensorrt`` in +# ``python/tvm/relax/backend/contrib/tensorrt.py``, the codegen in +# ``src/relax/backend/contrib/tensorrt/``, and the runtime in +# ``src/runtime/extra/contrib/tensorrt/``. diff --git a/python/tvm/relax/backend/contrib/example_npu/README.md b/python/tvm/relax/backend/contrib/example_npu/README.md index 7e88a0ece00e..310670fab105 100644 --- a/python/tvm/relax/backend/contrib/example_npu/README.md +++ b/python/tvm/relax/backend/contrib/example_npu/README.md @@ -168,7 +168,7 @@ in the TVM build: - `__init__.py` - Registers the backend and its BYOC entry points with TVM so the compiler can discover and use the example NPU. ### Runtime Implementation -- `src/runtime/contrib/example_npu/example_npu_runtime.cc` - C++ runtime implementation that handles JSON-based graph execution for the NPU backend. +- `src/runtime/extra/contrib/example_npu/example_npu_runtime.cc` - C++ runtime implementation that handles JSON-based graph execution for the NPU backend. ### Tests and Examples - `tests/python/contrib/test_example_npu.py` - Comprehensive test suite containing example IRModules (e.g. `MatmulReLU`, `Conv2dReLU`) and demonstrating the complete BYOC flow from pattern registration to runtime execution. @@ -230,7 +230,7 @@ This shows the registered patterns and that matched subgraphs were turned into c - **Power management**: Support for different power modes (high_performance, balanced, low_power) ### Pattern Matching Features -- **Memory constraint checking**: Validates tensor sizes against NPU memory limits +- **Memory constraint hooks**: Placeholder checks where a real backend would reject tensors that exceed on-chip memory; the example accepts all - **Fusion opportunities**: Identifies conv+activation and other beneficial fusions - **Layout preferences**: NHWC channel-last layouts preferred by NPUs diff --git a/src/runtime/extra/contrib/example_npu/example_npu_runtime.cc b/src/runtime/extra/contrib/example_npu/example_npu_runtime.cc index 0408a3fe9acd..a0f1d0970a0f 100644 --- a/src/runtime/extra/contrib/example_npu/example_npu_runtime.cc +++ b/src/runtime/extra/contrib/example_npu/example_npu_runtime.cc @@ -18,7 +18,7 @@ */ /*! - * \file src/runtime/contrib/example_npu/example_npu_runtime.cc + * \file src/runtime/extra/contrib/example_npu/example_npu_runtime.cc * \brief Example NPU runtime demonstrating architectural concepts * * This runtime demonstrates key NPU architectural patterns: