Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 93 additions & 1 deletion src/metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L

add_argument_metadata!(job, mod, entry)

add_globals_metadata!(job, mod, entry)

add_module_metadata!(job, mod)
end

Expand Down Expand Up @@ -550,6 +552,96 @@ function argument_type_name(typ)
end
end

# global metadata generation
#
# module metadata is used to identify global buffers that are used as kernel arguments.
function add_globals_metadata!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
entry::LLVM.Function)
entry_ft = function_type(entry)

## argument info
arg_infos = Metadata[]


# Iterate through arguments and create metadata for them
globs = globals(mod)
@show globs
i = 1
for gv in globs
@show gv
gv_typ = global_value_type(gv)
(isconstant(gv) && addrspace(gv_typ) == 3) || continue
# if job.config.optimize
# @assert parameters(entry_ft)[arg.idx] isa LLVM.PointerType
# else
# parameters(entry_ft)[arg.idx] isa LLVM.PointerType || continue
# end

# # NOTE: we emit the bare minimum of argument metadata to support
# # bindless argument encoding. Actually using the argument encoder
# # APIs (deprecated in Metal 3) turned out too difficult, given the
# # undocumented nature of the argument metadata, and the complex
# # arguments we encounter with typical Julia kernels.
global_infos = Metadata[]

push!(global_infos, MDString("air.global_binding"))
push!(global_infos, Metadata(gv))

md = Metadata[]

# argument index
push!(md, Metadata(ConstantInt(Int32(-1))))

push!(md, MDString("air.buffer"))

push!(md, MDString("air.location_index"))
push!(md, Metadata(ConstantInt(Int32(i-1))))

# XXX: unknown
push!(md, Metadata(ConstantInt(Int32(1))))

push!(md, MDString("air.read_write")) # TODO: Check for const array

push!(md, MDString("air.address_space"))
push!(md, Metadata(ConstantInt(Int32(addrspace(global_value_type(gv))))))

val_type = global_value_type(gv)
# val_type = if value_type(gv) <: Core.LLVMPtr
# arg.typ.parameters[1]
# else
# arg.typ
# end

@show gv_typ
@show isconstant(gv)
# @show isconstant(gv_typ)
# @show Int32(alignment(gv))

push!(md, MDString("air.arg_type_size"))
push!(md, Metadata(ConstantInt(Int32(4))))

push!(md, MDString("air.arg_type_align_size"))
push!(md, Metadata(ConstantInt(Int32(alignment(gv)))))

push!(md, MDString("air.arg_type_name"))
# push!(md, MDString(repr(arg.typ)))

push!(md, MDString("air.arg_name"))
push!(md, MDString(String(LLVM.name(gv))))

push!(arg_infos, MDNode(md))

i += 1
end

println()
arg_infos = MDNode(arg_infos)

push!(metadata(mod)["air.global_bindings"], arg_infos)

return
end

# argument metadata generation
#
# module metadata is used to identify buffers that are passed as kernel arguments.
Expand All @@ -565,7 +657,7 @@ function add_argument_metadata!(@nospecialize(job::CompilerJob), mod::LLVM.Modul
args = classify_arguments(job, entry_ft; post_optimization=job.config.optimize)
i = 1
for arg in args
arg.idx === nothing && continue
arg.idx === nothing && continue
if job.config.optimize
@assert parameters(entry_ft)[arg.idx] isa LLVM.PointerType
else
Expand Down
Loading