diff --git a/src/Runtime.zig b/src/Runtime.zig index 9afbca9..8651497 100644 --- a/src/Runtime.zig +++ b/src/Runtime.zig @@ -55,6 +55,9 @@ current_parameter_index: SpvWord, current_function: ?*Result, function_stack: std.ArrayList(Function), +current_label: ?SpvWord, +previous_label: ?SpvWord, + specialization_constants: std.AutoHashMapUnmanaged(u32, []const u8), pub fn init(allocator: std.mem.Allocator, module: *Module) RuntimeError!Self { @@ -71,6 +74,8 @@ pub fn init(allocator: std.mem.Allocator, module: *Module) RuntimeError!Self { .current_parameter_index = 0, .current_function = null, .function_stack = .empty, + .current_label = null, + .previous_label = null, .specialization_constants = .empty, }; } @@ -257,4 +262,6 @@ pub fn flushDescriptorSets(self: *const Self, allocator: std.mem.Allocator) Runt fn reset(self: *Self) void { self.function_stack.clearRetainingCapacity(); self.current_function = null; + self.current_label = null; + self.previous_label = null; } diff --git a/src/opcodes.zig b/src/opcodes.zig index dac27e5..970904b 100644 --- a/src/opcodes.zig +++ b/src/opcodes.zig @@ -276,6 +276,7 @@ pub fn initRuntimeDispatcher() void { runtime_dispatcher[@intFromEnum(spv.SpvOp.IsNan)] = CondEngine(.Float, .IsNan).op; runtime_dispatcher[@intFromEnum(spv.SpvOp.IsNormal)] = CondEngine(.Float, .IsNan).op; runtime_dispatcher[@intFromEnum(spv.SpvOp.Kill)] = opKill; + runtime_dispatcher[@intFromEnum(spv.SpvOp.Label)] = opLabel; runtime_dispatcher[@intFromEnum(spv.SpvOp.Load)] = opLoad; runtime_dispatcher[@intFromEnum(spv.SpvOp.LogicalAnd)] = CondEngine(.Bool, .LogicalAnd).op; runtime_dispatcher[@intFromEnum(spv.SpvOp.LogicalEqual)] = CondEngine(.Bool, .LogicalEqual).op; @@ -286,6 +287,7 @@ pub fn initRuntimeDispatcher() void { runtime_dispatcher[@intFromEnum(spv.SpvOp.MatrixTimesScalar)] = MathEngine(.Float, .MatrixTimesScalar, false).op; // TODO runtime_dispatcher[@intFromEnum(spv.SpvOp.MatrixTimesVector)] = MathEngine(.Float, .MatrixTimesVector, false).op; // TODO runtime_dispatcher[@intFromEnum(spv.SpvOp.Not)] = BitEngine(.UInt, .Not).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.Phi)] = opPhi; runtime_dispatcher[@intFromEnum(spv.SpvOp.Return)] = opReturn; runtime_dispatcher[@intFromEnum(spv.SpvOp.ReturnValue)] = opReturnValue; runtime_dispatcher[@intFromEnum(spv.SpvOp.SConvert)] = ConversionEngine(.SInt, .SInt).op; @@ -1395,6 +1397,7 @@ fn opBitcast(allocator: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeErro fn opBranch(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { const id = try rt.it.next(); + rt.previous_label = rt.current_label; _ = rt.it.jumpToSourceLocation(switch ((try rt.results[id].getVariant()).*) { .Label => |l| l.source_location, else => return RuntimeError.InvalidSpirV, @@ -1411,6 +1414,7 @@ fn opBranchConditional(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeEr .Label => |l| l.source_location, else => return RuntimeError.InvalidSpirV, }; + rt.previous_label = rt.current_label; if (cond_value.Bool) { _ = rt.it.jumpToSourceLocation(true_branch); } else { @@ -2064,11 +2068,14 @@ fn opFunctionParameter(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeEr fn opLabel(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { const id = try rt.it.next(); - rt.results[id].variant = .{ - .Label = .{ - .source_location = rt.it.emitSourceLocation() - 2, // Original label location - }, - }; + rt.current_label = id; + if (rt.results[id].variant == null) { + rt.results[id].variant = .{ + .Label = .{ + .source_location = rt.it.emitSourceLocation() - 2, // Original label location + }, + }; + } } fn opKill(_: std.mem.Allocator, _: SpvWord, _: *Runtime) RuntimeError!void { @@ -2126,6 +2133,25 @@ fn opName(allocator: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) Runti result.name = try readStringN(allocator, &rt.it, word_count - 1); } +fn opPhi(_: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) RuntimeError!void { + _ = try rt.it.next(); // result type + const id = try rt.it.next(); + + const predecessor = rt.previous_label orelse return RuntimeError.InvalidSpirV; + const pair_count = @divExact(word_count - 2, 2); + + for (0..pair_count) |_| { + const value_id = try rt.it.next(); + const parent_label_id = try rt.it.next(); + + if (parent_label_id == predecessor) { + try copyValue(try rt.results[id].getValue(), try rt.results[value_id].getValue()); + return; + } + } + return RuntimeError.InvalidSpirV; +} + fn opReturn(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { _ = rt.function_stack.pop(); if (rt.function_stack.getLastOrNull()) |function| {