adding more mathematical operations and unit tests with them
Some checks failed
Build / build (push) Successful in 56s
Test / build (push) Failing after 4m15s

This commit is contained in:
2026-01-15 00:35:13 +01:00
parent 88e847e2d9
commit e570b7f19d
4 changed files with 228 additions and 74 deletions

View File

@@ -21,6 +21,7 @@ pub const RuntimeError = error{
Killed, Killed,
InvalidEntryPoint, InvalidEntryPoint,
ToDo, ToDo,
DivisionByZero,
}; };
pub const Function = struct { pub const Function = struct {

View File

@@ -19,6 +19,14 @@ const MathType = enum {
UInt, UInt,
}; };
const MathOp = enum {
Add,
Sub,
Mul,
Div,
Mod,
};
pub const OpCodeFunc = *const fn (std.mem.Allocator, SpvWord, *Runtime) RuntimeError!void; pub const OpCodeFunc = *const fn (std.mem.Allocator, SpvWord, *Runtime) RuntimeError!void;
pub const SetupDispatcher = block: { pub const SetupDispatcher = block: {
@@ -30,10 +38,8 @@ pub const SetupDispatcher = block: {
.Decorate = opDecorate, .Decorate = opDecorate,
.EntryPoint = opEntryPoint, .EntryPoint = opEntryPoint,
.ExecutionMode = opExecutionMode, .ExecutionMode = opExecutionMode,
.FMul = autoSetupConstant,
.Function = opFunction, .Function = opFunction,
.FunctionEnd = opFunctionEnd, .FunctionEnd = opFunctionEnd,
.IMul = autoSetupConstant,
.Label = opLabel, .Label = opLabel,
.Load = autoSetupConstant, .Load = autoSetupConstant,
.MemberDecorate = opDecorateMember, .MemberDecorate = opDecorateMember,
@@ -52,6 +58,18 @@ pub const SetupDispatcher = block: {
.TypeVector = opTypeVector, .TypeVector = opTypeVector,
.TypeVoid = opTypeVoid, .TypeVoid = opTypeVoid,
.Variable = opVariable, .Variable = opVariable,
.FAdd = autoSetupConstant,
.FDiv = autoSetupConstant,
.FMod = autoSetupConstant,
.FMul = autoSetupConstant,
.FSub = autoSetupConstant,
.IAdd = autoSetupConstant,
.IMul = autoSetupConstant,
.ISub = autoSetupConstant,
.SDiv = autoSetupConstant,
.SMod = autoSetupConstant,
.UDiv = autoSetupConstant,
.UMod = autoSetupConstant,
}); });
}; };
@@ -61,11 +79,21 @@ pub const RuntimeDispatcher = block: {
.AccessChain = opAccessChain, .AccessChain = opAccessChain,
.CompositeConstruct = opCompositeConstruct, .CompositeConstruct = opCompositeConstruct,
.CompositeExtract = opCompositeExtract, .CompositeExtract = opCompositeExtract,
.FMul = maths(.Float).opMul, .FAdd = maths(.Float, .Add).op,
.IMul = maths(.SInt).opMul, .FDiv = maths(.Float, .Div).op,
.FMod = maths(.Float, .Mod).op,
.FMul = maths(.Float, .Mul).op,
.FSub = maths(.Float, .Sub).op,
.IAdd = maths(.SInt, .Add).op,
.IMul = maths(.SInt, .Mul).op,
.ISub = maths(.SInt, .Sub).op,
.Load = opLoad, .Load = opLoad,
.Return = opReturn, .Return = opReturn,
.SDiv = maths(.SInt, .Div).op,
.SMod = maths(.SInt, .Mod).op,
.Store = opStore, .Store = opStore,
.UDiv = maths(.UInt, .Div).op,
.UMod = maths(.UInt, .Mod).op,
}); });
}; };
@@ -387,14 +415,18 @@ fn opConstant(allocator: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) R
switch (target.variant.?.Constant) { switch (target.variant.?.Constant) {
.Int => |*i| { .Int => |*i| {
if (word_count - 2 != 1) { if (word_count - 2 != 1) {
i.uint64 = @as(u64, try rt.it.next()) | (@as(u64, try rt.it.next()) >> 32); const low = @as(u64, try rt.it.next());
const high = @as(u64, try rt.it.next());
i.uint64 = (high << 32) | low;
} else { } else {
i.uint32 = try rt.it.next(); i.uint32 = try rt.it.next();
} }
}, },
.Float => |*f| { .Float => |*f| {
if (word_count - 2 != 1) { if (word_count - 2 != 1) {
f.float64 = @bitCast(@as(u64, try rt.it.next()) | (@as(u64, try rt.it.next()) >> 32)); const low = @as(u64, try rt.it.next());
const high = @as(u64, try rt.it.next());
f.float64 = @bitCast((high << 32) | low);
} else { } else {
f.float32 = @bitCast(try rt.it.next()); f.float32 = @bitCast(try rt.it.next());
} }
@@ -598,9 +630,9 @@ fn opReturn(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void {
} }
} }
fn maths(comptime T: MathType) type { fn maths(comptime T: MathType, comptime Op: MathOp) type {
return struct { return struct {
fn opMul(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { fn op(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void {
const target_type = (rt.results[try rt.it.next()].variant orelse return RuntimeError.InvalidSpirV).Type; const target_type = (rt.results[try rt.it.next()].variant orelse return RuntimeError.InvalidSpirV).Type;
const value = try rt.results[try rt.it.next()].getValue(); const value = try rt.results[try rt.it.next()].getValue();
const op1_value = try rt.results[try rt.it.next()].getValue(); const op1_value = try rt.results[try rt.it.next()].getValue();
@@ -614,26 +646,52 @@ fn maths(comptime T: MathType) type {
}; };
const operator = struct { const operator = struct {
fn operation(comptime TT: type, op1: TT, op2: TT) RuntimeError!TT {
return switch (Op) {
.Add => if (@typeInfo(TT) == .int) @addWithOverflow(op1, op2)[0] else op1 + op2,
.Sub => if (@typeInfo(TT) == .int) @subWithOverflow(op1, op2)[0] else op1 - op2,
.Mul => if (@typeInfo(TT) == .int) @mulWithOverflow(op1, op2)[0] else op1 * op2,
.Div => blk: {
if (op2 == 0) return RuntimeError.DivisionByZero;
break :blk if (@typeInfo(TT) == .int) @divTrunc(op1, op2) else op1 / op2;
},
.Mod => blk: {
if (op2 == 0) return RuntimeError.DivisionByZero;
break :blk @mod(op1, op2);
},
};
}
fn process(bit_count: SpvWord, v: *Result.Value, op1_v: *const Result.Value, op2_v: *const Result.Value) RuntimeError!void { fn process(bit_count: SpvWord, v: *Result.Value, op1_v: *const Result.Value, op2_v: *const Result.Value) RuntimeError!void {
switch (T) { switch (T) {
.Float => switch (bit_count) { .Float => switch (bit_count) {
16 => v.Float.float16 = op1_v.Float.float16 * op2_v.Float.float16, inline 16, 32, 64 => |i| @field(v.Float, std.fmt.comptimePrint("float{}", .{i})) = try operation(
32 => v.Float.float32 = op1_v.Float.float32 * op2_v.Float.float32, @Type(.{ .float = .{ .bits = i } }),
64 => v.Float.float64 = op1_v.Float.float64 * op2_v.Float.float64, @field(op1_v.Float, std.fmt.comptimePrint("float{}", .{i})),
@field(op2_v.Float, std.fmt.comptimePrint("float{}", .{i})),
),
else => return RuntimeError.InvalidSpirV, else => return RuntimeError.InvalidSpirV,
}, },
.SInt => switch (bit_count) { .SInt => switch (bit_count) {
8 => v.Int.sint8 = @mulWithOverflow(op1_v.Int.sint8, op2_v.Int.sint8)[0], inline 8, 16, 32, 64 => |i| @field(v.Int, std.fmt.comptimePrint("sint{}", .{i})) = try operation(
16 => v.Int.sint16 = @mulWithOverflow(op1_v.Int.sint16, op2_v.Int.sint16)[0], @Type(.{ .int = .{
32 => v.Int.sint32 = @mulWithOverflow(op1_v.Int.sint32, op2_v.Int.sint32)[0], .signedness = .signed,
64 => v.Int.sint64 = @mulWithOverflow(op1_v.Int.sint64, op2_v.Int.sint64)[0], .bits = i,
} }),
@field(op1_v.Int, std.fmt.comptimePrint("sint{}", .{i})),
@field(op2_v.Int, std.fmt.comptimePrint("sint{}", .{i})),
),
else => return RuntimeError.InvalidSpirV, else => return RuntimeError.InvalidSpirV,
}, },
.UInt => switch (bit_count) { .UInt => switch (bit_count) {
8 => v.Int.uint8 = @mulWithOverflow(op1_v.Int.uint8, op2_v.Int.uint8)[0], inline 8, 16, 32, 64 => |i| @field(v.Int, std.fmt.comptimePrint("uint{}", .{i})) = try operation(
16 => v.Int.uint16 = @mulWithOverflow(op1_v.Int.uint16, op2_v.Int.uint16)[0], @Type(.{ .int = .{
32 => v.Int.uint32 = @mulWithOverflow(op1_v.Int.uint32, op2_v.Int.uint32)[0], .signedness = .unsigned,
64 => v.Int.uint64 = @mulWithOverflow(op1_v.Int.uint64, op2_v.Int.uint64)[0], .bits = i,
} }),
@field(op1_v.Int, std.fmt.comptimePrint("uint{}", .{i})),
@field(op2_v.Int, std.fmt.comptimePrint("uint{}", .{i})),
),
else => return RuntimeError.InvalidSpirV, else => return RuntimeError.InvalidSpirV,
}, },
} }

View File

@@ -3,7 +3,7 @@ const root = @import("root.zig");
const compileNzsl = root.compileNzsl; const compileNzsl = root.compileNzsl;
const case = root.case; const case = root.case;
test "FMul vec4[f32]" { test "Simple fragment shader" {
const allocator = std.testing.allocator; const allocator = std.testing.allocator;
const shader = const shader =
\\ [nzsl_version("1.1")] \\ [nzsl_version("1.1")]

View File

@@ -3,62 +3,157 @@ const root = @import("root.zig");
const compileNzsl = root.compileNzsl; const compileNzsl = root.compileNzsl;
const case = root.case; const case = root.case;
test "Mul vec4" { const Operations = enum {
const allocator = std.testing.allocator; Add,
const types = [_]type{ Sub,
f32, Mul,
//f64, Div,
i32, Mod,
u32, };
fn Vec(comptime len: usize, comptime T: type) type {
return struct {
const Self = @This();
val: @Vector(len, T),
pub fn format(self: *const Self, w: *std.Io.Writer) std.Io.Writer.Error!void {
inline for (0..len) |i| {
try w.print("{d}", .{self.val[i]});
if (i < len - 1) try w.writeAll(", ");
}
}
}; };
}
inline for (types) |T| { // Tests all mathematical operation on all NZSL supported primitive types
const base_color = case.random(@Vector(4, T)); test "Maths primitives" {
const ratio = case.random(@Vector(4, T)); const allocator = std.testing.allocator;
const expected = switch (@typeInfo(T)) { const types = [_]type{ f32, f64, i32, u32 };
.float => base_color * ratio, var operations = std.EnumMap(Operations, u8).init(.{
.int => @mulWithOverflow(base_color, ratio)[0], .Add = '+',
else => unreachable, .Sub = '-',
}; .Mul = '*',
.Div = '/',
.Mod = '%',
});
const shader = try std.fmt.allocPrint( var it = operations.iterator();
allocator, while (it.next()) |op| {
\\ [nzsl_version("1.1")] inline for (types) |T| {
\\ [feature(float64)] const base: T = case.random(T);
\\ module; const ratio: T = case.random(T);
\\ const expected = switch (op.key) {
\\ struct FragOut .Add => if (@typeInfo(T) == .int) @addWithOverflow(base, ratio)[0] else base + ratio,
\\ {{ .Sub => if (@typeInfo(T) == .int) @subWithOverflow(base, ratio)[0] else base - ratio,
\\ [location(0)] color: vec4[{s}] .Mul => if (@typeInfo(T) == .int) @mulWithOverflow(base, ratio)[0] else base * ratio,
\\ }} .Div => if (@typeInfo(T) == .int) @divTrunc(base, ratio) else base / ratio,
\\ .Mod => @mod(base, ratio),
\\ [entry(frag)] };
\\ fn main() -> FragOut
\\ {{ const shader = try std.fmt.allocPrint(
\\ let ratio = vec4[{s}]({d}, {d}, {d}, {d}); allocator,
\\ \\ [nzsl_version("1.1")]
\\ let output: FragOut; \\ [feature(float64)]
\\ output.color = vec4[{s}]({d}, {d}, {d}, {d}) * ratio; \\ module;
\\ return output; \\
\\ }} \\ struct FragOut
, \\ {{
.{ \\ [location(0)] color: vec4[{s}]
@typeName(T), \\ }}
@typeName(T), \\
ratio[0], \\ [entry(frag)]
ratio[1], \\ fn main() -> FragOut
ratio[2], \\ {{
ratio[3], \\ let ratio: {s} = {d};
@typeName(T), \\ let base: {s} = {d};
base_color[0], \\ let color = base {c} ratio;
base_color[1], \\
base_color[2], \\ let output: FragOut;
base_color[3], \\ output.color = vec4[{s}](color, color, color, color);
}, \\ return output;
); \\ }}
defer allocator.free(shader); ,
const code = try compileNzsl(allocator, shader); .{
defer allocator.free(code); @typeName(T),
try case.expectOutput(T, 4, code, "color", &@as([4]T, expected)); @typeName(T),
ratio,
@typeName(T),
base,
op.value.*,
@typeName(T),
},
);
defer allocator.free(shader);
const code = try compileNzsl(allocator, shader);
defer allocator.free(code);
try case.expectOutput(T, 4, code, "color", &.{ expected, expected, expected, expected });
}
}
}
// Tests all mathematical operation on vec2/3/4 with all NZSL supported primitive types
test "Maths vectors" {
const allocator = std.testing.allocator;
const types = [_]type{ f32, f64, i32, u32 };
var operations = std.EnumMap(Operations, u8).init(.{
.Add = '+',
.Sub = '-',
.Mul = '*',
.Div = '/',
.Mod = '%',
});
var it = operations.iterator();
while (it.next()) |op| {
inline for (2..5) |L| {
inline for (types) |T| {
const base_color: Vec(L, T) = .{ .val = case.random(@Vector(L, T)) };
const ratio: Vec(L, T) = .{ .val = case.random(@Vector(L, T)) };
const expected = switch (op.key) {
.Add => if (@typeInfo(T) == .int) @addWithOverflow(base_color.val, ratio.val)[0] else base_color.val + ratio.val,
.Sub => if (@typeInfo(T) == .int) @subWithOverflow(base_color.val, ratio.val)[0] else base_color.val - ratio.val,
.Mul => if (@typeInfo(T) == .int) @mulWithOverflow(base_color.val, ratio.val)[0] else base_color.val * ratio.val,
.Div => if (@typeInfo(T) == .int) @divTrunc(base_color.val, ratio.val) else base_color.val / ratio.val,
.Mod => @mod(base_color.val, ratio.val),
};
const shader = try std.fmt.allocPrint(
allocator,
\\ [nzsl_version("1.1")]
\\ [feature(float64)]
\\ module;
\\
\\ struct FragOut
\\ {{
\\ [location(0)] color: vec{d}[{s}]
\\ }}
\\
\\ [entry(frag)]
\\ fn main() -> FragOut
\\ {{
\\ let ratio = vec{d}[{s}]({f});
\\
\\ let output: FragOut;
\\ output.color = vec{d}[{s}]({f}) {c} ratio;
\\ return output;
\\ }}
,
.{
L,
@typeName(T),
L,
@typeName(T),
ratio,
L,
@typeName(T),
base_color,
op.value.*,
},
);
defer allocator.free(shader);
const code = try compileNzsl(allocator, shader);
defer allocator.free(code);
try case.expectOutput(T, L, code, "color", &@as([L]T, expected));
}
}
} }
} }