diff --git a/src/main.zig b/src/main.zig index 4f20dd5..a663ccd 100644 --- a/src/main.zig +++ b/src/main.zig @@ -69,7 +69,7 @@ const Weights = struct { // (optional) classifier weights for the logits, on the last layer wcls: [*]f32, // (vocab_size, dim) - fn init(config: *const Config, data: []u8, shared_weights: bool) Weights { + fn init(config: *const Config, data: []f32, shared_weights: bool) Weights { const vocab_size: usize = config.vocab_size; const dim: usize = config.dim; const hidden_dim: usize = config.hidden_dim; @@ -81,7 +81,7 @@ const Weights = struct { var weights: Weights = undefined; - var ptr: [*]f32 = @alignCast(@ptrCast(data)); + var ptr: [*]f32 = data.ptr; weights.token_embedding_table = ptr; ptr += vocab_size * dim; weights.rms_att_weight = ptr; @@ -875,16 +875,16 @@ pub fn main() !void { log("SIMD vector size: {d}\n", .{DEFAULT_VECTOR_WIDTH}); log("\n", .{}); - const data: []align(mem.page_size) u8 = blk: { + const data = blk: { const weights_size: usize = file_size - @sizeOf(ConfigReader); - const buffer = try allocator.alignedAlloc(u8, mem.page_size, weights_size); + const buffer = try allocator.alignedAlloc(u8, @alignOf(f32), weights_size); const read_len = try checkpoint.readAll(buffer); - if (read_len != weights_size) { + if (read_len != weights_size or read_len % @alignOf(f32) != 0) { std.debug.print("error: failed to read checkpoint file\n", .{}); std.process.exit(1); } checkpoint.close(); - break :blk buffer; + break :blk std.mem.bytesAsSlice(f32, buffer); // mmap seems slower // break :blk try std.os.mmap(null, file_size, std.os.PROT.READ, std.os.MAP.PRIVATE, mapped_checkpoint.handle, 0); };