From ca33cfe3e997503208f031d270018c10d0611989 Mon Sep 17 00:00:00 2001 From: Kbz-8 Date: Mon, 11 May 2026 21:36:04 +0200 Subject: [PATCH] adding some matrix operations --- src/Module.zig | 1 + src/Value.zig | 28 ++++++ src/opcodes.zig | 220 +++++++++++++++++++++++++++--------------------- test/maths.zig | 61 ++++++++++++++ test/root.zig | 6 +- 5 files changed, 218 insertions(+), 98 deletions(-) diff --git a/src/Module.zig b/src/Module.zig index 5557930..8353ca6 100644 --- a/src/Module.zig +++ b/src/Module.zig @@ -19,6 +19,7 @@ const WordIterator = @import("WordIterator.zig"); const Self = @This(); pub const ModuleOptions = struct { + /// Also affects matrices use_simd_vectors_specializations: bool = true, }; diff --git a/src/Value.zig b/src/Value.zig index 354d752..ec30762 100644 --- a/src/Value.zig +++ b/src/Value.zig @@ -847,6 +847,10 @@ pub const Value = union(Type) { } } + pub fn getPrimitiveFieldConst(comptime T: PrimitiveType, comptime BitCount: SpvWord, v: *const Value) RuntimeError!*const getPrimitiveFieldType(T, BitCount) { + return getPrimitiveField(T, BitCount, @constCast(v)); + } + pub fn getPrimitiveField(comptime T: PrimitiveType, comptime BitCount: SpvWord, v: *Value) RuntimeError!*getPrimitiveFieldType(T, BitCount) { if (std.meta.activeTag(v.*) == .Pointer) { return switch (v.Pointer.ptr) { @@ -926,4 +930,28 @@ pub const Value = union(Type) { else => .unsigned, }; } + + pub inline fn getVectorSpecialization(self: *const Self, comptime N: usize, comptime T: type) @Vector(N, T) { + return switch (T) { + f32 => switch (N) { + inline 4 => self.Vector4f32, + inline 3 => self.Vector3f32, + inline 2 => self.Vector2f32, + else => unreachable, + }, + i32 => switch (N) { + inline 4 => self.Vector4i32, + inline 3 => self.Vector3i32, + inline 2 => self.Vector2i32, + else => unreachable, + }, + u32 => switch (N) { + inline 4 => self.Vector4u32, + inline 3 => self.Vector3u32, + inline 2 => self.Vector2u32, + else => unreachable, + }, + else => unreachable, + }; + } }; diff --git a/src/opcodes.zig b/src/opcodes.zig index b6b208e..b892af6 100644 --- a/src/opcodes.zig +++ b/src/opcodes.zig @@ -291,7 +291,7 @@ pub fn initRuntimeDispatcher() void { runtime_dispatcher[@intFromEnum(spv.SpvOp.LogicalOr)] = CondEngine(.Bool, .LogicalOr).op; runtime_dispatcher[@intFromEnum(spv.SpvOp.MatrixTimesMatrix)] = MathEngine(.Float, .MatrixTimesMatrix, false).op; runtime_dispatcher[@intFromEnum(spv.SpvOp.MatrixTimesScalar)] = MathEngine(.Float, .MatrixTimesScalar, false).op; // TODO - runtime_dispatcher[@intFromEnum(spv.SpvOp.MatrixTimesVector)] = MathEngine(.Float, .MatrixTimesVector, false).op; // TODO + runtime_dispatcher[@intFromEnum(spv.SpvOp.MatrixTimesVector)] = MathEngine(.Float, .MatrixTimesVector, false).op; runtime_dispatcher[@intFromEnum(spv.SpvOp.Not)] = BitEngine(.UInt, .Not).op; runtime_dispatcher[@intFromEnum(spv.SpvOp.Phi)] = opPhi; runtime_dispatcher[@intFromEnum(spv.SpvOp.Return)] = opReturn; @@ -916,32 +916,39 @@ fn MathEngine(comptime T: PrimitiveType, comptime Op: MathOp, comptime IsAtomic: const operator = struct { fn operation(comptime TT: type, op1: TT, op2: TT) RuntimeError!TT { + const is_int = @typeInfo(TT) == .int or (@typeInfo(TT) == .vector and @typeInfo(std.meta.Child(TT)) == .int); + const op2_is_zero = if (@typeInfo(TT) == .vector) std.simd.countElementsWithValue(op2, 0) != 0 else op2 == 0; + return switch (Op) { - .Add => if (@typeInfo(TT) == .int) @addWithOverflow(op1, op2)[0] else op1 + op2, - .Sub => if (@typeInfo(TT) == .int) @subWithOverflow(op1, op2)[0] else op1 - op2, + .Add => if (comptime is_int) @addWithOverflow(op1, op2)[0] else op1 + op2, + .Sub => if (comptime is_int) @subWithOverflow(op1, op2)[0] else op1 - op2, .Mul, .MatrixTimesMatrix, - => if (@typeInfo(TT) == .int) @mulWithOverflow(op1, op2)[0] else op1 * op2, + .MatrixTimesVector, + => if (comptime is_int) @mulWithOverflow(op1, op2)[0] else op1 * op2, .Div => blk: { - if (op2 == 0) return RuntimeError.DivisionByZero; - break :blk if (@typeInfo(TT) == .int) @divTrunc(op1, op2) else op1 / op2; + if (op2_is_zero) return RuntimeError.DivisionByZero; + break :blk if (comptime is_int) @divTrunc(op1, op2) else op1 / op2; }, - .Mod => if (op2 == 0) return RuntimeError.DivisionByZero else @mod(op1, op2), - .Rem => if (op2 == 0) return RuntimeError.DivisionByZero else @rem(op1, op2), + .Mod => if (op2_is_zero) return RuntimeError.DivisionByZero else @mod(op1, op2), + .Rem => if (op2_is_zero) return RuntimeError.DivisionByZero else @rem(op1, op2), else => return RuntimeError.InvalidSpirV, }; } - fn applyScalar(bit_count: SpvWord, d: *Value, l: *Value, r: *Value) RuntimeError!void { + fn applyScalarRaw(comptime BitCount: SpvWord, l: *const Value, r: *const Value) RuntimeError!Value.getPrimitiveFieldType(T, BitCount) { + const ScalarT = Value.getPrimitiveFieldType(T, BitCount); + const l_field = try Value.getPrimitiveFieldConst(T, BitCount, l); + const r_field = try Value.getPrimitiveFieldConst(T, BitCount, r); + return try operation(ScalarT, l_field.*, r_field.*); + } + + fn applyScalar(bit_count: SpvWord, d: *Value, l: *const Value, r: *const Value) RuntimeError!void { switch (bit_count) { inline 8, 16, 32, 64 => |bits| { - if (bits == 8 and T == .Float) return RuntimeError.UnsupportedSpirV; - - const ScalarT = Value.getPrimitiveFieldType(T, bits); + if (comptime bits == 8 and T == .Float) return RuntimeError.UnsupportedSpirV; const d_field = try Value.getPrimitiveField(T, bits, d); - const l_field = try Value.getPrimitiveField(T, bits, l); - const r_field = try Value.getPrimitiveField(T, bits, r); - d_field.* = try operation(ScalarT, l_field.*, r_field.*); + d_field.* = try applyScalarRaw(bits, l, r); }, else => return RuntimeError.UnsupportedSpirV, } @@ -950,81 +957,94 @@ fn MathEngine(comptime T: PrimitiveType, comptime Op: MathOp, comptime IsAtomic: inline fn applyVectorTimesScalarFloat(comptime bit_count: SpvWord, d: []Value, l: []const Value, r_v: *const Value) RuntimeError!void { for (d, l) |*d_v, l_v| { switch (bit_count) { - 16 => d_v.Float.value.float16 = l_v.Float.value.float16 * r_v.Float.value.float16, - 32 => d_v.Float.value.float32 = l_v.Float.value.float32 * r_v.Float.value.float32, - 64 => d_v.Float.value.float64 = l_v.Float.value.float64 * r_v.Float.value.float64, + inline 16 => d_v.Float.value.float16 = l_v.Float.value.float16 * r_v.Float.value.float16, + inline 32 => d_v.Float.value.float32 = l_v.Float.value.float32 * r_v.Float.value.float32, + inline 64 => d_v.Float.value.float64 = l_v.Float.value.float64 * r_v.Float.value.float64, else => return RuntimeError.UnsupportedSpirV, } } } - inline fn applySIMDVector(comptime ElemT: type, comptime N: usize, d: *@Vector(N, ElemT), l: *const @Vector(N, ElemT), r: *const @Vector(N, ElemT)) RuntimeError!void { - inline for (0..N) |i| { - d[i] = try operation(ElemT, l[i], r[i]); - } + inline fn applySIMDVector(comptime ElemT: type, comptime N: usize, d: *@Vector(N, ElemT), l: @Vector(N, ElemT), r: @Vector(N, ElemT)) RuntimeError!void { + d.* = try operation(@Vector(N, ElemT), l, r); } - inline fn applyVectorSIMDTimesScalarF32(comptime N: usize, d: *@Vector(N, f32), l: *const @Vector(N, f32), r: f32) void { - inline for (0..N) |i| { - d[i] = l[i] * r; - } - } - - fn applySIMDVectorf32(comptime N: usize, d: *@Vector(N, f32), l: *const @Vector(N, f32), r: *const Value) RuntimeError!void { + fn applySIMDVectorf32(comptime N: usize, d: *@Vector(N, f32), l: *const Value, r: *const Value) RuntimeError!void { switch (Op) { - .VectorTimesScalar => applyVectorSIMDTimesScalarF32(N, d, l, r.Float.value.float32), - else => { - const rh: *const @Vector(N, f32) = switch (N) { - 2 => &r.Vector2f32, - 3 => &r.Vector3f32, - 4 => &r.Vector4f32, - else => unreachable, - }; - try applySIMDVector(f32, N, d, l, rh); + .MatrixTimesVector => inline for (0..N) |i| { + d[i] = @reduce(.Add, l.Matrix[i].getVectorSpecialization(N, f32) * r.getVectorSpecialization(N, f32)); }, + else => try applyDirectSIMDVectorf32(N, d, l.getVectorSpecialization(N, f32), r), + } + } + + fn applyDirectSIMDVectorf32(comptime N: usize, d: *@Vector(N, f32), l: @Vector(N, f32), r: *const Value) RuntimeError!void { + switch (Op) { + .VectorTimesScalar => d.* = l * @as(@Vector(N, f32), @splat(r.Float.value.float32)), + else => try applySIMDVector(f32, N, d, l, r.getVectorSpecialization(N, f32)), } } }; + const vectorRoutines = struct { + fn routines(dst2: *Value, lhs2: *const Value, rhs2: *const Value, lane_bits2: SpvWord) RuntimeError!void { + switch (dst2.*) { + .Vector => |dst_vec| switch (Op) { + .VectorTimesScalar => switch (lane_bits2) { + inline 16, 32, 64 => |bits_count| try operator.applyVectorTimesScalarFloat(bits_count, dst_vec, lhs2.Vector, rhs2), + else => return RuntimeError.UnsupportedSpirV, + }, + .MatrixTimesVector => for (dst_vec, lhs2.Matrix) |*d_lane, *l_mat| { + switch (lane_bits2) { + inline 8, 16, 32, 64 => |bits| { + if (comptime bits == 8 and T == .Float) return RuntimeError.UnsupportedSpirV; + const d_field = try Value.getPrimitiveField(T, bits, d_lane); + + d_field.* = 0; + + for (l_mat.Vector[0..], rhs2.Vector) |*l_lane, *r_lane| { + d_field.* += try operator.applyScalarRaw(bits, l_lane, r_lane); + } + }, + else => return RuntimeError.UnsupportedSpirV, + } + }, + else => for (dst_vec, lhs2.Vector, rhs2.Vector) |*d_lane, *l_lane, *r_lane| { + try operator.applyScalar(lane_bits2, d_lane, l_lane, r_lane); + }, + }, + + .Vector4f32 => |*d| try operator.applySIMDVectorf32(4, d, lhs2, rhs2), + .Vector3f32 => |*d| try operator.applySIMDVectorf32(3, d, lhs2, rhs2), + .Vector2f32 => |*d| try operator.applySIMDVectorf32(2, d, lhs2, rhs2), + + .Vector4i32 => |*d| try operator.applySIMDVector(i32, 4, d, lhs2.Vector4i32, rhs2.Vector4i32), + .Vector3i32 => |*d| try operator.applySIMDVector(i32, 3, d, lhs2.Vector3i32, rhs2.Vector3i32), + .Vector2i32 => |*d| try operator.applySIMDVector(i32, 2, d, lhs2.Vector2i32, rhs2.Vector2i32), + + .Vector4u32 => |*d| try operator.applySIMDVector(u32, 4, d, lhs2.Vector4u32, rhs2.Vector4u32), + .Vector3u32 => |*d| try operator.applySIMDVector(u32, 3, d, lhs2.Vector3u32, rhs2.Vector3u32), + .Vector2u32 => |*d| try operator.applySIMDVector(u32, 2, d, lhs2.Vector2u32, rhs2.Vector2u32), + + else => return RuntimeError.InvalidValueType, + } + } + }.routines; + switch (dst.*) { .Int, .Float => try operator.applyScalar(lane_bits, dst, lhs, rhs), - .Vector => |dst_vec| switch (Op) { - .VectorTimesScalar => switch (lane_bits) { - inline 16, 32, 64 => |bits_count| try operator.applyVectorTimesScalarFloat(bits_count, dst_vec, lhs.Vector, rhs), - else => return RuntimeError.UnsupportedSpirV, - }, - else => for (dst_vec, lhs.Vector, rhs.Vector) |*d_lane, *l_lane, *r_lane| { - try operator.applyScalar(lane_bits, d_lane, l_lane, r_lane); - }, - }, - - .Vector4f32 => |*d| try operator.applySIMDVectorf32(4, d, &lhs.Vector4f32, rhs), - .Vector3f32 => |*d| try operator.applySIMDVectorf32(3, d, &lhs.Vector3f32, rhs), - .Vector2f32 => |*d| try operator.applySIMDVectorf32(2, d, &lhs.Vector2f32, rhs), - - .Vector4i32 => |*d| try operator.applySIMDVector(i32, 4, d, &lhs.Vector4i32, &rhs.Vector4i32), - .Vector3i32 => |*d| try operator.applySIMDVector(i32, 3, d, &lhs.Vector3i32, &rhs.Vector3i32), - .Vector2i32 => |*d| try operator.applySIMDVector(i32, 2, d, &lhs.Vector2i32, &rhs.Vector2i32), - - .Vector4u32 => |*d| try operator.applySIMDVector(u32, 4, d, &lhs.Vector4u32, &rhs.Vector4u32), - .Vector3u32 => |*d| try operator.applySIMDVector(u32, 3, d, &lhs.Vector3u32, &rhs.Vector3u32), - .Vector2u32 => |*d| try operator.applySIMDVector(u32, 2, d, &lhs.Vector2u32, &rhs.Vector2u32), - .Matrix => |dst_m| switch (Op) { .MatrixTimesMatrix => { for (dst_m, lhs.Matrix, rhs.Matrix) |*dst_vec, *lhs_vec, *rhs_vec| { - for (dst_vec.Vector, lhs_vec.Vector, rhs_vec.Vector) |*d_lane, *l_lane, *r_lane| { - try operator.applyScalar(lane_bits, d_lane, l_lane, r_lane); - } + try vectorRoutines(dst_vec, lhs_vec, rhs_vec, lane_bits); } }, - // TODO : matrix times vector // TODO : matrix times scalar else => return RuntimeError.ToDo, }, - else => return RuntimeError.InvalidSpirV, + else => try vectorRoutines(dst, lhs, rhs, lane_bits), } if (comptime IsAtomic) { @@ -1614,10 +1634,45 @@ fn opCompositeConstruct(_: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) return; } + const vectorRoutines = struct { + fn routines(value2: *Value, rt2: *Runtime) RuntimeError!void { + switch (value2.*) { + .Vector4f32 => |*vec| inline for (0..4) |i| { + vec[i] = (try rt2.results[try rt2.it.next()].getVariant()).Constant.value.Float.value.float32; + }, + .Vector3f32 => |*vec| inline for (0..3) |i| { + vec[i] = (try rt2.results[try rt2.it.next()].getVariant()).Constant.value.Float.value.float32; + }, + .Vector2f32 => |*vec| inline for (0..2) |i| { + vec[i] = (try rt2.results[try rt2.it.next()].getVariant()).Constant.value.Float.value.float32; + }, + .Vector4i32 => |*vec| inline for (0..4) |i| { + vec[i] = (try rt2.results[try rt2.it.next()].getVariant()).Constant.value.Int.value.sint32; + }, + .Vector3i32 => |*vec| inline for (0..3) |i| { + vec[i] = (try rt2.results[try rt2.it.next()].getVariant()).Constant.value.Int.value.sint32; + }, + .Vector2i32 => |*vec| inline for (0..2) |i| { + vec[i] = (try rt2.results[try rt2.it.next()].getVariant()).Constant.value.Int.value.sint32; + }, + .Vector4u32 => |*vec| inline for (0..4) |i| { + vec[i] = (try rt2.results[try rt2.it.next()].getVariant()).Constant.value.Int.value.uint32; + }, + .Vector3u32 => |*vec| inline for (0..3) |i| { + vec[i] = (try rt2.results[try rt2.it.next()].getVariant()).Constant.value.Int.value.uint32; + }, + .Vector2u32 => |*vec| inline for (0..2) |i| { + vec[i] = (try rt2.results[try rt2.it.next()].getVariant()).Constant.value.Int.value.uint32; + }, + else => return RuntimeError.InvalidValueType, + } + } + }.routines; + switch (value.*) { - .Matrix => |m| { + .Matrix => |*m| { var index: SpvWord = 0; - for (m[0..]) |mat_elem| { + for (m.*[0..]) |*mat_elem| { if (mat_elem.getCompositeDataOrNull()) |vec| { for (vec[0..]) |*elem| { const elem_value = (try rt.results[try rt.it.next()].getVariant()).Constant.value; @@ -1626,6 +1681,8 @@ fn opCompositeConstruct(_: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) if (index == index_count) return; } + } else { + try vectorRoutines(mat_elem, rt); } } }, @@ -1639,34 +1696,7 @@ fn opCompositeConstruct(_: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) offset += arr.stride; } }, - .Vector4f32 => |*vec| inline for (0..4) |i| { - vec[i] = (try rt.results[try rt.it.next()].getVariant()).Constant.value.Float.value.float32; - }, - .Vector3f32 => |*vec| inline for (0..3) |i| { - vec[i] = (try rt.results[try rt.it.next()].getVariant()).Constant.value.Float.value.float32; - }, - .Vector2f32 => |*vec| inline for (0..2) |i| { - vec[i] = (try rt.results[try rt.it.next()].getVariant()).Constant.value.Float.value.float32; - }, - .Vector4i32 => |*vec| inline for (0..4) |i| { - vec[i] = (try rt.results[try rt.it.next()].getVariant()).Constant.value.Int.value.sint32; - }, - .Vector3i32 => |*vec| inline for (0..3) |i| { - vec[i] = (try rt.results[try rt.it.next()].getVariant()).Constant.value.Int.value.sint32; - }, - .Vector2i32 => |*vec| inline for (0..2) |i| { - vec[i] = (try rt.results[try rt.it.next()].getVariant()).Constant.value.Int.value.sint32; - }, - .Vector4u32 => |*vec| inline for (0..4) |i| { - vec[i] = (try rt.results[try rt.it.next()].getVariant()).Constant.value.Int.value.uint32; - }, - .Vector3u32 => |*vec| inline for (0..3) |i| { - vec[i] = (try rt.results[try rt.it.next()].getVariant()).Constant.value.Int.value.uint32; - }, - .Vector2u32 => |*vec| inline for (0..2) |i| { - vec[i] = (try rt.results[try rt.it.next()].getVariant()).Constant.value.Int.value.uint32; - }, - else => return RuntimeError.InvalidValueType, + else => try vectorRoutines(value, rt), } } diff --git a/test/maths.zig b/test/maths.zig index e436ae0..757f3b5 100644 --- a/test/maths.zig +++ b/test/maths.zig @@ -1,5 +1,6 @@ const std = @import("std"); const root = @import("root.zig"); +const zm = @import("zmath"); const compileNzsl = root.compileNzsl; const case = root.case; @@ -294,3 +295,63 @@ test "Maths matrices" { } } } + +// Tests all mathematical operation on mat3/4 with all NZSL supported vectors +test "Maths matrices with vectors" { + const allocator = std.testing.allocator; + const types = [_]type{ f32, f64 }; + + inline for (3..5) |L| { + inline for (types) |T| { + const base: case.Mat(L, T) = .{ .val = case.random([L][L]T) }; + const ratio: case.Vec(L, T) = .{ .val = case.random(@Vector(L, T)) }; + var expected: @Vector(L, T) = undefined; + + expected[0] = (base.val[0][0] * ratio.val[0]) + (base.val[0][1] * ratio.val[1]) + (base.val[0][2] * ratio.val[2]) + if (L == 4) (base.val[0][3] * ratio.val[3]) else 0.0; + expected[1] = (base.val[1][0] * ratio.val[0]) + (base.val[1][1] * ratio.val[1]) + (base.val[1][2] * ratio.val[2]) + if (L == 4) (base.val[1][3] * ratio.val[3]) else 0.0; + expected[2] = (base.val[2][0] * ratio.val[0]) + (base.val[2][1] * ratio.val[1]) + (base.val[2][2] * ratio.val[2]) + if (L == 4) (base.val[2][3] * ratio.val[3]) else 0.0; + if (L == 4) + expected[3] = (base.val[3][0] * ratio.val[0]) + (base.val[3][1] * ratio.val[1]) + (base.val[3][2] * ratio.val[2]) + (base.val[3][3] * ratio.val[3]); + + const shader = try std.fmt.allocPrint( + allocator, + \\ [nzsl_version("1.1")] + \\ [feature(float64)] + \\ module; + \\ + \\ struct FragOut + \\ {{ + \\ [location(0)] value: vec{d}[{s}] + \\ }} + \\ + \\ [entry(frag)] + \\ fn main() -> FragOut + \\ {{ + \\ let output: FragOut; + \\ output.value = mat{d}[{s}]({f}) * vec{d}[{s}]({f}); + \\ return output; + \\ }} + , + .{ + L, + @typeName(T), + L, + @typeName(T), + base, + L, + @typeName(T), + ratio, + }, + ); + defer allocator.free(shader); + const code = try compileNzsl(allocator, shader); + defer allocator.free(code); + try case.expect(.{ + .source = code, + .expected_outputs = &.{ + std.mem.asBytes(&@as([L]T, expected)), + }, + }); + } + } +} diff --git a/test/root.zig b/test/root.zig index 4f42d73..b6abd7e 100644 --- a/test/root.zig +++ b/test/root.zig @@ -36,9 +36,9 @@ pub const case = struct { .{ .use_simd_vectors_specializations = false, }, - //.{ - // .use_simd_vectors_specializations = true, - //}, + .{ + .use_simd_vectors_specializations = true, + }, }; for (module_options) |opt| {