From 72faa35357e5662162e01c36087d85e9035235b3 Mon Sep 17 00:00:00 2001 From: Kbz-8 Date: Thu, 12 Mar 2026 01:06:20 +0100 Subject: [PATCH] fixing example --- example/main.zig | 18 +++++++++++------ src/Module.zig | 52 ++++++++++++++++++++++++++++++++++++++++++++++-- src/Runtime.zig | 3 ++- src/lib.zig | 2 +- 4 files changed, 65 insertions(+), 10 deletions(-) diff --git a/example/main.zig b/example/main.zig index 19b8dc1..ef61322 100644 --- a/example/main.zig +++ b/example/main.zig @@ -30,13 +30,17 @@ pub fn main() !void { var runner_cache: std.ArrayList(Runner) = try .initCapacity(allocator, screen_height); defer { for (runner_cache.items) |*runner| { - runner.rt.deinit(allocator); + allocator.free(runner.heap); } runner_cache.deinit(allocator); } for (0..screen_height) |_| { - var rt = try spv.Runtime.init(allocator, &module); + const heap = try allocator.alloc(u8, module.needed_runtime_bytes); + errdefer allocator.free(heap); + + var buffer_allocator: std.heap.FixedBufferAllocator = .init(heap); + var rt = try spv.Runtime.init(buffer_allocator.allocator(), &module); (try runner_cache.addOne(allocator)).* = .{ .allocator = allocator, .surface = surface, @@ -46,6 +50,7 @@ pub fn main() !void { .time = try rt.getResultByName("time"), .pos = try rt.getResultByName("pos"), .res = try rt.getResultByName("res"), + .heap = heap, }; } @@ -105,6 +110,7 @@ const Runner = struct { time: spv.SpvWord, pos: spv.SpvWord, res: spv.SpvWord, + heap: []u8, fn runWrapper(self: *Self, y: usize, pixel_map: [*]u32, timer: f32) void { @call(.always_inline, Self.run, .{ self, y, pixel_map, timer }) catch |err| { @@ -122,11 +128,11 @@ const Runner = struct { var output: [4]f32 = undefined; for (0..screen_width) |x| { - try rt.writeInput(&.{timer}, self.time); - try rt.writeInput(&.{ @floatFromInt(screen_width), @floatFromInt(screen_height) }, self.res); - try rt.writeInput(&.{ @floatFromInt(x), @floatFromInt(y) }, self.pos); + try rt.writeInput(std.mem.asBytes(&timer), self.time); + try rt.writeInput(std.mem.asBytes(&[_]f32{ @floatFromInt(screen_width), @floatFromInt(screen_height) }), self.res); + try rt.writeInput(std.mem.asBytes(&[_]f32{ @floatFromInt(x), @floatFromInt(y) }), self.pos); try rt.callEntryPoint(self.allocator, self.entry); - try rt.readOutput(output[0..], self.color); + try rt.readOutput(std.mem.asBytes(output[0..]), self.color); const rgba = self.surface.mapRgba( @intCast(@max(@min(@as(i32, @intFromFloat(output[0] * 255.0)), 255), 0)), diff --git a/src/Module.zig b/src/Module.zig index 575985a..c0d26fd 100644 --- a/src/Module.zig +++ b/src/Module.zig @@ -38,6 +38,44 @@ pub const ModuleError = error{ OutOfMemory, }; +const AllocatorWrapper = struct { + child_allocator: std.mem.Allocator, + total_bytes_allocated: usize = 0, + + pub fn allocator(self: *AllocatorWrapper) std.mem.Allocator { + return .{ + .ptr = self, + .vtable = &.{ + .alloc = alloc, + .resize = resize, + .remap = remap, + .free = free, + }, + }; + } + + fn alloc(ctx: *anyopaque, n: usize, alignment: std.mem.Alignment, ra: usize) ?[*]u8 { + const self: *AllocatorWrapper = @ptrCast(@alignCast(ctx)); + self.total_bytes_allocated += alignment.toByteUnits() + n; + return self.child_allocator.rawAlloc(n, alignment, ra); + } + + fn resize(ctx: *anyopaque, buf: []u8, alignment: std.mem.Alignment, new_len: usize, ret_addr: usize) bool { + const self: *AllocatorWrapper = @ptrCast(@alignCast(ctx)); + return self.child_allocator.rawResize(buf, alignment, new_len, ret_addr); + } + + fn remap(context: *anyopaque, memory: []u8, alignment: std.mem.Alignment, new_len: usize, return_address: usize) ?[*]u8 { + const self: *AllocatorWrapper = @ptrCast(@alignCast(context)); + return self.child_allocator.rawRemap(memory, alignment, new_len, return_address); + } + + fn free(ctx: *anyopaque, buf: []u8, alignment: std.mem.Alignment, ret_addr: usize) void { + const self: *AllocatorWrapper = @ptrCast(@alignCast(ctx)); + return self.child_allocator.rawFree(buf, alignment, ret_addr); + } +}; + options: ModuleOptions, it: WordIterator, @@ -77,6 +115,8 @@ bindings: [lib.SPIRV_MAX_SET][lib.SPIRV_MAX_SET_BINDINGS]SpvWord, builtins: std.EnumMap(spv.SpvBuiltIn, SpvWord), push_constants: []Value, +needed_runtime_bytes: usize, + pub fn init(allocator: std.mem.Allocator, source: []const SpvWord, options: ModuleOptions) ModuleError!Self { var self: Self = std.mem.zeroInit(Self, .{ .options = options, @@ -92,6 +132,8 @@ pub fn init(allocator: std.mem.Allocator, source: []const SpvWord, options: Modu op.initRuntimeDispatcher(); + var wrapped_allocator: AllocatorWrapper = .{ .child_allocator = allocator }; + self.it = WordIterator.init(self.code); const magic = self.it.next() catch return ModuleError.InvalidSpirV; @@ -111,7 +153,7 @@ pub fn init(allocator: std.mem.Allocator, source: []const SpvWord, options: Modu self.generator_version = @intCast(generator & 0x0000FFFF); self.bound = self.it.next() catch return ModuleError.InvalidSpirV; - self.results = allocator.alloc(Result, self.bound) catch return ModuleError.OutOfMemory; + self.results = wrapped_allocator.allocator().alloc(Result, self.bound) catch return ModuleError.OutOfMemory; errdefer allocator.free(self.results); for (self.results) |*result| { @@ -167,6 +209,8 @@ pub fn init(allocator: std.mem.Allocator, source: []const SpvWord, options: Modu }); } + self.needed_runtime_bytes += wrapped_allocator.total_bytes_allocated; + //@import("pretty").print(allocator, self.results, .{ // .tab_size = 4, // .max_depth = 0, @@ -192,6 +236,8 @@ fn pass(self: *Self, allocator: std.mem.Allocator) ModuleError!void { var rt = Runtime.init(allocator, self) catch return ModuleError.OutOfMemory; defer rt.deinit(allocator); + var wrapped_allocator: AllocatorWrapper = .{ .child_allocator = allocator }; + while (rt.it.nextOrNull()) |opcode_data| { const word_count = ((opcode_data & (~spv.SpvOpCodeMask)) >> spv.SpvWordCountShift) - 1; const opcode = (opcode_data & spv.SpvOpCodeMask); @@ -199,12 +245,14 @@ fn pass(self: *Self, allocator: std.mem.Allocator) ModuleError!void { var it_tmp = rt.it; // Save because operations may iter on this iterator if (std.enums.fromInt(spv.SpvOp, opcode)) |spv_op| { if (op.SetupDispatcher.get(spv_op)) |pfn| { - pfn(allocator, word_count, &rt) catch return ModuleError.InvalidSpirV; + pfn(wrapped_allocator.allocator(), word_count, &rt) catch return ModuleError.InvalidSpirV; } } _ = it_tmp.skipN(word_count); rt.it = it_tmp; } + + self.needed_runtime_bytes += wrapped_allocator.total_bytes_allocated; } fn populateMaps(self: *Self) ModuleError!void { diff --git a/src/Runtime.zig b/src/Runtime.zig index 18a6674..ed4ab5f 100644 --- a/src/Runtime.zig +++ b/src/Runtime.zig @@ -113,7 +113,8 @@ pub fn callEntryPoint(self: *Self, allocator: std.mem.Allocator, entry_point_ind if (entry_point_result.variant) |variant| { switch (variant) { .Function => |f| { - if (!self.it.jumpToSourceLocation(f.source_location)) return RuntimeError.InvalidEntryPoint; + if (!self.it.jumpToSourceLocation(f.source_location)) + return RuntimeError.InvalidEntryPoint; self.function_stack.append(allocator, .{ .source_location = f.source_location, .result = entry_point_result, diff --git a/src/lib.zig b/src/lib.zig index b82e920..3057d63 100644 --- a/src/lib.zig +++ b/src/lib.zig @@ -21,7 +21,7 @@ //! //! try rt.callEntryPoint(allocator, try rt.getEntryPointByName("main")); //! var output: [4]f32 = undefined; -//! try rt.readOutput(f32, output[0..output.len], try rt.getResultByName("color")); +//! try rt.readOutput(std.mem.asBytes(output[0..output.len]), try rt.getResultByName("color")); //! std.log.info("Output: Vec4{any}", .{output}); //! } //! std.log.info("Successfully executed", .{});