adding msb, lsb and spec constants
Build / build (push) Successful in 2m3s
Test / build (push) Successful in 8m40s

This commit is contained in:
2026-03-30 01:00:00 +02:00
parent fbaf85a849
commit 6c8b364c7d
6 changed files with 431 additions and 243 deletions
+118 -20
View File
@@ -8,19 +8,18 @@ const Module = @import("../Module.zig");
const Runtime = @import("../Runtime.zig");
const Result = @import("../Result.zig");
const WordIterator = @import("../WordIterator.zig");
const Value = @import("../Value.zig").Value;
const value_ns = @import("../Value.zig");
const RuntimeError = Runtime.RuntimeError;
const ValueType = opc.ValueType;
const getValuePrimitiveField = opc.getValuePrimitiveField;
const getValuePrimitiveFieldType = opc.getValuePrimitiveFieldType;
const SpvVoid = spv.SpvVoid;
const SpvByte = spv.SpvByte;
const SpvWord = spv.SpvWord;
const SpvBool = spv.SpvBool;
const Value = value_ns.Value;
const PrimitiveType = value_ns.PrimitiveType;
const MathOp = enum {
Acos,
Acosh,
@@ -67,6 +66,12 @@ const MathOp = enum {
UMin,
};
const IntBitOp = enum {
FindILsb,
FindSMsb,
FindUMsb,
};
pub const OpCodeExtFunc = opc.OpCodeExtFunc;
/// Not an EnumMap as it is way too slow for this purpose
@@ -93,6 +98,9 @@ pub fn initRuntimeDispatcher() void {
runtime_dispatcher[@intFromEnum(ext.GLSLOp.Sqrt)] = MathEngine(.Float, .Sqrt).opSingleOperator;
runtime_dispatcher[@intFromEnum(ext.GLSLOp.Tan)] = MathEngine(.Float, .Tan).opSingleOperator;
runtime_dispatcher[@intFromEnum(ext.GLSLOp.Trunc)] = MathEngine(.Float, .Trunc).opSingleOperator;
runtime_dispatcher[@intFromEnum(ext.GLSLOp.FindILsb)] = IntBitEngine(.FindILsb).op;
runtime_dispatcher[@intFromEnum(ext.GLSLOp.FindSMsb)] = IntBitEngine(.FindSMsb).op;
runtime_dispatcher[@intFromEnum(ext.GLSLOp.FindUMsb)] = IntBitEngine(.FindUMsb).op;
// zig fmt: on
}
@@ -104,7 +112,7 @@ fn isFloatOrF32Vector(comptime T: type) bool {
};
}
fn MathEngine(comptime T: ValueType, comptime Op: MathOp) type {
fn MathEngine(comptime T: PrimitiveType, comptime Op: MathOp) type {
return struct {
fn opSingleOperator(_: std.mem.Allocator, target_type_id: SpvWord, id: SpvWord, _: SpvWord, rt: *Runtime) RuntimeError!void {
const target_type = (try rt.results[target_type_id].getVariant()).Type;
@@ -147,9 +155,9 @@ fn MathEngine(comptime T: ValueType, comptime Op: MathOp) type {
inline 8, 16, 32, 64 => |bits| {
if (bits == 8 and T == .Float) return RuntimeError.InvalidSpirV;
const ScalarT = getValuePrimitiveFieldType(T, bits);
const d_field = try getValuePrimitiveField(T, bits, d);
const s_field = try getValuePrimitiveField(T, bits, @constCast(s));
const ScalarT = Value.getPrimitiveFieldType(T, bits);
const d_field = try Value.getPrimitiveField(T, bits, d);
const s_field = try Value.getPrimitiveField(T, bits, @constCast(s));
d_field.* = try operation(ScalarT, s_field.*);
},
else => return RuntimeError.InvalidSpirV,
@@ -201,10 +209,10 @@ fn MathEngine(comptime T: ValueType, comptime Op: MathOp) type {
inline 8, 16, 32, 64 => |bits| {
if (bits == 8 and T == .Float) return RuntimeError.InvalidSpirV;
const ScalarT = getValuePrimitiveFieldType(T, bits);
const d_field = try getValuePrimitiveField(T, bits, d);
const l_field = try getValuePrimitiveField(T, bits, @constCast(l));
const r_field = try getValuePrimitiveField(T, bits, @constCast(r));
const ScalarT = Value.getPrimitiveFieldType(T, bits);
const d_field = try Value.getPrimitiveField(T, bits, d);
const l_field = try Value.getPrimitiveField(T, bits, @constCast(l));
const r_field = try Value.getPrimitiveField(T, bits, @constCast(r));
d_field.* = try operation(ScalarT, l_field.*, r_field.*);
},
else => return RuntimeError.InvalidSpirV,
@@ -237,6 +245,96 @@ fn MathEngine(comptime T: ValueType, comptime Op: MathOp) type {
};
}
fn IntBitEngine(comptime op_kind: IntBitOp) type {
return struct {
inline fn findILsb32(x: u32) i32 {
if (x == 0) return -1;
return @intCast(@ctz(x));
}
inline fn findUMsb32(x: u32) i32 {
if (x == 0) return -1;
return 31 - @as(i32, @intCast(@clz(x)));
}
inline fn findSMsb32(x: i32) i32 {
if (x == 0 or x == -1) return -1;
if (x > 0) {
return findUMsb32(@bitCast(x));
}
return findUMsb32(@bitCast(~x));
}
inline fn computeSigned(x: i32) i32 {
return switch (op_kind) {
.FindILsb => findILsb32(@bitCast(x)),
.FindSMsb => findSMsb32(x),
.FindUMsb => findUMsb32(@bitCast(x)),
};
}
inline fn computeUnsigned(x: u32) u32 {
const result: i32 = switch (op_kind) {
.FindILsb => findILsb32(x),
.FindSMsb => findSMsb32(@bitCast(x)),
.FindUMsb => findUMsb32(x),
};
return @bitCast(result);
}
fn readSourceLane(src: *const Value, lane_index: usize) RuntimeError!u32 {
return switch (op_kind) {
.FindSMsb => @bitCast(try Value.readLane(.SInt, 32, src, lane_index)),
.FindILsb, .FindUMsb => try Value.readLane(.UInt, 32, src, lane_index),
};
}
fn writeDestLane(dst: *Value, lane_index: usize, bits: u32, dst_is_signed: bool) RuntimeError!void {
if (dst_is_signed) {
try Value.writeLane(.SInt, 32, dst, lane_index, @as(i32, @bitCast(bits)));
} else {
try Value.writeLane(.UInt, 32, dst, lane_index, bits);
}
}
fn apply(dst: *Value, src: *const Value, lane_count: usize, dst_is_signed: bool) RuntimeError!void {
for (0..lane_count) |lane_index| {
const src_bits = try readSourceLane(src, lane_index);
const out_bits: u32 = if (dst_is_signed)
@bitCast(computeSigned(@bitCast(src_bits)))
else
computeUnsigned(src_bits);
try writeDestLane(dst, lane_index, out_bits, dst_is_signed);
}
}
fn op(
_: std.mem.Allocator,
target_type_id: SpvWord,
id: SpvWord,
_: SpvWord,
rt: *Runtime,
) RuntimeError!void {
const target_type = (try rt.results[target_type_id].getVariant()).Type;
const dst = try rt.results[id].getValue();
const src = try rt.results[try rt.it.next()].getValue();
const lane_bits = try Result.resolveLaneBitWidth(target_type, rt);
if (lane_bits != 32)
return RuntimeError.InvalidSpirV;
const lane_count = try Result.resolveLaneCount(target_type);
const dst_sign = try Result.resolveSign(target_type, rt);
try apply(dst, src, lane_count, dst_sign == .signed);
}
};
}
fn opLength(_: std.mem.Allocator, target_type_id: SpvWord, id: SpvWord, _: SpvWord, rt: *Runtime) RuntimeError!void {
const target_type = (try rt.results[target_type_id].getVariant()).Type;
const dst = try rt.results[id].getValue();
@@ -247,7 +345,7 @@ fn opLength(_: std.mem.Allocator, target_type_id: SpvWord, id: SpvWord, _: SpvWo
switch (lane_bits) {
inline 16, 32, 64 => |bits| {
var sum: std.meta.Float(bits) = 0.0;
const d_field = try getValuePrimitiveField(.Float, bits, dst);
const d_field = try Value.getPrimitiveField(.Float, bits, dst);
if (bits == 32) { // More likely to be SIMD if f32
switch (src.*) {
@@ -270,12 +368,12 @@ fn opLength(_: std.mem.Allocator, target_type_id: SpvWord, id: SpvWord, _: SpvWo
switch (src.*) {
.Float => {
// Fast path
const s_field = try getValuePrimitiveField(.Float, bits, src);
const s_field = try Value.getPrimitiveField(.Float, bits, src);
d_field.* = s_field.*;
return;
},
.Vector => |src_vec| for (src_vec) |*s_lane| {
const s_field = try getValuePrimitiveField(.Float, bits, s_lane);
const s_field = try Value.getPrimitiveField(.Float, bits, s_lane);
sum += s_field.*;
},
else => return RuntimeError.InvalidSpirV,
@@ -323,11 +421,11 @@ fn opNormalize(_: std.mem.Allocator, target_type_id: SpvWord, id: SpvWord, _: Sp
switch (src.*) {
.Float => {
const s_field = try getValuePrimitiveField(.Float, bits, src);
const s_field = try Value.getPrimitiveField(.Float, bits, src);
sum = s_field.*;
},
.Vector => |src_vec| for (src_vec) |*s_lane| {
const s_field = try getValuePrimitiveField(.Float, bits, s_lane);
const s_field = try Value.getPrimitiveField(.Float, bits, s_lane);
sum += s_field.*;
},
else => return RuntimeError.InvalidSpirV,
@@ -337,8 +435,8 @@ fn opNormalize(_: std.mem.Allocator, target_type_id: SpvWord, id: SpvWord, _: Sp
switch (dst.*) {
.Vector => |dst_vec| for (dst_vec, src.Vector) |*d_lane, *s_lane| {
const d_field = try getValuePrimitiveField(.Float, bits, d_lane);
const s_field = try getValuePrimitiveField(.Float, bits, s_lane);
const d_field = try Value.getPrimitiveField(.Float, bits, d_lane);
const s_field = try Value.getPrimitiveField(.Float, bits, s_lane);
d_field.* = s_field.* / sum;
},
else => return RuntimeError.InvalidSpirV,