-
Notifications
You must be signed in to change notification settings - Fork 73
Combine for RaggedIterDomain #5716
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: raggediterdomain_clone
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1054,6 +1054,106 @@ std::pair<IterDomain*, RaggedIterDomain*> RaggedIterDomain::partition( | |||||||||||
| return {component_id, ragged_id}; | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| IterDomain* RaggedIterDomain::combine( | ||||||||||||
| IterDomain* component, | ||||||||||||
| RaggedIterDomain* ragged) { | ||||||||||||
| NVF_ERROR(component != nullptr, "combine: component IterDomain is null"); | ||||||||||||
| NVF_ERROR(ragged != nullptr, "combine: ragged IterDomain is null"); | ||||||||||||
|
|
||||||||||||
| NVF_ERROR( | ||||||||||||
| !component->isA<RaggedIterDomain>(), | ||||||||||||
| "combine: component must be a regular IterDomain, got RaggedIterDomain: ", | ||||||||||||
| component->toString()); | ||||||||||||
|
|
||||||||||||
| // Validate that component and ragged have compatible properties | ||||||||||||
| NVF_ERROR_EQ( | ||||||||||||
| component->getParallelType(), | ||||||||||||
| ParallelType::Serial, | ||||||||||||
| "Combining parallelized IterDomain not supported: ", | ||||||||||||
| component->toString()); | ||||||||||||
|
|
||||||||||||
| NVF_ERROR_EQ( | ||||||||||||
| ragged->getParallelType(), | ||||||||||||
| ParallelType::Serial, | ||||||||||||
| "Combining parallelized RaggedIterDomain not supported: ", | ||||||||||||
| ragged->toString()); | ||||||||||||
|
|
||||||||||||
| NVF_ERROR_EQ( | ||||||||||||
| component->getIterType(), | ||||||||||||
| IterType::Iteration, | ||||||||||||
| "combine: only IterType::Iteration is supported for component, got ", | ||||||||||||
| component->getIterType(), | ||||||||||||
| " for IterDomain: ", | ||||||||||||
| component->toString()); | ||||||||||||
|
|
||||||||||||
| NVF_ERROR_EQ( | ||||||||||||
| ragged->getIterType(), | ||||||||||||
| IterType::Iteration, | ||||||||||||
| "combine: only IterType::Iteration is supported for ragged, got ", | ||||||||||||
| ragged->getIterType(), | ||||||||||||
| " for RaggedIterDomain: ", | ||||||||||||
| ragged->toString()); | ||||||||||||
|
|
||||||||||||
| // Validate component-ragged pairing when Partition definition is available | ||||||||||||
| // (Option 3 of doc/dev/ragged_iter_domain_combine_design_doc.md). | ||||||||||||
| // Only validate when the RaggedIterDomain has a direct Partition definition. | ||||||||||||
| // After propagation (e.g., set() operations), the definition may be nullptr, | ||||||||||||
| // in which case we trust the user to provide the correct component. | ||||||||||||
| if (ragged->definition() != nullptr && | ||||||||||||
| ragged->definition()->isA<Partition>()) { | ||||||||||||
| auto* partition = ragged->definition()->as<Partition>(); | ||||||||||||
| IterDomain* expected_component = partition->component(); | ||||||||||||
|
|
||||||||||||
| NVF_ERROR( | ||||||||||||
| component == expected_component, | ||||||||||||
| "combine: component mismatch. The provided component does not match ", | ||||||||||||
| "the component from the Partition that created this " | ||||||||||||
| "RaggedIterDomain.\n", | ||||||||||||
| " Provided component: ", | ||||||||||||
| component->toString(), | ||||||||||||
| "\n", | ||||||||||||
| " Expected component: ", | ||||||||||||
| expected_component->toString()); | ||||||||||||
| } | ||||||||||||
| // If no Partition definition (after set, in segmented fusion, or external | ||||||||||||
| // input), trust the user and proceed without validation | ||||||||||||
|
|
||||||||||||
| // The combined extent is the sum of all extents in the ragged dimension | ||||||||||||
| // For a 1D extents tensor [e0, e1, ..., en-1], the total is sum(extents) | ||||||||||||
| TensorView* extents_tv = ragged->extents(); | ||||||||||||
| NVF_ERROR(extents_tv != nullptr, "combine: ragged extents tensor is null"); | ||||||||||||
|
|
||||||||||||
| // It is still assumed the extents tensor is just 1D | ||||||||||||
| NVF_ERROR_EQ( | ||||||||||||
| std::ssize(extents_tv->getLogicalDomain()), | ||||||||||||
| 1, | ||||||||||||
| "Unexpected rank of extent tensor: ", | ||||||||||||
| extents_tv->toString()); | ||||||||||||
|
|
||||||||||||
| auto container = component->container(); | ||||||||||||
| auto zero = container->zeroVal(DataType::Index); | ||||||||||||
|
|
||||||||||||
| // Create a symbolic extent for the combined IterDomain | ||||||||||||
| // This represents the sum of all ragged extents, i.e., | ||||||||||||
| // sum(extents_tv, {0}). We could use the sum output as the extent | ||||||||||||
| // but we would need to extract the scalar value out of the 0-dim | ||||||||||||
| // tensor. For now, we leave it as a symbolic Val. | ||||||||||||
| Val* combined_extent = | ||||||||||||
| IrBuilder::createInContainer<Val>(container, DataType::Index); | ||||||||||||
|
Comment on lines
+1141
to
+1142
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: The symbolic extent approach leaves the combined extent as an opaque Check whether this symbolic extent will cause issues during:
|
||||||||||||
|
|
||||||||||||
| // Create the combined IterDomain with the symbolic extent | ||||||||||||
| IterDomain* combined_id = IterDomainBuilder(zero, combined_extent) | ||||||||||||
| .parallel_type(ParallelType::Serial) | ||||||||||||
| .iter_type(IterType::Iteration) | ||||||||||||
| .build(); | ||||||||||||
|
|
||||||||||||
| // Create the Combine expression linking component + ragged -> combined | ||||||||||||
| IrBuilder::createInContainer<Combine>( | ||||||||||||
| container, combined_id, component, ragged); | ||||||||||||
naoyam marked this conversation as resolved.
Show resolved
Hide resolved
Comment on lines
+1151
to
+1152
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: Unlike If the
Suggested change
Note: This would require updating the Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! |
||||||||||||
|
|
||||||||||||
| return combined_id; | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| TensorDomain::TensorDomain( | ||||||||||||
| IrBuilderPasskey passkey, | ||||||||||||
| std::vector<IterDomain*> logical_domain, | ||||||||||||
|
|
||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Missing validation that
componentandraggedare semantically compatible or from the same partition. Consider verifying:Partitionoperation (component->definition() == ragged->definition())Without this, arbitrary unrelated IterDomains could be combined.