From 9cdb683f3f24cf82774f1bf31b8c51d3863a9cfe Mon Sep 17 00:00:00 2001 From: Kbz-8 Date: Mon, 27 Apr 2026 15:42:29 +0200 Subject: [PATCH] adding decoration members propagation --- src/Module.zig | 128 ++++++++++++++++++++++++++++++++++++++++-------- src/Result.zig | 15 +++++- src/Runtime.zig | 45 +++++++++++++++-- src/opcodes.zig | 65 +++++++++++++++++++++--- 4 files changed, 220 insertions(+), 33 deletions(-) diff --git a/src/Module.zig b/src/Module.zig index 7d96fe3..6053483 100644 --- a/src/Module.zig +++ b/src/Module.zig @@ -166,6 +166,106 @@ fn pass(self: *Self, allocator: std.mem.Allocator) ModuleError!void { } } +fn resolveConstantWord(self: *const Self, id: SpvWord) ?SpvWord { + if (id >= self.results.len) return null; + + const variant = self.results[id].variant orelse return null; + return switch (variant) { + .Constant => |c| switch (c.value) { + .Int => |i| i.value.uint32, + else => null, + }, + else => null, + }; +} + +fn findAccessChainToMember(self: *const Self, base_id: SpvWord, member_index: SpvWord) ?SpvWord { + for (self.results, 0..) |result, id| { + const variant = result.variant orelse continue; + + switch (variant) { + .AccessChain => |a| { + if (a.base != base_id or a.indexes.len == 0) continue; + + const first_index = self.resolveConstantWord(a.indexes[0]) orelse continue; + if (first_index == member_index) return @intCast(id); + }, + else => {}, + } + } + + return null; +} + +fn applyInterfaceDecoration( + self: *Self, + storage_class: spv.SpvStorageClass, + decoration: Result.Decoration, + id: SpvWord, +) ModuleError!void { + switch (storage_class) { + .Input => switch (decoration.rtype) { + .BuiltIn => self.builtins.put( + std.enums.fromInt(spv.SpvBuiltIn, decoration.literal_1) orelse return ModuleError.InvalidSpirV, + id, + ), + .Location => self.input_locations[decoration.literal_1] = id, + else => {}, + }, + .Output => switch (decoration.rtype) { + .BuiltIn => self.builtins.put( + std.enums.fromInt(spv.SpvBuiltIn, decoration.literal_1) orelse return ModuleError.InvalidSpirV, + id, + ), + .Location => self.output_locations[decoration.literal_1] = id, + else => {}, + }, + else => {}, + } +} + +fn applyStructMemberInterfaceDecorations( + self: *Self, + storage_class: spv.SpvStorageClass, + type_word: SpvWord, + id: SpvWord, +) ModuleError!void { + switch (storage_class) { + .Input, .Output => {}, + else => return, + } + + const type_result = &self.results[type_word]; + const target_type_word = if (type_result.variant) |variant| switch (variant) { + .Type => |t| switch (t) { + .Pointer => |ptr| ptr.target, + else => type_word, + }, + else => type_word, + } else type_word; + + const target_result = &self.results[target_type_word]; + if (target_result.variant) |variant| { + switch (variant) { + .Type => |t| switch (t) { + .Structure => { + for (target_result.decorations.items) |decoration| { + switch (decoration.rtype) { + .BuiltIn, .Location => { + const member_id = self.findAccessChainToMember(id, decoration.index) orelse continue; + try self.applyInterfaceDecoration(storage_class, decoration, member_id); + }, + else => {}, + } + } + }, + else => {}, + }, + else => {}, + } + } +} + fn applyDecorations(self: *Self) ModuleError!void { for (self.results, 0..) |result, id| { if (result.variant == null) @@ -177,27 +277,9 @@ fn applyDecorations(self: *Self) ModuleError!void { for (result.decorations.items) |decoration| { switch (result.variant.?) { .Variable => |v| { + try self.applyInterfaceDecoration(v.storage_class, decoration, @intCast(id)); + switch (v.storage_class) { - .Input => { - switch (decoration.rtype) { - .BuiltIn => self.builtins.put( - std.enums.fromInt(spv.SpvBuiltIn, decoration.literal_1) orelse return ModuleError.InvalidSpirV, - @intCast(id), - ), - .Location => self.input_locations[decoration.literal_1] = @intCast(id), - else => {}, - } - }, - .Output => { - switch (decoration.rtype) { - .BuiltIn => self.builtins.put( - std.enums.fromInt(spv.SpvBuiltIn, decoration.literal_1) orelse return ModuleError.InvalidSpirV, - @intCast(id), - ), - .Location => self.output_locations[decoration.literal_1] = @intCast(id), - else => {}, - } - }, .StorageBuffer, .Uniform, .UniformConstant => { switch (decoration.rtype) { .Binding => binding = decoration.literal_1, @@ -221,6 +303,12 @@ fn applyDecorations(self: *Self) ModuleError!void { else => {}, } } + + switch (result.variant.?) { + .Variable => |v| try self.applyStructMemberInterfaceDecorations(v.storage_class, v.type_word, @intCast(id)), + else => {}, + } + if (set != null and binding != null) { self.bindings[set.?][binding.?] = @intCast(id); } diff --git a/src/Result.zig b/src/Result.zig index 07cc93d..c0c4cd0 100644 --- a/src/Result.zig +++ b/src/Result.zig @@ -201,6 +201,8 @@ pub const VariantData = union(Variant) { }, AccessChain: struct { target: SpvWord, + base: SpvWord, + indexes: []SpvWord, value: Value, }, FunctionParameter: struct { @@ -247,7 +249,10 @@ pub fn deinit(self: *Self, allocator: std.mem.Allocator) void { }, .Constant => |*c| c.value.deinit(allocator), .Variable => |*v| v.value.deinit(allocator), - .AccessChain => |*a| a.value.deinit(allocator), + .AccessChain => |*a| { + allocator.free(a.indexes); + a.value.deinit(allocator); + }, .Function => |f| allocator.free(f.params), else => {}, } @@ -363,6 +368,14 @@ pub fn dupe(self: *const Self, allocator: std.mem.Allocator) RuntimeError!Self { .params = allocator.dupe(SpvWord, f.params) catch return RuntimeError.OutOfMemory, }, }, + .AccessChain => |a| break :blk .{ + .AccessChain = .{ + .target = a.target, + .base = a.base, + .indexes = allocator.dupe(SpvWord, a.indexes) catch return RuntimeError.OutOfMemory, + .value = try a.value.dupe(allocator), + }, + }, else => break :blk variant, } } diff --git a/src/Runtime.zig b/src/Runtime.zig index 63db18a..62371b1 100644 --- a/src/Runtime.zig +++ b/src/Runtime.zig @@ -242,9 +242,46 @@ pub fn writeDescriptorSet(self: *const Self, input: []const u8, set: SpvWord, bi } } +fn readResultValue(self: *const Self, output: []u8, result: SpvWord) RuntimeError!void { + const variant = self.results[result].variant orelse return RuntimeError.InvalidSpirV; + switch (variant) { + .Variable => |v| _ = try v.value.read(output), + .AccessChain => |a| switch (a.value) { + .Pointer => |ptr| switch (ptr.ptr) { + .common => |value_ptr| _ = try value_ptr.read(output), + .f32_ptr => |value_ptr| std.mem.copyForwards(u8, output[0..@sizeOf(f32)], std.mem.asBytes(value_ptr)), + .i32_ptr => |value_ptr| std.mem.copyForwards(u8, output[0..@sizeOf(i32)], std.mem.asBytes(value_ptr)), + .u32_ptr => |value_ptr| std.mem.copyForwards(u8, output[0..@sizeOf(u32)], std.mem.asBytes(value_ptr)), + }, + else => _ = try a.value.read(output), + }, + else => return RuntimeError.InvalidSpirV, + } +} + +fn writeResultValue(self: *const Self, input: []const u8, result: SpvWord) RuntimeError!void { + if (self.results[result].variant) |*variant| { + switch (variant.*) { + .Variable => |*v| _ = try v.value.writeConst(input), + .AccessChain => |*a| switch (a.value) { + .Pointer => |ptr| switch (ptr.ptr) { + .common => |value_ptr| _ = try value_ptr.writeConst(input), + .f32_ptr => |value_ptr| std.mem.copyForwards(u8, std.mem.asBytes(value_ptr), input[0..@sizeOf(f32)]), + .i32_ptr => |value_ptr| std.mem.copyForwards(u8, std.mem.asBytes(value_ptr), input[0..@sizeOf(i32)]), + .u32_ptr => |value_ptr| std.mem.copyForwards(u8, std.mem.asBytes(value_ptr), input[0..@sizeOf(u32)]), + }, + else => _ = try a.value.writeConst(input), + }, + else => return RuntimeError.InvalidSpirV, + } + } else { + return RuntimeError.InvalidSpirV; + } +} + pub fn readOutput(self: *const Self, output: []u8, result: SpvWord) RuntimeError!void { if (std.mem.indexOfScalar(SpvWord, &self.mod.output_locations, result)) |_| { - _ = try self.results[result].variant.?.Variable.value.read(output); + try self.readResultValue(output, result); } else { return RuntimeError.NotFound; } @@ -252,7 +289,7 @@ pub fn readOutput(self: *const Self, output: []u8, result: SpvWord) RuntimeError pub fn readBuiltIn(self: *const Self, output: []u8, builtin: spv.SpvBuiltIn) RuntimeError!void { if (self.mod.builtins.get(builtin)) |result| { - _ = try self.results[result].variant.?.Variable.value.read(output); + try self.readResultValue(output, result); } else { return RuntimeError.NotFound; } @@ -260,7 +297,7 @@ pub fn readBuiltIn(self: *const Self, output: []u8, builtin: spv.SpvBuiltIn) Run pub fn writeInput(self: *const Self, input: []const u8, result: SpvWord) RuntimeError!void { if (std.mem.indexOfScalar(SpvWord, &self.mod.input_locations, result)) |_| { - _ = try self.results[result].variant.?.Variable.value.writeConst(input); + try self.writeResultValue(input, result); } else { return RuntimeError.NotFound; } @@ -268,7 +305,7 @@ pub fn writeInput(self: *const Self, input: []const u8, result: SpvWord) Runtime pub fn writeBuiltIn(self: *const Self, input: []const u8, builtin: spv.SpvBuiltIn) RuntimeError!void { if (self.mod.builtins.get(builtin)) |result| { - _ = try self.results[result].variant.?.Variable.value.writeConst(input); + try self.writeResultValue(input, result); } else { return RuntimeError.NotFound; } diff --git a/src/opcodes.zig b/src/opcodes.zig index 93fb032..6c2192c 100644 --- a/src/opcodes.zig +++ b/src/opcodes.zig @@ -88,6 +88,7 @@ pub const SetupDispatcher = block: { .AtomicUMax = autoSetupConstant, .AtomicUMin = autoSetupConstant, .AtomicXor = autoSetupConstant, + .AccessChain = setupAccessChain, .BitCount = autoSetupConstant, .BitFieldInsert = autoSetupConstant, .BitFieldSExtract = autoSetupConstant, @@ -145,6 +146,7 @@ pub const SetupDispatcher = block: { .IAddCarry = autoSetupConstant, .IEqual = autoSetupConstant, .ImageRead = autoSetupConstant, + .InBoundsAccessChain = setupAccessChain, .IMul = autoSetupConstant, .INotEqual = autoSetupConstant, .ISub = autoSetupConstant, @@ -1119,6 +1121,39 @@ fn autoSetupConstant(allocator: std.mem.Allocator, _: SpvWord, rt: *Runtime) Run _ = try setupConstant(allocator, rt); } +fn setupAccessChain(allocator: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) RuntimeError!void { + const var_type = try rt.it.next(); + const id = try rt.it.next(); + const base_id = try rt.it.next(); + + const index_count: usize = @intCast(word_count - 3); + const indexes = allocator.alloc(SpvWord, index_count) catch return RuntimeError.OutOfMemory; + errdefer allocator.free(indexes); + + for (indexes) |*index| { + index.* = try rt.it.next(); + } + + if (rt.results[id].variant) |*variant| { + switch (variant.*) { + .AccessChain => |*a| { + allocator.free(a.indexes); + a.value.deinit(allocator); + }, + else => {}, + } + } + + rt.results[id].variant = .{ + .AccessChain = .{ + .target = var_type, + .base = base_id, + .indexes = indexes, + .value = try Value.init(allocator, rt.results, var_type, false), + }, + }; +} + fn copyValue(dst: *Value, src: *const Value) RuntimeError!void { const helpers = struct { inline fn copySlice(dst_slice: []Value, src_slice: []const Value) RuntimeError!void { @@ -1269,25 +1304,39 @@ fn opAccessChain(allocator: std.mem.Allocator, word_count: SpvWord, rt: *Runtime var arena = std.heap.ArenaAllocator.init(allocator); defer arena.deinit(); - const index_count = word_count - 3; + const index_count: usize = @intCast(word_count - 3); - if (rt.results[id].variant) |*variant| { - switch (variant.*) { - .AccessChain => |*a| try a.value.flushPtr(allocator), - else => {}, + const indexes, const free_responsability = blk: { + if (rt.results[id].variant) |*variant| { + switch (variant.*) { + .AccessChain => |*a| { + if (a.indexes.len != index_count) + return RuntimeError.InvalidSpirV; + try a.value.flushPtr(allocator); + a.value.deinit(allocator); + break :blk .{ a.indexes, false }; + }, + else => {}, + } } - } + break :blk .{ allocator.alloc(SpvWord, index_count) catch return RuntimeError.OutOfMemory, true }; + }; + errdefer if (free_responsability) allocator.free(indexes); rt.results[id].variant = .{ .AccessChain = .{ .target = var_type, + .base = base_id, + .indexes = indexes, .value = blk: { var is_owner_of_uniform_slice = false; var uniform_slice_window: ?[]u8 = null; for (0..index_count) |index| { const is_last = (index == index_count - 1); - const member = &rt.results[try rt.it.next()]; + const index_id = try rt.it.next(); + indexes[index] = index_id; + const member = &rt.results[index_id]; const member_value = switch ((try member.getVariant()).*) { .Constant => |c| &c.value, .Variable => |v| &v.value, @@ -1488,7 +1537,7 @@ fn opCompositeExtract(allocator: std.mem.Allocator, word_count: SpvWord, rt: *Ru const res_type = try rt.it.next(); const id = try rt.it.next(); const composite_id = try rt.it.next(); - const index_count = word_count - 3; + const index_count: usize = @intCast(word_count - 3); var arena = std.heap.ArenaAllocator.init(allocator); defer arena.deinit();