diff --git a/src/opcodes.zig b/src/opcodes.zig index 9755ad5..10e92a1 100644 --- a/src/opcodes.zig +++ b/src/opcodes.zig @@ -270,6 +270,8 @@ pub fn initRuntimeDispatcher() void { runtime_dispatcher[@intFromEnum(spv.SpvOp.FUnordNotEqual)] = CondEngine(.Float, .NotEqual).op; runtime_dispatcher[@intFromEnum(spv.SpvOp.FunctionCall)] = opFunctionCall; runtime_dispatcher[@intFromEnum(spv.SpvOp.IAdd)] = MathEngine(.SInt, .Add, false).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.IAddCarry)] = opIAddCarry; + runtime_dispatcher[@intFromEnum(spv.SpvOp.ISubBorrow)] = opISubBorrow; runtime_dispatcher[@intFromEnum(spv.SpvOp.IEqual)] = CondEngine(.SInt, .Equal).op; runtime_dispatcher[@intFromEnum(spv.SpvOp.IMul)] = MathEngine(.SInt, .Mul, false).op; runtime_dispatcher[@intFromEnum(spv.SpvOp.INotEqual)] = CondEngine(.SInt, .NotEqual).op; @@ -1953,7 +1955,7 @@ fn writeMulExtendedBits(comptime bits: u32, dst: *Value, lane_index: usize, valu } fn opMulExtended(comptime is_signed: bool, rt: *Runtime) RuntimeError!void { - _ = try rt.it.next(); // Result Type + _ = try rt.it.next(); // result Type const result_id = try rt.it.next(); const lhs = try rt.results[try rt.it.next()].getValue(); const rhs = try rt.results[try rt.it.next()].getValue(); @@ -2008,6 +2010,96 @@ fn opMulExtended(comptime is_signed: bool, rt: *Runtime) RuntimeError!void { } } +fn opIAddCarry(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { + _ = try rt.it.next(); // result Type + const result_id = try rt.it.next(); + const lhs = try rt.results[try rt.it.next()].getValue(); + const rhs = try rt.results[try rt.it.next()].getValue(); + + const result = try rt.results[result_id].getValue(); + const result_members = switch (result.*) { + .Structure => |*s| s.values, + else => return RuntimeError.InvalidSpirV, + }; + if (result_members.len != 2) return RuntimeError.InvalidSpirV; + + const value_dst = &result_members[0]; + const carry_dst = &result_members[1]; + + const lane_count = try lhs.resolveLaneCount(); + if (try rhs.resolveLaneCount() != lane_count) return RuntimeError.InvalidSpirV; + if (try value_dst.resolveLaneCount() != lane_count) return RuntimeError.InvalidSpirV; + if (try carry_dst.resolveLaneCount() != lane_count) return RuntimeError.InvalidSpirV; + + const lane_bits = try lhs.resolveLaneBitWidth(); + if (try rhs.resolveLaneBitWidth() != lane_bits) return RuntimeError.InvalidSpirV; + if (try value_dst.resolveLaneBitWidth() != lane_bits) return RuntimeError.InvalidSpirV; + if (try carry_dst.resolveLaneBitWidth() != lane_bits) return RuntimeError.InvalidSpirV; + + switch (lane_bits) { + inline 8, 16, 32, 64 => |bits| { + const UIntT = Value.getPrimitiveFieldType(.UInt, bits); + + for (0..lane_count) |lane_index| { + const l: UIntT = try Value.readLane(.UInt, bits, lhs, lane_index); + const r: UIntT = try Value.readLane(.UInt, bits, rhs, lane_index); + const add_result = @addWithOverflow(l, r); + const sum = add_result[0]; + const carry: UIntT = @intCast(add_result[1]); + + try writeMulExtendedBits(bits, value_dst, lane_index, sum); + try writeMulExtendedBits(bits, carry_dst, lane_index, carry); + } + }, + else => return RuntimeError.InvalidSpirV, + } +} + +fn opISubBorrow(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { + _ = try rt.it.next(); // result Type + const result_id = try rt.it.next(); + const lhs = try rt.results[try rt.it.next()].getValue(); + const rhs = try rt.results[try rt.it.next()].getValue(); + + const result = try rt.results[result_id].getValue(); + const result_members = switch (result.*) { + .Structure => |*s| s.values, + else => return RuntimeError.InvalidSpirV, + }; + if (result_members.len != 2) return RuntimeError.InvalidSpirV; + + const value_dst = &result_members[0]; + const borrow_dst = &result_members[1]; + + const lane_count = try lhs.resolveLaneCount(); + if (try rhs.resolveLaneCount() != lane_count) return RuntimeError.InvalidSpirV; + if (try value_dst.resolveLaneCount() != lane_count) return RuntimeError.InvalidSpirV; + if (try borrow_dst.resolveLaneCount() != lane_count) return RuntimeError.InvalidSpirV; + + const lane_bits = try lhs.resolveLaneBitWidth(); + if (try rhs.resolveLaneBitWidth() != lane_bits) return RuntimeError.InvalidSpirV; + if (try value_dst.resolveLaneBitWidth() != lane_bits) return RuntimeError.InvalidSpirV; + if (try borrow_dst.resolveLaneBitWidth() != lane_bits) return RuntimeError.InvalidSpirV; + + switch (lane_bits) { + inline 8, 16, 32, 64 => |bits| { + const UIntT = Value.getPrimitiveFieldType(.UInt, bits); + + for (0..lane_count) |lane_index| { + const l: UIntT = try Value.readLane(.UInt, bits, lhs, lane_index); + const r: UIntT = try Value.readLane(.UInt, bits, rhs, lane_index); + const sub_result = @subWithOverflow(l, r); + const diff = sub_result[0]; + const borrow: UIntT = @intCast(sub_result[1]); + + try writeMulExtendedBits(bits, value_dst, lane_index, diff); + try writeMulExtendedBits(bits, borrow_dst, lane_index, borrow); + } + }, + else => return RuntimeError.InvalidSpirV, + } +} + fn opUMulExtended(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { try opMulExtended(false, rt); }