diff --git a/src/Runtime.zig b/src/Runtime.zig index 96c23f5..1cfb6b3 100644 --- a/src/Runtime.zig +++ b/src/Runtime.zig @@ -21,6 +21,7 @@ pub const RuntimeError = error{ Killed, InvalidEntryPoint, ToDo, + DivisionByZero, }; pub const Function = struct { diff --git a/src/opcodes.zig b/src/opcodes.zig index 5c2813a..fa86865 100644 --- a/src/opcodes.zig +++ b/src/opcodes.zig @@ -19,6 +19,14 @@ const MathType = enum { UInt, }; +const MathOp = enum { + Add, + Sub, + Mul, + Div, + Mod, +}; + pub const OpCodeFunc = *const fn (std.mem.Allocator, SpvWord, *Runtime) RuntimeError!void; pub const SetupDispatcher = block: { @@ -30,10 +38,8 @@ pub const SetupDispatcher = block: { .Decorate = opDecorate, .EntryPoint = opEntryPoint, .ExecutionMode = opExecutionMode, - .FMul = autoSetupConstant, .Function = opFunction, .FunctionEnd = opFunctionEnd, - .IMul = autoSetupConstant, .Label = opLabel, .Load = autoSetupConstant, .MemberDecorate = opDecorateMember, @@ -52,6 +58,18 @@ pub const SetupDispatcher = block: { .TypeVector = opTypeVector, .TypeVoid = opTypeVoid, .Variable = opVariable, + .FAdd = autoSetupConstant, + .FDiv = autoSetupConstant, + .FMod = autoSetupConstant, + .FMul = autoSetupConstant, + .FSub = autoSetupConstant, + .IAdd = autoSetupConstant, + .IMul = autoSetupConstant, + .ISub = autoSetupConstant, + .SDiv = autoSetupConstant, + .SMod = autoSetupConstant, + .UDiv = autoSetupConstant, + .UMod = autoSetupConstant, }); }; @@ -61,11 +79,21 @@ pub const RuntimeDispatcher = block: { .AccessChain = opAccessChain, .CompositeConstruct = opCompositeConstruct, .CompositeExtract = opCompositeExtract, - .FMul = maths(.Float).opMul, - .IMul = maths(.SInt).opMul, + .FAdd = maths(.Float, .Add).op, + .FDiv = maths(.Float, .Div).op, + .FMod = maths(.Float, .Mod).op, + .FMul = maths(.Float, .Mul).op, + .FSub = maths(.Float, .Sub).op, + .IAdd = maths(.SInt, .Add).op, + .IMul = maths(.SInt, .Mul).op, + .ISub = maths(.SInt, .Sub).op, .Load = opLoad, .Return = opReturn, + .SDiv = maths(.SInt, .Div).op, + .SMod = maths(.SInt, .Mod).op, .Store = opStore, + .UDiv = maths(.UInt, .Div).op, + .UMod = maths(.UInt, .Mod).op, }); }; @@ -387,14 +415,18 @@ fn opConstant(allocator: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) R switch (target.variant.?.Constant) { .Int => |*i| { if (word_count - 2 != 1) { - i.uint64 = @as(u64, try rt.it.next()) | (@as(u64, try rt.it.next()) >> 32); + const low = @as(u64, try rt.it.next()); + const high = @as(u64, try rt.it.next()); + i.uint64 = (high << 32) | low; } else { i.uint32 = try rt.it.next(); } }, .Float => |*f| { if (word_count - 2 != 1) { - f.float64 = @bitCast(@as(u64, try rt.it.next()) | (@as(u64, try rt.it.next()) >> 32)); + const low = @as(u64, try rt.it.next()); + const high = @as(u64, try rt.it.next()); + f.float64 = @bitCast((high << 32) | low); } else { f.float32 = @bitCast(try rt.it.next()); } @@ -598,9 +630,9 @@ fn opReturn(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { } } -fn maths(comptime T: MathType) type { +fn maths(comptime T: MathType, comptime Op: MathOp) type { return struct { - fn opMul(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { + fn op(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { const target_type = (rt.results[try rt.it.next()].variant orelse return RuntimeError.InvalidSpirV).Type; const value = try rt.results[try rt.it.next()].getValue(); const op1_value = try rt.results[try rt.it.next()].getValue(); @@ -614,26 +646,52 @@ fn maths(comptime T: MathType) type { }; const operator = struct { + fn operation(comptime TT: type, op1: TT, op2: TT) RuntimeError!TT { + return switch (Op) { + .Add => if (@typeInfo(TT) == .int) @addWithOverflow(op1, op2)[0] else op1 + op2, + .Sub => if (@typeInfo(TT) == .int) @subWithOverflow(op1, op2)[0] else op1 - op2, + .Mul => if (@typeInfo(TT) == .int) @mulWithOverflow(op1, op2)[0] else op1 * op2, + .Div => blk: { + if (op2 == 0) return RuntimeError.DivisionByZero; + break :blk if (@typeInfo(TT) == .int) @divTrunc(op1, op2) else op1 / op2; + }, + .Mod => blk: { + if (op2 == 0) return RuntimeError.DivisionByZero; + break :blk @mod(op1, op2); + }, + }; + } + fn process(bit_count: SpvWord, v: *Result.Value, op1_v: *const Result.Value, op2_v: *const Result.Value) RuntimeError!void { switch (T) { .Float => switch (bit_count) { - 16 => v.Float.float16 = op1_v.Float.float16 * op2_v.Float.float16, - 32 => v.Float.float32 = op1_v.Float.float32 * op2_v.Float.float32, - 64 => v.Float.float64 = op1_v.Float.float64 * op2_v.Float.float64, + inline 16, 32, 64 => |i| @field(v.Float, std.fmt.comptimePrint("float{}", .{i})) = try operation( + @Type(.{ .float = .{ .bits = i } }), + @field(op1_v.Float, std.fmt.comptimePrint("float{}", .{i})), + @field(op2_v.Float, std.fmt.comptimePrint("float{}", .{i})), + ), else => return RuntimeError.InvalidSpirV, }, .SInt => switch (bit_count) { - 8 => v.Int.sint8 = @mulWithOverflow(op1_v.Int.sint8, op2_v.Int.sint8)[0], - 16 => v.Int.sint16 = @mulWithOverflow(op1_v.Int.sint16, op2_v.Int.sint16)[0], - 32 => v.Int.sint32 = @mulWithOverflow(op1_v.Int.sint32, op2_v.Int.sint32)[0], - 64 => v.Int.sint64 = @mulWithOverflow(op1_v.Int.sint64, op2_v.Int.sint64)[0], + inline 8, 16, 32, 64 => |i| @field(v.Int, std.fmt.comptimePrint("sint{}", .{i})) = try operation( + @Type(.{ .int = .{ + .signedness = .signed, + .bits = i, + } }), + @field(op1_v.Int, std.fmt.comptimePrint("sint{}", .{i})), + @field(op2_v.Int, std.fmt.comptimePrint("sint{}", .{i})), + ), else => return RuntimeError.InvalidSpirV, }, .UInt => switch (bit_count) { - 8 => v.Int.uint8 = @mulWithOverflow(op1_v.Int.uint8, op2_v.Int.uint8)[0], - 16 => v.Int.uint16 = @mulWithOverflow(op1_v.Int.uint16, op2_v.Int.uint16)[0], - 32 => v.Int.uint32 = @mulWithOverflow(op1_v.Int.uint32, op2_v.Int.uint32)[0], - 64 => v.Int.uint64 = @mulWithOverflow(op1_v.Int.uint64, op2_v.Int.uint64)[0], + inline 8, 16, 32, 64 => |i| @field(v.Int, std.fmt.comptimePrint("uint{}", .{i})) = try operation( + @Type(.{ .int = .{ + .signedness = .unsigned, + .bits = i, + } }), + @field(op1_v.Int, std.fmt.comptimePrint("uint{}", .{i})), + @field(op2_v.Int, std.fmt.comptimePrint("uint{}", .{i})), + ), else => return RuntimeError.InvalidSpirV, }, } diff --git a/test/basics.zig b/test/basics.zig index ff74aae..518237d 100644 --- a/test/basics.zig +++ b/test/basics.zig @@ -3,7 +3,7 @@ const root = @import("root.zig"); const compileNzsl = root.compileNzsl; const case = root.case; -test "FMul vec4[f32]" { +test "Simple fragment shader" { const allocator = std.testing.allocator; const shader = \\ [nzsl_version("1.1")] diff --git a/test/maths.zig b/test/maths.zig index f0c0853..d55b40d 100644 --- a/test/maths.zig +++ b/test/maths.zig @@ -3,62 +3,157 @@ const root = @import("root.zig"); const compileNzsl = root.compileNzsl; const case = root.case; -test "Mul vec4" { - const allocator = std.testing.allocator; - const types = [_]type{ - f32, - //f64, - i32, - u32, +const Operations = enum { + Add, + Sub, + Mul, + Div, + Mod, +}; + +fn Vec(comptime len: usize, comptime T: type) type { + return struct { + const Self = @This(); + val: @Vector(len, T), + pub fn format(self: *const Self, w: *std.Io.Writer) std.Io.Writer.Error!void { + inline for (0..len) |i| { + try w.print("{d}", .{self.val[i]}); + if (i < len - 1) try w.writeAll(", "); + } + } }; +} - inline for (types) |T| { - const base_color = case.random(@Vector(4, T)); - const ratio = case.random(@Vector(4, T)); - const expected = switch (@typeInfo(T)) { - .float => base_color * ratio, - .int => @mulWithOverflow(base_color, ratio)[0], - else => unreachable, - }; +// Tests all mathematical operation on all NZSL supported primitive types +test "Maths primitives" { + const allocator = std.testing.allocator; + const types = [_]type{ f32, f64, i32, u32 }; + var operations = std.EnumMap(Operations, u8).init(.{ + .Add = '+', + .Sub = '-', + .Mul = '*', + .Div = '/', + .Mod = '%', + }); - const shader = try std.fmt.allocPrint( - allocator, - \\ [nzsl_version("1.1")] - \\ [feature(float64)] - \\ module; - \\ - \\ struct FragOut - \\ {{ - \\ [location(0)] color: vec4[{s}] - \\ }} - \\ - \\ [entry(frag)] - \\ fn main() -> FragOut - \\ {{ - \\ let ratio = vec4[{s}]({d}, {d}, {d}, {d}); - \\ - \\ let output: FragOut; - \\ output.color = vec4[{s}]({d}, {d}, {d}, {d}) * ratio; - \\ return output; - \\ }} - , - .{ - @typeName(T), - @typeName(T), - ratio[0], - ratio[1], - ratio[2], - ratio[3], - @typeName(T), - base_color[0], - base_color[1], - base_color[2], - base_color[3], - }, - ); - defer allocator.free(shader); - const code = try compileNzsl(allocator, shader); - defer allocator.free(code); - try case.expectOutput(T, 4, code, "color", &@as([4]T, expected)); + var it = operations.iterator(); + while (it.next()) |op| { + inline for (types) |T| { + const base: T = case.random(T); + const ratio: T = case.random(T); + const expected = switch (op.key) { + .Add => if (@typeInfo(T) == .int) @addWithOverflow(base, ratio)[0] else base + ratio, + .Sub => if (@typeInfo(T) == .int) @subWithOverflow(base, ratio)[0] else base - ratio, + .Mul => if (@typeInfo(T) == .int) @mulWithOverflow(base, ratio)[0] else base * ratio, + .Div => if (@typeInfo(T) == .int) @divTrunc(base, ratio) else base / ratio, + .Mod => @mod(base, ratio), + }; + + const shader = try std.fmt.allocPrint( + allocator, + \\ [nzsl_version("1.1")] + \\ [feature(float64)] + \\ module; + \\ + \\ struct FragOut + \\ {{ + \\ [location(0)] color: vec4[{s}] + \\ }} + \\ + \\ [entry(frag)] + \\ fn main() -> FragOut + \\ {{ + \\ let ratio: {s} = {d}; + \\ let base: {s} = {d}; + \\ let color = base {c} ratio; + \\ + \\ let output: FragOut; + \\ output.color = vec4[{s}](color, color, color, color); + \\ return output; + \\ }} + , + .{ + @typeName(T), + @typeName(T), + ratio, + @typeName(T), + base, + op.value.*, + @typeName(T), + }, + ); + defer allocator.free(shader); + const code = try compileNzsl(allocator, shader); + defer allocator.free(code); + try case.expectOutput(T, 4, code, "color", &.{ expected, expected, expected, expected }); + } + } +} + +// Tests all mathematical operation on vec2/3/4 with all NZSL supported primitive types +test "Maths vectors" { + const allocator = std.testing.allocator; + const types = [_]type{ f32, f64, i32, u32 }; + var operations = std.EnumMap(Operations, u8).init(.{ + .Add = '+', + .Sub = '-', + .Mul = '*', + .Div = '/', + .Mod = '%', + }); + + var it = operations.iterator(); + while (it.next()) |op| { + inline for (2..5) |L| { + inline for (types) |T| { + const base_color: Vec(L, T) = .{ .val = case.random(@Vector(L, T)) }; + const ratio: Vec(L, T) = .{ .val = case.random(@Vector(L, T)) }; + const expected = switch (op.key) { + .Add => if (@typeInfo(T) == .int) @addWithOverflow(base_color.val, ratio.val)[0] else base_color.val + ratio.val, + .Sub => if (@typeInfo(T) == .int) @subWithOverflow(base_color.val, ratio.val)[0] else base_color.val - ratio.val, + .Mul => if (@typeInfo(T) == .int) @mulWithOverflow(base_color.val, ratio.val)[0] else base_color.val * ratio.val, + .Div => if (@typeInfo(T) == .int) @divTrunc(base_color.val, ratio.val) else base_color.val / ratio.val, + .Mod => @mod(base_color.val, ratio.val), + }; + + const shader = try std.fmt.allocPrint( + allocator, + \\ [nzsl_version("1.1")] + \\ [feature(float64)] + \\ module; + \\ + \\ struct FragOut + \\ {{ + \\ [location(0)] color: vec{d}[{s}] + \\ }} + \\ + \\ [entry(frag)] + \\ fn main() -> FragOut + \\ {{ + \\ let ratio = vec{d}[{s}]({f}); + \\ + \\ let output: FragOut; + \\ output.color = vec{d}[{s}]({f}) {c} ratio; + \\ return output; + \\ }} + , + .{ + L, + @typeName(T), + L, + @typeName(T), + ratio, + L, + @typeName(T), + base_color, + op.value.*, + }, + ); + defer allocator.free(shader); + const code = try compileNzsl(allocator, shader); + defer allocator.free(code); + try case.expectOutput(T, L, code, "color", &@as([L]T, expected)); + } + } } }