diff --git a/.coveragerc b/.coveragerc index 62db6f887c..c82a90c324 100644 --- a/.coveragerc +++ b/.coveragerc @@ -19,14 +19,16 @@ exclude_lines = if False: if __name__ == .__main__.: pass + if TYPE_CHECKING: + if typing.TYPE_CHECKING: omit = # Omit files that cannot be tested dace/jupyter.py # Omit deprecated files - dace/frontend/tensorflow/__init__.py - dace/frontend/tensorflow/tensorflow.py - dace/frontend/tensorflow/winograd.py - dace/frontend/tensorflow/transformations/__init__.py - dace/frontend/tensorflow/transformations/redundant_array.py + dace/frontend/ml/tensorflow/__init__.py + dace/frontend/ml/tensorflow/tensorflow.py + dace/frontend/ml/tensorflow/winograd.py + dace/frontend/ml/tensorflow/transformations/__init__.py + dace/frontend/ml/tensorflow/transformations/redundant_array.py diff --git a/.github/workflows/copilot-setup-steps.yml b/.github/workflows/copilot-setup-steps.yml index fe60e4867c..c3d1a4088e 100644 --- a/.github/workflows/copilot-setup-steps.yml +++ b/.github/workflows/copilot-setup-steps.yml @@ -35,6 +35,6 @@ jobs: - name: Install DaCe in development mode run: | - python -m pip install --editable ".[testing,linting]" + python -m pip install --editable ".[testing,linting,ml]" pre-commit install pre-commit run diff --git a/.github/workflows/dace-updater.yml b/.github/workflows/dace-updater.yml new file mode 100644 index 0000000000..4f22121381 --- /dev/null +++ b/.github/workflows/dace-updater.yml @@ -0,0 +1,50 @@ +name: Inform the Python package index about a new DaCe release. + +on: + # Trigger for all pushes to tags matching this pattern + push: + tags: + - __gt4py-next-integration_* + + # To "install" this workflow you must enable this trigger, such that the workflow runs at least one. + # You should also disable any processing such that no commit in the index repo is performed. + # See https://stackoverflow.com/a/71057825 + #pull_request: + + # Allows to trigger the update manually. + # NOTE: Is only possible if the workflow file is located on the default and the branch where it should run on. + workflow_dispatch: + +jobs: + update-dace: + runs-on: ubuntu-latest + steps: + - name: Inform Index + shell: bash + run: | + INDEX_ORGANIZATION="gridtools" + INDEX_REPO="pypi" + + # We are using `github.sha` here to be sure that we transmit an identifier to the index + # that can be checked out. Before we used `github.ref_name` but got strange results + # with it. + DEPENDENCY_REF="${{ github.sha }}" + SOURCE_REPO="dace" + SOURCE_OWNER="gridtools" + + curl -L -v --fail-with-body \ + -X POST \ + -H "Accept: application/vnd.github+json" \ + -H "Authorization: Bearer ${{ secrets.PKG_UPDATE_TOKEN }}" \ + -H "X-GitHub-Api-Version: 2022-11-28" \ + "https://api.github.com/repos/${INDEX_ORGANIZATION}/${INDEX_REPO}/dispatches" \ + -d '{"event_type":"update_package_index","client_payload":{"source_repo":"'"${SOURCE_REPO}"'","source_org":"'"${SOURCE_OWNER}"'","dependency_ref":"'"${DEPENDENCY_REF}"'"}}' + CURL_RET=$? + + if [ "${CURL_RET}" -ne 0 ] + then + echo "POST to '${INDEX_ORGANIZATION}:${INDEX_REPO}' failed with error code '${CURL_RET}'" + exit 1 + fi + + exit 0 diff --git a/.github/workflows/fpga-ci.yml b/.github/workflows/fpga-ci.yml deleted file mode 100644 index 926d4c69e9..0000000000 --- a/.github/workflows/fpga-ci.yml +++ /dev/null @@ -1,75 +0,0 @@ -name: FPGA Tests - -on: - push: - branches: [ main, ci-fix ] - pull_request: - branches: [ main, ci-fix ] - merge_group: - branches: [ main, ci-fix ] - -env: - CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} - -concurrency: - group: ${{github.workflow}}-${{github.ref}} - cancel-in-progress: true - -jobs: - test-fpga: - if: ${{ !contains(github.event.pull_request.labels.*.name, 'no-ci') }} - runs-on: [self-hosted, linux, intel-fpga, xilinx-fpga] - steps: - - uses: actions/checkout@v4 - with: - submodules: 'recursive' - - name: Install dependencies - run: | - rm -f ~/.dace.conf - rm -rf .dacecache tests/.dacecache - python -m venv ~/.venv # create venv so we can use pip - source ~/.venv/bin/activate # activate venv - python -m pip install --upgrade pip - pip install pytest-xdist flake8 coverage click - pip uninstall -y dace - pip install -e ".[testing]" - curl -Os https://uploader.codecov.io/latest/linux/codecov - chmod +x codecov - - - name: Run FPGA Tests - run: | - source ~/.venv/bin/activate # activate venv - export COVERAGE_RCFILE=`pwd`/.coveragerc - - # Xilinx setup - export PATH=/opt/Xilinx/Vitis/2022.1/bin:/opt/Xilinx/Vitis_HLS/2022.1/bin:/opt/Xilinx/Vivado/2022.1/bin:$PATH - export XILINX_XRT=/opt/xilinx/xrt - export LD_LIBRARY_PATH=$XILINX_XRT/lib:$LD_LIBRARY_PATH - export XILINX_VITIS=/opt/Xilinx/Vitis/2022.1 - export DACE_compiler_xilinx_platform=xilinx_u250_gen3x16_xdma_4_1_202210_1 - - # Intel FPGA setup - export INTELFPGAOCLSDKROOT=/opt/intelFPGA_pro/19.1/hld - export ALTERAOCLSDKROOT=$INTELFPGAOCLSDKROOT - export AOCL_BOARD_PACKAGE_ROOT=/opt/intelFPGA_pro/19.1/hld/board/a10_ref - export PATH=$INTELFPGAOCLSDKROOT/bin:$PATH - export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$AOCL_BOARD_PACKAGE_ROOT/linux64/lib - export QUARTUS_ROOTDIR_OVERRIDE=/opt/intelFPGA_pro/19.1/quartus - export LD_PRELOAD=/lib/x86_64-linux-gnu/libstdc++.so.6 # Work around dependency issues - - # Due to an internal bug in the Xilinx tools, where the current datetime is passed as an integer - # and overflowed in the year 2022, run the FPGA tests pretending like it's January 1st 2021. - # faketime -f "@2021-01-01 00:00:00" pytest -n auto --cov-report=xml --cov=dace --tb=short -m "fpga" - # Try running without faketime - pytest -n auto --cov-report=xml --cov=dace --tb=short --timeout_method thread --timeout=300 -m "fpga" - - coverage report - coverage xml - reachable=0 - ping -W 2 -c 1 codecov.io || reachable=$? - if [ $reachable -eq 0 ]; then - ./codecov - else - echo "Codecov.io is unreachable" - fi - killall -9 xsim xsimk || true diff --git a/.github/workflows/general-ci.yml b/.github/workflows/general-ci.yml index 1d9dc3fa79..534cf26cfc 100644 --- a/.github/workflows/general-ci.yml +++ b/.github/workflows/general-ci.yml @@ -59,7 +59,7 @@ jobs: else export DACE_optimizer_automatic_simplification=${{ matrix.simplify }} fi - pytest -n auto --cov-report=xml --cov=dace --tb=short --timeout_method thread --timeout=300 -m "not gpu and not verilator and not tensorflow and not mkl and not sve and not papi and not mlir and not lapack and not fpga and not mpi and not rtl_hardware and not scalapack and not datainstrument and not long and not sequential" + pytest -n auto --cov-report=xml --cov=dace --tb=short --timeout_method thread --timeout=300 -m "not gpu and not autodiff and not torch and not onnx and not tensorflow and not mkl and not sve and not papi and not mlir and not lapack and not mpi and not scalapack and not datainstrument and not long and not sequential" ./codecov - name: Test OpenBLAS LAPACK diff --git a/.github/workflows/gpu-ci.yml b/.github/workflows/gpu-ci.yml index a7a28d7d91..68cfabfa16 100644 --- a/.github/workflows/gpu-ci.yml +++ b/.github/workflows/gpu-ci.yml @@ -37,7 +37,7 @@ jobs: pip install mpi4py pip install cupy pip uninstall -y dace - pip install -e ".[testing]" + pip install -e ".[testing,ml]" curl -Os https://uploader.codecov.io/latest/linux/codecov chmod +x codecov diff --git a/.github/workflows/hardware_test.yml b/.github/workflows/hardware_test.yml deleted file mode 100644 index 59dc201e4b..0000000000 --- a/.github/workflows/hardware_test.yml +++ /dev/null @@ -1,25 +0,0 @@ -name: DaCe RTL hardware emulation -on: workflow_dispatch -jobs: - test-rtl: - runs-on: [self-hosted, linux, xilinx-fpga] - steps: - - uses: actions/checkout@v4 - with: - submodules: 'recursive' - - name: Install dependencies - run: | - rm -f ~/.dace.conf - rm -rf .dacecache tests/.dacecache - . /opt/setupenv - python -m pip install --upgrade pip - pip install pytest-xdist flake8 - pip uninstall -y dace - pip install -e ".[testing]" - - - name: Run FPGA Tests - run: | - # Due to an internal bug in the Xilinx tools, where the current datetime is passed as an integer - # and overflowed in the year 2022, run the RTL FPGA tests pretending like it's January 1st 2021. - faketime -f "@2021-01-01 00:00:00" pytest -n auto --tb=short -m "rtl_hardware" - killall -9 xsim xsimk || true diff --git a/.github/workflows/heterogeneous-ci.yml b/.github/workflows/heterogeneous-ci.yml index 53a8788dce..baa3134ef8 100644 --- a/.github/workflows/heterogeneous-ci.yml +++ b/.github/workflows/heterogeneous-ci.yml @@ -48,7 +48,7 @@ jobs: run: | source ~/.venv/bin/activate # activate venv export DACE_cache=unique - pytest --cov-report=xml --cov=dace --tb=short --timeout_method thread --timeout=300 -m "verilator or mkl or papi or datainstrument" + pytest --cov-report=xml --cov=dace --tb=short --timeout_method thread --timeout=300 -m "mkl or papi or datainstrument" - name: Run MPI tests run: | diff --git a/.github/workflows/ml-ci.yml b/.github/workflows/ml-ci.yml new file mode 100644 index 0000000000..d6ffc83a5d --- /dev/null +++ b/.github/workflows/ml-ci.yml @@ -0,0 +1,62 @@ +name: Machine Learning and Autodiff Tests + +on: + push: + branches: [ main, ci-fix ] + pull_request: + branches: [ main, ci-fix ] + merge_group: + branches: [ main, ci-fix ] + +concurrency: + group: ${{github.workflow}}-${{github.ref}} + cancel-in-progress: true + +jobs: + test: + if: "!contains(github.event.pull_request.labels.*.name, 'no-ci')" + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['3.13'] + simplify: [0,1,autoopt] + + steps: + - uses: actions/checkout@v4 + with: + submodules: 'recursive' + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y libyaml-dev cmake + sudo apt-get install -y libblas-dev libopenblas-dev liblapacke-dev + python -m pip install --upgrade pip + pip install flake8 pytest-xdist coverage + pip install -e ".[ml-testing,ml]" + curl -Os https://uploader.codecov.io/latest/linux/codecov + chmod +x codecov + + - name: Test with pytest + run: | + export NOSTATUSBAR=1 + export DACE_testing_serialization=1 + export DACE_testing_deserialize_exception=1 + export DACE_cache=unique + if [ "${{ matrix.simplify }}" = "autoopt" ]; then + export DACE_optimizer_automatic_simplification=1 + export DACE_optimizer_autooptimize=1 + echo "Auto-optimization heuristics" + else + export DACE_optimizer_automatic_simplification=${{ matrix.simplify }} + fi + pytest --cov-report=xml --cov=dace --tb=short --timeout_method thread --timeout=600 -v -m "(torch or onnx or autodiff) and not gpu" + ./codecov + + - uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true diff --git a/.github/workflows/verilator_compatibility.yml b/.github/workflows/verilator_compatibility.yml deleted file mode 100644 index dce0c9b1fb..0000000000 --- a/.github/workflows/verilator_compatibility.yml +++ /dev/null @@ -1,37 +0,0 @@ -name: DaCe Verilator Compatibility Check -on: - workflow_dispatch: - inputs: - reason: - description: 'Reason for the trigger' - required: true - default: 'Check compatibility' - schedule: - - cron: '0 0 1 * *' # monthly -jobs: - build: - strategy: - matrix: - verilator_version: ['4.028', '4.034', '4.036', '4.100', 'master'] - runs-on: ubuntu-20.04 - steps: - - name: trigger reason - run: echo "Trigger Reason:" ${{ github.event.inputs.reason }} - - uses: actions/checkout@v4 - - name: checkout submodules - run: git submodule update --init --recursive - - name: install apt packages - run: sudo apt-get update && sudo apt-get -y install git make autoconf g++ flex bison libfl2 libfl-dev - - name: compile verilator - run: git clone https://github.com/verilator/verilator.git && cd verilator && git fetch origin && if [ ! "${{ matrix.verilator_version }}" == "master" ]; then git checkout v${{ matrix.verilator_version }}; fi && autoconf && ./configure && make -j2 && sudo make install - - uses: actions/setup-python@v5 - with: - python-version: '3.8' - architecture: 'x64' - - uses: BSFishy/pip-action@v1 - with: - packages: pytest - requirements: requirements.txt - - name: install dace - run: python3 -m pip install . - - run: pytest -m "verilator" diff --git a/.gitignore b/.gitignore index 5a54e2df44..14ea244bcd 100644 --- a/.gitignore +++ b/.gitignore @@ -150,12 +150,6 @@ perf.json perf*.csv /dace/frontend/octave/parsetab.py -# Xilinx -xilinx_vcu1525_* -sdaccel_profile_* -sdaccel_timeline_* -.run/ - # NVIDIA *.nvprof out.sdfg @@ -195,3 +189,8 @@ _build/ # Ignoring the test junk _all_tests/ + + +# Ignore downloaded ONNX models +/*.onnx +/*.bin diff --git a/.gitmodules b/.gitmodules index bc68bc3441..85d0edabec 100644 --- a/.gitmodules +++ b/.gitmodules @@ -5,12 +5,6 @@ [submodule "dace/external/moodycamel"] path = dace/external/moodycamel url = https://github.com/cameron314/concurrentqueue.git -[submodule "dace/external/hlslib"] - path = dace/external/hlslib - url = https://github.com/definelicht/hlslib.git [submodule "dace/viewer/webclient"] path = dace/viewer/webclient url = https://github.com/spcl/dace-webclient.git -[submodule "dace/external/rtllib"] - path = dace/external/rtllib - url = https://github.com/carljohnsen/rtllib.git diff --git a/MANIFEST.in b/MANIFEST.in index a24e07bf0f..f495d9c9a5 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,12 +1,8 @@ global-include LICENSE LICENSE.* -include dace/*.yml dace/codegen/CMakeLists.txt dace/codegen/tools/*.cpp dace/external/moodycamel/*.h dace/codegen/Xilinx_HLS.tcl.in dace/viewer/webclient/*.css dace/viewer/webclient/*.html dace/viewer/webclient/dist/*.js +include dace/*.yml dace/codegen/CMakeLists.txt dace/codegen/tools/*.cpp dace/external/moodycamel/*.h dace/viewer/webclient/*.css dace/viewer/webclient/*.html dace/viewer/webclient/dist/*.js recursive-include dace/codegen *.cmake graft dace/runtime/include graft dace/libraries graft dace/viewer/webclient/external_lib graft dace/viewer/templates graft dace/external/cub/cub -graft dace/external/hlslib/cmake -graft dace/external/hlslib/include -graft dace/external/rtllib/cmake -graft dace/external/rtllib/templates diff --git a/README.md b/README.md index 5f530c3d01..d69c818594 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,5 @@ [![General Tests](https://github.com/spcl/dace/actions/workflows/general-ci.yml/badge.svg)](https://github.com/spcl/dace/actions/workflows/general-ci.yml) [![GPU Tests](https://github.com/spcl/dace/actions/workflows/gpu-ci.yml/badge.svg)](https://github.com/spcl/dace/actions/workflows/gpu-ci.yml) -[![FPGA Tests](https://github.com/spcl/dace/actions/workflows/fpga-ci.yml/badge.svg)](https://github.com/spcl/dace/actions/workflows/fpga-ci.yml) [![Documentation Status](https://readthedocs.org/projects/spcldace/badge/?version=latest)](https://spcldace.readthedocs.io/en/latest/?badge=latest) [![PyPI version](https://badge.fury.io/py/dace.svg)](https://badge.fury.io/py/dace) [![codecov](https://codecov.io/gh/spcl/dace/branch/main/graph/badge.svg)](https://codecov.io/gh/spcl/dace) @@ -13,7 +12,7 @@ _Decoupling domain science from performance optimization._ DaCe is a [fast](https://nbviewer.org/github/spcl/dace/blob/main/tutorials/benchmarking.ipynb) parallel programming framework that takes code in Python/NumPy and other programming languages, and maps it to high-performance -**CPU, GPU, and FPGA** programs, which can be optimized to achieve state-of-the-art. Internally, DaCe +**CPU, GPU, and [FPGA](https://github.com/spcl/dace-fpga)** programs, which can be optimized to achieve state-of-the-art. Internally, DaCe uses the Stateful DataFlow multiGraph (SDFG) *data-centric intermediate representation*: A transformable, interactive representation of code based on data movement. diff --git a/codecov.yml b/codecov.yml index 1f7e594398..49fbd61acd 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,6 +1,8 @@ ignore: - "dace/jupyter.py" # Omit files that cannot be tested - - "dace/frontend/tensorflow/**/*" # Omit deprecated files + - "dace/frontend/ml/tensorflow/**/*" # Omit deprecated files + - "samples/**/*" + - "tests/**/*" coverage: range: 40..90 @@ -18,6 +20,6 @@ coverage: codecov: notify: - after_n_builds: 18 + after_n_builds: 23 comment: false diff --git a/dace/__init__.py b/dace/__init__.py index 823abb9111..98a44bd217 100644 --- a/dace/__init__.py +++ b/dace/__init__.py @@ -35,6 +35,17 @@ sys.path.insert(0, __external_transformations_path__) +# Lazy loading for ml module to avoid eager TensorFlow/PyTorch imports +def __getattr__(name): + if name == 'ml': + import importlib + ml_module = importlib.import_module('.ml', package='dace') + # Cache the module to avoid re-importing + globals()['ml'] = ml_module + return ml_module + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") + + # Hack that enables using @dace as a decorator # See https://stackoverflow.com/a/48100440/6489142 class DaceModule(sys.modules[__name__].__class__): diff --git a/dace/autodiff/__init__.py b/dace/autodiff/__init__.py new file mode 100644 index 0000000000..2e0e9bf746 --- /dev/null +++ b/dace/autodiff/__init__.py @@ -0,0 +1,58 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +DaCe Automatic Differentiation (AD) System. + +This module provides reverse-mode automatic differentiation for DaCe programs, +enabling automatic computation of gradients for optimized numerical kernels. + +Main Components +--------------- +- **add_backward_pass**: Main entry point for adding backward pass to an SDFG +- **BackwardPassGenerator**: Core algorithm for generating backward passes +- **BackwardImplementation**: ABC for implementing operation-specific backward rules +- **BackwardContext**: Context information for backward pass generation +- **BackwardResult**: Result of backward pass generation with forward/backward SDFGs +- **AutoDiffException**: Base exception for autodiff errors + +Key Features +------------ +- Support for control flow (loops, conditionals) +- Data forwarding strategies (store vs recompute tradeoffs) +- Extensible backward implementations for library nodes +- Integration with PyTorch autograd +- Automatic memory management for intermediate values + + +""" + +from .base_abc import BackwardImplementation, BackwardContext, BackwardResult, AutoDiffException +from .backward_pass_generator import BackwardPassGenerator +from .autodiff import add_backward_pass + +try: + from .torch import make_backward_function + TORCH_INTEGRATION_AVAILABLE = True +except ImportError: + make_backward_function = None + TORCH_INTEGRATION_AVAILABLE = False + +import sys +from . import library + +__all__ = [ + # Main API + "add_backward_pass", + # Core classes + "BackwardPassGenerator", + "BackwardContext", + "BackwardResult", + # Extension points + "BackwardImplementation", + # Exceptions + "AutoDiffException", + # Submodules + "library", +] + +if TORCH_INTEGRATION_AVAILABLE: + __all__.append("make_backward_function") diff --git a/dace/autodiff/analysis.py b/dace/autodiff/analysis.py new file mode 100644 index 0000000000..224f0db9f8 --- /dev/null +++ b/dace/autodiff/analysis.py @@ -0,0 +1,103 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Analysis helpers for autodiff +""" +from typing import Dict, Set, Tuple, Optional +import collections + +import networkx as nx + +from dace.sdfg import SDFG, SDFGState, nodes, utils as sdfg_utils +from dace.transformation.passes import analysis +from dace.sdfg.state import FunctionCallRegion + +AccessSets = Dict[SDFGState, Tuple[Set[str], Set[str]]] + + +def dependency_analysis(sdfg: SDFG) -> Dict[str, Set[str]]: + """ + Analyze read dependencies of arrays in the SDFG. + + :param sdfg: SDFG to analyze + :return: A dictionary mapping array names to a list of read dependencies. + """ + + # FIXME can be made more efficient + dependencies = nx.DiGraph() + for sdfg_node in sdfg.nodes(): + if isinstance(sdfg_node, SDFGState): + for node in sdfg_node.data_nodes(): + for edge in sdfg_node.edge_bfs(node, reverse=True): + dependencies.add_edge(node.data, edge.data.data) + elif isinstance(sdfg_node, FunctionCallRegion): + for state in sdfg_node.nodes(): + assert isinstance(state, SDFGState) + for node in state.data_nodes(): + for edge in state.edge_bfs(node, reverse=True): + dependencies.add_edge(node.data, edge.data.data) + + dependencies = nx.transitive_closure(dependencies) + result = {} + for array in dependencies: + result[array] = {nbr for nbr in dependencies.neighbors(array)} + return result + + +def inverse_reachability(sdfg: SDFG) -> Dict[SDFGState, Set[SDFGState]]: + + reachability = analysis.StateReachability().apply_pass(sdfg, {}) + inverse_reachability = collections.defaultdict(set) + # iterate over cfg_ids + for cfg_id in reachability.keys(): + for pred, successors in reachability[cfg_id].items(): + for successor in successors: + inverse_reachability[successor].add(pred) + + return inverse_reachability + + +def is_previously_written(sdfg: SDFG, + state: SDFGState, + node: nodes.Node, + array_name: str, + access_sets: Optional[AccessSets] = None) -> bool: + """ + Determine whether the given array name was written before the current node. + + :param sdfg: the sdfg containing the node + :param state: the state containing the node + :param node: the node to check + :param array_name: the array name to check + :return: True if the array was written before the node, False otherwise. + """ + + if access_sets is None: + access_sets = analysis.AccessSets().apply_pass(sdfg, {}) + + reachable = inverse_reachability(sdfg) + + # Check the current state + for subgraph in sdfg_utils.concurrent_subgraphs(state): + if node in subgraph.nodes(): + # Get all the access nodes in the subgraph to the same data + for other_node in subgraph.data_nodes(): + if other_node != node and other_node.data == array_name: + # Check if this is a write node + for in_edge in subgraph.in_edges(other_node): + if in_edge.data.data == array_name: + # Check if there's a path to our node, + # since we only care about writes that happen before the current node + if nx.has_path(state.nx, other_node, node): + return True + else: + # This is not our current subgraph, check the write states + _, write_set = subgraph.read_and_write_sets() + if array_name in write_set: + return True + + # Check other states + for predecessor in reachable[state]: + _, write_set = access_sets[predecessor] + if array_name in write_set: + return True + return False diff --git a/dace/autodiff/autodiff.md b/dace/autodiff/autodiff.md new file mode 100644 index 0000000000..cdec31941c --- /dev/null +++ b/dace/autodiff/autodiff.md @@ -0,0 +1,821 @@ +Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +# DaCe Automatic Differentiation (AD) System - Design Document + +## Table of Contents + +1. [Introduction](#1-introduction) +2. [Directory Structure](#2-directory-structure) +3. [Core Components](#3-core-components) +4. [Data Forwarding System](#4-data-forwarding-system) +5. [Backward Implementations](#5-backward-implementations) +6. [Library Integration](#6-library-integration) +7. [PyTorch Integration](#7-pytorch-integration) +8. [Gradient Accumulation and Clearing](#8-gradient-accumulation-and-clearing) + +--- + +## 1. Introduction + +### 1.1 Purpose + +The DaCe Automatic Differentiation (AD) module provides **reverse-mode automatic differentiation** for Stateful DataFlow Graphs (SDFGs). It enables automatic computation of gradients for optimized numerical kernels, making it possible to differentiate DaCe programs for machine learning, optimization, and scientific computing applications. + +### 1.2 Reverse-Mode AD Fundamentals + +Reverse-mode automatic differentiation (also known as backpropagation) computes gradients by: + +1. **Forward Pass**: Execute the original computation and record intermediate values +2. **Backward Pass**: Traverse the computation graph in reverse, accumulating gradients using the chain rule + +For a function `f: ℝⁿ → ℝᵐ`, reverse-mode AD efficiently computes the gradient when `m << n` (common in ML where loss is scalar). + +**Example**: For a composite function `y = f(g(h(x)))`: +- **Forward pass**: Compute and store intermediate values: + - `a = h(x)` + - `b = g(a)` + - `y = f(b)` +- **Backward pass**: Apply chain rule in reverse order. Given seed gradient `∂L/∂y`: + - Compute `∂L/∂b = ∂L/∂y · (∂f/∂b)` where `(∂f/∂b)` is evaluated at the stored value of `b` + - Compute `∂L/∂a = ∂L/∂b · (∂g/∂a)` where `(∂g/∂a)` is evaluated at the stored value of `a` + - Compute `∂L/∂x = ∂L/∂a · (∂h/∂x)` where `(∂h/∂x)` is evaluated at the input value `x` + +### 1.3 Key Features + +- **Control Flow Support**: Handles loops (`LoopRegion`) and conditionals +- **Data Forwarding Strategies**: Flexible tradeoff between memory (store intermediates) and computation (recompute on demand) +- **Extensible Backward Implementations**: Registry-based system for adding backward rules for new operations +- **ONNX Integration**: Differentiate ONNX neural network models imported into DaCe +- **PyTorch Compatibility**: Integration with PyTorch's autograd system via `torch.autograd.Function` +- **Library Node Support**: Backward implementations for DaCe standard library (BLAS, reductions, etc.) +- **Nested SDFG Differentiation**: Recursive backward pass generation for nested SDFGs + +### 1.4 Use Cases + +1. **Machine Learning Training**: Compute gradients for neural network parameters +2. **Sensitivity Analysis**: Determine how outputs change with respect to inputs +3. **Optimization**: Gradient-based optimization of physical simulations +4. **Inverse Problems**: Solve inverse problems by differentiating forward models +5. **Scientific Computing**: Adjoint methods for PDEs and large-scale simulations + +### 1.5 Component Interaction Flow + +``` +Input: Forward SDFG + Output Arrays + Input Arrays + ▼ +1. add_backward_pass() - Entry point + • Validate SDFG + • Simplify (optional) + • Inline control flow (conditionals) + ▼ +2. BackwardPassGenerator.__init__() + • Convert AccessNodes/strings to internal format + • Initialize mappings (reverse_map, array_grad_map, etc.) + • Set data forwarding strategy + ▼ +3. BackwardPassGenerator.backward() + • Reverse states in topological order + • For each state: + a. Reverse nodes (AccessNode, Tasklet, Map, etc.) + b. Find backward implementation via registry + c. Call implementation.backward() + d. Connect gradients with WCR + ▼ +4. DataForwardingManager.forward_data_to_backward_pass() + • Identify intermediates needed in backward pass + • Check if overwritten + • Apply strategy (store or recompute) + ▼ +5. Simplify and validate (optional) + ▼ +Output: Backward SDFG with gradients computed +``` + +--- + +## 2. Directory Structure + +### 2.1 File Organization + +``` +dace/autodiff/ +├── __init__.py # Main API exports +│ └── Exports: add_backward_pass, BackwardPassGenerator, +│ BackwardImplementation, AutoDiffException, etc. +│ +├── autodiff.py # Entry point +│ └── add_backward_pass() - High-level API +│ +├── base_abc.py # Abstract base classes +│ ├── BackwardImplementation (ABC) +│ ├── BackwardContext (dataclass) +│ ├── BackwardResult (dataclass) +│ ├── AutoDiffException +│ └── find_backward_implementation() +| +├── backward_pass_generator.py # Core AD engine +│ └── BackwardPassGenerator class - Main differentiation algorithm +│ +├── analysis.py # SDFG analysis +│ ├── dependency_analysis() +│ ├── inverse_reachability() +│ └── is_previously_written() +│ +├── utils.py # Utility functions +│ ├── Descriptor management +│ ├── Symbolic differentiation +│ ├── Graph traversal +│ └── Loop analysis +│ +├── torch.py # PyTorch integration +│ └── make_backward_function() - Convert ONNX to PyTorch differentiable +│ +├── data_forwarding/ # Store or recompute strategies +│ ├── __init__.py # Package exports +│ ├── manager.py # Strategy coordinator +│ │ └── DataForwardingManager +│ ├── store.py # Store strategy +│ │ └── resolve_overwrite_with_store() +│ └── recompute.py # Recompute strategy +│ └── resolve_overwrite_with_recomputation() +│ get_recomputation_nsdfg() +│ +├── implementations/ # Backward rules for node types +│ ├── __init__.py # Package exports (46 lines) +│ ├── dace_nodes.py # Pure SDFG elements (487 lines) +│ │ └── DaceNodeBackwardImplementations +│ │ ├── _reverse_AccessNode() +│ │ ├── _reverse_Tasklet() +│ │ ├── _reverse_MapEntry() +│ │ ├── _reverse_MapExit() +│ │ └── _reverse_NestedSDFG() +│ ├── dace_reduction_nodes.py # Reduction operations (307 lines) +│ │ ├── ReverseReduce +│ │ └── ... (reduction backward implementations) +│ ├── onnx_ops.py # ONNX operations (1045 lines) +│ │ ├── ONNXConvBackward +│ │ ├── ONNXMatMulBackward +│ │ └── ... (50+ ONNX ops) +│ └── pytorch_ops.py # PyTorch operations (128 lines) +│ └── Depthwise convolution backward pass +│ +└── library/ # Library integrations + ├── __init__.py # Package exports (31 lines) + ├── library.py # BackwardPass node (286 lines) + │ ├── ParameterArray (data descriptor) + │ ├── BackwardPass (LibraryNode) + │ └── ExpandBackwardPass (expansion) + └── torch_integration.py # PyTorch hooks (39 lines) +``` + + +## 3. Core Components + +### 3.1 Entry Point: `autodiff.py` + +**Location**: [autodiff.py](autodiff.py) + +The main entry point for users to add backward passes to SDFGs. + + +#### 3.1.1 Workflow + +``` +┌─────────────────────┐ +│ 1. Validate SDFG │ +└──────────┬──────────┘ + ▼ + ┌───────────────┐ + │ 2. Simplify │ + └──────┬────────┘ + ▼ +┌─────────────────────────────────┐ +│ 3. Inline Control Flow │ +│ (conditionals, not loops) │ +└──────────┬──────────────────────┘ + ▼ +┌─────────────────────────────────────┐ +│ 4. Create Backward SDFG │ +│ (if separate_sdfgs flag is True) │ +└──────────┬──────────────────────────┘ + ▼ +┌─────────────────────────────────┐ +│ 5. Initialize BackwardPass- │ +│ Generator │ +└──────────┬──────────────────────┘ + ▼ +┌─────────────────────────────────┐ +│ 6. generator.backward() │ +│ (main differentiation) │ +└──────────┬──────────────────────┘ + ▼ +┌─────────────────────┐ +│ 7. Validate SDFG │ +└──────────┬──────────┘ + ▼ + ┌───────────────┐ + │ 8. Simplify │ + └──────┬────────┘ + ▼ +┌─────────────────────┐ +│ 9. Return SDFG │ +└─────────────────────┘ +``` + +#### 3.1.2 Key Constraints + +- **Supported Nodes**: + - Maps, AccessNodes, Tasklets, LoopRegions, ControlFlowRegions (inlined into state machine) + - Reductions (Sum, Min, Max) + - ONNXOps (with registered backward implementations) + - NestedSDFGs + +--- + +### 3.2 BackwardPassGenerator: The Core AD Engine + +**Location**: [backward_pass_generator.py](backward_pass_generator.py) + +The `BackwardPassGenerator` class is the core of the AD system. It orchestrates the entire backward pass generation process. + +#### 3.2.1 Key Data Structures + +The generator maintains several mappings and data structures: + +- **Configuration**: + - `sdfg`: Forward SDFG + - `backward_sdfg`: Backward SDFG (can be same or separate) + - `given_gradients_data`: Output arrays (seed gradients provided) + - `required_gradients_data`: Input arrays (gradients to compute) + - `data_forwarding_strategy`: "store_all", "recompute_all", "user_defined" + +- **Generated Mappings**: + - `reverse_map: Dict[Node, Node]`: Forward node → backward node + - `reversed_states_map: Dict[SDFGState, SDFGState]`: Forward state → backward state + - `array_grad_map: Dict[str, str]`: Array name → gradient array name + - `result_map: Dict[Node, BackwardResult]`: Forward node → BackwardResult + +- **Analysis Results**: + - `read_only_arrays`: Arrays never written to + - `backward_grad_arrays`: Gradient array descriptors + - `backward_input_arrays`: Forward values needed in backward pass + - `data_to_forward`: List of data to forward from forward to backward + +#### 3.2.2 Main Algorithm: `backward()` + +**Steps**: + +1. **Initialize gradient arrays** for all required outputs +2. **Compute state order** (topological sort of SDFG states) +3. **Extract the Critical Computation Subgraph (CCS) of each state** +4. **Reverse the CCS of states** in reverse topological order: + - Create backward state + - Reverse nodes within CCS of the state + - Connect gradients between reversed nodes +5. **Reverse loop regions** by generating loop regions in the backward pass +6. **Handle data forwarding** (store or recompute intermediates) +7. **Create interstate edges** to reverse control flow and connect all reversed components +8. **Return** backward result with gradient mappings + +#### 3.2.3 State Reversal + +For each forward state, the generator: + +1. Creates a corresponding backward state +2. For each node in the CCS of the state: + - Finds appropriate backward implementation from registry + - Determines given/required gradients + - Calls `implementation.backward()` + - Stores mapping and result +3. Connects gradients between reversed nodes + +--- + +### 3.3 Abstract Base Classes: `base_abc.py` + +**Location**: [base_abc.py](base_abc.py) + +#### 3.3.1 BackwardImplementation (ABC) + +The abstract base class for all backward implementations. + +```python```python +@dace.registry.make_registry +class BackwardImplementation(abc.ABC): + + @staticmethod + def backward_can_be_applied(node: nd.Node, state: SDFGState, + sdfg: SDFG) -> bool: + """Check if this implementation can be applied to the node.""" + return True + + @staticmethod + @abc.abstractmethod + def backward( + forward_node: nd.Node, + context: BackwardContext, + given_gradients: List[Optional[str]], + required_gradients: List[Optional[str]] + ) -> Tuple[nd.Node, BackwardResult]: + """Generate the backward pass for this node.""" + ... +``` + +**Registration Example**: + +```python + +# For ONNX operations +@dace.registry.autoregister_params(op="MatMul", name="pure") +class MatMulBackward(BackwardImplementation): + ... +``` + +#### 3.3.2 BackwardContext (Dataclass) + +Contains all context information needed by backward implementations: + +```python +@dataclasses.dataclass +class BackwardContext: + forward_sdfg: SDFG # The forward SDFG + backward_sdfg: SDFG # The backward SDFG + backward_generator: BackwardPassGenerator # The generator (for utilities) +``` + +#### 3.3.3 BackwardResult (Dataclass) + +Returns information about the generated backward node: + +```python +@dataclasses.dataclass +class BackwardResult: + """Result of differentiating a node.""" + + # Mapping from forward input connector → gradient connector name + required_grad_names: Dict[Optional[str], Optional[str]] + + # Mapping from forward output connector → gradient connector name + given_grad_names: Dict[Optional[str], Optional[str]] + + # Which gradients should be zero-initialized + zero_init: Dict[Optional[str], Optional[bool]] +``` + +#### 3.3.4 find_backward_implementation() + +Looks up the registered backward implementation for a node by: + +1. Querying `BackwardImplementation.extensions()` registry +2. Filtering by `node_type` (for DaCe nodes) or `op` (for ONNX) +3. Checking `backward_can_be_applied()` for each candidate +4. Returning first valid implementation + +--- + +### 3.4 Analysis Utilities: `analysis.py` + +**Location**: [analysis.py](analysis.py) + +Provides SDFG analysis functions used by the AD engine: + +#### 3.4.1 dependency_analysis() + +Computes transitive read dependencies for each array. For example, if `C = A + B`, then `dependencies["C"] = {"A", "B"}`. Uses graph traversal and transitive closure to build a complete dependency map. + +#### 3.4.2 inverse_reachability() + +For each state, computes the set of predecessor states that can reach it. Uses DaCe's `StateReachability` analysis pass. + +#### 3.4.3 is_previously_written() + +Determines if an array was written before a given node in a state. Used by data forwarding to determine if an intermediate value needs to be stored (because it will be overwritten). + +Checks both: +1. Current state (concurrent subgraphs) +2. Predecessor states + +--- + +### 3.5 Utility Functions: `utils.py` + +**Location**: [utils.py](utils.py) + +The `utils.py` module contains many helper functions organized into categories: + +#### 3.5.1 Descriptor Management + +- `add_backward_desc()`: Add gradient array descriptor to backward SDFG +- `add_backward_desc_for_connector()`: Add backward descriptor for specific connector +- Helper functions for managing array descriptors and data types + +#### 3.5.2 Symbolic Differentiation + +- `differentiate_tasklet()`: Symbolically differentiates tasklet code using AST parsing and SymPy +- Converts tasklet expressions to symbolic form, computes derivatives, and generates backward code + +#### 3.5.3 Graph Traversal + +- `get_all_path_edges()`: Gets all edges on paths from source to target +- `concurrent_subgraphs()`: Finds concurrent execution regions in a state +- Helper functions for navigating SDFG structures + +#### 3.5.4 Loop Analysis + +- `state_within_loop()`: Checks if a state is inside a loop region +- `get_loop_carried_dependencies()`: Finds arrays with loop-carried dependencies +- Loop-specific helper functions + +--- + +## 4. Data Forwarding System + +### 4.1 The Core Problem + +During backward pass generation, we often need access to intermediate values from the forward pass: + +**Example**: +```python +# Forward +y = sigmoid(x) +z = y * y +L = z # Identity loss function + +# Backward (to compute dL/dx) +dL/dy = dL/dz * 2y # Need y from forward pass! +dL/dx = dL/dy * y * (1 - y) # Need y again! +``` + +**Two strategies**: +1. **Store**: Save `y` during forward pass, load it during backward + - **Pro**: Fast backward pass (no recomputation) + - **Con**: High memory usage +2. **Recompute**: Recompute `y = sigmoid(x)` during the backward pass + - **Pro**: Low memory usage (no storage required) + - **Con**: Slower backward pass due to recomputation cost + +--- + +### 4.2 DataForwardingManager: `manager.py` + +**Location**: [data_forwarding/manager.py](data_forwarding/manager.py) + +Coordinates the data forwarding strategy. + +#### 4.2.1 Strategy Selection + +The manager provides three strategies: + +1. **`store_all`** (default): Store all intermediate values + - Fastest backward pass + - Highest memory usage + - Best for memory-rich environments + +2. **`recompute_all`**: Recompute all intermediate values + - Experimental feature to test recomputation capabilities + +3. **`user_defined`**: User specifies which arrays to recompute + - Balanced approach + - Requires domain knowledge + - Allows fine-grained control + +#### 4.2.2 Main Algorithm + +For each data item that needs to be forwarded: + +1. Determine if the data is overwritten before backward pass needs it +2. If overwritten, choose resolution strategy (store or recompute) +3. Apply strategy: + - **Store**: Create copy before overwrite, load in backward + - **Recompute**: Extract computation subgraph, inline in backward + +#### 4.2.3 Overwrite Detection Algorithm + +**Problem**: Determine if an intermediate value is overwritten before the backward pass needs it + +**Algorithm**: +``` +is_overwritten(array, state, node): + 1. Check if array is written in concurrent subgraphs + 2. Check if array is written in successor states + 3. If either is true, the array is overwritten + 4. Apply data forwarding strategy (store or recompute) +``` + +**Uses**: `is_previously_written()` from `analysis.py` + +--- + +### 4.3 Store Strategy: `store.py` + +**Location**: [data_forwarding/store.py](data_forwarding/store.py) + +**Key Function**: `resolve_overwrite_with_store()` + +**Approach**: + +``` +Forward Pass State Backward Pass State +┌──────────────┐ ┌──────────────┐ +│ Compute x │ │ │ +│ Store x_copy │ │ Load x_copy │ +│ Overwrite x │ │ Use in grad │ +└──────────────┘ └──────────────┘ +``` + +**Steps**: +1. Create a storage descriptor for the intermediate value +2. Add a copy operation in the forward state (before overwrite) +3. Add a load operation in the backward state (when needed) +4. Update memlets to use the stored copy + +--- + +### 4.4 Recompute Strategy: `recompute.py` (Experimental!) + +**Location**: [data_forwarding/recompute.py](data_forwarding/recompute.py) + +**Key Function**: `resolve_overwrite_with_recomputation()` + +**Approach**: + +``` +Forward Pass State Backward Pass State +┌──────────────┐ ┌──────────────┐ +│ Compute x │ │ Recompute x │ +│ │ → │ Use in grad │ +│ Overwrite x │ │ │ +└──────────────┘ └──────────────┘ +``` + +**Steps**: +1. Extract the computation subgraph that produces the value +2. Create a nested SDFG containing the recomputation logic +3. Inline the nested SDFG in the backward state +4. Connect inputs and outputs appropriately + +**Subgraph Extraction** (`get_recomputation_nsdfg()`): +- Performs backward breadth-first search (BFS) from the data node to find all dependencies +- Copies nodes and edges into a new nested SDFG +- Handles map scopes and connectors +- Ensures all dependencies are included + +--- + +## 5. Backward Implementations + +### 5.1 DaCe Core Nodes: `dace_nodes.py` + +**Location**: [implementations/dace_nodes.py](implementations/dace_nodes.py) + +Implements backward passes for core SDFG elements. + +#### 5.1.1 AccessNode + +**Purpose**: Create gradient AccessNode + +**Approach**: +- Forward: `AccessNode("x")` +- Backward: `AccessNode("grad_x")` with `setzero=True` + +Also handles view connectors for arrays with views or subsets. + +#### 5.1.2 Tasklet + +**Purpose**: Symbolically differentiate tasklet code + +**Approach**: +1. Parse tasklet code to AST +2. Extract output expressions +3. Use SymPy to compute symbolic derivatives +4. Generate backward code: `grad_input = grad_output * derivative` + +**Example**: +- Forward: `y = x * x` +- Backward: `grad_x = grad_y * 2 * x` + +#### 5.1.3 Maps + +**Purpose**: Reverse map structure + +Maps are special: `MapEntry` and `MapExit` nodes are swapped in the backward pass. + +**Forward**: +``` +AccessNode → MapEntry → [Tasklet in scope] → MapExit → AccessNode +``` + +**Backward**: +``` +AccessNode → MapEntry (reversed) → [Tasklet_grad in scope] → MapExit (reversed) → AccessNode +``` + +**Approach**: +- `MapEntry` → `MapExit` in backward pass +- `MapExit` → `MapEntry` in backward pass +- Connectors inverted: `IN_X` ↔ `OUT_X` +- Same map object used for both + +#### 5.1.4 NestedSDFG + +**Purpose**: Recursively differentiate nested SDFGs + +**Approach**: +1. Recursively call `add_backward_pass()` on nested SDFG +2. Map forward connectors to backward connectors +3. Handle symbols and interstate edges +4. Ensure proper gradient flow through nested boundaries + +#### 5.1.5 LoopRegions + +**Purpose**: Reverse loops in the forward SDFG + +**Approach**: +Loops are reversed by creating a backward loop that iterates in the reverse direction to process gradients. + + +``` +# Forward loop: +for i in range(N): + y[i+1] = f(x[i]) + +# Backward loop: +for i in reversed(range(N)): + grad_x[i] = grad_f(x[i]) * grad_y[i+1] +``` + +--- + +### 5.2 DaCe Reduction Nodes: `dace_reduction_nodes.py` + +**Location**: [implementations/dace_reduction_nodes.py](implementations/dace_reduction_nodes.py) + +Implements backward passes for DaCe reduction operations (307 lines). + +#### 5.2.1 Key Implementations + +| Operation | Backward Implementation | Notes | +|-----------|------------------------|-------| +| **Reduce (Sum)** | Broadcast gradient to match input shape | Handles axis reduction | +| **Reduce (Max/Min)** | Gradient flows only to max/min elements | Requires forward values | + +--- + +### 5.3 ONNX Operations: `onnx_ops.py` + +**Location**: [implementations/onnx_ops.py](implementations/onnx_ops.py) + +Implements backward passes for 50+ ONNX operations. Each implementation follows the ONNX operator specification for gradient computation. + +**Categories**: + +- **Element-wise**: Add, Sub, Mul, Div, Sqrt, Exp, Log, Pow, etc. +- **Activation**: Relu, Sigmoid, Tanh, Softmax, etc. +- **Matrix**: MatMul, Gemm, BatchMatMul +- **Convolution**: Conv, ConvTranspose +- **Pooling**: MaxPool, AveragePool, GlobalAveragePool +- **Normalization**: BatchNormalization, LayerNormalization +- **Reduction**: ReduceSum, ReduceMean, ReduceMax, etc. +- **Shape**: Reshape, Transpose, Concat, Split, Squeeze, Unsqueeze +- **Advanced**: Gather, Scatter, Einsum, etc. + +Each ONNX backward implementation is registered with `@dace.registry.autoregister_params(op="OpName")`. + +--- + +### 5.4 PyTorch Operations: `pytorch_ops.py` + +**Location**: [implementations/pytorch_ops.py](implementations/pytorch_ops.py) + +Implements backward passes using PyTorch's optimized CUDA kernels (128 lines). + +#### 5.4.1 Key Implementations + +| Operation | Backward Implementation | Notes | +|-----------|------------------------|-------| +| **Conv (depthwise)** | `PyTorchConvBackward` | Uses `at::thnn_conv_depthwise2d_backward_out` | + +This implementation leverages PyTorch's C++ ATen library for GPU-accelerated depthwise convolution backward passes. + +--- + +## 6. Library Integration + +### 6.1 BackwardPass Library Node: `library.py` + +**Location**: [library/library.py](library/library.py) + +Provides a library node for encapsulating backward passes as reusable components. + +#### 6.1.1 ParameterArray + +A special data descriptor for gradient accumulation buffers that mimics PyTorch Parameters. + +#### 6.1.2 BackwardPass + +A library node that wraps a backward pass SDFG, allowing backward passes to be composed and reused like other library operations. + +#### 6.1.3 ExpandBackwardPass + +Expands the `BackwardPass` library node into the full SDFG. Handles: +- Gradient initialization (zero or provided seed) +- Parameter gradient accumulation + +--- + +## 7. PyTorch Integration + +### 7.1 Overview: `torch.py` + +**Location**: [torch.py](torch.py) + +Enables the integration between DaCe AD and PyTorch's autograd system. + +### 7.2 make_backward_function() + +**Purpose**: Convert ONNX model to PyTorch-differentiable function + +**Signature**: +```python +def make_backward_function( + forward_sdfg: SDFG, + inputs: List[str], + outputs: List[str], + parameters: Optional[List[str]] = None +) -> Type[torch.autograd.Function]: +``` + +**Returns**: PyTorch `autograd.Function` subclass with: +- `forward()`: Compiles and runs forward SDFG +- `backward()`: Compiles and runs backward SDFG +- Handles PyTorch tensor ↔ DaCe array conversion +- Supports scalar inputs/outputs +- Manages parameter gradients + +### 7.3 Integration Flow + +``` +PyTorch Model + ↓ +DaCe ONNX Import + ↓ +Forward SDFG + ↓ +add_backward_pass() + ↓ +Backward SDFG + ↓ +make_backward_function() + ↓ +torch.autograd.Function + ↓ +Use in PyTorch training loop +``` + +--- + +## 8. Gradient Accumulation and Clearing + +### 8.1 Gradient Accumulation + +**Problem**: Multiple paths can contribute to same gradient + +**Example**: +``` + ┌─→ y1 ─┐ + x ──┤ ├─→ z + └─→ y2 ─┘ +``` + +Both `y1` and `y2` contribute to `grad_x`. + +**Solution**: Write-Conflict Resolution (WCR) + +When connecting gradients, use WCR on memlets: +```python +memlet.wcr = "lambda a, b: a + b" +``` + +This ensures multiple gradient contributions are summed correctly. + +### 8.2 Gradient Clearing + +**Problem**: Overwritten arrays in the forward pass require clearing the gradients of the corresponding gradient arrays to allow the always-accumulate solution presented above. + +**When to Clear Gradients**: +- In the backward pass, at the corresponding point where arrays in the forward pass were overwritten. + +**Implementation Strategies**: + +1. **Zero Initialization for all intermediate arrays**: Set all gradient arrays to zero before backward pass + ```python + # In DaCe, gradient arrays can be initialized with setzero=True + grad_array = AccessNode("grad_x", setzero=True) + ``` + +2. **Manual Clearing**: Explicitly zero out gradient arrays if necessary + ```python + # Reset gradients if an overwrite is detected in dace/autodiff/backward_pass_generator.py + self._zero_out_gradient(forward_state=forward_state, + forward_node=node, + memlet=edge.data) + ``` diff --git a/dace/autodiff/autodiff.py b/dace/autodiff/autodiff.py new file mode 100644 index 0000000000..cfd686c224 --- /dev/null +++ b/dace/autodiff/autodiff.py @@ -0,0 +1,83 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +from typing import List, Union, Optional + +from dace.autodiff.backward_pass_generator import BackwardPassGenerator + +from dace.sdfg import SDFG, nodes +from dace.sdfg.utils import inline_control_flow_regions +from dace.sdfg.state import LoopRegion + + +def add_backward_pass(sdfg: SDFG, + outputs: List[Union[nodes.AccessNode, str]], + inputs: List[Union[nodes.AccessNode, str]], + data_forwarding_strategy: str = "store_all", + data_to_recompute: Optional[List[str]] = None, + simplify: bool = True, + separate_sdfgs: bool = False) -> Optional[SDFG]: + """ Experimental: Add a backward pass to `state` using reverse-mode automatic differentiation. + + ``inputs``, ``outputs`` and ``grads`` can be provided either as ``AccessNode`` nodes, or as ``str``, in which + case the graph will be searched for exactly one matching ``AccessNode`` with data matching the ``str``. + + The SDFG may contain the following nodes: + + * Maps + * AccessNodes + * Reductions (Sum, Min, Max) + * ONNXOps + * Multiple states + * LoopRegions + * NestedSDFGs (subject to the same constraints) + + When differentiating an :class:`~dace.libraries.onnx.nodes.onnx_op.ONNXOp`, the ONNXBackward registry will be checked + for any matching backward pass implementations. If none are found, the ONNXForward registry will be checked for + matching pure implementations. If one is found, symbolic differentiation of the pure implementation will be + attempted. If this fails, or no pure forward implementation is found, the method will fail. + + .. note:: + This function modifies the input SDFG in-place. Even if ``separate_sdfgs`` is ``True``, modifications + such as storing intermediate results and inlining ControlFlowRegions can be applied to the original SDFG. + + :param sdfg: the SDFG to add the backward pass to. + :param outputs: the forward pass outputs of the function to differentiate. + :param inputs: the inputs w.r.t. which the gradient will be returned. + :param data_forwarding_strategy: strategy for forwarding data to the backward pass. Could be one of: + * "store_all": store all intermediate data (default, uses most memory, fastest). + * "recompute_all": recompute all intermediate data. + * "user_defined": store all intermediates except for ones specified in `data_to_recompute`. + :param data_to_recompute: list of data arrays to recompute instead of storing. Only used if + `data_forwarding_strategy` is "user_defined". + :param simplify: whether to apply the simplify pass to the forward and backward SDFGs. + :param separate_sdfgs: whether to create a separate SDFG for the backward pass. + :return: the backward SDFG if separate_sdfgs is True, the original SDFG (which now also contains the backward pass) otherwise. + """ + # Validate the SDFG + sdfg.validate() + + if simplify: + sdfg.simplify() + + # Inline conditional blocks but keep loops + inline_control_flow_regions(sdfg, ignore_region_types=[LoopRegion]) + + if separate_sdfgs: + backward_sdfg = SDFG(sdfg.name + "_backward") + else: + backward_sdfg = sdfg + + # Add backward pass + gen = BackwardPassGenerator(sdfg=sdfg, + given_gradients=outputs, + required_gradients=inputs, + backward_sdfg=backward_sdfg, + data_forwarding_strategy=data_forwarding_strategy, + data_to_recompute=data_to_recompute) + gen.backward() + sdfg.validate() + + if simplify: + sdfg.simplify() + sdfg.validate() + + return backward_sdfg diff --git a/dace/autodiff/backward_pass_generator.py b/dace/autodiff/backward_pass_generator.py new file mode 100644 index 0000000000..6857749643 --- /dev/null +++ b/dace/autodiff/backward_pass_generator.py @@ -0,0 +1,2056 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import copy +from typing import List, Tuple, Set, Dict, Union, Optional, Sequence +import sympy as sp + +# DaCe imports +import dace +from dace.properties import CodeBlock +import dace.sdfg.nodes as nodes +import dace.transformation.transformation as xf +from dace import dtypes, data as dt +from dace.sdfg import SDFG, SDFGState, state as dstate, utils as dace_utils +from dace.sdfg.state import LoopRegion +from dace.memlet import Memlet + +try: + from dace.libraries.onnx.forward_implementation_abc import ONNXForward + from dace.libraries.onnx.nodes.onnx_op import ONNXOp + ONNX_AVAILABLE = True +except ImportError: + ONNXForward = None + ONNXOp = None + ONNX_AVAILABLE = False + +# Autodiff imports +from dace.autodiff.base_abc import (BackwardContext, BackwardResult, AutoDiffException, find_backward_implementation, + ExpansionTemplate) +import dace.autodiff.utils as ad_utils +from dace.autodiff.implementations.dace_nodes import DaceNodeBackwardImplementations +from dace.autodiff.data_forwarding.manager import DataForwardingManager + + +class BackwardPassGenerator: + """Generator for automatic differentiation backward passes on DaCe SDFGs. + + This class orchestrates the creation of backward passes for automatic differentiation + using reverse-mode AD. It handles gradient computation, data forwarding between + forward and backward passes, and complex control flow structures. + + :param sdfg: The forward SDFG to differentiate. + :param given_gradients: Output arrays for which gradients are provided (seed gradients). + :param required_gradients: Input arrays for which gradients should be computed. + :param backward_sdfg: SDFG to contain the backward pass. Can be same as forward SDFG. + :param array_grad_map: Optional mapping from array names to gradient array names. + :param conflicted_gradient_buffers: Arrays with potential write conflicts requiring special handling. + :param data_forwarding_strategy: Strategy for forwarding data ('store_all', 'recompute_all', 'user_defined'). + :param data_to_recompute: Arrays to recompute instead of storing (when strategy='user_defined'). + :raises AutoDiffException: If the backward pass generation fails. + + Example:: + + gen = BackwardPassGenerator( + sdfg=forward_sdfg, + given_gradients=['loss'], + required_gradients=['weights', 'input'] + ) + gen.backward() + """ + + def __init__( + self, + *, + sdfg: SDFG, + given_gradients: Sequence[Union[nodes.AccessNode, str]], + required_gradients: Sequence[Union[nodes.AccessNode, str]], + backward_sdfg: SDFG, # This can be the same as sdfg + array_grad_map: Optional[Dict[str, str]] = None, + conflicted_gradient_buffers: Optional[Set[str]] = None, + data_forwarding_strategy: str = "store_all", + data_to_recompute: Optional[List[str]] = None, + ): + + self.sdfg: SDFG = sdfg + self.data_to_recompute = data_to_recompute + self.backward_sdfg: SDFG = backward_sdfg + + given_gradients = [ + n if isinstance(n, nodes.AccessNode) else self._str_to_access(n, "outputs") for n in given_gradients + ] + required_gradients = [ + n if isinstance(n, nodes.AccessNode) else self._str_to_access(n, "inputs") for n in required_gradients + ] + required_gradients = [n for n in required_gradients if n is not None] + + self.given_gradients_data = {n.data for n in given_gradients} + self.required_gradients_data = {n.data for n in required_gradients} + + self.input_names = {n.data for n in required_gradients} + self.output_names = {n.data for n in given_gradients} + + #: Arrays descriptors for the gradients + self.backward_grad_arrays: Dict[str, dt.Array] = {} + + #: Arrays descriptors for inputs that are required from the forward pass + self.backward_input_arrays: Dict[str, dt.Array] = {} + + #: Mapping from forward node -> backward node, and forward map -> backward map + self.reverse_map: Dict[nodes.Node, Union[nodes.Node, nodes.Map]] = {} + + #: Mapping from forward state -> backward state + self.reversed_states_map: Dict[SDFGState, SDFGState] = {} + + #: Mapping from forward LoopRegion -> backward LoopRegion + self.reversed_loops_map: Dict[LoopRegion, LoopRegion] = {} + + #: Mapping from forward state -> backward state for loop states + self.reversed_loop_states_map: Dict[nodes.Node, nodes.Node] = {} + + #: Mapping between states and their subgraph views for AD processing + self.states_view_map: Dict[SDFGState, dstate.StateSubgraphView] = {} + + #: Mapping between loop states and their subgraph views for AD processing + self.loop_states_view_map: Dict[SDFGState, dstate.StateSubgraphView] = {} + + #: Mapping between the map entry of a conditional assignment block and its zero-out AN + self.conditional_block_entry: Dict[nodes.MapEntry, nodes.AccessNode] = {} + + #: Mapping from forward_node -> BackwardResult for that node + self.result_map: Dict[nodes.Node, BackwardResult] = {} + + #: Mapping from forward name to gradient name for arrays + self.array_grad_map: Dict[str, str] = array_grad_map or {} + + #: Mapping from the backward access nodes that will be zeroed out + # to the transients that contain the values before they are zeroed out + self.zeroed_out: Dict[nodes.AccessNode, List[nodes.AccessNode]] = {} + + #: The read-only arrays of the forward SDFG. Used in data forwarding decisions + self.read_only_arrays: Set[str] = ad_utils.get_read_only_arrays(self.sdfg) + + #: Mapping from overwritten input name to storing AccessNode + self.stored_inputs: Dict[str, nodes.AccessNode] = {} + + # Variable to check if backward has already been applied + self._applied = False + + self.data_forwarding_strategy = data_forwarding_strategy + + # Topological ordering of the states + self.state_order = ad_utils.get_state_topological_order(self.sdfg) + self.conflicted_gradient_buffers: Set[str] = conflicted_gradient_buffers or set() + + self.interstate_symbols: Dict[str, str] = {} + for edge in self.sdfg.all_interstate_edges(): + for assign_symbol, assignment in edge.data.assignments.items(): + self.interstate_symbols[assign_symbol] = assignment + + # Validate parameters and setup SDFG configuration + self._validate_gradients() + self._setup_sdfg_configuration(sdfg, backward_sdfg, given_gradients) + + # DaCe nodes backward implementations + self.dace_node_impl = DaceNodeBackwardImplementations(self) + + #: List containing information about all the data to be forwarded to the backward pass + self.data_to_forward: List[Tuple[SDFGState, SDFGState, nodes.AccessNode, nodes.Node, + dstate.MultiConnectorEdge]] = [] + + # Data forwarding manager + self.data_forwarding_manager = DataForwardingManager(self) + + def _validate_gradients(self) -> None: + """Validate that gradient arrays exist in the SDFG. + + Raises: + AutoDiffException: If gradient arrays are not found in SDFG arrays. + """ + # Check outputs (given gradients) + for outp in self.given_gradients_data: + if outp not in self.sdfg.arrays: + raise AutoDiffException(f"Could not find output '{outp}' in SDFG array descriptors") + + # Check inputs (required gradients) + for inp in self.required_gradients_data: + if inp not in self.sdfg.arrays: + raise AutoDiffException(f"Could not find input '{inp}' in SDFG array descriptors") + + def _setup_sdfg_configuration(self, sdfg: SDFG, backward_sdfg: SDFG, + given_gradients: List[nodes.AccessNode]) -> None: + """Setup SDFG configuration for separate or combined forward/backward passes. + + :param sdfg: Forward SDFG. + :param backward_sdfg: Backward SDFG. + :param given_gradients: List of gradient output nodes. + :raises AutoDiffException: If configuration is invalid for combined SDFG mode. + """ + if sdfg is backward_sdfg: + # Combined mode requires single scalar output + if len(given_gradients) != 1: + raise AutoDiffException("When forward and backward SDFGs are the same, exactly one output is required, " + f"got {len(given_gradients)}") + + output_array = sdfg.arrays[given_gradients[0].data] + if not ad_utils.is_int_eq_value(output_array.total_size, 1): + raise AutoDiffException("When forward and backward SDFGs are the same, output must be a single scalar") + + self.separate_sdfgs = False + else: + self.separate_sdfgs = True + + def create_child_generator(self, **kwargs) -> 'BackwardPassGenerator': + """Create a child generator for nested SDFG differentiation. + + This factory method creates a new BackwardPassGenerator instance for differentiating + nested SDFGs, propagating relevant configuration from the parent generator. + + :param kwargs: Parameters to pass to the child generator constructor. + Required: sdfg, given_gradients, required_gradients, backward_sdfg. + :return: A new BackwardPassGenerator instance configured for the nested SDFG. + """ + defaults = { + 'data_forwarding_strategy': self.data_forwarding_strategy, + 'data_to_recompute': self.data_to_recompute, + } + defaults.update(kwargs) + return BackwardPassGenerator(**defaults) + + def backward(self) -> Tuple[BackwardResult, Dict[str, dt.Array], Dict[str, dt.Array]]: + """Generate the backward pass in backward_sdfg.""" + return self.reverse_sdfg() + + def reverse_sdfg(self) -> Tuple[BackwardResult, Dict[str, dt.Array], Dict[str, dt.Array]]: + """Generate the backward pass by reversing all SDFG states. + + Processes all states in the SDFG and creates their backward counterparts, + connecting them with appropriate control flow for gradient computation. + + :return: A tuple containing: + + * ``BackwardResult`` - Contains gradient mappings and metadata. + * ``Dict[str, dt.Array]`` - Gradient array descriptors (backward pass outputs). + * ``Dict[str, dt.Array]`` - Forward pass arrays required by backward pass. + + :raises AutoDiffException: If backward pass was already applied to this generator. + """ + + if self._applied: + raise AutoDiffException("Backward may only be called once. Instantiate a new BackwardPassGenerator.") + + # Create state views mapping and expand all the SDFG nodes + self._create_stateviews_mapping() + + # Reverse each state in the graph + self._reverse_states() + + # Connect the new reversed states to the other states correctly + self._connect_reversed_states() + + # Fill the interstate edges with the correct conditions + self._fill_interstate_edge_conditions() + + # Add interstate assignments for control flow decisions + self._add_interstate_edge_assignments() + + # Forward required data by the backward pass according to a user defined strategy + self.data_forwarding_manager.forward_data_to_backward_pass() + + # In some cases (accessnode -> accessnode), the descriptors for the gradients of the function outputs are not + # added yet. Add them now. + for given_grad in sorted(self.given_gradients_data): + if self.array_grad_name(given_grad) not in self.backward_sdfg.arrays: + self._add_gradient_data_descriptor(given_grad) + + # Prepare the output + required_grad_names = {name: self.array_grad_name(name) for name in self.required_gradients_data} + given_grad_names = {name: self.array_grad_name(name) for name in self.given_gradients_data} + + # Set mapping from gradient name to whether it should be zeroed out on initialization + zero_init: Dict[str, bool] = {} + for node, bres in self.result_map.items(): + forward_state = self._get_node_state(node=node) + for zname, zinit in bres.zero_init.items(): + # Reverse lookup + cname = next(k for k, v in bres.required_grad_names.items() if v == zname) + + for e in forward_state.in_edges_by_connector(node, cname): + zero_init[e.data.data] = zinit + for e in forward_state.out_edges_by_connector(node, cname): + zero_init[e.data.data] = zinit + + self._applied = True + result = BackwardResult(required_grad_names=required_grad_names, + given_grad_names=given_grad_names, + zero_init=zero_init) + return result, self.backward_grad_arrays, self.backward_input_arrays + + def _create_stateviews_mapping(self) -> None: + """Map each state in the SDFG to views that indicate what to differentiate.""" + self._find_subgraph_to_differentiate() + # Expand until there is nothing left to expand + while self._expand_nodes(): + # Nodes have been expanded again on the expanded graph; recalculate the forward graph + self._find_subgraph_to_differentiate() + + def _reverse_states(self) -> None: + """Go through all states of the forward SDFG, reverse them and add them to the backward SDFG.""" + # For reversal we want to iterate through the states in reverse topological order + for state in reversed(self.state_order): + # Get all the views of this state + if state not in self.states_view_map: + raise AutoDiffException(f"State {state} not found in states view map") + state_subgraph_views = [self.states_view_map[state]] + + # In case this is a state loop + state_subgraph_loop_view = [] + if state in self.loop_states_view_map: + loop_view = self.loop_states_view_map[state] + state_subgraph_loop_view.append(loop_view) + + for state_subgraph_view in state_subgraph_views: + + # Make sure this state has not already been reversed + if state in self.reversed_states_map: + raise AutoDiffException(f"State {state} has already been reversed") + + # Create the new reversed state label + if state_subgraph_view in state_subgraph_loop_view: + reversed_state_label = f"{state.label}_loop_reversed" if state.label else None + else: + reversed_state_label = f"{state.label}_reversed" if state.label else None + + # Create new state for reversal + # At the moment we add all states to the backward_sdfg directly + # This will later be modified when connecting the states + reversed_state = self.backward_sdfg.add_state(label=reversed_state_label) + + # Add the new state to the reversed map dict + if state_subgraph_view in state_subgraph_loop_view: + self.reversed_loop_states_map[state] = reversed_state + else: + self.reversed_states_map[state] = reversed_state + + # Check that all edges are float, int, or boolean + ad_utils.check_edges_type_in_state(state_subgraph_view) + + # Recursively reverse the subgraph + self._reverse_subgraph(forward_state=state, backward_state=reversed_state, subgraph=state_subgraph_view) + + # We also reverse all the LoopRegions in the graph + for node in self.sdfg.nodes(): + if not isinstance(node, LoopRegion): + continue + self._reverse_loop_region(node) + + def _connect_reversed_states(self) -> None: + """Connect backward states corresponding to forward SDFG states. + + All incoming edges of a forward state become outgoing edges in the backward SDFG. + """ + + for state in self.state_order: + # All states should be reversed already + if state not in self.reversed_states_map: + raise AutoDiffException(f"State {state} not found in reversed states map") + backward_state = self.reversed_states_map[state] + + # Get all the out edges of the forward state + parent_graph = state.parent_graph + state_out_edges = parent_graph.out_edges(state) + + # If there are no outgoing connections + if len(state_out_edges) == 0: + # This is an end-state and it needs to be connected to its reversed state + # we do this only if the backward sdfg is the same as the forward one + if parent_graph == self.sdfg and not self.separate_sdfgs: + self.backward_sdfg.add_edge(src=state, dst=backward_state, data=dace.InterstateEdge()) + + # Get all the in connections of the forward state + forward_state_in_edges = parent_graph.in_edges(state) + + # Get the backward state again + # We need to do this in case the state is linked to an initialization state + # For outgoing edges, we connect the actual state not its initialization + backward_state = self.reversed_states_map[state] + + for edge in forward_state_in_edges: + # Each incoming edge to a forward state will add an outgoing edge to a backward state + fwd_src = edge.src + if isinstance(fwd_src, SDFGState): + bwd_src = self.reversed_states_map[fwd_src] + elif isinstance(fwd_src, LoopRegion): + bwd_src = self.reversed_loops_map[fwd_src] + + graph = bwd_src.parent_graph + graph.add_edge(src=backward_state, dst=bwd_src, data=dace.InterstateEdge()) + + # Connect all the loops + for loop in self.reversed_loops_map.keys(): + + # Get the loop parent + parent_graph = loop.parent_graph + + # Get the reversed loop + reversed_loop = self.reversed_loops_map[loop] + + # Get all the out edges of the forward state + loop_out_edges = parent_graph.out_edges(loop) + + # If there are no outgoing connections + if len(loop_out_edges) == 0: + # This is an end-region and it needs to be connected to its reversed region + # We do this only if the backward sdfg is the same as the forward one + if parent_graph == self.sdfg and not self.separate_sdfgs: + self.backward_sdfg.add_edge(src=state, dst=backward_state, data=dace.InterstateEdge()) + + # Get all the in edges + loop_in_edges = parent_graph.in_edges(loop) + + for edge in loop_in_edges: + + # A loop region could be connected to a state or another loop region + fwd_src = edge.src + if isinstance(fwd_src, SDFGState): + bwd_src = self.reversed_states_map[fwd_src] + elif isinstance(fwd_src, LoopRegion): + bwd_src = self.reversed_loops_map[fwd_src] + + # Get the graph to add the edge to + if isinstance(parent_graph, LoopRegion): + bwd_parent_graph = self.reversed_loops_map[parent_graph] + else: + bwd_parent_graph = self.backward_sdfg + + bwd_parent_graph.add_edge(src=reversed_loop, dst=bwd_src, data=dace.InterstateEdge()) + + def _fill_interstate_edge_conditions_in_scope(self, graph: Union[SDFG, LoopRegion]) -> None: + """ + Get all the nodes within this graph in topological order, + Connect the states and call the function recursively on the nested scopes. + """ + # A dictionary that keeps track of the conditions necessary to reach a state in the forward passs + conditions_map: dict[SDFGState, str] = {} + + # Iterate through all the nodes in topological order + nodes = dace_utils.dfs_topological_sort(graph, graph.source_nodes()) + for node in nodes: + # A list of the conditions on all the in edges for this state + in_edges_conditions: List[str] = [] + if isinstance(node, SDFG) or isinstance(node, LoopRegion): + # if this is not a reversed loop region + if not node in self.reversed_loops_map: + continue + self._fill_interstate_edge_conditions_in_scope(node) + else: + + if not isinstance(node, SDFGState): + raise AutoDiffException(f"Expected SDFGState, got {type(node)}") + forward_state = node + parent_graph = forward_state.parent_graph + + # if this is not a reversed state + if node not in self.reversed_states_map: + continue + + # We will iterate through all the incoming edges to the forward state + edges_list = parent_graph.in_edges(forward_state) + + # If there are none, this is a start state + # If there is only one incoming edge, no condition necessary + if len(edges_list) < 2: + conditions_map[forward_state] = "1" + + for edge in edges_list: + # Get the src state + src_state = edge.src + + # Get the condition to get to the source state in the forward pass + src_state_condition = conditions_map[src_state] + + # Add the condition in the current edge + current_edge_condition = edge.data.condition.as_string + + # New backward edge condition + # Handle "1" (unconditional) to avoid creating expressions like "1 and condition" + if src_state_condition == "1" and current_edge_condition == "1": + new_bwd_edge_condition = "1" + elif src_state_condition == "1": + new_bwd_edge_condition = current_edge_condition + elif current_edge_condition == "1": + new_bwd_edge_condition = src_state_condition + else: + new_bwd_edge_condition = f"({src_state_condition}) and ({current_edge_condition})" + + bwd_edge = self._get_backward_state_edge(edge) + + # Add the condition to the edge + bwd_edge.data.condition = CodeBlock(new_bwd_edge_condition) + + # If there is a special case for the first iteration of the backward state + if forward_state in self.loop_states_view_map: + + # Get the corresponding edge between the loop states + bwd_loop_edge = self._get_backward_loop_state_edge(edge) + + # Add the same condition to the edge + bwd_loop_edge.data.condition = CodeBlock(new_bwd_edge_condition) + + # Add the forward condition to the list to update the conditions_map dict + if new_bwd_edge_condition != "1": + # Only add the condition if it exists + in_edges_conditions.append(new_bwd_edge_condition) + + # Update the conditions mapping + # This will be the logical or of all the saved conditions + # because we can reach this state by taking any of the incoming edges + if len(in_edges_conditions) == 0: + condition_for_state = "1" + else: + condition_for_state = in_edges_conditions[0] + for i in range(1, len(in_edges_conditions)): + condition_for_state += f" or {in_edges_conditions[i]}" + + # Since we are doing topological sort before iterating + conditions_map[node] = condition_for_state + + def _fill_interstate_edge_conditions(self) -> None: + """ + Go through all of the states in the forward graph and fill the necessary conditions in the backward states. + Each edge in the backward SDFG will be the logical AND between the equivalent edge in the forward SDFG and + all of the conditions that are necessary to get to this state in the forward pass. + """ + self._fill_interstate_edge_conditions_in_scope(self.sdfg) + + # Iterate through all the loop regions and connect the loop states if necessary + for loop in self.sdfg.all_control_flow_regions(): + # Only iterate over loop regions + if not isinstance(loop, LoopRegion): + continue + # Get the start state + loop_start_state = loop.start_block + if not isinstance(loop_start_state, SDFGState): + # This would be the case for perfectly nested loops + # Nothing to do in this case + continue + + if not loop_start_state in self.reversed_loop_states_map: + # There are no extra states to connect + continue + + # If there are loop states to connect + # Prepare the condition for the new state + loop_it = loop.loop_variable + reversed_loop = self.reversed_loops_map[loop] + start, _ = self._extract_loop_region_info(reversed_loop) + + # We only want the loop state to execute + # in the first iteration of the reversed loop + first_state_condition = f"{loop_it} == {start}" + first_state_condition = CodeBlock(first_state_condition) + + leftover_loop_state = self.reversed_loop_states_map[loop_start_state] + + # Get the reversed loop start state + reversed_loop_start_state = self.reversed_states_map[loop_start_state] + + # Add a state to the reversed loop region + new_start_state = reversed_loop.add_state_before(reversed_loop_start_state, + is_start_block=True, + condition=first_state_condition) + + # The condition for this interstate edge should be all iterations expect the fist + leftover_iterations_condition = f"not {first_state_condition.as_string}" + + # Add a connection between this new start state and the first iteration state + reversed_loop.add_edge(src=new_start_state, + dst=leftover_loop_state, + data=dace.InterstateEdge(condition=leftover_iterations_condition)) + + def _add_interstate_edge_assignments(self) -> None: + """ + We will need to add interstate assignments at the start of the backward SDFG + This is necessary to make sure the control flow in the backward pass is correctly preserved. + """ + # We will add an empty state to the backward pass which will have all the assignments + + new_assignments = {} + # Get all the interstate edges in the forward sdfg + for edge in self.sdfg.all_interstate_edges(): + if edge.data.assignments: + # There are assignments to be added to the start of the backward pass + new_assignments = {**new_assignments, **edge.data.assignments} + + # We need to check if any data needs to be used in these assignment + # This is important in the case of a NSDFG where data will need to be forwarded + for _, rhs in edge.data.assignments.items(): + # If any of the sdfg arrays are in the rhs assignment + assignment_arrays = [array for array in self.sdfg.arrays.keys() if array in rhs] + if assignment_arrays and self.separate_sdfgs: + # We need to forward this data to the backward pass + for array in assignment_arrays: + if array not in self.backward_input_arrays: + self.backward_input_arrays[array] = self.sdfg.arrays[array] + # Special case if this is a symbol that is doesn't have a descriptor yet + if array not in self.backward_sdfg.arrays: + # We add it now + self.backward_sdfg.add_datadesc(array, copy.deepcopy(self.sdfg.arrays[array])) + + if new_assignments: + # Add the new state to the backward pass + # First we get the start block of the backward pass + if self.separate_sdfgs: + bwd_start_block = self.backward_sdfg.start_block + else: + fwd_start_state = self.sdfg.start_block + if isinstance(fwd_start_state, LoopRegion): + bwd_start_block = self.reversed_loops_map[fwd_start_state] + elif isinstance(fwd_start_state, SDFGState): + bwd_start_block = self.reversed_states_map[fwd_start_state] + else: + raise AutoDiffException("Need to add an assignments state but can't find the start block") + # TODO would this work on a loop region? + self.backward_sdfg.add_state_before(state=bwd_start_block, + label="_bwd_interstate_assignments_state", + assignments=new_assignments) + + def is_within_map(self, state: SDFGState, node: nodes.AccessNode) -> bool: + # Get the scope dictionary for the state + scope_dict = state.scope_dict() + + # Check if the node is within the scope of a map + scope_entry = scope_dict.get(node, None) + while scope_entry is not None: + if isinstance(scope_entry, nodes.MapEntry): + return True + scope_entry = scope_dict.get(scope_entry, None) + + return False + + def _zero_out_gradient(self, forward_state: SDFGState, forward_node: nodes.AccessNode, memlet: Memlet) -> None: + """ + Zero out gradients for overwritten arrays in the forward pass. + + Overwritten arrays need their gradients zeroed for gradient accumulation + to work correctly. This method: + + 1. Copies current gradient values to a temporary array (for one last use + in the backward pass) + 2. Zeros out the overwritten access in the backward pass + 3. Updates the read mapping to use the temporary instead of the original + + The operation is skipped when possible to optimize performance. + + :param forward_state: The state in the forward pass containing the write. + :param forward_node: The access node being overwritten. + :param memlet: The memlet describing the write operation. + """ + # Extra checks to only do this if necessary + # If this access node is not written to in the forward pass except for this one time, we don't need to zero it out + # An exception is made for required gradients that can be read outside the scope of the SDFG + clear_out_gradients = forward_node.data in self.required_gradients_data + + # Get the write instances in the forward sdfg to this node that happen in states before the current state + # These will represent the reads that will happen after this AccessNode + # This should avoid unnecessary zeroing out of dace generated temporaries + for state in self.state_order[0:self.state_order.index(forward_state) + 1]: + state_view = self.states_view_map[state] + for node, parent in state_view.all_nodes_recursive(): + if isinstance(node, nodes.AccessNode) and node.data == forward_node.data: + if parent.in_degree(node) > 0: + # We need to check if the the forward node is inside a map scope or a LoopRegion + within_loop, _ = ad_utils.state_within_loop(state) + within_map = self.is_within_map(state, node) + if node != forward_node or (node == forward_node and (within_loop or within_map)): + clear_out_gradients = True + break + + # We can avoid clearing out the gradients + if not clear_out_gradients: + return + + # Get the backward state + backward_state: SDFGState = self.reversed_states_map[forward_state] + + # Get the backward node + backward_node: nodes.AccessNode = self.reverse_map[forward_node] + + # Get the original array + array_desc = self.backward_sdfg.arrays[backward_node.data] + + if dtypes.can_access(dtypes.ScheduleType.CPU_Multicore, array_desc.storage): + cuda = False + elif dtypes.can_access(dtypes.ScheduleType.GPU_Default, array_desc.storage): + cuda = True + else: + raise ValueError(f"Unsupported storage {array_desc.storage}") + + # Careful! The order of the ifs here matters since ArrayView is a subclass of Array + if isinstance(array_desc, dt.View): + # No need to initialize: the viewed array will always be visited + # (since a view can never be a required grad), and thus the viewed array will be initialized. + pass + elif isinstance(array_desc, (dt.Array, dt.Scalar)): + # Create a new memlet to write to the gradient arrays + map_exit_memlet = copy.deepcopy(memlet) + map_exit_memlet.data = backward_node.data + + # Create the tasklet to zero out only the section in the memlet + # First, Get the range that the zeroout map should iterate over + # TODO: We are looking at writes in the forward pass, + # We should take the dst_subset of the memlet + # Are there cases where dst_subset is None? + ranges = [] + for iteration in map_exit_memlet.dst_subset: + if isinstance(iteration, tuple): + # The end of the range is inclusive in the loop + # We add 1 to get the upper bound for the map + ranges.append((iteration[0], iteration[1] + 1)) + elif isinstance(iteration, sp.Number): + # This covers the case of a single element being written + ranges.append((int(iteration), int(iteration) + 1)) + else: + raise AutoDiffException(f"Unsupported subset type {type(iteration)} in memlet {memlet}") + + # Create the indices dict + indices = {f"i{i}": f"{start}:{end}" for i, (start, end) in enumerate(ranges)} + + # Create the tasklet memlet from the indices + tasklet_memlet = dace.Memlet.simple(backward_node.data, ", ".join(indices.keys())) + + # Create the tasklet + _, map_entry, map_exit = backward_state.add_mapped_tasklet( + "_clear_" + backward_node.data + "_", + indices, {}, + f"__out = 0", { + "__out": tasklet_memlet, + }, + schedule=dtypes.ScheduleType.GPU_Device if cuda else dtypes.ScheduleType.Default, + external_edges=True) + + # Get the edge from the map exit to the backward node + edge = backward_state.out_edges(map_exit)[0] + + # Get the cleared out AN + cleared_out_node = edge.dst + if not isinstance(cleared_out_node, nodes.AccessNode): + raise AutoDiffException(f"Expected AccessNode as cleared out node, got {type(cleared_out_node)}") + + # Create a copy of new memlet that will keep its other subset + # We want to copy the elements to their same indices in the new tmp array + # Create a new memlet that copies what memlet is writing to to the tmp + new_memlet_subset = memlet.subset if memlet.data == forward_node.data else memlet.other_subset + original_to_tmp_memlet = dace.Memlet(data=backward_node.data, + subset=new_memlet_subset, + other_subset=new_memlet_subset) + + # Remove the src_subset of the new memlet and replace the memlet in the edge + map_exit_memlet.subset = memlet.subset if memlet.data == forward_node.data else memlet.other_subset + map_exit_memlet.other_subset = None + edge.data = map_exit_memlet + + # Add an edge from the backward_node to the new map entry + backward_state.add_edge(backward_node, None, map_entry, None, dace.Memlet()) + + # A race will happen unless we make sure the data is being copied being it is zeroed out + # There is a read from the same array + # We need to add a transient that reads the content from forward pass before it is zeroed out + # Create a new array descriptor for the transient + transient_desc = copy.deepcopy(array_desc) + transient_desc.transient = True + + # Add the new array to the sdfg + transient_name = self.array_grad_name(forward_node.data) + "_tmp" + + # Check if the array is already in the backward sdfg + if transient_name not in self.backward_sdfg.arrays: + self.backward_sdfg.add_datadesc(transient_name, transient_desc) + + # Create an AcessNode for this transient and add it to backward state + transient_node = backward_state.add_read(transient_name) + + # Add a read from the backward node to the transient + backward_state.add_edge(backward_node, None, transient_node, None, original_to_tmp_memlet) + + # Add an empty edge from the transient to the map entry + backward_state.add_edge(transient_node, None, map_entry, None, dace.Memlet()) + if backward_node not in self.zeroed_out: + self.zeroed_out[backward_node] = [transient_node] + else: + self.zeroed_out[backward_node].append(transient_node) + else: + raise AutoDiffException("Unsupported data descriptor {}".format(array_desc)) + + def _remove_onnx_attribute_accessnodes(self, nodes_list: List[nodes.Node], state: SDFGState) -> None: + """Remove ONNX attribute AccessNodes that don't need gradient tracking. + + For some ONNX operators, nodes have attributes as input connectors even if the inputs are actually constant. + Examples of such attributes are `axis` and `keepdims` in `ReduceSum`. + Gradients for these attributes should not be tracked since they represent control flow and not data flow. + """ + attribute_to_remove = {"axis", "keepdims", "axes", "p", "dilations", "kernel_shape", "strides"} + for node in nodes_list[:]: # Iterate over a copy of the list to avoid modification issues + if isinstance(node, nodes.AccessNode): + out_edges = state.out_edges(node) + if out_edges and all( + ONNX_AVAILABLE and isinstance(edge.dst, ONNXOp) and edge.dst_conn in attribute_to_remove + for edge in out_edges): + nodes_list.remove(node) + + def _remove_maps_without_input_connectors(self, nodes_list: List[nodes.Node], state: SDFGState) -> None: + """Remove maps that don't have any input connectors from the nodes_list. + + These are maps that won't have an output in the backward pass and thus can be skipped from the reversal process. + Note that we do not remove the AccessNode that the no-input map writes to. + This is because we might need to zero out the gradient of this node. + If no zeroing out is necessary, the node will be removed in the reverse_subgraph function cleanup at the end. + """ + for node in nodes_list[:]: # Iterate over a copy of the list to avoid modification issues + if isinstance(node, nodes.MapEntry) and len(node.in_connectors) == 0: + nodes_list.remove(node) + # Remove the MapExit and everything in between + # Get the equivalent map exit for the map entry + map_exit = state.exit_node(node) + nodes_list.remove(map_exit) + + # Get all the nodes between the map entry and exit + for state_node in state.nodes(): + # Check the scope of the node if it is within the map + if state_node in state.scope_dict() and state.scope_dict( + )[state_node] == node and state_node in nodes_list: + nodes_list.remove(state_node) + + def _find_subgraph_to_differentiate(self) -> None: + """Determine which nodes we need to reverse; this forms the subgraph we will differentiate. + + We do a reverse BFS from the target output node. + In the case where a state is within a loop, this may result in different subgraphs + depending on the loop iteration. + + To calculate the gradients for a node x in ``required_gradients``, we need to sum up the gradient + contributions from every node y where x is used as an input. + """ + backward_nodes: set[nodes.Node] = set() + given_gradients_all_states = set(self.given_gradients_data) + + required_gradients_all_states = {n for n in self.required_gradients_data} + given_gradients_all_states = given_gradients_all_states | required_gradients_all_states + + # Do the backward BFS iteratively + for state in reversed(self.state_order): + state_given_gradients: List[nodes.AccessNode] = [] + + for node in state: + if isinstance(node, nodes.AccessNode) and node.data in given_gradients_all_states: + state_given_gradients.append(node) + + backward_nodes = {n for e in state.edge_bfs(state_given_gradients, reverse=True) for n in [e.src, e.dst]} + nodes_list = list(backward_nodes) + + # Clean up unwanted elements + self._remove_maps_without_input_connectors(nodes_list, state) + self._remove_onnx_attribute_accessnodes(nodes_list, state) + + state_subgraph = dstate.StateSubgraphView(state, nodes_list) + + state_subgraph = self._add_missing_nested_sdfg_connectors_to_view(state=state, + state_subgraph=state_subgraph, + view_nodes=nodes_list) + + # Add mapping + self.states_view_map[state] = state_subgraph + + # In the case where this state is within a for loop + within_loop, _ = ad_utils.state_within_loop(state) + if within_loop: + # Other elements that are not within state_subgraph will need to be reversed + # We create a separate mapping for these elements + + # Get all the access nodes that are used in the previous view + subgraph_an = [node.data for node in state_subgraph.nodes() if isinstance(node, nodes.AccessNode)] + + # For each access node in this view + for state_node in state: + if isinstance(state_node, nodes.AccessNode) and state_node.data in subgraph_an: + state_given_gradients.append(state_node) + + # Do reverse BFS starting from this new set of nodes + backward_nodes = { + n + for e in state.edge_bfs(state_given_gradients, reverse=True) + for n in [e.src, e.dst] + } + + view_nodes = list(backward_nodes) + self._remove_maps_without_input_connectors(nodes_list, state) + + loop_state_subgraph = dstate.StateSubgraphView(state, view_nodes) + + loop_state_subgraph = self._add_missing_nested_sdfg_connectors_to_view( + state=state, state_subgraph=loop_state_subgraph, view_nodes=view_nodes) + + # If the two views are different + # Here we only check if the number of nodes is the same + # Since states_view_map[state] is a subset of loop_states_view_map[state] + if len(state_subgraph) != len(loop_state_subgraph): + self.loop_states_view_map[state] = loop_state_subgraph + + # Update the list of given gradients to use for states + for node in backward_nodes: + if isinstance(node, nodes.AccessNode) and node.data not in given_gradients_all_states: + # We want all of the backward AccessNodes that made it to the intersection + given_gradients_all_states.add(node.data) + + def array_grad_name(self, forward_name: str) -> str: + """Return the gradient name of a name from the forward pass.""" + if forward_name not in self.array_grad_map: + self.array_grad_map[forward_name] = \ + self.backward_sdfg._find_new_name("gradient_" + forward_name) + + return self.array_grad_map[forward_name] + + def _add_gradient_data_descriptor(self, data_name: str) -> dt.Array: + """Add the data descriptor for the gradient for `data_name`. + + :param data_name: The name of the forward descriptor. + """ + grad_name = self.array_grad_name(data_name) + + if grad_name in self.backward_sdfg.arrays: + raise AutoDiffException(f"descriptor for gradient of {data_name} ({grad_name}) already exists") + + array = self.sdfg.arrays[data_name] + + if not isinstance(array, (dt.Scalar, dt.Array, dt.View)): + raise AutoDiffException("Unsupported data descriptor {}".format(array)) + + cloned_datadesc = copy.deepcopy(array) + + # only the grads of the inputs and the outputs are not transient + cloned_datadesc.transient = data_name not in self.input_names and data_name not in self.output_names + + # Store references + self.backward_grad_arrays[grad_name] = cloned_datadesc + self.backward_sdfg.arrays[grad_name] = cloned_datadesc + return cloned_datadesc + + def _reverse_loop_conditional(self, loop: LoopRegion) -> str: + """Given a loop region as a parameter, create the conditional for the reversed version of this loop.""" + + # Get the loop iterator + it = loop.loop_variable + + # Get the loop start + start, _ = ad_utils.extract_loop_region_info(loop) + + # Get the stride sign + stride_sign = ad_utils.get_stride_sign(loop) + + # Reverse the conditional to end at the start of the original loop + # This will be incremented or decremented depending on the stride + if stride_sign > 0: + reversed_condition = f"{it} > {start}-1" + else: + reversed_condition = f"{it} < {start}+1" + + return reversed_condition + + def _reverse_loop_initial_statement(self, loop: LoopRegion) -> str: + """Given a loop region as a parameter, create the initialization statement for the reversed version of this loop.""" + # Get the loop iterator + it = loop.loop_variable + + stride_sign = ad_utils.get_stride_sign(loop) + + # Get the loop end + _, end = ad_utils.extract_loop_region_info(loop) + + # Reverse the initialization to start from the end of the forward loop + # This will be incremented or decremented depending on the stride + if stride_sign > 0: + init_expr = f"{it} = {end}-1" + else: + init_expr = f"{it} = {end}+1" + + return init_expr + + def _reverse_loop_update_statement(self, loop: LoopRegion) -> str: + """Given a loop region as a parameter, create the update statement for the reversed version of this loop.""" + + # Get the original update statement + fwd_update = loop.update_statement.as_string + + stride_sign = ad_utils.get_stride_sign(loop) + + # If the stride is positive + if stride_sign > 0: + update_statement = fwd_update.replace("+", "-") + else: + # If the stride is negative + update_statement = fwd_update.replace("-", "+") + + return update_statement + + def _match_loop_region(self, fwd_loop: LoopRegion) -> LoopRegion: + """Create the backward LoopRegion and fill it with the reversal of the forward LoopRegion.""" + + init_expr = self._reverse_loop_initial_statement(fwd_loop) + reversed_condition = self._reverse_loop_conditional(fwd_loop) + update_statement = self._reverse_loop_update_statement(fwd_loop) + + # Create the label + reversed_label = f"{fwd_loop.label}_reversed" + + # Create the loop object and return it + reversed_loop = LoopRegion(label=reversed_label, + initialize_expr=init_expr, + condition_expr=reversed_condition, + update_expr=update_statement, + loop_var=fwd_loop.loop_variable) + + return reversed_loop + + def _reverse_loop_region(self, loop: LoopRegion): + """Given a LoopRegion as a parameter, reverse it, add the loop states that belong in this region.""" + + # Create the reversed loop region + reversed_loop = self._match_loop_region(fwd_loop=loop) + self.reversed_loops_map[loop] = reversed_loop + + # Add the reversed loop directly + parent_graph = self._get_reversed_parent_graph(loop) + parent_graph.add_node(reversed_loop) + + # Add all the loop nodes to the graph and recursivly reverse child loop regions + for node in loop.nodes(): + if isinstance(node, LoopRegion): + + # This node shouldn't be reversed already since we're going top-down + if node in self.reversed_loops_map: + raise AutoDiffException(f"Loop {node} has already been reversed") + self._reverse_loop_region(node) + elif isinstance(node, SDFGState): + + # Get the backward_node + bwd_node = self.reversed_states_map[node] + + # Remove from the backward SDFG + self.backward_sdfg.remove_node(bwd_node) + + # Add it to the loop region + reversed_loop.add_node(bwd_node) + + # Also add loop states if any + if node in self.reversed_loop_states_map: + # Get the backward_node + bwd_node = self.reversed_loop_states_map[node] + + def _add_missing_nested_sdfg_connectors_to_view(self, state: SDFGState, state_subgraph: dstate.StateSubgraphView, + view_nodes: List[nodes.Node]): + """Add missing NestedSDFG connectors to the view for correctness. + + There is a special case for NestedSDFGs that we need to fix + in the case where a NestedSDFG has an inout connector, + but we only care about one of those connectors for the sake of AD. + We need to add the missing connector for correctness. + TODO: This is only a problem if the said connector is written to + inside the NestedSDFG. + """ + # In the case where a NestedSDFG has an inout connector, + # but we only care about one of those connectors for the sake of AD + # we need to add the missing connector for correctness + # TODO: this is only a problem if the said connector is written to + # inside the NestedSDFG + # Iterate over the nested SDFGs in the view + for g in state_subgraph.nodes(): + if isinstance(g, nodes.NestedSDFG): + + inout_connectors = set(g.in_connectors).intersection(set(g.out_connectors)) + # If there are any inout connectors + if len(inout_connectors) > 0: + out_connectors = {edge.src_conn: edge for edge in state.out_edges(g)} + in_connectors = {edge.dst_conn: edge for edge in state.in_edges(g)} + view_out_connectors = {edge.src_conn: edge for edge in state_subgraph.out_edges(g)} + view_in_connectors = {edge.dst_conn: edge for edge in state_subgraph.in_edges(g)} + for con in inout_connectors: + # Check if it is missing in the out or in connectors of the view + if con in view_out_connectors and con not in view_in_connectors: + # Get the equivalent in node and connector + edge = in_connectors[con] + if not isinstance(edge.src, nodes.AccessNode): + raise AutoDiffException(f"Expected AccessNode as source, got {type(edge.src)}") + view_nodes.append(edge.src) + if con not in view_out_connectors and con in view_in_connectors: + # Add the corresponding edge to the view + edge = out_connectors[con] + if not isinstance(edge.dst, nodes.AccessNode): + raise AutoDiffException(f"Expected AccessNode as destination, got {type(edge.dst)}") + view_nodes.append(edge.dst) + + return dstate.StateSubgraphView(state, view_nodes) + + def _compare_memlet_accesses_to_array_size(self, data_name: str, memlet: Memlet) -> int: + """Compare the memlet range with the size of the array to see if the array is being overwritten.""" + total_size = self.backward_sdfg.arrays[data_name].total_size + try: + if total_size > memlet.num_accesses: + return 1 + elif memlet.num_accesses == total_size: + return 0 + + # Something is wrong here raise an exception + raise AutoDiffException(f"Memlet {memlet} has more accesses than the size of the data {data_name}") + + # If the comparison can not be made, return None + except TypeError: + return None + + def _get_reversed_parent_graph(self, forward_node: nodes.Node): + """Given a node in the SDFG, get the reversed parent of this node.""" + fwd_parent_graph = forward_node.parent_graph + + if fwd_parent_graph == self.sdfg: + parent_graph = self.backward_sdfg + elif isinstance(fwd_parent_graph, SDFGState): + parent_graph = self.reversed_states_map[fwd_parent_graph] + elif isinstance(fwd_parent_graph, LoopRegion): + parent_graph = self.reversed_loops_map[fwd_parent_graph] + + return parent_graph + + def _get_backward_loop_state_edge(self, forward_edge: dace.InterstateEdge) -> dace.InterstateEdge: + """Given an edge from the forward pass, return the equivalent edge in the backward pass.""" + # Get the source and destination states + forward_src = forward_edge.src + forward_dst = forward_edge.dst + + if isinstance(forward_src, LoopRegion): + fwd_src_is_loop = True + if forward_src not in self.reversed_loops_map: + raise AutoDiffException(f"Forward loop {forward_src} not found in reversed loops map") + else: + fwd_src_is_loop = False + if forward_src not in self.reversed_states_map: + raise AutoDiffException(f"Forward state {forward_src} not found in reversed states map") + + if isinstance(forward_dst, LoopRegion): + fwd_dst_is_loop = True + if forward_dst not in self.reversed_loops_map: + raise AutoDiffException(f"Forward loop {forward_dst} not found in reversed loops map") + else: + fwd_dst_is_loop = False + if forward_dst not in self.reversed_states_map: + raise AutoDiffException(f"Forward state {forward_dst} not found in reversed states map") + + # Note that the source will become the destination + backward_dst = self.reversed_states_map[forward_src] if not fwd_src_is_loop else self.reversed_loops_map[ + forward_src] + backward_src = self.reversed_states_map[forward_dst] if not fwd_dst_is_loop else self.reversed_loops_map[ + forward_dst] + + # Each one of these in edges needs to have an equivalent + # out edge in the backward part of the SDFG + bwd_edge = None + connection_state = backward_dst + + # Find the equivalent edge in the backward SDFG + for b_edge in connection_state.parent_graph.in_edges(connection_state): + if b_edge.src == backward_src: + bwd_edge = b_edge + break + + if not bwd_edge: + raise AutoDiffException(f"Can't find the equivalent edge of {forward_edge} in the backward pass") + + return bwd_edge + + def _get_backward_state_edge(self, forward_edge: dace.InterstateEdge) -> dace.InterstateEdge: + """Given an edge from the forward pass, return the equivalent edge in the backward pass.""" + # Get the source and destination states + forward_state_src = forward_edge.src + forward_state_dst = forward_edge.dst + + # Get the equivalent states in the backward pass + if (forward_state_src not in self.reversed_states_map and forward_state_src not in self.reversed_loops_map): + raise AutoDiffException(f"Forward state source {forward_state_src} not found in reversed maps") + if (forward_state_dst not in self.reversed_states_map and forward_state_src not in self.reversed_loops_map): + raise AutoDiffException(f"Forward state destination {forward_state_dst} not found in reversed maps") + + # Note that the src will become the destination + backward_state_dst = self.reversed_states_map[ + forward_state_src] if forward_state_src in self.reversed_states_map else self.reversed_loops_map[ + forward_state_src] + backward_state_src = self.reversed_states_map[ + forward_state_dst] if forward_state_dst in self.reversed_states_map else self.reversed_loops_map[ + forward_state_dst] + + # Each one of these in edges needs to have an equivalent + # out edge in the backward part of the SDFG + bwd_edge = None + connection_state = backward_state_dst + + # Find the equivalent edge in the backward SDFG + for b_edge in connection_state.parent_graph.in_edges(connection_state): + if b_edge.src == backward_state_src: + bwd_edge = b_edge + break + + if not bwd_edge: + raise AutoDiffException(f"Can't find the equivalent edge of {forward_edge} in the backward pass") + + return bwd_edge + + def _str_to_access(self, data: str, source: str) -> nodes.AccessNode: + """Given a string containing the name of the accessed array, return the AccessNode in the state. + + Given a string containing the name of the accessed array, return the AccessNode in the state + that points to this array. + If there are multiple AccessNodes, the behavior will depend on whether we want + an output or input AccessNode. + Input: We will return the first occurrence of this node in the state and make sure there are + only outgoing edges from this node. + Output: We will return the last occurrence of this node in the state + where the node only has incoming edges. + """ + matches = [(node, state) for state in self.sdfg.states() for node in state.nodes() + if isinstance(node, nodes.AccessNode) and node.data == data] + # Unused in model + if len(matches) == 0: + return None + + # there is only a single AccessNode with this name + if len(matches) == 1: + return matches[0][0] + + # len(matches) > 1 + else: + # There are multiple occurrences of the same AccessNode + if source == "inputs": + # We return the first node with this data + input_node: nodes.AccessNode = matches[0][0] + return input_node + + if source == "outputs": + # Go through the list of matches in reverse + for output_node, output_node_state in reversed(matches): + # We want the first node that has at least one incoming edge to it + # This represents the last time the output data was modified + in_edges = output_node_state.in_edges(output_node) + if len(in_edges) > 0: + return output_node + + raise AutoDiffException( + f"The specified output {data} was not written to by any AccessNode in this state") + + raise AutoDiffException(f"There are multiple nodes with data {data} " + f" but the source (inputs or outputs) was not specified correctly") + + def _expand_nodes(self) -> bool: + """Expand all library nodes in the sdfg to pure implementations. + + Returns whether something was expanded. + """ + expanded_something = False + for state_view in self.states_view_map.values(): + for node, parent_graph in state_view.all_nodes_recursive(): + if isinstance(parent_graph, dstate.StateSubgraphView): + parent_graph = parent_graph.graph + + # Check if the node exists in the backward implementation repository + if find_backward_implementation(parent_graph.parent_graph, parent_graph, node) is not None: + continue + + # Only check others if we didn't break out of the above loop + if ONNX_AVAILABLE and isinstance(node, ONNXOp): + impls = ONNXForward.registered_implementations(node.schema.name) + + # Order the implementations so that implementations containing "pure" are tried first + impls = [i for name, i in impls if "pure" in name] + [i for name, i in impls if "pure" not in name] + for impl in impls: + if impl.forward_can_be_applied(node, parent_graph, self.sdfg): + # Configure the module-level expansion class + ExpansionTemplate.environments = impl.environments if hasattr(impl, "environments") else [] + ExpansionTemplate._impl = impl + ExpansionTemplate._match_node = xf.PatternNode(type(node)) + ExpansionTemplate.apply_to(parent_graph.parent, verify=False, _match_node=node) + expanded_something = True + break + + # This could later on be changed to check if the expansion is differentiable and if not, move + # on to the next expansion. For now we will just apply the first one that matches, prioritizing ones that + # have "pure" in the name + if isinstance(node, nodes.LibraryNode) and not (ONNX_AVAILABLE and isinstance(node, ONNXOp)): + # Try to select an expansion + if hasattr(node, "implementations"): + implementations = node.implementations + + pure_candidates = [name for name, _ in sorted(implementations.items()) if "pure" in name] + if len(pure_candidates) > 0: + expansion = pure_candidates[0] + else: + expansion = node.implementation + else: + expansion = node.implementation + + node.implementation = expansion + node.expand(parent_graph.parent, parent_graph) + expanded_something = True + + return expanded_something + + def _get_node_state(self, node: nodes.Node) -> SDFGState: + """Return the SDFG state that contains this node.""" + matches = [] + for state in self.sdfg.states(): + if node in state.nodes(): + matches.append(state) + + if len(matches) != 1: + raise AutoDiffException(f"Expected exactly one match, got {len(matches)}") + return matches[0] + + def _connect_conditional_map_exist(self, forward_state: SDFGState, backward_state: SDFGState, + backward_map_exit: nodes.MapExit, fwd_tasklet: nodes.Tasklet): + """Connect the map exit of a conditional tasklet to a new access node which will zero out the gradient. + """ + + if len(backward_map_exit.in_connectors) != 0: + raise AutoDiffException( + f"Expected no input connectors on backward map exit, got {len(backward_map_exit.in_connectors)}") + + # Add the in and out connectors for the zero-out operation + backward_map_exit.add_in_connector("IN_zero_out") + backward_map_exit.add_out_connector("OUT_zero_out") + + # Get the memlet data for the edge from the tasklet to the map exist + tasklet_out_edge = forward_state.out_edges(fwd_tasklet) + if len(tasklet_out_edge) != 1: + raise AutoDiffException(f"Expected exactly one tasklet output edge, got {len(tasklet_out_edge)}") + tasklet_out_edge = tasklet_out_edge[0] + tasklet_memlet_path = forward_state.memlet_path(tasklet_out_edge) + if len(tasklet_memlet_path) != 2: + raise AutoDiffException(f"Expected tasklet memlet path of length 2, got {len(tasklet_memlet_path)}") + + # Copy the memlet and change the data name + memlet_data = copy.deepcopy(tasklet_memlet_path[0].data) + memlet_data.data = self.array_grad_map[memlet_data.data] + + # Get the reversed tasklet + bwd_tasklet = self.reverse_map[fwd_tasklet] + + # Connect this map exist to the tasklet + backward_state.add_edge(bwd_tasklet, "__zero_out_conn__", backward_map_exit, "IN_zero_out", memlet_data) + + # Replicate the target accedd node and connect it + fwd_target_an: nodes.AccessNode = tasklet_memlet_path[-1].dst + if not isinstance(fwd_target_an, nodes.AccessNode): + raise AutoDiffException(f"Expected AccessNode for forward target, got {type(fwd_target_an)}") + if fwd_target_an not in self.reverse_map: + raise AutoDiffException(f"Forward target AccessNode {fwd_target_an} not found in reverse map") + bwd_target_an = self.reverse_map[fwd_target_an] + + replicated_bwd_target_an = copy.deepcopy(bwd_target_an) + backward_state.add_node(replicated_bwd_target_an) + + an_memlet_data: nodes.AccessNode = copy.deepcopy(tasklet_memlet_path[1].data) + an_memlet_data.data = self.array_grad_map[an_memlet_data.data] + backward_state.add_edge(backward_map_exit, "OUT_zero_out", replicated_bwd_target_an, None, an_memlet_data) + + # We need to get the map entry that starts the conditional block + # First get the conditional tasklet + conditional_block = self._extract_conditional_array_assignment_block( + forward_state=forward_state, tasklet_node=fwd_tasklet, subgraph=self.states_view_map[forward_state]) + # Get the map entry of the conditional bloc + map_entries = [n for n in conditional_block if isinstance(n, nodes.MapEntry)] + + if len(map_entries) != 1: + raise AutoDiffException( + f"Expected a single MapEntry node in the conditional block, found {len(map_entries)}") + else: + map_entry = map_entries[0] + + # Add the new access node to a dictionary in case it needs to be connected + self.conditional_block_entry[map_entry] = replicated_bwd_target_an + + def _conditional_tasklet(self, tasklet_node: nodes.Tasklet): + """Check if this tasklet contains a conditional. + + This only happens in conditional array assignments and requires special treatment in reversing the graph. + """ + # sanity check + if not isinstance(tasklet_node, nodes.Tasklet): + raise AutoDiffException(f"Expected Tasklet node, got {type(tasklet_node)}") + + # get the code string and check if there is an if + # TODO: How to more accurately check this? + return "if" in tasklet_node.code.as_string + + def _conditional_nested_sdfg(self, forward_state: SDFGState, node: nodes.NestedSDFG): + """Check if this NestedSDFG contains a conditional. + + This only happens in conditional array assignments and requires special treatment in reversing the graph. + """ + # sanity check + if not isinstance(node, nodes.NestedSDFG): + raise AutoDiffException(f"Expected NestedSDFG node, got {type(node)}") + + # get the incoming edges to the sdfg + in_edges = forward_state.in_edges(node) + + # check if any of the incoming edges are boolean edges + for edge in in_edges: + if self.sdfg.arrays[edge.data.data].dtype == dace.bool: + return True + + # get the code string and check if there is an if + return False + + def _extract_conditional_array_assignment_block(self, forward_state: SDFGState, tasklet_node: nodes.Node, + subgraph: dstate.SubgraphView): + """Extract a conditional array assignment block. + + Given a conditional tasklet, check if this is a conditional array assignment of the type + A[A>=0 and A<=5] = cst. At the moment the function only supports constant assignments. + """ + try: + + if not isinstance(tasklet_node, nodes.Tasklet): + raise AutoDiffException(f"Expected Tasklet node, got {type(tasklet_node)}") + # This applies to both Tasklet and NestedSDFG nodes + # get the AccessNode containing the boolean values for this assignment + tasklet_in_edges = forward_state.in_edges(tasklet_node) + tasklet_boolean_edge = None + single_boolean_edge_found = False + for edge in tasklet_in_edges: + edge_type = self.sdfg.arrays[edge.data.data].dtype + if edge_type == dace.bool: + # sanity check + if single_boolean_edge_found: + # we expect there to be a single AccessNode where the booleans come from + raise AutoDiffException( + "Multiple boolean edges found for conditional assignment. Expected only one.") + tasklet_boolean_edge = edge + single_boolean_edge_found = True + + if tasklet_boolean_edge is None: + raise AutoDiffException("Expected to find a boolean edge for conditional assignment") + tasklet_in_memlet_path = forward_state.memlet_path(tasklet_boolean_edge) + # the first element in the path is the boolean AN + bools_an = tasklet_in_memlet_path[0].src + if not isinstance(bools_an, nodes.AccessNode): + raise AutoDiffException(f"Expected AccessNode for boolean values, got {type(bools_an)}") + + # save all the nodes in the path to the assignment block list + conditional_assingement_block_nodes = { + n + for e in forward_state.edge_bfs(bools_an, reverse=True) + for n in [e.src, e.dst] + } + + # if any of the nodes in the block are required for gradient tracking + nodes_to_keep_tracking: set[nodes.Node] = self._get_gradient_nodes_to_track( + forward_state=forward_state, block_nodes=conditional_assingement_block_nodes, subgraph=subgraph) + for node in nodes_to_keep_tracking: + # we get the reverse bfs of this node and remove it from block nodes to avoid skipping these nodes + node_subgraph = {n for e in forward_state.edge_bfs(node, reverse=True) for n in [e.src, e.dst]} + + # add the node itself + node_subgraph.add(node) + conditional_assingement_block_nodes = conditional_assingement_block_nodes.difference(node_subgraph) + + except Exception as e: + # if this is not the structure we are expecting, fail + raise AutoDiffException(f"The boolean datatype in edges is limited to conditional array assingements." + f" This stucture is not supported.") from e + + return conditional_assingement_block_nodes + + def _get_gradient_nodes_to_track(self, forward_state: SDFGState, block_nodes: List[nodes.Node], + subgraph: dstate.SubgraphView): + """Get gradient nodes that need tracking in conditional assignments. + + When extracting the block for a conditional assignment, we need to make sure we keep tracking + the required gradient AccessNodes. + This function checks all the required access nodes that are in the conditional block. + At the moment this is just the target access node. + """ + nodes_to_track: List[nodes.AccessNode] = [] + gradient_nodes = [n for n in self.required_gradients_data] + gradient_nodes += [n for n in self.given_gradients_data] + + # get the subgraph difference + difference = set(subgraph.nodes()).difference(set(block_nodes)) + + # go through all the access nodes in the conditional block + for node in block_nodes: + if not isinstance(node, nodes.AccessNode): + continue + + # we always want to track the gradient nodes + if node.data in gradient_nodes: + nodes_to_track.append(node) + continue + # if this access node has multiple edges and any of them are outside the block + + node_out_edges = forward_state.out_edges(node) + if len(node_out_edges) > 1: + for edge in node_out_edges: + if edge.dst in difference: + nodes_to_track.append(node) + data = node.data + + # search for this array in the graph difference + for d_node in difference: + if not isinstance(d_node, nodes.AccessNode): + continue + if d_node.data == data: + nodes_to_track.append(node) + return nodes_to_track + + def _reverse_subgraph(self, forward_state: SDFGState, backward_state: SDFGState, + subgraph: dstate.StateSubgraphView) -> None: + """Reverse a given subgraph by reversing all nodes within it. + + :param forward_state: The forward state containing the subgraph. + :param backward_state: The backward state to add reversed nodes to. + :param subgraph: The subgraph view containing nodes to reverse. + """ + + # Conditional assignment nodes + conditional_assignment_nodes: List[nodes.Node] = [] + + # A reversed topological sort is a topological sort on the reverse graph + for node in reversed(list(dace_utils.dfs_topological_sort(subgraph, subgraph.source_nodes()))): + + try: + # If this node is a part of the conditional assignment block, we skip it + if node in conditional_assignment_nodes: + continue + + # Output names on the forward node + # (for which the gradient will be connected as an input on the reverse node) + given_gradients = [ + edge.src_conn for edge in subgraph.out_edges(node) + if ad_utils.path_src_node_in_subgraph(edge, subgraph) + ] + + # Input names on the forward node that gradients should be generated for + # note that the edge for the conditional is not included + required_gradients = [ + edge.dst_conn for edge in subgraph.in_edges(node) + if ad_utils.path_src_node_in_subgraph(edge, subgraph) + and self.sdfg.arrays[edge.data.data].dtype != dace.bool + ] + + reversed_node, backward_result = self._get_reverse_node(forward_state, backward_state, node, + given_gradients, required_gradients) + + self.reverse_map[node] = reversed_node + self.result_map[node] = backward_result + + # Connect the required inputs of the reverse node: + # the gradients ... + self._connect_given_gradients(forward_state=forward_state, + backward_state=backward_state, + subgraph=subgraph, + forward_node=node) + + # ... and any required input values from the forward pass + #################################### + # Determine which forward inputs we need to connect. + # these are the in_connectors on the reverse node, minus what has already been connected. + already_connected = {e.dst_conn for e in backward_state.in_edges(reversed_node)} + required_inputs = set(reversed_node.in_connectors).difference(already_connected) + required_inputs = {c: c for c in required_inputs} + self._connect_forward_inputs(forward_state, backward_state, node, reversed_node, required_inputs) + + if isinstance(node, nodes.AccessNode): + + # this means we are writing out a grad to an array. + # initialize the gradient if it hasn't been initialized already (this can also happen in + # _connect_given_gradients + array_grad_name = self.array_grad_name(node.data) + if array_grad_name not in self.backward_sdfg.arrays: + # this grad hasn't been written before: initialize it + self._add_gradient_data_descriptor(node.data) + + # we need to make all incoming gradients sum + if backward_state.in_degree(reversed_node) > 1: + + # Add a wcr to all the writes to the AccessNode + for edge in backward_state.in_edges(reversed_node): + # Add wcr to the memlet + for tree_edge in backward_state.memlet_tree(edge): + tree_edge.data.wcr = "lambda x, y: x + y" + + # If this node is a tasklet with a condition, we add some modification to the backward state + elif (isinstance(node, nodes.Tasklet) + and self._conditional_tasklet(node)) or (isinstance(node, nodes.NestedSDFG) + and self._conditional_nested_sdfg(forward_state, node)): + # extract the conditional assignment block or fail if this is an unexpected structure + conditional_block = self._extract_conditional_array_assignment_block(forward_state=forward_state, + tasklet_node=node, + subgraph=subgraph) + + # add these nodes to be skipped in the future + conditional_assignment_nodes.extend(conditional_block) + + # If the node is an AccessNode and it is being overwritten in the forward pass, + # we need to zero-out the gradients of the overwritten values + if isinstance(node, nodes.AccessNode): + # Check if there is an incoming edge to this node + incoming_edges = forward_state.in_edges(node) + + # If there is an incoming edge, we need to zero-out the gradient + for edge in incoming_edges: + + # Check, if possible, if the written subset is not zero + write_size = edge.data.subset.num_elements() + + # Check if the node doesn't have a WCR + # If it does, this is not an overwrite and the gradients should not be cleared + has_wcr = edge.data.wcr is not None + + # Check if the edge is dynamic, this means not all values are overwritten + # We will skip zeroing out the gradient in this case + if edge.data.dynamic: + Warning("Dynamic memlets are not fully supported in the reverse pass. " + "The gradient of the overwritten values may not be zeroed out.") + if not has_wcr and not edge.data.dynamic: + # Determine if we need to zero out the gradient + zero_out = not (isinstance(write_size, int) and write_size == 0) + + # We need to zero out the same memlet accesses in the backward pass + if zero_out: + self._zero_out_gradient(forward_state=forward_state, + forward_node=node, + memlet=edge.data) + + # Cleanup of isolated nodes + # We will have an isolated node if it is not connected to any other node in the state view + # And it has not been cleared out if it is an AccessNode + # Isolated nodes should only appear from clearing out gradients + # Check if this is an isolated node and remove it if it is + if backward_state.out_degree(reversed_node) == 0 and backward_state.in_degree(reversed_node) == 0: + if isinstance(node, nodes.AccessNode) and node not in self.zeroed_out: + backward_state.remove_node(reversed_node) + + except AutoDiffException as e: + raise AutoDiffException("Failed at node {}: {}".format(node, str(e))) from e + + def _set_wcr_if_needed(self, backward_state: SDFGState, backward_node: nodes.Node, + edge: dstate.MultiConnectorEdge) -> None: + """Set write-conflict resolution (WCR) for gradient accumulation if needed. + + If this AccessNode represents a gradient that has already been used elsewhere, + we want to accumulate the gradients rather than overwrite them. + + :param backward_state: The backward state containing the edge. + :param backward_node: The backward node (should be AccessNode for gradients). + :param edge: The edge that may need WCR for gradient accumulation. + """ + + # Check if the forward node is an AccessNode + if not isinstance(backward_node, nodes.AccessNode): + return + + # Otherwise, we add up the gradients, not overwrite them + for tree_edge in backward_state.memlet_tree(edge): + tree_edge.data.wcr = "lambda x, y: x + y" + + def _connect_given_gradients(self, forward_state: SDFGState, backward_state: SDFGState, + subgraph: dstate.StateSubgraphView, forward_node: nodes.Node) -> Optional[SDFGState]: + """Connect output gradients of forward_node as inputs to the corresponding reverse node. + + :param forward_state: The forward state containing the node. + :param backward_state: The backward state to add connections to. + :param subgraph: The subgraph view for the current operation. + :param forward_node: The forward node whose output gradients to connect. + :return: The backward state (possibly modified) or None. + """ + new_backward_state = None + # First, create the data descriptor if this is an access node and it hasn't been added before + if isinstance(forward_node, nodes.AccessNode): + grad_name = self.array_grad_name(forward_node.data) + if grad_name not in self.backward_sdfg.arrays: + # This grad hasn't been written before: initialize it + self._add_gradient_data_descriptor(forward_node.data) + + for edge in subgraph.out_edges(forward_node): + if not ad_utils.path_src_node_in_subgraph(edge, subgraph) or edge.dst not in self.reverse_map: + if edge.dst in self.conditional_block_entry: + backward_node = self.reverse_map[edge.src] + if not isinstance(edge.dst, nodes.MapEntry): + raise AutoDiffException(f"Expected MapEntry in conditional block, got {type(edge.dst)}") + conditional_zero_out_an = self.conditional_block_entry[edge.dst] + # Add an empty edge to skip the conditional block + backward_state.add_edge(conditional_zero_out_an, None, backward_node, None, Memlet()) + # skip connecting edges for which we don't need to generate grads. + continue + + # Skip connecting boolean edges + if self.sdfg.arrays[edge.data.data].dtype == dace.bool: + # we also need to remove this connector otherwise it will be dangling + backward_node = self.reverse_map[edge.src] + if not (isinstance(backward_node, nodes.MapEntry) or isinstance(backward_node, nodes.MapExit)): + # If this is not a map entry or exit, the boolean gradients will not be added + # No need to remove the connector in this case + continue + + conn_to_remove = ad_utils.invert_map_connector(edge.src_conn) + assert conn_to_remove in backward_node.in_connectors + assert backward_node.remove_in_connector(conn_to_remove) + if len(backward_node.in_connectors) == 0: + self._connect_conditional_map_exist(forward_state=forward_state, + backward_state=backward_state, + backward_map_exit=backward_node, + fwd_tasklet=edge.dst) + continue + + _, output_conn, dest_node, input_conn, fwd_memlet = edge + + memlet = copy.deepcopy(fwd_memlet) + + # Remove the WCR since these are now read edges + memlet.wcr = None + + grad_name = self.array_grad_name(memlet.data) + if grad_name not in self.backward_sdfg.arrays: + # This grad hasn't been written before: initialize it + self._add_gradient_data_descriptor(memlet.data) + + # We should not rely on the memlet data because that depends on the subset and other subset attibutes + # If this is an access node, and the memlet data is not the same as the AN data + memlet.data = grad_name + + # Check of the values have been zeroed out + backward_dst_node = self.reverse_map[dest_node] + if backward_dst_node in self.zeroed_out: + # The values will be zeroed out in the backward node + # We use the transient array instead + copied_zeroed_nodes = self.zeroed_out[backward_dst_node] + if len(copied_zeroed_nodes) == 1: + backward_dst_node = copied_zeroed_nodes[0] + else: + for node in copied_zeroed_nodes: + # Get the memlet to this node + zero_in_dege = backward_state.in_edges(node) + assert len(zero_in_dege) == 1 + zeroed_memlet = zero_in_dege[0].data + if zeroed_memlet.subset == edge.data.subset: + backward_dst_node = node + break + + memlet.data = backward_dst_node.data + + # We also need to Add an empty edge from the cleared node to where the data will be used + tmp_clear_node_out_edges = backward_state.out_edges(backward_dst_node) + for e in tmp_clear_node_out_edges: + if e.data.data is None and e.data.subset is None and e.data.other_subset is None: + clearing_map_entry = e.dst + assert isinstance(clearing_map_entry, nodes.MapEntry) + clearing_map_exit = backward_state.exit_node(clearing_map_entry) + assert isinstance(clearing_map_exit, nodes.MapExit) + # Check that this only has a single output edge and get the destination + assert backward_state.out_degree(clearing_map_exit) == 1 + cleared_out_node = backward_state.out_edges(clearing_map_exit)[0].dst + backward_node = self.reverse_map[forward_node] + backward_state.add_edge(cleared_out_node, None, backward_node, None, dace.Memlet()) + + # If this is a connection between two access nodes we need to flip the memlet subsets + if isinstance(forward_node, nodes.AccessNode): + # Special case for when the two access nodes are the same + if forward_node.data == dest_node.data and fwd_memlet.other_subset is not None: + new_memlet = dace.Memlet(data=self.reverse_map[forward_node].data, + subset=fwd_memlet.other_subset, + other_subset=fwd_memlet.subset) + else: + new_memlet = dace.Memlet(data=self.reverse_map[forward_node].data, + subset=fwd_memlet.subset + if fwd_memlet.data == forward_node.data else fwd_memlet.other_subset, + other_subset=fwd_memlet.other_subset + if fwd_memlet.data == forward_node.data else fwd_memlet.subset) + memlet = new_memlet + if input_conn not in self.result_map[dest_node].required_grad_names: + continue + new_edge = backward_state.add_edge( + backward_dst_node, + self._lookup_required_grad_name(dest_node, input_conn), + self.reverse_map[forward_node], + self._lookup_given_grad_name(forward_node, output_conn), + memlet, + ) + + # Change the access data in the memlet path if it has been zeroed out + # Calling the memlet path while reversing will raise an error + # Because the map has not been completely added for the backward state yet + # We also don't need to do anything for an AccessNode -> AccessNode connection + if (not isinstance(forward_node, + (nodes.MapExit, nodes.MapEntry))) and not (isinstance(forward_node, nodes.AccessNode) + and isinstance(dest_node, nodes.AccessNode)): + # Check if we can call the memlet path on new_edge safely + path = backward_state.memlet_path(new_edge) + + # Get the source access node in the path + source_access_node = list(path)[0].src + if isinstance(source_access_node, nodes.AccessNode): + # Check if this is a zeroed out node + in_values = any(source_access_node in values for values in self.zeroed_out.values()) + if source_access_node.data != memlet.data and in_values: + memlet.data = source_access_node.data + self._set_wcr_if_needed(backward_state=backward_state, + backward_node=self.reverse_map[forward_node], + edge=new_edge) + + return new_backward_state + + def _get_accessnode_to_forward(self, forward_state: SDFGState, forward_node: nodes.AccessNode): + """ + Check if this AccessNode is at the base level of the state. If yes, this is the node we want to connect + Otherwise, in the case the AN is encolsed by maps, we walk up the maps until we find the source AN. + """ + scope_dict = forward_state.scope_dict()[forward_node] + is_base_level = scope_dict is None + if is_base_level: + return forward_node + else: + # The node is within a map nest + # It should have an in edge leading to the original AN + in_edges = forward_state.in_edges(forward_node) + assert len(in_edges) == 1 + + # Get the memlet path and the original AN + memlet_path = forward_state.memlet_path(in_edges[0]) + original_an = memlet_path[0] + assert isinstance(original_an, nodes.AccessNode) + + # This should be a base level AN + assert forward_state.scope_dict()[original_an] is None + return original_an + + def _connect_forward_inputs(self, state: SDFGState, backward_state: SDFGState, forward_node: nodes.Node, + backward_node: nodes.Node, required_inputs: Dict[str, str]) -> None: + """Connect the reversed node to all required non-gradient inputs. + + This function handles non-trivial routing scenarios: + 1. When reading from an AccessNode in forward pass, route through maps in backward pass + 2. Save connector values to arrays when backward pass needs to read them + + Currently supports two strategies: store-all and recompute-all. + + :param state: Forward state containing the forward node. + :param backward_state: Backward state containing the backward node. + :param forward_node: The forward pass node. + :param backward_node: The backward pass node (not necessarily a reversed node). + :param required_inputs: Maps forward pass connector names to backward pass connector names. + :raises AutoDiffException: If required connectors don't exist on forward node. + """ + + if set(required_inputs).difference(forward_node.in_connectors): + missing_connectors = \ + set(required_inputs).difference(forward_node.in_connectors) + raise AutoDiffException(f"Cannot connect connectors {missing_connectors} to {backward_node} " + f"because they don't exist on the corresponding forward node {forward_node}") + + # note we use forward state here: we might need to connect inputs that are not in the + # forward pass + input_edges_to_connect = (edge for edge in state.in_edges(forward_node) if edge.dst_conn in required_inputs) + + for edge in input_edges_to_connect: + # Boolean to decide if the source of this edge needs to be replicated + replicate_node = False + + # Boolean to decide if the connection to the replicated node is required + # This is set to False if the connection has already been established + connect_replicated_node = True + edge_src = edge.src + next_required_inputs: Dict[Optional[str], Optional[str]] + replicated_edge_src: nodes.Node + replicated_edge_src_conn: str + + if isinstance(edge_src, nodes.MapEntry): + # In the map case, we need to connect the AN at the start of this memlet path + memlet_path = state.memlet_path(edge) + + # Get the AccessNode at the start of this path + starting_edge = memlet_path[0] + starting_an = starting_edge.src + if not isinstance(starting_an, nodes.AccessNode): + raise AutoDiffException(f"Expected AccessNode at start of memlet path, got {type(starting_an)}") + + # Save the information about the data to be forwarded + # to call the function to connect this required AccessNode + # after the reversal + self.data_to_forward.append((state, backward_state, starting_an, forward_node, edge)) + + # No further recusrive calls are required + # in this branch; next_required_inputs stays empty + next_required_inputs = {} + + # Everything will be done in the connect forward accessnode function + replicate_node = False + connect_replicated_node = False + + elif isinstance(edge_src, nodes.AccessNode): + # Get the AccessNode to connect + an_to_connect = self._get_accessnode_to_forward(state, edge_src) + + # Save the information about the data to be forwarded + # to call the function to connect this required AccessNode + # after the reversal + self.data_to_forward.append((state, backward_state, an_to_connect, forward_node, edge)) + + # No further recusrive calls are required + # in this branch; next_required_inputs stays empty + next_required_inputs = {} + + # Everything will be done in the connect forward accessnode function + replicate_node = False + connect_replicated_node = False + + elif isinstance(edge_src, nodes.Tasklet): + + replicate_node = True + # In the tasklet case, we need to connect all inputs + next_required_inputs = {c: c for c in edge_src.in_connectors} + else: + raise AutoDiffException("Unsupported node") + + if replicate_node: + replicated_edge_src_conn = edge.src_conn + + # always replicate the access node + replicated_edge_src = copy.deepcopy(edge_src) + backward_state.add_node(replicated_edge_src) + + if connect_replicated_node: + new_edge_data = copy.deepcopy(edge.data) + if isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode): + # code->code edges have a small special case: + # we need to copy the descriptor + data_name = new_edge_data.data + data_desc = copy.deepcopy(self.sdfg.arrays[data_name]) + if self.separate_sdfgs: + self.backward_sdfg.add_datadesc(data_name, data_desc) + else: + new_data_name = self.backward_sdfg.add_datadesc(data_name, data_desc, find_new_name=True) + new_edge_data.data = new_data_name + + if isinstance(edge_src, nodes.AccessNode) and isinstance(data_desc, dt.View): + if self.separate_sdfgs: + # Remove the view connector + assert replicated_edge_src.remove_in_connector("views") + else: + # If this is a view, we need to connect it to the AccessNode it is viewing + edge_src_in_edge = state.in_edges(edge_src) + + # A view should only have one incoming edge + assert len(edge_src_in_edge) == 1 + edge_src_in_edge = edge_src_in_edge[0] + + # Replicate the viewed node and its memlet and connect it + view_origin = edge_src_in_edge.src + replicated_view = copy.deepcopy(view_origin) + view_memlet = copy.deepcopy(edge_src_in_edge.data) + if self.separate_sdfgs: + # If the SDFGs are separate, we need to add the descriptor for this data + origin_desc = self.sdfg.arrays[view_origin.data] + origin_desc.transient = False + backward_state.sdfg.add_datadesc(view_origin.data, origin_desc) + backward_state.add_edge(replicated_view, None, replicated_edge_src, "views", view_memlet) + + # Add the new edge + backward_state.add_edge(replicated_edge_src, replicated_edge_src_conn, backward_node, + required_inputs[edge.dst_conn], new_edge_data) + + if next_required_inputs: + # If there are any required inputs on the new node, we need to + # recursively call + self._connect_forward_inputs(state, backward_state, edge.src, replicated_edge_src, next_required_inputs) + + def _lookup_required_grad_name(self, node: nodes.Node, connector: str) -> str: + """Look up the required gradient name for a given node and connector. + + :param node: The forward pass node. + :param connector: The connector name to look up. + :return: The required gradient name for the connector. + :raises AutoDiffException: If the node's backward result is not available. + """ + if node not in self.result_map: + raise AutoDiffException(f"Attempted to access required gradient of {node} " + f"before the backward node was created") + return self.result_map[node].required_grad_names[connector] + + def _lookup_given_grad_name(self, node: nodes.Node, connector: str) -> str: + """Look up the given gradient name for a given node and connector. + + :param node: The forward pass node. + :param connector: The connector name to look up. + :return: The given gradient name for the connector. + :raises AutoDiffException: If the node's backward result is not available. + """ + if node not in self.result_map: + raise AutoDiffException(f"Attempted to access given gradient of {node} " + f"before the backward node was created") + return self.result_map[node].given_grad_names[connector] + + def _find_backward_entry_node_for_map_entry(self, backward_state: SDFGState, + entry_node: nodes.MapEntry) -> nodes.MapEntry: + """Find the entry node in the backward pass corresponding to a forward pass entry node. + + :param backward_state: The backward state to search in. + :param entry_node: The MapEntry node from the forward pass. + :return: The corresponding MapEntry node in the backward pass. + :raises AutoDiffException: If exactly one corresponding node is not found. + """ + src_candidates = [ + node for node in backward_state.nodes() + if isinstance(node, nodes.MapEntry) and node.map == self.reverse_map[entry_node.map] + ] + if len(src_candidates) != 1: + raise AutoDiffException(f"Expected exactly one backward MapEntry for forward MapEntry {entry_node}, " + f"but found {len(src_candidates)} candidates") + + return src_candidates[0] + + def _get_reverse_node(self, state: SDFGState, backward_state: SDFGState, node: nodes.Node, + given_gradients: List[str], + required_gradients: List[str]) -> Tuple[nodes.Node, BackwardResult]: + """Add the reverse node for a node from the forward pass to the backward pass. + + Resolution order: + 1) Check for methods on this class + 2) Check the backward pass repository + + :param state: Forward state containing the node. + :param backward_state: Backward state to add the reverse node to. + :param node: Node from the forward pass to reverse. + :param given_gradients: Output names on the forward node for gradient input connections. + :param required_gradients: Input names on the forward node that need gradients generated. + :return: Tuple of (reversed node, BackwardResult with gradient connector names). + :raises AutoDiffException: If no backward implementation is found for the node type. + """ + + # (1) + if hasattr(self.dace_node_impl, "_reverse_" + type(node).__name__): + reverse_method = getattr(self.dace_node_impl, f"_reverse_{type(node).__name__}") + return reverse_method(state, backward_state, node, given_gradients, required_gradients) + + # (2) + impl = find_backward_implementation(self.sdfg, forward_state=state, node=node) + if impl is not None: + backward_node, backward_result = impl.backward(forward_node=node, + context=BackwardContext( + forward_state=state, + forward_sdfg=self.sdfg, + backward_state=backward_state, + backward_sdfg=self.backward_sdfg, + backward_generator=self, + ), + given_gradients=given_gradients, + required_gradients=required_gradients) + if isinstance(backward_node, nodes.CodeNode): + backward_node.schedule = node.schedule + return backward_node, backward_result + + raise AutoDiffException(f"Unable to differentiate node type {type(node)}. " + f"Either add a pure forward implementation or a backward implementation to progress.") diff --git a/dace/autodiff/base_abc.py b/dace/autodiff/base_abc.py new file mode 100644 index 0000000000..c47cc83d0d --- /dev/null +++ b/dace/autodiff/base_abc.py @@ -0,0 +1,183 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Abstract Base Classes for Autodiff +""" +import abc +import dataclasses +import typing +from typing import TYPE_CHECKING + +import dace.registry +from dace import config +from dace.sdfg import SDFG, SDFGState, nodes as nd +import dace.transformation.transformation as xf + +if TYPE_CHECKING: + from dace.autodiff.backward_pass_generator import BackwardPassGenerator + +try: + from dace.libraries.onnx.nodes.onnx_op import ONNXOp + ONNX_AVAILABLE = True +except ImportError: + ONNXOp = None + ONNX_AVAILABLE = False + + +class AutoDiffException(Exception): + """Base class for all exceptions related to automatic differentiation failures.""" + pass + + +@dataclasses.dataclass +class BackwardContext: + """A tuple holding the graph context required to construct reverse nodes.""" + forward_sdfg: SDFG #: The forward SDFG + forward_state: SDFGState #: The forward SDFG state + backward_sdfg: SDFG #: The backward SDFG + backward_state: SDFGState #: The backward SDFG state + backward_generator: 'BackwardPassGenerator' #: The backward pass generator + + +@dataclasses.dataclass +class BackwardResult: + """The return type of a differentiated node. It contains the names of the gradients the node calculates and + requires. + """ + + #: Mapping from names of output connectors to the connector name of the gradient for that connector. + required_grad_names: typing.Dict[typing.Optional[str], typing.Optional[str]] + + #: Mapping from names of input connectors to the connector name of the gradient for that connector. + given_grad_names: typing.Dict[typing.Optional[str], typing.Optional[str]] + + #: Mapping from names of gradients to whether they should be zeroed out on initialization. + zero_init: typing.Dict[typing.Optional[str], typing.Optional[bool]] + + def __init__(self, required_grad_names, given_grad_names, zero_init=None): + self.required_grad_names = required_grad_names + self.given_grad_names = given_grad_names + self.zero_init = zero_init or {} + + @staticmethod + def empty(): + """Create an empty BackwardResult with no gradients.""" + return BackwardResult(given_grad_names={}, required_grad_names={}, zero_init={}) + + +@dace.registry.make_registry +class BackwardImplementation(abc.ABC): + """ABC for backward implementations. + + This registry accepts two types of registrations. + The register function expects an argument ``node_type=TYPE`` where ``TYPE`` is the type of node that this + backward implementation supports. + It can also take an argument ``op=node_name`` where ``node_name`` is the string of the ONNX op it supports, + e.g. ``"Conv"``. + + It also expects a ``name`` argument that names the implementation. + """ + + @staticmethod + def backward_can_be_applied(node: nd.Node, state: SDFGState, sdfg: SDFG) -> bool: + """Return whether this expansion can be applied. + + :param node: The candidate node. + :param state: The candidate state. + :param sdfg: The candidate SDFG. + :return: True if the implementation can be applied, False otherwise. + """ + return True + + @staticmethod + @abc.abstractmethod + def backward(forward_node: nd.Node, context: BackwardContext, given_gradients: typing.List[typing.Optional[str]], + required_gradients: typing.List[typing.Optional[str]]) -> typing.Tuple[nd.Node, BackwardResult]: + """Add the reverse node for a node from the forward pass to the backward pass, and return it. + + For each input connector with name ``n`` of the forward in required_gradients, the returned backward node must + add an output connector with name ``required_gradients[n]`` that will output the gradient for that input. + + If any input from the forward pass is required, simply add a connector with the same name as the connector + on the forward node. The input will later be connected as required. + + :param forward_node: The node for which the backward pass should be generated. + :param context: The context for this node (see + :class:`~dace.autodiff.backward_implementation.BackwardContext`). + :param given_gradients: The names of outputs of the node that gradients will be connected for. + :param required_gradients: The names of connectors that gradients should be generated for. + :return: The reverse node and gradient names + (see :class:`~dace.autodiff.backward_implementation.BackwardResult`). + """ + ... + + +# Register the implementations +import dace.autodiff.implementations + + +def find_backward_implementation(forward_sdfg: SDFG, forward_state: SDFGState, + node: nd.Node) -> typing.Optional[BackwardImplementation]: + """Try to find the backward implementation for ``node``. + + :param forward_sdfg: The parent SDFG of the node. + :param forward_state: The parent SDFG state of the node. + :param node: The node to find the implementation for. + :return: The BackwardImplementation for node if one is registered and can be applied, else None. + """ + valid_impls = [] + for impl, args in BackwardImplementation.extensions().items(): + if "name" not in args: + raise ValueError(f"Expected name in arguments of implementation {impl}.") + + if "node_type" in args and isinstance(node, args["node_type"]) or (ONNX_AVAILABLE and isinstance(node, ONNXOp) + and "op" in args + and node.schema.name == args["op"]): + + if impl.backward_can_be_applied(node, forward_state, forward_sdfg): + valid_impls.append((args["name"], impl)) + + if ONNX_AVAILABLE and isinstance(node, ONNXOp) and node.backward_implementation: + + implementation = node.backward_implementation + elif ONNX_AVAILABLE and isinstance(node, ONNXOp) and node.default_backward_implementation: + implementation = node.default_backward_implementation + else: + implementation = None + + if implementation: + filtered_impls = [i for name, i in valid_impls if name == implementation] + if filtered_impls: + return filtered_impls[0] + + if config.Config.get_bool('debugprint'): + print(f"Warning: Set backward_implementation {node.backward_implementation} on {node}, but it could not be" + f" applied. Falling back to default selection.") + if valid_impls: + return valid_impls[0][1] + else: + return None + + +class ExpansionTemplate(xf.ExpandTransformation): + """Module-level expansion class for operations during autodiff. + + This class is used by BackwardPassGenerator._expand_nodes to expand operations + that don't have backward implementations. It needs to be at module level for serialization. + + The class is dynamically configured before use by setting: + - environments: List of required environments + - _impl: The implementation object containing the forward method + - _match_node: The pattern node to match + """ + environments = [] + _impl = None + + @classmethod + def expansion(cls, node, state, sdfg): + if cls._impl is None: + raise RuntimeError("_ONNXExpansion._impl must be set before expansion") + return cls._impl.forward(node, state, sdfg) + + @staticmethod + def annotates_memlets() -> bool: + return True diff --git a/dace/autodiff/data_forwarding/__init__.py b/dace/autodiff/data_forwarding/__init__.py new file mode 100644 index 0000000000..da93544194 --- /dev/null +++ b/dace/autodiff/data_forwarding/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Data Forwarding Strategies for Automatic Differentiation. + +This package manages the tradeoff between storing intermediate values and +recomputing them during the backward pass. This is a fundamental memory-time +tradeoff in automatic differentiation. +""" + +from .manager import DataForwardingManager +from .store import resolve_overwrite_with_store +from .recompute import get_recomputation_nsdfg, resolve_overwrite_with_recomputation + +__all__ = [ + "DataForwardingManager", + "resolve_overwrite_with_store", + "resolve_overwrite_with_recomputation", + "get_recomputation_nsdfg", +] diff --git a/dace/autodiff/data_forwarding/manager.py b/dace/autodiff/data_forwarding/manager.py new file mode 100644 index 0000000000..cb510632ea --- /dev/null +++ b/dace/autodiff/data_forwarding/manager.py @@ -0,0 +1,388 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import copy +from typing import List, Tuple, Optional + +# DaCe imports +import dace.sdfg.nodes as nodes +from dace import config, data as dt +from dace.sdfg import SDFGState, graph as dgraph + +# Autodiff imports +from dace.autodiff.base_abc import AutoDiffException +import dace.autodiff.utils as ad_utils +import dace.autodiff.data_forwarding as data_forwarding + + +class DataForwardingManager: + + def __init__(self, bwd_generator: 'BackwardPassGenerator'): + + # The user specified strategy for forwarding + # Whether to forward data through separate SDFGs + self.bwd_generator: 'BackwardPassGenerator' = bwd_generator + + def forward_data_to_backward_pass(self) -> None: + """ + Iterate through all the data that needs to be forwarded to the backward pass states. + """ + # Get the strategy decision for each data that needs to be forwarded to the backward pass + strategy_choice, recomputation_nsdfgs = self._get_overwrite_resolution_strategy() + + # Make the connection according to the chosen strategy + for index, (forward_state, backward_state, access_node, node, + edge) in enumerate(self.bwd_generator.data_to_forward): + self._connect_forward_accessnode(forward_state, backward_state, access_node, node, edge, + recomputation_nsdfgs[index], strategy_choice[index]) + + def _get_overwrite_resolution_strategy(self) -> Tuple[List[str], List[Optional[nodes.NestedSDFG]]]: + """ + Choose a strategy for resolving overwritten data that we need to forward to the backward pass. + If the user wants a specific strategy, we use it. + Otherwise, we evaluate what strategy is best for this specific node. + """ + strategy_choice: List[str] = [] + recomputation_nsdfgs: List[Optional[nodes.NestedSDFG]] = [] + + # As preprocessing step, + # We will store all of the global program inputs, + # if they are required for the backward pass + # NOTE: This can be relaxed since if an input is not overwritten + # it can be recomputed + to_remove = [] + for i, (forward_state, backward_state, access_node, node, + edge) in enumerate(self.bwd_generator.data_to_forward): + if access_node.data not in self.bwd_generator.sdfg.arg_names: + continue + + # Store the input + self._connect_forward_accessnode(forward_state, backward_state, access_node, node, edge, None, "store") + + # Remove this element from the list of the data to forward + to_remove.append(i) + + # Remove elements from the list of data to be forwarded (in reverse order to maintain indices) + for idx in sorted(to_remove, reverse=True): + del self.bwd_generator.data_to_forward[idx] + + if self.bwd_generator.data_forwarding_strategy == "store_all": + strategy_choice = ["store"] * len(self.bwd_generator.data_to_forward) + + # A recomputation block is not necessary + recomputation_nsdfgs = [None] * len(self.bwd_generator.data_to_forward) + elif self.bwd_generator.data_forwarding_strategy == "recompute_all": + strategy_choice = ["recompute"] * len(self.bwd_generator.data_to_forward) + + # We will delay getting the recomputation block for now + recomputation_nsdfgs = [None] * len(self.bwd_generator.data_to_forward) + elif self.bwd_generator.data_forwarding_strategy == "user_defined": + if self.bwd_generator.data_to_recompute is None: + raise AutoDiffException("The overwrite resolution strategy is User Defined " + "but no recomputation list has been provided." + "Please set the data_to_recompute parameter.") + + for forward_state, backward_state, access_node, node, edge in self.bwd_generator.data_to_forward: + + if access_node.data in self.bwd_generator.data_to_recompute: + try: + nsdfg = data_forwarding.get_recomputation_nsdfg(self.bwd_generator, forward_state, access_node) + choice = "recompute" + except Exception as e: + # If anything goes wrong, print a warning and fall back to storing + if config.Config.get_bool('debugprint'): + print( + f"Warning: Couldn't get the recomputation nested SDFG for {access_node.label} because {e}" + ) + nsdfg = None + choice = "store" + recomputation_nsdfgs.append(nsdfg) + strategy_choice.append(choice) + else: + # We store everything else + recomputation_nsdfgs.append(None) + strategy_choice.append("store") + else: + raise AutoDiffException("Please specify a valid overwrite resolution strategy. " + "Expected either store_all, recompute_all, or user_defined " + f"but got {self.bwd_generator.data_forwarding_strategy}") + return strategy_choice, recomputation_nsdfgs + + def _connect_forward_accessnode(self, forward_state: SDFGState, backward_state: SDFGState, + forward_node: nodes.AccessNode, target_node: nodes.Node, + starting_edge: dgraph.MultiConnectorEdge, + recomputation_nsdfg: Optional[nodes.NestedSDFG], strategy: str): + """ + We need to forward an array from the forward pass to the backward pass. + To do this we first check if this array has been overwritten or not. + If the array has not been overwritten, we just need to replicate it + in the backward pass and then forward it. + If the array has been overwritten, we pick a strategy for this AccessNode: + - Store strategy: + - We modify the forward pass to save the values in a new array + - Connect this new array to the node in the backward pass + - Recomputation: + - Add the recomputation as a NestedSDFG + - Connect the output of the NestedSDFG to the node in the backward pass + """ + + # First, we check if the node has been overwritten + overwritten, recomputable = self._check_node_overwrite(forward_state=forward_state, node=forward_node) + + # Boolean indicating whether we should fall back to storing + fallback = False + if strategy == "recompute" and recomputable: + try: + if recomputation_nsdfg is None: + recomputation_nsdfg = data_forwarding.get_recomputation_nsdfg(self.bwd_generator, + forward_state, + target_an=forward_node) + data_forwarding.resolve_overwrite_with_recomputation(recomputation_nsdfg=recomputation_nsdfg, + forward_state=forward_state, + backward_state=backward_state, + target_an=forward_node, + target_node=target_node, + starting_edge=starting_edge) + except Exception as e: + # If anything goes bad, print a warning and fall back to storing + if config.Config.get_bool('debugprint'): + print(f"Warning: Failed to recompute {forward_node.data}: {e}. Falling back to storing") + fallback = True + + if strategy == "store" or (strategy == "recompute" and not recomputable) or fallback: + # We store if: + # - This was the specified strategy + # - We tried to recompute a program input + # - We tried to recompute something that didn't work and we're falling back to storing + + # The data has been overwritten + if not overwritten: + # We still have access to this data + self._connect_forward_accessnode_not_overwritten(forward_state, backward_state, forward_node, + target_node, starting_edge) + return + + data_forwarding.resolve_overwrite_with_store(bwd_generator=self.bwd_generator, + forward_state=forward_state, + backward_state=backward_state, + forward_node=forward_node, + target_node=target_node, + starting_edge=starting_edge) + + def _check_node_overwrite(self, forward_state: SDFGState, node: nodes.AccessNode) -> Tuple[bool, bool]: + """ + Given an AccessNode from the forward state, check if the data of this node has changed. + We look at all the AccessNodes with the same data that occur after the 'node' parameter + if any of them has an incoming edge, return the node has been overwritten. + + :param node: the AccessNode to perform the check for. + :return: a tuple of whether this node has been overwritten, and if it can be recomputed + """ + overwritten = False + decided = False + recomputable = False + + # Get the descendant and ascendant states to look in for an overwrite + if forward_state not in self.bwd_generator.state_order: + raise AutoDiffException(f"Forward state {forward_state} not found in state order") + index = self.bwd_generator.state_order.index(forward_state) + descendant_states = self.bwd_generator.state_order[index:] + + # Check if this access node is a view + if isinstance(node.desc(self.bwd_generator.sdfg), dt.ArrayView): + # The view should have one incoming edge from the original access node + in_edges = forward_state.in_edges(node) + + # Sanity checks + if len(in_edges) != 1: + raise AutoDiffException(f"Expected exactly one incoming edge for view node {node}, got {len(in_edges)}") + if "views" not in node.in_connectors: + raise AutoDiffException(f"Expected 'views' connector in node {node}, but not found") + + # We want to check if the source has been overwritten + node = in_edges[0].src + + # Get all the AccessNodes with the same data + matches = [] + for d_state in descendant_states: + matches += [(nd, parent) for nd, parent in d_state.all_nodes_recursive() + if isinstance(nd, nodes.AccessNode) and nd.data == node.data] + + # There needs to be at least one occurrence which is the node passed as a parameter + if len(matches) == 0 or (node, forward_state) not in matches: + raise AutoDiffException(f"Node {node} not found in descendant states") + + # If there is only one occurrence of this data, it will not be overwritten later in the graph + if len(matches) == 1: + overwritten = False + decided = True + + # Get the index of the parameter node + index = matches.index((node, forward_state)) + + # If the parameter node is the last occurrence in the descendant states, + # it will not be overwritten + if len(matches) - 1 == index: + overwritten = False + decided = True + + # If we haven't already confirmed that this node has not been overwritten + if not decided: + # Iterate through all the successor occurrences + for nd, parent in matches[index + 1:]: + # Check if this node has an incoming edge + if len(parent.in_edges(nd)) > 0: + overwritten = True + + if not overwritten: + # There is no overwrite so far + # Check if this state is within a loop + is_in_loop, loop = ad_utils.state_within_loop(forward_state) + if is_in_loop: + + # Check if there is any write to this access node within the loop + loop_matches = [(nd, parent) for nd, parent in loop.all_nodes_recursive() + if isinstance(nd, nodes.AccessNode) and nd.data == node.data] + for match, match_parent in loop_matches: + # Check if this node has an incoming edge + if len(match_parent.in_edges(match)) > 0: + overwritten = True + + if overwritten and len(matches) == 1: + # Check if the overwrite is from constant arrays + # This means that the same value will be assigned at each iteration of the loop + # And no storing is necessary + match, match_parent = loop_matches[0] + all_read_only = True + for edge in match_parent.edge_bfs(match, reverse=True): + if edge.data.subset is not None and len(edge.data.subset.free_symbols) != 0: + all_read_only = False + break + if isinstance(edge.src, nodes.AccessNode): + # The memlet needs to be constant + if edge.src.data not in self.bwd_generator.read_only_arrays: + all_read_only = False + break + # Check if the data is read only + if all_read_only: + overwritten = False + + # Iterate through all the predecessor occurrences + for nd, parent in matches[:index + 1]: + # Check if this node has an incoming edge + if len(parent.in_edges(nd)) > 0: + recomputable = True + return overwritten, recomputable + + def _connect_forward_accessnode_not_overwritten(self, + forward_state: SDFGState, + backward_state: SDFGState, + forward_node: nodes.AccessNode, + target_node: nodes.Node, + starting_edge: dgraph.MultiConnectorEdge, + replicated_node: Optional[nodes.AccessNode] = None): + """ + Replicate and connect the forward AccessNode to the requesting node in the backward pass. + Because the AccessNode has not been overwritten, we just need to create the same connection + in the backward pass. + """ + + # First, replicate the AccessNode and add it to the backward pass + # If it has not already been replicated and passed as a parameter + if replicated_node is None: + replicated_node = copy.deepcopy(forward_node) + backward_state.add_node(replicated_node) + if self.bwd_generator.separate_sdfgs: + # Need to copy over the descriptor from the forward pass + data_name = replicated_node.data + data_desc = copy.deepcopy(forward_node.desc(self.bwd_generator.sdfg)) + data_desc.transient = False + if data_name not in self.bwd_generator.backward_sdfg.arrays: + self.bwd_generator.backward_sdfg.add_datadesc(data_name, data_desc) + + # We also need to forward this array + if data_name not in self.bwd_generator.backward_input_arrays: + # If the data is needed inside a NestedSDFG + # This will make sure the added array is correctly forwarded + # and an in connector to the NestedSDFG is added + self.bwd_generator.backward_input_arrays[data_name] = data_desc + + # We replicate the exact link between this forward access node and the target node + # Get all the edges in the path + all_edges_inbetween = ad_utils.get_all_path_edges(state=forward_state, + source=forward_node, + starting_edge=starting_edge) + + # A dictionary to keep track of temporary nodes in the path + replicated_tmp_nodes = {} + + # For each edge in the path + for edge in all_edges_inbetween: + src, src_conn, dst, dst_conn, data = edge + bwd_src, bwd_src_conn, bwd_dst, bwd_dst_conn, bwd_data = src, src_conn, dst, dst_conn, copy.deepcopy(data) + + # If the destination is a map entry, + if isinstance(dst, nodes.MapEntry): + # We need to get the corresponding map entry in the backward pass. + bwd_dst = self.bwd_generator._find_backward_entry_node_for_map_entry(backward_state=backward_state, + entry_node=dst) + # Add the dst connector to the map + added = bwd_dst.add_in_connector(bwd_dst_conn) + assert added + + # If the destination is a map entry, + if isinstance(src, nodes.MapEntry): + # We need to get the corresponding map entry in the backward pass. + bwd_src = self.bwd_generator._find_backward_entry_node_for_map_entry(backward_state=backward_state, + entry_node=src) + # Add the src connector to the map + added = bwd_src.add_out_connector(bwd_src_conn) + assert added + + if src is forward_node: + # If this is the node we replicated + bwd_src = replicated_node + elif isinstance(src, nodes.AccessNode): + # This is a temporary AccessNodes + # we should have already seen and replicated this + assert src in replicated_tmp_nodes + bwd_src = replicated_tmp_nodes[src] + + if dst is target_node: + # If this is the final connection node + bwd_dst = self.bwd_generator.reverse_map[dst] + elif isinstance(dst, nodes.AccessNode): + # This is a temporary AccessNodes + # we want to replicate and add it to the path + bwd_dst = copy.deepcopy(dst) + backward_state.add_node(bwd_dst) + replicated_tmp_nodes[dst] = bwd_dst + + # Modify the data in the memlet in case the array is replicated outside of the function + bwd_data.data = replicated_node.data + + # Add the edge to the backward state + backward_state.add_edge(bwd_src, bwd_src_conn, bwd_dst, bwd_dst_conn, bwd_data) + + # If we just connected a view, we need to remove the view in connector + data_desc = self.bwd_generator.sdfg.arrays[forward_node.data] + if isinstance(forward_node, nodes.AccessNode) and isinstance(data_desc, dt.View): + if self.bwd_generator.separate_sdfgs: + # Remove the view connector + assert replicated_node.remove_in_connector("views") + else: + # if this is a view, we need to connect it to the AccessNode it is viewing + edge_src_in_edge = forward_state.in_edges(forward_node) + + # a view should only have one incoming edge + assert len(edge_src_in_edge) == 1 + edge_src_in_edge = edge_src_in_edge[0] + + # replicate the viewed node and its memlet and connect it + view_origin = edge_src_in_edge.src + replicated_view = copy.deepcopy(view_origin) + view_memlet = copy.deepcopy(edge_src_in_edge.data) + if self.bwd_generator.separate_sdfgs: + # if the sdfgs are separate, we need to add the descriptor for this data + origin_desc = self.bwd_generator.sdfg.arrays[view_origin.data] + origin_desc.transient = False + backward_state.sdfg.add_datadesc(view_origin.data, origin_desc) + backward_state.add_edge(replicated_view, None, replicated_node, "views", view_memlet) diff --git a/dace/autodiff/data_forwarding/recompute.py b/dace/autodiff/data_forwarding/recompute.py new file mode 100644 index 0000000000..ce6330dbff --- /dev/null +++ b/dace/autodiff/data_forwarding/recompute.py @@ -0,0 +1,298 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import copy +from typing import List + +# DaCe imports +import dace +import dace.sdfg.nodes as nodes +from dace.sdfg import SDFG, SDFGState, graph as dgraph, state as dstate +from dace.sdfg.state import LoopRegion + +# Autodiff imports +from dace.autodiff.base_abc import AutoDiffException +import dace.autodiff.utils as ad_utils + + +def resolve_overwrite_with_recomputation( + recomputation_nsdfg: nodes.NestedSDFG, + forward_state: SDFGState, + backward_state: SDFGState, + target_an: nodes.AccessNode, + target_node: nodes.Node, + starting_edge: dstate.MultiConnectorEdge, +): + """ + Experimental! Use recomputation in the backward pass to compute data that was overwritten in the forward pass. + """ + + # Add the nsdfg where it is required + _connect_recomputation_nsdfg(forward_state=forward_state, + backward_state=backward_state, + nsdfg=recomputation_nsdfg, + target_an=target_an, + target_node=target_node, + starting_edge=starting_edge) + + +def _connect_recomputation_nsdfg(bwd_generator: 'BackwardPassGenerator', forward_state: SDFGState, + backward_state: SDFGState, target_an: nodes.AccessNode, target_node: nodes.Node, + nsdfg: nodes.NestedSDFG, starting_edge: dstate.MultiConnectorEdge): + """ + + """ + # Connect all the SDFG inputs to the nested SDFG + # First, add the nested sdfg + for input in nsdfg.in_connectors.keys(): + # For each argument + input_name = input if "recomputation_" not in input else input[14:] + + # Get the first instance of this AN in the SDFG + first_instance = None + for node, parent in bwd_generator.forward_sdfg.all_nodes_recursive(): + if isinstance(node, nodes.AccessNode) and node.data == input: + first_instance = node + first_node_state = parent + break + + assert first_instance + + new_an = nodes.AccessNode(input_name) + backward_state.add_node(new_an) + + # Create a memlet passing all the data to the nested-SDFG + memlet = bwd_generator.forward_sdfg.make_array_memlet(input_name) + + # Add the connection to the nested SDFG + backward_state.add_edge(new_an, None, nsdfg, input, memlet) + + # Write the data to a new access node in the backward state + # Add a new AccessNode and array to the forward pass + # First, check if a recomputated array with this name already exists + if "recomputed_" + target_an.data not in bwd_generator.backward_sdfg.arrays: + new_recomp_node_name = "recomputed_" + target_an.data + else: + i = 0 + while True: + if f"recomputed_{i}_" + target_an.data not in bwd_generator.backward_sdfg.arrays: + new_recomp_node_name = f"recomputed_{i}_" + target_an.data + break + i += 1 + + # Get the new array shape + # This will be the shape of the current array + shape: List[int] = list(bwd_generator.forward_sdfg.arrays[target_an.data].shape) + + # Add the array descriptor and AccessNode to the forward state + original_desc = target_an.desc(forward_state) + new_recomp_node = backward_state.add_array( + name=new_recomp_node_name, + shape=shape, + dtype=original_desc.dtype, + transient=True, + ) + new_recomp_node.setzero = True + + # Create a memlet passing all the data to the nested-SDFG + memlet = bwd_generator.forward_sdfg.make_array_memlet(new_recomp_node.data) + + nsdfg_out_conn = list(nsdfg.out_connectors.keys()) + assert len(nsdfg_out_conn) == 1 + nsdfg_out_conn = nsdfg_out_conn[0] + + # Connect the output of the NestedSDFG + backward_state.add_edge(nsdfg, nsdfg_out_conn, new_recomp_node, None, memlet) + + # Connect the new AccessNode to the required computation + bwd_generator._connect_forward_accessnode_not_overwritten(forward_state=forward_state, + backward_state=backward_state, + forward_node=target_an, + target_node=target_node, + starting_edge=starting_edge, + replicated_node=new_recomp_node) + + +def _prune_descendants_recomputation_nsdfg(forward_state: SDFGState, target_an: nodes.AccessNode, + nsdfg: nodes.NestedSDFG): + """ + 1: From this Nested-SDFG, we remove everything that will be executed after the target access node to be recomputed + 2: Prune the unnecessary computation inside the forward state + Note: this is even necessary sometimes since the output could be overwritten in the same state + """ + + # 1 + # Get the states order for the nested_sdfg + states_order: List[SDFGState] = ad_utils.get_state_topological_order(nsdfg.sdfg) + state_index = states_order.index(forward_state) + descendant_states: List[SDFGState] = states_order[state_index:] + assert descendant_states.pop(0) == forward_state + + # Check if the target state is within a loop + target_within_loop, target_loop = ad_utils.state_within_loop(forward_state) + + # We will save the states that are within the same loop because they require special treatement + same_loop_states: List[SDFGState] = [] + for state in descendant_states: + # We want to avoid removing the descendant states that are inside the same loop region + if target_within_loop: + descendant_within_loop, descendant_loop = ad_utils.state_within_loop(state) + if descendant_within_loop and descendant_loop == target_loop: + # If the state is within the same loop, we don't remove it + same_loop_states.add(state) + continue + + # Remove the state from the nested_sdfg + parent = state.parent_graph + parent.remove_node(state) + + # Cleanup empty LoopRegions if any + for node in nsdfg.sdfg.all_nodes_recursive(): + if isinstance(node, LoopRegion) and len(node.nodes()) == 0: + parent = node.parent_graph + parent.remove_node(node) + + # 2 + # Within the same state + if target_within_loop: + # For now we keep all of the computation inside the loop + # TODO: if there is an overwrite to the same array in the decendnat computation + # We need to make a special case for the last iteration of the loop where the + # else branch of this if is executed and a special version of the loop is added + raise AutoDiffException("Recomputation with overwrites within loops is not supported yet.") + else: + # If the target state is not within a loop + # We remove all the descendant computation from the graph + + # Do a reverse bfs to get all the necessary computation + backward_nodes = {n for e in forward_state.edge_bfs(target_an, reverse=True) for n in [e.src, e.dst]} + + # Remove everything else + descendant_nodes = set(forward_state.nodes()) - backward_nodes + + for node in descendant_nodes: + if node is not target_an: + forward_state.remove_node(node) + + +def _prune_recomputation_sdfg(forward_state: SDFGState, target_an: nodes.AccessNode, nsdfg: nodes.NestedSDFG): + """ + 1: From this Nested-SDFG, we remove everything that will be executed after the target access node to be recomputed + 2: Prune the unnecessary computation inside the forward state + Note: this is even necessary sometimes since the output could be overwritten in the same state + 3: TODO: From the target access node, we go backward in the graph and see what elements are required to get this array + """ + + # 1 and 2 + _prune_descendants_recomputation_nsdfg(forward_state=forward_state, target_an=target_an, nsdfg=nsdfg) + + +def _rename_descriptors_for_recomputation_nsdfg(forward_sdfg: SDFG, nsdfg: nodes.NestedSDFG): + """ + """ + # Get all the nodes to rename in the NestedSDFG + to_rename = [] + for inp in nsdfg.in_connectors: + for node, parent in nsdfg.sdfg.all_nodes_recursive(): + if isinstance(node, nodes.AccessNode) and node.data == inp and parent.in_degree(node) > 0: + # This is an input that will be written to in the SDFG we need to rename it + to_rename.append(inp) + break + + if len(to_rename) > 0: + # Add a new state to copy the data at the start of the SDFG + initi_state = nsdfg.sdfg.add_state_before(nsdfg.sdfg.start_state, label=f"init_{nsdfg.label}") + + # Rename the descriptors in the nested SDFG in addition to the in connector + for name in to_rename: + # Create a new array + new_name = f"recomputation_{name}" + + # Change the accessnodes in the NestedSDFG + for node, parent in nsdfg.sdfg.all_nodes_recursive(): + if isinstance(node, nodes.AccessNode) and node.data == name: + node.data = new_name + + # Change the memlets in the SDFG + for edge, parent in nsdfg.sdfg.all_edges_recursive(): + # Skip interstate edges + if isinstance(edge.data, dace.InterstateEdge): + continue + + if edge.data.data == name: + edge.data.data = new_name + + # Add the desciptor + old_desc = nsdfg.sdfg.arrays[name] + new_desc = copy.deepcopy(old_desc) + + # Check if this is the output of the recomputation block + if name not in nsdfg.out_connectors: + new_desc.transient = True + else: + new_desc.transient = False + + nsdfg.sdfg.add_datadesc(name=new_name, datadesc=new_desc) + + # Add a copy operation between the input node and the new descriptor + input_node = nodes.AccessNode(name) + new_node = nodes.AccessNode(new_name) + initi_state.add_node(input_node) + initi_state.add_node(new_node) + + # Add memory copy edge + initi_state.add_edge(input_node, None, new_node, None, forward_sdfg.make_array_memlet(name)) + + # Change the output if necessary + if name in nsdfg.out_connectors: + nsdfg.remove_out_connector(name) + nsdfg.add_out_connector(new_name) + + +def get_recomputation_nsdfg(bwd_generator: 'BackwardPassGenerator', forward_state: SDFGState, + target_an: nodes.AccessNode) -> nodes.NestedSDFG: + """ + Given an AccessNode for data that needs to be forwarded from the forward pass to the backward pass, + Return a nested SDFG that recomputes this data from input data. + """ + nsdfg_label = "recomputation_nsdfg_" + target_an.data + + # Initially, we will replicate the whole SDFG into a Nested-SDFG and connect it + # TODO: we likely need a copy of the SDFG before starting AD if separate_sdfgs + nsdfg = nodes.NestedSDFG(label=nsdfg_label, + sdfg=copy.deepcopy(bwd_generator.sdfg), + inputs=bwd_generator.sdfg.arg_names, + outputs=[target_an.data]) + + # We need to make sure the output inside the NestedSDFG is not a transient (anymore) + nsdfg.sdfg.arrays[target_an.data].transient = False + + # Find the same target node and state in the nsdfg + nsdfg_forward_state: SDFGState = None + nb_occurrences = 0 + for state in nsdfg.sdfg.states(): + if state.label == forward_state.label: + nsdfg_forward_state = state + nb_occurrences += 1 + + # Sanity check + assert nb_occurrences == 1 + assert nsdfg_forward_state + + # Find the target AccessNode within the state + nsdfg_target_node: nodes.AccessNode = None + nb_occurrences = 0 + for node in nsdfg_forward_state.nodes(): + if isinstance(node, nodes.AccessNode) and node.data == target_an.data and nsdfg_forward_state.node_id( + node) == forward_state.node_id(target_an): + nsdfg_target_node = node + nb_occurrences += 1 + + # Sanity check + assert nb_occurrences == 1 + assert nsdfg_target_node + + _prune_recomputation_sdfg(nsdfg=nsdfg, forward_state=nsdfg_forward_state, target_an=nsdfg_target_node) + + # Change descriptors if the inputs are written to + _rename_descriptors_for_recomputation_nsdfg(forward_sdfg=bwd_generator.sdfg, nsdfg=nsdfg) + + return nsdfg diff --git a/dace/autodiff/data_forwarding/store.py b/dace/autodiff/data_forwarding/store.py new file mode 100644 index 0000000000..a744884e36 --- /dev/null +++ b/dace/autodiff/data_forwarding/store.py @@ -0,0 +1,683 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import copy +from typing import List, Tuple +import sympy as sp + +# DaCe imports +import dace.sdfg.nodes as nodes +from dace import dtypes, data as dt, symbolic +from dace.sdfg import SDFGState, graph as dgraph, state as dstate +from dace.memlet import Memlet +from dace.sdfg.state import LoopRegion + +# Autodiff imports +from dace.autodiff.base_abc import AutoDiffException +import dace.autodiff.utils as ad_utils + + +def resolve_overwrite_with_store(bwd_generator: 'BackwardPassGenerator', forward_state: SDFGState, + backward_state: SDFGState, forward_node: nodes.AccessNode, target_node: nodes.Node, + starting_edge: dstate.MultiConnectorEdge): + """ + Given the AccessNode pointing to the data required by the backward pass, + We will save the values of this array in a new array and forward it to the backward pass. + """ + + # Modify the forward pass to save the data in a new array + new_stored_array, memlets = _store_data(bwd_generator=bwd_generator, + forward_state=forward_state, + backward_state=backward_state, + forward_an=forward_node, + target_node=target_node, + edge=starting_edge) + + # Check if this data needs to be forwarded through NestedSDFGs + if bwd_generator.separate_sdfgs or forward_state.sdfg.parent_sdfg is not None: + # We need to make sure the new array is forwarded to the backward SDFG + if new_stored_array.data not in bwd_generator.backward_input_arrays: + # If the data is needed inside a NestedSDFG + # This will make sure the added array is correctly forwarded + # and an in connector to the NestedSDFG is added + data_desc = new_stored_array.desc(forward_state) + bwd_generator.backward_input_arrays[new_stored_array.data] = data_desc + + # Connect the new array to the target node + _connect_stored_data_to_target(bwd_generator=bwd_generator, + forward_state=forward_state, + backward_state=backward_state, + source_node=new_stored_array, + forward_node=forward_node, + starting_edge=starting_edge, + memlets=memlets, + target_node=target_node) + + +def _store_data(bwd_generator: 'BackwardPassGenerator', forward_state: SDFGState, backward_state: SDFGState, + forward_an: nodes.AccessNode, target_node: nodes.Node, + edge: dgraph.MultiConnectorEdge) -> Tuple[nodes.AccessNode, List[Memlet]]: + """ + Given an edge leading an AccessNode or a map to the target node in the forward state, + add a path from the connector for this AccessNode to store its values for all iterations. + This can increase the dimension of the array. i.e. the size of the stored array is + greater or equal to the size of the original array. + + :param edge: the edge connecting the AccessNode to save data from to a map node. + :return: the new AccessNode which contains the stored data, + a list of memlets connecting an assign tasklet to this new AccessNode. + """ + + # Get the connector and edge to save + if isinstance(edge.src, nodes.AccessNode) and edge.src is not forward_an: + + # Get the incoming edge to this AccessNode + in_edges = forward_state.in_edges(edge.src) + + # There should only be one incoming edge + assert len(in_edges) == 1 + + # Get the memlet path for the edge incoming to this AccessNode + memlet_path = forward_state.memlet_path(in_edges[0]) + + # The start of this path should be the forward AccessNode + assert forward_an is memlet_path[0].src + + # The last edge in the memlet path has the connector we want to save + edge = memlet_path[-1] + + # Add a new AccessNode and array to the forward pass + # First, check if a stored array with this name already exists + new_store_node_name = forward_state.sdfg._find_new_name("stored_" + forward_an.data) + + # Get the new array shape + # This will be the shape of the current array + shape: List[int] = list(bwd_generator.sdfg.arrays[forward_an.data].shape) + + # If the shape is an expression: + free_symbols_dict = {sym: None for sym in bwd_generator.sdfg.free_symbols} + if any(symbolic.issymbolic(s, free_symbols_dict) for s in shape): + # Otherwise, replace all the loop dependent allocations with the max length of the loop + # For example, an array of size [i+1] in a range(2, 10) loop will be stored in a [10, 10] array (1) + # Additionally, an array of size [32-i] in the same loop will be stored in a [10, 30] (2) + loops = _get_all_enclosing_loops(forward_state) + + if len(loops) > 0: + # Loop over the shape dimensions + for i, s in enumerate(shape): + if ad_utils.shape_has_symbols_to_replace(bwd_generator.sdfg, s): + loop_size, loop_index = _get_symbol_upper_bound_from_loop(bwd_generator, s, loops) + # Replace the symbol with the loop size and evaluate the expression + # Check if loop size can be converted to an integer + loop_index_sym = symbolic.pystr_to_symbolic(loop_index) + loop_size_sym = loop_size if isinstance(loop_size, int) else symbolic.pystr_to_symbolic(loop_size) + shape[i] = s.subs(loop_index_sym, loop_size_sym) + + # Plus the size of any enclosing loops + enclosed, _ = ad_utils.state_within_loop(forward_state=forward_state) + nb_enclosing_loops = 0 + loop_param_list = [] + if enclosed: + # Get all enclosing loops + all_encolsing_loops = _get_all_enclosing_loops(forward_state=forward_state) + nb_enclosing_loops = len(all_encolsing_loops) + # Get the size of each loop and add it to the list + for loop in all_encolsing_loops: + # Get the end of the loop + start, end = ad_utils.extract_loop_region_info(loop) + + # Check if the loop is increasing or decreasing + # First, try to convert the strings to ints if possible + # Note that we look for the start or end of the loop + # And not the size of the loop. + # This is because we access using the loop indices + # Using the loop sizes instead would require shifting accesses + _, new_dim = ad_utils.get_loop_end(start, end, loop) + + # First we check if the new dimension contains symbols + # These will need to be replaced with scalars for correct allocation + # The sdfg symbols are allowed to be in the shape + if ad_utils.shape_has_symbols_to_replace(bwd_generator.sdfg, new_dim): + # Take the expression to sympy for easier processing + if isinstance(new_dim, str): + new_dim = symbolic.pystr_to_symbolic(new_dim) + + # Try to replace the symbols with the loop size + loop_size, loop_index = _get_symbol_upper_bound_from_loop(bwd_generator, new_dim, all_encolsing_loops) + loop_index_sym = symbolic.pystr_to_symbolic(loop_index) + loop_size_sym = loop_size if isinstance(loop_size, int) else symbolic.pystr_to_symbolic(loop_size) + new_dim = new_dim.subs(loop_index_sym, loop_size_sym) + shape.insert(0, new_dim) + loop_param_list.insert(0, loop.loop_variable) + + # Add the array descriptor and AccessNode to the forward state + original_desc = forward_an.desc(forward_state) + + # We make a special case for a memlet of the type A[i, j] in an i, j loop + # In this case we only need an array of the same size as the forward node + if enclosed and edge.data.data == forward_an.data and len(edge.data.subset) == nb_enclosing_loops: + # Check that the memlet subset matches perfectly the order of loop nest + # Make sure the subset elements are (i,i,1) and (j,j,1) + # Then check if this matches the loop indices + if all( + str(subset[0]) == loop_param_list[i] and subset[0] == subset[1] and subset[2] == 1 + for i, subset in enumerate(edge.data.subset)): + # We only use the loop accesses + # Both should work since shape[:nb_enclosing_loops] == shape[nb_enclosing_loops:] + shape = shape[nb_enclosing_loops:] + + # We want to build the memlet as if this was not in a a loop + nb_enclosing_loops = 0 + + new_store_node = forward_state.add_array( + name=new_store_node_name, + shape=shape, + dtype=original_desc.dtype, + transient=True, + ) + + # Connect the edge source and connector to the new access node + # We will save the memlets we create and return them + # This is useful to make the connections for the backward state + memlets_stack = [] + + # The loop accesses will be the same within the state + # Prepare them for all edges + loop_access = ','.join([f'{loop_param_list[i]}' for i in range(nb_enclosing_loops)]) + + # In the other cases, we need to route the storing through maps + all_edges = ad_utils.get_all_path_edges(forward_state, forward_an, edge) + + # Get the map nest memlet informtation + start_range, param_list, shape_list, param_dict = ad_utils.get_map_nest_information(all_edges) + + # The parameters to add for the current memlet in the loop + # At first we will use all of the parameters that are used in the memlet + # param_dict = {key: val for key, val in param_dict.items() if key in edge.data.free_symbols} + new_param_dict = {} + + # Iterate through the subset + for index, element in enumerate(edge.data.subset): + if str(element[0]) in edge.data.free_symbols and str(element[0]) in param_dict.keys(): + # Add the range from the param_dict + new_param_dict.update({str(element[0]): param_dict[str(element[0])]}) + else: + # Add the range from the param_dict + new_param_dict.update({index: element}) + + params_to_add = new_param_dict + # First, we need to add an assign tasklet + assign_tasklet_node, assign_tasklet_node_out_connector = _get_assign_tasklet(forward_state=forward_state, + node=forward_an, + stored_node=new_store_node, + last_edge=edge, + loop_iterators=loop_access) + + # Start iterating + previous_node = assign_tasklet_node + previous_node_out_connector = assign_tasklet_node_out_connector + map_exist = None + for edge in reversed(all_edges): + if isinstance(edge.src, nodes.MapEntry): + # Get the corresponding map exit + map_exist = _find_map_exist_for_map_entry(map_entry=edge.src, state=forward_state) + + # Add the Connectors to the map + map_exit_in_connector = f"IN_stored_{new_store_node.label}" + map_exit_out_connector = f"OUT_stored_{new_store_node.label}" + added = map_exist.add_in_connector(map_exit_in_connector) + assert added + added = map_exist.add_out_connector(map_exit_out_connector) + assert added + + # Prepare the memlet data for this edge + access_list = [] + for key, val in new_param_dict.items(): + if isinstance(key, str): + if key in params_to_add.keys(): + access_list.append(key) + else: + start = val[0] + end = val[1] + access_list.append(f'{start}:{end}') + elif isinstance(key, int): + start = val[0] + end = val[1] + 1 + access_list.append(f'{start}:{end}') + else: + raise AutoDiffException("Found unexepected type in memlet parameters dictionary") + + in_state_access = ','.join(access_list) + + memlet_data = Memlet( + expr=f"{new_store_node.data}[{loop_access},{in_state_access}]") if loop_access else Memlet( + expr=f"{new_store_node.data}[{in_state_access}]") + + # Save the memlet for later + memlets_stack.append(memlet_data) + + # Connect the previous node to this map exist + forward_state.add_edge(previous_node, previous_node_out_connector, map_exist, map_exit_in_connector, + memlet_data) + + previous_node = map_exist + previous_node_out_connector = map_exit_out_connector + + # Remove the parameters seen in the current map + # Since they will become out of scope in the next iteration + params_to_add = {} + for key, val in new_param_dict.items(): + if isinstance(key, str): + if key not in edge.src.params: + start = val[0] + end = val[1] + params_to_add.update({key: (start, end)}) + elif isinstance(key, int): + params_to_add.update({key: val}) + else: + raise AutoDiffException("Found unexepected type in memlet parameters dictionary") + + else: + # Prepare the memlet data for this edge + access_list = [] + for key, val in new_param_dict.items(): + if isinstance(key, str): + start = val[0] + end = val[1] + access_list.append(f'{start}:{end}') + elif isinstance(key, int): + start = val[0] + end = val[1] + 1 + access_list.append(f'{start}:{end}') + else: + raise AutoDiffException("Found unexepected type in memlet parameters dictionary") + + in_state_access = ','.join(access_list) + + # Get the memlet data for the connection between the last map exit and the new store AccessNode + memlet_data = Memlet( + expr=f"{new_store_node.data}[{loop_access},{in_state_access}]") if loop_access else Memlet( + expr=f"{new_store_node.data}[{in_state_access}]") + + memlets_stack.append(memlet_data) + + # This should be the last connection + forward_state.add_edge(previous_node, previous_node_out_connector, new_store_node, None, memlet_data) + break + + # We need to add an empty memlet from the new store AccessNode to make sure the data is stored before it is + # potentially altered + # First, we check if this can be avoided + # We do a BFS exploration to see if the data we are trying to store is overwritten within the same execution state + bfs_nodes = list(forward_state.bfs_nodes(source=forward_an)) + + # We make sure that views are also compared with their original array to check for conflicts + conflict_arrays = [forward_an.data] + # Check if the access node is a view + if isinstance(forward_an.desc(forward_state), dt.View): + # Get the original array name + viewed_array = next(forward_state.in_edges_by_connector(forward_an, "views")).data.data + conflict_arrays.append(viewed_array) + + if any(isinstance(n, nodes.AccessNode) and n.data in conflict_arrays and n is not forward_an for n in bfs_nodes): + to_connect = [] + for out_edge in forward_state.out_edges(forward_an): + # Get the destination of the edge + dst = out_edge.dst + if not isinstance(dst, nodes.MapEntry) and dst is not assign_tasklet_node: + # This will not be necessary for maps since the storing is added to the same map + # We also don't connect the newly created assign tasklet to avoid creating a cycle + if dst not in to_connect: + # We only need to make a single connection to the new stored data + to_connect.append(dst) + + for node in to_connect: + # Connect the new store AccessNode to assure the store happens first + # If there isn't already a connnection between these two nodes + if not any(e.dst == node for e in forward_state.out_edges(new_store_node)): + forward_state.add_edge(new_store_node, None, node, None, Memlet()) + + # Another case for making sure data is stored before it is altered is when the map we save from writes itself to the data we want to save + # In this case this would depend on the codegen order of the tasklets within the map and is thus not safe + # Detect if this is the case + if map_exist: + # Check if this map exit writes to the data we want to save + if any( + isinstance(e.dst, nodes.AccessNode) and e.dst.data == forward_an.data + for e in forward_state.out_edges(map_exist)): + # Get the map entry of this map exit + tasklet_in_edges = forward_state.in_edges(assign_tasklet_node) + assert len(tasklet_in_edges) == 1 + tasklet_in_edge = tasklet_in_edges[0] + + # Safety check + if not isinstance(tasklet_in_edge.src, nodes.MapEntry): + raise AutoDiffException( + "The map exit writes to the data we want to save, but the storing strcuture is not what we expect" + ) + + # Get all the edges coming out of this specific in connector + collusion_edges = [ + e for e in forward_state.out_edges(tasklet_in_edge.src) + if e.src_conn == tasklet_in_edge.src_conn and e.dst != assign_tasklet_node + ] + + # We need to add an empty memlet from the new store tasklet to everything else that reads from that connector + for out_edge in collusion_edges: + forward_state.add_edge(assign_tasklet_node, None, out_edge.dst, None, Memlet()) + + return new_store_node, memlets_stack + + +def _connect_stored_data_to_target(bwd_generator: 'BackwardPassGenerator', forward_state: SDFGState, + backward_state: SDFGState, source_node: nodes.AccessNode, + forward_node: nodes.AccessNode, target_node: nodes.Node, memlets: List[Memlet], + starting_edge: dgraph.MultiConnectorEdge): + """ + Connect the source node to the sink target node (both in the backawrd state) through a set of maps using the parameter memelets. + We use the forward_sink_edge to track which maps to make this connection through. + :param source_node: the source node of the new memlet path + :param sink_node: the sink node of the new memlet path + :param memlets: the set of memlets to use for the edges in the path + :param forward_sink_edge: the sink edge connecting the original nodes in the forward state + """ + # First, if the stored data is not already in the sdfg descriptors, add it + # This is the case for NestedSDFGs + if source_node.data not in backward_state.sdfg.arrays: + # Get the data descriptor from the original sdfg + data_desc = copy.deepcopy(bwd_generator.sdfg.arrays[source_node.data]) + data_desc.transient = False # The stored data will be forwarded + backward_state.sdfg.add_datadesc(source_node.data, data_desc) + + # Get the memlet path from the forward state + all_edges = ad_utils.get_all_path_edges(forward_state, forward_node, starting_edge) + assert len(all_edges) > 0 + + # We will iterate and connect parent -> child + reversed_child_node = bwd_generator.reverse_map[target_node] + child_node = reversed_child_node + child_node_in_connector = all_edges[-1].dst_conn + + # Iterate through the maps in the path in reverse + for edge in reversed(all_edges): + edge_src = edge.src + if isinstance(edge_src, nodes.MapEntry): + # Get the correponding map exist + map_exit = _find_map_exist_for_map_entry(map_entry=edge_src, state=forward_state) + + # Use the lookup table to get the map entry in the backward state corresponding to this map exist in the forward state + # Sanity check: this map entry should already exist in the backward state + assert map_exit in bwd_generator.reverse_map + bwd_map_entry = bwd_generator.reverse_map[map_exit] + + # Get a new connector id + next_conn = bwd_map_entry.next_connector() + + # Add a new in connector to the mapexit + parent_node_in_connector = "IN_stored_" + source_node.data + "_" + next_conn + added = bwd_map_entry.add_in_connector(parent_node_in_connector) + assert added + + # Add a new out connector to the mapexit + parent_node_out_connector = "OUT_stored_" + source_node.data + "_" + next_conn + added = bwd_map_entry.add_out_connector(parent_node_out_connector) + assert added + + memlet_data = copy.deepcopy(memlets.pop(0)) + + # Add the edge with the corresponding memlet + backward_state.add_edge(bwd_map_entry, parent_node_out_connector, child_node, child_node_in_connector, + memlet_data) + + child_node = bwd_map_entry + child_node_in_connector = parent_node_in_connector + + if isinstance(edge_src, nodes.AccessNode): + # The connection from the stored data will be made here + assert edge_src == forward_node + memlet_data = copy.deepcopy(memlets.pop(0)) + + # Replicate the source stored node + replicated_source_node = copy.deepcopy(source_node) + backward_state.add_node(replicated_source_node) + + # Change the memlet data to read from the stored data and not the original data + memlet_data.data = replicated_source_node.data + + # Add the final connection to the source node + backward_state.add_edge(replicated_source_node, None, child_node, child_node_in_connector, memlet_data) + + # If this connection was made to a NestedSDFG and the forward node was a view, + # We need to change the strides in the data descriptor this points to + # Since the stored data is not a view + # For example, if the stride of A is 5 (because it points to a column in a 2d array), + # The stored data will only contain the row and the stride for it should be one + # This is only a problem if the view points to a NestedSDFG input, + # that expects a descriptor with the original view stride + if isinstance(child_node, nodes.NestedSDFG) and isinstance(forward_node.desc(bwd_generator.sdfg), dt.View): + # Get the strides of the stored data + stored_data_desc = bwd_generator.sdfg.arrays[source_node.data] + stored_strides = stored_data_desc.strides + + # Get the NestedSDFG input descriptor + input_desc = child_node.sdfg.arrays[child_node_in_connector] + + # Set the strides to be the last elements of the stored strides + # We take the last elements since we might add loop indices to the shape + # Sanity check the strides for this desc should be less than or equal to the stored strides + assert len(input_desc.strides) <= len(stored_strides) + input_desc.strides = stored_strides[-len(input_desc.shape):] + + # There should be the same number of memlets through the new path + assert len(memlets) == 0 + + +def _get_assign_tasklet(forward_state: SDFGState, + node: nodes.AccessNode, + stored_node: nodes.AccessNode, + last_edge: dgraph.MultiConnectorEdge, + loop_iterators: str, + cuda: bool = False): + """ + """ + # Create the assign tasklet + assign_tasklet_node_in_connector = "in_stored_" + node.data + assign_tasklet_node_out_connector = "out_stored_" + node.data + + # Create the memlet for the assignment + # This will be the same as the memlet going to the tasklet + assign_memlet_data = copy.deepcopy(last_edge.data) + param_dict = {} + memlet_access_iterators = [] + + # We check the incoming memlet volume + if assign_memlet_data.volume != 1: + # We need to add a map to iterate through the missing dimensions + # For this we will create an assign block containing a map + + # First, Get the missing dimensions + # Iterate through the subset + for element in last_edge.data.subset: + if str(element[0]) in last_edge.data.free_symbols: + # This is a symbol we will keep in the store memlet + memlet_access_iterators.append(str(element[0])) + else: + # This is a range tuple we need to add an iterator for + # Create a random new free symbol + free_symbol = forward_state.sdfg.find_new_symbol("si") + + # Add the new symbol here so that find_new_symbol doesn't return it again + forward_state.sdfg.add_symbol(free_symbol, dtypes.int64) + memlet_access_iterators.append(free_symbol) + param_dict.update({free_symbol: element}) + + # Build the memlets for input and output + in_state_access = ','.join(memlet_access_iterators) + input_memlet = Memlet(expr=f"{last_edge.data.data}[{in_state_access}]") + if loop_iterators: + output_memlet = Memlet(expr=f"{stored_node.data}[{loop_iterators},{in_state_access}]") + else: + output_memlet = Memlet(expr=f"{stored_node.data}[{in_state_access}]") + + assign_tasklet_node, map_entry, map_exit = forward_state.add_mapped_tasklet( + name=f"__store_{node.data}_assign_", + map_ranges=param_dict, + inputs={assign_tasklet_node_in_connector: input_memlet}, + code=f"{assign_tasklet_node_out_connector} = {assign_tasklet_node_in_connector}", + outputs={assign_tasklet_node_out_connector: output_memlet}, + schedule=dtypes.ScheduleType.GPU_Device if cuda else dtypes.ScheduleType.Default, + external_edges=False) + + # Add the necessary connectors for external connections + map_entry.add_in_connector("IN_store_block") + map_exit.add_out_connector("OUT_store_block") + + # Update the internal edges to route through the new connectors + # Find and update the edge from map_entry to tasklet + for e in list(forward_state.out_edges(map_entry)): + if e.dst == assign_tasklet_node: + # Update the source connector to route through our external connector + forward_state.remove_edge(e) + forward_state.add_edge(map_entry, "OUT_store_block", assign_tasklet_node, + assign_tasklet_node_in_connector, e.data) + map_entry.add_out_connector("OUT_store_block") + break + + # Find and update the edge from tasklet to map_exit + for e in list(forward_state.in_edges(map_exit)): + if e.src == assign_tasklet_node: + # Update the destination connector to route through our external connector + forward_state.remove_edge(e) + forward_state.add_edge(assign_tasklet_node, assign_tasklet_node_out_connector, map_exit, + "IN_store_block", e.data) + map_exit.add_in_connector("IN_store_block") + break + + # Make sure this block is connected correctly + assign_block = map_entry + assign_block_in_connector = "IN_store_block" + return_node = map_exit + return_connector = "OUT_store_block" + else: + # Volume is 1, create a simple tasklet without a map + assign_tasklet_node = nodes.Tasklet( + label=f"__store_{node.data}_assign_", + inputs={assign_tasklet_node_in_connector}, + outputs={assign_tasklet_node_out_connector}, + code=f"{assign_tasklet_node_out_connector} = {assign_tasklet_node_in_connector}", + ) + + # Add it to the state + forward_state.add_node(assign_tasklet_node) + + assign_block = assign_tasklet_node + assign_block_in_connector = assign_tasklet_node_in_connector + return_node = assign_tasklet_node + return_connector = assign_tasklet_node_out_connector + + # Get the last map + last_map = last_edge.src + last_map_connector = last_edge.src_conn + + # Add the new edge from the last map entrance to the new assign block + forward_state.add_edge(last_map, last_map_connector, assign_block, assign_block_in_connector, assign_memlet_data) + return return_node, return_connector + + +def _find_map_exist_for_map_entry(map_entry: nodes.MapEntry, state: SDFGState) -> nodes.MapExit: + """ + Find the map exist that corresponds to the input map entry + """ + src_candidates = [node for node in state.nodes() if isinstance(node, nodes.MapExit) and node.map == map_entry.map] + if len(src_candidates) != 1: + # this shouldn't happen; if we are within a scope, the exit nodes + # for the scope should already exist in the backward pass + raise AutoDiffException("Invalid graph") + + return src_candidates[0] + + +def _get_symbol_upper_bound_from_loop(bwd_generator: 'DataForwardingbwd_generator', s: sp.Symbol, + loops: List[LoopRegion]) -> int: + """ + Given a symbol and a list of loops, get the upper bound of the symbol from the loops. + Raises an error if the symbol is not a loop index or the upper bound cannot be extracted correctly. + """ + # Get the symbol to match + if isinstance(s, (sp.Symbol, sp.Expr)): + # We don't want to match global SDFG symbols + loop_indices = {symb for symb in s.free_symbols if str(symb) not in bwd_generator.sdfg.free_symbols} + if len(loop_indices) != 1: + raise AutoDiffException(f"Symbol dimension {s} couldn't be parsed correctly during storing") + loop_index = str(list(loop_indices)[0]) + elif isinstance(s, str): + # Convert the string to a symbolic expression and extract free symbols + try: + expr = sp.sympify(s) + except (sp.SympifyError, TypeError, ValueError) as e: + raise AutoDiffException(f"Symbol dimension {s} couldn't be parsed as a symbolic expression: {e}") + + # We don't want to match global SDFG symbols + loop_indices = {symb for symb in expr.free_symbols if str(symb) not in bwd_generator.sdfg.free_symbols} + if len(loop_indices) != 1: + raise AutoDiffException(f"Symbol dimension {s} couldn't be parsed correctly during storing") + loop_index = str(list(loop_indices)[0]) + else: + raise AutoDiffException(f"Symbol dimension {s} is not a string and not a sympy symbol") + + # If the loop bound can be directly extracted from the interstate edges + if loop_index in bwd_generator.interstate_symbols: + loop_size = bwd_generator.interstate_symbols[loop_index] + else: + # Get the loop range for this symbol + loop_size = None + for l in loops: + # Convert the sympy symbol to string to check if it macthes the loop variable + if loop_index in l.loop_variable: + # Get the max loop range + start, end = ad_utils.extract_loop_region_info(l) + + # Check if the loop variable has a negative coefficient + # by extracting the coefficient from the affine expression + s_expr = sp.sympify(s) if isinstance(s, str) else s + # Find the actual symbol in the expression that matches loop_index by name + loop_symbol = None + for sym in s_expr.free_symbols: + if str(sym) == loop_index: + loop_symbol = sym + break + + # Extract the coefficient of the loop variable + if loop_symbol is not None: + coeff = s_expr.coeff(loop_symbol) + # If coefficient is negative we need to use smallest instead of largest + matched = coeff is not None and (coeff < 0) == True + else: + # Loop variable not found in expression + matched = False + smallest, largest = ad_utils.get_loop_end(start, end, l) + if not matched: + loop_size = largest + else: + loop_size = smallest + + if loop_size is None: + raise AutoDiffException( + f"Can't figure out how to save the data inside: {l.label} because of its symbol shape {s}") + + # We will call this function recusrively until loop size is numeric or it is a global SDFG symbol + if ad_utils.shape_has_symbols_to_replace(bwd_generator.sdfg, loop_size): + loop_size, _ = _get_symbol_upper_bound_from_loop(bwd_generator, loop_size, loops) + return loop_size, loop_index + + +def _get_all_enclosing_loops(forward_state: SDFGState) -> List[LoopRegion]: + """ + Check if this state will be executed several times within a loop. + We check if any of the parents of this state is a loop region. + """ + all_loops = [] + parent = forward_state.parent_graph + while parent is not None: + if isinstance(parent, LoopRegion): + all_loops.append(parent) + parent = parent.parent_graph + return all_loops diff --git a/dace/autodiff/implementations/__init__.py b/dace/autodiff/implementations/__init__.py new file mode 100644 index 0000000000..21feb8e552 --- /dev/null +++ b/dace/autodiff/implementations/__init__.py @@ -0,0 +1,46 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Backward Pass Implementations for SDFG Elements. + +This package provides backward (gradient) implementations for various SDFG node types. +Each implementation defines how to compute gradients for specific operations. + +Implementation Categories +------------------------- +1. **DaCe Nodes** (dace_nodes.py): + - Core SDFG elements: Tasklet, MapEntry, AccessNode, etc. + - Fundamental building blocks for all DaCe programs + - Registered in DaceNodeBackwardImplementations + +2. **DaCe Reduction Nodes** (dace_reduction_nodes.py): + - Reduction operations: Sum, Max, Min + - Registered using @autoregister decorator + +3. **ONNX Operations** (onnx_ops.py): + - ONNX-specific operations from dace.libraries.onnx + - Neural network layers and operators + - Supports ONNX model differentiation + +4. **PyTorch Operations** (pytorch_ops.py): + - Operations using PyTorch CUDA kernels + - Depthwise convolution backward pass +""" + +import dace.autodiff.implementations.dace_reduction_nodes +from dace.autodiff.implementations.dace_nodes import DaceNodeBackwardImplementations + +# ONNX ops are optional +try: + import dace.autodiff.implementations.onnx_ops +except ImportError: + pass + +# PyTorch ops are optional +try: + import dace.autodiff.implementations.pytorch_ops +except ImportError: + pass + +__all__ = [ + "DaceNodeBackwardImplementations", +] diff --git a/dace/autodiff/implementations/dace_nodes.py b/dace/autodiff/implementations/dace_nodes.py new file mode 100644 index 0000000000..2439e08137 --- /dev/null +++ b/dace/autodiff/implementations/dace_nodes.py @@ -0,0 +1,487 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" + Class for defining the reversal of pure SDFG nodes: AccessNode, Tasklet, MapEntry/Exit, NestedSDFG + Each method should return a tuple (reversed_node, BackwardResult) +""" +import ast +import collections +import copy +import numbers +import astunparse +import sympy as sp +from typing import List, Tuple + +# DaCe imports +import dace +import dace.sdfg.nodes as nodes +from dace import dtypes +from dace.data import Reference, Structure +from dace.sdfg import SDFGState +from dace.data import find_new_name + +# Autodiff imports +from dace.autodiff.base_abc import BackwardResult, AutoDiffException +import dace.autodiff.utils as ad_utils + + +class DaceNodeBackwardImplementations: + + def __init__(self, backward_pass_generator: 'BackwardPassGenerator'): + self.bwd_engine = backward_pass_generator + pass + + def _reverse_NestedSDFG( + self, + forward_state: SDFGState, + backward_state: SDFGState, + node: nodes.NestedSDFG, + given_gradients: List[str], + required_gradients: List[str], + ) -> Tuple[nodes.Node, BackwardResult]: + reverse_nsdfg = dace.SDFG(node.sdfg.name + "_backward") + + gen = self.bwd_engine.create_child_generator( + sdfg=node.sdfg, + given_gradients=given_gradients, + required_gradients=required_gradients, + backward_sdfg=reverse_nsdfg, + ) + backward_result, _, backward_input_arrays = gen.backward() + + # we need to defer add edges until after the arrays have been added because creation of the nested + # sdfg fails otherwise + deferred_edges = [] + + inputs = set(backward_result.given_grad_names[name] for name in sorted(given_gradients)) + # loop through the arrays that we need from the forward pass + for name, desc in sorted(backward_input_arrays.items()): + # if the name is not already passed to the reverse SDFG node ... + if name not in required_gradients and name not in node.in_connectors: + # ... this array needs to be forwarded out of the forward SDFG (i.e. it is an intermediate value) + # 1) add it to the current SDFG, and to self.bwd_engine.backward_input_arrays + # 2) add an out connector to the forward nested SDFG, add a write node to the current state, and an edge + # from the output to there + # 3) add a read node to the backward state, and an edge into it + + desc = node.sdfg.arrays[name] + + # if the original view node is in the in-connector, no need to connect it, continue + # if forwarded_name in node.in_connectors: + # continue + + # (1) + new_name = find_new_name(name + "_forwarded", self.bwd_engine.sdfg.arrays) + if new_name in self.bwd_engine.sdfg.arrays or new_name in self.bwd_engine.backward_input_arrays: + raise AutoDiffException( + "Attempted to create array with name '{}', but it already existed".format(new_name)) + + self.bwd_engine.sdfg.add_datadesc(new_name, copy.deepcopy(desc)) + self.bwd_engine.backward_input_arrays[new_name] = copy.deepcopy(desc) + + if self.bwd_engine.separate_sdfgs: + to_add = copy.deepcopy(desc) + to_add.transient = False + self.bwd_engine.backward_sdfg.add_datadesc(new_name, to_add) + + # (2) + node.sdfg.arrays[name].transient = False + added = node.add_out_connector(name, force=True) + assert added + write = forward_state.add_write(new_name) + forward_state.add_edge(node, name, write, None, self.bwd_engine.sdfg.make_array_memlet(new_name)) + + # (3) + read = backward_state.add_read(new_name) + deferred_edges.append( + dict(u=read, + u_connector=None, + v_connector=name, + memlet=self.bwd_engine.backward_sdfg.make_array_memlet(new_name))) + inputs.add(name) + else: + inputs.add(name) + + outputs = set(backward_result.required_grad_names[name] for name in required_gradients) + + for inp in inputs: + if inp in reverse_nsdfg.arrays: + reverse_nsdfg.arrays[inp].transient = False + for outp in outputs: + if outp in reverse_nsdfg.arrays: + reverse_nsdfg.arrays[outp].transient = False + # Create the sdfg and return it + nsdfg = backward_state.add_nested_sdfg( + reverse_nsdfg, + inputs=inputs, + outputs=outputs, + ) + + # If any input connectors point to symbols + for conn, _ in nsdfg.in_connectors.items(): + if conn in nsdfg.sdfg.symbols: + # We need to add a new symbol and create a mapping + new_symbol = find_new_name(conn, nsdfg.sdfg.symbols) + nsdfg.sdfg.add_symbol(new_symbol, nsdfg.sdfg.symbols[conn]) + nsdfg.sdfg.replace(conn, new_symbol) + nsdfg.symbol_mapping[new_symbol] = conn + # Remove it from the symbol mapping too + if conn in nsdfg.symbol_mapping: + del nsdfg.symbol_mapping[conn] + for edge_args in deferred_edges: + edge_args["v"] = nsdfg + backward_state.add_edge(**edge_args) + + return nsdfg, BackwardResult(required_grad_names=backward_result.required_grad_names, + given_grad_names=backward_result.given_grad_names) + + def _reverse_AccessNode( + self, + forward_state: SDFGState, + backward_state: SDFGState, + node: nodes.AccessNode, + given_gradients: List[str], + required_gradients: List[str], + ) -> Tuple[nodes.Node, BackwardResult]: + + desc = self.bwd_engine.sdfg.arrays[node.data] + if isinstance(desc, Reference): + raise AutoDiffException(f"AccessNode '{node.data}' points to a Reference, which is not yet supported") + if isinstance(desc, Structure): + raise AutoDiffException(f"AccessNode '{node.data}' points to a Structure, which is not yet supported") + + rev = nodes.AccessNode(self.bwd_engine.array_grad_name(node.data)) + # We want all gradient arrays to be initialized to zero + # This is important for correct gradient accumulation + rev.setzero = True + backward_state.add_node(rev) + required_grad_names = {None: None} + given_grad_names = {None: None} + + if "views" in node.in_connectors: + required_grad_names = {"views": "views"} + if "views" in node.out_connectors: + given_grad_names = {"views": "views"} + + return rev, BackwardResult(required_grad_names=required_grad_names, given_grad_names=given_grad_names) + + def _reverse_MapEntry( + self, + forward_state: SDFGState, + backward_state: SDFGState, + node: nodes.MapEntry, + given_gradients: List[str], + required_gradients: List[str], + ) -> Tuple[nodes.Node, BackwardResult]: + + required_grad_names = {n: ad_utils.invert_map_connector(n) for n in required_gradients} + given_grad_names = {n: ad_utils.invert_map_connector(n) for n in given_gradients} + result = BackwardResult(required_grad_names=required_grad_names, given_grad_names=given_grad_names) + rev = nodes.MapExit(self.bwd_engine.reverse_map[node.map]) + + for _, conn in sorted(given_grad_names.items()): + added = rev.add_in_connector(conn) + assert added + + for _, conn in sorted(required_grad_names.items()): + added = rev.add_out_connector(conn) + assert added + + backward_state.add_node(rev) + return rev, result + + def _reverse_MapExit( + self, + forward_state: SDFGState, + backward_state: SDFGState, + node: nodes.MapExit, + given_gradients: List[str], + required_gradients: List[str], + ): + self.bwd_engine.reverse_map[node.map] = copy.deepcopy(node.map) + + rev = nodes.MapEntry(self.bwd_engine.reverse_map[node.map]) + for conn in sorted(node.in_connectors): + added = rev.add_in_connector(conn) + assert added + + for conn in sorted(node.out_connectors): + added = rev.add_out_connector(conn) + assert added + + backward_state.add_node(rev) + # yapf: disable + return ( + rev, + BackwardResult(required_grad_names={ + n: ad_utils.invert_map_connector(n) + for n in required_gradients + }, + given_grad_names={ + n: ad_utils.invert_map_connector(n) + for n in given_gradients + }), + ) + # yapf: enable + + def _reverse_Tasklet( + self, + state: SDFGState, + backward_state: SDFGState, + tasklet: nodes.Tasklet, + given_gradients: List[str], + required_gradients: List[str], + ) -> Tuple[nodes.Node, BackwardResult]: + if tasklet.language is not dtypes.Language.Python: + raise AutoDiffException("Expected tasklet with language Python, got language {}".format(tasklet.language)) + + # tasklets should have scalar inputs (can be relaxed) + for _, _, _, _, memlet in state.in_edges(tasklet): + if memlet.data is not None: + try: + ad_utils.is_int_eq_value(memlet.subset.num_elements(), 1) + except AutoDiffException as e: + raise AutoDiffException( + "Autodiff only supported for tasklets with scalar inputs and outputs") from e + + for _, _, _, _, memlet in state.out_edges(tasklet): + if memlet.data is not None: + try: + ad_utils.is_int_eq_value(memlet.subset.num_elements(), 1) + except AutoDiffException as e: + raise AutoDiffException( + "Autodiff only supported for tasklets with scalar inputs and outputs") from e + + code_str = tasklet.code.as_string + + # check if this is a conditional tasklet + if self.bwd_engine._conditional_tasklet(tasklet): + # we want to extract the if and else expressions and pass them to sympy + if_expression, else_expression, conditional = ad_utils.extract_conditional_expressions(tasklet) + + if_code, if_rev_inputs, if_rev_outputs, if_result = self._differentiate_code_symbolically( + self.bwd_engine.sdfg, if_expression, state, tasklet, given_gradients, required_gradients) + + if else_expression: + else_code, else_rev_inputs, else_rev_outputs, else_result = self._differentiate_code_symbolically( + self.bwd_engine.sdfg, else_expression, state, tasklet, given_gradients, required_gradients) + assert else_rev_inputs == if_rev_inputs + assert if_rev_outputs == else_rev_outputs + assert else_result == if_result + + # prepare the tasklet code depending on the conditional type + # add the same conditional to the if_code + # first, add indentation + if_code = if_code.replace("\n", "\n\t") + if_code = f"if {conditional}:\n{if_code}" + + # add the conditional to the in connectors + if_rev_inputs.add(conditional) + joint_code = if_code + + if ":" not in code_str: + # only an if in the original code + assert else_expression + else_code = else_code.replace("\n", "\n\t") + else_code = f"else:\n{else_code}" + joint_code = f"{if_code}\n{else_code}" + + # in case there are no out_connectors, we will zero out the assigned-to AccessNode + if len(if_rev_outputs) == 0: + if_rev_outputs = {"__zero_out_conn__"} + + rev = nodes.Tasklet("_" + tasklet.label + "_reverse_", + inputs=if_rev_inputs, + outputs=if_rev_outputs, + code=joint_code, + debuginfo=tasklet.debuginfo) + + result = if_result + else: + code, rev_inputs, rev_outputs, result = self._differentiate_code_symbolically( + self.bwd_engine.sdfg, code_str, state, tasklet, given_gradients, required_gradients) + rev = nodes.Tasklet("_" + tasklet.label + "_reverse_", + inputs=rev_inputs, + outputs=rev_outputs, + code=code, + debuginfo=tasklet.debuginfo) + backward_state.add_node(rev) + return rev, result + + def _differentiate_code_symbolically( + self, + sdfg: dace.SDFG, + code_str: str, + forward_state: SDFGState, + tasklet: nodes.Tasklet, + given_gradients: List[str], + required_gradients: List[str], + ): + """Performs symbolic differentiation on tasklet code to generate the backward-pass tasklet. + + This method uses SymPy to symbolically differentiate expressions in a tasklet's code, + applying the chain rule to compute gradients with respect to input variables. + + :param sdfg: The parent SDFG containing the tasklet. + :param code_str: Code string from the tasklet to differentiate. + :param forward_state: The SDFGState containing the forward tasklet. + :param tasklet: The forward tasklet node being differentiated. + :param given_gradients: List of output connector names for which gradients are provided (∂L/∂output). + :param required_gradients: List of input connector names for which gradients must be computed (∂L/∂input). + :return: A 4-tuple containing (code, rev_inputs, rev_outputs, result) where code is the generated + Python code for the backward tasklet, rev_inputs is the set of input connector names, + rev_outputs is the set of output connector names, and result is the BackwardResult mapping. + :raises AutoDiffException: If symbolic differentiation fails (e.g., non-differentiable operations, + unexpected graph structure, missing input edges). + + .. note:: + - Uses SymPy's symbolic differentiation and common subexpression elimination (CSE) + - Supports indexed array accesses (e.g., A[i, j]) via IndexedBase + - Handles constant assignments by zeroing gradients + - Gradient names are generated with "_gradient" suffix to avoid conflicts + - SDFG-level symbols are excluded from backward tasklet inputs + - Type casting ensures gradient types match forward pass data types + - Applies chain rule: ∂L/∂input = ∂L/∂output * (∂output/∂input) + + Example:: + + Forward tasklet: ``y = x * x + 2 * x`` + Given gradient: dy (∂L/∂y) + Required gradient: dx (∂L/∂x) + Generated code: ``dx_gradient = dy_gradient * (2*x + 2)`` + """ + output_exprs, indexed_objects_map = ad_utils.code_to_exprs(code_str, tasklet, list(sdfg.symbols.keys())) + + # for each output that an input is used in, there will be an entry for the expression of the + # grad in this list in the final code snippet. When we generate the final code for the + # reverse tasklet, we need to add them all up. + rev_code = collections.defaultdict(list) + + # the outputs of the reversed nodes are the grads of inputs of the original node + rev_outputs = set() + rev_inputs = set() + + result = BackwardResult(required_grad_names={}, given_grad_names={}) + + # symbol generator to use for CSE + symbol_generator = sp.numbered_symbols() + + code = "" + + for output_conn in sorted(given_gradients): + + # special case for conditional tasklets with constant assignment + if len(required_gradients) == 0: + # for this we need to assing a zero to the gradient output + # pick a name for the input gradient + rev_input_grad_name = find_new_name(output_conn + "_gradient", rev_inputs) + result.given_grad_names[output_conn] = rev_input_grad_name + + # zero out the gradient + code = f"\n__zero_out_conn__ = 0.0" + rev_outputs = {} + rev_inputs = {rev_input_grad_name} + + # for each output_conn... + for inp in sorted(required_gradients): + # ...add the code to generate {inp}_grad + + if inp not in result.required_grad_names: + # pick a name for the gradient + rev_output_grad_name = find_new_name(inp + "_gradient", rev_outputs) + result.required_grad_names[inp] = rev_output_grad_name + rev_outputs.add(rev_output_grad_name) + else: + rev_output_grad_name = result.required_grad_names[inp] + + output_expr = output_exprs[output_conn] + # if the expression is a constant assignment, we need to cast the float to the sympy equivalent + if isinstance(output_expr, numbers.Real): + output_expr = sp.Float(output_expr) + + # We need to prepare the w.r.t expression + if inp in indexed_objects_map: + # if the input is an indexed object, we need to create the sympy expression + indexed_base = sp.IndexedBase(inp) + idx_objects = [sp.Idx(index) for index in indexed_objects_map[inp]] + inp_expr = indexed_base[tuple(idx_objects)] + else: + # if the input is not an indexed object, we can just use it as is + inp_expr = sp.symbols(inp) + + # symbolically differentiate the output w.r.t inp + diff_expr = output_expr.diff(inp_expr) + + # do common subexpression elimination + sub_expressions, diff_expr = sp.cse(diff_expr, symbols=symbol_generator) + + diff_expr = diff_expr[0] + + if diff_expr.atoms(sp.Derivative): + # the final result contains a call to sp.Derivative + raise AutoDiffException("Unable to symbolically differentiate expression: {}".format( + diff_expr.expr)) + + if output_conn not in result.given_grad_names: + # pick a name for the input gradient + rev_input_grad_name = find_new_name(output_conn + "_gradient", rev_inputs) + result.given_grad_names[output_conn] = rev_input_grad_name + else: + rev_input_grad_name = result.given_grad_names[output_conn] + + input_symbols = diff_expr.free_symbols\ + .union(s for _, e in sub_expressions for s in e.free_symbols)\ + .difference(e for e, _ in sub_expressions) + + string_symbols = {str(symb) for symb in input_symbols} + + # If there are any symbols that are defined at the global SDFG scope + # We do not need to add these as inputs to the reverse tasklet + string_symbols = string_symbols.difference(set(sdfg.symbols.keys())) + rev_inputs |= string_symbols | {rev_input_grad_name} + + diff_code_str = "{input} * ({diff_expr})".format(input=rev_input_grad_name, diff_expr=str(diff_expr)) + # small hack: our heaviside is lowercase + diff_code_str = diff_code_str.replace("Heaviside", "heaviside") + + diff_code_str = astunparse.unparse(ad_utils.SympyCleaner().visit(ast.parse(diff_code_str))) + + sub_expression_code_strs = "\n".join(f"{target} = {expression}" + for target, expression in sub_expressions) + + # get the the final type of the gradient: this is just the type of the input connector we creating the + # gradient for + cands = list(forward_state.in_edges_by_connector(tasklet, inp)) + if len(cands) != 1: + raise AutoDiffException(f"Unexpected graph structure, could not find input edge for connector {inp}" + f" on tasklet {tasklet}") + + converted_code = ad_utils.cast_consts_to_type(diff_code_str, sdfg.arrays[cands[0].data.data].dtype) + converted_code = converted_code.replace("\n", " ") + + converted_sub_expressions = ad_utils.cast_consts_to_type(sub_expression_code_strs, + sdfg.arrays[cands[0].data.data].dtype) + + # If there is indirection in the input + if inp in indexed_objects_map: + # We need to have indirection of the output container in the backward + output_code = rev_output_grad_name + "[" + " , ".join(indexed_objects_map[inp]) + "]" + + # We also need to add the indices as connectors so that they are forwarded from the forward pass + for idx in indexed_objects_map[inp]: + if idx not in rev_inputs: + # This needs to be available in the forward pass in the first place + if idx not in tasklet.in_connectors: + raise AutoDiffException( + f"Expected index {idx} to be an input connector of the tasklet {tasklet}, " + f"but it is not. This is required for the backward pass to work correctly.") + rev_inputs.add(idx) + else: + output_code = rev_output_grad_name + + code += converted_sub_expressions + "\n" + rev_code[output_code].append(converted_code) + + for output, exprs in sorted(rev_code.items()): + code += "\n" + output + " = " + " + ".join(exprs) + + return code, rev_inputs, rev_outputs, result diff --git a/dace/autodiff/implementations/dace_reduction_nodes.py b/dace/autodiff/implementations/dace_reduction_nodes.py new file mode 100644 index 0000000000..05232acbbf --- /dev/null +++ b/dace/autodiff/implementations/dace_reduction_nodes.py @@ -0,0 +1,307 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +DaCe Library Node Backward Pass Implementations for Automatic Differentiation. + +This module provides backward pass implementations for DaCe standard library nodes +in the automatic differentiation system. Each class implements the BackwardImplementation +interface to compute gradients for specific library operations during reverse-mode +automatic differentiation. + +""" + +import copy +import typing + +# DaCe core imports +import dace +import dace.dtypes as dtypes +import dace.libraries.standard.nodes +from dace import SDFGState, SDFG, Memlet +from dace.sdfg.nodes import Node + +# DaCe frontend imports +from dace.frontend.operations import detect_reduction_type +from dace.registry import autoregister_params + +# Autodiff imports +from dace.autodiff.base_abc import BackwardImplementation, BackwardContext, BackwardResult, AutoDiffException + +# Utility imports +from dace.sdfg.utils import in_desc_with_name, out_desc_with_name + + +@autoregister_params(node_type=dace.libraries.standard.nodes.Reduce, name="pure") +class ReverseReduce(BackwardImplementation): + """Backward implementation for DaCe Reduce library nodes. + + Supports Sum, Max, and Min reduction operations. The backward pass distributes + gradients appropriately based on the reduction type: + - Sum: Broadcasts gradients uniformly across reduced dimensions + - Max/Min: Routes gradients only to positions that achieved the extremal value + """ + + @staticmethod + def backward_can_be_applied(node: Node, state: SDFGState, sdfg: SDFG) -> bool: + """Check if backward pass can be applied to this reduction node. + + :param node: The reduction node to check. + :param state: The SDFG state containing the node (unused but required by interface). + :param sdfg: The SDFG containing the state (unused but required by interface). + :return: True if backward pass can be applied, False otherwise. + """ + reduction_type = detect_reduction_type(node.wcr) + if reduction_type not in (dtypes.ReductionType.Sum, dtypes.ReductionType.Max, dtypes.ReductionType.Min): + return False + + return True + + @staticmethod + def backward(forward_node: Node, context: BackwardContext, given_gradients: typing.List[typing.Optional[str]], + required_gradients: typing.List[typing.Optional[str]]) -> typing.Tuple[Node, BackwardResult]: + """Generate the backward pass for a reduction node. + + :param forward_node: The forward reduction node. + :param context: The backward pass context. + :param given_gradients: List of gradient names provided to this node. + :param required_gradients: List of gradient names required by this node. + :return: Tuple of the backward node and the backward result. + :raises AutoDiffException: If the node has invalid number of edges. + """ + reduction_type = detect_reduction_type(forward_node.wcr) + + if len(given_gradients) != 1: + raise AutoDiffException(f"Invalid SDFG: reduce node {forward_node} should have exactly one output edge, " + f"got {len(given_gradients)} output gradients") + + if len(required_gradients) != 1: + raise AutoDiffException(f"Invalid SDFG: reduce node {forward_node} should have exactly one input edge, " + f"got {len(required_gradients)} input gradients") + + input_name = next(iter(required_gradients)) + in_desc = in_desc_with_name(forward_node, context.forward_state, context.forward_sdfg, input_name) + + output_name = next(iter(given_gradients)) + out_desc = out_desc_with_name(forward_node, context.forward_state, context.forward_sdfg, output_name) + + all_axes: typing.List[int] = list(range(len(in_desc.shape))) + reduce_axes: typing.List[int] = all_axes if forward_node.axes is None else forward_node.axes + non_reduce_axes: typing.List[int] = [i for i in all_axes if i not in reduce_axes] + + result = BackwardResult.empty() + + return ReverseReduce._backward_reduction(forward_node, context, result, reduction_type, input_name, output_name, + in_desc, out_desc, all_axes, non_reduce_axes) + + @staticmethod + def _backward_reduction(forward_node: Node, context: BackwardContext, result: BackwardResult, + reduction_type: dtypes.ReductionType, input_name: str, output_name: str, in_desc, out_desc, + all_axes: typing.List[int], + non_reduce_axes: typing.List[int]) -> typing.Tuple[Node, BackwardResult]: + """Backward pass for Sum/Max/Min reductions. + + - Sum: Broadcasts gradients uniformly across reduced dimensions + - Max/Min: Routes gradients to positions that achieved the extremal value, + split equally among tied values + + :param forward_node: The forward reduction node. + :param context: The backward pass context. + :param result: The backward result to populate. + :param reduction_type: The type of reduction (Sum, Max, or Min). + :param input_name: Name of the input connector. + :param output_name: Name of the output connector. + :param in_desc: Input data descriptor. + :param out_desc: Output data descriptor. + :param all_axes: List of all axes indices. + :param non_reduce_axes: List of axes not being reduced. + :return: Tuple of the nested SDFG node and the backward result. + """ + is_extremal = reduction_type in (dtypes.ReductionType.Max, dtypes.ReductionType.Min) + type_name = { + dtypes.ReductionType.Sum: "sum", + dtypes.ReductionType.Max: "max", + dtypes.ReductionType.Min: "min" + }[reduction_type] + + sdfg = SDFG("_reverse_" + str(reduction_type).replace(".", "_") + "_") + + rev_input_conn_name = "input_gradient" + rev_output_conn_name = "output_gradient" + + result.required_grad_names[output_name] = rev_output_conn_name + result.given_grad_names[input_name] = rev_input_conn_name + + sdfg.add_array(rev_input_conn_name, shape=out_desc.shape, dtype=out_desc.dtype, strides=out_desc.strides) + sdfg.add_array(rev_output_conn_name, shape=in_desc.shape, dtype=in_desc.dtype, strides=in_desc.strides) + + nsdfg_inputs = {rev_input_conn_name} + + if is_extremal: + extremal_conn_name = f"input_{type_name}" + extremal_idx_conn_name = f"input_{type_name}_idx" + sdfg.add_array(extremal_conn_name, shape=out_desc.shape, dtype=out_desc.dtype, strides=out_desc.strides) + sdfg.add_array(extremal_idx_conn_name, shape=in_desc.shape, dtype=in_desc.dtype, strides=in_desc.strides) + nsdfg_inputs.update({extremal_conn_name, extremal_idx_conn_name}) + + # Add transient array to count matching elements per output position + count_arr_name = f"_{type_name}_count" + sdfg.add_array(count_arr_name, shape=out_desc.shape, dtype=out_desc.dtype, transient=True) + + reduce_all_axes = forward_node.axes is None or set(range(len(in_desc.shape))) == set(forward_node.axes) + + if is_extremal: + # Two-state approach for max/min: + # State 1: Count elements matching extremal value + # State 2: Compute normalized gradient + + count_state = sdfg.add_state(f"count_{type_name}_{id(forward_node)}") + grad_state = sdfg.add_state(f"grad_{type_name}_{id(forward_node)}") + sdfg.add_edge(count_state, grad_state, dace.InterstateEdge()) + + # State 1: Count matching elements + count_memlet = Memlet.simple(count_arr_name, + "0" if reduce_all_axes else ",".join("i" + str(i) for i in non_reduce_axes), + wcr_str="lambda x, y: x + y") + extremal_val_memlet_count = Memlet.simple( + extremal_conn_name, "0" if reduce_all_axes else ",".join("i" + str(i) for i in non_reduce_axes)) + extremal_idx_memlet_count = Memlet.simple(extremal_idx_conn_name, ",".join("i" + str(i) for i in all_axes)) + + _, _, count_exit = count_state.add_mapped_tasklet( + f"_count_{type_name}_matches_", { + "i" + str(i): "0:{}".format(shape) + for i, shape in enumerate(in_desc.shape) + }, { + "__extremal_val": extremal_val_memlet_count, + "__extremal_val_idx": extremal_idx_memlet_count + }, + "__count = 1.0 if __extremal_val == __extremal_val_idx else 0.0", {"__count": count_memlet}, + external_edges=True) + + # Set count array to zero before accumulation + count_out_edges = count_state.out_edges(count_exit) + if len(count_out_edges) == 1: + count_out_node = count_out_edges[0].dst + if isinstance(count_out_node, dace.nodes.AccessNode): + count_out_node.setzero = True + + # State 2: Compute normalized gradient (grad / count) + reduction_memlet = Memlet.simple( + rev_input_conn_name, "0" if reduce_all_axes else ",".join("i" + str(i) for i in non_reduce_axes)) + reverse_reduction_memlet = Memlet.simple(rev_output_conn_name, + ",".join("i" + str(i) for i in all_axes), + wcr_str="lambda x, y: x + y") + extremal_val_memlet = Memlet.simple( + extremal_conn_name, "0" if reduce_all_axes else ",".join("i" + str(i) for i in non_reduce_axes)) + extremal_idx_memlet = Memlet.simple(extremal_idx_conn_name, ",".join("i" + str(i) for i in all_axes)) + count_read_memlet = Memlet.simple( + count_arr_name, "0" if reduce_all_axes else ",".join("i" + str(i) for i in non_reduce_axes)) + + tasklet_inputs = { + "__in": reduction_memlet, + "__extremal_val": extremal_val_memlet, + "__extremal_val_idx": extremal_idx_memlet, + "__count": count_read_memlet + } + tasklet_code = "__out = __in / __count if __extremal_val == __extremal_val_idx else 0" + + _, _, exit_map = grad_state.add_mapped_tasklet(f"_{type_name}_grad_" + + str(reduction_type).replace(".", "_") + "_", { + "i" + str(i): "0:{}".format(shape) + for i, shape in enumerate(in_desc.shape) + }, + tasklet_inputs, + tasklet_code, {"__out": reverse_reduction_memlet}, + external_edges=True) + + state = grad_state + else: + # Sum reduction: simple broadcast + state = sdfg.add_state(f"block_{id(forward_node)}") + reduction_memlet = Memlet.simple( + rev_input_conn_name, "0" if reduce_all_axes else ",".join("i" + str(i) for i in non_reduce_axes)) + reverse_reduction_memlet = Memlet.simple(rev_output_conn_name, + ",".join("i" + str(i) for i in all_axes), + wcr_str="lambda x, y: x + y") + tasklet_inputs = {"__in": reduction_memlet} + tasklet_code = "__out = __in" + + _, _, exit_map = state.add_mapped_tasklet(f"_{type_name}_grad_" + str(reduction_type).replace(".", "_") + + "_", { + "i" + str(i): "0:{}".format(shape) + for i, shape in enumerate(in_desc.shape) + }, + tasklet_inputs, + tasklet_code, {"__out": reverse_reduction_memlet}, + external_edges=True) + + nsdfg = context.backward_state.add_nested_sdfg(sdfg, nsdfg_inputs, {rev_output_conn_name}) + + out_edges = state.out_edges(exit_map) + if len(out_edges) != 1: + raise AutoDiffException(f"Expected exactly one output edge from map exit, got {len(out_edges)}") + out_edge = out_edges[0] + out_node = out_edge.dst + if not isinstance(out_node, dace.nodes.AccessNode): + raise AutoDiffException(f"Expected AccessNode as output, got {type(out_node)}") + out_node.setzero = True + + if not is_extremal: + return nsdfg, result + + backward_state = context.backward_state + fwd_in_edges = context.forward_state.in_edges(forward_node) + if len(fwd_in_edges) != 1: + raise AutoDiffException(f"Expected exactly one input edge to forward node, got {len(fwd_in_edges)}") + fwd_in_edge = fwd_in_edges[0] + fwd_in_node = fwd_in_edge.src + if not isinstance(fwd_in_node, dace.nodes.AccessNode): + raise AutoDiffException(f"Expected AccessNode as input source, got {type(fwd_in_node)}") + + # Register forward input array for data forwarding (in case it's overwritten) + if fwd_in_node.data not in context.backward_generator.backward_input_arrays: + data_desc = copy.deepcopy(context.forward_sdfg.arrays[fwd_in_node.data]) + context.backward_generator.backward_input_arrays[fwd_in_node.data] = data_desc + + bwd_read = backward_state.add_read(fwd_in_node.data) + backward_state.add_edge(bwd_read, None, nsdfg, extremal_idx_conn_name, copy.deepcopy(fwd_in_edge.data)) + + if isinstance(context.forward_sdfg.arrays[fwd_in_node.data], (dace.data.View, dace.data.ArrayView)): + in_edge = context.forward_state.in_edges(fwd_in_node) + if len(in_edge) != 1: + raise AutoDiffException(f"Expected exactly one input edge to view node, got {len(in_edge)}") + in_edge = in_edge[0] + in_node = in_edge.src + if isinstance(in_node, dace.nodes.AccessNode): + if isinstance(context.forward_sdfg.arrays[in_node.data], (dace.data.View, dace.data.ArrayView)): + raise AutoDiffException(f"Nested views are not supported: {in_node.data}") + bwd_in_read = backward_state.add_read(in_node.data) + backward_state.add_edge(bwd_in_read, None, bwd_read, "views", copy.deepcopy(in_edge.data)) + + fwd_out_edges = context.forward_state.out_edges(forward_node) + if len(fwd_out_edges) != 1: + raise AutoDiffException(f"Expected exactly one output edge from forward node, got {len(fwd_out_edges)}") + fwd_out_edge = fwd_out_edges[0] + fwd_out_node = fwd_out_edge.dst + if not isinstance(fwd_out_node, dace.nodes.AccessNode): + raise AutoDiffException(f"Expected AccessNode as output destination, got {type(fwd_out_node)}") + + # Register forward output array for data forwarding (in case it's overwritten) + if fwd_out_node.data not in context.backward_generator.backward_input_arrays: + data_desc = copy.deepcopy(context.forward_sdfg.arrays[fwd_out_node.data]) + context.backward_generator.backward_input_arrays[fwd_out_node.data] = data_desc + + bwd_out_read = backward_state.add_read(fwd_out_node.data) + backward_state.add_edge(bwd_out_read, None, nsdfg, extremal_conn_name, copy.deepcopy(fwd_out_edge.data)) + + if isinstance(context.forward_sdfg.arrays[fwd_out_node.data], (dace.data.View, dace.data.ArrayView)): + out_edge = context.forward_state.out_edges(fwd_out_node) + if len(out_edge) != 1: + raise AutoDiffException(f"Expected exactly one output edge from view node, got {len(out_edge)}") + out_edge = out_edge[0] + out_node = out_edge.dst + if isinstance(out_node, dace.nodes.AccessNode): + if isinstance(context.forward_sdfg.arrays[out_node.data], (dace.data.View, dace.data.ArrayView)): + raise AutoDiffException(f"Nested views are not supported: {out_node.data}") + bwd_in_read = backward_state.add_read(out_node.data) + backward_state.add_edge(bwd_in_read, None, bwd_out_read, "views", copy.deepcopy(out_edge.data)) + + return nsdfg, result diff --git a/dace/autodiff/implementations/onnx_ops.py b/dace/autodiff/implementations/onnx_ops.py new file mode 100644 index 0000000000..e04f86c728 --- /dev/null +++ b/dace/autodiff/implementations/onnx_ops.py @@ -0,0 +1,1045 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +ONNX Backward Pass Implementations for Automatic Differentiation. + +This module provides backward pass implementations for ONNX operations in the DaCe autodiff +system. Each class implements the BackwardImplementation interface to compute gradients +for specific ONNX operations during reverse-mode automatic differentiation. + +The implementations handle various ONNX operations including: +- Mathematical operations (Einsum, Clip, Softmax, etc.) +- Neural network layers (Conv, LayerNormalization, etc.) +- Pooling operations (MaxPool, GlobalAveragePool) +- Utility operations (Transpose, Where, etc.) +""" + +import copy +import itertools +from typing import List, Optional, Tuple, Dict, Union + +import numpy as np + +# DaCe core imports +import dace +from dace.frontend.common import einsum +import dace.libraries +from dace.registry import autoregister_params +from dace import nodes as nd + +# ONNX-specific imports +import dace.libraries.onnx as donnx +from dace.libraries.onnx.converters import clean_onnx_name +from dace.libraries.onnx.op_implementations.linalg_ops import PureEinsum +from dace.transformation.onnx.replacement import onnx_constant_or_none + +# Autodiff imports +import dace.autodiff.utils as butils +from dace.autodiff.base_abc import BackwardImplementation, BackwardContext, BackwardResult + +# Utility imports +from dace.sdfg.utils import in_desc_with_name + + +def reverse_einsum_wrt_input(forward_node: 'donnx.nodes.onnx_op.ONNXOp', input_name: str) -> Tuple[List[str], str]: + """Produce the einsum string that computes the gradient of forward_node w.r.t. input_name. + + .. note:: + There is an edge case we currently don't handle (can be implemented though). + Something like 'ii->i' would become 'i->ii'. This is invalid because 'i' is repeated in the output. + + :param forward_node: The einsum node to reverse. + :param input_name: The connector on the forward node to produce the gradient computation for. + :return: Tuple of (list of forward node connectors required as inputs, einsum string). + The first parameter of the produced einsum string is implicitly the grad of Output. + """ + + _, input_idx = donnx.parse_variadic_param(input_name) + parser = einsum.EinsumParser(forward_node.equation) + + backward_input_expressions = [parser.output] + parser.inputs[:input_idx] + parser.inputs[input_idx + 1:] + backward_input_arrays = [ + f"Inputs__{i}" for i in itertools.chain(range(input_idx), range(input_idx + 1, len(parser.inputs))) + ] + + einsum_str = f"{','.join(backward_input_expressions)}->{parser.inputs[input_idx]}" + return backward_input_arrays, einsum_str + + +@autoregister_params(op="Einsum", name="default") +class DefaultEinsumBackward(BackwardImplementation): + """Backward implementation for ONNX Einsum operation. + + The symbolic autodiff can automatically derive matmuls, but the produced maps are more difficult to optimize. + This implementation provides a more efficient ONNX-based backward pass. + """ + + @staticmethod + def backward_can_be_applied(node: nd.Node, state: dace.SDFGState, sdfg: dace.SDFG) -> bool: + return PureEinsum.forward_can_be_applied(node, state, sdfg) + + @staticmethod + def backward(forward_node: nd.Node, context: BackwardContext, given_gradients: List[Optional[str]], + required_gradients: List[Optional[str]]) -> Tuple[nd.Node, BackwardResult]: + + nsdfg = dace.SDFG(forward_node.label + "_backward") + nstate = nsdfg.add_state() + + # setup arrays + output_desc = butils.forward_out_desc_with_name(forward_node, context, "Output") + result = BackwardResult.empty() + result.given_grad_names["Output"] = butils.add_backward_desc(nsdfg, context.forward_sdfg, output_desc, "Output") + access_output_grad = nstate.add_read(result.given_grad_names["Output"]) + + def create_access_node(connector: str) -> nd.AccessNode: + nsdfg.add_datadesc(connector, + copy.deepcopy(butils.forward_in_desc_with_name(forward_node, context, connector))) + return nstate.add_read(connector) + + # the forward inputs we will require + # maps the connector name to the accessnode + required_forward_inputs: Dict[str, nd.AccessNode] = {} + + for input_name in sorted(required_gradients): + # we add an einsum for each required gradient + forward_inputs, einsum_str = reverse_einsum_wrt_input(forward_node, input_name) + + einsum_node = donnx.ONNXEinsum(input_name + "_backward", equation=einsum_str) + nstate.add_node(einsum_node) + + # the first input is always the output grad + einsum_node.add_in_connector(f"Inputs__0") + nstate.add_edge(access_output_grad, None, einsum_node, "Inputs__0", + nsdfg.make_array_memlet(result.given_grad_names["Output"])) + + # add the other inputs from forward that we need + for i, forward_input in enumerate(sorted(forward_inputs)): + connector = f"Inputs__{i + 1}" + einsum_node.add_in_connector(connector) + if forward_input not in required_forward_inputs: + required_forward_inputs[forward_input] = create_access_node(forward_input) + + nstate.add_edge(required_forward_inputs[forward_input], None, einsum_node, connector, + nsdfg.make_array_memlet(forward_input)) + + # write out the gradient + butils.forward_in_desc_with_name(forward_node, context, input_name) + result.required_grad_names[input_name] = butils.add_backward_desc_for_connector( + nsdfg, forward_node, context, input_name, True) + memlet = nsdfg.make_array_memlet(result.required_grad_names[input_name]) + # Add a wcr for gradient accumulation + memlet.wcr = "lambda x, y: x + y" + nstate.add_edge(einsum_node, "Output", nstate.add_write(result.required_grad_names[input_name]), None, + memlet) + + result_node = context.backward_state.add_nested_sdfg( + nsdfg, + set(result.given_grad_names.values()).union(required_forward_inputs), + set(result.required_grad_names.values())) + + return result_node, result + + +@autoregister_params(op="Clip", name="default") +class DefaultClipBackward(BackwardImplementation): + """Backward implementation for ONNX Clip operation. + + Computes gradients by zeroing out regions where the input was clipped + and passing through gradients where the input was within bounds. + """ + + @staticmethod + def backward(forward_node: nd.Node, context: BackwardContext, given_gradients: List[Optional[str]], + required_gradients: List[Optional[str]]) -> Tuple[Union[nd.Node, dace.SDFG], BackwardResult]: + + result_node, result = butils.add_empty_sdfg_for_node(forward_node, ["input_grad", "output_grad", "input"], + context) + + nstate = result_node.sdfg.add_state() + + min_node = next(context.forward_state.in_edges_by_connector(forward_node, 'min')).src + max_node = next(context.forward_state.in_edges_by_connector(forward_node, 'max')).src + minval = onnx_constant_or_none(context.forward_sdfg, min_node) + maxval = onnx_constant_or_none(context.forward_sdfg, max_node) + + idesc = butils.forward_in_desc_with_name(forward_node, context, "input") + shape = idesc.shape + map_ranges = {f"i{i}": f"0:{s}" for i, s in enumerate(shape)} + + input_dtype = idesc.dtype + minstr = f"dace.{input_dtype.to_string()}({minval})" + maxstr = f"dace.{input_dtype.to_string()}({maxval})" + + index_str = f"{', '.join(map_ranges.keys())}" + code = f""" +if __input < {minstr} or __input > {maxstr}: + __input_grad = 0 +else: + __input_grad = __output_grad + """ + nstate.add_mapped_tasklet(forward_node.label + "_backward", + map_ranges=map_ranges, + inputs={ + f"__output_grad": dace.Memlet(f"output_grad[{index_str}]"), + f"__input": dace.Memlet(f"input[{index_str}]"), + }, + code=code, + outputs={f"__input_grad": dace.Memlet(f"input_grad[{index_str}]")}, + external_edges=True) + + return result_node, result + + +@autoregister_params(op="Dropout", name="default") +class DefaultDropoutBackward(BackwardImplementation): + """Backward implementation for ONNX Dropout operation. + + Applies the dropout mask to the output gradients and scales by the keep probability. + """ + + @staticmethod + def backward(forward_node: nd.Node, context: BackwardContext, given_gradients: List[Optional[str]], + required_gradients: List[Optional[str]]) -> Tuple[Union[nd.Node, dace.SDFG], BackwardResult]: + + result_node, result = butils.add_empty_sdfg_for_node(forward_node, + ["data_grad", "output_grad", "mask", "ratio"], context) + + nstate = result_node.sdfg.add_state() + + data_desc = butils.forward_in_desc_with_name(forward_node, context, "data") + shape = data_desc.shape + dtype = data_desc.dtype + dtype_str = dtype.to_string() + map_ranges = {f"i{i}": f"0:{s}" for i, s in enumerate(shape)} + index_str = f"{', '.join(map_ranges.keys())}" + code = f""" +scale = dace.{dtype_str}(1.0) / (1 - __ratio) +__data_grad = __output_grad * __mask * scale + """ + nstate.add_mapped_tasklet(forward_node.label + "_backward", + map_ranges=map_ranges, + inputs={ + "__output_grad": dace.Memlet(f"output_grad[{index_str}]"), + "__mask": dace.Memlet(f"mask[{index_str}]"), + "__ratio": dace.Memlet("ratio[0]") + }, + code=code, + outputs={f"__data_grad": dace.Memlet(f"data_grad[{index_str}]")}, + external_edges=True) + + return result_node, result + + +@autoregister_params(op="Softmax", name="default") +class DefaultSoftmaxBackward(BackwardImplementation): + """Backward implementation for ONNX Softmax operation. + + Computes gradients using the mathematical relationship: + dX = softmax(X) * (dY - sum(dY * softmax(X))) + where dY is the output gradient and dX is the input gradient. + """ + + @staticmethod + def backward(forward_node: nd.Node, context: BackwardContext, given_gradients: List[Optional[str]], + required_gradients: List[Optional[str]]) -> Tuple[Union[nd.Node, dace.SDFG], BackwardResult]: + + dim = forward_node.axis + + output_desc = copy.deepcopy(butils.forward_out_desc_with_name(forward_node, context, "output")) + output_desc.transient = False + + sums_shape = list(copy.deepcopy(output_desc.shape)) + sums_shape[dim] = 1 + + # Create new SDFG + nsdfg = dace.SDFG(forward_node.label + "_backward") + nstate = nsdfg.add_state() + + result = BackwardResult.empty() + + # Given gradients (from output of forward pass) + result.given_grad_names["output"] = "output_grad" + output_grad_desc = copy.deepcopy(output_desc) + nsdfg.add_datadesc("output_grad", output_grad_desc) + + # Required gradient to be computed + input_name = "input" + if "input" not in required_gradients: + # this can happen for example in bert, where the input to softmax is masked + input_name = next(iter(required_gradients)) + + input_grad_desc = copy.deepcopy(butils.forward_in_desc_with_name(forward_node, context, input_name)) + input_grad_desc.transient = False + input_grad_desc_dtype = input_grad_desc.dtype + result.required_grad_names[input_name] = "input_grad" + nsdfg.add_datadesc("input_grad", input_grad_desc) + + # We need the output of the forward op + nsdfg.add_datadesc("output", output_desc) + + # Intermediate arrays + prod_desc = copy.deepcopy(output_desc) + prod_desc.transient = True + nsdfg.add_datadesc("prod", prod_desc) + + sums_desc = dace.data.Array(input_grad_desc_dtype, sums_shape, transient=True) + nsdfg.add_datadesc("sums", sums_desc) + + sub_term_desc = copy.deepcopy(output_desc) + sub_term_desc.transient = True + nsdfg.add_datadesc("sub_term", sub_term_desc) + + # Add nodes + output_grad_read = nstate.add_read("output_grad") + forward_output_read = nstate.add_read("output") + input_grad_write = nstate.add_write("input_grad") + prod_access = nstate.add_access("prod") + sums_access = nstate.add_access("sums") + sub_term_access = nstate.add_access("sub_term") + + # prod = forward_output * output_grad + mul_node1 = donnx.ONNXMul("mul_prod") + nstate.add_node(mul_node1) + nstate.add_edge(forward_output_read, None, mul_node1, "A", nsdfg.make_array_memlet("output")) + nstate.add_edge(output_grad_read, None, mul_node1, "B", nsdfg.make_array_memlet("output_grad")) + nstate.add_edge(mul_node1, "C", prod_access, None, nsdfg.make_array_memlet("prod")) + + # sums = ReduceSum(prod, axes=[dim], keepdims=1) + reduce_sum_node = donnx.ONNXReduceSum("reduce_sum", keepdims=1, optional={"axes"}) + reduce_sum_node.axes = dim + nstate.add_node(reduce_sum_node) + + # Setup the axes input for the ReduceSum node + axes_name, _ = nsdfg.add_array(name="reduce_sum_axes", shape=[1], dtype=dace.int64, transient=True) + axes_access = nstate.add_access(axes_name) + axes_tasklet = nstate.add_tasklet("init_axes", {}, {"out"}, f"out = {dim};", language=dace.Language.CPP) + nstate.add_edge(axes_tasklet, "out", axes_access, None, dace.Memlet(f"{axes_name}")) + + nstate.add_edge(prod_access, None, reduce_sum_node, "data", nsdfg.make_array_memlet("prod")) + nstate.add_edge(axes_access, None, reduce_sum_node, "axes", nsdfg.make_array_memlet(axes_name)) + nstate.add_edge(reduce_sum_node, "reduced", sums_access, None, nsdfg.make_array_memlet("sums")) + + # sub_term = forward_output * sums + mul_node2 = donnx.ONNXMul("mul_sub_term") + nstate.add_node(mul_node2) + nstate.add_edge(forward_output_read, None, mul_node2, "A", nsdfg.make_array_memlet("output")) + nstate.add_edge(sums_access, None, mul_node2, "B", nsdfg.make_array_memlet("sums")) + nstate.add_edge(mul_node2, "C", sub_term_access, None, nsdfg.make_array_memlet("sub_term")) + + # input_grad = prod - sub_term + sub_node = donnx.ONNXSub("sub_input_grad") + nstate.add_node(sub_node) + nstate.add_edge(prod_access, None, sub_node, "A", nsdfg.make_array_memlet("prod")) + nstate.add_edge(sub_term_access, None, sub_node, "B", nsdfg.make_array_memlet("sub_term")) + nstate.add_edge(sub_node, "C", input_grad_write, None, nsdfg.make_array_memlet("input_grad")) + + # Create nested SDFG + result_node = context.backward_state.add_nested_sdfg( + nsdfg, + # Inputs to nested SDFG + {"output", "output_grad"}, + # Outputs from nested SDFG + {"input_grad"}) + + butils.connect_output_from_forward(forward_node, result_node, context, "output") + + return result_node, result + + +def _find_map_by_param(sdfg: dace.SDFG, pname: str) -> dace.nodes.MapEntry: + """Find the first map entry node by the given parameter name. + + :param sdfg: The SDFG to search. + :param pname: The parameter name to look for. + :return: The first MapEntry node containing the specified parameter. + :raises StopIteration: If no MapEntry with the parameter is found. + """ + return next(n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, dace.nodes.MapEntry) and pname in n.params) + + +@autoregister_params(op="MaxPool", name="default") +class DefaultMaxPoolBackward(BackwardImplementation): + """Backward implementation for ONNX MaxPool operation. + + Implements gradient computation by routing gradients only to the locations + that achieved the maximum value in the forward pass. + """ + + @staticmethod + def backward(forward_node: nd.Node, context: BackwardContext, given_gradients: List[Optional[str]], + required_gradients: List[Optional[str]]) -> Tuple[Union[nd.Node, dace.SDFG], BackwardResult]: + + output_shape = butils.forward_out_desc_with_name(forward_node, context, "Y").shape + + N, C, H, W = output_shape + sty, stx = forward_node.strides + sy, sx = forward_node.kernel_shape + dtype = butils.forward_in_desc_with_name(forward_node, context, "X").dtype + + def maxpool_backward(X, Y_grad, X_grad): + for b, c, ti, tj in dace.map[0:N, 0:C, 0:H, 0:W]: + maxv = np.empty([1], dtype=dtype) + maxi = np.empty([1], dtype=np.int32) + maxj = np.empty([1], dtype=np.int32) + with dace.tasklet: + v >> maxv + v = -9999999 + + # Deterministic argmax + for i, j in dace.map[0:sy, 0:sx] @ dace.ScheduleType.Sequential: + with dace.tasklet: + o << X[b, c, sty * ti + i, stx * tj + j] + vin << maxv + v >> maxv(-1) + ind_i >> maxi(-1) + ind_j >> maxj(-1) + if o > vin: + v = o + ind_i = i + ind_j = j + with dace.tasklet: + igrad << Y_grad[b, c, ti, tj] + ind_i << maxi + ind_j << maxj + ograd >> X_grad(1)[b, c, :, :] + ograd[ind_i, ind_j] = igrad + + result_node, result = butils.backward_program_for_node(maxpool_backward, context, forward_node) + + return result_node, result + + +@autoregister_params(op="LogSoftmax", name="default") +class DefaultLogSoftmaxBackward(BackwardImplementation): + """Backward implementation for ONNX LogSoftmax operation. + + Computes gradients using the mathematical relationship for log-softmax: + dX = dY - exp(Y) * sum(dY) + where Y is the forward output and dY is the output gradient. + """ + + @staticmethod + def backward(forward_node: nd.Node, context: BackwardContext, given_gradients: List[Optional[str]], + required_gradients: List[Optional[str]]) -> Tuple[nd.Node, BackwardResult]: + + dim = forward_node.axis + output_shape = butils.forward_out_desc_with_name(forward_node, context, "output").shape + output_dtype = butils.forward_out_desc_with_name(forward_node, context, "output").dtype + + sums_shape = list(copy.deepcopy(output_shape)) + sums_shape[dim] = 1 + + def logsoftmax_backward(output, output_grad, input_grad): + exp_output = dace.define_local(output_shape, output_dtype) + donnx.ONNXExp(input=output, output=exp_output) + + grad_output_sum = dace.define_local(sums_shape, output_dtype) + donnx.ONNXReduceSum(data=output_grad, reduced=grad_output_sum, keepdims=1, axes=[dim]) + # let's not use ONNXMul here; not sure how this inplace op is handled by ORT... + exp_output[:] = exp_output * grad_output_sum + donnx.ONNXSub(A=output_grad, B=exp_output, C=input_grad) + + result_node, result = butils.backward_program_for_node(logsoftmax_backward, context, forward_node) + + butils.connect_output_from_forward(forward_node, result_node, context, "output") + return result_node, result + + +@autoregister_params(op="GlobalAveragePool", name="pure") +class PureGlobalAveragePoolingBackward(BackwardImplementation): + """Pure implementation of GlobalAveragePool backward pass. + + Broadcasts the output gradient uniformly across the spatial dimensions + with appropriate scaling by the pool size. + """ + + @staticmethod + def backward_can_be_applied(node: nd.Node, state: dace.SDFGState, sdfg: dace.SDFG) -> bool: + return len(in_desc_with_name(node, state, sdfg, "X").shape) == 4 + + @staticmethod + def backward(forward_node: nd.Node, context: BackwardContext, given_gradients: List[Optional[str]], + required_gradients: List[Optional[str]]) -> Tuple[nd.Node, BackwardResult]: + desc = butils.forward_in_desc_with_name(forward_node, context, "X") + N, C, H, W = desc.shape + dtype = desc.dtype + + inv = 1.0 / (H * W) + + def bwd(X_grad, Y_grad): + for n, c, h, w in dace.map[0:N, 0:C, 0:H, 0:W]: + with dace.tasklet: + y_grad << Y_grad[n, c] + x_grad >> X_grad[n, c, h, w] + x_grad = y_grad * dtype(inv) + + return butils.backward_program_for_node(bwd, context, forward_node) + + +@autoregister_params(op="Transpose", name="default") +class DefaultTransposeBackward(BackwardImplementation): + """Backward implementation for ONNX Transpose operation. + + The gradient of transpose is another transpose with inverted permutation. + """ + + @staticmethod + def backward(forward_node: nd.Node, context: BackwardContext, given_gradients: List[Optional[str]], + required_gradients: List[Optional[str]]) -> Tuple[nd.Node, BackwardResult]: + inv_perm = tuple(np.argsort(forward_node.perm)) + + node = donnx.ONNXTranspose(forward_node.name + "_backward", perm=inv_perm) + context.backward_state.add_node(node) + + result = BackwardResult.empty() + result.given_grad_names["transposed"] = "data" + result.required_grad_names["data"] = "transposed" + + return node, result + + +@autoregister_params(op="Where", name="default") +class WhereBackward(BackwardImplementation): + """Backward implementation for ONNX Where operation. + + Routes gradients based on the condition: gradients flow to X where condition is True, + and to Y where condition is False. + """ + + @staticmethod + def backward(forward_node: nd.Node, context: BackwardContext, given_gradients: List[Optional[str]], + required_gradients: List[Optional[str]]) -> Tuple[nd.Node, BackwardResult]: + # condition, X, Y -> Output + # Get condition descriptor for shape information + _ = butils.forward_in_desc_with_name(forward_node, context, "condition") + + # NOTE: We cannot use ONNX ops for further potential lowering + # transformations because ONNXMul does not support boolean inputs. + # notcondition = dace.define_local(condition_shape, condition_dtype) + # donnx.ONNXMul(A=condition, B=output_grad, C=X_grad) + # donnx.ONNXNot(X=condition, Y=notcondition) + # donnx.ONNXMul(A=notcondition, B=output_grad, C=Y_grad) + + if 'X' in required_gradients and 'Y' not in required_gradients: + + def where_backward(condition, output_grad, X_grad): + X_grad[:] = condition * output_grad + elif 'Y' in required_gradients and 'X' not in required_gradients: + + def where_backward(condition, output_grad, Y_grad): + Y_grad[:] = ~condition * output_grad + elif 'X' in required_gradients and 'Y' in required_gradients: + + def where_backward(condition, output_grad, X_grad, Y_grad): + X_grad[:] = condition * output_grad + Y_grad[:] = ~condition * output_grad + + result_node, result = butils.backward_program_for_node(where_backward, context, forward_node) + + return result_node, result + + +@autoregister_params(op="LayerNormalization", name="default") +class DefaultLayerNormalizationBackward(BackwardImplementation): + """Backward implementation for ONNX LayerNormalization operation. + + Computes gradients for input, scale, and bias parameters using the + mathematical formulation of layer normalization backward pass. + """ + + @staticmethod + def backward(forward_node: nd.Node, context: BackwardContext, given_gradients: List[Optional[str]], + required_gradients: List[Optional[str]]) -> Tuple[nd.Node, BackwardResult]: + # Create new SDFG + nsdfg = dace.SDFG(forward_node.label + "_backward") + nstate = nsdfg.add_state() + + # Get input/output descriptors + X_desc = copy.deepcopy(butils.forward_in_desc_with_name(forward_node, context, "X")) + Scale_desc = copy.deepcopy(butils.forward_in_desc_with_name(forward_node, context, "Scale")) + Y_grad_desc = copy.deepcopy(butils.forward_out_desc_with_name(forward_node, context, "Y")) + X_desc.transient = False + Y_grad_desc.transient = False + Scale_desc.transient = False + + result = BackwardResult.empty() + # setup gradient arrays + result.given_grad_names["Y"] = "Y_grad" + if "X" in required_gradients: + result.required_grad_names["X"] = "X_grad" + if "Scale" in required_gradients: + result.required_grad_names["Scale"] = "Scale_grad" + if "B" in required_gradients: + result.required_grad_names["B"] = "B_grad" + + # Add data descriptors to SDFG + nsdfg.add_datadesc("X", X_desc) + nsdfg.add_datadesc("Scale", Scale_desc) + nsdfg.add_datadesc("Y_grad", Y_grad_desc) + + if "X" in required_gradients: + X_grad_desc = copy.deepcopy(X_desc) + nsdfg.add_datadesc("X_grad", X_grad_desc) + if "Scale" in required_gradients: + Scale_grad_desc = copy.deepcopy(Scale_desc) + nsdfg.add_datadesc("Scale_grad", Scale_grad_desc) + if "B" in required_gradients: + B_desc = copy.deepcopy(butils.forward_in_desc_with_name(forward_node, context, "B")) + B_desc.transient = False + B_grad_desc = copy.deepcopy(B_desc) + nsdfg.add_datadesc("B_grad", B_grad_desc) + # Add B to SDFG inputs when needed + nsdfg.add_datadesc("B", B_desc) + + # Get axis and epsilon + axis = forward_node.axis if hasattr(forward_node, 'axis') else -1 + epsilon = forward_node.epsilon if hasattr(forward_node, 'epsilon') else 1e-5 + + rank = len(X_desc.shape) + if axis < 0: + axis = rank + axis + reduction_axes = list(range(axis, rank)) + leading_non_normalized_axes = list(range(axis)) + # Calculate normalization size for reference (currently unused) + _ = float(np.prod([X_desc.shape[i] for i in range(axis, rank)])) + + # Create axes tensor for reduction + axes_name = "reduction_axes" + axes_desc = dace.data.Array(dace.int64, [len(reduction_axes)]) + axes_desc.transient = True # Make it transient since it's internal + nsdfg.add_datadesc(axes_name, axes_desc) + axes_access = nstate.add_access(axes_name) + + # Initialize reduction axes as a constant array + axes_tasklet = nstate.add_tasklet(name="init_axes", + inputs={}, + outputs={"out": dace.pointer(dace.int64)}, + code=f"\n".join([f"out[{i}] = {0};" for i, _ in enumerate(reduction_axes)]), + language=dace.Language.CPP) + nstate.add_edge(axes_tasklet, "out", axes_access, None, dace.Memlet(f"{axes_name}[0:{len(reduction_axes)}]")) + + # Create mean descriptor with reduced shape + mean_shape = list(X_desc.shape) + for i in reduction_axes: + mean_shape[i] = 1 + mean_desc = dace.data.Array(X_desc.dtype, mean_shape) + mean_desc.transient = True + mean_name = "mean" + nsdfg.add_datadesc(mean_name, mean_desc) + + mean_op = donnx.ONNXReduceMean("mean_op", keepdims=1, optional={"axes"}) + mean_op.axes = reduction_axes + nstate.add_node(mean_op) + nstate.add_edge(nstate.add_read("X"), None, mean_op, "data", nsdfg.make_array_memlet("X")) + nstate.add_edge(axes_access, None, mean_op, "axes", nsdfg.make_array_memlet(axes_name)) + mean_access = nstate.add_access("mean") + nstate.add_edge(mean_op, "reduced", mean_access, None, nsdfg.make_array_memlet("mean")) + + # Recompute variance + diff_shape = list(X_desc.shape) + diff_desc = dace.data.Array(X_desc.dtype, diff_shape) + diff_desc.transient = True + diff_name = "diff" + nsdfg.add_datadesc(diff_name, diff_desc) + + diff_op = donnx.ONNXSub("diff_op") + nstate.add_node(diff_op) + nstate.add_edge(nstate.add_read("X"), None, diff_op, "A", nsdfg.make_array_memlet("X")) + nstate.add_edge(mean_access, None, diff_op, "B", nsdfg.make_array_memlet("mean")) + diff_access = nstate.add_access("diff") + nstate.add_edge(diff_op, "C", diff_access, None, nsdfg.make_array_memlet("diff")) + + # Create squared difference descriptor + sq_diff_shape = list(X_desc.shape) + sq_diff_desc = dace.data.Array(X_desc.dtype, sq_diff_shape) + sq_diff_desc.transient = True + sq_diff_name = "sq_diff" + nsdfg.add_datadesc(sq_diff_name, sq_diff_desc) + + sq_diff_op = donnx.ONNXMul("sq_diff_op") + nstate.add_node(sq_diff_op) + nstate.add_edge(diff_access, None, sq_diff_op, "A", nsdfg.make_array_memlet("diff")) + nstate.add_edge(diff_access, None, sq_diff_op, "B", nsdfg.make_array_memlet("diff")) + sq_diff_access = nstate.add_access("sq_diff") + nstate.add_edge(sq_diff_op, "C", sq_diff_access, None, nsdfg.make_array_memlet("sq_diff")) + + # Create variance descriptor with reduced shape + variance_shape = list(X_desc.shape) + for i in reduction_axes: + variance_shape[i] = 1 + variance_desc = dace.data.Array(X_desc.dtype, variance_shape) + variance_desc.transient = True + variance_name = "variance" + nsdfg.add_datadesc(variance_name, variance_desc) + + variance_op = donnx.ONNXReduceMean("variance_op", keepdims=1, optional={"axes"}) + variance_op.axes = reduction_axes + nstate.add_node(variance_op) + nstate.add_edge(sq_diff_access, None, variance_op, "data", nsdfg.make_array_memlet("sq_diff")) + nstate.add_edge(axes_access, None, variance_op, "axes", nsdfg.make_array_memlet(axes_name)) + variance_access = nstate.add_access("variance") + nstate.add_edge(variance_op, "reduced", variance_access, None, nsdfg.make_array_memlet("variance")) + + # Add epsilon to variance + epsilon_name, _ = nsdfg.add_scalar("epsilon", X_desc.dtype, transient=True) + epsilon_tasklet = nstate.add_tasklet( + "make_epsilon", + {}, + {"out"}, + f"out = {epsilon};", + language=dace.Language.CPP, + ) + epsilon_write = nstate.add_write(epsilon_name) + nstate.add_edge(epsilon_tasklet, "out", epsilon_write, None, dace.Memlet(f"{epsilon_name}[0]")) + + # Create variance_eps descriptor + variance_eps_desc = dace.data.Array(X_desc.dtype, variance_shape) + variance_eps_desc.transient = True + variance_eps_name = "variance_eps" + nsdfg.add_datadesc(variance_eps_name, variance_eps_desc) + + variance_eps_op = donnx.ONNXAdd("variance_eps_op") + nstate.add_node(variance_eps_op) + nstate.add_edge(variance_access, None, variance_eps_op, "A", nsdfg.make_array_memlet("variance")) + nstate.add_edge(epsilon_write, None, variance_eps_op, "B", nsdfg.make_array_memlet(epsilon_name)) + variance_eps_access = nstate.add_access("variance_eps") + nstate.add_edge(variance_eps_op, "C", variance_eps_access, None, nsdfg.make_array_memlet("variance_eps")) + + # Create std_dev descriptor + std_dev_desc = dace.data.Array(X_desc.dtype, variance_shape) + std_dev_desc.transient = True + std_dev_name = "std_dev" + nsdfg.add_datadesc(std_dev_name, std_dev_desc) + + std_dev_op = donnx.ONNXSqrt("std_dev_op") + nstate.add_node(std_dev_op) + nstate.add_edge(variance_eps_access, None, std_dev_op, "X", nsdfg.make_array_memlet("variance_eps")) + std_dev_access = nstate.add_access("std_dev") + nstate.add_edge(std_dev_op, "Y", std_dev_access, None, nsdfg.make_array_memlet("std_dev")) + + # Create inv_std_dev descriptor + one_name, _ = nsdfg.add_scalar("one", X_desc.dtype, transient=True) + one_tasklet = nstate.add_tasklet("make_one", {}, {"out"}, "out = 1.0;", language=dace.Language.CPP) + one_write = nstate.add_write(one_name) + nstate.add_edge(one_tasklet, "out", one_write, None, dace.Memlet(f"{one_name}[0]")) + + inv_std_dev_desc = dace.data.Array(X_desc.dtype, variance_shape) + inv_std_dev_desc.transient = True + inv_std_dev_name = "inv_std_dev" + nsdfg.add_datadesc(inv_std_dev_name, inv_std_dev_desc) + + inv_std_dev_op = donnx.ONNXDiv("inv_std_dev_op") + nstate.add_node(inv_std_dev_op) + nstate.add_edge(one_write, None, inv_std_dev_op, "A", nsdfg.make_array_memlet(one_name)) + nstate.add_edge(std_dev_access, None, inv_std_dev_op, "B", nsdfg.make_array_memlet("std_dev")) + inv_std_dev_access = nstate.add_access("inv_std_dev") + nstate.add_edge(inv_std_dev_op, "C", inv_std_dev_access, None, nsdfg.make_array_memlet("inv_std_dev")) + + # Create x_hat descriptor (normalized input) + x_hat_desc = dace.data.Array(X_desc.dtype, X_desc.shape) + x_hat_desc.transient = True + x_hat_name = "x_hat" + nsdfg.add_datadesc(x_hat_name, x_hat_desc) + + x_hat_op = donnx.ONNXMul("x_hat_op") + nstate.add_node(x_hat_op) + nstate.add_edge(diff_access, None, x_hat_op, "A", nsdfg.make_array_memlet("diff")) + nstate.add_edge(inv_std_dev_access, None, x_hat_op, "B", nsdfg.make_array_memlet("inv_std_dev")) + x_hat_access = nstate.add_access("x_hat") + nstate.add_edge(x_hat_op, "C", x_hat_access, None, nsdfg.make_array_memlet("x_hat")) + + # Compute bias gradient if needed + if "B" in required_gradients: + b_grad_op = donnx.ONNXReduceSum("b_grad_op", keepdims=0, optional={"axes"}) + # This reduction will sum over the leading non-normalized axes + b_grad_op.axes = leading_non_normalized_axes + nstate.add_node(b_grad_op) + nstate.add_edge(nstate.add_read("Y_grad"), None, b_grad_op, "data", nsdfg.make_array_memlet("Y_grad")) + nstate.add_edge(axes_access, None, b_grad_op, "axes", nsdfg.make_array_memlet(axes_name)) + nstate.add_edge(b_grad_op, "reduced", nstate.add_write("B_grad"), None, nsdfg.make_array_memlet("B_grad")) + + # Compute scale gradient if needed + if "Scale" in required_gradients: + dY_x_hat_desc = dace.data.Array(X_desc.dtype, X_desc.shape) + dY_x_hat_desc.transient = True + dY_x_hat_name = "dY_x_hat" + nsdfg.add_datadesc(dY_x_hat_name, dY_x_hat_desc) + + dY_x_hat_op = donnx.ONNXMul("dY_x_hat_op") + nstate.add_node(dY_x_hat_op) + nstate.add_edge(nstate.add_read("Y_grad"), None, dY_x_hat_op, "A", nsdfg.make_array_memlet("Y_grad")) + nstate.add_edge(x_hat_access, None, dY_x_hat_op, "B", nsdfg.make_array_memlet("x_hat")) + dY_x_hat_access = nstate.add_access("dY_x_hat") + nstate.add_edge(dY_x_hat_op, "C", dY_x_hat_access, None, nsdfg.make_array_memlet("dY_x_hat")) + + scale_grad_op = donnx.ONNXReduceSum("scale_grad_op", keepdims=0, optional={"axes"}) + scale_grad_op.axes = leading_non_normalized_axes + nstate.add_node(scale_grad_op) + nstate.add_edge(dY_x_hat_access, None, scale_grad_op, "data", nsdfg.make_array_memlet("dY_x_hat")) + nstate.add_edge(axes_access, None, scale_grad_op, "axes", nsdfg.make_array_memlet(axes_name)) + nstate.add_edge(scale_grad_op, "reduced", nstate.add_write("Scale_grad"), None, + nsdfg.make_array_memlet("Scale_grad")) + + # Compute X gradient if needed + if "X" in required_gradients: + # Create dX_hat descriptor (gradient with respect to normalized input) + dX_hat_desc = dace.data.Array(X_desc.dtype, X_desc.shape) + dX_hat_desc.transient = True + dX_hat_name = "dX_hat" + nsdfg.add_datadesc(dX_hat_name, dX_hat_desc) + + dX_hat_op = donnx.ONNXMul("dX_hat_op") + nstate.add_node(dX_hat_op) + nstate.add_edge(nstate.add_read("Y_grad"), None, dX_hat_op, "A", nsdfg.make_array_memlet("Y_grad")) + nstate.add_edge(nstate.add_read("Scale"), None, dX_hat_op, "B", nsdfg.make_array_memlet("Scale")) + dX_hat_access = nstate.add_access("dX_hat") + nstate.add_edge(dX_hat_op, "C", dX_hat_access, None, nsdfg.make_array_memlet("dX_hat")) + + # Compute mean of dX_hat over reduction axes + dX_hat_mean_desc = dace.data.Array(X_desc.dtype, variance_shape) + dX_hat_mean_desc.transient = True + dX_hat_mean_name = "dX_hat_mean" + nsdfg.add_datadesc(dX_hat_mean_name, dX_hat_mean_desc) + + dX_hat_mean_op = donnx.ONNXReduceMean("dX_hat_mean_op", keepdims=1, optional={"axes"}) + dX_hat_mean_op.axes = reduction_axes + nstate.add_node(dX_hat_mean_op) + nstate.add_edge(dX_hat_access, None, dX_hat_mean_op, "data", nsdfg.make_array_memlet("dX_hat")) + nstate.add_edge(axes_access, None, dX_hat_mean_op, "axes", nsdfg.make_array_memlet(axes_name)) + dX_hat_mean_access = nstate.add_access("dX_hat_mean") + nstate.add_edge(dX_hat_mean_op, "reduced", dX_hat_mean_access, None, nsdfg.make_array_memlet("dX_hat_mean")) + + # Compute dX_hat * x_hat + dX_hat_x_hat_desc = dace.data.Array(X_desc.dtype, X_desc.shape) + dX_hat_x_hat_desc.transient = True + dX_hat_x_hat_name = "dX_hat_x_hat" + nsdfg.add_datadesc(dX_hat_x_hat_name, dX_hat_x_hat_desc) + + dX_hat_x_hat_op = donnx.ONNXMul("dX_hat_x_hat_op") + nstate.add_node(dX_hat_x_hat_op) + nstate.add_edge(dX_hat_access, None, dX_hat_x_hat_op, "A", nsdfg.make_array_memlet("dX_hat")) + nstate.add_edge(x_hat_access, None, dX_hat_x_hat_op, "B", nsdfg.make_array_memlet("x_hat")) + dX_hat_x_hat_access = nstate.add_access("dX_hat_x_hat") + nstate.add_edge(dX_hat_x_hat_op, "C", dX_hat_x_hat_access, None, nsdfg.make_array_memlet("dX_hat_x_hat")) + + # Compute mean of dX_hat * x_hat over reduction axes + dX_hat_x_hat_mean_desc = dace.data.Array(X_desc.dtype, variance_shape) + dX_hat_x_hat_mean_desc.transient = True + dX_hat_x_hat_mean_name = "dX_hat_x_hat_mean" + nsdfg.add_datadesc(dX_hat_x_hat_mean_name, dX_hat_x_hat_mean_desc) + + dX_hat_x_hat_mean_op = donnx.ONNXReduceMean("dX_hat_x_hat_mean_op", keepdims=1, optional={"axes"}) + dX_hat_x_hat_mean_op.axes = reduction_axes + nstate.add_node(dX_hat_x_hat_mean_op) + nstate.add_edge(dX_hat_x_hat_access, None, dX_hat_x_hat_mean_op, "data", + nsdfg.make_array_memlet("dX_hat_x_hat")) + nstate.add_edge(axes_access, None, dX_hat_x_hat_mean_op, "axes", nsdfg.make_array_memlet(axes_name)) + dX_hat_x_hat_mean_access = nstate.add_access("dX_hat_x_hat_mean") + nstate.add_edge(dX_hat_x_hat_mean_op, "reduced", dX_hat_x_hat_mean_access, None, + nsdfg.make_array_memlet("dX_hat_x_hat_mean")) + + # Compute x_hat * mean(dX_hat * x_hat) + x_hat_dX_hat_x_hat_mean_desc = dace.data.Array(X_desc.dtype, X_desc.shape) + x_hat_dX_hat_x_hat_mean_desc.transient = True + x_hat_dX_hat_x_hat_mean_name = "x_hat_dX_hat_x_hat_mean" + nsdfg.add_datadesc(x_hat_dX_hat_x_hat_mean_name, x_hat_dX_hat_x_hat_mean_desc) + + x_hat_dX_hat_x_hat_mean_op = donnx.ONNXMul("x_hat_dX_hat_x_hat_mean_op") + nstate.add_node(x_hat_dX_hat_x_hat_mean_op) + nstate.add_edge(x_hat_access, None, x_hat_dX_hat_x_hat_mean_op, "A", nsdfg.make_array_memlet("x_hat")) + nstate.add_edge(dX_hat_x_hat_mean_access, None, x_hat_dX_hat_x_hat_mean_op, "B", + nsdfg.make_array_memlet("dX_hat_x_hat_mean")) + x_hat_dX_hat_x_hat_mean_access = nstate.add_access("x_hat_dX_hat_x_hat_mean") + nstate.add_edge(x_hat_dX_hat_x_hat_mean_op, "C", x_hat_dX_hat_x_hat_mean_access, None, + nsdfg.make_array_memlet("x_hat_dX_hat_x_hat_mean")) + + # Compute dX_hat - mean(dX_hat) - x_hat * mean(dX_hat * x_hat) + dX_hat_minus_mean_desc = dace.data.Array(X_desc.dtype, X_desc.shape) + dX_hat_minus_mean_desc.transient = True + dX_hat_minus_mean_name = "dX_hat_minus_mean" + nsdfg.add_datadesc(dX_hat_minus_mean_name, dX_hat_minus_mean_desc) + + dX_hat_minus_mean_op = donnx.ONNXSub("dX_hat_minus_mean_op") + nstate.add_node(dX_hat_minus_mean_op) + nstate.add_edge(dX_hat_access, None, dX_hat_minus_mean_op, "A", nsdfg.make_array_memlet("dX_hat")) + nstate.add_edge(dX_hat_mean_access, None, dX_hat_minus_mean_op, "B", nsdfg.make_array_memlet("dX_hat_mean")) + dX_hat_minus_mean_access = nstate.add_access("dX_hat_minus_mean") + nstate.add_edge(dX_hat_minus_mean_op, "C", dX_hat_minus_mean_access, None, + nsdfg.make_array_memlet("dX_hat_minus_mean")) + + # Final subtraction + dX_hat_final_desc = dace.data.Array(X_desc.dtype, X_desc.shape) + dX_hat_final_desc.transient = True + dX_hat_final_name = "dX_hat_final" + nsdfg.add_datadesc(dX_hat_final_name, dX_hat_final_desc) + + dX_hat_final_op = donnx.ONNXSub("dX_hat_final_op") + nstate.add_node(dX_hat_final_op) + nstate.add_edge(dX_hat_minus_mean_access, None, dX_hat_final_op, "A", + nsdfg.make_array_memlet("dX_hat_minus_mean")) + nstate.add_edge(x_hat_dX_hat_x_hat_mean_access, None, dX_hat_final_op, "B", + nsdfg.make_array_memlet("x_hat_dX_hat_x_hat_mean")) + dX_hat_final_access = nstate.add_access("dX_hat_final") + nstate.add_edge(dX_hat_final_op, "C", dX_hat_final_access, None, nsdfg.make_array_memlet("dX_hat_final")) + + # Multiply by inv_std_dev to get final X gradient + x_grad_op = donnx.ONNXMul("x_grad_op") + nstate.add_node(x_grad_op) + nstate.add_edge(inv_std_dev_access, None, x_grad_op, "A", nsdfg.make_array_memlet("inv_std_dev")) + nstate.add_edge(dX_hat_final_access, None, x_grad_op, "B", nsdfg.make_array_memlet("dX_hat_final")) + nstate.add_edge(x_grad_op, "C", nstate.add_write("X_grad"), None, nsdfg.make_array_memlet("X_grad")) + + # Set up inputs for nested SDFG + inputs = {"X", "Scale", "Y_grad"} + if "B" in required_gradients: + inputs.add("B") + + outputs = set(result.required_grad_names.values()) + bwd_node = context.backward_state.add_nested_sdfg(nsdfg, inputs, outputs) + return bwd_node, result + + +@autoregister_params(op="ReduceSum", name="default") +class DefaultReduceSumBackward(BackwardImplementation): + """Backward implementation for ONNX ReduceSum operation. + + The backward pass of a reduction is a broadcast of the output gradient + to match the input shape. Handles both keepdims=True and keepdims=False cases. + """ + + @staticmethod + def backward_can_be_applied(node: nd.Node, state: dace.SDFGState, sdfg: dace.SDFG) -> bool: + return True + + @staticmethod + def backward(forward_node: nd.Node, context: BackwardContext, given_gradients: List[Optional[str]], + required_gradients: List[Optional[str]]) -> Tuple[nd.Node, BackwardResult]: + + # The backward pass of a reduction is a broadcast. + # We use ONNXExpand to perform the broadcast. + # If keepdims=False, we first need to unsqueeze the gradient. + + input_desc = butils.forward_in_desc_with_name(forward_node, context, "data") + output_desc = butils.forward_out_desc_with_name(forward_node, context, "reduced") + + nsdfg = dace.SDFG(f"{forward_node.label}_backward") + nstate = nsdfg.add_state() + + result = BackwardResult.empty() + result.given_grad_names["reduced"] = "reduced_grad" + result.required_grad_names["data"] = "data_grad" + + reduced_grad_desc = copy.deepcopy(output_desc) + reduced_grad_desc.transient = False + nsdfg.add_datadesc("reduced_grad", reduced_grad_desc) + + data_grad_desc_tmp = copy.deepcopy(input_desc) + data_grad_desc_tmp.transient = True + nsdfg.add_datadesc("data_grad_tmp", data_grad_desc_tmp) + + data_grad_desc = copy.deepcopy(input_desc) + data_grad_desc.transient = False + nsdfg.add_datadesc("data_grad", data_grad_desc) + + grad_to_expand = "reduced_grad" + read_grad_to_expand = nstate.add_read(grad_to_expand) + + keepdims = getattr(forward_node, 'keepdims', 1) + + if not keepdims: + # When keepdims is False, the rank of the output is reduced. We need to + # unsqueeze the gradient to match the input rank before broadcasting. + + # Deduce reduced axes by comparing input and output shapes. + in_shape = input_desc.shape + out_shape = reduced_grad_desc.shape + unsqueezed_shape = [] + axes = [] + if len(in_shape) < len(out_shape): + raise ValueError(f"Input shape {in_shape} has fewer dimensions than output shape {out_shape}. " + f"This is unexpected for a ReduceSum operation.") + if len(in_shape) > len(out_shape): + # This assumes that non-reduced dimensions are preserved in order. + out_shape_idx = 0 + for i, dim in enumerate(in_shape): + if out_shape_idx < len(out_shape) and dim == out_shape[out_shape_idx]: + out_shape_idx += 1 + unsqueezed_shape.append(dim) + else: + axes.append(i) + unsqueezed_shape.append(1) + + # If shapes are equal, it's a no-op reduction and axes is empty. + if (not axes) != (len(in_shape) == len(out_shape)): + raise ValueError(f"Inconsistent state: axes={axes}, input_shape={in_shape}, output_shape={out_shape}. " + f"For equal shapes, axes should be empty.") + + if 'axes' in forward_node.in_connectors: + # The axes are a dynamic input to the forward node. Pass them to the backward node. + axes_desc = butils.forward_in_desc_with_name(forward_node, context, "axes") + axes_desc_copy = copy.deepcopy(axes_desc) + axes_desc_copy.transient = False + nsdfg.add_datadesc("axes", axes_desc_copy) + axes_access = nstate.add_read("axes") + elif axes: + # Create a constant array for the axes to be passed to Unsqueeze + axes_name_in_bwd, axes_desc_bwd = nsdfg.add_array(f"axes_for_unsqueeze_{forward_node.name}", + [len(axes)], + dace.int64, + transient=True) + axes_tasklet = nstate.add_tasklet( + 'init_axes', + {}, + {'out'}, + '\n'.join([f'out[{i}] = {v};' for i, v in enumerate(axes)]), + language=dace.Language.CPP, + ) + axes_access = nstate.add_access(axes_name_in_bwd) + nstate.add_edge(axes_tasklet, 'out', axes_access, None, + dace.Memlet.from_array(axes_name_in_bwd, axes_desc_bwd)) + + unsqueezed_desc = dace.data.Array(dtype=reduced_grad_desc.dtype, shape=unsqueezed_shape, transient=True) + nsdfg.add_datadesc("unsqueezed_grad", unsqueezed_desc) + + unsqueeze_op = donnx.ONNXUnsqueeze("unsqueeze_grad") + nstate.add_node(unsqueeze_op) + + nstate.add_edge(read_grad_to_expand, None, unsqueeze_op, "data", nsdfg.make_array_memlet("reduced_grad")) + nstate.add_edge(axes_access, None, unsqueeze_op, "axes", + dace.Memlet(data=axes_access.data, subset=f'0:{axes_access.desc(nsdfg).shape[0]}')) + + grad_to_expand = "unsqueezed_grad" + read_grad_to_expand = nstate.add_access(grad_to_expand) + nstate.add_edge(unsqueeze_op, "expanded", read_grad_to_expand, None, + nsdfg.make_array_memlet("unsqueezed_grad")) + + # Create shape tensor for ONNXExpand + shape_name, shape_desc = nsdfg.add_array("shape_for_expand", [len(input_desc.shape)], + dace.int64, + transient=True) + shape_tasklet = nstate.add_tasklet("init_shape", {}, {"out"}, + '\n'.join([f"out[{i}] = {s};" for i, s in enumerate(input_desc.shape)])) + shape_access = nstate.add_access(shape_name) + nstate.add_edge(shape_tasklet, "out", shape_access, None, dace.Memlet.from_array(shape_name, shape_desc)) + + expand_op = donnx.ONNXExpand("expand_grad") + nstate.add_node(expand_op) + write_data_grad_tmp = nstate.add_write("data_grad_tmp") + + nstate.add_edge(read_grad_to_expand, None, expand_op, "input", nsdfg.make_array_memlet(grad_to_expand)) + nstate.add_edge(shape_access, None, expand_op, "shape", nsdfg.make_array_memlet(shape_name)) + nstate.add_edge(expand_op, "output", write_data_grad_tmp, None, nsdfg.make_array_memlet("data_grad_tmp")) + + # We add an additional write from data_grad_tmp to data_grad + # This is necessary to accumulate gradients in the backward pass. + finale_memlet = nsdfg.make_array_memlet("data_grad") + finale_memlet.wcr = "lambda x, y: x + y" + write_data_grad = nstate.add_write("data_grad") + nstate.add_edge(write_data_grad_tmp, None, write_data_grad, None, finale_memlet) + + inputs = {"reduced_grad"} + if not keepdims and 'axes' in forward_node.in_connectors: + inputs.add("axes") + + result_node = context.backward_state.add_nested_sdfg(nsdfg, inputs, {"data_grad"}) + + return result_node, result diff --git a/dace/autodiff/implementations/pytorch_ops.py b/dace/autodiff/implementations/pytorch_ops.py new file mode 100644 index 0000000000..9f50edbb9e --- /dev/null +++ b/dace/autodiff/implementations/pytorch_ops.py @@ -0,0 +1,128 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. + +import copy +import itertools +from typing import List, Optional, Tuple + +import dace +import dace.libraries.torch +from dace.registry import autoregister_params +from dace import nodes as nd + +from dace.libraries.onnx.converters import clean_onnx_name + +import dace.autodiff.utils as butils +from dace.autodiff.base_abc import BackwardImplementation, BackwardContext, BackwardResult +from dace.sdfg.utils import in_desc_with_name + + +@autoregister_params(op="Conv", name="PyTorch-dwise") +class PyTorchConvBackward(BackwardImplementation): + """Depthwise convolution backward implementation using PyTorch. + + This implementation leverages PyTorch's optimized CUDA kernels for + depthwise convolution backward pass computation. + """ + + @staticmethod + def backward_can_be_applied(node: nd.Node, state: dace.SDFGState, sdfg: dace.SDFG) -> bool: + X_desc = in_desc_with_name(node, state, sdfg, "X") + return len(X_desc.shape) == 4 + + @staticmethod + def backward(forward_node: nd.Node, context: BackwardContext, given_gradients: List[Optional[str]], + required_gradients: List[Optional[str]]) -> Tuple[nd.Node, BackwardResult]: + + nsdfg = dace.SDFG(forward_node.label + "_backward") + X_desc = butils.forward_in_desc_with_name(forward_node, context, "X") + W_desc = butils.forward_in_desc_with_name(forward_node, context, "W") + + T = X_desc.dtype + if str(T) == 'float': + pytorch_dtype = 'kFloat' + elif str(T) == 'double': + pytorch_dtype = 'kDouble' + else: + raise ValueError(f"PyTorch backward conv expansion supports only float and double tensors, got {str(T)}. " + f"Supported types: float, double") + + # setup gradient arrays + result = BackwardResult.empty() + required_grads = set(required_gradients) + for r in sorted(required_grads): + result.required_grad_names[r] = butils.add_backward_desc_for_connector(nsdfg, + forward_node, + context, + r, + input=True) + result.given_grad_names["Y"] = butils.add_backward_desc_for_connector(nsdfg, + forward_node, + context, + "Y", + input=False) + + # setup non-gradient arrays + required_forward_inputs = ["W", "X"] + for i in sorted(required_forward_inputs): + new_desc = copy.deepcopy(butils.forward_in_desc_with_name(forward_node, context, i)) + new_desc.transient = False + nsdfg.add_datadesc(i, new_desc) + + # setup state + nstate = nsdfg.add_state() + unique_id = "{}_{}_{}_{}_bwd".format(clean_onnx_name(forward_node.name), context.forward_sdfg.sdfg_id, + context.forward_sdfg.node_id(context.forward_state), + context.forward_state.node_id(forward_node)) + + init_code = "" + finalize_code = "" + code_global = """ + #include + #include + """ + tasklet_inputs = {f"_{i}": dace.pointer(T) for i in itertools.chain(["dY"], sorted(required_forward_inputs))} + tasklet_outputs = {f"_d{i}": dace.pointer(T) for i in itertools.chain(sorted(required_gradients))} + + tasklet_code = f""" + std::vector x_shape = {{ {", ".join(map(str, X_desc.shape))} }}; + std::vector x_strides = {{ {", ".join(map(str, X_desc.strides))} }}; + std::vector w_shape = {{ {", ".join(map(str, W_desc.shape))} }}; + std::vector w_strides = {{ {", ".join(map(str, W_desc.strides))} }}; + at::Tensor x = at::from_blob(_X, x_shape, x_strides, [](void*){{}}, at::TensorOptions().device(at::kCUDA).dtype(at::{pytorch_dtype}).requires_grad(false)); + at::Tensor w = at::from_blob(_W, w_shape, w_strides, [](void*){{}}, at::TensorOptions().device(at::kCUDA).dtype(at::{pytorch_dtype}).requires_grad(false)); + at::Tensor dy = at::from_blob(_dY, x_shape, x_strides, [](void*){{}}, at::TensorOptions().device(at::kCUDA).dtype(at::{pytorch_dtype}).requires_grad(false)); + at::Tensor dw = at::from_blob(_dW, w_shape, w_strides, [](void*){{}}, at::TensorOptions().device(at::kCUDA).dtype(at::{pytorch_dtype}).requires_grad(false)); + at::Tensor dx = at::from_blob(_dX, x_shape, x_strides, [](void*){{}}, at::TensorOptions().device(at::kCUDA).dtype(at::{pytorch_dtype}).requires_grad(false)); + + std::vector kernel_shape = {{ {", ".join(map(str, forward_node.kernel_shape))} }}; + std::vector conv_strides = {{ {", ".join(map(str, forward_node.strides))} }}; + std::vector padding = {{ {", ".join(map(str, forward_node.pads[::2]))} }}; + std::vector dilation = {{ {", ".join(map(str, forward_node.dilations))} }}; + + at::thnn_conv_depthwise2d_backward_out(dx, dw, dy, x, w, kernel_shape, conv_strides, padding, dilation); + """ + + tasklet = nstate.add_tasklet(name=unique_id, + inputs=tasklet_inputs, + outputs=tasklet_outputs, + code=tasklet_code, + language=dace.dtypes.Language.CPP, + code_global=code_global, + code_init=init_code, + code_exit=finalize_code) + tasklet.environments = {dace.libraries.torch.environments.PyTorch.full_class_path()} + + nstate.add_edge(nstate.add_read(result.given_grad_names["Y"]), None, tasklet, f"_dY", + nsdfg.make_array_memlet((result.given_grad_names["Y"]))) + for name in sorted(required_forward_inputs): + nstate.add_edge(nstate.add_read(name), None, tasklet, f"_{name}", nsdfg.make_array_memlet(name)) + + for name in sorted(required_gradients): + arr_name = result.required_grad_names[name] + nstate.add_edge(tasklet, f"_d{name}", nstate.add_write(arr_name), None, nsdfg.make_array_memlet(arr_name)) + + inputs = {result.given_grad_names["Y"]}.union(required_forward_inputs) + outputs = {result.required_grad_names[n] for n in sorted(required_gradients)} + node = context.backward_state.add_nested_sdfg(nsdfg, inputs, outputs) + + return node, result diff --git a/dace/autodiff/library/__init__.py b/dace/autodiff/library/__init__.py new file mode 100644 index 0000000000..2a56e34067 --- /dev/null +++ b/dace/autodiff/library/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Library Integration for Automatic Differentiation. + +This package provides integration between DaCe's autodiff system and various +libraries and frontends. It enables differentiation of code that uses +library operations and provides hooks for frontend-specific optimizations. +""" + +import dace.library + +from . import library + +# PyTorch integrations are optional +try: + from . import torch_integration + from dace.frontend.python.replacements import torch_autodiff + TORCH_INTEGRATION_AVAILABLE = True +except ImportError: + torch_integration = None + torch_autodiff = None + TORCH_INTEGRATION_AVAILABLE = False + +dace.library.register_library(__name__, "autodiff") + +__all__ = [ + "library", +] + +if TORCH_INTEGRATION_AVAILABLE: + __all__.extend(["torch_integration", "torch_autodiff"]) diff --git a/dace/autodiff/library/library.py b/dace/autodiff/library/library.py new file mode 100644 index 0000000000..b5e2a60e98 --- /dev/null +++ b/dace/autodiff/library/library.py @@ -0,0 +1,188 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Dace library for autodiff + +Includes the BackwardPass library node, and the replacements for the python frontend +""" +from typing import Dict, Set, Optional +import copy + +import dace +import dace.library +from dace import data, properties +from dace.transformation import transformation as pm +from dace.sdfg import SDFG, SDFGState, graph, nodes + +from dace.autodiff import backward_pass_generator as engine, analysis as autodiff_analysis +from dace.autodiff.utils import init_grad +from dace.sdfg.utils import in_edge_with_name +from dace.transformation.passes.analysis import AccessSets + +# Import ParameterArray from the data package for backward compatibility +from dace.data.ml import ParameterArray + + +@dace.library.expansion +class ExpandBackwardPass(pm.ExpandTransformation): + environments = [] + + @staticmethod + def expansion(node: 'BackwardPass', state: SDFGState, sdfg: SDFG): + + node.validate(sdfg, state) + + in_array_name = lambda connector_name: in_edge_with_name(node, state, connector_name).data.data + + array_grad_map = {} + + access_sets = AccessSets().apply_pass(sdfg, {}) + + nsdfg = SDFG("backward_" + sdfg.label) + + # Check for other BackwardPasses that also compute the same gradients as us + node.propagate_conflicts(sdfg, state) + + # get the names of the output arrays in the forward pass + given_gradients = node.outer_names_given_gradients(state) + + array_grad_map.update(node.required_gradients) + array_grad_map.update((in_array_name(value_conn_name), grad_conn_name) + for grad_conn_name, value_conn_name in node.given_gradients.items()) + + # remove the non-grad arrays as inputs from the forward pass; + # they were also just added for control dependencies + for forward_non_grad_conn_name in node.given_gradients.values(): + for edge in list(state.in_edges_by_connector(node, forward_non_grad_conn_name)): + state.remove_edge(edge) + if state.in_degree(edge.src) + state.out_degree(edge.src) == 0: + state.remove_node(edge.src) + node.remove_in_connector(forward_non_grad_conn_name) + + gen = engine.BackwardPassGenerator(sdfg=sdfg, + given_gradients=given_gradients, + required_gradients=node.required_gradients.keys(), + backward_sdfg=nsdfg, + array_grad_map=array_grad_map, + conflicted_gradient_buffers=node._conflicted_gradients) + + _, _, required_forwarded_values = gen.backward() + + # Add zero initialization for all gradients which we are the first to compute + for outer_edge in state.out_edges(node): + gradient_we_are_writing: str = outer_edge.data.data + is_written_with_wcr = any(edge.data.wcr is not None and edge.data.data == outer_edge.src_conn + for edge, _ in nsdfg.all_edges_recursive() + if isinstance(edge, graph.MultiConnectorEdge)) + + anyone_written_before_us = autodiff_analysis.is_previously_written(sdfg, + state, + node, + gradient_we_are_writing, + access_sets=access_sets) + if not anyone_written_before_us and is_written_with_wcr: + init_grad(gradient_we_are_writing, sdfg, state) + + for name in required_forwarded_values: + # get the access to the forwarded_value + # there should only be one since we don't allow inplace modification + n = [n for n in state.nodes() if isinstance(n, nodes.AccessNode) and n.data == name] + if len(n) > 1: + raise ValueError( + "Expected only one access node for forwarded value, does the graph have in-place modification?") + elif len(n) == 0: + n = state.add_read(name) + else: + n = n[0] + + node.add_in_connector(name) + state.add_edge(n, None, node, name, sdfg.make_array_memlet(name)) + + nsdfg.validate() + + return nsdfg + + +@dace.library.node +class BackwardPass(nodes.LibraryNode): + """ + The BackwardPass library node expands to an implementation of a + BackwardPass that computes the requested gradients. + + These gradients are computed using the DaCe autograd engine. + + The gradient will be computed for each array in the output connectors. + For this, the names of the output connectors must match the name of the + array for which the gradient is to be computed. + """ + + # Global properties + implementations = { + "differentiate": ExpandBackwardPass, + } + default_implementation = "differentiate" + + given_gradients = properties.DictProperty( + key_type=str, + value_type=str, + desc="Mapping between connector names of the given gradients and the names of the arrays they correspond to.") + required_gradients = properties.DictProperty( + key_type=str, + value_type=str, + desc= + "Mapping from array name for which a gradient should be computed to the name of the connector that will receive the gradient." + ) + + _conflicted_gradients = properties.SetProperty( + element_type=str, + desc="Keys from required_gradients for which the gradients are also computed elsewhere, and thus writes to the " + " buffer need to be with write-conflict-resolution. Note: this field is automatically populated upon expansion." + ) + + def __init__(self, name, given_gradients: Dict[str, str], *args, **kwargs): + super().__init__(name, *args, **kwargs) + self.given_gradients = given_gradients + self.required_gradients = {} + + def outer_names_given_gradients(self, state: SDFGState) -> Set[str]: + """ + Returns the names of the arrays that are passed as given gradients. + """ + in_array_name = lambda connector_name: in_edge_with_name(self, state, connector_name).data.data + return set(map(in_array_name, self.given_gradients.values())) + + def propagate_conflicts(self, sdfg: SDFG, state: SDFGState): + """ + Across this SDFG, check for other BackwardPasses that also compute the same gradients as us. + + If there are multiple BackwardPasses that compute the same gradients, update their list of conflicts. + """ + + ours = set(self.required_gradients) + + for state in sdfg.states(): + for node in state.nodes(): + if isinstance(node, BackwardPass): + if node is self: + continue + conflicts = ours.intersection(node.required_gradients) + if conflicts: + self._conflicted_gradients |= conflicts + node._conflicted_gradients |= conflicts + + def validate(self, sdfg, state): + # Check that there is a correspondence between given gradients and inputs + all_inputs = set(self.in_connectors) + for given_grad, tensor_name in self.given_gradients.items(): + if given_grad not in all_inputs: + raise ValueError("Given gradient '{}' is not an input of the node".format(given_grad)) + + all_inputs.remove(given_grad) + all_inputs.remove(tensor_name) + + if all_inputs: + raise ValueError("The following in connectors were not included in given_gradients: {}".format( + ', '.join(all_inputs))) + + # Check that we are computing at least one gradient + if len(self.out_connectors) == 0: + raise ValueError("BackwardPass node '{}' does not compute any gradients".format(self.name)) diff --git a/dace/autodiff/library/torch_integration.py b/dace/autodiff/library/torch_integration.py new file mode 100644 index 0000000000..fe9e150dac --- /dev/null +++ b/dace/autodiff/library/torch_integration.py @@ -0,0 +1,39 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Hooks for PyTorch tensors to make them compatible with dace +""" +import copy + +try: + import torch + TORCH_AVAILABLE = True +except ImportError: + torch = None + TORCH_AVAILABLE = False + +from dace import data + +from dace.autodiff.library.library import ParameterArray + +if TORCH_AVAILABLE: + + def create_descriptor_tensor(self: torch.Tensor) -> data.Data: + """ + Creates a descriptor for a tensor. + If the tensor requires grad, we convert to a ParameterArray + """ + + desc = data.create_datadescriptor(self, no_custom_desc=True) + if not isinstance(desc, data.Array): + raise ValueError("Unsupported descriptor: {}".format(desc)) + + if not self.requires_grad: + return desc + + new_desc = copy.deepcopy(desc) + new_desc.__class__ = ParameterArray + new_desc.gradient = None + return new_desc + + # register with pytorch + torch.Tensor.__descriptor__ = create_descriptor_tensor diff --git a/dace/autodiff/torch.py b/dace/autodiff/torch.py new file mode 100644 index 0000000000..754a77a0a2 --- /dev/null +++ b/dace/autodiff/torch.py @@ -0,0 +1,124 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +from typing import Tuple, Dict, List + +import dace +from dace import data as dt + +from dace.autodiff.backward_pass_generator import BackwardPassGenerator +from dace.autodiff.base_abc import AutoDiffException, BackwardResult + +try: + from dace.libraries.onnx.converters import clean_onnx_name + from dace.frontend.ml.onnx import ONNXModel + ONNX_AVAILABLE = True +except ImportError: + ONNX_AVAILABLE = False + clean_onnx_name = None + ONNXModel = None + + +def make_backward_function( + model, # ONNXModel type hint removed for optional import + required_grads: List[str], +) -> Tuple[dace.SDFG, dace.SDFG, BackwardResult, Dict[str, dt.Data]]: + """ Convert an ONNXModel to a PyTorch differentiable function. This method should not be used on its own. + Instead use the ``backward=True`` parameter of :class:`dace.ml.DaceModule`. + + :param model: the model to convert. + :param required_grads: the list of inputs names of the module that we must compute gradients for. + :return: A 4-tuple of forward SDFG, backward SDFG, backward result, and input arrays for + backward pass (as mapping of names to DaCe data descriptors). + """ + if not ONNX_AVAILABLE: + raise ImportError("make_backward_function requires ONNX. Install with: pip install dace[ml]") + + if len(model.sdfg.nodes()) != 1: + raise AutoDiffException("Expected to find exactly one SDFGState, found {}".format(len(model.sdfg.nodes()))) + + forward_sdfg = model.sdfg + + backward_sdfg = dace.SDFG(forward_sdfg.name + "_backward") + + gen = BackwardPassGenerator(sdfg=forward_sdfg, + given_gradients=[clean_onnx_name(name) for name in model.outputs], + required_gradients=required_grads, + backward_sdfg=backward_sdfg) + + backward_result, backward_grad_arrays, backward_input_arrays = gen.backward() + + replaced_scalars = {} + + # get the forward state + forward_state = forward_sdfg.nodes() + # A loaded pytorch model should only have one state + if len(forward_state) != 1: + raise AutoDiffException(f"Expected forward SDFG to have exactly one state, found {len(forward_state)}") + forward_state = forward_state[0] + + # get the backward state + backward_state = backward_sdfg.nodes() + # A loaded pytorch model should only have one state + if len(backward_state) != 1: + raise AutoDiffException(f"Expected backward SDFG to have exactly one state, found {len(backward_state)}") + backward_state = backward_state[0] + + for name, desc in backward_input_arrays.items(): + if name not in forward_sdfg.arrays: + raise AutoDiffException("Expected to find array with name '{}' in SDFG".format(name)) + + forward_desc = forward_sdfg.arrays[name] + # we will save this output and pass it to the backward pass + + # Views should not be forwarded. Instead the backward pass generator should forward the source of the view, + # and rebuild the sequence of required views in the backward pass. + if type(forward_desc) is dt.View: + raise AutoDiffException( + f"Cannot forward View '{name}' to backward pass. " + "Views should not be forwarded; the backward pass generator should forward " + "the source of the view and rebuild the sequence of required views in the backward pass.") + if isinstance(forward_desc, dt.Scalar): + # we can't return scalars from SDFGs, so we add a copy to an array of size 1 + fwd_arr_name, _ = forward_sdfg.add_array(name + "_array", [1], + forward_desc.dtype, + transient=False, + storage=forward_desc.storage, + find_new_name=True) + bwd_arr_name, bwd_desc = backward_sdfg.add_array(name + "_array", [1], + forward_desc.dtype, + transient=False, + storage=forward_desc.storage, + find_new_name=True) + backward_sdfg.arrays[name].transient = True + + fwd_copy_state = forward_sdfg.add_state_after(forward_state, label="copy_out_" + fwd_arr_name) + bwd_copy_state = backward_sdfg.add_state_before(backward_state, label="copy_in_" + bwd_arr_name) + fwd_copy_state.add_edge(fwd_copy_state.add_read(name), None, fwd_copy_state.add_write(fwd_arr_name), None, + dace.Memlet(name + "[0]")) + + bwd_copy_state.add_edge(bwd_copy_state.add_read(bwd_arr_name), None, bwd_copy_state.add_write(name), None, + dace.Memlet(name + "[0]")) + replaced_scalars[name] = (bwd_arr_name, bwd_desc) + else: + forward_sdfg.arrays[name].transient = False + + for orig_name, (replaced_name, replaced_desc) in replaced_scalars.items(): + del backward_input_arrays[orig_name] + backward_input_arrays[replaced_name] = replaced_desc + + for fwd_name, bwd_name in backward_result.required_grad_names.items(): + desc = backward_sdfg.arrays[bwd_name] + if isinstance(desc, dt.Scalar): + arr_name, arr_desc = backward_sdfg.add_array(bwd_name + "_array", [1], + desc.dtype, + transient=False, + storage=desc.storage, + find_new_name=True) + desc.transient = True + bwd_copy_state = backward_sdfg.add_state_after(backward_state, label="copy_out_" + bwd_name) + bwd_copy_state.add_edge(bwd_copy_state.add_read(bwd_name), None, bwd_copy_state.add_write(arr_name), None, + dace.Memlet(bwd_name + "[0]")) + backward_result.required_grad_names[fwd_name] = arr_name + + backward_sdfg.validate() + + return forward_sdfg, backward_sdfg, backward_result, backward_input_arrays diff --git a/dace/autodiff/utils.py b/dace/autodiff/utils.py new file mode 100644 index 0000000000..d1693e3ef4 --- /dev/null +++ b/dace/autodiff/utils.py @@ -0,0 +1,910 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import ast +import collections +import copy +import inspect +import numbers +import re +from typing import Dict, List, Set, Tuple, Union + +import astunparse +import sympy as sp + +# DaCe imports +import dace +import dace.sdfg.utils as utils +from dace import dtypes +from dace import data as dt +from dace.frontend.python.parser import DaceProgram +from dace.sdfg import SDFG, SDFGState, graph as dgraph, nodes as nd, state as dstate +from dace.sdfg.state import LoopRegion + +# Autodiff imports +from dace.autodiff.base_abc import AutoDiffException, BackwardContext, BackwardResult + + +def forward_in_desc_with_name(forward_node: nd.Node, context: BackwardContext, name: str) -> dt.Data: + """Find the descriptor of the data that connects to input connector ``name``. + + :param forward_node: The node in the forward pass. + :param context: The backward context containing forward SDFG and state information. + :param name: The input connector name to find the descriptor for. + :return: The data descriptor that connects to the specified connector. + """ + return utils.in_desc_with_name(forward_node, context.forward_state, context.forward_sdfg, name) + + +def forward_out_desc_with_name(forward_node: nd.Node, context: BackwardContext, name: str) -> dt.Data: + """Find the descriptor of the data that connects to output connector ``name``. + + :param forward_node: The node in the forward pass. + :param context: The backward context containing forward SDFG and state information. + :param name: The output connector name to find the descriptor for. + :return: The data descriptor that connects to the specified connector. + """ + return utils.out_desc_with_name(forward_node, context.forward_state, context.forward_sdfg, name) + + +def add_backward_desc_for_connector(backward_sdfg: dace.SDFG, forward_node: nd.Node, context: BackwardContext, + connector: str, input: bool) -> str: + """Adds the backward array for the connector of ``forward_node``. + + :param backward_sdfg: The SDFG to add the backward array descriptor to. + :param forward_node: The forward node with the connector to create a descriptor for. + :param context: The backward context containing forward SDFG and state information. + :param connector: The connector name on the forward node. + :param input: True if the connector is an input, False if it's an output. + :return: The name of the newly added gradient array in ``backward_sdfg``. + """ + + if input: + edge = utils.in_edge_with_name(forward_node, context.forward_state, connector) + else: + edge = utils.out_edge_with_name(forward_node, context.forward_state, connector) + arr_name = edge.data.data + + forward_desc = context.forward_sdfg.arrays[arr_name] + + new_desc = copy.deepcopy(forward_desc) + new_desc.transient = False + return backward_sdfg.add_datadesc(arr_name + "_grad", new_desc, find_new_name=True) + + +def add_backward_desc(backward_sdfg: dace.SDFG, forward_sdfg: dace.SDFG, forward_desc: dt.Data, + forward_name: str) -> str: + """Adds the backward array for the given descriptor. + + :param backward_sdfg: The SDFG to add the backward array descriptor to. + :param forward_sdfg: The forward SDFG used for finding unique names. + :param forward_desc: The data descriptor of the forward array. + :param forward_name: A name for the forward array (doesn't have to match its actual name). + :return: The name of the newly added gradient array in ``backward_sdfg``. + """ + backward_name = dt.find_new_name(forward_name + "_grad", forward_sdfg.arrays) + new_desc = copy.deepcopy(forward_desc) + new_desc.transient = False + return backward_sdfg.add_datadesc(backward_name, new_desc) + + +def add_empty_sdfg_for_node(forward_node: nd.Node, required_descriptors: List[str], + context: BackwardContext) -> Tuple[nd.NestedSDFG, BackwardResult]: + """ Given a node, return an SDFG that can be used as a nested SDFG expansion for that node. + + ``required_descriptors`` may contain: + * Inputs/outputs of the forward node (these will be hooked up as required) + * Inputs/outputs of the forward node with the ``_grad`` suffix. These will be hooked up + as the gradients of the corresponding inputs/outputs. + + The descriptors will be initialized using the descriptors connected to edges of the + forward node. + + :param forward_node: the node in the forward pass + :param required_descriptors: A list of descriptors that should be added to the SDFG. + :param context: the backward context + :return: the nested SDFG and backward result for the forward node + """ + + nsdfg = dace.SDFG(forward_node.label + "_backward_expansion") + + def _get_fwd_descriptor(name): + """Returns the descriptor and whether it is an input""" + if name in forward_node.out_connectors: + return forward_out_desc_with_name(forward_node, context, name), False + elif name in forward_node.in_connectors: + return forward_in_desc_with_name(forward_node, context, name), True + + raise ValueError(f"Could not find {name} in inputs or outputs of {forward_node}") + + outputs_to_connect_from_forward = [] + + result = BackwardResult.empty() + inputs = set() + outputs = set() + + for name in required_descriptors: + if name.endswith("_grad"): + # hook this up as a gradient + desc, is_input = _get_fwd_descriptor(name[:-5]) + if is_input: + result.required_grad_names[name[:-5]] = name + else: + result.given_grad_names[name[:-5]] = name + # input grads are outputs of the backward node + if is_input: + outputs.add(name) + else: + inputs.add(name) + else: + desc, is_input = _get_fwd_descriptor(name) + if not is_input: + outputs_to_connect_from_forward.append(name) + inputs.add(name) + ndesc = copy.deepcopy(desc) + ndesc.transient = False + nsdfg.add_datadesc(name, ndesc) + + bwd_node = context.backward_state.add_nested_sdfg(nsdfg, inputs, outputs) + for output in outputs_to_connect_from_forward: + connect_output_from_forward(forward_node, bwd_node, context, output) + + return bwd_node, result + + +def backward_program_for_node(program, context: BackwardContext, + forward_node: nd.Node) -> Tuple[nd.Node, BackwardResult]: + """ Expand a function to the backward function for a node. + + The dtypes for the arguments will be extracted by matching the parameter names to edges. + + Gradient parameters should be the name of the forward parameter, appended with _grad. For these arguments the + data descriptors will match the data descriptors of the inputs/outputs they correspond to. + """ + + input_names = set(inp.name for inp in forward_node.schema.inputs) + output_names = set(outp.name for outp in forward_node.schema.outputs) + + if input_names.intersection(output_names): + # this is currently the case for only one onnx op + raise ValueError("program_for_node cannot be applied on nodes of this type;" + " '{}' is both an input and an output".format(next(input_names.intersection(output_names)))) + + def name_without_grad_in(name, collection): + return name[-5:] == "_grad" and name[:-5] in collection + + params = inspect.signature(program).parameters + + backward_result = BackwardResult.empty() + + inputs = {} + outputs = {} + for name, _ in params.items(): + if name in input_names: + inputs[name] = copy.deepcopy(forward_in_desc_with_name(forward_node, context, name)) + + elif name_without_grad_in(name, input_names): + outputs[name] = copy.deepcopy(forward_in_desc_with_name(forward_node, context, name[:-5])) + backward_result.required_grad_names[name[:-5]] = name + + elif name in output_names: + inputs[name] = copy.deepcopy(forward_out_desc_with_name(forward_node, context, name)) + + elif name_without_grad_in(name, output_names): + inputs[name] = copy.deepcopy(forward_out_desc_with_name(forward_node, context, name[:-5])) + backward_result.given_grad_names[name[:-5]] = name + + else: + raise ValueError("'{}' was not found as an input or output for {}".format(name, forward_node.schema.name)) + + program.__annotations__ = {**inputs, **outputs} + + sdfg = DaceProgram(program, (), {}, False, dace.DeviceType.CPU).to_sdfg() + + result_node = context.backward_state.add_nested_sdfg(sdfg, set(inputs), set(outputs)) + + return result_node, backward_result + + +def connect_output_from_forward(forward_node: nd.Node, backward_node: nd.Node, context: BackwardContext, + output_connector_name: str): + """ Connect an output of the forward node as an input to the backward node. This is done by forwarding the array + from the forward pass. + + Conceptually, this is similar to pytorch's ctx.save_for_backward. + + :param forward_node: the node in the forward pass. + :param backward_node: the node in the backward pass. + :param context: the backward context. + :param output_connector_name: the name of the connector on the backward pass. The output of that connector will + be forwarded to the connector of the same name on the backward node. + """ + output_edge = utils.out_edge_with_name(forward_node, context.forward_state, output_connector_name) + + # add the array of the output to backward_input_arrays that it will be forwarded by the autodiff engine + output_arr_name = output_edge.data.data + if output_arr_name not in context.backward_generator.backward_input_arrays: + data_desc = copy.deepcopy(context.forward_sdfg.arrays[output_arr_name]) + context.backward_generator.backward_input_arrays[output_arr_name] = data_desc + + if context.backward_generator.separate_sdfgs: + data_desc.transient = False + context.backward_sdfg.add_datadesc(output_arr_name, data_desc) + + read = context.backward_state.add_read(output_arr_name) + else: + cand = [ + n for n, _ in context.backward_state.all_nodes_recursive() + if isinstance(n, nd.AccessNode) and n.data == output_arr_name + ] + read = cand[0] + context.backward_state.add_edge(read, None, backward_node, output_connector_name, copy.deepcopy(output_edge.data)) + + +def cast_consts_to_type(code: str, dtype: dace.typeclass) -> str: + """Convert a piece of code so that constants are wrapped in casts to ``dtype``. + + For example:: + + x * (3 / 2) + + becomes:: + + x * (dace.float32(3) / dace.float32(2)) + + This is only done when it is required due to a Div operator to ensure proper + type casting in mathematical expressions during automatic differentiation. + + :param code: The code string to convert. + :param dtype: The DaCe typeclass to cast constants to. + :return: A string of the converted code with properly typed constants. + """ + + class CastConsts(ast.NodeTransformer): + + def __init__(self): + self._in_div_stack = collections.deque() + + def visit_Num(self, node): + if self._in_div_stack: + return ast.copy_location( + ast.parse(f"dace.{dtype.to_string()}({astunparse.unparse(node)})").body[0].value, node) + else: + return self.generic_visit(node) + + def visit_BinOp(self, node: ast.BinOp): + if node.op.__class__.__name__ == "Pow": + # within pow, we don't need to cast unless there is a new div + old_stack = self._in_div_stack + # reset the stack + self._in_div_stack = collections.deque() + node = self.generic_visit(node) + self._in_div_stack = old_stack + return node + + elif node.op.__class__.__name__ == "Div": + self._in_div_stack.append(None) + node = self.generic_visit(node) + self._in_div_stack.popleft() + return node + else: + return self.generic_visit(node) + + def visit_Constant(self, node): + if self._in_div_stack: + return ast.copy_location( + ast.parse(f"dace.{dtype.to_string()}({astunparse.unparse(node)})").body[0].value, node) + else: + return self.generic_visit(node) + + return astunparse.unparse(CastConsts().visit(ast.parse(code))) + + +def init_grad(data: str, sdfg: SDFG, current_state: SDFGState) -> None: + """Add a state where ``data`` is initialized with zero. + + This function creates a new state before the current state that initializes + the gradient array with zeros. It handles different storage types (CPU/GPU) + and array types appropriately. + + :param data: The name of the data array to initialize. + :param sdfg: The SDFG to add the initialization state to. + :param current_state: The current state; initialization will be done before this state. + :raises ValueError: If the storage type is not supported. + :raises AutoDiffException: If the data descriptor type is not supported. + """ + arr = sdfg.arrays[data] + + state = sdfg.add_state_before(current_state, label="init_" + data) + + scalar = 0 + if dtypes.can_access(dtypes.ScheduleType.CPU_Multicore, arr.storage): + cuda = False + elif dtypes.can_access(dtypes.ScheduleType.GPU_Default, arr.storage): + cuda = True + else: + raise ValueError(f"Unsupported storage {arr.storage}") + + if isinstance(arr, (dt.Array, dt.Scalar)): + state.add_mapped_tasklet( + "_init_" + data + "_", { + "i{}".format(i): "0:{}".format(shape) + for i, shape in enumerate(arr.shape) + }, {}, + "__out = {}".format(scalar), + {"__out": dace.Memlet.simple(data, ", ".join("i{}".format(i) for i in range(len(arr.shape))))}, + schedule=dtypes.ScheduleType.GPU_Device if cuda else dtypes.ScheduleType.Default, + external_edges=True) + elif type(arr) is dt.View: + # not need to initialize: the viewed array will always be visited + # (since a view can never be a required grad), and thus the viewed array will be initialized. + pass + else: + raise AutoDiffException("Unsupported data descriptor {}".format(arr)) + + +def extract_indices(expression: str) -> Dict[str, List[str]]: + """Extracts indexed array names and their indices from a given string expression. + + This function uses regular expressions to find patterns like "array[i, j, k]" + and returns a dictionary mapping array names to their index lists. + + :param expression: The string expression to analyze. + :return: A dictionary mapping array names to lists of their indices. + + Example:: + + >>> extract_indices("a[i, j] + b[k]") + {'a': ['i', 'j'], 'b': ['k']} + """ + # Regular expression to match the array names and their indices + pattern = r"(\w+)\[((?:\w+,?\s*)+)\]" + + # Find all matches in the given expression + matches = re.findall(pattern, expression) + + # Create a dictionary to store the arrays and their indices + index_map = {} + for name, indices in matches: + # Split indices by comma and remove any extra spaces + index_list = [index.strip() for index in indices.split(',')] + index_map[name] = index_list + + return index_map + + +def code_to_exprs(code: str, tasklet: nd.Tasklet, + symbols: List[str]) -> Tuple[Dict[str, sp.Expr], Dict[str, List[str]]]: + """ Convert a python string to a set of (simplified) symbolic sympy expressions. Currently, this + supports only code consisting of assignment statements. + + :param code: the code to convert + :param inputs: the inputs (i.e. the defined variables) for the code + :param outputs: the outputs to generate simplified expressions for + :return: map from outputs to symbolic expressions + """ + + inputs: List[str] = list(tasklet.in_connectors) + outputs: List[str] = list(tasklet.out_connectors) + + # Add the definition of global constant symbols that are presen in the code + # Prepare the Symbol declaration code + symbol_code = "" + for symb in symbols: + symbol_code += f" {symb} = sp.symbols('{symb}')\n" + + # We prepare a map of indexed objects and their indices + indexed_objects_map = extract_indices(code) + + # For now, make sure none of the outputs are indexed objects + indexed_outputs = [out for out in outputs if out in indexed_objects_map] + if indexed_outputs: + raise AutoDiffException(f"Indexed outputs are not currently supported: {indexed_outputs}") + + # Add the definition of indexed objects to the sympy code + indexed_objects_code = "" + for conn in inputs + outputs: + if (conn in inputs and isinstance(tasklet.in_connectors[conn], dace.dtypes.pointer) + or (conn in outputs and isinstance(tasklet.out_connectors[conn], dace.dtypes.pointer))): + if conn not in indexed_objects_map: + raise AutoDiffException(f"Expected connector '{conn}' to be in indexed objects map for pointer type") + indexed_objects_code += f" {conn} = sp.IndexedBase('{conn}')\n" + for idx in indexed_objects_map[conn]: + indexed_objects_code += f" {idx} = sp.symbols('{idx}', cls=sp.Idx)\n" + + code_fn = """ +def symbolic_execution({}): + # define functions from cmath.h + from sympy import exp, log + def log2(x): + return log(x, 2) + def log10(x): + return log(x, 10) + from sympy import sin, cos, tan, asin, acos, atan, sinh, cosh, tanh, asinh, acosh, atanh + from sympy import sin, cos, tan, asin, acos, atan, sinh, cosh, tanh, asinh, acosh, atanh + from sympy import Pow as pow, sqrt + from sympy import sign, floor, ceiling as ceil, Abs as abs, Abs as fabs + from sympy import Max as max, Min as min + from sympy import Max as fmax, Min as fmin + from sympy import erf + import sympy as sp +{} +{} +{} + return {} + """ + code_fn = code_fn.format( + ", ".join(inputs), + symbol_code, + indexed_objects_code, + "\n".join(" " + line.strip() for line in code.split("\n")), + ", ".join(outputs), + ) + + # Clean out type conversions from the code + code_fn = re.sub(r"dace\.(float32|int32|float64|int64)\((.*?)\)", r"\2", code_fn) + + try: + # need to have dace so things like `dace.float32(1)` work + temp_globals = {'dace': dace} + exec(code_fn, temp_globals) + + # no idea why, but simply calling symbolic_execution doesn't work + results = temp_globals["symbolic_execution"](*[sp.symbols(inp) for inp in inputs]) + + if len(outputs) > 1: + # make sure that everything is a sympy expression + for i, res in enumerate(results): + if not isinstance(res, sp.Expr): + results[i] = sp.sympify(res) + return dict(zip(outputs, results)), indexed_objects_map + else: + # make sure that everything is a sympy expression + if not isinstance(results, sp.Expr): + results = sp.sympify(results) + return {outputs[0]: results}, indexed_objects_map + except Exception as e: + raise AutoDiffException( + "Exception occurred while attempting to symbolically execute code:\n{}".format(code)) from e + + +def is_int_eq_value(value, target_value: int) -> bool: + if isinstance(value, numbers.Integral): + return value == target_value + + if len(value.free_symbols) > 0 or int(value) != target_value: + return False + + return True + + +def invert_map_connector(conn: str) -> str: + if conn.startswith("IN"): + return "OUT" + conn[2:] + elif conn.startswith("OUT"): + return "IN" + conn[3:] + else: + raise AutoDiffException("Could not parse map connector '{}'".format(conn)) + + +def path_src_node_in_subgraph(edge: dgraph.MultiConnectorEdge, subgraph: dstate.StateSubgraphView) -> bool: + path_src = subgraph.memlet_path(edge)[0].src + return path_src in subgraph.nodes() + + +def get_read_only_arrays(sdfg: SDFG) -> Set[str]: + """Get the arrays that are only read in SDFG. + + This function identifies arrays that are never written to (only have outgoing + edges with data or only empty memlets on incoming edges). + + :param sdfg: The SDFG to analyze. + :return: A set of array names that are read-only in the SDFG. + """ + written_to_arrays = set() + for node, parent in sdfg.all_nodes_recursive(): + if isinstance(node, nd.AccessNode): + if parent.in_degree(node) > 0 and any(not e.data.is_empty() for e in parent.in_edges(node)): + written_to_arrays.add(node.data) + + read_only_arrays = set(sdfg.arrays.keys()) - written_to_arrays + return read_only_arrays + + +def get_state_topological_order(graph) -> List[SDFGState]: + """ + Returns the SDFG states in topological order. + """ + all_nodes = list(utils.dfs_topological_sort(graph, graph.source_nodes())) + state_order = [] + for node in all_nodes: + if isinstance(node, SDFGState): + state_order.append(node) + elif isinstance(node, LoopRegion): + loop_state_order = get_state_topological_order(node) + state_order.extend(loop_state_order) + else: + raise AutoDiffException( + f"Unsupported node type {node} at the highest level of the SDFG while getting the state order") + + # All states in the graph need to be present in the state order + if isinstance(graph, SDFG) and set(state_order) != set(graph.states()): + raise AutoDiffException("Could not find all states of the SDFG in the state order") + return state_order + + +def shape_has_symbols_to_replace(sdfg: SDFG, shape: Union[str, sp.Symbol, sp.Expr]) -> bool: + """ + Check if the shape dimension passed as a parameter has a symbol that needs to be replaced. + We do not replace global SDFG symbols but rather the loop indices only. + """ + defined_symbols = sdfg.free_symbols | set(sdfg.arg_names) + if isinstance(shape, str): + shape = dace.symbolic.pystr_to_symbolic(shape) + return dace.symbolic.issymbolic(shape, defined_symbols) + + +def get_loop_end(start: str, end: str, loop: LoopRegion) -> str: + """ + Get the smallest and largest index of a loop given the start and end values. + This is an attempt at estimating the number of iterations of the loop. + """ + start_sym = dace.symbolic.pystr_to_symbolic(start) + end_sym = dace.symbolic.pystr_to_symbolic(end) + if not dace.symbolic.issymbolic(start_sym) and not dace.symbolic.issymbolic(end_sym): + int_start, int_end = int(start_sym), int(end_sym) + if int_start < int_end: + # Increasing loop + largest_index = int_end + smallest_index = int_start + else: + # Decreasing loop e.g., range(6, -1, -1) + # Since the start will be the first index there are start+1 iterations + largest_index = int_start + 1 + smallest_index = int_end + else: + # We check using the update statement + change = analyze_loop_change(loop.update_statement.as_string, loop.loop_variable) + if change == "increase": + # Increasing loop + largest_index = end + smallest_index = start + else: + # Decreasing loop + # Since the start will be the first index there are start+1 iterations + largest_index = start + "+1" + smallest_index = end + + return smallest_index, largest_index + + +def analyze_loop_change(code: str, loop_variable: str) -> str: + """Analyze if the given loop variable in the provided code increases or decreases. + + :param code: The Python code to analyze. + :param loop_variable: The name of the loop variable to analyze. + :return: ``'increase'``, ``'decrease'``, or ``'unknown'``. + """ + tree = ast.parse(code) + change_type = "unknown" + + for node in ast.walk(tree): + # Look for assignment statements + if isinstance(node, ast.Assign): + # Ensure the assignment targets the loop variable + if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name): + target = node.targets[0].id + if target == loop_variable and isinstance(node.value, ast.BinOp): + # Check for `loop_variable = loop_variable + ...` + if isinstance(node.value.left, ast.Name) and node.value.left.id == loop_variable: + # Analyze the right-hand side for increase or decrease + rhs = node.value.right + if isinstance(rhs, ast.UnaryOp) and isinstance(rhs.op, ast.USub): # Unary negative + if isinstance(rhs.operand, ast.Constant) and isinstance(rhs.operand.value, (int, float)): + change_type = "decrease" + elif isinstance(rhs, ast.UnaryOp) and isinstance(rhs.op, ast.UAdd): # Unary positive + if isinstance(rhs.operand, ast.Constant) and isinstance(rhs.operand.value, (int, float)): + change_type = "increase" + elif isinstance(rhs, ast.Constant) and isinstance(rhs.value, (int, float)): + change_type = "increase" if rhs.value > 0 else "decrease" + if change_type == "unknown": + raise AutoDiffException(f"Could not determine loop variable change in code: {code}") + return change_type + + +def get_map_nest_information( + edges_list: List[dstate.MultiConnectorEdge]) -> Tuple[List, List[str], List, Dict[str, Tuple]]: + """ + """ + # First, get the shape of the new array + shape_list = [] + + # We will also need the starting range of the maps in the path + start_range = [] + + # And the names of the parameters of the maps in the path + param_list = [] + + for e in edges_list: + edge_src = e.src + if isinstance(edge_src, nd.MapEntry): + for rng in edge_src.map.range.ranges: + # the range contains the last index in the loop + # while we want the size so we add 1 + shape_list.append(rng[1] + 1) + start_range.append(rng[0]) + for par in edge_src.map.params: + param_list.append(par) + + if not (len(param_list) == len(shape_list) == len(start_range)): + raise AutoDiffException( + f"Mismatched lengths: params={len(param_list)}, shapes={len(shape_list)}, ranges={len(start_range)}") + + # Create a dictionary mapping parameters to their start and end ranges + param_dict = {param: (start, end) for param, start, end in zip(param_list, start_range, shape_list)} + return start_range, param_list, shape_list, param_dict + + +def get_all_path_edges(state: SDFGState, source: nd.Node, + starting_edge: dgraph.MultiConnectorEdge) -> List[dgraph.MultiConnectorEdge]: + """ + We will start from the target node and go back until we reach the destination. + Starting edge should be an in node + """ + all_edges = [] + memlet_path = state.memlet_path(starting_edge) + all_edges += memlet_path + first_source = memlet_path[0].src + if first_source == source: + return all_edges + + # If there is only one edge coming to the first node + if state.in_degree(first_source) == 1: + edge = state.in_edges(first_source)[0] + memlet_path = state.memlet_path(edge) + all_edges += memlet_path + first_source = memlet_path[0].src + if first_source == source: + return all_edges + + raise AutoDiffException("Can't easily find path. Upgrade function.") + + +def extract_conditional_expressions(tasklet_node: nd.Tasklet) -> Tuple[str, str, str]: + """ + Given a conditional tasklet node, extract the if and else expressions and return them with the conditional. + The else statement could be None in case there is only an if statement. The current supported formats are the following: + 1 - if cond: + out = expression_1 + which would return ("out = expression_1", None, "if cond") + 2- out = expression_1 if cond else expression 2 + """ + + tasklet_code = tasklet_node.code.as_string + + # check which type of assignment this is + if ":" in tasklet_code: + # get the conditional input connector through regular expression matching + matches = re.search(r"if (.)*:", tasklet_code) + if not matches: + raise AutoDiffException(f"Could not find 'if' statement in conditional tasklet code: {tasklet_code}") + conditional = matches.group() + + # remove the conditional from the code to get the expression + if_statement = tasklet_code.replace(conditional, "") + if_statement = if_statement.replace("\n", "") + + # remove indentation + if_statement = if_statement[3:] + + # extract the in connector only + conditional = conditional.replace(":", "") + conditional = conditional.replace("if ", "") + if conditional not in tasklet_node.in_connectors: + raise AutoDiffException( + f"Conditional '{conditional}' not found in tasklet input connectors: {list(tasklet_node.in_connectors.keys())}" + ) + + else_statement = None + + # match the out connector + matches = re.search(r"^(.)* =", if_statement) + if not matches: + raise AutoDiffException(f"Could not find output assignment in if statement: {if_statement}") + out_connector = matches.group() + + # remove the assignment from the if statement + if_statement = if_statement.replace(out_connector, "") + + # extract the out connector only + out_connector = out_connector[1:].replace(" =", "") + + else: + # get the conditional input connector through regular expression matching + matches = re.search(r"if (.)* else", tasklet_code) + if not matches: + raise AutoDiffException(f"Could not find 'if...else' statement in conditional tasklet code: {tasklet_code}") + conditional = matches.group() + + # extract the in connector only + conditional = conditional.replace("if ", "") + conditional = conditional.replace(" else", "") + + if conditional not in tasklet_node.in_connectors: + raise AutoDiffException( + f"Conditional '{conditional}' not found in tasklet input connectors: {list(tasklet_node.in_connectors.keys())}" + ) + + # get the if statement by matching what comes before the if until we encounter a parenthesis or = + matches = re.search(r"= \((.)* if", tasklet_code) + if not matches: + # try without the parenthesis + matches = re.search(r"= (.)* if", tasklet_code) + if not matches: + raise AutoDiffException(f"Could not find if expression pattern in tasklet code: {tasklet_code}") + + if_statement = matches.group() + + # extract the in statement only + if_statement = if_statement.replace("= (", "") + if_statement = if_statement.replace(" if", "") + + # get the else statement by matching the else and what comes after it until we encounter a parenthesis + matches = re.search(r"else (.)*\)", tasklet_code) + if not matches: + raise AutoDiffException(f"Could not find else expression pattern in tasklet code: {tasklet_code}") + else_statement = matches.group() + + # extract the in statement only + else_statement = else_statement.replace("else ", "") + + # remove the last closing parenthesis if it exists + if else_statement.endswith(")"): + else_statement = else_statement[:-1] + + # match the out connector + matches = re.search(r"^(.)* =", tasklet_code) + if not matches: + raise AutoDiffException(f"Could not find output assignment in tasklet code: {tasklet_code}") + out_connector = matches.group() + + # extract the in statement only + out_connector = out_connector.replace(" =", "") + + # sanity check this should be in the out connectors of the tasklet + if out_connector not in tasklet_node.out_connectors: + raise AutoDiffException( + f"Output connector '{out_connector}' not found in tasklet output connectors: {list(tasklet_node.out_connectors.keys())}" + ) + + # create the return expressions + if_expression = f"{out_connector} = {if_statement}" + else_expression = f"{out_connector} = {else_statement}" if else_statement else None + + return if_expression, else_expression, conditional + + +def check_edges_type_in_state(subgraph: dstate.StateSubgraphView) -> None: + """ + Check if all the edges in this state are of type float, int, or boolean. + """ + for edge, parent_subgraph in subgraph.all_edges_recursive(): + if isinstance(parent_subgraph, SDFGState): + parent_sdfg = parent_subgraph.parent + elif isinstance(parent_subgraph, dstate.StateSubgraphView): + parent_sdfg = parent_subgraph.graph.parent + elif isinstance(parent_subgraph, SDFG) or isinstance(parent_subgraph, LoopRegion): + # if there are any fancy things on the interstate edges we should probably throw an error + continue + else: + raise AutoDiffException("Unexpected subgraph structure") + + if edge.data.data: + edge_type = parent_sdfg.arrays[edge.data.data].dtype + if edge_type in [dace.string]: + raise AutoDiffException( + f"Expected Subgraph to differentiate to only contain float, int, and bool edges, but data {edge.data}" + f" on edge {edge} has type {edge_type}") + + +def state_within_loop(forward_state: SDFGState) -> Tuple[bool, LoopRegion]: + """ + Check if this state will be executed several times within a loop. + We check if any of the parents of this state is a loop region. + """ + parent = forward_state.parent_graph + while parent is not None: + if isinstance(parent, LoopRegion): + return True, parent + parent = parent.parent_graph + return False, None + + +class SympyCleaner(ast.NodeTransformer): + + def visit_Name(self, node): + if node.id == "pi": + return ast.copy_location(ast.parse("dace.math.pi").body[0].value, node) + return self.generic_visit(node) + + +def extract_loop_region_info(loop: LoopRegion) -> Tuple[str, str]: + """ + Use regular expression matching to extract the start and end of the loop region. + We only treat regular for-loops with incrementation and decrementation updates. + """ + + # Extract the loop iterator + it = loop.loop_variable + + # Extract the end of the loop from the conditional statement + conditional = loop.loop_condition.as_string + + stride_sign = get_stride_sign(loop) + + # If the stride is positive + if stride_sign > 0: + conditional_expression = fr".*{it} < .*" + else: + # If the stride is negative + conditional_expression = fr".*{it} > .*" + + # Match the conditional using regular expressions + matches = re.search(conditional_expression, conditional) + if not matches: + raise AutoDiffException(f"Could not match conditional expression '{conditional_expression}' in '{conditional}'") + expression = matches.group() + matches_inner = re.search(conditional_expression[:-2], conditional) + if not matches_inner: + raise AutoDiffException( + f"Could not match conditional pattern '{conditional_expression[:-2]}' in '{conditional}'") + expression_to_remove = matches_inner.group() + end = expression.replace(expression_to_remove, "") + + # TODO: need more generalized solution for functions in the loop bounds + if "floor" not in conditional: + # There is no function call in the statement, remove parenthesis + end = end.replace("(", "") + end = end.replace(")", "") + end = end.replace(" ", "") + else: + if expression_to_remove.startswith("(") and not expression_to_remove.endswith(")") and expression.endswith(")"): + # Remove extra parenthesis + end = end[:-1] + + # Get the start from the initialization code + init_code = loop.init_statement.as_string + matches = re.search(fr".*{it} = .*", init_code) + if not matches: + raise AutoDiffException(f"Could not find initialization pattern for loop variable '{it}' in '{init_code}'") + expression = matches.group() + matches = re.search(fr"{it} =", init_code) + if not matches: + raise AutoDiffException(f"Could not find assignment pattern for loop variable '{it}' in '{init_code}'") + expression_to_remove = matches.group() + start = expression.replace(expression_to_remove, "") + + # Remove parenthesis and space + start = start.replace("(", "") + start = start.replace(")", "") + start = start.replace(" ", "") + + return start, end + + +def get_stride_sign(loop: LoopRegion) -> int: + """Check if the stride for this loop is positive or negative. + + :param loop: The loop region to analyze. + :return: ``1`` if the stride is positive, ``-1`` if negative. + :raises AutoDiffException: If the loop has an unsupported structure. + """ + if loop.update_statement is None: + raise AutoDiffException("While loops are not yet supported in DaCe AD") + update_statement = loop.update_statement.as_string + if "-" in update_statement: + return -1 + if "+" in update_statement: + return 1 + + # unsupported loop structure + raise AutoDiffException(f"Expected the loop region {loop.label} to have a regular update statement." + f" Instead got: {update_statement}") diff --git a/dace/builtin_hooks.py b/dace/builtin_hooks.py index 2a5b49e983..b691cd0296 100644 --- a/dace/builtin_hooks.py +++ b/dace/builtin_hooks.py @@ -96,7 +96,8 @@ def _make_filter_function(filter: Optional[Union[str, Callable[[Any], bool]]], if isinstance(filter, str): # If a string was given, construct predicate based on wildcard name matching if with_attr: - filter_func = lambda elem: fnmatch.fnmatch(elem.name, filter) + filter_func = lambda elem: fnmatch.fnmatch(elem.name, filter) if hasattr(elem, 'name') else fnmatch.fnmatch( + elem.label, filter) else: filter_func = lambda elem: fnmatch.fnmatch(elem, filter) elif callable(filter): diff --git a/dace/codegen/CMakeLists.txt b/dace/codegen/CMakeLists.txt index e1a5e33947..7b93f60d64 100644 --- a/dace/codegen/CMakeLists.txt +++ b/dace/codegen/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. cmake_minimum_required(VERSION 3.17) project(dace_program) @@ -7,38 +7,16 @@ set(DACE_PROGRAM_NAME "dace_program" CACHE STRING "Name of DaCe program") set(DACE_SRC_DIR "" CACHE STRING "Root directory of generated code files") set(DACE_FILES "" CACHE STRING "List of host code files relative to the root of the source directory") set(DACE_LIBS "" CACHE STRING "Extra libraries") -set(HLSLIB_PART_NAME "${DACE_XILINX_PART_NAME}") +set(DACE_CPP_STANDARD "20" CACHE STRING "C++ standard to use for compilation (e.g., 14, 17, 20, 23, 26)") +set(DACE_CMAKE_FILES "" CACHE STRING "List of additional CMake files to include") # CUDA set(DACE_CUDA_ARCHITECTURES_DEFAULT "" CACHE STRING "Default CUDA architectures in case native not found") -# FPGA specific -set(DACE_FPGA_AUTOBUILD_BITSTREAM OFF CACHE STRING "Automatically build bitstreams if they are not present.") - -# Allow passing flags to various stages of Xilinx compilation process -set(DACE_XILINX_MODE "simulation" CACHE STRING "Type of compilation/execution [simulation/software_emulation/hardware_emulation/hardware].") -set(DACE_XILINX_HOST_FLAGS "" CACHE STRING "Extra flags to host code") -set(DACE_XILINX_SYNTHESIS_FLAGS "" CACHE STRING "Extra flags for performing high-level synthesis") -set(DACE_XILINX_BUILD_FLAGS "" CACHE STRING "Extra flags to xocc build phase") -set(DACE_XILINX_TARGET_CLOCK "" CACHE STRING "Target clock frequency of FPGA kernel") -set(DACE_XILINX_PART_NAME "xcu280-fsvh2892-2L-e" CACHE STRING "Xilinx chip to target from HLS") -set(DACE_XILINX_TARGET_PLATFORM "xilinx_u280_xdma_201920_1" CACHE STRING "Vitis platform to target") -set(DACE_XILINX_ENABLE_DEBUGGING OFF CACHE STRING "Inject debugging cores to kernel build (always on for simulation/emulation)") - -# Intel FPGA options -set(DACE_INTELFPGA_MODE "simulation" CACHE STRING "Type of compilation/execution [emulator/simulator/hardare].") -set(DACE_INTELFPGA_HOST_FLAGS "" CACHE STRING "Extra flags to host compiler.") -set(DACE_INTELFPGA_KERNEL_FLAGS "" CACHE STRING "Extra flags to kernel compiler.") -set(DACE_INTELFPGA_TARGET_BOARD "a10gx" CACHE STRING "Target FPGA board.") -set(DACE_INTELFPGA_ENABLE_DEBUGGING OFF CACHE STRING "Enable debugging.") - # Target detection set(DACE_ENABLE_MPI OFF) set(DACE_ENABLE_CUDA OFF) set(DACE_ENABLE_HIP OFF) -set(DACE_ENABLE_XILINX OFF) -set(DACE_ENABLE_INTELFPGA OFF) -set(DACE_ENABLE_RTL OFF) # Split list by target foreach(DACE_FILE ${DACE_FILES}) @@ -65,35 +43,9 @@ foreach(DACE_FILE ${DACE_FILES}) set(DACE_ENABLE_CUDA ON) set(DACE_CPP_FILES ${DACE_CPP_FILES} ${DACE_FILE}) endif() - elseif(${DACE_FILE_TARGET} STREQUAL "xilinx") - set(DACE_ENABLE_XILINX ON) - if(DACE_FILE_TARGET_TYPE MATCHES "host") - set(DACE_XILINX_HOST_FILES ${DACE_XILINX_HOST_FILES} ${DACE_FILE}) - elseif (DACE_FILE_EXT MATCHES "ip.cpp") - set(DACE_ENABLE_RTL ON) - set(DACE_XILINX_IP_FILES ${DACE_XILINX_IP_FILES} ${DACE_FILE}) - elseif(DACE_FILE_EXT MATCHES ".cpp") - set(DACE_XILINX_KERNEL_FILES ${DACE_XILINX_KERNEL_FILES} ${DACE_FILE}) - elseif(DACE_FILE_EXT MATCHES ".cfg") - set(DACE_XILINX_CONFIG_FILE ${DACE_FILE}) - endif() - elseif(${DACE_FILE_TARGET} STREQUAL "intel_fpga") - set(DACE_ENABLE_INTELFPGA ON) - if(DACE_FILE_TARGET_TYPE MATCHES "host") - set(DACE_INTELFPGA_HOST_FILES ${DACE_INTELFPGA_HOST_FILES} ${DACE_FILE}) - else() - set(DACE_INTELFPGA_KERNEL_FILES ${DACE_INTELFPGA_KERNEL_FILES} ${DACE_FILE}) - endif() elseif(${DACE_FILE_TARGET} STREQUAL "mpi") set(DACE_ENABLE_MPI ON) set(DACE_CPP_FILES ${DACE_CPP_FILES} ${DACE_FILE}) - elseif(${DACE_FILE_TARGET} STREQUAL "rtl") - set(DACE_ENABLE_RTL ON) - if(DACE_FILE_EXT MATCHES ".v" OR DACE_FILE_EXT MATCHES ".sv") - set(DACE_RTL_FILES ${DACE_RTL_FILES} ${DACE_FILE}) - elseif(DACE_FILE_EXT MATCHES ".cpp") - set(DACE_HOST_FILES ${DACE_HOST_FILES} ${DACE_FILE}) - endif() else() set(DACE_CPP_FILES ${DACE_CPP_FILES} ${DACE_FILE}) endif() @@ -141,7 +93,7 @@ if(DACE_ENABLE_CUDA) set(CMAKE_CUDA_ARCHITECTURES "${LOCAL_CUDA_ARCHITECTURES}") enable_language(CUDA) - list(APPEND DACE_LIBS CUDA::cudart) + list(APPEND DACE_LIBS CUDA::cudart CUDA::nvtx3) add_definitions(-DWITH_CUDA) if (MSVC_IDE) @@ -167,6 +119,21 @@ if(DACE_ENABLE_HIP) # Add libraries such as rocBLAS link_directories(${HIP_PATH}/../lib) + if(ROCM_PATH) + find_path(ROCTX_INCLUDE_DIR roctx.h HINTS ${ROCM_PATH}/include/roctracer ${ROCM_PATH}/roctracer/include) + if(NOT ROCTX_INCLUDE_DIR) + message(WARNING "Could not find roctx.h in ${ROCM_PATH}/include/roctracer or ${ROCM_PATH}/roctracer/include") + endif() + endif() + if(ROCM_PATH AND ROCTX_INCLUDE_DIR) + find_path(ROCTX_LIBRARY_DIR "libroctx64.so" HINTS ${ROCM_PATH}/lib) + if(NOT ROCTX_LIBRARY_DIR) + message(WARNING "Could not find libroctx64.so in ${ROCM_PATH}/lib") + else() + list(APPEND DACE_LIBS "-lroctx64 -L${ROCTX_LIBRARY_DIR}") + include_directories(SYSTEM ${ROCTX_INCLUDE_DIR}) + endif() + endif() endif() # Function for performing deferred variable expansion @@ -241,27 +208,6 @@ set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${DACE_ENV_LINK_FLAGS}") set(CMAKE_STATIC_LINKER_FLAGS "${CMAKE_STATIC_LINKER_FLAGS} ${DACE_ENV_LINK_FLAGS}") set(CMAKE_MODULE_LINKER_FLAGS "${CMAKE_MODULE_LINKER_FLAGS} ${DACE_ENV_LINK_FLAGS}") -if(DACE_ENABLE_XILINX OR DACE_ENABLE_INTELFPGA) - set(DACE_HLSLIB_DIR ${CMAKE_SOURCE_DIR}/../external/hlslib) - set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${DACE_HLSLIB_DIR}/cmake) - include_directories(SYSTEM ${DACE_HLSLIB_DIR}/include) -endif() -if(DACE_ENABLE_XILINX) - find_package(Vitis REQUIRED) - include_directories(SYSTEM ${Vitis_INCLUDE_DIRS}) - add_definitions(-DDACE_XILINX -DDACE_VITIS_DIR=\"${VITIS_ROOT_DIR}\") - set(DACE_LIBS ${DACE_LIBS} ${Vitis_LIBRARIES}) -endif() -if(DACE_ENABLE_INTELFPGA) - find_package(IntelFPGAOpenCL REQUIRED) - include_directories(SYSTEM ${IntelFPGAOpenCL_INCLUDE_DIRS}) - add_definitions(-DDACE_INTELFPGA) - set(DACE_LIBS ${DACE_LIBS} ${IntelFPGAOpenCL_LIBRARIES}) -endif() -if (DACE_ENABLE_RTL AND DACE_ENABLE_XILINX) - set(DACE_RTLLIB_DIR ${CMAKE_SOURCE_DIR}/../external/rtllib) - include ("${DACE_RTLLIB_DIR}/cmake/rtl_target.cmake") -endif() # Create HIP object files if(DACE_ENABLE_HIP) @@ -310,297 +256,17 @@ if(DACE_ENABLE_HIP) set(DACE_OBJECTS ${DACE_OBJECTS} ${DACE_HIP_FILES}) endif() # DACE_ENABLE_HIP -# create verilator RTL simulation objects -if(DACE_ENABLE_RTL) - if (DACE_ENABLE_XILINX AND (NOT (DACE_XILINX_MODE STREQUAL "simulation"))) - # Get all of the kernel names - list(APPEND RTL_KERNELS "") - foreach(RTL_FILE ${DACE_RTL_FILES}) - get_filename_component(RTL_KERNEL ${RTL_FILE} DIRECTORY) - list(APPEND RTL_KERNELS ${RTL_KERNEL}) - endforeach() - list(REMOVE_DUPLICATES RTL_KERNELS) - - # Prepare build folders - set (RTL_GENERATED_DIR "${CMAKE_CURRENT_BINARY_DIR}/rtl/generated") - set (RTL_LOG_DIR "${CMAKE_CURRENT_BINARY_DIR}/rtl/log") - set (RTL_TEMP_DIR "${CMAKE_CURRENT_BINARY_DIR}/rtl/tmp") - file (MAKE_DIRECTORY - ${RTL_GENERATED_DIR} - ${RTL_LOG_DIR} - ${RTL_TEMP_DIR}) - execute_process(COMMAND ${Vitis_PLATFORMINFO} -p ${DACE_XILINX_TARGET_PLATFORM} -jhardwarePlatform.board.part - OUTPUT_VARIABLE RTL_PART - RESULT_VARIABLE _platforminfo_res) - - if (NOT ${_platforminfo_res} EQUAL 0) - message(FATAL_ERROR "No part was found for platform ${DACE_XILINX_TARGET_PLATFORM} after querying 'platforminfo -p ${DACE_XILINX_TARGET_PLATFORM} -j\"hardwarePlatform.board.part\"'") - endif() - - # Generate all of the .xo targets - foreach(RTL_SRC_DIR ${RTL_KERNELS}) - get_filename_component(RTL_KERNEL ${RTL_SRC_DIR} NAME) - get_filename_component(RTL_SCRIPTS "${RTL_SRC_DIR}/../scripts" ABSOLUTE) - set(RTL_XO "${RTL_KERNEL}.xo") - rtllib_rtl_target(${RTL_KERNEL} ${RTL_SRC_DIR} ${RTL_SCRIPTS} ${RTL_GENERATED_DIR} ${RTL_LOG_DIR} ${RTL_TEMP_DIR} "${RTLLIB_DIR}/rtl" ${RTL_XO} ${RTL_PART} "" "\"\"") - add_custom_target(${RTL_KERNEL} DEPENDS ${RTL_XO}) - set(DACE_RTL_KERNELS ${DACE_RTL_KERNELS} ${RTL_XO}) - set(DACE_RTL_DEPENDS ${DACE_RTL_DEPENDS} ${RTL_KERNEL}) - endforeach() - else() - # find verilator installation - find_package(verilator HINTS $ENV{VERILATOR_ROOT} ${VERILATOR_ROOT}) - if (NOT verilator_FOUND) - message(FATAL_ERROR "Verilator was not found. Either install it, or set the VERILATOR_ROOT environment variable") - endif() - - # check minimal version requirements - set(VERILATOR_MIN_VERSION "4.028") - if("${verilator_VERSION}" VERSION_LESS VERILATOR_MIN_VERSION) - message(ERROR "Please upgrade verilator to version >=${VERILATOR_MIN_VERSION}") - endif() - - # get verilator flags from dace.conf - set(VERILATOR_FLAGS "${DACE_RTL_VERILATOR_FLAGS}") - - # add lint verilator flags - if("${DACE_RTL_VERILATOR_LINT_WARNINGS}") - # -Wall: Enable all style warnings - # -Wno-fatal: Disable fatal exit on warnings - set(VERILATOR_FLAGS "${VERILATOR_FLAGS}" "-Wall" "-Wno-fatal") - endif() - - # add verilated.cpp source - set(DACE_CPP_FILES "${DACE_CPP_FILES}" "${VERILATOR_ROOT}/include/verilated.cpp" "${VERILATOR_ROOT}/include/verilated_threads.cpp" ) - - foreach(RTL_FILE ${DACE_RTL_FILES}) - - # extract design name - get_filename_component(RTL_FILE_NAME "${RTL_FILE}" NAME_WE) - - # add verilated design - add_library("${RTL_FILE_NAME}" OBJECT) - - # include verilator - set(VERILATOR_INCLUDE "${VERILATOR_ROOT}/include" "${dace_program_BINARY_DIR}/CMakeFiles/${RTL_FILE_NAME}.dir/V${RTL_FILE_NAME}.dir") - include_directories(${VERILATOR_INCLUDE}) - - # verilate design - verilate("${RTL_FILE_NAME}" SOURCES ${RTL_FILE} VERILATOR_ARGS "${VERILATOR_FLAGS}") - file(GLOB VSRC_FILES "${dace_program_BINARY_DIR}/CMakeFiles/${RTL_FILE_NAME}.dir/V${RTL_FILE_NAME}.dir/*.cpp") - set(DACE_CPP_FILES "${DACE_CPP_FILES}" ${VSRC_FILES} "${dace_program_BINARY_DIR}/CMakeFiles/${RTL_FILE_NAME}.dir/V${RTL_FILE_NAME}.dir/V${RTL_FILE_NAME}.cpp") - - # add object library for linking - set(DACE_LIBS ${DACE_LIBS} ${${RTL_FILE_NAME}}) - - endforeach() - endif() -endif() # DACE_ENABLE_RTL - - -# Create Xilinx object files -if(DACE_ENABLE_XILINX) - - if (DACE_XILINX_TARGET_CLOCK MATCHES "[|]") - string(REGEX MATCH "0:([0-9]+)" DACE_XILINX_EXTERNAL_TARGET_CLOCK ${DACE_XILINX_TARGET_CLOCK}) - string(REGEX MATCH "1:([0-9]+)" DACE_XILINX_INTERNAL_TARGET_CLOCK ${DACE_XILINX_TARGET_CLOCK}) - string(SUBSTRING ${DACE_XILINX_EXTERNAL_TARGET_CLOCK} 2 -1 DACE_XILINX_EXTERNAL_TARGET_CLOCK) - string(SUBSTRING ${DACE_XILINX_INTERNAL_TARGET_CLOCK} 2 -1 DACE_XILINX_INTERNAL_TARGET_CLOCK) - else() - set(DACE_XILINX_EXTERNAL_TARGET_CLOCK ${DACE_XILINX_TARGET_CLOCK}) - set(DACE_XILINX_INTERNAL_TARGET_CLOCK ${DACE_XILINX_TARGET_CLOCK}) - endif() - - if((NOT (DACE_XILINX_MODE STREQUAL "hardware")) OR DACE_XILINX_ENABLE_DEBUGGING) - set(DACE_XILINX_HOST_FLAGS "${DACE_XILINX_HOST_FLAGS} -g") - endif() - - set_source_files_properties(${DACE_XILINX_KERNEL_FILES} ${DACE_XILINX_HOST_FILES} PROPERTIES COMPILE_FLAGS "${DACE_XILINX_HOST_FLAGS}") - set_source_files_properties(${DACE_XILINX_KERNEL_FILES} PROPERTIES COMPILE_FLAGS "-DDACE_XILINX_DEVICE_CODE ${DACE_XILINX_HOST_FLAGS}") - set(DACE_OBJECTS ${DACE_OBJECTS} ${DACE_XILINX_KERNEL_FILES} ${DACE_XILINX_HOST_FILES}) - - if(DACE_XILINX_MODE STREQUAL "simulation") - # This will cause the OpenCL calls to instead call a simulation code - # running on the host - add_definitions(-DHLSLIB_SIMULATE_OPENCL) - endif() - - if(DACE_MINIMUM_FIFO_DEPTH) - set(DACE_XILINX_MINIMUM_FIFO_DEPTH "\nconfig_dataflow -fifo_depth ${DACE_MINIMUM_FIFO_DEPTH}") - endif() - - - # If the project uses generated IP cores (e.g. through multi-pumping) - if(DACE_XILINX_IP_FILES) - set(DACE_XILINX_BUILD_FLAGS ${DACE_XILINX_BUILD_FLAGS} --user_ip_repo_paths ip_cores) - endif() - - unset(DACE_KERNEL_TARGETS) - - # Generate the target kernel for each IP (multi-pumped kernel) - foreach(DACE_IP ${DACE_XILINX_IP_FILES}) - get_filename_component(DACE_KERNEL_NAME ${DACE_IP} NAME_WE) - get_filename_component(DACE_KERNEL_SRC ${DACE_IP} DIRECTORY) - - # Configure the tcl script for packaging the C++ kernel as an IP core for Vivado. - configure_file(${CMAKE_SOURCE_DIR}/Xilinx_IP.tcl.in Package_${DACE_KERNEL_NAME}.tcl) - add_custom_command( - OUTPUT ip_cores/${DACE_KERNEL_NAME}/impl/export.zip - COMMAND XILINX_PATH=${CMAKE_BINARY_DIR} ${Vitis_HLS} - -f Package_${DACE_KERNEL_NAME}.tcl - DEPENDS ${DACE_IP} - ) - - # Get the hardware part of the board, which is needed to package the .xo file. - execute_process(COMMAND ${Vitis_PLATFORMINFO} -p ${DACE_XILINX_TARGET_PLATFORM} -jhardwarePlatform.board.part - OUTPUT_VARIABLE RTL_PART - RESULT_VARIABLE _platforminfo_res - OUTPUT_STRIP_TRAILING_WHITESPACE) - - # Add target for packaging the kernel into an .xo file. - set (RTL_XO "${DACE_KERNEL_NAME}.xo") - rtllib_rtl_target(${DACE_KERNEL_NAME} ${DACE_KERNEL_SRC} ${DACE_KERNEL_SRC} ${DACE_KERNEL_SRC} log tmp ${DACE_KERNEL_SRC} ${RTL_XO} ${RTL_PART} ip_cores/${DACE_KERNEL_NAME}/impl/export.zip ip_cores) - add_custom_target(${DACE_KERNEL_NAME} DEPENDS ${RTL_XO}) - set(DACE_RTL_KERNELS ${DACE_RTL_KERNELS} ${RTL_XO}) - set(DACE_RTL_DEPENDS ${DACE_RTL_DEPENDS} ${DACE_KERNEL_NAME}) - endforeach() - - foreach(DACE_KERNEL_FILE ${DACE_XILINX_KERNEL_FILES}) - # Extract kernel name - get_filename_component(DACE_KERNEL_NAME ${DACE_KERNEL_FILE} NAME) - string(REGEX REPLACE "(.+).cpp" "\\1" DACE_KERNEL_NAME "${DACE_KERNEL_NAME}") - - add_vitis_kernel(${DACE_KERNEL_NAME} - FILES ${DACE_VITIS_KERNEL_FILES} ${DACE_KERNEL_FILE} - HLS_FLAGS "${DACE_XILINX_SYNTHESIS_FLAGS} -DDACE_SYNTHESIS -DDACE_XILINX -DDACE_XILINX_DEVICE_CODE" - HLS_CONFIG "config_compile -pipeline_style frp${DACE_XILINX_MINIMUM_FIFO_DEPTH}" - INCLUDE_DIRS ${CMAKE_SOURCE_DIR}/../external/hlslib/include - ${CMAKE_SOURCE_DIR}/../runtime/include) - set(DACE_KERNEL_TARGETS ${DACE_KERNEL_TARGETS} ${DACE_KERNEL_NAME}) - endforeach() - - add_vitis_program(${DACE_PROGRAM_NAME} - ${DACE_XILINX_TARGET_PLATFORM} - KERNELS ${DACE_KERNEL_TARGETS} - DEBUGGING ${DACE_XILINX_ENABLE_DEBUGGING} - CLOCK ${DACE_XILINX_EXTERNAL_TARGET_CLOCK} - BUILD_FLAGS ${DACE_XILINX_BUILD_FLAGS} - LINK_FLAGS ${DACE_RTL_KERNELS} - DEPENDS ${DACE_RTL_DEPENDS} - CONFIG ${DACE_XILINX_CONFIG_FILE}) - -endif() # DACE_ENABLE_XILINX - -# Create Intel FPGA object files -if(DACE_ENABLE_INTELFPGA) - - if((NOT (DACE_INTELFPGA_MODE STREQUAL "hardware")) OR DACE_INTELFPGA_ENABLE_DEBUGGING) - set(DACE_INTELFPGA_HOST_FLAGS "${DACE_INTELFPGA_HOST_FLAGS} -g") - set(DACE_INTELFPGA_SYNTHESIS_FLAGS "${DACE_INTELFPGA_KERNEL_FLAGS} -fast-compile -profile=all -g -fast-emulator") - endif() - - set_source_files_properties(${DACE_INTELFPGA_KERNEL_FILES} ${DACE_INTELFPGA_HOST_FILES} PROPERTIES COMPILE_FLAGS "${DACE_INTELFPGA_HOST_FLAGS}") - set_source_files_properties(${DACE_INTELFPGA_KERNEL_FILES} PROPERTIES COMPILE_FLAGS "-DDACE_INTELFPGA_DEVICE_CODE ${DACE_INTELFPGA_HOST_FLAGS}") - set(DACE_OBJECTS ${DACE_OBJECTS} ${DACE_INTELFPGA_KERNEL_FILES} ${DACE_INTELFPGA_HOST_FILES}) - - # Add synthesis and build commands - set(DACE_AOC_KERNEL_FILES) - set(DACE_AOC_DEFINITIONS "-DDACE_INTELFPGA") - foreach(DACE_KERNEL_FILE ${DACE_INTELFPGA_KERNEL_FILES}) - - get_filename_component(DACE_KERNEL_NAME ${DACE_KERNEL_FILE} NAME) - string(REGEX REPLACE "kernel_(.+).cl" "\\1" DACE_KERNEL_NAME "${DACE_KERNEL_NAME}") - set(DACE_AOC_KERNEL_FILES ${DACE_AOC_KERNEL_FILES} ${DACE_KERNEL_FILE}) - - # Intel compiler does not allow to specify the output file if more than input file is used. - # In this case, the output AOCX file will be named as the last OpenCL file given in input to the compiler. - # We need to save the name of the last input file, so that later we can assign a proper name to the produced bitstream. - get_filename_component(DACE_AOC_OUTPUT_FILE ${DACE_KERNEL_FILE} NAME_WE) - endforeach() - - string(REPLACE " " ";" DACE_INTELFPGA_KERNEL_FLAGS_INTERNAL - "${DACE_INTELFPGA_KERNEL_FLAGS}") - - set(DACE_AOC_BUILD_FLAGS - -I${CMAKE_SOURCE_DIR}/include - -I${CMAKE_SOURCE_DIR}/../external/hlslib/include - -I${CMAKE_SOURCE_DIR}/../runtime/include - -I${CMAKE_BINARY_DIR} - -board=${DACE_INTELFPGA_TARGET_BOARD} - ${DACE_INTELFPGA_KERNEL_FLAGS_INTERNAL} - ${DACE_AOC_DEFINITIONS}) - - add_custom_target( - intelfpga_report_${DACE_PROGRAM_NAME} - COMMAND - ${IntelFPGAOpenCL_AOC} - ${DACE_AOC_BUILD_FLAGS} - ${DACE_AOC_KERNEL_FILES} - -rtl - -report - COMMAND mv ${DACE_AOC_OUTPUT_FILE} ${DACE_PROGRAM_NAME}) - - add_custom_command( - OUTPUT ${DACE_PROGRAM_NAME}_emulator.aocx - COMMAND ${IntelFPGAOpenCL_AOC} - ${DACE_AOC_BUILD_FLAGS} - -march=emulator - ${DACE_AOC_KERNEL_FILES} - COMMAND mv ${DACE_AOC_OUTPUT_FILE}.aocx ${DACE_PROGRAM_NAME}_emulator.aocx - DEPENDS ${DACE_AOC_KERNEL_FILES}) - - add_custom_command( - OUTPUT ${DACE_PROGRAM_NAME}_hardware.aocx - COMMAND ${IntelFPGAOpenCL_AOC} - ${DACE_AOC_BUILD_FLAGS} - ${DACE_AOC_KERNEL_FILES} - COMMAND mv ${DACE_AOC_OUTPUT_FILE}.aocx ${DACE_PROGRAM_NAME}_hardware.aocx - COMMAND mv ${DACE_AOC_OUTPUT_FILE} ${DACE_PROGRAM_NAME} - DEPENDS ${DACE_AOC_KERNEL_FILES}) - -endif() - -include("targets/mlir/mlir.cmake") +# Additional target-specific CMake files +foreach(CMAKE_FILE ${DACE_CMAKE_FILES}) + include(${CMAKE_FILE}) +endforeach() # Create DaCe library file add_library(${DACE_PROGRAM_NAME} SHARED ${DACE_CPP_FILES} ${DACE_OBJECTS}) target_link_libraries(${DACE_PROGRAM_NAME} PUBLIC ${DACE_LIBS}) -# Add additional required files -if(DACE_ENABLE_INTELFPGA) - if(DACE_INTELFPGA_MODE STREQUAL "emulator") - add_custom_target(intelfpga_compile_${DACE_PROGRAM_NAME}_emulator - ALL DEPENDS ${DACE_PROGRAM_NAME}_emulator.aocx) - else() - add_custom_target(intelfpga_compile_${DACE_PROGRAM_NAME}_emulator - DEPENDS ${DACE_PROGRAM_NAME}_emulator.aocx) - endif() - if(DACE_INTELFPGA_MODE STREQUAL "hardware" AND DACE_FPGA_AUTOBUILD_BITSTREAM) - add_custom_target(intelfpga_compile_${DACE_PROGRAM_NAME}_hardware - ALL DEPENDS ${DACE_PROGRAM_NAME}_hardware.aocx) - else() - add_custom_target(intelfpga_compile_${DACE_PROGRAM_NAME}_hardware - DEPENDS ${DACE_PROGRAM_NAME}_hardware.aocx) - endif() -endif() - -if(DACE_ENABLE_XILINX) - if(DACE_XILINX_MODE STREQUAL "software_emulation" AND DACE_FPGA_AUTOBUILD_BITSTREAM) - add_custom_target(autobuild_bitstream ALL - COMMENT "Automatically built bitstream for software emulation." - DEPENDS sw_emu) - endif() - if(DACE_XILINX_MODE STREQUAL "hardware_emulation" AND DACE_FPGA_AUTOBUILD_BITSTREAM) - add_custom_target(autobuild_bitstream ALL - COMMENT "Automatically built bitstream for hardware emulation." - DEPENDS hw_emu) - endif() - if(DACE_XILINX_MODE STREQUAL "hardware" AND DACE_FPGA_AUTOBUILD_BITSTREAM) - add_custom_target(autobuild_bitstream ALL - COMMENT "Automatically built bitstream for hardware." - DEPENDS hw) - endif() -endif() +# Set C++ standard to C++20 (or the configured standard) +set_property(TARGET ${DACE_PROGRAM_NAME} PROPERTY CXX_STANDARD ${DACE_CPP_STANDARD}) # Create DaCe loader stub add_library(dacestub_${DACE_PROGRAM_NAME} SHARED "${CMAKE_SOURCE_DIR}/tools/dacestub.cpp") diff --git a/dace/codegen/Xilinx_IP.tcl.in b/dace/codegen/Xilinx_IP.tcl.in deleted file mode 100644 index 986ca9efc4..0000000000 --- a/dace/codegen/Xilinx_IP.tcl.in +++ /dev/null @@ -1,9 +0,0 @@ -open_project ip_cores -set_top ${DACE_KERNEL_NAME} -add_files -cflags "-std=c++11 -DDACE_SYNTHESIS -DDACE_XILINX -DDACE_XILINX_DEVICE_CODE -DHLSLIB_SYNTHESIS -DHLSLIB_XILINX -DVITIS_MAJOR_VERSION=2021 -DVITIS_MINOR_VERSION=1 -DVITIS_VERSION=2021.1 -D__VITIS_HLS__ -I${DACE_RUNTIME_DIR}/include -I${DACE_HLSLIB_DIR}/include -I${CMAKE_BINARY_DIR}" "${DACE_IP}" -open_solution "${DACE_KERNEL_NAME}" -flow_target vivado -set_part ${DACE_XILINX_PART_NAME} -create_clock -period ${DACE_XILINX_INTERNAL_TARGET_CLOCK}MHz -name default -csynth_design -export_design -format ip_catalog -quit diff --git a/dace/codegen/codegen.py b/dace/codegen/codegen.py index 3ccbb56dc6..d83b498de9 100644 --- a/dace/codegen/codegen.py +++ b/dace/codegen/codegen.py @@ -1,4 +1,4 @@ -# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. import functools from typing import List @@ -13,9 +13,6 @@ from dace.config import Config from dace.sdfg import infer_types -# Import CPU code generator. TODO: Remove when refactored -from dace.codegen.targets import cpp, cpu - from dace.codegen.instrumentation import InstrumentationProvider from dace.sdfg.state import SDFGState from dace.transformation.pass_pipeline import FixedPointPipeline @@ -61,6 +58,7 @@ def generate_dummy(sdfg: SDFG, frame: framecode.DaCeCodeGenerator) -> str: # allocate the array args using calloc for argname, arg in al.items(): if isinstance(arg, data.Array): + from dace.codegen.targets import cpp dims_mul = cpp.sym2cpp(functools.reduce(lambda a, b: a * b, arg.shape, 1)) basetype = str(arg.dtype) allocations += (" " + str(arg.as_arg(name=argname, with_types=True)) + " = (" + basetype + "*) calloc(" + @@ -160,7 +158,7 @@ def generate_code(sdfg: SDFG, validate=True) -> List[CodeObject]: :param validate: If True, validates the SDFG before generating the code. :return: List of code objects that correspond to files to compile. """ - from dace.codegen.targets.target import TargetCodeGenerator # Avoid import loop + from dace.codegen.target import TargetCodeGenerator # Avoid import loop # Before compiling, validate SDFG correctness if validate: @@ -176,7 +174,7 @@ def generate_code(sdfg: SDFG, validate=True) -> List[CodeObject]: sdfg.save(f'{tmp_dir}/test.sdfg', hash=False) sdfg2 = SDFG.from_file(f'{tmp_dir}/test.sdfg') sdfg2.save(f'{tmp_dir}/test2.sdfg', hash=False) - print('Testing SDFG serialization...') + if not filecmp.cmp(f'{tmp_dir}/test.sdfg', f'{tmp_dir}/test2.sdfg'): with open(f'{tmp_dir}/test.sdfg', 'r') as f1: with open(f'{tmp_dir}/test2.sdfg', 'r') as f2: @@ -217,6 +215,7 @@ def generate_code(sdfg: SDFG, validate=True) -> List[CodeObject]: # Instantiate CPU first (as it is used by the other code generators) # TODO: Refactor the parts used by other code generators out of CPU + from dace.codegen.targets import cpu default_target = cpu.CPUCodeGen for k, v in TargetCodeGenerator.extensions().items(): # If another target has already been registered as CPU, use it instead diff --git a/dace/codegen/common.py b/dace/codegen/common.py index d8524eacbc..f5bbf445a2 100644 --- a/dace/codegen/common.py +++ b/dace/codegen/common.py @@ -171,3 +171,14 @@ def get_gpu_runtime() -> gpu_runtime.GPURuntime: 'environment variable to point to the libraries.') return gpu_runtime.GPURuntime(backend, libpath) + + +def platform_library_name(libname: str) -> str: + """ Get the filename of a library. + + :param libname: the name of the library. + :return: the filename of the library. + """ + prefix = config.Config.get('compiler', 'library_prefix') + suffix = config.Config.get('compiler', 'library_extension') + return f"{prefix}{libname}.{suffix}" diff --git a/dace/codegen/compiled_sdfg.py b/dace/codegen/compiled_sdfg.py index 733f0ba53c..ba2dd8d465 100644 --- a/dace/codegen/compiled_sdfg.py +++ b/dace/codegen/compiled_sdfg.py @@ -5,10 +5,11 @@ import re import shutil import subprocess -from typing import Any, Callable, Dict, List, Tuple, Optional, Type, Union +from typing import Any, Callable, Dict, List, Tuple, Optional, Type, Union, Sequence import warnings import tempfile import pickle +import pathlib import sys import numpy as np @@ -77,7 +78,8 @@ def is_loaded(self) -> bool: lib_cfilename = ctypes.c_wchar_p(self._library_filename) else: # As UTF-8 - lib_cfilename = ctypes.c_char_p(self._library_filename.encode('utf-8')) + tt = self._library_filename.encode('utf-8') + lib_cfilename = ctypes.c_char_p(tt) return self._stub.is_library_loaded(lib_cfilename) == 1 @@ -96,21 +98,39 @@ def load(self): # Check if library is already loaded is_loaded = True lib_cfilename = None + lib_filename = self._library_filename + counter = 0 while is_loaded: # Convert library filename to string according to OS if os.name == 'nt': # As UTF-16 - lib_cfilename = ctypes.c_wchar_p(self._library_filename) + lib_cfilename = ctypes.c_wchar_p(lib_filename) else: # As UTF-8 - lib_cfilename = ctypes.c_char_p(self._library_filename.encode('utf-8')) + lib_cfilename = ctypes.c_char_p(lib_filename.encode('utf-8')) + # Test if the library is loaded. is_loaded = self._stub.is_library_loaded(lib_cfilename) + if is_loaded == 1: warnings.warn(f'Library {self._library_filename} already loaded, renaming file') + + # The library is loaded, copy the _original_ library file to a new file + # and then try to load that. We only do the copy if the new new name is + # free. It seems that at least on LINUX there is some issue if we + # overwrite a file that already exists. + lib_filename = self._library_filename + f'_{counter}' + counter += 1 + if pathlib.Path(lib_filename).exists(): + assert pathlib.Path(lib_filename).is_file() + continue + + # The file name is not taken, so make a copy. There might be a race condition + # here in the presence of multiple processes. + # TODO: Investigate if we should switch to hardlinks if they are supported. try: - shutil.copyfile(self._library_filename, self._library_filename + '_') - self._library_filename += '_' + assert self._library_filename != lib_filename + shutil.copyfile(self._library_filename, lib_filename) except shutil.Error: raise cgx.DuplicateDLLError(f'Library {os.path.basename(self._library_filename)}' 'is already loaded somewhere else and cannot be unloaded. ' @@ -118,6 +138,7 @@ def load(self): # Actually load the library self._lib = ctypes.c_void_p(self._stub.load_library(lib_cfilename)) + self._library_filename = lib_filename if self._lib.value is None: # Try to understand why the library is not loading, if dynamic @@ -147,12 +168,38 @@ def __enter__(self, *args, **kwargs): def __exit__(self, *args, **kwargs): self.unload() + def __copy__(self): + raise RuntimeError(f'Can not copy ReloadableDLL({self._library_filename})') + + def __deepcopy__(self, memodict={}): + raise RuntimeError(f'Can not copy ReloadableDLL({self._library_filename})') + class CompiledSDFG(object): """ A compiled SDFG object that can be called through Python. - Todo: - Scalar return values are not handled properly, this is a code gen issue. + Essentially this class makes an SDFG callable. Normally a user will not create it + directly but instead it is generated by some utilities such as `SDFG.compile()`. + + The class performs the following tasks: + - It ensures that the SDFG object is properly initialized, either by a direct + call to `initialize()` or the first time it is called. Furthermore, it will + also take care of the finalization if it does out of scope. + - It transforms Python arguments into C arguments. + + Technically there are two ways how the SDFG can be called, the first is using + `__call__()`, i.e. as a normal function. However, this will always processes + the arguments and does some error checking and is thus slow. The second way + is the advanced interface, which allows to decompose the calling into different + subset. For more information see `construct_arguments()`, `fast_call()` and + `convert_return_values()`. + + :note: In previous version the arrays used as return values were sometimes reused. + However, this was changed and every time `construct_arguments()` is called + new arrays are allocated. + :note: It is not possible to return scalars. Note that currently using scalars + as return values is a validation error. The only exception are (probably) + Python objects. """ def __init__(self, sdfg, lib: ReloadableDLL, argnames: List[str] = None): @@ -161,9 +208,14 @@ def __init__(self, sdfg, lib: ReloadableDLL, argnames: List[str] = None): self._lib = lib self._initialized = False self._libhandle = ctypes.c_void_p(0) - self._lastargs = () self.do_not_execute = False + # Contains the pointer arguments that where used to call the SDFG, `__call__()` + # was used. It is also used by `get_workspace_size()`. + # NOTE: Using its content might be dangerous as only the pointers to arrays are + # stored. It is the users responsibility to ensure that they are valid. + self._lastargs = None + lib.load() # Explicitly load the library self._init = lib.get_symbol('__dace_init_{}'.format(sdfg.name)) self._init.restype = ctypes.c_void_p @@ -172,17 +224,27 @@ def __init__(self, sdfg, lib: ReloadableDLL, argnames: List[str] = None): self._cfunc = lib.get_symbol('__program_{}'.format(sdfg.name)) # Cache SDFG return values - self._create_new_arrays: bool = True self._return_syms: Dict[str, Any] = None + # It will contain the shape of the array or the name if the return array is passed as argument. self._retarray_shapes: List[Tuple[str, np.dtype, dtypes.StorageType, Tuple[int], Tuple[int], int]] = [] - self._retarray_is_scalar: List[bool] = [] + # Is only `True` if teh return value is a scalar _and_ a `pyobject`. + self._retarray_is_pyobject: List[bool] = [] self._return_arrays: List[np.ndarray] = [] self._callback_retval_references: List[Any] = [] # Avoids garbage-collecting callback return values + # If there are return values then this is `True` it is is a single value. Note that + # `False` either means that a tuple is returned or there are no return values. + # NOTE: Needed to handle the case of a tuple with one element. + self._is_single_value_ret: bool = False + if '__return' in self._sdfg.arrays: + assert not any(aname.startswith('__return_') for aname in self._sdfg.arrays.keys()) + self._is_single_value_ret = True + # Cache SDFG argument properties self._typedict = self._sdfg.arglist() self._sig = self._sdfg.signature_arglist(with_types=False, arglist=self._typedict) self._free_symbols = self._sdfg.free_symbols + self._constants = self._sdfg.constants self.argnames = argnames if self.argnames is None and len(sdfg.arg_names) != 0: @@ -269,12 +331,21 @@ def get_workspace_sizes(self) -> Dict[dtypes.StorageType, int]: """ Returns the total external memory size to be allocated for this SDFG. + Note that the function queries the sizes of the last call that was made by + `__call__()` or `initialize()`. Calls made by `fast_call()` or `safe_call()` + will not be considered. + :return: A dictionary mapping storage types to the number of bytes necessary to allocate for the SDFG to work properly. + :note: It is the users responsibility that all arguments, especially the array + arguments, remain valid between the call to `__call__()` or `initialize()` + and the call to this function. """ if not self._initialized: raise ValueError('Compiled SDFG is uninitialized, please call ``initialize`` prior to ' 'querying external memory size.') + if self._lastargs is None: + raise ValueError('To use `get_workspace_sizes()` `__call__()` or `initialize()` must be called before.') result: Dict[dtypes.StorageType, int] = {} for storage in self.external_memory_types: @@ -288,15 +359,24 @@ def set_workspace(self, storage: dtypes.StorageType, workspace: Any): """ Sets the workspace for the given storage type to the given buffer. + Note that the function queries the sizes of the last call that was made by + `__call__()` or `initialize()`. Calls made by `fast_call()` or `safe_call()` + will not be considered. + :param storage: The storage type to fill. :param workspace: An array-convertible object (through ``__[cuda_]array_interface__``, see ``array_interface_ptr``) to use for the workspace. + :note: It is the users responsibility that all arguments, especially the array + arguments, remain valid between the call to `__call__()` or `initialize()` + and the call to this function. """ if not self._initialized: raise ValueError('Compiled SDFG is uninitialized, please call ``initialize`` prior to ' 'setting external memory.') if storage not in self.external_memory_types: raise ValueError(f'Compiled SDFG does not specify external memory of {storage}') + if self._lastargs is None: + raise ValueError('To use `get_workspace_sizes()` `__call__()` or `initialize()` must be called before.') func = self._lib.get_symbol(f'__dace_set_external_memory_{storage.name}', None) ptr = dtypes.array_interface_ptr(workspace, storage) @@ -331,12 +411,13 @@ def initialize(self, *args, **kwargs): if self._initialized: return - if len(args) > 0 and self.argnames is not None: - kwargs.update({aname: arg for aname, arg in zip(self.argnames, args)}) - # Construct arguments in the exported C function order - _, initargtuple = self._construct_args(kwargs) + callargtuple, initargtuple = self.construct_arguments(*args, **kwargs) self._initialize(initargtuple) + + # The main reason for setting `_lastargs` here is, to allow calls to `get_workspace_size()`. + self._lastargs = (callargtuple, initargtuple) + return self._libhandle def finalize(self): @@ -361,38 +442,34 @@ def __call__(self, *args, **kwargs): """ Forwards the Python call to the compiled ``SDFG``. - The order of the positional arguments is expected to be the same as in - the ``argnames`` member. The function will roughly perform the - following tasks: - - Change the order of the Python arguments into the one required by - the binary. - - Performing some basic sanity checks. - - Transforming the Python arguments into their ``C`` equivalents. - - Allocate the memory for the return values. - - Call the ``C` function. + The order of the positional arguments is expected to be the same as in the + ``argnames`` member. The function will perform the following tasks: + - Calling ``construct_arguments()`` and creating the argument vector and + allocating the memory for the return values. + - Performing the actual call by means of ``fast_call()``, with enabled error + checks. + - Then it will convert the return value into the expected format by means of + ``convert_return_values()`` and return that value. :note: The memory for the return values is only allocated the first time this function is called. Thus, this function will always return the same objects. To force the allocation of new memory you can call ``clear_return_values()`` in advance. """ - if self.argnames is None and len(args) != 0: - raise KeyError(f"Passed positional arguments to an SDFG that does not accept them.") - elif len(args) > 0 and self.argnames is not None: - kwargs.update( - # `_construct_args` will handle all of its arguments as kwargs. - { - aname: arg - for aname, arg in zip(self.argnames, args) - }) - argtuple, initargtuple = self._construct_args(kwargs) # Missing arguments will be detected here. - # Return values are cached in `self._lastargs`. - return self.fast_call(argtuple, initargtuple, do_gpu_check=True) + argtuple, initargtuple = self.construct_arguments(*args, **kwargs) # Missing arguments will be detected here. + self._lastargs = (argtuple, initargtuple) + self.fast_call(argtuple, initargtuple, do_gpu_check=True) + return self.convert_return_values() def safe_call(self, *args, **kwargs): """ Forwards the Python call to the compiled ``SDFG`` in a separate process to avoid crashes in the main process. Raises an exception if the SDFG execution fails. + + Note the current implementation lacks the proper handling of return values. + Thus output can only be transmitted through inout arguments. """ + if any(aname == '__return' or aname.startswith('__return_') for aname in self.sdfg.arrays.keys()): + raise NotImplementedError('`CompiledSDFG.safe_call()` does not support return values.') # Pickle the SDFG and arguments with tempfile.NamedTemporaryFile(mode='wb', delete=False) as f: @@ -444,24 +521,25 @@ def safe_call(self, *args, **kwargs): def fast_call( self, - callargs: Tuple[Any, ...], - initargs: Tuple[Any, ...], + callargs: Sequence[Any], + initargs: Sequence[Any], do_gpu_check: bool = False, - ) -> Union[Tuple[Any, ...], Any]: + ) -> None: """ - Calls the underlying binary functions directly and bypassing - argument sanitation. + Calls the underlying binary functions directly and bypassing argument sanitation. - This is a faster, but less user friendly version of ``__call__()``. - While ``__call__()`` will transforms its Python arguments such that - they can be forwarded, this function assumes that this processing - was already done by the user. + This is a faster, but less user friendly version of ``__call__()``. While + ``__call__()`` will transforms its Python arguments such that they can be + forwarded and allocate memory for the return values, this function assumes + that this processing was already done by the user. + To build the argument vectors you should use `self.construct_arguments()`. :param callargs: Arguments passed to the actual computation. :param initargs: Arguments passed to the initialization function. :param do_gpu_check: Check if errors happened on the GPU. - :note: You may use `_construct_args()` to generate the processed arguments. + :note: This is an advanced interface. + :note: In previous versions this function also called `convert_return_values()`. """ try: # Call initializer function if necessary, then SDFG @@ -485,8 +563,7 @@ def fast_call( if lasterror is not None: raise RuntimeError( f'An error was detected when calling "{self._sdfg.name}": {self._get_error_text(lasterror)}') - - return self._convert_return_values() + return except (RuntimeError, TypeError, UnboundLocalError, KeyError, cgx.DuplicateDLLError, ReferenceError): self._lib.unload() raise @@ -498,18 +575,40 @@ def __del__(self): self._libhandle = ctypes.c_void_p(0) self._lib.unload() - def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]: - """ - Main function that controls argument construction for calling - the C prototype of the SDFG. + def construct_arguments(self, *args: Any, **kwargs: Any) -> Tuple[Tuple[Any], Tuple[Any]]: + """Construct the argument vectors suitable for from its argument. + + The function returns a pair of tuple, that are suitable for `fast_call()`. + The first element of is `callargs`, i.e. the full arguments, while the + second element is `initargs`, which is only used/needed the first time + an SDFG is called. + + It is important that this function will also allocate new return values. + The array objects are managed by `self` and remain valid until this + function is called again. However, they are also returned by `self.__call__()`. - Organizes arguments first by ``sdfg.arglist``, then data descriptors - by alphabetical order, then symbols by alphabetical order. + It is also possible to pass the array, that should be used to return a value, + directly as argument. In that case the allocation for that return value will + be skipped. - :note: If not initialized this function will initialize the memory for - the return values, however, it might also reallocate said memory. - :note: This function will also update the internal argument cache. + :note: In case of arrays, the returned argument vectors only contains the + pointers to the underlying memory. Thus it is the user's responsibility + to ensure that the memory remains allocated until the argument vector + is used. + :note: This is an advanced interface. """ + if self.argnames is None and len(args) != 0: + raise KeyError(f"Passed positional arguments to an SDFG that does not accept them.") + elif len(args) > 0 and self.argnames is not None: + positional_arguments = {aname: avalue for aname, avalue in zip(self.argnames, args)} + if not positional_arguments.keys().isdisjoint(kwargs.keys()): + raise ValueError( + f'The arguments where passed once as positional and named arguments: {set(positional_arguments.keys()).intersection(kwargs.keys())}' + ) + kwargs.update(positional_arguments) + + # NOTE: This might invalidate the elements associated to the return values of + # all argument vectors that were created before. self._initialize_return_values(kwargs) # Add the return values to the arguments, since they are part of the C signature. @@ -539,31 +638,51 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]: argnames = [] sig = [] - # Type checking - cargs = [] no_view_arguments = not Config.get_bool('compiler', 'allow_view_arguments') - for i, (a, arg, atype) in enumerate(zip(argnames, arglist, argtypes)): - carg = dt.make_ctypes_argument(arg, - atype, - a, - allow_views=not no_view_arguments, - symbols=kwargs, - callback_retval_references=self._callback_retval_references) - cargs.append(carg) - - constants = self.sdfg.constants + cargs = tuple( + dt.make_ctypes_argument(aval, + atype, + aname, + allow_views=not no_view_arguments, + symbols=kwargs, + callback_retval_references=self._callback_retval_references) + for aval, atype, aname in zip(arglist, argtypes, argnames)) + symbols = self._free_symbols callparams = tuple((carg, aname) for arg, carg, aname in zip(arglist, cargs, argnames) - if not (symbolic.issymbolic(arg) and (hasattr(arg, 'name') and arg.name in constants))) - - newargs = tuple(carg for carg, aname in callparams) + if not ((hasattr(arg, 'name') and arg.name in self._constants) and symbolic.issymbolic(arg))) + newargs = tuple(carg for carg, _aname in callparams) initargs = tuple(carg for carg, aname in callparams if aname in symbols) - self._lastargs = newargs, initargs - return self._lastargs + return (newargs, initargs) + + def convert_return_values(self) -> Union[Any, Tuple[Any, ...]]: + """Convert the return arguments. + + Execute the `return` statement and return. This function should only be called + after `fast_call()` has been run. + Keep in mid that it is not possible to return scalars (with the exception of + `pyobject`s), they will be always returned as an array with shape `(1,)`. + + :note: This is an advanced interface. + :note: After `fast_call()` returns it is only allowed to call this function once. + """ + # TODO: Make sure that the function is called only once by checking it. + # NOTE: Currently it is not possible to return a scalar value, see `tests/sdfg/scalar_return.py` + if not self._return_arrays: + return None + elif self._is_single_value_ret: + assert len(self._return_arrays) == 1 + return self._return_arrays[0].item() if self._retarray_is_pyobject[0] else self._return_arrays[0] + else: + return tuple(r.item() if is_pyobj else r + for r, is_pyobj in zip(self._return_arrays, self._retarray_is_pyobject)) def clear_return_values(self): - self._create_new_arrays = True + warnings.warn( + 'The "CompiledSDFG.clear_return_values" API is deprecated, as this behaviour has' + ' become the new default, and is a noops.', DeprecationWarning) + pass def _create_array(self, _: str, dtype: np.dtype, storage: dtypes.StorageType, shape: Tuple[int], strides: Tuple[int], total_size: int): @@ -583,8 +702,6 @@ def ndarray(*args, buffer=None, **kwargs): zeros = cupy.empty except (ImportError, ModuleNotFoundError): raise NotImplementedError('GPU return values are unsupported if cupy is not installed') - if storage is dtypes.StorageType.FPGA_Global: - raise NotImplementedError('FPGA return values are unsupported') # Create an array with the properties of the SDFG array return ndarray(shape, dtype, buffer=zeros(total_size, dtype), strides=strides) @@ -599,52 +716,76 @@ def _initialize_return_values(self, kwargs): # Clear references from last call (allow garbage collection) self._callback_retval_references.clear() - if self._initialized: - if self._return_syms == syms: - if not self._create_new_arrays: - return - else: - self._create_new_arrays = False - # Use stored sizes to recreate arrays (fast path) - self._return_arrays = tuple(kwargs[desc[0]] if desc[0] in kwargs else self._create_array(*desc) - for desc in self._retarray_shapes) - return + if self._initialized and self._return_syms == syms: + # Use stored sizes to recreate arrays (fast path) + self._return_arrays = tuple(kwargs[desc[0]] if desc[0] in kwargs else self._create_array(*desc) + for desc in self._retarray_shapes) + return self._return_syms = syms - self._create_new_arrays = False - - # Initialize return values with numpy arrays - self._retarray_shapes = [] self._return_arrays = [] + self._retarray_shapes = [] + self._retarray_is_pyobject = [] for arrname, arr in sorted(self.sdfg.arrays.items()): - if arrname.startswith('__return') and not arr.transient: - if arrname in kwargs: + if arrname.startswith('__return'): + if arr.transient: + raise ValueError(f'Used the special array name "{arrname}" as transient.') + + elif arrname in kwargs: + # The return value is passed as an argument, in that case store the name in `self._retarray_shapes`. + warnings.warn(f'Return value "{arrname}" is passed as a regular argument.', stacklevel=2) self._return_arrays.append(kwargs[arrname]) - self._retarray_is_scalar.append(isinstance(arr, dt.Scalar)) self._retarray_shapes.append((arrname, )) - continue - if isinstance(arr, dt.Stream): + elif isinstance(arr, dt.Stream): raise NotImplementedError('Return streams are unsupported') - shape = tuple(symbolic.evaluate(s, syms) for s in arr.shape) - dtype = arr.dtype.as_numpy_dtype() - total_size = int(symbolic.evaluate(arr.total_size, syms)) - strides = tuple(symbolic.evaluate(s, syms) * arr.dtype.bytes for s in arr.strides) - shape_desc = (arrname, dtype, arr.storage, shape, strides, total_size) - self._retarray_is_scalar.append(isinstance(arr, dt.Scalar) or isinstance(arr.dtype, dtypes.pyobject)) - self._retarray_shapes.append(shape_desc) - - # Create an array with the properties of the SDFG array - arr = self._create_array(*shape_desc) - self._return_arrays.append(arr) + else: + shape = tuple(symbolic.evaluate(s, syms) for s in arr.shape) + dtype = arr.dtype.as_numpy_dtype() + total_size = int(symbolic.evaluate(arr.total_size, syms)) + strides = tuple(symbolic.evaluate(s, syms) * arr.dtype.bytes for s in arr.strides) + shape_desc = (arrname, dtype, arr.storage, shape, strides, total_size) + self._retarray_shapes.append(shape_desc) + + # Create an array with the properties of the SDFG array + return_array = self._create_array(*shape_desc) + self._return_arrays.append(return_array) + + # BUG COMPATIBILITY(PR#2206): + # In the original version `_retarray_is_pyobject` was named `_retarray_is_scalar`, however + # since scalars could not be returned on an [implementation level](https://github.com/spcl/dace/pull/1609) + # it was essentially useless. But was used for `pyobject` in _some_ cases. And indeed, + # since `pyobject`s are essentially `void` pointers is was, in principle possible, to return/pass + # them as "scalars", read "not inside an array". + # However, if the return value was passed as argument, i.e. the first `elif`, then it + # was ignored if `arr` was a `pyobject`. Only if the return value was managed by `self`, + # i.e. the `else` case, then it was considered, in a way at least. The problem was, that it was + # done using the following check: + # `isinstance(arr, dt.Scalar) or isinstance(arr.dtype, dtypes.pyobject)` + # Because of the `or` that is used, _everything_ whose `dtype` is `pyobject` was classified + # as a scalar `pyobject`, i.e. one element, even if it was in fact an array of millions of `pyobject`s. + # The correct behaviour would be to change the `or` to an `and` but then several unit + # tests (`test_pyobject_return`, `test_pyobject_return_tuple` and `test_nested_autoparse[False]` + # in `tests/python_frontend/callee_autodetect_test.py`) will fail. + # The following code is bug compatible and also allows to pass a `pyobject` directly, i.e. + # through `kwargs`. + if isinstance(arr.dtype, dtypes.pyobject): + if isinstance(arr, dt.Scalar): + # Proper scalar. + self._retarray_is_pyobject.append(True) + elif isinstance(arr, dt.Array): + # An array, let's check if it is just a wrapper for a single value. + if not (len(arr.shape) == 1 and arr.shape[0] == 1): + warnings.warn(f'Decay an array of `pyobject`s with shape {arr.shape} to a single one.', + stacklevel=2) + self._retarray_is_pyobject.append(True) + else: + raise ValueError( + f'Does not know how to handle "{arrname}", which is a {type(arr).__name__} of `pyobject`.') + else: + self._retarray_is_pyobject.append(False) - def _convert_return_values(self): - # Return the values as they would be from a Python function - # NOTE: Currently it is not possible to return a scalar value, see `tests/sdfg/scalar_return.py` - if not self._return_arrays: - return None - elif len(self._return_arrays) == 1: - return self._return_arrays[0].item() if self._retarray_is_scalar[0] else self._return_arrays[0] - else: - return tuple(r.item() if scalar else r for r, scalar in zip(self._return_arrays, self._retarray_is_scalar)) + assert (not self._is_single_value_ret) or (len(self._return_arrays) == 1) + assert len(self._return_arrays) == len(self._retarray_shapes) == len(self._retarray_is_pyobject) + self._return_arrays = tuple(self._return_arrays) diff --git a/dace/codegen/compiler.py b/dace/codegen/compiler.py index 00f40da622..d4ac377177 100644 --- a/dace/codegen/compiler.py +++ b/dace/codegen/compiler.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. """ Handles compilation of code objects. Creates the proper folder structure, compiles each target separately, links all targets to one binary, and returns the corresponding CompiledSDFG object. """ @@ -18,10 +18,10 @@ import dace from dace.config import Config from dace.codegen import exceptions as cgx -from dace.codegen.targets.target import TargetCodeGenerator +from dace.codegen.target import TargetCodeGenerator from dace.codegen.codeobject import CodeObject from dace.codegen import compiled_sdfg as csd -from dace.codegen.targets.target import make_absolute +from dace.codegen.target import make_absolute T = TypeVar('T') @@ -170,6 +170,7 @@ def configure_and_compile(program_folder, program_name=None, output_stream=None) "-DDACE_SRC_DIR=\"{}\"".format(src_folder), "-DDACE_FILES=\"{}\"".format(";".join(files)), "-DDACE_PROGRAM_NAME={}".format(program_name), + "-DDACE_CPP_STANDARD={}".format(Config.get('compiler', 'cpp_standard')), ] # Get required environments are retrieve the CMake information @@ -188,9 +189,11 @@ def configure_and_compile(program_folder, program_name=None, output_stream=None) # Generate CMake options for each compiler libraries = set() + cmake_files = [] for target_name, target in sorted(targets.items()): try: cmake_command += target.cmake_options() + cmake_files += target.cmake_files() libraries |= unique_flags(Config.get("compiler", target_name, "libs")) except KeyError: pass @@ -198,7 +201,7 @@ def configure_and_compile(program_folder, program_name=None, output_stream=None) raise cgx.CompilerConfigurationError(str(ex)) cmake_command.append("-DDACE_LIBS=\"{}\"".format(" ".join(sorted(libraries)))) - + cmake_command.append(f"-DDACE_CMAKE_FILES=\"{';'.join(cmake_files)}\"") cmake_command.append(f"-DCMAKE_BUILD_TYPE={Config.get('compiler', 'build_type')}") # Set linker and linker arguments, iff they have been specified diff --git a/dace/codegen/control_flow.py b/dace/codegen/control_flow.py index 2436c19b7c..1b5384428e 100644 --- a/dace/codegen/control_flow.py +++ b/dace/codegen/control_flow.py @@ -120,6 +120,11 @@ def _loop_region_to_code(region: LoopRegion, dispatch_state: Callable[[SDFGState expr += f'{update};\n' expr += '}\n' else: + if loop.unroll: + if loop.unroll_factor >= 1: + expr += f'#pragma unroll {loop.unroll_factor}\n' + else: + expr += f'#pragma unroll\n' expr += f'for ({init}; {cond}; {update}) {{\n' expr += _clean_loop_body(control_flow_region_to_code(loop, dispatch_state, codegen, symbols)) expr += '\n}\n' diff --git a/dace/codegen/cppunparse.py b/dace/codegen/cppunparse.py index 842a515fb7..a473b59db3 100644 --- a/dace/codegen/cppunparse.py +++ b/dace/codegen/cppunparse.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. # This module is derived from astunparse: https://github.com/simonpercivall/astunparse ########################################################################## ### astunparse LICENSES @@ -85,7 +85,7 @@ from numbers import Number from six import StringIO from dace import dtypes -from dace.codegen.tools import type_inference +from dace.sdfg import type_inference if sys.version_info < (3, 8): BytesConstant = ast.Bytes @@ -108,14 +108,6 @@ _py2c_reserved = {"True": "true", "False": "false", "None": "nullptr", "inf": "INFINITY", "nan": "NAN"} -_py2c_typeconversion = { - "uint": dace.dtypes.typeclass(np.uint32), - "int": dace.dtypes.typeclass(int), - "float": dace.dtypes.typeclass(float), - "float64": dace.dtypes.typeclass(np.float64), - "str": dace.dtypes.pointer(dace.dtypes.int8) -} - def interleave(inter, f, seq, **kwargs): """ @@ -342,15 +334,7 @@ def _Assign(self, t): raise RuntimeError(f"Failed to infer type of \"{target.id}\".") self.locals.define(target.id, t.lineno, self._indent, inferred_type) - if self.language == dace.dtypes.Language.OpenCL and (inferred_type is not None - and inferred_type.veclen > 1): - # if the veclen is greater than one, this should be defined with a vector data type - self.write("{}{} ".format(dace.dtypes._OCL_VECTOR_TYPES[inferred_type.type], - inferred_type.veclen)) - elif self.language == dace.dtypes.Language.OpenCL: - self.write(dace.dtypes._OCL_TYPES[inferred_type.type] + " ") - else: - self.write(dace.dtypes._CTYPES[inferred_type.type] + " ") + self.write(dace.dtypes._CTYPES[inferred_type.type] + " ") else: self.locals.define(target.id, t.lineno, self._indent) self.write("auto ") diff --git a/dace/codegen/dispatcher.py b/dace/codegen/dispatcher.py index 13fd27aeb1..86117cb0ab 100644 --- a/dace/codegen/dispatcher.py +++ b/dace/codegen/dispatcher.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. """ Contains the DaCe code generator target dispatcher, which is responsible for flexible code generation with multiple backends by dispatching certain @@ -9,7 +9,7 @@ from dace import config, data as dt, dtypes, nodes, registry from dace.memlet import Memlet from dace.codegen import exceptions as cgx, prettycode -from dace.codegen.targets import target +from dace.codegen import target from dace.sdfg import utils as sdutil, SDFG, SDFGState, ScopeSubgraphView from dace.sdfg.graph import MultiConnectorEdge from typing import Callable, Dict, List, Optional, Set, Tuple, Union @@ -28,7 +28,7 @@ class DefinedType(aenum.AutoNumberEnum): Object = () # An object moved by reference Stream = () # A stream object moved by reference and accessed via a push/pop API StreamArray = () # An array of Streams - FPGA_ShiftRegister = () # A shift-register object used in FPGA code generation + # TODO Remove ArrayInterface in subsequent PR ArrayInterface = () # An object representing an interface to an array, used mostly in FPGA diff --git a/dace/codegen/instrumentation/__init__.py b/dace/codegen/instrumentation/__init__.py index d357e1a5a3..49295b85a1 100644 --- a/dace/codegen/instrumentation/__init__.py +++ b/dace/codegen/instrumentation/__init__.py @@ -6,6 +6,6 @@ from .likwid import LIKWIDInstrumentationCPU, LIKWIDInstrumentationGPU from .timer import TimerProvider from .gpu_events import GPUEventProvider -from .fpga import FPGAInstrumentationProvider +from .gpu_tx_markers import GPUTXMarkersProvider from .data.data_dump import SaveProvider, RestoreProvider diff --git a/dace/codegen/instrumentation/data/data_dump.py b/dace/codegen/instrumentation/data/data_dump.py index 94bfa2f3bc..b05d59f0a1 100644 --- a/dace/codegen/instrumentation/data/data_dump.py +++ b/dace/codegen/instrumentation/data/data_dump.py @@ -1,9 +1,8 @@ -# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. from dace import data as dt, dtypes, registry, SDFG -from dace.sdfg import nodes, is_devicelevel_gpu +from dace.sdfg import nodes from dace.codegen.prettycode import CodeIOStream from dace.codegen.instrumentation.provider import InstrumentationProvider -from dace.sdfg.scope import is_devicelevel_fpga from dace.sdfg.state import ControlFlowRegion, SDFGState from dace.codegen import common from dace.codegen import cppunparse @@ -84,14 +83,13 @@ class SaveProvider(InstrumentationProvider, DataInstrumentationProviderMixin): def __init__(self): super().__init__() self.gpu_runtime_init = False - from dace.codegen.targets.framecode import DaCeCodeGenerator # Avoid import loop - self.codegen: DaCeCodeGenerator = None + self.framecode: 'DaCeCodeGenerator' = None def on_sdfg_begin(self, sdfg: SDFG, local_stream: CodeIOStream, global_stream: CodeIOStream, codegen: 'DaCeCodeGenerator'): # Initialize serializer versioning object if sdfg.parent is None: - self.codegen = codegen + self.framecode = codegen path = os.path.abspath(os.path.join(sdfg.build_folder, 'data')).replace('\\', '/') codegen.statestruct.append('dace::DataSerializer *serializer;') sdfg.append_init_code(f'__state->serializer = new dace::DataSerializer("{path}");\n') @@ -132,10 +130,6 @@ def on_node_end(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream): from dace.codegen.dispatcher import DefinedType # Avoid import loop - if is_devicelevel_gpu(sdfg, state, node) or is_devicelevel_fpga(sdfg, state, node): - # Only run on host code - return - condition_preamble, condition_postamble = '', '' condition: Optional[CodeBlock] = node.instrument_condition if condition is not None and not condition.as_string == '1': @@ -153,8 +147,8 @@ def on_node_end(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node desc = node.desc(sdfg) # Obtain a pointer for arrays and scalars - ptrname = cpp.ptr(node.data, desc, sdfg, self.codegen) - defined_type, _ = self.codegen.dispatcher.defined_vars.get(ptrname) + ptrname = cpp.ptr(node.data, desc, sdfg, self.framecode) + defined_type, _ = self.framecode.dispatcher.defined_vars.get(ptrname) if defined_type == DefinedType.Scalar: ptrname = '&' + ptrname @@ -190,8 +184,7 @@ class RestoreProvider(InstrumentationProvider, DataInstrumentationProviderMixin) def __init__(self): super().__init__() self.gpu_runtime_init = False - from dace.codegen.targets.framecode import DaCeCodeGenerator # Avoid import loop - self.codegen: DaCeCodeGenerator = None + self.framecode: 'DaCeCodeGenerator' = None def _generate_report_setter(self, sdfg: SDFG) -> str: return f''' @@ -204,7 +197,7 @@ def on_sdfg_begin(self, sdfg: SDFG, local_stream: CodeIOStream, global_stream: C codegen: 'DaCeCodeGenerator'): # Initialize serializer versioning object if sdfg.parent is None: - self.codegen = codegen + self.framecode = codegen codegen.statestruct.append('dace::DataSerializer *serializer;') sdfg.append_init_code(f'__state->serializer = new dace::DataSerializer("");\n') @@ -248,10 +241,6 @@ def on_node_begin(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, no outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream): from dace.codegen.dispatcher import DefinedType # Avoid import loop - if is_devicelevel_gpu(sdfg, state, node) or is_devicelevel_fpga(sdfg, state, node): - # Only run on host code - return - condition_preamble, condition_postamble = '', '' condition: Optional[CodeBlock] = node.instrument_condition if condition is not None and not condition.as_string == '1': @@ -269,8 +258,8 @@ def on_node_begin(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, no desc = node.desc(sdfg) # Obtain a pointer for arrays and scalars - ptrname = cpp.ptr(node.data, desc, sdfg, self.codegen) - defined_type, _ = self.codegen.dispatcher.defined_vars.get(ptrname) + ptrname = cpp.ptr(node.data, desc, sdfg, self.framecode) + defined_type, _ = self.framecode.dispatcher.defined_vars.get(ptrname) if defined_type == DefinedType.Scalar: ptrname = '&' + ptrname diff --git a/dace/codegen/instrumentation/data/data_report.py b/dace/codegen/instrumentation/data/data_report.py index d944c916f3..c13fabae77 100644 --- a/dace/codegen/instrumentation/data/data_report.py +++ b/dace/codegen/instrumentation/data/data_report.py @@ -2,10 +2,14 @@ from dataclasses import dataclass import struct from typing import Any, Dict, List, Set, Tuple, Union +from numbers import Number import os from dace import dtypes, SDFG -from dace.data import ArrayLike, Number # Type hint +try: + from numpy.typing import ArrayLike +except ImportError: + ArrayLike = Any # type: ignore import numpy as np diff --git a/dace/codegen/instrumentation/fpga.py b/dace/codegen/instrumentation/fpga.py deleted file mode 100644 index b9f4fdf758..0000000000 --- a/dace/codegen/instrumentation/fpga.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -from dace import dtypes, registry -from dace.codegen.instrumentation.provider import InstrumentationProvider - - -@registry.autoregister_params(type=dtypes.InstrumentationType.FPGA) -class FPGAInstrumentationProvider(InstrumentationProvider): - """Dummy provider to register the instrumentation type.""" - - def __init__(self): - super().__init__() diff --git a/dace/codegen/instrumentation/gpu_tx_markers.py b/dace/codegen/instrumentation/gpu_tx_markers.py new file mode 100644 index 0000000000..02c3454e89 --- /dev/null +++ b/dace/codegen/instrumentation/gpu_tx_markers.py @@ -0,0 +1,259 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import os +from typing import Union + +from dace import dtypes, registry +from dace.codegen import common +from dace.codegen.prettycode import CodeIOStream +from dace.codegen.instrumentation.provider import InstrumentationProvider +from dace.memlet import Memlet +from dace.sdfg import nodes, SDFG +from dace.sdfg.graph import MultiConnectorEdge +from dace.sdfg.nodes import NestedSDFG +from dace.sdfg.scope import is_devicelevel_gpu_kernel +from dace.sdfg.sdfg import SDFG +from dace.sdfg.state import ControlFlowRegion, SDFGState + + +@registry.autoregister_params(type=dtypes.InstrumentationType.GPU_TX_MARKERS) +class GPUTXMarkersProvider(InstrumentationProvider): + """ Timing instrumentation that adds NVTX/rocTX ranges to SDFGs and states. """ + NVTX_HEADER_INCLUDE = '#include ' + ROCTX_HEADER_INCLUDE = '#include ' + + def __init__(self): + self.backend = common.get_gpu_backend() + # Check if ROCm TX libraries and headers are available + rocm_path = os.getenv('ROCM_PATH', '/opt/rocm') + roctx_header_paths = [ + os.path.join(rocm_path, 'roctracer/include/roctx.h'), + os.path.join(rocm_path, 'include/roctracer/roctx.h') + ] + roctx_library_path = os.path.join(rocm_path, 'lib', 'libroctx64.so') + self.enable_rocTX = any(os.path.isfile(path) + for path in roctx_header_paths) and os.path.isfile(roctx_library_path) + self.include_generated = False + super().__init__() + + def _print_include(self, sdfg: SDFG) -> None: + """ Prints the include statement for the NVTX/rocTX library for a given SDFG. """ + if self.include_generated: + return + if self.backend == 'cuda': + sdfg.append_global_code(self.NVTX_HEADER_INCLUDE, 'frame') + elif self.backend == 'hip': + if self.enable_rocTX: + sdfg.append_global_code(self.ROCTX_HEADER_INCLUDE, 'frame') + else: + raise NameError('GPU backend "%s" not recognized' % self.backend) + self.include_generated = True + + def print_include(self, stream: CodeIOStream) -> None: + """ Prints the include statement for the NVTX/rocTX library in stream. """ + if stream is None: + return + if self.include_generated: + return + if self.backend == 'cuda': + stream.write(self.NVTX_HEADER_INCLUDE) + elif self.backend == 'hip': + if self.enable_rocTX: + stream.write(self.ROCTX_HEADER_INCLUDE) + else: + raise NameError('GPU backend "%s" not recognized' % self.backend) + self.include_generated = True + + def print_range_push(self, name: str, sdfg: SDFG, stream: CodeIOStream) -> None: + if stream is None: + return + self._print_include(sdfg) + if name is None: + name = 'None' + if self.backend == 'cuda': + stream.write(f'nvtxRangePush("{name}");') + elif self.backend == 'hip': + if self.enable_rocTX: + stream.write(f'roctxRangePush("{name}");') + else: + raise NameError(f'GPU backend "{self.backend}" not recognized') + + def print_range_pop(self, stream: CodeIOStream) -> None: + if stream is None: + return + if self.backend == 'cuda': + stream.write('nvtxRangePop();') + elif self.backend == 'hip': + if self.enable_rocTX: + stream.write('roctxRangePop();') + else: + raise NameError(f'GPU backend "{self.backend}" not recognized') + + def _is_sdfg_in_device_code(self, sdfg: SDFG) -> bool: + """ Check if the SDFG is in device code and not top level SDFG. """ + sdfg_parent_state = sdfg.parent + while sdfg_parent_state is not None: + sdfg_parent_node = sdfg.parent_nsdfg_node + if is_devicelevel_gpu_kernel(sdfg, sdfg_parent_state, sdfg_parent_node): + return True + sdfg_parent_state = sdfg_parent_state.sdfg.parent + return False + + def on_sdfg_begin(self, sdfg: SDFG, local_stream: CodeIOStream, global_stream: CodeIOStream, codegen) -> None: + if sdfg.instrument != dtypes.InstrumentationType.GPU_TX_MARKERS: + return + self.print_include(global_stream) + if self._is_sdfg_in_device_code(sdfg): + # Don't instrument device code + return + self.print_range_push(f'sdfg_{sdfg.name}', sdfg, local_stream) + + def on_sdfg_end(self, sdfg: SDFG, local_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + if sdfg.instrument != dtypes.InstrumentationType.GPU_TX_MARKERS: + return + if self._is_sdfg_in_device_code(sdfg): + # Don't instrument device code + return + self.print_range_pop(local_stream) + + def on_state_begin(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, local_stream: CodeIOStream, + global_stream: CodeIOStream) -> None: + if state.instrument != dtypes.InstrumentationType.GPU_TX_MARKERS: + return + if self._is_sdfg_in_device_code(sdfg): + # Don't instrument device code + return + self.print_range_push(f'state_{state.label}', sdfg, local_stream) + + def on_state_end(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, local_stream: CodeIOStream, + global_stream: CodeIOStream) -> None: + if state.instrument != dtypes.InstrumentationType.GPU_TX_MARKERS: + return + if self._is_sdfg_in_device_code(sdfg): + # Don't instrument device code + return + self.print_range_pop(local_stream) + + def on_copy_begin(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, src_node: nodes.Node, + dst_node: nodes.Node, edge: MultiConnectorEdge[Memlet], local_stream: CodeIOStream, + global_stream: CodeIOStream, copy_shape, src_strides, dst_strides) -> None: + if state.instrument != dtypes.InstrumentationType.GPU_TX_MARKERS: + return + if is_devicelevel_gpu_kernel(sdfg, state, src_node) or is_devicelevel_gpu_kernel(sdfg, state, dst_node): + # Don't instrument device code + return + self.print_range_push(f'copy_{src_node.label}_to_{dst_node.label}', sdfg, local_stream) + + def on_copy_end(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, src_node: nodes.Node, + dst_node: nodes.Node, edge: MultiConnectorEdge[Memlet], local_stream: CodeIOStream, + global_stream: CodeIOStream) -> None: + if state.instrument != dtypes.InstrumentationType.GPU_TX_MARKERS: + return + if is_devicelevel_gpu_kernel(sdfg, state, src_node) or is_devicelevel_gpu_kernel(sdfg, state, dst_node): + # Don't instrument device code + return + self.print_range_pop(local_stream) + + def on_scope_entry(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node: nodes.EntryNode, + outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + if node.map.instrument != dtypes.InstrumentationType.GPU_TX_MARKERS: + return + if is_devicelevel_gpu_kernel(sdfg, state, node): + # Don't instrument device code + return + self.print_range_push(f'scope_{node.label}', sdfg, outer_stream) + + def on_scope_exit(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node: nodes.ExitNode, + outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + entry_node = state.entry_node(node) + if entry_node.map.instrument != dtypes.InstrumentationType.GPU_TX_MARKERS: + return + if is_devicelevel_gpu_kernel(sdfg, state, entry_node): + # Don't instrument device code + return + self.print_range_pop(outer_stream) + + def on_sdfg_init_begin(self, sdfg: SDFG, callsite_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + if sdfg.instrument != dtypes.InstrumentationType.GPU_TX_MARKERS: + return + if self._is_sdfg_in_device_code(sdfg): + # Don't instrument device code + return + # cannot push rocTX markers before initializing HIP + if self.enable_rocTX: + return + self.print_range_push(f'init_{sdfg.name}', sdfg, callsite_stream) + + def on_sdfg_init_end(self, sdfg: SDFG, callsite_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + if sdfg.instrument != dtypes.InstrumentationType.GPU_TX_MARKERS: + return + if self._is_sdfg_in_device_code(sdfg): + # Don't instrument device code + return + # cannot push rocTX markers before initializing HIP so there's no marker to pop + if self.enable_rocTX: + return + self.print_range_pop(callsite_stream) + + def on_sdfg_exit_begin(self, sdfg: SDFG, callsite_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + if sdfg.instrument != dtypes.InstrumentationType.GPU_TX_MARKERS: + return + if self._is_sdfg_in_device_code(sdfg): + # Don't instrument device code + return + self.print_range_push(f'exit_{sdfg.name}', sdfg, callsite_stream) + + def on_sdfg_exit_end(self, sdfg: SDFG, callsite_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + if sdfg.instrument != dtypes.InstrumentationType.GPU_TX_MARKERS: + return + if self._is_sdfg_in_device_code(sdfg): + # Don't instrument device code + return + self.print_range_pop(callsite_stream) + + def on_allocation_begin(self, sdfg: SDFG, scope: Union[nodes.EntryNode, SDFGState, SDFG], + stream: CodeIOStream) -> None: + if sdfg.instrument != dtypes.InstrumentationType.GPU_TX_MARKERS: + return + # We only want to instrument allocations at the SDFG or state level + if not isinstance(scope, (SDFGState, SDFG)): + return + if self._is_sdfg_in_device_code(sdfg): + # Don't instrument device code + return + self.print_range_push(f'alloc_{sdfg.name}', sdfg, stream) + + def on_allocation_end(self, sdfg: SDFG, scope: Union[nodes.EntryNode, SDFGState, SDFG], + stream: CodeIOStream) -> None: + if sdfg.instrument != dtypes.InstrumentationType.GPU_TX_MARKERS: + return + # We only want to instrument allocations at the SDFG or state level + if not isinstance(scope, (SDFGState, SDFG)): + return + if self._is_sdfg_in_device_code(sdfg): + # Don't instrument device code + return + self.print_range_pop(stream) + + def on_deallocation_begin(self, sdfg: SDFG, scope: Union[nodes.EntryNode, SDFGState, SDFG], + stream: CodeIOStream) -> None: + if sdfg.instrument != dtypes.InstrumentationType.GPU_TX_MARKERS: + return + # We only want to instrument allocations at the SDFG or state level + if not isinstance(scope, (SDFGState, SDFG)): + return + if self._is_sdfg_in_device_code(sdfg): + # Don't instrument device code + return + self.print_range_push(f'dealloc_{sdfg.name}', sdfg, stream) + + def on_deallocation_end(self, sdfg: SDFG, scope: Union[nodes.EntryNode, SDFGState, SDFG], + stream: CodeIOStream) -> None: + if sdfg.instrument != dtypes.InstrumentationType.GPU_TX_MARKERS: + return + # We only want to instrument allocations at the SDFG or state level + if not isinstance(scope, (SDFGState, SDFG)): + return + if self._is_sdfg_in_device_code(sdfg): + # Don't instrument device code + return + self.print_range_pop(stream) diff --git a/dace/codegen/instrumentation/papi.py b/dace/codegen/instrumentation/papi.py index e14fcade10..116559039a 100644 --- a/dace/codegen/instrumentation/papi.py +++ b/dace/codegen/instrumentation/papi.py @@ -368,7 +368,7 @@ def perf_get_supersection_start_string(node, dfg, unified_id): elif x.map.schedule == dtypes.ScheduleType.Sequential: x.map._can_be_supersection_start = False else: - # Any other type (FPGA, GPU) - not supported by PAPI. + # Any other type (e.g., GPU) - not supported by PAPI. x.map._can_be_supersection_start = False if (node.map._can_be_supersection_start and not dace.sdfg.is_parallel(dfg)): diff --git a/dace/codegen/instrumentation/provider.py b/dace/codegen/instrumentation/provider.py index a95c0495ba..4c14e7a98a 100644 --- a/dace/codegen/instrumentation/provider.py +++ b/dace/codegen/instrumentation/provider.py @@ -183,3 +183,79 @@ def on_node_end(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node :param global_stream: Code generator for global (external) code. """ pass + + def on_sdfg_init_begin(self, sdfg: SDFG, local_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + """ Event called at the beginning of SDFG initialization code generation. + + :param sdfg: The generated SDFG object. + :param local_stream: Code generator for the in-function code. + :param global_stream: Code generator for global (external) code. + """ + pass + + def on_sdfg_init_end(self, sdfg: SDFG, local_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + """ Event called at the end of SDFG initialization code generation. + + :param sdfg: The generated SDFG object. + :param local_stream: Code generator for the in-function code. + :param global_stream: Code generator for global (external) code. + """ + pass + + def on_sdfg_exit_begin(self, sdfg: SDFG, local_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + """ Event called at the beginning of SDFG exit code generation. + + :param sdfg: The generated SDFG object. + :param local_stream: Code generator for the in-function code. + :param global_stream: Code generator for global (external) code. + """ + pass + + def on_sdfg_exit_end(self, sdfg: SDFG, local_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + """ Event called at the end of SDFG exit code generation. + + :param sdfg: The generated SDFG object. + :param local_stream: Code generator for the in-function code. + :param global_stream: Code generator for global (external) code. + """ + pass + + def on_allocation_begin(self, sdfg: SDFG, scope: Union[nodes.EntryNode, SDFGState, SDFG], + stream: CodeIOStream) -> None: + """ Event called at the beginning of an allocation code generation. + + :param sdfg: The generated SDFG object. + :param scope: The scope in which allocation is performed. + :param stream: Code generator. + """ + pass + + def on_allocation_end(self, sdfg: SDFG, scope: Union[nodes.EntryNode, SDFGState, SDFG], + lstream: CodeIOStream) -> None: + """ Event called at the end of an allocation code generation. + + :param sdfg: The generated SDFG object. + :param scope: The scope in which allocation is performed. + :param local_stream: Code generator. + """ + pass + + def on_deallocation_begin(self, sdfg: SDFG, scope: Union[nodes.EntryNode, SDFGState, SDFG], + stream: CodeIOStream) -> None: + """ Event called at the beginning of a deallocation code generation. + + :param sdfg: The generated SDFG object. + :param scope: The scope in which deallocation is performed. + :param local_stream: Code generator. + """ + pass + + def on_deallocation_end(self, sdfg: SDFG, scope: Union[nodes.EntryNode, SDFGState, SDFG], + lstream: CodeIOStream) -> None: + """ Event called at the end of a deallocation code generation. + + :param sdfg: The generated SDFG object. + :param scope: The scope in which deallocation is performed. + :param local_stream: Code generator. + """ + pass diff --git a/dace/codegen/targets/target.py b/dace/codegen/target.py similarity index 84% rename from dace/codegen/targets/target.py rename to dace/codegen/target.py index fb9c415438..f52440e2b2 100644 --- a/dace/codegen/targets/target.py +++ b/dace/codegen/target.py @@ -1,16 +1,19 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. import os import shutil # which -from typing import List +from typing import List, TYPE_CHECKING import warnings -from dace import memlet as mm, data as dt +from dace import memlet as mm, data as dt, dtypes, subsets from dace.sdfg import nodes, SDFG, SDFGState, ScopeSubgraphView, graph as gr from dace.registry import make_registry from dace.codegen.prettycode import CodeIOStream from dace.codegen.codeobject import CodeObject from dace.sdfg.state import ControlFlowRegion +if TYPE_CHECKING: + from dace.codegen.targets.framecode import DaCeCodeGenerator + @make_registry class TargetCodeGenerator(object): @@ -39,6 +42,14 @@ def cmake_options() -> List[str]: """ return [] + @staticmethod + def cmake_files() -> List[str]: + """ + Returns a list of CMake file paths that should be included + during the CMake configuration step. + """ + return [] + def preprocess(self, sdfg: SDFG) -> None: """ Called before code generation on any target that will be dispatched. @@ -50,6 +61,23 @@ def preprocess(self, sdfg: SDFG) -> None: """ pass + def get_framecode_generator(self) -> 'DaCeCodeGenerator': + """ + Returns the frame-code generator associated with this target. + + :return: The frame-code generator. + """ + return self._frame + + def get_includes(self) -> dict[str, list[str]]: + """ + Returns a dictionary mapping backends to lists of include files + required by this target. + + :return: A dictionary of backend names to lists of include files. + """ + return {} + @property def has_initializer(self) -> bool: """ Returns True if the target generates a `__dace_init_` @@ -196,6 +224,29 @@ def copy_memory(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: SDFGState, state_ """ raise NotImplementedError('Abstract class') + def emit_interstate_variable_declaration(self, name: str, dtype: dtypes.typeclass, callsite_stream: CodeIOStream, + sdfg: SDFG): + """ Emits the declaration of an interstate variable at the given + call-site. + + :param name: The name of the variable. + :param dtype: The data type of the variable. + :param callsite_stream: A ``CodeIOStream`` object that points + to the current location (call-site) + in the code. + :param sdfg: The SDFG in which the variable is declared. + """ + raise NotImplementedError('Abstract class') + + def adjust_subset_for_codegen(self, nodedesc: dt.Data, subset: subsets.Subset) -> subsets.Subset: + """ Adjusts a memlet subset for code generation, if necessary. + + :param subset: The original subset. + :param nodedesc: The data descriptor the subset applies to. + :return: The adjusted subset. + """ + return subset + class IllegalCopy(TargetCodeGenerator): """ A code generator that is triggered when invalid copies are specified diff --git a/dace/codegen/targets/__init__.py b/dace/codegen/targets/__init__.py index cd4d5f957f..e101ea3988 100644 --- a/dace/codegen/targets/__init__.py +++ b/dace/codegen/targets/__init__.py @@ -1,11 +1,7 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. from .cpu import CPUCodeGen from .cuda import CUDACodeGen -from .intel_fpga import IntelFPGACodeGen from .mpi import MPICodeGen -from .xilinx import XilinxCodeGen -from .rtl import RTLCodeGen -from .unroller import UnrollCodeGen from .mlir.mlir import MLIRCodeGen from .sve.codegen import SVECodeGen from .snitch import SnitchCodeGen diff --git a/dace/codegen/targets/cpp.py b/dace/codegen/targets/cpp.py index 7871962cad..bfc4835348 100644 --- a/dace/codegen/targets/cpp.py +++ b/dace/codegen/targets/cpp.py @@ -1,4 +1,4 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. """ Helper functions for C++ code generation. NOTE: The C++ code generator is currently located in cpu.py. @@ -29,11 +29,12 @@ from dace.sdfg import nodes, graph as gr, utils, propagation from dace.properties import LambdaProperty from dace.sdfg import SDFG, is_devicelevel_gpu, SDFGState -from dace.codegen.targets import fpga from dace.sdfg.state import ControlFlowRegion, StateSubgraphView if TYPE_CHECKING: from dace.codegen.dispatcher import TargetDispatcher + from dace.codegen.target import TargetCodeGenerator + from dace.codegen.targets.framecode import DaCeCodeGenerator def mangle_dace_state_struct_name(sdfg: Union[SDFG, str]) -> str: @@ -60,7 +61,7 @@ def copy_expr( is_write=None, # Otherwise it's a read offset=None, relative_offset=True, - packed_types=False, + name_override=None, ): data_desc = sdfg.arrays[data_name] # NOTE: Are there any cases where a mix of '.' and '->' is needed when traversing nested structs? @@ -71,6 +72,8 @@ def copy_expr( else: name = data_name ptrname = ptr(data_name, data_desc, sdfg, dispatcher.frame) + if name_override is not None: + ptrname = name_override if relative_offset: s = memlet.subset o = offset @@ -104,39 +107,18 @@ def copy_expr( if not defined_types: defined_types = dispatcher.defined_vars.get(ptrname, is_global=is_global) def_type, _ = defined_types - if fpga.is_fpga_array(data_desc): - # get conf flag - decouple_array_interfaces = Config.get_bool("compiler", "xilinx", "decouple_array_interfaces") - - # TODO: Study structures on FPGAs. Should probably use 'name' instead of 'data_name' here. - expr = fpga.fpga_ptr( - data_name, - data_desc, - sdfg, - s, - is_write, - dispatcher, - 0, - def_type == DefinedType.ArrayInterface - # If this is a view, it has already been renamed - and not isinstance(data_desc, data.View), - decouple_array_interfaces=decouple_array_interfaces) + if name_override is not None: + expr = name_override else: expr = ptr(name, data_desc, sdfg, dispatcher.frame) add_offset = offset_cppstr != "0" - if def_type in [DefinedType.Pointer, DefinedType.ArrayInterface]: + if def_type in [DefinedType.Pointer, DefinedType.ArrayInterface, DefinedType.Object]: return "{}{}{}".format(dt, expr, " + {}".format(offset_cppstr) if add_offset else "") - elif def_type == DefinedType.StreamArray: return "{}[{}]".format(expr, offset_cppstr) - - elif def_type == DefinedType.FPGA_ShiftRegister: - return expr - - elif def_type in [DefinedType.Scalar, DefinedType.Stream, DefinedType.Object]: - + elif def_type in [DefinedType.Scalar, DefinedType.Stream]: if add_offset: raise TypeError("Tried to offset address of scalar {}: {}".format(data_name, offset_cppstr)) @@ -145,8 +127,7 @@ def copy_expr( else: return data_name else: - raise NotImplementedError("copy_expr not implemented " - "for connector type: {}".format(def_type)) + return expr def memlet_copy_to_absolute_strides(dispatcher: 'TargetDispatcher', @@ -155,7 +136,9 @@ def memlet_copy_to_absolute_strides(dispatcher: 'TargetDispatcher', edge: gr.MultiConnectorEdge[mmlt.Memlet], src_node: nodes.AccessNode, dst_node: nodes.AccessNode, - packed_types: bool = False): + src_name_override: Optional[str] = None, + dst_name_override: Optional[str] = None, + codegen: 'TargetCodeGenerator' = None): memlet = edge.data copy_shape = memlet.subset.size_exact() src_nodedesc = src_node.desc(sdfg) @@ -167,6 +150,10 @@ def memlet_copy_to_absolute_strides(dispatcher: 'TargetDispatcher', dst_subset = memlet.get_dst_subset(edge, state) is_src_write = not memlet._is_data_src + if codegen is not None: + src_subset = codegen.adjust_subset_for_codegen(src_nodedesc, src_subset) + dst_subset = codegen.adjust_subset_for_codegen(dst_nodedesc, dst_subset) + if dispatcher is not None: src_expr = copy_expr(dispatcher, sdfg, @@ -175,7 +162,7 @@ def memlet_copy_to_absolute_strides(dispatcher: 'TargetDispatcher', is_write=is_src_write, offset=src_subset, relative_offset=False, - packed_types=packed_types) + name_override=src_name_override) dst_expr = copy_expr(dispatcher, sdfg, dst_node.data, @@ -183,7 +170,7 @@ def memlet_copy_to_absolute_strides(dispatcher: 'TargetDispatcher', is_write=(not is_src_write), offset=dst_subset, relative_offset=False, - packed_types=packed_types) + name_override=dst_name_override) if src_subset is None: src_subset = subsets.Range.from_array(src_nodedesc) if dst_subset is None: @@ -251,16 +238,16 @@ def is_cuda_codegen_in_device(framecode) -> bool: return cuda_codegen_in_device -def ptr(name: str, desc: data.Data, sdfg: SDFG = None, framecode=None) -> str: +def ptr(name: str, desc: data.Data, sdfg: SDFG = None, framecode: 'DaCeCodeGenerator' = None) -> str: """ Returns a string that points to the data based on its name and descriptor. :param name: Data name. :param desc: Data descriptor. + :param sdfg: SDFG in which the data resides. + :param framecode: Frame-code generator object. :return: C-compatible name that can be used to access the data. """ - from dace.codegen.targets.framecode import DaCeCodeGenerator # Avoid import loop - framecode: DaCeCodeGenerator = framecode if '.' in name: root = name.split('.')[0] @@ -290,25 +277,21 @@ def emit_memlet_reference(dispatcher: 'TargetDispatcher', memlet: mmlt.Memlet, pointer_name: str, conntype: dtypes.typeclass, + codegen: 'TargetCodeGenerator', ancestor: int = 1, - is_write: bool = None, - device_code: bool = False, - decouple_array_interfaces: bool = False) -> Tuple[str, str, str]: + is_write: bool = None) -> Tuple[str, str, str]: """ Returns a tuple of three strings with a definition of a reference to an existing memlet. Used in nested SDFG arguments. - :param device_code: boolean flag indicating whether we are in the process of generating FPGA device code - :param decouple_array_interfaces: boolean flag, used for Xilinx FPGA code generation. It indicates whether or not - we are generating code by decoupling reads/write from memory. :return: A tuple of the form (type, name, value). """ desc = sdfg.arrays[memlet.data] typedef = conntype.ctype - offset = cpp_offset_expr(desc, memlet.subset) + offset = cpp_offset_expr(desc, memlet.subset, codegen=codegen) offset_expr = '[' + offset + ']' - is_scalar = not isinstance(conntype, dtypes.pointer) and not fpga.is_fpga_array(desc) - ptrname = ptr(memlet.data, desc, sdfg, dispatcher.frame) + is_scalar = not isinstance(conntype, dtypes.pointer) + ptrname = codegen.ptr(memlet.data, desc, sdfg, subset=memlet.subset, ancestor=ancestor, is_write=is_write) ref = '' # Get defined type (pointer, stream etc.) and change the type definition @@ -325,20 +308,7 @@ def emit_memlet_reference(dispatcher: 'TargetDispatcher', defined_types = dispatcher.defined_vars.get(ptrname, ancestor) defined_type, defined_ctype = defined_types - if fpga.is_fpga_array(desc): - - datadef = fpga.fpga_ptr(memlet.data, - desc, - sdfg, - memlet.subset, - is_write, - dispatcher, - ancestor, - defined_type == DefinedType.ArrayInterface, - decouple_array_interfaces=decouple_array_interfaces) - - else: - datadef = ptr(memlet.data, desc, sdfg, dispatcher.frame) + datadef = ptrname def make_const(expr: str) -> str: # check whether const has already been added before @@ -395,25 +365,11 @@ def make_const(expr: str) -> str: ref = '' typedef = defined_ctype defined_type = DefinedType.StreamArray - elif defined_type == DefinedType.FPGA_ShiftRegister: - ref = '&' if is_scalar else '' - defined_type = DefinedType.Pointer else: raise TypeError('Unsupported memlet type "%s"' % defined_type.name) - if (not device_code and defined_type != DefinedType.ArrayInterface and desc.storage == dace.StorageType.FPGA_Global - and not isinstance(desc, dace.data.Scalar)): - # This is a device buffer accessed on the host. - # Can not be accessed with offset different than zero. Check this if we can: - if (isinstance(offset, int) and int(offset) != 0) or (isinstance(offset, str) and offset.isnumeric() - and int(offset) != 0): - raise TypeError("Can not offset device buffers from host code ({}, offset {})".format(datadef, offset)) - # Device buffers are passed by reference - expr = datadef - ref = '&' - else: - # Cast as necessary - expr = make_ptr_vector_cast(datadef + offset_expr, desc.dtype, conntype, is_scalar, defined_type) + # Cast as necessary + expr = codegen.make_ptr_vector_cast(datadef + offset_expr, desc.dtype, conntype, is_scalar, defined_type) # Register defined variable dispatcher.defined_vars.add(pointer_name, defined_type, typedef, allow_shadowing=True) @@ -479,7 +435,8 @@ def ndcopy_to_strided_copy( """ # Cannot degenerate tiled copies - if any(ts != 1 for ts in subset.tile_sizes): + # In the case where subset is of type Indices, there are no tile_sizes + if hasattr(subset, 'tile_sizes') and any(ts != 1 for ts in subset.tile_sizes): return None # If the copy is contiguous, the difference between the first and last @@ -524,6 +481,7 @@ def ndcopy_to_strided_copy( if copy_shape == src_copy_shape: srcdim = copydim else: + # TODO: Remove try-except in subsequent FPGA PR try: srcdim = next(i for i, c in enumerate(src_copy_shape) if c != 1) except StopIteration: @@ -539,6 +497,7 @@ def ndcopy_to_strided_copy( if copy_shape == dst_copy_shape: dstdim = copydim else: + # TODO: Remove try-except in subsequent FPGA PR try: dstdim = next(i for i, c in enumerate(dst_copy_shape) if c != 1) except StopIteration: @@ -555,7 +514,12 @@ def ndcopy_to_strided_copy( return None -def cpp_offset_expr(d: data.Data, subset_in: subsets.Subset, offset=None, packed_veclen=1, indices=None): +def cpp_offset_expr(d: data.Data, + subset_in: subsets.Subset, + offset=None, + packed_veclen=1, + indices=None, + codegen: Optional['TargetCodeGenerator'] = None) -> str: """ Creates a C++ expression that can be added to a pointer in order to offset it to the beginning of the given subset and offset. @@ -566,10 +530,12 @@ def cpp_offset_expr(d: data.Data, subset_in: subsets.Subset, offset=None, packed vector length that the final offset should be divided by. :param indices: A tuple of indices to use for expression. + :param codegen: Optional code generator to adjust subset. :return: A string in C++ syntax with the correct offset """ - if fpga.is_multibank_array_with_distributed_index(d): - subset_in = fpga.modify_distributed_subset(subset_in, 0) + # Offset according to code generator + if codegen is not None: + subset_in = codegen.adjust_subset_for_codegen(d, subset_in) # Offset according to parameters, then offset according to array if offset is not None: @@ -597,13 +563,14 @@ def cpp_array_expr(sdfg, use_other_subset=False, indices=None, referenced_array=None, - codegen=None): + codegen: Optional['TargetCodeGenerator'] = None, + framecode: Optional['DaCeCodeGenerator'] = None): """ Converts an Indices/Range object to a C++ array access string. """ subset = memlet.subset if not use_other_subset else memlet.other_subset s = subset if relative_offset else subsets.Indices(offset) o = offset if relative_offset else None desc = (sdfg.arrays[memlet.data] if referenced_array is None else referenced_array) - offset_cppstr = cpp_offset_expr(desc, s, o, packed_veclen, indices=indices) + offset_cppstr = cpp_offset_expr(desc, s, o, packed_veclen, indices=indices, codegen=codegen) # NOTE: Are there any cases where a mix of '.' and '->' is needed when traversing nested structs? # TODO: Study this when changing Structures to be (optionally?) non-pointers. @@ -614,17 +581,10 @@ def cpp_array_expr(sdfg, name = memlet.data if with_brackets: - if fpga.is_fpga_array(desc): - # get conf flag - decouple_array_interfaces = Config.get_bool("compiler", "xilinx", "decouple_array_interfaces") - # TODO: Study structures on FPGAs. Should probably use 'name' instead of 'memlet.data' here. - ptrname = fpga.fpga_ptr(memlet.data, - desc, - sdfg, - subset, - decouple_array_interfaces=decouple_array_interfaces) + if codegen is not None: + ptrname = codegen.ptr(name, desc, sdfg, memlet.subset) else: - ptrname = ptr(name, desc, sdfg, codegen) + ptrname = ptr(name, desc, sdfg, framecode=framecode) return "%s[%s]" % (ptrname, offset_cppstr) else: return offset_cppstr @@ -654,7 +614,7 @@ def cpp_ptr_expr(sdfg, use_other_subset=False, indices=None, is_write=None, - codegen=None, + codegen: 'TargetCodeGenerator' = None, decouple_array_interface=False): """ Converts a memlet to a C++ pointer expression. """ subset = memlet.subset if not use_other_subset else memlet.other_subset @@ -664,19 +624,8 @@ def cpp_ptr_expr(sdfg, if isinstance(indices, str): offset_cppstr = indices else: - offset_cppstr = cpp_offset_expr(desc, s, o, indices=indices) - if fpga.is_fpga_array(desc): - dname = fpga.fpga_ptr(memlet.data, - desc, - sdfg, - s, - is_write, - None, - None, - defined_type == DefinedType.ArrayInterface, - decouple_array_interfaces=decouple_array_interface) - else: - dname = ptr(memlet.data, desc, sdfg, codegen) + offset_cppstr = cpp_offset_expr(desc, s, o, indices=indices, codegen=codegen) + dname = codegen.ptr(memlet.data, desc, sdfg, memlet.subset) if defined_type == DefinedType.Scalar: dname = '&' + dname @@ -1069,9 +1018,9 @@ class InterstateEdgeUnparser(cppunparse.CPPUnparser): inter-state edge code generation. """ - def __init__(self, sdfg: SDFG, tree: ast.AST, file: IO[str], defined_symbols=None, codegen=None): + def __init__(self, sdfg: SDFG, tree: ast.AST, file: IO[str], defined_symbols=None, framecode=None): self.sdfg = sdfg - self.codegen = codegen + self.framecode = framecode super().__init__(tree, 0, cppunparse.CPPLocals(), file, expr_semicolon=False, defined_symbols=defined_symbols) def _Name(self, t: ast.Name): @@ -1081,7 +1030,7 @@ def _Name(self, t: ast.Name): # Replace values with their code-generated names (for example, persistent arrays) desc = self.sdfg.arrays[t.id] ref = '' if not isinstance(desc, data.View) else '*' - self.write(ref + ptr(t.id, desc, self.sdfg, self.codegen)) + self.write(ref + ptr(t.id, desc, self.sdfg, self.framecode)) def _Attribute(self, t: ast.Attribute): from dace.frontend.python.astutils import rname @@ -1091,7 +1040,7 @@ def _Attribute(self, t: ast.Attribute): # Replace values with their code-generated names (for example, persistent arrays) desc = self.sdfg.arrays[name] - self.write(ptr(name, desc, self.sdfg, self.codegen)) + self.write(ptr(name, desc, self.sdfg, self.framecode)) def _Subscript(self, t: ast.Subscript): from dace.frontend.python.astutils import subscript_to_slice @@ -1101,14 +1050,7 @@ def _Subscript(self, t: ast.Subscript): raise SyntaxError('Range subscripts disallowed in interstate edges') memlet = mmlt.Memlet(data=target, subset=rng) - - if target not in self.sdfg.arrays: - # This could be an FPGA array whose name has been mangled - unqualified = fpga.unqualify_fpga_array_name(self.sdfg, target) - desc = self.sdfg.arrays[unqualified] - self.write(cpp_array_expr(self.sdfg, memlet, referenced_array=desc, codegen=self.codegen)) - else: - self.write(cpp_array_expr(self.sdfg, memlet, codegen=self.codegen)) + self.write(cpp_array_expr(self.sdfg, memlet, framecode=self.framecode)) class DaCeKeywordRemover(ExtNodeTransformer): @@ -1128,7 +1070,6 @@ def __init__(self, sdfg, memlets, constants, codegen): self.constants = constants self.codegen = codegen self.allow_casts = True - self._decouple_array_interfaces = Config.get_bool("compiler", "xilinx", "decouple_array_interfaces") def visit_TopLevelExpr(self, node): # This is a DaCe shift, omit it @@ -1235,7 +1176,7 @@ def visit_Assign(self, node): return node elif isinstance(desc, data.Stream): if desc.is_stream_array(): - index = cpp_offset_expr(desc, memlet.subset) + index = cpp_offset_expr(desc, memlet.subset, codegen=self.codegen) target = f"{ptrname}[{index}]" else: target = ptrname @@ -1252,25 +1193,16 @@ def visit_Assign(self, node): )) elif (var_type != DefinedType.ArrayInterface or isinstance(desc, data.View)): newnode = ast.Name(id="%s = %s;" % ( - cpp_array_expr(self.sdfg, memlet, codegen=self.codegen._frame), + cpp_array_expr(self.sdfg, memlet, codegen=self.codegen), cppunparse.cppunparse(value, expr_semicolon=False), )) else: - array_interface_name = fpga.fpga_ptr( - ptrname, - desc, - self.sdfg, - memlet.dst_subset, - True, - None, - None, - True, - decouple_array_interfaces=self._decouple_array_interfaces) + array_interface_name = self.codegen.ptr(ptrname, desc, self.sdfg, memlet.dst_subset, True, + None, None, True) newnode = ast.Name( id=f"{array_interface_name}" f"[{cpp_array_expr(self.sdfg, memlet, with_brackets=False, codegen=self.codegen._frame)}]" f" = {cppunparse.cppunparse(value, expr_semicolon=False)};") - return self._replace_assignment(newnode, node) except TypeError: # cannot determine truth value of Relational pass diff --git a/dace/codegen/targets/cpu.py b/dace/codegen/targets/cpu.py index ef9b42fe33..82bd7da280 100644 --- a/dace/codegen/targets/cpu.py +++ b/dace/codegen/targets/cpu.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. from copy import deepcopy from dace.sdfg.graph import MultiConnectorEdge from dace.sdfg.state import ControlFlowRegion, SDFGState, StateSubgraphView @@ -11,7 +11,7 @@ from dace.codegen.prettycode import CodeIOStream from dace.codegen.targets import cpp from dace.codegen.common import codeblock_to_cpp, sym2cpp, update_persistent_desc -from dace.codegen.targets.target import TargetCodeGenerator, make_absolute +from dace.codegen.target import TargetCodeGenerator, make_absolute from dace.codegen.dispatcher import DefinedType, TargetDispatcher from dace.frontend import operations from dace.sdfg import nodes, utils as sdutils @@ -20,7 +20,6 @@ from dace.sdfg.scope import is_devicelevel_gpu, is_in_scope from dace.sdfg.validation import validate_memlet_data from typing import TYPE_CHECKING, Optional, Tuple, Union -from dace.codegen.targets import fpga if TYPE_CHECKING: from dace.codegen.targets.framecode import DaCeCodeGenerator @@ -93,11 +92,6 @@ def __init__(self, frame_codegen: 'DaCeCodeGenerator', sdfg: SDFG): # Keep nested SDFG schedule when descending into it self._toplevel_schedule = None - # FIXME: this allows other code generators to change the CPU - # behavior to assume that arrays point to packed types, thus dividing - # all addresess by the vector length. - self._packed_types = False - # Keep track of traversed nodes self._generated_nodes = set() @@ -194,7 +188,7 @@ def allocate_view(self, name = node.data nodedesc = node.desc(sdfg) - ptrname = cpp.ptr(name, nodedesc, sdfg, self._frame) + ptrname = self.ptr(name, nodedesc, sdfg) # Check if array is already declared declared = self._dispatcher.declared_arrays.has(ptrname) @@ -227,14 +221,14 @@ def allocate_view(self, memlet, name, dtypes.pointer(nodedesc.dtype), + codegen=self, ancestor=0, - is_write=is_write, - decouple_array_interfaces=decouple_array_interfaces) + is_write=is_write) # Test for views of container arrays and structs if isinstance(sdfg.arrays[viewed_dnode.data], (data.Structure, data.ContainerArray, data.ContainerView)): vdesc = sdfg.arrays[viewed_dnode.data] - ptrname = cpp.ptr(memlet.data, vdesc, sdfg, self._dispatcher.frame) + ptrname = self.ptr(memlet.data, vdesc, sdfg) field_name = None if is_write and mpath[-1].dst_conn: field_name = mpath[-1].dst_conn @@ -243,12 +237,12 @@ def allocate_view(self, # Plain view into a container array if isinstance(vdesc, data.ContainerArray) and not isinstance(vdesc.stype, data.Structure): - offset = cpp.cpp_offset_expr(vdesc, memlet.subset) + offset = cpp.cpp_offset_expr(vdesc, memlet.subset, codegen=self) value = f'{ptrname}[{offset}]' else: if field_name is not None: if isinstance(vdesc, data.ContainerArray): - offset = cpp.cpp_offset_expr(vdesc, memlet.subset) + offset = cpp.cpp_offset_expr(vdesc, memlet.subset, codegen=self) arrexpr = f'{ptrname}[{offset}]' stype = vdesc.stype else: @@ -282,7 +276,7 @@ def allocate_reference(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: SDFGState, allocation_stream: CodeIOStream) -> None: name = node.data nodedesc = node.desc(sdfg) - ptrname = cpp.ptr(name, nodedesc, sdfg, self._frame) + ptrname = self.ptr(name, nodedesc, sdfg) # Check if reference is already declared declared = self._dispatcher.declared_arrays.has(ptrname) @@ -305,7 +299,7 @@ def declare_array(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphVi "that must have their declaration and allocation separate.") name = node.root_data - ptrname = cpp.ptr(name, nodedesc, sdfg, self._frame) + ptrname = self.ptr(name, nodedesc, sdfg) if nodedesc.transient is False: return @@ -351,7 +345,7 @@ def allocate_array(self, declaration_stream: CodeIOStream, allocation_stream: CodeIOStream, allocate_nested_data: bool = True) -> None: - alloc_name = cpp.ptr(node.data, nodedesc, sdfg, self._frame) + alloc_name = self.ptr(node.data, nodedesc, sdfg) name = alloc_name tokens = node.data.split('.') @@ -372,7 +366,7 @@ def allocate_array(self, if len(tokens) > 1: for i in range(len(tokens) - 1): tmp_name = '.'.join(tokens[:i + 1]) - tmp_alloc_name = cpp.ptr(tmp_name, sdfg.arrays[tmp_name], sdfg, self._frame) + tmp_alloc_name = self.ptr(tmp_name, sdfg.arrays[tmp_name], sdfg) if not self._dispatcher.defined_vars.has(tmp_alloc_name): self.allocate_array(sdfg, cfg, @@ -447,11 +441,7 @@ def allocate_array(self, self.allocate_array(sdfg, cfg, dfg, state_id, memlet_path[-1].dst, memlet_path[-1].dst.desc(sdfg), function_stream, declaration_stream, allocation_stream) - array_expr = cpp.copy_expr(self._dispatcher, - sdfg, - nodedesc.sink, - edges[0].data, - packed_types=self._packed_types) + array_expr = cpp.copy_expr(self._dispatcher, sdfg, nodedesc.sink, edges[0].data) threadlocal = "" threadlocal_stores = [dtypes.StorageType.CPU_ThreadLocal, dtypes.StorageType.Register] if (sdfg.arrays[nodedesc.sink].storage in threadlocal_stores or nodedesc.storage in threadlocal_stores): @@ -577,7 +567,7 @@ def deallocate_array(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgrap if not isinstance(nodedesc.dtype, dtypes.opaque): arrsize_bytes = arrsize * nodedesc.dtype.bytes - alloc_name = cpp.ptr(node.data, nodedesc, sdfg, self._frame) + alloc_name = self.ptr(node.data, nodedesc, sdfg) if isinstance(nodedesc, data.Array) and nodedesc.start_offset != 0: alloc_name = f'({alloc_name} - {cpp.sym2cpp(nodedesc.start_offset)})' @@ -718,7 +708,7 @@ def _emit_copy( dst_nodedesc = dst_node.desc(sdfg) if write: - vconn = cpp.ptr(dst_node.data, dst_nodedesc, sdfg, self._frame) + vconn = self.ptr(dst_node.data, dst_nodedesc, sdfg) ctype = dst_nodedesc.dtype.ctype ############################################# @@ -726,10 +716,10 @@ def _emit_copy( # Setting a reference if isinstance(dst_nodedesc, data.Reference) and orig_vconn == 'set': - srcptr = cpp.ptr(src_node.data, src_nodedesc, sdfg, self._frame) + srcptr = self.ptr(src_node.data, src_nodedesc, sdfg) defined_type, _ = self._dispatcher.defined_vars.get(srcptr) stream.write( - "%s = %s;" % (vconn, cpp.cpp_ptr_expr(sdfg, memlet, defined_type)), + "%s = %s;" % (vconn, cpp.cpp_ptr_expr(sdfg, memlet, defined_type, codegen=self)), cfg, state_id, [src_node, dst_node], @@ -756,14 +746,12 @@ def _emit_copy( if memlet.data != src_node.data and memlet.other_subset: stream_subset = memlet.other_subset - stream_expr = cpp.cpp_offset_expr(src_nodedesc, stream_subset) - array_expr = cpp.cpp_offset_expr(dst_nodedesc, array_subset) + stream_expr = cpp.cpp_offset_expr(src_nodedesc, stream_subset, codegen=self) + array_expr = cpp.cpp_offset_expr(dst_nodedesc, array_subset, codegen=self) assert functools.reduce(lambda a, b: a * b, src_nodedesc.shape, 1) == 1 stream.write( - "{s}.pop(&{arr}[{aexpr}], {maxsize});".format(s=cpp.ptr(src_node.data, src_nodedesc, sdfg, - self._frame), - arr=cpp.ptr(dst_node.data, dst_nodedesc, sdfg, - self._frame), + "{s}.pop(&{arr}[{aexpr}], {maxsize});".format(s=self.ptr(src_node.data, src_nodedesc, sdfg), + arr=self.ptr(dst_node.data, dst_nodedesc, sdfg), aexpr=array_expr, maxsize=cpp.sym2cpp(array_subset.num_elements())), cfg, @@ -775,17 +763,17 @@ def _emit_copy( if isinstance(src_nodedesc, (data.Scalar, data.Array)) and isinstance(dst_nodedesc, data.Stream): if isinstance(src_nodedesc, data.Scalar): stream.write( - "{s}.push({arr});".format(s=cpp.ptr(dst_node.data, dst_nodedesc, sdfg, self._frame), - arr=cpp.ptr(src_node.data, src_nodedesc, sdfg, self._frame)), + "{s}.push({arr});".format(s=self.ptr(dst_node.data, dst_nodedesc, sdfg), + arr=self.ptr(src_node.data, src_nodedesc, sdfg)), cfg, state_id, [src_node, dst_node], ) elif hasattr(src_nodedesc, "src"): # ArrayStreamView stream.write( - "{s}.push({arr});".format(s=cpp.ptr(dst_node.data, dst_nodedesc, sdfg, self._frame), - arr=cpp.ptr(src_nodedesc.src, sdfg.arrays[src_nodedesc.src], sdfg, - self._frame)), + "{s}.push({arr});".format(s=self.ptr(dst_node.data, dst_nodedesc, sdfg), + arr=self.ptr(src_nodedesc.src, sdfg.arrays[src_nodedesc.src], + sdfg)), cfg, state_id, [src_node, dst_node], @@ -793,9 +781,8 @@ def _emit_copy( else: copysize = " * ".join([cpp.sym2cpp(s) for s in memlet.subset.size()]) stream.write( - "{s}.push({arr}, {size});".format(s=cpp.ptr(dst_node.data, dst_nodedesc, sdfg, self._frame), - arr=cpp.ptr(src_node.data, src_nodedesc, sdfg, - self._frame), + "{s}.push({arr}, {size});".format(s=self.ptr(dst_node.data, dst_nodedesc, sdfg), + arr=self.ptr(src_node.data, src_nodedesc, sdfg), size=copysize), cfg, state_id, @@ -811,7 +798,7 @@ def _emit_copy( state_dfg: SDFGState = cfg.nodes()[state_id] copy_shape, src_strides, dst_strides, src_expr, dst_expr = cpp.memlet_copy_to_absolute_strides( - self._dispatcher, sdfg, state_dfg, edge, src_node, dst_node, self._packed_types) + self._dispatcher, sdfg, state_dfg, edge, src_node, dst_node, codegen=self) # Which numbers to include in the variable argument part dynshape, dynsrc, dyndst = 1, 1, 1 @@ -949,12 +936,12 @@ def write_and_resolve_expr(self, redtype = operations.detect_reduction_type(memlet.wcr) atomic = "_atomic" if not nc else "" - ptrname = cpp.ptr(memlet.data, sdfg.arrays[memlet.data], sdfg, self._frame) + ptrname = self.ptr(memlet.data, sdfg.arrays[memlet.data], sdfg) defined_type, _ = self._dispatcher.defined_vars.get(ptrname) if isinstance(indices, str): - ptr = '%s + %s' % (cpp.cpp_ptr_expr(sdfg, memlet, defined_type, codegen=self._frame), indices) + ptr = '%s + %s' % (cpp.cpp_ptr_expr(sdfg, memlet, defined_type, codegen=self), indices) else: - ptr = cpp.cpp_ptr_expr(sdfg, memlet, defined_type, indices=indices, codegen=self._frame) + ptr = cpp.cpp_ptr_expr(sdfg, memlet, defined_type, indices=indices, codegen=self) if isinstance(dtype, dtypes.pointer): dtype = dtype.base_type @@ -972,7 +959,6 @@ def write_and_resolve_expr(self, vec_prefix = 'v' vec_suffix = f'<{dtype.veclen}>' dtype = dtype.base_type - func = f'{vec_prefix}reduce{atomic}{vec_suffix}' # Special call for detected reduction types @@ -1064,10 +1050,10 @@ def process_out_memlets(self, out_local_name = " __" + uconn in_local_name = uconn if not locals_defined: - out_local_name = self.memlet_ctor(sdfg, memlet, node.out_connectors[uconn], True) + out_local_name = codegen.memlet_ctor(sdfg, memlet, node.out_connectors[uconn], True) in_memlets = [d for _, _, _, _, d in dfg.in_edges(node)] assert len(in_memlets) == 1 - in_local_name = self.memlet_ctor(sdfg, in_memlets[0], node.out_connectors[uconn], False) + in_local_name = codegen.memlet_ctor(sdfg, in_memlets[0], node.out_connectors[uconn], False) if memlet.wcr is not None: nc = not cpp.is_write_conflicted(dfg, edge, sdfg_schedule=self._toplevel_schedule) @@ -1079,7 +1065,7 @@ def process_out_memlets(self, # which we skip since the memlets are references continue desc = sdfg.arrays[memlet.data] - ptrname = cpp.ptr(memlet.data, desc, sdfg, self._frame) + ptrname = codegen.ptr(memlet.data, desc, sdfg) is_global = desc.lifetime in (dtypes.AllocationLifetime.Global, dtypes.AllocationLifetime.Persistent, dtypes.AllocationLifetime.External) @@ -1089,37 +1075,14 @@ def process_out_memlets(self, defined_type, _ = self._dispatcher.defined_vars.get(ptrname, is_global=is_global) if defined_type == DefinedType.Scalar: - mname = cpp.ptr(memlet.data, desc, sdfg, self._frame) + mname = codegen.ptr(memlet.data, desc, sdfg) write_expr = f"{mname} = {in_local_name};" elif defined_type == DefinedType.Pointer and is_refset: - mname = cpp.ptr(memlet.data, desc, sdfg, self._frame) + mname = codegen.ptr(memlet.data, desc, sdfg) write_expr = f"{mname} = {in_local_name};" - elif (defined_type == DefinedType.ArrayInterface and not isinstance(desc, data.View)): - # Special case: No need to write anything between - # array interfaces going out - try: - deftype, _ = self._dispatcher.defined_vars.get(in_local_name) - except KeyError: - deftype = None - if deftype == DefinedType.ArrayInterface: - continue - array_expr = cpp.cpp_array_expr(sdfg, memlet, with_brackets=False, codegen=self._frame) - decouple_array_interfaces = Config.get_bool("compiler", "xilinx", - "decouple_array_interfaces") - ptr_str = fpga.fpga_ptr( # we are on fpga, since this is array interface - memlet.data, - desc, - sdfg, - memlet.subset, - True, - None, - None, - True, - decouple_array_interfaces=decouple_array_interfaces) - write_expr = f"*({ptr_str} + {array_expr}) = {in_local_name};" else: desc_dtype = desc.dtype - expr = cpp.cpp_array_expr(sdfg, memlet, codegen=self._frame) + expr = cpp.cpp_array_expr(sdfg, memlet, codegen=codegen) write_expr = codegen.make_ptr_assignment(in_local_name, conntype, expr, desc_dtype) # Write out @@ -1155,7 +1118,7 @@ def make_ptr_assignment(self, src_expr, src_dtype, dst_expr, dst_dtype, codegen= def memlet_view_ctor(self, sdfg: SDFG, memlet: mmlt.Memlet, dtype, is_output: bool) -> str: memlet_params = [] - memlet_name = cpp.ptr(memlet.data, sdfg.arrays[memlet.data], sdfg, self._frame) + memlet_name = self.ptr(memlet.data, sdfg.arrays[memlet.data], sdfg) def_type, _ = self._dispatcher.defined_vars.get(memlet_name) if def_type == DefinedType.Pointer: @@ -1166,13 +1129,7 @@ def memlet_view_ctor(self, sdfg: SDFG, memlet: mmlt.Memlet, dtype, is_output: bo raise TypeError("Unsupported connector type {}".format(def_type)) if isinstance(memlet.subset, subsets.Indices): - - # FIXME: _packed_types influences how this offset is - # generated from the FPGA codegen. We should find a nicer solution. - if self._packed_types is True: - offset = cpp.cpp_array_expr(sdfg, memlet, False, codegen=self._frame) - else: - offset = cpp.cpp_array_expr(sdfg, memlet, False, codegen=self._frame) + offset = cpp.cpp_array_expr(sdfg, memlet, False, codegen=self) # Compute address memlet_params.append(memlet_expr + " + " + offset) @@ -1183,14 +1140,7 @@ def memlet_view_ctor(self, sdfg: SDFG, memlet: mmlt.Memlet, dtype, is_output: bo if isinstance(memlet.subset, subsets.Range): dims = len(memlet.subset.ranges) - - # FIXME: _packed_types influences how this offset is - # generated from the FPGA codegen. We should find a nicer - # solution. - if self._packed_types is True: - offset = cpp.cpp_offset_expr(sdfg.arrays[memlet.data], memlet.subset) - else: - offset = cpp.cpp_offset_expr(sdfg.arrays[memlet.data], memlet.subset) + offset = cpp.cpp_offset_expr(sdfg.arrays[memlet.data], memlet.subset, codegen=self) if offset == "0": memlet_params.append(memlet_expr) else: @@ -1252,7 +1202,7 @@ def memlet_definition(self, local_name: str, conntype: Union[data.Data, dtypes.typeclass] = None, allow_shadowing: bool = False, - codegen: 'CPUCodeGen' = None): + codegen: Optional['CPUCodeGen'] = None): # TODO: Robust rule set if conntype is None: raise ValueError('Cannot define memlet for "%s" without connector type' % local_name) @@ -1272,7 +1222,7 @@ def memlet_definition(self, # Allocate variable type memlet_type = conntype.dtype.ctype - ptr = cpp.ptr(memlet.data, desc, sdfg, self._frame) + ptr = codegen.ptr(memlet.data, desc, sdfg) types = None # Non-free symbol dependent Arrays due to their shape dependent_shape = (isinstance(desc, data.Array) and not isinstance(desc, data.View) and any( @@ -1288,20 +1238,8 @@ def memlet_definition(self, types = self._dispatcher.defined_vars.get(ptr, is_global=True) var_type, ctypedef = types - if fpga.is_fpga_array(desc): - decouple_array_interfaces = Config.get_bool("compiler", "xilinx", "decouple_array_interfaces") - ptr = fpga.fpga_ptr(memlet.data, - desc, - sdfg, - memlet.subset, - output, - self._dispatcher, - 0, - var_type == DefinedType.ArrayInterface and not isinstance(desc, data.View), - decouple_array_interfaces=decouple_array_interfaces) - result = '' - expr = (cpp.cpp_array_expr(sdfg, memlet, with_brackets=False, codegen=self._frame) + expr = (cpp.cpp_array_expr(sdfg, memlet, with_brackets=False, codegen=self) if var_type in [DefinedType.Pointer, DefinedType.StreamArray, DefinedType.ArrayInterface] else ptr) if expr != ptr: @@ -1345,7 +1283,7 @@ def memlet_definition(self, if not memlet.dynamic and memlet.num_accesses == 1: if not output: if isinstance(desc, data.Stream) and desc.is_stream_array(): - index = cpp.cpp_offset_expr(desc, memlet.subset) + index = cpp.cpp_offset_expr(desc, memlet.subset, codegen=codegen) expr = f"{memlet.data}[{index}]" result += f'{memlet_type} {local_name} = ({expr}).pop();' defined = DefinedType.Scalar @@ -1364,11 +1302,11 @@ def memlet_definition(self, def memlet_stream_ctor(self, sdfg: SDFG, memlet: mmlt.Memlet) -> str: stream = sdfg.arrays[memlet.data] - return memlet.data + ("[{}]".format(cpp.cpp_offset_expr(stream, memlet.subset)) + return memlet.data + ("[{}]".format(cpp.cpp_offset_expr(stream, memlet.subset, codegen=self)) if isinstance(stream, data.Stream) and stream.is_stream_array() else "") def memlet_ctor(self, sdfg: SDFG, memlet: mmlt.Memlet, dtype, is_output: bool) -> str: - ptrname = cpp.ptr(memlet.data, sdfg.arrays[memlet.data], sdfg, self._frame) + ptrname = self.ptr(memlet.data, sdfg.arrays[memlet.data], sdfg) def_type, _ = self._dispatcher.defined_vars.get(ptrname) if def_type in [DefinedType.Stream, DefinedType.Object, DefinedType.StreamArray]: @@ -1593,11 +1531,11 @@ def define_out_memlet(self, sdfg: SDFG, cfg: ControlFlowRegion, state_dfg: State is_refset = isinstance(desc, data.Reference) and state_dfg.memlet_path(edge)[-1].dst_conn == 'set' if not is_refset and not isinstance(desc.dtype, dtypes.pointer): - ptrname = cpp.ptr(edge.data.data, desc, sdfg, self._frame) + ptrname = self.ptr(edge.data.data, desc, sdfg) is_global = desc.lifetime in (dtypes.AllocationLifetime.Global, dtypes.AllocationLifetime.Persistent, dtypes.AllocationLifetime.External) defined_type, _ = self._dispatcher.defined_vars.get(ptrname, is_global=is_global) - base_ptr = cpp.cpp_ptr_expr(sdfg, edge.data, defined_type, codegen=self._frame) + base_ptr = cpp.cpp_ptr_expr(sdfg, edge.data, defined_type, codegen=self) callsite_stream.write(f'{cdtype.ctype} {edge.src_conn} = {base_ptr};', cfg, state_id, src_node) else: callsite_stream.write(f'{cdtype.as_arg(edge.src_conn)};', cfg, state_id, src_node) @@ -1605,7 +1543,6 @@ def define_out_memlet(self, sdfg: SDFG, cfg: ControlFlowRegion, state_dfg: State callsite_stream.write(f'{cdtype.ctype} {edge.src_conn};', cfg, state_id, src_node) def generate_nsdfg_header(self, sdfg, cfg, state, state_id, node, memlet_references, sdfg_label, state_struct=True): - # TODO: Use a single method for GPU kernels, FPGA modules, and NSDFGs arguments = [] if state_struct: @@ -1654,12 +1591,6 @@ def generate_nsdfg_arguments(self, sdfg, cfg, dfg, state, node): # Connectors that are both input and output share the same name inout = set(node.in_connectors.keys() & node.out_connectors.keys()) - for _, _, _, vconn, memlet in state.all_edges(node): - if (memlet.data in sdfg.arrays and fpga.is_multibank_array(sdfg.arrays[memlet.data]) - and fpga.parse_location_bank(sdfg.arrays[memlet.data])[0] == "HBM"): - - raise NotImplementedError("HBM in nested SDFGs not supported in non-FPGA code.") - memlet_references = [] for _, _, _, vconn, in_memlet in sorted(state.in_edges(node), key=lambda e: e.dst_conn or ''): if vconn in inout or in_memlet.data is None: @@ -1669,6 +1600,7 @@ def generate_nsdfg_arguments(self, sdfg, cfg, dfg, state, node): sdfg, in_memlet, vconn, + codegen=self, is_write=vconn in node.out_connectors, conntype=node.in_connectors[vconn])) @@ -1679,6 +1611,7 @@ def generate_nsdfg_arguments(self, sdfg, cfg, dfg, state, node): sdfg, out_memlet, uconn, + codegen=self, conntype=node.out_connectors[uconn])) return memlet_references @@ -1873,7 +1806,7 @@ def _generate_MapEntry( # Define all input connectors of this map entry for e in dynamic_map_inputs(state_dfg, node): - if cpp.ptr(e.data.data, sdfg.arrays[e.data.data], sdfg, self._frame) != e.dst_conn: + if self.ptr(e.data.data, sdfg.arrays[e.data.data], sdfg) != e.dst_conn: callsite_stream.write( self.memlet_definition(sdfg, e.data, False, e.dst_conn, e.dst.in_connectors[e.dst_conn]), cfg, state_id, node) @@ -2313,3 +2246,29 @@ def generate_tasklet_postamble(self, sdfg, cfg, dfg_scope, state_id, node, funct def make_ptr_vector_cast(self, *args, **kwargs): return cpp.make_ptr_vector_cast(*args, **kwargs) + + def ptr(self, + name: str, + desc: data.Data, + sdfg: SDFG = None, + subset: Optional[subsets.Subset] = None, + is_write: Optional[bool] = None, + ancestor: int = 0) -> str: + """ + Returns a string that points to the data based on its name and descriptor. + + :param name: Data name. + :param desc: Data descriptor. + :param sdfg: SDFG in which the data resides. + :param subset: Optional subset associated with the data. + :param is_write: Whether the access is a write access. + :param ancestor: Scope ancestor level. + :return: C-compatible name that can be used to access the data. + """ + return cpp.ptr(name, desc, sdfg, self._frame) + + def emit_interstate_variable_declaration(self, name: str, dtype: dtypes.typeclass, callsite_stream: CodeIOStream, + sdfg: SDFG): + isvar = data.Scalar(dtype) + callsite_stream.write('%s;\n' % (isvar.as_arg(with_types=True, name=name)), sdfg) + self._frame.dispatcher.defined_vars.add(name, DefinedType.Scalar, dtype.ctype) diff --git a/dace/codegen/targets/cuda.py b/dace/codegen/targets/cuda.py index feb5193091..850fabbd45 100644 --- a/dace/codegen/targets/cuda.py +++ b/dace/codegen/targets/cuda.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. import ctypes import functools import warnings @@ -20,7 +20,7 @@ from dace.codegen.common import update_persistent_desc from dace.codegen.targets.cpp import (codeblock_to_cpp, cpp_array_expr, memlet_copy_to_absolute_strides, sym2cpp, synchronize_streams, unparse_cr, mangle_dace_state_struct_name) -from dace.codegen.targets.target import IllegalCopy, TargetCodeGenerator, make_absolute +from dace.codegen.target import IllegalCopy, TargetCodeGenerator, make_absolute from dace.config import Config from dace.frontend import operations from dace.sdfg import (SDFG, ScopeSubgraphView, SDFGState, has_dynamic_map_inputs, is_array_stream_view, @@ -183,8 +183,13 @@ def preprocess(self, sdfg: SDFG) -> None: # NOTE: If possible `memlet_copy_to_absolute_strides()` will collapse a # ND copy into a 1D copy if the memory is contiguous. In that case # `copy_shape` will only have one element. - copy_shape, src_strides, dst_strides, _, _ = memlet_copy_to_absolute_strides( - None, nsdfg, state, e, e.src, e.dst) + copy_shape, src_strides, dst_strides, _, _ = memlet_copy_to_absolute_strides(None, + nsdfg, + state, + e, + e.src, + e.dst, + codegen=self) dims = len(copy_shape) # Skip supported copy types @@ -268,8 +273,6 @@ def _compute_pool_release(self, top_sdfg: SDFG): if not pooled: continue self.has_pool = True - if self.backend != 'cuda': - raise ValueError(f'Backend "{self.backend}" does not support the memory pool allocation hint') # Keep only global arrays pooled = filter( @@ -380,13 +383,13 @@ def get_generated_codeobjects(self): pool_header = '' if self.has_pool: - poolcfg = Config.get('compiler', 'cuda', 'mempool_release_threshold') - pool_header = f''' - cudaMemPool_t mempool; - cudaDeviceGetDefaultMemPool(&mempool, 0); - uint64_t threshold = {poolcfg if poolcfg != -1 else 'UINT64_MAX'}; - cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReleaseThreshold, &threshold); -''' + poolcfg = int(Config.get('compiler', 'cuda', 'mempool_release_threshold')) + pool_header = """ + {backend}MemPool_t mempool; + {backend}DeviceGetDefaultMemPool(&mempool, 0); + uint64_t threshold = {poolcfg_threshold}; + {backend}MemPoolSetAttribute(mempool, {backend}MemPoolAttrReleaseThreshold, &threshold); +""".format(backend=self.backend, poolcfg_threshold=('UINT64_MAX' if poolcfg == -1 else poolcfg)) self._codeobject.code = """ #include <{backend_header}> @@ -573,7 +576,7 @@ def declare_array(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphVi raise NotImplementedError("The declare_array method should only be used for variables " "that must have their declaration and allocation separate.") - ptrname = cpp.ptr(node.data, nodedesc, sdfg, self._frame) + ptrname = self.ptr(node.data, nodedesc, sdfg) # Check if array is already declared if self._dispatcher.declared_arrays.has(ptrname): @@ -599,7 +602,7 @@ def declare_array(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphVi def allocate_array(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphView, state_id: int, node: nodes.AccessNode, nodedesc: dt.Data, function_stream: CodeIOStream, declaration_stream: CodeIOStream, allocation_stream: CodeIOStream) -> None: - dataname = cpp.ptr(node.data, nodedesc, sdfg, self._frame) + dataname = self.ptr(node.data, nodedesc, sdfg) try: self._dispatcher.defined_vars.get(dataname) @@ -700,7 +703,7 @@ def allocate_stream(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraph node: nodes.AccessNode, nodedesc: dt.Data, function_stream: CodeIOStream, declaration_stream: CodeIOStream, allocation_stream: CodeIOStream) -> None: dataname = node.data - allocname = cpp.ptr(dataname, nodedesc, sdfg, self._frame) + allocname = self.ptr(dataname, nodedesc, sdfg) if nodedesc.storage == dtypes.StorageType.GPU_Global: fmtargs = { 'name': allocname, # TODO: Handle persistent streams @@ -719,7 +722,7 @@ def allocate_stream(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraph raise NotImplementedError("Cannot handle streams writing to multiple arrays.") fmtargs['ptr'] = nodedesc.sink + ' + ' + cpp_array_expr( - sdfg, edges[0].data, with_brackets=False, codegen=self._frame) + sdfg, edges[0].data, with_brackets=False, codegen=self) # Assuming 1D subset of sink/src # sym2cpp(edges[0].data.subset[-1]) @@ -764,7 +767,7 @@ def allocate_stream(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraph def deallocate_stream(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphView, state_id: int, node: nodes.AccessNode, nodedesc: dt.Data, function_stream: CodeIOStream, callsite_stream: CodeIOStream) -> None: - dataname = cpp.ptr(node.data, nodedesc, sdfg, self._frame) + dataname = self.ptr(node.data, nodedesc, sdfg) if nodedesc.storage == dtypes.StorageType.GPU_Global: if is_array_stream_view(sdfg, dfg, node): callsite_stream.write('dace::FreeGPUArrayStreamView(%s);' % dataname, cfg, state_id, node) @@ -774,7 +777,7 @@ def deallocate_stream(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgra def deallocate_array(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphView, state_id: int, node: nodes.AccessNode, nodedesc: dt.Data, function_stream: CodeIOStream, callsite_stream: CodeIOStream) -> None: - dataname = cpp.ptr(node.data, nodedesc, sdfg, self._frame) + dataname = self.ptr(node.data, nodedesc, sdfg) if isinstance(nodedesc, dt.Array) and nodedesc.start_offset != 0: dataname = f'({dataname} - {cpp.sym2cpp(nodedesc.start_offset)})' @@ -1027,7 +1030,7 @@ def _emit_copy(self, state_id: int, src_node: nodes.Node, src_storage: dtypes.St # Obtain copy information copy_shape, src_strides, dst_strides, src_expr, dst_expr = (memlet_copy_to_absolute_strides( - self._dispatcher, sdfg, state_dfg, edge, src_node, dst_node, self._cpu_codegen._packed_types)) + self._dispatcher, sdfg, state_dfg, edge, src_node, dst_node, codegen=self)) dims = len(copy_shape) dtype = dst_node.desc(sdfg).dtype @@ -1088,6 +1091,11 @@ def _emit_copy(self, state_id: int, src_node: nodes.Node, src_storage: dtypes.St is_c_order = is_fortran_order dims = 1 + for instr in self._dispatcher.instrumentation.values(): + if instr is not None: + instr.on_copy_begin(sdfg, cfg, state_dfg, src_node, dst_node, edge, callsite_stream, None, + copy_shape, src_strides, dst_strides) + if dims > 2: # Currently we only support ND copies when they can be represented # as a 1D copy or as a 2D strided copy @@ -1243,6 +1251,10 @@ def _emit_copy(self, state_id: int, src_node: nodes.Node, src_storage: dtypes.St self._emit_sync(callsite_stream) + for instr in self._dispatcher.instrumentation.values(): + if instr is not None: + instr.on_copy_end(sdfg, cfg, state_dfg, src_node, dst_node, edge, callsite_stream, None) + # Copy within the GPU elif (src_storage in gpu_storage_types and dst_storage in gpu_storage_types): @@ -1270,7 +1282,7 @@ def _emit_copy(self, state_id: int, src_node: nodes.Node, src_storage: dtypes.St if inner_schedule == dtypes.ScheduleType.GPU_Device: # Obtain copy information copy_shape, src_strides, dst_strides, src_expr, dst_expr = (memlet_copy_to_absolute_strides( - self._dispatcher, sdfg, state, edge, src_node, dst_node, self._cpu_codegen._packed_types)) + self._dispatcher, sdfg, state, edge, src_node, dst_node, codegen=self)) dims = len(copy_shape) funcname = 'dace::%sTo%s%dD' % (_get_storagename(src_storage), _get_storagename(dst_storage), dims) @@ -1404,7 +1416,7 @@ def generate_state(self, continue desc = sd.arrays[name] - ptrname = cpp.ptr(name, desc, sd, self._frame) + ptrname = self.ptr(name, desc, sd) if isinstance(desc, dt.Array) and desc.start_offset != 0: ptrname = f'({ptrname} - {cpp.sym2cpp(desc.start_offset)})' @@ -1479,7 +1491,7 @@ def generate_devicelevel_state(self, sdfg: SDFG, cfg: ControlFlowRegion, state: and node.desc(sdfg).lifetime == dtypes.AllocationLifetime.Scope ] for stream in streams_to_reset: - ptrname = cpp.ptr(stream.data, stream.desc(sdfg), sdfg, self._frame) + ptrname = self.ptr(stream.data, stream.desc(sdfg), sdfg) callsite_stream.write("{}.reset();".format(ptrname), cfg, state.block_id) components = dace.sdfg.concurrent_subgraphs(state) @@ -1591,7 +1603,11 @@ def generate_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: StateSub # Handle dynamic map inputs for e in dace.sdfg.dynamic_map_inputs(state, scope_entry): - kernel_args[str(e.src)] = e.src.desc(sdfg) + if e.data is None: + raise Exception("Dynamic map input's memlet can't be None") + data_name = e.data.data + data_desc = state.sdfg.arrays[data_name] + kernel_args[data_name] = data_desc # Add data from nested SDFGs to kernel arguments extra_call_args = [] @@ -1608,12 +1624,12 @@ def generate_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: StateSub continue visited.add((nsdfg, node.data)) if desc.transient and self._frame.where_allocated[(nsdfg, node.data)] is not nsdfg: - outer_name = cpp.ptr(node.data, desc, nsdfg, self._frame) + outer_name = self.ptr(node.data, desc, nsdfg) # Create name from within kernel oldval = self._in_device_code self._in_device_code = True - inner_name = cpp.ptr(node.data, desc, nsdfg, self._frame) + inner_name = self.ptr(node.data, desc, nsdfg) self._in_device_code = oldval self.extra_nsdfg_args.append((desc.as_arg(name=''), inner_name, outer_name)) @@ -1670,12 +1686,12 @@ def generate_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: StateSub defined_type, ctype = (self._dispatcher.declared_arrays.get(aname, is_global=is_global)) except KeyError: pass - ptrname = cpp.ptr(aname, data_desc, sdfg, self._frame) + ptrname = self.ptr(aname, data_desc, sdfg) if not defined_type: defined_type, ctype = self._dispatcher.defined_vars.get(ptrname, is_global=is_global) self._in_device_code = True - inner_ptrname = cpp.ptr(aname, data_desc, sdfg, self._frame) + inner_ptrname = self.ptr(aname, data_desc, sdfg) self._in_device_code = False self._dispatcher.defined_vars.add(inner_ptrname, @@ -1688,13 +1704,13 @@ def generate_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: StateSub else: if aname in sdfg.arrays: data_desc = sdfg.arrays[aname] - ptrname = cpp.ptr(aname, data_desc, sdfg, self._frame) + ptrname = self.ptr(aname, data_desc, sdfg) is_global = data_desc.lifetime in (dtypes.AllocationLifetime.Global, dtypes.AllocationLifetime.Persistent, dtypes.AllocationLifetime.External) defined_type, ctype = self._dispatcher.defined_vars.get(ptrname, is_global=is_global) self._in_device_code = True - inner_ptrname = cpp.ptr(aname, data_desc, sdfg, self._frame) + inner_ptrname = self.ptr(aname, data_desc, sdfg) self._in_device_code = False self._dispatcher.defined_vars.add(inner_ptrname, defined_type, ctype, allow_shadowing=True) @@ -1812,7 +1828,15 @@ def generate_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: StateSub # make sure dynamic map inputs are properly handled for e in dace.sdfg.dynamic_map_inputs(state, scope_entry): + if e.data is not None and e.data.data == e.dst_conn: + warnings.warn( + f"Dynamic map input name {e.data.data} is same as the dst connector. Will result in a name clash, omitting of code for this assignment is skipped." + ) + comment_out_str = "// Omitted name clash on dynamic map input\n//" + else: + comment_out_str = "" self._localcode.write( + comment_out_str + self._cpu_codegen.memlet_definition(sdfg, e.data, False, e.dst_conn, e.dst.in_connectors[e.dst_conn]), cfg, state_id, scope_entry) @@ -1883,17 +1907,24 @@ def generate_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: StateSub callsite_stream.write( 'DACE_GPU_CHECK({backend}EventSynchronize(__state->gpu_context->events[{ev}]));'.format( ev=ev, backend=self.backend), cfg, state_id, [e.src, e.dst]) + if e.data is not None and e.data.data == e.dst_conn: + warnings.warn( + f"Dynamic map input name {e.data.data} is same as the dst connector. Will result in a name clash, omitting of code for this assignment is skipped." + ) + comment_out_str = "// Omitted name clash on dynamic map input\n//" + else: + comment_out_str = "" callsite_stream.write( + comment_out_str + self._cpu_codegen.memlet_definition(sdfg, e.data, False, e.dst_conn, e.dst.in_connectors[e.dst_conn]), cfg, state_id, node) # Invoke kernel call callsite_stream.write( '__dace_runkernel_%s(%s);\n' % - (kernel_name, - ', '.join(['__state'] + [cpp.ptr(aname, arg, sdfg, self._frame) - for aname, arg in kernel_args.items()] + extra_call_args)), cfg, state_id, - scope_entry) + (kernel_name, ', '.join(['__state'] + [self.ptr(aname, arg, sdfg) + for aname, arg in kernel_args.items()] + extra_call_args)), cfg, + state_id, scope_entry) # If there are dynamic Map inputs, put the kernel invocation in its own scope to avoid redefinitions. if dace.sdfg.has_dynamic_map_inputs(state, scope_entry): @@ -2119,8 +2150,8 @@ def get_kernel_dimensions(self, dfg_scope): # Check block size against configured maximum values, if those can be determined total_bsize = prod(block_size) - total_limit = Config.get('compiler', 'cuda', 'block_size_limit') - lastdim_limit = Config.get('compiler', 'cuda', 'block_size_lastdim_limit') + total_limit = int(Config.get('compiler', 'cuda', 'block_size_limit')) + lastdim_limit = int(Config.get('compiler', 'cuda', 'block_size_lastdim_limit')) if (total_bsize > total_limit) == True: raise ValueError(f'Block size for kernel "{kernelmap_entry.map.label}" ({block_size}) ' f'is larger than the possible number of threads per block ({total_limit}). ' @@ -2175,7 +2206,15 @@ def generate_kernel_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: S # handle dynamic map inputs for e in dace.sdfg.dynamic_map_inputs(cfg.node(state_id), dfg_scope.source_nodes()[0]): + if e.data is not None and e.data.data == e.dst_conn: + warnings.warn( + f"Dynamic map input name {e.data.data} is same as the dst connector. Will result in a name clash, omitting of code for this assignment is skipped." + ) + comment_out_str = "// Omitted name clash on dynamic map input\n//" + else: + comment_out_str = "" kernel_stream.write( + comment_out_str + self._cpu_codegen.memlet_definition(sdfg, e.data, False, e.dst_conn, e.dst.in_connectors[e.dst_conn]), cfg, state_id, dfg_scope.source_nodes()[0]) @@ -2353,9 +2392,16 @@ def generate_devicelevel_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_sco # They define outside the schedule the bounds of the dynamic Map's for-loop invocation. # NOTE: The value of the dynamic Map's variable may differ inside and outside the schedule. for e in dace.sdfg.dynamic_map_inputs(dfg, scope_entry): + if e.data is not None and e.data.data == e.dst_conn: + warnings.warn( + f"Dynamic map input name {e.data.data} is same as the dst connector. Will result in a name clash, omitting of code for this assignment is skipped." + ) + comment_out_str = "// Omitted name clash on dynamic map input\n//" + else: + comment_out_str = "" callsite_stream.write( - self._cpu_codegen.memlet_definition(sdfg, e.data, False, e.dst_conn, - e.dst.in_connectors[e.dst_conn]), cfg, state_id, scope_entry) + comment_out_str + self._cpu_codegen.memlet_definition( + sdfg, e.data, False, e.dst_conn, e.dst.in_connectors[e.dst_conn]), cfg, state_id, scope_entry) dynmap_var = scope_map.params[0] dynmap_begin = scope_map.range[0][0] @@ -2384,9 +2430,16 @@ def generate_devicelevel_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_sco param=dynmap_var), cfg, state_id, scope_entry) for e in dace.sdfg.dynamic_map_inputs(dfg, scope_entry): + if e.data is not None and e.data.data == e.dst_conn: + warnings.warn( + f"Dynamic map input name {e.data.data} is same as the dst connector. Will result in a name clash, omitting of code for this assignment is skipped." + ) + comment_out_str = "// Omitted name clash on dynamic map input\n//" + else: + comment_out_str = "" callsite_stream.write( - self._cpu_codegen.memlet_definition(sdfg, e.data, False, e.dst_conn, - e.dst.in_connectors[e.dst_conn]), cfg, state_id, scope_entry) + comment_out_str + self._cpu_codegen.memlet_definition( + sdfg, e.data, False, e.dst_conn, e.dst.in_connectors[e.dst_conn]), cfg, state_id, scope_entry) if dynmap_step != 1: callsite_stream.write( @@ -2419,9 +2472,16 @@ def generate_devicelevel_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_sco # handle dynamic map inputs for e in dace.sdfg.dynamic_map_inputs(dfg, scope_entry): + if e.data is not None and e.data.data == e.dst_conn: + warnings.warn( + f"Dynamic map input name {e.data.data} is same as the dst connector. Will result in a name clash, omitting of code for this assignment is skipped." + ) + comment_out_str = "// Omitted name clash on dynamic map input\n//" + else: + comment_out_str = "" callsite_stream.write( - self._cpu_codegen.memlet_definition(sdfg, e.data, False, e.dst_conn, - e.dst.in_connectors[e.dst_conn]), cfg, state_id, + comment_out_str + self._cpu_codegen.memlet_definition( + sdfg, e.data, False, e.dst_conn, e.dst.in_connectors[e.dst_conn]), cfg, state_id, scope_entry) # variables that need to be declared + the value they need to be initialized with @@ -2893,6 +2953,32 @@ def _generate_Tasklet(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgra def make_ptr_vector_cast(self, *args, **kwargs): return cpp.make_ptr_vector_cast(*args, **kwargs) + def ptr(self, + name: str, + desc: dt.Data, + sdfg: SDFG = None, + subset: Optional[subsets.Subset] = None, + is_write: Optional[bool] = None, + ancestor: int = 0) -> str: + """ + Returns a string that points to the data based on its name and descriptor. + + :param name: Data name. + :param desc: Data descriptor. + :param sdfg: SDFG in which the data resides. + :param subset: Optional subset associated with the data. + :param is_write: Whether the access is a write access. + :param ancestor: Scope ancestor level. + :return: C-compatible name that can be used to access the data. + """ + return cpp.ptr(name, desc, sdfg, self._frame) + + def emit_interstate_variable_declaration(self, name: str, dtype: dtypes.typeclass, callsite_stream: CodeIOStream, + sdfg: SDFG): + isvar = dt.Scalar(dtype) + callsite_stream.write('%s;\n' % (isvar.as_arg(with_types=True, name=name)), sdfg) + self._frame.dispatcher.defined_vars.add(name, DefinedType.Scalar, dtype.ctype) + ######################################################################## ######################################################################## diff --git a/dace/codegen/targets/fpga.py b/dace/codegen/targets/fpga.py deleted file mode 100644 index 5485ed078e..0000000000 --- a/dace/codegen/targets/fpga.py +++ /dev/null @@ -1,2449 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. - -from six import StringIO -import collections -import itertools -import re -import warnings -import numpy as np -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union -import copy - -import dace -from dace.codegen.targets import cpp -from dace import subsets, data as dt, dtypes, memlet, symbolic -from dace.config import Config -from dace.sdfg import SDFG, nodes, utils, dynamic_map_inputs -from dace.sdfg import ScopeSubgraphView -from dace.sdfg.graph import MultiConnectorEdge -from dace.codegen import exceptions as cgx -from dace.codegen.dispatcher import DefinedType -from dace.codegen.prettycode import CodeIOStream -from dace.codegen.common import update_persistent_desc -from dace.codegen.targets.target import TargetCodeGenerator -from dace.codegen import cppunparse -from dace.sdfg.state import ControlFlowRegion, SDFGState, StateSubgraphView -from dace.sdfg.utils import is_fpga_kernel -from dace.symbolic import evaluate -from collections import defaultdict - -if TYPE_CHECKING: - from dace.codegen.targets.framecode import DaCeCodeGenerator - from dace.codegen.targets.cpu import CPUCodeGen - -_CPU_STORAGE_TYPES = {dtypes.StorageType.CPU_Heap, dtypes.StorageType.CPU_ThreadLocal, dtypes.StorageType.CPU_Pinned} -_FPGA_STORAGE_TYPES = { - dtypes.StorageType.FPGA_Global, dtypes.StorageType.FPGA_Local, dtypes.StorageType.FPGA_Registers, - dtypes.StorageType.FPGA_ShiftRegister -} - -_FPGA_LOCAL_STORAGE_TYPES = { - dtypes.StorageType.FPGA_Local, dtypes.StorageType.FPGA_Registers, dtypes.StorageType.FPGA_ShiftRegister -} - - -def vector_element_type_of(dtype): - if isinstance(dtype, dace.pointer): - # "Dereference" the pointer type and try again - return vector_element_type_of(dtype.base_type) - elif isinstance(dtype, dace.vector): - return dtype.base_type - return dtype - - -def is_external_stream(node: dace.sdfg.nodes.Node, subgraph: Union[dace.sdfg.SDFGState, ScopeSubgraphView]): - """ - Given a node and a subgraph, returns whether this is an external stream (the other endpoint is in - another FPGA Kernel) or not. - - :return: True if node represent an external stream, False otherwise - """ - - external = False - - # If this is a stream, check if the other side of it is in the same kernel/subgraph - if isinstance(node, dace.nodes.AccessNode) and isinstance(node.desc(subgraph), dt.Stream): - for nn in subgraph.nodes(): - if nn != node and isinstance(nn, dace.nodes.AccessNode) and node.desc(subgraph) == nn.desc(subgraph): - break - else: - external = True - - return external - - -def is_multibank_array(array: dt.Data): - """ - :return: True if this array is placed on HBM/DDR on FPGA Global memory - """ - if (isinstance(array, dt.Array) and array.storage == dtypes.StorageType.FPGA_Global): - res = parse_location_bank(array) - return res is not None and (res[0] == "HBM" or res[0] == "DDR") - else: - return False - - -def is_multibank_array_with_distributed_index(array: dt.Data): - """ - :param array: access node to be checked - :return: True if this array is placed on HBM/DDR and has an extra first - dimension equal to the number of banks is placed on. For HBM/DDR arrays - spanning across multiple banks this is always true. - """ - if is_multibank_array(array): - res = parse_location_bank(array) - low, high = get_multibank_ranges_from_subset(res[1], None) - return high - low > 1 or (len(array.shape) > 1 and str(array.shape[0]) == "1") - else: - return False - - -def is_fpga_array(array: dt.Data): - """ - :return: True if this array is placed on FPGA memory - """ - return isinstance(array, dt.Array) and array.storage in _FPGA_STORAGE_TYPES - - -def iterate_multibank_interface_ids(array: dt.Array, interface_ids: Union[int, List[Tuple[int, int]]]): - """ - Works on the interface_ids generated by make_parameter. If the array is a hbm/ddr multibank array, - interface_ids is a list of tuples of the form (bank, id), and the method will yield the values - one by one. If it is not, it will return a tuple of 0 (bank) and the interface id once. - """ - if is_multibank_array_with_distributed_index(array): - for bank, id in interface_ids: - yield (bank, id) - else: - yield (0, interface_ids) - - -def iterate_distributed_subset(desc: dt.Array, access_memlet: memlet.Memlet, is_write: bool, sdfg: SDFG): - """ - :param desc: The array accessed by the memlet - :param access_memlet: The memlet - :param is_write: If we care about the write or read direction. is_write means we write to desc, - not is_write means we read from it - :return: if access_memlet contains a distributed subset the method will count from the lower to the upper - end of it. Otherwise returns 0 once. - """ - if is_multibank_array_with_distributed_index(desc): - if is_write: - subset = access_memlet.dst_subset or access_memlet.subset - else: - subset = access_memlet.src_subset or access_memlet.subset - if subset is None: - yield 0 - else: - # We can assume anywhere in the FPGA codegen that distributed subsets - # are evaluatable, because all maps are unrolled before codegen - low, high = get_multibank_ranges_from_subset(subset, sdfg) - for k in range(low, high): - yield k - else: - yield 0 - - -def modify_distributed_subset(subset: subsets.Subset, change: int): - """ - Modifies the first index of :param subset: (the one used for distributed subsets). - - :param subset: is deepcopied before any modification to it is done. - :param change: the first index is set to this value, unless it's (-1) in which case - the first index is completly removed - """ - cps = copy.deepcopy(subset) - if change == -1: - cps.pop([0]) - else: - cps[0] = (change, change, 1) - return cps - - -def get_multibank_ranges_from_subset(subset: Union[subsets.Subset, str], sdfg: SDFG) -> Tuple[int, int]: - """ - Returns the upper and lower end of the accessed multibank-range, evaluated using the - constants on the SDFG. - - :return: (low, high) where low = the lowest accessed bank and high the - highest accessed bank + 1. - """ - if isinstance(subset, str): - subset = subsets.Range.from_string(subset) - low, high, stride = subset[0] - if stride != 1: - raise NotImplementedError(f"Strided multibank subsets not supported.") - try: - low = int(symbolic.resolve_symbol_to_constant(low, sdfg)) - high = int(symbolic.resolve_symbol_to_constant(high, sdfg)) - except: - raise ValueError(f"Only constant evaluatable indices allowed for multibank-memlets on the bank index.") - return (low, high + 1) - - -def parse_location_bank(array: dt.Array) -> Tuple[str, str]: - """ - :param array: an array on FPGA global memory - :return: None if an array is given which does not have a location['memorytype'] value. - Otherwise it will return a tuple (bank_type, bank_assignment), where bank_type - is one of 'DDR', 'HBM' and bank_assignment a string that describes which banks are - used. - """ - if "memorytype" in array.location: - if "bank" not in array.location: - raise ValueError("If 'memorytype' is specified for an array 'bank' must also be specified") - val: str = array.location["bank"] - memorytype: str = array.location["memorytype"] - memorytype = memorytype.upper() - if (memorytype == "DDR" or memorytype == "HBM"): - return (memorytype, array.location["bank"]) - else: - raise ValueError(f"{memorytype} is an invalid memorytype. Supported are HBM and DDR.") - else: - return None - - -def fpga_ptr(name: str, - desc: dt.Data = None, - sdfg: SDFG = None, - subset_info: Union[subsets.Subset, int] = None, - is_write: bool = None, - dispatcher=None, - ancestor: int = 0, - is_array_interface: bool = False, - interface_id: int = None, - decouple_array_interfaces: bool = False): - """ - Returns a string that points to the data based on its name, and various other conditions - that may apply for that data field. - - :param name: Data name. - :param desc: Data descriptor. - :param subset_info: Any additional information about the accessed subset. - :param ancestor: The ancestor level where the variable should be searched for if - is_array_interface is True when dispatcher is not None - :param is_array_interface: Data is pointing to an interface in FPGA-Kernel compilation - :param interface_id: An optional interface id that will be added to the name (only for array interfaces) - :param decouple_array_interfaces: if True it will qualify the name of an array interface, depending whether - it is used for reading from or writing to memory - :return: C-compatible name that can be used to access the data. - """ - if (desc is not None and is_multibank_array_with_distributed_index(desc)): - - location_bank = parse_location_bank(desc) - mem_type = "" - if location_bank is not None: - mem_type = location_bank[0].lower() - - if (subset_info == None): - raise ValueError("Cannot generate name for bank without subset info") - elif (isinstance(subset_info, int)): - name = f"{mem_type}{subset_info}_{name}" - elif (isinstance(subset_info, subsets.Subset)): - if (sdfg == None): - raise ValueError("Cannot generate name for bank using subset if sdfg not provided") - low, high = get_multibank_ranges_from_subset(subset_info, sdfg) - if (low + 1 != high): - raise ValueError("ptr cannot generate names for subsets accessing more than one memory bank") - - name = f"{mem_type}{low}_{name}" - - subset_info = low #used for arrayinterface name where it must be int - if is_array_interface: - - if decouple_array_interfaces: - # qualify the name - if is_write is None: - raise ValueError("is_write must be set for ArrayInterface.") - ptr_in = f"__{name}_in" - ptr_out = f"__{name}_out" - if dispatcher is not None: - # DaCe allows reading from an output connector, even though it - # is not an input connector. If this occurs, panic and read - # from the output interface instead - if is_write or not dispatcher.defined_vars.has(ptr_in, ancestor): - # Throw a KeyError if this pointer also doesn't exist - dispatcher.defined_vars.get(ptr_out, ancestor) - # Otherwise use it - name = ptr_out - else: - name = ptr_in - else: - # We might call this before the variable is even defined (e.g., because - # we are about to define it), so if the dispatcher is not passed, just - # return the appropriate string - name = ptr_out if is_write else ptr_in - # Append the interface id, if provided - if interface_id is not None: - name = f"{name}_{interface_id}" - - return name - - -def unqualify_fpga_array_name(sdfg: dace.SDFG, arr_name: str): - """ - Returns the unqualified array name if it refers to an array interface. - Otherwise return it as it is. - - :param name: array name to unqualify - """ - - if arr_name not in sdfg.arrays and (arr_name.endswith('_in') - or arr_name.endswith('out')) and arr_name.startswith('__'): - unqualified = re.sub('_in$|_out$', '', arr_name) - unqualified = re.sub('^__', '', unqualified) - return unqualified - else: - return arr_name - - -def is_vendor_supported(fpga_vendor: str) -> bool: - """ - Returns wheter the given vendor is supported or not, by looking - among the registered FPGA code-generators. - - :param fpga_vendor: the fpga vendor - """ - - registered_codegens = dace.codegen.targets.target.TargetCodeGenerator._registry_ - supported_vendors = set() - for cl, attr in registered_codegens.items(): - if issubclass(cl, dace.codegen.targets.fpga.FPGACodeGen): - if attr["name"] == fpga_vendor.lower(): - break - else: - supported_vendors.add(attr["name"]) - else: - raise cgx.CompilerConfigurationError( - f"FPGA vendor {fpga_vendor} is not supported. The supported vendors are {supported_vendors}.") - - -class FPGACodeGen(TargetCodeGenerator): - # Set by deriving class - target_name = None - title = None - language = None - - def __init__(self, frame_codegen: 'DaCeCodeGenerator', sdfg: SDFG): - - # The inheriting class must set target_name, title and language. - - self._in_device_code = False - self._cpu_codegen: Optional['CPUCodeGen'] = None - self._frame = frame_codegen - self._dispatcher = frame_codegen.dispatcher - self._kernel_count = 0 - self._global_sdfg = sdfg - self._program_name = sdfg.name - - # Verify that we did not miss the allocation of any global arrays, even - # if they're nested deep in the SDFG - self._allocated_global_arrays = set() - self._unrolled_pes = set() - - # Dictionary node->kernel_id - self._node_to_kernel = defaultdict() - # Keep track of dependencies among kernels (if any) - self._kernels_dependencies = dict() - self._kernels_names_to_id = dict() - - # Register dispatchers - self._cpu_codegen = self._dispatcher.get_generic_node_dispatcher() - - self._host_codes = [] - self._ip_codes = [] - self._kernel_codes = [] - # any other kind of generated file if any (name, code object) - self._other_codes = {} - self._bank_assignments = {} # {(data name, sdfg): (type, id)} - self._stream_connections = {} # { name: [src, dst] } - # For generating kernel instrumentation code, is incremented every time - # a kernel is instrumented - self._kernel_instrumentation_index: int = 0 - - self._decouple_array_interfaces = False - # Register additional FPGA dispatchers - self._dispatcher.register_map_dispatcher( - [dtypes.ScheduleType.FPGA_Device, dtypes.ScheduleType.FPGA_Multi_Pumped], self) - - self._dispatcher.register_state_dispatcher(self, predicate=is_fpga_kernel) - - self._dispatcher.register_node_dispatcher( - self, - predicate=lambda sdfg, state, node: self._in_device_code and not (isinstance( - node, nodes.Tasklet) and node.language == dtypes.Language.SystemVerilog)) - - fpga_storage = [ - dtypes.StorageType.FPGA_Global, - dtypes.StorageType.FPGA_Local, - dtypes.StorageType.FPGA_Registers, - dtypes.StorageType.FPGA_ShiftRegister, - ] - self._dispatcher.register_array_dispatcher(fpga_storage, self) - - # Register permitted copies - for storage_from in itertools.chain(fpga_storage, [dtypes.StorageType.Register]): - for storage_to in itertools.chain(fpga_storage, [dtypes.StorageType.Register]): - if (storage_from == dtypes.StorageType.Register and storage_to == dtypes.StorageType.Register): - # register this as copy dispatcher only if the destination is scheduled on FPGA - self._dispatcher.register_copy_dispatcher(storage_from, storage_to, dtypes.ScheduleType.FPGA_Device, - self) - self._dispatcher.register_copy_dispatcher(storage_from, storage_to, - dtypes.ScheduleType.FPGA_Multi_Pumped, self) - else: - self._dispatcher.register_copy_dispatcher(storage_from, storage_to, None, self) - self._dispatcher.register_copy_dispatcher(dtypes.StorageType.FPGA_Global, dtypes.StorageType.CPU_Heap, None, - self) - self._dispatcher.register_copy_dispatcher(dtypes.StorageType.FPGA_Global, dtypes.StorageType.CPU_ThreadLocal, - None, self) - self._dispatcher.register_copy_dispatcher(dtypes.StorageType.CPU_Heap, dtypes.StorageType.FPGA_Global, None, - self) - self._dispatcher.register_copy_dispatcher(dtypes.StorageType.CPU_ThreadLocal, dtypes.StorageType.FPGA_Global, - None, self) - - # Memory width converters (gearboxing) to generate globally - self.converters_to_generate = set() - - @property - def has_initializer(self): - return True - - @property - def has_finalizer(self): - return False - - def find_rtl_tasklet(self, subgraph: ScopeSubgraphView): - ''' - Finds a tasklet with SystemVerilog as its language, within the given subgraph, if it contains one. - - :param subgraph: The subgraph to check. - :return: The tasklet node if one exists, None otherwise. - ''' - for n in subgraph.nodes(): - if isinstance(n, dace.nodes.NestedSDFG): - if len(n.sdfg.nodes()) == 1 and isinstance(n.sdfg.nodes()[0], SDFGState): - for sg in dace.sdfg.concurrent_subgraphs(n.sdfg.start_state): - node = self.find_rtl_tasklet(sg) - if node: - return node - elif isinstance(n, dace.nodes.Tasklet) and n.language == dace.dtypes.Language.SystemVerilog: - return n - return None - - def is_multi_pumped_subgraph(self, subgraph: ScopeSubgraphView): - ''' - Checks whether the given subgraph is a multi-pumped subgraph. A subgraph is multi-pumped if it contains a map whose schedule is set to multi-pumped. - - :param subgraph: The subgraph to check. - :return: True if the given subgraph is a multi-pumped subgraph. - ''' - for n in subgraph.nodes(): - if isinstance(n, dace.nodes.NestedSDFG): - if len(n.sdfg.nodes()) == 1 and isinstance(n.sdfg.nodes()[0], SDFGState): - for sg in dace.sdfg.concurrent_subgraphs(n.sdfg.nodes()[0]): - if self.is_multi_pumped_subgraph(sg): - return True - elif isinstance(n, dace.nodes.MapEntry) and n.schedule == dace.ScheduleType.FPGA_Multi_Pumped: - return True - return False - - def preprocess(self, sdfg: SDFG) -> None: - # Right before finalizing code, write FPGA context to state structure - self._frame.statestruct.append('dace_fpga_context *fpga_context;') - - # Call vendor-specific preprocessing - self._internal_preprocess(sdfg) - - def _kernels_subgraphs(self, graph: Union[dace.sdfg.SDFGState, ScopeSubgraphView], dependencies: dict): - """ - Finds subgraphs of an SDFGState or ScopeSubgraphView that correspond to kernels. - This is done by looking to which kernel, each node belongs. - - :param graph: the state/subgraph to consider - :param dependencies: a dictionary containing for each kernel ID, the IDs of the kernels on which it - depends on - :return: a list of tuples (subgraph, kernel ID) topologically ordered according kernel dependencies. - """ - from dace.sdfg.scope import ScopeSubgraphView - - if not isinstance(graph, (dace.sdfg.SDFGState, ScopeSubgraphView)): - raise TypeError("Expected SDFGState or ScopeSubgraphView, got: {}".format(type(graph).__name__)) - - subgraphs = collections.defaultdict(list) # {kernel_id: {nodes in subgraph}} - - # Go over the nodes and populate the kernels subgraphs - for node in graph.nodes(): - if isinstance(node, dace.sdfg.SDFGState): - continue - - node_repr = utils.unique_node_repr(graph, node) - if node_repr in self._node_to_kernel: - subgraphs[self._node_to_kernel[node_repr]].append(node) - - # add this node to the corresponding subgraph - if isinstance(node, dace.nodes.AccessNode): - # AccessNodes can be read from multiple kernels, so - # check all out edges - - start_nodes = [e.dst for e in graph.out_edges(node)] - for n in start_nodes: - n_repr = utils.unique_node_repr(graph, n) - if n_repr in self._node_to_kernel: - subgraphs[self._node_to_kernel[n_repr]].append(node) - - # Now stick each of the found components together in a ScopeSubgraphView and return - # them. Sort according kernel dependencies order. - - # Build a dependency graph - import networkx as nx - kernels_graph = nx.DiGraph() - for k in subgraphs.keys(): - # we could have no dependencies at all - kernels_graph.add_node(k) - if k in dependencies: - kernel_dependencies = dependencies[k] - for p in kernel_dependencies: - kernels_graph.add_edge(p, k) - - subgraph_views = [] - all_nodes = graph.nodes() - - # Use topological sort to order kernels according to their dependencies - for kernel_id in nx.topological_sort(kernels_graph): - # Return the subgraph and the kernel id - subgraph_views.append((ScopeSubgraphView(graph, [n for n in all_nodes if n in subgraphs[kernel_id]], - None), kernel_id)) - del kernels_graph - return subgraph_views - - def generate_state(self, sdfg: dace.SDFG, cfg: ControlFlowRegion, state: dace.SDFGState, - function_stream: CodeIOStream, callsite_stream: CodeIOStream) -> None: - """ - Generate an FPGA State, possibly comprising multiple Kernels and/or PEs. - - :param sdfg: - :param state: - :param function_stream: CPU code stream: contains global declarations (e.g. exported forward declaration of - device specific host functions). - :param callsite_stream: CPU code stream, contains the actual code (for creating global buffers, invoking - device host functions, and so on). - """ - state_id = state.block_id - - if not self._in_device_code: - # Avoid import loop - from dace.transformation.dataflow import MapUnroll - - # Unroll maps directly in the SDFG so the subgraphs can be - # recognized as independent processing elements - top_level_unrolled = [ - n for n in state.scope_children()[None] - if isinstance(n, dace.sdfg.nodes.MapEntry) and n.schedule == dtypes.ScheduleType.Unrolled - ] - for map_entry in top_level_unrolled: - MapUnroll.apply_to(sdfg, map_entry=map_entry) - if top_level_unrolled: - disp = self._dispatcher.get_scope_dispatcher(dtypes.ScheduleType.Unrolled) - self._dispatcher._used_targets.add(disp) - - kernels = [] # List of tuples (subgraph, kernel_id) - - # Start a new state code generation: reset previous dependencies if any - self._kernels_dependencies.clear() - self._kernels_names_to_id.clear() - - # Determine independent components: these are our starting kernels. - # Then, try to split these components further - subgraphs = dace.sdfg.concurrent_subgraphs(state) - - if Config.get_bool("compiler", "fpga", "concurrent_kernel_detection"): - start_kernel = 0 - for sg in subgraphs: - # Determine kernels in state - num_kernels, dependencies = self.partition_kernels(sg, default_kernel=start_kernel) - if num_kernels > 1: - # For each kernel, derive the corresponding subgraphs - # and keep track of dependencies - kernels.extend(self._kernels_subgraphs(sg, dependencies)) - self._kernels_dependencies.update(dependencies) - else: - kernels.append((sg, start_kernel)) - start_kernel = start_kernel + num_kernels - - # There is no need to generate additional kernels if the number of found kernels - # is equal to the number of connected components: use PEs instead (only one kernel) - if len(subgraphs) == len(kernels): - kernels = [(state, 0)] - else: - # Only one FPGA kernel (possibly with multiple PEs) - kernels = [(state, 0)] - - self._num_kernels = len(kernels) - - state_parameters = [] - - # As long as we generate kernels, generate the host file for invoking kernels, - # synchronize them, create transient buffers. - state_host_header_stream = CodeIOStream() - state_host_body_stream = CodeIOStream() - instrumentation_stream = CodeIOStream() - - # Kernels are now sorted considering their dependencies - for kern, kern_id in kernels: - # Generate all kernels in this state - subgraphs = dace.sdfg.concurrent_subgraphs(kern) - single_sgs: list(ScopeSubgraphView) = [] - multi_sgs: list(ScopeSubgraphView) = [] - for sg in subgraphs: - if self.is_multi_pumped_subgraph(sg): - multi_sgs.append(sg) - else: - single_sgs.append(sg) - shared_transients = set(sdfg.shared_transients()) - - # Allocate global memory transients, unless they are shared with - # other states - all_transients = set(kern.all_transients()) - allocated = set(shared_transients) - for node in kern.data_nodes(): - data = node.desc(sdfg) - if node.data not in all_transients or node.data in allocated: - continue - if (data.storage == dtypes.StorageType.FPGA_Global and not isinstance(data, dt.View)): - allocated.add(node.data) - self._dispatcher.dispatch_allocate(sdfg, cfg, kern, state_id, node, data, function_stream, - callsite_stream) - - # Create a unique kernel name to avoid name clashes - # If this kernels comes from a Nested SDFG, use that name also - if sdfg.parent_nsdfg_node is not None: - kernel_name = f"{sdfg.parent_nsdfg_node.label}_{state.label}_{kern_id}_{cfg.cfg_id}" - else: - kernel_name = f"{state.label}_{kern_id}_{cfg.cfg_id}" - - # Vitis HLS removes double underscores, which leads to a compilation - # error down the road due to kernel name mismatch. Remove them here - # to prevent this - kernel_name = re.sub(r"__+", "_", kernel_name) - - self._kernels_names_to_id[kernel_name] = kern_id - - if len(multi_sgs) != 0: - # Currently, there is only added one additional multi pumped per state. In the future, when we can - # emit multi-pumped kernels that to not consist of directly connected subgraphs, more than 1 should - # be added. - self._num_kernels += 1 - - # Generate kernel code - self.generate_kernel(sdfg, cfg, state, kernel_name, single_sgs, function_stream, callsite_stream, - state_host_header_stream, state_host_body_stream, instrumentation_stream, - state_parameters, kern_id) - - if len(multi_sgs) != 0: - func_stream = CodeIOStream() - call_stream = CodeIOStream() - ignore = CodeIOStream() - # TODO should be able to generate multiple 'pumps'. e.g. pump b and d in - # a > b > c > d > e - # Currently, it only works if the subgraphs are directly chained - self.generate_kernel(sdfg, cfg, state, f'{kernel_name}_pumped', multi_sgs, func_stream, call_stream, - state_host_header_stream, state_host_body_stream, ignore, state_parameters, 42) - - kernel_args_call_host = [] - kernel_args_opencl = [] - - # Include state in args - kernel_args_opencl.append(f"{cpp.mangle_dace_state_struct_name(self._global_sdfg)} *__state") - kernel_args_call_host.append(f"__state") - - for is_output, arg_name, arg, interface_id in state_parameters: - # Streams and Views are not passed as arguments - if (isinstance(arg, dt.Array)): - for bank, _ in iterate_multibank_interface_ids(arg, interface_id): - current_name = fpga_ptr(arg_name, - arg, - sdfg, - bank, - decouple_array_interfaces=self._decouple_array_interfaces) - kernel_args_call_host.append(arg.as_arg(False, name=current_name)) - kernel_args_opencl.append(FPGACodeGen.make_opencl_parameter(current_name, arg)) - elif (not isinstance(arg, dt.Stream) and not isinstance(arg, dt.View)): - kernel_args_call_host.append(arg.as_arg(False, name=arg_name)) - kernel_args_opencl.append(FPGACodeGen.make_opencl_parameter(arg_name, arg)) - - kernel_args_call_host = dtypes.deduplicate(kernel_args_call_host) - kernel_args_opencl = dtypes.deduplicate(kernel_args_opencl) - - ## Generate the global function here - - kernel_host_stream = CodeIOStream() - host_function_name = f"__dace_runstate_{cfg.cfg_id}_{state.name}_{state_id}" - function_stream.write("\n\nDACE_EXPORTED void {}({});\n\n".format(host_function_name, - ", ".join(kernel_args_opencl))) - - # add generated header information - kernel_host_stream.write(state_host_header_stream.getvalue()) - - kernel_host_stream.write(f"""\ -DACE_EXPORTED void {host_function_name}({', '.join(kernel_args_opencl)}) {{""") - - if state.instrument == dtypes.InstrumentationType.FPGA: - kernel_host_stream.write("""\ -const unsigned long int _dace_fpga_begin_us = std::chrono::duration_cast(std::chrono::high_resolution_clock::now().time_since_epoch()).count(); -""") - - kernel_host_stream.write(f"""\ - hlslib::ocl::Program program = __state->fpga_context->Get().CurrentlyLoadedProgram();\ -""") - # Create a vector to collect all events that are being generated to allow - # waiting before exiting this state - kernel_host_stream.write("std::vector all_events;") - - # Kernels invocations - kernel_host_stream.write(state_host_body_stream.getvalue()) - - # Wait for all events - kernel_host_stream.write("hlslib::ocl::WaitForEvents(all_events);") - - # Instrumentation - if state.instrument == dtypes.InstrumentationType.FPGA: - kernel_host_stream.write(""" -// Begin FPGA kernel runtime instrumentation -cl_ulong first_start = std::numeric_limits::max(); -cl_ulong last_end = std::numeric_limits::min();""") - if Config.get_bool("instrumentation", "print_fpga_runtime"): - kernel_host_stream.write(""" -std::cout << std::scientific;""") - kernel_host_stream.write(instrumentation_stream.getvalue()) - kernel_host_stream.write(f"""\ -const unsigned long int _dace_fpga_end_us = std::chrono::duration_cast(std::chrono::high_resolution_clock::now().time_since_epoch()).count(); -// Convert from nanoseconds (reported by OpenCL) to microseconds (expected by the profiler) -__state->report.add_completion("Full FPGA kernel runtime for {state.label}", "FPGA", 1e-3 * first_start, 1e-3 * last_end, {sdfg.cfg_id}, {state_id}, -1); -__state->report.add_completion("Full FPGA state runtime for {state.label}", "FPGA", _dace_fpga_begin_us, _dace_fpga_end_us, {sdfg.cfg_id}, {state_id}, -1); -""") - if Config.get_bool("instrumentation", "print_fpga_runtime"): - kernel_host_stream.write(f""" -const double elapsed = 1e-6 * (_dace_fpga_end_us - _dace_fpga_begin_us); -std::cout << "FPGA program \\"{state.label}\\" executed in " << elapsed << " seconds.\\n";\ -""") - - kernel_host_stream.write("}\n") - - callsite_stream.write("{}({});".format(host_function_name, ", ".join(kernel_args_call_host))) - - # Store code strings to be passed to compilation phase - self._host_codes.append((kernel_name, kernel_host_stream.getvalue())) - - else: # self._in_device_code == True - - to_allocate = dace.sdfg.local_transients(sdfg, state, None) - allocated = set() - subgraphs = dace.sdfg.concurrent_subgraphs(state) - - for node in state.data_nodes(): - data = node.desc(sdfg) - if node.data not in to_allocate or node.data in allocated: - continue - # Make sure there are no global transients in the nested state - # that are thus not gonna be allocated - if data.storage == dtypes.StorageType.FPGA_Global and not isinstance(data, dt.View): - raise cgx.CodegenError("Cannot allocate global memory from device code.") - allocated.add(node.data) - # Allocate transients - self._dispatcher.dispatch_allocate(sdfg, cfg, state, state_id, node, data, function_stream, - callsite_stream) - - self.generate_nested_state(sdfg, cfg, state, state.label, subgraphs, function_stream, callsite_stream) - - @staticmethod - def shared_data(subgraphs): - """ - Returns a set of data objects that are shared between two or more of - the specified subgraphs. - """ - shared = set() - if len(subgraphs) >= 2: - seen = {} - for sg in subgraphs: - for node in sg: - if isinstance(node, dace.sdfg.nodes.AccessNode): - if node.data in seen: - if seen[node.data] != sg: - shared.add(node.data) - else: - seen[node.data] = sg - return shared - - def make_parameters(self, sdfg: SDFG, state: SDFGState, subgraphs): - """ - Determines the parameters that must be passed to the passed list of - subgraphs, as well as to the global kernel. - - :return: A tuple with the following six entries: - - Data container parameters that should be passed from the - host to the FPGA kernel. - - Data containers that are local to the kernel, but must be - allocated by the host prior to invoking the kernel. - - A dictionary mapping from each processing element subgraph - to which parameters it needs (from the total list of - parameters). - - Parameters that must be passed to the kernel from the host, - but that do not exist before the CPU calls the kernel - wrapper. - - A dictionary of which memory interfaces should be assigned to - which memory banks. - - External streams that connect different FPGA kernels, and - must be defined during the compilation flow. - """ - - # Get a set of data nodes that are shared across subgraphs - shared_data = self.shared_data(subgraphs) - # Transients that are accessed in other states in this SDFG - used_outside = sdfg.shared_transients() - transients = [t for t in sdfg.transients() if t not in used_outside] - datanodes = set() - for sg in subgraphs: - for n in sg.data_nodes(): - datanodes.add(n.data) - used_inside = [dn for dn in datanodes if dn in transients] - - # Build a dictionary of arrays to arbitrary data nodes referring to - # them, needed to trace memory bank assignments and to pass to the array - # allocator - data_to_node: Dict[str, dace.nodes.Node] = {} - - global_data_parameters = set() - # Count appearances of each global array to create multiple interfaces - if self._decouple_array_interfaces: - global_interfaces: Dict[str, int] = collections.defaultdict(int) - else: - # For Xilinx, even if we are not decoupling array interfaces we need anyway to use different interfaces - # if we access the same container from different PEs - global_interfaces: Dict[str, (int, int)] = collections.defaultdict(lambda: (0, 0)) - - top_level_local_data = set() - subgraph_parameters = collections.OrderedDict() # {subgraph: [params]} - nested_global_transients = set() - # [(Is an output, dataname string, data object, interface)] - # TODO rephrase is_output. Currently it is "Is an output from the main kernel", but in the future there can be - # more kernels, so make it "is output from current subgraph", but then it needs to map to each subgraph. - external_streams: Set[tuple[bool, str, dt.Data, dict[str, int]]] = set() - - # Mapping from global arrays to memory interfaces - bank_assignments: Dict[str, Tuple[str, Union[int, subsets.Range]]] = {} - - # Mapping from symbol to a unique parameter tuple - all_symbols = {k: (False, k, dt.Scalar(v), None) for k, v in sdfg.symbols.items() if k not in sdfg.constants} - - # Add symbols from inter-state edges - global_symbols = copy.deepcopy(sdfg.symbols) - interstate_symbols = {} - for e in sdfg.dfs_edges(sdfg.start_state): - symbols = e.data.new_symbols(sdfg, global_symbols) - # Inferred symbols only take precedence if global symbol not defined or None - symbols = { - k: v if (k not in global_symbols or global_symbols[k] is None) else global_symbols[k] - for k, v in symbols.items() - } - interstate_symbols.update(symbols) - global_symbols.update(symbols) - all_symbols.update({ - k: (False, k, dt.Scalar(v), None) - for k, v in interstate_symbols.items() if k not in all_symbols and k not in sdfg.constants - }) - - # Symbols that will be passed as parameters to the top-level kernel - global_symbols = set() - - # Sorting by name, then by input/output, then by interface id - sort_func = lambda t: f"{t[1]}{t[0]}{t[3]}" - - subgraph_counter = 0 - for subgraph in subgraphs: - data_to_node.update( - {node.data: node - for node in subgraph.nodes() if isinstance(node, dace.sdfg.nodes.AccessNode)}) - is_rtl_subgraph = self.find_rtl_tasklet(subgraph) - is_multi_subgraph = self.is_multi_pumped_subgraph(subgraph) - subsdfg = subgraph.parent - candidates = [] # type: List[Tuple[bool,str,Data]] - # [(is an output, dataname string, data object)] - array_to_banks_used_out: Dict[str, Set[int]] = {} - array_to_banks_used_in: Dict[str, Set[int]] = {} - sources = subgraph.source_nodes() - for n in sources: - # Check if the node is connected to an RTL tasklet, in which - # case it should be an external stream - is_external = is_rtl_subgraph - is_multi = is_multi_subgraph - is_output = True - if not is_external and not is_multi and self._num_kernels > 1: - if is_external_stream(n, subgraph): - is_external = True - is_output = False - - if is_multi: - if n.data in shared_data: - is_external = False - elif is_external_stream(n, subgraph): - is_external = True - is_output = False - - if is_external and isinstance(subsdfg.arrays[n.data], dt.Stream): - external_streams.add((is_output, n.data, subsdfg.arrays[n.data], None)) - else: - candidates += [(False, e.data.data, subsdfg.arrays[e.data.data]) for e in state.in_edges(n)] - - sinks = subgraph.sink_nodes() - for n in sinks: - # Check if the node is connected to an RTL tasklet, in which - # case it should be an external stream - is_external = is_rtl_subgraph - is_multi = is_multi_subgraph - is_output = False - if not is_external and not is_multi and self._num_kernels > 1: - if is_external_stream(n, subgraph): - is_external = True - is_output = True - - if is_multi: - if n.data in shared_data: - is_external = False - elif is_external_stream(n, subgraph): - is_external = True - is_output = True - - if is_external and isinstance(subsdfg.arrays[n.data], dt.Stream): - external_streams.add((is_output, n.data, subsdfg.arrays[n.data], None)) - else: - candidates += [(True, e.data.data, subsdfg.arrays[e.data.data]) for e in state.out_edges(n)] - # Find other data nodes that are used internally - for n, scope in subgraph.all_nodes_recursive(): - if isinstance(n, dace.sdfg.nodes.AccessNode): - # Add nodes if they are outer-level, or an inner-level transient - # (inner-level inputs/outputs are just connected to data in the outer layers, - # whereas transients can be independent). - # Views are not nested global transients - if scope == subgraph or n.desc(scope).transient: - desc = n.desc(scope) - if scope.out_degree(n) > 0: - candidates.append((False, n.data, desc)) - if scope.in_degree(n) > 0: - candidates.append((True, n.data, desc)) - if is_multibank_array_with_distributed_index(desc): - # Record all banks used by this subgraph to generate interfaces for them - # inputs and outputs seperate, because using a bank as an input doesn't mean - # we also need an output interface - current_banks_out = set() - current_banks_in = set() - for edge in scope.in_edges(n): - for bank in iterate_distributed_subset(desc, edge.data, True, sdfg): - current_banks_out.add(bank) - for edge in scope.out_edges(n): - for bank in iterate_distributed_subset(desc, edge.data, False, sdfg): - current_banks_in.add(bank) - if n.data in array_to_banks_used_in: - array_to_banks_used_in[n.data].update(current_banks_in) - else: - array_to_banks_used_in[n.data] = current_banks_in - if n.data in array_to_banks_used_out: - array_to_banks_used_out[n.data].update(current_banks_out) - else: - array_to_banks_used_out[n.data] = current_banks_out - if scope != subgraph: - if (isinstance(n.desc(scope), dt.Array) - and n.desc(scope).storage == dtypes.StorageType.FPGA_Global - and not isinstance(n.desc(scope), dt.View)): - nested_global_transients.add(n) - subgraph_parameters[subgraph] = set() - # For each subgraph, keep a listing of array to current interface ID - data_to_interface: Dict[str, int] = {} - # multibank data name -> is_output -> List of (bank, interface id) - # same as data_to_interface, but for HBM/DDR-arrays with multiple banks - multibank_data_to_interface: Dict[str, Dict[bool, List[Tuple[int, int]]]] = {} - - # Differentiate global and local arrays. The former are allocated - # from the host and passed to the device code, while the latter are - # (statically) allocated on the device side. - for is_output, data_name, desc in candidates: - # Ignore views, as these never need to be explicitly passed - if isinstance(desc, dt.View): - continue - if not isinstance(desc, dt.Array): - is_output = None - # If this is a global array, assign the correct interface ID and - # memory interface (e.g., DDR or HBM bank) - if (isinstance(desc, dt.Array) and desc.storage == dtypes.StorageType.FPGA_Global): - if data_name in data_to_interface: - interface_id = data_to_interface[data_name] - elif data_name in multibank_data_to_interface and is_output in multibank_data_to_interface[ - data_name]: - interface_id = multibank_data_to_interface[data_name][is_output] - else: - # Get and update global memory interface ID - if is_multibank_array_with_distributed_index(desc): - tmp_interface_ids = [] - if is_output: - banks_looked_at = array_to_banks_used_out[data_name] - else: - banks_looked_at = array_to_banks_used_in[data_name] - for bank in banks_looked_at: - ptr_str = fpga_ptr(data_name, - desc, - sdfg, - bank, - decouple_array_interfaces=self._decouple_array_interfaces) - - if self._decouple_array_interfaces: - tmp_interface_id = global_interfaces[ptr_str] - global_interfaces[ptr_str] += 1 - else: - if ptr_str not in global_interfaces: - global_interfaces[ptr_str] = (0, subgraph_counter) - - tmp_interface_id, last_used_in = global_interfaces[ptr_str] - if last_used_in != subgraph_counter: - # we accessed the same container from a different subgraph/PE: we need - # to use a different interface - tmp_interface_id += 1 - global_interfaces[ptr_str] = (tmp_interface_id, subgraph_counter) - - tmp_interface_ids.append((bank, tmp_interface_id)) - interface_id = tuple(tmp_interface_ids) - if data_name not in multibank_data_to_interface: - multibank_data_to_interface[data_name] = {} - multibank_data_to_interface[data_name][is_output] = interface_id - else: - if self._decouple_array_interfaces: - interface_id = global_interfaces[data_name] - global_interfaces[data_name] += 1 - else: - if data_name not in global_interfaces: - global_interfaces[data_name] = (0, subgraph_counter) - - interface_id, last_used_in = global_interfaces[data_name] - if last_used_in != subgraph_counter: - # we accessed the same container from a different data subgraph/PE: we need - # to use a different interface - global_interfaces[data_name] = (interface_id + 1, subgraph_counter) - interface_id += 1 - data_to_interface[data_name] = interface_id - # Collect the memory bank specification, if present, by - # traversing outwards to where the data container is - # actually allocated - inner_node = data_to_node[data_name] - trace = utils.trace_nested_access(inner_node, subgraph, sdfg) - bank = None - bank_type = None - for (trace_in, trace_out), _, _, trace_sdfg in trace: - trace_node = trace_in or trace_out - trace_name = trace_node.data - trace_desc = trace_node.desc(trace_sdfg) - if "bank" in trace_desc.location: - trace_type, trace_bank = parse_location_bank(trace_desc) - if (bank is not None and bank_type is not None - and (bank != trace_bank or bank_type != trace_type)): - raise cgx.CodegenError("Found inconsistent memory bank " - f"specifier for {trace_name}.") - bank = trace_bank - bank_type = trace_type - - # Make sure the array has been allocated on this bank in the - # outermost scope - if bank_type is not None: - outer_node = trace[0][0][0] or trace[0][0][1] - outer_desc = outer_node.desc(trace[0][2]) - okbank = False - if ("bank" in outer_desc.location): - trace_type, trace_bank = parse_location_bank(outer_desc) - okbank = (trace_type == bank_type and trace_bank == bank) - if not okbank: - raise cgx.CodegenError("Memory bank allocation must be present on " - f"outermost data descriptor {outer_node.data} " - "to be allocated correctly.") - bank_assignments[data_name] = (bank_type, bank) - else: - bank_assignments[data_name] = None - else: - interface_id = None - if (not desc.transient or desc.storage == dtypes.StorageType.FPGA_Global or data_name in used_outside): - # Add the data as a parameter to this PE - subgraph_parameters[subgraph].add((is_output, data_name, desc, interface_id)) - # Global data is passed from outside the kernel - global_data_parameters.add((is_output, data_name, desc, interface_id)) - elif data_name in shared_data: - # Add the data as a parameter to this PE - subgraph_parameters[subgraph].add((is_output, data_name, desc, interface_id)) - # Must be allocated outside PEs and passed to them - top_level_local_data.add(data_name) - - # Order by name - subgraph_parameters[subgraph] = list(sorted(subgraph_parameters[subgraph], key=sort_func)) - # Append symbols used in this subgraph - for k in sorted(self._frame.free_symbols(subgraph)): - if k not in sdfg.constants: - param = all_symbols[k] - subgraph_parameters[subgraph].append(param) - global_symbols.add(param) - - subgraph_counter += 1 - # Order by name - global_data_parameters = list(sorted(global_data_parameters, key=sort_func)) - global_data_parameters += sorted(global_symbols, key=sort_func) - external_streams = list(sorted(external_streams, key=sort_func)) - nested_global_transients = list(sorted(nested_global_transients)) - - stream_names = {sname for _, sname, _, _ in external_streams} - top_level_local_data = [data_to_node[name] for name in sorted(top_level_local_data) if name not in stream_names] - - return (global_data_parameters, top_level_local_data, subgraph_parameters, nested_global_transients, - bank_assignments, external_streams) - - def generate_nested_state(self, sdfg: SDFG, cfg: ControlFlowRegion, state: dace.SDFGState, nest_name: str, - subgraphs: List[ScopeSubgraphView], function_stream: CodeIOStream, - callsite_stream: CodeIOStream) -> None: - - for sg in subgraphs: - self._dispatcher.dispatch_subgraph(sdfg, - cfg, - sg, - cfg.node_id(state), - function_stream, - callsite_stream, - skip_entry_node=False) - - def generate_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: StateSubgraphView, state_id: int, - function_stream: CodeIOStream, callsite_stream: CodeIOStream) -> None: - - if not self._in_device_code: - # If we're not already generating kernel code, fail - raise cgx.CodegenError('FPGA kernel needs to be generated inside a device state.') - - self.generate_node(sdfg, cfg, dfg_scope, state_id, - dfg_scope.source_nodes()[0], function_stream, callsite_stream) - - self._dispatcher.dispatch_subgraph(sdfg, - cfg, - dfg_scope, - state_id, - function_stream, - callsite_stream, - skip_entry_node=True) - - def declare_array(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphView, state_id: int, - node: nodes.AccessNode, nodedesc: dt.Data, function_stream: CodeIOStream, - declaration_stream: CodeIOStream) -> None: - - fsymbols = self._frame.symbols_and_constants(sdfg) - if not utils.is_nonfree_sym_dependent(node, nodedesc, dfg, fsymbols): - raise NotImplementedError("The declare_array method should only be used for variables " - "that must have their declaration and allocation separate.") - - result_decl = StringIO() - arrsize = nodedesc.total_size - dataname = node.data - ptrname = cpp.ptr(dataname, nodedesc, sdfg, self._frame) - - # Check if array is already declared - if self._dispatcher.declared_arrays.has(ptrname): - return - - if nodedesc.storage == dtypes.StorageType.FPGA_Global: - - if self._in_device_code: - - if nodedesc not in self._allocated_global_arrays: - raise RuntimeError("Cannot allocate global array " - "from device code: {} in {}".format(node.label, sdfg.name)) - - else: - # TODO: Distinguish between read, write, and read+write - # Define buffer, using proper type - result_decl.write("hlslib::ocl::Buffer <{}, hlslib::ocl::Access::readWrite> {};".format( - nodedesc.dtype.ctype, ptrname)) - self._dispatcher.declared_arrays.add( - ptrname, DefinedType.Pointer, - 'hlslib::ocl::Buffer <{}, hlslib::ocl::Access::readWrite>'.format(nodedesc.dtype.ctype)) - elif (nodedesc.storage in (dtypes.StorageType.FPGA_Local, dtypes.StorageType.FPGA_Registers, - dtypes.StorageType.FPGA_ShiftRegister)): - - raise ValueError("Dynamic allocation of FPGA " - "fast memory not allowed: {}, size {}".format(dataname, arrsize)) - - else: - raise NotImplementedError("Unimplemented storage type " + str(nodedesc.storage)) - - declaration_stream.write(result_decl.getvalue(), cfg, state_id, node) - - def allocate_array(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphView, state_id: int, - node: nodes.AccessNode, nodedesc: dt.Data, function_stream: CodeIOStream, - declaration_stream: CodeIOStream, allocation_stream: CodeIOStream) -> None: - - # NOTE: The code below fixes symbol-related issues with transient data originally defined in a NestedSDFG scope - # but promoted to be persistent. These data must have their free symbols replaced with the corresponding - # top-level SDFG symbols. - if nodedesc.lifetime in (dtypes.AllocationLifetime.Persistent, dtypes.AllocationLifetime.External): - nodedesc = update_persistent_desc(nodedesc, sdfg) - - result_decl = StringIO() - result_alloc = StringIO() - arrsize = nodedesc.total_size - is_dynamically_sized = dace.symbolic.issymbolic(arrsize, sdfg.constants) - dataname = cpp.ptr(node.data, nodedesc, sdfg, self._frame) - - if not isinstance(nodedesc, dt.Stream): - # Unless this is a Stream, if the variable has been already defined we can return - # For Streams, we still allocate them to keep track of their names across - # nested SDFGs (needed by Intel FPGA backend for channel mangling) - if self._dispatcher.defined_vars.has(dataname): - return - - # Check if array is already declared - declared = self._dispatcher.declared_arrays.has(dataname) - - if isinstance(nodedesc, dt.View): - return self.allocate_view(sdfg, cfg, dfg, state_id, node, function_stream, declaration_stream, - allocation_stream) - elif isinstance(nodedesc, dt.Reference): - return self.allocate_reference(sdfg, cfg, dfg, state_id, node, function_stream, declaration_stream, - allocation_stream) - elif isinstance(nodedesc, dt.Stream): - - if not self._in_device_code: - raise cgx.CodegenError("Cannot allocate FIFO from CPU code: {}".format(node.data)) - - if is_dynamically_sized: - raise cgx.CodegenError("Arrays of streams cannot have dynamic size on FPGA") - - try: - buffer_size = dace.symbolic.evaluate(nodedesc.buffer_size, sdfg.constants) - except TypeError: - raise cgx.CodegenError("Buffer length of stream cannot have dynamic size on FPGA") - - # Language-specific implementation - ctype, is_global = self.define_stream(nodedesc.dtype, buffer_size, dataname, arrsize, function_stream, - result_decl, sdfg) - - # defined type: decide whether this is a stream array or a single stream - def_type = (DefinedType.StreamArray if arrsize != 1 else DefinedType.Stream) - if is_global: - self._dispatcher.defined_vars.add_global(dataname, def_type, ctype) - else: - self._dispatcher.defined_vars.add(dataname, def_type, ctype) - - elif isinstance(nodedesc, dt.Array): - - if nodedesc.storage == dtypes.StorageType.FPGA_Global: - - if self._in_device_code: - - if nodedesc not in self._allocated_global_arrays: - raise RuntimeError("Cannot allocate global array " - "from device code: {} in {}".format(node.label, sdfg.name)) - - else: - # TODO: Distinguish between read, write, and read+write - self._allocated_global_arrays.add(node.data) - memory_bank_arg_count = 1 - bank_offset = -1 - storage_type_str = "hlslib::ocl::StorageType::DDR" # DDR to use unspecified memory - - # Fix bankassignments if present - bank_info = parse_location_bank(nodedesc) - if bank_info is not None: - bank_type, bank = bank_info - - bank_low, bank_high = get_multibank_ranges_from_subset(bank, sdfg) - memory_bank_arg_count = bank_high - bank_low - if bank_high - bank_low > 1: - arrsize = dace.symbolic.pystr_to_symbolic( - f"int_ceil(({str(arrsize)}) , ({str(bank_high - bank_low)}))") - else: - arrsize = dace.symbolic.pystr_to_symbolic(f"({str(arrsize)})") - - bank_offset = bank_low - - if bank_type == "HBM": - storage_type_str = "hlslib::ocl::StorageType::HBM" - else: - storage_type_str = "hlslib::ocl::StorageType::DDR" - - # Define buffer, using proper type - for bank_index in range(memory_bank_arg_count): - alloc_name = fpga_ptr(dataname, - nodedesc, - sdfg, - bank_index, - decouple_array_interfaces=self._decouple_array_interfaces) - if not declared: - result_decl.write("hlslib::ocl::Buffer <{}, hlslib::ocl::Access::readWrite> {};\n".format( - nodedesc.dtype.ctype, alloc_name)) - result_alloc.write(f"{alloc_name} = __state->fpga_context->Get()." - f"MakeBuffer<{nodedesc.dtype.ctype}, hlslib::ocl::Access::readWrite>" - f"({storage_type_str}, {bank_offset + bank_index}, " - f"{cpp.sym2cpp(arrsize)});\n") - - self._dispatcher.defined_vars.add( - alloc_name, DefinedType.Pointer, - 'hlslib::ocl::Buffer <{}, hlslib::ocl::Access::readWrite>'.format(nodedesc.dtype.ctype)) - - elif (nodedesc.storage in (dtypes.StorageType.FPGA_Local, dtypes.StorageType.FPGA_Registers, - dtypes.StorageType.FPGA_ShiftRegister)): - - if not self._in_device_code: - raise cgx.CodegenError("Tried to allocate local FPGA memory " - "outside device code: {}".format(dataname)) - if is_dynamically_sized: - raise ValueError("Dynamic allocation of FPGA " - "fast memory not allowed: {}, size {}".format(dataname, arrsize)) - - generate_scalar = cpp.sym2cpp(arrsize) == "1" - - if generate_scalar: - # Language-specific - ctype = self.make_vector_type(nodedesc.dtype, False) - define_str = "{} {};".format(ctype, dataname) - result_decl.write(define_str) - self._dispatcher.defined_vars.add(dataname, DefinedType.Scalar, ctype) - else: - # Language-specific - if (nodedesc.storage == dtypes.StorageType.FPGA_ShiftRegister): - self.define_shift_register(dataname, nodedesc, arrsize, function_stream, result_decl, sdfg, - state_id, node) - else: - self.define_local_array(dataname, nodedesc, arrsize, function_stream, result_decl, sdfg, - state_id, node) - - else: - raise NotImplementedError("Unimplemented storage type " + str(nodedesc.storage)) - - elif isinstance(nodedesc, dt.Scalar): - - ctype = self.make_vector_type(nodedesc.dtype, False) - result_decl.write("{} {};\n".format(ctype, dataname)) - self._dispatcher.defined_vars.add(dataname, DefinedType.Scalar, ctype) - - else: - raise TypeError("Unhandled data type: {}".format(type(nodedesc).__name__)) - - declaration_stream.write(result_decl.getvalue(), cfg, state_id, node) - allocation_stream.write(result_alloc.getvalue(), cfg, state_id, node) - - def deallocate_array(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphView, state_id: int, - node: nodes.AccessNode, nodedesc: dt.Data, function_stream: CodeIOStream, - callsite_stream: CodeIOStream) -> None: - pass # Handled by destructor - - def partition_kernels(self, state: dace.SDFGState, default_kernel: int = 0): - """ - Associate node to different kernels. - This field is applied to all FPGA maps, tasklets, and library nodes - that can be executed in parallel in separate kernels. - - :param state: the state to analyze. - :param default_kernel: The Kernel ID to start counting from. - :return: a tuple containing the number of kernels and the dependencies among them - """ - - concurrent_kernels = 0 # Max number of kernels - sdfg = state.sdfg - - def increment(kernel_id): - if concurrent_kernels > 0: - return (kernel_id + 1) % concurrent_kernels - return kernel_id + 1 - - # Dictionary containing dependencies among kernels: - # dependencies[K] = [list of kernel IDs on which K depends] - dependencies = dict() - - source_nodes = state.source_nodes() - sink_nodes = state.sink_nodes() - - max_kernels = default_kernel - # First step: assign a different Kernel ID - # to each source node which is not an AccessNode - for i, node in enumerate(source_nodes): - if isinstance(node, nodes.AccessNode): - continue - - self._node_to_kernel[utils.unique_node_repr(state, node)] = max_kernels - max_kernels = increment(max_kernels) - - # Consecutive nodes that are not crossroads can be in the same Kernel - # A node is said to be a crossroad, if it belongs in more that two - # disjoint paths that connect graph sinks and sources - - scopes = state.scope_dict() - - for e in state.dfs_edges(source_nodes): - if utils.unique_node_repr(state, e.dst) in self._node_to_kernel: - # Node has been already visited) - continue - - e_src_repr = utils.unique_node_repr(state, e.src) - if e_src_repr in self._node_to_kernel: - kernel = self._node_to_kernel[e_src_repr] - - if (isinstance(e.dst, nodes.AccessNode) and isinstance(sdfg.arrays[e.dst.data], dt.View)): - # Skip views - self._node_to_kernel[utils.unique_node_repr(state, e.dst)] = kernel - continue - - # Does this node (e.dst) need to be in another kernel? - # If it is a crossroad node (has more than one predecessor, its predecessors contain some compute, - # no local buffers, the destination is not a sink) then it should be on a separate kernel. - - if len(list(state.predecessors(e.dst))) > 1 and not isinstance( - e.dst, nodes.ExitNode) and e.dst not in sink_nodes and scopes[e.dst] == None: - # Loop over all predecessors (except this edge) - crossroad_node = False - for pred_edge in state.in_edges(e.dst): - if pred_edge != e and self._trace_back_edge(pred_edge, state): - crossroad_node = True - break - - if crossroad_node: - kernel = max_kernels - max_kernels = increment(max_kernels) - - else: - - # From this edge we don't have any kernel id. - # Look up for the other predecessor nodes, if any of them has a kernel - # ID, use that. - for pred_edge in state.in_edges(e.dst): - if pred_edge != e: - kernel = self._trace_back_edge(pred_edge, state, look_for_kernel_id=True) - if kernel is not None: - break - else: - # Look at the successor nodes: because of the DFS visit, it may occur - # that one of them has already an associated kernel ID. If this is the case, and - # if the edge that connects this node with it is a tasklet-to-tasklet - # edge, then we use that kernel ID. In all the other cases, we use a new one. - - # TODO: support more robust detection - # It could be the case that we need to look also at the predecessors: - # if they are associated with a different kernel, and there is a tasklet-to-tasklet, - # maybe we don't want to generate a different kernel. - - for succ_edge in state.out_edges(e.dst): - succ_edge_dst_repr = utils.unique_node_repr(state, succ_edge.dst) - if succ_edge_dst_repr in self._node_to_kernel and isinstance( - succ_edge.src, nodes.Tasklet) and isinstance(succ_edge.dst, nodes.Tasklet): - kernel = self._node_to_kernel[succ_edge_dst_repr] - break - else: - # Trace this edge forward: if it finds something that has a kernel id and - # there is at least one local buffer along the way, then reuse that kernel id - only_global, kern = self._trace_forward_edge(e, state) - if not only_global and kern is not None: - kernel = kern - else: - kernel = max_kernels - if (isinstance(e.dst, nodes.AccessNode) and (isinstance(sdfg.arrays[e.dst.data], dt.View))): - # Skip views and local buffers - pass - else: - max_kernels = increment(max_kernels) - - self._node_to_kernel[utils.unique_node_repr(state, e.dst)] = kernel - - # do another pass and track dependencies among Kernels - for node in state.nodes(): - node_repr = utils.unique_node_repr(state, node) - if node_repr in self._node_to_kernel: - this_kernel = self._node_to_kernel[node_repr] - # get all predecessors and see their associated kernel ID - for pred in state.predecessors(node): - pred_repr = utils.unique_node_repr(state, pred) - if pred_repr in self._node_to_kernel and self._node_to_kernel[pred_repr] != this_kernel: - if this_kernel not in dependencies: - dependencies[this_kernel] = set() - dependencies[this_kernel].add(self._node_to_kernel[pred_repr]) - - max_kernels = max_kernels if concurrent_kernels == 0 else concurrent_kernels - return max_kernels, dependencies - - def _trace_back_edge(self, - edge: MultiConnectorEdge[dace.Memlet], - state: dace.SDFGState, - look_for_kernel_id: bool = False) -> Union[bool, int]: - """ - Given an edge, this traverses the edges backwards. - It can be used either for: - - - understanding if along the backward path there is some compute node but no local buffers, or - - looking for the kernel_id of a predecessor (look_for_kernel_id must be set to True) - - :return: if look_for_kernel_id is false it returns a boolean indicating if there is a - compute node on the backward path and no access nodes to local buffers. Otherwise, it returns - the kernel_id of a predecessor node. - """ - - curedge = edge - source_nodes = state.source_nodes() - contains_compute = False - contains_only_global_buffers = True - while not curedge.src in source_nodes: - - if not look_for_kernel_id: - if isinstance(curedge.src, (nodes.EntryNode, nodes.ExitNode, nodes.CodeNode)): - # We can stop here: this is a scope which will contain some compute, or a tasklet/libnode - contains_compute = True - elif isinstance(curedge.src, nodes.AccessNode): - if curedge.src.desc(state).storage in _FPGA_LOCAL_STORAGE_TYPES: - contains_only_global_buffers = False - - else: - src_repr = utils.unique_node_repr(state, curedge.src) - if src_repr in self._node_to_kernel: - # Found a node with a kernel id. Use that - return self._node_to_kernel[src_repr] - next_edge = next(e for e in state.in_edges(curedge.src)) - curedge = next_edge - - # We didn't return before - if not look_for_kernel_id: - return contains_compute and contains_only_global_buffers - else: - src_repr = utils.unique_node_repr(state, curedge.src) - return self._node_to_kernel[src_repr] if src_repr in self._node_to_kernel else None - - def _trace_forward_edge(self, edge: MultiConnectorEdge[dace.Memlet], state: dace.SDFGState) -> Tuple[bool, int]: - """ - Given an edge, this traverses the edges forward. - It can be used either for: - - - understanding if along the forward path there is a local buffer, and - - returning the the kernel_id of a successor if any - - :return: a tuple containing two booleans indicating if the path contains only global buffers - and the kernel_id of a successor if any - """ - - curedge = edge - sink_nodes = state.sink_nodes() - contains_compute = False - contains_only_global_buffers = True - while not curedge.dst in sink_nodes: - - if isinstance(curedge.dst, nodes.AccessNode): - if curedge.dst.desc(state).storage in _FPGA_LOCAL_STORAGE_TYPES: - contains_only_global_buffers = False - dst_repr = utils.unique_node_repr(state, curedge.dst) - if dst_repr in self._node_to_kernel: - # Found a node with a kernel id. Use that - return contains_only_global_buffers, self._node_to_kernel[dst_repr] - next_edge = next(e for e in state.out_edges(curedge.dst)) - curedge = next_edge - - # We didn't return before - dst_repr = utils.unique_node_repr(state, curedge.dst) - kernel_id = self._node_to_kernel[dst_repr] if dst_repr in self._node_to_kernel else None - return contains_only_global_buffers, kernel_id - - def _emit_copy(self, sdfg: SDFG, cfg: ControlFlowRegion, state_id: int, src_node: nodes.Node, - src_storage: dtypes.StorageType, dst_node: nodes.Node, dst_storage: dtypes.StorageType, - dst_schedule: dtypes.ScheduleType, edge: MultiConnectorEdge[memlet.Memlet], dfg: StateSubgraphView, - function_stream: CodeIOStream, callsite_stream: CodeIOStream) -> None: - - u, v, memlet = edge.src, edge.dst, edge.data - - # Determine directionality - if isinstance(src_node, dace.sdfg.nodes.AccessNode) and memlet.data == src_node.data: - outgoing_memlet = True - elif isinstance(dst_node, dace.sdfg.nodes.AccessNode) and memlet.data == dst_node.data: - outgoing_memlet = False - else: - raise LookupError("Memlet does not point to any of the nodes") - - data_to_data = (isinstance(src_node, dace.sdfg.nodes.AccessNode) - and isinstance(dst_node, dace.sdfg.nodes.AccessNode)) - - host_to_device = (data_to_data and src_storage in _CPU_STORAGE_TYPES - and dst_storage == dtypes.StorageType.FPGA_Global) - device_to_host = (data_to_data and src_storage == dtypes.StorageType.FPGA_Global - and dst_storage in _CPU_STORAGE_TYPES) - device_to_device = (data_to_data and src_storage == dtypes.StorageType.FPGA_Global - and dst_storage == dtypes.StorageType.FPGA_Global) - - if (host_to_device or device_to_host) and self._in_device_code: - raise RuntimeError("Cannot copy between host and device from device") - - if (host_to_device or device_to_host or (device_to_device and not self._in_device_code)): - - dims = memlet.subset.dims() - src_nodedesc = src_node.desc(sdfg) - dst_nodedesc = dst_node.desc(sdfg) - src_is_subset = memlet._is_data_src is None or memlet._is_data_src - - copy_shape = memlet.subset.bounding_box_size() - is_src_using_multibank = (src_is_subset and is_multibank_array_with_distributed_index(src_nodedesc)) - is_dst_using_multibank = (not src_is_subset and is_multibank_array_with_distributed_index(dst_nodedesc)) - if is_src_using_multibank or is_dst_using_multibank: - copy_shape = copy_shape[1:] - - offset_src, offset_dst = "0", "0" - if memlet.src_subset is not None: - offset_src = cpp.cpp_array_expr(sdfg, - memlet, - with_brackets=False, - referenced_array=src_nodedesc, - use_other_subset=(not src_is_subset - and memlet.other_subset is not None), - codegen=self._frame) - if memlet.dst_subset is not None: - offset_dst = cpp.cpp_array_expr(sdfg, - memlet, - with_brackets=False, - referenced_array=dst_nodedesc, - use_other_subset=(src_is_subset and memlet.other_subset is not None), - codegen=self._frame) - - if (not sum(copy_shape) == 1 and - (not isinstance(memlet.subset, subsets.Range) or any([step != 1 for _, _, step in memlet.subset]))): - raise NotImplementedError("Only contiguous copies currently " - "supported for FPGA codegen.") - - if host_to_device or device_to_device: - host_dtype = sdfg.data(src_node.data).dtype - device_dtype = sdfg.data(dst_node.data).dtype - elif device_to_host: - device_dtype = sdfg.data(src_node.data).dtype - host_dtype = sdfg.data(dst_node.data).dtype - cast = False - if not device_to_device and host_dtype != device_dtype: - host_dtype_base = host_dtype - while True: - updated = host_dtype_base.base_type - if updated != host_dtype_base: - host_dtype_base = updated - continue - break - device_dtype_base = device_dtype - while True: - updated = device_dtype_base.base_type - if updated != device_dtype_base: - device_dtype_base = updated - continue - break - if ((isinstance(host_dtype, dace.vector) or isinstance(device_dtype, dace.vector)) - and host_dtype_base == device_dtype_base): - if ((host_to_device and memlet.data == src_node.data) - or (device_to_host and memlet.data == dst_node.data)): - if host_dtype.bytes > device_dtype.bytes: - copy_shape[-1] *= (host_dtype.bytes // device_dtype.bytes) - else: - copy_shape[-1] //= (device_dtype.bytes // host_dtype.bytes) - cast = True - else: - raise TypeError("Memory copy type mismatch: {} vs {}".format(host_dtype, device_dtype)) - - copysize = " * ".join([cppunparse.pyexpr2cpp(dace.symbolic.symstr(s, cpp_mode=True)) for s in copy_shape]) - - src_subset = memlet.src_subset or memlet.subset - dst_subset = memlet.dst_subset or memlet.subset - if host_to_device: - - ptr_str = (fpga_ptr(src_node.data, - src_nodedesc, - sdfg, - src_subset, - decouple_array_interfaces=self._decouple_array_interfaces) + - (" + {}".format(offset_src) if outgoing_memlet and str(offset_src) != "0" else "")) - if cast: - ptr_str = "reinterpret_cast<{} const *>({})".format(device_dtype.ctype, ptr_str) - - callsite_stream.write( - "{}.CopyFromHost({}, {}, {});".format( - fpga_ptr(dst_node.data, - dst_nodedesc, - sdfg, - dst_subset, - decouple_array_interfaces=self._decouple_array_interfaces), - (offset_dst if not outgoing_memlet else 0), copysize, ptr_str), cfg, state_id, - [src_node, dst_node]) - - elif device_to_host: - - ptr_str = (fpga_ptr(dst_node.data, - dst_nodedesc, - sdfg, - dst_subset, - decouple_array_interfaces=self._decouple_array_interfaces) + - (" + {}".format(offset_dst) if outgoing_memlet and str(offset_dst) != "0" else "")) - if cast: - ptr_str = "reinterpret_cast<{} *>({})".format(device_dtype.ctype, ptr_str) - - callsite_stream.write( - "{}.CopyToHost({}, {}, {});".format( - fpga_ptr(src_node.data, - src_nodedesc, - sdfg, - src_subset, - decouple_array_interfaces=self._decouple_array_interfaces), - (offset_src if outgoing_memlet else 0), copysize, ptr_str), cfg, state_id, [src_node, dst_node]) - - elif device_to_device: - - callsite_stream.write( - "{}.CopyToDevice({}, {}, {}, {});".format( - fpga_ptr(src_node.data, - src_nodedesc, - sdfg, - src_subset, - decouple_array_interfaces=self._decouple_array_interfaces), - (offset_src if outgoing_memlet else 0), copysize, - fpga_ptr(dst_node.data, - dst_nodedesc, - sdfg, - dst_subset, - decouple_array_interfaces=self._decouple_array_interfaces), - (offset_dst if not outgoing_memlet else 0)), cfg, state_id, [src_node, dst_node]) - - # Reject copying to/from local memory from/to outside the FPGA - elif (data_to_data and - (((src_storage in (dtypes.StorageType.FPGA_Local, dtypes.StorageType.FPGA_Registers, - dtypes.StorageType.FPGA_ShiftRegister)) and dst_storage not in _FPGA_STORAGE_TYPES) or - ((dst_storage in (dtypes.StorageType.FPGA_Local, dtypes.StorageType.FPGA_Registers, - dtypes.StorageType.FPGA_ShiftRegister)) and src_storage not in _FPGA_STORAGE_TYPES))): - raise NotImplementedError("Copies between host memory and FPGA " - "local memory not supported: from {} to {}".format(src_node, dst_node)) - - elif data_to_data: - - if memlet.wcr is not None: - raise NotImplementedError("WCR not implemented for copy edges") - - if src_storage == dtypes.StorageType.FPGA_ShiftRegister: - raise NotImplementedError("Reads from shift registers only supported from tasklets.") - - # Try to turn into degenerate/strided ND copies - state_dfg = cfg.node(state_id) - copy_shape, src_strides, dst_strides, src_expr, dst_expr = (cpp.memlet_copy_to_absolute_strides( - self._dispatcher, sdfg, state_dfg, edge, src_node, dst_node, packed_types=True)) - - dtype = src_node.desc(sdfg).dtype - ctype = dtype.ctype - - if dst_storage == dtypes.StorageType.FPGA_ShiftRegister: - if len(copy_shape) != 1: - raise ValueError("Only single-dimensional writes " - "to shift registers supported: {}{}".format(dst_node.data, copy_shape)) - - # Check if we are copying between vectorized and non-vectorized - # types - memwidth_src = src_node.desc(sdfg).veclen - memwidth_dst = dst_node.desc(sdfg).veclen - if memwidth_src < memwidth_dst: - is_pack = True - is_unpack = False - packing_factor = memwidth_dst // memwidth_src - if memwidth_dst % memwidth_src != 0: - raise ValueError("Destination vectorization width {} " - "is not divisible by source vectorization width {}.".format( - memwidth_dst, memwidth_src)) - self.converters_to_generate.add((False, vector_element_type_of(dtype).ctype, packing_factor)) - elif memwidth_src > memwidth_dst: - is_pack = False - is_unpack = True - packing_factor = memwidth_src // memwidth_dst - if memwidth_src % memwidth_dst != 0: - raise ValueError("Source vectorization width {} is not divisible " - "by destination vectorization width {}.".format(memwidth_dst, memwidth_src)) - self.converters_to_generate.add((True, vector_element_type_of(dtype).ctype, packing_factor)) - else: - is_pack = False - is_unpack = False - packing_factor = 1 - - # TODO: detect in which cases we shouldn't unroll - register_to_register = (src_node.desc(sdfg).storage == dtypes.StorageType.FPGA_Registers - or dst_node.desc(sdfg).storage == dtypes.StorageType.FPGA_Registers) - - num_loops = sum((d != 1 for d in copy_shape), 0) - has_pipelined_loops = (num_loops > 0 and not register_to_register) - - # Determine data that should have dependency pragmas injected - dependency_pragma_nodes = [] - for node in ((src_node, ) if src_node.data == dst_node.data else (src_node, dst_node)): - desc = node.desc(sdfg) - if (isinstance(desc, dt.Array) and desc.storage == dtypes.StorageType.FPGA_Local - and desc.total_size != 1): - dependency_pragma_nodes.append(node) - - if has_pipelined_loops: - # Language-specific - self.generate_pipeline_loop_pre(callsite_stream, sdfg, cfg, state_id, dst_node) - if len(copy_shape) > 1: - # Language-specific - self.generate_flatten_loop_pre(callsite_stream, sdfg, cfg, state_id, dst_node) - for node in dependency_pragma_nodes: - # Inject dependence pragmas - self.generate_no_dependence_pre(callsite_stream, sdfg, cfg, state_id, dst_node, node.data) - - # Loop intro - for i, copy_dim in enumerate(copy_shape): - if copy_dim != 1: - if register_to_register: - # Language-specific - self.generate_unroll_loop_pre(callsite_stream, None, sdfg, cfg, state_id, dst_node) - - callsite_stream.write( - "for (int __dace_copy{} = 0; __dace_copy{} < {}; " - "++__dace_copy{}) {{".format(i, i, cpp.sym2cpp(copy_dim), i), cfg, state_id, dst_node) - - if register_to_register: - # Language-specific - self.generate_unroll_loop_post(callsite_stream, None, sdfg, cfg, state_id, dst_node) - - # Pragmas - if has_pipelined_loops: - # Language-specific - self.generate_pipeline_loop_post(callsite_stream, sdfg, cfg, state_id, dst_node) - self.generate_flatten_loop_post(callsite_stream, sdfg, cfg, state_id, dst_node) - # Inject dependence pragmas - for node in dependency_pragma_nodes: - self.generate_no_dependence_post(callsite_stream, sdfg, cfg, state_id, dst_node, node.data) - - src_name = cpp.ptr(src_node.data, src_node.desc(sdfg), sdfg, self._frame) - dst_name = cpp.ptr(dst_node.data, dst_node.desc(sdfg), sdfg, self._frame) - src_def_type, _ = self._dispatcher.defined_vars.get(src_name) - dst_def_type, _ = self._dispatcher.defined_vars.get(dst_name) - - # Construct indices (if the length of the stride array is zero, - # resolves to an empty string) - src_index = " + ".join([ - "__dace_copy{}{}".format(i, " * " + cpp.sym2cpp(stride) if stride != 1 else "") - for i, stride in enumerate(src_strides) if copy_shape[i] != 1 - ]) - dst_index = " + ".join([ - "__dace_copy{}{}".format(i, " * " + cpp.sym2cpp(stride) if stride != 1 else "") - for i, stride in enumerate(dst_strides) if copy_shape[i] != 1 - ]) - if not src_index: - src_index = "0" - if not dst_index: - dst_index = "0" - - # Language specific - read_expr = self.make_read(src_def_type, dtype, src_node.label, src_expr, src_index, is_pack, - packing_factor) - - # Language specific - if dst_storage == dtypes.StorageType.FPGA_ShiftRegister: - write_expr = self.make_shift_register_write(dst_def_type, dtype, dst_node.label, dst_expr, dst_index, - read_expr, None, is_unpack, packing_factor, sdfg) - else: - write_expr = self.make_write(dst_def_type, dtype, dst_node.label, dst_expr, dst_index, read_expr, None, - is_unpack, packing_factor) - - callsite_stream.write(write_expr) - - # Loop outtro - for _ in range(num_loops): - callsite_stream.write("}") - - else: - - self.generate_memlet_definition(sdfg, cfg, dfg, state_id, src_node, dst_node, edge, callsite_stream) - - @staticmethod - def make_opencl_parameter(name, desc): - if isinstance(desc, dt.Array): - return (f"hlslib::ocl::Buffer<{desc.dtype.ctype}, " - f"hlslib::ocl::Access::readWrite> &{name}") - else: - return (desc.as_arg(with_types=True, name=name)) - - def get_next_scope_entries(self, sdfg, dfg, scope_entry): - parent_scope_entry = dfg.entry_node(scope_entry) - parent_scope = dfg.scope_subgraph(parent_scope_entry) - - # Get all scopes from the same level - all_scopes = [node for node in parent_scope.bfs_nodes() if isinstance(node, dace.sdfg.nodes.EntryNode)] - - return all_scopes[all_scopes.index(scope_entry) + 1:] - - def generate_node(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphView, state_id: int, node: nodes.Node, - function_stream: CodeIOStream, callsite_stream: CodeIOStream) -> None: - method_name = "_generate_" + type(node).__name__ - # Fake inheritance... use this class' method if it exists, - # otherwise fall back on CPU codegen - if hasattr(self, method_name): - - if hasattr(node, "schedule") and node.schedule not in [ - dtypes.ScheduleType.Default, dtypes.ScheduleType.FPGA_Device, dtypes.ScheduleType.FPGA_Multi_Pumped - ]: - warnings.warn("Found schedule {} on {} node in FPGA code. " - "Ignoring.".format(node.schedule, - type(node).__name__)) - - getattr(self, method_name)(sdfg, cfg, dfg, state_id, node, function_stream, callsite_stream) - else: - old_codegen = self._cpu_codegen.calling_codegen - self._cpu_codegen.calling_codegen = self - - self._cpu_codegen.generate_node(sdfg, cfg, dfg, state_id, node, function_stream, callsite_stream) - - self._cpu_codegen.calling_codegen = old_codegen - - def copy_memory(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphView, state_id: int, - src_node: Union[nodes.CodeNode, nodes.AccessNode], - dst_node: Union[nodes.CodeNode, nodes.AccessNode], edge: MultiConnectorEdge[memlet.Memlet], - function_stream: CodeIOStream, callsite_stream: CodeIOStream) -> None: - if isinstance(src_node, dace.sdfg.nodes.CodeNode): - src_storage = dtypes.StorageType.Register - try: - src_parent = dfg.entry_node(src_node) - except KeyError: - src_parent = None - dst_schedule = (None if src_parent is None else src_parent.map.schedule) - else: - src_storage = src_node.desc(sdfg).storage - - if isinstance(dst_node, dace.sdfg.nodes.CodeNode): - dst_storage = dtypes.StorageType.Register - else: - dst_storage = dst_node.desc(sdfg).storage - - try: - dst_parent = dfg.entry_node(dst_node) - except KeyError: - dst_parent = None - dst_schedule = None if dst_parent is None else dst_parent.map.schedule - state_dfg = cfg.state(state_id) - - # Check if this is a copy memlet using at least one multibank array - edge_list = [] - if (isinstance(src_node, dace.sdfg.nodes.AccessNode) and isinstance(dst_node, dace.sdfg.nodes.AccessNode)): - src_array = src_node.desc(sdfg) - dst_array = dst_node.desc(sdfg) - src_is_multibank = is_multibank_array_with_distributed_index(src_array) - dst_is_multibank = is_multibank_array_with_distributed_index(dst_array) - if src_is_multibank or dst_is_multibank: - modedge = copy.deepcopy(edge) - mem: memlet.Memlet = modedge.data - if mem.src_subset is None: - mem.src_subset = subsets.Range.from_array(src_array) - if mem.dst_subset is None: - mem.dst_subset = subsets.Range.from_array(dst_array) - if src_is_multibank: - bankbeg, bankend = get_multibank_ranges_from_subset(mem.src_subset, sdfg) - if dst_is_multibank: - bankbeg, bankend = get_multibank_ranges_from_subset(mem.dst_subset, sdfg) - num_accessed_banks = bankend - bankbeg - oldmem = copy.deepcopy(mem) - for i in range(num_accessed_banks): - src_index = oldmem.src_subset[0][0] + i - dst_index = oldmem.dst_subset[0][0] + i - # Support for ignoring the distributed index if it's not required, e.g. on the host - if src_is_multibank or num_accessed_banks > 1: - mem.src_subset = modify_distributed_subset(mem.src_subset, src_index) - if dst_is_multibank or num_accessed_banks > 1: - mem.dst_subset = modify_distributed_subset(mem.dst_subset, dst_index) - edge_list.append(copy.deepcopy(modedge)) - else: - edge_list.append(edge) - else: - edge_list.append(edge) - - # Emit actual copy - for current_edge in edge_list: - self._emit_copy(sdfg, cfg, state_id, src_node, src_storage, dst_node, dst_storage, dst_schedule, - current_edge, state_dfg, function_stream, callsite_stream) - - def _generate_PipelineEntry(self, *args, **kwargs): - self._generate_MapEntry(*args, **kwargs) - - def _is_innermost(self, scope, scope_dict, sdfg): - to_search = list(scope) - while len(to_search) > 0: - x = to_search.pop() - if (isinstance(x, (dace.sdfg.nodes.MapEntry, dace.sdfg.nodes.PipelineEntry))): - # Degenerate loops should not be pipelined - fully_degenerate = True - for begin, end, skip in x.map.range: - if not self._is_degenerate(begin, end, skip, sdfg)[0]: - fully_degenerate = False - break - # Non-unrolled, non-degenerate loops must be pipelined, so we - # are not innermost - if not x.unroll and not fully_degenerate: - return False - to_search += scope_dict[x] - elif isinstance(x, dace.sdfg.nodes.NestedSDFG): - for state in x.sdfg.states(): - if not self._is_innermost(state.nodes(), state.scope_children(), x.sdfg): - return False - return True - - @staticmethod - def _is_degenerate(begin, end, skip, sdfg): - try: - begin_val = evaluate(begin, sdfg.constants) - skip_val = evaluate(skip, sdfg.constants) - end_val = evaluate(end, sdfg.constants) - is_degenerate = begin_val + skip_val > end_val - return is_degenerate, begin_val - except TypeError: # Cannot statically evaluate expression - return False, begin - - def _generate_MapEntry(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphView, state_id: int, - node: nodes.MapEntry, function_stream: CodeIOStream, callsite_stream: CodeIOStream) -> None: - result = callsite_stream - - scope_dict = dfg.scope_dict() - if node.map in self._unrolled_pes: - - # This is a top-level unrolled map, meaning it has been used to - # replicate processing elements. Don't generate anything here. - pass - - else: - # Add extra opening brace (dynamic map ranges, closed in MapExit - # generator) - callsite_stream.write('{', cfg, state_id, node) - - # Define dynamic loop bounds variables (dynamic input memlets to - # the MapEntry node) - for e in dynamic_map_inputs(cfg.state(state_id), node): - if e.data.data != e.dst_conn: - callsite_stream.write( - self._cpu_codegen.memlet_definition(sdfg, e.data, False, e.dst_conn, - e.dst.in_connectors[e.dst_conn]), cfg, state_id, node) - - # Pipeline innermost loops - scope_children = dfg.scope_children() - scope = scope_children[node] - is_innermost = self._is_innermost(scope, scope_children, sdfg) - - # Generate custom iterators if this is a pipelined (and thus - # flattened) loop - if isinstance(node, dace.sdfg.nodes.PipelineEntry): - for i in range(len(node.map.range)): - result.write("long {} = {};\n".format(node.map.params[i], node.map.range[i][0])) - for var, value in node.pipeline.additional_iterators.items(): - result.write("long {} = {};\n".format(var, value)) - - is_degenerate = [] - degenerate_values = [] - for begin, end, skip in node.map.range: - # If we know at compile-time that a loop will only have a - # single iteration, we can replace it with a simple assignment - b, val = self._is_degenerate(begin, end, skip, sdfg) - is_degenerate.append(b) - degenerate_values.append(val) - fully_degenerate = all(is_degenerate) - - # Being this a map (each iteration is independent), we can add pragmas to ignore dependencies on data - # that is read/written inside this map, if there are no WCR. If there are no WCR at all, we can add - # a more generic pragma to ignore all loop-carried dependencies. - map_exit_node = dfg.exit_node(node) - state = cfg.state(state_id) - candidates_in = set() - candidates_out = set() - is_there_a_wcr = False - # get data that is read/written - for _, _, _, _, memlet in state.in_edges(node): - if memlet.data is not None: - desc = sdfg.arrays[memlet.data] - if (isinstance(desc, dt.Array) and (desc.storage == dtypes.StorageType.FPGA_Global - or desc.storage == dtypes.StorageType.FPGA_Local) - and memlet.wcr is None): - candidates_in.add(memlet.data) - elif memlet.wcr is not None: - is_there_a_wcr = True - - for _, _, _, _, memlet in state.out_edges(map_exit_node): - if memlet.data is not None: - desc = sdfg.arrays[memlet.data] - if (isinstance(desc, dt.Array) and (desc.storage == dtypes.StorageType.FPGA_Global - or desc.storage == dtypes.StorageType.FPGA_Local) - and memlet.wcr is None and desc.total_size != 1): - candidates_out.add(memlet.data) - elif memlet.wcr is not None: - is_there_a_wcr = True - in_out_data = candidates_in.intersection(candidates_out) - - # Generate nested loops - if not isinstance(node, dace.sdfg.nodes.PipelineEntry): - - for i, r in enumerate(node.map.range): - - # Add pragmas - if not fully_degenerate and not is_degenerate[i]: - if node.map.unroll: - self.generate_unroll_loop_pre(result, None, sdfg, cfg, state_id, node) - elif is_innermost: - self.generate_pipeline_loop_pre(result, sdfg, cfg, state_id, node) - # Do not put pragma if this is degenerate (loop does not exist) - self.generate_flatten_loop_pre(result, sdfg, cfg, state_id, node) - if not node.map.unroll: - if len(in_out_data) > 0 and is_there_a_wcr == False: - # add pragma to ignore all loop carried dependencies - self.generate_no_dependence_pre(result, sdfg, cfg, state_id, node) - else: - # add specific pragmas - for candidate in in_out_data: - self.generate_no_dependence_pre(result, sdfg, cfg, state_id, node, candidate) - - var = node.map.params[i] - begin, end, skip = r - # decide type of loop variable - loop_var_type = "int" - # try to decide type of loop variable - try: - if (evaluate(begin, sdfg.constants) >= 0 and evaluate(skip, sdfg.constants) > 0): - # it could be an unsigned (uint32) variable: we need - # to check to the type of 'end', - # if we are able to determine it - symbols = list(dace.symbolic.symlist(end).values()) - if len(symbols) > 0: - sym = symbols[0] - if str(sym) in sdfg.symbols: - end_type = sdfg.symbols[str(sym)].dtype - else: - # Symbol not found, try to use symbol object - # or use the default symbol type (int32) - end_type = sym.dtype - else: - end_type = None - if end_type is not None: - if np.dtype(end_type.dtype.type) > np.dtype('uint32'): - v = dace.config.Config.get("compiler", "fpga", "vendor") - if v.casefold() == 'intel_fpga'.casefold(): - loop_var_type = end_type.ocltype - else: - loop_var_type = end_type.ctype - elif np.issubdtype(np.dtype(end_type.dtype.type), np.unsignedinteger): - loop_var_type = "size_t" - except (UnboundLocalError): - raise UnboundLocalError('Pipeline scopes require ' - 'specialized bound values') - except (TypeError): - # Raised when the evaluation of begin or skip fails. - # This could occur, for example, if they are defined in terms of other symbols, which - # is the case in a tiled map - pass - - # To enforce opencl type long instead of c type long long for intel fpga - v = dace.config.Config.get("compiler", "fpga", "vendor") - if v.casefold() == 'intel_fpga'.casefold(): - loop_var_type = loop_var_type.replace("long long", "long") - - if is_degenerate[i]: - result.write("{{\nconst {} {} = {}; // Degenerate loop".format( - loop_var_type, var, degenerate_values[i])) - else: - result.write( - "for ({} {} = {}; {} < {}; {} += {}) {{\n".format(loop_var_type, var, cpp.sym2cpp(begin), - var, cpp.sym2cpp(end + 1), var, - cpp.sym2cpp(skip)), cfg, state_id, node) - - #Add unroll pragma - if not fully_degenerate and not is_degenerate[i] and node.map.unroll: - self.generate_unroll_loop_post(result, None, sdfg, cfg, state_id, node) - - else: - pipeline = node.pipeline - flat_it = pipeline.iterator_str() - bound = pipeline.loop_bound_str() - - if len(in_out_data) > 0: - if is_there_a_wcr == False: - # add pragma to ignore all loop carried dependencies - self.generate_no_dependence_pre(result, sdfg, cfg, state_id, node) - else: - # add specific pragmas - for candidate in in_out_data: - self.generate_no_dependence_pre(result, sdfg, cfg, state_id, node, candidate) - result.write("for (long {it} = 0; {it} < {bound}; ++{it}) {{\n".format( - it=flat_it, bound=node.pipeline.loop_bound_str())) - if pipeline.init_size != 0: - result.write("const bool {} = {} < {};\n".format(node.pipeline.init_condition(), flat_it, - cpp.sym2cpp(pipeline.init_size))) - if pipeline.drain_size != 0: - result.write("const bool {} = {} >= {};\n".format( - node.pipeline.drain_condition(), flat_it, - bound + (" - " + cpp.sym2cpp(pipeline.drain_size) if pipeline.drain_size != 0 else ""))) - - # Add pragmas - if not fully_degenerate: - if not node.map.unroll: - if is_innermost: - self.generate_pipeline_loop_post(result, sdfg, cfg, state_id, node) - self.generate_flatten_loop_post(result, sdfg, cfg, state_id, node) - # add pragmas for data read/written inside this map, but only for local arrays - for candidate in in_out_data: - if sdfg.arrays[candidate].storage != dtypes.StorageType.FPGA_Global: - self.generate_no_dependence_post(result, sdfg, cfg, state_id, node, candidate) - - # Emit internal transient array allocation - to_allocate = dace.sdfg.local_transients(sdfg, cfg.state(state_id), node) - allocated = set() - for child in dfg.scope_children()[node]: - if not isinstance(child, dace.sdfg.nodes.AccessNode): - continue - if child.data not in to_allocate or child.data in allocated: - continue - allocated.add(child.data) - self._dispatcher.dispatch_allocate(sdfg, cfg, dfg, state_id, child, child.desc(sdfg), None, result) - - def _generate_PipelineExit(self, *args, **kwargs): - self._generate_MapExit(*args, **kwargs) - - def _generate_MapExit(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphView, state_id: int, - node: nodes.MapExit, function_stream: CodeIOStream, callsite_stream: CodeIOStream) -> None: - scope_dict = dfg.scope_dict() - entry_node = scope_dict[node] - if entry_node.map in self._unrolled_pes: - # This was generated as unrolled processing elements, no need to - # generate anything here - return - if isinstance(node, dace.sdfg.nodes.PipelineExit): - flat_it = node.pipeline.iterator_str() - bound = node.pipeline.loop_bound_str() - pipeline = node.pipeline - cond = [] - if pipeline.init_size != 0 and pipeline.init_overlap == False: - cond.append("!" + pipeline.init_condition()) - if pipeline.drain_size != 0 and pipeline.drain_overlap == False: - cond.append("!" + pipeline.drain_condition()) - if len(cond) > 0: - callsite_stream.write("if ({}) {{".format(" && ".join(cond))) - # ranges could have been defined in terms of floor/ceiling. Before printing the code - # they are converted from a symbolic expression to a C++ compilable expression - for it, r in reversed(list(zip(pipeline.params, pipeline.range))): - callsite_stream.write("if ({it} >= {end}) {{\n{it} = {begin};\n".format( - it=it, - begin=dace.symbolic.symstr(r[0], cpp_mode=True), - end=dace.symbolic.symstr(r[1], cpp_mode=True))) - for it, r in zip(pipeline.params, pipeline.range): - callsite_stream.write("}} else {{\n{it} += {step};\n}}\n".format(it=it, - step=dace.symbolic.symstr( - r[2], cpp_mode=True))) - if len(cond) > 0: - callsite_stream.write("}\n") - callsite_stream.write("}\n}\n") - else: - self._cpu_codegen._generate_MapExit(sdfg, cfg, dfg, state_id, node, function_stream, callsite_stream) - - def generate_kernel(self, - sdfg: dace.SDFG, - cfg: ControlFlowRegion, - state: dace.SDFGState, - kernel_name: str, - subgraphs: list, - function_stream: CodeIOStream, - callsite_stream: CodeIOStream, - state_host_header_stream: CodeIOStream, - state_host_body_stream: CodeIOStream, - instrumentation_stream: CodeIOStream, - state_parameters: list, - kernel_id: int = None): - """ - Entry point for generating an FPGA Kernel out of the given subgraphs. - - :param sdfg: - :param state: - :param kernel_name: the generated kernel name. - :param subgraphs: the connected components that constitute this kernel. - :param function_stream: CPU code stream, contains global declarations. - :param callsite_stream: CPU code stream, contains code for invoking kernels, ... - :param state_host_header_stream: Device-specific host code stream: contains the host code - for the state global declarations. - :param state_host_body_stream: Device-specific host code stream: contains all the code related - to this state, for creating transient buffers, spawning kernels, and synchronizing them. - :param instrumentation_stream: Code for profiling kernel execution time. - :param state_parameters: a list of parameters that must be passed to the state. It will get populated - considering all the parameters needed by the kernels in this state. - :param kernel_id: Unique ID of this kernels as computed in the generate_state function - """ - - if self._in_device_code: - raise cgx.CodegenError("Tried to generate kernel from device code") - self._in_device_code = True - self._cpu_codegen._packed_types = True - kernel_stream = CodeIOStream() - - predecessors = [] - # Check if this kernels depends from someone else - if kernel_id is not None and kernel_id in self._kernels_dependencies: - - def get_kernel_name(val): - for key, value in self._kernels_names_to_id.items(): - if val == value: - return key - raise RuntimeError(f"Error while generating kernel dependencies. Kernel {val} not found.") - - # Build a list containing all the name of kernels from which this one depends - for pred in self._kernels_dependencies[kernel_id]: - predecessors.append(get_kernel_name(pred)) - - # Actual kernel code generation - self.generate_kernel_internal(sdfg, cfg, state, kernel_name, predecessors, subgraphs, kernel_stream, - state_host_header_stream, state_host_body_stream, instrumentation_stream, - function_stream, callsite_stream, state_parameters) - self._kernel_count = self._kernel_count + 1 - self._in_device_code = False - self._cpu_codegen._packed_types = False - - # Check if this is a multi pumped kernel - is_multi_pumped = all([self.is_multi_pumped_subgraph(sg) for sg in subgraphs]) - - # Store code strings to be passed to compilation phase - self._kernel_codes.append((kernel_name, kernel_stream.getvalue(), is_multi_pumped)) - - self._allocated_global_arrays = set() - - def _module_name(self, subgraph, state): - """ - Generate the name of an FPGA module produced from the given subgraph. - """ - to_traverse = subgraph.source_nodes() - seen = set() - tasklet_list = [] - access_nodes = [] - while len(to_traverse) > 0: - n = to_traverse.pop() - if n in seen: - continue - seen.add(n) - if (isinstance(n, dace.sdfg.nodes.Tasklet) or isinstance(n, dace.sdfg.nodes.NestedSDFG)): - tasklet_list.append(n) - else: - if isinstance(n, dace.sdfg.nodes.AccessNode): - access_nodes.append(n) - for e in subgraph.out_edges(n): - if e.dst not in seen: - to_traverse.append(e.dst) - # Name module according to all reached tasklets (can be just one) - labels = [n.label.replace(" ", "_") + f"_{state.node_id(n)}" for n in tasklet_list] - # If there are no tasklets, name it after access nodes in the - # subgraph - if len(labels) == 0: - labels = [n.label.replace(" ", "_") for n in access_nodes] - if len(labels) == 0: - raise RuntimeError("Expected at least one tasklet or data node.") - return "_".join(labels) - - def generate_modules(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, kernel_name: str, subgraphs, - subgraph_parameters, module_stream, entry_stream, host_stream, instrumentation_stream): - """ - Generate all PEs inside an FPGA Kernel. - """ - for subgraph in subgraphs: - module_name = self._module_name(subgraph, state) - self.generate_module(sdfg, cfg, state, kernel_name, module_name, subgraph, subgraph_parameters[subgraph], - module_stream, entry_stream, host_stream, instrumentation_stream) - - def generate_nsdfg_header(self, sdfg, cfg, state, state_id, node, memlet_references, sdfg_label): - return self._cpu_codegen.generate_nsdfg_header(sdfg, - cfg, - state, - state_id, - node, - memlet_references, - sdfg_label, - state_struct=False) - - def generate_nsdfg_call(self, sdfg, cfg, state, node, memlet_references, sdfg_label): - return self._cpu_codegen.generate_nsdfg_call(sdfg, - cfg, - state, - node, - memlet_references, - sdfg_label, - state_struct=False) - - def generate_nsdfg_arguments(self, sdfg, cfg, dfg, state, node): - return self._cpu_codegen.generate_nsdfg_arguments(sdfg, cfg, state, dfg, node) - - def generate_host_function_boilerplate(self, sdfg, cfg, state, nested_global_transients, host_code_stream): - """ - Generates global transients that must be passed to the state (required by a kernel) - """ - - # Any extra transients stored in global memory on the FPGA must now be - # allocated and passed to the kernel - for arr_node in nested_global_transients: - self._dispatcher.dispatch_allocate(sdfg, cfg, state, None, arr_node, arr_node.desc(sdfg), None, - host_code_stream) - - def _generate_Tasklet(self, *args, **kwargs): - # Call CPU implementation with this code generator as callback - self._cpu_codegen._generate_Tasklet(*args, codegen=self, **kwargs) - - def define_out_memlet(self, sdfg: SDFG, cfg: ControlFlowRegion, state_dfg: StateSubgraphView, state_id: int, - src_node: nodes.Node, dst_node: nodes.Node, edge: MultiConnectorEdge[memlet.Memlet], - function_stream: CodeIOStream, callsite_stream: CodeIOStream) -> None: - self._dispatcher.dispatch_copy(src_node, dst_node, edge, sdfg, cfg, state_dfg, state_id, function_stream, - callsite_stream) - - def process_out_memlets(self, *args, **kwargs): - # Call CPU implementation with this code generator as callback - self._cpu_codegen.process_out_memlets(*args, codegen=self, **kwargs) - - def generate_tasklet_preamble(self, *args, **kwargs): - # Fall back on CPU implementation - self._cpu_codegen.generate_tasklet_preamble(*args, **kwargs) - - def generate_tasklet_postamble(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphView, state_id: int, - node: nodes.Node, function_stream: CodeIOStream, before_memlets_stream: CodeIOStream, - after_memlets_stream: CodeIOStream) -> None: - # Inject dependency pragmas on memlets - for edge in dfg.out_edges(node): - dataname = edge.data.data - if not dataname: - continue # Empty memlet - datadesc = sdfg.arrays[dataname] - if (isinstance(datadesc, dt.Array) and datadesc.storage == dace.StorageType.FPGA_Local - and not cpp.is_write_conflicted(dfg, edge) and self._dispatcher.defined_vars.has(edge.src_conn) - and datadesc.total_size != 1): - if is_multibank_array_with_distributed_index(datadesc): - accessed_subset, _ = get_multibank_ranges_from_subset(edge.data.dst_subset or edge.data.subset, - sdfg) - else: - accessed_subset = 0 - - self.generate_no_dependence_post(after_memlets_stream, sdfg, cfg, state_id, node, dataname, - accessed_subset) - - def make_ptr_vector_cast(self, *args, **kwargs): - return cpp.make_ptr_vector_cast(*args, **kwargs) - - def make_ptr_assignment(self, *args, **kwargs): - return self._cpu_codegen.make_ptr_assignment(*args, codegen=self, **kwargs) - - def instrument_opencl_kernel(self, kernel_name: str, state_id: int, cfg_id: int, code_stream: CodeIOStream): - """ - Emits code to instrument the OpenCL kernel with the given `kernel_name`. - """ - kernel_index = self._kernel_instrumentation_index - self._kernel_instrumentation_index += 1 - if Config.get_bool("instrumentation", "print_fpga_runtime"): - print_str = f""" -const double elapsed = 1e-9 * (event_end - event_start); -std::cout << "FPGA OpenCL kernel \\"{kernel_name}\\" executed in " << elapsed << " seconds.\\n";\ - """ - else: - print_str = "" - code_stream.write(f"""\ - {{ -cl_ulong event_start = 0; -cl_ulong event_end = 0; -hlslib::ocl::Event const &event = all_events[{kernel_index}]; -event.getProfilingInfo(CL_PROFILING_COMMAND_START, &event_start); -event.getProfilingInfo(CL_PROFILING_COMMAND_END, &event_end); -if (event_start < first_start) {{ - first_start = event_start; -}} -if (event_end > last_end) {{ - last_end = event_end; -}} -// Convert from nanoseconds (reported by OpenCL) to microseconds (expected by the profiler) -__state->report.add_completion("{kernel_name}", "FPGA", 1e-3 * event_start, 1e-3 * event_end, {cfg_id}, {state_id}, -1);{print_str} -}}""") diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index 449e312efa..d63dddf8dc 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -1,4 +1,4 @@ -# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. import collections import copy import pathlib @@ -14,8 +14,8 @@ from dace.codegen import dispatcher as disp from dace.codegen.prettycode import CodeIOStream from dace.codegen.common import codeblock_to_cpp, sym2cpp -from dace.codegen.targets.target import TargetCodeGenerator -from dace.codegen.tools.type_inference import infer_expr_type +from dace.codegen.target import TargetCodeGenerator +from dace.sdfg.type_inference import infer_expr_type from dace.sdfg import SDFG, SDFGState, nodes from dace.sdfg import scope as sdscope from dace.sdfg import utils @@ -139,6 +139,12 @@ def generate_fileheader(self, sdfg: SDFG, global_stream: CodeIOStream, backend: global_stream.write('#include "../../include/hash.h"\n', sdfg) ######################################################### + # Target-based includes + for target in self._dispatcher.used_targets: + headers = target.get_includes() + if backend in headers: + global_stream.write("\n".join("#include \"" + h + "\"" for h in headers[backend]), sdfg) + # Environment-based includes for env in self.environments: if len(env.headers) > 0: @@ -279,11 +285,17 @@ def generate_footer(self, sdfg: SDFG, global_stream: CodeIOStream, callsite_stre callsite_stream.write( f""" DACE_EXPORTED {mangle_dace_state_struct_name(sdfg)} *__dace_init_{sdfg.name}({initparams}) -{{ - int __result = 0; - {mangle_dace_state_struct_name(sdfg)} *__state = new {mangle_dace_state_struct_name(sdfg)}; +{{""", sdfg) + + # Invoke all instrumentation providers + for instr in self._dispatcher.instrumentation.values(): + if instr is not None: + instr.on_sdfg_init_begin(sdfg, callsite_stream, global_stream) - """, sdfg) + callsite_stream.write( + f""" + int __result = 0; + {mangle_dace_state_struct_name(sdfg)} *__state = new {mangle_dace_state_struct_name(sdfg)};""", sdfg) for target in self._dispatcher.used_targets: if target.has_initializer: @@ -304,17 +316,29 @@ def generate_footer(self, sdfg: SDFG, global_stream: CodeIOStream, callsite_stre callsite_stream.write(self._initcode.getvalue(), sdfg) - callsite_stream.write( - f""" + callsite_stream.write(f""" if (__result) {{ delete __state; return nullptr; }} +""", sdfg) + # Invoke all instrumentation providers + for instr in self._dispatcher.instrumentation.values(): + if instr is not None: + instr.on_sdfg_init_end(sdfg, callsite_stream, global_stream) + callsite_stream.write( + f""" return __state; }} DACE_EXPORTED int __dace_exit_{sdfg.name}({mangle_dace_state_struct_name(sdfg)} *__state) {{ +""", sdfg) + # Invoke all instrumentation providers + for instr in self._dispatcher.instrumentation.values(): + if instr is not None: + instr.on_sdfg_exit_begin(sdfg, callsite_stream, global_stream) + callsite_stream.write(f""" int __err = 0; """, sdfg) @@ -349,6 +373,10 @@ def generate_footer(self, sdfg: SDFG, global_stream: CodeIOStream, callsite_stre callsite_stream.write("}") callsite_stream.write('delete __state;\n', sdfg) + # Invoke all instrumentation providers + for instr in self._dispatcher.instrumentation.values(): + if instr is not None: + instr.on_sdfg_exit_end(sdfg, callsite_stream, global_stream) callsite_stream.write('return __err;\n}\n', sdfg) def generate_external_memory_management(self, sdfg: SDFG, callsite_stream: CodeIOStream): @@ -798,6 +826,11 @@ def determine_allocation_lifetime(self, top_sdfg: SDFG): def allocate_arrays_in_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, scope: Union[nodes.EntryNode, SDFGState, SDFG], function_stream: CodeIOStream, callsite_stream: CodeIOStream) -> None: + if len(self.to_allocate[scope]) == 0: + return + for instr in self._dispatcher.instrumentation.values(): + if instr is not None: + instr.on_allocation_begin(sdfg, scope, callsite_stream) """ Dispatches allocation of all arrays in the given scope. """ for tsdfg, state, node, declare, allocate, _ in self.to_allocate[scope]: if state is not None: @@ -809,10 +842,18 @@ def allocate_arrays_in_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, scope: Un self._dispatcher.dispatch_allocate(tsdfg, cfg if state is None else state.parent_graph, state, state_id, node, desc, function_stream, callsite_stream, declare, allocate) + for instr in self._dispatcher.instrumentation.values(): + if instr is not None: + instr.on_allocation_end(sdfg, scope, callsite_stream) def deallocate_arrays_in_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, scope: Union[nodes.EntryNode, SDFGState, SDFG], function_stream: CodeIOStream, callsite_stream: CodeIOStream): + if len(self.to_allocate[scope]) == 0: + return + for instr in self._dispatcher.instrumentation.values(): + if instr is not None: + instr.on_deallocation_begin(sdfg, scope, callsite_stream) """ Dispatches deallocation of all arrays in the given scope. """ for tsdfg, state, node, _, _, deallocate in self.to_allocate[scope]: if not deallocate: @@ -826,6 +867,9 @@ def deallocate_arrays_in_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, scope: self._dispatcher.dispatch_deallocate(tsdfg, state.parent_graph, state, state_id, node, desc, function_stream, callsite_stream) + for instr in self._dispatcher.instrumentation.values(): + if instr is not None: + instr.on_deallocation_end(sdfg, scope, callsite_stream) def generate_code(self, sdfg: SDFG, @@ -911,6 +955,11 @@ def generate_code(self, interstate_symbols.update(symbols) global_symbols.update(symbols) + try: + edge_codegen = self.dispatcher.get_scope_dispatcher(schedule) + except KeyError: + edge_codegen = self.dispatcher.get_generic_node_dispatcher() + for isvarName, isvarType in interstate_symbols.items(): if isvarType is None: raise TypeError(f'Type inference failed for symbol {isvarName}') @@ -921,17 +970,10 @@ def generate_code(self, # as part of the function's arguments if not is_top_level and isvarName in sdfg.parent_nsdfg_node.symbol_mapping: continue - isvar = data.Scalar(isvarType) - if (schedule in (dtypes.ScheduleType.FPGA_Device, dtypes.ScheduleType.FPGA_Multi_Pumped) - and config.Config.get('compiler', 'fpga', 'vendor').lower() == 'intel_fpga'): - # Emit OpenCL type - callsite_stream.write(f'{isvarType.ocltype} {isvarName};\n', sdfg) - self.dispatcher.defined_vars.add(isvarName, disp.DefinedType.Scalar, isvarType.ctype) - else: - # If the variable is passed as an input argument to the SDFG, do not need to declare it - if isvarName not in outside_symbols: - callsite_stream.write('%s;\n' % (isvar.as_arg(with_types=True, name=isvarName)), sdfg) - self.dispatcher.defined_vars.add(isvarName, disp.DefinedType.Scalar, isvarType.ctype) + if isvarName not in outside_symbols: + edge_codegen.emit_interstate_variable_declaration(isvarName, isvarType, callsite_stream, sdfg) + # If the variable is passed as an input argument to the SDFG, do not need to declare it + callsite_stream.write('\n', sdfg) ####################################################################### diff --git a/dace/codegen/targets/intel_fpga.py b/dace/codegen/targets/intel_fpga.py deleted file mode 100644 index 55bae28669..0000000000 --- a/dace/codegen/targets/intel_fpga.py +++ /dev/null @@ -1,1582 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -import ast -import functools -import copy -import itertools -from six import StringIO -import numpy as np - -import dace -from dace import registry, dtypes, symbolic -from dace.codegen import cppunparse -from dace.config import Config -from dace.codegen import exceptions as cgx -from dace.codegen.codeobject import CodeObject -from dace.codegen.dispatcher import DefinedType -from dace.codegen.prettycode import CodeIOStream -from dace.codegen.targets import cpp, fpga -from dace.codegen.common import codeblock_to_cpp -from dace.codegen.tools.type_inference import infer_expr_type -from dace.frontend.python.astutils import rname, unparse, evalnode -from dace.frontend import operations -from dace.sdfg import find_input_arraynode, find_output_arraynode -from dace.sdfg import nodes, utils as sdutils -from dace.codegen.common import sym2cpp -from dace.sdfg import SDFGState -from dace.sdfg.sdfg import SDFG -from dace.sdfg.state import ControlFlowRegion, StateSubgraphView -import dace.sdfg.utils as utils -from dace.symbolic import evaluate -from collections import defaultdict - -REDUCTION_TYPE_TO_HLSLIB = { - dace.dtypes.ReductionType.Min: "min", - dace.dtypes.ReductionType.Max: "max", - dace.dtypes.ReductionType.Sum: "+", - dace.dtypes.ReductionType.Sub: "-", - dace.dtypes.ReductionType.Product: "*", - dace.dtypes.ReductionType.Div: "/", - dace.dtypes.ReductionType.Logical_And: " && ", - dace.dtypes.ReductionType.Bitwise_And: "&", - dace.dtypes.ReductionType.Logical_Or: "||", - dace.dtypes.ReductionType.Bitwise_Or: "|", - dace.dtypes.ReductionType.Bitwise_Xor: "^" -} - -REDUCTION_TYPE_TO_PYEXPR = { - dace.dtypes.ReductionType.Min: "min({a}, {b})", - dace.dtypes.ReductionType.Max: "max({a}, {b})", - dace.dtypes.ReductionType.Sum: "{a} + {b}", - dace.dtypes.ReductionType.Product: "*", - dace.dtypes.ReductionType.Logical_And: " && ", - dace.dtypes.ReductionType.Bitwise_And: "&", - dace.dtypes.ReductionType.Logical_Or: "||", - dace.dtypes.ReductionType.Bitwise_Or: "|", - dace.dtypes.ReductionType.Bitwise_Xor: "^" -} - - -@registry.autoregister_params(name='intel_fpga') -class IntelFPGACodeGen(fpga.FPGACodeGen): - target_name = 'intel_fpga' - title = 'Intel FPGA' - language = 'hls' - - def __init__(self, *args, **kwargs): - self.fpga_vendor = Config.get("compiler", "fpga", "vendor") - - # Check that the given vendor is supported - fpga.is_vendor_supported(self.fpga_vendor) - - if self.fpga_vendor.lower() != "intel_fpga": - # Don't register this code generator - return - # Keep track of generated converters to avoid multiple definition - self.generated_converters = set() - # constants - self.generated_constants = set() - # Channel mangles - self.channel_mangle = defaultdict(dict) - # Modules name mangles - self.module_mange = defaultdict(dict) - - # Keep track of external streams - self.external_streams = set() - - super().__init__(*args, **kwargs) - - @staticmethod - def cmake_options(): - - host_flags = Config.get("compiler", "intel_fpga", "host_flags") - kernel_flags = Config.get("compiler", "intel_fpga", "kernel_flags") - mode = Config.get("compiler", "intel_fpga", "mode") - target_board = Config.get("compiler", "intel_fpga", "board") - enable_debugging = ("ON" if Config.get_bool("compiler", "intel_fpga", "enable_debugging") else "OFF") - autobuild = ("ON" if Config.get_bool("compiler", "fpga", "autobuild_bitstreams") else "OFF") - options = [ - "-DDACE_INTELFPGA_HOST_FLAGS=\"{}\"".format(host_flags), - "-DDACE_INTELFPGA_KERNEL_FLAGS=\"{}\"".format(kernel_flags), "-DDACE_INTELFPGA_MODE={}".format(mode), - "-DDACE_INTELFPGA_TARGET_BOARD=\"{}\"".format(target_board), - "-DDACE_INTELFPGA_ENABLE_DEBUGGING={}".format(enable_debugging), - "-DDACE_FPGA_AUTOBUILD_BITSTREAM={}".format(autobuild) - ] - # Override Intel FPGA OpenCL installation directory - if Config.get("compiler", "intel_fpga", "path"): - options.append("-DINTELFPGAOCL_ROOT_DIR=\"{}\"".format( - Config.get("compiler", "intel_fpga", "path").replace("\\", "/"))) - return options - - def get_generated_codeobjects(self): - - execution_mode = Config.get("compiler", "intel_fpga", "mode") - kernel_file_name = "DACE_BINARY_DIR \"/{}".format(self._program_name) - emulation_flag = "" - if execution_mode == "emulator": - kernel_file_name += "_emulator.aocx\"" - emulation_flag = ("\n dace::set_environment_variable" - "(\"CL_CONTEXT_EMULATOR_DEVICE_INTELFPGA\", \"1\");") - elif execution_mode == "simulator": - kernel_file_name += "_simulator.aocx\"" - elif execution_mode == "hardware": - kernel_file_name += "_hardware.aocx\"" - else: - raise cgx.CodegenError("Unknown Intel FPGA execution mode: {}".format(execution_mode)) - - host_code = CodeIOStream() - host_code.write('#include "dace/intel_fpga/host.h"') - if len(self._dispatcher.instrumentation) > 2: - host_code.write("""\ -#include "dace/perf/reporting.h" -#include -#include -#include -#include -""") - host_code.write("\n\n") - - self._frame.generate_fileheader(self._global_sdfg, host_code, 'intelfpga_host') - - params_comma = self._global_sdfg.init_signature(free_symbols=self._frame.free_symbols(self._global_sdfg)) - if params_comma: - params_comma = ', ' + params_comma - - host_code.write(""" -DACE_EXPORTED int __dace_init_intel_fpga({sdfg_state_name} *__state{signature}) {{{emulation_flag} - __state->fpga_context = new dace_fpga_context(); - __state->fpga_context->Get().MakeProgram({kernel_file_name}); - return 0; -}} - -DACE_EXPORTED int __dace_exit_intel_fpga({sdfg_state_name} *__state) {{ - delete __state->fpga_context; - return 0; -}} - -{host_code}""".format(signature=params_comma, - sdfg=self._global_sdfg, - sdfg_state_name=cpp.mangle_dace_state_struct_name(self._global_sdfg), - emulation_flag=emulation_flag, - kernel_file_name=kernel_file_name, - host_code="".join([ - "{separator}\n// State: {kernel_name}" - "\n{separator}\n\n{code}\n\n".format(separator="/" * 79, kernel_name=name, code=code) - for (name, code) in self._host_codes - ]))) - - host_code_obj = CodeObject(self._program_name, - host_code.getvalue(), - "cpp", - IntelFPGACodeGen, - "Intel FPGA", - target_type="host", - sdfg=self._global_sdfg) - - kernel_code_objs = [ - CodeObject(kernel_name, - code, - "cl", - IntelFPGACodeGen, - "Intel FPGA", - target_type="device", - sdfg=self._global_sdfg) for (kernel_name, code, _) in self._kernel_codes - ] - # add the util header if present - other_code_objs = [ - CodeObject(file_name, - code.getvalue(), - "cl", - IntelFPGACodeGen, - "Intel FPGA", - target_type="device", - sdfg=self._global_sdfg) for (file_name, code) in self._other_codes.items() - ] - - return [host_code_obj] + kernel_code_objs + other_code_objs - - def _internal_preprocess(self, sdfg: dace.SDFG): - """ - Vendor-specific SDFG Preprocessing - """ - pass - - def create_mangled_channel_name(self, var_name, kernel_id, external_stream): - """ - Memorize and returns the mangled name of a global channel - The dictionary is organized as ``(var_name) : {kernel_id: mangled_name}`` - - :param external_stream: indicates whether this channel is an external stream - (inter-FPGA Kernel) or not. If this is the case, it will not actually mangle - the name by appending a suffix. - """ - - if kernel_id not in self.channel_mangle[var_name]: - if not external_stream: - existing_count = len(self.channel_mangle[var_name]) - suffix = f"_{existing_count}" if existing_count > 0 else "" - mangled_name = f"{var_name}{suffix}" - else: - mangled_name = var_name - self.channel_mangle[var_name][kernel_id] = mangled_name - return self.channel_mangle[var_name][kernel_id] - - def get_mangled_channel_name(self, var_name, kernel_id): - """ - Returns the mangled name of a channel if it is a global channel, - or var_name if it is an alias (generated through #define) - """ - if var_name in self.channel_mangle: - return self.channel_mangle[var_name][kernel_id] - else: - return var_name - - def create_mangled_module_name(self, module_name, kernel_id): - """ - Memorize and returns the mangled name of a module (OpenCL kernel) - The dictionary is organized as {module_name: {kernel_id: mangled_name}} - """ - - if kernel_id not in self.module_mange[module_name]: - existing_count = len(self.module_mange[module_name]) - suffix = f"_{existing_count}" if existing_count > 0 else "" - mangled_name = f"{module_name}{suffix}" - self.module_mange[module_name][kernel_id] = mangled_name - return self.module_mange[module_name][kernel_id] - - def define_stream(self, dtype, buffer_size, var_name, array_size, function_stream, kernel_stream, sdfg): - """ - Defines a stream - - :return: a tuple containing the type of the created variable, and boolean indicating - whether this is a global variable or not - """ - vec_type = self.make_vector_type(dtype, False) - minimum_depth = Config.get("compiler", "fpga", "minimum_fifo_depth") - buffer_size = evaluate(buffer_size, sdfg.constants) - if minimum_depth: - minimum_depth = int(minimum_depth) - if minimum_depth > buffer_size: - buffer_size = minimum_depth - if buffer_size != 1: - depth_attribute = " __attribute__((depth({})))".format(cpp.sym2cpp(buffer_size)) - else: - depth_attribute = "" - if cpp.sym2cpp(array_size) != "1": - size_str = "[" + cpp.sym2cpp(array_size) + "]" - else: - size_str = "" - - if var_name in self.external_streams: - # This is an external streams: it connects two different FPGA Kernels - # that will be code-generated as two separate files. - # We need to declare the channel as global variable and it must have have - # the same name in both the files. - - chan_name = self.create_mangled_channel_name(var_name, self._kernel_count, True) - function_stream.write("channel {} {}{}{};".format(vec_type, chan_name, size_str, depth_attribute)) - else: - # mangle name - chan_name = self.create_mangled_channel_name(var_name, self._kernel_count, False) - - kernel_stream.write("channel {} {}{}{};".format(vec_type, chan_name, size_str, depth_attribute)) - - # Return value is used for adding to defined_vars in fpga.py - # In Intel FPGA, streams must be defined as global entity, so they will be added to the global variables - return 'channel {}'.format(vec_type), True - - def define_local_array(self, var_name, desc, array_size, function_stream, kernel_stream, sdfg, state_id, node): - vec_type = self.make_vector_type(desc.dtype, False) - if desc.storage == dace.dtypes.StorageType.FPGA_Registers: - attributes = " __attribute__((register))" - else: - attributes = "" - kernel_stream.write("{}{} {}[{}];\n".format(vec_type, attributes, var_name, cpp.sym2cpp(array_size))) - self._dispatcher.defined_vars.add(var_name, DefinedType.Pointer, vec_type) - - def define_shift_register(self, *args, **kwargs): - # Shift registers are just arrays on Intel - self.define_local_array(*args, **kwargs) - - @staticmethod - def make_vector_type(dtype, is_const): - return "{}{}".format("const " if is_const else "", dtype.ocltype) - - def make_kernel_argument(self, data, var_name, is_output, with_vectorization): - if isinstance(data, dace.data.Array): - if with_vectorization: - vec_type = data.dtype.ocltype - else: - vec_type = fpga.vector_element_type_of(data.dtype).ocltype - return "__global volatile {}* restrict {}".format(vec_type, var_name) - elif isinstance(data, dace.data.Stream): - return None # Streams are global objects - else: # Scalar or structure - return f'{data.dtype.ocltype} {var_name}' - - @staticmethod - def generate_unroll_loop_pre(kernel_stream, factor, sdfg, cfg, state_id, node): - if factor is not None: - factor_str = " " + factor - else: - factor_str = "" - kernel_stream.write("#pragma unroll{}".format(factor_str), cfg, state_id, node) - - @staticmethod - def generate_unroll_loop_post(kernel_stream, factor, sdfg, cfg, state_id, node): - pass - - @staticmethod - def generate_pipeline_loop_pre(kernel_stream, sdfg, cfg, state_id, node): - pass - - @staticmethod - def generate_pipeline_loop_post(kernel_stream, sdfg, cfg, state_id, node): - pass - - @staticmethod - def generate_flatten_loop_pre(kernel_stream, sdfg, cfg, state_id, node): - kernel_stream.write("#pragma loop_coalesce") - - @staticmethod - def generate_flatten_loop_post(kernel_stream, sdfg, cfg, state_id, node): - pass - - def make_read(self, defined_type, dtype, var_name, expr, index, is_pack, packing_factor): - if defined_type in [DefinedType.Stream, DefinedType.StreamArray]: - # channel mangling: the expression could contain indexing - expr.replace(var_name, self.get_mangled_channel_name(var_name, self._kernel_count)) - read_expr = "read_channel_intel({})".format(expr) - elif defined_type == DefinedType.Pointer: - if index and index != "0": - read_expr = f"*({expr} + {index})" - else: - if " " in expr: - expr = f"({expr})" - read_expr = f"*{expr}" - elif defined_type == DefinedType.Scalar: - read_expr = var_name - else: - raise NotImplementedError("Unimplemented read type: {}".format(defined_type)) - if is_pack: - ocltype = fpga.vector_element_type_of(dtype).ocltype - self.converters_to_generate.add((True, ocltype, packing_factor)) - return "pack_{}{}(&({}))".format(ocltype, packing_factor, read_expr) - else: - return read_expr - - def make_write(self, defined_type, dtype, var_name, write_expr, index, read_expr, wcr, is_unpack, packing_factor): - """ - Creates write expression, taking into account wcr if present - """ - if wcr is not None: - redtype = operations.detect_reduction_type(wcr, openmp=True) - - if defined_type in [DefinedType.Stream, DefinedType.StreamArray]: - #mangle name - chan_name = self.get_mangled_channel_name(write_expr, self._kernel_count) - if defined_type == DefinedType.StreamArray: - write_expr = "{}[{}]".format(chan_name, index) - if is_unpack: - return "\n".join("write_channel_intel({}, {}[{}]);".format(write_expr, read_expr, i) - for i in range(packing_factor)) - else: - return "write_channel_intel({}, {});".format(chan_name, read_expr) - elif defined_type == DefinedType.Pointer: - if wcr is not None: - if (redtype != dace.dtypes.ReductionType.Min and redtype != dace.dtypes.ReductionType.Max): - return "{}[{}] = {}[{}] {} {};".format(write_expr, index, write_expr, index, - REDUCTION_TYPE_TO_HLSLIB[redtype], read_expr) - else: - # use max/min opencl builtins - return "{}[{}] = {}{}({}[{}],{});".format( - write_expr, index, ("f" if dtype.ocltype == "float" or dtype.ocltype == "double" else ""), - REDUCTION_TYPE_TO_HLSLIB[redtype], write_expr, index, read_expr) - else: - if is_unpack: - ocltype = fpga.vector_element_type_of(dtype).ocltype - self.converters_to_generate.add((False, ocltype, packing_factor)) - if not index or index == "0": - return "unpack_{}{}({}, {});".format(ocltype, packing_factor, read_expr, write_expr) - else: - return "unpack_{}{}({}, {} + {});".format(ocltype, packing_factor, read_expr, write_expr, index) - else: - if " " in write_expr: - write_expr = f"({write_expr})" - if index and index != "0": - return f"{write_expr}[{index}] = {read_expr};" - else: - return f"*{write_expr} = {read_expr};" - elif defined_type == DefinedType.Scalar: - if wcr is not None: - if redtype != dace.dtypes.ReductionType.Min and redtype != dace.dtypes.ReductionType.Max: - return "{} = {} {} {};".format(write_expr, write_expr, REDUCTION_TYPE_TO_HLSLIB[redtype], read_expr) - else: - # use max/min opencl builtins - return "{} = {}{}({},{});".format( - write_expr, ("f" if dtype.ocltype == "float" or dtype.ocltype == "double" else ""), - REDUCTION_TYPE_TO_HLSLIB[redtype], write_expr, read_expr) - else: - if is_unpack: - ocltype = fpga.vector_element_type_of(dtype).ocltype - self.converters_to_generate.add((False, ocltype, packing_factor)) - return "unpack_{}{}({}, {});".format( - vector_element_type_of(dtype).ocltype, packing_factor, read_expr, var_name) - else: - return "{} = {};".format(var_name, read_expr) - raise NotImplementedError("Unimplemented write type: {}".format(defined_type)) - - def make_shift_register_write(self, defined_type, dtype, var_name, write_expr, index, read_expr, wcr, is_unpack, - packing_factor, sdfg): - if defined_type != DefinedType.Pointer: - raise TypeError("Intel shift register must be an array: " - "{} is {}".format(var_name, defined_type)) - # Shift array - arr_size = functools.reduce(lambda a, b: a * b, sdfg.data(var_name).shape, 1) - res = """ -#pragma unroll -for (int u_{name} = 0; u_{name} < {size} - {veclen}; ++u_{name}) {{ - {name}[u_{name}] = {name}[u_{name} + {veclen}]; -}}\n""".format(name=var_name, size=arr_size, veclen=cpp.sym2cpp(dtype.veclen)) - # Then do write - res += self.make_write(defined_type, dtype, var_name, write_expr, index, read_expr, wcr, is_unpack, - packing_factor) - return res - - @staticmethod - def generate_no_dependence_pre(kernel_stream, sdfg, cfg, state_id, node, var_name=None): - """ - Adds pre-loop pragma for ignoring loop carried dependencies on a given variable - (if var_name is provided) or all variables - """ - if var_name is None: - kernel_stream.write("#pragma ivdep", cfg, state_id, node) - else: - kernel_stream.write("#pragma ivdep array({})".format(var_name), cfg, state_id, node) - - @staticmethod - def generate_no_dependence_post(kernel_stream, sdfg, cfg, state_id, node, var_name=None, accessed_subset=None): - pass - - def generate_kernel_internal(self, sdfg: dace.SDFG, cfg: ControlFlowRegion, state: dace.SDFGState, kernel_name: str, - predecessors: list, subgraphs: list, kernel_stream: CodeIOStream, - state_host_header_stream: CodeIOStream, state_host_body_stream: CodeIOStream, - instrumentation_stream: CodeIOStream, function_stream: CodeIOStream, - callsite_stream: CodeIOStream, state_parameters: list) -> None: - """ - Generates Kernel code, both device and host side. - - :param sdfg: - :param state: - :param kernel_name: - :param predecessors: list containing all the name of kernels from which this one depends - :param subgraphs: - :param kernel_stream: Device code stream, contains the kernel code - :param state_host_header_stream: Device-specific code stream: contains the host code - for the state global declarations. - :param state_host_body_stream: Device-specific code stream: contains all the code related to - this state, for creating transient buffers, spawning kernels, and synchronizing them. - :param instrumentation_stream: Code for profiling kernel execution time. - :param function_stream: CPU code stream. - :param callsite_stream: CPU code stream. - :param state_parameters: list of state parameters. The kernel-specific parameters will be appended to it. - """ - - # In xilinx one of them is not used because part of the code goes in another place (entry_stream) - state_id = state.block_id - - kernel_header_stream = CodeIOStream() - kernel_body_stream = CodeIOStream() - - #reset list of needed converters - self.converters_to_generate = set() - - kernel_header_stream.write("#include \n\n", cfg) - self.generate_constants(sdfg, kernel_header_stream) - kernel_header_stream.write("\n", cfg) - - (global_data_parameters, top_level_local_data, subgraph_parameters, nested_global_transients, bank_assignments, - external_streams) = self.make_parameters(sdfg, state, subgraphs) - - # save the name of external streams - self.external_streams = set([chan_name for _, chan_name, _, _ in external_streams]) - - # Emit allocations of inter-kernel memories - for node in top_level_local_data: - self._dispatcher.dispatch_allocate(sdfg, cfg, state, state_id, node, node.desc(sdfg), callsite_stream, - kernel_body_stream) - - kernel_body_stream.write("\n") - state_parameters.extend(global_data_parameters) - # Generate host code (Global transients) - self.generate_host_function_boilerplate(sdfg, cfg, state, nested_global_transients, state_host_body_stream) - - self.generate_host_function_prologue(sdfg, cfg, state, state_host_body_stream, kernel_name) - - # Generate PEs code - self.generate_modules(sdfg, cfg, state, kernel_name, subgraphs, subgraph_parameters, kernel_body_stream, - state_host_header_stream, state_host_body_stream, instrumentation_stream) - - kernel_body_stream.write("\n") - - # Generate data width converters - self.generate_converters(sdfg, cfg, kernel_header_stream) - - kernel_stream.write(kernel_header_stream.getvalue() + kernel_body_stream.getvalue()) - - # Generate host kernel invocation - self.generate_host_function_body(sdfg, cfg, state, state_host_body_stream, kernel_name, predecessors) - - def generate_host_function_prologue(self, sdfg, cfg, state, host_stream, kernel_name): - seperator = "/" * 59 - host_stream.write(f"\n{seperator}\n// Kernel: {kernel_name}\n{seperator}\n\n") - - host_stream.write(f"std::vector {kernel_name}_kernels;", cfg, state.block_id) - - def generate_host_function_body(self, sdfg: dace.SDFG, cfg: ControlFlowRegion, state: dace.SDFGState, - host_stream: CodeIOStream, kernel_name: str, predecessors: list) -> None: - """ - Generate the host-specific code for spawning and synchronizing the given kernel. - - :param sdfg: - :param state: - :param host_stream: Device-specific code stream - :param kernel_name: - :param predecessors: list containing all the name of kernels that must be finished before starting this one - """ - state_id = state.block_id - - # Check if this kernel depends from other kernels - needs_synch = len(predecessors) > 0 - - if needs_synch: - # Build a vector containing all the events associated with the kernels from which this one depends - kernel_deps_name = f"deps_{kernel_name}" - host_stream.write(f"std::vector {kernel_deps_name};") - for pred in predecessors: - # concatenate events from predecessor kernel - host_stream.write( - f"{kernel_deps_name}.insert({kernel_deps_name}.end(), {pred}_events.begin(), {pred}_events.end());") - - # While spawning the kernel, indicates the synchronization events (if any) - host_stream.write( - f"""\ - std::vector {kernel_name}_events; - for (auto &k : {kernel_name}_kernels) {{ - {kernel_name}_events.emplace_back(k.ExecuteTaskAsync({f'{kernel_deps_name}.begin(), {kernel_deps_name}.end()' if needs_synch else ''})); - }} - all_events.insert(all_events.end(), {kernel_name}_events.begin(), {kernel_name}_events.end()); -""", cfg, state_id) - - def generate_module(self, sdfg, cfg, state, kernel_name, module_name, subgraph, parameters, module_stream, - host_header_stream, host_body_stream, instrumentation_stream): - state_id = state.block_id - dfg = cfg.state(state_id) - - kernel_args_opencl = [] - kernel_args_host = [] - kernel_args_call = [] - for is_output, pname, p, _ in parameters: - if isinstance(p, dace.data.View): - continue - arg = self.make_kernel_argument(p, pname, is_output, True) - - if arg is not None: - #change c type to opencl type - if arg in dtypes._CTYPES_TO_OCLTYPES: - arg = dtypes._CTYPES_TO_OCLTYPES[arg] - - kernel_args_opencl.append(arg) - kernel_args_host.append(p.as_arg(True, name=pname)) - kernel_args_call.append(pname) - - # If the kernel takes no arguments, we don't have to call it from the - # host - is_autorun = len(kernel_args_opencl) == 0 - - # create a unique module name to prevent name clashes - module_function_name = "mod_" + str(cfg.cfg_id) + "_" + module_name - # The official limit suggested by Intel for module name is 61. However, the compiler - # can also append text to the module. Longest seen so far is - # "_cra_slave_inst", which is 15 characters, so we restrict to - # 61 - 15 = 46, and round down to 36 to be conservative, since - # internally could still fail while dealing with RTL. - # However, in this way we could have name clashes (e.g., if we have two almost identical NestedSDFG). - # Therefore we explicitly take care of this by mangling the name - module_function_name = self.create_mangled_module_name(module_function_name[0:36], self._kernel_count) - - # Unrolling processing elements: if there first scope of the subgraph - # is an unrolled map, generate a processing element for each iteration - scope_children = subgraph.scope_children() - top_scopes = [n for n in scope_children[None] if isinstance(n, dace.sdfg.nodes.EntryNode)] - unrolled_loop = None - if len(top_scopes) == 1: - scope = top_scopes[0] - if scope.unroll: - # Unrolled processing elements - self._unrolled_pes.add(scope.map) - kernel_args_opencl += ["const int " + p for p in scope.params] # PE id will be a macro defined constant - kernel_args_call += [p for p in scope.params] - unrolled_loop = scope.map - - # Ensure no duplicate parameters are used - kernel_args_opencl = dtypes.deduplicate(kernel_args_opencl) - kernel_args_call = dtypes.deduplicate(kernel_args_call) - - # Add kernel call host function - if not is_autorun: - if unrolled_loop is None: - host_body_stream.write( - "{}_kernels.emplace_back(program.MakeKernel(\"{}\"{}));".format( - kernel_name, module_function_name, - ", ".join([""] + kernel_args_call) if len(kernel_args_call) > 0 else ""), cfg, state_id) - if state.instrument == dtypes.InstrumentationType.FPGA: - self.instrument_opencl_kernel(module_function_name, state_id, cfg.cfg_id, instrumentation_stream) - else: - # We will generate a separate kernel for each PE. Adds host call - start, stop, skip = unrolled_loop.range.ranges[0] - start_idx = evaluate(start, sdfg.constants) - stop_idx = evaluate(stop, sdfg.constants) - skip_idx = evaluate(skip, sdfg.constants) - # Due to restrictions on channel indexing, PE IDs must start - # from zero and skip index must be 1 - if start_idx != 0 or skip_idx != 1: - raise cgx.CodegenError(f"Unrolled Map in {sdfg.name} should start from 0 " - "and have skip equal to 1") - for p in range(start_idx, stop_idx + 1, skip_idx): - # Last element in list kernel_args_call is the PE ID, but - # this is already written in stone in the OpenCL generated - # code - unrolled_module_name = f"{module_function_name}_{p}" - host_body_stream.write( - "{}_kernels.emplace_back(program.MakeKernel(\"{}\"{}));".format( - kernel_name, unrolled_module_name, - ", ".join([""] + kernel_args_call[:-1]) if len(kernel_args_call) > 1 else ""), cfg, - state_id) - if state.instrument == dtypes.InstrumentationType.FPGA: - self.instrument_opencl_kernel(unrolled_module_name, state_id, cfg.cfg_id, - instrumentation_stream) - - # ---------------------------------------------------------------------- - # Generate kernel code - # ---------------------------------------------------------------------- - - self._dispatcher.defined_vars.enter_scope(subgraph) - - module_body_stream = CodeIOStream() - - AUTORUN_STR = """\ -__attribute__((max_global_work_dim(0))) -__attribute__((autorun))\n""" - - if unrolled_loop is None: - module_body_stream.write( - "{}__kernel void {}({}) {{".format(AUTORUN_STR if is_autorun else "", module_function_name, - ", ".join(kernel_args_opencl)), cfg, state_id) - else: - # Unrolled PEs: we have to generate a kernel for each PE. We will generate - # a function that will be used create a kernel multiple times - - # generate a unique name for this function - pe_function_name = "pe_" + str(cfg.cfg_id) + "_" + module_name + "_func" - module_body_stream.write("inline void {}({}) {{".format(pe_function_name, ", ".join(kernel_args_opencl)), - cfg, state_id) - - # Allocate local transients - data_to_allocate = (set(subgraph.top_level_transients()) - set(sdfg.shared_transients()) - - set([p[1] for p in parameters])) - allocated = set() - for node in subgraph.nodes(): - if not isinstance(node, dace.sdfg.nodes.AccessNode): - continue - if node.data not in data_to_allocate or node.data in allocated: - continue - allocated.add(node.data) - self._dispatcher.dispatch_allocate(sdfg, cfg, state, state_id, node, node.desc(sdfg), module_stream, - module_body_stream) - - self._dispatcher.dispatch_subgraph(sdfg, - cfg, - subgraph, - state_id, - module_stream, - module_body_stream, - skip_entry_node=False) - - module_stream.write(module_body_stream.getvalue(), cfg, state_id) - module_stream.write("}\n\n") - - if unrolled_loop is not None: - - AUTORUN_STR_MACRO = """ -__attribute__((max_global_work_dim(0))) \\ -__attribute__((autorun)) \\""" - - # Unrolled PEs: create as many kernels as the number of PEs - # To avoid long and duplicated code, do it with define (gosh) - # Since OpenCL is "funny", it does not support variadic macros - # One of the argument is for sure the PE_ID, which is also the last one in kernel_args lists: - # it will be not passed by the host but code-generated - module_stream.write("""\ -#define _DACE_FPGA_KERNEL_{}(PE_ID{}{}) \\{} -__kernel void \\ -{}_##PE_ID({}) \\ -{{ \\ - {}({}{}PE_ID); \\ -}}\\\n\n""".format(module_function_name, ", " if len(kernel_args_call) > 1 else "", ", ".join(kernel_args_call[:-1]), - AUTORUN_STR_MACRO if is_autorun else "", module_function_name, ", ".join(kernel_args_opencl[:-1]), - pe_function_name, ", ".join(kernel_args_call[:-1]), ", " if len(kernel_args_call) > 1 else "")) - - # create PE kernels by using the previously defined macro - start, stop, skip = unrolled_loop.range.ranges[0] - start_idx = evaluate(start, sdfg.constants) - stop_idx = evaluate(stop, sdfg.constants) - skip_idx = evaluate(skip, sdfg.constants) - # First macro argument is the processing element id - for p in range(start_idx, stop_idx + 1, skip_idx): - module_stream.write("_DACE_FPGA_KERNEL_{}({}{}{})\n".format(module_function_name, p, - ", " if len(kernel_args_call) > 1 else "", - ", ".join(kernel_args_call[:-1]))) - module_stream.write("#undef _DACE_FPGA_KERNEL_{}\n".format(module_function_name)) - - self._dispatcher.defined_vars.exit_scope(subgraph) - - def generate_nsdfg_header(self, sdfg, cfg, state, state_id, node, memlet_references, sdfg_label): - # Intel FPGA needs to deal with streams - arguments = [f'{atype} {aname}' for atype, aname, _ in memlet_references] - fsyms = node.sdfg.used_symbols(all_symbols=False, keep_defined_in_mapping=True) - arguments += [ - f'{node.sdfg.symbols[aname].ocltype} {aname}' for aname in sorted(node.symbol_mapping.keys()) - if aname in fsyms and aname not in sdfg.constants - ] - arguments = ', '.join(arguments) - function_header = f'void {sdfg_label}({arguments}) {{' - nested_stream = CodeIOStream() - - #generate Stream defines if needed - for edge in state.in_edges(node): - if edge.data.data is not None: # skip empty memlets - desc = sdfg.arrays[edge.data.data] - if isinstance(desc, dace.data.Stream): - src_node = find_input_arraynode(state, edge) - self._dispatcher.dispatch_copy(src_node, node, edge, sdfg, cfg, state, state_id, None, - nested_stream) - for edge in state.out_edges(node): - if edge.data.data is not None: # skip empty memlets - desc = sdfg.arrays[edge.data.data] - if isinstance(desc, dace.data.Stream): - dst_node = find_output_arraynode(state, edge) - self._dispatcher.dispatch_copy(node, dst_node, edge, sdfg, cfg, state, state_id, None, - nested_stream) - return function_header + "\n" + nested_stream.getvalue() - - def generate_nsdfg_arguments(self, sdfg, cfg, dfg, state, node): - # Connectors that are both input and output share the same name - inout = set(node.in_connectors.keys() & node.out_connectors.keys()) - memlet_references = [] - - for _, _, _, vconn, in_memlet in state.in_edges(node): - if vconn in inout or in_memlet.data is None: - continue - desc = sdfg.arrays[in_memlet.data] - ptrname = cpp.ptr(in_memlet.data, desc, sdfg, self._frame) - defined_type, defined_ctype = self._dispatcher.defined_vars.get(ptrname, 1) - - #change c type to opencl type - if defined_ctype in dtypes._CTYPES_TO_OCLTYPES: - defined_ctype = dtypes._CTYPES_TO_OCLTYPES[defined_ctype] - - if isinstance(desc, dace.data.Array) and (desc.storage == dtypes.StorageType.FPGA_Global - or desc.storage == dtypes.StorageType.FPGA_Local): - # special case: in intel FPGA this must be handled properly to guarantee OpenCL compatibility - # (no pass by reference) - # The defined type can be a scalar, and therefore we get its address - vec_type = desc.dtype.ocltype - offset = cpp.cpp_offset_expr(desc, in_memlet.subset, None) - offset_expr = '[' + offset + ']' if defined_type is not DefinedType.Scalar else '' - - expr = self.make_ptr_vector_cast(ptrname + offset_expr, desc.dtype, node.in_connectors[vconn], False, - defined_type) - if desc.storage == dtypes.StorageType.FPGA_Global: - typedef = "__global volatile {}* restrict".format(vec_type) - else: - typedef = "{} *".format(vec_type) - ref = '&' if defined_type is DefinedType.Scalar else '' - memlet_references.append((typedef, vconn, ref + expr)) - # get the defined type (as defined in the parent) - # Register defined variable - self._dispatcher.defined_vars.add(vconn, DefinedType.Pointer, typedef, allow_shadowing=True) - elif isinstance(desc, dace.data.Stream): - # streams are defined as global variables - continue - elif isinstance(desc, dace.data.Scalar): - typedef = defined_ctype - if defined_type is DefinedType.Scalar: - # if this is a scalar and the argument passed is also a scalar - # then we have to pass it by value - ref = (typedef, vconn, ptrname) - self._dispatcher.defined_vars.add(vconn, defined_type, typedef, allow_shadowing=True) - else: - # otherwise, pass it as a pointer (references do not exist in C99) - ref = (typedef, vconn, cpp.cpp_ptr_expr(sdfg, in_memlet, defined_type, codegen=self._frame)) - self._dispatcher.defined_vars.add(vconn, defined_type, typedef, allow_shadowing=True) - memlet_references.append(ref) - else: - # all the other cases - memlet_references.append( - cpp.emit_memlet_reference(self._dispatcher, - sdfg, - in_memlet, - vconn, - conntype=node.in_connectors[vconn])) - - for _, uconn, _, _, out_memlet in state.out_edges(node): - if out_memlet.data is not None: - desc = sdfg.arrays[out_memlet.data] - ptrname = cpp.ptr(out_memlet.data, desc, sdfg, self._frame) - defined_type, defined_ctype = self._dispatcher.defined_vars.get(ptrname, 1) - - #change c type to opencl type - if defined_ctype in dtypes._CTYPES_TO_OCLTYPES: - defined_ctype = dtypes._CTYPES_TO_OCLTYPES[defined_ctype] - - if isinstance(desc, dace.data.Array) and (desc.storage == dtypes.StorageType.FPGA_Global - or desc.storage == dtypes.StorageType.FPGA_Local): - # special case: in intel FPGA this must be handled properly. - # The defined type can be scalar, and therefore we get its address - vec_type = desc.dtype.ocltype - offset = cpp.cpp_offset_expr(desc, out_memlet.subset, None) - offset_expr = '[' + offset + ']' if defined_type is not DefinedType.Scalar else '' - if desc.storage == dtypes.StorageType.FPGA_Global: - typedef = "__global volatile {}* restrict".format(vec_type) - else: - typedef = "{}*".format(vec_type) - ref = '&' if defined_type is DefinedType.Scalar else '' - expr = self.make_ptr_vector_cast(ptrname + offset_expr, desc.dtype, node.out_connectors[uconn], - False, defined_type) - memlet_references.append((typedef, uconn, ref + expr)) - # Register defined variable - self._dispatcher.defined_vars.add(uconn, DefinedType.Pointer, typedef, allow_shadowing=True) - elif isinstance(desc, dace.data.Stream): - # streams are defined as global variables - continue - elif isinstance(desc, dace.data.Scalar): - # if this is a scalar and the argument passed is also a scalar - # then we have to pass it by reference, i.e., we should define it - # as a pointer since references do not exist in C99 - typedef = defined_ctype - if defined_type is not DefinedType.Pointer: - typedef = typedef + "*" - memlet_references.append( - (typedef, uconn, cpp.cpp_ptr_expr(sdfg, out_memlet, defined_type, codegen=self._frame))) - self._dispatcher.defined_vars.add(uconn, DefinedType.Pointer, typedef, allow_shadowing=True) - else: - memlet_references.append( - cpp.emit_memlet_reference(self._dispatcher, - sdfg, - out_memlet, - uconn, - conntype=node.out_connectors[uconn])) - - # Special case for Intel FPGA: this comes out from the unrolling processing elements: - # if the first scope of the subgraph is an unrolled map, generates a processing element for each iteration - # We need to pass to this function also the id of the PE (the top scope parameter) - scope_children = dfg.scope_children() - top_scopes = [n for n in scope_children[None] if isinstance(n, dace.sdfg.nodes.EntryNode)] - if len(top_scopes) == 1: - scope = top_scopes[0] - if scope.unroll: - # Unrolled processing elements - typedef = "const int" - for p in scope.params: - # if this is not already a mapped symbol, add it - if p not in node.symbol_mapping.keys(): - memlet_references.append((typedef, p, p)) - return memlet_references - - def allocate_view(self, sdfg: dace.SDFG, cfg: ControlFlowRegion, dfg: SDFGState, state_id: int, - node: dace.nodes.AccessNode, global_stream: CodeIOStream, declaration_stream: CodeIOStream, - allocation_stream: CodeIOStream) -> None: - """ - Allocates (creates pointer and refers to original) a view of an - existing array, scalar, or view. Specifically tailored for Intel FPGA - """ - name = node.data - nodedesc = node.desc(sdfg) - ptrname = cpp.ptr(name, nodedesc, sdfg, self._frame) - if self._dispatcher.defined_vars.has(ptrname): - return # View was already allocated - - # Check directionality of view (referencing dst or src) - edge = sdutils.get_view_edge(dfg, node) - - # Allocate the viewed data before the view, if necessary - mpath = dfg.memlet_path(edge) - viewed_dnode = mpath[0].src if edge.dst is node else mpath[-1].dst - self._dispatcher.dispatch_allocate(sdfg, cfg, dfg, state_id, viewed_dnode, viewed_dnode.desc(sdfg), - global_stream, allocation_stream) - - # Emit memlet as a reference and register defined variable - if nodedesc.storage == dace.dtypes.StorageType.FPGA_Global: - # If the viewed (hence the view) node has global storage type, we need to specifically - # derive the declaration/definition - - qualifier = "__global volatile " - atype = dtypes.pointer(nodedesc.dtype).ocltype + " restrict" - aname = ptrname - viewed_desc = sdfg.arrays[edge.data.data] - eptr = cpp.ptr(edge.data.data, viewed_desc, sdfg, self._frame) - defined_type, _ = self._dispatcher.defined_vars.get(eptr, 0) - # Register defined variable - self._dispatcher.defined_vars.add(aname, defined_type, atype, allow_shadowing=True) - _, _, value = cpp.emit_memlet_reference(self._dispatcher, - sdfg, - edge.data, - name, - dtypes.pointer(nodedesc.dtype), - ancestor=0, - device_code=self._in_device_code) - else: - qualifier = "" - atype, aname, value = cpp.emit_memlet_reference(self._dispatcher, - sdfg, - edge.data, - name, - dtypes.pointer(nodedesc.dtype), - ancestor=0) - declaration_stream.write(f'{qualifier}{atype} {aname} = {value};', cfg, state_id, node) - - def generate_memlet_definition(self, sdfg, cfg, dfg, state_id, src_node, dst_node, edge, callsite_stream): - - if isinstance(edge.dst, dace.sdfg.nodes.CodeNode): - # Input memlet - connector = edge.dst_conn - is_output = False - tasklet = edge.dst - conntype = tasklet.in_connectors[connector] - elif isinstance(edge.src, dace.sdfg.nodes.CodeNode): - # Output memlet - connector = edge.src_conn - is_output = True - tasklet = edge.src - conntype = tasklet.out_connectors[connector] - else: - raise NotImplementedError("Not implemented for {} to {}".format(type(edge.src), type(edge.dst))) - - memlet = edge.data - data_name = memlet.data - data_desc = sdfg.arrays[data_name] - data_dtype = data_desc.dtype - - is_scalar = not isinstance(conntype, dtypes.pointer) - dtype = conntype if is_scalar else conntype._typeclass - - memlet_type = self.make_vector_type(dtype, False) - offset = cpp.cpp_offset_expr(data_desc, memlet.subset, None) - - if dtype != data_dtype: - if (isinstance(dtype, dace.vector) and dtype.base_type == data_dtype): - cast = True - else: - raise TypeError("Type mismatch: {} vs. {}".format(dtype, data_dtype)) - else: - cast = False - - result = "" - - # NOTE: FPGA Streams are defined at the top-level scope. We use the - # following boolean to pass this informations to the `get` method of - # the `defined_vars` object. - is_global = False - if isinstance(data_desc, dace.data.Stream): - # Derive the name of the original stream, by tracing the memlet path through nested SDFGs - outer_stream_node_trace = utils.trace_nested_access(dst_node if is_output else src_node, - cfg.state(state_id), sdfg) - data_name = outer_stream_node_trace[0][0][1 if is_output else 0].label - is_global = True - - data_name = cpp.ptr(data_name, data_desc, sdfg, self._frame) - - def_type, ctypedef = self._dispatcher.defined_vars.get(data_name, is_global=is_global) - if def_type == DefinedType.Scalar: - if cast: - rhs = f"(*({memlet_type} const *)&{data_name})" - else: - rhs = data_name - if not memlet.dynamic: - if not is_output: - # We can pre-read the value - result += "{} {} = {};".format(memlet_type, connector, rhs) - else: - # The value will be written during the tasklet, and will be - # automatically written out after - init = "" - - result += "{} {}{};".format(memlet_type, connector, init) - self._dispatcher.defined_vars.add(connector, DefinedType.Scalar, memlet_type) - else: - # Variable number of reads or writes - result += "{} *{} = &{};".format(memlet_type, connector, rhs) - self._dispatcher.defined_vars.add(connector, DefinedType.Pointer, '%s *' % memlet_type) - elif def_type == DefinedType.Pointer: - if cast: - rhs = f"(({memlet_type} const *){data_name})" - else: - rhs = data_name - if is_scalar and not memlet.dynamic: - if is_output: - result += "{} {};".format(memlet_type, connector) - else: - result += "{} {} = {}[{}];".format(memlet_type, connector, rhs, offset) - self._dispatcher.defined_vars.add(connector, DefinedType.Scalar, memlet_type) - else: - if data_desc.storage == dace.dtypes.StorageType.FPGA_Global: - qualifiers = "__global " - else: - qualifiers = "" - ctype = '{}{} *'.format(qualifiers, memlet_type) - result += "{}{} = &{}[{}];".format(ctype, connector, rhs, offset) - self._dispatcher.defined_vars.add(connector, DefinedType.Pointer, ctype) - elif def_type == DefinedType.Stream: - if cast: - raise TypeError("Cannot cast stream from {} to {}.".format(data_dtype, dtype)) - - # In the define we refer to the stream defined in the outermost scope - if not memlet.dynamic and memlet.num_accesses == 1: - if is_output: - result += "{} {};".format(memlet_type, connector) - else: - result += "{} {} = read_channel_intel({});".format( - memlet_type, connector, self.get_mangled_channel_name(data_name, self._kernel_count)) - self._dispatcher.defined_vars.add(connector, DefinedType.Scalar, memlet_type) - else: - # Desperate times call for desperate measures - result += "#define {} {} // God save us".format( - connector, self.get_mangled_channel_name(data_name, self._kernel_count)) - self._dispatcher.defined_vars.add(connector, DefinedType.Stream, ctypedef) - elif def_type == DefinedType.StreamArray: - if cast: - raise TypeError("Cannot cast stream array from {} to {}.".format(data_dtype, dtype)) - # We need to refer to the stream defined in the outermost scope - # Since this is a Stream Array, we need also the offset, which is contained in the memlet that arrives/departs - # from that stream - outer_memlet = outer_stream_node_trace[0][1][1 if is_output else 0] - outer_sdfg = outer_stream_node_trace[0][-1] - - if not memlet.dynamic and memlet.num_accesses == 1 and (is_output is True - or isinstance(edge.dst, dace.sdfg.nodes.Tasklet)): - # if this is an input memlet, generate the read only if this is a tasklet - if is_output: - result += "{} {};".format(memlet_type, connector) - else: - global_node = utils.trace_nested_access(dst_node if is_output else src_node, cfg.state(state_id), - sdfg) - data_name = global_node[0][0][1 if is_output else 0].label - - if outer_memlet is not None: - offset = cpp.cpp_offset_expr(outer_sdfg.arrays[data_name], outer_memlet.subset) - - result += "{} {} = read_channel_intel({}[{}]);".format( - memlet_type, connector, self.get_mangled_channel_name(data_name, self._kernel_count), offset) - self._dispatcher.defined_vars.add(connector, DefinedType.Scalar, memlet_type) - else: - # Must happen directly in the code - # Here we create a macro which take the proper channel - if outer_memlet is not None: - channel_idx = cpp.cpp_offset_expr(outer_sdfg.arrays[data_name], outer_memlet.subset) - else: - channel_idx = cpp.cpp_offset_expr(sdfg.arrays[data_name], memlet.subset) - result += "#define {} {}[{}] // God save us".format( - connector, self.get_mangled_channel_name(data_name, self._kernel_count), channel_idx) - self._dispatcher.defined_vars.add(connector, DefinedType.Stream, ctypedef) - else: - raise TypeError("Unknown variable type: {}".format(def_type)) - - callsite_stream.write(result, cfg, state_id, tasklet) - - def generate_channel_writes(self, sdfg, cfg, dfg, node, callsite_stream, state_id): - for edge in dfg.out_edges(node): - connector = edge.src_conn - memlet = edge.data - data_name = memlet.data - if data_name is not None: - data_desc = sdfg.arrays[data_name] - if (isinstance(data_desc, dace.data.Stream) and memlet.volume == 1 and not memlet.dynamic): - # mangle channel - chan_name = self.get_mangled_channel_name(data_name, self._kernel_count) - if data_desc.is_stream_array(): - offset = cpp.cpp_offset_expr(data_desc, memlet.subset) - target = f"{chan_name}[{offset}]" - else: - target = chan_name - callsite_stream.write(f"write_channel_intel({target}, {connector});", cfg) - - def generate_undefines(self, sdfg, cfg, dfg, node, callsite_stream): - for edge in itertools.chain(dfg.in_edges(node), dfg.out_edges(node)): - memlet = edge.data - data_name = memlet.data - - if edge.src == node: - memlet_name = edge.src_conn - elif edge.dst == node: - memlet_name = edge.dst_conn - - if data_name is not None: - data_desc = sdfg.arrays[data_name] - if (isinstance(data_desc, dace.data.Stream) and (memlet.dynamic or memlet.num_accesses != 1)): - callsite_stream.write("#undef {}".format(memlet_name), cfg) - - def _generate_converter(self, is_unpack, ctype, veclen, sdfg, cfg, function_stream): - # Get the file stream - if "converters" not in self._other_codes: - self._other_codes["converters"] = CodeIOStream() - converter_stream = self._other_codes["converters"] - - veclen = cpp.sym2cpp(veclen) - - if is_unpack: - converter_name = "unpack_{dtype}{veclen}".format(dtype=ctype, veclen=veclen) - signature = "void {name}(const {dtype}{veclen} value, {dtype} *const ptr)".format(name=converter_name, - dtype=ctype, - veclen=veclen) - if converter_name not in self.generated_converters: - self.generated_converters.add(converter_name) - - # create code for converter in appropriate header file - converter_stream.write( - """\ -{signature} {{ - #pragma unroll - for (int u = 0; u < {veclen}; ++u) {{ - ptr[u] = value[u]; - }} -}}\n\n""".format(signature=signature, dtype=ctype, veclen=veclen), cfg) - - # add forward declaration - function_stream.write("extern {};".format(signature), cfg) - - else: - converter_name = "pack_{dtype}{veclen}".format(dtype=ctype, veclen=veclen) - signature = "{dtype}{veclen} {name}({dtype} const *const ptr)".format(name=converter_name, - dtype=ctype, - veclen=veclen) - if converter_name not in self.generated_converters: - self.generated_converters.add(converter_name) - # create code for converter in appropriate header file - converter_stream.write( - """\ -{signature} {{ - {dtype}{veclen} vec; - #pragma unroll - for (int u = 0; u < {veclen}; ++u) {{ - vec[u] = ptr[u]; - }} - return vec; -}}\n\n""".format(signature=signature, dtype=ctype, veclen=veclen), cfg) - - # add forward declaration - function_stream.write("extern {};".format(signature), cfg, self) - - def generate_converters(self, sdfg, cfg, function_stream): - for unpack, ctype, veclen in self.converters_to_generate: - self._generate_converter(unpack, ctype, veclen, sdfg, cfg, function_stream) - - def unparse_tasklet(self, sdfg: SDFG, cfg: ControlFlowRegion, state_id: int, dfg: StateSubgraphView, - node: nodes.Tasklet, function_stream: CodeIOStream, callsite_stream: CodeIOStream, locals, - ldepth, toplevel_schedule) -> str: - if node.label is None or node.label == "": - return '' - - state_dfg = cfg.state(state_id) - - # Not [], "" or None - if not node.code: - return '' - # Not [], "" or None - if node.code_global and node.code_global.code: - function_stream.write( - codeblock_to_cpp(node.code_global), - cfg, - state_id, - node, - ) - function_stream.write("\n", cfg, state_id, node) - - # If raw C++ or OpenCL code, return the code directly - if node.language != dtypes.Language.Python: - if node.language != dtypes.Language.CPP and node.language != dtypes.Language.OpenCL: - raise ValueError("Only Python, C++ and OpenCL code are supported in Intel FPGA codegen, got: {}".format( - node.language)) - callsite_stream.write(type(node).__properties__["code"].to_string(node.code), cfg, state_id, node) - return - - body = node.code.code - - callsite_stream.write('// Tasklet code (%s)\n' % node.label, cfg, state_id, node) - - # Map local names to memlets (for WCR detection) - memlets = {} - for edge in state_dfg.all_edges(node): - u, uconn, v, vconn, memlet = edge - if u == node: - if uconn in u.out_connectors: - conntype = u.out_connectors[uconn] - else: - conntype = None - - # this could be a wcr - memlets[uconn] = (memlet, not edge.data.wcr_nonatomic, edge.data.wcr, conntype) - elif v == node: - if vconn in v.in_connectors: - conntype = v.in_connectors[vconn] - else: - conntype = None - memlets[vconn] = (memlet, False, None, conntype) - - # Build dictionary with all the previously defined symbols - # This is used for forward type inference - defined_symbols = state_dfg.symbols_defined_at(node) - - # This could be problematic for numeric constants that have no dtype - defined_symbols.update({ - k: v.dtype if hasattr(v, 'dtype') else dtypes.typeclass(type(v)) - for k, v in sdfg.constants.items() - }) - - for connector, (memlet, _, _, conntype) in memlets.items(): - if connector is not None: - defined_symbols.update({connector: conntype}) - - for stmt in body: # for each statement in tasklet body - stmt = copy.deepcopy(stmt) - ocl_visitor = OpenCLDaceKeywordRemover(sdfg, self._dispatcher.defined_vars, memlets, self) - - if isinstance(stmt, ast.Expr): - rk = ocl_visitor.visit_TopLevelExpr(stmt) - else: - rk = ocl_visitor.visit(stmt) - - # Generate width converters - self.converters_to_generate |= ocl_visitor.width_converters - - if rk is not None: - result = StringIO() - cppunparse.CPPUnparser(rk, - ldepth + 1, - locals, - result, - defined_symbols=defined_symbols, - type_inference=True, - language=dtypes.Language.OpenCL) - callsite_stream.write(result.getvalue(), cfg, state_id, node) - - def generate_constants(self, sdfg, callsite_stream): - # To avoid a constant being multiple defined, define it once and - # declare it as extern everywhere else. - - for cstname, (csttype, cstval) in sdfg.constants_prop.items(): - if isinstance(csttype, dace.data.Array): - const_str = "__constant " + csttype.dtype.ocltype + \ - " " + cstname + "[" + str(cstval.size) + "]" - - if cstname not in self.generated_constants: - # First time, define it - self.generated_constants.add(cstname) - const_str += " = {" - it = np.nditer(cstval, order='C') - for i in range(cstval.size - 1): - const_str += str(it[0]) + ", " - it.iternext() - const_str += str(it[0]) + "};\n" - else: - # only define - const_str = "extern " + const_str + ";\n" - callsite_stream.write(const_str, sdfg) - else: - # This is a scalar: defining it as an extern variable has the drawback - # that it is not resolved at compile time, preventing the compiler to - # allocate fast memory. Therefore, we will use a #define - callsite_stream.write(f"#define {cstname} {sym2cpp(cstval)}\n", sdfg) - - def generate_tasklet_postamble(self, sdfg, cfg, dfg, state_id, node, function_stream, callsite_stream, - after_memlets_stream): - super().generate_tasklet_postamble(sdfg, cfg, dfg, state_id, node, function_stream, callsite_stream, - after_memlets_stream) - self.generate_channel_writes(sdfg, cfg, dfg, node, after_memlets_stream, state_id) - - def write_and_resolve_expr(self, sdfg, memlet, nc, outname, inname, indices=None, dtype=None): - desc = sdfg.arrays[memlet.data] - offset = cpp.cpp_offset_expr(desc, memlet.subset, None) - ptrname = cpp.ptr(memlet.data, desc, sdfg, self._frame) - defined_type, _ = self._dispatcher.defined_vars.get(ptrname) - return self.make_write(defined_type, dtype, ptrname, ptrname, offset, inname, memlet.wcr, False, 1) - - def make_ptr_vector_cast(self, dst_expr, dst_dtype, src_dtype, is_scalar, defined_type): - """ - Cast a destination pointer so the source expression can be written to it. - - :param dst_expr: Expression of the target pointer. - :param dst_dtype: Type of the target pointer. - :param src_dtype: Type of the variable that needs to be written. - :param is_scalar: Whether the variable to be written is a scalar. - :param defined_type: The code generated variable type of the - destination. - """ - vtype = self.make_vector_type(src_dtype, False) - expr = dst_expr - if dst_dtype != src_dtype: - if is_scalar: - expr = f"*({vtype} *)(&{dst_expr})" - elif src_dtype.base_type != dst_dtype: - expr = f"({vtype})(&{expr})" - elif defined_type == DefinedType.Pointer: - expr = "&" + expr - elif not is_scalar: - expr = "&" + expr - return expr - - def process_out_memlets(self, sdfg, cfg, state_id, node, dfg, dispatcher, result, locals_defined, function_stream, - **kwargs): - # Call CPU implementation with this code generator as callback - self._cpu_codegen.process_out_memlets(sdfg, - cfg, - state_id, - node, - dfg, - dispatcher, - result, - locals_defined, - function_stream, - codegen=self, - **kwargs) - # Inject undefines - self.generate_undefines(sdfg, cfg, dfg, node, result) - - -class OpenCLDaceKeywordRemover(cpp.DaCeKeywordRemover): - """ - Removes Dace Keywords and enforces OpenCL compliance - """ - - nptypes_to_ctypes = {'float64': 'double', 'float32': 'float', 'int32': 'int', 'int64': 'long'} - nptypes = ['float64', 'float32', 'int32', 'int64'] - ctypes = [ - 'bool', 'char', 'cl_char', 'unsigned char', 'uchar', 'cl_uchar', 'short', 'cl_short', 'unsigned short', - 'ushort', 'int', 'unsigned int', 'uint', 'long', 'unsigned long', 'ulong', 'float', 'half', 'size_t', - 'ptrdiff_t', 'intptr_t', 'uintptr_t', 'void', 'double' - ] - - def __init__(self, sdfg, defined_vars, memlets, codegen): - self.sdfg = sdfg - self.defined_vars = defined_vars - # Keep track of the different streams used in a tasklet - self.used_streams = [] - self.width_converters = set() # Pack and unpack vectors - self.dtypes = {k: v[3] for k, v in memlets.items() if k is not None} # Type inference - # consider also constants: add them to known dtypes - for k, v in sdfg.constants.items(): - if k is not None: - self.dtypes[k] = v.dtype - - super().__init__(sdfg, memlets, sdfg.constants, codegen) - - def visit_Assign(self, node): - target = rname(node.targets[0]) - if target not in self.memlets: - # If we don't have a memlet for this target, it could be the case - # that on the right hand side we have a constant (a Name or a subscript) - # If this is the case, we try to infer the type, otherwise we fallback to generic visit - if ((isinstance(node.value, ast.Name) and node.value.id in self.constants) - or (isinstance(node.value, ast.Subscript) and node.value.value.id in self.constants)): - dtype = infer_expr_type(unparse(node.value), self.dtypes) - value = cppunparse.cppunparse(self.visit(node.value), expr_semicolon=False) - code_str = "{} {} = {};".format(dtype, target, value) - updated = ast.Name(id=code_str) - return updated - else: - return self.generic_visit(node) - - memlet, nc, wcr, dtype = self.memlets[target] - is_scalar = not isinstance(dtype, dtypes.pointer) - - value = cppunparse.cppunparse(self.visit(node.value), expr_semicolon=False) - - veclen_lhs = self.sdfg.data(memlet.data).veclen - try: - dtype_rhs = infer_expr_type(unparse(node.value), self.dtypes) - except SyntaxError: - # non-valid python - dtype_rhs = None - - if dtype_rhs is None: - # If we don't understand the vector length of the RHS, assume no - # conversion is needed - veclen_rhs = veclen_lhs - else: - veclen_rhs = dtype_rhs.veclen - - if ((veclen_lhs > veclen_rhs and veclen_rhs != 1) or (veclen_lhs < veclen_rhs and veclen_lhs != 1)): - raise ValueError("Conflicting memory widths: {} and {}".format(veclen_lhs, veclen_rhs)) - - if veclen_rhs > veclen_lhs: - veclen = veclen_rhs - ocltype = fpga.vector_element_type_of(dtype).ocltype - self.width_converters.add((True, ocltype, veclen)) - unpack_str = "unpack_{}{}".format(ocltype, cpp.sym2cpp(veclen)) - - if veclen_lhs > veclen_rhs and isinstance(dtype_rhs, dace.pointer): - veclen = veclen_lhs - ocltype = fpga.vector_element_type_of(dtype).ocltype - self.width_converters.add((False, ocltype, veclen)) - pack_str = "pack_{}{}".format(ocltype, cpp.sym2cpp(veclen)) - # TODO: Horrible hack to not dereference pointers if we have to - # unpack it - if value[0] == "*": - value = value[1:] - value = "{}({})".format(pack_str, value) - - defined_type, _ = self.defined_vars.get(target) - - if defined_type == DefinedType.Pointer: - # In case of wcr over an array, resolve access to pointer, replacing the code inside - # the tasklet - if isinstance(node.targets[0], ast.Subscript): - - if veclen_rhs > veclen_lhs: - code_str = unpack_str + "({src}, &{dst}[{idx}]);" - else: - code_str = "{dst}[{idx}] = {src};" - slice = self.visit(node.targets[0].slice) - if (isinstance(slice, ast.Slice) and isinstance(slice.value, ast.Tuple)): - subscript = unparse(slice)[1:-1] - else: - subscript = unparse(slice) - if wcr is not None: - redtype = operations.detect_reduction_type(wcr) - red_str = REDUCTION_TYPE_TO_PYEXPR[redtype].format(a="{}[{}]".format(memlet.data, subscript), - b=value) - code_str = code_str.format(dst=memlet.data, idx=subscript, src=red_str) - else: - code_str = code_str.format(dst=target, idx=subscript, src=value) - else: # Target has no subscript - if veclen_rhs > veclen_lhs: - code_str = unpack_str + "({}, {});".format(value, target) - else: - if self.defined_vars.get(target)[0] == DefinedType.Pointer: - code_str = "*{} = {};".format(target, value) - else: - code_str = "{} = {};".format(target, value) - updated = ast.Name(id=code_str) - - elif (defined_type == DefinedType.Stream or defined_type == DefinedType.StreamArray): - if memlet.dynamic or memlet.num_accesses != 1: - updated = ast.Name(id="write_channel_intel({}, {});".format(target, value)) - self.used_streams.append(target) - else: - # in this case for an output stream we have - # previously defined an output local var: we use that one - # instead of directly writing to channel - updated = ast.Name(id="{} = {};".format(target, value)) - elif memlet is not None and (not is_scalar or memlet.dynamic): - newnode = ast.Name(id="*{} = {}; ".format(target, value)) - return ast.copy_location(newnode, node) - elif defined_type == DefinedType.Scalar: - code_str = "{} = {};".format(target, value) - updated = ast.Name(id=code_str) - else: - raise RuntimeError("Unhandled case: {}, type {}, veclen {}, " - "memory size {}, {} accesses".format(target, defined_type, veclen_lhs, veclen_lhs, - memlet.num_accesses)) - - return ast.copy_location(updated, node) - - def visit_BinOp(self, node): - if node.op.__class__.__name__ == 'Pow': - - # Special case for integer power: do not generate dace namespaces (dace::math) but just call pow - if not (isinstance(node.right, - (ast.Num, ast.Constant)) and int(node.right.n) == node.right.n and node.right.n >= 0): - - left_value = cppunparse.cppunparse(self.visit(node.left), expr_semicolon=False) - - try: - unparsed = symbolic.pystr_to_symbolic(evalnode(node.right, { - **self.constants, - 'dace': dace, - })) - evaluated = symbolic.symstr(evaluate(unparsed, self.constants), cpp_mode=True) - infered_type = infer_expr_type(evaluated, self.dtypes) - right_value = evaluated - - if infered_type == dtypes.int64 or infered_type == dtypes.int32: - updated = ast.Name(id="pown({},{})".format(left_value, right_value)) - else: - updated = ast.Name(id="pow({},{})".format(left_value, right_value)) - - except (TypeError, AttributeError, NameError, KeyError, ValueError, SyntaxError): - right_value = cppunparse.cppunparse(self.visit(node.right), expr_semicolon=False) - updated = ast.Name(id="pow({},{})".format(left_value, right_value)) - - return ast.copy_location(updated, node) - - return self.generic_visit(node) - - def visit_Name(self, node): - if node.id not in self.memlets: - return self.generic_visit(node) - - memlet, nc, wcr, dtype = self.memlets[node.id] - defined_type, _ = self.defined_vars.get(node.id) - updated = node - - if ((defined_type == DefinedType.Stream or defined_type == DefinedType.StreamArray) and memlet.dynamic): - # Input memlet, we read from channel - # we should not need mangle here, since we are in a tasklet - updated = ast.Call(func=ast.Name(id="read_channel_intel"), args=[ast.Name(id=node.id)], keywords=[]) - self.used_streams.append(node.id) - elif defined_type == DefinedType.Pointer and memlet.dynamic: - # if this has a variable number of access, it has been declared - # as a pointer. We need to deference it - if isinstance(node.id, ast.Subscript): - slice = self.visit(node.id.slice) - if isinstance(slice.value, ast.Tuple): - subscript = unparse(slice)[1:-1] - else: - subscript = unparse(slice) - updated = ast.Name(id="{}[{}]".format(node.id, subscript)) - else: # no subscript - updated = ast.Name(id="*{}".format(node.id)) - - return ast.copy_location(updated, node) - - # Replace default modules (e.g., math) with OpenCL Compliant (e.g. "dace::math::"->"") - def visit_Attribute(self, node): - attrname = rname(node) - module_name = attrname[:attrname.rfind(".")] - func_name = attrname[attrname.rfind(".") + 1:] - if module_name in dtypes._OPENCL_ALLOWED_MODULES: - cppmodname = dtypes._OPENCL_ALLOWED_MODULES[module_name] - return ast.copy_location(ast.Name(id=(cppmodname + func_name), ctx=ast.Load), node) - return self.generic_visit(node) - - def visit_Call(self, node): - # enforce compliance to OpenCL - # Type casting: - if isinstance(node.func, ast.Name): - if node.func.id in self.ctypes: - node.func.id = "({})".format(node.func.id) - elif node.func.id in self.nptypes_to_ctypes: - # if it as numpy type, convert to C type - node.func.id = "({})".format(self.nptypes_to_ctypes[node.func.id]) - elif isinstance(node.func, ast.Attribute): - if node.func.attr in self.ctypes: - node.func.attr = "({})".format(node.func.attr) - elif node.func.attr in self.nptypes_to_ctypes: - # if it as numpy type, convert to C type - node.func.attr = "({})".format(self.nptypes_to_ctypes[node.func.attr]) - elif (isinstance(node.func, (ast.Num, ast.Constant)) - and (node.func.n.to_string() in self.ctypes or node.func.n.to_string() in self.nptypes)): - new_node = ast.Name(id="({})".format(node.func.n), ctx=ast.Load) - new_node = ast.copy_location(new_node, node) - node.func = new_node - - return self.generic_visit(node) diff --git a/dace/codegen/targets/mlir/mlir.py b/dace/codegen/targets/mlir/mlir.py index 57a9924042..e3451c951d 100644 --- a/dace/codegen/targets/mlir/mlir.py +++ b/dace/codegen/targets/mlir/mlir.py @@ -1,8 +1,9 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. +import os from typing import TYPE_CHECKING from dace import registry, dtypes from dace.codegen.codeobject import CodeObject -from dace.codegen.targets.target import TargetCodeGenerator +from dace.codegen.target import TargetCodeGenerator from dace.codegen.targets.cpu import CPUCodeGen from dace.sdfg import nodes from dace.sdfg.sdfg import SDFG @@ -40,3 +41,8 @@ def generate_node(self, sdfg, cfg, dfg, state_id, node, function_stream, callsit def cmake_options(): options = [] return options + + @staticmethod + def cmake_files(): + mlir_cmake = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'mlir.cmake') + return [mlir_cmake] diff --git a/dace/codegen/targets/mpi.py b/dace/codegen/targets/mpi.py index 41b273e19d..dfff046191 100644 --- a/dace/codegen/targets/mpi.py +++ b/dace/codegen/targets/mpi.py @@ -4,7 +4,7 @@ from dace import registry, symbolic, dtypes from dace.codegen.prettycode import CodeIOStream from dace.codegen.codeobject import CodeObject -from dace.codegen.targets.target import TargetCodeGenerator, make_absolute +from dace.codegen.target import TargetCodeGenerator, make_absolute from dace.codegen.targets.cpp import mangle_dace_state_struct_name from dace.sdfg import nodes, SDFG from dace.config import Config diff --git a/dace/codegen/targets/rtl.py b/dace/codegen/targets/rtl.py deleted file mode 100644 index 699071e17a..0000000000 --- a/dace/codegen/targets/rtl.py +++ /dev/null @@ -1,849 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. - -import itertools -from typing import List, Dict -import warnings - -from dace import dtypes, config, registry, symbolic, nodes, data, SDFG -from dace.sdfg import graph, find_input_arraynode, find_output_arraynode -from dace.codegen import codeobject, dispatcher, prettycode -from dace.codegen.targets import target, framecode -from dace.codegen.common import sym2cpp -from dace.sdfg.state import ControlFlowRegion, SDFGState, StateSubgraphView - - -@registry.autoregister_params(name='rtl') -class RTLCodeGen(target.TargetCodeGenerator): - """ RTL Code Generator (SystemVerilog) """ - - title = 'RTL' - target_name = 'rtl' - languages = [dtypes.Language.SystemVerilog] - n_unrolled: Dict[str, int] = {} - - def __init__(self, frame_codegen: framecode.DaCeCodeGenerator, sdfg: SDFG): - # store reference to sdfg - self.sdfg = sdfg - # store reference to frame code generator - self.frame = frame_codegen - self._frame = self.frame - # get dispatcher to register callbacks for allocation/nodes/.. code generators - self.dispatcher: dispatcher.TargetDispatcher = frame_codegen.dispatcher - # register node dispatcher -> generate_node(), predicate: process tasklets only - self.dispatcher.register_node_dispatcher( - self, lambda sdfg, state, node: isinstance(node, nodes.Tasklet) and node.language == dtypes.Language. - SystemVerilog) - # register all storage types that connect from/to an RTL tasklet - for src_storage, dst_storage in itertools.product(dtypes.StorageType, dtypes.StorageType): - self.dispatcher.register_copy_dispatcher( - src_storage, dst_storage, None, self, lambda sdfg, dfg, src_node, dest_node: - (isinstance(src_node, nodes.Tasklet) and src_node.language == dtypes.Language.SystemVerilog) or - (isinstance(dest_node, nodes.Tasklet) and dest_node.language == dtypes.Language.SystemVerilog)) - # local variables - self.verilator_debug: bool = config.Config.get_bool("compiler", "rtl", "verilator_enable_debug") - self.code_objects: List[codeobject.CodeObject] = list() - self.cpp_general_header_added: bool = False - self.vendor: str = config.Config.get("compiler", "fpga", "vendor") - self.hardware_target: bool = config.Config.get("compiler", "xilinx", "mode").startswith("hardware") - self.frequencies: str = config.Config.get("compiler", "xilinx", "frequency") - - def generate_node(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphView, state_id: int, node: nodes.Node, - function_stream: prettycode.CodeIOStream, callsite_stream: prettycode.CodeIOStream) -> None: - # check instance type - if isinstance(node, nodes.Tasklet): - """ - handle Tasklet: - (1) generate in->tasklet - (2) generate tasklet->out - (3) generate tasklet - """ - callsite_stream.write('{', cfg, state_id, dfg.node_id(node)) - # generate code to handle data input to the tasklet - for edge in dfg.in_edges(node): - # find input array - src_node = find_input_arraynode(dfg, edge) - # dispatch code gen (copy_memory) - self.dispatcher.dispatch_copy(src_node, node, edge, sdfg, cfg, dfg, state_id, function_stream, - callsite_stream) - # generate code to handle data output from the tasklet - for edge in dfg.out_edges(node): - # find output array - dst_node = find_output_arraynode(dfg, edge) - # dispatch code gen (define_out_memlet) - self.dispatcher.dispatch_output_definition(node, dst_node, edge, sdfg, cfg, dfg, state_id, - function_stream, callsite_stream) - # generate tasklet code - self.unparse_tasklet(sdfg, cfg, dfg, state_id, node, function_stream, callsite_stream) - callsite_stream.write('}', cfg, state_id, dfg.node_id(node)) - else: - raise RuntimeError( - "Only tasklets are handled here, not {}. This should have been filtered by the predicate".format( - type(node))) - - def copy_memory(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphView, state_id: int, - src_node: nodes.Node, dst_node: nodes.Node, edge: graph.MultiConnectorEdge, - function_stream: prettycode.CodeIOStream, callsite_stream: prettycode.CodeIOStream) -> None: - """ - Generate input/output memory copies from the array references to local variables (i.e. for the tasklet code). - """ - if isinstance(edge.src, nodes.AccessNode) and isinstance(edge.dst, nodes.Tasklet): # handle AccessNode->Tasklet - if isinstance(dst_node.in_connectors[edge.dst_conn], dtypes.pointer): # pointer accessor - line: str = "{} {} = &{}[0];".format(dst_node.in_connectors[edge.dst_conn].ctype, edge.dst_conn, - edge.src.data) - elif isinstance(dst_node.in_connectors[edge.dst_conn], dtypes.vector): # vector accessor - line: str = "{} {} = *({} *)(&{}[0]);".format(dst_node.in_connectors[edge.dst_conn].ctype, - edge.dst_conn, - dst_node.in_connectors[edge.dst_conn].ctype, - edge.src.data) - else: # scalar accessor - arr = sdfg.arrays[edge.data.data] - if isinstance(arr, data.Array): - line: str = "{}* {} = &{}[0];".format(dst_node.in_connectors[edge.dst_conn].ctype, edge.dst_conn, - edge.src.data) - elif isinstance(arr, data.Scalar): - line: str = "{} {} = {};".format(dst_node.in_connectors[edge.dst_conn].ctype, edge.dst_conn, - edge.src.data) - elif isinstance(arr, data.Stream): - # TODO Streams are currently unsupported, as the proper - # behaviour has to be implemented to avoid deadlocking. It - # is only a warning, as the RTL backend is partially used - # by the Xilinx backend, which may hit this case, but will - # discard the errorneous code. - warnings.warn( - 'Streams are currently unsupported by the RTL backend.' \ - 'This may produce errors or deadlocks in the generated code.' - ) - line: str = "// WARNING: Unsupported read from ({}) variable '{}' from stream '{}'." \ - " This may lead to a deadlock if used in code.\n".format( - dst_node.in_connectors[edge.dst_conn].ctype, edge.dst_conn, edge.src_conn) - line += "{} {} = {}.pop();".format(dst_node.in_connectors[edge.dst_conn].ctype, edge.dst_conn, - edge.src.data) - elif isinstance(edge.src, nodes.MapEntry) and isinstance(edge.dst, nodes.Tasklet): - rtl_name = self.unique_name(edge.dst, cfg.state(state_id)) - self.n_unrolled[rtl_name] = symbolic.evaluate(edge.src.map.range[0][1] + 1, sdfg.constants) - line: str = f'{dst_node.in_connectors[edge.dst_conn]} {edge.dst_conn} = &{edge.data.data}[{edge.src.map.params[0]}*{edge.data.volume}];' - else: - raise RuntimeError("Not handling copy_memory case of type {} -> {}.".format(type(edge.src), type(edge.dst))) - # write accessor to file - callsite_stream.write(line) - - def define_out_memlet(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphView, state_id: int, - src_node: nodes.Node, dst_node: nodes.Node, edge: graph.MultiConnectorEdge, - function_stream: prettycode.CodeIOStream, callsite_stream: prettycode.CodeIOStream): - """ - Generate output copy code (handled within the rtl tasklet code). - """ - if isinstance(edge.src, nodes.Tasklet) and isinstance(edge.dst, nodes.AccessNode): - if isinstance(src_node.out_connectors[edge.src_conn], dtypes.pointer): # pointer accessor - line: str = "{} {} = &{}[0];".format(src_node.out_connectors[edge.src_conn].ctype, edge.src_conn, - edge.dst.data) - elif isinstance(src_node.out_connectors[edge.src_conn], dtypes.vector): # vector accessor - line: str = "{}* {} = ({} *)(&{}[0]);".format(src_node.out_connectors[edge.src_conn].ctype, - edge.src_conn, - src_node.out_connectors[edge.src_conn].ctype, - edge.dst.data) - else: # scalar accessor - line: str = "{}* {} = &{}[0];".format(src_node.out_connectors[edge.src_conn].ctype, edge.src_conn, - edge.dst.data) - elif isinstance(edge.src, nodes.Tasklet) and isinstance(edge.dst, nodes.MapExit): - line: str = f'{src_node.out_connectors[edge.src_conn].ctype} {edge.src_conn} = &{edge.data.data}[{edge.dst.map.params[0]}*{edge.data.volume}];' - else: - raise RuntimeError("Not handling define_out_memlet case of type {} -> {}.".format( - type(edge.src), type(edge.dst))) - # write accessor to file - callsite_stream.write(line) - - def get_generated_codeobjects(self): - """ - Return list of code objects (that are later generating code files). - """ - return self.code_objects - - @property - def has_initializer(self): - """ - Disable initializer method generation. - """ - return False - - @property - def has_finalizer(self): - """ - Disable exit/finalizer method generation. - """ - return False - - @staticmethod - def cmake_options(): - """ - Process variables to be exposed to the CMakeList.txt script. - """ - # get flags from config - verbose = config.Config.get_bool("compiler", "rtl", "verbose") - verilator_flags = config.Config.get("compiler", "rtl", "verilator_flags") - verilator_lint_warnings = config.Config.get_bool("compiler", "rtl", "verilator_lint_warnings") - mode: str = config.Config.get("compiler", "rtl", "mode") - - # create options list - options = [ - "-DDACE_RTL_VERBOSE=\"{}\"".format(verbose), "-DDACE_RTL_VERILATOR_FLAGS=\"{}\"".format(verilator_flags), - "-DDACE_RTL_VERILATOR_LINT_WARNINGS=\"{}\"".format(verilator_lint_warnings), - "-DDACE_RTL_MODE=\"{}\"".format(mode) - ] - return options - - def generate_rtl_parameters(self, constants): - """ - Construct parameters module header - """ - if len(constants) == 0: - return str() - else: - return "#(\n{}\n)".format(" " + "\n".join([ - "{} parameter {} = {}".format("," if i > 0 else "", key, sym2cpp(constants[key])) - for i, key in enumerate(constants) - ])) - - def generate_padded_axis(self, is_output, name, total_size, veclen): - """ - Generates a padded list of strings for pretty printing streaming - AXI port definitions. E.g. for a streaming input port named "a", the - output would be: - - , input s_axis_a_tvalid - , input [31:0] s_axis_a_tdata - , output reg s_axis_a_tready - , input [3:0] s_axis_a_tkeep - , input s_axis_a_tlast - """ - vec_str = '' if veclen <= 1 else f'[{veclen-1}:0]' - bits_str = f'[{(total_size // veclen) * 8 - 1}:0]' - bytes_str = f'[{total_size - 1}:0]' - dir_str = 'output reg' if is_output else 'input ' - ndir_str = 'input ' if is_output else 'output reg' - prefix = f'm_axis_{name}' if is_output else f's_axis_{name}' - padding = ' ' * (len(bits_str) + len(vec_str)) - bytes_padding = ' ' * ((len(bits_str) - len(bytes_str)) + len(vec_str)) - return [ - f', {dir_str} {padding} {prefix}_tvalid', - f', {dir_str} {vec_str}{bits_str} {prefix}_tdata', - f', {ndir_str} {padding} {prefix}_tready', - f', {dir_str} {bytes_padding}{bytes_str} {prefix}_tkeep', - f', {dir_str} {padding} {prefix}_tlast', - ] - - def generate_rtl_inputs_outputs(self, buses, scalars): - """ - Generates all of the input and output ports for the tasklet - """ - inputs = [] - outputs = [] - if scalars: - inputs += [', input scalars_valid'] - for scalar, (is_output, total_size) in scalars.items(): - inputs += [f', {"output" if is_output else "input"} [{(total_size*8)-1}:0] {scalar}'] - - for bus, (_, is_output, total_size, vec_len, volume) in buses.items(): - if is_output: - inputs += self.generate_padded_axis(True, bus, total_size, vec_len) - else: - outputs += self.generate_padded_axis(False, bus, total_size, vec_len) - - return inputs, outputs - - def generate_clk_from_cfg(self): - """ - Generate the clock handling initialization expressions. - """ - if self.frequencies == '': # Default case: no frequency specified, set to 300. - freqs = ['0:300'] - elif ':' not in self.frequencies: # Case of a single number without id. - freqs = [f'0:{self.frequencies}'] - else: # Multiple clocks specified in the format "0:freq_0\|1:freq_1" - freqs = self.frequencies.strip('"').split('\\|') - - prm_clk_format = \ - '''input ap_aclk // convention: ap_aclk clocks the design, specifically the external ports -, input ap_areset // convention: ap_areset resets the design''' - scd_clk_format = \ - ''', input ap_aclk_{id} // convention: ap_aclk_{id} is a secondary clock, which can be used inside -, input ap_areset_{id} // convention: ap_areset_{id} resets the components clocked by ap_aclk_{id}''' - - nclks = len(freqs) - ports = [prm_clk_format] + [scd_clk_format.format(id=i + 2) for i in range(nclks - 1)] - clks = ['&(model->ap_aclk)'] + [f'&(model->ap_aclk_{i+2})' for i in range(nclks - 1)] - freqs = f'{{ {", ".join([freq.split(":")[1] for freq in freqs])} }}' - nclks = str(nclks) - clks = f'{{ {", ".join(clks)} }}' - ports = '\n'.join(ports) - return nclks, freqs, clks, ports - - def generate_cpp_zero_inits(self, buses, scalars): - """ - Generate zero initialization statements - """ - valids = [] - readys = [] - for name, (arr, is_output, bytes, veclen, volume) in buses.items(): - if is_output: - readys.append(f'model->m_axis_{name}_tready = 0;') - else: - valids.append(f'model->s_axis_{name}_tvalid = 0;') - - scals = [f'model->{name} = {name};' for name, _ in scalars.items()] - return valids, readys, scals - - def generate_cpp_inputs_outputs(self, tasklet, buses): - - # generate cpp input reading/output writing code - """ - input: - for arrays: - model->a = a[in_ptr_a++]; - for vectors: - tmp = a[in_ptr_a++]; - for (int i = 0; i < WIDTH; i++) {{ - model->a[i] = tmp[i]; - }} - for scalars: - model->a = a[0]; - - output: - for arrays: - b[out_ptr_b++] = (int)model->b - for vectors: - for(int i = 0; i < WIDTH; i++) {{ - tmp[i] = (int)model->b[i]; - }} - b[out_ptr_b++] = tmp; - for scalars: - b[0] = (int)model->b; - """ - inputs = {} - outputs = {} - - for name, (arr, is_output, bytes, veclen, volume) in buses.items(): - if is_output: - conn = tasklet.out_connectors[name] - if isinstance(conn, dtypes.vector): - outputs[name] = f'''out_ptr_{name}++; -for (int j = 0; j < {veclen}; j++) {{ - {name}[0][j] = (int)(model->m_axis_{name}_tdata[j]); -}}''' - elif isinstance(conn, dtypes.pointer): - if isinstance(conn.base_type, dtypes.vector): - outputs[name] = f'''int idx = out_ptr_{name}++; -for (int j = 0; j < {veclen}; j++) {{ - {name}[idx][j] = (int)(model->m_axis_{name}_tdata[j]); -}}''' - else: - outputs[name] = f'{name}[out_ptr_{name}++] = (int)(model->m_axis_{name}_tdata);' - else: - outputs[name] = f'{name}[out_ptr_{name}++] = (int)(model->m_axis_{name}_tdata);' - - else: # input - conn = tasklet.in_connectors[name] - if isinstance(conn, dtypes.vector): - inputs[name] = f'''in_ptr_{name}++; -for (int j = 0; j < {veclen}; j++) {{ - model->s_axis_{name}_tdata[j] = {name}[j]; -}}''' - elif isinstance(conn, dtypes.pointer): - if isinstance(conn.base_type, dtypes.vector): - inputs[name] = f'''int idx = in_ptr_{name}++; -for (int j = 0; j < {veclen}; j++) {{ - model->s_axis_{name}_tdata[j] = (int){name}[idx][j]; -}}''' - else: - inputs[name] = f'model->s_axis_{name}_tdata = {name}[in_ptr_{name}++];' - else: - inputs[name] = f'''in_ptr_{name}++; -model->s_axis_{name}_tdata = {name}[0];''' - - return inputs, outputs - - def generate_cpp_vector_init(self, tasklet): - inits = [] - for name in tasklet.in_connectors: - conn = tasklet.in_connectors[name] - if isinstance(conn, dtypes.pointer): - conn = conn.base_type - if isinstance(conn, dtypes.vector): - inits.append(f'''for (int j = 0; j < {conn.veclen}; j++) {{ - model->s_axis_{name}_tdata[j] = 0; -}}''') - - return "\n".join(inits) - - def generate_cpp_num_elements(self, buses): - # TODO: compute num_elements=#elements that enter/leave the pipeline, for now we assume in_elem=out_elem (i.e. no reduction) - return [ - f'''int num_elements_{name} = {volume};''' - for name, (arr, is_output, bytes, veclen, volume) in buses.items() - ] - - def generate_cpp_internal_state(self, buses): - internal_state_strs = [] - internal_state_vars = [] - for name, (_, is_output, _, veclen, _) in buses.items(): - prefix = 'm' if is_output else 's' - data_format = '0x%x' if veclen == 1 else f'[{" ".join(["0x%x"]*veclen)}]' - internal_state_strs.append(f'{name}={data_format} ready=%u valid=%u') - data_vars = f'model->{prefix}_axis_{name}_tdata' if veclen == 1 else ', '.join( - [f'model->{prefix}_axis_{name}_tdata[{i}]' for i in range(veclen)]) - internal_state_vars.append( - f'{data_vars}, model->{prefix}_axis_{name}_tvalid, model->{prefix}_axis_{name}_tready') - internal_state_str = " | ".join(internal_state_strs) - internal_state_var = ", ".join(internal_state_vars) - return internal_state_str, internal_state_var - - def generate_input_hs(self, buses): - """ - Generate checking whether input to the tasklet has been consumed - """ - return [ - f'''if (model->s_axis_{name}_tready == 1 && model->s_axis_{name}_tvalid == 1) {{ - read_input_hs_{name} = true; - }}''' for name, (arr, is_output, bytes, veclen, volume) in buses.items() if not is_output - ] - - def generate_feeding(self, tasklet, inputs): - """ - Generate statements for feeding into a streaming AXI bus - """ - debug_feed_element = "std::cout << \"feed new element\" << std::endl;\n" if self.verilator_debug else "" - return [ - f'''if (model->s_axis_{name}_tvalid == 0 && in_ptr_{name} < num_elements_{name}) {{ - {debug_feed_element}{inputs[name]} - model->s_axis_{name}_tvalid = 1; - }}''' for name in inputs - ] - - def generate_ptrs(self, tasklet): - """ - Generate pointers for the transaction counters - """ - ins = [f'''int in_ptr_{name} = 0;''' for name in tasklet.in_connectors] - outs = [f'''int out_ptr_{name} = 0;''' for name in tasklet.out_connectors] - return ins, outs - - def generate_exporting(self, tasklet, outputs): - """ - Generate statements for whether an element output by the tasklet is ready. - """ - debug_export_element = "std::cout << \"export element\" << std::endl;\n" if self.verilator_debug else "" - return [ - f'''if (model->m_axis_{name}_tvalid == 1) {{ - {debug_export_element}{outputs[name]} - model->m_axis_{name}_tready = 1; - }}''' for name in tasklet.out_connectors - ] - - def generate_write_output_hs(self, tasklet): - """ - Generate check for whether an element has been consumed from the output of a tasklet. - """ - return [ - f'''if (model->m_axis_{name}_tready && model->m_axis_{name}_tvalid == 1) {{ - write_output_hs_{name} = true; - }}''' for name in tasklet.out_connectors - ] - - def generate_hs_flags(self, buses): - """ - Generate flags - """ - return [ - f'bool {"write_out" if is_output else "read_in"}put_hs_{name} = false;' - for name, (arr, is_output, bytes, veclen, volume) in buses.items() - ] - - def generate_input_hs_toggle(self, buses): - """ - Generate statements for toggling input flags. - """ - debug_read_input_hs = "\nstd::cout << \"remove read_input_hs flag\" << std::endl;" if self.verilator_debug else "" - return [ - f'''if (read_input_hs_{name}) {{ - // remove valid flag {debug_read_input_hs} - model->s_axis_{name}_tvalid = 0; - read_input_hs_{name} = false; - }}''' for name, (arr, is_output, bytes, veclen, volume) in buses.items() if not is_output - ] - - def generate_output_hs_toggle(self, buses): - """ - Generate statements for toggling output flags. - """ - debug_write_output_hs = "\nstd::cout << \"remove write_output_hs flag\" << std::endl;" if self.verilator_debug else "" - return [ - f'''if (write_output_hs_{name}) {{ - // remove ready flag {debug_write_output_hs} - model->m_axis_{name}_tready = 0; - write_output_hs_{name} = false; - }}''' for name, (arr, is_output, bytes, veclen, volume) in buses.items() if is_output - ] - - def generate_running_condition(self, tasklet): - """ - Generate the condition for whether the simulation should be running. - """ - # TODO should be changed with free-running kernels. Currently only - # one element is supported. Additionally, this should not be used as - # condition, as the amount of input and output elements might not be - # equal to each other. - evals = ' && '.join([f'out_ptr_{name} < num_elements_{name}' for name in tasklet.out_connectors]) - return evals - - def unique_name(self, node: nodes.RTLTasklet, state: SDFGState): - return "{}_{}_{}_{}".format(node.name, state.parent_graph.cfg_id, state.block_id, state.node_id(node)) - - def unparse_tasklet(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphView, state_id: int, - node: nodes.Node, function_stream: prettycode.CodeIOStream, - callsite_stream: prettycode.CodeIOStream): - - # extract data - state = cfg.state(state_id) - tasklet = node - - # construct variables paths - unique_name: str = self.unique_name(tasklet, state) - - # Collect all of the input and output connectors into buses and scalars - buses = {} # {tasklet_name: (array_name, output_from_rtl, bytes, veclen)} - scalars = {} - for edge in state.in_edges(tasklet): - arr = sdfg.arrays[edge.data.data] - conn = tasklet.in_connectors[edge.dst_conn] - if isinstance(conn, dtypes.pointer): - conn = conn.base_type - - # catch symbolic (compile time variables) - check_issymbolic([conn.veclen, conn.bytes], sdfg) - - # extract parameters - vec_len = int(symbolic.evaluate(conn.veclen, sdfg.constants)) - total_size = int(symbolic.evaluate(conn.bytes, sdfg.constants)) - if isinstance(arr, data.Array): - if self.hardware_target: - raise NotImplementedError('Array input for RTL hardware* targets not implemented') - else: - buses[edge.dst_conn] = (edge.data.data, False, total_size, vec_len, edge.data.volume) - elif isinstance(arr, data.Stream): - buses[edge.dst_conn] = (edge.data.data, False, total_size, vec_len, edge.data.volume) - elif isinstance(arr, data.Scalar): - scalars[edge.dst_conn] = (False, total_size) - - for edge in state.out_edges(tasklet): - arr = sdfg.arrays[edge.data.data] - conn = tasklet.out_connectors[edge.src_conn] - if isinstance(conn, dtypes.pointer): - conn = conn.base_type - - # catch symbolic (compile time variables) - check_issymbolic([conn.veclen, conn.bytes], sdfg) - - # extract parameters - vec_len = int(symbolic.evaluate(conn.veclen, sdfg.constants)) - total_size = int(symbolic.evaluate(conn.bytes, sdfg.constants)) - if isinstance(arr, data.Array): - if self.hardware_target: - raise NotImplementedError('Array input for RTL hardware* targets not implemented') - else: - buses[edge.src_conn] = (edge.data.data, True, total_size, vec_len, edge.data.volume) - elif isinstance(arr, data.Stream): - buses[edge.src_conn] = (edge.data.data, True, total_size, vec_len, edge.data.volume) - elif isinstance(arr, data.Scalar): - raise NotImplementedError('Scalar output from RTL kernels not implemented') - - # generate system verilog module components - parameter_string: str = self.generate_rtl_parameters(sdfg.constants) - inputs, outputs = self.generate_rtl_inputs_outputs(buses, scalars) - nclks, freqs, clks, ports = self.generate_clk_from_cfg() - - # create rtl code object (that is later written to file) - self.code_objects.append( - codeobject.CodeObject(name="{}".format(unique_name), - code=RTLCodeGen.RTL_HEADER.format(name=unique_name, - parameters=parameter_string, - inputs="\n".join(inputs), - outputs="\n".join(outputs), - clk_rst_ports=ports) + tasklet.code.code + - RTLCodeGen.RTL_FOOTER, - language="sv", - target=RTLCodeGen, - title="rtl", - target_type="{}".format(unique_name), - additional_compiler_kwargs="", - linkable=True, - environments=None)) - - if self.hardware_target: - if self.vendor == 'xilinx': - # Avoid importing submodule if not necessary - from dace.external.rtllib.templates.control import generate_from_config as rtllib_control - from dace.external.rtllib.templates.package import generate_from_config as rtllib_package - from dace.external.rtllib.templates.synth import generate_from_config as rtllib_synth - from dace.external.rtllib.templates.top import generate_from_config as rtllib_top - - rtllib_config = { - "name": unique_name, - "buses": { - name: ('m_axis' if is_output else 's_axis', vec_len) - for name, (_, is_output, _, vec_len, _) in buses.items() - }, - "params": { - "scalars": { - name: total_size * 8 # width in bits - for name, (_, total_size) in scalars.items() - }, - "memory": {} - }, - "unroll": self.n_unrolled[unique_name] if unique_name in self.n_unrolled else 1, - "ip_cores": tasklet.ip_cores if isinstance(tasklet, nodes.RTLTasklet) else {}, - "clocks": int(nclks) - } - - self.code_objects.append( - codeobject.CodeObject(name=f"{unique_name}_control", - code=rtllib_control(rtllib_config), - language="v", - target=RTLCodeGen, - title="rtl", - target_type="{}".format(unique_name), - additional_compiler_kwargs="", - linkable=True, - environments=None)) - - self.code_objects.append( - codeobject.CodeObject(name=f"{unique_name}_top", - code=rtllib_top(rtllib_config), - language="v", - target=RTLCodeGen, - title="rtl", - target_type="{}".format(unique_name), - additional_compiler_kwargs="", - linkable=True, - environments=None)) - - self.code_objects.append( - codeobject.CodeObject(name=f"{unique_name}_package", - code=rtllib_package(rtllib_config), - language="tcl", - target=RTLCodeGen, - title="rtl", - target_type="scripts", - additional_compiler_kwargs="", - linkable=True, - environments=None)) - - self.code_objects.append( - codeobject.CodeObject(name=f"{unique_name}_synth", - code=rtllib_synth(rtllib_config), - language="tcl", - target=RTLCodeGen, - title="rtl", - target_type="scripts", - additional_compiler_kwargs="", - linkable=True, - environments=None)) - else: # self.vendor != "xilinx" - raise NotImplementedError('Only RTL codegen for Xilinx is implemented') - else: # not hardware_target - # generate verilator simulation cpp code components - inputs, outputs = self.generate_cpp_inputs_outputs(tasklet, buses) - valid_zeros, ready_zeros, scalar_zeros = self.generate_cpp_zero_inits(buses, scalars) - vector_init = self.generate_cpp_vector_init(tasklet) - num_elements = self.generate_cpp_num_elements(buses) - internal_state_str, internal_state_var = self.generate_cpp_internal_state(buses) - read_input_hs = self.generate_input_hs(buses) - feed_elements = self.generate_feeding(tasklet, inputs) - in_ptrs, out_ptrs = self.generate_ptrs(tasklet) - export_elements = self.generate_exporting(tasklet, outputs) - write_output_hs = self.generate_write_output_hs(tasklet) - hs_flags = self.generate_hs_flags(buses) - input_hs_toggle = self.generate_input_hs_toggle(buses) - output_hs_toggle = self.generate_output_hs_toggle(buses) - running_condition = self.generate_running_condition(tasklet) - - # add header code to stream - if not self.cpp_general_header_added: - sdfg.append_global_code(cpp_code=RTLCodeGen.CPP_GENERAL_HEADER_TEMPLATE.format( - debug_include="// generic includes\n#include " if self.verilator_debug else "")) - self.cpp_general_header_added = True - sdfg.append_global_code(cpp_code=RTLCodeGen.CPP_MODEL_HEADER_TEMPLATE.format(name=unique_name)) - - # add main cpp code to stream - callsite_stream.write(contents=RTLCodeGen.CPP_MAIN_TEMPLATE.format( - name=unique_name, - inputs=inputs, - outputs=outputs, - num_elements=str.join('\n', num_elements), - vector_init=vector_init, - valid_zeros=str.join('\n', valid_zeros), - ready_zeros=str.join('\n', ready_zeros), - scalar_zeros=str.join('\n', scalar_zeros), - read_input_hs=str.join('\n', read_input_hs), - feed_elements=str.join('\n', feed_elements), - in_ptrs=str.join('\n', in_ptrs), - out_ptrs=str.join('\n', out_ptrs), - export_elements=str.join('\n', export_elements), - write_output_hs=str.join('\n', write_output_hs), - hs_flags=str.join('\n', hs_flags), - input_hs_toggle=str.join('\n', input_hs_toggle), - output_hs_toggle=str.join('\n', output_hs_toggle), - nclks=nclks, - freqs=freqs, - clks=clks, - running_condition=running_condition, - internal_state_str=internal_state_str, - internal_state_var=internal_state_var, - debug_sim_start="std::cout << \"SIM {name} START\" << std::endl;" if self.verilator_debug else "", - debug_internal_state=f''' -// report internal state -VL_PRINTF("[t=%lu] ap_aclk=%u ap_areset=%u\\n", main_time, model->ap_aclk, model->ap_areset); -VL_PRINTF("{internal_state_str}\\n", {internal_state_var}); -std::cout << std::flush; -''' if self.verilator_debug else '', - debug_sim_end="\nstd::cout << \"SIM {name} END\" << std::endl;" if self.verilator_debug else "", - ), - cfg=cfg, - state_id=state_id, - node_id=node) - - CPP_GENERAL_HEADER_TEMPLATE = """\ -{debug_include} -// verilator includes -#include -""" - - CPP_MODEL_HEADER_TEMPLATE = """\ -// include model header, generated from verilating the sv design -#include "V{name}.h" -""" - - CPP_MAIN_TEMPLATE = """\ -{debug_sim_start} - -vluint64_t main_time = 0; - -// instantiate model(s) -V{name}* model = new V{name}; - -// instantiate clock handling -int nclks = {nclks}; -int freqs[nclks] = {freqs}; -CData* clks[nclks] = {clks}; -double periods[nclks]; -for (int i = 0; i < nclks; i++) {{ - periods[i] = (1000.0 / freqs[i]); -}} -double ttf[nclks]; // time to flip -auto tick = [&]() {{ - // Lambda function for driving all of the clock signals of the model, until the first clock signal have had a full in-between-rising-edge cycle - bool first_rised = *(clks[0]); - while (true) {{ - int next = 0; - for (int i = 1; i < nclks; i++) {{ - if (ttf[i] < ttf[next]) {{ - next = i; - }} - }} - double time = ttf[next]; - for (int i = 0; i < nclks; i++) {{ - if (i == next) {{ - ttf[i] = periods[i] / 2; - *(clks[i]) = !*(clks[i]); - }} else {{ - ttf[i] -= time; - }} - }} - main_time += time; - model->eval(); - if (next == 0 && *(clks[0]) == 1) {{ - if (first_rised) - break; - else - first_rised = true; - }} - }} -}}; - -// apply initial input values -model->ap_areset = 0; // no reset -{valid_zeros} -{ready_zeros} -{scalar_zeros} - -model->eval(); - -// Initialize vectors -{vector_init} - -// reset design -model->ap_areset = 1; -tick(); -model->ap_areset = 0; -tick(); - -// simulate until in_handshakes = out_handshakes = num_elements -{hs_flags} -{in_ptrs} -{out_ptrs} -{num_elements} - -while ({running_condition}) {{ - // increment time - main_time++; - - // feed elements -{feed_elements} - // export elements -{export_elements} - - // check if valid and ready have been asserted at the rising clock edge -> input read handshake -{read_input_hs} - // check if valid and ready have been asserted at the rising clock edge -> output write handshake -{write_output_hs} - - tick(); {debug_internal_state} - - // check if valid and ready has been asserted for each input at the rising clock edge -{input_hs_toggle} - - // check if valid and ready haæ been asserted for each output at the rising clock edge -{output_hs_toggle} -}} {debug_internal_state} - -// final model cleanup -model->final(); - -// clean up resources -delete model; -model = NULL; -{debug_sim_end}""" - - RTL_HEADER = """\ -module {name} -{parameters} -( {clk_rst_ports} -, input ap_start // convention: ap_start indicates a start from host -, output ap_done // convention: ap_done tells the host that the kernel has finished - -{inputs} - -{outputs} -); -""" - - RTL_FOOTER = ''' -endmodule -''' - - -def check_issymbolic(iterator: iter, sdfg): - for item in iterator: - # catch symbolic (compile time variables) - if symbolic.issymbolic(item, sdfg.constants): - raise ValueError("Please use sdfg.specialize to make the following symbol(s) constant: {}".format(", ".join( - [str(x) for x in item.free_symbols if str(x) not in sdfg.constants]))) diff --git a/dace/codegen/targets/snitch.py b/dace/codegen/targets/snitch.py index 6d080a2a14..bb7b3a93c3 100644 --- a/dace/codegen/targets/snitch.py +++ b/dace/codegen/targets/snitch.py @@ -1,6 +1,6 @@ -# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. -from typing import Union +from typing import Optional, Union import dace import itertools import numpy as np @@ -12,13 +12,13 @@ from dace.sdfg.state import ControlFlowRegion, SDFGState, StateSubgraphView from dace.transformation.dataflow.streaming_memory import _collect_map_ranges -from dace import registry, data, dtypes, config, symbolic +from dace import registry, data, dtypes, config, symbolic, subsets from dace.sdfg import nodes, utils as sdutils from dace.sdfg.scope import ScopeSubgraphView from dace.codegen.prettycode import CodeIOStream from dace.codegen.targets import cpp from dace.codegen.common import update_persistent_desc -from dace.codegen.targets.target import TargetCodeGenerator +from dace.codegen.target import TargetCodeGenerator from dace.codegen.targets.framecode import DaCeCodeGenerator from dace.codegen.targets.cpp import sym2cpp from dace.codegen.dispatcher import DefinedType @@ -43,8 +43,6 @@ def __init__(self, frame_codegen: DaCeCodeGenerator, sdfg: dace.SDFG): self.frame = frame_codegen # Can be used to dispatch other code generators for allocation/nodes self.dispatcher = frame_codegen.dispatcher - # ??? - self.packed_types = False # Mapping of ssr to ssr_config self.ssrs = MAX_SSR_STREAMERS * [None] @@ -76,6 +74,14 @@ def __init__(self, frame_codegen: DaCeCodeGenerator, sdfg: dace.SDFG): self.dispatcher.register_array_dispatcher(dace.StorageType.Snitch_TCDM, self) self.dispatcher.register_array_dispatcher(dace.StorageType.Snitch_SSR, self) + def get_framecode_generator(self) -> 'DaCeCodeGenerator': + """ + Returns the frame-code generator associated with this target. + + :return: The frame-code generator. + """ + return self.frame + def state_dispatch_predicate(self, sdfg, state): for node in state.nodes(): if isinstance(node, nodes.AccessNode): @@ -107,7 +113,7 @@ def try_simplify(expr): continue dbg(f'emitting ssr config for ssr {ssr}') node = ssr["data"] - alloc_name = cpp.ptr(node.data, node.desc(sdfg)) + alloc_name = self.ptr(node.data, node.desc(sdfg), sdfg) # emit bound/stride setup stride_off = '0' for dim_num, dim in enumerate(ssr["dims"]): @@ -294,7 +300,7 @@ def define_out_memlet(self, sdfg: SDFG, cfg: ControlFlowRegion, state_dfg: State elif isinstance(cdtype, dtypes.pointer): # If pointer, also point to output defined_type, _ = self.dispatcher.defined_vars.get(edge.data.data) - base_ptr = cpp.cpp_ptr_expr(sdfg, edge.data, defined_type) + base_ptr = cpp.cpp_ptr_expr(sdfg, edge.data, defined_type, codegen=self) callsite_stream.write(f'{cdtype.ctype} {edge.src_conn} = {base_ptr};', cfg, state_id, src_node) else: callsite_stream.write(f'{cdtype.ctype} {edge.src_conn};', cfg, state_id, src_node) @@ -319,19 +325,14 @@ def memlet_definition(self, sdfg, memlet, output, local_name, conntype=None, all memlet_type = conntype.dtype.ctype desc = sdfg.arrays[memlet.data] - ptr = cpp.ptr(memlet.data, desc) + ptr = self.ptr(memlet.data, desc, sdfg) var_type, ctypedef = self.dispatcher.defined_vars.get(memlet.data) result = '' - expr = (cpp.cpp_array_expr(sdfg, memlet, with_brackets=False) - if var_type in [DefinedType.Pointer, DefinedType.StreamArray, DefinedType.ArrayInterface] else ptr) + expr = (cpp.cpp_array_expr(sdfg, memlet, with_brackets=False, codegen=self) + if var_type in [DefinedType.Pointer, DefinedType.StreamArray] else ptr) - # Special case: ArrayInterface, append _in or _out _ptr = ptr - if var_type == DefinedType.ArrayInterface: - # Views have already been renamed - if not isinstance(desc, data.View): - ptr = cpp.array_interface_variable(ptr, output, self.dispatcher) if expr != _ptr: expr = '%s[%s]' % (ptr, expr) # If there is a type mismatch, cast pointer @@ -339,11 +340,9 @@ def memlet_definition(self, sdfg, memlet, output, local_name, conntype=None, all defined = None - if var_type in [DefinedType.Scalar, DefinedType.Pointer, DefinedType.ArrayInterface]: + if var_type in [DefinedType.Scalar, DefinedType.Pointer]: if output: - if is_pointer and var_type == DefinedType.ArrayInterface: - result += "{} {} = {};".format(memlet_type, local_name, expr) - elif not memlet.dynamic or (memlet.dynamic and memlet.wcr is not None): + if not memlet.dynamic or (memlet.dynamic and memlet.wcr is not None): # Dynamic WCR memlets start uninitialized result += "{} {};".format(memlet_type, local_name) defined = DefinedType.Scalar @@ -397,7 +396,7 @@ def allocate_array(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphV # Compute array size arrsize = nodedesc.total_size arrsize_bytes = arrsize * nodedesc.dtype.bytes - alloc_name = cpp.ptr(name, nodedesc) + alloc_name = self.ptr(name, nodedesc, sdfg) dbg(' arrsize "{}" arrsize_bytes "{}" alloc_name "{}" nodedesc "{}"'.format( arrsize, arrsize_bytes, alloc_name, nodedesc)) @@ -466,7 +465,7 @@ def deallocate_array(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgrap node: nodes.AccessNode, nodedesc: data.Data, function_stream: CodeIOStream, callsite_stream: CodeIOStream) -> None: arrsize = nodedesc.total_size - alloc_name = cpp.ptr(node.data, nodedesc) + alloc_name = self.ptr(node.data, nodedesc, sdfg) dbg(f'-- deallocate_array storate="{nodedesc.storage}" arrsize="{arrsize}" alloc_name="{alloc_name}"') if isinstance(nodedesc, data.Scalar): @@ -613,8 +612,7 @@ def copy_memory( copy_shape, src_strides, dst_strides, src_expr, dst_expr = \ cpp.memlet_copy_to_absolute_strides( - self.dispatcher, sdfg, state_dfg, edge, src_node, dst_node, - self.packed_types) + self.dispatcher, sdfg, state_dfg, edge, src_node, dst_node, codegen=self) dbg(f' copy_shape = "{copy_shape}", src_strides = "{src_strides}", dst_strides = "{dst_strides}", src_expr = "{src_expr}", dst_expr = "{dst_expr}"' ) @@ -1052,9 +1050,9 @@ def write_and_resolve_expr(self, sdfg, memlet, nc, outname, inname, indices=None defined_type, _ = self.dispatcher.defined_vars.get(memlet.data) if isinstance(indices, str): - ptr = '%s + %s' % (cpp.cpp_ptr_expr(sdfg, memlet, defined_type), indices) + ptr = '%s + %s' % (cpp.cpp_ptr_expr(sdfg, memlet, defined_type, codegen=self), indices) else: - ptr = cpp.cpp_ptr_expr(sdfg, memlet, defined_type, indices=indices) + ptr = cpp.cpp_ptr_expr(sdfg, memlet, defined_type, indices=indices, codegen=self) if isinstance(dtype, dtypes.pointer): dtype = dtype.base_type # If there is a type mismatch, cast pointer @@ -1148,3 +1146,23 @@ def gen_code_snitch(sdfg): ccode = ccode.replace(i, o) return (ccode, hdrs) + + def ptr(self, + name: str, + desc: data.Data, + sdfg: SDFG = None, + subset: Optional[subsets.Subset] = None, + is_write: Optional[bool] = None, + ancestor: int = 0) -> str: + """ + Returns a string that points to the data based on its name and descriptor. + + :param name: Data name. + :param desc: Data descriptor. + :param sdfg: SDFG in which the data resides. + :param subset: Optional subset associated with the data. + :param is_write: Whether the access is a write access. + :param ancestor: Scope ancestor level. + :return: C-compatible name that can be used to access the data. + """ + return cpp.ptr(name, desc, sdfg, self.frame) diff --git a/dace/codegen/targets/sve/codegen.py b/dace/codegen/targets/sve/codegen.py index 1581aec90d..14d2c88e5b 100644 --- a/dace/codegen/targets/sve/codegen.py +++ b/dace/codegen/targets/sve/codegen.py @@ -1,11 +1,11 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. """ Code generation: This module is responsible for converting an SDFG into SVE code. """ from dace.sdfg.scope import ScopeSubgraphView from dace.codegen.prettycode import CodeIOStream -from dace.codegen.targets.target import TargetCodeGenerator +from dace.codegen.target import TargetCodeGenerator from dace.codegen.targets.framecode import DaCeCodeGenerator from dace.sdfg import nodes, SDFG, SDFGState, ScopeSubgraphView, graph as gr from dace.codegen.prettycode import CodeIOStream @@ -15,7 +15,7 @@ from dace.sdfg.scope import is_in_scope import itertools from dace.codegen.targets.sve import util as util -from typing import List +from typing import List, Optional import copy from six import StringIO import dace.codegen.targets.sve.unparse @@ -29,7 +29,7 @@ import copy import numpy as np from dace.codegen.targets.cpp import is_write_conflicted -from dace import data as data +from dace import data, subsets from dace.frontend.operations import detect_reduction_type import dace.codegen.targets @@ -72,6 +72,9 @@ def __init__(self, frame_codegen: DaCeCodeGenerator, sdfg: dace.SDFG): self.cpu_codegen: dace.codegen.targets.CPUCodeGen = self.dispatcher.get_generic_node_dispatcher() + def get_framecode_generator(self) -> DaCeCodeGenerator: + return self.frame + def get_generated_codeobjects(self): res = super().get_generated_codeobjects() return res @@ -180,7 +183,7 @@ def generate_read(self, sdfg: SDFG, state: SDFGState, map: nodes.Map, edge: grap ################## # Pointer reference code.write( - f'{dst_type} {dst_name} = {cpp.cpp_ptr_expr(sdfg, edge.data, None, codegen=self.frame)};') + f'{dst_type} {dst_name} = {cpp.cpp_ptr_expr(sdfg, edge.data, None, codegen=self.cpu_codegen)};') elif util.is_vector(dst_type): ################## # Vector load @@ -200,7 +203,7 @@ def generate_read(self, sdfg: SDFG, state: SDFGState, map: nodes.Map, edge: grap # Regular load and gather share the first arguments load_args = '{}, {}'.format( util.get_loop_predicate(sdfg, state, edge.dst), - ptr_cast + cpp.cpp_ptr_expr(sdfg, edge.data, DefinedType.Pointer, codegen=self.frame)) + ptr_cast + cpp.cpp_ptr_expr(sdfg, edge.data, DefinedType.Pointer, codegen=self.cpu_codegen)) if stride == 1: code.write('{} = svld1({});'.format(load_lhs, load_args)) @@ -211,7 +214,7 @@ def generate_read(self, sdfg: SDFG, state: SDFGState, map: nodes.Map, edge: grap else: ################## # Scalar read from array - code.write(f'{dst_type} {dst_name} = {cpp.cpp_array_expr(sdfg, edge.data, codegen=self.frame)};') + code.write(f'{dst_type} {dst_name} = {cpp.cpp_array_expr(sdfg, edge.data, codegen=self)};') elif isinstance(desc, data.Scalar): # Refer to shared variable src_type = desc.dtype @@ -327,7 +330,7 @@ def generate_writeback(self, sdfg: SDFG, state: SDFGState, map: nodes.Map, store_args = '{}, {}'.format( util.get_loop_predicate(sdfg, state, edge.src), - ptr_cast + cpp.cpp_ptr_expr(sdfg, edge.data, DefinedType.Pointer, codegen=self.frame), + ptr_cast + cpp.cpp_ptr_expr(sdfg, edge.data, DefinedType.Pointer, codegen=self.cpu_codegen), ) if stride == 1: @@ -339,7 +342,7 @@ def generate_writeback(self, sdfg: SDFG, state: SDFGState, map: nodes.Map, else: ################## # Scalar write into array - code.write(f'{cpp.cpp_array_expr(sdfg, edge.data, codegen=self.frame)} = {src_name};') + code.write(f'{cpp.cpp_array_expr(sdfg, edge.data, codegen=self)} = {src_name};') elif isinstance(desc, data.Scalar): ################## # Write into Scalar @@ -381,7 +384,7 @@ def allocate_array(self, sdfg: SDFG, cfg: state.ControlFlowRegion, dfg: SDFGStat nodedesc.dtype, dtypes.vector): # Special allocation if vector Code->Code register in SVE scope # We prevent dace::vec<>'s and allocate SVE registers instead - ptrname = cpp.ptr(node.data, nodedesc, sdfg, self.frame) + ptrname = self.ptr(node.data, nodedesc, sdfg) if self.dispatcher.defined_vars.has(ptrname): sve_type = util.TYPE_TO_SVE[nodedesc.dtype.vtype] self.dispatcher.defined_vars.add(ptrname, DefinedType.Scalar, sve_type) @@ -515,3 +518,23 @@ def unparse_tasklet(self, sdfg: SDFG, cfg: state.ControlFlowRegion, dfg: state.S callsite_stream.write(result.getvalue(), cfg, state_id, node) callsite_stream.write('///////////////////\n\n') + + def ptr(self, + name: str, + desc: data.Data, + sdfg: SDFG = None, + subset: Optional[subsets.Subset] = None, + is_write: Optional[bool] = None, + ancestor: int = 0) -> str: + """ + Returns a string that points to the data based on its name and descriptor. + + :param name: Data name. + :param desc: Data descriptor. + :param sdfg: SDFG in which the data resides. + :param subset: Optional subset associated with the data. + :param is_write: Whether the access is a write access. + :param ancestor: Scope ancestor level. + :return: C-compatible name that can be used to access the data. + """ + return cpp.ptr(name, desc, sdfg, self.frame) diff --git a/dace/codegen/targets/sve/infer.py b/dace/codegen/targets/sve/infer.py index 10556a9581..97f8329c5f 100644 --- a/dace/codegen/targets/sve/infer.py +++ b/dace/codegen/targets/sve/infer.py @@ -6,7 +6,7 @@ import numpy as np import ast from dace import dtypes -from dace.codegen import cppunparse +from dace.sdfg import type_inference from dace.symbolic import SymExpr import sympy import sys @@ -33,7 +33,7 @@ def _dispatch(tree, symbols, inferred_symbols): if hasattr(patch, name): meth = getattr(patch, name) else: - meth = getattr(dace.codegen.tools.type_inference, name) + meth = getattr(type_inference, name) return meth(tree, symbols, inferred_symbols) diff --git a/dace/codegen/targets/sve/unparse.py b/dace/codegen/targets/sve/unparse.py index bdab13d1eb..06c56ff70d 100644 --- a/dace/codegen/targets/sve/unparse.py +++ b/dace/codegen/targets/sve/unparse.py @@ -346,7 +346,7 @@ def vector_reduction_expr(self, edge, dtype, rhs): store_args = '{}, {}'.format( self.pred_name, - ptr_cast + cpp_ptr_expr(self.sdfg, edge.data, DefinedType.Pointer, codegen=self.cpu_codegen._frame), + ptr_cast + cpp_ptr_expr(self.sdfg, edge.data, DefinedType.Pointer, codegen=self.cpu_codegen), ) red_type = util.REDUCTION_TYPE_TO_SVE[reduction_type][:-1] + '_x' diff --git a/dace/codegen/targets/unroller.py b/dace/codegen/targets/unroller.py deleted file mode 100644 index f4c2bdd2c0..0000000000 --- a/dace/codegen/targets/unroller.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -import copy - -import dace -from dace import registry -from dace.sdfg.scope import ScopeSubgraphView -from dace.codegen.prettycode import CodeIOStream -from dace.codegen.targets.target import TargetCodeGenerator -from dace.codegen.targets.framecode import DaCeCodeGenerator -from itertools import product -from dace.sdfg import state -import dace.subsets -import dace.sdfg -from dace.sdfg import nodes as nd -import dace.codegen.common - - -@registry.autoregister_params(name='unroll') -class UnrollCodeGen(TargetCodeGenerator): - """ A constant-range map unroller code generator. """ - target_name = 'unroll' - title = 'Unrolled' - language = 'cpp' - - def __init__(self, frame_codegen: DaCeCodeGenerator, sdfg: dace.SDFG): - self._frame = frame_codegen - self._dispatcher = frame_codegen.dispatcher - dispatcher = self._dispatcher - - # Register dispatchers - self._dispatcher.register_map_dispatcher(dace.ScheduleType.Unrolled, self) - - def get_generated_codeobjects(self): - return [] - - #Generate new names for nsdfgs, and adds defined variables to constants - def nsdfg_prepare_unroll(self, scope: ScopeSubgraphView, paramname: str, paramval: str): - backup = [] - for node in scope.nodes(): - if (isinstance(node, nd.NestedSDFG)): - backup.append((node, node.unique_name, node.sdfg.name, node.symbol_mapping, node.sdfg.constants_prop)) - node.unique_name = copy.deepcopy(node.unique_name) - node.sdfg.name = copy.deepcopy(node.sdfg.name) - node.symbol_mapping = copy.deepcopy(node.symbol_mapping) - node.sdfg.constants_prop = copy.deepcopy(node.sdfg.constants_prop) - node.unique_name = f"{node.unique_name}_{paramname}{paramval}" - node.sdfg.name = f"{node.sdfg.name}_{paramname}{paramval}" - for nstate in node.sdfg.nodes(): - backup.extend(self.nsdfg_prepare_unroll(nstate, paramname, paramval)) - if paramname in node.symbol_mapping: - node.symbol_mapping.pop(paramname) - node.sdfg.add_constant(paramname, int(paramval)) - return backup - - def nsdfg_after_unroll(self, backup: "list[tuple[str, str, dict, dict]]"): - for node, unique_name, name, symbols, constants in backup: - node.unique_name = unique_name - node.sdfg.name = name - node.symbol_mapping = symbols - node.sdfg.constants_prop = constants - - #TODO: Expand the unroller so it can also generate openCL code - def generate_scope(self, sdfg: dace.SDFG, cfg: state.ControlFlowRegion, scope: ScopeSubgraphView, state_id: int, - function_stream: CodeIOStream, callsite_stream: CodeIOStream) -> None: - entry_node: nd.MapEntry = scope.source_nodes()[0] - index_list = [] - - for begin, end, stride in entry_node.map.range: - l = [] - while begin <= end: - l.append(begin) - begin += stride - index_list.append(l) - - sdfgconsts = sdfg.constants_prop - sdfg.constants_prop = copy.deepcopy(sdfg.constants_prop) - - mapsymboltypes = entry_node.new_symbols(sdfg, scope, [entry_node.map.params]) - for indices in product(*index_list): - callsite_stream.write('{') - nsdfg_unroll_info = None - for param, index in zip(entry_node.map.params, indices): - if nsdfg_unroll_info is None: - nsdfg_unroll_info = self.nsdfg_prepare_unroll(scope, str(param), str(index)) - else: - self.nsdfg_prepare_unroll(scope, str(param), str(index)) - callsite_stream.write( - f"constexpr {mapsymboltypes[param]} {param} = " - f"{dace.codegen.common.sym2cpp(index)};\n", cfg) - sdfg.add_constant(param, int(index)) - - callsite_stream.write('{') - self._dispatcher.dispatch_subgraph( - sdfg, - cfg, - scope, - state_id, - function_stream, - callsite_stream, - skip_entry_node=True, - skip_exit_node=True, - ) - callsite_stream.write('}') - callsite_stream.write('}') - self.nsdfg_after_unroll(nsdfg_unroll_info) - - sdfg.constants_prop = sdfgconsts diff --git a/dace/codegen/targets/xilinx.py b/dace/codegen/targets/xilinx.py deleted file mode 100644 index 0c99974f97..0000000000 --- a/dace/codegen/targets/xilinx.py +++ /dev/null @@ -1,1339 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -import copy -from dace.sdfg.sdfg import SDFG -import re -import ast - -import dace -from dace import data as dt, registry, dtypes, subsets -from dace.config import Config -from dace.frontend import operations -from dace.sdfg import nodes -from dace.codegen import exceptions as cgx -from dace.codegen.codeobject import CodeObject -from dace.codegen.dispatcher import DefinedType -from dace.codegen.prettycode import CodeIOStream -from dace.codegen.targets import cpp, fpga -from typing import List, Union, Tuple - -from dace.sdfg.state import ControlFlowRegion - -REDUCTION_TYPE_TO_HLSLIB = { - dace.dtypes.ReductionType.Min: "hlslib::op::Min", - dace.dtypes.ReductionType.Max: "hlslib::op::Max", - dace.dtypes.ReductionType.Sum: "hlslib::op::Sum", - dace.dtypes.ReductionType.Product: "hlslib::op::Product", - dace.dtypes.ReductionType.Logical_And: "hlslib::op::And", -} - - -@registry.autoregister_params(name='xilinx') -class XilinxCodeGen(fpga.FPGACodeGen): - """ Xilinx FPGA code generator. """ - - target_name = 'xilinx' - title = 'Xilinx' - language = 'hls' - - def __init__(self, *args, **kwargs): - self.fpga_vendor = Config.get("compiler", "fpga", "vendor") - - # Check that the given vendor is supported - fpga.is_vendor_supported(self.fpga_vendor) - - if self.fpga_vendor.lower() != "xilinx": - # Don't register this code generator - return - super().__init__(*args, **kwargs) - # Used to pass memory bank assignments from kernel generation code to - # where they are written to file - self._bank_assignments = {} - - # Keep track of external streams: original_name -> mangled_name - self._external_streams = dict() - self._defined_external_streams = set() - self._execution_mode = Config.get("compiler", "xilinx", "mode") - self._decouple_array_interfaces = Config.get_bool("compiler", "xilinx", "decouple_array_interfaces") - - @staticmethod - def cmake_options(): - host_flags = Config.get("compiler", "xilinx", "host_flags") - synthesis_flags = Config.get("compiler", "xilinx", "synthesis_flags") - build_flags = Config.get("compiler", "xilinx", "build_flags") - mode = Config.get("compiler", "xilinx", "mode") - target_platform = Config.get("compiler", "xilinx", "platform") - enable_debugging = ("ON" if Config.get_bool("compiler", "xilinx", "enable_debugging") else "OFF") - autobuild = ("ON" if Config.get_bool("compiler", "fpga", "autobuild_bitstreams") else "OFF") - frequency = Config.get("compiler", "xilinx", "frequency").strip() - options = [ - "-DDACE_XILINX_HOST_FLAGS=\"{}\"".format(host_flags), - "-DDACE_XILINX_SYNTHESIS_FLAGS=\"{}\"".format(synthesis_flags), - "-DDACE_XILINX_BUILD_FLAGS=\"{}\"".format(build_flags), "-DDACE_XILINX_MODE={}".format(mode), - "-DDACE_XILINX_TARGET_PLATFORM=\"{}\"".format(target_platform), - "-DDACE_XILINX_ENABLE_DEBUGGING={}".format(enable_debugging), - "-DDACE_FPGA_AUTOBUILD_BITSTREAM={}".format(autobuild), "-DDACE_XILINX_TARGET_CLOCK={}".format(frequency) - ] - # Override Vitis/SDx/SDAccel installation directory - if Config.get("compiler", "xilinx", "path"): - options.append("-DVITIS_ROOT_DIR=\"{}\"".format( - Config.get("compiler", "xilinx", "path").replace("\\", "/"))) - return options - - def get_generated_codeobjects(self): - - execution_mode = Config.get("compiler", "xilinx", "mode") - - kernel_file_name = "DACE_BINARY_DIR \"/{}".format(self._program_name) - if execution_mode == "software_emulation": - kernel_file_name += "_sw_emu.xclbin\"" - xcl_emulation_mode = "\"sw_emu\"" - xilinx_sdx = "DACE_VITIS_DIR" - elif execution_mode == "hardware_emulation": - kernel_file_name += "_hw_emu.xclbin\"" - xcl_emulation_mode = "\"hw_emu\"" - xilinx_sdx = "DACE_VITIS_DIR" - elif execution_mode == "hardware" or execution_mode == "simulation": - kernel_file_name += "_hw.xclbin\"" - xcl_emulation_mode = None - xilinx_sdx = None - else: - raise cgx.CodegenError("Unknown Xilinx execution mode: {}".format(execution_mode)) - - set_env_vars = "" - set_str = "dace::set_environment_variable(\"{}\", {});\n" - unset_str = "dace::unset_environment_variable(\"{}\");\n" - set_env_vars += (set_str.format("XCL_EMULATION_MODE", xcl_emulation_mode) - if xcl_emulation_mode is not None else unset_str.format("XCL_EMULATION_MODE")) - set_env_vars += (set_str.format("XILINX_SDX", xilinx_sdx) - if xilinx_sdx is not None else unset_str.format("XILINX_SDX")) - set_env_vars += set_str.format( - "EMCONFIG_PATH", - "DACE_BINARY_DIR") if execution_mode == 'hardware_emulation' else unset_str.format("EMCONFIG_PATH") - - host_code = CodeIOStream() - host_code.write("""\ -#include "dace/xilinx/host.h" -#include "dace/dace.h" -#include "dace/xilinx/stream.h" -""") - if len(self._dispatcher.instrumentation) > 2: - host_code.write("""\ -#include "dace/perf/reporting.h" -#include -#include -#include -#include -""") - host_code.write("\n\n") - - self._frame.generate_fileheader(self._global_sdfg, host_code, 'xilinx_host') - - params_comma = self._global_sdfg.init_signature(free_symbols=self._frame.free_symbols(self._global_sdfg)) - if params_comma: - params_comma = ', ' + params_comma - - host_code.write(""" -DACE_EXPORTED int __dace_init_xilinx({sdfg_state_name} *__state{signature}) {{ - {environment_variables} - - __state->fpga_context = new dace_fpga_context(); - __state->fpga_context->Get().MakeProgram({kernel_file_name}); - return 0; -}} - -DACE_EXPORTED int __dace_exit_xilinx({sdfg_state_name} *__state) {{ - delete __state->fpga_context; - return 0; -}} - -{host_code}""".format(signature=params_comma, - sdfg=self._global_sdfg, - sdfg_state_name=cpp.mangle_dace_state_struct_name(self._global_sdfg), - environment_variables=set_env_vars, - kernel_file_name=kernel_file_name, - host_code="".join([ - "{separator}\n// Kernel: {kernel_name}" - "\n{separator}\n\n{code}\n\n".format(separator="/" * 79, kernel_name=name, code=code) - for (name, code) in self._host_codes - ]))) - - host_code_obj = CodeObject(self._program_name, - host_code.getvalue(), - "cpp", - XilinxCodeGen, - "Xilinx", - target_type="host") - - ip_objs = [ - CodeObject(kernel_name, code, ext, XilinxCodeGen, "Xilinx", target_type="device") - for (kernel_name, ext, code) in self._ip_codes - ] - - kernel_code_objs = [ - CodeObject(kernel_name, - code, - f"{'ip.' if is_multipumped else ''}cpp", - XilinxCodeGen, - "Xilinx", - target_type="device") for (kernel_name, code, is_multipumped) in self._kernel_codes - ] - - # Memory bank and streaming interfaces connectivity configuration file - link_cfg = CodeIOStream() - self._other_codes["link.cfg"] = link_cfg - link_cfg.write("[connectivity]") - are_assigned = [v is not None for v in self._bank_assignments.values()] - if any(are_assigned): - if not all(are_assigned): - raise RuntimeError("Some, but not all global memory arrays " - "were assigned to memory banks: {}".format(self._bank_assignments)) - # Emit mapping from kernel memory interfaces to DRAM banks - for (kernel_name, interface_name), (memory_type, memory_bank) in self._bank_assignments.items(): - link_cfg.write(f"sp={kernel_name}_1.m_axi_{interface_name}:{memory_type}[{memory_bank}]") - # Emit mapping between inter-kernel streaming interfaces - for stream_name, (src, dst) in self._stream_connections.items(): - if src is None or dst is None: - link_cfg.write(f'{stream_name} failed {src} {dst}') - elif src.replace('m_axis', 's_axis') != dst: - link_cfg.write(f"stream_connect={src}:{dst}") - - ip_names = [ip_name for ip_name, _, _ in self._ip_codes] - for kernel_name, _, _ in self._kernel_codes: - postfix = '_top_1' if f'{kernel_name}_top' in ip_names else '_1' - link_cfg.write(f'slr={kernel_name}{postfix}:SLR0') - - other_objs = [] - for name, code in self._other_codes.items(): - name = name.split(".") - other_objs.append( - CodeObject(name[0], code.getvalue(), ".".join(name[1:]), XilinxCodeGen, "Xilinx", target_type="device")) - - return [host_code_obj] + ip_objs + kernel_code_objs + other_objs - - def _internal_preprocess(self, sdfg: dace.SDFG): - """ - Vendor-specific SDFG Preprocessing - """ - - if self._decouple_array_interfaces: - # If array accesses are decoupled, preprocess inter state edge assignments: - # - look at every interstate edge - # - if any of them accesses an ArrayInterface (Global FPGA memory), qualify its name and replace it - # in the assignment string - - for graph in sdfg.all_sdfgs_recursive(): - for e in graph.all_interstate_edges(): - if len(e.data.assignments) > 0: - replace_dict = dict() - - for variable, value in e.data.assignments.items(): - expr = ast.parse(value) - # walk in the expression, get all array names and check whether we need to qualify them - for node in ast.walk(expr): - if isinstance(node, ast.Subscript) and isinstance(node.value, ast.Name): - arr_name = node.value.id - - if arr_name not in replace_dict and arr_name in graph.arrays and graph.arrays[ - arr_name].storage == dace.dtypes.StorageType.FPGA_Global: - repl = fpga.fpga_ptr(arr_name, - graph.arrays[node.value.id], - sdfg, - None, - False, - None, - None, - True, - decouple_array_interfaces=self._decouple_array_interfaces) - replace_dict[arr_name] = repl - - # Perform replacement and update graph.arrays to allow type inference - # on interstate edges - for k, v in replace_dict.items(): - e.data.replace(k, v) - if v not in graph.arrays: - # Note: this redundancy occurs only during codegen - graph.arrays[v] = graph.arrays[k] - - def define_stream(self, dtype, buffer_size, var_name, array_size, function_stream, kernel_stream, sdfg): - """ - Defines a stream - - :return: a tuple containing the type of the created variable, and boolean indicating - whether this is a global variable or not - """ - - ctype = "dace::FIFO<{}, {}, {}>".format(dtype.base_type.ctype, cpp.sym2cpp(dtype.veclen), - cpp.sym2cpp(buffer_size)) - - array_size_cpp = cpp.sym2cpp(array_size) - if array_size_cpp == "1": - kernel_stream.write("{} {}(\"{}\");".format(ctype, var_name, var_name)) - else: - kernel_stream.write("{} {}[{}];\n".format(ctype, var_name, array_size_cpp)) - kernel_stream.write("dace::SetNames({}, \"{}\", {});".format(var_name, var_name, array_size_cpp)) - # In Xilinx, streams are defined as local variables - # Return value is used for adding to defined_vars in fpga.py - return ctype, False - - def define_local_array(self, var_name, desc, array_size, function_stream, kernel_stream, sdfg, state_id, node): - dtype = desc.dtype - kernel_stream.write("{} {}[{}];\n".format(dtype.ctype, var_name, cpp.sym2cpp(array_size))) - if desc.storage == dace.dtypes.StorageType.FPGA_Registers: - kernel_stream.write("#pragma HLS ARRAY_PARTITION variable={} " - "complete\n".format(var_name)) - elif desc.storage == dace.dtypes.StorageType.FPGA_Local: - pass - else: - raise ValueError("Unsupported storage type: {}".format(desc.storage.name)) - self._dispatcher.defined_vars.add(var_name, DefinedType.Pointer, '%s *' % dtype.ctype) - - def define_shift_register(*args, **kwargs): - raise NotImplementedError("Xilinx shift registers NYI") - - @staticmethod - def make_vector_type(dtype, is_const): - return "{}{}".format("const " if is_const else "", dtype.ctype) - - @staticmethod - def make_kernel_argument(data: dt.Data, - var_name: str, - subset_info: Union[int, subsets.Subset], - sdfg: SDFG, - is_output: bool, - with_vectorization: bool, - interface_id: Union[int, List[int]] = None, - decouple_array_interfaces=False): - if isinstance(data, dt.Array): - var_name = fpga.fpga_ptr(var_name, - data, - sdfg, - subset_info, - is_output, - None, - None, - True, - interface_id, - decouple_array_interfaces=decouple_array_interfaces) - if with_vectorization: - dtype = data.dtype - else: - dtype = data.dtype.base_type - return "{} *{}".format(dtype.ctype, var_name) - if isinstance(data, dt.Stream): - ctype = "dace::FIFO<{}, {}, {}>".format(data.dtype.base_type.ctype, cpp.sym2cpp(data.dtype.veclen), - cpp.sym2cpp(data.buffer_size)) - if data.shape[0] == 1: - return "{} &{}".format(ctype, var_name) - else: - return "{} {}[{}]".format(ctype, var_name, data.shape[0]) - else: - return data.as_arg(with_types=True, name=var_name) - - def generate_unroll_loop_pre(self, kernel_stream, factor, sdfg, cfg, state_id, node): - pass - - @staticmethod - def generate_unroll_loop_post(kernel_stream, factor, sdfg, cfg, state_id, node): - if factor is None: - kernel_stream.write("#pragma HLS UNROLL", cfg, state_id, node) - else: - kernel_stream.write("#pragma HLS UNROLL factor={}".format(factor), cfg, state_id, node) - - @staticmethod - def generate_pipeline_loop_pre(kernel_stream, sdfg, cfg, state_id, node): - pass - - @staticmethod - def generate_pipeline_loop_post(kernel_stream, sdfg, cfg, state_id, node): - kernel_stream.write("#pragma HLS PIPELINE II=1", cfg, state_id, node) - - @staticmethod - def generate_flatten_loop_pre(kernel_stream, sdfg, cfg, state_id, node): - pass - - @staticmethod - def generate_flatten_loop_post(kernel_stream, sdfg, cfg, state_id, node): - kernel_stream.write("#pragma HLS LOOP_FLATTEN") - - def generate_nsdfg_header(self, sdfg, cfg, state, state_id, node, memlet_references, sdfg_label): - # TODO: Use a single method for GPU kernels, FPGA modules, and NSDFGs - arguments = [f'{atype} {aname}' for atype, aname, _ in memlet_references] - fsyms = node.sdfg.used_symbols(all_symbols=False, keep_defined_in_mapping=True) - arguments += [ - f'{node.sdfg.symbols[aname].as_arg(aname)}' for aname in sorted(node.symbol_mapping.keys()) - if aname in fsyms and aname not in sdfg.constants - ] - arguments = ', '.join(arguments) - return f'void {sdfg_label}({arguments}) {{\n#pragma HLS INLINE' - - def write_and_resolve_expr(self, sdfg, memlet, nc, outname, inname, indices=None, dtype=None): - """ - Emits a conflict resolution call from a memlet. - """ - redtype = operations.detect_reduction_type(memlet.wcr, openmp=True) - ptrname = cpp.ptr(memlet.data, sdfg.arrays[memlet.data], sdfg, self._frame) - defined_type, _ = self._dispatcher.defined_vars.get(ptrname) - if isinstance(indices, str): - ptr = '%s + %s' % (cpp.cpp_ptr_expr(sdfg, - memlet, - defined_type, - is_write=True, - codegen=self._frame, - decouple_array_interface=self._decouple_array_interfaces), indices) - else: - ptr = cpp.cpp_ptr_expr(sdfg, - memlet, - defined_type, - indices=indices, - is_write=True, - codegen=self._frame, - decouple_array_interface=self._decouple_array_interfaces) - - if isinstance(dtype, dtypes.pointer): - dtype = dtype.base_type - - # Special call for detected reduction types - if redtype != dtypes.ReductionType.Custom: - if redtype == dace.dtypes.ReductionType.Sub: - # write this as an addition - credtype = "dace::ReductionType::Sum" - is_sub = True - else: - credtype = "dace::ReductionType::" + str(redtype)[str(redtype).find(".") + 1:] - is_sub = False - - if isinstance(dtype, dtypes.vector): - return (f'dace::xilinx_wcr_fixed_vec<{credtype}, ' - f'{dtype.vtype.ctype}, {dtype.veclen}>::reduce(' - f'{ptr}, {"-" if is_sub else ""}{inname})') - return (f'dace::xilinx_wcr_fixed<{credtype}, {dtype.ctype}>::reduce(' - f'{ptr}, {"-" if is_sub else ""}{inname})') - - # General reduction - raise NotImplementedError('General reductions not yet implemented') - - @staticmethod - def make_read(defined_type, dtype, var_name, expr, index, is_pack, packing_factor): - if defined_type in [DefinedType.Stream, DefinedType.StreamArray]: - if " " in expr: - expr = "(" + expr + ")" - read_expr = "{}.pop()".format(expr) - elif defined_type == DefinedType.Scalar: - read_expr = var_name - else: - if index is not None and index != "0": - read_expr = "{} + {}".format(expr, index) - else: - read_expr = expr - if is_pack: - return "dace::Pack<{}, {}>({})".format(dtype.base_type.ctype, packing_factor, read_expr) - else: - return "dace::Read<{}, {}>({})".format(dtype.base_type.ctype, dtype.veclen, read_expr) - - def generate_converter(*args, **kwargs): - pass # Handled in C++ - - @staticmethod - def make_write(defined_type, dtype, var_name, write_expr, index, read_expr, wcr, is_unpack, packing_factor): - if defined_type in [DefinedType.Stream, DefinedType.StreamArray]: - if defined_type == DefinedType.StreamArray: - write_expr = "{}[{}]".format(write_expr, "0" if not index else index) - if is_unpack: - return "\n".join("{}.push({}[{}]);".format(write_expr, read_expr, i) for i in range(packing_factor)) - else: - return "{}.push({});".format(write_expr, read_expr) - else: - if defined_type == DefinedType.Scalar: - write_expr = var_name - elif index and index != "0": - write_expr = "{} + {}".format(write_expr, index) - if is_unpack: - return "dace::Unpack<{}, {}>({}, {});".format(dtype.base_type.ctype, packing_factor, read_expr, - write_expr) - else: - # TODO: Temporary hack because we don't have the output - # vector length. - veclen = max(dtype.veclen, packing_factor) - return "dace::Write<{}, {}>({}, {});".format(dtype.base_type.ctype, veclen, write_expr, read_expr) - - def make_shift_register_write(self, defined_type, dtype, var_name, write_expr, index, read_expr, wcr, is_unpack, - packing_factor, sdfg): - raise NotImplementedError("Xilinx shift registers NYI") - - @staticmethod - def generate_no_dependence_pre(kernel_stream, sdfg, cfg, state_id, node, var_name=None): - pass - - def generate_no_dependence_post(self, - kernel_stream, - sdfg: SDFG, - cfg: ControlFlowRegion, - state_id: int, - node: nodes.Node, - var_name: str, - accessed_subset: Union[int, subsets.Subset] = None): - """ - Adds post loop pragma for ignoring loop carried dependencies on a given variable - """ - defined_type, _ = self._dispatcher.defined_vars.get(var_name) - - if var_name in sdfg.arrays: - array = sdfg.arrays[var_name] - else: - array = None - - var_name = fpga.fpga_ptr(var_name, - array, - sdfg, - accessed_subset, - True, - self._dispatcher, - is_array_interface=(defined_type == DefinedType.ArrayInterface), - decouple_array_interfaces=self._decouple_array_interfaces) - kernel_stream.write("#pragma HLS DEPENDENCE variable={} false".format(var_name), cfg, state_id, node) - - def generate_kernel_boilerplate_pre(self, sdfg, cfg, state_id, kernel_name, parameters, bank_assignments, - module_stream, kernel_stream, external_streams, multi_pumped): - - # Write header - module_stream.write("""#include -#include -#include """, cfg) - self._frame.generate_fileheader(sdfg, module_stream, 'xilinx_device') - module_stream.write("\n", cfg) - - argname_to_bank_assignment = {} - # Build kernel signature - kernel_args = [] - array_args = [] - for is_output, data_name, data, interface in parameters: - is_assigned = data_name in bank_assignments and bank_assignments[data_name] is not None - if is_assigned and isinstance(data, dt.Array): - memory_bank = bank_assignments[data_name] - lowest_bank_index, _ = fpga.get_multibank_ranges_from_subset(memory_bank[1], sdfg) - - for bank, interface_id in fpga.iterate_multibank_interface_ids(data, interface): - kernel_arg = self.make_kernel_argument(data, - data_name, - bank, - sdfg, - is_output, - True, - interface_id, - decouple_array_interfaces=self._decouple_array_interfaces) - if kernel_arg: - kernel_args.append(kernel_arg) - array_args.append((kernel_arg, data_name)) - argname_to_bank_assignment[kernel_arg] = (memory_bank[0], lowest_bank_index + bank) - else: - kernel_arg = self.make_kernel_argument(data, - data_name, - None, - None, - is_output, - True, - interface, - decouple_array_interfaces=self._decouple_array_interfaces) - if kernel_arg: - kernel_args.append(kernel_arg) - if isinstance(data, dt.Array): - array_args.append((kernel_arg, data_name)) - argname_to_bank_assignment[kernel_arg] = None - - stream_args = [] - for is_output, data_name, data, interface in external_streams: - kernel_arg = self.make_kernel_argument(data, - data_name, - None, - None, - is_output, - True, - interface, - decouple_array_interfaces=self._decouple_array_interfaces) - if kernel_arg: - stream_args.append(kernel_arg) - - # Sometimes streams are added as an argument twice, which they shouldn't. - stream_args = dtypes.deduplicate(stream_args) - - if not self._decouple_array_interfaces: - kernel_args = dtypes.deduplicate(kernel_args) - - # Write kernel signature - kernel_stream.write("DACE_EXPORTED void {}({}) {{\n".format(kernel_name, ', '.join(kernel_args + stream_args)), - cfg, state_id) - - # Insert interface pragmas - num_mapped_args = 0 - if not self._decouple_array_interfaces: - array_args = dtypes.deduplicate(array_args) - - for arg, data_name in array_args: - var_name = re.findall(r"\w+", arg)[-1] - if "*" in arg: - interface_name = "gmem{}".format(num_mapped_args) - kernel_stream.write( - "#pragma HLS INTERFACE m_axi port={} " - "offset=slave bundle={}".format(var_name, interface_name), cfg, state_id) - # Map this interface to the corresponding location - # specification to be passed to the Xilinx compiler - memory_bank = argname_to_bank_assignment[arg] - self._bank_assignments[(kernel_name, interface_name)] = memory_bank - num_mapped_args += 1 - - if multi_pumped: - kernel_stream.write('#pragma HLS INTERFACE ap_ctrl_none port=return') - else: - for arg in kernel_args + ["return"]: - var_name = re.findall(r"\w+", arg)[-1] - kernel_stream.write("#pragma HLS INTERFACE s_axilite port={} bundle=control".format(var_name)) - - axis_pragmas = [] - for _, var_name, node, _ in external_streams: - arr_len = dace.symbolic.evaluate(node.shape[0], sdfg.constants) - if arr_len > 1: - partition_pragma = f"#pragma HLS ARRAY_PARTITION variable={var_name} dim=1 complete" - axis_pragmas.append(partition_pragma) - port_pragma = f"#pragma HLS INTERFACE axis port={var_name}" - axis_pragmas.append(port_pragma) - # Sometimes, streams are added as arguments twice, which they shouldn't. - axis_pragmas = dtypes.deduplicate(axis_pragmas) - for axis_pragma in axis_pragmas: - kernel_stream.write(axis_pragma) - - # TODO: add special case if there's only one module for niceness - kernel_stream.write("\n#pragma HLS DATAFLOW") - kernel_stream.write("\nHLSLIB_DATAFLOW_INIT();") - - @staticmethod - def generate_kernel_boilerplate_post(kernel_stream, sdfg, cfg, state_id): - kernel_stream.write("HLSLIB_DATAFLOW_FINALIZE();\n}\n", cfg, state_id) - - def generate_host_function_body(self, sdfg: dace.SDFG, cfg: ControlFlowRegion, state: dace.SDFGState, - kernel_name: str, predecessors: list, parameters: list, rtl_tasklet_names: list, - kernel_stream: CodeIOStream, instrumentation_stream: CodeIOStream, - multi_pumped: bool) -> None: - """ - Generate the host-specific code for spawning and synchronizing the given kernel. - - :param sdfg: The SDFG. - :param state: The state to generate in. - :param predecessors: list containing all the name of kernels that must be finished before starting this one - :param parameters: list containing the kernel parameters (of all kernels in this state) - :param rtl_tasklet_names: A list of RTL tasklet names. - :param kernel_stream: Device-specific code stream. - :param instrumentation_stream: Code for profiling kernel execution time. - :param multi_pumped: Whether the kernel is multi pumped - """ - - # Keep track of kernel arguments as (arg, interface_id) pair - kernel_args: List[Tuple[str, int]] = [] - - for _, name, p, interface_ids in parameters: - if isinstance(p, dt.Array): - for bank, interface_id in fpga.iterate_multibank_interface_ids(p, interface_ids): - # Keep track of the interface_id (if any), while creating kernel arguments. - # In Xilinx we may have kernel argument with the same name but we want to keep all of them - # if they have different interface IDs (this could be the case if the same data is accessed - # from different PEs) - - kernel_args.append( - (p.as_arg(False, - name=fpga.fpga_ptr(name, - p, - sdfg, - bank, - decouple_array_interfaces=self._decouple_array_interfaces)), - interface_id)) - elif isinstance(p, dt.Stream) and name in self._defined_external_streams: - if p.is_stream_array(): - kernel_args.append((f" hlslib::ocl::SimulationOnly(&{p.as_arg(False, name=name)}[0])", 0)) - else: - kernel_args.append((f" hlslib::ocl::SimulationOnly({p.as_arg(False, name=name)})", 0)) - else: - kernel_args.append((p.as_arg(False, name=name), 0)) - - kernel_function_name = kernel_name - kernel_file_name = "{}.xclbin".format(kernel_name) - - # Check if this kernel depends from other kernels - needs_synch = len(predecessors) > 0 - - if needs_synch: - # Build a vector containing all the events associated with the kernels from which this one depends - kernel_deps_name = f"deps_{kernel_name}" - kernel_stream.write(f"std::vector {kernel_deps_name};") - for pred in predecessors: - # concatenate events from predecessor kernel - kernel_stream.write(f"{kernel_deps_name}.push_back({pred}_event);") - if not self._decouple_array_interfaces: - kernel_args = dtypes.deduplicate(kernel_args) - # Launch HLS kernel, passing synchronization events (if any) - if multi_pumped: - kernel_signature = f'"{kernel_function_name}_top"' - else: - kernel_signature = f'{kernel_function_name}, "{kernel_function_name}"' - kernel_stream.write( - f"""auto {kernel_name}_kernel = program.MakeKernel({kernel_signature}, {", ".join(ka[0] for ka in kernel_args)});""" - ) - - kernel_stream.write( - f"""\ - hlslib::ocl::Event {kernel_name}_event = {kernel_name}_kernel.ExecuteTaskAsync({f'{kernel_deps_name}.begin(), {kernel_deps_name}.end()' if needs_synch else ''}); - all_events.push_back({kernel_name}_event);""", cfg, state.block_id) - if state.instrument == dtypes.InstrumentationType.FPGA: - self.instrument_opencl_kernel(kernel_name, state.block_id, cfg.cfg_id, instrumentation_stream) - - def generate_module(self, sdfg, cfg, state, kernel_name, name, subgraph, parameters, module_stream, entry_stream, - host_stream, instrumentation_stream): - """Generates a module that will run as a dataflow function in the FPGA - kernel.""" - - state_id = state.block_id - dfg = cfg.state(state_id) - - kernel_args_call = [] - kernel_args_module = [] - for is_output, pname, p, interface_ids in parameters: - if isinstance(p, dt.Array): - for bank, interface_id in fpga.iterate_multibank_interface_ids(p, interface_ids): - arr_name = fpga.fpga_ptr(pname, - p, - sdfg, - bank, - is_output, - is_array_interface=True, - decouple_array_interfaces=self._decouple_array_interfaces) - # Add interface ID to called module, but not to the module - # arguments - argname = fpga.fpga_ptr(pname, - p, - sdfg, - bank, - is_output, - is_array_interface=True, - interface_id=interface_id, - decouple_array_interfaces=self._decouple_array_interfaces) - - kernel_args_call.append(argname) - dtype = p.dtype - - if self._decouple_array_interfaces: - kernel_args_module.append("{} {}*{}".format(dtype.ctype, "const " if not is_output else "", - arr_name)) - else: - # in this case we don't know if this is accessed read-only or not - kernel_args_module.append("{} *{}".format(dtype.ctype, arr_name)) - - else: - if isinstance(p, dt.Stream): - # if this is an external stream, its name may have been mangled in the kernel - call_name = self._external_streams[pname] if pname in self._external_streams else pname - kernel_args_call.append(p.as_arg(with_types=False, name=call_name)) - if p.is_stream_array(): - kernel_args_module.append("dace::FIFO<{}, {}, {}> {}[{}]".format( - p.dtype.base_type.ctype, cpp.sym2cpp(p.veclen), cpp.sym2cpp(p.buffer_size), pname, - p.size_string())) - else: - kernel_args_module.append("dace::FIFO<{}, {}, {}> &{}".format( - p.dtype.base_type.ctype, cpp.sym2cpp(p.veclen), cpp.sym2cpp(p.buffer_size), pname)) - else: - kernel_args_call.append(p.as_arg(with_types=False, name=pname)) - kernel_args_module.append(p.as_arg(with_types=True, name=pname)) - - # Check if we are generating an RTL module, in which case only the - # accesses to the streams should be handled - rtl_tasklet = self.find_rtl_tasklet(subgraph) - if rtl_tasklet: - # Write placeholders in the original kernel. - entry_stream.write(f'// [RTL] HLSLIB_DATAFLOW_FUNCTION({name}, {", ".join(kernel_args_call)});') - module_stream.write(f'// [RTL] void {name}({", ".join(kernel_args_module)});\n\n') - - rtl_name = self.rtl_tasklet_name(rtl_tasklet, state, cfg) - - # _i in names are due to vitis - source_accessors = [] - for node in subgraph.source_nodes(): - if isinstance(node, dace.nodes.MapEntry): - source_accessors += [e.dst for e in state.out_edges(node)] - else: - source_accessors += [node] - - for node in source_accessors: - if isinstance(sdfg.arrays[node.data], dt.Stream): - # TODO multiple readers accessing a single stream should fail - dst = subgraph.out_edges(node)[0].dst - if isinstance(dst, dace.nodes.MapEntry) and dst.map.unroll: - unrolled_map_range = dace.symbolic.evaluate(dst.map.range[0][1] + 1, sdfg.constants) - else: - unrolled_map_range = 1 - if unrolled_map_range > 1: - elements_to_add = [f'{node.data}_{i}' for i in range(unrolled_map_range)] - else: - elements_to_add = [node.data] - for i in range(unrolled_map_range): - elem = elements_to_add[i] - postfix = f'_{i}' if unrolled_map_range > 1 else '' - if elem not in self._stream_connections: - self._stream_connections[elem] = [None, None] - for edge in subgraph.out_edges(node): - rtl_dst = state.memlet_path(edge)[-1].dst_conn - val = '{}_top_1.s_axis_{}{}'.format(rtl_name, rtl_dst, postfix) - self._stream_connections[elem][1] = val - - sink_accessors = [] - for node in subgraph.sink_nodes(): - if isinstance(node, dace.nodes.MapExit): - sink_accessors += [e.src for e in state.in_edges(node)] - else: - sink_accessors += [node] - - for node in sink_accessors: - if isinstance(sdfg.arrays[node.data], dt.Stream): - # TODO multiple writers accessing a single stream should fail - src = subgraph.in_edges(node)[0].src - if (isinstance(src, dace.nodes.MapExit) and src.map.unroll): - unrolled_map_range = dace.symbolic.evaluate(src.map.range[0][1] + 1, sdfg.constants) - else: - unrolled_map_range = 1 - if unrolled_map_range > 1: - elements_to_add = [f'{node.data}_{i}' for i in range(unrolled_map_range)] - else: - elements_to_add = [node.data] - for i in range(unrolled_map_range): - elem = elements_to_add[i] - postfix = f'_{i}' if unrolled_map_range > 1 else '' - if elem not in self._stream_connections: - self._stream_connections[elem] = [None, None] - for edge in state.in_edges(node): - rtl_src = subgraph.memlet_path(edge)[0].src_conn - self._stream_connections[elem][0] = '{}_top_1.m_axis_{}{}'.format( - rtl_name, rtl_src, postfix) - - # Make the dispatcher trigger generation of the RTL module, but - # ignore the generated code, as the RTL codegen will generate the - # appropriate files. - ignore_stream = CodeIOStream() - self._dispatcher.dispatch_subgraph(sdfg, - cfg, - subgraph, - state_id, - ignore_stream, - ignore_stream, - skip_entry_node=False) - - # Launch the kernel from the host code - # kernel arguments - host_stream.write( - f"all_events.push_back(program.MakeKernel(\"{rtl_name}_top\"{', '.join([''] + [name for _, name, p, _ in parameters if not isinstance(p, dt.Stream)])}).ExecuteTaskAsync());", - cfg, state_id, rtl_tasklet) - if state.instrument == dtypes.InstrumentationType.FPGA: - self.instrument_opencl_kernel(rtl_name, state_id, cfg.cfg_id, instrumentation_stream) - - return - - # create a unique module name to prevent name clashes - module_function_name = f"module_{name}_{cfg.cfg_id}" - - # Unrolling processing elements: if there first scope of the subgraph - # is an unrolled map, generate a processing element for each iteration - scope_children = subgraph.scope_children() - top_scopes = [n for n in scope_children[None] if isinstance(n, dace.sdfg.nodes.EntryNode)] - unrolled_loops = 0 - if len(top_scopes) == 1: - scope = top_scopes[0] - if scope.unroll: - self._unrolled_pes.add(scope.map) - kernel_args_call += ", ".join(scope.map.params) - kernel_args_module += ["int " + p for p in scope.params] - for p, r in zip(scope.map.params, scope.map.range): - if len(r) > 3: - raise cgx.CodegenError("Strided unroll not supported") - entry_stream.write("for (size_t {param} = {begin}; {param} < {end}; " - "{param} += {increment}) {{\n#pragma HLS UNROLL".format(param=p, - begin=r[0], - end=r[1] + 1, - increment=r[2])) - unrolled_loops += 1 - - # Generate caller code in top-level function - if not self._decouple_array_interfaces: - kernel_args_call = dtypes.deduplicate(kernel_args_call) - entry_stream.write( - "HLSLIB_DATAFLOW_FUNCTION({}, {});".format(module_function_name, ", ".join(kernel_args_call)), cfg, - state_id) - - for _ in range(unrolled_loops): - entry_stream.write("}") - - # ---------------------------------------------------------------------- - # Generate kernel code - # ---------------------------------------------------------------------- - - self._dispatcher.defined_vars.enter_scope(subgraph) - - module_body_stream = CodeIOStream() - - if not self._decouple_array_interfaces: - kernel_args_module = dtypes.deduplicate(kernel_args_module) - - module_body_stream.write("void {}({}) {{".format(module_function_name, ", ".join(kernel_args_module)), cfg, - state_id) - - # Register the array interface as a naked pointer for use inside the - # FPGA kernel - interfaces_added = set() - - for is_output, argname, arg, interface_id in parameters: - for bank, _ in fpga.iterate_multibank_interface_ids(arg, interface_id): - if isinstance(arg, dt.Stream) and argname in self._external_streams: - # This is an external stream being passed to the module - # Add this to defined vars - if not self._dispatcher.defined_vars.has(argname): - self._dispatcher.defined_vars.add(argname, DefinedType.Stream, arg.ctype) - continue - - if (not (isinstance(arg, dt.Array) and arg.storage == dace.dtypes.StorageType.FPGA_Global)): - continue - ctype = dtypes.pointer(arg.dtype).ctype - ptr_name = fpga.fpga_ptr(argname, - arg, - sdfg, - bank, - is_output, - None, - is_array_interface=True, - decouple_array_interfaces=self._decouple_array_interfaces) - if not is_output and self._decouple_array_interfaces: - ctype = f"const {ctype}" - - if self._decouple_array_interfaces: - self._dispatcher.defined_vars.add(ptr_name, DefinedType.Pointer, ctype) - if argname in interfaces_added: - continue - interfaces_added.add(argname) - self._dispatcher.defined_vars.add(argname, DefinedType.ArrayInterface, ctype, allow_shadowing=True) - module_body_stream.write("\n") - - # Allocate local transients - data_to_allocate = (set(subgraph.top_level_transients()) - set(sdfg.shared_transients()) - - set([p[1] for p in parameters])) - allocated = set() - for node in subgraph.nodes(): - if not isinstance(node, dace.sdfg.nodes.AccessNode): - continue - if node.data not in data_to_allocate or node.data in allocated: - continue - allocated.add(node.data) - self._dispatcher.dispatch_allocate(sdfg, cfg, state, state_id, node, node.desc(sdfg), module_stream, - module_body_stream) - - self._dispatcher.dispatch_subgraph(sdfg, - cfg, - subgraph, - state_id, - module_stream, - module_body_stream, - skip_entry_node=False) - - module_stream.write(module_body_stream.getvalue(), cfg, state_id) - module_stream.write("}\n\n") - - self._dispatcher.defined_vars.exit_scope(subgraph) - - def rtl_tasklet_name(self, node: nodes.RTLTasklet, state, cfg): - return "{}_{}_{}_{}".format(node.name, cfg.cfg_id, state.block_id, state.node_id(node)) - - def generate_kernel_internal(self, sdfg: dace.SDFG, cfg: ControlFlowRegion, state: dace.SDFGState, kernel_name: str, - predecessors: list, subgraphs: list, kernel_stream: CodeIOStream, - state_host_header_stream: CodeIOStream, state_host_body_stream: CodeIOStream, - instrumentation_stream: CodeIOStream, function_stream: CodeIOStream, - callsite_stream: CodeIOStream, state_parameters: list) -> None: - """ - Generates Kernel code, both device and host side. - - :param sdfg: - :param state: - :param kernel_name: - :param predecessors: list containing all the name of kernels from which this one depends - :param subgraphs: - :param kernel_stream: Device code stream, contains the kernel code - :param state_host_header_stream: Device-specific code stream: contains the host code - for the state global declarations. - :param state_host_body_stream: Device-specific code stream: contains all the code related to - this state, for creating transient buffers, spawning kernels, and synchronizing them. - :param instrumentation_stream: Code for profiling kernel execution time. - :param function_stream: CPU code stream. - :param callsite_stream: CPU code stream. - :param state_parameters: list of state parameters. The kernel-specific parameters will be appended to it. - """ - - (global_data_parameters, top_level_local_data, subgraph_parameters, nested_global_transients, bank_assignments, - external_streams) = self.make_parameters(sdfg, state, subgraphs) - - state_parameters.extend(global_data_parameters) - - # We need to pass external streams as parameters to module - # (unless they are already there. This could be case of inter-PE intra-kernel streams) - # TODO It doesn't break RTL, but the streams are passed to sub kernels that don't need the streams, in turn relying on Vitis to optimize them away again. - for k, v in subgraph_parameters.items(): - for stream_is_out, stream_name, stream_desc, stream_iid in external_streams: - for is_output, data_name, desc, interface_id in v: - if data_name == stream_name and stream_desc == desc: - break - else: - v.append((stream_is_out, stream_name, stream_desc, stream_iid)) - - # Xilinx does not like external streams name with leading underscores or multiple underscores in a row to be - # used as port names. We remove them, and we check that they are not defined anywhere else. - for es in external_streams: - - new_name = re.sub('_+', '_', es[1]) - new_name = new_name.strip("_") - self._external_streams[es[1]] = new_name - - if new_name != es[1]: - clashes = [param for param in global_data_parameters if param[1] == new_name] - clashes.extend([param for param in top_level_local_data if param.data == new_name]) - clashes.extend( - [param for params in subgraph_parameters.values() for param in params if param[1] == new_name]) - clashes.extend([param for param in nested_global_transients if param.data == new_name]) - if len(clashes) > 0: - raise cgx.CodegenError( - f"External stream {es[1]} with sanitized name {new_name} clashes with other paramters {len(clashes)} times." - ) - else: - # Update the sdfg - sdfg.replace(es[1], new_name) - - # Update the global data parameters - for i, p in enumerate(global_data_parameters): - if p[1] == es[1]: - global_data_parameters[i] = (p[1], new_name, p[2], p[3]) - - # Update the top level local data - for p in top_level_local_data: - if p.data == es[1]: - p.data = new_name - - # Update the subgraph parameters - for v in subgraph_parameters.values(): - for i, p in enumerate(v): - if p[1] == es[1]: - v[i] = (p[0], new_name, p[2], p[3]) - - # Update the nested global transients - for p in nested_global_transients: - if p.data == es[1]: - p.data = new_name - - # Update the external streams - external_streams.remove(es) - external_streams.append((es[0], new_name, es[2], es[3])) - - # Detect RTL tasklets, which will be launched as individual kernels - rtl_tasklet_names = [ - self.rtl_tasklet_name(nd, state, cfg) for nd in state.nodes() if isinstance(nd, nodes.RTLTasklet) - ] - - multi_pumped = all([self.is_multi_pumped_subgraph(sg) for sg in subgraphs]) - - # Generate host code - self.generate_host_header(sdfg, cfg, kernel_name, global_data_parameters + external_streams, - state_host_header_stream, multi_pumped) - self.generate_host_function_boilerplate(sdfg, cfg, state, nested_global_transients, state_host_body_stream) - - # Now we write the device code - module_stream = CodeIOStream() - entry_stream = CodeIOStream() - - state_id = cfg.node_id(state) - - self.generate_kernel_boilerplate_pre(sdfg, cfg, state_id, kernel_name, global_data_parameters, bank_assignments, - module_stream, entry_stream, external_streams, multi_pumped) - - # Emit allocations - for node in top_level_local_data: - self._dispatcher.dispatch_allocate(sdfg, cfg, state, state_id, node, node.desc(sdfg), module_stream, - entry_stream) - - for is_output, name, node, _ in external_streams: - buffer_size = dace.symbolic.evaluate(node.buffer_size, sdfg.constants) - ctype = "dace::FIFO<{}, {}, {}>".format(node.dtype.base_type.ctype, node.dtype.veclen, buffer_size) - num_streams = dace.symbolic.evaluate(node.shape[0], sdfg.constants) - self._dispatcher.defined_vars.add_global(name, DefinedType.Stream, ctype) - key = 0 if is_output else 1 - - # Define here external streams - if name not in self._defined_external_streams: - self.define_stream(node.dtype, node.buffer_size, name, node.total_size, None, state_host_body_stream, - sdfg) - self._defined_external_streams.add(name) - - if num_streams > 1: - streams = [f'{name}_{i}' for i in range(num_streams)] - else: # _num should not be appended, when there is only one kernel - streams = [name] - if name not in self._defined_external_streams: - self.define_stream(node.dtype, node.buffer_size, name, node.total_size, None, - state_host_body_stream, sdfg) - self._defined_external_streams.add(name) - - for stream in streams: - if stream not in self._stream_connections: - self._stream_connections[stream] = [None, None] - stream_prefix = 'm_axis_' if is_output else 's_axis_' - stream_prefix = stream_prefix if multi_pumped else '' - kernel_postfix = '_top_1' if multi_pumped else '_1' - val = '{}{}.{}{}'.format(kernel_name, kernel_postfix, stream_prefix, stream) - self._stream_connections[stream][key] = val - - self.generate_modules(sdfg, cfg, state, kernel_name, subgraphs, subgraph_parameters, module_stream, - entry_stream, state_host_body_stream, instrumentation_stream) - - if multi_pumped: - # We have to generate the rest of the RTL files for multi-pumping. In particular: - # - The tcl script for configuring the C++ kernel and data plumbing IP cores. - # - The top-level file for instantiating the C++ kernel and data plumbing IP cores. - # - The Verilog controller for communicating with the host program. - # - A tcl script for synthesizing the multi-pumped kernel for a faster development cycle. - rtllib_config = { - "name": kernel_name, - "buses": { # TODO unroll factor - pname: ('m_axis' if is_output or pname.endswith('_out') else 's_axis', p.veclen) - for is_output, pname, p, _ in external_streams - }, - "params": { - "scalars": { - #name: total_size - #for name, (_, total_size) in scalars.items() - }, - "memory": {} - }, - #"unroll": True, - "double_pump": True, - "ip_cores": { - # TODO Maybe with some help from rtllib - }, - #"version": 20211, - "clocks": - 2 # TODO make this "trickle" down here. Maybe add speeds as well? Might be usefull when packaging - } - # Add the emitted C++ kernel as an IP core - rtllib_config['ip_cores'][f'{kernel_name}_0'] = { - 'name': f'{kernel_name}', - 'vendor': 'xilinx.com', - 'library': 'hls', - 'version': '1.0', - 'params': {} - } - # Add the IP cores for clock synchronization - for _, pname, p, _ in external_streams: - rtllib_config['ip_cores'][f'clock_sync_{pname}'] = { - 'name': 'axis_clock_converter', - 'vendor': 'xilinx.com', - 'library': 'ip', - 'version': '1.1', - 'params': { - 'CONFIG.TDATA_NUM_BYTES': p.dtype.bytes, - 'CONFIG.SYNCHRONIZATION_STAGES': 8, - } - } - - # Avoid importing submodule if not necessary - from dace.external.rtllib.templates.control import generate_from_config as rtllib_control - from dace.external.rtllib.templates.package import generate_from_config as rtllib_package - from dace.external.rtllib.templates.top import generate_from_config as rtllib_top - from dace.external.rtllib.templates.synth import generate_from_config as rtllib_synth - - # Trigger the generation - self._ip_codes.append((f"{kernel_name}_control", 'v', rtllib_control(rtllib_config))) - self._ip_codes.append((f'{kernel_name}_top', 'v', rtllib_top(rtllib_config))) - self._ip_codes.append((f'{kernel_name}_package', 'tcl', rtllib_package(rtllib_config))) - self._ip_codes.append((f'{kernel_name}_synth', 'tcl', rtllib_synth(rtllib_config))) - - self.generate_host_function_body(sdfg, cfg, state, kernel_name, predecessors, - global_data_parameters + external_streams, rtl_tasklet_names, - state_host_body_stream, instrumentation_stream, multi_pumped) - - # Store code to be passed to compilation phase - # self._host_codes.append((kernel_name, host_code_stream.getvalue())) - kernel_stream.write(module_stream.getvalue()) - kernel_stream.write(entry_stream.getvalue()) - - self.generate_kernel_boilerplate_post(kernel_stream, sdfg, cfg, state_id) - - def generate_host_header(self, sdfg, cfg, kernel_function_name, parameters, host_code_stream, multi_pumped): - - kernel_args = [] - for is_output, name, arg, interface_ids in parameters: - if isinstance(arg, dt.Stream): - - if arg.is_stream_array(): - kernel_args.append("dace::FIFO<{}, {}, {}> {}[{}]".format(arg.dtype.base_type.ctype, - cpp.sym2cpp(arg.veclen), - cpp.sym2cpp(arg.buffer_size), name, - arg.size_string())) - else: - kernel_args.append("dace::FIFO<{}, {}, {}> &{}".format(arg.dtype.base_type.ctype, - cpp.sym2cpp(arg.veclen), - cpp.sym2cpp(arg.buffer_size), name)) - elif isinstance(arg, dt.Array): - for bank, interface_id in fpga.iterate_multibank_interface_ids(arg, interface_ids): - argname = fpga.fpga_ptr(name, - arg, - sdfg, - bank, - is_output, - None, - None, - True, - interface_id, - decouple_array_interfaces=self._decouple_array_interfaces) - kernel_args.append(arg.as_arg(with_types=True, name=argname)) - else: - kernel_args.append(arg.as_arg(with_types=True, name=name)) - if not self._decouple_array_interfaces: - kernel_args = dtypes.deduplicate(kernel_args) - ignore_signature = '//' if multi_pumped else '' - host_code_stream.write( - """\ -// Signature of kernel function (with raw pointers) for argument matching -{ignore_signature}DACE_EXPORTED void {kernel_function_name}({kernel_args});\n\n""".format( - kernel_function_name=kernel_function_name, - ignore_signature=ignore_signature, - kernel_args=", ".join(kernel_args)), cfg) - - def generate_memlet_definition(self, sdfg, cfg, dfg, state_id, src_node, dst_node, edge, callsite_stream): - memlet = edge.data - ptrname = cpp.ptr(memlet.data, sdfg.arrays[memlet.data], sdfg, self._frame) - - if (self._dispatcher.defined_vars.get(ptrname)[0] == DefinedType.FPGA_ShiftRegister): - raise NotImplementedError("Shift register for Xilinx NYI") - else: - self._cpu_codegen.copy_memory(sdfg, cfg, dfg, state_id, src_node, dst_node, edge, None, callsite_stream) - - def allocate_view(self, sdfg: dace.SDFG, cfg: ControlFlowRegion, dfg: dace.SDFGState, state_id: int, - node: dace.nodes.AccessNode, global_stream: CodeIOStream, declaration_stream: CodeIOStream, - allocation_stream: CodeIOStream) -> None: - return self._cpu_codegen.allocate_view(sdfg, - cfg, - dfg, - state_id, - node, - global_stream, - declaration_stream, - allocation_stream, - decouple_array_interfaces=self._decouple_array_interfaces) - - def generate_nsdfg_arguments(self, sdfg, cfg, dfg, state, node): - # Connectors that are both input and output share the same name, unless - # they are pointers to global memory in device code, in which case they - # are split into explicit input and output interfaces - inout = set(node.in_connectors.keys() & node.out_connectors.keys()) - - memlet_references = [] - for _, _, _, vconn, in_memlet in sorted(state.in_edges(node), key=lambda e: e.dst_conn or ""): - if in_memlet.data is None: - continue - if not self._decouple_array_interfaces and vconn in inout: - # Only one interface will be generated - continue - ptrname = cpp.ptr(in_memlet.data, sdfg.arrays[in_memlet.data], sdfg, self._frame) - is_memory_interface = (self._dispatcher.defined_vars.get(ptrname, 1)[0] == DefinedType.ArrayInterface) - desc = sdfg.arrays[in_memlet.data] - if is_memory_interface: - for bank in fpga.iterate_distributed_subset(sdfg.arrays[in_memlet.data], in_memlet, False, sdfg): - interface_name = fpga.fpga_ptr(vconn, - sdfg.arrays[in_memlet.data], - sdfg, - bank, - False, - is_array_interface=True, - decouple_array_interfaces=self._decouple_array_interfaces) - passed_memlet = copy.deepcopy(in_memlet) - passed_memlet.subset = fpga.modify_distributed_subset(passed_memlet.subset, bank) - interface_ref = cpp.emit_memlet_reference(self._dispatcher, - sdfg, - passed_memlet, - interface_name, - conntype=node.in_connectors[vconn], - is_write=False, - decouple_array_interfaces=self._decouple_array_interfaces) - memlet_references.append(interface_ref) - - if vconn in inout: - continue - if fpga.is_multibank_array_with_distributed_index(sdfg.arrays[in_memlet.data]): - passed_memlet = copy.deepcopy(in_memlet) - passed_memlet.subset = fpga.modify_distributed_subset(passed_memlet.subset, - 0) # dummy so it works for HBM - else: - passed_memlet = in_memlet - ref = cpp.emit_memlet_reference(self._dispatcher, - sdfg, - passed_memlet, - vconn, - conntype=node.in_connectors[vconn], - is_write=False, - decouple_array_interfaces=self._decouple_array_interfaces) - if not is_memory_interface: - memlet_references.append(ref) - - for _, uconn, _, _, out_memlet in sorted(state.out_edges(node), key=lambda e: e.src_conn or ""): - if out_memlet.data is None: - continue - if fpga.is_multibank_array_with_distributed_index(sdfg.arrays[out_memlet.data]): - passed_memlet = copy.deepcopy(out_memlet) - passed_memlet.subset = fpga.modify_distributed_subset(passed_memlet.subset, - 0) # dummy so it works for HBM - else: - passed_memlet = out_memlet - desc = sdfg.arrays[out_memlet.data] - ref = cpp.emit_memlet_reference(self._dispatcher, - sdfg, - passed_memlet, - uconn, - conntype=node.out_connectors[uconn], - is_write=True, - decouple_array_interfaces=self._decouple_array_interfaces) - ptrname = cpp.ptr(out_memlet.data, sdfg.arrays[out_memlet.data], sdfg, self._frame) - is_memory_interface = (self._dispatcher.defined_vars.get(ptrname, 1)[0] == DefinedType.ArrayInterface) - - if is_memory_interface: - for bank in fpga.iterate_distributed_subset(sdfg.arrays[out_memlet.data], out_memlet, True, sdfg): - interface_name = fpga.fpga_ptr(uconn, - sdfg.arrays[out_memlet.data], - sdfg, - bank, - True, - is_array_interface=True, - decouple_array_interfaces=self._decouple_array_interfaces) - passed_memlet = copy.deepcopy(out_memlet) - passed_memlet.subset = fpga.modify_distributed_subset(passed_memlet.subset, bank) - interface_ref = cpp.emit_memlet_reference(self._dispatcher, - sdfg, - passed_memlet, - interface_name, - conntype=node.out_connectors[uconn], - is_write=True, - decouple_array_interfaces=self._decouple_array_interfaces) - memlet_references.append(interface_ref) - else: - memlet_references.append(ref) - - return memlet_references - - def unparse_tasklet(self, *args, **kwargs): - # Pass this object for callbacks into the Xilinx codegen - cpp.unparse_tasklet(*args, codegen=self, **kwargs) - - def make_ptr_assignment(self, src_expr, src_dtype, dst_expr, dst_dtype): - """ - Write source to destination, where the source is a scalar, and the - destination is a pointer. - - :return: String of C++ performing the write. - """ - return self.make_write(DefinedType.Pointer, dst_dtype, None, "&" + dst_expr, None, src_expr, None, - dst_dtype.veclen < src_dtype.veclen, src_dtype.veclen) diff --git a/dace/config.py b/dace/config.py index c917dac34d..53328a7536 100644 --- a/dace/config.py +++ b/dace/config.py @@ -1,4 +1,4 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. import contextlib import os import platform @@ -178,6 +178,25 @@ def load_schema(self, filename: Optional[str] = None): with open(filename, 'r') as f: self._config_metadata = yaml.load(f.read(), Loader=yaml.SafeLoader) + def extend(self, schema_filename: str): + """ + Extends the current configuration schema with another schema from file. + + :param schema_filename: The schema file to load. + """ + with open(schema_filename, 'r') as f: + new_metadata = yaml.load(f.read(), Loader=yaml.SafeLoader) + + def merge_dicts(d1: Dict[str, Any], d2: Dict[str, Any]): + for k, v in d2.items(): + if k in d1 and isinstance(d1[k], dict) and isinstance(v, dict): + merge_dicts(d1[k], v) + else: + d1[k] = v + + merge_dicts(self._config_metadata['required'], new_metadata['required']) + _add_defaults(self._config, new_metadata['required']) + def save(self, path: Optional[str] = None, all: bool = False, file: Optional[io.FileIO] = None): if path is None and file is None: path = self._cfg_filename @@ -323,6 +342,15 @@ def cfg_filename(): """ return Config._data.cfg_filename() + @staticmethod + def extend(schema_filename: str): + """ + Extends the current configuration schema with another schema from file. + + :param schema_filename: The schema file to load. + """ + return Config._data.extend(schema_filename=schema_filename) + @staticmethod def load(filename: Optional[str] = None, file: Optional[io.FileIO] = None): """ diff --git a/dace/config_schema.yml b/dace/config_schema.yml index 812e24329e..2b05d45232 100644 --- a/dace/config_schema.yml +++ b/dace/config_schema.yml @@ -173,6 +173,13 @@ required: The typename of this struct is derived by appending this value to the SDFG's name. Note that the suffix may only contains letters, digits and underscores. + cpp_standard: + type: str + default: "20" + title: C++ standard version + description: > + C++ standard to use for compilation (e.g., 14, 17, 20, 23, 26). + format_code: type: bool default: false @@ -183,7 +190,7 @@ required: format_config_file: type: str default: "" - title: Path the clang-format file + title: Path to the .clang-format file description: > Clang-format file to be used by clang-format, only used if format_code is true @@ -261,7 +268,7 @@ required: type: str title: Arguments description: Compiler argument flags - default: '-std=c++14 -fPIC -Wall -Wextra -O3 -march=native -ffast-math -Wno-unused-parameter -Wno-unused-label' + default: '-fPIC -Wall -Wextra -O3 -march=native -ffast-math -Wno-unused-parameter -Wno-unused-label' default_Windows: '/O2 /fp:fast /arch:AVX2 /D_USRDLL /D_WINDLL /D__restrict__=__restrict' libs: @@ -310,7 +317,7 @@ required: type: str title: hipcc Arguments description: Compiler argument flags for HIP - default: '-std=c++17 -fPIC -O3 -ffast-math -Wno-unused-parameter' + default: '-fPIC -O3 -ffast-math -Wno-unused-parameter' cuda_arch: type: str @@ -458,210 +465,6 @@ required: will raise an exception if such a Memlet is encountered. This allows the user to have full control over all Maps in the SDFG. - - ############################################# - # General FPGA flags - fpga: - type: dict - title: FPGA - description: "Common preferences for FPGA compilation." - required: - - autobuild_bitstreams: - type: bool - default: true - title: Automatically build bitstreams - description: > - If set to true, CMake will automatically build missing - bitstreams when running an FPGA program. This can take a - very long time, and users might want to do this manually. - If set to false, the program will optimistically assume - that the bitstream is present in the build directory, and - will crash if this is not the case. - - minimum_fifo_depth: - type: int - default: '' - title: Minimum depth of FIFOs - description: Sets the minimum depth of any generated FIFO. - - vendor: - type: str - default: xilinx - title: FPGA vendor - description: > - Target Xilinx ("xilinx") or Intel ("intel_fpga") FPGAs when - generating code. - - concurrent_kernel_detection: - type: bool - default: false - title: Detect parts of an SDFG that can run in parallel - description: > - If set to false, DaCe will place each weakly connected - component found in an SDFG state in a different Kernel/Processing Element. - If true, a heuristic will further inspect each independent component - for other parallelism opportunities (e.g., branches of the SDFG - that can be executed in parallel), creating the corresponding kernels. - - ############################################# - # FPGA (Xilinx) compiler flags - xilinx: - type: dict - title: Xilinx - description: FPGA (Xilinx) compiler preferences - required: - - mode: - type: str - default: simulation - title: Compilation mode - description: Target of FPGA kernel build (simulation/software_emulation/hardware_emulation/hardware) - - path: - type: str - default: '' - title: Vitis installation override - description: > - Path to specific Vitis/SDx/SDAccel installation to - use instead of just searching PATH and environment - variables. - - platform: - type: str - default: xilinx_u250_xdma_201830_2 - title: Target platform for Xilinx - description: Platform name of Vitis/SDx/SDAccel target. - - frequency: - type: str - default: '' - title: Target frequency for Xilinx kernels - description: > - Target frequency, in MHz, when compiling kernels - for Xilinx. Will not necessarily be achieved in - practice. To enable multiple clocks, enter values - in the format "clock_id:frequency", with frequency - being specified in MHz separated by an escaped bar, - all enclosed in quotes. E.g. "0:250\|1:500". - - enable_debugging: - type: bool - default: false - title: Enable debugging for hardware kernels - description: > - Injects debugging cores on the interfaces of the - kernel, allowing fine-grained debugging of hardware - runs at the cost of additional resources. This is - always enabled for emulation runs. - - host_flags: - type: str - title: Host arguments - description: Extra host compiler argument flags - default: "-Wno-unknown-pragmas -Wno-unused-label" - - synthesis_flags: - type: str - title: Synthesis arguments - description: High-level synthesis C++ flags - default: "-std=c++14" - - build_flags: - type: str - title: Arguments - description: Kernel build C++ flags - default: "" - - decouple_array_interfaces: - type: bool - default: false - title: Decouple array memory interfaces - description: > - If an array is both read and written, this option decouples - its accesses, by creatin a memory interface for reading and one - for writing. - Note that this may hide potential Read-After-Write or - Write-After-Read dependencies. - - - ############################################# - # Intel FPGA compiler flags - intel_fpga: - type: dict - title: Intel FPGA - description: Intel FPGA compiler preferences. - required: - - mode: - type: str - default: emulator - title: Compilation mode - description: > - Target of FPGA kernel build - (emulator/simulator/hardware). - - path: - type: str - default: '' - title: Intel FPGA OpenCL SDK installation override - description: > - Path to specific Intel FPGA OpenCL SDK installation - to use instead of just searching PATH and - environment variables. - - board: - type: str - default: a10gx - title: Target FPGA board - description: FPGA board to compile for, obtain list by running ``aoc --list-boards``. - - enable_debugging: - type: bool - default: false - title: Enable debugging for hardware kernels - description: Injects debugging cores where available. - - host_flags: - type: str - title: Host arguments - description: Extra host compiler argument flags - default: "-Wno-unknown-pragmas" - - kernel_flags: - type: str - title: Kernel flags - description: High-level synthesis C++ flags - default: "-fp-relaxed -cl-no-signed-zeros -cl-fast-relaxed-math -cl-single-precision-constant -no-interleaving=default" - - ############################################# - # RTL (SystemVerilog) compiler - rtl: - type: dict - title: RTL - description: RTL (SystemVerilog) compiler preferences - required: - verbose: - type: bool - default: false - title: Verbose Build & Execution Output - description: Output full build and execution (incl internal state) log. - verilator_flags: - type: str - default: '' - title: Additional Verilator Arguments - description: Additional arguments feed to verilator. - verilator_lint_warnings: - type: bool - default: true - title: Verilator Lint Warnings - description: Enable/Disable detailed SV lint checker output. - verilator_enable_debug: - type: bool - default: false - title: Verilator Enable Debug - description: Enable/disable verbose internal state debug output. - ############################################# # MPI compiler mpi: @@ -692,10 +495,7 @@ required: type: str title: Arguments description: Linker argument flags - # Tell linker to use rpath instead of runpath. Intel - # FPGA programs fail to find certain libraries at - # runtime with runpath. - default: '-Wl,--disable-new-dtags' + default: '' default_Darwin: '' default_Windows: '' @@ -738,12 +538,6 @@ required: description: > Enables analysis of gcc vectorization information. Only gcc/g++ is supported. - print_fpga_runtime: - type: bool - default: false - title: Print FPGA runtime - description: Prints the runtime of instrumented FPGA kernel states to standard output. - ############################################# # Python frontend settings @@ -1020,19 +814,6 @@ required: Force the default implementation, even if an implementation has been explicitly set on a node. - fpga: - type: dict - title: FPGA - description: FPGA-specific BLAS options. - required: - default_stream_depth: - type: int - default: 32 - title: Default FPGA stream depth - description: > - Default FPGA stream depth used in the BLAS - library nodes and the corresponding - streaming transformations lapack: type: dict title: LAPACK diff --git a/dace/data/__init__.py b/dace/data/__init__.py new file mode 100644 index 0000000000..4620474f01 --- /dev/null +++ b/dace/data/__init__.py @@ -0,0 +1,110 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Data descriptors for DaCe. + +This package contains classes for describing data containers (arrays, scalars, streams, etc.) +that can be used in SDFGs. The classes in this package are used to specify the shape, type, +and storage location of data, as well as other properties that affect code generation. + +For backward compatibility, all classes and functions are re-exported from the top-level +`dace.data` module. +""" + +# Core data descriptors +from dace.data.core import ( + Data, + Scalar, + Array, + ContainerArray, + Stream, + Structure, + View, + Reference, + ArrayView, + StructureView, + ContainerView, + ArrayReference, + StructureReference, + ContainerArrayReference, +) + +# Import prod from utils and expose as _prod for backward compatibility +from dace.utils import prod as _prod + +# Tensor/sparse tensor support +from dace.data.tensor import ( + TensorIterationTypes, + TensorAssemblyType, + TensorIndex, + TensorIndexDense, + TensorIndexCompressed, + TensorIndexSingleton, + TensorIndexRange, + TensorIndexOffset, + Tensor, +) + +# Convenience aliases for tensor indices +Dense = TensorIndexDense +Compressed = TensorIndexCompressed +Singleton = TensorIndexSingleton +Range = TensorIndexRange +Offset = TensorIndexOffset + +# ML-related data descriptors +from dace.data.ml import ParameterArray + +# Descriptor creation and array creation from descriptors +from dace.data.creation import ( + create_datadescriptor, + make_array_from_descriptor, + make_reference_from_descriptor, +) + +# Ctypes interoperability +from dace.data.ctypes_interop import make_ctypes_argument + +# Import utility function from utils (for backward compatibility) +from dace.utils import find_new_name + +__all__ = [ + # Core classes + 'Data', + 'Scalar', + 'Array', + 'ContainerArray', + 'Stream', + 'Structure', + 'View', + 'Reference', + 'ArrayView', + 'StructureView', + 'ContainerView', + 'ArrayReference', + 'StructureReference', + 'ContainerArrayReference', + # Tensor support + 'TensorIterationTypes', + 'TensorAssemblyType', + 'TensorIndex', + 'TensorIndexDense', + 'TensorIndexCompressed', + 'TensorIndexSingleton', + 'TensorIndexRange', + 'TensorIndexOffset', + 'Tensor', + # Tensor aliases + 'Dense', + 'Compressed', + 'Singleton', + 'Range', + 'Offset', + # ML descriptors + 'ParameterArray', + # Functions + 'create_datadescriptor', + 'make_array_from_descriptor', + 'make_reference_from_descriptor', + 'make_ctypes_argument', + 'find_new_name', +] diff --git a/dace/data.py b/dace/data/core.py similarity index 58% rename from dace/data.py rename to dace/data/core.py index 6026b24f32..416783a29e 100644 --- a/dace/data.py +++ b/dace/data/core.py @@ -1,15 +1,17 @@ -# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. -import aenum +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. +""" +Core data descriptor classes. + +This module contains the base ``Data`` class and all core descriptor classes: +``Scalar``, ``Array``, ``ContainerArray``, ``Stream``, ``Structure``, +``View``, ``Reference``, and their subclasses. +""" import copy as cp import ctypes import dataclasses -import functools -import warnings -from abc import ABC, abstractmethod from collections import OrderedDict -from numbers import Number -from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union +from typing import Any, Dict, List, Set, Tuple, Union import numpy as np import sympy as sp @@ -19,146 +21,26 @@ except (ModuleNotFoundError, ImportError): ArrayLike = Any -from dace import config, dtypes, serialize, symbolic -from dace.codegen import cppunparse +from dace import dtypes, serialize, symbolic from dace.properties import (DebugInfoProperty, DictProperty, EnumProperty, ListProperty, NestedDataClassProperty, OrderedDictProperty, Property, ShapeProperty, SymbolicProperty, TypeClassProperty, make_properties) +from dace.utils import prod +# Backward compatibility alias +_prod = prod -def create_datadescriptor(obj, no_custom_desc=False): - """ Creates a data descriptor from various types of objects. - :see: dace.data.Data - """ - if isinstance(obj, Data): - return obj - elif not no_custom_desc and hasattr(obj, '__descriptor__'): - return obj.__descriptor__() - elif not no_custom_desc and hasattr(obj, 'descriptor'): - return obj.descriptor - elif type(obj).__module__ == "torch" and type(obj).__name__ == "Tensor": - # special case for torch tensors. Maybe __array__ could be used here for a more - # general solution, but torch doesn't support __array__ for cuda tensors. - try: - # If torch is importable, define translations between typeclasses and torch types. These are reused by daceml. - # conversion happens here in pytorch: - # https://github.com/pytorch/pytorch/blob/143ef016ee1b6a39cf69140230d7c371de421186/torch/csrc/utils/tensor_numpy.cpp#L237 - import torch - TYPECLASS_TO_TORCH_DTYPE = { - dtypes.bool_: torch.bool, - dtypes.int8: torch.int8, - dtypes.int16: torch.int16, - dtypes.int32: torch.int32, - dtypes.int64: torch.int64, - dtypes.uint8: torch.uint8, - dtypes.float16: torch.float16, - dtypes.float32: torch.float32, - dtypes.float64: torch.float64, - dtypes.complex64: torch.complex64, - dtypes.complex128: torch.complex128, - } - - TORCH_DTYPE_TO_TYPECLASS = {v: k for k, v in TYPECLASS_TO_TORCH_DTYPE.items()} - - storage = dtypes.StorageType.GPU_Global if obj.device.type == 'cuda' else dtypes.StorageType.Default - - return Array(dtype=TORCH_DTYPE_TO_TYPECLASS[obj.dtype], - strides=obj.stride(), - shape=tuple(obj.shape), - storage=storage) - except ImportError: - raise ValueError("Attempted to convert a torch.Tensor, but torch could not be imported") - elif dtypes.is_array(obj) and (hasattr(obj, '__array_interface__') or hasattr(obj, '__cuda_array_interface__')): - if dtypes.is_gpu_array(obj): - interface = obj.__cuda_array_interface__ - storage = dtypes.StorageType.GPU_Global - else: - interface = obj.__array_interface__ - storage = dtypes.StorageType.Default +def _arrays_to_json(arrays): + if arrays is None: + return None + return [(k, serialize.to_json(v)) for k, v in arrays.items()] - if hasattr(obj, 'dtype') and obj.dtype.fields is not None: # Struct - dtype = dtypes.struct('unnamed', **{k: dtypes.typeclass(v[0].type) for k, v in obj.dtype.fields.items()}) - else: - if np.dtype(interface['typestr']).type is np.void: # Struct from __array_interface__ - if 'descr' in interface: - dtype = dtypes.struct('unnamed', **{ - k: dtypes.typeclass(np.dtype(v).type) - for k, v in interface['descr'] - }) - else: - raise TypeError(f'Cannot infer data type of array interface object "{interface}"') - else: - dtype = dtypes.typeclass(np.dtype(interface['typestr']).type) - itemsize = np.dtype(interface['typestr']).itemsize - if len(interface['shape']) == 0: - return Scalar(dtype, storage=storage) - return Array(dtype=dtype, - shape=interface['shape'], - strides=(tuple(s // itemsize for s in interface['strides']) if interface['strides'] else None), - storage=storage) - elif isinstance(obj, (list, tuple)): - # Lists and tuples are cast to numpy - obj = np.array(obj) - - if obj.dtype.fields is not None: # Struct - dtype = dtypes.struct('unnamed', **{k: dtypes.typeclass(v[0].type) for k, v in obj.dtype.fields.items()}) - else: - dtype = dtypes.typeclass(obj.dtype.type) - return Array(dtype=dtype, strides=tuple(s // obj.itemsize for s in obj.strides), shape=obj.shape) - elif type(obj).__module__ == "cupy" and type(obj).__name__ == "ndarray": - # special case for CuPy and HIP, which does not support __cuda_array_interface__ - storage = dtypes.StorageType.GPU_Global - dtype = dtypes.typeclass(obj.dtype.type) - itemsize = obj.itemsize - return Array(dtype=dtype, shape=obj.shape, strides=tuple(s // itemsize for s in obj.strides), storage=storage) - elif symbolic.issymbolic(obj): - return Scalar(symbolic.symtype(obj)) - elif isinstance(obj, dtypes.typeclass): - return Scalar(obj) - elif (obj is int or obj is float or obj is complex or obj is bool or obj is None): - return Scalar(dtypes.typeclass(obj)) - elif isinstance(obj, type) and issubclass(obj, np.number): - return Scalar(dtypes.typeclass(obj)) - elif isinstance(obj, (Number, np.number, np.bool_)): - return Scalar(dtypes.typeclass(type(obj))) - elif obj is type(None): - # NoneType is void * - return Scalar(dtypes.pointer(dtypes.typeclass(None))) - elif isinstance(obj, str) or obj is str: - return Scalar(dtypes.string) - elif callable(obj): - # Cannot determine return value/argument types from function object - return Scalar(dtypes.callback(None)) - - raise TypeError(f'Could not create a DaCe data descriptor from object {obj}. ' - 'If this is a custom object, consider creating a `__descriptor__` ' - 'adaptor method to the type hint or object itself.') - - -def _prod(sequence): - return functools.reduce(lambda a, b: a * b, sequence, 1) - - -def find_new_name(name: str, existing_names: Sequence[str]) -> str: - """ - Returns a name that matches the given ``name`` as a prefix, but does not - already exist in the given existing name set. The behavior is typically - to append an underscore followed by a unique (increasing) number. If the - name does not already exist in the set, it is returned as-is. - - :param name: The given name to find. - :param existing_names: The set of existing names. - :return: A new name that is not in existing_names. - """ - if name not in existing_names: - return name - cur_offset = 0 - new_name = name + '_' + str(cur_offset) - while new_name in existing_names: - cur_offset += 1 - new_name = name + '_' + str(cur_offset) - return new_name + +def _arrays_from_json(obj, context=None): + if obj is None: + return {} + return OrderedDict((k, serialize.from_json(v, context)) for k, v in obj) @make_properties @@ -340,915 +222,20 @@ def set_strides_from_layout(self, self.strides = strides self.total_size = totalsize - def __matmul__(self, storage: dtypes.StorageType): - """ - Syntactic sugar for specifying the storage of a data descriptor. - This enables controlling the storage location as follows: - - .. code-block:: python - - @dace - def add(X: dace.float32[10, 10] @ dace.StorageType.GPU_Global): - return X + 1 - """ - new_desc = cp.deepcopy(self) - new_desc.storage = storage - return new_desc - - -def _arrays_to_json(arrays): - if arrays is None: - return None - return [(k, serialize.to_json(v)) for k, v in arrays.items()] - - -def _arrays_from_json(obj, context=None): - if obj is None: - return {} - return OrderedDict((k, serialize.from_json(v, context)) for k, v in obj) - - -@make_properties -class Structure(Data): - """ Base class for structures. """ - - members = OrderedDictProperty(default=OrderedDict(), - desc="Dictionary of structure members", - from_json=_arrays_from_json, - to_json=_arrays_to_json) - name = Property(dtype=str, desc="Structure type name") - - def __init__(self, - members: Union[Dict[str, Data], List[Tuple[str, Data]]], - name: str = 'Structure', - transient: bool = False, - storage: dtypes.StorageType = dtypes.StorageType.Default, - location: Dict[str, str] = None, - lifetime: dtypes.AllocationLifetime = dtypes.AllocationLifetime.Scope, - debuginfo: dtypes.DebugInfo = None): - - self.members = OrderedDict(members) - for k, v in self.members.items(): - if isinstance(v, dtypes.typeclass): - v = Scalar(v) - self.members[k] = v - v.transient = transient - - self.name = name - fields_and_types = OrderedDict() - symbols = set() - for k, v in self.members.items(): - if isinstance(v, Structure): - symbols |= v.free_symbols - fields_and_types[k] = (v.dtype, str(v.total_size)) - elif isinstance(v, Array): - symbols |= v.free_symbols - fields_and_types[k] = (dtypes.pointer(v.dtype), str(_prod(v.shape))) - elif isinstance(v, Scalar): - symbols |= v.free_symbols - fields_and_types[k] = v.dtype - elif isinstance(v, dtypes.typeclass): - fields_and_types[k] = v - elif isinstance(v, (sp.Basic, symbolic.SymExpr)): - symbols |= v.free_symbols - fields_and_types[k] = symbolic.symtype(v) - elif isinstance(v, (int, np.integer)): - fields_and_types[k] = dtypes.typeclass(type(v)) - else: - raise TypeError(f"Attribute {k}'s value {v} has unsupported type: {type(v)}") - - # NOTE: We will not store symbols in the dtype for now, but leaving it as a comment to investigate later. - # NOTE: See discussion about data/object symbols. - # for s in symbols: - # if str(s) in fields_and_types: - # continue - # if hasattr(s, "dtype"): - # fields_and_types[str(s)] = s.dtype - # else: - # fields_and_types[str(s)] = dtypes.int32 - - dtype = dtypes.pointer(dtypes.struct(name, **fields_and_types)) - dtype.base_type.__descriptor__ = self - shape = (1, ) - super(Structure, self).__init__(dtype, shape, transient, storage, location, lifetime, debuginfo) - - @staticmethod - def from_json(json_obj, context=None): - if json_obj['type'] != 'Structure': - raise TypeError("Invalid data type") - - # Create dummy object - ret = Structure({}) - serialize.set_properties_from_json(ret, json_obj, context=context) - - return ret - - @staticmethod - def from_dataclass(cls, **overrides) -> 'Structure': - """ - Creates a Structure data descriptor from a dataclass instance. - - :param cls: The dataclass to convert. - :param overrides: Optional overrides for the structure fields. - :return: A Structure data descriptor. - """ - members = {} - for field in dataclasses.fields(cls): - # Recursive structures - if dataclasses.is_dataclass(field.type): - members[field.name] = Structure.from_dataclass(field.type) - continue - members[field.name] = field.type - - members.update(overrides) - return Structure(members, name=cls.__name__) - - @property - def total_size(self): - return -1 - - @property - def offset(self): - return [0] - - @property - def start_offset(self): - return 0 - - @property - def strides(self): - return [1] - - @property - def free_symbols(self) -> Set[symbolic.SymbolicType]: - """ Returns a set of undefined symbols in this data descriptor. """ - result = set() - for k, v in self.members.items(): - result |= v.free_symbols - return result - - def __repr__(self): - return f"{self.name} ({', '.join([f'{k}: {v}' for k, v in self.members.items()])})" - - def as_arg(self, with_types=True, for_call=False, name=None): - if self.storage is dtypes.StorageType.GPU_Global: - return Array(self.dtype, [1]).as_arg(with_types, for_call, name) - if not with_types or for_call: - return name - return self.dtype.as_arg(name) - - def __getitem__(self, s): - """ This is syntactic sugar that allows us to define an array type - with the following syntax: ``Structure[N,M]`` - :return: A ``data.ContainerArray`` data descriptor. - """ - if isinstance(s, list) or isinstance(s, tuple): - return ContainerArray(self, tuple(s)) - return ContainerArray(self, (s, )) - - # NOTE: Like Scalars? - @property - def may_alias(self) -> bool: - return False - - # TODO: Can Structures be optional? - @property - def optional(self) -> bool: - return False - - def keys(self): - result = self.members.keys() - for k, v in self.members.items(): - if isinstance(v, Structure): - result |= set(map(lambda x: f"{k}.{x}", v.keys())) - return result - - def clone(self): - return Structure(self.members, self.name, self.transient, self.storage, self.location, self.lifetime, - self.debuginfo) - - # NOTE: Like scalars? - @property - def pool(self) -> bool: - return False - - def make_argument(self, **fields) -> ctypes.Structure: - """ - Creates a structure instance from the given field values, which can be used as - an argument for DaCe programs. - - :param fields: Dictionary of field names to values. - :return: A ctypes Structure instance. - """ - struct_type: dtypes.struct = self.dtype.base_type - struct_ctype = struct_type.as_ctypes() - - def _make_arg(arg: Any, expected_type: Data, name: str) -> Any: - if isinstance(expected_type, Structure): - return ctypes.pointer(expected_type.make_argument_from_object(arg)) - return make_ctypes_argument(arg, expected_type, name) - - args = { - field_name: _make_arg(field_value, self.members[field_name], field_name) - for field_name, field_value in fields.items() if field_name in self.members - } - - struct_instance = struct_ctype(**args) - return struct_instance - - def make_argument_from_object(self, obj) -> ctypes.Structure: - """ - Creates a structure instance from the given object, which can be used as - an argument for DaCe programs. If the object has attributes matching the field names, - those attributes are used as field values. Other attributes are ignored. - - :param obj: Object containing field values. - :return: A ctypes Structure instance. - """ - return self.make_argument(**{field_name: getattr(obj, field_name) for field_name in self.members}) - - -class TensorIterationTypes(aenum.AutoNumberEnum): - """ - Types of tensor iteration capabilities. - - Value (Coordinate Value Iteration) allows to directly iterate over - coordinates such as when using the Dense index type. - - Position (Coordinate Position Iteratation) iterates over coordinate - positions, at which the actual coordinates lie. This is for example the case - with a compressed index, in which the pos array enables one to iterate over - the positions in the crd array that hold the actual coordinates. - """ - Value = () - Position = () - - -class TensorAssemblyType(aenum.AutoNumberEnum): - """ - Types of possible assembly strategies for the individual indices. - - NoAssembly: Assembly is not possible as such. - - Insert: index allows inserting elements at random (e.g. Dense) - - Append: index allows appending to a list of existing coordinates. Depending - on append order, this affects whether the index is ordered or not. This - could be changed by sorting the index after assembly - """ - NoAssembly = () - Insert = () - Append = () - - -class TensorIndex(ABC): - """ - Abstract base class for tensor index implementations. - """ - - @property - @abstractmethod - def iteration_type(self) -> TensorIterationTypes: - """ - Iteration capability supported by this index. - - See TensorIterationTypes for reference. - """ - pass - - @property - @abstractmethod - def locate(self) -> bool: - """ - True if the index supports locate (aka random access), False otw. - """ - pass - - @property - @abstractmethod - def assembly(self) -> TensorAssemblyType: - """ - What assembly type is supported by the index. - - See TensorAssemblyType for reference. - """ - pass - - @property - @abstractmethod - def full(self) -> bool: - """ - True if the level is full, False otw. - - A level is considered full if it encompasses all valid coordinates along - the corresponding tensor dimension. - """ - pass - - @property - @abstractmethod - def ordered(self) -> bool: - """ - True if the level is ordered, False otw. - - A level is ordered when all coordinates that share the same ancestor are - ordered by increasing value (e.g. in typical CSR). - """ - pass - - @property - @abstractmethod - def unique(self) -> bool: - """ - True if coordinate in the level are unique, False otw. - - A level is considered unique if no collection of coordinates that share - the same ancestor contains duplicates. In CSR this is True, in COO it is - not. - """ - pass - - @property - @abstractmethod - def branchless(self) -> bool: - """ - True if the level doesn't branch, false otw. - - A level is considered branchless if no coordinate has a sibling (another - coordinate with same ancestor) and all coordinates in parent level have - a child. In other words if there is a bijection between the coordinates - in this level and the parent level. An example of the is the Singleton - index level in the COO format. - """ - pass - - @property - @abstractmethod - def compact(self) -> bool: - """ - True if the level is compact, false otw. - - A level is compact if no two coordinates are separated by an unlabled - node that does not encode a coordinate. An example of a compact level - can be found in CSR, while the DIA formats range and offset levels are - not compact (they have entries that would coorespond to entries outside - the tensors index range, e.g. column -1). - """ - pass - - @abstractmethod - def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: - """ - Generates the fields needed for the index. - - :return: a Dict of fields that need to be present in the struct - """ - pass - - def to_json(self): - attrs = serialize.all_properties_to_json(self) - - retdict = {"type": type(self).__name__, "attributes": attrs} - - return retdict - - @classmethod - def from_json(cls, json_obj, context=None): - - # Selecting proper subclass - if json_obj['type'] == "TensorIndexDense": - self = TensorIndexDense.__new__(TensorIndexDense) - elif json_obj['type'] == "TensorIndexCompressed": - self = TensorIndexCompressed.__new__(TensorIndexCompressed) - elif json_obj['type'] == "TensorIndexSingleton": - self = TensorIndexSingleton.__new__(TensorIndexSingleton) - elif json_obj['type'] == "TensorIndexRange": - self = TensorIndexRange.__new__(TensorIndexRange) - elif json_obj['type'] == "TensorIndexOffset": - self = TensorIndexOffset.__new__(TensorIndexOffset) - else: - raise TypeError(f"Invalid data type, got: {json_obj['type']}") - - serialize.set_properties_from_json(self, json_obj['attributes'], context=context) - - return self - - -@make_properties -class TensorIndexDense(TensorIndex): - """ - Dense tensor index. - - Levels of this type encode the the coordinate in the interval [0, N), where - N is the size of the corresponding dimension. This level doesn't need any - index structure beyond the corresponding dimension size. - """ - - _ordered = Property(dtype=bool, default=False) - _unique = Property(dtype=bool) - - @property - def iteration_type(self) -> TensorIterationTypes: - return TensorIterationTypes.Value - - @property - def locate(self) -> bool: - return True - - @property - def assembly(self) -> TensorAssemblyType: - return TensorAssemblyType.Insert - - @property - def full(self) -> bool: - return True - - @property - def ordered(self) -> bool: - return self._ordered - - @property - def unique(self) -> bool: - return self._unique - - @property - def branchless(self) -> bool: - return False - - @property - def compact(self) -> bool: - return True - - def __init__(self, ordered: bool = True, unique: bool = True): - self._ordered = ordered - self._unique = unique - - def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: - return {} - - def __repr__(self) -> str: - s = "Dense" - - non_defaults = [] - if not self._ordered: - non_defaults.append("¬O") - if not self._unique: - non_defaults.append("¬U") - - if len(non_defaults) > 0: - s += f"({','.join(non_defaults)})" - - return s - - -@make_properties -class TensorIndexCompressed(TensorIndex): - """ - Tensor level that stores coordinates in segmented array. - - Levels of this type are compressed using a segented array. The pos array - holds the start and end positions of the segment in the crd (coordinate) - array that holds the child coordinates corresponding the parent. - """ - - _full = Property(dtype=bool, default=False) - _ordered = Property(dtype=bool, default=False) - _unique = Property(dtype=bool, default=False) - - @property - def iteration_type(self) -> TensorIterationTypes: - return TensorIterationTypes.Position - - @property - def locate(self) -> bool: - return False - - @property - def assembly(self) -> TensorAssemblyType: - return TensorAssemblyType.Append - - @property - def full(self) -> bool: - return self._full - - @property - def ordered(self) -> bool: - return self._ordered - - @property - def unique(self) -> bool: - return self._unique - - @property - def branchless(self) -> bool: - return False - - @property - def compact(self) -> bool: - return True - - def __init__(self, full: bool = False, ordered: bool = True, unique: bool = True): - self._full = full - self._ordered = ordered - self._unique = unique - - def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: - return { - f"idx{lvl}_pos": dtypes.int32[dummy_symbol], # TODO (later) choose better length - f"idx{lvl}_crd": dtypes.int32[dummy_symbol], # TODO (later) choose better length - } - - def __repr__(self) -> str: - s = "Compressed" - - non_defaults = [] - if self._full: - non_defaults.append("F") - if not self._ordered: - non_defaults.append("¬O") - if not self._unique: - non_defaults.append("¬U") - - if len(non_defaults) > 0: - s += f"({','.join(non_defaults)})" - - return s - - -@make_properties -class TensorIndexSingleton(TensorIndex): - """ - Tensor index that encodes a single coordinate per parent coordinate. - - Levels of this type hold exactly one coordinate for every coordinate in the - parent level. An example can be seen in the COO format, where every - coordinate but the first is encoded in this manner. - """ - - _full = Property(dtype=bool, default=False) - _ordered = Property(dtype=bool, default=False) - _unique = Property(dtype=bool, default=False) - - @property - def iteration_type(self) -> TensorIterationTypes: - return TensorIterationTypes.Position - - @property - def locate(self) -> bool: - return False - - @property - def assembly(self) -> TensorAssemblyType: - return TensorAssemblyType.Append - - @property - def full(self) -> bool: - return self._full - - @property - def ordered(self) -> bool: - return self._ordered - - @property - def unique(self) -> bool: - return self._unique - - @property - def branchless(self) -> bool: - return True - - @property - def compact(self) -> bool: - return True - - def __init__(self, full: bool = False, ordered: bool = True, unique: bool = True): - self._full = full - self._ordered = ordered - self._unique = unique - - def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: - return { - f"idx{lvl}_crd": dtypes.int32[dummy_symbol], # TODO (later) choose better length - } - - def __repr__(self) -> str: - s = "Singleton" - - non_defaults = [] - if self._full: - non_defaults.append("F") - if not self._ordered: - non_defaults.append("¬O") - if not self._unique: - non_defaults.append("¬U") - - if len(non_defaults) > 0: - s += f"({','.join(non_defaults)})" - - return s - - -@make_properties -class TensorIndexRange(TensorIndex): - """ - Tensor index that encodes a interval of coordinates for every parent. - - The interval is computed from an offset for each parent together with the - tensor dimension size of this level (M) and the parent level (N) parents - corresponding tensor. Given the parent coordinate i, the level encodes the - range of coordinates between max(0, -offset[i]) and min(N, M - offset[i]). - """ - - _ordered = Property(dtype=bool, default=False) - _unique = Property(dtype=bool, default=False) - - @property - def iteration_type(self) -> TensorIterationTypes: - return TensorIterationTypes.Value - - @property - def locate(self) -> bool: - return False - - @property - def assembly(self) -> TensorAssemblyType: - return TensorAssemblyType.NoAssembly - - @property - def full(self) -> bool: - return False - - @property - def ordered(self) -> bool: - return self._ordered - - @property - def unique(self) -> bool: - return self._unique - - @property - def branchless(self) -> bool: - return False - - @property - def compact(self) -> bool: - return False - - def __init__(self, ordered: bool = True, unique: bool = True): - self._ordered = ordered - self._unique = unique - - def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: - return { - f"idx{lvl}_offset": dtypes.int32[dummy_symbol], # TODO (later) choose better length - } - - def __repr__(self) -> str: - s = "Range" - - non_defaults = [] - if not self._ordered: - non_defaults.append("¬O") - if not self._unique: - non_defaults.append("¬U") - - if len(non_defaults) > 0: - s += f"({','.join(non_defaults)})" - - return s - - -@make_properties -class TensorIndexOffset(TensorIndex): - """ - Tensor index that encodes the next coordinates as offset from parent. - - Given a parent coordinate i and an offset index k, the level encodes the - coordinate j = i + offset[k]. - """ - - _ordered = Property(dtype=bool, default=False) - _unique = Property(dtype=bool, default=False) - - @property - def iteration_type(self) -> TensorIterationTypes: - return TensorIterationTypes.Position - - @property - def locate(self) -> bool: - return False - - @property - def assembly(self) -> TensorAssemblyType: - return TensorAssemblyType.NoAssembly - - @property - def full(self) -> bool: - return False - - @property - def ordered(self) -> bool: - return self._ordered - - @property - def unique(self) -> bool: - return self._unique - - @property - def branchless(self) -> bool: - return True - - @property - def compact(self) -> bool: - return False - - def __init__(self, ordered: bool = True, unique: bool = True): - self._ordered = ordered - self._unique = unique - - def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: - return { - f"idx{lvl}_offset": dtypes.int32[dummy_symbol], # TODO (later) choose better length - } - - def __repr__(self) -> str: - s = "Offset" - - non_defaults = [] - if not self._ordered: - non_defaults.append("¬O") - if not self._unique: - non_defaults.append("¬U") - - if len(non_defaults) > 0: - s += f"({','.join(non_defaults)})" - - return s - - -@make_properties -class Tensor(Structure): - """ - Abstraction for Tensor storage format. - - This abstraction is based on [https://doi.org/10.1145/3276493]. - """ - - value_dtype = TypeClassProperty(default=dtypes.int32, choices=dtypes.Typeclasses) - tensor_shape = ShapeProperty(default=[]) - indices = ListProperty(element_type=TensorIndex) - index_ordering = ListProperty(element_type=symbolic.SymExpr) - value_count = SymbolicProperty(default=0) - - def __init__(self, - value_dtype: dtypes.Typeclasses, - tensor_shape, - indices: List[Tuple[TensorIndex, Union[int, symbolic.SymExpr]]], - value_count: symbolic.SymExpr, - name: str, - transient: bool = False, - storage: dtypes.StorageType = dtypes.StorageType.Default, - location: Dict[str, str] = None, - lifetime: dtypes.AllocationLifetime = dtypes.AllocationLifetime.Scope, - debuginfo: dtypes.DebugInfo = None): - """ - Constructor for Tensor storage format. - - Below are examples of common matrix storage formats: - - .. code-block:: python - - M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) - - csr = dace.data.Tensor( - dace.float32, - (M, N), - [(dace.data.Dense(), 0), (dace.data.Compressed(), 1)], - nnz, - "CSR_Matrix", - ) - - csc = dace.data.Tensor( - dace.float32, - (M, N), - [(dace.data.Dense(), 1), (dace.data.Compressed(), 0)], - nnz, - "CSC_Matrix", - ) - - coo = dace.data.Tensor( - dace.float32, - (M, N), - [ - (dace.data.Compressed(unique=False), 0), - (dace.data.Singleton(), 1), - ], - nnz, - "CSC_Matrix", - ) - - num_diags = dace.symbol('num_diags') # number of diagonals stored - - diag = dace.data.Tensor( - dace.float32, - (M, N), - [ - (dace.data.Dense(), num_diags), - (dace.data.Range(), 0), - (dace.data.Offset(), 1), - ], - nnz, - "DIA_Matrix", - ) - - Below you can find examples of common 3rd order tensor storage formats: - - .. code-block:: python - - I, J, K, nnz = (dace.symbol(s) for s in ('I', 'J', 'K', 'nnz')) - - coo = dace.data.Tensor( - dace.float32, - (I, J, K), - [ - (dace.data.Compressed(unique=False), 0), - (dace.data.Singleton(unique=False), 1), - (dace.data.Singleton(), 2), - ], - nnz, - "COO_3D_Tensor", - ) - - csf = dace.data.Tensor( - dace.float32, - (I, J, K), - [ - (dace.data.Compressed(), 0), - (dace.data.Compressed(), 1), - (dace.data.Compressed(), 2), - ], - nnz, - "CSF_3D_Tensor", - ) - - :param value_type: data type of the explicitly stored values. - :param tensor_shape: logical shape of tensor (#rows, #cols, etc...) - :param indices: - a list of tuples, each tuple represents a level in the tensor - storage hirachy, specifying the levels tensor index type, and the - corresponding dimension this level encodes (as index of the - tensor_shape tuple above). The order of the dimensions may differ - from the logical shape of the tensor, e.g. as seen in the CSC - format. If an index's dimension is unrelated to the tensor shape - (e.g. in diagonal format where the first index's dimension is the - number of diagonals stored), a symbol can be specified instead. - :param value_count: number of explicitly stored values. - :param name: name of resulting struct. - :param others: See Structure class for remaining arguments - """ - - self.value_dtype = value_dtype - self.tensor_shape = tensor_shape - self.value_count = value_count - - indices, index_ordering = zip(*indices) - self.indices, self.index_ordering = list(indices), list(index_ordering) - - num_dims = len(tensor_shape) - dimension_order = [idx for idx in self.index_ordering if isinstance(idx, int)] - - # all tensor dimensions must occure exactly once in indices - if not sorted(dimension_order) == list(range(num_dims)): - raise TypeError((f"All tensor dimensions must be refferenced exactly once in " - f"tensor indices. (referenced dimensions: {dimension_order}; " - f"tensor dimensions: {list(range(num_dims))})")) - - # assembling permanent and index specific fields - fields = dict( - order=Scalar(dtypes.int32), - dim_sizes=dtypes.int32[num_dims], - value_count=value_count, - values=dtypes.float32[value_count], - ) - - for (lvl, index) in enumerate(indices): - fields.update(index.fields(lvl, value_count)) - - super(Tensor, self).__init__(fields, name, transient, storage, location, lifetime, debuginfo) - - def __repr__(self): - return f"{self.name} (dtype: {self.value_dtype}, shape: {list(self.tensor_shape)}, indices: {self.indices})" - - @staticmethod - def from_json(json_obj, context=None): - if json_obj['type'] != 'Tensor': - raise TypeError("Invalid data type") + def __matmul__(self, storage: dtypes.StorageType): + """ + Syntactic sugar for specifying the storage of a data descriptor. + This enables controlling the storage location as follows: - # Create dummy object - tensor = Tensor.__new__(Tensor) - serialize.set_properties_from_json(tensor, json_obj, context=context) + .. code-block:: python - return tensor + @dace + def add(X: dace.float32[10, 10] @ dace.StorageType.GPU_Global): + return X + 1 + """ + new_desc = cp.deepcopy(self) + new_desc.storage = storage + return new_desc @make_properties @@ -1727,6 +714,60 @@ def is_packed_c_strides(self) -> bool: return tuple(strides) == tuple(self.strides) +@make_properties +class ContainerArray(Array): + """ An array that may contain other data containers (e.g., Structures, other arrays). """ + + stype = NestedDataClassProperty(allow_none=True, default=None) + + def __init__(self, + stype: Data, + shape, + transient=False, + allow_conflicts=False, + storage=dtypes.StorageType.Default, + location=None, + strides=None, + offset=None, + may_alias=False, + lifetime=dtypes.AllocationLifetime.Scope, + alignment=0, + debuginfo=None, + total_size=None, + start_offset=None, + optional=None, + pool=False): + + self.stype = stype + if stype: + if isinstance(stype, Structure): + dtype = stype.dtype + else: + dtype = dtypes.pointer(stype.dtype) + else: + dtype = dtypes.pointer(dtypes.typeclass(None)) # void* + super(ContainerArray, + self).__init__(dtype, shape, transient, allow_conflicts, storage, location, strides, offset, may_alias, + lifetime, alignment, debuginfo, total_size, start_offset, optional, pool) + + @classmethod + def from_json(cls, json_obj, context=None): + # Create dummy object + ret = cls(None, ()) + serialize.set_properties_from_json(ret, json_obj, context=context) + + # Default shape-related properties + if not ret.offset: + ret.offset = [0] * len(ret.shape) + if not ret.strides: + # Default strides are C-ordered + ret.strides = [_prod(ret.shape[i + 1:]) for i in range(len(ret.shape))] + if ret.total_size == 0: + ret.total_size = _prod(ret.shape) + + return ret + + @make_properties class Stream(Data): """ Stream (or stream array) data descriptor. """ @@ -1773,174 +814,319 @@ def from_json(cls, json_obj, context=None): ret = cls(dtypes.int8, 1) serialize.set_properties_from_json(ret, json_obj, context=context) - return ret + return ret + + def __repr__(self): + return '%s (dtype=%s, shape=%s)' % (type(self).__name__, self.dtype, self.shape) + + @property + def total_size(self): + return _prod(self.shape) + + @property + def strides(self): + return [_prod(self.shape[i + 1:]) for i in range(len(self.shape))] + + @property + def start_offset(self): + return 0 + + @property + def optional(self) -> bool: + return False + + @property + def may_alias(self) -> bool: + return False + + def clone(self): + return type(self)(self.dtype, self.buffer_size, self.shape, self.transient, self.storage, self.location, + self.offset, self.lifetime, self.debuginfo) + + # Checks for equivalent shape and type + def is_equivalent(self, other): + if not isinstance(other, type(self)): + return False + + # Test type + if self.dtype != other.dtype: + return False + + # Test dimensionality + if len(self.shape) != len(other.shape): + return False + + # Test shape + for dim, otherdim in zip(self.shape, other.shape): + if dim != otherdim: + return False + return True + + def as_arg(self, with_types=True, for_call=False, name=None): + if not with_types or for_call: return name + if self.storage in [dtypes.StorageType.GPU_Global, dtypes.StorageType.GPU_Shared]: + return 'dace::GPUStream<%s, %s> %s' % (str( + self.dtype.ctype), 'true' if sp.log(self.buffer_size, 2).is_Integer else 'false', name) + + return 'dace::Stream<%s> %s' % (str(self.dtype.ctype), name) + + def sizes(self): + return [d.name if isinstance(d, symbolic.symbol) else str(d) for d in self.shape] + + def is_stream_array(self): + return _prod(self.shape) != 1 + + def covers_range(self, rng): + if len(rng) != len(self.shape): + return False + + for s, (rb, re, rs) in zip(self.shape, rng): + # Shape has to be positive + if isinstance(s, sp.Basic): + olds = s + if 'positive' in s.assumptions0: + s = sp.Symbol(str(s), **s.assumptions0) + else: + s = sp.Symbol(str(s), positive=True, **s.assumptions0) + if isinstance(rb, sp.Basic): + rb = rb.subs({olds: s}) + if isinstance(re, sp.Basic): + re = re.subs({olds: s}) + if isinstance(rs, sp.Basic): + rs = rs.subs({olds: s}) + + try: + if rb < 0: # Negative offset + return False + except TypeError: # cannot determine truth value of Relational + pass + #print('WARNING: Cannot evaluate relational expression %s, assuming true.' % (rb > 0), + # 'If this expression is false, please refine symbol definitions in the program.') + try: + if re > s: # Beyond shape + return False + except TypeError: # cannot determine truth value of Relational + pass + #print('WARNING: Cannot evaluate relational expression %s, assuming true.' % (re < s), + # 'If this expression is false, please refine symbol definitions in the program.') + + return True + + def used_symbols(self, all_symbols: bool) -> Set[symbolic.SymbolicType]: + result = super().used_symbols(all_symbols) + if (self.transient or all_symbols) and isinstance(self.buffer_size, sp.Expr): + result |= set(self.buffer_size.free_symbols) + for o in self.offset: + if isinstance(o, sp.Expr): + result |= set(o.free_symbols) + + return result + + @property + def free_symbols(self): + return self.used_symbols(all_symbols=True) + + +@make_properties +class Structure(Data): + """ Base class for structures. """ + + members = OrderedDictProperty(default=OrderedDict(), + desc="Dictionary of structure members", + from_json=_arrays_from_json, + to_json=_arrays_to_json) + name = Property(dtype=str, desc="Structure type name") + + def __init__(self, + members: Union[Dict[str, Data], List[Tuple[str, Data]]], + name: str = 'Structure', + transient: bool = False, + storage: dtypes.StorageType = dtypes.StorageType.Default, + location: Dict[str, str] = None, + lifetime: dtypes.AllocationLifetime = dtypes.AllocationLifetime.Scope, + debuginfo: dtypes.DebugInfo = None): + + self.members = OrderedDict(members) + for k, v in self.members.items(): + if isinstance(v, dtypes.typeclass): + v = Scalar(v) + self.members[k] = v + v.transient = transient + + self.name = name + fields_and_types = OrderedDict() + symbols = set() + for k, v in self.members.items(): + if isinstance(v, Structure): + symbols |= v.free_symbols + fields_and_types[k] = (v.dtype, str(v.total_size)) + elif isinstance(v, Array): + symbols |= v.free_symbols + fields_and_types[k] = (dtypes.pointer(v.dtype), str(_prod(v.shape))) + elif isinstance(v, Scalar): + symbols |= v.free_symbols + fields_and_types[k] = v.dtype + elif isinstance(v, dtypes.typeclass): + fields_and_types[k] = v + elif isinstance(v, (sp.Basic, symbolic.SymExpr)): + symbols |= v.free_symbols + fields_and_types[k] = symbolic.symtype(v) + elif isinstance(v, (int, np.integer)): + fields_and_types[k] = dtypes.typeclass(type(v)) + else: + raise TypeError(f"Attribute {k}'s value {v} has unsupported type: {type(v)}") + + # NOTE: We will not store symbols in the dtype for now, but leaving it as a comment to investigate later. + # NOTE: See discussion about data/object symbols. + # for s in symbols: + # if str(s) in fields_and_types: + # continue + # if hasattr(s, "dtype"): + # fields_and_types[str(s)] = s.dtype + # else: + # fields_and_types[str(s)] = dtypes.int32 + + dtype = dtypes.pointer(dtypes.struct(name, **fields_and_types)) + dtype.base_type.__descriptor__ = self + shape = (1, ) + super(Structure, self).__init__(dtype, shape, transient, storage, location, lifetime, debuginfo) + + @staticmethod + def from_json(json_obj, context=None): + if json_obj['type'] != 'Structure': + raise TypeError("Invalid data type") + + # Create dummy object + ret = Structure({}) + serialize.set_properties_from_json(ret, json_obj, context=context) + + return ret + + @staticmethod + def from_dataclass(cls, **overrides) -> 'Structure': + """ + Creates a Structure data descriptor from a dataclass instance. + + :param cls: The dataclass to convert. + :param overrides: Optional overrides for the structure fields. + :return: A Structure data descriptor. + """ + members = {} + for field in dataclasses.fields(cls): + # Recursive structures + if dataclasses.is_dataclass(field.type): + members[field.name] = Structure.from_dataclass(field.type) + continue + members[field.name] = field.type - def __repr__(self): - return '%s (dtype=%s, shape=%s)' % (type(self).__name__, self.dtype, self.shape) + members.update(overrides) + return Structure(members, name=cls.__name__) @property def total_size(self): - return _prod(self.shape) + return -1 @property - def strides(self): - return [_prod(self.shape[i + 1:]) for i in range(len(self.shape))] + def offset(self): + return [0] @property def start_offset(self): return 0 @property - def optional(self) -> bool: - return False + def strides(self): + return [1] @property - def may_alias(self) -> bool: - return False - - def clone(self): - return type(self)(self.dtype, self.buffer_size, self.shape, self.transient, self.storage, self.location, - self.offset, self.lifetime, self.debuginfo) - - # Checks for equivalent shape and type - def is_equivalent(self, other): - if not isinstance(other, type(self)): - return False - - # Test type - if self.dtype != other.dtype: - return False - - # Test dimensionality - if len(self.shape) != len(other.shape): - return False + def free_symbols(self) -> Set[symbolic.SymbolicType]: + """ Returns a set of undefined symbols in this data descriptor. """ + result = set() + for k, v in self.members.items(): + result |= v.free_symbols + return result - # Test shape - for dim, otherdim in zip(self.shape, other.shape): - if dim != otherdim: - return False - return True + def __repr__(self): + return f"{self.name} ({', '.join([f'{k}: {v}' for k, v in self.members.items()])})" def as_arg(self, with_types=True, for_call=False, name=None): - if not with_types or for_call: return name - if self.storage in [dtypes.StorageType.GPU_Global, dtypes.StorageType.GPU_Shared]: - return 'dace::GPUStream<%s, %s> %s' % (str( - self.dtype.ctype), 'true' if sp.log(self.buffer_size, 2).is_Integer else 'false', name) - - return 'dace::Stream<%s> %s' % (str(self.dtype.ctype), name) - - def sizes(self): - return [d.name if isinstance(d, symbolic.symbol) else str(d) for d in self.shape] - - def size_string(self): - return (" * ".join([cppunparse.pyexpr2cpp(symbolic.symstr(s, cpp_mode=True)) for s in self.shape])) - - def is_stream_array(self): - return _prod(self.shape) != 1 - - def covers_range(self, rng): - if len(rng) != len(self.shape): - return False - - for s, (rb, re, rs) in zip(self.shape, rng): - # Shape has to be positive - if isinstance(s, sp.Basic): - olds = s - if 'positive' in s.assumptions0: - s = sp.Symbol(str(s), **s.assumptions0) - else: - s = sp.Symbol(str(s), positive=True, **s.assumptions0) - if isinstance(rb, sp.Basic): - rb = rb.subs({olds: s}) - if isinstance(re, sp.Basic): - re = re.subs({olds: s}) - if isinstance(rs, sp.Basic): - rs = rs.subs({olds: s}) + if self.storage is dtypes.StorageType.GPU_Global: + return Array(self.dtype, [1]).as_arg(with_types, for_call, name) + if not with_types or for_call: + return name + return self.dtype.as_arg(name) - try: - if rb < 0: # Negative offset - return False - except TypeError: # cannot determine truth value of Relational - pass - #print('WARNING: Cannot evaluate relational expression %s, assuming true.' % (rb > 0), - # 'If this expression is false, please refine symbol definitions in the program.') - try: - if re > s: # Beyond shape - return False - except TypeError: # cannot determine truth value of Relational - pass - #print('WARNING: Cannot evaluate relational expression %s, assuming true.' % (re < s), - # 'If this expression is false, please refine symbol definitions in the program.') + def __getitem__(self, s): + """ This is syntactic sugar that allows us to define an array type + with the following syntax: ``Structure[N,M]`` + :return: A ``data.ContainerArray`` data descriptor. + """ + if isinstance(s, list) or isinstance(s, tuple): + return ContainerArray(self, tuple(s)) + return ContainerArray(self, (s, )) - return True + # NOTE: Like Scalars? + @property + def may_alias(self) -> bool: + return False - def used_symbols(self, all_symbols: bool) -> Set[symbolic.SymbolicType]: - result = super().used_symbols(all_symbols) - if (self.transient or all_symbols) and isinstance(self.buffer_size, sp.Expr): - result |= set(self.buffer_size.free_symbols) - for o in self.offset: - if isinstance(o, sp.Expr): - result |= set(o.free_symbols) + # TODO: Can Structures be optional? + @property + def optional(self) -> bool: + return False + def keys(self): + result = self.members.keys() + for k, v in self.members.items(): + if isinstance(v, Structure): + result |= set(map(lambda x: f"{k}.{x}", v.keys())) return result - @property - def free_symbols(self): - return self.used_symbols(all_symbols=True) + def clone(self): + return Structure(self.members, self.name, self.transient, self.storage, self.location, self.lifetime, + self.debuginfo) + # NOTE: Like scalars? + @property + def pool(self) -> bool: + return False -@make_properties -class ContainerArray(Array): - """ An array that may contain other data containers (e.g., Structures, other arrays). """ + def make_argument(self, **fields) -> ctypes.Structure: + """ + Creates a structure instance from the given field values, which can be used as + an argument for DaCe programs. - stype = NestedDataClassProperty(allow_none=True, default=None) + :param fields: Dictionary of field names to values. + :return: A ctypes Structure instance. + """ + # Import here to avoid circular import + from dace.data.ctypes_interop import make_ctypes_argument + struct_type: dtypes.struct = self.dtype.base_type + struct_ctype = struct_type.as_ctypes() - def __init__(self, - stype: Data, - shape, - transient=False, - allow_conflicts=False, - storage=dtypes.StorageType.Default, - location=None, - strides=None, - offset=None, - may_alias=False, - lifetime=dtypes.AllocationLifetime.Scope, - alignment=0, - debuginfo=None, - total_size=None, - start_offset=None, - optional=None, - pool=False): + def _make_arg(arg: Any, expected_type: Data, name: str) -> Any: + if isinstance(expected_type, Structure): + return ctypes.pointer(expected_type.make_argument_from_object(arg)) + return make_ctypes_argument(arg, expected_type, name) - self.stype = stype - if stype: - if isinstance(stype, Structure): - dtype = stype.dtype - else: - dtype = dtypes.pointer(stype.dtype) - else: - dtype = dtypes.pointer(dtypes.typeclass(None)) # void* - super(ContainerArray, - self).__init__(dtype, shape, transient, allow_conflicts, storage, location, strides, offset, may_alias, - lifetime, alignment, debuginfo, total_size, start_offset, optional, pool) + args = { + field_name: _make_arg(field_value, self.members[field_name], field_name) + for field_name, field_value in fields.items() if field_name in self.members + } - @classmethod - def from_json(cls, json_obj, context=None): - # Create dummy object - ret = cls(None, ()) - serialize.set_properties_from_json(ret, json_obj, context=context) + struct_instance = struct_ctype(**args) + return struct_instance - # Default shape-related properties - if not ret.offset: - ret.offset = [0] * len(ret.shape) - if not ret.strides: - # Default strides are C-ordered - ret.strides = [_prod(ret.shape[i + 1:]) for i in range(len(ret.shape))] - if ret.total_size == 0: - ret.total_size = _prod(ret.shape) + def make_argument_from_object(self, obj) -> ctypes.Structure: + """ + Creates a structure instance from the given object, which can be used as + an argument for DaCe programs. If the object has attributes matching the field names, + those attributes are used as field values. Other attributes are ignored. - return ret + :param obj: Object containing field values. + :return: A ctypes Structure instance. + """ + return self.make_argument(**{field_name: getattr(obj, field_name) for field_name in self.members}) class View: @@ -2245,230 +1431,3 @@ def as_array(self): copy = cp.deepcopy(self) copy.__class__ = ContainerArray return copy - - -def make_array_from_descriptor(descriptor: Array, - original_array: Optional[ArrayLike] = None, - symbols: Optional[Dict[str, Any]] = None) -> ArrayLike: - """ - Creates an array that matches the given data descriptor, and optionally copies another array to it. - - :param descriptor: The data descriptor to create the array from. - :param original_array: An optional array to fill the content of the return value with. - :param symbols: An optional symbol mapping between symbol names and their values. Used for creating arrays - with symbolic sizes. - :return: A NumPy-compatible array (CuPy for GPU storage) with the specified size and strides. - """ - symbols = symbols or {} - - free_syms = set(map(str, descriptor.free_symbols)) - symbols.keys() - if free_syms: - raise NotImplementedError(f'Cannot make Python references to arrays with undefined symbolic sizes: {free_syms}') - - if descriptor.storage == dtypes.StorageType.GPU_Global: - try: - import cupy as cp - except (ImportError, ModuleNotFoundError): - raise NotImplementedError('GPU memory can only be allocated in Python if cupy is installed') - - def create_array(shape: Tuple[int], dtype: np.dtype, total_size: int, strides: Tuple[int]) -> ArrayLike: - buffer = cp.ndarray(shape=[total_size], dtype=dtype) - view = cp.ndarray(shape=shape, - dtype=dtype, - memptr=buffer.data, - strides=[s * dtype.itemsize for s in strides]) - return view - - def copy_array(dst, src): - dst[:] = cp.asarray(src) - - elif descriptor.storage == dtypes.StorageType.FPGA_Global: - raise TypeError('Cannot allocate FPGA array in Python') - else: - - def create_array(shape: Tuple[int], dtype: np.dtype, total_size: int, strides: Tuple[int]) -> ArrayLike: - buffer = np.ndarray([total_size], dtype=dtype) - view = np.ndarray(shape, dtype, buffer=buffer, strides=[s * dtype.itemsize for s in strides]) - return view - - def copy_array(dst, src): - dst[:] = src - - # Make numpy array from data descriptor - npdtype = descriptor.dtype.as_numpy_dtype() - evaluated_shape = tuple(symbolic.evaluate(s, symbols) for s in descriptor.shape) - evaluated_size = symbolic.evaluate(descriptor.total_size, symbols) - evaluated_strides = tuple(symbolic.evaluate(s, symbols) for s in descriptor.strides) - view = create_array(evaluated_shape, npdtype, evaluated_size, evaluated_strides) - if original_array is not None: - copy_array(view, original_array) - - return view - - -def make_reference_from_descriptor(descriptor: Array, - original_array: ctypes.c_void_p, - symbols: Optional[Dict[str, Any]] = None) -> ArrayLike: - """ - Creates an array that matches the given data descriptor from the given pointer. Shares the memory - with the argument (does not create a copy). - - :param descriptor: The data descriptor to create the array from. - :param original_array: The array whose memory the return value would be used in. - :param symbols: An optional symbol mapping between symbol names and their values. Used for referencing arrays - with symbolic sizes. - :return: A NumPy-compatible array (CuPy for GPU storage) with the specified size and strides, sharing memory - with the pointer specified in ``original_array``. - """ - symbols = symbols or {} - - original_array: int = ctypes.cast(original_array, ctypes.c_void_p).value - - free_syms = set(map(str, descriptor.free_symbols)) - symbols.keys() - if free_syms: - raise NotImplementedError(f'Cannot make Python references to arrays with undefined symbolic sizes: {free_syms}') - - if descriptor.storage == dtypes.StorageType.GPU_Global: - try: - import cupy as cp - except (ImportError, ModuleNotFoundError): - raise NotImplementedError('GPU memory can only be referenced in Python if cupy is installed') - - def create_array(shape: Tuple[int], dtype: np.dtype, total_size: int, strides: Tuple[int]) -> ArrayLike: - buffer = dtypes.ptrtocupy(original_array, descriptor.dtype.as_ctypes(), (total_size, )) - view = cp.ndarray(shape=shape, - dtype=dtype, - memptr=buffer.data, - strides=[s * dtype.itemsize for s in strides]) - return view - - elif descriptor.storage == dtypes.StorageType.FPGA_Global: - raise TypeError('Cannot reference FPGA array in Python') - else: - - def create_array(shape: Tuple[int], dtype: np.dtype, total_size: int, strides: Tuple[int]) -> ArrayLike: - buffer = dtypes.ptrtonumpy(original_array, descriptor.dtype.as_ctypes(), (total_size, )) - view = np.ndarray(shape, dtype, buffer=buffer, strides=[s * dtype.itemsize for s in strides]) - return view - - # Make numpy array from data descriptor - npdtype = descriptor.dtype.as_numpy_dtype() - evaluated_shape = tuple(symbolic.evaluate(s, symbols) for s in descriptor.shape) - evaluated_size = symbolic.evaluate(descriptor.total_size, symbols) - evaluated_strides = tuple(symbolic.evaluate(s, symbols) for s in descriptor.strides) - return create_array(evaluated_shape, npdtype, evaluated_size, evaluated_strides) - - -def make_ctypes_argument(arg: Any, - argtype: Data, - name: Optional[str] = None, - allow_views: Optional[bool] = None, - symbols: Optional[Dict[str, Any]] = None, - callback_retval_references: Optional[List[Any]] = None) -> Any: - """ - Converts a given argument to the expected ``ctypes`` type for passing to compiled SDFG functions. - - :param arg: The argument to convert. - :param argtype: The expected data descriptor type of the argument. - :param name: The name of the argument (for error messages). - :param allow_views: Whether to allow views and references as input. If False, raises an error if a view or - reference is passed. If None (default), uses the global configuration setting - ``compiler.allow_view_arguments``. - :param symbols: An optional symbol mapping between symbol names and their values. Used for evaluating symbolic - sizes in callback arguments. - :param callback_retval_references: A list to store references to callback return values (to avoid garbage - collection of said return values). This object must be kept alive until the - SDFG call is complete. - :return: The argument converted to the appropriate ctypes type. - """ - if allow_views is None: - no_view_arguments = not config.Config.get_bool('compiler', 'allow_view_arguments') - else: - no_view_arguments = not allow_views - a = name or '' - atype = argtype - - result = arg - is_array = dtypes.is_array(arg) - is_ndarray = isinstance(arg, np.ndarray) - is_dtArray = isinstance(argtype, Array) - if not is_array and is_dtArray: - if isinstance(arg, list): - print(f'WARNING: Casting list argument "{a}" to ndarray') - elif arg is None: - if atype.optional is False: # If array cannot be None - raise TypeError(f'Passing a None value to a non-optional array in argument "{a}"') - # Otherwise, None values are passed as null pointers below - elif isinstance(arg, ctypes._Pointer): - pass - elif isinstance(arg, str): - # Cast to bytes - result = ctypes.c_char_p(arg.encode('utf-8')) - else: - raise TypeError(f'Passing an object (type {type(arg).__name__}) to an array in argument "{a}"') - elif is_array and not is_dtArray: - # GPU scalars and return values are pointers, so this is fine - if atype.storage != dtypes.StorageType.GPU_Global and not a.startswith('__return'): - raise TypeError(f'Passing an array to a scalar (type {atype.dtype.ctype}) in argument "{a}"') - elif (is_dtArray and is_ndarray and not isinstance(atype, ContainerArray) - and atype.dtype.as_numpy_dtype() != arg.dtype): - # Make exception for vector types - if (isinstance(atype.dtype, dtypes.vector) and atype.dtype.vtype.as_numpy_dtype() == arg.dtype): - pass - else: - print(f'WARNING: Passing {arg.dtype} array argument "{a}" to a {atype.dtype.type.__name__} array') - elif is_dtArray and is_ndarray and arg.base is not None and not '__return' in a and no_view_arguments: - raise TypeError(f'Passing a numpy view (e.g., sub-array or "A.T") "{a}" to DaCe ' - 'programs is not allowed in order to retain analyzability. ' - 'Please make a copy with "numpy.copy(...)". If you know what ' - 'you are doing, you can override this error in the ' - 'configuration by setting compiler.allow_view_arguments ' - 'to True.') - elif (not isinstance(atype, (Array, Structure)) and not isinstance(atype.dtype, dtypes.callback) - and not isinstance(arg, (atype.dtype.type, sp.Basic)) - and not (isinstance(arg, symbolic.symbol) and arg.dtype == atype.dtype)): - is_int = isinstance(arg, int) - if is_int and atype.dtype.type == np.int64: - pass - elif (is_int and atype.dtype.type == np.int32 and abs(arg) <= (1 << 31) - 1): - pass - elif (is_int and atype.dtype.type == np.uint32 and arg >= 0 and arg <= (1 << 32) - 1): - pass - elif isinstance(arg, float) and atype.dtype.type == np.float64: - pass - elif isinstance(arg, bool) and atype.dtype.type == np.bool_: - pass - elif (isinstance(arg, str) or arg is None) and atype.dtype == dtypes.string: - if arg is None: - result = ctypes.c_char_p(None) - else: - # Cast to bytes - result = ctypes.c_char_p(arg.encode('utf-8')) - else: - warnings.warn(f'Casting scalar argument "{a}" from {type(arg).__name__} to {atype.dtype.type}') - result = atype.dtype.type(arg) - - # Call a wrapper function to make NumPy arrays from pointers. - if isinstance(argtype.dtype, dtypes.callback): - result = argtype.dtype.get_trampoline(result, symbols or {}, callback_retval_references) - # List to array - elif isinstance(result, list) and isinstance(argtype, Array): - result = np.array(result, dtype=argtype.dtype.type) - # Null pointer - elif result is None and isinstance(argtype, Array): - result = ctypes.c_void_p(0) - - # Retain only the element datatype for upcoming checks and casts - actype = argtype.dtype.as_ctypes() - - try: - if dtypes.is_array(result): # `c_void_p` is subclass of `ctypes._SimpleCData`. - result = ctypes.c_void_p(dtypes.array_interface_ptr(result, atype.storage)) - elif not isinstance(result, (ctypes._SimpleCData, ctypes._Pointer)): - result = actype(result) - else: - pass - except TypeError as ex: - raise TypeError(f'Invalid type for scalar argument "{a}": {ex}') - - return result diff --git a/dace/data/creation.py b/dace/data/creation.py new file mode 100644 index 0000000000..8dec40156a --- /dev/null +++ b/dace/data/creation.py @@ -0,0 +1,239 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Data descriptor creation functions. + +This module contains functions for creating data descriptors from arbitrary objects, +as well as functions for creating arrays from descriptors. +""" +import ctypes + +from numbers import Number +from typing import Any, Dict, Optional, Tuple + +import numpy as np + +try: + from numpy.typing import ArrayLike +except (ModuleNotFoundError, ImportError): + ArrayLike = Any + +from dace import dtypes, symbolic +from dace.data.core import Array, Data, Scalar + + +def create_datadescriptor(obj, no_custom_desc=False): + """ Creates a data descriptor from various types of objects. + + :see: dace.data.Data + """ + if isinstance(obj, Data): + return obj + elif not no_custom_desc and hasattr(obj, '__descriptor__'): + return obj.__descriptor__() + elif not no_custom_desc and hasattr(obj, 'descriptor'): + return obj.descriptor + elif type(obj).__module__ == "torch" and type(obj).__name__ == "Tensor": + # special case for torch tensors. Maybe __array__ could be used here for a more + # general solution, but torch doesn't support __array__ for cuda tensors. + try: + # If torch is importable, define translations between typeclasses and torch types. These are reused by daceml. + # conversion happens here in pytorch: + # https://github.com/pytorch/pytorch/blob/143ef016ee1b6a39cf69140230d7c371de421186/torch/csrc/utils/tensor_numpy.cpp#L237 + import torch + TYPECLASS_TO_TORCH_DTYPE = { + dtypes.bool_: torch.bool, + dtypes.int8: torch.int8, + dtypes.int16: torch.int16, + dtypes.int32: torch.int32, + dtypes.int64: torch.int64, + dtypes.uint8: torch.uint8, + dtypes.float16: torch.float16, + dtypes.float32: torch.float32, + dtypes.float64: torch.float64, + dtypes.complex64: torch.complex64, + dtypes.complex128: torch.complex128, + } + + TORCH_DTYPE_TO_TYPECLASS = {v: k for k, v in TYPECLASS_TO_TORCH_DTYPE.items()} + + storage = dtypes.StorageType.GPU_Global if obj.device.type == 'cuda' else dtypes.StorageType.Default + + return Array(dtype=TORCH_DTYPE_TO_TYPECLASS[obj.dtype], + strides=obj.stride(), + shape=tuple(obj.shape), + storage=storage) + except ImportError: + raise ValueError("Attempted to convert a torch.Tensor, but torch could not be imported") + elif dtypes.is_array(obj) and (hasattr(obj, '__array_interface__') or hasattr(obj, '__cuda_array_interface__')): + if dtypes.is_gpu_array(obj): + interface = obj.__cuda_array_interface__ + storage = dtypes.StorageType.GPU_Global + else: + interface = obj.__array_interface__ + storage = dtypes.StorageType.Default + + if hasattr(obj, 'dtype') and obj.dtype.fields is not None: # Struct + dtype = dtypes.struct('unnamed', **{k: dtypes.typeclass(v[0].type) for k, v in obj.dtype.fields.items()}) + else: + if np.dtype(interface['typestr']).type is np.void: # Struct from __array_interface__ + if 'descr' in interface: + dtype = dtypes.struct('unnamed', **{ + k: dtypes.typeclass(np.dtype(v).type) + for k, v in interface['descr'] + }) + else: + raise TypeError(f'Cannot infer data type of array interface object "{interface}"') + else: + dtype = dtypes.typeclass(np.dtype(interface['typestr']).type) + itemsize = np.dtype(interface['typestr']).itemsize + if len(interface['shape']) == 0: + return Scalar(dtype, storage=storage) + return Array(dtype=dtype, + shape=interface['shape'], + strides=(tuple(s // itemsize for s in interface['strides']) if interface['strides'] else None), + storage=storage) + elif isinstance(obj, (list, tuple)): + # Lists and tuples are cast to numpy + obj = np.array(obj) + + if obj.dtype.fields is not None: # Struct + dtype = dtypes.struct('unnamed', **{k: dtypes.typeclass(v[0].type) for k, v in obj.dtype.fields.items()}) + else: + dtype = dtypes.typeclass(obj.dtype.type) + return Array(dtype=dtype, strides=tuple(s // obj.itemsize for s in obj.strides), shape=obj.shape) + elif type(obj).__module__ == "cupy" and type(obj).__name__ == "ndarray": + # special case for CuPy and HIP, which does not support __cuda_array_interface__ + storage = dtypes.StorageType.GPU_Global + dtype = dtypes.typeclass(obj.dtype.type) + itemsize = obj.itemsize + return Array(dtype=dtype, shape=obj.shape, strides=tuple(s // itemsize for s in obj.strides), storage=storage) + elif symbolic.issymbolic(obj): + return Scalar(symbolic.symtype(obj)) + elif isinstance(obj, dtypes.typeclass): + return Scalar(obj) + elif (obj is int or obj is float or obj is complex or obj is bool or obj is None): + return Scalar(dtypes.typeclass(obj)) + elif isinstance(obj, type) and issubclass(obj, np.number): + return Scalar(dtypes.typeclass(obj)) + elif isinstance(obj, (Number, np.number, np.bool_)): + return Scalar(dtypes.typeclass(type(obj))) + elif obj is type(None): + # NoneType is void * + return Scalar(dtypes.pointer(dtypes.typeclass(None))) + elif isinstance(obj, str) or obj is str: + return Scalar(dtypes.string) + elif callable(obj): + # Cannot determine return value/argument types from function object + return Scalar(dtypes.callback(None)) + + raise TypeError(f'Could not create a DaCe data descriptor from object {obj}. ' + 'If this is a custom object, consider creating a `__descriptor__` ' + 'adaptor method to the type hint or object itself.') + + +def make_array_from_descriptor(descriptor: Array, + original_array: Optional[ArrayLike] = None, + symbols: Optional[Dict[str, Any]] = None) -> ArrayLike: + """ + Creates an array that matches the given data descriptor, and optionally copies another array to it. + + :param descriptor: The data descriptor to create the array from. + :param original_array: An optional array to fill the content of the return value with. + :param symbols: An optional symbol mapping between symbol names and their values. Used for creating arrays + with symbolic sizes. + :return: A NumPy-compatible array (CuPy for GPU storage) with the specified size and strides. + """ + symbols = symbols or {} + + free_syms = set(map(str, descriptor.free_symbols)) - symbols.keys() + if free_syms: + raise NotImplementedError(f'Cannot make Python references to arrays with undefined symbolic sizes: {free_syms}') + + if descriptor.storage == dtypes.StorageType.GPU_Global: + try: + import cupy as cp + except (ImportError, ModuleNotFoundError): + raise NotImplementedError('GPU memory can only be allocated in Python if cupy is installed') + + def create_array(shape: Tuple[int], dtype: np.dtype, total_size: int, strides: Tuple[int]) -> ArrayLike: + buffer = cp.ndarray(shape=[total_size], dtype=dtype) + view = cp.ndarray(shape=shape, + dtype=dtype, + memptr=buffer.data, + strides=[s * dtype.itemsize for s in strides]) + return view + + def copy_array(dst, src): + dst[:] = cp.asarray(src) + + else: + + def create_array(shape: Tuple[int], dtype: np.dtype, total_size: int, strides: Tuple[int]) -> ArrayLike: + buffer = np.ndarray([total_size], dtype=dtype) + view = np.ndarray(shape, dtype, buffer=buffer, strides=[s * dtype.itemsize for s in strides]) + return view + + def copy_array(dst, src): + dst[:] = src + + # Make numpy array from data descriptor + npdtype = descriptor.dtype.as_numpy_dtype() + evaluated_shape = tuple(symbolic.evaluate(s, symbols) for s in descriptor.shape) + evaluated_size = symbolic.evaluate(descriptor.total_size, symbols) + evaluated_strides = tuple(symbolic.evaluate(s, symbols) for s in descriptor.strides) + view = create_array(evaluated_shape, npdtype, evaluated_size, evaluated_strides) + if original_array is not None: + copy_array(view, original_array) + + return view + + +def make_reference_from_descriptor(descriptor: Array, + original_array: ctypes.c_void_p, + symbols: Optional[Dict[str, Any]] = None) -> ArrayLike: + """ + Creates an array that matches the given data descriptor from the given pointer. Shares the memory + with the argument (does not create a copy). + + :param descriptor: The data descriptor to create the array from. + :param original_array: The array whose memory the return value would be used in. + :param symbols: An optional symbol mapping between symbol names and their values. Used for referencing arrays + with symbolic sizes. + :return: A NumPy-compatible array (CuPy for GPU storage) with the specified size and strides, sharing memory + with the pointer specified in ``original_array``. + """ + symbols = symbols or {} + + original_array: int = ctypes.cast(original_array, ctypes.c_void_p).value + + free_syms = set(map(str, descriptor.free_symbols)) - symbols.keys() + if free_syms: + raise NotImplementedError(f'Cannot make Python references to arrays with undefined symbolic sizes: {free_syms}') + + if descriptor.storage == dtypes.StorageType.GPU_Global: + try: + import cupy as cp + except (ImportError, ModuleNotFoundError): + raise NotImplementedError('GPU memory can only be referenced in Python if cupy is installed') + + def create_array(shape: Tuple[int], dtype: np.dtype, total_size: int, strides: Tuple[int]) -> ArrayLike: + buffer = dtypes.ptrtocupy(original_array, descriptor.dtype.as_ctypes(), (total_size, )) + view = cp.ndarray(shape=shape, + dtype=dtype, + memptr=buffer.data, + strides=[s * dtype.itemsize for s in strides]) + return view + + else: + + def create_array(shape: Tuple[int], dtype: np.dtype, total_size: int, strides: Tuple[int]) -> ArrayLike: + buffer = dtypes.ptrtonumpy(original_array, descriptor.dtype.as_ctypes(), (total_size, )) + view = np.ndarray(shape, dtype, buffer=buffer, strides=[s * dtype.itemsize for s in strides]) + return view + + # Make numpy array from data descriptor + npdtype = descriptor.dtype.as_numpy_dtype() + evaluated_shape = tuple(symbolic.evaluate(s, symbols) for s in descriptor.shape) + evaluated_size = symbolic.evaluate(descriptor.total_size, symbols) + evaluated_strides = tuple(symbolic.evaluate(s, symbols) for s in descriptor.strides) + return create_array(evaluated_shape, npdtype, evaluated_size, evaluated_strides) diff --git a/dace/data/ctypes_interop.py b/dace/data/ctypes_interop.py new file mode 100644 index 0000000000..d9dfba58e1 --- /dev/null +++ b/dace/data/ctypes_interop.py @@ -0,0 +1,133 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Ctypes interoperability for data descriptors. + +This module contains functions for converting data descriptors to ctypes. +""" +import ctypes +import warnings + +from typing import Any, Dict, List, Optional + +import numpy as np +import sympy as sp + +from dace import config, dtypes, symbolic + + +def make_ctypes_argument(arg: Any, + argtype: 'Data', + name: Optional[str] = None, + allow_views: Optional[bool] = None, + symbols: Optional[Dict[str, Any]] = None, + callback_retval_references: Optional[List[Any]] = None) -> Any: + """ + Converts a given argument to the expected ``ctypes`` type for passing to compiled SDFG functions. + + :param arg: The argument to convert. + :param argtype: The expected data descriptor type of the argument. + :param name: The name of the argument (for error messages). + :param allow_views: Whether to allow views and references as input. If False, raises an error if a view or + reference is passed. If None (default), uses the global configuration setting + ``compiler.allow_view_arguments``. + :param symbols: An optional symbol mapping between symbol names and their values. Used for evaluating symbolic + sizes in callback arguments. + :param callback_retval_references: A list to store references to callback return values (to avoid garbage + collection of said return values). This object must be kept alive until the + SDFG call is complete. + :return: The argument converted to the appropriate ctypes type. + """ + # Import here to avoid circular imports + from dace.data.core import Array, ContainerArray, Structure + + if allow_views is None: + no_view_arguments = not config.Config.get_bool('compiler', 'allow_view_arguments') + else: + no_view_arguments = not allow_views + a = name or '' + atype = argtype + + result = arg + is_array = dtypes.is_array(arg) + is_ndarray = isinstance(arg, np.ndarray) + is_dtArray = isinstance(argtype, Array) + if not is_array and is_dtArray: + if isinstance(arg, list): + print(f'WARNING: Casting list argument "{a}" to ndarray') + elif arg is None: + if atype.optional is False: # If array cannot be None + raise TypeError(f'Passing a None value to a non-optional array in argument "{a}"') + # Otherwise, None values are passed as null pointers below + elif isinstance(arg, ctypes._Pointer): + pass + elif isinstance(arg, str): + # Cast to bytes + result = ctypes.c_char_p(arg.encode('utf-8')) + else: + raise TypeError(f'Passing an object (type {type(arg).__name__}) to an array in argument "{a}"') + elif is_array and not is_dtArray: + # GPU scalars and return values are pointers, so this is fine + if atype.storage != dtypes.StorageType.GPU_Global and not a.startswith('__return'): + raise TypeError(f'Passing an array to a scalar (type {atype.dtype.ctype}) in argument "{a}"') + elif (is_dtArray and is_ndarray and not isinstance(atype, ContainerArray) + and atype.dtype.as_numpy_dtype() != arg.dtype): + # Make exception for vector types + if (isinstance(atype.dtype, dtypes.vector) and atype.dtype.vtype.as_numpy_dtype() == arg.dtype): + pass + else: + print(f'WARNING: Passing {arg.dtype} array argument "{a}" to a {atype.dtype.type.__name__} array') + elif is_dtArray and is_ndarray and arg.base is not None and not '__return' in a and no_view_arguments: + raise TypeError(f'Passing a numpy view (e.g., sub-array or "A.T") "{a}" to DaCe ' + 'programs is not allowed in order to retain analyzability. ' + 'Please make a copy with "numpy.copy(...)". If you know what ' + 'you are doing, you can override this error in the ' + 'configuration by setting compiler.allow_view_arguments ' + 'to True.') + elif (not isinstance(atype, (Array, Structure)) and not isinstance(atype.dtype, dtypes.callback) + and not isinstance(arg, (atype.dtype.type, sp.Basic)) + and not (isinstance(arg, symbolic.symbol) and arg.dtype == atype.dtype)): + is_int = isinstance(arg, int) + if is_int and atype.dtype.type == np.int64: + pass + elif (is_int and atype.dtype.type == np.int32 and abs(arg) <= (1 << 31) - 1): + pass + elif (is_int and atype.dtype.type == np.uint32 and arg >= 0 and arg <= (1 << 32) - 1): + pass + elif isinstance(arg, float) and atype.dtype.type == np.float64: + pass + elif isinstance(arg, bool) and atype.dtype.type == np.bool_: + pass + elif (isinstance(arg, str) or arg is None) and atype.dtype == dtypes.string: + if arg is None: + result = ctypes.c_char_p(None) + else: + # Cast to bytes + result = ctypes.c_char_p(arg.encode('utf-8')) + else: + warnings.warn(f'Casting scalar argument "{a}" from {type(arg).__name__} to {atype.dtype.type}') + result = atype.dtype.type(arg) + + # Call a wrapper function to make NumPy arrays from pointers. + if isinstance(argtype.dtype, dtypes.callback): + result = argtype.dtype.get_trampoline(result, symbols or {}, callback_retval_references) + # List to array + elif isinstance(result, list) and isinstance(argtype, Array): + result = np.array(result, dtype=argtype.dtype.type) + # Null pointer + elif result is None and isinstance(argtype, Array): + result = ctypes.c_void_p(0) + + # Retain only the element datatype for upcoming checks and casts + actype = argtype.dtype.as_ctypes() + + try: + if dtypes.is_array(result): # `c_void_p` is subclass of `ctypes._SimpleCData`. + result = ctypes.c_void_p(dtypes.array_interface_ptr(result, atype.storage)) + elif not isinstance(result, (ctypes._SimpleCData, ctypes._Pointer)): + result = actype(result) + else: + pass + except TypeError as ex: + raise TypeError(f'Invalid type for scalar argument "{a}": {ex}') + + return result diff --git a/dace/data/ml.py b/dace/data/ml.py new file mode 100644 index 0000000000..5f26c6e3a1 --- /dev/null +++ b/dace/data/ml.py @@ -0,0 +1,113 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +ML-related data descriptors. + +This module contains data descriptors that are specific to machine learning workflows, +such as ParameterArray for automatic differentiation. +""" +import copy + +from dace import properties +from dace.data.core import Array +from dace.sdfg import SDFG, nodes + + +@properties.make_properties +class ParameterArray(Array): + """ + An array for which a gradient can be computed. + """ + # since this can be None, this is not a DataProperty + gradient = properties.Property(dtype=str, desc="The corresponding gradient buffer", default=None, allow_none=True) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def __repr__(self): + return "Parameter" + Array.__repr__(self) + + def add_gradient_buffer(self, sdfg: SDFG, name: str) -> str: + """ + Find or create a gradient buffer for the parameter in the given SDFG. + + :param sdfg: the SDFG containing the parameter + :param name: the name of the parameter + :return: the name of the gradient buffer + """ + + if self.gradient: + return self.gradient + + # First, check if this array already has a gradient buffer in a nested + # SDFG. This happens, for example when pytorch modules are used in the + # frontend. In that case: + # 1. the parser assembles the closure of the module, which adds + # descriptors for all the parameters and their gradients (if they + # are required). + # 2. A nested sdfg is added for the module, with those array names. + # 3. The DaceProgram will then pass these arrays in when the + # DaceProgram is called, using the names from the closure that + # match the names from the NestedSDFG + # 4. When parsing the backward nodes, we want the gradient buffers in + # the closure to match the gradient buffers that we pass in. Thus, + # we need to make sure that we use the same name as the NestedSDFG + # + # Note that we do not currently do any nesting beyond this level, + # because nested modules are converted to one SDFG. + + cands = set() + for state in sdfg.nodes(): + for node in state.nodes(): + if not isinstance(node, nodes.NestedSDFG): + continue + + nested_names = set() + + for edge in state.in_edges(node): + if edge.data.data == name: + nested_names.add(edge.dst_conn) + for edge in state.out_edges(node): + if edge.data.data == name: + nested_names.add(edge.dst_conn) + + for name in nested_names: + nested_desc = node.sdfg.arrays[name] + if isinstance(nested_desc, ParameterArray) and nested_desc.gradient: + cands.add(nested_desc.gradient) + + if len(cands) > 1: + raise ValueError("Multiple gradient buffers found for parameter " + name) + elif len(cands) == 1: + # we found a name of a gradient buffer in a nested SDFG: + # reuse the same name in the outer sdfg if there is a matching descriptor + grad_name = cands.pop() + if grad_name in sdfg.arrays: + self.gradient = grad_name + return grad_name + else: + grad_name = sdfg._find_new_name('gradient_' + name) + + # Create a gradient buffer for the array + grad_desc = copy.deepcopy(self) + grad_desc.__class__ = Array + grad_desc.transient = True + grad_name = sdfg.add_datadesc(grad_name, grad_desc, find_new_name=True) + self.gradient = grad_name + return grad_name + + @staticmethod + def make_parameter(sdfg: SDFG, name: str): + """ + Converts an existing array into a parameter, without copying. + + :param sdfg: the SDFG containing the array. + :param name: the name of the array. + """ + desc = sdfg.arrays[name] + if isinstance(desc, ParameterArray): + return + + new_desc = copy.deepcopy(desc) + new_desc.__class__ = ParameterArray + new_desc.gradient = None + sdfg.arrays[name] = new_desc diff --git a/dace/data/tensor.py b/dace/data/tensor.py new file mode 100644 index 0000000000..4444c31f0f --- /dev/null +++ b/dace/data/tensor.py @@ -0,0 +1,698 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Tensor data descriptors for sparse tensor formats. + +This module contains classes for representing various sparse tensor storage formats +based on the abstraction described in [https://doi.org/10.1145/3276493]. +""" +import aenum + +from abc import ABC, abstractmethod +from typing import Dict, List, Tuple, Union + +from dace import dtypes, serialize, symbolic +from dace.data.core import Data, Scalar, Structure +from dace.properties import ListProperty, Property, ShapeProperty, SymbolicProperty, TypeClassProperty, make_properties + + +class TensorIterationTypes(aenum.AutoNumberEnum): + """ + Types of tensor iteration capabilities. + + Value (Coordinate Value Iteration) allows to directly iterate over + coordinates such as when using the Dense index type. + + Position (Coordinate Position Iteratation) iterates over coordinate + positions, at which the actual coordinates lie. This is for example the case + with a compressed index, in which the pos array enables one to iterate over + the positions in the crd array that hold the actual coordinates. + """ + Value = () + Position = () + + +class TensorAssemblyType(aenum.AutoNumberEnum): + """ + Types of possible assembly strategies for the individual indices. + + NoAssembly: Assembly is not possible as such. + + Insert: index allows inserting elements at random (e.g. Dense) + + Append: index allows appending to a list of existing coordinates. Depending + on append order, this affects whether the index is ordered or not. This + could be changed by sorting the index after assembly + """ + NoAssembly = () + Insert = () + Append = () + + +class TensorIndex(ABC): + """ + Abstract base class for tensor index implementations. + """ + + @property + @abstractmethod + def iteration_type(self) -> TensorIterationTypes: + """ + Iteration capability supported by this index. + + See TensorIterationTypes for reference. + """ + pass + + @property + @abstractmethod + def locate(self) -> bool: + """ + True if the index supports locate (aka random access), False otw. + """ + pass + + @property + @abstractmethod + def assembly(self) -> TensorAssemblyType: + """ + What assembly type is supported by the index. + + See TensorAssemblyType for reference. + """ + pass + + @property + @abstractmethod + def full(self) -> bool: + """ + True if the level is full, False otw. + + A level is considered full if it encompasses all valid coordinates along + the corresponding tensor dimension. + """ + pass + + @property + @abstractmethod + def ordered(self) -> bool: + """ + True if the level is ordered, False otw. + + A level is ordered when all coordinates that share the same ancestor are + ordered by increasing value (e.g. in typical CSR). + """ + pass + + @property + @abstractmethod + def unique(self) -> bool: + """ + True if coordinate in the level are unique, False otw. + + A level is considered unique if no collection of coordinates that share + the same ancestor contains duplicates. In CSR this is True, in COO it is + not. + """ + pass + + @property + @abstractmethod + def branchless(self) -> bool: + """ + True if the level doesn't branch, false otw. + + A level is considered branchless if no coordinate has a sibling (another + coordinate with same ancestor) and all coordinates in parent level have + a child. In other words if there is a bijection between the coordinates + in this level and the parent level. An example of the is the Singleton + index level in the COO format. + """ + pass + + @property + @abstractmethod + def compact(self) -> bool: + """ + True if the level is compact, false otw. + + A level is compact if no two coordinates are separated by an unlabled + node that does not encode a coordinate. An example of a compact level + can be found in CSR, while the DIA formats range and offset levels are + not compact (they have entries that would coorespond to entries outside + the tensors index range, e.g. column -1). + """ + pass + + @abstractmethod + def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: + """ + Generates the fields needed for the index. + + :return: a Dict of fields that need to be present in the struct + """ + pass + + def to_json(self): + attrs = serialize.all_properties_to_json(self) + + retdict = {"type": type(self).__name__, "attributes": attrs} + + return retdict + + @classmethod + def from_json(cls, json_obj, context=None): + + # Selecting proper subclass + if json_obj['type'] == "TensorIndexDense": + self = TensorIndexDense.__new__(TensorIndexDense) + elif json_obj['type'] == "TensorIndexCompressed": + self = TensorIndexCompressed.__new__(TensorIndexCompressed) + elif json_obj['type'] == "TensorIndexSingleton": + self = TensorIndexSingleton.__new__(TensorIndexSingleton) + elif json_obj['type'] == "TensorIndexRange": + self = TensorIndexRange.__new__(TensorIndexRange) + elif json_obj['type'] == "TensorIndexOffset": + self = TensorIndexOffset.__new__(TensorIndexOffset) + else: + raise TypeError(f"Invalid data type, got: {json_obj['type']}") + + serialize.set_properties_from_json(self, json_obj['attributes'], context=context) + + return self + + +@make_properties +class TensorIndexDense(TensorIndex): + """ + Dense tensor index. + + Levels of this type encode the the coordinate in the interval [0, N), where + N is the size of the corresponding dimension. This level doesn't need any + index structure beyond the corresponding dimension size. + """ + + _ordered = Property(dtype=bool, default=False) + _unique = Property(dtype=bool) + + @property + def iteration_type(self) -> TensorIterationTypes: + return TensorIterationTypes.Value + + @property + def locate(self) -> bool: + return True + + @property + def assembly(self) -> TensorAssemblyType: + return TensorAssemblyType.Insert + + @property + def full(self) -> bool: + return True + + @property + def ordered(self) -> bool: + return self._ordered + + @property + def unique(self) -> bool: + return self._unique + + @property + def branchless(self) -> bool: + return False + + @property + def compact(self) -> bool: + return True + + def __init__(self, ordered: bool = True, unique: bool = True): + self._ordered = ordered + self._unique = unique + + def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: + return {} + + def __repr__(self) -> str: + s = "Dense" + + non_defaults = [] + if not self._ordered: + non_defaults.append("¬O") + if not self._unique: + non_defaults.append("¬U") + + if len(non_defaults) > 0: + s += f"({','.join(non_defaults)})" + + return s + + +@make_properties +class TensorIndexCompressed(TensorIndex): + """ + Tensor level that stores coordinates in segmented array. + + Levels of this type are compressed using a segented array. The pos array + holds the start and end positions of the segment in the crd (coordinate) + array that holds the child coordinates corresponding the parent. + """ + + _full = Property(dtype=bool, default=False) + _ordered = Property(dtype=bool, default=False) + _unique = Property(dtype=bool, default=False) + + @property + def iteration_type(self) -> TensorIterationTypes: + return TensorIterationTypes.Position + + @property + def locate(self) -> bool: + return False + + @property + def assembly(self) -> TensorAssemblyType: + return TensorAssemblyType.Append + + @property + def full(self) -> bool: + return self._full + + @property + def ordered(self) -> bool: + return self._ordered + + @property + def unique(self) -> bool: + return self._unique + + @property + def branchless(self) -> bool: + return False + + @property + def compact(self) -> bool: + return True + + def __init__(self, full: bool = False, ordered: bool = True, unique: bool = True): + self._full = full + self._ordered = ordered + self._unique = unique + + def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: + return { + f"idx{lvl}_pos": dtypes.int32[dummy_symbol], # TODO (later) choose better length + f"idx{lvl}_crd": dtypes.int32[dummy_symbol], # TODO (later) choose better length + } + + def __repr__(self) -> str: + s = "Compressed" + + non_defaults = [] + if self._full: + non_defaults.append("F") + if not self._ordered: + non_defaults.append("¬O") + if not self._unique: + non_defaults.append("¬U") + + if len(non_defaults) > 0: + s += f"({','.join(non_defaults)})" + + return s + + +@make_properties +class TensorIndexSingleton(TensorIndex): + """ + Tensor index that encodes a single coordinate per parent coordinate. + + Levels of this type hold exactly one coordinate for every coordinate in the + parent level. An example can be seen in the COO format, where every + coordinate but the first is encoded in this manner. + """ + + _full = Property(dtype=bool, default=False) + _ordered = Property(dtype=bool, default=False) + _unique = Property(dtype=bool, default=False) + + @property + def iteration_type(self) -> TensorIterationTypes: + return TensorIterationTypes.Position + + @property + def locate(self) -> bool: + return False + + @property + def assembly(self) -> TensorAssemblyType: + return TensorAssemblyType.Append + + @property + def full(self) -> bool: + return self._full + + @property + def ordered(self) -> bool: + return self._ordered + + @property + def unique(self) -> bool: + return self._unique + + @property + def branchless(self) -> bool: + return True + + @property + def compact(self) -> bool: + return True + + def __init__(self, full: bool = False, ordered: bool = True, unique: bool = True): + self._full = full + self._ordered = ordered + self._unique = unique + + def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: + return { + f"idx{lvl}_crd": dtypes.int32[dummy_symbol], # TODO (later) choose better length + } + + def __repr__(self) -> str: + s = "Singleton" + + non_defaults = [] + if self._full: + non_defaults.append("F") + if not self._ordered: + non_defaults.append("¬O") + if not self._unique: + non_defaults.append("¬U") + + if len(non_defaults) > 0: + s += f"({','.join(non_defaults)})" + + return s + + +@make_properties +class TensorIndexRange(TensorIndex): + """ + Tensor index that encodes a interval of coordinates for every parent. + + The interval is computed from an offset for each parent together with the + tensor dimension size of this level (M) and the parent level (N) parents + corresponding tensor. Given the parent coordinate i, the level encodes the + range of coordinates between max(0, -offset[i]) and min(N, M - offset[i]). + """ + + _ordered = Property(dtype=bool, default=False) + _unique = Property(dtype=bool, default=False) + + @property + def iteration_type(self) -> TensorIterationTypes: + return TensorIterationTypes.Value + + @property + def locate(self) -> bool: + return False + + @property + def assembly(self) -> TensorAssemblyType: + return TensorAssemblyType.NoAssembly + + @property + def full(self) -> bool: + return False + + @property + def ordered(self) -> bool: + return self._ordered + + @property + def unique(self) -> bool: + return self._unique + + @property + def branchless(self) -> bool: + return False + + @property + def compact(self) -> bool: + return False + + def __init__(self, ordered: bool = True, unique: bool = True): + self._ordered = ordered + self._unique = unique + + def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: + return { + f"idx{lvl}_offset": dtypes.int32[dummy_symbol], # TODO (later) choose better length + } + + def __repr__(self) -> str: + s = "Range" + + non_defaults = [] + if not self._ordered: + non_defaults.append("¬O") + if not self._unique: + non_defaults.append("¬U") + + if len(non_defaults) > 0: + s += f"({','.join(non_defaults)})" + + return s + + +@make_properties +class TensorIndexOffset(TensorIndex): + """ + Tensor index that encodes the next coordinates as offset from parent. + + Given a parent coordinate i and an offset index k, the level encodes the + coordinate j = i + offset[k]. + """ + + _ordered = Property(dtype=bool, default=False) + _unique = Property(dtype=bool, default=False) + + @property + def iteration_type(self) -> TensorIterationTypes: + return TensorIterationTypes.Position + + @property + def locate(self) -> bool: + return False + + @property + def assembly(self) -> TensorAssemblyType: + return TensorAssemblyType.NoAssembly + + @property + def full(self) -> bool: + return False + + @property + def ordered(self) -> bool: + return self._ordered + + @property + def unique(self) -> bool: + return self._unique + + @property + def branchless(self) -> bool: + return True + + @property + def compact(self) -> bool: + return False + + def __init__(self, ordered: bool = True, unique: bool = True): + self._ordered = ordered + self._unique = unique + + def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: + return { + f"idx{lvl}_offset": dtypes.int32[dummy_symbol], # TODO (later) choose better length + } + + def __repr__(self) -> str: + s = "Offset" + + non_defaults = [] + if not self._ordered: + non_defaults.append("¬O") + if not self._unique: + non_defaults.append("¬U") + + if len(non_defaults) > 0: + s += f"({','.join(non_defaults)})" + + return s + + +@make_properties +class Tensor(Structure): + """ + Abstraction for Tensor storage format. + + This abstraction is based on [https://doi.org/10.1145/3276493]. + """ + + value_dtype = TypeClassProperty(default=dtypes.int32, choices=dtypes.Typeclasses) + tensor_shape = ShapeProperty(default=[]) + indices = ListProperty(element_type=TensorIndex) + index_ordering = ListProperty(element_type=symbolic.SymExpr) + value_count = SymbolicProperty(default=0) + + def __init__(self, + value_dtype: dtypes.Typeclasses, + tensor_shape, + indices: List[Tuple[TensorIndex, Union[int, symbolic.SymExpr]]], + value_count: symbolic.SymExpr, + name: str, + transient: bool = False, + storage: dtypes.StorageType = dtypes.StorageType.Default, + location: Dict[str, str] = None, + lifetime: dtypes.AllocationLifetime = dtypes.AllocationLifetime.Scope, + debuginfo: dtypes.DebugInfo = None): + """ + Constructor for Tensor storage format. + + Below are examples of common matrix storage formats: + + .. code-block:: python + + M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) + + csr = dace.data.Tensor( + dace.float32, + (M, N), + [(dace.data.Dense(), 0), (dace.data.Compressed(), 1)], + nnz, + "CSR_Matrix", + ) + + csc = dace.data.Tensor( + dace.float32, + (M, N), + [(dace.data.Dense(), 1), (dace.data.Compressed(), 0)], + nnz, + "CSC_Matrix", + ) + + coo = dace.data.Tensor( + dace.float32, + (M, N), + [ + (dace.data.Compressed(unique=False), 0), + (dace.data.Singleton(), 1), + ], + nnz, + "CSC_Matrix", + ) + + num_diags = dace.symbol('num_diags') # number of diagonals stored + + diag = dace.data.Tensor( + dace.float32, + (M, N), + [ + (dace.data.Dense(), num_diags), + (dace.data.Range(), 0), + (dace.data.Offset(), 1), + ], + nnz, + "DIA_Matrix", + ) + + Below you can find examples of common 3rd order tensor storage formats: + + .. code-block:: python + + I, J, K, nnz = (dace.symbol(s) for s in ('I', 'J', 'K', 'nnz')) + + coo = dace.data.Tensor( + dace.float32, + (I, J, K), + [ + (dace.data.Compressed(unique=False), 0), + (dace.data.Singleton(unique=False), 1), + (dace.data.Singleton(), 2), + ], + nnz, + "COO_3D_Tensor", + ) + + csf = dace.data.Tensor( + dace.float32, + (I, J, K), + [ + (dace.data.Compressed(), 0), + (dace.data.Compressed(), 1), + (dace.data.Compressed(), 2), + ], + nnz, + "CSF_3D_Tensor", + ) + + :param value_type: data type of the explicitly stored values. + :param tensor_shape: logical shape of tensor (#rows, #cols, etc...) + :param indices: + a list of tuples, each tuple represents a level in the tensor + storage hirachy, specifying the levels tensor index type, and the + corresponding dimension this level encodes (as index of the + tensor_shape tuple above). The order of the dimensions may differ + from the logical shape of the tensor, e.g. as seen in the CSC + format. If an index's dimension is unrelated to the tensor shape + (e.g. in diagonal format where the first index's dimension is the + number of diagonals stored), a symbol can be specified instead. + :param value_count: number of explicitly stored values. + :param name: name of resulting struct. + :param others: See Structure class for remaining arguments + """ + + self.value_dtype = value_dtype + self.tensor_shape = tensor_shape + self.value_count = value_count + + indices, index_ordering = zip(*indices) + self.indices, self.index_ordering = list(indices), list(index_ordering) + + num_dims = len(tensor_shape) + dimension_order = [idx for idx in self.index_ordering if isinstance(idx, int)] + + # all tensor dimensions must occure exactly once in indices + if not sorted(dimension_order) == list(range(num_dims)): + raise TypeError((f"All tensor dimensions must be refferenced exactly once in " + f"tensor indices. (referenced dimensions: {dimension_order}; " + f"tensor dimensions: {list(range(num_dims))})")) + + # assembling permanent and index specific fields + fields = dict( + order=Scalar(dtypes.int32), + dim_sizes=dtypes.int32[num_dims], + value_count=value_count, + values=dtypes.float32[value_count], + ) + + for (lvl, index) in enumerate(indices): + fields.update(index.fields(lvl, value_count)) + + super(Tensor, self).__init__(fields, name, transient, storage, location, lifetime, debuginfo) + + def __repr__(self): + return f"{self.name} (dtype: {self.value_dtype}, shape: {list(self.tensor_shape)}, indices: {self.indices})" + + @staticmethod + def from_json(json_obj, context=None): + if json_obj['type'] != 'Tensor': + raise TypeError("Invalid data type") + + # Create dummy object + tensor = Tensor.__new__(Tensor) + serialize.set_properties_from_json(tensor, json_obj, context=context) + + return tensor diff --git a/dace/dtypes.py b/dace/dtypes.py index faadc84a50..982339e204 100644 --- a/dace/dtypes.py +++ b/dace/dtypes.py @@ -1,4 +1,4 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. """ A module that contains various DaCe type definitions. """ import ctypes import json @@ -64,7 +64,6 @@ class ScheduleType(aenum.AutoNumberEnum): MPI = () #: MPI processes CPU_Multicore = () #: OpenMP parallel for loop CPU_Persistent = () #: OpenMP parallel region - Unrolled = () #: Unrolled code SVE_Map = () #: Arm SVE #: Default scope schedule for GPU code. Specializes to schedule GPU_Device and GPU_Global during inference. @@ -76,7 +75,6 @@ class ScheduleType(aenum.AutoNumberEnum): FPGA_Device = () Snitch = () Snitch_Multicore = () - FPGA_Multi_Pumped = () #: Used for double pumping # A subset of GPU schedule types @@ -102,7 +100,6 @@ class ScheduleType(aenum.AutoNumberEnum): FPGA_STORAGES = [ StorageType.FPGA_Local, StorageType.FPGA_Registers, - StorageType.FPGA_ShiftRegister, ] @@ -166,7 +163,7 @@ class InstrumentationType(aenum.AutoNumberEnum): LIKWID_CPU = () LIKWID_GPU = () GPU_Events = () - FPGA = () + GPU_TX_MARKERS = () @undefined_safe_enum @@ -215,14 +212,12 @@ class TilingType(aenum.AutoNumberEnum): ScheduleType.MPI: ScheduleType.CPU_Multicore, ScheduleType.CPU_Multicore: ScheduleType.Sequential, ScheduleType.CPU_Persistent: ScheduleType.CPU_Multicore, - ScheduleType.Unrolled: ScheduleType.CPU_Multicore, ScheduleType.GPU_Default: ScheduleType.GPU_Device, ScheduleType.GPU_Persistent: ScheduleType.GPU_Device, ScheduleType.GPU_Device: ScheduleType.GPU_ThreadBlock, ScheduleType.GPU_ThreadBlock: ScheduleType.Sequential, ScheduleType.GPU_ThreadBlock_Dynamic: ScheduleType.Sequential, ScheduleType.FPGA_Device: ScheduleType.FPGA_Device, - ScheduleType.FPGA_Multi_Pumped: ScheduleType.FPGA_Device, ScheduleType.SVE_Map: ScheduleType.Sequential, ScheduleType.Snitch: ScheduleType.Snitch, ScheduleType.Snitch_Multicore: ScheduleType.Snitch_Multicore @@ -265,68 +260,6 @@ class TilingType(aenum.AutoNumberEnum): numpy.complex128: "dace::complex128", } -# Translation of types to OpenCL types -_OCL_TYPES = { - None: "void", - int: "int", - float: "float", - bool: "bool", - numpy.bool_: "bool", - numpy.int8: "char", - numpy.int16: "short", - numpy.int32: "int", - numpy.intc: "int", - numpy.int64: "long", - numpy.uint8: "uchar", - numpy.uint16: "ushort", - numpy.uint32: "uint", - numpy.uint64: "ulong", - numpy.uintc: "uint", - numpy.float32: "float", - numpy.float64: "double", - numpy.complex64: "complex float", - numpy.complex128: "complex double", -} - -_CTYPES_TO_OCLTYPES = { - "void": "void", - "int": "int", - "float": "float", - "double": "double", - "dace::complex64": "complex float", - "dace::complex128": "complex double", - "bool": "bool", - "char": "char", - "short": "short", - "int": "int", - "int64_t": "long", - "uint8_t": "uchar", - "uint16_t": "ushort", - "uint32_t": "uint", - "dace::uint": "uint", - "uint64_t": "ulong", - "dace::float16": "half", -} - -# Translation of types to OpenCL vector types -_OCL_VECTOR_TYPES = { - numpy.int8: "char", - numpy.uint8: "uchar", - numpy.int16: "short", - numpy.uint16: "ushort", - numpy.int32: "int", - numpy.intc: "int", - numpy.uint32: "uint", - numpy.uintc: "uint", - numpy.int64: "long", - numpy.uint64: "ulong", - numpy.float16: "half", - numpy.float32: "float", - numpy.float64: "double", - numpy.complex64: "complex float", - numpy.complex128: "complex double", -} - # Translation of types to ctypes types _FFI_CTYPES = { None: ctypes.c_void_p, @@ -502,10 +435,6 @@ def base_type(self): def veclen(self): return 1 - @property - def ocltype(self): - return _OCL_TYPES[self.type] - def as_arg(self, name): return self.ctype + ' ' + name @@ -703,10 +632,6 @@ def as_numpy_dtype(self): def base_type(self): return self._typeclass - @property - def ocltype(self): - return f"{self.base_type.ocltype}*" - class vector(typeclass): """ @@ -734,14 +659,6 @@ def from_json(json_obj, context=None): def ctype(self): return "dace::vec<%s, %s>" % (self.vtype.ctype, self.veclen) - @property - def ocltype(self): - if self.veclen > 1: - vectype = _OCL_VECTOR_TYPES[self.type] - return f"{vectype}{self.veclen}" - else: - return self.base_type.ocltype - @property def ctype_unaligned(self): return self.ctype @@ -1362,9 +1279,6 @@ def dtype_to_typeclass(dtype=None): "cmath": "dace::cmath::", } -# Lists allowed modules and maps them to OpenCL -_OPENCL_ALLOWED_MODULES = {"builtins": "", "dace": "", "math": ""} - def ismodule(var): """ Returns True if a given object is a module. """ diff --git a/dace/external/hlslib b/dace/external/hlslib deleted file mode 160000 index 1b5b3aee5d..0000000000 --- a/dace/external/hlslib +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 1b5b3aee5dab19adcc443fa9a7cd45244bd246b1 diff --git a/dace/external/rtllib b/dace/external/rtllib deleted file mode 160000 index 4f320ac020..0000000000 --- a/dace/external/rtllib +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 4f320ac02007fd50d16aa3ff5b9fce7fa804e955 diff --git a/dace/fpga_testing.py b/dace/fpga_testing.py deleted file mode 100644 index e1f76b58be..0000000000 --- a/dace/fpga_testing.py +++ /dev/null @@ -1,293 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -from datetime import datetime -import importlib.util -import inspect -import os -import multiprocessing as mp -from pathlib import Path -import pytest -import re -import subprocess as sp -from typing import Callable, Iterable, Optional, Tuple, Union - -from dace import SDFG -from dace.config import Config, temporary_config - -TEST_TIMEOUT_SW = 600 # Timeout software simulation tests after 10 minutes -TEST_TIMEOUT_HW = 900 # Timeout hardware emulation tests after 15 minutes - - -class Colors: - SUCCESS = "\033[92m" - STATUS = "\033[94m" - ERROR = "\033[91m" - BOLD = "\033[1m" - UNDERLINE = "\033[4m" - END = "\033[0m" - - -def print_status(message): - timestamp = datetime.now().strftime("%H:%M:%S") - print(f"{Colors.STATUS}{Colors.BOLD}[{timestamp}]{Colors.END} {message}") - - -def print_success(message): - timestamp = datetime.now().strftime("%H:%M:%S") - print(f"{Colors.SUCCESS}{Colors.BOLD}[{timestamp}]{Colors.END} {message}") - - -def print_error(message): - timestamp = datetime.now().strftime("%H:%M:%S") - print(f"{Colors.ERROR}{Colors.BOLD}[{timestamp}]{Colors.END} {message}") - - -def dump_logs(proc_or_logs: Union[sp.CompletedProcess, Tuple[str, str]]): - if isinstance(proc_or_logs, tuple): - log_out, log_err = proc_or_logs - else: - proc_or_logs.terminate() - proc_or_logs.kill() - try: - log_out, log_err = proc_or_logs.communicate(timeout=10) - except sp.TimeoutExpired: - return None # Failed to even kill the process - if log_out: - print(log_out) - if log_err: - print(log_err) - return log_out, log_err - - -# https://stackoverflow.com/a/33599967/2949968 -class FPGATestProcess(mp.Process): - - def __init__(self, *args, **kwargs): - mp.Process.__init__(self, *args, **kwargs) - self._pconn, self._cconn = mp.Pipe() - self._exception = None - - def run(self): - try: - ret = mp.Process.run(self) - self._cconn.send(ret) - except Exception as e: - self._cconn.send(e) - raise e - - @property - def exception(self): - if self._pconn.poll(): - self._exception = self._pconn.recv() - return self._exception - - -class TestFailed(Exception): - pass - - -def raise_error(message): - print_error(message) - raise TestFailed(message) - - -def _run_fpga_test(vendor: str, - test_function: Callable, - test_timeout: int, - run_synthesis: bool = True, - assert_ii_1: bool = True): - path = Path(inspect.getfile(test_function)) - base_name = f"{path.stem}::{Colors.UNDERLINE}{test_function.__name__}{Colors.END}" - with temporary_config(): - Config.set("compiler", "use_cache", value=False) - Config.set("cache", value="unique") - Config.set("call_hooks", value=None) - Config.set("optimizer", "autooptimize", value=False) - if vendor == "xilinx": - Config.set("compiler", "fpga", "vendor", value="xilinx") - Config.set("compiler", "xilinx", "mode", value="simulation") - Config.set("compiler", "xilinx", "frequency", value="100") # 100 is the vitis_hls default - - # Simulation in software - print_status(f"{base_name} [Xilinx]: Running simulation.") - if "rtl" in path.parts: - Config.set("compiler", "xilinx", "mode", value="hardware_emulation") - if "LIBRARY_PATH" not in os.environ: - os.environ["LIBRARY_PATH"] = "" - library_path_backup = None - else: - library_path_backup = os.environ["LIBRARY_PATH"] - os.environ["LIBRARY_PATH"] += ":/usr/lib/x86_64-linux-gnu" - sdfgs = test_function() - if "rtl" in path.parts: - if library_path_backup is None: - del os.environ["LIBRARY_PATH"] - else: - os.environ["LIBRARY_PATH"] = library_path_backup - if sdfgs is None: - raise_error("No SDFG(s) returned by FPGA test.") - elif isinstance(sdfgs, SDFG): - sdfgs = [sdfgs] - print_success(f"{base_name} [Xilinx]: " - "Simulation successful.") - - for sdfg in sdfgs: - build_folder = Path(sdfg.build_folder) / "build" - if not build_folder.exists(): - raise_error(f"Build folder {build_folder} " - f"not found for {base_name}.") - - # High-level synthesis - if run_synthesis: - print_status(f"{base_name} [Xilinx]: Running high-level " - f"synthesis for {sdfg.name}.") - try: - proc = sp.Popen(["make", "synthesis"], - cwd=build_folder, - stdout=sp.PIPE, - stderr=sp.PIPE, - encoding="utf=8") - syn_out, syn_err = proc.communicate(timeout=test_timeout) - except sp.TimeoutExpired: - dump_logs(proc) - raise_error(f"{base_name} [Xilinx]: High-level " - f"synthesis timed out after " - f"{test_timeout} seconds.") - if proc.returncode != 0: - dump_logs(proc) - raise_error(f"{base_name} [Xilinx]: High-level " - f"synthesis failed.") - print_success(f"{base_name} [Xilinx]: High-level " - f"synthesis successful for " - f"{sdfg.name}.") - open(build_folder / "synthesis.out", "w").write(syn_out) - open(build_folder / "synthesis.err", "w").write(syn_err) - - # Check if loops were pipelined with II=1 - if assert_ii_1: - loops_found = False - for f in build_folder.iterdir(): - if "hls.log" in f.name: - hls_log = f - break - else: - raise_error(f"{base_name} [Xilinx]: HLS " - f"log file not found.") - hls_log = open(hls_log, "r").read() - for m in re.finditer(r"Final II = ([0-9]+)", hls_log): - loops_found = True - if int(m.group(1)) != 1: - dump_logs((syn_out, syn_err)) - raise_error(f"{base_name} [Xilinx]: " - f"Failed to achieve II=1.") - if not loops_found: - dump_logs((syn_out, syn_err)) - raise_error(f"{base_name} [Xilinx]: No " - f"pipelined loops found.") - print_success(f"{base_name} [Xilinx]: II=1 " - f"achieved.") - - elif vendor == "intel_fpga": - # Set environment variables - Config.set("compiler", "fpga", "vendor", value="intel_fpga") - Config.set("compiler", "default_data_types", value="C") - Config.set("compiler", "intel_fpga", "mode", value="emulator") - - # Simulation in software - print_status(f"{base_name} [Intel FPGA]: Running " - f"emulation.") - test_function() - print_success(f"{base_name} [Intel FPGA]: Emulation " - f"successful.") - else: - raise ValueError(f"Unrecognized vendor {vendor}.") - - -def fpga_test(run_synthesis: bool = True, - assert_ii_1: bool = True, - xilinx: bool = True, - intel: bool = True, - rtl: bool = False): - """ - Decorator to run an FPGA test with pytest, setting the appropriate - variables and performing additional checks, such as running HLS and - asserting II=1. The test function must return an SDFG or a list of SDFGs - that will be used for this check. - - :param run_synthesis: Whether to run HLS for Xilinx tests (Intel tests will always run synthesis). - :param assert_ii_1: Assert that all loops have been fully pipelined (currently only implemented for Xilinx). - :param xilinx: Run as a Xilinx test. - :param intel: Run as an Intel test. - :param rtl: Run as an RTL Xilinx test. - """ - - # Check arguments - if not xilinx and not intel: - raise ValueError("FPGA test must be run for Xilinx, Intel, or both.") - pytest_params = [] - if xilinx: - pytest_params.append("xilinx") - if intel: - pytest_params.append("intel_fpga") - test_timeout = TEST_TIMEOUT_HW if rtl else TEST_TIMEOUT_SW - - def decorator(test_function: Callable): - - def internal(vendor: Optional[str]): - if vendor == None: - vendor = Config.get("compiler", "fpga", "vendor") - p = FPGATestProcess(target=_run_fpga_test, - args=(vendor, test_function, test_timeout, run_synthesis, assert_ii_1)) - p.start() - p.join(timeout=test_timeout) - if p.is_alive(): - p.kill() - raise_error(f"Test {Colors.UNDERLINE}{test_function.__name__}" - f"{Colors.END} timed out.") - if p.exception: - raise p.exception - - if rtl: - - @pytest.mark.rtl_hardware - @pytest.mark.parametrize("vendor", pytest_params) - def wrapper(vendor: Optional[str]): - internal(vendor) - else: - - @pytest.mark.fpga - @pytest.mark.parametrize("vendor", pytest_params) - def wrapper(vendor: Optional[str]): - internal(vendor) - - return wrapper - - return decorator - - -def xilinx_test(*args, **kwargs): - return fpga_test(*args, xilinx=True, intel=False, **kwargs) - - -def intel_fpga_test(*args, **kwargs): - return fpga_test(*args, xilinx=False, intel=True, **kwargs) - - -def rtl_test(*args, **kwargs): - return fpga_test(*args, xilinx=True, intel=False, rtl=True, **kwargs) - - -def import_sample(path: Union[Path, str]): - """ - Import a Python file from the samples directory as a module so it can be - used in a test. - - :param path: Path relative to the DaCe samples directory. - """ - path = Path(__file__).parent.parent / "samples" / Path(path) - if not path.exists(): - raise ValueError(f"Sample {path} not found.") - name = path.stem - spec = importlib.util.spec_from_file_location(name, path) - loaded_module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(loaded_module) - return loaded_module diff --git a/dace/frontend/common/einsum.py b/dace/frontend/common/einsum.py index df1c8de34e..d33b0150f3 100644 --- a/dace/frontend/common/einsum.py +++ b/dace/frontend/common/einsum.py @@ -9,7 +9,7 @@ import dace from dace import dtypes, subsets, symbolic -from dace.data import _prod as prod +from dace.utils import prod from dace.sdfg.nodes import AccessNode from dace.sdfg import SDFG, SDFGState, InterstateEdge from dace.memlet import Memlet diff --git a/dace/frontend/ml/__init__.py b/dace/frontend/ml/__init__.py new file mode 100644 index 0000000000..6e6305d8f9 --- /dev/null +++ b/dace/frontend/ml/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. + +try: + from .torch import DaceModule +except ImportError: + DaceModule = None + +try: + from .onnx import ONNXModel +except ImportError: + ONNXModel = None + +__all__ = ['DaceModule', 'ONNXModel'] diff --git a/dace/frontend/ml/onnx/__init__.py b/dace/frontend/ml/onnx/__init__.py new file mode 100644 index 0000000000..aa0c16bf05 --- /dev/null +++ b/dace/frontend/ml/onnx/__init__.py @@ -0,0 +1,5 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. + +from .importer import ONNXModel + +__all__ = ['ONNXModel'] diff --git a/dace/frontend/ml/onnx/importer.py b/dace/frontend/ml/onnx/importer.py new file mode 100644 index 0000000000..226b3e821f --- /dev/null +++ b/dace/frontend/ml/onnx/importer.py @@ -0,0 +1,794 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +ONNX Model Importer for DaCe. + +This module provides the ONNXModel class, which is the main entry point for +importing ONNX models into DaCe. It handles the complete pipeline of: + +1. **Model Loading**: Loading ONNX models from files or protobuf objects +2. **Model Simplification**: Applying onnx-simplifier for optimization +3. **Shape Inference**: Computing tensor shapes symbolically or concretely +4. **Graph Conversion**: Converting ONNX graph to DaCe SDFG +5. **Weight Management**: Handling model parameters and initializers +6. **Compilation**: Compiling the SDFG to executable code +7. **Execution**: Running the model with NumPy or PyTorch tensors + +Key Features: +- Automatic shape inference for dynamic models +- Support for both CPU and CUDA execution +- Integration with PyTorch for seamless tensor conversion +- Configurable optimization levels +- Weight initialization and parameter management +- Support for nested models and subgraphs + +Typical Workflow: + >>> import onnx + >>> from dace.frontend.ml.onnx import ONNXModel + >>> + >>> # Load ONNX model + >>> onnx_model = onnx.load("model.onnx") + >>> dace_model = ONNXModel("my_model", onnx_model) + >>> + >>> # Run inference + >>> import numpy as np + >>> input_data = np.random.randn(1, 3, 224, 224).astype(np.float32) + >>> output = dace_model(input_data) + +The module also provides utility functions for: +- Type conversion between NumPy, PyTorch, and ONNX types +- Model validation and checking +- Shape inference helpers +- Weight loading and initialization + +Note: + This is a large module (900+ lines) that handles multiple concerns. + Consider the architectural recommendations in the code review for + potential refactoring into smaller, focused modules. +""" + +import collections +import copy +import tempfile +from itertools import chain, repeat +from typing import Any, Callable, Dict, List, Optional, OrderedDict, Tuple, Union + +import numpy as np + +# PyTorch is optional (only needed for tensor conversion features) +try: + import torch + TORCH_AVAILABLE = True +except ImportError: + torch = None + TORCH_AVAILABLE = False + +# ONNX is mandatory for this module +try: + import onnx + import onnx.checker + from onnx import numpy_helper +except ImportError as e: + raise ImportError("ONNX library is required. Install with: pip install dace[ml]") from e + +# ONNXRuntime for symbolic shape inference +try: + from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference + ONNXRUNTIME_AVAILABLE = True +except ImportError: + SymbolicShapeInference = None + ONNXRUNTIME_AVAILABLE = False + +# onnxsim is optional (only needed for model simplification) +try: + import onnxsim + ONNXSIM_AVAILABLE = True +except ImportError: + onnxsim = None + ONNXSIM_AVAILABLE = False + +import dace +from dace import config, SDFG, SDFGState, data as dt, dtypes, nodes +from dace.codegen import compiled_sdfg +from dace.frontend.python import parser +from dace.sdfg import utils as sdfg_utils +from dace.symbolic import pystr_to_symbolic +from dace.transformation.onnx import auto_optimize_onnx as auto_opt +from dace.transformation.onnx import expand_onnx_nodes as onnx_node_expander + +from dace.libraries.onnx.converters import clean_onnx_name, convert_attribute_proto, onnx_tensor_type_to_typeclass +from dace.libraries.onnx.nodes.onnx_op_registry import get_onnx_node, has_onnx_node +from dace.libraries.onnx.schema import ONNXParameterType + +#: Mapping from NumPy dtypes to PyTorch dtypes for tensor conversion +if TORCH_AVAILABLE: + numpy_to_torch_dtype_dict = { + np.bool_: torch.bool, + np.uint8: torch.uint8, + np.int8: torch.int8, + np.int16: torch.int16, + np.int32: torch.int32, + np.int64: torch.int64, + np.float16: torch.float16, + np.float32: torch.float32, + np.float64: torch.float64, + np.complex64: torch.complex64, + np.complex128: torch.complex128 + } + + #: Reverse mapping from PyTorch dtypes to NumPy dtypes + torch_to_numpy_dtype_dict = {v: k for k, v in numpy_to_torch_dtype_dict.items()} +else: + numpy_to_torch_dtype_dict = {} + torch_to_numpy_dtype_dict = {} + + +def _nested_HasField(obj, full_attr: str) -> bool: + """ + Check if a protobuf object has a nested field. + + This function performs a nested hasattr check by traversing dot-separated + attribute names on a protobuf object. + + :param obj: The protobuf object to check. + :param full_attr: Dot-separated attribute path (e.g., "graph.node"). + :return: True if all attributes in the path exist, False otherwise. + + Example:: + + >>> _nested_HasField(model, "graph.node") + True + """ + attrs = full_attr.split(".") + for attr in attrs: + if obj.HasField(attr): + obj = getattr(obj, attr) + else: + return False + return True + + +def infer_shapes_onnx_model(model: onnx.ModelProto, auto_merge: bool = False) -> onnx.ModelProto: + """ + Perform shape inference on an ONNX model using ONNXRuntime's symbolic shape inference. + + This function uses ONNXRuntime's symbolic shape inference tool which provides + better support for symbolic dimensions and dynamic shapes compared to ONNX's + built-in shape inference. + + :param model: The ONNX model to perform shape inference on. + :param auto_merge: Whether to automatically merge symbolic dimensions when possible. + :return: The ONNX model with inferred shapes. + + .. note:: + Falls back to ONNX's built-in shape inference if ONNXRuntime is not available + or if symbolic shape inference produces incomplete results. + """ + if not ONNXRUNTIME_AVAILABLE: + if config.Config.get_bool('debugprint'): + print("Warning: ONNXRuntime not available, falling back to ONNX shape inference.") + # Fallback to ONNX's built-in shape inference + import onnx.shape_inference + return onnx.shape_inference.infer_shapes(model, check_type=False, strict_mode=False, data_prop=True) + + try: + # Use static method API + model = SymbolicShapeInference.infer_shapes( + model, + int_max=2**31 - 1, + auto_merge=auto_merge, + guess_output_rank=False, + verbose=0, + ) + + # Check if shape inference completed successfully for all value_infos + incomplete_shapes = False + for value in model.graph.value_info: + if not _nested_HasField(value, "type.tensor_type.shape"): + incomplete_shapes = True + break + + if incomplete_shapes: + if config.Config.get_bool('debugprint'): + print("Warning: ONNXRuntime symbolic shape inference produced incomplete results, " + "falling back to ONNX shape inference.") + import onnx.shape_inference + return onnx.shape_inference.infer_shapes(model, check_type=False, strict_mode=False, data_prop=True) + + return model + except Exception as e: + if config.Config.get_bool('debugprint'): + print(f"Warning: ONNXRuntime symbolic shape inference failed ({e}), " + "falling back to ONNX shape inference.") + import onnx.shape_inference + return onnx.shape_inference.infer_shapes(model, check_type=False, strict_mode=False, data_prop=True) + + +def simplify_onnx_model(model: onnx.ModelProto, auto_merge: bool) -> onnx.ModelProto: + """ + Simplify an ONNX model using onnx-simplifier. + + This function applies various optimizations to the ONNX model including: + - Constant folding + - Dead code elimination + - Shape inference + - Operator fusion (except batch normalization) + + :param model: The ONNX model to simplify. + :param auto_merge: Whether to automatically merge nodes (passed to onnxsim). + :return: The simplified ONNX model. + :raises ImportError: If onnxsim is not installed. + :raises RuntimeError: If onnx-simplifier optimizations fail validation. + + .. note:: + Batch normalization fusion is skipped (skip_fuse_bn=True) to maintain + numerical accuracy and allow separate optimization strategies. + """ + if not ONNXSIM_AVAILABLE: + raise ImportError("onnxsim is required for model simplification. Install with: pip install dace[ml]") + + try: + model, check = onnxsim.simplify(model, skip_fuse_bn=True) + if not check: + raise RuntimeError("onnx-simplifier optimizations failed validation") + return model + except (onnx.checker.ValidationError, ValueError) as e: + # If simplification fails due to validation errors (e.g., missing shape info), + # return the original model + if config.Config.get_bool('debugprint'): + print(f"Warning: ONNX simplification failed with error: {e}. Continuing without simplification.") + return model + + +class ONNXModel: + """ Loads an ONNX model into an SDFG. + + :Example: + First download an ONNX model, such as + `efficientnet `_. + + .. testsetup:: + + import subprocess + model_path = os.path.join("..", "tests", "onnx_files", "efficientnet.onnx") + # Download model + if not os.path.exists(model_path): + subprocess.check_call([ + "wget", + "http://spclstorage.inf.ethz.ch/~rauscho/efficientnet-lite4-11.onnx", + "--output-document={}".format(model_path), + "--no-verbose" + ]) + + + .. testcode:: + + import onnx + import os + import numpy as np + from dace.onnx import ONNXModel + + model_path = os.path.join("..", "tests", "onnx_files", "efficientnet.onnx") + model = onnx.load(model_path) + dace_model = ONNXModel("efficientnet", model) + + test_input = np.random.rand(1, 3, 224, 224).astype(np.float32) + dace_model(test_input) + + """ + + def __init__(self, + name: str, + model: onnx.ModelProto, + cuda: bool = False, + auto_optimize: bool = False, + simplify: bool = False, + onnx_simplify: bool = True, + storage: Optional[dtypes.StorageType] = None, + save_transients: Optional[Dict[str, torch.Tensor]] = None, + auto_merge: bool = False): + """ + :param name: the name for the SDFG. + :param model: the model to import. + :param cuda: if ``True``, the model will be executed on the GPU. + :param simplify: if ``True``, apply simplification transformations after all nodes have been expanded. + :param onnx_simplify: if True, run ONNX-level simplifications such as constant folding and shape inference. + :param auto_optimize: if ``True``, apply automatic optimizations before calling. + :param storage: the storage type of the parameters, inputs and outputs. If None, will be set according to + ``cuda``. + :param save_transients: if not None, save transients to this dict (for debugging). + :param: whether to automatically merge conflicting shapes in symbolic shape inference. + :param auto_merge: whether to automatically merge symbolic shapes in symbolic shape inference. + """ + + onnx.checker.check_model(model) + + # Use temporary files for intermediate model saves + with tempfile.NamedTemporaryFile(suffix='.onnx', delete=True) as temp_original: + onnx.save(model, temp_original.name) + model = infer_shapes_onnx_model(model, auto_merge=auto_merge) + + with tempfile.NamedTemporaryFile(suffix='.onnx', delete=True) as temp_shapes: + onnx.save(model, temp_shapes.name) + + if onnx_simplify: + model = simplify_onnx_model(model, auto_merge) + with tempfile.NamedTemporaryFile(suffix='.onnx', delete=True) as temp_simplified: + onnx.save(model, temp_simplified.name) + + self.do_auto_optimize = auto_optimize + self.model = model + graph: onnx.GraphProto = model.graph + self.save_transients = save_transients + self.sdfg: SDFG = SDFG(name) #: the generated SDFG. + self.sdfg._parent_onnx_model = self + self.cuda = cuda + self.simplify = simplify + self.state: SDFGState = self.sdfg.add_state() #: the state containing the model computation. + + # Add all values to the SDFG, check for unsupported ops + ########################################## + + self.value_infos = {} + + self.inputs: List[str] = [] #: the inputs to the model + self.outputs: List[str] = [] #: the outputs of the model + + if storage is None: + storage = dtypes.StorageType.GPU_Global if self.cuda else dtypes.StorageType.Default + + for value, is_input in chain(zip(graph.input, repeat(True)), zip(graph.output, repeat(False))): + if not value.HasField("name"): + raise ValueError("Got input or output without name") + if is_input: + self.inputs.append(value.name) + else: + self.outputs.append(value.name) + + self.value_infos[value.name] = value + storage = storage + self._add_value_info(value, storage=storage) + + self.sdfg.arg_names = [clean_onnx_name(i) for i in self.inputs] + + for value in graph.value_info: + if not value.HasField("name"): + raise ValueError("Got input or output without name") + if value.name not in self.value_infos: + self.value_infos[value.name] = value + + # add weights + self.weights: Dict[str, torch.Tensor] = {} #: mapping from weight name to array + for init in graph.initializer: + self._add_constant_tensor(init, storage) + + access_nodes = {} + self._idx_to_node = [] + for i, node in enumerate(graph.node): + if not has_onnx_node(node.op_type): + raise ValueError("Unsupported ONNX operator: '{}'".format(node.op_type)) + + # extract the op attributes + op_attributes = { + attribute_proto.name: convert_attribute_proto(attribute_proto) + for attribute_proto in node.attribute + } + + if node.op_type == "Constant": + # Add constants to weights immediately + possible_values = [ + "sparse_value", "value", "value_float", "value_floats", "value_int", "value_ints", "value_string", + "value_strings" + ] + + # do some manual validation here since the node validation will never run + if set(op_attributes).difference(possible_values): + raise ValueError(f"Got unexpected attributes on Constant node " + f"{set(op_attributes).difference(possible_values)}") + + if len(op_attributes) != 1: + raise ValueError("Expected Constant node to have exactly one of its attributes set") + + if len(node.input) != 0 or len(node.output) != 1: + raise ValueError("Expected Constant node to have no inputs and exactly 1 output") + + value_name = next(iter(op_attributes)) + + self._add_constant_tensor((node.output[0], op_attributes[value_name]), storage) + continue + + if node.HasField("name"): + node_name = clean_onnx_name(node.name) + else: + node_name = node.op_type + "_" + str(i) + + # construct the dace node + [opset] = [i for i in model.opset_import if not i.domain] + node_schema = onnx.defs.get_schema(node.op_type, opset.version) + node_version = node_schema.since_version + op_node = get_onnx_node(node.op_type, node_version)(node_name, **op_attributes) + self.state.add_node(op_node) + self._idx_to_node.append(op_node) + + for param_idx, (name, is_input) in chain(enumerate(zip(node.input, repeat(True))), + enumerate(zip(node.output, repeat(False)))): + # Get parameter schema + params = op_node.schema.inputs if is_input else op_node.schema.outputs + params_len = len(params) + + # Determine parameter type and validate + if param_idx >= params_len: + # Variadic parameter beyond schema range + if params[-1].param_type != ONNXParameterType.Variadic: + raise ValueError( + "Expected the last {i_or_o} parameter to be variadic," + " since the {i_or_o} with idx {param_idx} has more parameters than the schema ({params_len})" + .format(i_or_o="input" if is_input else "output", + param_idx=param_idx, + params_len=params_len)) + param_type = ONNXParameterType.Variadic + conn_name = params[-1].name + "__" + str(param_idx - params_len + 1) + else: + param_type = params[param_idx].param_type + if param_type == ONNXParameterType.Variadic: + conn_name = params[param_idx].name + "__0" + else: + conn_name = params[param_idx].name + + # Handle optional parameters + if param_type == ONNXParameterType.Optional and not name: + continue + + # Validate required parameters + if param_type != ONNXParameterType.Optional and not name: + raise ValueError("Required {i_or_o} parameter '{param_name}' is not set".format( + i_or_o="input" if is_input else "output", param_name=params[param_idx].name)) + + # Create array if needed + if clean_onnx_name(name) not in self.sdfg.arrays: + if name not in self.value_infos: + raise ValueError("Could not find array with name '{}'".format(name)) + self._add_value_info(self.value_infos[name]) + + # Get or create access node + if name in access_nodes: + access = access_nodes[name] + else: + access = nodes.AccessNode(clean_onnx_name(name)) + self.state.add_node(access) + access_nodes[name] = access + + data_desc = self.sdfg.arrays[clean_onnx_name(name)] + + # Add connector and edge + if is_input: + if conn_name not in op_node.in_connectors: + assert op_node.add_in_connector(conn_name) + self.state.add_edge(access, None, op_node, conn_name, + dace.Memlet.from_array(clean_onnx_name(name), data_desc)) + else: + if conn_name not in op_node.out_connectors: + assert op_node.add_out_connector(conn_name) + self.state.add_edge(op_node, conn_name, access, None, + dace.Memlet.from_array(clean_onnx_name(name), data_desc)) + + # scalars need to be promoted to arrays so that we can return them from the dace program + # however, this is only for CPU: on GPU, scalars are already pointers + self._promoted_scalars = set() + + # insert copies from outputs to __return arrays + copy_out_state = self.sdfg.add_state_after(self.state, label='copy_out') + new_output_names = [] + for i, output in enumerate(self.outputs): + clean_name = clean_onnx_name(output) + new_output_name = '__return' + if len(self.outputs) > 1: + new_output_name += '_' + str(i) + new_output_names.append(new_output_name) + + desc = copy.deepcopy(self.sdfg.arrays[clean_name]) + if isinstance(desc, dt.Scalar) and not self.cuda: + desc = dt.Array(desc.dtype, (1, )) + self._promoted_scalars.add(new_output_name) + + # insert new descriptor + self.sdfg.arrays[new_output_name] = desc + desc.transient = False + + copy_out_state.add_edge(copy_out_state.add_read(clean_name), None, + copy_out_state.add_write(new_output_name), None, + self.sdfg.make_array_memlet(clean_name)) + + # finally, rename outputs, and fuse states + self.outputs = new_output_names + sdfg_utils.fuse_states(self.sdfg) + + if self.cuda: + self.sdfg.apply_gpu_transformations() + + def _add_constant_tensor(self, tensor: Union[onnx.TensorProto, Tuple[str, np.ndarray]], + storage: dtypes.StorageType): + if isinstance(tensor, tuple): + unclean_name, value = tensor + dtype = dtypes.dtype_to_typeclass(value.dtype.type) + shape = value.shape + np_array = value + else: + if not tensor.HasField("name"): + raise ValueError("Got tensor without name") + + if not tensor.HasField("data_type"): + raise ValueError("Initializer tensor '{}' has no type".format(tensor.name)) + unclean_name = tensor.name + dtype = onnx_tensor_type_to_typeclass(tensor.data_type) + shape = [d for d in tensor.dims] + np_array = numpy_helper.to_array(tensor) + + name = clean_onnx_name(unclean_name) + if unclean_name in self.inputs: + # remove the tensor from inputs since this is a constant + self.inputs.remove(unclean_name) + # note: inputs already have data-descriptors created for them, so + # we skip the below code + elif len(shape) == 0: + # this is a scalar + self.sdfg.add_scalar(name, dtype, storage=storage) + else: + if name not in self.sdfg.arrays: + self.sdfg.add_array(name, shape, dtype, storage=storage, transient=False) + else: + existing_arr = self.sdfg.arrays[name] + if existing_arr.dtype != dtype: + raise ValueError( + "Invalid ONNX model; found two values with name '{}', but different dtypes ({} and {})".format( + name, existing_arr.dtype, dtype)) + if tuple(existing_arr.shape) != tuple(shape): + raise ValueError( + "Invalid ONNX model; found two values with name '{}', but different dimensions ({} and {})". + format(name, existing_arr.shape, shape)) + + # we need to copy here because the weight_arr tensor is not writable + self.weights[unclean_name] = torch.from_numpy(np_array.copy()) + + def _add_value_info(self, value_info: onnx.ValueInfoProto, storage=dtypes.StorageType.Default): + if not value_info.HasField("name"): + raise ValueError("Got value without name") + + name = value_info.name + + if not _nested_HasField(value_info, "type.tensor_type.shape"): + raise ValueError("Value '{}' does not have a shape in this graph." + " Please run shape inference before importing.".format(name)) + + tensor_type = value_info.type.tensor_type + + if not tensor_type.HasField("elem_type"): + raise ValueError("Value '{}' does not have a type in this graph." + " Please run type inference before importing.".format(name)) + + shape = [] + for d in tensor_type.shape.dim: + if d.HasField("dim_value"): + shape.append(d.dim_value) + elif d.HasField("dim_param"): + parsed = pystr_to_symbolic(d.dim_param) + + for sym in parsed.free_symbols: + if clean_onnx_name(str(sym)) not in self.sdfg.symbols: + self.sdfg.add_symbol(clean_onnx_name(str(sym)), stype=int) + parsed = parsed.subs(sym, dace.symbol(clean_onnx_name(str(sym)))) + + shape.append(parsed) + else: + raise ValueError("Value '{}' does not have a shape in this graph." + " Please run shape inference before importing.".format(name)) + transient = name not in self.inputs + if len(shape) == 0: + self.sdfg.add_scalar(clean_onnx_name(name), + dtype=onnx_tensor_type_to_typeclass(tensor_type.elem_type), + transient=transient, + storage=storage) + else: + self.sdfg.add_array(clean_onnx_name(name), + shape=shape, + dtype=onnx_tensor_type_to_typeclass(tensor_type.elem_type), + transient=transient, + storage=storage) + + @property + def clean_weights(self): + return {clean_onnx_name(k): v for k, v in self.weights.items()} + + def compile_and_init(self) -> compiled_sdfg.CompiledSDFG: + """ Compile the SDFG and load parameters into GPU memory. """ + + compiled_sdfg = self.sdfg.compile() + + # copy all parameters to the device + self.initialized_parameters = {} + for name, arr in self.weights.items(): + if clean_onnx_name(name) in compiled_sdfg.sdfg.arrays: + desc = self.sdfg.arrays[clean_onnx_name(name)] + cuda = desc.storage in dace.dtypes.GPU_STORAGES + if type(desc) is dt.Scalar: + self.initialized_parameters[clean_onnx_name(name)] = arr.cuda() if cuda else arr.cpu().numpy()[()] + else: + self.initialized_parameters[clean_onnx_name(name)] = arr.cuda() if cuda else arr + + return compiled_sdfg + + def __call__(self, *args, + **kwargs) -> Union[Union[torch.Tensor, np.ndarray], Tuple[Union[torch.Tensor, np.ndarray]]]: + """ Execute the model. + + :param args: positional arguments to the model. The i-th argument will be passed as the i-th input of the + model. + :param kwargs: named arguments to the model. The passed names should match the names in the ONNX model. + :return: the output of the model (or a tuple of outputs if there are multiple). + """ + + transient_kwargs = {} + if self.save_transients is not None: + for node, parent in self.sdfg.all_nodes_recursive(): + if isinstance(node, nodes.AccessNode): + desc = self.sdfg.arrays[node.data] + if not isinstance(desc, dt.View) and desc.transient: + desc.transient = False + transient_kwargs[node.data] = desc + + if self.do_auto_optimize: + self.auto_optimize() + + compiled = self.compile_and_init() + + inputs, symbols, outputs = self._call_args(args=args, kwargs=kwargs) + + for name, desc in transient_kwargs.items(): + if name in self.initialized_parameters: + transient_kwargs[name] = self.initialized_parameters[name] + self.initialized_parameters.pop(name) + else: + transient_kwargs[name] = create_output_array(symbols, desc, use_torch=True, zeros=True) + self.save_transients[name] = transient_kwargs[name] + + compiled(**inputs, **outputs, **self.initialized_parameters, **symbols, **transient_kwargs) + + # demote scalars we promoted above + for scalar in self._promoted_scalars: + outputs[scalar] = outputs[scalar].reshape(()) + + if len(outputs) == 1: + return next(iter(outputs.values())) + + return tuple(outputs.values()) + + def _call_args(self, + *, + args, + kwargs, + torch_outputs: bool = None) -> Tuple[Dict[str, Any], Dict[str, Any], OrderedDict[str, Any]]: + """ Prepare the arguments for a call. + + This returns 4 dicts; one for each of the following: + 1. the inputs + 3. inferred values for symbols for dynamic dimensions + 4. outputs + + These arguments can be passed to `self.sdfg`. + + :param args: model positional args + :param kwargs: model kwargs + :param torch_outputs: if not None, the outputs will be torch tensors depending on the boolean value. + Otherwise the outputs will be torch tensors only if at least one of the inputs is a + torch tensor. + :return: the tuple of dicts + """ + inputs = kwargs + + # convert the positional args to kwargs + if len(args) > len(self.inputs): + raise ValueError("Expected {} arguments, got {}".format(len(self.inputs), len(args))) + + inputs.update(dict(zip(self.inputs, args))) + + # check that there are no missing inputs + if len(set(self.inputs).difference(inputs)) != 0: + raise ValueError("Missing inputs {}".format(", ".join(set(self.inputs).difference(inputs)))) + + # check that there are no unknown inputs + # NOTE symbols can only be passed as kwargs + if len(set(inputs).difference(self.inputs).difference(self.sdfg.free_symbols)) != 0: + raise ValueError("Unknown inputs {}".format(", ".join(set(inputs).difference(self.inputs)))) + + clean_inputs = {} + for input, arr in inputs.items(): + if input in self.sdfg.free_symbols: + clean_inputs[input] = arr + else: + clean_inputs[clean_onnx_name(input)] = arr + + inferred_symbols = parser.infer_symbols_from_datadescriptor(self.sdfg, { + **clean_inputs, + **self.initialized_parameters + }) + inferred_symbols = {k: int(v) for k, v in inferred_symbols.items()} + + if torch_outputs is None: + torch_outputs = any(self.sdfg.arrays[clean_onnx_name(o)].storage in dace.dtypes.GPU_STORAGES + for o in self.outputs) or any( + isinstance(inp, torch.Tensor) for _, inp in clean_inputs.items()) + + outputs = collections.OrderedDict() + # create numpy arrays for the outputs + for name in self.outputs: + clean_name = clean_onnx_name(name) + outputs[clean_name] = create_output_array(inferred_symbols, + self.sdfg.arrays[clean_name], + use_torch=torch_outputs, + zeros=True) + + # check that there's no overlap + seen = set() + for parameters in [clean_inputs, self.initialized_parameters, outputs, inferred_symbols]: + new_parameters = set(parameters) + assert not seen.intersection(new_parameters) + seen |= new_parameters + + return clean_inputs, inferred_symbols, outputs + + def expand_onnx_nodes(self): + onnx_node_expander(self.sdfg) + + def auto_optimize(self): + auto_opt( + self.sdfg, + self.cuda, + simplify=self.simplify, + # constants have been folded before GPU transforms + fold_constants=False) + + +def create_output_array(inferred_symbols: Dict[str, int], + desc: dt.Data, + use_torch=False, + zeros: bool = False) -> Union[np.ndarray, torch.tensor]: + """ Create the array for an output. This is either a numpy array or a torch tensor depending on `use_torch` + + When `self.force_torch_outputs` is True, the outputs will be tensors. Otherwise, the outputs will be tensors + :param inferred_symbols: the symbols inferred from `infer_symbols_from_datadescriptor`. + :param desc: the data descriptor for the array + :param use_torch: whether to return a numpy array or a torch tensor. + :param zeros: if true init with zeros else empty. + """ + + def eval_dim(dim): + for sym in dim.free_symbols: + dim = dim.subs(sym, inferred_symbols[sym.name]) + return dim + + cuda = desc.storage in dace.dtypes.GPU_STORAGES + if cuda and not use_torch: + raise ValueError("Got use_torch=False, but received a GPU descriptor") + + if isinstance(desc, dt.Scalar): + shape = [] + else: + shape = [eval_dim(d) if type(d) is dace.symbol else d for d in desc.shape] + + if use_torch: + # torch functions don't accept the empty shape, so create shape [1] then reshape to () + if len(shape) == 0: + shape = [1] + + # as_numpy_dtype doesn't seem to work for indexing into the dict + if desc.dtype == dace.pointer(dace.typeclass(None)): + # assuming 64 bit ptrs + dtype = torch.int64 + else: + dtype = numpy_to_torch_dtype_dict[getattr(np, desc.dtype.to_string())] + tens = (torch.zeros if zeros else torch.empty)(*shape, dtype=dtype) + if isinstance(desc, dt.Scalar): + tens = tens.reshape(()) + + return tens.cuda() if cuda else tens + else: + return (np.zeros if zeros else np.empty)(shape, dtype=getattr(np, desc.dtype.to_string())) diff --git a/dace/frontend/tensorflow/__init__.py b/dace/frontend/ml/tensorflow/__init__.py similarity index 100% rename from dace/frontend/tensorflow/__init__.py rename to dace/frontend/ml/tensorflow/__init__.py diff --git a/dace/frontend/tensorflow/tensorflow.py b/dace/frontend/ml/tensorflow/tensorflow.py similarity index 99% rename from dace/frontend/tensorflow/tensorflow.py rename to dace/frontend/ml/tensorflow/tensorflow.py index af71493214..ef6cfdb409 100644 --- a/dace/frontend/tensorflow/tensorflow.py +++ b/dace/frontend/ml/tensorflow/tensorflow.py @@ -17,8 +17,8 @@ from dace.data import Scalar from dace.sdfg.nodes import Tasklet, NestedSDFG from dace.symbolic import symstr, SymExpr -from dace.frontend.tensorflow.winograd import winograd_convolution -from dace.frontend.tensorflow.transformations.redundant_array import (TensorflowRedundantArray) +from .winograd import winograd_convolution +from .transformations.redundant_array import TensorflowRedundantArray try: import tensorflow as tf diff --git a/dace/frontend/tensorflow/transformations/__init__.py b/dace/frontend/ml/tensorflow/transformations/__init__.py similarity index 100% rename from dace/frontend/tensorflow/transformations/__init__.py rename to dace/frontend/ml/tensorflow/transformations/__init__.py diff --git a/dace/frontend/tensorflow/transformations/redundant_array.py b/dace/frontend/ml/tensorflow/transformations/redundant_array.py similarity index 100% rename from dace/frontend/tensorflow/transformations/redundant_array.py rename to dace/frontend/ml/tensorflow/transformations/redundant_array.py diff --git a/dace/frontend/tensorflow/winograd.py b/dace/frontend/ml/tensorflow/winograd.py similarity index 100% rename from dace/frontend/tensorflow/winograd.py rename to dace/frontend/ml/tensorflow/winograd.py diff --git a/dace/frontend/ml/torch/__init__.py b/dace/frontend/ml/torch/__init__.py new file mode 100644 index 0000000000..d1563fac13 --- /dev/null +++ b/dace/frontend/ml/torch/__init__.py @@ -0,0 +1,6 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. + +from .module import DaceModule +from .interface import module + +__all__ = ['DaceModule', 'module'] diff --git a/dace/frontend/ml/torch/interface.py b/dace/frontend/ml/torch/interface.py new file mode 100644 index 0000000000..6dc1f68d1f --- /dev/null +++ b/dace/frontend/ml/torch/interface.py @@ -0,0 +1,89 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Python interface for DaCe PyTorch/Torch integration. + +This module provides decorators and utilities for converting PyTorch modules +to DaCe-accelerated implementations. +""" + +from functools import wraps +from typing import Optional, Tuple, List + +from dace.dtypes import paramdec + + +@paramdec +def module(moduleclass, + dummy_inputs: Optional[Tuple] = None, + cuda: Optional[bool] = None, + training: bool = False, + backward=False, + inputs_to_skip: Optional[List[str]] = None, + onnx_simplify: bool = True, + simplify: bool = True, + auto_optimize: bool = True, + sdfg_name: Optional[str] = None, + compile_torch_extension: bool = True, + debug_transients: bool = False): + """ + Decorator to apply on a definition of a ``torch.nn.Module`` to convert it + to a data-centric module upon construction. + + Example:: + + import dace.ml + import torch.nn as nn + + @dace.ml.module + class MyDecoratedModule(nn.Module): + def forward(self, x): + x = torch.log(x) + x = torch.sqrt(x) + return x + + module_instance = MyDecoratedModule() + module_instance(torch.ones(2)) # tensor([0., 0.]) + + .. Note:: + You must import ``dace.ml`` (not just ``dace``) to use this decorator. + + :param moduleclass: The model to wrap. + :param dummy_inputs: A tuple of tensors to use as input when tracing the model. + :param cuda: If ``True``, the module will execute using CUDA. + If ``None``, it will be detected from the module. + :param training: Whether to use train mode when tracing the model. + :param backward: Whether to enable the backward pass. + :param inputs_to_skip: If provided, a list of inputs to skip computing gradients for + (only relevant when the backward pass is enabled). + :param onnx_simplify: Whether to apply ONNX simplification using onnxsim. + :param simplify: Whether to apply simplification transforms after conversion. + This generally improves performance but can be slow. + :param auto_optimize: Whether to apply automatic optimizations. + :param sdfg_name: The name to give to the SDFG (defaults to moduleclass name). + :param compile_torch_extension: If ``True``, a torch C++ extension will be compiled + and used for this module. Otherwise, a Python ctypes implementation will be used. + :param debug_transients: If ``True``, the module will have all transients as outputs. + """ + wraps(moduleclass) + + def _create(*args, **kwargs): + # Lazy import DaceModule when decorator is actually used + try: + from dace.frontend.ml.torch import DaceModule + except ImportError: + raise ImportError("DaceModule requires PyTorch. Install with: pip install torch") + + return DaceModule(moduleclass(*args, **kwargs), + dummy_inputs=dummy_inputs, + cuda=cuda, + training=training, + backward=backward, + inputs_to_skip=inputs_to_skip, + onnx_simplify=onnx_simplify, + simplify=simplify, + auto_optimize=auto_optimize, + sdfg_name=sdfg_name, + compile_torch_extension=compile_torch_extension, + debug_transients=debug_transients) + + return _create diff --git a/dace/frontend/ml/torch/module.py b/dace/frontend/ml/torch/module.py new file mode 100644 index 0000000000..ba820faf87 --- /dev/null +++ b/dace/frontend/ml/torch/module.py @@ -0,0 +1,581 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" DaCe Python parsing functionality and entry point to Python frontend. """ +from dataclasses import dataclass +import collections +import itertools +import tempfile +import copy +import os +from typing import Any, Callable, Dict, OrderedDict, List, Optional, Set, Sequence, Tuple, Union + +# Try importing ML dependencies +try: + import torch + from torch import Tensor + import torch.nn as nn + from torch.onnx import TrainingMode + TORCH_AVAILABLE = True +except ImportError: + torch = None + Tensor = None + nn = None + TrainingMode = None + TORCH_AVAILABLE = False + +try: + import onnx + ONNX_AVAILABLE = True +except ImportError: + onnx = None + ONNX_AVAILABLE = False + +import dace +from dace import config, data +from dace.codegen import compiled_sdfg +from dace.sdfg import SDFG, nodes +from dace.frontend.python import common as pycommon +from dace.data import find_new_name + +if TORCH_AVAILABLE and ONNX_AVAILABLE: + from dace.libraries.onnx.converters import clean_onnx_name + from dace.libraries.torch import dispatchers + from dace.autodiff import torch as torch_autodiff + from dace.autodiff.library import library as autodiff_library + from dace.frontend.ml.onnx import ONNXModel + from dace.transformation.onnx import auto_optimize_onnx as auto_opt +else: + clean_onnx_name = None + dispatchers = None + torch_autodiff = None + autodiff_library = None + ONNXModel = None + auto_opt = None + +if TORCH_AVAILABLE and ONNX_AVAILABLE: + + def _onnx_delete_initializers(model: onnx.ModelProto, names: Set[str]) -> None: + """ + Delete the given initializers from the given onnx model. + + :param model: The ONNX model to modify. + :param names: Set of initializer names to delete. + :note: Operates in-place. + """ + to_remove = [] + for i, initializer in enumerate(model.graph.initializer): + if initializer.name in names: + to_remove.append(i) + + for i in reversed(to_remove): + model.graph.initializer.pop(i) + + class DaceModule(nn.Module, pycommon.SDFGConvertible): + """ A wrapper that converts a PyTorch ``nn.Module`` to a PyTorch compatible data-centric ``nn.Module``. + + :param module: the model to wrap. + :param dummy_inputs: a tuple of tensors to use as input when tracing ``model``. + :param cuda: if ``True``, the module will execute using CUDA. If ``None``, it will be detected from the + ``module``. + :param training: whether to use train mode when tracing ``model``. + :param backward: whether to enable the backward pass. + :param inputs_to_skip: if provided, a list of inputs to skip computing gradients for. + (only relevant when the backward pass is enabled) + :param onnx_simplify: whether to apply onnx simplification using onnxsim. + :param simplify: whether to apply simplification transforms after conversion (this generally improves performance, + but can be slow). + :param sdfg_name: the name to give to the sdfg (defaults to moduleclass name). + :param auto_optimize: whether to apply automatic optimizations. + :param compile_torch_extension: if True, a torch C++ extension will be compiled and used for this module. + Otherwise, a python ctypes implementation will be used. + :param debug_transients: if True, the module will have all transients as outputs. + + :Example: + >>> from dace.frontend.ml.torch import DaceModule + >>> class MyModule(nn.Module): + ... def forward(self, x): + ... x = torch.log(x) + ... x = torch.sqrt(x) + ... return x + >>> module = MyModule() + >>> module(torch.ones(2)) + tensor([0., 0.]) + >>> dace_module = DaceModule(module) + >>> dace_module(torch.ones(2)) + tensor([0., 0.]) + """ + + def __init__(self, + module: nn.Module, + dummy_inputs: Optional[Tuple[torch.Tensor, ...]] = None, + cuda: Optional[bool] = None, + training: bool = False, + backward: bool = False, + inputs_to_skip: Optional[List[str]] = None, + onnx_simplify: bool = True, + simplify: bool = True, + auto_optimize: bool = False, + debug_transients: bool = False, + compile_torch_extension: bool = True, + sdfg_name: Optional[str] = None): + + super(DaceModule, self).__init__() + + self.backward = backward + self.model = module + self.dace_model: Optional[ONNXModel] = None + self.training = training + self.sdfg: Optional[SDFG] = None + self.use_cuda = cuda + self.sdfg_name = sdfg_name or type(module).__name__ + self.auto_optimize = auto_optimize + self.onnx_simplify = onnx_simplify + self.simplify = simplify + self.debug_transients = debug_transients + self.compile_torch_extension = compile_torch_extension + self.inputs_to_skip = inputs_to_skip or [] + + self.function = None + + #: hooks that are executed after onnx graph is imported to an SDFG + self.post_onnx_hooks: OrderedDict[str, Callable[[DaceModule], None]] = collections.OrderedDict() + + #: hooks that are executed after the backpropagation sdfg has been created + self.post_autodiff_hooks: OrderedDict[str, Callable[[SDFG, SDFG], None]] = collections.OrderedDict() + + #: hooks that are executed after the sdfg is compiled + self.post_compile_hooks: OrderedDict[str, Callable[[compiled_sdfg.CompiledSDFG], + None]] = collections.OrderedDict() + # setup debug hook + if self.debug_transients: + + def transients_outputs(module): + for state in module.sdfg.nodes(): + for node in state.nodes(): + if (isinstance(node, nodes.AccessNode) and node.desc(module.sdfg).transient + and not isinstance(node.desc(module.sdfg), data.Scalar)): + if "mean" not in node.data and "std" not in node.data: + module.dace_model.outputs.append(node.data) + node.desc(module.sdfg).transient = False + + self.prepend_post_onnx_hook("make_transients_outputs", transients_outputs) + + # setup optimization hooks + if self.auto_optimize: + if self.backward: + + def auto_optimize_backward(fwd_sdfg, bwd_sdfg): + auto_opt(fwd_sdfg, self.use_cuda, simplify=self.simplify) + auto_opt(bwd_sdfg, self.use_cuda, simplify=self.simplify) + + self.append_post_autodiff_hook("auto_optimize", auto_optimize_backward) + else: + self.append_post_onnx_hook( + "auto_optimize", lambda dace_module: auto_opt( + dace_module.dace_model.sdfg, self.use_cuda, simplify=self.simplify)) + elif self.simplify: + if self.backward: + + def simplify_hook(fwd_sdfg, bwd_sdfg): + fwd_sdfg.simplify() + bwd_sdfg.simplify() + + self.append_post_autodiff_hook("simplify", simplify_hook) + else: + self.append_post_onnx_hook("simplify", lambda dace_module: dace_module.sdfg.simplify()) + + if dummy_inputs is not None: + self.function = self._initialize_sdfg(dummy_inputs) + + def reset_sdfg(self) -> None: + """Clear the SDFG so that optimizations are reapplied.""" + self.function = None + + def _detect_cuda_usage(self, dummy_inputs) -> bool: + """ + Detect whether CUDA should be used based on inputs and model parameters. + + :param dummy_inputs: Tuple of tensors to check. + :return: True if CUDA should be used, False otherwise. + """ + try: + module_is_cuda = next(iter(dummy_inputs)).is_cuda + except StopIteration: + module_is_cuda = False + + if not module_is_cuda: + # check the parameters + try: + module_is_cuda = next(self.model.parameters()).is_cuda + except StopIteration: + module_is_cuda = False + return module_is_cuda + + def prepend_post_onnx_hook(self, name: str, func: Callable[["DaceModule"], None]) -> None: + """ + Add a hook to be executed after ONNX graph import, at the beginning of the hook list. + + :param name: Name of the hook (will be made unique if necessary). + :param func: Callable to execute after ONNX import. + """ + if self.function is not None: + if config.Config.get_bool('debugprint'): + print(f"Warning: Added a hook after the model was already initialized. This hook " + f"(with name {name}) will not be executed!") + name = find_new_name(name, self.post_onnx_hooks) + self.post_onnx_hooks[name] = func + self.post_onnx_hooks.move_to_end(name, last=False) + + def append_post_onnx_hook(self, name: str, func: Callable[["DaceModule"], None]) -> None: + """ + Add a hook to be executed after ONNX graph import, at the end of the hook list. + + :param name: Name of the hook (will be made unique if necessary). + :param func: Callable to execute after ONNX import. + """ + if self.function is not None: + if config.Config.get_bool('debugprint'): + print(f"Warning: Added a hook after the model was already initialized. This hook " + f"(with name {name}) will not be executed!") + name = find_new_name(name, self.post_onnx_hooks) + self.post_onnx_hooks[name] = func + + def prepend_post_autodiff_hook(self, name: str, func: Callable[[SDFG, SDFG], None]) -> None: + """ + Add a hook to be executed after autodiff, at the beginning of the hook list. + + :param name: Name of the hook (will be made unique if necessary). + :param func: Callable to execute after autodiff, receiving forward and backward SDFGs. + """ + if self.function is not None: + if config.Config.get_bool('debugprint'): + print(f"Warning: Added a hook after the model was already initialized. This hook " + f"(with name {name}) will not be executed!") + name = find_new_name(name, self.post_autodiff_hooks) + self.post_autodiff_hooks[name] = func + self.post_autodiff_hooks.move_to_end(name, last=False) + + def append_post_autodiff_hook(self, name: str, func: Callable[[SDFG, SDFG], None]) -> None: + """ + Add a hook to be executed after autodiff, at the end of the hook list. + + :param name: Name of the hook (will be made unique if necessary). + :param func: Callable to execute after autodiff, receiving forward and backward SDFGs. + """ + if self.function is not None: + if config.Config.get_bool('debugprint'): + print(f"Warning: Added a hook after the model was already initialized. This hook " + f"(with name {name}) will not be executed!") + name = find_new_name(name, self.post_autodiff_hooks) + self.post_autodiff_hooks[name] = func + + def prepend_post_compile_hook(self, name: str, func: Callable[[compiled_sdfg.CompiledSDFG], None]) -> None: + """ + Add a hook to be executed after compilation, at the beginning of the hook list. + + :param name: Name of the hook (will be made unique if necessary). + :param func: Callable to execute after compilation, receiving the compiled SDFG. + """ + if self.function is not None: + if config.Config.get_bool('debugprint'): + print(f"Warning: Added a hook after the model was already initialized. This hook " + f"(with name {name}) will not be executed!") + name = find_new_name(name, self.post_compile_hooks) + self.post_compile_hooks[name] = func + self.post_compile_hooks.move_to_end(name, last=False) + + def append_post_compile_hook(self, name: str, func: Callable[[compiled_sdfg.CompiledSDFG], None]) -> None: + """ + Add a hook to be executed after compilation, at the end of the hook list. + + :param name: Name of the hook (will be made unique if necessary). + :param func: Callable to execute after compilation, receiving the compiled SDFG. + """ + if self.function is not None: + if config.Config.get_bool('debugprint'): + print(f"Warning: Added a hook after the model was already initialized. This hook " + f"(with name {name}) will not be executed!") + name = find_new_name(name, self.post_compile_hooks) + self.post_compile_hooks[name] = func + + def _initialize_sdfg(self, dummy_inputs): + """ + Initialize the SDFG by converting the PyTorch module to ONNX and then to DaCe. + + :param dummy_inputs: Tuple of tensors to use for tracing. + :return: Forward function to be called during execution. + """ + # determine whether we are using CUDA + if self.use_cuda is None: + self.use_cuda = self._detect_cuda_usage(dummy_inputs) + + if self.use_cuda: + self.model = self.model.cuda() + + # TODO change to StringIO if not too big + with tempfile.TemporaryDirectory() as dir_name: + export_name = os.path.join(dir_name, "export.onnx") + + # save the state of the model, and restore it after tracing + state = copy.deepcopy(self.state_dict()) + torch.onnx.export( + self.model, + dummy_inputs, + export_name, + verbose=config.Config.get_bool('debugprint'), + # Some models will require training even when we don't want to train: + # when training is set to EVAL, pytorch currently performs an optimization pass ("onnx_eval_peephole") + # that renames weights and thus breaks the model in some settings. + training=(TrainingMode.TRAINING if self.training else TrainingMode.EVAL), + opset_version=18, + export_params=not self.backward, + # pytorch constant folding will add new unnamed inputs to the graph and remove some of the + # named parameters of the model: this means that we can't match with the state dict + # anymore, so we disable this. Our CF is more flexible. + do_constant_folding=False, + keep_initializers_as_inputs=True, + dynamo=False) + self.load_state_dict(state) + onnx_model_exported = onnx.load(export_name) + + # Remove buffers and parameters from initializers + # they should already be in the inputs (from the pytorch exporter) + # this prevents onnx tools from messing with parameters + input_names = set() + for name, _ in itertools.chain(self.named_parameters(), self.named_buffers()): + # pytorch adds a "model." prefix here that isn't in the onnx export; + # remove it + if not name.startswith("model."): + raise ValueError("Expected parameter names to start with 'model.'") + input_names.add(name[6:]) + + # save the parameters as they are now for later access + self._exported_parameters = dict( + (n, p) for n, p in itertools.chain(self.model.named_parameters(), self.model.named_buffers())) + + _onnx_delete_initializers(onnx_model_exported, input_names) + + # load using importer + dace_model = ONNXModel(self.sdfg_name, + onnx_model_exported, + onnx_simplify=self.onnx_simplify, + cuda=self.use_cuda, + auto_optimize=self.auto_optimize) + self.sdfg = dace_model.sdfg + self.dace_model = dace_model + + self.sdfg.validate() + + for _, hook in self.post_onnx_hooks.items(): + hook(self) + + # choose the backend that will generate the function to call during + # forward + if self.compile_torch_extension: + function_generator = dispatchers.register_and_compile_torch_extension + else: + function_generator = dispatchers.get_ctypes_dispatcher + + if self.backward: + + # Determine what grads we need + # For now: we want gradients for all inputs that are not pytorch buffers + named_buffers = {n for n, _ in self.model.named_buffers()} + required_gradients = [ + clean_onnx_name(name) for name in self.dace_model.inputs + if name not in named_buffers and name not in self.inputs_to_skip + ] + named_parameters = dict(self.model.named_parameters()) + required_gradients.extend( + clean_onnx_name(name) for name, param in named_parameters.items() if param.requires_grad) + required_gradients = list(set(required_gradients)) + + self.forward_sdfg, self.backward_sdfg, self._ad_result, self._ad_inp_arrs = torch_autodiff.make_backward_function( + dace_model, required_gradients) + + for _, hook in self.post_autodiff_hooks.items(): + hook(self.forward_sdfg, self.backward_sdfg) + self.compiled_function = function_generator(self, dummy_inputs) + else: + self.compiled_function = function_generator(self, dummy_inputs) + + # order the parameters + parameters_to_pass = self._call_params() + + def forward(*args): + return self.compiled_function.function(*self.compiled_function.ptr, *args, *parameters_to_pass) + + return forward + + def _call_params(self) -> Tuple[Union[Tensor, nn.parameter.Parameter], ...]: + """ + Get the parameters that we need to pass to the model, in the correct order. + + :return: Tuple of parameters and buffers in the order expected by the SDFG. + """ + # self.dace_model.inputs contains the buffers, parameters and the inputs. + # We only want the parameters and buffers + model_inputs = self.dace_model.inputs + + # find the index of the first input that is a parameter or buffer + start_idx = 0 + while start_idx < len(model_inputs) and model_inputs[start_idx] not in self._exported_parameters: + start_idx += 1 + + return tuple(self._exported_parameters[i] for i in model_inputs[start_idx:]) + + def forward(self, *actual_inputs): + """ + Execute the forward pass using the traced module. + + :param actual_inputs: Input tensors to the model. + :return: Output tensors from the model. + """ + if self.function is None: + self.function = self._initialize_sdfg(actual_inputs) + + return self.function(*actual_inputs) + + # SDFGConvertible methods: + # used when the model is called in a DaceProgram. + ################################################# + + def __sdfg__(self, *args): + """ + Get the SDFG representation of this module (SDFGConvertible interface). + + :param args: Arguments (currently unused). + :return: The SDFG representation. + :raises ValueError: If the model has not been initialized yet. + """ + if self.sdfg is None: + raise ValueError("Using a PyTorch model in a DaceProgram requires that the model is initialized first. " + "Either call this model using some inputs, or pass 'dummy_inputs' to the constructor.") + for name, param in self._exported_parameters.items(): + onnx_name = clean_onnx_name(name) + if param.requires_grad: + autodiff_library.ParameterArray.make_parameter(self.sdfg, onnx_name) + return self.sdfg + + def _add_gradient_buffers(self) -> List[str]: + """ + Allocate gradient buffers for all parameters, and add their descriptors to the SDFG. + + :return: a list of the sdfg array names of the gradient buffers + """ + + assert self.sdfg is not None + if hasattr(self, '_gradient_buffers'): + return self._gradient_buffers + + buffers = [] + for name, param in self._exported_parameters.items(): + onnx_name = clean_onnx_name(name) + desc = self.sdfg.arrays[onnx_name] + + if param.requires_grad: + # allocate gradient buffer + param.grad = torch.empty_like(param.data) + + # add gradient buffer descriptor to sdfg + autodiff_library.ParameterArray.make_parameter(self.sdfg, onnx_name) + desc: autodiff_library.ParameterArray = self.sdfg.arrays[onnx_name] + grad_name = desc.add_gradient_buffer(self.sdfg, onnx_name) + grad_desc = self.sdfg.arrays[grad_name] + grad_desc.transient = False + buffers.append(grad_name) + self._gradient_buffers = buffers + return buffers + + def __sdfg_signature__(self): + """ + Get the SDFG signature (SDFGConvertible interface). + + :return: Tuple of (input names, output names). + :raises ValueError: If the SDFG has not been generated yet. + """ + if self.dace_model is None: + raise ValueError("Can't determine signature before SDFG is generated.") + inputs = [clean_onnx_name(name) for name in self.dace_model.inputs] + grad_buffers = self._add_gradient_buffers() + inputs.extend(grad_buffers) + + return inputs, [] + + @staticmethod + def _tensor_from_param(param) -> Tensor: + """ + Extract tensor from parameter while preserving requires_grad flag. + + :param param: PyTorch parameter. + :return: Tensor with correct requires_grad setting. + """ + t = param.data + # Accessing .data on a Parameter resets the requires_grad flag + t.requires_grad = param.requires_grad + return t + + def __sdfg_closure__(self, reevaluate: Optional[Dict[str, str]] = None) -> Dict[str, Any]: + """ + Get the SDFG closure (SDFGConvertible interface). + + :param reevaluate: Optional dictionary for reevaluation (unused). + :return: Dictionary mapping parameter names to their tensor values. + """ + result = {} + for name, param in self._exported_parameters.items(): + onnx_name = clean_onnx_name(name) + result[onnx_name] = self._tensor_from_param(param) + if param.requires_grad: + grad_name = self.sdfg.arrays[onnx_name].gradient + assert grad_name, "Expected gradient descriptor to be present" + assert param.grad is not None, "Expected gradient buffer to be allocated" + result[grad_name] = param.grad + + return result + + def closure_resolver(self, + constant_args: Dict[str, Any], + given_args: Set[str], + parent_closure: Optional[pycommon.SDFGClosure] = None) -> pycommon.SDFGClosure: + """ + Resolve closure for SDFG execution (SDFGConvertible interface). + + :param constant_args: Constant arguments. + :param given_args: Arguments already provided. + :param parent_closure: Optional parent closure. + :return: SDFGClosure object containing closure arrays. + """ + assert self.sdfg is not None, "SDFG must be initialized before resolving closure" + result = pycommon.SDFGClosure() + + class TensorClosure: + """Helper class to wrap tensor access in a callable.""" + + def __init__(self, t): + self.t = t + + def __call__(self): + return self.t + + for name, param in self._exported_parameters.items(): + onnx_name = clean_onnx_name(name) + desc = self.sdfg.arrays[onnx_name] + + if param.requires_grad: + # the gradient was already added when __sdfg_signature__ was called earlier + assert desc.gradient, "Expected gradient descriptor to be present" + grad_name = desc.gradient + # also add the gradient to the closure, because we need to write to it + result.closure_arrays[grad_name] = (grad_name, self.sdfg.arrays[grad_name], + TensorClosure(param.grad), False) + + result.closure_arrays[onnx_name] = (name, desc, TensorClosure(self._tensor_from_param(param)), False) + return result + +else: + # Stub class when ML dependencies are not available + class DaceModule: + """Stub class for DaceModule when PyTorch and ONNX are not installed.""" + + def __init__(self, *args, **kwargs): + raise ImportError("DaceModule requires PyTorch and ONNX. Install with: pip install dace[ml]") diff --git a/dace/frontend/python/interface.py b/dace/frontend/python/interface.py index 446079d8f1..18943d6799 100644 --- a/dace/frontend/python/interface.py +++ b/dace/frontend/python/interface.py @@ -3,7 +3,7 @@ import inspect from functools import wraps -from typing import Any, Callable, Deque, Dict, Generator, Optional, Tuple, TypeVar, Union, overload +from typing import Any, Callable, Deque, Dict, Generator, Optional, Tuple, TypeVar, Union, overload, TYPE_CHECKING from dace import dtypes from dace.dtypes import paramdec diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index b5e8c72ea8..724ad16a45 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -33,6 +33,7 @@ from dace.sdfg.state import (BreakBlock, ConditionalBlock, ContinueBlock, ControlFlowBlock, FunctionCallRegion, LoopRegion, ControlFlowRegion, NamedRegion) from dace.sdfg.replace import replace_datadesc_names +from dace.sdfg.type_inference import infer_expr_type from dace.symbolic import pystr_to_symbolic, inequal_symbols import numpy @@ -1630,7 +1631,6 @@ def _symbols_from_params(self, params: List[Tuple[str, Union[str, dtypes.typecla object to maintain compatibility with global symbols. Used to maintain typed symbols in SDFG scopes (e.g., map, consume). """ - from dace.codegen.tools.type_inference import infer_expr_type result = {} # Add map inputs first @@ -2011,7 +2011,7 @@ def _parse_map_inputs(self, name: str, params: List[Tuple[str, str]], if symbolic.issymbolic(atom, self.sdfg.constants): # Check for undefined variables atomstr = str(atom) - if atomstr not in self.defined: + if atomstr not in self.defined and atomstr not in self.sdfg.arrays: raise DaceSyntaxError(self, node, 'Undefined variable "%s"' % atom) # Add to global SDFG symbols @@ -2457,8 +2457,6 @@ def visit_For(self, node: ast.For): self._add_dependencies(state, tasklet, me, mx, inputs, outputs, map_inputs, symbols) elif iterator == 'range': # Create an extra typed symbol for the loop iterate - from dace.codegen.tools.type_inference import infer_expr_type - sym_name = indices[0] integer = True nonnegative = None @@ -3245,8 +3243,14 @@ def _add_access( else: var_name = self.get_target_name() - parent_name = self.scope_vars[name] - parent_array = self.scope_arrays[parent_name] + parent_name = self.scope_vars[until(name, '.')] + if '.' in name: + struct_field = name[name.index('.'):] + parent_name += struct_field + scope_ndict = dace.sdfg.NestedDict(self.scope_arrays) + parent_array = scope_ndict[parent_name] + else: + parent_array = self.scope_arrays[parent_name] has_indirection = (_subset_has_indirection(rng, self) or _subset_is_local_symbol_dependent(rng, self)) strides = list(parent_array.strides) @@ -3419,7 +3423,7 @@ def _add_write_access(self, return self.accesses[(name, rng, 'w')] elif name in self.variables: return (self.variables[name], rng) - elif (name, rng, 'r') in self.accesses or name in self.scope_vars: + elif (name, rng, 'r') in self.accesses or until(name, '.') in self.scope_vars: return self._add_access(name, rng, 'w', target, new_name, arr_type) else: raise NotImplementedError @@ -3527,8 +3531,10 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): while isinstance(last_subscript.value, ast.Subscript): last_subscript = last_subscript.value if isinstance(target, ast.Subscript) and not isinstance(last_subscript.value, ast.Name): - store_target = copy.copy(last_subscript.value) - store_target.ctx = ast.Store() + store_target = astutils.copy_tree(last_subscript.value) + for n in ast.walk(store_target): # Recursively make attributes into stores + if hasattr(n, 'ctx'): + n.ctx = ast.Store() true_name = self.visit(store_target) # Refresh defined variables and arrays defined_vars = {**self.variables, **self.scope_vars} @@ -3736,7 +3742,7 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): raise IndexError('Boolean array indexing cannot be combined with indirect access') if self.nested and not new_data and not visited_target: - new_name, new_rng = self._add_write_access(name, rng, target) + new_name, new_rng = self._add_write_access(true_name, rng, target) # Local symbol or local data dependent if _subset_is_local_symbol_dependent(rng, self): new_rng = rng diff --git a/dace/frontend/python/replacements/torch_autodiff.py b/dace/frontend/python/replacements/torch_autodiff.py new file mode 100644 index 0000000000..d110046be9 --- /dev/null +++ b/dace/frontend/python/replacements/torch_autodiff.py @@ -0,0 +1,163 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Integration with the dace python frontend +""" + +from contextlib import contextmanager +from typing import Optional, Union, Sequence +import itertools +import warnings + +import torch +import torch.autograd + +from dace import SDFG, SDFGState, config, data +import dace.sdfg.sdfg +from dace.transformation import optimizer +from dace.frontend.python import common +from dace.frontend.common import op_repository +from dace.frontend.python import newast +from dace.transformation.passes.fusion_inline import InlineControlFlowRegions +from dace.data import find_new_name +from dace.sdfg.utils import expand_nodes +from dace.libraries.onnx.op_implementations.common import iterables_equal +from dace.autodiff import analysis as autodiff_analysis + +from dace.autodiff.library.library import ParameterArray, BackwardPass + +TensorOrTensors = Union[str, Sequence[str]] + + +@op_repository.replaces('torch.autograd.backward') +def backward(pv: newast.ProgramVisitor, + sdfg: SDFG, + state: SDFGState, + tensors: TensorOrTensors, + grads: Optional[TensorOrTensors] = None): + """ + Adds a backward pass node to the SDFG. + + This function analyses the dependency tree of the tensors and computes + gradients for each Parameter that was used to compute the tensors. + """ + + # First, remove function call regions + transformation = InlineControlFlowRegions() + transformation.set_opts({ + 'no_inline_function_call_regions': False, + 'no_inline_named_regions': False, + 'no_inline_loops': True, + 'no_inline_conditional': True + }) + transformation.apply_pass(sdfg, {}) + + if isinstance(tensors, str): + tensors = [tensors] + + if isinstance(grads, str): + grads = [grads] + + if grads is None: + grads = [] + # when the tensors are scalars, we can implicity create the grads with ones + for tensor in tensors: + tensor_desc = sdfg.arrays[tensor] + if tensor_desc.total_size == 1: + constant_name = sdfg._find_new_name("one") + desc = data.Scalar(tensor_desc.dtype, transient=True, storage=tensor_desc.storage) + sdfg.add_constant(constant_name, 1, dtype=desc) + sdfg.arrays[constant_name] = desc + grads.append(constant_name) + else: + raise common.DaceSyntaxError(pv, None, "grad can be implicitly created only for scalar outputs") + + if len(grads) != len(tensors): + raise common.DaceSyntaxError(pv, None, "grads and tensors must correspond, but they were not the same length") + + for grad, tensor in zip(grads, tensors): + if grad not in sdfg.arrays and grad not in sdfg.constants_prop: + raise common.DaceSyntaxError(pv, None, "Gradient {} is not an array".format(grad)) + if tensor not in sdfg.arrays: + raise common.DaceSyntaxError(pv, None, "Tensor {} is not an array".format(tensor)) + + grad_desc = sdfg.arrays[grad] if grad in sdfg.arrays else sdfg.constants_prop[grad][0] + + if not iterables_equal(grad_desc.shape, sdfg.arrays[tensor].shape): + raise common.DaceSyntaxError(pv, None, + "Gradient {} and tensor {} have different shapes".format(grad, tensor)) + + given_gradients = dict(zip(grads, tensors)) + + bwd_node = BackwardPass('backward', + inputs=set(itertools.chain(tensors, grads)), + outputs=set(), + given_gradients=given_gradients) + state.add_node(bwd_node) + + for inp in itertools.chain(tensors, grads): + state.add_edge(state.add_read(inp), None, bwd_node, inp, sdfg.make_array_memlet(inp)) + + # determine what grdaients to compute + dependencies = autodiff_analysis.dependency_analysis(sdfg) + + to_compute = { + dependency + for tensor in tensors + for dependency in dependencies[tensor] if isinstance(sdfg.arrays[dependency], ParameterArray) + } + + for param in to_compute: + param_desc: ParameterArray = sdfg.arrays[param] + grad_name = param_desc.add_gradient_buffer(sdfg, param) + + conn_name = find_new_name(grad_name, bwd_node.out_connectors) + bwd_node.required_gradients[param] = conn_name + bwd_node.add_out_connector(conn_name) + write_an = state.add_write(grad_name) + write_an.setzero = True + state.add_edge(bwd_node, conn_name, write_an, None, sdfg.make_array_memlet(grad_name)) + + +@op_repository.replaces_attribute('ParameterArray', 'grad') +def grad(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, arr: str) -> str: + """ + Returns the name of the gradient buffer of the given array. + + The Array must have been marked as requires_grad_ using + ``arr.requires_grad_()``, otherwise there will be an error + """ + + if arr not in sdfg.arrays: + raise common.DaceSyntaxError(pv, None, "Array {} is not defined".format(arr)) + desc = sdfg.arrays[arr] + if not isinstance(desc, ParameterArray): + raise common.DaceSyntaxError( + pv, None, "Called .grad on an Array that was not a Parameter. Convert it to a parameter " + " first using .requires_grad_()") + + return desc.gradient + + +@op_repository.replaces_method('Array', 'requires_grad_') +@op_repository.replaces_method('Scalar', 'requires_grad_') +def requires_grad_(pv: newast.ProgramVisitor, sdfg: SDFG, state: SDFGState, self: str): + """ + Converts a array to a ParameterArray. This creates a descriptor for + the gradient buffer for this array. + """ + + if self not in sdfg.arrays: + raise common.DaceSyntaxError(pv, None, "Array {} is not defined".format(self)) + ParameterArray.make_parameter(sdfg, self) + + +@op_repository.replaces_method('Array', 'backward') +@op_repository.replaces_method('Scalar', 'backward') +def backward_method(pv: newast.ProgramVisitor, sdfg: SDFG, state: SDFGState, self: str, grad: Optional[str] = None): + """ + Alias for ``torch.autograd.backward(self)`` + """ + backward(pv, sdfg, state, self, grad) + + +dace.hooks.register_sdfg_call_hook(before_hook=lambda sdfg: expand_nodes(sdfg, lambda n: isinstance(n, BackwardPass))) diff --git a/dace/frontend/python/replacements/utils.py b/dace/frontend/python/replacements/utils.py index 92711fdc9d..0d04591bb6 100644 --- a/dace/frontend/python/replacements/utils.py +++ b/dace/frontend/python/replacements/utils.py @@ -154,13 +154,13 @@ def get_idx(i): all_idx_dict[get_idx(i)] = dim1 # if unidirectional, this is not allowed - elif dim1 == None and not unidirectional: + elif dim1 is None and not unidirectional: # dim2 != None must hold here a2_idx.append(get_idx(i)) all_idx_dict[get_idx(i)] = dim2 - elif dim2 == None: + elif dim2 is None: # dim1 != None must hold here a1_idx.append(get_idx(i)) diff --git a/dace/libraries/blas/blas_helpers.py b/dace/libraries/blas/blas_helpers.py index 6a568f6e4a..42a5c3287b 100644 --- a/dace/libraries/blas/blas_helpers.py +++ b/dace/libraries/blas/blas_helpers.py @@ -181,7 +181,7 @@ def get_gemm_opts(a_strides, b_strides, c_strides) -> Dict[str, Any]: }, } - if sAM == 1: + if sAM == 1 and sAK != 1: optA = 'm' elif sAK == 1: optA = 'k' diff --git a/dace/libraries/blas/nodes/axpy.py b/dace/libraries/blas/nodes/axpy.py index 30efa6a2d1..4c6a02906f 100644 --- a/dace/libraries/blas/nodes/axpy.py +++ b/dace/libraries/blas/nodes/axpy.py @@ -1,10 +1,10 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. import dace.library import dace.properties import dace.sdfg.nodes from dace.transformation.transformation import ExpandTransformation from dace.libraries.blas import environments -from dace import (config, data as dt, dtypes, memlet as mm, SDFG, SDFGState, symbolic) +from dace import config, data as dt, dtypes, memlet as mm, SDFG, SDFGState, symbolic from dace.frontend.common import op_repository as oprepo @@ -22,8 +22,7 @@ def expansion(node, parent_state: SDFGState, parent_sdfg, schedule=dace.Schedule :param node: Node to expand. :param parent_state: State that the node is in. :param parent_sdfg: SDFG that the node is in. - :param schedule: The schedule to set on maps in the expansion. For FPGA - expansion, this should be set to FPGA_Device. + :param schedule: The schedule to set on maps in the expansion. """ node.validate(parent_sdfg, parent_state) @@ -82,29 +81,6 @@ def expansion(node, parent_state: SDFGState, parent_sdfg, schedule=dace.Schedule return axpy_sdfg -@dace.library.expansion -class ExpandAxpyFpga(ExpandTransformation): - """ - FPGA expansion which uses the generic implementation, but sets the map - schedule to be executed on FPGA. - """ - - environments = [] - - @staticmethod - def expansion(node, parent_state: SDFGState, parent_sdfg: SDFG, **kwargs): - """ - :param node: Node to expand. - :param parent_state: State that the node is in. - :param parent_sdfg: SDFG that the node is in. - """ - return ExpandAxpyVectorized.expansion(node, - parent_state, - parent_sdfg, - schedule=dace.ScheduleType.FPGA_Device, - **kwargs) - - @dace.library.node class Axpy(dace.sdfg.nodes.LibraryNode): """ @@ -116,7 +92,6 @@ class Axpy(dace.sdfg.nodes.LibraryNode): # Global properties implementations = { "pure": ExpandAxpyVectorized, - "fpga": ExpandAxpyFpga, } default_implementation = None diff --git a/dace/libraries/blas/nodes/dot.py b/dace/libraries/blas/nodes/dot.py index c994504048..03d6e822ff 100644 --- a/dace/libraries/blas/nodes/dot.py +++ b/dace/libraries/blas/nodes/dot.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. import copy import warnings import dace.library @@ -151,345 +151,6 @@ def expansion(node, parent_state, parent_sdfg, n=None, **kwargs): return tasklet -@dace.library.expansion -class ExpandDotFpgaPartialSums(ExpandTransformation): - """ - FPGA-expansion of DOT that does NOT assume that native accumulation of the - data type is possible (e.g., floating point on Xilinx devices or float64 - on Stratix 10). - - To achieve II=1, accumulation is done into multiple partial sums, which are - reduced at the end of the computation. - """ - - environments = [] - - @staticmethod - def expansion(node, parent_state, parent_sdfg, n=None, partial_width=8): - """ - :param node: The node to expand. - :param parent_state: The state that the node is in. - :param parent_sdfg: The SDFG that the node is in. - :param n: Override the vector dimension. If this is not set, the value - specified in the node is used. - :param partial_width: Width of the inner reduction buffer. Must be - larger than the latency of addition on the given - data type. - """ - (desc_x, stride_x), (desc_y, stride_y), desc_res, sz = node.validate(parent_sdfg, parent_state) - - n = n or node.n or sz - - sdfg = dace.SDFG("dot") - - stream_state = sdfg.add_state("stream") - - dtype = desc_x.dtype.base_type - veclen = desc_x.veclen - vtype = dtypes.vector(dtype, veclen) - - desc_x = desc_x.clone() - desc_x.transient = False - desc_y = desc_y.clone() - desc_y.transient = False - desc_res = desc_res.clone() - desc_res.transient = False - sdfg.add_datadesc("_x", desc_x) - sdfg.add_datadesc("_y", desc_y) - sdfg.add_datadesc("_result", desc_res) - - x_read = stream_state.add_read("_x") - y_read = stream_state.add_read("_y") - res_write = stream_state.add_write("_result") - - input_x_name = "input_x" - sdfg.add_array(input_x_name, (1, ), vtype, transient=True, storage=dtypes.StorageType.FPGA_Local) - input_x_access = stream_state.add_access(input_x_name) - - input_y_name = "input_y" - sdfg.add_array(input_y_name, (1, ), vtype, transient=True, storage=dtypes.StorageType.FPGA_Local) - input_y_access = stream_state.add_access(input_y_name) - - entry, exit = stream_state.add_map("stream", {"_i_dot": f"0:{n}/{veclen}"}, - schedule=dtypes.ScheduleType.FPGA_Device) - - index_x = "0" if isinstance(desc_x, dt.Stream) else "_i_dot" - index_y = "0" if isinstance(desc_y, dt.Stream) else "_i_dot" - - stream_state.add_memlet_path(x_read, - entry, - input_x_access, - memlet=dace.Memlet(f"{x_read.data}[{index_x}]", other_subset="0", dynamic=False)) - stream_state.add_memlet_path(y_read, - entry, - input_y_access, - memlet=dace.Memlet(f"{y_read.data}[{index_y}]", other_subset="0", dynamic=False)) - - tasklet = stream_state.add_tasklet("multiply", {"__x", "__y"}, {f"_product": vtype}, f"_product = __x * __y") - - stream_state.add_memlet_path(input_x_access, tasklet, dst_conn="__x", memlet=dace.Memlet(f"{input_x_name}[0]")) - stream_state.add_memlet_path(input_y_access, tasklet, dst_conn="__y", memlet=dace.Memlet(f"{input_y_name}[0]")) - - product_name = "product" - sdfg.add_array(product_name, (veclen, ), dtype, transient=True, storage=dtypes.StorageType.FPGA_Local) - product_access = stream_state.add_access(product_name) - - stream_state.add_memlet_path(tasklet, - product_access, - src_conn="_product", - memlet=dace.Memlet(f"{product_name}[0:{veclen}]")) - - collapse_name = "reduce_vector" - sdfg.add_array(collapse_name, (1, ), dtype, transient=True, storage=dtypes.StorageType.FPGA_Local) - collapse_read = stream_state.add_read(collapse_name) - collapse_access = stream_state.add_access(collapse_name) - - unroll_entry, unroll_exit = stream_state.add_map("unroll", {"_j_dot": f"0:{veclen}"}, - unroll=True, - schedule=dtypes.ScheduleType.FPGA_Device) - - collapse_tasklet = stream_state.add_tasklet( - "reduce_vector", {"val_in", "reduce_in"}, {"reduce_out"}, """\ -prev = reduce_in if _j_dot > 0 else 0 -reduce_out = prev + val_in""") - - stream_state.add_memlet_path(collapse_read, - unroll_entry, - collapse_tasklet, - dst_conn="reduce_in", - memlet=dace.Memlet(f"{collapse_name}[0]")) - stream_state.add_memlet_path(entry, collapse_read, memlet=dace.Memlet()) - stream_state.add_memlet_path(collapse_tasklet, - unroll_exit, - collapse_access, - src_conn="reduce_out", - memlet=dace.Memlet(f"{collapse_name}[0]")) - stream_state.add_memlet_path(product_access, - unroll_entry, - collapse_tasklet, - dst_conn="val_in", - memlet=dace.Memlet(f"{product_name}[_j_dot]")) - - buffer_name = "partial_sums" - sdfg.add_array(buffer_name, (partial_width, ), dtype, transient=True, storage=dtypes.StorageType.FPGA_Local) - - # The partial result buffer must be initialized. - init_tasklet = stream_state.add_tasklet("init_dummy_ps", {}, {"init_data"}, "init_data = 0") - init_ps_entry, init_ps_exit = stream_state.add_map("init_unroll", {"_j_dot": f"0:{partial_width}"}, - unroll=True, - schedule=dtypes.ScheduleType.FPGA_Device) - buffer_read = stream_state.add_access(buffer_name) - stream_state.add_memlet_path(init_ps_entry, init_tasklet, memlet=dace.Memlet()) - stream_state.add_memlet_path(init_tasklet, - init_ps_exit, - buffer_read, - src_conn="init_data", - memlet=dace.Memlet(f"{buffer_name}[_j_dot]")) - - buffer_write = stream_state.add_write(buffer_name) - - partial_sum_tasklet = stream_state.add_tasklet( - "partial_sum", {"result_in", "buffer_in"}, {"buffer_out"}, f"""\ -prev = buffer_in if _i_dot >= {partial_width} else 0 -buffer_out = prev + result_in""") - - stream_state.add_memlet_path(collapse_access, - partial_sum_tasklet, - dst_conn="result_in", - memlet=dace.Memlet(f"{collapse_access.data}[0]")) - stream_state.add_memlet_path(buffer_read, - entry, - partial_sum_tasklet, - dst_conn=f"buffer_in", - memlet=dace.Memlet(f"{buffer_name}[_i_dot%{partial_width}]")) - stream_state.add_memlet_path(partial_sum_tasklet, - exit, - buffer_write, - src_conn=f"buffer_out", - memlet=dace.Memlet(f"{buffer_name}[_i_dot%{partial_width}]")) - - reduce_entry, reduce_exit = stream_state.add_map("reduce", {"_i_dot": f"0:{partial_width}"}, - schedule=dtypes.ScheduleType.FPGA_Device, - unroll=True) - - reduce_tasklet = stream_state.add_tasklet( - "reduce", {"reduce_in", "result_in"}, {"reduce_out"}, """\ -prev = reduce_in if _i_dot > 0 else 0 -reduce_out = prev + result_in""") - - stream_state.add_memlet_path(buffer_write, - reduce_entry, - reduce_tasklet, - dst_conn="result_in", - memlet=dace.Memlet(f"{buffer_name}[_i_dot]")) - - reduce_name = "reduce" - sdfg.add_array(reduce_name, (1, ), dtype, transient=True, storage=dtypes.StorageType.FPGA_Local) - reduce_read = stream_state.add_read(reduce_name) - reduce_access = stream_state.add_access(reduce_name) - - stream_state.add_memlet_path(reduce_read, - reduce_entry, - reduce_tasklet, - dst_conn="reduce_in", - memlet=dace.Memlet(f"{reduce_name}[0]")) - stream_state.add_memlet_path(reduce_tasklet, - reduce_exit, - reduce_access, - src_conn="reduce_out", - memlet=dace.Memlet(f"{reduce_name}[0]")) - - stream_state.add_memlet_path(reduce_access, - res_write, - memlet=dace.Memlet(f"{reduce_name}[0]", other_subset="0")) - - return sdfg - - -@dace.library.expansion -class ExpandDotFpgaAccumulate(ExpandTransformation): - """ - Version of DOT that assumes that native II=1 accumulation of the data type - is possible on the target architecture (e.g., 32-bit floating point on - Stratix 10). - """ - - environments = [] - - @staticmethod - def expansion(node, parent_state, parent_sdfg, n=None, **kwargs): - """ - :param node: The node to expand. - :param parent_state: The state that the node is in. - :param parent_sdfg: The SDFG that the node is in. - :param n: Override the vector dimension. If this is not set, the value - specified in the node is used. - """ - - (desc_x, stride_x), (desc_y, stride_y), desc_res, sz = node.validate(parent_sdfg, parent_state) - - n = n or node.n or sz - - sdfg = dace.SDFG("dot") - - state = sdfg.add_state("dot") - - dtype = desc_x.dtype.base_type - veclen = desc_x.veclen - vtype = dtypes.vector(dtype, veclen) - - desc_x = desc_x.clone() - desc_x.transient = False - desc_y = desc_y.clone() - desc_y.transient = False - desc_res = desc_res.clone() - desc_res.transient = False - sdfg.add_datadesc("_x", desc_x) - sdfg.add_datadesc("_y", desc_y) - sdfg.add_datadesc("_result", desc_res) - - x_read = state.add_read("_x") - y_read = state.add_read("_y") - res_write = state.add_write("_result") - - input_x_name = "input_x" - sdfg.add_array(input_x_name, (1, ), vtype, transient=True, storage=dtypes.StorageType.FPGA_Local) - input_x_access = state.add_access(input_x_name) - - input_y_name = "input_y" - sdfg.add_array(input_y_name, (1, ), vtype, transient=True, storage=dtypes.StorageType.FPGA_Local) - input_y_access = state.add_access(input_y_name) - - entry, exit = state.add_map("stream", {"_i_dot": f"0:{n}/{veclen}"}, schedule=dtypes.ScheduleType.FPGA_Device) - - index_x = "0" if isinstance(desc_x, dt.Stream) else "_i_dot" - index_y = "0" if isinstance(desc_y, dt.Stream) else "_i_dot" - - state.add_memlet_path(x_read, - entry, - input_x_access, - memlet=dace.Memlet(f"{x_read.data}[{index_x}]", other_subset="0", dynamic=False)) - state.add_memlet_path(y_read, - entry, - input_y_access, - memlet=dace.Memlet(f"{y_read.data}[{index_y}]", other_subset="0", dynamic=False)) - - tasklet = state.add_tasklet("multiply", {"__x", "__y"}, {f"_product": vtype}, f"_product = __x * __y") - - state.add_memlet_path(input_x_access, tasklet, dst_conn="__x", memlet=dace.Memlet(f"{input_x_name}[0]")) - state.add_memlet_path(input_y_access, tasklet, dst_conn="__y", memlet=dace.Memlet(f"{input_y_name}[0]")) - - product_name = "product" - sdfg.add_array(product_name, (veclen, ), dtype, transient=True, storage=dtypes.StorageType.FPGA_Local) - product_access = state.add_access(product_name) - - state.add_memlet_path(tasklet, - product_access, - src_conn="_product", - memlet=dace.Memlet(f"{product_name}[0:{veclen}]")) - - collapse_name = "reduce_vector" - sdfg.add_array(collapse_name, (1, ), dtype, transient=True, storage=dtypes.StorageType.FPGA_Local) - collapse_read = state.add_read(collapse_name) - collapse_access = state.add_access(collapse_name) - - unroll_entry, unroll_exit = state.add_map("unroll", {"_j_dot": f"0:{veclen}"}, - unroll=True, - schedule=dtypes.ScheduleType.FPGA_Device) - - collapse_tasklet = state.add_tasklet("reduce_vector", {"val_in", "reduce_in"}, {"reduce_out"}, """\ -prev = reduce_in if _j_dot > 0 else 0 -reduce_out = prev + val_in""") - - state.add_memlet_path(collapse_read, - unroll_entry, - collapse_tasklet, - dst_conn="reduce_in", - memlet=dace.Memlet(f"{collapse_name}[0]")) - state.add_memlet_path(entry, collapse_read, memlet=dace.Memlet()) - state.add_memlet_path(collapse_tasklet, - unroll_exit, - collapse_access, - src_conn="reduce_out", - memlet=dace.Memlet(f"{collapse_name}[0]")) - state.add_memlet_path(product_access, - unroll_entry, - collapse_tasklet, - dst_conn="val_in", - memlet=dace.Memlet(f"{product_name}[_j_dot]")) - - buffer_name = "reduce_buffer" - sdfg.add_array(buffer_name, (1, ), dtype, transient=True, storage=dtypes.StorageType.FPGA_Local) - buffer_read = state.add_read(buffer_name) - buffer_write = state.add_access(buffer_name) - - zero_tasklet = state.add_tasklet("zero", {}, {"buffer"}, "buffer = 0") - state.add_memlet_path(zero_tasklet, buffer_read, src_conn="buffer", memlet=dace.Memlet(f"{buffer_name}[0]")) - - reduce_tasklet = state.add_tasklet("sum", {"buffer_in", "result_in"}, {"buffer_out"}, """\ -prev = buffer_in if _i_dot > 0 else 0 -buffer_out = prev + result_in""") - - state.add_memlet_path(collapse_access, - reduce_tasklet, - dst_conn="result_in", - memlet=dace.Memlet(f"{collapse_access.data}[0]")) - state.add_memlet_path(buffer_read, - entry, - reduce_tasklet, - dst_conn="buffer_in", - memlet=dace.Memlet(f"{buffer_name}[0]")) - state.add_memlet_path(reduce_tasklet, - exit, - buffer_write, - src_conn=f"buffer_out", - memlet=dace.Memlet(f"{buffer_name}[0]")) - - state.add_memlet_path(buffer_write, res_write, memlet=dace.Memlet(f"{buffer_name}[0]", other_subset="0")) - - return sdfg - - @dace.library.node class Dot(dace.sdfg.nodes.LibraryNode): @@ -499,8 +160,6 @@ class Dot(dace.sdfg.nodes.LibraryNode): "OpenBLAS": ExpandDotOpenBLAS, "MKL": ExpandDotMKL, "cuBLAS": ExpandDotCuBLAS, - "FPGA_PartialSums": ExpandDotFpgaPartialSums, - "FPGA_Accumulate": ExpandDotFpgaAccumulate, } default_implementation = None diff --git a/dace/libraries/blas/nodes/gemm.py b/dace/libraries/blas/nodes/gemm.py index 003ab45bba..b13f462b41 100644 --- a/dace/libraries/blas/nodes/gemm.py +++ b/dace/libraries/blas/nodes/gemm.py @@ -1,4 +1,4 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. from copy import deepcopy as dc from dace import dtypes, memlet as mm, properties, data as dt from dace.symbolic import symstr, equal, equal_valued @@ -493,466 +493,6 @@ def _gemm_pblas(_a: dtype[M, K], _b: dtype[K, N], _c: dtype[M, N]): return _gemm_pblas.to_sdfg() -class ExpandGemmFPGA1DSystolic(ExpandTransformation): - """ - FPGA based implementation of GEMM, using a 1D systolic array. - - Currently it supports non-transposed input matrices, and non-vectorized input array A. - """ - - environments = [] - - @staticmethod - def expansion(node, parent_state, parent_sdfg, num_pes=32, tile_size_m=None): - """ - GEMM node expansion. - - :param node: Node to expand. - :param parent_state: State that the node is in. - :param parent_sdfg: SDFG that the node is in. - :param num_pes: Number of Processing Elements of the systolic array. By default it is set to 32. - - :param tile_size_m: tiling size considering columns of the input matrix B and resulting matrix C. - If B/C are vectorized, the tile size refers to the vectorized container. - If set to None, no tiling is used, corresponding to setting the tile size - equal to the number of columns of B/C. - :return: - """ - - ((edge_a, outer_array_a, shape_a, strides_a, _, _), (edge_b, outer_array_b, shape_b, strides_b, _, _), - (edge_c, outer_array_c, shape_c, strides_c, _, _)) = _get_matmul_operands(node, parent_state, parent_sdfg) - - dtype_a = outer_array_a.dtype.type - dtype_b = outer_array_b.dtype.type - dtype_c = dace.dtype_to_typeclass(np.result_type(dtype_a, dtype_b).type) - shape_c = (shape_a[0], shape_b[1]) - if node.transA: - raise NotImplementedError("GEMM FPGA expansion not implemented for transposed A.") - if node.transB: - raise NotImplementedError("GEMM FPGA expansion not implemented for transposed B.") - - if outer_array_a.veclen > 1: - raise NotImplementedError("Vectorization not support for input array A.") - - if len(shape_a) != 2 or len(shape_b) != 2 or shape_a[1] != shape_b[0]: - raise SyntaxError("Matrix sizes must match") - - if outer_array_b.dtype.veclen != outer_array_c.dtype.veclen: - raise SyntaxError("Vectorization lengths of B and C must match") - - ###################################################################### - # GEMM Parameters and checks - - # Note: the following sizes consider also vectorization - vec_width = outer_array_b.dtype.veclen - vec_type = dace.vector(dtype_c, vec_width) - N, K, M = shape_a[0], shape_a[1], shape_b[1] - - P = num_pes - T = tile_size_m - if T is None: - T = M - - # we will perform sanity check using T and M. But at this stage, we still - # don't know to what outer symbol they will map. - # We try to resolve them to constant if they are symbolic, otherwise we skip the checks - T_constant = dace.symbolic.resolve_symbol_to_constant(T, parent_sdfg) - K_constant = dace.symbolic.resolve_symbol_to_constant(K, parent_sdfg) - - # Safe delay: this will be used in the compute state, pipeline scope, to insert - # a delay between accumulation on the same result if needed. - # Further explanations are provided in the compute state. - - # Note: this is a platform and type dependent parameter. - if T_constant is not None: - L = max(16 - T_constant, 0) - else: - L = 0 - - # This implementation uses a flattened nested loop, that overlaps feeding, - # computing and draining phases. Each PE is responsible for computing one - # tile of one row of the final result C. With the current implementation, - # A PE needs K*T cycles to compute the results and then P*T clock cycles - # to fully drain them (draining is distributed across PEs). - # Therefore, in order to guarantee correctness and deadlock free we have - # to ensure that the number of cycles needed to drain the results is less - # or equal to the number of cycles needed to compute them. - # That is PT <= KT. - - if K_constant is not None and P > K_constant: - raise ValueError(f"GEMM-FPGA: Number of processing elements {P} must be smaller than the K-dimension {K}.") - - ###################################################################### - # Build the SDFG - - new_sdfg = dace.SDFG(node.label + "_sdfg") - new_state = new_sdfg.add_state("compute") - - # Add data descriptors - new_sdfg.add_array("_a", shape_a, dtype_a, strides=strides_a, storage=outer_array_a.storage) - new_sdfg.add_array("_b", shape_b, dtype_b, strides=strides_b, storage=outer_array_b.storage) - new_sdfg.add_array("_c", shape_c, dtype_c, strides=strides_c, storage=outer_array_c.storage) - - def make_read_A(state): - - # A given row of A must be repeated according to B number of tiles - # Both N and M can be not a multiple of P and T respectively - entry, exit = state.add_map("read_A", { - "n0": f"0:ceiling({N}/{P})", - "tm": f"0:ceiling({M}/{T})", - "k": f"0:{K}", - "n1": f"0:{P}" - }, - schedule=dace.ScheduleType.FPGA_Device) - - # The reader of A reads one element per clock cycle. - # Note that if P > T+L, then this will be the bottleneck - - mem = state.add_read("_a") - pipe = state.add_write("A_pipe") - - # Read data from memory: if we are out-of-bound do not read from memory - # but inject dummy data - tasklet = state.add_tasklet("read_A", {"from_memory"}, {"to_kernel"}, f"""\ -data = from_memory if n0 * {P} + n1 < {N} else 0 -to_kernel = data""") - - state.add_memlet_path(mem, - entry, - tasklet, - dst_conn="from_memory", - memlet=dace.Memlet(f"_a[n0 * {P} + n1, k]", dynamic=True, allow_oob=True)) - state.add_memlet_path(tasklet, - exit, - pipe, - src_conn="to_kernel", - memlet=dace.Memlet(f"A_pipe[{P} - n1 - 1]")) - - def make_read_B(state): - - # Also while reading B, we have to consider that T and P could not divide - # M and N - - entry, exit = state.add_map("read_B", { - "n": f"0:ceiling({N}/{P})", - "tm": f"0:ceiling({M}/{T})", - "k": f"0:{K}", - "m": f"0:{T}" - }, - schedule=dace.ScheduleType.FPGA_Device) - - # If we are out-of bound, use a dummy value - new_sdfg.add_array("B_dummy", - dtype=vec_type, - shape=[1], - transient=True, - storage=dace.dtypes.StorageType.FPGA_Registers) - b_dummy = state.add_access("B_dummy") - init_tasklet = state.add_tasklet("init_dummy_B", {}, {"init_data"}, "init_data = 0") - - state.add_memlet_path(init_tasklet, b_dummy, src_conn="init_data", memlet=dace.Memlet("B_dummy[0]")) - - mem = state.add_read("_b") - pipe = state.add_write("B_pipe") - tasklet = state.add_tasklet( - "read_B", {"from_memory", "dummy_data"}, {"to_kernel"}, f"""\ -data = from_memory if tm*{T} + m < {M} else dummy_data -to_kernel = data""") - - state.add_memlet_path(b_dummy, entry, tasklet, dst_conn="dummy_data", memlet=dace.Memlet("B_dummy[0]")) - - state.add_memlet_path(mem, - entry, - tasklet, - dst_conn="from_memory", - memlet=dace.Memlet(f"_b[k, tm*{T} + m]", dynamic=True, allow_oob=True)) - - state.add_memlet_path(tasklet, exit, pipe, src_conn="to_kernel", memlet=dace.Memlet("B_pipe[0]")) - - def make_write_C(state): - - # Receives the results and adds it to C - - pipe = state.add_read("C_pipe") - if not equal_valued(0, node.beta): - mem_read = state.add_read("_c") - mem = state.add_write("_c") - - entry_map, exit_map = state.add_map("write_C", { - "n0": f"0:ceiling({N}/{P})", - "tm": f"0:ceiling({M}/{T})", - "n1": f"0:{P}", - "m": f"0:{T}" - }, - schedule=dace.ScheduleType.FPGA_Device) - - # write in memory by adding C when we copy that to memory - - # deal with out-of-bound accesses - - mul_accumulated = f"{node.alpha} * from_kernel" if not equal_valued(1, node.alpha) else "from_kernel" - if not equal_valued(0, node.beta): - if not equal_valued(1, node.beta): - add_prev_c = f" + {node.beta} * prev_c" - else: - add_prev_c = " + prev_c" - else: - add_prev_c = "" - tasklet_inputs = {"from_kernel", "prev_c"} if not equal_valued(0, node.beta) else {"from_kernel"} - tasklet = state.add_tasklet( - "write_C", tasklet_inputs, {"to_memory"}, f"""\ -if tm * {T} + m < {M} and n0 * {P} + n1 < {N} : - to_memory = {mul_accumulated}{add_prev_c} -""") - state.add_memlet_path(pipe, - entry_map, - tasklet, - dst_conn="from_kernel", - memlet=dace.Memlet(f"C_pipe[{P}-1]")) - if not equal_valued(0, node.beta): - state.add_memlet_path(mem_read, - entry_map, - tasklet, - dst_conn="prev_c", - memlet=dace.Memlet(f"_c[n0 * {P} + n1, tm * {T} + m]", - dynamic=True, - allow_oob=True)) - - state.add_memlet_path(tasklet, - exit_map, - mem, - src_conn="to_memory", - memlet=dace.Memlet(f"_c[n0 * {P} + n1, tm * {T} + m]", dynamic=True, allow_oob=True)) - - def make_compute(sdfg, state): - - A_pipe_in = state.add_read("A_pipe") - B_pipe_in = state.add_read("B_pipe") - B_pipe_out = state.add_write("B_pipe") - C_pipe_in = state.add_read("C_pipe") - C_pipe_out = state.add_write("C_pipe") - - # The computation is expressed a single, flattened loop, which is generated by the following - # pipeline scope. Each PE accumulates over T partial results. The drain phase last P*T clock cycles. - # Draining and compute are overlapped. - # We are generating the loop by explicitly ignoring loop carried dependencies. Therefore, we have - # to guarantee that the PE will accumulate on the same partial result only when its value is consolidated. - # The + L is a safe delay between accumulation between the same partial result. - # It must be computed by considering T and the latency needed to consolidate a partial result - # (which is the latency of the add + latency for reading and writing to BRAM). - - entry_pipeline, exit_pipeline = state.add_pipeline("compute_and_drain", { - "n0": f"0:ceiling({N}/{P})", - "tm": f"0:ceiling({M}/{T})", - "k": f"0:{K}", - "m": f"0:{T} + {L}" - }, - drain_size=P * T, - drain_overlap=False, - additional_iterators={ - 'm_drain': 0, - 'k_drain': 0 - }, - schedule=dace.ScheduleType.FPGA_Device) - - # Instantiate buffers - sdfg.add_scalar("A_reg", dtype=dtype_a, transient=True, storage=dace.dtypes.StorageType.FPGA_Registers) - A_reg = state.add_write("A_reg") - A_reg_init = state.add_access("A_reg") - - # For C result we are going to use vectorized data type - - # Note: for some of the Sacred Mysteries of Intel OpenCL Compiler (TM), if this buffer is smaller - # than 24 floats, the II of the pipeline will be 5. Therefore we check this and in case we enlarge it - buffer_size = T if T_constant is None else max(T_constant, 24) - sdfg.add_array("C_buffer", [buffer_size], - dtype=vec_type, - transient=True, - storage=dace.dtypes.StorageType.FPGA_Local) - C_buffer_in = state.add_read("C_buffer") - C_buffer_out = state.add_write("C_buffer") - - # Init data to reset partial results - new_sdfg.add_array("C_init", - dtype=vec_type, - shape=[1], - transient=True, - storage=dace.dtypes.StorageType.FPGA_Registers) - C_init = state.add_access("C_init") - C_init_tasklet = state.add_tasklet("C_data_init", {}, {"init_data"}, "init_data = 0") - - state.add_memlet_path(C_init_tasklet, C_init, src_conn="init_data", memlet=dace.Memlet("C_init[0]")) - state.add_memlet_path(entry_pipeline, C_init_tasklet, memlet=dace.Memlet()) - - # Feed A - # every PE: reads input data, buffer the data assigned to it - buffer_a_tasklet = state.add_tasklet( - "buffer_a", {"a_in"}, { - "a_reg", - }, f"""\ -if m == 0 and not {entry_pipeline.pipeline.drain_condition()}: - a_reg = a_in""") - - state.add_memlet_path(A_pipe_in, - entry_pipeline, - buffer_a_tasklet, - memlet=dace.Memlet("A_pipe[p]", dynamic=True), - dst_conn="a_in") - state.add_memlet_path(buffer_a_tasklet, - A_reg, - memlet=dace.Memlet("A_reg[0]", dynamic=True), - src_conn="a_reg") - - # Feed B - sdfg.add_array("B_reg", - shape=[1], - dtype=vec_type, - transient=True, - storage=dace.dtypes.StorageType.FPGA_Local) - B_reg = state.add_access("B_reg") - buffer_b_tasklet = state.add_tasklet( - "buffer_b", {"b_in"}, {"b_reg_out"}, f"""\ -if m>={L} and not {entry_pipeline.pipeline.drain_condition()}: - b_reg_out = b_in""") - - state.add_memlet_path(B_pipe_in, - entry_pipeline, - buffer_b_tasklet, - memlet=dace.Memlet("B_pipe[p]", dynamic=True), - dst_conn="b_in") - state.add_memlet_path(buffer_b_tasklet, - B_reg, - memlet=dace.Memlet("B_reg[0]", dynamic=True), - src_conn="b_reg_out") - - # Compute, Forward B, and Drain - compute_tasklet = state.add_tasklet( - "compute_and_drain", {"a_in", "b_in", "c_in", "forward_in", "c_init_data"}, - {"b_out", "c_out", "c_pipe_out"}, f"""\ -result = c_in -if m >= {L} and not {entry_pipeline.pipeline.drain_condition()}: - c_prev = c_init_data if k == 0 else c_in - result = c_prev + a_in * b_in - c_out = result - if p < {P} - 1: - b_out = b_in -# Drain -# when we have to drain: -# - if we are working on second assigned row or second tile and we have something to drain -# - if k = K-1 and m>=L: each PE has just finished to compute something -# - if we are in the draining phase -# How: -# - if k = K-1 and m>=L: then the PE drains its own result -#- otherwise, if k_drain

0 or tm > 0) and k_drain

= {L}) or ({entry_pipeline.pipeline.drain_condition()} and k_drain < p): - c_pipe_out = result if (p==0 or (k_drain=={K}-1 and not {entry_pipeline.pipeline.drain_condition()})) else forward_in - -# adjust draining iterators -if not {entry_pipeline.pipeline.drain_condition()}: - if m_drain >= {L} + {T} -1: - m_drain = 0 - if k_drain >= {K} - 1: - k_drain = 0 - else: - k_drain = k_drain +1 - else: - m_drain = m_drain + 1 -else: - if m_drain >= {T} -1: - m_drain = 0 - if k_drain >= {K} - 1: - k_drain = 0 - else: - k_drain = k_drain +1 - else: - m_drain = m_drain + 1 - """) - - state.add_memlet_path(A_reg, compute_tasklet, dst_conn="a_in", memlet=dace.Memlet("A_reg[0]")) - state.add_memlet_path(B_reg, - compute_tasklet, - memlet=dace.Memlet("B_reg[0]", dynamic=False), - dst_conn="b_in") - state.add_memlet_path(C_init, compute_tasklet, memlet=dace.Memlet("C_init[0]"), dst_conn="c_init_data") - - state.add_memlet_path(compute_tasklet, - exit_pipeline, - B_pipe_out, - memlet=dace.Memlet("B_pipe[p + 1]", dynamic=True), - src_conn="b_out") - state.add_memlet_path(C_buffer_in, - entry_pipeline, - compute_tasklet, - dst_conn="c_in", - memlet=dace.Memlet(f"C_buffer[m-{L}]", allow_oob=True)) - - state.add_memlet_path(compute_tasklet, - exit_pipeline, - C_buffer_out, - memlet=dace.Memlet(f"C_buffer[m-{L}]", allow_oob=True, dynamic=True), - src_conn="c_out") - - state.add_memlet_path(C_pipe_in, - entry_pipeline, - compute_tasklet, - memlet=dace.Memlet("C_pipe[p-1]", dynamic=True), - dst_conn="forward_in") - state.add_memlet_path(compute_tasklet, - exit_pipeline, - C_pipe_out, - memlet=dace.Memlet("C_pipe[p]", dynamic=True), - src_conn="c_pipe_out") - - # Unroll processing elements - compute_entry, compute_exit = state.add_map("unroll_compute", {"p": "0:{}".format(P)}, - schedule=dace.ScheduleType.FPGA_Device, - unroll=True) - - # Bring data nodes into scope - state.add_memlet_path(compute_entry, A_pipe_in, memlet=dace.memlet.Memlet()) - state.add_memlet_path(compute_entry, B_pipe_in, memlet=dace.memlet.Memlet()) - state.add_memlet_path(compute_entry, C_pipe_in, memlet=dace.memlet.Memlet()) - - state.add_memlet_path(B_pipe_out, compute_exit, memlet=dace.memlet.Memlet()) - - state.add_memlet_path(C_pipe_out, compute_exit, memlet=dace.memlet.Memlet()) - - state.add_memlet_path(compute_entry, A_reg_init, memlet=dace.memlet.Memlet()) - state.add_memlet_path(A_reg_init, entry_pipeline, memlet=dace.memlet.Memlet()) - b_init = state.add_access("B_reg") - state.add_memlet_path(compute_entry, b_init, memlet=dace.Memlet()) - state.add_memlet_path(b_init, entry_pipeline, memlet=dace.Memlet()) - state.add_memlet_path(compute_entry, C_buffer_in, memlet=dace.Memlet()) - state.add_memlet_path(C_buffer_out, compute_exit, memlet=dace.Memlet()) - - # build the compute State - - new_sdfg.add_stream("A_pipe", - dtype_a, - transient=True, - shape=(P, ), - storage=dace.dtypes.StorageType.FPGA_Local, - buffer_size=str(P)) - new_sdfg.add_stream("B_pipe", - vec_type, - transient=True, - shape=(P + 1, ), - buffer_size=1, - storage=dace.dtypes.StorageType.FPGA_Local) - new_sdfg.add_stream("C_pipe", - vec_type, - transient=True, - shape=(P + 1, ), - buffer_size=T, - storage=dace.dtypes.StorageType.FPGA_Local) - - make_read_A(new_state) - make_read_B(new_state) - make_compute(new_sdfg, new_state) - make_write_C(new_state) - return new_sdfg - - @dace.library.node class Gemm(dace.sdfg.nodes.LibraryNode): """Executes alpha * (A @ B) + beta * C. C should be unidirectionally @@ -967,7 +507,6 @@ class Gemm(dace.sdfg.nodes.LibraryNode): "cuBLAS": ExpandGemmCuBLAS, "rocBLAS": ExpandGemmRocBLAS, "PBLAS": ExpandGemmPBLAS, - "FPGA1DSystolic": ExpandGemmFPGA1DSystolic } default_implementation = None diff --git a/dace/libraries/blas/nodes/gemv.py b/dace/libraries/blas/nodes/gemv.py index a791faba10..0cc0ff147f 100644 --- a/dace/libraries/blas/nodes/gemv.py +++ b/dace/libraries/blas/nodes/gemv.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. import copy from dace import properties, symbolic import dace.library @@ -112,509 +112,6 @@ def expansion(node, parent_state, parent_sdfg, **kwargs): return sdfg -@dace.library.expansion -class ExpandGemvFpgaAccumulate(ExpandTransformation): - """ - This FPGA-oriented expansion iterates over the input matrix A in simple - row-major order, with optional tiling in both dimensions, where the tiles - are also traversed in simple row-major order. This means that y is only - written once, but x is read for every tile in the y-dimension. - - The implementation requires accumulation on the output, and does NOT assume - native accumulation for the given data type. Instead it uses multiple - partial sums to ensure that II=1, and only writes the final accumulated - value once it has been combined from the partial sums. - - This works for both transposed and non-transposed A, but vectorization is - only implemented for non-transposed A. - """ - # The above corresponds to gemv_v1 in FBLAS - - environments = [] - - @staticmethod - def expansion(node, parent_state, parent_sdfg, tile_size_x=None, tile_size_y=None, num_partial_sums=16): - """ - :param node: Node to expand. - :param parent_state: State that the node is in. - :param parent_sdfg: SDFG that the node is in. - :param tile_size_x: Tile size along the dimension of the vector x. If - set to None, no tiling is used, corresponding to - setting the tile size equal to the full size of x. - :param tile_size_y: Tile size along the dimension of the vector y. If - set to None, no tiling is used, corresponding to - setting the tile size equal to the full size of y. - :param num_partial_sums: The number of distinct registers to accumulate - contributions to the final sum into. Should be - a power of two, and should be higher than the - latency of adding two numbers of the given - data type. - """ - - node.validate(parent_sdfg, parent_state) - - sdfg = dace.SDFG("gemv") - state = sdfg.add_state("gemv") - - alpha = node.alpha - beta = node.beta - - # Get input/output data (the method considers also the presence of view nodes) - ((edge_a, desc_a, _, _, shape_a, strides_a), (edge_x, desc_x, _, _, shape_x, strides_x), - (edge_y, desc_y, _, _, shape_y, strides_y)) = _get_matmul_operands(node, - parent_state, - parent_sdfg, - name_lhs="_A", - name_rhs="_x", - name_out="_y") - - # Create local versions of input/output data nodes - _, desc_a = sdfg.add_array("_A", - shape_a, - desc_a.dtype, - strides=strides_a, - storage=desc_a.storage, - transient=False) - _, desc_x = sdfg.add_array("_x", - shape_x, - desc_x.dtype, - strides=strides_x, - storage=desc_x.storage, - transient=False) - _, desc_y_y = sdfg.add_array("_y", - shape_y, - desc_y.dtype, - strides=strides_y, - storage=desc_y.storage, - transient=False) - - if node.transA and desc_a.dtype.veclen > 1: - raise NotImplementedError("Vectorization not implemented for transposed A.") - - # Create accesses - read_a = state.add_read("_A") - read_x = state.add_read("_x") - if beta != 0: - read_y = state.add_read("_y") - write_y = state.add_write("_y") - - size_x = desc_x.shape[0] - size_y = desc_y.shape[0] - if tile_size_x is None: - tile_size_x = size_x - if tile_size_y is None: - tile_size_y = size_y - num_tiles_y = f"{size_y}/{tile_size_y}" - num_tiles_x = f"{size_x}/{tile_size_x}" - - veclen = desc_a.dtype.veclen - - # Create tile map - y_tile_entry, y_tile_exit = state.add_map("y_tiles", {"ty": f"0:{num_tiles_y}"}, - schedule=dace.ScheduleType.FPGA_Device) - x_tile_entry, x_tile_exit = state.add_map("x_tiles", {"tx": f"0:{num_tiles_x}"}, - schedule=dace.ScheduleType.FPGA_Device) - - # Create y map - y_entry, y_exit = state.add_map("y", {"iy": f"0:{tile_size_y}"}, schedule=dace.ScheduleType.FPGA_Device) - - # Create x map - x_entry, x_exit = state.add_map("x", {"ix": f"0:{tile_size_x}"}, schedule=dace.ScheduleType.FPGA_Device) - - # Local buffer of x - sdfg.add_array("x_local", (tile_size_x, ), desc_x.dtype, storage=dace.StorageType.FPGA_Local, transient=True) - x_local_access = state.add_read("x_local") - - if beta != 0: - raise NotImplementedError("Not yet implemented.") - - multiply_tasklet = state.add_tasklet("multiply", {"A_in", "x_in"}, {f"product": desc_a.dtype}, - "product = A_in * x_in") - - if isinstance(desc_a, dt.Stream): - subset = "0" - elif node.transA: - subset = f"tx * {tile_size_x} + ix, ty * {tile_size_y} + iy" - else: - subset = f"ty * {tile_size_y} + iy, tx * {tile_size_x} + ix" - state.add_memlet_path(read_a, - y_tile_entry, - x_tile_entry, - y_entry, - x_entry, - multiply_tasklet, - dst_conn="A_in", - memlet=dace.Memlet(f"_A[{subset}]")) - read_x_entry, read_x_exit = state.add_map("read_x", {"ix": f"0:{tile_size_x}"}, - schedule=dace.ScheduleType.FPGA_Device) - subset = ("0" if isinstance(desc_x, dt.Stream) else f"tx*{tile_size_x} + ix") - read_x_tasklet = state.add_tasklet("read_x", {"x_memory"}, {"x_buffer"}, "x_buffer = x_memory") - state.add_memlet_path(read_x, - y_tile_entry, - x_tile_entry, - read_x_entry, - read_x_tasklet, - dst_conn="x_memory", - memlet=dace.Memlet(f"_x[{subset}]")) - state.add_memlet_path(read_x_tasklet, - read_x_exit, - x_local_access, - src_conn="x_buffer", - memlet=dace.Memlet(f"x_local[ix]")) - state.add_memlet_path(x_local_access, - y_entry, - x_entry, - multiply_tasklet, - dst_conn="x_in", - memlet=dace.Memlet(f"x_local[ix]")) - - # Write to buffer - sdfg.add_array("product_vector", (1, ), desc_a.dtype, transient=True, storage=dace.StorageType.FPGA_Local) - product_vector = state.add_access("product_vector") - state.add_memlet_path(multiply_tasklet, - product_vector, - src_conn="product", - memlet=dace.Memlet(f"product_vector[0]")) - - # Vector length conversion - sdfg.add_array("product_scalar", (veclen, ), - desc_a.dtype.base_type, - transient=True, - storage=dace.StorageType.FPGA_Local) - product_scalar = state.add_access("product_scalar") - state.add_memlet_path(product_vector, - product_scalar, - memlet=dace.Memlet(f"product_vector[0]", other_subset=f"0:{veclen}")) - - # Now we need to collapse this - reduce_vector_entry, reduce_vector_exit = state.add_map("reduce_vector", {"u": f"0:{veclen}"}, - schedule=dace.ScheduleType.FPGA_Device, - unroll=True) - - reduce_vector_tasklet = state.add_tasklet("reduce_vector", {"product_in", "acc_in"}, {"acc_out"}, - "acc_out = product_in + acc_in") - state.add_memlet_path(product_scalar, - reduce_vector_entry, - reduce_vector_tasklet, - dst_conn="product_in", - memlet=dace.Memlet(f"{product_scalar}[u]")) - - # Add accumulation register - sdfg.add_array("accumulate_product", (1, ), - desc_a.dtype.base_type, - transient=True, - storage=dace.StorageType.FPGA_Local) - accumulate_product_read = state.add_access("accumulate_product") - accumulate_product_write = state.add_access("accumulate_product") - - # Initialize it to zero - init_reduce_vector_tasklet = state.add_tasklet("init_reduce_vector", {}, {"acc_out"}, "acc_out = 0") - state.add_memlet_path(x_entry, init_reduce_vector_tasklet, memlet=dace.Memlet()) - state.add_memlet_path(init_reduce_vector_tasklet, - accumulate_product_read, - src_conn="acc_out", - memlet=dace.Memlet(f"accumulate_product[0]")) - - # Connect it to the tasklet - state.add_memlet_path(accumulate_product_read, - reduce_vector_entry, - reduce_vector_tasklet, - dst_conn="acc_in", - memlet=dace.Memlet(f"accumulate_product[0]")) - state.add_memlet_path(reduce_vector_tasklet, - reduce_vector_exit, - accumulate_product_write, - src_conn="acc_out", - memlet=dace.Memlet(f"accumulate_product[0]")) - - # Partial sums - sdfg.add_array("partial_sums", (num_partial_sums, ), - desc_y.dtype, - storage=dace.StorageType.FPGA_Registers, - transient=True) - partial_sum_read = state.add_read("partial_sums") - partial_sum_write = state.add_access("partial_sums") - - # Output array - sdfg.add_array("y_local", (tile_size_y, ), desc_y.dtype, storage=dace.StorageType.FPGA_Local, transient=True) - - # Now we need to actually accumulate into a local register of y - y_local_read = state.add_read("y_local") - y_local_write = state.add_read("y_local") - update_y_tasklet = state.add_tasklet( - "update_y", {"y_in", "acc_in"}, {"acc_out"}, f"""\ -prev = acc_in if ix >= {num_partial_sums} else 0 -acc_out = prev + y_in""") - state.add_memlet_path(accumulate_product_write, - update_y_tasklet, - dst_conn="y_in", - memlet=dace.Memlet(f"accumulate_product[0]")) - state.add_memlet_path(partial_sum_read, - x_entry, - update_y_tasklet, - dst_conn="acc_in", - memlet=dace.Memlet(f"partial_sums[ix%{num_partial_sums}]")) - state.add_memlet_path(y_tile_entry, y_local_read, memlet=dace.Memlet()) - state.add_memlet_path(y_entry, partial_sum_read, memlet=dace.Memlet()) - state.add_memlet_path(update_y_tasklet, - x_exit, - partial_sum_write, - src_conn="acc_out", - memlet=dace.Memlet(f"partial_sums[ix%{num_partial_sums}]")) - - # Reduce the partial sums - reduce_sums_entry, reduce_sums_exit = state.add_map("reduce_partial_sums", {"u": f"0:{num_partial_sums}"}, - schedule=dace.ScheduleType.FPGA_Device, - unroll=True) - reduce_sums_tasklet = state.add_tasklet("reduce_partial_sums", {"sum_in", "val_in"}, {"sum_out"}, """ -prev = sum_in if u > 0 else 0 -sum_out = prev + val_in""") - sdfg.add_array("accumulate_sum", (1, ), desc_y.dtype, transient=True, storage=dace.StorageType.FPGA_Local) - accumulate_sum_read = state.add_access("accumulate_sum") - accumulate_sum_write = state.add_access("accumulate_sum") - state.add_memlet_path(y_entry, accumulate_sum_read, memlet=dace.Memlet()) - state.add_memlet_path(accumulate_sum_read, - reduce_sums_entry, - reduce_sums_tasklet, - dst_conn="sum_in", - memlet=dace.Memlet("accumulate_sum[0]")) - state.add_memlet_path(reduce_sums_tasklet, - reduce_sums_exit, - accumulate_sum_write, - src_conn="sum_out", - memlet=dace.Memlet("accumulate_sum[0]")) - state.add_memlet_path(partial_sum_write, - reduce_sums_entry, - reduce_sums_tasklet, - dst_conn="val_in", - memlet=dace.Memlet("partial_sums[u]")) - - # Combine with y buffer - combine_tasklet = state.add_tasklet("combine_y", {"val", "buffer_in"}, {"buffer_out"}, """\ -prev = buffer_in if tx > 0 else 0 -buffer_out = prev + val""") - state.add_memlet_path(accumulate_sum_write, - combine_tasklet, - dst_conn="val", - memlet=dace.Memlet("accumulate_sum[0]")) - state.add_memlet_path(y_local_read, - x_tile_entry, - y_entry, - combine_tasklet, - dst_conn="buffer_in", - memlet=dace.Memlet("y_local[iy]")) - - state.add_memlet_path(combine_tasklet, - y_exit, - x_tile_exit, - y_local_write, - src_conn="buffer_out", - memlet=dace.Memlet(f"y_local[iy]")) - - subset = ("0" if isinstance(desc_y, dt.Stream) else f"ty*{tile_size_y} + iy") - write_y_entry, write_y_exit = state.add_map("write_y", {"iy": f"0:{tile_size_y}"}, - schedule=dace.ScheduleType.FPGA_Device) - write_y_tasklet = state.add_tasklet("write_y", {"y_buffer"}, {"y_memory"}, "y_memory = y_buffer") - state.add_memlet_path(y_local_write, - write_y_entry, - write_y_tasklet, - dst_conn="y_buffer", - memlet=dace.Memlet(f"y_local[iy]")) - state.add_memlet_path(write_y_tasklet, - write_y_exit, - y_tile_exit, - write_y, - src_conn="y_memory", - memlet=dace.Memlet(f"_y[{subset}]")) - - return sdfg - - -@dace.library.expansion -class ExpandGemvFpgaTilesByColumn(ExpandTransformation): - """ - FPGA-oriented expansion that reads the input matrix A in column-major - order, such that consecutive values are accumulated into different - registers, avoiding a loop-carried dependency due to accumulation. - - The matrix can optionally be tiled, where the tiles will be traversed in - row-major order in order to bound the size of the output buffer to the tile - size. The tile size on y must be larger than the latency of addition for - the given data type. - - This expansion supports both transposed A and non-transposed A, but - vectorization is only implemented for transposed A. - """ - # This corresponds to gemv_v2 in FBLAS - - environments = [] - - @staticmethod - def expansion(node, state, sdfg, tile_size_x=None, tile_size_y=None): - """ - :param node: Node to expand. - :param parent_state: State that the node is in. - :param parent_sdfg: SDFG that the node is in. - :param tile_size_x: Tile size along the dimension of the vector x. If - set to None, no tiling is used, corresponding to - setting the tile size equal to the full size of x. - :param tile_size_y: Tile size along the dimension of the vector y. If - set to None, no tiling is used, corresponding to - setting the tile size equal to the full size of y. - """ - - node.validate(sdfg, state) - - for e in state.in_edges(node): - if e.dst_conn == "_A": - desc_a = sdfg.arrays[e.data.data] - elif e.dst_conn == "_x": - desc_x = sdfg.arrays[e.data.data] - for e in state.out_edges(node): - if e.src_conn == "_y": - desc_y = sdfg.arrays[e.data.data] - - sdfg = dace.SDFG("gemv") - state = sdfg.add_state("gemv") - - alpha = node.alpha - beta = node.beta - - # Create local versions of input data nodes - desc_a = desc_a.clone() - desc_a.transient = False - sdfg.add_datadesc("_A", desc_a) - desc_x = desc_x.clone() - desc_x.transient = False - sdfg.add_datadesc("_x", desc_x) - desc_y = desc_y.clone() - desc_y.transient = False - sdfg.add_datadesc("_y", desc_y) - - if not node.transA and desc_a.dtype.veclen > 1: - raise NotImplementedError("Vectorization not implemented for non-transposed A.") - - # Create accesses - read_a = state.add_read("_A") - read_x = state.add_read("_x") - if beta != 0: - read_y = state.add_read("_y") - write_y = state.add_write("_y") - - size_x = desc_x.shape[0] - size_y = desc_y.shape[0] - if tile_size_x is None: - tile_size_x = size_x - if tile_size_y is None: - tile_size_y = size_y - num_tiles_y = f"{size_y}/{tile_size_y}" - num_tiles_x = f"{size_x}/{tile_size_x}" - - # Create y tile map - y_tile_entry, y_tile_exit = state.add_map("y_tiles", {"ty": f"0:{num_tiles_y}"}, - schedule=dace.ScheduleType.FPGA_Device) - - # Create buffer - sdfg.add_array("y_local", (tile_size_y, ), desc_y.dtype, storage=dace.StorageType.FPGA_Local, transient=True) - y_local = state.add_access("y_local") - y_local_write = state.add_access("y_local") - - # Initialize buffer - init_entry, init_exit = state.add_map("init", {"iy": f"0:{tile_size_y}"}, - schedule=dace.ScheduleType.FPGA_Device) - if beta != 0: - if isinstance(desc_y, dt.Stream): - subset = "0" - else: - subset = f"ty*{tile_size_y}+iy" - init_tasklet = state.add_tasklet("init", {"y_in"}, {"y_out"}, - f"y_out = {desc_y.dtype.base_type.ctype}({beta}) * y_in") - state.add_memlet_path(read_y, - y_tile_entry, - init_entry, - init_tasklet, - dst_conn="y_in", - memlet=dace.Memlet(f"_y[{subset}]")) - state.add_memlet_path(init_tasklet, - init_exit, - y_local, - src_conn="y_out", - memlet=dace.Memlet(f"y_local[iy]")) - else: - state.add_memlet_path(y_tile_entry, init_entry, memlet=dace.Memlet()) - init_tasklet = state.add_tasklet("init", {}, {"y_out"}, "y_out = 0") - state.add_memlet_path(init_entry, init_tasklet, memlet=dace.Memlet()) - state.add_memlet_path(init_tasklet, init_exit, y_local, src_conn="y_out", memlet=dace.Memlet("y_local[iy]")) - - # Create x tile map - x_tile_entry, x_tile_exit = state.add_map("x_tiles", {"tx": f"0:{num_tiles_x}"}, - schedule=dace.ScheduleType.FPGA_Device) - - # Create loop over tile size in x - x_entry, x_exit = state.add_map("x", {"ix": f"0:{tile_size_x}"}, schedule=dace.ScheduleType.FPGA_Device) - - # Buffer a scalar value of x - sdfg.add_array("x_local", (1, ), desc_x.dtype, transient=True, storage=dace.StorageType.FPGA_Local) - x_local = state.add_access("x_local") - subset = "0" if isinstance(desc_x, dt.Stream) else f"tx*{tile_size_x}+ix" - state.add_memlet_path(read_x, y_tile_entry, x_tile_entry, x_entry, x_local, memlet=dace.Memlet(f"_x[{subset}]")) - - # Create loop over tile size in y - y_entry, y_exit = state.add_map("y", {"iy": f"0:{tile_size_y}"}, schedule=dace.ScheduleType.FPGA_Device) - - # Do computation - tasklet = state.add_tasklet("gemv", {"A_in", "x_in", "y_in"}, {"y_out"}, - f"y_out = y_in + {alpha} * A_in * x_in") - state.add_memlet_path(y_local, - x_tile_entry, - x_entry, - y_entry, - tasklet, - dst_conn="y_in", - memlet=dace.Memlet("y_local[iy]")) - state.add_memlet_path(x_local, y_entry, tasklet, dst_conn="x_in", memlet=dace.Memlet("x_local[0]")) - state.add_memlet_path(tasklet, - y_exit, - x_exit, - x_tile_exit, - y_local_write, - src_conn="y_out", - memlet=dace.Memlet("y_local[iy]")) - if isinstance(desc_a, dt.Stream): - subset = "0" - elif node.transA: - subset = f"tx * {tile_size_x} + ix, ty * {tile_size_y} + iy" - else: - subset = f"ty * {tile_size_y} + iy, tx * {tile_size_x} + ix" - state.add_memlet_path(read_a, - y_tile_entry, - x_tile_entry, - x_entry, - y_entry, - tasklet, - dst_conn="A_in", - memlet=dace.Memlet(f"_A[{subset}]")) - - # Write out tile of y - write_y_entry, write_y_exit = state.add_map("write_y", {"iy": f"0:{tile_size_y}"}, - schedule=dace.ScheduleType.FPGA_Device) - write_y_tasklet = state.add_tasklet("write_y", {"y_in"}, {"y_out"}, "y_out = y_in") - subset = ("0" if isinstance(desc_y, dt.Stream) else f"ty * {tile_size_y} + iy") - state.add_memlet_path(y_local_write, - write_y_entry, - write_y_tasklet, - dst_conn="y_in", - memlet=dace.Memlet("y_local[iy]")) - state.add_memlet_path(write_y_tasklet, - write_y_exit, - y_tile_exit, - write_y, - src_conn="y_out", - memlet=dace.Memlet(f"_y[{subset}]")) - - return sdfg - - @dace.library.expansion class ExpandGemvCuBLAS(ExpandTransformation): @@ -877,8 +374,6 @@ class Gemv(dace.sdfg.nodes.LibraryNode): "OpenBLAS": ExpandGemvOpenBLAS, "MKL": ExpandGemvMKL, "cuBLAS": ExpandGemvCuBLAS, - "FPGA_Accumulate": ExpandGemvFpgaAccumulate, - "FPGA_TilesByColumn": ExpandGemvFpgaTilesByColumn, "PBLAS": ExpandGemvPBLAS } default_implementation = None diff --git a/dace/libraries/blas/nodes/ger.py b/dace/libraries/blas/nodes/ger.py index c22f8f7010..a91c5e3b10 100644 --- a/dace/libraries/blas/nodes/ger.py +++ b/dace/libraries/blas/nodes/ger.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. from dace.properties import SymbolicProperty from dace.transformation.transformation import ExpandTransformation from dace.frontend.common import op_repository as oprepo @@ -76,145 +76,6 @@ def expansion(node, parent_state, parent_sdfg, **kwargs): return nsdfg_node -@dace.library.expansion -class ExpandGerFpga(ExpandTransformation): - """ - FPGA-specific expansion of GER with support for vectorization and tiling - in both dimensions. - """ - - environments = [] - - @staticmethod - def expansion(node, state, sdfg, m=None, n=None, tile_size_x=None, tile_size_y=None): - """ - :param node: Node to expand. - :param state: State that the node is in. - :param sdfg: SDFG that the node is in. - :param m: Override the number of rows. - :param n: Override the number of columns. - :param tile_size_x: Tile size along the M-dimension (rows of A, size of - vector x). - :param tile_size_x: Tile size along the N-dimension (columns of A, - size of vector y). - """ - - desc_a_in, desc_x, desc_y = node.validate(sdfg, state) - desc_a_out = None - for e in state.out_edges(node): - if e.src_conn == "_res": - desc_a_out = sdfg.arrays[e.data.data] - - sdfg = dace.SDFG("ger") - state = sdfg.add_state("ger") - - desc_a_in = desc_a_in.clone() - desc_x = desc_x.clone() - desc_y = desc_y.clone() - desc_a_out = desc_a_out.clone() - desc_a_in.transient = False - desc_a_out.transient = False - desc_x.transient = False - desc_y.transient = False - sdfg.add_datadesc("_A", desc_a_in) - sdfg.add_datadesc("_res", desc_a_out) - sdfg.add_datadesc("_x", desc_x) - sdfg.add_datadesc("_y", desc_y) - - m = m or node.m - n = n or node.n - alpha = node.alpha - veclen = desc_y.dtype.veclen - - size_x = m - size_y = n / veclen - - num_tiles_x = f"{size_x} / {tile_size_x}" - num_tiles_y = f"{size_y} / {tile_size_y}" - - y_tile_entry, y_tile_exit = state.add_map("y_tiles", {"ty": f"0:{num_tiles_y}"}, - schedule=dace.ScheduleType.FPGA_Device) - - sdfg.add_array("y_local", (tile_size_y, ), desc_y.dtype, transient=True, storage=dace.StorageType.FPGA_Local) - y_local = state.add_access("y_local") - - # Load y buffer - read_y = state.add_read("_y") - subset = ("0" if isinstance(desc_y, dace.data.Stream) else f"ty*{tile_size_y}+iy") - read_y_entry, read_y_exit = state.add_map("read_y", {"iy": f"0:{tile_size_y}"}, - schedule=dace.ScheduleType.FPGA_Device) - read_y_tasklet = state.add_tasklet("read_y", {"y_memory"}, {"y_buffer"}, "y_buffer = y_memory") - state.add_memlet_path(read_y, - y_tile_entry, - read_y_entry, - read_y_tasklet, - dst_conn="y_memory", - memlet=dace.Memlet(f"_y[{subset}]")) - state.add_memlet_path(read_y_tasklet, - read_y_exit, - y_local, - src_conn="y_buffer", - memlet=dace.Memlet(f"y_local[iy]")) - - x_tile_entry, x_tile_exit = state.add_map("x_tiles", {"tx": f"0:{num_tiles_x}"}, - schedule=dace.ScheduleType.FPGA_Device) - - x_entry, x_exit = state.add_map("x", {"ix": f"0:{tile_size_x}"}, schedule=dace.ScheduleType.FPGA_Device) - - # Load x - read_x = state.add_read("_x") - sdfg.add_array("x_local", (1, ), desc_x.dtype, transient=True, storage=dace.StorageType.FPGA_Local) - x_local = state.add_access("x_local") - subset = ("0" if isinstance(desc_x, dace.data.Stream) else f"tx*{tile_size_x} + ix") - state.add_memlet_path(read_x, - y_tile_entry, - x_tile_entry, - x_entry, - x_local, - memlet=dace.Memlet(f"_x[{subset}]", other_subset="0")) - - y_entry, y_exit = state.add_map("y", {"iy": f"0:{tile_size_y}"}, schedule=dace.ScheduleType.FPGA_Device) - - # Actual computation - compute_tasklet = state.add_tasklet("ger", {"a_in", "x_in", "y_in"}, {"a_out"}, - f"a_out = {alpha} * x_in * y_in + a_in") - - # Stream in A - read_a = state.add_read("_A") - subset_a = ("0" if isinstance(desc_a_in, dace.data.Stream) else f"tx*{tile_size_x} + ix, ty*{tile_size_y} + iy") - state.add_memlet_path(read_a, - y_tile_entry, - x_tile_entry, - x_entry, - y_entry, - compute_tasklet, - dst_conn="a_in", - memlet=dace.Memlet(f"_A[{subset_a}]")) - - # Load buffered x and y - state.add_memlet_path(x_local, y_entry, compute_tasklet, dst_conn="x_in", memlet=dace.Memlet("x_local[0]")) - state.add_memlet_path(y_local, - x_tile_entry, - x_entry, - y_entry, - compute_tasklet, - dst_conn="y_in", - memlet=dace.Memlet(f"y_local[iy]")) - - # Store result - write_a = state.add_write("_res") - state.add_memlet_path(compute_tasklet, - y_exit, - x_exit, - x_tile_exit, - y_tile_exit, - write_a, - src_conn="a_out", - memlet=dace.Memlet(f"_res[{subset_a}]")) - - return sdfg - - @library.node class Ger(LibraryNode): """ @@ -226,7 +87,7 @@ class Ger(LibraryNode): """ # Global properties - implementations = {"pure": ExpandGerPure, "FPGA": ExpandGerFpga} + implementations = {"pure": ExpandGerPure} default_implementation = None # Object fields diff --git a/dace/libraries/linalg/nodes/tensordot.py b/dace/libraries/linalg/nodes/tensordot.py index e2e6e54e46..03b89a5a2c 100644 --- a/dace/libraries/linalg/nodes/tensordot.py +++ b/dace/libraries/linalg/nodes/tensordot.py @@ -4,7 +4,7 @@ import dace.libraries.linalg.environments as environments from dace import library, nodes, properties -from dace.data import _prod +from dace.utils import prod as _prod from dace.libraries.blas import blas_helpers from dace.symbolic import symstr from dace.transformation.transformation import ExpandTransformation diff --git a/dace/libraries/mpi/nodes/gather.py b/dace/libraries/mpi/nodes/gather.py index 8ad2b0df8b..b231ff7cee 100644 --- a/dace/libraries/mpi/nodes/gather.py +++ b/dace/libraries/mpi/nodes/gather.py @@ -1,6 +1,6 @@ # Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. from dace import dtypes, library, properties -from dace.data import _prod +from dace.utils import prod as _prod from dace.libraries.mpi import utils from dace.sdfg import nodes from dace.symbolic import symstr diff --git a/dace/libraries/mpi/nodes/redistribute.py b/dace/libraries/mpi/nodes/redistribute.py index e58ed544b7..19dddc0f8b 100644 --- a/dace/libraries/mpi/nodes/redistribute.py +++ b/dace/libraries/mpi/nodes/redistribute.py @@ -1,6 +1,6 @@ # Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. from dace import dtypes, library, properties, subsets, symbolic -from dace.data import _prod +from dace.utils import prod as _prod from dace.libraries.mpi import utils from dace.sdfg import nodes from dace.transformation.transformation import ExpandTransformation diff --git a/dace/libraries/mpi/nodes/scatter.py b/dace/libraries/mpi/nodes/scatter.py index 04367cbbfb..59cc54f3a7 100644 --- a/dace/libraries/mpi/nodes/scatter.py +++ b/dace/libraries/mpi/nodes/scatter.py @@ -1,6 +1,6 @@ # Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. from dace import dtypes, library, properties -from dace.data import _prod +from dace.utils import prod as _prod from dace.libraries.mpi import utils from dace.sdfg import nodes from dace.symbolic import symstr diff --git a/dace/libraries/onnx/__init__.py b/dace/libraries/onnx/__init__.py new file mode 100644 index 0000000000..449c7913e8 --- /dev/null +++ b/dace/libraries/onnx/__init__.py @@ -0,0 +1,61 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +DaCe ONNX Integration Library. + +This module provides comprehensive support for importing and executing ONNX models +in DaCe. It enables: + +- Importing ONNX models and converting them to DaCe SDFGs +- Implementing ONNX operations as DaCe library nodes +- Automatic shape inference for dynamic models +- Multiple implementation strategies (pure, optimized, etc.) + +Main Components: +- ONNXModel: Main class for importing and manipulating ONNX models +- ONNXOp: Base class for ONNX operation nodes in SDFGs +- Schema system: Type checking and validation for ONNX operations + +The library is registered with DaCe and uses 'pure' as the default implementation +strategy for ONNX operations. +""" + +from dace.library import register_library, _DACE_REGISTERED_LIBRARIES + +try: + # Import schema and node utilities (nodes are lazy-loaded via __getattr__) + from .schema import onnx_representation, ONNXAttributeType, ONNXAttribute, ONNXTypeConstraint, ONNXParameterType, ONNXSchema, ONNXParameter + from .nodes import get_onnx_node, has_onnx_node + + register_library(__name__, "dace.libraries.onnx") + _DACE_REGISTERED_LIBRARIES["dace.libraries.onnx"].default_implementation = "pure" + + ONNX_AVAILABLE = True + + def __getattr__(name): + """Lazy attribute access for ONNX node classes, ONNXModel, and utilities.""" + if name == 'ONNXModel': + from dace.frontend.ml.onnx import ONNXModel as _ONNXModel + return _ONNXModel + if name == 'parse_variadic_param': + from .nodes.node_utils import parse_variadic_param as _parse_variadic_param + return _parse_variadic_param + if name.startswith('ONNX'): + # Initialize registry and get the node class + from .nodes.onnx_op_registry import _initialize_onnx_registry + _initialize_onnx_registry() + from .nodes import onnx_op_registry + if hasattr(onnx_op_registry, name): + return getattr(onnx_op_registry, name) + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") + +except ImportError: + # ONNX library not available + ONNXModel = None + onnx_representation = None + ONNXAttributeType = None + ONNXAttribute = None + ONNXTypeConstraint = None + ONNXParameterType = None + ONNXSchema = None + ONNXParameter = None + ONNX_AVAILABLE = False diff --git a/dace/libraries/onnx/converters.py b/dace/libraries/onnx/converters.py new file mode 100644 index 0000000000..5ba326934a --- /dev/null +++ b/dace/libraries/onnx/converters.py @@ -0,0 +1,247 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Type conversion utilities for ONNX-DaCe integration. + +This module provides conversion functions between ONNX and DaCe type systems: +- Converting ONNX protobuf types to DaCe types +- Converting DaCe types to ONNX representation +- Handling ONNX AttributeProto conversions +- Type validation and name sanitization + +Key Functions: +- convert_onnx_proto: Convert ONNX protobuf objects to Python/DaCe types +- onnx_tensor_type_to_typeclass: Convert ONNX tensor types to DaCe typeclasses +- typeclass_to_onnx_str: Convert DaCe types to ONNX string representation +- clean_onnx_name: Sanitize ONNX names for valid DaCe identifiers +""" + +import re +from typing import Union + +import onnx +from dace import config, dtypes as dt +from dace.dtypes import typeclass +from onnx.numpy_helper import to_array + + +def get_proto_attr(proto, name: str): + """Safely access a protobuf attribute with encoding validation. + + This function provides defensive checks against encoding issues when accessing + protobuf attributes. Python's getattr expects strings, but protobuf uses UTF-8. + + :param proto: The protobuf object to access. + :param name: The attribute name to retrieve (must be ASCII). + :return: The value of the requested attribute. + :raises ValueError: If the name is not ASCII-encodable. + + .. note:: + HasField checks may break in proto3, but ONNX doesn't use proto3 yet. + """ + + def is_ascii(s: str) -> bool: + """Check if a string is ASCII-encodable.""" + try: + s.encode('ascii') + except UnicodeEncodeError: + return False + else: + return True + + if not is_ascii(name): + raise ValueError( + f"Attempted to access non-ASCII property name '{name}' on protobuf {proto} (type {type(proto)}). " + "Please open an issue") + + return getattr(proto, name) + + +def convert_onnx_proto(attribute): + from dace.libraries.onnx.schema import ONNXAttributeType, _KNOWN_ONNX_PROTOS, ONNXParameterType + + if type(attribute) in _KNOWN_ONNX_PROTOS: + return _KNOWN_ONNX_PROTOS[type(attribute)].from_onnx_proto(attribute) + + # Check ONNX enum types BEFORE basic types, because ONNX enums derive from + # IntEnum and would incorrectly match isinstance(attribute, int) + if type(attribute) is onnx.defs.OpSchema.FormalParameterOption: + if attribute == onnx.defs.OpSchema.FormalParameterOption.Single: + return ONNXParameterType.Single + elif attribute == onnx.defs.OpSchema.FormalParameterOption.Optional: + return ONNXParameterType.Optional + elif attribute == onnx.defs.OpSchema.FormalParameterOption.Variadic: + return ONNXParameterType.Variadic + else: + raise NotImplementedError( + "Only single, optional and variadic formal parameters are supported, got".format(attribute)) + + if type(attribute) is onnx.defs.OpSchema.AttrType: + if attribute == onnx.defs.OpSchema.AttrType.FLOAT: + return ONNXAttributeType.Float + elif attribute == onnx.defs.OpSchema.AttrType.FLOATS: + return ONNXAttributeType.Floats + elif attribute == onnx.defs.OpSchema.AttrType.INT: + return ONNXAttributeType.Int + elif attribute == onnx.defs.OpSchema.AttrType.INTS: + return ONNXAttributeType.Ints + elif attribute == onnx.defs.OpSchema.AttrType.STRING: + return ONNXAttributeType.String + elif attribute == onnx.defs.OpSchema.AttrType.STRINGS: + return ONNXAttributeType.Strings + elif attribute == onnx.defs.OpSchema.AttrType.TENSOR: + return ONNXAttributeType.Tensor + else: + if config.Config.get_bool('debugprint'): + print("Got unsupported attribute type {}".format(attribute)) + return ONNXAttributeType.Unsupported + + if type(attribute) is onnx.AttributeProto: + return convert_attribute_proto(attribute) + + # Check basic Python types after ONNX enums (must be after enum checks) + if isinstance(attribute, (int, str, bool, float)): + return attribute + + raise NotImplementedError("No conversion implemented for {} (type {})".format(attribute, type(attribute))) + + +def convert_attribute_proto(proto): + # we cache the reverse map as an attribute of the method + if hasattr(convert_attribute_proto, "inv_map"): + inv_map = convert_attribute_proto.inv_map + else: + inv_map = {} + for k, v in onnx.AttributeProto.AttributeType.items(): + if k == "FLOAT": + inv_map[v] = lambda attr: get_proto_attr(attr, "f") + elif k == "FLOATS": + inv_map[v] = lambda attr: list(get_proto_attr(attr, "floats")) + elif k == "INT": + inv_map[v] = lambda attr: get_proto_attr(attr, "i") + elif k == "INTS": + inv_map[v] = lambda attr: list(get_proto_attr(attr, "ints")) + elif k == "STRING": + inv_map[v] = lambda attr: get_proto_attr(attr, "s").decode('utf-8') + elif k == "STRINGS": + inv_map[v] = lambda attr: list(map(lambda x: x.decode('utf-8'), get_proto_attr(attr, "strings"))) + elif k == "TENSOR": + inv_map[v] = lambda attr: to_array(get_proto_attr(attr, "t")) + + convert_attribute_proto.inv_map = inv_map + + onnx_type = get_proto_attr(proto, "type") + + if onnx_type == 0: + # in case of undefined return None + return None + + if onnx_type not in inv_map: + type_str = {v: k for k, v in onnx.AttributeProto.AttributeType.items()}[onnx_type] + raise NotImplementedError( + "Only FLOAT, FLOATS, INT, INTS, STRING, STRINGS and TENSOR attributes are supported, got attribute with type {}" + .format(type_str)) + + return inv_map[onnx_type](proto) + + +ONNX_DTYPES_TO_DACE_TYPE_CLASS = { + 'bool': dt.bool, + 'int8': dt.int8, + 'int16': dt.int16, + 'int32': dt.int32, + 'int64': dt.int64, + 'uint8': dt.uint8, + 'uint16': dt.uint16, + 'uint32': dt.uint32, + 'uint64': dt.uint64, + 'float16': dt.float16, + 'float': dt.float32, + 'double': dt.float64, + 'complex64': dt.complex64, + 'complex128': dt.complex128, +} + + +def typeclass_to_onnx_tensor_type_int(dtype: typeclass) -> int: + # we cache the reverse map as an attribute of the method + if not hasattr(typeclass_to_onnx_tensor_type_int, "inv_map"): + typeclass_to_onnx_tensor_type_int.inv_map = { + v: getattr(onnx.TensorProto.DataType, k.upper()) + for k, v in ONNX_DTYPES_TO_DACE_TYPE_CLASS.items() + } + + return typeclass_to_onnx_tensor_type_int.inv_map[dtype] + + +def onnx_tensor_type_to_typeclass(elem_type: int) -> typeclass: + # we cache the reverse map as an attribute of the method + if hasattr(onnx_tensor_type_to_typeclass, "inv_map"): + inv_map = onnx_tensor_type_to_typeclass.inv_map + else: + k: str + v: int + inv_map = {} + for k, v in onnx.TensorProto.DataType.items(): + if k.lower() in ONNX_DTYPES_TO_DACE_TYPE_CLASS: + inv_map[v] = ONNX_DTYPES_TO_DACE_TYPE_CLASS[k.lower()] + + onnx_tensor_type_to_typeclass.inv_map = inv_map + + if elem_type not in inv_map: + raise ValueError("Got unsupported ONNX tensor type: {}".format({ + v: k + for k, v in onnx.TensorProto.DataType.items() + }[elem_type])) + + return inv_map[elem_type] + + +def typeclass_to_onnx_str(dtype: typeclass) -> str: + # we cache the reverse map as an attribute of the method + if hasattr(typeclass_to_onnx_str, "inv_map"): + inv_map = typeclass_to_onnx_str.inv_map + else: + inv_map = {v: k for k, v in ONNX_DTYPES_TO_DACE_TYPE_CLASS.items()} + + if dtype not in inv_map: + raise ValueError("Attempted to convert unsupported dace type to ONNX type: {}".format(dtype)) + + return inv_map[dtype] + + +def onnx_type_str_to_typeclass(onnx_str) -> Union[typeclass, None]: + """Converts an onnx type string, like tensor(float16) to a dace typeclass""" + + results = re.findall(r"^tensor\((.+)\)", onnx_str) + if len(results) != 1 or results[0] not in ONNX_DTYPES_TO_DACE_TYPE_CLASS: + # we return None here, these types will be filtered out later + return None + + return ONNX_DTYPES_TO_DACE_TYPE_CLASS[str(results[0])] + + +def clean_onnx_name(name: str) -> str: + """Sanitize an ONNX name to make it a valid DaCe identifier. + + This function transforms ONNX names that may contain invalid characters + or patterns into valid DaCe identifiers by: + + - Prefixing names starting with digits with "ONNX_" + - Replacing special characters with their textual equivalents + + :param name: The ONNX name to sanitize. + :return: A valid DaCe identifier based on the ONNX name. + + Example:: + + >>> clean_onnx_name("123_layer") + 'ONNX_123_layer' + >>> clean_onnx_name("my.tensor:0") + 'myDOTtensorCOLON0' + """ + # If the first character is a digit, add the ONNX_ prefix + if re.match("^[0-9]", name): + name = f"ONNX_{name}" + + # Replace special characters with their textual equivalents + return (name.replace(".", "DOT").replace(":", "COLON").replace("/", "SLASH").replace("-", "DASH")) diff --git a/dace/libraries/onnx/forward_implementation_abc.py b/dace/libraries/onnx/forward_implementation_abc.py new file mode 100644 index 0000000000..b43d7fc4b3 --- /dev/null +++ b/dace/libraries/onnx/forward_implementation_abc.py @@ -0,0 +1,105 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Abstract Base Class for ONNX Operation Forward Implementations. + +This module defines the interface that all ONNX operation implementations must +follow in DaCe. It uses a registry pattern to allow multiple implementations +for each ONNX operation, enabling: + +- Pure Python implementations for correctness +- Optimized implementations for performance +- Hardware-specific implementations +- Custom user-provided implementations + +The ONNXForward ABC provides: +- Registration mechanism via @make_registry decorator +- Implementation selection based on applicability +- Expansion of ONNX ops to DaCe SDFG nodes + +Implementation Registration: + Implementations register themselves by inheriting from ONNXForward and + using the @op_implementation decorator with: + - `op`: ONNX operation name (e.g., "Conv", "MatMul") + - `name`: Implementation name (e.g., "pure", "optimized") + +Example: + @op_implementation(op="MatMul", name="pure") + class PureMatMul(ONNXForward): + @staticmethod + def forward(node, state, sdfg): + # Implementation here + pass +""" + +import abc +import typing + +from dace import SDFG, SDFGState +from dace.registry import make_registry +from dace.sdfg.nodes import Node + +from dace.libraries.onnx.nodes.onnx_op import ONNXOp + + +@make_registry +class ONNXForward(abc.ABC): + """ + Abstract base class for ONNX operation forward implementations. + + This class defines the interface for implementing ONNX operations in DaCe. + Subclasses must implement the `forward` method to expand an ONNX operation + node into DaCe SDFG constructs. + + The registry system allows multiple implementations per operation, with + selection based on applicability criteria. + """ + + @staticmethod + def forward_can_be_applied(node: ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + """Check whether this implementation can be applied to the given node. + + This method is called during SDFG expansion to determine if this + implementation is suitable for the given context. The default + implementation returns True (always applicable). + + :param node: The ONNX operation node to expand. + :param state: The SDFG state containing the node. + :param sdfg: The parent SDFG. + :return: True if this implementation can be applied, False otherwise. + """ + return True + + @staticmethod + @abc.abstractmethod + def forward(node: ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + """Expand an ONNX operation node into DaCe SDFG constructs. + + This is the main method that must be implemented by subclasses. It takes + an ONNX operation node and replaces it with equivalent DaCe constructs + (tasklets, nested SDFGs, library nodes, etc.). + + :param node: The ONNX operation node to expand. + :param state: The SDFG state containing the node. + :param sdfg: The parent SDFG. + :return: The expanded node or a nested SDFG representing the operation. + """ + ... + + @classmethod + def registered_implementations(cls, op_name: str) -> typing.List[typing.Tuple[str, "ONNXForward"]]: + """Get all registered implementations for a specific ONNX operation. + + :param op_name: The ONNX operation name (e.g., "Conv", "MatMul"). + :return: List of tuples (implementation_name, implementation_class) for + all registered implementations of the given operation. + """ + impls = [] + for impl, args in cls.extensions().items(): + if "op" in args and args["op"] == op_name: + impls.append((args["name"], impl)) + + return impls + + +# Import op_implementations to trigger registration of all implementations +import dace.libraries.onnx.op_implementations diff --git a/dace/libraries/onnx/nodes/__init__.py b/dace/libraries/onnx/nodes/__init__.py new file mode 100644 index 0000000000..0ed1814fe5 --- /dev/null +++ b/dace/libraries/onnx/nodes/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +from .onnx_op_registry import get_onnx_node, has_onnx_node diff --git a/dace/libraries/onnx/nodes/node_utils.py b/dace/libraries/onnx/nodes/node_utils.py new file mode 100644 index 0000000000..d377660f65 --- /dev/null +++ b/dace/libraries/onnx/nodes/node_utils.py @@ -0,0 +1,90 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Utility functions for ONNX node operations. + +This module provides helper functions for working with ONNX operation nodes +in DaCe SDFGs, including: + +- Parsing variadic parameter names +- Validating parameter formats +- Schema utilities for ONNX operations + +These utilities support the ONNX node system by handling the complexities +of variadic inputs/outputs and parameter naming conventions. +""" + +from typing import Tuple + +from dace.libraries.onnx.schema import ONNXParameterType, ONNXSchema + + +def parse_variadic_param(param: str) -> Tuple[str, int]: + """Parse a variadic parameter name into its base name and index. + + ONNX operations can have variadic inputs/outputs, which are named using + the convention 'base_name__index' (e.g., 'input__0', 'input__1'). + This function extracts the base name and numeric index. + + :param param: The variadic parameter name in format 'name__number'. + :return: A tuple of (base_name, index) where base_name is the parameter name + and index is the variadic position (zero-indexed). + :raises ValueError: If the parameter format is invalid, has leading zeros + in the number, or the number is negative. + + Example:: + + >>> parse_variadic_param("input__0") + ('input', 0) + >>> parse_variadic_param("output__5") + ('output', 5) + >>> parse_variadic_param("input__01") # raises ValueError + """ + split = param.split('__') + if len(split) != 2: + raise ValueError("Unable to parse variadic parameter '{}'".format(param)) + name = split[0] + number = split[1] + + if number[0] == '0' and len(number) > 1: + raise ValueError("Variadic parameters must not be numbered with leading zeros, got: '{}'".format(number)) + + number = int(number) + if number < 0: + raise ValueError("Variadic parameter numberings must be greater than zero, got: '{}'".format(number)) + return name, number + + +def get_position(schema: ONNXSchema, is_input: bool, parameter_name: str): + """Get the position that the parameter has in the ONNX op. + + :param schema: The ONNX schema containing parameter definitions. + :param is_input: True if looking for input parameters, False for output parameters. + :param parameter_name: The name of the parameter to find position for. + :return: The position index of the parameter in the operation signature. + :raises ValueError: If parameter is not found, has incorrect variadic format, + or schema validation fails. + """ + if "__" in parameter_name: + parameter_name, variadic_number = parse_variadic_param(parameter_name) + else: + variadic_number = None + + matches = [(i, param) for i, param in enumerate(schema.inputs if is_input else schema.outputs) + if param.name == parameter_name] + if len(matches) != 1: + raise ValueError("Error in schema: found more or less than one parameter with name {}".format(parameter_name)) + + index, param = matches[0] + + if variadic_number is not None and param.param_type != ONNXParameterType.Variadic: + raise ValueError("Got variadic index for non-variadic parameter {}".format(parameter_name)) + + if variadic_number is None and param.param_type == ONNXParameterType.Variadic: + raise ValueError("Did not get variadic index for variadic parameter {}. " + "Specify a variadic index by renaming the parameter to {}__i, where i is a number".format( + parameter_name, parameter_name)) + + if variadic_number is not None: + return variadic_number + index + else: + return index diff --git a/dace/libraries/onnx/nodes/onnx_op.py b/dace/libraries/onnx/nodes/onnx_op.py new file mode 100644 index 0000000000..e448bf8443 --- /dev/null +++ b/dace/libraries/onnx/nodes/onnx_op.py @@ -0,0 +1,295 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import itertools +from typing import Iterator, Tuple, List + +import dace.sdfg.nodes as nd +from dace.sdfg import SDFG, SDFGState +from dace.properties import Property, make_properties +from dace.sdfg.graph import MultiConnectorEdge + +from dace.libraries.onnx.nodes.node_utils import parse_variadic_param +from dace.libraries.onnx.schema import ONNXSchema, ONNXParameterType + + +def get_missing_arguments_message(function_name, missing_arguments, argument_type): + names = list(map(lambda x: "'" + x + "'", missing_arguments)) + + if len(missing_arguments) == 1: + arglist = names[0] + else: + arglist = ", ".join(names[:-1]) + ", and " + names[-1] + + return "{function_name} missing {num_missing} required {argument_type}{s}: {arglist}".format( + function_name=function_name, + num_missing=len(missing_arguments), + argument_type=argument_type, + s='' if len(missing_arguments) == 1 else 's', + arglist=arglist) + + +@make_properties +class ONNXOp(nd.LibraryNode): + """ Abstract superclass for all ONNX ops. Do not use this class, use the concrete subclasses + (e.g. :class:`~dace.libraries.onnx.nodes.onnx_op.ONNXConv`) instead. + """ + + # Global properties + # these two are filled out in the generated constructor + implementations = {} + default_implementation = None + default_backward_implementation = None + + # Object fields + schema = Property(dtype=ONNXSchema, desc="The operator's ONNX OpSchema", allow_none=True) + + backward_implementation = Property( + dtype=str, + allow_none=True, + desc="Which implementation this library node will expand into in the backward pass.") + + def iter_outputs_in_onnx_order(self, state: SDFGState) -> List[MultiConnectorEdge]: + """ Iterate through the input edges in the same order as they would appear in an ONNX node proto. + This assumes that the node has been validated! + + :param state: the state containing this node. + :return: the out edges in the order as they would appear in the node proto. + """ + return self._iter_params_in_onnx_order(state, inputs=False) + + def iter_inputs_in_onnx_order(self, state: SDFGState) -> List[MultiConnectorEdge]: + """ Iterate through the output edges in the same order as they would appear in an ONNX node proto. + This assumes that the node has been validated! + + :param state: the state containing this node. + :return: the in edges in the order as they would appear in the node proto. + """ + return self._iter_params_in_onnx_order(state, inputs=True) + + def _iter_params_in_onnx_order(self, state: SDFGState, inputs: bool = False) -> List[MultiConnectorEdge]: + parameters = list(self.schema.inputs if inputs else self.schema.outputs) + if len(parameters) == 0: + return [] + if parameters[-1].param_type == ONNXParameterType.Variadic: + name = parameters[-1].name + parameters = itertools.chain([param.name for param in parameters[:-1]], + (name + "__" + str(i) for i in itertools.count())) + else: + parameters = [param.name for param in parameters] + + edges = state.in_edges(self) if inputs else state.out_edges(self) + parameters = list(itertools.islice(parameters, len(edges))) + conn_to_edge = {edge.dst_conn if inputs else edge.src_conn: edge for edge in edges} + + return [conn_to_edge[name] for name in parameters] + + def iter_edges( + self, + state: SDFGState, + ignore_unknown=False, + ) -> Iterator[Tuple[MultiConnectorEdge, bool]]: + """ Returns an iterator over tuples of an edge and a boolean that indicates whether that edge is an input, + ordered by the order required by the schema. + This method assumes that this node has been validated. + + :param state: the state containing this node. + :param ignore_unknown: whether to ignore any edges that don't exist in the ONNX schema. Otherwise, an + error will be thrown. + """ + in_edges: List[MultiConnectorEdge] = state.in_edges(self) + out_edges: List[MultiConnectorEdge] = state.out_edges(self) + + def get_idx(parameters, name): + if '__' in name: + name, number = parse_variadic_param(name) + else: + number = 0 + + matched = [i for i, param in enumerate(parameters) if param.name == name] + + if len(matched) != 1: + if ignore_unknown: + return None + raise ValueError("Found {} connectors with name '{}', expected to find exactly one".format( + len(matched), name)) + + parameter_idx = matched[0] + + # add on the variadic parameter index + parameter_idx += number + + return parameter_idx + + if ignore_unknown: + in_edges = [e for e in in_edges if get_idx(self.schema.inputs, e.dst_conn) is not None] + out_edges = [e for e in out_edges if get_idx(self.schema.outputs, e.src_conn) is not None] + + sorted_in = sorted(in_edges, key=lambda edge: get_idx(self.schema.inputs, edge.dst_conn)) + sorted_out = sorted(out_edges, key=lambda edge: get_idx(self.schema.outputs, edge.src_conn)) + + return itertools.chain(zip(sorted_in, itertools.repeat(True)), zip(sorted_out, itertools.repeat(False))) + + def validate(self, sdfg: SDFG, state: SDFGState): + """ Validate this node. + + :param sdfg: the parent sdfg. + :param state: the parent state. + """ + in_edges = state.in_edges(self) + out_edges = state.out_edges(self) + + # check that we don't have connectors to None + all_connectors = {edge.dst_conn for edge in in_edges}.union(edge.src_conn for edge in out_edges) + if None in all_connectors: + raise ValueError("Edges to ONNX Ops must not have connector None") + + # check that all edges have connectors + ########################################## + for edge, is_input in self.iter_edges(state): + if is_input: + conn_name = edge.dst_conn + if conn_name not in self.in_connectors: + raise ValueError("Memlet {} leading to nonexistent input connector '{}'".format( + edge.data, conn_name)) + else: + conn_name = edge.src_conn + if conn_name not in self.out_connectors: + raise ValueError("Memlet {} leading to nonexistent output connector '{}'".format( + edge.data, conn_name)) + + # check that we have all required in_edges + ########################################## + required_inputs = {inp.name for inp in self.schema.inputs if inp.param_type == ONNXParameterType.Single} + passed_inputs = {inp.dst_conn + for inp in in_edges if '__' not in inp.dst_conn} # we will test variadic inputs separately + known_inputs = {inp.name for inp in self.schema.inputs} + + missing_inputs = required_inputs.difference(passed_inputs) + if len(missing_inputs) > 0: + raise ValueError(get_missing_arguments_message(self.schema.name, missing_inputs, "input")) + + # check that we have all required out_edges + ########################################## + required_outputs = {outp.name for outp in self.schema.outputs if outp.param_type == ONNXParameterType.Single} + passed_outputs = {outp.src_conn + for outp in out_edges if '__' not in outp.src_conn} # we will test variadic inputs separately + known_outputs = {outp.name for outp in self.schema.outputs} + + missing_outputs = required_outputs.difference(passed_outputs) + if len(missing_outputs) > 0: + raise ValueError(get_missing_arguments_message(self.schema.name, missing_outputs, "output")) + + # check that we have no unknown in edges + ########################################## + unknown_inputs = passed_inputs.difference(known_inputs) + if len(unknown_inputs) > 0: + raise TypeError("Got an unexpected argument '{}'".format(list(unknown_inputs)[0])) + + # check that we have no unknown out edges + ########################################## + unknown_outputs = passed_outputs.difference(known_outputs) + if len(unknown_outputs) > 0: + raise TypeError("Got an unexpected argument '{}'".format(list(unknown_outputs)[0])) + + # check variadic params + ########################################## + variadic_inputs = {inp.name for inp in self.schema.inputs if inp.param_type == ONNXParameterType.Variadic} + passed_variadic_inputs = {edge.dst_conn for edge in in_edges if '__' in edge.dst_conn} + + seen_variadic_numbers = set() + for param in passed_variadic_inputs: + name, number = parse_variadic_param(param) + if name not in variadic_inputs: + raise ValueError("Got an unexpected variadic argument '{}'".format(param)) + if number in seen_variadic_numbers: + raise ValueError("Got two variadic inputs with index {}, expected at most one".format(number)) + seen_variadic_numbers.add(number) + + # check that we have seen every number + for i in range(len(seen_variadic_numbers)): + if i not in seen_variadic_numbers: + raise ValueError( + "Since {} variadic inputs were passed, expected variadic parameter with number {}".format( + len(seen_variadic_numbers), i)) + + variadic_outputs = {outp.name for outp in self.schema.outputs if outp.param_type == ONNXParameterType.Variadic} + passed_variadic_outputs = {edge.src_conn for edge in out_edges if '__' in edge.src_conn} + seen_variadic_numbers = set() + for param in passed_variadic_outputs: + name, number = parse_variadic_param(param) + if name not in variadic_outputs: + raise ValueError("Got an unexpected variadic argument '{}'".format(param)) + if number in seen_variadic_numbers: + raise ValueError("Got two variadic outputs with index {}, expected at most one".format(number)) + seen_variadic_numbers.add(number) + + # check that we have seen every number + for i in range(len(seen_variadic_numbers)): + if i not in seen_variadic_numbers: + raise ValueError( + "Since {} variadic outputs were passed, expected variadic parameter with number {}".format( + len(seen_variadic_numbers), i)) + + # check that type params solve + ########################################## + + assigned_params = {} + for edge, is_input in self.iter_edges(state): + conn_name = edge.dst_conn if is_input else edge.src_conn + + if '__' in conn_name: + parsed_name, number = parse_variadic_param(conn_name) + else: + parsed_name = conn_name + + matching = [ + inp for inp in (self.schema.inputs if is_input else self.schema.outputs) if inp.name == parsed_name + ] + + if len(matching) != 1: + raise ValueError("Expected to find one {} parameter in schema with name '{}', but found {}".format( + "input" if is_input else "output", parsed_name, len(matching))) + matched = matching[0] + + if '__' in conn_name and matched.param_type != ONNXParameterType.Variadic: + raise ValueError("Got variadic argument '{}' for non-variadic parameter '{}'." + " Ensure that non-variadic args do not contain '__'".format(conn_name, matched.name)) + + if '__' not in conn_name and matched.param_type == ONNXParameterType.Variadic: + raise ValueError( + "Expected variadic argument for variadic parameter '{}', got '{}'. Use '{}__i' as the connector" + " name, where i is the desired index of the variadic parameter.".format( + matched.name, conn_name, conn_name)) + + edge_data = edge.data.data + edge_dtype = sdfg.arrays[edge_data].dtype + # edge_dtype can be a vector type + if matched.param_type == ONNXParameterType.Variadic and not matched.homogeneous: + # non homogeneous parameters don't need to be consistent + pass + elif matched.type_str in assigned_params and (assigned_params[matched.type_str] != edge_dtype and + assigned_params[matched.type_str] != edge_dtype.base_type): + raise ValueError( + "Could not solve type constraints;" + " excepted type '{expected}' for {param_type} '{conn_name}', got type '{actual}'".format( + expected=assigned_params[matched.type_str], + param_type="input" if is_input else "output", + conn_name=matched.name, + actual=edge_dtype)) + + # otherwise, matched.type_str was not assigned a type yet: try to assign it + cons = self.schema.type_constraints[matched.type_str] + if edge_dtype not in cons.types and edge_dtype.base_type not in cons.types: + raise ValueError( + "Expected type in '{possible}' for {param_type} '{conn_name}', got type '{actual}'".format( + possible=cons.types, + param_type="input" if is_input else "output", + conn_name=matched.name, + actual=edge_dtype)) + assigned_params[matched.type_str] = edge_dtype.base_type + + # check that we have all required attributes + ########################################## + required_attrs = {name for name, attr in self.schema.attributes.items() if attr.required} + for attr in required_attrs: + if getattr(self, attr) is None: + raise ValueError("Expected value for required attribute '{}', got None".format(attr)) diff --git a/dace/libraries/onnx/nodes/onnx_op_registry.py b/dace/libraries/onnx/nodes/onnx_op_registry.py new file mode 100644 index 0000000000..6b145a0859 --- /dev/null +++ b/dace/libraries/onnx/nodes/onnx_op_registry.py @@ -0,0 +1,351 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import collections +from typing import Iterator, Tuple, List, Dict, Type + +import dace +import dace.library +import dace.sdfg.nodes as nd +import dace.frontend.common.op_repository as dace_op_repo +from dace.frontend.python.newast import ProgramVisitor +from dace import config, SDFG, SDFGState, dtypes, data +from dace.properties import Property, ListProperty, make_properties +from dace.sdfg.graph import MultiConnectorEdge +from dace.transformation.transformation import ExpandTransformation + +from dace.libraries.onnx.nodes.node_utils import parse_variadic_param +from dace.libraries.onnx.schema import ONNXSchema, ONNXAttributeType, _ATTR_TYPE_TO_PYTHON_TYPE, ONNXParameterType, ONNXAttribute, ONNXParameter, ONNXTypeConstraint + +import dace.libraries.onnx.nodes.onnx_op as onnx_op +from dace.frontend.python.common import StringLiteral + +import onnx + + +def _get_typecons_docstring(cons: ONNXTypeConstraint) -> str: + """Generate documentation string for type constraints.""" + return " * **{}** -- {}".format(cons.type_str, + ", ".join(":class:`{}`".format(t.to_string()) for t in cons.types)) + + +def _get_connector_docstring(param: ONNXParameter) -> str: + """Generate documentation string for connectors.""" + return " * **{}** ({}, {}) -- {}".format(param.name, param.type_str, param.param_type.name.lower(), + param.description) + + +def _get_attr_docstring(attr: ONNXAttribute) -> str: + """Generate documentation string for attributes.""" + param_doc = ":param {}: {}".format(attr.name, attr.description) + + if attr.attribute_type is ONNXAttributeType.Unsupported: + return "" + + if attr.attribute_type is ONNXAttributeType.Tensor: + type_string = "numpy.ndarray" + else: + type_string = _ATTR_TYPE_TO_PYTHON_TYPE[attr.attribute_type].__name__ + + type_string = ":class:`{}`".format(type_string) + + if attr.attribute_type in [ONNXAttributeType.Ints, ONNXAttributeType.Floats, ONNXAttributeType.Strings]: + type_string = ":class:`List` [{}]".format(type_string) + + if not attr.required: + type_string = ":class:`Optional` [{}], default={}".format(type_string, repr(attr.default_value)) + + param_type = ":type {}: {}".format(attr.name, type_string) + + return param_doc + "\n" + param_type + + +def _get_all_schemas(): + """Get all ONNX schemas with version history.""" + name_to_schemas = collections.defaultdict(list) + for schema in onnx.defs.get_all_schemas_with_history(): + name_to_schemas[schema.name].append(schema) + + all_schemas = [] + for name, schemas in name_to_schemas.items(): + all_schemas.extend(schemas) + + return all_schemas + + +def register_op_repo_replacement(cls: Type[onnx_op.ONNXOp], cls_name: str, dace_schema: ONNXSchema): + """Register an op repository replacement for the given ONNX operation class.""" + + @dace_op_repo.replaces("dace.libraries.onnx.{}".format(cls_name)) + def op_repo_replacement(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, **kwargs): + attrs = {name: value for name, value in kwargs.items() if name in dace_schema.attributes} + # Remove used attrs + kwargs = {k: v for k, v in kwargs.items() if k not in attrs} + + onnx_node = cls(name=cls_name, **attrs) + state.add_node(onnx_node) + + input_names = dace_schema.non_variadic_inputs() + variadic_inputs = dace_schema.variadic_inputs() + + output_names = dace_schema.non_variadic_outputs() + variadic_outputs = dace_schema.variadic_outputs() + + inputs = { + name: arr_name + for name, arr_name in kwargs.items() + if (name in input_names or + # variadic params + ("__" in name and parse_variadic_param(name)[0] in variadic_inputs)) + } + + kwargs = {k: v for k, v in kwargs.items() if k not in inputs} + + outputs = { + name: arr_name + for name, arr_name in kwargs.items() + if (name in output_names or + # variadic params + ("__" in name and parse_variadic_param(name)[0] in variadic_outputs)) + } + + kwargs = {k: v for k, v in kwargs.items() if k not in outputs} + + if len(kwargs) > 0: + raise TypeError(f"Unknown arguments {', '.join(kwargs)}") + + # Remove all non-string attributes + # Sometimes constants are passed as inputs, but they do not require AccessNodes + # so we add them first as attributes to the node + for inp, arr_name in inputs.items(): + if not isinstance(arr_name, str): + setattr(onnx_node, inp, arr_name) + + inputs = {inp: arr_name for inp, arr_name in inputs.items() if isinstance(arr_name, str)} + + for inp, arr_name in inputs.items(): + read = state.add_read(arr_name) + state.add_edge(read, None, onnx_node, inp, sdfg.make_array_memlet(arr_name)) + onnx_node.add_in_connector(inp) + + for outp, arr_name in outputs.items(): + write = state.add_read(arr_name) + state.add_edge(onnx_node, outp, write, None, sdfg.make_array_memlet(arr_name)) + onnx_node.add_out_connector(outp) + return [] + + +_ONNX_OPS = {} +_REGISTRY_INITIALIZED = False + + +def _initialize_onnx_registry(): + """ + Lazy initialization of ONNX operator registry. + This function is called only when ONNX nodes are actually used. + We add a global flag (_REGISTRY_INITIALIZED) to avoid re-initializing the registry multiple times. + """ + global _REGISTRY_INITIALIZED, _ONNX_OPS + + if _REGISTRY_INITIALIZED: + return + + _REGISTRY_INITIALIZED = True + + # Import these here to avoid circular imports at module load time + from dace.libraries.onnx.forward_implementation_abc import ONNXForward + import dace.libraries.onnx.op_implementations # Registers implementations + + # Generate all of the Op Nodes + for schema in _get_all_schemas(): + try: + dace_schema = ONNXSchema.from_onnx_proto(schema) + # If the schema has a parameter name that exists as both an input and an output, prepend "in_" and "out_" + intersecting_names = set(i.name for i in dace_schema.inputs).intersection(o.name + for o in dace_schema.outputs) + for name in intersecting_names: + in_cands = [i for i in dace_schema.inputs if i.name == name] + out_cands = [i for i in dace_schema.outputs if i.name == name] + assert len(in_cands) == len(out_cands) == 1 + in_cands[0].name = "in_" + name + out_cands[0].name = "out_" + name + + except Exception as e: + if config.Config.get_bool('debugprint'): + print("Import of {} failed: {}".format(schema.name, e)) + continue + + attrs = {} + # Add properties for each op attribute + for name, attr in dace_schema.attributes.items(): + if attr.attribute_type in [ + ONNXAttributeType.Int, ONNXAttributeType.String, ONNXAttributeType.Float, ONNXAttributeType.Tensor + ]: + attrs[name] = Property(dtype=_ATTR_TYPE_TO_PYTHON_TYPE[attr.attribute_type], + desc=attr.description, + allow_none=True, + default=None if attr.default_value is None else attr.default_value) + elif attr.attribute_type in [ONNXAttributeType.Ints, ONNXAttributeType.Strings, ONNXAttributeType.Floats]: + attrs[name] = ListProperty(element_type=_ATTR_TYPE_TO_PYTHON_TYPE[attr.attribute_type], + desc=attr.description, + allow_none=True, + default=None if attr.default_value is None else attr.default_value) + elif attr.required: + raise NotImplementedError("Required attribute '{}' has an unsupported type".format(attr.name)) + + required_attrs = {name for name, attr in dace_schema.attributes.items() if attr.required} + + def __init__(self, name, *args, location=None, optional=set(), **op_attributes): + super(onnx_op.ONNXOp, self).__init__( + name, + location=location, + # Add required parameters as in/out connectors, without types for now + inputs={ + inp.name + for inp in self.schema.inputs if inp.param_type == ONNXParameterType.Single or ( + inp.name in optional and inp.param_type == ONNXParameterType.Optional) + }, + outputs={ + out.name + for out in self.schema.outputs if out.param_type == ONNXParameterType.Single or ( + out.name in optional and out.param_type == ONNXParameterType.Optional) + }) + self.backward_implementation = None + + if len(args) > 0: + raise TypeError("__init__() takes 1 positional arguments but {} were given".format(1 + len(args))) + + missing_arguments = required_attrs.difference(op_attributes) + if len(missing_arguments) > 0: + + raise TypeError( + onnx_op.get_missing_arguments_message("__init__()", missing_arguments, "keyword-only argument")) + + unknown_attrs = set(op_attributes).difference(self.schema.attributes) + if len(unknown_attrs) > 0: + raise TypeError("{}.__init__() got an unexpected keyword argument '{}'".format( + self.schema.name, + list(unknown_attrs)[0])) + + for name, attr in op_attributes.items(): + if isinstance(attr, StringLiteral): + attr = attr.value + setattr(self, name, attr) + + input_connector_docstrings = "\n".join(_get_connector_docstring(param) for param in dace_schema.inputs) + output_connector_docstrings = "\n".join(_get_connector_docstring(param) for param in dace_schema.outputs) + + cls_name = "ONNX" + dace_schema.name + + # The first line of the init docstring contains the signature of the method. This will be picked up by sphinx and + # means that the generated sphinx docs have a proper signature, and not just *args, **kwargs. + init_docstring = "__init__(name, *, {})\n".format(", ".join(attr.name if attr.required else attr.name + "=" + + repr(attr.default_value) + for _, attr in dace_schema.attributes.items())) + init_docstring += ":param name: The name of the node.\n" + "\n".join( + _get_attr_docstring(attr) for _, attr in dace_schema.attributes.items()) + + docstring = "\n" + dace_schema.doc + type_docstrings = "\n".join(_get_typecons_docstring(cons) for _, cons in dace_schema.type_constraints.items()) + docstring += "\n\n" + docstring += ":Node Inputs:" + input_connector_docstrings + docstring += "\n\n" + docstring += ":Node Outputs:" + output_connector_docstrings + docstring += "\n\n" + docstring += ":Type Constraints:" + type_docstrings + + attrs['__doc__'] = docstring + "\n" + attrs['schema'] = dace_schema + + attrs['__init__'] = __init__ + + cls_name_ver = cls_name + "_" + str(dace_schema.since_version) + + cls = type(cls_name_ver, (onnx_op.ONNXOp, ), attrs) + cls = dace.library.node(cls) + cls.__init__.__doc__ = "\n" + init_docstring + # Set library name for lazy-loaded nodes + cls._dace_library_name = "dace.libraries.onnx" + + # Register pure implementations + registered = False + for impl, args in ONNXForward.extensions().items(): + if "op" in args and args["op"] == schema.name: + + class Expansion(ExpandTransformation): + environments = [] + forward_impl: ONNXForward = impl + + @classmethod + def expansion(cls, node, state, sdfg, **kwargs): + # validate + node.validate(sdfg, state) + + if cls.forward_impl.forward_can_be_applied(node, state, sdfg): + result = cls.forward_impl.forward(node, state, sdfg, **kwargs) + if hasattr(cls.forward_impl, "environments"): + cls.environments.extend(cls.forward_impl.environments) + return result + + implementation_name = args["name"] + + # Give the Expansion class a unique name and register it in globals + # so it can be located during deserialization + expansion_class_name = f"{cls_name_ver}_Expansion_{implementation_name}" + Expansion.__name__ = expansion_class_name + Expansion.__qualname__ = expansion_class_name + globals()[expansion_class_name] = Expansion + + cls.register_implementation(implementation_name, Expansion) + registered = True + + if not registered: + # WARNING: No implementation found for this op + cls.default_implementation = None + + version = schema.since_version + + if cls_name not in _ONNX_OPS: + _ONNX_OPS[cls_name] = {} + _ONNX_OPS[cls_name][version] = cls + + for name, ver_to_cls in _ONNX_OPS.items(): + _ONNX_OPS[name] = dict(sorted(ver_to_cls.items())) + for i, (version, cls) in enumerate(_ONNX_OPS[name].items()): + if i == len(_ONNX_OPS[name]) - 1: + # last version registered as the default + globals()[name] = cls + # register python frontend replacement + register_op_repo_replacement(cls, name, cls.schema) + # all other versions are registered with version as a suffix + globals()[name + "_" + str(version)] = cls + + +def has_onnx_node(name: str) -> bool: + """Check if an ONNX operator is supported. + + :param name: The operator name. + :return: True if the operator is supported, False otherwise. + """ + _initialize_onnx_registry() + return ("ONNX" + name) in _ONNX_OPS + + +def get_onnx_node(name: str, version: int = -1) -> onnx_op.ONNXOp: + """Get the ONNX Operator node for an operator by name. + + :param name: The operator name. + :param version: The version of the operator (-1 for latest). + :return: The ONNX operator node class. + :raises ValueError: If no version of the operator is found for the given version. + """ + _initialize_onnx_registry() + name_to_versions = list(_ONNX_OPS["ONNX" + name].items()) + + if version == -1: + # Take the latest version + return name_to_versions[-1][1] + else: + # Take the latest version which is less than or equal to the given version + for ver, cls in reversed(name_to_versions): + if ver <= version: + return cls + raise ValueError(f"No version of {name} found for version {version}") diff --git a/dace/libraries/onnx/onnx.md b/dace/libraries/onnx/onnx.md new file mode 100644 index 0000000000..0f43d037a3 --- /dev/null +++ b/dace/libraries/onnx/onnx.md @@ -0,0 +1,993 @@ +Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +# DaCe ONNX Integration Library - Design Document + +## Table of Contents + +1. [Introduction](#1-introduction) +2. [Architecture Overview](#2-architecture-overview) +3. [Directory Structure](#3-directory-structure) +4. [Core Components](#4-core-components) +5. [Import Pipeline](#5-import-pipeline) +6. [Shape Inference System](#6-shape-inference-system) +7. [Implementation Strategies](#7-implementation-strategies) +8. [Key Algorithms](#8-key-algorithms) +9. [Extension Points](#9-extension-points) + +--- + +## 1. Introduction + +### 1.1 Purpose + +The DaCe ONNX Integration Library enables **the importing and executing of ONNX (Open Neural Network Exchange) models** within the DaCe framework. It provides a pipeline for converting ONNX neural network models into optimized DaCe SDFGs (Stateful DataFlow Graphs) that can run efficiently on CPUs, GPUs, and other accelerators. + +### 1.2 Current Capabilities + +- **Model Import**: Load ONNX models from files or protobuf objects +- **Shape Inference**: Automatic computation of tensor shapes (symbolic and concrete) +- **Multi-Strategy Implementations**: Pure (correctness), optimized (performance), hardware-specific (GPU/FPGA) +- **Type Safety**: Schema-based validation and type checking +- **Framework Integration**: Interoperability with PyTorch and NumPy + +### 1.3 Use Cases + +1. **ML Inference Optimization**: Optimize pre-trained models for production deployment +2. **Hardware Acceleration**: Leverage DaCe's code generation for GPU/FPGA execution +3. **Cross-Framework Compatibility**: Run PyTorch/TensorFlow models in DaCe ecosystem +4. **Research and Experimentation**: Analyze and optimize neural network architectures +5. **Custom Optimization**: Apply DaCe transformations to ML workloads +6. **Benchmarking**: Compare performance across different implementations + +### 1.4 ONNX Background + +ONNX is an open standard for representing machine learning models, supported by major frameworks: +- **Export**: PyTorch, TensorFlow, Keras, scikit-learn +- **Operators**: 150+ standard operations (Conv, MatMul, Attention, etc.) +- **Opsets**: Versioned operator specifications (current: opset 18) +- **Use**: Model exchange, optimization, deployment + +--- + +## 2. Architecture Overview + +### 2.1 High-Level System Diagram + +``` +┌─────────────────────────────────────────────────────────────┐ +│ USER INTERFACE │ +│ ┌──────────────┐ ┌──────────────┐ ┌─────────────────┐ │ +│ │ ONNXModel │ │ ONNX Backend │ │ Direct ONNX Op │ │ +│ │ (main API) │ │ (testing) │ │ calls │ │ +│ └──────┬───────┘ └──────┬───────┘ └────────┬────────┘ │ +└─────────┼─────────────────┼───────────────────┼─────────────┘ + │ │ │ + └─────────────────┼───────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ IMPORT PIPELINE │ +│ ┌──────────────────────────────────────────────────────┐ │ +│ │ ONNXModel (frontend/ml/onnx/importer.py) │ │ +│ │ 1. Load Model → 4. Graph Construction │ │ +│ │ 2. Simplify → 5. Weight Management │ │ +│ │ 3. Shape Infer → 6. Compilation │ │ +│ └──────────────────┬───────────────────────────────────┘ │ +└─────────────────────┼───────────────────────────────────────┘ + │ + ┌────────────┼───────────┐ + ▼ ▼ ▼ +┌──────────────┐ ┌─────────┐ ┌──────────────────┐ +│ REGISTRY │ │ SCHEMA │ │ SHAPE INFERENCE │ +├──────────────┤ ├─────────┤ ├──────────────────┤ +│ Dynamic Node │ │ Type │ │ Symbolic Shape │ +│ Generation │ │ System │ │ Inference │ +│ │ │ │ │ (Microsoft impl) │ +│ • 100+ ops │ │ • Valid-│ │ │ +│ • Versioning │ │ ation │ │ • Dynamic dims │ +│ • Properties │ │ • Const-│ │ • Auto-merge │ +│ • Connectors │ │ raints│ │ • Concrete eval │ +└──────────────┘ └─────────┘ └──────────────────┘ + │ │ │ + └────────────┼────────────┘ + ▼ +┌────────────────────────────────────────────────────────────┐ +│ IMPLEMENTATION LAYER │ +│ ┌─────────────┐ ┌──────────────┐ ┌──────────────────┐ │ +│ │ Pure │ │ Optimized │ │ Hardware │ │ +│ │ (SDFGs) │ │ (img ops) │ │ (cuDNN, etc) │ │ +│ ├─────────────┤ ├──────────────┤ ├──────────────────┤ │ +│ │ Reference │ │ Performance │ │ GPU/FPGA │ │ +│ │ impl for │ │ focused │ │ specialized │ │ +│ │ correctness │ │ operations │ │ libraries │ │ +│ └─────────────┘ └──────────────┘ └──────────────────┘ │ +└────────────────────────────────────────────────────────────┘ + ▼ + DaCe SDFG with ONNX Nodes + ▼ + Expansion → Optimization → Code Generation +``` + +### 2.2 Component Interaction Flow + +``` +ONNX Model File + ↓ +ONNXModel.__init__() + ↓ +1. onnx.checker.check_model() → Validate + ↓ +2. shape_inference.infer_shapes() → Compute shapes + ↓ +3. onnxsim.simplify() (optional) → Optimize ONNX graph + ↓ +4. Create SDFG structure + ↓ +5. For each ONNX node: + ├─→ get_onnx_node(op_type, version) → Retrieve node class + ├─→ Create instance with attributes + ├─→ Add connectors from schema + └─→ Create edges with memlets + ↓ +6. Load weights (initializers) + ↓ +7. Handle outputs (scalar promotion, return arrays) + ↓ +8. Apply GPU transformations (if cuda=True) + ↓ +SDFG with ONNX Library Nodes + ↓ +Compilation (triggered by first call or explicit compile()): + ├─→ Expand ONNX nodes (select implementation) + ├─→ Apply DaCe optimizations + ├─→ Generate C++/CUDA code + └─→ Compile to binary + ↓ +Execution: + ├─→ Infer runtime symbols from input shapes + ├─→ Call compiled function with inputs + weights + └─→ Return outputs (NumPy or PyTorch tensors) +``` + +--- + +## 3. Directory Structure + +### 3.1 File Organization + +``` +dace/libraries/onnx/ +├── __init__.py # Library registration (61 lines) +│ └── Exports: ONNXModel (lazy), get_onnx_node, has_onnx_node, schema types +│ +├── (Note: ONNXModel is at dace/frontend/ml/onnx/importer.py) +│ +├── schema.py # Type system +│ ├── @onnx_representation decorator +│ ├── ONNXSchema +│ ├── ONNXAttribute +│ ├── ONNXParameter +│ └── ONNXTypeConstraint +│ +├── converters.py # Type conversions +│ ├── convert_onnx_proto() +│ ├── onnx_tensor_type_to_typeclass() +│ ├── clean_onnx_name() +│ └── convert_attribute_proto() +│ +├── forward_implementation_abc.py # Implementation interface +│ └── ONNXForward (ABC + registry) +│ +├── nodes/ # ONNX operation nodes +│ ├── onnx_op.py # Base class +│ │ └── ONNXOp - Abstract superclass for all ONNX ops +│ ├── onnx_op_registry.py # Dynamic generation +│ │ ├── _get_all_schemas() +│ │ ├── _create_node_class() +│ │ └── get_onnx_node() / has_onnx_node() +│ └── node_utils.py # Utilities +│ └── parse_variadic_param() +│ +├── op_implementations/ # Implementation strategies +│ ├── __init__.py # Package exports (11 lines) +│ ├── elementwise_ops.py # Element-wise operations (212 lines) +│ ├── reduction_ops.py # Reduction operations (304 lines) +│ ├── array_ops.py # Array operations (681 lines) +│ ├── linalg_ops.py # Linear algebra ops (359 lines) +│ ├── normalization_ops.py # Normalization ops (281 lines) +│ ├── image_ops.py # Image operations (443 lines) +│ ├── img_op_implementations.py # Optimized image ops (563 lines) +│ ├── criteria_implementations.py # Conditional selection (90 lines) +│ ├── common.py # Common utilities (11 lines) +│ └── utils.py # Helpers (223 lines) +│ ├── @op_implementation decorator +│ ├── @python_pure_op_implementation +│ ├── program_for_node() +│ └── empty_sdfg_for_node() +│ +└── shape_inference/ # Dynamic shape support + └── (Empty - uses onnxruntime.tools.symbolic_shape_infer instead) + +``` + +### 3.2 File Size Distribution + +| File | Lines | Purpose | +|------|-------|---------| +| `array_ops.py` | 681 | Array operations (Concat, Gather, etc.) | +| `img_op_implementations.py` | 563 | Optimized image operations | +| `image_ops.py` | 443 | Image operation implementations | +| `linalg_ops.py` | 359 | Linear algebra operations | +| `onnx_op_registry.py` | 351 | Dynamic node class generation | +| `schema.py` | 333 | Type system and validation | +| `reduction_ops.py` | 304 | Reduction operations | +| `onnx_op.py` | 295 | Base class for ONNX operations | +| `normalization_ops.py` | 281 | Normalization operations | +| `converters.py` | 247 | Type conversion utilities | +| `utils.py` | 223 | Implementation helpers | +| `elementwise_ops.py` | 212 | Element-wise operations | +| `forward_implementation_abc.py` | 105 | Implementation interface | + +**Note**: ONNXModel (794 lines) is located at [dace/frontend/ml/onnx/importer.py](../../frontend/ml/onnx/importer.py). + +--- + +## 4. Core Components + +### 4.1 ONNXModel: The Main Entry Point + +**Location**: [dace/frontend/ml/onnx/importer.py](../../frontend/ml/onnx/importer.py) + +The `ONNXModel` class is the primary interface for importing and executing ONNX models. + +#### Key Features + +- **Model Loading**: Loads models from files or ONNX protobuf objects +- **Automatic Optimization**: Provides optional ONNX-level simplification +- **Shape Inference**: Handles dynamic and symbolic shapes automatically +- **Weight Management**: Loads and manages model parameters efficiently +- **Compilation**: Supports lazy or explicit compilation to optimized code +- **Execution**: Provides direct `__call__` interface with NumPy/PyTorch tensors +- **GPU Support**: Automatic GPU transformation when `cuda=True` + +#### Constructor Signature + +```python +class ONNXModel: + def __init__( + self, + name: str, + model: Union[str, onnx.ModelProto], + cuda: bool = False, + apply_strict: bool = False, + auto_optimize: bool = True, + onnx_simplify: bool = True, + infer_shapes: bool = True, + auto_merge: bool = False + ): + """ + Import an ONNX model into DaCe. + + Args: + name: Name for the generated SDFG + model: Path to .onnx file or onnx.ModelProto object + cuda: Enable GPU execution + apply_strict: Strict ONNX validation + auto_optimize: Apply DaCe optimizations on first run + onnx_simplify: Apply onnx-simplifier before import + infer_shapes: Run shape inference + auto_merge: Auto-merge conflicting symbolic shapes + """ +``` + +#### Main Methods + +- **`__call__()`**: Execute the model with inputs +- **`compile()`**: Explicitly compile the SDFG +- **`save()`**: Save compiled model to disk +- **`infer_symbols()`**: Infer symbolic dimension values from input shapes + +--- + +### 4.2 Registry System: Dynamic Node Generation + +**Location**: [nodes/onnx_op_registry.py](nodes/onnx_op_registry.py) + +The registry system **dynamically generates Python classes** for all ONNX operations at import time, eliminating the need to manually write 100+ node classes. + +#### How It Works + +**Process**: +``` +1. Query ONNX for all supported operations + ↓ +2. For each operation (e.g., "Conv"): + ├─ Get all versions (e.g., Conv_1, Conv_11, Conv_13) + ├─ Convert ONNX OpSchema to ONNXSchema + └─ For each version: + ├─ Create Python properties from attributes + ├─ Generate __init__ constructor + ├─ Add input/output connectors + ├─ Generate documentation + ├─ Create class with type() + └─ Register with DaCe library system + ↓ +3. Store in global registry: + _ONNX_OPS["Conv"][11] = ONNXConv_11 + _ONNX_OPS["Conv"][13] = ONNXConv_13 + ↓ +4. Export latest version to module: + ONNXConv = ONNXConv_13 +``` + +#### Generated Class Structure + +For each ONNX operation, the registry generates: + +- **Class Name**: `ONNX{OpName}_{Version}` (e.g., `ONNXConv_11`) +- **Properties**: One DaCe property per ONNX attribute +- **Constructor**: Validates required attributes, sets defaults +- **Connectors**: Input/output connectors from schema +- **Schema**: Embedded `ONNXSchema` for validation +- **Implementations**: Linked expansion transformations +- **Documentation**: Auto-generated from ONNX docs + +#### API Functions + +```python +def has_onnx_node(name: str) -> bool: + """Check if ONNX operation is supported.""" + +def get_onnx_node(name: str, opset_version: int = None) -> Type[ONNXOp]: + """Get ONNX node class by name and version.""" +``` + +--- + +### 4.3 Schema System: Type Safety + +**Location**: [schema.py](schema.py) + +The schema system provides a Python representation layer for ONNX protobuf schemas, enabling type-safe interactions. + +#### Key Components + +**ONNXSchema** - Complete operation specification: +```python +@dataclass +class ONNXSchema: + name: str # Operation name (e.g., "Conv") + since_version: int # First opset supporting this + doc: str # Documentation + inputs: List[ONNXParameter] # Input specifications + outputs: List[ONNXParameter] # Output specifications + attributes: Dict[str, ONNXAttribute] # Attribute specs + type_constraints: Dict[str, ONNXTypeConstraint] # Type constraints +``` + +**ONNXParameter** - Input/output parameter: +```python +@dataclass +class ONNXParameter: + name: str # Parameter name + type_str: str # Type constraint reference + param_type: ONNXParameterType # Single/Optional/Variadic + description: str # Documentation + homogeneous: bool # For variadic params +``` + +**ONNXAttribute** - Operation configuration: +```python +@dataclass +class ONNXAttribute: + name: str # Attribute name + type: ONNXAttributeType # Int/Float/String/Tensor/etc. + required: bool # Must be provided? + default_value: Any # Default if not provided + description: str # Documentation +``` + +**ONNXTypeConstraint** - Allowed types: +```python +@dataclass +class ONNXTypeConstraint: + type_param_str: str # Type parameter (e.g., "T") + allowed_types: List[typeclass] # Allowed DaCe types + description: str # Documentation +``` + +#### The @onnx_representation Decorator + +Enables creating Python classes from ONNX protobufs: + +```python +@onnx_representation(onnx.TensorProto) +class ONNXTensor: + dims: List[int] + data_type: int + # ... other fields +``` + +Automatically generates: +- `__init__()` constructor +- `from_onnx_proto()` class method +- `from_json()` / `to_json()` serialization +- Registration in the global protobuf registry + +--- + +### 4.4 ONNXOp Base Class + +**Location**: [nodes/onnx_op.py](nodes/onnx_op.py) + +`ONNXOp` is the abstract base class for all ONNX operation nodes in DaCe SDFGs. + +#### Key Methods + +- **`iter_inputs_in_onnx_order()`**: Get input edges in schema order +- **`iter_outputs_in_onnx_order()`**: Get output edges in schema order +- **`iter_edges()`**: Iterate all edges with input/output flag +- **Validation**: Automatic schema-based validation during SDFG construction + +#### Properties + +- `schema`: The operation's ONNXSchema +- `backward_implementation`: Which backward impl to use (for autodiff) +- `implementations`: Available forward implementations +- `default_implementation`: Default expansion strategy + +--- + +### 4.5 Type Converters + +**Location**: [converters.py](converters.py) + +Provides bidirectional conversion between ONNX, DaCe, NumPy, and PyTorch type systems. + +#### Key Functions + +**Type Conversion**: +- `onnx_tensor_type_to_typeclass()`: ONNX type enum → DaCe typeclass +- `typeclass_to_onnx_tensor_type_int()`: DaCe typeclass → ONNX type enum +- `convert_onnx_proto()`: Generic protobuf → Python conversion +- `convert_attribute_proto()`: ONNX AttributeProto → Python value + +**Name Sanitization**: +- `clean_onnx_name()`: Makes ONNX names valid DaCe identifiers + - Prefixes digit-starting names: `123` → `ONNX_123` + - Replaces special characters: `.` → `DOT`, `:` → `COLON`, `/` → `SLASH` + +**Helper Functions**: +- `get_proto_attr()`: Provides safe protobuf attribute access with encoding checks + +--- + +## 5. Import Pipeline + +### 5.1 Complete Workflow + +``` +┌─────────────────────────────────────────────────────────┐ +│ Phase 1: Model Loading and Validation │ +├─────────────────────────────────────────────────────────┤ +│ 1. Load ONNX model (from file or protobuf) │ +│ 2. Run onnx.checker.check_model() │ +│ 3. Validate model conforms to ONNX spec │ +└──────────────────────┬──────────────────────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────┐ +│ Phase 2: Shape Inference │ +├─────────────────────────────────────────────────────────┤ +│ 1. Run symbolic shape inference │ +│ 2. Compute concrete shapes where possible │ +│ 3. Create symbolic dimensions for dynamic shapes │ +│ 4. Auto-merge conflicting symbols (optional) │ +└──────────────────────┬──────────────────────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────┐ +│ Phase 3: ONNX-Level Optimization (optional) │ +├─────────────────────────────────────────────────────────┤ +│ 1. Apply onnxsim.simplify() │ +│ - Constant folding │ +│ - Dead code elimination │ +│ - Operator fusion │ +│ 2. Validate optimization preserves semantics │ +└──────────────────────┬──────────────────────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────┐ +│ Phase 4: SDFG Construction │ +├─────────────────────────────────────────────────────────┤ +│ 1. Create empty SDFG with initial state │ +│ 2. Register inputs/outputs as data descriptors │ +│ 3. For each ONNX node: │ +│ a. Get node class from registry │ +│ b. Extract and convert attributes │ +│ c. Create node instance │ +│ d. Add input/output connectors │ +│ e. Create AccessNodes for data │ +│ f. Add edges with memlets │ +│ 4. Handle special cases (Constants, Identities) │ +└──────────────────────┬──────────────────────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────┐ +│ Phase 5: Weight Management │ +├─────────────────────────────────────────────────────────┤ +│ 1. Load initializers (weights/biases) from ONNX │ +│ 2. Convert to PyTorch tensors │ +│ 3. Store in self.weights dictionary │ +│ 4. Create corresponding DaCe arrays (non-transient) │ +└──────────────────────┬──────────────────────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────┐ +│ Phase 6: Output Handling │ +├─────────────────────────────────────────────────────────┤ +│ 1. Promote scalars to arrays (CPU only) │ +│ 2. Create return arrays (__return, __return_0, etc.) │ +│ 3. Add copy-out state for outputs │ +│ 4. Fuse states for efficiency │ +└──────────────────────┬──────────────────────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────┐ +│ Phase 7: GPU Transformation (if cuda=True) │ +├─────────────────────────────────────────────────────────┤ +│ 1. Apply sdfg.apply_gpu_transformations() │ +│ 2. Convert memory to GPU_Global storage │ +│ 3. Add GPU kernel launch infrastructure │ +└──────────────────────┬──────────────────────────────────┘ + ▼ + SDFG with ONNX Library Nodes +``` + +### 5.2 Node Construction Details + +For each ONNX operation in the graph: + +**Step 1: Operation Lookup** +```python +if not has_onnx_node(node.op_type): + raise ValueError(f"Unsupported operation: {node.op_type}") +``` + +**Step 2: Attribute Extraction** +```python +attributes = {attr.name: convert_attribute_proto(attr) + for attr in node.attribute} +``` + +**Step 3: Node Class Retrieval** +```python +node_class = get_onnx_node(node.op_type, model_opset_version) +``` + +**Step 4: Instance Creation** +```python +dace_node = node_class(name=node.name, **attributes) +``` + +**Step 5: Connector and Edge Creation** +```python +for input_param in node_class.schema.inputs: + # Validate parameter type (Single/Optional/Variadic) + # Create or reuse AccessNode + # Add connector to operation node + # Create Memlet edge with full array semantics +``` + +### 5.3 Special Handling + +- **Constants**: Directly added to weights, no node created +- **Identities**: Can be elided during optimization +- **Variadic Parameters**: Use naming convention `param_name__index` +- **Optional Parameters**: Checked for presence, skipped if absent + +--- + +## 6. Shape Inference System + +### 6.1 Purpose and Motivation + +ONNX models often have **dynamic shapes** where tensor dimensions depend on runtime inputs: +- Batch size: Variable number of samples +- Sequence length: Variable-length sequences (NLP) +- Image dimensions: Variable-size images + +Shape inference computes tensor shapes either symbolically or concretely for all intermediate tensors in the model. + +### 6.2 Integration + +Shape inference uses `onnxruntime.tools.symbolic_shape_infer.SymbolicShapeInference` from the ONNX Runtime library. + +Called during model import: +```python +from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference +model = SymbolicShapeInference.infer_shapes(model, auto_merge=auto_merge) +``` + +### 6.3 Capabilities + +**Symbolic Dimensions**: +```python +# Input shape: [batch_size, 3, 224, 224] +# After Conv: [batch_size, 64, 112, 112] +# After Pool: [batch_size, 64, 56, 56] +``` + +**Concrete Evaluation**: +```python +# Known: kernel_size=3, stride=2, padding=1, input_size=224 +# Computed: output_size = (224 + 2*1 - 3) / 2 + 1 = 112 +``` + +**Broadcasting**: +```python +# Shape A: [batch, 256, 1, 1] +# Shape B: [batch, 256, 7, 7] +# Result: [batch, 256, 7, 7] +``` + +**Auto-Merge** (optional): +```python +# Before: tensor_0: [batch_0, seq_len_0] +# tensor_1: [batch_1, seq_len_1] +# After: tensor_0: [batch, seq_len] +# tensor_1: [batch, seq_len] +``` + +### 6.4 Implementation Details + +Shape inference uses the **ONNX Runtime implementation** (`onnxruntime.tools.symbolic_shape_infer`) which provides: + +- Helper functions for dimension extraction and axis handling +- `SymbolicShapeInference` class with per-operation rules +- Sympy-based symbolic computation +- Integration with ONNX's native shape inference +- Special handling for complex operations (Reshape, Transpose, Concat) + +### 6.5 DaCe Integration + +Symbolic dimensions are added to the SDFG symbol table: +```python +for dim_name in symbolic_dimensions: + sdfg.add_symbol(dim_name, dace.int64) +``` + +At runtime, DaCe infers symbol values from input shapes: +```python +symbols = {} +if 'batch_size' in sdfg.symbols: + symbols['batch_size'] = input_tensor.shape[0] +``` + +--- + +## 7. Implementation Strategies + +### 7.1 The ONNXForward Interface + +**Location**: [forward_implementation_abc.py](forward_implementation_abc.py) + +```python +@make_registry +class ONNXForward(abc.ABC): + """Abstract base for ONNX operation implementations.""" + + @staticmethod + def forward_can_be_applied(node: ONNXOp, state: SDFGState, + sdfg: SDFG) -> bool: + """Check if implementation is applicable.""" + return True + + @staticmethod + @abc.abstractmethod + def forward(node: ONNXOp, state: SDFGState, + sdfg: SDFG) -> Union[Node, SDFG]: + """Expand node to DaCe constructs.""" + ... +``` + +### 7.2 Implementation Types + +#### 1. Pure Implementations + +**Location**: Implementations are organized across multiple files in [op_implementations/](op_implementations/): +- `elementwise_ops.py` - Element-wise operations (Add, Mul, Div, etc.) +- `reduction_ops.py` - Reduction operations (ReduceMean, ReduceSum, etc.) +- `array_ops.py` - Array operations (Concat, Gather, Reshape, etc.) +- `linalg_ops.py` - Linear algebra operations (MatMul, Gemm, etc.) +- `normalization_ops.py` - Normalization operations (BatchNorm, LayerNorm, etc.) +- `image_ops.py` - Image operations (Conv, Pool, etc.) + +**Purpose**: Provides reference implementations focused on correctness + +**Characteristics**: +- Written in Python/NumPy style +- Automatically parsed via the DaCe Python frontend +- Semantically correct according to ONNX specifications +- May not be optimally performant until further transformations are applied + +**Implementation Pattern**: +```python +@python_pure_op_implementation +def Relu(X: dace.float32[H, W]): + """Pure implementation of ReLU activation.""" + return np.maximum(X, 0) +``` + +**Process**: +1. Decorator creates an `ONNXForward` subclass +2. Function is parsed via the DaCe Python frontend +3. Converted to SDFG with maps and tasklets +4. Result: Efficient parallel code generation + +#### 2. Optimized Implementations + +**Location**: [op_implementations/img_op_implementations.py](op_implementations/img_op_implementations.py) + +**Purpose**: Provides performance-optimized implementations for specific operations + +**Examples**: +- `Conv`: Optimized convolution with im2col or Winograd +- `MaxPool/AveragePool`: Efficient pooling operations +- `BatchNormalization`: Fused batch normalization + +**Characteristics**: +- Hand-crafted SDFG construction +- May use library calls (BLAS, cuDNN) +- Optimized for specific hardware/configurations + +#### 3. Hardware-Specific Implementations + +**Concept**: Implementations optimized for specific hardware + +**Examples** (potential): +- `cuDNN` implementations for GPU (Conv, Pool, BatchNorm) +- `MKL-DNN` implementations for CPU +- `FPGA` implementations for reconfigurable hardware + +**Selection via Applicability**: +```python +@op_implementation(op="Conv", name="cudnn") +class CuDNNConv(ONNXForward): + @staticmethod + def forward_can_be_applied(node, state, sdfg): + return sdfg.gpu and has_cudnn() +``` + +### 7.3 Implementation Selection + +**Process**: + +1. Query the registry for the operation's implementations +2. Filter by applicability: `forward_can_be_applied()` +3. Prefer user-specified implementation (if set) +4. Fall back to the default implementation +5. Expand the node using the selected implementation + +**Priority Order**: +1. User-specified implementation (node property) +2. First applicable implementation (by registration order) +3. Default implementation (usually "pure") + +### 7.4 Common Implementation Patterns + +#### Pattern A: Pure Python with Decorator + +```python +@python_pure_op_implementation +def Softmax(X: dace.float32[N, M], axis: int = -1): + """Softmax activation function.""" + exp_x = np.exp(X - np.max(X, axis=axis, keepdims=True)) + return exp_x / np.sum(exp_x, axis=axis, keepdims=True) +``` + +#### Pattern B: Manual SDFG Construction + +```python +@op_implementation(op="MatMul", name="blas") +class BLASMatMul(ONNXForward): + @staticmethod + def forward(node, state, sdfg): + # Create nested SDFG + nsdfg = dace.SDFG(f"{node.label}_matmul") + nstate = nsdfg.add_state() + + # Use BLAS library node + from dace.libraries.blas import MatMul + matmul_node = MatMul("matmul") + + # Connect inputs/outputs + # ... + + return nsdfg +``` + +#### Pattern C: Library Call Integration + +```python +@op_implementation(op="Conv", name="optimized") +class OptimizedConv(ONNXForward): + @staticmethod + def forward(node, state, sdfg): + # Leverage existing DaCe library nodes + from dace.libraries.standard import Conv2D + + # Convert ONNX semantics to library call + conv_node = Conv2D(...) + + # Return library node (further expanded by DaCe) + return conv_node +``` + +### 7.5 Implementation Utilities + +**Location**: [op_implementations/utils.py](op_implementations/utils.py) + +**Key Functions**: + +- `@op_implementation(op, name)`: Register implementation with registry +- `@python_pure_op_implementation`: Create implementation from Python function +- `program_for_node()`: Convert Python function to nested SDFG +- `empty_sdfg_for_node()`: Create empty nested SDFG template + +--- + +## 8. Key Algorithms + +### 8.1 Dynamic Node Class Generation + +**Algorithm**: Creates Python classes at import time + +``` +For each ONNX operation in onnx.defs.get_all_schemas(): + 1. Extract OpSchema from ONNX + 2. Convert to ONNXSchema (DaCe representation) + 3. For each version of the operation: + a. Generate class name: ONNX{OpName}_{Version} + b. Create properties from attributes: + - Map ONNX types to DaCe property types + - Set defaults and required flags + c. Generate __init__ constructor: + - Validate required attributes provided + - Convert types (e.g., StringLiteral → str) + - Set up connectors for parameters + d. Generate documentation from schema + e. Create class with type(): + cls = type(cls_name, (ONNXOp,), attrs) + f. Register as DaCe library node: + cls = dace.library.node(cls) + g. Link implementations: + - Query ONNXForward.extensions() + - Create ExpandTransformation wrappers + - Register with node class + h. Store in registry: + _ONNX_OPS[op_name][version] = cls + 4. Export latest version to module: + globals()[f"ONNX{OpName}"] = latest_version +``` + +**Result**: 100+ operation classes generated automatically, ready for use + +### 8.2 Schema-Based Validation + +**Algorithm**: Validates node construction + +``` +When creating ONNX node instance: + 1. Check required attributes provided: + missing = required_attrs - provided_attrs + if missing: raise ValueError(...) + + 2. Validate connector usage: + For each edge connected to node: + a. Determine parameter (input/output) + b. Check parameter type (Single/Optional/Variadic) + c. Validate connector naming: + - Single/Optional: exact name + - Variadic: name__index format + d. Verify edge data type matches constraints + + 3. Type constraint checking: + For each connector with type constraint: + a. Get connector data type + b. Look up constraint allowed types + c. Verify type in allowed set + d. If not: raise validation error +``` + +### 8.3 Runtime Symbol Inference + +**Algorithm**: Infers symbolic dimension values from inputs + +``` +When executing ONNXModel: + 1. Collect all symbols in SDFG: + symbols = sdfg.free_symbols + + 2. For each input tensor: + For each dimension in tensor.shape: + if dimension_name in symbols: + inferred_symbols[dimension_name] = dimension_value + + 3. Verify all required symbols inferred: + missing = symbols - inferred_symbols.keys() + if missing: raise ValueError(...) + + 4. Pass symbols to compiled SDFG: + result = compiled_sdfg(inputs..., **inferred_symbols) +``` + +### 8.4 Type Conversion Pipeline + +**Algorithm**: Converts between type systems + +``` +ONNX Type → DaCe Type: + 1. Extract ONNX type enum (e.g., TensorProto.FLOAT) + 2. Look up in cached mapping: + dace_type = onnx_to_dace_type_map[onnx_type] + 3. Return DaCe typeclass (e.g., dace.float32) + +DaCe Type → NumPy Type: + 1. Get DaCe typeclass + 2. Extract numpy_dtype property + 3. Return numpy dtype (e.g., np.float32) + +NumPy Type → PyTorch Type: + 1. Look up in numpy_to_torch_dtype_dict + 2. Return torch dtype (e.g., torch.float32) +``` + +--- + +## 9. Extension Points + +### 9.1 Adding New ONNX Operations + +If an ONNX operation is not yet supported, you can add it by creating an implementation: + +**Step 1: Create Implementation Class** + +```python +from dace.libraries.onnx.forward_implementation_abc import ONNXForward +from dace.libraries.onnx.op_implementations.utils import op_implementation + +@op_implementation(op="CustomOp", name="pure") +class CustomOpImplementation(ONNXForward): + @staticmethod + def forward_can_be_applied(node, state, sdfg): + # Check if this implementation is applicable + return True + + @staticmethod + def forward(node, state, sdfg): + # Create nested SDFG for operation + # ... + return nested_sdfg +``` + +**Step 2: Register Implementation** + +The `@op_implementation` decorator automatically registers the implementation with the ONNXForward registry. + +**Step 3: Use in Models** + +The operation will now be available when importing ONNX models that use it. + +### 9.2 Custom Implementations for Existing Operations + +Override the default implementation with a custom one: + +```python +@op_implementation(op="Conv", name="my_optimized_conv") +class MyOptimizedConv(ONNXForward): + @staticmethod + def forward_can_be_applied(node, state, sdfg): + # Only apply for specific configurations + return (node.kernel_shape == [3, 3] and + node.stride == [1, 1]) + + @staticmethod + def forward(node, state, sdfg): + # Custom optimized implementation + # ... +``` + +**Selection**: Set `node.default_implementation = "my_optimized_conv"` or allow DaCe to select automatically based on applicability. diff --git a/dace/libraries/onnx/op_implementations/__init__.py b/dace/libraries/onnx/op_implementations/__init__.py new file mode 100644 index 0000000000..256b3a444a --- /dev/null +++ b/dace/libraries/onnx/op_implementations/__init__.py @@ -0,0 +1,11 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +from .utils import * +from .common import * +from .elementwise_ops import * +from .reduction_ops import * +from .normalization_ops import * +from .array_ops import * +from .linalg_ops import * +from .image_ops import * +from .img_op_implementations import * +from .criteria_implementations import * diff --git a/dace/libraries/onnx/op_implementations/array_ops.py b/dace/libraries/onnx/op_implementations/array_ops.py new file mode 100644 index 0000000000..c918d9b6b3 --- /dev/null +++ b/dace/libraries/onnx/op_implementations/array_ops.py @@ -0,0 +1,681 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Array and Tensor Manipulation Operations for ONNX in DaCe. + +This module provides pure DaCe implementations for ONNX array/tensor manipulation +operations. These operations handle shape manipulation, slicing, and other array transformations. + +The module contains: +- Shape manipulation operations (Reshape, Flatten, Squeeze, Unsqueeze, Expand) +- Slicing and indexing operations (Slice, SliceAllConstant, Gather) +- Concatenation and splitting operations (Concat, Split) +- Transposition operations (Transpose, EinsumTranspose) +- Shape query operations (Shape) + +Each implementation follows the ONNX specification and is designed to be: +- Semantically correct according to ONNX standards +- Efficient when converted to DaCe SDFGs +""" + +import copy +from math import prod +import typing + +import dace +import numpy as np +from dace import SDFG, SDFGState, subsets +from dace.sdfg.nodes import Node +from dace.sdfg.utils import in_desc_with_name, in_edge_with_name, out_desc_with_name +from dace.libraries.onnx.forward_implementation_abc import ONNXForward +from dace.libraries.onnx.op_implementations.common import iterables_equal +from dace.libraries.onnx.nodes import onnx_op +from dace.libraries.onnx.op_implementations.utils import (empty_sdfg_for_node, op_implementation, program_for_node, + python_pure_op_implementation) +from dace.transformation.onnx import constant_folding +from dace.transformation.onnx.replacement import onnx_constant_or_none +from dace.libraries.onnx import converters + +# ============================================================================== +# Concatenation Operations +# ============================================================================== + + +@op_implementation(op="Concat", name="pure") +class PureConcat(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> bool: + return True + + @staticmethod + def forward(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + axis = node.axis + + num_inputs = len(state.in_edges(node)) + + def inp_name(i): + return f"inputs__{i}" + + out_name = "concat_result" + + inp_data = [in_desc_with_name(node, state, sdfg, inp_name(i)) for i in range(num_inputs)] + out_data = out_desc_with_name(node, state, sdfg, out_name) + + nsdfg = dace.SDFG(node.label) + + inp_data_descs = [copy.deepcopy(desc) for desc in inp_data] + for i, desc in enumerate(inp_data_descs): + desc.transient = False + nsdfg.add_datadesc(inp_name(i), desc) + out_data_desc = copy.deepcopy(out_data) + out_data_desc.transient = False + nsdfg.add_datadesc(out_name, out_data_desc) + + inp_shapes = [d.shape for d in inp_data] + out_shape = out_data_desc.shape + + nstate = nsdfg.add_state() + out_write = nstate.add_write(out_name) + + for inp_idx in range(num_inputs): + inp_read = nstate.add_read(inp_name(inp_idx)) + + tasklet = nstate.add_tasklet( + f'concat_{inp_idx}', + {'inp': inp_data_descs[inp_idx].dtype}, + {'out': out_data_desc.dtype}, + "out = inp", + ) + + map_entry, map_exit = nstate.add_map(f"concat_map_{inp_idx}", { + f"i{i}": f"0:{s}" + for i, s in enumerate(inp_shapes[inp_idx]) + }) + + inp_access = [f'i{i}' for i, _ in enumerate(inp_shapes[inp_idx])] + inp_access_str = ", ".join(inp_access) + inp_memlet = dace.Memlet(f"{inp_name(inp_idx)}[{inp_access_str}]") + + stack_idx_offset = "" + for i in range(inp_idx): + stack_idx_offset += f" + ({inp_shapes[i][axis]})" + + out_access = [f'i{i}' for i, _ in enumerate(out_shape)] + if stack_idx_offset: + out_access[axis] += stack_idx_offset + out_access_str = ", ".join(out_access) + out_memlet = dace.Memlet(f"{out_name}[{out_access_str}]") + + nstate.add_memlet_path(inp_read, map_entry, tasklet, memlet=inp_memlet, dst_conn="inp") + nstate.add_memlet_path(tasklet, map_exit, out_write, memlet=out_memlet, src_conn="out") + + return nsdfg + + +# ============================================================================== +# Shape Manipulation Operations - Unsqueeze +# ============================================================================== + + +@op_implementation(op="Unsqueeze", name="pure") +class PureUnsqueeze(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> bool: + return True + + @staticmethod + def forward(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + + # Get input/output descriptors + expanded_desc = copy.deepcopy(out_desc_with_name(node, state, sdfg, "expanded")) + + def prog(data, expanded): + expanded[:] = np.reshape(data, expanded_desc.shape) + + return program_for_node(prog, sdfg, state, node) + + +# ============================================================================== +# Shape Manipulation Operations - Squeeze +# ============================================================================== + + +@op_implementation(op="Squeeze", name="pure") +class PureSqueeze(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> bool: + return True + + @staticmethod + def forward(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + squeezed_desc = copy.deepcopy(out_desc_with_name(node, state, sdfg, "squeezed")) + + def prog(data, squeezed): + squeezed[:] = np.reshape(data, squeezed_desc.shape) + + return program_for_node(prog, sdfg, state, node) + + +# ============================================================================== +# Shape Manipulation Operations - Expand +# ============================================================================== + + +@op_implementation(op="Expand", name="pure") +class PureExpand(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + return True + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + + shape = out_desc_with_name(node, state, sdfg, "output").shape + + def prog(input, output): + output = np.broadcast_to(input, shape) + + return program_for_node(prog, sdfg, state, node) + + +@op_implementation(op="Expand", name="pure") +class PureExpand(ONNXForward): + """ Handle no-op case for Expand """ + + @staticmethod + def forward_can_be_applied(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + return iterables_equal( + in_desc_with_name(node, state, sdfg, "input").shape, + out_desc_with_name(node, state, sdfg, "output").shape) + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + + constant_folding.remove_node_and_computation(sdfg, state, node, "shape") + + def prog(input, output): + output[:] = input + + return program_for_node(prog, sdfg, state, node) + + +# ============================================================================== +# Transposition Operations +# ============================================================================== + + +@python_pure_op_implementation( + perm=lambda node, data: node.perm if node.perm is not None else list(reversed(range(len(data.shape))))) +def Transpose(data, transposed): + transposed[:] = np.transpose(data, axes=perm) + + +@op_implementation(op="Transpose", name="einsum") +class EinsumTranspose(ONNXForward): + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + from dace.libraries.onnx.nodes.onnx_op_registry import ONNXEinsum # avoid import loop + perm = node.perm + input_desc = in_desc_with_name(node, state, sdfg, "data") + output_desc = out_desc_with_name(node, state, sdfg, "transposed") + + letters = [chr(ord('z') - i) for i in range(26)] + input_letters = "".join(letters[i] for i, _ in enumerate(input_desc.shape)) + output_letters = "".join(letters[i] for i in perm) + equation_str = f"{input_letters}->{output_letters}" + + nsdfg = dace.SDFG(node.label + "_expansion") + nstate = nsdfg.add_state() + einsum_node: onnx_op.ONNXOp = ONNXEinsum(node.label + "_einsum_expansion", equation=equation_str) + + nstate.add_node(einsum_node) + einsum_node.add_in_connector("Inputs__0") + nsdfg.add_datadesc("data", copy.deepcopy(input_desc)) + nsdfg.add_datadesc("transposed", copy.deepcopy(output_desc)) + nsdfg.arrays["data"].transient = False + nsdfg.arrays["transposed"].transient = False + + nstate.add_edge(nstate.add_read("data"), None, einsum_node, "Inputs__0", nsdfg.make_array_memlet("data")) + nstate.add_edge(einsum_node, "Output", nstate.add_write("transposed"), None, + nsdfg.make_array_memlet("transposed")) + + return nsdfg + + +# ============================================================================== +# Reshape and Flatten Operations +# ============================================================================== + + +@python_pure_op_implementation(shape=lambda reshaped: reshaped.shape, + allowzero=lambda node: getattr(node, 'allowzero', 0)) +def Reshape(data, reshaped): + # If allowzero is 0 (default), we use numpy's reshape which doesn't allow zeros + # If allowzero is 1, we need to handle zeros in the shape tensor + if allowzero == 0: + reshaped[:] = np.reshape(data, shape) + else: + # For allowzero=1, we need to handle zeros in the shape tensor + # This means we need to preserve the original dimension size when a zero is encountered + new_shape = list(shape) + for i, dim in enumerate(new_shape): + if dim == 0: + new_shape[i] = data.shape[i] + reshaped[:] = np.reshape(data, new_shape) + + +@python_pure_op_implementation(shape=lambda input, node: [prod(input.shape[:node.axis]), prod(input.shape[node.axis:])]) +def Flatten(input, output): + output[:] = input.reshape(shape) + + +# ============================================================================== +# Slicing Operations +# ============================================================================== + + +@op_implementation(op="Slice", name="pure") +class PureSlice(ONNXForward): + ''' + Slice expansion + ''' + + @staticmethod + def forward_can_be_applied(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + # Check that all the inputs (even the optional ones) are present and constant + + if not hasattr(sdfg, "_parent_onnx_model"): + return False + + constant_starts = in_edge_with_name(node, state, "starts").src.data in sdfg._parent_onnx_model.clean_weights + + if not constant_starts: + return False + if in_edge_with_name(node, state, "ends").src.data not in sdfg._parent_onnx_model.clean_weights: + return False + + # optional inputs + is_axes_present = True + try: + if in_edge_with_name(node, state, "axes").src.data not in sdfg._parent_onnx_model.clean_weights: + return False + except ValueError: + is_axes_present = False + + is_steps_present = True + try: + if in_edge_with_name(node, state, "steps").src.data not in sdfg._parent_onnx_model.clean_weights: + return False + except ValueError: + is_steps_present = False + + # Current constraints: axes and steps must be explict. Axes must be zero and steps must be 1 + if not is_axes_present or not is_steps_present: + return False + + step = sdfg._parent_onnx_model.clean_weights[in_edge_with_name(node, state, "steps").src.data].numpy()[0] + axis = sdfg._parent_onnx_model.clean_weights[in_edge_with_name(node, state, "axes").src.data].numpy()[0] + + if step != 1 or axis != 0: + return False + + return True + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + + start = sdfg._parent_onnx_model.clean_weights[in_edge_with_name(node, state, "starts").src.data].numpy()[0] + end = sdfg._parent_onnx_model.clean_weights[in_edge_with_name(node, state, "ends").src.data].numpy()[0] + + output_shape = out_desc_with_name(node, state, sdfg, "output").shape + if end == np.iinfo(np.int64).max: + # Pytorch exporter artifact + end = start + output_shape[0] + + def prog(data, output): + tmp = data[start:end:1, :] + # We need reshape to avoid Invalid Edge errors + output[:] = np.reshape(tmp, output.shape) + + return program_for_node(prog, sdfg, state, node) + + +@op_implementation(op="Slice", name="pure") +class PureSliceAllConstant(ONNXForward): + + @staticmethod + def _get_constant(conn: str, node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG): + try: + srcnode = next(state.in_edges_by_connector(node, conn)).src + except StopIteration: + # Return default values + if conn == "steps": + return 1 + return None + # Scalar copied to GPU + if 'gpu_' in srcnode.data: + srcnode = state.predecessors(srcnode)[0] + return onnx_constant_or_none(sdfg, srcnode) + + @staticmethod + def forward_can_be_applied(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + for inconn in ("axes", "ends", "starts", "steps"): + if PureSliceAllConstant._get_constant(inconn, node, state, sdfg) is None: + return False + return True + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + axes = PureSliceAllConstant._get_constant('axes', node, state, sdfg) + ends = PureSliceAllConstant._get_constant('ends', node, state, sdfg) + starts = PureSliceAllConstant._get_constant('starts', node, state, sdfg) + steps = PureSliceAllConstant._get_constant('steps', node, state, sdfg) + + constant_folding.remove_node_and_computation(sdfg, state, node, "axes") + constant_folding.remove_node_and_computation(sdfg, state, node, "ends") + constant_folding.remove_node_and_computation(sdfg, state, node, "starts") + constant_folding.remove_node_and_computation(sdfg, state, node, "steps") + + nsdfg = dace.SDFG(node.label + "_expansion") + nstate = nsdfg.add_state() + + idesc = in_desc_with_name(node, state, sdfg, "data") + odesc = out_desc_with_name(node, state, sdfg, "output") + nsdfg.add_datadesc("data", copy.deepcopy(idesc)) + nsdfg.add_datadesc("output", copy.deepcopy(odesc)) + nsdfg.arrays["data"].transient = False + nsdfg.arrays["output"].transient = False + + if not isinstance(axes, (tuple, list)): + axes = [axes] + ends = [ends] + starts = [starts] + steps = [steps] + + # Set up slicing memlet + rng = [(0, s - 1, 1) for s in idesc.shape] + for axis, start, end, step in zip(axes, starts, ends, steps): + s = idesc.shape[axis] + if end > s: + end = s + rng[axis] = (start, end - 1, step) + + sbs = subsets.Range(rng) + osbs = subsets.Range.from_array(odesc) + + # Make copy / view + rnode = nstate.add_read("data") + wnode = nstate.add_write("output") + + nstate.add_nedge(rnode, wnode, dace.Memlet(data="data", subset=sbs, other_subset=osbs)) + + return nsdfg + + +# ============================================================================== +# Split Operations +# ============================================================================== + + +@op_implementation(op="Split", name="pure") +class SplitPure(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> bool: + from dace.transformation.onnx.replacement import onnx_constant_or_none + + # Check if we have either split input or num_outputs attribute + has_split_input = len(list(state.in_edges_by_connector(node, "split"))) > 0 + has_num_outputs = hasattr(node, 'num_outputs') + + if not (has_split_input or has_num_outputs): + return False + + # If split input is provided, it must be a constant + if has_split_input: + split_node = next(state.in_edges_by_connector(node, "split")).src + if not onnx_constant_or_none(sdfg, split_node): + return False + + return True + + @staticmethod + def forward(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + from dace.transformation.onnx.replacement import onnx_constant_or_none + + nsdfg = dace.SDFG(node.label + "_expansion") + nstate = nsdfg.add_state() + + split_dim = node.axis + idesc = in_desc_with_name(node, state, sdfg, "input") + nsdfg.add_datadesc("input", copy.deepcopy(idesc)) + nsdfg.arrays["input"].transient = False + + rnode = nstate.add_read("input") + + # Get split sizes either from input or compute from num_outputs + if len(list(state.in_edges_by_connector(node, "split"))) > 0: + # Get split sizes from input tensor + split_node = next(state.in_edges_by_connector(node, "split")).src + split_sizes = onnx_constant_or_none(sdfg, split_node) + if split_sizes is None: + raise ValueError("Split sizes must be constant") + + # Add split input as a data descriptor + split_desc = copy.deepcopy(in_desc_with_name(node, state, sdfg, "split")) + split_desc.transient = False + nsdfg.add_datadesc("split", split_desc) + split_read = nstate.add_read("split") + else: + # Compute split sizes from num_outputs + num_outputs = node.num_outputs + total_size = idesc.shape[split_dim] + base_size = total_size // num_outputs + remainder = total_size % num_outputs + split_sizes = [base_size + (1 if i < remainder else 0) for i in range(num_outputs)] + + # Verify split sizes + if sum(split_sizes) != idesc.shape[split_dim]: + raise ValueError( + f"Sum of split sizes ({sum(split_sizes)}) must equal dimension size ({idesc.shape[split_dim]})") + + offset = 0 + for i, odim in enumerate(split_sizes): + # Set up new node shape and memlet + new_shape = list(idesc.shape) + new_shape[split_dim] = odim + rng = subsets.Range([(0, s - 1, 1) if j != split_dim else (offset, offset + odim - 1, 1) + for j, s in enumerate(new_shape)]) + offset += odim + + # Set up data descriptor + oname = f"outputs__{i}" + odesc = copy.deepcopy(out_desc_with_name(node, state, sdfg, oname)) + odesc.transient = False + nsdfg.add_datadesc(oname, odesc) + wnode = nstate.add_write(oname) + + # Perform copy (view) + nstate.add_nedge(rnode, wnode, + dace.Memlet(data="input", subset=rng, other_subset=subsets.Range.from_array(odesc))) + + return nsdfg + + +# ============================================================================== +# Shape Query Operations +# ============================================================================== + + +@op_implementation(op="Shape", name="pure") +class PureShape(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> bool: + data_desc = in_desc_with_name(node, state, sdfg, "data") + + try: + np.array(data_desc.shape, np.int64) + except Exception: + # this happens if the shape is symbolic, for example + return False + + return True + + @staticmethod + def forward(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + + data_desc = in_desc_with_name(node, state, sdfg, "data") + shape_val = np.array(data_desc.shape, np.int64) + + nsdfg = dace.SDFG(node.label + "_expansion") + nstate = nsdfg.add_state() + nsdfg.add_datadesc( + "data", + copy.deepcopy(data_desc), + ) + nsdfg.arrays["data"].transient = False + nsdfg.add_array("shape", shape_val.shape, dtype=dace.int64) + s = nstate.add_write("shape") + + for i, v in enumerate(shape_val): + tasklet = nstate.add_tasklet("write_shape", {}, {'shape_scalar': dace.int64}, f"shape_scalar = {v}") + nstate.add_edge(tasklet, "shape_scalar", s, None, dace.Memlet("shape[{}]".format(i))) + + return nsdfg + + +# ============================================================================== +# Gather Operations +# ============================================================================== + + +@op_implementation(op="Gather", name="pure") +class PureGather(ONNXForward): + + @staticmethod + def forward(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + # To understand this operator, read the docs for np.take. + # The ONNX docs are not easy to understand (and are incorrect in opset 11) + + nsdfg, nstate, _, _ = empty_sdfg_for_node(sdfg, state, node, add_access_nodes=False) + out_desc = out_desc_with_name(node, state, sdfg, "output") + out_shape = out_desc.shape + idx_desc = in_desc_with_name(node, state, sdfg, "indices") + idx_shape = idx_desc.shape + data_shape = in_desc_with_name(node, state, sdfg, "data").shape + + # FIXME: we can sometimes generate views + + # Generate a copy kernel that loops over every element in the output + # and read the correct element according to the indices + + axis = node.axis + + map_ranges = [(f"i{i}", f"0:{s}") for i, s in enumerate(out_shape)] + # the map ranges can be partitioned into two parts. + # the first part is the range over the indices, the second part is the + # range over the data + if isinstance(idx_desc, dace.data.Scalar): + # handle the edgecase here because the shape of a scalar in dace is + # (1,) not () + idx_len = 0 + else: + idx_len = len(idx_shape) + map_ranges_indices = map_ranges[axis:axis + idx_len] + map_ranges_data = map_ranges[:axis] + map_ranges[axis + idx_len:] + + # compute the indexing expressions + fst = lambda x: x[0] + output_idx_str = 'output[' + ', '.join(map(fst, map_ranges)) + ']' + # the memlet string used to read data, which reads the whole axis + data_memlet_elems = list(map(fst, map_ranges_data)) + data_memlet_elems.insert(axis, f'0:{data_shape[axis]}') + + data_memlet_str = 'data[' + ', '.join(data_memlet_elems) + ']' + + indices_idx_str = 'indices' + if map_ranges_indices: + indices_idx_str += '[' + ', '.join(map(fst, map_ranges_indices)) + ']' + else: + indices_idx_str += '[0]' + + tasklet, me, mx = nstate.add_mapped_tasklet(node.label + "_tasklet", + map_ranges=map_ranges, + inputs={ + "__data": dace.Memlet(data_memlet_str), + "idx": dace.Memlet(indices_idx_str), + }, + code=f"__output = __data[idx]", + outputs={"__output": dace.Memlet(output_idx_str)}, + external_edges=True) + + # required to make underlying code to see it as a pointer and enable index-based access + # even if the data contains just a single element + tasklet.in_connectors["__data"] = dace.pointer(out_desc.dtype) + + return nsdfg + + +# ============================================================================== +# Utility Operations +# ============================================================================== + + +@python_pure_op_implementation +def Where(condition, X, Y, output): + output[:] = np.where(condition, X, Y) + + +@python_pure_op_implementation +def Identity(input, output): + output[:] = input + + +@op_implementation(op="Cast", name="pure") +class PureCast(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + + if (in_desc_with_name(node, state, sdfg, "input").dtype == out_desc_with_name(node, state, sdfg, + "output").dtype): + return True + + target_type = node.to + try: + converters.onnx_tensor_type_to_typeclass(target_type) + except ValueError: + return False + + return True + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + input_desc = in_desc_with_name(node, state, sdfg, "input") + output_desc = out_desc_with_name(node, state, sdfg, "output") + if (input_desc.dtype == output_desc.dtype): + + def prog(input, output): + output[:] = input + + return program_for_node(prog, sdfg, state, node) + else: + + nsdfg, nstate, _, _ = empty_sdfg_for_node(sdfg, state, node, add_access_nodes=False) + + shape = out_desc_with_name(node, state, sdfg, "output").shape + map_ranges = {f"i{i}": f"0:{s}" for i, s in enumerate(shape)} + index_str = f"{', '.join(map_ranges.keys())}" + tasklet, _, _ = nstate.add_mapped_tasklet(node.label + "_tasklet", + map_ranges=map_ranges, + inputs={f"__input": dace.Memlet(f"input[{index_str}]")}, + code=f"__output = __input", + outputs={"__output": dace.Memlet(f"output[{index_str}]")}, + external_edges=True) + + return nsdfg diff --git a/dace/libraries/onnx/op_implementations/common.py b/dace/libraries/onnx/op_implementations/common.py new file mode 100644 index 0000000000..fcda74fc27 --- /dev/null +++ b/dace/libraries/onnx/op_implementations/common.py @@ -0,0 +1,11 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Common utilities and helper functions for ONNX pure implementations. +""" + + +def iterables_equal(a, b) -> bool: + """ Return whether the two iterables ``a`` and ``b`` are equal. """ + if len(a) != len(b): + return False + return all(x == y for x, y in zip(a, b)) diff --git a/dace/libraries/onnx/op_implementations/criteria_implementations.py b/dace/libraries/onnx/op_implementations/criteria_implementations.py new file mode 100644 index 0000000000..42eccd6bca --- /dev/null +++ b/dace/libraries/onnx/op_implementations/criteria_implementations.py @@ -0,0 +1,90 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. + +from typing import Union + +import numpy as np + +import dace +from dace import SDFG, SDFGState, nodes as nd + +from dace.libraries.onnx.op_implementations.utils import op_implementation, program_for_node +from dace.libraries.onnx.nodes import onnx_op +from dace.libraries.onnx.forward_implementation_abc import ONNXForward + +from dace.sdfg.utils import in_desc_with_name + + +@op_implementation(op="SoftmaxCrossEntropyLoss", name="pure") +class PureSoftmaxCrossEntropyLoss(ONNXForward): + """Pure implementation of SoftmaxCrossEntropyLoss operation.""" + + @staticmethod + def forward_can_be_applied(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + """Check if this implementation can be applied to the given node. + + :param node: The SoftmaxCrossEntropyLoss ONNX node. + :param state: The SDFG state containing the node. + :param sdfg: The parent SDFG. + :return: True if the implementation can be applied, False otherwise. + """ + # Softmax is weird in opset 11, so let's stick to 2D for now + if len(in_desc_with_name(node, state, sdfg, "scores").shape) != 2: + return False + + if node.ignore_index is not None and node.ignore_index >= 0: + return False + + # The weights and log_prob arguments are optional + # We don't support them in this implementation + if 'weights' in node.in_connectors: + return False + if 'log_prob' in node.out_connectors: + return False + + return True + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> Union[nd.Node, SDFG]: + """Generate the forward pass implementation for SoftmaxCrossEntropyLoss. + + :param node: The SoftmaxCrossEntropyLoss ONNX node. + :param state: The SDFG state containing the node. + :param sdfg: The parent SDFG. + :return: A nested SDFG implementing the SoftmaxCrossEntropyLoss operation. + """ + + if node.reduction == 'mean': + + def reduction(x): + return np.mean(x) + elif node.reduction == 'none': + + def reduction(x): + return x + elif node.reduction == 'sum': + + def reduction(x): + return np.sum(x) + else: + raise ValueError("Unsupported reduction: {}".format(node.reduction)) + reduction = dace.program(reduction) + + # This implementation doesn't use ONNX LogSoftmax, and thus saves the + # final sum reduction by just grabbing the label scores directly, and + # skipping the computation of log softmax for all non-label scores + def prog(scores, labels, output): + # Extract the scores for the labels + + # Compute the log softmax normalization + maximum = np.maximum.reduce(scores, axis=1, keepdims=True) + max_sub = scores - maximum + exponent = np.exp(max_sub) + sum = np.add.reduce(exponent, axis=1) + log_sum = np.log(sum) + + # Compute the loss values + label_exponents = max_sub[:, labels] + losses = log_sum - label_exponents + output[:] = reduction(losses) + + return program_for_node(prog, sdfg, state, node) diff --git a/dace/libraries/onnx/op_implementations/elementwise_ops.py b/dace/libraries/onnx/op_implementations/elementwise_ops.py new file mode 100644 index 0000000000..bf66a868d5 --- /dev/null +++ b/dace/libraries/onnx/op_implementations/elementwise_ops.py @@ -0,0 +1,212 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Elementwise and mathematical ONNX operations. + +This module contains pure implementations of elementwise mathematical operations including: +- Basic arithmetic: Add, Sub, Mul, Div, Pow +- Unary math functions: Log, Exp, Sqrt, Sin, Cos, Tanh, Erf, Neg, Reciprocal +- Activation functions: Relu, LeakyRelu, Sigmoid, Softplus +- Utility operations: Clip + +All operations support broadcasting where applicable. +""" + +import typing + +import dace +import numpy as np +from dace import SDFG, SDFGState +from dace.sdfg.nodes import Node + +from dace.libraries.onnx.forward_implementation_abc import ONNXForward +from dace.libraries.onnx.nodes import onnx_op +from dace.libraries.onnx.op_implementations.utils import (op_implementation, out_desc_with_name, program_for_node, + python_pure_op_implementation) +from dace.sdfg.utils import in_desc_with_name, in_edge_with_name, out_desc_with_name +from dace.transformation.onnx.replacement import onnx_constant_or_none + +# ============================================================================ +# Unary Mathematical Operations +# ============================================================================ + + +@python_pure_op_implementation +def Log(input, output): + """ONNX Log operation implementation. + + Computes the natural logarithm of the input tensor element-wise. + + :param input: Input tensor of any numeric type. + :param output: Output tensor with the same shape and type as input. + """ + output[:] = np.log(input) + + +@python_pure_op_implementation +def Exp(input, output): + """ONNX Exp operation implementation. + + Computes the exponential of the input tensor element-wise. + + :param input: Input tensor of any numeric type. + :param output: Output tensor with the same shape and type as input. + """ + output[:] = np.exp(input) + + +@python_pure_op_implementation +def Sqrt(X, Y): + """ONNX Sqrt operation implementation. + + Computes the square root of the input tensor element-wise. + + :param X: Input tensor of any numeric type. + :param Y: Output tensor with the same shape and type as X. + """ + Y[:] = dace.elementwise(lambda x: sqrt(x), X) + + +@python_pure_op_implementation +def Sin(input, output): + output[:] = np.sin(input) + + +@python_pure_op_implementation +def Cos(input, output): + output[:] = np.cos(input) + + +@python_pure_op_implementation +def Tanh(input, output): + output[:] = dace.elementwise(lambda x: tanh(x), input) + + +@python_pure_op_implementation +def Erf(input, output): + output[:] = dace.elementwise(lambda x: erf(x), input) + + +@python_pure_op_implementation +def Neg(X, Y): + Y[:] = -X + + +@python_pure_op_implementation(string=lambda X: "lambda x: dace.{}(1) / x".format(X.dtype.to_string())) +def Reciprocal(X, Y): + Y[:] = dace.elementwise(string, X) + + +@python_pure_op_implementation +def Softplus(X, Y): + Y[:] = np.log(1 + np.exp(X)) + + +@python_pure_op_implementation(dtype=lambda X: X.dtype) +def Sigmoid(X, Y): + Y[:] = dace.elementwise(lambda x: dtype(1) / (dtype(1) + exp(-x)), X) + + +# ============================================================================ +# Binary Arithmetic Operations +# ============================================================================ + + +@op_implementation(op="Pow", name="pure") +class PurePow(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> bool: + return True + + @staticmethod + def forward(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + + # Special case for constant exponents + y_value = None + try: + if hasattr(sdfg, "_parent_onnx_model") and in_edge_with_name( + node, state, "Y").src.data in sdfg._parent_onnx_model.clean_weights: + y_value = sdfg._parent_onnx_model.clean_weights[in_edge_with_name(node, state, "Y").src.data].numpy() + except ValueError: + pass + + if y_value is not None and y_value.ndim == 0: + y_value = int(y_value) + + def prog(X, Z): + Z[:] = X**y_value + + return program_for_node(prog, sdfg, state, node) + + # General case + def prog(X, Y, Z): + Z[:] = X**Y + + return program_for_node(prog, sdfg, state, node) + + +@python_pure_op_implementation +def Add(A, B, C): + C[:] = A + B + + +@python_pure_op_implementation +def Sub(A, B, C): + C[:] = A - B + + +@python_pure_op_implementation +def Mul(A, B, C): + C[:] = A * B + + +@python_pure_op_implementation +def Div(A, B, C): + C[:] = A / B + + +# ============================================================================ +# Activation Functions and Clipping +# ============================================================================ + + +@python_pure_op_implementation(cast_lambda=lambda X: "lambda x: max(x, dace.{}(0))".format(X.dtype.to_string())) +def Relu(X, Y): + Y[:] = dace.elementwise(cast_lambda, X) + + +@python_pure_op_implementation( + cast_lambda=lambda node, X: "lambda x: (max(x, dace.{dtype}(0)) + {alpha} * min(x, dace.{dtype}(0)))".format( + dtype=X.dtype.to_string(), alpha=node.alpha)) +def LeakyRelu(X, Y): + Y[:] = dace.elementwise(cast_lambda, X) + + +@op_implementation(op="Clip", name="pure") +class PureClip(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + min_node = next(state.in_edges_by_connector(node, 'min')).src + max_node = next(state.in_edges_by_connector(node, 'max')).src + # TODO other cases + return (onnx_constant_or_none(sdfg, min_node) is not None and onnx_constant_or_none(sdfg, max_node) is not None) + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + + min_node = next(state.in_edges_by_connector(node, 'min')).src + max_node = next(state.in_edges_by_connector(node, 'max')).src + minval = onnx_constant_or_none(sdfg, min_node) + maxval = onnx_constant_or_none(sdfg, max_node) + + input_dtype = in_desc_with_name(node, state, sdfg, "input").dtype + minstr = f"dace.{input_dtype.to_string()}({minval})" + maxstr = f"dace.{input_dtype.to_string()}({maxval})" + + lfunc = f"lambda x: min(max(x, {minstr}), {maxstr})" + + def prog(input, output): + output[:] = dace.elementwise(lfunc, input) + + return program_for_node(prog, sdfg, state, node) diff --git a/dace/libraries/onnx/op_implementations/image_ops.py b/dace/libraries/onnx/op_implementations/image_ops.py new file mode 100644 index 0000000000..55bad5d2f8 --- /dev/null +++ b/dace/libraries/onnx/op_implementations/image_ops.py @@ -0,0 +1,443 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Image and Signal Processing Operations for ONNX in DaCe. + +This module provides implementations for ONNX operations related to image and signal +processing, including resizing, interpolation, and related transformations. + +Operations implemented: +- Resize: Image resizing with various interpolation modes (nearest, linear, cubic) + and coordinate transformation modes +""" + +import copy +import typing + +import dace +from dace import SDFG, SDFGState +from dace.sdfg.nodes import Node +from dace.sdfg.utils import in_desc_with_name, out_desc_with_name + +from dace.libraries.onnx.forward_implementation_abc import ONNXForward +from dace.libraries.onnx.op_implementations.utils import op_implementation + + +@op_implementation(op="Resize", name="pure") +class PureResize(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> bool: + # Check if we have either scales or sizes (but not both) + has_scales = len(list(state.in_edges_by_connector(node, 'scales'))) > 0 + has_sizes = len(list(state.in_edges_by_connector(node, 'sizes'))) > 0 + + if has_scales == has_sizes: + return False + + # Check interpolation mode + mode = getattr(node, 'mode', 'nearest') + if mode is not None and mode not in ['nearest', 'linear', 'cubic']: + return False + + # Check nearest mode if using nearest interpolation + if mode == 'nearest': + nearest_mode = getattr(node, 'nearest_mode', 'round_prefer_floor') + if nearest_mode is not None and nearest_mode not in [ + 'round_prefer_floor', 'round_prefer_ceil', 'floor', 'ceil' + ]: + return False + + # Check coordinate transformation mode + coord_mode = getattr(node, 'coordinate_transformation_mode', 'half_pixel') + if coord_mode is not None and coord_mode not in [ + 'half_pixel', 'half_pixel_symmetric', 'pytorch_half_pixel', 'align_corners', 'asymmetric', + 'tf_crop_and_resize' + ]: + return False + + # For tf_crop_and_resize, roi must be present + if coord_mode == 'tf_crop_and_resize': + has_roi = len(list(state.in_edges_by_connector(node, 'roi'))) > 0 + if not has_roi: + return False + + # Check keep_aspect_ratio_policy if using sizes + if has_sizes: + policy = getattr(node, 'keep_aspect_ratio_policy', 'stretch') + if policy is not None and policy not in ['stretch', 'not_larger', 'not_smaller']: + return False + + # Check antialias + antialias = getattr(node, 'antialias', 0) + if antialias is not None and antialias not in [0, 1]: + return False + + # Check exclude_outside + exclude_outside = getattr(node, 'exclude_outside', 0) + if exclude_outside is not None and exclude_outside not in [0, 1]: + return False + + # Check extrapolation_value + extrapolation_value = getattr(node, 'extrapolation_value', 0.0) + if extrapolation_value is not None and not isinstance(extrapolation_value, (int, float)): + return False + + # Check cubic coefficient + if mode == 'cubic': + cubic_coeff_a = getattr(node, 'cubic_coeff_a', -0.75) + if cubic_coeff_a is not None and not isinstance(cubic_coeff_a, (int, float)): + return False + + # Check axes if provided + axes = getattr(node, 'axes', None) + if axes is not None: + if not isinstance(axes, (list, tuple)): + return False + # Check for duplicate axes + if len(set(axes)) != len(axes): + return False + # Check for valid axis values + rank = len(in_desc_with_name(node, state, sdfg, 'X').shape) + for axis in axes: + if not isinstance(axis, int) or axis < -rank or axis >= rank: + return False + + # Check input shapes + x_desc = in_desc_with_name(node, state, sdfg, 'X') + rank = len(x_desc.shape) + if has_scales: + scales_desc = in_desc_with_name(node, state, sdfg, 'scales') + if len(scales_desc.shape) != 1: + return False + if len(axes) if axes is not None else rank != scales_desc.shape[0]: + return False + if has_sizes: + sizes_desc = in_desc_with_name(node, state, sdfg, 'sizes') + if len(sizes_desc.shape) != 1: + return False + if len(axes) if axes is not None else rank != sizes_desc.shape[0]: + return False + + # Check output shape + y_desc = out_desc_with_name(node, state, sdfg, 'Y') + if len(x_desc.shape) != len(y_desc.shape): + return False + + return True + + @staticmethod + def forward(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + + inp_name = 'X' + out_name = 'Y' + + nsdfg = dace.SDFG(node.label) + + # Add required input and output descriptors + inp_data_desc = copy.deepcopy(in_desc_with_name(node, state, sdfg, inp_name)) + inp_data_desc.transient = False + nsdfg.add_datadesc(inp_name, inp_data_desc) + + out_data_desc = copy.deepcopy(out_desc_with_name(node, state, sdfg, out_name)) + out_data_desc.transient = False + nsdfg.add_datadesc(out_name, out_data_desc) + + # Check for optional parameters + has_scales = len(list(state.in_edges_by_connector(node, 'scales'))) > 0 + has_sizes = len(list(state.in_edges_by_connector(node, 'sizes'))) > 0 + has_roi = len(list(state.in_edges_by_connector(node, 'roi'))) > 0 + + # Get axes to resize + axes = node.axes or list(range(len(inp_data_desc.shape))) + + # Convert negative axes to positive + axes = [ax if ax >= 0 else len(inp_data_desc.shape) + ax for ax in axes] + + # Add optional parameter descriptors if they exist + if has_scales: + scales_name = 'scales' + scales_data_desc = copy.deepcopy(in_desc_with_name(node, state, sdfg, scales_name)) + scales_data_desc.transient = False + nsdfg.add_datadesc(scales_name, scales_data_desc) + + if has_sizes: + sizes_name = 'sizes' + sizes_data_desc = copy.deepcopy(in_desc_with_name(node, state, sdfg, sizes_name)) + sizes_data_desc.transient = False + nsdfg.add_datadesc(sizes_name, sizes_data_desc) + + if has_roi: + roi_name = 'roi' + roi_data_desc = copy.deepcopy(in_desc_with_name(node, state, sdfg, roi_name)) + roi_data_desc.transient = False + nsdfg.add_datadesc(roi_name, roi_data_desc) + + num_dims = len(inp_data_desc.shape) + + # setup inner SDFG + nstate = nsdfg.add_state() + + inp_read = nstate.add_read(inp_name) + out_write = nstate.add_write(out_name) + + # Add reads for optional parameters + tasklet_inputs = {'__inp': dace.pointer(inp_data_desc.dtype)} + if has_scales: + scales_read = nstate.add_read(scales_name) + tasklet_inputs['__scales'] = dace.pointer(scales_data_desc.dtype) + if has_sizes: + sizes_read = nstate.add_read(sizes_name) + tasklet_inputs['__sizes'] = dace.pointer(sizes_data_desc.dtype) + if has_roi: + roi_read = nstate.add_read(roi_name) + tasklet_inputs['__roi'] = dace.pointer(roi_data_desc.dtype) + + # Generate tasklet code for interpolation + tasklet_code = [] + + # Get interpolation parameters + coord_mode = getattr(node, 'coordinate_transformation_mode', 'half_pixel') + mode = getattr(node, 'mode', 'nearest') + antialias = getattr(node, 'antialias', 0) + exclude_outside = getattr(node, 'exclude_outside', 0) + extrapolation_value = getattr(node, 'extrapolation_value', 0.0) + + # Add cubic interpolation helper functions if needed + if mode == 'cubic': + cubic_coeff_a = getattr(node, 'cubic_coeff_a', -0.75) + tasklet_code.append(f""" + // Cubic interpolation helper functions + float cubic_weight(float x) {{ + float a = {cubic_coeff_a}; + float absx = abs(x); + if (absx < 1.0) {{ + return (a + 2.0) * absx * absx * absx - (a + 3.0) * absx * absx + 1.0; + }} else if (absx < 2.0) {{ + return a * absx * absx * absx - 5.0 * a * absx * absx + 8.0 * a * absx - 4.0 * a; + }} + return 0.0; + }} + """) + + # Loop over output dimensions + tasklet_code.append(""" + // Loop over all output dimensions + """) + + # Create nested loops for each dimension + for i in range(len(out_data_desc.shape)): + tasklet_code.append(f"for (int i{i} = 0; i{i} < {out_data_desc.shape[i]}; i{i}++) {{") + + # Calculate input indices + tasklet_code.append(""" + // Calculate input indices for each dimension + int inp_indices[{}]; + """.format(num_dims)) + + # Declare all size variables at the beginning + for i in range(num_dims): + if i in axes: + tasklet_code.append(f"float inp_size_{i};") + tasklet_code.append(f"float out_size_{i};") + + for i in range(num_dims): + tasklet_code.append(f"// Dimension {i}") + if i in axes: + axis_idx = axes.index(i) + if has_scales: + tasklet_code.append(f""" + float scale_{i} = __scales[{axis_idx}]; + inp_size_{i} = {inp_data_desc.shape[i]}; + out_size_{i} = {out_data_desc.shape[i]}; + float x_resized_{i} = i{i}; + float x_original_{i}; + """) + + # Add coordinate transformation based on mode + if coord_mode == 'half_pixel': + tasklet_code.append(f""" + x_original_{i} = (x_resized_{i} + 0.5) / scale_{i} - 0.5; + """) + elif coord_mode == 'half_pixel_symmetric': + tasklet_code.append(f""" + float adjustment_{i} = out_size_{i} / (out_size_{i} - 1); + float center_{i} = inp_size_{i} / 2; + float offset_{i} = center_{i} * (1 - adjustment_{i}); + x_original_{i} = offset_{i} + (x_resized_{i} + 0.5) / scale_{i} - 0.5; + """) + elif coord_mode == 'pytorch_half_pixel': + tasklet_code.append(f""" + x_original_{i} = out_size_{i} > 1 ? (x_resized_{i} + 0.5) / scale_{i} - 0.5 : 0; + """) + elif coord_mode == 'align_corners': + tasklet_code.append(f""" + x_original_{i} = x_resized_{i} * (inp_size_{i} - 1) / (out_size_{i} - 1); + """) + elif coord_mode == 'asymmetric': + tasklet_code.append(f""" + x_original_{i} = x_resized_{i} / scale_{i}; + """) + elif coord_mode == 'tf_crop_and_resize': + tasklet_code.append(f""" + float roi_start_{i} = __roi[{axis_idx}]; + float roi_end_{i} = __roi[{len(axes) + axis_idx}]; + if (out_size_{i} > 1) {{ + x_original_{i} = roi_start_{i} * (inp_size_{i} - 1) + x_resized_{i} * (roi_end_{i} - roi_start_{i}) * (inp_size_{i} - 1) / (out_size_{i} - 1); + }} else {{ + x_original_{i} = 0.5 * (roi_start_{i} + roi_end_{i}) * (inp_size_{i} - 1); + }} + """) + + # Add interpolation mode handling + if mode == 'nearest': + nearest_mode = getattr(node, 'nearest_mode', 'round_prefer_floor') + if nearest_mode == 'floor': + tasklet_code.append(f"inp_indices[{i}] = int(floor(x_original_{i}));") + elif nearest_mode == 'ceil': + tasklet_code.append(f"inp_indices[{i}] = int(ceil(x_original_{i}));") + else: # round_prefer_floor or round_prefer_ceil + tasklet_code.append(f"inp_indices[{i}] = int(round(x_original_{i}));") + elif mode == 'linear': + tasklet_code.append(f""" + float x0_{i} = floor(x_original_{i}); + float x1_{i} = ceil(x_original_{i}); + float w0_{i} = x1_{i} - x_original_{i}; + float w1_{i} = x_original_{i} - x0_{i}; + inp_indices[{i}] = int(x0_{i}); + inp_indices[{i} + {num_dims}] = int(x1_{i}); // Store second index for linear interpolation + """) + elif mode == 'cubic': + tasklet_code.append(f""" + float x0_{i} = floor(x_original_{i}); + float x1_{i} = x0_{i} + 1; + float x2_{i} = x0_{i} + 2; + float x3_{i} = x0_{i} + 3; + float w0_{i} = cubic_weight(x_original_{i} - x0_{i}); + float w1_{i} = cubic_weight(x_original_{i} - x1_{i}); + float w2_{i} = cubic_weight(x_original_{i} - x2_{i}); + float w3_{i} = cubic_weight(x_original_{i} - x3_{i}); + inp_indices[{i}] = int(x0_{i}); + inp_indices[{i} + {num_dims}] = int(x1_{i}); // Store indices for cubic interpolation + inp_indices[{i} + {2*num_dims}] = int(x2_{i}); + inp_indices[{i} + {3*num_dims}] = int(x3_{i}); + """) + else: # has_sizes + tasklet_code.append(f""" + inp_size_{i} = {inp_data_desc.shape[i]}; + out_size_{i} = {out_data_desc.shape[i]}; + inp_indices[{i}] = int(floor(i{i} * inp_size_{i} / out_size_{i})); + """) + else: + tasklet_code.append(f"inp_indices[{i}] = i{i};") + + # Calculate input index + tasklet_code.append(""" + // Calculate input index + int inp_idx = 0; + """) + for i in range(num_dims): + tasklet_code.append(f"inp_idx += inp_indices[{i}] * {inp_data_desc.strides[i]};") + + # Calculate output index + tasklet_code.append(""" + // Calculate output index + int out_idx = 0; + """) + for i in range(num_dims): + tasklet_code.append(f"out_idx += i{i} * {out_data_desc.strides[i]};") + + # Perform interpolation based on mode + if mode == 'linear': + tasklet_code.append(f""" + // Linear interpolation + float x0 = __inp [inp_idx]; + float x1 = __inp [inp_idx + {inp_data_desc.strides[axes[0]]}]; // Second index for linear interpolation + float result = w0 * x0 + w1 * x1; + """) + elif mode == 'cubic': + tasklet_code.append(f""" + // Cubic interpolation + float x0 = __inp [inp_idx]; + float x1 = __inp [inp_idx + {inp_data_desc.strides[axes[0]]}]; + float x2 = __inp [inp_idx + {2*inp_data_desc.strides[axes[0]]}]; + float x3 = __inp [inp_idx + {3*inp_data_desc.strides[axes[0]]}]; + float result = w0 * x0 + w1 * x1 + w2 * x2 + w3 * x3; + """) + else: # nearest or default + tasklet_code.append(""" + // Nearest neighbor interpolation + float result = __inp [inp_idx]; + """) + + # Handle antialiasing if enabled + if antialias == 1 and mode in ['linear', 'cubic']: + tasklet_code.append(""" + // Apply antialiasing filter + float scale = __scales[0]; // Assuming first axis is being resized + if (scale < 1.0) { + float filter_scale = max(1.0, 1.0 / scale); + result *= filter_scale; + } + """) + + # Handle exclude_outside if enabled + if exclude_outside == 1: + tasklet_code.append(f""" + // Handle exclude_outside + bool is_outside = false; + for (int i = 0; i < {num_dims}; i++) {{ + if (inp_indices[i] < 0 || inp_indices[i] >= {inp_data_desc.shape[0]}) {{ + is_outside = true; + break; + }} + }} + if (is_outside) {{ + result = 0.0; + }} + """) + + # Handle extrapolation_value for tf_crop_and_resize + if coord_mode == 'tf_crop_and_resize': + tasklet_code.append(f""" + // Handle extrapolation for tf_crop_and_resize + bool is_outside = false; + for (int i = 0; i < {num_dims}; i++) {{ + if (inp_indices[i] < 0 || inp_indices[i] >= {inp_data_desc.shape[0]}) {{ + is_outside = true; + break; + }} + }} + if (is_outside) {{ + result = {extrapolation_value}; + }} + """) + + # Write the result to output + tasklet_code.append(""" + // Write output + __out [out_idx] = result; + """) + + # Close dimension loops + for i in range(len(out_data_desc.shape)): + tasklet_code.append("}") + + tasklet = nstate.add_tasklet(f'tasklet_reshape', + tasklet_inputs, {'__out': dace.pointer(out_data_desc.dtype)}, + "\n".join(tasklet_code), + language=dace.Language.CPP) + + # Connect tasklet inputs + nstate.add_edge(inp_read, None, tasklet, "__inp", dace.Memlet.from_array(inp_name, inp_data_desc)) + if has_scales: + nstate.add_edge(scales_read, None, tasklet, "__scales", + dace.Memlet.from_array(scales_name, scales_data_desc)) + if has_sizes: + nstate.add_edge(sizes_read, None, tasklet, "__sizes", dace.Memlet.from_array(sizes_name, sizes_data_desc)) + if has_roi: + nstate.add_edge(roi_read, None, tasklet, "__roi", dace.Memlet.from_array(roi_name, roi_data_desc)) + + # Connect tasklet output + nstate.add_edge(tasklet, "__out", out_write, None, dace.Memlet.from_array(out_name, out_data_desc)) + + return nsdfg diff --git a/dace/libraries/onnx/op_implementations/img_op_implementations.py b/dace/libraries/onnx/op_implementations/img_op_implementations.py new file mode 100644 index 0000000000..1c3727b07c --- /dev/null +++ b/dace/libraries/onnx/op_implementations/img_op_implementations.py @@ -0,0 +1,563 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import copy +import functools +import typing + +import numpy as np + +import dace +from dace import SDFGState, SDFG, dtypes +from dace.sdfg import nodes, propagation +from dace.transformation.dataflow import MapExpansion, MapCollapse +from dace.sdfg.nodes import Node +from dace.libraries.onnx.forward_implementation_abc import ONNXForward +from dace.libraries.onnx.nodes.onnx_op import ONNXOp +from dace.libraries.onnx.op_implementations.utils import op_implementation, program_for_node +from dace.sdfg.utils import in_desc_with_name, out_desc_with_name, in_edge_with_name, out_edge_with_name + + +def _prod(sequence): + return functools.reduce(lambda a, b: a * b, sequence, 1) + + +@op_implementation(op="MaxPool", name="pure") +class PureMaxPool2D(ONNXForward): + """Pure implementation of 2D MaxPool operation.""" + + @staticmethod + def forward_can_be_applied(node: ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + """Check if this implementation can be applied to the given node. + + :param node: The MaxPool ONNX node. + :param state: The SDFG state containing the node. + :param sdfg: The parent SDFG. + :return: True if the implementation can be applied, False otherwise. + """ + X = in_desc_with_name(node, state, sdfg, "X") + + if "Indices" in {e.src_conn for e in state.out_edges(node)}: + return False + + image_dims = len(X.shape) - 2 + + # Only do 2D for now + if image_dims != 2: + return False + + if node.pads is not None and (len(node.pads) != image_dims * 2): + return False + + if node.strides is not None and len(node.strides) != image_dims: + return False + + if node.auto_pad != 'NOTSET': + return False + + if node.ceil_mode != 0 or node.storage_order != 0: + return False + + if node.dilations is not None and (not all(d == 1 + for d in node.dilations) or len(node.dilations) != image_dims): + return False + return True + + @staticmethod + def forward(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> typing.Union[nodes.Node, SDFG]: + """Generate the forward pass implementation for MaxPool2D. + + :param node: The MaxPool ONNX node. + :param state: The SDFG state containing the node. + :param sdfg: The parent SDFG. + :return: A nested SDFG implementing the MaxPool operation. + """ + X = in_desc_with_name(node, state, sdfg, "X") + Y = out_desc_with_name(node, state, sdfg, "Y") + + image_dims = len(X.shape) - 2 + batch_size = X.shape[0] + num_channels = X.shape[1] + strides = node.strides if node.strides is not None else [1 for _ in range(image_dims)] + pads = node.pads if node.pads is not None else [0 for _ in range(image_dims) * 2] + stride_x, stride_y = strides + assert pads[0] == pads[2] and pads[1] == pads[3] + pad_x, pad_y, _, _ = pads + filter_hx, filter_hy = node.kernel_shape + input_size_x, input_size_y = X.shape[2:] + output_size_x, output_size_y = Y.shape[2:] + + # Create new SDFG + nsdfg = dace.SDFG(node.label + "_expansion") + nstate = nsdfg.add_state() + + # Add data descriptors + nsdfg.add_datadesc("X", copy.deepcopy(X)) + nsdfg.add_datadesc("Y", copy.deepcopy(Y)) + nsdfg.arrays["X"].transient = False + nsdfg.arrays["Y"].transient = False + + # Add access nodes + X_read = nstate.add_read("X") + Y_write = nstate.add_write("Y") + + # Create tasklet that performs the max pooling operation + tasklet = nstate.add_tasklet(name=node.label + "_tasklet", + inputs={"__X": dace.pointer(X.dtype)}, + outputs={"__Y": dace.pointer(Y.dtype)}, + code=f""" + // Initialize output with minimum value + for (int b = 0; b < {batch_size}; b++) {{ + for (int c = 0; c < {num_channels}; c++) {{ + for (int out_x = 0; out_x < {output_size_x}; out_x++) {{ + for (int out_y = 0; out_y < {output_size_y}; out_y++) {{ + __Y[b * {Y.strides[0]} + c * {Y.strides[1]} + out_x * {Y.strides[2]} + out_y * {Y.strides[3]}] = {dtypes.min_value(Y.dtype)}; + }} + }} + }} + }} + + // Main max pooling computation + for (int b = 0; b < {batch_size}; b++) {{ + for (int c = 0; c < {num_channels}; c++) {{ + for (int out_x = 0; out_x < {output_size_x}; out_x++) {{ + for (int out_y = 0; out_y < {output_size_y}; out_y++) {{ + for (int hx = 0; hx < {filter_hx}; hx++) {{ + for (int hy = 0; hy < {filter_hy}; hy++) {{ + int sx = hx + out_x * {stride_x} - {pad_x}; + int sy = hy + out_y * {stride_y} - {pad_y}; + + if (0 <= sx && sx < {input_size_x} && 0 <= sy && sy < {input_size_y}) {{ + float input_val = __X[b * {X.strides[0]} + c * {X.strides[1]} + sx * {X.strides[2]} + sy * {X.strides[3]}]; + float& output_val = __Y[b * {Y.strides[0]} + c * {Y.strides[1]} + out_x * {Y.strides[2]} + out_y * {Y.strides[3]}]; + output_val = max(output_val, input_val); + }} + }} + }} + }} + }} + }} + }} + """, + language=dace.Language.CPP) + + # Connect the tasklet with memlets + nstate.add_edge(X_read, None, tasklet, "__X", dace.Memlet.from_array("X", X)) + nstate.add_edge(tasklet, "__Y", Y_write, None, dace.Memlet.from_array("Y", Y)) + + return nsdfg + + +@op_implementation(op="Conv", name="pure") +class PureConv2D(ONNXForward): + """Convolution implementation with support for grouped and depthwise convolutions.""" + + @staticmethod + def forward_can_be_applied(node: ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + """Check if this implementation can be applied to the given node. + + :param node: The Conv ONNX node. + :param state: The SDFG state containing the node. + :param sdfg: The parent SDFG. + :return: True if the implementation can be applied, False otherwise. + """ + X = in_desc_with_name(node, state, sdfg, "X") + W = in_desc_with_name(node, state, sdfg, "W") + try: + B = in_desc_with_name(node, state, sdfg, "B") + except Exception as e: + B = None + + image_dims = len(X.shape) - 2 + num_filters = W.shape[0] + num_channels = X.shape[1] + + if (X.dtype not in [dace.float16, dace.float32, dace.float64] + or W.dtype not in [dace.float16, dace.float32, dace.float64]): + return False + + # Only do 2D for now + if len(X.shape) != 4 or len(W.shape) != 4: + return False + + # Check group convolution constraints + groups = node.group if node.group is not None else 1 + if groups < 1: + return False + + # For grouped convolution: + # - Input channels must be divisible by groups + # - Output channels (num_filters) must be divisible by groups + # - Weight shape[1] should be num_channels // groups + if num_channels % groups != 0: + return False + if num_filters % groups != 0: + return False + if W.shape[1] != num_channels // groups: + return False + + if node.dilations is not None and (not all(d == 1 + for d in node.dilations) or len(node.dilations) != image_dims): + return False + + if node.pads is not None and (len(node.pads) != image_dims * 2): + return False + + if node.strides is not None and (len(node.strides) != image_dims): + return False + + if B is not None and B.shape[0] != num_filters: + return False + + if node.auto_pad != 'NOTSET': + return False + + return True + + @staticmethod + def forward(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> typing.Union[nodes.Node, SDFG]: + """Generate the forward pass implementation for Conv2D. + + :param node: The Conv ONNX node. + :param state: The SDFG state containing the node. + :param sdfg: The parent SDFG. + :return: A nested SDFG implementing the Conv operation. + """ + X = in_desc_with_name(node, state, sdfg, "X") + W = in_desc_with_name(node, state, sdfg, "W") + Y = out_desc_with_name(node, state, sdfg, "Y") + + # Check if bias is present in input connectors + B = in_desc_with_name(node, state, sdfg, "B") if "B" in node.in_connectors else None + + if node.kernel_shape is not None: + filter_hx, filter_hy = node.kernel_shape + else: + filter_hx, filter_hy = W.shape[2:] + + num_filters = W.shape[0] + num_channels = X.shape[1] + batch_size = X.shape[0] + + # Get number of groups (default to 1 for standard convolution) + groups = node.group if node.group is not None else 1 + channels_per_group = num_channels // groups + filters_per_group = num_filters // groups + + input_size_x, input_size_y = X.shape[2:] + output_size_y, output_size_x = Y.shape[2:] + stride_y, stride_x = node.strides or [1, 1] + pad_x, pad_y, _, _ = node.pads or [0, 0, 0, 0] + + dtype = X.dtype + + # Create new SDFG + nsdfg = dace.SDFG(node.label + "_expansion") + nstate = nsdfg.add_state() + + # Add data descriptors + nsdfg.add_datadesc("X", copy.deepcopy(X)) + nsdfg.add_datadesc("W", copy.deepcopy(W)) + nsdfg.add_datadesc("Y", copy.deepcopy(Y)) + if B is not None: + nsdfg.add_datadesc("B", copy.deepcopy(B)) + + # Set arrays as non-transient since they are inputs/outputs + nsdfg.arrays["X"].transient = False + nsdfg.arrays["W"].transient = False + nsdfg.arrays["Y"].transient = False + if B is not None: + nsdfg.arrays["B"].transient = False + + # Add access nodes + X_read = nstate.add_read("X") + W_read = nstate.add_read("W") + Y_write = nstate.add_write("Y") + if B is not None: + B_read = nstate.add_read("B") + + # Generate C++ code for the grouped convolution + code = f""" + // Initialize output + {f''' + // Initialize with bias + for (int b = 0; b < {batch_size}; b++) {{ + for (int m = 0; m < {num_filters}; m++) {{ + for (int out_x = 0; out_x < {output_size_x}; out_x++) {{ + for (int out_y = 0; out_y < {output_size_y}; out_y++) {{ + __Y[b * {Y.strides[0]} + m * {Y.strides[1]} + out_x * {Y.strides[2]} + out_y * {Y.strides[3]}] = __B[m]; + }} + }} + }} + }} + ''' if B is not None else f''' + // Zero-initialize output + for (int b = 0; b < {batch_size}; b++) {{ + for (int m = 0; m < {num_filters}; m++) {{ + for (int out_x = 0; out_x < {output_size_x}; out_x++) {{ + for (int out_y = 0; out_y < {output_size_y}; out_y++) {{ + __Y[b * {Y.strides[0]} + m * {Y.strides[1]} + out_x * {Y.strides[2]} + out_y * {Y.strides[3]}] = 0; + }} + }} + }} + }} + '''} + + // Main grouped convolution computation + for (int b = 0; b < {batch_size}; b++) {{ + for (int g = 0; g < {groups}; g++) {{ + // Each group processes a subset of input/output channels + int in_channel_start = g * {channels_per_group}; + int out_channel_start = g * {filters_per_group}; + + for (int m = 0; m < {filters_per_group}; m++) {{ + int out_channel = out_channel_start + m; + + for (int out_x = 0; out_x < {output_size_x}; out_x++) {{ + for (int out_y = 0; out_y < {output_size_y}; out_y++) {{ + // Only convolve with channels in the same group + for (int c = 0; c < {channels_per_group}; c++) {{ + int in_channel = in_channel_start + c; + + for (int hx = 0; hx < {filter_hx}; hx++) {{ + for (int hy = 0; hy < {filter_hy}; hy++) {{ + int sx = hx + out_x * {stride_x} - {pad_x}; + int sy = hy + out_y * {stride_y} - {pad_y}; + + if (0 <= sx && sx < {input_size_x} && 0 <= sy && sy < {input_size_y}) {{ + // Note: Weight tensor layout for grouped conv: + // [num_filters, channels_per_group, filter_hx, filter_hy] + float filter = __W[out_channel * {W.strides[0]} + c * {W.strides[1]} + hx * {W.strides[2]} + hy * {W.strides[3]}]; + float image = __X[b * {X.strides[0]} + in_channel * {X.strides[1]} + sx * {X.strides[2]} + sy * {X.strides[3]}]; + __Y[b * {Y.strides[0]} + out_channel * {Y.strides[1]} + out_x * {Y.strides[2]} + out_y * {Y.strides[3]}] += filter * image; + }} + }} + }} + }} + }} + }} + }} + }} + }} + """ + + # Create tasklet inputs and outputs + tasklet_inputs = { + "__X": dace.pointer(X.dtype), + "__W": dace.pointer(W.dtype), + } + tasklet_outputs = { + "__Y": dace.pointer(Y.dtype), + } + + if B is not None: + tasklet_inputs["__B"] = dace.pointer(B.dtype) + + # Create the tasklet + tasklet = nstate.add_tasklet(name=node.label + "_tasklet", + inputs=tasklet_inputs, + outputs=tasklet_outputs, + code=code, + language=dace.Language.CPP) + + # Connect the tasklet with memlets + nstate.add_edge(X_read, None, tasklet, "__X", dace.Memlet.from_array("X", X)) + nstate.add_edge(W_read, None, tasklet, "__W", dace.Memlet.from_array("W", W)) + if B is not None: + nstate.add_edge(B_read, None, tasklet, "__B", dace.Memlet.from_array("B", B)) + nstate.add_edge(tasklet, "__Y", Y_write, None, dace.Memlet.from_array("Y", Y)) + + return nsdfg + + +@op_implementation(op="BatchNormalization", name="pure") +class PureBatchNormalization(ONNXForward): + """Pure implementation of BatchNormalization operation.""" + + @staticmethod + def forward_can_be_applied(node: ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + """Check if this implementation can be applied to the given node. + + :param node: The BatchNormalization ONNX node. + :param state: The SDFG state containing the node. + :param sdfg: The parent SDFG. + :return: True if the implementation can be applied, False otherwise. + """ + X = in_desc_with_name(node, state, sdfg, "X") + if len(X.shape) != 4: + return False + + if "in_mean" in node.in_connectors and "input_mean" not in node.in_connectors: + # Replace the old names with the new ones + node.add_in_connector("input_mean", node.in_connectors["in_mean"]) + node.remove_in_connector("in_mean") + + if "in_var" in node.in_connectors and "input_var" not in node.in_connectors: + # Replace the old names with the new ones + node.add_in_connector("input_var", node.in_connectors["in_var"]) + node.remove_in_connector("in_var") + + # Check for the new output names + if not {"scale", "B", "input_mean", "input_var"}.issubset(node.in_connectors): + return False + + return True + + @staticmethod + def forward(node: ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[nodes.Node, SDFG]: + """Generate the forward pass implementation for BatchNormalization. + + :param node: The BatchNormalization ONNX node. + :param state: The SDFG state containing the node. + :param sdfg: The parent SDFG. + :return: A nested SDFG implementing the BatchNormalization operation. + """ + shape = copy.deepcopy(in_desc_with_name(node, state, sdfg, "X").shape) + reduce_axes = list(shape) + num_channels = reduce_axes.pop(1) + + N = _prod(reduce_axes) + broadcast_shape = [num_channels, 1, 1] + dtype = in_desc_with_name(node, state, sdfg, "X").dtype + eps = node.epsilon + momentum = node.momentum + inv_momentum = 1 - node.momentum + + axis = tuple(i for i in range(len(shape)) if i != 1) + + # Check if training_mode attribute exists + if not hasattr(node, "training_mode"): + # By default, set to False (inference mode) + node.training_mode = False + + if node.training_mode: + # TRAINING: compute batch statistics and update running statistics (EMA like PyTorch) + def prog(input_mean, scale, input_var, B, X, Y, running_mean, running_var): + # Batch mean, variance over axis=(0,2,3) for NCHW (your `axis`/`N` already set) + batch_mean = np.add.reduce(X, axis=axis) / N + + batch_mean_broadcastable = dace.define_local(broadcast_shape, dtype) + batch_mean_broadcastable[:] = batch_mean + X_minus_mean = X - batch_mean_broadcastable + + batch_var = np.add.reduce(X_minus_mean * X_minus_mean, axis=axis) / N + batch_var_eps = np.reshape(batch_var + eps, broadcast_shape) + + inv_std = dace.elementwise(lambda x: dtype(1.0) / sqrt(x), batch_var_eps) + normalized = X_minus_mean * inv_std + + scale_reshaped = np.reshape(scale, broadcast_shape) + bias_reshaped = np.reshape(B, broadcast_shape) + Y[:] = normalized * scale_reshaped + bias_reshaped + + # FIXED: PyTorch EMA + # running = (1 - momentum) * running + momentum * batch + running_mean[:] = input_mean * (1.0 - momentum) + batch_mean * momentum + running_var[:] = input_var * (1.0 - momentum) + batch_var * momentum + + new_sdfg = program_for_node(prog, sdfg, state, node) + + # Keep your "write-back" edges as-is + new_state = sdfg.add_state_after(sdfg.nodes()[0]) + rm_name = out_edge_with_name(node, state, "running_mean").data.data + new_state.add_edge(new_state.add_read(rm_name), None, + new_state.add_read(in_edge_with_name(node, state, "input_mean").data.data), None, + sdfg.make_array_memlet(rm_name)) + rv_name = out_edge_with_name(node, state, "running_var").data.data + new_state.add_edge(new_state.add_read(rv_name), None, + new_state.add_read(in_edge_with_name(node, state, "input_var").data.data), None, + sdfg.make_array_memlet(rv_name)) + else: + # EVAL: use provided running statistics; DO NOT recompute mean/var + def prog(input_mean, scale, input_var, B, X, Y): + mean_b = dace.define_local(broadcast_shape, dtype) + var_b = dace.define_local(broadcast_shape, dtype) + mean_b[:] = input_mean + var_b[:] = input_var + + X_minus_mean = X - mean_b + inv_std = dace.elementwise(lambda x: dtype(1.0) / sqrt(x + eps), var_b) + + normalized = X_minus_mean * inv_std + scale_b = np.reshape(scale, broadcast_shape) + bias_b = np.reshape(B, broadcast_shape) + Y[:] = normalized * scale_b + bias_b + + new_sdfg = program_for_node(prog, sdfg, state, node) + + return new_sdfg + + +@op_implementation(op="GlobalAveragePool", name="pure") +class PureGlobalAveragePool(ONNXForward): + """Pure implementation of GlobalAveragePool operation.""" + + @staticmethod + def forward_can_be_applied(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> bool: + """Check if this implementation can be applied to the given node. + + :param node: The GlobalAveragePool ONNX node. + :param state: The SDFG state containing the node. + :param sdfg: The parent SDFG. + :return: Always True for this implementation. + """ + return True + + @staticmethod + def forward(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + """Generate the forward pass implementation for GlobalAveragePool. + + :param node: The GlobalAveragePool ONNX node. + :param state: The SDFG state containing the node. + :param sdfg: The parent SDFG. + :return: A nested SDFG implementing the GlobalAveragePool operation. + """ + from dace.libraries.onnx.nodes.onnx_op_registry import ONNXReduceMean + + # Get input and output descriptors + X_desc = copy.deepcopy(in_desc_with_name(node, state, sdfg, "X")) + Y_desc = copy.deepcopy(out_desc_with_name(node, state, sdfg, "Y")) + + # Create new SDFG + nsdfg = dace.SDFG(node.label + "_expansion") + nstate = nsdfg.add_state() + + # Add data descriptors + nsdfg.add_datadesc("X", X_desc) + nsdfg.add_datadesc("Y", Y_desc) + nsdfg.arrays["X"].transient = False + nsdfg.arrays["Y"].transient = False + + # Add access nodes + X_read = nstate.add_read("X") + Y_write = nstate.add_write("Y") + + # Create axes array for reduction over spatial dimensions (2, 3) + axes_name = "axes" + rank = len(X_desc.shape) # e.g., (N, C, H, W) -> 4 + axes_values = list(range(2, rank)) + axes_arr_dtype = dace.int64 + axes_arr_shape = [len(axes_values)] + _, axes_desc = nsdfg.add_array(axes_name, axes_arr_shape, axes_arr_dtype, transient=True) + axes_node = nstate.add_access(axes_name) + + # Add a tasklet to initialize the axes array + axes_init_tasklet = nstate.add_tasklet("init_axes", + set(), {"out": dace.pointer(axes_arr_dtype)}, + "\n".join( + [f"out [{idx}] = {val};" for idx, val in enumerate(axes_values)]), + language=dace.Language.CPP) + nstate.add_edge(axes_init_tasklet, "out", axes_node, None, dace.Memlet(f"{axes_name}[0:{len(axes_values)}]")) + + # Create ONNXReduceMean node + reduce_mean_op = ONNXReduceMean("reduce_mean", keepdims=1) + reduce_mean_op.axes = axes_values + nstate.add_node(reduce_mean_op) + reduce_mean_op.add_in_connector("data") + reduce_mean_op.add_in_connector("axes") + reduce_mean_op.add_out_connector("reduced") + + # Connect the ReduceMean operation + nstate.add_edge(X_read, None, reduce_mean_op, "data", nsdfg.make_array_memlet("X")) + nstate.add_edge(axes_node, None, reduce_mean_op, "axes", nsdfg.make_array_memlet(axes_name)) + nstate.add_edge(reduce_mean_op, "reduced", Y_write, None, nsdfg.make_array_memlet("Y")) + + return nsdfg diff --git a/dace/libraries/onnx/op_implementations/linalg_ops.py b/dace/libraries/onnx/op_implementations/linalg_ops.py new file mode 100644 index 0000000000..ca5c2bd9ad --- /dev/null +++ b/dace/libraries/onnx/op_implementations/linalg_ops.py @@ -0,0 +1,359 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Linear algebra operations for ONNX. + +This module contains implementations of linear algebra operations including: +- MatMul: Matrix multiplication with broadcasting +- Gemm: General matrix multiplication (alpha*A*B + beta*C) +- Einsum: Einstein summation notation for tensor operations + +""" + +import copy +import itertools +import typing + +import dace +from dace import SDFG, SDFGState, nodes +from dace.sdfg.nodes import Node +from dace.sdfg.utils import in_desc_with_name, out_desc_with_name + +from dace import config +from dace.libraries.onnx.forward_implementation_abc import ONNXForward +from dace.libraries.onnx.nodes import onnx_op +from dace.libraries.onnx.op_implementations.utils import in_desc_with_name, op_implementation, out_desc_with_name +from dace.frontend.common import create_einsum_sdfg + +# ============================================================================ +# Matrix Multiplication +# ============================================================================ + + +@op_implementation(op="MatMul", name="pure") +class PureMatMul(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + input0_dim = len(in_desc_with_name(node, state, sdfg, "A").shape) + input1_dim = len(in_desc_with_name(node, state, sdfg, "B").shape) + + if input0_dim == 1 or input1_dim == 1: + return False + return True + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + from dace.libraries.onnx.nodes.onnx_op_registry import ONNXEinsum # avoid import loop + + A_desc = in_desc_with_name(node, state, sdfg, "A") + B_desc = in_desc_with_name(node, state, sdfg, "B") + Y_desc = out_desc_with_name(node, state, sdfg, "Y") + input0_dim = A_desc.shape + input1_dim = B_desc.shape + + # list containing letters from z-a + letters = [chr(ord('z') - i) for i in range(26)] + # i j k are used for the last dimensions + letters = [l for l in letters if l not in ['i', 'j', 'k']] + + if len(input0_dim) == 1: + if len(input1_dim) != 2: + raise ValueError("invalid dimensions") + arg1 = 'k' + arg2 = 'kj' + result = 'j' + elif len(input1_dim) == 1: + if len(input0_dim) != 2: + raise ValueError("invalid dimensions") + arg1 = 'ik' + arg2 = 'k' + result = 'i' + else: + # build the einsum. The last two dimensions are always just the matrix multiply einsum + # dace will later specialize to a batched matmul if possible + arg1 = 'ik' + arg2 = 'kj' + result = 'ij' + if input0_dim[-2] != input0_dim[-1]: + if dace.symbolic.issymbolic(input0_dim[-2]): + if config.Config.get_bool('debugprint'): + print( + f"Warning: overriding symbol {input0_dim[-2]} with value {input1_dim[-1]} in descriptor of input A of node {node}" + ) + new_shape = list(A_desc.shape) + new_shape[-1] = input1_dim[-2] + A_desc.shape = new_shape + elif dace.symbolic.issymbolic(input1_dim[-1]): + if config.Config.get_bool('debugprint'): + print( + f"Warning: overriding symbol {input0_dim[-1]} with value {input0_dim[-2]} in descriptor of input B of node {node}" + ) + new_shape = list(B_desc.shape) + new_shape[-2] = input0_dim[-1] + B_desc.shape = new_shape + input0_dim = input0_dim[:-2] + input1_dim = input1_dim[:-2] + for dim0, dim1 in itertools.zip_longest(reversed(input0_dim), reversed(input1_dim)): + if dim0 is None: + # only dim0 exists + letter = letters.pop() + arg2 = letter + arg2 + result = letter + result + elif dim1 is None: + # only dim1 exists + letter = letters.pop() + arg1 = letter + arg1 + result = letter + result + else: + # both exist + letter = letters.pop() + arg1 = letter + arg1 + arg2 = letter + arg2 + result = letter + result + + einsum_str = '{},{}->{}'.format(arg1, arg2, result) + + # we lower to an ONNXEinsum node instead straight to the dace einsum to + # make the autodiff simpler + nsdfg = dace.SDFG(node.label + "_expansion") + nstate = nsdfg.add_state() + einsum_node: onnx_op.ONNXOp = ONNXEinsum(node.label + "_einsum_expansion", equation=einsum_str) + + nstate.add_node(einsum_node) + einsum_node.add_in_connector("Inputs__0") + einsum_node.add_in_connector("Inputs__1") + nsdfg.add_datadesc("A", copy.deepcopy(A_desc)) + nsdfg.add_datadesc("B", copy.deepcopy(B_desc)) + nsdfg.add_datadesc("Y", copy.deepcopy(Y_desc)) + nsdfg.arrays["A"].transient = False + nsdfg.arrays["B"].transient = False + nsdfg.arrays["Y"].transient = False + + nstate.add_edge(nstate.add_read("A"), None, einsum_node, "Inputs__0", nsdfg.make_array_memlet("A")) + nstate.add_edge(nstate.add_read("B"), None, einsum_node, "Inputs__1", nsdfg.make_array_memlet("B")) + nstate.add_edge(einsum_node, "Output", nstate.add_write("Y"), None, nsdfg.make_array_memlet("Y")) + + return nsdfg + + +# ============================================================================ +# Einstein Summation +# ============================================================================ + + +@op_implementation(op="Einsum", name="pure") +class PureEinsum(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + if "..." in node.equation: + return False + return True + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + nsdfg = dace.SDFG(node.label + "_expansion") + nstate = nsdfg.add_state() + + for e in node.iter_inputs_in_onnx_order(state): + desc = copy.deepcopy(in_desc_with_name(node, state, sdfg, e.dst_conn)) + desc.transient = False + nsdfg.add_datadesc(e.dst_conn, desc) + for e in node.iter_outputs_in_onnx_order(state): + desc = copy.deepcopy(out_desc_with_name(node, state, sdfg, e.src_conn)) + desc.transient = False + nsdfg.add_datadesc(e.src_conn, desc) + + # Check if there is a wcr sum to accumulate the result instead of initialization the output + # This is necessary for gradient accumulation to be consistent + output_edge = state.out_edges(node) + assert len(output_edge) == 1, "Einsum node should have exactly one output edge" + output_edge = output_edge[0] + beta = 1 if output_edge.data.wcr else 0 + create_einsum_sdfg(nsdfg, + nstate, + node.equation.replace(" ", ""), + *(e.dst_conn for e in node.iter_inputs_in_onnx_order(state)), + output="Output", + beta=beta) + return nsdfg + + +# ============================================================================ +# General Matrix Multiplication (Gemm) +# ============================================================================ + + +@op_implementation(op="Gemm", name="pure") +class PureGemm(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + return True + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + from dace.libraries.onnx.nodes.onnx_op_registry import ONNXEinsum # avoid import loop + A_desc = in_desc_with_name(node, state, sdfg, "A") + B_desc = in_desc_with_name(node, state, sdfg, "B") + Y_desc = out_desc_with_name(node, state, sdfg, "Y") + input0_dim = A_desc.shape + input1_dim = B_desc.shape + + # list containing letters from z-a + letters = [chr(ord('z') - i) for i in range(26)] + # i j k are used for the last dimensions + letters = [l for l in letters if l not in ['i', 'j', 'k']] + + if len(input0_dim) == 1: + if len(input1_dim) != 2: + raise ValueError("invalid dimensions") + arg1 = 'k' + arg2 = 'kj' + result = 'j' + elif len(input1_dim) == 1: + if len(input0_dim) != 2: + raise ValueError("invalid dimensions") + arg1 = 'ik' + arg2 = 'k' + result = 'i' + else: + # build the einsum. The last two dimensions are always just the matrix multiply einsum + # dace will later specialize to a batched matmul if possible + arg1 = 'ik' + arg2 = 'kj' + result = 'ij' + if input0_dim[-2] != input0_dim[-1]: + if dace.symbolic.issymbolic(input0_dim[-2]): + if config.Config.get_bool('debugprint'): + print( + f"Warning: overriding symbol {input0_dim[-2]} with value {input1_dim[-1]} in descriptor of input A of node {node}" + ) + new_shape = list(A_desc.shape) + new_shape[-1] = input1_dim[-2] + A_desc.shape = new_shape + elif dace.symbolic.issymbolic(input1_dim[-1]): + if config.Config.get_bool('debugprint'): + print( + f"Warning: overriding symbol {input0_dim[-1]} with value {input0_dim[-2]} in descriptor of input B of node {node}" + ) + new_shape = list(B_desc.shape) + new_shape[-2] = input0_dim[-1] + B_desc.shape = new_shape + input0_dim = input0_dim[:-2] + input1_dim = input1_dim[:-2] + for dim0, dim1 in itertools.zip_longest(reversed(input0_dim), reversed(input1_dim)): + if dim0 is None: + # only dim0 exists + letter = letters.pop() + arg2 = letter + arg2 + result = letter + result + elif dim1 is None: + # only dim1 exists + letter = letters.pop() + arg1 = letter + arg1 + result = letter + result + else: + # both exist + letter = letters.pop() + arg1 = letter + arg1 + arg2 = letter + arg2 + result = letter + result + + if node.transA == 1: + arg1 = ''.join(reversed(arg1)) + if node.transB == 1: + arg2 = ''.join(reversed(arg2)) + + einsum_str = '{},{}->{}'.format(arg1, arg2, result) + + # we lower to an ONNXEinsum node instead straight to the dace einsum to + # make the autodiff simpler + nsdfg = dace.SDFG(node.label + "_expansion") + nstate = nsdfg.add_state() + + # Einsum: "A", "B" -> mm_result + einsum_node: nodes.LibraryNode = ONNXEinsum(node.label + "_einsum_expansion", equation=einsum_str) + + nstate.add_node(einsum_node) + einsum_node.add_in_connector("Inputs__0") + einsum_node.add_in_connector("Inputs__1") + nsdfg.add_datadesc("A", copy.deepcopy(A_desc)) + nsdfg.add_datadesc("B", copy.deepcopy(B_desc)) + nsdfg.add_datadesc("Y", copy.deepcopy(Y_desc)) + nsdfg.arrays["A"].transient = False + nsdfg.arrays["B"].transient = False + nsdfg.arrays["Y"].transient = False + + # Decide on array names based on alpha and beta + uid = state.node_id(node) + mm_result = "Y" + if node.alpha != 1 or node.beta != 0: + mm_result = f"Ytmp_{uid}" + scal_result = mm_result + if node.alpha != 1: + scal_result = f"scaled_{uid}" + + # Create arrays according to alpha and beta + if node.alpha != 1 or node.beta != 0: + Ytmp_desc = out_desc_with_name(node, state, sdfg, "Y") + nsdfg.add_datadesc(f"Ytmp_{uid}", copy.deepcopy(Ytmp_desc)) + nsdfg.arrays[f"Ytmp_{uid}"].transient = True + if node.beta != 0: + beta_desc = out_desc_with_name(node, state, sdfg, "Y") + nsdfg.add_datadesc(f"scaled_{uid}", copy.deepcopy(beta_desc)) + nsdfg.arrays[f"scaled_{uid}"].transient = True + + nstate.add_edge(nstate.add_read("A"), None, einsum_node, "Inputs__0", nsdfg.make_array_memlet("A")) + nstate.add_edge(nstate.add_read("B"), None, einsum_node, "Inputs__1", nsdfg.make_array_memlet("B")) + mm_result_node = nstate.add_write(mm_result) + nstate.add_edge(einsum_node, "Output", mm_result_node, None, nsdfg.make_array_memlet(mm_result)) + + # Multiply by alpha: mm_result -> scal_result + if node.alpha != 1: + nstate.add_mapped_tasklet( + node.label + '_alphascale', + { + k: f'0:{Ytmp_desc.shape[i]}' + for i, k in enumerate(result) + }, + dict(a=dace.Memlet(data=mm_result, subset=','.join(result))), + f'o = a * dace.{Ytmp_desc.dtype}({node.alpha})', + dict(o=dace.Memlet(data=scal_result, subset=','.join(result))), + external_edges=True, + input_nodes=dict(a=mm_result_node), + ) + + # Multiply by beta: scal_result, "C" -> "Y" + if node.beta != 0: + C_desc = in_desc_with_name(node, state, sdfg, "C") + nsdfg.add_datadesc("C", copy.deepcopy(C_desc)) + nsdfg.arrays["C"].transient = False + scal_result_node = next(n for n in nstate.sink_nodes() + if isinstance(n, dace.nodes.AccessNode) and n.data == scal_result) + beta_scale_code = f'o = s + c * dace.{C_desc.dtype}({node.beta})' + if node.beta == 1: + beta_scale_code = f'o = s + c' + + # Support broadcasting in C -> Y + c_index = result[-len(C_desc.shape):] + for c_shp, y_shp in zip(reversed(C_desc.shape), reversed(Y_desc.shape)): + if c_shp != y_shp: + raise ValueError('Could not broadcast dimensions from C ' + 'to Y in ONNXGemm') + + nstate.add_mapped_tasklet( + node.label + '_betascale', + { + k: f'0:{Y_desc.shape[i]}' + for i, k in enumerate(result) + }, + dict(s=dace.Memlet(data=scal_result, subset=','.join(result)), + c=dace.Memlet(data="C", subset=','.join(c_index))), + beta_scale_code, + dict(o=dace.Memlet(data="Y", subset=','.join(result))), + external_edges=True, + input_nodes={scal_result: scal_result_node}, + ) + + return nsdfg diff --git a/dace/libraries/onnx/op_implementations/normalization_ops.py b/dace/libraries/onnx/op_implementations/normalization_ops.py new file mode 100644 index 0000000000..f434e3d5ac --- /dev/null +++ b/dace/libraries/onnx/op_implementations/normalization_ops.py @@ -0,0 +1,281 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Normalization operations for ONNX. + +This module contains implementations of normalization operations including: +- Softmax, LogSoftmax: Softmax normalization +- LayerNormalization: Layer normalization +- Dropout: Dropout regularization +""" + +import copy +import typing + +import dace +import numpy as np +from dace import SDFG, SDFGState, nodes +from dace.sdfg.utils import in_desc_with_name, out_desc_with_name + +from dace.libraries.onnx.forward_implementation_abc import ONNXForward +from dace.libraries.onnx.op_implementations.utils import (in_desc_with_name, op_implementation, out_desc_with_name, + python_pure_op_implementation) + +# ============================================================================ +# Softmax Operations +# ============================================================================ + +softmax_compute = dict(axis=lambda node, input: tuple(range(len(input.shape)))[node.axis:]) + + +@python_pure_op_implementation(**softmax_compute) +def Softmax(input, output): + maximum = np.maximum.reduce(input, axis=axis, keepdims=True) + exp_values = np.exp(input - maximum) + sum_exp = np.add.reduce(exp_values, axis=axis, keepdims=True) + output[:] = exp_values / sum_exp + + +@python_pure_op_implementation(**softmax_compute) +def LogSoftmax(input, output): + maximum = np.maximum.reduce(input, axis=axis, keepdims=True) + max_sub = input - maximum + exponent = np.exp(max_sub) + sum = np.add.reduce(exponent, axis=axis, keepdims=True) + log_sum = np.log(sum) + output[:] = max_sub - log_sum + + +# ============================================================================ +# Layer Normalization +# ============================================================================ + + +def _layernorm_axis(node, X): + axis = node.axis if hasattr(node, 'axis') and node.axis >= 0 else len(X.shape) + node.axis + return tuple(range(axis, len(X.shape))) + + +def _layernorm_norm_size(node, X): + axis = node.axis if hasattr(node, 'axis') and node.axis >= 0 else len(X.shape) + node.axis + return int(np.prod([X.shape[i] for i in range(axis, len(X.shape))])) + + +def _layernorm_epsilon(node, X): + eps = getattr(node, 'epsilon', 1e-5) + return X.dtype.type(eps) + + +def _layernorm_one(X): + return X.dtype.type(1) + + +layernorm_compute = dict(axis=_layernorm_axis, + epsilon=_layernorm_epsilon, + norm_size=_layernorm_norm_size, + one=_layernorm_one) + + +@python_pure_op_implementation(**layernorm_compute) +def LayerNormalization(X, Scale, B, Y): + sum_x = np.add.reduce(X, axis=axis, keepdims=True) + mean = sum_x / norm_size + diff = X - mean + sum_sq = np.add.reduce(diff * diff, axis=axis, keepdims=True) + variance = sum_sq / norm_size + inv_std = one / np.sqrt(variance + epsilon) + normalized = diff * inv_std + Y[:] = normalized * Scale + B + + +# ============================================================================ +# Dropout +# ============================================================================ + + +@op_implementation(op="Dropout", name="pure") +class PureDropout(ONNXForward): + """ Dropout implementation with support for training and inference modes. + """ + + @staticmethod + def forward_can_be_applied(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> bool: + # Get input descriptor + data = in_desc_with_name(node, state, sdfg, "data") + + # Check if optional inputs are present + has_ratio = "ratio" in node.in_connectors + has_training_mode = "training_mode" in node.in_connectors + + # Check data type + if data.dtype not in [dace.float16, dace.float32, dace.float64]: + return False + + # If ratio is provided as input, it should be a scalar + if has_ratio: + ratio = in_desc_with_name(node, state, sdfg, "ratio") + if ratio.total_size != 1: + return False + + # If training_mode is provided as input, it should be a scalar boolean + if has_training_mode: + training_mode = in_desc_with_name(node, state, sdfg, "training_mode") + if training_mode.total_size != 1: + return False + + return True + + @staticmethod + def forward(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> typing.Union[nodes.Node, SDFG]: + # Get descriptors + data = in_desc_with_name(node, state, sdfg, "data") + output = out_desc_with_name(node, state, sdfg, "output") + + # Check for optional mask output + has_mask_output = "mask" in node.out_connectors + mask = out_desc_with_name(node, state, sdfg, "mask") if has_mask_output else None + + # Check for optional inputs + has_ratio_input = "ratio" in node.in_connectors + has_training_mode_input = "training_mode" in node.in_connectors + + ratio_desc = in_desc_with_name(node, state, sdfg, "ratio") if has_ratio_input else None + training_mode_desc = in_desc_with_name(node, state, sdfg, "training_mode") if has_training_mode_input else None + + # Get dropout ratio (from attribute or will be provided as input) + # ONNX spec: default ratio is 0.5 if not specified + dropout_ratio = getattr(node, 'ratio', 0.5) if not has_ratio_input else None + + # Get seed if specified (for reproducible dropout) + seed = getattr(node, 'seed', None) + + # Calculate total elements + total_elements = data.total_size + + # Get data type + dtype = data.dtype + dtype_str = str(dtype).replace("dace.", "") + + # Create new SDFG + nsdfg = dace.SDFG(node.label + "_expansion") + nstate = nsdfg.add_state() + + # Add data descriptors + nsdfg.add_datadesc("data", copy.deepcopy(data)) + nsdfg.add_datadesc("output", copy.deepcopy(output)) + + if has_mask_output: + nsdfg.add_datadesc("mask", copy.deepcopy(mask)) + + if has_ratio_input: + nsdfg.add_datadesc("ratio", copy.deepcopy(ratio_desc)) + + if has_training_mode_input: + nsdfg.add_datadesc("training_mode", copy.deepcopy(training_mode_desc)) + + # Set arrays as non-transient + nsdfg.arrays["data"].transient = False + nsdfg.arrays["output"].transient = False + if has_mask_output: + nsdfg.arrays["mask"].transient = False + if has_ratio_input: + nsdfg.arrays["ratio"].transient = False + if has_training_mode_input: + nsdfg.arrays["training_mode"].transient = False + + # Add access nodes + data_read = nstate.add_read("data") + output_write = nstate.add_write("output") + mask_write = nstate.add_write("mask") if has_mask_output else None + ratio_read = nstate.add_read("ratio") if has_ratio_input else None + training_mode_read = nstate.add_read("training_mode") if has_training_mode_input else None + + # Generate C++ code for dropout + # Note: This implementation uses a simple linear congruential generator for portability + # In production, you might want to use a better random number generator + + code = f""" + #include + #include + + // Get dropout ratio + {dtype_str} ratio = {dropout_ratio if not has_ratio_input else '__ratio'}; + + // Get training mode (default to false if not specified) + bool training_mode = {('__training_mode' if has_training_mode_input else 'false')}; + + // If in inference mode, just copy input to output + if (!training_mode) {{ + for (int i = 0; i < {total_elements}; i++) {{ + __output[i] = __data[i]; + {"__mask[i] = true;" if has_mask_output else ""} + }} + }} else {{ + // Training mode: apply dropout + + // Initialize random seed + static uint64_t rng_state = {seed if seed is not None else 'uint64_t(std::time(nullptr))'}; + + // Scale factor for remaining values (1 / (1 - ratio)) + {dtype_str} scale = ({dtype_str})(1.0 / (1.0 - ratio)); + + // Apply dropout + for (int i = 0; i < {total_elements}; i++) {{ + // Simple LCG for random number generation + // This generates a random number in [0, 1) + rng_state = (rng_state * 1664525ULL + 1013904223ULL); + double random_val = double(rng_state) / double(UINT64_MAX); + + // Dropout: keep if random value is greater than ratio + bool keep = (random_val >= ratio); + + if (keep) {{ + // Scale the kept values + __output[i] = __data[i] * scale; + {"__mask[i] = true;" if has_mask_output else ""} + }} else {{ + // Drop the value + __output[i] = 0; + {"__mask[i] = false;" if has_mask_output else ""} + }} + }} + }} + """ + + # Create tasklet inputs and outputs + tasklet_inputs = { + "__data": dace.pointer(data.dtype), + } + tasklet_outputs = { + "__output": dace.pointer(output.dtype), + } + + if has_ratio_input: + tasklet_inputs["__ratio"] = ratio_desc.dtype + if has_training_mode_input: + tasklet_inputs["__training_mode"] = training_mode_desc.dtype + if has_mask_output: + tasklet_outputs["__mask"] = dace.pointer(mask.dtype) + + # Create the tasklet + tasklet = nstate.add_tasklet(name=node.label + "_tasklet", + inputs=tasklet_inputs, + outputs=tasklet_outputs, + code=code, + language=dace.Language.CPP) + + # Connect the tasklet with memlets + nstate.add_edge(data_read, None, tasklet, "__data", dace.Memlet.from_array("data", data)) + + if has_ratio_input: + nstate.add_edge(ratio_read, None, tasklet, "__ratio", dace.Memlet.from_array("ratio", ratio_desc)) + + if has_training_mode_input: + nstate.add_edge(training_mode_read, None, tasklet, "__training_mode", + dace.Memlet.from_array("training_mode", training_mode_desc)) + + nstate.add_edge(tasklet, "__output", output_write, None, dace.Memlet.from_array("output", output)) + + if has_mask_output: + nstate.add_edge(tasklet, "__mask", mask_write, None, dace.Memlet.from_array("mask", mask)) + + return nsdfg diff --git a/dace/libraries/onnx/op_implementations/reduction_ops.py b/dace/libraries/onnx/op_implementations/reduction_ops.py new file mode 100644 index 0000000000..cd9a361b8e --- /dev/null +++ b/dace/libraries/onnx/op_implementations/reduction_ops.py @@ -0,0 +1,304 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Reduction operations for ONNX. + +This module contains implementations of reduction operations including: +- ReduceSum, ReduceMean: Standard reductions over specified axes +- ReduceMax, ReduceMin: Min/max reductions +- CumSum: Cumulative sum along an axis +- Sum: Element-wise sum of multiple inputs + +""" + +import copy +import typing + +import dace +import numpy as np +from dace import SDFG, SDFGState +from dace.sdfg.nodes import Node +from dace.sdfg.utils import in_desc_with_name, in_edge_with_name, out_desc_with_name +from dace.libraries.onnx.forward_implementation_abc import ONNXForward +from dace.libraries.onnx.nodes import onnx_op +from dace.libraries.onnx.op_implementations.common import iterables_equal +from dace.libraries.onnx.op_implementations.utils import (empty_sdfg_for_node, in_desc_with_name, op_implementation, + out_desc_with_name, program_for_node) + +# ============================================================================ +# Cumulative Sum +# ============================================================================ + + +@op_implementation(op="CumSum", name="pure") +class PureCumSum(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + if node.exclusive or node.reverse: + return False + try: + if hasattr(sdfg, "_parent_onnx_model") and in_edge_with_name( + node, state, "axis").src.data not in sdfg._parent_onnx_model.clean_weights: + return False + except ValueError: + return False + return True + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + axis = sdfg._parent_onnx_model.clean_weights[in_edge_with_name(node, state, "axis").src.data].numpy().item() + + def prog(x, y): + y[:] = np.cumsum(x, axis=axis) + + return program_for_node(prog, sdfg, state, node) + + +# ============================================================================ +# ReduceMean Operations +# ============================================================================ + + +@op_implementation(op="ReduceMean", name="pure") +class PureReduceMean(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + is_axes_present = True + try: + if hasattr(sdfg, "_parent_onnx_model") and in_edge_with_name( + node, state, "axes").src.data not in sdfg._parent_onnx_model.clean_weights: + return False + except ValueError: + is_axes_present = False + + if not is_axes_present and hasattr(node, "axes"): + is_axes_present = True + + if not is_axes_present: + return False + + return True + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + axes = None + try: + if hasattr(sdfg, "_parent_onnx_model") and in_edge_with_name( + node, state, "axes").src.data in sdfg._parent_onnx_model.clean_weights: + axes = sdfg._parent_onnx_model.clean_weights[in_edge_with_name(node, state, "axes").src.data].numpy() + except ValueError: + pass + if axes is not None: + if len(axes) == 1: + axes = axes[0] + else: + axes = tuple(axes) + else: + axes = node.axes if hasattr(node, "axes") else None + + def prog(data, reduced): + reduced[:] = np.mean(data, axis=axes) + + return program_for_node(prog, sdfg, state, node) + + +# ============================================================================ +# ReduceSum Operations +# ============================================================================ + + +@op_implementation(op="ReduceSum", name="pure") +class PureReduceSum(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + is_axes_present = True + try: + if hasattr(sdfg, "_parent_onnx_model") and in_edge_with_name( + node, state, "axes").src.data not in sdfg._parent_onnx_model.clean_weights: + return False + except ValueError: + is_axes_present = False + + if not is_axes_present and hasattr(node, "axes"): + is_axes_present = True + + if not is_axes_present: + return False + + return True + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + axes = None + try: + if hasattr(sdfg, "_parent_onnx_model") and in_edge_with_name( + node, state, "axes").src.data in sdfg._parent_onnx_model.clean_weights: + axes = sdfg._parent_onnx_model.clean_weights[in_edge_with_name(node, state, "axes").src.data].numpy() + except ValueError: + pass + if axes is not None: + if len(axes) == 1: + axes = axes[0] + else: + axes = tuple(axes) + else: + axes = node.axes if hasattr(node, "axes") else None + + def prog(data, reduced): + reduced[:] = np.sum(data, axis=axes) + + return program_for_node(prog, sdfg, state, node) + + +# ============================================================================ +# ReduceMax and ReduceMin Operations +# ============================================================================ + + +@op_implementation(op="ReduceMax", name="pure") +class PureReduceMax(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + is_axes_present = True + try: + if hasattr(sdfg, "_parent_onnx_model") and in_edge_with_name( + node, state, "axes").src.data not in sdfg._parent_onnx_model.clean_weights: + return False + except ValueError: + is_axes_present = False + + if not is_axes_present and hasattr(node, "axes"): + is_axes_present = True + + if not is_axes_present: + return False + + return True + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + axes = None + try: + if hasattr(sdfg, "_parent_onnx_model") and in_edge_with_name( + node, state, "axes").src.data in sdfg._parent_onnx_model.clean_weights: + axes = sdfg._parent_onnx_model.clean_weights[in_edge_with_name(node, state, "axes").src.data].numpy() + except ValueError: + pass + if axes is not None: + if len(axes) == 1: + axes = axes[0] + else: + axes = tuple(axes) + else: + axes = node.axes if hasattr(node, "axes") else None + + def prog(data, reduced): + reduced[:] = np.max(data, axis=axes) + + return program_for_node(prog, sdfg, state, node) + + +@op_implementation(op="ReduceMin", name="pure") +class PureReduceMin(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + is_axes_present = True + try: + if hasattr(sdfg, "_parent_onnx_model") and in_edge_with_name( + node, state, "axes").src.data not in sdfg._parent_onnx_model.clean_weights: + return False + except ValueError: + is_axes_present = False + + if not is_axes_present and hasattr(node, "axes"): + is_axes_present = True + + if not is_axes_present: + return False + + return True + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + axes = None + try: + if hasattr(sdfg, "_parent_onnx_model") and in_edge_with_name( + node, state, "axes").src.data in sdfg._parent_onnx_model.clean_weights: + axes = sdfg._parent_onnx_model.clean_weights[in_edge_with_name(node, state, "axes").src.data].numpy() + except ValueError: + pass + if axes is not None: + if len(axes) == 1: + axes = axes[0] + else: + axes = tuple(axes) + else: + axes = node.axes if hasattr(node, "axes") else None + + def prog(data, reduced): + reduced[:] = np.min(data, axis=axes) + + return program_for_node(prog, sdfg, state, node) + + +# ============================================================================ +# Sum (Multi-input sum) +# ============================================================================ + + +@op_implementation(op="Sum", name="pure") +class PureSum(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + # check that all shapes are arrays, and that the shapes are all equal + shape = None + for edge in node.iter_inputs_in_onnx_order(state): + desc = in_desc_with_name(node, state, sdfg, edge.dst_conn) + if shape is None: + shape = desc.shape + + if not iterables_equal(shape, desc.shape): + return False + + if not iterables_equal(shape, out_desc_with_name(node, state, sdfg, "sum").shape): + return False + + return True + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + + nsdfg = dace.SDFG(node.name) + input_names = [] + for e in node.iter_inputs_in_onnx_order(state): + new_desc = copy.deepcopy(in_desc_with_name(node, state, sdfg, e.dst_conn)) + new_desc.transient = False + nsdfg.add_datadesc(e.dst_conn, new_desc) + input_names.append(e.dst_conn) + + new_desc = copy.deepcopy(out_desc_with_name(node, state, sdfg, "sum")) + new_desc.transient = False + nsdfg.add_datadesc("sum", new_desc) + + nstate = nsdfg.add_state() + # we know all shapes are equal to the output shape + shape = out_desc_with_name(node, state, sdfg, "sum").shape + map_ranges = {f"i{i}": f"0:{s}" for i, s in enumerate(shape)} + index_str = f"{', '.join(map_ranges.keys())}" + tasklet, _, _ = nstate.add_mapped_tasklet( + node.name + "_tasklet", + map_ranges=map_ranges, + inputs={f"__{inp}": dace.Memlet(f"{inp}[{index_str}]") + for inp in input_names}, + code=f"__sum = {' + '.join(f'__{inp}' for inp in input_names)}", + outputs={"__sum": dace.Memlet(f"sum[{index_str}]")}, + external_edges=True) + + tasklet.in_connectors = {f"__{inp}": in_desc_with_name(node, state, sdfg, inp).dtype for inp in input_names} + tasklet.out_connectors = {"__sum": out_desc_with_name(node, state, sdfg, "sum").dtype} + return nsdfg diff --git a/dace/libraries/onnx/op_implementations/utils.py b/dace/libraries/onnx/op_implementations/utils.py new file mode 100644 index 0000000000..620d9b3081 --- /dev/null +++ b/dace/libraries/onnx/op_implementations/utils.py @@ -0,0 +1,223 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import inspect +import copy +from typing import Dict, Tuple, Optional, Callable, Union, Any +import functools +import textwrap + +import dace +from dace import SDFGState, SDFG, dtypes, nodes +from dace.frontend.python.parser import DaceProgram +from dace.registry import autoregister + +from dace.libraries.onnx.nodes import onnx_op +from dace.libraries.onnx.forward_implementation_abc import ONNXForward +from dace.libraries.onnx.nodes.node_utils import parse_variadic_param +from dace.sdfg.utils import in_desc_with_name, out_desc_with_name + + +def op_implementation(op, name): + """A decorator that registers an op implementation. + + It should be used on classes that extend :class:`~dace.libraries.onnx.forward_implementation_abc.ONNXForward`. + + :param op: The ONNX name of the op to register for. + :param name: The name of the implementation. + """ + + def dec(cls): + if cls.__doc__ is not None: + cls.__doc__ +=\ + """ + :Implementation name: ``"{}"`` + """.format(name) + else: + cls.__doc__ =\ + """ + :Implementation name: ``"{}"`` + """.format(name) + + return autoregister(cls, op=op, name=name) + + return dec + + +def program_for_node(program, + sdfg: SDFG, + state: SDFGState, + node: onnx_op.ONNXOp, + extra_vars: Optional[Dict[str, Any]] = None) -> SDFG: + """Expand a function to a DaCe program. + + The dtypes for the arguments will be extracted by matching the parameter names to edges. + + All inputs that are not specified as parameters will be removed using + constant_folding.remove_node_and_computation. + + :param program: The function to expand into a DaCe program. + :param sdfg: The parent SDFG. + :param state: The SDFG state containing the node. + :param node: The ONNX node to create a program for. + :param extra_vars: Optional extra variables to add to the program. + :return: A new SDFG implementing the program. + """ + + from dace.transformation.onnx import constant_folding # avoid import loop + input_names = node.schema.non_variadic_inputs() + variadic_input_names = node.schema.variadic_inputs() + + output_names = node.schema.non_variadic_outputs() + variadic_output_names = node.schema.variadic_outputs() + + if set(input_names).intersection(output_names): + # This is currently the case for only one ONNX op + raise ValueError("program_for_node cannot be applied on nodes of this type;" + " '{}' are both an input and an output".format(set(input_names).intersection(output_names))) + + params = inspect.signature(program).parameters + connectors_to_remove = set(input_names).difference(params) + + annotations = {} + for name, param in params.items(): + if name in input_names or ("__" in name and parse_variadic_param(name)[0] in variadic_input_names): + annotations[name] = in_desc_with_name(node, state, sdfg, name) + elif name in output_names or ("__" in name and parse_variadic_param(name)[0] in variadic_output_names): + annotations[name] = out_desc_with_name(node, state, sdfg, name) + else: + raise ValueError("'{}' was not found as an input or output for {}".format(name, node.schema.name)) + + program.__annotations__ = annotations + + program.__name__ = node.label + "_expansion" + result = DaceProgram(program, (), {}, False, dace.DeviceType.CPU) + if extra_vars is not None: + result.global_vars.update(extra_vars) + + for conn in connectors_to_remove: + constant_folding.remove_node_and_computation(sdfg, state, node, conn) + + sdfg = result.to_sdfg() + + if node.schedule in [dtypes.ScheduleType.GPU_Default] + dtypes.GPU_SCHEDULES: + sdfg.apply_gpu_transformations() + + return sdfg + + +def empty_sdfg_for_node( + sdfg: SDFG, + state: SDFGState, + node: onnx_op.ONNXOp, + add_access_nodes=True) -> Tuple[SDFG, SDFGState, Dict[str, nodes.AccessNode], Dict[str, nodes.AccessNode]]: + """Given a node, return an SDFG that can be used as a nested SDFG expansion for that node. + + The dtypes for the arguments will be extracted by matching the parameter names to edges. + + :param sdfg: The parent SDFG. + :param state: The SDFG state containing the node. + :param node: The ONNX node to create an SDFG for. + :param add_access_nodes: Whether to add access nodes to the SDFG. + :return: A tuple containing (nested SDFG, nested state, input nodes dict, output nodes dict). + """ + nsdfg = SDFG(node.label + "_expansion") + nstate = nsdfg.add_state() + + input_nodes = {} + output_nodes = {} + for edge, is_input in node.iter_edges(state, ignore_unknown=True): + if is_input: + conn_name = edge.dst_conn + nsdfg.add_datadesc(conn_name, copy.deepcopy(in_desc_with_name(node, state, sdfg, conn_name))) + if add_access_nodes: + input_nodes[conn_name] = nstate.add_read(conn_name) + else: + conn_name = edge.src_conn + nsdfg.add_datadesc(conn_name, copy.deepcopy(out_desc_with_name(node, state, sdfg, conn_name))) + if add_access_nodes: + output_nodes[conn_name] = nstate.add_write(conn_name) + nsdfg.arrays[conn_name].transient = False + + return nsdfg, nstate, input_nodes, output_nodes + + +@dace.dtypes.paramdec +def python_pure_op_implementation(func, **compute: Dict[str, Callable]): + """A decorator that registers a Python op implementation. + + The name of the function will be the name of the op that is being replaced. + + The compute parameter enables you to compute a variable given the node and + its inputs/outputs. This variable will be namespaced when parsing the function. + + To use this, the argument names of the functions can be either: + + * ``node``, in which case the argument will be passed the node we are expanding, + * or, the name of any connector of the node, in which case the argument will be + the data descriptor for that connector + + For example, the following compute argument instantiation will make + variables ``axis`` and ``shape`` available when the function is parsed. + + + .. highlight:: python + .. code-block:: python + + compute=dict( + # Grabs the axis of a node + axis=lambda node: node.axis + # Grabs the shape of the connector with name 'data' + shape=lambda data: data.shape + ) + + :param func: The function to register as an implementation + :param compute: A dictionary of functions that compute variables. + """ + + @op_implementation(op=func.__name__, name="pure") + class PureImpl(ONNXForward): + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> Union[nodes.Node, SDFG]: + + def compute_argument_resolver(arg: str): + if arg == "node": + return node + elif arg in node.in_connectors: + return in_desc_with_name(node, state, sdfg, arg) + elif arg in node.out_connectors: + return out_desc_with_name(node, state, sdfg, arg) + else: + raise ValueError("Got unknown compute argument {}." + " Arguments to compute can be either 'node'," + " or the name of a connector of the node".format(arg)) + + extra_vars = {} + if compute is not None: + for var_name, function in compute.items(): + + # Get the names of the lambda + argument_names = list(inspect.signature(function).parameters) + + args = map(compute_argument_resolver, argument_names) + var_value = function(*args) + + extra_vars[var_name] = var_value + + return program_for_node(func, sdfg, state, node, extra_vars=extra_vars) + + doc = \ + """ +Pure implementation parsed with +:func:`~dace.libraries.onnx.op_implementations.utils.python_pure_op_implementation`. + +.. code :: python + +""" + doc += textwrap.indent(inspect.getsource(func), prefix=" ") + + PureImpl.__module__ = func.__module__ + PureImpl.__name__ = func.__name__ + PureImpl.__qualname__ = func.__qualname__ + PureImpl.__doc__ = doc + + return PureImpl diff --git a/dace/libraries/onnx/schema.py b/dace/libraries/onnx/schema.py new file mode 100644 index 0000000000..9d9ba3991d --- /dev/null +++ b/dace/libraries/onnx/schema.py @@ -0,0 +1,333 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +ONNX Schema System for DaCe. + +This module provides a Python representation layer for ONNX protobuf schemas, +enabling type-safe interaction with ONNX operations in DaCe. It handles: + +- Converting ONNX protobuf definitions to Python classes +- Type validation and constraint checking for ONNX operations +- Attribute and parameter schema definitions +- Automatic mapping between ONNX types and DaCe types + +Key Components: +- onnx_representation: Decorator for creating Python representations of ONNX protobufs +- ONNXSchema: Complete schema for an ONNX operation +- ONNXAttribute: Attribute definitions (e.g., kernel_shape, strides) +- ONNXParameter: Input/output parameter specifications +- ONNXTypeConstraint: Type constraints for operation parameters +- Enums: ONNXAttributeType, ONNXParameterType for type classification + +The schema system enables: +- Compile-time validation of ONNX operations +- Automatic property generation from schemas +- Type-safe conversion between ONNX and DaCe representations +- Integration with DaCe's property system + +Example: + @onnx_representation(onnx.TensorProto) + class ONNXTensor: + dims: List[int] + data_type: int +""" + +from itertools import chain +from typing import List + +import aenum +import numpy as np +import onnx + +import dace +from dace import config +from dace.dtypes import typeclass +from dace.libraries.onnx.converters import convert_onnx_proto, get_proto_attr, onnx_type_str_to_typeclass +from dace.properties import DictProperty, ListProperty, Property, make_properties + +#: Global registry of known ONNX protobuf types and their Python representations +_KNOWN_ONNX_PROTOS = {} + + +def onnx_representation(represents, **mapping): + """Decorator for python representations of ONNX protobufs. + + The decorator will monkey patch in the following methods: + + * ``__init__`` - a constructor based on the class properties + * ``construct_from_onnx_proto`` + * ``construct_from_json`` + + :param represents: The ONNX protobuf type that the decorated class represents. + :param mapping: A mapping from class property names to either: + + * A string ``s`` - ``convert_onnx_attribute`` will be applied on the + protobuf attribute with the name ``s`` to get the property value. + * A function ``f`` - ``f`` will be called with the protobuf, and the + property value will be set to the return value of that call. + + If a property name is not present in ``mapping``, the property name + itself will be used to access the protobuf attribute. + """ + + def decorator(cls): + + cls = make_properties(cls) + + # initialize the mapping with identity + # this means that by default, we will read the property of the protobuf using the same name as the property name + for name, _ in cls.__properties__.items(): + if name not in mapping: + mapping[name] = name + + def __init__(self, *args, **kwargs): + args = list(args) + for name, prop in self.__properties__.items(): + if len(args) > 0: + # try to init all the positional args first + setattr(self, name, args.pop(0)) + else: + # then try kwargs + setattr(self, name, kwargs[name]) + self._represents = represents + if hasattr(self, "validate"): + self.validate() + + @classmethod + def from_onnx_proto(cls, onnx_proto): + + if type(onnx_proto) is not represents: + raise ValueError("Unexpected protobuf '{}' (type {}), expected protobuf of type {}".format( + onnx_proto, type(onnx_proto), represents)) + + constructor_args = {} + for name, _ in cls.__properties__.items(): + if type(mapping[name]) is str: + # if the value of the mapping for that property is a string, read the attribute with that name + constructor_args[name] = convert_onnx_proto(get_proto_attr(onnx_proto, mapping[name])) + else: + # the value of the mapping should be a function, apply it to the onnx_proto + constructor_args[name] = mapping[name](onnx_proto) + + return cls(**constructor_args) + + @classmethod + def from_json(cls, json, context=None): + + constructor_args = { + name: prop.from_json(json[name] if name in json else prop.default) + for name, prop in cls.__properties__.items() + } + return cls(**constructor_args) + + def to_json(self): + serialized = dace.serialize.all_properties_to_json(self) + serialized["type"] = cls.__name__ + return serialized + + cls.__init__ = __init__ + + # the first line of the init docstring contains the signature of the method. This will be picked up by sphinx + # and means that the generated sphinx docs have a proper signature, and not just *args, **kwargs. + init_docstring = "__init__({})\n\n".format(", ".join(name + "=" + repr(prop._default) + for name, prop in cls.__properties__.items())) + + def get_prop_docstring(name, prop): + return ":param {}: {}\n:type {}: ``{}``, default ``{}``".format( + name, prop.__doc__, name, + prop._dtype.__name__ if prop._dtype is not None else type(prop._default).__name__, repr(prop._default)) + + init_docstring += "\n".join(get_prop_docstring(name, prop) for name, prop in cls.__properties__.items()) + + cls.__init__.__doc__ = init_docstring + + cls.from_onnx_proto = from_onnx_proto + cls.from_json = from_json + cls.to_json = to_json + from_onnx_proto.__func__.__doc__ = " Construct an object from an ONNX proto of type ``{}``. ".format(represents) + from_json.__func__.__doc__ = " Construct an object json ".format(represents) + to_json.__doc__ = " Serialize to json ".format(represents) + + # register so that we're able to load it + _KNOWN_ONNX_PROTOS[represents] = cls + + return cls + + return decorator + + +class ONNXParameterType(aenum.AutoNumberEnum): + Single = () #: single/required parameters + Optional = () #: optional parameters + Variadic = () #: variadic parameters + + +@onnx_representation(onnx.defs.OpSchema.FormalParameter, + type_str='type_str', + param_type='option', + homogeneous="is_homogeneous") +class ONNXParameter: + """ Python representation of an ONNX parameter. """ + + name = Property(dtype=str, desc="The parameter name") + description = Property(dtype=str, desc="A description of the parameter") + type_str = Property(dtype=str, desc="The type string of this parameter") + param_type = Property(choices=ONNXParameterType, + desc="The type of the this parameter", + default=ONNXParameterType.Single) + homogeneous = Property(dtype=bool, desc="Whether this parameter is homogeneous") + + def __repr__(self): + return "{} ({})".format(self.name, str(self.param_type)) + + +class ONNXAttributeType(aenum.AutoNumberEnum): + Int = () #: Integer (python representation is ``int``) + Float = () #: Float (python representation is ``float``) + String = () #: String (python representation is ``str``) + Ints = () #: Ints (python representation is ``List`` [``int``]) + Floats = () #: Floats (python representation is ``List`` [``float``]) + Strings = () #: Strings (python representation is ``List`` [``str``]) + Tensor = () #: Tensor (python representation is ``numpy.ndarray``) + Unsupported = () #: Any unsupported attribute type + + +_ATTR_TYPE_TO_PYTHON_TYPE = { + ONNXAttributeType.Int: int, + ONNXAttributeType.Ints: int, + ONNXAttributeType.Float: float, + ONNXAttributeType.Floats: float, + ONNXAttributeType.String: str, + ONNXAttributeType.Strings: str, + ONNXAttributeType.Tensor: np.ndarray +} + + +@onnx_representation(onnx.defs.OpSchema.Attribute, attribute_type='type') +class ONNXAttribute: + """ Python representation of an ONNX attribute. """ + + name = Property(dtype=str, desc="The attribute name") + description = Property(dtype=str, desc="A description this attribute") + required = Property(dtype=bool, desc="Whether this attribute is required") + attribute_type = Property(choices=ONNXAttributeType, + desc="The type of this attribute", + default=ONNXAttributeType.Int) + default_value = Property(dtype=None, desc="The default value of this attribute", default=None, allow_none=True) + + def validate(self): + if self.required and self.attribute_type == ONNXAttributeType.Unsupported: + raise NotImplementedError("Required attribute '{}' has an unsupported type".format(self.name)) + + def __repr__(self): + return self.name + + +@onnx_representation( + onnx.defs.OpSchema.TypeConstraintParam, + type_str='type_param_str', + types=lambda proto: list( + filter(lambda x: x is not None, map(onnx_type_str_to_typeclass, get_proto_attr(proto, "allowed_type_strs"))))) +class ONNXTypeConstraint: + """ Python representation of an ONNX type constraint. """ + + type_str = Property(dtype=str, desc="The type parameter string") + types = ListProperty(element_type=typeclass, + desc="The possible types. Note that only tensor types are currently supported.") + + def __repr__(self): + return self.type_str + + +@onnx_representation( + onnx.defs.OpSchema, + inputs=lambda proto: list(map(convert_onnx_proto, get_proto_attr(proto, "inputs"))), + outputs=lambda proto: list(map(convert_onnx_proto, get_proto_attr(proto, "outputs"))), + attributes=lambda proto: { + str(k): convert_onnx_proto(v) + for k, v in get_proto_attr(proto, "attributes").items() + }, + type_constraints=lambda proto: + {str(cons.type_param_str): convert_onnx_proto(cons) + for cons in get_proto_attr(proto, "type_constraints")}) +class ONNXSchema: + """Python representation of an ONNX schema""" + + name = Property(dtype=str, desc="The operator name") + domain = Property(dtype=str, desc="The operator domain") + doc = Property(dtype=str, desc="The operator's docstring") + since_version = Property(dtype=int, desc="The version of the operator") + attributes = DictProperty(key_type=str, + value_type=ONNXAttribute, + desc="The operator attributes. Keys should contain the name of the attribute, and values " + "should have type :class:`~dace.libraries.onnx.ONNXAttribute`.") + type_constraints = DictProperty( + key_type=str, + value_type=ONNXTypeConstraint, + desc="The type constraints for inputs and outputs. Keys should contain the type string of the constraint, " + "values should have type :class:`~dace.libraries.onnx.ONNXTypeConstraint`.") + inputs = ListProperty(element_type=ONNXParameter, + desc="The operator input parameter descriptors. Entries should have type" + " :class:`~dace.libraries.onnx.ONNXParameter`.") + outputs = ListProperty(element_type=ONNXParameter, + desc="The operator output parameter descriptors. Entries should have type" + " :class:`~dace.libraries.onnx.ONNXParameter`.") + + def __repr__(self): + return self.domain + "." + self.name + + def non_variadic_inputs(self) -> List[str]: + return [i.name for i in self.inputs if i.param_type is not ONNXParameterType.Variadic] + + def variadic_inputs(self) -> List[str]: + return [i.name for i in self.inputs if i.param_type is ONNXParameterType.Variadic] + + def non_variadic_outputs(self) -> List[str]: + return [i.name for i in self.outputs if i.param_type is not ONNXParameterType.Variadic] + + def variadic_outputs(self) -> List[str]: + return [i.name for i in self.outputs if i.param_type is ONNXParameterType.Variadic] + + def validate(self): + # check all parameters with a type str have a entry in the type constraints + for param in chain(self.inputs, self.outputs): + if param.type_str not in self.type_constraints: + # some operators put a type descriptor here. for those, we will try to insert a new type constraint + cons_name = param.name + "_constraint" + if cons_name in self.type_constraints: + raise ValueError( + "Attempted to insert new type constraint, but the name already existed. Please open an issue.") + parsed_typeclass = onnx_type_str_to_typeclass(param.type_str) + + if parsed_typeclass is None: + if config.Config.get_bool('debugprint'): + print("Could not parse typeStr '{}' for parameter '{}'".format(param.type_str, param.name)) + + cons = ONNXTypeConstraint(cons_name, [parsed_typeclass] if parsed_typeclass is not None else []) + self.type_constraints[cons_name] = cons + param.type_str = cons_name + + # check for required parameters with no supported type + for param in chain(self.inputs, self.outputs): + if ((param.param_type == ONNXParameterType.Single or param.param_type == ONNXParameterType.Variadic) + and len(self.type_constraints[param.type_str].types) == 0): + raise NotImplementedError("None of the types for parameter '{}' are supported".format(param.name)) + + # check that all variadic parameter names do not contain "__" + for param in chain(self.inputs, self.outputs): + if param.param_type == ONNXParameterType.Variadic and "__" in param.name: + raise ValueError( + "Unsupported parameter name '{}': variadic parameter names must not contain '__'".format( + param.name)) + + # check that all inputs and outputs have unique names + seen = set() + for param in self.inputs: + if param.name in seen: + raise ValueError("Got duplicate input parameter name '{}'".format(param.name)) + seen.add(param.name) + + seen = set() + for param in self.outputs: + if param.name in seen: + raise ValueError("Got duplicate output parameter name '{}'".format(param.name)) + seen.add(param.name) diff --git a/dace/libraries/standard/nodes/gearbox.py b/dace/libraries/standard/nodes/gearbox.py index 0197855944..32f1383996 100644 --- a/dace/libraries/standard/nodes/gearbox.py +++ b/dace/libraries/standard/nodes/gearbox.py @@ -236,7 +236,7 @@ def __init__(self, size, name=None, schedule=None, **kwargs): memory into n/4 elements (vector size 4), this parameter should be set to n/16. """ - super().__init__(name=name or "gearbox", schedule=schedule or dace.ScheduleType.FPGA_Device, **kwargs) + super().__init__(name=name or "gearbox", schedule=schedule, **kwargs) self.size = size if schedule is not None: self.schedule = schedule diff --git a/dace/libraries/standard/nodes/reduce.py b/dace/libraries/standard/nodes/reduce.py index 37feba4843..34b19e1788 100644 --- a/dace/libraries/standard/nodes/reduce.py +++ b/dace/libraries/standard/nodes/reduce.py @@ -1,4 +1,4 @@ -# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. """ File defining the reduction library node. """ import ast @@ -875,213 +875,6 @@ def expansion(node: 'Reduce', state: SDFGState, sdfg: SDFG): #return reduce_node.expand(sdfg, state) -@dace.library.expansion -class ExpandReduceFPGAPartialReduction(pm.ExpandTransformation): - """ - FPGA SDFG Reduce expansion. This does not assume single-cycle accumulation of the given data type. - - To achieve II=1, reduction is done into multiple partial reduction, which are then - combined at the end. - """ - environments = [] - - # Reduction type expressions dictionary - _REDUCTION_TYPE_EXPR = { - dtypes.ReductionType.Max: 'max(prev, data_in)', - dtypes.ReductionType.Min: 'min(prev, data_in)', - dtypes.ReductionType.Sum: 'prev + data_in', - dtypes.ReductionType.Product: 'prev * data_in', - dtypes.ReductionType.Sub: 'prev - data_in', - dtypes.ReductionType.Div: 'prev / data_in' - } - - @staticmethod - def expansion(node: 'Reduce', state: SDFGState, sdfg: SDFG, partial_width=16): - """ - - :param node: the node to expand - :param state: the state in which the node is in - :param sdfg: the SDFG in which the node is in - :param partial_width: Width of the inner reduction buffer. Must be - larger than the latency of the reduction operation on the given - data type - """ - node.validate(sdfg, state) - inedge: graph.MultiConnectorEdge = state.in_edges(node)[0] - outedge: graph.MultiConnectorEdge = state.out_edges(node)[0] - input_dims = len(inedge.data.subset) - output_dims = len(outedge.data.subset) - input_data = sdfg.arrays[inedge.data.data] - output_data = sdfg.arrays[outedge.data.data] - - # Standardize axes - axes = node.axes if node.axes else [i for i in range(input_dims)] - - # Create nested SDFG - nsdfg = SDFG('reduce') - - nsdfg.add_array('_in', - inedge.data.subset.size(), - input_data.dtype, - strides=input_data.strides, - storage=input_data.storage) - - nsdfg.add_array('_out', - outedge.data.subset.size(), - output_data.dtype, - strides=output_data.strides, - storage=output_data.storage) - if input_data.dtype.veclen > 1: - raise NotImplementedError('Vectorization currently not implemented for FPGA expansion of Reduce.') - - nstate = nsdfg.add_state() - - # (If axes != all) Add outer map, which corresponds to the output range - if len(axes) != input_dims: - all_axis = False - # Interleave input and output axes to match input memlet - ictr, octr = 0, 0 - input_subset = [] - for i in range(input_dims): - if i in axes: - input_subset.append(f'_i{ictr}') - ictr += 1 - else: - input_subset.append(f'_o{octr}') - octr += 1 - - output_size = outedge.data.subset.size() - - ome, omx = nstate.add_map('reduce_output', { - f'_o{i}': f'0:{symstr(sz)}' - for i, sz in enumerate(outedge.data.subset.size()) - }) - outm_idx = ','.join([f'_o{i}' for i in range(output_dims)]) - outm = dace.Memlet(f'_out[{outm_idx}]') - inm_idx = ','.join(input_subset) - inmm = dace.Memlet(f'_in[{inm_idx}]') - else: - all_axis = True - ome, omx = None, None - outm = dace.Memlet('_out[0]') - inm_idx = ','.join([f'_i{i}' for i in range(len(axes))]) - inmm = dace.Memlet(f'_in[{inm_idx}]') - - # Add inner map, which corresponds to the range to reduce - r = nstate.add_read('_in') - w = nstate.add_read('_out') - - # TODO support vectorization - buffer_name = 'partial_results' - nsdfg.add_array(buffer_name, (partial_width, ), - input_data.dtype, - transient=True, - storage=dtypes.StorageType.FPGA_Local) - buffer = nstate.add_access(buffer_name) - buffer_write = nstate.add_write(buffer_name) - - # Initialize explicitly partial results, as the inner map could run for a number of iteration < partial_width - init_me, init_mx = nstate.add_map('partial_results_init', {'i': f'0:{partial_width}'}, - schedule=dtypes.ScheduleType.FPGA_Device, - unroll=True) - init_tasklet = nstate.add_tasklet('init_pr', {}, {'pr_out'}, f'pr_out = {node.identity}') - nstate.add_memlet_path(init_me, init_tasklet, memlet=dace.Memlet()) - nstate.add_memlet_path(init_tasklet, - init_mx, - buffer, - src_conn='pr_out', - memlet=dace.Memlet(f'{buffer_name}[i]')) - - if not all_axis: - nstate.add_memlet_path(ome, init_me, memlet=dace.Memlet()) - - ime, imx = nstate.add_map('reduce_values', { - f'_i{i}': f'0:{symstr(inedge.data.subset.size()[axis])}' - for i, axis in enumerate(sorted(axes)) - }) - - # Accumulate over partial results - redtype = detect_reduction_type(node.wcr) - if redtype not in ExpandReduceFPGAPartialReduction._REDUCTION_TYPE_EXPR: - raise ValueError('Reduction type not supported for "%s"' % node.wcr) - else: - reduction_expr = ExpandReduceFPGAPartialReduction._REDUCTION_TYPE_EXPR[redtype] - - # generate flatten index considering inner map: will be used for indexing into partial results - ranges_size = ime.range.size() - inner_index = '+'.join([f'_i{i} * {ranges_size[i + 1]}' for i in range(len(axes) - 1)]) - inner_op = ' + ' if len(axes) > 1 else '' - inner_index = inner_index + f'{inner_op}_i{(len(axes) - 1)}' - partial_reduce_tasklet = nstate.add_tasklet('partial_reduce', {'data_in', 'buffer_in'}, {'buffer_out'}, f'''\ -prev = buffer_in -buffer_out = {reduction_expr}''') - - if not all_axis: - # Connect input and partial sums - nstate.add_memlet_path(r, ome, ime, partial_reduce_tasklet, dst_conn='data_in', memlet=inmm) - else: - nstate.add_memlet_path(r, ime, partial_reduce_tasklet, dst_conn='data_in', memlet=inmm) - nstate.add_memlet_path(buffer, - ime, - partial_reduce_tasklet, - dst_conn='buffer_in', - memlet=dace.Memlet(f'{buffer_name}[({inner_index})%{partial_width}]')) - nstate.add_memlet_path(partial_reduce_tasklet, - imx, - buffer_write, - src_conn='buffer_out', - memlet=dace.Memlet(f'{buffer_name}[({inner_index})%{partial_width}]')) - - # Then perform reduction on partial results - reduce_entry, reduce_exit = nstate.add_map('reduce', {'i': f'0:{partial_width}'}, - schedule=dtypes.ScheduleType.FPGA_Device, - unroll=True) - - reduce_tasklet = nstate.add_tasklet( - 'reduce', {'reduce_in', 'data_in'}, {'reduce_out'}, f'''\ -prev = reduce_in if i > 0 else {node.identity} -reduce_out = {reduction_expr}''') - nstate.add_memlet_path(buffer_write, - reduce_entry, - reduce_tasklet, - dst_conn='data_in', - memlet=dace.Memlet(f'{buffer_name}[i]')) - - reduce_name = 'reduce_result' - nsdfg.add_array(reduce_name, (1, ), output_data.dtype, transient=True, storage=dtypes.StorageType.FPGA_Local) - reduce_read = nstate.add_access(reduce_name) - reduce_access = nstate.add_access(reduce_name) - - if not all_axis: - nstate.add_memlet_path(ome, reduce_read, memlet=dace.Memlet()) - - nstate.add_memlet_path(reduce_read, - reduce_entry, - reduce_tasklet, - dst_conn='reduce_in', - memlet=dace.Memlet(f'{reduce_name}[0]')) - nstate.add_memlet_path(reduce_tasklet, - reduce_exit, - reduce_access, - src_conn='reduce_out', - memlet=dace.Memlet(f'{reduce_name}[0]')) - - if not all_axis: - # Write out the result - nstate.add_memlet_path(reduce_access, omx, w, memlet=outm) - else: - nstate.add_memlet_path(reduce_access, w, memlet=outm) - - # Rename outer connectors and add to node - inedge._dst_conn = '_in' - outedge._src_conn = '_out' - node.add_in_connector('_in') - node.add_out_connector('_out') - nsdfg.validate() - - return nsdfg - - @dace.library.expansion class ExpandReduceGPUAuto(pm.ExpandTransformation): """ @@ -1556,7 +1349,6 @@ class Reduce(dace.sdfg.nodes.LibraryNode): 'CUDA (device)': ExpandReduceCUDADevice, 'CUDA (block)': ExpandReduceCUDABlock, 'CUDA (block allreduce)': ExpandReduceCUDABlockAll, - 'FPGAPartialReduction': ExpandReduceFPGAPartialReduction, 'GPUAuto': ExpandReduceGPUAuto # 'CUDA (warp)': ExpandReduceCUDAWarp, # 'CUDA (warp allreduce)': ExpandReduceCUDAWarpAll diff --git a/dace/libraries/stencil/intel_fpga.py b/dace/libraries/stencil/intel_fpga.py deleted file mode 100644 index 1f7ac89c84..0000000000 --- a/dace/libraries/stencil/intel_fpga.py +++ /dev/null @@ -1,516 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -import ast -import collections -import functools -import itertools -import operator -import re - -import dace -from dace import data as dt, subsets as sbs -import numpy as np -from .subscript_converter import SubscriptConverter -from ._common import * - - -@dace.library.expansion -class ExpandStencilIntelFPGA(dace.library.ExpandTransformation): - - environments = [] - - @staticmethod - def expansion(node, parent_state, parent_sdfg): - - sdfg = dace.SDFG(node.label + "_outer") - state = sdfg.add_state(node.label + "_outer") - - (inputs, outputs, shape, field_to_data, field_to_desc, field_to_edge, - vector_lengths) = parse_connectors(node, parent_state, parent_sdfg) - - ####################################################################### - # Parse the tasklet code - ####################################################################### - - # Replace relative indices with memlet names - converter = SubscriptConverter() - - # Add copy boundary conditions - for field in node.boundary_conditions: - if node.boundary_conditions[field]["btype"] == "copy": - center_index = tuple(0 for _ in range(len(parent_sdfg.arrays[field_to_data[field]].shape))) - # This will register the renaming - converter.convert(field, center_index) - - # Replace accesses in the code - code, field_accesses = parse_accesses(node.code.as_string, outputs) - - iterator_mapping = make_iterator_mapping(node, field_accesses, shape) - vector_length = validate_vector_lengths(vector_lengths, iterator_mapping) - shape_vectorized = tuple(s / vector_length if i == len(shape) - 1 else s for i, s in enumerate(shape)) - - # Extract which fields to read from streams and what to buffer - buffer_sizes = collections.OrderedDict() - buffer_accesses = collections.OrderedDict() - scalars = {} # {name: type} - for field_name in inputs: - relative = field_accesses[field_name] - dim_mask = iterator_mapping[field_name] - if not any(dim_mask): - # This is a scalar, no buffer needed. Instead, the SDFG must - # take this as a symbol - scalars[field_name] = parent_sdfg.symbols[field_name] - sdfg.add_symbol(field_name, parent_sdfg.symbols[field_name]) - continue - abs_indices = ( - [dim_to_abs_val(i, tuple(s for s, m in zip(shape, dim_mask) if m), parent_sdfg) - for i in relative] + ([0] if field_name in node.boundary_conditions - and node.boundary_conditions[field_name]["btype"] == "copy" else [])) - max_access = max(abs_indices) - min_access = min(abs_indices) - buffer_size = max_access - min_access + vector_lengths[field_name] - buffer_sizes[field_name] = buffer_size - # (indices relative to center, buffer indices, center index) - buffer_accesses[field_name] = ([tuple(r) for r in relative], [i - min_access - for i in abs_indices], -min_access) - - # Create a initialization phase corresponding to the highest distance - # to the center - init_sizes = [(buffer_sizes[key] - vector_lengths[key] - val[2]) // vector_length - for key, val in buffer_accesses.items()] - init_size_max = int(np.max(init_sizes)) - - parameters = [f"_i{i}" for i in range(len(shape))] - - # Dimensions we need to iterate over - iterator_mask = np.array([s != 0 and s != 1 for s in shape], dtype=bool) - iterators = make_iterators(tuple(s for s, m in zip(shape_vectorized, iterator_mask) if m), - parameters=tuple(s for s, m in zip(parameters, iterator_mask) if m)) - - # Manually add pipeline entry and exit nodes - pipeline_range = dace.properties.SubsetProperty.from_string(', '.join(iterators.values())) - pipeline = dace.sdfg.nodes.PipelineScope("compute_" + node.label, - list(iterators.keys()), - pipeline_range, - dace.dtypes.ScheduleType.FPGA_Device, - False, - init_size=init_size_max, - init_overlap=False, - drain_size=init_size_max, - drain_overlap=True) - entry = dace.sdfg.nodes.PipelineEntry(pipeline) - exit = dace.sdfg.nodes.PipelineExit(pipeline) - state.add_nodes_from([entry, exit]) - - # Add nested SDFG to do 1) shift buffers 2) read from input 3) compute - nested_sdfg = dace.SDFG(node.label + "_inner", parent=state) - nested_sdfg_tasklet = state.add_nested_sdfg( - nested_sdfg, - # Input connectors - [k + "_in" - for k in inputs if any(iterator_mapping[k])] + [name + "_buffer_in" for name, _ in buffer_sizes.items()], - # Output connectors - [k + "_out" for k in outputs] + [name + "_buffer_out" for name, _ in buffer_sizes.items()], - schedule=dace.ScheduleType.FPGA_Device) - # Propagate symbols - for sym_name, sym_type in parent_sdfg.symbols.items(): - nested_sdfg.add_symbol(sym_name, sym_type) - nested_sdfg_tasklet.symbol_mapping[sym_name] = sym_name - # Map iterators - for p in parameters: - nested_sdfg.add_symbol(p, dace.int64) - nested_sdfg_tasklet.symbol_mapping[p] = p - - # Shift state, which shifts all buffers by one - shift_state = nested_sdfg.add_state(node.label + "_shift") - - # Update state, which reads new values from memory - update_state = nested_sdfg.add_state(node.label + "_update") - - ####################################################################### - # Implement boundary conditions - ####################################################################### - - boundary_code, oob_cond = generate_boundary_conditions(node, shape, field_accesses, field_to_desc, - iterator_mapping) - - ####################################################################### - # Only write if we're in bounds - ####################################################################### - - write_code = ("\n".join([ - "{}_inner_out = {}\n".format(output, field_accesses[output][tuple(0 for _ in range(len(shape)))]) - for output in outputs - ])) - if init_size_max > 0 or len(oob_cond) > 0: - write_cond = [] - if init_size_max > 0: - init_cond = pipeline.init_condition() - write_cond.append("not " + init_cond) - nested_sdfg_tasklet.symbol_mapping[init_cond] = init_cond - nested_sdfg.add_symbol(init_cond, dace.bool) - if len(oob_cond) > 0: - oob_cond = " or ".join(sorted(oob_cond)) - oob_cond = f"not ({oob_cond})" - write_cond.append(oob_cond) - write_cond = " and ".join(write_cond) - write_cond = f"if {write_cond}:\n\t" - else: - write_cond = "" - - code = boundary_code + "\n" + code + "\n" + write_code - - ####################################################################### - # Create DaCe compute state - ####################################################################### - - # Compute state, which reads from input channels, performs the compute, - # and writes to the output channel(s) - compute_state = nested_sdfg.add_state(node.label + "_compute") - compute_inputs = list( - itertools.chain.from_iterable([["_" + v for v in field_accesses[f].values()] for f in inputs - if any(iterator_mapping[f])])) - compute_tasklet = compute_state.add_tasklet(node.label + "_compute", - compute_inputs, {name + "_inner_out" - for name in outputs}, - code, - language=dace.dtypes.Language.Python) - if vector_length > 1: - compute_unroll_entry, compute_unroll_exit = compute_state.add_map(compute_state.label + "_unroll", - {"i_unroll": f"0:{vector_length}"}, - schedule=dace.ScheduleType.FPGA_Device, - unroll=True) - - # Connect the three nested states - nested_sdfg.add_edge(shift_state, update_state, dace.sdfg.InterstateEdge()) - nested_sdfg.add_edge(update_state, compute_state, dace.sdfg.InterstateEdge()) - - # First, grab scalar variables - for scalar, scalar_type in scalars.items(): - nested_sdfg.add_symbol(scalar, scalar_type) - - # Code to increment custom iterators - iterator_code = "" - - for (field_name, size), init_size in zip(buffer_sizes.items(), init_sizes): - - data_name = field_to_data[field_name] - connector = field_to_edge[field_name].dst_conn - data_name_outer = connector - data_name_inner = field_name + "_in" - parent_desc = parent_sdfg.arrays[data_name] - if isinstance(parent_desc, dt.View): - if isinstance(parent_desc, (dt.ArrayView, dt.ContainerArray)): - desc_outer = parent_desc.as_array() - elif isinstance(parent_desc, dt.StructureView): - desc_outer = parent_desc.as_structure() - else: - desc_outer = parent_sdfg.arrays[data_name].clone() - desc_outer.transient = False - sdfg.add_datadesc(data_name_outer, desc_outer) - - mapping = iterator_mapping[field_name] - is_array = not isinstance(desc_outer, dt.Stream) - - # If this array is part of the initialization phase, it needs its - # own iterator, which we need to instantiate and increment in the - # outer SDFG - if is_array: - if init_size == 0: - field_index = [s for s, p in zip(parameters, mapping) if p] - else: - # Create custom iterators for this array - num_dims = sum(mapping, 0) - field_iterators = [(f"_{field_name}_i{i}", shape[i]) for i in range(num_dims) if mapping[i]] - start_index = init_size_max - init_size - tab = "" - if start_index > 0: - iterator_code += (f"if {pipeline.iterator_str()} >= {start_index}:\n") - tab += " " - for i, (it, s) in enumerate(reversed(field_iterators)): - iterator_code += f"""\ -{tab}if {it} < {s} - 1: -{tab} {it} = {it} + 1 -{tab}else: -{tab} {it} = 0\n""" - tab += " " - field_index = [fi[0] for fi in field_iterators] - for fi in field_index: - pipeline.additional_iterators[fi] = "0" - nested_sdfg.add_symbol(fi, dace.int64) - nested_sdfg_tasklet.symbol_mapping[fi] = fi - field_index = ", ".join(field_index) - else: - field_index = "0" - - # Begin reading according to this field's own buffer size, which is - # translated to an index by subtracting it from the maximum buffer - # size - begin_reading = init_size_max - init_size - total_size = functools.reduce(operator.mul, shape_vectorized, 1) - end_reading = total_size + init_size_max - init_size - - # Outer memory read - read_node_outer = state.add_read(data_name_outer) - if begin_reading != 0 or end_reading != total_size + init_size_max: - sdfg.add_scalar(f"{field_name}_wavefront", - desc_outer.dtype, - storage=dace.StorageType.FPGA_Local, - transient=True) - wavefront_access = state.add_access(f"{field_name}_wavefront") - condition = [] - it = pipeline.iterator_str() - if begin_reading != 0: - condition.append(f"{it} >= {begin_reading}") - if end_reading != total_size + init_size_max: - condition.append(f"{it} < {end_reading}") - condition = " and ".join(condition) - update_tasklet = state.add_tasklet(f"read_{field_name}", {"wavefront_in"}, {"wavefront_out"}, - f"if {condition}:\n" - "\twavefront_out = wavefront_in\n", - language=dace.dtypes.Language.Python) - state.add_memlet_path(read_node_outer, - entry, - update_tasklet, - dst_conn="wavefront_in", - memlet=dace.Memlet(f"{data_name_outer}[{field_index}]", dynamic=True)) - state.add_memlet_path(update_tasklet, - wavefront_access, - src_conn="wavefront_out", - memlet=dace.Memlet(f"{field_name}_wavefront", dynamic=True)) - state.add_memlet_path(wavefront_access, - nested_sdfg_tasklet, - dst_conn=f"{field_name}_in", - memlet=dace.Memlet(f"{field_name}_wavefront")) - else: - state.add_memlet_path(read_node_outer, - entry, - nested_sdfg_tasklet, - dst_conn=f"{field_name}_in", - memlet=dace.Memlet(f"{data_name_outer}[{field_index}]")) - - # Create inner memory access - nested_sdfg.add_scalar(data_name_inner, - desc_outer.dtype, - storage=dace.StorageType.FPGA_Local, - transient=False) - - buffer_name_outer = f"{node.label}_{field_name}_buffer" - buffer_name_inner_read = f"{field_name}_buffer_in" - buffer_name_inner_write = f"{field_name}_buffer_out" - - # Create buffer transient in outer SDFG - field_dtype = parent_sdfg.data(data_name).dtype - _, desc_outer = sdfg.add_array(buffer_name_outer, (size, ), - field_dtype.base_type, - storage=dace.dtypes.StorageType.FPGA_Local, - transient=True) - - # Create read and write nodes - read_node_outer = state.add_read(buffer_name_outer) - write_node_outer = state.add_write(buffer_name_outer) - - # Outer buffer read - state.add_memlet_path(read_node_outer, - entry, - nested_sdfg_tasklet, - dst_conn=buffer_name_inner_read, - memlet=dace.Memlet(f"{buffer_name_outer}[0:{size}]")) - - # Outer buffer write - state.add_memlet_path(nested_sdfg_tasklet, - exit, - write_node_outer, - src_conn=buffer_name_inner_write, - memlet=dace.Memlet(f"{write_node_outer.data}[0:{size}]", dynamic=True)) - - # Inner copy - desc_inner_read = desc_outer.clone() - desc_inner_read.transient = False - desc_inner_read.name = buffer_name_inner_read - desc_inner_write = desc_inner_read.clone() - desc_inner_write.name = buffer_name_inner_write - nested_sdfg.add_datadesc(buffer_name_inner_read, desc_inner_read) - nested_sdfg.add_datadesc(buffer_name_inner_write, desc_inner_write) - - # Make shift state if necessary - if size > 1: - shift_read = shift_state.add_read(buffer_name_inner_read) - shift_write = shift_state.add_write(buffer_name_inner_write) - shift_entry, shift_exit = shift_state.add_map(f"shift_{field_name}", - {"i_shift": f"0:{size} - {vector_lengths[field_name]}"}, - schedule=dace.dtypes.ScheduleType.FPGA_Device, - unroll=True) - shift_tasklet = shift_state.add_tasklet(f"shift_{field_name}", {f"{field_name}_shift_in"}, - {f"{field_name}_shift_out"}, - f"{field_name}_shift_out = {field_name}_shift_in") - shift_state.add_memlet_path(shift_read, - shift_entry, - shift_tasklet, - dst_conn=field_name + "_shift_in", - memlet=dace.Memlet(f"{shift_read.data}" - f"[i_shift + {vector_lengths[field_name]}]")) - shift_state.add_memlet_path(shift_tasklet, - shift_exit, - shift_write, - src_conn=field_name + "_shift_out", - memlet=dace.Memlet(f"{shift_write.data}[i_shift]")) - - # Make update state - update_read = update_state.add_read(data_name_inner) - update_write = update_state.add_write(buffer_name_inner_write) - subset = f"{size} - {vector_length}:{size}" if size > 1 else "0" - update_state.add_memlet_path(update_read, - update_write, - memlet=dace.Memlet(f"{update_read.data}", other_subset=f"{subset}")) - - # Make compute state - compute_read = compute_state.add_read(buffer_name_inner_read) - for relative, offset in zip(buffer_accesses[field_name][0], buffer_accesses[field_name][1]): - memlet_name = field_accesses[field_name][tuple(relative)] - if vector_length > 1: - if vector_lengths[field_name] > 1: - offset = f"{offset} + i_unroll" - else: - offset = str(offset) - path = [compute_read, compute_unroll_entry, compute_tasklet] - else: - offset = str(offset) - path = [compute_read, compute_tasklet] - compute_state.add_memlet_path(*path, - dst_conn="_" + memlet_name, - memlet=dace.Memlet(f"{compute_read.data}[{offset}]")) - - # Tasklet to update iterators - if iterator_code: - update_iterator_tasklet = state.add_tasklet(f"{node.label}_update_iterators", {}, {}, iterator_code) - state.add_memlet_path(nested_sdfg_tasklet, update_iterator_tasklet, memlet=dace.Memlet()) - state.add_memlet_path(update_iterator_tasklet, exit, memlet=dace.Memlet()) - - for field_name in outputs: - - for offset in field_accesses[field_name]: - if offset is not None and list(offset) != [0] * len(offset): - raise NotImplementedError("Output offsets not implemented") - - data_name = field_to_data[field_name] - - # Outer write - data_name_outer = field_name - data_name_inner = field_name + "_out" - parent_desc = parent_sdfg.arrays[data_name] - if isinstance(parent_desc, dt.View): - if isinstance(parent_desc, dt.ArrayView): - desc_outer = dt.Array(parent_desc.dtype, - parent_desc.shape, - transient=False, - allow_conflicts=parent_desc.allow_conflicts, - storage=parent_desc.storage, - location=parent_desc.location, - strides=parent_desc.strides, - offset=parent_desc.offset, - may_alias=parent_desc.may_alias, - lifetime=parent_desc.lifetime, - alignment=parent_desc.alignment, - debuginfo=parent_desc.debuginfo, - total_size=parent_desc.total_size, - start_offset=parent_desc.start_offset, - optional=parent_desc.optional, - pool=parent_desc.pool) - elif isinstance(parent_desc, dt.StructureView): - desc_outer = dt.Structure(members=parent_desc.members, - name=parent_desc.name, - transient=False, - storage=parent_desc.storage, - location=parent_desc.location, - lifetime=parent_desc.lifetime, - debuginfo=parent_desc.debuginfo) - elif isinstance(parent_desc, dt.ContainerView): - desc_outer = dt.ContainerArray(parent_desc.dtype, - parent_desc.shape, - transient=False, - allow_conflicts=parent_desc.allow_conflicts, - storage=parent_desc.storage, - location=parent_desc.location, - strides=parent_desc.strides, - offset=parent_desc.offset, - may_alias=parent_desc.may_alias, - lifetime=parent_desc.lifetime, - alignment=parent_desc.alignment, - debuginfo=parent_desc.debuginfo, - total_size=parent_desc.total_size, - start_offset=parent_desc.start_offset, - optional=parent_desc.optional, - pool=parent_desc.pool) - else: - desc_outer = parent_sdfg.arrays[data_name].clone() - desc_outer.transient = False - array_index = ", ".join(map(str, parameters)) - try: - sdfg.add_datadesc(data_name_outer, desc_outer) - except NameError: # Already an input - pass - - # Create inner access - nested_sdfg.add_scalar(data_name_inner, - desc_outer.dtype, - storage=dace.StorageType.FPGA_Local, - transient=False) - - # Inner write - write_node_inner = compute_state.add_write(data_name_inner) - - # Intermediate buffer, mostly relevant for vectorization - output_buffer_name = field_name + "_output_buffer" - nested_sdfg.add_array(output_buffer_name, (vector_length, ), - desc_outer.dtype.base_type, - storage=dace.StorageType.FPGA_Registers, - transient=True) - output_buffer = compute_state.add_access(output_buffer_name) - - # If vectorized, we need to pass through the unrolled scope - if vector_length > 1: - compute_state.add_memlet_path(compute_tasklet, - compute_unroll_exit, - output_buffer, - src_conn=field_name + "_inner_out", - memlet=dace.Memlet(f"{output_buffer_name}[i_unroll]")) - else: - compute_state.add_memlet_path(compute_tasklet, - output_buffer, - src_conn=field_name + "_inner_out", - memlet=dace.Memlet(f"{output_buffer_name}[0]")), - - # Final memlet to the output - compute_state.add_memlet_path(output_buffer, - write_node_inner, - memlet=dace.Memlet(f"{write_node_inner.data}")), - - # Conditional write tasklet - sdfg.add_scalar(f"{field_name}_result", - desc_outer.dtype, - storage=dace.StorageType.FPGA_Local, - transient=True) - output_access = state.add_access(f"{field_name}_result") - state.add_memlet_path(nested_sdfg_tasklet, - output_access, - src_conn=data_name_inner, - memlet=dace.Memlet(f"{field_name}_result")) - output_tasklet = state.add_tasklet(f"{field_name}_conditional_write", {f"_{field_name}_result"}, - {f"_{data_name_inner}"}, - (write_cond + f"_{data_name_inner} = _{field_name}_result")) - state.add_memlet_path(output_access, - output_tasklet, - dst_conn=f"_{field_name}_result", - memlet=dace.Memlet(f"{field_name}_result")) - write_node_outer = state.add_write(data_name_outer) - if isinstance(desc_outer, dt.Stream): - subset = "0" - else: - subset = array_index - state.add_memlet_path(output_tasklet, - exit, - write_node_outer, - src_conn=f"_{data_name_inner}", - memlet=dace.Memlet(f"{write_node_outer.data}[{subset}]", dynamic=True)), - - return sdfg diff --git a/dace/libraries/stencil/stencil.py b/dace/libraries/stencil/stencil.py index f845af2ef2..bff5d89e06 100644 --- a/dace/libraries/stencil/stencil.py +++ b/dace/libraries/stencil/stencil.py @@ -6,8 +6,6 @@ import dace.library from .cpu import ExpandStencilCPU -from .intel_fpga import ExpandStencilIntelFPGA -# from .xilinx import ExpandStencilXilinx @dace.library.node @@ -48,8 +46,6 @@ class Stencil(dace.library.LibraryNode): implementations = { "pure": ExpandStencilCPU, - "intel_fpga": ExpandStencilIntelFPGA, - # "xilinx": ExpandStencilXilinx } default_implementation = "pure" diff --git a/dace/libraries/torch/__init__.py b/dace/libraries/torch/__init__.py new file mode 100644 index 0000000000..8cc16a958d --- /dev/null +++ b/dace/libraries/torch/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +DaCe PyTorch Integration Library. + +This module provides integration between DaCe (Data-Centric Parallel Programming) +and PyTorch, enabling: +- Compilation of PyTorch operations to optimized DaCe SDFGs +- Interoperability between PyTorch tensors and DaCe arrays +- Support for both CPU (PyTorch) and GPU (PyTorchGPU) execution +- DLPack-based zero-copy tensor sharing + +The main exports are environment classes that define the PyTorch runtime +dependencies and configuration for code generation. +""" + +try: + from .environments import PyTorch, PyTorchGPU + __all__ = ["PyTorch", "PyTorchGPU"] +except ImportError: + # PyTorch not available + PyTorch = None + PyTorchGPU = None + __all__ = [] diff --git a/dace/libraries/torch/dispatchers/__init__.py b/dace/libraries/torch/dispatchers/__init__.py new file mode 100644 index 0000000000..33b5aeecee --- /dev/null +++ b/dace/libraries/torch/dispatchers/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +PyTorch Dispatchers for DaCe Modules. + +This module provides different dispatcher implementations for executing DaCe SDFGs +from PyTorch. Dispatchers handle: +- Compiling SDFGs to native code +- Initializing runtime state and memory +- Converting between PyTorch tensors and DaCe arrays +- Calling forward and backward SDFG functions +- Managing the integration with PyTorch's autograd system + +Available dispatchers: +- CTypes dispatcher: Uses ctypes for direct C function calls +- C++ PyTorch extension: Registers as a native PyTorch extension with custom autograd +""" + +from .common import DaceTorchFunction +from .cpp_torch_extension import register_and_compile_torch_extension +from .ctypes_module import get_ctypes_dispatcher + +__all__ = ["DaceTorchFunction", "register_and_compile_torch_extension", "get_ctypes_dispatcher"] diff --git a/dace/libraries/torch/dispatchers/common.py b/dace/libraries/torch/dispatchers/common.py new file mode 100644 index 0000000000..80ee1d9f28 --- /dev/null +++ b/dace/libraries/torch/dispatchers/common.py @@ -0,0 +1,112 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Common utilities for PyTorch-DaCe dispatchers. + +This module provides shared functionality for different dispatcher implementations, +including: +- SDFG compilation and initialization +- Argument list extraction and processing +- State management for forward and backward passes +- Integration with PyTorch's autograd system +""" + +import dataclasses +from typing import Callable, List, Tuple, Union + +import dace +import torch +from dace.codegen.compiled_sdfg import CompiledSDFG +from dace.libraries.onnx.converters import clean_onnx_name +from dace.frontend.ml.onnx.importer import create_output_array + + +@dataclasses.dataclass +class DaceTorchFunction: + """ + An initialized, callable function for a DaceModule and its associated state. + + This dataclass encapsulates a compiled DaCe module with its runtime state, + providing a callable interface for PyTorch integration. + + Attributes: + function: The PyTorch callable function that executes the SDFG. + compiled_sdfgs: The compiled SDFGs holding their runtime states. + ptr: Pointers to the initialized SDFG state handles. These must be + passed as the first arguments to the function. + """ + function: Callable + compiled_sdfgs: List[CompiledSDFG] + ptr: List[torch.Tensor] + + +def get_arglist(module: 'dace.frontend.ml.torch.DaceModule') -> Tuple[List[str], List[str]]: + """Get the list of forward-pass argument names for a module. + + :param module: The DaCe module to extract argument names from. + :return: A tuple of (input_names, output_names) where each is a list of cleaned + argument names suitable for use in generated code. + """ + + arglist = [clean_onnx_name(input_name) for input_name in module.dace_model.inputs] + outputs = [clean_onnx_name(output_name) for output_name in module.dace_model.outputs] + return arglist, outputs + + +def compile_and_init_sdfgs( + module: 'dace.frontend.ml.torch.DaceModule', dummy_inputs +) -> Union[Tuple[CompiledSDFG, torch.Tensor], Tuple[CompiledSDFG, torch.Tensor, CompiledSDFG, torch.Tensor]]: + """Compile SDFGs and initialize them using the provided dummy inputs. + + This function compiles the forward pass SDFG and optionally the backward pass + SDFG if the module has automatic differentiation enabled. It initializes both + SDFGs with the appropriate tensors and parameters. + + :param module: The DaCe module to compile SDFGs for. + :param dummy_inputs: The dummy inputs to use for shape inference and initialization. + :return: If the module has no backward pass: (compiled_sdfg, state_ptr). + If the module has a backward pass: (compiled_fwd_sdfg, fwd_state_ptr, + compiled_bwd_sdfg, bwd_state_ptr). Where state_ptr is a torch.Tensor + containing the pointer to the SDFG state. + """ + + compiled: CompiledSDFG = module.dace_model.compile_and_init() + # Construct the arguments and initialize the SDFG + args = tuple(dummy_inputs) + module._call_params() + args = tuple(arg.detach() for arg in args) + inputs, symbols, outputs = module.dace_model._call_args(args=args, kwargs={}) + + if module.backward: + forwarded_transients = { + name: + create_output_array(symbols, desc, use_torch=True, zeros=True) + if name not in module.dace_model.initialized_parameters else module.dace_model.initialized_parameters[name] + for name, desc in module._ad_inp_arrs.items() + } + else: + forwarded_transients = {} + + all_kwargs = {**inputs, **outputs, **symbols, **forwarded_transients, **module.dace_model.initialized_parameters} + + compiled.initialize(**all_kwargs) + for _, hook in module.post_compile_hooks.items(): + hook(compiled) + handle_ptr = torch.tensor([compiled._libhandle.value]).squeeze(0) + + if module.backward: + # Compile and initialize the backward_sdfg + compiled_bwd: CompiledSDFG = module.backward_sdfg.compile() + + required_grads = { + bwd_name: create_output_array(symbols, compiled_bwd.sdfg.arrays[bwd_name], use_torch=True, zeros=True) + for _, bwd_name in module._ad_result.required_grad_names.items() + } + given_grads = { + bwd_name: create_output_array(symbols, compiled_bwd.sdfg.arrays[bwd_name], use_torch=True, zeros=True) + for _, bwd_name in module._ad_result.given_grad_names.items() + } + + compiled_bwd.initialize(**required_grads, **given_grads, **forwarded_transients) + bwd_handle_ptr = torch.tensor([compiled_bwd._libhandle.value]).squeeze(0) + return compiled, handle_ptr, compiled_bwd, bwd_handle_ptr + else: + return compiled, handle_ptr diff --git a/dace/libraries/torch/dispatchers/cpp_torch_extension.py b/dace/libraries/torch/dispatchers/cpp_torch_extension.py new file mode 100644 index 0000000000..f7449df2f4 --- /dev/null +++ b/dace/libraries/torch/dispatchers/cpp_torch_extension.py @@ -0,0 +1,699 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +"""Code generation for PyTorch C++ dispatched operators.""" +import copy +import dataclasses +from distutils import sysconfig +import hashlib +import itertools +import operator +import os +import sys +from typing import List, Tuple, Callable, Optional, Dict, Union + +import dace.library +import numpy as np +import torch +from torch.utils.cpp_extension import load as torch_load +import dace +from dace import config, dtypes as dt, data +from dace.codegen import targets, compiler +from dace.codegen.codeobject import CodeObject +from dace.codegen.compiled_sdfg import CompiledSDFG +from dace.codegen.prettycode import CodeIOStream +from dace.codegen.common import sym2cpp, platform_library_name + +from dace.autodiff import BackwardResult +from dace.libraries.torch.environments import PyTorch + +from dace.libraries.torch.dispatchers.common import DaceTorchFunction, compile_and_init_sdfgs, get_arglist + +_REPLACED_CTYPES = {dace.int64: "int64_t", dace.uint64: "uint64_t", dace.float16: "at::Half"} + + +def torch_ctype(dtype: dace.typeclass) -> str: + """Convert a DaCe type to the corresponding PyTorch C++ type string. + + :param dtype: The DaCe typeclass to convert. + :return: The corresponding C++ type string for PyTorch. + """ + if isinstance(dtype, dace.pointer): + # assuming pointers are 64 bit + ctype = "int64_t" + elif dtype in _REPLACED_CTYPES: + ctype = _REPLACED_CTYPES[dtype] + else: + ctype = dtype.ctype + return ctype + + +_TYPECLASS_TO_TORCH_DTYPE_STR = { + dt.bool: "kBool", + dt.int8: "kInt8", + dt.uint8: "kUInt8", + dt.int16: "kInt16", + dt.int32: "kInt32", + dt.int64: "kInt64", + dt.float16: "kFloat16", + dt.float32: "kFloat32", + dt.float64: "kFloat64", + dt.complex64: "kComplexFloat", + dt.complex128: "kComplexDouble", +} + + +def typeclass_to_torch_cpp_type(type: dace.typeclass) -> str: + """Convert a DaCe typeclass to PyTorch C++ tensor type string. + + :param type: The DaCe typeclass to convert. + :return: The corresponding PyTorch tensor type string (e.g., 'kFloat32'). + """ + if isinstance(type, dace.pointer): + # assuming pointers are 64 bit + return "kInt64" + else: + return _TYPECLASS_TO_TORCH_DTYPE_STR[type] + + +def tensor_init_for_desc(name: str, desc: data.Data, clean_weights: Dict[str, torch.Tensor], zeros=True) -> str: + """Emit the initialization code for a descriptor. + + :param name: The name of the tensor. + :param desc: The data descriptor. + :param clean_weights: Dictionary of constant weights. + :param zeros: Whether to initialize with zeros (True) or empty (False). + :return: C++ code string for tensor initialization. + """ + + # Check if name is in clean_weights + if name in clean_weights: + # Get the tensor from clean_weights + weight_tensor = clean_weights[name] + + # Convert the tensor to a C++ initializer list format + # Flatten the tensor and convert to list + values = weight_tensor.flatten().tolist() + + # Format the values based on the data type + def format_value(v, dtype): + if dtype in [dt.float32, dt.float16]: + return f'{v}f' + elif dtype == dt.float64: + return str(v) + elif dtype in [dt.int8, dt.int16, dt.int32, dt.int64, dt.uint8]: + return str(int(v)) + elif dtype == dt.bool: + return str(v).lower() + else: + return str(v) + + # Format the values as a C++ initializer list + values_str = ', '.join(format_value(v, desc.dtype) for v in values) + + return f"""\ + Tensor {name} = torch::from_blob( + new float[{len(values)}]{{{values_str}}}, + {{{', '.join(str(s) for s in desc.shape)}}}, + torch::TensorOptions() + .dtype(torch::{typeclass_to_torch_cpp_type(desc.dtype)}) + .device(torch::{'kCUDA' if desc.storage in dace.dtypes.GPU_STORAGES else 'kCPU'}) + .layout(torch::kStrided)).clone(); + """ + else: + # Initialize with zeros or empty + return f"""\ + Tensor {name} = torch::{'zeros' if zeros else 'empty'}( + {{{', '.join(str(s) for s in desc.shape)}}}, + torch::TensorOptions() + .dtype(torch::{typeclass_to_torch_cpp_type(desc.dtype)}) + .device(torch::{'kCUDA' if desc.storage in dace.dtypes.GPU_STORAGES else 'kCPU'}) + .layout(torch::kStrided)); + """ + + +def initialize_outputs_code(module: 'dace.frontend.ml.torch.DaceModule', output_names: List[str], + clean_weights: Dict[str, torch.Tensor]) -> str: + """Generate the code that initializes the output tensors. + + :param module: The module + :param output_names: The output names of the SDFG. + :param clean_weights: Dictionary of constant weights + :return: The code + """ + arglist = module.sdfg.arglist() + code = "" + for name in sorted(output_names): + code += tensor_init_for_desc(name, arglist[name], clean_weights) + + return code + + +def argument_codegen(sdfg: dace.SDFG, + clean_weights: Dict[str, torch.Tensor], + input_names: List[str], + output_names: List[str], + guard_contiguous: Optional[List[str]] = None) -> Tuple[str, str, str]: + """Generate the code that grabs the pointers of inputs and outputs. + + The names of the tensors will match the SDFG tensor names. Tensors that are not created by us (i.e. inputs) + should be named {sdfg_name}_ first, and then .contiguous() will be called on them to yield the tensor that we + require. This is the case for all tensors in ``guard_contiguous``. + + :param sdfg: The SDFG to generate code for + :param clean_weights: The constant weights of the SDFG. + :param input_names: Names of inputs to the torch function. + :param output_names: Names of outputs to the torch function. + :param guard_contiguous: A subset of input_names to call .contiguous on. If None, all input names will be + guarded. + :return: The code for initializing the argument, the SDFG arguments in order, and the init call arguments + """ + arglist = sdfg.arglist() + + guard_contiguous = set(guard_contiguous or input_names) + + assert set(input_names).issubset(arglist.keys()), \ + f"Input names {set(input_names).difference(arglist.keys())} are not SDFG arguments {arglist.keys()}" + + # Initialize the inputs and outputs + ptr_init_code = "\n// Setup input and output pointers\n" + for name in sorted(input_names): + tctype = torch_ctype(arglist[name].dtype) + dctype = arglist[name].dtype + + if isinstance(arglist[name], data.Array) or dt.can_access(dt.ScheduleType.GPU_Device, arglist[name].storage): + if name in guard_contiguous: + if config.Config.get_bool('debugprint'): + ptr_init_code += f""" + if (!{name}_.is_contiguous()) {{ + fprintf(stderr, "{name} was not contiguous!"); + }} + """ + ptr_init_code += '\n' + f"Tensor {name} = {name}_.contiguous();" + + ptr_init_code += '\n' + f"{dctype} *{name}_ptr = reinterpret_cast<{dctype}*>({name}.data_ptr<{tctype}>());" + + elif isinstance(arglist[name], data.Scalar): + if name in guard_contiguous: + ptr_init_code += '\n' + f"{dctype} {name}_ptr = static_cast<{dctype}>({name}_.item().to<{tctype}>());" + else: + ptr_init_code += '\n' + f"{dctype} {name}_ptr = static_cast<{dctype}>({name}.item().to<{tctype}>());" + else: + raise ValueError(f"Unsupported data type {type(arglist[name])} for descriptor {name}") + + ptr_init_code += '\n' + + # Outputs and backward arrays + ptr_init_code += '\n'.join( + f"{arglist[name].dtype.ctype} *{name}_ptr = reinterpret_cast<{arglist[name].dtype.ctype}*>" + f"({name}.data_ptr<{torch_ctype(arglist[name].dtype)}>());" for name in sorted(output_names)) + ptr_init_code += "\n// Setup constant arguments\n" + + all_access_nodes = set() + for state in sdfg.nodes(): + all_access_nodes |= set(n.data for n in state.data_nodes()) + + # Initialize all remaining parameters + remaining = set(arglist).difference(itertools.chain(input_names, output_names)) + for name in sorted(remaining): + # Remaining args must be constants + if name not in clean_weights: + raise ValueError(f"Cannot generate PyTorch module C++ code: SDFG argument {name} is not an input or output" + f" of the PyTorch Module, and not a constant.") + + value = clean_weights[name] + ptr_init_code += f"{constant_initializer_code(name, arglist[name], value)}\n" + + arguments = ", ".join(f"{n}_ptr" for n in arglist) + init_arguments = ", ".join(f"{n}_ptr" for n, desc in arglist.items() if isinstance(desc, data.Scalar)) + + return ptr_init_code, arguments, init_arguments + + +def item_to_cpp_literal(item) -> str: + """Convert a numpy item to a C++ literal string. + + :param item: The numpy item to convert. + :return: The C++ literal representation as a string. + """ + dtype = str(item.dtype) + if np.isneginf(item): + return "-std::numeric_limits::infinity()" + if np.isposinf(item): + return "std::numeric_limits::infinity()" + if dtype == "float32": + return f"{item}f" + elif dtype == "bool": + return f"{str(item).lower()}" + elif dtype == "int64": + return f"{item}l" + elif dtype == "float16": + ctype = dace.dtypes._CTYPES[item.dtype.type] + return f"(({ctype}){item})" + elif dtype in ["float64", "int32", "int16", "int8"]: + return str(item) + else: + raise ValueError(f"Unsupported tensor type {item.dtype}") + + +def constant_initializer_code(name: str, desc: data.Data, value) -> str: + """Generate C++ code for initializing a constant value. + + :param name: The name of the constant. + :param desc: The data descriptor. + :param value: The constant value. + :return: C++ code string for constant initialization. + """ + gpu_storage = dt.can_access(dt.ScheduleType.GPU_Device, desc.storage) + gpu_storage = False + if desc.total_size == 0: + return f"{desc.dtype.ctype} *{name}_ptr = nullptr;" + elif isinstance(desc, data.Array) or gpu_storage: + numpyval = value.cpu().numpy() + if len(numpyval.shape) == 0: + numpyval = numpyval.reshape((1, )) + iterator = np.nditer(numpyval, order="C") + gpu_copy_code = f""" + Tensor {name} = torch::from_blob({name}_ptr_cpu, {{{', '.join(sym2cpp(s) for s in desc.shape)}}}, + {{{', '.join(sym2cpp(s) for s in desc.strides)}}}, torch::{typeclass_to_torch_cpp_type(desc.dtype)}) + .to(torch::kCUDA); + {desc.dtype.ctype} *{name}_ptr = reinterpret_cast<{desc.dtype.ctype}*>({name}.data_ptr<{torch_ctype(desc.dtype)}>()); + """ + return f""" + {desc.dtype.ctype} {name}_ptr{'_cpu' if gpu_storage else ''}[{sym2cpp(desc.total_size)}] = + {{{', '.join(item_to_cpp_literal(e) for e in iterator)}}}; + {gpu_copy_code if gpu_storage else ""} + """ + elif isinstance(desc, data.Scalar): + if str(value.item()) == "-inf": + return f"{desc.dtype.ctype} {name}_ptr = -std::numeric_limits<{desc.dtype.ctype}>::infinity();" + elif str(value.item()) == "inf": + return f"{desc.dtype.ctype} {name}_ptr = std::numeric_limits<{desc.dtype.ctype}>::infinity();" + if desc.dtype.ctype == "bool": + # Special case for bools + bool_str = "true" if value.item() else "false" + return f"{desc.dtype.ctype} {name}_ptr = {bool_str};" + return f"{desc.dtype.ctype} {name}_ptr = {str(value.item())};" + else: + raise ValueError("Unsupported data descriptor") + + +def return_type_str(outputs: List[str]) -> str: + """Generate the return type string for the given outputs. + + :param outputs: List of output names. + :return: The C++ return type string. + """ + return f"""{"Tensor" if len(outputs) == 1 else f"variable_list"}""" + + +def save_non_inputs_outputs(names: List[str]): + """Generate code to save non-input/output tensors for backward pass. + + :param names: List of tensor names to save. + :return: C++ code string for saving tensors. + """ + return "\n".join(f'ctx->saved_data["{n}"] = {n};' for n in names) + + +def recover_saved_inputs_outputs(saved_inputs_outputs: List[str], other_saved: List[str]): + """Generate code to recover saved tensors in backward pass. + + :param saved_inputs_outputs: List of saved input/output tensor names. + :param other_saved: List of other saved tensor names. + :return: C++ code string for recovering saved tensors. + """ + code = "" + if saved_inputs_outputs: + code += "auto saved = ctx->get_saved_variables();\n" + for i, n in enumerate(saved_inputs_outputs): + code += f"\nauto {n} = saved[{i}];" + + for n in other_saved: + code += f'\nauto {n} = ctx->saved_data["{n}"].toTensor();' + + return code + + +def setup_grad_values(backward_result: BackwardResult, sdfg: dace.SDFG, outputs: List[str], + clean_weights: Dict[str, torch.Tensor]) -> str: + """Generate code to setup gradient values for backward pass. + + :param backward_result: The backward pass result containing gradient information. + :param sdfg: The SDFG. + :param outputs: List of output names. + :param clean_weights: Dictionary of constant weights. + :return: C++ code string for gradient setup. + """ + code = "// input grads" + for param_name, grad_name in sorted(backward_result.required_grad_names.items()): + zero_init = backward_result.zero_init.get(param_name, True) + code += "\n" + tensor_init_for_desc(grad_name, sdfg.arrays[grad_name], clean_weights, zeros=zero_init) + + code += "// output grads" + for i, o in enumerate(outputs): + grad_name = backward_result.given_grad_names[o] + code += f'\nauto {grad_name}_ = grad_outputs[{i}];' + + return code + + +def code_for_backward_function(module: 'dace.frontend.ml.torch.DaceModule', forward_sdfg: dace.SDFG, + backward_sdfg: dace.SDFG, backward_result: BackwardResult, + forwarded_arrays: Dict[str, data.Data]) -> str: + """Generate C++ code for a differentiable PyTorch function. + + :param module: The DaCe module. + :param forward_sdfg: The forward SDFG. + :param backward_sdfg: The backward SDFG. + :param backward_result: The backward pass result. + :param forwarded_arrays: Arrays forwarded from forward to backward pass. + :return: Complete C++ code string for the differentiable function. + """ + + inputs, outputs = get_arglist(module) + sdfg_name = forward_sdfg.name + + ret_str = return_type_str(outputs) + + outputs_with_forwarded_outputs = copy.deepcopy(outputs) + outputs_with_forwarded_outputs.extend(n for n in forwarded_arrays if n not in inputs and n not in outputs) + + fwd_ptr_init_code, fwd_sdfg_call_arguments, _ = argument_codegen(forward_sdfg, module.dace_model.clean_weights, + inputs, outputs_with_forwarded_outputs) + + # Inputs are given_grads + forwarded_outputs + bwd_inputs = list(backward_result.given_grad_names.values()) + list(forwarded_arrays) + + # Outputs are required grads + bwd_outputs = list(backward_result.required_grad_names.values()) + + bwd_ptr_init_code, bwd_sdfg_call_arguments, _ = argument_codegen(backward_sdfg, + module.dace_model.clean_weights, + bwd_inputs, + bwd_outputs, + guard_contiguous=list( + backward_result.given_grad_names.values())) + + # Saved inputs/outputs + saved_io_for_backward = [n for n in forwarded_arrays if n in inputs or n in outputs] + other_saved_for_backward = [n for n in forwarded_arrays if n not in inputs and n not in outputs] + return f""" +{get_header(forward_sdfg, backward_sdfg, inputs, outputs, module.use_cuda)} +class {sdfg_name}Function : public torch::autograd::Function<{sdfg_name}Function> {{ + public: + static + {ret_str} + forward( + AutogradContext *ctx, + int64_t fwd_handle_ptr, int64_t bwd_handle_ptr, {", ".join(f"const Tensor& {name}_" for name in inputs)}) {{ + + at::AutoDispatchBelowADInplaceOrView g; + + // initialize outputs + {initialize_outputs_code(module, outputs_with_forwarded_outputs, module.dace_model.clean_weights)} + + {fwd_ptr_init_code} + + // get SDFG state handle + {forward_sdfg.name}Handle_t handle = reinterpret_cast<{forward_sdfg.name}Handle_t>(fwd_handle_ptr); + + + // call SDFG + __program_{forward_sdfg.name}(handle, {fwd_sdfg_call_arguments}); + + // save inputs/outputs for backward + { + f"ctx->save_for_backward({{{', '.join(f'{n}' for n in saved_io_for_backward)}}});" + if saved_io_for_backward else "" + } + + // save non-inputs/outputs + {save_non_inputs_outputs(other_saved_for_backward)} + + // save bwd handle + ctx->saved_data["bwd_handle"] = bwd_handle_ptr; + + // return to torch + return {f"{outputs[0]}" if len(outputs) == 1 + else f"{{{', '.join(o for o in outputs)}}}"}; + }} + + static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs) {{ + // recover bwd_handle_ptr + int64_t bwd_handle_ptr = ctx->saved_data.find("bwd_handle")->second.toInt(); + + // recover saved values + {recover_saved_inputs_outputs(saved_io_for_backward, other_saved_for_backward)} + + // create grad values + // NOTE, it might make sense take these from .grad() + {setup_grad_values(backward_result, backward_sdfg, outputs, module.dace_model.clean_weights)} + + {bwd_ptr_init_code} + + // get SDFG state handle + {backward_sdfg.name}Handle_t handle = reinterpret_cast<{backward_sdfg.name}Handle_t>(bwd_handle_ptr); + + // call bwd SDFG + __program_{backward_sdfg.name}(handle, {bwd_sdfg_call_arguments}); + + // return calculated grads in correct order + // first two grads are None (these are the grads for the handle ptrs) + return {{ + Tensor(), Tensor(), {', '.join(backward_result.required_grad_names[i] if i in backward_result.required_grad_names else 'Tensor()' for i in inputs )} + }}; +}} +}}; + +{ret_str} +{sdfg_name}_autograd(int64_t handle_ptr, int64_t bwd_handle_ptr, {",".join(f"const Tensor& {name}_" for name in inputs)}) {{ +return {sdfg_name}Function::apply( +handle_ptr, bwd_handle_ptr, {", ".join(f"{name}_" for name in inputs)} +); +}} + +TORCH_LIBRARY_IMPL(dace_{sdfg_name}, Autograd{'CUDA' if module.use_cuda else 'CPU'}, m) {{ +m.impl("{sdfg_name}", {sdfg_name}_autograd); +}} +""" + + +def code_for_module(module: 'dace.frontend.ml.torch.DaceModule', compiled_sdfg: CompiledSDFG) -> str: + """Generate the code for an operator that calls the SDFGs in the module. + + :param module: The module. + :param compiled_sdfg: The compiled SDFG. + """ + + inputs, outputs = get_arglist(module) + sdfg_name = compiled_sdfg.sdfg.name + + ret_str = return_type_str(outputs) + ptr_init_code, sdfg_call_arguments, init_arguments = argument_codegen(compiled_sdfg.sdfg, + module.dace_model.clean_weights, inputs, + outputs) + return f""" +{get_header(compiled_sdfg.sdfg, None, inputs, outputs, module.use_cuda)} + +// Function definition +{ret_str} +{sdfg_name}(int64_t handle_ptr, {",".join(f"const Tensor& {name}_" for name in inputs)}) {{ + + // Initialize outputs + {initialize_outputs_code(module, outputs, module.dace_model.clean_weights)} + + {ptr_init_code} + + // Get SDFG state handle + {sdfg_name}Handle_t handle = reinterpret_cast<{sdfg_name}Handle_t>(handle_ptr); + + // Call SDFG + __program_{sdfg_name}(handle, {sdfg_call_arguments}); + + // Return to torch + return {f"{outputs[0]}" if len(outputs) == 1 + else f"{{{', '.join(o for o in outputs)}}}"}; +}} + +TORCH_LIBRARY_IMPL(dace_{sdfg_name}, {'CUDA' if module.use_cuda else 'CPU'}, m) {{ + m.impl("{sdfg_name}", {sdfg_name}); +}} + """ + + +def get_header(fwd_sdfg: dace.SDFG, bwd_sdfg: Optional[dace.SDFG], inputs, outputs, use_cuda: bool) -> str: + """Generate the C++ header code for the PyTorch extension. + + :param fwd_sdfg: The forward SDFG. + :param bwd_sdfg: The backward SDFG (optional). + :param inputs: List of input names. + :param outputs: List of output names. + :param use_cuda: Whether CUDA is used. + :return: C++ header code string. + """ + return f""" +#include +#include +#include "{fwd_sdfg.name}.h" +{"" if bwd_sdfg is None else f'#include "{bwd_sdfg.name}.h"'} +using torch::Tensor; +using torch::DeviceType; +using torch::autograd::tensor_list; +using torch::autograd::variable_list; +using torch::autograd::AutogradContext; + +TORCH_LIBRARY(dace_{fwd_sdfg.name}, m) {{ + m.def("{fwd_sdfg.name}(int handle_ptr,{"int bwd_handle_ptr," if bwd_sdfg else ""} {", ".join('Tensor ' + arg for arg in inputs)}) -> {'Tensor' if len(outputs) == 1 else 'Tensor[]'}"); +}} +""" + + +def _torch_ext_root() -> str: + """Resolve the torch extensions root without using private PyTorch APIs.""" + env = os.environ.get("TORCH_EXTENSIONS_DIR") + if env: + return env + + return os.path.join(os.path.expanduser("~"), ".cache", "torch_extensions") + + +def register_and_compile_torch_extension(module: 'dace.frontend.ml.torch.DaceModule', + dummy_inputs) -> DaceTorchFunction: + """Get a torch callable for the module. This will compile the SDFG, compile a PyTorch C++ operator, register it + with PyTorch and return the function that calls it. + + This function handles code generation for both the forward and backward pass. + + :param module: The module. + :param dummy_inputs: Dummy inputs to initialize the model with. + :return: The callable function for the SDFG. + """ + + # Build the SDFG + # Set all states to not-sync + for state in module.sdfg.nodes(): + state.nosync = True + + environments = { + PyTorch.full_class_path(), + } + if module.backward: + compiled, handle_ptr, compiled_bwd, bwd_handle_ptr = compile_and_init_sdfgs(module, dummy_inputs) + compiled_sdfgs = [compiled, compiled_bwd] if compiled_bwd is not None else [compiled] + ptrs = [handle_ptr, bwd_handle_ptr] if compiled_bwd is not None else [handle_ptr] + if compiled_bwd is not None: + environments.add(get_env_for_sdfg(compiled_bwd).full_class_path()) + bwd_sdfg = compiled_bwd.sdfg + code = code_for_backward_function(module, compiled.sdfg, bwd_sdfg, module._ad_result, module._ad_inp_arrs) + else: + bwd_sdfg = module.backward_sdfg + code = code_for_module(module, compiled) + else: + compiled, handle_ptr = compile_and_init_sdfgs(module, dummy_inputs) + compiled_sdfgs = [compiled] + ptrs = [handle_ptr] + code = code_for_module(module, compiled) + + environments.add(get_env_for_sdfg(compiled).full_class_path()) + code = indent_code(code) + + # ---------- Build the PyTorch module ---------- + base_libname = f"torch_{compiled.sdfg.name}" + program = CodeObject(base_libname, + code, + "cpp", + targets.cpu.CPUCodeGen, + f"Torch{module.sdfg_name}", + environments=environments) + + torch_module_build_path = os.path.join('.dacecache', base_libname) + parts = os.path.normpath(compiled.filename).split(os.sep) + sdfg_folder_name = parts[parts.index('.dacecache') + 1] + + # Treat the case where a hash is added to the SDFG folder dir + backward_sdfg_folder_name = f"{compiled.sdfg.name}_backward_{sdfg_folder_name.removeprefix(compiled.sdfg.name + '_')}" if sdfg_folder_name != compiled.sdfg.name else f"{compiled.sdfg.name}_backward" + compiler.generate_program_folder(None, [program], torch_module_build_path) + + include_path = os.path.abspath(os.path.join('.dacecache', sdfg_folder_name, "include")) + include_path_bwd = os.path.abspath(os.path.join('.dacecache', backward_sdfg_folder_name, "include")) + dace_include_path = os.path.abspath(os.path.join(os.path.dirname(dace.__file__), "runtime", "include")) + dace_include_onnx = os.path.abspath(os.path.join(os.path.dirname(dace.__file__), "libraries", "onnx", "include")) + dace_include_blas = os.path.abspath(os.path.join(os.path.dirname(dace.__file__), "libraries", "blas", "include")) + + code_path = os.path.join('.dacecache', sdfg_folder_name, "src", "cpu", f"{compiled.sdfg.name}.cpp") + code_path_bwd = os.path.join('.dacecache', backward_sdfg_folder_name, "src", "cpu", + f"{compiled.sdfg.name}_backward.cpp") + torch_code_path = os.path.join('.dacecache', base_libname, "src", "cpu", f"{base_libname}.cpp") + + sources = [p for p in [code_path, torch_code_path, code_path_bwd] if os.path.exists(p)] + + pid = os.getpid() + salt = hashlib.sha1(("".join(sources)).encode("utf-8")).hexdigest()[:8] + base_libname = f"torch_{compiled.sdfg.name}" + unique_name = f"{base_libname}_p{pid}_{salt}" + + build_root = _torch_ext_root() # <- uses our helper + unique_build_dir = os.path.join(build_root, unique_name) + os.makedirs(unique_build_dir, exist_ok=True) + + # We pass unique name + unique build directory to avoid FileBaton contention + torch_load( + name=unique_name, + sources=sources, + build_directory=unique_build_dir, + extra_cflags=["-g"], + extra_include_paths=[ + p for p in { + include_path, + include_path_bwd if os.path.exists(include_path_bwd) else None, + dace_include_path, + dace_include_blas, + dace_include_onnx, + } if p + ], + is_python_module=False, + ) + + torch_function = operator.attrgetter(f"dace_{compiled.sdfg.name}.{compiled.sdfg.name}")(torch.ops) + + return DaceTorchFunction(function=torch_function, compiled_sdfgs=compiled_sdfgs, ptr=ptrs) + + +def get_env_for_sdfg(compiled: CompiledSDFG): + """Create an environment for the given compiled SDFG. + + :param compiled: The compiled SDFG. + :return: The environment class for the SDFG. + """ + sdfg_build_path = os.path.abspath(compiled.sdfg.build_folder) + + class SDFGEnvironment: + """Environment for the SDFG.""" + + cmake_minimum_version = None + cmake_packages = [] + cmake_variables = {} + cmake_includes = [os.path.join(sdfg_build_path, "include")] + cmake_compile_flags = [] + cmake_link_flags = [] + cmake_files = [] + cmake_libraries = [os.path.join(sdfg_build_path, "build", platform_library_name(compiled.sdfg.name))] + state_fields = [] + dependencies = [] + headers = [] + init_code = "" + finalize_code = "" + + SDFGEnvironment.__name__ = compiled.sdfg.name + dace.library.environment(SDFGEnvironment) + return SDFGEnvironment + + +def indent_code(code: str) -> str: + """Indent the given code string properly. + + :param code: The code string to indent. + :return: The indented code string. + """ + stream = CodeIOStream() + stream.write(code) + return stream.getvalue() diff --git a/dace/libraries/torch/dispatchers/ctypes_module.py b/dace/libraries/torch/dispatchers/ctypes_module.py new file mode 100644 index 0000000000..fb73179bc1 --- /dev/null +++ b/dace/libraries/torch/dispatchers/ctypes_module.py @@ -0,0 +1,222 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +A torch python autograd function that calls the SDFG using ctypes. + +This can be as an alternative to the C++ registration for large neural nets to +get around the 64 parameter limit of torch's dispatcher. +""" +import copy +import itertools +from typing import List, Dict, Tuple + +from dace import data +import torch +from dace.codegen.compiled_sdfg import CompiledSDFG + +import dace +from dace.autodiff import BackwardResult +from dace.frontend.ml.onnx.importer import create_output_array +from dace.libraries.torch.dispatchers import DaceTorchFunction +from dace.libraries.torch.dispatchers.common import compile_and_init_sdfgs, \ + get_arglist + + +def init_remaining_parameters(module, fwd_arglist, input_names, output_names): + """Initialize remaining parameters that are not inputs or outputs. + + :param module: The DaCe module containing the weights. + :param fwd_arglist: Forward pass argument list. + :param input_names: Names of input tensors. + :param output_names: Names of output tensors. + :return: Dictionary of constant parameters. + :raises ValueError: If a parameter is neither an input/output nor a constant. + """ + # initialize all remaining parameters + remaining = set(fwd_arglist).difference(itertools.chain(input_names, output_names)) + constants = {} + for name in remaining: + # remaining arguments must be constant + if name not in module.dace_model.clean_weights: + raise ValueError(f"Cannot generate ctypes dispatcher: SDFG argument {name} is " + f"not an input or output of the PyTorch Module, and not a" + f" constant.") + constants[name] = module.dace_model.clean_weights[name] + if fwd_arglist[name].storage in dace.dtypes.GPU_STORAGES: + constants[name] = constants[name].cuda() + return constants + + +def callable_for_fwd_module(module: 'dace.frontend.ml.torch.DaceModule', forward_compiled: CompiledSDFG): + """Create a callable for forward pass execution. + + :param module: The DaCe module containing the model. + :param forward_compiled: Compiled SDFG for forward pass. + :return: Function that executes the forward pass. + """ + assert forward_compiled._initialized + + fwd_arglist = forward_compiled.sdfg.arglist() + + input_names, output_names = get_arglist(module) + + constants = init_remaining_parameters(module, fwd_arglist, input_names, output_names) + + def forward(*inputs): + kwargs = {} + + # set the inputs + for i, input_name in enumerate(input_names): + kwargs[input_name] = inputs[i].contiguous() + + # initialize the outputs + for name in output_names: + output_desc = forward_compiled.sdfg.arrays[name] + kwargs[name] = create_output_array( + {}, output_desc, use_torch=True, zeros=False + ) if name not in module.dace_model.initialized_parameters else module.dace_model.initialized_parameters[name] + + # call the SDFG + return forward_compiled(**kwargs, **constants) + + return forward + + +def callable_for_bwd_module(module: 'dace.frontend.ml.torch.DaceModule', forward_compiled: CompiledSDFG, + backward_compiled: CompiledSDFG, backward_result: BackwardResult, + forwarded_arrays: Dict[str, data.Data]): + + assert forward_compiled._initialized + assert backward_compiled._initialized + + fwd_arglist = forward_compiled.sdfg.arglist() + + input_names, output_names = get_arglist(module) + + # arrays that we will forward to the backward pass using saved_for_backward + forwarded_io_names: List[str] = [name for name in forwarded_arrays if name in output_names or name in input_names] + + # non input/output arrays that we are forwarding + forwarded_non_io_names: List[str] = [ + name for name in forwarded_arrays if name not in output_names and name not in input_names + ] + + # for each gradient array that is required, this contains the: + # * name of the gradient + # * whether the array requires zero initialization + # * the descriptor for the array + gradient_descriptors: List[Tuple[str, bool, data.Data]] = [] + + for _, grad_name in backward_result.required_grad_names.items(): + zero_init = backward_result.zero_init.get(grad_name, True) + desc = backward_compiled.sdfg.arrays[grad_name] + + gradient_descriptors.append((grad_name, zero_init, desc)) + + outputs_with_forwarded_outputs: List[str] = copy.deepcopy(output_names) + outputs_with_forwarded_outputs.extend(n for n in forwarded_arrays if n not in input_names and n not in output_names) + + output_gradient_names: List[str] = [ + backward_result.given_grad_names[output] if output in backward_result.given_grad_names else None + for output in output_names + ] + input_gradient_names: List[str] = [ + backward_result.required_grad_names[input] if input in backward_result.required_grad_names else None + for input in input_names + ] + + constants = init_remaining_parameters(module, fwd_arglist, input_names, outputs_with_forwarded_outputs) + + class DifferentiableFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, *inputs): + kwargs = {} + + # set the inputs + for i, input_name in enumerate(input_names): + kwargs[input_name] = inputs[i].contiguous() + + # initialize the outputs + for name in outputs_with_forwarded_outputs: + output_desc = forward_compiled.sdfg.arrays[name] + kwargs[name] = create_output_array( + {}, output_desc, use_torch=True, zeros=True + ) if name not in module.dace_model.initialized_parameters else module.dace_model.initialized_parameters[ + name] + + # call the SDFG + outputs = forward_compiled(**kwargs, **constants) + + # save inputs/outputs for backward + ctx.save_for_backward(*(kwargs[name] for name in forwarded_io_names)) + + # save non- input/output values for backward + for name in forwarded_non_io_names: + setattr(ctx, f"dace_saved_{name}", kwargs[name]) + + return outputs + + @staticmethod + def backward(ctx, *grad_outputs): + kwargs = {} + + # recover saved values + saved = ctx.saved_tensors + for value_name, saved_value in zip(forwarded_io_names, saved): + kwargs[value_name] = saved_value + + for value_name in forwarded_non_io_names: + kwargs[value_name] = getattr(ctx, f"dace_saved_{value_name}") + + # create gradient buffers of inputs + for grad_name, zero_init, desc in gradient_descriptors: + kwargs[grad_name] = create_output_array({}, desc, use_torch=True, zeros=zero_init) + + # grab gradient buffers of outputs + for grad_name, grad_value in zip(output_gradient_names, grad_outputs): + kwargs[grad_name] = grad_value.contiguous() + + # call bwd sdfg + backward_compiled(**kwargs) + + # return grads + grads = tuple(None if name is None else kwargs[name] for name in input_gradient_names) + if len(grads) == 1: + return grads[0] + return grads + + return lambda *args: DifferentiableFunction.apply(*args) + + +def get_ctypes_dispatcher(module: 'dace.frontend.ml.torch.DaceModule', dummy_inputs) -> DaceTorchFunction: + """ + Get a torch callable for the module. This will compile the sdfg and create a + wrapper python callable that can be used with PyTorch. + + :param module: the module. + :param dummy_inputs: dummy inputs to initialize the model with. + :return: the callable function for the SDFG. + """ + + # build the SDFG + # set all states to not-sync + for state in module.sdfg.nodes(): + state.nosync = True + + if module.backward: + # TODO we could return the inferred symbols here + compiled, _, compiled_bwd, _ = compile_and_init_sdfgs(module, dummy_inputs) + + function = callable_for_bwd_module(module, compiled, compiled_bwd, module._ad_result, module._ad_inp_arrs) + compiled_sdfgs = [compiled, compiled_bwd] + else: + compiled, _ = compile_and_init_sdfgs(module, dummy_inputs) + function = callable_for_fwd_module(module, compiled) + compiled_sdfgs = [compiled] + + result = DaceTorchFunction( + function=function, + compiled_sdfgs=compiled_sdfgs, + # no pointers required for ctypes dispatcher + ptr=[]) + return result diff --git a/dace/libraries/torch/dlpack.py b/dace/libraries/torch/dlpack.py new file mode 100644 index 0000000000..614e1a5345 --- /dev/null +++ b/dace/libraries/torch/dlpack.py @@ -0,0 +1,188 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Interface for integrating with DLPack. + +Some of the following code is derived from the following resources: +https://github.com/dmlc/dlpack/blob/main/apps/from_numpy/main.py +https://github.com/vadimkantorov/pydlpack/blob/master/dlpack.py +""" + +import ctypes + +import dace +import torch +import torch.utils.dlpack +from dace import data, dtypes + + +class DLDeviceType(ctypes.c_int): + """DLPack device type enumeration.""" + kDLCPU = 1 + kDLGPU = 2 + kDLCPUPinned = 3 + kDLOpenCL = 4 + kDLVulkan = 7 + kDLMetal = 8 + kDLVPI = 9 + kDLROCM = 10 + kDLExtDev = 12 + + +class DLDataTypeCode(ctypes.c_uint8): + """DLPack data type code enumeration.""" + kDLInt = 0 + kDLUInt = 1 + kDLFloat = 2 + kDLBfloat = 4 + + +class DLDataType(ctypes.Structure): + """DLPack data type structure.""" + _fields_ = [('type_code', DLDataTypeCode), ('bits', ctypes.c_uint8), ('lanes', ctypes.c_uint16)] + + +dace_to_dldtype_dict = { + dace.float32: DLDataType(DLDataTypeCode.kDLFloat, 32, 1), + dace.float64: DLDataType(DLDataTypeCode.kDLFloat, 64, 1), + dace.uint8: DLDataType(DLDataTypeCode.kDLUInt, 8, 1), + dace.uint16: DLDataType(DLDataTypeCode.kDLUInt, 16, 1), + dace.uint32: DLDataType(DLDataTypeCode.kDLUInt, 32, 1), + dace.uint64: DLDataType(DLDataTypeCode.kDLUInt, 64, 1), + dace.int8: DLDataType(DLDataTypeCode.kDLInt, 8, 1), + dace.int16: DLDataType(DLDataTypeCode.kDLInt, 16, 1), + dace.int32: DLDataType(DLDataTypeCode.kDLInt, 32, 1), + dace.int64: DLDataType(DLDataTypeCode.kDLInt, 64, 1), +} + + +class DLContext(ctypes.Structure): + """DLPack context structure for device information.""" + _fields_ = [('device_type', DLDeviceType), ('device_id', ctypes.c_int)] + + +class DLTensor(ctypes.Structure): + """DLPack tensor structure.""" + _fields_ = [('data', ctypes.c_void_p), ('ctx', DLContext), ('ndim', ctypes.c_int), ('dtype', DLDataType), + ('shape', ctypes.POINTER(ctypes.c_int64)), ('strides', ctypes.POINTER(ctypes.c_int64)), + ('byte_offset', ctypes.c_uint64)] + + +class DLManagedTensor(ctypes.Structure): + """DLPack managed tensor structure.""" + pass + + +DLManagedTensorHandle = ctypes.POINTER(DLManagedTensor) + +DeleterFunc = ctypes.CFUNCTYPE(None, DLManagedTensorHandle) + +DLManagedTensor._fields_ = [("dl_tensor", DLTensor), ("manager_ctx", ctypes.c_void_p), ("deleter", DeleterFunc)] + + +def make_manager_ctx(obj) -> ctypes.c_void_p: + """Create a manager context from a Python object. + + This function wraps a Python object in a ctypes void pointer and increments + its reference count to prevent garbage collection while in use by DLPack. + + :param obj: The Python object to create a context for. + :return: A ctypes void pointer to the object. + """ + pyobj = ctypes.py_object(obj) + void_p = ctypes.c_void_p.from_buffer(pyobj) + ctypes.pythonapi.Py_IncRef(pyobj) + return void_p + + +@DeleterFunc +def dl_managed_tensor_deleter(_dl_managed_tensor_handle) -> None: + """Deleter function for DLPack managed tensors. + + This is a no-op deleter because the underlying data is managed by DaCe + and will be freed when the SDFG state struct is deallocated. + + :param _dl_managed_tensor_handle: Handle to the managed tensor (unused). + """ + # Do nothing: the data is freed in the state struct + pass + + +class PyCapsule: + """Python capsule interface for DLPack integration.""" + New = ctypes.pythonapi.PyCapsule_New + New.restype = ctypes.py_object + New.argtypes = (ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p) + + SetContext = ctypes.pythonapi.PyCapsule_SetContext + SetContext.restype = ctypes.c_int + SetContext.argtypes = (ctypes.py_object, ctypes.c_void_p) + + GetContext = ctypes.pythonapi.PyCapsule_GetContext + GetContext.restype = ctypes.c_void_p + GetContext.argtypes = (ctypes.py_object, ) + + GetPointer = ctypes.pythonapi.PyCapsule_GetPointer + GetPointer.restype = ctypes.c_void_p + GetPointer.argtypes = (ctypes.py_object, ctypes.c_char_p) + + Destructor = ctypes.CFUNCTYPE(None, ctypes.py_object) + + SetDestructor = ctypes.pythonapi.PyCapsule_SetDestructor + SetDestructor.argtypes = (ctypes.py_object, Destructor) + SetDestructor.restype = ctypes.c_int + + +def array_to_torch_tensor(ptr: ctypes.c_void_p, desc: data.Array) -> torch.Tensor: + """Convert a DaCe array descriptor to a PyTorch tensor that points to the same data. + + This function performs zero-copy conversion using the DLPack protocol, + allowing PyTorch to access DaCe arrays without data duplication. + + :param ptr: The pointer to the memory of the array. + :param desc: The DaCe array descriptor containing shape, strides, and dtype information. + :return: A PyTorch tensor that shares memory with the DaCe array. + :raises ValueError: If the storage type or dtype is unsupported. + """ + + if desc.storage is dtypes.StorageType.GPU_Global: + device_type = DLDeviceType.kDLGPU + elif desc.storage in [dtypes.StorageType.CPU_Heap, dtypes.StorageType.Default]: + device_type = DLDeviceType.kDLCPU + else: + raise ValueError(f"Unsupported storage type {desc.storage}") + + context = DLContext(device_type=device_type, device_id=0) + + if desc.dtype not in dace_to_dldtype_dict: + raise ValueError(f"Unsupported dtype {desc.dtype}") + dtype = dace_to_dldtype_dict[desc.dtype] + + shape = (ctypes.c_int64 * len(desc.shape))() + for i, s in enumerate(desc.shape): + shape[i] = s + + strides = (ctypes.c_int64 * len(desc.shape))() + for i, s in enumerate(desc.strides): + strides[i] = s + + dltensor = DLTensor(data=ptr, + ctx=context, + ndim=len(desc.shape), + dtype=dtype, + shape=shape, + strides=strides, + byte_offset=0) + + c_obj = DLManagedTensor() + c_obj.dl_tensor = dltensor + c_obj.manager_ctx = ctypes.c_void_p(0) + c_obj.deleter = dl_managed_tensor_deleter + + # The capsule must be used in the same stack frame, otherwise it will be deallocated and the capsule will + # point to invalid data. + capsule = PyCapsule.New(ctypes.byref(c_obj), b"dltensor", None) + tensor: torch.Tensor = torch.utils.dlpack.from_dlpack(capsule) + + # Store the dltensor as an attribute of the tensor so that the tensor takes ownership + tensor._dace_dlpack = c_obj + return tensor diff --git a/dace/libraries/torch/environments/__init__.py b/dace/libraries/torch/environments/__init__.py new file mode 100644 index 0000000000..3af1d61e60 --- /dev/null +++ b/dace/libraries/torch/environments/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +from .pytorch_env import PyTorch, PyTorchGPU diff --git a/dace/libraries/torch/environments/pytorch_env.py b/dace/libraries/torch/environments/pytorch_env.py new file mode 100644 index 0000000000..1d82beebdf --- /dev/null +++ b/dace/libraries/torch/environments/pytorch_env.py @@ -0,0 +1,100 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import os + +try: + import torch.utils.cpp_extension +except ImportError as e: + raise ImportError("PyTorch is required for torch integration. Install with: pip install dace[ml]") from e + +import dace.library + +from dace.codegen.common import platform_library_name, get_gpu_backend + + +@dace.library.environment +class PyTorch: + """Environment used to build PyTorch C++ Operators.""" + + cmake_minimum_version = None + cmake_packages = [] + cmake_variables = {} + cmake_includes = torch.utils.cpp_extension.include_paths() + + @staticmethod + def cmake_libraries(): + """Get the required PyTorch library paths for linking. + + :return: List of library paths for PyTorch CPU libraries. + :raises RuntimeError: If a required library cannot be found. + """ + library_names = ["c10", "torch", "torch_cpu", "torch_python"] + library_paths = [] + + for name in library_names: + for path in torch.utils.cpp_extension.library_paths(): + path = os.path.join(path, platform_library_name(name)) + if os.path.isfile(path): + library_paths.append(path) + break + else: + raise RuntimeError(f"Couldn't locate shared library {name} in PyTorch library paths") + + return library_paths + + cmake_compile_flags = ["-D_GLIBCXX_USE_CXX11_ABI=0"] + cmake_link_flags = [] + cmake_files = [] + state_fields = [] + dependencies = [] + + headers = [] + init_code = "" + finalize_code = "" + + +@dace.library.environment +class PyTorchGPU: + """Environment used to build PyTorch C++ Operators (with CUDA/HIP).""" + + cmake_minimum_version = None + cmake_packages = [] + cmake_variables = {} + cmake_includes = torch.utils.cpp_extension.include_paths() + + @staticmethod + def cmake_libraries(): + """ + Get the required PyTorch library paths for linking with GPU support. + + :return: List of library paths for PyTorch GPU libraries. + :raises RuntimeError: If a required library cannot be found. + """ + backend = get_gpu_backend() + if backend == 'hip': + library_names = ["c10", "torch", "torch_cpu", "torch_hip", "torch_python", "c10_hip"] + runtime_lib = "amdhip64" + else: + library_names = ["c10", "torch", "torch_cpu", "torch_cuda", "torch_python", "c10_cuda"] + runtime_lib = "cudart" + + library_paths = [] + for name in library_names: + for path in torch.utils.cpp_extension.library_paths(device_type=backend): + path = os.path.join(path, platform_library_name(name)) + if os.path.isfile(path): + library_paths.append(path) + break + else: + raise RuntimeError(f"Couldn't locate shared library {name} in PyTorch library paths") + + return library_paths + [runtime_lib] + + cmake_compile_flags = ["-D_GLIBCXX_USE_CXX11_ABI=0"] + cmake_link_flags = [] + cmake_files = [] + state_fields = [] + dependencies = [] + + headers = [] + init_code = "" + finalize_code = "" diff --git a/dace/libraries/torch/torch.md b/dace/libraries/torch/torch.md new file mode 100644 index 0000000000..1a83857c82 --- /dev/null +++ b/dace/libraries/torch/torch.md @@ -0,0 +1,1254 @@ +Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +# DaCe PyTorch Integration Library - Design Document + +## Table of Contents + +1. [Introduction](#1-introduction) +2. [Architecture Overview](#2-architecture-overview) +3. [Directory Structure](#3-directory-structure) +4. [Core Components](#4-core-components) +5. [Dispatcher Strategies](#5-dispatcher-strategies) +6. [Integration Pipeline](#6-integration-pipeline) +7. [Zero-Copy Tensor Sharing](#7-zero-copy-tensor-sharing) +8. [Autograd Integration](#8-autograd-integration) + +--- + +## 1. Introduction + +### 1.1 Purpose + +The DaCe PyTorch Integration Library provides **bidirectional integration** between PyTorch's neural network framework and DaCe's high-performance SDFG execution engine. It enables: + +- **Optimizing PyTorch models** using DaCe's dataflow transformations +- **Accelerating training and inference** with optimized compiled code + +### 1.2 Current Capabilities + +- **Model Optimization**: Convert PyTorch `nn.Module` to optimized DaCe SDFGs +- **Automatic Differentiation**: Integration with PyTorch's autograd system +- **Dual Dispatch**: C++ extension (performance) or CTypes (flexibility) +- **Training Support**: Backward pass generation and gradient computation + +### 1.3 Integration Directions + +The library supports bidirectional data flow: + +**1. PyTorch → DaCe (Primary Direction)**: +```python +# Wrap PyTorch model for DaCe optimization +dace_module = DaceModule(pytorch_model, dummy_inputs, backward=True) + +# Use as drop-in replacement +output = dace_module(input_tensor) +loss.backward() # Autograd works! +``` + +**Workflow**: PyTorch Model → ONNX Export → DaCe SDFG → Compiled Code → PyTorch Operator + +**2. DaCe → PyTorch (Zero-Copy Access)**: +```python +# DaCe arrays accessible as PyTorch tensors (no copy) +torch_tensor = array_to_torch_tensor(ptr, dace_descriptor) +``` + +**Mechanism**: DLPack protocol for memory sharing + +### 1.4 Use Cases + +1. **Neural Network Optimization**: Speed up inference for production deployment +2. **Training Acceleration**: Optimize forward and backward passes for faster training +3. **Custom Operators**: Implement custom PyTorch operations with DaCe +4. **Research**: Experiment with dataflow-level optimizations on ML models +5. **Mixed Workflows**: Combine PyTorch layers with DaCe-optimized modules + +--- + +## 2. Architecture Overview + +### 2.1 High-Level System Diagram + +``` +┌───────────────────────────────────────────────────────────┐ +│ USER INTERFACE │ +│ ┌────────────────────────────────────────────────────┐ │ +│ │ DaceModule (pytorch_model, dummy_inputs, ...) │ │ +│ │ • Wraps PyTorch nn.Module │ │ +│ │ • Provides PyTorch-compatible interface │ │ +│ │ • Supports forward + backward passes │ │ +│ └──────────────────┬─────────────────────────────────┘ │ +└─────────────────────┼─────────────────────────────────────┘ + │ + ▼ +┌───────────────────────────────────────────────────────────┐ +│ ONNX EXPORT PIPELINE │ +│ ┌────────────────────────────────────────────────────┐ │ +│ │ torch.onnx.export() │ │ +│ │ PyTorch Model → ONNX ModelProto │ │ +│ └──────────────────┬─────────────────────────────────┘ │ +└─────────────────────┼─────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────┐ +│ SDFG CONSTRUCTION │ +│ ┌────────────────────────────────────────────────────┐ │ +│ │ ONNXModel (onnx_proto) │ │ +│ │ ONNX → DaCe SDFG (Forward) │ │ +│ └────────────────────┬───────────────────────────────┘ │ +│ │ │ +│ ▼ (if backward=True) │ +│ ┌────────────────────────────────────────────────────┐ │ +│ │ BackwardPassGenerator │ │ +│ │ Forward SDFG → Backward SDFG │ │ +│ └──────────────────┬─────────────────────────────────┘ │ +└─────────────────────┼───────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ DISPATCHER SELECTION │ +│ ┌──────────────────┐ ┌──────────────────┐ │ +│ │ C++ Extension │ OR │ CTypes Module │ │ +│ ├──────────────────┤ ├──────────────────┤ │ +│ │ • Performance │ │ • No param limit │ │ +│ │ • Native PyTorch │ │ • Faster compile │ │ +│ │ • 64 param limit │ │ • Pure Python │ │ +│ └────────┬─────────┘ └────────┬─────────┘ │ +└───────────┼──────────────────────────────┼──────────────────┘ + │ │ + └──────────┬───────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────┐ +│ CODE GENERATION & COMPILATION │ +│ ┌────────────────────────────────────────────────────┐ │ +│ │ SDFG.compile() → Shared Library (.so) │ │ +│ │ C++ Codegen → PyTorch Operator Registration │ │ +│ │ State Initialization → Handle Creation │ │ +│ └──────────────────┬─────────────────────────────────┘ │ +└─────────────────────┼───────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ CALLABLE PYTORCH OPERATOR │ +│ ┌──────────────────────────────────────────────────────┐ │ +│ │ dace_module(inputs) → outputs │ │ +│ │ • Zero-copy tensor access via DLPack │ │ +│ │ • Stateful execution via handles │ │ +│ └──────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────┘ +``` + +### 2.2 Component Interaction Flow + +``` +User Code: dace_module = DaceModule(model, dummy_inputs, backward=True) + ↓ +1. Store model and configuration + ↓ +User Code: output = dace_module(actual_input) # First call + ↓ +2. Detect function is None → Initialize SDFG + ↓ +3. Export to ONNX + a. torch.onnx.export(model, dummy_inputs) + b. Save parameters and model structure + ↓ +4. Import ONNX to DaCe + a. ONNXModel(onnx_proto) + b. Create forward SDFG + ↓ +5. Generate Backward (if backward=True) + a. Determine required gradients + b. BackwardPassGenerator.backward() + c. Create backward SDFG + d. Identify forwarded transients + ↓ +6. Compile SDFGs + a. forward_sdfg.compile() + b. backward_sdfg.compile() (if applicable) + c. Initialize with dummy inputs + ↓ +7. Select and Initialize Dispatcher + If compile_torch_extension: + a. Generate C++ code with autograd + b. Compile as PyTorch extension + c. Register as torch.ops.dace_{name}.{name} + Else: + a. Create Python autograd.Function + b. Wrap with CTypes calls + ↓ +8. Create Wrapper Function + a. Accept user inputs + b. Pass state handles as first args + c. Call compiled operator + d. Return outputs + ↓ +9. Execute and Return + a. Zero-copy tensor access via DLPack + b. Call native SDFG code + c. Return PyTorch tensors + ↓ +User Code: loss.backward() # Backward pass + ↓ +10. PyTorch calls backward function + a. Recover saved tensors from context + b. Allocate gradient buffers + c. Call backward SDFG + d. Return input gradients +``` + +--- + +## 3. Directory Structure + +### 3.1 File Organization + +``` +dace/libraries/torch/ +├── __init__.py # Library exports +│ └── Exports: PyTorch, PyTorchGPU environment classes +│ +├── dlpack.py # Zero-copy tensor sharing +│ ├── DLPack structure definitions +│ ├── Type conversion mappings +│ └── array_to_torch_tensor() - Main conversion function +│ +├── dispatchers/ # Dispatcher implementations +│ ├── __init__.py # Package exports +│ │ └── Exports: DaceTorchFunction, get_ctypes_dispatcher, register_and_compile_torch_extension +│ │ +│ ├── common.py # Shared utilities +│ │ ├── DaceTorchFunction dataclass +│ │ ├── get_arglist() +│ │ └── compile_and_init_sdfgs() +│ │ +│ ├── cpp_torch_extension.py # C++ extension generator +│ │ ├── Type conversion utilities +│ │ ├── C++ code generation for forward/backward +│ │ ├── Autograd function generation +│ │ ├── Tensor initialization +│ │ └── register_and_compile_torch_extension() +│ │ +│ └── ctypes_module.py # CTypes dispatcher +│ ├── init_remaining_parameters() +│ ├── callable_for_fwd_module() +│ ├── callable_for_bwd_module() +│ └── get_ctypes_dispatcher() +│ +└── environments/ # Build configuration + ├── __init__.py # Package exports (1 line) + └── pytorch_env.py # PyTorch environments + ├── PyTorch (CPU environment) + └── PyTorchGPU (GPU environment) + +``` + +### 3.2 Component Responsibilities + +| Component | Lines | Purpose | +|-----------|-------|---------| +| `cpp_torch_extension.py` | 699 | C++ code generation for PyTorch operators | +| `ctypes_module.py` | 222 | CTypes-based dispatcher for large models | +| `dlpack.py` | 188 | Zero-copy tensor sharing via DLPack | +| `common.py` | 112 | Shared dispatcher utilities | +| `pytorch_env.py` | 100 | CMake build configuration | +| `__init__.py` (dispatchers) | 22 | Dispatcher exports | +| `__init__.py` (main) | 23 | Library exports | + +**Note**: DaceModule is located at [dace/frontend/ml/torch/module.py](../../frontend/ml/torch/module.py) (581 lines). + +--- + +## 4. Core Components + +### 4.1 DaceModule: The Main Entry Point + +**Location**: [dace/frontend/ml/torch/module.py](../../frontend/ml/torch/module.py) + +#### Constructor Signature + +```python +class DaceModule: + def __init__( + self, + module: torch.nn.Module, + dummy_inputs: Tuple[torch.Tensor, ...], + cuda: bool = False, + backward: bool = False, + compile_torch_extension: bool = True, + auto_optimize: bool = True, + **onnx_kwargs + ): + """ + Wrap a PyTorch module for DaCe optimization. + + Args: + module: PyTorch module to optimize + dummy_inputs: Sample inputs for shape inference and tracing + cuda: Enable GPU execution + backward: Generate backward pass for training + compile_torch_extension: Use C++ extension (True) or CTypes (False) + auto_optimize: Apply DaCe optimizations + **onnx_kwargs: Additional arguments for torch.onnx.export() + """ +``` + +#### Key Methods + +- **`__call__(*inputs)`**: Execute the optimized module +- **`_initialize_sdfg(inputs)`**: Lazy compilation on first call +- **`_call_params()`**: Get model parameters as tensors + +#### Workflow Summary + +1. **Initialization**: Store model and configuration +2. **First Call**: Export to ONNX → Import to SDFG → Compile → Create dispatcher +3. **Subsequent Calls**: Direct execution via the compiled operator +4. **Backward**: Automatic execution via PyTorch autograd integration + +--- + +### 4.2 DLPack Bridge: Zero-Copy Tensor Sharing + +**Location**: [dlpack.py](dlpack.py) + +The DLPack bridge enables **zero-copy conversion** between DaCe arrays and PyTorch tensors. + +#### DLPack Structure Definitions + +**Type System**: +```python +class DLDeviceType(ctypes.c_int): + kDLCPU = 1 + kDLGPU = 2 + # ... other devices + +class DLDataTypeCode(ctypes.c_uint8): + kDLInt = 0 + kDLUInt = 1 + kDLFloat = 2 + kDLBfloat = 4 + +class DLDataType(ctypes.Structure): + _fields_ = [('type_code', DLDataTypeCode), + ('bits', ctypes.c_uint8), + ('lanes', ctypes.c_uint16)] +``` + +**Tensor Representation**: +```python +class DLTensor(ctypes.Structure): + _fields_ = [ + ('data', ctypes.c_void_p), # Raw pointer + ('ctx', DLContext), # Device info + ('ndim', ctypes.c_int), # Number of dimensions + ('dtype', DLDataType), # Data type + ('shape', ctypes.POINTER(ctypes.c_int64)), # Shape array + ('strides', ctypes.POINTER(ctypes.c_int64)), # Strides array + ('byte_offset', ctypes.c_uint64) # Byte offset + ] +``` + +#### Zero-Copy Conversion + +**Function**: `array_to_torch_tensor(ptr, desc)` + +**Process**: +1. Map the DaCe storage type to DLDeviceType +2. Convert the DaCe dtype to DLDataType +3. Create shape and strides arrays +4. Build the DLTensor structure +5. Wrap in DLManagedTensor with a no-op deleter +6. Create a PyCapsule with name "dltensor" +7. Call `torch.utils.dlpack.from_dlpack(capsule)` to create the PyTorch tensor +8. Store the DLPack structure as a tensor attribute (prevents garbage collection) + +**Memory Ownership**: +- Data is owned by the DaCe SDFG state struct +- No-op deleter ensures that DaCe manages deallocation +- PyTorch tensor is a **view** into DaCe memory (zero-copy) + +**Type Mapping**: +```python +dace_to_dldtype_dict = { + dace.float32: DLDataType(kDLFloat, 32, 1), + dace.float64: DLDataType(kDLFloat, 64, 1), + dace.int32: DLDataType(kDLInt, 32, 1), + # ... complete mapping +} +``` + +--- + +### 4.3 Common Dispatcher Utilities + +**Location**: [dispatchers/common.py](dispatchers/common.py) + +#### DaceTorchFunction Dataclass + +```python +@dataclasses.dataclass +class DaceTorchFunction: + """Encapsulates a compiled DaCe module with PyTorch interface.""" + function: Callable # The callable (torch op or Python function) + compiled_sdfgs: List[CompiledSDFG] # [forward, backward] (or just [forward]) + ptr: List[torch.Tensor] # State handle pointers [fwd_handle, bwd_handle] +``` + +**Purpose**: Provides a uniform interface regardless of dispatcher choice (C++ or CTypes) + +#### compile_and_init_sdfgs() + +**Function Signature**: +```python +def compile_and_init_sdfgs( + module: DaceModule, + dummy_inputs: Tuple[torch.Tensor, ...] +) -> Union[ + Tuple[CompiledSDFG, torch.Tensor], # No backward + Tuple[CompiledSDFG, torch.Tensor, # With backward + CompiledSDFG, torch.Tensor] +]: +``` + +**Process**: +1. Compile the forward SDFG +2. Construct arguments from dummy inputs and parameters +3. Infer symbols from input shapes +4. Allocate forwarded transients (for backward pass) +5. Initialize the forward SDFG state +6. Extract the state handle as `torch.tensor([libhandle])` +7. If backward is enabled: + - Compile the backward SDFG + - Allocate gradient buffers + - Initialize the backward SDFG state + - Extract the backward handle +8. Return the compiled SDFGs and handles + +#### get_arglist() + +**Function**: +```python +def get_arglist(module: DaceModule) -> Tuple[List[str], List[str]]: + """Extracts input and output names with ONNX name cleaning.""" + inputs = [clean_onnx_name(name) for name in module.dace_model.inputs] + outputs = [clean_onnx_name(name) for name in module.dace_model.outputs] + return inputs, outputs +``` + +--- + +### 4.4 PyTorch Environment Configuration + +**Location**: [environments/pytorch_env.py](environments/pytorch_env.py) + +Defines the CMake build configuration for linking against PyTorch libraries. + +#### PyTorch Environment (CPU) + +```python +@dace.library.environment +class PyTorch: + """Environment for building PyTorch C++ operators (CPU).""" + + cmake_includes = torch.utils.cpp_extension.include_paths() + + @staticmethod + def cmake_libraries(): + """Locate and return PyTorch library paths.""" + library_names = ["c10", "torch", "torch_cpu", "torch_python"] + # Search in torch.utils.cpp_extension.library_paths() + return library_paths + + cmake_compile_flags = ["-D_GLIBCXX_USE_CXX11_ABI=0"] # ABI compatibility +``` + +#### PyTorchGPU Environment (GPU) + +```python +@dace.library.environment +class PyTorchGPU: + """Environment for building PyTorch C++ operators (CUDA).""" + + cmake_includes = torch.utils.cpp_extension.include_paths() + + @staticmethod + def cmake_libraries(): + """Locate and return PyTorch CUDA library paths.""" + library_names = ["c10", "torch", "torch_cpu", "torch_cuda", + "torch_python", "c10_cuda"] + return library_paths + ["cudart"] +``` + +**Integration with DaCe**: +- Registered via the `@dace.library.environment` decorator +- DaCe's CMake generator uses these settings for linker configuration +- Ensures that compiled code can call the PyTorch C++ API + +--- + +## 5. Dispatcher Strategies + +### 5.1 Why Two Dispatchers? + +The library provides two dispatcher implementations to handle different use cases: + +| Feature | C++ Extension | CTypes Module | +|---------|--------------|---------------| +| **Performance** | High (native call) | Good (small overhead) | +| **Parameter Limit** | 64 parameters | Unlimited | +| **Compilation Time** | Slower (C++ compile) | Faster (no codegen) | +| **Registration** | `torch.ops.dace_name` | Python function | + +### 5.2 C++ PyTorch Extension + +**Location**: [dispatchers/cpp_torch_extension.py](dispatchers/cpp_torch_extension.py) + +#### Overview + +Generates C++ code that registers a custom PyTorch operator with native autograd support. + +#### Type Conversion Utilities + +**DaCe → PyTorch C++ Types**: +```python +_REPLACED_CTYPES = { + dace.int64: "int64_t", + dace.uint64: "uint64_t", + dace.float16: "at::Half" +} + +def torch_ctype(dtype: dace.typeclass) -> str: + """Convert DaCe type to PyTorch C++ type string.""" + if isinstance(dtype, dace.pointer): + return "int64_t" + elif dtype in _REPLACED_CTYPES: + return _REPLACED_CTYPES[dtype] + else: + return dtype.ctype # e.g., "float", "double" +``` + +**DaCe → PyTorch Tensor Dtype**: +```python +_TYPECLASS_TO_TORCH_DTYPE_STR = { + dt.bool: "kBool", + dt.int8: "kInt8", + dt.float32: "kFloat32", + dt.float64: "kFloat64", + # ... complete mapping +} +``` + +#### Tensor Initialization Code Generation + +**Function**: `tensor_init_for_desc()` + +**Purpose**: Generates C++ code to allocate PyTorch tensors + +**Approach**: +- Checks if tensor is a constant (from weights) +- If constant: embeds values as a C++ initializer list +- If output: allocates with `torch::zeros()` or `torch::empty()` +- Sets proper dtype, device (CPU/CUDA), and layout + +**Example Output**: +```cpp +Tensor output = torch::zeros( + {10, 256}, + torch::TensorOptions() + .dtype(torch::kFloat32) + .device(torch::kCPU) + .layout(torch::kStrided) +); +``` + +#### Forward Function Code Generation + +**Generated Structure**: +```cpp +Tensor forward_function( + int64_t fwd_handle_ptr, + int64_t bwd_handle_ptr, // if backward + const Tensor& input_0_, + const Tensor& input_1_, + // ... more inputs +) { + // 1. Initialize outputs + Tensor output = torch::zeros({...}, torch::TensorOptions()...); + + // 2. Ensure inputs are contiguous + Tensor input_0 = input_0_.contiguous(); + + // 3. Extract pointers + float *input_0_ptr = reinterpret_cast(input_0.data_ptr()); + float *output_ptr = reinterpret_cast(output.data_ptr()); + + // 4. Call SDFG + MySDFGHandle_t handle = reinterpret_cast(fwd_handle_ptr); + __program_my_sdfg(handle, input_0_ptr, output_ptr); + + // 5. Return outputs + return output; // or std::make_tuple(...) for multiple +} +``` + +#### Autograd Function Code Generation + +**Generated Structure**: +```cpp +class MySDFGFunction : public torch::autograd::Function { +public: + static Tensor forward( + AutogradContext *ctx, + int64_t fwd_handle_ptr, + int64_t bwd_handle_ptr, + const Tensor& input_ + ) { + // Run forward pass + Tensor output = forward_function(fwd_handle_ptr, bwd_handle_ptr, input_); + + // Save for backward + ctx->save_for_backward({input_, output}); + + // Save non-I/O transients + ctx->saved_data["intermediate"] = intermediate_value; + + // Save backward handle + ctx->saved_data["bwd_handle"] = bwd_handle_ptr; + + return output; + } + + static tensor_list backward( + AutogradContext *ctx, + tensor_list grad_outputs + ) { + // 1. Recover saved tensors + auto saved = ctx->get_saved_variables(); + auto input = saved[0]; + auto intermediate = ctx->saved_data["intermediate"].toTensor(); + + // 2. Get backward handle + int64_t bwd_handle_ptr = ctx->saved_data["bwd_handle"].toInt(); + MySDFGBackwardHandle_t bwd_handle = + reinterpret_cast(bwd_handle_ptr); + + // 3. Allocate gradient buffers + Tensor grad_input = torch::zeros({...}); // or empty if zero_init=False + + // 4. Get output gradients + Tensor grad_output = grad_outputs[0].contiguous(); + + // 5. Call backward SDFG + __program_my_sdfg_backward( + bwd_handle, + grad_output.data_ptr(), + input.data_ptr(), + intermediate.data_ptr(), + grad_input.data_ptr() + ); + + // 6. Return gradients (None for non-differentiable) + return {Tensor(), Tensor(), grad_input}; // None for handles, grad for input + } +}; +``` + +#### Operator Registration + +**Generated Code**: +```cpp +// Register operator +TORCH_LIBRARY(dace_my_sdfg, m) { + m.def("my_sdfg", forward_function); +} + +// Register autograd if backward enabled +TORCH_LIBRARY_IMPL(dace_my_sdfg, Autograd, m) { + m.impl("my_sdfg", MySDFGFunction::apply); +} +``` + +#### Compilation Process + +**Function**: `register_and_compile_torch_extension()` + +**Steps**: +1. Generate the complete C++ source code +2. Write to a temporary file +3. Use `torch.utils.cpp_extension.load()` for JIT compilation +4. Link against: + - PyTorch libraries (from environment) + - Compiled SDFG shared library +5. Return the operator accessible via `torch.ops.dace_name.name` + +**Limitations**: +- PyTorch dispatcher has **64 parameter limit** +- Longer compilation time (~seconds) +- Requires C++ compiler + +--- + +### 5.3 CTypes Module + +**Location**: [dispatchers/ctypes_module.py](dispatchers/ctypes_module.py) + +#### Overview + +A pure Python dispatcher that calls compiled SDFGs via ctypes, avoiding C++ code generation. + +#### When to Use + +- Models with >64 parameters +- Rapid development/iteration +- Environments where C++ compilation is problematic +- Prototyping and debugging + +#### Forward-Only Callable + +**Function**: `callable_for_fwd_module()` + +**Generated Function**: +```python +def forward(*inputs): + kwargs = {} + + # Set inputs + for i, input_name in enumerate(input_names): + kwargs[input_name] = inputs[i].contiguous() + + # Initialize outputs + for name in output_names: + kwargs[name] = create_output_array( + {}, + forward_compiled.sdfg.arrays[name], + use_torch=True, + zeros=False + ) + + # Add constants + kwargs.update(constants) + + # Call SDFG (ctypes handles conversion) + return forward_compiled(**kwargs) +``` + +#### Forward+Backward Callable + +**Function**: `callable_for_bwd_module()` + +**Generated Autograd Function**: +```python +class DifferentiableFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, *inputs): + kwargs = {} + + # Set inputs + for i, input_name in enumerate(input_names): + kwargs[input_name] = inputs[i].contiguous() + + # Initialize outputs + forwarded transients + for name in outputs_and_forwarded: + kwargs[name] = create_output_array(...) + + # Call forward SDFG + outputs = forward_compiled(**kwargs, **constants) + + # Save I/O tensors for backward + ctx.save_for_backward(*(kwargs[name] for name in forwarded_io_names)) + + # Save non-I/O transients as attributes + for name in forwarded_non_io_names: + setattr(ctx, f"dace_saved_{name}", kwargs[name]) + + return outputs + + @staticmethod + def backward(ctx, *grad_outputs): + kwargs = {} + + # Recover saved I/O tensors + saved = ctx.saved_tensors + for name, val in zip(forwarded_io_names, saved): + kwargs[name] = val + + # Recover non-I/O transients + for name in forwarded_non_io_names: + kwargs[name] = getattr(ctx, f"dace_saved_{name}") + + # Allocate gradient buffers + for grad_name, zero_init, desc in gradient_descriptors: + kwargs[grad_name] = create_output_array(..., zeros=zero_init) + + # Set output gradients from PyTorch + for grad_name, grad_val in zip(output_gradient_names, grad_outputs): + kwargs[grad_name] = grad_val.contiguous() + + # Call backward SDFG + backward_compiled(**kwargs) + + # Return input gradients (None for non-differentiable) + return tuple(kwargs.get(grad_name) for grad_name in input_gradient_names) + +return DifferentiableFunction.apply +``` + +#### Parameter Handling + +**Function**: `init_remaining_parameters()` + +**Purpose**: Extracts constant parameters (model weights) that are neither inputs nor outputs + +**Process**: +1. Identify parameters not in the input/output lists +2. Verify they exist in `module.dace_model.clean_weights` +3. Transfer to CUDA if needed +4. Return as a constants dictionary + +--- + +## 6. Integration Pipeline + +### 6.1 Complete Workflow + +``` +┌─────────────────────────────────────────────────────────┐ +│ Phase 1: Initialization │ +├─────────────────────────────────────────────────────────┤ +│ dace_module = DaceModule(model, dummy_inputs, ...) │ +│ │ +│ 1. Store PyTorch model reference │ +│ 2. Store configuration (cuda, backward, dispatcher) │ +│ 3. Set function = None (lazy compilation) │ +└──────────────────────┬──────────────────────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────┐ +│ Phase 2: First Forward Call │ +├─────────────────────────────────────────────────────────┤ +│ output = dace_module(actual_input) │ +│ │ +│ Detect function is None → Trigger _initialize_sdfg() │ +└──────────────────────┬──────────────────────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────┐ +│ Phase 3: ONNX Export │ +├─────────────────────────────────────────────────────────┤ +│ 1. Call torch.onnx.export(model, dummy_inputs, ...) │ +│ 2. Save exported ONNX ModelProto │ +│ 3. Extract and save model parameters │ +│ 4. Remove initializers that overlap with inputs │ +└──────────────────────┬──────────────────────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────┐ +│ Phase 4: ONNX → DaCe SDFG │ +├─────────────────────────────────────────────────────────┤ +│ 1. Create ONNXModel(onnx_proto) │ +│ - Import ONNX graph to SDFG │ +│ - Run shape inference │ +│ - Apply simplifications │ +│ 2. Store forward SDFG as module.sdfg │ +│ 3. Apply post_onnx_hooks (if any) │ +└──────────────────────┬──────────────────────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────┐ +│ Phase 5: Backward SDFG Generation (if backward=True) │ +├─────────────────────────────────────────────────────────┤ +│ 1. Determine required gradients: │ +│ - Model inputs (if not in clean_weights) │ +│ - Parameters with requires_grad=True │ +│ │ +│ 2. Call make_backward_function(): │ +│ a. Create backward SDFG │ +│ b. Initialize BackwardPassGenerator │ +│ c. Generate reverse operations │ +│ d. Identify forwarded transients │ +│ │ +│ 3. Modify forward SDFG: │ +│ - Make forwarded arrays non-transient (outputs) │ +│ - Convert scalars to size-1 arrays │ +│ │ +│ 4. Store: │ +│ - module.forward_sdfg │ +│ - module.backward_sdfg │ +│ - module._ad_result (BackwardResult) │ +│ - module._ad_inp_arrs (forwarded arrays) │ +│ │ +│ 5. Apply post_autodiff_hooks (if any) │ +└──────────────────────┬──────────────────────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────┐ +│ Phase 6: SDFG Compilation │ +├─────────────────────────────────────────────────────────┤ +│ Call compile_and_init_sdfgs(module, dummy_inputs): │ +│ │ +│ 1. Compile forward SDFG → forward_compiled │ +│ 2. Construct arguments from dummy inputs + parameters │ +│ 3. Call _call_args() to infer symbols │ +│ 4. Allocate forwarded transients (if backward) │ +│ 5. Initialize forward SDFG state │ +│ 6. Extract state handle: fwd_handle=compiled._libhandle │ +│ │ +│ 7. If backward: │ +│ a. Compile backward SDFG → backward_compiled │ +│ b. Allocate gradient buffers │ +│ c. Initialize backward SDFG state │ +│ d. Extract backward handle │ +│ │ +│ 8. Apply post_compile_hooks (if any) │ +│ │ +│ 9. Return compiled SDFGs and handles │ +└──────────────────────┬──────────────────────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────┐ +│ Phase 7: Dispatcher Generation │ +├─────────────────────────────────────────────────────────┤ +│ If compile_torch_extension: │ +│ ├─→ register_and_compile_torch_extension() │ +│ │ 1. Generate C++ code with autograd │ +│ │ 2. Compile as PyTorch extension │ +│ │ 3. Register operator │ +│ │ 4. Return torch.ops.dace_name.name │ +│ │ │ +│ Else: │ +│ └─→ get_ctypes_dispatcher() │ +│ 1. Create Python autograd.Function │ +│ 2. Wrap compiled SDFGs with ctypes calls │ +│ 3. Return callable │ +│ │ +│ Return DaceTorchFunction(function,compiled_sdfgs,ptrs)│ +└──────────────────────┬──────────────────────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────┐ +│ Phase 8: Wrapper Function Creation │ +├─────────────────────────────────────────────────────────┤ +│ Create forward() wrapper: │ +│ │ +│ def forward(*args): │ +│ return compiled_function.function( │ +│ *compiled_function.ptr, # State handles │ +│ *args, # User inputs │ +│ *parameters_to_pass) # Model params │ +│ │ +│ Store as module.function │ +└──────────────────────┬──────────────────────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────┐ +│ Phase 9: Execution │ +├─────────────────────────────────────────────────────────┤ +│ Forward Pass: │ +│ 1. User calls dace_module(input) │ +│ 2. Wrapper extracts .contiguous() tensors │ +│ 3. Zero-copy access via DLPack (if needed) │ +│ 4. Call compiled SDFG with pointers │ +│ 5. Return PyTorch tensors │ +│ │ +│ Backward Pass (if backward=True): │ +│ 1. User calls loss.backward() │ +│ 2. PyTorch autograd calls backward function │ +│ 3. Recover saved tensors from context │ +│ 4. Allocate gradient buffers │ +│ 5. Call backward SDFG │ +│ 6. Return input gradients to PyTorch │ +└─────────────────────────────────────────────────────────┘ +``` + +### 6.2 Data Transformations + +**Input Transformation** (PyTorch → DaCe): +``` +torch.Tensor (user input) + ↓ .contiguous() +torch.Tensor (contiguous memory) + ↓ .data_ptr() (C++ extension) or direct pass (CTypes) +Raw pointer / PyTorch tensor + ↓ Passed to SDFG +SDFG operates on memory +``` + +**Output Transformation** (DaCe → PyTorch): +``` +Allocate torch.Tensor (zeros or empty) + ↓ Extract .data_ptr() +Raw pointer + ↓ Pass to SDFG as output parameter +SDFG fills memory + ↓ No copy needed +Return torch.Tensor (already owns memory) +``` + +**Constant Transformation**: +``` +PyTorch model parameters + ↓ Extract in ONNX export +ONNX initializers + ↓ Save as clean_weights +Embed in C++ (C++ extension) or pass as kwargs (CTypes) +``` + +--- + +## 7. Zero-Copy Tensor Sharing + +### 7.1 The DLPack Protocol + +**Purpose**: Industry-standard protocol for zero-copy tensor exchange between frameworks + +**Key Concept**: Shares memory pointers and metadata between frameworks without copying data + +### 7.2 DaCe → PyTorch Conversion + +**Function**: `array_to_torch_tensor(ptr, desc)` + +**Complete Process**: + +**Step 1: Device Mapping** +```python +if desc.storage == dtypes.StorageType.GPU_Global: + device_type = DLDeviceType.kDLGPU +elif desc.storage in [StorageType.CPU_Heap, StorageType.Default]: + device_type = DLDeviceType.kDLCPU +``` + +**Step 2: Type Conversion** +```python +dtype = dace_to_dldtype_dict[desc.dtype] +# e.g., dace.float32 → DLDataType(kDLFloat, 32, 1) +``` + +**Step 3: Shape and Strides** +```python +shape = (ctypes.c_int64 * len(desc.shape))(*desc.shape) +strides = (ctypes.c_int64 * len(desc.shape))(*desc.strides) +``` + +**Step 4: DLTensor Construction** +```python +dltensor = DLTensor( + data=ptr, # Raw pointer from DaCe + ctx=DLContext(device_type, device_id=0), + ndim=len(desc.shape), + dtype=dtype, + shape=shape, + strides=strides, + byte_offset=0 +) +``` + +**Step 5: Managed Tensor Wrapper** +```python +managed = DLManagedTensor( + dl_tensor=dltensor, + manager_ctx=None, + deleter=no_op_deleter # DaCe owns memory +) +``` + +**Step 6: PyCapsule Creation** +```python +capsule = PyCapsule.New( + ctypes.byref(managed), + b"dltensor", + None +) +``` + +**Step 7: PyTorch Conversion** +```python +tensor = torch.utils.dlpack.from_dlpack(capsule) +tensor._dace_dlpack = managed # Prevent GC +``` + +### 7.3 Memory Lifecycle + +**Ownership**: +- The DaCe SDFG state struct owns the memory +- PyTorch tensor is a **view** that shares the memory +- No-op deleter ensures that DaCe handles deallocation + +**Safety**: +- Keep the SDFG state alive as long as tensors exist +- State handles are stored as `torch.Tensor` objects (ref-counted) +- PyTorch's garbage collector won't free memory prematurely + +**Use Cases**: +- Return DaCe outputs as PyTorch tensors +- Access intermediate SDFG arrays from PyTorch +- Enable PyTorch operations on DaCe memory + +--- + +## 8. Autograd Integration + +### 8.1 Backward Pass Generation + +**Entry Point**: `make_backward_function()` (in `dace/autodiff/torch.py`) + +**Workflow**: + +**Step 1: Determine Required Gradients** +```python +required_grads = [] +for param_name in model.parameters(): + if param_name.requires_grad and param_name not in inputs: + required_grads.append(param_name) +``` + +**Step 2: Create Backward SDFG** +```python +generator = BackwardPassGenerator( + forward_sdfg, + backward_sdfg, + given_gradients=model_outputs, + required_gradients=model_inputs + required_params +) +backward_result = generator.backward() +``` + +**Step 3: Identify Forwarded Transients** +- Identifies values needed for gradient computation +- Example: For `y = x * w`, the backward pass needs both `x` and `w` +- These are marked as non-transient (outputs) in the forward SDFG + +**Step 4: Modify Forward SDFG** +- Makes forwarded arrays non-transient +- Converts scalar outputs to size-1 arrays +- Ensures proper storage types + +### 8.2 C++ Extension Autograd + +**Forward Method**: +```cpp +static Tensor forward(AutogradContext *ctx, int64_t fwd_handle, + int64_t bwd_handle, Tensor input) { + // Execute forward SDFG + Tensor output = forward_function(fwd_handle, bwd_handle, input); + + // Save I/O tensors + ctx->save_for_backward({input, output}); + + // Save non-I/O transients (not saved by PyTorch) + ctx->saved_data["intermediate"] = intermediate_value; + + // Save backward handle + ctx->saved_data["bwd_handle"] = bwd_handle; + + return output; +} +``` + +**Backward Method**: +```cpp +static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs) { + // 1. Recover saved tensors + auto saved = ctx->get_saved_variables(); + auto input = saved[0]; + auto intermediate = ctx->saved_data["intermediate"].toTensor(); + + // 2. Get handles + int64_t bwd_handle = ctx->saved_data["bwd_handle"].toInt(); + + // 3. Allocate gradient buffers + Tensor grad_input = torch::zeros({...}); // If zero_init=True + // OR + Tensor grad_input = torch::empty({...}); // If zero_init=False + + // 4. Get output gradients + Tensor grad_output = grad_outputs[0].contiguous(); + + // 5. Call backward SDFG + __program_backward( + bwd_handle, + grad_output.data_ptr(), + input.data_ptr(), + intermediate.data_ptr(), + grad_input.data_ptr() + ); + + // 6. Return gradients (None for handles, grads for inputs) + return {Tensor(), Tensor(), grad_input}; +} +``` + +### 8.3 CTypes Autograd + +**Forward Method**: +```python +@staticmethod +def forward(ctx, *inputs): + kwargs = {} + + # Set inputs + for i, name in enumerate(input_names): + kwargs[name] = inputs[i].contiguous() + + # Allocate outputs + forwarded transients + for name in all_output_names: + kwargs[name] = create_output_array(...) + + # Call forward SDFG + forward_compiled(**kwargs, **constants) + + # Save I/O for backward + ctx.save_for_backward(*(kwargs[n] for n in forwarded_io_names)) + + # Save non-I/O transients as attributes + for name in forwarded_non_io_names: + setattr(ctx, f"dace_saved_{name}", kwargs[name]) + + return tuple(kwargs[n] for n in model_output_names) +``` + +**Backward Method**: +```python +@staticmethod +def backward(ctx, *grad_outputs): + kwargs = {} + + # Recover I/O tensors + saved = ctx.saved_tensors + for name, val in zip(forwarded_io_names, saved): + kwargs[name] = val + + # Recover non-I/O transients + for name in forwarded_non_io_names: + kwargs[name] = getattr(ctx, f"dace_saved_{name}") + + # Allocate gradient buffers + for grad_name, zero_init in gradient_specs: + kwargs[grad_name] = create_output_array(..., zeros=zero_init) + + # Set output gradients + for grad_name, grad_val in zip(out_grad_names, grad_outputs): + kwargs[grad_name] = grad_val.contiguous() + + # Call backward SDFG + backward_compiled(**kwargs) + + # Return input gradients + return tuple(kwargs.get(g) for g in input_grad_names) +``` + +### 8.4 Gradient Accumulation + +**BackwardResult Structure**: +```python +required_grad_names = { + "input_0": "grad_input_0", + "param_weight": "grad_param_weight" +} + +given_grad_names = { + "output": "grad_output" +} + +zero_init = { + "grad_input_0": True, # Initialize to zero + "grad_param_weight": False # Don't initialize (accumulate) +} +``` + +**Usage**: +- `zero_init=True`: First gradient computation (allocate and initialize to zeros) +- `zero_init=False`: Accumulate into existing buffer (for gradient accumulation) + +--- diff --git a/dace/ml/__init__.py b/dace/ml/__init__.py new file mode 100644 index 0000000000..2e8dc8c341 --- /dev/null +++ b/dace/ml/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. + +# Import PyTorch frontend +try: + from dace.frontend.ml.torch import DaceModule, module +except ImportError: + DaceModule = None + module = None + +# Import ONNX frontend +try: + from dace.frontend.ml.onnx import ONNXModel +except ImportError: + ONNXModel = None + +__all__ = ['DaceModule', 'module', 'ONNXModel'] diff --git a/dace/optimization/data_layout_tuner.py b/dace/optimization/data_layout_tuner.py index 1dab5b3f17..f45325f313 100644 --- a/dace/optimization/data_layout_tuner.py +++ b/dace/optimization/data_layout_tuner.py @@ -5,7 +5,7 @@ import copy import itertools -from typing import Generator, Optional, Tuple, Dict, List, Sequence, Set +from typing import Any, Generator, Optional, Tuple, Dict, List, Sequence, Set from dace import data as dt, SDFG, dtypes from dace.optimization import cutout_tuner @@ -19,6 +19,11 @@ except (ImportError, ModuleNotFoundError): tqdm = lambda x, **kwargs: x +try: + from numpy.typing import ArrayLike +except ImportError: + ArrayLike = Any # type: ignore + class TuningGroups(enum.Enum): Separate = enum.auto() @@ -111,7 +116,7 @@ def pre_evaluate(self, cutout: dace.SDFG, dreport: data_report.InstrumentedDataR cutout.instrument = self.instrument # Prepare original arguments to sub-SDFG from instrumented data report - arguments: Dict[str, dt.ArrayLike] = {} + arguments: Dict[str, ArrayLike] = {} for cstate in cutout.nodes(): for dnode in cstate.data_nodes(): if cutout.arrays[dnode.data].transient: diff --git a/dace/properties.py b/dace/properties.py index b0658d8572..5137abc621 100644 --- a/dace/properties.py +++ b/dace/properties.py @@ -15,7 +15,10 @@ from dace.symbolic import pystr_to_symbolic from dace.dtypes import DebugInfo, typeclass from numbers import Integral, Number -from typing import List, Set, Type, Union, TypeVar, Generic +from typing import List, Set, Type, Union, TypeVar, Generic, TYPE_CHECKING + +if TYPE_CHECKING: + from dace.data import Data as dData T = TypeVar('T') @@ -123,10 +126,16 @@ def tmp_func(self): self._category = category if desc is not None and len(desc) > 0: self.__doc__ = desc - elif self.dtype is not None: - self.__doc__ = "Object property of type %s" % self.dtype.__name__ else: - self.__doc__ = "Object property of type %s" % type(self).__name__ + try: + dtype = self.dtype + if dtype is not None: + self.__doc__ = "Object property of type %s" % dtype.__name__ + else: + self.__doc__ = "Object property of type %s" % type(self).__name__ + except (ImportError, AttributeError): + # Handle circular import case - defer docstring generation + self.__doc__ = "Object property of type %s" % type(self).__name__ def __get__(self, obj, objtype=None) -> T: if obj is None: @@ -887,7 +896,7 @@ def to_json(self, obj): return LambdaProperty.to_string(obj) def from_json(self, s, sdfg=None): - if s == None: return None + if s is None: return None return LambdaProperty.from_string(s) def __set__(self, obj, val): @@ -1343,7 +1352,7 @@ def from_json(obj, context=None): class NestedDataClassProperty(Property): """ Custom property type for nested data. """ - def __get__(self, obj, objtype=None) -> 'Data': + def __get__(self, obj, objtype=None) -> 'dData': return super().__get__(obj, objtype) @property diff --git a/dace/registry.py b/dace/registry.py index 08efeb65ed..d2218b3101 100644 --- a/dace/registry.py +++ b/dace/registry.py @@ -3,10 +3,12 @@ subclasses and values can be registered externally. """ import aenum -from typing import Dict, Type +from typing import Dict, Type, TypeVar +T = TypeVar('T') -def make_registry(cls: Type): + +def make_registry(cls: Type[T]) -> Type[T]: """ Decorator that turns a class into a user-extensible class with three class methods: ``register``, ``unregister``, and ``extensions``. diff --git a/dace/runtime/include/dace/cuda/copy.cuh b/dace/runtime/include/dace/cuda/copy.cuh index 14018b74d4..706e63c9bd 100644 --- a/dace/runtime/include/dace/cuda/copy.cuh +++ b/dace/runtime/include/dace/cuda/copy.cuh @@ -769,14 +769,15 @@ namespace dace #pragma unroll for (int i = 0; i < WRITES; ++i) { - wcr_custom::template reduce( - wcr, ptr + (ltid + i * BLOCK_SIZE) * dst_xstride, + const auto __dace__reduction_lambda = {wcr}; + wcr_custom::template reduce( + __dace__reduction_lambda, ptr + (ltid + i * BLOCK_SIZE) * dst_xstride, *(smem + (ltid + i * BLOCK_SIZE) * src_xstride)); } if (REM_WRITES != 0) { if (ltid < REM_WRITES) - wcr_custom::template reduce( + wcr_custom::reduce( ptr + (ltid + WRITES * BLOCK_SIZE)* dst_xstride, *(smem + (ltid + WRITES * BLOCK_SIZE) * src_xstride)); } @@ -793,14 +794,14 @@ namespace dace #pragma unroll for (int i = 0; i < WRITES; ++i) { - wcr_fixed::template reduce_atomic( + wcr_fixed::reduce_atomic( ptr + (ltid + i * BLOCK_SIZE) * dst_xstride, *(smem + (ltid + i * BLOCK_SIZE) * src_xstride)); } if (REM_WRITES != 0) { if (ltid < REM_WRITES) - wcr_fixed::template reduce_atomic( + wcr_fixed::reduce_atomic( ptr + (ltid + WRITES*BLOCK_SIZE)* dst_xstride, *(smem + (ltid + WRITES * BLOCK_SIZE) * src_xstride)); } diff --git a/dace/runtime/include/dace/cuda/multidim_gbar.cuh b/dace/runtime/include/dace/cuda/multidim_gbar.cuh index 599fe7edb5..55e119daa2 100644 --- a/dace/runtime/include/dace/cuda/multidim_gbar.cuh +++ b/dace/runtime/include/dace/cuda/multidim_gbar.cuh @@ -87,7 +87,11 @@ public: // Threadfence and syncthreads to make sure global writes are visible before // thread-0 reports in with its sync counter __threadfence(); + #if __CUDACC_VER_MAJOR__ >= 13 + __syncthreads(); + #else CTA_SYNC(); + #endif int linear_tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x; int linear_blockid = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.y * gridDim.x; @@ -102,7 +106,11 @@ public: d_vol_sync[linear_blockid] = 1; } + #if __CUDACC_VER_MAJOR__ >= 13 + __syncthreads(); + #else CTA_SYNC(); + #endif // Wait for everyone else to report in for (int peer_block = linear_tid; peer_block < grid; peer_block += block) @@ -113,7 +121,11 @@ public: } } + #if __CUDACC_VER_MAJOR__ >= 13 + __syncthreads(); + #else CTA_SYNC(); + #endif // Let everyone know it's safe to proceed for (int peer_block = linear_tid; peer_block < grid; peer_block += block) @@ -135,7 +147,11 @@ public: } } + #if __CUDACC_VER_MAJOR__ >= 13 + __syncthreads(); + #else CTA_SYNC(); + #endif } } }; diff --git a/dace/runtime/include/dace/dace.h b/dace/runtime/include/dace/dace.h index 960aece94c..d6b8a1cf57 100644 --- a/dace/runtime/include/dace/dace.h +++ b/dace/runtime/include/dace/dace.h @@ -33,14 +33,4 @@ #include "cudainterop.h" #endif -#ifdef DACE_XILINX -#include "xilinx/host.h" -#endif - -#ifdef DACE_INTELFPGA -#include "intel_fpga/host.h" -#endif - -#include "fpga_common.h" - #endif // __DACE_RUNTIME_H diff --git a/dace/runtime/include/dace/fpga_common.h b/dace/runtime/include/dace/fpga_common.h deleted file mode 100644 index f3aba7b0b9..0000000000 --- a/dace/runtime/include/dace/fpga_common.h +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -#pragma once - -// Defined as a struct rather than a class for C compatibility with OpenCL -// For definition, see fpga_host.h -struct dace_fpga_context; diff --git a/dace/runtime/include/dace/fpga_device.h b/dace/runtime/include/dace/fpga_device.h deleted file mode 100644 index f7cb59fba2..0000000000 --- a/dace/runtime/include/dace/fpga_device.h +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -#pragma once - -#ifdef DACE_XILINX -#include "xilinx/device.h" -#endif - -#ifdef DACE_INTELFPGA -#include "intel_fpga/device.h" -#endif - -// Defined as a struct rather than a class for C compatibility with OpenCL -// For definition, see fpga_host.h -struct dace_fpga_context; diff --git a/dace/runtime/include/dace/fpga_host.h b/dace/runtime/include/dace/fpga_host.h deleted file mode 100644 index 82ebe2fd99..0000000000 --- a/dace/runtime/include/dace/fpga_host.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -#pragma once - -#if !defined(DACE_XILINX) && !defined(DACE_INTELFPGA) -#error "Either DACE_XILINX or DACE_INTELFPGA must be defined." -#endif - -#include -#include - -struct dace_fpga_context { - dace_fpga_context() = default; - ~dace_fpga_context() = default; - - inline hlslib::ocl::Context &Get(int device_id = 0) { - auto c = contexts_.find(device_id); - if (c != contexts_.end()) { - return c->second; - } else { - contexts_.emplace(device_id, device_id); - return contexts_.at(device_id); - } - } - - private: - // Don't allow copying or moving - dace_fpga_context(dace_fpga_context const &) = delete; - dace_fpga_context(dace_fpga_context &&) = delete; - dace_fpga_context &operator=(dace_fpga_context const &) = delete; - dace_fpga_context &operator=(dace_fpga_context &&) = delete; - - std::unordered_map contexts_; -}; diff --git a/dace/runtime/include/dace/intel_fpga/device.h b/dace/runtime/include/dace/intel_fpga/device.h deleted file mode 100644 index 187c5c7f36..0000000000 --- a/dace/runtime/include/dace/intel_fpga/device.h +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -#pragma once - -#pragma OPENCL EXTENSION cl_intel_channels : enable - -#include "dace/intel_fpga/math.h" diff --git a/dace/runtime/include/dace/intel_fpga/host.h b/dace/runtime/include/dace/intel_fpga/host.h deleted file mode 100644 index cb4af4b7ff..0000000000 --- a/dace/runtime/include/dace/intel_fpga/host.h +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -#pragma once - -#include // For concurrent kernel launches - -#include "hlslib/intel/OpenCL.h" - -#include // Must be included after hlslib/intel/OpenCL.h -#include -#include -#include diff --git a/dace/runtime/include/dace/intel_fpga/math.h b/dace/runtime/include/dace/intel_fpga/math.h deleted file mode 100644 index c03835a92e..0000000000 --- a/dace/runtime/include/dace/intel_fpga/math.h +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -/** - This file contains a set of preprocessor macros useful - for simple arithmetic -*/ - -#pragma once - -#define int_ceil(N,D) ((int)((N+D-1)/D)) -#define int_floor(N,D) (1+(int)((N-1)/D)) -#define Min min -#define Max max -#define Abs abs diff --git a/dace/runtime/include/dace/reduction.h b/dace/runtime/include/dace/reduction.h index d4201168a6..8fbd2a2f09 100644 --- a/dace/runtime/include/dace/reduction.h +++ b/dace/runtime/include/dace/reduction.h @@ -11,6 +11,8 @@ #ifdef __CUDACC__ #if __has_include() #include + #include + #include #else #include "../../../external/cub/cub/device/device_segmented_reduce.cuh" #include "../../../external/cub/cub/device/device_reduce.cuh" @@ -599,9 +601,19 @@ namespace dace { }; inline auto stridedIterator(size_t stride) { - cub::CountingInputIterator counting_iterator(0); + #if __CUDACC_VER_MAJOR__ >= 13 + thrust::counting_iterator + #else + cub::CountingInputIterator + #endif + counting_iterator(0); StridedIteratorHelper conversion_op(stride); + #if __CUDACC_VER_MAJOR__ >= 13 + thrust::transform_iterator itr(counting_iterator, conversion_op); + #else cub::TransformInputIterator itr(counting_iterator, conversion_op); + #endif + return itr; } #endif diff --git a/dace/runtime/include/dace/vector.h b/dace/runtime/include/dace/vector.h index 58a9242259..fb0378ebff 100644 --- a/dace/runtime/include/dace/vector.h +++ b/dace/runtime/include/dace/vector.h @@ -1,10 +1,11 @@ -// Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +// Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. #ifndef __DACE_VECTOR_H #define __DACE_VECTOR_H #ifdef DACE_XILINX -#include "xilinx/vec.h" -#else // Don't include this file if building for Xilinx +#include +// Don't include the code below if building for Xilinx +#else #include "types.h" @@ -329,5 +330,5 @@ namespace dace } -#endif // XILINX_DEVICE_CODE +#endif // DACE_XILINX #endif // __DACE_VECTOR_H diff --git a/dace/runtime/include/dace/xilinx/access.h b/dace/runtime/include/dace/xilinx/access.h deleted file mode 100644 index d22a71e614..0000000000 --- a/dace/runtime/include/dace/xilinx/access.h +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -#pragma once - -#include "dace/xilinx/vec.h" -#include "dace/xilinx/stream.h" - -namespace dace { - -template -vec Read(vec const *ptr) { - #pragma HLS INLINE - return *ptr; -} - -template -vec Read(vec const &ref) { - #pragma HLS INLINE - return ref; -} - -template -void Write(vec *ptr, vec const &value) { - #pragma HLS INLINE - *ptr = value; -} - -template -void Write(vec &ref, vec const &value) { - #pragma HLS INLINE - ref = value; -} - -template -vec Pack(T const *const ptr) { - #pragma HLS INLINE - vec val; - for (unsigned i = 0; i < vector_length; ++i) { - #pragma HLS UNROLL - val[i] = ptr[i]; - } - return val; -} - -template -void Unpack(vec const &val, T *const ptr) { - #pragma HLS INLINE - for (unsigned i = 0; i < vector_length; ++i) { - #pragma HLS UNROLL - ptr[i] = val[i]; - } -} - -} // End namespace dace diff --git a/dace/runtime/include/dace/xilinx/device.h b/dace/runtime/include/dace/xilinx/device.h deleted file mode 100644 index 05ccefdab7..0000000000 --- a/dace/runtime/include/dace/xilinx/device.h +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -#pragma once - -#include "hlslib/xilinx/Simulation.h" -#include "hlslib/xilinx/Utility.h" -#include "hlslib/xilinx/ShiftRegister.h" - -#include "dace/copy.h" -#include "dace/types.h" -#include "dace/pyinterop.h" - -#include "dace/xilinx/reduce.h" -#include "dace/xilinx/stream.h" -#include "dace/xilinx/vec.h" -#include "dace/xilinx/access.h" - -#include "dace/xilinx/math.h" diff --git a/dace/runtime/include/dace/xilinx/host.h b/dace/runtime/include/dace/xilinx/host.h deleted file mode 100644 index de27734f95..0000000000 --- a/dace/runtime/include/dace/xilinx/host.h +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -#pragma once - -#include // For concurrent kernel launches - -#include "hlslib/xilinx/OpenCL.h" - -#include // Must be included after hlslib/xilinx/OpenCL.h -#include -#include diff --git a/dace/runtime/include/dace/xilinx/math.h b/dace/runtime/include/dace/xilinx/math.h deleted file mode 100644 index 9bbca29117..0000000000 --- a/dace/runtime/include/dace/xilinx/math.h +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -/** - Support for additional math operators on Xilinx -*/ - -#pragma once - -// fabs support for xilinx -template -DACE_HDFI hlslib::DataPack fabs(const hlslib::DataPack& a) { - hlslib::DataPack res; - for (int i = 0; i < vector_length; ++i) { - #pragma HLS UNROLL - const auto elem = a[i]; - res[i] = elem < 0 ? -elem : elem; - } - return res; -} diff --git a/dace/runtime/include/dace/xilinx/reduce.h b/dace/runtime/include/dace/xilinx/reduce.h deleted file mode 100644 index f0275fb84d..0000000000 --- a/dace/runtime/include/dace/xilinx/reduce.h +++ /dev/null @@ -1,183 +0,0 @@ -// Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -#pragma once - -#include -#include -#include - -#include "hlslib/xilinx/Operators.h" -#include "hlslib/xilinx/TreeReduce.h" - -#include "dace/types.h" -#include "dace/xilinx/access.h" - -namespace dace { - -//////////////////////////////////////////////////////////////////////////////// -// Conversion from DACE reduction types to hlslib types -//////////////////////////////////////////////////////////////////////////////// - -template -struct ConvertReduction; - -template <> -struct ConvertReduction { - template - using Operator = hlslib::op::Min; -}; - -template <> -struct ConvertReduction { - template - using Operator = hlslib::op::Max; -}; - -template <> -struct ConvertReduction { - template - using Operator = hlslib::op::Sum; -}; - -template <> -struct ConvertReduction { - template - using Operator = hlslib::op::Product; -}; - -template <> -struct ConvertReduction { - template - using Operator = hlslib::op::And; -}; - -//////////////////////////////////////////////////////////////////////////////// -// Helper functions/template implementation -// (Actual implementation is at the bottom of the file.) -//////////////////////////////////////////////////////////////////////////////// - -namespace { - -template -struct IsRandomAccess { - static constexpr bool value = false; -}; - -template -struct IsRandomAccess< - T, typename std::enable_if::value_type, void>::value>::type> { - static constexpr bool value = true; -}; - -// Vector to a scalar, call tree reduction -template -typename std::enable_if<(IsRandomAccess::value && - !IsRandomAccess::value && W_out - W_in < 0), - T_out>::type -ReduceImpl(T_in &&a, T_out &&b) { -#pragma HLS INLINE - static_assert(W_out != 1, - "Vector reduction only supported for output length 1."); - const auto a_reduced = hlslib::TreeReduce(a); - return Functor::Apply(a_reduced, b[0]); -} - -// Vector to a scalar wrapped in a vector type, call tree reduction -template -typename std::enable_if<(IsRandomAccess::value && - IsRandomAccess::value && W_out - W_in < 0), - T_out>::type -ReduceImpl(T_in &&a, T_out &&b) { -#pragma HLS INLINE - static_assert(W_out != 1, - "Vector reduction only supported for output length 1."); - const auto a_reduced = hlslib::TreeReduce(a); - typename std::remove_reference::type result; - result[0] = Functor::Apply(a_reduced, b[0]); - return result; -} - -// Between two scalars -template -typename std::enable_if<(!IsRandomAccess::value && - !IsRandomAccess::value && W_in == 1 && - W_out == 1), - typename std::remove_reference::type>::type -ReduceImpl(T_in &&a, T_out &&b) { - #pragma HLS INLINE - return Functor::Apply(std::forward(a), std::forward(b)); -} - -// Between two scalars wrapped in vector types -template -typename std::enable_if<(IsRandomAccess::value && - IsRandomAccess::value && W_in == 1 && - W_out == 1), - T_out>::type -ReduceImpl(T_in &&a, T_out &&b) { - #pragma HLS INLINE - typename std::remove_reference::type result; - result[0] = Functor::Apply(a[0], b[0]); - return result; -} - -// Vector-to-vector, apply the reduction on every index -template -typename std::enable_if<(IsRandomAccess() && IsRandomAccess() && - W_in > 1 && W_out > 1), - T_out>::type -ReduceImpl(T_in &&a, T_out &&b) { - #pragma HLS INLINE - return hlslib::op::Wide(std::forward(a), - std::forward(b)); -} - -} // End anonymous namespace - -//////////////////////////////////////////////////////////////////////////////// -// Function exposed to DACE -//////////////////////////////////////////////////////////////////////////////// - -template -T Reduce(T_in &&a, T_out &&b) { - #pragma HLS INLINE - static_assert(W_out <= W_in, - "Output vector length must be shorter or identical to input " - "vector length."); - return ReduceImpl(Read(a), - Read(b)); -} - -template -struct xilinx_wcr_fixed { - static inline T reduce(T *ptr, const T &value) { - #pragma HLS INLINE - using Functor = - typename ConvertReduction::template Operator; - T old_val = *ptr; - *ptr = Reduce(old_val, value); - return old_val; - } -}; - -// Specialization for vector types -template -struct xilinx_wcr_fixed_vec { - static inline vec reduce(vec *ptr, - const vec &value) { - #pragma HLS INLINE - using Functor = - typename ConvertReduction::template Operator; - vec old_val = *ptr; - *ptr = Reduce(old_val, value); - return old_val; - } -}; - -} // End namespace dace diff --git a/dace/runtime/include/dace/xilinx/stream.h b/dace/runtime/include/dace/xilinx/stream.h deleted file mode 100644 index 0036ea43f5..0000000000 --- a/dace/runtime/include/dace/xilinx/stream.h +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -#pragma once - -#include "hlslib/xilinx/Stream.h" -#include "dace/xilinx/vec.h" -#ifndef DACE_SYNTHESIS -#include // std::to_string -#endif - -namespace dace { - -/// Proxy class that wraps hlslib::Stream in a dace::Stream-compatible -/// interface. -template -class FIFO { - public: - FIFO() : stream_() { - #pragma HLS INLINE - } - FIFO(char const *const name) : stream_(name) { - #pragma HLS INLINE - } - FIFO(FIFO const&) = delete; - FIFO(FIFO&&) = delete; - FIFO& operator=(FIFO const&) = delete; - FIFO& operator=(FIFO&&) = delete; - ~FIFO() = default; - - using Data_t = dace::vec; - - Data_t pop_blocking() { - #pragma HLS INLINE - return stream_.ReadBlocking(); - } - - Data_t pop() { - #pragma HLS INLINE - return pop_blocking(); - } - - bool pop_try(Data_t& output) { - #pragma HLS INLINE - return stream_.ReadNonBlocking(output); - } - - template - void push_blocking(U&& val) { - #pragma HLS INLINE - stream_.WriteBlocking(std::forward(val)); - } - - template - void push(U&& val) { - #pragma HLS INLINE - push_blocking(val); - } - - // ArrayView-compatible interface - - template - void write(U&& val) { - #pragma HLS INLINE - push(std::forward(val)); - } - - template - void operator=(U&& val) { - #pragma HLS INLINE - push(std::forward(val)); - } - - operator Data_t() { - #pragma HLS INLINE - return pop_blocking(); - } - -#ifndef DACE_SYNTHESIS - void SetName(std::string const &str) { - stream_.set_name(str.c_str()); - } -#endif - - private: - hlslib::Stream stream_; -}; - -template -void SetNames(FIFO fifos[], char const *const str, - const unsigned num) { - #pragma HLS INLINE -#ifndef DACE_SYNTHESIS - for (unsigned i = 0; i < num; ++i) { - fifos[i].SetName(std::string(str) + "[" + std::to_string(i) + "]"); - } -#endif -} - -} // End namespace dace diff --git a/dace/runtime/include/dace/xilinx/vec.h b/dace/runtime/include/dace/xilinx/vec.h deleted file mode 100644 index 129b26d88a..0000000000 --- a/dace/runtime/include/dace/xilinx/vec.h +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -#pragma once - -#include "hlslib/xilinx/DataPack.h" -#include - -namespace dace { - -template -using vec = - typename std::conditional<(width > 1), hlslib::DataPack, T>::type; - -// Don't distinguish aligned and unaligned on FPGA -template -using vecu = vec; - -} // End namespace dace diff --git a/dace/sdfg/analysis/cutout.py b/dace/sdfg/analysis/cutout.py index 432c765aa0..ea754a2cbc 100644 --- a/dace/sdfg/analysis/cutout.py +++ b/dace/sdfg/analysis/cutout.py @@ -8,7 +8,8 @@ import sympy as sp from collections import deque import copy -from typing import Deque, Dict, List, Set, Tuple, Union, Optional, Any +from typing import Any, Deque, Dict, List, Set, Tuple, Union, Optional +from numbers import Number from dace import data, DataInstrumentationType from dace.sdfg import nodes as nd, SDFG, SDFGState, utils as sdutil, InterstateEdge from dace.memlet import Memlet @@ -19,6 +20,11 @@ from dace.transformation.interstate.loop_detection import DetectLoop from dace.transformation.passes.analysis import StateReachability +try: + from numpy.typing import ArrayLike +except ImportError: + ArrayLike = Any # type: ignore + class SDFGCutout(SDFG): @@ -52,12 +58,12 @@ def _dry_run_base_sdfg(self, *args, **kwargs) -> None: self._instrument_base_sdfg() self._base_sdfg(*args, **kwargs) - def find_inputs(self, *args, **kwargs) -> Dict[str, Union[data.ArrayLike, data.Number]]: + def find_inputs(self, *args, **kwargs) -> Dict[str, Union[ArrayLike, Number]]: self._dry_run_base_sdfg(*args, **kwargs) drep = self._base_sdfg.get_instrumented_data() if drep: - vals: Dict[str, Union[data.ArrayLike, data.Number]] = dict() + vals: Dict[str, Union[ArrayLike, Number]] = dict() for ip in self.input_config.union(set(self.symbols)): val = drep.get_first_version(ip) vals[ip] = val @@ -697,7 +703,7 @@ def _reduce_in_configuration(state: SDFGState, # If there is no unique outer entry node, we use a proxy node as the source. scope_nodes: Set[nd.Node] = set() - if source == None: + if source is None: source = nd.Node() scope_nodes = set(scope_children[None]) else: diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 3ffcf1cb08..46eb37cdb2 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -23,7 +23,6 @@ NODE_TO_SCOPE_TYPE = { dace.nodes.MapEntry: tn.MapScope, dace.nodes.ConsumeEntry: tn.ConsumeScope, - dace.nodes.PipelineEntry: tn.PipelineScope, } @@ -731,7 +730,7 @@ def as_schedule_tree(sdfg: SDFG, in_place: bool = False, toplevel: bool = True) Converts an SDFG into a schedule tree. The schedule tree is a tree of nodes that represent the execution order of the SDFG. Each node in the tree can either represent a single statement (symbol assignment, tasklet, copy, library node, etc.) - or a ``ScheduleTreeScope`` block (map, for-loop, pipeline, etc.) that contains other nodes. + or a ``ScheduleTreeScope`` block (map, for-loop, etc.) that contains other nodes. It can be used to generate code from an SDFG, or to perform schedule transformations on the SDFG. For example, erasing an empty if branch, or merging two consecutive for-loops. The SDFG can then be reconstructed via the diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 0471e0b75d..0d709ee0fd 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -1,4 +1,4 @@ -# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. from dataclasses import dataclass, field from dace import nodes, data, subsets @@ -262,18 +262,6 @@ def as_string(self, indent: int = 0): return result + super().as_string(indent) -@dataclass -class PipelineScope(DataflowScope): - """ - Pipeline scope. - """ - - def as_string(self, indent: int = 0): - rangestr = ', '.join(subsets.Range.dim_to_string(d) for d in self.node.map.range) - result = indent * INDENTATION + f'pipeline {", ".join(self.node.map.params)} in [{rangestr}]:\n' - return result + super().as_string(indent) - - @dataclass class TaskletNode(ScheduleTreeNode): node: nodes.Tasklet diff --git a/dace/sdfg/nodes.py b/dace/sdfg/nodes.py index 31ab055b48..c790f9411d 100644 --- a/dace/sdfg/nodes.py +++ b/dace/sdfg/nodes.py @@ -1,4 +1,4 @@ -# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. """ Contains classes implementing the different types of nodes of the stateful dataflow multigraph representation. """ @@ -20,6 +20,7 @@ from dace.symbolic import issymbolic, pystr_to_symbolic from dace import data, subsets as sbs, dtypes from dace.sdfg import tasklet_validation as tval +from dace.sdfg.type_inference import infer_types, infer_expr_type import pydoc import warnings @@ -538,9 +539,6 @@ def infer_connector_types(self, sdfg, state): raise TypeError('Cannot infer output connectors of tasklet "%s", ' 'not all input connectors have types' % str(self)) - # Avoid import loop - from dace.codegen.tools.type_inference import infer_types - # Get symbols defined at beginning of node, and infer all types in # tasklet syms = state.symbols_defined_at(self) @@ -870,8 +868,6 @@ def free_symbols(self) -> Set[str]: return set(k for k in self._map.range.free_symbols if k not in dyn_inputs) def new_symbols(self, sdfg, state, symbols) -> Dict[str, dtypes.typeclass]: - from dace.codegen.tools.type_inference import infer_expr_type - result = {} # Add map params for p, rng in zip(self._map.params, self._map.range): @@ -1184,8 +1180,6 @@ def free_symbols(self) -> Set[str]: return result - dyn_inputs def new_symbols(self, sdfg, state, symbols) -> Dict[str, dtypes.typeclass]: - from dace.codegen.tools.type_inference import infer_expr_type - result = {} # Add PE index result[self._consume.pe_index] = infer_expr_type(self._consume.num_pes, symbols) @@ -1317,120 +1311,6 @@ def get_param_num(self): # ------------------------------------------------------------------------------ -@dace.serialize.serializable -class PipelineEntry(MapEntry): - - @staticmethod - def map_type(): - return PipelineScope - - @property - def pipeline(self): - return self._map - - @pipeline.setter - def pipeline(self, val): - self._map = val - - def new_symbols(self, sdfg, state, symbols) -> Dict[str, dtypes.typeclass]: - result = super().new_symbols(sdfg, state, symbols) - for param in self.map.params: - result[param] = dtypes.int64 # Overwrite params from Map - for param in self.pipeline.additional_iterators: - result[param] = dtypes.int64 - result[self.pipeline.iterator_str()] = dtypes.int64 - try: - result[self.pipeline.init_condition()] = dtypes.bool - except ValueError: - pass # Overlaps - try: - result[self.pipeline.drain_condition()] = dtypes.bool - except ValueError: - pass # Overlaps - return result - - -@dace.serialize.serializable -class PipelineExit(MapExit): - - @staticmethod - def map_type(): - return PipelineScope - - @property - def pipeline(self): - return self._map - - @pipeline.setter - def pipeline(self, val): - self._map = val - - -@make_properties -class PipelineScope(Map): - """ This a convenience-subclass of Map that allows easier implementation of - loop nests (using regular Map indices) that need a constant-sized - initialization and drain phase (e.g., N*M + c iterations), which would - otherwise need a flattened one-dimensional map. - """ - init_size = SymbolicProperty(default=0, desc="Number of initialization iterations.") - init_overlap = Property(dtype=bool, - default=True, - desc="Whether to increment regular map indices during initialization.") - drain_size = SymbolicProperty(default=1, desc="Number of drain iterations.") - drain_overlap = Property(dtype=bool, - default=True, - desc="Whether to increment regular map indices during pipeline drain.") - additional_iterators = Property(dtype=dict, desc="Additional iterators, managed by the user inside the scope.") - - def __init__(self, - *args, - init_size=0, - init_overlap=False, - drain_size=0, - drain_overlap=False, - additional_iterators={}, - **kwargs): - super(PipelineScope, self).__init__(*args, **kwargs) - self.init_size = init_size - self.init_overlap = init_overlap - self.drain_size = drain_size - self.drain_overlap = drain_overlap - self.additional_iterators = additional_iterators - - def iterator_str(self): - return "__" + "".join(self.params) - - def loop_bound_str(self): - from dace.codegen.common import sym2cpp - bound = 1 - for begin, end, step in self.range: - bound *= (step + end - begin) // step - # Add init and drain phases when relevant - add_str = (" + " + sym2cpp(self.init_size) if self.init_size != 0 and not self.init_overlap else "") - add_str += (" + " + sym2cpp(self.drain_size) if self.drain_size != 0 and not self.drain_overlap else "") - return sym2cpp(bound) + add_str - - def init_condition(self): - """Variable that can be checked to see if pipeline is currently in - initialization phase.""" - if self.init_size == 0: - raise ValueError("No init condition exists for " + self.label) - return self.iterator_str() + "_init" - - def drain_condition(self): - """Variable that can be checked to see if pipeline is currently in - draining phase.""" - if self.drain_size == 0: - raise ValueError("No drain condition exists for " + self.label) - return self.iterator_str() + "_drain" - - -PipelineEntry = indirect_properties(PipelineScope, lambda obj: obj.map)(PipelineEntry) - -# ------------------------------------------------------------------------------ - - # Based on https://stackoverflow.com/a/2020083/6489142 def full_class_path(cls_or_obj: Union[type, object]): if isinstance(cls_or_obj, type): diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 44a085603d..5c9061a3bb 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -1,4 +1,4 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. import ast import collections import copy @@ -27,6 +27,7 @@ from dace.frontend.python import astutils from dace.sdfg import nodes as nd from dace.sdfg.state import ConditionalBlock, ControlFlowBlock, SDFGState, ControlFlowRegion +from dace.sdfg.type_inference import infer_expr_type from dace.distr_types import ProcessGrid, SubArray, RedistrArray from dace.dtypes import validate_name from dace.properties import (DebugInfoProperty, EnumProperty, ListProperty, make_properties, Property, CodeProperty, @@ -262,7 +263,7 @@ def used_symbols(self, all_symbols: bool = False, union_lhs_symbols: bool = Fals # - condition = 'i < 10', assignments = {'i': '3'} # - assignments = {'j': 'i + 1', 'i': '3'} # The new algorithm below addresses the issue by iterating over the edge's condition and assignments and - # exlcuding keys from being considered "defined" if they have been already read. + # excluding keys from being considered "defined" if they have been already read. # Symbols in conditions are always free, because the condition is executed before the assignments cond_symbols = set(map(str, dace.symbolic.symbols_in_ast(self.condition.code[0]))) @@ -360,7 +361,6 @@ def new_symbols(self, sdfg, symbols) -> Dict[str, dtypes.typeclass]: Returns a mapping between symbols defined by this edge (i.e., assignments) to their type. """ - from dace.codegen.tools.type_inference import infer_expr_type if sdfg is not None: alltypes = copy.copy(symbols) @@ -864,7 +864,7 @@ def set_global_code(self, cpp_code: str, location: str = 'frame'): :param cpp_code: The code to set. :param location: The file/backend in which to generate the code. Options are None (all files), "frame", "openmp", - "cuda", "xilinx", "intel_fpga", or any code generator + "cuda", or any code generator name. """ self.global_code[location] = CodeBlock(cpp_code, dace.dtypes.Language.CPP) @@ -877,7 +877,7 @@ def set_init_code(self, cpp_code: str, location: str = 'frame'): :param cpp_code: The code to set. :param location: The file/backend in which to generate the code. Options are None (all files), "frame", "openmp", - "cuda", "xilinx", "intel_fpga", or any code generator + "cuda", or any code generator name. """ self.init_code[location] = CodeBlock(cpp_code, dtypes.Language.CPP) @@ -890,7 +890,7 @@ def set_exit_code(self, cpp_code: str, location: str = 'frame'): :param cpp_code: The code to set. :param location: The file/backend in which to generate the code. Options are None (all files), "frame", "openmp", - "cuda", "xilinx", "intel_fpga", or any code generator + "cuda", or any code generator name. """ self.exit_code[location] = CodeBlock(cpp_code, dtypes.Language.CPP) @@ -903,7 +903,7 @@ def append_global_code(self, cpp_code: str, location: str = 'frame'): :param cpp_code: The code to set. :param location: The file/backend in which to generate the code. Options are None (all files), "frame", "openmp", - "cuda", "xilinx", "intel_fpga", or any code generator + "cuda", or any code generator name. """ if location not in self.global_code: @@ -918,7 +918,7 @@ def append_init_code(self, cpp_code: str, location: str = 'frame'): :param cpp_code: The code to append. :param location: The file/backend in which to generate the code. Options are None (all files), "frame", "openmp", - "cuda", "xilinx", "intel_fpga", or any code generator + "cuda", or any code generator name. """ if location not in self.init_code: @@ -933,7 +933,7 @@ def append_exit_code(self, cpp_code: str, location: str = 'frame'): :param cpp_code: The code to append. :param location: The file/backend in which to generate the code. Options are None (all files), "frame", "openmp", - "cuda", "xilinx", "intel_fpga", or any code generator + "cuda", or any code generator name. """ if location not in self.exit_code: @@ -948,7 +948,7 @@ def prepend_exit_code(self, cpp_code: str, location: str = 'frame'): :param cpp_code: The code to prepend. :param location: The file/backend in which to generate the code. Options are None (all files), "frame", "openmp", - "cuda", "xilinx", "intel_fpga", or any code generator + "cuda", or any code generator name. """ if location not in self.exit_code: @@ -1030,8 +1030,12 @@ def get_latest_report_path(self) -> Optional[str]: :return: A path to the latest instrumentation report, or None if one does not exist. """ path = os.path.join(self.build_folder, 'perf') - files = [f for f in os.listdir(path) if f.startswith('report-')] - if len(files) == 0: + try: + files = [f for f in os.listdir(path) if f.startswith('report-')] + except FileNotFoundError: + return None + + if not files: return None return os.path.join(path, sorted(files, reverse=True)[0]) @@ -1097,7 +1101,7 @@ def clear_data_reports(self): def call_with_instrumented_data(self, dreport: 'InstrumentedDataReport', *args, **kwargs): """ Invokes an SDFG with an instrumented data report, generating and compiling code if necessary. - Arguments given as ``args`` and ``kwargs`` will be overriden by the data containers defined in the report. + Arguments given as ``args`` and ``kwargs`` will be overridden by the data containers defined in the report. :param dreport: The instrumented data report to use upon calling. :param args: Arguments to call SDFG with. @@ -1129,7 +1133,7 @@ def as_schedule_tree(self, in_place: bool = False) -> 'ScheduleTreeScope': Creates a schedule tree from this SDFG and all nested SDFGs. The schedule tree is a tree of nodes that represent the execution order of the SDFG. Each node in the tree can either represent a single statement (symbol assignment, tasklet, copy, library node, - etc.) or a ``ScheduleTreeScope`` block (map, for-loop, pipeline, etc.) that contains other nodes. + etc.) or a ``ScheduleTreeScope`` block (map, for-loop, etc.) that contains other nodes. It can be used to generate code from an SDFG, or to perform schedule transformations on the SDFG. For example, erasing an empty if branch, or merging two consecutive for-loops. @@ -2110,7 +2114,7 @@ def _add_symbols(sdfg: SDFG, desc: dt.Data): if isinstance(v, dt.Data): _add_symbols(sdfg, v) for sym in desc.free_symbols: - if sym.name not in sdfg.symbols: + if sym.name not in sdfg.symbols and sym.name not in sdfg.arg_names: sdfg.add_symbol(sym.name, sym.dtype) # Add the data descriptor to the SDFG and all symbols that are not yet known. @@ -2846,21 +2850,6 @@ def apply_gpu_transformations(self, permissive=permissive, states=states) - def apply_fpga_transformations(self, states=None, validate=True, validate_all=False, permissive=False): - """ Applies a series of transformations on the SDFG for it to - generate FPGA code. - - :note: This is an in-place operation on the SDFG. - """ - # Avoiding import loops - from dace.transformation.interstate import FPGATransformSDFG - - self.apply_transformations(FPGATransformSDFG, - validate=validate, - validate_all=validate_all, - permissive=permissive, - states=states) - def expand_library_nodes(self, recursive=True): """ Recursively expand all unexpanded library nodes in the SDFG, diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index d558053d3d..aeab8505b5 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -1,4 +1,4 @@ -# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. """ Contains classes of a single SDFG state and dataflow subgraphs. """ import ast @@ -28,6 +28,7 @@ from dace.sdfg.graph import (MultiConnectorEdge, NodeNotFoundError, OrderedMultiDiConnectorGraph, SubgraphView, OrderedDiGraph, Edge, generate_element_id) from dace.sdfg.propagation import propagate_memlet +from dace.sdfg.type_inference import infer_expr_type from dace.sdfg.validation import validate_state from dace.subsets import Range, Subset @@ -1797,7 +1798,6 @@ def add_nested_sdfg( raise ValueError('Missing symbols on nested SDFG "%s": %s' % (name, missing_symbols)) # Add new global symbols to nested SDFG - from dace.codegen.tools.type_inference import infer_expr_type for sym, symval in s.symbol_mapping.items(): if sym not in sdfg.symbols: # TODO: Think of a better way to avoid calling @@ -2059,56 +2059,6 @@ def add_reduce( self.add_node(result) return result - def add_pipeline(self, - name, - ndrange, - init_size=0, - init_overlap=False, - drain_size=0, - drain_overlap=False, - additional_iterators={}, - schedule=dtypes.ScheduleType.FPGA_Device, - debuginfo=None, - **kwargs) -> Tuple[nd.PipelineEntry, nd.PipelineExit]: - """ Adds a pipeline entry and pipeline exit. These are used for FPGA - kernels to induce distinct behavior between an "initialization" - phase, a main streaming phase, and a "draining" phase, which require - a additive number of extra loop iterations (i.e., N*M + I + D), - where I and D are the number of initialization/drain iterations. - The code can detect which phase it is in by querying the - init_condition() and drain_condition() boolean variable. - - :param name: Pipeline label - :param ndrange: Mapping between range variable names and - their subsets (parsed from strings) - :param init_size: Number of iterations of initialization phase. - :param init_overlap: Whether the initialization phase overlaps - with the "main" streaming phase of the loop. - :param drain_size: Number of iterations of draining phase. - :param drain_overlap: Whether the draining phase overlaps with - the "main" streaming phase of the loop. - :param additional_iterators: a dictionary containing additional - iterators that will be created for this scope and that are not - automatically managed by the scope code. - The dictionary takes the form 'variable_name' -> init_value - :return: (map_entry, map_exit) node 2-tuple - """ - debuginfo = _getdebuginfo(debuginfo or self._default_lineinfo) - pipeline = nd.PipelineScope(name, - *_make_iterators(ndrange), - init_size=init_size, - init_overlap=init_overlap, - drain_size=drain_size, - drain_overlap=drain_overlap, - additional_iterators=additional_iterators, - schedule=schedule, - debuginfo=debuginfo, - **kwargs) - pipeline_entry = nd.PipelineEntry(pipeline) - pipeline_exit = nd.PipelineExit(pipeline) - self.add_nodes_from([pipeline_entry, pipeline_exit]) - return pipeline_entry, pipeline_exit - def add_edge_pair( self, scope_node, @@ -3230,6 +3180,12 @@ class LoopRegion(ControlFlowRegion): 'do-while style into a while(true) with a break before the update (at the end ' + 'of an iteration) if the condition no longer holds.') loop_variable = Property(dtype=str, default='', desc='The loop variable, if given') + unroll = Property(dtype=bool, + default=False, + desc='If True, indicates that this loop should be unrolled during code generation.') + unroll_factor = Property(dtype=int, + default=1, + desc='If unrolling is enabled, the factor by which to unroll the loop.') def __init__(self, label: str, @@ -3239,7 +3195,9 @@ def __init__(self, update_expr: Optional[Union[str, CodeBlock]] = None, inverted: bool = False, sdfg: Optional['SDFG'] = None, - update_before_condition=True): + update_before_condition=True, + unroll: bool = False, + unroll_factor: int = 1): super(LoopRegion, self).__init__(label, sdfg) if initialize_expr is not None: @@ -3269,6 +3227,8 @@ def __init__(self, self.loop_variable = loop_var or '' self.inverted = inverted self.update_before_condition = update_before_condition + self.unroll = unroll + self.unroll_factor = unroll_factor def inline(self, lower_returns: bool = False) -> Tuple[bool, Any]: """ @@ -3610,7 +3570,6 @@ def _used_symbols_internal(self, def new_symbols(self, symbols) -> Dict[str, dtypes.typeclass]: # Avoid cyclic import - from dace.codegen.tools.type_inference import infer_expr_type from dace.transformation.passes.analysis import loop_analysis if self.init_statement and self.loop_variable: diff --git a/dace/codegen/tools/type_inference.py b/dace/sdfg/type_inference.py similarity index 92% rename from dace/codegen/tools/type_inference.py rename to dace/sdfg/type_inference.py index a753aa3703..1a141c9629 100644 --- a/dace/codegen/tools/type_inference.py +++ b/dace/sdfg/type_inference.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. """ Type inference: traverses code and returns types for all undefined symbols according to C semantics infer() has a lenient implementation: if something it not inferred (for example an unsupported construct) it will not @@ -11,13 +11,54 @@ import ast from dace import data, dtypes from dace import symbolic -from dace.codegen import cppunparse from dace.symbolic import symbol, SymExpr, symstr import sympy import sys import dace.frontend.python.astutils import inspect -from typing import Union +from typing import Callable, Union + +# Additional function names that can be used to infer types +KNOWN_FUNCTIONS: dict[str, Callable[[list[dtypes.typeclass]], dtypes.typeclass]] = { + 'abs': lambda arg_types: arg_types[0], + 'log': lambda arg_types: arg_types[0], + 'min': lambda arg_types: dtypes.result_type_of(arg_types[0], *arg_types), + 'max': lambda arg_types: dtypes.result_type_of(arg_types[0], *arg_types), + 'round': lambda arg_types: dtypes.typeclass(int), +} + +_cmpops = { + "Eq": "==", + "NotEq": "!=", + "Lt": "<", + "LtE": "<=", + "Gt": ">", + "GtE": ">=", + "Is": "==", + "IsNot": "!=", + # "In":"in", "NotIn":"not in" +} + +_funcops = { + "FloorDiv": (" /", "dace::math::ifloor"), + "MatMult": (",", "dace::gemm"), +} + +_py2c_reserved = { + "True": "true", + "False": "false", + "None": "nullptr", + "inf": "INFINITY", + "nan": "NAN", +} + +_py2c_typeconversion = { + "uint": dace.dtypes.typeclass(np.uint32), + "int": dace.dtypes.typeclass(int), + "float": dace.dtypes.typeclass(float), + "float64": dace.dtypes.typeclass(np.float64), + "str": dace.dtypes.pointer(dace.dtypes.int8), +} def infer_types(code, symbols=None): @@ -142,8 +183,8 @@ def _Assign(t, symbols, inferred_symbols): def _AugAssign(t, symbols, inferred_symbols): _dispatch(t.target, symbols, inferred_symbols) # Operations that require a function call - if t.op.__class__.__name__ in cppunparse.CPPUnparser.funcops: - separator, func = cppunparse.CPPUnparser.funcops[t.op.__class__.__name__] + if t.op.__class__.__name__ in _funcops: + separator, func = _funcops[t.op.__class__.__name__] if not t.target.id in symbols and not t.target.id in inferred_symbols: _dispatch(t.target, symbols, inferred_symbols) inferred_type = _dispatch(t.value, symbols, inferred_symbols) @@ -277,7 +318,7 @@ def _JoinedStr(t, symbols, inferred_symbols): def _Name(t, symbols, inferred_symbols): - if t.id in cppunparse._py2c_reserved: + if t.id in _py2c_reserved: return dtypes.typeclass(np.result_type(t.id)) else: # check if this name is a python type, it is in defined_symbols or in local symbols. @@ -286,8 +327,8 @@ def _Name(t, symbols, inferred_symbols): # if this is a statement generated from a tasklet with a dynamic memlet, it could have a leading * (pointer) t_id = t.id[1:] if t.id.startswith('*') else t.id - if t_id.strip("()") in cppunparse._py2c_typeconversion: - inferred_type = cppunparse._py2c_typeconversion[t_id.strip("()")] + if t_id.strip("()") in _py2c_typeconversion: + inferred_type = _py2c_typeconversion[t_id.strip("()")] elif t_id in symbols: # defined symbols could have dtypes, in case convert it to typeclass inferred_type = symbols[t_id] @@ -339,8 +380,8 @@ def _UnaryOp(t, symbols, inferred_symbols): def _BinOp(t, symbols, inferred_symbols): # Operations that require a function call - if t.op.__class__.__name__ in cppunparse.CPPUnparser.funcops: - separator, func = cppunparse.CPPUnparser.funcops[t.op.__class__.__name__] + if t.op.__class__.__name__ in _funcops: + separator, func = _funcops[t.op.__class__.__name__] # get the type of left and right operands for type inference type_left = _dispatch(t.left, symbols, inferred_symbols) @@ -382,7 +423,7 @@ def _Compare(t, symbols, inferred_symbols): if isinstance(inf_type, dtypes.vector): vec_len = inf_type.veclen for o, e in zip(t.ops, t.comparators): - if o.__class__.__name__ not in cppunparse.CPPUnparser.cmpops: + if o.__class__.__name__ not in _cmpops: continue if isinstance(e, ast.Constant) and e.value is None: continue @@ -455,16 +496,9 @@ def _Call(t, symbols, inferred_symbols): if module == 'math': return dtypes.result_type_of(arg_types[0], *arg_types) - # Reading from an Intel channel returns the channel type - if name == 'read_channel_intel': - return arg_types[0] - - if name in ('abs', 'log'): - return arg_types[0] - if name in ('min', 'max'): # binary math operations that do not exist in the math module - return dtypes.result_type_of(arg_types[0], *arg_types) - if name in ('round', ): - return dtypes.typeclass(int) + # Check in known functions + if name in KNOWN_FUNCTIONS: + return KNOWN_FUNCTIONS[name](arg_types) # dtypes (dace.int32, np.float64) can be used as functions inf_type = _infer_dtype(t) diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index ad13aefd51..84660da9a6 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -1099,6 +1099,8 @@ def get_view_edge(state: SDFGState, view: nd.AccessNode) -> gr.MultiConnectorEdg """ in_edges = state.in_edges(view) + # We should ignore empty synchronization edges + in_edges = [e for e in in_edges if not e.data.is_empty()] out_edges = state.out_edges(view) # Invalid case: No data to view @@ -2689,3 +2691,87 @@ def specialize_scalar(sdfg: 'dace.SDFG', scalar_name: str, scalar_val: Union[flo assert isinstance(scalar_name, str) assert isinstance(scalar_val, (float, int, str)) _specialize_scalar_impl(sdfg, sdfg, scalar_name, scalar_val) + + +def in_edge_with_name(node: nd.Node, state: SDFGState, name: str) -> MultiConnectorEdge: + """ + Find the edge that connects to input connector `name` on `node`. + + :param node: the node. + :param state: the state. + :param name: the input connector name. + :return: the edge that connects to connector `name`. + """ + cands = list(state.in_edges_by_connector(node, name)) + if len(cands) != 1: + raise ValueError("Expected to find exactly one edge with name '{}', found {}".format(name, len(cands))) + return cands[0] + + +def out_edge_with_name(node: nd.Node, state: SDFGState, name: str) -> MultiConnectorEdge: + """ + Find the edge that connects to output connector `name` on `node`. + + :param node: the node. + :param state: the state. + :param name: the output connector name. + :return: the edge that connects to connector `name`. + """ + cands = list(state.out_edges_by_connector(node, name)) + if len(cands) != 1: + raise ValueError("Expected to find exactly one edge with name '{}', found {}".format(name, len(cands))) + return cands[0] + + +def in_desc_with_name(node: nd.Node, state: SDFGState, sdfg: SDFG, name: str) -> dt.Data: + """ + Find the descriptor of the data that connects to input connector `name`. + + :param node: the node. + :param state: the state. + :param sdfg: the sdfg. + :param name: the input connector name. + :return: the descriptor of the data that connects to connector `name`. + """ + return sdfg.arrays[in_edge_with_name(node, state, name).data.data] + + +def out_desc_with_name(node: nd.Node, state: SDFGState, sdfg: SDFG, name: str) -> dt.Data: + """ + Find the descriptor of the data that connects to output connector `name`. + + :param node: the node. + :param state: the state. + :param sdfg: the sdfg. + :param name: the output connector name. + :return: the descriptor of the data that connects to connector `name`. + """ + return sdfg.arrays[out_edge_with_name(node, state, name).data.data] + + +def expand_nodes(sdfg: SDFG, predicate: Callable[[nd.Node], bool]): + """ + Recursively expand library nodes in the SDFG using a given predicate. + + :param sdfg: the sdfg to expand nodes on. + :param predicate: a predicate that will be called to check if a node should be expanded. + """ + if sdfg is None: + return + states = list(sdfg.states()) + while len(states) > 0: + state = states.pop() + expanded_something = False + for node in list(state.nodes()): + if isinstance(node, nd.NestedSDFG): + expand_nodes(node.sdfg, predicate=predicate) + elif isinstance(node, nd.LibraryNode): + if predicate(node): + impl_name = node.expand(sdfg, state) + if config.Config.get_bool('debugprint'): + print("Automatically expanded library node \"{}\" with implementation \"{}\".".format( + str(node), impl_name)) + expanded_something = True + + if expanded_something: + states.append(state) diff --git a/dace/sdfg/validation.py b/dace/sdfg/validation.py index 2cb66bc765..9434794cd5 100644 --- a/dace/sdfg/validation.py +++ b/dace/sdfg/validation.py @@ -1,4 +1,4 @@ -# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. """ Exception classes and methods for validation of SDFGs. """ import copy @@ -226,7 +226,6 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context """ # Avoid import loop from dace import data as dt - from dace.codegen.targets import fpga from dace.sdfg.scope import is_devicelevel_fpga, is_devicelevel_gpu from dace.sdfg.state import ConditionalBlock @@ -330,32 +329,6 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context "Array %s cannot be both persistent/external and use Register as " "storage type. Please use a different storage location." % name, sdfg, None) - # Check for valid bank assignments - try: - bank_assignment = fpga.parse_location_bank(desc) - except ValueError as e: - raise InvalidSDFGError(str(e), sdfg, None) - if bank_assignment is not None: - if bank_assignment[0] == "DDR" or bank_assignment[0] == "HBM": - try: - tmp = subsets.Range.from_string(bank_assignment[1]) - except SyntaxError: - raise InvalidSDFGError( - "Memory bank specifier must be convertible to subsets.Range" - f" for array {name}", sdfg, None) - try: - low, high = fpga.get_multibank_ranges_from_subset(bank_assignment[1], sdfg) - except ValueError as e: - raise InvalidSDFGError(str(e), sdfg, None) - if (high - low < 1): - raise InvalidSDFGError( - "Memory bank specifier must at least define one bank to be used" - f" for array {name}", sdfg, None) - if (high - low > 1 and (high - low != desc.shape[0] or len(desc.shape) < 2)): - raise InvalidSDFGError( - "Arrays that use a multibank access pattern must have the size of the first dimension equal" - f" the number of banks and have at least 2 dimensions for array {name}", sdfg, None) - # Check if SDFG is located within a GPU kernel context['in_gpu'] = is_devicelevel_gpu(sdfg, None, None) context['in_fpga'] = is_devicelevel_fpga(sdfg, None, None) @@ -438,7 +411,6 @@ def validate_state(state: 'dace.sdfg.SDFGState', # Avoid import loops from dace import data as dt from dace import subsets as sbs - from dace.codegen.targets import fpga from dace.config import Config from dace.sdfg import SDFG from dace.sdfg import nodes as nd @@ -610,17 +582,6 @@ def validate_state(state: 'dace.sdfg.SDFGState', nid, ) - # Tasklets may only access 1 HBM bank at a time - if isinstance(node, nd.Tasklet): - for attached in state.all_edges(node): - if attached.data.data in sdfg.arrays: - if fpga.is_multibank_array_with_distributed_index(sdfg.arrays[attached.data.data]): - low, high, _ = attached.data.subset[0] - if (low != high): - raise InvalidSDFGNodeError( - "Tasklets may only be directly connected" - " to HBM-memlets accessing only one bank", sdfg, state_id, nid) - # Connector tests ######################################## # Tasklet connector tests diff --git a/dace/subsets.py b/dace/subsets.py index c2e898d9b9..3eeffc4903 100644 --- a/dace/subsets.py +++ b/dace/subsets.py @@ -32,7 +32,7 @@ def bounding_box_cover_exact(subset_a, subset_b, approximation=False) -> bool: and `False` otherwise. :param subset_a: The first subset, the one that should cover. - :param subset_b: The second subset, the one that should be convered. + :param subset_b: The second subset, the one that should be covered. :param approximation: If `True` then use the approximated bounds. """ min_elements_a = subset_a.min_element_approx() if approximation else subset_a.min_element() @@ -73,7 +73,7 @@ def bounding_box_symbolic_positive(subset_a, subset_b, approximation=False) -> b and `False` otherwise. :param subset_a: The first subset, the one that should cover. - :param subset_b: The second subset, the one that should be convered. + :param subset_b: The second subset, the one that should be covered. :param approximation: If `True` then use the approximated bounds. :note: In previous versions this function raised `TypeError` in some cases @@ -646,7 +646,7 @@ def from_string(string): # Open parenthesis found, increase count by 1 if token[i] == '(': count += 1 - # Closing parenthesis found, decrease cound by 1 + # Closing parenthesis found, decrease count by 1 elif token[i] == ')': count -= 1 # Move to the next character @@ -1267,7 +1267,7 @@ def covers(self, other): if isinstance(other, SubsetUnion): for subset in self.subset_list: - # check if ther is a subset in self that covers every subset in other + # check if there is a subset in self that covers every subset in other if all(subset.covers(s) for s in other.subset_list): return True # return False if that's not the case for any of the subsets in self @@ -1285,7 +1285,7 @@ def covers_precise(self, other): if isinstance(other, SubsetUnion): for subset in self.subset_list: - # check if ther is a subset in self that covers every subset in other + # check if there is a subset in self that covers every subset in other if all(subset.covers_precise(s) for s in other.subset_list): return True # return False if that's not the case for any of the subsets in self diff --git a/dace/transformation/auto/auto_optimize.py b/dace/transformation/auto/auto_optimize.py index 561673dbeb..41f9ec62e8 100644 --- a/dace/transformation/auto/auto_optimize.py +++ b/dace/transformation/auto/auto_optimize.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. """ Automatic optimization routines for SDFGs. """ import dace @@ -10,7 +10,7 @@ from dace.sdfg.scope import is_devicelevel_gpu_kernel from dace import config, data as dt, dtypes, Memlet, symbolic from dace.sdfg import SDFG, nodes, graph as gr -from typing import Set, Tuple, Union, List, Iterable, Dict +from typing import Set, Tuple, Union, List, Iterable, Dict, Callable import warnings # Transformations @@ -27,9 +27,6 @@ # Enumerator from dace.transformation.estimator.enumeration import GreedyEnumerator -# FPGA AutoOpt -from dace.transformation.auto import fpga as fpga_auto_opt - GraphViewType = Union[SDFG, SDFGState, gr.SubgraphView, ControlFlowRegion] @@ -342,7 +339,7 @@ def find_fast_library(device: dtypes.DeviceType) -> List[str]: # Returns the optimized library node implementations for the given target # device - if device is dtypes.DeviceType.GPU: + if device == dtypes.DeviceType.GPU: try: backend = get_gpu_backend() except RuntimeError: @@ -354,10 +351,7 @@ def find_fast_library(device: dtypes.DeviceType) -> List[str]: return ['rocBLAS', 'GPUAuto', 'pure'] else: return ['GPUAuto', 'pure'] - - elif device is dtypes.DeviceType.FPGA: - return ['FPGA_PartialSums', 'FPGAPartialReduction', 'FPGA_Accumulate', 'FPGA1DSystolic', 'pure'] - elif device is dtypes.DeviceType.CPU: + elif device == dtypes.DeviceType.CPU: result = [] # BLAS calls @@ -396,19 +390,28 @@ def move_small_arrays_to_stack(sdfg: SDFG) -> None: print(f'Statically allocating {converted} transient arrays') -def set_fast_implementations(sdfg: SDFG, device: dtypes.DeviceType, blocklist: List[str] = None): +def set_fast_implementations(sdfg: SDFG, + device: dtypes.DeviceType, + blocklist: List[str] = None, + find_fast_library_fn: Callable[[dtypes.DeviceType], List[str]] = None) -> None: """ Set fast library node implementations for the given device :param sdfg: The SDFG to optimize. :param device: the device to optimize for. :param blocklist: list of disallowed implementations. + :param find_fast_library_fn: function that returns the prioritized list of + implementations for the given device, which will take priority over + the built-in ``find_fast_library`` function. :note: Operates in-place on the given SDFG. """ - if blocklist is None: - implementation_prio = find_fast_library(device) - else: - implementation_prio = [i for i in find_fast_library(device) if i not in blocklist] + implementation_prio = [] + if find_fast_library_fn is not None: + implementation_prio.extend(find_fast_library_fn(device)) + implementation_prio.extend(find_fast_library(device)) + + if blocklist is not None: + implementation_prio = [i for i in implementation_prio if i not in blocklist] # specialized nodes: pre-expand for current_sdfg in sdfg.all_sdfgs_recursive(): @@ -559,7 +562,8 @@ def auto_optimize(sdfg: SDFG, validate: bool = True, validate_all: bool = False, symbols: Dict[str, int] = None, - use_gpu_storage: bool = False) -> SDFG: + use_gpu_storage: bool = False, + find_fast_library_fn: Callable[[dtypes.DeviceType], List[str]] = None) -> SDFG: """ Runs a basic sequence of transformations to optimize a given SDFG to decent performance. In particular, performs the following: @@ -580,6 +584,9 @@ def auto_optimize(sdfg: SDFG, :param validate_all: If True, validates the SDFG after every step. :param symbols: Optional dict that maps symbols (str/symbolic) to int/float :param use_gpu_storage: If True, changes the storage of non-transient data to GPU global memory. + :param find_fast_library_fn: Optional function that returns the prioritized list of + implementations for the given device, which will take priority over + the existing set of fast libraries found using auto-optimize. :return: The optimized SDFG. :note: Operates in-place on the given SDFG. :note: This function is still experimental and may harm correctness in @@ -590,6 +597,7 @@ def auto_optimize(sdfg: SDFG, # Simplification and loop parallelization transformed = True sdfg.apply_transformations_repeated(TrivialMapElimination, validate=validate, validate_all=validate_all) + while transformed: sdfg.simplify(validate=False, validate_all=validate_all) l2ms = sdfg.apply_transformations_repeated((LoopToMap, RefineNestedAccess), @@ -612,7 +620,6 @@ def auto_optimize(sdfg: SDFG, # fuse subgraphs greedily sdfg.simplify() sdfg.reset_cfg_list() - greedy_fuse(sdfg, device=device, validate_all=validate_all) # fuse stencils greedily @@ -622,16 +629,6 @@ def auto_optimize(sdfg: SDFG, from dace.transformation.interstate import MoveLoopIntoMap sdfg.apply_transformations_repeated([MoveLoopIntoMap]) - if device == dtypes.DeviceType.FPGA: - # apply FPGA Transformations - sdfg.apply_fpga_transformations() - fpga_auto_opt.fpga_global_to_local(sdfg) - fpga_auto_opt.fpga_rr_interleave_containers_to_banks(sdfg) - - # Set all library nodes to expand to fast library calls - set_fast_implementations(sdfg, device) - return sdfg - # Tiled WCR and streams for nsdfg in list(sdfg.all_sdfgs_recursive()): tile_wcrs(nsdfg, validate_all) @@ -646,7 +643,7 @@ def auto_optimize(sdfg: SDFG, pass # Set all library nodes to expand to fast library calls - set_fast_implementations(sdfg, device) + set_fast_implementations(sdfg, device, find_fast_library_fn=find_fast_library_fn) # NOTE: We need to `infer_types` in case a LibraryNode expands to other LibraryNodes (e.g., np.linalg.solve) infer_types.infer_connector_types(sdfg) diff --git a/dace/transformation/auto/fpga.py b/dace/transformation/auto/fpga.py deleted file mode 100644 index 8139337ed1..0000000000 --- a/dace/transformation/auto/fpga.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -""" -FPGA-Oriented Automatic optimization routines for SDFGs. -""" - -from dace.sdfg import SDFG, SDFGState, trace_nested_access -from dace import config, data as dt, dtypes, Memlet, symbolic -from dace.sdfg import SDFG, nodes, graph as gr - - -def fpga_global_to_local(sdfg: SDFG, max_size: int = 1048576) -> None: - """ Takes an entire SDFG and changes the storage type of a global FPGA data container - to Local in the following situation: - - the data is transient, - - the data is not a transient shared with other states, and - - the data has a compile-time known size. - :param sdfg: The SDFG to operate on. It must be a top-level SDFG. - :param max_size: maximum size (in bytes) that a container can have to be considered for - storage type change - :note: Operates in-place on the SDFG. - """ - converted = [] - - for name, desc in sdfg.arrays.items(): - if desc.transient and name not in sdfg.shared_transients() and desc.storage == dtypes.StorageType.FPGA_Global: - - # Get the total size, trying to resolve it to constant if it is a symbol - total_size = symbolic.resolve_symbol_to_constant(desc.total_size, sdfg) - - if total_size is not None and total_size * desc.dtype.bytes <= max_size: - desc.storage = dtypes.StorageType.FPGA_Local - converted.append(name) - - # update all access nodes that refer to this container - for node, graph in sdfg.all_nodes_recursive(): - if isinstance(node, nodes.AccessNode): - trace = trace_nested_access(node, graph, graph.parent) - - for (_, candidate), memlet_trace, state_trace, sdfg_trace in trace: - if candidate is not None and candidate.data == name: - nodedesc = node.desc(graph) - nodedesc.storage = dtypes.StorageType.FPGA_Local - if config.Config.get_bool('debugprint'): - print(f'Applied {len(converted)} Global-To-Local{": " if len(converted)>0 else "."} {", ".join(converted)}') - - -def fpga_rr_interleave_containers_to_banks(sdfg: SDFG, num_banks: int = 4, memory_type: str = "DDR"): - """ - Allocates the (global) arrays to FPGA off-chip memory banks, interleaving them in a - Round-Robin (RR) fashion. This applies to all the arrays in the SDFG hierarchy. - - :param sdfg: The SDFG to operate on. - :param num_banks: number of off-chip memory banks to consider - :param memory_type: type of off-chip memory, either "DDR" or "HBM" (if the target FPGA supports it) - :return: a list containing the number of (transient) arrays allocated to each bank - :note: Operates in-place on the SDFG. - """ - - if memory_type.upper() not in {"DDR", "HBM"}: - raise ValueError("Memory type should be either \"DDR\" or \"HBM\"") - - # keep track of memory allocated to each bank - num_allocated = [0 for i in range(num_banks)] - - i = 0 - for sd, aname, desc in sdfg.arrays_recursive(): - if not isinstance(desc, dt.Stream) and desc.storage == dtypes.StorageType.FPGA_Global and desc.transient: - desc.location["memorytype"] = memory_type.upper() - desc.location["bank"] = str(i % num_banks) - num_allocated[i % num_banks] = num_allocated[i % num_banks] + 1 - i = i + 1 - - return num_allocated diff --git a/dace/transformation/dataflow/__init__.py b/dace/transformation/dataflow/__init__.py index 992db083fd..9cf8ad8a30 100644 --- a/dace/transformation/dataflow/__init__.py +++ b/dace/transformation/dataflow/__init__.py @@ -50,8 +50,6 @@ from .gpu_transform_local_storage import GPUTransformLocalStorage from .mpi import MPITransformMap from .warp_tiling import WarpTiling -from .bank_split import BankSplit -from .hbm_transform import HbmTransform # Algorithmic from .matrix_product_transpose import MatrixProductTranspose diff --git a/dace/transformation/dataflow/add_threadblock_map.py b/dace/transformation/dataflow/add_threadblock_map.py index 9bc5a8a2a7..febdb12861 100644 --- a/dace/transformation/dataflow/add_threadblock_map.py +++ b/dace/transformation/dataflow/add_threadblock_map.py @@ -76,8 +76,8 @@ def validate_block_size_limits(kernel_map_entry: nodes.MapEntry, block_size: Lis kernel_map_label = kernel_map_entry.map.label total_block_size = product(block_size) - limit = Config.get('compiler', 'cuda', 'block_size_limit') - lastdim_limit = Config.get('compiler', 'cuda', 'block_size_lastdim_limit') + limit = int(Config.get('compiler', 'cuda', 'block_size_limit')) + lastdim_limit = int(Config.get('compiler', 'cuda', 'block_size_lastdim_limit')) if (total_block_size > limit) == True: raise ValueError(f'Block size for kernel "{kernel_map_label}" ({block_size}) ' diff --git a/dace/transformation/dataflow/bank_split.py b/dace/transformation/dataflow/bank_split.py deleted file mode 100644 index 8161b22ede..0000000000 --- a/dace/transformation/dataflow/bank_split.py +++ /dev/null @@ -1,170 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -from typing import Any, Dict, Iterable, List, Tuple, Union - -from dace import data, dtypes, properties -from dace.sdfg import utils -from dace.transformation import transformation -from dace.sdfg import nodes as nd -from dace import SDFG, SDFGState, memlet -from dace import symbolic -import functools - - -@properties.make_properties -class BankSplit(transformation.SingleStateTransformation): - """ - A transformation that allow splitting an array and distribute it on another - array with one dimension more, or vice versa. Works with arbitrary arrays, - but its intended use case is to distribute data on many HBM-banks. - Matches any 2 AccessNodes connected by an edge, if the dimensionality of the two accessed - arrays differ by exactly one. The sizes of the arrays have to be large enough with - respect to the split executed, but this is not verified. While it is allowed to use symbolics - for the shapes of the array, it is expected that each dimension is divisible by the number - of splits specified. - - When appling an unrolled map is generated around the accessnodes, which copies the parts of - the array to the target array. - - Examples: - Distribute: Suppose for example we copy from A to B, where A has shape [100, 100] and B shape - [10, 100, 10]. We can distribute A in that case to B using the transformation by setting - split_array_info=[1, 10]. A will then be divided along it's second dimension into 10 parts - of size [100, 10] and distributed on B. - Gather: Suppose A has shape [4, 50, 50] and B has shape [100, 100]. If one sets - split_array_info to [2, 2] and applies the transformation, it will split - equally in all dimensions. - Therefore A[0] will be copied to B[0:50, 0:50], A[1] to B[0:50, 50:100], A[2] to B[50:100, 0:50] and - A[3] to B[50:100, 50:100]. - - Note that simply reversing the AccessNodes for the arrays in the above examples would - have lead to the inverse operation, i.e. the gather would become a distribute and - the other way around. - """ - - src_node = transformation.PatternNode(nd.AccessNode) - dst_node = transformation.PatternNode(nd.AccessNode) - - # dtype=List[int] - split_array_info = properties.Property( - dtype=List, - default=None, - allow_none=True, - desc="Describes how many times this array is split in each dimension, " - "where the k-th number describes how many times dimension k is split. " - "If the k-th number is 1 this means that the array is not split in " - "the k-th dimension at all. " - "If None, then the transform will split the first dimension exactly shape[0] times.") - - default_to_storage = properties.Property( - dtype=dtypes.StorageType, - default=dtypes.StorageType.CPU_Heap, - allow_none=False, - desc="The storage type of involved arrays will be set to the value of this property if " - "they have Default storage type. ") - - def _get_split_size(self, virtual_shape: Iterable, split_count: List[int]) -> List[int]: - """ - :return: the shape of a part-array on one HBMbank - """ - new_shape_list = [] - for d in range(len(virtual_shape)): - if split_count[d] != 1: - new_shape_list.append(virtual_shape[d] // split_count[d]) - else: - new_shape_list.append(virtual_shape[d]) - return new_shape_list - - def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissive: bool) -> bool: - src = self.src_node - dst = self.dst_node - src_array = sdfg.arrays[src.data] - dst_array = sdfg.arrays[dst.data] - - plain_array = lambda array: isinstance(array, data.Array) and not isinstance(array, data.View) - - if not plain_array(src_array): - return False - if not plain_array(dst_array): - return False - - # same dimensions means HBM-array needs 1 dimension more - collect_src = len(src_array.shape) - 1 == len(dst_array.shape) - distribute_dst = len(src_array.shape) + 1 == len(dst_array.shape) - if collect_src and symbolic.issymbolic(src_array.shape[0], sdfg.constants): - return False - elif distribute_dst and symbolic.issymbolic(dst_array.shape[0], sdfg.constants): - return False - return collect_src or distribute_dst - - @classmethod - def expressions(cls): - return [utils.node_path_graph(cls.src_node, cls.dst_node)] - - def apply(self, graph: SDFGState, sdfg: SDFG) -> Union[Any, None]: - # Load/parse infos from the SDFG - src = self.src_node - dst = self.dst_node - src_array = sdfg.arrays[src.data] - dst_array = sdfg.arrays[dst.data] - collect_src = len(src_array.shape) - 1 == len( - dst_array.shape) # If this is not true we have to distribute to dst (checked in can_apply) - if collect_src: - bank_count = int(src_array.shape[0]) - true_size = dst_array.shape - else: - bank_count = int(dst_array.shape[0]) - true_size = src_array.shape - ndim = len(true_size) - - # Move Default storage - if sdfg.arrays[src.data].storage == dtypes.StorageType.Default: - sdfg.arrays[src.data].storage = self.default_to_storage - if sdfg.arrays[dst.data].storage == dtypes.StorageType.Default: - sdfg.arrays[dst.data].storage = self.default_to_storage - - # Figure out how to split - if self.split_array_info is None: - split_info = [1] * ndim - split_info[0] = bank_count - else: - split_info = self.split_array_info - if len(split_info) != ndim: - raise RuntimeError("Length of split_array_info must match number of " - "dimensions") - if functools.reduce(lambda a, b: a * b, split_info) != bank_count: - raise RuntimeError("Splitting is not possible with the selected splits" - "and this number of HBM-banks (required number of banks " - "!= actual number of banks)") - - # create the copy-subgraph - ndrange = dict() - usable_params = [] - for i in range(ndim): - usable_params.append(f"i{i}") - for i in range(ndim): - ndrange[usable_params[i]] = f"0:{split_info[i]}" - graph.remove_edge_and_connectors(graph.edges_between(src, dst)[0]) - copy_map_enter, copy_map_exit = graph.add_map("hbm_bank_split", ndrange, dtypes.ScheduleType.Unrolled) - graph.add_edge(copy_map_enter, None, src, None, memlet.Memlet()) - graph.add_edge(dst, None, copy_map_exit, None, memlet.Memlet()) - - target_size = [str(x) for x in self._get_split_size(true_size, split_info)] - target_hbm_bank = [] - for i in range(ndim): - target_hbm_bank.append(usable_params[i]) - for j in range(i): - target_hbm_bank[j] = f"{split_info[i]}*{target_hbm_bank[j]}" - target_offset = [] - for i in range(ndim): - target_offset.append(f"{usable_params[i]}*{target_size[i]}") - - target_size_str = ", ".join([f"{x}:{y}" for x, y in zip([0] * ndim, target_size)]) - target_hbm_bank_str = "+ ".join(target_hbm_bank) - target_offset_str = ", ".join([f"({x}):({x}+{y})" for x, y in zip(target_offset, target_size)]) - if collect_src: - copy_memlet = memlet.Memlet(f"{src.data}[{target_hbm_bank_str}, {target_size_str}]->" - f"[{target_offset_str}]") - else: - copy_memlet = memlet.Memlet(f"{src.data}[{target_offset_str}]->[{target_hbm_bank_str}, " - f"{target_size_str}]") - graph.add_edge(src, None, dst, None, copy_memlet) diff --git a/dace/transformation/dataflow/hbm_transform.py b/dace/transformation/dataflow/hbm_transform.py deleted file mode 100644 index 18b34fae9b..0000000000 --- a/dace/transformation/dataflow/hbm_transform.py +++ /dev/null @@ -1,477 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -from dace.transformation.dataflow import StripMining, MapCollapse -from typing import Any, Dict, List, Union - -from dace import dtypes, properties, registry, subsets, symbolic -from dace.sdfg import propagation, utils, graph -from dace.codegen.targets import fpga -from dace.transformation import transformation, helpers -from dace.sdfg import nodes as nd -from dace import SDFG, SDFGState, memlet, data - - -def modify_bank_assignment(array_name: str, - sdfg: SDFG, - new_memory: str, - new_bank: str, - split_array_info: List[int] = None, - set_storage_type: bool = True): - """ - Updates bank assignments for the array on the SDFG. Will update - the shape of the array as well depending on the previous assignment. - :param split_array_info: A list with the same length as the old dimension - of the array. When transfering to HBM the size in each dimension is divided by - the corresponding int, when moving to DDR it is multiplied. - :param set_storage_type: Place the array on FPGA_Global. - """ - desc = sdfg.arrays[array_name] - old_memory = None - if 'memorytype' in desc.location and desc.location["memorytype"] is not None: - old_memory = desc.location["memorytype"] - if new_memory == "HBM": - low, high = fpga.get_multibank_ranges_from_subset(new_bank, sdfg) - else: - low, high = int(new_bank), int(new_bank) + 1 - if split_array_info is None: - d_size = len(desc.shape) - if fpga.is_multibank_array_with_distributed_index(desc): - d_size -= 1 - split_array_info = [1] * d_size - - if (old_memory is None or old_memory == "DDR") and new_memory == "HBM": - desc = sdfg.arrays[array_name] - new_shape = [x // y for x, y in zip(desc.shape, split_array_info)] - if high - low > 1: - desc.set_shape((high - low, *new_shape)) - else: - desc.set_shape(new_shape) - elif old_memory == "HBM" and (new_memory == "DDR" or new_memory is None): - desc = sdfg.arrays[array_name] - if fpga.is_multibank_array_with_distributed_index(desc): - old_shape = list(desc.shape)[1:] - else: - old_shape = desc.shape - new_shape = [x * y for x, y in zip(old_shape, split_array_info)] - desc.set_shape(new_shape) - elif old_memory == "HBM" and new_memory == "HBM": - oldlow, oldhigh = fpga.get_multibank_ranges_from_subset(desc.location["bank"], sdfg) - if oldlow == low and oldhigh == high: - return - # It would be problematic to change the number of banks, because of split_array_info - raise NotImplementedError("Cannot directly transfer from HBM to HBM") - desc.location["memorytype"] = new_memory - desc.location['bank'] = new_bank - if set_storage_type: - desc.storage = dtypes.StorageType.FPGA_Global - - -def _update_memlet_hbm(state: SDFGState, inner_edge: graph.MultiConnectorEdge, inner_subset_index: symbolic.symbol, - this_node: nd.AccessNode): - """ - Add the subset_index to the inner index. If the end/start of - the path is also an AccessNode, it will insert a tasklet before the access to - avoid validation failures due to dimensionality mismatch. - :param inner_edge: The inner_edge to modify - :param inner_subset_index: The distributed subset for the innermost edge on - the memlet path defined by convertible_node - :param this_node: The AccessNode for HBM which is associated with this call to - the function. (i.e. one side of the path) - """ - mem: memlet.Memlet = inner_edge.data - # If the memlet already contains the distributed subset, ignore it - # That's helpful because of inconsistencies when nesting and because - # one can 'hint' the correct bank assignment when using this function - if len(mem.subset) == len(state.parent.arrays[this_node.data].shape): - return - new_subset = subsets.Range([[inner_subset_index, inner_subset_index, 1]] + [x for x in mem.subset]) - - path = state.memlet_path(inner_edge) - if path[-1].dst == this_node: - is_write = True - other_node = path[0].src - elif path[0].src == this_node: - is_write = False - other_node = path[-1].dst - - if isinstance(other_node, nd.NestedSDFG): # Ignore those and update them via propagation - new_subset = subsets.Range.from_array(state.parent.arrays[this_node.data]) - - if isinstance(other_node, nd.AccessNode): - fwtasklet = state.add_tasklet("fwtasklet", set(["_in"]), set(["_out"]), "_out = _in") - state.remove_edge(inner_edge) - target_other_subset = mem.other_subset - mem.other_subset = None - if is_write: - inner_edge = state.add_edge(fwtasklet, '_out', inner_edge.dst, inner_edge.dst_conn, mem) - state.add_edge(other_node, path[0].src_conn, fwtasklet, "_in", - memlet.Memlet(other_node.data, subset=target_other_subset)) - else: - inner_edge = state.add_edge(inner_edge.src, inner_edge.src_conn, fwtasklet, '_in', mem) - state.add_edge(fwtasklet, "_out", other_node, path[-1].dst_conn, - memlet.Memlet(other_node.data, subset=target_other_subset)) - - inner_edge.data.subset = new_subset - - -@properties.make_properties -class HbmTransform(transformation.SingleStateTransformation): - """ - A transformation that applies on a map when all attached global memories are - assigned to banks. Note that the assignment is rather a hinting, the actual - dimensionality changes (when required) will be done by the transformation. - All arrays that should span across multiple banks (i.e. should be split) must have - exactly one dimension where an access happens in dependence of the map variable i. - All attached arrays must either be assigned to one single bank or span across - the same number of banks. - - Moves all attached arrays to their banks, and changes their size according to the - bank assignment. Adds an outer unrolled loop around the map on which it is applied - and changes all accesses such that they go to the same place as before (for single - bank assignments), respectively to the right bank and the modified location for - multibank arrays. By default the outer unrolled loop is made the top-level map of the - scope. - - For any access in some dimension to an array that should be split the transformation - assumes that the access to the dimension behaves linearly in the map variable. - Also in such dimension the array size has to be dividable by the number of banks - across which all multibank arrays span. - - At the moment the transformation cannot apply if the same array - in global memory is attached to the map with multiple edges. This also implies - that write-back to an array from which data is read is disallowed. It is done this way - because it reduces the complexity of the transformation somewhat and because in - the context of HBM this is likely a bad idea anyway. (More complex routing + reduced IO). - """ - """ - - - @staticmethod - def expressions(): - return [utils.node_path_graph(HbmTransform._map_entry)] - """ - _map_entry = transformation.PatternNode(nd.MapEntry) - #_map_entry = nd.MapEntry(nd.Map("", [], [])) - - @classmethod - def expressions(cls): - return [utils.node_path_graph(cls._map_entry)] - - new_dim = properties.Property(dtype=str, default="k", desc="Defines the map param of the outer unrolled map") - - move_to_FPGA_global = properties.Property(dtype=bool, - default=True, - desc="All assigned arrays have their storage changed to FPGA_Global") - - make_new_map_outermost = properties.Property(dtype=bool, - default=True, - desc="Make the map with schedule unrolled the outermost of the scope") - - def can_be_applied(self, - graph: Union[SDFG, SDFGState], - expr_index: int, - sdfg: SDFG, - permissive: bool = False) -> bool: - - # This can only be applied when this state can run on FPGA - if not isinstance(graph, SDFGState) or not helpers.can_run_state_on_fpga(graph): - return False - - map_entry = self._map_entry - map_exit = graph.exit_node(map_entry) - - # Can't handle nesting - scope = graph.scope_subgraph(map_entry) - for node in scope.nodes(): - if isinstance(node, nd.NestedSDFG) or isinstance(node, nd.LibraryNode): - return False - - if len(map_entry.map.params) != 1: - return False - - # Check if all arrays are assigned, and we can somehow split - result = HbmTransform._scan_paths(sdfg, graph, map_entry, map_exit) - if result is None: - return False - - return True - - def apply(self, graph: Union[SDFG, SDFGState], sdfg: SDFG) -> Union[Any, None]: - state: SDFGState = graph - unroll_entry: nd.MapEntry = self._map_entry - unroll_exit = state.exit_node(unroll_entry) - - split_arrays, no_split_arrays, unroll_factor, split_dimensions = HbmTransform._scan_paths( - sdfg, - state, - unroll_entry, - unroll_exit, - ) - - unroll_entry, inner_entry, tmp_old_outer_param = HbmTransform.unroll_map(sdfg, state, unroll_entry, - unroll_factor, self.new_dim, False) - scope_view = state.scope_subgraph(unroll_entry) - tmp_to_inner_range = inner_entry.map.range[0] - - # We remove the multiplication (since it's not needed any more) on the old parameter, - # but keep it so we can easier replace with the new dimension later - scope_view.replace(tmp_old_outer_param, f"({tmp_old_outer_param}/{unroll_factor})") - - # Actually place the arrays and update paths - for edge in state.in_edges(unroll_entry) + state.out_edges(unroll_exit): - name = edge.data.data - if name in split_arrays or name in no_split_arrays: - desc = sdfg.arrays[name] - memory_type = desc.location["memorytype"] - desc.location.pop("memorytype") - bank = desc.location["bank"] - desc.location.pop("bank") - if name in split_arrays: - division_info = [1] * len(sdfg.arrays[name].shape) - division_info[split_dimensions[name]] = unroll_factor - else: - division_info = None - modify_bank_assignment(name, sdfg, memory_type, bank, division_info, self.move_to_FPGA_global) - - if name in split_arrays: - path = state.memlet_path(edge) - if isinstance(path[0].src, nd.AccessNode) and path[0].src.data == name: - this_node = path[0].src - inner_edge = path[-1] - else: - this_node = path[-1].dst - inner_edge = path[0] - - inner_edge.data.replace({tmp_old_outer_param: "0"}) - _update_memlet_hbm(state, inner_edge, self.new_dim, this_node) - - # Replace the dummy symbol everywhere where it still remains - scope_view.replace(tmp_old_outer_param, f"({tmp_to_inner_range[1]}+1)*{self.new_dim}") - - if self.make_new_map_outermost: - HbmTransform.make_outermost_map(sdfg, state, unroll_entry) - - # Propagate the modified inner memlets - propagation.propagate_memlets_state(sdfg, state) - - @staticmethod - def all_innermost_edges(state: SDFGState, of: Union[nd.AccessNode, graph.MultiConnectorEdge]): - """ - Generator that returns all the innermost edges, i.e. the last edge(s) for a read - path and the first edge(s) for a write path. - :param of: If of is an AccessNode all the innermost edges of all attached edges are returned. - If of is an edge the innermost edges of that memlet path is returned. - """ - - def get_innermost_edges(tree): - res = [] - if len(tree.children) == 0: - return [tree.edge] - for child in tree.children: - res.extend(get_innermost_edges(child)) - return res - - if isinstance(of, nd.AccessNode): - src = lambda: state.all_edges(of) - else: - src = lambda: [of] - for edge in src(): - tree = state.memlet_tree(edge) - res = get_innermost_edges(tree) - for r in res: - yield r - - @staticmethod - def _scan_paths(sdfg: SDFG, state: SDFGState, map_entry: nd.MapEntry, map_exit: nd.MapExit): - """ - Find all arrays attached to the map, check their bank assignment/accesses - and find a suitable unroll factor if possible - :return: A tuple of (split_arrays, no_split_arrays, unroll_factor, split_dimensions, - array_dimensions), where split_arrays the array names that are split, - no_split_arrays array names that are not split, - unroll_factor the value for the number of splits that are created from - split_arrrays, - split_dimensions a mapping from array name to the dimension along which - the array should be split (always only 1). - """ - - unroll_factor = None - no_split_arrays = {} - split_arrays = {} - split_dimensions = {} - has_pending_changes = False # Will there something be done? - - attached_array = {} - for edge in state.in_edges(map_entry) + state.out_edges(map_exit): - if edge.data.is_empty(): - continue - if edge.data.data in attached_array: # Only one edge per array - return None - attached_array[edge.data.data] = edge - - for name in attached_array: - desc = sdfg.arrays[name] - - if not isinstance(desc, data.Array) or isinstance(desc, data.View): - continue - if desc.storage != dtypes.StorageType.FPGA_Global and desc.storage != dtypes.StorageType.Default: # If not in global memory ignore - continue - - assigned = fpga.parse_location_bank(desc) - if assigned is None: # All arrays must be assigned - return None - else: - if assigned[0] == "HBM": - low, high = fpga.get_multibank_ranges_from_subset(assigned[1], sdfg) - if high - low == 1: - no_split_arrays[name] = assigned - continue - if unroll_factor is None: - unroll_factor = high - low - else: - if unroll_factor != high - low: # All split arrays must have the same number of banks - return None - split_arrays[name] = assigned - - if desc.shape[0] != high - low: # Otherwise we assume the array was already placed - has_pending_changes = True - else: - return None # If an array was already placed on HBM we cannot apply - else: - no_split_arrays[name] = assigned - - # Check if the arrays which should be split can do so - for name in split_arrays: - edge = attached_array[name] - count_innermost = 0 - for edge in HbmTransform.all_innermost_edges(state, edge): - count_innermost += 1 - if count_innermost > 1: - return None # Can't handle trees - innermost = edge - - found = None - for i, val in enumerate(innermost.data.subset): - low, high, stride = val - if stride != 1 or low != high: - continue - if map_entry.map.params[0] in set([str(x) for x in low.free_symbols]): - if found is None: - found = i - else: - return None # Only 1 dimension may be dependent. - if found is None: - return None - - # We assume that the found dimension behaves linear in the map symbol - split_dimensions[name] = found - - if not has_pending_changes: # In this case we would do nothing - return None - - return (split_arrays, no_split_arrays, unroll_factor, split_dimensions) - - @staticmethod - def make_outermost_map(sdfg: SDFG, state: SDFGState, map_entry: nd.MapEntry): - """ - Make the map defined by map_entry the outermost of the scope - """ - # Find to top level map of the given entry - scopes = state.scope_dict() - outer_map = map_entry - while scopes[outer_map] is not None: - outer_map = scopes[outer_map] - if outer_map == map_entry: - return - - # "Strip Mine" with a factor of 1, i.e insert a dummy map - new_map: nd.Map = StripMining.apply_to(sdfg, { - "dim_idx": 0, - "tile_size": 1, - "divides_evenly": True, - }, - map_entry=outer_map) - for n in state.nodes(): - if isinstance(n, nd.MapEntry) and n.map == new_map: - # nodes are turned around by strip mine - inner_entry = outer_map - outer_map = n - break - # Copy relevant map values to new outermost - inner_entry.map.range = outer_map.map.range - inner_entry.map.schedule = dtypes.ScheduleType.Default - outer_map.map.params[0] = map_entry.map.params[0] - outer_map.map.range = map_entry.map.range - outer_map.map.schedule = map_entry.map.schedule - - # Delete the old map - scopes = state.scope_dict() - direct_parent_map = scopes[map_entry] - direct_parent_map, _ = MapCollapse.apply_to(sdfg, {}, - permissive=True, - outer_map_entry=direct_parent_map, - inner_map_entry=map_entry) - ndim = len(direct_parent_map.map.params) - direct_parent_map.map.params.pop() - direct_parent_map.map.range.pop([ndim - 1]) - - @staticmethod - def unroll_map(sdfg: SDFG, - state: SDFGState, - unroll_entry: nd.MapEntry, - unroll_factor: int, - new_param_name="k", - update_memlets=True): - """ - Add an unrolled map around unroll_entry and reduce the range of unroll_entry - accordingly. - :param update_memlets: If True all memlets are updated to behave the same as - before. Otherwise the old map symbols are kept (this is an invalid SDFG, - but usefull if one wants to modify the accesses in another way, like HBMTransform - does). - """ - tile_prefix = sdfg._find_new_name("bank") - map_range_int = None - - try: - map_range_int = int(symbolic.resolve_symbol_to_constant(unroll_entry.map.range[0][1], sdfg)) + 1 - except TypeError: - pass - - new_map: nd.Map = StripMining.apply_to(sdfg, { - "tile_size": unroll_factor, - "divides_evenly": True, - "skew": True, - "tiling_type": dtypes.TilingType.CeilRange, - "new_dim_prefix": tile_prefix - }, - map_entry=unroll_entry) - for n in state.nodes(): - if isinstance(n, nd.MapEntry) and n.map == new_map: - # nodes are turned around by strip mine - inner_entry = unroll_entry - unroll_entry = n - break - - # sympy does weird stuff when int_ceil with actual integers, so compute it for them - if map_range_int is not None: - low, _, stride = unroll_entry.map.range[0] - unroll_entry.map.range = subsets.Range([(low, ((map_range_int + unroll_factor - 1) // unroll_factor) - 1, - stride)]) - - # Switch the maps, update schedules, set outer parameter - tmp_to_inner_range = unroll_entry.map.range[0] - tmp_to_outer_range = inner_entry.map.range[0] - tmp_old_outer_param = unroll_entry.map.params[0] - - unroll_entry.map.params[0] = new_param_name - unroll_entry.map.range[0] = tmp_to_outer_range - inner_entry.map.range[0] = tmp_to_inner_range - inner_entry.map.schedule = dtypes.ScheduleType.Default - unroll_entry.map.schedule = dtypes.ScheduleType.Unrolled - - # Update the memlets, or return the symbol used before, so the caller can do that by itself - if update_memlets: - scope_view = state.scope_subgraph(unroll_entry) - scope_view.replace(tmp_old_outer_param, f"(({tmp_to_inner_range[1]}+1)*{new_param_name})/{unroll_factor}") - else: - return (unroll_entry, inner_entry, tmp_old_outer_param) diff --git a/dace/transformation/dataflow/map_distribution.py b/dace/transformation/dataflow/map_distribution.py index 5cd551d3ac..eaa8d07f2a 100644 --- a/dace/transformation/dataflow/map_distribution.py +++ b/dace/transformation/dataflow/map_distribution.py @@ -319,7 +319,7 @@ def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG): Px = dace.symbol('Px', dtype=dace.int32, integer=True, positive=True) Py = dace.symbol('Py', dtype=dace.int32, integer=True, positive=True) - from dace.data import _prod + from dace.utils import prod as _prod # NOTE: Maps with step in their ranges are currently not supported if len(map_entry.map.params) == 2: diff --git a/dace/transformation/dataflow/map_expansion.py b/dace/transformation/dataflow/map_expansion.py index 3835450172..26341ac378 100644 --- a/dace/transformation/dataflow/map_expansion.py +++ b/dace/transformation/dataflow/map_expansion.py @@ -34,7 +34,7 @@ class MapExpansion(pm.SingleStateTransformation): dtype=dtypes.ScheduleType, default=dtypes.ScheduleType.Sequential, allow_none=True) - expansion_limit = Property(desc="How many unidimensional maps will be creaed, known as k. " + expansion_limit = Property(desc="How many unidimensional maps will be created, known as k. " "If None, the default no limit is in place.", dtype=int, allow_none=True, diff --git a/dace/transformation/dataflow/map_fusion_vertical.py b/dace/transformation/dataflow/map_fusion_vertical.py index eda0e639f5..4dc4762931 100644 --- a/dace/transformation/dataflow/map_fusion_vertical.py +++ b/dace/transformation/dataflow/map_fusion_vertical.py @@ -7,6 +7,7 @@ from dace import data, properties, subsets, symbolic, transformation from dace.sdfg import SDFG, SDFGState, graph, nodes, propagation from dace.transformation.dataflow import map_fusion_helper as mfhelper +from dace.sdfg.type_inference import infer_expr_type @properties.make_properties @@ -1548,11 +1549,15 @@ def _is_data_accessed_downstream( def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: return (edge.dst for edge in graph.out_edges(node)) - # Dataflow graph is acyclic, so we do not need to keep a list of - # what we have visited. + # Track visited nodes to avoid exponential blowup from visiting + # the same node multiple times via different paths in the DAG. to_visit: List[nodes.Node] = list(next_nodes(begin)) + visited: Set[nodes.Node] = set() while len(to_visit) > 0: node = to_visit.pop() + if node in visited: + continue + visited.add(node) if isinstance(node, nodes.AccessNode) and node.data == data: return True to_visit.extend(next_nodes(node)) @@ -1842,7 +1847,6 @@ def compute_new_shape_or_stride(inner_values, outer_values, pattern): else: # NOTE: This code is copied from `SDFGState.add_nested_sdfg()`, according # to the description this is not a very good implementation. - from dace.codegen.tools.type_inference import infer_expr_type new_inner_value_type = infer_expr_type(outer_value, outer_sdfg.symbols) or dtypes.typeclass(int) inner_sdfg.add_symbol(new_inner_value_sym, new_inner_value_type) return new_inner_values diff --git a/dace/transformation/dataflow/streaming_memory.py b/dace/transformation/dataflow/streaming_memory.py index 6dec58e4bc..ed93a37c1a 100644 --- a/dace/transformation/dataflow/streaming_memory.py +++ b/dace/transformation/dataflow/streaming_memory.py @@ -394,7 +394,7 @@ def apply(self, state: SDFGState, sdfg: SDFG) -> nodes.AccessNode: read_to_gearbox = state.add_read(input_gearbox_name) write_from_gearbox = state.add_write(output_gearbox_name) - gearbox = Gearbox(total_size / vector_size) + gearbox = Gearbox(total_size / vector_size, schedule=dtypes.ScheduleType.FPGA_Device) state.add_node(gearbox) @@ -417,7 +417,7 @@ def apply(self, state: SDFGState, sdfg: SDFG) -> nodes.AccessNode: newdesc = input_gearbox_newdesc else: - # Qualify name to avoid name clashes if memory interfaces are not decoupled for Xilinx + # Qualify name to avoid name clashes stream_name = "stream_" + dnode.data name, newdesc = sdfg.add_stream(stream_name, desc.dtype, @@ -696,7 +696,7 @@ def apply(self, state: SDFGState, sdfg: SDFG) -> nodes.AccessNode: # Create new stream of shape 1 desc = sdfg.arrays[access.data] - # Qualify name to avoid name clashes if memory interfaces are not decoupled for Xilinx + # Qualify name to avoid name clashes stream_name = "stream_" + access.data name, newdesc = sdfg.add_stream(stream_name, desc.dtype, diff --git a/dace/transformation/dataflow/sve/infer_types.py b/dace/transformation/dataflow/sve/infer_types.py index 819dea2be5..cc1c8dc431 100644 --- a/dace/transformation/dataflow/sve/infer_types.py +++ b/dace/transformation/dataflow/sve/infer_types.py @@ -17,6 +17,7 @@ import dace.dtypes as dtypes from collections import defaultdict from dace.sdfg.utils import dfs_topological_sort +from dace.sdfg.type_inference import infer_types class TypeInferenceDict(DefaultDict[Tuple[Tasklet, str, bool], dtypes.typeclass]): @@ -35,9 +36,6 @@ def infer_tasklet_connectors(sdfg: SDFG, state: SDFGState, node: Tasklet, inferr raise TypeError('Cannot infer output connectors of tasklet "%s", ' 'not all input connectors have types' % str(node)) - # Avoid import loop - from dace.codegen.tools.type_inference import infer_types - # Get symbols defined at beginning of node syms = state.symbols_defined_at(node) diff --git a/dace/transformation/helpers.py b/dace/transformation/helpers.py index 4875279bea..bb4bd85d3f 100644 --- a/dace/transformation/helpers.py +++ b/dace/transformation/helpers.py @@ -971,6 +971,7 @@ def unsqueeze_memlet(internal_memlet: Memlet, external_offset: Tuple[int] = None) -> Memlet: """ Unsqueezes and offsets a memlet, as per the semantics of nested SDFGs. + :param internal_memlet: The internal memlet (inside nested SDFG) before modification. :param external_memlet: The external memlet before modification. :param preserve_minima: Do not change the subset's minimum elements. @@ -1613,40 +1614,6 @@ def replace_code_to_code_edges(sdfg: SDFG): state.remove_edge(edge) -def can_run_state_on_fpga(state: SDFGState): - """ - Checks if state can be executed on FPGA. Used by FPGATransformState - and HbmTransform. - """ - for node, graph in state.all_nodes_recursive(): - # Consume scopes are currently unsupported - if isinstance(node, (nodes.ConsumeEntry, nodes.ConsumeExit)): - return False - - # Streams have strict conditions due to code generator limitations - if (isinstance(node, nodes.AccessNode) and isinstance(graph.sdfg.arrays[node.data], data.Stream)): - nodedesc = graph.sdfg.arrays[node.data] - sdict = graph.scope_dict() - if nodedesc.storage in [ - dtypes.StorageType.CPU_Heap, dtypes.StorageType.CPU_Pinned, dtypes.StorageType.CPU_ThreadLocal - ]: - return False - - # Cannot allocate FIFO from CPU code - if sdict[node] is None: - return False - - # Arrays of streams cannot have symbolic size on FPGA - if symbolic.issymbolic(nodedesc.total_size, graph.sdfg.constants): - return False - - # Streams cannot be unbounded on FPGA - if nodedesc.buffer_size < 1: - return False - - return True - - def make_map_internal_write_external(sdfg: SDFG, state: SDFGState, map_exit: nodes.MapExit, access: nodes.AccessNode, sink: nodes.AccessNode): """ diff --git a/dace/transformation/interstate/__init__.py b/dace/transformation/interstate/__init__.py index 8464f7218f..c7fdd02efb 100644 --- a/dace/transformation/interstate/__init__.py +++ b/dace/transformation/interstate/__init__.py @@ -8,8 +8,6 @@ from .state_fusion_with_happens_before import StateFusionExtended from .state_elimination import (EndStateElimination, StartStateElimination, StateAssignElimination, SymbolAliasPromotion, HoistState) -from .fpga_transform_state import FPGATransformState -from .fpga_transform_sdfg import FPGATransformSDFG from .gpu_transform_sdfg import GPUTransformSDFG from .sdfg_nesting import NestSDFG, InlineSDFG, InlineTransients, RefineNestedAccess from .loop_unroll import LoopUnroll diff --git a/dace/transformation/interstate/fpga_transform_sdfg.py b/dace/transformation/interstate/fpga_transform_sdfg.py deleted file mode 100644 index 09a6ee2aa8..0000000000 --- a/dace/transformation/interstate/fpga_transform_sdfg.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. -""" Contains inter-state transformations of an SDFG to run on an FPGA. """ - -import networkx as nx - -from dace import properties -from dace.sdfg.sdfg import SDFG -from dace.transformation import transformation - - -@properties.make_properties -@transformation.explicit_cf_compatible -class FPGATransformSDFG(transformation.MultiStateTransformation): - """ Implements the FPGATransformSDFG transformation, which takes an entire - SDFG and transforms it into an FPGA-capable SDFG. """ - - promote_global_trans = properties.Property( - dtype=bool, - default=True, - desc="If True, transient arrays that are fully internal are pulled out so " - "that they can be allocated on the host.") - - @staticmethod - def annotates_memlets(): - return True - - @classmethod - def expressions(cls): - # Match anything - return [nx.DiGraph()] - - def can_be_applied(self, graph, expr_index, sdfg: SDFG, permissive=False): - # Avoid import loops - from dace.transformation.interstate import FPGATransformState - - # Condition match depends on matching FPGATransformState for each state - for state in sdfg.states(): - fps = FPGATransformState() - fps.setup_match(sdfg, state.parent_graph.cfg_id, -1, {FPGATransformState.state: state.block_id}, 0) - if not fps.can_be_applied(state.parent_graph, expr_index, sdfg): - return False - - return True - - def apply(self, _, sdfg: SDFG): - # Avoid import loops - from dace.transformation.interstate import NestSDFG - from dace.transformation.interstate import FPGATransformState - - cfg_id = sdfg.cfg_id - nesting = NestSDFG() - nesting.setup_match(sdfg, cfg_id, -1, {}, self.expr_index) - nesting.promote_global_trans = self.promote_global_trans - nesting.apply(sdfg, sdfg) - - # The state ID is zero since we applied NestSDFG and have only one state in the new SDFG - fpga_transform = FPGATransformState() - fpga_transform.setup_match(sdfg, cfg_id, -1, {FPGATransformState.state: 0}, self.expr_index) - fpga_transform.apply(sdfg, sdfg) diff --git a/dace/transformation/interstate/fpga_transform_state.py b/dace/transformation/interstate/fpga_transform_state.py deleted file mode 100644 index dc888d8c33..0000000000 --- a/dace/transformation/interstate/fpga_transform_state.py +++ /dev/null @@ -1,209 +0,0 @@ -# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. -""" Contains inter-state transformations of an SDFG to run on an FPGA. """ - -import copy -import dace -from dace import memlet, dtypes, sdfg as sd, subsets -from dace.sdfg import nodes -from dace.sdfg import utils as sdutil -from dace.sdfg.sdfg import SDFG -from dace.sdfg.state import ControlFlowRegion, SDFGState -from dace.transformation import transformation, helpers as xfh - - -def fpga_update(sdfg: SDFG, state: SDFGState, depth: int): - scope_dict = state.scope_dict() - for node in state.nodes(): - if (isinstance(node, nodes.AccessNode) and node.desc(sdfg).storage == dtypes.StorageType.Default): - nodedesc = node.desc(sdfg) - pmap = xfh.get_parent_map(state, node) - if depth >= 2 or (pmap is not None and pmap[0].schedule == dtypes.ScheduleType.FPGA_Device): - nodedesc.storage = dtypes.StorageType.FPGA_Local - else: - if scope_dict[node]: - nodedesc.storage = dtypes.StorageType.FPGA_Local - else: - nodedesc.storage = dtypes.StorageType.FPGA_Global - if (hasattr(node, "schedule") and node.schedule == dace.dtypes.ScheduleType.Default): - node.schedule = dace.dtypes.ScheduleType.FPGA_Device - if isinstance(node, nodes.NestedSDFG): - for s in node.sdfg.states(): - fpga_update(node.sdfg, s, depth + 1) - - -@transformation.explicit_cf_compatible -class FPGATransformState(transformation.MultiStateTransformation): - """ Implements the FPGATransformState transformation. """ - - state = transformation.PatternNode(sd.SDFGState) - - @classmethod - def expressions(cls): - return [sdutil.node_path_graph(cls.state)] - - def can_be_applied(self, graph, expr_index, sdfg, permissive=False): - state = self.state - - if not xfh.can_run_state_on_fpga(state): - return False - - for node in state.nodes(): - - if (isinstance(node, nodes.AccessNode) - and node.desc(sdfg).storage not in (dtypes.StorageType.Default, dtypes.StorageType.Register)): - return False - - if not isinstance(node, nodes.MapEntry): - continue - - map_entry = node - candidate_map = map_entry.map - - # Map schedules that are disallowed to transform to FPGAs - if (candidate_map.schedule == dtypes.ScheduleType.MPI - or candidate_map.schedule == dtypes.ScheduleType.GPU_Device - or candidate_map.schedule == dtypes.ScheduleType.FPGA_Device - or candidate_map.schedule == dtypes.ScheduleType.GPU_ThreadBlock): - return False - - # Recursively check parent for FPGA schedules - sdict = state.scope_dict() - current_node = map_entry - while current_node is not None: - if (current_node.map.schedule == dtypes.ScheduleType.GPU_Device - or current_node.map.schedule == dtypes.ScheduleType.FPGA_Device - or current_node.map.schedule == dtypes.ScheduleType.GPU_ThreadBlock): - return False - current_node = sdict[current_node] - - return True - - def apply(self, graph: ControlFlowRegion, sdfg: SDFG): - state = self.state - - # Find source/sink (data) nodes that are relevant outside this FPGA - # kernel - shared_transients = set(sdfg.shared_transients()) - input_nodes = [ - n for n in state.source_nodes() - if isinstance(n, nodes.AccessNode) and (not sdfg.arrays[n.data].transient or n.data in shared_transients) - ] - output_nodes = [ - n for n in state.sink_nodes() - if isinstance(n, nodes.AccessNode) and (not sdfg.arrays[n.data].transient or n.data in shared_transients) - ] - - fpga_data = {} - - # Input nodes may also be nodes with WCR memlets - # We have to recur across nested SDFGs to find them - wcr_input_nodes = set() - - for node, node_parent_graph in state.all_nodes_recursive(): - if isinstance(node, dace.sdfg.nodes.AccessNode): - for e in node_parent_graph.in_edges(node): - if e.data.wcr is not None: - trace = dace.sdfg.trace_nested_access(node, node_parent_graph, node_parent_graph.sdfg) - for node_trace, memlet_trace, state_trace, sdfg_trace in trace: - # Find the name of the accessed node in our scope - if state_trace == state and sdfg_trace == sdfg: - _, outer_node = node_trace - if outer_node is not None: - break - else: - # This does not trace back to the current state, so - # we don't care - continue - if any(outer_node.data == n.data for n in input_nodes): - continue # Skip adding duplicates - input_nodes.append(outer_node) - wcr_input_nodes.add(outer_node) - if input_nodes: - # create pre_state - pre_state = sd.SDFGState('pre_' + state.label, sdfg) - - for node in input_nodes: - - if not isinstance(node, dace.sdfg.nodes.AccessNode): - continue - desc = node.desc(sdfg) - if not isinstance(desc, dace.data.Array): - # TODO: handle streams - continue - - if node.data in fpga_data: - fpga_array = fpga_data[node.data] - elif node not in wcr_input_nodes: - fpga_array = sdfg.add_array('fpga_' + node.data, - desc.shape, - desc.dtype, - transient=True, - storage=dtypes.StorageType.FPGA_Global, - allow_conflicts=desc.allow_conflicts, - strides=desc.strides, - offset=desc.offset) - fpga_array[1].location = copy.copy(desc.location) - desc.location.clear() - fpga_data[node.data] = fpga_array - - pre_node = pre_state.add_read(node.data) - pre_fpga_node = pre_state.add_write('fpga_' + node.data) - mem = memlet.Memlet(data=node.data, subset=subsets.Range.from_array(desc)) - pre_state.add_edge(pre_node, None, pre_fpga_node, None, mem) - - if node not in wcr_input_nodes: - fpga_node = state.add_read('fpga_' + node.data) - sdutil.change_edge_src(state, node, fpga_node) - state.remove_node(node) - - graph.add_node(pre_state) - sdutil.change_edge_dest(graph, state, pre_state) - graph.add_edge(pre_state, state, sd.InterstateEdge()) - - if output_nodes: - - post_state = sd.SDFGState('post_' + state.label, sdfg) - - for node in output_nodes: - - if not isinstance(node, dace.sdfg.nodes.AccessNode): - continue - desc = node.desc(sdfg) - if not isinstance(desc, dace.data.Array): - # TODO: handle streams - continue - - if node.data in fpga_data: - fpga_array = fpga_data[node.data] - else: - fpga_array = sdfg.add_array('fpga_' + node.data, - desc.shape, - desc.dtype, - transient=True, - storage=dtypes.StorageType.FPGA_Global, - allow_conflicts=desc.allow_conflicts, - strides=desc.strides, - offset=desc.offset) - fpga_array[1].location = copy.copy(desc.location) - desc.location.clear() - fpga_data[node.data] = fpga_array - # fpga_node = type(node)(fpga_array) - - post_node = post_state.add_write(node.data) - post_fpga_node = post_state.add_read('fpga_' + node.data) - mem = memlet.Memlet(f"fpga_{node.data}", None, subsets.Range.from_array(desc)) - post_state.add_edge(post_fpga_node, None, post_node, None, mem) - - fpga_node = state.add_write('fpga_' + node.data) - sdutil.change_edge_dest(state, node, fpga_node) - state.remove_node(node) - - graph.add_node(post_state) - sdutil.change_edge_src(graph, state, post_state) - graph.add_edge(state, post_state, sd.InterstateEdge()) - - # propagate memlet info from a nested sdfg - for src, src_conn, dst, dst_conn, mem in state.edges(): - if mem.data is not None and mem.data in fpga_data: - mem.data = 'fpga_' + mem.data - fpga_update(sdfg, state, 0) diff --git a/dace/transformation/interstate/loop_to_map.py b/dace/transformation/interstate/loop_to_map.py index 063070bd94..ba6e2b2841 100644 --- a/dace/transformation/interstate/loop_to_map.py +++ b/dace/transformation/interstate/loop_to_map.py @@ -8,7 +8,7 @@ import warnings from dace import data as dt, dtypes, memlet, nodes, sdfg as sd, symbolic, subsets, properties -from dace.codegen.tools.type_inference import infer_expr_type +from dace.sdfg.type_inference import infer_expr_type from dace.sdfg import graph as gr, nodes from dace.sdfg import SDFG, SDFGState from dace.sdfg import utils as sdutil diff --git a/dace/transformation/interstate/sdfg_nesting.py b/dace/transformation/interstate/sdfg_nesting.py index 78ba72c78d..b90d448c56 100644 --- a/dace/transformation/interstate/sdfg_nesting.py +++ b/dace/transformation/interstate/sdfg_nesting.py @@ -141,8 +141,6 @@ def can_be_applied(self, graph: SDFGState, expr_index, sdfg, permissive=False): # Ensure that every connector has at least one corresponding access # node in the (nested) SDFG. Otherwise, inlining is not possible. - # NOTE: FPGA-compatible SDFGs can have input connectors for data that - # are only written. inp_data = {conn: set() for conn in in_connectors} for e in graph.in_edges(nested_sdfg): src = graph.memlet_path(e)[0].src @@ -1222,7 +1220,7 @@ def apply(self, _, sdfg: SDFG) -> nodes.NestedSDFG: # If this transient has a symbolic shape, and if any symbol is in in the "ranges" # of the state then substitute it with its max value (if it can be inferred). # This is useful for the cases where the transient comes from a slice operation - # (e.g. array[:i] or array[i:]), and we are on devices such as FPGAs that do not + # (e.g. array[:i] or array[i:]), and we are on devices that do not # support dynamic memory allocation. propagation.propagate_states(nested_sdfg) diff --git a/dace/transformation/onnx/__init__.py b/dace/transformation/onnx/__init__.py new file mode 100644 index 0000000000..651cd8eed4 --- /dev/null +++ b/dace/transformation/onnx/__init__.py @@ -0,0 +1,10 @@ +try: + from .constant_folding import ConstantFolding + from .parameter_to_transient import parameter_to_transient + from .optimize import expand_onnx_nodes, auto_optimize_onnx +except ImportError: + # ONNX transformations not available + ConstantFolding = None + parameter_to_transient = None + expand_onnx_nodes = None + auto_optimize_onnx = None diff --git a/dace/transformation/onnx/constant_folding.py b/dace/transformation/onnx/constant_folding.py new file mode 100644 index 0000000000..53b54daf5c --- /dev/null +++ b/dace/transformation/onnx/constant_folding.py @@ -0,0 +1,158 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +from typing import Optional, TYPE_CHECKING + +import numpy as np + +import dace +import torch +from dace import config +from dace.properties import make_properties +from dace.transformation import transformation +from dace.sdfg import nodes as nd +from dace.sdfg import utils as sdutil + +import dace.libraries.onnx as donnx +from dace.libraries.onnx.converters import clean_onnx_name +from dace.libraries.onnx.nodes.onnx_op import ONNXOp + +if TYPE_CHECKING: + from dace.frontend.ml.onnx import ONNXModel + +# blocklist of nondeterministic ops +# yapf: disable +NONDETERMINISTIC_OPS = {'ONNXDropout', + 'ONNXGradient', + 'ONNXGraphCall', + 'ONNXIf', + 'ONNXLoop', + 'ONNXMomentum', + 'ONNXMultinomial', + 'ONNXRandomNormal', + 'ONNXRandomNormalLike', + 'ONNXRandomUniform', + 'ONNXRandomUniformLike', + 'ONNXSVMClassifier', + 'ONNXSVMRegressor', + 'ONNXScan', + 'ONNXTreeEnsembleClassifier', + 'ONNXTreeEnsembleRegressor'} +# yapf: enable + + +@make_properties +class ConstantFolding(transformation.SingleStateTransformation): + """ Remove nodes where all inputs are known and replace them with constant nodes by precomputing the output. + """ + # pattern matching only checks that the type of the node matches, + onnx_node = transformation.PatternNode(ONNXOp) + + @classmethod + def expressions(cls): + return [sdutil.node_path_graph(cls.onnx_node)] + + @staticmethod + def is_constant(sdfg: dace.SDFG, state: dace.SDFGState, node) -> bool: + if len(state.in_edges(node)) > 0: + return False + + # the ONNX importer adds a _parent_onnx_model attribute to the sdfg + if isinstance(node, nd.AccessNode) and node.data in sdfg._parent_onnx_model.clean_weights: + return True + + return False + + def can_be_applied(self, + graph: dace.sdfg.graph.OrderedMultiDiConnectorGraph, + expr_index: int, + sdfg, + permissive: bool = False): + + node = self.onnx_node + + # SDFG must be imported from an ONNXModel + if not hasattr(sdfg, "_parent_onnx_model"): + return False + + if not 'ONNX' + node.schema.name not in NONDETERMINISTIC_OPS: + return False + + if isinstance(node, donnx.ONNXShape): + assert len(graph.in_edges(node)) == 1 + shape_in_edge = graph.in_edges(node)[0] + assert shape_in_edge.dst_conn == "data" + shape_desc = sdfg.arrays[shape_in_edge.src.data] + try: + np.array(shape_desc.shape, np.int64) + except Exception: + # this happens if the shape is symbolic, for example + return False + + return True + + # all inputs are constant + for edge in graph.in_edges(node): + if not ConstantFolding.is_constant(sdfg, graph, edge.src): + return False + + return True + + @classmethod + def match_to_str(cls, graph): + node: ONNXOp = cls.onnx_node + return "Precompute outputs of {}".format(node) + + def apply(self, state: dace.SDFGState, sdfg: dace.SDFG): + parent: "ONNXModel" = sdfg._parent_onnx_model + node = self.onnx_node + if config.Config.get_bool('debugprint'): + print(f"Applying constant folding: {node} in {state}") + + if isinstance(node, donnx.ONNXShape): + # if we have a shape node, replace it with a constant + assert len(state.in_edges(node)) == 1 + shape_in_edge = state.in_edges(node)[0] + assert shape_in_edge.dst_conn == "data" + shape_desc = sdfg.arrays[shape_in_edge.src.data] + + constant_name = sdfg.temp_data_name() + clean_constant_name = clean_onnx_name(constant_name) + sdfg.add_array(clean_constant_name, (len(shape_desc.shape), ), dace.int64) + + assert constant_name not in parent.clean_weights + parent.weights[constant_name] = torch.from_numpy(np.array(shape_desc.shape, np.int64)) + + assert len(state.out_edges(node)) == 1 + output_edge = state.out_edges(node)[0] + access_shape = state.add_access(clean_constant_name) + state.add_edge(access_shape, None, output_edge.dst, output_edge.dst_conn, + sdfg.make_array_memlet(clean_constant_name)) + + # remove all now useless nodes with a reverse BFS + remove_node_and_computation(sdfg, state, node) + + +def remove_node_and_computation(sdfg: dace.SDFG, state: dace.SDFGState, node: nd.Node, connector: Optional[str] = None): + """ Remove a node and the parent nodes that compute this node, if the outputs are not used elsewhere. + + :param sdfg: the sdfg containing the node. + :param state: the state containing the node. + :param node: the node to remove + :param connector: if not None, the computation of the connector of + ``node`` will be removed, but not ``node`` itself. + """ + if connector is not None: + if connector not in node.in_connectors: + return + node.remove_in_connector(connector) + edges = state.in_edges_by_connector(node, connector) + for e in edges: + state.remove_edge(e) + else: + edges = state.out_edges(node) + for e in edges: + state.remove_edge(e) + + # remove dangling nodes, this can happen with non-transients + for node, parent in sdfg.all_nodes_recursive(): + if (isinstance(node, nd.AccessNode) and parent.in_degree(node) + parent.out_degree(node) == 0): + parent.remove_node(node) diff --git a/dace/transformation/onnx/optimize.py b/dace/transformation/onnx/optimize.py new file mode 100644 index 0000000000..dc266aa4c0 --- /dev/null +++ b/dace/transformation/onnx/optimize.py @@ -0,0 +1,65 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +from typing import Optional, Callable + +import dace +from dace import config, nodes as nd +from dace.libraries import blas +from dace.sdfg.utils import expand_nodes +from dace.transformation import dataflow +from dace.transformation.auto.auto_optimize import set_fast_implementations +from dace.transformation.dataflow import CopyToMap + + +def expand_onnx_nodes(sdfg: dace.SDFG, predicate: Optional[Callable[[nd.Node], bool]] = None): + """ Recursively expand all onnx library nodes in the SDFG, resulting in an SDFG that can be optimized by + dace transformations. Will also specialize dace matmuls. + + :param sdfg: the sdfg to expand nodes on. + :param predicate: a predicate that will be called to check if a node should be expanded. + """ + + try: + from dace.libraries.onnx.nodes.onnx_op import ONNXOp # avoid import loop + except ImportError: + raise ImportError("expand_onnx_nodes requires ONNX. Install with: pip install dace[ml]") + + if predicate is None: + new_predicate = lambda n: isinstance(n, (ONNXOp, blas.MatMul)) + else: + new_predicate = lambda n: predicate(n) and isinstance(n, (ONNXOp, blas.MatMul)) + + expand_nodes(sdfg, new_predicate) + + +def auto_optimize_onnx(sdfg: dace.SDFG, cuda, simplify=False, fold_constants=True): + """ Automatically optimize ``sdfg``. + + :param sdfg: the sdfg to optimize (inplace). + :param cuda: whether to optimize for cuda. + :param simplify: whether to apply simplification transformations to the sdfg after optimization. + :param fold_constants: whether to apply constant folding. + """ + + try: + from dace.transformation.onnx import ConstantFolding # avoid import loop + except ImportError: + raise ImportError("auto_optimize_onnx requires ONNX. Install with: pip install dace[ml]") + + if config.Config.get_bool('debugprint'): + print("Applying automatic optimizations") + if fold_constants: + if config.Config.get_bool('debugprint'): + print("Applying constant folding") + sdfg.apply_transformations_repeated([ConstantFolding, dataflow.RedundantSecondArray], validate_all=True) + if config.Config.get_bool('debugprint'): + print("Expanding ONNX nodes") + expand_onnx_nodes(sdfg) + if config.Config.get_bool('debugprint'): + print("Setting fast implementations") + set_fast_implementations(sdfg, dace.DeviceType.GPU if cuda else dace.DeviceType.CPU) + if simplify: + if config.Config.get_bool('debugprint'): + print("Applying simplification transforms") + sdfg.simplify() + if cuda: + sdfg.apply_transformations_once_everywhere(CopyToMap) diff --git a/dace/transformation/onnx/parameter_to_transient.py b/dace/transformation/onnx/parameter_to_transient.py new file mode 100644 index 0000000000..39988b83b4 --- /dev/null +++ b/dace/transformation/onnx/parameter_to_transient.py @@ -0,0 +1,83 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import operator + +import dace +from dace import config, dtypes, nodes + +from dace.libraries.onnx.converters import clean_onnx_name +from dace.libraries.torch import dlpack + + +def parameter_to_transient(dace_module: 'dace.frontend.ml.torch', parameter_path: str): + """ Convert the dace array for pytorch parameter found at parameter_path to a persistently allocated transient. + + :param dace_module: the module containing the weight to transform. + :param weight_path: the dotted path to the weight + """ + + if config.Config.get_bool('debugprint'): + print(f"Converting parameter {parameter_path} to a transient") + + pt_weight_name = parameter_path + pt_tensor = operator.attrgetter(pt_weight_name)(dace_module.model) + array_name = clean_onnx_name(pt_weight_name) + dace_module.dace_model.inputs.remove(parameter_path) + + # the the access node for this array of this array + cands = [(node, parent) for (node, parent) in dace_module.sdfg.all_nodes_recursive() + if isinstance(node, nodes.AccessNode) and node.data == array_name] + + if len(cands) == 0: + if config.Config.get_bool('debugprint'): + print(f"Warning: Could not find access node with name '{array_name}', skipping parameter to transient") + return + + if len(cands) != 1: + raise ValueError("parameter_to_transient does not work when the target array has multiple AccessNodes") + + if array_name not in dace_module.sdfg.arrays: + raise ValueError(f"Could not find parameter {array_name} in sdfg.") + + if dace_module.sdfg.arrays[array_name].storage is dtypes.StorageType.GPU_Global: + dace_module.sdfg.arrays[array_name].transient = True + dace_module.sdfg.arrays[array_name].lifetime = dtypes.AllocationLifetime.Persistent + gpu_array_name = array_name + else: + + # find the GPU transient of this array + state: dace.SDFGState + cand, state = cands[0] + if state.out_degree(cand) != 1: + raise ValueError(f"expected one out edge coming out of {cand}, found {state.out_degree(cand)}") + _, _, dst_node, _, _ = state.out_edges(cand)[0] + if (not isinstance(dst_node, nodes.AccessNode) + or dace_module.sdfg.arrays[dst_node.data].storage is not dtypes.StorageType.GPU_Global): + raise ValueError(f"parameter_to_transient only works for arrays that are copied to GPU_Global arrays," + f" but array {array_name} was connected to {dst_node}") + + gpu_array_name = dst_node.data + + # since it is parsable, proceed with the transformation + dace_module.sdfg.arrays[gpu_array_name].transient = True + dace_module.sdfg.arrays[gpu_array_name].lifetime = dtypes.AllocationLifetime.Persistent + + # remove the CPU node + state.remove_node(cand) + del dace_module.sdfg[array_name] + + def post_compile_hook(compiled_sdfg): + + struct = compiled_sdfg.get_state_struct() + + param_sdfg = compiled_sdfg.sdfg + struct_entry_name = f'__{param_sdfg.sdfg_id}_{gpu_array_name}' + + if not hasattr(struct, struct_entry_name): + raise ValueError(f"Could not parse parameter {gpu_array_name} from state_struct.") + + ptr = getattr(struct, struct_entry_name) + # copy the data into the torch parameter tensor + torch_tensor = dlpack.array_to_torch_tensor(ptr, param_sdfg.arrays[gpu_array_name]) + torch_tensor[:] = pt_tensor + + dace_module.post_compile_hooks["init_" + pt_weight_name] = post_compile_hook diff --git a/dace/transformation/onnx/replacement.py b/dace/transformation/onnx/replacement.py new file mode 100644 index 0000000000..ce51fdd14c --- /dev/null +++ b/dace/transformation/onnx/replacement.py @@ -0,0 +1,159 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" General class for pattern replacement transformations. """ +import abc +import dace +from dace import registry, nodes, data as dt +from dace.transformation import transformation, helpers as xfh +from typing import Any, Dict, List, Optional, Tuple, Union +from dace.sdfg import utils as sdutil +from dace.libraries.onnx import nodes as onnx_op +from dace.sdfg import graph as gr + + +def make_onnx_path(*path_nodes: nodes.Node) -> gr.OrderedDiGraph: + result = gr.OrderedDiGraph() + + # First add the nodes in order, so that they can be accessed + path_nodes = [transformation.PatternNode(n) for n in path_nodes] + result.add_nodes_from(path_nodes) + + # Then make a path and add access nodes as necessary + last_node = None + for node in path_nodes: + if last_node is not None: + result.add_edge(last_node, node) + last_node = node + + return result + + +def add_connecting_access_nodes(graph: gr.OrderedDiGraph): + edges_to_remove = [] + outputs = {} + for pnode in graph.nodes(): + if issubclass(pnode.node, (nodes.LibraryNode, nodes.NestedSDFG)): + if any(issubclass(e.dst.node, (nodes.LibraryNode, nodes.NestedSDFG)) for e in graph.out_edges(pnode)): + # Make new output node that everyone will link from + new_node = transformation.PatternNode(nodes.AccessNode) + graph.add_node(new_node) + graph.add_edge(pnode, new_node) + outputs[pnode] = new_node + + for e in graph.edges(): + if (issubclass(e.src.node, (nodes.LibraryNode, nodes.NestedSDFG)) + and issubclass(e.dst.node, (nodes.LibraryNode, nodes.NestedSDFG))): + # Direct path between two library nodes means that there is at least + # another access node in between + if e.src in outputs: + graph.add_edge(outputs[e.src], e.dst) + edges_to_remove.append(e) + else: + raise ValueError('Found directly connected library nodes with source not designated as output') + for e in edges_to_remove: + graph.remove_edge(e) + + +def onnx_constant_or_none(sdfg: dace.SDFG, node_or_name: Union[nodes.AccessNode, str]) -> Optional[Any]: + name = node_or_name if isinstance(node_or_name, str) else node_or_name.data + if name not in sdfg._parent_onnx_model.clean_weights: + return None + cten = sdfg._parent_onnx_model.clean_weights[name] + return cten.item() if cten.numel() == 1 else cten.tolist() + + +class ReplacementTransformation(transformation.SingleStateTransformation, abc.ABC): + + @classmethod + @abc.abstractmethod + def pattern(cls) -> gr.OrderedDiGraph[nodes.Node, dace.Memlet]: + """ Returns a pattern to match as a directed graph. """ + raise NotImplementedError + + @abc.abstractmethod + def replacement(self, subgraph: List[nodes.Node], sdfg: dace.SDFG, + state: dace.SDFGState) -> Tuple[nodes.Node, Dict[str, Tuple[nodes.Node, Union[str, dt.Data]]]]: + """ + Defines replacement behavior for the transformation. This method returns + a node (which could also be a nested SDFG if a subgraph should be + returned), accompanied by instructions for reconnecting the surrounding + nodes and creating new data (arrays). + :param subgraph: The list of nodes in the matched state with the same + IDs as the pattern subgraph. + :param sdfg: The SDFG in which to perform the replacement. + :param state: The state in which the subgraph was found. + :return: A 2-tuple of (new node, mapping), where the latter maps a + connector name on the new node to either a pair of + (old node, old connector) to redirect from, or + (None, data descriptor) if a new one shall be created. + """ + raise NotImplementedError + + @classmethod + def expressions(cls): + if hasattr(cls, '_pattern'): + return [cls._pattern] + result = cls.pattern() + add_connecting_access_nodes(result) + + # Set subgraph as class property + cls._pattern = result + # Set pattern nodes as class properties + for i, node in enumerate(result.nodes()): + setattr(cls, f'_pnode{i}', node) + return [result] + + def can_be_applied(self, graph: Union[dace.SDFG, dace.SDFGState], candidate: Dict[transformation.PatternNode, int], + expr_index: int, sdfg: dace.SDFG, simplify: bool) -> bool: + # All internal nodes must not be global (non-transient) or reused + # anywhere else + subgraph = gr.SubgraphView(graph, [graph.node(id) for id in candidate.values()]) + for node in subgraph.nodes(): + # Check for internal nodes + if node in subgraph.source_nodes() or node in subgraph.sink_nodes(): + continue + if not isinstance(node, nodes.AccessNode): + continue + if not node.desc(sdfg).transient: + return False + other_data_nodes_with_same_name = [ + n for s in sdfg.nodes() for n in s.nodes() + if isinstance(n, nodes.AccessNode) and n.data == node.data and n not in subgraph.nodes() + ] + if len(other_data_nodes_with_same_name) > 0: + return False + return True + + def apply(self, sdfg: dace.SDFG) -> nodes.Node: + state: dace.SDFGState = sdfg.node(self.state_id) + matcher = self.expressions()[0] + subgraph = [state.node(self.subgraph[n]) for n in matcher.nodes()] + new_node, reconnection = self.replacement(subgraph, sdfg, state) + + # Remap edges and add new arrays + for new_conn, (node, old_conn) in reconnection.items(): + # Make new array + if node is None: + desc = old_conn + name = sdfg.add_datadesc('_' + new_conn, desc, find_new_name=True) + node = state.add_access(name) + if new_conn in new_node.in_connectors: + state.add_edge(node, None, new_node, new_conn, dace.Memlet(name)) + elif new_conn in new_node.out_connectors: + state.add_edge(new_node, new_conn, node, None, dace.Memlet(name)) + continue + # END of new array + + if new_conn in new_node.in_connectors: + e = next(state.in_edges_by_connector(node, old_conn)) + xfh.redirect_edge(state, e, new_dst=new_node, new_dst_conn=new_conn) + elif new_conn in new_node.out_connectors: + e = next(state.out_edges_by_connector(node, old_conn)) + xfh.redirect_edge(state, e, new_src=new_node, new_src_conn=new_conn) + + # Remove subgraph nodes that are not connected from outside + sgview = gr.SubgraphView(state, subgraph) + state.remove_nodes_from( + [n for n in subgraph if isinstance(n, nodes.CodeNode) or state.degree(n) == sgview.degree(n)]) + # Remove orphan nodes + state.remove_nodes_from([n for n in state.nodes() if isinstance(n, nodes.AccessNode) and state.degree(n) == 0]) + return new_node diff --git a/dace/transformation/passes/constant_propagation.py b/dace/transformation/passes/constant_propagation.py index 59ca8afabe..f3a87db377 100644 --- a/dace/transformation/passes/constant_propagation.py +++ b/dace/transformation/passes/constant_propagation.py @@ -33,7 +33,7 @@ class ConstantPropagation(ppl.Pass): CATEGORY: str = 'Simplification' - recursive = properties.Property(dtype=bool, default=True, desc='Propagagte recursively through nested SDFGs') + recursive = properties.Property(dtype=bool, default=True, desc='Propagate recursively through nested SDFGs') progress = properties.Property(dtype=bool, default=None, allow_none=True, desc='Show progress') def modifies(self) -> ppl.Modifies: @@ -201,7 +201,7 @@ def _propagate_loop(self, loop: LoopRegion, post_constants: BlockConstsT, if loop in post_constants and post_constants[loop] is not None: if loop.update_statement is not None and (loop.inverted and loop.update_before_condition or not loop.inverted): - # Replace the RHS of the update experssion + # Replace the RHS of the update expression post_mapping = { k: v for k, v in post_constants[loop].items() diff --git a/dace/transformation/passes/scalar_to_symbol.py b/dace/transformation/passes/scalar_to_symbol.py index 2514d8412d..693a4a7777 100644 --- a/dace/transformation/passes/scalar_to_symbol.py +++ b/dace/transformation/passes/scalar_to_symbol.py @@ -145,6 +145,10 @@ def find_promotable_scalars(sdfg: sd.SDFG, transients_only: bool = True, integer continue edge = state.in_edges(node)[0] + # Edge must not be empty + if edge.data.is_empty(): + continue + # Edge must not be WCR if edge.data.wcr is not None: candidates.remove(candidate) @@ -169,11 +173,16 @@ def find_promotable_scalars(sdfg: sd.SDFG, transients_only: bool = True, integer if state.out_degree(edge.src) > 1: candidates.remove(candidate) continue - # If inputs to tasklets are not arrays, skip + for tinput in state.in_edges(edge.src): + # If inputs to tasklets are not arrays, skip if not isinstance(tinput.src, nodes.AccessNode): candidates.remove(candidate) break + # If edge memlet is empty, skip + if tinput.data.is_empty(): + candidates.remove(candidate) + break if isinstance(sdfg.arrays[tinput.src.data], dt.Stream): candidates.remove(candidate) break diff --git a/dace/transformation/passes/simplification/control_flow_raising.py b/dace/transformation/passes/simplification/control_flow_raising.py index 0749b03b82..2a47221116 100644 --- a/dace/transformation/passes/simplification/control_flow_raising.py +++ b/dace/transformation/passes/simplification/control_flow_raising.py @@ -50,7 +50,7 @@ def _lift_returns(self, sdfg: SDFG) -> int: region, and not the entire SDFG. :param sdfg: The SDFG in which to lift returns - :returns: The number of return blocks lifted + :return: The number of return blocks lifted """ returns_lifted = 0 for nd in sdfg.nodes(): @@ -220,7 +220,7 @@ def _lift_unstructured(self, sdfg: SDFG) -> int: cycles represent unstructured control flow. :param sdfg: The SDFG in which to lift unstructured control flow - :returns: The number of unstructured control flow blocks lifted + :return: The number of unstructured control flow blocks lifted """ lifted = 0 for cfg in sdfg.all_control_flow_regions(): diff --git a/dace/transformation/subgraph/__init__.py b/dace/transformation/subgraph/__init__.py index 92abe349d9..a4aa969b0c 100644 --- a/dace/transformation/subgraph/__init__.py +++ b/dace/transformation/subgraph/__init__.py @@ -5,5 +5,4 @@ from .expansion import MultiExpansion from .subgraph_fusion import SubgraphFusion from .stencil_tiling import StencilTiling -from .temporal_vectorization import TemporalVectorization from .composite import CompositeFusion diff --git a/dace/transformation/subgraph/helpers.py b/dace/transformation/subgraph/helpers.py index 7fa0c933e7..07e08226c0 100644 --- a/dace/transformation/subgraph/helpers.py +++ b/dace/transformation/subgraph/helpers.py @@ -151,7 +151,7 @@ def get_outermost_scope_maps(sdfg, graph, subgraph=None, scope_dict=None): If the underlying subgraph is not connected, there might be multiple locally outermost scopes. In this ambiguous case, the method returns an empty list. - If subgraph == None, the whole graph is taken + If subgraph is None, the whole graph is taken for analysis. """ subgraph = graph if subgraph is None else subgraph diff --git a/dace/transformation/subgraph/stencil_tiling.py b/dace/transformation/subgraph/stencil_tiling.py index 93a6517720..caaa367fdb 100644 --- a/dace/transformation/subgraph/stencil_tiling.py +++ b/dace/transformation/subgraph/stencil_tiling.py @@ -256,6 +256,11 @@ def can_be_applied(sdfg, subgraph) -> bool: if data_name in coverages[child_entry][0]: children_coverage = subsets.union(children_coverage, coverages[child_entry][0][data_name]) + # TODO: Is there a better fix for this? + if children_coverage is None: + # no coverage for this data_name in children + # this is not supported + return False # extend mapping map_parameter -> coverage # by the previous mapping diff --git a/dace/transformation/subgraph/temporal_vectorization.py b/dace/transformation/subgraph/temporal_vectorization.py deleted file mode 100644 index b0d500b71c..0000000000 --- a/dace/transformation/subgraph/temporal_vectorization.py +++ /dev/null @@ -1,166 +0,0 @@ -import dace -from dace import data, dtypes, nodes, properties -from dace.libraries.standard import Gearbox -from dace.memlet import Memlet -from dace.sdfg.graph import SubgraphView -from dace.sdfg.sdfg import SDFG -from dace.sdfg.state import SDFGState -from dace.transformation import transformation -from dace.transformation.subgraph import helpers - - -@properties.make_properties -class TemporalVectorization(transformation.SubgraphTransformation): - ''' - This transformation applies the multi-pumping optimization to a subgraph targeting FPGAs, in turn packing more computations temporally rather than spatially as is done in traditional vectorization. - - Currently, it can only be applied to applications targeting Xilinx FPGAs, and to subgraphs that purely communicate through streams. It can be applied in two ways: - 1. Where the widths of the internal paths remain unchanged while the external paths are widened by the multi-pumping factor. This gives the benefit of increased throughput at the same critical-resource footprint. - 2. Where the widths of the internal paths are divided by the multi-pumping factor, while the widths of the external paths remain unchanged. This gives the benefit of a reduced critical-resource footprint at the same throughput. - ''' - # TODO 3rd approach: where the subgraph is just clocked faster without introducing gearboxing. This could help designs where the subgraph reaches II=2, then by clocking it two times faster, it essentially behaves as II=1 from the slow clock domain. - factor = properties.Property(dtype=int, - default=2, - desc='The multi-pumping factor. E.g. double-pumping is a factor of 2.') - approach = properties.Property( - dtype=int, - default=2, - desc= - 'Which approach to use. Can be 1 (increased throughput, same resources) or 2 (same throughput, reduced resourced).' - ) - - def can_be_applied(self, sdfg: SDFG, subgraph: SubgraphView) -> bool: - ''' - Temporal vectorization can be applied if: - 1. There is one outermost map in the subgraph. - 2. All of the non- source and sink nodes are only accessed within this subgraph. - 3. All of the source and sink nodes are either streams or scalars. - 4. If the approach is 1, then either: - - The elemental type of the streams must be a vector type. - - The elemental type must be convertible to a vector type. - - Data packers/issuers are allowed to be inserted at the cost of performance through additional data plumbing overhead. - 5. If the approach is 2, all the elemental types of the streams must be a vector type that is integer divisable by the multi-pumping factor. - ''' - # Extract all of the relevant components of the subgraph - graph = subgraph.graph - src_nodes = subgraph.source_nodes() - dst_nodes = subgraph.sink_nodes() - srcdst_nodes = src_nodes + dst_nodes - srcdst_arrays = [sdfg.arrays[node.data] for node in srcdst_nodes if isinstance(node, nodes.AccessNode)] - access_nodes = [ - node for node in subgraph.nodes() if isinstance(node, nodes.AccessNode) and not node in srcdst_nodes - ] - map_entries = helpers.get_outermost_scope_maps(sdfg, graph, subgraph) - map_exits = [graph.exit_node(map_entry) for map_entry in map_entries] - maps = [map_entry.map for map_entry in map_entries] - - # Perform checks - # 1. There is at least one map. - if len(maps) < 1: return False - - # TODO array of streams - # TODO map of computation (matmul) - # TODO scalars - - # 2. All of the non- source and sink nodes only resides within this subgraph. - for sg in dace.sdfg.concurrent_subgraphs(graph): - if sg == subgraph: continue - for nd in sg.nodes(): - if isinstance(nd, nodes.AccessNode) and nd in access_nodes: - return False - - # 3. All of the source and sink nodes only are either streams or scalars. - for arr in srcdst_arrays: - if not (isinstance(arr, data.Stream) or isinstance(arr, data.Scalar)): - return False - - # 4. If the approach is 1, then either the dataype must be a vector type or must be convertible to a vector type. - if self.approach == 1: - # TODO not implemented yet. - return False - - # 5. If the approach is 2, then all the elemental datatype of the streams must be a vector type. - elif self.approach == 2: - for arr in srcdst_arrays: - if (isinstance(arr, data.Stream) - and not isinstance(arr.dtype, dtypes.vector)) or arr.veclen % self.factor != 0: - return False - - # If the approach is wrong, then it should not be applied. - else: - return False - - return True - - def issuer(self, sdfg: SDFG, state: SDFGState, subgraph: SubgraphView, src): - arr = sdfg.arrays[src.data] - veclen = arr.dtype.veclen // self.factor - dtype = dace.vector(arr.dtype.base_type, veclen) - name = f'{src.data}_pumped' - new_src = sdfg.add_stream(name, dtype, storage=dtypes.StorageType.FPGA_Local, transient=True) - - # Update the subgraph - old_edge = subgraph.out_edges(src)[0] - old_path = state.memlet_path(old_edge) - for edge in old_path[1:]: - edge.data = dace.Memlet(f'{name}[0]') - old_path[-1].dst.in_connectors[old_path[-1].dst_conn] = dtype - state.remove_edge(old_edge) - new_src = state.add_read(name) - state.add_edge(new_src, old_edge.src_conn, old_edge.dst, old_edge.dst_conn, memlet=dace.Memlet(f'{name}[0]')) - innermost_map = [n.src.map for n in old_path if isinstance(n.src, nodes.MapEntry)][-1] - - # Insert gearboxing for converting stream widths - gearbox = Gearbox(innermost_map.range.ranges[0][1] + 1, schedule=dtypes.ScheduleType.FPGA_Multi_Pumped) - gearbox_src = state.add_write(name) - state.add_memlet_path(src, gearbox, dst_conn='from_memory', memlet=Memlet(f'{src.data}[0]')) - state.add_memlet_path(gearbox, gearbox_src, src_conn='to_kernel', memlet=Memlet(f'{name}[0]')) - return innermost_map - - def packer(self, sdfg: SDFG, state: SDFGState, subgraph: SubgraphView, dst): - arr = sdfg.arrays[dst.data] - veclen = arr.dtype.veclen // self.factor - dtype = dace.vector(arr.dtype.base_type, veclen) - name = f'{dst.data}_pumped' - sdfg.add_stream(name, dtype, storage=dtypes.StorageType.FPGA_Local, transient=True) - - # Update the subgraph - old_edge = subgraph.in_edges(dst)[0] - old_path = state.memlet_path(old_edge) - for edge in old_path[:-1]: - edge.data = dace.Memlet(f'{name}[0]') - old_path[0].src.out_connectors[old_path[0].src_conn] = dtype - state.remove_edge(old_edge) - new_dst = state.add_write(name) - state.add_edge(old_edge.src, old_edge.src_conn, new_dst, old_edge.dst_conn, memlet=dace.Memlet(f'{name}[0]')) - innermost_map = [n.dst.map for n in old_path if isinstance(n.dst, nodes.MapExit)][0] - - # Insert gearbox for converting stream widths. - gearbox = Gearbox(innermost_map.range.ranges[0][1] + 1, schedule=dtypes.ScheduleType.FPGA_Multi_Pumped) - gearbox_dst = state.add_read(name) - state.add_memlet_path(gearbox_dst, gearbox, dst_conn='from_memory', memlet=Memlet(f'{name}[0]')) - state.add_memlet_path(gearbox, dst, src_conn='to_kernel', memlet=Memlet(f'{dst.data}[0]')) - return innermost_map - - def apply(self, sdfg: SDFG, **kwargs): - # Get the graphs and the nodes - subgraph = self.subgraph_view(sdfg) - graph = subgraph.graph - src_nodes = subgraph.source_nodes() - dst_nodes = subgraph.sink_nodes() - affected_maps = set() - - # Update all of the subgraph inputs - for src in src_nodes: - affected_maps.add(self.issuer(sdfg, graph, subgraph, src)) - - # Update all of the subgraph outputs - for dst in dst_nodes: - affected_maps.add(self.packer(sdfg, graph, subgraph, dst)) - - # Update the schedules of the innermost affected maps. - for map in affected_maps: - rng = list(map.range.ranges[0]) - rng[1] = ((rng[1] + 1) * 2) - 1 - map.range.ranges[0] = rng - map.schedule = dtypes.ScheduleType.FPGA_Multi_Pumped diff --git a/dace/utils.py b/dace/utils.py new file mode 100644 index 0000000000..26e3661be8 --- /dev/null +++ b/dace/utils.py @@ -0,0 +1,58 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Utility functions for DaCe. + +This module provides general utility functions that are used across various parts of DaCe. +""" + +import math +from typing import Iterable, Sequence, Union + +import sympy + +# Type alias for numeric or symbolic values +NumericType = Union[int, float, sympy.Basic] + + +def prod(sequence: Iterable[NumericType], start: NumericType = 1) -> NumericType: + """ + Computes the product of a sequence of numbers or symbolic expressions. + + This function handles both numeric values and SymPy symbolic expressions, + making it suitable for use with DaCe's symbolic shape calculations. + + :param sequence: An iterable of numbers or symbolic expressions. + :param start: The starting value for the product (default: 1). + :return: The product of all elements in the sequence, multiplied by start. + Returns start if the sequence is empty. + """ + result = start + for item in sequence: + result = result * item + return result + + +def find_new_name(name: str, existing_names: Sequence[str]) -> str: + """ + Returns a name that matches the given ``name`` as a prefix, but does not + already exist in the given existing name set. The behavior is typically + to append an underscore followed by a unique (increasing) number. If the + name does not already exist in the set, it is returned as-is. + + :param name: The given name to find. + :param existing_names: The set of existing names. + :return: A new name that is not in existing_names. + """ + if name not in existing_names: + return name + cur_offset = 0 + new_name = name + '_' + str(cur_offset) + while new_name in existing_names: + cur_offset += 1 + new_name = name + '_' + str(cur_offset) + return new_name + + +def deduplicate(iterable): + """ Removes duplicates in the passed iterable. """ + return type(iterable)([i for i in sorted(set(iterable), key=lambda x: iterable.index(x))]) diff --git a/dace/version.py b/dace/version.py index 1f356cc57b..14f8047584 100644 --- a/dace/version.py +++ b/dace/version.py @@ -1 +1 @@ -__version__ = '1.0.0' +__version__ = '43!2026.01.13' diff --git a/doc/codegen/codegen.rst b/doc/codegen/codegen.rst index af34aac569..06edd03936 100644 --- a/doc/codegen/codegen.rst +++ b/doc/codegen/codegen.rst @@ -79,7 +79,7 @@ among others. Within the control flow tree, each state is visited by the frame-code generator, which will then dispatch the other targets using the :class:`~dace.codegen.dispatcher.TargetDispatcher` class. Code generator targets register themselves -with the dispatcher by extending the :class:`~dace.codegen.targets.target.TargetCodeGenerator` class, and then +with the dispatcher by extending the :class:`~dace.codegen.target.TargetCodeGenerator` class, and then the dispatcher via the ``register_*_dispatcher`` methods (e.g., :func:`~dace.codegen.dispatcher.TargetDispatcher.register_node_dispatcher`) in their constructor. The dispatcher will then call the given predicate function to determine whether the target should be invoked for the given node. For example, an excerpt from the GPU code generator is shown below: @@ -104,7 +104,7 @@ should be invoked for the given node. For example, an excerpt from the GPU code The dispatcher will then invoke its ``dispatch_*`` methods (e.g., :func:`~dace.codegen.dispatcher.TargetDispatcher.dispatch_node`) -to invoke the target. Those will then call the ``generate_*`` methods (e.g., :func:`~dace.codegen.targets.target.TargetCodeGenerator.generate_node`). +to invoke the target. Those will then call the ``generate_*`` methods (e.g., :func:`~dace.codegen.target.TargetCodeGenerator.generate_node`). On most targets, each node type has a matching ``_generate_`` method, similarly to AST visitors, which are responsible for that node type. For example, see :func:`~dace.codegen.targets.cpu.CPUCodeGen._generate_MapEntry` in :class:`~dace.codegen.targets.cpu.CPUCodeGen`. @@ -116,7 +116,7 @@ for global declarations). At this point, instrumentation providers are also invo exact methods that are invoked can be found in :class:`~dace.codegen.instrumentation.provider.InstrumentationProvider`. After the graph is traversed, each target is invoked with two methods: -:func:`~dace.codegen.targets.target.TargetCodeGenerator.get_generated_codeobjects` and :func:`~dace.codegen.targets.target.TargetCodeGenerator.cmake_options` +:func:`~dace.codegen.target.TargetCodeGenerator.get_generated_codeobjects` and :func:`~dace.codegen.target.TargetCodeGenerator.cmake_options` to retrieve any extra :class:`~dace.codegen.codeobject.CodeObject` files and CMake options, respectively. The frame-code generator will then merge all code objects and return them, along with any environments/libraries that were requested by the code generators (e.g., link with CUBLAS). The compiler interface then generates the ``.dacecache`` @@ -132,8 +132,7 @@ The code generator uses a thin C++ runtime for support. The folder, which contai be found in the ``dace/runtime`` folder. The ``dace.h`` header file is the point of entry for the runtime, and it includes all the other necessary headers. The runtime is used for: - * **Target-specific runtime functions**: Header files inside the ``cuda``, ``intel_fpga``, and ``xilinx`` folders contain - GPU (CUDA/HIP), Intel FPGA, and Xilinx-specific functions, respectively. + * **Target-specific runtime functions**: Header files inside the ``cuda`` folder contains GPU (CUDA/HIP) specific functions. * Memory management * **Profiling**: ``perf/reporting.h`` contains functions that create :ref:`instrumentation reports `, ``perf/papi.h`` contains functions that use the `PAPI `_ library to measure performance counters. @@ -177,132 +176,3 @@ For example, if we want to debug the code generation of a specific node, we can and add a condition to it, such as ``node.label == "my_node"``. This will stop the code generation process when the code generator reaches the node with the label ``my_node``. This can be used to debug the code generation of a specific node, or to debug the code generation of a specific node type (e.g., ``isinstance(node, dace.nodes.MapEntry)``). - - -FPGA Code Generation --------------------- -The FPGA Code Generation emits High-Level Synthesis device code and all the host code required to target either Xilinx or Intel FPGAs. - -The FPGA code generation is implemented by different modules, organized hierarchically: - - * a generic FPGA backend (``dace/codegen/target/fpga.py``) is in charge of traversing the SDFG as shown in :ref:`codegen_how_it_works`; - * two lower level components that are in charge of generating device-specific code for Vivado HLS (``dace/codegen/target/xilinx.py``) or Intel FPGA OpenCL (``dace/codegen/target/intel_fpga.py``). - -Vendor-specific semantics and syntax are handled by the two lower-level components triggered by the generic FPGA backend. - -The FPGA code generation relies on the `HLSLIB `_ external library to facilitate host/device interaction and HLS code generation. - - -Maps: pipelined and unrolled parallelism -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Maps are used to express parallel scopes in SDFGs. -In the context of FPGAs, we exploit this parallelism in two ways: pipelined and unrolled parallelism. - -.. rubric:: - Pipeline parallelism - -By default. maps are code-generated as pipelined loops, where iterations are executed in sequence, with the lowest II that can -be attained by the compiler. -With the Intel OpenCL compiler, loops are automatically pipelined. For the Xilinx backend, proper pragmas are generated (``#pragma HLS pipeline``). - - -.. rubric:: - Unrolled (or spatial) parallelism - -If a map is explicitly unrolled, this will be code generated as a loop with unrolling hints. -In this case, the compiler will unroll the loop, replicating the hardware and exploiting the spatial parallelism of the device. - - - -Streams -^^^^^^^ - -Streams are DaCe containers that represent first-in, first-out queues. -In FPGAs, they can be implemented in hardware (FIFOs) to exploit the on-chip resources and allow fast -communication between different program components. - -These containers and their related operations are generated differently for Xilinx and Intel FPGA: - - * for Xilinx FPGAs, streams are emitted in the top-level kernel function as local objects. - Then they are passed as arguments to the producer and consumer accessing them. - - * for Intel FPGAs, they must be emitted to the global kernel scope, where the - producer and consumer will read them directly (i.e., rather than receiving them as arguments). - This would require, among the others, considering the case where different streams are defined - using the same name. In this case, the Intel FPGA Code generator will mangle their name so - they can be uniquely identified in the program. - -Finally, we should also consider the presence of streams that connect different FPGA kernels (see the section about FPGA kernels and processing elements). -In this case, they are defined either in the connectivity configuration file (``link.cfg``) that is passed to the Vitis compiler (Xilinx), -or in a shared header that is then included by the different kernels (Intel OpenCL). - - - -Decoupled Memory interfaces -^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -When a container stored in the FPGA Device Memory (off-chip memory) is both read and written, DaCe, by default, -creates a single memory interface for both types of accesses. - -While this has no particular performance impact on Intel, for Xilinx this could impair place and route step, resulting in -a lower synthesis frequency. - -For this reason, the programmer can set to true the DaCe configuration option ``DACE_compiler_fpga_xilixn_decouple_array_interfaces``. -This has an effect on the code generated for Xilinx. Any time that an array is If an array is both read and written, this option decouples -its accesses by creating a memory interface for reading and one for writing. The array name is qualified and code generated with a ``_in`` or -``_out`` suffix, indicating the access directionality. - - -*Warning*: while decoupling memory interfaces can improve performance, it must be used carefully. This may hide potential Read-After-Write or -Write-After-Read dependencies to the Vitis compiler, resulting in erroneous hardware. In addition to this, enabling the configuration could create up to 2 times the number of interfaces, -possibly reaching the limits supported by the device/Vitis. - - -.. _codegen_fpga_kernels: - -FPGA Kernels and Processing Elements -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -When the DaCe code generator backend encounters a state that only accesses containers situated on the FPGA, it designates it as an *FPGA kernel* -and triggers FPGA code generation (:func:`~dace.codegen.targets.fpga.FPGACodeGen.generate_state`). - -Before continuing the traversal to generate the hardware itself, the kernel *boundary* is detected. -Here, DaCe supports two options: - - * by default, it will infer the entire SDFG state as an FPGA kernel. The DaCe code generator will generate each weakly connected - component found in an SDFG state in a different *Processing Element*. Being independent, these SDFG components can be executed in parallel. - The notion of partitioning the functionality of a kernel into multiple independently-scheduled modules - is central to designing large FPGA architectures. - - * if the ``DACE_compiler_fpga_concurrent_kernel_detection`` configuration option is set to ``True``, - a heuristic will further inspect each independent component for other parallelism opportunities (e.g., branches of the SDFG - that can be executed in parallel). With this, inside the same state there could be multiple FPGA Kernels, that may depending - on each other (e.g., a kernel must wait for the completion of a previous one before it can be executed). - - -Once kernel boundaries are identified, the code generator infers the necessary arguments that must be passed and generate -host code call for kernel launches and synchronizations. - -Regarding processing elements, in the Vivado HLS toolflow, processing elements are expressed by annotating a scope in the -generated C++ code with the ``DATAFLOW`` pragma, resulting in every loop and function call in the scope to be scheduled -as a distinct processing element. -Intel OpenCL has no distinction between processing elements and kernels. Therefore every processing element must be expressed as a -separate OpenCL kernel. Launching each processing element is thus done directly from the host code. - - - - -Systolic Arrays -^^^^^^^^^^^^^^^ -Systolic arrays are used to express parametric parallelism, by using an array of communicating processing elements that can be programmed to perform a common operation. - -In a SDFG, 1D systolic arrays can be represented by unrolled maps in the outermost FPGA kernel scope. -The map can have a symbolic, but compile-time specialized, number of iterations, and must be coupled with array(s) of stream objects. - -When the map is unrolled, its body get replicated, and each instance becomes a weakly connected component in the state, resulting in them being instantiated as separate processing elements (see :ref:`codegen_fpga_kernels`). - - -The actual code generation varies between Xilinx and Intel FPGA. In the former case, it is sufficient to unroll a loop in the C++ kernel code with bounds known at compile tim. For Intel, the OpenCL kernel representing the processing element is replicated and specialized directly in the generated code. - - -.. TODO: adding figure/example may help understanding what's going on. diff --git a/doc/design/codegen.md b/doc/design/codegen.md index f236179c49..7cd8e9efc1 100644 --- a/doc/design/codegen.md +++ b/doc/design/codegen.md @@ -26,7 +26,7 @@ The current code generation system in DaCe follows this monolithic structure in 2. **Frame Code Generator Setup** - Create `DaCeCodeGenerator` instance - - Target instantiation (CPU, CUDA, FPGA, etc.) + - Target instantiation (CPU, GPU, etc.) - Code generation target querying (`_get_codegen_targets()`) 3. **Target Preprocessing** @@ -50,11 +50,9 @@ The current code generation system in DaCe follows this monolithic structure in #### Target System: - **`targets/target.py`**: Base target interface - **`targets/cpu.py`**: CPU/OpenMP code generation -- **`targets/cuda.py`**: CUDA/HIP GPU code generation -- **`targets/fpga.py`**: FPGA code generation +- **`targets/gpu.py`**: CUDA/HIP GPU code generation - **`targets/cpp.py`**: C++ utilities - **`targets/mpi.py`**: MPI parallelization -- **`targets/rtl.py`**: RTL/SystemVerilog generation #### Specialized Systems: - **`targets/sve/`**: ARM SVE vectorization @@ -102,43 +100,43 @@ The `DaCeCodeGenerator` class currently handles numerous responsibilities that s ### Phase 1: Scheduling and Analysis Passes -#### 1. **ValidationPass** +#### **ValidationPass** - **Purpose**: Run SDFG validation prior to code generation - **Input**: Input SDFG - **Output**: None - **Current Location**: `validate.py` -#### 2. **TypeInferencePass** +#### **TypeInferencePass** - **Purpose**: Infer connector types and set default storage/schedule types - **Input**: Input SDFG - **Output**: SDFG with inferred types, pipeline_results["type_info"] - **Current Location**: `infer_types.py` functions -#### 3. **LibraryExpansionPass** +#### **LibraryExpansionPass** - **Purpose**: Expand all library nodes that haven't been expanded - **Input**: Type-inferred SDFG - **Output**: SDFG with expanded library nodes - **Current Location**: `sdfg.expand_library_nodes()` -#### 4. **TypeInferencePass** +#### **TypeInferencePass** - **Purpose**: After expanding library nodes, run a second type inference pass if the SDFG changed - **Input**: Library-expanded SDFG - **Output**: SDFG with inferred types, updated pipeline_results["type_info"] - **Current Location**: `infer_types.py` functions -#### 5. **MetadataCollectionPass** +#### **MetadataCollectionPass** - **Purpose**: Collect free symbols, argument lists, constants, shared transients - **Input**: Expanded SDFG - **Output**: pipeline_results["metadata"] = {symbols, arglist, constants, shared_transients} - **Current Location**: `DaCeCodeGenerator.__init__()` -#### 6. **ControlFlowRaising** +#### **ControlFlowRaising** - **Purpose**: Extract structured control flow from state machines, if Control Flow Regions were not already given - **Input**: SDFG - **Output**: SDFG with Control Flow Regions - **Current Location**: Already exists -#### 7. **AllocationAnalysisPass** +#### **AllocationAnalysisPass** - **Purpose**: Determine allocation lifetimes and scopes for all data containers - **Input**: SDFG with metadata - **Output**: SDFG with allocation/deallocation points stored in node metadata @@ -146,7 +144,7 @@ The `DaCeCodeGenerator` class currently handles numerous responsibilities that s these decisions. - **Current Location**: `DaCeCodeGenerator.determine_allocation_lifetime()` -#### 8. **StreamAssignmentPass** (mostly GPU-specific) +#### **StreamAssignmentPass** (mostly GPU-specific) - **Purpose**: Assign streams for concurrent execution. Currently used for CUDA/HIP streams but can apply to other architectures - **Input**: SDFG - **Output**: SDFG with stream assignments stored in node metadata @@ -157,35 +155,42 @@ The `DaCeCodeGenerator` class currently handles numerous responsibilities that s #### Target-Specific Preprocessing Passes - **Purpose**: Perform preprocessing modifications on the SDFG based on the code generators that will be used next -- **Examples**: `FPGAPreprocessingPass` for FPGAs, `StreamAssignmentPass` for GPUs, `CopyToMap` for heterogeneous targets in general (see below) +- **Examples**: `StreamAssignmentPass` for GPUs, `CopyToMap` for heterogeneous targets in general (see below) -#### 9. **LowerAllocations** +#### **LowerConsume** +- **Purpose**: Convert Consume scopes into while loops or kernels, depending on the target (`LowerConsumeCPP`, `LowerConsumeGPU`) +- **Input**: SDFG with consume scopes +- **Output**: SDFG with control flow regions +- **Current Location**: Inline in code generators +- **Note**: This modifies the SDFG structure rather than generating code + +#### **LowerAllocations** - **Purpose**: Add allocation/deallocation annotations (e.g., as tasklets) to the SDFG for each scope - **Input**: SDFG with allocation analysis - **Output**: SDFG with allocation/deallocation tasklets inserted - **Current Location**: `allocate_arrays_in_scope()`, `deallocate_arrays_in_scope()` - **Note**: This modifies the SDFG structure rather than generating code -#### 10. **CopyToMap** +#### **CopyToMap** - **Purpose**: Convert nontrivial memory copies to Map nodes where needed - **Input**: SDFG with targets identified - **Output**: SDFG with transformed copies - **Current Location**: `cuda.py` preprocessing, various target preprocessors -- **Applies To**: GPU strided copies, FPGA transfers +- **Applies To**: GPU strided copies -#### 11. **LowerTaskletLanguage** +#### **LowerTaskletLanguage** - **Purpose**: Convert Python/generic tasklets to tasklets in the target language (C++/CUDA/etc.) - **Input**: SDFG with tasklets - **Output**: SDFG with lowered tasklets - **Current Location**: Distributed across target generators -#### 12. **LowerMemlets** +#### **LowerMemlets** - **Purpose**: Lower high-level memlets to explicit copy operations - **Input**: SDFG with target analysis - **Output**: SDFG with explicit copies annotated (e.g., as tasklets) - **Current Location**: Embedded in target-specific copy generation -#### 13. **SplitSDFGToTargets** +#### **SplitSDFGToTargets** - **Purpose**: The final lowering step splits the single SDFG into an SDFG per target file. This means that, for example, a GPU kernel map will be converted to an ExternalSDFG call to another SDFG file that contains the kernel. @@ -199,22 +204,22 @@ The `DaCeCodeGenerator` class currently handles numerous responsibilities that s ### Phase 3: Code Generation Passes -#### 14. **GenerateStateStruct** +#### **GenerateStateStruct** - **Purpose**: Generate state struct definitions for persistent data - **Input**: SDFG with allocation info - **Output**: pipeline_results["state_struct"] = {struct_def, struct_init} - **Current Location**: `DaCeCodeGenerator.generate_code()` -#### 15. **GenerateTargetCode** +#### **GenerateTargetCode** - **Purpose**: Generate both frame code and target-specific code for each SDFG file by traversing the graph and emitting code for each element. - **Input**: Split SDFGs with all previous analyses - **Output**: pipeline_results["code_objects"] = List[CodeObject] with complete code - **Current Location**: Combined from `DaCeCodeGenerator.generate_code()` and target-specific `get_generated_codeobjects()` -- **Note**: This pass may call individual target code generators (CppCodeGen, GPUCodeGen, FPGACodeGen, etc.) to +- **Note**: This pass may call individual target code generators (CppCodeGen, GPUCodeGen, etc.) to generate platform-specific code -#### 14. **GenerateHeaders** +#### **GenerateHeaders** - **Purpose**: Generate C/C++ header files for SDFG interface - **Input**: CodeObjects with complete code - **Output**: pipeline_results["headers"] = {call_header, sample_main} @@ -244,8 +249,9 @@ class CodeGenerationPipeline(Pipeline): # Phase 2: Lowering LowerAllocations(), ConditionalPipeline([ - (lambda r: 'cuda' in r.get('targets', []), CopyToMapPass()), - (lambda r: 'fpga' in r.get('targets', []), FPGAPreprocessingPass()), + (lambda r: 'gpu' in r.get('targets', []), CopyToMapPass()), + (lambda r: 'gpu' in r.get('targets', []), LowerConsumeGPU()), + (lambda r: 'cpu' in r.get('targets', []), LowerConsumeCPP()), ]), LowerTaskletLanguage(), LowerMemlets(), @@ -301,7 +307,6 @@ dace/codegen/ │ ├── __init__.py │ ├── analysis/ # Analysis passes │ │ ├── __init__.py -│ │ ├── type_inference.py │ │ ├── metadata_collection.py │ │ └── allocation_analysis.py │ ├── transformation/ # Transformation passes @@ -329,7 +334,6 @@ dace/codegen/ │ ├── openmp.py # OpenMP backend (split from cpu.py) │ ├── cpp.py # Pure C++ backend (split from cpu.py and cpp.py) │ ├── gpu.py # GPU backend (generalized from cuda.py) -│ ├── fpga/ # FPGA backends │ └── specialized/ # Other specialized targets ├── runtime/ # Runtime interface (from compiled_sdfg.py) └── utils/ # Utilities (dispatcher, codeobject, etc.) @@ -349,11 +353,11 @@ dace/codegen/ - Base for other C++ based backends - Sequential execution model - Basic memory management +- Current "CPU" backend functionality (without parallelism) #### 2. **OpenMP Backend** (`targets/openmp.py`) - Extends C++ backend with OpenMP directives - CPU parallelization via OpenMP -- Current "CPU" backend functionality - Shared memory parallelism #### 3. **GPU Backend** (`targets/gpu.py`) @@ -373,10 +377,6 @@ TargetCodeGenerator (base) ├── GPUCodeGen (unified GPU backend) │ ├── CUDACodeGen (NVIDIA specifics) │ └── HIPCodeGen (AMD specifics) -├── FPGACodeGen (FPGA base) -│ ├── XilinxCodeGen -│ ├── IntelFPGACodeGen -| └── RTLCodeGen └── MLIRCodeGen ``` diff --git a/doc/design/frontend.md b/doc/design/frontend.md index 4e030f7348..dc331b865e 100644 --- a/doc/design/frontend.md +++ b/doc/design/frontend.md @@ -153,7 +153,7 @@ Existing nodes in `treenodes.py`: - `GBlock` (general control flow block) **Dataflow Scopes:** -- `MapScope`, `ConsumeScope`, `PipelineScope` +- `MapScope`, `ConsumeScope` **Leaf Nodes:** - `TaskletNode`, `LibraryCall`, `CopyNode`, `ViewNode` diff --git a/doc/index.rst b/doc/index.rst index d3e96c2ef7..616dca24e7 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -20,7 +20,7 @@ DaCe generates high-performance programs for: * Multi-core CPUs (tested on Intel, IBM POWER9, and ARM with SVE) * NVIDIA GPUs and AMD GPUs (see :ref:`how to use HIP in DaCe `) - * Xilinx and Intel FPGAs + * `Xilinx and Intel FPGAs `_ diff --git a/doc/optimization/fpga.rst b/doc/optimization/fpga.rst deleted file mode 100644 index ac6b4ad861..0000000000 --- a/doc/optimization/fpga.rst +++ /dev/null @@ -1,201 +0,0 @@ -FPGA Optimization Best Practices -================================ - -.. note:: - - This document is a work in progress. Feel free to make any contributions or suggestions via Pull Requests. - - -This section provides guidance on leveraging DaCe functionalities to optimize DaCe programs targeting FPGAs. -Once the user program is parsed into an SDFG, we can optimize (transform) it to improve performance. In the case of FPGA programs, -the user can apply transformations or follow best practices to reduce data movements, specialize operations implementation, and -increase spatial parallelism. -In the following we start by presenting automatic transformations specifically helpful for FPGA programs. Then we discuss how to specialize -library node implementation. Finally we show how to control various low-level aspects, such as Maps scheduling and Memory Hierarchy. - - -.. _fpga_transformations: - -Transformations for FPGA programs ---------------------------------- - -.. TODO: Structure this slightly differently (don't assume the user knows anything). Show an example of apply_fpga_transformation, -.. and dedicate subsubsections for transformation types (streaming transformations, memory layout transformations) instead of just simple bullet points. - -Existing SDFGs can be transformed from a generic to an FPGA implementation using graph transformations. -The resulting SDFGs can be can be further optimized using general-purpose transformations available in the DaCe toolbox. -This includes platform-agnostic transformations (such as Trivial Map Elimination, Map Collapsing, Map tiling, ...) and more -FPGA-oriented transformations, which we describe here. - -* :py:func:`~dace.transformation.interstate.fpga_transform_sdfg.FPGATransformSDFG`: programmers can automatically offload a full - SDFG using this transformation. This takes care of creating create additional pre- and post-states performing memory transfers - between host and device. The memories accessed by the transformed subgraph are replaced with their FPGA equivalents. -* :py:func:`~dace.transformation.dataflow.streaming_memory.StreamingMemory`: this transformation enables the automatic creation of - streaming memory accessors (see :ref:`fpga_streams`). The transformation analyzes accesses to data containers. If applicable, - it converts an existing memory access to a streaming memory access: the data is read/written to/from a stream in a separate connected - component than the computation. If the `use_memory_buffering` option is set to ``True``, the transformation enables burst reads/write form/to memory, by - using a wider data format (e.g., 512 bits), and then convert it on the fly to the right data type used by the computation. -* :py:func:`~dace.transformation.dataflow.streaming_memory.StreamingComposition`: in unoptimized SDFGs, intermediate data occuring between two consecutive computations - is represented as data access nodes, pointing to off-chip memory by default. This off-chip accesses are undesirable, and in certain conditions can be completely avoided. - This transformation converts two connected computations (nodes, map scopes) into two separate processing elements, with a stream connecting the results. - The transformation performs checks similar to the previous one, and applyes only if the memory access patterns of the two computations match. -* :py:func:`~dace.transformation.auto.fpga.fpga_global_to_local`: changes the storage of containers allocated in global memory to local memory when this is possible. -* :py:func:`~dace.transformation.auto.fpga.fpga_rr_interleave_containers_to_banks`: interleaved global memory containers on the available off-chip memory banks. - Containers are allocated in a Round-Robin fashion. - - -Library Nodes and FPGA specialization -------------------------------------- - -Library nodes are high-level nodes that represent specific functions (e.g., matrix multiplication). During compilation and optimization, -Library Nodes are *expanded* by replacing them with a subgraph, *lowering* them towards a concrete -implementation of their behavior. - -.. TODO: add links to the library node (rather than mention their name). For this, we need to enable their docs - -Available FPGA expansions -^^^^^^^^^^^^^^^^^^^^^^^^^ -DaCe provides FPGA-specific expansions for the principal numerical linear algebra or common operations: - -* vector dot product (``dot``) can be specialized for FPGA using two expansions: ``FPGA_Accumulate`` and ``FPGA_PartialSums``. The former assumes that - native single clock cycle accumulation of the data type is possible on the target architecture (e.g., 32-bit floating - point on Intel Stratix 10). The latter does not assume that native accumulation of the data type is possible. - Both expansions achieve an Initiation Interval of 1. -* matrix-vector multiplication (``gemv``) is available in two versions: - - * ``FPGA_Accumulate``: this FPGA-oriented expansion iterates over the input matrix in simple row-major order, with optional - tiling in both dimensions, where the tiles are also traversed in simple row-major order. - * ``FPGA_TilesByColumn``: this expansion reads the input matrix in column-major order, such that consecutive values are accumulated into different - registers. The matrix can optionally be tiled, where the tiles will be traversed in row-major order. - - These two expansions complement each other as they can be used to favor composability (pipeline-ability) with the rest of the computation. - For example, if another library node produces the input matrix by row, it makes sense to use the first expansion so that the matrix values - can be streamed directly. -* outer product (``ger``) can be expanded for FPGA using the ``FPGA`` expansion. Input vectors can be optionally tiled. -* matrix-matrix multiplication(``gemm``) FPGA specialization is implemented by the ``FPGA1DSystolic`` expansion. This implements the matrix-matrix - multiplication (with accumulation) using a 1D systolic array. The matrices can optionally be tiled along the result columns. - The user can specify the number of used processing elements and tile size according to her needs. -* Reduction library nodes can be inserted by the frontend. They "reduce" an array according to a binary operation (e.g., sum, max), starting - with initial value identity, over the given axis. Reductions can be specialized for FPGAs using the ``FPGAPartialReduction`` expansion. - - -How to specialized library node expansions for FPGA -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Users can target FPGA expansions either through VSCode plugin, or programmatically. -In the VSCode plugin, the user can select for each library node the desired expansion and apply it. - -To do this programmatically, the user has two options: - -* expand specific library nodes. This can be done by choosing the implementation, and manually expand it: - - .. code-block:: python - - # Get the library node that we want to expand, e.g., a gemv node - gemv_node = ... - - # Set the desired expansion, e.g., "FPGA_Accumulate" - gemv_node.implementation = "FPGA_Accumulate" - - # Expand it by passing the SDFG and state that contains it together - # with expansion arguments (if any). - # For example, in this case we specify a tile size of 1024 x 1024 elements - expansion_args = { - "tile_size_x": 1024, - "tile_size_y": 1024 - } - gemv_node.expand(sdfg, state, **expansion_args) - -* set a default expansion for all the library nodes of a given type: - - .. code-block:: python - - # Set a default expansion for all GEMM library node - from dace.libraries.blas import Gemm - Gemm.default_implementation = "FPGA1DSystolic" - - -Vectorization -------------------------------------- -TBD - -Maps and parallelism --------------------- - -In DaCe maps are used to express parallel scopes in SDFGs. -In the context of FPGAs, we distinguish between: - -* *pipelined* maps, where iterations are executed in sequence, exploiting pipeline parallelism in the mapped computation; -* *unrolled* maps, which represent parametrically replicated hardware, such as systolic arrays or SIMD-style vectorization. - -By default, maps are code-generated as pipelined loops. The user can switch to unrolled maps by changing their schedule (either -programmatically or through the VSCode plugin). For pipelined maps, the schedule must be set to :py:data:`~dace.dtypes.ScheduleType.Default`, while -for unrolled maps it must be set to :py:data:`~dace.dtypes.ScheduleType.Unrolled`. - -.. TODO: add a simple illustrative figure (or a snippet of generated code) -- probably it is better to add both of them - -FPGA memory hierarchy ------------------------------ - -Modern FPGAs are characterized by having small, fast on-chip memory and large, but slower, off-chip memory. - -DaCe allows to specify for each FPGA container, where it should be allocated by specifying its :py:data:`~dace.dtypes.StorageType`, either programmatically -or through the VSCode plugin. We can distinguish between: - -* *global* memory (:py:data:`~dace.dtypes.StorageType.FPGA_Global`), which represents data present in off-chip, memory-mapped storage such as DDR or HBM. - Containers in global memory can be created/accessed from both the host and the device side; -* *local* memory (:py:data:`~dace.dtypes.StorageType.FPGA_Local`), representing any on-chip memory implementation such as registers, BRAM/M20K, - LUTRAM, or UltraRAM. Which one will be actually used is left up to the HLS compiler; -* *register* memory (:py:data:`~dace.dtypes.StorageType.FPGA_Register`), which is a subset of local memory, but forces the compiler to implement it - as register (LUT), allowing parallel read/write to the container. This can be useful in the presence of unrolled maps. - - -.. TODO: also introduce Shift Register - -.. _fpga_streams: - -Streams and how to exploit them -------------------------------- -In DaCe, stream containers represent single or multidimensional arrays of First-In-First-Out (FIFO) queues (see :ref:`descriptors`). - -In FPGAs, they are implemented in hardware (FIFOs) either using BRAM or registers. This implies that streams -cannot be unbounded and must be single-producer, single-consumer. - -Streams can be particularly useful in FPGA programs as: - -* they facilitate the division of the program logic in processing elements. The different processing elements can be - simultaneously in execution while communicating using fast on-chip resources, reducing more expensive off-chip memory - accesses; -* they allow memory access extraction, enabling compute and memory accesses to be pipelined and optimized separately. - Creating streaming accessors has many benefits, including using burst mode in memory controllers, tailored buffering, - or broadcasting off-chip memory to multiple processing elements. - - -While these opportunities can be exploited by carefully designing the SDFG, -DaCe also provides transformations to automatically enabling them (see :ref:`fpga_transformations`). - -.. TODO: add sample code - - - - -FPGA kernels and processing elements ------------------------------------- - -.. TODO: this is part of the general info (schedule, storage, dataflow structure) -.. an embedded SDFG example would go a long way - -In DaCe, a state that only accesses containers situated on the FPGA will trigger FPGA code generation. - -In DaCe, we hierarchically organize the code in *FPGA Kernels*, which can be further divided into multiple *Processing elements*. -These concepts will be mapped to different entities depending on the used FPGA backend (see :ref:`Code generating FPGA kernels and processing elements `). - - - -By default, an SDFG state with only FPGA containers is inferred as an FPGA kernel. Then, each of the weakly connected component -found in the state are treated as different Processing Elements, that can be executed in parallel. -The notion of partitioning the functionality of a kernel into multiple independently-scheduled modules is central to designing large FPGA architectures, and can be exploited to write systolic arrays. - -If the :envvar:`compiler.fpga.concurrent_kernel_detection` configuration option is set to ``True``, -a heuristic will further inspect each independent component for other parallelism opportunities (e.g., branches of the SDFG -that can be executed in parallel). If this is the case, multiple, possibly depending, FPGA Kernels are generated for the same state. diff --git a/doc/optimization/optimization.rst b/doc/optimization/optimization.rst index 69a7b382c5..392d1895c4 100644 --- a/doc/optimization/optimization.rst +++ b/doc/optimization/optimization.rst @@ -49,7 +49,6 @@ The following subsections provide more information on the different types of opt blas vscode gpu - fpga .. interactive diff --git a/doc/optimization/profiling.rst b/doc/optimization/profiling.rst index 87539e87a8..93346fc32c 100644 --- a/doc/optimization/profiling.rst +++ b/doc/optimization/profiling.rst @@ -121,7 +121,8 @@ Instrumentation can also collect performance counters on CPUs and GPUs using `LI The :class:`~dace.dtypes.InstrumentationType.LIKWID_Counters` instrumentation type can be configured to collect a wide variety of performance counters on CPUs and GPUs. An example use can be found in the `LIKWID instrumentation code sample `_. - +The :class:`~dace.dtypes.InstrumentationType.GPU_TX_MARKERS` instrumentation type wraps a DaCe program executed on the GPU with NVTX or rocTX markers. Important parts of the execution of the program on the GPU as the different states, SDFGs and initialization and finalization phases are marked with these markers. +These markers can be used to visualize and measure the GPU activity using the NVIDIA Nsight Systems or ROCm Systems profilers. Instrumentation file format ~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/doc/sdfg/auto_optimize.rst b/doc/sdfg/auto_optimize.rst index e8770f5838..e6cda4725b 100644 --- a/doc/sdfg/auto_optimize.rst +++ b/doc/sdfg/auto_optimize.rst @@ -51,7 +51,6 @@ transformations, applied in this order: * Multi-dimensional :class:`~dace.transformation.dataflow.map_collapse.MapCollapse` to parallelize across multiple dimensions. * Greedy subgraph fusion (fusing contents of maps with common dimensions to reduce data movement). See :class:`~dace.transformation.subgraph.subgraph_fusion.SubgraphFusion` for more information. * Move loops into maps (when memory access pattern permits) in order to increase the granularity of work threads perform (:class:`~dace.transformation.interstate.move_loop_into_map.MoveLoopIntoMap`). - * (for FPGAs) Interleave data containers (e.g. arrays) in off-chip memory banks, and use local memory (e.g. BRAM) when possible. * Tiling of maps with write-conflict resolution to reduce atomic operations (tile sizes are configurable via :envvar:`optimizer.autotile_size`). Partial parallelism (non-conflicting dimensions) can also be extracted to convert atomics to simple updates (configurable in :envvar:`optimizer.autotile_partial_parallelism`, True by default). @@ -63,4 +62,4 @@ transformations, applied in this order: * Make transient data containers' allocation lifetime :class:`dace.dtypes.AllocationLifetime.Persistent`, if possible. This moves allocation and deallocation out of the critical code path and into the SDFG init/exit functions. -Apart from those, the pass transforms the SDFG to run on the specified platform (e.g., GPU, FPGA). +Apart from those, the pass transforms the SDFG to run on the specified platform (e.g. GPU). diff --git a/doc/setup/config.rst b/doc/setup/config.rst index 58895e0bc0..922350db69 100644 --- a/doc/setup/config.rst +++ b/doc/setup/config.rst @@ -121,7 +121,3 @@ GPU programming and debugging: ``hip`` for AMD GPUs). * :envvar:`compiler.cuda.syncdebug` (default: False): If True, calls device-synchronization after every GPU kernel and checks for errors. Good for checking crashes or invalid memory accesses. - -FPGA programming: - - * :envvar:`compiler.fpga.vendor`: Can be ``xilinx`` for Xilinx FPGAs, or ``intel_fpga`` for Intel FPGAs. diff --git a/doc/setup/installation.rst b/doc/setup/installation.rst index e8dbc4d570..9f122d2c59 100644 --- a/doc/setup/installation.rst +++ b/doc/setup/installation.rst @@ -14,16 +14,13 @@ Most dependencies will be resolved when the package is installed with ``pip`` or however, it requires two more runtime dependencies to be installed and available in the ``PATH`` environment variable (if not, see :ref:`config` for how to configure different compiler paths): - * A C++14-capable compiler (e.g., gcc 5.3+) + * A C++20-capable compiler (e.g., gcc 10+) * CMake 3.15 or newer. *Note: if CMake cannot be found or is too old, pip will try to install a version but it sometimes fails.* **GPU**: For NVIDIA GPUs, the CUDA toolkit is also required, and AMD GPUs require HIP. :ref:`See more information on how to configure DaCe to use AMD GPUs `. You may (optionally) want to install `CuPy `_ for easy integration of GPU arrays in Python. -**FPGA**: Xilinx FPGAs require the Vitis suite and Intel FPGAs require the Intel FPGA SDK to be installed. -DaCe has been tested with Intel FPGA SDK for OpenCL Pro edition v18.1 and v19.1, targeting Arria 10 and Stratix 10 devices, and Xilinx Vitis HLS v2020.x, v2021.x targeting u250 and u280 devices. - **Distributed Computing**: If using multiple nodes, MPI has to be installed and available. @@ -138,12 +135,6 @@ Common issues with the DaCe Python module * **Bug in DaCe**: If you suspect an issue happens within DaCe, see :ref:`debugging` for ways to pinpoint the source of the issue. - * **Intel FPGA libraries not found**: when targeting Intel FPGAs, the compilation process may fail due to missing OpenCL headers (CMake returns - a ``Could NOT find IntelFPGAOpenCL`` error). This is usually the case when Intel OpenCL compiler does not return the right path to OpenCL host headers. - DaCe relies on ``hlslib`` for compiling FPGA programs, which in turns relies on Intel's compiler to derive the right include path. Please verify that - the include path returned by the Intel compiler (using the ``aocl compile-config`` command) points to a directory that actually contains the OpenCL headers (namely ``cl.hpp`` and - ``cl2.hpp`` files). If this is not the case, please locate them under the Intel Quartus installation folder, and symlink (or copy) them in the ``aocl`` returned path. - .. _qa_vscode: Common issues with the Visual Studio Code extension diff --git a/doc/source/dace.codegen.instrumentation.rst b/doc/source/dace.codegen.instrumentation.rst index d476090d6a..138ca40386 100644 --- a/doc/source/dace.codegen.instrumentation.rst +++ b/doc/source/dace.codegen.instrumentation.rst @@ -4,10 +4,10 @@ dace.codegen.instrumentation package Submodules ---------- -dace.codegen.instrumentation.fpga module ----------------------------------------- +dace.codegen.instrumentation.gpu_tx_markers module +----------------------------------------------- -.. automodule:: dace.codegen.instrumentation.fpga +.. automodule:: dace.codegen.instrumentation.gpu_tx_markers :members: :undoc-members: :show-inheritance: diff --git a/doc/source/dace.codegen.rst b/doc/source/dace.codegen.rst index d3611c697b..17502ec157 100644 --- a/doc/source/dace.codegen.rst +++ b/doc/source/dace.codegen.rst @@ -84,6 +84,13 @@ dace.codegen.prettycode module :undoc-members: :show-inheritance: +dace.codegen.target module +---------------------------------- + +.. automodule:: dace.codegen.target + :members: + :undoc-members: + :show-inheritance: Module contents --------------- diff --git a/doc/source/dace.codegen.targets.rst b/doc/source/dace.codegen.targets.rst index e133bf3fb2..03e5cf276c 100644 --- a/doc/source/dace.codegen.targets.rst +++ b/doc/source/dace.codegen.targets.rst @@ -36,39 +36,6 @@ dace.codegen.targets.mpi module :undoc-members: :show-inheritance: -dace.codegen.targets.target module ----------------------------------- - -.. automodule:: dace.codegen.targets.target - :members: - :undoc-members: - :show-inheritance: - -dace.codegen.targets.fpga module --------------------------------- - -.. automodule:: dace.codegen.targets.fpga - :members: - :undoc-members: - :show-inheritance: - -dace.codegen.targets.xilinx module ----------------------------------- - -.. automodule:: dace.codegen.targets.xilinx - :members: - :undoc-members: - :show-inheritance: - -dace.codegen.targets.intel_fpga module --------------------------------------- - -.. automodule:: dace.codegen.targets.intel_fpga - :members: - :undoc-members: - :show-inheritance: - - Module contents --------------- diff --git a/doc/source/dace.transformation.auto.rst b/doc/source/dace.transformation.auto.rst index 6d603e487a..eb4d807179 100644 --- a/doc/source/dace.transformation.auto.rst +++ b/doc/source/dace.transformation.auto.rst @@ -21,12 +21,3 @@ Module contents :members: :undoc-members: :show-inheritance: - - -dace.transformation.auto.fpga module ------------------------------------- - -.. automodule:: dace.transformation.auto.fpga - :members: - :undoc-members: - :show-inheritance: diff --git a/doc/source/dace.transformation.interstate.rst b/doc/source/dace.transformation.interstate.rst index 01d9230e85..f0529edeef 100644 --- a/doc/source/dace.transformation.interstate.rst +++ b/doc/source/dace.transformation.interstate.rst @@ -4,22 +4,6 @@ dace.transformation.interstate package Submodules ---------- -dace.transformation.interstate.fpga\_transform\_sdfg module ------------------------------------------------------------ - -.. automodule:: dace.transformation.interstate.fpga_transform_sdfg - :members: - :undoc-members: - :show-inheritance: - -dace.transformation.interstate.fpga\_transform\_state module ------------------------------------------------------------- - -.. automodule:: dace.transformation.interstate.fpga_transform_state - :members: - :undoc-members: - :show-inheritance: - dace.transformation.interstate.gpu\_transform\_sdfg module ---------------------------------------------------------- diff --git a/pytest.ini b/pytest.ini index b0aa6e9b8f..3925db3286 100644 --- a/pytest.ini +++ b/pytest.ini @@ -3,20 +3,21 @@ markers = gpu: Test requires a compute-capable GPU (select with '-m "gpu"') tensorflow: Test requires TensorFlow installed (select with '-m "tensorflow"') mkl: Test requires Intel MKL (select with '-m "mkl"') - verilator: Test requires Verilator (>=v4.028) installed (select with '-m "verilator"') papi: Test requires PAPI counters to work (select with '-m "papi"') mlir: Test requires pyMLIR, MLIR tools and LLVM compiler to work (select with '-m "mlir"') sve: Test requires SVE-capable ARM processor (select with '-m "sve"') lapack: Test for the LAPACK library that requires OpenBLAS (select with '-m "lapack"') - fpga: Test requires the Xilinx and Intel FPGA tools to be evaluated. (select with '-m "fpga"') - rtl_hardware: Test requires the Xilinx tools to be evaluated. (select with '-m rtl_hardware') mpi: Test requires MPI. (select with '-m mpi') scalapack: Test requires ScaLAPACK (Intel MKL and OpenMPI). (select with '-m scalapack') datainstrument: Test uses data instrumentation (select with '-m datainstrument') hptt: Test requires the HPTT library (select with '-m "hptt') long: Test runs for a long time and is skipped in CI (select with '-m "long"') + torch: Test for the PyTorch/ONNX frontend (select with '-m "torch"') + autodiff: Test for automatic differentiation (select with '-m "autodiff"') + onnx: Test for the ONNX frontend (select with '-m "onnx"') sequential: Test must be run sequentially (select with '-m "sequential"') python_files = + test_*.py *_test.py *_cudatest.py addopts = --ignore=dace/external --color=yes diff --git a/samples/README.md b/samples/README.md index e5d51a696f..e244ebf54d 100644 --- a/samples/README.md +++ b/samples/README.md @@ -7,6 +7,5 @@ There are several sub-folders here: * **explicit**: Examples that use the explicit data-centric API (`dace.map`, `dace.tasklet`), giving fine-grained control over hardware mapping * **sdfg_api**: Programs that use the SDFG API to create the DaCe intermediate representation directly. Useful if you are writing your own frontend. * **optimization**: Examples that use the transformation and instrumentation API to optimize programs to run fast on CPUs and GPUs -* **fpga**: FPGA programs with explicit circuit design patterns (e.g., systolic arrays), mostly using the SDFG API * **distributed**: Python/NumPy and explicit applications that run on multiple machines * **codegen**: Samples showing how to extend the code generator of DaCe to support new platforms (e.g., Tensor Cores) diff --git a/samples/codegen/tensor_cores.py b/samples/codegen/tensor_cores.py index 5fc2afc1d0..a8a55d8388 100644 --- a/samples/codegen/tensor_cores.py +++ b/samples/codegen/tensor_cores.py @@ -13,7 +13,7 @@ # Code generator imports and helpers from dace.codegen.targets.framecode import DaCeCodeGenerator -from dace.codegen.targets.target import TargetCodeGenerator +from dace.codegen.target import TargetCodeGenerator from dace.codegen.targets.cpp import cpp_array_expr, cpp_offset_expr # Frontend imports and helpers @@ -135,9 +135,9 @@ def copy_memory(self, sdfg: dace.SDFG, cfg: ControlFlowRegion, dfg: StateSubgrap # Set non-tensor-core C++ expression based on memlet if edge.data.data == nontc_node.data: - other_expr = cpp_array_expr(sdfg, edge.data) + other_expr = cpp_array_expr(sdfg, edge.data, framecode=self._frame) elif edge.data.other_subset is not None: - offset_cppstr = cpp_offset_expr(nontc_desc, edge.data.other_subset) + offset_cppstr = cpp_offset_expr(nontc_desc, edge.data.other_subset, codegen=self) other_expr = '%s[%s]' % (nontc_node.data, offset_cppstr) else: other_expr = '%s[0]' % nontc_node.data diff --git a/samples/fpga/axpy_transformed.py b/samples/fpga/axpy_transformed.py deleted file mode 100644 index 85e6775a22..0000000000 --- a/samples/fpga/axpy_transformed.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -import argparse -import dace -import numpy as np -from dace.transformation.interstate import FPGATransformSDFG - -N = dace.symbol('N') - - -@dace.program(dace.float64, dace.float64[N], dace.float64[N]) -def axpy(A, X, Y): - - @dace.map(_[0:N]) - def multiplication(i): - in_A << A - in_X << X[i] - in_Y << Y[i] - out >> Y[i] - - out = in_A * in_X + in_Y - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("N", type=int, nargs="?", default=24) - args = vars(parser.parse_args()) - - print('Scalar-vector multiplication %d' % (args['N'])) - - A = dace.float64(np.random.rand()) - X = np.random.rand(args['N']) - Y = np.random.rand(args['N']) - expected = A * X + Y - - # Obtain SDFG from @dace.program - sdfg = axpy.to_sdfg() - - # Convert SDFG to FPGA using a transformation - sdfg.apply_transformations(FPGATransformSDFG) - - # Specialize and execute SDFG on FPGA - sdfg._name = 'axpy_fpga_%d' % args['N'] - sdfg.specialize(dict(N=args['N'])) - sdfg(A=A, X=X, Y=Y) - - diff = np.linalg.norm(expected - Y) / args['N'] - print("Difference:", diff) - exit(0 if diff <= 1e-5 else 1) diff --git a/samples/fpga/gemm_systolic_vectorized.py b/samples/fpga/gemm_systolic_vectorized.py deleted file mode 100644 index 744bf543dd..0000000000 --- a/samples/fpga/gemm_systolic_vectorized.py +++ /dev/null @@ -1,734 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -""" -Computes C = A @ B + C. - -This implementation is based on the HLS implementation from: - https://github.com/spcl/gemm_hls -It uses the compute and I/O optimal strategy described in the FPGA'20 paper: - "Flexible Communication Avoiding Matrix Multiplication on FPGA with - High-Level Synthesis". -""" - -import click -import copy -import dace -import numpy as np -from dace.libraries.standard import Gearbox -from dace.transformation.interstate import InlineSDFG - -MINIMUM_CHANNEL_DEPTH = 8 - -# Symbols used in this implementation: -# -# N: Number of rows of A and C. -# K: Number of columns of A and rows of B. -# M: Number of columns of B and C. -# TN: The tile size in N, which must divide the size in N. -# TM: The tile size in M, which must divide the size in M. -# P: The number of (vertically unrolled) processing elements, and -# consequently one of the two degrees of parallelism in the kernel. Must -# divide the tile size TN. -# W: The vectorization width, being the other degree of parallelism. Must -# divide the tile size TM. - - -def make_copy_to_fpga_state(sdfg, vtype): - """ - Creates the pre-state where the matrices are transferred to the FPGA. - """ - - state = sdfg.add_state("copy_to_device") - dtype = vtype.base_type - # mem_veclen is the vectorization width necessary to create a 512-bit - # interface to memory, and mtype is the corresponding type. - mem_veclen = 64 // dtype.bytes - mtype = dace.vector(dtype, mem_veclen) - - # Host data has plain data types - sdfg.add_array("A", ["N", "K"], dtype=dtype) - sdfg.add_array("B", ["K", "M"], dtype=dtype) - sdfg.add_array("C", ["N", "M"], dtype=dtype) - A_host = state.add_read("A") - B_host = state.add_read("B") - C_host = state.add_read("C") - - # On the device, vector B and C will be vectorized along rows. A is read - # column-wise, so it is not vectorized. - sdfg.add_array("A_device", ["N", f"K//{mem_veclen}"], - dtype=mtype, - transient=True, - location={ - "memorytype": "DDR", - "bank": 1 - }, - storage=dace.dtypes.StorageType.FPGA_Global) - sdfg.add_array("B_device", ["K", f"M//{mem_veclen}"], - dtype=mtype, - transient=True, - location={ - "memorytype": "DDR", - "bank": 1 - }, - storage=dace.dtypes.StorageType.FPGA_Global) - sdfg.add_array("C_device", ["N", f"M//{mem_veclen}"], - dtype=mtype, - transient=True, - location={ - "memorytype": "DDR", - "bank": 1 - }, - storage=dace.dtypes.StorageType.FPGA_Global) - A_device = state.add_write("A_device") - B_device = state.add_write("B_device") - C_device = state.add_write("C_device") - - state.add_memlet_path(A_host, A_device, memlet=dace.Memlet(f"A_device[0:N, 0:K//{mem_veclen}]")) - state.add_memlet_path(B_host, B_device, memlet=dace.Memlet(f"B_device[0:K, 0:M//{mem_veclen}]")) - state.add_memlet_path(C_host, C_device, memlet=dace.Memlet(f"C_device[0:N, 0:M//{mem_veclen}]")) - - return state - - -def make_copy_to_host_state(sdfg, vtype): - """ - Creates the post-state where C is copied back to the host. - """ - - state = sdfg.add_state("copy_to_host") - - C_device = state.add_read("C_device") - C_host = state.add_write("C") - - state.add_memlet_path(C_device, C_host, memlet=dace.Memlet("C[0:N, 0:M]")) - - return state - - -def make_read_A(sdfg, state, vtype): - """ - Creates the memory read from A, which performs in-memory transposition by - reading 512-bit wide vectors of A, then piping them into separate streams - that are popped one at a time and sent to the kernel. - """ - - # Deduce types - dtype = vtype.base_type - mem_veclen = 64 // dtype.bytes - - # Unpack vector into a register - sdfg.add_array("transpose_reg", (mem_veclen, ), dtype, storage=dace.StorageType.FPGA_Local, transient=True) - - # Add a stream for each element in the vector - sdfg.add_stream( - "transpose", - dtype, - # Allow loading the next column while the previous is being - # used - buffer_size="2 * TN", - shape=(mem_veclen, ), - storage=dace.StorageType.FPGA_Local, - transient=True) - - # Read each element into a buffer to unpack the vector into individual - # elements - mem = state.add_read("A_device") - entry, exit = state.add_map("read_A", { - "n0": "0:N//TN", - "m": "0:M//TM", - "k0": f"0:K//{mem_veclen}", - "n1": "0:TN", - }, - schedule=dace.ScheduleType.FPGA_Device) - buffer_access = state.add_access("transpose_reg") - state.add_memlet_path(mem, entry, buffer_access, memlet=dace.Memlet("A_device[n0 * TN + n1, k0]")) - - # Now stick each element into a separate stream - unroll_entry, unroll_exit = state.add_map("unpack_A", {"k1": f"0:{mem_veclen}"}, - schedule=dace.ScheduleType.FPGA_Device, - unroll=True) - unroll_tasklet = state.add_tasklet("unpack_A", {"from_memory"}, {"to_pipe"}, "to_pipe = from_memory") - unroll_write = state.add_write("transpose") - state.add_memlet_path(buffer_access, - unroll_entry, - unroll_tasklet, - dst_conn="from_memory", - memlet=dace.Memlet(f"transpose_reg[k1]")) - state.add_memlet_path(unroll_tasklet, - unroll_exit, - exit, - unroll_write, - src_conn="to_pipe", - memlet=dace.Memlet(f"transpose[k1]")) - - # A separate processing element will pop from the streams one at a time - transpose_read = state.add_read("transpose") - transpose_entry, transpose_exit = state.add_map("transpose_A", { - "n0": "0:N//TN", - "m": "0:M//TM", - "k0": f"0:K//{mem_veclen}", - "k1": f"0:{mem_veclen}", - "n1": "0:TN", - }, - schedule=dace.ScheduleType.FPGA_Device) - pipe_out = state.add_write("A_pipe") - tasklet = state.add_tasklet("transpose_A", {"from_transpose"}, {"to_kernel"}, "to_kernel = from_transpose") - state.add_memlet_path(transpose_read, - transpose_entry, - tasklet, - dst_conn="from_transpose", - memlet=dace.Memlet(f"transpose[k1]")) - state.add_memlet_path(tasklet, transpose_exit, pipe_out, src_conn="to_kernel", memlet=dace.Memlet("A_pipe[0]")) - - -def make_read_B(sdfg, state, vtype): - - # Deduce types - dtype = vtype.base_type - mem_veclen = 64 // dtype.bytes - mtype = dace.vector(dtype, mem_veclen) - - entry, exit = state.add_map("read_B", { - "n0": "0:N//TN", - "m0": "0:M//TM", - "k": "0:K", - "m1": f"0:TM//{mem_veclen}" - }, - schedule=dace.ScheduleType.FPGA_Device) - - mem = state.add_read("B_device") - to_feeder = state.add_write("B_to_feeder") - tasklet = state.add_tasklet("read_B", {"from_memory"}, {"to_feeder"}, "to_feeder = from_memory") - state.add_memlet_path(mem, - entry, - tasklet, - dst_conn="from_memory", - memlet=dace.Memlet(f"B_device[k, m0 * (TM//{mem_veclen}) + m1]")) - - if mem_veclen > vtype.veclen: - - # Data arrives as 512-bit wide vectors, and will be converted to the - # vector length of the kernel - - sdfg.add_stream("B_to_converter", - dtype=mtype, - buffer_size=MINIMUM_CHANNEL_DEPTH, - storage=dace.StorageType.FPGA_Local, - transient=True) - to_converter_write = state.add_write("B_to_converter") - state.add_memlet_path(tasklet, - exit, - to_converter_write, - src_conn="to_feeder", - memlet=dace.Memlet("B_to_converter[0]")) - - # Convert 512-bit vectors to whatever width the kernel uses - to_converter_read = state.add_read("B_to_converter") - gearbox = Gearbox(f"(N//TN) * (M//TM) * K * (TM//{mem_veclen})", "convert_B", dace.ScheduleType.FPGA_Device) - state.add_memlet_path(to_converter_read, - gearbox, - dst_conn="from_memory", - memlet=dace.Memlet(f"B_to_converter[0]", dynamic=True)) - state.add_memlet_path(gearbox, - to_feeder, - src_conn="to_feeder", - memlet=dace.Memlet("B_to_feeder[0]", dynamic=True)) - - else: - - # If the kernel uses the full memory width, just send the data directly - # without any conversion - state.add_memlet_path(tasklet, exit, to_feeder, src_conn="to_feeder", memlet=dace.Memlet(f"B_to_feeder[0]")) - - -def make_feed_B(sdfg, state, vtype): - """ - This module will buffer the values read from the B matrix, sending them - multiple times to the kernel for each row of A in the current outer product. - """ - - entry, exit = state.add_map("feed_B", { - "n0": "0:N//TN", - "m0": "0:M//TM", - "k": "0:K", - "n1": "0:TN//P", - "m1": "0:TM//W" - }, - schedule=dace.ScheduleType.FPGA_Device) - - sdfg.add_array("feed_B_buffer", ("TM//W", ), vtype, storage=dace.StorageType.FPGA_Local, transient=True) - buffer_read = state.add_read("feed_B_buffer") - buffer_write = state.add_write("feed_B_buffer") - - read = state.add_read("B_to_feeder") - write = state.add_write("B_pipe") - tasklet = state.add_tasklet( - "feed_B", {"from_memory", "buffer_in"}, {"to_kernel", "buffer_out"}, """ -val = buffer_in -if n1 == 0: - val = from_memory -to_kernel = val -buffer_out = val""") - - state.add_memlet_path(read, - entry, - tasklet, - dst_conn="from_memory", - memlet=dace.Memlet("B_to_feeder[0]", dynamic=True)) - state.add_memlet_path(buffer_read, entry, tasklet, dst_conn="buffer_in", memlet=dace.Memlet("feed_B_buffer[m1]")) - state.add_memlet_path(tasklet, exit, buffer_write, src_conn="buffer_out", memlet=dace.Memlet("feed_B_buffer[m1]")) - state.add_memlet_path(tasklet, exit, write, src_conn="to_kernel", memlet=dace.Memlet("B_pipe[0]")) - - -def make_write_C(sdfg, state, vtype): - - # Deduce types - dtype = vtype.base_type - mem_veclen = 64 // dtype.bytes - mtype = dace.vector(dtype, mem_veclen) - - from_kernel = state.add_read("C_pipe") - mem_read = state.add_read("C_device") - mem_write = state.add_write("C_device") - - if mem_veclen > vtype.veclen: - - # We need to convert from the kernel vectorization length to 512-bit - # vectors that are written back to memory - - gearbox = Gearbox(f"(N//TN) * (M//TM) * TN * (TM//{mem_veclen})", - "convert_C", - schedule=dace.ScheduleType.FPGA_Device) - sdfg.add_stream("C_from_converter", - mtype, - buffer_size=f"TM//{mem_veclen}", - storage=dace.StorageType.FPGA_Local, - transient=True) - converter_write = state.add_write("C_from_converter") - state.add_memlet_path(from_kernel, - gearbox, - dst_conn="from_kernel", - memlet=dace.Memlet(f"C_pipe[0]", dynamic=True)) - state.add_memlet_path(gearbox, - converter_write, - src_conn="to_memory", - memlet=dace.Memlet("C_from_converter[0]", dynamic=True)) - - to_writer = state.add_read("C_from_converter") - to_writer_subset = "C_from_converter[0]" - - else: - - # Just send the data directly to the reader - to_writer = from_kernel - to_writer_subset = "C_pipe[0]" - - entry, exit = state.add_map("write_C", { - "n0": "0:N//TN", - "m0": "0:M//TM", - "n1": "0:TN", - "m1": f"0:TM//{mem_veclen}" - }, - schedule=dace.ScheduleType.FPGA_Device) - - tasklet = state.add_tasklet("write_C", {"from_kernel", "prev"}, {"to_memory"}, "to_memory = from_kernel + prev") - state.add_memlet_path(to_writer, entry, tasklet, dst_conn="from_kernel", memlet=dace.Memlet(to_writer_subset)) - - state.add_memlet_path(mem_read, - entry, - tasklet, - dst_conn="prev", - memlet=dace.Memlet(f"C_device[n0 * TN + n1, m0 * (TM//{mem_veclen}) + m1]")) - - state.add_memlet_path(tasklet, - exit, - mem_write, - src_conn="to_memory", - memlet=dace.Memlet(f"C_device[n0 * TN + n1, m0 * (TM//{mem_veclen}) + m1]")) - - -def make_compute(sdfg, state, vtype): - - dtype = vtype.base_type - - # Pipes connecting the systolic array - A_pipe_in = state.add_read("A_pipe") - A_pipe_out = state.add_write("A_pipe") - B_pipe_in = state.add_read("B_pipe") - B_pipe_out = state.add_write("B_pipe") - C_pipe_in = state.add_read("C_pipe") - C_pipe_out = state.add_write("C_pipe") - - # Instantiate the buffer for A, and initialize it - sdfg.add_array("A_buffer", ("2 * (TN//P)", ), - dtype=dtype, - transient=True, - storage=dace.dtypes.StorageType.FPGA_Registers) - init_A = state.add_access("A_buffer") - init_entry, init_exit = state.add_map("init", { - "n0": "0:TN//P", - "n1": "0:P-p" - }, - schedule=dace.ScheduleType.FPGA_Device) - init_tasklet = state.add_tasklet( - "init_A", {"from_prev"}, {"to_buffer", "to_next"}, """\ -val = from_prev -if n1 == 0: - to_buffer = val -elif p < P - 1: - to_next = val""") - state.add_memlet_path(A_pipe_in, init_entry, init_tasklet, dst_conn="from_prev", memlet=dace.Memlet("A_pipe[p]")) - state.add_memlet_path(init_tasklet, - init_exit, - init_A, - src_conn="to_buffer", - memlet=dace.Memlet(f"A_buffer[n0]", dynamic=True)) - state.add_memlet_path(init_tasklet, - init_exit, - A_pipe_out, - src_conn="to_next", - memlet=dace.Memlet(f"A_pipe[p + 1]", dynamic=True)) - - # Now instantiate the body of the computation - outer_entry, outer_exit = state.add_map("tiles", { - "n0": "0:N//TN", - "m0": "0:M//TM" - }, - schedule=dace.ScheduleType.FPGA_Device) - - # Make a dummy edge to bring the initialization buffer into scope - state.add_memlet_path(init_A, outer_entry, memlet=dace.Memlet()) - - # Loop over the reduced dimension - k_entry, k_exit = state.add_map("k", {"k": "0:K"}, schedule=dace.ScheduleType.FPGA_Device) - - # Loops over the tile content - inner_entry, inner_exit = state.add_map("inner", { - "n1": "0:TN//P", - "m1": "0:TM//W" - }, - schedule=dace.ScheduleType.FPGA_Device) - - # Double buffering scheme of A - update_A = state.add_access("A_buffer") - buffer_tasklet = state.add_tasklet( - "double_buffer_A", {"from_prev"}, {"to_buffer", "to_next"}, """\ -if (n0 < (N/TN) - 1 or m0 < (M/TM) - 1 or k < K - 1) and m1 >= p and m1 < P: - val = from_prev - if m1 == p: - to_buffer = val - elif p < P - 1: - to_next = val""") - state.add_memlet_path(A_pipe_in, - outer_entry, - k_entry, - inner_entry, - buffer_tasklet, - dst_conn="from_prev", - memlet=dace.Memlet(f"A_pipe[p]", dynamic=True)) - state.add_memlet_path(buffer_tasklet, - update_A, - src_conn="to_buffer", - memlet=dace.Memlet(f"A_buffer[n1 + (1 - (k % 2)) * (TN//P)]", dynamic=True)) - state.add_memlet_path(buffer_tasklet, - inner_exit, - k_exit, - outer_exit, - A_pipe_out, - src_conn="to_next", - memlet=dace.Memlet(f"A_pipe[p + 1]", dynamic=True)) - - # Instantiate the "big" buffer of the output, where most of our fast memory - # will be spent - sdfg.add_array("C_buffer", ("TN/P", "TM/W"), vtype, storage=dace.StorageType.FPGA_Local, transient=True) - - # Now the tasklet performing the actual computation - compute_tasklet = state.add_tasklet( - "multiply_add", {"a_in", "b_in", "c_in"}, {"b_out", "c_out"}, """\ -if p < P - 1: - b_out = b_in -c_val = c_in -if k == 0: - c_val = 0 -c_out = c_val + a_in * b_in""") - C_buffer_read = state.add_read("C_buffer") - C_buffer_write = state.add_access("C_buffer") - state.add_memlet_path(update_A, - compute_tasklet, - dst_conn="a_in", - memlet=dace.Memlet(f"A_buffer[n1 + (k % 2) * (TN//P)]")) - state.add_memlet_path(B_pipe_in, - outer_entry, - k_entry, - inner_entry, - compute_tasklet, - dst_conn="b_in", - memlet=dace.Memlet("B_pipe[p]")) - state.add_memlet_path(C_buffer_read, - outer_entry, - k_entry, - inner_entry, - compute_tasklet, - dst_conn="c_in", - memlet=dace.Memlet("C_buffer[n1, m1]")) - state.add_memlet_path(compute_tasklet, - inner_exit, - k_exit, - outer_exit, - B_pipe_out, - src_conn="b_out", - memlet=dace.Memlet("B_pipe[p + 1]", dynamic=True)) - state.add_memlet_path(compute_tasklet, - inner_exit, - k_exit, - C_buffer_write, - src_conn="c_out", - memlet=dace.Memlet("C_buffer[n1, m1]")) - - # Now we need to write C out after each tile has been processed - write_entry, write_exit = state.add_map("write_C", {"n1": "0:TN//P"}, schedule=dace.ScheduleType.FPGA_Device) - - # We need to enforce sequentiality between these loops - write_sdfg = dace.SDFG("write_C") - write_sdfg_node = state.add_nested_sdfg(write_sdfg, {"buffer_in", "forward_in"}, {"forward_out"}) - state.add_memlet_path(C_buffer_write, - write_entry, - write_sdfg_node, - dst_conn="buffer_in", - memlet=dace.Memlet("C_buffer[n1, 0:TM/W]")) - state.add_memlet_path(C_pipe_in, - outer_entry, - write_entry, - write_sdfg_node, - dst_conn="forward_in", - memlet=dace.Memlet("C_pipe[p + 1]", dynamic=True)) - state.add_memlet_path(write_sdfg_node, - write_exit, - outer_exit, - C_pipe_out, - src_conn="forward_out", - memlet=dace.Memlet("C_pipe[p]", dynamic=True)) - write_sdfg.add_stream("forward_in", - vtype, - buffer_size=MINIMUM_CHANNEL_DEPTH, - storage=dace.StorageType.FPGA_Local, - transient=False) - write_sdfg.add_stream("forward_out", - vtype, - buffer_size=MINIMUM_CHANNEL_DEPTH, - storage=dace.StorageType.FPGA_Local, - transient=False) - write_sdfg.add_array("buffer_in", ("TM//W", ), vtype, transient=False, storage=dace.StorageType.FPGA_Local) - # Send results from this PE - send_state = write_sdfg.add_state("send_C") - send_read = send_state.add_read("buffer_in") - send_write = send_state.add_write("forward_out") - send_tasklet = send_state.add_tasklet("send_C", {"from_buffer"}, {"to_next"}, "to_next = from_buffer") - send_entry, send_exit = send_state.add_map("send_C", {"m1": "0:TM//W"}, schedule=dace.ScheduleType.FPGA_Device) - send_state.add_memlet_path(send_read, - send_entry, - send_tasklet, - dst_conn="from_buffer", - memlet=dace.Memlet("buffer_in[m1]")) - send_state.add_memlet_path(send_tasklet, - send_exit, - send_write, - src_conn="to_next", - memlet=dace.Memlet("forward_out[0]")) - # And finally forward results from earlier PEs - forward_state = write_sdfg.add_state("forward_C") - forward_read = forward_state.add_read("forward_in") - forward_write = forward_state.add_read("forward_out") - forward_tasklet = forward_state.add_tasklet("forward_C", {"from_prev"}, {"to_next"}, """\ -if p < P - 1: - to_next = from_prev""") - forward_entry, forward_exit = forward_state.add_map("forward_C", { - "n1": "0:P - p - 1", - "m1": "0:TM//W" - }, - schedule=dace.ScheduleType.FPGA_Device) - # These must be dynamic so the compiler can optimize out the write from the - # last processing element - forward_state.add_memlet_path(forward_read, - forward_entry, - forward_tasklet, - dst_conn="from_prev", - memlet=dace.Memlet("forward_in[0]", dynamic=True)) - forward_state.add_memlet_path(forward_tasklet, - forward_exit, - forward_write, - src_conn="to_next", - memlet=dace.Memlet("forward_out[0]", dynamic=True)) - # Enforce sending own data before forwarding - write_sdfg.add_edge(send_state, forward_state, dace.InterstateEdge()) - - # Unroll processing elements - unroll_entry, unroll_exit = state.add_map("unroll_processing_elements", {"p": "0:P"}, - schedule=dace.ScheduleType.FPGA_Device, - unroll=True) - - # Bring data nodes into scope - state.add_memlet_path(unroll_entry, A_pipe_in, memlet=dace.memlet.Memlet()) - state.add_memlet_path(unroll_entry, B_pipe_in, memlet=dace.memlet.Memlet()) - state.add_memlet_path(unroll_entry, C_pipe_in, memlet=dace.memlet.Memlet()) - state.add_memlet_path(unroll_entry, C_buffer_read, memlet=dace.memlet.Memlet()) - state.add_memlet_path(A_pipe_out, unroll_exit, memlet=dace.memlet.Memlet()) - state.add_memlet_path(B_pipe_out, unroll_exit, memlet=dace.memlet.Memlet()) - state.add_memlet_path(C_pipe_out, unroll_exit, memlet=dace.memlet.Memlet()) - - # Propagate symbols - write_sdfg.symbols = copy.deepcopy(sdfg.symbols) - write_sdfg.add_symbol("p", sdfg.symbols["P"]) - write_sdfg_node.symbol_mapping = {k: k for k in sdfg.free_symbols} - write_sdfg_node.symbol_mapping["p"] = "p" - - -def make_fpga_state(sdfg, vtype): - - state = sdfg.add_state("gemm") - - sdfg.add_stream("A_pipe", - vtype.base_type, - transient=True, - shape=("P + 1", ), - storage=dace.dtypes.StorageType.FPGA_Local, - buffer_size="P") - sdfg.add_stream("B_pipe", - vtype, - transient=True, - shape=("P + 1", ), - buffer_size=MINIMUM_CHANNEL_DEPTH, - storage=dace.dtypes.StorageType.FPGA_Local) - sdfg.add_stream("B_to_feeder", - vtype, - transient=True, - buffer_size=MINIMUM_CHANNEL_DEPTH, - storage=dace.StorageType.FPGA_Local) - sdfg.add_stream("C_pipe", - vtype, - transient=True, - shape=("P + 1", ), - buffer_size=MINIMUM_CHANNEL_DEPTH, - storage=dace.dtypes.StorageType.FPGA_Local) - - make_read_A(sdfg, state, vtype) - make_read_B(sdfg, state, vtype) - make_feed_B(sdfg, state, vtype) - make_compute(sdfg, state, vtype) - make_write_C(sdfg, state, vtype) - - return state - - -def make_sdfg(name, vtype): - - sdfg = dace.SDFG(name) - - pre_state = make_copy_to_fpga_state(sdfg, vtype) - compute_state = make_fpga_state(sdfg, vtype) - post_state = make_copy_to_host_state(sdfg, vtype) - - sdfg.add_edge(pre_state, compute_state, dace.sdfg.InterstateEdge()) - sdfg.add_edge(compute_state, post_state, dace.sdfg.InterstateEdge()) - - if vtype.bytes < 64: - sdfg.expand_library_nodes() - assert sdfg.apply_transformations_repeated(InlineSDFG) == 2 - - return sdfg - - -@click.command() -@click.argument("N", type=int) -@click.argument("K", type=int) -@click.argument("M", type=int) -@click.argument("num-pes", type=int) -@click.argument("vector-width", type=int) -@click.option("--dtype", default="float32") -@click.option("--tile-size-n", - type=int, - default=None, - help=("Must be a multiple of the number of processing elements, " - "and must divide the size in N.")) -@click.option("--tile-size-m", - type=int, - default=None, - help=("Must be a multiple of the vector size, and must divide" - " the size in M.")) -@click.option("--specialize/--no-specialize", default=False, help="Fix matrix sizes at compile time.") -def cli(n, k, m, num_pes, dtype, tile_size_n, tile_size_m, vector_width, specialize): - - # Some reasonable default values for tile sizes - if not tile_size_n: - tile_size_n = n // num_pes - if not tile_size_m: - tile_size_m = min(m, 1024) - - # Rename - P = num_pes - W = vector_width - TN = tile_size_n - TM = tile_size_m - - dtype = getattr(dace.dtypes, dtype) # Convert from string to typeclass - vtype = dace.vector(dtype, vector_width) - - if TN % P != 0: - raise ValueError(f"Tile size in N {TN} must be divisible by the number of processing elements {P}.") - if TM % W != 0: - raise ValueError(f"Tile size in M {TM} must be divisible by the vectorization width {W}.") - if n % TN != 0: - raise ValueError(f"Size in N {n} must be divisible by the tile size in N {TN}.") - if n % TM != 0: - raise ValueError(f"Size in M {m} must be divisible by the tile size in M {TM}.") - if (dtype.bytes * TM) % 64 != 0: - raise ValueError(f"Tile size in M {TM} must be a multiple of 64 bytes.") - if (dtype.bytes * k) % 64 != 0: - raise ValueError(f"Size in K {K} must be a multiple of 64 bytes.") - - dtype = dtype.type # Convert from typeclass to NumPy type - - if specialize: - name = (f"gemm_fpga_systolic_vectorized_d{num_pes}_" - f"w{vector_width}_{tile_size_n}x{tile_size_m}_{n}x{k}x{m}") - else: - name = (f"gemm_fpga_systolic_vectorized_d{num_pes}_" - f"w{vector_width}_{tile_size_n}x{tile_size_m}_NxKxM") - - sdfg = make_sdfg(name, vtype) - - # Specialize compile time constants - sdfg.specialize({"P": P, "W": W, "TN": TN, "TM": TM}) - if specialize: - sdfg.specialize({"N": n, "K": k, "M": m}) - - print(f"Matrix multiplication {n}x{k}x{m} " - f"with {num_pes} PEs " - f"and vectorization width {vector_width}, " - f"and tile sizes {TN}x{TM}.") - - # Initialize arrays: Randomize A and B, zero C - A = np.ndarray([n, k], dtype=dtype) - B = np.ndarray([k, m], dtype=dtype) - C = np.ndarray([n, m], dtype=dtype) - A[:] = np.random.rand(n, k).astype(dtype) - B[:] = np.random.rand(k, m).astype(dtype) - C[:] = np.random.rand(n, m).astype(dtype) - - # Compute reference result - C_reference = A @ B + C - - # Run DaCe program - if specialize: - sdfg(A=A, B=B, C=C) - else: - sdfg(A=A, B=B, C=C, N=n, K=k, M=m) - - # Verify results - if not np.allclose(C, C_reference): - raise ValueError("Verification failed.") - else: - print("Results successfully verified.") - - -if __name__ == "__main__": - cli() diff --git a/samples/fpga/gemv_fpga.py b/samples/fpga/gemv_fpga.py deleted file mode 100644 index 770fef1e30..0000000000 --- a/samples/fpga/gemv_fpga.py +++ /dev/null @@ -1,207 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -import argparse -import dace -import numpy as np -import select -import sys - -N = dace.symbol("N") -M = dace.symbol("M") -dtype = dace.float64 - -# This implementation of transposed DGEMV assumes that the two vectors (x and y) -# fit into FPGA on-chip memory - - -def make_init_state(sdfg): - - state = sdfg.add_state("init") - - a_host = state.add_array("A", (M, N), dtype) - a_device = state.add_array("A_device", (M, N), dtype, transient=True, storage=dace.dtypes.StorageType.FPGA_Global) - x_host = state.add_array("x", (M, ), dtype) - x_device = state.add_array("x_device", (M, ), dtype, transient=True, storage=dace.dtypes.StorageType.FPGA_Global) - y_host = state.add_array("y", (M, ), dtype) - y_device = state.add_array("y_device", (N, ), dtype, transient=True, storage=dace.dtypes.StorageType.FPGA_Global) - - state.add_memlet_path(a_host, a_device, memlet=dace.memlet.Memlet.simple(a_device, "0:N, 0:M")) - - state.add_memlet_path(x_host, x_device, memlet=dace.memlet.Memlet.simple(x_device, "0:M")) - - state.add_memlet_path(y_host, y_device, memlet=dace.memlet.Memlet.simple(y_device, "0:N")) - - return state - - -def make_finalize_state(sdfg): - - state = sdfg.add_state("finalize") - - y_host = state.add_array("y", (M, ), dtype) - y_device = state.add_array("y_device", (N, ), dtype, transient=True, storage=dace.dtypes.StorageType.FPGA_Global) - - state.add_memlet_path(y_device, y_host, memlet=dace.memlet.Memlet.simple(y_host, "0:N")) - - return state - - -def make_load_state(sdfg): - - state = sdfg.add_state("load") - - y = state.add_array("y_nested", (N, ), dtype, storage=dace.dtypes.StorageType.FPGA_Global) - y_buffer = state.add_array("y_buffer", (N, ), dtype, transient=True, storage=dace.dtypes.StorageType.FPGA_Local) - - state.add_memlet_path(y, y_buffer, memlet=dace.memlet.Memlet.simple(y_buffer, "0:N")) - - return state - - -def make_store_state(sdfg): - - state = sdfg.add_state("store") - - y_buffer = state.add_array("y_buffer", (N, ), dtype, transient=True, storage=dace.dtypes.StorageType.FPGA_Local) - y = state.add_array("y_nested", (N, ), dtype, storage=dace.dtypes.StorageType.FPGA_Global) - - state.add_memlet_path(y_buffer, y, memlet=dace.memlet.Memlet.simple(y, "0:N")) - - return state - - -def make_compute_state(sdfg): - - state = sdfg.add_state("compute") - - a = state.add_array("A_nested", (M, N), dtype, storage=dace.dtypes.StorageType.FPGA_Global) - x = state.add_array("x_nested", (M, ), dtype, storage=dace.dtypes.StorageType.FPGA_Global) - y_buffer = state.add_array("y_buffer", (N, ), dtype, transient=True, storage=dace.dtypes.StorageType.FPGA_Local) - - cols_entry, cols_exit = state.add_map("cols", {"m": "0:M"}, schedule=dace.ScheduleType.FPGA_Device) - rows_entry, rows_exit = state.add_map("rows", {"n": "0:N"}, schedule=dace.ScheduleType.FPGA_Device) - - tasklet = state.add_tasklet("update", {"a", "x_in"}, {"update"}, "update = a * x_in") - - wcr_memlet = dace.memlet.Memlet.simple(y_buffer, "n", wcr_str="lambda a, b: a + b") - - state.add_memlet_path(a, cols_entry, rows_entry, tasklet, dst_conn="a", memlet=dace.memlet.Memlet.simple(a, "m, n")) - state.add_memlet_path(x, cols_entry, rows_entry, tasklet, dst_conn="x_in", memlet=dace.memlet.Memlet.simple(x, "m")) - state.add_memlet_path(tasklet, rows_exit, cols_exit, y_buffer, src_conn="update", memlet=wcr_memlet) - - return state - - -def make_outer_compute_state(sdfg): - - state = sdfg.add_state("gemv_transposed") - - nested_sdfg = dace.SDFG("gemv_transposed") - load_state = make_load_state(nested_sdfg) - compute_state = make_compute_state(nested_sdfg) - store_state = make_store_state(nested_sdfg) - nested_sdfg.add_edge(load_state, compute_state, dace.sdfg.InterstateEdge()) - nested_sdfg.add_edge(compute_state, store_state, dace.sdfg.InterstateEdge()) - - tasklet = state.add_nested_sdfg(nested_sdfg, {"A_nested", "x_nested", "y_nested"}, {"y_nested"}) - - a_device = state.add_array("A_device", (M, N), dtype, transient=True, storage=dace.dtypes.StorageType.FPGA_Global) - x_device = state.add_array("x_device", (M, ), dtype, transient=True, storage=dace.dtypes.StorageType.FPGA_Global) - y_device_r = state.add_array("y_device", (N, ), dtype, transient=True, storage=dace.dtypes.StorageType.FPGA_Global) - y_device_w = state.add_array("y_device", (N, ), dtype, transient=True, storage=dace.dtypes.StorageType.FPGA_Global) - - state.add_memlet_path(a_device, - tasklet, - dst_conn="A_nested", - memlet=dace.memlet.Memlet.simple(a_device, "0:M, 0:N")) - state.add_memlet_path(x_device, tasklet, dst_conn="x_nested", memlet=dace.memlet.Memlet.simple(x_device, "0:M")) - state.add_memlet_path(y_device_r, tasklet, dst_conn="y_nested", memlet=dace.memlet.Memlet.simple(y_device_r, "0:N")) - state.add_memlet_path(tasklet, y_device_w, src_conn="y_nested", memlet=dace.memlet.Memlet.simple(y_device_w, "0:N")) - - return state - - -def make_sdfg(specialize, N, M): - - if specialize: - name = "gemv_transposed_{}x{}".format(N, M) - else: - name = "gemv_transposed_{}xM".format(N) - - sdfg = dace.SDFG(name) - - init_state = make_init_state(sdfg) - fpga_state = make_outer_compute_state(sdfg) - finalize_state = make_finalize_state(sdfg) - - sdfg.add_edge(init_state, fpga_state, dace.sdfg.InterstateEdge()) - sdfg.add_edge(fpga_state, finalize_state, dace.sdfg.InterstateEdge()) - - return sdfg - - -def run_gemv(n: int, m: int, specialize: bool): - - print("==== Program start ====") - - if specialize: - print("Specializing M...") - - gemv = make_sdfg(specialize, n, m) - gemv.specialize(dict(N=n)) - - if specialize: - gemv.specialize(dict(M=m)) - - print("Running GEMV {}x{} ({}specialized)".format(n, m, ("" if specialize else "not "))) - - A = dace.ndarray([m, n], dtype=dtype) - x = dace.ndarray([m], dtype=dtype) - y = dace.ndarray([n], dtype=dtype) - - # Intialize: randomize A, x and y - # A[:, :] = np.random.rand(M, N).astype(dtype.type) - # x[:] = np.random.rand(M).astype(dtype.type) - # y[:] = np.random.rand(N).astype(dtype.type) - A[:, :] = 1 - x[:] = 1 - y[:] = 0 - - # Regression - regression = np.matmul(np.transpose(A), x) + y - - ############################################# - # Run DaCe program - - if specialize: - gemv(A=A, x=x, y=x) - else: - gemv(A=A, M=m, x=x, y=y) - - residual = np.linalg.norm(y - regression) / (n * m) - print("Residual:", residual) - diff = np.abs(y - regression) - wrong_elements = np.transpose(np.nonzero(diff >= 0.01)) - highest_diff = np.max(diff) - - print("==== Program end ====") - if residual >= 0.01 or highest_diff >= 0.01: - print("Verification failed!") - print("Residual: {}".format(residual)) - print("Incorrect elements: {} / {}".format(wrong_elements.shape[0], (n * m))) - print("Highest difference: {}".format(highest_diff)) - print("** Result:\n", y) - print("** Reference:\n", regression) - raise RuntimeError("Validation failed/") - - return gemv - - -if __name__ == "__main__": - - parser = argparse.ArgumentParser() - parser.add_argument("n", type=int) - parser.add_argument("m", type=int) - parser.add_argument("-specialize", default=False, action="store_true", help="Also fix M in hardware") - args = parser.parse_args() - - run_gemv(args.n, args.m, args.specialize) diff --git a/samples/fpga/histogram_fpga.py b/samples/fpga/histogram_fpga.py deleted file mode 100644 index f1005d8b11..0000000000 --- a/samples/fpga/histogram_fpga.py +++ /dev/null @@ -1,199 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -from __future__ import print_function - -import argparse -import dace -import math -import numpy as np - -W = dace.symbol("W") -H = dace.symbol("H") -num_bins = dace.symbol("num_bins") -dtype = dace.float32 - - -def make_copy_to_fpga_state(sdfg): - - state = sdfg.add_state("copy_to_fpga") - - a_host = state.add_array("A", (H, W), dtype) - a_device = state.add_array("A_device", (H, W), dtype, transient=True, storage=dace.dtypes.StorageType.FPGA_Global) - hist_host = state.add_array("hist", (num_bins, ), dace.uint32) - hist_device = state.add_array("hist_device", (num_bins, ), - dace.uint32, - transient=True, - storage=dace.dtypes.StorageType.FPGA_Global) - - state.add_memlet_path(a_host, a_device, memlet=dace.memlet.Memlet.simple(a_device, "0:H, 0:W")) - state.add_memlet_path(hist_host, hist_device, memlet=dace.memlet.Memlet.simple(hist_device, "0:num_bins")) - - return state - - -def make_copy_to_host_state(sdfg): - - state = sdfg.add_state("copy_to_host") - - hist_device = state.add_array("hist_device", (num_bins, ), - dace.uint32, - transient=True, - storage=dace.dtypes.StorageType.FPGA_Global) - hist_host = state.add_array("hist", (num_bins, ), dace.uint32) - - state.add_memlet_path(hist_device, hist_host, memlet=dace.memlet.Memlet.simple(hist_host, "0:num_bins")) - - return state - - -def make_compute_state(sdfg): - - state = sdfg.add_state("histogram_fpga") - - a = state.add_array("A_in", (H, W), dtype, storage=dace.dtypes.StorageType.FPGA_Global) - hist = state.add_array("hist_buffer", (num_bins, ), - dace.uint32, - transient=True, - storage=dace.dtypes.StorageType.FPGA_Local) - - entry, exit = state.add_map("map", {"i": "0:H", "j": "0:W"}) - - tasklet = state.add_tasklet("compute", {"a"}, {"out"}, "out[int(float(num_bins) * a)] = 1") - - read_memlet = dace.memlet.Memlet.simple(a, "i, j") - write_memlet = dace.memlet.Memlet.simple(hist, "0:num_bins", wcr_str="lambda a, b: a + b") - - state.add_memlet_path(a, entry, tasklet, memlet=read_memlet, dst_conn="a") - state.add_memlet_path(tasklet, exit, hist, memlet=write_memlet, src_conn="out") - - return state - - -def make_init_buffer_state(sdfg): - - state = sdfg.add_state("init_buffer") - - hist_buffer = state.add_array("hist_buffer", (num_bins, ), - dace.uint32, - transient=True, - storage=dace.dtypes.StorageType.FPGA_Local) - - entry, exit = state.add_map("init_map", {"i": "0:num_bins"}) - tasklet = state.add_tasklet("zero", {}, {"out"}, "out = 0") - state.add_nedge(entry, tasklet, dace.memlet.Memlet()) - state.add_memlet_path(tasklet, - exit, - hist_buffer, - src_conn="out", - memlet=dace.memlet.Memlet.simple(hist_buffer, "i")) - - return state - - -def make_write_buffer_state(sdfg): - - state = sdfg.add_state("write_buffer") - - hist_buffer = state.add_array("hist_buffer", (num_bins, ), - dace.uint32, - transient=True, - storage=dace.dtypes.StorageType.FPGA_Local) - hist_dram = state.add_array("hist_out", (num_bins, ), dace.uint32, storage=dace.dtypes.StorageType.FPGA_Global) - - state.add_memlet_path(hist_buffer, hist_dram, memlet=dace.memlet.Memlet.simple(hist_dram, "0:num_bins")) - - return state - - -def make_nested_sdfg(parent): - - sdfg = dace.SDFG("compute") - - init_state = make_init_buffer_state(sdfg) - compute_state = make_compute_state(sdfg) - finalize_state = make_write_buffer_state(sdfg) - - sdfg.add_edge(init_state, compute_state, dace.sdfg.InterstateEdge()) - sdfg.add_edge(compute_state, finalize_state, dace.sdfg.InterstateEdge()) - - return sdfg - - -def make_sdfg(specialize, h, w): - - if specialize: - sdfg = dace.SDFG("histogram_fpga_{}x{}".format(h, w)) - else: - sdfg = dace.SDFG("histogram_fpga") - - copy_to_fpga_state = make_copy_to_fpga_state(sdfg) - - state = sdfg.add_state("compute") - nested_sdfg = make_nested_sdfg(state) - tasklet = state.add_nested_sdfg(nested_sdfg, {"A_in"}, {"hist_out"}) - a_device = state.add_array("A_device", (H, W), dtype, transient=True, storage=dace.dtypes.StorageType.FPGA_Global) - hist_device = state.add_array("hist_device", (num_bins, ), - dace.uint32, - transient=True, - storage=dace.dtypes.StorageType.FPGA_Global) - state.add_memlet_path(a_device, tasklet, dst_conn="A_in", memlet=dace.memlet.Memlet.simple(a_device, "0:H, 0:W")) - state.add_memlet_path(tasklet, - hist_device, - src_conn="hist_out", - memlet=dace.memlet.Memlet.simple(hist_device, "0:num_bins")) - - copy_to_host_state = make_copy_to_host_state(sdfg) - - sdfg.add_edge(copy_to_fpga_state, state, dace.sdfg.InterstateEdge()) - sdfg.add_edge(state, copy_to_host_state, dace.sdfg.InterstateEdge()) - - return sdfg - - -if __name__ == "__main__": - - parser = argparse.ArgumentParser() - parser.add_argument("H", type=int) - parser.add_argument("W", type=int) - parser.add_argument("-specialize", - default=False, - action="store_true", - help="Fix all symbols at compile time/in hardware") - args = vars(parser.parse_args()) - - nbins = 256 - - if args["specialize"]: - h = args["H"] - w = args["W"] - histogram = make_sdfg(True, h, w) - histogram.specialize(dict(H=h, W=w, num_bins=nbins)) - else: - histogram = make_sdfg(False) - histogram.specialize(dict(num_bins=num_bins)) - h = args["H"] - w = args["W"] - - print("Histogram {}x{} ({}specialized)".format(h, w, "" if args["specialize"] else "not ")) - - A = dace.ndarray([H, W], dtype=dtype) - hist = dace.ndarray([num_bins], dtype=dace.uint32) - - A[:] = np.random.rand(h, w).astype(dace.float32.type) - hist[:] = dace.uint32(0) - - if args["specialize"]: - histogram(A=A, hist=hist) - else: - histogram(A=A, H=H, W=W, hist=hist) - - if dace.Config.get_bool('profiling'): - dace.timethis('histogram', 'numpy', (h * w), np.histogram, A, num_bins) - - diff = np.linalg.norm(np.histogram(A, bins=nbins, range=(0.0, 1.0))[0][1:-1] - hist[1:-1]) - - print("Difference:", diff) - if diff > 1e-5: - print("Validation failed.") - print("==== Program end ====") - - exit(0 if diff <= 1e-5 else 1) diff --git a/samples/fpga/jacobi_fpga_systolic.py b/samples/fpga/jacobi_fpga_systolic.py deleted file mode 100644 index f70075aeea..0000000000 --- a/samples/fpga/jacobi_fpga_systolic.py +++ /dev/null @@ -1,343 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -import argparse -import dace -import numpy as np -import select -import sys -from scipy import ndimage - -W = dace.symbol("W") -H = dace.symbol("H") -T = dace.symbol("T") -P = dace.symbol("P") # Number of processing elements -dtype = dace.float32 - - -def make_init_state(sdfg): - state = sdfg.add_state("init") - - a0 = state.add_read("A") - tmp0 = state.add_write("tmp") - state.add_memlet_path(a0, tmp0, memlet=dace.Memlet.simple(tmp0, "0, 0:H, 0:W")) - - a1 = state.add_read("A") - tmp1 = state.add_write("tmp") - state.add_memlet_path(a1, tmp1, memlet=dace.Memlet.simple(tmp1, "1, 0:H, 0:W")) - - return state - - -def make_finalize_state(sdfg, even): - state = sdfg.add_state("finalize_" + ("even" if even else "odd")) - - tmp = state.add_read("tmp") - a = state.add_write("A") - state.add_memlet_path(tmp, a, memlet=dace.Memlet.simple(tmp, "{}, 0:H, 0:W".format(0 if even else 1))) - - return state - - -def make_compute_sdfg(): - sdfg = dace.SDFG("compute") - - pre_shift = sdfg.add_state("pre_shift") - loop_body = sdfg.add_state("compute_body") - post_shift = sdfg.add_state("post_shift") - - sdfg.add_edge(pre_shift, loop_body, dace.sdfg.InterstateEdge()) - sdfg.add_edge(loop_body, post_shift, dace.sdfg.InterstateEdge()) - - sdfg.add_stream("stream_in", dtype, storage=dace.dtypes.StorageType.FPGA_Local) - sdfg.add_stream("stream_out", dtype, storage=dace.dtypes.StorageType.FPGA_Local) - sdfg.add_array("row_buffers", (2, W), dtype, storage=dace.dtypes.StorageType.FPGA_Local) - sdfg.add_array("sliding_window", (3, 3), dtype, storage=dace.dtypes.StorageType.FPGA_Registers) - - stream_in = pre_shift.add_read("stream_in") - stream_out = loop_body.add_write("stream_out") - - rows_in = pre_shift.add_read("row_buffers") - rows_out = post_shift.add_write("row_buffers") - - window_buffer_in = post_shift.add_read("sliding_window") - window_buffer_out = pre_shift.add_write("sliding_window") - window_compute_in = loop_body.add_read("sliding_window") - window_shift_in = post_shift.add_read("sliding_window") - window_shift_out = post_shift.add_write("sliding_window") - - code = """\ -res = 0.0 -if y >= 3 and x >= 3 and y < H - 1 and x < W - 1: - res = float(0.2) * (window[0, 1] + window[1, 0] + window[1, 1] + window[1, 2] + window[2, 1]) -elif y >= 2 and x >= 2: - res = window[1, 1] -if (y >= 3 and x >= 3 and y < H - 1 and x < W - 1) or (y >= 2 and x >= 2): - result = res""" - - tasklet = loop_body.add_tasklet("compute", {"window"}, {"result"}, code) - - # Input window - loop_body.add_memlet_path(window_compute_in, - tasklet, - dst_conn="window", - memlet=dace.Memlet.simple(window_compute_in, "0:3, 0:3")) - - # Output result (conditional write) - out_memlet = dace.Memlet.simple(stream_out, "0", num_accesses=-1) - loop_body.add_memlet_path(tasklet, stream_out, src_conn="result", memlet=out_memlet) - - # Read row buffer - read_row_memlet = dace.Memlet.simple(rows_in, "0:2, x", num_accesses=2, other_subset_str="0:2, 2") - pre_shift.add_memlet_path(rows_in, window_buffer_out, memlet=read_row_memlet) - - # Read from memory - read_memory_memlet = dace.Memlet(f"{stream_in.data}[0]", dynamic=True) - read_memory_tasklet = pre_shift.add_tasklet("skip_last", {"read"}, {"window_buffer"}, - "if y < H - 1 and x < W - 1:\n\twindow_buffer = read") - pre_shift.add_memlet_path(stream_in, read_memory_tasklet, memlet=read_memory_memlet, dst_conn="read") - pre_shift.add_memlet_path(read_memory_tasklet, - window_buffer_out, - memlet=dace.Memlet.simple(window_buffer_out, "2, 2"), - src_conn="window_buffer") - - # Shift window - shift_window_memlet = dace.Memlet.simple(window_shift_in, '0:3, 1:3', other_subset_str='0:3, 0:2') - post_shift.add_memlet_path(window_shift_in, window_shift_out, memlet=shift_window_memlet) - - # To row buffer - write_row_memlet = dace.Memlet.simple(window_buffer_in, '1:3, 2', other_subset_str='0:2, x') - post_shift.add_memlet_path(window_buffer_in, rows_out, memlet=write_row_memlet) - - return sdfg - - -def make_outer_compute_state(sdfg): - state = sdfg.add_state("fpga_outer_state") - - tmp_in = state.add_read("tmp") - pipes_memory_read = state.add_stream("pipes", - dtype, - 1, - transient=True, - shape=(P + 1, ), - storage=dace.dtypes.StorageType.FPGA_Local) - pipes_read = state.add_stream("pipes", - dtype, - 1, - transient=True, - shape=(P + 1, ), - storage=dace.dtypes.StorageType.FPGA_Local) - pipes_write = state.add_stream("pipes", - dtype, - 1, - transient=True, - shape=(P + 1, ), - storage=dace.dtypes.StorageType.FPGA_Local) - pipes_memory_write = state.add_stream("pipes", - dtype, - 1, - transient=True, - shape=(P + 1, ), - storage=dace.dtypes.StorageType.FPGA_Local) - - # Read memory - read_entry, read_exit = state.add_map("read", { - "t": "0:T/P", - "y": "1:H-1", - "x": "1:W-1" - }, - schedule=dace.ScheduleType.FPGA_Device) - read_tasklet = state.add_tasklet("read", {"mem"}, {"to_kernel"}, "to_kernel = mem") - state.add_memlet_path(tmp_in, - read_entry, - read_tasklet, - dst_conn="mem", - memlet=dace.Memlet(f"{tmp_in.data}[t % 2, y, x]")) - state.add_memlet_path(read_tasklet, - read_exit, - pipes_memory_write, - src_conn="to_kernel", - memlet=dace.Memlet(f"{pipes_memory_write.data}[0]")), - - # Compute - compute_sdfg = make_compute_sdfg() - compute_sdfg_node = state.add_nested_sdfg(compute_sdfg, {"stream_in", "sliding_window", "row_buffers"}, - {"stream_out", "sliding_window", "row_buffers"}) - systolic_entry, systolic_exit = state.add_map("unroll_compute", {"p": "0:P"}, - schedule=dace.ScheduleType.FPGA_Device, - unroll=True) - state.add_memlet_path(systolic_entry, pipes_read, memlet=dace.Memlet()) - sdfg.add_array("_sliding_window", (3, 3), dtype, transient=True, storage=dace.dtypes.StorageType.FPGA_Registers) - sliding_window_read = state.add_read("_sliding_window") - sliding_window_write = state.add_write("_sliding_window") - sdfg.add_array("_row_buffers", (2, W), dtype, storage=dace.dtypes.StorageType.FPGA_Local, transient=True) - row_buffers_read = state.add_read("_row_buffers") - row_buffers_write = state.add_write("_row_buffers") - compute_entry, compute_exit = state.add_map("compute", { - "t": "0:T/P", - "y": "1:H", - "x": "1:W" - }, - schedule=dace.ScheduleType.FPGA_Device) - state.add_memlet_path(pipes_read, - compute_entry, - compute_sdfg_node, - dst_conn="stream_in", - memlet=dace.Memlet(f"{pipes_read.data}[p]", dynamic=True)) - state.add_memlet_path(systolic_entry, sliding_window_read, memlet=dace.Memlet()) - state.add_memlet_path(sliding_window_read, - compute_entry, - compute_sdfg_node, - dst_conn="sliding_window", - memlet=dace.Memlet(f"_sliding_window[0:3, 0:3]")) - state.add_memlet_path(sliding_window_write, systolic_exit, memlet=dace.Memlet()) - state.add_memlet_path(compute_sdfg_node, - compute_exit, - sliding_window_write, - src_conn="sliding_window", - memlet=dace.Memlet(f"_sliding_window[0:3, 0:3]")) - state.add_memlet_path(systolic_entry, row_buffers_read, memlet=dace.Memlet()) - state.add_memlet_path(row_buffers_read, - compute_entry, - compute_sdfg_node, - dst_conn="row_buffers", - memlet=dace.Memlet(f"_row_buffers[0:2, 0:W]")) - state.add_memlet_path(row_buffers_write, systolic_exit, memlet=dace.Memlet()) - state.add_memlet_path(compute_sdfg_node, - compute_exit, - row_buffers_write, - src_conn="row_buffers", - memlet=dace.Memlet(f"_row_buffers[0:2, 0:W]")) - state.add_memlet_path(compute_sdfg_node, - compute_exit, - pipes_write, - src_conn="stream_out", - memlet=dace.Memlet(f"{pipes_write.data}[p + 1]", dynamic=True)) - state.add_memlet_path(pipes_write, systolic_exit, memlet=dace.Memlet()) - - # Write memory - write_entry, write_exit = state.add_map("write", { - "t": "0:T/P", - "y": "1:H-1", - "x": "1:W-1" - }, - schedule=dace.ScheduleType.FPGA_Device) - write_tasklet = state.add_tasklet("write", {"from_kernel"}, {"mem"}, "mem = from_kernel") - tmp_out = state.add_write("tmp") - state.add_memlet_path(pipes_memory_read, - write_entry, - write_tasklet, - dst_conn="from_kernel", - memlet=dace.Memlet(f"{pipes_memory_read}[P]")) - state.add_memlet_path(write_tasklet, - write_exit, - tmp_out, - src_conn="mem", - memlet=dace.Memlet(f"{tmp_out.data}[1 - t % 2, y, x]")) - - return state - - -def make_sdfg(specialize_all, h, w, t, p): - name = "jacobi_fpga_systolic_{}_{}x{}x{}".format(p, ("H" if not specialize_all else h), w, - ("T" if not specialize_all else t)) - - sdfg = dace.SDFG(name) - sdfg.add_symbol('T', dace.int32) - - sdfg.add_array("A", (H, W), dtype) - sdfg.add_array("tmp", (2, H, W), dtype, transient=True, storage=dace.dtypes.StorageType.FPGA_Global) - - init_state = make_init_state(sdfg) - - fpga_state = make_outer_compute_state(sdfg) - - finalize_even = make_finalize_state(sdfg, True) - finalize_odd = make_finalize_state(sdfg, False) - - sdfg.add_edge(init_state, fpga_state, dace.sdfg.InterstateEdge()) - sdfg.add_edge( - fpga_state, finalize_even, - dace.sdfg.InterstateEdge(condition=dace.properties.CodeProperty.from_string( - "(T / P) % 2 == 0", language=dace.dtypes.Language.Python))) - sdfg.add_edge( - fpga_state, finalize_odd, - dace.sdfg.InterstateEdge(condition=dace.properties.CodeProperty.from_string( - "(T / P) % 2 == 1", language=dace.dtypes.Language.Python))) - - return sdfg - - -def run_jacobi(w: int, h: int, t: int, p: int, specialize_all: bool = False): - print("==== Program start ====") - - # Width and number of PEs must be known at compile time, as it will - # influence the hardware layout - if specialize_all: - print("Specializing H and T...") - - jacobi = make_sdfg(specialize_all, h, w, t, p) - jacobi.specialize(dict(W=w, P=p)) - - if specialize_all: - jacobi.specialize(dict(H=h, T=t)) - - if t % p != 0: - raise ValueError("Iteration must be divisable by number of processing elements") - - print("Jacobi Stencil {}x{} ({} steps) with {} PEs{}".format(h, w, t, p, - (" (fully specialized)" if specialize_all else ""))) - - A = dace.ndarray([h, w], dtype=dace.float32) - - # Initialize arrays: Randomize A, zero B - A[:] = dace.float32(0) - A[2:h - 2, 2:w - 2] = 1 - regression = np.ndarray([h - 4, w - 4], dtype=np.float32) - regression[:] = A[2:h - 2, 2:w - 2] - - ############################################# - # Run DaCe program - - if specialize_all: - jacobi(A=A) - else: - jacobi(A=A, H=h, T=t) - - # Regression - kernel = np.array([[0, 0.2, 0], [0.2, 0.2, 0.2], [0, 0.2, 0]], dtype=np.float32) - for i in range(t): - regression = ndimage.convolve(regression, kernel, mode='constant', cval=0.0) - - residual = np.linalg.norm(A[2:h - 2, 2:w - 2] - regression) / (h * w) - print("Residual:", residual) - diff = np.abs(A[2:h - 2, 2:w - 2] - regression) - wrong_elements = np.transpose(np.nonzero(diff >= 0.01)) - highest_diff = np.max(diff) - - print("==== Program end ====") - if residual >= 0.01 or highest_diff >= 0.01: - print("Verification failed!") - print("Residual: {}".format(residual)) - print("Incorrect elements: {} / {}".format(wrong_elements.shape[0], h * w)) - print("Highest difference: {}".format(highest_diff)) - print("** Result:\n", A[:min(6, h), :min(6, w)]) - print("** Reference:\n", regression[:min(5, h), :min(4, w)]) - raise RuntimeError("Validation failed.") - - return jacobi - - -if __name__ == "__main__": - - parser = argparse.ArgumentParser() - parser.add_argument("H", type=int, nargs="?", default=64) - parser.add_argument("W", type=int, nargs="?", default=8192) - parser.add_argument("T", type=int, nargs="?", default=16) - parser.add_argument("P", type=int, nargs="?", default=8) - parser.add_argument("-specialize_all", - default=False, - action="store_true", - help="Fix all loop bounds at compile time/in hardware") - args = parser.parse_args() - - run_jacobi(args.H, args.W, args.T, args.P, args.specialize_all) diff --git a/samples/fpga/matrix_multiplication_pipelined.py b/samples/fpga/matrix_multiplication_pipelined.py deleted file mode 100644 index 2a511bc46a..0000000000 --- a/samples/fpga/matrix_multiplication_pipelined.py +++ /dev/null @@ -1,217 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -import argparse -import dace -import numpy as np - -N = dace.symbol("N") -M = dace.symbol("M") -K = dace.symbol("K") - - -def make_sdfg(specialized, n, k, m): - - if specialized: - sdfg = dace.SDFG("mm_fpga_pipelined_{}x{}x{}".format(n, k, m)) - else: - sdfg = dace.SDFG("mm_fpga_pipelined_NxKx{}".format(m)) - - ########################################################################### - # Copy data to FPGA - - pre_state = sdfg.add_state("pre_mm") - - A_host = pre_state.add_array("A", [N, K], dtype=dace.float32) - B_host = pre_state.add_array("B", [K, M], dtype=dace.float32) - C_host = pre_state.add_array("C", [N, M], dtype=dace.float32) - - A_device = pre_state.add_array("A_device", [N, K], - dtype=dace.float32, - transient=True, - storage=dace.dtypes.StorageType.FPGA_Global) - B_device = pre_state.add_array("B_device", [K, M], - dtype=dace.float32, - transient=True, - storage=dace.dtypes.StorageType.FPGA_Global) - C_device = pre_state.add_array("C_device", [N, M], - dtype=dace.float32, - transient=True, - storage=dace.dtypes.StorageType.FPGA_Global) - - pre_state.add_edge(A_host, None, A_device, None, dace.memlet.Memlet.simple(A_device, "0:N, 0:K")) - pre_state.add_edge(B_host, None, B_device, None, dace.memlet.Memlet.simple(B_device, "0:K, 0:M")) - pre_state.add_edge(C_host, None, C_device, None, dace.memlet.Memlet.simple(C_device, "0:N, 0:M")) - - ########################################################################### - # Compute - - state = sdfg.add_state("mm") - sdfg.add_edge(pre_state, state, dace.sdfg.InterstateEdge()) - - A = state.add_array("A_device", [N, K], - dtype=dace.float32, - transient=True, - storage=dace.dtypes.StorageType.FPGA_Global) - B = state.add_array("B_device", [K, M], - dtype=dace.float32, - transient=True, - storage=dace.dtypes.StorageType.FPGA_Global) - C = state.add_array("C_device", [N, M], - dtype=dace.float32, - transient=True, - storage=dace.dtypes.StorageType.FPGA_Global) - - C_buffer_in = state.add_array("C_buffer", [M], - dtype=dace.float32, - transient=True, - storage=dace.dtypes.StorageType.FPGA_Local) - C_buffer_out = state.add_array("C_buffer", [M], - dtype=dace.float32, - transient=True, - storage=dace.dtypes.StorageType.FPGA_Local) - - n_entry, n_exit = state.add_map("Map_N", {"n": "0:N"}, schedule=dace.dtypes.ScheduleType.FPGA_Device) - k_entry, k_exit = state.add_map("Map_K", {"k": "0:K"}, schedule=dace.dtypes.ScheduleType.FPGA_Device) - m_entry, m_exit = state.add_map("Map_M", {"m": "0:M"}, schedule=dace.dtypes.ScheduleType.FPGA_Device) - - state.add_nedge(n_entry, C_buffer_in, dace.memlet.Memlet()) - - ########################################################################### - # Nested SDFG - - nested_sdfg = dace.SDFG("zero_or_wcr") - - if_state = nested_sdfg.add_state("if_state") - then_state = nested_sdfg.add_state("then_state") - else_state = nested_sdfg.add_state("else_state") - end_state = nested_sdfg.add_state("end_state") - nested_sdfg.add_edge( - if_state, then_state, - dace.sdfg.InterstateEdge( - condition=dace.properties.CodeProperty.from_string("k == 0", language=dace.dtypes.Language.Python))) - nested_sdfg.add_edge( - if_state, else_state, - dace.sdfg.InterstateEdge( - condition=dace.properties.CodeProperty.from_string("k != 0", language=dace.dtypes.Language.Python))) - nested_sdfg.add_edge(then_state, end_state, dace.sdfg.InterstateEdge()) - nested_sdfg.add_edge(else_state, end_state, dace.sdfg.InterstateEdge()) - - # These are identical, they only differ in their confres - then_tasklet = then_state.add_tasklet("multiply", {"a", "b"}, {"c_out"}, "c_out = a * b") - else_tasklet = else_state.add_tasklet("multiply", {"a", "b", "c_in"}, {"c_out"}, "c_out = c_in + a * b") - - # Add scalar I/O - then_A_val = then_state.add_scalar("A_val", dtype=dace.float32, storage=dace.dtypes.StorageType.FPGA_Local) - then_B_val = then_state.add_scalar("B_val", dtype=dace.float32, storage=dace.dtypes.StorageType.FPGA_Local) - then_C_out = then_state.add_scalar("C_out", dtype=dace.float32, storage=dace.dtypes.StorageType.FPGA_Local) - - else_A_val = else_state.add_scalar("A_val", dtype=dace.float32, storage=dace.dtypes.StorageType.FPGA_Local) - else_B_val = else_state.add_scalar("B_val", dtype=dace.float32, storage=dace.dtypes.StorageType.FPGA_Local) - else_C_in = else_state.add_scalar("C_in", dtype=dace.float32, storage=dace.dtypes.StorageType.FPGA_Local) - else_C_out = else_state.add_scalar("C_out", dtype=dace.float32, storage=dace.dtypes.StorageType.FPGA_Local) - - # Memlets - then_a_val_memlet = dace.memlet.Memlet.simple(then_A_val, "0") - then_b_val_memlet = dace.memlet.Memlet.simple(then_B_val, "0") - then_c_out_memlet = dace.memlet.Memlet.simple(then_C_out, "0") - - else_a_val_memlet = dace.memlet.Memlet.simple(else_A_val, "0") - else_b_val_memlet = dace.memlet.Memlet.simple(else_B_val, "0") - else_c_in_memlet = dace.memlet.Memlet.simple(else_C_in, "0") - else_c_out_memlet = dace.memlet.Memlet.simple(else_C_out, "0") - - # Draw paths within each state - then_state.add_memlet_path(then_A_val, then_tasklet, memlet=then_a_val_memlet, dst_conn="a") - then_state.add_memlet_path(then_B_val, then_tasklet, memlet=then_b_val_memlet, dst_conn="b") - then_state.add_memlet_path(then_tasklet, then_C_out, memlet=then_c_out_memlet, src_conn="c_out") - - else_state.add_memlet_path(else_A_val, else_tasklet, memlet=else_a_val_memlet, dst_conn="a") - else_state.add_memlet_path(else_B_val, else_tasklet, memlet=else_b_val_memlet, dst_conn="b") - else_state.add_memlet_path(else_C_in, else_tasklet, memlet=else_c_in_memlet, dst_conn="c_in") - else_state.add_memlet_path(else_tasklet, else_C_out, memlet=else_c_out_memlet, src_conn="c_out") - - tasklet = state.add_nested_sdfg(nested_sdfg, {"A_val", "B_val", "C_in"}, {"C_out"}) - - ########################################################################### - # Compute continued - - # tasklet = state.add_tasklet("multiply", {"a", "b"}, {"c"}, "c = a * b") - - read_a_memlet = dace.memlet.Memlet.simple(A, "n, k") - read_b_memlet = dace.memlet.Memlet.simple(B, "k, m") - read_c_memlet = dace.memlet.Memlet.simple(C_buffer_in, "m") - - state.add_memlet_path(A, n_entry, k_entry, m_entry, tasklet, memlet=read_a_memlet, dst_conn="A_val") - state.add_memlet_path(B, n_entry, k_entry, m_entry, tasklet, memlet=read_b_memlet, dst_conn="B_val") - state.add_memlet_path(C_buffer_in, k_entry, m_entry, tasklet, memlet=read_c_memlet, dst_conn="C_in") - - write_buffer_memlet = dace.memlet.Memlet.simple(C_buffer_out, "m") - - state.add_memlet_path(tasklet, m_exit, k_exit, C_buffer_out, memlet=write_buffer_memlet, src_conn="C_out") - - write_c_memlet = dace.memlet.Memlet.simple(C, "n, 0:M") - - state.add_memlet_path(C_buffer_out, n_exit, C, memlet=write_c_memlet) - - ########################################################################### - # Copy back result - - post_state = sdfg.add_state("post_mm") - sdfg.add_edge(state, post_state, dace.sdfg.InterstateEdge()) - - C_device = post_state.add_array("C_device", [N, M], - dtype=dace.float32, - transient=True, - storage=dace.dtypes.StorageType.FPGA_Global) - - C_host = post_state.add_array("C", [N, M], dtype=dace.float32) - - post_state.add_edge(C_device, None, C_host, None, dace.memlet.Memlet.simple(C_device, "0:N, 0:M")) - - return sdfg - - -if __name__ == "__main__": - print("==== Program start ====") - - parser = argparse.ArgumentParser() - parser.add_argument("M", type=int) - parser.add_argument("N", type=int) - parser.add_argument("K", type=int) - parser.add_argument("-specialize", - default=False, - action="store_true", - help="Fix all loop bounds at compile time/in hardware") - args = vars(parser.parse_args()) - - n = args.N - k = args.K - m = args.M - - if not args["specialize"]: - # M must always be specialized, as it's used for the static buffer size - sdfg = make_sdfg(False, n, k, m) - sdfg.specialize(dict(M=m)) - else: - sdfg = make_sdfg(True, n, k, m) - sdfg.specialize(dict(M=m, N=n, K=k)) - - print("Matrix multiplication {}x{}x{} ({}specialized)".format(m, n, k, "" if args["specialize"] else "not ")) - - # Initialize arrays: Randomize A and B, zero C - A = np.ndarray([n, k], dtype=dace.float32.type) - B = np.ndarray([k, m], dtype=dace.float32.type) - C = np.ndarray([n, m], dtype=dace.float32.type) - A[:] = np.random.rand(m, n).astype(dace.float32.type) - B[:] = np.random.rand(n, k).astype(dace.float32.type) - C[:] = dace.float32(0) - - if args["specialize"]: - sdfg(A=A, B=B, C=C) - else: - sdfg(A=A, B=B, C=C, N=n, K=k) - - diff = np.linalg.norm((A @ B) - C) / float(m * k) - if diff > 1e-6: - raise ValueError(f"Verification failed, difference: {diff}") - else: - print("Results successfully verified.") diff --git a/samples/fpga/matrix_multiplication_stream.py b/samples/fpga/matrix_multiplication_stream.py deleted file mode 100644 index 10105f79a6..0000000000 --- a/samples/fpga/matrix_multiplication_stream.py +++ /dev/null @@ -1,208 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -import argparse -import dace -import numpy as np -import pdb -import select -import sys - -N = dace.symbol("N") -M = dace.symbol("M") -K = dace.symbol("K") - - -def make_copy_to_fpga_state(sdfg): - - ########################################################################### - # Copy data to FPGA - - state = sdfg.add_state("copy_to_device") - - sdfg.add_array("A", [N, K], dtype=dace.float32) - sdfg.add_array("B", [K, M], dtype=dace.float32) - sdfg.add_array("C", [N, M], dtype=dace.float32) - - A_host = state.add_read("A") - B_host = state.add_read("B") - C_host = state.add_read("C") - - sdfg.add_array("A_device", [N, K], dtype=dace.float32, transient=True, storage=dace.StorageType.FPGA_Global) - sdfg.add_array("B_device", [K, M], dtype=dace.float32, transient=True, storage=dace.StorageType.FPGA_Global) - sdfg.add_array("C_device", [N, M], dtype=dace.float32, transient=True, storage=dace.StorageType.FPGA_Global) - - A_device = state.add_write("A_device") - B_device = state.add_write("B_device") - C_device = state.add_write("C_device") - - state.add_edge(A_host, None, A_device, None, dace.Memlet("A_device")) - state.add_edge(B_host, None, B_device, None, dace.Memlet("B_device")) - state.add_edge(C_host, None, C_device, None, dace.Memlet("C_device")) - - return state - - -def make_copy_to_host_state(sdfg): - - ########################################################################### - # Copy data to FPGA - - state = sdfg.add_state("copy_to_host") - - C_device = state.add_read("C_device") - C_host = state.add_write("C") - - state.add_edge(C_device, None, C_host, None, dace.Memlet("C")) - - return state - - -def make_fpga_state(sdfg): - - state = sdfg.add_state("mm") - - A = state.add_read("A_device") - B = state.add_read("B_device") - C = state.add_write("C_device") - - A_pipe_in = state.add_stream("A_pipe", dace.float32, transient=True, storage=dace.StorageType.FPGA_Local) - B_pipe_in = state.add_stream("B_pipe", dace.float32, transient=True, storage=dace.StorageType.FPGA_Local) - C_pipe_in = state.add_stream("C_pipe", dace.float32, transient=True, storage=dace.StorageType.FPGA_Local) - A_pipe_out = state.add_stream("A_pipe", dace.float32, transient=True, storage=dace.StorageType.FPGA_Local) - B_pipe_out = state.add_stream("B_pipe", dace.float32, transient=True, storage=dace.StorageType.FPGA_Local) - C_pipe_out = state.add_stream("C_pipe", dace.float32, transient=True, storage=dace.StorageType.FPGA_Local) - - state.add_memlet_path(A, A_pipe_out, memlet=dace.Memlet("A_device")) - - read_b_entry, read_b_exit = state.add_map("read_b", { - "n": "0:N", - "k": "0:K", - "m": "0:M" - }, - schedule=dace.ScheduleType.FPGA_Device) - read_b_tasklet = state.add_tasklet("read_b", {"mem"}, {"s"}, "s = mem") - state.add_memlet_path(B, read_b_entry, read_b_tasklet, dst_conn="mem", memlet=dace.Memlet("B_device[k, m]")) - state.add_memlet_path(read_b_tasklet, read_b_exit, B_pipe_out, src_conn="s", memlet=dace.Memlet("B_pipe[0]")) - - state.add_memlet_path(C_pipe_in, C, src_conn="mem", memlet=dace.Memlet("C_device")) - - ########################################################################### - - n_entry, n_exit = state.add_map("outer_map", {"n": "0:N"}, schedule=dace.ScheduleType.FPGA_Device) - km_entry, km_exit = state.add_map("inner_map", {"k": "0:K", "m": "0:M"}, schedule=dace.ScheduleType.FPGA_Device) - - sdfg.add_array("output_buffer", [M], dtype=dace.float32, transient=True, storage=dace.StorageType.FPGA_Local) - sdfg.add_array("A_reg", [1], dtype=dace.float32, transient=True, storage=dace.StorageType.FPGA_Local) - output_buffer_read = state.add_read("output_buffer") - output_buffer_write = state.add_write("output_buffer") - read_a_reg = state.add_read("A_reg") - write_a_reg = state.add_write("A_reg") - - tasklet = state.add_tasklet( - "multiply_accumulate", {"a_mem", "a_reg_in", "b", "c_in"}, {"a_reg_out", "c_out"}, """\ -a = a_mem if m == 0 else a_reg_in -a_reg_out = a -prev = 0 if k == 0 else c_in -c_out = prev + a * b""") - - state.add_memlet_path(A_pipe_in, - n_entry, - km_entry, - tasklet, - dst_conn="a_mem", - memlet=dace.Memlet("A_pipe[0]", dynamic=True)) - - state.add_memlet_path(B_pipe_in, n_entry, km_entry, tasklet, dst_conn="b", memlet=dace.Memlet("B_pipe[0]")) - - state.add_memlet_path(read_a_reg, n_entry, km_entry, tasklet, dst_conn="a_reg_in", memlet=dace.Memlet("A_reg[0]")) - - state.add_memlet_path(output_buffer_read, - km_entry, - tasklet, - dst_conn="c_in", - memlet=dace.Memlet("output_buffer[m]")) - - # Make sure it's in scope - state.add_memlet_path(n_entry, output_buffer_read, memlet=dace.Memlet()) - - state.add_memlet_path(tasklet, - km_exit, - output_buffer_write, - src_conn="c_out", - memlet=dace.Memlet("output_buffer[m]")) - - state.add_memlet_path(tasklet, km_exit, n_exit, write_a_reg, src_conn="a_reg_out", memlet=dace.Memlet("A_reg[0]")) - - state.add_memlet_path(output_buffer_write, n_exit, C_pipe_out, memlet=dace.Memlet("output_buffer[0:M]")) - - return state - - -def make_sdfg(specialized, n, k, m): - - if specialized: - sdfg = dace.SDFG("mm_fpga_stream_{}x{}x{}".format(n, k, m)) - else: - sdfg = dace.SDFG("mm_fpga_stream_NxKx{}".format(m)) - - pre_state = make_copy_to_fpga_state(sdfg) - compute_state = make_fpga_state(sdfg) - post_state = make_copy_to_host_state(sdfg) - - sdfg.add_edge(pre_state, compute_state, dace.InterstateEdge()) - sdfg.add_edge(compute_state, post_state, dace.InterstateEdge()) - - return sdfg - - -if __name__ == "__main__": - print("==== Program start ====") - - parser = argparse.ArgumentParser() - parser.add_argument("M", type=int) - parser.add_argument("N", type=int) - parser.add_argument("K", type=int) - parser.add_argument("-specialize", - default=False, - action="store_true", - help="Fix all loop bounds at compile time/in hardware") - args = vars(parser.parse_args()) - - m = args.M - n = args.N - k = args.K - - if not args["specialize"]: - # M must always be specialized, as it's used for the static buffer size - sdfg = make_sdfg(False, n, k, m) - sdfg.specialize(dict(M=m)) - else: - sdfg = make_sdfg(True, n, k, m) - sdfg.specialize(dict(M=m, N=n, K=k)) - - print("Matrix multiplication {}x{}x{} ({}specialized)".format(m, n, k, "" if args["specialize"] else "not ")) - - # Initialize arrays: Randomize A and B, zero C - A = np.ndarray([n, k], dtype=dace.float32.type) - B = np.ndarray([k, m], dtype=dace.float32.type) - C = np.ndarray([n, m], dtype=dace.float32.type) - A[:] = 1 # np.random.rand(n, k).astype(dace.float32.type) - B[:] = 1 # np.random.rand(k, m).astype(dace.float32.type) - C[:] = dace.float32(0) - - A_regression = np.ndarray([n, k], dtype=np.float32) - B_regression = np.ndarray([k, m], dtype=np.float32) - C_regression = np.ndarray([n, m], dtype=np.float32) - A_regression[:] = A[:] - B_regression[:] = B[:] - C_regression[:] = C[:] - - if args["specialize"]: - sdfg(A=A, B=B, C=C) - else: - sdfg(A=A, B=B, C=C, N=n, K=k) - - diff = np.linalg.norm((A @ B) - C) / float(m * k) - if diff > 1e-6: - raise ValueError(f"Verification failed, difference: {diff}") - else: - print("Results successfully verified.") diff --git a/samples/fpga/matrix_multiplication_systolic.py b/samples/fpga/matrix_multiplication_systolic.py deleted file mode 100644 index f630233fba..0000000000 --- a/samples/fpga/matrix_multiplication_systolic.py +++ /dev/null @@ -1,316 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -import argparse -import click -import dace -import numpy as np -import pdb -import select -import sys - -N = dace.symbol("N") -K = dace.symbol("K") -M = dace.symbol("M") -P = dace.symbol("P") - - -def make_copy_to_fpga_state(sdfg): - - ########################################################################### - # Copy data to FPGA - - state = sdfg.add_state("copy_to_device") - - sdfg.add_array("A", [N, K], dtype=dace.float32) - sdfg.add_array("B", [K, M], dtype=dace.float32) - sdfg.add_array("C", [N, M], dtype=dace.float32) - A_host = state.add_read("A") - B_host = state.add_read("B") - C_host = state.add_read("C") - - sdfg.add_array("A_device", [N, K], dtype=dace.float32, transient=True, storage=dace.dtypes.StorageType.FPGA_Global) - sdfg.add_array("B_device", [K, M], dtype=dace.float32, transient=True, storage=dace.dtypes.StorageType.FPGA_Global) - sdfg.add_array("C_device", [N, M], dtype=dace.float32, transient=True, storage=dace.dtypes.StorageType.FPGA_Global) - A_device = state.add_write("A_device") - B_device = state.add_write("B_device") - C_device = state.add_write("C_device") - - state.add_memlet_path(A_host, A_device, memlet=dace.Memlet("A_device[0:N, 0:K]")) - state.add_memlet_path(B_host, B_device, memlet=dace.Memlet("B_device[0:K, 0:M]")) - state.add_memlet_path(C_host, C_device, memlet=dace.Memlet("C_device[0:N, 0:M]")) - - return state - - -def make_copy_to_host_state(sdfg): - - ########################################################################### - # Copy data to FPGA - - state = sdfg.add_state("copy_to_host") - - C_device = state.add_read("C_device") - C_host = state.add_write("C") - - state.add_memlet_path(C_device, C_host, memlet=dace.Memlet("C[0:N, 0:M]")) - - return state - - -def make_read_A(state): - - entry, exit = state.add_map("read_A", { - "n0": "0:N/P", - "k": "0:K", - "n1": "0:P" - }, - schedule=dace.ScheduleType.FPGA_Device) - - mem = state.add_read("A_device") - pipe = state.add_write("A_pipe") - tasklet = state.add_tasklet("read_A", {"from_memory"}, {"to_kernel"}, "to_kernel = from_memory") - - state.add_memlet_path(mem, entry, tasklet, dst_conn="from_memory", memlet=dace.Memlet("A_device[n0 * P + n1, k]")) - state.add_memlet_path(tasklet, exit, pipe, src_conn="to_kernel", memlet=dace.Memlet("A_pipe[0]")) - - -def make_read_B(state): - - entry, exit = state.add_map("read_B", { - "n": "0:N/P", - "k": "0:K", - "m": "0:M" - }, - schedule=dace.ScheduleType.FPGA_Device) - - mem = state.add_read("B_device") - pipe = state.add_write("B_pipe") - tasklet = state.add_tasklet("read_B", {"from_memory"}, {"to_kernel"}, "to_kernel = from_memory") - - state.add_memlet_path(mem, entry, tasklet, dst_conn="from_memory", memlet=dace.Memlet("B_device[k, m]")) - state.add_memlet_path(tasklet, exit, pipe, src_conn="to_kernel", memlet=dace.Memlet("B_pipe[0]")) - - -def make_write_C(state): - - pipe = state.add_read("C_pipe") - mem = state.add_write("C_device") - - state.add_memlet_path(pipe, mem, memlet=dace.Memlet("C_device[0:N, 0:M]", other_subset="P - 1")) - - -def make_compute(sdfg, state): - - A_pipe_in = state.add_read("A_pipe") - A_pipe_out = state.add_write("A_pipe") - B_pipe_in = state.add_read("B_pipe") - B_pipe_out = state.add_write("B_pipe") - C_pipe_in = state.add_read("C_pipe") - C_pipe_out = state.add_write("C_pipe") - - entry_n0, exit_n0 = state.add_map("n0", { - "n0": "0:N/P", - }, schedule=dace.ScheduleType.FPGA_Device) - entry_k, exit_k = state.add_map("k", {"k": "0:K"}, schedule=dace.ScheduleType.FPGA_Device) - entry_a, exit_a = state.add_map("buffer_A", {"n1": "0:P"}, schedule=dace.ScheduleType.FPGA_Device) - entry_m, exit_m = state.add_map("m", {"m": "0:M"}, schedule=dace.ScheduleType.FPGA_Device) - entry_c, exit_c = state.add_map("write_C", {"n1": "0:P", "m": "0:M"}, schedule=dace.ScheduleType.FPGA_Device) - - # Instantiate buffers - sdfg.add_scalar("A_reg", dtype=dace.float32, transient=True, storage=dace.dtypes.StorageType.FPGA_Registers) - A_reg = state.add_write("A_reg") - sdfg.add_array("C_buffer", [M], dtype=dace.float32, transient=True, storage=dace.dtypes.StorageType.FPGA_Local) - C_buffer_in = state.add_read("C_buffer") - C_buffer_out = state.add_write("C_buffer") - - buffer_a_tasklet = state.add_tasklet("buffer_a", {"a_in"}, {"a_reg", "a_out"}, """\ -if n1 == P - p - 1: - a_reg = a_in -if p < P - 1: - a_out = a_in""") - state.add_memlet_path(A_pipe_in, - entry_n0, - entry_k, - entry_a, - buffer_a_tasklet, - memlet=dace.Memlet("A_pipe[p]", dynamic=False), - dst_conn="a_in") - state.add_memlet_path(buffer_a_tasklet, - exit_a, - A_reg, - memlet=dace.Memlet("A_reg[0]", dynamic=True), - src_conn="a_reg") - state.add_memlet_path(buffer_a_tasklet, - exit_a, - exit_k, - exit_n0, - A_pipe_out, - memlet=dace.Memlet("A_pipe[p + 1]", dynamic=True), - src_conn="a_out") - - compute_tasklet = state.add_tasklet( - "multiply_add", {"a_in", "b_in", "c_in"}, {"b_out", "c_out"}, """\ -c_prev = 0 if k == 0 else c_in -c_out = c_prev + a_in * b_in -if p < P - 1: - b_out = b_in""") - - state.add_memlet_path(A_reg, entry_m, compute_tasklet, dst_conn="a_in", memlet=dace.Memlet("A_reg[0]")) - state.add_memlet_path(B_pipe_in, - entry_n0, - entry_k, - entry_m, - compute_tasklet, - memlet=dace.Memlet("B_pipe[p]", dynamic=False), - dst_conn="b_in") - state.add_memlet_path(compute_tasklet, - exit_m, - exit_k, - exit_n0, - B_pipe_out, - memlet=dace.Memlet("B_pipe[p + 1]", dynamic=True), - src_conn="b_out") - state.add_memlet_path(C_buffer_in, - entry_k, - entry_m, - compute_tasklet, - dst_conn="c_in", - memlet=dace.Memlet("C_buffer[m]")) - state.add_memlet_path(entry_n0, C_buffer_in, memlet=dace.Memlet()) - state.add_memlet_path(compute_tasklet, - exit_m, - exit_k, - C_buffer_out, - memlet=dace.Memlet("C_buffer[m]"), - src_conn="c_out") - state.add_memlet_path(C_buffer_out, exit_n0, memlet=dace.Memlet()) - - # Write back - write_c_tasklet = state.add_tasklet("write_c", {"buffer_in", "forward_in"}, {"c_out"}, """\ -if n1 <= p: - c_out = forward_in if p > 0 and n1 > 0 else buffer_in""") - state.add_memlet_path(C_buffer_out, - entry_c, - write_c_tasklet, - memlet=dace.Memlet("C_buffer[m]", dynamic=True), - dst_conn="buffer_in") - state.add_memlet_path(C_pipe_in, - entry_n0, - entry_c, - write_c_tasklet, - memlet=dace.Memlet("C_pipe[p-1]", dynamic=True), - dst_conn="forward_in") - state.add_memlet_path(write_c_tasklet, - exit_c, - exit_n0, - C_pipe_out, - memlet=dace.Memlet("C_pipe[p]", dynamic=True), - src_conn="c_out") - - # Unroll processing elements - compute_entry, compute_exit = state.add_map("unroll_compute", {"p": "0:P"}, - schedule=dace.ScheduleType.FPGA_Device, - unroll=True) - - # Bring data nodes into scope - state.add_memlet_path(compute_entry, A_pipe_in, memlet=dace.memlet.Memlet()) - state.add_memlet_path(compute_entry, B_pipe_in, memlet=dace.memlet.Memlet()) - state.add_memlet_path(compute_entry, C_pipe_in, memlet=dace.memlet.Memlet()) - state.add_memlet_path(A_pipe_out, compute_exit, memlet=dace.memlet.Memlet()) - state.add_memlet_path(B_pipe_out, compute_exit, memlet=dace.memlet.Memlet()) - state.add_memlet_path(C_pipe_out, compute_exit, memlet=dace.memlet.Memlet()) - - -def make_fpga_state(sdfg): - - state = sdfg.add_state("mm") - - sdfg.add_stream("A_pipe", - dace.float32, - transient=True, - shape=(P + 1, ), - storage=dace.dtypes.StorageType.FPGA_Local, - buffer_size="P") - sdfg.add_stream("B_pipe", dace.float32, transient=True, shape=(P + 1, ), storage=dace.dtypes.StorageType.FPGA_Local) - sdfg.add_stream("C_pipe", dace.float32, transient=True, shape=(P, ), storage=dace.dtypes.StorageType.FPGA_Local) - - make_read_A(state) - make_read_B(state) - make_compute(sdfg, state) - make_write_C(state) - - return state - - -def make_sdfg(specialized, p, n, k, m): - - if specialized: - sdfg = dace.SDFG("mm_fpga_systolic_{}_{}x{}x{}".format(p, n, k, m)) - else: - sdfg = dace.SDFG("mm_fpga_systolic_{}_NxKx{}".format(p, m)) - - pre_state = make_copy_to_fpga_state(sdfg) - compute_state = make_fpga_state(sdfg) - post_state = make_copy_to_host_state(sdfg) - - sdfg.add_edge(pre_state, compute_state, dace.sdfg.InterstateEdge()) - sdfg.add_edge(compute_state, post_state, dace.sdfg.InterstateEdge()) - - return sdfg - - -def run_matmul_systolic(m, n, k, p, specialize): - print("==== Program start ====") - - if not specialize: - # M must always be specialized, as it's used for the static buffer size - sdfg = make_sdfg(False, p, n, k, m) - sdfg.specialize(dict(P=p, M=m)) - else: - sdfg = make_sdfg(True, p, n, k, m) - sdfg.specialize(dict(P=p, M=m, N=n, K=k)) - - print("Matrix multiplication {}x{}x{} with {} PEs ({}specialized)".format(m, n, k, p, "" if specialize else "not ")) - - # Initialize arrays: Randomize A and B, zero C - A = np.ndarray([n, k], dtype=dace.float32.type) - B = np.ndarray([k, m], dtype=dace.float32.type) - C = np.ndarray([n, m], dtype=dace.float32.type) - A[:] = np.random.rand(n, k).astype(dace.float32.type) - B[:] = np.random.rand(k, m).astype(dace.float32.type) - C[:] = dace.float32(0) - - A_regression = np.ndarray([n, k], dtype=np.float32) - B_regression = np.ndarray([k, m], dtype=np.float32) - C_regression = np.ndarray([n, m], dtype=np.float32) - A_regression[:] = A[:] - B_regression[:] = B[:] - C_regression[:] = C[:] - - if specialize: - sdfg(A=A, B=B, C=C) - else: - sdfg(A=A, B=B, C=C, N=n, K=k) - - diff = np.linalg.norm((A @ B) - C) / float(m * k) - if diff > 1e-6: - raise ValueError(f"Verification failed, difference: {diff}") - else: - print("Results successfully verified.") - - print("==== Program end ====") - - return sdfg - - -@click.command() -@click.argument("M", type=int) -@click.argument("N", type=int) -@click.argument("K", type=int) -@click.argument("P", type=int) -@click.option("--specialize/--no-specialize", default=False, help="Fix all loop bounds at compile time/in hardware") -def cli(m, n, k, p, specialize): - run_matmul_systolic(m, n, k, p, specialize) - - -if __name__ == "__main__": - cli() diff --git a/samples/fpga/rtl/add_fortytwo.py b/samples/fpga/rtl/add_fortytwo.py deleted file mode 100644 index bd24aaa378..0000000000 --- a/samples/fpga/rtl/add_fortytwo.py +++ /dev/null @@ -1,137 +0,0 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. -""" - This sample shows adding a constant integer value to a stream of integers. - - It is intended for running hardware_emulation or hardware xilinx targets. -""" - -import dace -import numpy as np - -# add symbol -N = dace.symbol('N') - -# add sdfg -sdfg = dace.SDFG('add_fortytwo') - -# add state -state = sdfg.add_state('device_state') - -# add arrays -sdfg.add_array('A', [N], dtype=dace.int32, storage=dace.StorageType.CPU_Heap) -sdfg.add_array('B', [N], dtype=dace.int32, storage=dace.StorageType.CPU_Heap) -sdfg.add_array('fpga_A', [N], dtype=dace.int32, transient=True, storage=dace.StorageType.FPGA_Global) -sdfg.add_array('fpga_B', [N], dtype=dace.int32, transient=True, storage=dace.StorageType.FPGA_Global) - -# add streams -sdfg.add_stream('A_stream', dtype=dace.int32, transient=True, storage=dace.StorageType.FPGA_Local) -sdfg.add_stream('B_stream', dtype=dace.int32, transient=True, storage=dace.StorageType.FPGA_Local) - -# add custom rtl tasklet -rtl_tasklet = state.add_tasklet(name='rtl_tasklet', - inputs={'a'}, - outputs={'b'}, - code=''' - /* - Convention: - |--------------------------------------------------------| - | | - -->| ap_aclk (clock input) | - -->| ap_areset (reset input, rst on high) | - -->| ap_start (start pulse from host) | - <--| ap_done (tells the host that the kernel is done) | - | | - | For each input: For each output: | - | | - -->| s_axis_{input}_tvalid reg m_axis_{output}_tvalid |--> - -->| s_axis_{input}_tdata reg m_axis_{output}_tdata |--> - <--| reg s_axis_{input}_tready m_axis_{output}_tready |<-- - -->| s_axis_{input}_tkeep reg m_axis_{output}_tkeep |--> - -->| s_axis_{input}_tlast reg m_axis_{output}_tlast |--> - | | - |--------------------------------------------------------| - */ - - assign ap_done = 1; // free-running kernel - - always@(posedge ap_aclk) begin - if (ap_areset) begin // case: reset - s_axis_a_tready <= 1'b1; - m_axis_b_tvalid <= 1'b0; - m_axis_b_tdata <= 0; - end else if (s_axis_a_tvalid && s_axis_a_tready) begin - s_axis_a_tready <= 1'b0; - m_axis_b_tvalid <= 1'b1; - m_axis_b_tdata <= s_axis_a_tdata + 42; - end else if (!s_axis_a_tready && m_axis_b_tvalid && m_axis_b_tready) begin - s_axis_a_tready <= 1'b1; - m_axis_b_tvalid <= 1'b0; - end - end - ''', - language=dace.Language.SystemVerilog) - -# add read and write tasklets -read_a = state.add_tasklet('read_a', {'inp'}, {'out'}, 'out = inp') -write_b = state.add_tasklet('write_b', {'inp'}, {'out'}, 'out = inp') - -# add read and write maps -read_a_entry, read_a_exit = state.add_map('read_a_map', dict(i='0:N'), schedule=dace.ScheduleType.FPGA_Device) -write_b_entry, write_b_exit = state.add_map('write_b_map', dict(i='0:N'), schedule=dace.ScheduleType.FPGA_Device) - -# add read_a memlets and access nodes -read_a_inp = state.add_read('fpga_A') -read_a_out = state.add_write('A_stream') -state.add_memlet_path(read_a_inp, read_a_entry, read_a, dst_conn='inp', memlet=dace.Memlet('fpga_A[i]')) -state.add_memlet_path(read_a, read_a_exit, read_a_out, src_conn='out', memlet=dace.Memlet('A_stream[0]')) - -# add tasklet memlets -A = state.add_read('A_stream') -B = state.add_write('B_stream') -state.add_memlet_path(A, rtl_tasklet, dst_conn='a', memlet=dace.Memlet('A_stream[0]')) -state.add_memlet_path(rtl_tasklet, B, src_conn='b', memlet=dace.Memlet('B_stream[0]')) - -# add write_b memlets and access nodes -write_b_inp = state.add_read('B_stream') -write_b_out = state.add_write('fpga_B') -state.add_memlet_path(write_b_inp, write_b_entry, write_b, dst_conn='inp', memlet=dace.Memlet('B_stream[0]')) -state.add_memlet_path(write_b, write_b_exit, write_b_out, src_conn='out', memlet=dace.Memlet('fpga_B[i]')) - -# add copy to device state -copy_to_device = sdfg.add_state('copy_to_device') -cpu_a = copy_to_device.add_read('A') -dev_a = copy_to_device.add_write('fpga_A') -copy_to_device.add_memlet_path(cpu_a, dev_a, memlet=dace.Memlet('A[0:N]')) -sdfg.add_edge(copy_to_device, state, dace.InterstateEdge()) - -# add copy to host state -copy_to_host = sdfg.add_state('copy_to_host') -dev_b = copy_to_host.add_read('fpga_B') -cpu_b = copy_to_host.add_write('B') -copy_to_host.add_memlet_path(dev_b, cpu_b, memlet=dace.Memlet('B[0:N]')) -sdfg.add_edge(state, copy_to_host, dace.InterstateEdge()) - -# validate sdfg -sdfg.validate() - -###################################################################### - -if __name__ == '__main__': - with dace.config.set_temporary('compiler', 'xilinx', 'mode', value='hardware_emulation'): - # init data structures - N = 8192 - a = np.random.randint(0, 100, N).astype(np.int32) - b = np.zeros((N, )).astype(np.int32) - - # show initial values - print("a={}, b={}".format(a, b)) - - # call program - sdfg(A=a, B=b, N=N) - - # show result - print("a={}, b={}".format(a, b)) - - # check result - for i in range(N): - assert b[i] == a[i] + 42 diff --git a/samples/fpga/rtl/axpy.py b/samples/fpga/rtl/axpy.py deleted file mode 100644 index c4e9ba0af1..0000000000 --- a/samples/fpga/rtl/axpy.py +++ /dev/null @@ -1,265 +0,0 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. -""" - This sample shows the AXPY BLAS routine. It is implemented through Xilinx IPs in order to utilize floating point - operations. - - It is intended for running hardware_emulation or hardware xilinx targets. -""" - -import dace -import numpy as np - -# add symbol -N = dace.symbol('N') - - -def make_sdfg(veclen=2): - # add sdfg - sdfg = dace.SDFG('axpy') - - # add state - state = sdfg.add_state('device_state') - - # add parameter - sdfg.add_constant('VECLEN', veclen) - - # add arrays - sdfg.add_scalar('a', dtype=dace.float32, storage=dace.StorageType.FPGA_Global) - sdfg.add_array('x', [N // veclen], dtype=dace.vector(dace.float32, veclen), storage=dace.StorageType.CPU_Heap) - sdfg.add_array('y', [N // veclen], dtype=dace.vector(dace.float32, veclen), storage=dace.StorageType.CPU_Heap) - sdfg.add_array('result', [N // veclen], dtype=dace.vector(dace.float32, veclen), storage=dace.StorageType.CPU_Heap) - sdfg.add_array('fpga_x', [N // veclen], - dtype=dace.vector(dace.float32, veclen), - transient=True, - storage=dace.StorageType.FPGA_Global) - sdfg.add_array('fpga_y', [N // veclen], - dtype=dace.vector(dace.float32, veclen), - transient=True, - storage=dace.StorageType.FPGA_Global) - sdfg.add_array('fpga_result', [N // veclen], - dtype=dace.vector(dace.float32, veclen), - transient=True, - storage=dace.StorageType.FPGA_Global) - - # add streams - sdfg.add_stream('x_stream', - buffer_size=32, - dtype=dace.vector(dace.float32, veclen), - transient=True, - storage=dace.StorageType.FPGA_Local) - sdfg.add_stream('y_stream', - buffer_size=32, - dtype=dace.vector(dace.float32, veclen), - transient=True, - storage=dace.StorageType.FPGA_Local) - sdfg.add_stream('result_stream', - buffer_size=32, - dtype=dace.vector(dace.float32, veclen), - transient=True, - storage=dace.StorageType.FPGA_Local) - - # add custom rtl tasklet - rtl_tasklet = state.add_tasklet(name='rtl_tasklet', - inputs={'a_in', 'x_in', 'y_in'}, - outputs={'result_out'}, - code=''' - /* - Convention: - |--------------------------------------------------------| - | | - -->| ap_aclk (slow clock input) | - -->| ap_areset (slow reset input, rst on high) | - -->| ap_aclk (fast clock input) | - -->| ap_areset_2 (fast reset input, rst on high) | - -->| ap_start (start pulse from host) | - <--| ap_done (tells the host that the kernel is done) | - | | - | For each input: For each output: | - | | - -->| s_axis_{input}_tvalid reg m_axis_{output}_tvalid |--> - -->| s_axis_{input}_tdata reg m_axis_{output}_tdata |--> - <--| reg s_axis_{input}_tready m_axis_{output}_tready |<-- - -->| s_axis_{input}_tkeep reg m_axis_{output}_tkeep |--> - -->| s_axis_{input}_tlast reg m_axis_{output}_tlast |--> - | | - |--------------------------------------------------------| - */ - assign ap_done = 1; // free-running kernel - - wire axis_ax_tvalid; - wire [31:0] axis_ax_tdata; - wire axis_ax_tready; - - reg [VECLEN-1:0] s_axis_x_in_tready_tmp; - reg [VECLEN-1:0] s_axis_y_in_tready_tmp; - reg [VECLEN-1:0] m_axis_result_out_tvalid_tmp; - - generate for (genvar i = 0; i < VECLEN; i++) begin - - wire axis_ax_tvalid; - wire [31:0] axis_ax_tdata; - wire axis_ax_tready; - - floating_point_mult multiplier ( - .aclk(ap_aclk), - - .s_axis_a_tvalid(scalars_valid), - .s_axis_a_tdata(a_in), - //.s_axis_a_tready(), - - .s_axis_b_tvalid(s_axis_x_in_tvalid), - .s_axis_b_tdata( s_axis_x_in_tdata[i]), - .s_axis_b_tready(s_axis_x_in_tready_tmp[i]), - - .m_axis_result_tvalid(axis_ax_tvalid), - .m_axis_result_tdata( axis_ax_tdata), - .m_axis_result_tready(axis_ax_tready) - ); - - floating_point_add adder ( - .aclk(ap_aclk), - - .s_axis_a_tvalid(axis_ax_tvalid), - .s_axis_a_tdata( axis_ax_tdata), - .s_axis_a_tready(axis_ax_tready), - - .s_axis_b_tvalid(s_axis_y_in_tvalid), - .s_axis_b_tdata( s_axis_y_in_tdata[i]), - .s_axis_b_tready(s_axis_y_in_tready_tmp[i]), - - .m_axis_result_tvalid(m_axis_result_out_tvalid_tmp[i]), - .m_axis_result_tdata( m_axis_result_out_tdata[i]), - .m_axis_result_tready(m_axis_result_out_tready) - ); - - end endgenerate - - assign s_axis_x_in_tready = &s_axis_x_in_tready_tmp; - assign s_axis_y_in_tready = &s_axis_y_in_tready_tmp; - assign m_axis_result_out_tvalid = &m_axis_result_out_tvalid_tmp; - ''', - language=dace.Language.SystemVerilog) - - rtl_tasklet.add_ip_core( - 'floating_point_mult', 'floating_point', 'xilinx.com', '7.1', { - "CONFIG.Operation_Type": "Multiply", - "CONFIG.C_Mult_Usage": "Max_Usage", - "CONFIG.Axi_Optimize_Goal": "Performance", - "CONFIG.A_Precision_Type": "Single", - "CONFIG.C_A_Exponent_Width": "8", - "CONFIG.C_A_Fraction_Width": "24", - "CONFIG.Result_Precision_Type": "Single", - "CONFIG.C_Result_Exponent_Width": "8", - "CONFIG.C_Result_Fraction_Width": "24", - "CONFIG.C_Latency": "9", - "CONFIG.C_Rate": "1" - }) - - rtl_tasklet.add_ip_core('floating_point_add', 'floating_point', 'xilinx.com', '7.1', { - "CONFIG.Add_Sub_Value": "Add", - "CONFIG.Axi_Optimize_Goal": "Performance", - "CONFIG.C_Latency": "14" - }) - - # add read and write tasklets - read_x = state.add_tasklet('read_x', {'inp'}, {'out'}, 'out = inp') - read_y = state.add_tasklet('read_y', {'inp'}, {'out'}, 'out = inp') - write_result = state.add_tasklet('write_result', {'inp'}, {'out'}, 'out = inp') - - # add read and write maps - read_x_entry, read_x_exit = state.add_map('read_x_map', - dict(i='0:N//VECLEN'), - schedule=dace.ScheduleType.FPGA_Device) - read_y_entry, read_y_exit = state.add_map('read_y_map', - dict(i='0:N//VECLEN'), - schedule=dace.ScheduleType.FPGA_Device) - write_result_entry, write_result_exit = state.add_map('write_result_map', - dict(i='0:N//VECLEN'), - schedule=dace.ScheduleType.FPGA_Device) - - # add read_a memlets and access nodes - read_x_inp = state.add_read('fpga_x') - read_x_out = state.add_write('x_stream') - state.add_memlet_path(read_x_inp, read_x_entry, read_x, dst_conn='inp', memlet=dace.Memlet('fpga_x[i]')) - state.add_memlet_path(read_x, read_x_exit, read_x_out, src_conn='out', memlet=dace.Memlet('x_stream[0]')) - - read_y_inp = state.add_read('fpga_y') - read_y_out = state.add_write('y_stream') - state.add_memlet_path(read_y_inp, read_y_entry, read_y, dst_conn='inp', memlet=dace.Memlet('fpga_y[i]')) - state.add_memlet_path(read_y, read_y_exit, read_y_out, src_conn='out', memlet=dace.Memlet('y_stream[0]')) - - # add tasklet memlets - a = state.add_read('a') - x = state.add_read('x_stream') - y = state.add_read('y_stream') - result = state.add_write('result_stream') - state.add_memlet_path(a, rtl_tasklet, dst_conn='a_in', memlet=dace.Memlet('a[0]')) - state.add_memlet_path(x, rtl_tasklet, dst_conn='x_in', memlet=dace.Memlet('x_stream[0]')) - state.add_memlet_path(y, rtl_tasklet, dst_conn='y_in', memlet=dace.Memlet('y_stream[0]')) - state.add_memlet_path(rtl_tasklet, result, src_conn='result_out', memlet=dace.Memlet('result_stream[0]')) - - # add write_c memlets and access nodes - write_result_inp = state.add_read('result_stream') - write_result_out = state.add_write('fpga_result') - state.add_memlet_path(write_result_inp, - write_result_entry, - write_result, - dst_conn='inp', - memlet=dace.Memlet('result_stream[0]')) - state.add_memlet_path(write_result, - write_result_exit, - write_result_out, - src_conn='out', - memlet=dace.Memlet('fpga_result[i]')) - - # add copy to device state - copy_to_device = sdfg.add_state('copy_to_device') - cpu_x = copy_to_device.add_read('x') - cpu_y = copy_to_device.add_read('y') - dev_x = copy_to_device.add_write('fpga_x') - dev_y = copy_to_device.add_write('fpga_y') - copy_to_device.add_memlet_path(cpu_x, dev_x, memlet=dace.Memlet('x[0:N//VECLEN]')) - copy_to_device.add_memlet_path(cpu_y, dev_y, memlet=dace.Memlet('y[0:N//VECLEN]')) - sdfg.add_edge(copy_to_device, state, dace.InterstateEdge()) - - # add copy to host state - copy_to_host = sdfg.add_state('copy_to_host') - dev_result = copy_to_host.add_read('fpga_result') - cpu_result = copy_to_host.add_write('result') - copy_to_host.add_memlet_path(dev_result, cpu_result, memlet=dace.Memlet('result[0:N//VECLEN]')) - sdfg.add_edge(state, copy_to_host, dace.InterstateEdge()) - - # validate sdfg - sdfg.validate() - - return sdfg - - -###################################################################### - -if __name__ == '__main__': - with dace.config.set_temporary('compiler', 'xilinx', 'mode', value='hardware_emulation'): - # init data structures - N = 4096 - a = np.random.rand(1)[0].astype(np.float32) - x = np.random.rand(N).astype(np.float32) - y = np.random.rand(N).astype(np.float32) - result = np.zeros((N, )).astype(np.float32) - - # show initial values - print("a={}, x={}, y={}".format(a, x, y)) - - # Build the SDFG - sdfg = make_sdfg() - - # call program - sdfg(a=a, x=x, y=y, result=result, N=N) - - # show result - print("result={}".format(result)) - - # check result - expected = a * x + y - diff = np.linalg.norm(expected - result) / N - print("Difference:", diff) - assert diff <= 1e-5 diff --git a/samples/fpga/rtl/axpy_double_pump.py b/samples/fpga/rtl/axpy_double_pump.py deleted file mode 100644 index e87ddfd68b..0000000000 --- a/samples/fpga/rtl/axpy_double_pump.py +++ /dev/null @@ -1,456 +0,0 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. -""" - This sample shows the AXPY BLAS routine. It is implemented through Xilinx - IPs in order to utilize double pumping, which doubles the performance per - consumed FPGA resource. The double pumping operation is "inwards", which - means that the internal vectorization width of the core computation is half - that of the external vectorization width. This translates into utilizing half - the amount of internal computing resources, compared to a regular vectorized - implementetation. The block diagram of the design for a 32-bit floating-point - implementation using vectorization width 2 is: - - ap_aclk s_axis_y_in s_axis_x_in a - │ │ │ │ - │ │ │ │ - │ │ │ │ - ┌───────┼─────────┬────────┼─────────┐ │ │ - │ │ │ │ │ │ │ - │ │ │ ▼ │ ▼ │ - │ │ │ ┌────────────┐ │ ┌────────────┐ │ - │ │ └─►│ │ └─►│ │ │ - │ │ │ Clock sync │ │ Clock sync │ │ - │ │ ┌─►│ │ ┌─►│ │ │ - │ ▼ 300 MHz │ └─────┬──────┘ │ └─────┬──────┘ │ - │ ┌────────────┐ │ │ │ │ │ - │ │ Clock │ │ │ │ │ │ - │ │ │ ├────────┼─────────┤ │ │ - │ │ Multiplier │ │ │ │ │ │ - │ └─────┬──────┘ │ ▼ 64 bit │ ▼ 64 bit │ - │ │ 600 MHz │ ┌────────────┐ │ ┌────────────┐ │ - │ │ │ │ │ │ │ │ │ - │ └─────────┼─►│ Data issue │ └─►│ Data issue │ │ - │ │ │ │ │ │ │ - │ │ └─────┬──────┘ └─────┬──────┘ │ - │ │ │ 32 bit │ 32 bit │ - │ │ │ │ │ - │ │ │ │ │ - │ │ │ ▼ ▼ - │ │ │ ┌────────────┐ - │ │ │ │ │ - │ ├────────┼────────────────►│ Multiplier │ - │ │ │ │ │ - │ │ │ └─────┬──────┘ - │ │ │ │ - │ │ │ ┌──────────────┘ - │ │ │ │ - │ │ ▼ ▼ - │ │ ┌────────────┐ - │ │ │ │ - │ ├─────►│ Adder │ - │ │ │ │ - │ │ └─────┬──────┘ - │ │ │ - │ │ ▼ 32 bit - │ │ ┌─────────────┐ - │ │ │ │ - │ ├─────►│ Data packer │ - │ │ │ │ - │ │ └─────┬───────┘ - │ │ │ 64 bit - │ │ ▼ - │ │ ┌────────────┐ - │ └─────►│ │ - │ │ Clock sync │ - └───────────────────────►│ │ - └─────┬──────┘ - │ - ▼ - m_axis_result_out - - It is intended for running hardware_emulation or hardware xilinx targets. -""" - -import dace -import numpy as np - -# add symbol -N = dace.symbol('N') - - -def make_sdfg(veclen=2): - # Double check that the provided veclen is divisible by 2 - assert veclen >= 2 and veclen % 2 == 0 - - # add sdfg - sdfg = dace.SDFG(f'axpy_double_pump_v{veclen}') - - # add state - state = sdfg.add_state('device_state') - - # add parameter - sdfg.add_constant('VECLEN', veclen) - - # add arrays - sdfg.add_scalar('a', dtype=dace.float32, storage=dace.StorageType.FPGA_Global) - sdfg.add_array('x', [N // veclen], dtype=dace.vector(dace.float32, veclen), storage=dace.StorageType.CPU_Heap) - sdfg.add_array('y', [N // veclen], dtype=dace.vector(dace.float32, veclen), storage=dace.StorageType.CPU_Heap) - sdfg.add_array('result', [N // veclen], dtype=dace.vector(dace.float32, veclen), storage=dace.StorageType.CPU_Heap) - sdfg.add_array('fpga_x', [N // veclen], - dtype=dace.vector(dace.float32, veclen), - transient=True, - storage=dace.StorageType.FPGA_Global) - sdfg.add_array('fpga_y', [N // veclen], - dtype=dace.vector(dace.float32, veclen), - transient=True, - storage=dace.StorageType.FPGA_Global) - sdfg.add_array('fpga_result', [N // veclen], - dtype=dace.vector(dace.float32, veclen), - transient=True, - storage=dace.StorageType.FPGA_Global) - - # add streams - sdfg.add_stream('x_stream', - buffer_size=32, - dtype=dace.vector(dace.float32, veclen), - transient=True, - storage=dace.StorageType.FPGA_Local) - sdfg.add_stream('y_stream', - buffer_size=32, - dtype=dace.vector(dace.float32, veclen), - transient=True, - storage=dace.StorageType.FPGA_Local) - sdfg.add_stream('result_stream', - buffer_size=32, - dtype=dace.vector(dace.float32, veclen), - transient=True, - storage=dace.StorageType.FPGA_Local) - - # add custom rtl tasklet - rtl_tasklet = state.add_tasklet(name='rtl_tasklet', - inputs={'a_in', 'x_in', 'y_in'}, - outputs={'result_out'}, - code=''' - /* - Convention: - |--------------------------------------------------------| - | | - -->| ap_aclk (slow clock input) | - -->| ap_areset (slow reset input, rst on high) | - -->| ap_aclk_2 (fast clock input) | - -->| ap_areset_2 (fast reset input, rst on high) | - -->| ap_start (start pulse from host) | - <--| ap_done (tells the host that the kernel is done) | - | | - | For each input: For each output: | - | | - -->| s_axis_{input}_tvalid reg m_axis_{output}_tvalid |--> - -->| s_axis_{input}_tdata reg m_axis_{output}_tdata |--> - <--| reg s_axis_{input}_tready m_axis_{output}_tready |<-- - -->| s_axis_{input}_tkeep reg m_axis_{output}_tkeep |--> - -->| s_axis_{input}_tlast reg m_axis_{output}_tlast |--> - | | - |--------------------------------------------------------| - */ - assign ap_done = 1; // free-running kernel - wire ap_areset_n = ~ap_areset; - wire ap_areset_n_2 = ~ap_areset_2; - - wire axis_x_clk_data_tvalid; - wire [VECLEN-1:0][31:0] axis_x_clk_data_tdata; - wire axis_x_clk_data_tready; - - slow_to_fast_clk clock_sync_x ( - .s_axis_aclk(ap_aclk), - .s_axis_aresetn(ap_areset_n), - .m_axis_aclk(ap_aclk_2), - .m_axis_aresetn(ap_areset_n_2), - - .s_axis_tvalid(s_axis_x_in_tvalid), - .s_axis_tdata( s_axis_x_in_tdata), - .s_axis_tready(s_axis_x_in_tready), - - .m_axis_tvalid(axis_x_clk_data_tvalid), - .m_axis_tdata( axis_x_clk_data_tdata), - .m_axis_tready(axis_x_clk_data_tready) - ); - - wire axis_y_clk_data_tvalid; - wire [VECLEN-1:0][31:0] axis_y_clk_data_tdata; - wire axis_y_clk_data_tready; - - slow_to_fast_clk clock_sync_y ( - .s_axis_aclk(ap_aclk), - .s_axis_aresetn(ap_areset_n), - .m_axis_aclk(ap_aclk_2), - .m_axis_aresetn(ap_areset_n_2), - - .s_axis_tvalid(s_axis_y_in_tvalid), - .s_axis_tdata( s_axis_y_in_tdata), - .s_axis_tready(s_axis_y_in_tready), - - .m_axis_tvalid(axis_y_clk_data_tvalid), - .m_axis_tdata( axis_y_clk_data_tdata), - .m_axis_tready(axis_y_clk_data_tready) - ); - - wire axis_x_data_fl_tvalid; - wire [(VECLEN/2)-1:0][31:0] axis_x_data_fl_tdata; - wire [(VECLEN/2)-1:0] axis_x_data_fl_tready; - - slow_to_fast_data data_issue_x ( - .aclk(ap_aclk_2), - .aresetn(ap_areset_n_2), - - .s_axis_tvalid(axis_x_clk_data_tvalid), - .s_axis_tdata( axis_x_clk_data_tdata), - .s_axis_tready(axis_x_clk_data_tready), - - .m_axis_tvalid(axis_x_data_fl_tvalid), - .m_axis_tdata( axis_x_data_fl_tdata), - .m_axis_tready(&axis_x_data_fl_tready) - ); - - wire axis_y_data_fl_tvalid; - wire [(VECLEN/2)-1:0][31:0] axis_y_data_fl_tdata; - wire [(VECLEN/2)-1:0] axis_y_data_fl_tready; - - slow_to_fast_data data_issue_y ( - .aclk(ap_aclk_2), - .aresetn(ap_areset_n_2), - - .s_axis_tvalid(axis_y_clk_data_tvalid), - .s_axis_tdata( axis_y_clk_data_tdata), - .s_axis_tready(axis_y_clk_data_tready), - - .m_axis_tvalid(axis_y_data_fl_tvalid), - .m_axis_tdata( axis_y_data_fl_tdata), - .m_axis_tready(&axis_y_data_fl_tready) - ); - - wire [(VECLEN/2)-1:0] axis_ax_tvalid; - wire [(VECLEN/2)-1:0][31:0] axis_ax_tdata; - wire [(VECLEN/2)-1:0] axis_ax_tready; - - generate - for (genvar i = 0; i < (VECLEN/2); i++) begin - floating_point_mult multiplier ( - .aclk(ap_aclk_2), - - .s_axis_a_tvalid(scalars_valid), - .s_axis_a_tdata(a_in), - //.s_axis_a_tready(), - - .s_axis_b_tvalid(axis_x_data_fl_tvalid), - .s_axis_b_tdata( axis_x_data_fl_tdata[i]), - .s_axis_b_tready(axis_x_data_fl_tready[i]), - - .m_axis_result_tvalid(axis_ax_tvalid[i]), - .m_axis_result_tdata( axis_ax_tdata[i]), - .m_axis_result_tready(axis_ax_tready[i]) - ); - end - endgenerate - - wire [(VECLEN/2)-1:0] axis_result_tvalid; - wire [(VECLEN/2)-1:0][31:0] axis_result_tdata; - wire axis_result_tready; - - generate - for (genvar i = 0; i < (VECLEN/2); i++) begin - floating_point_add adder ( - .aclk(ap_aclk_2), - - .s_axis_a_tvalid(axis_ax_tvalid[i]), - .s_axis_a_tdata( axis_ax_tdata[i]), - .s_axis_a_tready(axis_ax_tready[i]), - - .s_axis_b_tvalid(axis_y_data_fl_tvalid), - .s_axis_b_tdata( axis_y_data_fl_tdata[i]), - .s_axis_b_tready(axis_y_data_fl_tready[i]), - - .m_axis_result_tvalid(axis_result_tvalid[i]), - .m_axis_result_tdata( axis_result_tdata[i]), - .m_axis_result_tready(axis_result_tready) - ); - end - endgenerate - - wire axis_result_data_clk_tvalid; - wire [VECLEN-1:0][31:0] axis_result_data_clk_tdata; - wire axis_result_data_clk_tready; - - fast_to_slow_data data_packer ( - .aclk(ap_aclk_2), - .aresetn(ap_areset_n_2), - - .s_axis_tvalid(&axis_result_tvalid), - .s_axis_tdata( axis_result_tdata), - .s_axis_tready(axis_result_tready), - - .m_axis_tvalid(axis_result_data_clk_tvalid), - .m_axis_tdata( axis_result_data_clk_tdata), - .m_axis_tready(axis_result_data_clk_tready) - ); - - fast_to_slow_clk clock_sync_result ( - .s_axis_aclk(ap_aclk_2), - .s_axis_aresetn(ap_areset_n_2), - .m_axis_aclk(ap_aclk), - .m_axis_aresetn(ap_areset_n), - - .s_axis_tvalid(axis_result_data_clk_tvalid), - .s_axis_tdata( axis_result_data_clk_tdata), - .s_axis_tready(axis_result_data_clk_tready), - - .m_axis_tvalid(m_axis_result_out_tvalid), - .m_axis_tdata( m_axis_result_out_tdata), - .m_axis_tready(m_axis_result_out_tready) - ); - ''', - language=dace.Language.SystemVerilog) - - rtl_tasklet.add_ip_core('slow_to_fast_clk', 'axis_clock_converter', 'xilinx.com', '1.1', { - "CONFIG.TDATA_NUM_BYTES": f'{4*veclen}', - "CONFIG.SYNCHRONIZATION_STAGES": "8" - }) - - rtl_tasklet.add_ip_core('slow_to_fast_data', 'axis_dwidth_converter', 'xilinx.com', '1.1', { - "CONFIG.S_TDATA_NUM_BYTES": f'{4*veclen}', - "CONFIG.M_TDATA_NUM_BYTES": f'{4*(veclen//2)}' - }) - - rtl_tasklet.add_ip_core( - 'floating_point_mult', 'floating_point', 'xilinx.com', '7.1', { - "CONFIG.Operation_Type": "Multiply", - "CONFIG.C_Mult_Usage": "Max_Usage", - "CONFIG.Axi_Optimize_Goal": "Performance", - "CONFIG.A_Precision_Type": "Single", - "CONFIG.C_A_Exponent_Width": "8", - "CONFIG.C_A_Fraction_Width": "24", - "CONFIG.Result_Precision_Type": "Single", - "CONFIG.C_Result_Exponent_Width": "8", - "CONFIG.C_Result_Fraction_Width": "24", - "CONFIG.C_Latency": "9", - "CONFIG.C_Rate": "1" - }) - - rtl_tasklet.add_ip_core('floating_point_add', 'floating_point', 'xilinx.com', '7.1', { - "CONFIG.Add_Sub_Value": "Add", - "CONFIG.Axi_Optimize_Goal": "Performance", - "CONFIG.C_Latency": "14" - }) - - rtl_tasklet.add_ip_core('fast_to_slow_data', 'axis_dwidth_converter', 'xilinx.com', '1.1', { - "CONFIG.S_TDATA_NUM_BYTES": f'{4*(veclen//2)}', - "CONFIG.M_TDATA_NUM_BYTES": f'{4*veclen}' - }) - - rtl_tasklet.add_ip_core('fast_to_slow_clk', 'axis_clock_converter', 'xilinx.com', '1.1', { - "CONFIG.TDATA_NUM_BYTES": f'{4*veclen}', - "CONFIG.SYNCHRONIZATION_STAGES": "8" - }) - - # add read and write tasklets - read_x = state.add_tasklet('read_x', {'inp'}, {'out'}, 'out = inp') - read_y = state.add_tasklet('read_y', {'inp'}, {'out'}, 'out = inp') - write_result = state.add_tasklet('write_result', {'inp'}, {'out'}, 'out = inp') - - # add read and write maps - read_x_entry, read_x_exit = state.add_map('read_x_map', - dict(i='0:N//VECLEN'), - schedule=dace.ScheduleType.FPGA_Device) - read_y_entry, read_y_exit = state.add_map('read_y_map', - dict(i='0:N//VECLEN'), - schedule=dace.ScheduleType.FPGA_Device) - write_result_entry, write_result_exit = state.add_map('write_result_map', - dict(i='0:N//VECLEN'), - schedule=dace.ScheduleType.FPGA_Device) - - # add read_a memlets and access nodes - read_x_inp = state.add_read('fpga_x') - read_x_out = state.add_write('x_stream') - state.add_memlet_path(read_x_inp, read_x_entry, read_x, dst_conn='inp', memlet=dace.Memlet('fpga_x[i]')) - state.add_memlet_path(read_x, read_x_exit, read_x_out, src_conn='out', memlet=dace.Memlet('x_stream[0]')) - - read_y_inp = state.add_read('fpga_y') - read_y_out = state.add_write('y_stream') - state.add_memlet_path(read_y_inp, read_y_entry, read_y, dst_conn='inp', memlet=dace.Memlet('fpga_y[i]')) - state.add_memlet_path(read_y, read_y_exit, read_y_out, src_conn='out', memlet=dace.Memlet('y_stream[0]')) - - # add tasklet memlets - a = state.add_read('a') - x = state.add_read('x_stream') - y = state.add_read('y_stream') - result = state.add_write('result_stream') - state.add_memlet_path(a, rtl_tasklet, dst_conn='a_in', memlet=dace.Memlet('a[0]')) - state.add_memlet_path(x, rtl_tasklet, dst_conn='x_in', memlet=dace.Memlet('x_stream[0]')) - state.add_memlet_path(y, rtl_tasklet, dst_conn='y_in', memlet=dace.Memlet('y_stream[0]')) - state.add_memlet_path(rtl_tasklet, result, src_conn='result_out', memlet=dace.Memlet('result_stream[0]')) - - # add write_c memlets and access nodes - write_result_inp = state.add_read('result_stream') - write_result_out = state.add_write('fpga_result') - state.add_memlet_path(write_result_inp, - write_result_entry, - write_result, - dst_conn='inp', - memlet=dace.Memlet('result_stream[0]')) - state.add_memlet_path(write_result, - write_result_exit, - write_result_out, - src_conn='out', - memlet=dace.Memlet('fpga_result[i]')) - - # add copy to device state - copy_to_device = sdfg.add_state('copy_to_device') - cpu_x = copy_to_device.add_read('x') - cpu_y = copy_to_device.add_read('y') - dev_x = copy_to_device.add_write('fpga_x') - dev_y = copy_to_device.add_write('fpga_y') - copy_to_device.add_memlet_path(cpu_x, dev_x, memlet=dace.Memlet('x[0:N//VECLEN]')) - copy_to_device.add_memlet_path(cpu_y, dev_y, memlet=dace.Memlet('y[0:N//VECLEN]')) - sdfg.add_edge(copy_to_device, state, dace.InterstateEdge()) - - # add copy to host state - copy_to_host = sdfg.add_state('copy_to_host') - dev_result = copy_to_host.add_read('fpga_result') - cpu_result = copy_to_host.add_write('result') - copy_to_host.add_memlet_path(dev_result, cpu_result, memlet=dace.Memlet('result[0:N//VECLEN]')) - sdfg.add_edge(state, copy_to_host, dace.InterstateEdge()) - - # validate sdfg - sdfg.validate() - - return sdfg - - -###################################################################### - -if __name__ == '__main__': - with dace.config.set_temporary('compiler', 'xilinx', 'frequency', value='"0:300\\|1:600"'): - with dace.config.set_temporary('compiler', 'xilinx', 'mode', value='hardware_emulation'): - # init data structures - N = 4096 - a = np.random.rand(1)[0].astype(np.float32) - x = np.random.rand(N).astype(np.float32) - y = np.random.rand(N).astype(np.float32) - result = np.zeros((N, )).astype(np.float32) - - # show initial values - print("a={}, x={}, y={}".format(a, x, y)) - - # Build the SDFG - sdfg = make_sdfg() - - # call program - sdfg(a=a, x=x, y=y, result=result, N=N) - - # show result - print("result={}".format(result)) - - # check result - expected = a * x + y - diff = np.linalg.norm(expected - result) / N - print("Difference:", diff) - - assert diff <= 1e-5 diff --git a/samples/fpga/rtl/fladd.py b/samples/fpga/rtl/fladd.py deleted file mode 100644 index c2e2fdcc81..0000000000 --- a/samples/fpga/rtl/fladd.py +++ /dev/null @@ -1,194 +0,0 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. -""" - This sample shows how to utilize an IP core in an RTL tasklet. This is done - through the vector add problem, which adds two floating point vectors - together. - - It is intended for running hardware_emulation or hardware xilinx targets. -""" - -import dace -import numpy as np - -# add symbol -N = dace.symbol('N') - -# add sdfg -sdfg = dace.SDFG('fladd') - -# add state -state = sdfg.add_state('device_state') - -# add parameter -veclen = 1 -sdfg.add_constant('VECLEN', veclen) - -# add arrays -sdfg.add_array('A', [N // veclen], dtype=dace.vector(dace.float32, veclen), storage=dace.StorageType.CPU_Heap) -sdfg.add_array('B', [N // veclen], dtype=dace.vector(dace.float32, veclen), storage=dace.StorageType.CPU_Heap) -sdfg.add_array('C', [N // veclen], dtype=dace.vector(dace.float32, veclen), storage=dace.StorageType.CPU_Heap) -sdfg.add_array('fpga_A', [N // veclen], - dtype=dace.vector(dace.float32, veclen), - transient=True, - storage=dace.StorageType.FPGA_Global) -sdfg.add_array('fpga_B', [N // veclen], - dtype=dace.vector(dace.float32, veclen), - transient=True, - storage=dace.StorageType.FPGA_Global) -sdfg.add_array('fpga_C', [N // veclen], - dtype=dace.vector(dace.float32, veclen), - transient=True, - storage=dace.StorageType.FPGA_Global) - -# add streams -sdfg.add_stream('A_stream', - buffer_size=32, - dtype=dace.vector(dace.float32, veclen), - transient=True, - storage=dace.StorageType.FPGA_Local) -sdfg.add_stream('B_stream', - buffer_size=32, - dtype=dace.vector(dace.float32, veclen), - transient=True, - storage=dace.StorageType.FPGA_Local) -sdfg.add_stream('C_stream', - buffer_size=32, - dtype=dace.vector(dace.float32, veclen), - transient=True, - storage=dace.StorageType.FPGA_Local) - -# add custom rtl tasklet -rtl_tasklet = state.add_tasklet(name='rtl_tasklet', - inputs={'a', 'b'}, - outputs={'c'}, - code=''' - /* - Convention: - |--------------------------------------------------------| - | | - -->| ap_aclk (clock input) | - -->| ap_areset (reset input, rst on high) | - -->| ap_start (start pulse from host) | - <--| ap_done (tells the host that the kernel is done) | - | | - | For each input: For each output: | - | | - -->| s_axis_{input}_tvalid reg m_axis_{output}_tvalid |--> - -->| s_axis_{input}_tdata reg m_axis_{output}_tdata |--> - <--| reg s_axis_{input}_tready m_axis_{output}_tready |<-- - -->| s_axis_{input}_tkeep reg m_axis_{output}_tkeep |--> - -->| s_axis_{input}_tlast reg m_axis_{output}_tlast |--> - | | - |--------------------------------------------------------| - */ - - assign ap_done = 1; // free-running kernel - - wire ap_aresetn = ~ap_areset; // IP core is active-low reset - - floating_point_add add( - .aclk(ap_aclk), - .aresetn(ap_aresetn), - - .s_axis_a_tvalid(s_axis_a_tvalid), - .s_axis_a_tdata(s_axis_a_tdata), - .s_axis_a_tready(s_axis_a_tready), - - .s_axis_b_tvalid(s_axis_b_tvalid), - .s_axis_b_tdata(s_axis_b_tdata), - .s_axis_b_tready(s_axis_b_tready), - - .m_axis_result_tvalid(m_axis_c_tvalid), - .m_axis_result_tdata(m_axis_c_tdata), - .m_axis_result_tready(m_axis_c_tready) - ); - ''', - language=dace.Language.SystemVerilog) - -rtl_tasklet.add_ip_core('floating_point_add', 'floating_point', 'xilinx.com', '7.1', { - 'CONFIG.Add_Sub_Value': 'Add', - 'CONFIG.Has_ARESETn': 'true' -}) - -# add read and write tasklets -read_a = state.add_tasklet('read_a', {'inp'}, {'out'}, 'out = inp') -read_b = state.add_tasklet('read_b', {'inp'}, {'out'}, 'out = inp') -write_c = state.add_tasklet('write_c', {'inp'}, {'out'}, 'out = inp') - -# add read and write maps -read_a_entry, read_a_exit = state.add_map('read_a_map', dict(i='0:N//VECLEN'), schedule=dace.ScheduleType.FPGA_Device) -read_b_entry, read_b_exit = state.add_map('read_b_map', dict(i='0:N//VECLEN'), schedule=dace.ScheduleType.FPGA_Device) -write_c_entry, write_c_exit = state.add_map('write_c_map', - dict(i='0:N//VECLEN'), - schedule=dace.ScheduleType.FPGA_Device) - -# add read_a memlets and access nodes -read_a_inp = state.add_read('fpga_A') -read_a_out = state.add_write('A_stream') -state.add_memlet_path(read_a_inp, read_a_entry, read_a, dst_conn='inp', memlet=dace.Memlet('fpga_A[i]')) -state.add_memlet_path(read_a, read_a_exit, read_a_out, src_conn='out', memlet=dace.Memlet('A_stream[0]')) - -read_b_inp = state.add_read('fpga_B') -read_b_out = state.add_write('B_stream') -state.add_memlet_path(read_b_inp, read_b_entry, read_b, dst_conn='inp', memlet=dace.Memlet('fpga_B[i]')) -state.add_memlet_path(read_b, read_b_exit, read_b_out, src_conn='out', memlet=dace.Memlet('B_stream[0]')) - -# add tasklet memlets -A = state.add_read('A_stream') -B = state.add_read('B_stream') -C = state.add_write('C_stream') -state.add_memlet_path(A, rtl_tasklet, dst_conn='a', memlet=dace.Memlet('A_stream[0]')) -state.add_memlet_path(B, rtl_tasklet, dst_conn='b', memlet=dace.Memlet('B_stream[0]')) -state.add_memlet_path(rtl_tasklet, C, src_conn='c', memlet=dace.Memlet('C_stream[0]')) - -# add write_c memlets and access nodes -write_c_inp = state.add_read('C_stream') -write_c_out = state.add_write('fpga_C') -state.add_memlet_path(write_c_inp, write_c_entry, write_c, dst_conn='inp', memlet=dace.Memlet('C_stream[0]')) -state.add_memlet_path(write_c, write_c_exit, write_c_out, src_conn='out', memlet=dace.Memlet('fpga_C[i]')) - -# add copy to device state -copy_to_device = sdfg.add_state('copy_to_device') -cpu_a = copy_to_device.add_read('A') -cpu_b = copy_to_device.add_read('B') -dev_a = copy_to_device.add_write('fpga_A') -dev_b = copy_to_device.add_write('fpga_B') -copy_to_device.add_memlet_path(cpu_a, dev_a, memlet=dace.Memlet('A[0:N//VECLEN]')) -copy_to_device.add_memlet_path(cpu_b, dev_b, memlet=dace.Memlet('B[0:N//VECLEN]')) -sdfg.add_edge(copy_to_device, state, dace.InterstateEdge()) - -# add copy to host state -copy_to_host = sdfg.add_state('copy_to_host') -dev_c = copy_to_host.add_read('fpga_C') -cpu_c = copy_to_host.add_write('C') -copy_to_host.add_memlet_path(dev_c, cpu_c, memlet=dace.Memlet('C[0:N//VECLEN]')) -sdfg.add_edge(state, copy_to_host, dace.InterstateEdge()) - -# validate sdfg -sdfg.validate() - -###################################################################### - -if __name__ == '__main__': - with dace.config.set_temporary('compiler', 'xilinx', 'mode', value='hardware_emulation'): - # init data structures - N = 8192 - a = np.random.randint(0, 100, N).astype(np.float32) - b = np.random.randint(0, 100, N).astype(np.float32) - c = np.zeros((N // veclen, )).astype(np.float32) - print(a.shape, b.shape, c.shape) - - # show initial values - print("a={}, b={}".format(a, b)) - - # call program - sdfg(A=a, B=b, C=c, N=N) - - # show result - print("a={}, b={}, c={}".format(a, b, c)) - - # check result - expected = a + b - diff = np.linalg.norm(expected - c) / N - print("Difference:", diff) - assert diff <= 1e-5 diff --git a/samples/fpga/rtl/pipeline.py b/samples/fpga/rtl/pipeline.py deleted file mode 100644 index 4fe5d3c74b..0000000000 --- a/samples/fpga/rtl/pipeline.py +++ /dev/null @@ -1,172 +0,0 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. -""" - This sample shows a DEPTH deep pipeline, where each stage adds 1 to the - integer input stream. - - It is intended for running hardware_emulation or hardware xilinx targets. -""" - -import dace -import numpy as np - -# add symbols -N = dace.symbol('N') - -# add sdfg -sdfg = dace.SDFG('pipeline') - -# add state -state = sdfg.add_state('device_state') - -# add constants -depth = 10 -sdfg.add_constant('DEPTH', depth) - -# add arrays -sdfg.add_array('A', [N], dtype=dace.int32, storage=dace.StorageType.CPU_Heap) -sdfg.add_array('B', [N], dtype=dace.int32, storage=dace.StorageType.CPU_Heap) -sdfg.add_array('fpga_A', [N], dtype=dace.int32, transient=True, storage=dace.StorageType.FPGA_Global) -sdfg.add_array('fpga_B', [N], dtype=dace.int32, transient=True, storage=dace.StorageType.FPGA_Global) - -# add streams -sdfg.add_stream('A_stream', dtype=dace.int32, transient=True, storage=dace.StorageType.FPGA_Local) -sdfg.add_stream('B_stream', dtype=dace.int32, transient=True, storage=dace.StorageType.FPGA_Local) - -# add custom rtl tasklet -rtl_tasklet = state.add_tasklet(name='rtl_tasklet', - inputs={'a'}, - outputs={'b'}, - code=''' - /* - Convention: - |--------------------------------------------------------| - | | - -->| ap_aclk (clock input) | - -->| ap_areset (reset input, rst on high) | - -->| ap_start (start pulse from host) | - <--| ap_done (tells the host that the kernel is done) | - | | - | For each input: For each output: | - | | - -->| s_axis_{input}_tvalid reg m_axis_{output}_tvalid |--> - -->| s_axis_{input}_tdata reg m_axis_{output}_tdata |--> - <--| reg s_axis_{input}_tready m_axis_{output}_tready |<-- - -->| s_axis_{input}_tkeep reg m_axis_{output}_tkeep |--> - -->| s_axis_{input}_tlast reg m_axis_{output}_tlast |--> - | | - |--------------------------------------------------------| - */ - - assign ap_done = 1; // free-running kernel - - reg [DEPTH-1:0] tvalids; - reg [31:0] tdatas [DEPTH-1:0]; - reg [DEPTH-1:0] treadys; - integer i; - - always @(posedge ap_aclk) begin - if (ap_areset) begin - for (i = 0; i < DEPTH; i = i + 1) begin - tvalids[i] = 0; - tdatas[i] = 0; - treadys[i] = 1; - end - s_axis_a_tready = 1; - m_axis_b_tvalid = 0; - m_axis_b_tdata = 0; - end else begin - // Handle m_axis - if (!m_axis_b_tvalid || (m_axis_b_tvalid && m_axis_b_tready)) begin - m_axis_b_tvalid = tvalids[DEPTH-1]; - m_axis_b_tdata = tdatas[DEPTH-1]; - tvalids[DEPTH-1] = 0; - tvalids[DEPTH-1] = 0; - end - treadys[DEPTH-1] = !m_axis_b_tvalid; - - // Handle intermediates - for (i = DEPTH-1; i > 0; i = i - 1) begin - if (tvalids[i-1] && treadys[i-1]) begin - tvalids[i] = tvalids[i-1]; - tdatas[i] = tdatas[i-1] + 1; - tvalids[i-1] = 0; - tdatas[i-1] = 0; - end - treadys[i-1] = !tvalids[i]; - end - - // Handle s_axis - if (s_axis_a_tvalid && s_axis_a_tready) begin - tvalids[0] = s_axis_a_tvalid; - tdatas[0] = s_axis_a_tdata + 1; - end - s_axis_a_tready = !tvalids[0]; - end - end - ''', - language=dace.Language.SystemVerilog) - -# add read and write tasklets -read_a = state.add_tasklet('read_a', {'inp'}, {'out'}, 'out = inp') -write_b = state.add_tasklet('write_b', {'inp'}, {'out'}, 'out = inp') - -# add read and write maps -read_a_entry, read_a_exit = state.add_map('read_a_map', dict(i='0:N'), schedule=dace.ScheduleType.FPGA_Device) -write_b_entry, write_b_exit = state.add_map('write_b_map', dict(i='0:N'), schedule=dace.ScheduleType.FPGA_Device) - -# add read_a memlets and access nodes -read_a_inp = state.add_read('fpga_A') -read_a_out = state.add_write('A_stream') -state.add_memlet_path(read_a_inp, read_a_entry, read_a, dst_conn='inp', memlet=dace.Memlet('fpga_A[i]')) -state.add_memlet_path(read_a, read_a_exit, read_a_out, src_conn='out', memlet=dace.Memlet('A_stream[0]')) - -# add tasklet memlets -A = state.add_read('A_stream') -B = state.add_write('B_stream') -state.add_memlet_path(A, rtl_tasklet, dst_conn='a', memlet=dace.Memlet('A_stream[0]')) -state.add_memlet_path(rtl_tasklet, B, src_conn='b', memlet=dace.Memlet('B_stream[0]')) - -# add write_b memlets and access nodes -write_b_inp = state.add_read('B_stream') -write_b_out = state.add_write('fpga_B') -state.add_memlet_path(write_b_inp, write_b_entry, write_b, dst_conn='inp', memlet=dace.Memlet('B_stream[0]')) -state.add_memlet_path(write_b, write_b_exit, write_b_out, src_conn='out', memlet=dace.Memlet('fpga_B[i]')) - -# add copy to device state -copy_to_device = sdfg.add_state('copy_to_device') -cpu_a = copy_to_device.add_read('A') -dev_a = copy_to_device.add_write('fpga_A') -copy_to_device.add_memlet_path(cpu_a, dev_a, memlet=dace.Memlet('A[0:N]')) -sdfg.add_edge(copy_to_device, state, dace.InterstateEdge()) - -# add copy to host state -copy_to_host = sdfg.add_state('copy_to_host') -dev_b = copy_to_host.add_read('fpga_B') -cpu_b = copy_to_host.add_write('B') -copy_to_host.add_memlet_path(dev_b, cpu_b, memlet=dace.Memlet('B[0:N]')) -sdfg.add_edge(state, copy_to_host, dace.InterstateEdge()) - -# validate sdfg -sdfg.validate() - -###################################################################### - -if __name__ == '__main__': - with dace.config.set_temporary('compiler', 'xilinx', 'mode', value='hardware_emulation'): - # init data structures - N = 8192 - a = np.random.randint(0, 100, N).astype(np.int32) - b = np.zeros((N, )).astype(np.int32) - - # show initial values - print("a={}, b={}".format(a, b)) - - # call program - sdfg(A=a, B=b, N=N) - - # show result - print("a={}, b={}".format(a, b)) - - # check result - for i in range(N): - assert b[i] == a[i] + depth diff --git a/samples/fpga/rtl/rtl_multi_tasklet.py b/samples/fpga/rtl/rtl_multi_tasklet.py deleted file mode 100644 index 4a4a09deec..0000000000 --- a/samples/fpga/rtl/rtl_multi_tasklet.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. -""" - Two sequential RTL tasklets connected through a memlet. - - It is intended for running simulation xilinx targets. -""" - -import dace -import numpy as np - -# add sdfg -sdfg = dace.SDFG('rtl_multi_tasklet') - -# add state -state = sdfg.add_state() - -# add arrays -sdfg.add_array('A', [1], dtype=dace.int32) -sdfg.add_array('B', [1], dtype=dace.int32) -sdfg.add_array('C', [1], dtype=dace.int32) - -# add custom cpp tasklet -tasklet0 = state.add_tasklet(name='rtl_tasklet0', - inputs={'a'}, - outputs={'b'}, - code="""\ -typedef enum [1:0] {READY, BUSY, DONE} state_e; -state_e state; - -always@(posedge ap_aclk) begin - if (ap_areset) begin // case: reset - m_axis_b_tdata <= 0; - s_axis_a_tready <= 1'b1; - state <= READY; - end else if (s_axis_a_tvalid && state == READY) begin // case: load a - m_axis_b_tdata <= s_axis_a_tdata; - s_axis_a_tready <= 1'b0; - state <= BUSY; - end else if (m_axis_b_tdata < 80) // case: increment counter b - m_axis_b_tdata <= m_axis_b_tdata + 1; - else - m_axis_b_tdata <= m_axis_b_tdata; - state <= DONE; -end - -assign m_axis_b_tvalid = (m_axis_b_tdata >= 80) ? 1'b1:1'b0; -""", - language=dace.Language.SystemVerilog) - -tasklet1 = state.add_tasklet(name='rtl_tasklet1', - inputs={'b'}, - outputs={'c'}, - code="""\ -typedef enum [1:0] {READY, BUSY, DONE} state_e; -state_e state; - -always@(posedge ap_aclk) begin - if (ap_areset) begin // case: reset - m_axis_c_tdata <= 0; - s_axis_b_tready <= 1'b1; - state <= READY; - end else if (s_axis_b_tvalid && state == READY) begin // case: load a - m_axis_c_tdata <= s_axis_b_tdata; - s_axis_b_tready <= 1'b0; - state <= BUSY; - end else if (m_axis_c_tdata < 100) // case: increment counter b - m_axis_c_tdata <= m_axis_c_tdata + 1; - else - m_axis_c_tdata <= m_axis_c_tdata; - state <= DONE; -end - -assign m_axis_c_tvalid = (m_axis_c_tdata >= 100) ? 1'b1:1'b0; -""", - language=dace.Language.SystemVerilog) - -# add input/output array -A = state.add_read('A') -B_w = state.add_write('B') -B_r = state.add_read('B') -C = state.add_write('C') - -# connect input/output array with the tasklet -state.add_edge(A, None, tasklet0, 'a', dace.Memlet('A[0]')) -state.add_edge(tasklet0, 'b', B_w, None, dace.Memlet('B[0]')) -state.add_edge(B_r, None, tasklet1, 'b', dace.Memlet('B[0]')) -state.add_edge(tasklet1, 'c', C, None, dace.Memlet('C[0]')) - -# validate sdfg -sdfg.validate() - -###################################################################### - -if __name__ == '__main__': - with dace.config.set_temporary('compiler', 'xilinx', 'mode', value='simulation'): - # init data structures - a = np.random.randint(0, 80, 1).astype(np.int32) - b = np.array([0]).astype(np.int32) - c = np.array([0]).astype(np.int32) - - # show initial values - print("a={}, b={}, c={}".format(a, b, c)) - - # call program - sdfg(A=a, B=b, C=c) - - # show result - print("a={}, b={}, c={}".format(a, b, c)) - - # check result - assert b == 80 - assert c == 100 diff --git a/samples/fpga/rtl/rtl_tasklet_parameter.py b/samples/fpga/rtl/rtl_tasklet_parameter.py deleted file mode 100644 index 112e88a6bf..0000000000 --- a/samples/fpga/rtl/rtl_tasklet_parameter.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. -""" - Simple RTL tasklet with a single scalar input and a single scalar output. It increments b from a up to 100. - - It is intended for running simulation xilinx targets. -""" - -import dace -import numpy as np - -# add sdfg -sdfg = dace.SDFG('rtl_tasklet_parameter') - -# add state -state = sdfg.add_state() - -# add arrays -sdfg.add_array('A', [1], dtype=dace.int32) -sdfg.add_array('B', [1], dtype=dace.int32) - -# add parameters -sdfg.add_constant("MAX_VAL", 42) - -# add custom cpp tasklet -tasklet = state.add_tasklet(name='rtl_tasklet', - inputs={'a'}, - outputs={'b'}, - code=''' - /* - Convention: - |---------------------------------------------------------------------| - -->| ap_aclk (clock input) | - -->| ap_areset (reset input, rst on high) | - | | - -->| {inputs} reg {outputs} |--> - | | - <--| s_axis_a_tready (ready for data) (data avail) m_axis_b_tvalid |--> - -->| s_axis_a_tvalid (new data avail) (data consumed) m_axis_b_tready |<-- - |---------------------------------------------------------------------| - */ - - typedef enum [1:0] {READY, BUSY, DONE} state_e; - state_e state; - - always@(posedge ap_aclk) begin - if (ap_areset) begin // case: reset - m_axis_b_tdata <= 0; - s_axis_a_tready <= 1'b1; - state <= READY; - end else if (s_axis_a_tvalid && state == READY) begin // case: load a - m_axis_b_tdata <= s_axis_a_tdata; - s_axis_a_tready <= 1'b0; - state <= BUSY; - end else if (m_axis_b_tdata < MAX_VAL) // case: increment counter b - m_axis_b_tdata <= m_axis_b_tdata + 1; - else - m_axis_b_tdata <= m_axis_b_tdata; - state <= DONE; - end - - assign m_axis_b_tvalid = (m_axis_b_tdata >= MAX_VAL) ? 1'b1:1'b0; - ''', - language=dace.Language.SystemVerilog) - -# add input/output array -A = state.add_read('A') -B = state.add_write('B') - -# connect input/output array with the tasklet -state.add_edge(A, None, tasklet, 'a', dace.Memlet('A[0]')) -state.add_edge(tasklet, 'b', B, None, dace.Memlet('B[0]')) - -# validate sdfg -sdfg.validate() - -###################################################################### - -if __name__ == '__main__': - with dace.config.set_temporary('compiler', 'xilinx', 'mode', value='simulation'): - # init data structures - a = np.random.randint(0, 100, 1).astype(np.int32) - b = np.array([0]).astype(np.int32) - - # show initial values - print("a={}, b={}".format(a, b)) - - # call program - sdfg(A=a, B=b) - - # show result - print("a={}, b={}".format(a, b)) - - # check result - assert b == sdfg.constants["MAX_VAL"] diff --git a/samples/fpga/rtl/rtl_tasklet_pipeline.py b/samples/fpga/rtl/rtl_tasklet_pipeline.py deleted file mode 100644 index 3ef20cd03f..0000000000 --- a/samples/fpga/rtl/rtl_tasklet_pipeline.py +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. -""" - Pipelined, AXI-handshake compliant example that increments b from a up to 100. - - It is intended for running simulation xilinx targets. -""" - -import dace -import numpy as np - -# add symbol -N = dace.symbol('N') - -# add sdfg -sdfg = dace.SDFG('rtl_tasklet_pipeline') - -# add state -state = sdfg.add_state() - -# define compile-time constant -sdfg.specialize(dict(N=4)) - -# disable sv debugging output -sdfg.add_constant("SYSTEMVERILOG_DEBUG", False) - -# add arrays -sdfg.add_array('A', [N], dtype=dace.int32) -sdfg.add_array('B', [N], dtype=dace.int32) - -# add custom cpp tasklet -tasklet = state.add_tasklet(name='rtl_tasklet', - inputs={'a'}, - outputs={'b'}, - code=''' - /* - Convention: - |---------------------------------------------------------------------| - -->| ap_aclk (clock input) | - -->| ap_areset (reset input, rst on high) | - | | - -->| {inputs} reg {outputs} |--> - | | - <--| s_axis_a_tready (ready for data) (data avail) m_axis_b_tvalid |--> - -->| s_axis_a_tvalid (new data avail) (data consumed) m_axis_b_tready |<-- - |---------------------------------------------------------------------| - */ - - /**** - * Finite State Machine - *****/ - typedef enum [1:0] {READY, BUSY, DONE} state_e; - state_e state, state_next; - - always@(posedge ap_aclk) - begin - if(ap_areset) - state <= READY; - else - state <= state_next; - end - - always_comb - begin - state_next = state; - case(state) - READY: if(s_axis_a_tvalid) state_next = BUSY; - BUSY: if(m_axis_b_tdata >= 99) state_next = DONE; - DONE: if(m_axis_b_tready) state_next = READY; - default: state_next = state; - endcase - end - - - /*********** - * Control Logic - ************/ - always_comb - begin - // init default value - s_axis_a_tready = 0; - m_axis_b_tvalid = 0; - // set actual value - case(state) - READY: s_axis_a_tready = 1; - DONE: m_axis_b_tvalid = 1; - default:; - endcase - end - - /**** - * Data Path - ****/ - always@(posedge ap_aclk) - begin - case(state) - READY: if(s_axis_a_tvalid) m_axis_b_tdata <= s_axis_a_tdata; - BUSY: m_axis_b_tdata <= m_axis_b_tdata + 1; - DONE: m_axis_b_tdata <= m_axis_b_tdata; - default: m_axis_b_tdata <= m_axis_b_tdata; - endcase - end - - /***** - * DEBUG - *****/ - always@(posedge ap_aclk) - begin - if(SYSTEMVERILOG_DEBUG) - begin - case(state) - READY: $display("READY"); - BUSY: $display("BUSY"); - DONE: $display("DONE"); - default: $display("Undefined State"); - endcase - end - end - ''', - language=dace.Language.SystemVerilog) - -# add input/output array -A = state.add_read('A') -B = state.add_write('B') - -# connect input/output array with the tasklet -state.add_edge(A, None, tasklet, 'a', dace.Memlet('A[0:N-1]')) -state.add_edge(tasklet, 'b', B, None, dace.Memlet('B[0:N-1]')) - -# validate sdfg -sdfg.validate() - -###################################################################### - -if __name__ == '__main__': - with dace.config.set_temporary('compiler', 'xilinx', 'mode', value='simulation'): - # init data structures - num_elements = dace.symbolic.evaluate(N, sdfg.constants) - a = np.random.randint(0, 100, num_elements).astype(np.int32) - b = np.array([0] * num_elements).astype(np.int32) - - # show initial values - print("a={}, b={}".format(a, b)) - - # call program - sdfg(A=a, B=b) - - # show result - print("a={}, b={}".format(a, b)) - - assert b[ - 0] == 100 # TODO: implement detection of #elements to process, s.t. we can extend the assertion to the whole array - assert np.all(map((lambda x: x == 0), b[1:-1])) # should still be at the init value (for the moment) diff --git a/samples/fpga/rtl/rtl_tasklet_scalar.py b/samples/fpga/rtl/rtl_tasklet_scalar.py deleted file mode 100644 index cf8d53ec91..0000000000 --- a/samples/fpga/rtl/rtl_tasklet_scalar.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. -""" - Simple RTL tasklet with a single scalar input and a single scalar output. It increments b from a up to 100. - - It is intended for running simulation xilinx targets. -""" - -import dace -import numpy as np - -# add sdfg -sdfg = dace.SDFG('rtl_tasklet_scalar') - -# add state -state = sdfg.add_state() - -# add arrays -sdfg.add_array('A', [1], dtype=dace.int32) -sdfg.add_array('B', [1], dtype=dace.int32) - -# add custom cpp tasklet -tasklet = state.add_tasklet(name='rtl_tasklet', - inputs={'a'}, - outputs={'b'}, - code=''' - /* - Convention: - |--------------------------------------------------------| - | | - -->| ap_aclk (clock input) | - -->| ap_areset (reset input, rst on high) | - | | - | For each input: For each output: | - | | - -->| s_axis_{input}_tvalid reg m_axis_{output}_tvalid |--> - -->| s_axis_{input}_tdata reg m_axis_{output}_tdata |--> - <--| reg s_axis_{input}_tready m_axis_{output}_tready |<-- - -->| s_axis_{input}_tkeep reg m_axis_{output}_tkeep |--> - -->| s_axis_{input}_tlast reg m_axis_{output}_tlast |--> - | | - |--------------------------------------------------------| - */ - - typedef enum logic [1:0] {READY, BUSY, DONE} state_e; - state_e state; - - always@(posedge ap_aclk) begin - if (ap_areset) begin // case: reset - m_axis_b_tdata <= 0; - s_axis_a_tready <= 1'b1; - state <= READY; - end else if (s_axis_a_tvalid && state == READY) begin // case: load a - m_axis_b_tdata <= s_axis_a_tdata; - s_axis_a_tready <= 1'b0; - state <= BUSY; - end else if (m_axis_b_tdata < 100) // case: increment counter b - m_axis_b_tdata <= m_axis_b_tdata + 1; - else begin - m_axis_b_tdata <= m_axis_b_tdata; - state <= DONE; - end - end - - assign m_axis_b_tvalid = (m_axis_b_tdata >= 100) ? 1'b1:1'b0; - ''', - language=dace.Language.SystemVerilog) - -# add input/output array -A = state.add_read('A') -B = state.add_write('B') - -# connect input/output array with the tasklet -state.add_memlet_path(A, tasklet, dst_conn='a', memlet=dace.Memlet('A[0]')) -state.add_memlet_path(tasklet, B, src_conn='b', memlet=dace.Memlet('B[0]')) - -# validate sdfg -sdfg.validate() - -###################################################################### - -if __name__ == '__main__': - with dace.config.set_temporary('compiler', 'xilinx', 'mode', value='simulation'): - # init data structures - a = np.random.randint(0, 100, 1).astype(np.int32) - b = np.array([0]).astype(np.int32) - - # show initial values - print("a={}, b={}".format(a, b)) - - # call program - sdfg(A=a, B=b) - - # show result - print("a={}, b={}".format(a, b)) - - # check result - assert b == 100 diff --git a/samples/fpga/rtl/rtl_tasklet_vector.py b/samples/fpga/rtl/rtl_tasklet_vector.py deleted file mode 100644 index 9015b4f35e..0000000000 --- a/samples/fpga/rtl/rtl_tasklet_vector.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. -""" - RTL tasklet with a vector input of 4 int32 (width=128bits) and a single scalar output. It increments b from a[31:0] up to 100. - - It is intended for running simulation xilinx targets. -""" - -import dace -import numpy as np - -# add symbol -WIDTH = dace.symbol('WIDTH') - -# add sdfg -sdfg = dace.SDFG('rtl_tasklet_vector') - -# define compile-time constant -sdfg.specialize(dict(WIDTH=4)) - -# add state -state = sdfg.add_state() - -# add arrays -sdfg.add_array('A', [WIDTH], dtype=dace.int32) -sdfg.add_array('B', [1], dtype=dace.int32) - -# add custom cpp tasklet -tasklet = state.add_tasklet(name='rtl_tasklet', - inputs={'a': dace.vector(dace.int32, WIDTH)}, - outputs={'b'}, - code=''' - /* - Convention: - |--------------------------------------------------------------------| - -->| ap_aclk (clock input) | - -->| ap_areset (reset input, rst on high) | - | | - -->| {inputs} reg {outputs} |--> - | | - <--| s_axis_a_tready (ready for data) (data avail) m_axis_b_tvalid|--> - -->| s_axis_a_tvalid (new data avail) (data consumed) m_axis_b_tready|<-- - |--------------------------------------------------------------------| - */ - - typedef enum [1:0] {READY, BUSY, DONE} state_e; - state_e state; - - always@(posedge ap_aclk) begin - if (ap_areset) begin // case: reset - m_axis_b_tdata <= 0; - s_axis_a_tready <= 1'b1; - state <= READY; - end else if (s_axis_a_tvalid && state == READY) begin // case: load a - m_axis_b_tdata <= s_axis_a_tdata[0]; - s_axis_a_tready <= 1'b0; - state <= BUSY; - end else if (m_axis_b_tdata < s_axis_a_tdata[0] + s_axis_a_tdata[1] && state == BUSY) begin // case: increment counter b - m_axis_b_tdata <= m_axis_b_tdata + 1; - end else if (state == BUSY) begin - m_axis_b_tdata <= m_axis_b_tdata; - state <= DONE; - end - end - - assign m_axis_b_tvalid = (m_axis_b_tdata >= s_axis_a_tdata[0] + s_axis_a_tdata[1] && (state == BUSY || state == DONE)) ? 1'b1:1'b0; - ''', - language=dace.Language.SystemVerilog) - -# add input/output array -A = state.add_read('A') -B = state.add_write('B') - -# connect input/output array with the tasklet -state.add_edge(A, None, tasklet, 'a', dace.Memlet('A[0:WIDTH-1]')) -state.add_edge(tasklet, 'b', B, None, dace.Memlet('B[0]')) - -# validate sdfg -sdfg.validate() - -###################################################################### - -if __name__ == '__main__': - with dace.config.set_temporary('compiler', 'xilinx', 'mode', value='simulation'): - # init data structures - a = np.random.randint(0, 100, dace.symbolic.evaluate(WIDTH, sdfg.constants)).astype(np.int32) - b = np.array([0]).astype(np.int32) - - # show initial values - print("a={}, b={}".format(a, b)) - - # call program - sdfg(A=a, B=b) - - # show result - print("a={}, b={}".format(a, b)) - - # check result - assert b == a[0] + a[1] diff --git a/samples/fpga/spmv_fpga_stream.py b/samples/fpga/spmv_fpga_stream.py deleted file mode 100644 index 7661873c66..0000000000 --- a/samples/fpga/spmv_fpga_stream.py +++ /dev/null @@ -1,506 +0,0 @@ -# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. -from __future__ import print_function - -import argparse -import dace -import math -import numpy as np -import scipy -import pdb -import select -import sys - -from dace.sdfg import SDFG, InterstateEdge -from dace.memlet import Memlet -from dace.dtypes import AllocationLifetime, StorageType, ScheduleType -from dace.sdfg.state import ConditionalBlock, ControlFlowRegion, LoopRegion - -cols = dace.symbol("cols") -rows = dace.symbol("rows") -nnz = dace.symbol("nnz") -itype = dace.uint32 -dtype = dace.float32 - - -def make_pre_state(sdfg: SDFG): - - state = sdfg.add_state("pre_state", is_start_block=True) - - a_row_host = state.add_array("A_row", (rows + 1, ), itype) - a_row_device = state.add_array("A_row_device", (rows + 1, ), itype, transient=True, storage=StorageType.FPGA_Global) - - a_col_host = state.add_array("A_col", (nnz, ), itype) - a_col_device = state.add_array("A_col_device", (nnz, ), itype, transient=True, storage=StorageType.FPGA_Global) - - a_val_host = state.add_array("A_val", (nnz, ), dtype) - a_val_device = state.add_array("A_val_device", (nnz, ), dtype, transient=True, storage=StorageType.FPGA_Global) - - x_host = state.add_array("x", (cols, ), dtype) - x_device = state.add_array("x_device", (cols, ), dtype, transient=True, storage=StorageType.FPGA_Global) - - state.add_memlet_path(a_row_host, a_row_device, memlet=dace.memlet.Memlet.simple(a_row_device, "0:rows+1")) - - state.add_memlet_path(a_col_host, a_col_device, memlet=dace.memlet.Memlet.simple(a_col_device, "0:nnz")) - - state.add_memlet_path(a_val_host, a_val_device, memlet=dace.memlet.Memlet.simple(a_val_device, "0:nnz")) - - state.add_memlet_path(x_host, x_device, memlet=dace.memlet.Memlet.simple(x_device, "0:cols")) - - return state - - -def make_post_state(sdfg: SDFG): - - state = sdfg.add_state("post_state") - - b_device = state.add_array("b_device", (rows, ), dtype, transient=True, storage=StorageType.FPGA_Global) - b_host = state.add_array("b", (rows, ), dtype) - - state.add_memlet_path(b_device, b_host, memlet=dace.memlet.Memlet.simple(b_host, "0:rows")) - - return state - - -def make_write_sdfg(): - - sdfg = SDFG("spmv_write") - - loop = LoopRegion('write_loop', 'h < rows', 'h', 'h = 0', 'h = h + 1') - sdfg.add_node(loop, is_start_block=True) - state = loop.add_state('body', is_start_block=True) - - result_to_write_in = state.add_stream("b_pipe", dtype, storage=StorageType.FPGA_Local) - b = state.add_array("b_mem", (rows, ), dtype, storage=StorageType.FPGA_Global) - - state.add_memlet_path(result_to_write_in, b, memlet=Memlet.simple(b, "h")) - - return sdfg - - -def make_iteration_space(sdfg: SDFG): - - pre_state = sdfg.add_state("pre_state", is_start_block=True) - - rows_loop = LoopRegion('rows_loop', 'h < rows', 'h', 'h = 0', 'h = h + 1') - sdfg.add_node(rows_loop) - sdfg.add_edge(pre_state, rows_loop, InterstateEdge()) - - shift_rowptr = rows_loop.add_state('shift_rowptr', is_start_block=True) - read_rowptr = rows_loop.add_state('read_rowptr') - rows_loop.add_edge(shift_rowptr, read_rowptr, InterstateEdge()) - - cols_loop = LoopRegion('cols_loop', 'c < row_end - row_begin', 'c', 'c = 0', 'c = c + 1') - rows_loop.add_node(cols_loop) - rows_loop.add_edge(read_rowptr, cols_loop, InterstateEdge()) - - body = cols_loop.add_state('compute', is_start_block=True) - - post_state = rows_loop.add_state('post_state') - rows_loop.add_edge(cols_loop, post_state, InterstateEdge()) - - row_end_first = pre_state.add_scalar("row_end", itype, transient=True, storage=StorageType.FPGA_Registers) - row_pipe_first = pre_state.add_stream("row_pipe", itype, storage=StorageType.FPGA_Local) - pre_state.add_memlet_path(row_pipe_first, row_end_first, memlet=Memlet.simple(row_end_first, "0")) - - row_end_shift = shift_rowptr.add_scalar("row_end", itype, transient=True, storage=StorageType.FPGA_Registers) - row_begin_shift = shift_rowptr.add_scalar("row_begin", - itype, - transient=True, - lifetime=AllocationLifetime.SDFG, - storage=StorageType.FPGA_Registers) - shift_rowptr.add_memlet_path(row_end_shift, row_begin_shift, memlet=Memlet.simple(row_begin_shift, "0")) - - row_pipe = read_rowptr.add_stream("row_pipe", itype, storage=StorageType.FPGA_Local) - row_end = read_rowptr.add_scalar("row_end", itype, transient=True, storage=StorageType.FPGA_Registers) - read_rowptr.add_memlet_path(row_pipe, row_end, memlet=Memlet.simple(row_end, "0")) - - return pre_state, body, post_state - - -def make_compute_nested_sdfg(): - - sdfg = SDFG('spmv_compute_nested') - - init_state = sdfg.add_state("init", is_start_block=True) - - conditional = ConditionalBlock('spmv_conditional') - sdfg.add_node(conditional) - sdfg.add_edge(init_state, conditional, InterstateEdge()) - - then_branch = ControlFlowRegion('then_branch') - conditional.add_branch('c == 0', then_branch) - then_state = then_branch.add_state('then', is_start_block=True) - - else_branch = ControlFlowRegion('then_branch') - conditional.add_branch('c != 0', else_branch) - else_state = else_branch.add_state('else', is_start_block=True) - - a_in = init_state.add_scalar("a_in", dtype, storage=StorageType.FPGA_Registers) - x_in = init_state.add_scalar("x_in", dtype, storage=StorageType.FPGA_Registers) - b_tmp_out = init_state.add_scalar("b_tmp", dtype, transient=True, storage=StorageType.FPGA_Registers) - tasklet = init_state.add_tasklet("compute", {"_a_in", "_x_in"}, {"_b_out"}, "_b_out = _a_in * _x_in") - init_state.add_memlet_path(a_in, tasklet, dst_conn="_a_in", memlet=Memlet.simple(a_in, "0")) - init_state.add_memlet_path(x_in, tasklet, dst_conn="_x_in", memlet=Memlet.simple(x_in, "0")) - init_state.add_memlet_path(tasklet, b_tmp_out, src_conn="_b_out", memlet=Memlet.simple(b_tmp_out, "0")) - - b_tmp_then_in = then_state.add_scalar("b_tmp", dtype, transient=True, storage=StorageType.FPGA_Registers) - b_then_out = then_state.add_scalar("b_out", dtype, storage=StorageType.FPGA_Registers) - then_state.add_memlet_path(b_tmp_then_in, b_then_out, memlet=Memlet.simple(b_then_out, "0")) - - b_tmp_else_in = else_state.add_scalar("b_tmp", dtype, transient=True, storage=StorageType.FPGA_Registers) - b_else_in = else_state.add_scalar("b_in", dtype, storage=StorageType.FPGA_Registers) - b_else_out = else_state.add_scalar("b_out", dtype, storage=StorageType.FPGA_Registers) - else_tasklet = else_state.add_tasklet("b_wcr", {"_b_in", "b_prev"}, {"_b_out"}, "_b_out = b_prev + _b_in") - else_state.add_memlet_path(b_tmp_else_in, else_tasklet, dst_conn="_b_in", memlet=Memlet.simple(b_tmp_else_in, "0")) - else_state.add_memlet_path(b_else_in, else_tasklet, dst_conn="b_prev", memlet=Memlet.simple(b_else_in, "0")) - else_state.add_memlet_path(else_tasklet, b_else_out, src_conn="_b_out", memlet=Memlet.simple(b_else_out, "0")) - - return sdfg - - -def make_compute_sdfg(): - - sdfg = SDFG("spmv_compute") - - pre_state, body, post_state = make_iteration_space(sdfg) - - a_pipe = body.add_stream("a_pipe", dtype, storage=StorageType.FPGA_Local) - x_pipe = body.add_stream("x_pipe", dtype, storage=StorageType.FPGA_Local) - b_buffer_in = body.add_scalar("b_buffer", dtype, transient=True, storage=StorageType.FPGA_Registers) - b_buffer_out = body.add_scalar("b_buffer", dtype, transient=True, storage=StorageType.FPGA_Registers) - nested_sdfg = make_compute_nested_sdfg() - tasklet = body.add_nested_sdfg(nested_sdfg, {"a_in", "x_in", "b_in"}, {"b_out"}, schedule=ScheduleType.FPGA_Device) - body.add_memlet_path(a_pipe, tasklet, dst_conn="a_in", memlet=Memlet.simple(a_pipe, "0")) - body.add_memlet_path(b_buffer_in, tasklet, dst_conn="b_in", memlet=Memlet.simple(b_buffer_in, "0")) - body.add_memlet_path(x_pipe, tasklet, dst_conn="x_in", memlet=Memlet.simple(x_pipe, "0")) - body.add_memlet_path(tasklet, b_buffer_out, src_conn="b_out", memlet=Memlet.simple(b_buffer_out, "0")) - - b_buffer_post_in = post_state.add_scalar("b_buffer", dtype, transient=True, storage=StorageType.FPGA_Registers) - b_pipe = post_state.add_stream("b_pipe", dtype, storage=StorageType.FPGA_Local) - post_state.add_memlet_path(b_buffer_post_in, b_pipe, memlet=Memlet.simple(b_pipe, "0")) - - return sdfg - - -def make_read_x(): - - sdfg = SDFG("spmv_read_x") - - pre_state, body, post_state = make_iteration_space(sdfg) - - x_mem = body.add_array("x_mem", (cols, ), dtype, storage=StorageType.FPGA_Global) - col_pipe = body.add_stream("col_pipe", itype, storage=StorageType.FPGA_Local) - compute_pipe = body.add_stream("compute_pipe", dtype, storage=StorageType.FPGA_Local) - - tasklet = body.add_tasklet("read_x", {"x_in", "col_in"}, {"x_out"}, "x_out = x_in[col_in]") - - body.add_memlet_path(x_mem, tasklet, dst_conn="x_in", memlet=Memlet.simple(x_mem, "0:cols")) - body.add_memlet_path(col_pipe, tasklet, dst_conn="col_in", memlet=Memlet.simple(col_pipe, "0")) - body.add_memlet_path(tasklet, compute_pipe, src_conn="x_out", memlet=Memlet.simple(compute_pipe, "0")) - - return sdfg - - -def make_read_val(): - - sdfg = SDFG("spmv_read_val") - - pre_state, body, post_state = make_iteration_space(sdfg) - - a_val_mem = body.add_array("A_val_mem", (nnz, ), dtype, storage=StorageType.FPGA_Global) - compute_pipe = body.add_stream("compute_pipe", dtype, storage=StorageType.FPGA_Local) - - tasklet = body.add_tasklet("read_val", {"a_in"}, {"a_out"}, "a_out = a_in[row_begin + c]") - - body.add_memlet_path(a_val_mem, tasklet, dst_conn="a_in", memlet=Memlet.simple(a_val_mem, "0:nnz")) - body.add_memlet_path(tasklet, compute_pipe, src_conn="a_out", memlet=Memlet.simple(compute_pipe, "0")) - - return sdfg - - -def make_read_col(): - - sdfg = SDFG("spmv_read_col") - - pre_state, body, post_state = make_iteration_space(sdfg) - - a_col = body.add_array("A_col_mem", (nnz, ), itype, storage=StorageType.FPGA_Global) - col_pipe = body.add_stream("col_pipe", itype, storage=StorageType.FPGA_Local) - - tasklet = body.add_tasklet("read_col", {"col_in"}, {"col_out"}, "col_out = col_in[row_begin + c]") - - body.add_memlet_path(a_col, tasklet, dst_conn="col_in", memlet=Memlet.simple(a_col, "0:nnz")) - body.add_memlet_path(tasklet, col_pipe, src_conn="col_out", memlet=Memlet.simple(col_pipe, "0")) - - return sdfg - - -def make_read_row(): - - sdfg = SDFG("spmv_read_row") - - loop = LoopRegion('read_row_loop', 'h < (rows + 1)', 'h', 'h = 0', 'h = h + 1') - sdfg.add_node(loop, is_start_block=True) - body = loop.add_state("body") - - a_row_mem = body.add_array("A_row_mem", (rows + 1, ), itype, storage=StorageType.FPGA_Global) - to_val_pipe = body.add_stream("to_val_pipe", itype, storage=StorageType.FPGA_Local) - to_col_pipe = body.add_stream("to_col_pipe", itype, storage=StorageType.FPGA_Local) - to_compute_pipe = body.add_stream("to_compute_pipe", itype, storage=StorageType.FPGA_Local) - to_x_pipe = body.add_stream("to_x_pipe", itype, storage=StorageType.FPGA_Local) - tasklet = body.add_tasklet( - "read_row", {"row_in"}, {"to_val_out", "to_col_out", "to_compute_out", "to_x_out"}, "to_val_out = row_in\n" - "to_col_out = row_in\n" - "to_compute_out = row_in\n" - "to_x_out = row_in") - - body.add_memlet_path(a_row_mem, tasklet, dst_conn="row_in", memlet=Memlet.simple(a_row_mem, "h")) - body.add_memlet_path(tasklet, to_val_pipe, src_conn="to_val_out", memlet=Memlet.simple(to_val_pipe, "0")) - body.add_memlet_path(tasklet, to_col_pipe, src_conn="to_col_out", memlet=Memlet.simple(to_col_pipe, "0")) - body.add_memlet_path(tasklet, - to_compute_pipe, - src_conn="to_compute_out", - memlet=Memlet.simple(to_compute_pipe, "0")) - body.add_memlet_path(tasklet, to_x_pipe, src_conn="to_x_out", memlet=Memlet.simple(to_x_pipe, "0")) - - return sdfg - - -def make_main_state(sdfg: SDFG): - - state = sdfg.add_state("spmv") - - # Read row pointers and send to value and column readers - a_row = state.add_array("A_row_device", (rows + 1, ), itype, transient=True, storage=StorageType.FPGA_Global) - row_to_val_out = state.add_stream("row_to_val", itype, transient=True, storage=StorageType.FPGA_Local) - row_to_col_out = state.add_stream("row_to_col", itype, transient=True, storage=StorageType.FPGA_Local) - row_to_x_out = state.add_stream("row_to_x", itype, transient=True, storage=StorageType.FPGA_Local) - row_to_compute_out = state.add_stream("row_to_compute", itype, transient=True, storage=StorageType.FPGA_Local) - read_row_sdfg = make_read_row() - read_row_tasklet = state.add_nested_sdfg(read_row_sdfg, {"A_row_mem"}, - {"to_val_pipe", "to_col_pipe", "to_x_pipe", "to_compute_pipe"}, - schedule=ScheduleType.FPGA_Device) - state.add_memlet_path(a_row, - read_row_tasklet, - memlet=dace.memlet.Memlet.simple(a_row, "0:rows+1"), - dst_conn="A_row_mem") - state.add_memlet_path(read_row_tasklet, - row_to_val_out, - memlet=dace.memlet.Memlet.simple(row_to_val_out, "0", num_accesses=-1), - src_conn="to_val_pipe") - state.add_memlet_path(read_row_tasklet, - row_to_col_out, - memlet=dace.memlet.Memlet.simple(row_to_col_out, "0", num_accesses=-1), - src_conn="to_col_pipe") - state.add_memlet_path(read_row_tasklet, - row_to_x_out, - memlet=dace.memlet.Memlet.simple(row_to_x_out, "0", num_accesses=-1), - src_conn="to_x_pipe") - state.add_memlet_path(read_row_tasklet, - row_to_compute_out, - memlet=dace.memlet.Memlet.simple(row_to_compute_out, "0", num_accesses=-1), - src_conn="to_compute_pipe") - - # Read columns of A using row pointers and send to x reader - a_col = state.add_array("A_col_device", (nnz, ), itype, transient=True, storage=StorageType.FPGA_Global) - row_to_col_in = state.add_stream("row_to_col", itype, transient=True, storage=StorageType.FPGA_Local) - col_to_x_out = state.add_stream("col_to_x", itype, transient=True, storage=StorageType.FPGA_Local) - read_col_sdfg = make_read_col() - read_col_tasklet = state.add_nested_sdfg(read_col_sdfg, {"A_col_mem", "row_pipe"}, {"col_pipe"}, - schedule=ScheduleType.FPGA_Device) - state.add_memlet_path(a_col, - read_col_tasklet, - memlet=dace.memlet.Memlet.simple(a_col, "0:nnz"), - dst_conn="A_col_mem") - state.add_memlet_path(row_to_col_in, - read_col_tasklet, - memlet=dace.memlet.Memlet.simple(row_to_col_in, "0", num_accesses=-1), - dst_conn="row_pipe") - state.add_memlet_path(read_col_tasklet, - col_to_x_out, - memlet=dace.memlet.Memlet.simple(col_to_x_out, "0", num_accesses=-1), - src_conn="col_pipe") - - # Read values of A using row pointers and send to compute - a_val = state.add_array("A_val_device", (nnz, ), dtype, transient=True, storage=StorageType.FPGA_Global) - row_to_val_in = state.add_stream("row_to_val", itype, transient=True, storage=StorageType.FPGA_Local) - val_to_compute_out = state.add_stream("val_to_compute", dtype, transient=True, storage=StorageType.FPGA_Local) - read_val_sdfg = make_read_val() - read_val_tasklet = state.add_nested_sdfg(read_val_sdfg, {"A_val_mem", "row_pipe"}, {"compute_pipe"}, - schedule=ScheduleType.FPGA_Device) - state.add_memlet_path(a_val, - read_val_tasklet, - dst_conn="A_val_mem", - memlet=dace.memlet.Memlet.simple(a_val, "0:nnz")) - state.add_memlet_path(row_to_val_in, - read_val_tasklet, - dst_conn="row_pipe", - memlet=dace.memlet.Memlet.simple(row_to_val_in, "0", num_accesses=-1)) - state.add_memlet_path(read_val_tasklet, - val_to_compute_out, - src_conn="compute_pipe", - memlet=dace.memlet.Memlet.simple(val_to_compute_out, "0", num_accesses=-1)) - - # Read values of x using column pointers and send to compute - x = state.add_array("x_device", (cols, ), dtype, transient=True, storage=StorageType.FPGA_Global) - row_to_x_in = state.add_stream("row_to_x", itype, transient=True, storage=StorageType.FPGA_Local) - col_to_x_in = state.add_stream("col_to_x", itype, transient=True, storage=StorageType.FPGA_Local) - x_to_compute_out = state.add_stream("x_to_compute", dtype, transient=True, storage=StorageType.FPGA_Local) - read_x_sdfg = make_read_x() - read_x_tasklet = state.add_nested_sdfg(read_x_sdfg, {"x_mem", "col_pipe", "row_pipe"}, {"compute_pipe"}, - schedule=ScheduleType.FPGA_Device) - state.add_memlet_path(x, read_x_tasklet, dst_conn="x_mem", memlet=dace.memlet.Memlet.simple(x, "0:cols")) - state.add_memlet_path(col_to_x_in, - read_x_tasklet, - dst_conn="col_pipe", - memlet=dace.memlet.Memlet.simple(col_to_x_in, "0", num_accesses=-1)) - state.add_memlet_path(row_to_x_in, - read_x_tasklet, - dst_conn="row_pipe", - memlet=dace.memlet.Memlet.simple(row_to_x_in, "0", num_accesses=-1)) - state.add_memlet_path(read_x_tasklet, - x_to_compute_out, - src_conn="compute_pipe", - memlet=dace.memlet.Memlet.simple(x_to_compute_out, "0", num_accesses=-1)) - - # Receive values of A and x and compute resulting values of b - row_to_compute_in = state.add_stream("row_to_compute", itype, transient=True, storage=StorageType.FPGA_Local) - val_to_compute_in = state.add_stream("val_to_compute", dtype, transient=True, storage=StorageType.FPGA_Local) - x_to_compute_in = state.add_stream("x_to_compute", dtype, transient=True, storage=StorageType.FPGA_Local) - result_to_write_out = state.add_stream("result_to_write", dtype, transient=True, storage=StorageType.FPGA_Local) - compute_sdfg = make_compute_sdfg() - compute_tasklet = state.add_nested_sdfg(compute_sdfg, {"row_pipe", "a_pipe", "x_pipe"}, {"b_pipe"}, - schedule=ScheduleType.FPGA_Device) - state.add_memlet_path(row_to_compute_in, - compute_tasklet, - dst_conn="row_pipe", - memlet=dace.memlet.Memlet.simple(row_to_compute_out, "0", num_accesses=-1)) - state.add_memlet_path(val_to_compute_in, - compute_tasklet, - dst_conn="a_pipe", - memlet=dace.memlet.Memlet.simple(val_to_compute_in, "0", num_accesses=-1)) - state.add_memlet_path(x_to_compute_in, - compute_tasklet, - dst_conn="x_pipe", - memlet=dace.memlet.Memlet.simple(x_to_compute_in, "0", num_accesses=-1)) - state.add_memlet_path(compute_tasklet, - result_to_write_out, - src_conn="b_pipe", - memlet=dace.memlet.Memlet.simple(result_to_write_out, "0", num_accesses=-1)) - - # Write back values of b - result_to_write_in = state.add_stream("result_to_write", dtype, transient=True, storage=StorageType.FPGA_Local) - b = state.add_array("b_device", (rows, ), dtype, transient=True, storage=StorageType.FPGA_Global) - write_sdfg = make_write_sdfg() - write_tasklet = state.add_nested_sdfg(write_sdfg, {"b_pipe"}, {"b_mem"}, schedule=ScheduleType.FPGA_Device) - state.add_memlet_path(result_to_write_in, - write_tasklet, - dst_conn="b_pipe", - memlet=dace.memlet.Memlet.simple(result_to_write_in, "0", num_accesses=-1)) - state.add_memlet_path(write_tasklet, b, src_conn="b_mem", memlet=dace.memlet.Memlet.simple(b, "0:rows")) - - return state - - -def make_sdfg(specialize, rows, cols, nnz): - - if specialize: - name = "spmv_fpga_stream_{}x{}x{}".format(rows, cols, nnz) - else: - name = "spmv_fpga_stream" - sdfg = dace.SDFG(name) - - pre_state = make_pre_state(sdfg) - main_state = make_main_state(sdfg) - post_state = make_post_state(sdfg) - - sdfg.add_edge(pre_state, main_state, dace.sdfg.InterstateEdge()) - sdfg.add_edge(main_state, post_state, dace.sdfg.InterstateEdge()) - - return sdfg - - -def run_spmv(size_w, size_h, num_nonzero, specialize): - print("Sparse Matrix-Vector Multiplication {}x{} " - "({} non-zero elements, {}specialized)".format(size_w, size_h, num_nonzero, "not " if not specialize else "")) - - A_row = dace.ndarray([size_h + 1], dtype=itype) - A_col = dace.ndarray([num_nonzero], dtype=itype) - A_val = dace.ndarray([num_nonzero], dtype=dtype) - - x = dace.ndarray([size_w], dtype) - b = dace.ndarray([size_h], dtype) - - # Assuming uniform sparsity distribution across rows - nnz_per_row = num_nonzero // size_h - nnz_last_row = nnz_per_row + (num_nonzero % size_h) - if nnz_last_row > size_w: - print("Too many nonzeros per row") - exit(1) - - # RANDOMIZE SPARSE MATRIX - A_row[0] = itype(0) - A_row[1:size_h] = itype(nnz_per_row) - A_row[-1] = itype(nnz_last_row) - A_row = np.cumsum(A_row, dtype=itype.type) - - # Fill column data - for i in range(size_h - 1): - A_col[nnz_per_row*i:nnz_per_row*(i+1)] = \ - np.sort(np.random.choice(size_w, nnz_per_row, replace=False)) - # Fill column data for last row - A_col[nnz_per_row * (size_h - 1):] = np.sort(np.random.choice(size_w, nnz_last_row, replace=False)) - - A_val[:] = np.random.rand(num_nonzero).astype(dtype.type) - ######################### - - x[:] = np.random.rand(size_w).astype(dtype.type) - #b[:] = dtype(0) - - # Setup regression - A_sparse = scipy.sparse.csr_matrix((A_val, A_col, A_row), shape=(size_h, size_w)) - - spmv = make_sdfg(specialize, size_h, size_w, num_nonzero) - if specialize: - spmv.specialize(dict(rows=size_h, cols=size_w, nnz=num_nonzero)) - spmv(A_row=A_row, A_col=A_col, A_val=A_val, x=x, b=b, rows=size_h, cols=size_w, nnz=num_nonzero) - - if dace.Config.get_bool("profiling"): - dace.timethis("spmv", "scipy", 0, A_sparse.dot, x) - - diff = np.linalg.norm(A_sparse.dot(x) - b) / float(size_h) - print("Difference:", diff) - if diff >= 1e-5: - print("Validation failed.") - print("Result:") - print(b) - print("Reference:") - print(A_sparse.dot(x)) - print("Type \"debug\" to enter debugger, " - "or any other string to quit (timeout in 10 seconds)") - read, _, _ = select.select([sys.stdin], [], [], 10) - if len(read) > 0 and sys.stdin.readline().strip().lower() == "debug": - print("Entering debugger...") - pdb.set_trace() - else: - print("Exiting...") - print("==== Program end ====") - if diff > 1e-5: - raise RuntimeError("Validation failed.") - - return spmv - - -if __name__ == "__main__": - - parser = argparse.ArgumentParser() - parser.add_argument("cols", type=int) - parser.add_argument("rows", type=int) - parser.add_argument("nnz", type=int) - parser.add_argument("-specialize", - default=False, - action="store_true", - help="Fix all symbols at compile time/in hardware") - args = parser.parse_args() - - run_spmv(args.cols, args.rows, args.nnz, args.specialize) diff --git a/samples/optimization/matmul.py b/samples/optimization/matmul.py index 06b6a38939..9353a5aae9 100644 --- a/samples/optimization/matmul.py +++ b/samples/optimization/matmul.py @@ -11,7 +11,6 @@ # For optimizations from dace.transformation.dataflow import (DoubleBuffering, MapCollapse, MapExpansion, MapReduceFusion, StripMining, InLocalStorage, AccumulateTransient, Vectorization) -from dace.transformation.interstate import FPGATransformSDFG from dace.transformation import helpers as xfutil # For library node implementations @@ -197,8 +196,7 @@ def optimize_for_gpu(sdfg: dace.SDFG, m: int, n: int, k: int): @click.option('-K', type=int, default=64) @click.option('-N', type=int, default=64) @click.option('--version', - type=click.Choice(('unoptimized', 'optimize_cpu', 'optimize_gpu', 'mkl', 'cublas', 'rocblas', - 'fpga_naive', 'fpga_library')), + type=click.Choice(('unoptimized', 'optimize_cpu', 'optimize_gpu', 'mkl', 'cublas', 'rocblas')), default='unoptimized') @click.option('--verify/--no-verify', default=True) def cli(m, k, n, version, verify): @@ -248,13 +246,6 @@ def cli(m, k, n, version, verify): dace.libraries.blas.default_implementation = 'rocBLAS' # Call program C = matmul_lib(A, B) - elif version == 'fpga_naive': - matmul_sdfg = matmul.to_sdfg() - matmul_sdfg.apply_transformations(FPGATransformSDFG) - matmul_sdfg(A=A, B=B, C=C, N=n, K=k, M=m) - elif version == 'fpga_systolic': - dace.libraries.blas.default_implementation = 'FPGA1DSystolic' - C = matmul_lib(A, B) else: raise ValueError('Invalid version %s' % version) diff --git a/setup.py b/setup.py index aaf120c8f0..79f5802d09 100644 --- a/setup.py +++ b/setup.py @@ -22,12 +22,6 @@ ] cub_files = [f[len(dace_path):] for f in glob.glob(dace_path + 'external/cub/cub/**/*', recursive=True) ] + [dace_path + 'external/cub/LICENSE.TXT'] -hlslib_files = [f[len(dace_path):] for f in glob.glob(dace_path + 'external/hlslib/cmake/**/*', recursive=True)] + [ - f[len(dace_path):] for f in glob.glob(dace_path + 'external/hlslib/include/**/*', recursive=True) -] + [dace_path + 'external/hlslib/LICENSE.md'] -rtllib_files = [f[len(dace_path):] for f in glob.glob(dace_path + 'external/rtllib/cmake/**/*', recursive=True)] + [ - f[len(dace_path):] for f in glob.glob(dace_path + 'external/rtllib/templates/**/*', recursive=True) -] # See if CMake is available and if not, install as a dependency cmake_requires = ['scikit-build', 'cmake'] @@ -68,8 +62,8 @@ package_data={ '': [ '*.yml', 'codegen/CMakeLists.txt', 'codegen/tools/*.cpp', 'external/moodycamel/*.h', - 'external/moodycamel/LICENSE.md', 'codegen/Xilinx_HLS.tcl.in' - ] + runtime_files + cub_files + viewer_files + hlslib_files + library_files + rtllib_files + cmake_files + 'external/moodycamel/LICENSE.md' + ] + runtime_files + cub_files + viewer_files + library_files + cmake_files }, include_package_data=True, install_requires=[ @@ -78,9 +72,22 @@ 'typing-compat; python_version < "3.8"', 'packaging' ] + cmake_requires, extras_require={ + 'ml': ['onnx', 'torch', 'onnxsim', 'onnxscript', 'onnxruntime', 'protobuf', 'ninja'], 'testing': [ + 'coverage', + 'pytest-cov', + 'scipy', + 'absl-py', + 'opt_einsum', + 'pymlir', + 'click', + 'ipykernel', + 'nbconvert', + 'pytest-timeout', + ], + 'ml-testing': [ 'coverage', 'pytest-cov', 'scipy', 'absl-py', 'opt_einsum', 'pymlir', 'click', 'ipykernel', 'nbconvert', - 'pytest-timeout' + 'pytest-timeout', 'transformers == 4.50', 'jax <= 0.6.2', 'efficientnet_pytorch' ], 'docs': ['jinja2<3.2.0', 'sphinx-autodoc-typehints', 'sphinx-rtd-theme>=0.5.1'], 'linting': ['pre-commit==4.1.0', 'yapf==0.43.0'], diff --git a/tests/.gitignore b/tests/.gitignore index 748014c0fd..6553fb299d 100644 --- a/tests/.gitignore +++ b/tests/.gitignore @@ -2,3 +2,6 @@ *.from_serialized *.serialized *.txt + +# Ignore downloaded files for tests +data/ diff --git a/tests/autodiff/test_multi_state.py b/tests/autodiff/test_multi_state.py new file mode 100644 index 0000000000..db23109d63 --- /dev/null +++ b/tests/autodiff/test_multi_state.py @@ -0,0 +1,304 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import numpy as np +import pytest + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") + +import torch + +import dace +from dace import SDFG, InterstateEdge, Memlet +from test_single_state import SDFGBackwardRunner, run_correctness + + +@pytest.mark.autodiff +@run_correctness +def test_two_state_add_mul(): + """ + Test a two-state SDFG: + - State 1: Z = X + Y (element-wise addition) + - State 2: S = sum(Z * Z) (element-wise multiplication then sum) + """ + + sdfg = SDFG("two_state_add_mul") + + sdfg.add_array("X", [3, 3], dace.float32) + sdfg.add_array("Y", [3, 3], dace.float32) + sdfg.add_array("Z", [3, 3], dace.float32, transient=False) + sdfg.add_array("S", [1], dace.float32) + + state1 = sdfg.add_state("state1") + X_read = state1.add_access("X") + Y_read = state1.add_access("Y") + Z_write = state1.add_access("Z") + + map_entry, map_exit = state1.add_map("add_map", dict(i="0:3", j="0:3")) + + tasklet_add = state1.add_tasklet("add", {"x", "y"}, {"z"}, "z = x + y") + + state1.add_memlet_path(X_read, map_entry, tasklet_add, dst_conn="x", memlet=Memlet("X[i, j]")) + state1.add_memlet_path(Y_read, map_entry, tasklet_add, dst_conn="y", memlet=Memlet("Y[i, j]")) + state1.add_memlet_path(tasklet_add, map_exit, Z_write, src_conn="z", memlet=Memlet("Z[i, j]")) + + state2 = sdfg.add_state("state2") + Z_read = state2.add_access("Z") + S_write = state2.add_access("S") + + map_entry2, map_exit2 = state2.add_map("mul_map", dict(i="0:3", j="0:3")) + + tasklet_mul = state2.add_tasklet("mul", {"z"}, {"s"}, "s = z * z") + + state2.add_memlet_path(Z_read, map_entry2, tasklet_mul, dst_conn="z", memlet=Memlet("Z[i, j]")) + state2.add_memlet_path(tasklet_mul, + map_exit2, + S_write, + src_conn="s", + memlet=Memlet("S[0]", wcr="lambda a, b: a + b")) + + sdfg.add_edge(state1, state2, InterstateEdge()) + + # PyTorch reference implementation + def torch_func(*, X, Y): + Z = X + Y + S = (Z * Z).sum() + S.backward() + return dict(gradient_X=X.grad, gradient_Y=Y.grad) + + return ( + SDFGBackwardRunner(sdfg, "S"), + torch_func, + dict( + X=np.random.rand(3, 3).astype(np.float32), + Y=np.random.rand(3, 3).astype(np.float32), + ), + ) + + +@pytest.mark.autodiff +@run_correctness +def test_conditional_simple(): + """ + Test a Python program with a simple conditional in the forward pass: + if X[0, 0] > 0.5: + Y = X * 2 + else: + Y = X * 3 + S = sum(Y) + """ + + @dace.program + def conditional_program(X: dace.float32[3, 3], Y: dace.float32[3, 3], S: dace.float32[1]): + if X[0, 0] > 0.5: + Y[:] = X * 2.0 + else: + Y[:] = X * 3.0 + S[0] = np.sum(Y) + + sdfg = conditional_program.to_sdfg(simplify=True) + + # PyTorch reference implementation + def torch_func(*, X): + Y = torch.where(X[0, 0] > 0.5, X * 2.0, X * 3.0) + S = Y.sum() + S.backward() + return dict(gradient_X=X.grad) + + return ( + SDFGBackwardRunner(sdfg, "S", simplify=False), + torch_func, + dict(X=np.random.rand(3, 3).astype(np.float32)), + ) + + +@pytest.mark.autodiff +@run_correctness +def test_for_loop(): + """ + Test a simple for loop similar to jacobi_1d, but simplified: + for i in range(3): + A = A + B + S = sum(A) + """ + + @dace.program + def for_loop_program(A: dace.float32[10], B: dace.float32[10]): + for i in range(3): + A[:] = A + B + return np.sum(A) + + sdfg = for_loop_program.to_sdfg() + + # PyTorch reference implementation + def torch_func(*, A, B): + A_result = A.clone() + for i in range(3): + A_result = A_result + B + S = A_result.sum() + S.backward() + return dict(gradient_A=A.grad, gradient_B=B.grad) + + return ( + SDFGBackwardRunner(sdfg, "__return"), + torch_func, + dict( + A=np.random.rand(10).astype(np.float32), + B=np.random.rand(10).astype(np.float32), + ), + ) + + +@pytest.mark.autodiff +@run_correctness +def test_diamond_pattern_conditional(): + """ + Test an SDFG with a diamond pattern control flow using GOTOs. + + Structure: + state1: Y = X * 2 + if X[0] > 0.5: + goto state3 + else: + goto state2 + state2: Y = Y + 1 + state3: S = sum(Y) + + This creates a diamond pattern where both paths can reach state3. + """ + + sdfg = SDFG("irreducible_cf") + + # Add arrays + sdfg.add_array("X", [5], dace.float32) + sdfg.add_array("Y", [5], dace.float32, transient=False) + sdfg.add_array("S", [1], dace.float32) + + # State 1: Y = X * 2 + state1 = sdfg.add_state("state1") + X_read1 = state1.add_access("X") + Y_write1 = state1.add_access("Y") + + map_entry1, map_exit1 = state1.add_map("mul_map", dict(i="0:5")) + tasklet1 = state1.add_tasklet("mul", {"x"}, {"y"}, "y = x * 2.0") + + state1.add_memlet_path(X_read1, map_entry1, tasklet1, dst_conn="x", memlet=Memlet("X[i]")) + state1.add_memlet_path(tasklet1, map_exit1, Y_write1, src_conn="y", memlet=Memlet("Y[i]")) + + # State 2: Y = Y + 1 + state2 = sdfg.add_state("state2") + Y_read2 = state2.add_access("Y") + Y_write2 = state2.add_access("Y") + + map_entry2, map_exit2 = state2.add_map("add_map", dict(i="0:5")) + tasklet2 = state2.add_tasklet("add", {"y_in"}, {"y_out"}, "y_out = y_in + 1.0") + + state2.add_memlet_path(Y_read2, map_entry2, tasklet2, dst_conn="y_in", memlet=Memlet("Y[i]")) + state2.add_memlet_path(tasklet2, map_exit2, Y_write2, src_conn="y_out", memlet=Memlet("Y[i]")) + + # State 3: S = sum(Y) + state3 = sdfg.add_state("state3") + Y_read3 = state3.add_access("Y") + S_write3 = state3.add_access("S") + + map_entry3, map_exit3 = state3.add_map("sum_map", dict(i="0:5")) + tasklet3 = state3.add_tasklet("sum", {"y"}, {"s"}, "s = y") + + state3.add_memlet_path(Y_read3, map_entry3, tasklet3, dst_conn="y", memlet=Memlet("Y[i]")) + state3.add_memlet_path(tasklet3, map_exit3, S_write3, src_conn="s", memlet=Memlet("S[0]", wcr="lambda a, b: a + b")) + + # Create conditional edges (irreducible control flow) + # Add condition: if X[0] > 0.5 goto state3, else goto state2 + sdfg.add_edge(state1, state3, InterstateEdge(condition="X[0] > 0.5")) + sdfg.add_edge(state1, state2, InterstateEdge(condition="X[0] <= 0.5")) + sdfg.add_edge(state2, state3, InterstateEdge()) + + # PyTorch reference implementation + def torch_func(*, X): + Y = X * 2.0 + Y = torch.where(X[0] > 0.5, Y, Y + 1.0) + S = Y.sum() + S.backward() + return dict(gradient_X=X.grad) + + return ( + SDFGBackwardRunner(sdfg, "S", simplify=False), + torch_func, + dict(X=np.random.rand(5).astype(np.float32)), + ) + + +@pytest.mark.autodiff +@run_correctness +def test_multi_output_state(): + """ + Test a two-state SDFG where the first state produces multiple outputs: + State 1: Y = X * 2, Z = X + 1 + State 2: S = sum(Y * Z) + """ + + # Build SDFG using API + sdfg = SDFG("multi_output_state") + + # Add arrays + sdfg.add_array("X", [5], dace.float32) + sdfg.add_array("Y", [5], dace.float32, transient=False) + sdfg.add_array("Z", [5], dace.float32, transient=False) + sdfg.add_array("S", [1], dace.float32) + + # State 1: Compute Y and Z + state1 = sdfg.add_state("state1") + X_read1 = state1.add_access("X") + Y_write1 = state1.add_access("Y") + Z_write1 = state1.add_access("Z") + + map_entry1, map_exit1 = state1.add_map("compute_map", dict(i="0:5")) + tasklet_y = state1.add_tasklet("compute_y", {"x"}, {"y"}, "y = x * 2.0") + tasklet_z = state1.add_tasklet("compute_z", {"x"}, {"z"}, "z = x + 1.0") + + state1.add_memlet_path(X_read1, map_entry1, tasklet_y, dst_conn="x", memlet=Memlet("X[i]")) + state1.add_memlet_path(tasklet_y, map_exit1, Y_write1, src_conn="y", memlet=Memlet("Y[i]")) + + X_read2 = state1.add_access("X") + state1.add_memlet_path(X_read2, map_entry1, tasklet_z, dst_conn="x", memlet=Memlet("X[i]")) + state1.add_memlet_path(tasklet_z, map_exit1, Z_write1, src_conn="z", memlet=Memlet("Z[i]")) + + # State 2: Multiply and sum + state2 = sdfg.add_state("state2") + Y_read2 = state2.add_access("Y") + Z_read2 = state2.add_access("Z") + S_write2 = state2.add_access("S") + + map_entry2, map_exit2 = state2.add_map("mul_sum_map", dict(i="0:5")) + tasklet_mul = state2.add_tasklet("mul", {"y", "z"}, {"s"}, "s = y * z") + + state2.add_memlet_path(Y_read2, map_entry2, tasklet_mul, dst_conn="y", memlet=Memlet("Y[i]")) + state2.add_memlet_path(Z_read2, map_entry2, tasklet_mul, dst_conn="z", memlet=Memlet("Z[i]")) + state2.add_memlet_path(tasklet_mul, + map_exit2, + S_write2, + src_conn="s", + memlet=Memlet("S[0]", wcr="lambda a, b: a + b")) + + # Connect states + sdfg.add_edge(state1, state2, InterstateEdge()) + + # PyTorch reference implementation + def torch_func(*, X): + Y = X * 2.0 + Z = X + 1.0 + S = (Y * Z).sum() + S.backward() + return dict(gradient_X=X.grad) + + return ( + SDFGBackwardRunner(sdfg, "S"), + torch_func, + dict(X=np.random.rand(5).astype(np.float32)), + ) + + +if __name__ == "__main__": + test_two_state_add_mul() + test_conditional_simple() + test_for_loop() + test_diamond_pattern_conditional() + test_multi_output_state() diff --git a/tests/autodiff/test_nested.py b/tests/autodiff/test_nested.py new file mode 100644 index 0000000000..427efce8c2 --- /dev/null +++ b/tests/autodiff/test_nested.py @@ -0,0 +1,174 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import numpy as np +import pytest + +pytest.importorskip("onnx", reason="ONNX not installed. Please install with: pip install dace[ml]") +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") + +import torch + +import dace +from dace.transformation.interstate import StateFusion + +import dace.libraries.onnx as donnx +from test_single_state import SDFGBackwardRunner, run_correctness + + +@dace.program +def inner_sdfg(Z: dace.float32[3, 3], W: dace.float32[3, 3]): + W[:] = dace.elementwise(lambda x: log(x), Z) + + +@dace.program +def inner_sdfg_with_intermediate(Z: dace.float32[3, 3], W: dace.float32[3, 3]): + intermediate = dace.define_local([3, 3], dace.float32) + intermediate[:] = dace.elementwise(lambda x: sqrt(x), Z) + W[:] = dace.elementwise(lambda x: log(x), intermediate) + + +@dace.program +def middle_sqrt(Y: dace.float32[3, 3]): + intermediate = dace.define_local([3, 3], dace.float32) + W = dace.define_local([3, 3], dace.float32) + intermediate[:] = dace.elementwise(lambda x: sqrt(x), Y) + inner_sdfg(intermediate, W) + Z = np.sum(W) + return Z + + +@pytest.mark.autodiff +@run_correctness +def test_nested(): + sdfg = middle_sqrt.to_sdfg(simplify=True) + + def torch_func(*, Y): + inter = torch.sqrt(Y) + W = torch.log(inter) + Z = torch.sum(W) + Z.backward() + return dict(gradient_Y=Y.grad) + + return (SDFGBackwardRunner(sdfg, "__return", + simplify=False), torch_func, dict(Y=np.random.rand(3, 3).astype(np.float32))) + + +@dace.program +def middle_sqrt_with_intermediate(Y: dace.float32[3, 3]): + intermediate = dace.define_local([3, 3], dace.float32) + W = dace.define_local([3, 3], dace.float32) + intermediate[:] = dace.elementwise(lambda x: sqrt(x), Y) + inner_sdfg_with_intermediate(intermediate, W) + Z = np.sum(W) + return Z + + +@pytest.mark.autodiff +@run_correctness +def test_nested_forwarding(): + sdfg = middle_sqrt_with_intermediate.to_sdfg(simplify=True) + + def torch_func(*, Y): + inter = torch.sqrt(Y) + inter2 = torch.sqrt(inter) + W = torch.log(inter2) + Z = torch.sum(W) + Z.backward() + return dict(gradient_Y=Y.grad) + + return (SDFGBackwardRunner(sdfg, "__return", + simplify=False), torch_func, dict(Y=np.random.rand(3, 3).astype(np.float32))) + + +@dace.program +def middle_sqrt_no_sum(Y: dace.float32[3, 3]): + intermediate = dace.define_local([3, 3], dace.float32) + W = dace.define_local([3, 3], dace.float32) + intermediate[:] = dace.elementwise(lambda x: sqrt(x), Y) + inner_sdfg_with_intermediate(intermediate, W) + return W + + +@dace.program +def outer_sqrt_with_intermediate(Y: dace.float32[3, 3]): + intermediate = dace.define_local([3, 3], dace.float32) + W = dace.define_local([3, 3], dace.float32) + intermediate[:] = dace.elementwise(lambda x: sqrt(x), Y) + W[:] = middle_sqrt_no_sum(intermediate) + Z = np.sum(W) + return Z + + +@pytest.mark.autodiff +@run_correctness +def test_triple_nested_forwarding(): + sdfg = outer_sqrt_with_intermediate.to_sdfg(simplify=True) + + def torch_func(*, Y): + inter = torch.sqrt(Y) + inter2 = torch.sqrt(inter) + inter3 = torch.sqrt(inter2) + W = torch.log(inter3) + Z = torch.sum(W) + Z.backward() + return dict(gradient_Y=Y.grad) + + return (SDFGBackwardRunner(sdfg, "__return", + simplify=False), torch_func, dict(Y=np.random.rand(3, 3).astype(np.float32))) + + +@pytest.mark.autodiff +@run_correctness +def test_view_forwarding(): + # Prepare the inner sdfg + old_default = donnx.default_implementation + donnx.default_implementation = "pure" + + @dace.program + def add_reshape_grad_test_nested(inp1: dace.float64[9], bias: dace.float64[3], target_shape: dace.int64[2], + result: dace.float64): + reshaped = dace.define_local([3, 3], dace.float64) + added = inp1 + 1 + donnx.ONNXReshape(data=added, shape=target_shape, reshaped=reshaped) + Z = reshaped * bias + Zl = dace.elementwise(lambda x: log(x + 1), Z) + result[:] = np.sum(Zl) + + sdfg = add_reshape_grad_test_nested.to_sdfg(simplify=True) + + sdfg.expand_library_nodes() + del sdfg.arrays["target_shape"] + + donnx.default_implementation = old_default + + # Prepare the outer SDFG + @dace.program + def inner_view_forwarding(inp1: dace.float64[9], bias: dace.float64[3]): + result = dace.define_local_scalar(dace.float64) + # target shape gets removed by the pure reshape expansion + sdfg(inp1=inp1, bias=bias, result=result) + return result + 1 + + # This generates a FunctionCallRegion in the current frontned + # We need to simplify. + outer_sdfg = inner_view_forwarding.to_sdfg(simplify=True) + outer_sdfg.apply_transformations_repeated([StateFusion]) + + def torch_func(*, inp1, bias): + reshaped = torch.reshape(inp1 + 1, [3, 3]) + + Z = reshaped * bias + Zl = torch.log(Z + 1) + S = Zl.sum() + 1 + + S.backward() + return dict(gradient_inp1=inp1.grad, gradient_bias=bias.grad) + + return (SDFGBackwardRunner(outer_sdfg, "__return", simplify=False), torch_func, + dict(inp1=np.random.rand(9).astype(np.float64), bias=np.random.rand(3).astype(np.float64))) + + +if __name__ == "__main__": + test_nested() + test_nested_forwarding() + test_triple_nested_forwarding() + test_view_forwarding() diff --git a/tests/autodiff/test_single_state.py b/tests/autodiff/test_single_state.py new file mode 100644 index 0000000000..aeeeb0eb1c --- /dev/null +++ b/tests/autodiff/test_single_state.py @@ -0,0 +1,635 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import numpy as np +import pytest + +pytest.importorskip("onnx", reason="ONNX not installed. Please install with: pip install dace[ml]") +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") + +import torch + +import dace +import dace.sdfg.nodes as nd +from dace.transformation.interstate import StateFusion + +import dace.libraries.onnx as donnx +from dace.autodiff import add_backward_pass + + +################################## +# Testing utilities +def run_correctness(func): + + def test_correctness(): + runner, pytorch_func, inputs = func() + sdfg_dict = {name: arr.copy() for name, arr in inputs.items()} + torch_dict = {name: torch.tensor(arr.copy(), requires_grad=True) for name, arr in inputs.items()} + + sdfg_results = runner.run(**sdfg_dict) + torch_results = pytorch_func(**torch_dict) + + for k, v in torch_results.items(): + v = v.detach().numpy() + diff = np.linalg.norm(sdfg_results[k] - v) / np.prod(v.shape) + assert diff < 1e-5, f"Gradient mismatch for '{k}': normalized difference {diff:.2e} exceeds tolerance 1e-5" + + return test_correctness + + +class SDFGBackwardRunner: + + def __init__(self, sdfg, target, simplify=True): + if simplify: + sdfg.simplify() + self.sdfg: dace.SDFG = sdfg + self.target = target + + # Collect all non-transient float arrays from all states as required gradients + required_grads = [] + seen_names = set() + for state in sdfg.states(): + for node in state.nodes(): + if isinstance(node, nd.AccessNode): + arr = node.desc(sdfg) + if (arr.dtype in [dace.float32, dace.float64] and not arr.transient + and node.data not in seen_names): + required_grads.append(node) + seen_names.add(node.data) + + add_backward_pass(sdfg=self.sdfg, outputs=[self.target], inputs=required_grads, simplify=simplify) + + def run(self, **inputs): + + # Zero out all arrays + intermediate_arrs = {} + gradient_target = "gradient_" + self.target + + for name, arr in self.sdfg.arrays.items(): + # Skip gradient target, dunder names, inputs, and transients + if (name == gradient_target or name.startswith("__") or name in inputs or arr.transient): + continue + + dtype = getattr(np, arr.dtype.to_string()) + intermediate_arrs[name] = np.zeros(arr.shape, dtype=dtype) + + inputs.update(intermediate_arrs) + inputs["gradient_" + self.target] = np.ones((1, ), + dtype=getattr(np, self.sdfg.arrays[self.target].dtype.to_string())) + + self.sdfg(**inputs) + + results = {name: arr for name, arr in inputs.items()} + return results + + +################################## +# Tests +@pytest.mark.autodiff +@run_correctness +def test_gemm(): + + def torch_gemm(*, X, Y): + Z = X @ Y + S = Z.sum() + S.backward() + return dict(gradient_X=X.grad, gradient_Y=Y.grad) + + @dace.program + def dace_gemm( + X: dace.float32[5, 4], + Y: dace.float32[4, 3], + Z: dace.float32[5, 3], + S: dace.float32[1], + ): + + Z[:] = X @ Y + + @dace.map(_[0:5, 0:3]) + def summap(i, j): + s >> S(1, lambda x, y: x + y)[0] + z << Z[i, j] + s = z + + sdfg = dace_gemm.to_sdfg() + + return ( + SDFGBackwardRunner(sdfg, "S"), + torch_gemm, + dict( + X=np.random.rand(5, 4).astype(np.float32), + Y=np.random.rand(4, 3).astype(np.float32), + ), + ) + + +@pytest.mark.autodiff +@run_correctness +def test_sum(): + + def torch_sum(*, X, Y): + Z = X + Y + Z = Z * Z + S = Z.sum() + S.backward() + return dict(gradient_X=X.grad, gradient_Y=Y.grad) + + @dace.program + def dace_sum( + X: dace.float32[3, 3], + Y: dace.float32[3, 3], + Z: dace.float32[3, 3], + S: dace.float32[1], + ): + + Z[:] = X + Y + + @dace.map(_[0:3, 0:3]) + def summap(i, j): + s >> S(1, lambda x, y: x + y)[0] + z << Z[i, j] + s = z * z + + sdfg = dace_sum.to_sdfg() + + return ( + SDFGBackwardRunner(sdfg, "S"), + torch_sum, + dict( + X=np.random.rand(3, 3).astype(np.float32), + Y=np.random.rand(3, 3).astype(np.float32), + ), + ) + + +@pytest.mark.autodiff +@run_correctness +def test_complex_tasklet(): + + def torch_sum(*, X, Y): + Z = X + Y + Z = Z * Z + S = Z.sum() + S.backward() + return dict(gradient_X=X.grad, gradient_Y=Y.grad) + + @dace.program + def dace_sum_complex( + X: dace.float32[3, 3], + Y: dace.float32[3, 3], + Z: dace.float32[3, 3], + S: dace.float32[1], + ): + + Z[:] = X + Y + + @dace.map(_[0:3, 0:3]) + def summap(i, j): + s >> S(1, lambda x, y: x + y)[0] + z << Z[i, j] + + z1 = z + 1 + log(3) # random expr + z2 = z - 1 * (2 / 2) + # hello world 1, 2, 3 + s = z1 * z2 + + sdfg = dace_sum_complex.to_sdfg() + + return ( + SDFGBackwardRunner(sdfg, "S"), + torch_sum, + dict( + X=np.random.rand(3, 3).astype(np.float32), + Y=np.random.rand(3, 3).astype(np.float32), + ), + ) + + +@pytest.mark.autodiff +@run_correctness +def test_tasklets_only_reuse(): + + def torch_func(*, A): + tmp_a = torch.sqrt(A) + tmp_b = torch.log(A + 1) + + C = tmp_a * tmp_b + + C.backward() + return dict(gradient_A=A.grad) + + @dace.program + def tasklets_only_reuse(A: dace.float32[1], C: dace.float32[1]): + tmp_a = dace.define_local_scalar(dace.float32) + tmp_b = dace.define_local_scalar(dace.float32) + + with dace.tasklet: + a << A[0] + a_out >> tmp_a + + a_out = sqrt(a) + + with dace.tasklet: + a << A[0] + a_out >> tmp_b + + a_out = log(a + 1) + + with dace.tasklet: + a << tmp_a + b << tmp_b + c >> C[0] + c = a * b + + sdfg = tasklets_only_reuse.to_sdfg(simplify=False) + sdfg.simplify() + return ( + SDFGBackwardRunner(sdfg, "C"), + torch_func, + dict(A=np.random.rand(1).astype(np.float32)), + ) + + +@pytest.mark.autodiff +@run_correctness +def test_tasklets_multioutput(): + + def torch_func(*, A, B): + tmp_a = torch.sqrt(A) + tmp_b = torch.log(B + 1) + + C = tmp_a * tmp_b * B + + C.backward() + return dict(gradient_A=A.grad, gradient_B=B.grad) + + @dace.program + def tasklets_multioutput(A: dace.float32[1], B: dace.float32[1], C: dace.float32[1]): + tmp_a = dace.define_local_scalar(dace.float32) + tmp_b = dace.define_local_scalar(dace.float32) + tmp_d = dace.define_local_scalar(dace.float32) + + with dace.tasklet: + a << A[0] + a_out >> tmp_a + + a_out = sqrt(a) + + with dace.tasklet: + b << B[0] + b_out >> tmp_b + d_out >> tmp_d + + b_out = log(b + 1) + d_out = b + + with dace.tasklet: + a << tmp_a + b << tmp_b + d << tmp_d + c >> C[0] + c = a * b * d + + sdfg = tasklets_multioutput.to_sdfg(simplify=False) + sdfg.simplify() + + return ( + SDFGBackwardRunner(sdfg, "C"), + torch_func, + dict( + A=np.random.rand(1).astype(np.float32), + B=np.random.rand(1).astype(np.float32), + ), + ) + + +@pytest.mark.autodiff +@run_correctness +def test_tasklets_only(): + + def torch_func(*, A, B): + tmp_a = torch.sqrt(A) + tmp_b = torch.log(B + 1) + + C = tmp_a * tmp_b + + C.backward() + return dict(gradient_A=A.grad, gradient_B=B.grad) + + @dace.program + def tasklets_only(A: dace.float32[1], B: dace.float32[1], C: dace.float32[1]): + tmp_a = dace.define_local_scalar(dace.float32) + tmp_b = dace.define_local_scalar(dace.float32) + + with dace.tasklet: + a << A[0] + a_out >> tmp_a + + a_out = sqrt(a) + + with dace.tasklet: + a << B[0] + a_out >> tmp_b + + a_out = log(a + 1) + + with dace.tasklet: + a << tmp_a + b << tmp_b + c >> C[0] + c = a * b + + sdfg = tasklets_only.to_sdfg(simplify=False) + sdfg.simplify() + + return ( + SDFGBackwardRunner(sdfg, "C"), + torch_func, + dict( + A=np.random.rand(1).astype(np.float32), + B=np.random.rand(1).astype(np.float32), + ), + ) + + +@pytest.mark.autodiff +@run_correctness +def test_add_mmul_transpose_log(): + + def torch_func(*, X, Y, W): + + Xt = X.T + YW = W * Y + Z = Xt @ YW + Zl = torch.log(Z + 1) + + S = Zl.sum() + S.backward() + return dict(gradient_X=X.grad, gradient_Y=Y.grad, gradient_W=W.grad) + + @dace.program + def add_mmul_transpose_log( + X: dace.float32[4, 5], + Y: dace.float32[4, 3], + W: dace.float32[4, 3], + S: dace.float32[1], + ): + + Xt = np.transpose(X) + YW = W * Y + Z = Xt @ YW + + @dace.map(_[0:5, 0:3]) + def summap(i, j): + s >> S(1, lambda x, y: x + y)[0] + z << Z[i, j] + s = log(z + 1) + + sdfg = add_mmul_transpose_log.to_sdfg() + + return ( + SDFGBackwardRunner(sdfg, "S"), + torch_func, + dict( + X=np.random.rand(4, 5).astype(np.float32), + W=np.random.rand(4, 3).astype(np.float32), + Y=np.random.rand(4, 3).astype(np.float32), + ), + ) + + +@pytest.mark.autodiff +@run_correctness +def test_reduce_node_1_axis_and_none_axis(): + + def torch_func(*, X, Y, W): + + Xt = X.T + YW = torch.sum(W, dim=0) * Y + Z = Xt @ YW + Zl = torch.log(Z + 1) + + S = Zl.sum() + S.backward() + return dict(gradient_X=X.grad, gradient_Y=Y.grad, gradient_W=W.grad) + + @dace.program + def reduce_node_1_axis_and_none_axis(X: dace.float32[4, 5], Y: dace.float32[4, 3], W: dace.float32[7, 4, 3]): + + Xt = np.transpose(X) + YW = np.sum(W, axis=0) * Y + Z = Xt @ YW + + Zl = dace.elementwise(lambda x: log(x + 1), Z) + S = np.sum(Zl) + return S + + sdfg = reduce_node_1_axis_and_none_axis.to_sdfg() + + return ( + SDFGBackwardRunner(sdfg, "__return"), + torch_func, + dict( + X=np.random.rand(4, 5).astype(np.float32), + W=np.random.rand(7, 4, 3).astype(np.float32), + Y=np.random.rand(4, 3).astype(np.float32), + ), + ) + + +@pytest.mark.autodiff +@run_correctness +def test_reduce_max_simple(): + + def torch_func(*, W): + + Z = torch.max(W, dim=1) + S = Z.values.sum() + S.backward() + return dict(gradient_W=W.grad) + + @dace.program + def reduce_max_simple(W: dace.float32[4, 5]): + + Z = np.max(W, axis=1) + S = np.sum(Z) + return S + + sdfg = reduce_max_simple.to_sdfg() + + return ( + SDFGBackwardRunner(sdfg, "__return"), + torch_func, + dict(W=np.random.rand(4, 5).astype(np.float32)), + ) + + +@pytest.mark.autodiff +@run_correctness +def test_reduce_max_node_1_axis(): + + def torch_func(*, X, Y, W): + + Xt = X.T + YW = torch.min(W, dim=0).values * Y + Z = Xt @ YW + Zl = torch.log(Z + 1) + + S = Zl.sum() + S.backward() + return dict(gradient_X=X.grad, gradient_Y=Y.grad, gradient_W=W.grad) + + @dace.program + def dace_func(X: dace.float64[4, 5], Y: dace.float64[4, 3], W: dace.float64[7, 4, 3]): + + Xt = np.transpose(X) + YW = np.min(W, axis=0) * Y + Z = Xt @ YW + + Zl = dace.elementwise(lambda x: log(x + 1), Z) + S = np.sum(Zl) + return S + + sdfg = dace_func.to_sdfg() + + return ( + SDFGBackwardRunner(sdfg, "__return"), + torch_func, + dict( + X=np.random.rand(4, 5).astype(np.float64), + W=np.random.rand(7, 4, 3).astype(np.float64), + Y=np.random.rand(4, 3).astype(np.float64), + ), + ) + + +@pytest.mark.autodiff +@run_correctness +def test_reshape(): + + @dace.program + def single_state_reshape(inp: dace.float64[9], bias: dace.float64[3], target_shape: dace.int64[2]): + reshaped = dace.define_local([3, 3], dace.float64) + donnx.ONNXReshape(data=inp, shape=target_shape, reshaped=reshaped) + Z = reshaped + bias + Zl = dace.elementwise(lambda x: log(x + 1), Z) + S = np.sum(Zl) + return S + + sdfg = single_state_reshape.to_sdfg(simplify=False) + + sdfg.apply_transformations_repeated([StateFusion]) + + def torch_func(*, inp, bias): + reshaped = torch.reshape(inp, [3, 3]) + + Z = reshaped + bias + Zl = torch.log(Z + 1) + S = Zl.sum() + + S.backward() + return dict(gradient_inp=inp.grad, gradient_bias=bias.grad) + + return ( + SDFGBackwardRunner(sdfg, "__return", simplify=False), + torch_func, + dict( + inp=np.random.rand(9).astype(np.float64), + bias=np.random.rand(3).astype(np.float64), + ), + ) + + +@pytest.mark.autodiff +@run_correctness +def test_reshape_on_memlet_path(): + old_default = donnx.default_implementation + donnx.default_implementation = "pure" + + @dace.program + def single_state_reshape_memlet_path(inp1: dace.float64[9], bias: dace.float64[3], target_shape: dace.int64[2]): + reshaped = dace.define_local([3, 3], dace.float64) + donnx.ONNXReshape(data=inp1, shape=target_shape, reshaped=reshaped) + Z = reshaped + bias + Zl = dace.elementwise(lambda x: log(x + 1), Z) + S = np.sum(Zl) + return S + + sdfg = single_state_reshape_memlet_path.to_sdfg(simplify=False) + + sdfg.expand_library_nodes() + sdfg.apply_transformations_repeated([StateFusion]) + + donnx.default_implementation = old_default + + def torch_func(*, inp1, bias): + reshaped = torch.reshape(inp1, [3, 3]) + + Z = reshaped + bias + Zl = torch.log(Z + 1) + S = Zl.sum() + + S.backward() + return dict(gradient_inp1=inp1.grad, gradient_bias=bias.grad) + + return ( + SDFGBackwardRunner(sdfg, "__return", simplify=False), + torch_func, + dict( + inp1=np.random.rand(9).astype(np.float64), + bias=np.random.rand(3).astype(np.float64), + ), + ) + + +@pytest.mark.autodiff +@run_correctness +def test_reshape_reuse_in_same_state(): + old_default = donnx.default_implementation + donnx.default_implementation = "pure" + + @dace.program + def single_state_reshape_same_state(inp: dace.float64[9], target_shape: dace.int64[2]): + reshaped = dace.define_local([3, 3], dace.float64) + donnx.ONNXReshape(data=inp, shape=target_shape, reshaped=reshaped) + Zl = dace.elementwise(lambda x: log(x + 1), reshaped) + S = np.sum(Zl) + return S + + sdfg = single_state_reshape_same_state.to_sdfg(simplify=False) + + sdfg.expand_library_nodes() + sdfg.apply_transformations_repeated([StateFusion]) + + donnx.default_implementation = old_default + + def torch_func(*, inp): + reshaped = torch.reshape(inp, [3, 3]) + + Z = reshaped + Zl = torch.log(Z + 1) + S = Zl.sum() + + S.backward() + return dict(gradient_inp=inp.grad) + + return ( + SDFGBackwardRunner(sdfg, "__return", simplify=False), + torch_func, + dict(inp=np.random.rand(9).astype(np.float64), ), + ) + + +if __name__ == "__main__": + test_gemm() + test_sum() + test_complex_tasklet() + test_tasklets_only_reuse() + test_tasklets_multioutput() + test_tasklets_only() + test_add_mmul_transpose_log() + test_reduce_node_1_axis_and_none_axis() + test_reduce_max_simple() + test_reduce_max_node_1_axis() + test_reshape() + test_reshape_on_memlet_path() + test_reshape_reuse_in_same_state() diff --git a/tests/autodiff/torch_backward/test_dont_compute_input_grads.py b/tests/autodiff/torch_backward/test_dont_compute_input_grads.py new file mode 100644 index 0000000000..d64a9e0400 --- /dev/null +++ b/tests/autodiff/torch_backward/test_dont_compute_input_grads.py @@ -0,0 +1,68 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +import torch +from torch import nn + +from dace.ml import DaceModule +from tests.utils import torch_tensors_close + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_skip_input_grads(use_cpp_dispatcher: bool): + + class Module(torch.nn.Module): + + def __init__(self): + super(Module, self).__init__() + self.fc1 = nn.Parameter(torch.rand(10, 10)) + + def forward(self, x): + return x @ self.fc1 + + dace_module = Module() + pt_module = Module() + pt_module.load_state_dict(dace_module.state_dict()) + + shape = [8, 10] + input_value = torch.rand(*shape, dtype=torch.float32) + + pytorch_input = torch.empty( + *shape, + dtype=torch.float32, + requires_grad=False, + ) + pytorch_input.copy_(input_value) + dace_input = torch.empty(*shape, dtype=torch.float32, requires_grad=False) + dace_input.copy_(input_value) + + # TODO: provide a better API for input names + dispatcher_suffix = "cpp" if use_cpp_dispatcher else "ctypes" + dace_module = DaceModule(dace_module, + sdfg_name=f"test_skip_input_grads_{dispatcher_suffix}", + backward=True, + inputs_to_skip=["onnx::MatMul_0"], + compile_torch_extension=use_cpp_dispatcher) + + dy = torch.rand(8, 10) + + dace_output = dace_module(dace_input) + pt_output = pt_module(pytorch_input) + + torch_tensors_close("output", pt_output, dace_output) + + # check that fc1.grad is being computed + dace_output.backward(dy) + pt_output.backward(dy) + torch_tensors_close("param_grad", pt_module.fc1.grad, dace_module.model.fc1.grad) + + # Make sure that input grad is not being computed + assert len(dace_module.backward_sdfg.node(0).sink_nodes()) == 1, \ + f"Expected 1 sink node (no input gradient), got {len(dace_module.backward_sdfg.node(0).sink_nodes())}" + + +if __name__ == "__main__": + test_skip_input_grads(use_cpp_dispatcher=True) + test_skip_input_grads(use_cpp_dispatcher=False) diff --git a/tests/autodiff/torch_backward/test_dropout.py b/tests/autodiff/torch_backward/test_dropout.py new file mode 100644 index 0000000000..7e360ccf26 --- /dev/null +++ b/tests/autodiff/torch_backward/test_dropout.py @@ -0,0 +1,69 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +from typing import Literal, Union +import pytest + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +import torch +from torch import nn +from dace.ml import DaceModule +from tests.utils import torch_tensors_close + + +@pytest.mark.torch +def test_dropout_fwd_training(): + p = 0.5 + module = nn.Dropout(p=p).train() + dace_module = DaceModule(module, + sdfg_name="test_dropout_fwd_training", + dummy_inputs=(torch.ones(10, 10), ), + training=True) + + # dropout will set some of these to zero + test_data = torch.randint(1, 10, (10, 10)).float() + + out = dace_module(torch.clone(test_data)) + zeroed = out == 0 + + scale = 1 / (1 - p) + torch_tensors_close("output", test_data[~zeroed] * scale, out[~zeroed]) + + +@pytest.mark.torch +@pytest.mark.autodiff +@pytest.mark.parametrize("p", [0, 0.99, 0.6, 0.5]) +def test_dropout_bwd(p: Union[float, Literal[0]]): + module = nn.Dropout(p=p).train() + sdfg_name = f"test_dropout_{str(p).replace('.', '_')}_bwd" + dace_module = DaceModule(module, + sdfg_name=sdfg_name, + dummy_inputs=(torch.ones(10, 10), ), + backward=True, + training=True) + + test_data = torch.randint(1, 10, (10, 10)).float() + test_data.requires_grad = True + dy = torch.rand_like(test_data) + + out = dace_module(torch.clone(test_data)) + + zeroed = out == 0 + scale = 1 / (1 - p) + # check that fwd was correct + torch_tensors_close("output", test_data[~zeroed] * scale, out[~zeroed]) + + out.backward(dy) + + # check that the gradient is correct: + zeros = torch.zeros_like(test_data.grad) + # check that zeroed values are zero in the grad + torch_tensors_close("grad_zeroed", zeros[zeroed], test_data.grad[zeroed]) + + # check that non-zeroed values are correct + torch_tensors_close("grad_zeroed", dy[~zeroed] * scale, test_data.grad[~zeroed]) + + +if __name__ == "__main__": + test_dropout_fwd_training() + # Test with different dropout probabilities + for p in [0, 0.99, 0.6, 0.5]: + test_dropout_bwd(p=p) diff --git a/tests/autodiff/torch_backward/test_extremal_reduction_backward.py b/tests/autodiff/torch_backward/test_extremal_reduction_backward.py new file mode 100644 index 0000000000..8758cdec43 --- /dev/null +++ b/tests/autodiff/torch_backward/test_extremal_reduction_backward.py @@ -0,0 +1,160 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. + +import numpy as np +import pytest + +pytest.importorskip("torch", reason="PyTorch not installed") +import torch + +import dace +from tests.autodiff.test_single_state import SDFGBackwardRunner + + +def run_max_reduction_test(dace_func, torch_func, inputs, rtol=1e-5, atol=1e-5): + sdfg = dace_func.to_sdfg() + runner = SDFGBackwardRunner(sdfg, "__return") + + sdfg_dict = {name: arr.copy() for name, arr in inputs.items()} + torch_dict = {name: torch.tensor(arr.copy(), requires_grad=True) for name, arr in inputs.items()} + + sdfg_results = runner.run(**sdfg_dict) + torch_results = torch_func(**torch_dict) + + for k, v in torch_results.items(): + v = v.detach().numpy() + assert np.allclose(sdfg_results[k], v, rtol=rtol, atol=atol), \ + f"Gradient mismatch for '{k}':\n DaCe: {sdfg_results[k]}\n PyTorch: {v}" + + +@pytest.mark.autodiff +def test_max_single_maximum(): + """Max reduction with single maximum - no ties.""" + + def torch_func(*, W): + Z = torch.amax(W, dim=0) + S = Z.sum() + S.backward() + return dict(gradient_W=W.grad) + + @dace.program + def dace_func(W: dace.float32[4]): + Z = np.max(W, axis=0) + S = np.sum(Z) + return S + + inputs = dict(W=np.array([1.0, 3.0, 2.0, 0.0], dtype=np.float32)) + run_max_reduction_test(dace_func, torch_func, inputs) + + +@pytest.mark.autodiff +def test_max_tied_values_2d(): + """Max reduction with tied values along an axis. + + For input [[1, 3], [3, 2]] with max along axis=0: + - Column 0: max=3 at row 1 only -> grad [0, 1] + - Column 1: max=3 at row 0 only -> grad [1, 0] + """ + + def torch_func(*, W): + Z = torch.amax(W, dim=0) + S = Z.sum() + S.backward() + return dict(gradient_W=W.grad) + + @dace.program + def dace_func(W: dace.float32[2, 2]): + Z = np.max(W, axis=0) + S = np.sum(Z) + return S + + inputs = dict(W=np.array([[1.0, 3.0], [3.0, 2.0]], dtype=np.float32)) + run_max_reduction_test(dace_func, torch_func, inputs) + + +@pytest.mark.autodiff +def test_max_tied_values_same_column(): + """Max reduction with tied values in the same column. + + For input [[3, 1], [3, 2]] with max along axis=0: + - Column 0: max=3 at rows 0 AND 1 -> split grad equally: [0.5, 0.5] + - Column 1: max=2 at row 1 only -> grad [0, 1] + + Expected gradient: [[0.5, 0], [0.5, 1]] + """ + + def torch_func(*, W): + Z = torch.amax(W, dim=0) + S = Z.sum() + S.backward() + return dict(gradient_W=W.grad) + + @dace.program + def dace_func(W: dace.float32[2, 2]): + Z = np.max(W, axis=0) + S = np.sum(Z) + return S + + inputs = dict(W=np.array([[3.0, 1.0], [3.0, 2.0]], dtype=np.float32)) + run_max_reduction_test(dace_func, torch_func, inputs) + + +@pytest.mark.autodiff +def test_max_all_equal_column(): + """Max reduction where entire column has same value. + + For input [[3, 1], [3, 2], [3, 0]] with max along axis=0: + - Column 0: all values are 3 -> split equally: [1/3, 1/3, 1/3] + - Column 1: max=2 at row 1 only -> grad [0, 1, 0] + + Expected gradient: [[1/3, 0], [1/3, 1], [1/3, 0]] + """ + + def torch_func(*, W): + Z = torch.amax(W, dim=0) + S = Z.sum() + S.backward() + return dict(gradient_W=W.grad) + + @dace.program + def dace_func(W: dace.float32[3, 2]): + Z = np.max(W, axis=0) + S = np.sum(Z) + return S + + inputs = dict(W=np.array([[3.0, 1.0], [3.0, 2.0], [3.0, 0.0]], dtype=np.float32)) + run_max_reduction_test(dace_func, torch_func, inputs) + + +@pytest.mark.autodiff +def test_min_tied_values(): + """Min reduction with tied values. + + For input [[1, 2], [1, 3], [2, 1]] with min along axis=0: + - Column 0: min=1 at rows 0 AND 1 -> split: [0.5, 0.5, 0] + - Column 1: min=1 at row 2 only -> grad [0, 0, 1] + + Expected gradient: [[0.5, 0], [0.5, 0], [0, 1]] + """ + + def torch_func(*, W): + Z = torch.amin(W, dim=0) + S = Z.sum() + S.backward() + return dict(gradient_W=W.grad) + + @dace.program + def dace_func(W: dace.float32[3, 2]): + Z = np.min(W, axis=0) + S = np.sum(Z) + return S + + inputs = dict(W=np.array([[1.0, 2.0], [1.0, 3.0], [2.0, 1.0]], dtype=np.float32)) + run_max_reduction_test(dace_func, torch_func, inputs) + + +if __name__ == "__main__": + test_max_single_maximum() + test_max_tied_values_2d() + test_max_tied_values_same_column() + test_max_all_equal_column() + test_min_tied_values() diff --git a/tests/autodiff/torch_backward/test_full_training_graph.py b/tests/autodiff/torch_backward/test_full_training_graph.py new file mode 100644 index 0000000000..9e97b10b65 --- /dev/null +++ b/tests/autodiff/torch_backward/test_full_training_graph.py @@ -0,0 +1,227 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import copy + +import pytest + +import numpy as np + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +import torch + +import dace + +from dace.ml import DaceModule +from tests.utils import torch_tensors_close, tensors_close + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_module(): + gpu = False + module = torch.nn.Sequential(torch.nn.Linear(12, 2, bias=False)) + + torch_module = copy.deepcopy(module) + dace_module = copy.deepcopy(module) + + dace_module = DaceModule(dace_module, + sdfg_name="test_full_training_graph_module", + simplify=False, + backward=True, + training=True, + auto_optimize=False) + + x = torch.randn(8, 12) + + expected_output = torch_module(x) + result = dace_module(x) + torch_tensors_close('output', expected_output, result) + + dc_loss = dace_module(x).sum() + dc_loss.backward() + + pt_loss = torch_module(x).sum() + pt_loss.backward() + + tensors_close("loss", pt_loss, dc_loss) + assert all(hasattr(p, 'grad') and p.grad is not None for p in dace_module.parameters()), \ + "Not all parameters have gradients computed" + + for d, t in zip(dace_module.parameters(), torch_module.parameters()): + torch_tensors_close("param", t.grad, d.grad) + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_parse_backward_simple(): + x = torch.randn(10, 5, dtype=torch.float64) + dy = torch.randn(10, dtype=torch.float64) + + @dace.program + def train_step(x: dace.float64[10, 5], dy: dace.float64[10]): + x.requires_grad_() + red = np.add.reduce(x, axis=1) + torch.autograd.backward(red, dy) + return x.grad + + sdfg = train_step.to_sdfg() + sdfg.expand_library_nodes() + sdfg.validate() + + result = sdfg(x.clone(), dy.clone()) + tensors_close('x.grad', dy.reshape(10, 1).expand(10, 5), result) + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_parse_backward_scalar(): + x = torch.randn(10, 5, dtype=torch.float64) + + @dace.program + def train_step(x: dace.float64[10, 5]): + x.requires_grad_() + red = np.add.reduce(x, axis=[0, 1]) + torch.autograd.backward(red) + return x.grad + + sdfg = train_step.to_sdfg() + sdfg.expand_library_nodes() + sdfg.validate() + + result = sdfg(x.clone()) + tensors_close('x.grad', 1, result) + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_parse_backward_with_forwarding(): + x = torch.randn(10, 5, dtype=torch.float64) + dy = torch.randn(10, dtype=torch.float64) + + @dace.program + def train_step(x: dace.float64[10, 5]): + x.requires_grad_() + y = x + 1 + red = np.add.reduce(x, axis=1, keepdims=True) + z = red * y + loss = np.add.reduce(z, axis=[0, 1]) + torch.autograd.backward(loss) + return x.grad + + def torch_fn(x): + x.requires_grad_() + y = x + 1 + red = x.sum(axis=1, keepdims=True) + z = red * y + loss = z.sum() + loss.backward() + return x.grad + + sdfg = train_step.to_sdfg() + sdfg.expand_library_nodes() + sdfg.validate() + + result = sdfg(x.clone()) + expected = torch_fn(x.clone()) + tensors_close('x.grad', expected, result) + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_two_backward_passes(): + + @dace.program + def train_step(x1: dace.float64[10, 5], x2: dace.float64[5], dy: dace.float64[10]): + x1.requires_grad_() + x2.requires_grad_() + + z1 = x1 + 1 + y1 = np.log(z1) + l1 = np.add.reduce(y1, axis=1) + + z2 = x2 * 2 + y2 = np.log(z2) + l2 = y2.sum() + + l2.backward() + l1.backward(dy) + return x1.grad, x2.grad + + def torch_fn(x1, x2, dy): + x1.requires_grad_() + x2.requires_grad_() + z1 = x1 + 1 + y1 = torch.log(z1).sum(axis=1) + + z2 = x2 * 2 + y2 = torch.log(z2).sum() + y2.backward() + y1.backward(dy) + return x1.grad, x2.grad + + sdfg = train_step.to_sdfg() + sdfg.validate() + sdfg.expand_library_nodes() + sdfg.validate() + + x1 = torch.randn(10, 5, dtype=torch.float64) + x2 = torch.randn(5, dtype=torch.float64) + dy = torch.randn(10, dtype=torch.float64) + + r1, r2 = sdfg(x1.clone(), x2.clone(), dy.clone()) + ex_1, ex_2 = torch_fn(x1.clone(), x2.clone(), dy.clone()) + tensors_close('x2.grad', ex_2, r2) + tensors_close('x1.grad', ex_1, r1) + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_two_backward_passes_accumulate(): + + @dace.program + def train_step(x: dace.float64[10, 5], dy: dace.float64[10]): + x.requires_grad_() + + z1 = x + 1 + y1 = np.log(z1) + l1 = np.add.reduce(y1, axis=1) + + z2 = x * 2 + y2 = np.log(z2) + l2 = y2.sum() + + l2.backward() + l1.backward(dy) + return x.grad + + def torch_fn(x, dy): + x.requires_grad = True + z1 = x + 1 + y1 = torch.log(z1).sum(axis=1) + + z2 = x * 2 + y2 = torch.log(z2).sum() + y2.backward() + y1.backward(dy) + return x.grad + + sdfg = train_step.to_sdfg() + sdfg.validate() + sdfg.expand_library_nodes() + sdfg.validate() + + x1 = torch.randn(10, 5, dtype=torch.float64) + dy = torch.randn(10, dtype=torch.float64) + + result = sdfg(x1.clone(), dy.clone()) + expected = torch_fn(x1.clone(), dy.clone()) + + tensors_close('x.grad', expected, result) + + +if __name__ == "__main__": + test_module() + test_parse_backward_simple() + test_parse_backward_scalar() + test_parse_backward_with_forwarding() + test_two_backward_passes() + test_two_backward_passes_accumulate() diff --git a/tests/autodiff/torch_backward/test_llama_decoder_backward.py b/tests/autodiff/torch_backward/test_llama_decoder_backward.py new file mode 100644 index 0000000000..79cb8621b1 --- /dev/null +++ b/tests/autodiff/torch_backward/test_llama_decoder_backward.py @@ -0,0 +1,107 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +pytest.importorskip("transformers", + reason="transformers not installed. Please install with: pip install dace[ml-testing]") +import torch +import torch.nn as nn +from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaConfig +from dace.ml import DaceModule +from tests.utils import torch_tensors_close + + +# Create a wrapper module that handles the position embeddings internally +class LlamaDecoderLayerWrapper(nn.Module): + + def __init__(self, decoder_layer, config): + super().__init__() + self.decoder_layer = decoder_layer + self.config = config + + # Create rotary embeddings as part of the wrapper + from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding + self.rotary_emb = LlamaRotaryEmbedding(config) + + def forward(self, hidden_states, attention_mask, position_ids): + # Generate position embeddings + cos, sin = self.rotary_emb(hidden_states, position_ids) + + # Call the decoder layer + outputs = self.decoder_layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + position_embeddings=(cos, sin), + past_key_value=None, + output_attentions=False, + use_cache=False, + ) + + # Return only the hidden states (first element of the tuple) + return outputs[0] + + +@pytest.mark.xdist_group("large_ML_models") +@pytest.mark.torch +@pytest.mark.autodiff +def test_llama_decoder_backward(): + # Create configuration + config = LlamaConfig( + hidden_size=512, + intermediate_size=1024, + num_attention_heads=8, + num_key_value_heads=8, + max_position_embeddings=128, + rms_norm_eps=1e-5, + rope_theta=10000.0, + attention_dropout=0.0, + hidden_act="silu", + ) + + # Create decoder layer + decoder_layer = LlamaDecoderLayer(config, layer_idx=0) + + # Prepare dummy inputs + batch_size = 2 + seq_length = 128 + + # Create input tensors + hidden_states = torch.randn(batch_size, seq_length, config.hidden_size) + attention_mask = torch.ones(batch_size, 1, seq_length, seq_length) + position_ids = torch.arange(seq_length).unsqueeze(0).expand(batch_size, seq_length) + + # Create wrapped model + wrapped_model = LlamaDecoderLayerWrapper(decoder_layer, config) + + # Avoid the simplify pass since it takes too long for this model + dace_model = DaceModule( + wrapped_model, + sdfg_name="test_llama_decoder_backward", + onnx_simplify=True, + backward=True, + ) + + hidden_states_pt, attention_mask_pt, position_ids_pt = (torch.clone(hidden_states), torch.clone(attention_mask), + torch.clone(position_ids)) + hidden_states_pt.requires_grad = True + + hidden_states_dace, attention_mask_dace, position_ids_dace = (torch.clone(hidden_states), + torch.clone(attention_mask), + torch.clone(position_ids)) + hidden_states_dace.requires_grad = True + + wrapped_model(hidden_states_pt, attention_mask_pt, position_ids_pt).sum().backward() + dace_model(hidden_states_dace, attention_mask_dace, position_ids_dace).sum().backward() + + # Check gradients of the parameters + for (name, dace_param), (pt_name, pt_param) in zip(wrapped_model.named_parameters(), dace_model.named_parameters()): + assert 'model.' + name == pt_name, f"Parameter name mismatch: expected 'model.{name}', got '{pt_name}'" + torch_tensors_close(name, dace_param.grad, pt_param.grad) + + # Check the gradients of the input tensor + torch_tensors_close("hidden_states_pt_grad", hidden_states_pt.grad, hidden_states_dace.grad) + + +if __name__ == "__main__": + test_llama_decoder_backward() diff --git a/tests/autodiff/torch_backward/test_llama_for_causalLM_backward.py b/tests/autodiff/torch_backward/test_llama_for_causalLM_backward.py new file mode 100644 index 0000000000..a1ec63a3c8 --- /dev/null +++ b/tests/autodiff/torch_backward/test_llama_for_causalLM_backward.py @@ -0,0 +1,105 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +pytest.importorskip("transformers", + reason="transformers not installed. Please install with: pip install dace[ml-testing]") +import torch +import torch.nn as nn +from transformers import LlamaForCausalLM, LlamaConfig +from dace.ml import DaceModule +from tests.utils import torch_tensors_close + + +class LlamaWrapper(nn.Module): + + def __init__(self, model): + super().__init__() + self.model = model + self.config = model.config + + def forward(self, input_ids): + # Get the embeddings + inputs_embeds = self.model.model.embed_tokens(input_ids) + + # Create position ids + batch_size, seq_length = input_ids.shape + position_ids = torch.arange(seq_length, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) + + # Process through decoder layers + hidden_states = inputs_embeds + + # Create causal mask for attention + causal_mask = torch.triu(torch.ones((seq_length, seq_length), device=input_ids.device), diagonal=1) + causal_mask = causal_mask.masked_fill(causal_mask == 1, float('-inf')) + causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) + + # Forward through each layer + for decoder_layer in self.model.model.layers: + # Get rotary embeddings + cos, sin = self.model.model.rotary_emb(hidden_states, position_ids) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=None, + output_attentions=False, + use_cache=False, + position_embeddings=(cos, sin), + ) + hidden_states = layer_outputs[0] + + # Final layer norm + hidden_states = self.model.model.norm(hidden_states) + + # Get logits + logits = self.model.lm_head(hidden_states) + + return logits + + +@pytest.mark.xdist_group("large_ML_models") +@pytest.mark.torch +@pytest.mark.autodiff +@pytest.mark.long +def test_llama_model_backward(): + # Create a small LLaMA configuration + config = LlamaConfig( + vocab_size=32000, + hidden_size=512, + intermediate_size=1024, + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=8, + max_position_embeddings=128, + rms_norm_eps=1e-5, + rope_theta=10000.0, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + ) + + # Create the full model + model = LlamaForCausalLM(config) + export_seq_length = 16 + export_batch_size = 1 + input = torch.randint(3, config.vocab_size, (export_batch_size, export_seq_length)) + + wrapped_model = LlamaWrapper(model) + + # Avoid the simplify pass since it takes too long for this model + dace_model = DaceModule(wrapped_model, sdfg_name="test_llama_model_backward", backward=True, onnx_simplify=True) + + wrapped_model(input.clone()).sum().backward() + dace_model(input.clone()).sum().backward() + + # Check gradients of the parameters + for (name, dace_param), (pt_name, pt_param) in zip(wrapped_model.named_parameters(), dace_model.named_parameters()): + assert 'model.' + name == pt_name, f"Parameter name mismatch: expected 'model.{name}', got '{pt_name}'" + torch_tensors_close(name, pt_param.grad, dace_param.grad) + + +if __name__ == "__main__": + test_llama_model_backward() diff --git a/tests/autodiff/torch_backward/test_multi_output_ad.py b/tests/autodiff/torch_backward/test_multi_output_ad.py new file mode 100644 index 0000000000..057813add3 --- /dev/null +++ b/tests/autodiff/torch_backward/test_multi_output_ad.py @@ -0,0 +1,64 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +import torch + +from dace.ml import DaceModule +from tests.utils import torch_tensors_close + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_multi_output(use_cpp_dispatcher: bool): + + class Module(torch.nn.Module): + + def forward(self, x): + return x + 1, x * 2 + + module = Module() + + input_value = torch.rand(5, 10, dtype=torch.float32) + + pytorch_input = torch.empty( + 5, + 10, + dtype=torch.float32, + requires_grad=False, + ) + pytorch_input.copy_(input_value) + + dace_input = torch.empty(5, 10, dtype=torch.float32, requires_grad=False) + dace_input.copy_(input_value) + + pytorch_input.requires_grad = True + dace_input.requires_grad = True + + torch_dy = torch.randn(5, 10, dtype=torch.float32) + dace_dy = torch_dy.clone() + + pytorch_y1, pytorch_y2 = module(pytorch_input) + + pytorch_y1.backward(torch_dy) + pytorch_y2.backward(torch_dy) + + dispatcher_suffix = "cpp" if use_cpp_dispatcher else "ctypes" + dace_module = DaceModule( + module, + sdfg_name=f"test_multi_output_ad_{dispatcher_suffix}", + backward=True, + compile_torch_extension=use_cpp_dispatcher, + ) + + dace_y1, dace_y2 = dace_module(dace_input) + + dace_y1.backward(dace_dy, retain_graph=True) + dace_y2.backward(dace_dy) + + torch_tensors_close("grad", pytorch_input.grad, dace_input.grad) + + +if __name__ == "__main__": + test_multi_output(use_cpp_dispatcher=True) + test_multi_output(use_cpp_dispatcher=False) diff --git a/tests/autodiff/torch_backward/test_pytorch.py b/tests/autodiff/torch_backward/test_pytorch.py new file mode 100644 index 0000000000..05d2f6948a --- /dev/null +++ b/tests/autodiff/torch_backward/test_pytorch.py @@ -0,0 +1,305 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import numpy as np +import pytest +import copy + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +import torch +import torch.nn as nn +import torch.nn.functional as F + +from dace.ml import DaceModule +from tests.utils import torch_tensors_close + + +@pytest.mark.torch +@pytest.mark.autodiff +def run_pytorch_module( + module: torch.nn.Module, + sdfg_name: str, + shape: tuple = None, + use_max: bool = False, + auto_optimize: bool = False, + rtol: float = 1e-4, + atol: float = 1e-3, + post_onnx_hooks: list = None, +): + shape = shape or (3, 5) + + pt_model_for_dace = copy.deepcopy(module) + + input_value = torch.rand(*shape, dtype=torch.float32) + + pytorch_input = torch.empty( + *shape, + dtype=torch.float32, + requires_grad=False, + ) + pytorch_input.copy_(input_value) + + dace_input = torch.empty(*shape, dtype=torch.float32, requires_grad=False) + dace_input.copy_(input_value) + + pytorch_input.requires_grad = True + dace_input.requires_grad = True + + if use_max: + pytorch_s = module(pytorch_input).max() + else: + pytorch_s = module(pytorch_input).sum() + pytorch_s.backward() + + dace_module = DaceModule( + pt_model_for_dace, + sdfg_name=sdfg_name, + simplify=False, + backward=True, + auto_optimize=auto_optimize, + compile_torch_extension=True, + ) + if post_onnx_hooks is not None: + for i, h in enumerate(post_onnx_hooks): + dace_module.append_post_onnx_hook(str(i), h) + + if use_max: + dace_s = dace_module(dace_input).max() + else: + dace_s = dace_module(dace_input).sum() + dace_s.backward() + torch_tensors_close("grad", pytorch_input.grad, dace_input.grad, rtol=rtol, atol=atol) + + for (name, dace_param), (pt_name, pt_param) in zip(module.named_parameters(), dace_module.named_parameters()): + assert 'model.' + name == pt_name + torch_tensors_close(name, pt_param.grad, dace_param.grad, rtol=rtol, atol=atol) + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_simple(): + + class Module(torch.nn.Module): + + def forward(self, x): + x = torch.sqrt(x) + x = torch.log(x) + return x + + run_pytorch_module(Module(), sdfg_name="test_simple") + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_repeated(): + + class Module(torch.nn.Module): + + def forward(self, x): + x = torch.sqrt(x) + x = torch.sqrt(x) + return x + + run_pytorch_module(Module(), sdfg_name="test_repeated") + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_softmax(): + + class Module(torch.nn.Module): + + def forward(self, x): + x = F.softmax(x, dim=1) + return x + + run_pytorch_module(Module(), sdfg_name="test_softmax", use_max=True) + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_reshape_on_memlet_path(): + # required test: this function in a nn.Module, with apply simplify so that the reshape is + # inlined and copy is removed + class Module(torch.nn.Module): + + def forward(self, x): + reshaped = torch.reshape(x + 1, [3, 3]) + return torch.log(reshaped) + torch.reshape(torch.tensor([[3, 2, 1]], device=reshaped.device), [3]) + + run_pytorch_module(Module(), sdfg_name="test_reshape_on_memlet_path", shape=(9, )) + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_weights_ln(): + + class Module(torch.nn.Module): + + def __init__(self): + super(Module, self).__init__() + self.fc1 = nn.Linear(784, 120) + self.fc2 = nn.Linear(120, 32) + self.ln = nn.LayerNorm(32) + self.fc3 = nn.Linear(32, 10) + + def forward(self, x): + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.ln(x) + x = self.fc3(x) + return x + + run_pytorch_module(Module(), sdfg_name="test_weights_ln", shape=(4, 784), auto_optimize=False) + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_layernorm(): + + class Module(torch.nn.Module): + + def __init__(self): + super(Module, self).__init__() + self.ln = nn.LayerNorm(3) + + def forward(self, x): + return self.ln(x) + + run_pytorch_module(Module(), sdfg_name="test_layernorm", shape=(2, 3), use_max=True, atol=1e-2) + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_weights(): + + class Module(torch.nn.Module): + + def __init__(self): + super(Module, self).__init__() + self.fc1 = nn.Linear(784, 120) + self.fc2 = nn.Linear(120, 32) + self.fc3 = nn.Linear(32, 10) + + def forward(self, x): + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + run_pytorch_module(Module(), sdfg_name="test_weights", shape=(4, 784), use_max=False, auto_optimize=False) + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_nested_gradient_summation(): + + class Module(torch.nn.Module): + + def __init__(self): + super(Module, self).__init__() + self.fc1 = nn.Parameter(torch.rand(10, 10)) + + def forward(self, x): + y = x @ self.fc1 + z = x * 2 + return z + y + + run_pytorch_module(Module(), + sdfg_name="test_nested_gradient_summation", + shape=(4, 10), + use_max=False, + auto_optimize=False) + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_trans_add(): + + class Module(torch.nn.Module): + + def __init__(self): + super(Module, self).__init__() + + def forward(self, x): + x = x + 1 + x = torch.transpose(x.reshape(4, 4), 1, 0) + return x + + run_pytorch_module(Module(), sdfg_name="test_trans_add", shape=(16, ), use_max=False) + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_batched_matmul(): + + class Module(torch.nn.Module): + + def __init__(self): + super(Module, self).__init__() + self.fc1 = nn.Parameter(torch.ones([10, 5, 3])) + + def forward(self, x): + return self.fc1 @ x + + run_pytorch_module(Module(), sdfg_name="test_batched_matmul", use_max=False, auto_optimize=False) + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_scalar_forwarding(): + + class Module(torch.nn.Module): + + def __init__(self): + super(Module, self).__init__() + self.factor = nn.Parameter(torch.ones(())) + + def forward(self, x): + return self.factor * x + + run_pytorch_module(Module(), sdfg_name="test_scalar_forwarding", use_max=False, auto_optimize=False) + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_scalar_buffer(): + + class Module(torch.nn.Module): + + def __init__(self): + super(Module, self).__init__() + self.register_buffer("factor", torch.tensor(2)) + + def forward(self, x): + return self.factor * x + + run_pytorch_module(Module(), sdfg_name="test_scalar_buffer", use_max=False) + + +@pytest.mark.torch +@pytest.mark.autodiff +@pytest.mark.skip(reason="Requires pure implementation of expand") +def test_simple_broadcasted_mul(): + + class Module(torch.nn.Module): + + def forward(self, x): + y = x.sum(axis=0) + return x * y + + run_pytorch_module(Module(), sdfg_name="test_simple_broadcasted_mul") + + +if __name__ == "__main__": + test_simple() + test_repeated() + test_softmax() + test_reshape_on_memlet_path() + test_weights_ln() + test_layernorm() + test_weights() + test_nested_gradient_summation() + test_trans_add() + test_batched_matmul() + test_scalar_forwarding() + test_scalar_buffer() + # test_simple_broadcasted_mul is skipped diff --git a/tests/autodiff/torch_backward/test_training.py b/tests/autodiff/torch_backward/test_training.py new file mode 100644 index 0000000000..0858288145 --- /dev/null +++ b/tests/autodiff/torch_backward/test_training.py @@ -0,0 +1,124 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import os +import copy +import pytest + +import numpy as np + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +pytest.importorskip("transformers", + reason="transformers not installed. Please install with: pip install dace[ml-testing]") +import torch +from torch import nn, optim +from transformers import BertConfig +from transformers.models.bert.modeling_bert import BertLayer + +from dace.ml import DaceModule +from tests.utils import torch_tensors_close + + +def training_step( + dace_model: torch.nn.Module, + pt_model: torch.nn.Module, + train_batch: tuple, + sdfg_name: str, + train_criterion: torch.nn.Module = None, +): + + # Copy over the weights + dace_model.load_state_dict(pt_model.state_dict()) + for dace_value, value in zip(pt_model.state_dict().values(), dace_model.state_dict().values()): + assert torch.allclose(dace_value, value), "State dict copy verification failed" + + dace_model = DaceModule(dace_model, sdfg_name=sdfg_name, backward=True, simplify=True, training=True) + + x, y = train_batch + + train_criterion = train_criterion or nn.NLLLoss() + + pt_loss = train_criterion(pt_model(x), y) + + dace_output = dace_model(x) + dace_loss = train_criterion(dace_output, y) + + diff = abs(pt_loss.item() - dace_loss.item()) / pt_loss.item() + assert diff < 1e-5, f"Loss mismatch: relative difference {diff:.2e} exceeds tolerance 1e-5" + + pt_loss.backward() + dace_loss.backward() + + for (name, dace_param), (pt_name, pt_param) in zip(pt_model.named_parameters(), dace_model.named_parameters()): + assert 'model.' + name == pt_name, f"Parameter name mismatch: expected 'model.{name}', got '{pt_name}'" + torch_tensors_close(name, pt_param.grad, dace_param.grad) + + optimizer = optim.SGD(pt_model.parameters(), lr=0.001) + dace_optimizer = optim.SGD(dace_model.parameters(), lr=0.001) + optimizer.step() + dace_optimizer.step() + + for (name, dace_param), (pt_name, pt_param) in zip(pt_model.named_parameters(), dace_model.named_parameters()): + assert 'model.' + name == pt_name, f"Parameter name mismatch after optimizer step: expected 'model.{name}', got '{pt_name}'" + torch_tensors_close(name, pt_param.detach(), dace_param.detach()) + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_mnist(): + input_size = 784 + hidden_sizes = [128, 64] + output_size = 10 + + # initialize modules + # yapf: disable + model = nn.Sequential(nn.Linear(input_size, hidden_sizes[0]), + nn.ReLU(), + nn.Linear(hidden_sizes[0], hidden_sizes[1]), + nn.ReLU(), + nn.Linear(hidden_sizes[1], output_size), + nn.LayerNorm(output_size), + nn.LogSoftmax(dim=1)) + + dace_model = nn.Sequential(nn.Linear(input_size, hidden_sizes[0]), + nn.ReLU(), + nn.Linear(hidden_sizes[0], hidden_sizes[1]), + nn.ReLU(), + nn.Linear(hidden_sizes[1], output_size), + nn.LayerNorm(output_size), + nn.LogSoftmax(dim=1)) + + # check forward pass using loss + images = torch.randn(64, 784) + labels = torch.randint(0, 10, [64], dtype=torch.long) + + training_step(dace_model, model, (images, labels), sdfg_name="test_mnist_training") + +@pytest.mark.xdist_group("large_ML_models") +@pytest.mark.torch +@pytest.mark.autodiff +@pytest.mark.skip(reason="Requires pure implementation of expand") +def test_bert(): + batch_size = 2 + seq_len = 512 + hidden_size = 768 + + class BertTokenSoftmaxClf(nn.Module): + + def __init__(self): + super(BertTokenSoftmaxClf, self).__init__() + self.bert = BertLayer(BertConfig(hidden_act="relu")).eval() + self.sm = nn.LogSoftmax(dim=-1) + + def forward(self, x): + embs = self.bert(x)[0] + return self.sm(embs.sum(dim=-1)) + + # check forward pass using loss + input = torch.randn([batch_size, seq_len, hidden_size]) + labels = torch.tensor([0, 123], dtype=torch.long) + + training_step(BertTokenSoftmaxClf(), BertTokenSoftmaxClf(), (input, labels), sdfg_name="test_bert_training") + + +if __name__ == "__main__": + test_mnist() + # test_bert is skipped diff --git a/tests/blas/nodes/axpy_test.py b/tests/blas/nodes/axpy_test.py index 297961b78c..c7feb2029f 100755 --- a/tests/blas/nodes/axpy_test.py +++ b/tests/blas/nodes/axpy_test.py @@ -9,11 +9,10 @@ import random import dace -from dace.fpga_testing import fpga_test from dace.memlet import Memlet import dace.libraries.blas as blas -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG +from dace.transformation.interstate import InlineSDFG from dace.transformation.dataflow import StreamingMemory from dace.libraries.standard.memory import aligned_ndarray @@ -35,22 +34,13 @@ def run_test(configs, target): ref_result = reference_result(x, y_ref, a) - if target == "fpga_stream": - sdfg = stream_fpga_graph(veclen, dtype, "fpga", i) - elif target == "fpga_array": - sdfg = fpga_graph(veclen, dtype, "fpga", i) - else: - sdfg = pure_graph(veclen, dtype, "pure", i) + sdfg = pure_graph(veclen, dtype, "pure", i) program = sdfg.compile() with dace.config.set_temporary('compiler', 'allow_view_arguments', value=True): - if target in ["fpga_stream", "fpga_array"]: - program(x=x, y=y, a=a, n=np.int32(n)) - ref_norm = np.linalg.norm(y - ref_result) / n - else: - program(x=x, y=y, a=a, n=np.int32(n)) - ref_norm = np.linalg.norm(y - ref_result) / n + program(x=x, y=y, a=a, n=np.int32(n)) + ref_norm = np.linalg.norm(y - ref_result) / n if ref_norm >= 1e-5: raise ValueError(f"Failed validation for target {target}.") @@ -100,47 +90,5 @@ def test_pure(): run_test(configs, "pure") -def fpga_graph(veclen, dtype, test_case, expansion): - sdfg = pure_graph(veclen, dtype, test_case, expansion) - sdfg.apply_transformations_repeated([FPGATransformSDFG, InlineSDFG]) - return sdfg - - -def stream_fpga_graph(veclen, precision, test_case, expansion): - sdfg = fpga_graph(veclen, precision, test_case, expansion) - sdfg.expand_library_nodes() - sdfg.apply_transformations_repeated([InlineSDFG, StreamingMemory], [{}, {"storage": dace.StorageType.FPGA_Local}]) - return sdfg - - -# TODO: Investigate and re-enable if possible. -@pytest.mark.skip(reason="Unexplained CI Regression") -@fpga_test() -def test_axpy_fpga_array(): - configs = [(0.5, 1, dace.float32), (1.0, 4, dace.float64)] - return run_test(configs, "fpga_array") - - -# TODO: Investigate and re-enable if possible. -@pytest.mark.skip(reason="Unexplained CI Regression") -@fpga_test() -def test_axpy_fpga_stream(): - configs = [(0.5, 1, dace.float32), (1.0, 4, dace.float64)] - return run_test(configs, "fpga_stream") - - if __name__ == "__main__": - - cmdParser = argparse.ArgumentParser(allow_abbrev=False) - - cmdParser.add_argument("--target", dest="target", default="pure") - - args = cmdParser.parse_args() - - if args.target == "fpga": - test_axpy_fpga_array(None) - test_axpy_fpga_stream(None) - elif args.target == "pure": - test_pure() - else: - raise RuntimeError(f"Unknown target \"{args.target}\".") + test_pure() diff --git a/tests/blas/nodes/dot_test.py b/tests/blas/nodes/dot_test.py index 8f87c24240..c06f3a9eea 100755 --- a/tests/blas/nodes/dot_test.py +++ b/tests/blas/nodes/dot_test.py @@ -9,11 +9,9 @@ import dace from dace.memlet import Memlet -from dace.fpga_testing import xilinx_test, intel_fpga_test import dace.libraries.blas as blas -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory +from dace.transformation.interstate import InlineSDFG from dace.config import set_temporary @@ -48,23 +46,9 @@ def pure_graph(implementation, dtype, veclen): return sdfg -def fpga_graph(implementation, dtype, veclen): - sdfg = pure_graph(implementation, dtype, veclen) - sdfg.apply_transformations_repeated([FPGATransformSDFG, InlineSDFG]) - sdfg.expand_library_nodes() - sdfg.apply_transformations_repeated([InlineSDFG, StreamingMemory], [{}, {"storage": dace.StorageType.FPGA_Local}]) - return sdfg - - def run_test(target, size, vector_length): if target == "pure": sdfg = pure_graph("pure", dace.float32, vector_length) - elif target == "intel_fpga": - dace.Config.set("compiler", "fpga", "vendor", value="intel_fpga") - sdfg = fpga_graph("FPGA_Accumulate", dace.float32, vector_length) - elif target == "xilinx": - dace.Config.set("compiler", "fpga", "vendor", value="xilinx") - sdfg = fpga_graph("FPGA_PartialSums", dace.float32, vector_length) else: print(f"Unsupported target: {target}") exit(-1) @@ -96,25 +80,6 @@ def test_dot_pure(): assert isinstance(run_test("pure", 64, 1), dace.SDFG) -# TODO: Refactor to use assert or return True/False (pytest deprecation of returning non-booleans) -@xilinx_test() -def test_dot_xilinx(): - return run_test("xilinx", 64, 16) - - -# TODO: Refactor to use assert or return True/False (pytest deprecation of returning non-booleans) -@xilinx_test() -def test_dot_xilinx_decoupled(): - with set_temporary("compiler", "xilinx", "decouple_array_interfaces", value=True): - return run_test("xilinx", 64, 16) - - -# TODO: Refactor to use assert or return True/False (pytest deprecation of returning non-booleans) -@intel_fpga_test() -def test_dot_intel_fpga(): - return run_test("intel_fpga", 64, 16) - - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("N", type=int, nargs="?", default=64) diff --git a/tests/blas/nodes/gemv_test.py b/tests/blas/nodes/gemv_test.py index 2915fdc57f..b745b41f2e 100644 --- a/tests/blas/nodes/gemv_test.py +++ b/tests/blas/nodes/gemv_test.py @@ -5,9 +5,7 @@ from dace.memlet import Memlet import dace.libraries.blas as blas -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory +from dace.transformation.interstate import InlineSDFG def pure_graph(dtype, transposed, expansion, veclen, alpha, beta, expansion_args=None): @@ -47,18 +45,6 @@ def pure_graph(dtype, transposed, expansion, veclen, alpha, beta, expansion_args return sdfg -def fpga_graph(dtype, transposed, expansion, veclen, alpha, beta, tile_size_x, tile_size_y): - sdfg = pure_graph(dtype, transposed, expansion, veclen, alpha, beta, { - "tile_size_x": tile_size_x, - "tile_size_y": tile_size_y - }) - sdfg.apply_transformations_repeated([FPGATransformSDFG, InlineSDFG]) - - sdfg.expand_library_nodes() - sdfg.apply_transformations_repeated([InlineSDFG, StreamingMemory], [{}, {"storage": dace.StorageType.FPGA_Local}]) - return sdfg - - def run_gemv(target: str, n: int, m: int, @@ -71,26 +57,6 @@ def run_gemv(target: str, beta = 0 # TODO: GEMV is not currently implemented for beta != 0 if target == "pure": sdfg = pure_graph(dace.float32, transposed, "pure", vectorize, alpha, beta) - elif target == "tiles_by_column": - if not transposed and vectorize > 1: - raise NotImplementedError("Non-transposed vectorized tile-by-column NYI.") - sdfg = fpga_graph(dace.float32, - transposed, - "FPGA_TilesByColumn", - vectorize, - alpha, - beta, - tile_size_x=tile_size_x, - tile_size_y=tile_size_y) - elif target == "accumulate": - sdfg = fpga_graph(dace.float32, - transposed, - "FPGA_Accumulate", - vectorize, - alpha, - beta, - tile_size_x=tile_size_x, - tile_size_y=tile_size_y) else: raise ValueError("Unsupported target") @@ -115,16 +81,6 @@ def test_pure(): run_gemv("pure", 256, 512, transposed=True) -@fpga_test() -def test_gemv_fpga_tiles_by_column(): - return run_gemv("tiles_by_column", 256, 512, transposed=True, vectorize=4) - - -@fpga_test() -def test_gemv_fpga_accumulate(): - return run_gemv("accumulate", 256, 512, vectorize=4) - - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("M", type=int, nargs="?", default=256) diff --git a/tests/blas/nodes/ger_test.py b/tests/blas/nodes/ger_test.py index 096d44786f..efd91a69c0 100755 --- a/tests/blas/nodes/ger_test.py +++ b/tests/blas/nodes/ger_test.py @@ -8,12 +8,9 @@ import dace import dace.libraries.blas as blas -from dace.fpga_testing import fpga_test from dace.libraries.standard.memory import aligned_ndarray from dace.memlet import Memlet -from dace.transformation.dataflow.streaming_memory import StreamingMemory from dace.transformation.interstate.sdfg_nesting import InlineSDFG -from dace.transformation.interstate.fpga_transform_sdfg import FPGATransformSDFG def pure_graph(implementation, dtype, veclen): @@ -49,39 +46,6 @@ def pure_graph(implementation, dtype, veclen): return ger_node, state, sdfg -def fpga_graph(dtype, veclen, tile_size_x, tile_size_y): - ger_node, state, sdfg = pure_graph("FPGA", dtype, veclen) - ger_node.expand(sdfg, state, tile_size_x=tile_size_x, tile_size_y=tile_size_y) - sdfg.apply_transformations_repeated([FPGATransformSDFG, InlineSDFG]) - sdfg.expand_library_nodes() - sdfg.apply_transformations_repeated([InlineSDFG, StreamingMemory], [{}, {"storage": dace.StorageType.FPGA_Local}]) - return sdfg - - -def run_test(ger, target): - - x = np.ndarray(m, dtype=np.float32) - y = np.ndarray(n, dtype=np.float32) - A = np.ndarray((m, n), dtype=np.float32) - res = A.copy() - ref = res.copy() - - x[:] = np.random.rand(m).astype(np.float32) - y[:] = np.random.rand(n).astype(np.float32) - A[:] = np.random.rand(m, n).astype(np.float32) - - ger(alpha=alpha, x=x, y=y, A=A, res=res, m=m, n=n) - - ref = scipy.linalg.blas.sger(alpha=alpha, x=x, y=y, a=A) - - diff = np.linalg.norm(np.subtract(res, ref)) - if diff >= args.eps * n * m: - raise RuntimeError("Unexpected result returned from ger rank 1 operation: " - "got:\n{}\nexpected:\n{} on {}".format(A, ref, target)) - else: - print("Ok") - - def run_ger(target: str, n: int, m: int, @@ -95,8 +59,6 @@ def run_ger(target: str, ger_node, state, sdfg = pure_graph("pure", dace.float32, veclen) ger_node.expand(sdfg, state) sdfg.apply_transformations_repeated([InlineSDFG]) - elif target == "fpga": - sdfg = fpga_graph(dace.float32, veclen, tile_size_x, tile_size_y) else: raise ValueError("Unsupported target") @@ -126,11 +88,6 @@ def test_ger_pure(): run_ger("pure", 256, 512, 16, 32) -@fpga_test() -def test_ger_fpga(): - return run_ger("fpga", 256, 512, 16, 32) - - if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/tests/codegen/external_memory_test.py b/tests/codegen/external_memory_test.py index 169e050914..47eac55ff3 100644 --- a/tests/codegen/external_memory_test.py +++ b/tests/codegen/external_memory_test.py @@ -30,7 +30,7 @@ def tester(a: dace.float64[N]): a = np.random.rand(20) if symbolic: - extra_args = dict(a=a, N=20) + extra_args = dict(N=20) else: extra_args = {} diff --git a/tests/codegen/unroller_general_test.py b/tests/codegen/unroller_general_test.py deleted file mode 100644 index ba9aad8c62..0000000000 --- a/tests/codegen/unroller_general_test.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -from dace.sdfg.sdfg import InterstateEdge -from dace import subsets as sbs, dtypes, memlet as mem -import dace -import numpy as np - - -def create_deeply_nested_sdfg(): - sdfg = dace.SDFG("deepnest_test") - state: dace.SDFGState = sdfg.add_state("init") - xarr = state.add_array("x", [4, 100], dace.float32) - yarr = state.add_array("y", [4, 100], dace.float32) - - topMapEntry, topMapExit = state.add_map("topmap", dict(k="0:2")) - topMapEntry.schedule = dtypes.ScheduleType.Unrolled - - nsdfg = dace.SDFG("nest") - nstate = nsdfg.add_state("nested_state", True) - xRead = nstate.add_array("xin", [4, 100], dace.float32) - xWrite = nstate.add_array("xout", [4, 100], dace.float32) - mapEntry, mapExit = nstate.add_map("map1", dict(w="0:2")) - mapEntry.schedule = dtypes.ScheduleType.Unrolled - noUnrollEntry, noUnrollExit = nstate.add_map("map2", dict(i="0:100")) - nope = nstate.add_tasklet("nop", dict(_in=None), dict(_out=None), "_out = _in") - inputMem = mem.Memlet("xin[2*k+w, i]") - outputMem = mem.Memlet("xout[2*k+w, i]") - nstate.add_memlet_path( - xRead, - mapEntry, - noUnrollEntry, - nope, - memlet=inputMem, - dst_conn="_in", - ) - nstate.add_memlet_path( - nope, - noUnrollExit, - mapExit, - xWrite, - memlet=outputMem, - src_conn="_out", - ) - - nstate2 = nsdfg.add_state("second_nest") - tasklet = nstate2.add_tasklet("overwrite", set(), set(["_out"]), "_out = 15.0") - xWrite2 = nstate2.add_write("xout") - nstate2.add_memlet_path( - tasklet, - xWrite2, - memlet=mem.Memlet("xout[mpt, 0]"), - src_conn="_out", - ) - - nsdfg.add_edge(nstate, nstate2, InterstateEdge(None, dict(mpt="k"))) - nsdfg_node = state.add_nested_sdfg(nsdfg, set(["xin"]), set(['xout'])) - nsdfg_node.unique_name = "SomeUniqueName" - - state.add_memlet_path( - xarr, - topMapEntry, - nsdfg_node, - memlet=mem.Memlet.from_array("x", sdfg.arrays["x"]), - dst_conn="xin", - ) - state.add_memlet_path( - nsdfg_node, - topMapExit, - yarr, - memlet=mem.Memlet.from_array("y", sdfg.arrays["y"]), - src_conn="xout", - ) - - return sdfg - - -def test_unrolled_deeply_nested(): - sdfg = create_deeply_nested_sdfg() - passed = np.full((4, 100), 42.0, dtype=np.float32) - returns = np.zeros((4, 100), np.float32) - sdfg(x=passed, y=returns) - expected = passed - expected[1, 0] = 15.0 - expected[0, 0] = 15.0 - assert (np.allclose(expected, returns, 1e-6)) - - -def create_simple_unrolled_sdfg(): - - @dace.program - def ucopy(input: dace.float32[4], output: dace.float32[4]): - for i in dace.map[0:4]: - output[i] = input[i] - - sdfg = ucopy.to_sdfg() - for node in sdfg.states()[0].nodes(): - if (isinstance(node, dace.sdfg.nodes.MapEntry)): - node.schedule = dace.ScheduleType.Unrolled - return sdfg - - -def test_unrolled_simple_map(): - sdfg = create_simple_unrolled_sdfg() - passed = np.full((4), 42.0, dtype=np.float32) - returns = np.zeros((4), np.float32) - sdfg(input=passed, output=returns) - assert (np.allclose(passed, returns, 1e-6)) - - -if __name__ == "__main__": - test_unrolled_deeply_nested() - test_unrolled_simple_map() diff --git a/tests/codegen/unroller_test.py b/tests/codegen/unroller_test.py deleted file mode 100644 index 03cfa0d908..0000000000 --- a/tests/codegen/unroller_test.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -import dace -import numpy as np -import unittest - - -@dace.program -def Copy(output: dace.int32[5], input: dace.int32[5]): - - @dace.map - def mytasklet(i: _[0:5]): - inp << input[i] - out >> output[i] - - out = inp - - -class UnrollerTest(unittest.TestCase): - - def test_unroller(self): - sdfg = Copy.to_sdfg() - - # Transform map to unrolled - for state in sdfg.nodes(): - for node in state.nodes(): - if isinstance(node, dace.sdfg.nodes.MapEntry): - node.schedule = dace.ScheduleType.Unrolled - - input = np.ones([5], dtype=np.int32) - output = np.zeros([5], dtype=np.int32) - sdfg(output=output, input=input) - - self.assertTrue((output == input).all()) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/conftest.py b/tests/conftest.py index 57f611ce66..8fe2fb56f7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,3 +13,14 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config): if config.option.markexpr == 'mpi': if exitstatus in (pytest.ExitCode.TESTS_FAILED, pytest.ExitCode.INTERNAL_ERROR, pytest.ExitCode.INTERRUPTED): os._exit(1) + + +def pytest_generate_tests(metafunc): + """ + This method sets up the parametrizations for the custom fixtures + """ + if "use_cpp_dispatcher" in metafunc.fixturenames: + metafunc.parametrize("use_cpp_dispatcher", [ + pytest.param(True, id="use_cpp_dispatcher"), + pytest.param(False, id="no_use_cpp_dispatcher"), + ]) diff --git a/tests/fortran/parent_test.py b/tests/fortran/parent_test.py index 5369e58ac3..d4619103ae 100644 --- a/tests/fortran/parent_test.py +++ b/tests/fortran/parent_test.py @@ -29,7 +29,7 @@ def test_fortran_frontend_parent(): ast_transforms.ParentScopeAssigner().visit(ast) assert ast.parent is None - assert ast.main_program.parent == None + assert ast.main_program.parent is None main_program = ast.main_program # Both executed lines @@ -43,7 +43,7 @@ def test_fortran_frontend_parent(): for subroutine in ast.subroutine_definitions: - assert subroutine.parent == None + assert subroutine.parent is None assert subroutine.execution_part.parent == subroutine for execution in subroutine.execution_part.execution: assert execution.parent == subroutine @@ -78,10 +78,10 @@ def test_fortran_frontend_module(): ast_transforms.ParentScopeAssigner().visit(ast) assert ast.parent is None - assert ast.main_program.parent == None + assert ast.main_program.parent is None module = ast.modules[0] - assert module.parent == None + assert module.parent is None specification = module.specification_part.specifications[0] assert specification.parent == module diff --git a/tests/fpga/auto_opt_fpga_test.py b/tests/fpga/auto_opt_fpga_test.py deleted file mode 100644 index 8f059c96e2..0000000000 --- a/tests/fpga/auto_opt_fpga_test.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -""" FPGA Tests for Auto Optimization """ - -import dace -import numpy as np -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG -from dace.transformation.auto import auto_optimize as aopt -from dace.transformation.auto import fpga as fpga_auto_opt - -N = dace.symbol('N') - - -@fpga_test() -def test_global_to_local(): - """ - Tests global_to_local optimization - """ - - @dace.program - def global_to_local(alpha: dace.float32, B: dace.float32[N]): - tmp = alpha / 2 - return tmp * B - - size = 8 - - alpha = 0.5 - B = np.random.rand(size).astype(np.float32) - - sdfg = global_to_local.to_sdfg() - - aopt.auto_optimize(sdfg, dace.DeviceType.FPGA) - - # Check that transformation has been actually applied - # There should be only one transient among the sdfg arrays and it must have Local Storage Type - candidate = None - for name, array in sdfg.arrays.items(): - if array.transient: - assert array.storage == dace.dtypes.StorageType.FPGA_Local - candidate = name - break - - assert candidate is not None - - # Check that all access nodes that refer to this container have also been updated - for node, graph in sdfg.all_nodes_recursive(): - if isinstance(node, dace.sdfg.nodes.AccessNode): - trace = dace.sdfg.utils.trace_nested_access(node, graph, graph.parent) - - for (_, acc_node), memlet_trace, state_trace, sdfg_trace in trace: - if acc_node is not None and acc_node.data == candidate: - nodedesc = node.desc(graph) - assert nodedesc.storage == dace.dtypes.StorageType.FPGA_Local - - C = sdfg(alpha=alpha, B=B, N=size) - ref = alpha / 2 * B - assert np.allclose(ref, C) - - return sdfg - - -@fpga_test() -def test_rr_interleave(): - """ - Tests RR interleaving of containers to memory banks - """ - - @dace.program - def rr_interleave(A: dace.float32[8], B: dace.float32[8], C: dace.float32[8]): - return A + B + C - - A = np.random.rand(8).astype(np.float32) - B = np.random.rand(8).astype(np.float32) - C = np.random.rand(8).astype(np.float32) - - sdfg = rr_interleave.to_sdfg() - sdfg.apply_transformations([FPGATransformSDFG]) - - #specifically run the the interleave transformation - allocated = fpga_auto_opt.fpga_rr_interleave_containers_to_banks(sdfg) - - # There will be 5 arrays (one is a temporary containing A + B) - assert allocated == [2, 1, 1, 1] - - R = sdfg(A=A, B=B, C=C) - assert np.allclose(A + B + C, R) - - return sdfg - - -if __name__ == "__main__": - test_global_to_local(8) - test_rr_interleave() diff --git a/tests/fpga/autorun_test.py b/tests/fpga/autorun_test.py deleted file mode 100644 index 5e52a953dc..0000000000 --- a/tests/fpga/autorun_test.py +++ /dev/null @@ -1,111 +0,0 @@ -import argparse -import dace -import numpy as np -import re -from dace.fpga_testing import intel_fpga_test - -DTYPE = dace.float32 - - -def make_sdfg(): - - N = dace.symbol("N", DTYPE) - P = dace.symbol("P", DTYPE) - - sdfg = dace.SDFG("autorun_test") - - pre_state = sdfg.add_state("host_to_device") - state = sdfg.add_state("compute") - post_state = sdfg.add_state("device_to_host") - - sdfg.add_edge(pre_state, state, dace.InterstateEdge()) - sdfg.add_edge(state, post_state, dace.InterstateEdge()) - - sdfg.add_array("arr_host", (N, ), DTYPE) - sdfg.add_array("arr", (N, ), DTYPE, storage=dace.StorageType.FPGA_Global, transient=True) - - # Copy from host to device - pre_host = pre_state.add_read("arr_host") - pre_device = pre_state.add_write("arr") - pre_state.add_memlet_path(pre_host, pre_device, memlet=dace.Memlet("arr[0:N]")) - - # Copy from device to host - post_device = post_state.add_read("arr") - post_host = post_state.add_write("arr_host") - post_state.add_memlet_path(post_device, post_host, memlet=dace.Memlet("arr_host[0:N]")) - - sdfg.add_stream("pipe_in", DTYPE, storage=dace.StorageType.FPGA_Local, transient=True) - - # Read from memory into a stream - memory_read = state.add_read("arr") - pipe_in_write = state.add_write("pipe_in") - state.add_memlet_path(memory_read, pipe_in_write, memlet=dace.Memlet("arr[0:N]", other_subset="0")) - - sdfg.add_stream("pipes_systolic", DTYPE, shape=(P + 1, ), storage=dace.StorageType.FPGA_Local, transient=True) - - # Simple processing element that can be autorun - pipe_in_read = state.add_read("pipe_in") - entry_add, exit_add = state.add_map("add", {"i": "0:N"}, schedule=dace.ScheduleType.FPGA_Device) - tasklet_add = state.add_tasklet("add", {"val_in"}, {"val_out"}, "val_out = val_in + 9") - state.add_memlet_path(pipe_in_read, entry_add, tasklet_add, dst_conn="val_in", memlet=dace.Memlet("pipe_in[0]")) - pipe_systolic_write_head = state.add_write("pipes_systolic") - state.add_memlet_path(tasklet_add, - exit_add, - pipe_systolic_write_head, - src_conn="val_out", - memlet=dace.Memlet("pipes_systolic[0]")) - - # Systolic array which can be autorun - unroll_entry, unroll_exit = state.add_map("systolic_array", {"p": "0:P"}, - schedule=dace.ScheduleType.FPGA_Device, - unroll=True) - pipe_unroll_read = state.add_read("pipes_systolic") - state.add_memlet_path(unroll_entry, pipe_unroll_read, memlet=dace.Memlet()) - systolic_entry, systolic_exit = state.add_map("add_systolic", {"i": "0:N"}, schedule=dace.ScheduleType.FPGA_Device) - systolic_tasklet = state.add_tasklet("add_systolic", {"val_in"}, {"val_out"}, "val_out = 2 * val_in") - state.add_memlet_path(pipe_unroll_read, - systolic_entry, - systolic_tasklet, - dst_conn="val_in", - memlet=dace.Memlet("pipes_systolic[p]")) - pipe_unroll_write = state.add_write("pipes_systolic") - state.add_memlet_path(systolic_tasklet, - systolic_exit, - pipe_unroll_write, - src_conn="val_out", - memlet=dace.Memlet("pipes_systolic[p + 1]")) - state.add_memlet_path(pipe_unroll_write, unroll_exit, memlet=dace.Memlet()) - - # Write back to memory - pipe_systolic_read_tail = state.add_read("pipes_systolic") - memory_write = state.add_write("arr") - state.add_memlet_path(pipe_systolic_read_tail, memory_write, memlet=dace.Memlet("arr[0:N]", other_subset="P")) - - return sdfg - - -@intel_fpga_test() -def test_autorun(): - - n = 128 - p = 4 - - sdfg = make_sdfg() - sdfg.specialize({"N": 128, "P": 4}) - - arr = np.ones((128, ), dtype=DTYPE.type) - - for c in (c for c in sdfg.generate_code() if c.language == "cl"): - if len(re.findall(r"__attribute__\(\(autorun\)\)", c.code)) != 2: - raise RuntimeError("Autogen attributes not found.") - - sdfg(arr_host=arr) - - if any(arr != 2**4 * 10): - raise ValueError("Verification failed.") - - return sdfg - - -if __name__ == "__main__": - test_autorun(None) diff --git a/tests/fpga/axpy_transform_test.py b/tests/fpga/axpy_transform_test.py deleted file mode 100644 index 3c8a97a013..0000000000 --- a/tests/fpga/axpy_transform_test.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -import argparse -import dace -import numpy as np -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG - -N = dace.symbol('N') - - -@dace.program(dace.float64, dace.float64[N], dace.float64[N]) -def axpy(A, X, Y): - - @dace.map(_[0:N]) - def multiplication(i): - in_A << A - in_X << X[i] - in_Y << Y[i] - out >> Y[i] - - out = in_A * in_X + in_Y - - -@fpga_test() -def test_axpy_transformed(): - - n = 24 - - print(f'Scalar-vector multiplication {n}') - - A = dace.float64(np.random.rand()) - X = np.random.rand(n) - Y = np.random.rand(n) - expected = A * X + Y - - # Obtain SDFG from @dace.program - sdfg = axpy.to_sdfg() - - # Convert SDFG to FPGA using a transformation - sdfg.apply_transformations(FPGATransformSDFG) - - # Specialize and execute SDFG on FPGA - sdfg._name = f'axpy_fpga_{n}' - sdfg.specialize(dict(N=n)) - sdfg(A=A, X=X, Y=Y) - - diff = np.linalg.norm(expected - Y) / n - assert diff <= 1e-5 - - return sdfg - - -if __name__ == "__main__": - test_axpy_transformed(None) diff --git a/tests/fpga/bank_split_test.py b/tests/fpga/bank_split_test.py deleted file mode 100644 index cb9f39eb6c..0000000000 --- a/tests/fpga/bank_split_test.py +++ /dev/null @@ -1,62 +0,0 @@ -import dace -from multibank_copy_fpga_test import mkc -from dace.dtypes import StorageType -from dace.transformation.dataflow import BankSplit -from dace.transformation import optimizer -import numpy as np - - -def test_simple_split(): - sdfg = dace.SDFG("hbm_bank_split_first_dim") - _, b, a = mkc(sdfg, None, "b", "a", StorageType.CPU_Heap, StorageType.CPU_Heap, [4, 10, 10], [40, 10], "b") - for xform in optimizer.Optimizer(sdfg).get_pattern_matches(patterns=BankSplit): - xform.apply(sdfg.node(xform.state_id), sdfg) - sdfg(a=a, b=b) - assert np.allclose(b[1], a[10:20, :]) - assert np.allclose(b[3], a[30:40, :]) - - -def test_even_split_3d(): - sdfg = dace.SDFG("hbm_bank_split_even_split_3d") - s, b, a = mkc(sdfg, None, "b", "a", StorageType.CPU_Heap, StorageType.CPU_Heap, [8, 50, 50, 50], [100, 100, 100], - "b") - for xform in optimizer.Optimizer(sdfg).get_pattern_matches(patterns=BankSplit): - xform.split_array_info = [2, 2, 2] - xform.apply(sdfg.node(xform.state_id), sdfg) - b = np.random.uniform(0, 100, [8, 50, 50, 50]).astype(np.int32) - sdfg(a=a, b=b) - assert np.allclose(a[0:50, 0:50, 0:50], b[0, :, :, :]) - assert np.allclose(a[50:100, 50:100, 50:100], b[7, :, :, :]) - assert np.allclose(a[0:50, 50:100, 0:50], b[2, :, :, :]) - - -def test_second_dim_split_2d(): - sdfg = dace.SDFG("hbm_bank_split_sec_dim_split2d") - s, a, b = mkc(sdfg, None, "a", "b", StorageType.CPU_Heap, StorageType.CPU_Heap, [10, 100], [10, 10, 10], "b") - for xform in optimizer.Optimizer(sdfg).get_pattern_matches(patterns=BankSplit): - xform.split_array_info = [1, 10] - xform.apply(sdfg.node(xform.state_id), sdfg) - a = np.random.uniform(0, 10, [10, 100]).astype(np.int32) - sdfg(a=a, b=b) - for i in range(10): - assert np.allclose(a[0:10, 10 * i:(10 * i + 10)], b[i]) - - -def test_explicit_split_3d(): - sdfg = dace.SDFG("hbm_bank_split_explicit_3d") - s, a, b = mkc(sdfg, None, "a", "b", StorageType.CPU_Heap, StorageType.CPU_Heap, [120, 100, 100], [24, 40, 50, 25]) - for xform in optimizer.Optimizer(sdfg).get_pattern_matches(patterns=BankSplit): - xform.split_array_info = [3, 2, 4] - xform.apply(sdfg.node(xform.state_id), sdfg) - a = np.random.uniform(0, 100, [120, 100, 100]).astype(np.int32) - sdfg(a=a, b=b) - assert np.allclose(a[80:120, 50:100, 75:100], b[23]) - assert np.allclose(a[0:40, 50:100, 75:100], b[7]) - assert np.allclose(a[40:80, 0:50, 25:50], b[9]) - - -if __name__ == "__main__": - test_simple_split() - test_even_split_3d() - test_second_dim_split_2d() - test_explicit_split_3d() diff --git a/tests/fpga/channel_mangling_test.py b/tests/fpga/channel_mangling_test.py deleted file mode 100644 index 19533781af..0000000000 --- a/tests/fpga/channel_mangling_test.py +++ /dev/null @@ -1,178 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. - -# The scope of the test is to verify channel mangling. In the -# SDFG we have two nested SDFG that increments some streaming data by one. -# The NestedSDFGs are similar, but they work on different sizes - -# Note: to generate the code for both NSDFG, we set Dace config for compiler->unique_functions to False - -import dace -import numpy as np -import argparse -import subprocess -from dace.config import Config - -from dace.fpga_testing import intel_fpga_test -from dace.memlet import Memlet - -N = dace.symbol("N") - - -def make_increment_sdfg(sdfg_name: str, dtype=dace.float32): - inc_sdfg = dace.SDFG(sdfg_name) - - # FPGA State - - fpga_state = inc_sdfg.add_state("fpga_state") - - inc_sdfg.add_array("x", shape=[N], dtype=dtype, storage=dace.dtypes.StorageType.FPGA_Global) - inc_sdfg.add_array("y", shape=[N], dtype=dtype, storage=dace.dtypes.StorageType.FPGA_Global) - inc_sdfg.add_stream("what_a_nice_pipe", dtype, transient=True, storage=dace.dtypes.StorageType.FPGA_Local) - - data_in = fpga_state.add_read("x") - data_out = fpga_state.add_write("y") - pipe_write = fpga_state.add_write("what_a_nice_pipe") - pipe_read = fpga_state.add_read("what_a_nice_pipe") - - # ---------- ---------- - read_map_entry, read_map_exit = fpga_state.add_map('read_incr_map', - dict(i='0:N'), - schedule=dace.dtypes.ScheduleType.FPGA_Device) - - incr_tasklet = fpga_state.add_tasklet('incr_task', ['in_con'], ['out_con'], 'out_con = in_con + 1') - - # From memory to increment - fpga_state.add_memlet_path(data_in, - read_map_entry, - incr_tasklet, - dst_conn='in_con', - memlet=dace.Memlet(f"{data_in.data}[i]")) - # from increment to pipe - fpga_state.add_memlet_path(incr_tasklet, - read_map_exit, - pipe_write, - src_conn='out_con', - memlet=dace.Memlet("what_a_nice_pipe[0]")) - - # from pipe to memory - write_map_entry, write_map_exit = fpga_state.add_map('write_map', - dict(i='0:N'), - schedule=dace.dtypes.ScheduleType.FPGA_Device) - - copy_tasklet = fpga_state.add_tasklet('copy_task', ['in_con'], ['out_con'], 'out_con = in_con ') - - fpga_state.add_memlet_path(pipe_read, - write_map_entry, - copy_tasklet, - dst_conn='in_con', - memlet=dace.Memlet("what_a_nice_pipe[0]")) - fpga_state.add_memlet_path(copy_tasklet, write_map_exit, data_out, src_conn='out_con', memlet=dace.Memlet("y[i]")) - - ######### - # Validate - inc_sdfg.fill_scope_connectors() - inc_sdfg.validate() - return inc_sdfg - - -def make_nested_sdfg_fpga(dtype=dace.float32): - """ - Build an SDFG with two nested SDFGs, each one a different state - """ - - sdfg = dace.SDFG("channels_mangling") - - ########################################################################### - # Copy data to FPGA - - copy_in_state = sdfg.add_state("copy_to_device") - - sdfg.add_array("X", shape=[N], dtype=dtype) - - in_host_x = copy_in_state.add_read("X") - - sdfg.add_array("device_X", shape=[N], dtype=dtype, storage=dace.dtypes.StorageType.FPGA_Global, transient=True) - sdfg.add_array("device_tmp", shape=[N], dtype=dtype, storage=dace.dtypes.StorageType.FPGA_Global, transient=True) - - in_device_x = copy_in_state.add_write("device_X") - - copy_in_state.add_memlet_path(in_host_x, in_device_x, memlet=Memlet.simple(in_host_x, "0:N")) - - ########################################################################### - # Copy data from FPGA - - copy_out_state = sdfg.add_state("copy_to_host") - sdfg.add_array("Y", shape=[N], dtype=dtype) - sdfg.add_array("device_Y", shape=[N], dtype=dtype, storage=dace.dtypes.StorageType.FPGA_Global, transient=True) - - out_device = copy_out_state.add_read("device_Y") - out_host = copy_out_state.add_write("Y") - - copy_out_state.add_memlet_path(out_device, out_host, memlet=Memlet.simple(out_host, "0:N")) - - ######################################################################## - # First state - state = sdfg.add_state("state") - state.location["is_FPGA_kernel"] = False - - to_nest = make_increment_sdfg("nest_1", dtype) - x = state.add_read("device_X") - tmp = state.add_write("device_tmp") - - # add nested sdfg with symbol mapping - nested_sdfg = state.add_nested_sdfg(to_nest, {"x"}, {"y"}) - state.add_memlet_path(x, nested_sdfg, dst_conn="x", memlet=Memlet("device_X[0:N]")) - state.add_memlet_path(nested_sdfg, tmp, src_conn="y", memlet=Memlet("device_tmp[0:N]")) - - ######################################################################## - # First state - state2 = sdfg.add_state("state2") - state2.location["is_FPGA_kernel"] = False - - to_nest = make_increment_sdfg("nest_2", dtype) - tmp_read = state2.add_read("device_tmp") - y = state2.add_write("device_Y") - - # add nested sdfg with symbol mapping - nested_sdfg = state2.add_nested_sdfg(to_nest, {"x"}, {"y"}) - state2.add_memlet_path(tmp_read, nested_sdfg, dst_conn="x", memlet=Memlet("device_tmp[0:N]")) - state2.add_memlet_path(nested_sdfg, y, src_conn="y", memlet=Memlet("device_Y[0:N]")) - - ###################################### - # Interstate edges - sdfg.add_edge(state, state2, dace.sdfg.sdfg.InterstateEdge()) - - # Interstate edges - sdfg.add_edge(copy_in_state, state, dace.sdfg.sdfg.InterstateEdge()) - sdfg.add_edge(state2, copy_out_state, dace.sdfg.sdfg.InterstateEdge()) - sdfg.validate() - - return sdfg - - -@intel_fpga_test() -def test_channel_mangling(): - - parser = argparse.ArgumentParser() - parser.add_argument("N", type=int, nargs="?", default=32) - args = vars(parser.parse_args()) - - size_n = args["N"] - - from dace.config import Config - # set unique function to false to generate both sdfgs - Config.set("compiler", "unique_functions", value="none") - sdfg = make_nested_sdfg_fpga() - - X = np.random.rand(size_n).astype(np.float32) - Y = np.random.rand(size_n).astype(np.float32) - sdfg(X=X, Y=Y, N=size_n) - ref = X + 2 - diff = np.linalg.norm(ref - Y) / size_n - assert diff <= 1e-5 - - return sdfg - - -if __name__ == "__main__": - test_channel_mangling(None) diff --git a/tests/fpga/conflict_resolution_test.py b/tests/fpga/conflict_resolution_test.py deleted file mode 100644 index 966b8f4422..0000000000 --- a/tests/fpga/conflict_resolution_test.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. - -# Tests whether conflict resolution is handled correctly on both local and -# global memory containers from within an FPGA kernel. - -import dace -import numpy as np -from dace.fpga_testing import fpga_test - - -def make_sdfg(): - - N = dace.symbol("N") - - sdfg = dace.SDFG("fpga_conflict_resolution") - - sdfg.add_array("host_memory", [N], dace.int32) - sdfg.add_array("global_memory", [N], dace.int32, transient=True, storage=dace.StorageType.FPGA_Global) - sdfg.add_array("local_memory", [1], dace.int32, transient=True, storage=dace.StorageType.FPGA_Local) - - state = sdfg.add_state("fpga_conflict_resolution") - - # Copy memory to FPGA - pre_state = sdfg.add_state("pre_state") - pre_host = pre_state.add_read("host_memory") - pre_device = pre_state.add_write("global_memory") - pre_state.add_memlet_path(pre_host, pre_device, memlet=dace.Memlet("global_memory[0:N]")) - sdfg.add_edge(pre_state, state, dace.InterstateEdge()) - - # Copy memory back - post_state = sdfg.add_state("post_state") - post_device = post_state.add_read("global_memory") - post_host = post_state.add_write("host_memory") - post_state.add_memlet_path(post_device, post_host, memlet=dace.Memlet("global_memory[0:N]")) - sdfg.add_edge(state, post_state, dace.InterstateEdge()) - - gmem_read = state.add_read("global_memory") - gmem_write = state.add_write("global_memory") - - local_init = state.add_access("local_memory") - local_write = state.add_access("local_memory") - - # Initialize local memory - init_tasklet = state.add_tasklet("init", {}, {"out"}, "out = 0") - state.add_memlet_path(init_tasklet, local_init, src_conn="out", memlet=dace.Memlet("local_memory[0]")) - - # Accumulate on local memory - acc_entry, acc_exit = state.add_map("wcr_local", {"i": "0:N"}, schedule=dace.ScheduleType.FPGA_Device) - acc_tasklet = state.add_tasklet("wcr_local", {"gmem"}, {"lmem"}, "lmem = gmem") - state.add_memlet_path(gmem_read, acc_entry, acc_tasklet, dst_conn="gmem", memlet=dace.Memlet("global_memory[i]")) - state.add_memlet_path(local_init, acc_entry, memlet=dace.Memlet()) - state.add_memlet_path(acc_tasklet, - acc_exit, - local_write, - src_conn="lmem", - memlet=dace.Memlet("local_memory[0]", wcr="lambda a, b: a + b")) - - # Write with conflict into global memory - wcr_entry, wcr_exit = state.add_map("wcr_global", {"i": "0:N"}, schedule=dace.ScheduleType.FPGA_Device) - wcr_tasklet = state.add_tasklet("wcr_global", {"lmem"}, {"gmem"}, "gmem = lmem") - state.add_memlet_path(local_write, wcr_entry, wcr_tasklet, dst_conn="lmem", memlet=dace.Memlet("local_memory[0]")) - state.add_memlet_path(wcr_tasklet, - wcr_exit, - gmem_write, - src_conn="gmem", - memlet=dace.Memlet("global_memory[i]", wcr="lambda a, b: a + b")) - - return sdfg - - -@fpga_test() -def test_fpga_wcr(): - sdfg = make_sdfg() - size = 128 - host_memory = np.arange(size, dtype=np.int32) - reference = host_memory.copy() - sdfg(host_memory=host_memory, N=size) - assert all(np.sum(reference) + reference == host_memory) - return sdfg - - -if __name__ == "__main__": - test_fpga_wcr(None) diff --git a/tests/fpga/dot_fpga_test.py b/tests/fpga/dot_fpga_test.py deleted file mode 100644 index 44fb27af8f..0000000000 --- a/tests/fpga/dot_fpga_test.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -# Dot product with WCR -# Used as simple test for WCR over scalar - -#!/usr/bin/env python - -import click -import dace -import numpy as np -from dace.fpga_testing import fpga_test - -from dace.transformation.dataflow import MapTiling -from dace.transformation.interstate import FPGATransformSDFG - -N = dace.symbol("N") - - -@dace.program -def dot(A: dace.float32[N], B: dace.float32[N], out: dace.float32[1]): - - @dace.map - def product(i: _[0:N]): - a << A[i] - b << B[i] - o >> out(1, lambda x, y: x + y) - o = a * b - - -def run_dot(n, tile_first): - N = n - A = dace.ndarray([N], dtype=dace.float32) - B = dace.ndarray([N], dtype=dace.float32) - out_AB = dace.scalar(dace.float32) - - A[:] = np.random.rand(N).astype(dace.float32.type) - B[:] = np.random.rand(N).astype(dace.float32.type) - out_AB[0] = dace.float32(0) - - sdfg = dot.to_sdfg() - if tile_first: - sdfg.apply_transformations(MapTiling) - sdfg.apply_transformations(FPGATransformSDFG) - else: - sdfg.apply_transformations(FPGATransformSDFG) - sdfg.apply_transformations(MapTiling) - - sdfg(A=A, B=B, out=out_AB, N=N) - - diff_ab = np.linalg.norm(np.dot(A, B) - out_AB) / float(N) - assert diff_ab <= 1e-5 - - return sdfg - - -@fpga_test(assert_ii_1=False) -def test_dot_tile_first(): - return run_dot(64, True) - - -@fpga_test(assert_ii_1=False) -def test_dot_fpga_transform_first(): - return run_dot(64, False) - - -if __name__ == "__main__": - test_dot_tile_first(None) - test_dot_fpga_transform_first(None) diff --git a/tests/fpga/fpga_instrumentation_test.py b/tests/fpga/fpga_instrumentation_test.py deleted file mode 100644 index 591bb1bab7..0000000000 --- a/tests/fpga/fpga_instrumentation_test.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -import dace -from dace.sdfg.utils import is_fpga_kernel -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace import config -import numpy as np -import re - - -def make_sdfg(make_tmp_local: bool): - """ - Creates an SDFG that has a left and a right branch writing into two - respective temporary arrays, which are both read by subsequent map. - If `male_tmp_local` is set, the temporary arrays will be made local, such - that DaCe will generate a single kernel for the state. Otherwise, DaCe - will generate three separate kernels. - """ - - sdfg = dace.SDFG("instrumentation_test") - sdfg.add_array("in0", (16, ), dace.float32) - sdfg.add_array("in1", (16, ), dace.float32) - sdfg.add_array("in2", (16, ), dace.float32) - sdfg.add_array("tmp0", (16, ), dace.float32, transient=True) - sdfg.add_array("tmp1", (16, ), dace.float32, transient=True) - sdfg.add_array("out0", (16, ), dace.float32) - sdfg.add_array("out1", (16, ), dace.float32) - - state = sdfg.add_state("instrumentation_test") - - in0 = state.add_read("in0") - in1 = state.add_read("in1") - tmp0 = state.add_access("tmp0") - tmp1 = state.add_access("tmp1") - out0 = state.add_write("out0") - - # Left branch subgraph - entry_left, exit_left = state.add_map("left_map", {"i": "0:16"}) - tasklet_left = state.add_tasklet("left_tasklet", {"_in"}, {"_tmp"}, "_tmp = _in + 1") - state.add_memlet_path(in0, entry_left, tasklet_left, dst_conn="_in", memlet=dace.Memlet("in0[i]")) - state.add_memlet_path(tasklet_left, exit_left, tmp0, src_conn="_tmp", memlet=dace.Memlet("tmp0[i]")) - - # Right branch subgraph - entry_right, exit_right = state.add_map("right_map", {"i": "0:16"}) - tasklet_right = state.add_tasklet("right_tasklet", {"_in"}, {"_tmp"}, "_tmp = _in + 1") - state.add_memlet_path(in1, entry_right, tasklet_right, dst_conn="_in", memlet=dace.Memlet("in1[i]")) - state.add_memlet_path(tasklet_right, exit_right, tmp1, src_conn="_tmp", memlet=dace.Memlet("tmp1[i]")) - - # Bottom subgraph - entry_after, exit_after = state.add_map("after_map", {"i": "0:16"}) - tasklet_after = state.add_tasklet("after_tasklet", {"_tmp0", "_tmp1"}, {"_c"}, "_c = 2 * (_tmp0 + _tmp1)") - state.add_memlet_path(tmp0, entry_after, tasklet_after, dst_conn="_tmp0", memlet=dace.Memlet("tmp0[i]")) - state.add_memlet_path(tmp1, entry_after, tasklet_after, dst_conn="_tmp1", memlet=dace.Memlet("tmp1[i]")) - state.add_memlet_path(tasklet_after, exit_after, out0, src_conn="_c", memlet=dace.Memlet("out0[i]")) - - # Extra independent subgraph (will be a PE on Xilinx, kernel on Intel) - in2 = state.add_read("in2") - out1 = state.add_write("out1") - entry_extra, exit_extra = state.add_map("extra_map", {"i": "0:16"}) - tasklet_extra = state.add_tasklet("extra_tasklet", {"_in"}, {"_out"}, "_out = _in * _in") - state.add_memlet_path(in2, entry_extra, tasklet_extra, dst_conn="_in", memlet=dace.Memlet("in2[i]")) - state.add_memlet_path(tasklet_extra, exit_extra, out1, src_conn="_out", memlet=dace.Memlet("out1[i]")) - - assert sdfg.apply_transformations(FPGATransformSDFG) == 1 - assert sdfg.apply_transformations(InlineSDFG) == 1 - - if make_tmp_local: - made_local = 0 - for name, desc in sdfg.arrays.items(): - if "tmp" in name: - desc.storage = dace.StorageType.FPGA_Local - made_local += 1 - assert made_local == 2 - - for s in sdfg.states(): - if is_fpga_kernel(sdfg, s): - s.instrument = dace.InstrumentationType.FPGA - break - else: - raise RuntimeError("FPGA state was not found.") - - return sdfg - - -def run_program(sdfg): - - in0 = np.zeros((16, ), np.float32) - in1 = np.ones((16, ), np.float32) - in2 = np.ones((16, ), np.float32) - out0 = np.empty((16, ), np.float32) - out1 = np.empty((16, ), np.float32) - - sdfg(in0=in0, in1=in1, in2=in2, out0=out0, out1=out1) - - assert np.allclose(out0, 2 * ((in0 + 1) + (in1 + 1))) - assert np.allclose(out1, in2 * in2) - - -@fpga_test() -def test_instrumentation_single(): - sdfg = make_sdfg(True) - run_program(sdfg) - report = sdfg.get_latest_report() - # There should be three runtimes: One for the kernel, and two for the state - if dace.Config.get("compiler", "fpga", "vendor") == "xilinx": - # For Xilinx, processing elements live within a single kernel - expected_num_kernels = 1 - elif dace.Config.get("compiler", "fpga", "vendor") == "intel_fpga": - # For Intel, each processing element is a distinct kernel - expected_num_kernels = 2 - assert len(re.findall(r"[0-9\.]+\s+[0-9\.]+\s+[0-9\.]+\s+[0-9\.]+\s+", str(report))) == 2 + expected_num_kernels - return sdfg - - -@fpga_test() -def test_instrumentation_multiple(): - sdfg = make_sdfg(False) - with config.set_temporary("compiler", "fpga", "concurrent_kernel_detection", value=True): - run_program(sdfg) - report = sdfg.get_latest_report() - # There should be five runtimes: One for each kernel, and two for the state - assert len(re.findall(r"[0-9\.]+\s+[0-9\.]+\s+[0-9\.]+\s+[0-9\.]+\s+", str(report))) == 6 - return sdfg - - -if __name__ == "__main__": - test_instrumentation_multiple(None) - test_instrumentation_single(None) diff --git a/tests/fpga/fpga_stencil_test.py b/tests/fpga/fpga_stencil_test.py deleted file mode 100644 index 8152b84c3f..0000000000 --- a/tests/fpga/fpga_stencil_test.py +++ /dev/null @@ -1,203 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -import copy -import dace -from dace.fpga_testing import fpga_test - - -def make_sdfg(name="fpga_stcl_test", dtype=dace.float32, veclen=8): - - vtype = dace.vector(dtype, veclen) - - n = dace.symbol("N") - m = dace.symbol("M") - - sdfg = dace.SDFG(name) - - pre_state = sdfg.add_state(name + "_pre") - state = sdfg.add_state(name) - post_state = sdfg.add_state(name + "_post") - sdfg.add_edge(pre_state, state, dace.InterstateEdge()) - sdfg.add_edge(state, post_state, dace.InterstateEdge()) - - _, desc_input_host = sdfg.add_array("a", (n, m / veclen), vtype) - _, desc_output_host = sdfg.add_array("b", (n, m / veclen), vtype) - desc_input_device = copy.copy(desc_input_host) - desc_input_device.storage = dace.StorageType.FPGA_Global - desc_input_device.location["memorytype"] = "ddr" - desc_input_device.location["bank"] = "0" - desc_input_device.transient = True - desc_output_device = copy.copy(desc_output_host) - desc_output_device.storage = dace.StorageType.FPGA_Global - desc_output_device.location["memorytype"] = "ddr" - desc_output_device.location["bank"] = "1" - desc_output_device.transient = True - sdfg.add_datadesc("a_device", desc_input_device) - sdfg.add_datadesc("b_device", desc_output_device) - - # Host to device - pre_read = pre_state.add_read("a") - pre_write = pre_state.add_write("a_device") - pre_state.add_memlet_path(pre_read, pre_write, memlet=dace.Memlet(f"a_device[0:N, 0:M/{veclen}]")) - - # Device to host - post_read = post_state.add_read("b_device") - post_write = post_state.add_write("b") - post_state.add_memlet_path(post_read, post_write, memlet=dace.Memlet(f"b_device[0:N, 0:M/{veclen}]")) - - # Compute state - read_memory = state.add_read("a_device") - write_memory = state.add_write("b_device") - - # Memory streams - sdfg.add_stream("a_stream", vtype, storage=dace.StorageType.FPGA_Local, transient=True) - sdfg.add_stream("b_stream", vtype, storage=dace.StorageType.FPGA_Local, transient=True) - produce_input_stream = state.add_write("a_stream") - consume_input_stream = state.add_read("a_stream") - produce_output_stream = state.add_write("b_stream") - consume_output_stream = state.add_write("b_stream") - - tasklet = state.add_tasklet( - name, {"_north", "_west", "_east", "_south"}, {"result"}, """\ -north = _north if i >= 1 else 1 -west = _west if {W}*j + u >= 1 else 1 -east = _east if {W}*j + u < M - 1 else 1 -south = _south if i < N - 1 else 1 - -result = 0.25 * (north + west + east + south)""".format(W=veclen)) - - entry, exit = state.add_pipeline(name, { - "i": "0:N", - "j": "0:M/{}".format(veclen), - }, - schedule=dace.ScheduleType.FPGA_Device, - init_size=m / veclen, - init_overlap=False, - drain_size=m / veclen, - drain_overlap=True) - - # Unrolled map - unroll_entry, unroll_exit = state.add_map(name + "_unroll", {"u": "0:{}".format(veclen)}, - schedule=dace.ScheduleType.FPGA_Device, - unroll=True) - - # Container-to-container copies between arrays and streams - state.add_memlet_path(read_memory, - produce_input_stream, - memlet=dace.Memlet(f"{read_memory.data}[0:N, 0:M/{veclen}]", other_subset="0")) - state.add_memlet_path(consume_output_stream, - write_memory, - memlet=dace.Memlet(write_memory.data, - f"{write_memory.data}[0:N, 0:M/{veclen}]", - other_subset="0")) - - # Container-to-container copy from vectorized stream to non-vectorized - # buffer - sdfg.add_array("input_buffer", (1, ), vtype, storage=dace.StorageType.FPGA_Local, transient=True) - sdfg.add_array("shift_register", (2 * m + veclen, ), - dtype, - storage=dace.StorageType.FPGA_ShiftRegister, - transient=True) - sdfg.add_array("output_buffer", (veclen, ), dtype, storage=dace.StorageType.FPGA_Local, transient=True) - sdfg.add_array("output_buffer_packed", (1, ), vtype, storage=dace.StorageType.FPGA_Local, transient=True) - input_buffer = state.add_access("input_buffer") - shift_register = state.add_access("shift_register") - output_buffer = state.add_access("output_buffer") - output_buffer_packed = state.add_access("output_buffer_packed") - - # Only write if not initializing - read_tasklet = state.add_tasklet(name + "_conditional_read", {"_in"}, {"_out"}, - "if not {}:\n\t_out = _in".format(entry.pipeline.drain_condition())) - - # Input stream to buffer - state.add_memlet_path(consume_input_stream, - entry, - read_tasklet, - dst_conn="_in", - memlet=dace.Memlet(f"{consume_input_stream.data}[0]", dynamic=True)) - state.add_memlet_path(read_tasklet, input_buffer, src_conn="_out", memlet=dace.Memlet(f"{input_buffer.data}[0]")) - state.add_memlet_path(input_buffer, - shift_register, - memlet=dace.Memlet(f"{input_buffer.data}[0]", other_subset=f"2*M:(2*M + {veclen})")) - - # Stencils accesses - state.add_memlet_path(shift_register, - unroll_entry, - tasklet, - dst_conn="_north", - memlet=dace.Memlet(f"{shift_register.data}[u]")) # North - state.add_memlet_path(shift_register, - unroll_entry, - tasklet, - dst_conn="_west", - memlet=dace.Memlet(f"{shift_register.data}[u + M - 1]")) # West - state.add_memlet_path(shift_register, - unroll_entry, - tasklet, - dst_conn="_east", - memlet=dace.Memlet(f"{shift_register.data}[u + M + 1]")) # East - state.add_memlet_path(shift_register, - unroll_entry, - tasklet, - dst_conn="_south", - memlet=dace.Memlet(f"{shift_register.data}[u + 2 * M]")) # South - - # Tasklet to buffer - state.add_memlet_path(tasklet, - unroll_exit, - output_buffer, - src_conn="result", - memlet=dace.Memlet(f"{output_buffer.data}[u]")) - - # Pack buffer - state.add_memlet_path(output_buffer, - output_buffer_packed, - memlet=dace.Memlet(f"{output_buffer_packed.data}[0]", other_subset=f"0:{veclen}")) - - # Only write if not initializing - write_tasklet = state.add_tasklet(name + "_conditional_write", {"_in"}, {"_out"}, - "if not {}:\n\t_out = _in".format(entry.pipeline.init_condition())) - - # Buffer to output stream - state.add_memlet_path(output_buffer_packed, - write_tasklet, - dst_conn="_in", - memlet=dace.Memlet(f"{output_buffer_packed.data}[0]")) - - # Buffer to output stream - state.add_memlet_path(write_tasklet, - exit, - produce_output_stream, - src_conn="_out", - memlet=dace.Memlet(f"{produce_output_stream.data}[0]", dynamic=True)) - - return sdfg - - -@fpga_test(xilinx=False) -def test_fpga_stencil(): - - import numpy as np - - dtype = dace.float32 - n = 16 - m = 16 - - jacobi = make_sdfg(dtype=dtype) - jacobi.specialize({"N": n, "M": m}) - - a = np.arange(n * m, dtype=dtype.type).reshape((n, m)) - b = np.empty((n, m), dtype=dtype.type) - - jacobi(a=a, b=b) - padded = np.ones((n + 2, m + 2), dtype.type) - padded[1:-1, 1:-1] = a - ref = 0.25 * (padded[:-2, 1:-1] + padded[2:, 1:-1] + padded[1:-1, :-2] + padded[1:-1, 2:]) - - if (b != ref).any(): - raise ValueError("Unexpected output:\nGot: {}\nExpected: {}".format(b, ref)) - - return jacobi - - -if __name__ == "__main__": - test_fpga_stencil(None) diff --git a/tests/fpga/gemv_fpga_test.py b/tests/fpga/gemv_fpga_test.py deleted file mode 100644 index 4c049928b9..0000000000 --- a/tests/fpga/gemv_fpga_test.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -from dace.fpga_testing import fpga_test, import_sample -from pathlib import Path - - -@fpga_test() -def test_gemv_fpga(): - gemv = import_sample(Path("fpga") / "gemv_fpga.py") - return gemv.run_gemv(1024, 1024, False) - - -if __name__ == "__main__": - test_gemv_fpga(None) diff --git a/tests/fpga/hbm_transform_test.py b/tests/fpga/hbm_transform_test.py deleted file mode 100644 index 32bbbf48c0..0000000000 --- a/tests/fpga/hbm_transform_test.py +++ /dev/null @@ -1,283 +0,0 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. - -from dace.fpga_testing import xilinx_test -from dace.sdfg.state import SDFGState -import numpy as np -from dace import dtypes -from dace.transformation.interstate.sdfg_nesting import InlineSDFG -from typing import List, Tuple -from dace.sdfg import SDFG, nodes -import dace -from dace.transformation.dataflow import HbmTransform -from dace.transformation.interstate import NestSDFG -from functools import reduce - - -def set_assignment(sdfg: SDFG, assignments: List[Tuple[str, str, str]]): - for array, memorytype, bank in assignments: - desc = sdfg.arrays[array] - desc.location["memorytype"] = memorytype - desc.location["bank"] = bank - - -def rand_float(input_shape): - a = np.random.rand(*input_shape) - a = a.astype(np.float32) - #a = np.ones(input_shape, np.float32) - return a - - -def _exec_hbmtransform(sdfg_source, assign, nest=False, num_apply=1, apply_to=None): - sdfg = sdfg_source() - set_assignment(sdfg, assign) - if apply_to is None: - assert sdfg.apply_transformations_repeated(HbmTransform, { - "new_dim": "kw", - "move_to_FPGA_global": False - }, - validate=False) == num_apply - if num_apply == 0: - return sdfg - else: - for map_entry in apply_to(sdfg): - HbmTransform.apply_to(sdfg, { - "new_dim": "kw", - "move_to_FPGA_global": False - }, - save=False, - _map_entry=map_entry) - if nest: - for _, desc in sdfg.arrays.items(): - if desc.storage == dtypes.StorageType.Default: - desc.storage = dtypes.StorageType.FPGA_Global - sdfg.apply_transformations(NestSDFG, validate=False) - for _, desc in sdfg.arrays.items(): - if desc.storage == dtypes.StorageType.FPGA_Global: - desc.storage = dtypes.StorageType.Default - sdfg.apply_fpga_transformations(validate=False) - sdfg.apply_transformations_repeated(InlineSDFG, validate=False) - csdfg = sdfg.compile() - return (csdfg, sdfg) - - -def create_vadd_sdfg(name, array_shape=dace.symbol("n"), map_range=dace.symbol("n")): - - @dace.program - def vadd(x: dace.float32[array_shape], y: dace.float32[array_shape], z: dace.float32[array_shape]): - for i in dace.map[0:map_range]: - with dace.tasklet: - xin << x[i] - yin << y[i] - zout >> z[i] - zout = xin + yin - - sdfg = vadd.to_sdfg() - sdfg.name = name - sdfg.apply_strict_transformations() - return sdfg - - -def create_multi_access_sdfg(name): - N = dace.symbol("N") - - @dace.program - def sth(z: dace.float32[N], x: dace.float32[N], y: dace.float32[N], w: dace.float32[N], o1: dace.float32[N], - o2: dace.float32[N]): - for i in dace.map[0:N]: - o1[i] = z[i] + x[i] - for i in dace.map[0:N]: - o2[i] = w[i] + y[i] - - sdfg = sth.to_sdfg() - sdfg.name = name - sdfg.apply_strict_transformations() - return sdfg - - -def create_nd_sdfg(name): - n = dace.symbol("n") - m = dace.symbol("m") - - @dace.program - def nd_sdfg(x: dace.float32[n, m], y: dace.float32[m, n], z: dace.float32[n, m]): - for i in dace.map[0:n]: - for j in dace.map[0:m]: - with dace.tasklet: - yin << y[j, i] - xin << x[i, j] - zout >> z[i, j] - zout = yin + xin - - sdfg = nd_sdfg.to_sdfg() - sdfg.name = name - sdfg.apply_strict_transformations() - return sdfg - - -def create_gemv_blas_sdfg(name, tile_size_y=None, tile_size_x=None, m=None): - N = dace.symbol("N") - M = dace.symbol("M") - - @dace.program - def gemv(A: dace.float32[M, N], x: dace.float32[N], y: dace.float32[M]): - y[:] = A @ x - - sdfg = gemv.to_sdfg() - sdfg.apply_strict_transformations() - if m is not None: - sdfg.specialize({M: m}) - libnode = list(filter(lambda x: isinstance(x, nodes.LibraryNode), sdfg.nodes()[0].nodes()))[0] - libnode.expand(sdfg, sdfg.nodes()[0]) - libnode = list(filter(lambda x: isinstance(x, nodes.LibraryNode), sdfg.nodes()[0].nodes()))[0] - libnode.implementation = "FPGA_TilesByColumn" - libnode.expand(sdfg, sdfg.nodes()[0], tile_size_y=tile_size_y, tile_size_x=tile_size_x) - sdfg.apply_strict_transformations() - sdfg.name = name - return sdfg - - -def validate_vadd_sdfg(csdfg, input_shape): - a = rand_float(input_shape) - b = rand_float(input_shape) - c = rand_float(input_shape) - expect = a + b - - csdfg(x=a, y=b, z=c, n=reduce(lambda x, y: x * y, input_shape)) - assert np.allclose(expect, c) - - -def validate_gemv_sdfg(csdfg, matrix_shape, x_shape, y_shape): - # A and potentially y is assumed to be split along dim 0 - A = rand_float(matrix_shape) - x = rand_float(x_shape) - y = rand_float(y_shape) - expect = np.matmul(A, x) - - csdfg(A=A, x=x, y=y, M=matrix_shape[0] * matrix_shape[1], N=matrix_shape[2]) - if len(y_shape) == 1: - y = np.reshape(y, [matrix_shape[0], matrix_shape[1]]) - assert np.allclose(y, expect) - - -def validate_nd_sdfg(csdfg, m, n, divide_m=1, divide_n=1): - A = np.zeros([divide_m * divide_n, n // divide_n, m // divide_m], np.float32) - B = np.zeros([divide_m * divide_n, m // divide_m, n // divide_n], np.float32) - Z = np.zeros([divide_m * divide_n, n // divide_n, m // divide_m], np.float32) - expect = np.zeros([divide_m * divide_n, n // divide_n, m // divide_m], np.float32) - - for k_i in range(1, divide_n + 1): - for k_j in range(1, divide_m + 1): - for i in range(n // divide_n): - for j in range(m // divide_m): - index = k_i * k_j - 1 - A[index, i, j] = np.random.random() - B[index, j, i] = np.random.random() - expect[index, i, j] = A[index, i, j] + B[index, j, i] - - csdfg(x=A, y=B, z=Z, m=m, n=n) - assert np.allclose(expect, Z) - - -@xilinx_test(run_synthesis=False) -def test_axpy_unroll_3(): - csdfg, sdfg = _exec_hbmtransform(lambda: create_vadd_sdfg("axpy_unroll_3"), - [("x", "HBM", "3:6"), ("y", "HBM", "0:3"), ("z", "HBM", "6:9")]) - validate_vadd_sdfg(csdfg, [3, 20]) - return sdfg - - -@xilinx_test(run_synthesis=False) -def test_axpy_unroll_mixed(): - csdfg, sdfg = _exec_hbmtransform(lambda: create_vadd_sdfg("axpy_mixed"), [("x", "DDR", "0"), ("y", "HBM", "0:2"), - ("z", "HBM", "0:2")]) - validate_vadd_sdfg(csdfg, [2, 20]) - return sdfg - - -@xilinx_test(run_synthesis=False) -def test_nd_split(): - csdfg, sdfg = _exec_hbmtransform(lambda: create_nd_sdfg("nd_split"), [("x", "HBM", "0:10"), ("y", "HBM", "10:20"), - ("z", "HBM", "20:30")]) - validate_nd_sdfg(csdfg, 10, 10, divide_n=10) - return sdfg - - -@xilinx_test(run_synthesis=False) -def test_nd_split_inner(): - - def apply_to(sdfg): - state: SDFGState = sdfg.start_state - for node in state.nodes(): - if isinstance(node, nodes.MapEntry) and node.map.params[0] == "j": - return [node] - - csdfg, sdfg = _exec_hbmtransform(lambda: create_nd_sdfg("nd_split_inner"), [("x", "HBM", "0:10"), - ("y", "HBM", "10:20"), - ("z", "HBM", "20:30")], - apply_to=apply_to) - validate_nd_sdfg(csdfg, 10, 10, divide_m=10) - return sdfg - - -@xilinx_test(run_synthesis=False) -def test_gemv_blas_1(): - csdfg, sdfg = _exec_hbmtransform(lambda: create_gemv_blas_sdfg("gemv_1", 32), [("x", "HBM", "31:32"), - ("y", "HBM", "30:31"), - ("A", "HBM", "0:30")], True) - validate_gemv_sdfg(csdfg, [30, 32, 5], [5], [32 * 30]) - return sdfg - - -@xilinx_test(run_synthesis=False) -def test_gemv_blas_2(): - csdfg, sdfg = _exec_hbmtransform(lambda: create_gemv_blas_sdfg("gemv_2", 32), [("x", "HBM", "31:32"), - ("y", "HBM", "15:30"), - ("A", "HBM", "0:15")], True) - validate_gemv_sdfg(csdfg, [15, 32, 5], [5], [15, 32]) - return sdfg - - -@xilinx_test(run_synthesis=False) -def test_multiple_applications(): - _, sdfg = _exec_hbmtransform(lambda: create_multi_access_sdfg("multi_access"), [("x", "HBM", "0:2"), - ("y", "HBM", "2:4"), - ("z", "HBM", "4:6"), - ("w", "HBM", "10:12"), - ("o1", "HBM", "6:8"), - ("o2", "HBM", "8:10")], - num_apply=2) - return sdfg - - -# This does not run with synthesis enabled -@xilinx_test(run_synthesis=False) -def test_axpy_unroll_1(): - # This SDFG is fine, but we would do nothing at all - sdfg = _exec_hbmtransform(lambda: create_vadd_sdfg("axpy_unroll_1"), [("x", "DDR", "0"), ("y", "HBM", "0:1"), - ("z", "DDR", "1")], - num_apply=0) - sdfg.compile() # We still have to compile for pytest, so the build folder exists - return sdfg - - -# This does not run with synthesis enabled -@xilinx_test(run_synthesis=False) -def test_axpy_inconsistent_no_apply(): - sdfg = _exec_hbmtransform(lambda: create_vadd_sdfg("axpy_inconsistent"), [("x", "HBM", "0:2"), ("y", "DDR", "0"), - ("z", "HBM", "0:3")], - num_apply=0) - set_assignment(sdfg, [("x", "DDR", "0"), ("y", "HBM", "0:1"), ("z", "DDR", "1")]) - sdfg.compile() # We still have to compile for pytest, so the build folder exists - return sdfg - - -if __name__ == "__main__": - test_axpy_unroll_3(None) - test_axpy_unroll_1(None) - test_axpy_unroll_mixed(None) - test_nd_split(None) - test_nd_split_inner(None) - test_gemv_blas_1(None) - test_gemv_blas_2(None) - test_axpy_inconsistent_no_apply(None) - test_multiple_applications(None) diff --git a/tests/fpga/jacobi_fpga_test.py b/tests/fpga/jacobi_fpga_test.py deleted file mode 100644 index 290eccaaf7..0000000000 --- a/tests/fpga/jacobi_fpga_test.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -from dace.fpga_testing import xilinx_test, import_sample -from pathlib import Path -import pytest - - -# This kernel does not work with the Intel FPGA codegen, because it uses the -# constant systolic array index in the connector on the nested SDFG. -@pytest.mark.skip('Xilinx failure due to unresolved phi nodes, Intel FPGA failure due to systolic array index') -@xilinx_test(assert_ii_1=False) -def test_jacobi_fpga(): - jacobi = import_sample(Path("fpga") / "jacobi_fpga_systolic.py") - return jacobi.run_jacobi(64, 512, 16, 4) - - -if __name__ == "__main__": - test_jacobi_fpga(None) diff --git a/tests/fpga/kernel_detection_test.py b/tests/fpga/kernel_detection_test.py deleted file mode 100644 index 4e74175821..0000000000 --- a/tests/fpga/kernel_detection_test.py +++ /dev/null @@ -1,574 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. - -# Tests for kernels detection - -import dace -import numpy as np -from pathlib import Path -import pytest -import re -from dace.sdfg.utils import is_fpga_kernel -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.fpga_testing import fpga_test -from dace import config - - -def count_kernels(sdfg: dace.SDFG): - """ - Test utility functions: Counts the number of generated device kernels - - :param sdfg: Already compiled SDFG to count kernels for. - :return: number of kernels - """ - - import csv - kernels = 0 - with open(Path(sdfg.build_folder) / "dace_files.csv", "r") as csv_file: - csv_reader = csv.reader(csv_file, delimiter=',') - for row in csv_reader: - if row[1] == "device" and (row[-1].endswith("cpp") or row[-1].endswith("cl")): - kernels = kernels + 1 - return kernels - - -@fpga_test() -def test_kernels_inside_component_0(): - """ - Tests for kernels detection inside a single connected component. - It computes z =(x+y) + (v+w) - - High-level overview: - ┌───────────┐ - │ Add_Map_0 │ - └──────┬────┘ - │ - ┌───────────┐ ┌───────────┐ - │ Add_Map_1 │ │ Add_Map_2 │ - └──────┬────┘ └──────┬────┘ - │ ┌───────────┐ │ - └─► │ Add_Map_3 │◄───┘ - └───────────┘ - The 4 maps, should belong to three distinct kernels - """ - - @dace.program - def kernels_inside_component_0(x: dace.float32[8], y: dace.float32[8], v: dace.float32[8], w: dace.float32[8], - z: dace.float32[8]): - tmp = (x + y) + v - return tmp + (w + z) - - x = np.random.rand(8).astype(np.float32) - y = np.random.rand(8).astype(np.float32) - v = np.random.rand(8).astype(np.float32) - w = np.random.rand(8).astype(np.float32) - z = np.random.rand(8).astype(np.float32) - - sdfg = kernels_inside_component_0.to_sdfg() - sdfg.apply_transformations([FPGATransformSDFG, InlineSDFG]) - - for state in sdfg.states(): - if is_fpga_kernel(sdfg, state): - state.instrument = dace.InstrumentationType.FPGA - - with config.set_temporary("compiler", "fpga", "concurrent_kernel_detection", value=True): - res = sdfg(x=x, y=y, v=v, w=w, z=z) - assert count_kernels(sdfg) == 3 - assert np.allclose(res, x + y + v + w + z) - - report = sdfg.get_latest_report() - assert len(report.durations[(0, 0, -1)]) == 5 - - full_fpga_events = 0 - for event_name in report.durations[(0, 0, -1)]: - if "Full FPGA" in event_name: - full_fpga_events += 1 - - assert full_fpga_events == 2 - - return sdfg - - -@fpga_test() -def test_kernels_inside_component_1(): - """ - Tests for kernels detection inside a single connected component. - The program computes: - - z = alpha*((x+y) + (v+w)) - - t = beta*((x+y) + (v+w)) - - High-level overview: - ┌───────────┐ ┌───────────┐ - │ Add_Map_0 │ │ Add_Map_1 │ - └──────┬────┘ └──────┬────┘ - │ ┌───────────┐ │ - └─► │ Add_Map_2 │◄───┘ - ────└───────────┘──── - │ │ - ┌──────v────┐ ┌─────v─────┐ - │ Mul_3 │ │ Mul_4 │ - └───────────┘ └───────────┘ - - The five Maps should belong to 5 distinct kernels - - """ - - @dace.program - def kernels_inside_component_1(x: dace.float32[8], y: dace.float32[8], v: dace.float32[8], w: dace.float32[8], - z: dace.float32[8], t: dace.float32[8], alpha: dace.float32, beta: dace.float32): - tmp1 = x + y - tmp2 = v + w - tmp3 = tmp1 + tmp2 - z[:] = alpha * tmp3 - t[:] = beta * tmp3 - - x = np.random.rand(8).astype(np.float32) - y = np.random.rand(8).astype(np.float32) - v = np.random.rand(8).astype(np.float32) - w = np.random.rand(8).astype(np.float32) - z = np.random.rand(8).astype(np.float32) - t = np.random.rand(8).astype(np.float32) - alpha = 1.0 - beta = 2.0 - - sdfg = kernels_inside_component_1.to_sdfg() - sdfg.apply_transformations([FPGATransformSDFG, InlineSDFG]) - - with config.set_temporary("compiler", "fpga", "concurrent_kernel_detection", value=True): - program = sdfg.compile() - assert count_kernels(sdfg) == 5 - program(x=x, y=y, v=v, w=w, z=z, t=t, alpha=alpha, beta=beta) - ref_z = alpha * (x + y + v + w) - ref_t = beta * (x + y + v + w) - assert np.allclose(z, ref_z) - assert np.allclose(t, ref_t) - - return sdfg - - -@fpga_test() -def test_kernels_inside_component_2(): - """ - Tests for PEs detection inside a single Component. - It computes z =(x+y) and t = (y+v) - - - High-level overview: - - x y v - │ │ │ - ┌──V────────<───┘────>────V──────┐ - │ Add_Map_0 │ │ Add_Map_1 │ - └───────────┘ └───────────┘ - - Map_0 and Map_1 should belong to two distinct kernels - """ - - @dace.program - def kernels_inside_component_2(x: dace.float32[8], y: dace.float32[8], v: dace.float32[8], z: dace.float32[8], - t: dace.float32[8]): - z[:] = x + y - t[:] = y + v - - x = np.random.rand(8).astype(np.float32) - y = np.random.rand(8).astype(np.float32) - v = np.random.rand(8).astype(np.float32) - z = np.random.rand(8).astype(np.float32) - t = np.random.rand(8).astype(np.float32) - - sdfg = kernels_inside_component_2.to_sdfg() - sdfg.apply_transformations([FPGATransformSDFG, InlineSDFG]) - with config.set_temporary("compiler", "fpga", "concurrent_kernel_detection", value=True): - program = sdfg.compile() - - # NOTE: here we have only one kernel since subgraph detection already - # detects two PEs - assert count_kernels(sdfg) == 1 - program(x=x, y=y, v=v, t=t, z=z) - assert np.allclose(z, x + y) - assert np.allclose(t, v + y) - - return sdfg - - -@fpga_test(assert_ii_1=False) -def test_kernels_lns_inside_component(): - """ - Tests for kernels detection inside a single connected component where we - have multiple library nodes. - - It computes z =(x+y) + (v+w) - - High-level overview: - ┌───────────┐ ┌───────────┐ - │ Matmul_0 │ │ Matmul_1 │ - └──────┬────┘ └──────┬────┘ - │ ┌───────────┐ │ - └─► │ Dot_2 │◄───┘ - └───────────┘ - """ - - # (Provisional) Disable unique function - unique_functions_conf = dace.config.Config.get('compiler', 'unique_functions') - dace.config.Config.set('compiler', 'unique_functions', value="none") - - @dace.program - def kernels_lns_inside_component(A: dace.float32[8, 8], x: dace.float32[8], B: dace.float32[8, 8], - y: dace.float32[8]): - tmp1 = A @ x - tmp2 = B @ y - return np.dot(tmp1, tmp2) - - A = np.random.rand(8, 8).astype(np.float32) - B = np.random.rand(8, 8).astype(np.float32) - x = np.random.rand(8).astype(np.float32) - y = np.random.rand(8).astype(np.float32) - - sdfg = kernels_lns_inside_component.to_sdfg() - sdfg.apply_transformations([FPGATransformSDFG, InlineSDFG]) - - with config.set_temporary("compiler", "fpga", "concurrent_kernel_detection", value=True): - program = sdfg.compile() - - assert count_kernels(sdfg) == 3 - z = program(A=A, x=x, B=B, y=y) - ref = np.dot(A @ x, B @ y) - assert np.allclose(z, ref) - dace.config.Config.set('compiler', 'unique_functions', value=unique_functions_conf) - - return sdfg - - -@fpga_test() -def test_kernels_inside_components_0(): - """ - Tests for kernels detection in two distinct connected components. - The program computes: - z = (x+y) + (v+w) - zz = (xx+yy) + (vv+ww) - - High-level overview: the two connected components are the same and look - like the following - ┌───────────┐ ┌───────────┐ - │ Add_Map_0 │ │ Add_Map_1 │ - └──────┬────┘ └──────┬────┘ - │ ┌───────────┐ │ - └─► │ Add_Map_2 │◄───┘ - └───────────┘ - The three maps, should belong to three distinct kernels - - """ - - @dace.program - def kernels_inside_components_0(x: dace.float32[8], y: dace.float32[8], v: dace.float32[8], w: dace.float32[8], - xx: dace.float32[8], yy: dace.float32[8], vv: dace.float32[8], ww: dace.float32[8]): - z = (x + y) + (v + w) - zz = (xx + yy) + (vv + ww) - return z, zz - - x = np.random.rand(8).astype(np.float32) - y = np.random.rand(8).astype(np.float32) - v = np.random.rand(8).astype(np.float32) - w = np.random.rand(8).astype(np.float32) - xx = np.random.rand(8).astype(np.float32) - yy = np.random.rand(8).astype(np.float32) - vv = np.random.rand(8).astype(np.float32) - ww = np.random.rand(8).astype(np.float32) - - sdfg = kernels_inside_components_0.to_sdfg() - sdfg.apply_transformations([FPGATransformSDFG, InlineSDFG]) - - with config.set_temporary("compiler", "fpga", "concurrent_kernel_detection", value=True): - program = sdfg.compile() - - assert count_kernels(sdfg) == 6 - z, zz = program(x=x, y=y, v=v, w=w, xx=xx, yy=yy, vv=vv, ww=ww) - assert np.allclose(z, x + y + v + w) - assert np.allclose(zz, xx + yy + vv + ww) - - return sdfg - - -@fpga_test() -def test_kernels_inside_components_multiple_states(): - """ - Tests for kernels detection in two distinct states. - It computes - z = (x+y) + (v+w) - zz = (xx+yy) + (vv+ww) - - High-level overview: the two connected components are the same and look - like the following - ┌───────────┐ ┌───────────┐ - │ Add_Map_0 │ │ Add_Map_1 │ - └──────┬────┘ └──────┬────┘ - │ ┌───────────┐ │ - └─► │ Add_Map_2 │◄───┘ - └───────────┘ - The three maps, should belong to three distinct kernels - """ - - def make_sdfg(dtype=dace.float32): - sdfg = dace.SDFG("multiple_kernels_multiple_states") - n = dace.symbol("size") - - input_data = ["x", "y", "v", "w", "xx", "yy", "vv", "ww"] - output_data = ["z", "zz"] - device_transient_data = ["device_tmp0", "device_tmp1", "device_tmp2", "device_tmp3"] - - for d in input_data + output_data: - sdfg.add_array(d, shape=[n], dtype=dtype) - sdfg.add_array(f"device_{d}", - shape=[n], - dtype=dtype, - storage=dace.dtypes.StorageType.FPGA_Global, - transient=True) - - for d in device_transient_data: - sdfg.add_array(d, shape=[n], dtype=dtype, storage=dace.dtypes.StorageType.FPGA_Global, transient=True) - - ########################################################################### - # Copy data to FPGA - - copy_in_state = sdfg.add_state("copy_to_device") - - for d in input_data: - in_host = copy_in_state.add_read(d) - in_device = copy_in_state.add_read(f"device_{d}") - - copy_in_state.add_memlet_path(in_host, in_device, memlet=dace.Memlet(f"{d}[0:{n}]")) - - ########################################################################### - # Copy data from FPGA - copy_out_state = sdfg.add_state("copy_to_host") - - for d in output_data: - out_host = copy_out_state.add_write(d) - out_device = copy_out_state.add_read(f"device_{d}") - - copy_out_state.add_memlet_path(out_device, out_host, memlet=dace.Memlet(f"{d}[0:{n}]")) - - ######################################################################## - # FPGA, First State - - fpga_state_0 = sdfg.add_state("fpga_state_0") - - x_in = fpga_state_0.add_read("device_x") - y_in = fpga_state_0.add_read("device_y") - v_in = fpga_state_0.add_read("device_v") - w_in = fpga_state_0.add_read("device_w") - device_tmp0 = fpga_state_0.add_access("device_tmp0") - device_tmp1 = fpga_state_0.add_access("device_tmp1") - z_out = fpga_state_0.add_write("device_z") - - # x + y - vecMap_entry00, vecMap_exit00 = fpga_state_0.add_map('vecAdd_map00', - dict(i=f'0:{n}'), - schedule=dace.dtypes.ScheduleType.FPGA_Device) - - vecAdd_tasklet00 = fpga_state_0.add_tasklet('vec_add_task00', ['x_con', 'y_con'], ['z_con'], - 'z_con = x_con + y_con') - - fpga_state_0.add_memlet_path(x_in, - vecMap_entry00, - vecAdd_tasklet00, - dst_conn='x_con', - memlet=dace.Memlet("device_x[i]")) - - fpga_state_0.add_memlet_path(y_in, - vecMap_entry00, - vecAdd_tasklet00, - dst_conn='y_con', - memlet=dace.Memlet("device_y[i]")) - - fpga_state_0.add_memlet_path(vecAdd_tasklet00, - vecMap_exit00, - device_tmp0, - src_conn='z_con', - memlet=dace.Memlet("device_tmp0[i]")) - - # v + w - - vecMap_entry01, vecMap_exit01 = fpga_state_0.add_map('vecAdd_map01', - dict(i=f'0:{n}'), - schedule=dace.dtypes.ScheduleType.FPGA_Device) - - vecAdd_tasklet01 = fpga_state_0.add_tasklet('vec_add_task01', ['x_con', 'y_con'], ['z_con'], - 'z_con = x_con + y_con') - - fpga_state_0.add_memlet_path(v_in, - vecMap_entry01, - vecAdd_tasklet01, - dst_conn='x_con', - memlet=dace.Memlet(f"device_v[i]")) - - fpga_state_0.add_memlet_path(w_in, - vecMap_entry01, - vecAdd_tasklet01, - dst_conn='y_con', - memlet=dace.Memlet(f"device_w[i]")) - - fpga_state_0.add_memlet_path(vecAdd_tasklet01, - vecMap_exit01, - device_tmp1, - src_conn='z_con', - memlet=dace.Memlet(f"device_tmp1[i]")) - - # tmp0 + tmp 1 - - vecMap_entry02, vecMap_exit02 = fpga_state_0.add_map('vecAdd_map02', - dict(i=f'0:{n}'), - schedule=dace.dtypes.ScheduleType.FPGA_Device) - - vecAdd_tasklet02 = fpga_state_0.add_tasklet('vec_add_task02', ['x_con', 'y_con'], ['z_con'], - 'z_con = x_con + y_con') - - fpga_state_0.add_memlet_path(device_tmp0, - vecMap_entry02, - vecAdd_tasklet02, - dst_conn='x_con', - memlet=dace.Memlet("device_tmp0[i]")) - - fpga_state_0.add_memlet_path(device_tmp1, - vecMap_entry02, - vecAdd_tasklet02, - dst_conn='y_con', - memlet=dace.Memlet("device_tmp1[i]")) - - fpga_state_0.add_memlet_path(vecAdd_tasklet02, - vecMap_exit02, - z_out, - src_conn='z_con', - memlet=dace.Memlet("device_z[i]")) - ######################################################################## - # FPGA, Second State - - fpga_state_1 = sdfg.add_state("fpga_state_1") - - xx_in = fpga_state_1.add_read("device_xx") - yy_in = fpga_state_1.add_read("device_yy") - vv_in = fpga_state_1.add_read("device_vv") - ww_in = fpga_state_1.add_read("device_ww") - device_tmp2 = fpga_state_1.add_access("device_tmp2") - device_tmp3 = fpga_state_1.add_access("device_tmp3") - zz_out = fpga_state_1.add_write("device_zz") - - # xx + yy - vecMap_entry10, vecMap_exit10 = fpga_state_1.add_map('vecAdd_map10', - dict(i=f'0:{n}'), - schedule=dace.dtypes.ScheduleType.FPGA_Device) - - vecAdd_tasklet10 = fpga_state_1.add_tasklet('vec_add_task10', ['x_con', 'y_con'], ['z_con'], - 'z_con = x_con + y_con') - - fpga_state_1.add_memlet_path(xx_in, - vecMap_entry10, - vecAdd_tasklet10, - dst_conn='x_con', - memlet=dace.Memlet("device_xx[i]")) - - fpga_state_1.add_memlet_path(yy_in, - vecMap_entry10, - vecAdd_tasklet10, - dst_conn='y_con', - memlet=dace.Memlet("device_yy[i]")) - - fpga_state_1.add_memlet_path(vecAdd_tasklet10, - vecMap_exit10, - device_tmp2, - src_conn='z_con', - memlet=dace.Memlet("device_tmp2[i]")) - - # vv + ww - vecMap_entry11, vecMap_exit11 = fpga_state_1.add_map('vecAdd_map11', - dict(i=f'0:{n}'), - schedule=dace.dtypes.ScheduleType.FPGA_Device) - - vecAdd_tasklet11 = fpga_state_1.add_tasklet('vec_add_task11', ['x_con', 'y_con'], ['z_con'], - 'z_con = x_con + y_con') - - fpga_state_1.add_memlet_path(vv_in, - vecMap_entry11, - vecAdd_tasklet11, - dst_conn='x_con', - memlet=dace.Memlet(f"device_vv[i]")) - - fpga_state_1.add_memlet_path(ww_in, - vecMap_entry11, - vecAdd_tasklet11, - dst_conn='y_con', - memlet=dace.Memlet(f"device_ww[i]")) - - fpga_state_1.add_memlet_path(vecAdd_tasklet11, - vecMap_exit11, - device_tmp3, - src_conn='z_con', - memlet=dace.Memlet(f"device_tmp3[i]")) - - # tmp2 + tmp 3 - - vecMap_entry12, vecMap_exit12 = fpga_state_1.add_map('vecAdd_map12', - dict(i=f'0:{n}'), - schedule=dace.dtypes.ScheduleType.FPGA_Device) - - vecAdd_tasklet12 = fpga_state_1.add_tasklet('vec_add_task12', ['x_con', 'y_con'], ['z_con'], - 'z_con = x_con + y_con') - - fpga_state_1.add_memlet_path(device_tmp2, - vecMap_entry12, - vecAdd_tasklet12, - dst_conn='x_con', - memlet=dace.Memlet("device_tmp2[i]")) - - fpga_state_1.add_memlet_path(device_tmp3, - vecMap_entry12, - vecAdd_tasklet12, - dst_conn='y_con', - memlet=dace.Memlet("device_tmp3[i]")) - - fpga_state_1.add_memlet_path(vecAdd_tasklet12, - vecMap_exit12, - zz_out, - src_conn='z_con', - memlet=dace.Memlet("device_zz[i]")) - - ###################################### - # Interstate edges - sdfg.add_edge(copy_in_state, fpga_state_0, dace.sdfg.sdfg.InterstateEdge()) - sdfg.add_edge(fpga_state_0, fpga_state_1, dace.sdfg.sdfg.InterstateEdge()) - sdfg.add_edge(fpga_state_1, copy_out_state, dace.sdfg.sdfg.InterstateEdge()) - - ######### - # Validate - sdfg.fill_scope_connectors() - sdfg.validate() - return sdfg - - x = np.random.rand(8).astype(np.float32) - y = np.random.rand(8).astype(np.float32) - v = np.random.rand(8).astype(np.float32) - w = np.random.rand(8).astype(np.float32) - z = np.random.rand(8).astype(np.float32) - xx = np.random.rand(8).astype(np.float32) - yy = np.random.rand(8).astype(np.float32) - vv = np.random.rand(8).astype(np.float32) - ww = np.random.rand(8).astype(np.float32) - zz = np.random.rand(8).astype(np.float32) - - sdfg = make_sdfg() - with config.set_temporary("compiler", "fpga", "concurrent_kernel_detection", value=True): - program = sdfg.compile() - assert count_kernels(sdfg) == 6 - program(z=z, zz=zz, x=x, y=y, v=v, w=w, xx=xx, yy=yy, vv=vv, ww=ww, size=8) - assert np.allclose(z, x + y + v + w) - assert np.allclose(zz, xx + yy + vv + ww) - - return sdfg - - -if __name__ == "__main__": - test_kernels_inside_component_0(None) - test_kernels_inside_component_1(None) - test_kernels_inside_component_2(None) - test_kernels_lns_inside_component(None) - test_kernels_inside_components_0(None) - test_kernels_inside_components_multiple_states(None) diff --git a/tests/fpga/long_long_opencl_test.py b/tests/fpga/long_long_opencl_test.py deleted file mode 100644 index 23852bb3cc..0000000000 --- a/tests/fpga/long_long_opencl_test.py +++ /dev/null @@ -1,68 +0,0 @@ -# Simple test case to check that for the intel fpga the openCL type long gets used instead of C type long long - -import dace.dtypes -import numpy as np -import dace as dc -import argparse -from dace.fpga_testing import intel_fpga_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG - -N = dc.symbol('N', dtype=dc.int64) - - -@dc.program -def simple_add_kernel(A: dc.int64[N], B: dc.int64[N]): - B += A - - -def initialize(N, datatype=np.int64): - A = np.fromfunction(lambda i: (i + 372036854775807), (N, ), dtype=datatype) - B = np.fromfunction(lambda i: (i + 3372036854775807), (N, ), dtype=datatype) - return A, B - - -def ground_truth(A, B): - B += A - - -def run_simple_add(device_type: dace.dtypes.DeviceType): - ''' - Runs simple add for the given device - :return: the SDFG - ''' - - # Initialize data (polybench small size) - N = 120 - A, B = initialize(N) - A_ref = np.copy(A) - B_ref = np.copy(B) - - # Parse SDFG and apply FPGA friendly optimization - sdfg = simple_add_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - sdfg.specialize(dict(N=N)) - sdfg(A=A, B=B) - - # Compute ground truth and validate - ground_truth(A_ref, B_ref) - assert np.allclose(B, B_ref) - return sdfg - - -@intel_fpga_test(assert_ii_1=False) -def test_fpga(): - return run_simple_add(dace.dtypes.DeviceType.FPGA) - - -if __name__ == "__main__": - - parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') - - args = vars(parser.parse_args()) - target = args["target"] - - if target == "fpga": - run_simple_add(dace.dtypes.DeviceType.FPGA) diff --git a/tests/fpga/mandelbrot_fpga_test.py b/tests/fpga/mandelbrot_fpga_test.py deleted file mode 100644 index 20579cef54..0000000000 --- a/tests/fpga/mandelbrot_fpga_test.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -import dace -from dace.fpga_testing import fpga_test, import_sample -from dace.transformation.interstate import FPGATransformSDFG -from pathlib import Path - - -# TODO: Pipeline control flow while-loop? -@fpga_test(assert_ii_1=False) -def test_mandelbrot_fpga(): - mandelbrot = import_sample(Path("simple") / "mandelbrot.py") - h, w, max_iterations = 64, 64, 1000 - out = dace.ndarray([h, w], dtype=dace.uint16) - out[:] = dace.uint32(0) - sdfg = mandelbrot.mandelbrot.to_sdfg() - sdfg.apply_transformations(FPGATransformSDFG) - sdfg(output=out, maxiter=max_iterations, W=w, H=h) - return sdfg - - -if __name__ == "__main__": - test_mandelbrot_fpga(None) diff --git a/tests/fpga/map_unroll_processing_elements_test.py b/tests/fpga/map_unroll_processing_elements_test.py deleted file mode 100644 index 069ab47e04..0000000000 --- a/tests/fpga/map_unroll_processing_elements_test.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -import dace -import dace.sdfg.nodes as nodes -from dace.fpga_testing import xilinx_test -import importlib.util -import numpy as np -from pathlib import Path -import pytest -from dace.config import set_temporary - - -@pytest.mark.skip('Xilinx HLS fails due to unresolved phi nodes') -@xilinx_test(assert_ii_1=False) -def test_map_unroll_processing_elements(): - # Grab the systolic GEMM implementation the samples directory - # To achieve II=1 with Xilinx, we need to decouple reads/writes from memory - - spec = importlib.util.spec_from_file_location( - "gemm", - Path(__file__).parent.parent.parent / "samples" / "fpga" / "gemm_systolic_vectorized.py") - gemm = importlib.util.module_from_spec(spec) - spec.loader.exec_module(gemm) - - N = 128 - K = 256 - M = 512 - P = 8 - W = 4 - TN = 32 - TM = 128 - - # Create an SDFG with multiple processing elements - sdfg = gemm.make_sdfg("map_unroll_processing_elements", dace.vector(dace.float32, W)) - sdfg.specialize({"P": P, "W": W, "TN": TN, "TM": TM}) - for state in sdfg.states(): - for node in state.nodes(): - if isinstance(node, nodes.MapEntry) and node.params == ["p"]: - node.unroll = False - node.schedule = dace.ScheduleType.Unrolled - - # Initialize arrays: Randomize A and B, zero C - A = np.ndarray([N, K], dtype=dace.float32.type) - B = np.ndarray([K, M], dtype=dace.float32.type) - C = np.ndarray([N, M], dtype=dace.float32.type) - A[:] = np.random.rand(N, K).astype(dace.float32.type) - B[:] = np.random.rand(K, M).astype(dace.float32.type) - C[:] = np.random.rand(N, M).astype(dace.float32.type) - - C_regression = A @ B + C - - sdfg(A=A, B=B, C=C, N=N, M=M, K=K) - diff = np.linalg.norm(C_regression - C) / float(N * M) - if not np.allclose(C_regression, C): - raise ValueError("Verification failed.") - - return sdfg - - -@pytest.mark.skip('Test no longer achieves II=1') -@xilinx_test(assert_ii_1=True) -def test_map_unroll_processing_elements_decoupled(): - # Grab the systolic GEMM implementation the samples directory - - spec = importlib.util.spec_from_file_location( - "gemm", - Path(__file__).parent.parent.parent / "samples" / "fpga" / "gemm_systolic_vectorized.py") - gemm = importlib.util.module_from_spec(spec) - spec.loader.exec_module(gemm) - - N = 128 - K = 256 - M = 512 - P = 8 - W = 4 - TN = 32 - TM = 128 - - # Create an SDFG with multiple processing elements - sdfg = gemm.make_sdfg("map_unroll_processing_elements", dace.vector(dace.float32, W)) - sdfg.specialize({"P": P, "W": W, "TN": TN, "TM": TM}) - for state in sdfg.states(): - for node in state.nodes(): - if isinstance(node, nodes.MapEntry) and node.params == ["p"]: - node.unroll = False - node.schedule = dace.ScheduleType.Unrolled - - # Initialize arrays: Randomize A and B, zero C - A = np.ndarray([N, K], dtype=dace.float32.type) - B = np.ndarray([K, M], dtype=dace.float32.type) - C = np.ndarray([N, M], dtype=dace.float32.type) - A[:] = np.random.rand(N, K).astype(dace.float32.type) - B[:] = np.random.rand(K, M).astype(dace.float32.type) - C[:] = np.random.rand(N, M).astype(dace.float32.type) - - C_regression = A @ B + C - - with set_temporary("compiler", "xilinx", "decouple_array_interfaces", value=True): - sdfg(A=A, B=B, C=C, N=N, M=M, K=K) - diff = np.linalg.norm(C_regression - C) / float(N * M) - if not np.allclose(C_regression, C): - raise ValueError("Verification failed.") - - return sdfg - - -if __name__ == "__main__": - test_map_unroll_processing_elements(None) - test_map_unroll_processing_elements_decoupled(None) diff --git a/tests/fpga/matmul_test.py b/tests/fpga/matmul_test.py deleted file mode 100644 index 1a5423bd52..0000000000 --- a/tests/fpga/matmul_test.py +++ /dev/null @@ -1,255 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -import dace -from dace.fpga_testing import fpga_test, import_sample, xilinx_test -import dace.libraries.blas as blas -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -import numpy as np -import pytest -from pathlib import Path -from dace.config import set_temporary - - -def create_gemm_sdfg(sdfg_name, - alpha, - beta, - A, - B, - C, - dtype, - transA=False, - transB=False, - vec_width=1, - expansion_args=None): - """ - Build an SDFG that perform the given GEMM operation along the given axis - Input data A, B, and C is not vectorized - """ - sdfg = dace.SDFG(sdfg_name) - - ########################################################################### - # Copy data to FPGA - - copy_in_state = sdfg.add_state("copy_to_device") - A_shape = A.shape - B_shape = B.shape - C_shape = C.shape - N = A_shape[0] - K = A_shape[1] - M = B_shape[1] - vec_type = dace.vector(dtype, vec_width) - - # Create data containers - sdfg.add_array('A', A_shape, dtype) - sdfg.add_array("A_device", shape=A_shape, dtype=dtype, storage=dace.dtypes.StorageType.FPGA_Global, transient=True) - sdfg.add_array("B", [K, M / vec_width], dtype=vec_type) - sdfg.add_array("B_device", [K, M / vec_width], - dtype=vec_type, - transient=True, - storage=dace.dtypes.StorageType.FPGA_Global) - - sdfg.add_array("C", [N, M / vec_width], dtype=vec_type) - sdfg.add_array("C_device", [N, M / vec_width], - dtype=vec_type, - transient=True, - storage=dace.dtypes.StorageType.FPGA_Global) - - # Copy A - in_host_A = copy_in_state.add_read("A") - in_device_A = copy_in_state.add_write("A_device") - copy_in_state.add_memlet_path(in_host_A, in_device_A, memlet=dace.Memlet(f"A[0:{N}, 0:{K}]")) - - # Copy B - in_host_B = copy_in_state.add_read("B") - in_device_B = copy_in_state.add_write("B_device") - copy_in_state.add_memlet_path(in_host_B, in_device_B, memlet=dace.Memlet(f"B[0:{K}, 0:{M}/{vec_width}]")) - - # Copy C - in_host_C = copy_in_state.add_read("C") - in_device_C = copy_in_state.add_write("C_device") - copy_in_state.add_memlet_path(in_host_C, in_device_C, memlet=dace.Memlet(f"C[0:{N}, 0:{M}/{vec_width}]")) - - ########################################################################### - # Copy data from FPGA - copy_out_state = sdfg.add_state("copy_from_device") - - out_device = copy_out_state.add_read("C_device") - out_host = copy_out_state.add_write("C") - copy_out_state.add_memlet_path(out_device, out_host, memlet=dace.Memlet(f"C[0:{N}, 0:{M}//{vec_width}]")) - - ######################################################################## - # FPGA State - - fpga_state = sdfg.add_state("fpga_state") - in_A = fpga_state.add_read("A_device") - in_B = fpga_state.add_read("B_device") - in_C = fpga_state.add_read("C_device") - out_C = fpga_state.add_read("C_device") - - gemm_node = blas.Gemm("gemm", transA=transA, transB=transB, alpha=alpha, beta=beta) - gemm_node.implementation = "FPGA1DSystolic" - - fpga_state.add_memlet_path(in_A, gemm_node, dst_conn="_a", memlet=dace.Memlet(f"A_device[0:{N}, 0:{K}]")) - fpga_state.add_memlet_path(in_B, - gemm_node, - dst_conn="_b", - memlet=dace.Memlet(f"B_device[0:{K}, 0:{M}/{vec_width}]")) - fpga_state.add_memlet_path(in_C, - gemm_node, - dst_conn="_c", - memlet=dace.Memlet(f"C_device[0:{N}, 0:{M}/{vec_width}]")) - fpga_state.add_memlet_path(gemm_node, - out_C, - src_conn="_c", - memlet=dace.Memlet(f"C_device[0:{N}, 0:{M}/{vec_width}]")) - - ###################################### - # Interstate edges - sdfg.add_edge(copy_in_state, fpga_state, dace.sdfg.sdfg.InterstateEdge()) - sdfg.add_edge(fpga_state, copy_out_state, dace.sdfg.sdfg.InterstateEdge()) - sdfg.validate() - - if expansion_args is not None: - gemm_node.expand(sdfg, fpga_state, **expansion_args) - - return sdfg - - -@fpga_test(assert_ii_1=False) -def test_naive_matmul_fpga(): - matmul = import_sample(Path("optimization") / "matmul.py") - sdfg = matmul.matmul.to_sdfg() - sdfg.apply_transformations(FPGATransformSDFG) - - n, k, m = 64, 64, 64 - - A = np.random.rand(m, k).astype(np.float64) - B = np.random.rand(k, n).astype(np.float64) - C = np.zeros((m, n), dtype=np.float64) - - sdfg(A=A, B=B, C=C, N=n, K=k, M=m) - - expected = A @ B - diff = np.linalg.norm(C - expected) / (m * n) - - assert diff <= 1e-6 - - return sdfg - - -@fpga_test(xilinx=False) -def test_systolic_matmul_fpga(): - matmul = import_sample(Path("fpga") / "matrix_multiplication_systolic.py") - return matmul.run_matmul_systolic(128, 32, 64, 4, False) - - -@fpga_test(assert_ii_1=False, xilinx=False) -def test_gemm_vectorized(): - # Test with vectorization - # To achieve II=1 with Xilinx, we need to decouple reads/writes from memory - A = np.random.rand(128, 128).astype(np.float32) - B = np.random.rand(128, 128).astype(np.float32) - C = np.random.rand(128, 128).astype(np.float32) - alpha = 2.1 - beta = 1.5 - vec_width = 4 - sdfg = create_gemm_sdfg("gemm_vectorized", alpha, beta, A, B, C, dace.float32, vec_width=vec_width) - sdfg.expand_library_nodes() - sdfg.apply_transformations_repeated([InlineSDFG]) - # Compute ground truth - C_regression = alpha * (A @ B) + beta * C - sdfg(A=A, B=B, C=C) - assert np.allclose(C, C_regression, atol=1e-6) - return sdfg - - -@pytest.mark.skip('Xilinx HLS fails due to unresolved phi nodes') -@xilinx_test(assert_ii_1=True) -def test_gemm_vectorized_decoupled(): - # Test with vectorization - A = np.random.rand(128, 128).astype(np.float32) - B = np.random.rand(128, 128).astype(np.float32) - C = np.random.rand(128, 128).astype(np.float32) - alpha = 2.1 - beta = 1.5 - vec_width = 4 - sdfg = create_gemm_sdfg("gemm_vectorized", alpha, beta, A, B, C, dace.float32, vec_width=vec_width) - sdfg.expand_library_nodes() - sdfg.apply_transformations_repeated([InlineSDFG]) - # Compute ground truth - C_regression = alpha * (A @ B) + beta * C - with set_temporary("compiler", "xilinx", "decouple_array_interfaces", value=True): - sdfg(A=A, B=B, C=C) - assert np.allclose(C, C_regression, atol=1e-6) - return sdfg - - -@fpga_test(assert_ii_1=False, xilinx=False) -def test_gemm_size_not_multiples_of(): - - # Test with matrix sizes that are not a multiple of #PEs and Tile sizes - A = np.random.rand(120, 128).astype(np.float32) - B = np.random.rand(128, 128).astype(np.float32) - C = np.random.rand(120, 128).astype(np.float32) - expansion_args = {"tile_size_m": 50, "num_pes": 7} - sdfg = create_gemm_sdfg("gemm_not_multiple_of", 1, 1, A, B, C, dace.float32, expansion_args=expansion_args) - sdfg.expand_library_nodes() - sdfg.apply_transformations_repeated([InlineSDFG]) - # compute ground truth - C_regression = A @ B + C - sdfg(A=A, B=B, C=C) - assert np.allclose(C, C_regression, atol=1e-6) - return sdfg - - -@pytest.mark.skip('Xilinx HLS fails due to unresolved phi nodes') -@xilinx_test() -def test_gemm_size_not_multiples_of_decoupled(): - # Test with matrix sizes that are not a multiple of #PEs and Tile sizes - # To achieve II=1 with Xilinx, we need to decouple reads/writes from memory - A = np.random.rand(120, 128).astype(np.float32) - B = np.random.rand(128, 128).astype(np.float32) - C = np.random.rand(120, 128).astype(np.float32) - expansion_args = {"tile_size_m": 50, "num_pes": 7} - sdfg = create_gemm_sdfg("gemm_not_multiple_of", 1, 1, A, B, C, dace.float32, expansion_args=expansion_args) - sdfg.expand_library_nodes() - sdfg.apply_transformations_repeated([InlineSDFG]) - # compute ground truth - C_regression = A @ B + C - with set_temporary("compiler", "xilinx", "decouple_array_interfaces", value=True): - sdfg(A=A, B=B, C=C) - assert np.allclose(C, C_regression, atol=1e-6) - return sdfg - - -@fpga_test(xilinx=False) -def test_matmul_np(): - # Test with numpy matmul, and double precision - @dace.program - def matmul_np(A: dace.float64[128, 64], B: dace.float64[64, 32], C: dace.float64[128, 32]): - C[:] = A @ B - - A = np.random.rand(128, 64).astype(np.float64) - B = np.random.rand(64, 32).astype(np.float64) - C = np.random.rand(128, 32).astype(np.float64) - - sdfg = matmul_np.to_sdfg() - sdfg.apply_transformations([FPGATransformSDFG]) - from dace.libraries.blas import Gemm - Gemm.default_implementation = "FPGA1DSystolic" - # We have to Inline - sdfg.expand_library_nodes() - sdfg.apply_transformations_repeated([InlineSDFG]) - C_regression = A @ B - sdfg(A=A, B=B, C=C) - assert np.allclose(C, C_regression, atol=1e-6) - return sdfg - - -if __name__ == "__main__": - test_naive_matmul_fpga(None) - test_systolic_matmul_fpga(None) - test_gemm_vectorized(None) - test_gemm_vectorized_decoupled(None) - test_gemm_size_not_multiples_of(None) - test_gemm_size_not_multiples_of_decoupled(None) - test_matmul_np(None) diff --git a/tests/fpga/memory_buffering_test.py b/tests/fpga/memory_buffering_test.py deleted file mode 100644 index 12aa0f223e..0000000000 --- a/tests/fpga/memory_buffering_test.py +++ /dev/null @@ -1,256 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -""" -Tests memory buffering in an FPGA SDFG, where memory is read and written using -512-bit wide accesses, and converted (using a "gearbox") to/from the vector -width used by the computational kernel. - -Unfortunately this doesn't currently work for Intel, since Intel does not -support vectors of vectors in kernel code. -""" -import dace -from dace.fpga_testing import fpga_test, xilinx_test -from dace.libraries.standard import Gearbox -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -import numpy as np - -dtype = dace.float32 -mem_width = 64 // dtype.bytes -n = dace.symbol("n") - - -def run_program(sdfg: dace.SDFG): - size = 16 * mem_width - input_array = np.ones((size, ), dtype.type) - output_array = np.empty((size, ), dtype.type) - sdfg(input_array_host=input_array, output_array_host=output_array, n=size) - assert all(output_array == input_array + 1) - - -def memory_buffering(vec_width, use_library_node, elementwise): - - gear_factor = mem_width // vec_width - kernel_type = dace.vector(dtype, vec_width) - if elementwise: - memory_type = dace.vector(dtype, mem_width) - else: - memory_type = dace.vector(kernel_type, gear_factor) - sdfg = dace.SDFG("memory_buffering_library_node") - state = sdfg.add_state("memory_buffering_library_node") - - sdfg.add_array("input_array", (n / mem_width, ), memory_type, transient=True, storage=dace.StorageType.FPGA_Global) - sdfg.add_array("output_array", (n / mem_width, ), memory_type, transient=True, storage=dace.StorageType.FPGA_Global) - sdfg.add_stream("read_to_gearbox", memory_type, transient=True, storage=dace.StorageType.FPGA_Local) - sdfg.add_stream("gearbox_to_kernel", kernel_type, transient=True, storage=dace.StorageType.FPGA_Local) - sdfg.add_stream("kernel_to_gearbox", kernel_type, transient=True, storage=dace.StorageType.FPGA_Local) - sdfg.add_stream("gearbox_to_write", memory_type, transient=True, storage=dace.StorageType.FPGA_Local) - - # Read from memory - memory_read = state.add_read("input_array") - read_to_gearbox_write = state.add_write("read_to_gearbox") - read_entry, read_exit = state.add_map("read", {"i": f"0:n/{mem_width}"}, schedule=dace.ScheduleType.FPGA_Device) - read_tasklet = state.add_tasklet("read", {"mem"}, {"to_gearbox"}, "to_gearbox = mem") - state.add_memlet_path(memory_read, read_entry, read_tasklet, dst_conn="mem", memlet=dace.Memlet(f"input_array[i]")) - state.add_memlet_path(read_tasklet, - read_exit, - read_to_gearbox_write, - src_conn="to_gearbox", - memlet=dace.Memlet(f"read_to_gearbox[0]")) - - # Gearbox input - read_to_gearbox_read = state.add_read("read_to_gearbox") - gearbox_to_kernel_write = state.add_write("gearbox_to_kernel") - if use_library_node: - read_gearbox = Gearbox(n / mem_width, name="read_gearbox") - state.add_node(read_gearbox) - state.add_memlet_path(read_to_gearbox_read, - read_gearbox, - dst_conn="from_memory", - memlet=dace.Memlet("read_to_gearbox[0]", volume=n / mem_width)) - state.add_memlet_path(read_gearbox, - gearbox_to_kernel_write, - src_conn="to_kernel", - memlet=dace.Memlet("gearbox_to_kernel[0]", volume=n / vec_width)) - else: - sdfg.add_array("read_buffer", (1, ), memory_type, storage=dace.StorageType.FPGA_Local, transient=True) - read_buffer_read = state.add_read("read_buffer") - read_buffer_write = state.add_write("read_buffer") - read_gearbox_entry, read_gearbox_exit = state.add_map("gearbox_read", { - "i": f"0:n/{mem_width}", - "j": f"0:{gear_factor}" - }, - schedule=dace.ScheduleType.FPGA_Device) - read_gearbox_tasklet = state.add_tasklet( - "gearbox_read", { - "from_memory": memory_type, - "buffer_in": None - }, {"to_kernel", "buffer_out"}, """\ -wide = from_memory if j == 0 else buffer_in -to_kernel = wide[j] -buffer_out = wide""") - state.add_memlet_path(read_to_gearbox_read, - read_gearbox_entry, - read_gearbox_tasklet, - dst_conn="from_memory", - memlet=dace.Memlet("read_to_gearbox[0]", dynamic=True)) - state.add_memlet_path(read_buffer_read, - read_gearbox_entry, - read_gearbox_tasklet, - dst_conn="buffer_in", - memlet=dace.Memlet("read_buffer[0]")) - state.add_memlet_path(read_gearbox_tasklet, - read_gearbox_exit, - gearbox_to_kernel_write, - src_conn="to_kernel", - memlet=dace.Memlet("gearbox_to_kernel[0]")) - state.add_memlet_path(read_gearbox_tasklet, - read_gearbox_exit, - read_buffer_write, - src_conn="buffer_out", - memlet=dace.Memlet("read_buffer[0]")) - - # Some fictional compute - gearbox_to_kernel_read = state.add_read("gearbox_to_kernel") - kernel_to_gearbox_write = state.add_write("kernel_to_gearbox") - compute_entry, compute_exit = state.add_map("compute", {"i": f"0:n/{vec_width}"}, - schedule=dace.ScheduleType.FPGA_Device) - compute_tasklet = state.add_tasklet("compute", {"val_in"}, {"val_out"}, "val_out = val_in + 1") - state.add_memlet_path(gearbox_to_kernel_read, - compute_entry, - compute_tasklet, - dst_conn="val_in", - memlet=dace.Memlet("gearbox_to_kernel[0]")) - state.add_memlet_path(compute_tasklet, - compute_exit, - kernel_to_gearbox_write, - src_conn="val_out", - memlet=dace.Memlet("kernel_to_gearbox[0]")) - - # Gearbox output - kernel_to_gearbox_read = state.add_write("kernel_to_gearbox") - gearbox_to_write_write = state.add_read("gearbox_to_write") - if use_library_node: - write_gearbox = Gearbox(n / mem_width, name="write_gearbox") - state.add_node(write_gearbox) - state.add_memlet_path(kernel_to_gearbox_read, - write_gearbox, - dst_conn="from_kernel", - memlet=dace.Memlet("kernel_to_gearbox[0]", volume=n / vec_width)) - state.add_memlet_path(write_gearbox, - gearbox_to_write_write, - src_conn="to_memory", - memlet=dace.Memlet("gearbox_to_write[0]", volume=n / mem_width)) - else: - sdfg.add_array("write_buffer", (1, ), memory_type, storage=dace.StorageType.FPGA_Local, transient=True) - write_buffer_read = state.add_read("write_buffer") - write_buffer_write = state.add_write("write_buffer") - write_gearbox_entry, write_gearbox_exit = state.add_map("gearbox_write", { - "i": f"0:n/{mem_width}", - "j": f"0:{gear_factor}" - }, - schedule=dace.ScheduleType.FPGA_Device) - write_gearbox_tasklet = state.add_tasklet( - "gearbox_write", {"from_kernel", "buffer_in"}, {"to_memory", "buffer_out"}, f"""\ -wide = buffer_in -wide[j] = from_kernel -if j == {gear_factor} - 1: - to_memory = wide -buffer_out = wide""") - state.add_memlet_path(kernel_to_gearbox_read, - write_gearbox_entry, - write_gearbox_tasklet, - dst_conn="from_kernel", - memlet=dace.Memlet("kernel_to_gearbox[0]")) - state.add_memlet_path(write_buffer_read, - write_gearbox_entry, - write_gearbox_tasklet, - dst_conn="buffer_in", - memlet=dace.Memlet("write_buffer[0]")) - state.add_memlet_path(write_gearbox_tasklet, - write_gearbox_exit, - gearbox_to_write_write, - src_conn="to_memory", - memlet=dace.Memlet("gearbox_to_write[0]", dynamic=True)) - state.add_memlet_path(write_gearbox_tasklet, - write_gearbox_exit, - write_buffer_write, - src_conn="buffer_out", - memlet=dace.Memlet("write_buffer[0]")) - - # Write memory - gearbox_to_write_read = state.add_read("gearbox_to_write") - memory_write = state.add_write("output_array") - write_entry, write_exit = state.add_map("write", {"i": f"0:n/{mem_width}"}, schedule=dace.ScheduleType.FPGA_Device) - write_tasklet = state.add_tasklet("write", {"from_gearbox"}, {"mem"}, "mem = from_gearbox") - state.add_memlet_path(gearbox_to_write_read, - write_entry, - write_tasklet, - dst_conn="from_gearbox", - memlet=dace.Memlet("gearbox_to_write[0]")) - state.add_memlet_path(write_tasklet, - write_exit, - memory_write, - src_conn="mem", - memlet=dace.Memlet("output_array[i]")) - - # Copy data to the FPGA - sdfg.add_array("input_array_host", (n, ), dtype) - pre_state = sdfg.add_state("host_to_device") - host_to_device_read = pre_state.add_read("input_array_host") - host_to_device_write = pre_state.add_write("input_array") - pre_state.add_memlet_path(host_to_device_read, - host_to_device_write, - memlet=dace.Memlet(f"input_array[0:n/{mem_width}]")) - - # Copy data back to the host - sdfg.add_array("output_array_host", (n, ), dtype) - post_state = sdfg.add_state("device_to_host") - device_to_host_read = post_state.add_read("output_array") - device_to_host_write = post_state.add_write("output_array_host") - post_state.add_memlet_path(device_to_host_read, - device_to_host_write, - memlet=dace.Memlet(f"output_array[0:n/{mem_width}]")) - - # Link states - sdfg.add_edge(pre_state, state, dace.InterstateEdge()) - sdfg.add_edge(state, post_state, dace.InterstateEdge()) - - run_program(sdfg) - - return sdfg - - -@xilinx_test() -def test_memory_buffering_manual(): - return memory_buffering(4, False, False) - - -@xilinx_test() -def test_memory_buffering_manual_scalar(): - return memory_buffering(1, False, False) - - -@xilinx_test() -def test_memory_buffering_library_node(): - return memory_buffering(4, True, False) - - -@xilinx_test() -def test_memory_buffering_library_node_scalar(): - return memory_buffering(1, True, False) - - -@fpga_test() -def test_memory_buffering_library_node_elementwise(): - return memory_buffering(4, True, True) - - -@fpga_test() -def test_memory_buffering_library_node_elementwise_scalar(): - return memory_buffering(1, True, True) - - -if __name__ == "__main__": - test_memory_buffering_manual(None) - test_memory_buffering_library_node(None) - test_memory_buffering_library_node_scalar(None) - test_memory_buffering_library_node_elementwise(None) diff --git a/tests/fpga/multibank_copy_fpga_test.py b/tests/fpga/multibank_copy_fpga_test.py deleted file mode 100644 index 5f48e4373a..0000000000 --- a/tests/fpga/multibank_copy_fpga_test.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -from dace import subsets as sbs, dtypes, memlet as mem -import dace -import numpy as np -from dace.codegen.targets.fpga import _FPGA_STORAGE_TYPES -from dace.dtypes import StorageType -from dace.fpga_testing import fpga_test, xilinx_test - -# A test checking copies involving Multibank-arrays using HBM and DDR in some way - - -def mkc(sdfg: dace.SDFG, - state_before, - src_name, - dst_name, - src_storage=None, - dst_storage=None, - src_shape=None, - dst_shape=None, - copy_expr=None, - src_loc=None, - dst_loc=None): - """ - Helper MaKe_Copy that creates and appends states performing exactly one copy. If a provided - arrayname already exists it will use the old array, and ignore all newly passed values - """ - - if copy_expr is None: - copy_expr = src_name - if (state_before == None): - state = sdfg.add_state(is_start_block=True) - else: - state = sdfg.add_state_after(state_before) - - def mkarray(name, shape, storage, loc): - if (name in sdfg.arrays): - return sdfg.arrays[name] - is_transient = False - if (storage in _FPGA_STORAGE_TYPES): - is_transient = True - arr = sdfg.add_array(name, shape, dace.int32, storage, transient=is_transient) - if loc is not None: - arr[1].location["memorytype"] = loc[0] - arr[1].location["bank"] = loc[1] - return arr - - a = mkarray(src_name, src_shape, src_storage, src_loc) - b = mkarray(dst_name, dst_shape, dst_storage, dst_loc) - - aAcc = state.add_access(src_name) - bAcc = state.add_access(dst_name) - - edge = state.add_edge(aAcc, None, bAcc, None, mem.Memlet(copy_expr)) - - a_np_arr, b_np_arr = None, None - if src_shape is not None: - try: - a_np_arr = np.zeros(src_shape, dtype=np.int32) - except: - pass - if dst_shape is not None: - try: - b_np_arr = np.zeros(dst_shape, dtype=np.int32) - except: - pass - return (state, a_np_arr, b_np_arr) - - -# Note, usually there are only 4 ddr banks but much more hmb banks. -# Since the tests run in simulation mode, this should not be an issue. - - -def copy_multibank_1_mem_type(mem_type): - sdfg = dace.SDFG("copy_multibank_1_mem_type_" + mem_type) - s, a, _ = mkc(sdfg, None, "a", "x", StorageType.Default, StorageType.FPGA_Global, [3, 4, 4], [3, 4, 4], "a", None, - (mem_type, "0:3")) - s, _, _ = mkc(sdfg, s, "x", "y", None, StorageType.FPGA_Global, None, [2, 4, 4, 4], - "x[1, 1:4, 1:4]->[1, 1:4, 1:4, 1]", None, (mem_type, "3:5")) - s, _, _ = mkc(sdfg, s, "y", "z", None, StorageType.FPGA_Global, None, [1, 4, 4, 4], - "y[1, 0:4, 0:4, 0:4]->[0, 0:4, 0:4, 0:4]", None, (mem_type, "5:6")) - s, _, _ = mkc(sdfg, s, "z", "w", None, StorageType.FPGA_Global, None, [1, 4, 4, 4], "z", None, (mem_type, "6:7")) - s, _, c = mkc(sdfg, s, "w", "c", None, StorageType.Default, None, [1, 4, 4, 4], "w") - - a.fill(1) - a[1, 0:4, 1] += 2 - a[1, 1, 0:4] += 2 - expect = np.copy(c) - expect.fill(1) - expect[0, 1:5, 1, 1] += 2 - expect[0, 1, 1:5, 1] += 2 - sdfg(a=a, c=c) - assert np.allclose(c[0, 1:4, 1:4, 1], expect[0, 1:4, 1:4, 1]) - return sdfg - - -def copy_multibank_2_mem_type(mem_type_1, mem_type_2): - sdfg = dace.SDFG("copy_multibank_2_mem_type_" + mem_type_1 + "_" + mem_type_2) - s, a, _ = mkc(sdfg, None, "a", "x", StorageType.Default, StorageType.FPGA_Global, [3, 5, 5], [3, 5, 5], "a", None, - (mem_type_1, "0:3")) - s, _, _ = mkc(sdfg, s, "x", "d1", None, StorageType.FPGA_Global, None, [3, 5, 5], "x[2, 0:5, 0:5]->[1, 0:5, 0:5]", - None, (mem_type_2, "1:4")) - s, _, _ = mkc(sdfg, s, "d1", "y", None, StorageType.FPGA_Global, None, [1, 7, 7], "d1[1, 0:5,0:5]->[0, 2:7, 2:7]", - None, (mem_type_1, "3:4")) - s, _, c = mkc(sdfg, s, "y", "c", None, StorageType.Default, None, [1, 7, 7], "y") - - a.fill(1) - a[2, 2:4, 2:4] += 3 - expect = np.copy(c) - expect.fill(1) - expect[0, 4:6, 4:6] += 3 - sdfg(a=a, c=c) - assert np.allclose(c[2:7], expect[2:7]) - return sdfg - - -@xilinx_test() -def test_copy_hbm2hbm(): - return copy_multibank_1_mem_type(mem_type="hbm") - - -@xilinx_test() -def test_copy_ddr2ddr(): - return copy_multibank_1_mem_type(mem_type="ddr") - - -@xilinx_test() -def test_copy_hbm2ddr(): - return copy_multibank_2_mem_type(mem_type_1="hbm", mem_type_2="ddr") - - -@xilinx_test() -def test_copy_ddr2hbm(): - return copy_multibank_2_mem_type(mem_type_1="ddr", mem_type_2="hbm") - - -if __name__ == "__main__": - test_copy_hbm2hbm(None) # HBM to HBM to HBM - test_copy_ddr2ddr(None) # DDR to DDR to DDR - test_copy_hbm2ddr(None) # HBM to DDR to HBM - test_copy_ddr2hbm(None) # DDR to HBM to DDR diff --git a/tests/fpga/multibank_deeply_nested_fpga_test.py b/tests/fpga/multibank_deeply_nested_fpga_test.py deleted file mode 100644 index e584cf161f..0000000000 --- a/tests/fpga/multibank_deeply_nested_fpga_test.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -import dace -from dace.fpga_testing import xilinx_test -from dace import subsets as sbs, dtypes, memlet as mem -from dace import subsets -import numpy as np - -# A test checking Multibank HBM/DDR in the context of nested maps and nested sdfgs -# Note, usually there are only 4 ddr banks but much more hmb banks. -# Since the tests run in simulation mode, this should not be an issue. - - -def create_deeply_nested_sdfg(mem_type): - sdfg = dace.SDFG("deepnest_test_" + mem_type) - state: dace.SDFGState = sdfg.add_state("init") - xarr = state.add_array("x", [4, 10], dace.float32) - sdfg.arrays["x"].location["memorytype"] = mem_type - sdfg.arrays["x"].location["bank"] = "0:4" - yarr = state.add_array("y", [4, 10], dace.float32) - sdfg.arrays["y"].location["memorytype"] = mem_type - sdfg.arrays["y"].location["bank"] = "4:8" - - top_map_entry, top_map_exit = state.add_map("topmap", dict(k="0:2")) - top_map_entry.schedule = dtypes.ScheduleType.Unrolled - - nsdfg = dace.SDFG("nest") - nstate = nsdfg.add_state("nested_state") - x_read = nstate.add_array("xin", [4, 10], dace.float32, dtypes.StorageType.FPGA_Global) - x_write = nstate.add_array("xout", [4, 10], dace.float32, dtypes.StorageType.FPGA_Global) - nsdfg.arrays["xin"].location["memorytype"] = mem_type - nsdfg.arrays["xin"].location["bank"] = "0:4" - nsdfg.arrays["xout"].location["memorytype"] = mem_type - nsdfg.arrays["xout"].location["bank"] = "4:8" - map_entry, map_exit = nstate.add_map("map1", dict(w="0:2")) - map_entry.schedule = dtypes.ScheduleType.Unrolled - imap_entry, imap_exit = nstate.add_map("map2", dict(i="0:10")) - nope = nstate.add_tasklet("nop", dict(_in=None), dict(_out=None), "_out = _in") - input_mem = mem.Memlet("xin[2*k+w, i]") - output_mem = mem.Memlet("xout[2*k+w, i]") - nstate.add_memlet_path(x_read, map_entry, imap_entry, nope, memlet=input_mem, dst_conn="_in") - nstate.add_memlet_path(nope, imap_exit, map_exit, x_write, memlet=output_mem, src_conn="_out") - nsdfg_node = state.add_nested_sdfg(nsdfg, set(["xin"]), set(['xout'])) - - state.add_memlet_path(xarr, - top_map_entry, - nsdfg_node, - memlet=mem.Memlet.from_array("x", sdfg.arrays["x"]), - dst_conn="xin") - state.add_memlet_path(nsdfg_node, - top_map_exit, - yarr, - memlet=mem.Memlet.from_array("y", sdfg.arrays["y"]), - src_conn="xout") - sdfg.apply_fpga_transformations() - - return sdfg - - -def deeply_nested_sdfg(mem_type): - sdfg = create_deeply_nested_sdfg(mem_type) - a = np.zeros((4, 10), np.float32) - a[2, 4:9] += 1 - a[3, 3:8] += 2 - a[0, 7] += 3 - c = np.ones((4, 10), np.float32) - sdfg(x=a, y=c) - assert np.allclose(a, c, 10e-6) - return sdfg - - -@xilinx_test() -def test_hbm_deeply_nested_sdfg(): - return deeply_nested_sdfg(mem_type="hbm") - - -@xilinx_test() -def test_ddr_deeply_nested_sdfg(): - return deeply_nested_sdfg(mem_type="ddr") - - -if __name__ == "__main__": - test_hbm_deeply_nested_sdfg(None) - test_ddr_deeply_nested_sdfg(None) diff --git a/tests/fpga/multibank_dynamic_memlets_test.py b/tests/fpga/multibank_dynamic_memlets_test.py deleted file mode 100644 index 96edee0c20..0000000000 --- a/tests/fpga/multibank_dynamic_memlets_test.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. - -import dace -from dace import subsets as sbs, dtypes, memlet as mem -from dace.fpga_testing import xilinx_test -import numpy as np - -# Checks dynamic access and dynamic loop bounds from multibank HBM and DDR - - -def create_dynamic_memlet_sdfg(mem_type): - sdfg = dace.SDFG("dyn_memlet_" + mem_type) - state: dace.SDFGState = sdfg.add_state("dyn_memlet") - xarr = state.add_array("x", [4, 10], dace.int32) - sdfg.arrays["x"].location["memorytype"] = mem_type - sdfg.arrays["x"].location["bank"] = "0:4" - yarr = state.add_array("y", [4, 10], dace.int32) - sdfg.arrays["y"].location["memorytype"] = mem_type - sdfg.arrays["y"].location["bank"] = "4:8" - - map_enter, map_exit = state.add_map(mem_type + "map", dict(k="0:4"), dtypes.ScheduleType.Unrolled) - arr_map_enter, arr_map_exit = state.add_map("map", dict(i="0:_dynbound")) - tasklet = state.add_tasklet("dyn", set(["_in"]), set(["_out"]), ("if(i == 2):\n" - " _out = 2\n" - "elif (_in != 2):\n" - " _out = _in\n")) - - state.add_memlet_path(xarr, - map_enter, - arr_map_enter, - tasklet, - memlet=mem.Memlet("x[k, i]", dynamic=True), - dst_conn="_in") - state.add_memlet_path(tasklet, - arr_map_exit, - map_exit, - yarr, - memlet=mem.Memlet("y[k, i]", dynamic=True), - src_conn="_out") - state.add_memlet_path(xarr, map_enter, arr_map_enter, memlet=mem.Memlet("x[1, 0]"), dst_conn="_dynbound") - sdfg.apply_fpga_transformations() - return sdfg - - -def dynamic_memlet(mem_type): - sdfg = create_dynamic_memlet_sdfg(mem_type) - x = np.zeros((4, 10), dtype=np.int32) - y = np.ones((4, 10), dtype=np.int32) # has to be copied to sdfg - x[0:4, 8] = 2 - x[1, 0] = 10 - expected = np.copy(x) - expected[0:4, 2] = 2 - expected[0:4, 8] = 1 - sdfg(x=x, y=y) - assert np.allclose(y, expected) - return sdfg - - -@xilinx_test() -def test_hbm_dynamic_memlet(): - return dynamic_memlet("hbm") - - -@xilinx_test() -def test_ddr_dynamic_memlet(): - return dynamic_memlet("ddr") - - -if __name__ == "__main__": - test_hbm_dynamic_memlet(None) - test_ddr_dynamic_memlet(None) diff --git a/tests/fpga/multibank_multiple_interface_test.py b/tests/fpga/multibank_multiple_interface_test.py deleted file mode 100644 index b0d72a7c5c..0000000000 --- a/tests/fpga/multibank_multiple_interface_test.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. - -from dace.transformation.dataflow.map_unroll import MapUnroll -from dace import dtypes, subsets -import dace -from dace import memlet -from dace.fpga_testing import xilinx_test -import numpy as np -from dace.sdfg import SDFG -from dace.transformation.interstate import InlineSDFG -from dace.config import set_temporary -# Checks multiple interfaces attached to the same HBM/DDR-bank. - - -def four_interface_to_2_banks(mem_type, decouple_interfaces): - sdfg = SDFG("test_4_interface_to_2_banks_" + mem_type) - state = sdfg.add_state() - - _, desc_a = sdfg.add_array("a", [2, 2], dace.int32) - desc_a.location["memorytype"] = mem_type - desc_a.location["bank"] = "0:2" - acc_read1 = state.add_read("a") - acc_write1 = state.add_write("a") - - t1 = state.add_tasklet("r1", set(["_x1", "_x2"]), set(["_y1"]), "_y1 = _x1 + _x2") - - m1_in, m1_out = state.add_map("m", {"k": "0:2"}, dtypes.ScheduleType.Unrolled) - - state.add_memlet_path(acc_read1, m1_in, t1, memlet=memlet.Memlet("a[0, 0]"), dst_conn="_x1") - state.add_memlet_path(acc_read1, m1_in, t1, memlet=memlet.Memlet("a[1, 0]"), dst_conn="_x2") - state.add_memlet_path(t1, m1_out, acc_write1, memlet=memlet.Memlet("a[0, 1]"), src_conn="_y1") - - sdfg.apply_fpga_transformations() - assert sdfg.apply_transformations(InlineSDFG) == 1 - assert sdfg.apply_transformations(MapUnroll) == 1 - for node in sdfg.states()[0].nodes(): - if isinstance(node, dace.sdfg.nodes.Tasklet): - sdfg.states()[0].out_edges(node)[0].data.subset = subsets.Range.from_string("1, 1") - break - - with set_temporary("compiler", "xilinx", "decouple_array_interfaces", value=decouple_interfaces): - bank_assignment = sdfg.generate_code()[3].clean_code - # if we are not decoupling array interfaces we will use less mem interfaces - assert bank_assignment.count("sp") == 6 if decouple_interfaces else 4 - assert bank_assignment.count(mem_type + "[0]") == 3 if decouple_interfaces else 2 - assert bank_assignment.count(mem_type + "[1]") == 3 if decouple_interfaces else 2 - - a = np.zeros([2, 2], np.int32) - a[0, 0] = 2 - a[1, 0] = 3 - sdfg(a=a) - assert a[0, 1] == 5 - - return sdfg - - -@xilinx_test(assert_ii_1=False) -def test_4_interface_to_2_banks_ddr_non_decoupled_interfaces(): - return four_interface_to_2_banks(mem_type="DDR", decouple_interfaces=False) - - -@xilinx_test(assert_ii_1=False) -def test_4_interface_to_2_banks_ddr_decoupled_interfaces(): - return four_interface_to_2_banks(mem_type="DDR", decouple_interfaces=True) - - -@xilinx_test(assert_ii_1=False) -def test_4_interface_to_2_banks_hbm_non_decoupled_interface(): - return four_interface_to_2_banks(mem_type="HBM", decouple_interfaces=False) - - -@xilinx_test(assert_ii_1=False) -def test_4_interface_to_2_banks_hbm_decoupled_interface(): - return four_interface_to_2_banks(mem_type="HBM", decouple_interfaces=True) - - -if __name__ == "__main__": - test_4_interface_to_2_banks_hbm_decoupled_interface(None) - test_4_interface_to_2_banks_hbm_non_decoupled_interface(None) - test_4_interface_to_2_banks_ddr_decoupled_interfaces(None) - test_4_interface_to_2_banks_ddr_non_decoupled_interfaces(None) diff --git a/tests/fpga/multibank_reduce_fpga_test.py b/tests/fpga/multibank_reduce_fpga_test.py deleted file mode 100644 index 8f202fab91..0000000000 --- a/tests/fpga/multibank_reduce_fpga_test.py +++ /dev/null @@ -1,190 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. - -import dace -from dace import subsets -from dace.fpga_testing import xilinx_test -import numpy as np -import pytest -from dace.config import set_temporary - -# A test checking wcr-reduction with HBM/DDR arrays as inputs and output - - -def create_multibank_reduce_sdfg( - name, - mem_type, - banks=2, -): - N = dace.symbol("N") - M = dace.symbol("M") - - sdfg = dace.SDFG(name + "_" + mem_type) - state = sdfg.add_state('red_' + mem_type, True) - - in1 = sdfg.add_array("in1", [banks, N, M], dace.float32) - in2 = sdfg.add_array("in2", [banks, N, M], dace.float32) - out = sdfg.add_array("out", [banks, N], dace.float32) - in1[1].location["memorytype"] = mem_type - in2[1].location["memorytype"] = mem_type - out[1].location["memorytype"] = mem_type - in1[1].location["bank"] = f"0:{banks}" - in2[1].location["bank"] = f"{banks}:{2*banks}" - out[1].location["bank"] = f"{2*banks}:{3*banks}" - - read_in1 = state.add_read("in1") - read_in2 = state.add_read("in2") - out_write = state.add_write("out") - tmp_in1_memlet = dace.Memlet(f"in1[k, i, j]") - tmp_in2_memlet = dace.Memlet(f"in2[k, i, j]") - tmp_out_memlet = dace.Memlet(f"out[k, i]", wcr="lambda x,y: x+y") - - outer_entry, outer_exit = state.add_map("vadd_outer_map", dict(k=f'0:{banks}')) - map_entry, map_exit = state.add_map("vadd_inner_map", dict(i="0:N", j="0:M")) - tasklet = state.add_tasklet("mul", dict(__in1=None, __in2=None), dict(__out=None), '__out = __in1 * __in2') - outer_entry.map.schedule = dace.ScheduleType.Unrolled - - state.add_memlet_path(read_in1, outer_entry, map_entry, tasklet, memlet=tmp_in1_memlet, dst_conn="__in1") - state.add_memlet_path(read_in2, outer_entry, map_entry, tasklet, memlet=tmp_in2_memlet, dst_conn="__in2") - state.add_memlet_path(tasklet, map_exit, outer_exit, out_write, memlet=tmp_out_memlet, src_conn="__out") - - sdfg.apply_fpga_transformations() - return sdfg - - -def create_test_set(N, M, banks): - in1 = np.random.rand(*[banks, N, M]).astype('f') - in2 = np.random.rand(*[banks, N, M]).astype('f') - expected = np.sum(in1 * in2, axis=2, dtype=np.float32) - out = np.zeros((banks, N), dtype=np.float32) - return (in1, in2, expected, out) - - -def exec_test(N, M, banks, mem_type, name): - in1, in2, expected, target = create_test_set(N, M, banks) - sdfg = create_multibank_reduce_sdfg(name, mem_type, banks) - sdfg(in1=in1, in2=in2, out=target, N=N, M=M) - assert np.allclose(expected, target, rtol=1e-6) - return sdfg - - -@xilinx_test() -def test_hbm_reduce_2x3_2b(): - return exec_test(2, 3, 2, "hbm", "red_2x3_2b") - - -@xilinx_test() -def test_hbm_reduce_2x3_2b_decouple_array_interfaces(): - with set_temporary("compiler", "xilinx", "decouple_array_interfaces", value=True): - return exec_test(2, 3, 2, "hbm", "red_2x3_2b_decoupled") - - -@xilinx_test() -def test_hbm_reduce_10x50_4b(): - return exec_test(10, 50, 4, "hbm", "red_10x50_4b") - - -@xilinx_test() -def test_hbm_reduce_10x50_4b_decouple_array_interfaces(): - with set_temporary("compiler", "xilinx", "decouple_array_interfaces", value=True): - return exec_test(10, 50, 4, "hbm", "red_10x50_4b_decoupled") - - -@xilinx_test() -def test_hbm_reduce_red_1x50_1b(): - return exec_test(1, 50, 1, "hbm", "red_1x50_1b") - - -@xilinx_test() -def test_hbm_reduce_red_1x50_1b_decouple_array_interfaces(): - with set_temporary("compiler", "xilinx", "decouple_array_interfaces", value=True): - return exec_test(1, 50, 1, "hbm", "red_1x50_1b_decoupled") - - -@xilinx_test() -def test_hbm_reduce_red_1x40_8b(): - return exec_test(1, 40, 8, "hbm", "red_1x40_8b") - - -@xilinx_test() -def test_hbm_reduce_red_1x40_8b_decouple_array_interfaces(): - with set_temporary("compiler", "xilinx", "decouple_array_interfaces", value=True): - return exec_test(1, 40, 8, "hbm", "red_1x40_8b_decoupled") - - -@xilinx_test() -def test_hbm_reduce_red_2x40_6b(): - return exec_test(2, 40, 6, "hbm", "red_2x40_6b") - - -@xilinx_test() -def test_hbm_reduce_red_2x40_6b_decouple_array_interfaces(): - with set_temporary("compiler", "xilinx", "decouple_array_interfaces", value=True): - return exec_test(2, 40, 6, "hbm", "red_2x40_6b_decoupled") - - -@xilinx_test() -def test_ddr_reduce_2x3_2b(): - return exec_test(2, 3, 2, "ddr", "red_2x3_2b") - - -@xilinx_test() -def test_ddr_reduce_2x3_2b_decouple_array_interfaces(): - with set_temporary("compiler", "xilinx", "decouple_array_interfaces", value=True): - return exec_test(2, 3, 2, "ddr", "red_2x3_2b_decoupled") - - -@xilinx_test() -def test_ddr_reduce_10x50_4b(): - return exec_test(10, 50, 4, "ddr", "red_10x50_4b") - - -@xilinx_test() -def test_ddr_reduce_10x50_4b_decouple_array_interfaces(): - with set_temporary("compiler", "xilinx", "decouple_array_interfaces", value=True): - return exec_test(10, 50, 4, "ddr", "red_10x50_4b_decoupled") - - -@xilinx_test() -def test_ddr_reduce_red_1x50_1b(): - return exec_test(1, 50, 1, "ddr", "red_1x50_1b") - - -@xilinx_test() -def test_ddr_reduce_red_1x50_1b_decouple_array_interfaces(): - with set_temporary("compiler", "xilinx", "decouple_array_interfaces", value=True): - return exec_test(1, 50, 1, "ddr", "red_1x50_1b_decoupled") - - -@xilinx_test() -def test_ddr_reduce_red_1x40_8b(): - return exec_test(1, 40, 8, "ddr", "red_1x40_8b") - - -@xilinx_test() -def test_ddr_reduce_red_1x40_8b_decouple_array_interfaces(): - with set_temporary("compiler", "xilinx", "decouple_array_interfaces", value=True): - return exec_test(1, 40, 8, "ddr", "red_1x40_8b_decoupled") - - -@xilinx_test() -def test_ddr_reduce_red_2x40_6b(): - return exec_test(2, 40, 6, "ddr", "red_2x40_6b") - - -@xilinx_test() -def test_ddr_reduce_red_2x40_6b_decouple_array_interfaces(): - with set_temporary("compiler", "xilinx", "decouple_array_interfaces", value=True): - return exec_test(2, 40, 6, "ddr", "red_2x40_6b_decoupled") - - -if __name__ == "__main__": - test_hbm_reduce_2x3_2b(None) - test_hbm_reduce_10x50_4b(None) - test_hbm_reduce_red_1x50_1b(None) - test_hbm_reduce_red_1x40_8b(None) - test_hbm_reduce_red_2x40_6b(None) - test_ddr_reduce_2x3_2b(None) - test_ddr_reduce_10x50_4b(None) - test_ddr_reduce_red_1x50_1b(None) - test_ddr_reduce_red_1x40_8b(None) - test_ddr_reduce_red_2x40_6b(None) diff --git a/tests/fpga/multibank_vadd_fpga_test.py b/tests/fpga/multibank_vadd_fpga_test.py deleted file mode 100644 index e7d3b4d718..0000000000 --- a/tests/fpga/multibank_vadd_fpga_test.py +++ /dev/null @@ -1,168 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. - -from dace import subsets -from dace.fpga_testing import xilinx_test -import dace -import numpy as np -from dace.transformation.interstate import InlineSDFG - -# A test executing vector addition with multidimensional arrays using HBM/DDR. - - -def create_vadd_multibank_sdfg(bank_count_per_array=2, - ndim=1, - unroll_map_inside=False, - mem_type="hbm", - sdfg_name="vadd_hbm"): - N = dace.symbol("N") - M = dace.symbol("M") - S = dace.symbol("S") - - sdfg = dace.SDFG(sdfg_name + "_" + mem_type) - state = sdfg.add_state('vadd_' + mem_type, True) - shape = [bank_count_per_array, N] - access_str = "i" - inner_map_range = dict() - inner_map_range["i"] = "0:N" - if (ndim >= 2): - shape = [bank_count_per_array, N, M] - access_str = "i, j" - inner_map_range["j"] = "0:M" - if (ndim >= 3): - shape = [bank_count_per_array, N, M, S] - access_str = "i, j, t" - inner_map_range["t"] = "0:S" - - in1 = sdfg.add_array("in1", shape, dace.float32) - in2 = sdfg.add_array("in2", shape, dace.float32) - out = sdfg.add_array("out", shape, dace.float32) - - in1[1].location["memorytype"] = mem_type - in2[1].location["memorytype"] = mem_type - out[1].location["memorytype"] = mem_type - in1[1].location["bank"] = f"0:{bank_count_per_array}" - in2[1].location["bank"] = f"{bank_count_per_array}:{2*bank_count_per_array}" - out[1].location["bank"] = f"{2*bank_count_per_array}:{3*bank_count_per_array}" - - read_in1 = state.add_read("in1") - read_in2 = state.add_read("in2") - out_write = state.add_write("out") - - tmp_in1_memlet = dace.Memlet(f"in1[k, {access_str}]") - tmp_in2_memlet = dace.Memlet(f"in2[k, {access_str}]") - tmp_out_memlet = dace.Memlet(f"out[k, {access_str}]") - - outer_entry, outer_exit = state.add_map("vadd_outer_map", dict(k=f'0:{bank_count_per_array}')) - map_entry, map_exit = state.add_map("vadd_inner_map", inner_map_range) - tasklet = state.add_tasklet("addandwrite", dict(__in1=None, __in2=None), dict(__out=None), '__out = __in1 + __in2') - outer_entry.map.schedule = dace.ScheduleType.Unrolled - - if (unroll_map_inside): - state.add_memlet_path(read_in1, map_entry, outer_entry, tasklet, memlet=tmp_in1_memlet, dst_conn="__in1") - state.add_memlet_path(read_in2, map_entry, outer_entry, tasklet, memlet=tmp_in2_memlet, dst_conn="__in2") - state.add_memlet_path(tasklet, outer_exit, map_exit, out_write, memlet=tmp_out_memlet, src_conn="__out") - else: - state.add_memlet_path(read_in1, outer_entry, map_entry, tasklet, memlet=tmp_in1_memlet, dst_conn="__in1") - state.add_memlet_path(read_in2, outer_entry, map_entry, tasklet, memlet=tmp_in2_memlet, dst_conn="__in2") - state.add_memlet_path(tasklet, map_exit, outer_exit, out_write, memlet=tmp_out_memlet, src_conn="__out") - - sdfg.apply_fpga_transformations() - sdfg.apply_transformations(InlineSDFG) - return sdfg - - -def create_test_set(dim, size1D, banks): - shape = [banks] - for i in range(dim): - shape.append(size1D) - in1 = np.random.rand(*shape) - in2 = np.random.rand(*shape) - in1 = in1.astype(np.float32) - in2 = in2.astype(np.float32) - expected = in1 + in2 - out = np.empty(shape, dtype=np.float32) - return (in1, in2, expected, out) - - -def exec_test( - dim, - size1D, - banks, - mem_type, - test_name, - unroll_map_inside=False, -): - in1, in2, expected, target = create_test_set(dim, size1D, banks) - sdfg = create_vadd_multibank_sdfg(banks, dim, unroll_map_inside, mem_type, test_name) - if (dim == 1): - sdfg(in1=in1, in2=in2, out=target, N=size1D) - elif (dim == 2): - sdfg(in1=in1, in2=in2, out=target, N=size1D, M=size1D) - else: - sdfg(in1=in1, in2=in2, out=target, N=size1D, M=size1D, S=size1D) - assert np.allclose(expected, target, rtol=1e-6) - return sdfg - - -@xilinx_test() -def test_vadd_hbm_1b1d(): - return exec_test(1, 50, 1, "hbm", "vadd_1b1d") - - -@xilinx_test() -def test_vadd_hbm_2b1d(): - return exec_test(1, 50, 2, "hbm", "vadd_2b1d") - - -@xilinx_test() -def test_vadd_hbm_2b2d(): - return exec_test(2, 50, 2, "hbm", "vadd_2b2d") - - -@xilinx_test() -def test_vadd_hbm_2b3d(): - return exec_test(3, 10, 2, "hbm", "vadd_2b3d") - - -@xilinx_test() -def test_vadd_hbm_8b1d(): - return exec_test(1, 50, 8, "hbm", "vadd_8b1d", True) - - -@xilinx_test() -def test_vadd_ddr_1b1d(): - return exec_test(1, 50, 1, "ddr", "vadd_1b1d") - - -@xilinx_test() -def test_vadd_ddr_2b1d(): - return exec_test(1, 50, 2, "ddr", "vadd_2b1d") - - -@xilinx_test() -def test_vadd_ddr_2b2d(): - return exec_test(2, 50, 2, "ddr", "vadd_2b2d") - - -@xilinx_test() -def test_vadd_ddr_2b3d(): - return exec_test(3, 10, 2, "ddr", "vadd_2b3d") - - -@xilinx_test() -def test_vadd_ddr_8b1d(): - return exec_test(1, 50, 8, "ddr", "vadd_8b1d", True) - - -if __name__ == '__main__': - test_vadd_hbm_1b1d(None) - test_vadd_hbm_2b1d(None) - test_vadd_hbm_2b2d(None) - test_vadd_hbm_2b3d(None) - test_vadd_hbm_8b1d(None) - - test_vadd_ddr_1b1d(None) - test_vadd_ddr_2b1d(None) - test_vadd_ddr_2b2d(None) - test_vadd_ddr_2b3d(None) - test_vadd_ddr_8b1d(None) diff --git a/tests/fpga/multibank_validation_test.py b/tests/fpga/multibank_validation_test.py deleted file mode 100644 index 699f1a91dd..0000000000 --- a/tests/fpga/multibank_validation_test.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. - -from dace.sdfg.validation import InvalidSDFGEdgeError, InvalidSDFGError, InvalidSDFGNodeError, validate -from dace import subsets as sbs, dtypes, memlet as mem -import dace -import numpy as np -from dace import subsets -from dace.sdfg import nodes as nd - -# A test to check the changes to the validation required for the support for HBM and DDR - - -def assert_validation_failure(sdfg, exceptiontype): - ok = False - try: - sdfg.validate() - except exceptiontype as msg: - ok = True - assert ok - - -def multibank_deep_scope(mem_type): - - @dace.program - def deep_scope(input: dace.int32[12, 10], output: dace.int32[12, 10]): - for k in dace.map[0:10]: - for j in dace.map[0:2]: - for z in dace.map[0:2]: - with dace.tasklet: - _read << input[k + j + z, 0] - _write >> output[k + j * z, 0] - _write = _read + 1 - - sdfg = deep_scope.to_sdfg() - for state in sdfg.nodes(): - for node in state.nodes(): - if isinstance(node, nd.MapEntry): - node.map.schedule = dtypes.ScheduleType.Unrolled - sdfg.arrays["input"].location["memorytype"] = mem_type - sdfg.arrays["output"].location["memorytype"] = mem_type - sdfg.arrays["input"].location["bank"] = "0:12" - sdfg.arrays["output"].location["bank"] = "12:24" - sdfg.apply_fpga_transformations(validate=False) - sdfg.validate() - - -def multibank_multi_tasklet(mem_type): - - @dace.program - def multi_tasklet(input: dace.int32[12, 10], output: dace.int32[12, 10]): - with dace.tasklet: - m << input[0:2, 4] - n >> output[0:4, 5] - n = m - - sdfg = multi_tasklet.to_sdfg() - sdfg.validate() - sdfg.arrays["input"].location["memorytype"] = mem_type - sdfg.arrays["output"].location["memorytype"] = mem_type - sdfg.arrays["input"].location["bank"] = "0:12" - sdfg.arrays["output"].location["bank"] = "12:24" - sdfg.apply_fpga_transformations(validate=False) - assert_validation_failure(sdfg, InvalidSDFGNodeError) - - @dace.program - def singletasklet(input: dace.int32[2, 10], output: dace.int32[2, 10]): - with dace.tasklet: - m << input[0, 0:10] - n >> output[1, 0:10] - n = m - - sdfg = singletasklet.to_sdfg() - sdfg.arrays["input"].location["memorytype"] = mem_type - sdfg.arrays["output"].location["memorytype"] = mem_type - sdfg.arrays["input"].location["bank"] = "0:2" - sdfg.arrays["output"].location["bank"] = "2:4" - sdfg.apply_fpga_transformations() - sdfg.validate() - - -def multibank_unsound_location(mem_type_1, mem_type_2): - sdfg = dace.SDFG("jdj") - sdfg.add_array("a", [4, 3], dtypes.int32, dtypes.StorageType.FPGA_Global) - sdfg.add_array("b", [4], dtypes.int32, dtypes.StorageType.FPGA_Global) - state = sdfg.add_state("dummy") - sdfg.validate() - sdfg.arrays["a"].location["memorytype"] = ":" - assert_validation_failure(sdfg, InvalidSDFGError) - sdfg.arrays["a"].location["memorytype"] = mem_type_1 - sdfg.arrays["a"].location["bank"] = "2:5" - assert_validation_failure(sdfg, InvalidSDFGError) - sdfg.add_constant("k", 1) - sdfg.arrays["a"].location["memorytype"] = mem_type_1 - sdfg.arrays["a"].location["bank"] = "k:5" - sdfg.validate() - sdfg.constants_prop.clear() - assert_validation_failure(sdfg, InvalidSDFGError) - sdfg.arrays["a"].location["memorytype"] = mem_type_1 - sdfg.arrays["a"].location["bank"] = "2:2" - assert_validation_failure(sdfg, InvalidSDFGError) - sdfg.arrays["a"].location["memorytype"] = mem_type_1 - sdfg.arrays["a"].location["bank"] = "0:4" - sdfg.validate() - sdfg.arrays["b"].location["memorytype"] = mem_type_1 - sdfg.arrays["b"].location["bank"] = "0:4" - assert_validation_failure(sdfg, InvalidSDFGError) - sdfg.arrays["b"].location["memorytype"] = mem_type_2 - sdfg.arrays["b"].location["bank"] = "abc" - assert_validation_failure(sdfg, InvalidSDFGError) - sdfg.arrays["b"].location["memorytype"] = mem_type_2 - sdfg.arrays["b"].location["bank"] = "1" - sdfg.validate() - sdfg.arrays["b"].location["memorytype"] = mem_type_1 - sdfg.arrays["b"].location["bank"] = "4" - sdfg.validate() - - -def test_multibank_deep_scope_hbm(): - multibank_deep_scope("hbm") - - -def test_multibank_deep_scope_ddr(): - multibank_deep_scope("ddr") - - -def test_multibank_multi_tasklet_hbm(): - multibank_multi_tasklet("hbm") - - -def test_multibank_multi_tasklet_ddr(): - multibank_multi_tasklet("ddr") - - -def test_multibank_unsound_location_hmb2ddr(): - multibank_unsound_location("hbm", "ddr") - - -def test_multibank_unsound_location(): - multibank_unsound_location("ddr", "hbm") - - -if __name__ == "__main__": - test_multibank_deep_scope_hbm() - test_multibank_deep_scope_ddr() - test_multibank_multi_tasklet_hbm() - test_multibank_multi_tasklet_ddr() - test_multibank_unsound_location_hmb2ddr() - test_multibank_unsound_location() diff --git a/tests/fpga/multiple_kernels_stream.py b/tests/fpga/multiple_kernels_stream.py deleted file mode 100644 index 0da2dfab36..0000000000 --- a/tests/fpga/multiple_kernels_stream.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. - -# Two FPGA states that communicate through stream - -import dace -import numpy as np -import argparse - -from dace.memlet import Memlet - - -def make_sdfg(dtype=dace.float32): - sdfg = dace.SDFG("multiple_kernels_streams") - - ########################################################################### - # Copy data to FPGA - - copy_in_state = sdfg.add_state("copy_to_device") - - sdfg.add_array("a", shape=[1], dtype=dtype) - sdfg.add_array("c", shape=[1], dtype=dtype) - - in_host_a = copy_in_state.add_read("a") - in_host_c = copy_in_state.add_read("c") - - sdfg.add_array("device_a", shape=[1], dtype=dtype, storage=dace.dtypes.StorageType.FPGA_Global, transient=True) - sdfg.add_array("device_c", shape=[1], dtype=dtype, storage=dace.dtypes.StorageType.FPGA_Global, transient=True) - - in_device_a = copy_in_state.add_write("device_a") - in_device_c = copy_in_state.add_write("device_c") - - copy_in_state.add_memlet_path(in_host_a, in_device_a, memlet=Memlet(f"{in_host_a}[0]")) - copy_in_state.add_memlet_path(in_host_c, in_device_c, memlet=Memlet(f"{in_host_c}[0]")) - - ########################################################################### - # Copy data from FPGA - copy_out_state = sdfg.add_state("copy_to_host") - - device_c = copy_out_state.add_read("device_c") - host_c = copy_out_state.add_write("c") - - copy_out_state.add_memlet_path(device_c, host_c, memlet=Memlet(f"{host_c}[0]")) - - ######################################################################## - # FPGA, First State - - # create the stream for connecting the two states - sdfg.add_stream('device_b_stream', dtype, buffer_size=32, storage=dace.dtypes.StorageType.FPGA_Local) - - fpga_state_0 = sdfg.add_state("fpga_state_0") - - a_in = fpga_state_0.add_read("device_a") - b_stream_out = fpga_state_0.add_write("device_b_stream") - - state_0_tasklet = fpga_state_0.add_tasklet('state_0_tasklet', ['inCon'], ['outCon'], 'outCon = inCon + 1') - - fpga_state_0.add_memlet_path(a_in, state_0_tasklet, dst_conn='inCon', memlet=dace.Memlet(f"{a_in}[0]")) - - fpga_state_0.add_memlet_path(state_0_tasklet, - b_stream_out, - src_conn='outCon', - memlet=dace.Memlet(f"{b_stream_out}[0]", dynamic=True)) - - ######################################################################## - # FPGA, Second State - - fpga_state_1 = sdfg.add_state("fpga_state_1") - - b_stream_in = fpga_state_1.add_read("device_b_stream") - c_out = fpga_state_1.add_write("device_c") - - state_1_tasklet = fpga_state_1.add_tasklet('state_1_tasklet', ['inCon'], ['outCon'], 'outCon = inCon + 1') - - fpga_state_1.add_memlet_path(b_stream_in, - state_1_tasklet, - dst_conn='inCon', - memlet=dace.Memlet(f"{b_stream_in}[0]", dynamic=True)) - - fpga_state_1.add_memlet_path(state_1_tasklet, c_out, src_conn='outCon', memlet=dace.Memlet(f"{c_out.data}[0]")) - - ###################################### - # Interstate edges - sdfg.add_edge(copy_in_state, fpga_state_0, dace.sdfg.sdfg.InterstateEdge()) - sdfg.add_edge(fpga_state_0, fpga_state_1, dace.sdfg.sdfg.InterstateEdge()) - sdfg.add_edge(fpga_state_1, copy_out_state, dace.sdfg.sdfg.InterstateEdge()) - - ######### - # Validate - sdfg.fill_scope_connectors() - sdfg.validate() - return sdfg - - -if __name__ == "__main__": - - sdfg = make_sdfg() - - comp = sdfg.compile() - - a = np.random.rand(1).astype(np.float32) - c = np.random.rand(1).astype(np.float32) - ref_a = a[0] - ref_c = c[0] - print(a) - print(c) - comp(a=a, c=c) - - diff = ((ref_a + 2) - c) / c - print("Ref_c {}, new c {}".format(ref_c, c)) - print("Difference:", diff) - if diff <= 1e-5: - print("==== Program end ====") - else: - print("==== Program Error! ====") - - exit(0 if diff <= 1e-5 else 1) diff --git a/tests/fpga/multiple_kernels_test.py b/tests/fpga/multiple_kernels_test.py deleted file mode 100644 index bf8dc51965..0000000000 --- a/tests/fpga/multiple_kernels_test.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. - -# DaCe program with two state that will be generated as two kernels - -import dace -import numpy as np -import argparse - -from dace.memlet import Memlet - - -def make_sdfg(dtype=dace.float32): - sdfg = dace.SDFG("multiple_kernels") - - ########################################################################### - # Copy data to FPGA - - copy_in_state = sdfg.add_state("copy_to_device") - - sdfg.add_array("a", shape=[1], dtype=dtype) - sdfg.add_array("b", shape=[1], dtype=dtype) - sdfg.add_array("c", shape=[1], dtype=dtype) - - in_host_a = copy_in_state.add_read("a") - in_host_b = copy_in_state.add_read("b") - in_host_c = copy_in_state.add_read("c") - - sdfg.add_array("device_a", shape=[1], dtype=dtype, storage=dace.dtypes.StorageType.FPGA_Global, transient=True) - sdfg.add_array("device_b", shape=[1], dtype=dtype, storage=dace.dtypes.StorageType.FPGA_Global, transient=True) - sdfg.add_array("device_c", shape=[1], dtype=dtype, storage=dace.dtypes.StorageType.FPGA_Global, transient=True) - - in_device_a = copy_in_state.add_write("device_a") - in_device_b = copy_in_state.add_write("device_b") - in_device_c = copy_in_state.add_write("device_c") - - copy_in_state.add_memlet_path(in_host_a, in_device_a, memlet=Memlet.simple(in_host_a, "0")) - copy_in_state.add_memlet_path(in_host_b, in_device_b, memlet=Memlet.simple(in_host_b, "0")) - copy_in_state.add_memlet_path(in_host_c, in_device_c, memlet=Memlet.simple(in_host_c, "0")) - - ########################################################################### - # Copy data from FPGA - copy_out_state = sdfg.add_state("copy_to_host") - - device_c = copy_out_state.add_read("device_c") - host_c = copy_out_state.add_write("c") - - device_b = copy_out_state.add_read("device_b") - host_b = copy_out_state.add_write("b") - - copy_out_state.add_memlet_path(device_c, host_c, memlet=Memlet.simple(host_c, "0")) - - copy_out_state.add_memlet_path(device_b, host_b, memlet=Memlet.simple(host_b, "0")) - - ######################################################################## - # FPGA, First State - - fpga_state_0 = sdfg.add_state("fpga_state_0") - - a_in = fpga_state_0.add_read("device_a") - b_out = fpga_state_0.add_write("device_b") - - state_0_tasklet = fpga_state_0.add_tasklet('state_0_tasklet', ['inCon'], ['outCon'], 'outCon = inCon + 1') - - fpga_state_0.add_memlet_path(a_in, state_0_tasklet, dst_conn='inCon', memlet=dace.Memlet.simple(a_in.data, '0')) - - fpga_state_0.add_memlet_path(state_0_tasklet, b_out, src_conn='outCon', memlet=dace.Memlet.simple(b_out.data, '0')) - - ######################################################################## - # FPGA, Second State - - fpga_state_1 = sdfg.add_state("fpga_state_1") - - b_in = fpga_state_1.add_read("device_b") - c_out = fpga_state_1.add_write("device_c") - - state_1_tasklet = fpga_state_1.add_tasklet('state_1_tasklet', ['inCon'], ['outCon'], 'outCon = inCon + 1') - - fpga_state_1.add_memlet_path(b_in, state_1_tasklet, dst_conn='inCon', memlet=dace.Memlet.simple(b_in.data, '0')) - - fpga_state_1.add_memlet_path(state_1_tasklet, c_out, src_conn='outCon', memlet=dace.Memlet.simple(c_out.data, '0')) - - ###################################### - # Interstate edges - sdfg.add_edge(copy_in_state, fpga_state_0, dace.sdfg.sdfg.InterstateEdge()) - sdfg.add_edge(fpga_state_0, fpga_state_1, dace.sdfg.sdfg.InterstateEdge()) - sdfg.add_edge(fpga_state_1, copy_out_state, dace.sdfg.sdfg.InterstateEdge()) - - ######### - # Validate - sdfg.fill_scope_connectors() - sdfg.validate() - return sdfg - - -if __name__ == "__main__": - - sdfg = make_sdfg() - - comp = sdfg.compile() - - a = np.random.rand(1).astype(np.float32) - b = np.random.rand(1).astype(np.float32) - c = np.random.rand(1).astype(np.float32) - ref_a = a[0] - ref_b = b[0] - ref_c = c[0] - comp(a=a, b=b, c=c) - - diff1 = ((ref_a + 1) - b) / b - diff2 = ((ref_a + 2) - c) / c - if diff1 <= 1e-5 and diff2 <= 1e-5: - print("==== Program end ====") - else: - raise Exception("==== Program Error! ====") diff --git a/tests/fpga/multiple_veclen_conversions_test.py b/tests/fpga/multiple_veclen_conversions_test.py deleted file mode 100644 index 43422a17eb..0000000000 --- a/tests/fpga/multiple_veclen_conversions_test.py +++ /dev/null @@ -1,281 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. - -# The purpose of this test is to verify that gearboxing functions are created only -# once, even if reused multiple times (so they are not defined multiple times) -# There are three different kernels, each one read/write to intermediate access nodes using gearboxing -# each of them increment by 1 the data - -# NOTE: this is a nice case where we need streams that cross states - -import dace -import numpy as np -import argparse - -from dace.memlet import Memlet -from dace.fpga_testing import fpga_test - -N = dace.symbol("N") - - -def make_sdfg(dtype=dace.float32, vec_width=4): - sdfg = dace.SDFG("multiple_veclen_conversions") - - ########################################################################### - # Copy data to FPGA - - copy_in_state = sdfg.add_state("copy_to_device") - - sdfg.add_array("a", shape=[N], dtype=dtype) - sdfg.add_array("d", shape=[N], dtype=dtype) - - in_host_a = copy_in_state.add_read("a") - in_host_d = copy_in_state.add_read("d") - vec_type = dace.vector(dtype, vec_width) - - sdfg.add_array("device_a", shape=[N], dtype=dtype, storage=dace.dtypes.StorageType.FPGA_Global, transient=True) - sdfg.add_array("device_b", - shape=[N / vec_width], - dtype=vec_type, - storage=dace.dtypes.StorageType.FPGA_Global, - transient=True) - sdfg.add_array("device_c", - shape=[N / vec_width], - dtype=vec_type, - storage=dace.dtypes.StorageType.FPGA_Global, - transient=True) - sdfg.add_array("device_d", shape=[N], dtype=dtype, storage=dace.dtypes.StorageType.FPGA_Global, transient=True) - - in_device_a = copy_in_state.add_write("device_a") - in_device_d = copy_in_state.add_write("device_d") - - copy_in_state.add_memlet_path(in_host_a, in_device_a, memlet=Memlet(f"{in_host_a}[0:N]")) - copy_in_state.add_memlet_path(in_host_d, in_device_d, memlet=Memlet(f"{in_host_d}[0:N]")) - - ########################################################################### - # Copy data from FPGA - copy_out_state = sdfg.add_state("copy_to_host") - - device_d = copy_out_state.add_read("device_d") - host_d = copy_out_state.add_write("d") - - copy_out_state.add_memlet_path(device_d, host_d, memlet=Memlet(f"{host_d}[0:N]")) - - ######################################################################## - # FPGA, First State - - # reads data, increment by 1 pack - - fpga_state_0 = sdfg.add_state("fpga_state_0") - - a_in = fpga_state_0.add_read("device_a") - b_out = fpga_state_0.add_write("device_b") - - # local storage to read and increment data - sdfg.add_array('vec_data', - shape=[vec_width], - dtype=dtype, - transient=True, - storage=dace.dtypes.StorageType.FPGA_Registers) - sdfg.add_array('inc_data', - shape=[vec_width], - dtype=dtype, - transient=True, - storage=dace.dtypes.StorageType.FPGA_Registers) - vect_data = fpga_state_0.add_access("vec_data") - inc_data = fpga_state_0.add_access("inc_data") - # Read the data - - map_entry, map_exit = fpga_state_0.add_map("read_A", { - "n0": "0:{}/{}".format(N, vec_width), - }, - schedule=dace.ScheduleType.FPGA_Device) - read_map_entry, read_map_exit = fpga_state_0.add_map("unrolled_reads", {"n1": "0:{}".format(vec_width)}, - schedule=dace.ScheduleType.FPGA_Device, - unroll=True) - - # In the innermost map we read W=vec_width data elements and we store them into `vec_data` - tasklet = fpga_state_0.add_tasklet("read_data", {"from_memory"}, {"to_kernel"}, "to_kernel = from_memory") - fpga_state_0.add_memlet_path(a_in, - map_entry, - read_map_entry, - tasklet, - dst_conn="from_memory", - memlet=dace.Memlet("device_a[n0*{}+n1]".format(vec_width))) - - fpga_state_0.add_memlet_path(tasklet, - read_map_exit, - vect_data, - src_conn="to_kernel", - memlet=dace.Memlet("vec_data[n1]")) - - # Increment all the elements by one - inc_map_entry, inc_map_exit = fpga_state_0.add_map("unrolled_inc", {"n1": "0:{}".format(vec_width)}, - schedule=dace.ScheduleType.FPGA_Device, - unroll=True) - - inc_tasklet = fpga_state_0.add_tasklet("increment", {"_a"}, {"_b"}, "_b = _a + 1 ") - - fpga_state_0.add_memlet_path(vect_data, - inc_map_entry, - inc_tasklet, - dst_conn="_a", - memlet=dace.Memlet("vec_data[n1]")) - - fpga_state_0.add_memlet_path(inc_tasklet, inc_map_exit, inc_data, src_conn="_b", memlet=dace.Memlet("inc_data[n1]")) - - # write it to memory (it wil pack it) - fpga_state_0.add_memlet_path(inc_data, - map_exit, - b_out, - src_conn="outCon", - memlet=dace.Memlet(f"{b_out}[n0]", other_subset="0:{}".format(vec_width))) - - ######################################################################## - # FPGA, Second State: - # this read, increment, unpack, re-pack and save to memory - - fpga_state_1 = sdfg.add_state("fpga_state_1") - - b_in = fpga_state_1.add_read("device_b") - c_out = fpga_state_1.add_write("device_c") - - # unpack data - sdfg.add_array('vec_data_B', - shape=[vec_width], - dtype=dtype, - transient=True, - storage=dace.dtypes.StorageType.FPGA_Registers) - sdfg.add_array('inc_data_B', - shape=[vec_width], - dtype=dtype, - transient=True, - storage=dace.dtypes.StorageType.FPGA_Registers) - - map_entry, map_exit = fpga_state_1.add_map("read_B", { - "n0": "0:{}/{}".format(N, vec_width), - }, - schedule=dace.ScheduleType.FPGA_Device) - - vect_data = fpga_state_1.add_access("vec_data_B") - inc_data = fpga_state_1.add_access("inc_data_B") - - # unpack data - fpga_state_1.add_memlet_path(b_in, - map_entry, - vect_data, - memlet=dace.Memlet("device_b[n0]", other_subset="0:{}".format(vec_width))) - - # Increment all the elements by one - inc_map_entry, inc_map_exit = fpga_state_1.add_map("unrolled_inc", {"n1": "0:{}".format(vec_width)}, - schedule=dace.ScheduleType.FPGA_Device, - unroll=True) - - inc_tasklet = fpga_state_1.add_tasklet("increment", {"_a"}, {"_b"}, "_b = _a + 1 ") - - fpga_state_1.add_memlet_path(vect_data, - inc_map_entry, - inc_tasklet, - dst_conn="_a", - memlet=dace.Memlet("vec_data_B[n1]")) - - fpga_state_1.add_memlet_path(inc_tasklet, - inc_map_exit, - inc_data, - src_conn="_b", - memlet=dace.Memlet("inc_data_B[n1]")) - - # then we copy that to C, we need other gearboxing - fpga_state_1.add_memlet_path(inc_data, - map_exit, - c_out, - src_conn="to_memory", - memlet=dace.Memlet(f"{c_out.data}[n0]", other_subset="0:{}".format(vec_width))) - - ######################################################################## - # FPGA, third State, read from C, write unpacked to D - - fpga_state_2 = sdfg.add_state("fpga_state_2") - - c_in = fpga_state_2.add_read("device_c") - d_out = fpga_state_2.add_write("device_d") - - # unpack data - sdfg.add_array('vec_data_C', - shape=[vec_width], - dtype=dtype, - transient=True, - storage=dace.dtypes.StorageType.FPGA_Registers) - - map_entry, map_exit = fpga_state_2.add_map("read_C", { - "n0": "0:{}/{}".format(N, vec_width), - }, - schedule=dace.ScheduleType.FPGA_Device) - vect_data = fpga_state_2.add_access("vec_data_C") - write_map_entry, write_map_exit = fpga_state_2.add_map("unrolled_reads", {"n1": "0:{}".format(vec_width)}, - schedule=dace.ScheduleType.FPGA_Device, - unroll=True) - - fpga_state_2.add_memlet_path(c_in, - map_entry, - vect_data, - memlet=dace.Memlet("device_c[n0]", other_subset="0:{}".format(vec_width))) - - # then we copy that to memory - tasklet = fpga_state_2.add_tasklet("write_D", {"from_kernel"}, {"to_memory"}, "to_memory = from_kernel") - fpga_state_2.add_memlet_path(vect_data, - write_map_entry, - tasklet, - dst_conn="from_kernel", - memlet=dace.Memlet("vec_data_C[n1]")) - - fpga_state_2.add_memlet_path(tasklet, - write_map_exit, - map_exit, - d_out, - src_conn="to_memory", - memlet=dace.Memlet(f"{d_out.data}[n0*{vec_width}+n1]")) - - ###################################### - # Interstate edges - sdfg.add_edge(copy_in_state, fpga_state_0, dace.sdfg.sdfg.InterstateEdge()) - sdfg.add_edge(fpga_state_0, fpga_state_1, dace.sdfg.sdfg.InterstateEdge()) - sdfg.add_edge(fpga_state_1, fpga_state_2, dace.sdfg.sdfg.InterstateEdge()) - sdfg.add_edge(fpga_state_2, copy_out_state, dace.sdfg.sdfg.InterstateEdge()) - - ######### - # Validate - sdfg.fill_scope_connectors() - sdfg.validate() - return sdfg - - -@fpga_test(assert_ii_1=False) -def test_multiple_veclen_conversions_test(): - - parser = argparse.ArgumentParser() - parser.add_argument("N", type=int, nargs="?", default=32) - args = vars(parser.parse_args()) - - size_n = args["N"] - - sdfg = make_sdfg() - comp = sdfg.compile() - - a = np.random.rand(size_n).astype(np.float32) - d = np.random.rand(size_n).astype(np.float32) - ref = a + 2 - comp(a=a, d=d, N=size_n) - diff = np.linalg.norm(ref - d) / size_n - print("Difference:", diff) - if diff <= 1e-5: - print("==== Program end ====") - else: - print("==== Program Error! ====") - - assert diff <= 1e-5 - - return sdfg - - -if __name__ == "__main__": - test_multiple_veclen_conversions_test(None) diff --git a/tests/fpga/nested_sdfg_as_kernel_test.py b/tests/fpga/nested_sdfg_as_kernel_test.py deleted file mode 100644 index fe3777c978..0000000000 --- a/tests/fpga/nested_sdfg_as_kernel_test.py +++ /dev/null @@ -1,448 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. - -# The scope of the test is to verify that two nested SDFGs within the same state are generated -# as two different FPGA kernels. - -# There are two different tests: -# - independent: the two nested SDFGs (vector addition and vector multiplication) do not depend one from the other -# - dependent: the second nested SDFG uses the result produced by the first one. The result is stored on an FPGA array -# The two Nested SDFGs implements vector addition - -import dace -import numpy as np -import argparse -import subprocess - -from dace.fpga_testing import fpga_test -from dace.memlet import Memlet - - -def make_vec_add_sdfg(dtype=dace.float32): - - # Vector addition SDFG - - vecWidth = 4 - n = dace.symbol("size") - vecAdd_sdfg = dace.SDFG("vec_add") - vecType = dace.vector(dtype, vecWidth) - fpga_state = vecAdd_sdfg.add_state("vec_add_state") - - vecAdd_sdfg.add_array('_device_x', shape=[n / vecWidth], dtype=vecType, storage=dace.dtypes.StorageType.FPGA_Global) - vecAdd_sdfg.add_array('_device_y', shape=[n / vecWidth], dtype=vecType, storage=dace.dtypes.StorageType.FPGA_Global) - vecAdd_sdfg.add_array('_device_z', shape=[n / vecWidth], dtype=vecType, storage=dace.dtypes.StorageType.FPGA_Global) - - x = fpga_state.add_read("_device_x") - y = fpga_state.add_read("_device_y") - z = fpga_state.add_write("_device_z") - - # ---------- ---------- - # COMPUTE - # ---------- ---------- - vecMap_entry, vecMap_exit = fpga_state.add_map('vecAdd_map', - dict(i='0:{0}/{1}'.format(n, vecWidth)), - schedule=dace.dtypes.ScheduleType.FPGA_Device) - - vecAdd_tasklet = fpga_state.add_tasklet('vec_add_task', ['x_con', 'y_con'], ['z_con'], 'z_con = x_con + y_con') - - fpga_state.add_memlet_path(x, vecMap_entry, vecAdd_tasklet, dst_conn='x_con', memlet=dace.Memlet(f"{x.data}[i]")) - - fpga_state.add_memlet_path(y, vecMap_entry, vecAdd_tasklet, dst_conn='y_con', memlet=dace.Memlet(f"{y.data}[i]")) - - fpga_state.add_memlet_path(vecAdd_tasklet, vecMap_exit, z, src_conn='z_con', memlet=dace.Memlet(f"{z.data}[i]")) - - ######### - # Validate - vecAdd_sdfg.fill_scope_connectors() - vecAdd_sdfg.validate() - return vecAdd_sdfg - - -def make_vec_mul_sdfg(dtype=dace.float32): - # Vector multiplication SDFG - - vecWidth = 4 - n = dace.symbol("size") - vecMul_sdfg = dace.SDFG("vec_mul") - vecType = dace.vector(dtype, vecWidth) - fpga_state = vecMul_sdfg.add_state("vec_mul_state") - - vecMul_sdfg.add_array('_device_x', shape=[n / vecWidth], dtype=vecType, storage=dace.dtypes.StorageType.FPGA_Global) - vecMul_sdfg.add_array('_device_y', shape=[n / vecWidth], dtype=vecType, storage=dace.dtypes.StorageType.FPGA_Global) - vecMul_sdfg.add_array('_device_z', shape=[n / vecWidth], dtype=vecType, storage=dace.dtypes.StorageType.FPGA_Global) - - x = fpga_state.add_read("_device_x") - y = fpga_state.add_read("_device_y") - z = fpga_state.add_write("_device_z") - - # ---------- ---------- - # COMPUTE - # ---------- ---------- - vecMap_entry, vecMap_exit = fpga_state.add_map('vecMul_map', - dict(i='0:{0}/{1}'.format(n, vecWidth)), - schedule=dace.dtypes.ScheduleType.FPGA_Device) - - vecMul_tasklet = fpga_state.add_tasklet('vecMul_task', ['x_con', 'y_con'], ['z_con'], 'z_con = x_con * y_con') - - fpga_state.add_memlet_path(x, vecMap_entry, vecMul_tasklet, dst_conn='x_con', memlet=dace.Memlet(f"{x.data}[i]")) - - fpga_state.add_memlet_path(y, vecMap_entry, vecMul_tasklet, dst_conn='y_con', memlet=dace.Memlet(f"{y.data}[i]")) - - fpga_state.add_memlet_path(vecMul_tasklet, vecMap_exit, z, src_conn='z_con', memlet=dace.Memlet(f"{z.data}[i]")) - - ######### - # Validate - vecMul_sdfg.fill_scope_connectors() - vecMul_sdfg.validate() - return vecMul_sdfg - - -def make_fpga_sdfg(): - """ - Build an SDFG with two nested SDFGs in a single FPGA state - """ - - n = dace.symbol("n") - vecWidth = 4 - vecType = dace.vector(dace.float32, vecWidth) - sdfg = dace.SDFG("nested_sdfg_kernels") - - ########################################################################### - # Copy data to FPGA - - copy_in_state = sdfg.add_state("copy_to_device") - - sdfg.add_array("x", shape=[n / vecWidth], dtype=vecType) - sdfg.add_array("y", shape=[n / vecWidth], dtype=vecType) - - sdfg.add_array("v", shape=[n / vecWidth], dtype=vecType) - - in_host_x = copy_in_state.add_read("x") - in_host_y = copy_in_state.add_read("y") - - in_host_v = copy_in_state.add_read("v") - - sdfg.add_array("device_x", - shape=[n / vecWidth], - dtype=vecType, - storage=dace.dtypes.StorageType.FPGA_Global, - transient=True) - sdfg.add_array("device_y", - shape=[n / vecWidth], - dtype=vecType, - storage=dace.dtypes.StorageType.FPGA_Global, - transient=True) - - sdfg.add_array("device_v", - shape=[n / vecWidth], - dtype=vecType, - storage=dace.dtypes.StorageType.FPGA_Global, - transient=True) - - in_device_x = copy_in_state.add_write("device_x") - in_device_y = copy_in_state.add_write("device_y") - - in_device_v = copy_in_state.add_write("device_v") - - copy_in_state.add_memlet_path(in_host_x, in_device_x, memlet=dace.Memlet(f"{in_host_x.data}[0:{n}/{vecWidth}]")) - copy_in_state.add_memlet_path(in_host_y, in_device_y, memlet=dace.Memlet(f"{in_host_y.data}[0:{n}/{vecWidth}]")) - - copy_in_state.add_memlet_path(in_host_v, in_device_v, memlet=dace.Memlet(f"{in_host_v.data}[0:{n}/{vecWidth}]")) - - ########################################################################### - # Copy data from FPGA - sdfg.add_array("z", shape=[n / vecWidth], dtype=vecType) - sdfg.add_array("u", shape=[n / vecWidth], dtype=vecType) - - copy_out_state = sdfg.add_state("copy_to_host") - - sdfg.add_array("device_z", - shape=[n / vecWidth], - dtype=vecType, - storage=dace.dtypes.StorageType.FPGA_Global, - transient=True) - - sdfg.add_array("device_u", - shape=[n / vecWidth], - dtype=vecType, - storage=dace.dtypes.StorageType.FPGA_Global, - transient=True) - - out_device_z = copy_out_state.add_read("device_z") - out_host_z = copy_out_state.add_write("z") - - out_device_u = copy_out_state.add_read("device_u") - out_host_u = copy_out_state.add_write("u") - - copy_out_state.add_memlet_path(out_device_z, out_host_z, memlet=dace.Memlet(f"{out_host_z.data}[0:{n}/{vecWidth}]")) - copy_out_state.add_memlet_path(out_device_u, out_host_u, memlet=dace.Memlet(f"{out_host_u.data}[0:{n}/{vecWidth}]")) - ########################################################################### - # State that must not become an FPGA kernel - - non_fpga_state = sdfg.add_state("I_do_not_want_to_be_fpga_kernel") - non_fpga_state.location["is_FPGA_kernel"] = False - # Build the vec addition SDFG and nest it - - in_device_x = non_fpga_state.add_read("device_x") - in_device_y = non_fpga_state.add_read("device_y") - in_device_v = non_fpga_state.add_read("device_v") - out_device_z = non_fpga_state.add_write("device_z") - out_device_u = non_fpga_state.add_write("device_u") - - to_nest = make_vec_add_sdfg() - # add nested sdfg with symbol mapping - nested_sdfg = non_fpga_state.add_nested_sdfg(to_nest, {"_device_x", "_device_y"}, {"_device_z"}, {"size": "n"}) - - non_fpga_state.add_memlet_path(in_device_x, - nested_sdfg, - dst_conn="_device_x", - memlet=dace.Memlet(f"{in_device_x.data}[0:{n}/{vecWidth}]")) - non_fpga_state.add_memlet_path(in_device_y, - nested_sdfg, - dst_conn="_device_y", - memlet=dace.Memlet(f"{in_device_y.data}[0:{n}/{vecWidth}]")) - non_fpga_state.add_memlet_path(nested_sdfg, - out_device_z, - src_conn="_device_z", - memlet=dace.Memlet(f"{out_device_z.data}[0:{n}/{vecWidth}]")) - - # Build the second vec addition SDFG and nest it - - to_nest = make_vec_add_sdfg() - # add nested sdfg with symbol mapping - nested_sdfg = non_fpga_state.add_nested_sdfg(to_nest, {"_device_x", "_device_y"}, {"_device_z"}, {"size": "n"}) - - non_fpga_state.add_memlet_path(out_device_z, - nested_sdfg, - dst_conn="_device_x", - memlet=dace.Memlet(f"{out_device_z.data}[0:{n}/{vecWidth}]")) - non_fpga_state.add_memlet_path(in_device_v, - nested_sdfg, - dst_conn="_device_y", - memlet=dace.Memlet(f"{in_device_v.data}[0:{n}/{vecWidth}]")) - non_fpga_state.add_memlet_path(nested_sdfg, - out_device_u, - src_conn="_device_z", - memlet=dace.Memlet(f"{out_device_u.data}[0:{n}/{vecWidth}]")) - - ###################################### - # Interstate edges - sdfg.add_edge(copy_in_state, non_fpga_state, dace.sdfg.sdfg.InterstateEdge()) - sdfg.add_edge(non_fpga_state, copy_out_state, dace.sdfg.sdfg.InterstateEdge()) - sdfg.fill_scope_connectors() - sdfg.validate() - - return sdfg - - -def make_fpga_sdfg_independent(): - """ - Build an SDFG with two nested SDFGs in a single FPGA state - """ - - n = dace.symbol("n") - vecWidth = 4 - vecType = dace.vector(dace.float32, vecWidth) - sdfg = dace.SDFG("nested_sdfg_kernels") - - ########################################################################### - # Copy data to FPGA - - copy_in_state = sdfg.add_state("copy_to_device") - - sdfg.add_array("x", shape=[n / vecWidth], dtype=vecType) - sdfg.add_array("y", shape=[n / vecWidth], dtype=vecType) - - sdfg.add_array("v", shape=[n / vecWidth], dtype=vecType) - sdfg.add_array("w", shape=[n / vecWidth], dtype=vecType) - - in_host_x = copy_in_state.add_read("x") - in_host_y = copy_in_state.add_read("y") - - in_host_v = copy_in_state.add_read("v") - in_host_w = copy_in_state.add_read("w") - - sdfg.add_array("device_x", - shape=[n / vecWidth], - dtype=vecType, - storage=dace.dtypes.StorageType.FPGA_Global, - transient=True) - sdfg.add_array("device_y", - shape=[n / vecWidth], - dtype=vecType, - storage=dace.dtypes.StorageType.FPGA_Global, - transient=True) - - sdfg.add_array("device_v", - shape=[n / vecWidth], - dtype=vecType, - storage=dace.dtypes.StorageType.FPGA_Global, - transient=True) - sdfg.add_array("device_w", - shape=[n / vecWidth], - dtype=vecType, - storage=dace.dtypes.StorageType.FPGA_Global, - transient=True) - - in_device_x = copy_in_state.add_write("device_x") - in_device_y = copy_in_state.add_write("device_y") - - in_device_v = copy_in_state.add_write("device_v") - in_device_w = copy_in_state.add_write("device_w") - - copy_in_state.add_memlet_path(in_host_x, in_device_x, memlet=dace.Memlet(f"{in_host_x.data}[0:{n}/{vecWidth}]")) - copy_in_state.add_memlet_path(in_host_y, in_device_y, memlet=dace.Memlet(f"{in_host_y.data}[0:{n}/{vecWidth}]")) - - copy_in_state.add_memlet_path(in_host_v, in_device_v, memlet=dace.Memlet(f"{in_host_v.data}[0:{n}/{vecWidth}]")) - copy_in_state.add_memlet_path(in_host_w, in_device_w, memlet=dace.Memlet(f"{in_host_w.data}[0:{n}/{vecWidth}]")) - - ########################################################################### - # Copy data from FPGA - sdfg.add_array("z", shape=[n / vecWidth], dtype=vecType) - sdfg.add_array("u", shape=[n / vecWidth], dtype=vecType) - - copy_out_state = sdfg.add_state("copy_to_host") - - sdfg.add_array("device_z", - shape=[n / vecWidth], - dtype=vecType, - storage=dace.dtypes.StorageType.FPGA_Global, - transient=True) - - sdfg.add_array("device_u", - shape=[n / vecWidth], - dtype=vecType, - storage=dace.dtypes.StorageType.FPGA_Global, - transient=True) - - out_device_z = copy_out_state.add_read("device_z") - out_host_z = copy_out_state.add_write("z") - - out_device_u = copy_out_state.add_read("device_u") - out_host_u = copy_out_state.add_write("u") - - copy_out_state.add_memlet_path(out_device_z, out_host_z, memlet=dace.Memlet(f"{out_host_z.data}[0:{n}/{vecWidth}]")) - copy_out_state.add_memlet_path(out_device_u, out_host_u, memlet=dace.Memlet(f"{out_host_u.data}[0:{n}/{vecWidth}]")) - ########################################################################### - # Non-FPGA state - - non_fpga_state = sdfg.add_state("I_do_not_want_to_be_fpga_kernel") - non_fpga_state.location["is_FPGA_kernel"] = False - - in_device_x = non_fpga_state.add_read("device_x") - in_device_y = non_fpga_state.add_read("device_y") - in_device_v = non_fpga_state.add_read("device_v") - in_device_w = non_fpga_state.add_read("device_w") - out_device_z = non_fpga_state.add_write("device_z") - out_device_u = non_fpga_state.add_write("device_u") - - # Build the vec addition SDFG and nest it - - to_nest = make_vec_add_sdfg() - # add nested sdfg with symbol mapping - nested_sdfg = non_fpga_state.add_nested_sdfg(to_nest, {"_device_x", "_device_y"}, {"_device_z"}, {"size": "n"}) - - non_fpga_state.add_memlet_path(in_device_x, - nested_sdfg, - dst_conn="_device_x", - memlet=dace.Memlet(f"{in_device_x.data}[0:{n}/{vecWidth}]")) - non_fpga_state.add_memlet_path(in_device_y, - nested_sdfg, - dst_conn="_device_y", - memlet=dace.Memlet(f"{in_device_y.data}[0:{n}/{vecWidth}]")) - non_fpga_state.add_memlet_path(nested_sdfg, - out_device_z, - src_conn="_device_z", - memlet=dace.Memlet(f"{out_device_z.data}[0:{n}/{vecWidth}]")) - - # Build the vec multiplication SDFG and nest it - - to_nest = make_vec_mul_sdfg() - # add nested sdfg with symbol mapping - nested_sdfg = non_fpga_state.add_nested_sdfg(to_nest, {"_device_x", "_device_y"}, {"_device_z"}, {"size": "n"}) - - non_fpga_state.add_memlet_path(in_device_v, - nested_sdfg, - dst_conn="_device_x", - memlet=dace.Memlet(f"{in_device_v.data}[0:{n}/{vecWidth}]")) - non_fpga_state.add_memlet_path(in_device_w, - nested_sdfg, - dst_conn="_device_y", - memlet=dace.Memlet(f"{in_device_w.data}[0:{n}/{vecWidth}]")) - non_fpga_state.add_memlet_path(nested_sdfg, - out_device_u, - src_conn="_device_z", - memlet=dace.Memlet(f"{out_device_u.data}[0:{n}/{vecWidth}]")) - - ###################################### - # Interstate edges - sdfg.add_edge(copy_in_state, non_fpga_state, dace.sdfg.sdfg.InterstateEdge()) - sdfg.add_edge(non_fpga_state, copy_out_state, dace.sdfg.sdfg.InterstateEdge()) - sdfg.fill_scope_connectors() - sdfg.validate() - - return sdfg - - -@fpga_test() -def test_nested_sdfg_as_kernel(): - - parser = argparse.ArgumentParser() - parser.add_argument("N", type=int, nargs="?", default=32) - args = vars(parser.parse_args()) - - size_n = args["N"] - - ########################################## - # SDFG with two disconnected Nested SDFGs - ######################################### - sdfg = make_fpga_sdfg_independent() - vec_ops = sdfg.compile() - - x = np.random.rand(size_n).astype(np.float32) - y = np.random.rand(size_n).astype(np.float32) - z = np.random.rand(size_n).astype(np.float32) - - v = np.random.rand(size_n).astype(np.float32) - u = np.random.rand(size_n).astype(np.float32) - w = np.random.rand(size_n).astype(np.float32) - - vec_ops(x=x, y=y, z=z, v=v, w=w, u=u, n=size_n) - ref1 = np.add(x, y) - ref2 = np.multiply(v, w) - diff1 = np.linalg.norm(ref1 - z) / size_n - diff2 = np.linalg.norm(ref2 - u) / size_n - - ########################################## - # SDFG with two connected Nested SDFGs - ########################################## - - sdfg = make_fpga_sdfg() - - vec_ops = sdfg.compile() - - x = np.random.rand(size_n).astype(np.float32) - y = np.random.rand(size_n).astype(np.float32) - z = np.random.rand(size_n).astype(np.float32) - - v = np.random.rand(size_n).astype(np.float32) - - vec_ops(x=x, y=y, z=z, v=v, u=u, n=size_n) - ref3 = np.add(x, y) - ref4 = np.add(ref3, v) - - diff3 = np.linalg.norm(ref3 - z) / size_n - diff4 = np.linalg.norm(ref4 - u) / size_n - - if diff1 <= 1e-5 and diff2 <= 1e-5 and diff3 <= 1e-5 and diff4 <= 1e-5: - print("==== Program end ====") - else: - raise Exception("==== Program Error! ====") - - # There is no need to check that the Nested SDFG has been generated only once. If this is not the case - # the test will fail while compiling - - return sdfg - - -if __name__ == "__main__": - test_nested_sdfg_as_kernel(None) diff --git a/tests/fpga/overapprox_transient_shapes_test.py b/tests/fpga/overapprox_transient_shapes_test.py deleted file mode 100644 index adf4f020d3..0000000000 --- a/tests/fpga/overapprox_transient_shapes_test.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -""" - Tests over-approximation of transient shapes. - In the computation, the result produced by the inner loop - is stored in a transient container, whose shape should be correctly overapproximated. -""" - -import numpy as np -import dace -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG - -M, N = (dace.symbol(s, dtype=dace.int32) for s in ('M', 'N')) - - -@dace.program -def overapprox(alpha: dace.float32, C: dace.float32[N, N], A: dace.float32[N, M]): - - for i in range(N): - tmp = np.zeros((N, ), dtype=np.float32) - for k in range(M): - tmp[:i + 1] += alpha * A[:i + 1, k] - C[i, :i + 1] = tmp[:i + 1] - - -def reference(alpha, A, C, N, M): - - for i in range(N): - tmp = np.zeros((N, ), dtype=np.float32) - for k in range(M): - tmp[:i + 1] += alpha * A[:i + 1, k] - C[i, :i + 1] = tmp[:i + 1] - - -@fpga_test() -def test_overapprox_transient_shapes(): - size_n = 4 - size_m = 8 - alpha = 1.1 - C = np.random.rand(size_n, size_n).astype(np.float32) - A = np.random.rand(size_n, size_m).astype(np.float32) - C_np = np.copy(C) - sdfg = overapprox.to_sdfg() - sdfg.apply_transformations([FPGATransformSDFG]) - sdfg(N=size_n, M=size_m, A=A, C=C, alpha=alpha) - reference(alpha, A, C_np, size_n, size_m) - assert np.allclose(C_np, C) - return sdfg - - -if __name__ == "__main__": - test_overapprox_transient_shapes(None) diff --git a/tests/fpga/pipeline_scope_test.py b/tests/fpga/pipeline_scope_test.py deleted file mode 100644 index d5faab1c22..0000000000 --- a/tests/fpga/pipeline_scope_test.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -import copy -import dace -from dace.fpga_testing import fpga_test, xilinx_test - - -def make_sdfg(dtype, - name="pipeline_test", - input_device_memory="ddr", - input_device_bank="0", - output_device_memory="ddr", - output_device_bank="1"): - - n = dace.symbol("N") - k = dace.symbol("K") - m = dace.symbol("M") - - sdfg = dace.SDFG(name) - - pre_state = sdfg.add_state(name + "_pre") - state = sdfg.add_state(name) - post_state = sdfg.add_state(name + "_post") - sdfg.add_edge(pre_state, state, dace.InterstateEdge()) - sdfg.add_edge(state, post_state, dace.InterstateEdge()) - - _, desc_input_host = sdfg.add_array("a", (n, k, m), dtype) - _, desc_output_host = sdfg.add_array("b", (n, k, m), dtype) - desc_input_device = copy.copy(desc_input_host) - desc_input_device.storage = dace.StorageType.FPGA_Global - desc_input_device.location["memorytype"] = input_device_memory - desc_input_device.location["bank"] = input_device_bank - desc_input_device.transient = True - desc_output_device = copy.copy(desc_output_host) - desc_output_device.storage = dace.StorageType.FPGA_Global - desc_output_device.location["memorytype"] = output_device_memory - desc_output_device.location["bank"] = output_device_bank - desc_output_device.transient = True - sdfg.add_datadesc("a_device", desc_input_device) - sdfg.add_datadesc("b_device", desc_output_device) - - # Host to device - pre_read = pre_state.add_read("a") - pre_write = pre_state.add_write("a_device") - pre_state.add_memlet_path(pre_read, pre_write, memlet=dace.Memlet("a_device[0:N, 0:K, 0:M]")) - - # Device to host - post_read = post_state.add_read("b_device") - post_write = post_state.add_write("b") - post_state.add_memlet_path(post_read, post_write, memlet=dace.Memlet("b[0:N, 0:K, 0:M]")) - - # Compute state - read_memory = state.add_read("a_device") - write_memory = state.add_write("b_device") - - # Memory streams - sdfg.add_stream("a_stream", dtype, storage=dace.StorageType.FPGA_Local, transient=True) - sdfg.add_stream("b_stream", dtype, storage=dace.StorageType.FPGA_Local, transient=True) - produce_input_stream = state.add_write("a_stream") - consume_input_stream = state.add_read("a_stream") - produce_output_stream = state.add_write("b_stream") - consume_output_stream = state.add_write("b_stream") - - entry, exit = state.add_pipeline(name, { - "n": "0:N", - "k": "0:K", - "m": "0:M", - }, - schedule=dace.ScheduleType.FPGA_Device, - init_size=k * m, - init_overlap=True, - drain_size=k * m, - drain_overlap=True, - additional_iterators={'user_var': 0}) - # for the sake of testing, use the additional user_var to set to zero the last element of each row - tasklet = state.add_tasklet( - name, {"_in"}, {"_out"}, """\ -_out = _in + (0 if user_var==M-1 else (1 if {} else (3 if {} else 2))) -if user_var == M-1: - user_var = 0 -else: - user_var = user_var + 1 -""".format(entry.pipeline.init_condition(), entry.pipeline.drain_condition())) - - # Container-to-container copies between arrays and streams - state.add_memlet_path(read_memory, - produce_input_stream, - memlet=dace.Memlet("a_device[0:N, 0:K, 0:M]", other_subset="0", volume=n * k * m)) - state.add_memlet_path(consume_output_stream, - write_memory, - memlet=dace.Memlet("b_device[0:N, 0:K, 0:M]", other_subset="0", volume=n * k * m)) - - # Input stream to buffer - state.add_memlet_path(consume_input_stream, - entry, - tasklet, - dst_conn="_in", - memlet=dace.Memlet("a_stream[0]", dynamic=True)) - - # Buffer to output stream - state.add_memlet_path(tasklet, - exit, - produce_output_stream, - src_conn="_out", - memlet=dace.Memlet("b_stream[0]", dynamic=True)) - - return sdfg - - -def exec_jacobi(jacobi, dtype): - import numpy as np - - n = 16 - k = 24 - m = 32 - - jacobi.specialize({"N": n, "K": k, "M": m}) - - a = np.copy(np.arange(n * k * m, dtype=dtype).reshape((n, k, m))) - b = np.empty((n, k, m), dtype=dtype) - - jacobi(a=a, b=b) - - ref = copy.copy(a) - ref[0, :, 0:-1] += 1 - ref[1:-1, :, 0:-1] += 2 - ref[-1, :, 0:-1] += 3 - - if (b != ref).any(): - print(b) - print(ref) - raise ValueError("Unexpected output.") - - return jacobi - - -@fpga_test() -def test_pipeline_scope(): - import numpy as np - - dtype = np.float64 - jacobi = make_sdfg(dtype=dtype) - return exec_jacobi(jacobi, dtype) - - -@xilinx_test() -def test_pipeline_scope_hbm(): - import numpy as np - - dtype = np.float32 - jacobi = make_sdfg(dtype, "pipeline_hbm_test", "hbm", "1", "hbm", "2") - return exec_jacobi(jacobi, dtype) - - -if __name__ == "__main__": - test_pipeline_scope(None) - test_pipeline_scope_hbm(None) diff --git a/tests/fpga/power_opencl_test.py b/tests/fpga/power_opencl_test.py deleted file mode 100644 index 2170f6da0b..0000000000 --- a/tests/fpga/power_opencl_test.py +++ /dev/null @@ -1,73 +0,0 @@ -# Simple test case to check how ** is translated to openCL - -import dace.dtypes -import numpy as np -import dace as dc -import argparse -from dace.fpga_testing import intel_fpga_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG - -N = dc.symbol('N', dtype=dc.int64) - - -@dc.program -def power_test_kernel(A: dc.float64[N], B: dc.float64[N], C: dc.int32[N]): - for i in range(N): - B[i] = B[i]**A[i] - A[i] = A[i]**C[i] - A[i] = A[i]**2.3 - B[i] = B[i]**3 - - -def initialize(N, datatype=np.float64): - - A = np.full((N, ), 1.2, dtype=datatype) - B = np.fromfunction(lambda i: (i + 3.1), (N, ), dtype=datatype) - C = np.full((N, ), 2, dtype=np.int32) - return A, B, C - - -def ground_truth(A, B, C): - for i in range(120): - B[i] = B[i]**A[i] - A[i] = A[i]**C[i] - A[i] = A[i]**2.3 - B[i] = B[i]**3 - - -def run_power_test(device_type: dace.dtypes.DeviceType): - ''' - Runs simple power calculations for Intel FPGAS - :return: the SDFG - ''' - - # Initialize data (polybench small size) - N = 120 - A, B, C = initialize(N) - A_ref = np.copy(A) - B_ref = np.copy(B) - C_ref = np.copy(C) - - # Parse SDFG and apply FPGA friendly optimization - sdfg = power_test_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - sdfg.specialize(dict(N=N)) - sdfg(A=A, B=B, C=C) - - # Compute ground truth and validate - ground_truth(A_ref, B_ref, C_ref) - - assert np.allclose(B, B_ref) - assert np.allclose(A, A_ref) - return sdfg - - -@intel_fpga_test(assert_ii_1=False) -def test_fpga(): - return run_power_test(dace.dtypes.DeviceType.FPGA) - - -if __name__ == "__main__": - - run_power_test(dace.dtypes.DeviceType.FPGA) diff --git a/tests/fpga/reduce_fpga_test.py b/tests/fpga/reduce_fpga_test.py deleted file mode 100644 index 007a8c770d..0000000000 --- a/tests/fpga/reduce_fpga_test.py +++ /dev/null @@ -1,155 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. - -# Tests reduce expansions for FPGA -import dace -import numpy as np -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG - - -def create_reduce_sdfg(wcr_str, reduction_axis, sdfg_name, input_data, output_data, dtype): - """ - Build an SDFG that perform the given reduction along the given axis - - :param wcr_str: reduction operation to perform - :param reduction_axis: the axis on which operate - :param sdfg_name: - :param input_data: - :param output_data: - """ - sdfg = dace.SDFG(sdfg_name) - - ########################################################################### - # Copy data to FPGA - - copy_in_state = sdfg.add_state("copy_to_device") - input_data_shape = input_data.shape - output_data_shape = output_data.shape - - sdfg.add_array('A', input_data_shape, dtype) - - in_host_A = copy_in_state.add_read('A') - - sdfg.add_array("device_A", - shape=input_data_shape, - dtype=dtype, - storage=dace.dtypes.StorageType.FPGA_Global, - transient=True) - - in_device_A = copy_in_state.add_write("device_A") - - copy_in_memlet = dace.Memlet("A[{}]".format(",".join([f"0:{i}" for i in input_data_shape]))) - - copy_in_state.add_memlet_path(in_host_A, in_device_A, memlet=copy_in_memlet) - - ########################################################################### - # Copy data from FPGA - - copy_out_state = sdfg.add_state("copy_from_device") - sdfg.add_array("B", output_data_shape, dtype) - sdfg.add_array("device_B", - shape=output_data_shape, - dtype=dtype, - storage=dace.dtypes.StorageType.FPGA_Global, - transient=True) - - out_device = copy_out_state.add_read("device_B") - out_host = copy_out_state.add_write("B") - copy_out_memlet = dace.Memlet("B[{}]".format(",".join([f"0:{i}" for i in output_data_shape]))) - copy_out_state.add_memlet_path(out_device, out_host, memlet=copy_out_memlet) - - ######################################################################## - # FPGA State - - fpga_state = sdfg.add_state("fpga_state") - r = fpga_state.add_read("device_A") - w = fpga_state.add_write("device_B") - red = fpga_state.add_reduce(wcr_str, reduction_axis, 0, schedule=dace.dtypes.ScheduleType.FPGA_Device) - - fpga_state.add_nedge(r, red, dace.Memlet(data="device_A")) - fpga_state.add_nedge(red, w, dace.Memlet(data="device_B")) - - ###################################### - # Interstate edges - sdfg.add_edge(copy_in_state, fpga_state, dace.sdfg.sdfg.InterstateEdge()) - sdfg.add_edge(fpga_state, copy_out_state, dace.sdfg.sdfg.InterstateEdge()) - - sdfg.validate() - - return sdfg - - -@fpga_test(assert_ii_1=False) -def test_reduce_sum_one_axis(): - A = np.random.rand(8, 8).astype(np.float32) - B = np.random.rand(8).astype(np.float32) - sdfg = create_reduce_sdfg("lambda a,b: a+b", [0], "reduction_sum_one_axis", A, B, dace.float32) - from dace.libraries.standard import Reduce - Reduce.default_implementation = "FPGAPartialReduction" - sdfg.expand_library_nodes() - sdfg(A=A, B=B) - assert np.allclose(B, np.sum(A, axis=0)) - return sdfg - - -@fpga_test() -def test_reduce_sum_all_axis(): - A = np.random.rand(4, 4).astype(np.float32) - B = np.random.rand(1).astype(np.float32) - sdfg = create_reduce_sdfg("lambda a,b: a+b", (0, 1), "reduction_sum_all_axis", A, B, dace.float32) - from dace.libraries.standard import Reduce - Reduce.default_implementation = "FPGAPartialReduction" - sdfg.expand_library_nodes() - sdfg(A=A, B=B) - assert np.allclose(B, np.sum(A, axis=(0, 1))) - return sdfg - - -@fpga_test(xilinx=False) -def test_reduce_sum_4D(): - A = np.random.rand(4, 4, 4, 12).astype(np.float64) - B = np.random.rand(4, 4).astype(np.float64) - sdfg = create_reduce_sdfg("lambda a,b: a+b", [2, 3], "reduction_sum_4D", A, B, dace.float64) - from dace.libraries.standard import Reduce - Reduce.default_implementation = "FPGAPartialReduction" - sdfg.expand_library_nodes() - sdfg(A=A, B=B) - assert np.allclose(B, np.sum(A, axis=(2, 3))) - return sdfg - - -@fpga_test(assert_ii_1=False) -def test_reduce_max(): - A = np.random.rand(4, 4).astype(np.float32) - B = np.random.rand(4).astype(np.float32) - sdfg = create_reduce_sdfg("lambda a,b: max(a,b)", [1], "reduction_max", A, B, dace.float32) - from dace.libraries.standard import Reduce - Reduce.default_implementation = "FPGAPartialReduction" - sdfg.expand_library_nodes() - sdfg(A=A, B=B) - assert np.allclose(B, np.max(A, axis=1)) - return sdfg - - -@fpga_test(assert_ii_1=False) -def test_reduce_scalar(): - - @dace.program - def reduction_to_scalar(A: dace.float64[64]): - result = dace.reduce(lambda a, b: a + b, A) - return result - - sdfg = reduction_to_scalar.to_sdfg() - sdfg.apply_transformations(FPGATransformSDFG) - - A = np.random.rand(64) - res = sdfg(A) - assert np.allclose(res, np.sum(A)) - return sdfg - - -if __name__ == "__main__": - test_reduce_sum_one_axis(None) - test_reduce_sum_all_axis(None) - test_reduce_sum_4D(None) - test_reduce_max(None) diff --git a/tests/fpga/remove_degenerate_loop_test.py b/tests/fpga/remove_degenerate_loop_test.py deleted file mode 100644 index cc98f66038..0000000000 --- a/tests/fpga/remove_degenerate_loop_test.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -import dace -from dace.fpga_testing import fpga_test - -import copy -import numpy as np -import re - - -def make_sdfg(name="transpose"): - - n = dace.symbol("N") - m = dace.symbol("M") - - sdfg = dace.SDFG(name) - - pre_state = sdfg.add_state(name + "_pre") - state = sdfg.add_state(name) - post_state = sdfg.add_state(name + "_post") - sdfg.add_edge(pre_state, state, dace.InterstateEdge()) - sdfg.add_edge(state, post_state, dace.InterstateEdge()) - - _, desc_input_host = sdfg.add_array("a_input", (n, m), dace.float64) - _, desc_output_host = sdfg.add_array("a_output", (m, n), dace.float64) - desc_input_device = copy.copy(desc_input_host) - desc_input_device.storage = dace.StorageType.FPGA_Global - desc_input_device.location["memorytype"] = "ddr" - desc_input_device.location["bank"] = "0" - desc_input_device.transient = True - desc_output_device = copy.copy(desc_output_host) - desc_output_device.storage = dace.StorageType.FPGA_Global - desc_output_device.location["memorytype"] = "ddr" - desc_output_device.location["bank"] = "1" - desc_output_device.transient = True - sdfg.add_datadesc("a_input_device", desc_input_device) - sdfg.add_datadesc("a_output_device", desc_output_device) - - # Host to device - pre_read = pre_state.add_read("a_input") - pre_write = pre_state.add_write("a_input_device") - pre_state.add_memlet_path(pre_read, pre_write, memlet=dace.Memlet.simple(pre_write, "0:N, 0:M")) - - # Device to host - post_read = post_state.add_read("a_output_device") - post_write = post_state.add_write("a_output") - post_state.add_memlet_path(post_read, post_write, memlet=dace.Memlet.simple(post_write, "0:N, 0:M")) - - # Compute state - read = state.add_read("a_input_device") - write = state.add_write("a_output_device") - - # Trivial tasklet - tasklet = state.add_tasklet(name, {"_in"}, {"_out"}, "_out = _in") - - entry, exit = state.add_map(name, { - "i": "0:N", - "j": "0:M", - }, schedule=dace.ScheduleType.FPGA_Device) - - state.add_memlet_path(read, - entry, - tasklet, - dst_conn="_in", - memlet=dace.Memlet.simple("a_input_device", "i, j", num_accesses=1)) - state.add_memlet_path(tasklet, - exit, - write, - src_conn="_out", - memlet=dace.Memlet.simple("a_output_device", "j, i", num_accesses=1)) - - return sdfg - - -@fpga_test() -def test_remove_degenerate_loop(): - - sdfg = make_sdfg("remove_degenerate_loop_test") - - size = 8192 - - sdfg.specialize({"N": size, "M": 1}) # Degenerate dimension - - codes = sdfg.generate_code() - tasklet_name = sdfg.name + "_tasklet" - for code in codes: - if code.target_type == "device": - break # code now points to the appropriate code object - else: # Sanity check - raise ValueError("Didn't find tasklet in degenerate map.") - - if re.search(r"for \(.+\bj\b < \bM\b", code.code) is not None: - raise ValueError("Single iteration loop was not removed.") - - first_assignment = re.search(r"\bj\b\s*=\s*0\s*;", code.code) - if first_assignment is None: - raise ValueError("Assignment to constant variable not found.") - - a_input = np.copy(np.arange(size, dtype=np.float64).reshape((size, 1))) - a_output = np.empty((1, size), dtype=np.float64) - - sdfg(a_input=a_input, a_output=a_output) - - if any(a_input.ravel() != a_output.ravel()): - raise ValueError("Unexpected output.") - - return sdfg - - -if __name__ == "__main__": - test_remove_degenerate_loop(None) diff --git a/tests/fpga/reshape_view_fpga_test.py b/tests/fpga/reshape_view_fpga_test.py deleted file mode 100644 index b1ddeffc3a..0000000000 --- a/tests/fpga/reshape_view_fpga_test.py +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -""" FPGA Tests for reshaping and reinterpretation of existing arrays. - Part of the following tests are based on the ones in numpy/reshape_test.py""" -import dace -import numpy as np -import pytest -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG, GPUTransformSDFG, NestSDFG -from dace.fpga_testing import fpga_test - -N = dace.symbol('N') - - -@fpga_test() -def test_view_fpga_sdfg(): - """ - Manually built FPGA-SDFG with a view: Array -> view -> Array - """ - - sdfg = dace.SDFG("view_fpga") - - ########################################################################### - # Copy data to FPGA - - copy_in_state = sdfg.add_state("copy_to_device") - - sdfg.add_array('A', [2, 3, 4], dace.float32) - - in_host_A = copy_in_state.add_read('A') - - sdfg.add_array("device_A", - shape=[2, 3, 4], - dtype=dace.float32, - storage=dace.dtypes.StorageType.FPGA_Global, - transient=True) - - in_device_A = copy_in_state.add_write("device_A") - - copy_in_state.add_memlet_path(in_host_A, in_device_A, memlet=dace.Memlet("A[0:2,0:3,0:4]")) - ########################################################################### - # Copy data from FPGA - - copy_out_state = sdfg.add_state("copy_from_device") - sdfg.add_array('B', [8, 3], dace.float32) - sdfg.add_array("device_B", - shape=[8, 3], - dtype=dace.float32, - storage=dace.dtypes.StorageType.FPGA_Global, - transient=True) - - out_device = copy_out_state.add_read("device_B") - out_host = copy_out_state.add_write("B") - - copy_out_state.add_memlet_path(out_device, out_host, memlet=dace.Memlet("B[0:8,0:3]")) - - ######################################################################## - # FPGA State - - fpga_state = sdfg.add_state("fpga_state") - - sdfg.add_view('Av', [8, 3], dace.float32, storage=dace.dtypes.StorageType.FPGA_Global) - r = fpga_state.add_read('device_A') - v = fpga_state.add_access('Av') - w = fpga_state.add_write('device_B') - fpga_state.add_edge(r, None, v, 'views', dace.Memlet(data='device_A')) - fpga_state.add_nedge(v, w, dace.Memlet(data='device_B')) - - ###################################### - # Interstate edges - sdfg.add_edge(copy_in_state, fpga_state, dace.sdfg.sdfg.InterstateEdge()) - sdfg.add_edge(fpga_state, copy_out_state, dace.sdfg.sdfg.InterstateEdge()) - - sdfg.validate() - - ########################################################################################### - # Execute - - A = np.random.rand(2, 3, 4).astype(np.float32) - B = np.random.rand(8, 3).astype(np.float32) - sdfg(A=A, B=B) - assert np.allclose(A, np.reshape(B, [2, 3, 4])) - - return sdfg - - -@fpga_test() -def test_reshape_np(): - """ - Dace program with numpy reshape, transformed for FPGA - """ - - @dace.program - def reshp_np(A: dace.float32[3, 4], B: dace.float32[2, 6]): - B[:] = np.reshape(A, [2, 6]) - - A = np.random.rand(3, 4).astype(np.float32) - B = np.random.rand(2, 6).astype(np.float32) - - sdfg = reshp_np.to_sdfg() - sdfg.apply_transformations([FPGATransformSDFG]) - sdfg(A=A, B=B) - assert np.allclose(np.reshape(A, [2, 6]), B) - - return sdfg - - -@fpga_test() -def test_reshape_dst_explicit(): - """ Tasklet->View->Array """ - sdfg = dace.SDFG('reshapedst') - sdfg.add_array('A', [2, 3, 4], dace.float64) - sdfg.add_view('Bv', [2, 3, 4], dace.float64) - sdfg.add_array('B', [8, 3], dace.float64) - state = sdfg.add_state() - - me, mx = state.add_map('compute', dict(i='0:2', j='0:3', k='0:4')) - t = state.add_tasklet('add', {'a'}, {'b'}, 'b = a + 1') - state.add_memlet_path(state.add_read('A'), me, t, dst_conn='a', memlet=dace.Memlet('A[i,j,k]')) - v = state.add_access('Bv') - state.add_memlet_path(t, mx, v, src_conn='b', memlet=dace.Memlet('Bv[i,j,k]')) - state.add_nedge(v, state.add_write('B'), dace.Memlet('B')) - sdfg.validate() - - A = np.random.rand(2, 3, 4) - B = np.random.rand(8, 3) - sdfg.apply_transformations([FPGATransformSDFG]) - sdfg(A=A, B=B) - assert np.allclose(A + 1, np.reshape(B, [2, 3, 4])) - - return sdfg - - -@fpga_test(assert_ii_1=False) -def test_view_slice(): - """ - In this test we use slice. In this case a view is used to access - the desired portion of the original array - (this is part of symm polybench kernel) - """ - N = dace.symbol('N', dace.int32) - M = dace.symbol('M', dace.int32) - - @dace.program - def view_slice(alpha: dace.float32, beta: dace.float32, C: dace.float32[M, N], A: dace.float32[M, M], - B: dace.float32[M, N]): - - C *= beta - for i in range(M): - for j in range(N): - C[:i, j] += alpha * B[i, j] * A[i, :i] - - def kernel_numpy(M, N, alpha, beta, C, A, B): - C *= beta - for i in range(M): - for j in range(N): - C[:i, j] += alpha * B[i, j] * A[i, :i] - - M, N = 16, 32 - alpha = 2 - beta = 3 - A = np.random.rand(M, M).astype(np.float32) - B = np.random.rand(M, N).astype(np.float32) - C = np.random.rand(M, N).astype(np.float32) - np_C = np.copy(C) - kernel_numpy(M, N, alpha, beta, np_C, A, B) - sdfg = view_slice.to_sdfg() - sdfg.apply_transformations([FPGATransformSDFG]) - sdfg(A=A, B=B, C=C, alpha=alpha, beta=beta, M=M, N=N) - assert np.allclose(C, np_C, atol=1e-06) - - return sdfg - - -if __name__ == "__main__": - test_reshape_np(None) - test_view_fpga_sdfg(None) - test_reshape_dst_explicit(None) - test_view_slice(None) diff --git a/tests/fpga/simple_systolic_array_test.py b/tests/fpga/simple_systolic_array_test.py deleted file mode 100644 index 6a77b2797e..0000000000 --- a/tests/fpga/simple_systolic_array_test.py +++ /dev/null @@ -1,294 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -""" -Simple systolic array of P processing element, each one increments by 1 the -incoming element. -""" - -import argparse -import dace -import numpy as np -import select -import sys -from dace.fpga_testing import fpga_test - -N = dace.symbol("N") -P = dace.symbol("P") - - -def make_copy_to_fpga_state(sdfg): - - ########################################################################### - # Copy data to FPGA - - state = sdfg.add_state("copy_to_device") - - A_host = state.add_array("A", [N], dtype=dace.int32) - - A_device = state.add_array("A_device", [N], - dtype=dace.int32, - transient=True, - storage=dace.dtypes.StorageType.FPGA_Global) - - state.add_edge(A_host, None, A_device, None, dace.memlet.Memlet.simple(A_device, "0:N")) - return state - - -def make_copy_to_host_state(sdfg): - - ########################################################################### - # Copy data to FPGA - - state = sdfg.add_state("copy_to_host") - - A_device = state.add_array("A_device", [N], - dtype=dace.int32, - transient=True, - storage=dace.dtypes.StorageType.FPGA_Global) - A_host = state.add_array("A", [N], dtype=dace.int32) - - state.add_edge(A_device, None, A_host, None, dace.memlet.Memlet.simple(A_host, "0:N")) - - return state - - -def make_read_A_sdfg(): - - sdfg = dace.SDFG("array_read_A") - - n_inner_begin = sdfg.add_state("n_inner_begin") - n_inner_entry = sdfg.add_state("n_inner_entry") - n_inner_end = sdfg.add_state("n_inner_end") - - loop_body = sdfg.add_state("read_memory") - - sdfg.add_edge(n_inner_begin, n_inner_entry, dace.sdfg.InterstateEdge(assignments={"n": 0})) - sdfg.add_edge( - n_inner_entry, loop_body, - dace.sdfg.InterstateEdge( - condition=dace.properties.CodeProperty.from_string("n < N", language=dace.dtypes.Language.Python))) - sdfg.add_edge( - n_inner_entry, n_inner_end, - dace.sdfg.InterstateEdge( - condition=dace.properties.CodeProperty.from_string("n >= N", language=dace.dtypes.Language.Python))) - - sdfg.add_edge(loop_body, n_inner_entry, dace.sdfg.InterstateEdge(assignments={"n": "n + 1"})) - - mem = loop_body.add_array("mem", [N], dtype=dace.int32, storage=dace.dtypes.StorageType.FPGA_Global) - - pipe = loop_body.add_stream("pipe", dace.int32, storage=dace.dtypes.StorageType.FPGA_Local) - - loop_body.add_memlet_path(mem, pipe, memlet=dace.memlet.Memlet.simple(pipe, '0', other_subset_str='n')) - - return sdfg - - -def make_write_A_sdfg(): - - sdfg = dace.SDFG("array_write_A") - - n_begin = sdfg.add_state("n_begin") - n_entry = sdfg.add_state("n_entry") - n_end = sdfg.add_state("n_end") - - loop_body = sdfg.add_state("write_memory") - - sdfg.add_edge(n_begin, n_entry, dace.sdfg.InterstateEdge(assignments={"n": 0})) - - sdfg.add_edge( - n_entry, loop_body, - dace.sdfg.InterstateEdge( - condition=dace.properties.CodeProperty.from_string("n < N", language=dace.dtypes.Language.Python))) - - sdfg.add_edge(loop_body, n_entry, dace.sdfg.InterstateEdge(assignments={"n": "n + 1"})) - - sdfg.add_edge( - n_entry, n_end, - dace.sdfg.InterstateEdge( - condition=dace.properties.CodeProperty.from_string("n >= N", language=dace.dtypes.Language.Python))) - - mem = loop_body.add_array("mem", [N], dtype=dace.int32, storage=dace.dtypes.StorageType.FPGA_Global) - - pipe = loop_body.add_stream("pipe", dace.int32, storage=dace.dtypes.StorageType.FPGA_Local) - - loop_body.add_memlet_path(pipe, mem, memlet=dace.memlet.Memlet.simple(mem, 'n', other_subset_str='0')) - - return sdfg - - -def make_compute_sdfg(): - - sdfg = dace.SDFG("gemm_compute") - - n_begin = sdfg.add_state("n_begin") - n_entry = sdfg.add_state("n_entry") - n_end = sdfg.add_state("n_end") - - state = sdfg.add_state("compute") - - # Data nodes - A_pipe_in = state.add_stream("A_stream_in", dace.int32, storage=dace.dtypes.StorageType.FPGA_Local) - A_pipe_out = state.add_stream("A_stream_out", dace.int32, storage=dace.dtypes.StorageType.FPGA_Local) - - # N-loop - sdfg.add_edge(n_begin, n_entry, dace.sdfg.InterstateEdge(assignments={"n": 0})) - sdfg.add_edge( - n_entry, state, - dace.sdfg.InterstateEdge( - condition=dace.properties.CodeProperty.from_string("n < N", language=dace.dtypes.Language.Python))) - sdfg.add_edge( - n_entry, n_end, - dace.sdfg.InterstateEdge( - condition=dace.properties.CodeProperty.from_string("n >= N", language=dace.dtypes.Language.Python))) - - # Backtrack two loops - sdfg.add_edge(state, n_entry, dace.sdfg.InterstateEdge(assignments={"n": "n + 1"})) - - # Compute tasklet - - compute_tasklet = state.add_tasklet("add", {"a_in"}, {"a_out"}, "a_out = a_in +1") - - state.add_memlet_path(A_pipe_in, - compute_tasklet, - memlet=dace.memlet.Memlet.simple(A_pipe_in, '0', num_accesses=-1), - dst_conn="a_in") - state.add_memlet_path(compute_tasklet, - A_pipe_out, - memlet=dace.memlet.Memlet.simple(A_pipe_out, '0', num_accesses=-1), - src_conn="a_out") - - return sdfg - - -def make_fpga_state(sdfg): - - state = sdfg.add_state("simple_array") - - read_A_sdfg = make_read_A_sdfg() - read_A_sdfg_node = state.add_nested_sdfg(read_A_sdfg, {"mem"}, {"pipe"}) - - compute_sdfg = make_compute_sdfg() - compute_sdfg_node = state.add_nested_sdfg(compute_sdfg, {"A_stream_in"}, {"A_stream_out"}) - - write_A_sdfg = make_write_A_sdfg() - write_A_sdfg_node = state.add_nested_sdfg(write_A_sdfg, {"pipe"}, {"mem"}) - - A_IN = state.add_array("A_device", [N], - dtype=dace.int32, - transient=True, - storage=dace.dtypes.StorageType.FPGA_Global) - A_OUT = state.add_array("A_device", [N], - dtype=dace.int32, - transient=True, - storage=dace.dtypes.StorageType.FPGA_Global) - A_pipe_read = state.add_stream("A_pipe", - dace.int32, - transient=True, - shape=(P + 1, ), - storage=dace.dtypes.StorageType.FPGA_Local) - A_pipe_in = state.add_stream("A_pipe", - dace.int32, - transient=True, - shape=(P + 1, ), - storage=dace.dtypes.StorageType.FPGA_Local) - A_pipe_write = state.add_stream("A_pipe", - dace.int32, - transient=True, - shape=(P + 1, ), - storage=dace.dtypes.StorageType.FPGA_Local) - A_pipe_out = state.add_stream("A_pipe", - dace.int32, - transient=True, - shape=(P + 1, ), - storage=dace.dtypes.StorageType.FPGA_Local) - - compute_entry, compute_exit = state.add_map("unroll_compute", {"p": "0:P"}, - schedule=dace.ScheduleType.FPGA_Device, - unroll=True) - - # Bring data nodes into scope - state.add_memlet_path(compute_entry, A_pipe_in, memlet=dace.memlet.Memlet()) - state.add_memlet_path(A_pipe_out, compute_exit, memlet=dace.memlet.Memlet()) - - # Connect data nodes - state.add_memlet_path(A_pipe_in, - compute_sdfg_node, - dst_conn="A_stream_in", - memlet=dace.memlet.Memlet.simple(A_pipe_in, - 'p', - num_accesses=dace.symbolic.pystr_to_symbolic("N/P"))) - state.add_memlet_path(compute_sdfg_node, - A_pipe_out, - src_conn="A_stream_out", - memlet=dace.memlet.Memlet.simple(A_pipe_out, - 'p + 1', - num_accesses=dace.symbolic.pystr_to_symbolic("N/P"))) - - state.add_memlet_path(A_IN, read_A_sdfg_node, dst_conn="mem", memlet=dace.memlet.Memlet.simple(A_IN, "0:N")) - state.add_memlet_path(read_A_sdfg_node, - A_pipe_read, - src_conn="pipe", - memlet=dace.memlet.Memlet.simple(A_pipe_in, - '0', - num_accesses=dace.symbolic.pystr_to_symbolic("N"))) - - state.add_memlet_path(A_pipe_write, - write_A_sdfg_node, - dst_conn="pipe", - memlet=dace.memlet.Memlet.simple(A_pipe_out, - 'P', - num_accesses=dace.symbolic.pystr_to_symbolic("N"))) - state.add_memlet_path(write_A_sdfg_node, A_OUT, src_conn="mem", memlet=dace.memlet.Memlet.simple(A_OUT, "0:N")) - - return state - - -def make_sdfg(name=None, p=None): - - if name is None: - if p is not None: - name = "simple_systolic_array_{}".format(p) - else: - name = "simple_systolic_array_P" - - sdfg = dace.SDFG(name) - - pre_state = make_copy_to_fpga_state(sdfg) - compute_state = make_fpga_state(sdfg) - post_state = make_copy_to_host_state(sdfg) - - sdfg.add_edge(pre_state, compute_state, dace.sdfg.InterstateEdge()) - sdfg.add_edge(compute_state, post_state, dace.sdfg.InterstateEdge()) - - return sdfg - - -@fpga_test(xilinx=False) -def test_simple_systolic_array(): - - P = 4 - N = 128 - - sdfg = make_sdfg() - sdfg.specialize(dict(P=P, N=N)) - - # Initialize arrays: Randomize A and B, zero C - A = np.ndarray([N], dtype=dace.int32.type) - A[:] = np.random.randint(0, 1000, N).astype(dace.int32.type) - - A_Exp = A + P - - sdfg(A=A) - # print("A: ", A) - # print("A_Exp: ", A_Exp) - diff = np.abs(A_Exp - A) - diff_total = np.sum(diff) - highest_diff = np.max(diff) - wrong_elements = np.transpose(np.nonzero(diff >= 0.01)) - - assert diff_total < 0.01 - - return sdfg - - -if __name__ == "__main__": - test_simple_systolic_array(None) diff --git a/tests/fpga/spmv_fpga_test.py b/tests/fpga/spmv_fpga_test.py deleted file mode 100644 index ea33b2f497..0000000000 --- a/tests/fpga/spmv_fpga_test.py +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -from pathlib import Path - -import dace -from dace.fpga_testing import fpga_test, import_sample - - -@fpga_test(assert_ii_1=False) -def test_spmv_fpga(): - spmv = import_sample(Path("fpga") / "spmv_fpga_stream.py") - return spmv.run_spmv(64, 64, 640, False) - - -if __name__ == "__main__": - test_spmv_fpga(None) diff --git a/tests/fpga/streaming_memory_test.py b/tests/fpga/streaming_memory_test.py deleted file mode 100644 index 025fc63341..0000000000 --- a/tests/fpga/streaming_memory_test.py +++ /dev/null @@ -1,994 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -""" Tests the StreamingMemory transformation. """ -import copy - -from numpy.core.numeric import allclose -import pytest -import dace -import dace.libraries.blas -import networkx as nx -import numpy as np - -from dace.transformation.dataflow import streaming_memory as sm, MapExpansion -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.fpga_testing import xilinx_test -from dace.transformation.auto.fpga import fpga_rr_interleave_containers_to_banks - -M, N, K = 64, 64, 64 - -M_s = dace.symbol('M_s') -N_s = dace.symbol('N_s') -K_s = dace.symbol('K_s') - - -@dace.program -def two_maps_kernel_legal(A: dace.float32[N], B: dace.float32[N], C: dace.float32[N], D: dace.float32[N], - E: dace.float32[N]): - - @dace.map - def sum(i: _[0:N]): - in_a << A[i] - in_b << B[i] - out >> D[i] - out = in_a + in_b - - @dace.map - def sum(i: _[0:N]): - in_b << B[i] - in_c << C[i] - out >> E[i] - out = in_b + in_c - - -@dace.program -def two_maps_kernel_illegal(A: dace.float32[N], B: dace.float32[N], C: dace.float32[N], D: dace.float32[N], - E: dace.float32[N]): - - @dace.map - def sum(i: _[0:N]): - in_a << A[i] - in_b << B[i] - out >> D[i] - out = in_a + in_b - - @dace.map - def sum(i: _[0:N:2]): - in_b << B[i] - in_c << C[i] - out >> E[i] - out = in_b + in_c - - -@dace.program -def bicg(A: dace.float32[N, M], p: dace.float32[M], r: dace.float32[N]): - return r @ A, A @ p - - -@dace.program -def atax(A: dace.float32[M, N], x: dace.float32[N]): - return (A @ x) @ A - - -@dace.program -def vecadd_1_streaming(A: dace.float32[N], B: dace.float32[N]): - B[:] = A + 1.0 - - -@dace.program -def vecadd_1_streaming_non_appl_0(A: dace.float32[N], B: dace.float32[N]): - for i in dace.map[0:61]: - with dace.tasklet: - in_A << A[i] - out >> B[i] - out = in_A + 1.0 - - -@dace.program -def vecadd_1_streaming_non_appl_1(A: dace.float32[N], B: dace.float32[N]): - for i in dace.map[0:N:2]: - with dace.tasklet: - in_A << A[i] - out >> B[i] - out = in_A + 1.0 - - -@dace.program -def vecadd_1_streaming_symbol(A: dace.float32[N_s], B: dace.float32[N_s]): - B[:] = A + 1.0 - - -@dace.program -def vecadd_streaming(A: dace.float32[N], B: dace.float32[N], C: dace.float32[N]): - C[:] = A + B - - -def vecadd_streaming_type(type0, type1, type2): - - @dace.program - def vecadd_streaming_type_kernel(A: type0[N], B: type1[N], C: type2[N]): - C[:] = A + B - - return vecadd_streaming_type_kernel - - -@dace.program -def matadd_streaming(A: dace.float32[M, N], B: dace.float32[M, N], C: dace.float32[M, N]): - C[:] = A + B - - -@dace.program -def matadd_streaming_symbol(A: dace.float32[M_s, N_s], B: dace.float32[M_s, N_s], C: dace.float32[M_s, N_s]): - C[:] = A + B - - -@dace.program -def matadd_streaming_bad_stride(A: dace.float32[M + 1, N + 1], B: dace.float32[M + 1, N + 1], C: dace.float32[M + 1, - N + 1]): - C[:] = A + B - - -@dace.program -def tensoradd_streaming(A: dace.float32[M, N, K], B: dace.float32[M, N, K], C: dace.float32[M, N, K]): - C[:] = A + B - - -@dace.program -def maporder_streaming(A: dace.float32[N, N, N], B: dace.float32[N, N, N], C: dace.float32[N, N, N], - D: dace.float32[N, N, N], E: dace.float32[N, N, N], F: dace.float32[N, N, - N], G: dace.float32[N, N]): - for i, j in dace.map[0:N, 0:N]: - with dace.tasklet: - in_A << A[i, j, 0] # No - in_B << B[i, 0, j] # Yes - in_C << C[0, i, j] # Yes - in_D << D[j, i, 0] # No - in_E << E[j, 0, i] # No - in_F << F[0, j, i] # No - out >> G[i, j] # Yes - - out = in_A + in_B + in_C + in_D + in_E + in_F - - -@dace.program -def matadd_multistream(A: dace.float32[M, N], B: dace.float32[M, N], C: dace.float32[M, N], D: dace.float32[M, N]): - C[:] = A + B - D[:] = A - B - - -@dace.program -def matmul_streaming(A: dace.float32[M, K], B: dace.float32[K, N], C: dace.float32[M, N]): - tmp = np.ndarray([M, N, K], dtype=A.dtype) - - # Multiply every pair of values to a large 3D temporary array - for i, j, k in dace.map[0:M, 0:N, 0:K]: - with dace.tasklet: - in_A << A[i, k] - in_B << B[k, j] - out >> tmp[i, j, k] - - out = in_A * in_B - - # Sum last dimension of temporary array to obtain resulting matrix - dace.reduce(lambda a, b: a + b, tmp, C, axis=2, identity=0) - - -@dace.program -def streamingcomp(A: dace.float32[M, N], B: dace.float32[M, N]): - # Slightly tricky situation - tmp = np.ndarray((M, N), dtype=A.dtype) - for i, j in dace.map[0:M, 0:N]: - with dace.tasklet: - a << A[i, j] - b << B[i, j] - t >> tmp[i, j] - t = a + b - - return tmp * B - - -@dace.program -def streaming_not_composable(A: dace.float32[M, N], B: dace.float32[M, N]): - for i, j in dace.map[0:M, 0:N - 1]: - with dace.tasklet: - a1 << A[i, j + 1] - a2 << A[i, j] - b >> B[i, j] - b = (a1 + a2) / 2 - for i, j in dace.map[0:M, 0:N - 1]: - with dace.tasklet: - a1 << B[i, j + 1] - a2 << B[i, j] - b >> A[i, j] - b = (a1 + a2) / 2 - - -@xilinx_test() -def test_streaming_mem(): - # Make SDFG - sdfg: dace.SDFG = matadd_streaming.to_sdfg() - # Transform - sdfg.apply_transformations([FPGATransformSDFG, InlineSDFG]) - assert sdfg.apply_transformations_repeated(sm.StreamingMemory, dict(storage=dace.StorageType.FPGA_Local)) == 3 - - # Run verification - A = np.random.rand(M, N).astype(np.float32) - B = np.random.rand(M, N).astype(np.float32) - C = np.random.rand(M, N).astype(np.float32) - - sdfg(A=A, B=B, C=C) - - diff = np.linalg.norm(C - (A + B)) - assert diff <= 1e-5 - - return sdfg - - -@xilinx_test() -def test_multistream(): - # Make SDFG - sdfg: dace.SDFG = matadd_multistream.to_sdfg() - # Transform - sdfg.apply_transformations([FPGATransformSDFG, InlineSDFG]) - assert sdfg.apply_transformations_repeated(sm.StreamingMemory, dict(storage=dace.StorageType.FPGA_Local)) == 4 - - # Ensure only 4 connected components exist - mainstate = next(s for s in sdfg.nodes() if 'copy' not in s.label) - assert len(list(nx.weakly_connected_components(mainstate.nx))) == 6 - - # Run verification - A = np.random.rand(M, N).astype(np.float32) - B = np.random.rand(M, N).astype(np.float32) - C = np.random.rand(M, N).astype(np.float32) - D = np.random.rand(M, N).astype(np.float32) - - sdfg(A=A, B=B, C=C, D=D) - - diff1 = np.linalg.norm(C - (A + B)) - diff2 = np.linalg.norm(D - (A - B)) - assert diff1 <= 1e-5 and diff2 <= 1e-5 - - return sdfg - - -@xilinx_test() -def test_multistream_with_deps(): - # Make SDFG - sdfg: dace.SDFG = streamingcomp.to_sdfg() - # Transform - sdfg.apply_transformations([FPGATransformSDFG, InlineSDFG]) - assert sdfg.apply_transformations_repeated(sm.StreamingMemory, dict(storage=dace.StorageType.FPGA_Local)) == 3 - - # Ensure only 4 connected components exist - mainstate = next(s for s in sdfg.nodes() if 'copy' not in s.label) - assert len(list(nx.weakly_connected_components(mainstate.nx))) == 4 - - # Run verification - A = np.random.rand(M, N).astype(np.float32) - B = np.random.rand(M, N).astype(np.float32) - - C = sdfg(A=A, B=B) - - diff = np.linalg.norm(C - ((A + B) * B)) / (M * N) - assert diff <= 1e-5 - - return sdfg - - -@xilinx_test() -def test_streaming_mem_mapnests(): - # Make SDFG - sdfg: dace.SDFG = matadd_streaming.to_sdfg() - # Transform - sdfg.apply_transformations([FPGATransformSDFG, InlineSDFG, MapExpansion]) - assert sdfg.apply_transformations_repeated(sm.StreamingMemory, dict(storage=dace.StorageType.FPGA_Local)) == 3 - - # Run verification - A = np.random.rand(M, N).astype(np.float32) - B = np.random.rand(M, N).astype(np.float32) - C = np.random.rand(M, N).astype(np.float32) - - sdfg(A=A, B=B, C=C) - - diff = np.linalg.norm(C - (A + B)) - assert diff <= 1e-5 - - return sdfg - - -@xilinx_test() -def test_streaming_composition_matching(): - sdfg: dace.SDFG = streaming_not_composable.to_sdfg() - assert sdfg.apply_transformations_repeated(sm.StreamingComposition) == 0 - return [] # SDFG was not compiled, so we can't run HLS on it - - -@xilinx_test() -def test_streaming_composition(): - # Make SDFG - sdfg: dace.SDFG = streamingcomp.to_sdfg() - # Transform - sdfg.apply_transformations([FPGATransformSDFG, InlineSDFG]) - assert sdfg.apply_transformations_repeated(sm.StreamingComposition, dict(storage=dace.StorageType.FPGA_Local)) == 1 - - # Run verification - A = np.random.rand(M, N).astype(np.float32) - B = np.random.rand(M, N).astype(np.float32) - - C = sdfg(A=A, B=B) - - diff = np.linalg.norm(C - ((A + B) * B)) / (M * N) - assert diff <= 1e-5 - - return sdfg - - -@xilinx_test() -def test_streaming_composition_mapnests(): - # Make SDFG - sdfg: dace.SDFG = streamingcomp.to_sdfg() - # Transform - sdfg.apply_transformations([FPGATransformSDFG, InlineSDFG]) - - # Test 1 - both maps expanded - test1 = copy.deepcopy(sdfg) - assert test1.apply_transformations_repeated(MapExpansion) == 2 - assert test1.apply_transformations_repeated(sm.StreamingComposition, dict(storage=dace.StorageType.FPGA_Local)) == 1 - - # Test 2 - one only one map expanded - sdfg.apply_transformations(MapExpansion) - assert sdfg.apply_transformations_repeated(sm.StreamingComposition, dict(storage=dace.StorageType.FPGA_Local)) == 1 - - # Run verification - A = np.random.rand(M, N).astype(np.float32) - B = np.random.rand(M, N).astype(np.float32) - C = np.random.rand(M, N).astype(np.float32) - - C = sdfg(A=A, B=B) - - diff = np.linalg.norm(C - ((A + B) * B)) / (M * N) - assert diff <= 1e-5 - - return sdfg - - -@xilinx_test() -def test_streaming_and_composition(): - # Make SDFG - sdfg: dace.SDFG = streamingcomp.to_sdfg() - # Transform - sdfg.apply_transformations([FPGATransformSDFG, InlineSDFG]) - - assert sdfg.apply_transformations_repeated(sm.StreamingMemory, dict(storage=dace.StorageType.FPGA_Local)) == 3 - assert sdfg.apply_transformations_repeated(sm.StreamingComposition, dict(storage=dace.StorageType.FPGA_Local)) == 1 - - # Run verification - A = np.random.rand(M, N).astype(np.float32) - B = np.random.rand(M, N).astype(np.float32) - C = np.random.rand(M, N).astype(np.float32) - - C = sdfg(A=A, B=B) - - diff = np.linalg.norm(C - ((A + B) * B)) / (M * N) - assert diff <= 1e-5 - - return sdfg - - -@pytest.mark.long -def test_mem_buffer_vec_add_1(): - # Make SDFG - sdfg: dace.SDFG = vecadd_1_streaming.to_sdfg() - # Transform - - sdfg.apply_transformations([ - FPGATransformSDFG, - InlineSDFG, - ]) - - assert sdfg.apply_transformations_repeated(sm.StreamingMemory, - options=[{ - 'use_memory_buffering': True, - "storage": dace.StorageType.FPGA_Local - }]) == 2 - - # Run verification - A = np.random.rand(N).astype(np.float32) - B = np.random.rand(N).astype(np.float32) - - sdfg(A=A, B=B) - - assert all(B == A + 1) - - return sdfg - - -@pytest.mark.long -def test_mem_buffer_vec_add_1_symbolic(): - # Make SDFG - sdfg: dace.SDFG = vecadd_1_streaming_symbol.to_sdfg() - # Transform - - sdfg.apply_transformations([ - FPGATransformSDFG, - InlineSDFG, - ]) - - assert sdfg.apply_transformations_repeated(sm.StreamingMemory, - options=[{ - 'use_memory_buffering': True, - "storage": dace.StorageType.FPGA_Local - }]) == 2 - - # Run verification - A = np.random.rand(N).astype(np.float32) - B = np.random.rand(N).astype(np.float32) - - sdfg(A=A, B=B, N_s=256) - - assert all(B == A + 1) - - return sdfg - - -@xilinx_test() -def test_mem_buffer_vec_add(): - # Make SDFG - sdfg: dace.SDFG = vecadd_streaming.to_sdfg() - # Transform - - sdfg.apply_transformations([ - FPGATransformSDFG, - InlineSDFG, - ]) - - assert sdfg.apply_transformations_repeated(sm.StreamingMemory, - options=[{ - 'use_memory_buffering': True, - "storage": dace.StorageType.FPGA_Local - }]) == 3 - - # Run verification - A = np.random.rand(N).astype(np.float32) - B = np.random.rand(N).astype(np.float32) - C = np.random.rand(N).astype(np.float32) - - sdfg(A=A, B=B, C=C) - - diff = np.linalg.norm(C - (A + B)) - assert diff <= 1e-5 - - return sdfg - - -def mem_buffer_vec_add_types(dace_type0, dace_type1, dace_type2, np_type0, np_type1, np_type2): - - sdfg: dace.SDFG = vecadd_streaming_type(dace_type0, dace_type1, dace_type2).to_sdfg() - - sdfg.apply_transformations([ - FPGATransformSDFG, - InlineSDFG, - ]) - - assert sdfg.apply_transformations_repeated(sm.StreamingMemory, - options=[{ - 'use_memory_buffering': True, - "storage": dace.StorageType.FPGA_Local - }]) == 3 - - # Run verification - A = (np.random.rand(N) * 100).astype(np_type0) - B = (np.random.rand(N) * 100).astype(np_type1) - C = (np.random.rand(N) * 100).astype(np_type2) - - sdfg(A=A, B=B, C=C) - - diff = np.linalg.norm(C - (A + B)) - assert diff <= 1e-5 - - return sdfg - - -@pytest.mark.long -# def test_mem_buffer_vec_add_float16(): -# return mem_buffer_vec_add_types(dace.float16, dace.float16, dace.float16, np.float16, np.float16, np.float16) -@pytest.mark.long -def test_mem_buffer_vec_add_float32(): - return mem_buffer_vec_add_types(dace.float32, dace.float32, dace.float32, np.float32, np.float32, np.float32) - - -@pytest.mark.long -def test_mem_buffer_vec_add_float64(): - return mem_buffer_vec_add_types(dace.float64, dace.float64, dace.float64, np.float64, np.float64, np.float64) - - -@pytest.mark.long -def test_mem_buffer_vec_add_int8(): - return mem_buffer_vec_add_types(dace.int8, dace.int8, dace.int8, np.int8, np.int8, np.int8) - - -@pytest.mark.long -def test_mem_buffer_vec_add_int16(): - return mem_buffer_vec_add_types(dace.int16, dace.int16, dace.int16, np.int16, np.int16, np.int16) - - -@pytest.mark.long -def test_mem_buffer_vec_add_int32(): - return mem_buffer_vec_add_types(dace.int32, dace.int32, dace.int32, np.int32, np.int32, np.int32) - - -@pytest.mark.long -def test_mem_buffer_vec_add_int64(): - return mem_buffer_vec_add_types(dace.int64, dace.int64, dace.int64, np.int64, np.int64, np.int64) - - -@pytest.mark.long -def test_mem_buffer_vec_add_complex64(): - return mem_buffer_vec_add_types(dace.complex64, dace.complex64, dace.complex64, np.complex64, np.complex64, - np.complex64) - - -@pytest.mark.long -def test_mem_buffer_vec_add_complex128(): - return mem_buffer_vec_add_types(dace.complex128, dace.complex128, dace.complex128, np.complex128, np.complex128, - np.complex128) - - -@pytest.mark.long -# def test_mem_buffer_vec_add_mixed_float(): -# return mem_buffer_vec_add_types(dace.float16, dace.float32, dace.float64, np.float16, np.float32, np.float64) -@pytest.mark.long -def test_mem_buffer_vec_add_mixed_int(): - return mem_buffer_vec_add_types(dace.int16, dace.int32, dace.int64, np.int16, np.int32, np.int64) - - -# TODO: Investigate and re-enable if possible. -@pytest.mark.skip(reason="Unexplained CI Regression") -@xilinx_test() -def test_mem_buffer_mat_add(): - # Make SDFG - sdfg: dace.SDFG = matadd_streaming.to_sdfg() - # Transform - sdfg.apply_transformations([FPGATransformSDFG, InlineSDFG]) - - assert sdfg.apply_transformations_repeated(sm.StreamingMemory, - options=[{ - 'use_memory_buffering': True, - "storage": dace.StorageType.FPGA_Local - }]) == 3 - - # Run verification - A = np.random.rand(M, N).astype(np.float32) - B = np.random.rand(M, N).astype(np.float32) - C = np.random.rand(M, N).astype(np.float32) - - sdfg(A=A, B=B, C=C) - - diff = np.linalg.norm(C - (A + B)) - - assert diff <= 1e-5 - - return sdfg - - -@pytest.mark.long -def test_mem_buffer_mat_add_symbol(): - # Make SDFG - sdfg: dace.SDFG = matadd_streaming_symbol.to_sdfg() - # Transform - sdfg.apply_transformations([FPGATransformSDFG, InlineSDFG]) - - assert sdfg.apply_transformations_repeated(sm.StreamingMemory, - options=[{ - 'use_memory_buffering': True, - "storage": dace.StorageType.FPGA_Local - }]) == 3 - - # Run verification - A = np.random.rand(M, N).astype(np.float32) - B = np.random.rand(M, N).astype(np.float32) - C = np.random.rand(M, N).astype(np.float32) - - sdfg(A=A, B=B, C=C, M_s=256, N_s=512) - - diff = np.linalg.norm(C - (A + B)) - - assert diff <= 1e-5 - - return sdfg - - -@pytest.mark.long -def test_mem_buffer_tensor_add(): - # Make SDFG - sdfg: dace.SDFG = tensoradd_streaming.to_sdfg() - # Transform - sdfg.apply_transformations([FPGATransformSDFG, InlineSDFG]) - - assert sdfg.apply_transformations_repeated(sm.StreamingMemory, - options=[{ - 'use_memory_buffering': True, - "storage": dace.StorageType.FPGA_Local - }]) == 3 - - # Run verification - A = np.random.rand(M, N, K).astype(np.float32) - B = np.random.rand(M, N, K).astype(np.float32) - C = np.random.rand(M, N, K).astype(np.float32) - - sdfg(A=A, B=B, C=C) - - diff = np.linalg.norm(C - (A + B)) - - assert diff <= 1e-5 - - return sdfg - - -# TODO: Investigate and re-enable if possible. -@pytest.mark.skip(reason="Unexplained CI Regression") -@xilinx_test() -def test_mem_buffer_multistream(): - # Make SDFG - sdfg: dace.SDFG = matadd_multistream.to_sdfg() - # Transform - sdfg.apply_transformations([FPGATransformSDFG, InlineSDFG]) - - assert sdfg.apply_transformations_repeated(sm.StreamingMemory, - options=[{ - 'use_memory_buffering': True, - "storage": dace.StorageType.FPGA_Local - }]) == 4 - - mainstate = next(s for s in sdfg.nodes() if 'copy' not in s.label) - assert len(list(nx.weakly_connected_components(mainstate.nx))) == 12 - - # Run verification - A = np.random.rand(M, N).astype(np.float32) - B = np.random.rand(M, N).astype(np.float32) - C = np.random.rand(M, N).astype(np.float32) - D = np.random.rand(M, N).astype(np.float32) - - sdfg(A=A, B=B, C=C, D=D) - - diff1 = np.linalg.norm(C - (A + B)) - diff2 = np.linalg.norm(D - (A - B)) - assert diff1 <= 1e-5 and diff2 <= 1e-5 - - return sdfg - - -# TODO: Investigate and re-enable if possible. -@pytest.mark.skip(reason="Unexplained CI Regression") -@xilinx_test() -def test_mem_buffer_multistream_with_deps(): - # Make SDFG - sdfg: dace.SDFG = streamingcomp.to_sdfg() - # Transform - sdfg.apply_transformations([FPGATransformSDFG, InlineSDFG]) - - assert sdfg.apply_transformations_repeated(sm.StreamingMemory, - options=[{ - 'use_memory_buffering': True, - "storage": dace.StorageType.FPGA_Local - }]) == 3 - - mainstate = next(s for s in sdfg.nodes() if 'copy' not in s.label) - assert len(list(nx.weakly_connected_components(mainstate.nx))) == 8 - - # Run verification - A = np.random.rand(M, N).astype(np.float32) - B = np.random.rand(M, N).astype(np.float32) - - C = sdfg(A=A, B=B) - - diff = np.linalg.norm(C - ((A + B) * B)) / (M * N) - assert diff <= 1e-5 - - return sdfg - - -@pytest.mark.long -def test_mem_buffer_mat_mul(): - # Make SDFG - sdfg: dace.SDFG = matmul_streaming.to_sdfg() - # Transform - sdfg.apply_transformations([FPGATransformSDFG, InlineSDFG]) - - assert sdfg.apply_transformations_repeated(sm.StreamingMemory, - options=[{ - 'use_memory_buffering': True, - "storage": dace.StorageType.FPGA_Local - }]) == 1 - - # Run verification - A = np.random.rand(M, K).astype(np.float32) - B = np.random.rand(K, N).astype(np.float32) - C = np.random.rand(M, N).astype(np.float32) - - sdfg(A=A, B=B, C=C) - - diff = np.linalg.norm(C - (A @ B)) - assert diff <= 1e-5 - - return sdfg - - -# TODO: Investigate and re-enable if possible. -@pytest.mark.skip(reason="Unexplained CI Regression") -@xilinx_test() -def test_mem_buffer_map_order(): - # Make SDFG - sdfg: dace.SDFG = maporder_streaming.to_sdfg() - # Transform - sdfg.apply_transformations([FPGATransformSDFG, InlineSDFG]) - - assert sdfg.apply_transformations_repeated(sm.StreamingMemory, - options=[{ - 'use_memory_buffering': True, - "storage": dace.StorageType.FPGA_Local - }]) == 3 - - # Run verification - A = np.random.rand(N, N, N).astype(np.float32) - B = np.random.rand(N, N, N).astype(np.float32) - C = np.random.rand(N, N, N).astype(np.float32) - D = np.random.rand(N, N, N).astype(np.float32) - E = np.random.rand(N, N, N).astype(np.float32) - F = np.random.rand(N, N, N).astype(np.float32) - G = np.random.rand(N, N).astype(np.float32) - G_sol = np.random.rand(N, N).astype(np.float32) - - for i in range(N): - for j in range(N): - G_sol[i][j] = A[i, j, 0] + B[i, 0, j] + \ - C[0, i, j] + D[j, i, 0] + E[j, 0, i] + F[0, j, i] - - sdfg(A=A, B=B, C=C, D=D, E=E, F=F, G=G) - - assert allclose(G_sol, G) - - return sdfg - - -@xilinx_test() -def test_mem_buffer_not_applicable(): - - sdfg: dace.SDFG = vecadd_1_streaming.to_sdfg() - sdfg.apply_transformations([FPGATransformSDFG, InlineSDFG]) - - assert sdfg.apply_transformations_repeated(sm.StreamingMemory, - options=[{ - 'use_memory_buffering': True, - "storage": dace.StorageType.FPGA_Local, - "memory_buffering_target_bytes": 65 - }]) == 0 - - assert sdfg.apply_transformations_repeated(sm.StreamingMemory, - options=[{ - 'use_memory_buffering': True, - "storage": dace.StorageType.FPGA_Local, - "memory_buffering_target_bytes": 0 - }]) == 0 - - sdfg2: dace.SDFG = matadd_streaming_bad_stride.to_sdfg() - sdfg2.apply_transformations([FPGATransformSDFG, InlineSDFG]) - - assert sdfg2.apply_transformations_repeated(sm.StreamingMemory, - options=[{ - 'use_memory_buffering': True, - "storage": dace.StorageType.FPGA_Local, - }]) == 0 - - sdfg3: dace.SDFG = vecadd_1_streaming_non_appl_0.to_sdfg() - sdfg3.apply_transformations([FPGATransformSDFG, InlineSDFG]) - - assert sdfg3.apply_transformations_repeated(sm.StreamingMemory, - options=[{ - 'use_memory_buffering': True, - "storage": dace.StorageType.FPGA_Local, - }]) == 0 - - sdfg4: dace.SDFG = vecadd_1_streaming_non_appl_1.to_sdfg() - sdfg4.apply_transformations([FPGATransformSDFG, InlineSDFG]) - - assert sdfg4.apply_transformations_repeated(sm.StreamingMemory, - options=[{ - 'use_memory_buffering': True, - "storage": dace.StorageType.FPGA_Local, - }]) == 0 - - return [] - - -@pytest.mark.long -def test_mem_buffer_atax(): - - A = np.random.rand(M, N).astype(np.float32) - x = np.random.rand(N).astype(np.float32) - - # Parse SDFG and apply FPGA friendly optimization - sdfg = atax.to_sdfg(strict=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - fpga_rr_interleave_containers_to_banks(sdfg, num_banks=4) - - # Use FPGA Expansion for lib nodes, and expand them to enable further optimizations - from dace.libraries.blas import Gemv - Gemv.default_implementation = "FPGA_Accumulate" - sdfg.expand_library_nodes() - sm_applied = sdfg.apply_transformations_repeated([InlineSDFG, sm.StreamingMemory], [{}, { - 'storage': dace.StorageType.FPGA_Local, - 'use_memory_buffering': True - }], - print_report=False) - assert sm_applied == 5 # 3 inlines and 2 Streaming memories - - sm_applied = sdfg.apply_transformations_repeated([InlineSDFG, sm.StreamingMemory], [{}, { - 'storage': dace.StorageType.FPGA_Local, - 'use_memory_buffering': False - }], - print_report=False) - - assert sm_applied == 1 # 1 Streaming memories - - # specialize the SDFG (needed by the GEMV expansion) - sdfg.specialize(dict(M=M, N=N)) - - y = sdfg(A=A, x=x) - - # Compute ground truth and Validate result - y_ref = atax.f(A, x) - - assert np.allclose(y, y_ref) - return sdfg - - -@pytest.mark.long -def test_mem_buffer_bicg(): - - A = np.random.rand(N, M).astype(np.float32) - p = np.random.rand(M).astype(np.float32) - r = np.random.rand(M).astype(np.float32) - - # Parse SDFG and apply FPGA friendly optimization - sdfg = bicg.to_sdfg(strict=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - fpga_rr_interleave_containers_to_banks(sdfg, num_banks=4) - - # Use FPGA Expansion for lib nodes, and expand them to enable further optimizations - from dace.libraries.blas import Gemv - Gemv.default_implementation = "FPGA_Accumulate" - sdfg.expand_library_nodes() - sm_applied = sdfg.apply_transformations_repeated([InlineSDFG, sm.StreamingMemory], [{}, { - 'storage': dace.StorageType.FPGA_Local, - 'use_memory_buffering': True - }], - print_report=True) - assert sm_applied == 7 # 3 inlines and 4 Streaming memories - - sm_applied = sdfg.apply_transformations_repeated([InlineSDFG, sm.StreamingMemory], [{}, { - 'storage': dace.StorageType.FPGA_Local, - 'use_memory_buffering': False - }], - print_report=True) - - assert sm_applied == 1 # 1 Streaming memories - - # specialize the SDFG (needed by the GEMV expansion) - sdfg.specialize(dict(M=M, N=N)) - - res0, res1 = sdfg(A=A, p=p, r=r) - - # Compute ground truth and Validate result - res0_ref, res1_ref = bicg.f(A, p, r) - - assert np.allclose(res0_ref, res0) - assert np.allclose(res1, res1_ref) - - return sdfg - - -@xilinx_test() -def test_two_maps_legal(): - - A = np.random.rand(N).astype(dace.float32.type) - B = np.random.rand(N).astype(dace.float32.type) - C = np.random.rand(N).astype(dace.float32.type) - D = np.random.rand(N).astype(dace.float32.type) - E = np.random.rand(N).astype(dace.float32.type) - - D_exp = A + B - E_exp = B + C - - sdfg: dace.SDFG = two_maps_kernel_legal.to_sdfg() - - sdfg.apply_transformations([FPGATransformSDFG, InlineSDFG]) - - assert sdfg.apply_transformations_repeated(sm.StreamingMemory, - options=[{ - 'use_memory_buffering': True, - "storage": dace.StorageType.FPGA_Local - }]) == 5 - - sdfg(A=A, B=B, C=C, D=D, E=E) - - assert np.allclose(D, D_exp) - assert np.allclose(E, E_exp) - - return sdfg - - -@xilinx_test() -def test_two_maps_illegal(): - - A = np.random.rand(N).astype(dace.float32.type) - B = np.random.rand(N).astype(dace.float32.type) - C = np.random.rand(N).astype(dace.float32.type) - D = np.random.rand(N).astype(dace.float32.type) - E = np.random.rand(N).astype(dace.float32.type) - E_exp = np.copy(E) - - D_exp = A + B - for i in range(0, 64, 2): - E_exp[i] = B[i] + C[i] - - sdfg = two_maps_kernel_illegal.to_sdfg() - - sdfg.apply_transformations([FPGATransformSDFG, InlineSDFG]) - - sdfg.apply_transformations_repeated(sm.StreamingMemory, - options=[{ - 'use_memory_buffering': True, - "storage": dace.StorageType.FPGA_Local - }]) == 2 - - sdfg(A=A, B=B, C=C, D=D, E=E) - - assert np.allclose(D, D_exp) - assert np.allclose(E, E_exp) - - return sdfg - - -if __name__ == "__main__": - test_streaming_mem(None) - test_streaming_mem_mapnests(None) - test_multistream(None) - test_multistream_with_deps(None) - test_streaming_composition_matching(None) - test_streaming_composition(None) - test_streaming_composition_mapnests(None) - test_streaming_and_composition(None) - - test_mem_buffer_vec_add_1(None) - test_mem_buffer_vec_add_1_symbolic(None) - test_mem_buffer_vec_add(None) - test_mem_buffer_mat_add(None) - test_mem_buffer_mat_add_symbol(None) - test_mem_buffer_tensor_add(None) - test_mem_buffer_multistream(None) - test_mem_buffer_multistream_with_deps(None) - test_mem_buffer_mat_mul(None) - test_mem_buffer_not_applicable(None) - test_mem_buffer_map_order(None) - - # test_mem_buffer_vec_add_float16(None) - test_mem_buffer_vec_add_float32(None) - test_mem_buffer_vec_add_float64(None) - test_mem_buffer_vec_add_int8(None) - test_mem_buffer_vec_add_int16(None) - test_mem_buffer_vec_add_int32(None) - test_mem_buffer_vec_add_int64(None) - # test_mem_buffer_vec_add_mixed_float(None) - test_mem_buffer_vec_add_mixed_int(None) - test_mem_buffer_vec_add_complex64(None) - test_mem_buffer_vec_add_complex128(None) - - test_mem_buffer_atax(None) - test_mem_buffer_bicg(None) - - test_two_maps_legal(None) - test_two_maps_illegal(None) diff --git a/tests/fpga/type_inference_test.py b/tests/fpga/type_inference_test.py deleted file mode 100644 index 3b09d38a7d..0000000000 --- a/tests/fpga/type_inference_test.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -""" Type inference test with annotated types. """ - -import dace -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG -import numpy as np - -N = dace.symbol("N") -CONSTANT_ARRAY = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.float32) -CONSTANT_VALUE = float(1) - - -def make_sdfg(): - sdfg = dace.SDFG("constant_type_inference") - - sdfg.add_array("output", shape=[N], dtype=dace.float32) - - sdfg.add_array("device_output", - shape=[N], - dtype=dace.float32, - storage=dace.dtypes.StorageType.FPGA_Global, - transient=True) - - ########################################################################### - # Copy data from FPGA - copy_out_state = sdfg.add_state("copy_to_host") - - device_output = copy_out_state.add_read("device_output") - host_output = copy_out_state.add_write("output") - - copy_out_state.add_memlet_path(device_output, host_output, memlet=dace.Memlet(f"{host_output}[0:N]")) - - ######################################################################## - # FPGA, First State - - # increment constant array of 1 elements - - fpga_state = sdfg.add_state("fpga_state") - - out = fpga_state.add_write("device_output") - map_entry, map_exit = fpga_state.add_map("increment_map", {"i": "0:N"}, schedule=dace.ScheduleType.FPGA_Device) - - # Force type inference for constant array - tasklet = fpga_state.add_tasklet("increment_tasklet", {}, {"out"}, "incr = constant_value\n" - "tmp = constant_array[i]\n" - "out = tmp + incr") - - fpga_state.add_memlet_path(map_entry, tasklet, memlet=dace.Memlet()) - fpga_state.add_memlet_path(tasklet, map_exit, out, src_conn="out", memlet=dace.Memlet("device_output[i]")) - - sdfg.add_edge(fpga_state, copy_out_state, dace.sdfg.sdfg.InterstateEdge()) - sdfg.fill_scope_connectors() - sdfg.validate() - return sdfg - - -@dace.program -def type_inference(x: dace.float32[N], y: dace.float32[N]): - - @dace.map - def comp(i: _[0:N]): - in_x << x[i] - in_y << y[i] - out >> y[i] - - # computes y[i]=(int)x[i] + ((int)y[i])*2.1 - var1 = int(in_x) - var2: int = in_y - var3 = 2.1 if (i > 1 and i < 10) else 2.1 # Just for the sake of testing - res = var1 + var3 * var2 - out = res - - -@fpga_test() -def test_type_inference_fpga(): - - N = 24 - - # Initialize vector: X - X = np.random.uniform(-10, 0, N).astype(dace.float32.type) - Y = np.random.uniform(-10, 0, N).astype(dace.float32.type) - # compute expected result - Z = np.zeros(N) - for i in range(0, N): - Z[i] = int(X[i]) + int(Y[i]) * 2.1 - - sdfg = type_inference.to_sdfg() - sdfg.apply_transformations(FPGATransformSDFG) - sdfg(x=X, y=Y, N=N) - - diff = np.linalg.norm(Z - Y) / N - - assert diff <= 1e-5 - - return sdfg - - -@fpga_test() -def test_constant_type_inference_fpga(): - sdfg = make_sdfg() - sdfg.add_constant('constant_array', CONSTANT_ARRAY) - sdfg.add_constant('constant_value', CONSTANT_VALUE) - - out = dace.ndarray([CONSTANT_ARRAY.size], dtype=dace.float32) - sdfg(N=CONSTANT_ARRAY.size, output=out) - ref = CONSTANT_ARRAY + CONSTANT_VALUE - diff = np.linalg.norm(ref - out) / CONSTANT_ARRAY.size - assert diff <= 1e-5 - return sdfg - - -if __name__ == "__main__": - test_type_inference_fpga(None) - test_constant_type_inference_fpga(None) diff --git a/tests/fpga/unique_nested_sdfg_fpga_test.py b/tests/fpga/unique_nested_sdfg_fpga_test.py deleted file mode 100644 index bd05ba1b6d..0000000000 --- a/tests/fpga/unique_nested_sdfg_fpga_test.py +++ /dev/null @@ -1,231 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. - -# The scope of the test is to verify that code nested SDFGs with a unique name is generated only once -# The nested SDFG compute vector addition on FPGA, with vectorization - -import dace -import numpy as np -import argparse -import subprocess - -from dace.memlet import Memlet -from dace.fpga_testing import fpga_test - - -def make_vecAdd_sdfg(sdfg_name: str, dtype=dace.float32): - vecWidth = 4 - n = dace.symbol("size") - vecAdd_sdfg = dace.SDFG(sdfg_name) - vecType = dace.vector(dtype, vecWidth) - - x_name = "x" - y_name = "y" - z_name = "z" - - ########################################################################### - # Copy data to FPGA - - copy_in_state = vecAdd_sdfg.add_state("copy_to_device") - - vecAdd_sdfg.add_array(x_name, shape=[n / vecWidth], dtype=vecType) - vecAdd_sdfg.add_array(y_name, shape=[n / vecWidth], dtype=vecType) - - in_host_x = copy_in_state.add_read(x_name) - in_host_y = copy_in_state.add_read(y_name) - - vecAdd_sdfg.add_array("device_x", - shape=[n / vecWidth], - dtype=vecType, - storage=dace.dtypes.StorageType.FPGA_Global, - transient=True) - vecAdd_sdfg.add_array("device_y", - shape=[n / vecWidth], - dtype=vecType, - storage=dace.dtypes.StorageType.FPGA_Global, - transient=True) - - in_device_x = copy_in_state.add_write("device_x") - in_device_y = copy_in_state.add_write("device_y") - - copy_in_state.add_memlet_path(in_host_x, - in_device_x, - memlet=Memlet.simple(in_host_x, "0:{}/{}".format(n, vecWidth))) - copy_in_state.add_memlet_path(in_host_y, - in_device_y, - memlet=Memlet.simple(in_host_y, "0:{}/{}".format(n, vecWidth))) - - ########################################################################### - # Copy data from FPGA - vecAdd_sdfg.add_array(z_name, shape=[n / vecWidth], dtype=vecType) - - copy_out_state = vecAdd_sdfg.add_state("copy_to_host") - - vecAdd_sdfg.add_array("device_z", - shape=[n / vecWidth], - dtype=vecType, - storage=dace.dtypes.StorageType.FPGA_Global, - transient=True) - - out_device = copy_out_state.add_read("device_z") - out_host = copy_out_state.add_write(z_name) - - copy_out_state.add_memlet_path(out_device, out_host, memlet=Memlet.simple(out_host, "0:{}/{}".format(n, vecWidth))) - - ######################################################################## - # FPGA State - - fpga_state = vecAdd_sdfg.add_state("fpga_state") - - x = fpga_state.add_read("device_x") - y = fpga_state.add_read("device_y") - z = fpga_state.add_write("device_z") - - # ---------- ---------- - # COMPUTE - # ---------- ---------- - vecMap_entry, vecMap_exit = fpga_state.add_map('vecAdd_map', - dict(i='0:{0}/{1}'.format(n, vecWidth)), - schedule=dace.dtypes.ScheduleType.FPGA_Device) - - vecAdd_tasklet = fpga_state.add_tasklet('vecAdd_task', ['x_con', 'y_con'], ['z_con'], 'z_con = x_con + y_con') - - fpga_state.add_memlet_path(x, - vecMap_entry, - vecAdd_tasklet, - dst_conn='x_con', - memlet=dace.Memlet.simple(x.data, "i")) - - fpga_state.add_memlet_path(y, - vecMap_entry, - vecAdd_tasklet, - dst_conn='y_con', - memlet=dace.Memlet.simple(y.data, 'i')) - - fpga_state.add_memlet_path(vecAdd_tasklet, vecMap_exit, z, src_conn='z_con', memlet=dace.Memlet.simple(z.data, 'i')) - - ###################################### - # Interstate edges - vecAdd_sdfg.add_edge(copy_in_state, fpga_state, dace.sdfg.sdfg.InterstateEdge()) - vecAdd_sdfg.add_edge(fpga_state, copy_out_state, dace.sdfg.sdfg.InterstateEdge()) - - ######### - # Validate - vecAdd_sdfg.fill_scope_connectors() - vecAdd_sdfg.validate() - return vecAdd_sdfg - - -def make_nested_sdfg_fpga(unique_names): - """ - Build an SDFG with two nested SDFGs, each one a different state - - :param unique_names: use unique names for all the Nested SDFGs - """ - n = dace.symbol("n") - m = dace.symbol("m") - - sdfg = dace.SDFG("two_vecAdd") - state = sdfg.add_state("state") - - # build the first axpy: works with x,y, and z of n-elements - - # ATTENTION: this two nested SDFG must have the same name as they are equal - sdfg_name = "vecAdd" - to_nest = make_vecAdd_sdfg(sdfg_name) - - sdfg.add_array("x", [n], dace.float32) - sdfg.add_array("y", [n], dace.float32) - sdfg.add_array("z", [n], dace.float32) - x = state.add_read("x") - y = state.add_read("y") - z = state.add_write("z") - - # add nested sdfg with symbol mapping - nested_sdfg = state.add_nested_sdfg(to_nest, {"x", "y"}, {"z"}, {"size": "n"}) - - if unique_names: - nested_sdfg.unique_name = sdfg_name - - state.add_memlet_path(x, nested_sdfg, dst_conn="x", memlet=Memlet.simple(x, "0:n", num_accesses=n)) - state.add_memlet_path(y, nested_sdfg, dst_conn="y", memlet=Memlet.simple(y, "0:n", num_accesses=n)) - state.add_memlet_path(nested_sdfg, z, src_conn="z", memlet=Memlet.simple(z, "0:n", num_accesses=n)) - - # Build the second axpy: works with v,w and u of m elements, use another state - - state2 = sdfg.add_state("state2") - - to_nest = make_vecAdd_sdfg(sdfg_name) - - sdfg.add_array("v", [m], dace.float32) - sdfg.add_array("w", [m], dace.float32) - sdfg.add_array("u", [m], dace.float32) - v = state2.add_read("v") - w = state2.add_read("w") - u = state2.add_write("u") - - nested_sdfg = state2.add_nested_sdfg(to_nest, {"x", "y"}, {"z"}, {"size": "m"}) - - if unique_names: - nested_sdfg.unique_name = sdfg_name - - state2.add_memlet_path(v, nested_sdfg, dst_conn="x", memlet=Memlet.simple(v, "0:m", num_accesses=m)) - state2.add_memlet_path(w, nested_sdfg, dst_conn="y", memlet=Memlet.simple(w, "0:m", num_accesses=m)) - state2.add_memlet_path(nested_sdfg, u, src_conn="z", memlet=Memlet.simple(u, "0:m", num_accesses=m)) - ###################################### - # Interstate edges - sdfg.add_edge(state, state2, dace.sdfg.sdfg.InterstateEdge()) - sdfg.validate() - - return sdfg - - -@fpga_test() -def test_unique_nested_sdfg_fpga(): - - parser = argparse.ArgumentParser() - parser.add_argument("N", type=int, nargs="?", default=32) - parser.add_argument("M", type=int, nargs="?", default=64) - args = vars(parser.parse_args()) - - size_n = args["N"] - size_m = args["M"] - - x = np.random.rand(size_n).astype(np.float32) - y = np.random.rand(size_n).astype(np.float32) - z_hash = np.random.rand(size_n).astype(np.float32) - z_u_name = np.random.rand(size_n).astype(np.float32) - - v = np.random.rand(size_m).astype(np.float32) - w = np.random.rand(size_m).astype(np.float32) - u_hash = np.random.rand(size_m).astype(np.float32) - u_u_name = np.random.rand(size_m).astype(np.float32) - - ref1 = np.add(x, y) - ref2 = np.add(v, w) - - # Hash based detection of equivalent SDFGs - two_axpy_hash = make_nested_sdfg_fpga(False) - with dace.config.set_temporary('compiler', 'unique_functions', value='hash'): - two_axpy_hash(x=x, y=y, z=z_hash, v=v, w=w, u=u_hash, n=size_n, m=size_m) - - diff1 = np.linalg.norm(ref1 - z_hash) / size_n - diff2 = np.linalg.norm(ref2 - u_hash) / size_m - if diff1 <= 1e-5 and diff2 <= 1e-5: - print("==== Program end ====") - else: - raise Exception("==== Program Error! ====") - - # Unique_name based detection of equivalent SDFGs - two_axpy_u_name = make_nested_sdfg_fpga(True) - with dace.config.set_temporary('compiler', 'unique_functions', value='unique_name'): - two_axpy_u_name(x=x, y=y, z=z_u_name, v=v, w=w, u=u_u_name, n=size_n, m=size_m) - - diff1 = np.linalg.norm(ref1 - z_u_name) / size_n - diff2 = np.linalg.norm(ref2 - u_u_name) / size_m - assert diff1 <= 1e-5 and diff2 <= 1e-5 - - return [two_axpy_hash, two_axpy_u_name] - - -if __name__ == "__main__": - test_unique_nested_sdfg_fpga(None) diff --git a/tests/fpga/vec_sum_test.py b/tests/fpga/vec_sum_test.py deleted file mode 100644 index e09c57300f..0000000000 --- a/tests/fpga/vec_sum_test.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -""" -Vector addition with explicit dataflow. Computes Z += X + Y -Can be used for simple vectorization test -""" - -import dace -from dace.fpga_testing import fpga_test, xilinx_test -import numpy as np -from dace.config import set_temporary -import pytest - - -def run_vec_sum(vectorize_first: bool): - N = dace.symbol("N") - - @dace.program - def vec_sum(x: dace.float32[N], y: dace.float32[N], z: dace.float32[N]): - - @dace.map - def sum(i: _[0:N]): - in_x << x[i] - in_y << y[i] - in_z << z[i] - out >> z[i] - - out = in_x + in_y + in_z - - n = 24 - - # Initialize arrays: X, Y and Z - rng = np.random.default_rng(42) - X = rng.random(n, dtype=np.float32) - Y = rng.random(n, dtype=np.float32) - Z = rng.random(n, dtype=np.float32) - ref = X + Y + Z - - sdfg = vec_sum.to_sdfg() - - if vectorize_first: - transformations = [ - dace.transformation.dataflow.vectorization.Vectorization, - dace.transformation.interstate.fpga_transform_sdfg.FPGATransformSDFG - ] - transformation_options = [{"propagate_parent": True, "postamble": False}, {}] - else: - transformations = [ - dace.transformation.interstate.fpga_transform_sdfg.FPGATransformSDFG, - dace.transformation.dataflow.vectorization.Vectorization - ] - transformation_options = [{}, {"propagate_parent": True, "postamble": False}] - - assert sdfg.apply_transformations(transformations, transformation_options) == 2 - - sdfg(x=X, y=Y, z=Z, N=n) - - print(f"ref ({ref.shape}): {ref}") - print(f"Z ({Z.shape}): {Z}") - - diff = np.linalg.norm(ref - Z) / n - if diff > 1e-5: - raise ValueError("Difference: {}".format(diff)) - - return sdfg - - -@fpga_test(assert_ii_1=False) -def test_vec_sum_vectorize_first(): - return run_vec_sum(True) - - -# TODO: Investigate and re-enable if possible. -@pytest.mark.skip(reason="Unexplained CI Regression") -@fpga_test(assert_ii_1=False, intel=False) -def test_vec_sum_fpga_transform_first(): - return run_vec_sum(False) - - -@xilinx_test(assert_ii_1=True) -def test_vec_sum_vectorize_first_decoupled_interfaces(): - # For this test, decoupled read/write interfaces are needed to achieve II=1 - with set_temporary("compiler", "xilinx", "decouple_array_interfaces", value=True): - return run_vec_sum(True) - - -@xilinx_test(assert_ii_1=True) -def test_vec_sum_fpga_transform_first_decoupled_interfaces(): - # For this test, decoupled read/write interfaces are needed to achieve II=1 - with set_temporary("compiler", "xilinx", "decouple_array_interfaces", value=True): - with set_temporary('testing', 'serialization', value=False): - return run_vec_sum(True) - - -if __name__ == "__main__": - test_vec_sum_vectorize_first(None) - test_vec_sum_fpga_transform_first(None) - test_vec_sum_fpga_transform_first_decoupled_interfaces(None) diff --git a/tests/fpga/veclen_conversion_connector_test.py b/tests/fpga/veclen_conversion_connector_test.py deleted file mode 100644 index 24b2d95fe3..0000000000 --- a/tests/fpga/veclen_conversion_connector_test.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -import numpy as np -import pytest -from veclen_conversion_test import make_sdfg -from dace.fpga_testing import fpga_test - - -#TODO: Investigate and re-enable if possible. -@pytest.mark.skip(reason="Unexplained CI Regression") -@fpga_test() -def test_veclen_conversion_connector(): - - size = 128 - vector_length = 4 - - if size % vector_length != 0: - raise ValueError("Size {} must be divisible by vector length {}.".format(size, vector_length)) - - sdfg = make_sdfg(name="veclen_conversion_connector", vectorize_connector=True) - sdfg.specialize({"W": vector_length}) - - A = np.arange(size, dtype=np.float64) - B = np.zeros((size, ), dtype=np.float64) - - sdfg(A=A, B=B, N=size) - - mid = vector_length // 2 - - for i in range(size // vector_length): - expected = np.concatenate( - (A[i * vector_length + mid:(i + 1) * vector_length], A[i * vector_length:i * vector_length + mid])) - if any(B[i * vector_length:(i + 1) * vector_length] != expected): - raise ValueError("Shuffle failed: {} (should be {})".format(B, expected)) - - return sdfg - - -if __name__ == "__main__": - test_veclen_conversion_connector() diff --git a/tests/fpga/veclen_conversion_test.py b/tests/fpga/veclen_conversion_test.py deleted file mode 100755 index f9a7967ad5..0000000000 --- a/tests/fpga/veclen_conversion_test.py +++ /dev/null @@ -1,169 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -import dace -import numpy as np -from dace.fpga_testing import intel_fpga_test - -SIZE = dace.symbol("N") -VECTOR_LENGTH = dace.symbol("W") -DTYPE = dace.float64 - - -def make_copy_to_fpga_state(sdfg, veclen): - - state = sdfg.add_state("copy_to_device") - - A_host = sdfg.add_array("A", [SIZE // veclen], dtype=dace.vector(DTYPE, veclen)) - - A_device = sdfg.add_array("A_device", [SIZE], - dtype=dace.vector(DTYPE, veclen), - transient=True, - storage=dace.dtypes.StorageType.FPGA_Global) - - read = state.add_read("A") - write = state.add_write("A_device") - - state.add_memlet_path(read, - write, - memlet=dace.memlet.Memlet.simple("A_device", - "0:N//{}".format(veclen), - num_accesses=SIZE // veclen)) - - return state - - -def make_copy_to_host_state(sdfg, veclen): - - state = sdfg.add_state("copy_to_host") - - B_device = sdfg.add_array("B_device", [SIZE], - dtype=dace.vector(DTYPE, veclen), - transient=True, - storage=dace.dtypes.StorageType.FPGA_Global) - - B_host = sdfg.add_array("B", [SIZE // veclen], dtype=dace.vector(DTYPE, veclen)) - - read = state.add_read("B_device") - write = state.add_write("B") - - state.add_memlet_path(read, - write, - memlet=dace.memlet.Memlet.simple("B", "0:N//{}".format(veclen), num_accesses=SIZE // veclen)) - - return state - - -def make_fpga_state(sdfg, vectorize_connector, veclen): - - state = sdfg.add_state("fpga_state") - - sdfg.add_array("input_buffer", (veclen, ), DTYPE, transient=True, storage=dace.StorageType.FPGA_Registers) - sdfg.add_array("output_buffer", (veclen, ), DTYPE, transient=True, storage=dace.StorageType.FPGA_Registers) - - read_input = state.add_read("A_device") - read_buffer = state.add_access("input_buffer") - write_buffer = state.add_access("output_buffer") - write_output = state.add_write("B_device") - - outer_entry, outer_exit = state.add_map("outer_map", {"i": "0:N/W"}, schedule=dace.ScheduleType.FPGA_Device) - - # Test read from packed memory to an unpacked buffer - if vectorize_connector: - outputs = {"a_unpacked": dace.vector(DTYPE, veclen)} - else: - outputs = {"a_unpacked"} # Infers an array - unpack_tasklet = state.add_tasklet("unpack_tasklet", {"a"}, outputs, "a_unpacked = a") - state.add_memlet_path(read_input, - outer_entry, - unpack_tasklet, - dst_conn="a", - memlet=dace.Memlet.simple("A_device", "i")) - state.add_memlet_path(unpack_tasklet, - read_buffer, - src_conn="a_unpacked", - memlet=dace.Memlet.simple(read_buffer.data, "0:{}".format(veclen))) - - unroll_entry, unroll_exit = state.add_map("shuffle_map", {"w": "0:W"}, - schedule=dace.ScheduleType.FPGA_Device, - unroll=True) - - tasklet = state.add_tasklet("shuffle_tasklet", {"a"}, {"b"}, "b = a") - - state.add_memlet_path(read_buffer, - unroll_entry, - tasklet, - dst_conn="a", - memlet=dace.Memlet.simple("input_buffer", "(w + W // 2) % W", num_accesses=1)) - - state.add_memlet_path(tasklet, - unroll_exit, - write_buffer, - src_conn="b", - memlet=dace.Memlet.simple("output_buffer", "w", num_accesses=1)) - - # Test writing from unpacked to packed from inside tasklet - if vectorize_connector: - outputs = {"b": dace.vector(DTYPE, veclen)} - else: - outputs = {"b"} - pack_tasklet = state.add_tasklet("pack_tasklet", outputs, {"b_packed"}, "b_packed = b") - state.add_memlet_path(write_buffer, - pack_tasklet, - dst_conn="b", - memlet=dace.Memlet.simple(write_buffer.data, "0:{}".format(veclen))) - - # Write back out to memory from unpacked to packed memory - state.add_memlet_path(pack_tasklet, - outer_exit, - write_output, - src_conn="b_packed", - memlet=dace.Memlet.simple("B_device", "i")) - - return state - - -def make_sdfg(name=None, vectorize_connector=False, veclen=4): - - if name is None: - name = "veclen_conversion" - - sdfg = dace.SDFG(name) - - pre_state = make_copy_to_fpga_state(sdfg, veclen) - post_state = make_copy_to_host_state(sdfg, veclen) - compute_state = make_fpga_state(sdfg, vectorize_connector, veclen) - - sdfg.add_edge(pre_state, compute_state, dace.sdfg.InterstateEdge()) - sdfg.add_edge(compute_state, post_state, dace.sdfg.InterstateEdge()) - - return sdfg - - -@intel_fpga_test() -def test_veclen_conversion(): - size = 128 - vector_length = 4 - - if size % vector_length != 0: - raise ValueError("Size {} must be divisible by vector length {}.".format(size, vector_length)) - - sdfg = make_sdfg(vectorize_connector=False, veclen=vector_length) - sdfg.specialize({"W": vector_length}) - - A = np.arange(size, dtype=np.float64) - B = np.zeros((size, ), dtype=np.float64) - - sdfg(A=A, B=B, N=size) - - mid = vector_length // 2 - - for i in range(size // vector_length): - expected = np.concatenate( - (A[i * vector_length + mid:(i + 1) * vector_length], A[i * vector_length:i * vector_length + mid])) - if any(B[i * vector_length:(i + 1) * vector_length] != expected): - raise ValueError("Shuffle failed: {} (should be {})".format(B, expected)) - - return sdfg - - -if __name__ == "__main__": - test_veclen_conversion(None) diff --git a/tests/fpga/veclen_copy_conversion_test.py b/tests/fpga/veclen_copy_conversion_test.py deleted file mode 100644 index 3aa4716ef5..0000000000 --- a/tests/fpga/veclen_copy_conversion_test.py +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -import copy -import dace -from dace.fpga_testing import fpga_test - - -def make_sdfg(tasklet_code=None, name="veclen_copy_conversion", dtype=dace.float32, veclen=16): - - vtype = dace.vector(dace.float32, veclen) - - if tasklet_code is None: - tasklet_code = "_out = _in" - - n = dace.symbol("N") - - sdfg = dace.SDFG(name) - - pre_state = sdfg.add_state(name + "_pre") - state = sdfg.add_state(name) - post_state = sdfg.add_state(name + "_post") - sdfg.add_edge(pre_state, state, dace.InterstateEdge()) - sdfg.add_edge(state, post_state, dace.InterstateEdge()) - - _, desc_input_host = sdfg.add_array("a", (n // veclen, ), vtype) - _, desc_output_host = sdfg.add_array("b", (n // veclen, ), vtype) - desc_input_device = copy.copy(desc_input_host) - desc_input_device.storage = dace.StorageType.FPGA_Global - desc_input_device.location["memorytype"] = "ddr" - desc_input_device.location["bank"] = "0" - desc_input_device.transient = True - desc_output_device = copy.copy(desc_output_host) - desc_output_device.storage = dace.StorageType.FPGA_Global - desc_output_device.location["memorytype"] = "ddr" - desc_output_device.location["bank"] = "1" - desc_output_device.transient = True - sdfg.add_datadesc("a_device", desc_input_device) - sdfg.add_datadesc("b_device", desc_output_device) - - # Host to device - pre_read = pre_state.add_read("a") - pre_write = pre_state.add_write("a_device") - pre_state.add_memlet_path(pre_read, pre_write, memlet=dace.Memlet(pre_write.data, None)) - - # Device to host - post_read = post_state.add_read("b_device") - post_write = post_state.add_write("b") - post_state.add_memlet_path(post_read, post_write, memlet=dace.Memlet(post_write.data, None)) - - # Compute state - read_memory = state.add_read("a_device") - write_memory = state.add_write("b_device") - - # Memory streams - sdfg.add_stream("a_stream", vtype, storage=dace.StorageType.FPGA_Local, transient=True) - sdfg.add_stream("b_stream", vtype, storage=dace.StorageType.FPGA_Local, transient=True) - produce_input_stream = state.add_write("a_stream") - consume_input_stream = state.add_read("a_stream") - produce_output_stream = state.add_write("b_stream") - consume_output_stream = state.add_write("b_stream") - - tasklet = state.add_tasklet(name, {"_in"}, {"_out"}, tasklet_code) - - # Iterative map - entry, exit = state.add_map(name, { - "i": "0:N//{}".format(veclen), - }, schedule=dace.ScheduleType.FPGA_Device) - - # Unrolled map - unroll_entry, unroll_exit = state.add_map(name + "_unroll", {"u": "0:{}".format(veclen)}, - schedule=dace.ScheduleType.FPGA_Device, - unroll=True) - - # Container-to-container copies between arrays and streams - state.add_memlet_path(read_memory, produce_input_stream, memlet=dace.Memlet(read_memory.data)) - state.add_memlet_path(consume_output_stream, write_memory, memlet=dace.Memlet(write_memory.data)) - - # Container-to-container copy from vectorized stream to non-vectorized - # buffer - sdfg.add_array("a_buffer", (veclen, ), dtype, storage=dace.StorageType.FPGA_Local, transient=True) - sdfg.add_array("b_buffer", (veclen, ), dtype, storage=dace.StorageType.FPGA_Local, transient=True) - a_buffer = state.add_access("a_buffer") - b_buffer = state.add_access("b_buffer") - - # Input stream to buffer - state.add_memlet_path(consume_input_stream, - entry, - a_buffer, - memlet=dace.Memlet.simple(consume_input_stream.data, - "0", - other_subset_str="0:{}".format(veclen))) - # Buffer to tasklet - state.add_memlet_path(a_buffer, - unroll_entry, - tasklet, - dst_conn="_in", - memlet=dace.Memlet.simple(a_buffer.data, "u", num_accesses=1)) - - # Tasklet to buffer - state.add_memlet_path(tasklet, - unroll_exit, - b_buffer, - src_conn="_out", - memlet=dace.Memlet.simple(b_buffer.data, "u", num_accesses=1)) - - # Buffer to output stream - state.add_memlet_path(b_buffer, - exit, - produce_output_stream, - memlet=dace.Memlet.simple(produce_output_stream.data, - "0", - other_subset_str="0:{}".format(veclen), - num_accesses=1)) - - return sdfg - - -@fpga_test() -def test_veclen_copy_conversion(): - - import numpy as np - - dtype = np.float32 - - gearbox = make_sdfg(tasklet_code="_out = _in + 1", dtype=dtype) - - size = 1024 - a = np.arange(size, dtype=dtype) - b = np.empty(size, dtype=dtype) - - gearbox(a=a, b=b, N=size) - - if any(b != a + 1): - raise ValueError("Unexpected output.") - - return gearbox - - -if __name__ == "__main__": - test_veclen_copy_conversion(None) diff --git a/tests/fpga/vector_reduce_test.py b/tests/fpga/vector_reduce_test.py deleted file mode 100644 index 8c2fa9b332..0000000000 --- a/tests/fpga/vector_reduce_test.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -""" Sums all the element of the vector with a reduce. """ - -import dace -import numpy as np -import argparse -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG - -N = dace.symbol('N') - - -@dace.program -def vector_reduce(x: dace.float32[N], s: dace.scalar(dace.float32)): - #transient - tmp = dace.define_local([N], dtype=x.dtype) - - @dace.map - def sum(i: _[0:N]): - in_x << x[i] - out >> tmp[i] - - out = in_x - - dace.reduce(lambda a, b: a + b, tmp, s, axis=(0), identity=0) - - -@fpga_test() -def test_vector_reduce(): - - N = 24 - - # Initialize arrays: X, Y and Z - X = np.random.rand(N).astype(dace.float32.type) - s = dace.scalar(dace.float32) - - sdfg = vector_reduce.to_sdfg() - sdfg.apply_transformations(FPGATransformSDFG) - sdfg(x=X, s=s, N=N) - - # Compute expected result - s_exp = 0.0 - for x in X: - s_exp += x - diff = np.linalg.norm(s_exp - s) / N - assert diff <= 1e-5 - - return sdfg - - -if __name__ == "__main__": - test_vector_reduce(None) diff --git a/tests/fpga/xilinx_interstate_preprocessing.py b/tests/fpga/xilinx_interstate_preprocessing.py deleted file mode 100644 index aaefa80e57..0000000000 --- a/tests/fpga/xilinx_interstate_preprocessing.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. -""" - Xilinx SDFG preprocessing replaces variable names in interstate edges with their - qualified name (i.e., appending '_in'/'_out' to the container name). - This behaviour is tested by `tests/npbench/nussinov_test.py`. - - This test (issue #972) tests that type inference for interstate edge variables (triggered - after preprocessing) is done correctly. - -""" - -import dace.dtypes -import numpy as np -import dace as dc -import pytest -import argparse -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG - -N0 = dace.symbol('N0', dtype=dace.uint32) -N1 = dace.symbol('N1', dtype=dace.uint32) - - -@dace.program -def ComputeRestriction(Axf: dace.float64[N0], rc: dace.float64[N1], f2c: dace.uint32[N0], rf: dace.float64[N0]): - for i_res in range(0, N1): - problem_variable = f2c[i_res] - rc[i_res] = rf[problem_variable] - Axf[problem_variable] - - -@dace.program -def program(f2cOperator0: dace.uint32[N0], rc1: dace.float64[N1], Axf0: dace.float64[N0], x: dace.float64[N0]): - ComputeRestriction(Axf=Axf0, rc=rc1, f2c=f2cOperator0, rf=x) - - -@fpga_test(assert_ii_1=False, intel=False) -def test_type_inference(): - sdfg = program.to_sdfg() - sdfg.apply_transformations(FPGATransformSDFG) - f2cOperator = np.array([0, 1, 2, 3], dtype=np.uint32) - rc = np.array([42, 42, 42, 42], dtype=np.float64) - Axf = np.array([0, 2, 4, 6], dtype=np.float64) - x = np.array([0, 1, 2, 3], dtype=np.float64) - sdfg(f2cOperator0=f2cOperator, rc1=rc, Axf0=Axf, x=x, N0=4, N1=4) - - assert ((rc == np.array([0, -1, -2, -3])).all()) - return sdfg - - -if __name__ == "__main__": - test_type_inference(None) diff --git a/tests/fpga_polybench_test.sh b/tests/fpga_polybench_test.sh deleted file mode 100755 index 6ba90242ad..0000000000 --- a/tests/fpga_polybench_test.sh +++ /dev/null @@ -1,208 +0,0 @@ -#!/bin/bash -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. - -# Executes polybench kernels for correctness check, by default on both Xilinx and Intel FPGA -# Tests have also a timeout time (defined in variable) TEST_TIMEOUT -# this must be invoked on test directory -# usage: fpga_polybench.sh - -set -a - -SCRIPTPATH="$( - cd "$(dirname "$0")" - pwd -P -)" -PYTHONPATH=$SCRIPTPATH/.. - -DACE_debugprint="${DACE_debugprint:-0}" -ERRORS=0 -FAILED_TESTS="" -TIMEDOUT_TESTS="" -POLYBENCH_INPUT="mini" -TESTS=0 -PYTHON_BINARY="${PYTHON_BINARY:-python3}" - -TEST_TIMEOUT="30s" - -RED='\033[0;31m' -YELLOW='\033[0;33m' -NC='\033[0m' - -################################################ - -bail() { - ERRORSTR=$1 - /bin/echo -e "${RED}ERROR${NC} in $ERRORSTR" 1>&2 - ERRORS=$(expr $ERRORS + 1) - FAILED_TESTS="${FAILED_TESTS} $ERRORSTR\n" -} - -timedout() { - ERRORSTR=$1 - /bin/echo -e "${RED}ERROR${NC} in $ERRORSTR" 1>&2 - ERRORS=$(expr $ERRORS + 1) - TIMEDOUT_TESTS="${TIMEDOUT_TESTS} $ERRORSTR\n" -} - -run_sample_intel() { - # Args: - # 1 - Relative path of FPGA test starting from test folder - # 2 - name of the build folder - # 3 - number of the SDFG transformation to apply - TESTS=$(expr $TESTS + 1) - echo -e "${YELLOW}Running test $1...${NC}" - - #1: generate the opencl - #remove previously built version. This helps to avoid stall in case the program does not terminates - rm -fr .dacecache/$2 2>/dev/null - echo -e "FPGATransformSDFG\$${3}\ny" | $PYTHON_BINARY $1.py -size ${POLYBENCH_INPUT} 2>/dev/null | : - - #2: compile for emulation - cd .dacecache/$2/build - if [ $? -ne 0 ]; then - bail "$1 (${RED}Code generation failed${NC})" - return 1 - fi - - make intelfpga_compile_$2_emulator - cd ../../../ - - if [ $? -ne 0 ]; then - bail "$1 (${RED}high-level synthesis failed${NC})" - return 1 - fi - - #3: execute the emulation with timeout - echo -e "FPGATransformSDFG\$${3}\ny" | timeout $TEST_TIMEOUT $PYTHON_BINARY $1.py -size ${POLYBENCH_INPUT} - - ret_status=$? - if [ $ret_status -ne 0 ]; then - echo "Result " $ret_status - if [ $ret_status -eq "124" ]; then #timeout test - timedout "Intel_FPGA: $1 (${RED}Test timeout${NC})" - else - bail "Intel_FPGA: $1 (${RED}Result not correct${NC})" - fi - return 1 - fi - - #4 cleanup - cd .dacecache/$2/build - rm -fr $1_* - cd - - - return 0 -} - -run_sample_xilinx() { - # Args: - # 1 - Relative path of FPGA test starting from test folder - # 2 - name of the build folder - # 3 - number of the SDFG transformation to apply - TESTS=$(expr $TESTS + 1) - echo -e "${YELLOW}Running test $1...${NC}" - - #1: execute the benchmark with timeout - echo -e "FPGATransformSDFG\$${3}\ny" | timeout $TEST_TIMEOUT $PYTHON_BINARY $1.py -size ${POLYBENCH_INPUT} ${@:3} - ret_status=$? - if [ $ret_status -ne 0 ]; then - echo "Result " $ret_status - if [ $ret_status -eq "124" ]; then #timeout test - timedout "Xilinx: $1 (${RED}Test timeout${NC})" - else - bail "Xilinx: $1 (${RED}Result not correct${NC})" - fi - return 1 - fi - - return 0 -} - -run_all() { - # Args: - # - run_sample function to invoke - - echo "Removing cache..." - sleep 5 - rm -fr .dacecache - - $1 2mm k2mm 0 - $1 3mm k3mm 0 - $1 adi adi 0 - $1 atax atax 0 - $1 bicg bicg 0 - $1 cholesky cholesky 0 - $1 correlation correlation 0 - $1 covariance covariance 0 - $1 deriche deriche 0 - $1 doitgen doitgen 1 - $1 durbin durbin 0 - $1 fdtd-2d fdtd2d 0 - $1 floyd-warshall floyd_warshall 0 - $1 gemm gemm 0 - $1 gemver gemver 0 - $1 gesummv gesummv 0 - $1 gramschmidt gramschmidt 1 - $1 heat-3d heat3d 0 - $1 jacobi-1d jacobi1d 0 - $1 jacobi-2d jacobi2d 0 - $1 ludcmp ludcmp 0 - $1 lu lu 0 - $1 mvt mvt 0 - $1 nussinov nussinov 0 - $1 seidel-2d seidel2d 0 - $1 symm symm 1 - $1 syr2k syr2k 0 - $1 syrk syrk 0 - $1 trisolv trisolv 0 - $1 trmm trmm 1 -} - -if [ "$1" == "intel_fpga" ]; then - # Check if aoc is vailable - which aoc - if [ $? -ne 0 ]; then - echo "aocc not available" - exit 99 - fi - - echo "====== Target: INTEL FPGA ======" - export DACE_compiler_use_cache=0 - export DACE_compiler_fpga_vendor="intel_fpga" - export DACE_compiler_intel_fpga_mode="emulator" - - TEST_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null && pwd)" - cd $TEST_DIR/polybench - - run_all run_sample_intel - -else - # assuming xilinx - # Check if xocc is vailable - which xocc - if [ $? -ne 0 ]; then - echo "xocc not available" - exit 99 - fi - - echo "====== Target: Xilinx ======" - - export DACE_compiler_use_cache=0 - export DACE_compiler_xilinx_mode="simulation" - export DACE_compiler_fpga_vendor="xilinx" - - echo "Attention: this will cleanup cache in 5 seconds..." - #cleanup - rm -fr .dacecache/ - - TEST_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null && pwd)" - cd $TEST_DIR/polybench - run_all run_sample_xilinx -fi - -PASSED=$(expr $TESTS - $ERRORS) -echo "$PASSED / $TESTS tests passed" -if [ $ERRORS -ne 0 ]; then - printf "Failed tests:\n${FAILED_TESTS}" - exit 1 -fi diff --git a/tests/instrumentation_test.py b/tests/instrumentation_test.py index 2aa26edf36..120e025812 100644 --- a/tests/instrumentation_test.py +++ b/tests/instrumentation_test.py @@ -4,6 +4,7 @@ import pytest import numpy as np +import re import sys import dace @@ -39,14 +40,17 @@ def onetest(instrumentation: dace.InstrumentationType, size=128): if isinstance(node, nodes.MapEntry) and node.map.label == 'mult': node.map.instrument = instrumentation state.instrument = instrumentation - # Set Timer instrumentation on the whole SDFG - if instrumentation == dace.InstrumentationType.Timer: - sdfg.instrument = instrumentation - if instrumentation == dace.InstrumentationType.GPU_Events: + if instrumentation in [dace.InstrumentationType.GPU_Events, dace.InstrumentationType.GPU_TX_MARKERS]: sdfg.apply_transformations(GPUTransformSDFG) - sdfg(A=A, B=B, C=C, N=size) + with dace.instrument(instrumentation, + filter='*', + annotate_maps=True, + annotate_tasklets=False, + annotate_states=True, + annotate_sdfgs=True): + sdfg(A=A, B=B, C=C, N=size) # Check for correctness assert np.allclose(C, 20 * A @ B) @@ -57,6 +61,22 @@ def onetest(instrumentation: dace.InstrumentationType, size=128): report = sdfg.get_latest_report() print(report) + # Check that the NVTX/rocTX range wrapper is present in the generated CPU code + if instrumentation == dace.InstrumentationType.GPU_TX_MARKERS: + code = sdfg.generate_code()[0].clean_code + tx_include = re.search(r'#include <(nvtx3/nvToolsExt|roctx).h>', code) + assert tx_include is not None + range_push = re.search(r'(nvtx|roctx)RangePush\("sdfg', code) is not None + range_push &= re.search(r'(nvtx|roctx)RangePush\("copy', code) is not None + range_push &= re.search(r'(nvtx|roctx)RangePush\("state', code) is not None + range_push &= re.search(r'(nvtx|roctx)RangePush\("alloc', code) is not None + range_push &= re.search(r'(nvtx|roctx)RangePush\("dealloc', code) is not None + range_push &= re.search(r'(nvtx|roctx)RangePush\("init', code) is not None + range_push &= re.search(r'(nvtx|roctx)RangePush\("exit', code) is not None + assert range_push + range_pop = re.search(r'(nvtx|roctx)RangePop\b', code) + assert range_pop is not None + def test_timer(): onetest(dace.InstrumentationType.Timer) @@ -73,8 +93,14 @@ def test_gpu_events(): onetest(dace.InstrumentationType.GPU_Events) +@pytest.mark.gpu +def test_gpu_tx_markers(): + onetest(dace.InstrumentationType.GPU_TX_MARKERS) + + if __name__ == '__main__': test_timer() test_papi() if len(sys.argv) > 1 and sys.argv[1] == 'gpu': test_gpu_events() + test_gpu_tx_markers() diff --git a/tests/library/codelibnode_test.py b/tests/library/codelibnode_test.py index 457f297004..a996730244 100644 --- a/tests/library/codelibnode_test.py +++ b/tests/library/codelibnode_test.py @@ -3,7 +3,6 @@ from dace.data import Array from dace.properties import Property, make_properties from dace.libraries.standard.nodes import CodeLibraryNode -from dace.codegen.targets.cpp import cpp_offset_expr import numpy as np from typing import Dict diff --git a/tests/library/stencil/stencil_node_test.py b/tests/library/stencil/stencil_node_test.py index c6d6298e7a..d33814d3e9 100644 --- a/tests/library/stencil/stencil_node_test.py +++ b/tests/library/stencil/stencil_node_test.py @@ -1,9 +1,8 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. import dace -from dace.fpga_testing import intel_fpga_test from dace.libraries.stencil import Stencil import numpy as np -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG +from dace.transformation.interstate import InlineSDFG SIZE = dace.symbol("size") ROWS = dace.symbol("rows") @@ -84,19 +83,6 @@ def test_stencil_node_1d(): run_stencil_1d(make_sdfg_1d("pure", 1), 32) -def stencil_node_1d_fpga_array(vector_length: int): - sdfg = make_sdfg_1d(dace.Config.get("compiler", "fpga", "vendor"), vector_length) - assert sdfg.apply_transformations(FPGATransformSDFG) == 1 - assert sdfg.apply_transformations(InlineSDFG) == 1 - run_stencil_1d(sdfg, 32) - return sdfg - - -@intel_fpga_test() -def test_stencil_node_1d_fpga_array(): - return stencil_node_1d_fpga_array(1) - - def run_stencil_2d(sdfg, rows, cols, specialize: bool): a = np.zeros((rows, cols), dtype=DTYPE) a[1:-1, 1:-1] = np.arange(1, (rows - 2) * (cols - 2) + 1, dtype=DTYPE).reshape((rows - 2, cols - 2)) @@ -117,28 +103,6 @@ def test_stencil_node_2d(): run_stencil_2d(make_sdfg_2d("pure", 1), 16, 32, False) -def stencil_node_2d_fpga_array(vector_length: int): - sdfg = make_sdfg_2d(dace.Config.get("compiler", "fpga", "vendor"), vector_length) - sdfg.specialize({"cols": 8}) - assert sdfg.apply_transformations(FPGATransformSDFG) == 1 - assert sdfg.apply_transformations(InlineSDFG) == 1 - run_stencil_2d(sdfg, 4, 8, True) - return sdfg - - -@intel_fpga_test() -def test_stencil_node_2d_fpga_array(): - return stencil_node_2d_fpga_array(1) - - -@intel_fpga_test() -def test_stencil_node_2d_fpga_array_vectorized(): - return stencil_node_2d_fpga_array(4) - - if __name__ == "__main__": test_stencil_node_1d() - test_stencil_node_1d_fpga_array(None) test_stencil_node_2d() - test_stencil_node_2d_fpga_array(None) - test_stencil_node_2d_fpga_array_vectorized(None) diff --git a/tests/memlet_propagation_volume_test.py b/tests/memlet_propagation_volume_test.py deleted file mode 100644 index 4e46dfb120..0000000000 --- a/tests/memlet_propagation_volume_test.py +++ /dev/null @@ -1,136 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -from dace.memlet import Memlet -from dace.dtypes import Language, StorageType -from dace.properties import CodeProperty -from dace.sdfg import propagation -from dace.sdfg.sdfg import InterstateEdge -from dace.symbolic import symbol -import dace - -N = dace.symbol('N') -M = dace.symbol('M') - - -def memlet_check_parameters(memlet, volume, dynamic, subsets): - if memlet.volume != volume: - raise RuntimeError('Expected volume of {}, got {}'.format(volume, memlet.volume)) - elif dynamic and not memlet.dynamic: - raise RuntimeError('Expected dynamic volume, got static') - elif memlet.dynamic and not dynamic: - raise RuntimeError('Expected static volume, got dynamic') - - if len(subsets) != memlet.subset.dims(): - raise RuntimeError('Expected subset of dim {}, got {}'.format(len(subsets), memlet.subset.dims())) - - for i in range(len(subsets)): - if subsets[i] != memlet.subset.ranges[i]: - raise RuntimeError('Expected subset {} at dim {}, got {}'.format(subsets[i], i, memlet.subset.ranges[i])) - - -def state_check_executions(state, expected, expected_dynamic=False): - if state.executions != expected: - raise RuntimeError('Expected {} executions, got {}'.format(expected, state.executions)) - elif expected_dynamic and not state.dynamic_executions: - raise RuntimeError('Expected dynamic executions, got static') - elif state.dynamic_executions and not expected_dynamic: - raise RuntimeError('Expected static executions, got dynamic') - - -def make_nested_sdfg(): - sdfg = dace.SDFG('vol_propagation_nested') - - assign_loop_bound = sdfg.add_state('assign') - guard_state = sdfg.add_state('guard') - loop_state = sdfg.add_state('for') - end_state = sdfg.add_state('endfor') - - sdfg.add_edge(assign_loop_bound, guard_state, InterstateEdge(assignments={'i': '0'})) - sdfg.add_edge(guard_state, loop_state, - InterstateEdge(condition=CodeProperty.from_string('i < loop_bound', language=Language.Python))) - sdfg.add_edge(loop_state, guard_state, InterstateEdge(assignments={'i': 'i+1'})) - sdfg.add_edge(guard_state, end_state, - InterstateEdge(condition=CodeProperty.from_string('not (i < loop_bound)', language=Language.Python))) - - in_bound = assign_loop_bound.add_stream('IN_bound', dace.int32, storage=StorageType.FPGA_Local) - loop_bound = assign_loop_bound.add_scalar('loop_bound', - dace.int32, - transient=True, - storage=StorageType.FPGA_Registers) - assign_loop_bound.add_memlet_path(in_bound, loop_bound, memlet=Memlet.simple(loop_bound, '0')) - - in_a = loop_state.add_array('IN_a', [N], dace.int32, storage=StorageType.FPGA_Global) - out_stream = loop_state.add_stream('OUT_stream', dace.int32, storage=StorageType.FPGA_Local) - tasklet2 = loop_state.add_tasklet('compute', {'_IN_a'}, {'_OUT_stream'}, '_OUT_stream = _IN_a[0]') - loop_state.add_memlet_path(in_a, tasklet2, dst_conn='_IN_a', memlet=Memlet.simple(in_a, '0:N')) - loop_state.add_memlet_path(tasklet2, out_stream, src_conn='_OUT_stream', memlet=Memlet.simple(out_stream, '0')) - - return sdfg - - -def make_sdfg(): - sdfg = dace.SDFG('vol_propagation') - - sdfg.add_symbol('N', dace.int32) - sdfg.add_symbol('M', dace.int32) - - state = sdfg.add_state('main') - - a_in = state.add_array('A_in', [N], dace.int32, storage=StorageType.FPGA_Global) - bound_pipe = state.add_stream('bound_in', dace.int32, transient=True, storage=StorageType.FPGA_Local) - out_stream = state.add_stream('out_stream', dace.int32, transient=True, storage=StorageType.FPGA_Local) - - nest = state.add_nested_sdfg(make_nested_sdfg(), { - 'IN_a', - 'IN_bound', - }, { - 'OUT_stream', - }) - - state.add_memlet_path(a_in, nest, dst_conn='IN_a', memlet=Memlet.simple(a_in, '0:N')) - state.add_memlet_path(bound_pipe, nest, dst_conn='IN_bound', memlet=Memlet.simple(bound_pipe, '0', num_accesses=-1)) - state.add_memlet_path(nest, - out_stream, - src_conn='OUT_stream', - memlet=Memlet.simple(out_stream, '0', num_accesses=-1)) - - return sdfg - - -def test_memlet_volume_propagation_nsdfg(): - sdfg = make_sdfg() - propagation.propagate_memlets_sdfg(sdfg) - - main_state = sdfg.nodes()[0] - data_in_memlet = main_state.edges()[0].data - bound_stream_in_memlet = main_state.edges()[1].data - out_stream_memlet = main_state.edges()[2].data - - memlet_check_parameters(data_in_memlet, 0, True, [(0, N - 1, 1)]) - memlet_check_parameters(bound_stream_in_memlet, 1, False, [(0, 0, 1)]) - memlet_check_parameters(out_stream_memlet, 0, True, [(0, 0, 1)]) - - nested_sdfg = main_state.nodes()[3].sdfg - - loop_state = nested_sdfg.nodes()[2] - - state_check_executions(loop_state, symbol('loop_bound')) - - -def test_memlet_volume_constants(): - sdfg = dace.SDFG('cmprop') - sdfg.add_constant('N', 32) - sdfg.add_array('A', [32], dace.float64) - state = sdfg.add_state() - state.add_mapped_tasklet('doit', dict(i='0:N'), {}, 'a = i', dict(a=dace.Memlet('A[i]')), external_edges=True) - - sdfg.validate() - sink_node = next(iter(state.sink_nodes())) - edge = state.in_edges(sink_node)[0] - - assert not edge.data.dynamic - assert edge.data.volume == dace.symbol('N') - - -if __name__ == '__main__': - test_memlet_volume_propagation_nsdfg() - test_memlet_volume_constants() diff --git a/tests/npbench/deep_learning/conv2d_bias_test.py b/tests/npbench/deep_learning/conv2d_bias_test.py index 648903ffb9..5b776b84c9 100644 --- a/tests/npbench/deep_learning/conv2d_bias_test.py +++ b/tests/npbench/deep_learning/conv2d_bias_test.py @@ -5,11 +5,10 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test, xilinx_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt +from dace.transformation.interstate import InlineSDFG +from dace.transformation.auto.auto_optimize import auto_optimize from dace.config import set_temporary +from dace.autodiff import add_backward_pass C_in, C_out, H, K, N, W = (dc.symbol(s, dc.int64) for s in ('C_in', 'C_out', 'H', 'K', 'N', 'W')) @@ -72,6 +71,43 @@ def conv2d_bias_np(input, weights, bias): return conv2d_np(input, weights) + bias +def conv2d_lax(jnp, lax, input, weights): + # Kernel size, number of input images, and output dimensions. + K = weights.shape[0] # Assuming square kernel of size K x K. + N = input.shape[0] # Batch size. + H_out = input.shape[1] - K + 1 # Output height. + W_out = input.shape[2] - K + 1 # Output width. + C_out = weights.shape[3] # Number of output channels. + + # Allocate output array. + output = jnp.empty((N, H_out, W_out, C_out), dtype=input.dtype) + + # Row update: iterate over output rows. + def row_update(out, i): + # Column update: iterate over output columns. + def col_update(out, j): + # Extract a patch from 'input' at the given (i, j) position. + patch = lax.dynamic_slice(input, (0, i, j, 0), (N, K, K, input.shape[-1])) + # Expand dims on the patch to broadcast with weights. + # weights: shape (K, K, in_channels, C_out) + # patch[..., None] becomes shape (N, K, K, in_channels, 1) + # We add a new leading dimension to weights to broadcast: + conv = jnp.sum(patch[..., None] * weights[None, :, :, :], axis=(1, 2, 3)) + # conv now has shape (N, C_out). Update output at (0, i, j, 0). + out = lax.dynamic_update_slice(out, conv[:, None, None, :], (0, i, j, 0)) + return out, None + + out, _ = lax.scan(col_update, out, jnp.arange(W_out)) + return out, None + + output, _ = lax.scan(row_update, output, jnp.arange(H_out)) + return output + + +def conv2d_bias_jax_kernel(jnp, lax, input, weights, bias): + return jnp.sum(conv2d_lax(jnp, lax, input, weights) + bias) + + def run_conv2d_bias(device_type: dace.dtypes.DeviceType): ''' Runs conv2d_bias for the given device @@ -87,19 +123,6 @@ def run_conv2d_bias(device_type: dace.dtypes.DeviceType): sdfg = conv2d_bias_kernel.to_sdfg() sdfg = auto_optimize(sdfg, device_type) out = sdfg(input, weight, bias, C_in=C_in, C_out=C_out, H=H, K=K, N=N, W=W) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = conv2d_bias_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - # Use FPGA Expansion for lib nodes, and expand them to enable further optimizations - from dace.libraries.standard import Reduce - Reduce.default_implementation = "FPGAPartialReduction" - sdfg.expand_library_nodes() - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - sdfg.specialize(dict(C_in=C_in, C_out=C_out, H=H, K=K, N=N, W=W)) - out = sdfg(input, weight, bias) # Compute ground truth and validate out_ref = conv2d_bias_np(input, weight, bias) @@ -107,6 +130,52 @@ def run_conv2d_bias(device_type: dace.dtypes.DeviceType): return sdfg +def run_conv2d_bias_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (npbench test size) + N, C_in, C_out, K, H, W = 4, 3, 8, 2, 12, 12 + input, weights, bias = initialize(C_in, C_out, H, K, N, W) + + # Initialize gradient computation data + gradient_input = np.zeros_like(input, dtype=np.float32) + gradient___return = np.ones((1, ), dtype=np.float32) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(input: dc.float32[N, H, W, C_in], weights: dc.float32[K, K, C_in, C_out], + bias: dc.float32[C_out]): + A = conv2d(input, weights) + bias + return np.sum(A) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["input"], outputs=["__return"]) + + sdfg(input, + weights, + bias, + C_in=C_in, + C_out=C_out, + H=H, + K=K, + N=N, + W=W, + gradient_input=gradient_input, + gradient___return=gradient___return) + + # Enable float32 for JAX to match DaCe consistency + jax.config.update("jax_enable_x64", False) + + # Numerically validate vs JAX + jax_kernel = lambda input, weights, bias: conv2d_bias_jax_kernel(jnp, lax, input, weights, bias) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=0)) + jax_grad_input = jax_grad(input, weights, bias) + np.testing.assert_allclose(gradient_input, jax_grad_input, atol=1e-6, rtol=1e-6) + + def test_cpu(): run_conv2d_bias(dace.dtypes.DeviceType.CPU) @@ -117,22 +186,22 @@ def test_gpu(): run_conv2d_bias(dace.dtypes.DeviceType.GPU) -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_conv2d_bias(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_conv2d_bias_autodiff() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_conv2d_bias(dace.dtypes.DeviceType.CPU) + run_conv2d_bias_autodiff() elif target == "gpu": run_conv2d_bias(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_conv2d_bias(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/deep_learning/lenet_test.py b/tests/npbench/deep_learning/lenet_test.py index 37cba9af9b..46dfb1aa4f 100644 --- a/tests/npbench/deep_learning/lenet_test.py +++ b/tests/npbench/deep_learning/lenet_test.py @@ -5,12 +5,8 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test, xilinx_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt -from dace.config import temporary_config, Config -import os +from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass N, H, W, C_before_fc1, S0, S1, S2, S3, S4, S5 = (dc.symbol(s, dtype=dc.int64) for s in ('N', 'H', 'W', 'C_before_fc1', 'S0', 'S1', 'S2', 'S3', 'S4', @@ -146,6 +142,74 @@ def lenet5_np(input, conv1, conv1bias, conv2, conv2bias, fc1w, fc1b, fc2w, fc2b, return x @ fc3w + fc3b +def conv2d_lax(jnp, lax, input, weights): + # Kernel size, number of input images, and output dimensions. + K = weights.shape[0] # Assuming square kernel of size K x K. + N = input.shape[0] # Batch size. + H_out = input.shape[1] - K + 1 # Output height. + W_out = input.shape[2] - K + 1 # Output width. + C_out = weights.shape[3] # Number of output channels. + + # Allocate output array. + output = jnp.empty((N, H_out, W_out, C_out), dtype=input.dtype) + + # Row update: iterate over output rows. + def row_update(out, i): + # Column update: iterate over output columns. + def col_update(out, j): + # Extract a patch from 'input' at the given (i, j) position. + patch = lax.dynamic_slice(input, (0, i, j, 0), (N, K, K, input.shape[-1])) + # Expand dims on the patch to broadcast with weights. + # weights: shape (K, K, in_channels, C_out) + # patch[..., None] becomes shape (N, K, K, in_channels, 1) + # We add a new leading dimension to weights to broadcast: + conv = jnp.sum(patch[..., None] * weights[None, :, :, :], axis=(1, 2, 3)) + # conv now has shape (N, C_out). Update output at (0, i, j, 0). + out = lax.dynamic_update_slice(out, conv[:, None, None, :], (0, i, j, 0)) + return out, None + + out, _ = lax.scan(col_update, out, jnp.arange(W_out)) + return out, None + + output, _ = lax.scan(row_update, output, jnp.arange(H_out)) + return output + + +def maxpool2d_lax(jnp, lax, x): + output = jnp.empty([x.shape[0], x.shape[1] // 2, x.shape[2] // 2, x.shape[3]], dtype=x.dtype) + + def row_update(output, i): + + def col_update(output, j): + input_slice = lax.dynamic_slice(x, (0, 2 * i, 2 * j, 0), (x.shape[0], 2, 2, x.shape[3])) + output = lax.dynamic_update_slice(output, jnp.max(input_slice, axis=(1, 2))[:, None, None, :], (0, i, j, 0)) + return output, None + + output, _ = lax.scan(col_update, output, jnp.arange(x.shape[2] // 2)) + return output, None + + output, _ = lax.scan(row_update, output, jnp.arange(x.shape[1] // 2)) + + return output + + +def jax_relu(jnp, x): + return jnp.maximum(x, 0) + + +def lenet_jax_kernel(jnp, lax, input, conv1, conv1bias, conv2, conv2bias, fc1w, fc1b, fc2w, fc2b, fc3w, fc3b): + C_before_fc1 = fc1w.shape[0] + N = input.shape[0] + x = jax_relu(jnp, conv2d_lax(jnp, lax, input, conv1) + conv1bias) + x = maxpool2d_lax(jnp, lax, x) + x = jax_relu(jnp, conv2d_lax(jnp, lax, x, conv2) + conv2bias) + x = maxpool2d_lax(jnp, lax, x) + x = jnp.reshape(x, (N, C_before_fc1)) + x = jax_relu(jnp, x @ fc1w + fc1b) + x = jax_relu(jnp, x @ fc2w + fc2b) + return jnp.sum(x @ fc3w + fc3b) + + def run_lenet(device_type: dace.dtypes.DeviceType): ''' Runs lenet for the given device @@ -175,19 +239,6 @@ def run_lenet(device_type: dace.dtypes.DeviceType): H=H, W=W, C_before_fc1=C_before_fc1) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = lenet5_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - # Use FPGA Expansion for lib nodes, and expand them to enable further optimizations - from dace.libraries.standard import Reduce - Reduce.default_implementation = "FPGAPartialReduction" - sdfg.expand_library_nodes() - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - sdfg.specialize(dict(N=N, H=W, W=W, C_before_fc1=C_before_fc1)) - out = sdfg(input, conv1, conv1bias, conv2, conv2bias, fc1w, fc1b, fc2w, fc2b, fc3w, fc3b) # Compute ground truth and validate out_ref = lenet5_np(input, conv1, conv1bias, conv2, conv2bias, fc1w, fc1b, fc2w, fc2b, fc3w, fc3b, N, C_before_fc1) @@ -195,6 +246,58 @@ def run_lenet(device_type: dace.dtypes.DeviceType): return sdfg +def run_lenet_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (npbench test size) + N, H, W = 4, 16, 16 + input, conv1, conv1bias, conv2, conv2bias, fc1w, fc1b, fc2w, fc2b, fc3w, fc3b, C_before_fc1 = initialize(N, H, W) + + # Initialize gradient computation data + gradient_input = np.zeros_like(input, dtype=np.float32) + gradient___return = np.ones((1, ), dtype=np.float32) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(input: dc.float32[N, H, W, 1], conv1: dc.float32[5, 5, 1, 6], conv1bias: dc.float32[6], + conv2: dc.float32[5, 5, 6, 16], conv2bias: dc.float32[16], fc1w: dc.float32[C_before_fc1, 120], + fc1b: dc.float32[120], fc2w: dc.float32[120, 84], fc2b: dc.float32[84], + fc3w: dc.float32[84, 10], fc3b: dc.float32[10]): + result = lenet5_kernel(input, conv1, conv1bias, conv2, conv2bias, fc1w, fc1b, fc2w, fc2b, fc3w, fc3b) + return np.sum(result) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["input"], outputs=["__return"]) + + sdfg(input, + conv1, + conv1bias, + conv2, + conv2bias, + fc1w, + fc1b, + fc2w, + fc2b, + fc3w, + fc3b, + N=N, + H=H, + W=W, + C_before_fc1=C_before_fc1, + gradient_input=gradient_input, + gradient___return=gradient___return) + + # Numerically validate vs JAX + jax_kernel = lambda input, conv1, conv1bias, conv2, conv2bias, fc1w, fc1b, fc2w, fc2b, fc3w, fc3b: lenet_jax_kernel( + jnp, lax, input, conv1, conv1bias, conv2, conv2bias, fc1w, fc1b, fc2w, fc2b, fc3w, fc3b) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=0)) + jax_grad_input = jax_grad(input, conv1, conv1bias, conv2, conv2bias, fc1w, fc1b, fc2w, fc2b, fc3w, fc3b) + np.testing.assert_allclose(gradient_input, jax_grad_input, rtol=1e-6) + + def test_cpu(monkeypatch): # Serialization causes issues, we temporarily disable it monkeypatch.setenv("DACE_testing_serialization", 0) @@ -207,23 +310,22 @@ def test_gpu(): run_lenet(dace.dtypes.DeviceType.GPU) -@pytest.mark.skip(reason="Dynamic memory allocation") -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_lenet(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_lenet_autodiff() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_lenet(dace.dtypes.DeviceType.CPU) + run_lenet_autodiff() elif target == "gpu": run_lenet(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_lenet(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/deep_learning/mlp_test.py b/tests/npbench/deep_learning/mlp_test.py index 9588b66f68..0dd68fc467 100644 --- a/tests/npbench/deep_learning/mlp_test.py +++ b/tests/npbench/deep_learning/mlp_test.py @@ -5,11 +5,8 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test, xilinx_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt -from dace.config import set_temporary +from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass C_in, N, S0, S1, S2, N1, N2 = (dc.symbol(s, dtype=dc.int64) for s in ('C_in', 'N', 'S0', 'S1', 'S2', 'N1', 'N2')) @@ -78,6 +75,25 @@ def mlp_np(input, w1, b1, w2, b2, w3, b3): return x +def jax_relu(jnp, x): + return jnp.maximum(x, 0) + + +# Numerically-stable version of softmax +def jax_softmax(jnp, x): + tmp_max = jnp.max(x, axis=-1, keepdims=True) + tmp_out = jnp.exp(x - tmp_max) + tmp_sum = jnp.sum(tmp_out, axis=-1, keepdims=True) + return tmp_out / tmp_sum + + +def mlp_jax_kernel(jnp, input, w1, b1, w2, b2, w3, b3): + x = jax_relu(jnp, input @ w1 + b1) + x = jax_relu(jnp, x @ w2 + b2) + x = jax_softmax(jnp, x @ w3 + b3) # Softmax call can be omitted if necessary + return jnp.sum(x) + + def run_mlp(device_type: dace.dtypes.DeviceType): ''' Runs conv2d_bias for the given device @@ -93,21 +109,6 @@ def run_mlp(device_type: dace.dtypes.DeviceType): sdfg = mlp_kernel.to_sdfg() sdfg = auto_optimize(sdfg, device_type) out = sdfg(input, w1, b1, w2, b2, w3, b3, N=N, S0=S0, S1=S1, S2=S2, C_in=C_in) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = mlp_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - # Use FPGA Expansion for lib nodes, and expand them to enable further optimizations - from dace.libraries.standard import Reduce - Reduce.default_implementation = "FPGAPartialReduction" - from dace.libraries.blas import Gemm - Gemm.default_implementation = "FPGA1DSystolic" - sdfg.expand_library_nodes() - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - sdfg.specialize(dict(N=N, S0=S0, S1=S1, S2=S2, C_in=C_in)) - out = sdfg(input, w1, b1, w2, b2, w3, b3) # Compute ground truth and validate out_ref = mlp_np(input, w1, b1, w2, b2, w3, b3) @@ -115,6 +116,53 @@ def run_mlp(device_type: dace.dtypes.DeviceType): return sdfg +def run_mlp_autodiff(): + import jax + import jax.numpy as jnp + + # Initialize data (npbench test size) + C_in, N, S0, S1, S2 = 3, 8, 30, 20, 20 + input, w1, b1, w2, b2, w3, b3 = initialize(C_in, N, S0, S1, S2) + + # Initialize gradient computation data + gradient_input = np.zeros_like(input, dtype=np.float32) + gradient___return = np.ones((1, ), dtype=np.float32) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(input: dc.float32[N, C_in], w1: dc.float32[C_in, S0], b1: dc.float32[S0], + w2: dc.float32[S0, S1], b2: dc.float32[S1], w3: dc.float32[S1, S2], b3: dc.float32[S2]): + x1 = relu(input @ w1 + b1) + x2 = relu(x1 @ w2 + b2) + x3 = softmax(x2 @ w3 + b3) + return np.sum(x3) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["input"], outputs=["__return"]) + + sdfg(input, + w1, + b1, + w2, + b2, + w3, + b3, + N=N, + S0=S0, + S1=S1, + S2=S2, + C_in=C_in, + gradient_input=gradient_input, + gradient___return=gradient___return) + + # Numerically validate vs JAX + jax_kernel = lambda input, w1, b1, w2, b2, w3, b3: mlp_jax_kernel(jnp, input, w1, b1, w2, b2, w3, b3) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=0)) + jax_grad_input = jax_grad(input, w1, b1, w2, b2, w3, b3) + np.testing.assert_allclose(gradient_input, jax_grad_input, rtol=1e-4, atol=1e-10) + + def test_cpu(): run_mlp(dace.dtypes.DeviceType.CPU) @@ -124,23 +172,22 @@ def test_gpu(): run_mlp(dace.dtypes.DeviceType.GPU) -@pytest.mark.skip(reason="Intel, compilation error") -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_mlp(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_mlp_autodiff() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_mlp(dace.dtypes.DeviceType.CPU) + run_mlp_autodiff() elif target == "gpu": run_mlp(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_mlp(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/deep_learning/resnet_test.py b/tests/npbench/deep_learning/resnet_test.py index cfe43718e0..e171c5cbe4 100644 --- a/tests/npbench/deep_learning/resnet_test.py +++ b/tests/npbench/deep_learning/resnet_test.py @@ -5,11 +5,8 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test, xilinx_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt -from dace.config import set_temporary +from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass N, H, W, C1, C2, S0, S1, S2, S3, S4, S5 = (dc.symbol(s, dtype=dc.int64) for s in ('N', 'H', 'W', 'C1', 'C2', 'S0', 'S1', 'S2', 'S3', 'S4', 'S5')) @@ -168,6 +165,65 @@ def resnet_basicblock_np(input, conv1, conv2, conv3): return relu_np(x + input) +def conv2d_lax(jnp, lax, input, weights): + # Kernel size, number of input images, and output dimensions. + K = weights.shape[0] # Assuming square kernel of size K x K. + N = input.shape[0] # Batch size. + H_out = input.shape[1] - K + 1 # Output height. + W_out = input.shape[2] - K + 1 # Output width. + C_out = weights.shape[3] # Number of output channels. + + # Allocate output array. + output = jnp.empty((N, H_out, W_out, C_out), dtype=input.dtype) + + # Row update: iterate over output rows. + def row_update(out, i): + # Column update: iterate over output columns. + def col_update(out, j): + # Extract a patch from 'input' at the given (i, j) position. + patch = lax.dynamic_slice(input, (0, i, j, 0), (N, K, K, input.shape[-1])) + # Expand dims on the patch to broadcast with weights. + # weights: shape (K, K, in_channels, C_out) + # patch[..., None] becomes shape (N, K, K, in_channels, 1) + # We add a new leading dimension to weights to broadcast: + conv = jnp.sum(patch[..., None] * weights[None, :, :, :], axis=(1, 2, 3)) + # conv now has shape (N, C_out). Update output at (0, i, j, 0). + out = lax.dynamic_update_slice(out, conv[:, None, None, :], (0, i, j, 0)) + return out, None + + out, _ = lax.scan(col_update, out, jnp.arange(W_out)) + return out, None + + output, _ = lax.scan(row_update, output, jnp.arange(H_out)) + return output + + +def jax_relu(jnp, x): + return jnp.maximum(x, 0) + + +# Batch normalization operator, as used in ResNet +def jax_batchnorm2d(jnp, x, eps=1e-5): + mean = jnp.mean(x, axis=0, keepdims=True) + std = jnp.std(x, axis=0, keepdims=True) + return (x - mean) / jnp.sqrt(std + eps) + + +def resnet_jax_kernel(jnp, lax, input, conv1, conv2, conv3): + # Pad output of first convolution for second convolution + padded = jnp.zeros((input.shape[0], input.shape[1] + 2, input.shape[2] + 2, conv1.shape[3]), dtype=input.dtype) + padded = lax.dynamic_update_slice(padded, conv2d_lax(jnp, lax, input, conv1), (0, 1, 1, 0)) + x = jax_batchnorm2d(jnp, padded) + x = jax_relu(jnp, x) + + x = conv2d_lax(jnp, lax, x, conv2) + x = jax_batchnorm2d(jnp, x) + x = jax_relu(jnp, x) + x = conv2d_lax(jnp, lax, x, conv3) + x = jax_batchnorm2d(jnp, x) + return jnp.sum(jax_relu(jnp, x + input)) + + def run_resnet(device_type: dace.dtypes.DeviceType): ''' Runs resnet for the given device @@ -183,19 +239,6 @@ def run_resnet(device_type: dace.dtypes.DeviceType): sdfg = resnet_basicblock.to_sdfg() sdfg = auto_optimize(sdfg, device_type) out = sdfg(input=input, conv1=conv1, conv2=conv2, conv3=conv3, N=N, W=W, H=H, C1=C1, C2=C2) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = resnet_basicblock.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - # Use FPGA Expansion for lib nodes, and expand them to enable further optimizations - from dace.libraries.standard import Reduce - Reduce.default_implementation = "FPGAPartialReduction" - sdfg.expand_library_nodes() - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - sdfg.specialize(dict(N=N, W=W, H=H, C1=C1, C2=C2)) - out = sdfg(input=input, conv1=conv1, conv2=conv2, conv3=conv3) # Compute ground truth and validate out_ref = resnet_basicblock_np(input, conv1, conv2, conv3) @@ -203,6 +246,53 @@ def run_resnet(device_type: dace.dtypes.DeviceType): return sdfg +def run_resnet_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (npbench test size) + N, W, H, C1, C2 = 2, 8, 8, 8, 4 + input, conv1, conv2, conv3 = initialize(N, W, H, C1, C2) + + # Initialize gradient computation data + gradient_input = np.zeros_like(input, dtype=np.float32) + gradient___return = np.ones((1, ), dtype=np.float32) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(input: dc.float32[N, H, W, C1], conv1: dc.float32[1, 1, C1, C2], + conv2: dc.float32[3, 3, C2, C2], conv3: dc.float32[1, 1, C2, C1]): + # Pad output of first convolution for second convolution + x = resnet_basicblock(input, conv1, conv2, conv3) + return np.sum(x) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg(simplify=True) + add_backward_pass(sdfg=sdfg, inputs=["input"], outputs=["__return"]) + + sdfg(input, + conv1, + conv2, + conv3, + N=N, + W=W, + H=H, + C1=C1, + C2=C2, + gradient_input=gradient_input, + gradient___return=gradient___return) + + # Numerically validate vs JAX + jax_kernel = lambda input, conv1, conv2, conv3: resnet_jax_kernel(jnp, lax, input, conv1, conv2, conv3) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=0)) + jax_grad_input = jax_grad(input, conv1, conv2, conv3) + + # The tolerance if fairly high with float32 inputs + # The same code using float64 works with a 1e-12 tolerance + np.testing.assert_allclose(gradient_input, jax_grad_input, atol=1e-2, rtol=1e-2) + + def test_cpu(): run_resnet(dace.dtypes.DeviceType.CPU) @@ -213,23 +303,22 @@ def test_gpu(): run_resnet(dace.dtypes.DeviceType.GPU) -@pytest.mark.skip(reason="Dynamic memory allocation") -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_resnet(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_resnet_autodiff() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_resnet(dace.dtypes.DeviceType.CPU) + run_resnet_autodiff() elif target == "gpu": run_resnet(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_resnet(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/deep_learning/softmax_test.py b/tests/npbench/deep_learning/softmax_test.py index 4408645adc..aaf460c9f1 100644 --- a/tests/npbench/deep_learning/softmax_test.py +++ b/tests/npbench/deep_learning/softmax_test.py @@ -5,11 +5,8 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test, xilinx_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt -from dace.config import set_temporary +from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass N, H, SM = (dc.symbol(s, dc.int64) for s in ('N', 'H', 'SM')) @@ -37,6 +34,13 @@ def ground_truth(x): return tmp_out / tmp_sum +def softmax_jax_kernel(jnp, x): + tmp_max = jnp.max(x, axis=-1, keepdims=True) + tmp_out = jnp.exp(x - tmp_max) + tmp_sum = jnp.sum(tmp_out, axis=-1, keepdims=True) + return jnp.sum(tmp_out / tmp_sum) + + def run_softmax(device_type: dace.dtypes.DeviceType): ''' Runs Softmax for the given device @@ -52,19 +56,6 @@ def run_softmax(device_type: dace.dtypes.DeviceType): sdfg = softmax_kernel.to_sdfg() sdfg = auto_optimize(sdfg, device_type) out = sdfg(x, N=N, H=H, SM=SM) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = softmax_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - # Use FPGA Expansion for lib nodes, and expand them to enable further optimizations - from dace.libraries.standard import Reduce - Reduce.default_implementation = "FPGAPartialReduction" - sdfg.expand_library_nodes() - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - sdfg.specialize(dict(N=N, H=H, SM=SM)) - out = sdfg(x) # Compute ground truth and validate out_ref = ground_truth(x) @@ -72,6 +63,36 @@ def run_softmax(device_type: dace.dtypes.DeviceType): return sdfg +def run_softmax_autodiff(): + import jax + import jax.numpy as jnp + + # Initialize data (npbench test size) + N, H, SM = 4, 4, 32 + x = initialize(N, H, SM) + out = np.zeros_like(x) + + # Initialize gradient computation data + gradient_x = np.zeros_like(x) + gradient___return = np.ones((1, ), dtype=np.float32) + + # Define sum reduction for the output + @dc.program + def softmax_autodiff_kernel(x: dc.float32[N, H, SM, SM]): + return np.sum(softmax_kernel(x)) + + # Add the backward pass to the SDFG + sdfg = softmax_autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["x"], outputs=["__return"]) + sdfg(x, out, N=N, H=H, SM=SM, gradient_x=gradient_x, gradient___return=gradient___return) + + # Numerically validate vs JAX + jax_kernel = lambda x: softmax_jax_kernel(jnp, x) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=0)) + jax_grad_x = jax_grad(x) + np.testing.assert_allclose(gradient_x, jax_grad_x, atol=1e-6) + + def test_cpu(): run_softmax(dace.dtypes.DeviceType.CPU) @@ -81,22 +102,22 @@ def test_gpu(): run_softmax(dace.dtypes.DeviceType.GPU) -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_softmax(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_softmax_autodiff() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_softmax(dace.dtypes.DeviceType.CPU) + run_softmax_autodiff() elif target == "gpu": run_softmax(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_softmax(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/misc/arc_distance_test.py b/tests/npbench/misc/arc_distance_test.py index 571c0dd772..42b125e608 100644 --- a/tests/npbench/misc/arc_distance_test.py +++ b/tests/npbench/misc/arc_distance_test.py @@ -6,8 +6,6 @@ import pytest import argparse from dace.transformation.auto.auto_optimize import auto_optimize -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG N = dace.symbol('N', dtype=dace.int64) @@ -50,20 +48,8 @@ def run_arc_distance(device_type: dace.dtypes.DeviceType): sdfg = arc_distance.to_sdfg() sdfg = auto_optimize(sdfg, device_type) val = sdfg(theta_1=t0, phi_1=p0, theta_2=t1, phi_2=p1, N=N) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = arc_distance.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - sdfg.specialize(dict(N=N, )) - val = sdfg( - theta_1=t0, - phi_1=p0, - theta_2=t1, - phi_2=p1, - ) + else: + raise ValueError(f'Unsupported device type: {device_type}') # Compute ground truth and Validate result ref = arc_distance.f(t0, p0, t1, p1) @@ -80,17 +66,10 @@ def test_gpu(): run_arc_distance(dace.dtypes.DeviceType.GPU) -# TODO: Investigate and re-enable if possible. -@pytest.mark.skip(reason="Unexplained CI Regression") -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_arc_distance(dace.dtypes.DeviceType.FPGA) - - if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] @@ -99,5 +78,3 @@ def test_fpga(): run_arc_distance(dace.dtypes.DeviceType.CPU) elif target == "gpu": run_arc_distance(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_arc_distance(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/misc/azimint_hist_test.py b/tests/npbench/misc/azimint_hist_test.py index c6ea9a620d..ed010ceb9b 100644 --- a/tests/npbench/misc/azimint_hist_test.py +++ b/tests/npbench/misc/azimint_hist_test.py @@ -6,8 +6,6 @@ import pytest import argparse from dace.transformation.auto.auto_optimize import auto_optimize -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG N, bins, npt = (dace.symbol(s, dtype=dace.int64) for s in ('N', 'bins', 'npt')) @@ -100,15 +98,8 @@ def run_azimint_hist(device_type: dace.dtypes.DeviceType): sdfg = dace_azimint_hist.to_sdfg() sdfg = auto_optimize(sdfg, device_type) val = sdfg(data=data, radius=radius, N=N, npt=npt) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = dace_azimint_hist.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - sdfg.specialize(dict(N=N, npt=npt)) - val = sdfg(data=data, radius=radius) + else: + raise ValueError(f'Unsupported device type: {device_type}') # Compute ground truth and Validate result ref = numpy_azimint_hist(data, radius, npt) @@ -129,16 +120,10 @@ def test_gpu(): run_azimint_hist(dace.dtypes.DeviceType.GPU) -@pytest.mark.skip(reason="FPGA Transform error") -@fpga_test(assert_ii_1=False) -def test_fpga(): - run_azimint_hist(dace.dtypes.DeviceType.FPGA) - - if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] @@ -147,5 +132,3 @@ def test_fpga(): run_azimint_hist(dace.dtypes.DeviceType.CPU) elif target == "gpu": run_azimint_hist(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_azimint_hist(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/misc/azimint_naive_test.py b/tests/npbench/misc/azimint_naive_test.py index 930571498a..b8b165b41a 100644 --- a/tests/npbench/misc/azimint_naive_test.py +++ b/tests/npbench/misc/azimint_naive_test.py @@ -6,8 +6,6 @@ import pytest import argparse from dace.transformation.auto.auto_optimize import auto_optimize -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG N, npt = (dace.symbol(s, dtype=dace.int64) for s in ('N', 'npt')) @@ -70,19 +68,8 @@ def run_azimint_naive(device_type: dace.dtypes.DeviceType): sdfg = dace_azimint_naive.to_sdfg() sdfg = auto_optimize(sdfg, device_type) val = sdfg(data=data, radius=radius, N=N, npt=npt) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = dace_azimint_naive.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - from dace.libraries.standard import Reduce - Reduce.default_implementation = "FPGAPartialReduction" - sdfg.expand_library_nodes() - - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - sdfg.specialize(dict(N=N, npt=npt)) - val = sdfg(data=data, radius=radius) + else: + raise ValueError(f'Unsupported device type: {device_type}') # Compute ground truth and Validate result ref = numpy_azimint_naive(data, radius, npt) @@ -99,16 +86,10 @@ def test_gpu(): run_azimint_naive(dace.dtypes.DeviceType.GPU) -@pytest.mark.skip(reason="Incorrect output") -@fpga_test(assert_ii_1=False) -def test_fpga(): - run_azimint_naive(dace.dtypes.DeviceType.FPGA) - - if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] @@ -117,5 +98,3 @@ def test_fpga(): run_azimint_naive(dace.dtypes.DeviceType.CPU) elif target == "gpu": run_azimint_naive(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_azimint_naive(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/misc/cavity_flow_test.py b/tests/npbench/misc/cavity_flow_test.py index d2e4d50f13..5b59a35c14 100644 --- a/tests/npbench/misc/cavity_flow_test.py +++ b/tests/npbench/misc/cavity_flow_test.py @@ -5,9 +5,8 @@ import dace import pytest import argparse -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG +from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass nx, ny, nit = (dace.symbol(s, dace.int64) for s in ('nx', 'ny', 'nit')) @@ -146,6 +145,67 @@ def initialize(ny, nx): return u, v, p, dx, dy, dt +def jax_build_up_b(b, rho, dt, u, v, dx, dy): + b = b.at[1:-1, 1:-1].set( + (rho * (1 / dt * ((u[1:-1, 2:] - u[1:-1, 0:-2]) / (2 * dx) + (v[2:, 1:-1] - v[0:-2, 1:-1]) / + (2 * dy)) - ((u[1:-1, 2:] - u[1:-1, 0:-2]) / (2 * dx))**2 - 2 * + ((u[2:, 1:-1] - u[0:-2, 1:-1]) / (2 * dy) * (v[1:-1, 2:] - v[1:-1, 0:-2]) / + (2 * dx)) - ((v[2:, 1:-1] - v[0:-2, 1:-1]) / (2 * dy))**2))) + return b + + +def jax_pressure_poisson(jnp, nit, p, dx, dy, b): + pn = jnp.empty_like(p) + pn = p.copy() + + for q in range(nit): + pn = p.copy() + p = p.at[1:-1, 1:-1].set((((pn[1:-1, 2:] + pn[1:-1, 0:-2]) * dy**2 + (pn[2:, 1:-1] + pn[0:-2, 1:-1]) * dx**2) / + (2 * (dx**2 + dy**2)) - dx**2 * dy**2 / (2 * (dx**2 + dy**2)) * b[1:-1, 1:-1])) + + p = p.at[:, -1].set(p[:, -2]) # dp/dx = 0 at x = 2 + p = p.at[0, :].set(p[1, :]) # dp/dy = 0 at y = 0 + p = p.at[:, 0].set(p[:, 1]) # dp/dx = 0 at x = 0 + p = p.at[-1, :].set(0) # p = 0 at y = 2 + return p + + +def cavity_flow_jax_kernel(jnp, nx, ny, nt, nit, u, v, dt, dx, dy, p, rho, nu): + un = jnp.empty_like(u) + vn = jnp.empty_like(v) + b = jnp.zeros((ny, nx)) + + for n in range(nt): + un = u.copy() + vn = v.copy() + + b = jax_build_up_b(b, rho, dt, u, v, dx, dy) + p = jax_pressure_poisson(jnp, nit, p, dx, dy, b) + + u = u.at[1:-1, 1:-1].set( + (un[1:-1, 1:-1] - un[1:-1, 1:-1] * dt / dx * (un[1:-1, 1:-1] - un[1:-1, 0:-2]) - vn[1:-1, 1:-1] * dt / dy * + (un[1:-1, 1:-1] - un[0:-2, 1:-1]) - dt / (2 * rho * dx) * (p[1:-1, 2:] - p[1:-1, 0:-2]) + nu * + (dt / dx**2 * (un[1:-1, 2:] - 2 * un[1:-1, 1:-1] + un[1:-1, 0:-2]) + dt / dy**2 * + (un[2:, 1:-1] - 2 * un[1:-1, 1:-1] + un[0:-2, 1:-1])))) + + v = v.at[1:-1, 1:-1].set( + (vn[1:-1, 1:-1] - un[1:-1, 1:-1] * dt / dx * (vn[1:-1, 1:-1] - vn[1:-1, 0:-2]) - vn[1:-1, 1:-1] * dt / dy * + (vn[1:-1, 1:-1] - vn[0:-2, 1:-1]) - dt / (2 * rho * dy) * (p[2:, 1:-1] - p[0:-2, 1:-1]) + nu * + (dt / dx**2 * (vn[1:-1, 2:] - 2 * vn[1:-1, 1:-1] + vn[1:-1, 0:-2]) + dt / dy**2 * + (vn[2:, 1:-1] - 2 * vn[1:-1, 1:-1] + vn[0:-2, 1:-1])))) + + u = u.at[0, :].set(0) + u = u.at[:, 0].set(0) + u = u.at[:, -1].set(0) + u = u.at[-1, :].set(1) + v = v.at[0, :].set(0) + v = v.at[-1, :].set(0) + v = v.at[:, 0].set(0) + v = v.at[:, -1].set(0) + + return jnp.sum(v) + + def run_cavity_flow(device_type: dace.dtypes.DeviceType): ''' Runs cavity-flow for the given device @@ -162,19 +222,8 @@ def run_cavity_flow(device_type: dace.dtypes.DeviceType): sdfg = dace_cavity_flow.to_sdfg() sdfg = auto_optimize(sdfg, device_type) sdfg(nt=nt, nit=nit, u=dace_u, v=dace_v, dt=dt, dx=dx, dy=dy, p=dace_p, rho=rho, nu=nu, ny=ny, nx=nx) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = dace_cavity_flow.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - ########################### - # FPGA Auto Opt - fpga_auto_opt.fpga_global_to_local(sdfg) - - sdfg.specialize(dict(nx=nx, ny=ny)) - sdfg(nt=nt, u=dace_u, v=dace_v, dt=dt, dx=dx, dy=dy, p=dace_p, rho=rho, nu=nu, nit=nit) + else: + raise ValueError(f'Unsupported device type: {device_type}') # Compute ground truth and Validate result numpy_cavity_flow(nx, ny, nt, nit, u, v, dt, dx, dy, p, rho, nu) @@ -184,6 +233,58 @@ def run_cavity_flow(device_type: dace.dtypes.DeviceType): return sdfg +def run_cavity_flow_autodiff(): + import jax + import jax.numpy as jnp + + # Initialize data (test size from benchmark) + ny, nx, nt, nit, rho, nu = (4, 4, 4, 5, 1.0, 0.1) + u, v, p, dx, dy, dt = initialize(ny, nx) + jax_u, jax_v, jax_p = u.copy(), v.copy(), p.copy() + + # Initialize gradient computation data + gradient_u = np.zeros_like(u) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define cavity flow kernel based on benchmark with __return pattern + @dace.program + def autodiff_kernel(nt: dace.int64, nit: dace.int64, u: dace.float64[ny, nx], v: dace.float64[ny, nx], + dt: dace.float64, dx: dace.float64, dy: dace.float64, p: dace.float64[ny, nx], + rho: dace.float64, nu: dace.float64): + + dace_cavity_flow(nt, nit, u, v, dt, dx, dy, p, rho, nu) + return np.sum(v) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg(simplify=True) + add_backward_pass(sdfg=sdfg, inputs=["u"], outputs=["__return"]) + sdfg(nt, + nit, + u, + v, + dt, + dx, + dy, + p, + rho, + nu, + ny=ny, + nx=nx, + nit=nit, + gradient_u=gradient_u, + gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda nx, ny, nt, nit, u, v, dt, dx, dy, p, rho, nu: cavity_flow_jax_kernel( + jnp, nx, ny, nt, nit, u, v, dt, dx, dy, p, rho, nu) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=4), static_argnums=(0, 1, 2, 3)) + jax_grad_u = jax_grad(nx, ny, nt, nit, jax_u, jax_v, dt, dx, dy, jax_p, rho, nu) + np.testing.assert_allclose(gradient_u, jax_grad_u, rtol=1e-6, atol=1e-10) + + def test_cpu(): run_cavity_flow(dace.dtypes.DeviceType.CPU) @@ -193,23 +294,22 @@ def test_gpu(): run_cavity_flow(dace.dtypes.DeviceType.GPU) -@pytest.mark.skip(reason="Intel FPGA kernel arguments") -@fpga_test(assert_ii_1=False) -def test_fpga(): - run_cavity_flow(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_cavity_flow_autodiff() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_cavity_flow(dace.dtypes.DeviceType.CPU) + run_cavity_flow_autodiff() elif target == "gpu": run_cavity_flow(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_cavity_flow(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/misc/channel_flow_test.py b/tests/npbench/misc/channel_flow_test.py index c5768cd11f..4dea47b563 100644 --- a/tests/npbench/misc/channel_flow_test.py +++ b/tests/npbench/misc/channel_flow_test.py @@ -6,8 +6,6 @@ import pytest import argparse from dace.transformation.auto.auto_optimize import auto_optimize -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG nx, ny, nit = (dace.symbol(s, dace.int64) for s in ('nx', 'ny', 'nit')) @@ -253,34 +251,8 @@ def run_channel_flow(device_type: dace.dtypes.DeviceType): sdfg = dace_channel_flow.to_sdfg() sdfg = auto_optimize(sdfg, device_type) sdfg(nit=nit, u=dace_u, v=dace_v, dt=dt, dx=dx, dy=dy, p=dace_p, rho=rho, nu=nu, F=F, ny=ny, nx=nx) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = dace_channel_flow.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - from dace.libraries.standard import Reduce - Reduce.default_implementation = "FPGAPartialReduction" - sdfg.expand_library_nodes() - - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - ########################### - # FPGA Auto Opt - # fpga_auto_opt.fpga_global_to_local(sdfg) - - sdfg.specialize(dict(nx=nx, ny=ny)) - sdfg( - nit=nit, - u=dace_u, - v=dace_v, - dt=dt, - dx=dx, - dy=dy, - p=dace_p, - rho=rho, - nu=nu, - F=F, - ) + else: + raise ValueError(f'Unsupported device type: {device_type}') # Compute ground truth and Validate result numpy_channel_flow(nit, u, v, dt, dx, dy, p, rho, nu, F) @@ -299,16 +271,10 @@ def test_gpu(): run_channel_flow(dace.dtypes.DeviceType.GPU) -@pytest.mark.skip(reason="Compiler error after codegen") -@fpga_test(assert_ii_1=False) -def test_fpga(): - run_channel_flow(dace.dtypes.DeviceType.FPGA) - - if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] @@ -317,5 +283,3 @@ def test_fpga(): run_channel_flow(dace.dtypes.DeviceType.CPU) elif target == "gpu": run_channel_flow(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_channel_flow(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/misc/compute_test.py b/tests/npbench/misc/compute_test.py index feb5585ca5..9bc220ddbf 100644 --- a/tests/npbench/misc/compute_test.py +++ b/tests/npbench/misc/compute_test.py @@ -6,8 +6,7 @@ import pytest import argparse from dace.transformation.auto.auto_optimize import auto_optimize -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG +from dace.autodiff import add_backward_pass def relerror(val, ref): @@ -50,20 +49,8 @@ def run_compute(device_type: dace.dtypes.DeviceType): sdfg = compute.to_sdfg() sdfg = auto_optimize(sdfg, device_type) val = sdfg(array_1=array_1, array_2=array_2, a=a, b=b, c=c, M=M, N=N) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = compute.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - from dace.libraries.standard import Reduce - Reduce.default_implementation = "FPGAPartialReduction" - sdfg.expand_library_nodes() - - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - - sdfg.specialize(dict(M=M, N=N)) - val = sdfg(array_1=array_1, array_2=array_2, a=a, b=b, c=c) + else: + raise ValueError(f'Unsupported device type: {device_type}') # Compute ground truth and Validate result ref = compute.f(array_1, array_2, a, b, c) @@ -80,16 +67,10 @@ def test_gpu(): run_compute(dace.dtypes.DeviceType.GPU) -@pytest.mark.skip(reason="Compiler error") -@fpga_test(assert_ii_1=False) -def test_fpga(): - run_compute(dace.dtypes.DeviceType.FPGA) - - if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] @@ -98,5 +79,3 @@ def test_fpga(): run_compute(dace.dtypes.DeviceType.CPU) elif target == "gpu": run_compute(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_compute(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/misc/contour_integral_test.py b/tests/npbench/misc/contour_integral_test.py index 02e1629b76..8dbe791187 100644 --- a/tests/npbench/misc/contour_integral_test.py +++ b/tests/npbench/misc/contour_integral_test.py @@ -6,8 +6,6 @@ import pytest import argparse from dace.transformation.auto.auto_optimize import auto_optimize -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG def relerror(val, ref): @@ -87,20 +85,8 @@ def run_contour_integral(device_type: dace.dtypes.DeviceType): sdfg = dace_contour_integral.to_sdfg() sdfg = auto_optimize(sdfg, device_type) val0, val1 = sdfg(Ham=Ham, int_pts=int_pts, Y=Y, NR=NR, NM=NM, slab_per_bc=slab_per_bc) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = dace_contour_integral.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - from dace.libraries.standard import Reduce - Reduce.default_implementation = "FPGAPartialReduction" - sdfg.expand_library_nodes() - - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - - sdfg.specialize(dict(NR=NR, NM=NM, slab_per_bc=slab_per_bc)) - val0, val1 = sdfg(Ham=Ham, int_pts=int_pts, Y=Y) + else: + raise ValueError(f'Unsupported device type: {device_type}') # Compute ground truth and Validate result ref0, ref1 = numpy_contour_integral(NR, NM, slab_per_bc, Ham, int_pts, Y) @@ -119,16 +105,10 @@ def test_gpu(): run_contour_integral(dace.dtypes.DeviceType.GPU) -@pytest.mark.skip(reason="Missing FPGA friendly expansions for solve, getrf, getrs") -@fpga_test(assert_ii_1=False) -def test_fpga(): - run_contour_integral(dace.dtypes.DeviceType.FPGA) - - if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] @@ -137,5 +117,3 @@ def test_fpga(): run_contour_integral(dace.dtypes.DeviceType.CPU) elif target == "gpu": run_contour_integral(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_contour_integral(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/misc/crc16_test.py b/tests/npbench/misc/crc16_test.py index 3a7e1ae7cf..4b5d7b73e3 100644 --- a/tests/npbench/misc/crc16_test.py +++ b/tests/npbench/misc/crc16_test.py @@ -5,11 +5,7 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test, xilinx_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt -from dace.config import set_temporary +from dace.transformation.auto.auto_optimize import auto_optimize N = dace.symbol('N') poly: dace.uint16 = 0x8408 @@ -77,15 +73,8 @@ def run_crc16(device_type: dace.dtypes.DeviceType): sdfg = crc16_kernel.to_sdfg() sdfg = auto_optimize(sdfg, device_type) out = sdfg(data, N=N) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = crc16_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - sdfg.specialize(dict(N=N)) - out = sdfg(data) + else: + raise ValueError(f'Unsupported device type: {device_type}') # Compute ground truth and validate out_ref = ground_truth(data) @@ -102,16 +91,10 @@ def test_gpu(): run_crc16(dace.dtypes.DeviceType.GPU) -@pytest.mark.skip(reason="Operand type in binary expressions") -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_crc16(dace.dtypes.DeviceType.FPGA) - - if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] @@ -120,5 +103,3 @@ def test_fpga(): run_crc16(dace.dtypes.DeviceType.CPU) elif target == "gpu": run_crc16(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_crc16(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/misc/go_fast_test.py b/tests/npbench/misc/go_fast_test.py index bf686b5b15..eed1d89873 100644 --- a/tests/npbench/misc/go_fast_test.py +++ b/tests/npbench/misc/go_fast_test.py @@ -5,11 +5,8 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test, xilinx_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt -from dace.config import set_temporary +from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass N = dc.symbol('N', dtype=dc.int64) @@ -50,15 +47,8 @@ def run_go_fast(device_type: dace.dtypes.DeviceType): sdfg = go_fast_kernel.to_sdfg() sdfg = auto_optimize(sdfg, device_type) out = sdfg(a, N=N) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = go_fast_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - sdfg.specialize(dict(N=N)) - out = sdfg(a) + else: + raise ValueError(f'Unsupported device type: {device_type}') # Compute ground truth and validate out_ref = ground_truth(a) @@ -66,6 +56,51 @@ def run_go_fast(device_type: dace.dtypes.DeviceType): return sdfg +def go_fast_jax_kernel(jnp, lax, a): + + def body_fn(trace, i): + # Update the trace by adding tanh(a[i, i]) + new_trace = trace + jnp.tanh(a[i, i]) + return new_trace, None # Return a dummy output. + + trace, _ = lax.scan(body_fn, 0.0, jnp.arange(a.shape[0])) + return jnp.sum(a + trace) + + +def run_go_fast_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize forward data (using smaller size for AD test) + N = 20 + a = initialize(N) + + # Initialize gradient computation data + gradient_a = np.zeros_like(a) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(a: dc.float64[N, N]): + result = go_fast_kernel(a) + return np.sum(result) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["a"], outputs=["__return"]) + sdfg(a, N=N, gradient_a=gradient_a, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda a: go_fast_jax_kernel(jnp, lax, a) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=0)) + jax_grad_a = jax_grad(a) + np.testing.assert_allclose(gradient_a, jax_grad_a) + + def test_cpu(): run_go_fast(dace.dtypes.DeviceType.CPU) @@ -75,23 +110,22 @@ def test_gpu(): run_go_fast(dace.dtypes.DeviceType.GPU) -@pytest.mark.skip(reason="Operand type in binary expressions") -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_go_fast(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_go_fast_autodiff() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_go_fast(dace.dtypes.DeviceType.CPU) + run_go_fast_autodiff() elif target == "gpu": run_go_fast(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_go_fast(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/misc/mandelbrot1_test.py b/tests/npbench/misc/mandelbrot1_test.py index 6fb8262aa2..b79dba9498 100644 --- a/tests/npbench/misc/mandelbrot1_test.py +++ b/tests/npbench/misc/mandelbrot1_test.py @@ -5,11 +5,7 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test, xilinx_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt -from dace.config import set_temporary +from dace.transformation.auto.auto_optimize import auto_optimize XN, YN, N = (dc.symbol(s, dtype=dc.int64) for s in ['XN', 'YN', 'N']) @@ -86,15 +82,8 @@ def run_mandelbrot1(device_type: dace.dtypes.DeviceType): sdfg = mandelbrot_kernel.to_sdfg() sdfg = auto_optimize(sdfg, device_type) Z, N = sdfg(xmin, xmax, ymin, ymax, maxiter, horizon, XN=XN, YN=YN) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = mandelbrot_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - sdfg.specialize(dict(XN=XN, YN=YN)) - Z, N = sdfg(xmin, xmax, ymin, ymax, maxiter, horizon) + else: + raise ValueError(f'Unsupported device type: {device_type}') # Compute ground truth and validate Z_ref, N_ref = ground_truth(xmin, xmax, ymin, ymax, XN, YN, maxiter) @@ -114,16 +103,10 @@ def test_gpu(): run_mandelbrot1(dace.dtypes.DeviceType.GPU) -@pytest.mark.skip(reason="Parsing error (see issue #1139)") -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_mandelbrot1(dace.dtypes.DeviceType.FPGA) - - if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] @@ -132,5 +115,3 @@ def test_fpga(): run_mandelbrot1(dace.dtypes.DeviceType.CPU) elif target == "gpu": run_mandelbrot1(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_mandelbrot1(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/misc/mandelbrot2_test.py b/tests/npbench/misc/mandelbrot2_test.py index 1be84d0c1c..fe492adc01 100644 --- a/tests/npbench/misc/mandelbrot2_test.py +++ b/tests/npbench/misc/mandelbrot2_test.py @@ -5,11 +5,7 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test, xilinx_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt -from dace.config import set_temporary +from dace.transformation.auto.auto_optimize import auto_optimize XN, YN, N = (dc.symbol(s, dtype=dc.int64) for s in ['XN', 'YN', 'N']) @@ -134,15 +130,8 @@ def run_mandelbrot2(device_type: dace.dtypes.DeviceType): sdfg = mandelbrot_kernel.to_sdfg() sdfg = auto_optimize(sdfg, device_type) Z, N = sdfg(xmin, xmax, ymin, ymax, maxiter, horizon, XN=XN, YN=YN) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = mandelbrot_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - sdfg.specialize(dict(XN=XN, YN=YN)) - Z, N = sdfg(xmin, xmax, ymin, ymax, maxiter, horizon) + else: + raise ValueError(f'Unsupported device type: {device_type}') # Compute ground truth and validate Z_ref, N_ref = ground_truth(xmin, xmax, ymin, ymax, XN, YN, maxiter) @@ -162,16 +151,10 @@ def test_gpu(): run_mandelbrot2(dace.dtypes.DeviceType.GPU) -@pytest.mark.skip(reason="Parsing error (see issue #1139)") -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_mandelbrot2(dace.dtypes.DeviceType.FPGA) - - if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] @@ -180,5 +163,3 @@ def test_fpga(): run_mandelbrot2(dace.dtypes.DeviceType.CPU) elif target == "gpu": run_mandelbrot2(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_mandelbrot2(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/misc/nbody_test.py b/tests/npbench/misc/nbody_test.py index 24505465e5..f413dba599 100644 --- a/tests/npbench/misc/nbody_test.py +++ b/tests/npbench/misc/nbody_test.py @@ -5,11 +5,7 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test, xilinx_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt -from dace.config import set_temporary +from dace.transformation.auto.auto_optimize import auto_optimize N, Nt = (dc.symbol(s, dtype=dc.int64) for s in ('N', 'Nt')) @@ -277,21 +273,8 @@ def run_nbody(device_type: dace.dtypes.DeviceType): sdfg = nbody.to_sdfg() sdfg = auto_optimize(sdfg, device_type) KE, PE = sdfg(mass, pos, vel, dt, G, softening, N=N, Nt=Nt) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = nbody.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - from dace.libraries.standard import Reduce - Reduce.default_implementation = "FPGAPartialReduction" - from dace.libraries.blas import Gemv - Gemv.default_implementation = "FPGA_Accumulate" - sdfg.expand_library_nodes() - - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - sdfg.specialize(dict(N=N, Nt=Nt)) - KE, PE = sdfg(mass, pos, vel, dt, G, softening) + else: + raise ValueError(f"Unsupported device type: {device_type}") # Compute ground truth and validate KE_ref, PE_ref = nbody_np(mass_ref, pos_ref, vel_ref, N, Nt, dt, G, softening) @@ -310,16 +293,10 @@ def test_gpu(): run_nbody(dace.dtypes.DeviceType.GPU) -@pytest.mark.skip(reason="Xilinx validation error, Intel argument overflow") -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_nbody(dace.dtypes.DeviceType.FPGA) - - if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] @@ -328,5 +305,3 @@ def test_fpga(): run_nbody(dace.dtypes.DeviceType.CPU) elif target == "gpu": run_nbody(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_nbody(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/misc/scattering_self_test.py b/tests/npbench/misc/scattering_self_test.py index 5b9a5ade62..6b973f1c63 100644 --- a/tests/npbench/misc/scattering_self_test.py +++ b/tests/npbench/misc/scattering_self_test.py @@ -5,11 +5,7 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test, xilinx_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt -from dace.config import set_temporary +from dace.transformation.auto.auto_optimize import auto_optimize NA, NB, Nkz, NE, Nqz, Nw, Norb, N3D = (dc.symbol(s, dc.int64) for s in ('NA', 'NB', 'Nkz', 'NE', 'Nqz', 'Nw', 'Norb', 'N3D')) @@ -92,19 +88,8 @@ def run_scattering_self_test(device_type: dace.dtypes.DeviceType): sdfg = scattering_self_energies_kernel.to_sdfg() sdfg = auto_optimize(sdfg, device_type) sdfg(neigh_idx, dH, G, D, Sigma, Nkz=Nkz, NE=NE, Nqz=Nqz, N3D=N3D, NA=NA, NB=NB, Norb=Norb, Nw=Nw) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = scattering_self_energies_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - from dace.libraries.blas import Gemm - Gemm.default_implementation = "FPGA1DSystolic" - sdfg.expand_library_nodes() - - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - sdfg.specialize(dict(Nkz=Nkz, NE=NE, Nqz=Nqz, N3D=N3D, NA=NA, NB=NB, Norb=Norb, Nw=Nw)) - sdfg(neigh_idx, dH, G, D, Sigma) + else: + raise ValueError(f'Unsupported device type: {device_type}') # Compute ground truth and validate ground_truth(neigh_idx, dH, G, D, Sigma_ref) @@ -121,16 +106,10 @@ def test_gpu(): run_scattering_self_test(dace.dtypes.DeviceType.GPU) -@pytest.mark.skip(reason="Compiler error") -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_scattering_self_test(dace.dtypes.DeviceType.FPGA) - - if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] @@ -139,5 +118,3 @@ def test_fpga(): run_scattering_self_test(dace.dtypes.DeviceType.CPU) elif target == "gpu": run_scattering_self_test(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_scattering_self_test(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/misc/spmv_test.py b/tests/npbench/misc/spmv_test.py index 0771ff5198..b657879352 100644 --- a/tests/npbench/misc/spmv_test.py +++ b/tests/npbench/misc/spmv_test.py @@ -6,8 +6,6 @@ import pytest import argparse from dace.transformation.auto.auto_optimize import auto_optimize -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG M, N, nnz = (dc.symbol(s, dtype=dc.int64) for s in ('M', 'N', 'nnz')) @@ -72,21 +70,8 @@ def run_spmv(device_type: dace.dtypes.DeviceType): sdfg = spmv_kernel.to_sdfg() sdfg = auto_optimize(sdfg, device_type) y = sdfg(A_rows, A_cols, np.copy(A_vals), x, M=M, N=N, nnz=nnz) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = spmv_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - # Use FPGA Expansion for lib nodes, and expand them to enable further optimizations - from dace.libraries.blas import Dot - Dot.default_implementation = "FPGA_PartialSums" - sdfg.expand_library_nodes() - - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - - sdfg.specialize(dict(M=M, N=N, nnz=nnz)) - y = sdfg(A_rows, A_cols, np.copy(A_vals), x) + else: + raise ValueError(f'Unsupported device type: {device_type}') # Compute ground truth and Validate result y_ref = ground_truth(A_rows, A_cols, A_vals, x) @@ -103,16 +88,10 @@ def test_gpu(): run_spmv(dace.dtypes.DeviceType.GPU) -@pytest.mark.skip(reason="Missing free symbol") -@fpga_test(assert_ii_1=False) -def test_fpga(): - run_spmv(dace.dtypes.DeviceType.FPGA) - - if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] @@ -121,5 +100,3 @@ def test_fpga(): run_spmv(dace.dtypes.DeviceType.CPU) elif target == "gpu": run_spmv(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_spmv(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/misc/stockham_fft_test.py b/tests/npbench/misc/stockham_fft_test.py index 5878cf621a..746c555cb1 100644 --- a/tests/npbench/misc/stockham_fft_test.py +++ b/tests/npbench/misc/stockham_fft_test.py @@ -6,8 +6,6 @@ import pytest import argparse from dace.transformation.auto.auto_optimize import auto_optimize -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG R, K, M1, M2 = (dc.symbol(s, dtype=dc.int64, integer=True, positive=True) for s in ('R', 'K', 'M1', 'M2')) N = R**K @@ -133,21 +131,8 @@ def run_stockham_fft(device_type: dace.dtypes.DeviceType): sdfg = auto_optimize(sdfg, device_type) sdfg.expand_library_nodes() sdfg(x=x, y=y, N=N, R=R, K=K) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = stockham_fft_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - # Use FPGA Expansion for lib nodes, and expand them to enable further optimizations - from dace.libraries.blas import Gemm - Gemm.default_implementation = "FPGA1DSystolic" - sdfg.expand_library_nodes() - - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - - sdfg.specialize(dict(N=N, R=R, K=K)) - sdfg(x=x, y=y) + else: + raise ValueError(f'Unsupported device type: {device_type}') # Compute ground truth and Validate result ground_truth(N, R, K, x, y_ref) @@ -166,16 +151,10 @@ def test_gpu(): run_stockham_fft(dace.dtypes.DeviceType.GPU) -@pytest.mark.skip(reason="Missing free symbol") -@fpga_test(assert_ii_1=False) -def test_fpga(): - run_stockham_fft(dace.dtypes.DeviceType.FPGA) - - if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] @@ -184,5 +163,3 @@ def test_fpga(): run_stockham_fft(dace.dtypes.DeviceType.CPU) elif target == "gpu": run_stockham_fft(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_stockham_fft(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/polybench/adi_test.py b/tests/npbench/polybench/adi_test.py index 9cccf7da0f..d76e70b4f6 100644 --- a/tests/npbench/polybench/adi_test.py +++ b/tests/npbench/polybench/adi_test.py @@ -5,9 +5,8 @@ import dace import pytest import argparse -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG +from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass # Dataset sizes # TSTEPS, N @@ -43,7 +42,7 @@ def numpy_kernel(TSTEPS, N, u): e = 1.0 + mul2 f = d - for t in range(1, TSTEPS + 1): + for t in range(0, TSTEPS): v[0, 1:N - 1] = 1.0 p[1:N - 1, 0] = 0.0 q[1:N - 1, 0] = v[0, 1:N - 1] @@ -89,7 +88,7 @@ def adi_kernel(TSTEPS: dace.int64, u: dace.float64[N, N]): e = 1.0 + mul2 f = d - for t in range(1, TSTEPS + 1): + for t in range(0, TSTEPS): v[0, 1:N - 1] = 1.0 p[1:N - 1, 0] = 0.0 q[1:N - 1, 0] = v[0, 1:N - 1] @@ -118,6 +117,80 @@ def initialize(N, datatype=np.float64): return u +def adi_jax_kernel(jnp, lax, TSTEPS, u): + N = u.shape[0] + v = jnp.zeros_like(u) + p = jnp.zeros_like(u) + q = jnp.zeros_like(u) + + DX = 1.0 / N + DY = 1.0 / N + DT = 1.0 / TSTEPS + B1 = 2.0 + B2 = 1.0 + mul1 = B1 * DT / (DX * DX) + mul2 = B2 * DT / (DY * DY) + a = -mul1 / 2.0 + b = 1.0 + mul2 + c = a + d = -mul2 / 2.0 + e = 1.0 + mul2 + f = d + + def first_j_scan(carry, j): + p, q, u = carry + + p = p.at[1:N - 1, j].set(-c / (a * p[1:N - 1, j - 1] + b)) + q = q.at[1:N - 1, + j].set((-d * u[j, 0:N - 2] + (1.0 + 2.0 * d) * u[j, 1:N - 1] - f * u[j, 2:N] - a * q[1:N - 1, j - 1]) / + (a * p[1:N - 1, j - 1] + b)) + return (p, q, u), None + + def first_backward_j_scan(carry, j): + v, p, q = carry + idx = N - 2 - j # reverse order index: when j=0, idx = N-2; when j=N-2, idx = 0. + v = v.at[idx, 1:N - 1].set(p[1:N - 1, idx] * v[idx + 1, 1:N - 1] + q[1:N - 1, idx]) + return (v, p, q), None + + def second_j_scan(carry, j): + p, q, v = carry + p = p.at[1:N - 1, j].set(-f / (d * p[1:N - 1, j - 1] + e)) + q = q.at[1:N - 1, + j].set((-a * v[0:N - 2, j] + (1.0 + 2.0 * a) * v[1:N - 1, j] - c * v[2:N, j] - d * q[1:N - 1, j - 1]) / + (d * p[1:N - 1, j - 1] + e)) + return (p, q, v), None + + def second_backward_j_scan(carry, j): + u, p, q = carry + idx = N - 2 - j + u = u.at[1:N - 1, idx].set(p[1:N - 1, idx] * u[1:N - 1, idx + 1] + q[1:N - 1, idx]) + return (u, p, q), None + + def time_step_body(carry, t): + u, v, p, q = carry + + v = v.at[0, 1:N - 1].set(1.0) + p = p.at[1:N - 1, 0].set(0.0) + q = q.at[1:N - 1, 0].set(v[0, 1:N - 1]) + (p, q, u), _ = lax.scan(first_j_scan, (p, q, u), jnp.arange(1, N - 1)) + + v = v.at[N - 1, 1:N - 1].set(1.0) + + (v, p, q), _ = lax.scan(first_backward_j_scan, (v, p, q), jnp.arange(0, N - 2)) + + u = u.at[1:N - 1, 0].set(1.0) + p = p.at[1:N - 1, 0].set(0.0) + q = q.at[1:N - 1, 0].set(u[1:N - 1, 0]) + (p, q, v), _ = lax.scan(second_j_scan, (p, q, v), jnp.arange(1, N - 1)) + u = u.at[1:N - 1, N - 1].set(1.0) + (u, p, q), _ = lax.scan(second_backward_j_scan, (u, p, q), jnp.arange(0, N - 2)) + + return (u, v, p, q), None + + (u, v, p, q), _ = lax.scan(time_step_body, (u, v, p, q), jnp.arange(0, TSTEPS)) + return jnp.sum(u) + + def run_adi(device_type: dace.dtypes.DeviceType): ''' Runs ADI for the given device @@ -134,21 +207,8 @@ def run_adi(device_type: dace.dtypes.DeviceType): sdfg = adi_kernel.to_sdfg() sdfg = auto_optimize(sdfg, device_type) sdfg(TSTEPS=TSTEPS, u=dace_u, N=N) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = adi_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - # Use FPGA Expansion for lib nodes, and expand them to enable further optimizations - from dace.libraries.blas import Gemm - Gemm.default_implementation = "FPGA1DSystolic" - sdfg.expand_library_nodes() - # In this case, we want to generate the top-level state as an host-based state, - # not an FPGA kernel. We need to explicitly indicate that - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - sdfg.specialize(dict(N=N)) - sdfg(TSTEPS=TSTEPS, u=dace_u) + else: + raise ValueError(f"Unsupported device type: {device_type}") # Compute ground truth and Validate result numpy_kernel(TSTEPS, N, u) @@ -156,6 +216,46 @@ def run_adi(device_type: dace.dtypes.DeviceType): return sdfg +def run_adi_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (polybench mini size for smaller problem) + _, N = sizes["mini"] + + # Use smaller number of timesteps to avoid exploding gradients + TSTEPS = 10 + + u = initialize(N) + + # Initialize gradient computation data + gradient_u = np.zeros_like(u) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dace.program + def autodiff_kernel(TSTEPS: dace.int64, u: dace.float64[N, N]): + adi_kernel(TSTEPS, u) + return np.sum(u) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["u"], outputs=["__return"]) + sdfg(TSTEPS, u, N=N, gradient_u=gradient_u, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda TSTEPS, u: adi_jax_kernel(jnp, lax, TSTEPS, u) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=1), static_argnums=0) + u_jax = np.copy(initialize(N)) + jax_grad_u = jax_grad(TSTEPS, u_jax) + + np.testing.assert_allclose(gradient_u, jax_grad_u, rtol=1e-6) + + def test_cpu(): run_adi(dace.dtypes.DeviceType.CPU) @@ -165,23 +265,22 @@ def test_gpu(): run_adi(dace.dtypes.DeviceType.GPU) -@pytest.mark.skip(reason="Intel FPGA argument overflow") -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_adi(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_adi_autodiff() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_adi(dace.dtypes.DeviceType.CPU) + run_adi_autodiff() elif target == "gpu": run_adi(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_adi(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/polybench/atax_test.py b/tests/npbench/polybench/atax_test.py index dc0a438fab..8dbd812b82 100644 --- a/tests/npbench/polybench/atax_test.py +++ b/tests/npbench/polybench/atax_test.py @@ -5,11 +5,10 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test, xilinx_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt +from dace.transformation.interstate import InlineSDFG +from dace.transformation.auto.auto_optimize import auto_optimize from dace.config import set_temporary +from dace.autodiff import add_backward_pass # Data set sizes # M, N @@ -42,6 +41,11 @@ def init_data(M, N): return A, x, y +def atax_jax_kernel(jnp, A, x): + B = (A @ x) @ A + return jnp.sum(B) + + def run_atax(device_type: dace.dtypes.DeviceType): """ Runs ATAX for the given device @@ -59,38 +63,42 @@ def run_atax(device_type: dace.dtypes.DeviceType): sdfg = auto_optimize(sdfg, device_type) y = sdfg(A, x, M=M, N=N) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - # Use FPGA Expansion for lib nodes, and expand them to enable further optimizations - from dace.libraries.blas import Gemv - Gemv.default_implementation = "FPGA_Accumulate" - sdfg.expand_library_nodes() - sm_applied = sdfg.apply_transformations_repeated([InlineSDFG, StreamingMemory], - [{}, { - 'storage': dace.StorageType.FPGA_Local - }], - print_report=True) - assert sm_applied == 6 # 3 inlines and 3 Streaming memories - - ########################### - # FPGA Auto Opt - fpga_auto_opt.fpga_global_to_local(sdfg) - fpga_auto_opt.fpga_rr_interleave_containers_to_banks(sdfg) - - # specialize the SDFG (needed by the GEMV expansion) - sdfg.specialize(dict(M=M, N=N)) - y = sdfg(A, x) - # Compute ground truth and Validate result y_ref = kernel.f(A, x) assert np.allclose(y, y_ref) return sdfg +def run_atax_autodiff(): + import jax + import jax.numpy as jnp + + # Initialize data (polybench mini size) + M, N = sizes["mini"] + A, x, y = init_data(M, N) + + # Initialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float32) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(A: dc.float32[M, N], x: dc.float32[N]): + y = kernel(A, x) + return np.sum(y) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(A, x, M=M, N=N, gradient_A=gradient_A, gradient___return=gradient___return) + + # Numerically validate vs JAX + jax_kernel = lambda A, x: atax_jax_kernel(jnp, A, x) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=0)) + jax_grad_A = jax_grad(A, x) + np.testing.assert_allclose(gradient_A, jax_grad_A, rtol=1e-6, atol=1e-6) + + def test_cpu(): run_atax(dace.dtypes.DeviceType.CPU) @@ -100,28 +108,22 @@ def test_gpu(): run_atax(dace.dtypes.DeviceType.GPU) -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_atax(dace.dtypes.DeviceType.FPGA) - - -@xilinx_test(assert_ii_1=False) -def test_xilinx_decoupled_array_interfaces(): - with set_temporary("compiler", "xilinx", "decouple_array_interfaces", value=True): - return run_atax(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_atax_autodiff() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_atax(dace.dtypes.DeviceType.CPU) + run_atax_autodiff() elif target == "gpu": run_atax(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_atax(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/polybench/bicg_test.py b/tests/npbench/polybench/bicg_test.py index ae7daa10ef..e90b728609 100644 --- a/tests/npbench/polybench/bicg_test.py +++ b/tests/npbench/polybench/bicg_test.py @@ -5,10 +5,9 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt +from dace.transformation.interstate import InlineSDFG +from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass # Data set sizes # M, N @@ -36,6 +35,11 @@ def bicg_kernel(A: dc.float64[N, M], p: dc.float64[M], r: dc.float64[N]): return r @ A, A @ p +def bicg_jax_kernel(jnp, A, p, r): + B, D = r @ A, A @ p + return jnp.sum(D) + + def run_bicg(device_type: dace.dtypes.DeviceType): ''' Runs BiCG for the given device @@ -52,34 +56,6 @@ def run_bicg(device_type: dace.dtypes.DeviceType): sdfg = auto_optimize(sdfg, device_type) s, q = sdfg(A, p, r, M=M, N=N) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - # Note: currently the kernel uses double-precision floating point numbers - sdfg = bicg_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - # Use FPGA Expansion for lib nodes, and expand them to enable further optimizations - from dace.libraries.blas import Gemv - Gemv.default_implementation = "FPGA_Accumulate" - sdfg.expand_library_nodes() - - sm_applied = sdfg.apply_transformations_repeated([InlineSDFG, StreamingMemory], - [{}, { - 'storage': dace.StorageType.FPGA_Local - }], - print_report=True) - assert sm_applied == 8 # 3 inlines and 3 Streaming memories - - ########################### - # FPGA Auto Opt - fpga_auto_opt.fpga_global_to_local(sdfg) - fpga_auto_opt.fpga_rr_interleave_containers_to_banks(sdfg) - - # specialize the SDFG (needed by the GEMV expansion) - sdfg.specialize(dict(M=M, N=N)) - s, q = sdfg(A, p, r) - # Compute ground truth and Validate result s_ref, q_ref = bicg_kernel.f(A, p, r) assert np.allclose(s, s_ref) @@ -87,6 +63,41 @@ def run_bicg(device_type: dace.dtypes.DeviceType): return sdfg +def run_bicg_autodiff(): + import jax + import jax.numpy as jnp + + # Initialize data (polybench mini size) + M, N = sizes["mini"] + A, p, r = initialize(M, N) + + # Initialize gradient computation data + B = np.zeros((M, ), dtype=np.float64) + D = np.zeros((N, ), dtype=np.float64) + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(A: dc.float64[N, M], p: dc.float64[M], r: dc.float64[N]): + B, D = bicg_kernel(A, p, r) + return np.sum(D) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(A, p, r, M=M, N=N, gradient_A=gradient_A, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda A, p, r: bicg_jax_kernel(jnp, A, p, r) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=0)) + jax_grad_A = jax_grad(A, p, r) + np.testing.assert_allclose(gradient_A, jax_grad_A) + + def test_cpu(): run_bicg(dace.dtypes.DeviceType.CPU) @@ -96,22 +107,22 @@ def test_gpu(): run_bicg(dace.dtypes.DeviceType.GPU) -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_bicg(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_bicg_autodiff() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_bicg(dace.dtypes.DeviceType.CPU) + run_bicg_autodiff() elif target == "gpu": run_bicg(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_bicg(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/polybench/cholesky2_test.py b/tests/npbench/polybench/cholesky2_test.py index f0ff0bb187..50486f9126 100644 --- a/tests/npbench/polybench/cholesky2_test.py +++ b/tests/npbench/polybench/cholesky2_test.py @@ -6,9 +6,7 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition +from dace.transformation.interstate import InlineSDFG from dace.transformation.auto.auto_optimize import auto_optimize # Data set sizes @@ -63,20 +61,6 @@ def run_cholesky2(device_type: dace.dtypes.DeviceType): sdfg = cholesky2_kernel.to_sdfg() sdfg = auto_optimize(sdfg, device_type) sdfg(A=A, N=N) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = cholesky2_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - # Use FPGA Expansion for lib nodes, and expand them to enable further optimizations - from dace.libraries.blas import Dot - Dot.default_implementation = "FPGA_PartialSums" - sdfg.expand_library_nodes() - sdfg.apply_transformations_repeated([InlineSDFG]) - - sdfg(A=A, N=N) - # Compute ground truth and validate result ground_truth(N, gt_A) diff = np.linalg.norm(gt_A - A) / np.linalg.norm(gt_A) @@ -93,16 +77,10 @@ def test_gpu(): run_cholesky2(dace.dtypes.DeviceType.GPU) -@pytest.mark.skip(reason="Unsupported Lapack calls") -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_cholesky2(dace.dtypes.DeviceType.FPGA) - - if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] @@ -111,5 +89,3 @@ def test_fpga(): run_cholesky2(dace.dtypes.DeviceType.CPU) elif target == "gpu": run_cholesky2(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_cholesky2(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/polybench/cholesky_test.py b/tests/npbench/polybench/cholesky_test.py index a83d153338..309bb3197d 100644 --- a/tests/npbench/polybench/cholesky_test.py +++ b/tests/npbench/polybench/cholesky_test.py @@ -6,10 +6,9 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition +from dace.transformation.interstate import InlineSDFG from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass # Data set sizes # N @@ -43,6 +42,36 @@ def init_data(N): return A +def cholesky_jax_kernel(jnp, lax, A): + A = A.at[0, 0].set(jnp.sqrt(A[0, 0])) + + def row_update_body(A, i): + + def col_update_body(A, j): + + def do_update(_): + mask = jnp.arange(A.shape[1]) < j + A_i_slice = jnp.where(mask, A[i, :], 0) + A_j_slice = jnp.where(mask, A[j, :], 0) + dot_product = jnp.dot(A_i_slice, A_j_slice) + new_val = (A[i, j] - dot_product) / A[j, j] + return A.at[i, j].set(new_val) + + A = lax.cond(j < i, do_update, lambda _: A, operand=None) + return A, None + + A, _ = lax.scan(col_update_body, A, jnp.arange(A.shape[0])) + + mask = jnp.arange(A.shape[1]) < i + A_i_slice = jnp.where(mask, A[i, :], 0) + dot_product = jnp.dot(A_i_slice, A_i_slice) + A = A.at[i, i].set(jnp.sqrt(A[i, i] - dot_product)) + return A, None + + A, _ = lax.scan(row_update_body, A, jnp.arange(1, A.shape[0])) + return jnp.sum(A) + + def ground_truth(N, A): A[0, 0] = np.sqrt(A[0, 0]) for i in range(1, N): @@ -69,24 +98,6 @@ def run_cholesky(device_type: dace.dtypes.DeviceType): sdfg = kernel.to_sdfg() sdfg = auto_optimize(sdfg, device_type) sdfg(A=A, N=N) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - # Use FPGA Expansion for lib nodes, and expand them to enable further optimizations - from dace.libraries.blas import Dot - platform = dace.config.Config.get("compiler", "fpga", "vendor") - if platform == "intel_fpga": - Dot.default_implementation = "FPGA_Accumulate" - else: - Dot.default_implementation = "FPGA_PartialSums" - sdfg.expand_library_nodes() - sdfg.apply_transformations_repeated([InlineSDFG]) - - sdfg(A=A, N=N) - # Compute ground truth and validate result ground_truth(N, gt_A) diff = np.linalg.norm(gt_A - A) / np.linalg.norm(gt_A) @@ -94,6 +105,38 @@ def run_cholesky(device_type: dace.dtypes.DeviceType): return sdfg +def run_cholesky_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (polybench mini size) + N = 20 + A = init_data(N) + A_jax = jnp.copy(A) + + # Initialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float32) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(A: dc.float32[N, N]): + kernel(A) + return np.sum(A) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg(simplify=True) + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(A, N=N, gradient_A=gradient_A, gradient___return=gradient___return) + + # Numerically validate vs JAX + jax_kernel = lambda A: cholesky_jax_kernel(jnp, lax, A) + jax_grad = jax.jit(jax.grad(jax_kernel)) + jax_grad_A = jax_grad(A_jax) + np.testing.assert_allclose(gradient_A, jax_grad_A, rtol=1e-4, atol=1e-4) + + def test_cpu(): run_cholesky(dace.dtypes.DeviceType.CPU) @@ -103,22 +146,22 @@ def test_gpu(): run_cholesky(dace.dtypes.DeviceType.GPU) -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_cholesky(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_cholesky_autodiff() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_cholesky(dace.dtypes.DeviceType.CPU) + run_cholesky_autodiff() elif target == "gpu": run_cholesky(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_cholesky(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/polybench/correlation_test.py b/tests/npbench/polybench/correlation_test.py index a5532cf829..2bb9d65e07 100644 --- a/tests/npbench/polybench/correlation_test.py +++ b/tests/npbench/polybench/correlation_test.py @@ -7,6 +7,7 @@ import pytest import argparse from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass # Data set sizes # M, N @@ -42,6 +43,20 @@ def initialize(M, N, datatype=np.float64): return float_n, data +def correlation_jax_kernel(jnp, float_n, data): + mean = jnp.mean(data, axis=0) + M = data.shape[1] + stddev = jnp.sqrt(jnp.mean(jnp.subtract(data, mean)**2, axis=0)) + stddev = jnp.where(stddev <= 0.1, 1.0, stddev) + data = jnp.subtract(data, mean) + data = jnp.divide(data, jnp.sqrt(float_n) * stddev) + corr = jnp.eye(M, dtype=data.dtype) + for i in range(M - 1): + corr = corr.at[i, i + 1:M].set(data[:, i] @ data[:, i + 1:M]) + corr = corr.at[i + 1:M, i].set(corr[i, i + 1:M]) + return jnp.sum(corr) + + def ground_truth(M, float_n, data): mean = np.mean(data, axis=0) @@ -77,9 +92,6 @@ def run_correlation(device_type: dace.dtypes.DeviceType): corr = sdfg(float_n, data, M=M, N=N) os.environ['DACE_testing_serialization'] = last_value - elif device_type == dace.dtypes.DeviceType.FPGA: - pass # Not Yet Implemented - # Compute ground truth and validate result corr_ref = ground_truth(M, float_n_ref, data_ref) @@ -88,6 +100,40 @@ def run_correlation(device_type: dace.dtypes.DeviceType): return sdfg +def run_correlation_autodiff(): + import jax + import jax.numpy as jnp + + # Initialize data (polybench mini size) + M, N = sizes["mini"] + float_n, data = initialize(M, N) + + # Initialize gradient computation data + gradient_data = np.zeros_like(data) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(float_n: dc.float64, data: dc.float64[N, M]): + corr = correlation_kernel(float_n, data) + return np.sum(corr) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["data"], outputs=["__return"]) + sdfg(float_n, data, M=M, N=N, gradient_data=gradient_data, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda float_n, data: correlation_jax_kernel(jnp, float_n, data) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=1), static_argnums=(0, )) + _, data_jax = initialize(M, N) + jax_grad_data = jax_grad(float_n, data_jax) + np.testing.assert_allclose(gradient_data, jax_grad_data, rtol=1e-8, atol=1e-8) + + def test_cpu(): run_correlation(dace.dtypes.DeviceType.CPU) @@ -97,17 +143,27 @@ def test_gpu(): run_correlation(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + # Serialization causes issues, we temporarily disable it + # TODO: open an issue to fix the serialization stability problem + last_value = os.environ.get('DACE_testing_serialization', '0') + os.environ['DACE_testing_serialization'] = '0' + run_correlation_autodiff() + os.environ['DACE_testing_serialization'] = last_value + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_correlation(dace.dtypes.DeviceType.CPU) + run_correlation_autodiff() elif target == "gpu": run_correlation(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_correlation(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/polybench/covariance_test.py b/tests/npbench/polybench/covariance_test.py index 66878bad1a..4aeabaf9f6 100644 --- a/tests/npbench/polybench/covariance_test.py +++ b/tests/npbench/polybench/covariance_test.py @@ -1,17 +1,18 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. # Original application code: NPBench - https://github.com/spcl/npbench +import os import dace.dtypes import numpy as np import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG +from dace.transformation.interstate import InlineSDFG from dace.transformation.dataflow import StreamingMemory, MapFusionVertical, StreamingComposition, PruneConnectors -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt +from dace.transformation.auto.auto_optimize import auto_optimize from dace.libraries.standard import Reduce from dace.libraries.blas import Gemv +from dace.autodiff import add_backward_pass # Data set sizes # M, N @@ -36,6 +37,17 @@ def covariance_kernel(float_n: dc.float32, data: dc.float32[N, M]): return cov +def covariance_jax_kernel(jnp, float_n, data): + mean = jnp.mean(data, axis=0) + M = data.shape[1] + data -= mean + cov = jnp.zeros((M, M), dtype=data.dtype) + for i in range(M): + cov = cov.at[i:M, i].set(data[:, i] @ data[:, i:M] / (float_n - 1.0)) + cov = cov.at[i, i:M].set(data[:, i] @ data[:, i:M] / (float_n - 1.0)) + return jnp.sum(cov) + + def ground_truth(M, N, float_n, data): mean = np.empty((M, ), dtype=data.dtype) @@ -91,36 +103,41 @@ def run_covariance(device_type: dace.dtypes.DeviceType): sdfg = auto_optimize(sdfg, device_type) dace_res = sdfg(float_n=float_n, data=data, M=M, N=N) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = covariance_kernel.to_sdfg(simplify=False) - sdfg.simplify() - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 + # Compute ground truth and validate result + gt_res = ground_truth(M, N, float_n, gt_data) + assert np.allclose(gt_res, dace_res) + return sdfg - sdfg.apply_transformations([InlineSDFG]) - # Use FPGA Expansion for lib nodes, and expand them to enable further optimizations - # Reduce.default_implementation = "FPGAPartialReduction" - Gemv.default_implementation = "FPGA_Accumulate" +def run_covariance_autodiff(): + import jax + import jax.numpy as jnp - sdfg.expand_library_nodes() - sdfg.apply_transformations([InlineSDFG]) + # Initialize data (polybench mini size) + M, N = sizes["mini"] + float_n, data = init_data(M, N) + data_jax = np.copy(data) - # Other FPGA auto opt - fpga_auto_opt.fpga_global_to_local(sdfg) - fpga_auto_opt.fpga_rr_interleave_containers_to_banks(sdfg) + # Initialize gradient computation data + gradient_data = np.zeros_like(data) + gradient___return = np.ones((1, ), dtype=np.float32) - # Specialize the SDFG - sdfg.specialize(dict(N=N, M=M)) + # Define sum reduction for the output + @dc.program + def autodiff_kernel(float_n: dc.float32, data: dc.float32[N, M]): + cov = covariance_kernel(float_n, data) + return np.sum(cov) - # run program - dace_res = sdfg(float_n=float_n, data=data) + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["data"], outputs=["__return"]) + sdfg(float_n, data, M=M, N=N, gradient_data=gradient_data, gradient___return=gradient___return) - # Compute ground truth and validate result - gt_res = ground_truth(M, N, float_n, gt_data) - assert np.allclose(gt_res, dace_res) - return sdfg + # Numerically validate vs JAX + jax_kernel = lambda float_n, data: covariance_jax_kernel(jnp, float_n, data) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=1), static_argnums=(0, )) + jax_grad_data = jax_grad(float_n, data_jax) + np.testing.assert_allclose(gradient_data, jax_grad_data, rtol=1e-5, atol=1e-8) def test_cpu(monkeypatch): @@ -134,22 +151,27 @@ def test_gpu(): run_covariance(dace.dtypes.DeviceType.GPU) -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_covariance(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + # Serialization causes issues, we temporarily disable it + # TODO: open an issue to fix the serialization stability problem + last_value = os.environ.get('DACE_testing_serialization', '0') + os.environ['DACE_testing_serialization'] = '0' + run_covariance_autodiff() + os.environ['DACE_testing_serialization'] = last_value if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_covariance(dace.dtypes.DeviceType.CPU) + run_covariance_autodiff() elif target == "gpu": run_covariance(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_covariance(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/polybench/deriche_test.py b/tests/npbench/polybench/deriche_test.py index b2fe7d47e2..ac40cd14fc 100644 --- a/tests/npbench/polybench/deriche_test.py +++ b/tests/npbench/polybench/deriche_test.py @@ -5,10 +5,9 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt +from dace.transformation.interstate import InlineSDFG +from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass # Data set sizes # W, H @@ -76,6 +75,64 @@ def initialize(W, H, datatype=np.float64): return alpha, imgIn +def deriche_jax_kernel(jnp, lax, alpha, imgIn): + + k = (1.0 - jnp.exp(-alpha))**2 / (1.0 + alpha * jnp.exp(-alpha) - jnp.exp(2.0 * alpha)) + a1 = a5 = k + a2 = a6 = k * jnp.exp(-alpha) * (alpha - 1.0) + a3 = a7 = k * jnp.exp(-alpha) * (alpha + 1.0) + a4 = a8 = -k * jnp.exp(-2.0 * alpha) + b1 = 2.0**(-alpha) + b2 = -jnp.exp(-2.0 * alpha) + c1 = c2 = 1 + + y1 = jnp.empty_like(imgIn) + y1 = y1.at[:, 0].set(a1 * imgIn[:, 0]) + y1 = y1.at[:, 1].set(a1 * imgIn[:, 1] + a2 * imgIn[:, 0] + b1 * y1[:, 0]) + + def horizontal_forward_body(y1, j): + new_y1 = y1.at[:, j].set(a1 * imgIn[:, j] + a2 * imgIn[:, j - 1] + b1 * y1[:, j - 1] + b2 * y1[:, j - 2]) + return new_y1, None + + y1, _ = lax.scan(horizontal_forward_body, y1, jnp.arange(2, imgIn.shape[1])) + + y2 = jnp.empty_like(imgIn) + y2 = y2.at[:, -1].set(0.0) + y2 = y2.at[:, -2].set(a3 * imgIn[:, -1]) + + def horizontal_backward_body(y2, j): + idx = imgIn.shape[1] - 3 - j + new_y2 = y2.at[:, idx].set(a3 * imgIn[:, idx + 1] + a4 * imgIn[:, idx + 2] + b1 * y2[:, idx + 1] + + b2 * y2[:, idx + 2]) + return new_y2, None + + y2, _ = lax.scan(horizontal_backward_body, y2, jnp.arange(0, imgIn.shape[1] - 2)) + + imgOut = c1 * (y1 + y2) + + y1 = y1.at[0, :].set(a5 * imgOut[0, :]) + y1 = y1.at[1, :].set(a5 * imgOut[1, :] + a6 * imgOut[0, :] + b1 * y1[0, :]) + + def vertical_forward_body(y1, i): + new_y1 = y1.at[i, :].set(a5 * imgOut[i, :] + a6 * imgOut[i - 1, :] + b1 * y1[i - 1, :] + b2 * y1[i - 2, :]) + return new_y1, None + + y1, _ = lax.scan(vertical_forward_body, y1, jnp.arange(2, imgIn.shape[0])) + + y2 = y2.at[-1, :].set(0.0) + y2 = y2.at[-2, :].set(a7 * imgOut[-1, :]) + + def vertical_backward_body(y2, i): + idx = imgIn.shape[0] - 3 - i + new_y2 = y2.at[idx, :].set(a7 * imgOut[idx + 1, :] + a8 * imgOut[idx + 2, :] + b1 * y2[idx + 1, :] + + b2 * y2[idx + 2, :]) + return new_y2, None + + y2, _ = lax.scan(vertical_backward_body, y2, jnp.arange(0, imgIn.shape[0] - 2)) + + return jnp.sum(c2 * (y1 + y2)) + + def ground_truth(alpha, imgIn): k = (1.0 - np.exp(-alpha)) * (1.0 - np.exp(-alpha)) / (1.0 + alpha * np.exp(-alpha) - np.exp(2.0 * alpha)) @@ -132,23 +189,6 @@ def run_deriche(device_type: dace.dtypes.DeviceType): sdfg = auto_optimize(sdfg, device_type) imgOut = sdfg(alpha, imgIn, W=W, H=H) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - # Note: currently the kernel uses double-precision floating point numbers and - # works for Xilinx - sdfg = deriche_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - ########################### - # FPGA Auto Opt - fpga_auto_opt.fpga_global_to_local(sdfg) - fpga_auto_opt.fpga_rr_interleave_containers_to_banks(sdfg) - - # specialize the SDFG (needed by the GEMV expansion) - sdfg.specialize(dict(W=W, H=H)) - imgOut = sdfg(alpha, imgIn) - # Compute ground truth and validate result imgOut_ref = ground_truth(alpha, imgIn) @@ -156,6 +196,41 @@ def run_deriche(device_type: dace.dtypes.DeviceType): return sdfg +def run_deriche_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (test size for efficiency) + W, H = sizes["mini"] + alpha, imgIn = initialize(W, H) + + # Initialize gradient computation data + gradient_imgIn = np.zeros_like(imgIn) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output using __return pattern + @dc.program + def autodiff_kernel(alpha: dc.float64, imgIn: dc.float64[W, H]): + imgOut = deriche_kernel(alpha, imgIn) + return np.sum(imgOut) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["imgIn"], outputs=["__return"]) + sdfg(alpha, imgIn, W=W, H=H, gradient_imgIn=gradient_imgIn, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda alpha, imgIn: deriche_jax_kernel(jnp, lax, alpha, imgIn) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=1)) + alpha_jax, imgIn_jax = initialize(W, H) + jax_grad_imgIn = jax_grad(alpha_jax, imgIn_jax) + np.testing.assert_allclose(gradient_imgIn, jax_grad_imgIn) + + def test_cpu(): run_deriche(dace.dtypes.DeviceType.CPU) @@ -165,22 +240,22 @@ def test_gpu(): run_deriche(dace.dtypes.DeviceType.GPU) -@fpga_test(assert_ii_1=False, intel=False) -def test_fpga(): - return run_deriche(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_deriche_autodiff() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_deriche(dace.dtypes.DeviceType.CPU) + run_deriche_autodiff() elif target == "gpu": run_deriche(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_deriche(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/polybench/doitgen_test.py b/tests/npbench/polybench/doitgen_test.py index 52fffd1d0d..45da2d635f 100644 --- a/tests/npbench/polybench/doitgen_test.py +++ b/tests/npbench/polybench/doitgen_test.py @@ -5,9 +5,9 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG +from dace.transformation.interstate import InlineSDFG from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass # Data set sizes # NQ, NR, NP @@ -38,6 +38,15 @@ def initialize(NR, NQ, NP, datatype=np.float64): return A, C4 +def doitgen_jax_kernel(jnp, A, C4): + NR = A.shape[0] + NQ = A.shape[1] + NP = A.shape[2] + for r in range(NR): + A = A.at[r, :, :].set(jnp.reshape(jnp.reshape(A[r], (NQ, NP)) @ C4, (NQ, NP))) + return jnp.sum(A) + + def ground_truth(NR, NQ, NP, A, C4): A[:] = np.reshape(np.reshape(A, (NR, NQ, 1, NP)) @ C4, (NR, NQ, NP)) @@ -59,30 +68,46 @@ def run_doitgen(device_type: dace.dtypes.DeviceType): sdfg = auto_optimize(sdfg, device_type) sdfg(A, C4, NR=NR, NQ=NQ, NP=NP) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = doitgen_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - # Use FPGA Expansion for lib nodes, and expand them to enable further optimizations - from dace.libraries.blas import Gemm - Gemm.default_implementation = "FPGA1DSystolic" - sdfg.expand_library_nodes() - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - sdfg.states()[0].location["is_FPGA_kernel"] = False - # we need to specialize both the top-level SDFG and the nested SDFG - sdfg.specialize(dict(NR=NR, NQ=NQ, NP=NP)) - sdfg.states()[0].nodes()[0].sdfg.specialize(dict(NR=NR, NQ=NQ, NP=NP)) - # TODO: add support for `long long` in Intel FPGA, set systolic array size - sdfg(A, C4) - # Compute ground truth and Validate result ground_truth(NR, NQ, NP, A_ref, C4) assert np.allclose(A, A_ref) return sdfg +def run_doitgen_autodiff(): + import jax + import jax.numpy as jnp + + # Initialize data (polybench mini size) + NQ, NR, NP = sizes["mini"] + A, C4 = initialize(NR, NQ, NP) + + # Initialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(A: dc.float64[NR, NQ, NP], C4: dc.float64[NP, NP]): + doitgen_kernel(A, C4) + return np.sum(A) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(A, C4, NR=NR, NQ=NQ, NP=NP, gradient_A=gradient_A, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda A, C4: doitgen_jax_kernel(jnp, A, C4) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=0)) + A_jax, C4_jax = initialize(NR, NQ, NP) + jax_grad_A = jax_grad(A_jax, C4_jax) + np.testing.assert_allclose(gradient_A, jax_grad_A) + + def test_cpu(): run_doitgen(dace.dtypes.DeviceType.CPU) @@ -92,23 +117,22 @@ def test_gpu(): run_doitgen(dace.dtypes.DeviceType.GPU) -@pytest.mark.skip(reason="long long support for IntelFPGA") -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_doitgen(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_doitgen_autodiff() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_doitgen(dace.dtypes.DeviceType.CPU) + run_doitgen_autodiff() elif target == "gpu": run_doitgen(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_doitgen(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/polybench/durbin_test.py b/tests/npbench/polybench/durbin_test.py index ffeff150d9..c119dd1361 100644 --- a/tests/npbench/polybench/durbin_test.py +++ b/tests/npbench/polybench/durbin_test.py @@ -5,11 +5,10 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test, xilinx_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt +from dace.transformation.interstate import InlineSDFG +from dace.transformation.auto.auto_optimize import auto_optimize from dace.config import set_temporary +from dace.autodiff import add_backward_pass # Data set sizes # N @@ -63,6 +62,46 @@ def ground_truth(r): return y +def durbin_jax_kernel(jnp, lax, r): + # Initialize y, alpha, and beta. + y = jnp.empty_like(r) + alpha = -r[0] + beta = 1.0 + y = y.at[0].set(-r[0]) + + # Define the scan body. The loop index k will run from 1 to r.shape[0]-1. + def scan_body(carry, k): + alpha, beta, y, r = carry + + # Update beta. + beta = beta * (1.0 - alpha * alpha) + + # Create a mask for indices less than k. + mask = jnp.arange(r.shape[0]) < k + + # Compute the dot product between y and a shifted version of r. + # Note: jnp.roll(jnp.flip(r), [k], 0) is equivalent to shifting along axis 0. + products = jnp.where(mask, y * jnp.roll(jnp.flip(r), k, axis=0), 0.0) + dot_prod = jnp.sum(products) + + # Update alpha based on the k-th element of r and the dot product. + alpha = -(r[k] + dot_prod) / beta + + # Compute an update slice from a shifted version of y. + y_update_slice = jnp.where(mask, jnp.roll(jnp.flip(y), k, axis=0) * alpha, 0.0) + + # Update y by adding the computed slice and setting the k-th element to alpha. + y = y + y_update_slice + y = y.at[k].set(alpha) + + return (alpha, beta, y, r), None + + # Run the scan from k = 1 to r.shape[0]-1. + (alpha, beta, y, r), _ = lax.scan(scan_body, (alpha, beta, y, r), jnp.arange(1, r.shape[0])) + + return jnp.sum(y) + + def run_durbin(device_type: dace.dtypes.DeviceType): ''' Runs Durbin for the given device @@ -80,24 +119,44 @@ def run_durbin(device_type: dace.dtypes.DeviceType): sdfg = auto_optimize(sdfg, device_type) y = sdfg(r, N=N) assert np.allclose(y, y_ref) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = durbin_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - # Use FPGA Expansion for lib nodes, and expand them to enable further optimizations - from dace.libraries.blas import Dot - Dot.default_implementation = "FPGA_PartialSums" - sdfg.expand_library_nodes() - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - sdfg.specialize(dict(N=N)) - y = sdfg(r) - assert np.allclose(y, y_ref, atol=1e-6) - return sdfg +def run_durbin_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (polybench small size) + N = sizes["small"] + r = initialize(N) + + # Initialize gradient computation data + gradient_r = np.zeros_like(r) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(r: dc.float64[N]): + y = durbin_kernel(r) + return np.sum(y) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg(simplify=True) + add_backward_pass(sdfg=sdfg, inputs=["r"], outputs=["__return"], simplify=False) + sdfg(r, N=N, gradient_r=gradient_r, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda r: durbin_jax_kernel(jnp, lax, r) + jax_grad = jax.jit(jax.grad(jax_kernel)) + r_jax = initialize(N) + jax_grad_r = jax_grad(r_jax) + np.testing.assert_allclose(gradient_r, jax_grad_r) + + def test_cpu(): run_durbin(dace.dtypes.DeviceType.CPU) @@ -107,22 +166,22 @@ def test_gpu(): run_durbin(dace.dtypes.DeviceType.GPU) -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_durbin(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_durbin_autodiff() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_durbin(dace.dtypes.DeviceType.CPU) + run_durbin_autodiff() elif target == "gpu": run_durbin(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_durbin(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/polybench/fdtd_2d_test.py b/tests/npbench/polybench/fdtd_2d_test.py index 37db5de743..ff0a9a9913 100644 --- a/tests/npbench/polybench/fdtd_2d_test.py +++ b/tests/npbench/polybench/fdtd_2d_test.py @@ -5,11 +5,11 @@ import numpy as np import dace as dc import pytest -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG +from dace.transformation.interstate import InlineSDFG from dace.transformation.dataflow import StreamingMemory, StreamingComposition, MapFusionVertical from dace.transformation.auto.auto_optimize import auto_optimize import argparse +from dace.autodiff import add_backward_pass # Data set sizes # TMAX, NX, NY @@ -50,6 +50,26 @@ def init_data(TMAX, NX, NY): return ex, ey, hz, _fict_ +def fdtd_2d_jax_kernel(jnp, lax, ex, ey, hz, _fict_): + """JAX implementation using efficient lax.scan operations""" + TMAX = _fict_.shape[0] + + def scan_body(carry, t): + ex, ey, hz = carry + # Set the top row of ey using _fict_ for the current time step. + ey = ey.at[0, :].set(_fict_[t]) + # Update ey for rows 1 and beyond. + ey = ey.at[1:, :].set(ey[1:, :] - 0.5 * (hz[1:, :] - hz[:-1, :])) + # Update ex for columns 1 and beyond. + ex = ex.at[:, 1:].set(ex[:, 1:] - 0.5 * (hz[:, 1:] - hz[:, :-1])) + # Update hz for the interior (all but last row and col). + hz = hz.at[:-1, :-1].set(hz[:-1, :-1] - 0.7 * ((ex[:-1, 1:] - ex[:-1, :-1]) + (ey[1:, :-1] - ey[:-1, :-1]))) + return (ex, ey, hz), None + + (ex, ey, hz), _ = lax.scan(scan_body, (ex, ey, hz), jnp.arange(TMAX)) + return jnp.sum(hz) + + def ground_truth(TMAX, NX, NY, ex, ey, hz, _fict_): for t in range(TMAX): @@ -77,28 +97,6 @@ def run_fdtd_2d(device_type: dace.dtypes.DeviceType): sdfg = kernel.to_sdfg() sdfg = auto_optimize(sdfg, device_type) sdfg(ex=ex, ey=ey, hz=hz, _fict_=_fict_, TMAX=TMAX, NX=NX, NY=NY) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = kernel.to_sdfg(simplify=True) - sdfg.apply_transformations_repeated([MapFusionVertical]) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - sm_applied = sdfg.apply_transformations_repeated([InlineSDFG, StreamingMemory], - [{}, { - 'storage': dace.StorageType.FPGA_Local - }], - print_report=True) - - assert sm_applied > 0 - - sdfg.apply_transformations_repeated([InlineSDFG]) - # In this case, we want to generate the top-level state as an host-based state, - # not an FPGA kernel. We need to explicitly indicate that - sdfg.states()[0].location["is_FPGA_kernel"] = False - - sdfg(ex=ex, ey=ey, hz=hz, _fict_=_fict_, TMAX=TMAX, NX=NX, NY=NY) - # Compute ground truth and validate result ground_truth(TMAX, NX, NY, gt_ex, gt_ey, gt_hz, _fict_=_fict_) diff_ex = np.linalg.norm(gt_ex - ex) / np.linalg.norm(gt_ex) @@ -113,6 +111,39 @@ def run_fdtd_2d(device_type: dace.dtypes.DeviceType): return sdfg +def run_fdtd_2d_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (test size for efficiency) + TMAX, NX, NY = (2, 10, 12) + ex, ey, hz, _fict_ = init_data(TMAX, NX, NY) + + # Initialize gradient computation data + gradient_ex = np.zeros_like(ex) + gradient___return = np.ones((1, ), dtype=np.float32) + + # Define sum reduction for the output using __return pattern + @dc.program + def fdtd_2d_autodiff_kernel(ex: dc.float32[NX, NY], ey: dc.float32[NX, NY], hz: dc.float32[NX, NY], + _fict_: dc.float32[TMAX]): + kernel(ex, ey, hz, _fict_) + return np.sum(hz) + + # Add the backward pass to the SDFG + sdfg = fdtd_2d_autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["ex"], outputs=["__return"]) + sdfg(ex, ey, hz, _fict_, TMAX=TMAX, NX=NX, NY=NY, gradient_ex=gradient_ex, gradient___return=gradient___return) + + # Numerically validate vs JAX (use float32 consistent with kernel) + jax_kernel = lambda ex, ey, hz, _fict_: fdtd_2d_jax_kernel(jnp, lax, ex, ey, hz, _fict_) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=0)) + ex_jax, ey_jax, hz_jax, _fict_jax = init_data(TMAX, NX, NY) + jax_grad_ex = jax_grad(ex_jax, ey_jax, hz_jax, _fict_jax) + np.testing.assert_allclose(gradient_ex, jax_grad_ex) + + def test_cpu(): run_fdtd_2d(dace.dtypes.DeviceType.CPU) @@ -122,22 +153,22 @@ def test_gpu(): run_fdtd_2d(dace.dtypes.DeviceType.GPU) -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_fdtd_2d(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_fdtd_2d_autodiff() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_fdtd_2d(dace.dtypes.DeviceType.CPU) + run_fdtd_2d_autodiff() elif target == "gpu": run_fdtd_2d(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_fdtd_2d(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/polybench/floyd_warshall_test.py b/tests/npbench/polybench/floyd_warshall_test.py index 66e04cd00e..48a8e4485f 100644 --- a/tests/npbench/polybench/floyd_warshall_test.py +++ b/tests/npbench/polybench/floyd_warshall_test.py @@ -6,10 +6,9 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG, StateFusion +from dace.transformation.interstate import InlineSDFG, StateFusion from dace.transformation.dataflow import StreamingMemory, MapFusionVertical, StreamingComposition, PruneConnectors -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt +from dace.transformation.auto.auto_optimize import auto_optimize # Data set sizes # N @@ -64,53 +63,6 @@ def run_floyd_warshall(device_type: dace.dtypes.DeviceType): sdfg = auto_optimize(sdfg, device_type) sdfg(path=path, N=N) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = kernel.to_sdfg(simplify=True) - # sdfg.apply_transformations_repeated([MapFusionVertical]) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - sm_applied = sdfg.apply_transformations_repeated([InlineSDFG, StreamingMemory], - [{}, { - 'storage': dace.StorageType.FPGA_Local - }], - print_report=True) - sc_applied = sdfg.apply_transformations_repeated([InlineSDFG, StreamingComposition], - [{}, { - 'storage': dace.StorageType.FPGA_Local - }], - print_report=True, - permissive=True) - assert sc_applied == 1 - - # Prune connectors after Streaming Composition - pruned_conns = sdfg.apply_transformations_repeated(PruneConnectors, - options=[{ - 'remove_unused_containers': True - }]) - - assert pruned_conns == 1 - sdfg.apply_transformations_repeated(StateFusion) - - fpga_auto_opt.fpga_rr_interleave_containers_to_banks(sdfg) - - # In this case, we want to generate the top-level state as an host-based state, - # not an FPGA kernel. We need to explicitly indicate that - for state in sdfg.states(): - if any([isinstance(node, dace.nodes.NestedSDFG) for node in state.nodes()]): - state.location["is_FPGA_kernel"] = False - - # we need to specialize both the top-level SDFG and the nested SDFG - sdfg.specialize(dict(N=N)) - for state in sdfg.states(): - for node in state.nodes(): - if isinstance(node, dace.nodes.NestedSDFG): - node.sdfg.specialize(dict(N=N)) - - # run program - sdfg(path=path) - # Compute ground truth and validate result ground_truth(gt_path, N) assert np.allclose(path, gt_path) @@ -126,15 +78,10 @@ def test_gpu(): run_floyd_warshall(dace.dtypes.DeviceType.GPU) -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_floyd_warshall(dace.dtypes.DeviceType.FPGA) - - if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] @@ -143,5 +90,3 @@ def test_fpga(): run_floyd_warshall(dace.dtypes.DeviceType.CPU) elif target == "gpu": run_floyd_warshall(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_floyd_warshall(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/polybench/gemm_npbench_test.py b/tests/npbench/polybench/gemm_npbench_test.py index 58948f295d..7c078d5e0d 100644 --- a/tests/npbench/polybench/gemm_npbench_test.py +++ b/tests/npbench/polybench/gemm_npbench_test.py @@ -5,11 +5,9 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test, xilinx_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt -from dace.config import set_temporary +from dace.transformation.interstate import InlineSDFG +from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass # Data set sizes # NI, NJ, NK @@ -40,6 +38,10 @@ def initialize(NI, NJ, NK, datatype=np.float64): return alpha, beta, C, A, B +def gemm_jax_kernel(jnp, alpha, beta, A, B, C): + return jnp.sum(alpha * A @ B + beta * C) + + def run_gemm(device_type: dace.dtypes.DeviceType): ''' Runs Gemm for the given device @@ -56,26 +58,46 @@ def run_gemm(device_type: dace.dtypes.DeviceType): sdfg = gemm_kernel.to_sdfg() sdfg = auto_optimize(sdfg, device_type) sdfg(alpha, beta, C, A, B, NI=NI, NJ=NJ, NK=NK) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = gemm_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - # Use FPGA Expansion for lib nodes, and expand them to enable further optimizations - from dace.libraries.blas import Gemm - Gemm.default_implementation = "FPGA1DSystolic" - sdfg.expand_library_nodes() - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - sdfg.specialize(dict(NI=NI, NJ=NJ, NK=NK)) - sdfg(alpha, beta, C, A, B) - # Compute ground truth and validate gemm_kernel.f(alpha, beta, C_ref, A, B) assert np.allclose(C, C_ref) return sdfg +def run_gemm_autodiff(): + import jax + import jax.numpy as jnp + + # Initialize data (polybench mini size) + NI, NJ, NK = sizes["mini"] + alpha, beta, C, A, B = initialize(NI, NJ, NK) + + # Initialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(alpha: dc.float64, beta: dc.float64, C: dc.float64[NI, NJ], A: dc.float64[NI, NK], + B: dc.float64[NK, NJ]): + gemm_kernel(alpha, beta, C, A, B) + return np.sum(C) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(alpha, beta, C, A, B, NI=NI, NJ=NJ, NK=NK, gradient_A=gradient_A, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda alpha, beta, A, B, C: gemm_jax_kernel(jnp, alpha, beta, A, B, C) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=2), static_argnums=(0, 1)) + jax_grad_A = jax_grad(alpha, beta, A, B, C) + np.testing.assert_allclose(gradient_A, jax_grad_A) + + def test_cpu(): run_gemm(dace.dtypes.DeviceType.CPU) @@ -85,22 +107,22 @@ def test_gpu(): run_gemm(dace.dtypes.DeviceType.GPU) -@fpga_test(assert_ii_1=False, xilinx=False) -def test_fpga(): - return run_gemm(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_gemm_autodiff() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_gemm(dace.dtypes.DeviceType.CPU) + run_gemm_autodiff() elif target == "gpu": run_gemm(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_gemm(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/polybench/gemver_test.py b/tests/npbench/polybench/gemver_test.py index 58e078fe11..e49cf0c7d7 100644 --- a/tests/npbench/polybench/gemver_test.py +++ b/tests/npbench/polybench/gemver_test.py @@ -5,11 +5,9 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test, xilinx_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt -from dace.config import set_temporary +from dace.transformation.interstate import InlineSDFG +from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass # Data set sizes # N @@ -45,6 +43,13 @@ def initialize(N, datatype=np.float64): return alpha, beta, A, u1, v1, u2, v2, w, x, y, z +def gemver_jax_kernel(jnp, alpha, beta, A, u1, v1, u2, v2, w, x, y, z): + A += jnp.outer(u1, v1) + jnp.outer(u2, v2) + x += beta * y @ A + z + w += alpha * A @ x + return jnp.sum(w) + + def run_gemver(device_type: dace.dtypes.DeviceType): ''' Runs Gemver for the given device @@ -63,20 +68,6 @@ def run_gemver(device_type: dace.dtypes.DeviceType): sdfg = gemver_kernel.to_sdfg() sdfg = auto_optimize(sdfg, device_type) sdfg(alpha, beta, A, np.copy(u1), v1, u2, v2, w, x, y, z, N=N) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = gemver_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - # Use FPGA Expansion for lib nodes, and expand them to enable further optimizations - from dace.libraries.blas import Gemv - Gemv.default_implementation = "FPGA_Accumulate" - sdfg.expand_library_nodes() - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - sdfg.specialize(dict(N=N)) - sdfg(alpha, beta, A, np.copy(u1), v1, u2, v2, w, x, y, z) - # Compute ground truth and validate gemver_kernel.f(alpha, beta, A_ref, u1, v1, u2, v2, w_ref, x_ref, y, z) assert np.allclose(A, A_ref) @@ -86,6 +77,56 @@ def run_gemver(device_type: dace.dtypes.DeviceType): return sdfg +def run_gemver_autodiff(): + import jax + import jax.numpy as jnp + + # Initialize data (polybench mini size) + N = sizes["mini"] + alpha, beta, A, u1, v1, u2, v2, w, x, y, z = initialize(N) + A_jax, u1_jax, v1_jax, u2_jax, v2_jax, w_jax, x_jax, y_jax, z_jax = map(np.copy, (A, u1, v1, u2, v2, w, x, y, z)) + + # Initialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(alpha: dc.float64, beta: dc.float64, A: dc.float64[N, N], u1: dc.float64[N], v1: dc.float64[N], + u2: dc.float64[N], v2: dc.float64[N], w: dc.float64[N], x: dc.float64[N], y: dc.float64[N], + z: dc.float64[N]): + gemver_kernel(alpha, beta, A, u1, v1, u2, v2, w, x, y, z) + return np.sum(w) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(alpha, + beta, + A, + np.copy(u1), + v1, + u2, + v2, + w, + x, + y, + z, + N=N, + gradient_A=gradient_A, + gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda alpha, beta, A, u1, v1, u2, v2, w, x, y, z: gemver_jax_kernel( + jnp, alpha, beta, A, u1, v1, u2, v2, w, x, y, z) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=2)) + jax_grad_A = jax_grad(alpha, beta, A_jax, u1_jax, v1_jax, u2_jax, v2_jax, w_jax, x_jax, y_jax, z_jax) + np.testing.assert_allclose(gradient_A, jax_grad_A) + + def test_cpu(): run_gemver(dace.dtypes.DeviceType.CPU) @@ -95,22 +136,22 @@ def test_gpu(): run_gemver(dace.dtypes.DeviceType.GPU) -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_gemver(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_gemver_autodiff() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_gemver(dace.dtypes.DeviceType.CPU) + run_gemver_autodiff() elif target == "gpu": run_gemver(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_gemver(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/polybench/gesummv_test.py b/tests/npbench/polybench/gesummv_test.py index 9ba11a6b44..62a13ba452 100644 --- a/tests/npbench/polybench/gesummv_test.py +++ b/tests/npbench/polybench/gesummv_test.py @@ -5,11 +5,9 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test, xilinx_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt -from dace.config import set_temporary +from dace.transformation.interstate import InlineSDFG +from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass # Data set sizes # N @@ -34,6 +32,10 @@ def initialize(N, datatype=np.float64): return alpha, beta, A, B, x +def gesummv_jax_kernel(jnp, alpha, beta, A, B, x): + return jnp.sum(alpha * (A @ x) + beta * (B @ x)) + + def run_gesummv(device_type: dace.dtypes.DeviceType): ''' Runs Gesummv for the given device @@ -49,26 +51,46 @@ def run_gesummv(device_type: dace.dtypes.DeviceType): sdfg = gesummv_kernel.to_sdfg() sdfg = auto_optimize(sdfg, device_type) C = sdfg(alpha, beta, A, B, x, N=N) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = gesummv_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - # Use FPGA Expansion for lib nodes, and expand them to enable further optimizations - from dace.libraries.blas import Gemv - Gemv.default_implementation = "FPGA_Accumulate" - sdfg.expand_library_nodes() - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - sdfg.specialize(dict(N=N)) - C = sdfg(alpha, beta, A, B, x) - # Compute ground truth and validate C_ref = gesummv_kernel.f(alpha, beta, A, B, x) assert np.allclose(C, C_ref) return sdfg +def run_gesummv_autodiff(): + import jax + import jax.numpy as jnp + + # Initialize data (polybench mini size) + N = sizes["mini"] + alpha, beta, A, B, x = initialize(N) + + # Initialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(alpha: dc.float64, beta: dc.float64, A: dc.float64[N, N], B: dc.float64[N, N], + x: dc.float64[N]): + C = gesummv_kernel(alpha, beta, A, B, x) + return np.sum(C) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(alpha, beta, A, B, x, N=N, gradient_A=gradient_A, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda alpha, beta, A, B, x: gesummv_jax_kernel(jnp, alpha, beta, A, B, x) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=2), static_argnums=(0, 1)) + jax_grad_A = jax_grad(alpha, beta, A, B, x) + np.testing.assert_allclose(gradient_A, jax_grad_A) + + def test_cpu(): run_gesummv(dace.dtypes.DeviceType.CPU) @@ -78,23 +100,22 @@ def test_gpu(): run_gesummv(dace.dtypes.DeviceType.GPU) -@pytest.mark.skip(reason="Xilinx synthesis fails") -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_gesummv(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_gesummv_autodiff() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_gesummv(dace.dtypes.DeviceType.CPU) + run_gesummv_autodiff() elif target == "gpu": run_gesummv(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_gesummv(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/polybench/gramschmidt_test.py b/tests/npbench/polybench/gramschmidt_test.py index 5217db86f8..79e86b2da3 100644 --- a/tests/npbench/polybench/gramschmidt_test.py +++ b/tests/npbench/polybench/gramschmidt_test.py @@ -5,11 +5,10 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test, xilinx_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt +from dace.transformation.interstate import InlineSDFG +from dace.transformation.auto.auto_optimize import auto_optimize from dace.config import set_temporary +from dace.autodiff import add_backward_pass # Data set sizes # M, N @@ -47,6 +46,39 @@ def initialize(M, N, datatype=np.float64): return A +def gramschmidt_jax_kernel(jnp, lax, A): + n = A.shape[1] + Q = jnp.zeros_like(A) + R = jnp.zeros((n, n), dtype=A.dtype) + + def body_fun(carry, k): + Q, R, A = carry + + nrm = jnp.dot(A[:, k], A[:, k]) + R = R.at[k, k].set(jnp.sqrt(nrm)) + Q = Q.at[:, k].set(A[:, k] / R[k, k]) + + def inner_body_fun(carry_inner, j): + Q, R, A = carry_inner + + def do_update(_): + new_R = R.at[k, j].set(jnp.dot(Q[:, k], A[:, j])) + new_A = A.at[:, j].add(-Q[:, k] * new_R[k, j]) + return (Q, new_R, new_A) + + def no_update(_): + return (Q, R, A) + + Q, R, A = lax.cond(j >= (k + 1), do_update, no_update, operand=None) + return (Q, R, A), None + + (Q, R, A), _ = lax.scan(inner_body_fun, (Q, R, A), jnp.arange(n)) + return (Q, R, A), None + + (Q, R, A), _ = lax.scan(body_fun, (Q, R, A), jnp.arange(n)) + return jnp.sum(A) + + def ground_truth(A): Q = np.zeros_like(A) @@ -79,20 +111,6 @@ def run_gramschmidt(device_type: dace.dtypes.DeviceType): sdfg = gramschmidt_kernel.to_sdfg() sdfg = auto_optimize(sdfg, device_type) Q, R = sdfg(A, M=M, N=N) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = gramschmidt_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - # Use FPGA Expansion for lib nodes, and expand them to enable further optimizations - from dace.libraries.blas import Dot - Dot.default_implementation = "FPGA_PartialSums" - sdfg.expand_library_nodes() - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - sdfg.specialize(dict(M=M, N=N)) - Q, R = sdfg(A) - # Compute ground truth and validate Q_ref, R_ref = ground_truth(A_ref) assert np.allclose(Q, Q_ref) @@ -100,6 +118,41 @@ def run_gramschmidt(device_type: dace.dtypes.DeviceType): return sdfg +def run_gramschmidt_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (polybench mini size) + M, N = sizes["mini"] + A = initialize(M, N) + + # Initialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(A: dc.float64[M, N]): + Q, R = gramschmidt_kernel(A) + return np.sum(A) # Sum the modified A matrix + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(A, M=M, N=N, gradient_A=gradient_A, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda A: gramschmidt_jax_kernel(jnp, lax, A) + jax_grad = jax.jit(jax.grad(jax_kernel)) + A_jax = initialize(M, N) + jax_grad_A = jax_grad(A_jax) + np.testing.assert_allclose(gradient_A, jax_grad_A) + + def test_cpu(): run_gramschmidt(dace.dtypes.DeviceType.CPU) @@ -109,22 +162,22 @@ def test_gpu(): run_gramschmidt(dace.dtypes.DeviceType.GPU) -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_gramschmidt(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_gramschmidt_autodiff() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_gramschmidt(dace.dtypes.DeviceType.CPU) + run_gramschmidt_autodiff() elif target == "gpu": run_gramschmidt(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_gramschmidt(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/polybench/heat_3d_test.py b/tests/npbench/polybench/heat_3d_test.py index 75ad902c4b..84b2d7a11c 100644 --- a/tests/npbench/polybench/heat_3d_test.py +++ b/tests/npbench/polybench/heat_3d_test.py @@ -5,11 +5,10 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test, xilinx_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt +from dace.transformation.interstate import InlineSDFG +from dace.transformation.auto.auto_optimize import auto_optimize from dace.config import set_temporary +from dace.autodiff import add_backward_pass # Dataset sizes # TSTEPS, N @@ -39,6 +38,29 @@ def initialize(N, datatype=np.float64): return A, B +def heat_3d_jax_kernel(jnp, lax, TSTEPS, A, B): + + def time_step(carry, t): + A, B = carry + B_new = B.at[1:-1, 1:-1, + 1:-1].set(0.125 * (A[2:, 1:-1, 1:-1] - 2.0 * A[1:-1, 1:-1, 1:-1] + A[:-2, 1:-1, 1:-1]) + 0.125 * + (A[1:-1, 2:, 1:-1] - 2.0 * A[1:-1, 1:-1, 1:-1] + A[1:-1, :-2, 1:-1]) + 0.125 * + (A[1:-1, 1:-1, 2:] - 2.0 * A[1:-1, 1:-1, 1:-1] + A[1:-1, 1:-1, :-2]) + + A[1:-1, 1:-1, 1:-1]) + A_new = A.at[1:-1, 1:-1, + 1:-1].set(0.125 * + (B_new[2:, 1:-1, 1:-1] - 2.0 * B_new[1:-1, 1:-1, 1:-1] + B_new[:-2, 1:-1, 1:-1]) + + 0.125 * + (B_new[1:-1, 2:, 1:-1] - 2.0 * B_new[1:-1, 1:-1, 1:-1] + B_new[1:-1, :-2, 1:-1]) + + 0.125 * + (B_new[1:-1, 1:-1, 2:] - 2.0 * B_new[1:-1, 1:-1, 1:-1] + B_new[1:-1, 1:-1, :-2]) + + B_new[1:-1, 1:-1, 1:-1]) + return (A_new, B_new), None + + (A_final, B_final), _ = lax.scan(time_step, (A, B), jnp.arange(1, TSTEPS)) + return jnp.sum(A_final) + + def ground_truth(TSTEPS, A, B): for t in range(1, TSTEPS): @@ -81,26 +103,47 @@ def count_maps(sdfg: dc.SDFG) -> int: after_maps = count_maps(sdfg) assert after_maps < initial_maps, f"Expected less maps, initially {initial_maps} many maps, but after optimization {after_maps}" sdfg(TSTEPS, A, B, N=N) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = heat_3d_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - # Use FPGA Expansion for lib nodes, and expand them to enable further optimizations - from dace.libraries.blas import Dot - Dot.default_implementation = "FPGA_PartialSums" - sdfg.expand_library_nodes() - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - sdfg.specialize(dict(N=N)) - sdfg(TSTEPS, A, B) - # Compute ground truth and validate ground_truth(TSTEPS, A_ref, B_ref) assert np.allclose(A, A_ref) return sdfg +def run_heat_3d_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (polybench small size) + TSTEPS, N = sizes["small"] + A, B = initialize(N) + + # Initialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(TSTEPS: dc.int64, A: dc.float64[N, N, N], B: dc.float64[N, N, N]): + heat_3d_kernel(TSTEPS, A, B) + return np.sum(A) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(TSTEPS, A, B, N=N, gradient_A=gradient_A, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda TSTEPS, A, B: heat_3d_jax_kernel(jnp, lax, TSTEPS, A, B) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=1), static_argnums=(0, )) + A_jax, B_jax = initialize(N) + jax_grad_A = jax_grad(TSTEPS, A_jax, B_jax) + np.testing.assert_allclose(gradient_A, jax_grad_A) + + def test_cpu(): run_heat_3d(dace.dtypes.DeviceType.CPU) @@ -110,22 +153,22 @@ def test_gpu(): run_heat_3d(dace.dtypes.DeviceType.GPU) -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_heat_3d(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_heat_3d_autodiff() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_heat_3d(dace.dtypes.DeviceType.CPU) + run_heat_3d_autodiff() elif target == "gpu": run_heat_3d(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_heat_3d(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/polybench/jacobi_1d_test.py b/tests/npbench/polybench/jacobi_1d_test.py index 61f7ba1211..ef324c2763 100644 --- a/tests/npbench/polybench/jacobi_1d_test.py +++ b/tests/npbench/polybench/jacobi_1d_test.py @@ -5,11 +5,9 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition +from dace.transformation.interstate import InlineSDFG from dace.transformation.auto.auto_optimize import auto_optimize -from dace.config import set_temporary +from dace.autodiff import add_backward_pass # Dataset sizes # TSTEPS, N @@ -25,6 +23,15 @@ def jacobi_1d_kernel(TSTEPS: dc.int64, A: dc.float64[N], B: dc.float64[N]): A[1:-1] = 0.33333 * (B[:-2] + B[1:-1] + B[2:]) +def jacobi_1d_jax_kernel(jax, jnp, TSTEPS, A, B): + + for t in range(1, TSTEPS): + B = B.at[1:-1].set(0.33333 * (A[:-2] + A[1:-1] + A[2:])) + A = A.at[1:-1].set(0.33333 * (B[:-2] + B[1:-1] + B[2:])) + + return jax.block_until_ready(jnp.sum(A)) + + def initialize(N, datatype=np.float64): A = np.fromfunction(lambda i: (i + 2) / N, (N, ), dtype=datatype) B = np.fromfunction(lambda i: (i + 3) / N, (N, ), dtype=datatype) @@ -56,26 +63,46 @@ def run_jacobi_1d(device_type: dace.dtypes.DeviceType): sdfg = jacobi_1d_kernel.to_sdfg() sdfg = auto_optimize(sdfg, device_type) sdfg(TSTEPS, A, B, N=N) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = jacobi_1d_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - # Use FPGA Expansion for lib nodes, and expand them to enable further optimizations - from dace.libraries.blas import Dot - Dot.default_implementation = "FPGA_PartialSums" - sdfg.expand_library_nodes() - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - sdfg.specialize(dict(N=N)) - sdfg(TSTEPS=TSTEPS, A=A, B=B) - # Compute ground truth and validate ground_truth(TSTEPS, A_ref, B_ref) assert np.allclose(A, A_ref) return sdfg +def run_jacobi_1d_autodiff(): + import jax + import jax.numpy as jnp + + # Initialize data (polybench mini size) + TSTEPS, N = (20, 30) + A, B = initialize(N) + jax_A, jax_B = np.copy(A), np.copy(B) + + # Intiialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(TSTEPS: dc.int64, A: dc.float64[N], B: dc.float64[N]): + jacobi_1d_kernel(TSTEPS, A, B) + return np.sum(A) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(TSTEPS, A, B, gradient_A=gradient_A, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda TSTEPS, A, B: jacobi_1d_jax_kernel(jax, jnp, TSTEPS, A, B) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=1), static_argnums=0) + jax_grad_A = jax_grad(TSTEPS, jax_A, jax_B) + np.testing.assert_allclose(gradient_A, jax_grad_A) + + def test_cpu(): run_jacobi_1d(dace.dtypes.DeviceType.CPU) @@ -85,22 +112,22 @@ def test_gpu(): run_jacobi_1d(dace.dtypes.DeviceType.GPU) -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_jacobi_1d(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_jacobi_1d_autodiff() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_jacobi_1d(dace.dtypes.DeviceType.CPU) + run_jacobi_1d_autodiff() elif target == "gpu": run_jacobi_1d(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_jacobi_1d(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/polybench/jacobi_2d_test.py b/tests/npbench/polybench/jacobi_2d_test.py index 58ceb4c365..6ad77c4163 100644 --- a/tests/npbench/polybench/jacobi_2d_test.py +++ b/tests/npbench/polybench/jacobi_2d_test.py @@ -5,10 +5,10 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG +from dace.transformation.interstate import InlineSDFG from dace.transformation.dataflow import StreamingMemory, MapFusionVertical from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass N = dc.symbol('N', dtype=dc.int32) @@ -21,6 +21,20 @@ def kernel(TSTEPS: dc.int32, A: dc.float32[N, N], B: dc.float32[N, N]): A[1:-1, 1:-1] = 0.2 * (B[1:-1, 1:-1] + B[1:-1, :-2] + B[1:-1, 2:] + B[2:, 1:-1] + B[:-2, 1:-1]) +def kernel_jax(jnp, lax, TSTEPS, A, B): + + def body_fn(carry, t): + A, B = carry + + B = B.at[1:-1, 1:-1].set(0.2 * (A[1:-1, 1:-1] + A[1:-1, :-2] + A[1:-1, 2:] + A[2:, 1:-1] + A[:-2, 1:-1])) + + A = A.at[1:-1, 1:-1].set(0.2 * (B[1:-1, 1:-1] + B[1:-1, :-2] + B[1:-1, 2:] + B[2:, 1:-1] + B[:-2, 1:-1])) + return (A, B), None + + (A, B), _ = lax.scan(body_fn, (A, B), jnp.arange(1, TSTEPS)) + return jnp.sum(A) + + def init_data(N): A = np.empty((N, N), dtype=np.float32) B = np.empty((N, N), dtype=np.float32) @@ -50,30 +64,6 @@ def run_jacobi_2d(device_type: dace.dtypes.DeviceType): sdfg(A=A, B=B, TSTEPS=TSTEPS, N=N) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = kernel.to_sdfg(simplify=True) - sdfg.apply_transformations_repeated([MapFusionVertical]) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - sm_applied = sdfg.apply_transformations_repeated([InlineSDFG, StreamingMemory], - [{}, { - 'storage': dace.StorageType.FPGA_Local - }], - print_report=True) - - assert sm_applied > 0 - - # In this case, we want to generate the top-level state as an host-based state, - # not an FPGA kernel. We need to explicitly indicate that - sdfg.states()[0].location["is_FPGA_kernel"] = False - # we need to specialize both the top-level SDFG and the nested SDFG - for sd in sdfg.all_sdfgs_recursive(): - sd.specialize(dict(N=N)) - # run program - sdfg(A=A, B=B, TSTEPS=TSTEPS) - # Compute ground truth and validate result kernel.f(TSTEPS, np_A, np_B) assert np.allclose(A, np_A) @@ -81,6 +71,38 @@ def run_jacobi_2d(device_type: dace.dtypes.DeviceType): return sdfg +def run_jacobi_2d_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (polybench mini size) + TSTEPS, N = (20, 30) + A, B = init_data(N) + jax_A, jax_B = np.copy(A), np.copy(B) + + # Intiialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float32) + + # Define sum reduction for the output + @dc.program + def jacobi_2d_autodiff_kernel(TSTEPS: dc.int32, A: dc.float32[N, N], B: dc.float32[N, N]): + kernel(TSTEPS, A, B) + return np.sum(A) + + # Add the backward pass to the SDFG + sdfg = jacobi_2d_autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(TSTEPS, A, B, gradient_A=gradient_A, gradient___return=gradient___return) + + # Numerically validate vs JAX + jax_kernel = lambda TSTEPS, A, B: kernel_jax(jnp, lax, TSTEPS, A, B) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=1), static_argnums=0) + jax_grad_A = jax_grad(TSTEPS, jax_A, jax_B) + np.testing.assert_allclose(gradient_A, jax_grad_A, rtol=1e-6, atol=1e-6) + + def test_cpu(): run_jacobi_2d(dace.dtypes.DeviceType.CPU) @@ -90,22 +112,22 @@ def test_gpu(): run_jacobi_2d(dace.dtypes.DeviceType.GPU) -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_jacobi_2d(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_jacobi_2d_autodiff() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_jacobi_2d(dace.dtypes.DeviceType.CPU) + run_jacobi_2d_autodiff() elif target == "gpu": run_jacobi_2d(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_jacobi_2d(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/polybench/k2mm_test.py b/tests/npbench/polybench/k2mm_test.py index e7a26833fb..7fedd26df1 100644 --- a/tests/npbench/polybench/k2mm_test.py +++ b/tests/npbench/polybench/k2mm_test.py @@ -5,11 +5,9 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test, xilinx_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt -from dace.config import set_temporary +from dace.transformation.interstate import InlineSDFG +from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass # Data set sizes # NI, NJ, NK, NL @@ -31,6 +29,11 @@ def k2mm_kernel(alpha: dc.float64, beta: dc.float64, A: dc.float64[NI, NK], B: d D[:] = alpha * A @ B @ C + beta * D +def k2mm_jax(jnp, alpha, beta, A, B, C, D): + D = alpha * A @ B @ C + beta * D + return jnp.sum(D) + + def initialize(NI, NJ, NK, NL, datatype=np.float64): alpha = datatype(1.5) beta = datatype(1.2) @@ -58,25 +61,57 @@ def run_k2mm(device_type: dace.dtypes.DeviceType): sdfg = k2mm_kernel.to_sdfg() sdfg = auto_optimize(sdfg, device_type) sdfg(alpha, beta, A, B, C, D, NI=NI, NJ=NJ, NK=NK, NL=NL) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = k2mm_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - # Use FPGA Expansion for lib nodes, and expand them to enable further optimizations - from dace.libraries.blas import Gemm - Gemm.default_implementation = "FPGA1DSystolic" - sdfg.expand_library_nodes() - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - sdfg.specialize(dict(NI=NI, NJ=NJ, NK=NK, NL=NL)) - sdfg(alpha, beta, A, B, C, D) # Compute ground truth and validate k2mm_kernel.f(alpha, beta, A, B, C, D_ref) assert np.allclose(D, D_ref) return sdfg +def run_k2mm_autodiff(): + import jax + import jax.numpy as jnp + + # Initialize forward data + NI, NJ, NK, NL = sizes["small"] + alpha, beta, A, B, C, D = initialize(NI, NJ, NK, NL) + + # Intiialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(alpha: dc.float64, beta: dc.float64, A: dc.float64[NI, NK], B: dc.float64[NK, NJ], + C: dc.float64[NJ, NL], D: dc.float64[NI, NL]): + k2mm_kernel(alpha, beta, A, B, C, D) + return np.sum(D) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(alpha, + beta, + A, + B, + C, + D, + NI=NI, + NJ=NJ, + NK=NK, + NL=NL, + gradient_A=gradient_A, + gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda alpha, beta, A, B, C, D: k2mm_jax(jnp, alpha, beta, A, B, C, D) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=2)) + jax_grad_A = jax_grad(alpha, beta, A, B, C, D) + np.testing.assert_allclose(gradient_A, jax_grad_A) + + def test_cpu(): run_k2mm(dace.dtypes.DeviceType.CPU) @@ -86,22 +121,22 @@ def test_gpu(): run_k2mm(dace.dtypes.DeviceType.GPU) -@fpga_test(assert_ii_1=False, xilinx=False) -def test_fpga(): - return run_k2mm(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_k2mm_autodiff() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_k2mm(dace.dtypes.DeviceType.CPU) + run_k2mm_autodiff() elif target == "gpu": run_k2mm(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_k2mm(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/polybench/k3mm_test.py b/tests/npbench/polybench/k3mm_test.py index 398b30e107..c726c3d9c4 100644 --- a/tests/npbench/polybench/k3mm_test.py +++ b/tests/npbench/polybench/k3mm_test.py @@ -5,11 +5,9 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test, xilinx_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt -from dace.config import set_temporary +from dace.transformation.interstate import InlineSDFG +from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass # Data set sizes # NI, NJ, NK, NL, NM @@ -30,6 +28,11 @@ def k3mm_kernel(A: dc.float64[NI, NK], B: dc.float64[NK, NJ], C: dc.float64[NJ, return A @ B @ C @ D +def k3mm_jax(jnp, A, B, C, D): + E = A @ B @ C @ D + return jnp.sum(E) + + def initialize(NI, NJ, NK, NL, NM, datatype=np.float64): A = np.fromfunction(lambda i, j: ((i * j + 1) % NI) / (5 * NI), (NI, NK), dtype=datatype) B = np.fromfunction(lambda i, j: ((i * (j + 1) + 2) % NJ) / (5 * NJ), (NK, NJ), dtype=datatype) @@ -54,25 +57,45 @@ def run_k3mm(device_type: dace.dtypes.DeviceType): sdfg = k3mm_kernel.to_sdfg() sdfg = auto_optimize(sdfg, device_type) E = sdfg(A, B, C, D, NI=NI, NJ=NJ, NK=NK, NL=NL, NM=NM) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = k3mm_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - # Use FPGA Expansion for lib nodes, and expand them to enable further optimizations - from dace.libraries.blas import Gemm - Gemm.default_implementation = "FPGA1DSystolic" - sdfg.expand_library_nodes() - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - sdfg.specialize(dict(NI=NI, NJ=NJ, NK=NK, NL=NL, NM=NM)) - E = sdfg(A, B, C, D) # Compute ground truth and validate E_ref = k3mm_kernel.f(A, B, C, D) assert np.allclose(E, E_ref) return sdfg +def run_k3mm_autodiff(): + import jax + import jax.numpy as jnp + + # Initialize forward data + NI, NJ, NK, NL, NM = sizes["small"] + A, B, C, D = initialize(NI, NJ, NK, NL, NM) + + # Intiialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(A: dc.float64[NI, NK], B: dc.float64[NK, NJ], C: dc.float64[NJ, NM], D: dc.float64[NM, NL]): + E = k3mm_kernel(A, B, C, D) + return np.sum(E) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(A, B, C, D, NI=NI, NJ=NJ, NK=NK, NL=NL, NM=NM, gradient_A=gradient_A, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda A, B, C, D: k3mm_jax(jnp, A, B, C, D) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=0)) + jax_grad_A = jax_grad(A, B, C, D) + np.testing.assert_allclose(gradient_A, jax_grad_A) + + def test_cpu(): run_k3mm(dace.dtypes.DeviceType.CPU) @@ -82,22 +105,22 @@ def test_gpu(): run_k3mm(dace.dtypes.DeviceType.GPU) -@fpga_test(assert_ii_1=False, xilinx=False) -def test_fpga(): - return run_k3mm(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_k3mm_autodiff() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_k3mm(dace.dtypes.DeviceType.CPU) + run_k3mm_autodiff() elif target == "gpu": run_k3mm(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_k3mm(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/polybench/lu_test.py b/tests/npbench/polybench/lu_test.py index 1786503918..6f812d0f25 100644 --- a/tests/npbench/polybench/lu_test.py +++ b/tests/npbench/polybench/lu_test.py @@ -6,10 +6,10 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG +from dace.transformation.interstate import InlineSDFG from dace.transformation.dataflow import StreamingMemory, MapFusionVertical, StreamingComposition, PruneConnectors -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt +from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass N = dc.symbol('N', dtype=dc.int32) @@ -52,6 +52,43 @@ def init_data(N): return A +def lu_jax_kernel(jnp, lax, A): + n = A.shape[0] + + def outer_loop_body(A, i): + + def inner_loop_1_body(A, j): + + def update_fn(_): + mask = jnp.arange(n) < j + A_slice_1 = jnp.where(mask, A[i, :], 0.0) + A_slice_2 = jnp.where(mask, A[:, j], 0.0) + new_val = (A[i, j] - A_slice_1 @ A_slice_2) / A[j, j] + return A.at[i, j].set(new_val) + + A = lax.cond(j < i, lambda _: update_fn(None), lambda _: A, operand=None) + return A, None + + def inner_loop_2_body(A, j): + + def update_fn(_): + mask = jnp.arange(n) < i + A_slice_1 = jnp.where(mask, A[i, :], 0.0) + A_slice_2 = jnp.where(mask, A[:, j], 0.0) + new_val = A[i, j] - A_slice_1 @ A_slice_2 + return A.at[i, j].set(new_val) + + A = lax.cond(j >= i, lambda _: update_fn(None), lambda _: A, operand=None) + return A, None + + A, _ = lax.scan(inner_loop_1_body, A, jnp.arange(n)) + A, _ = lax.scan(inner_loop_2_body, A, jnp.arange(n)) + return A, None + + A, _ = lax.scan(outer_loop_body, A, jnp.arange(n)) + return jnp.sum(A) + + def run_lu(device_type: dace.dtypes.DeviceType): """ Runs LU for the given device @@ -70,35 +107,43 @@ def run_lu(device_type: dace.dtypes.DeviceType): auto_optimize(sdfg, device=device_type) dace_res = sdfg(A=A, N=N) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = lu_kernel.to_sdfg(simplify=True) + # Compute ground truth and validate result + ground_truth(N, gt_A) + diff = np.linalg.norm(gt_A - A) / np.linalg.norm(gt_A) + assert diff < 1e-5 + return sdfg + - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 +def run_lu_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax - # Use FPGA Expansion for lib nodes, and expand them to enable further optimizations - from dace.libraries.blas import Dot - platform = dace.config.Config.get("compiler", "fpga", "vendor") - if platform == "intel_fpga": - Dot.default_implementation = "FPGA_Accumulate" - else: - Dot.default_implementation = "FPGA_PartialSums" + # Initialize data (polybench mini size) + N = 5 + A = init_data(N) + A_jax = jnp.copy(A) - sdfg.expand_library_nodes() - sdfg.apply_transformations_repeated([InlineSDFG]) + # Initialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float32) - fpga_auto_opt.fpga_rr_interleave_containers_to_banks(sdfg) - fpga_auto_opt.fpga_global_to_local(sdfg) + # Define sum reduction for the output + @dc.program + def autodiff_kernel(A: dc.float32[N, N]): + lu_kernel(A) + return np.sum(A) - sdfg.specialize(dict(N=N)) - dace_res = sdfg(A=A) + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(A, N=N, gradient_A=gradient_A, gradient___return=gradient___return) - # Compute ground truth and validate result - ground_truth(N, gt_A) - diff = np.linalg.norm(gt_A - A) / np.linalg.norm(gt_A) - assert diff < 1e-5 - return sdfg + # Numerically validate vs JAX + jax_kernel = lambda A: lu_jax_kernel(jnp, lax, A) + jax_grad = jax.jit(jax.grad(jax_kernel)) + jax_grad_A = jax_grad(A_jax) + np.testing.assert_allclose(gradient_A, jax_grad_A, rtol=1e-5, atol=1e-5) def test_cpu(): @@ -110,22 +155,22 @@ def test_gpu(): run_lu(dace.dtypes.DeviceType.GPU) -@fpga_test(assert_ii_1=False, xilinx=False) -def test_fpga(): - return run_lu(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_lu_autodiff() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_lu(dace.dtypes.DeviceType.CPU) + run_lu_autodiff() elif target == "gpu": run_lu(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_lu(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/polybench/ludcmp_test.py b/tests/npbench/polybench/ludcmp_test.py index 2ffa681616..509ab75f98 100644 --- a/tests/npbench/polybench/ludcmp_test.py +++ b/tests/npbench/polybench/ludcmp_test.py @@ -1,15 +1,15 @@ # Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. # Original application code: NPBench - https://github.com/spcl/npbench +import os import dace.dtypes import numpy as np import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition +from dace.transformation.interstate import InlineSDFG from dace.transformation.auto.auto_optimize import auto_optimize from dace.config import set_temporary +from dace.autodiff import add_backward_pass # Dataset sizes # TSTEPS, N @@ -69,6 +69,100 @@ def ground_truth(A, b): return x, y +def ludcmp_jax_kernel(jnp, lax, A, b): + n = A.shape[0] + x = jnp.zeros_like(b) + y = jnp.zeros_like(b) + + def outer_loop_body_1(A, i): + + def inner_loop_1_body(A, j): + + def update(): + A_slice_1 = jnp.where(jnp.arange(n) < j, A[i, :], 0.0) + A_slice_2 = jnp.where(jnp.arange(n) < j, A[:, j], 0.0) + new_val = (A[i, j] - A_slice_1 @ A_slice_2) / A[j, j] + return A.at[i, j].set(new_val) + + A = lax.cond(j < i, lambda _: update(), lambda _: A, operand=None) + return A, None + + def inner_loop_2_body(A, j): + + def update(): + A_slice_1 = jnp.where(jnp.arange(n) < i, A[i, :], 0.0) + A_slice_2 = jnp.where(jnp.arange(n) < i, A[:, j], 0.0) + new_val = A[i, j] - A_slice_1 @ A_slice_2 + return A.at[i, j].set(new_val) + + A = lax.cond(j >= i, lambda _: update(), lambda _: A, operand=None) + return A, None + + A, _ = lax.scan(inner_loop_1_body, A, jnp.arange(n)) + A, _ = lax.scan(inner_loop_2_body, A, jnp.arange(n)) + return A, None + + A, _ = lax.scan(outer_loop_body_1, A, jnp.arange(n)) + + def loop_body_2_scan(loop_vars, i): + A, y, b = loop_vars + A_slice = jnp.where(jnp.arange(n) < i, A[i, :], 0.0) + y_slice = jnp.where(jnp.arange(n) < i, y, 0.0) + new_y = b[i] - A_slice @ y_slice + y = y.at[i].set(new_y) + return (A, y, b), None + + (A, y, b), _ = lax.scan(loop_body_2_scan, (A, y, b), jnp.arange(n)) + + def loop_body_3_scan(loop_vars, t): + A, x, y = loop_vars + i = n - 1 - t # reverse order + A_slice = jnp.where(jnp.arange(n) > i, A[i, :], 0.0) + x_slice = jnp.where(jnp.arange(n) > i, x, 0.0) + new_x = (y[i] - A_slice @ x_slice) / A[i, i] + x = x.at[i].set(new_x) + return (A, x, y), None + + (A, x, y), _ = lax.scan(loop_body_3_scan, (A, x, y), jnp.arange(n)) + + return jnp.sum(x) + + +def run_ludcmp_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (polybench mini size) + N = sizes["mini"] + A, b = initialize(N) + + # Initialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(A: dc.float64[N, N], b: dc.float64[N]): + x, y = ludcmp_kernel(A, b) + return np.sum(x) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(A, b, N=N, gradient_A=gradient_A, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda A, b: ludcmp_jax_kernel(jnp, lax, A, b) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=0)) + A_jax, b_jax = initialize(N) + jax_grad_A = jax_grad(A_jax, b_jax) + np.testing.assert_allclose(gradient_A, jax_grad_A) + + def run_ludcmp(device_type: dace.dtypes.DeviceType): ''' Runs Ludcmp for the given device @@ -85,20 +179,6 @@ def run_ludcmp(device_type: dace.dtypes.DeviceType): sdfg = ludcmp_kernel.to_sdfg() sdfg = auto_optimize(sdfg, device_type) x, y = sdfg(A, b, N=N) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = ludcmp_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - # Use FPGA Expansion for lib nodes, and expand them to enable further optimizations - from dace.libraries.blas import Dot - Dot.default_implementation = "FPGA_PartialSums" - sdfg.expand_library_nodes() - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - sdfg.specialize(dict(N=N)) - x, y = sdfg(A, b) - # Compute ground truth and validate x_ref, y_ref = ground_truth(A_ref, b) assert np.allclose(x, x_ref) @@ -115,22 +195,27 @@ def test_gpu(): run_ludcmp(dace.dtypes.DeviceType.GPU) -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_ludcmp(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + # Serialization causes issues, we temporarily disable it + # TODO: open an issue to fix the serialization stability problem + last_value = os.environ.get('DACE_testing_serialization', '0') + os.environ['DACE_testing_serialization'] = '0' + run_ludcmp_autodiff() + os.environ['DACE_testing_serialization'] = last_value if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_ludcmp(dace.dtypes.DeviceType.CPU) + run_ludcmp_autodiff() elif target == "gpu": run_ludcmp(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_ludcmp(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/polybench/mvt_test.py b/tests/npbench/polybench/mvt_test.py index a45024a15c..f3c35dfbd5 100644 --- a/tests/npbench/polybench/mvt_test.py +++ b/tests/npbench/polybench/mvt_test.py @@ -5,11 +5,9 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test, xilinx_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt -from dace.config import set_temporary +from dace.transformation.interstate import InlineSDFG +from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass # Data set sizes # N @@ -35,6 +33,12 @@ def initialize(N, datatype=np.float64): return x1, x2, y_1, y_2, A +def mvt_jax_kernel(jnp, x1, x2, y_1, y_2, A): + x1 += A @ y_1 + x2 += y_2 @ A + return jnp.sum(x2) + + def run_mvt(device_type: dace.dtypes.DeviceType): ''' Runs MVT for the given device @@ -52,20 +56,6 @@ def run_mvt(device_type: dace.dtypes.DeviceType): sdfg = mvt_kernel.to_sdfg() sdfg = auto_optimize(sdfg, device_type) sdfg(x1, x2, y_1, y_2, A, N=N) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = mvt_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - # Use FPGA Expansion for lib nodes, and expand them to enable further optimizations - from dace.libraries.blas import Gemv - Gemv.default_implementation = "FPGA_Accumulate" - sdfg.expand_library_nodes() - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - sdfg.specialize(dict(N=N)) - sdfg(x1, x2, y_1, y_2, A) - # Compute ground truth and validate mvt_kernel.f(x1_ref, x2_ref, y_1, y_2, A) assert np.allclose(x1, x1_ref) @@ -73,6 +63,41 @@ def run_mvt(device_type: dace.dtypes.DeviceType): return sdfg +def run_mvt_autodiff(): + import jax + import jax.numpy as jnp + + # Initialize data (polybench mini size) + N = sizes["mini"] + x1, x2, y_1, y_2, A = initialize(N) + + # Initialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(x1: dc.float64[N], x2: dc.float64[N], y_1: dc.float64[N], y_2: dc.float64[N], A: dc.float64[N, + N]): + mvt_kernel(x1, x2, y_1, y_2, A) + return np.sum(x2) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(x1, x2, y_1, y_2, A, N=N, gradient_A=gradient_A, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda x1, x2, y_1, y_2, A: mvt_jax_kernel(jnp, x1, x2, y_1, y_2, A) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=4)) + x1_jax, x2_jax, y_1_jax, y_2_jax, A_jax = initialize(N) + jax_grad_A = jax_grad(x1_jax, x2_jax, y_1_jax, y_2_jax, A_jax) + np.testing.assert_allclose(gradient_A, jax_grad_A) + + def test_cpu(): run_mvt(dace.dtypes.DeviceType.CPU) @@ -82,22 +107,22 @@ def test_gpu(): run_mvt(dace.dtypes.DeviceType.GPU) -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_mvt(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_mvt_autodiff() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_mvt(dace.dtypes.DeviceType.CPU) + run_mvt_autodiff() elif target == "gpu": run_mvt(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_mvt(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/polybench/nussinov_test.py b/tests/npbench/polybench/nussinov_test.py index 0e52cafc99..6b96fd94ee 100644 --- a/tests/npbench/polybench/nussinov_test.py +++ b/tests/npbench/polybench/nussinov_test.py @@ -6,10 +6,9 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test, xilinx_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG +from dace.transformation.interstate import InlineSDFG from dace.transformation.dataflow import StreamingMemory, MapFusionVertical, StreamingComposition, PruneConnectors -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt +from dace.transformation.auto.auto_optimize import auto_optimize from dace.config import set_temporary N = dc.symbol('N', dtype=dc.int32) @@ -100,18 +99,6 @@ def run_nussinov(device_type: dace.dtypes.DeviceType): sdfg = auto_optimize(sdfg, device_type) dace_res = sdfg(seq=seq, N=N) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - fpga_auto_opt.fpga_global_to_local(sdfg) # Necessary - fpga_auto_opt.fpga_rr_interleave_containers_to_banks(sdfg) - - sdfg.specialize(dict(N=N)) - dace_res = sdfg(seq=seq) - # Compute ground truth and validate result gt_res = ground_truth(N, seq) @@ -128,21 +115,10 @@ def test_gpu(): run_nussinov(dace.dtypes.DeviceType.GPU) -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_nussinov(dace.dtypes.DeviceType.FPGA) - - -@xilinx_test(assert_ii_1=False) -def test_xilinx_decoupled_array_interfaces(): - with set_temporary("compiler", "xilinx", "decouple_array_interfaces", value=True): - return run_nussinov(dace.dtypes.DeviceType.FPGA) - - if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] @@ -151,5 +127,3 @@ def test_xilinx_decoupled_array_interfaces(): run_nussinov(dace.dtypes.DeviceType.CPU) elif target == "gpu": run_nussinov(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_nussinov(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/polybench/seidel_2d_test.py b/tests/npbench/polybench/seidel_2d_test.py index 7d1f8a8389..0b4b0aeff0 100644 --- a/tests/npbench/polybench/seidel_2d_test.py +++ b/tests/npbench/polybench/seidel_2d_test.py @@ -5,11 +5,10 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition +from dace.transformation.interstate import InlineSDFG from dace.transformation.auto.auto_optimize import auto_optimize from dace.config import set_temporary +from dace.autodiff import add_backward_pass # Dataset sizes # TSTEPS, N @@ -35,6 +34,32 @@ def initialize(N, datatype=np.float64): return A +def seidel_2d_jax_kernel(jnp, lax, TSTEPS, A): + """JAX implementation using efficient lax.scan operations""" + N = A.shape[0] + + def loop1_body(A, t): + + def loop2_body(A, i): + update_val = (A[i, 1:-1] + (A[i - 1, :-2] + A[i - 1, 1:-1] + A[i - 1, 2:] + A[i, 2:] + A[i + 1, :-2] + + A[i + 1, 1:-1] + A[i + 1, 2:])) + A = A.at[i, 1:-1].set(update_val) + + def loop3_body(A, j): + new_val = (A[i, j] + A[i, j - 1]) / 9.0 + A = A.at[i, j].set(new_val) + return A, None + + A, _ = lax.scan(loop3_body, A, jnp.arange(1, N - 1)) + return A, None + + A, _ = lax.scan(loop2_body, A, jnp.arange(1, N - 1)) + return A, None + + A, _ = lax.scan(loop1_body, A, jnp.arange(TSTEPS - 1)) + return jnp.sum(A) + + def ground_truth(TSTEPS, N, A): for t in range(0, TSTEPS - 1): @@ -62,20 +87,6 @@ def run_seidel_2d(device_type: dace.dtypes.DeviceType): sdfg = seidel_2d_kernel.to_sdfg() # sdfg = auto_optimize(sdfg, device_type) # TBD sdfg(TSTEPS, A, N=N) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = seidel_2d_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - # Use FPGA Expansion for lib nodes, and expand them to enable further optimizations - from dace.libraries.blas import Dot - Dot.default_implementation = "FPGA_PartialSums" - sdfg.expand_library_nodes() - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - sdfg.specialize(dict(N=N)) - sdfg(TSTEPS, A) - # Compute ground truth and validate ground_truth( TSTEPS, @@ -87,6 +98,47 @@ def run_seidel_2d(device_type: dace.dtypes.DeviceType): return sdfg +def run_seidel_2d_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (test size for efficiency) + TSTEPS, N = (2, 8) + A = initialize(N) + + # Initialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output using __return pattern + @dc.program + def autodiff_kernel(TSTEPS: dc.int64, A: dc.float64[N, N]): + for t in range(0, TSTEPS - 1): + for i in range(1, N - 1): + A[i, 1:-1] += (A[i - 1, :-2] + A[i - 1, 1:-1] + A[i - 1, 2:] + A[i, 2:] + A[i + 1, :-2] + + A[i + 1, 1:-1] + A[i + 1, 2:]) + for j in range(1, N - 1): + A[i, j] += A[i, j - 1] + A[i, j] /= 9.0 + return np.sum(A) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(TSTEPS, A, N=N, gradient_A=gradient_A, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda TSTEPS, A: seidel_2d_jax_kernel(jnp, lax, TSTEPS, A) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=1), static_argnums=(0, )) + A_jax = initialize(N) + jax_grad_A = jax_grad(TSTEPS, A_jax) + np.testing.assert_allclose(gradient_A, jax_grad_A) + + def test_cpu(): run_seidel_2d(dace.dtypes.DeviceType.CPU) @@ -96,22 +148,22 @@ def test_gpu(): run_seidel_2d(dace.dtypes.DeviceType.GPU) -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_seidel_2d(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_seidel_2d_autodiff() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_seidel_2d(dace.dtypes.DeviceType.CPU) + run_seidel_2d_autodiff() elif target == "gpu": run_seidel_2d(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_seidel_2d(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/polybench/symm_test.py b/tests/npbench/polybench/symm_test.py index d0bae1edfc..17a4a5232b 100644 --- a/tests/npbench/polybench/symm_test.py +++ b/tests/npbench/polybench/symm_test.py @@ -5,11 +5,9 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test, xilinx_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt -from dace.config import set_temporary +from dace.transformation.interstate import InlineSDFG +from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass # Data set sizes # M, N @@ -43,6 +41,32 @@ def initialize(M, N, datatype=np.float64): return alpha, beta, C, A, B +def symm_jax_kernel(jnp, lax, alpha, beta, C, A, B): + temp2 = jnp.empty((C.shape[1], ), dtype=C.dtype) + C = C * beta + + def row_update_body(carry, i): + C, temp2 = carry + + def col_update_body(carry_inner, j): + C, temp2 = carry_inner + + A_slice = jnp.where(jnp.arange(A.shape[1]) < i, A[i, :], 0.0) + B_slice = jnp.where(jnp.arange(B.shape[0]) < i, B[:, j], 0.0) + + updated_col = C[:, j] + (alpha * B[i, j] * A_slice) + C = lax.dynamic_update_slice(C, updated_col[:, None], (0, j)) + temp2 = temp2.at[j].set(B_slice @ A_slice) + return (C, temp2), jnp.array(0) + + (C, temp2), _ = lax.scan(col_update_body, (C, temp2), jnp.arange(C.shape[1])) + C = C.at[i, :].add(alpha * B[i, :] * A[i, i] + alpha * temp2) + return (C, temp2), jnp.array(0) + + (C, temp2), _ = lax.scan(row_update_body, (C, temp2), jnp.arange(C.shape[0])) + return jnp.sum(C) + + def ground_truth(alpha, beta, C, A, B): temp2 = np.empty((C.shape[1], ), dtype=C.dtype) @@ -70,26 +94,48 @@ def run_symm(device_type: dace.dtypes.DeviceType): sdfg = symm_kernel.to_sdfg() sdfg = auto_optimize(sdfg, device_type) sdfg(alpha, beta, C, A, B, M=M, N=N) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = symm_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - # Use FPGA Expansion for lib nodes, and expand them to enable further optimizations - from dace.libraries.blas import Dot - Dot.default_implementation = "FPGA_PartialSums" - sdfg.expand_library_nodes() - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - sdfg.specialize(dict(M=M, N=N)) - sdfg(alpha, beta, C, A, B) - # Compute ground truth and validate ground_truth(alpha, beta, C_ref, A, B) assert np.allclose(C, C_ref) return sdfg +def run_symm_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (polybench mini size) + M, N = sizes["mini"] + alpha, beta, C, A, B = initialize(M, N) + + # Initialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(alpha: dc.float64, beta: dc.float64, C: dc.float64[M, N], A: dc.float64[M, M], + B: dc.float64[M, N]): + symm_kernel(alpha, beta, C, A, B) + return np.sum(C) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(alpha, beta, C, A, B, M=M, N=N, gradient_A=gradient_A, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda alpha, beta, C, A, B: symm_jax_kernel(jnp, lax, alpha, beta, C, A, B) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=3), static_argnums=(0, 1)) + alpha, beta, C_jax, A_jax, B_jax = initialize(M, N) + jax_grad_A = jax_grad(alpha, beta, C_jax, A_jax, B_jax) + np.testing.assert_allclose(gradient_A, jax_grad_A) + + def test_cpu(): run_symm(dace.dtypes.DeviceType.CPU) @@ -99,22 +145,22 @@ def test_gpu(): run_symm(dace.dtypes.DeviceType.GPU) -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_symm(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_symm_autodiff() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_symm(dace.dtypes.DeviceType.CPU) + run_symm_autodiff() elif target == "gpu": run_symm(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_symm(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/polybench/syr2k_test.py b/tests/npbench/polybench/syr2k_test.py index 0edb5a3045..411cc206e6 100644 --- a/tests/npbench/polybench/syr2k_test.py +++ b/tests/npbench/polybench/syr2k_test.py @@ -5,11 +5,9 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test, xilinx_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt -from dace.config import set_temporary +from dace.transformation.interstate import InlineSDFG +from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass # Data set sizes # M, N @@ -37,6 +35,51 @@ def initialize(M, N, datatype=np.float64): return alpha, beta, C, A, B +def syr2k_jax_kernel(jnp, lax, alpha, beta, C, A, B): + m = A.shape[0] # outer loop range + n = A.shape[1] # inner loop range + + def outer_body_fun(carry, i): + # Unpack loop variables for the outer loop. + alpha, beta, C, A, B = carry + + # Outer-loop update: scale row i of C by beta, but only for columns < i+1. + C_slice = jnp.where(jnp.arange(C.shape[1]) < (i + 1), C[i, :], 0.0) + C_slice = C_slice * beta + C_slice = jnp.where(jnp.arange(C.shape[1]) < (i + 1), C_slice, C[i, :]) + C = lax.dynamic_update_slice(C, C_slice[None, :], (i, 0)) + + # Define the inner scan that will update row i of C using index k. + def inner_body_fun(inner_carry, k): + # Unpack inner loop variables. + alpha_inner, C_inner, A_inner, B_inner = inner_carry + + # For A_update_slice and B_update_slice, only entries for indices < i+1 are used. + A_update_slice = jnp.where(jnp.arange(A_inner.shape[0]) < (i + 1), A_inner[:, k], 0.0) + A_update_slice = A_update_slice * (alpha_inner * B_inner[i, k]) + + B_update_slice = jnp.where(jnp.arange(B_inner.shape[0]) < (i + 1), B_inner[:, k], 0.0) + B_update_slice = B_update_slice * (alpha_inner * A_inner[i, k]) + + # Compute an update for row i of C: take its current values (only for indices < i+1) + # and add the contributions from A_update_slice and B_update_slice. + C_update_slice = jnp.where(jnp.arange(C_inner.shape[1]) < (i + 1), C_inner[i, :], 0.0) + C_update_slice = C_update_slice + A_update_slice + B_update_slice + # For indices not less than i+1, keep the original C[i, :]. + C_update_slice = jnp.where(jnp.arange(C_inner.shape[1]) < (i + 1), C_update_slice, C_inner[i, :]) + # Update row i of C. + C_inner = lax.dynamic_update_slice(C_inner, C_update_slice[None, :], (i, 0)) + return (alpha_inner, C_inner, A_inner, B_inner), None + + # Run the inner scan over k from 0 to n-1. + (alpha, C, A, B), _ = lax.scan(inner_body_fun, (alpha, C, A, B), jnp.arange(n)) + return (alpha, beta, C, A, B), None + + # Run the outer scan over i from 0 to m-1. + (alpha, beta, C, A, B), _ = lax.scan(outer_body_fun, (alpha, beta, C, A, B), jnp.arange(m)) + return jnp.sum(C) + + def ground_truth(alpha, beta, C, A, B): for i in range(A.shape[0]): @@ -61,23 +104,48 @@ def run_syr2k(device_type: dace.dtypes.DeviceType): sdfg = syr2k_kernel.to_sdfg() sdfg = auto_optimize(sdfg, device_type) sdfg(alpha, beta, C, A, B, M=M, N=N) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = syr2k_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - # No libnodes expansion for this kernel - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - sdfg.specialize(dict(M=M, N=N)) - sdfg(alpha, beta, C, A, B) - # Compute ground truth and validate ground_truth(alpha, beta, C_ref, A, B) assert np.allclose(C, C_ref) return sdfg +def run_syr2k_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (polybench mini size) + M, N = sizes["mini"] + alpha, beta, C, A, B = initialize(M, N) + + # Initialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(alpha: dc.float64, beta: dc.float64, C: dc.float64[N, N], A: dc.float64[N, M], + B: dc.float64[N, M]): + syr2k_kernel(alpha, beta, C, A, B) + return np.sum(C) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(alpha, beta, C, A, B, M=M, N=N, gradient_A=gradient_A, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda alpha, beta, C, A, B: syr2k_jax_kernel(jnp, lax, alpha, beta, C, A, B) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=3), static_argnums=(0, 1)) + alpha, beta, C_jax, A_jax, B_jax = initialize(M, N) + jax_grad_A = jax_grad(alpha, beta, C_jax, A_jax, B_jax) + np.testing.assert_allclose(gradient_A, jax_grad_A) + + def test_cpu(): run_syr2k(dace.dtypes.DeviceType.CPU) @@ -87,22 +155,22 @@ def test_gpu(): run_syr2k(dace.dtypes.DeviceType.GPU) -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_syr2k(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_syr2k_autodiff() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_syr2k(dace.dtypes.DeviceType.CPU) + run_syr2k_autodiff() elif target == "gpu": run_syr2k(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_syr2k(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/polybench/syrk_test.py b/tests/npbench/polybench/syrk_test.py index 6e92411128..33771b37be 100644 --- a/tests/npbench/polybench/syrk_test.py +++ b/tests/npbench/polybench/syrk_test.py @@ -6,10 +6,10 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG +from dace.transformation.interstate import InlineSDFG from dace.transformation.dataflow import StreamingMemory, MapFusionVertical, StreamingComposition, PruneConnectors -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt +from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass # M, N sizes = {"mini": (20, 30), "small": (60, 80), "medium": (200, 240), "large": (1000, 1200), "extra-large": (2000, 2600)} @@ -42,6 +42,48 @@ def init_data(N, M): return alpha, beta, C, A +def syrk_jax_kernel(jnp, lax, alpha, beta, C, A): + m = A.shape[0] # number of rows + n = A.shape[1] # number of columns + + def outer_body_fun(carry, i): + # Unpack outer loop carry. + alpha, beta, C, A = carry + + # Outer loop update: scale row i of C by beta for indices < i+1. + col_mask = jnp.arange(C.shape[1]) < (i + 1) + C_slice = jnp.where(col_mask, C[i, :], 0.0) + C_slice = C_slice * beta + # Preserve the original values for indices >= i+1. + C_slice = jnp.where(col_mask, C_slice, C[i, :]) + C = lax.dynamic_update_slice(C, C_slice[None, :], (i, 0)) + + # Define the inner loop which updates row i of C using column updates from A. + def inner_body_fun(inner_carry, k): + alpha_inner, C_inner, A_inner = inner_carry + + # Compute an update slice from A[:, k] for rows < i+1. + row_mask = jnp.arange(A_inner.shape[0]) < (i + 1) + A_update_slice = jnp.where(row_mask, A_inner[:, k], 0.0) + A_update_slice = A_update_slice * (alpha_inner * A_inner[i, k]) + + # Update C[i, :] by adding the A_update_slice, only for columns < i+1. + col_mask_inner = jnp.arange(C_inner.shape[1]) < (i + 1) + C_update_slice = jnp.where(col_mask_inner, C_inner[i, :], 0.0) + C_update_slice = C_update_slice + A_update_slice + C_update_slice = jnp.where(col_mask_inner, C_update_slice, C_inner[i, :]) + C_inner = lax.dynamic_update_slice(C_inner, C_update_slice[None, :], (i, 0)) + return (alpha_inner, C_inner, A_inner), None + + # Run the inner loop over k = 0,..., n-1. + (alpha, C, A), _ = lax.scan(inner_body_fun, (alpha, C, A), jnp.arange(n)) + return (alpha, beta, C, A), None + + # Run the outer loop over i = 0,..., m-1. + (alpha, beta, C, A), _ = lax.scan(outer_body_fun, (alpha, beta, C, A), jnp.arange(m)) + return jnp.sum(C) + + def ground_truth(N, M, alpha, beta, C, A): for i in range(N): @@ -68,24 +110,44 @@ def run_syrk(device_type: dace.dtypes.DeviceType): sdfg = auto_optimize(sdfg, device_type) sdfg(alpha=alpha, beta=beta, C=C, A=A, M=M, N=N) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - fpga_auto_opt.fpga_global_to_local(sdfg) - fpga_auto_opt.fpga_rr_interleave_containers_to_banks(sdfg) - sdfg.specialize(dict(N=N, M=M)) - # run program - sdfg(alpha=alpha, beta=beta, C=C, A=A) - # Compute ground truth and validate result ground_truth(N, M, alpha, beta, gt_C, A) assert np.allclose(C, gt_C) return sdfg +def run_syrk_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (polybench mini size) - note the order swap for this test + M, N = sizes["mini"] + alpha, beta, C, A = init_data(N, M) + + # Initialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float32) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(alpha: dc.float32, beta: dc.float32, C: dc.float32[N, N], A: dc.float32[N, M]): + kernel(alpha, beta, C, A) + return np.sum(C) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(alpha=alpha, beta=beta, C=C, A=A, M=M, N=N, gradient_A=gradient_A, gradient___return=gradient___return) + + # Numerically validate vs JAX + jax_kernel = lambda alpha, beta, C, A: syrk_jax_kernel(jnp, lax, alpha, beta, C, A) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=3), static_argnums=(0, 1)) + alpha, beta, C_jax, A_jax = init_data(N, M) + jax_grad_A = jax_grad(alpha, beta, C_jax, A_jax) + np.testing.assert_allclose(gradient_A, jax_grad_A, rtol=1e-6, atol=1e-5) + + def test_cpu(): run_syrk(dace.dtypes.DeviceType.CPU) @@ -95,22 +157,22 @@ def test_gpu(): run_syrk(dace.dtypes.DeviceType.GPU) -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_syrk(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_syrk_autodiff() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_syrk(dace.dtypes.DeviceType.CPU) + run_syrk_autodiff() elif target == "gpu": run_syrk(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_syrk(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/polybench/trisolv_test.py b/tests/npbench/polybench/trisolv_test.py index d9ec2a1802..af3be43b0a 100644 --- a/tests/npbench/polybench/trisolv_test.py +++ b/tests/npbench/polybench/trisolv_test.py @@ -5,11 +5,9 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test, xilinx_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt -from dace.config import set_temporary +from dace.transformation.interstate import InlineSDFG +from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass # Data set sizes # N @@ -30,6 +28,20 @@ def initialize(N, datatype=np.float64): return L, x, b +def trisolv_jax_kernel(jnp, lax, L, x, b): + + def scan_body(carry, i): + L, x, b = carry + mask = jnp.arange(x.shape[0]) < i + products = jnp.where(mask, L[i, :] * x, 0.0) + dot_product = jnp.sum(products) + x = x.at[i].set((b[i] - dot_product) / L[i, i]) + return (L, x, b), None + + (L, x, b), _ = lax.scan(scan_body, (L, x, b), jnp.arange(x.shape[0])) + return jnp.sum(x) + + def ground_truth(L, x, b): for i in range(x.shape[0]): x[i] = (b[i] - L[i, :i] @ x[:i]) / L[i, i] @@ -51,26 +63,47 @@ def run_trisolv(device_type: dace.dtypes.DeviceType): sdfg = trisolv_kernel.to_sdfg() sdfg = auto_optimize(sdfg, device_type) sdfg(L, x, np.copy(b), N=N) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = trisolv_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - # Use FPGA Expansion for lib nodes, and expand them to enable further optimizations - from dace.libraries.blas import Dot - Dot.default_implementation = "FPGA_PartialSums" - sdfg.expand_library_nodes() - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - sdfg.specialize(dict(N=N)) - sdfg(L, x, np.copy(b)) - # Compute ground truth and validate ground_truth(L, x_ref, b) assert np.allclose(x, x_ref) return sdfg +def run_trisolv_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (polybench mini size) + N = sizes["mini"] + L, x, b = initialize(N) + + # Initialize gradient computation data + gradient_L = np.zeros_like(L) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(L: dc.float64[N, N], x: dc.float64[N], b: dc.float64[N]): + trisolv_kernel(L, x, b) + return np.sum(x) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["L"], outputs=["__return"]) + sdfg(L, x, np.copy(b), N=N, gradient_L=gradient_L, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda L, x, b: trisolv_jax_kernel(jnp, lax, L, x, b) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=0)) + L_jax, x_jax, b_jax = initialize(N) + jax_grad_L = jax_grad(L_jax, x_jax, b_jax) + np.testing.assert_allclose(gradient_L, jax_grad_L) + + def test_cpu(): run_trisolv(dace.dtypes.DeviceType.CPU) @@ -80,22 +113,22 @@ def test_gpu(): run_trisolv(dace.dtypes.DeviceType.GPU) -@fpga_test(assert_ii_1=False, xilinx=False) -def test_fpga(): - return run_trisolv(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_trisolv_autodiff() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_trisolv(dace.dtypes.DeviceType.CPU) + run_trisolv_autodiff() elif target == "gpu": run_trisolv(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_trisolv(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/polybench/trmm_test.py b/tests/npbench/polybench/trmm_test.py index 51e2367df1..110591d07e 100644 --- a/tests/npbench/polybench/trmm_test.py +++ b/tests/npbench/polybench/trmm_test.py @@ -1,15 +1,14 @@ # Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. # Original application code: NPBench - https://github.com/spcl/npbench +import os import dace.dtypes import numpy as np import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test, xilinx_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt -from dace.config import set_temporary +from dace.transformation.interstate import InlineSDFG +from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass # Data set sizes # M, N @@ -37,6 +36,27 @@ def initialize(M, N, datatype=np.float64): return alpha, A, B +def trmm_jax_kernel(jnp, lax, alpha, A, B): + + def outer_body(carry, i): + B = carry + + def inner_body(B, j): + + mask = (jnp.arange(A.shape[0]) > i).astype(A.dtype) + dot_val = jnp.sum(A[:, i] * B[:, j] * mask) + new_val = B[i, j] + dot_val + B = B.at[i, j].set(new_val) + return B, jnp.array(0) + + B, _ = lax.scan(inner_body, B, jnp.arange(B.shape[1])) + return B, jnp.array(0) + + B, _ = lax.scan(outer_body, B, jnp.arange(B.shape[0])) + B = B * alpha + return jnp.sum(B) + + def ground_truth(alpha, A, B): for i in range(B.shape[0]): for j in range(B.shape[1]): @@ -60,26 +80,47 @@ def run_trmm(device_type: dace.dtypes.DeviceType): sdfg = trmm_kernel.to_sdfg() sdfg = auto_optimize(sdfg, device_type) sdfg(alpha, A, B, M=M, N=N) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = trmm_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - # Use FPGA Expansion for lib nodes, and expand them to enable further optimizations - from dace.libraries.blas import Dot - Dot.default_implementation = "FPGA_PartialSums" - sdfg.expand_library_nodes() - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - sdfg.specialize(dict(M=M, N=N)) - sdfg(alpha, A, B) - # Compute ground truth and validate ground_truth(alpha, A, B_ref) assert np.allclose(B, B_ref) return sdfg +def run_trmm_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (polybench mini size) + M, N = sizes["mini"] + alpha, A, B = initialize(M, N) + + # Initialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(alpha: dc.float64, A: dc.float64[M, M], B: dc.float64[M, N]): + trmm_kernel(alpha, A, B) + return np.sum(B) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(alpha, A, B, M=M, N=N, gradient_A=gradient_A, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda alpha, A, B: trmm_jax_kernel(jnp, lax, alpha, A, B) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=1), static_argnums=(0, )) + alpha, A_jax, B_jax = initialize(M, N) + jax_grad_A = jax_grad(alpha, A_jax, B_jax) + np.testing.assert_allclose(gradient_A, jax_grad_A) + + def test_cpu(): run_trmm(dace.dtypes.DeviceType.CPU) @@ -89,22 +130,27 @@ def test_gpu(): run_trmm(dace.dtypes.DeviceType.GPU) -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_trmm(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + # Serialization causes issues, we temporarily disable it + # TODO: open an issue to fix the serialization stability problem + last_value = os.environ.get('DACE_testing_serialization', '0') + os.environ['DACE_testing_serialization'] = '0' + run_trmm_autodiff() + os.environ['DACE_testing_serialization'] = last_value if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_trmm(dace.dtypes.DeviceType.CPU) + run_trmm_autodiff() elif target == "gpu": run_trmm(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_trmm(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/weather_stencils/hdiff_test.py b/tests/npbench/weather_stencils/hdiff_test.py index bd0150af91..8e16f3a098 100644 --- a/tests/npbench/weather_stencils/hdiff_test.py +++ b/tests/npbench/weather_stencils/hdiff_test.py @@ -5,10 +5,8 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass I, J, K = (dc.symbol(s, dtype=dc.int64) for s in ('I', 'J', 'K')) @@ -38,6 +36,32 @@ def hdiff_kernel(in_field: dc.float64[I + 4, J + 4, K], out_field: dc.float64[I, fly_field[:, 1:, :] - fly_field[:, :-1, :]) +def hdiff_jax_kernel(jnp, in_field, out_field, coeff): + I, J, K = out_field.shape[0], out_field.shape[1], out_field.shape[2] + lap_field = 4.0 * in_field[1:I + 3, 1:J + 3, :] - (in_field[2:I + 4, 1:J + 3, :] + in_field[0:I + 2, 1:J + 3, :] + + in_field[1:I + 3, 2:J + 4, :] + in_field[1:I + 3, 0:J + 2, :]) + + res = lap_field[1:, 1:J + 1, :] - lap_field[:-1, 1:J + 1, :] + flx_field = jnp.where( + (res * (in_field[2:I + 3, 2:J + 2, :] - in_field[1:I + 2, 2:J + 2, :])) > 0, + 0, + res, + ) + + res = lap_field[1:I + 1, 1:, :] - lap_field[1:I + 1, :-1, :] + fly_field = jnp.where( + (res * (in_field[2:I + 2, 2:J + 3, :] - in_field[2:I + 2, 1:J + 2, :])) > 0, + 0, + res, + ) + + out_field = out_field.at[:, :, :].set( + in_field[2:I + 2, 2:J + 2, :] - coeff[:, :, :] * + (flx_field[1:, :, :] - flx_field[:-1, :, :] + fly_field[:, 1:, :] - fly_field[:, :-1, :])) + + return jnp.sum(out_field) + + def initialize(I, J, K): from numpy.random import default_rng rng = default_rng(42) @@ -89,15 +113,6 @@ def run_hdiff(device_type: dace.dtypes.DeviceType): sdfg = hdiff_kernel.to_sdfg() sdfg = auto_optimize(sdfg, device_type) sdfg(in_field, out_field, coeff, I=I, J=J, K=K) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = hdiff_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - sdfg.specialize(dict(I=I, J=J, K=K)) - sdfg(in_field, out_field, coeff) # Compute ground truth and validate ground_truth(in_field, out_field_ref, coeff) @@ -105,6 +120,47 @@ def run_hdiff(device_type: dace.dtypes.DeviceType): return sdfg +def run_hdiff_autodiff(): + import jax + import jax.numpy as jnp + + # Initialize data (npbench small size) + I, J, K = 64, 64, 60 + in_field, out_field, coeff = initialize(I, J, K) + + # Initialize gradient computation data + gradient_in_field = np.zeros_like(in_field) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(in_field: dc.float64[I + 4, J + 4, K], out_field: dc.float64[I, J, K], coeff: dc.float64[I, J, + K]): + hdiff_kernel(in_field, out_field, coeff) + return np.sum(out_field) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["in_field"], outputs=["__return"]) + sdfg(in_field, + out_field, + coeff, + I=I, + J=J, + K=K, + gradient_in_field=gradient_in_field, + gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda in_field, out_field, coeff: hdiff_jax_kernel(jnp, in_field, out_field, coeff) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=0)) + jax_grad_in_field = jax_grad(in_field, out_field, coeff) + np.testing.assert_allclose(gradient_in_field, jax_grad_in_field) + + def test_cpu(): run_hdiff(dace.dtypes.DeviceType.CPU) @@ -114,22 +170,22 @@ def test_gpu(): run_hdiff(dace.dtypes.DeviceType.GPU) -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_hdiff(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_hdiff_autodiff() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_hdiff(dace.dtypes.DeviceType.CPU) + run_hdiff_autodiff() elif target == "gpu": run_hdiff(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_hdiff(dace.dtypes.DeviceType.FPGA) diff --git a/tests/npbench/weather_stencils/vadv_test.py b/tests/npbench/weather_stencils/vadv_test.py index cf01a0cd31..213c2fb783 100644 --- a/tests/npbench/weather_stencils/vadv_test.py +++ b/tests/npbench/weather_stencils/vadv_test.py @@ -5,10 +5,9 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test -from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass + # Sample constants BET_M = 0.5 BET_P = 0.5 @@ -89,9 +88,92 @@ def vadv_kernel(utens_stage: dc.float64[I, J, K], u_stage: dc.float64[I, J, K], for k in range(K - 2, -1, -1): # datacol = dcol[:, :, k] - ccol[:, :, k] * data_col[:, :] - datacol[:] = dcol[:, :, k] - ccol[:, :, k] * data_col[:, :] - data_col[:] = datacol - utens_stage[:, :, k] = dtr_stage * (datacol - u_pos[:, :, k]) + data_col[:] = dcol[:, :, k] - ccol[:, :, k] * data_col[:, :] + utens_stage[:, :, k] = dtr_stage * (data_col - u_pos[:, :, k]) + + +def vadv_jax_kernel(jnp, lax, utens_stage, u_stage, wcon, u_pos, utens, dtr_stage): + I, J, K = utens_stage.shape[0], utens_stage.shape[1], utens_stage.shape[2] + # Allocate working arrays. + ccol = jnp.empty((I, J, K), dtype=utens_stage.dtype) + dcol = jnp.empty((I, J, K), dtype=utens_stage.dtype) + data_col = jnp.empty((I, J), dtype=utens_stage.dtype) + + # --- Loop 1: for k in range(0, 1) --- + def loop1_body(carry, k): + ccol, dcol = carry + # Note: 0+1 is just 1. + gcv = 0.25 * (wcon[1:, :, 1] + wcon[:-1, :, 1]) + cs = gcv * BET_M + bs = gcv * BET_P + bcol = dtr_stage - bs + # update the d column correction term. + correction_term = -cs * (u_stage[:, :, k + 1] - u_stage[:, :, k]) + divided = 1.0 / bcol + ccol = ccol.at[:, :, k].set(bs * divided) + dcol = dcol.at[:, :, k].set( + (dtr_stage * u_pos[:, :, k] + utens[:, :, k] + utens_stage[:, :, k] + correction_term) * divided) + return (ccol, dcol), None + + (ccol, dcol), _ = lax.scan(loop1_body, (ccol, dcol), jnp.arange(0, 1)) + + # --- Loop 2: for k in range(1, K-1) --- + def loop2_body(carry, k): + ccol, dcol = carry + gav = -0.25 * (wcon[1:, :, k] + wcon[:-1, :, k]) + gcv = 0.25 * (wcon[1:, :, k + 1] + wcon[:-1, :, k + 1]) + as_ = gav * BET_M + cs = gcv * BET_M + bs = gcv * BET_P + acol = gav * BET_P + bcol = dtr_stage - acol - bs + correction_term = (-as_ * (u_stage[:, :, k - 1] - u_stage[:, :, k]) - cs * + (u_stage[:, :, k + 1] - u_stage[:, :, k])) + divided = 1.0 / (bcol - ccol[:, :, k - 1] * acol) + ccol = ccol.at[:, :, k].set(bs * divided) + dcol = dcol.at[:, :, k].set( + ((dtr_stage * u_pos[:, :, k] + utens[:, :, k] + utens_stage[:, :, k] + correction_term) - + dcol[:, :, k - 1] * acol) * divided) + return (ccol, dcol), None + + (ccol, dcol), _ = lax.scan(loop2_body, (ccol, dcol), jnp.arange(1, K - 1)) + + # --- Loop 3: for k in range(K-1, K) --- + def loop3_body(dcol, k): + gav = -0.25 * (wcon[1:, :, k] + wcon[:-1, :, k]) + as_ = gav * BET_M + acol = gav * BET_P + bcol = dtr_stage - acol + correction_term = -as_ * (u_stage[:, :, k - 1] - u_stage[:, :, k]) + divided = 1.0 / (bcol - ccol[:, :, k - 1] * acol) + dcol = dcol.at[:, :, k].set( + ((dtr_stage * u_pos[:, :, k] + utens[:, :, k] + utens_stage[:, :, k] + correction_term) - + dcol[:, :, k - 1] * acol) * divided) + return dcol, None + + dcol, _ = lax.scan(loop3_body, dcol, jnp.arange(K - 1, K)) + + # --- Loop 4: for k in range(K-1, K) --- + def loop4_body(carry, k): + data_col, utens_stage = carry + datacol = dcol[:, :, k] + data_col = data_col.at[:].set(datacol) + utens_stage = utens_stage.at[:, :, k].set(dtr_stage * (datacol - u_pos[:, :, k])) + return (data_col, utens_stage), None + + (data_col, utens_stage), _ = lax.scan(loop4_body, (data_col, utens_stage), jnp.arange(K - 1, K)) + + # --- Loop 5: for k in range(0, K-1) with reverse order --- + def loop5_body(carry, k): + data_col, utens_stage = carry + idx = (K - 2) - k # Reverse order: when k=0, idx=K-2; when k=K-2, idx=0. + datacol = dcol[:, :, idx] - ccol[:, :, idx] * data_col[:, :] + data_col = data_col.at[:].set(datacol) + utens_stage = utens_stage.at[:, :, idx].set(dtr_stage * (datacol - u_pos[:, :, idx])) + return (data_col, utens_stage), None + + (data_col, utens_stage), _ = lax.scan(loop5_body, (data_col, utens_stage), jnp.arange(0, K - 1)) + return jnp.sum(utens_stage) def initialize(I, J, K): @@ -195,15 +277,6 @@ def run_vadv(device_type: dace.dtypes.DeviceType): sdfg = vadv_kernel.to_sdfg() sdfg = auto_optimize(sdfg, device_type) sdfg(utens_stage, u_stage, wcon, u_pos, utens, dtr_stage, I=I, J=J, K=K) - elif device_type == dace.dtypes.DeviceType.FPGA: - # Parse SDFG and apply FPGA friendly optimization - sdfg = vadv_kernel.to_sdfg(simplify=True) - applied = sdfg.apply_transformations([FPGATransformSDFG]) - assert applied == 1 - - sdfg.apply_transformations_repeated([InlineSDFG], print_report=True) - sdfg.specialize(dict(I=I, J=J, K=K)) - sdfg(utens_stage, u_stage, wcon, u_pos, utens, dtr_stage) # Compute ground truth and validate ground_truth(utens_stage_ref, u_stage, wcon, u_pos, utens, dtr_stage) @@ -211,6 +284,55 @@ def run_vadv(device_type: dace.dtypes.DeviceType): return sdfg +def run_vadv_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (npbench small size) + I, J, K = 4, 4, 3 + dtr_stage, utens_stage, u_stage, wcon, u_pos, utens = initialize(I, J, K) + dtr_stage_jax, utens_stage_jax, u_stage_jax, wcon_jax, u_pos_jax, utens_jax = [ + np.copy(arr) for arr in (dtr_stage, utens_stage, u_stage, wcon, u_pos, utens) + ] + + # Initialize gradient computation data + gradient_utens = np.zeros_like(utens) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(utens_stage: dc.float64[I, J, K], u_stage: dc.float64[I, J, K], wcon: dc.float64[I + 1, J, K], + u_pos: dc.float64[I, J, K], utens: dc.float64[I, J, K], dtr_stage: dc.float64): + vadv_kernel(utens_stage, u_stage, wcon, u_pos, utens, dtr_stage) + return np.sum(utens_stage) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["utens"], outputs=["__return"]) + sdfg(utens_stage, + u_stage, + wcon, + u_pos, + utens, + dtr_stage, + I=I, + J=J, + K=K, + gradient_utens=gradient_utens, + gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda utens_stage, u_stage, wcon, u_pos, utens, dtr_stage: vadv_jax_kernel( + jnp, lax, utens_stage, u_stage, wcon, u_pos, utens, dtr_stage) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=4)) + jax_grad_utens = jax_grad(utens_stage_jax, u_stage_jax, wcon_jax, u_pos_jax, utens_jax, dtr_stage_jax) + np.testing.assert_allclose(gradient_utens, jax_grad_utens) + + def test_cpu(monkeypatch): # NOTE: Serialization fails because of "k - k" expression simplified to "0" monkeypatch.setenv("DACE_testing_serialization", 0) @@ -222,23 +344,22 @@ def test_gpu(): run_vadv(dace.dtypes.DeviceType.GPU) -@pytest.mark.skip(reason="Xilinx internal compiler error") -@fpga_test(assert_ii_1=False) -def test_fpga(): - return run_vadv(dace.dtypes.DeviceType.FPGA) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_vadv_autodiff() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu', 'fpga'], help='Target platform') + parser.add_argument("-t", "--target", default='cpu', choices=['cpu', 'gpu'], help='Target platform') args = vars(parser.parse_args()) target = args["target"] if target == "cpu": run_vadv(dace.dtypes.DeviceType.CPU) + run_vadv_autodiff() elif target == "gpu": run_vadv(dace.dtypes.DeviceType.GPU) - elif target == "fpga": - run_vadv(dace.dtypes.DeviceType.FPGA) diff --git a/tests/onnx/pure_expansions/test_conv_expansion.py b/tests/onnx/pure_expansions/test_conv_expansion.py new file mode 100644 index 0000000000..8ca3616b29 --- /dev/null +++ b/tests/onnx/pure_expansions/test_conv_expansion.py @@ -0,0 +1,61 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("onnx", reason="ONNX not installed. Please install with: pip install dace[ml]") +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +import dace +import dace.libraries.onnx as donnx +import torch +import torch.nn.functional as F +import numpy as np + + +@pytest.mark.onnx +@pytest.mark.parametrize("num_in_channels, kernel_size, num_filters, bias", + [(1, (3, 3), 8, True), (8, (3, 3), 3, False), (8, (5, 5), 3, True), (8, (4, 4), 3, False)]) +def test_conv_simple(num_in_channels, kernel_size, num_filters, bias): + + batch_size = 8 + + X = np.random.rand(batch_size, num_in_channels, 32, 32).astype(np.float32) + W = np.random.rand(num_filters, num_in_channels, *kernel_size).astype(np.float32) + + if bias: + B = np.random.rand(num_filters).astype(np.float32) + torch_Z = F.conv2d(torch.from_numpy(X), torch.from_numpy(W), bias=torch.from_numpy(B)).numpy() + else: + B = None + torch_Z = F.conv2d(torch.from_numpy(X), torch.from_numpy(W)).numpy() + + dace_Z = np.zeros_like(torch_Z) + + if bias: + + @dace.program + def conv(X_: dace.float32[tuple(X.shape)], W_: dace.float32[tuple(W.shape)], B_: dace.float32[tuple(B.shape)], + Z_: dace.float32[tuple(torch_Z.shape)]): + donnx.ONNXConv(X=X_, W=W_, B=B_, Y=Z_) + else: + + @dace.program + def conv(X_: dace.float32[tuple(X.shape)], W_: dace.float32[tuple(W.shape)], + Z_: dace.float32[tuple(torch_Z.shape)]): + donnx.ONNXConv(X=X_, W=W_, Y=Z_) + + sdfg = conv.to_sdfg() + sdfg.expand_library_nodes() + + if bias: + sdfg(X_=X, W_=W, Z_=dace_Z, B_=B) + else: + sdfg(X_=X, W_=W, Z_=dace_Z) + + print(torch_Z - dace_Z) + assert np.allclose(torch_Z, dace_Z) + + +if __name__ == "__main__": + # Test with different parameter combinations + params = [(1, (3, 3), 8, True), (8, (3, 3), 3, False), (8, (5, 5), 3, True), (8, (4, 4), 3, False)] + for num_in_channels, kernel_size, num_filters, bias in params: + test_conv_simple(num_in_channels, kernel_size, num_filters, bias) diff --git a/tests/onnx/pure_expansions/test_expansion_utils.py b/tests/onnx/pure_expansions/test_expansion_utils.py new file mode 100644 index 0000000000..d71c54dc5a --- /dev/null +++ b/tests/onnx/pure_expansions/test_expansion_utils.py @@ -0,0 +1,42 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("onnx", reason="ONNX not installed. Please install with: pip install dace[ml]") +import numpy as np + +import dace +import dace.libraries.onnx as donnx + + +@pytest.mark.onnx +def test_sqrt_expansion(): + # sqrt expansion makes use of the program_for_node function + sdfg = dace.SDFG("test_sqrt_expansion") + + sdfg.add_array("inp", [2, 4], dace.float32) + sdfg.add_array("__return", [2, 4], dace.float32) + + state = sdfg.add_state() + access_in = state.add_access("inp") + access_result = state.add_access("__return") + + op_node = donnx.ONNXSqrt("sqrt") + + state.add_node(op_node) + state.add_edge(access_in, None, op_node, "X", sdfg.make_array_memlet("inp")) + + state.add_edge(op_node, "Y", access_result, None, sdfg.make_array_memlet("__return")) + + X = np.random.rand(2, 4).astype(np.float32) + + sdfg.expand_library_nodes() + # check that the expansion worked. The default ORT expansion wouldn't produce a map + assert any(isinstance(n, dace.nodes.MapEntry) for n, _ in sdfg.all_nodes_recursive()) + + result = sdfg(inp=X) + + assert np.allclose(np.sqrt(X), result) + + +if __name__ == "__main__": + test_sqrt_expansion() diff --git a/tests/onnx/pure_expansions/test_expansions.py b/tests/onnx/pure_expansions/test_expansions.py new file mode 100644 index 0000000000..1c70738c92 --- /dev/null +++ b/tests/onnx/pure_expansions/test_expansions.py @@ -0,0 +1,561 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("onnx", reason="ONNX not installed. Please install with: pip install dace[ml]") +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") + +import copy +import numpy as np + +import dace +from dace import transformation, data as dt +from dace.libraries import blas +import dace.library + +import dace.libraries.onnx as donnx +from dace.transformation.onnx import expand_onnx_nodes + + +def assert_allclose(a, b, rtol=1e-5, atol=1e-8): + np.testing.assert_allclose(a, b, rtol=rtol, atol=atol) + + +@pytest.mark.onnx +@pytest.mark.parametrize("a_shape, b_shape", [([2, 4], [4, 3])]) +def test_matmul_expansion(a_shape, b_shape): + blas.Gemm.default_implementation = "pure" + sdfg = dace.SDFG("test_matmul_expansion") + + X = np.random.rand(*a_shape).astype(np.float32) + Z = np.random.rand(*b_shape).astype(np.float32) + expected_result = X @ Z + sdfg.add_array("X", a_shape, dace.float32) + sdfg.add_array("Z", b_shape, dace.float32) + sdfg.add_array("__return", expected_result.shape, dace.float32) + + state = sdfg.add_state() + access_X = state.add_access("X") + access_Z = state.add_access("Z") + access_result = state.add_access("__return") + + op_node = donnx.ONNXMatMul("Matmul") + + state.add_node(op_node) + state.add_edge(access_X, None, op_node, "A", sdfg.make_array_memlet("X")) + state.add_edge(access_Z, None, op_node, "B", sdfg.make_array_memlet("Z")) + + state.add_edge(op_node, "Y", access_result, None, sdfg.make_array_memlet("__return")) + + with dace.library.change_default(blas, "pure"): + sdfg.expand_library_nodes() + # check that the expansion worked. The default ORT expansion contains a Tasklet with suffix _onnx_code + assert not any( + isinstance(n, dace.nodes.Tasklet) and n.name.endswith("_onnx_code") for n, _ in sdfg.all_nodes_recursive()) + + result = sdfg(X=X, Z=Z) + + assert_allclose(expected_result, result) + + +@pytest.mark.onnx +def test_cast_int_to_float(): + sdfg = dace.SDFG("test_cast_int_to_float") + + sdfg.add_array("X", [2, 4], dace.int32) + sdfg.add_array("__return", [2, 4], dace.float32) + + state = sdfg.add_state() + access_X = state.add_access("X") + access_result = state.add_access("__return") + + op_node = donnx.ONNXCast("Cast") + op_node.to = donnx.converters.typeclass_to_onnx_tensor_type_int(dace.float32) + + state.add_node(op_node) + state.add_edge(access_X, None, op_node, "input", sdfg.make_array_memlet("X")) + + state.add_edge(op_node, "output", access_result, None, sdfg.make_array_memlet("__return")) + + X = np.random.randint(0, 10, size=(2, 4), dtype=np.int32) + + sdfg.expand_library_nodes() + # check that the expansion worked. The default ORT expansion contains a Tasklet with suffix _onnx_code + assert not any( + isinstance(n, dace.nodes.Tasklet) and n.name.endswith("_onnx_code") for n, _ in sdfg.all_nodes_recursive()) + + result = sdfg(X=X) + + assert_allclose(X.astype(np.float32), result) + + +@pytest.mark.onnx +def test_cast_float_to_int(): + sdfg = dace.SDFG("test_cast_float_to_int") + + sdfg.add_array("X", [2, 4], dace.float32) + sdfg.add_array("__return", [2, 4], dace.int32) + + state = sdfg.add_state() + access_X = state.add_access("X") + access_result = state.add_access("__return") + + op_node = donnx.ONNXCast("Cast") + op_node.to = donnx.converters.typeclass_to_onnx_tensor_type_int(dace.int32) + + state.add_node(op_node) + state.add_edge(access_X, None, op_node, "input", sdfg.make_array_memlet("X")) + + state.add_edge(op_node, "output", access_result, None, sdfg.make_array_memlet("__return")) + + X = np.random.normal(scale=10, size=(2, 4)).astype(np.float32) + + sdfg.expand_library_nodes() + # check that the expansion worked. The default ORT expansion contains a Tasklet with suffix _onnx_code + assert not any( + isinstance(n, dace.nodes.Tasklet) and n.name.endswith("_onnx_code") for n, _ in sdfg.all_nodes_recursive()) + + result = sdfg(X=X) + + assert_allclose(X.astype(np.int32), result) + + +@pytest.mark.onnx +def test_cast_float_to_long(): + sdfg = dace.SDFG("test_cast_float_to_long") + + sdfg.add_array("X", [2, 4], dace.float32) + sdfg.add_array("__return", [2, 4], dace.int64) + + state = sdfg.add_state() + access_X = state.add_access("X") + access_result = state.add_access("__return") + + op_node = donnx.ONNXCast("Cast") + op_node.to = donnx.converters.typeclass_to_onnx_tensor_type_int(dace.int64) + + state.add_node(op_node) + state.add_edge(access_X, None, op_node, "input", sdfg.make_array_memlet("X")) + + state.add_edge(op_node, "output", access_result, None, sdfg.make_array_memlet("__return")) + + X = np.random.normal(scale=10, size=(2, 4)).astype(np.float32) + + sdfg.expand_library_nodes() + # check that the expansion worked. The default ORT expansion contains a Tasklet with suffix _onnx_code + assert not any( + isinstance(n, dace.nodes.Tasklet) and n.name.endswith("_onnx_code") for n, _ in sdfg.all_nodes_recursive()) + + result = sdfg(X=X) + + assert_allclose(X.astype(np.int64), result) + + +@pytest.mark.onnx +#+yapf: disable +@pytest.mark.parametrize("reduce_type, keepdims, axes", + [('Sum', True, [0]), + ('Sum', False, [-1]), + ('Sum', True, [0, -1]), + ('Max', False, [0, -1]), + ('Max', True, [0]), + ('Max', True, [-1]), + ('Mean', True, [-1]), + ('Mean', True, [0, -1]), + ('Mean', False, [0])]) +#+yapf: enable +def test_reduce(keepdims, reduce_type, axes): + + X = np.random.normal(scale=10, size=(2, 4, 10)).astype(np.float32) + + sdfg = dace.SDFG("test_reduce") + + sdfg.add_array("X", [2, 4, 10], dace.float32) + + numpy_func = getattr(np, reduce_type.lower()) + numpy_result = numpy_func(X.copy(), axis=tuple(axes), keepdims=keepdims) + + resulting_shape = numpy_result.shape + + sdfg.add_array("__return", resulting_shape, dace.float32) + + state = sdfg.add_state() + access_X = state.add_access("X") + access_result = state.add_access("__return") + + op_node = getattr(donnx, "ONNXReduce" + reduce_type)("reduce") + op_node.axes = axes + op_node.keepdims = 1 if keepdims else 0 + + state.add_node(op_node) + state.add_edge(access_X, None, op_node, "data", sdfg.make_array_memlet("X")) + + state.add_edge(op_node, "reduced", access_result, None, sdfg.make_array_memlet("__return")) + + sdfg.expand_library_nodes() + # check that the expansion worked. The default ORT expansion contains a Tasklet with suffix _onnx_code + assert not any( + isinstance(n, dace.nodes.Tasklet) and n.name.endswith("_onnx_code") for n, _ in sdfg.all_nodes_recursive()) + result = sdfg(X=X) + + assert_allclose(numpy_result, result, rtol=1e-5, atol=1e-5) + + +@pytest.mark.onnx +def test_reduce_scalar(): + X = np.random.normal(scale=10, size=(2, 4, 10)).astype(np.float32) + + sdfg = dace.SDFG("test_reduce_scalar") + + numpy_result = np.mean(X) + + sdfg.add_array("X", [2, 4, 10], dace.float32) + sdfg.add_scalar("Y", dace.float32, transient=True) + sdfg.add_array("__return", [1], dace.float32) + + state = sdfg.add_state() + access_X = state.add_access("X") + access_Y = state.add_access("Y") + access_result = state.add_access("__return") + + op_node = donnx.ONNXReduceMean("mean") + op_node.keepdims = 0 + + state.add_node(op_node) + state.add_edge(access_X, None, op_node, "data", sdfg.make_array_memlet("X")) + + state.add_edge(op_node, "reduced", access_Y, None, sdfg.make_array_memlet("Y")) + + state.add_edge(access_Y, None, access_result, None, sdfg.make_array_memlet("__return")) + + sdfg.expand_library_nodes() + + result = sdfg(X=X) + + assert_allclose(numpy_result, result, rtol=1e-5, atol=1e-5) + + +@pytest.mark.onnx +@pytest.mark.parametrize("new_shape", [[8, 10], [80], [2, 40]]) +def test_reshape(new_shape): + X = np.random.normal(scale=10, size=(2, 4, 10)).astype(np.float32) + + sdfg = dace.SDFG("test_reshape") + + numpy_result = X.reshape(*new_shape) + + sdfg.add_array("X", [2, 4, 10], dace.float32) + sdfg.add_array("shape", [len(new_shape)], dace.int64) + sdfg.add_array("__return", new_shape, dace.float32) + + state = sdfg.add_state() + access_X = state.add_access("X") + access_shape = state.add_access("shape") + access_result = state.add_access("__return") + + op_node = donnx.ONNXReshape("reshape") + + state.add_node(op_node) + state.add_edge(access_X, None, op_node, "data", sdfg.make_array_memlet("X")) + state.add_edge(access_shape, None, op_node, "shape", sdfg.make_array_memlet("shape")) + + state.add_edge(op_node, "reshaped", access_result, None, sdfg.make_array_memlet("__return")) + + sdfg.expand_library_nodes() + + # we don't need shape anymore + del sdfg.arrays["shape"] + + result = sdfg(X=X) + + assert_allclose(numpy_result, result) + + +@pytest.mark.onnx +def test_flatten(): + + new_shape = [2, 40] + X = np.random.normal(scale=10, size=(2, 4, 10)).astype(np.float32) + + sdfg = dace.SDFG("test_flatten") + + numpy_result = X.reshape(*new_shape) + + sdfg.add_array("X", [2, 4, 10], dace.float32) + sdfg.add_array("__return", new_shape, dace.float32) + + state = sdfg.add_state() + access_X = state.add_access("X") + access_result = state.add_access("__return") + + op_node = donnx.ONNXFlatten("flatten") + + state.add_node(op_node) + state.add_edge(access_X, None, op_node, "input", sdfg.make_array_memlet("X")) + + state.add_edge(op_node, "output", access_result, None, sdfg.make_array_memlet("__return")) + + sdfg.expand_library_nodes() + + result = sdfg(X=X) + + assert_allclose(numpy_result, result) + + +@pytest.mark.onnx +def test_reciprocal(): + X = np.random.normal(scale=10, size=(2, 4, 10)).astype(np.float32) + + numpy_result = 1 / X + sdfg = dace.SDFG("test_reciprocal") + + sdfg.add_array("X", [2, 4, 10], dace.float32) + sdfg.add_array("__return", numpy_result.shape, dace.float32) + + state = sdfg.add_state() + access_X = state.add_access("X") + access_result = state.add_access("__return") + + op_node = donnx.ONNXReciprocal("reciprocal") + + state.add_node(op_node) + state.add_edge(access_X, None, op_node, "X", sdfg.make_array_memlet("X")) + + state.add_edge(op_node, "Y", access_result, None, sdfg.make_array_memlet("__return")) + + sdfg.expand_library_nodes() + + # check that the expansion worked. The default ORT expansion contains a Tasklet with suffix _onnx_code + assert not any( + isinstance(n, dace.nodes.Tasklet) and n.name.endswith("_onnx_code") for n, _ in sdfg.all_nodes_recursive()) + + result = sdfg(X=X) + + assert_allclose(numpy_result, result) + + +@pytest.mark.onnx +def test_einsum(): + + @dace.program + def test_einsum(A: dace.float64[5, 4, 3], B: dace.float64[3, 2]): + Y = dace.define_local([5, 4, 2], dace.float64) + donnx.ONNXEinsum(Inputs__0=A, Inputs__1=B, Output=Y, equation="bij, jk -> bik") + return Y + + sdfg = test_einsum.to_sdfg() + expand_onnx_nodes(sdfg) + assert any(isinstance(n, blas.Gemm) for n, _ in sdfg.all_nodes_recursive()) + + A = np.random.rand(5, 4, 3).astype(np.float64) + B = np.random.rand(3, 2).astype(np.float64) + result = test_einsum(A.copy(), B.copy()) + assert_allclose(result, np.einsum("bij ,jk -> bik", A, B)) + + +@pytest.mark.onnx +def test_reshape_add(): + + @dace.program + def add_reshape(inp: dace.float64[9], bias: dace.float64[3], target_shape: dace.int64[2]): + reshaped = dace.define_local([3, 3], dace.float64) + donnx.ONNXReshape(data=inp, shape=target_shape, reshaped=reshaped) + + return reshaped + bias + + sdfg: dace.SDFG = add_reshape.to_sdfg(simplify=False) + + sdfg.apply_transformations_repeated([transformation.interstate.StateFusion]) + + inp = np.arange(9).astype(np.float64) + bias = np.arange(3).astype(np.float64) + result = sdfg(inp=inp.copy(), bias=bias.copy(), target_shape=np.array([3, 3]).astype(np.int64)) + + assert_allclose(result, inp.reshape(3, 3) + bias) + + +@pytest.mark.onnx +@pytest.mark.parametrize("input_desc", [dace.float32[2, 3], dace.float32[1], dace.float32]) +def test_sum_arrays(input_desc): + + if isinstance(input_desc, dt.Array): + shape = input_desc.shape + else: + shape = [1] + + def prog(inp0: copy.deepcopy(input_desc), inp1: copy.deepcopy(input_desc), inp2: copy.deepcopy(input_desc)): + result = dace.define_local(shape, dace.float32) + donnx.ONNXSum(data_0__0=inp0, data_0__1=inp1, data_0__2=inp2, sum=result) + return result + + prog.__name__ = "test_sum_arrays" + prog = dace.program(prog) + + inputs = [np.random.randn(*shape).astype(np.float32) for _ in range(3)] + if not isinstance(input_desc, dt.Array): + inputs = [i[0] for i in inputs] + np_result = (inputs[0] + inputs[1]) + inputs[2] + result = prog(*inputs) + + assert_allclose(result, np_result) + + +@pytest.mark.onnx +def test_shape(): + + @dace.program + def shape(inp: dace.float64[9, 5, 3]): + shp = dace.define_local([3], dace.int64) + donnx.ONNXShape(data=inp, shape=shp) + return shp + + sdfg: dace.SDFG = shape.to_sdfg() + sdfg.expand_library_nodes() + sdfg.simplify() + + inp = np.random.rand(9, 5, 3).astype(np.float64) + result = sdfg(inp=inp.copy()) + assert_allclose(result, [9, 5, 3]), result + + +@pytest.mark.onnx +def test_gather_onnx_1(): + # gather in ONNX operators.md + @dace.program + def gather(inp: dace.float64[3, 2], indices: dace.int64[2, 2]): + output = dace.define_local([2, 2, 2], dace.float64) + donnx.ONNXGather(data=inp, output=output, indices=indices, axis=0) + return output + + sdfg: dace.SDFG = gather.to_sdfg() + sdfg.expand_library_nodes() + sdfg.simplify() + + data = np.array([[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]]) + indices = np.array([[0, 1], [1, 2]]) + result = sdfg(inp=data.copy(), indices=indices.copy()) + assert_allclose(result, data[indices]) + + +@pytest.mark.onnx +def test_gather_bert(): + # gather found at start of bert model + @dace.program + def gather(embs: dace.float64[64, 8], input_ids: dace.int64[8, 16]): + output = dace.define_local([8, 16, 8], dace.float64) + donnx.ONNXGather(data=embs, output=output, indices=input_ids, axis=0) + return output + + sdfg: dace.SDFG = gather.to_sdfg() + sdfg.expand_library_nodes() + sdfg.simplify() + + embs = np.random.rand(64, 8).astype(np.float64) + input_ids = np.random.randint(low=0, high=64, size=(8, 16)).astype(np.int64) + result = sdfg(embs=embs.copy(), input_ids=input_ids.copy()) + assert_allclose(result, embs[input_ids]) + + +@pytest.mark.onnx +def test_gather_scalar(): + # gather test 2 in BERT model (third last op) + @dace.program + def gather(inp: dace.float64[1, 8, 32], indices: dace.int64): + output = dace.define_local([1, 32], dace.float64) + donnx.ONNXGather(data=inp, output=output, indices=indices, axis=1) + return output + + sdfg: dace.SDFG = gather.to_sdfg() + sdfg.expand_library_nodes() + sdfg.simplify() + + data = np.random.rand(1, 8, 32) + indices = np.int64(5) + result = sdfg(inp=data.copy(), indices=indices.copy()) + np_result = np.take(data, indices, axis=1) + + assert_allclose(result, np_result) + + +@pytest.mark.onnx +def test_gather_onnx_2(): + # gather test 2 in ONNX operators.md + @dace.program + def gather(inp: dace.float64[3, 3], indices: dace.int64[1, 2]): + output = dace.define_local([3, 1, 2], dace.float64) + donnx.ONNXGather(data=inp, output=output, indices=indices, axis=1) + return output + + sdfg: dace.SDFG = gather.to_sdfg() + sdfg.expand_library_nodes() + sdfg.simplify() + + data = np.array([ + [1.0, 1.2, 1.9], + [2.3, 3.4, 3.9], + [4.5, 5.7, 5.9], + ]) + indices = np.array([[0, 2]]) + result = sdfg(inp=data.copy(), indices=indices.copy()) + np_result = np.take(data, indices, axis=1) + + assert_allclose(result, np_result) + + +@pytest.mark.onnx +def test_unsqueeze(): + + @dace.program + def unsqueeze(inp: dace.float64[3, 3]): + output = dace.define_local([3, 1, 3, 1], dace.float64) + axes = dace.define_local([2], dace.int64) + axes[0] = 1 + axes[1] = 3 + donnx.ONNXUnsqueeze(data=inp, expanded=output, axes=axes) + return output + + sdfg: dace.SDFG = unsqueeze.to_sdfg() + + data = np.array([ + [1.0, 1.2, 1.9], + [2.3, 3.4, 3.9], + [4.5, 5.7, 5.9], + ]) + + np_result = np.reshape(data, [3, 1, 3, 1]) + + result = sdfg(inp=data.copy()) + assert result.shape == (3, 1, 3, 1) + assert_allclose(result, np_result) + + +if __name__ == "__main__": + test_matmul_expansion(a_shape=[2, 4], b_shape=[4, 3]) + test_cast_int_to_float() + test_cast_float_to_int() + test_cast_float_to_long() + + reduce_params = [(True, 'Sum', [0]), (False, 'Sum', [-1]), (True, 'Sum', [0, -1]), (False, 'Max', [0, -1]), + (True, 'Max', [0]), (True, 'Max', [-1]), (True, 'Mean', [-1]), (True, 'Mean', [0, -1]), + (False, 'Mean', [0])] + for keepdims, reduce_type, axes in reduce_params: + test_reduce(keepdims=keepdims, reduce_type=reduce_type, axes=axes) + + test_reduce_scalar() + + for new_shape in [[8, 10], [80], [2, 40]]: + test_reshape(new_shape=new_shape) + + test_flatten() + test_reciprocal() + test_einsum() + test_reshape_add() + + for input_desc in [dace.float32[2, 3], dace.float32[1], dace.float32]: + test_sum_arrays(input_desc=input_desc) + + test_shape() + test_gather_onnx_1() + test_gather_bert() + test_gather_scalar() + test_gather_onnx_2() + test_unsqueeze() diff --git a/tests/onnx/test_bert_subgraphs.py b/tests/onnx/test_bert_subgraphs.py new file mode 100644 index 0000000000..8d3d90a13f --- /dev/null +++ b/tests/onnx/test_bert_subgraphs.py @@ -0,0 +1,106 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Regression tests for BERT subgraphs +""" +import numpy as np +import pytest + +pytest.importorskip("onnx", reason="ONNX not installed. Please install with: pip install dace[ml]") +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") + +from onnx import helper, numpy_helper, TensorProto +import torch +from dace.ml import ONNXModel + + +def make_slice_model(): + """Create a simple ONNX model with a Slice operation.""" + data_input = helper.make_tensor_value_info('data', TensorProto.FLOAT, [2]) + output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1]) + + starts = numpy_helper.from_array(np.array([0], dtype=np.int64), name='starts') + ends = numpy_helper.from_array(np.array([1], dtype=np.int64), name='ends') + axes = numpy_helper.from_array(np.array([0], dtype=np.int64), name='axes') + + slice_node = helper.make_node('Slice', inputs=['data', 'starts', 'ends', 'axes'], outputs=['output']) + + graph = helper.make_graph([slice_node], + 'slice_graph', + inputs=[data_input], + outputs=[output], + initializer=[starts, ends, axes]) + + model = helper.make_model(graph, opset_imports=[helper.make_opsetid('', 12)]) + model.ir_version = 7 + return model + + +def make_reshape_model(): + """Create an ONNX model simulating BERT embedding reshape operations.""" + output = helper.make_tensor_value_info('bert/embeddings/Reshape_4:0', TensorProto.FLOAT, [1, 256, 768]) + + position_embeddings = numpy_helper.from_array(np.random.randn(512, 768).astype(np.float32), + name='bert/embeddings/position_embeddings:0') + slice_starts = numpy_helper.from_array(np.array([0, 0], dtype=np.int64), name='const_slice__40') + slice_ends = numpy_helper.from_array(np.array([256, 2147483647], dtype=np.int64), name='const_slice__41') + reshape_shape = numpy_helper.from_array(np.array([1, 256, 768], dtype=np.int32), + name='bert/embeddings/Reshape_4/shape:0') + + slice_node = helper.make_node( + 'Slice', + inputs=['bert/embeddings/position_embeddings:0', 'const_slice__40', 'const_slice__41'], + outputs=['bert/embeddings/Slice:0']) + + cast_node = helper.make_node('Cast', + inputs=['bert/embeddings/Reshape_4/shape:0'], + outputs=['bert/embeddings/Reshape_4__42:0'], + to=TensorProto.INT64) + + reshape_node = helper.make_node('Reshape', + inputs=['bert/embeddings/Slice:0', 'bert/embeddings/Reshape_4__42:0'], + outputs=['bert/embeddings/Reshape_4:0']) + + graph = helper.make_graph([slice_node, cast_node, reshape_node], + 'reshape_graph', + inputs=[], + outputs=[output], + initializer=[position_embeddings, slice_starts, slice_ends, reshape_shape]) + + model = helper.make_model(graph, opset_imports=[helper.make_opsetid('', 12)]) + model.ir_version = 7 + return model + + +@pytest.mark.onnx +def test_slice(): + model = make_slice_model() + dace_model = ONNXModel("test_slice", model, onnx_simplify=False) + + data = torch.ones(2) + + out = dace_model(data=data) + assert out.shape == (1, ), f"Expected output shape (1,), got {out.shape}" + assert out[0] == 1.0, f"Expected output value 1.0, got {out[0]}" + + +@pytest.mark.onnx +def test_reshape(): + model = make_reshape_model() + dace_model = ONNXModel("test_reshape", model) + dace_model() + + +@pytest.mark.onnx +def test_save_transients(): + model = make_reshape_model() + transients = {} + dace_model = ONNXModel("test_save_transients", model, save_transients=transients) + dace_model() + assert torch.allclose(transients["bertSLASHembeddingsSLASHReshape_4COLON0"].cpu(), + dace_model.weights["bert/embeddings/Reshape_4:0"]) + + +if __name__ == "__main__": + test_slice() + test_reshape() + test_save_transients() diff --git a/tests/onnx/test_input_outputs.py b/tests/onnx/test_input_outputs.py new file mode 100644 index 0000000000..bc979a7ac8 --- /dev/null +++ b/tests/onnx/test_input_outputs.py @@ -0,0 +1,229 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Testing input and output combinations for onnx Ops + +| Output / Input | Array CPU | +|----------------+-----------| +| Scalar CPU | Shape | +| Array CPU | Add | +""" +import pytest + +pytest.importorskip("onnx", reason="ONNX not installed. Please install with: pip install dace[ml]") + +import numpy as np +import pytest + +import dace +import dace.libraries.onnx as donnx + + +@pytest.mark.onnx +@pytest.mark.parametrize("simplify", [True, False]) +def test_squeeze(simplify: bool): + + sdfg = dace.SDFG("test_squeeze") + + sdfg.add_array("X_arr", [1], dace.float32) + sdfg.add_array("axes", [1], dace.int64, transient=True) + sdfg.add_scalar("scalar", dace.float32, transient=True) + sdfg.add_array("__return", [1], dace.float32) + + state = sdfg.add_state() + access_X = state.add_access("X_arr") + access_axes = state.add_access("axes") + access_scalar = state.add_access("scalar") + + access_result = state.add_access("__return") + + # Tasklet to initialize axes + init_axes = state.add_tasklet("init_axes", + inputs={}, + outputs={"__axes": dace.pointer(dace.int64)}, + code="__axes[0] = 0;", + language=dace.Language.CPP) + + state.add_edge(init_axes, "__axes", access_axes, None, sdfg.make_array_memlet("axes")) + + op_node = donnx.ONNXSqueeze("Squeeze") + + state.add_node(op_node) + state.add_edge(access_X, None, op_node, "data", sdfg.make_array_memlet("X_arr")) + + state.add_edge(op_node, "squeezed", access_scalar, None, sdfg.make_array_memlet("scalar")) + + unsqueeze_op = donnx.ONNXUnsqueeze("Unsqueeze") + state.add_node(unsqueeze_op) + state.add_edge(access_scalar, None, unsqueeze_op, "data", sdfg.make_array_memlet("scalar")) + state.add_edge(access_axes, None, unsqueeze_op, "axes", sdfg.make_array_memlet("axes")) + state.add_edge(unsqueeze_op, "expanded", access_result, None, sdfg.make_array_memlet("__return")) + + X = np.random.rand(1).astype(np.float32) + + if simplify: + sdfg.expand_library_nodes() + sdfg.simplify() + + sdfg.expand_library_nodes() + result = sdfg(X_arr=X) + + assert result.shape == (1, ), f"Expected shape (1,), got {result.shape}" + assert result[0] == X, f"Expected value {X}, got {result[0]}" + + +@pytest.mark.onnx +@pytest.mark.parametrize("simplify", [True, False]) +def test_shape(simplify: bool): + sdfg = dace.SDFG("test_shape") + + sdfg.add_array("X_arr", [2, 4], dace.float32) + sdfg.add_array("__return", [2], dace.int64) + + state = sdfg.add_state() + access_X = state.add_access("X_arr") + + access_result = state.add_access("__return") + + op_node = donnx.ONNXShape("Shape") + + state.add_node(op_node) + state.add_edge(access_X, None, op_node, "data", sdfg.make_array_memlet("X_arr")) + + state.add_edge(op_node, "shape", access_result, None, sdfg.make_array_memlet("__return")) + + X = np.random.rand(2, 4).astype(np.float32) + + if simplify: + sdfg.expand_library_nodes() + sdfg.simplify() + + result = sdfg(X_arr=X) + + assert np.all(result == (2, 4)) + + +@pytest.mark.onnx +@pytest.mark.parametrize("simplify", [True, False]) +def test_unsqueeze(simplify: bool): + sdfg = dace.SDFG("test_unsqueeze") + + sdfg.add_scalar("X_arr", dace.float32) + sdfg.add_array("axes", [1], dace.int64, transient=True) + sdfg.add_array("__return", [1], dace.float32) + + state = sdfg.add_state() + access_X = state.add_access("X_arr") + access_axes = state.add_access("axes") + + access_result = state.add_access("__return") + + # Tasklet to initialize axes + init_axes = state.add_tasklet("init_axes", + inputs={}, + outputs={"__axes": dace.pointer(dace.int64)}, + code="__axes[0] = 0;", + language=dace.Language.CPP) + + state.add_edge(init_axes, "__axes", access_axes, None, sdfg.make_array_memlet("axes")) + + op_node = donnx.ONNXUnsqueeze("Unsqueeze") + + state.add_node(op_node) + state.add_edge(access_X, None, op_node, "data", sdfg.make_array_memlet("X_arr")) + state.add_edge(access_axes, None, op_node, "axes", sdfg.make_array_memlet("axes")) + + state.add_edge(op_node, "expanded", access_result, None, sdfg.make_array_memlet("__return")) + + X = np.float32(np.random.rand()) + + if simplify: + sdfg.expand_library_nodes() + sdfg.simplify() + + result = sdfg(X_arr=X) + + assert result.shape == (1, ), f"Expected shape (1,), got {result.shape}" + assert X == result[0], f"Expected value {X}, got {result[0]}" + + +@pytest.mark.onnx +@pytest.mark.parametrize("scalars", [True, False]) +@pytest.mark.parametrize("simplify", [True, False]) +def test_add(scalars: bool, simplify: bool): + sdfg = dace.SDFG("test_add") + + if scalars: + sdfg.add_scalar("X_arr", dace.float32) + sdfg.add_scalar("W_arr", dace.float32) + sdfg.add_scalar("Z_arr", dace.float32, transient=True) + sdfg.add_array("axes", [1], dace.int64, transient=True) + sdfg.add_array("__return", [1], dace.float32) + else: + sdfg.add_array("X_arr", [2, 2], dace.float32) + sdfg.add_array("W_arr", [2, 2], dace.float32) + sdfg.add_array("__return", [2, 2], dace.float32) + + state = sdfg.add_state() + access_X = state.add_access("X_arr") + access_W = state.add_access("W_arr") + + if scalars: + access_Z = state.add_access("Z_arr") + access_axes = state.add_access("axes") + + access_result = state.add_access("__return") + + op_node = donnx.ONNXAdd("Add") + + state.add_node(op_node) + state.add_edge(access_X, None, op_node, "A", sdfg.make_array_memlet("X_arr")) + state.add_edge(access_W, None, op_node, "B", sdfg.make_array_memlet("W_arr")) + + if scalars: + state.add_edge(op_node, "C", access_Z, None, sdfg.make_array_memlet("Z_arr")) + else: + state.add_edge(op_node, "C", access_result, None, sdfg.make_array_memlet("__return")) + + if scalars: + # Tasklet to initialize axes + init_axes = state.add_tasklet("init_axes", + inputs={}, + outputs={"__axes": dace.pointer(dace.int64)}, + code="__axes[0] = 0;", + language=dace.Language.CPP) + + state.add_edge(init_axes, "__axes", access_axes, None, sdfg.make_array_memlet("axes")) + + unsqueeze_op = donnx.ONNXUnsqueeze("Unsqueeze") + state.add_node(unsqueeze_op) + state.add_edge(access_Z, None, unsqueeze_op, "data", sdfg.make_array_memlet("Z_arr")) + state.add_edge(access_axes, None, unsqueeze_op, "axes", sdfg.make_array_memlet("axes")) + state.add_edge(unsqueeze_op, "expanded", access_result, None, sdfg.make_array_memlet("__return")) + + shapes = [] if scalars else [2, 2] + X = np.random.rand(*shapes) + W = np.random.rand(*shapes) + if not scalars: + X = X.astype(np.float32) + W = W.astype(np.float32) + + if simplify: + sdfg.expand_library_nodes() + sdfg.simplify() + + result = sdfg(X_arr=X, W_arr=W) + + numpy_result = X + W + + assert np.allclose(result, numpy_result) + + +if __name__ == "__main__": + for simplify in [True, False]: + test_squeeze(simplify=simplify) + test_shape(simplify=simplify) + test_unsqueeze(simplify=simplify) + + for scalars in [True, False]: + for simplify in [True, False]: + test_add(scalars=scalars, simplify=simplify) diff --git a/tests/onnx/test_models/test_bert.py b/tests/onnx/test_models/test_bert.py new file mode 100644 index 0000000000..1c447ac0e3 --- /dev/null +++ b/tests/onnx/test_models/test_bert.py @@ -0,0 +1,92 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Test a full model including indexing and input preparation. The model also includes lots of symbolic dimensions. +""" + +import pytest + +pytest.importorskip("onnx", reason="ONNX not installed. Please install with: pip install dace[ml]") +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +pytest.importorskip("onnxsim", reason="ONNX Simplifier not installed. Please install with: pip install dace[ml]") +pytest.importorskip("transformers", + reason="transformers not installed. Please install with: pip install dace[ml-testing]") +import os + +import onnx +import onnxsim +import pathlib +import urllib + +import torch +from transformers import BertTokenizer, BertModel + +import dace +import dace.libraries.onnx as donnx +from tests.utils import torch_tensors_close + + +def get_data_file(url, directory_name=None) -> str: + """ Get a data file from ``url``, cache it locally and return the local file path to it. + + :param url: the url to download from. + :param directory_name: an optional relative directory path where the file will be downloaded to. + :return: the path of the downloaded file. + """ + + data_directory = (pathlib.Path(dace.__file__).parent.parent / 'tests' / 'data') + + if directory_name is not None: + data_directory /= directory_name + + data_directory.mkdir(exist_ok=True, parents=True) + + file_name = os.path.basename(urllib.parse.urlparse(url).path) + file_path = str(data_directory / file_name) + + if not os.path.exists(file_path): + urllib.request.urlretrieve(url, file_path) + return file_path + + +@pytest.mark.xdist_group("large_ML_models") +@pytest.mark.onnx +def test_bert_full(): + bert_tiny_root = 'http://spclstorage.inf.ethz.ch/~rauscho/bert-tiny' + get_data_file(bert_tiny_root + "/config.json", directory_name='bert-tiny') + vocab = get_data_file(bert_tiny_root + "/vocab.txt", directory_name='bert-tiny') + bert_path = get_data_file(bert_tiny_root + "/bert-tiny.onnx", directory_name='bert-tiny') + get_data_file(bert_tiny_root + "/pytorch_model.bin", directory_name='bert-tiny') + model_dir = os.path.dirname(vocab) + + tokenizer = BertTokenizer.from_pretrained(vocab) + pt_model = BertModel.from_pretrained(model_dir) + + text = "[CLS] how are you today [SEP] dude [SEP]" + tokenized_text = tokenizer.tokenize(text) + indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text) + segment_ids = [0] * 6 + [1] * 2 + + tokens_tensor = torch.tensor([indexed_tokens]) + segments_tensors = torch.tensor([segment_ids]) + attention_mask = torch.ones(1, 8, dtype=torch.int64) + + model = onnx.load(bert_path) + # infer shapes + model, _ = onnxsim.simplify(model, + skip_fuse_bn=True, + input_shapes=dict(input_ids=tokens_tensor.shape, + token_type_ids=segments_tensors.shape, + attention_mask=attention_mask.shape)) + + dace_model = donnx.ONNXModel("test_bert_full", model, auto_merge=True) + + dace_output = dace_model(input_ids=tokens_tensor, token_type_ids=segments_tensors, attention_mask=attention_mask) + + output = pt_model(tokens_tensor, token_type_ids=segments_tensors, attention_mask=attention_mask) + + torch_tensors_close("output_0", output[0], dace_output[0]) + torch_tensors_close("output_1", output[1], dace_output[1]) + + +if __name__ == "__main__": + test_bert_full() diff --git a/tests/onnx/test_name_shadowing.py b/tests/onnx/test_name_shadowing.py new file mode 100644 index 0000000000..aed2be894e --- /dev/null +++ b/tests/onnx/test_name_shadowing.py @@ -0,0 +1,37 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("onnx", reason="ONNX not installed. Please install with: pip install dace[ml]") + +import dace + +import dace.libraries.onnx as donnx + + +@pytest.mark.onnx +def test_shadowing(): + new_shape = [8, 10] + sdfg = dace.SDFG("test_shadowing") + + sdfg.add_array("X", [2, 4, 10], dace.float32) + sdfg.add_array("shape", [len(new_shape)], dace.int64) + sdfg.add_array("__return", new_shape, dace.float32) + + state = sdfg.add_state() + access_X = state.add_access("X") + access_shape = state.add_access("shape") + access_result = state.add_access("__return") + + op_node = donnx.ONNXReshape("reshape") + + state.add_node(op_node) + state.add_edge(access_X, None, op_node, "data", sdfg.make_array_memlet("X")) + state.add_edge(access_shape, None, op_node, "shape", sdfg.make_array_memlet("shape")) + + state.add_edge(op_node, "reshaped", access_result, None, sdfg.make_array_memlet("__return")) + + sdfg.compile() + + +if __name__ == "__main__": + test_shadowing() diff --git a/tests/onnx/test_onnx_return_scalars.py b/tests/onnx/test_onnx_return_scalars.py new file mode 100644 index 0000000000..bc0f9cfd19 --- /dev/null +++ b/tests/onnx/test_onnx_return_scalars.py @@ -0,0 +1,61 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("onnx", reason="ONNX not installed. Please install with: pip install dace[ml]") +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") + +import torch +import onnx + +from dace.libraries import onnx as donnx + + +@pytest.mark.onnx +def test_onnx_return_scalars(): + # Dace programs can't return scalars. + # this test checks that we correctly copy out the scalars using a size [1] array + + # we will have a single operator that computes the sum of a 1D tensor + X = onnx.helper.make_tensor_value_info('X', onnx.TensorProto.FLOAT, [5]) + + # Create axes constant with value 0 + axes_constant = onnx.helper.make_tensor( + name='axes', + data_type=onnx.TensorProto.INT64, + dims=[1], # Single element array + vals=[0] # Reduce along axis 0 + ) + + # return value is a scalar + Y = onnx.helper.make_tensor_value_info('Y', onnx.TensorProto.FLOAT, []) + + node_def = onnx.helper.make_node( + 'ReduceSum', + ['X', "axes"], + ['Y'], + keepdims=0, + ) + + graph_def = onnx.helper.make_graph( + [node_def], + 'test-scalar-return', + [X], # inputs + [Y], # outputs + [axes_constant] # initializers (constants) + ) + + model_def = onnx.helper.make_model(graph_def, ir_version=10, opset_imports=[onnx.helper.make_opsetid('', 13)]) + + onnx.checker.check_model(model_def) + + # now we can test the backend + dace_model = donnx.ONNXModel("test_onnx_return_scalars", model_def) + inp = torch.arange(5).type(torch.float32) + + result = dace_model(inp) + assert result.shape == (), f"Expected scalar shape (), got {result.shape}" + assert result[()] == 1 + 2 + 3 + 4, f"Expected sum 10, got {result[()]}" + + +if __name__ == "__main__": + test_onnx_return_scalars() diff --git a/tests/onnx/test_python_frontend.py b/tests/onnx/test_python_frontend.py new file mode 100644 index 0000000000..495e5cddc3 --- /dev/null +++ b/tests/onnx/test_python_frontend.py @@ -0,0 +1,31 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Test the python frontend of onnx nodes +""" + +import pytest + +pytest.importorskip("onnx", reason="ONNX not installed. Please install with: pip install dace[ml]") +import numpy as np + +import dace +import dace.libraries.onnx as donnx + + +@pytest.mark.onnx +def test_matmul(): + + @dace + def matmul(inp1: dace.float32[5, 5], inp2: dace.float32[5, 3]): + out = dace.define_local([5, 3], dace.float32) + donnx.ONNXMatMul(A=inp1, B=inp2, Y=out) + return out + + A = np.random.normal(size=(5, 5)).astype(np.float32) + B = np.random.normal(size=(5, 3)).astype(np.float32) + result = matmul(inp1=A.copy(), inp2=B.copy()) + np.testing.assert_allclose(A @ B, result, atol=1e-5, rtol=1e-5, err_msg="MatMul output mismatch") + + +if __name__ == "__main__": + test_matmul() diff --git a/tests/onnx/test_shared_input_output.py b/tests/onnx/test_shared_input_output.py new file mode 100644 index 0000000000..1e53cdd326 --- /dev/null +++ b/tests/onnx/test_shared_input_output.py @@ -0,0 +1,111 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Batch Norm is the only op that has a shared name between inputs and outputs. Test that prepending "in_" and "out_" works +""" + +import pytest + +pytest.importorskip("onnx", reason="ONNX not installed. Please install with: pip install dace[ml]") +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +import dace +import dace.libraries.onnx as donnx +from dace.ml import DaceModule + +from tests.utils import torch_tensors_close + + +@pytest.mark.onnx +@pytest.mark.parametrize("training_mode", [True, False]) +def test_bn_standalone(training_mode: bool): + + if training_mode: + + @dace.program + def test_bn_standalone(X: dace.float32[8, 3, 32, + 32], scale: dace.float32[3], B: dace.float32[3], mean: dace.float32[3], + var: dace.float32[3], running_mean: dace.float32[3], running_var: dace.float32[3]): + Y = dace.define_local([8, 3, 32, 32], dace.float32) + donnx.ONNXBatchNormalization( + X=X, + scale=scale, + B=B, + input_mean=mean, + input_var=var, + Y=Y, + running_mean=running_mean, + running_var=running_var, + training_mode=True, + ) + return Y + else: + + @dace.program + def test_bn_standalone(X: dace.float32[8, 3, 32, 32], scale: dace.float32[3], B: dace.float32[3], + mean: dace.float32[3], var: dace.float32[3]): + + Y = dace.define_local([8, 3, 32, 32], dace.float32) + donnx.ONNXBatchNormalization(X=X, + scale=scale, + B=B, + input_mean=mean, + input_var=var, + Y=Y, + training_mode=training_mode) + return Y + + X = torch.randn(8, 3, 32, 32) + scale = torch.randn(3) + B = torch.randn(3) + mean = torch.randn(3) + var = torch.abs(torch.randn(3)) + X_torch, scale_torch, B_torch, mean_torch, var_torch = X.clone(), scale.clone(), B.clone(), mean.clone(), var.clone( + ) + if training_mode: + running_mean = np.zeros(3, dtype=np.float32) + running_var = np.ones(3, dtype=np.float32) + dace_result = test_bn_standalone(X, scale, B, mean, var, running_mean, running_var) + else: + dace_result = test_bn_standalone(X, scale, B, mean, var) + + pt_result = F.batch_norm(X_torch, mean_torch, var_torch, scale_torch, B_torch, training=training_mode) + torch_tensors_close("output", pt_result, torch.from_numpy(dace_result)) + + +@pytest.mark.onnx +def test_bn_in_import(): + + class Module(torch.nn.Module): + + def __init__(self): + super(Module, self).__init__() + self.bn = nn.BatchNorm2d(3, track_running_stats=True) + + def forward(self, x): + return self.bn(x) + + pt_module = Module() + pt_module.eval() + dace_module = Module() + dace_module.eval() + + dace_module.load_state_dict(pt_module.state_dict()) + + dace_module = DaceModule(dace_module, sdfg_name="test_bn_in_import") + + X = torch.randn(8, 3, 32, 32) + pt_result = pt_module(X) + dace_result = dace_module(X) + + torch_tensors_close("output", pt_result, dace_result) + + +if __name__ == "__main__": + for training_mode in [True, False]: + test_bn_standalone(training_mode=training_mode) + test_bn_in_import() diff --git a/tests/onnx/test_variadic.py b/tests/onnx/test_variadic.py new file mode 100644 index 0000000000..635af8b9ce --- /dev/null +++ b/tests/onnx/test_variadic.py @@ -0,0 +1,53 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("onnx", reason="ONNX not installed. Please install with: pip install dace[ml]") +import numpy as np + +import dace +import dace.libraries.onnx as donnx + + +@pytest.mark.onnx +def test_sum(): + sdfg = dace.SDFG("test_sum") + + sdfg.add_array("A_arr", [2, 2], dace.float32) + sdfg.add_array("B_arr", [2, 2], dace.float32) + sdfg.add_array("C_arr", [2, 2], dace.float32) + sdfg.add_array("__return", [2, 2], dace.float32) + + state = sdfg.add_state() + access_A = state.add_access("A_arr") + access_B = state.add_access("B_arr") + access_C = state.add_access("C_arr") + + access_result = state.add_access("__return") + + op_node = donnx.ONNXSum("Sum") + + state.add_node(op_node) + for i in range(3): + op_node.add_in_connector("data_0__{}".format(i)) + state.add_edge(access_A, None, op_node, "data_0__0", sdfg.make_array_memlet("A_arr")) + state.add_edge(access_B, None, op_node, "data_0__1", sdfg.make_array_memlet("B_arr")) + state.add_edge(access_C, None, op_node, "data_0__2", sdfg.make_array_memlet("C_arr")) + + state.add_edge(op_node, "sum", access_result, None, sdfg.make_array_memlet("__return")) + + A = np.random.rand(2, 2).astype(np.float32) + B = np.random.rand(2, 2).astype(np.float32) + C = np.random.rand(2, 2).astype(np.float32) + + sdfg.validate() + + result = sdfg(A_arr=A, B_arr=B, C_arr=C) + + numpy_result = A + B + C + + assert np.allclose(result, + numpy_result), f"Variadic sum mismatch: max diff = {np.max(np.abs(result - numpy_result))}" + + +if __name__ == "__main__": + test_sum() diff --git a/tests/python_frontend/return_value_test.py b/tests/python_frontend/return_value_test.py index 4a845bea0b..4e704287bc 100644 --- a/tests/python_frontend/return_value_test.py +++ b/tests/python_frontend/return_value_test.py @@ -9,7 +9,15 @@ def test_return_scalar(): def return_scalar(): return 5 - assert return_scalar() == 5 + res = return_scalar() + assert res == 5 + + # Don't be fooled by the test above the return value is an array. If you would + # add the return value annotation to the program, i.e. `-> dace.int32` you would + # get a validation error. + assert isinstance(res, np.ndarray) + assert res.shape == (1, ) + assert res.dtype == np.int64 def test_return_scalar_in_nested_function(): @@ -22,7 +30,15 @@ def nested_function() -> dace.int32: def return_scalar(): return nested_function() - assert return_scalar() == 5 + res = return_scalar() + assert res == 5 + + # Don't be fooled by the test above the return value is an array. If you would + # add the return value annotation to the program, i.e. `-> dace.int32` you would + # get a validation error. + assert isinstance(res, np.ndarray) + assert res.shape == (1, ) + assert res.dtype == np.int32 def test_return_array(): @@ -42,6 +58,8 @@ def return_tuple(): return 5, 6 res = return_tuple() + assert isinstance(res, tuple) + assert len(res) == 2 assert res == (5, 6) @@ -52,6 +70,8 @@ def return_array_tuple(): return 5 * np.ones(5), 6 * np.ones(6) res = return_array_tuple() + assert isinstance(res, tuple) + assert len(res) == 2 assert np.allclose(res[0], 5 * np.ones(5)) assert np.allclose(res[1], 6 * np.ones(6)) @@ -66,10 +86,25 @@ def return_void(a: dace.float64[20]): a = np.random.rand(20) ref = a + 1 - return_void(a) + res = return_void(a) + assert res is None assert np.allclose(a, ref) +def test_return_tuple_1_element(): + + @dace.program + def return_one_element_tuple(a: dace.float64[20]): + return (a + 3.5, ) + + a = np.random.rand(20) + ref = a + 3.5 + res = return_one_element_tuple(a) + assert isinstance(res, tuple) + assert len(res) == 1 + assert np.allclose(res[0], ref) + + def test_return_void_in_if(): @dace.program diff --git a/tests/python_frontend/structures/structure_python_test.py b/tests/python_frontend/structures/structure_python_test.py index 9505f8cab7..af317be7d8 100644 --- a/tests/python_frontend/structures/structure_python_test.py +++ b/tests/python_frontend/structures/structure_python_test.py @@ -257,6 +257,92 @@ def csr_to_dense_python(A: CSR, B: dace.float32[M, N]): assert np.allclose(B, ref) +def test_write_structure_in_map(): + M = dace.symbol('M') + N = dace.symbol('N') + Bundle = dace.data.Structure(members={ + "data": dace.data.Array(dace.float32, (M, N)), + "size": dace.data.Scalar(dace.int64) + }, + name="BundleType") + + @dace.program + def init_prog(bundle: Bundle, fill_value: int) -> None: + for index in dace.map[0:bundle.size]: + bundle.data[index, :] = fill_value + + data = np.zeros((10, 5), dtype=np.float32) + fill_value = 42 + inp_struct = Bundle.dtype.base_type.as_ctypes()( + data=data.__array_interface__['data'][0], + size=9, + ) + ref = np.zeros((10, 5), dtype=np.float32) + ref[:9, :] = fill_value + + init_prog.compile()(inp_struct, fill_value, M=10, N=5) + + assert np.allclose(data, ref) + + +def test_readwrite_structure_in_map(): + M = dace.symbol('M') + N = dace.symbol('N') + Bundle = dace.data.Structure(members={ + "data": dace.data.Array(dace.float32, (M, N)), + "data2": dace.data.Array(dace.float32, (M, N)), + "size": dace.data.Scalar(dace.int64) + }, + name="BundleType") + + @dace.program + def copy_prog(bundle: Bundle) -> None: + for index in dace.map[0:bundle.size]: + bundle.data[index, :] = bundle.data2[index, :] + 5 + + data = np.zeros((10, 5), dtype=np.float32) + data2 = np.ones((10, 5), dtype=np.float32) + inp_struct = Bundle.dtype.base_type.as_ctypes()( + data=data.__array_interface__['data'][0], + data2=data2.__array_interface__['data'][0], + size=6, + ) + ref = np.zeros((10, 5), dtype=np.float32) + ref[:6, :] = 6.0 + + copy_prog.compile()(inp_struct, M=10, N=5) + + assert np.allclose(data, ref) + + +def test_write_structure_in_loop(): + M = dace.symbol('M') + N = dace.symbol('N') + Bundle = dace.data.Structure(members={ + "data": dace.data.Array(dace.float32, (M, N)), + "size": dace.data.Scalar(dace.int64) + }, + name="BundleType") + + @dace.program + def init_prog(bundle: Bundle, fill_value: int) -> None: + for index in range(bundle.size): + bundle.data[index, :] = fill_value + + data = np.zeros((10, 5), dtype=np.float32) + fill_value = 42 + inp_struct = Bundle.dtype.base_type.as_ctypes()( + data=data.__array_interface__['data'][0], + size=6, + ) + ref = np.zeros((10, 5), dtype=np.float32) + ref[:6, :] = fill_value + + init_prog.compile()(inp_struct, fill_value, M=10, N=5) + + assert np.allclose(data, ref) + + def test_struct_interface(): M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) CSR = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), @@ -370,6 +456,9 @@ def struct_recursive(A: Struct, B: Struct): test_local_structure() test_rgf() # test_read_structure_gpu() + test_write_structure_in_map() + test_readwrite_structure_in_map() + test_write_structure_in_loop() test_struct_interface() test_struct_recursive() test_struct_recursive_from_dataclass() diff --git a/tests/rtl/hardware_test.py b/tests/rtl/hardware_test.py deleted file mode 100644 index b17da2050f..0000000000 --- a/tests/rtl/hardware_test.py +++ /dev/null @@ -1,526 +0,0 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. -""" - Test suite for testing RTL integration with DaCe targeting Xilinx FPGAs. -""" -import dace -from dace.fpga_testing import rtl_test -import numpy as np -import importlib.util -from pathlib import Path -import pytest -from dace.transformation.dataflow import StreamingMemory, Vectorization -from dace.transformation.interstate import FPGATransformState -from dace.transformation.subgraph import TemporalVectorization - - -def make_vadd_sdfg(N: dace.symbol, veclen: int = 8): - ''' - Function for generating a simple vector addition SDFG that adds a vector `A` of `N` elements to a scalar `B` into a vector `C` of `N` elements, all using SystemVerilog. - The tasklet creates `veclen` instances of a floating point adder that operates on `N` elements. - - :param N: The number of elements the SDFG takes as input and output. - :param veclen: The number of floating point adders to instantiate. - :return: An SDFG that has arguments `A`, `B` and `C`. - ''' - # add sdfg - sdfg = dace.SDFG('floating_point_vector_plus_scalar') - - # add state - state = sdfg.add_state('device_state') - - # add parameter - sdfg.add_constant('VECLEN', veclen) - - # add arrays - sdfg.add_array('A', [N // veclen], dtype=dace.vector(dace.float32, veclen), storage=dace.StorageType.CPU_Heap) - sdfg.add_scalar('B', dace.float32, storage=dace.StorageType.FPGA_Global) - sdfg.add_array('C', [N // veclen], dtype=dace.vector(dace.float32, veclen), storage=dace.StorageType.CPU_Heap) - sdfg.add_array('fpga_A', [N // veclen], - dtype=dace.vector(dace.float32, veclen), - transient=True, - storage=dace.StorageType.FPGA_Global) - sdfg.add_array('fpga_C', [N // veclen], - dtype=dace.vector(dace.float32, veclen), - transient=True, - storage=dace.StorageType.FPGA_Global) - - # add streams - sdfg.add_stream('A_stream', - buffer_size=32, - dtype=dace.vector(dace.float32, veclen), - transient=True, - storage=dace.StorageType.FPGA_Local) - sdfg.add_stream('C_stream', - buffer_size=32, - dtype=dace.vector(dace.float32, veclen), - transient=True, - storage=dace.StorageType.FPGA_Local) - - # add custom rtl tasklet - rtl_tasklet = state.add_tasklet(name='rtl_tasklet', - inputs={'a', 'b'}, - outputs={'c'}, - code=''' - assign ap_done = 1; - wire ap_aresetn = ~ap_areset; - - wire [VECLEN-1:0] a_tvalid; - wire [VECLEN-1:0][31:0] a_tdata; - wire [VECLEN-1:0] a_tready; - - wire [VECLEN-1:0] c_tvalid; - wire [VECLEN-1:0][31:0] c_tdata; - wire [VECLEN-1:0] c_tready; - - axis_broadcaster_0 ab0( - .aclk (ap_aclk), - .aresetn (ap_aresetn), - - .s_axis_tvalid (s_axis_a_tvalid), - .s_axis_tdata (s_axis_a_tdata), - .s_axis_tready (s_axis_a_tready), - - .m_axis_tvalid (a_tvalid), - .m_axis_tdata (a_tdata), - .m_axis_tready (a_tready) - ); - - generate - for (genvar i = 0; i < VECLEN; i = i + 1) begin - floating_point_add add( - .aclk (ap_aclk), - .aresetn (ap_aresetn), - - .s_axis_a_tvalid (a_tvalid[i]), - .s_axis_a_tdata (a_tdata[i]), - .s_axis_a_tready (a_tready[i]), - - .s_axis_b_tvalid (scalars_valid), - .s_axis_b_tdata (b), - - .m_axis_result_tvalid (c_tvalid[i]), - .m_axis_result_tdata (c_tdata[i]), - .m_axis_result_tready (c_tready[i]) - ); - end - endgenerate - - axis_combiner_0 ac0( - .aclk (ap_aclk), - .aresetn (ap_aresetn), - - .s_axis_tvalid (c_tvalid), - .s_axis_tdata (c_tdata), - .s_axis_tready (c_tready), - - .m_axis_tvalid (m_axis_c_tvalid), - .m_axis_tdata (m_axis_c_tdata), - .m_axis_tready (m_axis_c_tready) - ); - ''', - language=dace.Language.SystemVerilog) - - rtl_tasklet.add_ip_core( - 'floating_point_add', 'floating_point', 'xilinx.com', '7.1', { - 'CONFIG.Add_Sub_Value': 'Add', - 'CONFIG.Has_ARESETn': 'true', - 'CONFIG.Axi_Optimize_Goal': 'Performance', - 'CONFIG.C_Latency': '14' - }) - rtl_tasklet.add_ip_core( - 'axis_broadcaster_0', 'axis_broadcaster', 'xilinx.com', '1.1', - dict({ - 'CONFIG.NUM_MI': f'{veclen}', - 'CONFIG.M_TDATA_NUM_BYTES': '4', - 'CONFIG.S_TDATA_NUM_BYTES': f'{veclen*4}' - }, **{f'CONFIG.M{i:02}_TDATA_REMAP': f'tdata[{((i+1)*32)-1}:{i*32}]' - for i in range(veclen)})) - rtl_tasklet.add_ip_core('axis_combiner_0', 'axis_combiner', 'xilinx.com', '1.1', { - 'CONFIG.TDATA_NUM_BYTES': '4', - 'CONFIG.NUM_SI': f'{veclen}' - }) - - # add read and write tasklets - read_a = state.add_tasklet('read_a', {'inp'}, {'out'}, 'out = inp') - write_c = state.add_tasklet('write_c', {'inp'}, {'out'}, 'out = inp') - - # add read and write maps - read_a_entry, read_a_exit = state.add_map('read_a_map', - dict(i='0:N//VECLEN'), - schedule=dace.ScheduleType.FPGA_Device) - write_c_entry, write_c_exit = state.add_map('write_c_map', - dict(i='0:N//VECLEN'), - schedule=dace.ScheduleType.FPGA_Device) - - # add read_a memlets and access nodes - read_a_inp = state.add_read('fpga_A') - read_a_out = state.add_write('A_stream') - state.add_memlet_path(read_a_inp, read_a_entry, read_a, dst_conn='inp', memlet=dace.Memlet('fpga_A[i]')) - state.add_memlet_path(read_a, read_a_exit, read_a_out, src_conn='out', memlet=dace.Memlet('A_stream[0]')) - - # add tasklet memlets - A = state.add_read('A_stream') - B = state.add_read('B') - C = state.add_write('C_stream') - state.add_memlet_path(A, rtl_tasklet, dst_conn='a', memlet=dace.Memlet('A_stream[0]')) - state.add_memlet_path(B, rtl_tasklet, dst_conn='b', memlet=dace.Memlet('B[0]')) - state.add_memlet_path(rtl_tasklet, C, src_conn='c', memlet=dace.Memlet('C_stream[0]')) - - # add write_c memlets and access nodes - write_c_inp = state.add_read('C_stream') - write_c_out = state.add_write('fpga_C') - state.add_memlet_path(write_c_inp, write_c_entry, write_c, dst_conn='inp', memlet=dace.Memlet('C_stream[0]')) - state.add_memlet_path(write_c, write_c_exit, write_c_out, src_conn='out', memlet=dace.Memlet('fpga_C[i]')) - - # add copy to device state - copy_to_device = sdfg.add_state('copy_to_device') - cpu_a = copy_to_device.add_read('A') - dev_a = copy_to_device.add_write('fpga_A') - copy_to_device.add_memlet_path(cpu_a, dev_a, memlet=dace.Memlet('A[0:N//VECLEN]')) - sdfg.add_edge(copy_to_device, state, dace.InterstateEdge()) - - # add copy to host state - copy_to_host = sdfg.add_state('copy_to_host') - dev_c = copy_to_host.add_read('fpga_C') - cpu_c = copy_to_host.add_write('C') - copy_to_host.add_memlet_path(dev_c, cpu_c, memlet=dace.Memlet('C[0:N//VECLEN]')) - sdfg.add_edge(state, copy_to_host, dace.InterstateEdge()) - - # validate sdfg - sdfg.validate() - - return sdfg - - -def make_vadd_multi_sdfg(N, M, n, m): - ''' - Function for constructing an SDFG that adds a constant (42) to a an array `A` of `N` elements into a vector `B`. - Each instance of an adder is within its own Processing Element (PE), along with a reader and writer to/from global memory. - This SDFG also utilizes array of streams, giving each compute PE its own set of in- and output streams. - - :param N: The number of elements to compute on. - :param M: The number of compute PEs to initialize. - :return: An SDFG that has arguments `A` and `B`. - ''' - # add sdfg - sdfg = dace.SDFG(f'integer_vector_plus_42_multiple_kernels_{n // m}') - - # add state - state = sdfg.add_state('device_state') - - # add arrays - sdfg.add_array('A', [N], dtype=dace.int32, storage=dace.StorageType.CPU_Heap) - sdfg.add_array('B', [N], dtype=dace.int32, storage=dace.StorageType.CPU_Heap) - sdfg.add_array('fpga_A', [N], dtype=dace.int32, transient=True, storage=dace.StorageType.FPGA_Global) - sdfg.add_array('fpga_B', [N], dtype=dace.int32, transient=True, storage=dace.StorageType.FPGA_Global) - - # add streams - sdfg.add_stream('A_stream', - shape=(int(n / m), ), - dtype=dace.int32, - transient=True, - storage=dace.StorageType.FPGA_Local) - sdfg.add_stream('B_stream', - shape=(int(n / m), ), - dtype=dace.int32, - transient=True, - storage=dace.StorageType.FPGA_Local) - - # add custom rtl tasklet - rtl_tasklet = state.add_tasklet(name='rtl_tasklet', - inputs={'a'}, - outputs={'b'}, - code=''' - /* - Convention: - |--------------------------------------------------------| - | | - -->| ap_aclk (clock input) | - -->| ap_areset (reset input, rst on high) | - -->| ap_start (start pulse from host) | - <--| ap_done (tells the host that the kernel is done) | - | | - | For each input: For each output: | - | | - -->| s_axis_{input}_tvalid reg m_axis_{output}_tvalid |--> - -->| s_axis_{input}_tdata reg m_axis_{output}_tdata |--> - <--| reg s_axis_{input}_tready m_axis_{output}_tready |<-- - -->| s_axis_{input}_tkeep reg m_axis_{output}_tkeep |--> - -->| s_axis_{input}_tlast reg m_axis_{output}_tlast |--> - | | - |--------------------------------------------------------| - */ - - assign ap_done = 1; // free-running kernel - - always@(posedge ap_aclk) begin - if (ap_areset) begin // case: reset - s_axis_a_tready <= 1'b1; - m_axis_b_tvalid <= 1'b0; - m_axis_b_tdata <= 0; - end else if (s_axis_a_tvalid && s_axis_a_tready) begin - s_axis_a_tready <= 1'b0; - m_axis_b_tvalid <= 1'b1; - m_axis_b_tdata <= s_axis_a_tdata + 42; - end else if (!s_axis_a_tready && m_axis_b_tvalid && m_axis_b_tready) begin - s_axis_a_tready <= 1'b1; - m_axis_b_tvalid <= 1'b0; - end - end - ''', - language=dace.Language.SystemVerilog) - - # add read and write tasklets - read_a = state.add_tasklet('read_a', {'inp'}, {'out'}, 'out = inp') - write_b = state.add_tasklet('write_b', {'inp'}, {'out'}, 'out = inp') - - # add read and write maps - read_a_entry, read_a_exit = state.add_map('read_a_map', - dict(i='0:N//M', j='0:M'), - schedule=dace.ScheduleType.FPGA_Device) - write_b_entry, write_b_exit = state.add_map('write_b_map', - dict(i='0:N//M', j='0:M'), - schedule=dace.ScheduleType.FPGA_Device) - compute_entry, compute_exit = state.add_map('compute_map', - dict(i='0:N//M'), - schedule=dace.ScheduleType.FPGA_Device, - unroll=True) - - # add read_a memlets and access nodes - read_a_inp = state.add_read('fpga_A') - read_a_out = state.add_write('A_stream') - state.add_memlet_path(read_a_inp, read_a_entry, read_a, dst_conn='inp', memlet=dace.Memlet('fpga_A[i*M+j]')) - state.add_memlet_path(read_a, read_a_exit, read_a_out, src_conn='out', memlet=dace.Memlet('A_stream[i]')) - - # add tasklet memlets - A = state.add_read('A_stream') - B = state.add_write('B_stream') - state.add_memlet_path(A, compute_entry, rtl_tasklet, dst_conn='a', memlet=dace.Memlet('A_stream[i]')) - state.add_memlet_path(rtl_tasklet, compute_exit, B, src_conn='b', memlet=dace.Memlet('B_stream[i]')) - - # add write_b memlets and access nodes - write_b_inp = state.add_read('B_stream') - write_b_out = state.add_write('fpga_B') - state.add_memlet_path(write_b_inp, write_b_entry, write_b, dst_conn='inp', memlet=dace.Memlet('B_stream[i]')) - state.add_memlet_path(write_b, write_b_exit, write_b_out, src_conn='out', memlet=dace.Memlet('fpga_B[i*M+j]')) - - # add copy to device state - copy_to_device = sdfg.add_state('copy_to_device') - cpu_a = copy_to_device.add_read('A') - dev_a = copy_to_device.add_write('fpga_A') - copy_to_device.add_memlet_path(cpu_a, dev_a, memlet=dace.Memlet('A[0:N]')) - sdfg.add_edge(copy_to_device, state, dace.InterstateEdge()) - - # add copy to host state - copy_to_host = sdfg.add_state('copy_to_host') - dev_b = copy_to_host.add_read('fpga_B') - cpu_b = copy_to_host.add_write('B') - copy_to_host.add_memlet_path(dev_b, cpu_b, memlet=dace.Memlet('B[0:N]')) - sdfg.add_edge(state, copy_to_host, dace.InterstateEdge()) - - return sdfg - - -@rtl_test() -def test_hardware_vadd(): - ''' - Test for the simple vector addition. - ''' - - # add symbol - N = dace.symbol('N') - n = 32 - veclen = 4 - sdfg = make_vadd_sdfg(N, veclen) - a = np.random.randint(0, 100, n).astype(np.float32) - b = np.random.randint(1, 100, 1)[0].astype(np.float32) - c = np.zeros((n, )).astype(np.float32) - - # call program - sdfg(A=a, B=b, C=c, N=N) - - expected = a + b - diff = np.linalg.norm(expected - c) / n - assert diff <= 1e-5 - - return sdfg - - -@rtl_test() -def test_hardware_add42_single(): - ''' - Test for adding a constant using a single PE. - ''' - N = dace.symbol('N') - M = dace.symbol('M') - - # init data structures - n = 32 # elements - m = 32 # elements per kernel - a = np.random.randint(0, 100, n).astype(np.int32) - b = np.zeros((n, )).astype(np.int32) - sdfg = make_vadd_multi_sdfg(N, M, n, m) - sdfg.specialize(dict(N=n, M=m)) - - # call program - sdfg(A=a, B=b) - - # check result - for i in range(n): - assert b[i] == a[i] + 42 - - return sdfg - - -def _hardware_axpy_double_pump(veclen=2): - ''' - Tests manual application of the multi-pumping optimization applied to the AXPY program from BLAS. - - :param veclen: The vectorization length to instantiate. Must be a multiple of 2. - ''' - with dace.config.set_temporary('compiler', 'xilinx', 'frequency', value='"0:300\\|1:600"'): - # Grab the double pumped AXPY implementation the samples directory - spec = importlib.util.spec_from_file_location( - "axpy", - Path(__file__).parent.parent.parent / "samples" / "fpga" / "rtl" / "axpy_double_pump.py") - axpy = importlib.util.module_from_spec(spec) - spec.loader.exec_module(axpy) - - # init data structures - N = dace.symbol('N') - n = 32 - a = np.random.rand(1)[0].astype(np.float32) - x = np.random.rand(n).astype(np.float32) - y = np.random.rand(n).astype(np.float32) - result = np.zeros((n, )).astype(np.float32) - - # Build the SDFG - sdfg = axpy.make_sdfg(veclen) - - # call program - sdfg(a=a, x=x, y=y, result=result, N=N) - - # check result - expected = a * x + y - diff = np.linalg.norm(expected - result) / n - - assert diff <= 1e-5 - - return sdfg - - -@rtl_test() -def test_hardware_axpy_double_pump_vec2(): - ''' - Tests double pumping with a vector length of 2. - ''' - return _hardware_axpy_double_pump(veclen=2) - - -@rtl_test() -def test_hardware_axpy_double_pump_vec4(): - ''' - Tests double pumping with a vector length of 4. - ''' - return _hardware_axpy_double_pump(veclen=4) - - -@rtl_test() -def test_hardware_vadd_temporal_vectorization(): - ''' - Tests whether the multi-pumping optimization can be applied automatically by applying the temporal vectorization transformation. It starts from a numpy vector addition for generating the SDFG. This SDFG is then optimized by applying the vectorization, streaming memory, fpga and temporal vectorization transformations in that order. - ''' - # TODO !!!!! THIS TEST STALLS IN HARDWARE EMULATION WITH VITIS 2021.2 and 2022.1 !!!!! - # But it works fine for 2020.2, 2022.2, and 2023.1. It seems like - # everything but the last transaction correctly goes through just fine. The - # last transaction is never output by the floating point adder, but the - # inputs are consumed. - with dace.config.set_temporary('compiler', 'xilinx', 'frequency', value='"0:300\\|1:600"'): - # Generate the test data and expected results - size_n = 1024 - veclen = 4 - N = dace.symbol('N') - n = size_n - x = np.random.rand(n).astype(np.float32) - y = np.random.rand(n).astype(np.float32) - result = np.zeros(n, dtype=np.float32) - expected = x + y - - # Generate the initial SDFG - def np_vadd(x: dace.float32[N], y: dace.float32[N]): - return x + y - - sdfg = dace.program(np_vadd).to_sdfg() - - # Apply vectorization transformation - ambles = size_n % veclen != 0 - map_entry = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, dace.nodes.MapEntry)][0] - applied = sdfg.apply_transformations( - Vectorization, { - 'vector_len': veclen, - 'preamble': ambles, - 'postamble': ambles, - 'propagate_parent': True, - 'strided_map': False, - 'map_entry': map_entry - }) - assert (applied == 1) - - # Transform to an FPGA implementation - applied = sdfg.apply_transformations(FPGATransformState) - assert (applied == 1) - - # Apply streaming memory transformation - applied = sdfg.apply_transformations_repeated(StreamingMemory, { - 'storage': dace.StorageType.FPGA_Local, - 'buffer_size': 1 - }) - assert (applied == 3) - - # Apply temporal vectorization transformation - sgs = dace.sdfg.concurrent_subgraphs(sdfg.states()[0]) - sf = TemporalVectorization() - cba = [TemporalVectorization.can_be_applied(sf, sdfg, sg) for sg in sgs] - assert (sum(cba) == 1) - [TemporalVectorization.apply_to(sdfg, sg) for i, sg in enumerate(sgs) if cba[i]] - - # Run the program and verify the results - sdfg.specialize({'N': n}) - sdfg(x=x, y=y, __return=result) - assert (np.allclose(expected, result)) - - return sdfg - - -# Disabled due to problem with array of streams in Vitis 2021.1 -#rtl_test() -#def test_hardware_add42_multi(): -# N = dace.symbol('N') -# M = dace.symbol('M') -# -# # init data structures -# n = 32 # elements -# n = 8 # elements per kernel -# a = np.random.randint(0, 100, n).astype(np.int32) -# b = np.zeros((n, )).astype(np.int32) -# sdfg = make_vadd_multi_sdfg(N, M, n, m) -# sdfg.specialize(dict(N=N, M=M)) -# -# # call program -# sdfg(A=a, B=b) -# -# # check result -# for i in range(n): -# assert b[i] == a[i] + 42 -# -# return sdfg - -if __name__ == '__main__': - # These tests should only be run in hardware* mode - with dace.config.set_temporary('compiler', 'xilinx', 'mode', value='hardware_emulation'): - test_hardware_vadd(None) - test_hardware_vadd_temporal_vectorization(None) - test_hardware_add42_single(None) - #test_hardware_add42_multi(None) - test_hardware_axpy_double_pump_vec2(None) - test_hardware_axpy_double_pump_vec4(None) diff --git a/tests/rtl/simulation_test.py b/tests/rtl/simulation_test.py deleted file mode 100644 index d9fcf90eff..0000000000 --- a/tests/rtl/simulation_test.py +++ /dev/null @@ -1,698 +0,0 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. -""" - Test suite for testing RTL tasklets in DaCe with Verilator as a backend for simulation. -""" -import dace -import numpy as np -import pytest - - -@pytest.mark.verilator -def test_tasklet_array(): - """ - Test the simple array execution sample. - """ - - n = 128 - N = dace.symbol('N') - - # add sdfg - sdfg = dace.SDFG('rtl_tasklet_array') - - # add state - state = sdfg.add_state() - - # add arrays - sdfg.add_array('A', [N], dtype=dace.int32) - sdfg.add_array('B', [N], dtype=dace.int32) - - # add custom cpp tasklet - tasklet = state.add_tasklet(name='rtl_tasklet', - inputs={'a'}, - outputs={'b'}, - code=''' - always@(posedge ap_aclk) begin - if (ap_areset) begin - s_axis_a_tready <= 1; - m_axis_b_tvalid <= 0; - m_axis_b_tdata <= 0; - end else if (s_axis_a_tvalid && s_axis_a_tready) begin - s_axis_a_tready <= 0; - m_axis_b_tvalid <= 1; - m_axis_b_tdata <= s_axis_a_tdata + 42; - end else if (m_axis_b_tvalid && m_axis_b_tready) begin - s_axis_a_tready <= 1; - m_axis_b_tvalid <= 0; - m_axis_b_tdata <= 0; - end - end - ''', - language=dace.Language.SystemVerilog) - - # add input/output array - A = state.add_read('A') - B = state.add_write('B') - - # connect input/output array with the tasklet - state.add_edge(A, None, tasklet, 'a', dace.Memlet('A[0:N]')) - state.add_edge(tasklet, 'b', B, None, dace.Memlet('B[0:N]')) - - # validate sdfg - sdfg.specialize({'N': n}) - sdfg.validate() - - # init data structures - a = np.random.randint(0, 100, n).astype(np.int32) - b = np.zeros((n, )).astype(np.int32) - - # call program - sdfg(A=a, B=b) - - # check result - assert (b == a + 42).all() - - -@pytest.mark.skip -@pytest.mark.verilator -def test_tasklet_double_clk_counters(): - """ - Test double clock functionality utilizing two counters, one for each clock. - The first 16 bits of the result should contain the count from the "slow" clock. - The last 16 bits of the result should contain the count from the "fast" clock, i.e. slow count * 2 - """ - old_freq = dace.config.Config.get('compiler', 'xilinx', 'frequency') - dace.config.Config.set('compiler', 'xilinx', 'frequency', value='"0:300\\|1:600"') - sdfg = dace.SDFG('rtl_tasklet_double_clk_counters') - state = sdfg.add_state() - sdfg.add_array('A', [1], dtype=dace.int32) - sdfg.add_array('B', [1], dtype=dace.int32) - - tasklet = state.add_tasklet(name='rtl_tasklet', - inputs={'a'}, - outputs={'b'}, - code=''' - - reg [31:0] max_cnt; - reg [15:0] s_cnt; - reg s_done; - reg [15:0] d_cnt; - reg d_done; - - always @(posedge ap_aclk) begin - if (ap_areset) begin - s_axis_a_tready <= 1; - end else if (s_axis_a_tvalid && s_axis_a_tready) begin - max_cnt <= s_axis_a_tdata; - s_axis_a_tready <= 0; - end else if (m_axis_b_tvalid && m_axis_b_tready) begin - s_axis_a_tready <= 1; - end - end - - always @(posedge ap_aclk) begin - if (ap_areset) begin - s_cnt <= 0; - s_done <= 0; - end else if (s_cnt < max_cnt[15:0]) begin - s_cnt <= s_cnt + 1; - s_done <= 0; - end else begin - s_done <= max_cnt > 0; - end - end - - always @(posedge ap_aclk_2) begin - if (ap_areset) begin - d_cnt <= 0; - d_done <= 0; - end else if (s_cnt < max_cnt[15:0]) begin - d_cnt <= d_cnt + 1; - d_done <= 0; - end else begin - d_done <= max_cnt > 0; - end - end - - always @(posedge ap_aclk) begin - if (ap_areset) begin - m_axis_b_tvalid <= 0; - m_axis_b_tdata <= 0; - end else begin - m_axis_b_tvalid <= s_done && d_done; - m_axis_b_tdata[15:0] <= s_cnt; - m_axis_b_tdata[31:16] <= d_cnt; - end - end - ''', - language=dace.Language.SystemVerilog) - A = state.add_read('A') - B = state.add_write('B') - - state.add_edge(A, None, tasklet, 'a', dace.Memlet('A[0]')) - state.add_edge(tasklet, 'b', B, None, dace.Memlet('B[0]')) - - sdfg.validate() - - a = np.random.randint(0, 100, 1).astype(np.int32) - b = np.zeros((1, )).astype(np.int32) - - sdfg(A=a, B=b) - - dace.config.Config.set('compiler', 'xilinx', 'frequency', value=old_freq) - - assert b[0] & 0xFFFF == a[0] - assert (b[0] >> 16) & 0xFFFF == a[0] * 2 - - -@pytest.mark.verilator -def test_tasklet_scalar(): - """ - Test the simple scalar execution sample. - """ - - # add sdfg - sdfg = dace.SDFG('rtl_tasklet_scalar') - - # add state - state = sdfg.add_state() - - # add arrays - sdfg.add_scalar('A', dtype=dace.int32) - sdfg.add_array('B', [1], dtype=dace.int32) - - # add custom cpp tasklet - tasklet = state.add_tasklet(name='rtl_tasklet', - inputs={'a'}, - outputs={'b'}, - code=''' - always@(posedge ap_aclk) begin - if (ap_areset) begin // case: reset - m_axis_b_tvalid <= 0; - m_axis_b_tdata <= 0; - end else if (m_axis_b_tdata < a) begin // case: increment counter b - m_axis_b_tvalid <= 0; - m_axis_b_tdata <= m_axis_b_tdata + 1; - end else begin - m_axis_b_tvalid <= 1; - end - end - ''', - language=dace.Language.SystemVerilog) - - # add input/output array - A = state.add_read('A') - B = state.add_write('B') - - # connect input/output array with the tasklet - state.add_edge(A, None, tasklet, 'a', dace.Memlet('A[0]')) - state.add_edge(tasklet, 'b', B, None, dace.Memlet('B[0]')) - - # validate sdfg - sdfg.validate() - - # Execute - - # init data structures - a = np.random.randint(0, 100, 1).astype(np.int32) - b = np.zeros((1, )).astype(np.int32) - - # call program - sdfg(A=a[0], B=b) - - # check result - assert b[0] == a[0] - - -@pytest.mark.verilator -def test_tasklet_parameter(): - """ - Test the sv parameter support. - """ - - # add sdfg - sdfg = dace.SDFG('rtl_tasklet_parameter') - - # add state - state = sdfg.add_state() - - # add arrays - sdfg.add_array('A', [1], dtype=dace.int32) - sdfg.add_array('B', [1], dtype=dace.int32) - - # add parameter(s) - sdfg.add_constant("MAX_VAL", 42) - - # add custom cpp tasklet - tasklet = state.add_tasklet(name='rtl_tasklet', - inputs={'a'}, - outputs={'b'}, - code=''' - /* - Convention: - |---------------------------------------------------------------------| - -->| ap_aclk (clock input) | - -->| ap_areset (reset input, rst on high) | - | | - -->| {inputs} reg {outputs} |--> - | | - <--| s_axis_a_tready (ready for data) (data avail) m_axis_b_tvalid |--> - -->| s_axis_a_tvalid (new data avail) (data consumed) m_axis_b_tready |<-- - |---------------------------------------------------------------------| - */ - - typedef enum [1:0] {READY, BUSY, DONE} state_e; - state_e state; - - always@(posedge ap_aclk) begin - if (ap_areset) begin // case: reset - m_axis_b_tdata <= 0; - s_axis_a_tready <= 1'b1; - state <= READY; - end else if (s_axis_a_tvalid && state == READY) begin // case: load a - m_axis_b_tdata <= s_axis_a_tdata; - s_axis_a_tready <= 1'b0; - state <= BUSY; - end else if (m_axis_b_tdata < MAX_VAL) // case: increment counter b - m_axis_b_tdata <= m_axis_b_tdata + 1; - else - m_axis_b_tdata <= m_axis_b_tdata; - state <= DONE; - end - - assign m_axis_b_tvalid = (m_axis_b_tdata >= MAX_VAL) ? 1'b1:1'b0; - ''', - language=dace.Language.SystemVerilog) - - # add input/output array - A = state.add_read('A') - B = state.add_write('B') - - # connect input/output array with the tasklet - state.add_edge(A, None, tasklet, 'a', dace.Memlet('A[0]')) - state.add_edge(tasklet, 'b', B, None, dace.Memlet('B[0]')) - - # validate sdfg - sdfg.validate() - - # execute - - # init data structures - a = np.random.randint(0, 100, 1).astype(np.int32) - b = np.random.randint(0, 100, 1).astype(np.int32) - - # call program - sdfg(A=a, B=b) - - # check result - assert b == sdfg.constants["MAX_VAL"] - - -@pytest.mark.verilator -def test_tasklet_vector_add(): - """ - Test rtl tasklet vector support. - """ - - # add symbol - W = dace.symbol('W') - - # add sdfg - sdfg = dace.SDFG('rtl_tasklet_vector_add') - - # define compile-time constant - sdfg.specialize(dict(W=4)) - - # add state - state = sdfg.add_state() - - # add arrays - sdfg.add_array('A', [1], dtype=dace.vector(dace.int32, dace.symbolic.evaluate(W, sdfg.constants))) - sdfg.add_array('B', [1], dtype=dace.vector(dace.int32, dace.symbolic.evaluate(W, sdfg.constants))) - - # add custom cpp tasklet - tasklet = state.add_tasklet(name='rtl_tasklet', - inputs={'a'}, - outputs={'b'}, - code=''' - always@(posedge ap_aclk) begin - if (ap_areset) begin - s_axis_a_tready <= 1; - m_axis_b_tvalid <= 0; - m_axis_b_tdata <= 0; - end else if (s_axis_a_tvalid && s_axis_a_tready) begin - s_axis_a_tready <= 0; - m_axis_b_tvalid <= 1; - for (int i = 0; i < W; i++) begin - m_axis_b_tdata[i] <= s_axis_a_tdata[i] + 42; - end - end else if (m_axis_b_tvalid && m_axis_b_tready) begin - s_axis_a_tready <= 1; - m_axis_b_tvalid <= 0; - m_axis_b_tdata <= 0; - end - end - ''', - language=dace.Language.SystemVerilog) - - # add input/output array - A = state.add_read('A') - B = state.add_write('B') - - # connect input/output array with the tasklet - state.add_edge(A, None, tasklet, 'a', dace.Memlet('A[0]')) - state.add_edge(tasklet, 'b', B, None, dace.Memlet('B[0]')) - - # validate sdfg - sdfg.validate() - - # Execute - - # init data structures - a = np.random.randint(0, 100, (dace.symbolic.evaluate(W, sdfg.constants), )).astype(np.int32) - b = np.zeros((dace.symbolic.evaluate(W, sdfg.constants), )).astype(np.int32) - - # call program - sdfg(A=a, B=b) - - # check result - print(a) - print(b) - assert (b == a + 42).all() - - -@pytest.mark.verilator -def test_tasklet_vector_conversion(): - """ - Test rtl tasklet vector conversion support. - """ - - # add symbol - N = dace.symbol('N') - - # add sdfg - sdfg = dace.SDFG('rtl_tasklet_vector_conversion') - - # define compile-time constant - sdfg.specialize(dict(N=4)) - - # add state - state = sdfg.add_state() - - # add arrays - sdfg.add_array('A', [N], dtype=dace.int32) - sdfg.add_array('B', [1], dtype=dace.int32) - - # add custom cpp tasklet - tasklet = state.add_tasklet(name='rtl_tasklet', - inputs={'a': dace.vector(dace.int32, N)}, - outputs={'b'}, - code=''' - /* - Convention: - |---------------------------------------------------------------------| - -->| ap_aclk (clock input) | - -->| ap_areset (reset input, rst on high) | - | | - -->| {inputs} reg {outputs} |--> - | | - <--| s_axis_a_tready (ready for data) (data avail) m_axis_b_tvalid |--> - -->| s_axis_a_tvalid (new data avail) (data consumed) m_axis_b_tready |<-- - |---------------------------------------------------------------------| - */ - - typedef enum [1:0] {READY, BUSY, DONE} state_e; - state_e state; - - always@(posedge ap_aclk) begin - if (ap_areset) begin // case: reset - m_axis_b_tdata <= 0; - s_axis_a_tready <= 1'b1; - state <= READY; - end else if (s_axis_a_tvalid && state == READY) begin // case: load a - m_axis_b_tdata <= s_axis_a_tdata[0]; - s_axis_a_tready <= 1'b0; - state <= BUSY; - end else if (m_axis_b_tdata < s_axis_a_tdata[0] + s_axis_a_tdata[1] && state == BUSY) begin // case: increment counter b - m_axis_b_tdata <= m_axis_b_tdata + 1; - end else if (state == BUSY) begin - m_axis_b_tdata <= m_axis_b_tdata; - state <= DONE; - end - end - - assign m_axis_b_tvalid = (m_axis_b_tdata >= s_axis_a_tdata[0] + s_axis_a_tdata[1] && (state == BUSY || state == DONE)) ? 1'b1:1'b0; - ''', - language=dace.Language.SystemVerilog) - - # add input/output array - A = state.add_read('A') - B = state.add_write('B') - - # connect input/output array with the tasklet - state.add_edge(A, None, tasklet, 'a', dace.Memlet('A[0:N]')) - state.add_edge(tasklet, 'b', B, None, dace.Memlet('B[0]')) - - # validate sdfg - sdfg.validate() - - # Execute - - # init data structures - a = np.random.randint(0, 100, dace.symbolic.evaluate(N, sdfg.constants)).astype(np.int32) - b = np.array([0]).astype(np.int32) - - # call program - sdfg(A=a, B=b) - - # check result - assert b == a[0] + a[1] - - -@pytest.mark.verilator -def test_multi_tasklet(): - """ - Test multiple rtl tasklet support. - """ - - # add sdfg - sdfg = dace.SDFG('rtl_multi_tasklet') - - # add state - state = sdfg.add_state() - - # add arrays - sdfg.add_array('A', [1], dtype=dace.int32) - sdfg.add_array('B', [1], dtype=dace.int32) - sdfg.add_array('C', [1], dtype=dace.int32) - - # add custom cpp tasklet - tasklet0 = state.add_tasklet(name='rtl_tasklet0', - inputs={'a'}, - outputs={'b'}, - code=''' - typedef enum [1:0] {READY, BUSY, DONE} state_e; - state_e state; - - always@(posedge ap_aclk) begin - if (ap_areset) begin // case: reset - m_axis_b_tdata <= 0; - s_axis_a_tready <= 1'b1; - state <= READY; - end else if (s_axis_a_tvalid && state == READY) begin // case: load a - m_axis_b_tdata <= s_axis_a_tdata; - s_axis_a_tready <= 1'b0; - state <= BUSY; - end else if (m_axis_b_tdata < 80) // case: increment counter b - m_axis_b_tdata <= m_axis_b_tdata + 1; - else - m_axis_b_tdata <= m_axis_b_tdata; - state <= DONE; - end - - assign m_axis_b_tvalid = (m_axis_b_tdata >= 80) ? 1'b1:1'b0; - ''', - language=dace.Language.SystemVerilog) - - tasklet1 = state.add_tasklet(name='rtl_tasklet1', - inputs={'b'}, - outputs={'c'}, - code=''' - typedef enum [1:0] {READY, BUSY, DONE} state_e; - state_e state; - - always@(posedge ap_aclk) begin - if (ap_areset) begin // case: reset - m_axis_c_tdata <= 0; - s_axis_b_tready <= 1'b1; - state <= READY; - end else if (s_axis_b_tvalid && state == READY) begin // case: load a - m_axis_c_tdata <= s_axis_b_tdata; - s_axis_b_tready <= 1'b0; - state <= BUSY; - end else if (m_axis_c_tdata < 100) // case: increment counter b - m_axis_c_tdata <= m_axis_c_tdata + 1; - else - m_axis_c_tdata <= m_axis_c_tdata; - state <= DONE; - end - - assign m_axis_c_tvalid = (m_axis_c_tdata >= 100) ? 1'b1:1'b0; - ''', - language=dace.Language.SystemVerilog) - - # add input/output array - A = state.add_read('A') - B_w = state.add_write('B') - B_r = state.add_read('B') - C = state.add_write('C') - - # connect input/output array with the tasklet - state.add_edge(A, None, tasklet0, 'a', dace.Memlet('A[0]')) - state.add_edge(tasklet0, 'b', B_w, None, dace.Memlet('B[0]')) - state.add_edge(B_r, None, tasklet1, 'b', dace.Memlet('B[0]')) - state.add_edge(tasklet1, 'c', C, None, dace.Memlet('C[0]')) - - # validate sdfg - sdfg.validate() - - # Execute - - # init data structures - a = np.random.randint(0, 80, 1).astype(np.int32) - b = np.array([0]).astype(np.int32) - c = np.array([0]).astype(np.int32) - - # call program - sdfg(A=a, B=b, C=c) - - # check result - assert b == 80 - assert c == 100 - - -@pytest.mark.verilator -def test_tasklet_map(): - """ - Test the unrolled map support for M tasklets on N vectors of size W. - """ - # add symbols - n = 512 - m = 8 - w = 4 - N = dace.symbol('N') - M = dace.symbol('M') - W = dace.symbol('W') - - # add sdfg - sdfg = dace.SDFG('rtl_tasklet_map') - - # add state - state = sdfg.add_state() - - # add arrays - sdfg.add_array('A', [M, N], dtype=dace.vector(dace.int32, w)) - sdfg.add_array('B', [M, N], dtype=dace.vector(dace.int32, w)) - sdfg.add_array('C', [M, N], dtype=dace.vector(dace.int32, w)) - - mentry, mexit = state.add_map('compute_map', {'k': '0:M'}) - - tasklet = state.add_tasklet(name='rtl_tasklet1', - inputs={'a', 'b'}, - outputs={'c'}, - code=''' -reg [W-1:0][31:0] a_data; -reg a_valid; -reg [W-1:0][31:0] b_data; -reg b_valid; - -// Read A -always@(posedge ap_aclk) begin - if (ap_areset) begin - s_axis_a_tready <= 0; - a_valid <= 0; - a_data <= 0; - end else begin - if (s_axis_a_tready && s_axis_a_tvalid) begin - a_valid <= 1; - a_data <= s_axis_a_tdata; - s_axis_a_tready <= 0; - end else if (m_axis_c_tvalid && m_axis_c_tready) begin - a_valid <= 0; - s_axis_a_tready <= 1; - end else begin - s_axis_a_tready <= ~a_valid; - end - end -end - -// Read B -always@(posedge ap_aclk) begin - if (ap_areset) begin - s_axis_b_tready <= 0; - b_valid <= 0; - b_data <= 0; - end else begin - if (s_axis_b_tready && s_axis_b_tvalid) begin - b_valid <= 1; - b_data <= s_axis_b_tdata; - s_axis_b_tready <= 0; - end else if (m_axis_c_tvalid && m_axis_c_tready) begin - b_valid <= 0; - b_data <= 0; - s_axis_b_tready <= 1; - end else begin - s_axis_b_tready <= ~b_valid; - end - end -end - -// Compute and write C -always@(posedge ap_aclk) begin - if (ap_areset) begin - m_axis_c_tvalid <= 0; - m_axis_c_tdata <= 0; - end else begin - if (m_axis_c_tvalid && m_axis_c_tready) begin - m_axis_c_tvalid <= 0; - end else if (a_valid && b_valid) begin - m_axis_c_tvalid <= 1; - m_axis_c_tdata <= a_data + b_data; - end - end -end''', - language=dace.Language.SystemVerilog) - - A = state.add_read('A') - B = state.add_read('B') - C = state.add_write('C') - - state.add_memlet_path(A, mentry, tasklet, memlet=dace.Memlet('A[k,0:N]'), dst_conn='a') - state.add_memlet_path(B, mentry, tasklet, memlet=dace.Memlet('B[k,0:N]'), dst_conn='b') - state.add_memlet_path(tasklet, mexit, C, memlet=dace.Memlet('C[k,0:N]'), src_conn='c') - - sdfg.specialize({'M': m, 'N': n, 'W': w}) - sdfg.validate() - - # init data structures - a = np.random.randint(0, 100, m * n * w).reshape((m, n, w)).astype(np.int32) - b = np.random.randint(0, 100, m * n * w).reshape((m, n, w)).astype(np.int32) - c = np.zeros((m, n, w)).astype(np.int32) - - # call program - sdfg(A=a, B=b, C=c) - - # check result - assert (c == a + b).all() - - -if __name__ == '__main__': - # These tests should only be run in simulation mode - with dace.config.set_temporary('compiler', 'xilinx', 'mode', value='simulation'): - test_multi_tasklet() - test_tasklet_array() - test_tasklet_double_clk_counters() - test_tasklet_map() - test_tasklet_parameter() - test_tasklet_scalar() - test_tasklet_vector_add() - test_tasklet_vector_conversion() diff --git a/tests/symbol_dependent_transients_test.py b/tests/symbol_dependent_transients_test.py index f67b0dc416..4f8438f430 100644 --- a/tests/symbol_dependent_transients_test.py +++ b/tests/symbol_dependent_transients_test.py @@ -48,8 +48,6 @@ def _make_sdfg(name, storage=dace.dtypes.StorageType.CPU_Heap, isview=False): rednode = standard.Reduce('sum', wcr='lambda a, b : a + b', identity=0) if storage == dace.dtypes.StorageType.GPU_Global: rednode.implementation = 'CUDA (device)' - elif storage == dace.dtypes.StorageType.FPGA_Global: - rednode.implementation = 'FPGAPartialReduction' body2_state.add_node(rednode) write_tmp2 = body2_state.add_write('tmp2') body2_state.add_nedge(read_tmp1, rednode, dace.Memlet.from_array('tmp1', tmp1)) @@ -173,22 +171,6 @@ def test_symbol_dependent_gpu_view(): assert (np.allclose(B, B_ref)) -@pytest.mark.skip('FPGA compiler error') -def test_symbol_dependent_fpga_global_array(): - A = np.random.randn(10, 10, 10) - B = np.ndarray(10, dtype=np.float64) - sdfg = _make_sdfg("symbol_dependent_fpga_global_array", storage=dace.dtypes.StorageType.FPGA_Global) - # Compile manually to avoid simplification - sdfg_exec = sdfg.compile() - sdfg_exec(A=A, B=B, N=10) - del sdfg_exec - B_ref = np.ndarray(10, dtype=np.float64) - for i in range(10): - tmp = A[2:-2, 2:-2, i:] - B_ref[i] = np.sum(tmp) - assert (np.allclose(B, B_ref)) - - def test_symbol_dependent_array_in_map(): @dace.program @@ -221,5 +203,4 @@ def symbol_dependent_array_in_map(A: dace.float32[10]): test_symbol_dependent_gpu_global_array() test_symbol_dependent_pinned_array() # test_symbol_dependent_gpu_view() - # test_symbol_dependent_fpga_global_array() test_symbol_dependent_array_in_map() diff --git a/tests/tensorflow/callback_test.py b/tests/tensorflow/callback_test.py index 3dc359aac8..01b706e765 100644 --- a/tests/tensorflow/callback_test.py +++ b/tests/tensorflow/callback_test.py @@ -6,7 +6,7 @@ @pytest.mark.tensorflow def test_callback(): import tensorflow as tf - from dace.frontend.tensorflow import TFSession + from dace.frontend.ml.tensorflow import TFSession input_image = tf.constant(0.69, tf.float64, [2, 2, 5, 5, 2]) conv_filter = tf.constant(0.01, tf.float64, [1, 1, 1, 2, 2]) diff --git a/tests/tensorflow/compile_test.py b/tests/tensorflow/compile_test.py index 6f58597a3c..1fd00d32ed 100644 --- a/tests/tensorflow/compile_test.py +++ b/tests/tensorflow/compile_test.py @@ -6,7 +6,7 @@ @pytest.mark.tensorflow def test_compile(): import tensorflow as tf - from dace.frontend.tensorflow import TFSession + from dace.frontend.ml.tensorflow import TFSession print('DaCe Tensorflow frontend compile API test') diff --git a/tests/tensorflow/conv_test.py b/tests/tensorflow/conv_test.py index d7c44a98c4..41a6e384d7 100644 --- a/tests/tensorflow/conv_test.py +++ b/tests/tensorflow/conv_test.py @@ -7,7 +7,7 @@ def test_conv(): import tensorflow as tf from tensorflow.python.ops import gen_nn_ops - from dace.frontend.tensorflow import TFSession + from dace.frontend.ml.tensorflow import TFSession inp_shape = [10, 10, 10, 10] filter_shape = [3, 3, 10, 3] strides = [1, 3, 3, 1] diff --git a/tests/tensorflow/fbn_test.py b/tests/tensorflow/fbn_test.py index d3373745fb..fcc5f56fde 100644 --- a/tests/tensorflow/fbn_test.py +++ b/tests/tensorflow/fbn_test.py @@ -7,7 +7,7 @@ def test_fused_batch_norm(): import tensorflow as tf from tensorflow.python.ops import gen_nn_ops - from dace.frontend.tensorflow import TFSession + from dace.frontend.ml.tensorflow import TFSession num_channels = 3 size = [8, 224, 224, num_channels] diff --git a/tests/tensorflow/ops_test.py b/tests/tensorflow/ops_test.py index e685572f64..a34e882967 100644 --- a/tests/tensorflow/ops_test.py +++ b/tests/tensorflow/ops_test.py @@ -6,7 +6,7 @@ @pytest.mark.tensorflow def test_shapen(): import tensorflow as tf - from dace.frontend.tensorflow import TFSession + from dace.frontend.ml.tensorflow import TFSession myshape = [69, 96, 666] num_inputs = 5 @@ -28,7 +28,7 @@ def test_shapen(): @pytest.mark.tensorflow def test_mean(): import tensorflow as tf - from dace.frontend.tensorflow import TFSession + from dace.frontend.ml.tensorflow import TFSession shape = [10, 11, 12, 13] inp = tf.placeholder(tf.float64, shape) @@ -58,7 +58,7 @@ def test_mean(): @pytest.mark.tensorflow def test_addn(): import tensorflow as tf - from dace.frontend.tensorflow import TFSession + from dace.frontend.ml.tensorflow import TFSession shape = [10, 11, 12, 13] inputs = [np.random.rand(*shape) for _ in range(10)] addn_test_0 = tf.add_n(inputs) @@ -81,7 +81,7 @@ def test_addn(): @pytest.mark.tensorflow def test_slice(): import tensorflow as tf - from dace.frontend.tensorflow import TFSession + from dace.frontend.ml.tensorflow import TFSession t = tf.placeholder(tf.int32, [3, 2, 3]) b = tf.placeholder(tf.int32, [3]) s = tf.placeholder(tf.int32, [3]) diff --git a/tests/tensorflow/pool_test.py b/tests/tensorflow/pool_test.py index d9c1a8f4d4..b30b5f01fb 100644 --- a/tests/tensorflow/pool_test.py +++ b/tests/tensorflow/pool_test.py @@ -6,7 +6,7 @@ @pytest.mark.tensorflow def test_pooling(): import tensorflow as tf - from dace.frontend.tensorflow import TFSession + from dace.frontend.ml.tensorflow import TFSession size_in = [1, 112, 112, 3] # size_in = [4, 4, 4, 4] np.random.seed(0) diff --git a/tests/tensorflow/simple_test.py b/tests/tensorflow/simple_test.py index 92917936a4..31abdb8513 100644 --- a/tests/tensorflow/simple_test.py +++ b/tests/tensorflow/simple_test.py @@ -6,7 +6,7 @@ @pytest.mark.tensorflow def test_simple(): import tensorflow as tf - from dace.frontend.tensorflow import TFSession + from dace.frontend.ml.tensorflow import TFSession print('DaCe Tensorflow frontend test') A = np.random.rand(16, 16).astype(np.float32) diff --git a/tests/torch_forward/test_attn.py b/tests/torch_forward/test_attn.py new file mode 100644 index 0000000000..df673e057c --- /dev/null +++ b/tests/torch_forward/test_attn.py @@ -0,0 +1,39 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +import torch + +from dace.ml import DaceModule + +from dace.transformation.dataflow import RedundantSecondArray +from tests.utils import torch_tensors_close + + +@pytest.mark.torch +def test_attn(use_cpp_dispatcher: bool): + B = 2 + H = 16 + P = 64 + N = P * H + SM, SN = 512, 512 + K, Q, V = [torch.randn([SM, B, N]), torch.randn([SN, B, N]), torch.randn([SM, B, N])] + ptmodel = torch.nn.MultiheadAttention(N, H, bias=False) + + pt_outputs = ptmodel(Q, K, V) + + dispatcher_suffix = "cpp" if use_cpp_dispatcher else "ctypes" + dace_model = DaceModule(ptmodel, + sdfg_name=f"test_attn_{dispatcher_suffix}", + compile_torch_extension=use_cpp_dispatcher, + auto_optimize=False) + + dace_outputs = dace_model(Q, K, V) + + torch_tensors_close("outputs_0", pt_outputs[0], dace_outputs[0]) + torch_tensors_close("outputs_1", pt_outputs[1], dace_outputs[1]) + + +if __name__ == "__main__": + test_attn(use_cpp_dispatcher=True) + test_attn(use_cpp_dispatcher=False) diff --git a/tests/torch_forward/test_conv2d.py b/tests/torch_forward/test_conv2d.py new file mode 100644 index 0000000000..89a21bc27e --- /dev/null +++ b/tests/torch_forward/test_conv2d.py @@ -0,0 +1,55 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np +import dace +from dace.ml import DaceModule + + +@pytest.mark.torch +def test_conv2d(use_cpp_dispatcher: bool): + + class Model(nn.Module): + + def __init__(self): + super(Model, self).__init__() + self.conv1 = nn.Conv2d(1, 4, 3) + self.conv2 = nn.Conv2d(4, 4, 3) + + def forward(self, x): + x = F.relu(self.conv1(x)) + return F.relu(self.conv2(x)) + + ptmodel = Model() + x = torch.rand(1, 1, 8, 8) + + dispatcher_suffix = "cpp" if use_cpp_dispatcher else "ctypes" + + @dace.ml.module(sdfg_name=f"test_conv2d_decorator_{dispatcher_suffix}") + class TestDecorator(Model): + pass + + dace_model = DaceModule(ptmodel, + sdfg_name=f"test_conv2d_{dispatcher_suffix}", + compile_torch_extension=use_cpp_dispatcher) + dace_output = dace_model(x) + + dace_model_decorated = TestDecorator() + dace_model_decorated(x) + + torch_output = ptmodel(x) + + np.testing.assert_allclose(torch_output.detach().numpy(), + dace_output.detach().numpy(), + atol=1e-06, + err_msg="Conv2d output mismatch between PyTorch and DaCe") + + +if __name__ == "__main__": + test_conv2d(use_cpp_dispatcher=True) + test_conv2d(use_cpp_dispatcher=False) diff --git a/tests/torch_forward/test_cpp_extension.py b/tests/torch_forward/test_cpp_extension.py new file mode 100644 index 0000000000..c8b1b624f4 --- /dev/null +++ b/tests/torch_forward/test_cpp_extension.py @@ -0,0 +1,120 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +import os + +import numpy as np +import torch +import torch.utils.cpp_extension +from dace.codegen import targets, compiler +from dace.codegen.codeobject import CodeObject +from torch import nn + +import dace +from dace.libraries.torch import PyTorch +from tests.utils import torch_tensors_close + +op_source = """ +#include +#include + +#include + +using torch::Tensor; +using torch::DeviceType; +using torch::autograd::tensor_list; +using torch::autograd::AutogradContext; + +Tensor myadd(const Tensor& self, const Tensor& other) { + static auto op = torch::Dispatcher::singleton() + .findSchemaOrThrow("myops::myadd", "") + .typed(); + return op.call(self, other); +} + +TORCH_LIBRARY(myops, m) { + m.def("myadd(Tensor self, Tensor other) -> Tensor"); +} + +Tensor myadd_cpu(const Tensor& self_, const Tensor& other_) { + TORCH_CHECK(self_.sizes() == other_.sizes()); + TORCH_INTERNAL_ASSERT(self_.device().type() == DeviceType::CPU); + TORCH_INTERNAL_ASSERT(other_.device().type() == DeviceType::CPU); + Tensor self = self_.contiguous(); + Tensor other = other_.contiguous(); + Tensor result = torch::empty(self.sizes(), self.options()); + const float* self_ptr = self.data_ptr(); + const float* other_ptr = other.data_ptr(); + float* result_ptr = result.data_ptr(); + for (int64_t i = 0; i < result.numel(); i++) { + result_ptr[i] = self_ptr[i] + other_ptr[i]; + } + return result; +} + +class MyAddFunction : public torch::autograd::Function { + public: + static Tensor forward( + AutogradContext *ctx, torch::Tensor self, torch::Tensor other) { + at::AutoDispatchBelowADInplaceOrView g; + return myadd(self, other); + } + + static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs) { + auto grad_output = grad_outputs[0]; + return {grad_output, grad_output}; + } +}; + +Tensor myadd_autograd(const Tensor& self, const Tensor& other) { + return MyAddFunction::apply(self, other)[0]; +} + +TORCH_LIBRARY_IMPL(myops, CPU, m) { + m.impl("myadd", myadd_cpu); +} + +TORCH_LIBRARY_IMPL(myops, Autograd, m) { + m.impl("myadd", myadd_autograd); +} +""" + + +@pytest.mark.torch +def test_extension(): + program = CodeObject("myadd", + op_source, + "cpp", + targets.cpu.CPUCodeGen, + "MyAddFunction", + environments={PyTorch.full_class_path()}) + + BUILD_PATH = os.path.join('.dacecache', "pt_extension") + compiler.generate_program_folder(None, [program], BUILD_PATH) + torch.utils.cpp_extension.load( + name='pt_extension', + sources=[os.path.join(BUILD_PATH, 'src', 'cpu', 'myadd.cpp')], + is_python_module=False, + ) + torch.ops.myops.myadd(torch.randn(32, 32), torch.rand(32, 32)) + + +@pytest.mark.torch +def test_module_with_constant(): + + @dace.ml.module(sdfg_name="test_module_with_constant") + class Module(nn.Module): + + def forward(self, x): + return x + 1 + + inp = torch.ones((5, 5)) + output = Module()(inp) + + torch_tensors_close("output", inp + 1, output.cpu()) + + +if __name__ == "__main__": + test_extension() + test_module_with_constant() diff --git a/tests/torch_forward/test_debug_transients.py b/tests/torch_forward/test_debug_transients.py new file mode 100644 index 0000000000..995b986b8d --- /dev/null +++ b/tests/torch_forward/test_debug_transients.py @@ -0,0 +1,36 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +import torch +from torch import nn +import numpy as np + +import dace +from tests.utils import torch_tensors_close + + +@dace.ml.module(debug_transients=True, sdfg_name="test_debug_transients") +class Module(nn.Module): + + def forward(self, x): + y = x + 3 + return y * 5 + + +@pytest.mark.torch +def test_debug_transients(): + + module = Module() + + x = torch.rand(5, 5) + outputs = module(x) + output, y, y2 = outputs + + torch_tensors_close("output", (x + 3) * 5, output) + torch_tensors_close("y2", (x + 3) * 5, y2) + torch_tensors_close("y", x + 3, y) + + +if __name__ == "__main__": + test_debug_transients() diff --git a/tests/torch_forward/test_dlpack.py b/tests/torch_forward/test_dlpack.py new file mode 100644 index 0000000000..879417ea80 --- /dev/null +++ b/tests/torch_forward/test_dlpack.py @@ -0,0 +1,26 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import ctypes + +import pytest + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") + +import dace +import numpy as np + +from dace.libraries.torch.dlpack import array_to_torch_tensor + + +@pytest.mark.torch +def test_desc_to_dlpack(): + mydata = np.arange(6).reshape(2, 3).astype(np.float32) + + ptr = ctypes.c_void_p(mydata.__array_interface__["data"][0]) + tensor = array_to_torch_tensor(ptr, dace.float32[2, 3]) + np.testing.assert_allclose(tensor, mydata), "Initial DLPack tensor conversion failed" + mydata += 1 + np.testing.assert_allclose(tensor, mydata), "DLPack tensor does not share memory with numpy array" + + +if __name__ == "__main__": + test_desc_to_dlpack() diff --git a/tests/torch_forward/test_efficientnet_block.py b/tests/torch_forward/test_efficientnet_block.py new file mode 100644 index 0000000000..6401f561a6 --- /dev/null +++ b/tests/torch_forward/test_efficientnet_block.py @@ -0,0 +1,114 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +pytest.importorskip("efficientnet_pytorch", + reason="efficientnet_pytorch not installed. Please install with: pip install dace[ml-testing]") +import torch +import numpy as np +from dace.transformation.dataflow import TrivialMapElimination +from dace.transformation.interstate import HoistState +from efficientnet_pytorch import get_model_params +from efficientnet_pytorch.model import MBConvBlock + +from dace.ml import DaceModule +from tests.utils import torch_tensors_close + + +@pytest.mark.torch +def test_mbconv(use_cpp_dispatcher: bool): + + with torch.no_grad(): + dace_inputs = torch.rand(8, 32, 224, 224) + torch_inputs = torch.clone(dace_inputs) + + block_params, global_params = get_model_params("efficientnet-b0", {}) + + torch_model = MBConvBlock(block_params[0], global_params).eval() + torch_model.set_swish(memory_efficient=False) + dace_model = MBConvBlock(block_params[0], global_params).eval() + dace_model.set_swish(memory_efficient=False) + + # Get the DaceModule + sdfg_name = f"efficientnet_mbconv_{use_cpp_dispatcher}" + dace_model = DaceModule(dace_model, sdfg_name=sdfg_name, compile_torch_extension=use_cpp_dispatcher) + dace_model.model.load_state_dict(torch_model.state_dict()) + + for (dace_name, dace_value), (torch_name, value) in zip(dace_model.model.state_dict().items(), + torch_model.state_dict().items()): + assert dace_name == torch_name, f"Parameter name mismatch: {dace_name} != {torch_name}" + np.testing.assert_allclose(value, dace_value, err_msg=f"{dace_name} tensors do not match") + + dace_output = dace_model(dace_inputs) + + torch_output = torch_model(torch_inputs) + np.testing.assert_allclose(torch_output.detach(), + dace_output.detach(), + rtol=1e-3, + atol=1e-3, + err_msg="output tensors do not match") + + # check that the batch norm running means and so on are written out correctly + for (dace_name, dace_value), (torch_name, value) in zip(dace_model.model.state_dict().items(), + torch_model.state_dict().items()): + + assert dace_name == torch_name, f"Parameter name mismatch after inference: {dace_name} != {torch_name}" + if "num_batches_tracked" in dace_name: + # we don't update this parameter + continue + np.testing.assert_allclose(value, dace_value, err_msg=f"{dace_name} tensors do not match") + + +@pytest.mark.torch +def test_fast_mb(use_cpp_dispatcher: bool): + with torch.no_grad(): + dace_inputs = torch.rand(8, 32, 224, 224) + torch_inputs = torch.clone(dace_inputs) + + block_params, global_params = get_model_params("efficientnet-b0", {}) + + torch_model = MBConvBlock(block_params[0], global_params).eval() + torch_model.set_swish(memory_efficient=False) + dace_model = MBConvBlock(block_params[0], global_params).eval() + dace_model.set_swish(memory_efficient=False) + + # Get the DaceModule + sdfg_name = f"efficientnet_fast_mbconv_{use_cpp_dispatcher}" + dace_model = DaceModule(dace_model, sdfg_name=sdfg_name, compile_torch_extension=use_cpp_dispatcher) + dace_model.model.load_state_dict(torch_model.state_dict()) + + for (dace_name, dace_value), (torch_name, value) in zip(dace_model.model.state_dict().items(), + torch_model.state_dict().items()): + assert dace_name == torch_name, f"Parameter name mismatch: {dace_name} != {torch_name}" + torch_tensors_close(dace_name, value, dace_value) + + def fuse_everything(module: DaceModule): + sdfg = module.sdfg + + sdfg.apply_transformations_repeated(HoistState) + sdfg.apply_transformations_repeated(TrivialMapElimination) + sdfg.simplify() + + dace_model.append_post_onnx_hook("fuse_sg", fuse_everything) + + dace_output = dace_model(dace_inputs) + + torch_output = torch_model(torch_inputs) + torch_tensors_close("output", torch_output, dace_output, rtol=1e-3, atol=1e-3) + + # check that the batch norm running means and so on are written out correctly + for (dace_name, dace_value), (torch_name, value) in zip(dace_model.model.state_dict().items(), + torch_model.state_dict().items()): + + assert dace_name == torch_name, f"Parameter name mismatch after inference: {dace_name} != {torch_name}" + if "num_batches_tracked" in dace_name: + # we don't update this parameter + continue + torch_tensors_close(dace_name, value, dace_value) + + +if __name__ == "__main__": + test_mbconv(use_cpp_dispatcher=True) + test_mbconv(use_cpp_dispatcher=False) + test_fast_mb(use_cpp_dispatcher=True) + test_fast_mb(use_cpp_dispatcher=False) diff --git a/tests/torch_forward/test_img_op_implementations.py b/tests/torch_forward/test_img_op_implementations.py new file mode 100644 index 0000000000..4ec120b0f9 --- /dev/null +++ b/tests/torch_forward/test_img_op_implementations.py @@ -0,0 +1,95 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +import torch +from torch import nn +import numpy as np + +from dace.ml import DaceModule +from tests.utils import torch_tensors_close + + +class CustomBatchNorm(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, running_mean, running_var, weight, bias, training, momentum, eps): + output = torch.nn.functional.batch_norm(x, running_mean, running_var, weight, bias, training, momentum, eps) + return output, running_mean, running_var + + @staticmethod + def symbolic(g, x, running_mean, running_var, weight, bias, training, momentum, eps): + outputs = g.op("BatchNormalization", + x, + weight, + bias, + running_mean, + running_var, + training_mode_i=int(training), + momentum_f=momentum, + epsilon_f=eps, + outputs=3) + y, new_running_mean, new_running_var = outputs + return y, new_running_mean, new_running_var + + +class BatchNorm2dMeanVar(nn.Module): + + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): + super(BatchNorm2dMeanVar, self).__init__() + self.bn = nn.BatchNorm2d(num_features, + eps=eps, + momentum=momentum, + affine=affine, + track_running_stats=track_running_stats) + + def forward(self, x): + return CustomBatchNorm.apply(x, self.bn.running_mean, self.bn.running_var, self.bn.weight, self.bn.bias, + self.bn.training, self.bn.momentum, self.bn.eps) + + +@pytest.mark.torch +def test_bn(): + + inputs = torch.rand(1, 64, 60, 60) + + # pytorch and onnx specification differ in the way they use momentum: + # pytorch_momentum = 1 - onnx_momentum + # to guarantee matching behavior, we set the momentum to 0.5 + + pt_model = BatchNorm2dMeanVar(64, momentum=0.5) + dace_model = BatchNorm2dMeanVar(64, momentum=0.5) + pt_model.train() + dace_model.train() + + dace_model.load_state_dict(pt_model.state_dict()) + + dace_model = DaceModule(dace_model, sdfg_name="test_bn", training=True) + dace_output, dace_mean, dace_var = dace_model(inputs) + pt_output, pt_mean, pt_var = pt_model(inputs) + + torch_tensors_close("output", pt_output, dace_output) + torch_tensors_close("mean", pt_mean, dace_mean) + torch_tensors_close("var", pt_var, dace_var) + + +@pytest.mark.torch +def test_global_avg_pool(): + inputs = torch.rand(1, 64, 60, 60) + + pt_model = nn.AdaptiveAvgPool2d(1) + dace_model = nn.AdaptiveAvgPool2d(1) + + # Note: AdaptiveAvgPool2d has no parameters, but load_state_dict ensures compatibility + dace_model.load_state_dict(pt_model.state_dict()) + + dace_model = DaceModule(dace_model, sdfg_name="test_global_avg_pool", training=True) + dace_output = dace_model(inputs) + pt_output = pt_model(inputs) + + torch_tensors_close("output", pt_output, dace_output) + + +if __name__ == "__main__": + test_bn() + test_global_avg_pool() diff --git a/tests/torch_forward/test_lenet.py b/tests/torch_forward/test_lenet.py new file mode 100644 index 0000000000..e6db37f740 --- /dev/null +++ b/tests/torch_forward/test_lenet.py @@ -0,0 +1,60 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("onnx", reason="ONNX not installed. Please install with: pip install dace[ml]") +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") + +from dace.ml import DaceModule + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from tests.utils import torch_tensors_close + + +class LeNet(nn.Module): + + def __init__(self): + super(LeNet, self).__init__() + self.conv1 = nn.Conv2d(1, 6, (3, 3)) + self.conv2 = nn.Conv2d(6, 16, (3, 3)) + self.fc1 = nn.Linear(16 * 6 * 6, 120) # 6*6 from image dimension + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = F.max_pool2d(F.relu(self.conv1(x)), 2) + x = F.max_pool2d(F.relu(self.conv2(x)), 2) + + x = x.view(-1, 16 * 6 * 6) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + x = F.log_softmax(x, dim=1) + return x + + +@pytest.mark.torch +def test_lenet(use_cpp_dispatcher: bool): + + input = torch.rand(8, 1, 32, 32, dtype=torch.float32) + + net = LeNet() + dace_net = LeNet() + dace_net.load_state_dict(net.state_dict()) + dispatcher_suffix = "cpp" if use_cpp_dispatcher else "ctypes" + dace_net = DaceModule(dace_net, + sdfg_name=f"test_lenet_{dispatcher_suffix}", + compile_torch_extension=use_cpp_dispatcher) + + torch_output = net(torch.clone(input)) + dace_output = dace_net(torch.clone(input)) + dace_net.sdfg.expand_library_nodes() + + torch_tensors_close("output", torch_output, dace_output) + + +if __name__ == "__main__": + test_lenet(use_cpp_dispatcher=True) + test_lenet(use_cpp_dispatcher=False) diff --git a/tests/torch_forward/test_module_dace_program.py b/tests/torch_forward/test_module_dace_program.py new file mode 100644 index 0000000000..d4dfd0c363 --- /dev/null +++ b/tests/torch_forward/test_module_dace_program.py @@ -0,0 +1,63 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") + +import numpy as np +import torch +from torch import nn + +import dace + +from dace.ml import DaceModule +from tests.utils import tensors_close, torch_tensors_close + + +@pytest.mark.torch +def test_parse_forward_simple(): + torch_module = torch.nn.Sequential(torch.nn.Linear(12, 24), torch.nn.Linear(24, 2)) + dace_module = DaceModule(torch_module, sdfg_name='test_parse_forward_simple') + x = torch.randn(2, 12) + expected = torch_module(x) + result = dace_module(x) + + torch_tensors_close('output', expected, result) + + @dace + def train_step(y): + # output is potentially a gpu tensor + output = dace_module(y) + cpu = np.empty_like(output) + cpu[:] = output + return cpu.sum() + + result = train_step(x) + tensors_close('parsed', expected.sum(), result) + + +@pytest.mark.torch +def test_parse_forward_nested(): + + torch_module = torch.nn.Sequential(torch.nn.Sequential(torch.nn.Linear(12, 24), torch.nn.Linear(24, 2)), + nn.Softmax(dim=1)) + dace_module2 = DaceModule(torch_module, sdfg_name='test_parse_forward_nested') + x = torch.randn(2, 12) + expected = torch_module(x) + result = dace_module2(x) + + torch_tensors_close('output', expected, result) + + @dace + def train_step(y): + output = dace_module2(y) + cpu = np.empty_like(output) + cpu[:] = output + return cpu.sum() + + result = train_step(x) + tensors_close('parsed', expected.sum(), result) + + +if __name__ == "__main__": + test_parse_forward_simple() + test_parse_forward_nested() diff --git a/tests/torch_forward/test_multi_output.py b/tests/torch_forward/test_multi_output.py new file mode 100644 index 0000000000..070902b8da --- /dev/null +++ b/tests/torch_forward/test_multi_output.py @@ -0,0 +1,44 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +import torch +from torch import nn + +from dace.ml import DaceModule +from tests.utils import torch_tensors_close + + +class Model(nn.Module): + + def __init__(self, new_shape): + super(Model, self).__init__() + self.new_shape = new_shape + + def forward(self, x): + return x + 1, x + 2 + + +@pytest.mark.torch +def test_multiple_outputs(use_cpp_dispatcher: bool): + + ptmodel = Model([5, 5]) + x = torch.rand([25]) + + torch_outputs = ptmodel(torch.clone(x)) + + dispatcher_suffix = "cpp" if use_cpp_dispatcher else "ctypes" + dace_model = DaceModule(ptmodel, + sdfg_name=f"test_multi_output_{dispatcher_suffix}", + auto_optimize=False, + compile_torch_extension=use_cpp_dispatcher) + + dace_outputs = dace_model(x) + + torch_tensors_close("output_0", torch_outputs[0], dace_outputs[0]) + torch_tensors_close("output_1", torch_outputs[1], dace_outputs[1]) + + +if __name__ == "__main__": + test_multiple_outputs(use_cpp_dispatcher=True) + test_multiple_outputs(use_cpp_dispatcher=False) diff --git a/tests/torch_forward/test_reshape.py b/tests/torch_forward/test_reshape.py new file mode 100644 index 0000000000..ad48d6c453 --- /dev/null +++ b/tests/torch_forward/test_reshape.py @@ -0,0 +1,38 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +import torch +from torch import nn +from dace.ml import DaceModule +from tests.utils import torch_tensors_close + + +class Model(nn.Module): + + def __init__(self, new_shape): + super(Model, self).__init__() + self.new_shape = new_shape + + def forward(self, x): + x = x.reshape(self.new_shape) + return x + + +@pytest.mark.torch +def test_reshape_module(): + + ptmodel = Model([5, 5]) + x = torch.rand([25]) + + torch_output = ptmodel(torch.clone(x)) + + dace_model = DaceModule(ptmodel, sdfg_name="test_reshape_module", auto_optimize=False, dummy_inputs=(x, )) + + dace_output = dace_model(x) + + torch_tensors_close("output", torch_output, dace_output) + + +if __name__ == "__main__": + test_reshape_module() diff --git a/tests/transformations/interstate/loop_unroll_test.py b/tests/transformations/interstate/loop_unroll_test.py index 6d0ead3d80..d744361d36 100644 --- a/tests/transformations/interstate/loop_unroll_test.py +++ b/tests/transformations/interstate/loop_unroll_test.py @@ -1,5 +1,7 @@ # Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import re + import dace from dace.memlet import Memlet from dace.properties import CodeBlock @@ -80,7 +82,23 @@ def test_empty_loop(): assert len(loops) == 0 +def test_compiler_unroll_pragma(): + sdfg = _get_sdfg(add_state_before=False, l=5) + code = sdfg.generate_code()[0].clean_code + unroll_pragma = re.search(r'#pragma unroll', code) is None + assert unroll_pragma, "Unroll pragma found in generated code." + loops = {n for n in sdfg.all_control_flow_regions() if isinstance(n, LoopRegion)} + assert len(loops) == 1 + loop = next(iter(loops)) + loop.unroll = True + loop.unroll_factor = 5 + unrolled_loop_code = sdfg.generate_code()[0].clean_code + unroll_pragma = re.search(r'#pragma unroll 5', unrolled_loop_code) is not None + assert unroll_pragma, "Unroll pragma not found in generated code after setting unroll_pragma to True." + + if __name__ == "__main__": test_if_block_inside_for() test_empty_loop() test_top_level_for() + test_compiler_unroll_pragma() diff --git a/tests/transformations/mapfusion_fpga_test.py b/tests/transformations/mapfusion_fpga_test.py deleted file mode 100644 index 6d6447923a..0000000000 --- a/tests/transformations/mapfusion_fpga_test.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -from dace.fpga_testing import fpga_test, xilinx_test -from dace.transformation.dataflow import MapFusionVertical -from dace.transformation.interstate import FPGATransformSDFG -from .map_fusion_vertical_test import multiple_fusions, fusion_with_transient -import numpy as np -from dace.config import set_temporary - - -@fpga_test() -def test_multiple_fusions_fpga(): - sdfg = multiple_fusions.to_sdfg() - sdfg.simplify() - assert sdfg.apply_transformations_repeated(MapFusionVertical) >= 2 - assert sdfg.apply_transformations_repeated(FPGATransformSDFG) == 1 - A = np.random.rand(10, 20).astype(np.float32) - B = np.zeros_like(A) - C = np.zeros_like(A) - out = np.zeros(shape=1, dtype=np.float32) - sdfg(A=A, B=B, C=C, out=out) - diff1 = np.linalg.norm(A * A + 1 - B) - diff2 = np.linalg.norm(A * A + 2 - C) - assert diff1 <= 1e-4 - assert diff2 <= 1e-4 - return sdfg - - -@fpga_test(assert_ii_1=False) -def test_fusion_with_transient_fpga(): - # To achieve II=1 with Xilinx, we need to decouple reads/writes from memory - A = np.random.rand(2, 20) - expected = A * A * 2 - sdfg = fusion_with_transient.to_sdfg() - sdfg.simplify() - assert sdfg.apply_transformations_repeated(MapFusionVertical) >= 2 - assert sdfg.apply_transformations_repeated(FPGATransformSDFG) == 1 - sdfg(A=A) - assert np.allclose(A, expected) - return sdfg - - -@xilinx_test(assert_ii_1=True) -def test_fusion_with_transient_fpga_decoupled(): - - A = np.random.rand(2, 20) - expected = A * A * 2 - sdfg = fusion_with_transient.to_sdfg() - sdfg.simplify() - assert sdfg.apply_transformations_repeated(MapFusionVertical) >= 2 - assert sdfg.apply_transformations_repeated(FPGATransformSDFG) == 1 - with set_temporary("compiler", "xilinx", "decouple_array_interfaces", value=True): - sdfg(A=A) - assert np.allclose(A, expected) - return sdfg - - -if __name__ == "__main__": - multiple_fusions_fpga(None) - fusion_with_transient_fpga(None) diff --git a/tests/type_inference_test.py b/tests/type_inference_test.py index 6cfc13b13d..f58fb42a33 100644 --- a/tests/type_inference_test.py +++ b/tests/type_inference_test.py @@ -3,7 +3,7 @@ import numpy as np import sympy as sp from dace.config import Config -from dace.codegen.tools import type_inference +from dace.sdfg import type_inference from dace import dtypes import ast diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000000..b4ddb24314 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,60 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. + +import os +import typing +import urllib.request, urllib.parse +import pathlib +import pytest +import dace +import numpy as np + + +def get_data_file(url, directory_name=None) -> str: + """ Get a data file from ``url``, cache it locally and return the local file path to it. + + :param url: the url to download from. + :param directory_name: an optional relative directory path where the file will be downloaded to. + :return: the path of the downloaded file. + """ + + data_directory = (pathlib.Path(dace.__file__).parent.parent / 'tests' / 'data') + + if directory_name is not None: + data_directory /= directory_name + + data_directory.mkdir(exist_ok=True, parents=True) + + file_name = os.path.basename(urllib.parse.urlparse(url).path) + file_path = str(data_directory / file_name) + + if not os.path.exists(file_path): + urllib.request.urlretrieve(url, file_path) + return file_path + + +def tensors_close(name, expected, result, rtol=1e-5, atol=1e-5): + + def to_numpy(x): + if hasattr(x, 'detach'): + x = x.detach() + if hasattr(x, 'cpu'): + x = x.cpu() + if hasattr(x, 'numpy'): + x = x.numpy() + return x + + expected = to_numpy(expected) + result = to_numpy(result) + np.testing.assert_allclose(result, expected, rtol=rtol, atol=atol, err_msg=f'{name} not close') + + +def torch_tensors_close(name, torch_v, dace_v, rtol=1e-5, atol=1e-4): + """ + Assert that the two torch tensors are close. Prints a nice error string if not. + """ + # check that the device is correct + assert torch_v.device == dace_v.device, "Tensors are on different devices" + + torch_v = torch_v.detach().cpu().numpy() + dace_v = dace_v.detach().cpu().numpy() + np.testing.assert_allclose(dace_v, torch_v, rtol=rtol, atol=atol, err_msg=f'{name} not close') diff --git a/tutorials/codegen.ipynb b/tutorials/codegen.ipynb index 84b2cf7f01..e3909a544c 100644 --- a/tutorials/codegen.ipynb +++ b/tutorials/codegen.ipynb @@ -47,7 +47,7 @@ "from dace import registry\n", "from dace.sdfg.scope import ScopeSubgraphView\n", "from dace.codegen.prettycode import CodeIOStream\n", - "from dace.codegen.targets.target import TargetCodeGenerator\n", + "from dace.codegen.target import TargetCodeGenerator\n", "from dace.codegen.targets.framecode import DaCeCodeGenerator\n", "from dace.codegen.targets.cpp import sym2cpp" ]