This commit is contained in:
2026-04-29 23:44:55 +02:00
parent 046b1c8f9e
commit 11a59d8d7f
2 changed files with 189 additions and 62 deletions
+130 -30
View File
@@ -324,6 +324,7 @@ pub fn initRuntimeDispatcher() void {
runtime_dispatcher[@intFromEnum(spv.SpvOp.VectorTimesMatrix)] = MathEngine(.Float, .VectorTimesMatrix, false).op; // TODO
runtime_dispatcher[@intFromEnum(spv.SpvOp.VectorTimesScalar)] = MathEngine(.Float, .VectorTimesScalar, false).op;
runtime_dispatcher[@intFromEnum(spv.SpvOp.SMulExtended)] = opSMulExtended;
runtime_dispatcher[@intFromEnum(spv.SpvOp.UMulExtended)] = opUMulExtended;
runtime_dispatcher[@intFromEnum(spv.SpvOp.ImageRead)] = opImageRead;
runtime_dispatcher[@intFromEnum(spv.SpvOp.ImageWrite)] = opImageWrite;
// zig fmt: on
@@ -1781,58 +1782,147 @@ fn opConstantComposite(allocator: std.mem.Allocator, _: SpvWord, rt: *Runtime) R
}
}
fn opSMulExtended(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void {
const result_type_id = try rt.it.next();
const id = try rt.it.next();
fn writeMulExtendedBits(comptime bits: u32, dst: *Value, lane_index: usize, value: Value.getPrimitiveFieldType(.UInt, bits)) RuntimeError!void {
switch (dst.*) {
.Int => |*i| {
if (i.bit_count != bits) return RuntimeError.InvalidSpirV;
if (i.is_signed) {
switch (bits) {
8 => i.value.sint8 = @bitCast(value),
16 => i.value.sint16 = @bitCast(value),
32 => i.value.sint32 = @bitCast(value),
64 => i.value.sint64 = @bitCast(value),
else => unreachable,
}
} else {
switch (bits) {
8 => i.value.uint8 = value,
16 => i.value.uint16 = value,
32 => i.value.uint32 = value,
64 => i.value.uint64 = value,
else => unreachable,
}
}
},
.Vector => |lanes| try writeMulExtendedBits(bits, &lanes[lane_index], 0, value),
.Vector2i32 => |*v| switch (lane_index) {
inline 0...1 => |i| if (bits == 32) {
v[i] = @bitCast(value);
} else {
return RuntimeError.InvalidSpirV;
},
else => return RuntimeError.InvalidSpirV,
},
.Vector3i32 => |*v| switch (lane_index) {
inline 0...2 => |i| if (bits == 32) {
v[i] = @bitCast(value);
} else {
return RuntimeError.InvalidSpirV;
},
else => return RuntimeError.InvalidSpirV,
},
.Vector4i32 => |*v| switch (lane_index) {
inline 0...3 => |i| if (bits == 32) {
v[i] = @bitCast(value);
} else {
return RuntimeError.InvalidSpirV;
},
else => return RuntimeError.InvalidSpirV,
},
.Vector2u32 => |*v| switch (lane_index) {
inline 0...1 => |i| if (bits == 32) {
v[i] = @bitCast(value);
} else {
return RuntimeError.InvalidSpirV;
},
else => return RuntimeError.InvalidSpirV,
},
.Vector3u32 => |*v| switch (lane_index) {
inline 0...2 => |i| if (bits == 32) {
v[i] = @bitCast(value);
} else {
return RuntimeError.InvalidSpirV;
},
else => return RuntimeError.InvalidSpirV,
},
.Vector4u32 => |*v| switch (lane_index) {
inline 0...3 => |i| if (bits == 32) {
v[i] = @bitCast(value);
} else {
return RuntimeError.InvalidSpirV;
},
else => return RuntimeError.InvalidSpirV,
},
else => return RuntimeError.InvalidSpirV,
}
}
fn opMulExtended(comptime is_signed: bool, rt: *Runtime) RuntimeError!void {
_ = try rt.it.next(); // Result Type
const result_id = try rt.it.next();
const lhs = try rt.results[try rt.it.next()].getValue();
const rhs = try rt.results[try rt.it.next()].getValue();
const dst = try rt.results[id].getValue();
const result_members = switch (dst.*) {
.Structure => |s| s.values,
const result = try rt.results[result_id].getValue();
const result_members = switch (result.*) {
.Structure => |*s| s.values,
else => return RuntimeError.InvalidSpirV,
};
if (result_members.len != 2) return RuntimeError.InvalidSpirV;
const lsb_dst = &result_members[0];
const msb_dst = &result_members[1];
const low_dst = &result_members[0];
const high_dst = &result_members[1];
const result_type = (try rt.results[result_type_id].getVariant()).Type;
const member_types = switch (result_type) {
.Structure => |s| s.members_type_word,
else => return RuntimeError.InvalidSpirV,
};
if (member_types.len != 2) return RuntimeError.InvalidSpirV;
const lane_count = try lhs.resolveLaneCount();
if (try rhs.resolveLaneCount() != lane_count) return RuntimeError.InvalidSpirV;
if (try low_dst.resolveLaneCount() != lane_count) return RuntimeError.InvalidSpirV;
if (try high_dst.resolveLaneCount() != lane_count) return RuntimeError.InvalidSpirV;
const value_type = (try rt.results[member_types[0]].getVariant()).Type;
const lane_count = try Result.resolveLaneCount(value_type);
const lane_bits = try Result.resolveLaneBitWidth(value_type, rt);
const lane_bits = try lhs.resolveLaneBitWidth();
if (try rhs.resolveLaneBitWidth() != lane_bits) return RuntimeError.InvalidSpirV;
if (try low_dst.resolveLaneBitWidth() != lane_bits) return RuntimeError.InvalidSpirV;
if (try high_dst.resolveLaneBitWidth() != lane_bits) return RuntimeError.InvalidSpirV;
switch (lane_bits) {
inline 8, 16, 32, 64 => |bits| {
//const SIntT = Value.getPrimitiveFieldType(.SInt, bits);
const UIntT = Value.getPrimitiveFieldType(.UInt, bits);
const WideSIntT = std.meta.Int(.signed, bits * 2);
const WideUIntT = std.meta.Int(.unsigned, bits * 2);
for (0..lane_count) |lane_index| {
const l = try Value.readLane(.SInt, bits, lhs, lane_index);
const r = try Value.readLane(.SInt, bits, rhs, lane_index);
const product_bits: WideUIntT = if (is_signed) blk: {
const SIntT = Value.getPrimitiveFieldType(.SInt, bits);
const WideSIntT = std.meta.Int(.signed, bits * 2);
const l: SIntT = try Value.readLane(.SInt, bits, lhs, lane_index);
const r: SIntT = try Value.readLane(.SInt, bits, rhs, lane_index);
const product: WideSIntT = @as(WideSIntT, l) * @as(WideSIntT, r);
break :blk @bitCast(product);
} else blk: {
const l: UIntT = try Value.readLane(.UInt, bits, lhs, lane_index);
const r: UIntT = try Value.readLane(.UInt, bits, rhs, lane_index);
break :blk @as(WideUIntT, l) * @as(WideUIntT, r);
};
const product: WideSIntT = @as(WideSIntT, l) * @as(WideSIntT, r);
const product_bits: WideUIntT = @bitCast(product);
const low: UIntT = @truncate(product_bits);
const high: UIntT = @truncate(product_bits >> bits);
const lsb_bits: UIntT = @truncate(product_bits);
const msb_bits: UIntT = @truncate(product_bits >> bits);
std.debug.print("test 0x{X} - 0x{X}\n", .{ high, low });
try Value.writeLane(.SInt, bits, lsb_dst, lane_index, @bitCast(lsb_bits));
try Value.writeLane(.SInt, bits, msb_dst, lane_index, @bitCast(msb_bits));
try writeMulExtendedBits(bits, low_dst, lane_index, low);
try writeMulExtendedBits(bits, high_dst, lane_index, high);
}
},
else => return RuntimeError.InvalidSpirV,
}
}
fn opUMulExtended(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void {
try opMulExtended(false, rt);
}
fn opSMulExtended(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void {
try opMulExtended(true, rt);
}
fn opSpecConstant(allocator: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) RuntimeError!void {
const location = rt.it.emitSourceLocation();
_ = rt.it.skip();
@@ -2258,7 +2348,6 @@ fn opImageRead(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void
}
}
fn opImageWrite(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void {
const image = &rt.results[try rt.it.next()];
const coordinate = try rt.results[try rt.it.next()].getValue();
@@ -2396,8 +2485,19 @@ fn opImageWrite(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!voi
const z = helpers.readCoordLane(coordinate, 2) catch 0;
switch (texel.*) {
.Float, .Vector4f32, .Vector3f32, .Vector2f32 => try rt.image_api.writeImageFloat4(driver_image, x, y, z, try helpers.readFloatTexel(texel)),
.Int, .Vector4i32, .Vector3i32, .Vector2i32, .Vector4u32, .Vector3u32, .Vector2u32 => try rt.image_api.writeImageInt4(driver_image, x, y, z, try helpers.readIntTexel(texel)),
.Float,
.Vector4f32,
.Vector3f32,
.Vector2f32,
=> try rt.image_api.writeImageFloat4(driver_image, x, y, z, try helpers.readFloatTexel(texel)),
.Int,
.Vector4i32,
.Vector3i32,
.Vector2i32,
.Vector4u32,
.Vector3u32,
.Vector2u32,
=> try rt.image_api.writeImageInt4(driver_image, x, y, z, try helpers.readIntTexel(texel)),
.Vector => |lanes| {
if (lanes.len == 0) return RuntimeError.InvalidSpirV;
switch (lanes[0]) {