fixing casts unit tests
All checks were successful
Build / build (push) Successful in 1m36s
Test / build (push) Successful in 5m39s

This commit is contained in:
2026-01-18 00:30:15 +01:00
parent 8b0b0a72ae
commit 8bdea7b1fc
3 changed files with 98 additions and 42 deletions

View File

@@ -185,15 +185,7 @@ pub const Value = union(Type) {
}
};
const Self = @This();
name: ?[]const u8,
decorations: std.ArrayList(Decoration),
parent: ?*const Self,
variant: ?union(Variant) {
pub const VariantData = union(Variant) {
String: []const u8,
Extension: struct {},
Type: union(Type) {
@@ -265,7 +257,17 @@ variant: ?union(Variant) {
Label: struct {
source_location: usize,
},
},
};
const Self = @This();
name: ?[]const u8,
decorations: std.ArrayList(Decoration),
parent: ?*const Self,
variant: ?VariantData,
pub fn init() Self {
return .{
@@ -305,7 +307,7 @@ pub fn deinit(self: *Self, allocator: std.mem.Allocator) void {
}
pub fn getValueTypeWord(self: *Self) RuntimeError!SpvWord {
return switch (self.variant orelse return RuntimeError.InvalidSpirV) {
return switch ((try self.getVariant()).*) {
.Variable => |v| v.type_word,
.Constant => |c| c.type_word,
.AccessChain => |*a| a.target,
@@ -315,7 +317,7 @@ pub fn getValueTypeWord(self: *Self) RuntimeError!SpvWord {
}
pub fn getValueType(self: *Self) RuntimeError!Type {
return switch (self.variant orelse return RuntimeError.InvalidSpirV) {
return switch ((try self.getVariant()).*) {
.Variable => |v| v.type,
.Constant => |c| c.type,
.FunctionParameter => |p| p.type,
@@ -324,7 +326,7 @@ pub fn getValueType(self: *Self) RuntimeError!Type {
}
pub fn getValue(self: *Self) RuntimeError!*Value {
return switch (self.variant orelse return RuntimeError.InvalidSpirV) {
return switch ((try self.getVariant()).*) {
.Variable => |*v| &v.value,
.Constant => |*c| &c.value,
.AccessChain => |*a| &a.value,
@@ -333,6 +335,14 @@ pub fn getValue(self: *Self) RuntimeError!*Value {
};
}
pub inline fn getVariant(self: *Self) RuntimeError!*VariantData {
return &(self.variant orelse return RuntimeError.InvalidSpirV);
}
pub inline fn getConstVariant(self: *const Self) RuntimeError!*const VariantData {
return &(self.variant orelse return RuntimeError.InvalidSpirV);
}
/// Performs a deep copy
pub fn dupe(self: *const Self, allocator: std.mem.Allocator) RuntimeError!Self {
return .{

View File

@@ -186,8 +186,8 @@ pub const RuntimeDispatcher = block: {
fn CondEngine(comptime T: ValueType, comptime Op: CondOp) type {
return struct {
fn op(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void {
sw: switch ((rt.results[try rt.it.next()].variant orelse return RuntimeError.InvalidSpirV).Type) {
.Vector => |v| continue :sw (rt.results[v.components_type_word].variant orelse return RuntimeError.InvalidSpirV).Type,
sw: switch ((try rt.results[try rt.it.next()].getVariant()).Type) {
.Vector => |v| continue :sw (try rt.results[v.components_type_word].getVariant()).Type,
.Bool => {},
else => return RuntimeError.InvalidSpirV,
}
@@ -198,8 +198,8 @@ fn CondEngine(comptime T: ValueType, comptime Op: CondOp) type {
const op1_value = try op1_result.getValue();
const op2_value = try rt.results[try rt.it.next()].getValue();
const size = sw: switch ((rt.results[op1_type].variant orelse return RuntimeError.InvalidSpirV).Type) {
.Vector => |v| continue :sw (rt.results[v.components_type_word].variant orelse return RuntimeError.InvalidSpirV).Type,
const size = sw: switch ((try rt.results[op1_type].getVariant()).Type) {
.Vector => |v| continue :sw (try rt.results[v.components_type_word].getVariant()).Type,
.Float => |f| if (T == .Float) f.bit_length else return RuntimeError.InvalidSpirV,
.Int => |i| if (T == .SInt or T == .UInt) i.bit_length else return RuntimeError.InvalidSpirV,
else => return RuntimeError.InvalidSpirV,
@@ -246,21 +246,21 @@ fn CondEngine(comptime T: ValueType, comptime Op: CondOp) type {
fn ConversionEngine(comptime From: ValueType, comptime To: ValueType) type {
return struct {
fn op(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void {
const target_type = (rt.results[try rt.it.next()].variant orelse return RuntimeError.InvalidSpirV).Type;
const target_type = (try rt.results[try rt.it.next()].getVariant()).Type;
const value = try rt.results[try rt.it.next()].getValue();
const op_result = &rt.results[try rt.it.next()];
const op_type = try op_result.getValueTypeWord();
const op_value = try op_result.getValue();
const from_size = sw: switch ((rt.results[op_type].variant orelse return RuntimeError.InvalidSpirV).Type) {
.Vector => |v| continue :sw (rt.results[v.components_type_word].variant orelse return RuntimeError.InvalidSpirV).Type,
const from_size = sw: switch ((try rt.results[op_type].getVariant()).Type) {
.Vector => |v| continue :sw (try rt.results[v.components_type_word].getVariant()).Type,
.Float => |f| if (From == .Float) f.bit_length else return RuntimeError.InvalidSpirV,
.Int => |i| if (From == .SInt or From == .UInt) i.bit_length else return RuntimeError.InvalidSpirV,
else => return RuntimeError.InvalidSpirV,
};
const to_size = sw: switch (target_type) {
.Vector => |v| continue :sw (rt.results[v.components_type_word].variant orelse return RuntimeError.InvalidSpirV).Type,
.Vector => |v| continue :sw (try rt.results[v.components_type_word].getVariant()).Type,
.Float => |f| if (To == .Float) f.bit_length else return RuntimeError.InvalidSpirV,
.Int => |i| if (To == .SInt or To == .UInt) i.bit_length else return RuntimeError.InvalidSpirV,
else => return RuntimeError.InvalidSpirV,
@@ -306,13 +306,13 @@ fn ConversionEngine(comptime From: ValueType, comptime To: ValueType) type {
fn MathEngine(comptime T: ValueType, comptime Op: MathOp) type {
return struct {
fn op(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void {
const target_type = (rt.results[try rt.it.next()].variant orelse return RuntimeError.InvalidSpirV).Type;
const target_type = (try rt.results[try rt.it.next()].getVariant()).Type;
const value = try rt.results[try rt.it.next()].getValue();
const op1_value = try rt.results[try rt.it.next()].getValue();
const op2_value = try rt.results[try rt.it.next()].getValue();
const size = sw: switch (target_type) {
.Vector => |v| continue :sw (rt.results[v.components_type_word].variant orelse return RuntimeError.InvalidSpirV).Type,
.Vector => |v| continue :sw (try rt.results[v.components_type_word].getVariant()).Type,
.Float => |f| if (T == .Float) f.bit_length else return RuntimeError.InvalidSpirV,
.Int => |i| if (T == .SInt or T == .UInt) i.bit_length else return RuntimeError.InvalidSpirV,
else => return RuntimeError.InvalidSpirV,
@@ -489,7 +489,7 @@ fn opAccessChain(_: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) Runtim
const index_count = word_count - 3;
for (0..index_count) |_| {
const member = &rt.results[try rt.it.next()];
const member_value = switch (member.variant orelse return RuntimeError.InvalidSpirV) {
const member_value = switch ((try member.getVariant()).*) {
.Constant => |c| &c.value,
.Variable => |v| &v.value,
else => return RuntimeError.InvalidSpirV,
@@ -527,7 +527,7 @@ fn opAccessChain(_: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) Runtim
fn opBranch(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void {
const id = try rt.it.next();
_ = rt.it.jumpToSourceLocation(switch (rt.results[id].variant orelse return RuntimeError.InvalidSpirV) {
_ = rt.it.jumpToSourceLocation(switch ((try rt.results[id].getVariant()).*) {
.Label => |l| l.source_location,
else => return RuntimeError.InvalidSpirV,
});
@@ -535,11 +535,11 @@ fn opBranch(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void {
fn opBranchConditional(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void {
const cond_value = try rt.results[try rt.it.next()].getValue();
const true_branch = switch (rt.results[try rt.it.next()].variant orelse return RuntimeError.InvalidSpirV) {
const true_branch = switch ((try rt.results[try rt.it.next()].getVariant()).*) {
.Label => |l| l.source_location,
else => return RuntimeError.InvalidSpirV,
};
const false_branch = switch (rt.results[try rt.it.next()].variant orelse return RuntimeError.InvalidSpirV) {
const false_branch = switch ((try rt.results[try rt.it.next()].getVariant()).*) {
.Label => |l| l.source_location,
else => return RuntimeError.InvalidSpirV,
};
@@ -559,9 +559,9 @@ fn opCompositeConstruct(_: std.mem.Allocator, word_count: SpvWord, rt: *Runtime)
const id = try rt.it.next();
const index_count = word_count - 2;
const target = (rt.results[id].variant orelse return RuntimeError.InvalidSpirV).Constant.value.getCompositeDataOrNull() orelse return RuntimeError.InvalidSpirV;
const target = (try rt.results[id].getVariant()).Constant.value.getCompositeDataOrNull() orelse return RuntimeError.InvalidSpirV;
for (target[0..index_count]) |*elem| {
const value = (rt.results[try rt.it.next()].variant orelse return RuntimeError.InvalidSpirV).Constant.value;
const value = (try rt.results[try rt.it.next()].getVariant()).Constant.value;
elem.* = value;
}
}
@@ -572,7 +572,7 @@ fn opCompositeExtract(allocator: std.mem.Allocator, word_count: SpvWord, rt: *Ru
const composite_id = try rt.it.next();
const index_count = word_count - 3;
var composite = (rt.results[composite_id].variant orelse return RuntimeError.InvalidSpirV).Constant.value;
var composite = (try rt.results[composite_id].getVariant()).Constant.value;
for (0..index_count) |_| {
const member_id = try rt.it.next();
composite = (composite.getCompositeDataOrNull() orelse return RuntimeError.InvalidSpirV)[member_id];
@@ -580,7 +580,7 @@ fn opCompositeExtract(allocator: std.mem.Allocator, word_count: SpvWord, rt: *Ru
rt.results[id].variant = .{
.Constant = .{
.type_word = res_type,
.type = switch (rt.results[res_type].variant orelse return RuntimeError.InvalidSpirV) {
.type = switch ((try rt.results[res_type].getVariant()).*) {
.Type => |t| @as(Result.Type, t),
else => return RuntimeError.InvalidSpirV,
},
@@ -703,12 +703,12 @@ fn opFunctionCall(allocator: std.mem.Allocator, _: SpvWord, rt: *Runtime) Runtim
const ret = &rt.results[try rt.it.next()];
const func = &rt.results[try rt.it.next()];
for ((func.variant orelse return RuntimeError.InvalidSpirV).Function.params) |param| {
for ((try func.getVariant()).Function.params) |param| {
const arg = &rt.results[try rt.it.next()];
(rt.results[param].variant orelse return RuntimeError.InvalidSpirV).FunctionParameter.value_ptr = try arg.getValue();
((try rt.results[param].getVariant()).*).FunctionParameter.value_ptr = try arg.getValue();
}
rt.function_stack.items[rt.function_stack.items.len - 1].source_location = rt.it.emitSourceLocation();
const source_location = (func.variant orelse return RuntimeError.InvalidSpirV).Function.source_location;
const source_location = (try func.getVariant()).Function.source_location;
rt.function_stack.append(allocator, .{
.source_location = source_location,
.result = func,
@@ -736,14 +736,14 @@ fn opFunctionParameter(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeEr
target.variant = .{
.FunctionParameter = .{
.type_word = var_type,
.type = switch (resolved.variant orelse return RuntimeError.InvalidSpirV) {
.type = switch ((try resolved.getConstVariant()).*) {
.Type => |t| @as(Result.Type, t),
else => return RuntimeError.InvalidSpirV,
},
.value_ptr = null,
},
};
((rt.current_function orelse return RuntimeError.InvalidSpirV).variant orelse return RuntimeError.InvalidSpirV).Function.params[rt.current_parameter_index] = id;
(try (rt.current_function orelse return RuntimeError.InvalidSpirV).getVariant()).Function.params[rt.current_parameter_index] = id;
rt.current_parameter_index += 1;
}
@@ -925,7 +925,7 @@ fn opTypeMatrix(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!voi
.Type = .{
.Matrix = .{
.column_type_word = column_type_word,
.column_type = switch (rt.mod.results[column_type_word].variant orelse return RuntimeError.InvalidSpirV) {
.column_type = switch ((try rt.mod.results[column_type_word].getVariant()).*) {
.Type => |t| @as(Result.Type, t),
else => return RuntimeError.InvalidSpirV,
},
@@ -994,7 +994,7 @@ fn opTypeVector(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!voi
.Type = .{
.Vector = .{
.components_type_word = components_type_word,
.components_type = switch (rt.mod.results[components_type_word].variant orelse return RuntimeError.InvalidSpirV) {
.components_type = switch ((try rt.mod.results[components_type_word].getVariant()).*) {
.Type => |t| @as(Result.Type, t),
else => return RuntimeError.InvalidSpirV,
},
@@ -1030,7 +1030,7 @@ fn opVariable(allocator: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) R
.Variable = .{
.storage_class = storage_class,
.type_word = var_type,
.type = switch (resolved.variant orelse return RuntimeError.InvalidSpirV) {
.type = switch ((try resolved.getConstVariant()).*) {
.Type => |t| @as(Result.Type, t),
else => return RuntimeError.InvalidSpirV,
},
@@ -1083,7 +1083,7 @@ fn setupConstant(allocator: std.mem.Allocator, rt: *Runtime) RuntimeError!*Resul
.Constant = .{
.value = try Result.initValue(allocator, member_count, rt.mod.results, resolved),
.type_word = res_type,
.type = switch (resolved.variant orelse return RuntimeError.InvalidSpirV) {
.type = switch ((try resolved.getConstVariant()).*) {
.Type => |t| @as(Result.Type, t),
else => return RuntimeError.InvalidSpirV,
},

View File

@@ -9,9 +9,7 @@ test "Primitives casts" {
[2]type{ f32, u32 },
[2]type{ f32, i32 },
[2]type{ u32, f32 },
[2]type{ u32, i32 },
[2]type{ i32, f32 },
[2]type{ i32, u32 },
[2]type{ f32, f64 },
[2]type{ f64, f32 },
[2]type{ f64, u32 },
@@ -60,3 +58,51 @@ test "Primitives casts" {
try case.expectOutput(T[1], 4, code, "color", &.{ expected, expected, expected, expected });
}
}
test "Primitives bitcasts" {
const allocator = std.testing.allocator;
const types = [_][2]type{
[2]type{ u32, i32 },
[2]type{ i32, u32 },
};
inline for (types) |T| {
const base = case.random(T[0]);
const expected = @as(T[1], @bitCast(base));
const shader = try std.fmt.allocPrint(
allocator,
\\ [nzsl_version("1.1")]
\\ [feature(float64)]
\\ module;
\\
\\ struct FragOut
\\ {{
\\ [location(0)] color: vec4[{s}]
\\ }}
\\
\\ [entry(frag)]
\\ fn main() -> FragOut
\\ {{
\\ let base = {s}({d});
\\ let color = {s}(base);
\\
\\ let output: FragOut;
\\ output.color = vec4[{s}](color, color, color, color);
\\ return output;
\\ }}
,
.{
@typeName(T[1]),
@typeName(T[0]),
base,
@typeName(T[1]),
@typeName(T[1]),
},
);
defer allocator.free(shader);
const code = try compileNzsl(allocator, shader);
defer allocator.free(code);
try case.expectOutput(T[1], 4, code, "color", &.{ expected, expected, expected, expected });
}
}