Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ext/CUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module CUDAExt

using cuTile
using cuTile: TileArray, Constant, CGOpts, CuTileResults, emit_code
using cuTile: TileArray, Constant, CGOpts, CuTileResults, emit_code, sanitize_name

using CompilerCaching: CacheView, method_instance, results

Expand Down Expand Up @@ -51,7 +51,7 @@ function emit_function(cache::CacheView, mi::Core.MethodInstance)
res = results(cache, ci)
res.cuda_func !== nothing && return res.cuda_func

kernel_name = string(mi.def.name)
kernel_name = sanitize_name(string(mi.def.name))
cumod = CuModule(cubin)
cufunc = CuFunction(cumod, kernel_name)
res.cuda_func = cufunc
Expand Down
5 changes: 4 additions & 1 deletion src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,9 @@ function emit_ir(cache::CacheView, mi::Core.MethodInstance)
return res.julia_ir
end

# Encode characters outside [a-zA-Z0-9_] as _XX hex escapes for PTX/MLIR compatibility.
sanitize_name(name::String) = replace(name, r"[^a-zA-Z0-9_]" => c -> "_$(string(UInt8(only(c)); base=16, pad=2))")

"""
emit_code(cache, mi) -> Vector{UInt8}

Expand All @@ -357,7 +360,7 @@ function emit_code(cache::CacheView, mi::Core.MethodInstance)
# Generate Tile IR bytecode
bytecode = write_bytecode!(1) do writer, func_buf
emit_kernel!(writer, func_buf, sci, rettype;
name = string(mi.def.name),
name = sanitize_name(string(mi.def.name)),
sm_arch = opts.sm_arch,
num_ctas = opts.num_ctas,
occupancy = opts.occupancy,
Expand Down
2 changes: 1 addition & 1 deletion src/cuTile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,6 @@ include("language/operations.jl")
include("language/atomics.jl")

public launch
launch() = error("Please import CUDA.jl before using `cuTile.launch`.")
launch(args...) = error("Please import CUDA.jl before using `cuTile.launch`.")

end # module cuTile
7 changes: 7 additions & 0 deletions test/execution/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1058,3 +1058,10 @@ const _EXEC_TEST_GLOBAL_CONST = Float32(1 / log(2))
@test Array(b) ≈ Array(a) .* (scale * _EXEC_TEST_GLOBAL_CONST)
end

@testset "kernel name with !" begin
function kernel!()
return
end
ct.launch(kernel!, 1)
end