From e8a08d78851dd07e72de7b40821cd2af10f38866 Mon Sep 17 00:00:00 2001 From: Kbz-8 Date: Mon, 16 Mar 2026 03:52:59 +0100 Subject: [PATCH] fixing opSelect and opConstantComposite --- src/opcodes.zig | 246 +++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 200 insertions(+), 46 deletions(-) diff --git a/src/opcodes.zig b/src/opcodes.zig index 45937b7..f004141 100644 --- a/src/opcodes.zig +++ b/src/opcodes.zig @@ -323,7 +323,7 @@ fn BitOperator(comptime T: ValueType, comptime Op: BitOp) type { return (v >> @intCast(offset)) & @as(TT, @intCast(bitMask(count))); } - fn operationUnary(comptime TT: type, op1: TT) RuntimeError!TT { + inline fn operationUnary(comptime TT: type, op1: TT) RuntimeError!TT { return switch (Op) { .BitCount => @as(TT, @intCast(@bitSizeOf(TT))), // keep return type TT .BitReverse => @bitReverse(op1), @@ -332,7 +332,7 @@ fn BitOperator(comptime T: ValueType, comptime Op: BitOp) type { }; } - fn operationBinary(comptime TT: type, rt: *Runtime, op1: TT, op2: TT) RuntimeError!TT { + inline fn operationBinary(comptime TT: type, rt: *Runtime, op1: TT, op2: TT) RuntimeError!TT { return switch (Op) { .BitFieldInsert => blk: { const offset = try rt.results[try rt.it.next()].getValue(); @@ -380,29 +380,13 @@ fn BitOperator(comptime T: ValueType, comptime Op: BitOp) type { } } - fn laneRhsPtr(op2_value: ?*Value, index: usize) ?*const Value { + inline fn laneRhsPtr(op2_value: ?*Value, index: usize) ?*const Value { if (comptime isUnaryOp()) return null; const v = op2_value orelse return null; return &v.Vector[index]; } - fn applyFixedVector(comptime ElemT: type, comptime N: usize, dst: *[N]ElemT, op1: *[N]ElemT, op2_value: ?*Value) RuntimeError!void { - if (comptime isUnaryOp()) { - inline for (0..N) |i| dst[i] = try operationUnary(ElemT, op1[i]); - } else { - const op2 = op2_value orelse return RuntimeError.InvalidSpirV; - const b: *const [N]ElemT = switch (N) { - 2 => &op2.*.Vector2u32, // will be overridden by call sites per ElemT/tag - 3 => &op2.*.Vector3u32, - 4 => &op2.*.Vector4u32, - else => unreachable, - }; - _ = b; - return RuntimeError.InvalidSpirV; - } - } - - fn applyFixedVectorBinary( + inline fn applyFixedVectorBinary( comptime ElemT: type, comptime N: usize, rt: *Runtime, @@ -413,7 +397,7 @@ fn BitOperator(comptime T: ValueType, comptime Op: BitOp) type { inline for (0..N) |i| dst[i] = try operationBinary(ElemT, rt, op1[i], op2[i]); } - fn applyFixedVectorUnary( + inline fn applyFixedVectorUnary( comptime ElemT: type, comptime N: usize, dst: *[N]ElemT, @@ -495,17 +479,23 @@ fn BitEngine(comptime T: ValueType, comptime Op: BitOp) type { }; } -const unary_condition_set = std.EnumSet(CondOp).initMany(&.{ - .IsFinite, - .IsInf, - .IsNan, - .IsNormal, - .LogicalNot, -}); - fn CondOperator(comptime T: ValueType, comptime Op: CondOp) type { return struct { - fn operation(comptime TT: type, a: TT, b: TT) RuntimeError!bool { + inline fn isUnaryOp() bool { + return comptime switch (Op) { + .IsFinite, .IsInf, .IsNan, .IsNormal, .LogicalNot => true, + else => false, + }; + } + + inline fn operationBinary(comptime TT: type, a: TT, b: TT) RuntimeError!bool { + if (comptime TT == bool) { + switch (Op) { + .LogicalAnd => return a and b, + .LogicalOr => return a or b, + else => {}, + } + } return switch (Op) { .Equal, .LogicalEqual => a == b, .NotEqual, .LogicalNotEqual => a != b, @@ -513,24 +503,27 @@ fn CondOperator(comptime T: ValueType, comptime Op: CondOp) type { .GreaterEqual => a >= b, .Less => a < b, .LessEqual => a <= b, - .LogicalAnd => a and b, - .LogicalOr => a or b, else => RuntimeError.InvalidSpirV, }; } - fn operationUnary(comptime TT: type, a: TT) RuntimeError!bool { + inline fn operationUnary(comptime TT: type, a: TT) RuntimeError!bool { + if (comptime TT == bool) { + switch (Op) { + .LogicalNot => return !a, + else => {}, + } + } return switch (Op) { .IsFinite => std.math.isFinite(a), .IsInf => std.math.isInf(a), .IsNan => std.math.isNan(a), .IsNormal => std.math.isNormal(a), - .LogicalNot => !a, else => RuntimeError.InvalidSpirV, }; } - fn applyLane(bit_count: SpvWord, dst_bool: *Value, a_v: *const Value, b_v: ?*const Value) RuntimeError!void { + fn applyScalarBits(bit_count: SpvWord, dst_bool: *Value, a_v: *const Value, b_v: ?*const Value) RuntimeError!void { switch (bit_count) { inline 8, 16, 32, 64 => |bits| { if (bits == 8 and T == .Float) return RuntimeError.InvalidSpirV; @@ -538,23 +531,42 @@ fn CondOperator(comptime T: ValueType, comptime Op: CondOp) type { const TT = getValuePrimitiveFieldType(T, bits); const a = (try getValuePrimitiveField(T, bits, @constCast(a_v))).*; - if (unary_condition_set.contains(Op)) { + if (comptime isUnaryOp()) { dst_bool.Bool = try operationUnary(TT, a); } else { const b_ptr = b_v orelse return RuntimeError.InvalidSpirV; const b = (try getValuePrimitiveField(T, bits, @constCast(b_ptr))).*; - dst_bool.Bool = try operation(TT, a, b); + dst_bool.Bool = try operationBinary(TT, a, b); } }, else => return RuntimeError.InvalidSpirV, } } - fn laneRhsPtr(op2_value: ?*Value, index: usize) ?*const Value { + inline fn laneRhsPtr(op2_value: ?*Value, index: usize) ?*const Value { if (comptime Op == .LogicalNot) return null; const v = op2_value orelse return null; return &v.Vector[index]; } + + inline fn applyFixedVectorBinary( + comptime ElemT: type, + comptime N: usize, + dst: []Value, + op1: *[N]ElemT, + op2: *[N]ElemT, + ) RuntimeError!void { + inline for (0..N) |i| dst[i].Bool = try operationBinary(ElemT, op1[i], op2[i]); + } + + inline fn applyFixedVectorUnary( + comptime ElemT: type, + comptime N: usize, + dst: []Value, + op1: *[N]ElemT, + ) RuntimeError!void { + inline for (0..N) |i| dst[i].Bool = try operationUnary(ElemT, op1[i]); + } }; } @@ -573,18 +585,80 @@ fn CondEngine(comptime T: ValueType, comptime Op: CondOp) type { const op1_type = try op1_result.getValueTypeWord(); const op1_value = try op1_result.getValue(); - const op2_value: ?*Value = if (unary_condition_set.contains(Op)) null else try rt.results[try rt.it.next()].getValue(); + const operator = CondOperator(T, Op); + + const op2_value: ?*Value = if (comptime operator.isUnaryOp()) null else try rt.results[try rt.it.next()].getValue(); const lane_bits = try Result.resolveLaneBitWidth((try rt.results[op1_type].getVariant()).Type, rt); - const operator = CondOperator(T, Op); - switch (dst.*) { - .Bool => try operator.applyLane(lane_bits, dst, op1_value, op2_value), + .Bool => try operator.applyScalarBits(lane_bits, dst, op1_value, op2_value), - .Vector => |dst_vec| for (dst_vec, op1_value.Vector, 0..) |*d_lane, a_lane, i| { - const b_ptr = operator.laneRhsPtr(op2_value, i); - try operator.applyLane(lane_bits, d_lane, &a_lane, b_ptr); + .Vector => |dst_vec| { + switch (op1_value.*) { + .Vector => |op1_vec| for (dst_vec, op1_vec, 0..) |*d_lane, a_lane, i| { + const b_ptr = operator.laneRhsPtr(op2_value, i); + try operator.applyScalarBits(lane_bits, d_lane, &a_lane, b_ptr); + }, + + .Vector4f32 => |*op1_vec| { + if (comptime operator.isUnaryOp()) + try operator.applyFixedVectorUnary(f32, 4, dst_vec, op1_vec) + else + try operator.applyFixedVectorBinary(f32, 4, dst_vec, op1_vec, &op2_value.?.Vector4f32); + }, + .Vector3f32 => |*op1_vec| { + if (comptime operator.isUnaryOp()) + try operator.applyFixedVectorUnary(f32, 3, dst_vec, op1_vec) + else + try operator.applyFixedVectorBinary(f32, 3, dst_vec, op1_vec, &op2_value.?.Vector3f32); + }, + .Vector2f32 => |*op1_vec| { + if (comptime operator.isUnaryOp()) + try operator.applyFixedVectorUnary(f32, 2, dst_vec, op1_vec) + else + try operator.applyFixedVectorBinary(f32, 2, dst_vec, op1_vec, &op2_value.?.Vector2f32); + }, + + //.Vector4i32 => |*op1_vec| { + // if (comptime operator.isUnaryOp()) + // try operator.applyFixedVectorUnary(i32, 4, dst_vec, op1_vec) + // else + // try operator.applyFixedVectorBinary(i32, 4, dst_vec, op1_vec, &op2_value.?.Vector4i32); + //}, + //.Vector3i32 => |*op1_vec| { + // if (comptime operator.isUnaryOp()) + // try operator.applyFixedVectorUnary(i32, 3, dst_vec, op1_vec) + // else + // try operator.applyFixedVectorBinary(i32, 3, dst_vec, op1_vec, &op2_value.?.Vector3i32); + //}, + //.Vector2i32 => |*op1_vec| { + // if (comptime operator.isUnaryOp()) + // try operator.applyFixedVectorUnary(i32, 2, dst_vec, op1_vec) + // else + // try operator.applyFixedVectorBinary(i32, 2, dst_vec, op1_vec, &op2_value.?.Vector2i32); + //}, + + //.Vector4u32 => |*op1_vec| { + // if (comptime operator.isUnaryOp()) + // try operator.applyFixedVectorUnary(u32, 4, dst_vec, op1_vec) + // else + // try operator.applyFixedVectorBinary(u32, 4, dst_vec, op1_vec, &op2_value.?.Vector4u32); + //}, + //.Vector3u32 => |*op1_vec| { + // if (comptime operator.isUnaryOp()) + // try operator.applyFixedVectorUnary(u32, 3, dst_vec, op1_vec) + // else + // try operator.applyFixedVectorBinary(u32, 3, dst_vec, op1_vec, &op2_value.?.Vector3u32); + //}, + //.Vector2u32 => |*op1_vec| { + // if (comptime operator.isUnaryOp()) + // try operator.applyFixedVectorUnary(u32, 2, dst_vec, op1_vec) + // else + // try operator.applyFixedVectorBinary(u32, 2, dst_vec, op1_vec, &op2_value.?.Vector2u32); + //}, + else => return RuntimeError.InvalidSpirV, + } }, else => return RuntimeError.InvalidSpirV, @@ -1338,7 +1412,6 @@ fn opCompositeExtract(allocator: std.mem.Allocator, word_count: SpvWord, rt: *Ru fn opConstant(allocator: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) RuntimeError!void { const target = try setupConstant(allocator, rt); - // No check on null and sizes, absolute trust in this shit switch (target.variant.?.Constant.value) { .Int => |*i| { if (word_count - 2 != 1) { @@ -1739,6 +1812,87 @@ fn opSelect(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { ) |*t, c, o1, o2| { copyValue(t, if (c.Bool) &o1 else &o2); } + return; + } + + switch (target_val.*) { + .Bool, .Int, .Float => copyValue(target_val, if (cond_val.Bool) obj1_val else obj2_val), + + .Vector4f32 => |*v| { + const cond_vec = @Vector(4, bool){ + cond_val.Vector[0].Bool, + cond_val.Vector[1].Bool, + cond_val.Vector[2].Bool, + cond_val.Vector[3].Bool, + }; + v.* = @select(f32, cond_vec, obj1_val.Vector4f32, obj2_val.Vector4f32); + }, + .Vector3f32 => |*v| { + const cond_vec = @Vector(3, bool){ + cond_val.Vector[0].Bool, + cond_val.Vector[1].Bool, + cond_val.Vector[2].Bool, + }; + v.* = @select(f32, cond_vec, obj1_val.Vector3f32, obj2_val.Vector3f32); + }, + .Vector2f32 => |*v| { + const cond_vec = @Vector(2, bool){ + cond_val.Vector[0].Bool, + cond_val.Vector[1].Bool, + }; + v.* = @select(f32, cond_vec, obj1_val.Vector2f32, obj2_val.Vector2f32); + }, + + .Vector4i32 => |*v| { + const cond_vec = @Vector(4, bool){ + cond_val.Vector[0].Bool, + cond_val.Vector[1].Bool, + cond_val.Vector[2].Bool, + cond_val.Vector[3].Bool, + }; + v.* = @select(i32, cond_vec, obj1_val.Vector4i32, obj2_val.Vector4i32); + }, + .Vector3i32 => |*v| { + const cond_vec = @Vector(3, bool){ + cond_val.Vector[0].Bool, + cond_val.Vector[1].Bool, + cond_val.Vector[2].Bool, + }; + v.* = @select(i32, cond_vec, obj1_val.Vector3i32, obj2_val.Vector3i32); + }, + .Vector2i32 => |*v| { + const cond_vec = @Vector(2, bool){ + cond_val.Vector[0].Bool, + cond_val.Vector[1].Bool, + }; + v.* = @select(i32, cond_vec, obj1_val.Vector2i32, obj2_val.Vector2i32); + }, + + .Vector4u32 => |*v| { + const cond_vec = @Vector(4, bool){ + cond_val.Vector[0].Bool, + cond_val.Vector[1].Bool, + cond_val.Vector[2].Bool, + cond_val.Vector[3].Bool, + }; + v.* = @select(u32, cond_vec, obj1_val.Vector4u32, obj2_val.Vector4u32); + }, + .Vector3u32 => |*v| { + const cond_vec = @Vector(3, bool){ + cond_val.Vector[0].Bool, + cond_val.Vector[1].Bool, + cond_val.Vector[2].Bool, + }; + v.* = @select(u32, cond_vec, obj1_val.Vector3u32, obj2_val.Vector3u32); + }, + .Vector2u32 => |*v| { + const cond_vec = @Vector(2, bool){ + cond_val.Vector[0].Bool, + cond_val.Vector[1].Bool, + }; + v.* = @select(u32, cond_vec, obj1_val.Vector2u32, obj2_val.Vector2u32); + }, + else => return RuntimeError.InvalidSpirV, } }