diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 20a3cc1217be4..b5064c16d2c54 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -2651,6 +2651,11 @@ def setUpClass(cls): ) ) + @classmethod + def tearDownClass(cls): + cls._stack.close() + super().tearDownClass() + def check_code(self, code_str, num_kernels, num_allocs, num_deallocs): FileCheck().check(get_func_call()).check_count( get_kernel_launch(), @@ -2996,6 +3001,804 @@ def foo(x, y): self.check_code(code[0], num_kernels=3, num_allocs=3, num_deallocs=4) +<<<<<<< HEAD +======= +def autotune_select_algorithm_wrapper_return_multi(): + def wrapper(*args, **kwargs): + kwargs["return_multi_template"] = True + return autotune_select_algorithm(*args, **kwargs) + + return wrapper + + +def benchmark_choice_override_timings(benchmark_request, *args, aten_time, triton_time): + if isinstance( + benchmark_request, (ExternKernelBenchmarkRequest, ExternKernelCaller) + ): + return aten_time + elif isinstance(benchmark_request, (TritonBenchmarkRequest, TritonTemplateCaller)): + return triton_time + else: + return float("inf") + + +def mock_benchmark_choice_wrapper(aten_time, triton_time): + return functools.partial( + benchmark_choice_override_timings, aten_time=aten_time, triton_time=triton_time + ) + + +@instantiate_parametrized_tests +class TestEpilogueFusionStaticAnalysis(TestCase): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._stack = contextlib.ExitStack() + cls._stack.enter_context( + config.patch( + { + "max_autotune": True, + "autotune_fallback_to_aten": False, + "benchmark_epilogue_fusion": False, + "prologue_fusion": False, + } + ) + ) + + @classmethod + def tearDownClass(cls): + cls._stack.close() + super().tearDownClass() + + @contextlib.contextmanager + def get_common_patches( + self, + async_compile: bool, + persistent_tma: bool, + *, + aten_time: float | None = None, + triton_time: float | None = None, + mock_n_spills: int | None = None, + mock_fused_n_regs: int | None = None, + mock_unfused_n_regs: int | None = None, + epilogue_runtime: float | None = None, + ): + from torch._inductor.autotune_process import TritonBenchmarkRequest + from torch._inductor.runtime.triton_heuristics import CachingAutotuner + from torch._inductor.scheduler import BaseSchedulerNode + + common_patches = [ + config.patch( + { + "triton.enable_persistent_tma_matmul": persistent_tma, + "compile_threads": 1 + if not async_compile + else config.compile_threads, + } + ), + mock.patch( + "torch._inductor.kernel.mm.autotune_select_algorithm", + autotune_select_algorithm_wrapper_return_multi(), + ), + fresh_cache(), + ] + + if aten_time is not None and triton_time is not None: + common_patches.extend( + [ + mock.patch.object( + AlgorithmSelectorCache, + "benchmark_choice", + mock_benchmark_choice_wrapper(aten_time, triton_time), + ), + mock.patch( + "torch._inductor.autotune_process.run_autotune_in_subprocess", + mock_benchmark_choice_wrapper(aten_time, triton_time), + ), + ] + ) + + if mock_n_spills is not None or mock_fused_n_regs is not None: + original_precompile = CachingAutotuner.precompile + + def mock_precompile(self, *args, **kwargs): + original_precompile(self, *args, **kwargs) + for launcher in self.launchers: + if mock_n_spills is not None: + launcher.n_spills = mock_n_spills + if mock_fused_n_regs is not None: + launcher.n_regs = mock_fused_n_regs + + common_patches.append( + mock.patch.object(CachingAutotuner, "precompile", mock_precompile) + ) + + if mock_unfused_n_regs is not None: + original_bmreq_precompile = TritonBenchmarkRequest.precompile + + def mock_bmreq_precompile(self): + original_bmreq_precompile(self) + self.n_regs = mock_unfused_n_regs + + common_patches.append( + mock.patch.object( + TritonBenchmarkRequest, "precompile", mock_bmreq_precompile + ) + ) + + if epilogue_runtime is not None: + common_patches.append( + mock.patch.object( + BaseSchedulerNode, + "_get_estimated_runtime", + lambda node: epilogue_runtime, + ) + ) + + with contextlib.ExitStack() as stack: + for p in common_patches: + stack.enter_context(p) + + yield + + def _get_mm_inputs(self): + """Common matmul inputs for epilogue fusion tests.""" + a = torch.randn(512, 1024, device=GPU_TYPE, dtype=torch.bfloat16) + b = torch.randn(1024, 2048, device=GPU_TYPE, dtype=torch.bfloat16) + return a, b + + def _get_mm_with_epilogue_fn(self): + """Common function: matmul with type cast and add epilogue.""" + + def f(a, b): + return (a @ b).to(torch.float32) + 1.0 + + return f + + @contextlib.contextmanager + def _setup_mm_heuristic(self, use_async_compile: bool): + """Setup MM heuristic with single GemmConfig and handle cleanup.""" + mm_heuristic = CUDAMMTemplateConfigHeuristic() + original_mm_configs = mm_heuristic.mm_configs + gemm_config = GemmConfig(64, 64, 32, 2, 4, group_m=8) + mm_heuristic.mm_configs = [gemm_config] + + if use_async_compile: + torch._inductor.async_compile.AsyncCompile.wait_pool_ready() + + try: + yield + finally: + mm_heuristic.mm_configs = original_mm_configs + + @unittest.skipIf(not has_triton_tma_device(), "Need TMA support in Triton") + @skipIfXpu(msg="Bad tma config can be covered by XPU TMA") + @parametrize("use_async_compile", (True, False)) + def test_template_bad_epilogue_fusion(self, use_async_compile: bool): + def f(a, b): + return (a @ b).to(torch.float32) + + a = torch.randn(512, 1152, device=GPU_TYPE, dtype=torch.bfloat16) + b = torch.randn(1152, 7680, device=GPU_TYPE, dtype=torch.bfloat16) + + if GPU_TYPE == "xpu": + tma_heuristic = XPUPersistentTMATemplateConfigHeuristic() + mm_heuristic = XPUMMTemplateConfigHeuristic() + else: + tma_heuristic = CUDAPersistentTMATemplateConfigHeuristic() + mm_heuristic = CUDAMMTemplateConfigHeuristic() + + # Save original configs to restore later + original_tma_mm_configs = tma_heuristic.mm_configs + original_mm_mm_configs = mm_heuristic.mm_configs + + good_tma_config = GemmConfig(128, 64, 64, 4, 8, group_m=8) + if use_async_compile: + torch._inductor.async_compile.AsyncCompile.wait_pool_ready() + + original_compile_kernel = Scheduler.compile_kernel + + for simulate_fusion_failure in [True, False]: + torch._dynamo.reset() + tma_heuristic.mm_configs = [good_tma_config] + # Regular mm template gets no configs + mm_heuristic.mm_configs = [] + + def mock_compile_kernel_fail_fusion(self, nodes, hint_override=None): + fut, mod = original_compile_kernel(self, nodes, hint_override) + + if simulate_fusion_failure and len(nodes) > 1: + if fut is not None: + + def failing_result_fn(): + raise RuntimeError + + return torch._inductor.codecache.LambdaFuture( + failing_result_fn, fut.future + ), mod + else: + + class FailingPrecompile: + def precompile(self): + raise RuntimeError + + mod.triton_ = FailingPrecompile() + return None, mod + + return fut, mod + + # Different paths: + # benchmark_epilogue_fusion: True -> always multi_template + # causes benchmarking always + # benchmark_epilogue_fusion: False -> TritonTemplateBuffer + # returns speedup_from_fusion automatically as True + # What we want: force multi template -> no benchmarking with safety + try: + with ( + self.get_common_patches(use_async_compile, True), + mock.patch( + "torch._inductor.autotune_process.run_autotune_in_subprocess", + mock_benchmark_choice_wrapper( + aten_time=float("inf"), triton_time=0.1 + ), + ), + mock.patch.object( + AlgorithmSelectorCache, + "benchmark_choice", + mock_benchmark_choice_wrapper( + aten_time=float("inf"), triton_time=0.1 + ), + ), + mock.patch.object( + Scheduler, + "compile_kernel", + mock_compile_kernel_fail_fusion, + ), + ): + compiled_f = torch.compile(f, mode="max-autotune") + out, code = run_and_get_code(compiled_f, a, b) + + if not simulate_fusion_failure: + # Fusion should occur + FileCheck().check("triton_tem_fused__to_copy_mm").run(code[0]) + else: + # Fusion should fail to occur, unfused kernels + FileCheck().check("triton_tem_fused_mm").check( + "triton_poi_fused__to_copy" + ).run(code[0]) + + if not config.cpp_wrapper: + torch.testing.assert_close(out, f(a, b), atol=1e-2, rtol=1e-2) + finally: + # Restore original configs + tma_heuristic.mm_configs = original_tma_mm_configs + mm_heuristic.mm_configs = original_mm_mm_configs + + @unittest.skipIf( + not HAS_CUDA_AND_TRITON, "Scheduler static analysis only tested on cuda" + ) + @parametrize( + "test_case", + [ + "spills_reject", # High register spillage should reject fusion + "timing_reject", # Triton much slower than aten should reject fusion + "accept_with_triton_faster", # Low spills and good timing should accept fusion + "accept_with_aten_faster", # Fusion even if aten is slightly faster + ], + ) + @parametrize("use_async_compile", (True, False)) + def test_template_epilogue_fusion_static_analysis( + self, test_case: str, use_async_compile: bool + ): + """ + Test static analysis decisions for matmul epilogue fusions. + + Tests the scheduler logic that decides whether to fuse epilogues without + benchmarking, based on: + 1. Register spillage (n_spills <= 8 required for fusion) + 2. Runtime comparison (epilogue_runtime + ms_min_choice > choice_timings[choice]) + """ + if test_case == "spills_reject": + mock_n_spills = 100 + triton_time = 0.1 + aten_time = float("inf") + expect_fusion = False + elif test_case == "timing_reject": + mock_n_spills = 0 + triton_time = 100.0 + aten_time = 0.001 + expect_fusion = False + elif test_case == "accept_with_triton_faster": + mock_n_spills = 0 + triton_time = 0.1 + aten_time = float("inf") + expect_fusion = True + elif test_case == "accept_with_aten_faster": + mock_n_spills = 0 + triton_time = 0.1 + aten_time = 0.09999 + expect_fusion = True + else: + raise RuntimeError("Invalid test case") + + f = self._get_mm_with_epilogue_fn() + a, b = self._get_mm_inputs() + + with self._setup_mm_heuristic(use_async_compile): + with self.get_common_patches( + use_async_compile, + False, + aten_time=aten_time, + triton_time=triton_time, + mock_n_spills=mock_n_spills, + ): + compiled_f = torch.compile(f) + _, code = run_and_get_code(compiled_f, a, b) + + if expect_fusion: + FileCheck().check("triton_tem_fused__to_copy_add_mm_0").run(code[0]) + elif triton_time < aten_time: + FileCheck().check("triton_tem_fused_mm").check( + "triton_poi_fused__to_copy" + ).run(code[0]) + else: + FileCheck().check_not("triton_tem_fused_mm").check( + "triton_poi_fused__to_copy" + ).run(code[0]) + + @unittest.skipIf( + not HAS_CUDA_AND_TRITON, "Scheduler static analysis only tested on cuda" + ) + @skipIfRocm(msg="Scheduler static analysis needs investigation on ROCm") + @parametrize("fuse_epilogue", (True, False)) + @parametrize("use_async_compile", (True, False)) + def test_template_epilogue_fusion_extra_reads( + self, fuse_epilogue: bool, use_async_compile: bool + ): + """Test epilogue fusion with extra reads (bias and scale tensors).""" + + def fn(x, w, bias, scale): + out = torch.mm(x, w) + return out * scale + bias + + torch._dynamo.reset() + + x = torch.randn(512, 1024, device=GPU_TYPE, dtype=torch.bfloat16) + w = torch.randn(1024, 2048, device=GPU_TYPE, dtype=torch.bfloat16) + bias = torch.randn(512, 2048, device=GPU_TYPE, dtype=torch.bfloat16) + scale = torch.randn(512, 2048, device=GPU_TYPE, dtype=torch.bfloat16) + + epilogue_runtime = 0.5 + aten_time = 1.0 + unfused_time = aten_time + epilogue_runtime + # 2 times extra bytes / 3 + estimated_fused = epilogue_runtime * 2 / 3 + + if fuse_epilogue: + # triton + 1 read / 2 extra memory ratio * epilogue_runtime + # < aten_time + epilogue_runtime + triton_time = unfused_time - estimated_fused - 0.01 + else: + triton_time = unfused_time - estimated_fused + 0.01 + + with self._setup_mm_heuristic(use_async_compile): + with self.get_common_patches( + use_async_compile, + False, + aten_time=aten_time, + triton_time=triton_time, + epilogue_runtime=epilogue_runtime, + ): + compiled_fn = torch.compile(fn) + _, code = run_and_get_code(compiled_fn, x, w, bias, scale) + + if fuse_epilogue: + FileCheck().check("triton_tem_fused_add_mm_mul").run(code[0]) + else: + FileCheck().check_not("triton_tem").check( + "triton_poi_fused_add_mul" + ).run(code[0]) + + @unittest.skipIf( + not HAS_CUDA_AND_TRITON, "Scheduler static analysis only tested on cuda" + ) + @skipIfRocm(msg="Scheduler static analysis needs investigation on ROCm") + @parametrize( + "test_case", + [ + "occupancy_ratio_accept", # ratio > 0.5, accept via Branch B + "occupancy_ratio_reject", # ratio <= 0.5, reject + ], + ) + @parametrize("use_async_compile", (True, False)) + def test_template_epilogue_fusion_occupancy_ratio( + self, test_case: str, use_async_compile: bool + ): + """ + Test occupancy ratio branch of _fuse_epilogue. + + Occupancy calculation (assuming regs_per_sm = 65536): + blocks = regs_per_sm // (n_regs * threads_per_block) + threads_per_block = num_warps * warp_size = 4 * 32 = 128 + """ + triton_time = 0.1 + epilogue_runtime = triton_time + + if test_case == "occupancy_ratio_accept": + # blocks_unfused=5, blocks_fused=3, ratio=0.6 > 0.5 -> accept + # aten slightly faster to verify fusion picks triton even when aten wins + mock_unfused_n_regs, mock_fused_n_regs = 100, 160 + aten_time = 0.09 + expect_fusion = True + elif test_case == "occupancy_ratio_reject": + # blocks_unfused=8, blocks_fused=2, ratio=0.25 < 0.5 -> reject + mock_unfused_n_regs, mock_fused_n_regs = 64, 200 + aten_time = 0.11 + expect_fusion = False + else: + raise RuntimeError("Invalid test case") + + f = self._get_mm_with_epilogue_fn() + a, b = self._get_mm_inputs() + + with self._setup_mm_heuristic(use_async_compile): + with self.get_common_patches( + use_async_compile, + False, + aten_time=aten_time, + triton_time=triton_time, + mock_n_spills=0, + mock_fused_n_regs=mock_fused_n_regs, + mock_unfused_n_regs=mock_unfused_n_regs, + epilogue_runtime=epilogue_runtime, + ): + compiled_f = torch.compile(f) + _, code = run_and_get_code(compiled_f, a, b) + + if expect_fusion: + FileCheck().check("triton_tem_fused__to_copy_add_mm").run(code[0]) + else: + FileCheck().check("triton_tem_fused_mm").check( + "triton_poi_fused__to_copy" + ).run(code[0]) + + @unittest.skipIf( + not HAS_CUDA_AND_TRITON, "Scheduler static analysis only tested on cuda" + ) + @skipIfRocm(msg="Scheduler static analysis needs investigation on ROCm") + @parametrize( + "test_case", + [ + "memory_bound_accept", + "memory_bound_reject_low_occupancy", + ], + ) + @parametrize("use_async_compile", (True, False)) + def test_template_epilogue_fusion_dominating_epilogue( + self, test_case: str, use_async_compile: bool + ): + """ + Test memory-bound epilogue branch of _fuse_epilogue (Branch C). + + When Branches A and B fail (low occupancy), fusion can still be accepted + if the epilogue is memory-bound (ms2 > 2*ms1) AND blocks_fused > 1. + """ + triton_time = 0.1 + + if test_case == "memory_bound_accept": + # blocks_fused=2 > 1, ms2=0.3 > 2*ms1=0.2 -> Branch C accepts + # aten slightly faster to verify fusion picks triton even when aten wins + mock_unfused_n_regs, mock_fused_n_regs = 64, 256 + aten_time = 0.09 + epilogue_runtime = 0.3 + expect_fusion = True + elif test_case == "memory_bound_reject_low_occupancy": + # blocks_fused=1, ms2=0.3 > 2*ms1=0.2 BUT blocks_fused <= 1 -> reject + mock_unfused_n_regs, mock_fused_n_regs = 64, 512 + aten_time = 0.11 + epilogue_runtime = 0.3 + expect_fusion = False + else: + raise RuntimeError("Invalid test case") + + f = self._get_mm_with_epilogue_fn() + a, b = self._get_mm_inputs() + + with self._setup_mm_heuristic(use_async_compile): + with self.get_common_patches( + use_async_compile, + False, + aten_time=aten_time, + triton_time=triton_time, + mock_n_spills=0, + mock_fused_n_regs=mock_fused_n_regs, + mock_unfused_n_regs=mock_unfused_n_regs, + epilogue_runtime=epilogue_runtime, + ): + compiled_f = torch.compile(f) + _, code = run_and_get_code(compiled_f, a, b) + + if expect_fusion: + FileCheck().check("triton_tem_fused__to_copy_add_mm").run(code[0]) + else: + FileCheck().check("triton_tem_fused_mm").check( + "triton_poi_fused__to_copy" + ).run(code[0]) + + @unittest.skipIf( + not HAS_CUDA_AND_TRITON, "Scheduler static analysis only tested on cuda" + ) + @parametrize("use_async_compile", (True, False)) + def test_epilogue_prologue_fusion_cache_preserved(self, use_async_compile: bool): + def f(a, b): + # Prologue: pointwise operation on input 'a' before matmul + a_transformed = a + 1.0 + # Matmul + mm_result = a_transformed @ b + # Epilogue: pointwise operation on output after matmul + return mm_result + 2.0 + + torch._dynamo.reset() + + # Use float32 to avoid low precision heuristic rejection + a = torch.randn(512, 1024, device=GPU_TYPE, dtype=torch.float32) + b = torch.randn(1024, 2048, device=GPU_TYPE, dtype=torch.float32) + + triton_time = 0.1 + aten_time = float("inf") + epilogue_runtime = 0.05 + + # Always allow prologue fusion heuristics + def always_allow_prologue(*args): + return True + + with self._setup_mm_heuristic(use_async_compile): + with self.get_common_patches( + use_async_compile, + False, + aten_time=aten_time, + triton_time=triton_time, + mock_n_spills=0, + epilogue_runtime=epilogue_runtime, + ): + # Enable prologue fusion so both epilogue and prologue are considered + with config.patch(prologue_fusion=True): + # Bypass prologue heuristics that might reject the fusion + with mock.patch.object( + Scheduler, + "check_prologue_fusion_heuristics_fusable", + always_allow_prologue, + ): + compiled_f = torch.compile(f) + # If the bug exists, this will fail with: + # "ValueError: min() arg is an empty sequence" + # when get_min_choice() is called during prologue fusion + run_and_get_code(compiled_f, a, b) + + +def simple_fn(): + return 42 + + +class TestMaxAutotuneAsyncPipelined(TestMaxAutotune, TestEpilogueFusionStaticAnalysis): + """Tests for AsyncPipelinedAutotuning path.""" + + SKIP_TESTS = { + "test_inf_timing": "Logs not consistent with async pipelined autotuning", + "test_non_contiguous_input_mm_plus_mm": "Flaky on trunk", + "test_autotune_device_guard": "Flaky on trunk", + "test_template_bad_epilogue_fusion": "Benchmarking path is different", + "test_persistent_tma_epilogue_fusion_store_cache": "Epilogue fusion disabled in async pipelining", + # XPU specific skips due to lack of multiprocess tensor reduction support (issue #170636) + "test_max_autotune_addmm_persistent_tma": "No XPU implementation for multiprocess tensor reduction", + "test_max_autotune_regular_mm_persistent_tma": "No XPU implementation for multiprocess tensor reduction", + "test_max_autotune_regular_mm_persistent_tma_strided": "No XPU implementation for multiprocess tensor reduction", + "test_max_autotune_addmm_tma_dynamic_outer_dim": "No XPU implementation for multiprocess tensor reduction", + "test_max_autotune_regular_mm_tma_dynamic_outer_dim": "No XPU implementation for multiprocess tensor reduction", + } + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._async_config = config.patch( + { + "pipeline_max_autotune_gemm": True, + "benchmark_epilogue_fusion": False, + "test_configs.max_mm_configs": 1, + } + ) + cls._async_config.__enter__() + + @classmethod + def tearDownClass(cls): + cls._async_config.__exit__(None, None, None) + super().tearDownClass() + + def setUp(self): + super().setUp() + test_name = self._testMethodName + for skip_test_name in self.SKIP_TESTS: + if skip_test_name in test_name or TEST_XPU: + self.skipTest(self.SKIP_TESTS[skip_test_name]) + + def tearDown(self): + super().tearDown() + AutotuneProcessPool.shutdown_instance() + # Clear the AsyncAutotuner cache to prevent test pollution + AsyncAutotuner.choice_hash_to_future.clear() + + @config.patch(max_autotune_gemm=True) + def test_async_autotuner_cache_same_inputs(self): + M, K, N = 128, 64, 256 + M2, K2, N2 = 256, 128, 64 + + def three_matmuls(a1, b1, a2, b2, a3, b3): + return torch.mm(a1, b1), torch.mm(a2, b2), torch.mm(a3, b3) + + # Same shapes for first two matmuls + a1 = torch.randn(M, K, device=GPU_TYPE, dtype=torch.bfloat16) + b1 = torch.randn(K, N, device=GPU_TYPE, dtype=torch.bfloat16) + a2 = torch.randn(M, K, device=GPU_TYPE, dtype=torch.bfloat16) + b2 = torch.randn(K, N, device=GPU_TYPE, dtype=torch.bfloat16) + + # Different shapes for third matmul + a3 = torch.randn(M2, K2, device=GPU_TYPE, dtype=torch.bfloat16) + b3 = torch.randn(K2, N2, device=GPU_TYPE, dtype=torch.bfloat16) + + compiled_fn = torch.compile(three_matmuls) + result = compiled_fn(a1, b1, a2, b2, a3, b3) + + # Verify correctness + expected = three_matmuls(a1, b1, a2, b2, a3, b3) + for r, e in zip(result, expected): + torch.testing.assert_close(r, e, atol=1e-2, rtol=1e-2) + + # With max_mm_configs=1, we get 2 configs total (1 per unique shape) + # First two matmuls share the same shape, third has different shape + # 1 aten, 1 triton config + cache_size = len(AsyncAutotuner.choice_hash_to_future) + self.assertEqual( + cache_size, 4, "Cache should have 2 entries (one per unique input shape)" + ) + + @patch( + "torch._inductor.autotune_process.AUTOTUNE_POOL_INACTIVITY_TIMEOUT", + 2, + ) + def test_autotune_process_pool_inactivity_shutdown(self): + AutotuneProcessPool.shutdown_instance() + AutotuneProcessPool._shutdown_for_inactivity = False + + pool_instance = AutotuneProcessPool.get_instance() + pool_instance.warm_up() + + future = pool_instance.submit(simple_fn) + result = future.result() + self.assertEqual(result, 42) + self.assertIsNotNone(pool_instance._pool) + + time.sleep(5) + + self.assertIsNone(pool_instance._pool) + self.assertIsNone(pool_instance._timer) + self.assertTrue(AutotuneProcessPool._shutdown_for_inactivity) + + @patch( + "torch._inductor.autotune_process.AUTOTUNE_POOL_INACTIVITY_TIMEOUT", + 2, + ) + def test_autotune_process_pool_inactivity_shutdown_warmup_only(self): + """Test that the pool shuts down from inactivity even when only warmup is called.""" + AutotuneProcessPool.shutdown_instance() + AutotuneProcessPool._shutdown_for_inactivity = False + + pool_instance = AutotuneProcessPool.get_instance() + warmup_future = pool_instance.warm_up() + warmup_future.result() + + self.assertIsNotNone(pool_instance._pool) + + time.sleep(5) + + self.assertIsNone(pool_instance._pool) + self.assertIsNone(pool_instance._timer) + self.assertTrue(AutotuneProcessPool._shutdown_for_inactivity) + + @patch( + "torch._inductor.autotune_process.AUTOTUNE_POOL_INACTIVITY_TIMEOUT", + 2, + ) + @config.patch(max_autotune=True) + def test_compilation_after_inactivity(self): + """Test that compilation after pool inactivity shutdown uses synchronous path.""" + + # Reset state + AutotuneProcessPool.shutdown_instance() + AutotuneProcessPool._shutdown_for_inactivity = False + AsyncAutotuner.choice_hash_to_future.clear() + torch._dynamo.reset() + + # First compilation - should use pipelined path + self.assertTrue(use_pipelined_autotuning()) + + def matmul_fn(a, b): + return torch.mm(a, b) + + a1 = torch.randn(64, 32, device=GPU_TYPE, dtype=torch.bfloat16) + b1 = torch.randn(32, 64, device=GPU_TYPE, dtype=torch.bfloat16) + a2 = torch.randn(128, 64, device=GPU_TYPE, dtype=torch.bfloat16) + b2 = torch.randn(64, 128, device=GPU_TYPE, dtype=torch.bfloat16) + + compiled_fn = torch.compile(matmul_fn) + compiled_fn(a1, b1) + + # Verify pipelined path was used (cache should have entries) + cache_entries_after_first = len(AsyncAutotuner.choice_hash_to_future) + self.assertGreater(cache_entries_after_first, 0) + + # Wait for inactivity shutdown + time.sleep(5) + + self.assertTrue(AutotuneProcessPool._shutdown_for_inactivity) + self.assertFalse(use_pipelined_autotuning()) + + AsyncAutotuner.choice_hash_to_future.clear() + torch._dynamo.reset() + + compiled_fn2 = torch.compile(matmul_fn) + compiled_fn2(a2, b2) + + cache_entries_after_second = len(AsyncAutotuner.choice_hash_to_future) + self.assertEqual(cache_entries_after_second, 0) + + @config.patch(max_autotune_gemm=True) + def test_triton_error_precompilation_and_autotuning(self): + """ + Test error handling when do_autotuning throws NoValidChoicesError + for Triton choices. The fallback to extern kernels should still work. + """ + + def mock_do_autotuning(*args, **kwargs): + raise NoValidChoicesError("Simulated: all Triton choices failed") + + a = torch.randn(64, 32, device=GPU_TYPE, dtype=torch.bfloat16) + b = torch.randn(32, 64, device=GPU_TYPE, dtype=torch.bfloat16) + + def mm_func(a, b, epilogue): + if epilogue: + return torch.mm(a, b) + 1.0 + else: + return torch.mm(a, b) + + def test_aten_chosen(): + for epilogue in (True, False): + torch._dynamo.reset() + compiled_fn = torch.compile(mm_func) + out, code = run_and_get_code(compiled_fn, a, b, epilogue) + FileCheck().check_not("triton_tem").run(code[0]) + + with mock.patch.object( + AlgorithmSelectorCache, "do_autotuning", mock_do_autotuning + ): + test_aten_chosen() + + original_start = AsyncAutotuner.start + bmreq = _TestBenchmarkRequest( + exc=RuntimeError("Simulated benchmark failure in subprocess") + ) + bmreq.module_cache_key = "" + + def mock_start(choices, inputs_key): + for choice in choices: + if isinstance(choice, TritonTemplateCaller): + choice.bmreq = bmreq + return original_start(choices, inputs_key) + + with mock.patch.object(AsyncAutotuner, "start", mock_start): + test_aten_chosen() + + +>>>>>>> f7baaadc64c ([Inductor] Fix flaky epilogue fusion tests by adding missing tearDown… (#3244)) if __name__ == "__main__": from torch._inductor.utils import is_big_gpu