diff --git a/src/Value.zig b/src/Value.zig index bffacd7..685e7a4 100644 --- a/src/Value.zig +++ b/src/Value.zig @@ -291,73 +291,100 @@ pub const Value = union(Type) { .Int => |i| { switch (i.bit_count) { 8 => output[0] = @bitCast(i.value.uint8), - 16 => std.mem.copyForwards(u8, output[0..], std.mem.asBytes(&i.value.uint16)), - 32 => std.mem.copyForwards(u8, output[0..], std.mem.asBytes(&i.value.uint32)), - 64 => std.mem.copyForwards(u8, output[0..], std.mem.asBytes(&i.value.uint64)), + 16 => @memcpy(output[0..2], std.mem.asBytes(&i.value.uint16)), + 32 => @memcpy(output[0..4], std.mem.asBytes(&i.value.uint32)), + 64 => @memcpy(output[0..8], std.mem.asBytes(&i.value.uint64)), else => return RuntimeError.InvalidValueType, } return @divExact(i.bit_count, 8); }, .Float => |f| { switch (f.bit_count) { - 16 => std.mem.copyForwards(u8, output[0..], std.mem.asBytes(&f.value.float16)), - 32 => std.mem.copyForwards(u8, output[0..], std.mem.asBytes(&f.value.float32)), - 64 => std.mem.copyForwards(u8, output[0..], std.mem.asBytes(&f.value.float64)), + 16 => @memcpy(output[0..2], std.mem.asBytes(&f.value.float16)), + 32 => @memcpy(output[0..4], std.mem.asBytes(&f.value.float32)), + 64 => @memcpy(output[0..8], std.mem.asBytes(&f.value.float64)), else => return RuntimeError.InvalidValueType, } return @divExact(f.bit_count, 8); }, .Vector4f32 => |vec| { inline for (0..4) |i| { - std.mem.copyForwards(u8, output[(i * 4)..], std.mem.asBytes(&vec[i])); + const start = i * 4; + const end = (i + 1) * 4; + if (start >= output.len or end > output.len) return RuntimeError.OutOfBounds; + @memcpy(output[start..end], std.mem.asBytes(&vec[i])); } return 4 * 4; }, .Vector3f32 => |vec| { inline for (0..3) |i| { - std.mem.copyForwards(u8, output[(i * 4)..], std.mem.asBytes(&vec[i])); + const start = i * 4; + const end = (i + 1) * 4; + if (start >= output.len or end > output.len) return RuntimeError.OutOfBounds; + @memcpy(output[start..end], std.mem.asBytes(&vec[i])); } return 3 * 4; }, .Vector2f32 => |vec| { inline for (0..2) |i| { - std.mem.copyForwards(u8, output[(i * 4)..], std.mem.asBytes(&vec[i])); + const start = i * 4; + const end = (i + 1) * 4; + if (start >= output.len or end > output.len) return RuntimeError.OutOfBounds; + @memcpy(output[start..end], std.mem.asBytes(&vec[i])); } return 2 * 4; }, .Vector4i32 => |vec| { inline for (0..4) |i| { - std.mem.copyForwards(u8, output[(i * 4)..], std.mem.asBytes(&vec[i])); + const start = i * 4; + const end = (i + 1) * 4; + if (start >= output.len or end > output.len) return RuntimeError.OutOfBounds; + @memcpy(output[start..end], std.mem.asBytes(&vec[i])); } return 4 * 4; }, .Vector3i32 => |vec| { inline for (0..3) |i| { - std.mem.copyForwards(u8, output[(i * 4)..], std.mem.asBytes(&vec[i])); + const start = i * 4; + const end = (i + 1) * 4; + if (start >= output.len or end > output.len) return RuntimeError.OutOfBounds; + @memcpy(output[start..end], std.mem.asBytes(&vec[i])); } return 3 * 4; }, .Vector2i32 => |vec| { inline for (0..2) |i| { - std.mem.copyForwards(u8, output[(i * 4)..], std.mem.asBytes(&vec[i])); + const start = i * 4; + const end = (i + 1) * 4; + if (start >= output.len or end > output.len) return RuntimeError.OutOfBounds; + @memcpy(output[start..end], std.mem.asBytes(&vec[i])); } return 2 * 4; }, .Vector4u32 => |vec| { inline for (0..4) |i| { - std.mem.copyForwards(u8, output[(i * 4)..], std.mem.asBytes(&vec[i])); + const start = i * 4; + const end = (i + 1) * 4; + if (start >= output.len or end > output.len) return RuntimeError.OutOfBounds; + @memcpy(output[start..end], std.mem.asBytes(&vec[i])); } return 4 * 4; }, .Vector3u32 => |vec| { inline for (0..3) |i| { - std.mem.copyForwards(u8, output[(i * 4)..], std.mem.asBytes(&vec[i])); + const start = i * 4; + const end = (i + 1) * 4; + if (start >= output.len or end > output.len) return RuntimeError.OutOfBounds; + @memcpy(output[start..end], std.mem.asBytes(&vec[i])); } return 3 * 4; }, .Vector2u32 => |vec| { inline for (0..2) |i| { - std.mem.copyForwards(u8, output[(i * 4)..], std.mem.asBytes(&vec[i])); + const start = i * 4; + const end = (i + 1) * 4; + if (start >= output.len or end > output.len) return RuntimeError.OutOfBounds; + @memcpy(output[start..end], std.mem.asBytes(&vec[i])); } return 2 * 4; }, @@ -408,18 +435,18 @@ pub const Value = union(Type) { .Int => |*i| { switch (i.bit_count) { 8 => i.value.uint8 = @bitCast(input[0]), - 16 => std.mem.copyForwards(u8, std.mem.asBytes(&i.value.uint16), input[0..2]), - 32 => std.mem.copyForwards(u8, std.mem.asBytes(&i.value.uint32), input[0..4]), - 64 => std.mem.copyForwards(u8, std.mem.asBytes(&i.value.uint64), input[0..8]), + 16 => @memcpy(std.mem.asBytes(&i.value.uint16), input[0..2]), + 32 => @memcpy(std.mem.asBytes(&i.value.uint32), input[0..4]), + 64 => @memcpy(std.mem.asBytes(&i.value.uint64), input[0..8]), else => return RuntimeError.InvalidValueType, } return @divExact(i.bit_count, 8); }, .Float => |*f| { switch (f.bit_count) { - 16 => std.mem.copyForwards(u8, std.mem.asBytes(&f.value.float16), input[0..2]), - 32 => std.mem.copyForwards(u8, std.mem.asBytes(&f.value.float32), input[0..4]), - 64 => std.mem.copyForwards(u8, std.mem.asBytes(&f.value.float64), input[0..8]), + 16 => @memcpy(std.mem.asBytes(&f.value.float16), input[0..2]), + 32 => @memcpy(std.mem.asBytes(&f.value.float32), input[0..4]), + 64 => @memcpy(std.mem.asBytes(&f.value.float64), input[0..8]), else => return RuntimeError.InvalidValueType, } return @divExact(f.bit_count, 8); @@ -429,7 +456,7 @@ pub const Value = union(Type) { const start = i * 4; const end = (i + 1) * 4; if (start >= input.len or end > input.len) return RuntimeError.OutOfBounds; - std.mem.copyForwards(u8, std.mem.asBytes(&vec[i]), input[start..end]); + @memcpy(std.mem.asBytes(&vec[i]), input[start..end]); } return 4 * 4; }, @@ -438,7 +465,7 @@ pub const Value = union(Type) { const start = i * 4; const end = (i + 1) * 4; if (start >= input.len or end > input.len) return RuntimeError.OutOfBounds; - std.mem.copyForwards(u8, std.mem.asBytes(&vec[i]), input[start..end]); + @memcpy(std.mem.asBytes(&vec[i]), input[start..end]); } return 3 * 4; }, @@ -447,7 +474,7 @@ pub const Value = union(Type) { const start = i * 4; const end = (i + 1) * 4; if (start >= input.len or end > input.len) return RuntimeError.OutOfBounds; - std.mem.copyForwards(u8, std.mem.asBytes(&vec[i]), input[start..end]); + @memcpy(std.mem.asBytes(&vec[i]), input[start..end]); } return 2 * 4; }, @@ -456,7 +483,7 @@ pub const Value = union(Type) { const start = i * 4; const end = (i + 1) * 4; if (start >= input.len or end > input.len) return RuntimeError.OutOfBounds; - std.mem.copyForwards(u8, std.mem.asBytes(&vec[i]), input[start..end]); + @memcpy(std.mem.asBytes(&vec[i]), input[start..end]); } return 4 * 4; }, @@ -465,7 +492,7 @@ pub const Value = union(Type) { const start = i * 4; const end = (i + 1) * 4; if (start >= input.len or end > input.len) return RuntimeError.OutOfBounds; - std.mem.copyForwards(u8, std.mem.asBytes(&vec[i]), input[start..end]); + @memcpy(std.mem.asBytes(&vec[i]), input[start..end]); } return 3 * 4; }, @@ -474,7 +501,7 @@ pub const Value = union(Type) { const start = i * 4; const end = (i + 1) * 4; if (start >= input.len or end > input.len) return RuntimeError.OutOfBounds; - std.mem.copyForwards(u8, std.mem.asBytes(&vec[i]), input[start..end]); + @memcpy(std.mem.asBytes(&vec[i]), input[start..end]); } return 2 * 4; }, @@ -483,7 +510,7 @@ pub const Value = union(Type) { const start = i * 4; const end = (i + 1) * 4; if (start >= input.len or end > input.len) return RuntimeError.OutOfBounds; - std.mem.copyForwards(u8, std.mem.asBytes(&vec[i]), input[start..end]); + @memcpy(std.mem.asBytes(&vec[i]), input[start..end]); } return 4 * 4; }, @@ -492,7 +519,7 @@ pub const Value = union(Type) { const start = i * 4; const end = (i + 1) * 4; if (start >= input.len or end > input.len) return RuntimeError.OutOfBounds; - std.mem.copyForwards(u8, std.mem.asBytes(&vec[i]), input[start..end]); + @memcpy(std.mem.asBytes(&vec[i]), input[start..end]); } return 3 * 4; }, @@ -501,7 +528,7 @@ pub const Value = union(Type) { const start = i * 4; const end = (i + 1) * 4; if (start >= input.len or end > input.len) return RuntimeError.OutOfBounds; - std.mem.copyForwards(u8, std.mem.asBytes(&vec[i]), input[start..end]); + @memcpy(std.mem.asBytes(&vec[i]), input[start..end]); } return 2 * 4; }, @@ -790,8 +817,8 @@ pub const Value = union(Type) { pub fn resolveLaneBitWidth(self: *const Self) RuntimeError!SpvWord { return switch (self.*) { .Bool => 8, - .Float => |f| f.bit_length, - .Int => |i| i.bit_length, + .Float => |f| @intCast(f.bit_count), + .Int => |i| @intCast(i.bit_count), .Vector => |v| v[0].resolveLaneBitWidth(), .Vector4f32, .Vector3f32, diff --git a/src/opcodes.zig b/src/opcodes.zig index 22ae025..ab8cd98 100644 --- a/src/opcodes.zig +++ b/src/opcodes.zig @@ -324,6 +324,7 @@ pub fn initRuntimeDispatcher() void { runtime_dispatcher[@intFromEnum(spv.SpvOp.VectorTimesMatrix)] = MathEngine(.Float, .VectorTimesMatrix, false).op; // TODO runtime_dispatcher[@intFromEnum(spv.SpvOp.VectorTimesScalar)] = MathEngine(.Float, .VectorTimesScalar, false).op; runtime_dispatcher[@intFromEnum(spv.SpvOp.SMulExtended)] = opSMulExtended; + runtime_dispatcher[@intFromEnum(spv.SpvOp.UMulExtended)] = opUMulExtended; runtime_dispatcher[@intFromEnum(spv.SpvOp.ImageRead)] = opImageRead; runtime_dispatcher[@intFromEnum(spv.SpvOp.ImageWrite)] = opImageWrite; // zig fmt: on @@ -1781,58 +1782,147 @@ fn opConstantComposite(allocator: std.mem.Allocator, _: SpvWord, rt: *Runtime) R } } -fn opSMulExtended(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { - const result_type_id = try rt.it.next(); - const id = try rt.it.next(); +fn writeMulExtendedBits(comptime bits: u32, dst: *Value, lane_index: usize, value: Value.getPrimitiveFieldType(.UInt, bits)) RuntimeError!void { + switch (dst.*) { + .Int => |*i| { + if (i.bit_count != bits) return RuntimeError.InvalidSpirV; + if (i.is_signed) { + switch (bits) { + 8 => i.value.sint8 = @bitCast(value), + 16 => i.value.sint16 = @bitCast(value), + 32 => i.value.sint32 = @bitCast(value), + 64 => i.value.sint64 = @bitCast(value), + else => unreachable, + } + } else { + switch (bits) { + 8 => i.value.uint8 = value, + 16 => i.value.uint16 = value, + 32 => i.value.uint32 = value, + 64 => i.value.uint64 = value, + else => unreachable, + } + } + }, + .Vector => |lanes| try writeMulExtendedBits(bits, &lanes[lane_index], 0, value), + .Vector2i32 => |*v| switch (lane_index) { + inline 0...1 => |i| if (bits == 32) { + v[i] = @bitCast(value); + } else { + return RuntimeError.InvalidSpirV; + }, + else => return RuntimeError.InvalidSpirV, + }, + .Vector3i32 => |*v| switch (lane_index) { + inline 0...2 => |i| if (bits == 32) { + v[i] = @bitCast(value); + } else { + return RuntimeError.InvalidSpirV; + }, + else => return RuntimeError.InvalidSpirV, + }, + .Vector4i32 => |*v| switch (lane_index) { + inline 0...3 => |i| if (bits == 32) { + v[i] = @bitCast(value); + } else { + return RuntimeError.InvalidSpirV; + }, + else => return RuntimeError.InvalidSpirV, + }, + .Vector2u32 => |*v| switch (lane_index) { + inline 0...1 => |i| if (bits == 32) { + v[i] = @bitCast(value); + } else { + return RuntimeError.InvalidSpirV; + }, + else => return RuntimeError.InvalidSpirV, + }, + .Vector3u32 => |*v| switch (lane_index) { + inline 0...2 => |i| if (bits == 32) { + v[i] = @bitCast(value); + } else { + return RuntimeError.InvalidSpirV; + }, + else => return RuntimeError.InvalidSpirV, + }, + .Vector4u32 => |*v| switch (lane_index) { + inline 0...3 => |i| if (bits == 32) { + v[i] = @bitCast(value); + } else { + return RuntimeError.InvalidSpirV; + }, + else => return RuntimeError.InvalidSpirV, + }, + else => return RuntimeError.InvalidSpirV, + } +} + +fn opMulExtended(comptime is_signed: bool, rt: *Runtime) RuntimeError!void { + _ = try rt.it.next(); // Result Type + const result_id = try rt.it.next(); const lhs = try rt.results[try rt.it.next()].getValue(); const rhs = try rt.results[try rt.it.next()].getValue(); - const dst = try rt.results[id].getValue(); - const result_members = switch (dst.*) { - .Structure => |s| s.values, + const result = try rt.results[result_id].getValue(); + const result_members = switch (result.*) { + .Structure => |*s| s.values, else => return RuntimeError.InvalidSpirV, }; if (result_members.len != 2) return RuntimeError.InvalidSpirV; - const lsb_dst = &result_members[0]; - const msb_dst = &result_members[1]; + const low_dst = &result_members[0]; + const high_dst = &result_members[1]; - const result_type = (try rt.results[result_type_id].getVariant()).Type; - const member_types = switch (result_type) { - .Structure => |s| s.members_type_word, - else => return RuntimeError.InvalidSpirV, - }; - if (member_types.len != 2) return RuntimeError.InvalidSpirV; + const lane_count = try lhs.resolveLaneCount(); + if (try rhs.resolveLaneCount() != lane_count) return RuntimeError.InvalidSpirV; + if (try low_dst.resolveLaneCount() != lane_count) return RuntimeError.InvalidSpirV; + if (try high_dst.resolveLaneCount() != lane_count) return RuntimeError.InvalidSpirV; - const value_type = (try rt.results[member_types[0]].getVariant()).Type; - const lane_count = try Result.resolveLaneCount(value_type); - const lane_bits = try Result.resolveLaneBitWidth(value_type, rt); + const lane_bits = try lhs.resolveLaneBitWidth(); + if (try rhs.resolveLaneBitWidth() != lane_bits) return RuntimeError.InvalidSpirV; + if (try low_dst.resolveLaneBitWidth() != lane_bits) return RuntimeError.InvalidSpirV; + if (try high_dst.resolveLaneBitWidth() != lane_bits) return RuntimeError.InvalidSpirV; switch (lane_bits) { inline 8, 16, 32, 64 => |bits| { - //const SIntT = Value.getPrimitiveFieldType(.SInt, bits); const UIntT = Value.getPrimitiveFieldType(.UInt, bits); - const WideSIntT = std.meta.Int(.signed, bits * 2); const WideUIntT = std.meta.Int(.unsigned, bits * 2); for (0..lane_count) |lane_index| { - const l = try Value.readLane(.SInt, bits, lhs, lane_index); - const r = try Value.readLane(.SInt, bits, rhs, lane_index); + const product_bits: WideUIntT = if (is_signed) blk: { + const SIntT = Value.getPrimitiveFieldType(.SInt, bits); + const WideSIntT = std.meta.Int(.signed, bits * 2); + const l: SIntT = try Value.readLane(.SInt, bits, lhs, lane_index); + const r: SIntT = try Value.readLane(.SInt, bits, rhs, lane_index); + const product: WideSIntT = @as(WideSIntT, l) * @as(WideSIntT, r); + break :blk @bitCast(product); + } else blk: { + const l: UIntT = try Value.readLane(.UInt, bits, lhs, lane_index); + const r: UIntT = try Value.readLane(.UInt, bits, rhs, lane_index); + break :blk @as(WideUIntT, l) * @as(WideUIntT, r); + }; - const product: WideSIntT = @as(WideSIntT, l) * @as(WideSIntT, r); - const product_bits: WideUIntT = @bitCast(product); + const low: UIntT = @truncate(product_bits); + const high: UIntT = @truncate(product_bits >> bits); - const lsb_bits: UIntT = @truncate(product_bits); - const msb_bits: UIntT = @truncate(product_bits >> bits); + std.debug.print("test 0x{X} - 0x{X}\n", .{ high, low }); - try Value.writeLane(.SInt, bits, lsb_dst, lane_index, @bitCast(lsb_bits)); - try Value.writeLane(.SInt, bits, msb_dst, lane_index, @bitCast(msb_bits)); + try writeMulExtendedBits(bits, low_dst, lane_index, low); + try writeMulExtendedBits(bits, high_dst, lane_index, high); } }, else => return RuntimeError.InvalidSpirV, } } +fn opUMulExtended(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { + try opMulExtended(false, rt); +} + +fn opSMulExtended(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { + try opMulExtended(true, rt); +} + fn opSpecConstant(allocator: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) RuntimeError!void { const location = rt.it.emitSourceLocation(); _ = rt.it.skip(); @@ -2258,7 +2348,6 @@ fn opImageRead(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void } } - fn opImageWrite(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { const image = &rt.results[try rt.it.next()]; const coordinate = try rt.results[try rt.it.next()].getValue(); @@ -2396,8 +2485,19 @@ fn opImageWrite(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!voi const z = helpers.readCoordLane(coordinate, 2) catch 0; switch (texel.*) { - .Float, .Vector4f32, .Vector3f32, .Vector2f32 => try rt.image_api.writeImageFloat4(driver_image, x, y, z, try helpers.readFloatTexel(texel)), - .Int, .Vector4i32, .Vector3i32, .Vector2i32, .Vector4u32, .Vector3u32, .Vector2u32 => try rt.image_api.writeImageInt4(driver_image, x, y, z, try helpers.readIntTexel(texel)), + .Float, + .Vector4f32, + .Vector3f32, + .Vector2f32, + => try rt.image_api.writeImageFloat4(driver_image, x, y, z, try helpers.readFloatTexel(texel)), + .Int, + .Vector4i32, + .Vector3i32, + .Vector2i32, + .Vector4u32, + .Vector3u32, + .Vector2u32, + => try rt.image_api.writeImageInt4(driver_image, x, y, z, try helpers.readIntTexel(texel)), .Vector => |lanes| { if (lanes.len == 0) return RuntimeError.InvalidSpirV; switch (lanes[0]) {