diff --git a/src/Runtime.zig b/src/Runtime.zig index fb2aa16..986ee6c 100644 --- a/src/Runtime.zig +++ b/src/Runtime.zig @@ -33,6 +33,12 @@ pub const RuntimeError = error{ Unknown, }; +pub const SpecializationEntry = struct { + id: SpvWord, + offset: usize, + size: usize, +}; + pub const Function = struct { source_location: usize, result: *Result, @@ -49,6 +55,8 @@ current_parameter_index: SpvWord, current_function: ?*Result, function_stack: std.ArrayList(Function), +specialization_constants: std.AutoHashMapUnmanaged(u32, []const u8), + pub fn init(allocator: std.mem.Allocator, module: *Module) RuntimeError!Self { return .{ .mod = module, @@ -63,6 +71,7 @@ pub fn init(allocator: std.mem.Allocator, module: *Module) RuntimeError!Self { .current_parameter_index = 0, .current_function = null, .function_stack = .empty, + .specialization_constants = .empty, }; } @@ -72,6 +81,16 @@ pub fn deinit(self: *Self, allocator: std.mem.Allocator) void { } allocator.free(self.results); self.function_stack.deinit(allocator); + var it = self.specialization_constants.iterator(); + while (it.next()) |entry| { + allocator.free(entry.value_ptr.*); + } + self.specialization_constants.deinit(allocator); +} + +pub fn addSpecializationInfo(self: *Self, allocator: std.mem.Allocator, entry: SpecializationEntry, data: []const u8) RuntimeError!void { + const slice = allocator.dupe(u8, data[entry.offset .. entry.offset + entry.size]) catch return RuntimeError.OutOfMemory; + self.specialization_constants.put(allocator, entry.id, slice) catch return RuntimeError.OutOfMemory; } pub fn getEntryPointByName(self: *const Self, name: []const u8) error{NotFound}!SpvWord { @@ -81,6 +100,7 @@ pub fn getEntryPointByName(self: *const Self, name: []const u8) error{NotFound}! for (0..@min(name.len, entry_point.name.len)) |j| { if (name[j] != entry_point.name[j]) break :blk false; } + if (entry_point.name.len != name.len and entry_point.name[name.len] != 0) break :blk false; break :blk true; }) return @intCast(i); } @@ -121,6 +141,12 @@ pub fn callEntryPoint(self: *Self, allocator: std.mem.Allocator, entry_point_ind if (entry_point_index > self.mod.entry_points.items.len) return RuntimeError.InvalidEntryPoint; + // Spec constants pass + try self.pass(allocator, .initMany(&.{ + .SpecConstant, + .SpecConstantOp, + })); + { const entry_point_desc = &self.mod.entry_points.items[entry_point_index]; const entry_point_result = &self.mod.results[entry_point_desc.id]; @@ -142,11 +168,28 @@ pub fn callEntryPoint(self: *Self, allocator: std.mem.Allocator, entry_point_ind } } + // Execution pass + try self.pass(allocator, .initFull()); + + //@import("pretty").print(allocator, self.results, .{ + // .tab_size = 4, + // .max_depth = 0, + // .struct_max_len = 0, + // .array_max_len = 0, + //}) catch return RuntimeError.OutOfMemory; +} + +fn pass(self: *Self, allocator: std.mem.Allocator, op_set: std.EnumSet(spv.SpvOp)) RuntimeError!void { self.it.did_jump = false; // To reset function jump while (self.it.nextOrNull()) |opcode_data| { const word_count = ((opcode_data & (~spv.SpvOpCodeMask)) >> spv.SpvWordCountShift) - 1; const opcode = (opcode_data & spv.SpvOpCodeMask); + if (!op_set.contains(@enumFromInt(opcode))) { + _ = self.it.skipN(word_count); + continue; + } + var it_tmp = self.it; // Save because operations may iter on this iterator if (op.runtime_dispatcher[opcode]) |pfn| { try pfn(allocator, word_count, self); @@ -158,13 +201,6 @@ pub fn callEntryPoint(self: *Self, allocator: std.mem.Allocator, entry_point_ind self.it.did_jump = false; } } - - //@import("pretty").print(allocator, self.results, .{ - // .tab_size = 4, - // .max_depth = 0, - // .struct_max_len = 0, - // .array_max_len = 0, - //}) catch return RuntimeError.OutOfMemory; } pub fn writeDescriptorSet(self: *const Self, input: []u8, set: SpvWord, binding: SpvWord, descriptor_index: SpvWord) RuntimeError!void { diff --git a/src/opcodes.zig b/src/opcodes.zig index 7141f93..add164d 100644 --- a/src/opcodes.zig +++ b/src/opcodes.zig @@ -187,8 +187,6 @@ pub const SetupDispatcher = block: { .Variable = opVariable, .VectorTimesMatrix = autoSetupConstant, .VectorTimesScalar = autoSetupConstant, - .SpecConstant = opConstant, - .SpecConstantOp = opSpecConstantOp, .SpecConstantTrue = opSpecConstantTrue, .SpecConstantFalse = opSpecConstantFalse, .SpecConstantComposite = opConstantComposite, @@ -287,6 +285,8 @@ pub fn initRuntimeDispatcher() void { runtime_dispatcher[@intFromEnum(spv.SpvOp.UMod)] = MathEngine(.UInt, .Mod).op; runtime_dispatcher[@intFromEnum(spv.SpvOp.VectorTimesMatrix)] = MathEngine(.Float, .VectorTimesMatrix).op; // TODO runtime_dispatcher[@intFromEnum(spv.SpvOp.VectorTimesScalar)] = MathEngine(.Float, .VectorTimesScalar).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.SpecConstant)] = opSpecConstant; + runtime_dispatcher[@intFromEnum(spv.SpvOp.SpecConstantOp)] = opSpecConstantOp; // zig fmt: on // Extensions init @@ -574,10 +574,7 @@ fn CondOperator(comptime T: PrimitiveType, comptime Op: CondOp) type { } return switch (Op) { .IsFinite => std.math.isFinite(a), - .IsInf => blk: { - //std.debug.print("test {s} - {d} - {s}\n", .{ @typeName(TT), a, if (std.math.isInf(a)) "true" else "false" }); - break :blk std.math.isInf(a); - }, + .IsInf => std.math.isInf(a), .IsNan => std.math.isNan(a), .IsNormal => std.math.isNormal(a), else => RuntimeError.InvalidSpirV, @@ -1644,6 +1641,25 @@ fn opConstantComposite(allocator: std.mem.Allocator, _: SpvWord, rt: *Runtime) R } } +fn opSpecConstant(allocator: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) RuntimeError!void { + const location = rt.it.emitSourceLocation(); + _ = rt.it.skip(); + const result_id = try rt.it.next(); + _ = rt.it.goToSourceLocation(location); + + try opConstant(allocator, word_count, rt); + + const result = &rt.results[result_id]; + + for (result.decorations.items) |decoration| { + if (decoration.rtype == .SpecId) { + if (rt.specialization_constants.get(decoration.literal_1)) |data| { + _ = try (try result.getValue()).writeConst(data); + } + } + } +} + fn opSpecConstantTrue(allocator: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { const target = try setupConstant(allocator, rt); switch (target.variant.?.Constant.value) {