From 8bdea7b1fc0f284f68fc51a87310142c1372ea4a Mon Sep 17 00:00:00 2001 From: Kbz-8 Date: Sun, 18 Jan 2026 00:30:15 +0100 Subject: [PATCH] fixing casts unit tests --- src/Result.zig | 36 +++++++++++++++++++++------------ src/opcodes.zig | 54 ++++++++++++++++++++++++------------------------- test/casts.zig | 50 +++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 98 insertions(+), 42 deletions(-) diff --git a/src/Result.zig b/src/Result.zig index bfcb813..81b9fc3 100644 --- a/src/Result.zig +++ b/src/Result.zig @@ -185,15 +185,7 @@ pub const Value = union(Type) { } }; -const Self = @This(); - -name: ?[]const u8, - -decorations: std.ArrayList(Decoration), - -parent: ?*const Self, - -variant: ?union(Variant) { +pub const VariantData = union(Variant) { String: []const u8, Extension: struct {}, Type: union(Type) { @@ -265,7 +257,17 @@ variant: ?union(Variant) { Label: struct { source_location: usize, }, -}, +}; + +const Self = @This(); + +name: ?[]const u8, + +decorations: std.ArrayList(Decoration), + +parent: ?*const Self, + +variant: ?VariantData, pub fn init() Self { return .{ @@ -305,7 +307,7 @@ pub fn deinit(self: *Self, allocator: std.mem.Allocator) void { } pub fn getValueTypeWord(self: *Self) RuntimeError!SpvWord { - return switch (self.variant orelse return RuntimeError.InvalidSpirV) { + return switch ((try self.getVariant()).*) { .Variable => |v| v.type_word, .Constant => |c| c.type_word, .AccessChain => |*a| a.target, @@ -315,7 +317,7 @@ pub fn getValueTypeWord(self: *Self) RuntimeError!SpvWord { } pub fn getValueType(self: *Self) RuntimeError!Type { - return switch (self.variant orelse return RuntimeError.InvalidSpirV) { + return switch ((try self.getVariant()).*) { .Variable => |v| v.type, .Constant => |c| c.type, .FunctionParameter => |p| p.type, @@ -324,7 +326,7 @@ pub fn getValueType(self: *Self) RuntimeError!Type { } pub fn getValue(self: *Self) RuntimeError!*Value { - return switch (self.variant orelse return RuntimeError.InvalidSpirV) { + return switch ((try self.getVariant()).*) { .Variable => |*v| &v.value, .Constant => |*c| &c.value, .AccessChain => |*a| &a.value, @@ -333,6 +335,14 @@ pub fn getValue(self: *Self) RuntimeError!*Value { }; } +pub inline fn getVariant(self: *Self) RuntimeError!*VariantData { + return &(self.variant orelse return RuntimeError.InvalidSpirV); +} + +pub inline fn getConstVariant(self: *const Self) RuntimeError!*const VariantData { + return &(self.variant orelse return RuntimeError.InvalidSpirV); +} + /// Performs a deep copy pub fn dupe(self: *const Self, allocator: std.mem.Allocator) RuntimeError!Self { return .{ diff --git a/src/opcodes.zig b/src/opcodes.zig index b097a6a..bffabcc 100644 --- a/src/opcodes.zig +++ b/src/opcodes.zig @@ -186,8 +186,8 @@ pub const RuntimeDispatcher = block: { fn CondEngine(comptime T: ValueType, comptime Op: CondOp) type { return struct { fn op(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { - sw: switch ((rt.results[try rt.it.next()].variant orelse return RuntimeError.InvalidSpirV).Type) { - .Vector => |v| continue :sw (rt.results[v.components_type_word].variant orelse return RuntimeError.InvalidSpirV).Type, + sw: switch ((try rt.results[try rt.it.next()].getVariant()).Type) { + .Vector => |v| continue :sw (try rt.results[v.components_type_word].getVariant()).Type, .Bool => {}, else => return RuntimeError.InvalidSpirV, } @@ -198,8 +198,8 @@ fn CondEngine(comptime T: ValueType, comptime Op: CondOp) type { const op1_value = try op1_result.getValue(); const op2_value = try rt.results[try rt.it.next()].getValue(); - const size = sw: switch ((rt.results[op1_type].variant orelse return RuntimeError.InvalidSpirV).Type) { - .Vector => |v| continue :sw (rt.results[v.components_type_word].variant orelse return RuntimeError.InvalidSpirV).Type, + const size = sw: switch ((try rt.results[op1_type].getVariant()).Type) { + .Vector => |v| continue :sw (try rt.results[v.components_type_word].getVariant()).Type, .Float => |f| if (T == .Float) f.bit_length else return RuntimeError.InvalidSpirV, .Int => |i| if (T == .SInt or T == .UInt) i.bit_length else return RuntimeError.InvalidSpirV, else => return RuntimeError.InvalidSpirV, @@ -246,21 +246,21 @@ fn CondEngine(comptime T: ValueType, comptime Op: CondOp) type { fn ConversionEngine(comptime From: ValueType, comptime To: ValueType) type { return struct { fn op(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { - const target_type = (rt.results[try rt.it.next()].variant orelse return RuntimeError.InvalidSpirV).Type; + const target_type = (try rt.results[try rt.it.next()].getVariant()).Type; const value = try rt.results[try rt.it.next()].getValue(); const op_result = &rt.results[try rt.it.next()]; const op_type = try op_result.getValueTypeWord(); const op_value = try op_result.getValue(); - const from_size = sw: switch ((rt.results[op_type].variant orelse return RuntimeError.InvalidSpirV).Type) { - .Vector => |v| continue :sw (rt.results[v.components_type_word].variant orelse return RuntimeError.InvalidSpirV).Type, + const from_size = sw: switch ((try rt.results[op_type].getVariant()).Type) { + .Vector => |v| continue :sw (try rt.results[v.components_type_word].getVariant()).Type, .Float => |f| if (From == .Float) f.bit_length else return RuntimeError.InvalidSpirV, .Int => |i| if (From == .SInt or From == .UInt) i.bit_length else return RuntimeError.InvalidSpirV, else => return RuntimeError.InvalidSpirV, }; const to_size = sw: switch (target_type) { - .Vector => |v| continue :sw (rt.results[v.components_type_word].variant orelse return RuntimeError.InvalidSpirV).Type, + .Vector => |v| continue :sw (try rt.results[v.components_type_word].getVariant()).Type, .Float => |f| if (To == .Float) f.bit_length else return RuntimeError.InvalidSpirV, .Int => |i| if (To == .SInt or To == .UInt) i.bit_length else return RuntimeError.InvalidSpirV, else => return RuntimeError.InvalidSpirV, @@ -306,13 +306,13 @@ fn ConversionEngine(comptime From: ValueType, comptime To: ValueType) type { fn MathEngine(comptime T: ValueType, comptime Op: MathOp) type { return struct { fn op(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { - const target_type = (rt.results[try rt.it.next()].variant orelse return RuntimeError.InvalidSpirV).Type; + const target_type = (try rt.results[try rt.it.next()].getVariant()).Type; const value = try rt.results[try rt.it.next()].getValue(); const op1_value = try rt.results[try rt.it.next()].getValue(); const op2_value = try rt.results[try rt.it.next()].getValue(); const size = sw: switch (target_type) { - .Vector => |v| continue :sw (rt.results[v.components_type_word].variant orelse return RuntimeError.InvalidSpirV).Type, + .Vector => |v| continue :sw (try rt.results[v.components_type_word].getVariant()).Type, .Float => |f| if (T == .Float) f.bit_length else return RuntimeError.InvalidSpirV, .Int => |i| if (T == .SInt or T == .UInt) i.bit_length else return RuntimeError.InvalidSpirV, else => return RuntimeError.InvalidSpirV, @@ -489,7 +489,7 @@ fn opAccessChain(_: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) Runtim const index_count = word_count - 3; for (0..index_count) |_| { const member = &rt.results[try rt.it.next()]; - const member_value = switch (member.variant orelse return RuntimeError.InvalidSpirV) { + const member_value = switch ((try member.getVariant()).*) { .Constant => |c| &c.value, .Variable => |v| &v.value, else => return RuntimeError.InvalidSpirV, @@ -527,7 +527,7 @@ fn opAccessChain(_: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) Runtim fn opBranch(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { const id = try rt.it.next(); - _ = rt.it.jumpToSourceLocation(switch (rt.results[id].variant orelse return RuntimeError.InvalidSpirV) { + _ = rt.it.jumpToSourceLocation(switch ((try rt.results[id].getVariant()).*) { .Label => |l| l.source_location, else => return RuntimeError.InvalidSpirV, }); @@ -535,11 +535,11 @@ fn opBranch(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { fn opBranchConditional(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { const cond_value = try rt.results[try rt.it.next()].getValue(); - const true_branch = switch (rt.results[try rt.it.next()].variant orelse return RuntimeError.InvalidSpirV) { + const true_branch = switch ((try rt.results[try rt.it.next()].getVariant()).*) { .Label => |l| l.source_location, else => return RuntimeError.InvalidSpirV, }; - const false_branch = switch (rt.results[try rt.it.next()].variant orelse return RuntimeError.InvalidSpirV) { + const false_branch = switch ((try rt.results[try rt.it.next()].getVariant()).*) { .Label => |l| l.source_location, else => return RuntimeError.InvalidSpirV, }; @@ -559,9 +559,9 @@ fn opCompositeConstruct(_: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) const id = try rt.it.next(); const index_count = word_count - 2; - const target = (rt.results[id].variant orelse return RuntimeError.InvalidSpirV).Constant.value.getCompositeDataOrNull() orelse return RuntimeError.InvalidSpirV; + const target = (try rt.results[id].getVariant()).Constant.value.getCompositeDataOrNull() orelse return RuntimeError.InvalidSpirV; for (target[0..index_count]) |*elem| { - const value = (rt.results[try rt.it.next()].variant orelse return RuntimeError.InvalidSpirV).Constant.value; + const value = (try rt.results[try rt.it.next()].getVariant()).Constant.value; elem.* = value; } } @@ -572,7 +572,7 @@ fn opCompositeExtract(allocator: std.mem.Allocator, word_count: SpvWord, rt: *Ru const composite_id = try rt.it.next(); const index_count = word_count - 3; - var composite = (rt.results[composite_id].variant orelse return RuntimeError.InvalidSpirV).Constant.value; + var composite = (try rt.results[composite_id].getVariant()).Constant.value; for (0..index_count) |_| { const member_id = try rt.it.next(); composite = (composite.getCompositeDataOrNull() orelse return RuntimeError.InvalidSpirV)[member_id]; @@ -580,7 +580,7 @@ fn opCompositeExtract(allocator: std.mem.Allocator, word_count: SpvWord, rt: *Ru rt.results[id].variant = .{ .Constant = .{ .type_word = res_type, - .type = switch (rt.results[res_type].variant orelse return RuntimeError.InvalidSpirV) { + .type = switch ((try rt.results[res_type].getVariant()).*) { .Type => |t| @as(Result.Type, t), else => return RuntimeError.InvalidSpirV, }, @@ -703,12 +703,12 @@ fn opFunctionCall(allocator: std.mem.Allocator, _: SpvWord, rt: *Runtime) Runtim const ret = &rt.results[try rt.it.next()]; const func = &rt.results[try rt.it.next()]; - for ((func.variant orelse return RuntimeError.InvalidSpirV).Function.params) |param| { + for ((try func.getVariant()).Function.params) |param| { const arg = &rt.results[try rt.it.next()]; - (rt.results[param].variant orelse return RuntimeError.InvalidSpirV).FunctionParameter.value_ptr = try arg.getValue(); + ((try rt.results[param].getVariant()).*).FunctionParameter.value_ptr = try arg.getValue(); } rt.function_stack.items[rt.function_stack.items.len - 1].source_location = rt.it.emitSourceLocation(); - const source_location = (func.variant orelse return RuntimeError.InvalidSpirV).Function.source_location; + const source_location = (try func.getVariant()).Function.source_location; rt.function_stack.append(allocator, .{ .source_location = source_location, .result = func, @@ -736,14 +736,14 @@ fn opFunctionParameter(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeEr target.variant = .{ .FunctionParameter = .{ .type_word = var_type, - .type = switch (resolved.variant orelse return RuntimeError.InvalidSpirV) { + .type = switch ((try resolved.getConstVariant()).*) { .Type => |t| @as(Result.Type, t), else => return RuntimeError.InvalidSpirV, }, .value_ptr = null, }, }; - ((rt.current_function orelse return RuntimeError.InvalidSpirV).variant orelse return RuntimeError.InvalidSpirV).Function.params[rt.current_parameter_index] = id; + (try (rt.current_function orelse return RuntimeError.InvalidSpirV).getVariant()).Function.params[rt.current_parameter_index] = id; rt.current_parameter_index += 1; } @@ -925,7 +925,7 @@ fn opTypeMatrix(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!voi .Type = .{ .Matrix = .{ .column_type_word = column_type_word, - .column_type = switch (rt.mod.results[column_type_word].variant orelse return RuntimeError.InvalidSpirV) { + .column_type = switch ((try rt.mod.results[column_type_word].getVariant()).*) { .Type => |t| @as(Result.Type, t), else => return RuntimeError.InvalidSpirV, }, @@ -994,7 +994,7 @@ fn opTypeVector(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!voi .Type = .{ .Vector = .{ .components_type_word = components_type_word, - .components_type = switch (rt.mod.results[components_type_word].variant orelse return RuntimeError.InvalidSpirV) { + .components_type = switch ((try rt.mod.results[components_type_word].getVariant()).*) { .Type => |t| @as(Result.Type, t), else => return RuntimeError.InvalidSpirV, }, @@ -1030,7 +1030,7 @@ fn opVariable(allocator: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) R .Variable = .{ .storage_class = storage_class, .type_word = var_type, - .type = switch (resolved.variant orelse return RuntimeError.InvalidSpirV) { + .type = switch ((try resolved.getConstVariant()).*) { .Type => |t| @as(Result.Type, t), else => return RuntimeError.InvalidSpirV, }, @@ -1083,7 +1083,7 @@ fn setupConstant(allocator: std.mem.Allocator, rt: *Runtime) RuntimeError!*Resul .Constant = .{ .value = try Result.initValue(allocator, member_count, rt.mod.results, resolved), .type_word = res_type, - .type = switch (resolved.variant orelse return RuntimeError.InvalidSpirV) { + .type = switch ((try resolved.getConstVariant()).*) { .Type => |t| @as(Result.Type, t), else => return RuntimeError.InvalidSpirV, }, diff --git a/test/casts.zig b/test/casts.zig index e99d712..100a5e3 100644 --- a/test/casts.zig +++ b/test/casts.zig @@ -9,9 +9,7 @@ test "Primitives casts" { [2]type{ f32, u32 }, [2]type{ f32, i32 }, [2]type{ u32, f32 }, - [2]type{ u32, i32 }, [2]type{ i32, f32 }, - [2]type{ i32, u32 }, [2]type{ f32, f64 }, [2]type{ f64, f32 }, [2]type{ f64, u32 }, @@ -60,3 +58,51 @@ test "Primitives casts" { try case.expectOutput(T[1], 4, code, "color", &.{ expected, expected, expected, expected }); } } + +test "Primitives bitcasts" { + const allocator = std.testing.allocator; + const types = [_][2]type{ + [2]type{ u32, i32 }, + [2]type{ i32, u32 }, + }; + + inline for (types) |T| { + const base = case.random(T[0]); + const expected = @as(T[1], @bitCast(base)); + + const shader = try std.fmt.allocPrint( + allocator, + \\ [nzsl_version("1.1")] + \\ [feature(float64)] + \\ module; + \\ + \\ struct FragOut + \\ {{ + \\ [location(0)] color: vec4[{s}] + \\ }} + \\ + \\ [entry(frag)] + \\ fn main() -> FragOut + \\ {{ + \\ let base = {s}({d}); + \\ let color = {s}(base); + \\ + \\ let output: FragOut; + \\ output.color = vec4[{s}](color, color, color, color); + \\ return output; + \\ }} + , + .{ + @typeName(T[1]), + @typeName(T[0]), + base, + @typeName(T[1]), + @typeName(T[1]), + }, + ); + defer allocator.free(shader); + const code = try compileNzsl(allocator, shader); + defer allocator.free(code); + try case.expectOutput(T[1], 4, code, "color", &.{ expected, expected, expected, expected }); + } +}