diff --git a/ext/CUDAExt.jl b/ext/CUDAExt.jl index 4bb7e61..03a20fe 100644 --- a/ext/CUDAExt.jl +++ b/ext/CUDAExt.jl @@ -1,7 +1,7 @@ module CUDAExt using cuTile -using cuTile: TileArray, Constant, CGOpts, CuTileResults, emit_code +using cuTile: TileArray, Constant, CGOpts, CuTileResults, emit_code, sanitize_name using CompilerCaching: CacheView, method_instance, results @@ -51,7 +51,7 @@ function emit_function(cache::CacheView, mi::Core.MethodInstance) res = results(cache, ci) res.cuda_func !== nothing && return res.cuda_func - kernel_name = string(mi.def.name) + kernel_name = sanitize_name(string(mi.def.name)) cumod = CuModule(cubin) cufunc = CuFunction(cumod, kernel_name) res.cuda_func = cufunc diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index 86f8f8b..52ecb4b 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -337,6 +337,9 @@ function emit_ir(cache::CacheView, mi::Core.MethodInstance) return res.julia_ir end +# Encode characters outside [a-zA-Z0-9_] as _XX hex escapes for PTX/MLIR compatibility. +sanitize_name(name::String) = replace(name, r"[^a-zA-Z0-9_]" => c -> "_$(string(UInt8(only(c)); base=16, pad=2))") + """ emit_code(cache, mi) -> Vector{UInt8} @@ -357,7 +360,7 @@ function emit_code(cache::CacheView, mi::Core.MethodInstance) # Generate Tile IR bytecode bytecode = write_bytecode!(1) do writer, func_buf emit_kernel!(writer, func_buf, sci, rettype; - name = string(mi.def.name), + name = sanitize_name(string(mi.def.name)), sm_arch = opts.sm_arch, num_ctas = opts.num_ctas, occupancy = opts.occupancy, diff --git a/src/cuTile.jl b/src/cuTile.jl index c040c05..d4ae69e 100644 --- a/src/cuTile.jl +++ b/src/cuTile.jl @@ -39,6 +39,6 @@ include("language/operations.jl") include("language/atomics.jl") public launch -launch() = error("Please import CUDA.jl before using `cuTile.launch`.") +launch(args...) = error("Please import CUDA.jl before using `cuTile.launch`.") end # module cuTile diff --git a/test/execution/basic.jl b/test/execution/basic.jl index 750f5d3..8140fe1 100644 --- a/test/execution/basic.jl +++ b/test/execution/basic.jl @@ -1058,3 +1058,10 @@ const _EXEC_TEST_GLOBAL_CONST = Float32(1 / log(2)) @test Array(b) ≈ Array(a) .* (scale * _EXEC_TEST_GLOBAL_CONST) end +@testset "kernel name with !" begin + function kernel!() + return + end + ct.launch(kernel!, 1) +end +