265 lines
9.8 KiB
Zig
265 lines
9.8 KiB
Zig
const std = @import("std");
|
|
const vk = @import("vulkan");
|
|
const base = @import("base");
|
|
const spv = @import("spv");
|
|
const lib = @import("../lib.zig");
|
|
|
|
const PipelineState = @import("PipelineState.zig");
|
|
|
|
const SoftDevice = @import("../SoftDevice.zig");
|
|
const SoftPipeline = @import("../SoftPipeline.zig");
|
|
|
|
const VkError = base.VkError;
|
|
const SpvRuntimeError = spv.Runtime.RuntimeError;
|
|
|
|
const Self = @This();
|
|
|
|
const RunData = struct {
|
|
self: *Self,
|
|
batch_id: usize,
|
|
group_count: usize,
|
|
group_count_x: usize,
|
|
group_count_y: usize,
|
|
group_count_z: usize,
|
|
invocations_per_workgroup: usize,
|
|
pipeline: *SoftPipeline,
|
|
};
|
|
|
|
device: *SoftDevice,
|
|
state: *PipelineState,
|
|
batch_size: usize,
|
|
|
|
invocation_index: std.atomic.Value(usize),
|
|
|
|
early_dump: ?u32,
|
|
final_dump: ?u32,
|
|
|
|
pub fn init(device: *SoftDevice, state: *PipelineState) Self {
|
|
return .{
|
|
.device = device,
|
|
.state = state,
|
|
.batch_size = 0,
|
|
.invocation_index = .init(0),
|
|
.early_dump = std.process.parseEnvVarInt(lib.DUMP_EARLY_RESULT_TABLE_ENV_NAME, u32, 10) catch null,
|
|
.final_dump = std.process.parseEnvVarInt(lib.DUMP_FINAL_RESULT_TABLE_ENV_NAME, u32, 10) catch null,
|
|
};
|
|
}
|
|
|
|
pub fn destroy(self: *Self) void {
|
|
_ = self;
|
|
}
|
|
|
|
pub fn dispatch(self: *Self, group_count_x: u32, group_count_y: u32, group_count_z: u32) VkError!void {
|
|
const group_count: usize = @intCast(group_count_x * group_count_y * group_count_z);
|
|
|
|
const pipeline = self.state.pipeline orelse return VkError.InvalidPipelineDrv;
|
|
const shader = pipeline.stages.getPtr(.compute) orelse return VkError.InvalidPipelineDrv;
|
|
const spv_module = &shader.module.module;
|
|
self.batch_size = shader.runtimes.len;
|
|
|
|
const invocations_per_workgroup = spv_module.local_size_x * spv_module.local_size_y * spv_module.local_size_z;
|
|
|
|
self.invocation_index.store(0, .monotonic);
|
|
|
|
var wg: std.Thread.WaitGroup = .{};
|
|
for (0..@min(self.batch_size, group_count)) |batch_id| {
|
|
if (std.process.hasEnvVarConstant(lib.SINGLE_THREAD_COMPUTE_EXECUTION_ENV_NAME)) {
|
|
@branchHint(.cold); // Should only be reached for debugging
|
|
|
|
runWrapper(
|
|
RunData{
|
|
.self = self,
|
|
.batch_id = batch_id,
|
|
.group_count = group_count,
|
|
.group_count_x = @as(usize, @intCast(group_count_x)),
|
|
.group_count_y = @as(usize, @intCast(group_count_y)),
|
|
.group_count_z = @as(usize, @intCast(group_count_z)),
|
|
.invocations_per_workgroup = invocations_per_workgroup,
|
|
.pipeline = pipeline,
|
|
},
|
|
);
|
|
} else {
|
|
self.device.workers.spawnWg(&wg, runWrapper, .{
|
|
RunData{
|
|
.self = self,
|
|
.batch_id = batch_id,
|
|
.group_count = group_count,
|
|
.group_count_x = @as(usize, @intCast(group_count_x)),
|
|
.group_count_y = @as(usize, @intCast(group_count_y)),
|
|
.group_count_z = @as(usize, @intCast(group_count_z)),
|
|
.invocations_per_workgroup = invocations_per_workgroup,
|
|
.pipeline = pipeline,
|
|
},
|
|
});
|
|
}
|
|
}
|
|
self.device.workers.waitAndWork(&wg);
|
|
}
|
|
|
|
fn runWrapper(data: RunData) void {
|
|
@call(.always_inline, run, .{data}) catch |err| {
|
|
std.log.scoped(.@"SPIR-V runtime").err("SPIR-V runtime catched a '{s}'", .{@errorName(err)});
|
|
if (@errorReturnTrace()) |trace| {
|
|
std.debug.dumpStackTrace(trace.*);
|
|
}
|
|
};
|
|
}
|
|
|
|
inline fn run(data: RunData) !void {
|
|
const allocator = data.self.device.device_allocator.allocator();
|
|
|
|
const shader = data.pipeline.stages.getPtrAssertContains(.compute);
|
|
const rt = &shader.runtimes[data.batch_id];
|
|
|
|
const entry = try rt.getEntryPointByName(shader.entry);
|
|
|
|
try data.self.writeDescriptorSets(rt);
|
|
|
|
var group_index: usize = data.batch_id;
|
|
while (group_index < data.group_count) : (group_index += data.self.batch_size) {
|
|
var modulo: usize = group_index;
|
|
|
|
const group_z = @divTrunc(modulo, data.group_count_x * data.group_count_y);
|
|
|
|
modulo -= group_z * data.group_count_x * data.group_count_y;
|
|
const group_y = @divTrunc(modulo, data.group_count_x);
|
|
|
|
modulo -= group_y * data.group_count_x;
|
|
const group_x = modulo;
|
|
|
|
try setupWorkgroupBuiltins(data.self, rt, .{
|
|
@as(u32, @intCast(data.group_count_x)),
|
|
@as(u32, @intCast(data.group_count_y)),
|
|
@as(u32, @intCast(data.group_count_z)),
|
|
}, .{
|
|
@as(u32, @intCast(group_x)),
|
|
@as(u32, @intCast(group_y)),
|
|
@as(u32, @intCast(group_z)),
|
|
});
|
|
|
|
for (0..data.invocations_per_workgroup) |i| {
|
|
const invocation_index = data.self.invocation_index.fetchAdd(1, .monotonic);
|
|
|
|
try setupSubgroupBuiltins(data.self, rt, .{
|
|
@as(u32, @intCast(group_x)),
|
|
@as(u32, @intCast(group_y)),
|
|
@as(u32, @intCast(group_z)),
|
|
}, i);
|
|
|
|
if (data.self.early_dump != null and data.self.early_dump.? == invocation_index) {
|
|
@branchHint(.cold);
|
|
try dumpResultsTable(allocator, rt, true);
|
|
}
|
|
|
|
rt.callEntryPoint(allocator, entry) catch |err| switch (err) {
|
|
// Some errors can be ignored
|
|
SpvRuntimeError.OutOfBounds,
|
|
SpvRuntimeError.Killed,
|
|
=> {},
|
|
else => return err,
|
|
};
|
|
|
|
if (data.self.final_dump != null and data.self.final_dump.? == invocation_index) {
|
|
@branchHint(.cold);
|
|
try dumpResultsTable(allocator, rt, false);
|
|
}
|
|
|
|
try rt.flushDescriptorSets(allocator);
|
|
}
|
|
}
|
|
}
|
|
|
|
inline fn dumpResultsTable(allocator: std.mem.Allocator, rt: *spv.Runtime, is_early: bool) !void {
|
|
@branchHint(.cold);
|
|
const file = try std.fs.cwd().createFile(
|
|
std.fmt.comptimePrint("{s}_compute_result_table_dump.txt", .{if (is_early) "early" else "final"}),
|
|
.{ .truncate = true },
|
|
);
|
|
defer file.close();
|
|
var buffer = [_]u8{0} ** 1024;
|
|
var writer = file.writer(buffer[0..]);
|
|
try rt.dumpResultsTable(allocator, &writer.interface);
|
|
}
|
|
|
|
fn writeDescriptorSets(self: *Self, rt: *spv.Runtime) !void {
|
|
sets: for (self.state.sets[0..], 0..) |set, set_index| {
|
|
if (set == null)
|
|
continue :sets;
|
|
|
|
bindings: for (set.?.descriptors[0..], 0..) |binding, binding_index| {
|
|
switch (binding) {
|
|
.buffer => |buffer_data_array| for (buffer_data_array, 0..) |buffer_data, descriptor_index| {
|
|
if (buffer_data.object) |buffer| {
|
|
const memory = if (buffer.interface.memory) |memory| memory else continue :bindings;
|
|
const map: []u8 = @as([*]u8, @ptrCast(try memory.map(buffer_data.offset, buffer_data.size)))[0..buffer_data.size];
|
|
try rt.writeDescriptorSet(
|
|
map,
|
|
@as(u32, @intCast(set_index)),
|
|
@as(u32, @intCast(binding_index)),
|
|
@as(u32, @intCast(descriptor_index)),
|
|
);
|
|
}
|
|
},
|
|
.image => |image_data_array| for (image_data_array, 0..) |image_data, descriptor_index| {
|
|
if (image_data.object) |image| {
|
|
const memory = if (image.interface.memory) |memory| memory else continue :bindings;
|
|
const map: []u8 = @as([*]u8, @ptrCast(try memory.map(image.interface.memory_offset, try image.interface.getTotalSize())))[0..try image.interface.getTotalSize()];
|
|
try rt.writeDescriptorSet(
|
|
map,
|
|
@as(u32, @intCast(set_index)),
|
|
@as(u32, @intCast(binding_index)),
|
|
@as(u32, @intCast(descriptor_index)),
|
|
);
|
|
}
|
|
},
|
|
else => {},
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
fn setupWorkgroupBuiltins(
|
|
self: *Self,
|
|
rt: *spv.Runtime,
|
|
group_count: @Vector(3, u32),
|
|
group_id: @Vector(3, u32),
|
|
) spv.Runtime.RuntimeError!void {
|
|
const spv_module = &self.state.pipeline.?.stages.getPtrAssertContains(.compute).module.module;
|
|
const workgroup_size = @Vector(3, u32){
|
|
spv_module.local_size_x,
|
|
spv_module.local_size_y,
|
|
spv_module.local_size_z,
|
|
};
|
|
|
|
rt.writeBuiltIn(std.mem.asBytes(&workgroup_size), .WorkgroupSize) catch {};
|
|
rt.writeBuiltIn(std.mem.asBytes(&group_count), .NumWorkgroups) catch {};
|
|
rt.writeBuiltIn(std.mem.asBytes(&group_id), .WorkgroupId) catch {};
|
|
}
|
|
|
|
fn setupSubgroupBuiltins(
|
|
self: *Self,
|
|
rt: *spv.Runtime,
|
|
group_id: @Vector(3, u32),
|
|
local_invocation_index: usize,
|
|
) spv.Runtime.RuntimeError!void {
|
|
const spv_module = &self.state.pipeline.?.stages.getPtrAssertContains(.compute).module.module;
|
|
const workgroup_size = @Vector(3, u32){
|
|
spv_module.local_size_x,
|
|
spv_module.local_size_y,
|
|
spv_module.local_size_z,
|
|
};
|
|
const local_base = workgroup_size * group_id;
|
|
var local_invocation = @Vector(3, u32){ 0, 0, 0 };
|
|
|
|
var idx: u32 = @intCast(local_invocation_index);
|
|
local_invocation[2] = @divTrunc(idx, workgroup_size[0] * workgroup_size[1]);
|
|
idx -= local_invocation[2] * workgroup_size[0] * workgroup_size[1];
|
|
local_invocation[1] = @divTrunc(idx, workgroup_size[0]);
|
|
idx -= local_invocation[1] * workgroup_size[0];
|
|
local_invocation[0] = idx;
|
|
|
|
const global_invocation_index = local_base + local_invocation;
|
|
|
|
rt.writeBuiltIn(std.mem.asBytes(&global_invocation_index), .GlobalInvocationId) catch {};
|
|
}
|