-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy pathkernel.jl
More file actions
326 lines (280 loc) · 12.5 KB
/
kernel.jl
File metadata and controls
326 lines (280 loc) · 12.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
# kernel and argument handling
"""
emit_kernel!(writer, func_buf, sci, rettype; name, sm_arch=nothing, is_entry=true, num_ctas=nothing, occupancy=nothing)
Compile a StructuredIRCode to Tile IR bytecode.
"""
function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
sci::StructuredIRCode, rettype::Type;
name::String,
sm_arch::Union{String, Nothing} = nothing,
is_entry::Bool = true,
num_ctas::Union{Int, Nothing} = nothing,
occupancy::Union{Int, Nothing} = nothing,
cache::CacheView)
tt = writer.type_table
cb = CodeBuilder(writer.string_table, writer.constant_table, tt)
ctx = CGCtx(; cb, tt, sci, sm_arch, cache)
# Validate non-ghost argument types are concrete
for (i, argtype) in enumerate(sci.argtypes)
is_ghost_type(CC.widenconst(argtype)) && continue
require_concrete_type(argtype, "kernel argument $i")
end
# Build parameter list, handling ghost types and struct destructuring
param_types = TypeId[]
param_mapping = Tuple{Int, Union{Nothing, Symbol}}[]
for (i, argtype) in enumerate(sci.argtypes)
argtype_unwrapped = CC.widenconst(argtype)
if is_ghost_type(argtype_unwrapped)
continue
elseif should_destructure(argtype_unwrapped)
# Destructure TileArray into flat parameters
params = argtype_unwrapped.parameters
ndims = params[2]::Integer
for fi in 1:fieldcount(argtype_unwrapped)
fname = fieldname(argtype_unwrapped, fi)
ftype = fieldtype(argtype_unwrapped, fi)
if fname === :sizes || fname === :strides
fcount = ndims
elem_type = Int32
else
fcount = flat_field_count(ftype)
elem_type = ftype <: Ptr ? Ptr{params[1]} : (ftype <: Tuple ? eltype(ftype) : ftype)
end
for _ in 1:fcount
push!(param_types, tile_type_for_julia!(ctx, elem_type))
push!(param_mapping, (i, fname))
end
end
ctx.arg_types[i] = argtype_unwrapped
else
push!(param_types, tile_type_for_julia!(ctx, argtype_unwrapped))
push!(param_mapping, (i, nothing))
end
end
# Return types
result_types = TypeId[]
if rettype !== Nothing && rettype !== Union{}
push!(result_types, tile_type_for_julia!(ctx, rettype))
end
# Create entry hints if provided
entry_hints = encode_entry_hints(writer, sm_arch, EntryHints(; num_ctas, occupancy))
# Create function
cb = add_function!(writer, func_buf, name, param_types, result_types;
is_entry, entry_hints)
ctx.cb = cb
# Set up argument values
arg_values = make_block_args!(cb, length(param_types))
# Build arg_flat_values map
field_values = Dict{Tuple{Int, Union{Nothing, Symbol}}, Vector{Value}}()
for (param_idx, val) in enumerate(arg_values)
key = param_mapping[param_idx]
if !haskey(field_values, key)
field_values[key] = Value[]
end
push!(field_values[key], val)
end
# Store in context and set up slot/argument mappings
# arg_idx is the direct index into argtypes (2, 3, ...) which matches SlotNumber/Argument
for (key, values) in field_values
arg_idx, field = key
ctx.arg_flat_values[key] = values
if field === nothing
# Regular argument - create concrete CGVal
if length(values) != 1
throw(IRError("Expected exactly one value for argument $arg_idx, got $(length(values))"))
end
val = values[1]
type_id = tile_type_for_julia!(ctx, sci.argtypes[arg_idx])
tv = CGVal(val, type_id, sci.argtypes[arg_idx])
ctx[SlotNumber(arg_idx)] = tv
ctx[Argument(arg_idx)] = tv
end
end
# For destructured args, create lazy CGVals that track the argument index
for (arg_idx, argtype) in ctx.arg_types
tv = arg_ref_value(arg_idx, Union{Symbol, Int}[], argtype)
ctx[SlotNumber(arg_idx)] = tv
ctx[Argument(arg_idx)] = tv
end
# Create TensorViews for all TileArray arguments at kernel entry
for (arg_idx, _) in ctx.arg_types
cache_tensor_view!(ctx, arg_idx)
end
# Create memory ordering token
token_type = Token(tt)
ctx.token_type = token_type
ctx.token = encode_MakeTokenOp!(cb, token_type)
# Emit the structured IR (uses original Julia SSA indices everywhere)
emit_block!(ctx, ctx.sci.entry)
finalize_function!(func_buf, cb, writer.debug_info)
end
# getfield for destructured arguments (lazy chain extension)
function emit_getfield!(ctx::CGCtx, args, @nospecialize(result_type))
length(args) >= 2 || return nothing
# special case: multi-valued loops rely on getfield to extract values
tv = emit_loop_getfield!(ctx, args)
tv !== nothing && return tv
obj_arg = args[1]
field_arg = args[2]
# Extract field name or index
field = get_constant(ctx, field_arg)
# Try to get the object as a CGVal
obj_tv = emit_value!(ctx, obj_arg)
# Tuple indexing: extract component by integer index
if obj_tv !== nothing && obj_tv.tuple !== nothing && field isa Integer
return emit_value!(ctx, obj_tv.tuple[field])
end
# If obj is a lazy arg_ref, extend the chain
if obj_tv !== nothing && is_arg_ref(obj_tv)
arg_idx, chain = obj_tv.arg_ref
if field isa Symbol
# Field access: extend chain with symbol
new_chain = Union{Symbol, Int}[chain..., field]
# Check if this resolves to a scalar field (auto-materialize leaf)
# Don't auto-materialize tuple types - they need indexing first
rt = CC.widenconst(result_type)
if !(rt <: Tuple)
values = get_arg_flat_values(ctx, arg_idx, field)
if values !== nothing && length(values) == 1
# Scalar field - materialize immediately
type_id = tile_type_for_julia!(ctx, rt)
return CGVal(values[1], type_id, rt)
end
end
return arg_ref_value(arg_idx, new_chain, rt)
elseif field isa Integer && !isempty(chain) && chain[end] isa Symbol
# Tuple indexing: chain ends with field name, now indexing into it
# This is a leaf - materialize immediately
field_name = chain[end]
values = get_arg_flat_values(ctx, arg_idx, field_name)
if values !== nothing && 1 <= field <= length(values)
type_id = tile_type_for_julia!(ctx, CC.widenconst(result_type))
return CGVal(values[field], type_id, CC.widenconst(result_type))
end
end
end
nothing
end
# getindex for tuple field access (lazy chain extension)
function emit_getindex!(ctx::CGCtx, args, @nospecialize(result_type))
length(args) >= 2 || return nothing
obj_arg = args[1]
index_arg = args[2]
# Extract constant index
index = get_constant(ctx, index_arg)
index isa Integer || return nothing
# Try to get the object as a CGVal
obj_tv = emit_value!(ctx, obj_arg)
obj_tv === nothing && return nothing
# If obj is a lazy arg_ref, try to materialize or extend the chain
if is_arg_ref(obj_tv)
arg_idx, chain = obj_tv.arg_ref
# If chain ends with a symbol (field name), we're indexing into a tuple field
# Try to materialize immediately
if !isempty(chain) && chain[end] isa Symbol
field_name = chain[end]
values = get_arg_flat_values(ctx, arg_idx, field_name)
if values !== nothing && 1 <= index <= length(values)
type_id = tile_type_for_julia!(ctx, CC.widenconst(result_type))
return CGVal(values[index], type_id, CC.widenconst(result_type))
end
end
# Otherwise extend the chain
new_chain = Union{Symbol, Int}[chain..., Int(index)]
return arg_ref_value(arg_idx, new_chain, CC.widenconst(result_type))
end
# Not an arg_ref - not handled here
nothing
end
#=============================================================================
Subprogram compilation
=============================================================================#
"""
emit_subprogram!(ctx, func, arg_types, block_args, block_type_ids) -> Vector{Value}
Compile a Julia function into the current region body. Resolves `func` via the cuTile
pipeline (method_instance → code_ircode → StructuredIRCode), creates a sub-context,
maps `block_args` to the function's positional arguments, emits the body, and returns
the yielded result values.
- `func`: the Julia function to compile (e.g., `+`, `max`, a lambda)
- `arg_types`: Julia types for each block arg (e.g., `[Tile{Float32,()}]` repeated)
- `block_args`: IR `Value`s from the enclosing region (e.g., `[acc, elem]`)
- `block_type_ids`: `TypeId`s corresponding to each block arg
A `YieldOp` is emitted with the return value(s).
"""
function emit_subprogram!(ctx::CGCtx, func, arg_types::Vector,
block_args::Vector{Value}, block_type_ids::Vector{TypeId})
# 1. Resolve method instance
argtuple = Tuple{arg_types...}
world = ctx.cache.world
mi = @something(
method_instance(func, argtuple;
world, method_table=cuTileMethodTable),
method_instance(func, argtuple; world),
error("No method found for $func($(join(arg_types, ", ")))")
)
# 2. Compile through cuTile pipeline (cached)
if !haskey(ctx.cache, mi)
error("Expected $func($(join(arg_types, ", "))) to be cached already by inference.")
end
sci, _ = emit_ir(ctx.cache, mi)
# 3. Create sub-context
sub_ctx = CGCtx(; ctx.cb, ctx.tt, sci,
ctx.token, ctx.token_type,
ctx.type_cache, ctx.sm_arch,
ctx.cache)
# 4. Map arguments dynamically: ghost args get ghost_value, non-ghost args
# consume block_args sequentially.
n_argtypes = length(sci.argtypes)
block_idx = 1 # cursor into block_args
if mi.def.isva
# Varargs: fixed argtypes are 1:n_argtypes-1, last is the varargs tuple.
# Map fixed args (ghost or non-ghost), then pack remaining block_args
# into a tuple CGVal for the varargs argument.
for i in 1:(n_argtypes - 1)
argtype = sci.argtypes[i]
if is_ghost_type(CC.widenconst(argtype))
sub_ctx[Argument(i)] = ghost_value(argtype)
else
sub_ctx[Argument(i)] = CGVal(block_args[block_idx], block_type_ids[block_idx], arg_types[block_idx])
block_idx += 1
end
end
# Pack remaining block_args into a virtual tuple for the varargs argument
va_offset = n_argtypes + length(block_args) # high indices to avoid collision
tuple_components = Any[]
for j in block_idx:length(block_args)
sub_ctx[Argument(va_offset + j - block_idx + 1)] = CGVal(block_args[j], block_type_ids[j], arg_types[j])
push!(tuple_components, Argument(va_offset + j - block_idx + 1))
end
constants = Vector{Any}(fill(nothing, length(tuple_components)))
sub_ctx[Argument(n_argtypes)] = tuple_value(sci.argtypes[end], tuple_components, constants)
else
for i in 1:n_argtypes
argtype = sci.argtypes[i]
if is_ghost_type(CC.widenconst(argtype))
sub_ctx[Argument(i)] = ghost_value(argtype)
else
sub_ctx[Argument(i)] = CGVal(block_args[block_idx], block_type_ids[block_idx], arg_types[block_idx])
block_idx += 1
end
end
end
# 5. Emit body (skip terminator — we yield manually)
emit_block!(sub_ctx, sci.entry; skip_terminator=true)
# 6. Extract return value and yield
ret = sci.entry.terminator::ReturnNode
tv = emit_value!(sub_ctx, ret.val)
if tv.tuple !== nothing
# Tuple return: resolve each component to a concrete Value
results = Value[]
for ref in tv.tuple
component = emit_value!(sub_ctx, ref)
component === nothing && throw(IRError("Cannot resolve tuple component in subprogram return"))
push!(results, component.v::Value)
end
else
results = tv.v isa Vector ? tv.v : [tv.v]
end
encode_YieldOp!(ctx.cb, results)
return results
end