const std = @import("std"); const vk = @import("vulkan"); const base = @import("base"); const spv = @import("spv"); const lib = @import("../lib.zig"); const ExecutionDevice = @import("Device.zig"); const PipelineState = ExecutionDevice.PipelineState; const SoftDevice = @import("../SoftDevice.zig"); const SoftPipeline = @import("../SoftPipeline.zig"); const VkError = base.VkError; const SpvRuntimeError = spv.Runtime.RuntimeError; const Self = @This(); const RunData = struct { self: *Self, batch_id: usize, group_count: usize, group_count_x: usize, group_count_y: usize, group_count_z: usize, invocations_per_workgroup: usize, pipeline: *SoftPipeline, }; device: *SoftDevice, state: *PipelineState, batch_size: usize, invocation_index: std.atomic.Value(usize), early_dump: ?u32, final_dump: ?u32, pub fn init(device: *SoftDevice, state: *PipelineState) Self { return .{ .device = device, .state = state, .batch_size = 0, .invocation_index = .init(0), .early_dump = base.config.compute_dump_early_results_table, .final_dump = base.config.compute_dump_final_results_table, }; } pub fn dispatch(self: *Self, group_count_x: u32, group_count_y: u32, group_count_z: u32) VkError!void { const group_count: usize = @intCast(group_count_x * group_count_y * group_count_z); const pipeline = self.state.pipeline orelse return VkError.InvalidPipelineDrv; const shader = pipeline.stages.getPtr(.compute) orelse return VkError.InvalidPipelineDrv; const spv_module = &shader.module.module; self.batch_size = shader.runtimes.len; const invocations_per_workgroup = spv_module.local_size_x * spv_module.local_size_y * spv_module.local_size_z; self.invocation_index.store(0, .monotonic); const io = self.device.interface.io(); const timer = std.Io.Timestamp.now(io, .real); defer if (comptime base.config.logs != .none) { const duration = timer.untilNow(io, .real); const ms: f32 = @floatFromInt(duration.toMicroseconds()); std.log.scoped(.ComputeDispatcher).debug("Compute dispatch took {}ms", .{ms / 1000}); }; var wg: std.Io.Group = .init; for (0..@min(self.batch_size, group_count)) |batch_id| { const run_data: RunData = .{ .self = self, .batch_id = batch_id, .group_count = group_count, .group_count_x = @as(usize, @intCast(group_count_x)), .group_count_y = @as(usize, @intCast(group_count_y)), .group_count_z = @as(usize, @intCast(group_count_z)), .invocations_per_workgroup = invocations_per_workgroup, .pipeline = pipeline, }; if (comptime base.config.single_threaded_compute) runWrapper(run_data) else wg.async(self.device.interface.io(), runWrapper, .{run_data}); } wg.await(self.device.interface.io()) catch return VkError.DeviceLost; } fn runWrapper(data: RunData) void { @call(.always_inline, run, .{data}) catch |err| { std.log.scoped(.@"SPIR-V runtime").err("SPIR-V runtime catched a '{s}'", .{@errorName(err)}); if (comptime base.config.logs == .verbose) { if (@errorReturnTrace()) |trace| { std.debug.dumpErrorReturnTrace(trace); } } }; } inline fn run(data: RunData) !void { const allocator = data.self.device.device_allocator.allocator(); const io = data.self.device.interface.io(); const shader = data.pipeline.stages.getPtrAssertContains(.compute); const rt = &shader.runtimes[data.batch_id].rt; const entry = try rt.getEntryPointByName(shader.entry); const uses_control_barrier = hasControlBarrier(rt.mod.code); var barrier_runtimes: []spv.Runtime = &.{}; var barrier_statuses: []spv.Runtime.EntryPointStatus = &.{}; if (uses_control_barrier) { barrier_runtimes = try allocator.alloc(spv.Runtime, data.invocations_per_workgroup); barrier_statuses = try allocator.alloc(spv.Runtime.EntryPointStatus, data.invocations_per_workgroup); for (barrier_runtimes) |*barrier_rt| { barrier_rt.* = try spv.Runtime.init(allocator, rt.mod, rt.image_api); try barrier_rt.copySpecializationConstantsFrom(allocator, rt); } } defer { for (barrier_runtimes) |*barrier_rt| { barrier_rt.deinit(allocator); } allocator.free(barrier_runtimes); allocator.free(barrier_statuses); } if (!uses_control_barrier) try ExecutionDevice.writeDescriptorSets(data.self.state, rt); var group_index: usize = data.batch_id; while (group_index < data.group_count) : (group_index += data.self.batch_size) { var modulo: usize = group_index; const group_z = @divTrunc(modulo, data.group_count_x * data.group_count_y); modulo -= group_z * data.group_count_x * data.group_count_y; const group_y = @divTrunc(modulo, data.group_count_x); modulo -= group_y * data.group_count_x; const group_x = modulo; const group_count_vec = @Vector(3, u32){ @as(u32, @intCast(data.group_count_x)), @as(u32, @intCast(data.group_count_y)), @as(u32, @intCast(data.group_count_z)), }; const group_id_vec = @Vector(3, u32){ @as(u32, @intCast(group_x)), @as(u32, @intCast(group_y)), @as(u32, @intCast(group_z)), }; if (uses_control_barrier) { try runBarrierWorkgroup(data, barrier_runtimes, barrier_statuses, entry, group_count_vec, group_id_vec); continue; } try setupWorkgroupBuiltins(data.self, rt, group_count_vec, group_id_vec); for (0..data.invocations_per_workgroup) |i| { const invocation_index = data.self.invocation_index.fetchAdd(1, .monotonic); try setupSubgroupBuiltins(data.self, rt, .{ @as(u32, @intCast(group_x)), @as(u32, @intCast(group_y)), @as(u32, @intCast(group_z)), }, i); if (data.self.early_dump != null and data.self.early_dump.? == invocation_index) { @branchHint(.cold); try dumpResultsTable(allocator, io, rt, true); } rt.callEntryPoint(allocator, entry) catch |err| switch (err) { // Some errors can be ignored SpvRuntimeError.OutOfBounds, SpvRuntimeError.Killed, => {}, else => return err, }; if (data.self.final_dump != null and data.self.final_dump.? == invocation_index) { @branchHint(.cold); try dumpResultsTable(allocator, io, rt, false); } try rt.flushDescriptorSets(allocator); } } } fn runBarrierWorkgroup( data: RunData, runtimes: []spv.Runtime, statuses: []spv.Runtime.EntryPointStatus, entry: spv.SpvWord, group_count: @Vector(3, u32), group_id: @Vector(3, u32), ) !void { const allocator = data.self.device.device_allocator.allocator(); for (runtimes, 0..) |*rt, i| { try ExecutionDevice.writeDescriptorSets(data.self.state, rt); try setupWorkgroupBuiltins(data.self, rt, group_count, group_id); try setupSubgroupBuiltins(data.self, rt, group_id, i); statuses[i] = try rt.beginEntryPoint(allocator, entry); try rt.flushDescriptorSets(allocator); } while (true) { var pending = false; for (statuses) |status| { if (status == .barrier) { pending = true; break; } } if (!pending) break; for (runtimes, 0..) |*rt, i| { if (statuses[i] == .completed) continue; statuses[i] = try rt.continueEntryPoint(allocator); try rt.flushDescriptorSets(allocator); } } } /// TODO: Move this in the SPIR-V Interpreter fn hasControlBarrier(code: []const spv.SpvWord) bool { var i: usize = 5; while (i < code.len) { const opcode_data = code[i]; const word_count = (opcode_data & (~spv.spv.SpvOpCodeMask)) >> spv.spv.SpvWordCountShift; const opcode: spv.spv.SpvOp = @enumFromInt(opcode_data & spv.spv.SpvOpCodeMask); if (opcode == .ControlBarrier) return true; i += @max(word_count, 1); } return false; } inline fn dumpResultsTable(allocator: std.mem.Allocator, io: std.Io, rt: *spv.Runtime, is_early: bool) !void { @branchHint(.cold); const file = try std.Io.Dir.cwd().createFile( io, std.fmt.comptimePrint("{s}_compute_result_table_dump.txt", .{if (is_early) "early" else "final"}), .{ .truncate = true }, ); defer file.close(io); var buffer = [_]u8{0} ** 1024; var writer = file.writer(io, buffer[0..]); try rt.dumpResultsTable(allocator, &writer.interface); } fn setupWorkgroupBuiltins( self: *Self, rt: *spv.Runtime, group_count: @Vector(3, u32), group_id: @Vector(3, u32), ) spv.Runtime.RuntimeError!void { const spv_module = &self.state.pipeline.?.stages.getPtrAssertContains(.compute).module.module; const workgroup_size = @Vector(3, u32){ spv_module.local_size_x, spv_module.local_size_y, spv_module.local_size_z, }; rt.writeBuiltIn(std.mem.asBytes(&workgroup_size), .WorkgroupSize) catch {}; rt.writeBuiltIn(std.mem.asBytes(&group_count), .NumWorkgroups) catch {}; rt.writeBuiltIn(std.mem.asBytes(&group_id), .WorkgroupId) catch {}; } fn setupSubgroupBuiltins( self: *Self, rt: *spv.Runtime, group_id: @Vector(3, u32), local_invocation_index: usize, ) spv.Runtime.RuntimeError!void { const spv_module = &self.state.pipeline.?.stages.getPtrAssertContains(.compute).module.module; const workgroup_size = @Vector(3, u32){ spv_module.local_size_x, spv_module.local_size_y, spv_module.local_size_z, }; const local_base = workgroup_size * group_id; var local_invocation = @Vector(3, u32){ 0, 0, 0 }; var idx: u32 = @intCast(local_invocation_index); local_invocation[2] = @divTrunc(idx, workgroup_size[0] * workgroup_size[1]); idx -= local_invocation[2] * workgroup_size[0] * workgroup_size[1]; local_invocation[1] = @divTrunc(idx, workgroup_size[0]); idx -= local_invocation[1] * workgroup_size[0]; local_invocation[0] = idx; const global_invocation_index = local_base + local_invocation; rt.writeBuiltIn(std.mem.asBytes(&global_invocation_index), .GlobalInvocationId) catch {}; }