diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml new file mode 100644 index 0000000..b26d6d7 --- /dev/null +++ b/.github/workflows/pylint.yml @@ -0,0 +1,28 @@ +name: Pylint Style Check + +on: + pull_request: + paths: + - '**.py' + +jobs: + pylint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.13" + cache: 'pip' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + + - name: Run Pylint + run: | + # Use the repository's .pylintrc rules on all python files + pylint $(git ls-files '*.py') diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..a5e3241 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + +- repo: local + hooks: + - id: pylint + name: pylint + entry: pylint + language: system + types: [python] + require_serial: true + # Optional: You can list specific files to exclude here if needed + # exclude: ^tests/ diff --git a/Ironwood/src/benchmark_collectives.py b/Ironwood/src/benchmark_collectives.py index 59bc36d..f1a10fe 100644 --- a/Ironwood/src/benchmark_collectives.py +++ b/Ironwood/src/benchmark_collectives.py @@ -57,7 +57,8 @@ def create_mesh(ici_size: int, mesh_shape: str) -> Mesh: def get_sharding_axis(dim_str: str, mesh: Mesh) -> tuple[str, ...]: - """Computes sharding axis names from dimension string like '1x4' and mesh.""" + """Computes sharding axis names from dimension string and mesh.""" + # Example of a dimension string is '1x4' dim_tuple = dim_str.split("x") dim_tuple = tuple(int(dim) for dim in dim_tuple) sharding_axis = tuple( @@ -203,6 +204,7 @@ def psum_benchmark( num_runs: int = 1, trace_dir: str = None, ) -> Dict[str, Any]: + # pylint: disable=unused-argument """Benchmarks the psum collective operation. Args: @@ -354,6 +356,7 @@ def psum_scatter_benchmark( num_runs: int = 1, trace_dir: str = None, ) -> Dict[str, Any]: + # pylint: disable=unused-argument """Benchmarks the psum_scatter collective operation. Args: @@ -376,7 +379,7 @@ def psum_scatter_benchmark( "--xla_sc_disable_megacore_partitioning=true", "--xla_tpu_disable_sparse_core_collective_offload_remover=true", "--xla_tpu_enable_reduce_scatter_offload_tracing=true", - "--xla_tpu_enable_sparse_core_collective_offload_nd_reduce_scatter=true", + "--xla_tpu_enable_sparse_core_collective_offload_nd_reduce_scatter=true", # pylint: disable=line-too-long "--xla_tpu_enable_sparse_core_collective_offload_reduce_scatter=true", "--xla_tpu_enable_sparse_core_reduce_scatter_v2=true", "--xla_tpu_use_tc_device_shape_on_sc=true", @@ -470,6 +473,7 @@ def all_gather_benchmark( num_runs: int = 1, trace_dir: str = None, ) -> Dict[str, Any]: + # pylint: disable=unused-argument """Benchmarks the all_gather collective operation. Args: @@ -586,6 +590,7 @@ def all_to_all_benchmark( num_runs: int = 1, trace_dir: str = None, ) -> Dict[str, Any]: + # pylint: disable=unused-argument """Benchmarks the all_to_all collective operation. Args: diff --git a/Ironwood/src/benchmark_compute.py b/Ironwood/src/benchmark_compute.py index 458db80..2c42d19 100644 --- a/Ironwood/src/benchmark_compute.py +++ b/Ironwood/src/benchmark_compute.py @@ -318,11 +318,11 @@ def swiglu_fwd( def f(x): with jax.named_scope(MARKER): - A, B = jnp.split(x, 2, axis=-1) - A_fp32 = A.astype(jnp.float32) - B_fp32 = B.astype(jnp.float32) - Y_fp32 = jax.nn.silu(A_fp32) * B_fp32 - return Y_fp32.astype(jnp.bfloat16) + a, b = jnp.split(x, 2, axis=-1) + a_fp32 = a.astype(jnp.float32) + b_fp32 = b.astype(jnp.float32) + y_fp32 = jax.nn.silu(a_fp32) * b_fp32 + return y_fp32.astype(jnp.bfloat16) mesh = create_mesh(SHARDING_STRATEGY) x_sharding = get_rowwise_named_shading(mesh, SHARDING_STRATEGY) @@ -379,16 +379,17 @@ def swiglu_bwd( num_runs: int = 1, trace_dir: str = None, ) -> Dict[str, Any]: + # pylint: disable=invalid-name """ Inverse of swiglu_fwd """ def f_fwd(x): - A, B = jnp.split(x, 2, axis=-1) - A_fp32 = A.astype(jnp.float32) - B_fp32 = B.astype(jnp.float32) - Y_fp32 = jax.nn.silu(A_fp32) * B_fp32 - return Y_fp32.astype(jnp.bfloat16) + a, b = jnp.split(x, 2, axis=-1) + a_fp32 = a.astype(jnp.float32) + b_fp32 = b.astype(jnp.float32) + y_fp32 = jax.nn.silu(a_fp32) * b_fp32 + return y_fp32.astype(jnp.bfloat16) def f(x: jax.Array, dy: jax.Array) -> jax.Array: """ @@ -397,7 +398,10 @@ def f(x: jax.Array, dy: jax.Array) -> jax.Array: """ # Get the VJP "pullback" function # We ignore the forward result (_y) - _y, pullback_fn = jax.vjp(f_fwd, x) + # pylint: disable=unused-variable,invalid-name + _y, pullback_fn = jax.vjp( + f_fwd, x + ) with jax.named_scope(MARKER): # Call the pullback function with the upstream gradient # This IS the backward pass. @@ -555,7 +559,10 @@ def f(x: jax.Array, dy: jax.Array) -> jax.Array: """ # Get the VJP "pullback" function # We ignore the forward result (_y) - _y, pullback_fn = jax.vjp(f_fwd, x) + # pylint: disable=unused-variable,invalid-name + _y, pullback_fn = jax.vjp( + f_fwd, x + ) with jax.named_scope(MARKER): # Call the pullback function with the upstream gradient # This IS the backward pass. diff --git a/Ironwood/src/benchmark_gemm.py b/Ironwood/src/benchmark_gemm.py index c965a44..c79bda5 100644 --- a/Ironwood/src/benchmark_gemm.py +++ b/Ironwood/src/benchmark_gemm.py @@ -69,7 +69,8 @@ def gemm_multiple_run( ) -> Dict[str, Any]: """Benchmarks the OUT:BF16 = IN0 dtype x IN1:dtype.""" - """Accumulation is FP32. Current supported dtype: float8_e4m3fn, bfloat16.""" + # Accumulation is FP32. Current supported dtype: float8_e4m3fn, + # bfloat16. def f(x, y): with jax.named_scope(MARKER): @@ -170,8 +171,7 @@ def gemm_simple( trace_dir: str = None, ) -> Dict[str, Any]: """Benchmarks the OUT:BF16 = IN0:FP8 x IN1:FP8.""" - - """Accumulation is FP32.""" + # Accumulation is FP32. def f(x, y): with jax.named_scope(MARKER): @@ -266,8 +266,7 @@ def gemm_simple_with_dtype( trace_dir: str = None, ) -> Dict[str, Any]: """Benchmarks the OUT:BF16 = IN0:FP8 x IN1:FP8.""" - - """Accumulation is FP32.""" + # Accumulation is FP32. # Convert string dtypes to jnp dtypes lhs_dtype = str_to_dtype(in_dtype_str) @@ -368,7 +367,8 @@ def gemm_simple_with_dtype_calculate_metrics( def gemm( m: int, k: int, n: int, num_runs: int = 1, trace_dir: str = None ) -> Dict[str, Any]: - """OUT:BF16 = matmul(IN0:FP8, IN1:FP8) * outer_product(SF0:FP32 * SF1<1, N>:FP32).""" + """OUT:BF16 = matmul(IN0:FP8, IN1:FP8) * + outer_product(SF0:FP32 * SF1<1, N>:FP32).""" def f(x, y, scale_m, scale_n): with jax.named_scope(MARKER): @@ -473,7 +473,8 @@ def gemm_accum( num_runs: int = 1, trace_dir: str = None, ) -> Dict[str, Any]: - """OUT:FP32 += matmul(IN0:FP8, IN1:FP8) * outer_product(SF0:FP32 * SF1<1, N>:FP32).""" + """OUT:FP32 += matmul(IN0:FP8, IN1:FP8) * + outer_product(SF0:FP32 * SF1<1, N>:FP32).""" def f(out_buffer, x, y, scale_m, scale_n): with jax.named_scope(MARKER): diff --git a/Ironwood/src/benchmark_gemm_numerics.py b/Ironwood/src/benchmark_gemm_numerics.py index 8444021..0e1d54f 100644 --- a/Ironwood/src/benchmark_gemm_numerics.py +++ b/Ironwood/src/benchmark_gemm_numerics.py @@ -273,8 +273,7 @@ def gemm_fp8_b128_fp32( m: int, k: int, n: int, num_runs: int = 1, trace_dir: str = None ) -> Dict[str, Any]: """FP8 GEMM as DeepSeek-stype quantization, block size: 1x128.""" - - """Use dynamic scaling factors.""" + # Use dynamic scaling factors. def f(x, y): with jax.named_scope(MARKER): @@ -387,8 +386,7 @@ def gemm_fp8_b128_fp32_static_scaling( m: int, k: int, n: int, num_runs: int = 1, trace_dir: str = None ) -> Dict[str, Any]: """FP8 GEMM as DeepSeek-stype quantization, block size: 1x128.""" - - """Use static scaling factors.""" + # Use static scaling factors. def f(x, y): with jax.named_scope(MARKER): diff --git a/Ironwood/src/benchmark_hbm.py b/Ironwood/src/benchmark_hbm.py index a99d25a..f72f768 100644 --- a/Ironwood/src/benchmark_hbm.py +++ b/Ironwood/src/benchmark_hbm.py @@ -89,8 +89,8 @@ def single_device_hbm_copy_calculate_metrics( ) print( f"Tensor size: {tensor_size_bytes / 1024**2} MB, " - f"time taken (median): {time_statistics.statistics['p50']:.4f} ms, " - f"bandwidth (median): {statistics.statistics['p50']:.3f} GB/s" + f"time taken (median): {time_statistics.statistics["p50"]:.4f} ms, " + f"bandwidth (median): {statistics.statistics["p50"]:.3f} GB/s" ) print() # Gather the metrics to report. diff --git a/Ironwood/src/benchmark_host_device.py b/Ironwood/src/benchmark_host_device.py index de2db0a..cc1e505 100644 --- a/Ironwood/src/benchmark_host_device.py +++ b/Ironwood/src/benchmark_host_device.py @@ -1,4 +1,7 @@ -"""Benchmarks Host-to-Device and Device-to-Host transfer performance (Simple Baseline).""" +""" +Benchmarks Host-to-Device and Device-to-Host transfer performance +(Simple Baseline). +""" import time import os @@ -123,8 +126,8 @@ def add_metric(name, ms_list): ] stats_bw = MetricsStatistics(bw_list, f"{name}_bw (GiB/s)") print( - f" {name}_bw (GiB/s) median: {stats_bw.statistics['p50']}, " - f"P95: {stats_bw.statistics['p95']}", + f"{name}_bw (GiB/s) median: {stats_bw.statistics["p50"]}, " + f"P95: {stats_bw.statistics["p95"]}", flush=True, ) metrics.update(stats_bw.serialize_statistics()) diff --git a/Ironwood/src/benchmark_inference_compute.py b/Ironwood/src/benchmark_inference_compute.py index eb399bb..f2ae12f 100644 --- a/Ironwood/src/benchmark_inference_compute.py +++ b/Ironwood/src/benchmark_inference_compute.py @@ -347,89 +347,3 @@ def sigmoid_calculate_metrics( dtype=dtype.dtype.name, ) - -# def get_output_named_shading(mesh, strategy: ShardingStrategy): -# match strategy: -# case ShardingStrategy.NO_SHARDING: -# return NamedSharding(mesh, P(None)) -# case ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_M: -# return NamedSharding(mesh, P("device")) -# case ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_M: -# return NamedSharding(mesh, P("device")) -# case ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_N: -# assert False, f"ShardingStrategy is wrong for this ops: {strategy}" -# case ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_N: -# assert False, f"ShardingStrategy is wrong for this ops: {strategy}" - -# def get_out_sharding(strategy: ShardingStrategy): -# match strategy: -# case ShardingStrategy.NO_SHARDING: -# return P(None) -# case ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_M: -# return P("device") -# case ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_M: -# return P("device") -# case ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_N: -# assert False, f"ShardingStrategy is wrong for this ops: {strategy}" -# case ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_N: -# assert False, f"ShardingStrategy is wrong for this ops: {strategy}" - -# def add(m: int, dtype: jnp.dtype, num_runs: int = 1, trace_dir: str = None, -# ) -> Dict[str, Any]: -# """ -# Z = X + Y -# """ -# def f(x, y): -# with jax.named_scope(MARKER): -# return x + y - -# mesh = create_mesh(SHARDING_STRATEGY) -# x_sharding = get_output_named_shading(mesh, SHARDING_STRATEGY) -# y_sharding = get_output_named_shading(mesh, SHARDING_STRATEGY) -# out_sharding = get_out_sharding(SHARDING_STRATEGY) -# jit_sharded_f = jax.jit( -# shard_map( -# f, -# mesh, -# in_specs=(x_sharding.spec, y_sharding.spec), -# out_specs=out_sharding, -# check_rep=False, -# ) -# ) -# x_shape = (m) -# y_shape = (m) -# x_dtype = dtype -# y_dtype = dtype - -# key = jax.random.key(SEED) - -# def data_generator(): -# """Creates new random data on host and puts it on device.""" -# nonlocal key # Use and update the outer 'key' -# key, k1, k2 = jax.random.split(key, 3) - -# x_host = jax.random.normal(k1, x_shape).astype(x_dtype) -# y_host = jax.random.normal(k2, y_shape).astype(y_dtype) - -# x_device = jax.device_put(x_host, x_sharding) -# y_device = jax.device_put(y_host, y_sharding) - -# return (x_device, y_device) - -# time_ms_list = iteration_timeit( -# jit_sharded_f, -# data_generator, -# matrix_dim=f"{m}", -# tries=num_runs, -# task="add", -# trace_dir=trace_dir, -# ) -# return {"time_ms_list": time_ms_list} - -# def add_calculate_metrics( -# m: int, dtype: jnp.dtype, time_ms_list: list[float] -# ) -> Dict[str, Any]: -# scale = 2 if dtype == jnp.bfloat16 else 1 -# total_bytes = scale * 3 * m -# total_bytes, total_bytes_all_devices = handle_based_on_sharding(total_bytes, SHARDING_STRATEGY) -# return unified_bytes_metrics(m, 0, time_ms_list, total_bytes, total_bytes_all_devices, dtype=dtype.dtype.name) diff --git a/Ironwood/src/benchmark_send_recv.py b/Ironwood/src/benchmark_send_recv.py index 1eaa2b3..6374271 100644 --- a/Ironwood/src/benchmark_send_recv.py +++ b/Ironwood/src/benchmark_send_recv.py @@ -84,7 +84,11 @@ def send_recv_benchmark( dtype: jnp.dtype, trace_dir: str, ): - """Runs p2p communication, sending tensor_size_bytes from source to target device.""" + # pylint: disable=unused-argument + """ + Runs p2p communication, sending tensor_size_bytes from source to target + device. + """ device_count = jax.local_device_count() devices = mesh_utils.create_device_mesh((device_count,)) mesh = jax.sharding.Mesh(devices, "x") @@ -120,14 +124,14 @@ def p2p_send(source_id, target_id): target_recv_sizes, no_recvs, ) - input = jax.random.normal( + random_input = jax.random.normal( jax.random.key(0), (1, 8, last_dim), dtype=dtype ) output = jnp.zeros((1, 8, last_dim), dtype=dtype) with jax.named_scope(MARKER): ra2a = jax.lax.ragged_all_to_all( - operand=input, + operand=random_input, output=output, input_offsets=input_offsets, send_sizes=final_send_sizes, @@ -158,10 +162,10 @@ def p2p_send(source_id, target_id): def send_recv_benchmark_calculate_metrics( - source_id: int, - target_id: int, + source_id: int, # pylint: disable=unused-argument + target_id: int, # pylint: disable=unused-argument num_elements: int, - n_repeats: int, + n_repeats: int, # pylint: disable=unused-argument dtype: jnp.dtype, runtime_ms: float, ) -> Tuple[Dict[str, Any], Dict[str, Any]]: diff --git a/Ironwood/src/benchmark_utils.py b/Ironwood/src/benchmark_utils.py index 151fd4f..4927e68 100644 --- a/Ironwood/src/benchmark_utils.py +++ b/Ironwood/src/benchmark_utils.py @@ -33,10 +33,10 @@ def get_real_dtype_bytes(dtype) -> float: """Returns the real byte size of a dtype, handling sub-byte types.""" try: return jnp.finfo(dtype).bits / 8 - except Exception: + except (ValueError, TypeError): try: return jnp.iinfo(dtype).bits / 8 - except Exception: + except (ValueError, TypeError): return dtype.itemsize @@ -72,7 +72,7 @@ def multiple_iteration_timeit_from_trace_throttling( gap_strategy: str = None, ) -> list[float]: """Time a function with jax.profiler and get the run time from the trace.""" - LOCAL_TRACE_DIR = "/tmp/microbenchmarks_tmptrace" + local_trace_dir = "/tmp/microbenchmarks_tmptrace" if matrix_dim is not None: trace_name = f"{task}_dim_{matrix_dim}" @@ -86,7 +86,7 @@ def multiple_iteration_timeit_from_trace_throttling( # If the trace_dir isn't a local path, create one for dumping the trace for # parsing and getting metrics. if trace_dir and not is_local_directory_path(trace_dir): - tmp_trace_dir = f"{LOCAL_TRACE_DIR}/{trace_name}" + tmp_trace_dir = f"{local_trace_dir}/{trace_name}" if gap_strategy == "data_gen_once_block_every_iter": data_args = data_generator() @@ -162,7 +162,7 @@ def multiple_iteration_timeit_from_trace( """ Time a function with jax.profiler and get the run time from the trace. """ - LOCAL_TRACE_DIR = "/tmp/microbenchmarks_tmptrace" + local_trace_dir = "/tmp/microbenchmarks_tmptrace" if matrix_dim is not None: trace_name = f"{task}_dim_{matrix_dim}" @@ -176,7 +176,7 @@ def multiple_iteration_timeit_from_trace( # If the trace_dir isn't a local path, create one for dumping the trace for # parsing and getting metrics. if trace_dir and not is_local_directory_path(trace_dir): - tmp_trace_dir = f"{LOCAL_TRACE_DIR}/{trace_name}" + tmp_trace_dir = f"{local_trace_dir}/{trace_name}" # data_args = data_generator() with jax.profiler.trace(tmp_trace_dir): for i in range(tries): @@ -274,7 +274,7 @@ def iteration_timeit_from_trace( """ Time a function with jax.profiler and get the run time from the trace. """ - LOCAL_TRACE_DIR = "/tmp/microbenchmarks_tmptrace" + local_trace_dir = "/tmp/microbenchmarks_tmptrace" if matrix_dim is not None: trace_name = f"{task}_dim_{matrix_dim}" @@ -288,7 +288,7 @@ def iteration_timeit_from_trace( # If the trace_dir isn't a local path, create one for dumping the trace for # parsing and getting metrics. if trace_dir and not is_local_directory_path(trace_dir): - tmp_trace_dir = f"{LOCAL_TRACE_DIR}/{trace_name}" + tmp_trace_dir = f"{local_trace_dir}/{trace_name}" with jax.profiler.trace(tmp_trace_dir): for _ in range(tries): data_args = data_generator() @@ -380,6 +380,7 @@ def iteration_get_event_metrics_from_trace( trace: dict[str, Any], event_name_str_list: list[str], ) -> list[float]: + # pylint: disable=unused-variable # Rename the storage variable to reflect its contents selected_events = [] @@ -494,7 +495,7 @@ def iteration_timeit( outcomes_ms = [] print(f"[{task}] Running measurement loop with {tries} tries...") - for i in range(tries): + for i in range(tries): # pylint: disable=unused-variable # 1. Generate NEW random data (meets "no cache hit" rule) data_args = data_generator() jax.devices() # Force synchronization across devices @@ -632,7 +633,7 @@ def get_metrics_from_trace(trace: dict[str, Any], task: str) -> list[float]: try: task = TARGET_TASK_NAME_COLLECTIVES_MAP[task] return get_metrics_from_trace_tpu(trace, task) - except: + except (KeyError, ValueError, TypeError): return [-1.0] event_matcher = re.compile(task) @@ -693,15 +694,19 @@ def get_metrics_from_trace_tpu(trace: dict[str, Any], task: str) -> list[float]: return durations_ms -def is_local_directory_path(dir: str) -> bool: +def is_local_directory_path(directory: str) -> bool: """ Returns true if the path is a local path. """ - if not dir: # Handle None or empty string + if not directory: # Handle None or empty string return False # Heuristics for local paths - return dir.startswith("/") or dir.startswith("./") or dir.startswith("../") + return ( + directory.startswith("/") + or directory.startswith("./") + or directory.startswith("../") + ) def timeit_from_trace( @@ -716,7 +721,7 @@ def timeit_from_trace( """ Time a function with jax.profiler and get the run time from the trace. """ - LOCAL_TRACE_DIR = "/tmp/microbenchmarks_tmptrace" + local_trace_dir = "/tmp/microbenchmarks_tmptrace" jax.block_until_ready(f(*args)) # warm it up! @@ -732,7 +737,7 @@ def timeit_from_trace( # If the trace_dir isn't a local path, create one for dumping the trace for # parsing and getting metrics. if trace_dir and not is_local_directory_path(trace_dir): - tmp_trace_dir = f"{LOCAL_TRACE_DIR}/{trace_name}" + tmp_trace_dir = f"{local_trace_dir}/{trace_name}" print(trace_dir) with jax.profiler.trace(tmp_trace_dir): for _ in range(tries): @@ -757,7 +762,9 @@ def timeit_from_trace( def maybe_write_metrics_file( metrics_dir, metrics, metadata, test_name, test_start_time, test_end_time ): - """Writes metrics to a JSONL file to be consumed by the XLML metrics pipeline.""" + """ + Writes metrics to a JSONL file to be consumed by the XLML metrics pipeline. + """ # Only write metrics from one host. if jax.process_index() != 0: @@ -814,7 +821,7 @@ def upload_to_storage(trace_dir: str, local_file: str): def load_yaml_config(config_path: str) -> Dict[str, Any] | None: """Loads a YAML config file.""" try: - with open(config_path, "r") as f: + with open(config_path, "r", encoding="utf-8") as f: return yaml.safe_load(f) except FileNotFoundError: print(f"Warning: Config file not found at {config_path}") @@ -965,7 +972,7 @@ def rename_xla_dump( try: os.makedirs(dest_xla_dump_dir, exist_ok=True) shutil.copy(original_filepath, new_filepath) - except Exception as e: + except OSError as e: print( f"An unexpected error occurred while copy " f"'{original_filepath}': {e}" @@ -1014,7 +1021,7 @@ def extract_hlo_features_from_file( first_replica_group = None try: - with open(hlo_file_path, "r") as f: + with open(hlo_file_path, "r", encoding="utf-8") as f: content = f.read() except FileNotFoundError: print(f"Error: HLO file not found at {hlo_file_path}") @@ -1050,7 +1057,7 @@ def extract_hlo_features_from_file( content_rg = replica_groups_str[2:-2] first_group_str = content_rg.split("},{")[0] first_replica_group = [int(x) for x in first_group_str.split(",")] - except Exception as e: + except ValueError as e: print(f"Could not parse replica_groups in hlo_text: {e}") first_replica_group = None else: @@ -1190,6 +1197,7 @@ def create_mesh(strategy: ShardingStrategy) -> Mesh: def get_metrics_helper( params: Dict[str, Any], ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + # pylint: disable=invalid-name """Helper function to build the metrics and metadata for the benchmark.""" exclude_param_keys = { "time_ms_list", @@ -1214,9 +1222,10 @@ def unified_flops_metrics( time_ms_list: list[float], total_flops: int, total_flops_all_devices: int, - peak_TFLOPS_per_device: float, + peak_TFLOPS_per_device: float, # pylint: disable=invalid-name dtype: str = None, ) -> Dict[str, Any]: + # pylint: disable=unused-argument """Calculates the metrics for the naive matmul benchmark.""" # Build dictionary of all the parameters in the function params = locals().items() @@ -1262,8 +1271,6 @@ def unified_flops_metrics( f"TFLOP / second, " f"MFU: {mfu_statistics.statistics["p50"]:.2%}" ) - # print() - # time_ms_list = # Gather the metrics to report. metadata.update( @@ -1277,7 +1284,6 @@ def unified_flops_metrics( ), "MFU": mfu_statistics.statistics["p50"], "total_flops": total_flops, - # "all_time_ms_list": f"{json.dumps(time_ms_list)}", } ) metrics.update(average_time_ms_statistics.serialize_statistics()) @@ -1291,6 +1297,7 @@ def unified_flops_metrics( def unified_bytes_metrics( + # pylint: disable=unused-argument m: int, n: int, time_ms_list: list[float], @@ -1405,5 +1412,5 @@ def get_peak_flops_multiplier(in_dtype_str: str) -> float: return 0.25 else: raise RuntimeError( - f"{in_dtype_lower} is not supported for setting peak_flops_multiplier." + f"No support for {in_dtype_lower} in setting peak_flops_multiplier." ) diff --git a/Ironwood/src/run_benchmark.py b/Ironwood/src/run_benchmark.py index 01b75e2..c584fbb 100644 --- a/Ironwood/src/run_benchmark.py +++ b/Ironwood/src/run_benchmark.py @@ -58,7 +58,9 @@ ), } ATTENTION_BENCHMARK_MAP = { - "tokamax_splash_attention": "benchmark_attention.tokamax_splash_attention_benchmark", + "tokamax_splash_attention": ( + "benchmark_attention.tokamax_splash_attention_benchmark" + ), } HBM_BENCHMARK_MAP = { "single_device_hbm_copy": "benchmark_hbm.single_device_hbm_copy", @@ -134,7 +136,7 @@ def get_benchmark_config(config_path: str) -> Dict[str, Any]: """Load benchmark configuration from a YAML file.""" - with open(config_path, "r") as file: + with open(config_path, "r", encoding="utf-8") as file: return yaml.safe_load(file) @@ -142,7 +144,10 @@ def get_benchmark_config(config_path: str) -> Dict[str, Any]: def get_benchmark_functions( benchmark_name: str, ) -> Tuple[Callable[..., Any], Callable[..., Any]]: - """Dynamically load the benchmark function and its calculate_metrics function from the predefined map.""" + """ + Dynamically load the benchmark function and its calculate_metrics function + from the predefined map. + """ if benchmark_name not in BENCHMARK_MAP: raise ValueError( f"Benchmark {benchmark_name} is not defined in the map." @@ -156,7 +161,7 @@ def get_benchmark_functions( benchmark_func = getattr(module, func_name) except ModuleNotFoundError as e: raise ValueError( - f"Unable to import {module_path}.{func_name}. ModuleNotFoundError {e}." + f"Unable to import {module_path}.{func_name}. ModuleNotFoundError {e}." # pylint: disable=line-too-long ) from e except AttributeError as e: raise ValueError( @@ -206,7 +211,9 @@ def preprocess_benchmark_param( def generate_benchmark_params_sweeping( benchmark_sweep_params: List[Dict[str, Any]], ) -> List[Dict[str, Any]]: - """Generate benchmark parameters by sweeping through the specified ranges.""" + """ + Generate benchmark parameters by sweeping through the specified ranges. + """ generated_params = [] for sweep_params in benchmark_sweep_params: param_sets = {} @@ -275,7 +282,8 @@ def write_to_csv( Args: csv_path: The path to the output CSV file. - calculate_metrics_results: A list of dictionaries with benchmark results. + calculate_metrics_results: A list of dictionaries with benchmark + results. """ if not calculate_metrics_results: raise ValueError("0 metrics results are collected.") @@ -298,7 +306,9 @@ def flatten_dict(current_dict: Dict) -> Dict: return output_dict def convert_dict_to_df(target_dict: Dict) -> pd.DataFrame: - """Converts a single benchmark result dictionary to a pandas DataFrame.""" + """ + Converts a single benchmark result dictionary to a pandas DataFrame. + """ flattened_dict = flatten_dict(target_dict) # This section is specific to collective benchmarks that produce @@ -335,6 +345,7 @@ def convert_dict_to_df(target_dict: Dict) -> pd.DataFrame: def run_single_benchmark(benchmark_config: Dict[str, Any], output_path: str): + # pylint: disable=inconsistent-quotes """Run a single benchmark with one or more configurations.""" # Extract benchmark details benchmark_name = benchmark_config.get("benchmark_name") @@ -360,7 +371,7 @@ def run_single_benchmark(benchmark_config: Dict[str, Any], output_path: str): param["num_runs"] = global_num_runs if not benchmark_name: - raise ValueError("Each benchmark must have a 'benchmark_name'.") + raise ValueError("Each benchmark must have a benchmark_name.") # Get the benchmark function @@ -372,19 +383,16 @@ def run_single_benchmark(benchmark_config: Dict[str, Any], output_path: str): # Run the benchmark calculate_metrics_results = [] - for id, benchmark_param in enumerate(benchmark_params): + for idx, benchmark_param in enumerate(benchmark_params): original_benchmark_param = copy.deepcopy(benchmark_param) benchmark_param = preprocess_benchmark_param( benchmark_param, - trace_dir=os.path.join(trace_dir, f"benchmark_{id}"), - ) - print( - f"Running benchmark: {benchmark_name} with params: {benchmark_param}" + trace_dir=os.path.join(trace_dir, f"benchmark_{idx}"), ) + print(f"Running {benchmark_name} with params: {benchmark_param}") test_start_time = ( datetime.datetime.now(tz=datetime.timezone.utc).isoformat() + "Z" ) # "Z" indicates UTC - benchmark_func_params = inspect.signature(benchmark_func).parameters try: benchmark_results = benchmark_func(**benchmark_param) except Exception as e: # pylint: disable=broad-except @@ -447,6 +455,7 @@ def run_single_benchmark(benchmark_config: Dict[str, Any], output_path: str): def main(args): + # pylint: disable=redefined-outer-name """Main function.""" # Load configuration config_path = args.config @@ -472,7 +481,6 @@ def main(args): "XLA_IR_DEBUG": "1", "XLA_HLO_DEBUG": "1", "PJRT_DEVICE": "TPU", - # "LIBTPU_INIT_ARGS": "--xla_tpu_scoped_vmem_limit_kib=25602", }, ) ) @@ -516,6 +524,7 @@ def run_benchmark_multithreaded(benchmark_config, output_path): benchmark_name ) + # pylint: disable=inconsistent-quotes print(f"\n{'=' * 30}Starting benchmark '{benchmark_name}'{'=' * 30}\n") # Start a trace if requested diff --git a/src/all_gather.py b/src/all_gather.py index 0270d60..b408016 100644 --- a/src/all_gather.py +++ b/src/all_gather.py @@ -18,7 +18,9 @@ def all_gather(matrix_dim): - """Performs an all_gather operation and calculates the achieved bandwidth.""" + """ + Performs an all_gather operation and calculates the achieved bandwidth. + """ dtype = jax.numpy.bfloat16 matrix = jax.numpy.arange(matrix_dim * matrix_dim, dtype=dtype).reshape( matrix_dim, matrix_dim @@ -120,6 +122,7 @@ def run_benchmark(): "p90_achieved_bandwidth_gbyte_s": p90_achieved_bandwidth_gbyte_s, } if METRICS_JSONL_DIR: + # pylint: disable=no-value-for-parameter maybe_write_metrics_file( METRICS_JSONL_DIR, metrics, diff --git a/src/all_reduce.py b/src/all_reduce.py index e12f536..e58b4ad 100644 --- a/src/all_reduce.py +++ b/src/all_reduce.py @@ -90,7 +90,7 @@ def run_benchmark(): break except Exception as e: # pylint: disable=broad-exception-caught print( - f"Exception: {e} occurred at size {matrix_size} x {matrix_size}.\n" + f"Exception: {e} occurred at size {matrix_size} x {matrix_size}.\n" # pylint: disable=line-too-long ) break if TRACE_BASE_DIR: @@ -116,6 +116,7 @@ def run_benchmark(): "p90_achieved_bandwidth_gbyte_s": p90_achieved_bandwidth_gbyte_s, } if METRICS_JSONL_DIR: + # pylint: disable=no-value-for-parameter maybe_write_metrics_file( METRICS_JSONL_DIR, metrics, diff --git a/src/benchmark_attention.py b/src/benchmark_attention.py index a6e8bac..df075eb 100644 --- a/src/benchmark_attention.py +++ b/src/benchmark_attention.py @@ -7,9 +7,9 @@ key (k), and value (v) vectors. 2. pallas_flash_attention_benchmark: attention with the pallas flash attention kernel. -(https://github.com/jax-ml/jax/blob/main/jax/experimental/pallas/ops/tpu/flash_attention.py) +(https://github.com/jax-ml/jax/blob/main/jax/experimental/pallas/ops/tpu/flash_attention.py) # pylint: disable=line-too-long 3. splash_attention_benchmark: attention with the splash attention kernel. - (https://github.com/jax-ml/jax/tree/main/jax/experimental/pallas/ops/tpu/splash_attention) +(https://github.com/jax-ml/jax/tree/main/jax/experimental/pallas/ops/tpu/splash_attention) # pylint: disable=line-too-long 4. flax_nnx_attention_benchmark: attention with the flax nnx attention library. 5. flax_linen_attention_benchmark: attention with the flax linen attention library. @@ -122,6 +122,7 @@ def f(q, k, v, causal, scale): jax.block_until_ready(output) # Run benchmark + # pylint: disable=unexpected-keyword-arg time_ms_list = simple_timeit( f, q, @@ -179,6 +180,7 @@ def f(q, k, v, causal): jax.block_until_ready(output) # Run benchmark + # pylint: disable=unexpected-keyword-arg time_ms_list = simple_timeit( f, q, @@ -264,6 +266,7 @@ def f(q, k, v, causal): jax.block_until_ready(output) # Run benchmark + # pylint: disable=unexpected-keyword-arg time_ms_list = simple_timeit( f, q, @@ -323,6 +326,7 @@ def f(q, k, v): jax.block_until_ready(output) # Run benchmark + # pylint: disable=unexpected-keyword-arg time_ms_list = simple_timeit( f, q, @@ -379,6 +383,7 @@ def f(q, k, v): jax.block_until_ready(output) # Run benchmark + # pylint: disable=unexpected-keyword-arg time_ms_list = simple_timeit( f, q, @@ -445,6 +450,7 @@ def f(q, k, v, causal): jax.block_until_ready(output) # Run benchmark + # pylint: disable=unexpected-keyword-arg time_ms_list = simple_timeit( f, q, diff --git a/src/benchmark_collectives.py b/src/benchmark_collectives.py index ace4119..1993ec1 100644 --- a/src/benchmark_collectives.py +++ b/src/benchmark_collectives.py @@ -63,7 +63,10 @@ def generate_metrics_statistics( matrix_size_gbyte: float, metrics: Dict[str, Any], ) -> None: - """Calculates statistics for a metrics list, prints p50, and updates the metrics dict.""" + """ + Calculates statistics for a metrics list, prints p50, and updates the + metrics dict. + """ if not metrics_list: return statistics = MetricsStatistics( @@ -72,7 +75,8 @@ def generate_metrics_statistics( ) print( f"{benchmark_name}: Matrix size: {matrix_dim}x{matrix_dim}, {dtype=}, " - f"{matrix_size_gbyte=}, {metrics_name} (median) = {statistics.statistics['p50']}" + f"{matrix_size_gbyte=}, {metrics_name} (median) = " + f"{statistics.statistics["p50"]}" ) metrics.update(statistics.serialize_statistics()) @@ -92,6 +96,7 @@ def benchmark_collective( warmup_tries: int = 10, trace_dir: str = None, ) -> list[float]: + # pylint: disable=unexpected-keyword-arg """ Helper function to run a collective benchmark on DCN and ICI. @@ -101,10 +106,12 @@ def benchmark_collective( mesh: The JAX device mesh to run the benchmark on. matrix: The input array for the collective operation. matrix_dim: The dimension of the input matrix. - axis_name: The name of the axis over which the op is performed (e.g., "dcn" or "ici"). + axis_name: The name of the axis over which the op is performed (e.g., + "dcn" or "ici"). in_specs: The input sharding specs. out_specs: The output sharding specs. - check_rep: Indicate if replication check is needed. Can be skipped in some situations. + check_rep: Indicate if replication check is needed. Can be skipped in some + situations. jax_op_kwargs: Optional keyword arguments for the JAX operation. num_runs: The number of times to run the benchmark operation for timing. warmup_tries: The number of warmup runs before the actual timing. @@ -159,8 +166,8 @@ def psum_benchmark( dtype: The data type of the matrix. dcn_size: The number of DCN nodes, or number of slices. If 1, then no DCN benchmark is run. - ici_size: The number of chips in a single slice. If 1, then no ICI benchmark - is run. The ICI and DCN + ici_size: The number of chips in a single slice. If 1, then no ICI + benchmark is run. The ICI and DCN Returns: The measured time for the DCN and ICI benchmarks. @@ -304,8 +311,8 @@ def psum_scatter_benchmark( dtype: The data type of the matrix. dcn_size: The number of DCN nodes, or number of slices. If 1, then no DCN benchmark is run. - ici_size: The number of chips in a single slice. If 1, then no ICI benchmark - is run. The ICI and DCN + ici_size: The number of chips in a single slice. If 1, then no ICI + benchmark is run. The ICI and DCN Returns: The measured time for the DCN and ICI benchmarks. @@ -371,8 +378,8 @@ def psum_scatter_benchmark_calculate_metrics( # Calculate metrics for DCN benchmark if dcn_size > 1 and dcn_time_ms_list is not None: - # each sharded matrix size is matrix_size_gbyte / dcn_size and then it needs - # to use (dcn_size - 1) steps in a ring algorithm + # each sharded matrix size is matrix_size_gbyte / dcn_size and then it + # needs to use (dcn_size - 1) steps in a ring algorithm dcn_bandwidth_gbyte_s_list = [ matrix_size_gbyte * (dcn_size - 1) @@ -403,8 +410,8 @@ def psum_scatter_benchmark_calculate_metrics( # Calculate metrics for ICI benchmark if ici_size > 1 and ici_time_ms_list is not None: - # each sharded matrix size is matrix_size_gbyte / ici_size and then it needs - # to use (ici_size - 1) steps in a ring algorithm + # each sharded matrix size is matrix_size_gbyte / ici_size and then it + # needs to use (ici_size - 1) steps in a ring algorithm ici_bandwidth_gbyte_s_list = [ matrix_size_gbyte * (ici_size - 1) / ici_size / (ici_time_ms / 1e3) for ici_time_ms in ici_time_ms_list @@ -452,8 +459,8 @@ def all_gather_benchmark( dtype: The data type of the matrix. dcn_size: The number of DCN nodes, or number of slices. If 1, then no DCN benchmark is run. - ici_size: The number of chips in a single slice. If 1, then no ICI benchmark - is run. The ICI and DCN + ici_size: The number of chips in a single slice. If 1, then no ICI + benchmark is run. The ICI and DCN Returns: The measured time for the DCN and ICI benchmarks. @@ -521,8 +528,8 @@ def all_gather_benchmark_calculate_metrics( # Calculate metrics for DCN benchmark if dcn_size > 1 and dcn_time_ms_list is not None: - # each sharded matrix size is matrix_size_gbyte / dcn_size and then it needs - # to use (dcn_size - 1) steps in a ring algorithm + # each sharded matrix size is matrix_size_gbyte / dcn_size and then it + # needs to use (dcn_size - 1) steps in a ring algorithm dcn_bandwidth_gbyte_s_list = [ matrix_size_gbyte * (dcn_size - 1) / dcn_size / (dcn_time_ms / 1e3) for dcn_time_ms in dcn_time_ms_list @@ -549,8 +556,8 @@ def all_gather_benchmark_calculate_metrics( # Calculate metrics for ICI benchmark if ici_size > 1 and ici_time_ms_list is not None: - # each sharded matrix size is matrix_size_gbyte / ici_size and then it needs - # to use (ici_size - 1) steps in a ring algorithm + # each sharded matrix size is matrix_size_gbyte / ici_size and then it + # needs to use (ici_size - 1) steps in a ring algorithm ici_bandwidth_gbyte_s_list = [ matrix_size_gbyte * (ici_size - 1) / ici_size / (ici_time_ms / 1e3) for ici_time_ms in ici_time_ms_list @@ -598,8 +605,8 @@ def ppermute_benchmark( dtype: The data type of the matrix. dcn_size: The number of DCN nodes, or number of slices. If 1, then no DCN benchmark is run. - ici_size: The number of chips in a single slice. If 1, then no ICI benchmark - is run. The ICI and DCN + ici_size: The number of chips in a single slice. If 1, then no ICI + benchmark is run. The ICI and DCN Returns: The measured time for the DCN and ICI benchmarks. @@ -665,8 +672,8 @@ def ppermute_benchmark_calculate_metrics( # Calculate metrics for DCN benchmark if dcn_size > 1 and dcn_time_ms_list is not None: - # each sharded matrix size is matrix_size_gbyte / dcn_size and then it needs - # to use 1 step + # each sharded matrix size is matrix_size_gbyte / dcn_size and then it + # needs to use 1 step dcn_bandwidth_gbyte_s_list = [ matrix_size_gbyte / dcn_size / (dcn_time_ms / 1e3) for dcn_time_ms in dcn_time_ms_list @@ -693,8 +700,8 @@ def ppermute_benchmark_calculate_metrics( # Calculate metrics for ICI benchmark if ici_size > 1 and ici_time_ms_list is not None: - # each sharded matrix size is matrix_size_gbyte / ici_size and then it needs - # to use 1 step + # each sharded matrix size is matrix_size_gbyte / ici_size and then it + # needs to use 1 step ici_bandwidth_gbyte_s_list = [ matrix_size_gbyte / (ici_time_ms / 1e3) for ici_time_ms in ici_time_ms_list @@ -738,8 +745,8 @@ def all_to_all_benchmark( dtype: The data type of the matrix. dcn_size: The number of DCN nodes, or number of slices. If 1, then no DCN benchmark is run. - ici_size: The number of chips in a single slice. If 1, then no ICI benchmark - is run. The ICI and DCN + ici_size: The number of chips in a single slice. If 1, then no ICI + benchmark is run. The ICI and DCN Returns: The measured time for the DCN and ICI benchmarks. diff --git a/src/benchmark_convolution.py b/src/benchmark_convolution.py index fa77f5f..3329cc5 100644 --- a/src/benchmark_convolution.py +++ b/src/benchmark_convolution.py @@ -61,11 +61,14 @@ def f(x, kernel, mode): print(f"{task_name} Benchmark:") print( - f"Input Shape: {input_shape}, Kernel Shape: {kernel_shape}, Output Shape:" - f" {output.shape}, Padding Mode: {padding_mode}" + f"Input Shape: {input_shape}, " + f"Kernel Shape: {kernel_shape}, " + f"Output Shape: {output.shape}, " + f"Padding Mode: {padding_mode}" ) # Time the operation + # pylint: disable=unexpected-keyword-arg time_ms_list = simple_timeit( f, x, @@ -88,7 +91,9 @@ def convolve_common_calculate_metrics( time_ms_list: list[float], # pylint: disable=unused-argument ) -> Dict[str, Any]: - """Helper function to calculate the metrics for the convolution benchmarks.""" + """ + Helper function to calculate the metrics for the convolution benchmarks. + """ # Build dictionary of all the parameters in the function params = locals().items() exclude_param_keys = {"time_ms_list"} @@ -122,10 +127,10 @@ def convolve_common_calculate_metrics( # Print results print(f"Total flops: {flops}") print( - f"Average Execution Time: {time_ms_statistics.statistics['p50']:.4f} ms" + f"Average Execution Time: {time_ms_statistics.statistics["p50"]:.4f} ms" ) print( - f"FLOPS Utilization(median): {gflops_per_sec_statistics.statistics['p50']:.2f} GFLOPS/sec\n" + f"FLOPS Utilization(median): {gflops_per_sec_statistics.statistics["p50"]:.2f} GFLOPS/sec\n" # pylint: disable=line-too-long ) # Gather the metrics to report. metadata.update({"total_flops": flops}) @@ -321,12 +326,13 @@ def f(x, kernel, stride, dilation, mode): print("lax_conv_general_dilated Benchmark:") print( - f"Input Shape: {input_shape}, Kernel Shape: {kernel_shape}, Output shape:" + f"Input Shape: {input_shape}, Kernel Shape: {kernel_shape}, Output shape:" # pylint: disable=line-too-long f" {output.shape} Stride: {stride}, Dilation: {dilation}, Padding Mode:" f" {padding_mode}" ) # Time the operation + # pylint: disable=unexpected-keyword-arg time_ms_list = simple_timeit( f, x, @@ -398,11 +404,9 @@ def lax_conv_general_dilated_calculate_metrics( ) # Print results print(f"Total flops: {flops}") + print(f"Average Execution Time: {time_ms_statistics.statistics["p50"]:.4f} ms") # pylint: disable=line-too-long print( - f"Average Execution Time: {time_ms_statistics.statistics['p50']:.4f} ms" - ) - print( - f"FLOPS Utilization(median): {gflops_per_sec_statistics.statistics['p50']:.2f} GFLOPS/sec\n" + f"FLOPS Utilization(median): {gflops_per_sec_statistics.statistics["p50"]:.2f} GFLOPS/sec\n" # pylint: disable=line-too-long ) # Gather the metrics to report. metadata.update({"total_flops": flops}) diff --git a/src/benchmark_hbm.py b/src/benchmark_hbm.py index dcad0c8..051013d 100644 --- a/src/benchmark_hbm.py +++ b/src/benchmark_hbm.py @@ -41,6 +41,7 @@ def f(a): jax.block_until_ready(output) # Run the benchmark + # pylint: disable=unexpected-keyword-arg time_ms_list = simple_timeit( jitted_f, a, @@ -74,8 +75,9 @@ def single_chip_hbm_copy_calculate_metrics( metrics_list=bw_gbyte_sec_list, metrics_name="bw_gbyte_sec" ) print( - f"Tensor size: {tensor_size_bytes / 1024**2} MB, time taken (median):" - f" {time_statistics.statistics['p50']:.4f} ms, bandwidth (median): {statistics.statistics['p50']:.3f} GB/s" + f"Tensor size: {tensor_size_bytes / 1024**2} MB, " + f"time taken (median): {time_statistics.statistics["p50"]:.4f} ms, " + f"bandwidth (median): {statistics.statistics["p50"]:.3f} GB/s" ) print() # Gather the metrics to report. diff --git a/src/benchmark_matmul.py b/src/benchmark_matmul.py index 48dd0c1..18c8ea4 100644 --- a/src/benchmark_matmul.py +++ b/src/benchmark_matmul.py @@ -73,6 +73,7 @@ def naive_matmul( trace_dir: str = None, warmup_tries: int = 10, ) -> Dict[str, Any]: + # pylint: disable=unexpected-keyword-arg """Benchmarks the jax.numpy.einsum.""" def f(x, y): @@ -144,9 +145,9 @@ def naive_matmul_calculate_metrics( ) print( f"Total floating-point ops: {total_flops}, Performance (median):" - f" {tflops_per_sec_statistics.statistics['p50']:.2f} TFLOPs / second, Total GBs transferred (median):" + f" {tflops_per_sec_statistics.statistics["p50"]:.2f} TFLOPs / second, Total GBs transferred (median):" # pylint: disable=line-too-long f" {total_gigabytes_transferred:.2f} GB, GBs per second:" - f" {data_transfer_gbyte_sec_statistics.statistics['p50']:.2f} GB/s" + f" {data_transfer_gbyte_sec_statistics.statistics["p50"]:.2f} GB/s" ) print() # Gather the metrics to report. @@ -172,6 +173,7 @@ def single_host_naive_matmul( trace_dir: str = None, warmup_tries: int = 10, ) -> Dict[str, Any]: + # pylint: disable=unexpected-keyword-arg """Benchmarks matmul on a single device without any sharding.""" def f(x, y): @@ -229,10 +231,12 @@ def single_host_naive_matmul_calculate_metrics( metrics_name="data_transfer_gbyte_sec", ) print( - f"Total floating-point ops: {total_flops}, Performance (median):" - f" {tflops_per_sec_statistics.statistics['p50']:.2f} TFLOPs / second, Total GBs transferred (median):" - f" {total_gigabytes_transferred:.2f} GB, GBs per second:" - f" {data_transfer_gbyte_sec_statistics.statistics['p50']:.2f} GB/s" + f"Total floating-point ops: {total_flops}, " + f"Performance (median): " + f"{tflops_per_sec_statistics.statistics["p50"]:.2f} TFLOPs / second, " + f"Total GBs transferred (median): " + f"{total_gigabytes_transferred:.2f} GB, GBs per second: " + f"{data_transfer_gbyte_sec_statistics.statistics["p50"]:.2f} GB/s" ) print() # Gather the metrics to report. @@ -258,6 +262,7 @@ def collective_matmul_one_direction( trace_dir: str = None, warmup_tries: int = 10, ) -> Dict[str, Any]: + # pylint: disable=unexpected-keyword-arg """Benchmarks the collective matmul that does permute in one direction.""" def f(lhs, rhs): @@ -328,7 +333,9 @@ def scanned_call(i, carrys): def collective_matmul_one_direction_calculate_metrics( m: int, k: int, n: int, time_ms_list: list[float] ) -> Dict[str, Any]: - """Calculates the metrics for the collective matmul one direction benchmark.""" + """ + Calculates the metrics for the collective matmul one direction benchmark. + """ # Build dictionary of all the parameters in the function params = locals().items() metadata = get_metrics_helper(params) @@ -348,7 +355,7 @@ def collective_matmul_one_direction_calculate_metrics( ) print( f"Total floating-point ops: {total_flops}, Performance (median):" - f" {tflops_per_sec_statistics.statistics['p50']:.2f} TFLOPs / second" + f" {tflops_per_sec_statistics.statistics["p50"]:.2f} TFLOPs / second" ) print() # Gather the metrics to report. @@ -372,6 +379,7 @@ def collective_matmul_two_directions( trace_dir: str = None, warmup_tries: int = 10, ) -> Dict[str, Any]: + # pylint: disable=unexpected-keyword-arg """Benchmarks the collective matmul that does permute in two directions.""" def f(activations, weights): @@ -481,7 +489,9 @@ def scanned_call(i, carrys): def collective_matmul_two_directions_calculate_metrics( m: int, k: int, n: int, time_ms_list: list[float] ) -> Dict[str, Any]: - """Calculates the metrics for the collective matmul two direction benchmark.""" + """ + Calculates the metrics for the collective matmul two direction benchmark. + """ # Build dictionary of all the parameters in the function params = locals().items() metadata = get_metrics_helper(params) @@ -501,7 +511,7 @@ def collective_matmul_two_directions_calculate_metrics( ) print( f"Total floating-point ops: {total_flops}, Performance (median):" - f" {tflops_per_sec_statistics.statistics['p50']:.2f} TFLOPs / second" + f" {tflops_per_sec_statistics.statistics["p50"]:.2f} TFLOPs / second" ) print() # Gather the metrics to report. @@ -525,6 +535,7 @@ def multilayer_collective_matmul( trace_dir: str = None, warmup_tries: int = 10, ) -> Dict[str, Any]: + # pylint: disable=unexpected-keyword-arg """Benchmarks the multilayer collective matmul.""" def f(act, weights): @@ -601,7 +612,7 @@ def multilayer_collective_matmul_calculate_metrics( ) print( f"Total floating-point ops: {total_flops}, Performance (median):" - f" {tflops_per_sec_statistics.statistics['p50']:.2f} TFLOPs / second" + f" {tflops_per_sec_statistics.statistics["p50"]:.2f} TFLOPs / second" ) print() # Gather the metrics to report. diff --git a/src/benchmark_utils.py b/src/benchmark_utils.py index 01f7711..bfc8b61 100644 --- a/src/benchmark_utils.py +++ b/src/benchmark_utils.py @@ -54,7 +54,8 @@ def simple_timeit( # --- Measurement Loop --- outcomes_ms = [] - # Final barrier after warmup to ensure all hosts are ready to start measuring together. + # Final barrier after warmup to ensure all hosts are ready to start + # measuring together. if is_multihost: multihost_utils.sync_global_devices(f"warmup_done_{task}") @@ -64,7 +65,8 @@ def simple_timeit( jax.block_until_ready(f(*args)) - # Synchronize (Multi-Host Only): Wait for ALL hosts to finish the operation. + # Synchronize (Multi-Host Only): Wait for ALL hosts to finish the + # operation. if is_multihost: multihost_utils.sync_global_devices(f"end_run_{i}_{task}") @@ -120,15 +122,19 @@ def get_metrics_from_trace(trace: dict[str, Any], task: str) -> float: return durations_ms -def is_local_directory_path(dir: str) -> bool: +def is_local_directory_path(directory: str) -> bool: """ Returns true if the path is a local path. """ - if not dir: # Handle None or empty string + if not directory: # Handle None or empty string return False # Heuristics for local paths - return dir.startswith("/") or dir.startswith("./") or dir.startswith("../") + return ( + directory.startswith("/") + or directory.startswith("./") + or directory.startswith("../") + ) def timeit_from_trace( @@ -143,7 +149,7 @@ def timeit_from_trace( """ Time a function with jax.profiler and get the run time from the trace. """ - LOCAL_TRACE_DIR = "/tmp/microbenchmarks_tmptrace" + local_trace_dir = "/tmp/microbenchmarks_tmptrace" is_multihost = jax.process_count() > 1 # warmup loop @@ -163,9 +169,10 @@ def timeit_from_trace( trace_full_dir = f"{trace_dir}/{trace_name}" tmp_trace_dir = trace_full_dir - # If the trace_dir isn't a local path, create one for dumping the trace for parsing and getting metrics. + # If the trace_dir isn't a local path, create one for dumping the trace for + # parsing and getting metrics. if trace_dir and not is_local_directory_path(trace_dir): - tmp_trace_dir = f"{LOCAL_TRACE_DIR}/{trace_name}" + tmp_trace_dir = f"{local_trace_dir}/{trace_name}" with jax.profiler.trace(tmp_trace_dir): for i in range(tries): with jax.profiler.TraceAnnotation(task): @@ -183,14 +190,18 @@ def timeit_from_trace( def maybe_write_metrics_file( metrics_dir, metrics, metadata, test_name, test_start_time, test_end_time ): - """Writes metrics to a JSONL file to be consumed by the XLML metrics pipeline.""" + """ + Writes metrics to a JSONL file to be consumed by the XLML metrics pipeline. + """ local_devices = jax.local_devices() tpu_worker_id = int(os.getenv("TPU_WORKER_ID", "0")) is_multislice = hasattr(local_devices[0], "slice_index") - # For multi-slice workload, the result is only written by the first host on the first slice (slice_index=0, tpu_worker_id=0). - # For single-slice workload, the result is only written by the first host (tpu_worker_id=0). + # For multi-slice workload, the result is only written by the first host on + # the first slice (slice_index=0, tpu_worker_id=0). + # For single-slice workload, the result is only written by the first host + # (tpu_worker_id=0). if is_multislice: if local_devices[0].slice_index != 0 or tpu_worker_id != 0: return @@ -239,7 +250,8 @@ def upload_to_storage(trace_dir: str, local_file: str): except subprocess.CalledProcessError as e: print( - f"Failed to upload '{local_file}' to GCS: '{trace_dir}'. Error: {e.stderr.decode()}" + f"Failed to upload '{local_file}' to GCS: '{trace_dir}'. " + f"Error: {e.stderr.decode()}" ) else: raise KeyError(f"{trace_dir} is not a valid GCS path.") @@ -288,8 +300,9 @@ def rename_xla_dump( ): """ Finds the latest XLA dump file matching '*jit_f*before_optimizations*.txt', - then identifies all other files that share the same 'jit_f.[unique_id]' identifier - and renames them to 'benchmark_name_serialized_params.original_suffix_with_extension'. + then identifies all other files that share the same 'jit_f.[unique_id]' + identifier and renames them to + 'benchmark_name_serialized_params.original_suffix_with_extension'. """ serialized_benchmark_param = "_".join( @@ -302,7 +315,8 @@ def rename_xla_dump( if not matching_anchor_files: print( - f"No files found for anchor pattern: '{anchor_pattern}'. No files will be renamed." + f"No files found for anchor pattern: '{anchor_pattern}'. " + f"No files will be renamed." ) return @@ -317,13 +331,15 @@ def rename_xla_dump( if not jit_id_match: print( - f"Could not extract 'jit_f.[unique_id]' from '{filename_base}'. Cannot proceed with renaming." + f"Could not extract 'jit_f.[unique_id]' from '{filename_base}'. " + f"Cannot proceed with renaming." ) return common_jit_id_prefix = jit_id_match.group(1) - # Find all files in the directory that contain this specific common_jit_id_prefix + # Find all files in the directory that contain this specific + # common_jit_id_prefix all_related_files_pattern = os.path.join( tmp_xla_dump_dir, f"*{common_jit_id_prefix}*" ) @@ -331,7 +347,8 @@ def rename_xla_dump( if not all_related_files: print( - f"No files found containing '{common_jit_id_prefix}'. This is unexpected if an anchor was found." + f"No files found containing '{common_jit_id_prefix}'. " + f"This is unexpected if an anchor was found." ) return @@ -341,26 +358,26 @@ def rename_xla_dump( original_filename = os.path.basename(original_filepath) # Find the specific suffix part *after* the common_jit_id_prefix. - # This regex looks for the common_jit_id_prefix, then captures everything after it, - # ensuring it starts with a dot if there's more. - # Example: if original_filename is 'module_0080.jit_f.cl_747713181.after_codegen.txt' + # This regex looks for the common_jit_id_prefix, then captures + # everything after it, ensuring it starts with a dot if there's more. + # Example: if original_filename is + # 'module_0080.jit_f.cl_747713181.after_codegen.txt' # and common_jit_id_prefix is 'jit_f.cl_747713181' # we want to capture '.after_codegen.txt' suffix_match = re.search( re.escape(common_jit_id_prefix) + r"(\..*)", original_filename ) - + original_suffix_with_extension = "" if suffix_match: - original_suffix_with_extension = suffix_match.group( - 1 - ) # e.g., '.after_codegen.txt' + original_suffix_with_extension = suffix_match.group(1) new_filename = f"{new_base_name}{original_suffix_with_extension}" new_filepath = os.path.join(dest_xla_dump_dir, new_filename) if original_filepath == new_filepath: print( - f"Skipping: '{original_filename}' already has the desired name or path." + f"Skipping: '{original_filename}' already has the desired " + f"name or path." ) continue @@ -369,9 +386,10 @@ def rename_xla_dump( try: os.makedirs(dest_xla_dump_dir, exist_ok=True) shutil.copy(original_filepath, new_filepath) - except Exception as e: + except OSError as e: print( - f"An unexpected error occurred while copy '{original_filepath}': {e}" + f"An unexpected error occurred while copy " + f"'{original_filepath}': {e}" ) else: upload_to_storage( diff --git a/src/run_benchmark.py b/src/run_benchmark.py index aab7e22..7183a2d 100644 --- a/src/run_benchmark.py +++ b/src/run_benchmark.py @@ -89,7 +89,7 @@ def get_benchmark_config(config_path: str) -> Dict[str, Any]: """Load benchmark configuration from a YAML file.""" - with open(config_path, "r") as file: + with open(config_path, "r", encoding="utf-8") as file: return yaml.safe_load(file) @@ -97,7 +97,10 @@ def get_benchmark_config(config_path: str) -> Dict[str, Any]: def get_benchmark_functions( benchmark_name: str, ) -> Tuple[Callable[..., Any], Callable[..., Any]]: - """Dynamically load the benchmark function and its calculate_metrics function from the predefined map.""" + """ + Dynamically load the benchmark function and its calculate_metrics function + from the predefined map. + """ if benchmark_name not in BENCHMARK_MAP: raise ValueError( f"Benchmark {benchmark_name} is not defined in the map." @@ -111,7 +114,7 @@ def get_benchmark_functions( benchmark_func = getattr(module, func_name) except ModuleNotFoundError as e: raise ValueError( - f"Unable to import {module_path}.{func_name}. ModuleNotFoundError {e}." + f"Unable to import {module_path}.{func_name}. ModuleNotFoundError {e}." # pylint: disable=line-too-long ) from e except AttributeError as e: raise ValueError( @@ -161,7 +164,9 @@ def preprocess_benchmark_param( def generate_benchmark_params_sweeping( benchmark_sweep_params: List[Dict[str, Any]], ) -> List[Dict[str, Any]]: - """Generate benchmark parameters by sweeping through the specified ranges.""" + """ + Generate benchmark parameters by sweeping through the specified ranges. + """ generated_params = [] for sweep_params in benchmark_sweep_params: param_sets = {} @@ -186,8 +191,8 @@ def generate_benchmark_params_sweeping( current_value += increase_by else: raise ValueError( - "In sweep mode, user must provide either multiplier or" - " increase_by value." + "In sweep mode, user must provide either multiplier" + " or increase_by value." ) # Add the generated values to the param set param_sets[key] = param_values @@ -217,13 +222,14 @@ def write_to_csv( This function takes a list of dictionaries, where each dictionary contains the 'metadata' and 'metrics' from a benchmark run. It processes each - dictionary by flattening it, and sanitizing the inputs to make sure it's suitable - for Pandas DataFrames creation. All resulting DataFrames are concatenated and written to - the specified CSV file. + dictionary by flattening it, and sanitizing the inputs to make sure it's + suitable for Pandas DataFrames creation. All resulting DataFrames are + concatenated and written to the specified CSV file. Args: csv_path: The path to the output CSV file. - calculate_metrics_results: A list of dictionaries with benchmark results. + calculate_metrics_results: A list of dictionaries with benchmark + results. """ if not calculate_metrics_results: raise ValueError("0 metrics results are collected.") @@ -249,15 +255,19 @@ def flatten_and_sanitize_dict(current_dict: Dict) -> Dict: return output_dict def convert_dict_to_df(target_dict: Dict) -> pd.DataFrame: - """Converts a single benchmark result dictionary to a pandas DataFrame.""" + """ + Converts a single benchmark result dictionary to a pandas DataFrame. + """ flattened_dict = flatten_and_sanitize_dict(target_dict) df = pd.DataFrame(flattened_dict, index=[0]) return df # TODO(hylin2002@) - # This is a temporary workaround to generate a properly formatted CSV file for the output metrics. - # We should revert this PR and refactor the code such that metrics object is a flatten dict that can be easily exported as a CSV. - # For other information that requires nested structures, we should serialize it into a json file." + # This is a temporary workaround to generate a properly formatted CSV file + # for the output metrics. We should revert this PR and refactor the code + # such that metrics object is a flatten dict that can be easily exported as + # a CSV. For other information that requires nested structures, we should + # serialize it into a json file. try: df_list = [ convert_dict_to_df(each) for each in calculate_metrics_results @@ -267,13 +277,15 @@ def convert_dict_to_df(target_dict: Dict) -> pd.DataFrame: df.to_csv(csv_path, index=False, sep="\t") print(f"Metrics written to CSV at {csv_path}.") - except Exception as e: + except Exception as e: # pylint: disable=broad-exception-caught # Temporary workaround to catch all exceptions and print a warning as - # `lax_conv_general_dilated` benchmark fails during nightly test runs at `convert_dict_to_df`. + # `lax_conv_general_dilated` benchmark fails during nightly test runs at + # `convert_dict_to_df`. print(f"Failed to write metrics to CSV: {e}") def run_single_benchmark(benchmark_config: Dict[str, Any]): + # pylint: disable=inconsistent-quotes """Run a single benchmark with one or more configurations.""" # Extract benchmark details benchmark_name = benchmark_config.get("benchmark_name") @@ -291,7 +303,7 @@ def run_single_benchmark(benchmark_config: Dict[str, Any]): warmup_tries = warmup_tries if warmup_tries is not None else 10 if not benchmark_name: - raise ValueError("Each benchmark must have a 'benchmark_name'.") + raise ValueError("Each benchmark must have a benchmark_name.") # Get the benchmark function benchmark_func, calculate_metrics_func = get_benchmark_functions( @@ -308,7 +320,7 @@ def run_single_benchmark(benchmark_config: Dict[str, Any]): benchmark_param, trace_dir=trace_dir ) print( - f"Running benchmark: {benchmark_name} with params: {benchmark_param}" + f"Running benchmark: {benchmark_name} with params: {benchmark_param}" # pylint: disable=line-too-long ) test_start_time = ( datetime.datetime.now(tz=datetime.timezone.utc).isoformat() + "Z" @@ -392,7 +404,6 @@ def main(config_path: str, multithreaded: bool): "XLA_IR_DEBUG": "1", "XLA_HLO_DEBUG": "1", "PJRT_DEVICE": "TPU", - # "LIBTPU_INIT_ARGS": "--xla_tpu_scoped_vmem_limit_kib=25602", }, ) ) @@ -431,7 +442,10 @@ def run_benchmark_multithreaded(benchmark_config): benchmark_name ) - print(f"\n{'=' * 30}Starting benchmark '{benchmark_name}'{'=' * 30}\n") + # pylint: disable=inconsistent-quotes + print( + f"\n{'=' * 30}Starting benchmark '{benchmark_name}'{'=' * 30}\n" + ) # Start a trace if requested test_name = f"t_{benchmark_name}_" + "".join( @@ -469,7 +483,8 @@ def run_benchmark_multithreaded(benchmark_config): future.result() ) # Get the result from the future - # Filter benchmark_results to include only keys present in calculate_metrics_func + # Filter benchmark_results to include only keys present in + # calculate_metrics_func calculate_metrics_params = inspect.signature( calculate_metrics_func ).parameters @@ -479,7 +494,8 @@ def run_benchmark_multithreaded(benchmark_config): if key in calculate_metrics_params } - # Call calculate_metrics_func with the filtered results and benchmark_param + # Call calculate_metrics_func with the filtered results and + # benchmark_param metadata, metrics = calculate_metrics_func( **benchmark_param, **filtered_benchmark_results )