diff --git a/src/Result.zig b/src/Result.zig index b8358ec..dc4e322 100644 --- a/src/Result.zig +++ b/src/Result.zig @@ -374,6 +374,7 @@ pub const Value = union(Type) { inline for (0..4) |i| { const start = i * 4; const end = (i + 1) * 4; + if (start >= input.len or end >= input.len) return RuntimeError.OutOfBounds; std.mem.copyForwards(u8, std.mem.asBytes(&vec[i]), input[start..end]); } return 4 * 4; @@ -382,6 +383,7 @@ pub const Value = union(Type) { inline for (0..3) |i| { const start = i * 4; const end = (i + 1) * 4; + if (start >= input.len or end >= input.len) return RuntimeError.OutOfBounds; std.mem.copyForwards(u8, std.mem.asBytes(&vec[i]), input[start..end]); } return 3 * 4; @@ -390,6 +392,7 @@ pub const Value = union(Type) { inline for (0..2) |i| { const start = i * 4; const end = (i + 1) * 4; + if (start >= input.len or end >= input.len) return RuntimeError.OutOfBounds; std.mem.copyForwards(u8, std.mem.asBytes(&vec[i]), input[start..end]); } return 2 * 4; @@ -398,6 +401,7 @@ pub const Value = union(Type) { inline for (0..4) |i| { const start = i * 4; const end = (i + 1) * 4; + if (start >= input.len or end >= input.len) return RuntimeError.OutOfBounds; std.mem.copyForwards(u8, std.mem.asBytes(&vec[i]), input[start..end]); } return 4 * 4; @@ -406,6 +410,7 @@ pub const Value = union(Type) { inline for (0..3) |i| { const start = i * 4; const end = (i + 1) * 4; + if (start >= input.len or end >= input.len) return RuntimeError.OutOfBounds; std.mem.copyForwards(u8, std.mem.asBytes(&vec[i]), input[start..end]); } return 3 * 4; @@ -414,6 +419,7 @@ pub const Value = union(Type) { inline for (0..2) |i| { const start = i * 4; const end = (i + 1) * 4; + if (start >= input.len or end >= input.len) return RuntimeError.OutOfBounds; std.mem.copyForwards(u8, std.mem.asBytes(&vec[i]), input[start..end]); } return 2 * 4; @@ -422,6 +428,7 @@ pub const Value = union(Type) { inline for (0..4) |i| { const start = i * 4; const end = (i + 1) * 4; + if (start >= input.len or end >= input.len) return RuntimeError.OutOfBounds; std.mem.copyForwards(u8, std.mem.asBytes(&vec[i]), input[start..end]); } return 4 * 4; @@ -430,6 +437,7 @@ pub const Value = union(Type) { inline for (0..3) |i| { const start = i * 4; const end = (i + 1) * 4; + if (start >= input.len or end >= input.len) return RuntimeError.OutOfBounds; std.mem.copyForwards(u8, std.mem.asBytes(&vec[i]), input[start..end]); } return 3 * 4; @@ -438,6 +446,7 @@ pub const Value = union(Type) { inline for (0..2) |i| { const start = i * 4; const end = (i + 1) * 4; + if (start >= input.len or end >= input.len) return RuntimeError.OutOfBounds; std.mem.copyForwards(u8, std.mem.asBytes(&vec[i]), input[start..end]); } return 2 * 4; @@ -796,17 +805,21 @@ pub fn resolveSign(target_type: TypeData, rt: *const Runtime) RuntimeError!enum }; } -pub fn resolveType(self: *const Self, results: []const Self) *const Self { +pub inline fn resolveType(self: *const Self, results: []const Self) *const Self { + return if (self.resolveTypeWordOrNull()) |word| &results[word] else self; +} + +pub fn resolveTypeWordOrNull(self: *const Self) ?SpvWord { return if (self.variant) |variant| switch (variant) { .Type => |t| switch (t) { - .Pointer => |ptr| &results[ptr.target], - else => self, + .Pointer => |ptr| ptr.target, + else => null, }, - else => self, + else => null, } else - self; + null; } pub fn getMemberCounts(self: *const Self) usize { diff --git a/src/opcodes.zig b/src/opcodes.zig index 25a2aab..b6ea946 100644 --- a/src/opcodes.zig +++ b/src/opcodes.zig @@ -42,14 +42,18 @@ const CondOp = enum { Equal, Greater, GreaterEqual, + IsFinite, + IsInf, + IsNan, + IsNormal, Less, LessEqual, - NotEqual, - LogicalEqual, - LogicalNotEqual, LogicalAnd, - LogicalOr, + LogicalEqual, LogicalNot, + LogicalNotEqual, + LogicalOr, + NotEqual, }; const BitOp = enum { @@ -85,6 +89,7 @@ pub const SetupDispatcher = block: { .Capability = opCapability, .CompositeConstruct = autoSetupConstant, .Constant = opConstant, + .ConstantComposite = opConstantComposite, .ConvertFToS = autoSetupConstant, .ConvertFToU = autoSetupConstant, .ConvertPtrToU = autoSetupConstant, @@ -125,6 +130,10 @@ pub const SetupDispatcher = block: { .IMul = autoSetupConstant, .INotEqual = autoSetupConstant, .ISub = autoSetupConstant, + .IsFinite = autoSetupConstant, + .IsInf = autoSetupConstant, + .IsNan = autoSetupConstant, + .IsNormal = autoSetupConstant, .Label = opLabel, .Load = autoSetupConstant, .LogicalAnd = autoSetupConstant, @@ -151,6 +160,7 @@ pub const SetupDispatcher = block: { .SNegate = autoSetupConstant, .SatConvertSToU = autoSetupConstant, .SatConvertUToS = autoSetupConstant, + .Select = autoSetupConstant, .ShiftLeftLogical = autoSetupConstant, .ShiftRightArithmetic = autoSetupConstant, .ShiftRightLogical = autoSetupConstant, @@ -230,6 +240,10 @@ pub fn initRuntimeDispatcher() void { runtime_dispatcher[@intFromEnum(spv.SpvOp.IMul)] = MathEngine(.SInt, .Mul).op; runtime_dispatcher[@intFromEnum(spv.SpvOp.INotEqual)] = CondEngine(.SInt, .NotEqual).op; runtime_dispatcher[@intFromEnum(spv.SpvOp.ISub)] = MathEngine(.SInt, .Sub).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.IsFinite)] = CondEngine(.Float, .IsNan).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.IsInf)] = CondEngine(.Float, .IsInf).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.IsNan)] = CondEngine(.Float, .IsNan).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.IsNormal)] = CondEngine(.Float, .IsNan).op; runtime_dispatcher[@intFromEnum(spv.SpvOp.Kill)] = opKill; runtime_dispatcher[@intFromEnum(spv.SpvOp.Load)] = opLoad; runtime_dispatcher[@intFromEnum(spv.SpvOp.LogicalAnd)] = CondEngine(.Bool, .LogicalAnd).op; @@ -251,6 +265,7 @@ pub fn initRuntimeDispatcher() void { runtime_dispatcher[@intFromEnum(spv.SpvOp.SLessThanEqual)] = CondEngine(.SInt, .LessEqual).op; runtime_dispatcher[@intFromEnum(spv.SpvOp.SMod)] = MathEngine(.SInt, .Mod).op; runtime_dispatcher[@intFromEnum(spv.SpvOp.SNegate)] = MathEngine(.SInt, .Negate).opSingle; + runtime_dispatcher[@intFromEnum(spv.SpvOp.Select)] = opSelect; runtime_dispatcher[@intFromEnum(spv.SpvOp.ShiftLeftLogical)] = BitEngine(.UInt, .ShiftLeft).op; runtime_dispatcher[@intFromEnum(spv.SpvOp.ShiftRightArithmetic)] = BitEngine(.SInt, .ShiftRightArithmetic).op; runtime_dispatcher[@intFromEnum(spv.SpvOp.ShiftRightLogical)] = BitEngine(.UInt, .ShiftRight).op; @@ -479,6 +494,14 @@ fn BitEngine(comptime T: ValueType, comptime Op: BitOp) type { }; } +const unary_condition_set = std.EnumSet(CondOp).initMany(&.{ + .IsFinite, + .IsInf, + .IsNan, + .IsNormal, + .LogicalNot, +}); + fn CondOperator(comptime T: ValueType, comptime Op: CondOp) type { return struct { fn operation(comptime TT: type, a: TT, b: TT) RuntimeError!bool { @@ -497,6 +520,10 @@ fn CondOperator(comptime T: ValueType, comptime Op: CondOp) type { fn operationUnary(comptime TT: type, a: TT) RuntimeError!bool { return switch (Op) { + .IsFinite => std.math.isFinite(a), + .IsInf => std.math.isInf(a), + .IsNan => std.math.isNan(a), + .IsNormal => std.math.isNormal(a), .LogicalNot => !a, else => RuntimeError.InvalidSpirV, }; @@ -510,7 +537,7 @@ fn CondOperator(comptime T: ValueType, comptime Op: CondOp) type { const TT = getValuePrimitiveFieldType(T, bits); const a = (try getValuePrimitiveField(T, bits, @constCast(a_v))).*; - if (comptime Op == .LogicalNot) { + if (unary_condition_set.contains(Op)) { dst_bool.Bool = try operationUnary(TT, a); } else { const b_ptr = b_v orelse return RuntimeError.InvalidSpirV; @@ -545,7 +572,7 @@ fn CondEngine(comptime T: ValueType, comptime Op: CondOp) type { const op1_type = try op1_result.getValueTypeWord(); const op1_value = try op1_result.getValue(); - const op2_value: ?*Result.Value = if (comptime Op == .LogicalNot) null else try rt.results[try rt.it.next()].getValue(); + const op2_value: ?*Result.Value = if (unary_condition_set.contains(Op)) null else try rt.results[try rt.it.next()].getValue(); const lane_bits = try Result.resolveLaneBitWidth((try rt.results[op1_type].getVariant()).Type, rt); @@ -1350,6 +1377,16 @@ fn opConstant(allocator: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) R } } +fn opConstantComposite(allocator: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { + const target = try setupConstant(allocator, rt); + const target_value = try target.getValue(); + if (target_value.getCompositeDataOrNull()) |*values| { + for (values.*) |*element| { + copyValue(element, try rt.mod.results[try rt.it.next()].getValue()); + } + } +} + fn opCopyMemory(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { const target = try rt.it.next(); const source = try rt.it.next(); @@ -1574,6 +1611,21 @@ fn opLoad(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { _ = rt.it.skip(); const id = try rt.it.next(); const ptr_id = try rt.it.next(); + + //std.debug.print("\n{d} - {d}\n", .{ id, ptr_id }); + //@import("pretty").print(std.heap.page_allocator, rt.results[id], .{ + // .tab_size = 4, + // .max_depth = 0, + // .struct_max_len = 0, + // .array_max_len = 0, + //}) catch unreachable; + + //@import("pretty").print(std.heap.page_allocator, rt.results[ptr_id], .{ + // .tab_size = 4, + // .max_depth = 0, + // .struct_max_len = 0, + // .array_max_len = 0, + //}) catch unreachable; copyValue(try rt.results[id].getValue(), try rt.results[ptr_id].getValue()); } @@ -1649,6 +1701,30 @@ fn opReturnValue(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!vo } } +fn opSelect(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { + _ = rt.it.skip(); + const id = try rt.it.next(); + const cond = try rt.it.next(); + const obj1 = try rt.it.next(); + const obj2 = try rt.it.next(); + + const target_val = try rt.results[id].getValue(); + const cond_val = try rt.results[cond].getValue(); + const obj1_val = try rt.results[obj1].getValue(); + const obj2_val = try rt.results[obj2].getValue(); + + if (target_val.getCompositeDataOrNull()) |*targets| { + for ( + targets.*, + cond_val.getCompositeDataOrNull().?, + obj1_val.getCompositeDataOrNull().?, + obj2_val.getCompositeDataOrNull().?, + ) |*t, c, o1, o2| { + copyValue(t, if (c.Bool) &o1 else &o2); + } + } +} + fn opSourceExtension(allocator: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) RuntimeError!void { rt.mod.extensions.append(allocator, try readStringN(allocator, &rt.it, word_count)) catch return RuntimeError.OutOfMemory; } @@ -1872,12 +1948,13 @@ fn opVariable(allocator: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) R const target = &rt.mod.results[id]; - const resolved = rt.mod.results[var_type].resolveType(rt.mod.results); + const resolved_word = if (rt.mod.results[var_type].resolveTypeWordOrNull()) |word| word else var_type; + const resolved = &rt.mod.results[resolved_word]; const member_count = resolved.getMemberCounts(); target.variant = .{ .Variable = .{ .storage_class = storage_class, - .type_word = var_type, + .type_word = resolved_word, .type = switch ((try resolved.getConstVariant()).*) { .Type => |t| @as(Result.Type, t), else => return RuntimeError.InvalidSpirV,