diff --git a/test/arrays.zig b/test/arrays.zig index bf5d592..0a7ccb7 100644 --- a/test/arrays.zig +++ b/test/arrays.zig @@ -26,5 +26,10 @@ test "Simple array" { const code = try compileNzsl(allocator, shader); defer allocator.free(code); - try case.expectOutput(f32, 4, code, "color", &.{ 4, 3, 2, 1 }); + try case.expect(.{ + .source = code, + .expected_outputs = &.{ + std.mem.asBytes(&[_]f32{ 4, 3, 2, 1 }), + }, + }); } diff --git a/test/basics.zig b/test/basics.zig index 518237d..09f4b27 100644 --- a/test/basics.zig +++ b/test/basics.zig @@ -25,5 +25,10 @@ test "Simple fragment shader" { const code = try compileNzsl(allocator, shader); defer allocator.free(code); - try case.expectOutput(f32, 4, code, "color", &.{ 4, 3, 2, 1 }); + try case.expect(.{ + .source = code, + .expected_outputs = &.{ + std.mem.asBytes(&[_]f32{ 4, 3, 2, 1 }), + }, + }); } diff --git a/test/bitwise.zig b/test/bitwise.zig index a648e56..f88756d 100644 --- a/test/bitwise.zig +++ b/test/bitwise.zig @@ -73,7 +73,12 @@ test "Bitwise primitives" { defer allocator.free(shader); const code = try compileNzsl(allocator, shader); defer allocator.free(code); - try case.expectOutput(T, 4, code, "color", &.{ expected, expected, expected, expected }); + try case.expect(.{ + .source = code, + .expected_outputs = &.{ + std.mem.asBytes(&[_]T{ expected, expected, expected, expected }), + }, + }); } } } @@ -142,7 +147,12 @@ test "Bitwise vectors" { defer allocator.free(shader); const code = try compileNzsl(allocator, shader); defer allocator.free(code); - try case.expectOutput(T, L, code, "color", &@as([L]T, expected)); + try case.expect(.{ + .source = code, + .expected_outputs = &.{ + std.mem.asBytes(&@as([L]T, expected)), + }, + }); } } } diff --git a/test/branching.zig b/test/branching.zig index cecc6d3..0c00f7f 100644 --- a/test/branching.zig +++ b/test/branching.zig @@ -93,7 +93,12 @@ test "Simple branching" { defer allocator.free(shader); const code = try compileNzsl(allocator, shader); defer allocator.free(code); - try case.expectOutput(T, 4, code, "color", &.{ expected, expected, expected, expected }); + try case.expect(.{ + .source = code, + .expected_outputs = &.{ + std.mem.asBytes(&[_]T{ expected, expected, expected, expected }), + }, + }); } } } diff --git a/test/casts.zig b/test/casts.zig index 100a5e3..16683bd 100644 --- a/test/casts.zig +++ b/test/casts.zig @@ -55,7 +55,12 @@ test "Primitives casts" { defer allocator.free(shader); const code = try compileNzsl(allocator, shader); defer allocator.free(code); - try case.expectOutput(T[1], 4, code, "color", &.{ expected, expected, expected, expected }); + try case.expect(.{ + .source = code, + .expected_outputs = &.{ + std.mem.asBytes(&[_]T[1]{ expected, expected, expected, expected }), + }, + }); } } @@ -103,6 +108,11 @@ test "Primitives bitcasts" { defer allocator.free(shader); const code = try compileNzsl(allocator, shader); defer allocator.free(code); - try case.expectOutput(T[1], 4, code, "color", &.{ expected, expected, expected, expected }); + try case.expect(.{ + .source = code, + .expected_outputs = &.{ + std.mem.asBytes(&[_]T[1]{ expected, expected, expected, expected }), + }, + }); } } diff --git a/test/functions.zig b/test/functions.zig index 96f3a39..441205b 100644 --- a/test/functions.zig +++ b/test/functions.zig @@ -44,7 +44,12 @@ test "Simple function calls" { defer allocator.free(shader); const code = try compileNzsl(allocator, shader); defer allocator.free(code); - try case.expectOutput(T, 4, code, "color", &.{ n, n, n, n }); + try case.expect(.{ + .source = code, + .expected_outputs = &.{ + std.mem.asBytes(&[_]T{ n, n, n, n }), + }, + }); } } @@ -95,6 +100,11 @@ test "Nested function calls" { defer allocator.free(shader); const code = try compileNzsl(allocator, shader); defer allocator.free(code); - try case.expectOutput(T, 4, code, "color", &.{ n, n, n, n }); + try case.expect(.{ + .source = code, + .expected_outputs = &.{ + std.mem.asBytes(&[_]T{ n, n, n, n }), + }, + }); } } diff --git a/test/inputs.zig b/test/inputs.zig index af9e100..3bc0ec3 100644 --- a/test/inputs.zig +++ b/test/inputs.zig @@ -45,7 +45,15 @@ test "Inputs" { defer allocator.free(shader); const code = try compileNzsl(allocator, shader); defer allocator.free(code); - try case.expectOutputWithInput(T, L, code, "color", &@as([L]T, input.val), "pos", &@as([L]T, input.val)); + try case.expect(.{ + .source = code, + .inputs = &.{ + std.mem.asBytes(&@as([L]T, input.val)), + }, + .expected_outputs = &.{ + std.mem.asBytes(&@as([L]T, input.val)), + }, + }); } } } diff --git a/test/loops.zig b/test/loops.zig index 30c335e..6598165 100644 --- a/test/loops.zig +++ b/test/loops.zig @@ -47,5 +47,10 @@ test "Simple while loop" { defer allocator.free(shader); const code = try compileNzsl(allocator, shader); defer allocator.free(code); - try case.expectOutput(f32, 4, code, "color", &.{ expected, expected, expected, expected }); + try case.expect(.{ + .source = code, + .expected_outputs = &.{ + std.mem.asBytes(&[_]f32{ expected, expected, expected, expected }), + }, + }); } diff --git a/test/maths.zig b/test/maths.zig index d38f25d..f1d0c80 100644 --- a/test/maths.zig +++ b/test/maths.zig @@ -72,7 +72,12 @@ test "Maths primitives" { defer allocator.free(shader); const code = try compileNzsl(allocator, shader); defer allocator.free(code); - try case.expectOutput(T, 4, code, "color", &.{ expected, expected, expected, expected }); + try case.expect(.{ + .source = code, + .expected_outputs = &.{ + std.mem.asBytes(&[_]T{ expected, expected, expected, expected }), + }, + }); } } } @@ -139,7 +144,12 @@ test "Maths vectors" { defer allocator.free(shader); const code = try compileNzsl(allocator, shader); defer allocator.free(code); - try case.expectOutput(T, L, code, "color", &@as([L]T, expected)); + try case.expect(.{ + .source = code, + .expected_outputs = &.{ + std.mem.asBytes(&@as([L]T, expected)), + }, + }); } } } diff --git a/test/root.zig b/test/root.zig index dc52187..0dbfa57 100644 --- a/test/root.zig +++ b/test/root.zig @@ -20,34 +20,13 @@ pub fn compileNzsl(allocator: std.mem.Allocator, source: []const u8) ![]const u3 } pub const case = struct { - pub fn expectOutput(comptime T: type, comptime len: usize, source: []const u32, output_name: []const u8, expected: []const T) !void { - const allocator = std.testing.allocator; + pub const Config = struct { + source: []const u32, + inputs: []const []const u8 = &.{}, + expected_outputs: []const []const u8, + }; - const module_options = [_]spv.Module.ModuleOptions{ - .{ - .use_simd_vectors_specializations = true, - }, - .{ - .use_simd_vectors_specializations = false, - }, - }; - - for (module_options) |opt| { - var module = try spv.Module.init(allocator, source, opt); - defer module.deinit(allocator); - - var rt = try spv.Runtime.init(allocator, &module); - defer rt.deinit(allocator); - - try rt.callEntryPoint(allocator, try rt.getEntryPointByName("main")); - var output: [len]T = undefined; - try rt.readOutput(std.mem.sliceAsBytes(output[0..]), try rt.getResultByName(output_name)); - - try std.testing.expectEqualSlices(T, expected, &output); - } - } - - pub fn expectOutputWithInput(comptime T: type, comptime len: usize, source: []const u32, output_name: []const u8, expected: []const T, input_name: []const u8, input: []const T) !void { + pub fn expect(config: Config) !void { const allocator = std.testing.allocator; // To test with all important module options @@ -61,19 +40,25 @@ pub const case = struct { }; for (module_options) |opt| { - var module = try spv.Module.init(allocator, source, opt); + var module = try spv.Module.init(allocator, config.source, opt); defer module.deinit(allocator); var rt = try spv.Runtime.init(allocator, &module); defer rt.deinit(allocator); - try rt.writeInput(std.mem.sliceAsBytes(input[0..len]), try rt.getResultByName(input_name)); + for (config.inputs, 0..) |input, n| { + try rt.writeInput(input[0..], module.input_locations[n]); + } try rt.callEntryPoint(allocator, try rt.getEntryPointByName("main")); - var output: [len]T = undefined; - try rt.readOutput(std.mem.sliceAsBytes(output[0..]), try rt.getResultByName(output_name)); - try std.testing.expectEqualSlices(T, expected, &output); + for (config.expected_outputs, 0..) |expected, n| { + const output = try allocator.alloc(u8, expected.len); + defer allocator.free(output); + + try rt.readOutput(output[0..], module.output_locations[n]); + try std.testing.expectEqualSlices(u8, expected, output); + } } }