diff --git a/build.zig b/build.zig index 527b344..f81b6de 100644 --- a/build.zig +++ b/build.zig @@ -31,7 +31,7 @@ pub fn build(b: *std.Build) void { addSandbox(b, target, optimize, use_llvm, spv_mod, &install_spv_lib.step); addExample(b, target, optimize, use_llvm, spv_mod, &install_spv_lib.step); - addZigTests(b, target, optimize, spv_mod, zmath); + addZigTests(b, target, optimize, use_llvm, spv_mod, zmath); addCffi(b, target, optimize, use_llvm, spv_mod); addDocs(b, spv_mod); } @@ -146,6 +146,7 @@ fn addZigTests( b: *std.Build, target: std.Build.ResolvedTarget, optimize: std.builtin.OptimizeMode, + use_llvm: bool, spv_mod: *std.Build.Module, zmath: *std.Build.Dependency, ) void { @@ -172,6 +173,7 @@ fn addZigTests( .path = b.path("test/test_runner.zig"), .mode = .simple, }, + .use_llvm = use_llvm, }); const run_tests = b.addRunArtifact(tests); diff --git a/src/opcodes.zig b/src/opcodes.zig index 4dff70a..dfa1100 100644 --- a/src/opcodes.zig +++ b/src/opcodes.zig @@ -933,7 +933,7 @@ fn MathEngine(comptime T: PrimitiveType, comptime Op: MathOp, comptime IsAtomic: fn applyScalar(bit_count: SpvWord, d: *Value, l: *Value, r: *Value) RuntimeError!void { switch (bit_count) { inline 8, 16, 32, 64 => |bits| { - if (bits == 8 and T == .Float) return RuntimeError.InvalidSpirV; + if (bits == 8 and T == .Float) return RuntimeError.UnsupportedSpirV; const ScalarT = Value.getPrimitiveFieldType(T, bits); const d_field = try Value.getPrimitiveField(T, bits, d); @@ -941,13 +941,18 @@ fn MathEngine(comptime T: PrimitiveType, comptime Op: MathOp, comptime IsAtomic: const r_field = try Value.getPrimitiveField(T, bits, r); d_field.* = try operation(ScalarT, l_field.*, r_field.*); }, - else => return RuntimeError.InvalidSpirV, + else => return RuntimeError.UnsupportedSpirV, } } - inline fn applyVectorTimesScalarF32(d: []Value, l: []const Value, r: f32) void { + inline fn applyVectorTimesScalarFloat(comptime bit_count: SpvWord, d: []Value, l: []const Value, r_v: *const Value) RuntimeError!void { for (d, l) |*d_v, l_v| { - d_v.Float.value.float32 = l_v.Float.value.float32 * r; + switch (bit_count) { + 16 => d_v.Float.value.float16 = l_v.Float.value.float16 * r_v.Float.value.float16, + 32 => d_v.Float.value.float32 = l_v.Float.value.float32 * r_v.Float.value.float32, + 64 => d_v.Float.value.float64 = l_v.Float.value.float64 * r_v.Float.value.float64, + else => return RuntimeError.UnsupportedSpirV, + } } } @@ -963,7 +968,7 @@ fn MathEngine(comptime T: PrimitiveType, comptime Op: MathOp, comptime IsAtomic: } } - inline fn applySIMDVectorf32(comptime N: usize, d: *@Vector(N, f32), l: *const @Vector(N, f32), r: *const Value) RuntimeError!void { + fn applySIMDVectorf32(comptime N: usize, d: *@Vector(N, f32), l: *const @Vector(N, f32), r: *const Value) RuntimeError!void { switch (Op) { .VectorTimesScalar => applyVectorSIMDTimesScalarF32(N, d, l, r.Float.value.float32), else => { @@ -983,7 +988,10 @@ fn MathEngine(comptime T: PrimitiveType, comptime Op: MathOp, comptime IsAtomic: .Int, .Float => try operator.applyScalar(lane_bits, dst, lhs, rhs), .Vector => |dst_vec| switch (Op) { - .VectorTimesScalar => operator.applyVectorTimesScalarF32(dst_vec, lhs.Vector, rhs.Float.value.float32), + .VectorTimesScalar => switch (lane_bits) { + inline 16, 32, 64 => |bits_count| try operator.applyVectorTimesScalarFloat(bits_count, dst_vec, lhs.Vector, rhs), + else => return RuntimeError.UnsupportedSpirV, + }, else => for (dst_vec, lhs.Vector, rhs.Vector) |*d_lane, *l_lane, *r_lane| { try operator.applyScalar(lane_bits, d_lane, l_lane, r_lane); }, @@ -1315,7 +1323,7 @@ fn opAccessChain(allocator: std.mem.Allocator, word_count: SpvWord, rt: *Runtime if (a.indexes.len != index_count) return RuntimeError.InvalidSpirV; try a.value.flushPtr(allocator); - //a.value.deinit(allocator); + a.value.deinit(allocator); break :blk .{ a.indexes, false }; }, else => {}, diff --git a/test/maths.zig b/test/maths.zig index f1d0c80..69f0500 100644 --- a/test/maths.zig +++ b/test/maths.zig @@ -154,3 +154,70 @@ test "Maths vectors" { } } } + +// Tests all mathematical operation on vec2/3/4 with scalars with all NZSL supported primitive types +test "Maths vectors with scalars" { + const allocator = std.testing.allocator; + const types = [_]type{ f32, f64, i32, u32 }; + var operations = std.EnumMap(Operations, u8).init(.{ + .Mul = '*', + .Div = '/', + .Mod = '%', + }); + + var it = operations.iterator(); + while (it.next()) |op| { + inline for (2..5) |L| { + inline for (types) |T| { + const base_color: case.Vec(L, T) = .{ .val = case.random(@Vector(L, T)) }; + const ratio = case.random(T); + const splat_ratio = @as(@Vector(L, T), @splat(ratio)); + const expected = switch (op.key) { + .Mul => if (@typeInfo(T) == .int) @mulWithOverflow(base_color.val, splat_ratio)[0] else base_color.val * splat_ratio, + .Div => if (@typeInfo(T) == .int) @divTrunc(base_color.val, splat_ratio) else base_color.val / splat_ratio, + .Mod => @mod(base_color.val, splat_ratio), + else => unreachable, + }; + + const shader = try std.fmt.allocPrint( + allocator, + \\ [nzsl_version("1.1")] + \\ [feature(float64)] + \\ module; + \\ + \\ struct FragOut + \\ {{ + \\ [location(0)] color: vec{d}[{s}] + \\ }} + \\ + \\ [entry(frag)] + \\ fn main() -> FragOut + \\ {{ + \\ let output: FragOut; + \\ output.color = vec{d}[{s}]({f}) {c} {d}; + \\ return output; + \\ }} + , + .{ + L, + @typeName(T), + L, + @typeName(T), + base_color, + op.value.*, + ratio, + }, + ); + defer allocator.free(shader); + const code = try compileNzsl(allocator, shader); + defer allocator.free(code); + try case.expect(.{ + .source = code, + .expected_outputs = &.{ + std.mem.asBytes(&@as([L]T, expected)), + }, + }); + } + } + } +}