Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions backends/metax_gpu/cinn/cinn_interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,15 @@ extern C_Status MetaxLaunchKernel(void* dev_ptr,
void* stream);

// --- From passes/pass_manager.cc ---
// Applies custom graph optimization passes
extern C_Status MetaxApplyCustomPass(void* dev_ptr, void* ir_module);
// Applies a vendor-specific custom pass identified by name
extern C_Status MetaxApplyCustomPass(void* dev_ptr,
const char* pass_name,
void* ir_func);

// Queries the vendor's desired ordered pass pipeline
extern C_Status MetaxQueryPassPipeline(void* dev_ptr,
char pass_names[][128],
int* count);

// ============================================================
// Interface Initialization
Expand Down Expand Up @@ -102,6 +109,7 @@ void InitCinnInterface(C_DeviceInterface* device_interface) {

// 6. Register Compilation Strategy interface
metax_cinn_impl.apply_custom_pass = MetaxApplyCustomPass;
metax_cinn_impl.query_pass_pipeline = MetaxQueryPassPipeline;

// 7. Attach the populated dispatch table to the Paddle device interface
if (device_interface) {
Expand Down
86 changes: 82 additions & 4 deletions backends/metax_gpu/cinn/passes/pass_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cstring>
#include <iostream>

#include "paddle/phi/backends/device_ext.h"
Expand All @@ -20,10 +21,87 @@ namespace paddle {
namespace custom_device {
namespace metax {

// Applies custom graph optimization passes.
// Currently a no-op stub; returns success immediately.
C_Status MetaxApplyCustomPass(void* dev_ptr, void* ir_module) {
// VLOG(3) << "[MetaX] MetaxApplyCustomPass called (No-op)";
// ============================================================
// MetaxApplyCustomPass
// ============================================================
// Called by the CINN framework when it encounters a pass name in the pipeline
// that is NOT a built-in pass. `ir_func` is a `cinn::ir::LoweredFunc*` cast
// to void*.
C_Status MetaxApplyCustomPass(void* dev_ptr,
const char* pass_name,
void* ir_func) {
std::string name(pass_name);

if (name == "MetaxDebugLogPass") {
// A trivial pass that simply logs the function pointer address.
// Demonstrates the custom-pass mechanism without modifying IR.
std::cout << "[MetaX] MetaxDebugLogPass: ir_func=" << ir_func << std::endl;
return C_Status::C_SUCCESS;
}

std::cerr << "[MetaX] Unknown custom pass: " << name << std::endl;
return C_Status::C_FAILED;
}

// ============================================================
// MetaxQueryPassPipeline
// ============================================================
// Defines the ordered pass pipeline for MetaX GPU hardware.
//
// Rules:
// - Built-in pass names (understood by CINN) are executed by the framework.
// - Unknown names are forwarded to MetaxApplyCustomPass().
C_Status MetaxQueryPassPipeline(void* dev_ptr,
char pass_names[][128],
int* count) {
// Full NVGPU-equivalent pipeline with one custom pass inserted.
static const char* kPipeline[] = {
"Simplify",
"EliminateInvariantLoop",
"RealizeCompositeReduce",
"ReindexTransposeBuffer",
"ReplaceCrossThreadReduction",
"ReplaceCrossBlockReduction",
"SetCudaAxisInfo",
"RemoveGpuForLoops",
"CudaSyncThreadsDropIfThenElse",
"TransBufferWithDynamicShape",
"SimplifyUnitBlock",
"MapExternCall",
"ExternCallMultiOutputShallowStore",
"Simplify",
"IfFusion",
"EntailLoopCondition",
// Vendor-defined custom pass: forwarded to MetaxApplyCustomPass().
"MetaxDebugLogPass",
"RearrangeLoadInstruction",
"VectorizeForTrans",
"Simplify",
"RemoveScheduleBlock",
"IfFold",
"LowerIntrin",
"PrepareBufferCastExprs",
};
static const int kPipelineSize =
static_cast<int>(sizeof(kPipeline) / sizeof(kPipeline[0]));

// If pass_names is null the caller only wants the count.
if (pass_names == nullptr) {
*count = kPipelineSize;
return C_Status::C_SUCCESS;
}

// Check buffer capacity.
if (*count < kPipelineSize) {
*count = kPipelineSize;
return C_Status::C_FAILED;
}

for (int i = 0; i < kPipelineSize; ++i) {
std::strncpy(pass_names[i], kPipeline[i], 127);
pass_names[i][127] = '\0';
}
*count = kPipelineSize;
return C_Status::C_SUCCESS;
}

Expand Down
Loading