diff --git a/backends/metax_gpu/cinn/cinn_interface.cc b/backends/metax_gpu/cinn/cinn_interface.cc index a01bd0e67e..316888a81c 100644 --- a/backends/metax_gpu/cinn/cinn_interface.cc +++ b/backends/metax_gpu/cinn/cinn_interface.cc @@ -68,8 +68,15 @@ extern C_Status MetaxLaunchKernel(void* dev_ptr, void* stream); // --- From passes/pass_manager.cc --- -// Applies custom graph optimization passes -extern C_Status MetaxApplyCustomPass(void* dev_ptr, void* ir_module); +// Applies a vendor-specific custom pass identified by name +extern C_Status MetaxApplyCustomPass(void* dev_ptr, + const char* pass_name, + void* ir_func); + +// Queries the vendor's desired ordered pass pipeline +extern C_Status MetaxQueryPassPipeline(void* dev_ptr, + char pass_names[][128], + int* count); // ============================================================ // Interface Initialization @@ -102,6 +109,7 @@ void InitCinnInterface(C_DeviceInterface* device_interface) { // 6. Register Compilation Strategy interface metax_cinn_impl.apply_custom_pass = MetaxApplyCustomPass; + metax_cinn_impl.query_pass_pipeline = MetaxQueryPassPipeline; // 7. Attach the populated dispatch table to the Paddle device interface if (device_interface) { diff --git a/backends/metax_gpu/cinn/passes/pass_manager.cc b/backends/metax_gpu/cinn/passes/pass_manager.cc index 15d73d0738..cf4e652130 100644 --- a/backends/metax_gpu/cinn/passes/pass_manager.cc +++ b/backends/metax_gpu/cinn/passes/pass_manager.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include "paddle/phi/backends/device_ext.h" @@ -20,10 +21,87 @@ namespace paddle { namespace custom_device { namespace metax { -// Applies custom graph optimization passes. -// Currently a no-op stub; returns success immediately. -C_Status MetaxApplyCustomPass(void* dev_ptr, void* ir_module) { - // VLOG(3) << "[MetaX] MetaxApplyCustomPass called (No-op)"; +// ============================================================ +// MetaxApplyCustomPass +// ============================================================ +// Called by the CINN framework when it encounters a pass name in the pipeline +// that is NOT a built-in pass. `ir_func` is a `cinn::ir::LoweredFunc*` cast +// to void*. +C_Status MetaxApplyCustomPass(void* dev_ptr, + const char* pass_name, + void* ir_func) { + std::string name(pass_name); + + if (name == "MetaxDebugLogPass") { + // A trivial pass that simply logs the function pointer address. + // Demonstrates the custom-pass mechanism without modifying IR. + std::cout << "[MetaX] MetaxDebugLogPass: ir_func=" << ir_func << std::endl; + return C_Status::C_SUCCESS; + } + + std::cerr << "[MetaX] Unknown custom pass: " << name << std::endl; + return C_Status::C_FAILED; +} + +// ============================================================ +// MetaxQueryPassPipeline +// ============================================================ +// Defines the ordered pass pipeline for MetaX GPU hardware. +// +// Rules: +// - Built-in pass names (understood by CINN) are executed by the framework. +// - Unknown names are forwarded to MetaxApplyCustomPass(). +C_Status MetaxQueryPassPipeline(void* dev_ptr, + char pass_names[][128], + int* count) { + // Full NVGPU-equivalent pipeline with one custom pass inserted. + static const char* kPipeline[] = { + "Simplify", + "EliminateInvariantLoop", + "RealizeCompositeReduce", + "ReindexTransposeBuffer", + "ReplaceCrossThreadReduction", + "ReplaceCrossBlockReduction", + "SetCudaAxisInfo", + "RemoveGpuForLoops", + "CudaSyncThreadsDropIfThenElse", + "TransBufferWithDynamicShape", + "SimplifyUnitBlock", + "MapExternCall", + "ExternCallMultiOutputShallowStore", + "Simplify", + "IfFusion", + "EntailLoopCondition", + // Vendor-defined custom pass: forwarded to MetaxApplyCustomPass(). + "MetaxDebugLogPass", + "RearrangeLoadInstruction", + "VectorizeForTrans", + "Simplify", + "RemoveScheduleBlock", + "IfFold", + "LowerIntrin", + "PrepareBufferCastExprs", + }; + static const int kPipelineSize = + static_cast(sizeof(kPipeline) / sizeof(kPipeline[0])); + + // If pass_names is null the caller only wants the count. + if (pass_names == nullptr) { + *count = kPipelineSize; + return C_Status::C_SUCCESS; + } + + // Check buffer capacity. + if (*count < kPipelineSize) { + *count = kPipelineSize; + return C_Status::C_FAILED; + } + + for (int i = 0; i < kPipelineSize; ++i) { + std::strncpy(pass_names[i], kPipeline[i], 127); + pass_names[i][127] = '\0'; + } + *count = kPipelineSize; return C_Status::C_SUCCESS; }