Skip to content

Conversation

@kurapov-peter
Copy link
Contributor

Putting up this dirty draft for early feedback/questions. I'm putting together some tests to run a e2e llama3.1 going through linalg on tensors. The goal is to generate some nice linalg that would be optimization friendly. At the moment, there are just functional blocks and pieces that are just smoke-tested. These include naive implementations for rotary embeddings, feed forward, rms, and a bunch of other small snippets that are useful to implement the model. These are already enough to put an attention block together. It'd be nice to test it against the original implementation, but that'd require fairscale as a dependency. For now I only added pytest and kept the pipeline as simple as possible. I also reused the example with the schedule, so now it is a part of every test.

@rengolin
Copy link
Member

Should this be in examples?

@kurapov-peter
Copy link
Contributor Author

The e2e should be, yup, but this is mostly tests and getters.

@kurapov-peter
Copy link
Contributor Author

I moved the whole thing to examples and added attention the list of tests.

Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Contributor

@rolfmorel rolfmorel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Have left some comments inline.

[xq_scores_map, keys_scores_map, scores_map],
[parallel, parallel, parallel, parallel, reduction],
)
def compute_scores(q_val, k_val, score_val):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be written as a linalg.contract, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can move generics to contract and elementwise later. TPP-MLIR has linalg generalization because some passes don't work with the new ops.

Comment on lines 1200 to 1205
module = generate_module(ctx, ir_type)
bufferize_module(ctx, module)
schedule = create_schedule(ctx)
apply_schedule(module, schedule)
pm = create_pass_pipeline(ctx)
pm.run(module.operation)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
module = generate_module(ctx, ir_type)
bufferize_module(ctx, module)
schedule = create_schedule(ctx)
apply_schedule(module, schedule)
pm = create_pass_pipeline(ctx)
pm.run(module.operation)
module = generate_module(ctx, ir_type)
schedule = create_schedule(ctx)
apply_schedule(module, schedule)

Just move the passes from inside bufferize_module(ctx, module) and create_pass_pipeline(ctx) into the start and end of the schedule, i.e. with transform.apply_registered_pass.

I know this antipattern originates in an example script we merged, but we should not let this proliferate. It clearly is already confusing people.

Comment on lines 115 to 144
return schedule


def apply_schedule(kernel: ir.Module, schedule: ir.Module) -> None:
interpreter.apply_named_sequence(
payload_root=kernel,
transform_root=schedule.body.operations[0],
transform_module=schedule,
)
Copy link
Contributor

@rolfmorel rolfmorel Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return schedule
def apply_schedule(kernel: ir.Module, schedule: ir.Module) -> None:
interpreter.apply_named_sequence(
payload_root=kernel,
transform_root=schedule.body.operations[0],
transform_module=schedule,
)
return named_seq

If we do this, you can simply do:

schedule = create_schedule()
schedule.apply(module)

If you need access to the Module around the named_sequence, just ask for its .parent.

@@ -1,2 +1,2 @@
import ctypes
import torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this PR didn't introduce it, though looking at it now, I feel we should think about compartmentalizing code that depends on heavy dependencies a bit more. That is, not have it in the same module with code that doesn't have the dependency, e.g. get_packed_arg.

Comment on lines 75 to 85
def create_schedule(ctx: ir.Context) -> ir.Module:
"""
Create an MLIR module containing transformation schedule.
The schedule provides partial lowering to scalar operations.
Args:
ctx: MLIR context.
"""
with ctx, ir.Location.unknown(context=ctx):
# Create transform module.
schedule = ir.Module.create()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def create_schedule(ctx: ir.Context) -> ir.Module:
"""
Create an MLIR module containing transformation schedule.
The schedule provides partial lowering to scalar operations.
Args:
ctx: MLIR context.
"""
with ctx, ir.Location.unknown(context=ctx):
# Create transform module.
schedule = ir.Module.create()
def create_schedule() -> ir.Module:
schedule = ir.Module.create()

And just de-indent the rest of the function.

Comment on lines 133 to 135
def bufferize_module(ctx: ir.Context, kernel: ir.Module) -> None:
with ctx:
pm = PassManager("builtin.module")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def bufferize_module(ctx: ir.Context, kernel: ir.Module) -> None:
with ctx:
pm = PassManager("builtin.module")
def bufferize_module(kernel: ir.Module) -> None:
pm = PassManager("builtin.module")

Comment on lines 1160 to 1164
def to_ir_type(type_str, ctx):
if type_str == "f32":
return ir.F32Type.get(context=ctx)
elif type_str == "f64":
return ir.F64Type.get(context=ctx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def to_ir_type(type_str, ctx):
if type_str == "f32":
return ir.F32Type.get(context=ctx)
elif type_str == "f64":
return ir.F64Type.get(context=ctx)
def to_ir_type(type_str):
if type_str == "f32":
return ir.F32Type.get()
elif type_str == "f64":
return ir.F64Type.get()

In effect, these .get() methods are doing a ir.Context.current under the hood when you don't pass a context explicitly (just like the Op builders).

Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there's a lot of smaller comments that we can leave for post-merge review. This is an example, and a complicated one at that, and we don't want to over-engineer something that will soon move to a better program.

[xq_scores_map, keys_scores_map, scores_map],
[parallel, parallel, parallel, parallel, reduction],
)
def compute_scores(q_val, k_val, score_val):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can move generics to contract and elementwise later. TPP-MLIR has linalg generalization because some passes don't work with the new ops.

@rengolin rengolin merged commit bd87f3f into llvm:main Nov 27, 2025
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants