From 076abf5d6ae0ff130e55931faf50d407a3bbedc6 Mon Sep 17 00:00:00 2001 From: Kbz-8 Date: Fri, 16 Jan 2026 23:41:11 +0100 Subject: [PATCH] adding branching and conversions --- build.zig | 6 +- example/shader.nzsl | 37 ++-- example/shader.spv | Bin 716 -> 972 bytes example/shader.spv.txt | 70 ++++++ example/shader.spvasm | 50 ----- src/Result.zig | 38 +++- src/Runtime.zig | 10 +- src/WordIterator.zig | 3 + src/opcodes.zig | 477 +++++++++++++++++++++++++++++++---------- test/branching.zig | 100 +++++++++ test/root.zig | 1 + 11 files changed, 609 insertions(+), 183 deletions(-) create mode 100644 example/shader.spv.txt delete mode 100644 example/shader.spvasm create mode 100644 test/branching.zig diff --git a/build.zig b/build.zig index 242ad7a..f12cc2b 100644 --- a/build.zig +++ b/build.zig @@ -41,9 +41,13 @@ pub fn build(b: *std.Build) void { const run_example = b.addRunArtifact(example_exe); run_example.step.dependOn(&example_install.step); - const run_example_step = b.step("example", "Run the basic example"); + const run_example_step = b.step("example", "Run the example"); run_example_step.dependOn(&run_example.step); + const compile_shader_cmd = b.addSystemCommand(&[_][]const u8{ "nzslc", "example/shader.nzsl", "--compile=spv,spv-dis", "-o", "example" }); + const compile_shader_step = b.step("example-shader", "Compiles example's shader"); + compile_shader_step.dependOn(&compile_shader_cmd.step); + // Zig unit tests setup const nzsl = b.lazyDependency("NZSL", .{}) orelse return; diff --git a/example/shader.nzsl b/example/shader.nzsl index 0a888e4..0498f61 100644 --- a/example/shader.nzsl +++ b/example/shader.nzsl @@ -1,17 +1,24 @@ -[nzsl_version("1.1")] -module; + [nzsl_version("1.1")] + [feature(float64)] + module; + + struct FragOut + { + [location(0)] color: vec4[f32] + } -struct FragOut -{ - [location(0)] color: vec4[f32] -} + [entry(frag)] + fn main() -> FragOut + { + let op1: f64 = 0.0; + let op2: f64 = 9.0; + let color: f32; + if (op1 == op2) + color = f32(op1); + else + color = f32(op2); -[entry(frag)] -fn main() -> FragOut -{ - let ratio = vec4[f32](2.0, 2.0, 8.0, 0.25); - - let output: FragOut; - output.color = vec4[f32](4.0, 3.0, 2.0, 1.0) * ratio; - return output; -} + let output: FragOut; + output.color = vec4[f32](color, color, color, color); + return output; + } diff --git a/example/shader.spv b/example/shader.spv index c074b4005813c766870f50331298d6e04248f51c..b9de341f9678392c69fb5ff985299f6768a45455 100644 GIT binary patch literal 972 zcmZ9JOH0F05QV3$sqZ$mFP}AS)v5)d8xcjR8+E5{+!qlkXalYO0e_SK%?AjcFUf__ z3zL~Mb7pep-dwglYnHJOzhj3MpOR&jF~cv&7c3VQ$Q3PbGf~a#bub)J$jf|jO7a?} zC};EX=J&<*?BQuV8HS_Npm)?G?I{cIYCL$peScHG7h!6T;Y&D9ILu9B=DUAhYXPu!58`iz3fg*A3L6FE_<)bs5`u&7|y3OvFoM` z4*4(j*X23C>Zi6UTa%%e*v;8e%-o51?8v*5LjzjDKk8u~z07dE+ltZPdIQDeTrW7? z%M5z?PgBNxhg(y)J%x7?e5c#xZFXhMpbswfOWg(Uwm(`%o|{sAqiEq(w1 literal 716 zcmYk2%}c{T6vUs2so!mCZR=-=^x#1ddJ$2q=|#P%7w<(x3fe%b_agW|dl8)9W>c~- z*`0ZBXXhmh`YUEj_8p(tkrk_HffTd4g-+nyu$nDv-xfZ`lPQIo5UU4GCOU<-rrqLH zy^%e?=CesQy-WsY1LR05eYg4e?cwWF_an!ty=3oM?tAd9^6Hy+@}S-mSoal=ej!=Q z=Ra1SKpOcwE}dDTb5|bB!gTD3zEOquR3Rs|Q+?@M6|~ZlyE)B9C0}>t`4F!Soy+_1 zP0!KV;#=Aifn3~&*9BJT*KlmUcrSK5UGQd6zq!(Vvhz>Kc`k-8|BS`_zxuYc(@VA; zfh=&cwLO27%FG>sc>!}W+feS@qgM8HbtWg-=q1y&9Pro`=>Jh3_?|#N`bZi%_66Q% dcKqS=;X%Fwf%=p8IP{t}*U!vof13D0_y;&ZAou_P diff --git a/example/shader.spv.txt b/example/shader.spv.txt new file mode 100644 index 0000000..870c27b --- /dev/null +++ b/example/shader.spv.txt @@ -0,0 +1,70 @@ +Version 1.0 +Generator: 2560130 +Bound: 42 +Schema: 0 + OpCapability Capability(Shader) + OpCapability Capability(Float64) + OpMemoryModel AddressingModel(Logical) MemoryModel(GLSL450) + OpEntryPoint ExecutionModel(Fragment) %17 "main" %6 + OpExecutionMode %17 ExecutionMode(OriginUpperLeft) + OpSource SourceLanguage(NZSL) 4198400 + OpSourceExtension "Version: 1.1" + OpName %7 "FragOut" + OpMemberName %7 0 "color" + OpName %6 "color" + OpName %17 "main" + OpDecorate %6 Decoration(Location) 0 + OpMemberDecorate %7 0 Decoration(Offset) 0 + %1 = OpTypeVoid + %2 = OpTypeFunction %1 + %3 = OpTypeFloat 32 + %4 = OpTypeVector %3 4 + %5 = OpTypePointer StorageClass(Output) %4 + %7 = OpTypeStruct %4 + %8 = OpTypeFloat 64 + %9 = OpConstant %8 f64(0) +%10 = OpTypePointer StorageClass(Function) %8 +%11 = OpConstant %8 f64(9) +%12 = OpTypePointer StorageClass(Function) %3 +%13 = OpTypeBool +%14 = OpTypePointer StorageClass(Function) %7 +%15 = OpTypeInt 32 1 +%16 = OpConstant %15 i32(0) +%39 = OpTypePointer StorageClass(Function) %4 + %6 = OpVariable %5 StorageClass(Output) +%17 = OpFunction %1 FunctionControl(0) %2 +%18 = OpLabel +%19 = OpVariable %10 StorageClass(Function) +%20 = OpVariable %10 StorageClass(Function) +%21 = OpVariable %12 StorageClass(Function) +%22 = OpVariable %14 StorageClass(Function) + OpStore %19 %9 + OpStore %20 %11 +%26 = OpLoad %8 %19 +%27 = OpLoad %8 %20 +%28 = OpFOrdEqual %13 %26 %27 + OpSelectionMerge %23 SelectionControl(0) + OpBranchConditional %28 %24 %25 +%24 = OpLabel +%29 = OpLoad %8 %19 +%30 = OpFConvert %3 %29 + OpStore %21 %30 + OpBranch %23 +%25 = OpLabel +%31 = OpLoad %8 %20 +%32 = OpFConvert %3 %31 + OpStore %21 %32 + OpBranch %23 +%23 = OpLabel +%33 = OpLoad %3 %21 +%34 = OpLoad %3 %21 +%35 = OpLoad %3 %21 +%36 = OpLoad %3 %21 +%37 = OpCompositeConstruct %4 %33 %34 %35 %36 +%38 = OpAccessChain %39 %22 %16 + OpStore %38 %37 +%40 = OpLoad %7 %22 +%41 = OpCompositeExtract %4 %40 0 + OpStore %6 %41 + OpReturn + OpFunctionEnd diff --git a/example/shader.spvasm b/example/shader.spvasm deleted file mode 100644 index f746f7e..0000000 --- a/example/shader.spvasm +++ /dev/null @@ -1,50 +0,0 @@ -; SPIR-V -; Version: 1.0 -; Generator: SirLynix Nazara ShaderLang Compiler; 4226 -; Bound: 29 -; Schema: 0 - OpCapability Shader - OpMemoryModel Logical GLSL450 - OpEntryPoint Fragment %main "main" %color - OpExecutionMode %main OriginUpperLeft - OpSource NZSL 4198400 - OpSourceExtension "Version: 1.1" - OpName %FragOut "FragOut" - OpMemberName %FragOut 0 "color" - OpName %color "color" - OpName %main "main" - OpDecorate %color Location 0 - OpMemberDecorate %FragOut 0 Offset 0 - %void = OpTypeVoid - %2 = OpTypeFunction %void - %float = OpTypeFloat 32 - %v4float = OpTypeVector %float 4 -%_ptr_Output_v4float = OpTypePointer Output %v4float - %FragOut = OpTypeStruct %v4float - %float_2 = OpConstant %float 2 - %float_8 = OpConstant %float 8 - %float_0_25 = OpConstant %float 0.25 -%_ptr_Function_v4float = OpTypePointer Function %v4float -%_ptr_Function_FragOut = OpTypePointer Function %FragOut - %int = OpTypeInt 32 1 - %int_0 = OpConstant %int 0 - %float_4 = OpConstant %float 4 - %float_3 = OpConstant %float 3 - %float_1 = OpConstant %float 1 - %color = OpVariable %_ptr_Output_v4float Output - %main = OpFunction %void None %2 - %19 = OpLabel - %20 = OpVariable %_ptr_Function_v4float Function - %21 = OpVariable %_ptr_Function_FragOut Function - %22 = OpCompositeConstruct %v4float %float_2 %float_2 %float_8 %float_0_25 - OpStore %20 %22 - %23 = OpCompositeConstruct %v4float %float_4 %float_3 %float_2 %float_1 - %24 = OpLoad %v4float %20 - %25 = OpFMul %v4float %23 %24 - %26 = OpAccessChain %_ptr_Function_v4float %21 %int_0 - OpStore %26 %25 - %27 = OpLoad %FragOut %21 - %28 = OpCompositeExtract %v4float %27 0 - OpStore %color %28 - OpReturn - OpFunctionEnd diff --git a/src/Result.zig b/src/Result.zig index 5805c84..a14e0ef 100644 --- a/src/Result.zig +++ b/src/Result.zig @@ -238,9 +238,15 @@ variant: ?union(Variant) { }, Variable: struct { storage_class: spv.SpvStorageClass, + type_word: SpvWord, + type: Type, + value: Value, + }, + Constant: struct { + type_word: SpvWord, + type: Type, value: Value, }, - Constant: Value, Function: struct { source_location: usize, return_type: SpvWord, @@ -284,7 +290,7 @@ pub fn deinit(self: *Self, allocator: std.mem.Allocator) void { }, else => {}, }, - .Constant => |*v| v.deinit(allocator), + .Constant => |*c| c.value.deinit(allocator), .Variable => |*v| v.value.deinit(allocator), //.AccessChain => |*a| a.value.deinit(allocator), else => {}, @@ -293,10 +299,26 @@ pub fn deinit(self: *Self, allocator: std.mem.Allocator) void { self.decorations.deinit(allocator); } +pub fn getValueTypeWord(self: *Self) RuntimeError!SpvWord { + return switch (self.variant orelse return RuntimeError.InvalidSpirV) { + .Variable => |v| v.type_word, + .Constant => |c| c.type_word, + else => RuntimeError.InvalidSpirV, + }; +} + +pub fn getValueType(self: *Self) RuntimeError!Type { + return switch (self.variant orelse return RuntimeError.InvalidSpirV) { + .Variable => |v| v.type, + .Constant => |c| c.type, + else => RuntimeError.InvalidSpirV, + }; +} + pub fn getValue(self: *Self) RuntimeError!*Value { return switch (self.variant orelse return RuntimeError.InvalidSpirV) { .Variable => |*v| &v.value, - .Constant => |*v| v, + .Constant => |*c| &c.value, else => RuntimeError.InvalidSpirV, }; } @@ -343,10 +365,18 @@ pub fn dupe(self: *const Self, allocator: std.mem.Allocator) RuntimeError!Self { .Variable => |v| break :blk .{ .Variable = .{ .storage_class = v.storage_class, + .type_word = v.type_word, + .type = v.type, .value = try v.value.dupe(allocator), }, }, - .Constant => |c| break :blk .{ .Constant = try c.dupe(allocator) }, + .Constant => |c| break :blk .{ + .Constant = .{ + .type_word = c.type_word, + .type = c.type, + .value = try c.value.dupe(allocator), + }, + }, .Function => |f| break :blk .{ .Function = .{ .source_location = f.source_location, diff --git a/src/Runtime.zig b/src/Runtime.zig index 1cfb6b3..0a2b028 100644 --- a/src/Runtime.zig +++ b/src/Runtime.zig @@ -111,6 +111,7 @@ pub fn callEntryPoint(self: *Self, allocator: std.mem.Allocator, entry_point_ind } } + self.it.did_jump = false; // To reset function jump while (self.it.nextOrNull()) |opcode_data| { const word_count = ((opcode_data & (~spv.SpvOpCodeMask)) >> spv.SpvWordCountShift) - 1; const opcode = (opcode_data & spv.SpvOpCodeMask); @@ -121,8 +122,13 @@ pub fn callEntryPoint(self: *Self, allocator: std.mem.Allocator, entry_point_ind try pfn(allocator, word_count, self); } } - _ = it_tmp.skipN(word_count); - self.it = it_tmp; + if (!self.it.did_jump) { + _ = it_tmp.skipN(word_count); + self.it = it_tmp; + } else { + self.it.did_jump = false; + _ = it_tmp.skip(); + } } //@import("pretty").print(allocator, self.results, .{ diff --git a/src/WordIterator.zig b/src/WordIterator.zig index 43efecc..b087d0c 100644 --- a/src/WordIterator.zig +++ b/src/WordIterator.zig @@ -9,11 +9,13 @@ const Self = @This(); buffer: []const SpvWord, index: usize, +did_jump: bool, pub fn init(buffer: []const SpvWord) Self { return .{ .buffer = buffer, .index = 0, + .did_jump = false, }; } @@ -66,5 +68,6 @@ pub inline fn emitSourceLocation(self: *const Self) usize { pub inline fn jumpToSourceLocation(self: *Self, source_location: usize) bool { if (source_location > self.buffer.len) return false; self.index = source_location; + self.did_jump = true; return true; } diff --git a/src/opcodes.zig b/src/opcodes.zig index 61e69d2..970504f 100644 --- a/src/opcodes.zig +++ b/src/opcodes.zig @@ -13,18 +13,149 @@ const SpvByte = spv.SpvByte; const SpvWord = spv.SpvWord; const SpvBool = spv.SpvBool; -const MathType = enum {Float, +const ValueType = enum { + Float, SInt, UInt, }; -const MathOp = enum {Add, +const MathOp = enum { + Add, Sub, Mul, Div, Mod, }; +const CondOp = enum { + Equal, + NotEqual, + Greater, + GreaterEqual, + Less, + LessEqual, +}; + +fn CondEngine(comptime T: ValueType, comptime Op: CondOp) type { + return struct { + fn op(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { + sw: switch ((rt.results[try rt.it.next()].variant orelse return RuntimeError.InvalidSpirV).Type) { + .Vector => |v| continue :sw (rt.results[v.components_type_word].variant orelse return RuntimeError.InvalidSpirV).Type, + .Bool => {}, + else => return RuntimeError.InvalidSpirV, + } + + const value = try rt.results[try rt.it.next()].getValue(); + const op1_result = &rt.results[try rt.it.next()]; + const op1_type = try op1_result.getValueTypeWord(); + const op1_value = try op1_result.getValue(); + const op2_value = try rt.results[try rt.it.next()].getValue(); + + const size = sw: switch ((rt.results[op1_type].variant orelse return RuntimeError.InvalidSpirV).Type) { + .Vector => |v| continue :sw (rt.results[v.components_type_word].variant orelse return RuntimeError.InvalidSpirV).Type, + .Float => |f| if (T == .Float) f.bit_length else return RuntimeError.InvalidSpirV, + .Int => |i| if (T == .SInt or T == .UInt) i.bit_length else return RuntimeError.InvalidSpirV, + else => return RuntimeError.InvalidSpirV, + }; + + const operator = struct { + fn operation(comptime TT: type, op1: TT, op2: TT) RuntimeError!bool { + return switch (Op) { + .Equal => op1 == op2, + .NotEqual => op1 != op2, + .Greater => op1 > op2, + .GreaterEqual => op1 >= op2, + .Less => op1 < op2, + .LessEqual => op1 <= op2, + }; + } + + fn process(bit_count: SpvWord, v: *Result.Value, op1_v: *const Result.Value, op2_v: *const Result.Value) RuntimeError!void { + switch (bit_count) { + inline 8, 16, 32, 64 => |i| { + if (i == 8 and T == .Float) { // No f8 + return RuntimeError.InvalidSpirV; + } + v.Bool = try operation( + getValuePrimitiveFieldType(T, i), + (try getValuePrimitiveField(T, i, @constCast(op1_v))).*, + (try getValuePrimitiveField(T, i, @constCast(op2_v))).*, + ); + }, + else => return RuntimeError.InvalidSpirV, + } + } + }; + + switch (value.*) { + .Bool => try operator.process(size, value, op1_value, op2_value), + .Vector => |vec| for (vec, op1_value.Vector, op2_value.Vector) |*val, op1_v, op2_v| try operator.process(size, val, &op1_v, &op2_v), + else => return RuntimeError.InvalidSpirV, + } + } + }; +} + +fn ConversionEngine(comptime From: ValueType, comptime To: ValueType) type { + return struct { + fn op(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { + const target_type = (rt.results[try rt.it.next()].variant orelse return RuntimeError.InvalidSpirV).Type; + const value = try rt.results[try rt.it.next()].getValue(); + const op_result = &rt.results[try rt.it.next()]; + const op_type = try op_result.getValueTypeWord(); + const op_value = try op_result.getValue(); + + const from_size = sw: switch ((rt.results[op_type].variant orelse return RuntimeError.InvalidSpirV).Type) { + .Vector => |v| continue :sw (rt.results[v.components_type_word].variant orelse return RuntimeError.InvalidSpirV).Type, + .Float => |f| if (From == .Float) f.bit_length else return RuntimeError.InvalidSpirV, + .Int => |i| if (From == .SInt or From == .UInt) i.bit_length else return RuntimeError.InvalidSpirV, + else => return RuntimeError.InvalidSpirV, + }; + + const to_size = sw: switch (target_type) { + .Vector => |v| continue :sw (rt.results[v.components_type_word].variant orelse return RuntimeError.InvalidSpirV).Type, + .Float => |f| if (To == .Float) f.bit_length else return RuntimeError.InvalidSpirV, + .Int => |i| if (To == .SInt or To == .UInt) i.bit_length else return RuntimeError.InvalidSpirV, + else => return RuntimeError.InvalidSpirV, + }; + + const operator = struct { + fn process(from_bit_count: SpvWord, to_bit_count: SpvWord, to: *Result.Value, from: *Result.Value) RuntimeError!void { + switch (to_bit_count) { + inline 8, 16, 32, 64 => |i| { + if (i == 8 and To == .Float) { + return RuntimeError.InvalidSpirV; // No f8 + } + + const ToType = getValuePrimitiveFieldType(To, i); + (try getValuePrimitiveField(To, i, to)).* = std.math.lossyCast( + ToType, + switch (from_bit_count) { + inline 8, 16, 32, 64 => |j| blk: { + if (j == 8 and From == .Float) { + return RuntimeError.InvalidSpirV; // Same + } + break :blk (try getValuePrimitiveField(From, j, from)).*; + }, + else => return RuntimeError.InvalidSpirV, + }, + ); + }, + else => return RuntimeError.InvalidSpirV, + } + } + }; + + switch (value.*) { + .Float => if (To == .Float) try operator.process(from_size, to_size, value, op_value) else return RuntimeError.InvalidSpirV, + .Int => if (To == .SInt or To == .UInt) try operator.process(from_size, to_size, value, op_value) else return RuntimeError.InvalidSpirV, + .Vector => |vec| for (vec, op_value.Vector) |*val, *op_v| try operator.process(from_size, to_size, val, op_v), + else => return RuntimeError.InvalidSpirV, + } + } + }; +} + pub const OpCodeFunc = *const fn (std.mem.Allocator, SpvWord, *Runtime) RuntimeError!void; pub const SetupDispatcher = block: { @@ -36,14 +167,42 @@ pub const SetupDispatcher = block: { .Decorate = opDecorate, .EntryPoint = opEntryPoint, .ExecutionMode = opExecutionMode, + .FAdd = autoSetupConstant, + .FDiv = autoSetupConstant, + .FMod = autoSetupConstant, + .FMul = autoSetupConstant, + .FOrdEqual = autoSetupConstant, + .FOrdGreaterThan = autoSetupConstant, + .FOrdGreaterThanEqual = autoSetupConstant, + .FOrdLessThan = autoSetupConstant, + .FOrdLessThanEqual = autoSetupConstant, + .FOrdNotEqual = autoSetupConstant, + .FSub = autoSetupConstant, + .FUnordEqual = autoSetupConstant, + .FUnordGreaterThan = autoSetupConstant, + .FUnordGreaterThanEqual = autoSetupConstant, + .FUnordLessThan = autoSetupConstant, + .FUnordLessThanEqual = autoSetupConstant, + .FUnordNotEqual = autoSetupConstant, .Function = opFunction, .FunctionEnd = opFunctionEnd, + .IAdd = autoSetupConstant, + .IEqual = autoSetupConstant, + .IMul = autoSetupConstant, + .INotEqual = autoSetupConstant, + .ISub = autoSetupConstant, .Label = opLabel, .Load = autoSetupConstant, .MemberDecorate = opDecorateMember, .MemberName = opMemberName, .MemoryModel = opMemoryModel, .Name = opName, + .SDiv = autoSetupConstant, + .SGreaterThan = autoSetupConstant, + .SGreaterThanEqual = autoSetupConstant, + .SLessThan = autoSetupConstant, + .SLessThanEqual = autoSetupConstant, + .SMod = autoSetupConstant, .Source = opSource, .SourceExtension = opSourceExtension, .TypeBool = opTypeBool, @@ -55,19 +214,26 @@ pub const SetupDispatcher = block: { .TypeStruct = opTypeStruct, .TypeVector = opTypeVector, .TypeVoid = opTypeVoid, - .Variable = opVariable, - .FAdd = autoSetupConstant, - .FDiv = autoSetupConstant, - .FMod = autoSetupConstant, - .FMul = autoSetupConstant, - .FSub = autoSetupConstant, - .IAdd = autoSetupConstant, - .IMul = autoSetupConstant, - .ISub = autoSetupConstant, - .SDiv = autoSetupConstant, - .SMod = autoSetupConstant, .UDiv = autoSetupConstant, + .UGreaterThan = autoSetupConstant, + .UGreaterThanEqual = autoSetupConstant, + .ULessThan = autoSetupConstant, + .ULessThanEqual = autoSetupConstant, .UMod = autoSetupConstant, + .Variable = opVariable, + + .ConvertFToU = autoSetupConstant, + .ConvertFToS = autoSetupConstant, + .ConvertSToF = autoSetupConstant, + .ConvertUToF = autoSetupConstant, + .UConvert = autoSetupConstant, + .SConvert = autoSetupConstant, + .FConvert = autoSetupConstant, + .QuantizeToF16 = autoSetupConstant, + .ConvertPtrToU = autoSetupConstant, + .SatConvertSToU = autoSetupConstant, + .SatConvertUToS = autoSetupConstant, + .ConvertUToPtr = autoSetupConstant, }); }; @@ -75,26 +241,122 @@ pub const RuntimeDispatcher = block: { @setEvalBranchQuota(65535); break :block std.EnumMap(spv.SpvOp, OpCodeFunc).init(.{ .AccessChain = opAccessChain, + .Branch = opBranch, + .BranchConditional = opBranchConditional, .CompositeConstruct = opCompositeConstruct, .CompositeExtract = opCompositeExtract, - .FAdd = maths(.Float, .Add).op, - .FDiv = maths(.Float, .Div).op, - .FMod = maths(.Float, .Mod).op, - .FMul = maths(.Float, .Mul).op, - .FSub = maths(.Float, .Sub).op, - .IAdd = maths(.SInt, .Add).op, - .IMul = maths(.SInt, .Mul).op, - .ISub = maths(.SInt, .Sub).op, + .FAdd = MathEngine(.Float, .Add).op, + .FDiv = MathEngine(.Float, .Div).op, + .FMod = MathEngine(.Float, .Mod).op, + .FMul = MathEngine(.Float, .Mul).op, + .FOrdEqual = CondEngine(.Float, .Equal).op, + .FOrdGreaterThan = CondEngine(.Float, .Greater).op, + .FOrdGreaterThanEqual = CondEngine(.Float, .GreaterEqual).op, + .FOrdLessThan = CondEngine(.Float, .Less).op, + .FOrdLessThanEqual = CondEngine(.Float, .LessEqual).op, + .FOrdNotEqual = CondEngine(.Float, .NotEqual).op, + .FSub = MathEngine(.Float, .Sub).op, + .FUnordEqual = CondEngine(.Float, .Equal).op, + .FUnordGreaterThan = CondEngine(.Float, .Greater).op, + .FUnordGreaterThanEqual = CondEngine(.Float, .GreaterEqual).op, + .FUnordLessThan = CondEngine(.Float, .Less).op, + .FUnordLessThanEqual = CondEngine(.Float, .LessEqual).op, + .FUnordNotEqual = CondEngine(.Float, .NotEqual).op, + .IAdd = MathEngine(.SInt, .Add).op, + .IEqual = CondEngine(.SInt, .Equal).op, + .IMul = MathEngine(.SInt, .Mul).op, + .INotEqual = CondEngine(.SInt, .NotEqual).op, + .ISub = MathEngine(.SInt, .Sub).op, .Load = opLoad, .Return = opReturn, - .SDiv = maths(.SInt, .Div).op, - .SMod = maths(.SInt, .Mod).op, + .SDiv = MathEngine(.SInt, .Div).op, + .SGreaterThan = CondEngine(.SInt, .Greater).op, + .SGreaterThanEqual = CondEngine(.SInt, .GreaterEqual).op, + .SLessThan = CondEngine(.SInt, .Less).op, + .SLessThanEqual = CondEngine(.SInt, .LessEqual).op, + .SMod = MathEngine(.SInt, .Mod).op, .Store = opStore, - .UDiv = maths(.UInt, .Div).op, - .UMod = maths(.UInt, .Mod).op, + .UDiv = MathEngine(.UInt, .Div).op, + .UGreaterThan = CondEngine(.UInt, .Greater).op, + .UGreaterThanEqual = CondEngine(.UInt, .GreaterEqual).op, + .ULessThan = CondEngine(.UInt, .Less).op, + .ULessThanEqual = CondEngine(.UInt, .LessEqual).op, + .UMod = MathEngine(.UInt, .Mod).op, + + .ConvertFToU = ConversionEngine(.Float, .UInt).op, + .ConvertFToS = ConversionEngine(.Float, .SInt).op, + .ConvertSToF = ConversionEngine(.SInt, .Float).op, + .ConvertUToF = ConversionEngine(.UInt, .Float).op, + .UConvert = ConversionEngine(.UInt, .UInt).op, + .SConvert = ConversionEngine(.SInt, .SInt).op, + .FConvert = ConversionEngine(.Float, .Float).op, + //.QuantizeToF16 = autoSetupConstant, + //.ConvertPtrToU = autoSetupConstant, + //.SatConvertSToU = autoSetupConstant, + //.SatConvertUToS = autoSetupConstant, + //.ConvertUToPtr = autoSetupConstant, }); }; +fn MathEngine(comptime T: ValueType, comptime Op: MathOp) type { + return struct { + fn op(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { + const target_type = (rt.results[try rt.it.next()].variant orelse return RuntimeError.InvalidSpirV).Type; + const value = try rt.results[try rt.it.next()].getValue(); + const op1_value = try rt.results[try rt.it.next()].getValue(); + const op2_value = try rt.results[try rt.it.next()].getValue(); + + const size = sw: switch (target_type) { + .Vector => |v| continue :sw (rt.results[v.components_type_word].variant orelse return RuntimeError.InvalidSpirV).Type, + .Float => |f| if (T == .Float) f.bit_length else return RuntimeError.InvalidSpirV, + .Int => |i| if (T == .SInt or T == .UInt) i.bit_length else return RuntimeError.InvalidSpirV, + else => return RuntimeError.InvalidSpirV, + }; + + const operator = struct { + fn operation(comptime TT: type, op1: TT, op2: TT) RuntimeError!TT { + return switch (Op) { + .Add => if (@typeInfo(TT) == .int) @addWithOverflow(op1, op2)[0] else op1 + op2, + .Sub => if (@typeInfo(TT) == .int) @subWithOverflow(op1, op2)[0] else op1 - op2, + .Mul => if (@typeInfo(TT) == .int) @mulWithOverflow(op1, op2)[0] else op1 * op2, + .Div => blk: { + if (op2 == 0) return RuntimeError.DivisionByZero; + break :blk if (@typeInfo(TT) == .int) @divTrunc(op1, op2) else op1 / op2; + }, + .Mod => blk: { + if (op2 == 0) return RuntimeError.DivisionByZero; + break :blk @mod(op1, op2); + }, + }; + } + + fn process(bit_count: SpvWord, v: *Result.Value, op1_v: *const Result.Value, op2_v: *const Result.Value) RuntimeError!void { + switch (bit_count) { + inline 8, 16, 32, 64 => |i| { + if (i == 8 and T == .Float) { // No f8 + return RuntimeError.InvalidSpirV; + } + (try getValuePrimitiveField(T, i, v)).* = try operation( + getValuePrimitiveFieldType(T, i), + (try getValuePrimitiveField(T, i, @constCast(op1_v))).*, + (try getValuePrimitiveField(T, i, @constCast(op2_v))).*, + ); + }, + else => return RuntimeError.InvalidSpirV, + } + } + }; + + switch (value.*) { + .Float => if (T == .Float) try operator.process(size, value, op1_value, op2_value) else return RuntimeError.InvalidSpirV, + .Int => if (T == .SInt or T == .UInt) try operator.process(size, value, op1_value, op2_value) else return RuntimeError.InvalidSpirV, + .Vector => |vec| for (vec, op1_value.Vector, op2_value.Vector) |*val, op1_v, op2_v| try operator.process(size, val, &op1_v, &op2_v), + else => return RuntimeError.InvalidSpirV, + } + } + }; +} + fn addDecoration(allocator: std.mem.Allocator, rt: *Runtime, target: SpvWord, decoration_type: spv.SpvDecoration, member: ?SpvWord) RuntimeError!void { var decoration = rt.mod.results[target].decorations.addOne(allocator) catch return RuntimeError.OutOfMemory; decoration.rtype = decoration_type; @@ -156,81 +418,29 @@ fn copyValue(dst: *Result.Value, src: *const Result.Value) void { dst.* = src.*; } } -fn maths(comptime T: MathType, comptime Op: MathOp) type { - return struct { - fn op(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { - const target_type = (rt.results[try rt.it.next()].variant orelse return RuntimeError.InvalidSpirV).Type; - const value = try rt.results[try rt.it.next()].getValue(); - const op1_value = try rt.results[try rt.it.next()].getValue(); - const op2_value = try rt.results[try rt.it.next()].getValue(); - const size = sw: switch (target_type) { - .Vector => |v| continue :sw (rt.results[v.components_type_word].variant orelse return RuntimeError.InvalidSpirV).Type, - .Float => |f| if (T == .Float) f.bit_length else return RuntimeError.InvalidSpirV, - .Int => |i| if (T == .SInt or T == .UInt) i.bit_length else return RuntimeError.InvalidSpirV, - else => return RuntimeError.InvalidSpirV, - }; +fn getValuePrimitiveField(comptime T: ValueType, comptime BitCount: SpvWord, v: *Result.Value) RuntimeError!*getValuePrimitiveFieldType(T, BitCount) { + return switch (T) { + .Float => switch (BitCount) { + inline 16, 32, 64 => |i| &@field(v.Float, std.fmt.comptimePrint("float{}", .{i})), + else => return RuntimeError.InvalidSpirV, + }, + .SInt => switch (BitCount) { + inline 8, 16, 32, 64 => |i| &@field(v.Int, std.fmt.comptimePrint("sint{}", .{i})), + else => return RuntimeError.InvalidSpirV, + }, + .UInt => switch (BitCount) { + inline 8, 16, 32, 64 => |i| &@field(v.Int, std.fmt.comptimePrint("uint{}", .{i})), + else => return RuntimeError.InvalidSpirV, + }, + }; +} - const operator = struct { - fn operation(comptime TT: type, op1: TT, op2: TT) RuntimeError!TT { - return switch (Op) { - .Add => if (@typeInfo(TT) == .int) @addWithOverflow(op1, op2)[0] else op1 + op2, - .Sub => if (@typeInfo(TT) == .int) @subWithOverflow(op1, op2)[0] else op1 - op2, - .Mul => if (@typeInfo(TT) == .int) @mulWithOverflow(op1, op2)[0] else op1 * op2, - .Div => blk: { - if (op2 == 0) return RuntimeError.DivisionByZero; - break :blk if (@typeInfo(TT) == .int) @divTrunc(op1, op2) else op1 / op2; - }, - .Mod => blk: { - if (op2 == 0) return RuntimeError.DivisionByZero; - break :blk @mod(op1, op2); - }, - }; - } - - fn process(bit_count: SpvWord, v: *Result.Value, op1_v: *const Result.Value, op2_v: *const Result.Value) RuntimeError!void { - switch (T) { - .Float => switch (bit_count) { - inline 16, 32, 64 => |i| @field(v.Float, std.fmt.comptimePrint("float{}", .{i})) = try operation( - @Type(.{ .float = .{ .bits = i } }), - @field(op1_v.Float, std.fmt.comptimePrint("float{}", .{i})), - @field(op2_v.Float, std.fmt.comptimePrint("float{}", .{i})), - ), - else => return RuntimeError.InvalidSpirV, - }, - .SInt => switch (bit_count) { - inline 8, 16, 32, 64 => |i| @field(v.Int, std.fmt.comptimePrint("sint{}", .{i})) = try operation( - @Type(.{ .int = .{ - .signedness = .signed, - .bits = i, - } }), - @field(op1_v.Int, std.fmt.comptimePrint("sint{}", .{i})), - @field(op2_v.Int, std.fmt.comptimePrint("sint{}", .{i})), - ), - else => return RuntimeError.InvalidSpirV, - }, - .UInt => switch (bit_count) { - inline 8, 16, 32, 64 => |i| @field(v.Int, std.fmt.comptimePrint("uint{}", .{i})) = try operation( - @Type(.{ .int = .{ - .signedness = .unsigned, - .bits = i, - } }), - @field(op1_v.Int, std.fmt.comptimePrint("uint{}", .{i})), - @field(op2_v.Int, std.fmt.comptimePrint("uint{}", .{i})), - ), - else => return RuntimeError.InvalidSpirV, - }, - } - } - }; - - switch (value.*) { - .Float => if (T == .Float) try operator.process(size, value, op1_value, op2_value) else return RuntimeError.InvalidSpirV, - .Int => if (T == .SInt or T == .UInt) try operator.process(size, value, op1_value, op2_value) else return RuntimeError.InvalidSpirV, - .Vector => |vec| for (vec, op1_value.Vector, op2_value.Vector) |*val, op1_v, op2_v| try operator.process(size, val, &op1_v, &op2_v), - else => return RuntimeError.InvalidSpirV, - } - } +fn getValuePrimitiveFieldType(comptime T: ValueType, comptime BitCount: SpvWord) type { + return switch (T) { + .Float => std.meta.Float(BitCount), + .SInt => std.meta.Int(.signed, BitCount), + .UInt => std.meta.Int(.unsigned, BitCount), }; } @@ -246,7 +456,7 @@ fn opAccessChain(_: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) Runtim for (0..index_count) |_| { const member = &rt.results[try rt.it.next()]; const member_value = switch (member.variant orelse return RuntimeError.InvalidSpirV) { - .Constant => |c| &c, + .Constant => |c| &c.value, .Variable => |v| &v.value, else => return RuntimeError.InvalidSpirV, }; @@ -281,6 +491,31 @@ fn opAccessChain(_: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) Runtim }; } +fn opBranch(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { + const id = try rt.it.next(); + _ = rt.it.jumpToSourceLocation(switch (rt.results[id].variant orelse return RuntimeError.InvalidSpirV) { + .Label => |l| l.source_location, + else => return RuntimeError.InvalidSpirV, + }); +} + +fn opBranchConditional(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { + const cond_value = try rt.results[try rt.it.next()].getValue(); + const true_branch = switch (rt.results[try rt.it.next()].variant orelse return RuntimeError.InvalidSpirV) { + .Label => |l| l.source_location, + else => return RuntimeError.InvalidSpirV, + }; + const false_branch = switch (rt.results[try rt.it.next()].variant orelse return RuntimeError.InvalidSpirV) { + .Label => |l| l.source_location, + else => return RuntimeError.InvalidSpirV, + }; + if (cond_value.Bool) { + _ = rt.it.jumpToSourceLocation(true_branch); + } else { + _ = rt.it.jumpToSourceLocation(false_branch); + } +} + fn opCapability(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { rt.mod.capabilities.insert(try rt.it.nextAs(spv.SpvCapability)); } @@ -290,33 +525,40 @@ fn opCompositeConstruct(_: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) const id = try rt.it.next(); const index_count = word_count - 2; - const target = (rt.results[id].variant orelse return RuntimeError.InvalidSpirV).Constant.getCompositeDataOrNull() orelse return RuntimeError.InvalidSpirV; + const target = (rt.results[id].variant orelse return RuntimeError.InvalidSpirV).Constant.value.getCompositeDataOrNull() orelse return RuntimeError.InvalidSpirV; for (target[0..index_count]) |*elem| { - const value = (rt.results[try rt.it.next()].variant orelse return RuntimeError.InvalidSpirV).Constant; + const value = (rt.results[try rt.it.next()].variant orelse return RuntimeError.InvalidSpirV).Constant.value; elem.* = value; } } fn opCompositeExtract(allocator: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) RuntimeError!void { - _ = rt.it.skip(); + const res_type = try rt.it.next(); const id = try rt.it.next(); const composite_id = try rt.it.next(); const index_count = word_count - 3; - var composite = (rt.results[composite_id].variant orelse return RuntimeError.InvalidSpirV).Constant; + var composite = (rt.results[composite_id].variant orelse return RuntimeError.InvalidSpirV).Constant.value; for (0..index_count) |_| { const member_id = try rt.it.next(); composite = (composite.getCompositeDataOrNull() orelse return RuntimeError.InvalidSpirV)[member_id]; } rt.results[id].variant = .{ - .Constant = try composite.dupe(allocator), + .Constant = .{ + .type_word = res_type, + .type = switch (rt.results[res_type].variant orelse return RuntimeError.InvalidSpirV) { + .Type => |t| @as(Result.Type, t), + else => return RuntimeError.InvalidSpirV, + }, + .value = try composite.dupe(allocator), + }, }; } fn opConstant(allocator: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) RuntimeError!void { const target = try setupConstant(allocator, rt); // No check on null and sizes, absolute trust in this shit - switch (target.variant.?.Constant) { + switch (target.variant.?.Constant.value) { .Int => |*i| { if (word_count - 2 != 1) { const low = @as(u64, try rt.it.next()); @@ -442,13 +684,13 @@ fn opLoad(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { copyValue( switch (rt.results[id].variant orelse return RuntimeError.InvalidSpirV) { .Variable => |*v| &v.value, - .Constant => |*c| c, + .Constant => |*c| &c.value, .AccessChain => |*a| &a.value, else => return RuntimeError.InvalidSpirV, }, switch (rt.results[ptr_id].variant orelse return RuntimeError.InvalidSpirV) { .Variable => |v| &v.value, - .Constant => |c| &c, + .Constant => |c| &c.value, .AccessChain => |a| &a.value, else => return RuntimeError.InvalidSpirV, }, @@ -540,13 +782,13 @@ fn opStore(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { copyValue( switch (rt.results[ptr_id].variant orelse return RuntimeError.InvalidSpirV) { .Variable => |*v| &v.value, - .Constant => |*c| c, + .Constant => |*c| &c.value, .AccessChain => |*a| &a.value, else => return RuntimeError.InvalidSpirV, }, switch (rt.results[val_id].variant orelse return RuntimeError.InvalidSpirV) { .Variable => |v| &v.value, - .Constant => |c| &c, + .Constant => |c| &c.value, .AccessChain => |a| &a.value, else => return RuntimeError.InvalidSpirV, }, @@ -716,6 +958,11 @@ fn opVariable(allocator: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) R target.variant = .{ .Variable = .{ .storage_class = storage_class, + .type_word = var_type, + .type = switch (resolved.variant orelse return RuntimeError.InvalidSpirV) { + .Type => |t| @as(Result.Type, t), + else => return RuntimeError.InvalidSpirV, + }, .value = try Result.initValue(allocator, member_count, rt.mod.results, resolved), }, }; @@ -761,7 +1008,15 @@ fn setupConstant(allocator: std.mem.Allocator, rt: *Runtime) RuntimeError!*Resul if (member_count == 0) { return RuntimeError.InvalidSpirV; } - target.variant = .{ .Constant = try Result.initValue(allocator, member_count, rt.mod.results, resolved) }; + target.variant = .{ + .Constant = .{ + .value = try Result.initValue(allocator, member_count, rt.mod.results, resolved), + .type_word = res_type, + .type = switch (resolved.variant orelse return RuntimeError.InvalidSpirV) { + .Type => |t| @as(Result.Type, t), + else => return RuntimeError.InvalidSpirV, + }, + }, + }; return target; } - diff --git a/test/branching.zig b/test/branching.zig new file mode 100644 index 0000000..cecc6d3 --- /dev/null +++ b/test/branching.zig @@ -0,0 +1,100 @@ +const std = @import("std"); +const root = @import("root.zig"); +const compileNzsl = root.compileNzsl; +const case = root.case; + +const Operations = enum { + Equal, + NotEqual, + Greater, + GreaterEqual, + Less, + LessEqual, +}; + +test "Simple branching" { + const allocator = std.testing.allocator; + const types = [_]type{ f32, f64, i32, u32 }; + var operations = std.EnumMap(Operations, []const u8).init(.{ + .Equal = "==", + .NotEqual = "!=", + .Greater = ">", + .GreaterEqual = ">=", + .Less = "<", + .LessEqual = "<=", + }); + + var it = operations.iterator(); + while (it.next()) |op| { + inline for (types) |T| { + const values = [_][2]T{ + [2]T{ std.math.lossyCast(T, 0), std.math.lossyCast(T, 9) }, + [2]T{ std.math.lossyCast(T, 1), std.math.lossyCast(T, 8) }, + [2]T{ std.math.lossyCast(T, 2), std.math.lossyCast(T, 7) }, + [2]T{ std.math.lossyCast(T, 3), std.math.lossyCast(T, 6) }, + [2]T{ std.math.lossyCast(T, 4), std.math.lossyCast(T, 5) }, + [2]T{ std.math.lossyCast(T, 5), std.math.lossyCast(T, 4) }, + [2]T{ std.math.lossyCast(T, 6), std.math.lossyCast(T, 3) }, + [2]T{ std.math.lossyCast(T, 7), std.math.lossyCast(T, 2) }, + [2]T{ std.math.lossyCast(T, 8), std.math.lossyCast(T, 1) }, + [2]T{ std.math.lossyCast(T, 9), std.math.lossyCast(T, 0) }, + [2]T{ std.math.lossyCast(T, 0), std.math.lossyCast(T, 0) }, + }; + for (values) |v| { + const op1: T = v[0]; + const op2: T = v[1]; + const expected = switch (op.key) { + .Equal => if (op1 == op2) op1 else op2, + .NotEqual => if (op1 != op2) op1 else op2, + .Greater => if (op1 > op2) op1 else op2, + .GreaterEqual => if (op1 >= op2) op1 else op2, + .Less => if (op1 < op2) op1 else op2, + .LessEqual => if (op1 <= op2) op1 else op2, + }; + + const shader = try std.fmt.allocPrint( + allocator, + \\ [nzsl_version("1.1")] + \\ [feature(float64)] + \\ module; + \\ + \\ struct FragOut + \\ {{ + \\ [location(0)] color: vec4[{s}] + \\ }} + \\ + \\ [entry(frag)] + \\ fn main() -> FragOut + \\ {{ + \\ let op1 = {s}({d}); + \\ let op2 = {s}({d}); + \\ let color: {s}; + \\ if (op1 {s} op2) + \\ color = op1; + \\ else + \\ color = op2; + \\ + \\ let output: FragOut; + \\ output.color = vec4[{s}](color, color, color, color); + \\ return output; + \\ }} + , + .{ + @typeName(T), + @typeName(T), + op1, + @typeName(T), + op2, + @typeName(T), + op.value.*, + @typeName(T), + }, + ); + defer allocator.free(shader); + const code = try compileNzsl(allocator, shader); + defer allocator.free(code); + try case.expectOutput(T, 4, code, "color", &.{ expected, expected, expected, expected }); + } + } + } +} diff --git a/test/root.zig b/test/root.zig index 2f3266d..d464ee1 100644 --- a/test/root.zig +++ b/test/root.zig @@ -57,5 +57,6 @@ pub const case = struct { test { std.testing.refAllDecls(@import("basics.zig")); + std.testing.refAllDecls(@import("branching.zig")); std.testing.refAllDecls(@import("maths.zig")); }