From 147126c06f174e8f2b8183001ad600a951ccee1d Mon Sep 17 00:00:00 2001 From: Kbz-8 Date: Sun, 26 Apr 2026 02:39:21 +0200 Subject: [PATCH] adding result get from location --- ffi/SpirvInterpreter.h | 7 +++++++ ffi/runtime.zig | 12 ++++++++++++ src/Runtime.zig | 20 ++++++++++++++++---- 3 files changed, 35 insertions(+), 4 deletions(-) diff --git a/ffi/SpirvInterpreter.h b/ffi/SpirvInterpreter.h index b134439..2bfee61 100644 --- a/ffi/SpirvInterpreter.h +++ b/ffi/SpirvInterpreter.h @@ -78,6 +78,12 @@ typedef struct SpvSize size; } SpvRuntimeSpecializationEntry; +typedef enum +{ + SPV_LOCATION_INPUT = 0, + SPV_LOCATION_OUTPUT = 1, +} SpvLocationType; + typedef void* SpvModule; typedef void* SpvRuntime; @@ -92,6 +98,7 @@ SPV_API SpvResult SpvFlushDescriptorSets(SpvRuntime runtime); SPV_API SpvResult SpvAddSpecializationInfo(SpvRuntime runtime, SpvRuntimeSpecializationEntry entry, const SpvByte* data, SpvSize data_size); SPV_API SpvResult SpvGetResultByName(SpvRuntime runtime, const char* name, SpvWord* result); +SPV_API SpvResult SpvGetResultLocation(SpvRuntime runtime, SpvWord location, SpvLocationType type, SpvWord* result); SPV_API SpvResult SpvGetEntryPointByName(SpvRuntime runtime, const char* name, SpvWord* result); SPV_API SpvResult SpvCallEntryPoint(SpvRuntime runtime, SpvWord entry_point_index); diff --git a/ffi/runtime.zig b/ffi/runtime.zig index 5d2548b..d885d03 100644 --- a/ffi/runtime.zig +++ b/ffi/runtime.zig @@ -8,6 +8,11 @@ const CSpecializationEntry = extern struct { size: u32, }; +const LocationType = enum(c_int) { + input = 0, + output = 1, +}; + fn toCResult(err: spv.Runtime.RuntimeError) ffi.Result { return switch (err) { spv.Runtime.RuntimeError.DivisionByZero => ffi.Result.DivisionByZero, @@ -63,6 +68,13 @@ export fn SpvGetEntryPointByName(rt: *spv.Runtime, name: [*:0]const u8, result: return .Success; } +export fn SpvGetResultByLocation(rt: *spv.Runtime, location: spv.SpvWord, kind: LocationType, result: *spv.SpvWord) callconv(.c) ffi.Result { + result.* = rt.getResultByLocation(location, switch (kind) { + .input => .input, + .output => .output, + }) catch |err| return toCResult(err); + return .Success; +} export fn SpvGetResultByName(rt: *spv.Runtime, name: [*:0]const u8, result: *spv.SpvWord) callconv(.c) ffi.Result { result.* = rt.getResultByName(std.mem.span(name)) catch |err| return toCResult(err); return .Success; diff --git a/src/Runtime.zig b/src/Runtime.zig index 1d264aa..b65e252 100644 --- a/src/Runtime.zig +++ b/src/Runtime.zig @@ -99,7 +99,7 @@ pub fn addSpecializationInfo(self: *Self, allocator: std.mem.Allocator, entry: S self.specialization_constants.put(allocator, entry.id, slice) catch return RuntimeError.OutOfMemory; } -pub fn getEntryPointByName(self: *const Self, name: []const u8) error{NotFound}!SpvWord { +pub fn getEntryPointByName(self: *const Self, name: []const u8) RuntimeError!SpvWord { for (self.mod.entry_points.items, 0..) |entry_point, i| { if (blk: { // Not using std.mem.eql as entry point names may have longer size than their content @@ -112,10 +112,10 @@ pub fn getEntryPointByName(self: *const Self, name: []const u8) error{NotFound}! break :blk true; }) return @intCast(i); } - return error.NotFound; + return RuntimeError.NotFound; } -pub fn getResultByName(self: *const Self, name: []const u8) error{NotFound}!SpvWord { +pub fn getResultByName(self: *const Self, name: []const u8) RuntimeError!SpvWord { for (self.results, 0..) |result, i| { if (result.name) |result_name| { if (blk: { @@ -127,7 +127,19 @@ pub fn getResultByName(self: *const Self, name: []const u8) error{NotFound}!SpvW }) return @intCast(i); } } - return error.NotFound; + return RuntimeError.NotFound; +} + +pub fn getResultByLocation(self: *const Self, location: SpvWord, kind: enum { input, output }) RuntimeError!SpvWord { + switch (kind) { + .input => if (location < self.mod.input_locations.len) { + return self.mod.input_locations[location]; + }, + .output => if (location < self.mod.output_locations.len) { + return self.mod.output_locations[location]; + }, + } + return RuntimeError.NotFound; } pub fn dumpResultsTable(self: *Self, allocator: std.mem.Allocator, writer: *std.Io.Writer) RuntimeError!void {