From 2aed27c8fe5b32ed46dbfb579c2eddebb7882aa6 Mon Sep 17 00:00:00 2001 From: Vidushi Vashishth Date: Fri, 13 Mar 2026 17:56:25 -0700 Subject: [PATCH 1/7] Adding configuration and initial fixes for pylint --- .pylintrc | 415 ++++++++ Ironwood/src/benchmark_attention.py | 17 +- Ironwood/src/benchmark_collectives.py | 1023 ++++++++++--------- Ironwood/src/benchmark_compute.py | 46 +- Ironwood/src/benchmark_gemm.py | 77 +- Ironwood/src/benchmark_gemm_numerics.py | 58 +- Ironwood/src/benchmark_gemm_throttling.py | 160 +-- Ironwood/src/benchmark_hbm.py | 6 +- Ironwood/src/benchmark_host_device.py | 61 +- Ironwood/src/benchmark_inference_compute.py | 36 +- Ironwood/src/benchmark_send_recv.py | 115 ++- Ironwood/src/benchmark_utils.py | 367 ++++--- Ironwood/src/common.py | 2 + Ironwood/src/run_benchmark.py | 70 +- requirements.txt | 4 + src/all_gather.py | 267 ++--- src/all_reduce.py | 261 ++--- src/benchmark_attention.py | 18 +- src/benchmark_collectives.py | 460 ++++++--- src/benchmark_convolution.py | 32 +- src/benchmark_hbm.py | 4 +- src/benchmark_matmul.py | 129 ++- src/benchmark_utils.py | 50 +- src/run_benchmark.py | 86 +- 24 files changed, 2400 insertions(+), 1364 deletions(-) create mode 100644 .pylintrc diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000..70c1cc1 --- /dev/null +++ b/.pylintrc @@ -0,0 +1,415 @@ +# This Pylint rcfile contains a best-effort configuration to uphold the +# best-practices and style described in the Google Python style guide: +# https://google.github.io/styleguide/pyguide.html +# +# Its canonical open-source location is: +# https://google.github.io/styleguide/pylintrc + +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +[MAIN] + +# Files or directories to be skipped. They should be base names, not paths. +ignore=third_party + +# Files or directories matching the regex patterns are skipped. The regex +# matches against base names, not paths. +ignore-patterns= + +# Pickle collected data for later comparisons. +persistent=no + +# List of plugins (as comma separated values of python modules names) to load, +# usually to register additional checkers. +load-plugins= + +# Use multiple processes to speed up Pylint. +jobs=4 + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED +confidence= + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +#enable= + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once).You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use"--disable=all --enable=classes +# --disable=W" +disable=R, + abstract-method, + apply-builtin, + arguments-differ, + attribute-defined-outside-init, + backtick, + bad-option-value, + basestring-builtin, + buffer-builtin, + c-extension-no-member, + consider-using-enumerate, + cmp-builtin, + cmp-method, + coerce-builtin, + coerce-method, + delslice-method, + div-method, + eq-without-hash, + execfile-builtin, + file-builtin, + filter-builtin-not-iterating, + fixme, + getslice-method, + global-statement, + hex-method, + idiv-method, + implicit-str-concat, + import-error, + import-self, + import-star-module-level, + input-builtin, + intern-builtin, + invalid-str-codec, + locally-disabled, + long-builtin, + long-suffix, + map-builtin-not-iterating, + misplaced-comparison-constant, + missing-function-docstring, + metaclass-assignment, + next-method-called, + next-method-defined, + no-absolute-import, + no-init, # added + no-member, + no-name-in-module, + no-self-use, + nonzero-method, + oct-method, + old-division, + old-ne-operator, + old-octal-literal, + old-raise-syntax, + parameter-unpacking, + print-statement, + raising-string, + range-builtin-not-iterating, + raw_input-builtin, + rdiv-method, + reduce-builtin, + relative-import, + reload-builtin, + round-builtin, + setslice-method, + signature-differs, + standarderror-builtin, + suppressed-message, + sys-max-int, + trailing-newlines, + unichr-builtin, + unicode-builtin, + unnecessary-pass, + unpacking-in-except, + useless-else-on-loop, + useless-suppression, + using-cmp-argument, + wrong-import-order, + xrange-builtin, + zip-builtin-not-iterating, + + +[REPORTS] + +# Set the output format. Available formats are text, parseable, colorized, msvs +# (visual studio) and html. You can also give a reporter class, eg +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Tells whether to display a full report or only the messages +reports=no + +# Activate the evaluation score. +score=no + +# Python expression which should return a note less than 10 (10 is the highest +# note). You have access to the variables errors warning, statement which +# respectively contain the number of errors / warnings messages and the total +# number of statements analyzed. This is used by the global evaluation report +# (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details +#msg-template= + + +[BASIC] + +# Good variable names which should always be accepted, separated by a comma +good-names=main,_ + +# Bad variable names which should always be refused, separated by a comma +bad-names= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Include a hint for the correct naming format with invalid-name +include-naming-hint=no + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl + +# Regular expression matching correct function names +function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ + +# Regular expression matching correct variable names +variable-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct constant names +const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression matching correct attribute names +attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ + +# Regular expression matching correct argument names +argument-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct class attribute names +class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression matching correct inline iteration names +inlinevar-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct class names +class-rgx=^_?[A-Z][a-zA-Z0-9]*$ + +# Regular expression matching correct module names +module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$ + +# Regular expression matching correct method names +method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=12 + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis. It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + + +[FORMAT] + +# Maximum number of characters on a single line. +max-line-length=80 + +# TODO(https://github.com/pylint-dev/pylint/issues/3352): Direct pylint to exempt +# lines made too long by directives to pytype. + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=(?x)( + ^\s*(\#\ )??$| + ^\s*(from\s+\S+\s+)?import\s+.+$) + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=yes + +# Maximum number of lines in a module +max-module-lines=99999 + +# String used as indentation unit. The internal Google style guide mandates 2 +# spaces. Google's externaly-published style guide says 4, consistent with +# PEP 8. +indent-string=' ' + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=TODO + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=yes + + +[VARIABLES] + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# A regular expression matching the name of dummy variables (i.e. expectedly +# not used). +dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid to define new builtins when possible. +additional-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_,_cb + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools + + +[LOGGING] + +# Logging modules to check that the string format arguments are in logging +# function parameter format +logging-modules=logging,absl.logging,tensorflow.io.logging + + +[SIMILARITIES] + +# Minimum lines number of a similarity. +min-similarity-lines=4 + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + + +[SPELLING] + +# Spelling dictionary name. Available dictionaries: none. To make it working +# install python-enchant package. +spelling-dict= + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to indicated private dictionary in +# --spelling-private-dict-file option instead of raising a message. +spelling-store-unknown-words=no + + +[IMPORTS] + +# Deprecated modules which should not be used, separated by a comma +deprecated-modules=regsub, + TERMIOS, + Bastion, + rexec, + sets + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled) +import-graph= + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled) +ext-import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled) +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant, absl + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + + +[CLASSES] + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls, + class_ + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs diff --git a/Ironwood/src/benchmark_attention.py b/Ironwood/src/benchmark_attention.py index 95ae524..3e16748 100644 --- a/Ironwood/src/benchmark_attention.py +++ b/Ironwood/src/benchmark_attention.py @@ -1,6 +1,4 @@ -"""A script to benchmark tokamax splash attention implementation. - -""" +"""A script to benchmark tokamax splash attention implementation.""" import os @@ -19,13 +17,13 @@ splash_attention_mask as mask_lib, ) import tune_jax + tune_jax.tune_logger.setLevel(logging.ERROR) # pylint: disable=g-importing-member,g-bad-import-order -os.environ["LIBTPU_INIT_ARGS"] = ( - "--xla_tpu_dvfs_p_state=7" -) +os.environ["LIBTPU_INIT_ARGS"] = "--xla_tpu_dvfs_p_state=7" + def generate_qkv_separate_dims( batch_size: int, @@ -41,7 +39,9 @@ def generate_qkv_separate_dims( key = jax.random.PRNGKey(seed) key_q, key_k, key_v = jax.random.split(key, 3) q = jax.random.normal(key_q, (batch_size, q_heads, q_seq_len, qk_head_dim)) - k = jax.random.normal(key_k, (batch_size, kv_heads, kv_seq_len, qk_head_dim)) + k = jax.random.normal( + key_k, (batch_size, kv_heads, kv_seq_len, qk_head_dim) + ) v = jax.random.normal(key_v, (batch_size, kv_heads, kv_seq_len, v_head_dim)) return q, k, v @@ -87,6 +87,7 @@ def f(q, k, v): kernel_ = jax.vmap(kernel, in_axes=(0, 0, 0)) # batch vmap kernel_ = jax.vmap(kernel_, in_axes=(0, 0, 0)) # mqa vmap return kernel_(q, k, v) + else: kernel = splash.make_splash_mha_single_device(mask, config=config) f = jax.jit(jax.vmap(kernel, in_axes=(0, 0, 0))) @@ -263,7 +264,7 @@ def attention_fn( trace_dir=trace_dir, event_name_str_list=[ f"{event_filter_regex}_no_residuals.1", - ] + ], ) return {"time_ms_list": time_ms_list, "output": output} diff --git a/Ironwood/src/benchmark_collectives.py b/Ironwood/src/benchmark_collectives.py index 0889038..7c97afd 100644 --- a/Ironwood/src/benchmark_collectives.py +++ b/Ironwood/src/benchmark_collectives.py @@ -30,51 +30,54 @@ GLOBAL_PSTATE = 7 LOG_SPARSECORE_USAGE = False + def create_mesh(ici_size: int, mesh_shape: str) -> Mesh: - """Creates a mesh with the given ICI size.""" - devices_needed = ici_size - devices = jax.devices() + """Creates a mesh with the given ICI size.""" + devices_needed = ici_size + devices = jax.devices() - if len(devices) < devices_needed: - raise ValueError(f"Need {devices_needed} devices, but found {len(devices)}") - devices = devices[:devices_needed] - mesh_shape = mesh_shape.split("x") - mesh_shape = [int(i) for i in mesh_shape] + if len(devices) < devices_needed: + raise ValueError( + f"Need {devices_needed} devices, but found {len(devices)}" + ) + devices = devices[:devices_needed] + mesh_shape = mesh_shape.split("x") + mesh_shape = [int(i) for i in mesh_shape] - shape = mesh_shape if mesh_shape else (ici_size,) + shape = mesh_shape if mesh_shape else (ici_size,) - axis_names = [f"d_{i}" for i in range(len(shape))] + axis_names = [f"d_{i}" for i in range(len(shape))] - first_device = devices[0] - device_kind = first_device.device_kind - print("Device kind: ", device_kind) - mesh_devices = mesh_utils.create_device_mesh(shape, devices=jax.devices()) - mesh = Mesh(mesh_devices, axis_names) - return mesh + first_device = devices[0] + device_kind = first_device.device_kind + print("Device kind: ", device_kind) + mesh_devices = mesh_utils.create_device_mesh(shape, devices=jax.devices()) + mesh = Mesh(mesh_devices, axis_names) + return mesh def get_sharding_axis(dim_str: str, mesh: Mesh) -> tuple[str, ...]: - """Computes sharding axis names from dimension string like '1x4' and mesh.""" - dim_tuple = dim_str.split("x") - dim_tuple = tuple(int(dim) for dim in dim_tuple) - sharding_axis = tuple( - name for i, name in enumerate(mesh.axis_names) if dim_tuple[i] > 1 - ) - return sharding_axis + """Computes sharding axis names from dimension string like '1x4' and mesh.""" + dim_tuple = dim_str.split("x") + dim_tuple = tuple(int(dim) for dim in dim_tuple) + sharding_axis = tuple( + name for i, name in enumerate(mesh.axis_names) if dim_tuple[i] > 1 + ) + return sharding_axis def get_metrics_helper( params: Dict[str, Any], ) -> Dict[str, Any]: - """Helper function to build the metrics and metadata for the benchmark.""" - exclude_keys = ["ici_average_time_ms" ,"xla_output"] - metadata = { - key: value - for key, value in params - if value is not None and key not in exclude_keys - } - metadata["dtype"] = get_real_dtype_bytes(metadata["dtype"].dtype) - return metadata + """Helper function to build the metrics and metadata for the benchmark.""" + exclude_keys = ["ici_average_time_ms", "xla_output"] + metadata = { + key: value + for key, value in params + if value is not None and key not in exclude_keys + } + metadata["dtype"] = get_real_dtype_bytes(metadata["dtype"].dtype) + return metadata def unified_ici_collectives_metrics( @@ -89,105 +92,105 @@ def unified_ici_collectives_metrics( op_type: str, trace_dir: str = None, ) -> Dict[str, Any]: - """Calculates the metrics for the ICI collectives benchmark.""" - + """Calculates the metrics for the ICI collectives benchmark.""" - average_time_ms_statistics = MetricsStatistics( + average_time_ms_statistics = MetricsStatistics( metrics_list=ici_average_time_ms_list, metrics_name="step_time_ms" ) - hlo_input_shape = hlo_output_shape = hlo_replica_groups = None - hlo_first_replica_group = [] - - input_num_elements = matrix_shape[0] * matrix_shape[1] * matrix_shape[2] - dtype_name = dtype.dtype.name - dtype_bytes = get_real_dtype_bytes(dtype.dtype) - if xla_output: - xla_output_json = json.loads(xla_output) - hlo_input_shape = xla_output_json.get("hlo_input_shape") - hlo_output_shape = xla_output_json.get("hlo_output_shape") - hlo_replica_groups = xla_output_json.get("hlo_replica_groups") - hlo_first_replica_group = xla_output_json.get("hlo_first_replica_group") - - rank = max(len(hlo_first_replica_group), 1) - - if all(i % 2 == 0 for i in hlo_first_replica_group): - replica_group_type = "parallel" - else: - replica_group_type = "non-parallel" - - if replica_group_type == "parallel": - participating_ranks = rank - 1 - tf_multiplier = 2 - else: - participating_ranks = rank - 2 - tf_multiplier = 1 - - transferred_data = 0 - if op_type == "AG": - transferred_data = ( - input_num_elements - * participating_ranks - * dtype_bytes - * 0.000000001 - * tf_multiplier - ) - elif op_type == "AR": - transferred_data = ( - input_num_elements - * participating_ranks - * dtype_bytes - * 0.000000001 - * tf_multiplier - * 2 - /rank - ) - elif op_type in ["RS", "A2A"]: - transferred_data = ( - input_num_elements - * participating_ranks - * dtype_bytes - * 0.000000001 - * tf_multiplier - / rank - ) - - - sparsecore_used = "NA" - if LOG_SPARSECORE_USAGE: - print("trace_dir: ", trace_dir) - if trace_dir: - sparsecore_used = find_sparsecore_usage_from_xplane(trace_dir) - print("sparsecore_used: ", sparsecore_used) - print("hlo first replica group: ", hlo_first_replica_group) - - metadata = { - "iteration": iteration, - "op_type": op_type, - "replica_group_type": replica_group_type, - "rank": rank, - "mesh_shape": mesh_shape, - "op_dimension": op_dimension, - "sharding_strategy": sharding_strategy, - "input_num_elements": input_num_elements, - "matrix_shape": json.dumps(f"({matrix_shape})"), - "transferred_data (GB)": transferred_data, - "dtype_bytes": dtype_bytes, - "hlo_input_shape": json.dumps(hlo_input_shape), - "hlo_output_shape": json.dumps(hlo_output_shape), - "hlo_replica_groups": json.dumps(hlo_replica_groups), - "sparsecore_used": sparsecore_used, - } - achieved_bw = [transferred_data*1000/my_time for my_time in ici_average_time_ms_list] - achieved_bw_statistics = MetricsStatistics( + hlo_input_shape = hlo_output_shape = hlo_replica_groups = None + hlo_first_replica_group = [] + + input_num_elements = matrix_shape[0] * matrix_shape[1] * matrix_shape[2] + dtype_bytes = get_real_dtype_bytes(dtype.dtype) + if xla_output: + xla_output_json = json.loads(xla_output) + hlo_input_shape = xla_output_json.get("hlo_input_shape") + hlo_output_shape = xla_output_json.get("hlo_output_shape") + hlo_replica_groups = xla_output_json.get("hlo_replica_groups") + hlo_first_replica_group = xla_output_json.get("hlo_first_replica_group") + + rank = max(len(hlo_first_replica_group), 1) + + if all(i % 2 == 0 for i in hlo_first_replica_group): + replica_group_type = "parallel" + else: + replica_group_type = "non-parallel" + + if replica_group_type == "parallel": + participating_ranks = rank - 1 + tf_multiplier = 2 + else: + participating_ranks = rank - 2 + tf_multiplier = 1 + + transferred_data = 0 + if op_type == "AG": + transferred_data = ( + input_num_elements + * participating_ranks + * dtype_bytes + * 0.000000001 + * tf_multiplier + ) + elif op_type == "AR": + transferred_data = ( + input_num_elements + * participating_ranks + * dtype_bytes + * 0.000000001 + * tf_multiplier + * 2 + / rank + ) + elif op_type in ["RS", "A2A"]: + transferred_data = ( + input_num_elements + * participating_ranks + * dtype_bytes + * 0.000000001 + * tf_multiplier + / rank + ) + + sparsecore_used = "NA" + if LOG_SPARSECORE_USAGE: + print("trace_dir: ", trace_dir) + if trace_dir: + sparsecore_used = find_sparsecore_usage_from_xplane(trace_dir) + print("sparsecore_used: ", sparsecore_used) + print("hlo first replica group: ", hlo_first_replica_group) + + metadata = { + "iteration": iteration, + "op_type": op_type, + "replica_group_type": replica_group_type, + "rank": rank, + "mesh_shape": mesh_shape, + "op_dimension": op_dimension, + "sharding_strategy": sharding_strategy, + "input_num_elements": input_num_elements, + "matrix_shape": json.dumps(f"({matrix_shape})"), + "transferred_data (GB)": transferred_data, + "dtype_bytes": dtype_bytes, + "hlo_input_shape": json.dumps(hlo_input_shape), + "hlo_output_shape": json.dumps(hlo_output_shape), + "hlo_replica_groups": json.dumps(hlo_replica_groups), + "sparsecore_used": sparsecore_used, + } + achieved_bw = [ + transferred_data * 1000 / my_time + for my_time in ici_average_time_ms_list + ] + achieved_bw_statistics = MetricsStatistics( metrics_list=achieved_bw, metrics_name="achieved_bw (GB/s)" ) - metrics = {} - metrics.update(average_time_ms_statistics.serialize_statistics()) - metrics.update(achieved_bw_statistics.serialize_statistics()) + metrics = {} + metrics.update(average_time_ms_statistics.serialize_statistics()) + metrics.update(achieved_bw_statistics.serialize_statistics()) - print("metadata: ", metadata) - print("metrics: ", metrics) - return metadata, metrics + print("metadata: ", metadata) + print("metrics: ", metrics) + return metadata, metrics def psum_benchmark( @@ -200,114 +203,114 @@ def psum_benchmark( num_runs: int = 1, trace_dir: str = None, ) -> Dict[str, Any]: - """Benchmarks the psum collective operation. - - Args: - matrix_dim: The benchmark is run on a matrix with shape (matrix_dim, - matrix_dim). - mesh_shape: The shape of the mesh. - op_dimension: The dimension of the operation. - ici_size: The number of chips in a single slice. If 1, then no ICI benchmark - is run. - dtype: The data type of the matrix. - num_runs: The number of runs to perform. - trace_dir: The directory to save the trace to. - - Returns: - The measured time for the ICI benchmark. - """ - - libtpu_init_args = [ - "--xla_jf_debug_level=3", - "--xla_sc_disable_megacore_partitioning=true", - "--xla_tpu_disable_sparse_core_collective_offload_remover=true", - "--xla_tpu_enable_all_reduce_offload_tracing=true", - "--xla_tpu_enable_all_reduce_scatter_fusion=false", - "--xla_tpu_enable_sparse_core_collective_offload_all_reduce=true", - "--xla_tpu_pad_operations_input_tiles=true", - "--xla_tpu_sparse_core_all_reduce_offload_min_size_in_bytes=0", - "--xla_tpu_use_tc_device_shape_on_sc=true", - f"--xla_tpu_dvfs_p_state={GLOBAL_PSTATE}", - ] - os.environ["LIBTPU_INIT_ARGS"] = " ".join(libtpu_init_args) - mesh = create_mesh(ici_size, mesh_shape) - key = jax.random.key(SEED) - lhs_sharding = get_lhs_named_shading(mesh, GLOBAL_SHARDING_STRATEGY) - out_sharding = get_out_sharding(GLOBAL_SHARDING_STRATEGY) - sharding_axis = get_sharding_axis(sharding_strategy, mesh) - - # 1. Define the Primitive - zero_crop_p = Primitive("zero_crop") - - # 2. Implement Abstract Evaluation (output shape/dtype is same as input) - def zero_crop_abstract_eval(x): - return core.ShapedArray(x.shape, x.dtype) - - zero_crop_p.def_abstract_eval(zero_crop_abstract_eval) - - # 3. Implement the Lowering Rule using jax.ffi - def zero_crop_lowering(ctx, x): - (aval_in,) = ctx.avals_in - (aval_out,) = ctx.avals_out - - return ffi.ffi_lowering( - "ZeroCrop", - operands=[x], - operand_layouts=mlir.default_layouts(ctx, aval_in), - result_layouts=mlir.default_layouts(ctx, aval_out), - )(ctx, x) - - mlir.register_lowering(zero_crop_p, zero_crop_lowering) - - # 4. Create a Python Wrapper using jax.ffi.ffi_call - def zero_crop(x): - return ffi.ffi_call( - "ZeroCrop", - result_shape_dtypes=jax.ShapeDtypeStruct(x.shape, x.dtype), - has_side_effect=True, - )(x) - - def f(x): - with jax.named_scope(MARKER): - y = jax.lax.psum(x, sharding_axis) - # Insert the custom call to prevent y from being a live out buffer - return zero_crop(y) - - jit_sharded_f = jax.jit( - shard_map( - f, - mesh, - in_specs=lhs_sharding.spec, - out_specs=out_sharding, - check_rep=False, - ) - ) - m = matrix_dim - n = BASE_SHAPE[1] - k = BASE_SHAPE[2] - - def data_generator(): - """Creates new random data on host and puts it on device.""" - nonlocal key # Use and update the outer 'key' - - matrix = jnp.ones((m, n, k), dtype=dtype) - return (matrix,) - - print("Running psum benchmark", num_runs, matrix_dim) - time_ms_list = multiple_iteration_timeit_from_trace( - jit_sharded_f, - data_generator, - matrix_dim=f"{m}x{n}x{k}", - tries=num_runs, - task="psum_ici_op", - trace_dir=trace_dir, - ) - return { - "ici_average_time_ms_list": time_ms_list, - "matrix_shape": (m, n, k), - "op_type": "AR", - "trace_dir": trace_dir, - } + """Benchmarks the psum collective operation. + + Args: + matrix_dim: The benchmark is run on a matrix with shape (matrix_dim, + matrix_dim). + mesh_shape: The shape of the mesh. + op_dimension: The dimension of the operation. + ici_size: The number of chips in a single slice. If 1, then no ICI benchmark + is run. + dtype: The data type of the matrix. + num_runs: The number of runs to perform. + trace_dir: The directory to save the trace to. + + Returns: + The measured time for the ICI benchmark. + """ + + libtpu_init_args = [ + "--xla_jf_debug_level=3", + "--xla_sc_disable_megacore_partitioning=true", + "--xla_tpu_disable_sparse_core_collective_offload_remover=true", + "--xla_tpu_enable_all_reduce_offload_tracing=true", + "--xla_tpu_enable_all_reduce_scatter_fusion=false", + "--xla_tpu_enable_sparse_core_collective_offload_all_reduce=true", + "--xla_tpu_pad_operations_input_tiles=true", + "--xla_tpu_sparse_core_all_reduce_offload_min_size_in_bytes=0", + "--xla_tpu_use_tc_device_shape_on_sc=true", + f"--xla_tpu_dvfs_p_state={GLOBAL_PSTATE}", + ] + os.environ["LIBTPU_INIT_ARGS"] = " ".join(libtpu_init_args) + mesh = create_mesh(ici_size, mesh_shape) + key = jax.random.key(SEED) + lhs_sharding = get_lhs_named_shading(mesh, GLOBAL_SHARDING_STRATEGY) + out_sharding = get_out_sharding(GLOBAL_SHARDING_STRATEGY) + sharding_axis = get_sharding_axis(sharding_strategy, mesh) + + # 1. Define the Primitive + zero_crop_p = Primitive("zero_crop") + + # 2. Implement Abstract Evaluation (output shape/dtype is same as input) + def zero_crop_abstract_eval(x): + return core.ShapedArray(x.shape, x.dtype) + + zero_crop_p.def_abstract_eval(zero_crop_abstract_eval) + + # 3. Implement the Lowering Rule using jax.ffi + def zero_crop_lowering(ctx, x): + (aval_in,) = ctx.avals_in + (aval_out,) = ctx.avals_out + + return ffi.ffi_lowering( + "ZeroCrop", + operands=[x], + operand_layouts=mlir.default_layouts(ctx, aval_in), + result_layouts=mlir.default_layouts(ctx, aval_out), + )(ctx, x) + + mlir.register_lowering(zero_crop_p, zero_crop_lowering) + + # 4. Create a Python Wrapper using jax.ffi.ffi_call + def zero_crop(x): + return ffi.ffi_call( + "ZeroCrop", + result_shape_dtypes=jax.ShapeDtypeStruct(x.shape, x.dtype), + has_side_effect=True, + )(x) + + def f(x): + with jax.named_scope(MARKER): + y = jax.lax.psum(x, sharding_axis) + # Insert the custom call to prevent y from being a live out buffer + return zero_crop(y) + + jit_sharded_f = jax.jit( + shard_map( + f, + mesh, + in_specs=lhs_sharding.spec, + out_specs=out_sharding, + check_rep=False, + ) + ) + m = matrix_dim + n = BASE_SHAPE[1] + k = BASE_SHAPE[2] + + def data_generator(): + """Creates new random data on host and puts it on device.""" + nonlocal key # Use and update the outer 'key' + + matrix = jnp.ones((m, n, k), dtype=dtype) + return (matrix,) + + print("Running psum benchmark", num_runs, matrix_dim) + time_ms_list = multiple_iteration_timeit_from_trace( + jit_sharded_f, + data_generator, + matrix_dim=f"{m}x{n}x{k}", + tries=num_runs, + task="psum_ici_op", + trace_dir=trace_dir, + ) + return { + "ici_average_time_ms_list": time_ms_list, + "matrix_shape": (m, n, k), + "op_type": "AR", + "trace_dir": trace_dir, + } def psum_benchmark_calculate_metrics( @@ -323,21 +326,22 @@ def psum_benchmark_calculate_metrics( op_type: str, trace_dir: str, ) -> Dict[str, Any]: - """Calculates the metrics for the psum benchmark.""" - # Build dictionary of all the parameters in the function - - return unified_ici_collectives_metrics( - xla_output, - matrix_shape, - dtype, - mesh_shape, - op_dimension, - sharding_strategy, - ici_average_time_ms_list, - matrix_dim, - op_type, - trace_dir, - ) + # pylint: disable=unused-argument + """Calculates the metrics for the psum benchmark.""" + # Build dictionary of all the parameters in the function + + return unified_ici_collectives_metrics( + xla_output, + matrix_shape, + dtype, + mesh_shape, + op_dimension, + sharding_strategy, + ici_average_time_ms_list, + matrix_dim, + op_type, + trace_dir, + ) def psum_scatter_benchmark( @@ -350,79 +354,79 @@ def psum_scatter_benchmark( num_runs: int = 1, trace_dir: str = None, ) -> Dict[str, Any]: - """Benchmarks the psum_scatter collective operation. - - Args: - matrix_dim: The benchmark is run on a matrix with shape (matrix_dim, - matrix_dim). - dtype: The data type of the matrix. - ici_size: The number of chips in a single slice. If 1, then no ICI benchmark - is run. - mesh_shape: The shape of the mesh. - op_dimension: The dimension of the operation. - sharding_strategy: The sharding strategy of the operation. - num_runs: The number of runs to perform. - trace_dir: The directory to save the trace to. - - Returns: - The measured time for the ICI benchmark. - """ - libtpu_init_args = [ - "--xla_jf_debug_level=3", - "--xla_sc_disable_megacore_partitioning=true", - "--xla_tpu_disable_sparse_core_collective_offload_remover=true", - "--xla_tpu_enable_reduce_scatter_offload_tracing=true", - "--xla_tpu_enable_sparse_core_collective_offload_nd_reduce_scatter=true", - "--xla_tpu_enable_sparse_core_collective_offload_reduce_scatter=true", - "--xla_tpu_enable_sparse_core_reduce_scatter_v2=true", - "--xla_tpu_use_tc_device_shape_on_sc=true", - f"--xla_tpu_dvfs_p_state={GLOBAL_PSTATE}", - ] - os.environ["LIBTPU_INIT_ARGS"] = " ".join(libtpu_init_args) - mesh = create_mesh(ici_size, mesh_shape) - - sharding_axis = get_sharding_axis(sharding_strategy, mesh) - - def f(x): - with jax.named_scope(MARKER): - return jax.lax.psum_scatter(x, sharding_axis, tiled=True) - - jit_sharded_f = jax.jit( - shard_map( - f, - mesh=mesh, - in_specs=P(None, None, None), - out_specs=P(sharding_axis, None, None), - check_rep=False, - ) - ) - sharding_strategy_tuple = tuple(map(int, sharding_strategy.split("x"))) - op_dimension_tuple_multiplier = math.prod(sharding_strategy_tuple) - m = op_dimension_tuple_multiplier - n = matrix_dim - k = 256 - - def data_generator(): - """Creates new random data on host and puts it on device.""" - matrix = jnp.ones((m, n, k), dtype=dtype) - return (matrix,) - - time_ms_list = multiple_iteration_timeit_from_trace( - jit_sharded_f, - data_generator, - matrix_dim=f"{m}x{n}x{k}", - tries=num_runs, - task="psum_scatter_ici_op", - trace_dir=trace_dir, - ) - print("Running psum_scatter benchmark", num_runs, matrix_dim) - print("Matrix shape: ", m, n, k) - return { - "ici_average_time_ms_list": time_ms_list, - "matrix_shape": (m, n, k), - "op_type": "RS", - "trace_dir": trace_dir, - } + """Benchmarks the psum_scatter collective operation. + + Args: + matrix_dim: The benchmark is run on a matrix with shape (matrix_dim, + matrix_dim). + dtype: The data type of the matrix. + ici_size: The number of chips in a single slice. If 1, then no ICI benchmark + is run. + mesh_shape: The shape of the mesh. + op_dimension: The dimension of the operation. + sharding_strategy: The sharding strategy of the operation. + num_runs: The number of runs to perform. + trace_dir: The directory to save the trace to. + + Returns: + The measured time for the ICI benchmark. + """ + libtpu_init_args = [ + "--xla_jf_debug_level=3", + "--xla_sc_disable_megacore_partitioning=true", + "--xla_tpu_disable_sparse_core_collective_offload_remover=true", + "--xla_tpu_enable_reduce_scatter_offload_tracing=true", + "--xla_tpu_enable_sparse_core_collective_offload_nd_reduce_scatter=true", + "--xla_tpu_enable_sparse_core_collective_offload_reduce_scatter=true", + "--xla_tpu_enable_sparse_core_reduce_scatter_v2=true", + "--xla_tpu_use_tc_device_shape_on_sc=true", + f"--xla_tpu_dvfs_p_state={GLOBAL_PSTATE}", + ] + os.environ["LIBTPU_INIT_ARGS"] = " ".join(libtpu_init_args) + mesh = create_mesh(ici_size, mesh_shape) + + sharding_axis = get_sharding_axis(sharding_strategy, mesh) + + def f(x): + with jax.named_scope(MARKER): + return jax.lax.psum_scatter(x, sharding_axis, tiled=True) + + jit_sharded_f = jax.jit( + shard_map( + f, + mesh=mesh, + in_specs=P(None, None, None), + out_specs=P(sharding_axis, None, None), + check_rep=False, + ) + ) + sharding_strategy_tuple = tuple(map(int, sharding_strategy.split("x"))) + op_dimension_tuple_multiplier = math.prod(sharding_strategy_tuple) + m = op_dimension_tuple_multiplier + n = matrix_dim + k = 256 + + def data_generator(): + """Creates new random data on host and puts it on device.""" + matrix = jnp.ones((m, n, k), dtype=dtype) + return (matrix,) + + time_ms_list = multiple_iteration_timeit_from_trace( + jit_sharded_f, + data_generator, + matrix_dim=f"{m}x{n}x{k}", + tries=num_runs, + task="psum_scatter_ici_op", + trace_dir=trace_dir, + ) + print("Running psum_scatter benchmark", num_runs, matrix_dim) + print("Matrix shape: ", m, n, k) + return { + "ici_average_time_ms_list": time_ms_list, + "matrix_shape": (m, n, k), + "op_type": "RS", + "trace_dir": trace_dir, + } def psum_scatter_benchmark_calculate_metrics( @@ -438,21 +442,23 @@ def psum_scatter_benchmark_calculate_metrics( op_type: str, trace_dir: str, ) -> Dict[str, Any]: - """Calculates the metrics for the psum_scatter benchmark.""" - # Build dictionary of all the parameters in the function - - return unified_ici_collectives_metrics( - xla_output, - matrix_shape, - dtype, - mesh_shape, - op_dimension, - sharding_strategy, - ici_average_time_ms_list, - matrix_dim, - op_type, - trace_dir, - ) + # pylint: disable=unused-argument + """Calculates the metrics for the psum_scatter benchmark.""" + # Build dictionary of all the parameters in the function + + return unified_ici_collectives_metrics( + xla_output, + matrix_shape, + dtype, + mesh_shape, + op_dimension, + sharding_strategy, + ici_average_time_ms_list, + matrix_dim, + op_type, + trace_dir, + ) + def all_gather_benchmark( matrix_dim: int, @@ -464,79 +470,79 @@ def all_gather_benchmark( num_runs: int = 1, trace_dir: str = None, ) -> Dict[str, Any]: - """Benchmarks the all_gather collective operation. - - Args: - matrix_dim: The benchmark is run on a matrix with shape (matrix_dim, - matrix_dim). - dtype: The data type of the matrix. - ici_size: The number of chips in a single slice. If 1, then no ICI benchmark - is run. - mesh_shape: The shape of the mesh. - sharding_strategy: The sharding strategy of the operation. - op_dimension: The dimension of the operation. - num_runs: The number of runs to perform. - trace_dir: The directory to save the trace to. - - Returns: - The measured time for the ICI benchmark. - """ - libtpu_init_args = [ - "--xla_jf_debug_level=3", - "--xla_sc_disable_megacore_partitioning=true", - "--xla_tpu_disable_sparse_core_collective_offload_remover=true", - "--xla_tpu_enable_all_gather_offload_tracing=true", - "--xla_tpu_enable_sparse_core_collective_offload_2d_all_gather=true", - "--xla_tpu_enable_sparse_core_collective_offload_3d_all_gather=true", - "--xla_tpu_enable_sparse_core_collective_offload_all_gather=true", - "--xla_tpu_use_single_sparse_core_for_all_gather_offload=true", - "--xla_tpu_use_tc_device_shape_on_sc=true", - f"--xla_tpu_dvfs_p_state={GLOBAL_PSTATE}", - "--xla_tpu_scoped_vmem_limit_kib=65536", - ] - # libtpu_init_args=[ ] - os.environ["LIBTPU_INIT_ARGS"] = " ".join(libtpu_init_args) - mesh = create_mesh(ici_size, mesh_shape) - - sharding_axis = get_sharding_axis(sharding_strategy, mesh) - - def f(x): - with jax.named_scope(MARKER): - return jax.lax.all_gather(x, sharding_axis, tiled=True) - - jit_sharded_f = jax.jit( - shard_map( - f, - mesh=mesh, - in_specs=P(None, None, None), - out_specs=P(None, None, None), - check_rep=False, - ) - ) - m = matrix_dim - n = BASE_SHAPE[1] - k = BASE_SHAPE[2] - - def data_generator(): - """Creates new random data on host and puts it on device.""" - matrix = jnp.ones((m, n, k), dtype=dtype) - return (matrix,) - - time_ms_list = multiple_iteration_timeit_from_trace( - jit_sharded_f, - data_generator, - matrix_dim=f"{m}x{n}x{k}", - tries=num_runs, - task="all_gather_ici_op", - trace_dir=trace_dir, - ) - print("Running all_gather benchmark", num_runs, matrix_dim) - return { - "ici_average_time_ms_list": time_ms_list, - "matrix_shape": (m, n, k), - "op_type": "AG", - "trace_dir": trace_dir, - } + """Benchmarks the all_gather collective operation. + + Args: + matrix_dim: The benchmark is run on a matrix with shape (matrix_dim, + matrix_dim). + dtype: The data type of the matrix. + ici_size: The number of chips in a single slice. If 1, then no ICI benchmark + is run. + mesh_shape: The shape of the mesh. + sharding_strategy: The sharding strategy of the operation. + op_dimension: The dimension of the operation. + num_runs: The number of runs to perform. + trace_dir: The directory to save the trace to. + + Returns: + The measured time for the ICI benchmark. + """ + libtpu_init_args = [ + "--xla_jf_debug_level=3", + "--xla_sc_disable_megacore_partitioning=true", + "--xla_tpu_disable_sparse_core_collective_offload_remover=true", + "--xla_tpu_enable_all_gather_offload_tracing=true", + "--xla_tpu_enable_sparse_core_collective_offload_2d_all_gather=true", + "--xla_tpu_enable_sparse_core_collective_offload_3d_all_gather=true", + "--xla_tpu_enable_sparse_core_collective_offload_all_gather=true", + "--xla_tpu_use_single_sparse_core_for_all_gather_offload=true", + "--xla_tpu_use_tc_device_shape_on_sc=true", + f"--xla_tpu_dvfs_p_state={GLOBAL_PSTATE}", + "--xla_tpu_scoped_vmem_limit_kib=65536", + ] + # libtpu_init_args=[ ] + os.environ["LIBTPU_INIT_ARGS"] = " ".join(libtpu_init_args) + mesh = create_mesh(ici_size, mesh_shape) + + sharding_axis = get_sharding_axis(sharding_strategy, mesh) + + def f(x): + with jax.named_scope(MARKER): + return jax.lax.all_gather(x, sharding_axis, tiled=True) + + jit_sharded_f = jax.jit( + shard_map( + f, + mesh=mesh, + in_specs=P(None, None, None), + out_specs=P(None, None, None), + check_rep=False, + ) + ) + m = matrix_dim + n = BASE_SHAPE[1] + k = BASE_SHAPE[2] + + def data_generator(): + """Creates new random data on host and puts it on device.""" + matrix = jnp.ones((m, n, k), dtype=dtype) + return (matrix,) + + time_ms_list = multiple_iteration_timeit_from_trace( + jit_sharded_f, + data_generator, + matrix_dim=f"{m}x{n}x{k}", + tries=num_runs, + task="all_gather_ici_op", + trace_dir=trace_dir, + ) + print("Running all_gather benchmark", num_runs, matrix_dim) + return { + "ici_average_time_ms_list": time_ms_list, + "matrix_shape": (m, n, k), + "op_type": "AG", + "trace_dir": trace_dir, + } def all_gather_benchmark_calculate_metrics( @@ -552,21 +558,22 @@ def all_gather_benchmark_calculate_metrics( op_type: str, trace_dir: str, ) -> Dict[str, Any]: - """Calculates the metrics for the all_gather benchmark.""" - # Build dictionary of all the parameters in the function - - return unified_ici_collectives_metrics( - xla_output, - matrix_shape, - dtype, - mesh_shape, - op_dimension, - sharding_strategy, - ici_average_time_ms_list, - matrix_dim, - op_type, - trace_dir, - ) + # pylint: disable=unused-argument + """Calculates the metrics for the all_gather benchmark.""" + # Build dictionary of all the parameters in the function + + return unified_ici_collectives_metrics( + xla_output, + matrix_shape, + dtype, + mesh_shape, + op_dimension, + sharding_strategy, + ici_average_time_ms_list, + matrix_dim, + op_type, + trace_dir, + ) def all_to_all_benchmark( @@ -579,74 +586,74 @@ def all_to_all_benchmark( num_runs: int = 1, trace_dir: str = None, ) -> Dict[str, Any]: - """Benchmarks the all_to_all collective operation. - - Args: - matrix_dim: The benchmark is run on a matrix with shape (matrix_dim, - matrix_dim). - dtype: The data type of the matrix. - ici_size: The number of chips in a single slice. If 1, then no ICI benchmark - is run. - mesh_shape: The shape of the mesh. - op_dimension: The dimension of the operation. - num_runs: The number of runs to perform. - trace_dir: The directory to save the trace to. - - Returns: - The measured time for the ICI benchmark. - """ - libtpu_init_args = [ - "--xla_jf_debug_level=3", - f"--xla_tpu_dvfs_p_state={GLOBAL_PSTATE}", - ] - os.environ["LIBTPU_INIT_ARGS"] = " ".join(libtpu_init_args) - mesh = create_mesh(ici_size, mesh_shape) - key = jax.random.key(SEED) - lhs_sharding = get_lhs_named_shading(mesh, GLOBAL_SHARDING_STRATEGY) - out_sharding = get_out_sharding(GLOBAL_SHARDING_STRATEGY) - sharding_axis = get_sharding_axis(sharding_strategy, mesh) - - def f(x): - with jax.named_scope(MARKER): - return jax.lax.all_to_all( - x, sharding_axis, split_axis=0, concat_axis=0, tiled=True - ) - - jit_sharded_f = jax.jit( - shard_map( - f, - mesh, - in_specs=lhs_sharding.spec, - out_specs=out_sharding, - check_rep=False, - ) - ) - m = matrix_dim - n = BASE_SHAPE[1] - k = BASE_SHAPE[2] - - def data_generator(): - """Creates new random data on host and puts it on device.""" - nonlocal key # Use and update the outer 'key' - - matrix = jnp.ones((m, n, k), dtype=dtype) - return (matrix,) - - print("Running all_to_all benchmark", num_runs, matrix_dim) - time_ms_list = multiple_iteration_timeit_from_trace( - jit_sharded_f, - data_generator, - matrix_dim=f"{m}x{n}x{k}", - tries=num_runs, - task="all_to_all_ici_op", - trace_dir=trace_dir, - ) - return { - "ici_average_time_ms_list": time_ms_list, - "matrix_shape": (m, n, k), - "op_type": "A2A", - "trace_dir": trace_dir, - } + """Benchmarks the all_to_all collective operation. + + Args: + matrix_dim: The benchmark is run on a matrix with shape (matrix_dim, + matrix_dim). + dtype: The data type of the matrix. + ici_size: The number of chips in a single slice. If 1, then no ICI benchmark + is run. + mesh_shape: The shape of the mesh. + op_dimension: The dimension of the operation. + num_runs: The number of runs to perform. + trace_dir: The directory to save the trace to. + + Returns: + The measured time for the ICI benchmark. + """ + libtpu_init_args = [ + "--xla_jf_debug_level=3", + f"--xla_tpu_dvfs_p_state={GLOBAL_PSTATE}", + ] + os.environ["LIBTPU_INIT_ARGS"] = " ".join(libtpu_init_args) + mesh = create_mesh(ici_size, mesh_shape) + key = jax.random.key(SEED) + lhs_sharding = get_lhs_named_shading(mesh, GLOBAL_SHARDING_STRATEGY) + out_sharding = get_out_sharding(GLOBAL_SHARDING_STRATEGY) + sharding_axis = get_sharding_axis(sharding_strategy, mesh) + + def f(x): + with jax.named_scope(MARKER): + return jax.lax.all_to_all( + x, sharding_axis, split_axis=0, concat_axis=0, tiled=True + ) + + jit_sharded_f = jax.jit( + shard_map( + f, + mesh, + in_specs=lhs_sharding.spec, + out_specs=out_sharding, + check_rep=False, + ) + ) + m = matrix_dim + n = BASE_SHAPE[1] + k = BASE_SHAPE[2] + + def data_generator(): + """Creates new random data on host and puts it on device.""" + nonlocal key # Use and update the outer 'key' + + matrix = jnp.ones((m, n, k), dtype=dtype) + return (matrix,) + + print("Running all_to_all benchmark", num_runs, matrix_dim) + time_ms_list = multiple_iteration_timeit_from_trace( + jit_sharded_f, + data_generator, + matrix_dim=f"{m}x{n}x{k}", + tries=num_runs, + task="all_to_all_ici_op", + trace_dir=trace_dir, + ) + return { + "ici_average_time_ms_list": time_ms_list, + "matrix_shape": (m, n, k), + "op_type": "A2A", + "trace_dir": trace_dir, + } def all_to_all_benchmark_calculate_metrics( @@ -662,19 +669,19 @@ def all_to_all_benchmark_calculate_metrics( op_type: str, trace_dir: str, ) -> Dict[str, Any]: - """Calculates the metrics for the all_to_all benchmark.""" - # Build dictionary of all the parameters in the function - - return unified_ici_collectives_metrics( - xla_output, - matrix_shape, - dtype, - mesh_shape, - op_dimension, - sharding_strategy, - ici_average_time_ms_list, - matrix_dim, - op_type, - trace_dir, - ) - + # pylint: disable=unused-argument + """Calculates the metrics for the all_to_all benchmark.""" + # Build dictionary of all the parameters in the function + + return unified_ici_collectives_metrics( + xla_output, + matrix_shape, + dtype, + mesh_shape, + op_dimension, + sharding_strategy, + ici_average_time_ms_list, + matrix_dim, + op_type, + trace_dir, + ) diff --git a/Ironwood/src/benchmark_compute.py b/Ironwood/src/benchmark_compute.py index 813a076..d2adc33 100644 --- a/Ironwood/src/benchmark_compute.py +++ b/Ironwood/src/benchmark_compute.py @@ -15,7 +15,6 @@ import os from typing import Any, Dict, Callable - # pylint: disable=g-importing-member from benchmark_utils import ( iteration_timeit, @@ -141,24 +140,38 @@ def f(x): ) return qx.qvalue, qx.scale - return fp8_quantization(m, n, f, num_runs, trace_dir, task_name="quantization") + return fp8_quantization( + m, n, f, num_runs, trace_dir, task_name="quantization" + ) def quantization_calculate_metrics( - m: int, n: int, time_ms_list: list[float], quant_dtype: str = "float8_e4m3fn" + m: int, + n: int, + time_ms_list: list[float], + quant_dtype: str = "float8_e4m3fn", ) -> Dict[str, Any]: quant_jnp_dtype = jnp.dtype(quant_dtype) - info_fn = jnp.iinfo if jnp.issubdtype(quant_jnp_dtype, jnp.integer) else jnp.finfo + info_fn = ( + jnp.iinfo if jnp.issubdtype(quant_jnp_dtype, jnp.integer) else jnp.finfo + ) width_in_bytes = info_fn(quant_jnp_dtype).bits / 8 output_flops_based_on_dtype = m * n * width_in_bytes # calculate scale apply quant write quant output write scale factor # NOTE: (2 * m * n) + (2 * m * n) + (1 * m * n) + (4 * m) - total_bytes = (2 * m * n) + (2 * m * n) + (4 * m) + output_flops_based_on_dtype + total_bytes = ( + (2 * m * n) + (2 * m * n) + (4 * m) + output_flops_based_on_dtype + ) total_bytes, total_bytes_all_devices = handle_based_on_sharding( total_bytes, SHARDING_STRATEGY ) return unified_bytes_metrics( - m, n, time_ms_list, total_bytes, total_bytes_all_devices, quant_dtype=quant_dtype + m, + n, + time_ms_list, + total_bytes, + total_bytes_all_devices, + quant_dtype=quant_dtype, ) @@ -271,7 +284,12 @@ def f(x): return qx.qvalue, qx.scale return fp8_quantization( - m, n, f, num_runs, trace_dir, task_name="transpose_quantization_static_scaling" + m, + n, + f, + num_runs, + trace_dir, + task_name="transpose_quantization_static_scaling", ) @@ -452,7 +470,10 @@ def rmsnorm_fwd( Y_i = X_i / rms(x_i) """ rms_norm_module = nnx.RMSNorm( - num_features=n, dtype=jnp.bfloat16, param_dtype=jnp.float32, rngs=nnx.Rngs(SEED) + num_features=n, + dtype=jnp.bfloat16, + param_dtype=jnp.float32, + rngs=nnx.Rngs(SEED), ) def f(x): @@ -518,7 +539,10 @@ def rmsnorm_bwd( Inverse of rmsnorm_fwd """ rms_norm_module = nnx.RMSNorm( - num_features=n, dtype=jnp.bfloat16, param_dtype=jnp.float32, rngs=nnx.Rngs(SEED) + num_features=n, + dtype=jnp.bfloat16, + param_dtype=jnp.float32, + rngs=nnx.Rngs(SEED), ) def f_fwd(x): @@ -651,7 +675,9 @@ def data_generator(): return {"time_ms_list": time_ms_list} -def add_calculate_metrics(m: int, n: int, time_ms_list: list[float]) -> Dict[str, Any]: +def add_calculate_metrics( + m: int, n: int, time_ms_list: list[float] +) -> Dict[str, Any]: total_bytes = 6 * m * n total_bytes, total_bytes_all_devices = handle_based_on_sharding( total_bytes, SHARDING_STRATEGY diff --git a/Ironwood/src/benchmark_gemm.py b/Ironwood/src/benchmark_gemm.py index b802ddc..5e9dea1 100644 --- a/Ironwood/src/benchmark_gemm.py +++ b/Ironwood/src/benchmark_gemm.py @@ -23,14 +23,13 @@ handle_based_on_sharding, unified_flops_metrics, str_to_dtype, - get_peak_flops_multiplier + get_peak_flops_multiplier, ) from common import MARKER import jax from jax.experimental.shard_map import shard_map import jax.numpy as jnp - # pylint: disable=g-importing-member os.environ["LIBTPU_INIT_ARGS"] = ( @@ -62,6 +61,7 @@ SEED = 0 PEAK_FLOPS_PER_DEVICE = 2307 # TFLOP/s for single core(device) of FP8 + def gemm_multiple_run( m: int, k: int, @@ -146,7 +146,11 @@ def gemm_multiple_run_calculate_metrics( total_flops, total_flops_all_devices = handle_based_on_sharding( total_flops, SHARDING_STRATEGY ) - peak_flops = PEAK_FLOPS_PER_DEVICE if dtype==jax.numpy.float8_e4m3fn else PEAK_FLOPS_PER_DEVICE/2 + peak_flops = ( + PEAK_FLOPS_PER_DEVICE + if dtype == jax.numpy.float8_e4m3fn + else PEAK_FLOPS_PER_DEVICE / 2 + ) return unified_flops_metrics( m, n, @@ -158,6 +162,7 @@ def gemm_multiple_run_calculate_metrics( dtype=dtype.dtype.name, ) + def gemm_simple( m: int, k: int, @@ -213,8 +218,8 @@ def data_generator(): return (lhs_device, rhs_device) # Run the benchmark - num_runs = 1 - ## Need to fix gemm timing logic to handle num_runs > 1 + num_runs = 1 + # Need to fix gemm timing logic to handle num_runs > 1 time_ms_list = iteration_timeit( jit_sharded_f, @@ -251,9 +256,13 @@ def gemm_simple_calculate_metrics( def gemm_simple_with_dtype( - m: int, k: int, n: int, - in_dtype_str: str, out_dtype_str: str, - num_runs: int = 1, trace_dir: str = None + m: int, + k: int, + n: int, + in_dtype_str: str, + out_dtype_str: str, + num_runs: int = 1, + trace_dir: str = None, ) -> Dict[str, Any]: """Benchmarks the OUT:BF16 = IN0:FP8 x IN1:FP8. Accumulation is FP32.""" @@ -264,7 +273,9 @@ def gemm_simple_with_dtype( def f(x, y): with jax.named_scope(MARKER): - acc = jax.numpy.einsum("ij,jk->ik", x, y, preferred_element_type=jnp.float32) + acc = jax.numpy.einsum( + "ij,jk->ik", x, y, preferred_element_type=jnp.float32 + ) return acc.astype(out_dtype) mesh = create_mesh(SHARDING_STRATEGY) @@ -289,7 +300,7 @@ def f(x, y): def data_generator(): """Creates new random data on host and puts it on device.""" - nonlocal key # Use and update the outer 'key' + nonlocal key # Use and update the outer 'key' key, key_lhs, key_rhs = jax.random.split(key, 3) # Create random data on host @@ -302,8 +313,8 @@ def data_generator(): return (lhs_device, rhs_device) - num_runs = 1 - ## Need to fix gemm timing logic to handle num_runs > 1 + num_runs = 1 + # Need to fix gemm timing logic to handle num_runs > 1 # Run the benchmark time_ms_list = iteration_timeit( @@ -316,22 +327,33 @@ def data_generator(): ) return {"time_ms_list": time_ms_list} + def gemm_simple_with_dtype_calculate_metrics( - m: int, k: int, n: int, - in_dtype_str: str, out_dtype_str: str, - time_ms_list: list[float] + m: int, + k: int, + n: int, + in_dtype_str: str, + out_dtype_str: str, + time_ms_list: list[float], ) -> Dict[str, Any]: # Calculate FLOPs total_flops = (2 * k - 1) * m * n # Total floating-point operations - total_flops, total_flops_all_devices = handle_based_on_sharding(total_flops, SHARDING_STRATEGY) + total_flops, total_flops_all_devices = handle_based_on_sharding( + total_flops, SHARDING_STRATEGY + ) # Get the multiplier by calling the utility function peak_flops_multiplier = get_peak_flops_multiplier(in_dtype_str) metadata, metrics = unified_flops_metrics( - m, n, k, time_ms_list, - total_flops, total_flops_all_devices, - PEAK_FLOPS_PER_DEVICE * peak_flops_multiplier) + m, + n, + k, + time_ms_list, + total_flops, + total_flops_all_devices, + PEAK_FLOPS_PER_DEVICE * peak_flops_multiplier, + ) # Add dtype info to metadata for logging metadata["in_dtype"] = in_dtype_str @@ -407,8 +429,8 @@ def data_generator(): return (lhs_device, rhs_device, sf0_device, sf1_device) - num_runs = 1 - ## Need to fix gemm timing logic to handle num_runs > 1 + num_runs = 1 + # Need to fix gemm timing logic to handle num_runs > 1 time_ms_list = iteration_timeit( jit_sharded_f, @@ -519,11 +541,16 @@ def data_generator(): sf0_device = jax.device_put(sf0_host, sf0_sharding) sf1_device = jax.device_put(sf1_host, sf1_sharding) - return (out_buffer_device, lhs_device, rhs_device, sf0_device, sf1_device) - + return ( + out_buffer_device, + lhs_device, + rhs_device, + sf0_device, + sf1_device, + ) - num_runs = 1 - ## Need to fix gemm timing logic to handle num_runs > 1 + num_runs = 1 + # Need to fix gemm timing logic to handle num_runs > 1 time_ms_list = iteration_timeit( jit_sharded_f, diff --git a/Ironwood/src/benchmark_gemm_numerics.py b/Ironwood/src/benchmark_gemm_numerics.py index 93d680c..a14631a 100644 --- a/Ironwood/src/benchmark_gemm_numerics.py +++ b/Ironwood/src/benchmark_gemm_numerics.py @@ -12,7 +12,6 @@ import os from typing import Any, Dict, Callable - # pylint: disable=g-importing-member from benchmark_utils import ( iteration_timeit, @@ -145,7 +144,10 @@ def f(x, y): channelwise_axes=[1], ) acc = jax.numpy.einsum( - "ij,jk->ik", qx.qvalue, qy.qvalue, preferred_element_type=jnp.float32 + "ij,jk->ik", + qx.qvalue, + qy.qvalue, + preferred_element_type=jnp.float32, ) return acc.astype(jnp.bfloat16) @@ -191,10 +193,13 @@ def f(x, y): qtype=jnp.float8_e4m3fn, scale_dtype=jnp.float32, calibration_method="absmax", - channelwise_axes=[1] + channelwise_axes=[1], ) acc = jax.numpy.einsum( - "ij,jk->ik", qx.qvalue, qy.qvalue, preferred_element_type=jnp.float32 + "ij,jk->ik", + qx.qvalue, + qy.qvalue, + preferred_element_type=jnp.float32, ).astype(jnp.float32) final_result = acc * ( qx.scale.astype(jnp.float32) * qy.scale.astype(jnp.float32) @@ -292,7 +297,10 @@ def f(x, y): tiled_axes={1: 128}, ) acc = jax.numpy.einsum( - "ij,jk->ik", qx.qvalue, qy.qvalue, preferred_element_type=jnp.float32 + "ij,jk->ik", + qx.qvalue, + qy.qvalue, + preferred_element_type=jnp.float32, ) return acc.astype(jnp.bfloat16) @@ -341,12 +349,21 @@ def f(x, y): channelwise_axes=[1], ) acc = jax.numpy.einsum( - "ij,jk->ik", qx.qvalue, qy.qvalue, preferred_element_type=jnp.float32 + "ij,jk->ik", + qx.qvalue, + qy.qvalue, + preferred_element_type=jnp.float32, ) return acc.astype(jnp.bfloat16) return gemm_fp8_quantization( - m, k, n, f, num_runs, trace_dir, task_name="gemm_fp8_rowwise_static_scaling" + m, + k, + n, + f, + num_runs, + trace_dir, + task_name="gemm_fp8_rowwise_static_scaling", ) @@ -392,12 +409,21 @@ def f(x, y): tiled_axes={1: 128}, ) acc = jax.numpy.einsum( - "ij,jk->ik", qx.qvalue, qy.qvalue, preferred_element_type=jnp.float32 + "ij,jk->ik", + qx.qvalue, + qy.qvalue, + preferred_element_type=jnp.float32, ) return acc.astype(jnp.bfloat16) return gemm_fp8_quantization( - m, k, n, f, num_runs, trace_dir, task_name="gemm_fp8_b128_fp32_static_scaling" + m, + k, + n, + f, + num_runs, + trace_dir, + task_name="gemm_fp8_b128_fp32_static_scaling", ) @@ -426,11 +452,16 @@ def gemm_mxfp8_b32( def f(x, y): with jax.named_scope(MARKER): - how = qarray.HowToQuantize(qtype="mxfp8", calibration_method="absmax") + how = qarray.HowToQuantize( + qtype="mxfp8", calibration_method="absmax" + ) qx = qarray.quantize(x, how=how) qy = qarray.quantize(y, how=how) acc = jax.numpy.einsum( - "ij,jk->ik", qx.qvalue, qy.qvalue, preferred_element_type=jnp.float32 + "ij,jk->ik", + qx.qvalue, + qy.qvalue, + preferred_element_type=jnp.float32, ) return acc.astype(jnp.bfloat16) @@ -470,7 +501,10 @@ def f(x, y): qx = qarray.quantize(x, how=how) qy = qarray.quantize(y, how=how) acc = jax.numpy.einsum( - "ij,jk->ik", qx.qvalue, qy.qvalue, preferred_element_type=jnp.float32 + "ij,jk->ik", + qx.qvalue, + qy.qvalue, + preferred_element_type=jnp.float32, ) return acc.astype(jnp.bfloat16) diff --git a/Ironwood/src/benchmark_gemm_throttling.py b/Ironwood/src/benchmark_gemm_throttling.py index 30769f6..66712d8 100644 --- a/Ironwood/src/benchmark_gemm_throttling.py +++ b/Ironwood/src/benchmark_gemm_throttling.py @@ -17,7 +17,6 @@ from jax.experimental.shard_map import shard_map import jax.numpy as jnp - # pylint: disable=g-importing-member os.environ["LIBTPU_INIT_ARGS"] = ( @@ -50,71 +49,71 @@ def gemm_throttling( gap_strategy: str = "data_gen_every_iter_block_every_iter", trace_dir: str = None, ) -> Dict[str, Any]: - """Benchmarks the OUT:BF16 = IN0:FP8 x IN1:FP8. - - Accumulation is FP32. - """ - - def f(x, y): - with jax.named_scope(MARKER): - acc = jax.numpy.einsum( - "ij,jk->ik", x, y, preferred_element_type=jnp.float32 - ) - return acc.astype(jnp.bfloat16) - - mesh = create_mesh(SHARDING_STRATEGY) - lhs_sharding = get_lhs_named_shading(mesh, SHARDING_STRATEGY) - rhs_sharding = get_rhs_named_shading(mesh, SHARDING_STRATEGY) - out_sharding = get_out_sharding(SHARDING_STRATEGY) - - jit_sharded_f = jax.jit( - shard_map( - f, - mesh, - in_specs=(lhs_sharding.spec, rhs_sharding.spec), - out_specs=out_sharding, - check_rep=False, - ) - ) - - lhs_shape = (m, k) - rhs_shape = (k, n) - - lhs_dtype = dtype - rhs_dtype = dtype - - key = jax.random.key(SEED) - - def data_generator(): - """Creates new random data on host and puts it on device.""" - nonlocal key # Use and update the outer 'key' - key, key_lhs, key_rhs = jax.random.split(key, 3) - - # Create random data on host - lhs_host = jax.random.normal(key_lhs, lhs_shape).astype(lhs_dtype) - rhs_host = jax.random.normal(key_rhs, rhs_shape).astype(rhs_dtype) - - # Put on device (HBM) - lhs_device = jax.device_put(lhs_host, lhs_sharding) - rhs_device = jax.device_put(rhs_host, rhs_sharding) - - return (lhs_device, rhs_device) - - # Run the benchmark - - print("Running gemm_throttling benchmark", num_runs) - time_ms_list = multiple_iteration_timeit_from_trace_throttling( - jit_sharded_f, - data_generator, - matrix_dim=f"{m}x{n}x{k}", - tries=num_runs, - task="gemm_throttling", - trace_dir=trace_dir, - gap_strategy=gap_strategy, - ) - return { - "time_ms_list": time_ms_list, - } + """Benchmarks the OUT:BF16 = IN0:FP8 x IN1:FP8. + + Accumulation is FP32. + """ + + def f(x, y): + with jax.named_scope(MARKER): + acc = jax.numpy.einsum( + "ij,jk->ik", x, y, preferred_element_type=jnp.float32 + ) + return acc.astype(jnp.bfloat16) + + mesh = create_mesh(SHARDING_STRATEGY) + lhs_sharding = get_lhs_named_shading(mesh, SHARDING_STRATEGY) + rhs_sharding = get_rhs_named_shading(mesh, SHARDING_STRATEGY) + out_sharding = get_out_sharding(SHARDING_STRATEGY) + + jit_sharded_f = jax.jit( + shard_map( + f, + mesh, + in_specs=(lhs_sharding.spec, rhs_sharding.spec), + out_specs=out_sharding, + check_rep=False, + ) + ) + + lhs_shape = (m, k) + rhs_shape = (k, n) + + lhs_dtype = dtype + rhs_dtype = dtype + + key = jax.random.key(SEED) + + def data_generator(): + """Creates new random data on host and puts it on device.""" + nonlocal key # Use and update the outer 'key' + key, key_lhs, key_rhs = jax.random.split(key, 3) + + # Create random data on host + lhs_host = jax.random.normal(key_lhs, lhs_shape).astype(lhs_dtype) + rhs_host = jax.random.normal(key_rhs, rhs_shape).astype(rhs_dtype) + + # Put on device (HBM) + lhs_device = jax.device_put(lhs_host, lhs_sharding) + rhs_device = jax.device_put(rhs_host, rhs_sharding) + + return (lhs_device, rhs_device) + + # Run the benchmark + + print("Running gemm_throttling benchmark", num_runs) + time_ms_list = multiple_iteration_timeit_from_trace_throttling( + jit_sharded_f, + data_generator, + matrix_dim=f"{m}x{n}x{k}", + tries=num_runs, + task="gemm_throttling", + trace_dir=trace_dir, + gap_strategy=gap_strategy, + ) + return { + "time_ms_list": time_ms_list, + } def gemm_throttling_calculate_metrics( @@ -125,17 +124,18 @@ def gemm_throttling_calculate_metrics( dtype: jnp.dtype, time_ms_list: list[float], ) -> Dict[str, Any]: - # Calculate FLOPs - total_flops = 2 * m * k * n # Total floating-point operations - total_flops, total_flops_all_devices = handle_based_on_sharding( - total_flops, SHARDING_STRATEGY - ) - return unified_flops_metrics( - m, - n, - k, - time_ms_list, - total_flops, - total_flops_all_devices, - PEAK_FLOPS_PER_DEVICE, - ) + # pylint: disable=unused-argument + # Calculate FLOPs + total_flops = 2 * m * k * n # Total floating-point operations + total_flops, total_flops_all_devices = handle_based_on_sharding( + total_flops, SHARDING_STRATEGY + ) + return unified_flops_metrics( + m, + n, + k, + time_ms_list, + total_flops, + total_flops_all_devices, + PEAK_FLOPS_PER_DEVICE, + ) diff --git a/Ironwood/src/benchmark_hbm.py b/Ironwood/src/benchmark_hbm.py index 67f0429..aaef211 100644 --- a/Ironwood/src/benchmark_hbm.py +++ b/Ironwood/src/benchmark_hbm.py @@ -12,7 +12,6 @@ import jax import jax.numpy as jnp - SEED = 0 os.environ["LIBTPU_INIT_ARGS"] = ( "--xla_tpu_scoped_vmem_limit_kib=65536 " @@ -20,6 +19,7 @@ "--xla_tpu_dvfs_p_state=7 " ) + def get_metrics_helper( params: Dict[str, Any], ) -> Tuple[Dict[str, Any], Dict[str, Any]]: @@ -101,5 +101,7 @@ def single_device_hbm_copy_calculate_metrics( ) metrics.update(time_statistics.serialize_statistics()) metrics.update(statistics.serialize_statistics()) - metrics = {key: value for key, value in metrics.items() if value is not None} + metrics = { + key: value for key, value in metrics.items() if value is not None + } return metadata, metrics diff --git a/Ironwood/src/benchmark_host_device.py b/Ironwood/src/benchmark_host_device.py index f745eb4..244877c 100644 --- a/Ironwood/src/benchmark_host_device.py +++ b/Ironwood/src/benchmark_host_device.py @@ -5,11 +5,10 @@ from typing import Any, Dict, Tuple, List import jax -from jax import sharding import numpy as np +import contextlib from benchmark_utils import MetricsStatistics - libtpu_init_args = [ "--xla_tpu_dvfs_p_state=7", ] @@ -25,23 +24,24 @@ def benchmark_host_device( trace_dir: str = None, ) -> Dict[str, Any]: """Benchmarks H2D/D2H transfer using simple device_put/device_get.""" - + num_elements = 1024 * 1024 * data_size_mib // np.dtype(np.float32).itemsize - + # Allocate Host Source Buffer column = 128 - host_data = np.random.normal(size=(num_elements // column, column)).astype(np.float32) - + host_data = np.random.normal(size=(num_elements // column, column)).astype( + np.float32 + ) + print( f"Benchmarking Transfer with Data Size: {data_size_mib} MB for {num_runs} iterations", - flush=True + flush=True, ) # Performance Lists h2d_perf, d2h_perf = [], [] # Profiling Context - import contextlib if trace_dir: profiler_context = jax.profiler.trace(trace_dir) else: @@ -59,34 +59,36 @@ def benchmark_host_device( for i in range(num_runs): # Step Context if trace_dir: - step_context = jax.profiler.StepTraceAnnotation("host_device", step_num=i) + step_context = jax.profiler.StepTraceAnnotation( + "host_device", step_num=i + ) else: step_context = contextlib.nullcontext() - + with step_context: - # H2D + # H2D t0 = time.perf_counter() - + # Simple device_put device_array = jax.device_put(host_data) device_array.block_until_ready() - + t1 = time.perf_counter() h2d_perf.append((t1 - t0) * 1000) - + # Verify H2D shape assert device_array.shape == host_data.shape - + # D2H t2 = time.perf_counter() - + # Simple device_get # Note: device_get returns a numpy array (copy) _ = jax.device_get(device_array) - + t3 = time.perf_counter() d2h_perf.append((t3 - t2) * 1000) - + device_array.delete() return { @@ -94,37 +96,38 @@ def benchmark_host_device( "D2H_Bandwidth_ms": d2h_perf, } + def benchmark_host_device_calculate_metrics( data_size_mib: int, - H2D_Bandwidth_ms: List[float], - D2H_Bandwidth_ms: List[float], + h2d_bandwidth_ms: List[float], + d2h_bandwidth_ms: List[float], ) -> Tuple[Dict[str, Any], Dict[str, Any]]: """Calculates metrics for Host-Device transfer.""" params = locals().items() - + # Filter out list params from metadata to avoid explosion metadata_keys = { - "data_size_mib", + "data_size_mib", } metadata = {k: v for k, v in params if k in metadata_keys} - + metrics = {} - + def add_metric(name, ms_list): # Report Bandwidth (GiB/s) # Handle division by zero if ms is 0 bw_list = [ - ((data_size_mib / 1024) / (ms / 1000)) if ms > 0 else 0.0 + ((data_size_mib / 1024) / (ms / 1000)) if ms > 0 else 0.0 for ms in ms_list ] stats_bw = MetricsStatistics(bw_list, f"{name}_bw (GiB/s)") print( - f" {name}_bw (GiB/s) median: {stats_bw.statistics['p50']}, P95: {stats_bw.statistics['p95']}", - flush=True + f" {name}_bw (GiB/s) median: {stats_bw.statistics['p50']}, P95: {stats_bw.statistics['p95']}", + flush=True, ) metrics.update(stats_bw.serialize_statistics()) - add_metric("H2D", H2D_Bandwidth_ms) - add_metric("D2H", D2H_Bandwidth_ms) + add_metric("H2D", h2d_bandwidth_ms) + add_metric("D2H", d2h_bandwidth_ms) return metadata, metrics diff --git a/Ironwood/src/benchmark_inference_compute.py b/Ironwood/src/benchmark_inference_compute.py index 8bfa6b0..a5c948d 100644 --- a/Ironwood/src/benchmark_inference_compute.py +++ b/Ironwood/src/benchmark_inference_compute.py @@ -6,7 +6,6 @@ import os from typing import Any, Dict - # pylint: disable=g-importing-member from benchmark_utils import ( iteration_timeit, @@ -53,7 +52,8 @@ SHARDING_STRATEGY = ShardingStrategy.NO_SHARDING SEED = 0 -PEAK_FLOPS_PER_DEVICE = 2307 # TFLOP/s for single core(device) of FP8 under p_state=7 +# TFLOP/s for single core(device) of FP8 under p_state=7 +PEAK_FLOPS_PER_DEVICE = 2307 def add( @@ -124,7 +124,12 @@ def add_calculate_metrics( total_bytes, SHARDING_STRATEGY ) return unified_bytes_metrics( - m, n, time_ms_list, total_bytes, total_bytes_all_devices, dtype=dtype.dtype.name + m, + n, + time_ms_list, + total_bytes, + total_bytes_all_devices, + dtype=dtype.dtype.name, ) @@ -139,7 +144,9 @@ def rmsnorm( For each row i of N: Y_i = X_i / rms(x_i) """ - rms_norm_module = nnx.RMSNorm(num_features=n, dtype=dtype, rngs=nnx.Rngs(SEED)) + rms_norm_module = nnx.RMSNorm( + num_features=n, dtype=dtype, rngs=nnx.Rngs(SEED) + ) def f(x): with jax.named_scope(MARKER): @@ -191,7 +198,12 @@ def rmsnorm_calculate_metrics( total_bytes, SHARDING_STRATEGY ) return unified_bytes_metrics( - m, n, time_ms_list, total_bytes, total_bytes_all_devices, dtype=dtype.dtype.name + m, + n, + time_ms_list, + total_bytes, + total_bytes_all_devices, + dtype=dtype.dtype.name, ) @@ -264,7 +276,12 @@ def silu_mul_calculate_metrics( total_bytes, SHARDING_STRATEGY ) return unified_bytes_metrics( - m, n, time_ms_list, total_bytes, total_bytes_all_devices, dtype=dtype.dtype.name + m, + n, + time_ms_list, + total_bytes, + total_bytes_all_devices, + dtype=dtype.dtype.name, ) @@ -325,7 +342,12 @@ def sigmoid_calculate_metrics( total_bytes, SHARDING_STRATEGY ) return unified_bytes_metrics( - m, n, time_ms_list, total_bytes, total_bytes_all_devices, dtype=dtype.dtype.name + m, + n, + time_ms_list, + total_bytes, + total_bytes_all_devices, + dtype=dtype.dtype.name, ) diff --git a/Ironwood/src/benchmark_send_recv.py b/Ironwood/src/benchmark_send_recv.py index c7dd5db..c01ca0e 100644 --- a/Ironwood/src/benchmark_send_recv.py +++ b/Ironwood/src/benchmark_send_recv.py @@ -15,11 +15,11 @@ P = jax.sharding.PartitionSpec -os.environ['LIBTPU_INIT_ARGS'] = ( - '--xla_tpu_collect_sflag_wait_stats_trace=true ' - '--xla_tpu_force_global_barriers=true ' - '--xla_tpu_ragged_all_to_all_max_rdma_size_kib=-1 ' - '--xla_tpu_dvfs_p_state=7 ' +os.environ["LIBTPU_INIT_ARGS"] = ( + "--xla_tpu_collect_sflag_wait_stats_trace=true " + "--xla_tpu_force_global_barriers=true " + "--xla_tpu_ragged_all_to_all_max_rdma_size_kib=-1 " + "--xla_tpu_dvfs_p_state=7 " ) @@ -29,47 +29,50 @@ def _run_under_xprof( n_repeats: int, task: str, ): - """Runs a function under xprof.""" - # warmup - jax.block_until_ready(function(*inputs)) - with tempfile.TemporaryDirectory() as tmp_trace_dir: - with jax.profiler.trace(tmp_trace_dir, create_perfetto_link=False): - for i in range(n_repeats): - with jax.profiler.StepTraceAnnotation(task, step_num=i): - with jax.named_scope(f"{MARKER}_{i}"): - result = function(*inputs) - jax.block_until_ready(result) - jtrace = get_trace(tmp_trace_dir) - - marker_done_events = [] - for event in jtrace["traceEvents"]: - args = event.get("args", {}) - tf_op = args.get("tf_op", "") - if MARKER in tf_op: - marker_done_events.append(event) - # when offloaded to sparse core look for call-done events - marker_call_done_events = [ - e for e in marker_done_events if e.get("name", "").endswith("call-done") - ] - if marker_call_done_events: - marker_done_events = marker_call_done_events - durations_ms = [ - float(e["args"]["device_duration_ps"]) / 1e9 for e in marker_done_events - ] - return max(durations_ms) + """Runs a function under xprof.""" + # warmup + jax.block_until_ready(function(*inputs)) + with tempfile.TemporaryDirectory() as tmp_trace_dir: + with jax.profiler.trace(tmp_trace_dir, create_perfetto_link=False): + for i in range(n_repeats): + with jax.profiler.StepTraceAnnotation(task, step_num=i): + with jax.named_scope(f"{MARKER}_{i}"): + result = function(*inputs) + jax.block_until_ready(result) + jtrace = get_trace(tmp_trace_dir) + + marker_done_events = [] + for event in jtrace["traceEvents"]: + args = event.get("args", {}) + tf_op = args.get("tf_op", "") + if MARKER in tf_op: + marker_done_events.append(event) + # when offloaded to sparse core look for call-done events + marker_call_done_events = [ + e + for e in marker_done_events + if e.get("name", "").endswith("call-done") + ] + if marker_call_done_events: + marker_done_events = marker_call_done_events + durations_ms = [ + float(e["args"]["device_duration_ps"]) / 1e9 + for e in marker_done_events + ] + return max(durations_ms) def get_metrics_helper( params: Dict[str, Any], ) -> Tuple[Dict[str, Any], Dict[str, Any]]: """Helper function to build the metrics and metadata for the benchmark.""" - exclude_keys = {'time_ms_list'} + exclude_keys = {"time_ms_list"} metadata = { key: value for key, value in params if value is not None and key not in exclude_keys } - metadata['dtype'] = get_real_dtype_bytes(metadata['dtype'].dtype) + metadata["dtype"] = get_real_dtype_bytes(metadata["dtype"].dtype) return metadata @@ -84,24 +87,28 @@ def send_recv_benchmark( """Runs p2p communication, sending tensor_size_bytes from source to target device.""" device_count = jax.local_device_count() devices = mesh_utils.create_device_mesh((device_count,)) - mesh = jax.sharding.Mesh(devices, 'x') + mesh = jax.sharding.Mesh(devices, "x") item_size = get_real_dtype_bytes(jnp.dtype(dtype)) tensor_size_bytes = num_elements * item_size last_dim = tensor_size_bytes // (1 * 8 * item_size) def p2p_send(source_id, target_id): # Get the ID of the current device this code is running on - device_id = jax.lax.axis_index('x') - axis_size = jax.lax.axis_size('x') + device_id = jax.lax.axis_index("x") + axis_size = jax.lax.axis_size("x") input_offsets = jnp.zeros((axis_size,), dtype=jnp.int32) output_offsets = jnp.zeros((axis_size,), dtype=jnp.int32) no_sends = jnp.zeros((axis_size,), dtype=jnp.int32) no_recvs = jnp.zeros((axis_size,), dtype=jnp.int32) # Only device `source_id` sends, and it sends to `target_id`. - sender_send_sizes = jax.nn.one_hot(target_id, axis_size, dtype=jnp.int32) + sender_send_sizes = jax.nn.one_hot( + target_id, axis_size, dtype=jnp.int32 + ) # Only device `target_id` receives, and it receives from `source_id`. - target_recv_sizes = jax.nn.one_hot(source_id, axis_size, dtype=jnp.int32) + target_recv_sizes = jax.nn.one_hot( + source_id, axis_size, dtype=jnp.int32 + ) final_send_sizes = jax.lax.select( device_id == source_id, @@ -113,7 +120,9 @@ def p2p_send(source_id, target_id): target_recv_sizes, no_recvs, ) - input = jax.random.normal(jax.random.key(0), (1, 8, last_dim), dtype=dtype) + input = jax.random.normal( + jax.random.key(0), (1, 8, last_dim), dtype=dtype + ) output = jnp.zeros((1, 8, last_dim), dtype=dtype) with jax.named_scope(MARKER): @@ -124,7 +133,7 @@ def p2p_send(source_id, target_id): send_sizes=final_send_sizes, output_offsets=output_offsets, recv_sizes=final_recv_sizes, - axis_name='x', + axis_name="x", ) max_val = jax.lax.reduce_max(ra2a, axes=(0, 1, 2)) return max_val @@ -140,12 +149,12 @@ def p2p_send(source_id, target_id): .compile() ) - # Measures the longest wait time in milliseconds, across all the runs. + # Measures the longest wait time in milliseconds, across all the runs. runtime_ms = _run_under_xprof( - compiled_function, [], n_repeats, f'p2p_{source_id}_to_{target_id}' + compiled_function, [], n_repeats, f"p2p_{source_id}_to_{target_id}" ) - return {'runtime_ms': runtime_ms} + return {"runtime_ms": runtime_ms} def send_recv_benchmark_calculate_metrics( @@ -165,16 +174,20 @@ def send_recv_benchmark_calculate_metrics( tensor_size_bytes = num_elements * get_real_dtype_bytes(jnp.dtype(dtype)) tensor_size_gbytes = tensor_size_bytes / 10**9 - metrics['runtime_ms (ms)'] = runtime_ms + metrics["runtime_ms (ms)"] = runtime_ms runtime_s = runtime_ms / 10**3 - metrics['achieved_bw (GB/s)'] = tensor_size_gbytes / runtime_s + metrics["achieved_bw (GB/s)"] = tensor_size_gbytes / runtime_s # Gather the metrics to report. - metadata.update({ - 'tensor_size_gbytes': tensor_size_gbytes, - }) + metadata.update( + { + "tensor_size_gbytes": tensor_size_gbytes, + } + ) - metrics = {key: value for key, value in metrics.items() if value is not None} + metrics = { + key: value for key, value in metrics.items() if value is not None + } print(metadata) print(metrics) return metadata, metrics diff --git a/Ironwood/src/benchmark_utils.py b/Ironwood/src/benchmark_utils.py index ccd4f4c..9b2ede1 100644 --- a/Ironwood/src/benchmark_utils.py +++ b/Ironwood/src/benchmark_utils.py @@ -30,14 +30,14 @@ def get_real_dtype_bytes(dtype) -> float: - """Returns the real byte size of a dtype, handling sub-byte types.""" - try: - return jnp.finfo(dtype).bits / 8 - except Exception: + """Returns the real byte size of a dtype, handling sub-byte types.""" try: - return jnp.iinfo(dtype).bits / 8 + return jnp.finfo(dtype).bits / 8 except Exception: - return dtype.itemsize + try: + return jnp.iinfo(dtype).bits / 8 + except Exception: + return dtype.itemsize # The dictionary to map a JAX (collective) function to its main HLO. @@ -51,15 +51,15 @@ def get_real_dtype_bytes(dtype) -> float: class ShardingStrategy(Enum): - """Defines different sharding strategies for tensors.""" + """Defines different sharding strategies for tensors.""" - NO_SHARDING = auto() - SHARDING_ON_ALL_DEVICES_WITH_M = auto() - SHARDING_ON_SINGLE_CHIP_WITH_M = ( - auto() - ) # Only sharding on the two core of one single chip - SHARDING_ON_ALL_DEVICES_WITH_N = auto() - SHARDING_ON_SINGLE_CHIP_WITH_N = auto() + NO_SHARDING = auto() + SHARDING_ON_ALL_DEVICES_WITH_M = auto() + SHARDING_ON_SINGLE_CHIP_WITH_M = ( + auto() + ) # Only sharding on the two core of one single chip + SHARDING_ON_ALL_DEVICES_WITH_N = auto() + SHARDING_ON_SINGLE_CHIP_WITH_N = auto() def multiple_iteration_timeit_from_trace_throttling( @@ -71,69 +71,73 @@ def multiple_iteration_timeit_from_trace_throttling( trace_dir: str = None, gap_strategy: str = None, ) -> list[float]: - """Time a function with jax.profiler and get the run time from the trace.""" - LOCAL_TRACE_DIR = "/tmp/microbenchmarks_tmptrace" - - if matrix_dim is not None: - trace_name = f"{task}_dim_{matrix_dim}" - else: - trace_name = f"t_{task}_" + "".join( - random.choices(string.ascii_uppercase + string.digits, k=10) - ) + """Time a function with jax.profiler and get the run time from the trace.""" + LOCAL_TRACE_DIR = "/tmp/microbenchmarks_tmptrace" - trace_full_dir = f"{trace_dir}/{trace_name}" - tmp_trace_dir = trace_full_dir - # If the trace_dir isn't a local path, create one for dumping the trace for parsing and getting metrics. - if trace_dir and not is_local_directory_path(trace_dir): - tmp_trace_dir = f"{LOCAL_TRACE_DIR}/{trace_name}" + if matrix_dim is not None: + trace_name = f"{task}_dim_{matrix_dim}" + else: + trace_name = f"t_{task}_" + "".join( + random.choices(string.ascii_uppercase + string.digits, k=10) + ) - if gap_strategy == "data_gen_once_block_every_iter": - data_args = data_generator() - with jax.profiler.trace(tmp_trace_dir): - for i in range(tries): - if i % 10 == 0: - print( - f"[{task}] Running iteration {i} of {tries} with {matrix_dim}..." - ) - jax.devices() - with jax.profiler.StepTraceAnnotation(task, step_num=i): - with jax.named_scope(f"{MARKER}_{i}"): - result = compute_func(*data_args) - jax.block_until_ready(result) - elif gap_strategy=='data_gen_once_noblock': - data_args = data_generator() - with jax.profiler.trace(tmp_trace_dir): - results = [] - for i in range(tries): - if i % 10 == 0: - print(f"[{task}] Running iteration {i} of {tries} with {matrix_dim}...") - jax.devices() - with jax.profiler.StepTraceAnnotation(task, step_num=i): - with jax.named_scope(f"{MARKER}_{i}"): - compute_func(*data_args) - results.append(True) + trace_full_dir = f"{trace_dir}/{trace_name}" + tmp_trace_dir = trace_full_dir + # If the trace_dir isn't a local path, create one for dumping the trace for parsing and getting metrics. + if trace_dir and not is_local_directory_path(trace_dir): + tmp_trace_dir = f"{LOCAL_TRACE_DIR}/{trace_name}" - if results: - jax.block_until_ready(results) - elif gap_strategy == "data_gen_every_iter_block_every_iter": - with jax.profiler.trace(tmp_trace_dir): - for i in range(tries): - if i % 10 == 0: - print(f"[{task}] Running iteration {i} of {tries} with {matrix_dim}...") - data_args = data_generator() - jax.devices() - with jax.profiler.StepTraceAnnotation(task, step_num=i): - with jax.named_scope(f"{MARKER}_{i}"): - result = compute_func(*data_args) - jax.block_until_ready(result) - else: - raise ValueError(f"Unknown gap strategy: {gap_strategy}") - trace = get_trace(tmp_trace_dir) + if gap_strategy == "data_gen_once_block_every_iter": + data_args = data_generator() + with jax.profiler.trace(tmp_trace_dir): + for i in range(tries): + if i % 10 == 0: + print( + f"[{task}] Running iteration {i} of {tries} with {matrix_dim}..." + ) + jax.devices() + with jax.profiler.StepTraceAnnotation(task, step_num=i): + with jax.named_scope(f"{MARKER}_{i}"): + result = compute_func(*data_args) + jax.block_until_ready(result) + elif gap_strategy == "data_gen_once_noblock": + data_args = data_generator() + with jax.profiler.trace(tmp_trace_dir): + results = [] + for i in range(tries): + if i % 10 == 0: + print( + f"[{task}] Running iteration {i} of {tries} with {matrix_dim}..." + ) + jax.devices() + with jax.profiler.StepTraceAnnotation(task, step_num=i): + with jax.named_scope(f"{MARKER}_{i}"): + compute_func(*data_args) + results.append(True) + + if results: + jax.block_until_ready(results) + elif gap_strategy == "data_gen_every_iter_block_every_iter": + with jax.profiler.trace(tmp_trace_dir): + for i in range(tries): + if i % 10 == 0: + print( + f"[{task}] Running iteration {i} of {tries} with {matrix_dim}..." + ) + data_args = data_generator() + jax.devices() + with jax.profiler.StepTraceAnnotation(task, step_num=i): + with jax.named_scope(f"{MARKER}_{i}"): + result = compute_func(*data_args) + jax.block_until_ready(result) + else: + raise ValueError(f"Unknown gap strategy: {gap_strategy}") + trace = get_trace(tmp_trace_dir) - if trace_full_dir != tmp_trace_dir: - # Upload the traces to desired location - upload_to_storage(trace_dir=trace_full_dir, local_file=tmp_trace_dir) - return multiple_iteration_get_metrics_from_trace(trace) + if trace_full_dir != tmp_trace_dir: + # Upload the traces to desired location + upload_to_storage(trace_dir=trace_full_dir, local_file=tmp_trace_dir) + return multiple_iteration_get_metrics_from_trace(trace) def clear_jax_memory(): @@ -172,7 +176,9 @@ def multiple_iteration_timeit_from_trace( with jax.profiler.trace(tmp_trace_dir): for i in range(tries): if i % 10 == 0: - print(f"[{task}] Running iteration {i} of {tries} with {matrix_dim}...") + print( + f"[{task}] Running iteration {i} of {tries} with {matrix_dim}..." + ) data_args = data_generator() jax.devices() @@ -192,7 +198,9 @@ def multiple_iteration_timeit_from_trace( return multiple_iteration_get_metrics_from_trace(trace, task) -def multiple_iteration_get_metrics_from_trace(trace: dict[str, Any], task: str = None) -> list[float]: +def multiple_iteration_get_metrics_from_trace( + trace: dict[str, Any], task: str = None +) -> list[float]: marker_done_events = [] for event in trace["traceEvents"]: args = event.get("args", {}) @@ -211,7 +219,7 @@ def multiple_iteration_get_metrics_from_trace(trace: dict[str, Any], task: str = event_matcher = re.compile(task) if "traceEvents" not in trace: - raise KeyError("Key 'traceEvents' not found in trace.") + raise KeyError("Key 'traceEvents' not found in trace.") events = [] for e in trace["traceEvents"]: if "name" in e and event_matcher.match(e["name"]): @@ -223,31 +231,38 @@ def multiple_iteration_get_metrics_from_trace(trace: dict[str, Any], task: str = durations_ms = [] for e in events_from_min_pid: if e.get("args", {}).get("device_duration_ps"): - durations_ms.append(float(e["args"]["device_duration_ps"]) / 1e9) + durations_ms.append( + float(e["args"]["device_duration_ps"]) / 1e9 + ) elif "dur" in e: durations_ms.append(float(e["dur"]) / 1e3) if not durations_ms and events_from_min_pid: - print("Warning: No event duration found in legacy_get_metrics_from_trace_tpu.") + print( + "Warning: No event duration found in legacy_get_metrics_from_trace_tpu." + ) return durations_ms min_pid = min([e["pid"] for e in marker_done_events]) events_from_min_pid = [e for e in marker_done_events if e["pid"] == min_pid] durations_ms = [ - float(e["args"]["device_duration_ps"]) / 1e9 for e in events_from_min_pid + float(e["args"]["device_duration_ps"]) / 1e9 + for e in events_from_min_pid ] print(f"Collected {len(durations_ms)} events from trace for pid {min_pid}.") print(durations_ms) return durations_ms + def iteration_timeit_from_trace( compute_func: Callable, data_generator: Callable, - matrix_dim: str=None, - tries: int=10, + matrix_dim: str = None, + tries: int = 10, task: str = None, trace_dir: str = None, - event_name_str_list: list[str] = None) -> list[float]: + event_name_str_list: list[str] = None, +) -> list[float]: """ Time a function with jax.profiler and get the run time from the trace. """ @@ -279,8 +294,8 @@ def iteration_timeit_from_trace( # Upload the traces to desired location upload_to_storage(trace_dir=trace_full_dir, local_file=tmp_trace_dir) return iteration_get_metrics_from_trace( - trace=trace, - event_name_str_list=event_name_str_list) + trace=trace, event_name_str_list=event_name_str_list + ) def iteration_get_metrics_from_trace( @@ -388,9 +403,12 @@ def iteration_get_event_metrics_from_trace( events = events_by_pid[pid] # Collect the durarion_ms for each run - durations_ms_lists.append([ - float(e["args"].get("device_duration_ps", 0)) / 1e9 for e in events - ]) + durations_ms_lists.append( + [ + float(e["args"].get("device_duration_ps", 0)) / 1e9 + for e in events + ] + ) # 3. Print summary from the first device and return print(f"Average Execution time: {np.mean(durations_ms_lists[0]):.6f} ms") @@ -406,7 +424,7 @@ def iteration_timeit( warmup_tries: int = 10, tries: int = 10, task: str = None, - trace_dir: str = None + trace_dir: str = None, ) -> list[float]: """ Simple utility to time a function, ensuring no cache hits @@ -422,7 +440,7 @@ def iteration_timeit( """ assert task is not None print(f"[{task}] Running warmup loop with {warmup_tries} tries...") - result = None # To hold the last result for block_until_ready + result = None # To hold the last result for block_until_ready for _ in range(warmup_tries): # 1. Generate new data for each iteration data_args = data_generator() @@ -445,22 +463,23 @@ def iteration_timeit( if trace_dir is not None: if task == "rmsnorm": - # If the task is RMSNorm, we specifically target "copy-done" events. - # This is often done to capture the time of the asynchronous memory transfer - # needed for the normalization layer's input data. + # If the task is RMSNorm, we specifically target "copy-done" events. + # This is often done to capture the time of the asynchronous memory transfer + # needed for the normalization layer's input data. event_name_str_list = ["copy-done"] else: - # For all other tasks, use an empty list. + # For all other tasks, use an empty list. event_name_str_list = [] return iteration_timeit_from_trace( - compute_func, - data_generator, - matrix_dim=matrix_dim, - tries=tries, - task=task, - trace_dir=trace_dir, - event_name_str_list=event_name_str_list) + compute_func, + data_generator, + matrix_dim=matrix_dim, + tries=tries, + task=task, + trace_dir=trace_dir, + event_name_str_list=event_name_str_list, + ) outcomes_ms = [] print(f"[{task}] Running measurement loop with {tries} tries...") @@ -483,6 +502,7 @@ def iteration_timeit( outcomes_ms.append(1000 * (e_time - s_time).total_seconds()) return outcomes_ms + def simple_timeit( f, *args, matrix_dim=None, tries=10, task=None, trace_dir=None ) -> float: @@ -491,7 +511,12 @@ def simple_timeit( if trace_dir: return timeit_from_trace( - f, *args, matrix_dim=matrix_dim, tries=tries, task=task, trace_dir=trace_dir + f, + *args, + matrix_dim=matrix_dim, + tries=tries, + task=task, + trace_dir=trace_dir, ) outcomes_ms = [] @@ -512,7 +537,9 @@ def get_trace(log_dir: str) -> dict[str, Any]: A trace object in JSON format. """ # Navigate to the folder with the latest trace dump to find `trace.json.jz` - trace_folders = (pathlib.Path(log_dir).absolute() / "plugins" / "profile").iterdir() + trace_folders = ( + pathlib.Path(log_dir).absolute() / "plugins" / "profile" + ).iterdir() latest_trace_folder = max(trace_folders, key=os.path.getmtime) trace_jsons = latest_trace_folder.glob("*.trace.json.gz") try: @@ -608,13 +635,18 @@ def get_metrics_from_trace(trace: dict[str, Any], task: str) -> list[float]: events_by_run_id = defaultdict(list) for e in events: - run_id = e["args"]["run_id"] if "args" in e and "run_id" in e["args"] else "0" + run_id = ( + e["args"]["run_id"] + if "args" in e and "run_id" in e["args"] + else "0" + ) events_by_run_id[run_id].append(e) durations_ms = [] try: # Duration is in us. durations_ms = [ - max([e["dur"] for e in es]) / 1e3 for run_id, es in events_by_run_id.items() + max([e["dur"] for e in es]) / 1e3 + for run_id, es in events_by_run_id.items() ] except KeyError: print("KeyError: Key 'dur' not found in the event object") @@ -638,10 +670,13 @@ def get_metrics_from_trace_tpu(trace: dict[str, Any], task: str) -> list[float]: events_from_min_pid = [e for e in events if e["pid"] == min_pid] try: durations_ms = [ - float(e["args"]["device_duration_ps"]) / 1e9 for e in events_from_min_pid + float(e["args"]["device_duration_ps"]) / 1e9 + for e in events_from_min_pid ] except KeyError: - print("KeyError: Key 'device_duration_ps' not found in the event object") + print( + "KeyError: Key 'device_duration_ps' not found in the event object" + ) raise return durations_ms @@ -658,7 +693,13 @@ def is_local_directory_path(dir: str) -> bool: def timeit_from_trace( - f, *args, matrix_dim=None, tries=10, task=None, trace_dir=None, event_name_str_list: list[str] = None + f, + *args, + matrix_dim=None, + tries=10, + task=None, + trace_dir=None, + event_name_str_list: list[str] = None, ) -> float: """ Time a function with jax.profiler and get the run time from the trace. @@ -693,7 +734,9 @@ def timeit_from_trace( upload_to_storage(trace_dir=trace_full_dir, local_file=tmp_trace_dir) if event_name_str_list is not None: - return iteration_get_event_metrics_from_trace(trace, event_name_str_list=event_name_str_list) + return iteration_get_event_metrics_from_trace( + trace, event_name_str_list=event_name_str_list + ) return iteration_get_metrics_from_trace(trace) @@ -821,7 +864,9 @@ def rename_xla_dump( serialized_benchmark_param = "_".join( f"{key}_{value}" for key, value in benchmark_param.items() ) - anchor_pattern = os.path.join(tmp_xla_dump_dir, "*jit_f*before_optimizations*.txt") + anchor_pattern = os.path.join( + tmp_xla_dump_dir, "*jit_f*before_optimizations*.txt" + ) matching_anchor_files = glob.glob(anchor_pattern) if not matching_anchor_files: @@ -860,7 +905,9 @@ def rename_xla_dump( return new_base_name = f"{benchmark_name}_{serialized_benchmark_param}" - after_optimizations_path = input_shape = output_shape = replica_groups = first_replica_group = None + after_optimizations_path = input_shape = output_shape = replica_groups = ( + first_replica_group + ) = None for original_filepath in all_related_files: original_filename = os.path.basename(original_filepath) @@ -902,7 +949,9 @@ def rename_xla_dump( f"An unexpected error occurred while copy '{original_filepath}': {e}" ) else: - upload_to_storage(trace_dir=new_filepath, local_file=original_filepath) + upload_to_storage( + trace_dir=new_filepath, local_file=original_filepath + ) print(f"The XLA dump is stored in {dest_xla_dump_dir}") if after_optimizations_path: input_shape, output_shape, replica_groups, first_replica_group = ( @@ -913,15 +962,20 @@ def rename_xla_dump( "No files found with 'after_optimizations.txt' suffix. " "Please check the XLA dump directory." ) - return json.dumps({ - "after_optimizations_path": after_optimizations_path, - "hlo_input_shape": input_shape, - "hlo_output_shape": output_shape, - "hlo_replica_groups": replica_groups, - "hlo_first_replica_group": first_replica_group, - }) - -def extract_hlo_features_from_file(hlo_file_path: str) -> Tuple[str | None, str | None, str | None, list[int] | None]: + return json.dumps( + { + "after_optimizations_path": after_optimizations_path, + "hlo_input_shape": input_shape, + "hlo_output_shape": output_shape, + "hlo_replica_groups": replica_groups, + "hlo_first_replica_group": first_replica_group, + } + ) + + +def extract_hlo_features_from_file( + hlo_file_path: str, +) -> Tuple[str | None, str | None, str | None, list[int] | None]: """ Extracts input shape, output shape, and replica groups from an HLO file. @@ -946,7 +1000,9 @@ def extract_hlo_features_from_file(hlo_file_path: str) -> Tuple[str | None, str # Extract input/output shapes from HloModule line # Example: HloModule jit_f, ..., entry_computation_layout={(f32[32,128]{...})->f32[128,128]{...}} - layout_match = re.search(r"entry_computation_layout={\((.*?)\)->(.*?)}", content) + layout_match = re.search( + r"entry_computation_layout={\((.*?)\)->(.*?)}", content + ) if layout_match: input_shape = layout_match.group(1) output_shape = layout_match.group(2) @@ -954,24 +1010,30 @@ def extract_hlo_features_from_file(hlo_file_path: str) -> Tuple[str | None, str input_shape = re.sub(r"{.*}", "", input_shape) output_shape = re.sub(r"{.*}", "", output_shape) else: - print(f"Could not find entry_computation_layout in {hlo_file_path} to extract shapes.") + print( + f"Could not find entry_computation_layout in {hlo_file_path} to extract shapes." + ) # Extract replica groups # Example: replica_groups={{0,1},{2,3}}, dimensions... - rg_match = re.search(r"replica_groups=({{[0-9,]+(?:},{[0-9,]+)*}})", content, re.DOTALL) + rg_match = re.search( + r"replica_groups=({{[0-9,]+(?:},{[0-9,]+)*}})", content, re.DOTALL + ) if rg_match: replica_groups_str = rg_match.group(1) try: content_rg = replica_groups_str[2:-2] - first_group_str = content_rg.split('},{')[0] - first_replica_group = [int(x) for x in first_group_str.split(',')] + first_group_str = content_rg.split("},{")[0] + first_replica_group = [int(x) for x in first_group_str.split(",")] except Exception as e: - print(f'Could not parse replica_groups in hlo_text: {e}') + print(f"Could not parse replica_groups in hlo_text: {e}") first_replica_group = None else: print(f"Could not find replica_groups in {hlo_file_path}.") return input_shape, output_shape, replica_groups_str, first_replica_group + + def get_lhs_named_shading(mesh, strategy: ShardingStrategy): match strategy: case ShardingStrategy.NO_SHARDING: @@ -1056,7 +1118,9 @@ def handle_per_device_based_on_sharding(value, strategy: ShardingStrategy): return value // 2 -def handle_all_devices_based_on_sharding(value: int, strategy: ShardingStrategy): +def handle_all_devices_based_on_sharding( + value: int, strategy: ShardingStrategy +): match strategy: case ShardingStrategy.NO_SHARDING: return value * jax.device_count() @@ -1084,11 +1148,15 @@ def create_mesh(strategy: ShardingStrategy) -> Mesh: or strategy == ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_N ): num_devices = jax.device_count() - assert num_devices % 2 == 0, "Total devices must be divisible by 2 (chip size)" + assert ( + num_devices % 2 == 0 + ), "Total devices must be divisible by 2 (chip size)" num_chips = num_devices // 2 mesh_shape = (num_chips, 2) mesh_axes = ("chip", "device") - mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(mesh_shape), mesh_axes) + mesh = jax.sharding.Mesh( + np.array(jax.devices()).reshape(mesh_shape), mesh_axes + ) else: mesh = Mesh(np.array(jax.devices()), axis_names="device") return mesh @@ -1130,9 +1198,12 @@ def unified_flops_metrics( metadata = get_metrics_helper(params) metrics = {} - average_time_s_list = [average_time_ms / 10**3 for average_time_ms in time_ms_list] + average_time_s_list = [ + average_time_ms / 10**3 for average_time_ms in time_ms_list + ] tflops_per_sec_list = [ - total_flops / average_time_s / 10**12 for average_time_s in average_time_s_list + total_flops / average_time_s / 10**12 + for average_time_s in average_time_s_list ] tflops_per_sec_all_devices = [ total_flops_all_devices / average_time_s / 10**12 @@ -1146,7 +1217,8 @@ def unified_flops_metrics( metrics_list=time_ms_list, metrics_name="step_time_ms" ) tflops_per_sec_statistics = MetricsStatistics( - metrics_list=tflops_per_sec_list, metrics_name="tflops_per_sec_pre_device" + metrics_list=tflops_per_sec_list, + metrics_name="tflops_per_sec_pre_device", ) tflops_per_sec_all_devices_statistics = MetricsStatistics( metrics_list=tflops_per_sec_all_devices, metrics_name="tflops_per_sec" @@ -1182,7 +1254,9 @@ def unified_flops_metrics( metrics.update(tflops_per_sec_statistics.serialize_statistics()) metrics.update(tflops_per_sec_all_devices_statistics.serialize_statistics()) metrics.update(mfu_statistics.serialize_statistics()) - metrics = {key: value for key, value in metrics.items() if value is not None} + metrics = { + key: value for key, value in metrics.items() if value is not None + } return metadata, metrics @@ -1201,9 +1275,12 @@ def unified_bytes_metrics( metadata = get_metrics_helper(params) metrics = {} - average_time_s_list = [average_time_ms / 10**3 for average_time_ms in time_ms_list] + average_time_s_list = [ + average_time_ms / 10**3 for average_time_ms in time_ms_list + ] gigabytes_per_sec_list = [ - total_bytes / average_time_s / 10**9 for average_time_s in average_time_s_list + total_bytes / average_time_s / 10**9 + for average_time_s in average_time_s_list ] digabytes_per_sec_all_devices = [ total_bytes_all_devices / average_time_s / 10**9 @@ -1213,10 +1290,12 @@ def unified_bytes_metrics( metrics_list=time_ms_list, metrics_name="step_time_ms" ) gigabytes_per_sec_statistics = MetricsStatistics( - metrics_list=gigabytes_per_sec_list, metrics_name="Gbytes_per_sec_per_device" + metrics_list=gigabytes_per_sec_list, + metrics_name="Gbytes_per_sec_per_device", ) gigabytes_per_sec_all_devices_statistics = MetricsStatistics( - metrics_list=digabytes_per_sec_all_devices, metrics_name="Gbytes_per_sec" + metrics_list=digabytes_per_sec_all_devices, + metrics_name="Gbytes_per_sec", ) type_prefix = "" # Gather the metrics to report. @@ -1249,10 +1328,15 @@ def unified_bytes_metrics( ) metrics.update(average_time_ms_statistics.serialize_statistics()) metrics.update(gigabytes_per_sec_statistics.serialize_statistics()) - metrics.update(gigabytes_per_sec_all_devices_statistics.serialize_statistics()) - metrics = {key: value for key, value in metrics.items() if value is not None} + metrics.update( + gigabytes_per_sec_all_devices_statistics.serialize_statistics() + ) + metrics = { + key: value for key, value in metrics.items() if value is not None + } return metadata, metrics + def str_to_dtype(dtype_str: str) -> jnp.dtype: """Converts a string identifier to a JAX numpy dtype.""" if dtype_str.lower() == "fp8": @@ -1266,6 +1350,7 @@ def str_to_dtype(dtype_str: str) -> jnp.dtype: else: raise ValueError(f"Unsupported dtype string: {dtype_str}") + def get_peak_flops_multiplier(in_dtype_str: str) -> float: """ Returns the peak FLOPS multiplier relative to the baseline @@ -1284,4 +1369,6 @@ def get_peak_flops_multiplier(in_dtype_str: str) -> float: # FP32 is 4x slower than FP8 peak return 0.25 else: - raise RuntimeError(f"{in_dtype_lower} is not supported for setting peak_flops_multiplier.") + raise RuntimeError( + f"{in_dtype_lower} is not supported for setting peak_flops_multiplier." + ) diff --git a/Ironwood/src/common.py b/Ironwood/src/common.py index fd11258..df7ab4c 100644 --- a/Ironwood/src/common.py +++ b/Ironwood/src/common.py @@ -1 +1,3 @@ +"""Common constants for the microbenchmarks.""" + MARKER = "!!MARKER!!" diff --git a/Ironwood/src/run_benchmark.py b/Ironwood/src/run_benchmark.py index b44aab7..471bee1 100644 --- a/Ironwood/src/run_benchmark.py +++ b/Ironwood/src/run_benchmark.py @@ -12,7 +12,11 @@ import random import string from typing import Any, Callable, Dict, List, Tuple -from benchmark_utils import maybe_write_metrics_file, rename_xla_dump, MetricsStatistics +from benchmark_utils import ( + maybe_write_metrics_file, + rename_xla_dump, + MetricsStatistics, +) import jax import yaml import ray @@ -35,7 +39,9 @@ MATMUL_BENCHMARK_MAP = { "naive_matmul": "benchmark_matmul.naive_matmul", "single_host_naive_matmul": "benchmark_matmul.single_host_naive_matmul", - "multilayer_collective_matmul": ("benchmark_matmul.multilayer_collective_matmul"), + "multilayer_collective_matmul": ( + "benchmark_matmul.multilayer_collective_matmul" + ), "collective_matmul_one_direction": ( "benchmark_matmul.collective_matmul_one_direction" ), @@ -47,7 +53,9 @@ "numpy_convolve": "benchmark_convolution.numpy_convolve", "scipy_signal_convolve": "benchmark_convolution.scipy_signal_convolve", "scipy_signal_convolve2d": "benchmark_convolution.scipy_signal_convolve2d", - "lax_conv_general_dilated": ("benchmark_convolution.lax_conv_general_dilated"), + "lax_conv_general_dilated": ( + "benchmark_convolution.lax_conv_general_dilated" + ), } ATTENTION_BENCHMARK_MAP = { "tokamax_splash_attention": "benchmark_attention.tokamax_splash_attention_benchmark", @@ -136,7 +144,9 @@ def get_benchmark_functions( ) -> Tuple[Callable[..., Any], Callable[..., Any]]: """Dynamically load the benchmark function and its calculate_metrics function from the predefined map.""" if benchmark_name not in BENCHMARK_MAP: - raise ValueError(f"Benchmark {benchmark_name} is not defined in the map.") + raise ValueError( + f"Benchmark {benchmark_name} is not defined in the map." + ) module_path, func_name = BENCHMARK_MAP[benchmark_name].rsplit(".", 1) @@ -155,7 +165,9 @@ def get_benchmark_functions( # Get the calculate_metrics function try: - calculate_metrics_func = getattr(module, f"{func_name}_calculate_metrics") + calculate_metrics_func = getattr( + module, f"{func_name}_calculate_metrics" + ) except AttributeError: raise ValueError( f"Calculate metrics function for {benchmark_name} not found." @@ -240,14 +252,18 @@ def generate_benchmark_params_sweeping( # Generate all combinations using itertools.product combinations = [ dict(zip(param_names, values)) - for values in itertools.product(*(param_sets[name] for name in param_names)) + for values in itertools.product( + *(param_sets[name] for name in param_names) + ) ] generated_params += combinations return generated_params -def write_to_csv(csv_path: str, calculate_metrics_results: List[Dict[str, Any]]): +def write_to_csv( + csv_path: str, calculate_metrics_results: List[Dict[str, Any]] +): """Writes benchmark metrics to a CSV file. This function takes a list of dictionaries, where each dictionary contains @@ -323,7 +339,9 @@ def run_single_benchmark(benchmark_config: Dict[str, Any], output_path: str): benchmark_params = benchmark_config.get("benchmark_params", []) benchmark_sweep_params = benchmark_config.get("benchmark_sweep_params", {}) if benchmark_sweep_params: - benchmark_params += generate_benchmark_params_sweeping(benchmark_sweep_params) + benchmark_params += generate_benchmark_params_sweeping( + benchmark_sweep_params + ) csv_path = benchmark_config.get("csv_path") trace_dir = benchmark_config.get("trace_dir") xlml_metrics_dir = benchmark_config.get("xlml_metrics_dir") @@ -343,8 +361,10 @@ def run_single_benchmark(benchmark_config: Dict[str, Any], output_path: str): raise ValueError("Each benchmark must have a 'benchmark_name'.") # Get the benchmark function - - benchmark_func, calculate_metrics_func = get_benchmark_functions(benchmark_name) + + benchmark_func, calculate_metrics_func = get_benchmark_functions( + benchmark_name + ) print(f"\n{'=' * 30}Starting benchmark '{benchmark_name}'{'=' * 30}\n") @@ -353,9 +373,12 @@ def run_single_benchmark(benchmark_config: Dict[str, Any], output_path: str): for id, benchmark_param in enumerate(benchmark_params): original_benchmark_param = copy.deepcopy(benchmark_param) benchmark_param = preprocess_benchmark_param( - benchmark_param, trace_dir=os.path.join(trace_dir, f"benchmark_{id}") + benchmark_param, + trace_dir=os.path.join(trace_dir, f"benchmark_{id}"), + ) + print( + f"Running benchmark: {benchmark_name} with params: {benchmark_param}" ) - print(f"Running benchmark: {benchmark_name} with params: {benchmark_param}") test_start_time = ( datetime.datetime.now(tz=datetime.timezone.utc).isoformat() + "Z" ) # "Z" indicates UTC @@ -408,10 +431,9 @@ def run_single_benchmark(benchmark_config: Dict[str, Any], output_path: str): test_end_time, ) # Post process the xla dump - calculate_metrics_results.append({ - "metadata": metadata, - "metrics": metrics - }) + calculate_metrics_results.append( + {"metadata": metadata, "metrics": metrics} + ) # Dump metrics to file. if csv_path: @@ -472,7 +494,9 @@ def run_benchmark_multithreaded(benchmark_config, output_path): benchmark_params = benchmark_config.get("benchmark_params", []) benchmark_sweep_params = benchmark_config.get("benchmark_sweep_params", {}) if benchmark_sweep_params: - benchmark_params += generate_benchmark_params_sweeping(benchmark_sweep_params) + benchmark_params += generate_benchmark_params_sweeping( + benchmark_sweep_params + ) csv_path = benchmark_config.get("csv_path") if not benchmark_name: raise ValueError("Each benchmark must have a 'benchmark_name'.") @@ -487,7 +511,9 @@ def run_benchmark_multithreaded(benchmark_config, output_path): param["num_runs"] = global_num_runs # Get the benchmark function - benchmark_func, calculate_metrics_func = get_benchmark_functions(benchmark_name) + benchmark_func, calculate_metrics_func = get_benchmark_functions( + benchmark_name + ) print(f"\n{'=' * 30}Starting benchmark '{benchmark_name}'{'=' * 30}\n") @@ -521,7 +547,9 @@ def run_benchmark_multithreaded(benchmark_config, output_path): benchmark_param = future_to_param[ future ] # Retrieve the corresponding benchmark_param - benchmark_results = future.result() # Get the result from the future + benchmark_results = ( + future.result() + ) # Get the result from the future # Filter benchmark_results to include only keys present in calculate_metrics_func calculate_metrics_params = inspect.signature( @@ -537,7 +565,9 @@ def run_benchmark_multithreaded(benchmark_config, output_path): metadata, metrics = calculate_metrics_func( **benchmark_param, **filtered_benchmark_results ) - calculate_metrics_results.append({"metadata": metadata, "metrics": metrics}) + calculate_metrics_results.append( + {"metadata": metadata, "metrics": metrics} + ) if csv_path: os.makedirs(csv_path, exist_ok=True) diff --git a/requirements.txt b/requirements.txt index 3ae246a..bb971b1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,7 @@ qwix@git+https://github.com/google/qwix.git tokamax tune-jax immutabledict +pylint +black +isort +pre-commit \ No newline at end of file diff --git a/src/all_gather.py b/src/all_gather.py index 68c08fd..0270d60 100644 --- a/src/all_gather.py +++ b/src/all_gather.py @@ -11,7 +11,6 @@ import jax import numpy as np - TRACE_BASE_DIR = None METRICS_JSONL_DIR = None @@ -19,144 +18,154 @@ def all_gather(matrix_dim): - """Performs an all_gather operation and calculates the achieved bandwidth.""" - dtype = jax.numpy.bfloat16 - matrix = jax.numpy.arange(matrix_dim * matrix_dim, dtype=dtype).reshape( - matrix_dim, matrix_dim - ) - - selected_devices = jax.devices() - mesh = jax.sharding.Mesh(selected_devices, "axis") - sharded_sharding = jax.sharding.NamedSharding( - mesh, jax.sharding.PartitionSpec("axis") - ) - unsharded_sharding = jax.sharding.NamedSharding( - mesh, jax.sharding.PartitionSpec(None) - ) - - arrays = [ - jax.device_put(matrix[index], d) - for d, index in sharded_sharding.addressable_devices_indices_map( - matrix.shape - ).items() - ] - - matrix = jax.make_array_from_single_device_arrays( - matrix.shape, sharded_sharding, arrays - ) - - @functools.partial(jax.jit, out_shardings=unsharded_sharding) - def unshard_array(input_matrix): - return input_matrix - - average_time_ms = simple_timeit(unshard_array, matrix, task="unshard_array") - - matrix_size_gbyte = matrix.size * dtype.dtype.itemsize / 1e9 - number_of_devices = len(jax.devices()) - sharded_matrix_size_gbyte = matrix_size_gbyte / number_of_devices - - # Calculate achieved bandwidth - achieved_bandwidth_gbyte_s = ( - sharded_matrix_size_gbyte - * (number_of_devices - 1) - / (average_time_ms / 1e3) - ) - matrix_size_gbyte_to_bandwidth[matrix_size_gbyte] = achieved_bandwidth_gbyte_s - print( - f"Matrix size: {matrix_dim}x{matrix_dim}, {dtype=}, " - f"{matrix_size_gbyte=}, {achieved_bandwidth_gbyte_s=}" - ) + """Performs an all_gather operation and calculates the achieved bandwidth.""" + dtype = jax.numpy.bfloat16 + matrix = jax.numpy.arange(matrix_dim * matrix_dim, dtype=dtype).reshape( + matrix_dim, matrix_dim + ) + selected_devices = jax.devices() + mesh = jax.sharding.Mesh(selected_devices, "axis") + sharded_sharding = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec("axis") + ) + unsharded_sharding = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec(None) + ) -# TODO(qinyiyan): Merge common code with all_reduce.py. -def run_benchmark(): - """Runs the all_gather benchmark and saves traces.""" + arrays = [ + jax.device_put(matrix[index], d) + for d, index in sharded_sharding.addressable_devices_indices_map( + matrix.shape + ).items() + ] - trace_dir = None - if TRACE_BASE_DIR: - trace_name = "t_all_gather_" + "".join( - random.choices(string.ascii_uppercase + string.digits, k=10) + matrix = jax.make_array_from_single_device_arrays( + matrix.shape, sharded_sharding, arrays + ) + + @functools.partial(jax.jit, out_shardings=unsharded_sharding) + def unshard_array(input_matrix): + return input_matrix + + average_time_ms = simple_timeit(unshard_array, matrix, task="unshard_array") + + matrix_size_gbyte = matrix.size * dtype.dtype.itemsize / 1e9 + number_of_devices = len(jax.devices()) + sharded_matrix_size_gbyte = matrix_size_gbyte / number_of_devices + + # Calculate achieved bandwidth + achieved_bandwidth_gbyte_s = ( + sharded_matrix_size_gbyte + * (number_of_devices - 1) + / (average_time_ms / 1e3) + ) + matrix_size_gbyte_to_bandwidth[matrix_size_gbyte] = ( + achieved_bandwidth_gbyte_s ) - trace_dir = f"{TRACE_BASE_DIR}/{trace_name}" - jax.profiler.start_trace(str(trace_dir)) - - test_start_time = datetime.datetime.now().strftime("%Y%m%d%H%M%S") - matrix_size = 1024 - try: - while matrix_size <= 30000: - all_gather(matrix_size) - matrix_size += 1024 - except MemoryError: print( - "MemoryError: Failed to create or process matrix of size " - f"{matrix_size} x {matrix_size}.\n" + f"Matrix size: {matrix_dim}x{matrix_dim}, {dtype=}, " + f"{matrix_size_gbyte=}, {achieved_bandwidth_gbyte_s=}" + ) + + +# TODO(qinyiyan): Merge common code with all_reduce.py. +def run_benchmark(): + """Runs the all_gather benchmark and saves traces.""" + + trace_dir = None + if TRACE_BASE_DIR: + trace_name = "t_all_gather_" + "".join( + random.choices(string.ascii_uppercase + string.digits, k=10) + ) + trace_dir = f"{TRACE_BASE_DIR}/{trace_name}" + jax.profiler.start_trace(str(trace_dir)) + + test_start_time = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + matrix_size = 1024 + try: + while matrix_size <= 30000: + all_gather(matrix_size) + matrix_size += 1024 + except MemoryError: + print( + "MemoryError: Failed to create or process matrix of size " + f"{matrix_size} x {matrix_size}.\n" + ) + except Exception as e: # pylint: disable=broad-exception-caught + print( + f"Exception: {e} occurred at size {matrix_size} x {matrix_size}.\n" + ) + + if TRACE_BASE_DIR: + jax.profiler.stop_trace() + print(f"Trace saved to {trace_dir}") + + test_end_time = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + + # Calculate and write metrics + max_achieved_bandwidth_gbyte_s = max( + matrix_size_gbyte_to_bandwidth.values() ) - except Exception as e: # pylint: disable=broad-exception-caught - print(f"Exception: {e} occurred at size {matrix_size} x {matrix_size}.\n") - - if TRACE_BASE_DIR: - jax.profiler.stop_trace() - print(f"Trace saved to {trace_dir}") - - test_end_time = datetime.datetime.now().strftime("%Y%m%d%H%M%S") - - # Calculate and write metrics - max_achieved_bandwidth_gbyte_s = max(matrix_size_gbyte_to_bandwidth.values()) - median_achieved_bandwidth_gbyte_s = np.percentile( - list(matrix_size_gbyte_to_bandwidth.values()), 50 - ) - p90_achieved_bandwidth_gbyte_s = np.percentile( - list(matrix_size_gbyte_to_bandwidth.values()), 90 - ) - - metrics = { - "max_achieved_bandwidth_gbyte_s": max_achieved_bandwidth_gbyte_s, - "median_achieved_bandwidth_gbyte_s": median_achieved_bandwidth_gbyte_s, - "p90_achieved_bandwidth_gbyte_s": p90_achieved_bandwidth_gbyte_s, - } - if METRICS_JSONL_DIR: - maybe_write_metrics_file( - METRICS_JSONL_DIR, metrics, "all_gather", test_start_time, test_end_time + median_achieved_bandwidth_gbyte_s = np.percentile( + list(matrix_size_gbyte_to_bandwidth.values()), 50 ) + p90_achieved_bandwidth_gbyte_s = np.percentile( + list(matrix_size_gbyte_to_bandwidth.values()), 90 + ) + + metrics = { + "max_achieved_bandwidth_gbyte_s": max_achieved_bandwidth_gbyte_s, + "median_achieved_bandwidth_gbyte_s": median_achieved_bandwidth_gbyte_s, + "p90_achieved_bandwidth_gbyte_s": p90_achieved_bandwidth_gbyte_s, + } + if METRICS_JSONL_DIR: + maybe_write_metrics_file( + METRICS_JSONL_DIR, + metrics, + "all_gather", + test_start_time, + test_end_time, + ) def main(): - """Parses arguments and runs the benchmark.""" - parser = argparse.ArgumentParser( - description=( - "A script to analyze the benchmark results and dump the result" - " to a JSONL file." - ), - formatter_class=argparse.RawTextHelpFormatter, - ) - - parser.add_argument( - "--trace_dir", - type=str, - help=( - "Set the output directory, such as" - " `--trace_dir=/tmp/microbenchmark/outputs`" - ), - ) - parser.add_argument( - "--metrics_jsonl_dir", - type=str, - help=( - "The directory to generate the metrics JSONL file, such as" - " `--metrics_jsonl_dir=/tmp/microbenchmark/outputs/`" - ), - ) - - args = parser.parse_args() - - global TRACE_BASE_DIR, METRICS_JSONL_DIR - if args.trace_dir: - TRACE_BASE_DIR = args.trace_dir - if args.metrics_jsonl_dir: - METRICS_JSONL_DIR = args.metrics_jsonl_dir - - run_benchmark() + """Parses arguments and runs the benchmark.""" + parser = argparse.ArgumentParser( + description=( + "A script to analyze the benchmark results and dump the result" + " to a JSONL file." + ), + formatter_class=argparse.RawTextHelpFormatter, + ) + + parser.add_argument( + "--trace_dir", + type=str, + help=( + "Set the output directory, such as" + " `--trace_dir=/tmp/microbenchmark/outputs`" + ), + ) + parser.add_argument( + "--metrics_jsonl_dir", + type=str, + help=( + "The directory to generate the metrics JSONL file, such as" + " `--metrics_jsonl_dir=/tmp/microbenchmark/outputs/`" + ), + ) + + args = parser.parse_args() + + global TRACE_BASE_DIR, METRICS_JSONL_DIR + if args.trace_dir: + TRACE_BASE_DIR = args.trace_dir + if args.metrics_jsonl_dir: + METRICS_JSONL_DIR = args.metrics_jsonl_dir + + run_benchmark() if __name__ == "__main__": - main() + main() diff --git a/src/all_reduce.py b/src/all_reduce.py index 5170e82..e12f536 100644 --- a/src/all_reduce.py +++ b/src/all_reduce.py @@ -12,7 +12,6 @@ import jax.numpy as jnp import numpy as np - TRACE_BASE_DIR = None METRICS_JSONL_DIR = None @@ -20,137 +19,149 @@ def all_reduce_sum(matrix_dim): - """Calculates the sum of a matrix using all_reduce.""" - dtype = jax.numpy.bfloat16 - matrix = jax.numpy.arange( - jax.local_device_count() * matrix_dim * matrix_dim, dtype=dtype - ).reshape(jax.local_device_count(), matrix_dim, matrix_dim) - - @functools.partial(jax.pmap, axis_name="devices") - def parallel_sum(x): - return jax.lax.psum(x, axis_name="devices") - - # Preload the sharded data to devices. This is to avoid the data transfer - # time in the all_reduce operation. - matrix_split = jnp.array_split(matrix, jax.local_device_count(), axis=0) - matrix_distributed = jax.device_put_sharded(matrix_split, jax.local_devices()) - - average_time_ms = simple_timeit( - parallel_sum, matrix_distributed, task="parallel_sum" - ) - - print(f"Average time milliseconds: {average_time_ms:.2f}") - - matrix_size_gbyte = matrix.size * dtype.dtype.itemsize / 1e9 - shard_size_gbyte = ( - matrix.size * dtype.dtype.itemsize / 1e9 / jax.local_device_count() - ) - number_of_devices = len(jax.devices()) - # Send the data to all other (N-1) devices. - achieved_bandwidth_gbyte_s = ( - shard_size_gbyte - * (number_of_devices - 1) - / number_of_devices - / (average_time_ms / 1e3) - ) - matrix_size_gbyte_to_bandwidth[matrix_size_gbyte] = achieved_bandwidth_gbyte_s - print( - f"Matrix shape: {matrix.shape}, {dtype=}, {matrix_size_gbyte=}," - f" {achieved_bandwidth_gbyte_s=}" - ) + """Calculates the sum of a matrix using all_reduce.""" + dtype = jax.numpy.bfloat16 + matrix = jax.numpy.arange( + jax.local_device_count() * matrix_dim * matrix_dim, dtype=dtype + ).reshape(jax.local_device_count(), matrix_dim, matrix_dim) + + @functools.partial(jax.pmap, axis_name="devices") + def parallel_sum(x): + return jax.lax.psum(x, axis_name="devices") + + # Preload the sharded data to devices. This is to avoid the data transfer + # time in the all_reduce operation. + matrix_split = jnp.array_split(matrix, jax.local_device_count(), axis=0) + matrix_distributed = jax.device_put_sharded( + matrix_split, jax.local_devices() + ) + + average_time_ms = simple_timeit( + parallel_sum, matrix_distributed, task="parallel_sum" + ) + + print(f"Average time milliseconds: {average_time_ms:.2f}") + + matrix_size_gbyte = matrix.size * dtype.dtype.itemsize / 1e9 + shard_size_gbyte = ( + matrix.size * dtype.dtype.itemsize / 1e9 / jax.local_device_count() + ) + number_of_devices = len(jax.devices()) + # Send the data to all other (N-1) devices. + achieved_bandwidth_gbyte_s = ( + shard_size_gbyte + * (number_of_devices - 1) + / number_of_devices + / (average_time_ms / 1e3) + ) + matrix_size_gbyte_to_bandwidth[matrix_size_gbyte] = ( + achieved_bandwidth_gbyte_s + ) + print( + f"Matrix shape: {matrix.shape}, {dtype=}, {matrix_size_gbyte=}," + f" {achieved_bandwidth_gbyte_s=}" + ) def run_benchmark(): - """Runs the all_reduce benchmark and saves traces.""" - test_start_time = datetime.datetime.now().strftime("%Y%m%d%H%M%S") - trace_name = "t_all_reduce_sum_" + "".join( - random.choice(string.ascii_uppercase + string.digits) for _ in range(10) - ) - trace_dir = None - if TRACE_BASE_DIR: - trace_dir = f"{TRACE_BASE_DIR}/{trace_name}" - jax.profiler.start_trace(trace_dir) - - # Sweep the data size to saturate the bandwidth. - matrix_size = 1024 - while True: - try: - all_reduce_sum(matrix_size) - matrix_size += 1024 - if matrix_size > 10000: - break - except MemoryError: - print( - "MemoryError: Failed to create or process matrix of size" - f" {matrix_size} x {matrix_size}.\n" - ) - break - except Exception as e: # pylint: disable=broad-exception-caught - print(f"Exception: {e} occurred at size {matrix_size} x {matrix_size}.\n") - break - if TRACE_BASE_DIR: - jax.profiler.stop_trace() - print(f"Trace saved to {trace_dir}") - - test_end_time = datetime.datetime.now().strftime("%Y%m%d%H%M%S") - - # Calculate and write metrics - max_achieved_bandwidth_gbyte_s = max(matrix_size_gbyte_to_bandwidth.values()) - median_achieved_bandwidth_gbyte_s = np.percentile( - list(matrix_size_gbyte_to_bandwidth.values()), 50 - ) - p90_achieved_bandwidth_gbyte_s = np.percentile( - list(matrix_size_gbyte_to_bandwidth.values()), 90 - ) - - metrics = { - "max_achieved_bandwidth_gbyte_s": max_achieved_bandwidth_gbyte_s, - "median_achieved_bandwidth_gbyte_s": median_achieved_bandwidth_gbyte_s, - "p90_achieved_bandwidth_gbyte_s": p90_achieved_bandwidth_gbyte_s, - } - if METRICS_JSONL_DIR: - maybe_write_metrics_file( - METRICS_JSONL_DIR, metrics, "all_reduce", test_start_time, test_end_time + """Runs the all_reduce benchmark and saves traces.""" + test_start_time = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + trace_name = "t_all_reduce_sum_" + "".join( + random.choice(string.ascii_uppercase + string.digits) for _ in range(10) + ) + trace_dir = None + if TRACE_BASE_DIR: + trace_dir = f"{TRACE_BASE_DIR}/{trace_name}" + jax.profiler.start_trace(trace_dir) + + # Sweep the data size to saturate the bandwidth. + matrix_size = 1024 + while True: + try: + all_reduce_sum(matrix_size) + matrix_size += 1024 + if matrix_size > 10000: + break + except MemoryError: + print( + "MemoryError: Failed to create or process matrix of size" + f" {matrix_size} x {matrix_size}.\n" + ) + break + except Exception as e: # pylint: disable=broad-exception-caught + print( + f"Exception: {e} occurred at size {matrix_size} x {matrix_size}.\n" + ) + break + if TRACE_BASE_DIR: + jax.profiler.stop_trace() + print(f"Trace saved to {trace_dir}") + + test_end_time = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + + # Calculate and write metrics + max_achieved_bandwidth_gbyte_s = max( + matrix_size_gbyte_to_bandwidth.values() ) + median_achieved_bandwidth_gbyte_s = np.percentile( + list(matrix_size_gbyte_to_bandwidth.values()), 50 + ) + p90_achieved_bandwidth_gbyte_s = np.percentile( + list(matrix_size_gbyte_to_bandwidth.values()), 90 + ) + + metrics = { + "max_achieved_bandwidth_gbyte_s": max_achieved_bandwidth_gbyte_s, + "median_achieved_bandwidth_gbyte_s": median_achieved_bandwidth_gbyte_s, + "p90_achieved_bandwidth_gbyte_s": p90_achieved_bandwidth_gbyte_s, + } + if METRICS_JSONL_DIR: + maybe_write_metrics_file( + METRICS_JSONL_DIR, + metrics, + "all_reduce", + test_start_time, + test_end_time, + ) def main(): - """Parses arguments and runs the benchmark.""" - parser = argparse.ArgumentParser( - description=( - "A script to analyze the benchmark results and dump the result" - " to a JSONL file." - ), - formatter_class=argparse.RawTextHelpFormatter, - ) - - parser.add_argument( - "--trace_dir", - type=str, - help=( - "Set the output directory, such as" - " `--trace_dir=/tmp/microbenchmark/outputs`" - ), - ) - parser.add_argument( - "--metrics_jsonl_dir", - type=str, - help=( - "The directory to generate the metrics JSONL file, such as" - " `--metrics_jsonl_dir=/tmp/microbenchmark/outputs/metrics.jsonl`" - ), - ) - - args = parser.parse_args() - - global TRACE_BASE_DIR, METRICS_JSONL_DIR - if args.trace_dir: - TRACE_BASE_DIR = args.trace_dir - if args.metrics_jsonl_dir: - METRICS_JSONL_DIR = args.metrics_jsonl_dir - - run_benchmark() + """Parses arguments and runs the benchmark.""" + parser = argparse.ArgumentParser( + description=( + "A script to analyze the benchmark results and dump the result" + " to a JSONL file." + ), + formatter_class=argparse.RawTextHelpFormatter, + ) + + parser.add_argument( + "--trace_dir", + type=str, + help=( + "Set the output directory, such as" + " `--trace_dir=/tmp/microbenchmark/outputs`" + ), + ) + parser.add_argument( + "--metrics_jsonl_dir", + type=str, + help=( + "The directory to generate the metrics JSONL file, such as" + " `--metrics_jsonl_dir=/tmp/microbenchmark/outputs/metrics.jsonl`" + ), + ) + + args = parser.parse_args() + + global TRACE_BASE_DIR, METRICS_JSONL_DIR + if args.trace_dir: + TRACE_BASE_DIR = args.trace_dir + if args.metrics_jsonl_dir: + METRICS_JSONL_DIR = args.metrics_jsonl_dir + + run_benchmark() if __name__ == "__main__": - main() + main() diff --git a/src/benchmark_attention.py b/src/benchmark_attention.py index 1e5dbb4..a6e8bac 100644 --- a/src/benchmark_attention.py +++ b/src/benchmark_attention.py @@ -17,6 +17,7 @@ """ # pylint: disable=g-importing-member,g-bad-import-order +import keras # pylint: disable=g-bad-import-order,g-import-not-at-top from functools import partial import os from typing import Any, Dict, Tuple @@ -25,16 +26,21 @@ from flax import linen from flax import nnx import jax -from jax.experimental.pallas.ops.tpu import flash_attention as pallas_flash_attention -from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel -from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask +from jax.experimental.pallas.ops.tpu import ( + flash_attention as pallas_flash_attention, +) +from jax.experimental.pallas.ops.tpu.splash_attention import ( + splash_attention_kernel, +) +from jax.experimental.pallas.ops.tpu.splash_attention import ( + splash_attention_mask, +) import jax.numpy as jnp import numpy as np # pylint: disable=g-importing-member,g-bad-import-order os.environ["KERAS_BACKEND"] = "jax" -import keras # pylint: disable=g-bad-import-order,g-import-not-at-top # Tunable parameters for splash attention. # Kernel block sizes. @@ -95,7 +101,9 @@ def f(q, k, v, causal, scale): scale_factor = 1.0 if scale: scale_factor = 1.0 / jnp.sqrt(k_kv_size) - weights_unnormalized = jax.numpy.einsum("BHSD,BHTD->BHST", q, k) * scale_factor + weights_unnormalized = ( + jax.numpy.einsum("BHSD,BHTD->BHST", q, k) * scale_factor + ) if causal: weights_unnormalized_to_zero_out = jax.numpy.triu( jax.numpy.ones((seq_lengh, seq_lengh), jax.numpy.bfloat16), 1 diff --git a/src/benchmark_collectives.py b/src/benchmark_collectives.py index 30d1296..ace4119 100644 --- a/src/benchmark_collectives.py +++ b/src/benchmark_collectives.py @@ -15,7 +15,9 @@ # pylint: disable=g-importing-member -def create_mesh(dcn_size: int, ici_size: int) -> tuple[Mesh, list[int], list[int]]: +def create_mesh( + dcn_size: int, ici_size: int +) -> tuple[Mesh, list[int], list[int]]: """Creates a hybrid mesh with the given DCN and ICI sizes.""" dcn_parallelism = [dcn_size, 1] ici_parallelism = [1, ici_size] @@ -31,7 +33,9 @@ def create_mesh(dcn_size: int, ici_size: int) -> tuple[Mesh, list[int], list[int ) mesh = Mesh(mesh_devices, ("dcn", "ici")) else: - mesh_devices = mesh_utils.create_device_mesh([ici_size], devices=jax.devices()) + mesh_devices = mesh_utils.create_device_mesh( + [ici_size], devices=jax.devices() + ) mesh = Mesh(mesh_devices, "ici") return mesh @@ -49,6 +53,7 @@ def extract_metadata( metadata["dtype"] = metadata["dtype"].dtype.itemsize return metadata + def generate_metrics_statistics( metrics_list: list[float], metrics_name: str, @@ -71,6 +76,7 @@ def generate_metrics_statistics( ) metrics.update(statistics.serialize_statistics()) + def benchmark_collective( benchmark_name: str, jax_op: Any, @@ -135,6 +141,7 @@ def f(x): return time_ms_list + def psum_benchmark( matrix_dim: int, dtype: jnp.dtype, @@ -167,22 +174,37 @@ def psum_benchmark( if dcn_size > 1: results["dcn_time_ms_list"] = benchmark_collective( - benchmark_name="psum", jax_op=jax.lax.psum, mesh=mesh, matrix=matrix, - matrix_dim=matrix_dim, axis_name="dcn", in_specs=P("dcn", None), - out_specs=P(None, None), num_runs=num_runs, warmup_tries=warmup_tries, + benchmark_name="psum", + jax_op=jax.lax.psum, + mesh=mesh, + matrix=matrix, + matrix_dim=matrix_dim, + axis_name="dcn", + in_specs=P("dcn", None), + out_specs=P(None, None), + num_runs=num_runs, + warmup_tries=warmup_tries, trace_dir=trace_dir, ) if ici_size > 1: results["ici_time_ms_list"] = benchmark_collective( - benchmark_name="psum", jax_op=jax.lax.psum, mesh=mesh, matrix=matrix, - matrix_dim=matrix_dim, axis_name="ici", in_specs=P(None, None), - out_specs=P(None, None), num_runs=num_runs, warmup_tries=warmup_tries, + benchmark_name="psum", + jax_op=jax.lax.psum, + mesh=mesh, + matrix=matrix, + matrix_dim=matrix_dim, + axis_name="ici", + in_specs=P(None, None), + out_specs=P(None, None), + num_runs=num_runs, + warmup_tries=warmup_tries, trace_dir=trace_dir, ) return results + def psum_benchmark_calculate_metrics( matrix_dim: int, dtype: jnp.dtype, @@ -202,21 +224,31 @@ def psum_benchmark_calculate_metrics( # bandwidth is claculated as psum can be done via reduce_scatter + # all_gather so bandwidth is the sum of the two (formulas below) dcn_bandwidth_gbyte_s_list = [ - matrix_size_gbyte - * (dcn_size - 1) - * 2 - / dcn_size - / dcn_size - / (dcn_time_ms / 1e3) - for dcn_time_ms in dcn_time_ms_list + matrix_size_gbyte + * (dcn_size - 1) + * 2 + / dcn_size + / dcn_size + / (dcn_time_ms / 1e3) + for dcn_time_ms in dcn_time_ms_list ] generate_metrics_statistics( - dcn_bandwidth_gbyte_s_list, "dcn_bandwidth_gbyte_s", "psum_dcn", - matrix_dim, dtype, matrix_size_gbyte, metrics + dcn_bandwidth_gbyte_s_list, + "dcn_bandwidth_gbyte_s", + "psum_dcn", + matrix_dim, + dtype, + matrix_size_gbyte, + metrics, ) generate_metrics_statistics( - dcn_time_ms_list, "dcn_time_ms", "psum_dcn", - matrix_dim, dtype, matrix_size_gbyte, metrics + dcn_time_ms_list, + "dcn_time_ms", + "psum_dcn", + matrix_dim, + dtype, + matrix_size_gbyte, + metrics, ) metrics["dcn_time_ms_list"] = dcn_time_ms_list @@ -225,25 +257,36 @@ def psum_benchmark_calculate_metrics( # bandwidth is claculated as psum can be done via reduce_scatter + # all_gather so bandwidth is the sum of the two (formulas below) ici_bandwidth_gbyte_s_list = [ - matrix_size_gbyte - * (ici_size - 1) - * 2 - / ici_size - / (ici_time_ms / 1e3) - for ici_time_ms in ici_time_ms_list + matrix_size_gbyte + * (ici_size - 1) + * 2 + / ici_size + / (ici_time_ms / 1e3) + for ici_time_ms in ici_time_ms_list ] generate_metrics_statistics( - ici_bandwidth_gbyte_s_list, "ici_bandwidth_gbyte_s", "psum_ici", - matrix_dim, dtype, matrix_size_gbyte, metrics + ici_bandwidth_gbyte_s_list, + "ici_bandwidth_gbyte_s", + "psum_ici", + matrix_dim, + dtype, + matrix_size_gbyte, + metrics, ) generate_metrics_statistics( - ici_time_ms_list, "ici_time_ms", "psum_ici", - matrix_dim, dtype, matrix_size_gbyte, metrics + ici_time_ms_list, + "ici_time_ms", + "psum_ici", + matrix_dim, + dtype, + matrix_size_gbyte, + metrics, ) metrics["ici_time_ms_list"] = ici_time_ms_list return metadata, metrics + def psum_scatter_benchmark( matrix_dim: int, dtype: jnp.dtype, @@ -276,27 +319,41 @@ def psum_scatter_benchmark( psum_scatter_kwargs = {"tiled": True} - if dcn_size > 1: results["dcn_time_ms_list"] = benchmark_collective( - benchmark_name="psum_scatter", jax_op=jax.lax.psum_scatter, mesh=mesh, - matrix=matrix, matrix_dim=matrix_dim, axis_name="dcn", - in_specs=P("dcn", None), out_specs=P("dcn", None), - jax_op_kwargs=psum_scatter_kwargs, num_runs=num_runs, - warmup_tries=warmup_tries, trace_dir=trace_dir, + benchmark_name="psum_scatter", + jax_op=jax.lax.psum_scatter, + mesh=mesh, + matrix=matrix, + matrix_dim=matrix_dim, + axis_name="dcn", + in_specs=P("dcn", None), + out_specs=P("dcn", None), + jax_op_kwargs=psum_scatter_kwargs, + num_runs=num_runs, + warmup_tries=warmup_tries, + trace_dir=trace_dir, ) if ici_size > 1: results["ici_time_ms_list"] = benchmark_collective( - benchmark_name="psum_scatter", jax_op=jax.lax.psum_scatter, mesh=mesh, - matrix=matrix, matrix_dim=matrix_dim, axis_name="ici", - in_specs=P(None, None), out_specs=P(None, "ici"), - jax_op_kwargs=psum_scatter_kwargs, num_runs=num_runs, - warmup_tries=warmup_tries, trace_dir=trace_dir, + benchmark_name="psum_scatter", + jax_op=jax.lax.psum_scatter, + mesh=mesh, + matrix=matrix, + matrix_dim=matrix_dim, + axis_name="ici", + in_specs=P(None, None), + out_specs=P(None, "ici"), + jax_op_kwargs=psum_scatter_kwargs, + num_runs=num_runs, + warmup_tries=warmup_tries, + trace_dir=trace_dir, ) return results + def psum_scatter_benchmark_calculate_metrics( matrix_dim: int, dtype: jnp.dtype, @@ -317,20 +374,30 @@ def psum_scatter_benchmark_calculate_metrics( # each sharded matrix size is matrix_size_gbyte / dcn_size and then it needs # to use (dcn_size - 1) steps in a ring algorithm dcn_bandwidth_gbyte_s_list = [ - matrix_size_gbyte - * (dcn_size - 1) - / dcn_size - / dcn_size - / (dcn_time_ms / 1e3) - for dcn_time_ms in dcn_time_ms_list + matrix_size_gbyte + * (dcn_size - 1) + / dcn_size + / dcn_size + / (dcn_time_ms / 1e3) + for dcn_time_ms in dcn_time_ms_list ] generate_metrics_statistics( - dcn_bandwidth_gbyte_s_list, "dcn_bandwidth_gbyte_s", "psum_scatter_dcn", - matrix_dim, dtype, matrix_size_gbyte, metrics + dcn_bandwidth_gbyte_s_list, + "dcn_bandwidth_gbyte_s", + "psum_scatter_dcn", + matrix_dim, + dtype, + matrix_size_gbyte, + metrics, ) generate_metrics_statistics( - dcn_time_ms_list, "dcn_time_ms", "psum_scatter_dcn", - matrix_dim, dtype, matrix_size_gbyte, metrics + dcn_time_ms_list, + "dcn_time_ms", + "psum_scatter_dcn", + matrix_dim, + dtype, + matrix_size_gbyte, + metrics, ) metrics["dcn_time_ms_list"] = dcn_time_ms_list @@ -339,23 +406,32 @@ def psum_scatter_benchmark_calculate_metrics( # each sharded matrix size is matrix_size_gbyte / ici_size and then it needs # to use (ici_size - 1) steps in a ring algorithm ici_bandwidth_gbyte_s_list = [ - matrix_size_gbyte - * (ici_size - 1) - / ici_size - / (ici_time_ms / 1e3) - for ici_time_ms in ici_time_ms_list + matrix_size_gbyte * (ici_size - 1) / ici_size / (ici_time_ms / 1e3) + for ici_time_ms in ici_time_ms_list ] generate_metrics_statistics( - ici_bandwidth_gbyte_s_list, "ici_bandwidth_gbyte_s", "psum_scatter_ici", - matrix_dim, dtype, matrix_size_gbyte, metrics + ici_bandwidth_gbyte_s_list, + "ici_bandwidth_gbyte_s", + "psum_scatter_ici", + matrix_dim, + dtype, + matrix_size_gbyte, + metrics, ) generate_metrics_statistics( - ici_time_ms_list, "ici_time_ms", "psum_scatter_ici", - matrix_dim, dtype, matrix_size_gbyte, metrics + ici_time_ms_list, + "ici_time_ms", + "psum_scatter_ici", + matrix_dim, + dtype, + matrix_size_gbyte, + metrics, ) metrics["ici_time_ms_list"] = ici_time_ms_list - metrics = {key: value for key, value in metrics.items() if value is not None} + metrics = { + key: value for key, value in metrics.items() if value is not None + } return metadata, metrics @@ -393,24 +469,41 @@ def all_gather_benchmark( if dcn_size > 1: results["dcn_time_ms_list"] = benchmark_collective( - benchmark_name="all_gather", jax_op=jax.lax.all_gather, mesh=mesh, - matrix=matrix, matrix_dim=matrix_dim, axis_name="dcn", - in_specs=P("dcn", None), out_specs=P(None, None), check_rep=False, - jax_op_kwargs=all_gather_kwargs, num_runs=num_runs, - warmup_tries=warmup_tries, trace_dir=trace_dir, + benchmark_name="all_gather", + jax_op=jax.lax.all_gather, + mesh=mesh, + matrix=matrix, + matrix_dim=matrix_dim, + axis_name="dcn", + in_specs=P("dcn", None), + out_specs=P(None, None), + check_rep=False, + jax_op_kwargs=all_gather_kwargs, + num_runs=num_runs, + warmup_tries=warmup_tries, + trace_dir=trace_dir, ) if ici_size > 1: results["ici_time_ms_list"] = benchmark_collective( - benchmark_name="all_gather", jax_op=jax.lax.all_gather, mesh=mesh, - matrix=matrix, matrix_dim=matrix_dim, axis_name="ici", - in_specs=P("ici", None), out_specs=P(None, None), check_rep=False, - jax_op_kwargs=all_gather_kwargs, num_runs=num_runs, - warmup_tries=warmup_tries, trace_dir=trace_dir, + benchmark_name="all_gather", + jax_op=jax.lax.all_gather, + mesh=mesh, + matrix=matrix, + matrix_dim=matrix_dim, + axis_name="ici", + in_specs=P("ici", None), + out_specs=P(None, None), + check_rep=False, + jax_op_kwargs=all_gather_kwargs, + num_runs=num_runs, + warmup_tries=warmup_tries, + trace_dir=trace_dir, ) return results + def all_gather_benchmark_calculate_metrics( matrix_dim: int, dtype: jnp.dtype, @@ -431,19 +524,26 @@ def all_gather_benchmark_calculate_metrics( # each sharded matrix size is matrix_size_gbyte / dcn_size and then it needs # to use (dcn_size - 1) steps in a ring algorithm dcn_bandwidth_gbyte_s_list = [ - matrix_size_gbyte - * (dcn_size - 1) - / dcn_size - / (dcn_time_ms / 1e3) - for dcn_time_ms in dcn_time_ms_list + matrix_size_gbyte * (dcn_size - 1) / dcn_size / (dcn_time_ms / 1e3) + for dcn_time_ms in dcn_time_ms_list ] generate_metrics_statistics( - dcn_bandwidth_gbyte_s_list, "dcn_bandwidth_gbyte_s", "all_gather_dcn", - matrix_dim, dtype, matrix_size_gbyte, metrics + dcn_bandwidth_gbyte_s_list, + "dcn_bandwidth_gbyte_s", + "all_gather_dcn", + matrix_dim, + dtype, + matrix_size_gbyte, + metrics, ) generate_metrics_statistics( - dcn_time_ms_list, "dcn_time_ms", "all_gather_dcn", - matrix_dim, dtype, matrix_size_gbyte, metrics + dcn_time_ms_list, + "dcn_time_ms", + "all_gather_dcn", + matrix_dim, + dtype, + matrix_size_gbyte, + metrics, ) metrics["dcn_time_ms_list"] = dcn_time_ms_list @@ -452,25 +552,35 @@ def all_gather_benchmark_calculate_metrics( # each sharded matrix size is matrix_size_gbyte / ici_size and then it needs # to use (ici_size - 1) steps in a ring algorithm ici_bandwidth_gbyte_s_list = [ - matrix_size_gbyte - * (ici_size - 1) - / ici_size - / (ici_time_ms / 1e3) - for ici_time_ms in ici_time_ms_list + matrix_size_gbyte * (ici_size - 1) / ici_size / (ici_time_ms / 1e3) + for ici_time_ms in ici_time_ms_list ] generate_metrics_statistics( - ici_bandwidth_gbyte_s_list, "ici_bandwidth_gbyte_s", "all_gather_ici", - matrix_dim, dtype, matrix_size_gbyte, metrics + ici_bandwidth_gbyte_s_list, + "ici_bandwidth_gbyte_s", + "all_gather_ici", + matrix_dim, + dtype, + matrix_size_gbyte, + metrics, ) generate_metrics_statistics( - ici_time_ms_list, "ici_time_ms", "all_gather_ici", - matrix_dim, dtype, matrix_size_gbyte, metrics + ici_time_ms_list, + "ici_time_ms", + "all_gather_ici", + matrix_dim, + dtype, + matrix_size_gbyte, + metrics, ) metrics["ici_time_ms_list"] = ici_time_ms_list - metrics = {key: value for key, value in metrics.items() if value is not None} + metrics = { + key: value for key, value in metrics.items() if value is not None + } return metadata, metrics + def ppermute_benchmark( matrix_dim: int, dtype: jnp.dtype, @@ -501,29 +611,43 @@ def ppermute_benchmark( "ici_time_ms_list": None, } - if dcn_size > 1: dcn_perm = [(i, (i + 1) % dcn_size) for i in range(dcn_size)] results["dcn_time_ms_list"] = benchmark_collective( - benchmark_name="ppermute", jax_op=jax.lax.ppermute, mesh=mesh, - matrix=matrix, matrix_dim=matrix_dim, axis_name="dcn", - in_specs=P("dcn", None), out_specs=P("dcn", None), - jax_op_kwargs={"perm": dcn_perm}, num_runs=num_runs, - warmup_tries=warmup_tries, trace_dir=trace_dir, + benchmark_name="ppermute", + jax_op=jax.lax.ppermute, + mesh=mesh, + matrix=matrix, + matrix_dim=matrix_dim, + axis_name="dcn", + in_specs=P("dcn", None), + out_specs=P("dcn", None), + jax_op_kwargs={"perm": dcn_perm}, + num_runs=num_runs, + warmup_tries=warmup_tries, + trace_dir=trace_dir, ) if ici_size > 1: ici_perm = [(i, (i + 1) % ici_size) for i in range(ici_size)] results["ici_time_ms_list"] = benchmark_collective( - benchmark_name="ppermute", jax_op=jax.lax.ppermute, mesh=mesh, - matrix=matrix, matrix_dim=matrix_dim, axis_name="ici", - in_specs=P(None, None), out_specs=P(None, "ici"), - jax_op_kwargs={"perm": ici_perm}, num_runs=num_runs, - warmup_tries=warmup_tries, trace_dir=trace_dir, + benchmark_name="ppermute", + jax_op=jax.lax.ppermute, + mesh=mesh, + matrix=matrix, + matrix_dim=matrix_dim, + axis_name="ici", + in_specs=P(None, None), + out_specs=P(None, "ici"), + jax_op_kwargs={"perm": ici_perm}, + num_runs=num_runs, + warmup_tries=warmup_tries, + trace_dir=trace_dir, ) return results + def ppermute_benchmark_calculate_metrics( matrix_dim: int, dtype: jnp.dtype, @@ -544,16 +668,26 @@ def ppermute_benchmark_calculate_metrics( # each sharded matrix size is matrix_size_gbyte / dcn_size and then it needs # to use 1 step dcn_bandwidth_gbyte_s_list = [ - matrix_size_gbyte / dcn_size / (dcn_time_ms / 1e3) - for dcn_time_ms in dcn_time_ms_list + matrix_size_gbyte / dcn_size / (dcn_time_ms / 1e3) + for dcn_time_ms in dcn_time_ms_list ] generate_metrics_statistics( - dcn_bandwidth_gbyte_s_list, "dcn_bandwidth_gbyte_s", "ppermute_dcn", - matrix_dim, dtype, matrix_size_gbyte, metrics + dcn_bandwidth_gbyte_s_list, + "dcn_bandwidth_gbyte_s", + "ppermute_dcn", + matrix_dim, + dtype, + matrix_size_gbyte, + metrics, ) generate_metrics_statistics( - dcn_time_ms_list, "dcn_time_ms", "ppermute_dcn", - matrix_dim, dtype, matrix_size_gbyte, metrics + dcn_time_ms_list, + "dcn_time_ms", + "ppermute_dcn", + matrix_dim, + dtype, + matrix_size_gbyte, + metrics, ) metrics["dcn_time_ms_list"] = dcn_time_ms_list @@ -566,16 +700,27 @@ def ppermute_benchmark_calculate_metrics( for ici_time_ms in ici_time_ms_list ] generate_metrics_statistics( - ici_bandwidth_gbyte_s_list, "ici_bandwidth_gbyte_s", "ppermute_ici", - matrix_dim, dtype, matrix_size_gbyte, metrics + ici_bandwidth_gbyte_s_list, + "ici_bandwidth_gbyte_s", + "ppermute_ici", + matrix_dim, + dtype, + matrix_size_gbyte, + metrics, ) generate_metrics_statistics( - ici_time_ms_list, "ici_time_ms", "ppermute_ici", - matrix_dim, dtype, matrix_size_gbyte, metrics + ici_time_ms_list, + "ici_time_ms", + "ppermute_ici", + matrix_dim, + dtype, + matrix_size_gbyte, + metrics, ) metrics["ici_time_ms_list"] = ici_time_ms_list return metadata, metrics + def all_to_all_benchmark( matrix_dim: int, dtype: jnp.dtype, @@ -610,24 +755,40 @@ def all_to_all_benchmark( if dcn_size > 1: results["dcn_time_ms_list"] = benchmark_collective( - benchmark_name="all_to_all", jax_op=jax.lax.all_to_all, mesh=mesh, - matrix=matrix, matrix_dim=matrix_dim, axis_name="dcn", - in_specs=P("dcn", None), out_specs=P("dcn", None), - jax_op_kwargs=all_to_all_kwargs, num_runs=num_runs, - warmup_tries=warmup_tries, trace_dir=trace_dir, + benchmark_name="all_to_all", + jax_op=jax.lax.all_to_all, + mesh=mesh, + matrix=matrix, + matrix_dim=matrix_dim, + axis_name="dcn", + in_specs=P("dcn", None), + out_specs=P("dcn", None), + jax_op_kwargs=all_to_all_kwargs, + num_runs=num_runs, + warmup_tries=warmup_tries, + trace_dir=trace_dir, ) if ici_size > 1: results["ici_time_ms_list"] = benchmark_collective( - benchmark_name="all_to_all", jax_op=jax.lax.all_to_all, mesh=mesh, - matrix=matrix, matrix_dim=matrix_dim, axis_name="ici", - in_specs=P(None, None), out_specs=P(None, None), check_rep=False, - jax_op_kwargs=all_to_all_kwargs, num_runs=num_runs, - warmup_tries=warmup_tries, trace_dir=trace_dir, + benchmark_name="all_to_all", + jax_op=jax.lax.all_to_all, + mesh=mesh, + matrix=matrix, + matrix_dim=matrix_dim, + axis_name="ici", + in_specs=P(None, None), + out_specs=P(None, None), + check_rep=False, + jax_op_kwargs=all_to_all_kwargs, + num_runs=num_runs, + warmup_tries=warmup_tries, + trace_dir=trace_dir, ) return results + def all_to_all_benchmark_calculate_metrics( matrix_dim: int, dtype: jnp.dtype, @@ -646,41 +807,60 @@ def all_to_all_benchmark_calculate_metrics( if dcn_size > 1 and dcn_time_ms_list is not None: dcn_bandwidth_gbyte_s_list = [ - matrix_size_gbyte - * (dcn_size - 1) - / dcn_size - / dcn_size - / (dcn_time_ms / 1e3) - for dcn_time_ms in dcn_time_ms_list + matrix_size_gbyte + * (dcn_size - 1) + / dcn_size + / dcn_size + / (dcn_time_ms / 1e3) + for dcn_time_ms in dcn_time_ms_list ] generate_metrics_statistics( - dcn_bandwidth_gbyte_s_list, "dcn_bandwidth_gbyte_s", "all_to_all_dcn", - matrix_dim, dtype, matrix_size_gbyte, metrics + dcn_bandwidth_gbyte_s_list, + "dcn_bandwidth_gbyte_s", + "all_to_all_dcn", + matrix_dim, + dtype, + matrix_size_gbyte, + metrics, ) generate_metrics_statistics( - dcn_time_ms_list, "dcn_time_ms", "all_to_all_ici", - matrix_dim, dtype, matrix_size_gbyte, metrics + dcn_time_ms_list, + "dcn_time_ms", + "all_to_all_ici", + matrix_dim, + dtype, + matrix_size_gbyte, + metrics, ) metrics["dcn_time_ms_list"] = dcn_time_ms_list # Calculate metrics for ICI benchmark if ici_size > 1 and ici_time_ms_list is not None: ici_bandwidth_gbyte_s_list = [ - matrix_size_gbyte - * (ici_size - 1) - / ici_size - / (ici_time_ms / 1e3) - for ici_time_ms in ici_time_ms_list + matrix_size_gbyte * (ici_size - 1) / ici_size / (ici_time_ms / 1e3) + for ici_time_ms in ici_time_ms_list ] generate_metrics_statistics( - ici_bandwidth_gbyte_s_list, "ici_bandwidth_gbyte_s", "all_to_all_ici", - matrix_dim, dtype, matrix_size_gbyte, metrics + ici_bandwidth_gbyte_s_list, + "ici_bandwidth_gbyte_s", + "all_to_all_ici", + matrix_dim, + dtype, + matrix_size_gbyte, + metrics, ) generate_metrics_statistics( - ici_time_ms_list, "ici_time_ms", "all_to_all_ici", - matrix_dim, dtype, matrix_size_gbyte, metrics + ici_time_ms_list, + "ici_time_ms", + "all_to_all_ici", + matrix_dim, + dtype, + matrix_size_gbyte, + metrics, ) metrics["ici_time_ms_list"] = ici_time_ms_list - metrics = {key: value for key, value in metrics.items() if value is not None} + metrics = { + key: value for key, value in metrics.items() if value is not None + } return metadata, metrics diff --git a/src/benchmark_convolution.py b/src/benchmark_convolution.py index 34a39d1..fa77f5f 100644 --- a/src/benchmark_convolution.py +++ b/src/benchmark_convolution.py @@ -45,9 +45,15 @@ def convolve_common( def f(x, kernel, mode): return convolve_fn(x, kernel, mode=mode) - x = jnp.arange(np.prod(input_shape)).reshape(input_shape).astype(jnp.bfloat16) + x = ( + jnp.arange(np.prod(input_shape)) + .reshape(input_shape) + .astype(jnp.bfloat16) + ) kernel = ( - jnp.arange(np.prod(kernel_shape)).reshape(kernel_shape).astype(jnp.bfloat16) + jnp.arange(np.prod(kernel_shape)) + .reshape(kernel_shape) + .astype(jnp.bfloat16) ) # Warm up @@ -104,7 +110,8 @@ def convolve_common_calculate_metrics( # Calculate FLOPS utilization gflops_per_sec_list = [ - flops / (average_time_ms / 1000) / 1e9 for average_time_ms in time_ms_list + flops / (average_time_ms / 1000) / 1e9 + for average_time_ms in time_ms_list ] # Convert ms to seconds gflops_per_sec_statistics = MetricsStatistics( metrics_list=gflops_per_sec_list, metrics_name="gflops_per_sec" @@ -114,14 +121,18 @@ def convolve_common_calculate_metrics( ) # Print results print(f"Total flops: {flops}") - print(f"Average Execution Time: {time_ms_statistics.statistics['p50']:.4f} ms") + print( + f"Average Execution Time: {time_ms_statistics.statistics['p50']:.4f} ms" + ) print( f"FLOPS Utilization(median): {gflops_per_sec_statistics.statistics['p50']:.2f} GFLOPS/sec\n" ) # Gather the metrics to report. metadata.update({"total_flops": flops}) metrics.update(gflops_per_sec_statistics.serialize_statistics()) - metrics = {key: value for key, value in metrics.items() if value is not None} + metrics = { + key: value for key, value in metrics.items() if value is not None + } return metadata, metrics @@ -290,7 +301,9 @@ def lax_conv_general_dilated( dilation = (dilation, dilation) x = jnp.arange(np.prod(input_shape)).reshape(input_shape).astype(dtype) - kernel = jnp.arange(np.prod(kernel_shape)).reshape(kernel_shape).astype(dtype) + kernel = ( + jnp.arange(np.prod(kernel_shape)).reshape(kernel_shape).astype(dtype) + ) @partial(jax.jit, static_argnames=["mode", "stride", "dilation"]) def f(x, kernel, stride, dilation, mode): @@ -374,7 +387,8 @@ def lax_conv_general_dilated_calculate_metrics( # Calculate FLOPS utilization gflops_per_sec_list = [ - flops / (average_time_ms / 1000) / 1e9 for average_time_ms in time_ms_list + flops / (average_time_ms / 1000) / 1e9 + for average_time_ms in time_ms_list ] # Convert ms to seconds gflops_per_sec_statistics = MetricsStatistics( metrics_list=gflops_per_sec_list, metrics_name="gflops_per_sec" @@ -384,7 +398,9 @@ def lax_conv_general_dilated_calculate_metrics( ) # Print results print(f"Total flops: {flops}") - print(f"Average Execution Time: {time_ms_statistics.statistics['p50']:.4f} ms") + print( + f"Average Execution Time: {time_ms_statistics.statistics['p50']:.4f} ms" + ) print( f"FLOPS Utilization(median): {gflops_per_sec_statistics.statistics['p50']:.2f} GFLOPS/sec\n" ) diff --git a/src/benchmark_hbm.py b/src/benchmark_hbm.py index 4bc93a0..dcad0c8 100644 --- a/src/benchmark_hbm.py +++ b/src/benchmark_hbm.py @@ -86,5 +86,7 @@ def single_chip_hbm_copy_calculate_metrics( ) metrics.update(time_statistics.serialize_statistics()) metrics.update(statistics.serialize_statistics()) - metrics = {key: value for key, value in metrics.items() if value is not None} + metrics = { + key: value for key, value in metrics.items() if value is not None + } return metadata, metrics diff --git a/src/benchmark_matmul.py b/src/benchmark_matmul.py index a6a778e..48dd0c1 100644 --- a/src/benchmark_matmul.py +++ b/src/benchmark_matmul.py @@ -13,7 +13,6 @@ import os from typing import Any, Dict, Tuple - # pylint: disable=g-importing-member from benchmark_utils import simple_timeit, MetricsStatistics import jax @@ -67,7 +66,12 @@ def get_metrics_helper( def naive_matmul( - m: int, k: int, n: int, num_runs: int = 1, trace_dir: str = None, warmup_tries: int = 10, + m: int, + k: int, + n: int, + num_runs: int = 1, + trace_dir: str = None, + warmup_tries: int = 10, ) -> Dict[str, Any]: """Benchmarks the jax.numpy.einsum.""" @@ -92,7 +96,8 @@ def f(x, y): ) # Run once. output = jit_sharded_f(lhs, rhs) - jax.block_until_ready(output) # Ensure full completion before printing metrics + # Ensure full completion before printing metrics + jax.block_until_ready(output) print(f"{lhs.shape=} x {rhs.shape=} = {output.shape=}, {output.dtype=}") # Run the benchmark time_ms_list = simple_timeit( @@ -118,9 +123,12 @@ def naive_matmul_calculate_metrics( # Calculate FLOPs total_flops = 2 * m * k * n # Total floating-point operations - average_time_s_list = [average_time_ms / 10**3 for average_time_ms in time_ms_list] + average_time_s_list = [ + average_time_ms / 10**3 for average_time_ms in time_ms_list + ] tflops_per_sec_list = [ - total_flops / average_time_s / 10**12 for average_time_s in average_time_s_list + total_flops / average_time_s / 10**12 + for average_time_s in average_time_s_list ] tflops_per_sec_statistics = MetricsStatistics( metrics_list=tflops_per_sec_list, metrics_name="tflops_per_sec" @@ -150,12 +158,19 @@ def naive_matmul_calculate_metrics( ) metrics.update(tflops_per_sec_statistics.serialize_statistics()) metrics.update(data_transfer_gbyte_sec_statistics.serialize_statistics()) - metrics = {key: value for key, value in metrics.items() if value is not None} + metrics = { + key: value for key, value in metrics.items() if value is not None + } return metadata, metrics def single_host_naive_matmul( - m: int, k: int, n: int, num_runs: int = 1, trace_dir: str = None, warmup_tries: int = 10, + m: int, + k: int, + n: int, + num_runs: int = 1, + trace_dir: str = None, + warmup_tries: int = 10, ) -> Dict[str, Any]: """Benchmarks matmul on a single device without any sharding.""" @@ -194,9 +209,12 @@ def single_host_naive_matmul_calculate_metrics( # Calculate FLOPs total_flops = 2 * m * k * n # Total floating-point operations - average_time_s_list = [average_time_ms / 10**3 for average_time_ms in time_ms_list] + average_time_s_list = [ + average_time_ms / 10**3 for average_time_ms in time_ms_list + ] tflops_per_sec_list = [ - total_flops / average_time_s / 10**12 for average_time_s in average_time_s_list + total_flops / average_time_s / 10**12 + for average_time_s in average_time_s_list ] tflops_per_sec_statistics = MetricsStatistics( metrics_list=tflops_per_sec_list, metrics_name="tflops_per_sec" @@ -226,12 +244,19 @@ def single_host_naive_matmul_calculate_metrics( ) metrics.update(tflops_per_sec_statistics.serialize_statistics()) metrics.update(data_transfer_gbyte_sec_statistics.serialize_statistics()) - metrics = {key: value for key, value in metrics.items() if value is not None} + metrics = { + key: value for key, value in metrics.items() if value is not None + } return metadata, metrics def collective_matmul_one_direction( - m: int, k: int, n: int, num_runs: int = 1, trace_dir: str = None, warmup_tries: int = 10, + m: int, + k: int, + n: int, + num_runs: int = 1, + trace_dir: str = None, + warmup_tries: int = 10, ) -> Dict[str, Any]: """Benchmarks the collective matmul that does permute in one direction.""" @@ -257,7 +282,9 @@ def scanned_call(i, carrys): accum = jax.lax.dynamic_update_slice(accum, update, update_index) return accum, lhs - accum = jnp.zeros((lhs.shape[0] * axis_size, rhs.shape[1]), dtype=lhs.dtype) + accum = jnp.zeros( + (lhs.shape[0] * axis_size, rhs.shape[1]), dtype=lhs.dtype + ) for i in range(0, axis_size - 1): accum, lhs = scanned_call(i, (accum, lhs)) # compute the last chunk, without the ppermute @@ -283,7 +310,8 @@ def scanned_call(i, carrys): ) # Run once. output = jit_sharded_f(lhs, rhs) - jax.block_until_ready(output) # Ensure full completion before printing metrics + # Ensure full completion before printing metrics + jax.block_until_ready(output) print(f"{lhs.shape=} x {rhs.shape=} = {output.shape=}, {output.dtype=}") time_ms_list = simple_timeit( jit_sharded_f, @@ -308,16 +336,19 @@ def collective_matmul_one_direction_calculate_metrics( # Calculate FLOPs total_flops = 2 * m * k * n # Total floating-point operations - average_time_s_list = [average_time_ms / 10**3 for average_time_ms in time_ms_list] + average_time_s_list = [ + average_time_ms / 10**3 for average_time_ms in time_ms_list + ] tflops_per_sec_list = [ - total_flops / average_time_s / 10**12 for average_time_s in average_time_s_list + total_flops / average_time_s / 10**12 + for average_time_s in average_time_s_list ] tflops_per_sec_statistics = MetricsStatistics( metrics_list=tflops_per_sec_list, metrics_name="tflops_per_sec" ) print( f"Total floating-point ops: {total_flops}, Performance (median):" - f" {tflops_per_sec_statistics.statistics['p50'] :.2f} TFLOPs / second" + f" {tflops_per_sec_statistics.statistics['p50']:.2f} TFLOPs / second" ) print() # Gather the metrics to report. @@ -327,12 +358,19 @@ def collective_matmul_one_direction_calculate_metrics( } ) metrics.update(tflops_per_sec_statistics.serialize_statistics()) - metrics = {key: value for key, value in metrics.items() if value is not None} + metrics = { + key: value for key, value in metrics.items() if value is not None + } return metadata, metrics def collective_matmul_two_directions( - m: int, k: int, n: int, num_runs: int = 1, trace_dir: str = None, warmup_tries: int = 10, + m: int, + k: int, + n: int, + num_runs: int = 1, + trace_dir: str = None, + warmup_tries: int = 10, ) -> Dict[str, Any]: """Benchmarks the collective matmul that does permute in two directions.""" @@ -353,7 +391,9 @@ def f(activations, weights): update_index = (axis_index * chunk_size, 0) accum = jax.lax.dynamic_update_slice(accum, update, update_index) # Prepare forward and backward activations for next steps - activation_forward, activation_backward = jnp.split(activations, 2, axis=0) + activation_forward, activation_backward = jnp.split( + activations, 2, axis=0 + ) # Initial ppermute of activations to the next device activation_forward = jax.lax.ppermute( activation_forward, @@ -384,7 +424,9 @@ def scanned_call(i, carrys): perm=[(j, (j - 1) % axis_size) for j in range(axis_size)], ) # Update indices for forward and backward propagation - forward_update_index = ((axis_index - i - 1) % axis_size) * chunk_size + forward_update_index = ( + (axis_index - i - 1) % axis_size + ) * chunk_size backward_update_index = ( (axis_index + i + 1) % axis_size ) * chunk_size + mid_chunk @@ -420,7 +462,8 @@ def scanned_call(i, carrys): ) # Run once. output = jit_sharded_f(lhs, rhs) - jax.block_until_ready(output) # Ensure full completion before printing metrics + # Ensure full completion before printing metrics + jax.block_until_ready(output) print(f"{lhs.shape=} x {rhs.shape=} = {output.shape=}, {output.dtype=}") # Run the benchmark. time_ms_list = simple_timeit( @@ -446,16 +489,19 @@ def collective_matmul_two_directions_calculate_metrics( # Calculate FLOPs total_flops = 2 * m * k * n # Total floating-point operations - average_time_s_list = [average_time_ms / 10**3 for average_time_ms in time_ms_list] + average_time_s_list = [ + average_time_ms / 10**3 for average_time_ms in time_ms_list + ] tflops_per_sec_list = [ - total_flops / average_time_s / 10**12 for average_time_s in average_time_s_list + total_flops / average_time_s / 10**12 + for average_time_s in average_time_s_list ] tflops_per_sec_statistics = MetricsStatistics( metrics_list=tflops_per_sec_list, metrics_name="tflops_per_sec" ) print( f"Total floating-point ops: {total_flops}, Performance (median):" - f" {tflops_per_sec_statistics.statistics['p50'] :.2f} TFLOPs / second" + f" {tflops_per_sec_statistics.statistics['p50']:.2f} TFLOPs / second" ) print() # Gather the metrics to report. @@ -465,12 +511,19 @@ def collective_matmul_two_directions_calculate_metrics( } ) metrics.update(tflops_per_sec_statistics.serialize_statistics()) - metrics = {key: value for key, value in metrics.items() if value is not None} + metrics = { + key: value for key, value in metrics.items() if value is not None + } return metadata, metrics def multilayer_collective_matmul( - m: int, k: int, n: int, num_runs: int = 1, trace_dir: str = None, warmup_tries: int = 10, + m: int, + k: int, + n: int, + num_runs: int = 1, + trace_dir: str = None, + warmup_tries: int = 10, ) -> Dict[str, Any]: """Benchmarks the multilayer collective matmul.""" @@ -480,12 +533,16 @@ def f(act, weights): return act mesh = create_mesh() - activation = jnp.arange(np.prod((m, k))).reshape((m, k)).astype(jnp.bfloat16) + activation = ( + jnp.arange(np.prod((m, k))).reshape((m, k)).astype(jnp.bfloat16) + ) hidden_layers = [ jnp.arange(np.prod((k, k))).reshape((k, k)).astype(jnp.bfloat16) for _ in range(LAYERS - 1) ] - last_layer = [jnp.arange(np.prod((k, n))).reshape((k, n)).astype(jnp.bfloat16)] + last_layer = [ + jnp.arange(np.prod((k, n))).reshape((k, n)).astype(jnp.bfloat16) + ] weights = hidden_layers + last_layer activation_sharding = NamedSharding(mesh, P("i", None)) weight_sharding = NamedSharding(mesh, P(None, "i")) @@ -502,7 +559,8 @@ def f(act, weights): ) # Run once. output = jit_sharded_f(activation, weights) - jax.block_until_ready(output) # Ensure full completion before printing metrics + # Ensure full completion before printing metrics + jax.block_until_ready(output) print(f"Activation shape: {activation.shape}") print("Weights shapes:", [w.shape for w in weights]) print(f"Output shape: {output.shape}, Output dtype: {output.dtype}") @@ -531,16 +589,19 @@ def multilayer_collective_matmul_calculate_metrics( per_layer_flops = 2 * m * k * k # Total floating-point operations last_layer_flops = 2 * m * k * n total_flops = per_layer_flops * (LAYERS - 1) + last_layer_flops - average_time_s_list = [average_time_ms / 10**3 for average_time_ms in time_ms_list] + average_time_s_list = [ + average_time_ms / 10**3 for average_time_ms in time_ms_list + ] tflops_per_sec_list = [ - total_flops / average_time_s / 10**12 for average_time_s in average_time_s_list + total_flops / average_time_s / 10**12 + for average_time_s in average_time_s_list ] tflops_per_sec_statistics = MetricsStatistics( metrics_list=tflops_per_sec_list, metrics_name="tflops_per_sec" ) print( f"Total floating-point ops: {total_flops}, Performance (median):" - f" {tflops_per_sec_statistics.statistics['p50'] :.2f} TFLOPs / second" + f" {tflops_per_sec_statistics.statistics['p50']:.2f} TFLOPs / second" ) print() # Gather the metrics to report. @@ -550,5 +611,7 @@ def multilayer_collective_matmul_calculate_metrics( } ) metrics.update(tflops_per_sec_statistics.serialize_statistics()) - metrics = {key: value for key, value in metrics.items() if value is not None} + metrics = { + key: value for key, value in metrics.items() if value is not None + } return metadata, metrics diff --git a/src/benchmark_utils.py b/src/benchmark_utils.py index 1e9719a..01f7711 100644 --- a/src/benchmark_utils.py +++ b/src/benchmark_utils.py @@ -19,12 +19,28 @@ from jax.experimental import multihost_utils -def simple_timeit(f, *args, matrix_dim=None, warmup_tries = 10, tries=10, task=None, trace_dir=None) -> list[float]: +def simple_timeit( + f, + *args, + matrix_dim=None, + warmup_tries=10, + tries=10, + task=None, + trace_dir=None, +) -> list[float]: """Simple utility to time a function for multiple runs.""" assert task is not None if trace_dir: - return timeit_from_trace(f, *args, matrix_dim=matrix_dim, warmup_tries=warmup_tries, tries=tries, task=task, trace_dir=trace_dir) + return timeit_from_trace( + f, + *args, + matrix_dim=matrix_dim, + warmup_tries=warmup_tries, + tries=tries, + task=task, + trace_dir=trace_dir, + ) is_multihost = jax.process_count() > 1 @@ -40,7 +56,7 @@ def simple_timeit(f, *args, matrix_dim=None, warmup_tries = 10, tries=10, task=N # Final barrier after warmup to ensure all hosts are ready to start measuring together. if is_multihost: - multihost_utils.sync_global_devices(f'warmup_done_{task}') + multihost_utils.sync_global_devices(f"warmup_done_{task}") print(f"Running measurement loop with {tries} tries...") for i in range(tries): @@ -50,7 +66,7 @@ def simple_timeit(f, *args, matrix_dim=None, warmup_tries = 10, tries=10, task=N # Synchronize (Multi-Host Only): Wait for ALL hosts to finish the operation. if is_multihost: - multihost_utils.sync_global_devices(f'end_run_{i}_{task}') + multihost_utils.sync_global_devices(f"end_run_{i}_{task}") e_time = time.perf_counter() outcomes_ms.append(1000 * (e_time - s_time)) @@ -65,7 +81,9 @@ def get_trace(log_dir: str) -> dict[str, Any]: A trace object in JSON format. """ # Navigate to the folder with the latest trace dump to find `trace.json.jz` - trace_folders = (pathlib.Path(log_dir).absolute() / "plugins" / "profile").iterdir() + trace_folders = ( + pathlib.Path(log_dir).absolute() / "plugins" / "profile" + ).iterdir() latest_trace_folder = max(trace_folders, key=os.path.getmtime) trace_jsons = latest_trace_folder.glob("*.trace.json.gz") try: @@ -113,7 +131,15 @@ def is_local_directory_path(dir: str) -> bool: return dir.startswith("/") or dir.startswith("./") or dir.startswith("../") -def timeit_from_trace(f, *args, matrix_dim=None, warmup_tries=10, tries=10, task=None, trace_dir=None) -> list[float]: +def timeit_from_trace( + f, + *args, + matrix_dim=None, + warmup_tries=10, + tries=10, + task=None, + trace_dir=None, +) -> list[float]: """ Time a function with jax.profiler and get the run time from the trace. """ @@ -126,7 +152,7 @@ def timeit_from_trace(f, *args, matrix_dim=None, warmup_tries=10, tries=10, task data = f(*args) jax.block_until_ready(data) if is_multihost: - multihost_utils.sync_global_devices(f'warmup_done_{task}') + multihost_utils.sync_global_devices(f"warmup_done_{task}") if matrix_dim is not None: trace_name = f"{task}_dim_{matrix_dim}" @@ -145,7 +171,7 @@ def timeit_from_trace(f, *args, matrix_dim=None, warmup_tries=10, tries=10, task with jax.profiler.TraceAnnotation(task): jax.block_until_ready(f(*args)) if is_multihost: - multihost_utils.sync_global_devices(f'end_run_{i}_{task}') + multihost_utils.sync_global_devices(f"end_run_{i}_{task}") trace = get_trace(tmp_trace_dir) if trace_full_dir != tmp_trace_dir: @@ -269,7 +295,9 @@ def rename_xla_dump( serialized_benchmark_param = "_".join( f"{key}_{value}" for key, value in benchmark_param.items() ) - anchor_pattern = os.path.join(tmp_xla_dump_dir, "*jit_f*before_optimizations*.txt") + anchor_pattern = os.path.join( + tmp_xla_dump_dir, "*jit_f*before_optimizations*.txt" + ) matching_anchor_files = glob.glob(anchor_pattern) if not matching_anchor_files: @@ -346,5 +374,7 @@ def rename_xla_dump( f"An unexpected error occurred while copy '{original_filepath}': {e}" ) else: - upload_to_storage(trace_dir=new_filepath, local_file=original_filepath) + upload_to_storage( + trace_dir=new_filepath, local_file=original_filepath + ) print(f"The XLA dump is stored in {dest_xla_dump_dir}") diff --git a/src/run_benchmark.py b/src/run_benchmark.py index 5771443..aab7e22 100644 --- a/src/run_benchmark.py +++ b/src/run_benchmark.py @@ -33,7 +33,9 @@ MATMUL_BENCHMARK_MAP = { "naive_matmul": "benchmark_matmul.naive_matmul", "single_host_naive_matmul": "benchmark_matmul.single_host_naive_matmul", - "multilayer_collective_matmul": ("benchmark_matmul.multilayer_collective_matmul"), + "multilayer_collective_matmul": ( + "benchmark_matmul.multilayer_collective_matmul" + ), "collective_matmul_one_direction": ( "benchmark_matmul.collective_matmul_one_direction" ), @@ -45,14 +47,20 @@ "numpy_convolve": "benchmark_convolution.numpy_convolve", "scipy_signal_convolve": "benchmark_convolution.scipy_signal_convolve", "scipy_signal_convolve2d": "benchmark_convolution.scipy_signal_convolve2d", - "lax_conv_general_dilated": ("benchmark_convolution.lax_conv_general_dilated"), + "lax_conv_general_dilated": ( + "benchmark_convolution.lax_conv_general_dilated" + ), } ATTENTION_BENCHMARK_MAP = { "naive_attention": "benchmark_attention.naive_attention_benchmark", - "pallas_flash_attention": ("benchmark_attention.pallas_flash_attention_benchmark"), + "pallas_flash_attention": ( + "benchmark_attention.pallas_flash_attention_benchmark" + ), "splash_attention": "benchmark_attention.splash_attention_benchmark", "flax_nnx_attention": "benchmark_attention.flax_nnx_attention_benchmark", - "flax_linen_attention": ("benchmark_attention.flax_linen_attention_benchmark"), + "flax_linen_attention": ( + "benchmark_attention.flax_linen_attention_benchmark" + ), "keras_attention": "benchmark_attention.keras_attention_benchmark", } HBM_BENCHMARK_MAP = { @@ -91,7 +99,9 @@ def get_benchmark_functions( ) -> Tuple[Callable[..., Any], Callable[..., Any]]: """Dynamically load the benchmark function and its calculate_metrics function from the predefined map.""" if benchmark_name not in BENCHMARK_MAP: - raise ValueError(f"Benchmark {benchmark_name} is not defined in the map.") + raise ValueError( + f"Benchmark {benchmark_name} is not defined in the map." + ) module_path, func_name = BENCHMARK_MAP[benchmark_name].rsplit(".", 1) @@ -110,7 +120,9 @@ def get_benchmark_functions( # Get the calculate_metrics function try: - calculate_metrics_func = getattr(module, f"{func_name}_calculate_metrics") + calculate_metrics_func = getattr( + module, f"{func_name}_calculate_metrics" + ) except AttributeError: raise ValueError( f"Calculate metrics function for {benchmark_name} not found." @@ -189,14 +201,18 @@ def generate_benchmark_params_sweeping( # Generate all combinations using itertools.product combinations = [ dict(zip(param_names, values)) - for values in itertools.product(*(param_sets[name] for name in param_names)) + for values in itertools.product( + *(param_sets[name] for name in param_names) + ) ] generated_params += combinations return generated_params -def write_to_csv(csv_path: str, calculate_metrics_results: List[Dict[str, Any]]): +def write_to_csv( + csv_path: str, calculate_metrics_results: List[Dict[str, Any]] +): """Writes benchmark metrics to a CSV file. This function takes a list of dictionaries, where each dictionary contains @@ -224,7 +240,11 @@ def flatten_and_sanitize_dict(current_dict: Dict) -> Dict: if isinstance(val, Dict): output_dict.update(flatten_and_sanitize_dict(val)) else: - output_dict[key] = json.dumps(val) if isinstance(val, (list, tuple, set)) else val + output_dict[key] = ( + json.dumps(val) + if isinstance(val, (list, tuple, set)) + else val + ) return output_dict @@ -239,7 +259,9 @@ def convert_dict_to_df(target_dict: Dict) -> pd.DataFrame: # We should revert this PR and refactor the code such that metrics object is a flatten dict that can be easily exported as a CSV. # For other information that requires nested structures, we should serialize it into a json file." try: - df_list = [convert_dict_to_df(each) for each in calculate_metrics_results] + df_list = [ + convert_dict_to_df(each) for each in calculate_metrics_results + ] df = pd.concat(df_list, ignore_index=True) df.to_csv(csv_path, index=False, sep="\t") @@ -258,7 +280,9 @@ def run_single_benchmark(benchmark_config: Dict[str, Any]): benchmark_params = benchmark_config.get("benchmark_params", []) benchmark_sweep_params = benchmark_config.get("benchmark_sweep_params", {}) if benchmark_sweep_params: - benchmark_params += generate_benchmark_params_sweeping(benchmark_sweep_params) + benchmark_params += generate_benchmark_params_sweeping( + benchmark_sweep_params + ) csv_path = benchmark_config.get("csv_path") trace_dir = benchmark_config.get("trace_dir") xlml_metrics_dir = benchmark_config.get("xlml_metrics_dir") @@ -270,7 +294,9 @@ def run_single_benchmark(benchmark_config: Dict[str, Any]): raise ValueError("Each benchmark must have a 'benchmark_name'.") # Get the benchmark function - benchmark_func, calculate_metrics_func = get_benchmark_functions(benchmark_name) + benchmark_func, calculate_metrics_func = get_benchmark_functions( + benchmark_name + ) print(f"\n{'=' * 30}Starting benchmark '{benchmark_name}'{'=' * 30}\n") @@ -281,18 +307,24 @@ def run_single_benchmark(benchmark_config: Dict[str, Any]): benchmark_param = preprocess_benchmark_param( benchmark_param, trace_dir=trace_dir ) - print(f"Running benchmark: {benchmark_name} with params: {benchmark_param}") + print( + f"Running benchmark: {benchmark_name} with params: {benchmark_param}" + ) test_start_time = ( datetime.datetime.now(tz=datetime.timezone.utc).isoformat() + "Z" ) # "Z" indicates UTC - benchmark_results = benchmark_func(**benchmark_param, warmup_tries=warmup_tries) + benchmark_results = benchmark_func( + **benchmark_param, warmup_tries=warmup_tries + ) test_end_time = ( datetime.datetime.now(tz=datetime.timezone.utc).isoformat() + "Z" ) # Filter benchmark_results to include only keys present in # calculate_metrics_func - calculate_metrics_params = inspect.signature(calculate_metrics_func).parameters + calculate_metrics_params = inspect.signature( + calculate_metrics_func + ).parameters filtered_benchmark_results = { key: value for key, value in benchmark_results.items() @@ -308,7 +340,9 @@ def run_single_benchmark(benchmark_config: Dict[str, Any]): metadata, metrics = calculate_metrics_func( **filtered_benchmark_param, **filtered_benchmark_results ) - calculate_metrics_results.append({"metadata": metadata, "metrics": metrics}) + calculate_metrics_results.append( + {"metadata": metadata, "metrics": metrics} + ) if xlml_metrics_dir: maybe_write_metrics_file( xlml_metrics_dir, @@ -383,7 +417,9 @@ def run_benchmark_multithreaded(benchmark_config): benchmark_params = benchmark_config.get("benchmark_params", []) benchmark_sweep_params = benchmark_config.get("benchmark_sweep_params", {}) if benchmark_sweep_params: - benchmark_params += generate_benchmark_params_sweeping(benchmark_sweep_params) + benchmark_params += generate_benchmark_params_sweeping( + benchmark_sweep_params + ) csv_path = benchmark_config.get("csv_path") if not benchmark_name: raise ValueError("Each benchmark must have a 'benchmark_name'.") @@ -391,7 +427,9 @@ def run_benchmark_multithreaded(benchmark_config): warmup_tries = warmup_tries if warmup_tries is not None else 10 # Get the benchmark function - benchmark_func, calculate_metrics_func = get_benchmark_functions(benchmark_name) + benchmark_func, calculate_metrics_func = get_benchmark_functions( + benchmark_name + ) print(f"\n{'=' * 30}Starting benchmark '{benchmark_name}'{'=' * 30}\n") @@ -416,7 +454,9 @@ def run_benchmark_multithreaded(benchmark_config): with ThreadPoolExecutor(max_workers=num_hosts) as executor: # Create a mapping of futures to their corresponding parameters future_to_param = { - executor.submit(benchmark_func, **benchmark_param, warmup_tries=warmup_tries): benchmark_param + executor.submit( + benchmark_func, **benchmark_param, warmup_tries=warmup_tries + ): benchmark_param for benchmark_param in preprocessed_benchmark_params } @@ -425,7 +465,9 @@ def run_benchmark_multithreaded(benchmark_config): benchmark_param = future_to_param[ future ] # Retrieve the corresponding benchmark_param - benchmark_results = future.result() # Get the result from the future + benchmark_results = ( + future.result() + ) # Get the result from the future # Filter benchmark_results to include only keys present in calculate_metrics_func calculate_metrics_params = inspect.signature( @@ -441,7 +483,9 @@ def run_benchmark_multithreaded(benchmark_config): metadata, metrics = calculate_metrics_func( **benchmark_param, **filtered_benchmark_results ) - calculate_metrics_results.append({"metadata": metadata, "metrics": metrics}) + calculate_metrics_results.append( + {"metadata": metadata, "metrics": metrics} + ) if csv_path: write_to_csv(f"{csv_path}/{test_name}.csv", calculate_metrics_results) From 88a03d9167af91608a657b0911fc5b92f80c2004 Mon Sep 17 00:00:00 2001 From: Vidushi Vashishth Date: Mon, 16 Mar 2026 11:50:03 -0700 Subject: [PATCH 2/7] Removing unneeded line breaks --- Ironwood/src/benchmark_attention.py | 1 - Ironwood/src/run_benchmark.py | 1 - 2 files changed, 2 deletions(-) diff --git a/Ironwood/src/benchmark_attention.py b/Ironwood/src/benchmark_attention.py index 3e16748..1d25e86 100644 --- a/Ironwood/src/benchmark_attention.py +++ b/Ironwood/src/benchmark_attention.py @@ -87,7 +87,6 @@ def f(q, k, v): kernel_ = jax.vmap(kernel, in_axes=(0, 0, 0)) # batch vmap kernel_ = jax.vmap(kernel_, in_axes=(0, 0, 0)) # mqa vmap return kernel_(q, k, v) - else: kernel = splash.make_splash_mha_single_device(mask, config=config) f = jax.jit(jax.vmap(kernel, in_axes=(0, 0, 0))) diff --git a/Ironwood/src/run_benchmark.py b/Ironwood/src/run_benchmark.py index 471bee1..0650efe 100644 --- a/Ironwood/src/run_benchmark.py +++ b/Ironwood/src/run_benchmark.py @@ -482,7 +482,6 @@ def main(args): for benchmark_config in benchmarks: run_benchmark_multithreaded(benchmark_config, output_path) - else: for benchmark_config in benchmarks: run_single_benchmark(benchmark_config, output_path) From 751efb0ca6fc8f6b080204545e004a59cc76890f Mon Sep 17 00:00:00 2001 From: Vidushi Vashishth Date: Mon, 16 Mar 2026 15:07:39 -0700 Subject: [PATCH 3/7] Fixing remaining style errors that were skipped by the agent --- Ironwood/src/benchmark_attention.py | 24 ++-- Ironwood/src/benchmark_collectives.py | 16 +-- Ironwood/src/benchmark_compute.py | 36 +++--- Ironwood/src/benchmark_gemm.py | 45 ++++---- Ironwood/src/benchmark_gemm_numerics.py | 29 ++--- Ironwood/src/benchmark_gemm_throttling.py | 5 +- Ironwood/src/benchmark_hbm.py | 14 +-- Ironwood/src/benchmark_host_device.py | 6 +- Ironwood/src/benchmark_inference_compute.py | 25 ++--- Ironwood/src/benchmark_send_recv.py | 12 +- Ironwood/src/benchmark_utils.py | 115 +++++++++++++------- Ironwood/src/run_benchmark.py | 12 +- 12 files changed, 183 insertions(+), 156 deletions(-) diff --git a/Ironwood/src/benchmark_attention.py b/Ironwood/src/benchmark_attention.py index 1d25e86..69d8076 100644 --- a/Ironwood/src/benchmark_attention.py +++ b/Ironwood/src/benchmark_attention.py @@ -1,27 +1,21 @@ """A script to benchmark tokamax splash attention implementation.""" -import os - -# pylint: disable=g-importing-member,g-bad-import-order +# pylint: disable=g-importing-member +import dataclasses from functools import partial +import logging +import os from typing import Any, Callable, Dict, Tuple -import dataclasses -from benchmark_utils import timeit_from_trace, MetricsStatistics +from benchmark_utils import MetricsStatistics +from benchmark_utils import timeit_from_trace import jax -import logging -from tokamax._src.ops.experimental.tpu.splash_attention import ( - splash_attention_kernel as splash, -) -from tokamax._src.ops.experimental.tpu.splash_attention import ( - splash_attention_mask as mask_lib, -) +from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_kernel as splash +from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_mask as mask_lib import tune_jax tune_jax.tune_logger.setLevel(logging.ERROR) -# pylint: disable=g-importing-member,g-bad-import-order - os.environ["LIBTPU_INIT_ARGS"] = "--xla_tpu_dvfs_p_state=7" @@ -141,7 +135,7 @@ def tokamax_splash_attention_benchmark( # Attention mask mask = mask_lib.FullMask(_shape=(q_seq_len, kv_seq_len)) if causal: - # Pick offset for causal masks for a "representative" slice of the causal + # Pick offset for causal masks for a representative slice of the causal offset = v.shape[-2] - q.shape[-2] mask = mask_lib.CausalMask(shape=(q_seq_len, kv_seq_len), offset=offset) diff --git a/Ironwood/src/benchmark_collectives.py b/Ironwood/src/benchmark_collectives.py index 7c97afd..59bc36d 100644 --- a/Ironwood/src/benchmark_collectives.py +++ b/Ironwood/src/benchmark_collectives.py @@ -210,8 +210,8 @@ def psum_benchmark( matrix_dim). mesh_shape: The shape of the mesh. op_dimension: The dimension of the operation. - ici_size: The number of chips in a single slice. If 1, then no ICI benchmark - is run. + ici_size: The number of chips in a single slice. If 1, then no ICI + benchmark is run. dtype: The data type of the matrix. num_runs: The number of runs to perform. trace_dir: The directory to save the trace to. @@ -360,8 +360,8 @@ def psum_scatter_benchmark( matrix_dim: The benchmark is run on a matrix with shape (matrix_dim, matrix_dim). dtype: The data type of the matrix. - ici_size: The number of chips in a single slice. If 1, then no ICI benchmark - is run. + ici_size: The number of chips in a single slice. If 1, then no ICI + benchmark is run. mesh_shape: The shape of the mesh. op_dimension: The dimension of the operation. sharding_strategy: The sharding strategy of the operation. @@ -476,8 +476,8 @@ def all_gather_benchmark( matrix_dim: The benchmark is run on a matrix with shape (matrix_dim, matrix_dim). dtype: The data type of the matrix. - ici_size: The number of chips in a single slice. If 1, then no ICI benchmark - is run. + ici_size: The number of chips in a single slice. If 1, then no ICI + benchmark is run. mesh_shape: The shape of the mesh. sharding_strategy: The sharding strategy of the operation. op_dimension: The dimension of the operation. @@ -592,8 +592,8 @@ def all_to_all_benchmark( matrix_dim: The benchmark is run on a matrix with shape (matrix_dim, matrix_dim). dtype: The data type of the matrix. - ici_size: The number of chips in a single slice. If 1, then no ICI benchmark - is run. + ici_size: The number of chips in a single slice. If 1, then no ICI + benchmark is run. mesh_shape: The shape of the mesh. op_dimension: The dimension of the operation. num_runs: The number of runs to perform. diff --git a/Ironwood/src/benchmark_compute.py b/Ironwood/src/benchmark_compute.py index d2adc33..458db80 100644 --- a/Ironwood/src/benchmark_compute.py +++ b/Ironwood/src/benchmark_compute.py @@ -13,27 +13,24 @@ """ import os -from typing import Any, Dict, Callable - -# pylint: disable=g-importing-member -from benchmark_utils import ( - iteration_timeit, - ShardingStrategy, - get_out_sharding, - get_rowwise_named_shading, - get_output_named_shading, - create_mesh, - handle_based_on_sharding, - unified_bytes_metrics, -) +from typing import Any, Callable, Dict + +from benchmark_utils import create_mesh +from benchmark_utils import get_out_sharding +from benchmark_utils import get_output_named_shading +from benchmark_utils import get_rowwise_named_shading +from benchmark_utils import handle_based_on_sharding +from benchmark_utils import iteration_timeit +from benchmark_utils import ShardingStrategy +from benchmark_utils import unified_bytes_metrics +from common import MARKER + +from flax import nnx import jax from jax.experimental.shard_map import shard_map import jax.numpy as jnp from qwix import pallas as qpl -from flax import nnx -from common import MARKER -# pylint: disable=g-importing-member # Set the environment variable for TPU initialization arguments to optimize # collective matmul. Setting the flags to false will disable the optimization. os.environ["LIBTPU_INIT_ARGS"] = ( @@ -157,8 +154,11 @@ def quantization_calculate_metrics( ) width_in_bytes = info_fn(quant_jnp_dtype).bits / 8 output_flops_based_on_dtype = m * n * width_in_bytes - # calculate scale apply quant write quant output write scale factor - # NOTE: (2 * m * n) + (2 * m * n) + (1 * m * n) + (4 * m) + # NOTE: + # - calculate scale: (2 * m * n) + # - apply quant: (2 * m * n) + # - write quant output: (1 * m * n) + # - write scale factor: (4 * m) total_bytes = ( (2 * m * n) + (2 * m * n) + (4 * m) + output_flops_based_on_dtype ) diff --git a/Ironwood/src/benchmark_gemm.py b/Ironwood/src/benchmark_gemm.py index 5e9dea1..c965a44 100644 --- a/Ironwood/src/benchmark_gemm.py +++ b/Ironwood/src/benchmark_gemm.py @@ -10,27 +10,24 @@ import os from typing import Any, Dict -# pylint: disable=g-importing-member -from benchmark_utils import ( - iteration_timeit, - multiple_iteration_timeit_from_trace, - ShardingStrategy, - get_lhs_named_shading, - get_rhs_named_shading, - get_output_named_shading, - get_out_sharding, - create_mesh, - handle_based_on_sharding, - unified_flops_metrics, - str_to_dtype, - get_peak_flops_multiplier, -) +from benchmark_utils import create_mesh +from benchmark_utils import get_lhs_named_shading +from benchmark_utils import get_out_sharding +from benchmark_utils import get_output_named_shading +from benchmark_utils import get_peak_flops_multiplier +from benchmark_utils import get_rhs_named_shading +from benchmark_utils import handle_based_on_sharding +from benchmark_utils import iteration_timeit +from benchmark_utils import multiple_iteration_timeit_from_trace +from benchmark_utils import ShardingStrategy +from benchmark_utils import str_to_dtype +from benchmark_utils import unified_flops_metrics from common import MARKER + import jax from jax.experimental.shard_map import shard_map import jax.numpy as jnp -# pylint: disable=g-importing-member os.environ["LIBTPU_INIT_ARGS"] = ( "--xla_tpu_enable_async_collective_fusion=true " @@ -70,7 +67,9 @@ def gemm_multiple_run( num_runs: int = 1, trace_dir: str = None, ) -> Dict[str, Any]: - """Benchmarks the OUT:BF16 = IN0 dtype x IN1:dtype. Accumulation is FP32. Current supported dtype: float8_e4m3fn, bfloat16.""" + """Benchmarks the OUT:BF16 = IN0 dtype x IN1:dtype.""" + + """Accumulation is FP32. Current supported dtype: float8_e4m3fn, bfloat16.""" def f(x, y): with jax.named_scope(MARKER): @@ -170,7 +169,9 @@ def gemm_simple( num_runs: int = 1, trace_dir: str = None, ) -> Dict[str, Any]: - """Benchmarks the OUT:BF16 = IN0:FP8 x IN1:FP8. Accumulation is FP32.""" + """Benchmarks the OUT:BF16 = IN0:FP8 x IN1:FP8.""" + + """Accumulation is FP32.""" def f(x, y): with jax.named_scope(MARKER): @@ -264,7 +265,9 @@ def gemm_simple_with_dtype( num_runs: int = 1, trace_dir: str = None, ) -> Dict[str, Any]: - """Benchmarks the OUT:BF16 = IN0:FP8 x IN1:FP8. Accumulation is FP32.""" + """Benchmarks the OUT:BF16 = IN0:FP8 x IN1:FP8.""" + + """Accumulation is FP32.""" # Convert string dtypes to jnp dtypes lhs_dtype = str_to_dtype(in_dtype_str) @@ -365,7 +368,7 @@ def gemm_simple_with_dtype_calculate_metrics( def gemm( m: int, k: int, n: int, num_runs: int = 1, trace_dir: str = None ) -> Dict[str, Any]: - """OUT:BF16 = matmul(IN0:FP8, IN1:FP8) * outer_product(SF0:FP32 * SF1<1, N>:FP32)""" + """OUT:BF16 = matmul(IN0:FP8, IN1:FP8) * outer_product(SF0:FP32 * SF1<1, N>:FP32).""" def f(x, y, scale_m, scale_n): with jax.named_scope(MARKER): @@ -470,7 +473,7 @@ def gemm_accum( num_runs: int = 1, trace_dir: str = None, ) -> Dict[str, Any]: - """OUT:FP32 += matmul(IN0:FP8, IN1:FP8) * outer_product(SF0:FP32 * SF1<1, N>:FP32)""" + """OUT:FP32 += matmul(IN0:FP8, IN1:FP8) * outer_product(SF0:FP32 * SF1<1, N>:FP32).""" def f(out_buffer, x, y, scale_m, scale_n): with jax.named_scope(MARKER): diff --git a/Ironwood/src/benchmark_gemm_numerics.py b/Ironwood/src/benchmark_gemm_numerics.py index a14631a..65b03a0 100644 --- a/Ironwood/src/benchmark_gemm_numerics.py +++ b/Ironwood/src/benchmark_gemm_numerics.py @@ -10,19 +10,17 @@ """ import os -from typing import Any, Dict, Callable +from typing import Any, Callable, Dict # pylint: disable=g-importing-member -from benchmark_utils import ( - iteration_timeit, - ShardingStrategy, - get_lhs_named_shading, - get_rhs_named_shading, - get_out_sharding, - create_mesh, - handle_based_on_sharding, - unified_flops_metrics, -) +from benchmark_utils import create_mesh +from benchmark_utils import get_lhs_named_shading +from benchmark_utils import get_out_sharding +from benchmark_utils import get_rhs_named_shading +from benchmark_utils import handle_based_on_sharding +from benchmark_utils import iteration_timeit +from benchmark_utils import ShardingStrategy +from benchmark_utils import unified_flops_metrics import jax from jax.experimental.shard_map import shard_map import jax.numpy as jnp @@ -30,7 +28,6 @@ from qwix._src.core import qarray from common import MARKER -# pylint: disable=g-importing-member # Set the environment variable for TPU initialization arguments to optimize # collective matmul. Setting the flags to false will disable the optimization. os.environ["LIBTPU_INIT_ARGS"] = ( @@ -276,7 +273,9 @@ def gemm_fp8_rowwise_w_dequantize_calculate_metrics( def gemm_fp8_b128_fp32( m: int, k: int, n: int, num_runs: int = 1, trace_dir: str = None ) -> Dict[str, Any]: - """FP8 GEMM as DeepSeek-stype quantization, block size: 1x128. Use dynamic scaling factors.""" + """FP8 GEMM as DeepSeek-stype quantization, block size: 1x128.""" + + """Use dynamic scaling factors.""" def f(x, y): with jax.named_scope(MARKER): @@ -388,7 +387,9 @@ def gemm_fp8_rowwise_static_scaling_calculate_metrics( def gemm_fp8_b128_fp32_static_scaling( m: int, k: int, n: int, num_runs: int = 1, trace_dir: str = None ) -> Dict[str, Any]: - """FP8 GEMM as DeepSeek-stype quantization, block size: 1x128. Use static scaling factors.""" + """FP8 GEMM as DeepSeek-stype quantization, block size: 1x128.""" + + """Use static scaling factors.""" def f(x, y): with jax.named_scope(MARKER): diff --git a/Ironwood/src/benchmark_gemm_throttling.py b/Ironwood/src/benchmark_gemm_throttling.py index 66712d8..925cb54 100644 --- a/Ironwood/src/benchmark_gemm_throttling.py +++ b/Ironwood/src/benchmark_gemm_throttling.py @@ -3,7 +3,6 @@ import os from typing import Any, Dict -# pylint: disable=g-importing-member from benchmark_utils import create_mesh from benchmark_utils import get_lhs_named_shading from benchmark_utils import get_out_sharding @@ -13,12 +12,11 @@ from benchmark_utils import ShardingStrategy from benchmark_utils import unified_flops_metrics from common import MARKER + import jax from jax.experimental.shard_map import shard_map import jax.numpy as jnp -# pylint: disable=g-importing-member - os.environ["LIBTPU_INIT_ARGS"] = ( "--xla_tpu_enable_async_collective_fusion=true " "--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true " @@ -100,7 +98,6 @@ def data_generator(): return (lhs_device, rhs_device) # Run the benchmark - print("Running gemm_throttling benchmark", num_runs) time_ms_list = multiple_iteration_timeit_from_trace_throttling( jit_sharded_f, diff --git a/Ironwood/src/benchmark_hbm.py b/Ironwood/src/benchmark_hbm.py index aaef211..a99d25a 100644 --- a/Ironwood/src/benchmark_hbm.py +++ b/Ironwood/src/benchmark_hbm.py @@ -3,12 +3,11 @@ import os from typing import Any, Dict, Tuple -from benchmark_utils import ( - MetricsStatistics, - multiple_iteration_timeit_from_trace, - get_real_dtype_bytes, -) +from benchmark_utils import get_real_dtype_bytes +from benchmark_utils import MetricsStatistics +from benchmark_utils import multiple_iteration_timeit_from_trace from common import MARKER + import jax import jax.numpy as jnp @@ -89,8 +88,9 @@ def single_device_hbm_copy_calculate_metrics( metrics_list=bw_gbyte_sec_list, metrics_name="bw_gbyte_sec" ) print( - f"Tensor size: {tensor_size_bytes / 1024**2} MB, time taken (median):" - f" {time_statistics.statistics['p50']:.4f} ms, bandwidth (median): {statistics.statistics['p50']:.3f} GB/s" + f"Tensor size: {tensor_size_bytes / 1024**2} MB, " + f"time taken (median): {time_statistics.statistics['p50']:.4f} ms, " + f"bandwidth (median): {statistics.statistics['p50']:.3f} GB/s" ) print() # Gather the metrics to report. diff --git a/Ironwood/src/benchmark_host_device.py b/Ironwood/src/benchmark_host_device.py index 244877c..de2db0a 100644 --- a/Ironwood/src/benchmark_host_device.py +++ b/Ironwood/src/benchmark_host_device.py @@ -34,7 +34,8 @@ def benchmark_host_device( ) print( - f"Benchmarking Transfer with Data Size: {data_size_mib} MB for {num_runs} iterations", + f"Benchmarking Transfer with Data Size: {data_size_mib} MB " + f"for {num_runs} iterations", flush=True, ) @@ -122,7 +123,8 @@ def add_metric(name, ms_list): ] stats_bw = MetricsStatistics(bw_list, f"{name}_bw (GiB/s)") print( - f" {name}_bw (GiB/s) median: {stats_bw.statistics['p50']}, P95: {stats_bw.statistics['p95']}", + f" {name}_bw (GiB/s) median: {stats_bw.statistics['p50']}, " + f"P95: {stats_bw.statistics['p95']}", flush=True, ) metrics.update(stats_bw.serialize_statistics()) diff --git a/Ironwood/src/benchmark_inference_compute.py b/Ironwood/src/benchmark_inference_compute.py index a5c948d..eb399bb 100644 --- a/Ironwood/src/benchmark_inference_compute.py +++ b/Ironwood/src/benchmark_inference_compute.py @@ -6,24 +6,21 @@ import os from typing import Any, Dict -# pylint: disable=g-importing-member -from benchmark_utils import ( - iteration_timeit, - ShardingStrategy, - create_mesh, - handle_based_on_sharding, - get_rowwise_named_shading, - unified_bytes_metrics, - get_output_named_shading, - get_out_sharding, -) +from benchmark_utils import create_mesh +from benchmark_utils import get_out_sharding +from benchmark_utils import get_output_named_shading +from benchmark_utils import get_rowwise_named_shading +from benchmark_utils import handle_based_on_sharding +from benchmark_utils import iteration_timeit +from benchmark_utils import ShardingStrategy +from benchmark_utils import unified_bytes_metrics +from common import MARKER + +from flax import nnx import jax from jax.experimental.shard_map import shard_map import jax.numpy as jnp -from flax import nnx -from common import MARKER -# pylint: disable=g-importing-member # Set the environment variable for TPU initialization arguments to optimize # collective matmul. Setting the flags to false will disable the optimization. os.environ["LIBTPU_INIT_ARGS"] = ( diff --git a/Ironwood/src/benchmark_send_recv.py b/Ironwood/src/benchmark_send_recv.py index c01ca0e..1eaa2b3 100644 --- a/Ironwood/src/benchmark_send_recv.py +++ b/Ironwood/src/benchmark_send_recv.py @@ -1,17 +1,17 @@ """Benchmarking p2p source target transfer.""" import os +import tempfile from typing import Any, Dict, Tuple + +from benchmark_utils import get_real_dtype_bytes +from benchmark_utils import get_trace +from common import MARKER + import jax from jax.experimental import mesh_utils import jax.numpy as jnp import jax.sharding -from benchmark_utils import ( - get_trace, - get_real_dtype_bytes, -) -from common import MARKER -import tempfile P = jax.sharding.PartitionSpec diff --git a/Ironwood/src/benchmark_utils.py b/Ironwood/src/benchmark_utils.py index 9b2ede1..788d457 100644 --- a/Ironwood/src/benchmark_utils.py +++ b/Ironwood/src/benchmark_utils.py @@ -83,7 +83,8 @@ def multiple_iteration_timeit_from_trace_throttling( trace_full_dir = f"{trace_dir}/{trace_name}" tmp_trace_dir = trace_full_dir - # If the trace_dir isn't a local path, create one for dumping the trace for parsing and getting metrics. + # If the trace_dir isn't a local path, create one for dumping the trace for + # parsing and getting metrics. if trace_dir and not is_local_directory_path(trace_dir): tmp_trace_dir = f"{LOCAL_TRACE_DIR}/{trace_name}" @@ -93,7 +94,8 @@ def multiple_iteration_timeit_from_trace_throttling( for i in range(tries): if i % 10 == 0: print( - f"[{task}] Running iteration {i} of {tries} with {matrix_dim}..." + f"[{task}] Running iteration {i} of {tries} with " + f"{matrix_dim}..." ) jax.devices() with jax.profiler.StepTraceAnnotation(task, step_num=i): @@ -107,7 +109,8 @@ def multiple_iteration_timeit_from_trace_throttling( for i in range(tries): if i % 10 == 0: print( - f"[{task}] Running iteration {i} of {tries} with {matrix_dim}..." + f"[{task}] Running iteration {i} of {tries} with " + f"{matrix_dim}..." ) jax.devices() with jax.profiler.StepTraceAnnotation(task, step_num=i): @@ -122,7 +125,8 @@ def multiple_iteration_timeit_from_trace_throttling( for i in range(tries): if i % 10 == 0: print( - f"[{task}] Running iteration {i} of {tries} with {matrix_dim}..." + f"[{task}] Running iteration {i} of {tries} with" + f"{matrix_dim}..." ) data_args = data_generator() jax.devices() @@ -169,7 +173,8 @@ def multiple_iteration_timeit_from_trace( trace_full_dir = f"{trace_dir}/{trace_name}" tmp_trace_dir = trace_full_dir - # If the trace_dir isn't a local path, create one for dumping the trace for parsing and getting metrics. + # If the trace_dir isn't a local path, create one for dumping the trace for + # parsing and getting metrics. if trace_dir and not is_local_directory_path(trace_dir): tmp_trace_dir = f"{LOCAL_TRACE_DIR}/{trace_name}" # data_args = data_generator() @@ -177,7 +182,8 @@ def multiple_iteration_timeit_from_trace( for i in range(tries): if i % 10 == 0: print( - f"[{task}] Running iteration {i} of {tries} with {matrix_dim}..." + f"[{task}] Running iteration {i} of {tries} with " + f"{matrix_dim}..." ) data_args = data_generator() jax.devices() @@ -224,7 +230,8 @@ def multiple_iteration_get_metrics_from_trace( for e in trace["traceEvents"]: if "name" in e and event_matcher.match(e["name"]): events.append(e) - # For each trace, find the TPU with smallest `pid` value and consider it to be TPU-0 + # For each trace, find the TPU with smallest `pid` value and consider it + # to be TPU-0 min_pid = min([e["pid"] for e in events]) events_from_min_pid = [e for e in events if e["pid"] == min_pid] print(events_from_min_pid) @@ -238,7 +245,8 @@ def multiple_iteration_get_metrics_from_trace( durations_ms.append(float(e["dur"]) / 1e3) if not durations_ms and events_from_min_pid: print( - "Warning: No event duration found in legacy_get_metrics_from_trace_tpu." + "Warning: No event duration found in " + "legacy_get_metrics_from_trace_tpu." ) return durations_ms @@ -277,7 +285,8 @@ def iteration_timeit_from_trace( trace_full_dir = f"{trace_dir}/{trace_name}" tmp_trace_dir = trace_full_dir - # If the trace_dir isn't a local path, create one for dumping the trace for parsing and getting metrics. + # If the trace_dir isn't a local path, create one for dumping the trace for + # parsing and getting metrics. if trace_dir and not is_local_directory_path(trace_dir): tmp_trace_dir = f"{LOCAL_TRACE_DIR}/{trace_name}" with jax.profiler.trace(tmp_trace_dir): @@ -344,11 +353,12 @@ def iteration_get_metrics_from_trace( for pid in sorted(events_by_pid.keys()): events = events_by_pid[pid] - # Sum the device_duration_ps (picoseconds) for all events belonging to this PID - # CAVEAT: If multiple iterations of the op runs for benchmarking, then the next - # instruction will sum it for all the iterations which will not be the expected - # behavior. Find the metadata key which is different for different iteration on - # same PID. Eg: `group_id`. + # Sum the device_duration_ps (picoseconds) for all events belonging to + # this PID. + # CAVEAT: If multiple iterations of the op runs for benchmarking, then + # the next instruction will sum it for all the iterations which will + # not be the expected behavior. Find the metadata key which is + # different for different iteration on same PID. Eg: `group_id`. total_duration_ps = sum( float(e["args"].get("device_duration_ps", 0)) for e in events ) @@ -464,8 +474,8 @@ def iteration_timeit( if trace_dir is not None: if task == "rmsnorm": # If the task is RMSNorm, we specifically target "copy-done" events. - # This is often done to capture the time of the asynchronous memory transfer - # needed for the normalization layer's input data. + # This is often done to capture the time of the asynchronous memory + # transferneeded for the normalization layer's input data. event_name_str_list = ["copy-done"] else: # For all other tasks, use an empty list. @@ -615,7 +625,8 @@ def find_sparsecore_usage_from_xplane(log_dir: str) -> xplane_pb2.XSpace: def get_metrics_from_trace(trace: dict[str, Any], task: str) -> list[float]: - # Check if the given task name is a collective with corresponding TPU opertion. + # Check if the given task name is a collective with corresponding TPU + # opertion. # This is a workaround and should be reverted or refactored in future. if task in TARGET_TASK_NAME_COLLECTIVES_MAP: try: @@ -665,7 +676,8 @@ def get_metrics_from_trace_tpu(trace: dict[str, Any], task: str) -> list[float]: if "name" in e and event_matcher.match(e["name"]): events.append(e) - # For each trace, find the TPU with smallest `pid` value and consider it to be TPU-0 + # For each trace, find the TPU with smallest `pid` value and consider it to + # be TPU-0 min_pid = min([e["pid"] for e in events]) events_from_min_pid = [e for e in events if e["pid"] == min_pid] try: @@ -717,7 +729,8 @@ def timeit_from_trace( trace_full_dir = f"{trace_dir}/{trace_name}" tmp_trace_dir = trace_full_dir - # If the trace_dir isn't a local path, create one for dumping the trace for parsing and getting metrics. + # If the trace_dir isn't a local path, create one for dumping the trace for + # parsing and getting metrics. if trace_dir and not is_local_directory_path(trace_dir): tmp_trace_dir = f"{LOCAL_TRACE_DIR}/{trace_name}" print(trace_dir) @@ -791,7 +804,8 @@ def upload_to_storage(trace_dir: str, local_file: str): except subprocess.CalledProcessError as e: print( - f"Failed to upload '{local_file}' to GCS: '{trace_dir}'. Error: {e.stderr.decode()}" + f"Failed to upload '{local_file}' to GCS: '{trace_dir}'. " + f"Error: {e.stderr.decode()}" ) else: raise KeyError(f"{trace_dir} is not a valid GCS path.") @@ -857,8 +871,9 @@ def rename_xla_dump( ): """ Finds the latest XLA dump file matching '*jit_f*before_optimizations*.txt', - then identifies all other files that share the same 'jit_f.[unique_id]' identifier - and renames them to 'benchmark_name_serialized_params.original_suffix_with_extension'. + then identifies all other files that share the same 'jit_f.[unique_id]' + identifier and renames them to + 'benchmark_name_serialized_params.original_suffix_with_extension'. """ serialized_benchmark_param = "_".join( @@ -871,7 +886,8 @@ def rename_xla_dump( if not matching_anchor_files: print( - f"No files found for anchor pattern: '{anchor_pattern}'. No files will be renamed." + f"No files found for anchor pattern: '{anchor_pattern}'. No files " + "will be renamed." ) return @@ -886,13 +902,15 @@ def rename_xla_dump( if not jit_id_match: print( - f"Could not extract 'jit_f.[unique_id]' from '{filename_base}'. Cannot proceed with renaming." + f"Could not extract 'jit_f.[unique_id]' from '{filename_base}'." + " Cannot proceed with renaming." ) return common_jit_id_prefix = jit_id_match.group(1) - # Find all files in the directory that contain this specific common_jit_id_prefix + # Find all files in the directory that contain this specific + # common_jit_id_prefix all_related_files_pattern = os.path.join( tmp_xla_dump_dir, f"*{common_jit_id_prefix}*" ) @@ -900,7 +918,8 @@ def rename_xla_dump( if not all_related_files: print( - f"No files found containing '{common_jit_id_prefix}'. This is unexpected if an anchor was found." + f"No files found containing '{common_jit_id_prefix}'. This is " + "unexpected if an anchor was found." ) return @@ -914,9 +933,10 @@ def rename_xla_dump( original_suffix_with_extension = "" # Find the specific suffix part *after* the common_jit_id_prefix. - # This regex looks for the common_jit_id_prefix, then captures everything after it, - # ensuring it starts with a dot if there's more. - # Example: if original_filename is 'module_0080.jit_f.cl_747713181.after_codegen.txt' + # This regex looks for the common_jit_id_prefix, then captures + # everything after it, ensuring it starts with a dot if there's more. + # Example: if original_filename is + # 'module_0080.jit_f.cl_747713181.after_codegen.txt' # and common_jit_id_prefix is 'jit_f.cl_747713181' # we want to capture '.after_codegen.txt' suffix_match = re.search( @@ -935,7 +955,8 @@ def rename_xla_dump( if original_filepath == new_filepath: print( - f"Skipping: '{original_filename}' already has the desired name or path." + f"Skipping: '{original_filename}' already has the desired name " + "or path." ) continue @@ -946,7 +967,8 @@ def rename_xla_dump( shutil.copy(original_filepath, new_filepath) except Exception as e: print( - f"An unexpected error occurred while copy '{original_filepath}': {e}" + f"An unexpected error occurred while copy " + f"'{original_filepath}': {e}" ) else: upload_to_storage( @@ -983,8 +1005,8 @@ def extract_hlo_features_from_file( hlo_file_path: Path to the HLO dump file (e.g., after_optimizations.txt). Returns: - A tuple containing (input_shape, output_shape, replica_groups_str, first_replica_group), - or (None, None, None, None) if extraction fails. + A tuple containing (input_shape, output_shape, replica_groups_str, + first_replica_group), or (None, None, None, None) if extraction fails. """ input_shape = None output_shape = None @@ -1011,7 +1033,8 @@ def extract_hlo_features_from_file( output_shape = re.sub(r"{.*}", "", output_shape) else: print( - f"Could not find entry_computation_layout in {hlo_file_path} to extract shapes." + f"Could not find entry_computation_layout in {hlo_file_path} to " + "extract shapes." ) # Extract replica groups @@ -1227,9 +1250,13 @@ def unified_flops_metrics( dtype_prefix = f"[{dtype}] " if dtype is not None else "" print( f"{dtype_prefix}" - f"Total floating-point ops: {total_flops}, Step Time (median): {average_time_ms_statistics.statistics['p50']:.2f}, " - f"Throughput (median): {tflops_per_sec_statistics.statistics['p50']:.2f} TFLOP / second / device, " - f"TotalThroughput (median): {tflops_per_sec_all_devices_statistics.statistics['p50']:.2f} TFLOP / second, " + f"Total floating-point ops: {total_flops}, Step Time (median): " + f"{average_time_ms_statistics.statistics['p50']:.2f}, " + f"Throughput (median): {tflops_per_sec_statistics.statistics['p50']:.2f} " + f"TFLOP / second / device, " + f"TotalThroughput (median): " + f"{tflops_per_sec_all_devices_statistics.statistics['p50']:.2f} " + f"TFLOP / second, " f"MFU: {mfu_statistics.statistics['p50']:.2%}" ) # print() @@ -1309,9 +1336,13 @@ def unified_bytes_metrics( type_prefix = f"[d={dtype}] " print( f"{type_prefix}" - f"Total bytes: {total_bytes}, Step Time (median): {average_time_ms_statistics.statistics['p50']:.2f}, Throughput (median):" - f" {gigabytes_per_sec_statistics.statistics['p50']:.2f} GBytes / second / device," - f" TotalThroughput (median): {gigabytes_per_sec_all_devices_statistics.statistics['p50']:.2f} GBytes / second" + f"Total bytes: {total_bytes}, Step Time (median): " + f"{average_time_ms_statistics.statistics['p50']:.2f}, " + f"Throughput (median): {gigabytes_per_sec_statistics.statistics['p50']:.2f} " + f"GBytes / second / device, " + f"TotalThroughput (median): " + f"{gigabytes_per_sec_all_devices_statistics.statistics['p50']:.2f} " + f"GBytes / second" ) print() metadata.update( @@ -1359,8 +1390,8 @@ def get_peak_flops_multiplier(in_dtype_str: str) -> float: in_dtype_lower = in_dtype_str.lower() if in_dtype_lower == "fp8": # FP8 is 2x faster than BF16 - # The baseline PEAK_FLOPS_PER_DEVICE is 1153.5 * 2 = 2307, which is FP8 peak. - # So the multiplier should be 1.0 + # The baseline PEAK_FLOPS_PER_DEVICE is 1153.5 * 2 = 2307, which is FP8 + # peak. So the multiplier should be 1.0 return 1.0 elif in_dtype_lower == "bf16" or in_dtype_lower == "fp16": # BF16/FP16 is 2x slower than FP8 peak diff --git a/Ironwood/src/run_benchmark.py b/Ironwood/src/run_benchmark.py index 0650efe..c74ebd3 100644 --- a/Ironwood/src/run_benchmark.py +++ b/Ironwood/src/run_benchmark.py @@ -237,8 +237,8 @@ def generate_benchmark_params_sweeping( current_value += increase_by else: raise ValueError( - "In sweep mode, user must provide either multiplier or" - " increase_by value." + "In sweep mode, user must provide either multiplier" + " or increase_by value." ) # Add the generated values to the param set param_sets[key] = param_values @@ -321,9 +321,11 @@ def convert_dict_to_df(target_dict: Dict) -> pd.DataFrame: return df # TODO(hylin2002@) - # This is a temporary workaround to generate a properly formatted CSV file for the output metrics. - # We should revert this PR and refactor the code such that metrics object is a flatten dict that can be easily exported as a CSV. - # For other information that requires nested structures, we should serialize it into a json file." + # This is a temporary workaround to generate a properly formatted CSV file + # for the output metrics. We should revert this PR and refactor the code + # such that metrics object is a flatten dict that can be easily exported + # as a CSV. For other information that requires nested structures, we + # should serialize it into a json file." df_list = [convert_dict_to_df(each) for each in calculate_metrics_results] df = pd.concat(df_list, ignore_index=True) From adf7d08bd8f93e774b8c59ade2fba46001cc0b8f Mon Sep 17 00:00:00 2001 From: Vidushi Vashishth Date: Mon, 16 Mar 2026 15:18:34 -0700 Subject: [PATCH 4/7] some more style fixes --- Ironwood/src/benchmark_utils.py | 40 ++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/Ironwood/src/benchmark_utils.py b/Ironwood/src/benchmark_utils.py index 788d457..8c3222f 100644 --- a/Ironwood/src/benchmark_utils.py +++ b/Ironwood/src/benchmark_utils.py @@ -1021,14 +1021,16 @@ def extract_hlo_features_from_file( return None, None, None, None # Extract input/output shapes from HloModule line - # Example: HloModule jit_f, ..., entry_computation_layout={(f32[32,128]{...})->f32[128,128]{...}} + # Example: HloModule jit_f, ..., + # entry_computation_layout={(f32[32,128]{...})->f32[128,128]{...}} layout_match = re.search( r"entry_computation_layout={\((.*?)\)->(.*?)}", content ) if layout_match: input_shape = layout_match.group(1) output_shape = layout_match.group(2) - # Further clean shape if layout info is present, e.g., f32[1,2]{1,0} -> f32[1,2] + # Further clean shape if layout info is present, e.g., + # f32[1,2]{1,0} -> f32[1,2] input_shape = re.sub(r"{.*}", "", input_shape) output_shape = re.sub(r"{.*}", "", output_shape) else: @@ -1252,8 +1254,9 @@ def unified_flops_metrics( f"{dtype_prefix}" f"Total floating-point ops: {total_flops}, Step Time (median): " f"{average_time_ms_statistics.statistics['p50']:.2f}, " - f"Throughput (median): {tflops_per_sec_statistics.statistics['p50']:.2f} " - f"TFLOP / second / device, " + f"Throughput (median): " + f"{tflops_per_sec_statistics.statistics['p50']:.2f}" + f" TFLOP / second / device, " f"TotalThroughput (median): " f"{tflops_per_sec_all_devices_statistics.statistics['p50']:.2f} " f"TFLOP / second, " @@ -1266,12 +1269,12 @@ def unified_flops_metrics( metadata.update( { "StepTime(median,ms)": average_time_ms_statistics.statistics["p50"], - "Throughput(median,TFLOP/s/device)": tflops_per_sec_statistics.statistics[ - "p50" - ], - "TotalThroughput(median,TFLOP/s)": tflops_per_sec_all_devices_statistics.statistics[ - "p50" - ], + "Throughput(median,TFLOP/s/device)": ( + tflops_per_sec_statistics.statistics["p50"] + ), + "TotalThroughput(median,TFLOP/s)": ( + tflops_per_sec_all_devices_statistics.statistics["p50"] + ), "MFU": mfu_statistics.statistics["p50"], "total_flops": total_flops, # "all_time_ms_list": f"{json.dumps(time_ms_list)}", @@ -1338,7 +1341,8 @@ def unified_bytes_metrics( f"{type_prefix}" f"Total bytes: {total_bytes}, Step Time (median): " f"{average_time_ms_statistics.statistics['p50']:.2f}, " - f"Throughput (median): {gigabytes_per_sec_statistics.statistics['p50']:.2f} " + f"Throughput (median):" + f"{gigabytes_per_sec_statistics.statistics['p50']:.2f} " f"GBytes / second / device, " f"TotalThroughput (median): " f"{gigabytes_per_sec_all_devices_statistics.statistics['p50']:.2f} " @@ -1348,12 +1352,12 @@ def unified_bytes_metrics( metadata.update( { "StepTime(median,ms)": average_time_ms_statistics.statistics["p50"], - "Throughput(median,GBytes/s/device)": gigabytes_per_sec_statistics.statistics[ - "p50" - ], - "TotalThroughput(median,GBytes/s)": gigabytes_per_sec_all_devices_statistics.statistics[ - "p50" - ], + "Throughput(median,GBytes/s/device)": ( + gigabytes_per_sec_statistics.statistics["p50"] + ), + "TotalThroughput(median,GBytes/s)": ( + gigabytes_per_sec_all_devices_statistics.statistics["p50"] + ), "total_bytes": total_bytes, } ) @@ -1401,5 +1405,5 @@ def get_peak_flops_multiplier(in_dtype_str: str) -> float: return 0.25 else: raise RuntimeError( - f"{in_dtype_lower} is not supported for setting peak_flops_multiplier." + f"{in_dtype_lower} is not supported for setting " "peak_flops_multiplier." ) From 36a7437beba352a6b0790d4c537e9267b6dc355a Mon Sep 17 00:00:00 2001 From: Vidushi Vashishth Date: Mon, 16 Mar 2026 15:21:25 -0700 Subject: [PATCH 5/7] Fixing error Quote delimiter ' is inconsistent with the rest of the file (inconsistent-quotes) --- Ironwood/src/benchmark_utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/Ironwood/src/benchmark_utils.py b/Ironwood/src/benchmark_utils.py index 8c3222f..151fd4f 100644 --- a/Ironwood/src/benchmark_utils.py +++ b/Ironwood/src/benchmark_utils.py @@ -1253,14 +1253,14 @@ def unified_flops_metrics( print( f"{dtype_prefix}" f"Total floating-point ops: {total_flops}, Step Time (median): " - f"{average_time_ms_statistics.statistics['p50']:.2f}, " + f"{average_time_ms_statistics.statistics["p50"]:.2f}, " f"Throughput (median): " - f"{tflops_per_sec_statistics.statistics['p50']:.2f}" + f"{tflops_per_sec_statistics.statistics["p50"]:.2f}" f" TFLOP / second / device, " f"TotalThroughput (median): " - f"{tflops_per_sec_all_devices_statistics.statistics['p50']:.2f} " + f"{tflops_per_sec_all_devices_statistics.statistics["p50"]:.2f} " f"TFLOP / second, " - f"MFU: {mfu_statistics.statistics['p50']:.2%}" + f"MFU: {mfu_statistics.statistics["p50"]:.2%}" ) # print() # time_ms_list = @@ -1340,12 +1340,12 @@ def unified_bytes_metrics( print( f"{type_prefix}" f"Total bytes: {total_bytes}, Step Time (median): " - f"{average_time_ms_statistics.statistics['p50']:.2f}, " + f"{average_time_ms_statistics.statistics["p50"]:.2f}, " f"Throughput (median):" - f"{gigabytes_per_sec_statistics.statistics['p50']:.2f} " + f"{gigabytes_per_sec_statistics.statistics["p50"]:.2f} " f"GBytes / second / device, " f"TotalThroughput (median): " - f"{gigabytes_per_sec_all_devices_statistics.statistics['p50']:.2f} " + f"{gigabytes_per_sec_all_devices_statistics.statistics["p50"]:.2f} " f"GBytes / second" ) print() @@ -1405,5 +1405,5 @@ def get_peak_flops_multiplier(in_dtype_str: str) -> float: return 0.25 else: raise RuntimeError( - f"{in_dtype_lower} is not supported for setting " "peak_flops_multiplier." + f"{in_dtype_lower} is not supported for setting peak_flops_multiplier." ) From ae03ba6bebbca230b66c6078478b1e2a89fad6e5 Mon Sep 17 00:00:00 2001 From: Vidushi Vashishth Date: Mon, 16 Mar 2026 17:22:37 -0700 Subject: [PATCH 6/7] Fixing some more comments --- Ironwood/src/benchmark_gemm_numerics.py | 1 - 1 file changed, 1 deletion(-) diff --git a/Ironwood/src/benchmark_gemm_numerics.py b/Ironwood/src/benchmark_gemm_numerics.py index 65b03a0..8444021 100644 --- a/Ironwood/src/benchmark_gemm_numerics.py +++ b/Ironwood/src/benchmark_gemm_numerics.py @@ -12,7 +12,6 @@ import os from typing import Any, Callable, Dict -# pylint: disable=g-importing-member from benchmark_utils import create_mesh from benchmark_utils import get_lhs_named_shading from benchmark_utils import get_out_sharding From b1304ebb7eede4dff72eef9ae98add0cc0a2a73e Mon Sep 17 00:00:00 2001 From: Vidushi Vashishth Date: Mon, 16 Mar 2026 17:24:11 -0700 Subject: [PATCH 7/7] missed fixes --- Ironwood/src/benchmark_attention.py | 1 - Ironwood/src/run_benchmark.py | 6 ++++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/Ironwood/src/benchmark_attention.py b/Ironwood/src/benchmark_attention.py index 69d8076..d64b9fa 100644 --- a/Ironwood/src/benchmark_attention.py +++ b/Ironwood/src/benchmark_attention.py @@ -1,6 +1,5 @@ """A script to benchmark tokamax splash attention implementation.""" -# pylint: disable=g-importing-member import dataclasses from functools import partial import logging diff --git a/Ironwood/src/run_benchmark.py b/Ironwood/src/run_benchmark.py index c74ebd3..01b75e2 100644 --- a/Ironwood/src/run_benchmark.py +++ b/Ironwood/src/run_benchmark.py @@ -552,7 +552,8 @@ def run_benchmark_multithreaded(benchmark_config, output_path): future.result() ) # Get the result from the future - # Filter benchmark_results to include only keys present in calculate_metrics_func + # Filter benchmark_results to include only keys present in + # calculate_metrics_func calculate_metrics_params = inspect.signature( calculate_metrics_func ).parameters @@ -562,7 +563,8 @@ def run_benchmark_multithreaded(benchmark_config, output_path): if key in calculate_metrics_params } - # Call calculate_metrics_func with the filtered results and benchmark_param + # Call calculate_metrics_func with the filtered results and + # benchmark_param metadata, metrics = calculate_metrics_func( **benchmark_param, **filtered_benchmark_results )