diff --git a/src/Runtime.zig b/src/Runtime.zig index 207c641..e38882c 100644 --- a/src/Runtime.zig +++ b/src/Runtime.zig @@ -150,22 +150,6 @@ pub fn callEntryPoint(self: *Self, allocator: std.mem.Allocator, entry_point_ind //}) catch return RuntimeError.OutOfMemory; } -pub fn readOutput(self: *const Self, output: []u8, result: SpvWord) RuntimeError!void { - if (std.mem.indexOfScalar(SpvWord, &self.mod.output_locations, result)) |_| { - try self.readValue(output, &self.results[result]); - } else { - return RuntimeError.NotFound; - } -} - -pub fn writeInput(self: *const Self, comptime T: type, input: []const T, result: SpvWord) RuntimeError!void { - if (std.mem.indexOfScalar(SpvWord, &self.mod.input_locations, result)) |_| { - try self.writeValue(T, input, &self.results[result].variant.?.Variable.value); - } else { - return RuntimeError.NotFound; - } -} - pub fn readDescriptorSet(self: *const Self, comptime T: type, output: *T, set: SpvWord, binding: SpvWord) RuntimeError!void { if (set < lib.SPIRV_MAX_SET and binding < lib.SPIRV_MAX_SET_BINDINGS) { try self.readValue(T, output, &self.results[self.mod.bindings[set][binding]].variant.?.Variable.value); @@ -192,80 +176,94 @@ pub fn writeDescriptorSet(self: *const Self, comptime T: type, allocator: std.me } } +pub fn readOutput(self: *const Self, comptime T: type, output: []T, result: SpvWord) RuntimeError!void { + if (std.mem.indexOf(SpvWord, &self.mod.output_locations, &.{result})) |_| { + try self.readValue(T, output, &self.results[result].variant.?.Variable.value); + } else { + return RuntimeError.NotFound; + } +} + +pub fn writeInput(self: *const Self, comptime T: type, input: []const T, result: SpvWord) RuntimeError!void { + if (std.mem.indexOf(SpvWord, &self.mod.input_locations, &.{result})) |_| { + try self.writeValue(T, input, &self.results[result].variant.?.Variable.value); + } else { + return RuntimeError.NotFound; + } +} + fn reset(self: *Self) void { self.function_stack.clearRetainingCapacity(); self.current_function = null; } -fn readValue(self: *const Self, output: []u8, value: *const Result.Value) RuntimeError!void { - const type_word = try result.getValueTypeWord(); - const value = try result.getValue(); - const lane_bits = try Result.resolveLaneBitWidth((try self.results[type_word].getVariant()).Type, self); - +fn readValue(self: *const Self, comptime T: type, output: []T, value: *const Result.Value) RuntimeError!void { switch (value.*) { - .Bool => |b| output[0] = if (b == true) 1 else 0, + .Bool => |b| { + if (T == bool) { + output[0] = b; + } else { + return RuntimeError.InvalidValueType; + } + }, .Int => |i| { - switch (lane_bits) { - 8 => output[0] = @bitCast(i.uint8), - 16 => std.mem.copyForward(u8, output[0..], std.mem.asBytes(i.uint16)), - 32 => std.mem.copyForward(u8, output[0..], std.mem.asBytes(i.uint32)), - 64 => std.mem.copyForward(u8, output[0..], std.mem.asBytes(i.uint64)), + switch (T) { + i8 => output[0] = i.sint8, + i16 => output[0] = i.sint16, + i32 => output[0] = i.sint32, + i64 => output[0] = i.sint64, + u8 => output[0] = i.uint8, + u16 => output[0] = i.uint16, + u32 => output[0] = i.uint32, + u64 => output[0] = i.uint64, inline else => return RuntimeError.InvalidValueType, } }, .Float => |f| { - switch (lane_bits) { - 16 => std.mem.copyForward(u8, output[0..], std.mem.asBytes(f.float16)), - 32 => std.mem.copyForward(u8, output[0..], std.mem.asBytes(f.float32)), - 64 => std.mem.copyForward(u8, output[0..], std.mem.asBytes(f.float64)), + switch (T) { + f16 => output[0] = f.float16, + f32 => output[0] = f.float32, + f64 => output[0] = f.float64, inline else => return RuntimeError.InvalidValueType, } }, - .Vector4f32 => |vec| inline for (0..4) |i| switch (lane_bits) { - 32 => std.mem.copyForward(u8, output[0..], std.mem.asBytes(vec[i])), + .Vector4f32 => |vec| inline for (0..4) |i| switch (T) { + f32 => output[i] = vec[i], inline else => return RuntimeError.InvalidValueType, }, - .Vector3f32 => |vec| inline for (0..3) |i| switch (lane_bits) { - 32 => std.mem.copyForward(u8, output[0..], std.mem.asBytes(vec[i])), + .Vector3f32 => |vec| inline for (0..3) |i| switch (T) { + f32 => output[i] = vec[i], inline else => return RuntimeError.InvalidValueType, }, - .Vector2f32 => |vec| inline for (0..2) |i| switch (lane_bits) { - 32 => std.mem.copyForward(u8, output[0..], std.mem.asBytes(vec[i])), + .Vector2f32 => |vec| inline for (0..2) |i| switch (T) { + f32 => output[i] = vec[i], inline else => return RuntimeError.InvalidValueType, }, - .Vector4i32 => |vec| inline for (0..4) |i| switch (lane_bits) { - 32 => std.mem.copyForward(u8, output[0..], std.mem.asBytes(vec[i])), + .Vector4i32 => |vec| inline for (0..4) |i| switch (T) { + i32 => output[i] = vec[i], inline else => return RuntimeError.InvalidValueType, }, - .Vector3i32 => |vec| inline for (0..3) |i| switch (lane_bits) { - 32 => std.mem.copyForward(u8, output[0..], std.mem.asBytes(vec[i])), + .Vector3i32 => |vec| inline for (0..3) |i| switch (T) { + i32 => output[i] = vec[i], inline else => return RuntimeError.InvalidValueType, }, - .Vector2i32 => |vec| inline for (0..2) |i| switch (lane_bits) { - 32 => std.mem.copyForward(u8, output[0..], std.mem.asBytes(vec[i])), + .Vector2i32 => |vec| inline for (0..2) |i| switch (T) { + i32 => output[i] = vec[i], inline else => return RuntimeError.InvalidValueType, }, - .Vector4u32 => |vec| inline for (0..4) |i| switch (lane_bits) { - 32 => std.mem.copyForward(u8, output[0..], std.mem.asBytes(vec[i])), + .Vector4u32 => |vec| inline for (0..4) |i| switch (T) { + u32 => output[i] = vec[i], inline else => return RuntimeError.InvalidValueType, }, - .Vector3u32 => |vec| inline for (0..3) |i| switch (lane_bits) { - 32 => std.mem.copyForward(u8, output[0..], std.mem.asBytes(vec[i])), + .Vector3u32 => |vec| inline for (0..3) |i| switch (T) { + u32 => output[i] = vec[i], inline else => return RuntimeError.InvalidValueType, }, - .Vector2u32 => |vec| inline for (0..2) |i| switch (lane_bits) { - 32 => std.mem.copyForward(u8, output[0..], std.mem.asBytes(vec[i])), + .Vector2u32 => |vec| inline for (0..2) |i| switch (T) { + u32 => output[i] = vec[i], inline else => return RuntimeError.InvalidValueType, }, - .Array, - .Matrix, - .Structure, - .Vector, - => |values| for (values, 0..) |v, i| try self.readValue(T, output[i..], &v), - .RuntimeArray => |opt_values| if (opt_values) |values| { - for (values, 0..) |v, i| - try self.readValue(output[i..], result); - }, + .Vector, .Matrix, .Array, .Structure => |values| for (values, 0..) |v, i| try self.readValue(T, output[i..], &v), else => return RuntimeError.InvalidValueType, } } @@ -336,15 +334,7 @@ fn writeValue(self: *const Self, comptime T: type, input: []const T, value: *Res u32 => vec[i] = input[i], inline else => return RuntimeError.InvalidValueType, }, - .Array, - .Matrix, - .Structure, - .Vector, - => |*values| for (values.*, 0..) |*v, i| try self.writeValue(T, input[i..], v), - .RuntimeArray => |opt_values| if (opt_values) |*values| { - for (values.*, 0..) |*v, i| - try self.writeValue(T, input[i..], v); - }, + .Vector, .Matrix, .Array, .Structure => |*values| for (values.*, 0..) |*v, i| try self.writeValue(T, input[i..], v), else => return RuntimeError.InvalidValueType, } } diff --git a/test/root.zig b/test/root.zig index 784074d..8a12d25 100644 --- a/test/root.zig +++ b/test/root.zig @@ -50,6 +50,7 @@ pub const case = struct { pub fn expectOutputWithInput(comptime T: type, comptime len: usize, source: []const u32, output_name: []const u8, expected: []const T, input_name: []const u8, input: []const T) !void { const allocator = std.testing.allocator; + // To test with all important module options const module_options = [_]spv.Module.ModuleOptions{ .{ .use_simd_vectors_specializations = true,