From dacd67b858fa917f0793cae6182b3fa95190f412 Mon Sep 17 00:00:00 2001 From: Kbz-8 Date: Tue, 13 Jan 2026 00:06:49 +0100 Subject: [PATCH] adding proper unit testing --- build.zig | 12 ++++++- build.zig.zon | 5 +++ src/opcodes.zig | 83 +++++++++++++++++++++++++++++++------------------ test/basics.zig | 29 +++++++++++++++++ test/maths.zig | 58 ++++++++++++++++++++++++++++++++++ test/root.zig | 43 +++++++++++++++++++++++++ 6 files changed, 198 insertions(+), 32 deletions(-) create mode 100644 test/basics.zig create mode 100644 test/maths.zig create mode 100644 test/root.zig diff --git a/build.zig b/build.zig index b9ed751..6d2689c 100644 --- a/build.zig +++ b/build.zig @@ -46,7 +46,17 @@ pub fn build(b: *std.Build) void { // Zig unit tests setup - const lib_tests = b.addTest(.{ .root_module = mod }); + const nzsl = b.lazyDependency("NZSL", .{}) orelse return; + const test_mod = b.createModule(.{ + .root_source_file = b.path("test/root.zig"), + .target = target, + .optimize = optimize, + .imports = &.{ + .{ .name = "spv", .module = mod }, + .{ .name = "nzsl", .module = nzsl.module("nzigsl") }, + }, + }); + const lib_tests = b.addTest(.{ .root_module = test_mod }); const run_tests = b.addRunArtifact(lib_tests); const test_step = b.step("test", "Run Zig unit tests"); test_step.dependOn(&run_tests.step); diff --git a/build.zig.zon b/build.zig.zon index a4a23e5..4a0581f 100644 --- a/build.zig.zon +++ b/build.zig.zon @@ -6,6 +6,11 @@ .url = "git+https://github.com/Kbz-8/pretty#117674465efd4d07d5ae9d9d8ca59c2c323a65ba", .hash = "pretty-0.10.6-Tm65r99UAQDEJMgZysD10qE8dinBHr064fPM6YkxVPfB", }, + .NZSL = .{ // For unit tests + .url = "git+https://github.com/Kbz-8/NZigSL#60a82680901e806f322789b585bf32547f5f5442", + .hash = "NZSL-1.1.1-N0xSVGd6AABOa_qLFkTeeprKLj_2YqayosICqh10_nbB", + .lazy = true, + }, }, .minimum_zig_version = "0.15.2", .paths = .{ diff --git a/src/opcodes.zig b/src/opcodes.zig index 604b20d..0644011 100644 --- a/src/opcodes.zig +++ b/src/opcodes.zig @@ -24,8 +24,10 @@ pub const SetupDispatcher = block: { .Decorate = opDecorate, .EntryPoint = opEntryPoint, .ExecutionMode = opExecutionMode, + .FMul = autoSetupConstant, .Function = opFunction, .FunctionEnd = opFunctionEnd, + .IMul = autoSetupConstant, .Label = opLabel, .Load = autoSetupConstant, .MemberDecorate = opDecorateMember, @@ -44,7 +46,6 @@ pub const SetupDispatcher = block: { .TypeVector = opTypeVector, .TypeVoid = opTypeVoid, .Variable = opVariable, - .FMul = autoSetupConstant, }); }; @@ -54,10 +55,11 @@ pub const RuntimeDispatcher = block: { .AccessChain = opAccessChain, .CompositeConstruct = opCompositeConstruct, .CompositeExtract = opCompositeExtract, + .FMul = opFMul, + .IMul = opIMul, .Load = opLoad, .Return = opReturn, .Store = opStore, - .FMul = opFMul, }); }; @@ -591,43 +593,62 @@ fn opReturn(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { } fn opFMul(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { - const res_type = try rt.it.next(); - const id = try rt.it.next(); - const op1 = try rt.it.next(); - const op2 = try rt.it.next(); + const target_type = (rt.results[try rt.it.next()].variant orelse return RuntimeError.InvalidSpirV).Type; + const value = try rt.results[try rt.it.next()].getValue(); + const op1_value = try rt.results[try rt.it.next()].getValue(); + const op2_value = try rt.results[try rt.it.next()].getValue(); - const target_type = (rt.results[res_type].variant orelse return RuntimeError.InvalidSpirV).Type; - - const target = &rt.results[id]; - const value = try target.getValue(); - - const op1_target = &rt.results[op1]; - const op1_value = try op1_target.getValue(); - - const op2_target = &rt.results[op2]; - const op2_value = try op2_target.getValue(); - - const float_size = sw: switch (target_type) { + const size = sw: switch (target_type) { .Vector => |v| continue :sw (rt.results[v.components_type_word].variant orelse return RuntimeError.InvalidSpirV).Type, .Float => |f| f.bit_length, else => return RuntimeError.InvalidSpirV, }; - switch (value.*) { - .Float => switch (float_size) { - 16 => value.Float.float16 = op1_value.Float.float16 * op2_value.Float.float16, - 32 => value.Float.float32 = op1_value.Float.float32 * op2_value.Float.float32, - 64 => value.Float.float64 = op1_value.Float.float64 * op2_value.Float.float64, - else => return RuntimeError.InvalidSpirV, - }, - .Vector => |vec| for (vec, op1_value.Vector, op2_value.Vector) |*val, op1_v, op2_v| { - switch (float_size) { - 16 => val.Float.float16 = op1_v.Float.float16 * op2_v.Float.float16, - 32 => val.Float.float32 = op1_v.Float.float32 * op2_v.Float.float32, - 64 => val.Float.float64 = op1_v.Float.float64 * op2_v.Float.float64, + const operator = struct { + fn process(bit_count: SpvWord, v: *Result.Value, op1_v: *const Result.Value, op2_v: *const Result.Value) RuntimeError!void { + switch (bit_count) { + 16 => v.Float.float16 = op1_v.Float.float16 * op2_v.Float.float16, + 32 => v.Float.float32 = op1_v.Float.float32 * op2_v.Float.float32, + 64 => v.Float.float64 = op1_v.Float.float64 * op2_v.Float.float64, else => return RuntimeError.InvalidSpirV, } - }, + } + }; + + switch (value.*) { + .Float => try operator.process(size, value, op1_value, op2_value), + .Vector => |vec| for (vec, op1_value.Vector, op2_value.Vector) |*val, op1_v, op2_v| try operator.process(size, val, &op1_v, &op2_v), + else => return RuntimeError.InvalidSpirV, + } +} + +fn opIMul(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { + const target_type = (rt.results[try rt.it.next()].variant orelse return RuntimeError.InvalidSpirV).Type; + const value = try rt.results[try rt.it.next()].getValue(); + const op1_value = try rt.results[try rt.it.next()].getValue(); + const op2_value = try rt.results[try rt.it.next()].getValue(); + + const size = sw: switch (target_type) { + .Vector => |v| continue :sw (rt.results[v.components_type_word].variant orelse return RuntimeError.InvalidSpirV).Type, + .Int => |i| i.bit_length, + else => return RuntimeError.InvalidSpirV, + }; + + const operator = struct { + fn process(bit_count: SpvWord, v: *Result.Value, op1_v: *const Result.Value, op2_v: *const Result.Value) RuntimeError!void { + switch (bit_count) { + 8 => v.Int.sint8 = op1_v.Int.sint8 * op2_v.Int.sint8, + 16 => v.Int.sint16 = op1_v.Int.sint16 * op2_v.Int.sint16, + 32 => v.Int.sint32 = op1_v.Int.sint32 * op2_v.Int.sint32, + 64 => v.Int.sint64 = op1_v.Int.sint64 * op2_v.Int.sint64, + else => return RuntimeError.InvalidSpirV, + } + } + }; + + switch (value.*) { + .Int => try operator.process(size, value, op1_value, op2_value), + .Vector => |vec| for (vec, op1_value.Vector, op2_value.Vector) |*val, op1_v, op2_v| try operator.process(size, val, &op1_v, &op2_v), else => return RuntimeError.InvalidSpirV, } } diff --git a/test/basics.zig b/test/basics.zig new file mode 100644 index 0000000..5e1b76f --- /dev/null +++ b/test/basics.zig @@ -0,0 +1,29 @@ +const std = @import("std"); +const root = @import("root.zig"); +const compileNzsl = root.compileNzsl; +const case = root.case; + +test "FMul vec4[f32]" { + const allocator = std.testing.allocator; + const shader = + \\ [nzsl_version("1.1")] + \\ module; + \\ + \\ struct FragOut + \\ { + \\ [location(0)] color: vec4[f32] + \\ } + \\ + \\ [entry(frag)] + \\ fn main() -> FragOut + \\ { + \\ let output: FragOut; + \\ output.color = vec4[f32](4.0, 3.0, 2.0, 1.0); + \\ return output; + \\ } + ; + const code = try compileNzsl(allocator, shader); + defer allocator.free(code); + + try case.expectOutput(f32, code, "color", &.{ 4, 3, 2, 1 }); +} diff --git a/test/maths.zig b/test/maths.zig new file mode 100644 index 0000000..fde85f9 --- /dev/null +++ b/test/maths.zig @@ -0,0 +1,58 @@ +const std = @import("std"); +const root = @import("root.zig"); +const compileNzsl = root.compileNzsl; +const case = root.case; + +test "FMul vec4[f32]" { + const allocator = std.testing.allocator; + const shader = + \\ [nzsl_version("1.1")] + \\ module; + \\ + \\ struct FragOut + \\ { + \\ [location(0)] color: vec4[f32] + \\ } + \\ + \\ [entry(frag)] + \\ fn main() -> FragOut + \\ { + \\ let ratio = vec4[f32](2.0, 2.0, 8.0, 0.25); + \\ + \\ let output: FragOut; + \\ output.color = vec4[f32](4.0, 3.0, 2.0, 1.0) * ratio; + \\ return output; + \\ } + ; + const code = try compileNzsl(allocator, shader); + defer allocator.free(code); + + try case.expectOutput(f32, code, "color", &.{ 8, 6, 16, 0.25 }); +} + +test "IMul vec4[i32]" { + const allocator = std.testing.allocator; + const shader = + \\ [nzsl_version("1.1")] + \\ module; + \\ + \\ struct FragOut + \\ { + \\ [location(0)] color: vec4[i32] + \\ } + \\ + \\ [entry(frag)] + \\ fn main() -> FragOut + \\ { + \\ let ratio = vec4[i32](2, 2, 8, 25); + \\ + \\ let output: FragOut; + \\ output.color = vec4[i32](4, 3, 2, 1) * ratio; + \\ return output; + \\ } + ; + const code = try compileNzsl(allocator, shader); + defer allocator.free(code); + + try case.expectOutput(i32, code, "color", &.{ 8, 6, 16, 25 }); +} diff --git a/test/root.zig b/test/root.zig new file mode 100644 index 0000000..d82c19e --- /dev/null +++ b/test/root.zig @@ -0,0 +1,43 @@ +const std = @import("std"); +const spv = @import("spv"); +const nzsl = @import("nzsl"); + +pub fn compileNzsl(allocator: std.mem.Allocator, source: []const u8) ![]const u32 { + const module = try nzsl.parser.parseSource(source); + defer module.deinit(); + + const params = try nzsl.BackendParameters.init(); + defer params.deinit(); + params.setDebugLevel(.full); + + const writer = try nzsl.SpirvWriter.init(); + defer writer.deinit(); + + const output = try writer.generate(module, params); + defer output.deinit(); + + return allocator.dupe(u32, output.getCode()); +} + +pub const case = struct { + pub fn expectOutput(comptime T: type, source: []const u32, output_name: []const u8, comptime expected: []const T) !void { + const allocator = std.testing.allocator; + + var module = try spv.Module.init(allocator, source); + defer module.deinit(allocator); + + var rt = try spv.Runtime.init(allocator, &module); + defer rt.deinit(allocator); + + try rt.callEntryPoint(allocator, try rt.getEntryPointByName("main")); + var output: [expected.len]T = undefined; + try rt.readOutput(T, output[0..output.len], try rt.getResultByName(output_name)); + + try std.testing.expectEqualSlices(T, expected, &output); + } +}; + +test { + std.testing.refAllDecls(@import("basics.zig")); + std.testing.refAllDecls(@import("maths.zig")); +}