adding spec constant management
This commit is contained in:
@@ -33,6 +33,12 @@ pub const RuntimeError = error{
|
|||||||
Unknown,
|
Unknown,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
pub const SpecializationEntry = struct {
|
||||||
|
id: SpvWord,
|
||||||
|
offset: usize,
|
||||||
|
size: usize,
|
||||||
|
};
|
||||||
|
|
||||||
pub const Function = struct {
|
pub const Function = struct {
|
||||||
source_location: usize,
|
source_location: usize,
|
||||||
result: *Result,
|
result: *Result,
|
||||||
@@ -49,6 +55,8 @@ current_parameter_index: SpvWord,
|
|||||||
current_function: ?*Result,
|
current_function: ?*Result,
|
||||||
function_stack: std.ArrayList(Function),
|
function_stack: std.ArrayList(Function),
|
||||||
|
|
||||||
|
specialization_constants: std.AutoHashMapUnmanaged(u32, []const u8),
|
||||||
|
|
||||||
pub fn init(allocator: std.mem.Allocator, module: *Module) RuntimeError!Self {
|
pub fn init(allocator: std.mem.Allocator, module: *Module) RuntimeError!Self {
|
||||||
return .{
|
return .{
|
||||||
.mod = module,
|
.mod = module,
|
||||||
@@ -63,6 +71,7 @@ pub fn init(allocator: std.mem.Allocator, module: *Module) RuntimeError!Self {
|
|||||||
.current_parameter_index = 0,
|
.current_parameter_index = 0,
|
||||||
.current_function = null,
|
.current_function = null,
|
||||||
.function_stack = .empty,
|
.function_stack = .empty,
|
||||||
|
.specialization_constants = .empty,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -72,6 +81,16 @@ pub fn deinit(self: *Self, allocator: std.mem.Allocator) void {
|
|||||||
}
|
}
|
||||||
allocator.free(self.results);
|
allocator.free(self.results);
|
||||||
self.function_stack.deinit(allocator);
|
self.function_stack.deinit(allocator);
|
||||||
|
var it = self.specialization_constants.iterator();
|
||||||
|
while (it.next()) |entry| {
|
||||||
|
allocator.free(entry.value_ptr.*);
|
||||||
|
}
|
||||||
|
self.specialization_constants.deinit(allocator);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn addSpecializationInfo(self: *Self, allocator: std.mem.Allocator, entry: SpecializationEntry, data: []const u8) RuntimeError!void {
|
||||||
|
const slice = allocator.dupe(u8, data[entry.offset .. entry.offset + entry.size]) catch return RuntimeError.OutOfMemory;
|
||||||
|
self.specialization_constants.put(allocator, entry.id, slice) catch return RuntimeError.OutOfMemory;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn getEntryPointByName(self: *const Self, name: []const u8) error{NotFound}!SpvWord {
|
pub fn getEntryPointByName(self: *const Self, name: []const u8) error{NotFound}!SpvWord {
|
||||||
@@ -81,6 +100,7 @@ pub fn getEntryPointByName(self: *const Self, name: []const u8) error{NotFound}!
|
|||||||
for (0..@min(name.len, entry_point.name.len)) |j| {
|
for (0..@min(name.len, entry_point.name.len)) |j| {
|
||||||
if (name[j] != entry_point.name[j]) break :blk false;
|
if (name[j] != entry_point.name[j]) break :blk false;
|
||||||
}
|
}
|
||||||
|
if (entry_point.name.len != name.len and entry_point.name[name.len] != 0) break :blk false;
|
||||||
break :blk true;
|
break :blk true;
|
||||||
}) return @intCast(i);
|
}) return @intCast(i);
|
||||||
}
|
}
|
||||||
@@ -121,6 +141,12 @@ pub fn callEntryPoint(self: *Self, allocator: std.mem.Allocator, entry_point_ind
|
|||||||
if (entry_point_index > self.mod.entry_points.items.len)
|
if (entry_point_index > self.mod.entry_points.items.len)
|
||||||
return RuntimeError.InvalidEntryPoint;
|
return RuntimeError.InvalidEntryPoint;
|
||||||
|
|
||||||
|
// Spec constants pass
|
||||||
|
try self.pass(allocator, .initMany(&.{
|
||||||
|
.SpecConstant,
|
||||||
|
.SpecConstantOp,
|
||||||
|
}));
|
||||||
|
|
||||||
{
|
{
|
||||||
const entry_point_desc = &self.mod.entry_points.items[entry_point_index];
|
const entry_point_desc = &self.mod.entry_points.items[entry_point_index];
|
||||||
const entry_point_result = &self.mod.results[entry_point_desc.id];
|
const entry_point_result = &self.mod.results[entry_point_desc.id];
|
||||||
@@ -142,11 +168,28 @@ pub fn callEntryPoint(self: *Self, allocator: std.mem.Allocator, entry_point_ind
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Execution pass
|
||||||
|
try self.pass(allocator, .initFull());
|
||||||
|
|
||||||
|
//@import("pretty").print(allocator, self.results, .{
|
||||||
|
// .tab_size = 4,
|
||||||
|
// .max_depth = 0,
|
||||||
|
// .struct_max_len = 0,
|
||||||
|
// .array_max_len = 0,
|
||||||
|
//}) catch return RuntimeError.OutOfMemory;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn pass(self: *Self, allocator: std.mem.Allocator, op_set: std.EnumSet(spv.SpvOp)) RuntimeError!void {
|
||||||
self.it.did_jump = false; // To reset function jump
|
self.it.did_jump = false; // To reset function jump
|
||||||
while (self.it.nextOrNull()) |opcode_data| {
|
while (self.it.nextOrNull()) |opcode_data| {
|
||||||
const word_count = ((opcode_data & (~spv.SpvOpCodeMask)) >> spv.SpvWordCountShift) - 1;
|
const word_count = ((opcode_data & (~spv.SpvOpCodeMask)) >> spv.SpvWordCountShift) - 1;
|
||||||
const opcode = (opcode_data & spv.SpvOpCodeMask);
|
const opcode = (opcode_data & spv.SpvOpCodeMask);
|
||||||
|
|
||||||
|
if (!op_set.contains(@enumFromInt(opcode))) {
|
||||||
|
_ = self.it.skipN(word_count);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
var it_tmp = self.it; // Save because operations may iter on this iterator
|
var it_tmp = self.it; // Save because operations may iter on this iterator
|
||||||
if (op.runtime_dispatcher[opcode]) |pfn| {
|
if (op.runtime_dispatcher[opcode]) |pfn| {
|
||||||
try pfn(allocator, word_count, self);
|
try pfn(allocator, word_count, self);
|
||||||
@@ -158,13 +201,6 @@ pub fn callEntryPoint(self: *Self, allocator: std.mem.Allocator, entry_point_ind
|
|||||||
self.it.did_jump = false;
|
self.it.did_jump = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//@import("pretty").print(allocator, self.results, .{
|
|
||||||
// .tab_size = 4,
|
|
||||||
// .max_depth = 0,
|
|
||||||
// .struct_max_len = 0,
|
|
||||||
// .array_max_len = 0,
|
|
||||||
//}) catch return RuntimeError.OutOfMemory;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn writeDescriptorSet(self: *const Self, input: []u8, set: SpvWord, binding: SpvWord, descriptor_index: SpvWord) RuntimeError!void {
|
pub fn writeDescriptorSet(self: *const Self, input: []u8, set: SpvWord, binding: SpvWord, descriptor_index: SpvWord) RuntimeError!void {
|
||||||
|
|||||||
@@ -187,8 +187,6 @@ pub const SetupDispatcher = block: {
|
|||||||
.Variable = opVariable,
|
.Variable = opVariable,
|
||||||
.VectorTimesMatrix = autoSetupConstant,
|
.VectorTimesMatrix = autoSetupConstant,
|
||||||
.VectorTimesScalar = autoSetupConstant,
|
.VectorTimesScalar = autoSetupConstant,
|
||||||
.SpecConstant = opConstant,
|
|
||||||
.SpecConstantOp = opSpecConstantOp,
|
|
||||||
.SpecConstantTrue = opSpecConstantTrue,
|
.SpecConstantTrue = opSpecConstantTrue,
|
||||||
.SpecConstantFalse = opSpecConstantFalse,
|
.SpecConstantFalse = opSpecConstantFalse,
|
||||||
.SpecConstantComposite = opConstantComposite,
|
.SpecConstantComposite = opConstantComposite,
|
||||||
@@ -287,6 +285,8 @@ pub fn initRuntimeDispatcher() void {
|
|||||||
runtime_dispatcher[@intFromEnum(spv.SpvOp.UMod)] = MathEngine(.UInt, .Mod).op;
|
runtime_dispatcher[@intFromEnum(spv.SpvOp.UMod)] = MathEngine(.UInt, .Mod).op;
|
||||||
runtime_dispatcher[@intFromEnum(spv.SpvOp.VectorTimesMatrix)] = MathEngine(.Float, .VectorTimesMatrix).op; // TODO
|
runtime_dispatcher[@intFromEnum(spv.SpvOp.VectorTimesMatrix)] = MathEngine(.Float, .VectorTimesMatrix).op; // TODO
|
||||||
runtime_dispatcher[@intFromEnum(spv.SpvOp.VectorTimesScalar)] = MathEngine(.Float, .VectorTimesScalar).op;
|
runtime_dispatcher[@intFromEnum(spv.SpvOp.VectorTimesScalar)] = MathEngine(.Float, .VectorTimesScalar).op;
|
||||||
|
runtime_dispatcher[@intFromEnum(spv.SpvOp.SpecConstant)] = opSpecConstant;
|
||||||
|
runtime_dispatcher[@intFromEnum(spv.SpvOp.SpecConstantOp)] = opSpecConstantOp;
|
||||||
// zig fmt: on
|
// zig fmt: on
|
||||||
|
|
||||||
// Extensions init
|
// Extensions init
|
||||||
@@ -574,10 +574,7 @@ fn CondOperator(comptime T: PrimitiveType, comptime Op: CondOp) type {
|
|||||||
}
|
}
|
||||||
return switch (Op) {
|
return switch (Op) {
|
||||||
.IsFinite => std.math.isFinite(a),
|
.IsFinite => std.math.isFinite(a),
|
||||||
.IsInf => blk: {
|
.IsInf => std.math.isInf(a),
|
||||||
//std.debug.print("test {s} - {d} - {s}\n", .{ @typeName(TT), a, if (std.math.isInf(a)) "true" else "false" });
|
|
||||||
break :blk std.math.isInf(a);
|
|
||||||
},
|
|
||||||
.IsNan => std.math.isNan(a),
|
.IsNan => std.math.isNan(a),
|
||||||
.IsNormal => std.math.isNormal(a),
|
.IsNormal => std.math.isNormal(a),
|
||||||
else => RuntimeError.InvalidSpirV,
|
else => RuntimeError.InvalidSpirV,
|
||||||
@@ -1644,6 +1641,25 @@ fn opConstantComposite(allocator: std.mem.Allocator, _: SpvWord, rt: *Runtime) R
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn opSpecConstant(allocator: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) RuntimeError!void {
|
||||||
|
const location = rt.it.emitSourceLocation();
|
||||||
|
_ = rt.it.skip();
|
||||||
|
const result_id = try rt.it.next();
|
||||||
|
_ = rt.it.goToSourceLocation(location);
|
||||||
|
|
||||||
|
try opConstant(allocator, word_count, rt);
|
||||||
|
|
||||||
|
const result = &rt.results[result_id];
|
||||||
|
|
||||||
|
for (result.decorations.items) |decoration| {
|
||||||
|
if (decoration.rtype == .SpecId) {
|
||||||
|
if (rt.specialization_constants.get(decoration.literal_1)) |data| {
|
||||||
|
_ = try (try result.getValue()).writeConst(data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn opSpecConstantTrue(allocator: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void {
|
fn opSpecConstantTrue(allocator: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void {
|
||||||
const target = try setupConstant(allocator, rt);
|
const target = try setupConstant(allocator, rt);
|
||||||
switch (target.variant.?.Constant.value) {
|
switch (target.variant.?.Constant.value) {
|
||||||
|
|||||||
Reference in New Issue
Block a user