diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index dd65c0298b..c06fec37af 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1210,7 +1210,6 @@ def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: cond &= kernel_ctx.pipeline.F_vlayout == "row" cond &= kernel_ctx.pipeline.F_bias in ["no", "bias"] cond &= kernel_ctx.pipeline.F_qscale == "no" - cond &= problem_ctx.mode == "batch" cond &= kernel_ctx.pipeline.F_skip == "f" cond &= kernel_ctx.pipeline.F_logits == "f" return cond