diff --git a/example/main.zig b/example/main.zig index c4aede2..dd7ba41 100644 --- a/example/main.zig +++ b/example/main.zig @@ -17,11 +17,9 @@ pub fn main() !void { defer rt.deinit(allocator); try rt.callEntryPoint(allocator, try rt.getEntryPointByName("main")); - var value: f32 = undefined; - var value2: f32 = undefined; - try rt.readOutput(f32, @as([*]f32, @ptrCast(&value))[0..1], try rt.getResultByName("value")); - try rt.readOutput(f32, @as([*]f32, @ptrCast(&value2))[0..1], try rt.getResultByName("value2")); - std.log.info("Output: {d} {d}", .{ value, value2 }); + var output: [4]i32 = undefined; + try rt.readOutput(i32, output[0..], try rt.getResultByName("color")); + std.log.info("Output: Vec4{any}", .{output}); } std.log.info("Successfully executed", .{}); } diff --git a/example/shader.nzsl b/example/shader.nzsl index d71158b..4e5c3e8 100644 --- a/example/shader.nzsl +++ b/example/shader.nzsl @@ -10,7 +10,7 @@ struct FragOut fn main() -> FragOut { let base: i32 = 4; - let value: i32 = base << 3; + let value: i32 = base >> 3; let output: FragOut; output.color = vec4[i32](value, value, value, value); return output; diff --git a/example/shader.spv b/example/shader.spv index 04ee6a5..09b51e0 100644 Binary files a/example/shader.spv and b/example/shader.spv differ diff --git a/example/shader.spv.txt b/example/shader.spv.txt index a9fdff1..592e083 100644 --- a/example/shader.spv.txt +++ b/example/shader.spv.txt @@ -34,7 +34,7 @@ Schema: 0 %17 = OpVariable %11 StorageClass(Function) OpStore %15 %8 %18 = OpLoad %3 %15 -%19 = OpShiftLeftLogical %3 %18 %10 +%19 = OpShiftRightArithmetic %3 %18 %10 OpStore %16 %19 %20 = OpLoad %3 %16 %21 = OpLoad %3 %16 diff --git a/src/Runtime.zig b/src/Runtime.zig index 8e678ee..e66d8cf 100644 --- a/src/Runtime.zig +++ b/src/Runtime.zig @@ -166,7 +166,7 @@ fn readValue(self: *const Self, comptime T: type, output: []T, value: *const Res if (T == bool) { output[0] = b; } else { - unreachable; + unreachable; // Wanted value may not be composed of booleans } }, .Int => |i| { @@ -179,7 +179,7 @@ fn readValue(self: *const Self, comptime T: type, output: []T, value: *const Res u16 => output[0] = i.uint16, u32 => output[0] = i.uint32, u64 => output[0] = i.uint64, - inline else => unreachable, + inline else => unreachable, // Wanted value may not be composed of ints } }, .Float => |f| { @@ -187,7 +187,7 @@ fn readValue(self: *const Self, comptime T: type, output: []T, value: *const Res f16 => output[0] = f.float16, f32 => output[0] = f.float32, f64 => output[0] = f.float64, - inline else => unreachable, + inline else => unreachable, // Wanted value may not be composed of floats } }, .Vector => |values| for (values, 0..) |v, i| self.readValue(T, output[i..], &v), diff --git a/src/opcodes.zig b/src/opcodes.zig index bbe1d4b..d475317 100644 --- a/src/opcodes.zig +++ b/src/opcodes.zig @@ -43,6 +43,7 @@ const BitOp = enum { BitFieldUExtract, BitReverse, BitwiseAnd, + BitwiseOr, BitwiseXor, Not, ShiftLeft, @@ -55,7 +56,15 @@ pub const OpCodeFunc = *const fn (std.mem.Allocator, SpvWord, *Runtime) RuntimeE pub const SetupDispatcher = block: { @setEvalBranchQuota(65535); break :block std.EnumMap(spv.SpvOp, OpCodeFunc).init(.{ + .BitCount = autoSetupConstant, + .BitFieldInsert = autoSetupConstant, + .BitFieldSExtract = autoSetupConstant, + .BitFieldUExtract = autoSetupConstant, + .BitReverse = autoSetupConstant, .Bitcast = autoSetupConstant, + .BitwiseAnd = autoSetupConstant, + .BitwiseOr = autoSetupConstant, + .BitwiseXor = autoSetupConstant, .Capability = opCapability, .CompositeConstruct = autoSetupConstant, .Constant = opConstant, @@ -101,6 +110,7 @@ pub const SetupDispatcher = block: { .MemberName = opMemberName, .MemoryModel = opMemoryModel, .Name = opName, + .Not = autoSetupConstant, .QuantizeToF16 = autoSetupConstant, .SConvert = autoSetupConstant, .SDiv = autoSetupConstant, @@ -111,6 +121,9 @@ pub const SetupDispatcher = block: { .SMod = autoSetupConstant, .SatConvertSToU = autoSetupConstant, .SatConvertUToS = autoSetupConstant, + .ShiftLeftLogical = autoSetupConstant, + .ShiftRightArithmetic = autoSetupConstant, + .ShiftRightLogical = autoSetupConstant, .Source = opSource, .SourceExtension = opSourceExtension, .TypeArray = opTypeArray, @@ -131,17 +144,6 @@ pub const SetupDispatcher = block: { .ULessThanEqual = autoSetupConstant, .UMod = autoSetupConstant, .Variable = opVariable, - .ShiftLeftLogical = autoSetupConstant, - .ShiftRightLogical = autoSetupConstant, - .ShiftRightArithmetic = autoSetupConstant, - .BitwiseAnd = autoSetupConstant, - .BitwiseXor = autoSetupConstant, - .Not = autoSetupConstant, - .BitFieldInsert = autoSetupConstant, - .BitFieldSExtract = autoSetupConstant, - .BitFieldUExtract = autoSetupConstant, - .BitReverse = autoSetupConstant, - .BitCount = autoSetupConstant, }); }; @@ -149,7 +151,15 @@ pub const RuntimeDispatcher = block: { @setEvalBranchQuota(65535); break :block std.EnumMap(spv.SpvOp, OpCodeFunc).init(.{ .AccessChain = opAccessChain, + .BitCount = BitEngine(.UInt, .BitCount).op, + .BitFieldInsert = BitEngine(.UInt, .BitFieldInsert).op, + .BitFieldSExtract = BitEngine(.SInt, .BitFieldSExtract).op, + .BitFieldUExtract = BitEngine(.UInt, .BitFieldUExtract).op, + .BitReverse = BitEngine(.UInt, .BitReverse).op, .Bitcast = opBitcast, + .BitwiseAnd = BitEngine(.UInt, .BitwiseAnd).op, + .BitwiseOr = BitEngine(.UInt, .BitwiseOr).op, + .BitwiseXor = BitEngine(.UInt, .BitwiseXor).op, .Branch = opBranch, .BranchConditional = opBranchConditional, .CompositeConstruct = opCompositeConstruct, @@ -183,6 +193,7 @@ pub const RuntimeDispatcher = block: { .INotEqual = CondEngine(.SInt, .NotEqual).op, .ISub = MathEngine(.SInt, .Sub).op, .Load = opLoad, + .Not = BitEngine(.UInt, .Not).op, .Return = opReturn, .ReturnValue = opReturnValue, .SConvert = ConversionEngine(.SInt, .SInt).op, @@ -192,6 +203,9 @@ pub const RuntimeDispatcher = block: { .SLessThan = CondEngine(.SInt, .Less).op, .SLessThanEqual = CondEngine(.SInt, .LessEqual).op, .SMod = MathEngine(.SInt, .Mod).op, + .ShiftLeftLogical = BitEngine(.UInt, .ShiftLeft).op, + .ShiftRightArithmetic = BitEngine(.SInt, .ShiftRightArithmetic).op, + .ShiftRightLogical = BitEngine(.UInt, .ShiftRight).op, .Store = opStore, .UConvert = ConversionEngine(.UInt, .UInt).op, .UDiv = MathEngine(.UInt, .Div).op, @@ -210,47 +224,83 @@ pub const RuntimeDispatcher = block: { }; fn BitEngine(comptime T: ValueType, comptime Op: BitOp) type { + if (T == .Float) @compileError("Invalid value type"); return struct { fn op(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { - _ = rt.it.skip(); + const target_type = (try rt.results[try rt.it.next()].getVariant()).Type; const value = try rt.results[try rt.it.next()].getValue(); const op1_value = try rt.results[try rt.it.next()].getValue(); - const op2_value = if (Op == .Not) null else try rt.results[try rt.it.next()].getValue(); + const op2_value: ?*Result.Value = switch (Op) { + .Not, .BitCount, .BitReverse => null, + else => try rt.results[try rt.it.next()].getValue(), + }; const size = sw: switch (target_type) { .Vector => |v| continue :sw (try rt.results[v.components_type_word].getVariant()).Type, - .Float => |f| if (T == .Float) f.bit_length else return RuntimeError.InvalidSpirV, - .Int => |i| if (T == .SInt or T == .UInt) i.bit_length else return RuntimeError.InvalidSpirV, + .Int => |i| i.bit_length, else => return RuntimeError.InvalidSpirV, }; const operator = struct { - fn operation(comptime TT: type, op1: TT, op2: TT) RuntimeError!TT { - return switch (Op) { - .Add => if (@typeInfo(TT) == .int) @addWithOverflow(op1, op2)[0] else op1 + op2, - .Sub => if (@typeInfo(TT) == .int) @subWithOverflow(op1, op2)[0] else op1 - op2, - .Mul => if (@typeInfo(TT) == .int) @mulWithOverflow(op1, op2)[0] else op1 * op2, - .Div => blk: { - if (op2 == 0) return RuntimeError.DivisionByZero; - break :blk if (@typeInfo(TT) == .int) @divTrunc(op1, op2) else op1 / op2; - }, - .Mod => blk: { - if (op2 == 0) return RuntimeError.DivisionByZero; - break :blk @mod(op1, op2); - }, - }; + inline fn bitMask(bits: u64) u64 { + return if (bits >= 32) ~@as(u64, 0) else (@as(u64, 0x1) << @intCast(bits)) - 1; } - fn process(bit_count: SpvWord, v: *Result.Value, op1_v: *const Result.Value, op2_v: *const Result.Value) RuntimeError!void { + inline fn bitInsert(comptime TT: type, base: TT, insert: TT, offset: u64, count: u64) TT { + const mask = bitMask(count) << @intCast(offset); + return @as(TT, @intCast((base & ~mask) | ((insert << @intCast(offset)) & mask))); + } + + inline fn bitExtract(comptime TT: type, v: TT, offset: TT, count: u64) TT { + return (v >> @intCast(offset)) & @as(TT, @intCast(bitMask(count))); + } + + fn operation(comptime TT: type, rt2: *Runtime, op1: TT, op2: ?TT) RuntimeError!TT { + switch (Op) { + .BitCount => return @bitSizeOf(TT), + .BitReverse => return @bitReverse(op1), + .Not => return ~op1, + else => {}, + } + return if (op2) |v2| + switch (Op) { + .BitFieldInsert => blk: { + const offset = try rt2.results[try rt2.it.next()].getValue(); + const count = try rt2.results[try rt2.it.next()].getValue(); + break :blk bitInsert(TT, op1, v2, offset.Int.uint64, count.Int.uint64); + }, + .BitFieldSExtract => blk: { + if (T == .UInt) return RuntimeError.InvalidSpirV; + const count = try rt2.results[try rt2.it.next()].getValue(); + break :blk bitExtract(TT, op1, v2, count.Int.uint64); + }, + .BitFieldUExtract => blk: { + if (T == .SInt) return RuntimeError.InvalidSpirV; + const count = try rt2.results[try rt2.it.next()].getValue(); + break :blk bitExtract(TT, op1, v2, count.Int.uint64); + }, + .BitwiseAnd => op1 & v2, + .BitwiseOr => op1 | v2, + .BitwiseXor => op1 ^ v2, + .ShiftLeft => op1 << @intCast(v2), + .ShiftRight, .ShiftRightArithmetic => op1 >> @intCast(v2), + else => return RuntimeError.InvalidSpirV, + } + else + RuntimeError.InvalidSpirV; + } + + fn process(rt2: *Runtime, bit_count: SpvWord, v: *Result.Value, op1_v: *const Result.Value, op2_v: ?*const Result.Value) RuntimeError!void { switch (bit_count) { inline 8, 16, 32, 64 => |i| { - if (i == 8 and T == .Float) { // No f8 - return RuntimeError.InvalidSpirV; - } (try getValuePrimitiveField(T, i, v)).* = try operation( getValuePrimitiveFieldType(T, i), + rt2, (try getValuePrimitiveField(T, i, @constCast(op1_v))).*, - (try getValuePrimitiveField(T, i, @constCast(op2_v))).*, + if (op2_v) |v2| + (try getValuePrimitiveField(T, i, @constCast(v2))).* + else + null, ); }, else => return RuntimeError.InvalidSpirV, @@ -259,9 +309,9 @@ fn BitEngine(comptime T: ValueType, comptime Op: BitOp) type { }; switch (value.*) { - .Float => if (T == .Float) try operator.process(size, value, op1_value, op2_value) else return RuntimeError.InvalidSpirV, - .Int => if (T == .SInt or T == .UInt) try operator.process(size, value, op1_value, op2_value) else return RuntimeError.InvalidSpirV, - .Vector => |vec| for (vec, op1_value.Vector, op2_value.Vector) |*val, op1_v, op2_v| try operator.process(size, val, &op1_v, &op2_v), + .Int => try operator.process(rt, size, value, op1_value, op2_value), + .Vector => |vec| for (vec, op1_value.Vector, 0..) |*val, op1_v, i| + try operator.process(rt, size, val, &op1_v, if (op2_value) |op2_v| &op2_v.Vector[i] else null), else => return RuntimeError.InvalidSpirV, } } diff --git a/test/bitwise.zig b/test/bitwise.zig new file mode 100644 index 0000000..a648e56 --- /dev/null +++ b/test/bitwise.zig @@ -0,0 +1,149 @@ +const std = @import("std"); +const root = @import("root.zig"); +const compileNzsl = root.compileNzsl; +const case = root.case; + +const Operations = enum { + BitwiseAnd, + BitwiseOr, + BitwiseXor, + ShiftLeft, + ShiftRight, + ShiftRightArithmetic, +}; + +test "Bitwise primitives" { + const allocator = std.testing.allocator; + const types = [_]type{ i32, u32 }; + var operations = std.EnumMap(Operations, []const u8).init(.{ + .BitwiseAnd = "&", + .BitwiseOr = "|", + .BitwiseXor = "^", + .ShiftLeft = "<<", + .ShiftRight = ">>", + .ShiftRightArithmetic = ">>", + }); + + var it = operations.iterator(); + while (it.next()) |op| { + inline for (types) |T| { + const op1: T = case.random(T); + const op2: T = @mod(case.random(T), @bitSizeOf(T)); + const expected = switch (op.key) { + .BitwiseAnd => op1 & op2, + .BitwiseOr => op1 | op2, + .BitwiseXor => op1 ^ op2, + .ShiftLeft => op1 << @intCast(op2), + .ShiftRight, .ShiftRightArithmetic => op1 >> @intCast(op2), + }; + + const shader = try std.fmt.allocPrint( + allocator, + \\ [nzsl_version("1.1")] + \\ [feature(float64)] + \\ module; + \\ + \\ struct FragOut + \\ {{ + \\ [location(0)] color: vec4[{s}] + \\ }} + \\ + \\ [entry(frag)] + \\ fn main() -> FragOut + \\ {{ + \\ let op1: {s} = {d}; + \\ let op2: {s} = {d}; + \\ let color = op1 {s} op2; + \\ + \\ let output: FragOut; + \\ output.color = vec4[{s}](color, color, color, color); + \\ return output; + \\ }} + , + .{ + @typeName(T), + @typeName(T), + op1, + @typeName(T), + op2, + op.value.*, + @typeName(T), + }, + ); + defer allocator.free(shader); + const code = try compileNzsl(allocator, shader); + defer allocator.free(code); + try case.expectOutput(T, 4, code, "color", &.{ expected, expected, expected, expected }); + } + } +} + +test "Bitwise vectors" { + const allocator = std.testing.allocator; + const types = [_]type{ i32, u32 }; + var operations = std.EnumMap(Operations, []const u8).init(.{ + .BitwiseAnd = "&", + .BitwiseOr = "|", + .BitwiseXor = "^", + .ShiftLeft = "<<", + .ShiftRight = ">>", + .ShiftRightArithmetic = ">>", + }); + + var it = operations.iterator(); + while (it.next()) |op| { + inline for (2..5) |L| { + inline for (types) |T| { + const op1: case.Vec(L, T) = .{ .val = case.random(@Vector(L, T)) }; + var op2: case.Vec(L, T) = .{ .val = case.random(@Vector(L, T)) }; + for (0..L) |i| op2.val[i] = @mod(op2.val[i], @bitSizeOf(T)); + const expected = switch (op.key) { + .BitwiseAnd => op1.val & op2.val, + .BitwiseOr => op1.val | op2.val, + .BitwiseXor => op1.val ^ op2.val, + .ShiftLeft => op1.val << @intCast(op2.val), + .ShiftRight, .ShiftRightArithmetic => op1.val >> @intCast(op2.val), + }; + + const shader = try std.fmt.allocPrint( + allocator, + \\ [nzsl_version("1.1")] + \\ [feature(float64)] + \\ module; + \\ + \\ struct FragOut + \\ {{ + \\ [location(0)] color: vec{d}[{s}] + \\ }} + \\ + \\ [entry(frag)] + \\ fn main() -> FragOut + \\ {{ + \\ let op1 = vec{d}[{s}]({f}); + \\ let op2 = vec{d}[{s}]({f}); + \\ + \\ let output: FragOut; + \\ output.color = op1 {s} op2; + \\ return output; + \\ }} + , + .{ + L, + @typeName(T), + L, + @typeName(T), + op1, + L, + @typeName(T), + op2, + op.value.*, + }, + ); + defer allocator.free(shader); + const code = try compileNzsl(allocator, shader); + defer allocator.free(code); + try case.expectOutput(T, L, code, "color", &@as([L]T, expected)); + } + } + } +} diff --git a/test/maths.zig b/test/maths.zig index d55b40d..d38f25d 100644 --- a/test/maths.zig +++ b/test/maths.zig @@ -11,19 +11,6 @@ const Operations = enum { Mod, }; -fn Vec(comptime len: usize, comptime T: type) type { - return struct { - const Self = @This(); - val: @Vector(len, T), - pub fn format(self: *const Self, w: *std.Io.Writer) std.Io.Writer.Error!void { - inline for (0..len) |i| { - try w.print("{d}", .{self.val[i]}); - if (i < len - 1) try w.writeAll(", "); - } - } - }; -} - // Tests all mathematical operation on all NZSL supported primitive types test "Maths primitives" { const allocator = std.testing.allocator; @@ -106,8 +93,8 @@ test "Maths vectors" { while (it.next()) |op| { inline for (2..5) |L| { inline for (types) |T| { - const base_color: Vec(L, T) = .{ .val = case.random(@Vector(L, T)) }; - const ratio: Vec(L, T) = .{ .val = case.random(@Vector(L, T)) }; + const base_color: case.Vec(L, T) = .{ .val = case.random(@Vector(L, T)) }; + const ratio: case.Vec(L, T) = .{ .val = case.random(@Vector(L, T)) }; const expected = switch (op.key) { .Add => if (@typeInfo(T) == .int) @addWithOverflow(base_color.val, ratio.val)[0] else base_color.val + ratio.val, .Sub => if (@typeInfo(T) == .int) @subWithOverflow(base_color.val, ratio.val)[0] else base_color.val - ratio.val, diff --git a/test/root.zig b/test/root.zig index 9b057e7..0f62a4a 100644 --- a/test/root.zig +++ b/test/root.zig @@ -53,11 +53,25 @@ pub const case = struct { inline else => unreachable, }; } + + pub fn Vec(comptime len: usize, comptime T: type) type { + return struct { + const Self = @This(); + val: @Vector(len, T), + pub fn format(self: *const Self, w: *std.Io.Writer) std.Io.Writer.Error!void { + inline for (0..len) |i| { + try w.print("{d}", .{self.val[i]}); + if (i < len - 1) try w.writeAll(", "); + } + } + }; + } }; test { std.testing.refAllDecls(@import("arrays.zig")); std.testing.refAllDecls(@import("basics.zig")); + std.testing.refAllDecls(@import("bitwise.zig")); std.testing.refAllDecls(@import("branching.zig")); std.testing.refAllDecls(@import("casts.zig")); std.testing.refAllDecls(@import("functions.zig"));