adding base matrix management
Build / build (push) Successful in 1m33s
Test / build (push) Successful in 10m31s

This commit is contained in:
2026-05-11 01:48:13 +02:00
parent 9d20363ae8
commit 769009ad5e
8 changed files with 145 additions and 11 deletions
+1 -1
View File
@@ -20,4 +20,4 @@ jobs:
- uses: https://codeberg.org/mlugg/setup-zig@v2 - uses: https://codeberg.org/mlugg/setup-zig@v2
- name: Test - name: Test
run: zig build test -Dno-example=true run: zig build test -Dno-example=true --release=fast
+8 -1
View File
@@ -153,9 +153,15 @@ fn addZigTests(
const no_test = b.option(bool, "no-test", "Skip unit test dependencies fetch") orelse false; const no_test = b.option(bool, "no-test", "Skip unit test dependencies fetch") orelse false;
if (no_test) return; if (no_test) return;
const test_filter = b.option(
[]const u8,
"test-filter",
"Only run tests whose name contains this substring",
);
const nzsl = b.lazyDependency("NZSL", .{ const nzsl = b.lazyDependency("NZSL", .{
.target = target, .target = target,
.optimize = optimize, .optimize = .ReleaseFast,
}) orelse return; }) orelse return;
const tests = b.addTest(.{ const tests = b.addTest(.{
@@ -173,6 +179,7 @@ fn addZigTests(
.path = b.path("test/test_runner.zig"), .path = b.path("test/test_runner.zig"),
.mode = .simple, .mode = .simple,
}, },
.filters = if (test_filter) |filter| &.{filter} else &.{},
.use_llvm = use_llvm, .use_llvm = use_llvm,
}); });
+2 -2
View File
@@ -7,8 +7,8 @@
.hash = "zmath-0.11.0-dev-wjwivdMsAwD-xaLj76YHUq3t9JDH-X16xuMTmnDzqbu2", .hash = "zmath-0.11.0-dev-wjwivdMsAwD-xaLj76YHUq3t9JDH-X16xuMTmnDzqbu2",
}, },
.NZSL = .{ // For unit tests .NZSL = .{ // For unit tests
.url = "git+https://git.kbz8.me/kbz_8/NZigSL#ab95fc3734da46079fda2a4cd0f14143d92bf633", .url = "git+https://git.kbz8.me/kbz_8/NZigSL#68f6c0ae2d0fc6b91eaa9df5c0fcd68f3529c5b8",
.hash = "NZSL-1.1.2-N0xSVCR7AACeI_Wa6JPggJzy9_MPCpWC-2OHkMowwX-7", .hash = "NZSL-1.1.4-N0xSVC97AACi1-SuJ_ifNAOBRdCMPIXN1vMgVprfABhH",
.lazy = true, .lazy = true,
}, },
//.sdl3 = .{ //.sdl3 = .{
+3
View File
@@ -390,6 +390,7 @@ pub fn resolveLaneBitWidth(target_type: TypeData, rt: *const Runtime) RuntimeErr
.Float => |f| f.bit_length, .Float => |f| f.bit_length,
.Int => |i| i.bit_length, .Int => |i| i.bit_length,
.Vector => |v| continue :sw (try rt.results[v.components_type_word].getVariant()).Type, .Vector => |v| continue :sw (try rt.results[v.components_type_word].getVariant()).Type,
.Matrix => |m| continue :sw (try rt.results[m.column_type_word].getVariant()).Type,
.Vector4f32, .Vector4f32,
.Vector3f32, .Vector3f32,
.Vector2f32, .Vector2f32,
@@ -408,6 +409,7 @@ pub fn resolveLaneCount(target_type: TypeData) RuntimeError!SpvWord {
return switch (target_type) { return switch (target_type) {
.Bool, .Float, .Int => 1, .Bool, .Float, .Int => 1,
.Vector => |v| v.member_count, .Vector => |v| v.member_count,
.Matrix => |m| m.member_count,
.Vector4f32, .Vector4i32, .Vector4u32 => 4, .Vector4f32, .Vector4i32, .Vector4u32 => 4,
.Vector3f32, .Vector3i32, .Vector3u32 => 3, .Vector3f32, .Vector3i32, .Vector3u32 => 3,
.Vector2f32, .Vector2i32, .Vector2u32 => 2, .Vector2f32, .Vector2i32, .Vector2u32 => 2,
@@ -419,6 +421,7 @@ pub fn resolveSign(target_type: TypeData, rt: *const Runtime) RuntimeError!enum
return sw: switch (target_type) { return sw: switch (target_type) {
.Int => |i| if (i.is_signed) .signed else .unsigned, .Int => |i| if (i.is_signed) .signed else .unsigned,
.Vector => |v| continue :sw (try rt.results[v.components_type_word].getVariant()).Type, .Vector => |v| continue :sw (try rt.results[v.components_type_word].getVariant()).Type,
.Matrix => |m| continue :sw (try rt.results[m.column_type_word].getVariant()).Type,
.Vector4i32 => .signed, .Vector4i32 => .signed,
.Vector3i32 => .signed, .Vector3i32 => .signed,
.Vector2i32 => .signed, .Vector2i32 => .signed,
+1 -1
View File
@@ -136,7 +136,7 @@ pub const Value = union(Type) {
return switch (self.*) { return switch (self.*) {
.Structure => |*s| s.values, .Structure => |*s| s.values,
.Array => |*a| a.values, .Array => |*a| a.values,
.Vector, .Matrix => |v| v, .Vector => |v| v,
else => null, else => null,
}; };
} }
+31 -2
View File
@@ -289,7 +289,7 @@ pub fn initRuntimeDispatcher() void {
runtime_dispatcher[@intFromEnum(spv.SpvOp.LogicalNot)] = CondEngine(.Bool, .LogicalNot).op; runtime_dispatcher[@intFromEnum(spv.SpvOp.LogicalNot)] = CondEngine(.Bool, .LogicalNot).op;
runtime_dispatcher[@intFromEnum(spv.SpvOp.LogicalNotEqual)] = CondEngine(.Bool, .LogicalNotEqual).op; runtime_dispatcher[@intFromEnum(spv.SpvOp.LogicalNotEqual)] = CondEngine(.Bool, .LogicalNotEqual).op;
runtime_dispatcher[@intFromEnum(spv.SpvOp.LogicalOr)] = CondEngine(.Bool, .LogicalOr).op; runtime_dispatcher[@intFromEnum(spv.SpvOp.LogicalOr)] = CondEngine(.Bool, .LogicalOr).op;
runtime_dispatcher[@intFromEnum(spv.SpvOp.MatrixTimesMatrix)] = MathEngine(.Float, .MatrixTimesMatrix, false).op; // TODO runtime_dispatcher[@intFromEnum(spv.SpvOp.MatrixTimesMatrix)] = MathEngine(.Float, .MatrixTimesMatrix, false).op;
runtime_dispatcher[@intFromEnum(spv.SpvOp.MatrixTimesScalar)] = MathEngine(.Float, .MatrixTimesScalar, false).op; // TODO runtime_dispatcher[@intFromEnum(spv.SpvOp.MatrixTimesScalar)] = MathEngine(.Float, .MatrixTimesScalar, false).op; // TODO
runtime_dispatcher[@intFromEnum(spv.SpvOp.MatrixTimesVector)] = MathEngine(.Float, .MatrixTimesVector, false).op; // TODO runtime_dispatcher[@intFromEnum(spv.SpvOp.MatrixTimesVector)] = MathEngine(.Float, .MatrixTimesVector, false).op; // TODO
runtime_dispatcher[@intFromEnum(spv.SpvOp.Not)] = BitEngine(.UInt, .Not).op; runtime_dispatcher[@intFromEnum(spv.SpvOp.Not)] = BitEngine(.UInt, .Not).op;
@@ -919,7 +919,9 @@ fn MathEngine(comptime T: PrimitiveType, comptime Op: MathOp, comptime IsAtomic:
return switch (Op) { return switch (Op) {
.Add => if (@typeInfo(TT) == .int) @addWithOverflow(op1, op2)[0] else op1 + op2, .Add => if (@typeInfo(TT) == .int) @addWithOverflow(op1, op2)[0] else op1 + op2,
.Sub => if (@typeInfo(TT) == .int) @subWithOverflow(op1, op2)[0] else op1 - op2, .Sub => if (@typeInfo(TT) == .int) @subWithOverflow(op1, op2)[0] else op1 - op2,
.Mul => if (@typeInfo(TT) == .int) @mulWithOverflow(op1, op2)[0] else op1 * op2, .Mul,
.MatrixTimesMatrix,
=> if (@typeInfo(TT) == .int) @mulWithOverflow(op1, op2)[0] else op1 * op2,
.Div => blk: { .Div => blk: {
if (op2 == 0) return RuntimeError.DivisionByZero; if (op2 == 0) return RuntimeError.DivisionByZero;
break :blk if (@typeInfo(TT) == .int) @divTrunc(op1, op2) else op1 / op2; break :blk if (@typeInfo(TT) == .int) @divTrunc(op1, op2) else op1 / op2;
@@ -1009,6 +1011,19 @@ fn MathEngine(comptime T: PrimitiveType, comptime Op: MathOp, comptime IsAtomic:
.Vector3u32 => |*d| try operator.applySIMDVector(u32, 3, d, &lhs.Vector3u32, &rhs.Vector3u32), .Vector3u32 => |*d| try operator.applySIMDVector(u32, 3, d, &lhs.Vector3u32, &rhs.Vector3u32),
.Vector2u32 => |*d| try operator.applySIMDVector(u32, 2, d, &lhs.Vector2u32, &rhs.Vector2u32), .Vector2u32 => |*d| try operator.applySIMDVector(u32, 2, d, &lhs.Vector2u32, &rhs.Vector2u32),
.Matrix => |dst_m| switch (Op) {
.MatrixTimesMatrix => {
for (dst_m, lhs.Matrix, rhs.Matrix) |*dst_vec, *lhs_vec, *rhs_vec| {
for (dst_vec.Vector, lhs_vec.Vector, rhs_vec.Vector) |*d_lane, *l_lane, *r_lane| {
try operator.applyScalar(lane_bits, d_lane, l_lane, r_lane);
}
}
},
// TODO : matrix times vector
// TODO : matrix times scalar
else => return RuntimeError.ToDo,
},
else => return RuntimeError.InvalidSpirV, else => return RuntimeError.InvalidSpirV,
} }
@@ -1600,6 +1615,20 @@ fn opCompositeConstruct(_: std.mem.Allocator, word_count: SpvWord, rt: *Runtime)
} }
switch (value.*) { switch (value.*) {
.Matrix => |m| {
var index: SpvWord = 0;
for (m[0..]) |mat_elem| {
if (mat_elem.getCompositeDataOrNull()) |vec| {
for (vec[0..]) |*elem| {
const elem_value = (try rt.results[try rt.it.next()].getVariant()).Constant.value;
elem.* = elem_value;
index += 1;
if (index == index_count)
return;
}
}
}
},
.RuntimeArray => |arr| { .RuntimeArray => |arr| {
var offset: usize = 0; var offset: usize = 0;
+73
View File
@@ -221,3 +221,76 @@ test "Maths vectors with scalars" {
} }
} }
} }
// Tests all mathematical operation on mat3/4 with all NZSL supported primitive types
test "Maths matrices" {
const allocator = std.testing.allocator;
const types = [_]type{ f32, f64 };
var operations = std.EnumMap(Operations, u8).init(.{
.Add = '+',
.Sub = '-',
.Mul = '*',
});
var it = operations.iterator();
while (it.next()) |op| {
inline for (3..5) |L| {
inline for (types) |T| {
const base: case.Mat(L, T) = .{ .val = case.random([L][L]T) };
const ratio: case.Mat(L, T) = .{ .val = case.random([L][L]T) };
var expected: case.Mat(L, T) = undefined;
for (expected.val[0..], base.val[0..], ratio.val[0..]) |*ec, bc, rc| {
for (ec[0..], bc[0..], rc[0..]) |*e, b, r| {
e.* = switch (op.key) {
.Add => b + r,
.Sub => b - r,
.Mul => b * r,
else => unreachable,
};
}
}
const shader = try std.fmt.allocPrint(
allocator,
\\ [nzsl_version("1.1")]
\\ [feature(float64)]
\\ module;
\\
\\ struct FragOut
\\ {{
\\ [location(0)] value: mat{d}[{s}]
\\ }}
\\
\\ [entry(frag)]
\\ fn main() -> FragOut
\\ {{
\\ let output: FragOut;
\\ output.value = mat{d}[{s}]({f}) {c} mat{d}[{s}]({f});
\\ return output;
\\ }}
,
.{
L,
@typeName(T),
L,
@typeName(T),
base,
op.value.*,
L,
@typeName(T),
ratio,
},
);
defer allocator.free(shader);
const code = try compileNzsl(allocator, shader);
defer allocator.free(code);
try case.expect(.{
.source = code,
.expected_outputs = &.{
std.mem.asBytes(&expected),
},
});
}
}
}
}
+26 -4
View File
@@ -33,12 +33,12 @@ pub const case = struct {
// To test with all important module options // To test with all important module options
const module_options = [_]spv.Module.ModuleOptions{ const module_options = [_]spv.Module.ModuleOptions{
.{
.use_simd_vectors_specializations = true,
},
.{ .{
.use_simd_vectors_specializations = false, .use_simd_vectors_specializations = false,
}, },
//.{
// .use_simd_vectors_specializations = true,
//},
}; };
for (module_options) |opt| { for (module_options) |opt| {
@@ -78,7 +78,7 @@ pub const case = struct {
} }
pub fn random(comptime T: type) T { pub fn random(comptime T: type) T {
var prng: std.Random.DefaultPrng = .init(@intCast(std.Io.Timestamp.now(std.testing.io, .real).toMicroseconds())); var prng: std.Random.DefaultPrng = .init(@intCast(std.Io.Timestamp.now(std.testing.io, .real).toNanoseconds()));
const rand = prng.random(); const rand = prng.random();
return switch (@typeInfo(T)) { return switch (@typeInfo(T)) {
@@ -91,6 +91,13 @@ pub const case = struct {
} }
break :blk vec; break :blk vec;
}, },
.array => |a| blk: {
var arr: [a.len]a.child = undefined;
inline for (0..a.len) |i| {
arr[i] = random(a.child);
}
break :blk arr;
},
inline else => unreachable, inline else => unreachable,
}; };
} }
@@ -107,6 +114,21 @@ pub const case = struct {
} }
}; };
} }
pub fn Mat(comptime len: usize, comptime T: type) type {
return struct {
const Self = @This();
val: [len][len]T,
pub fn format(self: *const Self, w: *std.Io.Writer) std.Io.Writer.Error!void {
inline for (0..len) |i| {
inline for (0..len) |j| {
try w.print("{d}", .{self.val[i][j]});
if (i < len - 1 or j < len - 1) try w.writeAll(", ");
}
}
}
};
}
}; };
test { test {