fixing vector times scalar
This commit is contained in:
@@ -31,7 +31,7 @@ pub fn build(b: *std.Build) void {
|
|||||||
|
|
||||||
addSandbox(b, target, optimize, use_llvm, spv_mod, &install_spv_lib.step);
|
addSandbox(b, target, optimize, use_llvm, spv_mod, &install_spv_lib.step);
|
||||||
addExample(b, target, optimize, use_llvm, spv_mod, &install_spv_lib.step);
|
addExample(b, target, optimize, use_llvm, spv_mod, &install_spv_lib.step);
|
||||||
addZigTests(b, target, optimize, spv_mod, zmath);
|
addZigTests(b, target, optimize, use_llvm, spv_mod, zmath);
|
||||||
addCffi(b, target, optimize, use_llvm, spv_mod);
|
addCffi(b, target, optimize, use_llvm, spv_mod);
|
||||||
addDocs(b, spv_mod);
|
addDocs(b, spv_mod);
|
||||||
}
|
}
|
||||||
@@ -146,6 +146,7 @@ fn addZigTests(
|
|||||||
b: *std.Build,
|
b: *std.Build,
|
||||||
target: std.Build.ResolvedTarget,
|
target: std.Build.ResolvedTarget,
|
||||||
optimize: std.builtin.OptimizeMode,
|
optimize: std.builtin.OptimizeMode,
|
||||||
|
use_llvm: bool,
|
||||||
spv_mod: *std.Build.Module,
|
spv_mod: *std.Build.Module,
|
||||||
zmath: *std.Build.Dependency,
|
zmath: *std.Build.Dependency,
|
||||||
) void {
|
) void {
|
||||||
@@ -172,6 +173,7 @@ fn addZigTests(
|
|||||||
.path = b.path("test/test_runner.zig"),
|
.path = b.path("test/test_runner.zig"),
|
||||||
.mode = .simple,
|
.mode = .simple,
|
||||||
},
|
},
|
||||||
|
.use_llvm = use_llvm,
|
||||||
});
|
});
|
||||||
|
|
||||||
const run_tests = b.addRunArtifact(tests);
|
const run_tests = b.addRunArtifact(tests);
|
||||||
|
|||||||
+15
-7
@@ -933,7 +933,7 @@ fn MathEngine(comptime T: PrimitiveType, comptime Op: MathOp, comptime IsAtomic:
|
|||||||
fn applyScalar(bit_count: SpvWord, d: *Value, l: *Value, r: *Value) RuntimeError!void {
|
fn applyScalar(bit_count: SpvWord, d: *Value, l: *Value, r: *Value) RuntimeError!void {
|
||||||
switch (bit_count) {
|
switch (bit_count) {
|
||||||
inline 8, 16, 32, 64 => |bits| {
|
inline 8, 16, 32, 64 => |bits| {
|
||||||
if (bits == 8 and T == .Float) return RuntimeError.InvalidSpirV;
|
if (bits == 8 and T == .Float) return RuntimeError.UnsupportedSpirV;
|
||||||
|
|
||||||
const ScalarT = Value.getPrimitiveFieldType(T, bits);
|
const ScalarT = Value.getPrimitiveFieldType(T, bits);
|
||||||
const d_field = try Value.getPrimitiveField(T, bits, d);
|
const d_field = try Value.getPrimitiveField(T, bits, d);
|
||||||
@@ -941,13 +941,18 @@ fn MathEngine(comptime T: PrimitiveType, comptime Op: MathOp, comptime IsAtomic:
|
|||||||
const r_field = try Value.getPrimitiveField(T, bits, r);
|
const r_field = try Value.getPrimitiveField(T, bits, r);
|
||||||
d_field.* = try operation(ScalarT, l_field.*, r_field.*);
|
d_field.* = try operation(ScalarT, l_field.*, r_field.*);
|
||||||
},
|
},
|
||||||
else => return RuntimeError.InvalidSpirV,
|
else => return RuntimeError.UnsupportedSpirV,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline fn applyVectorTimesScalarF32(d: []Value, l: []const Value, r: f32) void {
|
inline fn applyVectorTimesScalarFloat(comptime bit_count: SpvWord, d: []Value, l: []const Value, r_v: *const Value) RuntimeError!void {
|
||||||
for (d, l) |*d_v, l_v| {
|
for (d, l) |*d_v, l_v| {
|
||||||
d_v.Float.value.float32 = l_v.Float.value.float32 * r;
|
switch (bit_count) {
|
||||||
|
16 => d_v.Float.value.float16 = l_v.Float.value.float16 * r_v.Float.value.float16,
|
||||||
|
32 => d_v.Float.value.float32 = l_v.Float.value.float32 * r_v.Float.value.float32,
|
||||||
|
64 => d_v.Float.value.float64 = l_v.Float.value.float64 * r_v.Float.value.float64,
|
||||||
|
else => return RuntimeError.UnsupportedSpirV,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -963,7 +968,7 @@ fn MathEngine(comptime T: PrimitiveType, comptime Op: MathOp, comptime IsAtomic:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline fn applySIMDVectorf32(comptime N: usize, d: *@Vector(N, f32), l: *const @Vector(N, f32), r: *const Value) RuntimeError!void {
|
fn applySIMDVectorf32(comptime N: usize, d: *@Vector(N, f32), l: *const @Vector(N, f32), r: *const Value) RuntimeError!void {
|
||||||
switch (Op) {
|
switch (Op) {
|
||||||
.VectorTimesScalar => applyVectorSIMDTimesScalarF32(N, d, l, r.Float.value.float32),
|
.VectorTimesScalar => applyVectorSIMDTimesScalarF32(N, d, l, r.Float.value.float32),
|
||||||
else => {
|
else => {
|
||||||
@@ -983,7 +988,10 @@ fn MathEngine(comptime T: PrimitiveType, comptime Op: MathOp, comptime IsAtomic:
|
|||||||
.Int, .Float => try operator.applyScalar(lane_bits, dst, lhs, rhs),
|
.Int, .Float => try operator.applyScalar(lane_bits, dst, lhs, rhs),
|
||||||
|
|
||||||
.Vector => |dst_vec| switch (Op) {
|
.Vector => |dst_vec| switch (Op) {
|
||||||
.VectorTimesScalar => operator.applyVectorTimesScalarF32(dst_vec, lhs.Vector, rhs.Float.value.float32),
|
.VectorTimesScalar => switch (lane_bits) {
|
||||||
|
inline 16, 32, 64 => |bits_count| try operator.applyVectorTimesScalarFloat(bits_count, dst_vec, lhs.Vector, rhs),
|
||||||
|
else => return RuntimeError.UnsupportedSpirV,
|
||||||
|
},
|
||||||
else => for (dst_vec, lhs.Vector, rhs.Vector) |*d_lane, *l_lane, *r_lane| {
|
else => for (dst_vec, lhs.Vector, rhs.Vector) |*d_lane, *l_lane, *r_lane| {
|
||||||
try operator.applyScalar(lane_bits, d_lane, l_lane, r_lane);
|
try operator.applyScalar(lane_bits, d_lane, l_lane, r_lane);
|
||||||
},
|
},
|
||||||
@@ -1315,7 +1323,7 @@ fn opAccessChain(allocator: std.mem.Allocator, word_count: SpvWord, rt: *Runtime
|
|||||||
if (a.indexes.len != index_count)
|
if (a.indexes.len != index_count)
|
||||||
return RuntimeError.InvalidSpirV;
|
return RuntimeError.InvalidSpirV;
|
||||||
try a.value.flushPtr(allocator);
|
try a.value.flushPtr(allocator);
|
||||||
//a.value.deinit(allocator);
|
a.value.deinit(allocator);
|
||||||
break :blk .{ a.indexes, false };
|
break :blk .{ a.indexes, false };
|
||||||
},
|
},
|
||||||
else => {},
|
else => {},
|
||||||
|
|||||||
@@ -154,3 +154,70 @@ test "Maths vectors" {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tests all mathematical operation on vec2/3/4 with scalars with all NZSL supported primitive types
|
||||||
|
test "Maths vectors with scalars" {
|
||||||
|
const allocator = std.testing.allocator;
|
||||||
|
const types = [_]type{ f32, f64, i32, u32 };
|
||||||
|
var operations = std.EnumMap(Operations, u8).init(.{
|
||||||
|
.Mul = '*',
|
||||||
|
.Div = '/',
|
||||||
|
.Mod = '%',
|
||||||
|
});
|
||||||
|
|
||||||
|
var it = operations.iterator();
|
||||||
|
while (it.next()) |op| {
|
||||||
|
inline for (2..5) |L| {
|
||||||
|
inline for (types) |T| {
|
||||||
|
const base_color: case.Vec(L, T) = .{ .val = case.random(@Vector(L, T)) };
|
||||||
|
const ratio = case.random(T);
|
||||||
|
const splat_ratio = @as(@Vector(L, T), @splat(ratio));
|
||||||
|
const expected = switch (op.key) {
|
||||||
|
.Mul => if (@typeInfo(T) == .int) @mulWithOverflow(base_color.val, splat_ratio)[0] else base_color.val * splat_ratio,
|
||||||
|
.Div => if (@typeInfo(T) == .int) @divTrunc(base_color.val, splat_ratio) else base_color.val / splat_ratio,
|
||||||
|
.Mod => @mod(base_color.val, splat_ratio),
|
||||||
|
else => unreachable,
|
||||||
|
};
|
||||||
|
|
||||||
|
const shader = try std.fmt.allocPrint(
|
||||||
|
allocator,
|
||||||
|
\\ [nzsl_version("1.1")]
|
||||||
|
\\ [feature(float64)]
|
||||||
|
\\ module;
|
||||||
|
\\
|
||||||
|
\\ struct FragOut
|
||||||
|
\\ {{
|
||||||
|
\\ [location(0)] color: vec{d}[{s}]
|
||||||
|
\\ }}
|
||||||
|
\\
|
||||||
|
\\ [entry(frag)]
|
||||||
|
\\ fn main() -> FragOut
|
||||||
|
\\ {{
|
||||||
|
\\ let output: FragOut;
|
||||||
|
\\ output.color = vec{d}[{s}]({f}) {c} {d};
|
||||||
|
\\ return output;
|
||||||
|
\\ }}
|
||||||
|
,
|
||||||
|
.{
|
||||||
|
L,
|
||||||
|
@typeName(T),
|
||||||
|
L,
|
||||||
|
@typeName(T),
|
||||||
|
base_color,
|
||||||
|
op.value.*,
|
||||||
|
ratio,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
defer allocator.free(shader);
|
||||||
|
const code = try compileNzsl(allocator, shader);
|
||||||
|
defer allocator.free(code);
|
||||||
|
try case.expect(.{
|
||||||
|
.source = code,
|
||||||
|
.expected_outputs = &.{
|
||||||
|
std.mem.asBytes(&@as([L]T, expected)),
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user