diff --git a/src/Module.zig b/src/Module.zig index 274c60b..70e7989 100644 --- a/src/Module.zig +++ b/src/Module.zig @@ -37,44 +37,6 @@ pub const ModuleError = error{ OutOfMemory, }; -const AllocatorWrapper = struct { - child_allocator: std.mem.Allocator, - total_bytes_allocated: usize = 0, - - pub fn allocator(self: *AllocatorWrapper) std.mem.Allocator { - return .{ - .ptr = self, - .vtable = &.{ - .alloc = alloc, - .resize = resize, - .remap = remap, - .free = free, - }, - }; - } - - fn alloc(ctx: *anyopaque, n: usize, alignment: std.mem.Alignment, ra: usize) ?[*]u8 { - const self: *AllocatorWrapper = @ptrCast(@alignCast(ctx)); - self.total_bytes_allocated += alignment.toByteUnits() + n; - return self.child_allocator.rawAlloc(n, alignment, ra); - } - - fn resize(ctx: *anyopaque, buf: []u8, alignment: std.mem.Alignment, new_len: usize, ret_addr: usize) bool { - const self: *AllocatorWrapper = @ptrCast(@alignCast(ctx)); - return self.child_allocator.rawResize(buf, alignment, new_len, ret_addr); - } - - fn remap(context: *anyopaque, memory: []u8, alignment: std.mem.Alignment, new_len: usize, return_address: usize) ?[*]u8 { - const self: *AllocatorWrapper = @ptrCast(@alignCast(context)); - return self.child_allocator.rawRemap(memory, alignment, new_len, return_address); - } - - fn free(ctx: *anyopaque, buf: []u8, alignment: std.mem.Alignment, ret_addr: usize) void { - const self: *AllocatorWrapper = @ptrCast(@alignCast(ctx)); - return self.child_allocator.rawFree(buf, alignment, ret_addr); - } -}; - options: ModuleOptions, it: WordIterator, @@ -114,8 +76,6 @@ bindings: [lib.SPIRV_MAX_SET][lib.SPIRV_MAX_SET_BINDINGS]SpvWord, builtins: std.EnumMap(spv.SpvBuiltIn, SpvWord), push_constants: []Value, -needed_runtime_bytes: usize, - pub fn init(allocator: std.mem.Allocator, source: []const SpvWord, options: ModuleOptions) ModuleError!Self { var self: Self = std.mem.zeroInit(Self, .{ .options = options, @@ -131,8 +91,6 @@ pub fn init(allocator: std.mem.Allocator, source: []const SpvWord, options: Modu op.initRuntimeDispatcher(); - var wrapped_allocator: AllocatorWrapper = .{ .child_allocator = allocator }; - self.it = WordIterator.init(self.code); const magic = self.it.next() catch return ModuleError.InvalidSpirV; @@ -152,7 +110,7 @@ pub fn init(allocator: std.mem.Allocator, source: []const SpvWord, options: Modu self.generator_version = @intCast(generator & 0x0000FFFF); self.bound = self.it.next() catch return ModuleError.InvalidSpirV; - self.results = wrapped_allocator.allocator().alloc(Result, self.bound) catch return ModuleError.OutOfMemory; + self.results = allocator.alloc(Result, self.bound) catch return ModuleError.OutOfMemory; errdefer allocator.free(self.results); for (self.results) |*result| { @@ -208,8 +166,6 @@ pub fn init(allocator: std.mem.Allocator, source: []const SpvWord, options: Modu }); } - self.needed_runtime_bytes += wrapped_allocator.total_bytes_allocated; - //@import("pretty").print(allocator, self.results, .{ // .tab_size = 4, // .max_depth = 0, @@ -235,8 +191,6 @@ fn pass(self: *Self, allocator: std.mem.Allocator) ModuleError!void { var rt = Runtime.init(allocator, self) catch return ModuleError.OutOfMemory; defer rt.deinit(allocator); - var wrapped_allocator: AllocatorWrapper = .{ .child_allocator = allocator }; - while (rt.it.nextOrNull()) |opcode_data| { const word_count = ((opcode_data & (~spv.SpvOpCodeMask)) >> spv.SpvWordCountShift) - 1; const opcode = (opcode_data & spv.SpvOpCodeMask); @@ -244,14 +198,12 @@ fn pass(self: *Self, allocator: std.mem.Allocator) ModuleError!void { var it_tmp = rt.it; // Save because operations may iter on this iterator if (std.enums.fromInt(spv.SpvOp, opcode)) |spv_op| { if (op.SetupDispatcher.get(spv_op)) |pfn| { - pfn(wrapped_allocator.allocator(), word_count, &rt) catch return ModuleError.InvalidSpirV; + pfn(allocator, word_count, &rt) catch return ModuleError.InvalidSpirV; } } _ = it_tmp.skipN(word_count); rt.it = it_tmp; } - - self.needed_runtime_bytes += wrapped_allocator.total_bytes_allocated; } fn applyDecorations(self: *Self) ModuleError!void { diff --git a/src/Result.zig b/src/Result.zig index 7b9948c..f856eba 100644 --- a/src/Result.zig +++ b/src/Result.zig @@ -72,7 +72,7 @@ const ImageInfo = struct { access: spv.SpvAccessQualifier, }; -const Decoration = struct { +pub const Decoration = struct { rtype: spv.SpvDecoration, literal_1: SpvWord, literal_2: ?SpvWord, @@ -112,6 +112,7 @@ pub const TypeData = union(Type) { components_type_word: SpvWord, components_type: Type, member_count: SpvWord, + stride: SpvWord, }, RuntimeArray: struct { components_type_word: SpvWord, @@ -142,7 +143,7 @@ pub const TypeData = union(Type) { .Int => |i| @divExact(i.bit_length, 8), .Float => |f| @divExact(f.bit_length, 8), .Vector => |v| results[v.components_type_word].variant.?.Type.getSize(results), - .Array => |a| results[a.components_type_word].variant.?.Type.getSize(results), + .Array => |a| a.stride, .Matrix => |m| results[m.column_type_word].variant.?.Type.getSize(results), .RuntimeArray => |a| a.stride, .Structure => |s| blk: { diff --git a/src/Value.zig b/src/Value.zig index e00ba95..35caa57 100644 --- a/src/Value.zig +++ b/src/Value.zig @@ -220,11 +220,10 @@ pub const Value = union(Type) { }, .Structure => |s| .{ .Structure = blk: { - const offsets = allocator.dupe(?SpvWord, s.offsets) catch return RuntimeError.OutOfMemory; const values = allocator.dupe(Self, s.values) catch return RuntimeError.OutOfMemory; for (values, s.values) |*new_value, value| new_value.* = try value.dupe(allocator); break :blk .{ - .offsets = offsets, + .offsets = allocator.dupe(?SpvWord, s.offsets) catch return RuntimeError.OutOfMemory, .values = values, }; }, diff --git a/src/WordIterator.zig b/src/WordIterator.zig index 07cecaf..12f5985 100644 --- a/src/WordIterator.zig +++ b/src/WordIterator.zig @@ -59,6 +59,7 @@ pub inline fn skipN(self: *Self, count: usize) bool { pub inline fn skipToEnd(self: *Self) void { self.index = self.buffer.len; + self.did_jump = true; } pub inline fn emitSourceLocation(self: *const Self) usize { diff --git a/src/opcodes.zig b/src/opcodes.zig index df85ebf..0c34f19 100644 --- a/src/opcodes.zig +++ b/src/opcodes.zig @@ -89,6 +89,7 @@ pub const SetupDispatcher = block: { .BitwiseXor = autoSetupConstant, .Capability = opCapability, .CompositeConstruct = autoSetupConstant, + .CompositeInsert = autoSetupConstant, .Constant = opConstant, .ConstantComposite = opConstantComposite, .ConvertFToS = autoSetupConstant, @@ -98,6 +99,7 @@ pub const SetupDispatcher = block: { .ConvertUToF = autoSetupConstant, .ConvertUToPtr = autoSetupConstant, .Decorate = opDecorate, + .DecorationGroup = opDecorationGroup, .Dot = autoSetupConstant, .EntryPoint = opEntryPoint, .ExecutionMode = opExecutionMode, @@ -126,6 +128,8 @@ pub const SetupDispatcher = block: { .FunctionCall = autoSetupConstant, .FunctionEnd = opFunctionEnd, .FunctionParameter = opFunctionParameter, + .GroupDecorate = opGroupDecorate, + .GroupMemberDecorate = opGroupMemberDecorate, .IAdd = autoSetupConstant, .IEqual = autoSetupConstant, .IMul = autoSetupConstant, @@ -209,6 +213,7 @@ pub fn initRuntimeDispatcher() void { runtime_dispatcher[@intFromEnum(spv.SpvOp.BranchConditional)] = opBranchConditional; runtime_dispatcher[@intFromEnum(spv.SpvOp.CompositeConstruct)] = opCompositeConstruct; runtime_dispatcher[@intFromEnum(spv.SpvOp.CompositeExtract)] = opCompositeExtract; + runtime_dispatcher[@intFromEnum(spv.SpvOp.CompositeInsert)] = opCompositeInsert; runtime_dispatcher[@intFromEnum(spv.SpvOp.ConvertFToS)] = ConversionEngine(.Float, .SInt).op; runtime_dispatcher[@intFromEnum(spv.SpvOp.ConvertFToU)] = ConversionEngine(.Float, .UInt).op; runtime_dispatcher[@intFromEnum(spv.SpvOp.ConvertSToF)] = ConversionEngine(.SInt, .Float).op; @@ -241,6 +246,7 @@ pub fn initRuntimeDispatcher() void { runtime_dispatcher[@intFromEnum(spv.SpvOp.IMul)] = MathEngine(.SInt, .Mul).op; runtime_dispatcher[@intFromEnum(spv.SpvOp.INotEqual)] = CondEngine(.SInt, .NotEqual).op; runtime_dispatcher[@intFromEnum(spv.SpvOp.ISub)] = MathEngine(.SInt, .Sub).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.InBoundsAccessChain)] = opAccessChain; runtime_dispatcher[@intFromEnum(spv.SpvOp.IsFinite)] = CondEngine(.Float, .IsNan).op; runtime_dispatcher[@intFromEnum(spv.SpvOp.IsInf)] = CondEngine(.Float, .IsInf).op; runtime_dispatcher[@intFromEnum(spv.SpvOp.IsNan)] = CondEngine(.Float, .IsNan).op; @@ -1119,6 +1125,12 @@ fn addDecoration(allocator: std.mem.Allocator, rt: *Runtime, target: SpvWord, de } } +fn cloneDecorationTo(allocator: std.mem.Allocator, rt: *Runtime, target: SpvWord, decoration: *const Result.Decoration, member: ?SpvWord) RuntimeError!void { + const out = rt.mod.results[target].decorations.addOne(allocator) catch return RuntimeError.OutOfMemory; + out.* = decoration.*; + out.index = if (member) |m| m else decoration.index; +} + fn autoSetupConstant(allocator: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { _ = try setupConstant(allocator, rt); } @@ -1542,6 +1554,113 @@ fn opCompositeExtract(allocator: std.mem.Allocator, word_count: SpvWord, rt: *Ru }; } +fn opCompositeInsert(allocator: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) RuntimeError!void { + _ = try rt.it.next(); + const id = try rt.it.next(); + const object = try rt.results[try rt.it.next()].getValue(); + const composite = try rt.results[try rt.it.next()].getValue(); + + const target = try rt.results[id].getValue(); + + copyValue(target, composite); + + const index_count = word_count - 4; + + var arena = std.heap.ArenaAllocator.init(allocator); + defer arena.deinit(); + + const helpers = struct { + fn insertAt( + alloc: std.mem.Allocator, + results: []const Result, + current: *Value, + object_value: *const Value, + indices: []const SpvWord, + ) RuntimeError!void { + if (indices.len == 0) { + copyValue(current, object_value); + return; + } + + const index = indices[0]; + + if (current.getCompositeDataOrNull()) |children| { + if (index >= children.len) return RuntimeError.OutOfBounds; + return insertAt(alloc, results, &children[index], object_value, indices[1..]); + } + + switch (current.*) { + .Structure => |*s| { + if (index >= s.values.len) return RuntimeError.OutOfBounds; + return insertAt(alloc, results, &s.values[index], object_value, indices[1..]); + }, + + .RuntimeArray => |*arr| { + if (index >= arr.getLen()) return RuntimeError.OutOfBounds; + + const elem_offset = arr.getOffsetOfIndex(index); + + if (indices.len == 1) { + _ = try object_value.read(arr.data[elem_offset..]); + return; + } + + var elem_value = try Value.init(alloc, results, arr.type_word); + _ = try elem_value.writeConst(arr.data[elem_offset..]); + try insertAt(alloc, results, &elem_value, object_value, indices[1..]); + _ = try elem_value.read(arr.data[elem_offset..]); + }, + + .Vector4f32 => |*v| { + if (index >= 4 or indices.len != 1) return RuntimeError.InvalidSpirV; + v[index] = (try getValuePrimitiveField(.Float, 32, @constCast(object_value))).*; + }, + .Vector3f32 => |*v| { + if (index >= 3 or indices.len != 1) return RuntimeError.InvalidSpirV; + v[index] = (try getValuePrimitiveField(.Float, 32, @constCast(object_value))).*; + }, + .Vector2f32 => |*v| { + if (index >= 2 or indices.len != 1) return RuntimeError.InvalidSpirV; + v[index] = (try getValuePrimitiveField(.Float, 32, @constCast(object_value))).*; + }, + + .Vector4i32 => |*v| { + if (index >= 4 or indices.len != 1) return RuntimeError.InvalidSpirV; + v[index] = (try getValuePrimitiveField(.SInt, 32, @constCast(object_value))).*; + }, + .Vector3i32 => |*v| { + if (index >= 3 or indices.len != 1) return RuntimeError.InvalidSpirV; + v[index] = (try getValuePrimitiveField(.SInt, 32, @constCast(object_value))).*; + }, + .Vector2i32 => |*v| { + if (index >= 2 or indices.len != 1) return RuntimeError.InvalidSpirV; + v[index] = (try getValuePrimitiveField(.SInt, 32, @constCast(object_value))).*; + }, + + .Vector4u32 => |*v| { + if (index >= 4 or indices.len != 1) return RuntimeError.InvalidSpirV; + v[index] = (try getValuePrimitiveField(.UInt, 32, @constCast(object_value))).*; + }, + .Vector3u32 => |*v| { + if (index >= 3 or indices.len != 1) return RuntimeError.InvalidSpirV; + v[index] = (try getValuePrimitiveField(.UInt, 32, @constCast(object_value))).*; + }, + .Vector2u32 => |*v| { + if (index >= 2 or indices.len != 1) return RuntimeError.InvalidSpirV; + v[index] = (try getValuePrimitiveField(.UInt, 32, @constCast(object_value))).*; + }, + + else => return RuntimeError.InvalidValueType, + } + } + }; + + const indices = try arena.allocator().alloc(SpvWord, index_count); + for (indices) |*idx| idx.* = try rt.it.next(); + + try helpers.insertAt(arena.allocator(), rt.results, target, object, indices); +} + fn opConstant(allocator: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) RuntimeError!void { const target = try setupConstant(allocator, rt); switch (target.variant.?.Constant.value) { @@ -1628,6 +1747,45 @@ fn opDecorateMember(allocator: std.mem.Allocator, _: SpvWord, rt: *Runtime) Runt try addDecoration(allocator, rt, target, decoration_type, member); } +fn opDecorationGroup(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { + _ = rt.it.skip(); +} + +fn opGroupDecorate(allocator: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) RuntimeError!void { + const decoration_group = try rt.it.next(); + + if (word_count < 2) return RuntimeError.InvalidSpirV; + + const group_result = &rt.mod.results[decoration_group]; + + for (0..(word_count - 1)) |_| { + const target = try rt.it.next(); + + for (group_result.decorations.items) |*decoration| { + try cloneDecorationTo(allocator, rt, target, decoration, null); + } + } +} + +fn opGroupMemberDecorate(allocator: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) RuntimeError!void { + const decoration_group = try rt.it.next(); + + if (word_count < 3) return RuntimeError.InvalidSpirV; + if (((word_count - 1) % 2) != 0) return RuntimeError.InvalidSpirV; + + const group_result = &rt.mod.results[decoration_group]; + const pair_count = @divExact(word_count - 1, 2); + + for (0..pair_count) |_| { + const target = try rt.it.next(); + const member = try rt.it.next(); + + for (group_result.decorations.items) |*decoration| { + try cloneDecorationTo(allocator, rt, target, decoration, member); + } + } +} + fn opDot(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { const target_type = (try rt.results[try rt.it.next()].getVariant()).Type; var value = try rt.results[try rt.it.next()].getValue(); @@ -1895,6 +2053,8 @@ fn opReturnValue(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!vo if (rt.function_stack.getLastOrNull()) |function| { var ret_res = rt.results[try rt.it.next()]; copyValue(try function.ret.getValue(), try ret_res.getValue()); + } else { + return RuntimeError.InvalidSpirV; // No current function ??? } _ = rt.function_stack.pop(); @@ -2024,8 +2184,10 @@ fn opStore(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { fn opTypeArray(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { const id = try rt.it.next(); + var target = &rt.mod.results[id]; const components_type_word = try rt.it.next(); - rt.mod.results[id].variant = .{ + const components_type_data = &((try rt.mod.results[components_type_word].getVariant()).*).Type; + target.variant = .{ .Type = .{ .Array = .{ .components_type_word = components_type_word, @@ -2034,6 +2196,13 @@ fn opTypeArray(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void else => return RuntimeError.InvalidSpirV, }, .member_count = try rt.it.next(), + .stride = blk: { + for (target.decorations.items) |decoration| { + if (decoration.rtype == .ArrayStride) + break :blk decoration.literal_1; + } + break :blk @intCast(components_type_data.getSize(rt.mod.results)); + }, }, }, };