diff --git a/build.zig b/build.zig index e6057c5..b5ea803 100644 --- a/build.zig +++ b/build.zig @@ -4,7 +4,7 @@ pub fn build(b: *std.Build) void { const target = b.standardTargetOptions(.{}); const optimize = b.standardOptimizeOption(.{}); - const use_llvm = b.option(bool, "use-llvm", "use llvm") orelse false; + const use_llvm = b.option(bool, "use-llvm", "use llvm") orelse (b.release_mode != .off); const mod = b.createModule(.{ .root_source_file = b.path("src/lib.zig"), diff --git a/example/main.zig b/example/main.zig index dd105e0..88723e6 100644 --- a/example/main.zig +++ b/example/main.zig @@ -4,91 +4,120 @@ const spv = @import("spv"); const shader_source = @embedFile("shader.spv"); -const screen_width = 640; -const screen_height = 480; +const screen_width = 1250; +const screen_height = 720; pub fn main() !void { { //var gpa: std.heap.DebugAllocator(.{}) = .init; //defer _ = gpa.deinit(); - var gpa = std.heap.ArenaAllocator.init(std.heap.page_allocator); - defer gpa.deinit(); - defer sdl3.shutdown(); - const init_flags = sdl3.InitFlags{ .video = true }; + const init_flags = sdl3.InitFlags{ .video = true, .events = true }; try sdl3.init(init_flags); defer sdl3.quit(init_flags); const window = try sdl3.video.Window.init("Hello triangle", screen_width, screen_height, .{}); defer window.deinit(); - const allocator = gpa.allocator(); + const surface = try window.getSurface(); + + const allocator = std.heap.smp_allocator; var module = try spv.Module.init(allocator, @ptrCast(@alignCast(shader_source)), .{}); defer module.deinit(allocator); - const surface = try window.getSurface(); - try surface.clear(.{ .r = 0.0, .g = 0.0, .b = 0.0, .a = 0.0 }); - - { - try surface.lock(); - defer surface.unlock(); - - var pixel_map: [*]u32 = @as([*]u32, @ptrCast(@alignCast((surface.getPixels() orelse return).ptr))); - - const margin_x = @divTrunc(screen_width, 3); - const margin_y = @divTrunc(screen_height, 3); - const top_y = margin_y; - const bottom_y = (screen_height - 1) - margin_y; - const center_x = @divTrunc(screen_width, 2); - const tri_h = bottom_y - top_y; - const max_half_w = @divTrunc(screen_width, 2) - margin_x; - - var timer = try std.time.Timer.start(); - defer { - const ns = timer.lap(); - std.log.info("Took {d:.3}s to render", .{@as(f32, @floatFromInt(ns)) / std.time.ns_per_s}); - } - - for (top_y..bottom_y) |y| { - const t: f32 = @as(f32, @floatFromInt(y - top_y)) / @as(f32, @floatFromInt(tri_h)); - const half_w: usize = @intFromFloat((t * @as(f32, @floatFromInt(max_half_w))) + 0.5); - const x0 = std.math.clamp(center_x - half_w, 0, screen_width - 1); - const x1 = std.math.clamp(center_x + half_w, 0, screen_width - 1); - - for (x0..x1) |x| { - var rt = try spv.Runtime.init(allocator, &module); - defer rt.deinit(allocator); - - var output: [4]f32 = undefined; - - const entry = try rt.getEntryPointByName("main"); - const color = try rt.getResultByName("color"); - const dim = try rt.getResultByName("dim"); - const pos = try rt.getResultByName("pos"); - - try rt.writeInput(f32, &.{ @floatFromInt(x1 - x0), @floatFromInt(bottom_y - top_y) }, dim); - try rt.writeInput(f32, &.{ @floatFromInt(x), @floatFromInt(y) }, pos); - - try rt.callEntryPoint(allocator, entry); - try rt.readOutput(f32, output[0..], color); - - const rgba = surface.mapRgba( - @truncate(@as(u32, @intFromFloat(output[0] * 255.0))), - @truncate(@as(u32, @intFromFloat(output[1] * 255.0))), - @truncate(@as(u32, @intFromFloat(output[2] * 255.0))), - @truncate(@as(u32, @intFromFloat(output[3] * 255.0))), - ); - - pixel_map[(y * surface.getWidth()) + x] = rgba.value; - } + var runner_cache: std.ArrayList(Runner) = try .initCapacity(allocator, screen_height); + defer { + for (runner_cache.items) |*runner| { + runner.rt.deinit(allocator); } + runner_cache.deinit(allocator); } - try window.updateSurface(); + for (0..screen_height) |_| { + (try runner_cache.addOne(allocator)).* = .{ + .allocator = allocator, + .surface = surface, + .rt = try spv.Runtime.init(allocator, &module), + }; + } - std.Thread.sleep(2_000_000_000); + var thread_pool: std.Thread.Pool = undefined; + try thread_pool.init(.{ + .allocator = allocator, + }); + + var quit = false; + while (!quit) { + try surface.clear(.{ .r = 0.0, .g = 0.0, .b = 0.0, .a = 0.0 }); + + while (sdl3.events.poll()) |event| + switch (event) { + .quit => quit = true, + .terminating => quit = true, + else => {}, + }; + + { + try surface.lock(); + defer surface.unlock(); + + const pixel_map: [*]u32 = @as([*]u32, @ptrCast(@alignCast((surface.getPixels() orelse return).ptr))); + + var timer = try std.time.Timer.start(); + defer { + const ns = timer.lap(); + const ms = @as(f32, @floatFromInt(ns)) / std.time.ns_per_s; + std.log.info("Took {d:.3}s - {d:.3}fps to render", .{ ms, 1.0 / ms }); + } + + var wait_group: std.Thread.WaitGroup = .{}; + for (0..screen_height) |y| { + const runner = &runner_cache.items[y]; + thread_pool.spawnWg(&wait_group, Runner.run, .{ runner, y, pixel_map }); + } + thread_pool.waitAndWork(&wait_group); + } + + try window.updateSurface(); + } } std.log.info("Successfully executed", .{}); } + +const Runner = struct { + const Self = @This(); + + allocator: std.mem.Allocator, + surface: sdl3.surface.Surface, + rt: spv.Runtime, + + fn run(self: *Self, y: usize, pixel_map: [*]u32) void { + const entry = self.rt.getEntryPointByName("main") catch |err| std.debug.panic("Catch error {s}", .{@errorName(err)}); + const color = self.rt.getResultByName("color") catch |err| std.debug.panic("Catch error {s}", .{@errorName(err)}); + const time = self.rt.getResultByName("time") catch |err| std.debug.panic("Catch error {s}", .{@errorName(err)}); + const pos = self.rt.getResultByName("pos") catch |err| std.debug.panic("Catch error {s}", .{@errorName(err)}); + const res = self.rt.getResultByName("res") catch |err| std.debug.panic("Catch error {s}", .{@errorName(err)}); + var output: [4]f32 = undefined; + + var rt = self.rt; // Copy to avoid pointer access of `self` at runtime. Okay as Runtime contains only pointers and trivially copyable fields + + for (0..screen_width) |x| { + rt.writeInput(f32, &.{@as(f32, @floatFromInt(std.time.milliTimestamp()))}, time) catch |err| std.debug.panic("Catch error {s}", .{@errorName(err)}); + rt.writeInput(f32, &.{ @floatFromInt(screen_width), @floatFromInt(screen_height) }, res) catch |err| std.debug.panic("Catch error {s}", .{@errorName(err)}); + rt.writeInput(f32, &.{ @floatFromInt(x), @floatFromInt(y) }, pos) catch |err| std.debug.panic("Catch error {s}", .{@errorName(err)}); + rt.callEntryPoint(self.allocator, entry) catch |err| std.debug.panic("Catch error {s}", .{@errorName(err)}); + rt.readOutput(f32, output[0..], color) catch |err| std.debug.panic("Catch error {s}", .{@errorName(err)}); + + const rgba = self.surface.mapRgba( + @truncate(@as(u32, @intFromFloat(output[0] * 255.0))), + @truncate(@as(u32, @intFromFloat(output[1] * 255.0))), + @truncate(@as(u32, @intFromFloat(output[2] * 255.0))), + @truncate(@as(u32, @intFromFloat(output[3] * 255.0))), + ); + + pixel_map[(y * self.surface.getWidth()) + x] = rgba.value; + } + } +}; diff --git a/example/shader.nzsl b/example/shader.nzsl index 7571d06..e918355 100644 --- a/example/shader.nzsl +++ b/example/shader.nzsl @@ -3,8 +3,9 @@ module; struct FragIn { - [location(0)] dim: vec2[f32], - [location(1)] pos: vec2[f32], + [location(0)] time: u32, + [location(1)] dim: vec2[f32], + [location(2)] pos: vec2[f32], } struct FragOut diff --git a/example/shader.spv b/example/shader.spv index 361d880..f2a83ee 100644 Binary files a/example/shader.spv and b/example/shader.spv differ diff --git a/example/shader.spv.txt b/example/shader.spv.txt index 8473601..ae71392 100644 --- a/example/shader.spv.txt +++ b/example/shader.spv.txt @@ -1,75 +1,86 @@ Version 1.0 Generator: 2560130 -Bound: 45 +Bound: 51 Schema: 0 OpCapability Capability(Shader) OpMemoryModel AddressingModel(Logical) MemoryModel(GLSL450) - OpEntryPoint ExecutionModel(Fragment) %20 "main" %6 %10 %16 - OpExecutionMode %20 ExecutionMode(OriginUpperLeft) + OpEntryPoint ExecutionModel(Fragment) %25 "main" %5 %12 %15 %21 + OpExecutionMode %25 ExecutionMode(OriginUpperLeft) OpSource SourceLanguage(NZSL) 4198400 OpSourceExtension "Version: 1.1" - OpName %12 "FragIn" - OpMemberName %12 0 "dim" - OpMemberName %12 1 "pos" - OpName %17 "FragOut" - OpMemberName %17 0 "color" - OpName %6 "dim" - OpName %10 "pos" - OpName %16 "color" - OpName %20 "main" - OpDecorate %6 Decoration(Location) 0 - OpDecorate %10 Decoration(Location) 1 - OpDecorate %16 Decoration(Location) 0 - OpMemberDecorate %12 0 Decoration(Offset) 0 - OpMemberDecorate %12 1 Decoration(Offset) 8 + OpName %17 "FragIn" + OpMemberName %17 0 "time" + OpMemberName %17 1 "dim" + OpMemberName %17 2 "pos" + OpName %22 "FragOut" + OpMemberName %22 0 "color" + OpName %5 "time" + OpName %12 "dim" + OpName %15 "pos" + OpName %21 "color" + OpName %25 "main" + OpDecorate %5 Decoration(Location) 0 + OpDecorate %12 Decoration(Location) 1 + OpDecorate %15 Decoration(Location) 2 + OpDecorate %21 Decoration(Location) 0 OpMemberDecorate %17 0 Decoration(Offset) 0 + OpMemberDecorate %17 1 Decoration(Offset) 8 + OpMemberDecorate %17 2 Decoration(Offset) 16 + OpMemberDecorate %22 0 Decoration(Offset) 0 %1 = OpTypeVoid %2 = OpTypeFunction %1 - %3 = OpTypeFloat 32 - %4 = OpTypeVector %3 2 - %5 = OpTypePointer StorageClass(Input) %4 - %7 = OpTypeInt 32 1 - %8 = OpConstant %7 i32(0) - %9 = OpTypePointer StorageClass(Function) %4 -%11 = OpConstant %7 i32(1) -%12 = OpTypeStruct %4 %4 -%13 = OpTypePointer StorageClass(Function) %12 -%14 = OpTypeVector %3 4 -%15 = OpTypePointer StorageClass(Output) %14 -%17 = OpTypeStruct %14 + %3 = OpTypeInt 32 0 + %4 = OpTypePointer StorageClass(Input) %3 + %6 = OpTypeInt 32 1 + %7 = OpConstant %6 i32(0) + %8 = OpTypePointer StorageClass(Function) %3 + %9 = OpTypeFloat 32 +%10 = OpTypeVector %9 2 +%11 = OpTypePointer StorageClass(Input) %10 +%13 = OpConstant %6 i32(1) +%14 = OpTypePointer StorageClass(Function) %10 +%16 = OpConstant %6 i32(2) +%17 = OpTypeStruct %3 %10 %10 %18 = OpTypePointer StorageClass(Function) %17 -%19 = OpConstant %3 f32(1) -%42 = OpTypePointer StorageClass(Function) %14 - %6 = OpVariable %5 StorageClass(Input) -%10 = OpVariable %5 StorageClass(Input) -%16 = OpVariable %15 StorageClass(Output) -%20 = OpFunction %1 FunctionControl(0) %2 -%21 = OpLabel -%22 = OpVariable %18 StorageClass(Function) -%23 = OpVariable %13 StorageClass(Function) -%24 = OpAccessChain %9 %23 %8 - OpCopyMemory %24 %6 -%25 = OpAccessChain %9 %23 %11 - OpCopyMemory %25 %10 -%26 = OpAccessChain %9 %23 %11 -%27 = OpLoad %4 %26 -%28 = OpCompositeExtract %3 %27 0 -%29 = OpAccessChain %9 %23 %8 -%30 = OpLoad %4 %29 -%31 = OpCompositeExtract %3 %30 0 -%32 = OpFDiv %3 %28 %31 -%33 = OpAccessChain %9 %23 %11 -%34 = OpLoad %4 %33 -%35 = OpCompositeExtract %3 %34 1 -%36 = OpAccessChain %9 %23 %8 -%37 = OpLoad %4 %36 -%38 = OpCompositeExtract %3 %37 1 -%39 = OpFDiv %3 %35 %38 -%40 = OpCompositeConstruct %14 %32 %39 %19 %19 -%41 = OpAccessChain %42 %22 %8 - OpStore %41 %40 -%43 = OpLoad %17 %22 -%44 = OpCompositeExtract %14 %43 0 - OpStore %16 %44 +%19 = OpTypeVector %9 4 +%20 = OpTypePointer StorageClass(Output) %19 +%22 = OpTypeStruct %19 +%23 = OpTypePointer StorageClass(Function) %22 +%24 = OpConstant %9 f32(1) +%48 = OpTypePointer StorageClass(Function) %19 + %5 = OpVariable %4 StorageClass(Input) +%12 = OpVariable %11 StorageClass(Input) +%15 = OpVariable %11 StorageClass(Input) +%21 = OpVariable %20 StorageClass(Output) +%25 = OpFunction %1 FunctionControl(0) %2 +%26 = OpLabel +%27 = OpVariable %23 StorageClass(Function) +%28 = OpVariable %18 StorageClass(Function) +%29 = OpAccessChain %8 %28 %7 + OpCopyMemory %29 %5 +%30 = OpAccessChain %14 %28 %13 + OpCopyMemory %30 %12 +%31 = OpAccessChain %14 %28 %16 + OpCopyMemory %31 %15 +%32 = OpAccessChain %14 %28 %16 +%33 = OpLoad %10 %32 +%34 = OpCompositeExtract %9 %33 0 +%35 = OpAccessChain %14 %28 %13 +%36 = OpLoad %10 %35 +%37 = OpCompositeExtract %9 %36 0 +%38 = OpFDiv %9 %34 %37 +%39 = OpAccessChain %14 %28 %16 +%40 = OpLoad %10 %39 +%41 = OpCompositeExtract %9 %40 1 +%42 = OpAccessChain %14 %28 %13 +%43 = OpLoad %10 %42 +%44 = OpCompositeExtract %9 %43 1 +%45 = OpFDiv %9 %41 %44 +%46 = OpCompositeConstruct %19 %38 %45 %24 %24 +%47 = OpAccessChain %48 %27 %7 + OpStore %47 %46 +%49 = OpLoad %22 %27 +%50 = OpCompositeExtract %19 %49 0 + OpStore %21 %50 OpReturn OpFunctionEnd diff --git a/sandbox/main.zig b/sandbox/main.zig index 1f3061e..757bf5b 100644 --- a/sandbox/main.zig +++ b/sandbox/main.zig @@ -16,9 +16,21 @@ pub fn main() !void { var rt = try spv.Runtime.init(allocator, &module); defer rt.deinit(allocator); - try rt.callEntryPoint(allocator, try rt.getEntryPointByName("main")); + const entry = try rt.getEntryPointByName("main"); + const color = try rt.getResultByName("color"); + const time = try rt.getResultByName("time"); + const pos = try rt.getResultByName("pos"); + const res = try rt.getResultByName("res"); + var output: [4]f32 = undefined; - try rt.readOutput(f32, output[0..output.len], try rt.getResultByName("color")); + + try rt.writeInput(f32, &.{@as(f32, @floatFromInt(std.time.milliTimestamp()))}, time); + try rt.writeInput(f32, &.{ 1250.0, 720.0 }, res); + try rt.writeInput(f32, &.{ 0.0, 0.0 }, pos); + + try rt.callEntryPoint(allocator, entry); + + try rt.readOutput(f32, output[0..output.len], color); std.log.info("Output: Vec4{any}", .{output}); } std.log.info("Successfully executed", .{}); diff --git a/sandbox/shader.nzsl b/sandbox/shader.nzsl index cc696f1..c5252e6 100644 --- a/sandbox/shader.nzsl +++ b/sandbox/shader.nzsl @@ -1,15 +1,71 @@ [nzsl_version("1.1")] module; +struct FragIn +{ + [location(0)] time: f32, + [location(1)] res: vec2[f32], + [location(2)] pos: vec2[f32], +} + struct FragOut { - [location(0)] color: vec4[f32] + [location(0)] color: vec4[f32] } [entry(frag)] -fn main() -> FragOut +fn main(input: FragIn) -> FragOut { - let output: FragOut; - output.color = vec4[f32](1.0, 1.0, 1.0, 1.0); - return output; + const I: i32 = 128; + const A: f32 = 7.5; + const MA: f32 = 100.0; + const MI: f32 = 0.001; + + let uv0 = input.pos / input.res * 2.0 - vec2[f32](1.0, 1.0); + let uv = vec2[f32](uv0.x * (input.res.x / input.res.y), uv0.y); + + let col = vec3[f32](0.0, 0.0, 0.0); + let ro = vec3[f32](0.0, 0.0, -2.0); + let rd = vec3[f32](uv.x, uv.y, 1.0); + let dt = 0.0; + let ds = 0.0; + let dm = -1.0; + let p = ro; + let c = vec3[f32](0.0, 0.0, 0.0); + + let l = vec3[f32](0.0, sin(input.time * 0.2) * 4.0, cos(input.time * 0.2) * 4.0); + + for i in 0 -> I + { + p = ro + rd * dt; + ds = length(c - p) - 1.0; + dt += ds; + + if (dm == -1.0 || ds < dm) + dm = ds; + + if (ds <= MI) + { + let value = max(dot(normalize(c - p), normalize(p - l)) - 0.35, 0.0); + col = vec3[f32](value, value, value); + break; + } + + if (ds >= MA) + { + if (dot(normalize(rd), normalize(l - ro)) <= 1.0) + { + let value = max(dot(normalize(rd), normalize(l - ro)) + 0.15, 0.05)/ 1.15 * (1.0 - dm * A); + col = vec3[f32](value, value, value); + } + break; + } + } + + if (col == vec3[f32](0.0, 0.0, 0.0)) + discard; + + let output: FragOut; + output.color = vec4[f32](col.x, col.y, col.z, 1.0); + return output; } diff --git a/sandbox/shader.spv b/sandbox/shader.spv index db7c59e..4306992 100644 Binary files a/sandbox/shader.spv and b/sandbox/shader.spv differ diff --git a/sandbox/shader.spv.txt b/sandbox/shader.spv.txt index 0d65b26..880ef68 100644 --- a/sandbox/shader.spv.txt +++ b/sandbox/shader.spv.txt @@ -1,39 +1,292 @@ Version 1.0 Generator: 2560130 -Bound: 20 +Bound: 210 Schema: 0 - OpCapability Capability(Shader) - OpMemoryModel AddressingModel(Logical) MemoryModel(GLSL450) - OpEntryPoint ExecutionModel(Fragment) %12 "main" %6 - OpExecutionMode %12 ExecutionMode(OriginUpperLeft) - OpSource SourceLanguage(NZSL) 4198400 - OpSourceExtension "Version: 1.1" - OpName %7 "FragOut" - OpMemberName %7 0 "color" - OpName %6 "color" - OpName %12 "main" - OpDecorate %6 Decoration(Location) 0 - OpMemberDecorate %7 0 Decoration(Offset) 0 - %1 = OpTypeVoid - %2 = OpTypeFunction %1 - %3 = OpTypeFloat 32 - %4 = OpTypeVector %3 4 - %5 = OpTypePointer StorageClass(Output) %4 - %7 = OpTypeStruct %4 - %8 = OpTypePointer StorageClass(Function) %7 - %9 = OpTypeInt 32 1 -%10 = OpConstant %9 i32(0) -%11 = OpConstant %3 f32(1) -%17 = OpTypePointer StorageClass(Function) %4 - %6 = OpVariable %5 StorageClass(Output) -%12 = OpFunction %1 FunctionControl(0) %2 -%13 = OpLabel -%14 = OpVariable %8 StorageClass(Function) -%15 = OpCompositeConstruct %4 %11 %11 %11 %11 -%16 = OpAccessChain %17 %14 %10 - OpStore %16 %15 -%18 = OpLoad %7 %14 -%19 = OpCompositeExtract %4 %18 0 - OpStore %6 %19 - OpReturn - OpFunctionEnd + OpCapability Capability(Shader) + %43 = OpExtInstImport "GLSL.std.450" + OpMemoryModel AddressingModel(Logical) MemoryModel(GLSL450) + OpEntryPoint ExecutionModel(Fragment) %44 "main" %5 %11 %14 %20 + OpExecutionMode %44 ExecutionMode(OriginUpperLeft) + OpSource SourceLanguage(NZSL) 4198400 + OpSourceExtension "Version: 1.1" + OpName %16 "FragIn" + OpMemberName %16 0 "time" + OpMemberName %16 1 "res" + OpMemberName %16 2 "pos" + OpName %21 "FragOut" + OpMemberName %21 0 "color" + OpName %5 "time" + OpName %11 "res" + OpName %14 "pos" + OpName %20 "color" + OpName %44 "main" + OpDecorate %5 Decoration(Location) 0 + OpDecorate %11 Decoration(Location) 1 + OpDecorate %14 Decoration(Location) 2 + OpDecorate %20 Decoration(Location) 0 + OpMemberDecorate %16 0 Decoration(Offset) 0 + OpMemberDecorate %16 1 Decoration(Offset) 8 + OpMemberDecorate %16 2 Decoration(Offset) 16 + OpMemberDecorate %21 0 Decoration(Offset) 0 + %1 = OpTypeVoid + %2 = OpTypeFunction %1 + %3 = OpTypeFloat 32 + %4 = OpTypePointer StorageClass(Input) %3 + %6 = OpTypeInt 32 1 + %7 = OpConstant %6 i32(0) + %8 = OpTypePointer StorageClass(Function) %3 + %9 = OpTypeVector %3 2 + %10 = OpTypePointer StorageClass(Input) %9 + %12 = OpConstant %6 i32(1) + %13 = OpTypePointer StorageClass(Function) %9 + %15 = OpConstant %6 i32(2) + %16 = OpTypeStruct %3 %9 %9 + %17 = OpTypePointer StorageClass(Function) %16 + %18 = OpTypeVector %3 4 + %19 = OpTypePointer StorageClass(Output) %18 + %21 = OpTypeStruct %18 + %22 = OpConstant %3 f32(2) + %23 = OpConstant %3 f32(1) + %24 = OpConstant %3 f32(0) + %25 = OpTypeVector %3 3 + %26 = OpTypePointer StorageClass(Function) %25 + %27 = OpConstant %3 f32(-2) + %28 = OpConstant %3 f32(-1) + %29 = OpConstant %3 f32(0.2) + %30 = OpConstant %3 f32(4) + %31 = OpTypePointer StorageClass(Function) %6 + %32 = OpConstant %6 i32(128) + %33 = OpTypeBool + %34 = OpConstant %3 f32(0.001) + %35 = OpConstant %3 f32(0.35) + %36 = OpConstant %3 f32(100) + %37 = OpConstant %3 f32(0.15) + %38 = OpConstant %3 f32(0.05) + %39 = OpConstant %3 f32(1.15) + %40 = OpConstant %3 f32(7.5) + %41 = OpTypeVector %33 3 + %42 = OpTypePointer StorageClass(Function) %21 +%207 = OpTypePointer StorageClass(Function) %18 + %5 = OpVariable %4 StorageClass(Input) + %11 = OpVariable %10 StorageClass(Input) + %14 = OpVariable %10 StorageClass(Input) + %20 = OpVariable %19 StorageClass(Output) + %44 = OpFunction %1 FunctionControl(0) %2 + %45 = OpLabel + %46 = OpVariable %13 StorageClass(Function) + %47 = OpVariable %13 StorageClass(Function) + %48 = OpVariable %26 StorageClass(Function) + %49 = OpVariable %26 StorageClass(Function) + %50 = OpVariable %26 StorageClass(Function) + %51 = OpVariable %8 StorageClass(Function) + %52 = OpVariable %8 StorageClass(Function) + %53 = OpVariable %8 StorageClass(Function) + %54 = OpVariable %26 StorageClass(Function) + %55 = OpVariable %26 StorageClass(Function) + %56 = OpVariable %26 StorageClass(Function) + %57 = OpVariable %31 StorageClass(Function) + %58 = OpVariable %31 StorageClass(Function) + %59 = OpVariable %8 StorageClass(Function) + %60 = OpVariable %8 StorageClass(Function) + %61 = OpVariable %42 StorageClass(Function) + %62 = OpVariable %17 StorageClass(Function) + %63 = OpAccessChain %8 %62 %7 + OpCopyMemory %63 %5 + %64 = OpAccessChain %13 %62 %12 + OpCopyMemory %64 %11 + %65 = OpAccessChain %13 %62 %15 + OpCopyMemory %65 %14 + %66 = OpAccessChain %13 %62 %15 + %67 = OpLoad %9 %66 + %68 = OpAccessChain %13 %62 %12 + %69 = OpLoad %9 %68 + %70 = OpFDiv %9 %67 %69 + %71 = OpVectorTimesScalar %9 %70 %22 + %72 = OpCompositeConstruct %9 %23 %23 + %73 = OpFSub %9 %71 %72 + OpStore %46 %73 + %74 = OpLoad %9 %46 + %75 = OpCompositeExtract %3 %74 0 + %76 = OpAccessChain %13 %62 %12 + %77 = OpLoad %9 %76 + %78 = OpCompositeExtract %3 %77 0 + %79 = OpAccessChain %13 %62 %12 + %80 = OpLoad %9 %79 + %81 = OpCompositeExtract %3 %80 1 + %82 = OpFDiv %3 %78 %81 + %83 = OpFMul %3 %75 %82 + %84 = OpLoad %9 %46 + %85 = OpCompositeExtract %3 %84 1 + %86 = OpCompositeConstruct %9 %83 %85 + OpStore %47 %86 + %87 = OpCompositeConstruct %25 %24 %24 %24 + OpStore %48 %87 + %88 = OpCompositeConstruct %25 %24 %24 %27 + OpStore %49 %88 + %89 = OpLoad %9 %47 + %90 = OpCompositeExtract %3 %89 0 + %91 = OpLoad %9 %47 + %92 = OpCompositeExtract %3 %91 1 + %93 = OpCompositeConstruct %25 %90 %92 %23 + OpStore %50 %93 + OpStore %51 %24 + OpStore %52 %24 + OpStore %53 %28 + %94 = OpLoad %25 %49 + OpStore %54 %94 + %95 = OpCompositeConstruct %25 %24 %24 %24 + OpStore %55 %95 + %96 = OpAccessChain %8 %62 %7 + %97 = OpLoad %3 %96 + %98 = OpFMul %3 %97 %29 + %99 = OpExtInst %3 GLSLstd450 Sin %98 +%100 = OpFMul %3 %99 %30 +%101 = OpAccessChain %8 %62 %7 +%102 = OpLoad %3 %101 +%103 = OpFMul %3 %102 %29 +%104 = OpExtInst %3 GLSLstd450 Cos %103 +%105 = OpFMul %3 %104 %30 +%106 = OpCompositeConstruct %25 %24 %100 %105 + OpStore %56 %106 + OpStore %57 %7 + OpStore %58 %32 + OpBranch %107 +%107 = OpLabel +%111 = OpLoad %6 %57 +%112 = OpLoad %6 %58 +%113 = OpSLessThan %33 %111 %112 + OpLoopMerge %109 %110 LoopControl(0) + OpBranchConditional %113 %108 %109 +%108 = OpLabel +%114 = OpLoad %25 %49 +%115 = OpLoad %25 %50 +%116 = OpLoad %3 %51 +%117 = OpVectorTimesScalar %25 %115 %116 +%118 = OpFAdd %25 %114 %117 + OpStore %54 %118 +%119 = OpLoad %25 %55 +%120 = OpLoad %25 %54 +%121 = OpFSub %25 %119 %120 +%122 = OpExtInst %3 GLSLstd450 Length %121 +%123 = OpFSub %3 %122 %23 + OpStore %52 %123 +%124 = OpLoad %3 %51 +%125 = OpLoad %3 %52 +%126 = OpFAdd %3 %124 %125 + OpStore %51 %126 +%130 = OpLoad %3 %53 +%131 = OpFOrdEqual %33 %130 %28 +%132 = OpLoad %3 %52 +%133 = OpLoad %3 %53 +%134 = OpFOrdLessThan %33 %132 %133 +%135 = OpLogicalOr %33 %131 %134 + OpSelectionMerge %127 SelectionControl(0) + OpBranchConditional %135 %128 %129 +%128 = OpLabel +%136 = OpLoad %3 %52 + OpStore %53 %136 + OpBranch %127 +%129 = OpLabel + OpBranch %127 +%127 = OpLabel +%140 = OpLoad %3 %52 +%141 = OpFOrdLessThanEqual %33 %140 %34 + OpSelectionMerge %137 SelectionControl(0) + OpBranchConditional %141 %138 %139 +%138 = OpLabel +%142 = OpLoad %25 %55 +%143 = OpLoad %25 %54 +%144 = OpFSub %25 %142 %143 +%145 = OpExtInst %25 GLSLstd450 Normalize %144 +%146 = OpLoad %25 %54 +%147 = OpLoad %25 %56 +%148 = OpFSub %25 %146 %147 +%149 = OpExtInst %25 GLSLstd450 Normalize %148 +%150 = OpDot %3 %145 %149 +%151 = OpFSub %3 %150 %35 +%152 = OpExtInst %3 GLSLstd450 FMax %151 %24 + OpStore %59 %152 +%153 = OpLoad %3 %59 +%154 = OpLoad %3 %59 +%155 = OpLoad %3 %59 +%156 = OpCompositeConstruct %25 %153 %154 %155 + OpStore %48 %156 + OpBranch %109 +%139 = OpLabel + OpBranch %137 +%137 = OpLabel +%160 = OpLoad %3 %52 +%161 = OpFOrdGreaterThanEqual %33 %160 %36 + OpSelectionMerge %157 SelectionControl(0) + OpBranchConditional %161 %158 %159 +%158 = OpLabel +%165 = OpLoad %25 %50 +%166 = OpExtInst %25 GLSLstd450 Normalize %165 +%167 = OpLoad %25 %56 +%168 = OpLoad %25 %49 +%169 = OpFSub %25 %167 %168 +%170 = OpExtInst %25 GLSLstd450 Normalize %169 +%171 = OpDot %3 %166 %170 +%172 = OpFOrdLessThanEqual %33 %171 %23 + OpSelectionMerge %162 SelectionControl(0) + OpBranchConditional %172 %163 %164 +%163 = OpLabel +%173 = OpLoad %25 %50 +%174 = OpExtInst %25 GLSLstd450 Normalize %173 +%175 = OpLoad %25 %56 +%176 = OpLoad %25 %49 +%177 = OpFSub %25 %175 %176 +%178 = OpExtInst %25 GLSLstd450 Normalize %177 +%179 = OpDot %3 %174 %178 +%180 = OpFAdd %3 %179 %37 +%181 = OpExtInst %3 GLSLstd450 FMax %180 %38 +%182 = OpFDiv %3 %181 %39 +%183 = OpLoad %3 %53 +%184 = OpFMul %3 %183 %40 +%185 = OpFSub %3 %23 %184 +%186 = OpFMul %3 %182 %185 + OpStore %60 %186 +%187 = OpLoad %3 %60 +%188 = OpLoad %3 %60 +%189 = OpLoad %3 %60 +%190 = OpCompositeConstruct %25 %187 %188 %189 + OpStore %48 %190 + OpBranch %162 +%164 = OpLabel + OpBranch %162 +%162 = OpLabel + OpBranch %109 +%159 = OpLabel + OpBranch %157 +%157 = OpLabel +%191 = OpLoad %6 %57 +%192 = OpIAdd %6 %191 %12 + OpStore %57 %192 + OpBranch %110 +%110 = OpLabel + OpBranch %107 +%109 = OpLabel +%196 = OpLoad %25 %48 +%197 = OpCompositeConstruct %25 %24 %24 %24 +%198 = OpFOrdEqual %41 %196 %197 + OpSelectionMerge %193 SelectionControl(0) + OpBranchConditional %198 %194 %195 +%194 = OpLabel + OpKill +%195 = OpLabel + OpBranch %193 +%193 = OpLabel +%199 = OpLoad %25 %48 +%200 = OpCompositeExtract %3 %199 0 +%201 = OpLoad %25 %48 +%202 = OpCompositeExtract %3 %201 1 +%203 = OpLoad %25 %48 +%204 = OpCompositeExtract %3 %203 2 +%205 = OpCompositeConstruct %18 %200 %202 %204 %23 +%206 = OpAccessChain %207 %61 %7 + OpStore %206 %205 +%208 = OpLoad %21 %61 +%209 = OpCompositeExtract %18 %208 0 + OpStore %20 %209 + OpReturn + OpFunctionEnd diff --git a/src/Module.zig b/src/Module.zig index 2076b94..999db9c 100644 --- a/src/Module.zig +++ b/src/Module.zig @@ -100,6 +100,8 @@ pub fn init(allocator: std.mem.Allocator, source: []const SpvWord, options: Modu }); errdefer self.deinit(allocator); + op.initRuntimeDispatcher(); + self.it = WordIterator.init(self.code); const magic = self.it.next() catch return ModuleError.InvalidSpirV; diff --git a/src/Runtime.zig b/src/Runtime.zig index 1af8381..0f12f5a 100644 --- a/src/Runtime.zig +++ b/src/Runtime.zig @@ -1,3 +1,5 @@ +//! A runtime meant for actual shader invocations. + const std = @import("std"); const spv = @import("spv.zig"); const op = @import("opcodes.zig"); @@ -127,10 +129,11 @@ pub fn callEntryPoint(self: *Self, allocator: std.mem.Allocator, entry_point_ind const opcode = (opcode_data & spv.SpvOpCodeMask); var it_tmp = self.it; // Save because operations may iter on this iterator - if (std.enums.fromInt(spv.SpvOp, opcode)) |spv_op| { - if (op.RuntimeDispatcher.get(spv_op)) |pfn| { - try pfn(allocator, word_count, self); - } + if (op.runtime_dispatcher[opcode]) |pfn| { + pfn(allocator, word_count, self) catch |err| switch (err) { + RuntimeError.Killed => return, + else => return err, + }; } if (!self.it.did_jump) { _ = it_tmp.skipN(word_count); diff --git a/src/opcodes.zig b/src/opcodes.zig index a5d2b26..14942e0 100644 --- a/src/opcodes.zig +++ b/src/opcodes.zig @@ -13,7 +13,15 @@ const SpvByte = spv.SpvByte; const SpvWord = spv.SpvWord; const SpvBool = spv.SpvBool; +// OpExtInstImport +// OpExtInst Sin +// OpExtInst Cos +// OpExtInst Length +// OpExtInst Normalize +// OpExtInst FMax + const ValueType = enum { + Bool, Float, SInt, UInt, @@ -22,10 +30,15 @@ const ValueType = enum { const MathOp = enum { Add, Div, + MatrixTimesMatrix, + MatrixTimesScalar, + MatrixTimesVector, Mod, Mul, Rem, Sub, + VectorTimesMatrix, + VectorTimesScalar, }; const CondOp = enum { @@ -35,6 +48,11 @@ const CondOp = enum { Less, LessEqual, NotEqual, + LogicalEqual, + LogicalNotEqual, + LogicalAnd, + LogicalOr, + LogicalNot, }; const BitOp = enum { @@ -76,6 +94,7 @@ pub const SetupDispatcher = block: { .ConvertUToF = autoSetupConstant, .ConvertUToPtr = autoSetupConstant, .Decorate = opDecorate, + .Dot = autoSetupConstant, .EntryPoint = opEntryPoint, .ExecutionMode = opExecutionMode, .FAdd = autoSetupConstant, @@ -107,6 +126,14 @@ pub const SetupDispatcher = block: { .ISub = autoSetupConstant, .Label = opLabel, .Load = autoSetupConstant, + .LogicalAnd = autoSetupConstant, + .LogicalEqual = autoSetupConstant, + .LogicalNot = autoSetupConstant, + .LogicalNotEqual = autoSetupConstant, + .LogicalOr = autoSetupConstant, + .MatrixTimesMatrix = autoSetupConstant, + .MatrixTimesScalar = autoSetupConstant, + .MatrixTimesVector = autoSetupConstant, .MemberDecorate = opDecorateMember, .MemberName = opMemberName, .MemoryModel = opMemoryModel, @@ -145,85 +172,95 @@ pub const SetupDispatcher = block: { .ULessThanEqual = autoSetupConstant, .UMod = autoSetupConstant, .Variable = opVariable, + .VectorTimesMatrix = autoSetupConstant, + .VectorTimesScalar = autoSetupConstant, }); }; -pub const RuntimeDispatcher = block: { - @setEvalBranchQuota(65535); - break :block std.EnumMap(spv.SpvOp, OpCodeFunc).init(.{ - .AccessChain = opAccessChain, - .BitCount = BitEngine(.UInt, .BitCount).op, - .BitFieldInsert = BitEngine(.UInt, .BitFieldInsert).op, - .BitFieldSExtract = BitEngine(.SInt, .BitFieldSExtract).op, - .BitFieldUExtract = BitEngine(.UInt, .BitFieldUExtract).op, - .BitReverse = BitEngine(.UInt, .BitReverse).op, - .Bitcast = opBitcast, - .BitwiseAnd = BitEngine(.UInt, .BitwiseAnd).op, - .BitwiseOr = BitEngine(.UInt, .BitwiseOr).op, - .BitwiseXor = BitEngine(.UInt, .BitwiseXor).op, - .Branch = opBranch, - .BranchConditional = opBranchConditional, - .CompositeConstruct = opCompositeConstruct, - .CompositeExtract = opCompositeExtract, - .ConvertFToS = ConversionEngine(.Float, .SInt).op, - .ConvertFToU = ConversionEngine(.Float, .UInt).op, - .ConvertSToF = ConversionEngine(.SInt, .Float).op, - .ConvertUToF = ConversionEngine(.UInt, .Float).op, - .CopyMemory = opCopyMemory, - .FAdd = MathEngine(.Float, .Add).op, - .FConvert = ConversionEngine(.Float, .Float).op, - .FDiv = MathEngine(.Float, .Div).op, - .FMod = MathEngine(.Float, .Mod).op, - .FMul = MathEngine(.Float, .Mul).op, - .FOrdEqual = CondEngine(.Float, .Equal).op, - .FOrdGreaterThan = CondEngine(.Float, .Greater).op, - .FOrdGreaterThanEqual = CondEngine(.Float, .GreaterEqual).op, - .FOrdLessThan = CondEngine(.Float, .Less).op, - .FOrdLessThanEqual = CondEngine(.Float, .LessEqual).op, - .FOrdNotEqual = CondEngine(.Float, .NotEqual).op, - .FSub = MathEngine(.Float, .Sub).op, - .FUnordEqual = CondEngine(.Float, .Equal).op, - .FUnordGreaterThan = CondEngine(.Float, .Greater).op, - .FUnordGreaterThanEqual = CondEngine(.Float, .GreaterEqual).op, - .FUnordLessThan = CondEngine(.Float, .Less).op, - .FUnordLessThanEqual = CondEngine(.Float, .LessEqual).op, - .FUnordNotEqual = CondEngine(.Float, .NotEqual).op, - .FunctionCall = opFunctionCall, - .IAdd = MathEngine(.SInt, .Add).op, - .IEqual = CondEngine(.SInt, .Equal).op, - .IMul = MathEngine(.SInt, .Mul).op, - .INotEqual = CondEngine(.SInt, .NotEqual).op, - .ISub = MathEngine(.SInt, .Sub).op, - .Load = opLoad, - .Not = BitEngine(.UInt, .Not).op, - .Return = opReturn, - .ReturnValue = opReturnValue, - .SConvert = ConversionEngine(.SInt, .SInt).op, - .SDiv = MathEngine(.SInt, .Div).op, - .SGreaterThan = CondEngine(.SInt, .Greater).op, - .SGreaterThanEqual = CondEngine(.SInt, .GreaterEqual).op, - .SLessThan = CondEngine(.SInt, .Less).op, - .SLessThanEqual = CondEngine(.SInt, .LessEqual).op, - .SMod = MathEngine(.SInt, .Mod).op, - .ShiftLeftLogical = BitEngine(.UInt, .ShiftLeft).op, - .ShiftRightArithmetic = BitEngine(.SInt, .ShiftRightArithmetic).op, - .ShiftRightLogical = BitEngine(.UInt, .ShiftRight).op, - .Store = opStore, - .UConvert = ConversionEngine(.UInt, .UInt).op, - .UDiv = MathEngine(.UInt, .Div).op, - .UGreaterThan = CondEngine(.UInt, .Greater).op, - .UGreaterThanEqual = CondEngine(.UInt, .GreaterEqual).op, - .ULessThan = CondEngine(.UInt, .Less).op, - .ULessThanEqual = CondEngine(.UInt, .LessEqual).op, - .UMod = MathEngine(.UInt, .Mod).op, +/// Not an EnumMap as it is way too slow for this purpose +pub var runtime_dispatcher = [_]?OpCodeFunc{null} ** spv.SpvOpMaxValue; - //.QuantizeToF16 = , - //.ConvertPtrToU = , - //.SatConvertSToU = , - //.SatConvertUToS = , - //.ConvertUToPtr = , - }); -}; +pub fn initRuntimeDispatcher() void { + // zig fmt: off + runtime_dispatcher[@intFromEnum(spv.SpvOp.AccessChain)] = opAccessChain; + runtime_dispatcher[@intFromEnum(spv.SpvOp.BitCount)] = BitEngine(.UInt, .BitCount).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.BitFieldInsert)] = BitEngine(.UInt, .BitFieldInsert).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.BitFieldSExtract)] = BitEngine(.SInt, .BitFieldSExtract).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.BitFieldUExtract)] = BitEngine(.UInt, .BitFieldUExtract).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.BitReverse)] = BitEngine(.UInt, .BitReverse).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.Bitcast)] = opBitcast; + runtime_dispatcher[@intFromEnum(spv.SpvOp.BitwiseAnd)] = BitEngine(.UInt, .BitwiseAnd).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.BitwiseOr)] = BitEngine(.UInt, .BitwiseOr).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.BitwiseXor)] = BitEngine(.UInt, .BitwiseXor).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.Branch)] = opBranch; + runtime_dispatcher[@intFromEnum(spv.SpvOp.BranchConditional)] = opBranchConditional; + runtime_dispatcher[@intFromEnum(spv.SpvOp.CompositeConstruct)] = opCompositeConstruct; + runtime_dispatcher[@intFromEnum(spv.SpvOp.CompositeExtract)] = opCompositeExtract; + runtime_dispatcher[@intFromEnum(spv.SpvOp.ConvertFToS)] = ConversionEngine(.Float, .SInt).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.ConvertFToU)] = ConversionEngine(.Float, .UInt).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.ConvertSToF)] = ConversionEngine(.SInt, .Float).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.ConvertUToF)] = ConversionEngine(.UInt, .Float).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.CopyMemory)] = opCopyMemory; + runtime_dispatcher[@intFromEnum(spv.SpvOp.Dot)] = opDot; + runtime_dispatcher[@intFromEnum(spv.SpvOp.FAdd)] = MathEngine(.Float, .Add).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.FConvert)] = ConversionEngine(.Float, .Float).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.FDiv)] = MathEngine(.Float, .Div).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.FMod)] = MathEngine(.Float, .Mod).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.FMul)] = MathEngine(.Float, .Mul).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.FOrdEqual)] = CondEngine(.Float, .Equal).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.FOrdGreaterThan)] = CondEngine(.Float, .Greater).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.FOrdGreaterThanEqual)] = CondEngine(.Float, .GreaterEqual).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.FOrdLessThan)] = CondEngine(.Float, .Less).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.FOrdLessThanEqual)] = CondEngine(.Float, .LessEqual).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.FOrdNotEqual)] = CondEngine(.Float, .NotEqual).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.FSub)] = MathEngine(.Float, .Sub).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.FUnordEqual)] = CondEngine(.Float, .Equal).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.FUnordGreaterThan)] = CondEngine(.Float, .Greater).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.FUnordGreaterThanEqual)] = CondEngine(.Float, .GreaterEqual).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.FUnordLessThan)] = CondEngine(.Float, .Less).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.FUnordLessThanEqual)] = CondEngine(.Float, .LessEqual).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.FUnordNotEqual)] = CondEngine(.Float, .NotEqual).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.FunctionCall)] = opFunctionCall; + runtime_dispatcher[@intFromEnum(spv.SpvOp.IAdd)] = MathEngine(.SInt, .Add).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.IEqual)] = CondEngine(.SInt, .Equal).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.IMul)] = MathEngine(.SInt, .Mul).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.INotEqual)] = CondEngine(.SInt, .NotEqual).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.ISub)] = MathEngine(.SInt, .Sub).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.Kill)] = opKill; + runtime_dispatcher[@intFromEnum(spv.SpvOp.Load)] = opLoad; + runtime_dispatcher[@intFromEnum(spv.SpvOp.LogicalAnd)] = CondEngine(.Float, .LogicalAnd).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.LogicalEqual)] = CondEngine(.Float, .LogicalEqual).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.LogicalNot)] = CondEngine(.Float, .LogicalNot).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.LogicalNotEqual)] = CondEngine(.Float, .LogicalNotEqual).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.LogicalOr)] = CondEngine(.Float, .LogicalOr).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.MatrixTimesMatrix)] = MathEngine(.Float, .MatrixTimesMatrix).op; // TODO + runtime_dispatcher[@intFromEnum(spv.SpvOp.MatrixTimesScalar)] = MathEngine(.Float, .MatrixTimesScalar).op; // TODO + runtime_dispatcher[@intFromEnum(spv.SpvOp.MatrixTimesVector)] = MathEngine(.Float, .MatrixTimesVector).op; // TODO + runtime_dispatcher[@intFromEnum(spv.SpvOp.Not)] = BitEngine(.UInt, .Not).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.Return)] = opReturn; + runtime_dispatcher[@intFromEnum(spv.SpvOp.ReturnValue)] = opReturnValue; + runtime_dispatcher[@intFromEnum(spv.SpvOp.SConvert)] = ConversionEngine(.SInt, .SInt).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.SDiv)] = MathEngine(.SInt, .Div).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.SGreaterThan)] = CondEngine(.SInt, .Greater).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.SGreaterThanEqual)] = CondEngine(.SInt, .GreaterEqual).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.SLessThan)] = CondEngine(.SInt, .Less).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.SLessThanEqual)] = CondEngine(.SInt, .LessEqual).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.SMod)] = MathEngine(.SInt, .Mod).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.ShiftLeftLogical)] = BitEngine(.UInt, .ShiftLeft).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.ShiftRightArithmetic)] = BitEngine(.SInt, .ShiftRightArithmetic).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.ShiftRightLogical)] = BitEngine(.UInt, .ShiftRight).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.Store)] = opStore; + runtime_dispatcher[@intFromEnum(spv.SpvOp.UConvert)] = ConversionEngine(.UInt, .UInt).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.UDiv)] = MathEngine(.UInt, .Div).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.UGreaterThan)] = CondEngine(.UInt, .Greater).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.UGreaterThanEqual)] = CondEngine(.UInt, .GreaterEqual).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.ULessThan)] = CondEngine(.UInt, .Less).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.ULessThanEqual)] = CondEngine(.UInt, .LessEqual).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.UMod)] = MathEngine(.UInt, .Mod).op; + runtime_dispatcher[@intFromEnum(spv.SpvOp.VectorTimesMatrix)] = MathEngine(.Float, .VectorTimesMatrix).op; // TODO + runtime_dispatcher[@intFromEnum(spv.SpvOp.VectorTimesScalar)] = MathEngine(.Float, .VectorTimesScalar).op; + // zig fmt: on +} fn BitEngine(comptime T: ValueType, comptime Op: BitOp) type { if (T == .Float) @compileError("Invalid value type"); @@ -362,7 +399,10 @@ fn CondEngine(comptime T: ValueType, comptime Op: CondOp) type { const op1_result = &rt.results[try rt.it.next()]; const op1_type = try op1_result.getValueTypeWord(); const op1_value = try op1_result.getValue(); - const op2_value = try rt.results[try rt.it.next()].getValue(); + const op2_value: ?*Result.Value = switch (Op) { + .LogicalNot => null, + else => try rt.results[try rt.it.next()].getValue(), + }; const size = sw: switch ((try rt.results[op1_type].getVariant()).Type) { .Vector => |v| continue :sw (try rt.results[v.components_type_word].getVariant()).Type, @@ -382,18 +422,21 @@ fn CondEngine(comptime T: ValueType, comptime Op: CondOp) type { }; const operator = struct { - fn operation(comptime TT: type, op1: TT, op2: TT) RuntimeError!bool { + fn operation(comptime TT: type, op1: TT, op2: ?TT) RuntimeError!bool { return switch (Op) { - .Equal => op1 == op2, - .NotEqual => op1 != op2, - .Greater => op1 > op2, - .GreaterEqual => op1 >= op2, - .Less => op1 < op2, - .LessEqual => op1 <= op2, + .Equal, .LogicalEqual => op1 == op2 orelse return RuntimeError.InvalidSpirV, + .NotEqual, .LogicalNotEqual => op1 != op2 orelse return RuntimeError.InvalidSpirV, + .Greater => op1 > op2 orelse return RuntimeError.InvalidSpirV, + .GreaterEqual => op1 >= op2 orelse return RuntimeError.InvalidSpirV, + .Less => op1 < op2 orelse return RuntimeError.InvalidSpirV, + .LessEqual => op1 <= op2 orelse return RuntimeError.InvalidSpirV, + .LogicalAnd => (op1 != @as(TT, 0)) and ((op2 orelse return RuntimeError.InvalidSpirV) != @as(TT, 0)), + .LogicalOr => (op1 != @as(TT, 0)) or ((op2 orelse return RuntimeError.InvalidSpirV) != @as(TT, 0)), + .LogicalNot => (op1 == @as(TT, 0)), }; } - fn process(bit_count: SpvWord, v: *Result.Value, op1_v: *const Result.Value, op2_v: *const Result.Value) RuntimeError!void { + fn process(bit_count: SpvWord, v: *Result.Value, op1_v: *const Result.Value, op2_v: ?*const Result.Value) RuntimeError!void { switch (bit_count) { inline 8, 16, 32, 64 => |i| { if (i == 8 and T == .Float) { // No f8 @@ -402,7 +445,7 @@ fn CondEngine(comptime T: ValueType, comptime Op: CondOp) type { v.Bool = try operation( getValuePrimitiveFieldType(T, i), (try getValuePrimitiveField(T, i, @constCast(op1_v))).*, - (try getValuePrimitiveField(T, i, @constCast(op2_v))).*, + if (op2_v) |val| (try getValuePrimitiveField(T, i, @constCast(val))).* else null, ); }, else => return RuntimeError.InvalidSpirV, @@ -412,7 +455,9 @@ fn CondEngine(comptime T: ValueType, comptime Op: CondOp) type { switch (value.*) { .Bool => try operator.process(size, value, op1_value, op2_value), - .Vector => |vec| for (vec, op1_value.Vector, op2_value.Vector) |*val, op1_v, op2_v| try operator.process(size, val, &op1_v, &op2_v), + .Vector => |vec| for (vec, op1_value.Vector, 0..) |*val, op1_v, i| { + try operator.process(size, val, &op1_v, if (op2_value) |op2_v| &op2_v.Vector[i] else null); + }, // No Vector specializations for booleans else => return RuntimeError.InvalidSpirV, } @@ -586,6 +631,7 @@ fn MathEngine(comptime T: ValueType, comptime Op: MathOp) type { }, .Mod => if (op2 == 0) return RuntimeError.DivisionByZero else @mod(op1, op2), .Rem => if (op2 == 0) return RuntimeError.DivisionByZero else @rem(op1, op2), + else => return RuntimeError.InvalidSpirV, }; } @@ -609,15 +655,29 @@ fn MathEngine(comptime T: ValueType, comptime Op: MathOp) type { switch (value.*) { .Float => if (T == .Float) try operator.process(size, value, op1_value, op2_value) else return RuntimeError.InvalidSpirV, .Int => if (T == .SInt or T == .UInt) try operator.process(size, value, op1_value, op2_value) else return RuntimeError.InvalidSpirV, - .Vector => |vec| for (vec, op1_value.Vector, op2_value.Vector) |*val, op1_v, op2_v| try operator.process(size, val, &op1_v, &op2_v), + .Vector => |vec| for (vec, op1_value.Vector, 0..) |*val, op1_v, i| { + switch (Op) { + .VectorTimesScalar => try operator.process(size, val, &op1_v, op2_value), + else => try operator.process(size, val, &op1_v, &op2_value.Vector[i]), + } + }, .Vector4f32 => |*vec| inline for (0..4) |i| { - vec[i] = try operator.operation(f32, op1_value.Vector4f32[i], op2_value.Vector4f32[i]); + switch (Op) { + .VectorTimesScalar => vec[i] = op1_value.Vector4f32[i] * op2_value.Float.float32, + else => vec[i] = try operator.operation(f32, op1_value.Vector4f32[i], op2_value.Vector4f32[i]), + } }, .Vector3f32 => |*vec| inline for (0..3) |i| { - vec[i] = try operator.operation(f32, op1_value.Vector3f32[i], op2_value.Vector3f32[i]); + switch (Op) { + .VectorTimesScalar => vec[i] = op1_value.Vector3f32[i] * op2_value.Float.float32, + else => vec[i] = try operator.operation(f32, op1_value.Vector3f32[i], op2_value.Vector3f32[i]), + } }, .Vector2f32 => |*vec| inline for (0..2) |i| { - vec[i] = try operator.operation(f32, op1_value.Vector2f32[i], op2_value.Vector2f32[i]); + switch (Op) { + .VectorTimesScalar => vec[i] = op1_value.Vector2f32[i] * op2_value.Float.float32, + else => vec[i] = try operator.operation(f32, op1_value.Vector2f32[i], op2_value.Vector2f32[i]), + } }, .Vector4i32 => |*vec| inline for (0..4) |i| { vec[i] = try operator.operation(i32, op1_value.Vector4i32[i], op2_value.Vector4i32[i]); @@ -737,6 +797,7 @@ fn copyValue(dst: *Result.Value, src: *const Result.Value) void { fn getValuePrimitiveField(comptime T: ValueType, comptime BitCount: SpvWord, v: *Result.Value) RuntimeError!*getValuePrimitiveFieldType(T, BitCount) { return switch (T) { + .Bool => &v.Bool, .Float => switch (BitCount) { inline 16, 32, 64 => |i| &@field(v.Float, std.fmt.comptimePrint("float{}", .{i})), else => return RuntimeError.InvalidSpirV, @@ -754,6 +815,7 @@ fn getValuePrimitiveField(comptime T: ValueType, comptime BitCount: SpvWord, v: fn getValuePrimitiveFieldType(comptime T: ValueType, comptime BitCount: SpvWord) type { return switch (T) { + .Bool => bool, .Float => std.meta.Float(BitCount), .SInt => std.meta.Int(.signed, BitCount), .UInt => std.meta.Int(.unsigned, BitCount), @@ -971,6 +1033,41 @@ fn opDecorateMember(allocator: std.mem.Allocator, _: SpvWord, rt: *Runtime) Runt try addDecoration(allocator, rt, target, decoration_type, member); } +fn opDot(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { + const target_type = (try rt.results[try rt.it.next()].getVariant()).Type; + var value = try rt.results[try rt.it.next()].getValue(); + const op1_value = try rt.results[try rt.it.next()].getValue(); + const op2_value = try rt.results[try rt.it.next()].getValue(); + + const size = switch (target_type) { + .Float => |f| f.bit_length, + else => return RuntimeError.InvalidSpirV, + }; + + value.Float.float64 = 0.0; + + switch (op1_value.*) { + .Vector => |vec| for (vec, op2_value.Vector) |*op1_v, *op2_v| { + switch (size) { + inline 16, 32, 64 => |i| { + (try getValuePrimitiveField(.Float, i, value)).* += (try getValuePrimitiveField(.Float, i, op1_v)).* * (try getValuePrimitiveField(.Float, i, op2_v)).*; + }, + else => return RuntimeError.InvalidSpirV, + } + }, + .Vector4f32 => |*vec| inline for (0..4) |i| { + value.Float.float32 += vec[i] * op2_value.Vector4f32[i]; + }, + .Vector3f32 => |*vec| inline for (0..3) |i| { + value.Float.float32 += vec[i] * op2_value.Vector3f32[i]; + }, + .Vector2f32 => |*vec| inline for (0..2) |i| { + value.Float.float32 += vec[i] * op2_value.Vector2f32[i]; + }, + else => return RuntimeError.InvalidSpirV, + } +} + fn opEntryPoint(allocator: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) RuntimeError!void { const entry = rt.mod.entry_points.addOne(allocator) catch return RuntimeError.OutOfMemory; entry.exec_model = try rt.it.nextAs(spv.SpvExecutionModel); @@ -1099,6 +1196,10 @@ fn opLabel(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { }; } +fn opKill(_: std.mem.Allocator, _: SpvWord, _: *Runtime) RuntimeError!void { + return RuntimeError.Killed; +} + fn opLoad(_: std.mem.Allocator, _: SpvWord, rt: *Runtime) RuntimeError!void { _ = rt.it.skip(); const id = try rt.it.next(); diff --git a/src/spv.zig b/src/spv.zig index 05ae91b..1e0b83e 100644 --- a/src/spv.zig +++ b/src/spv.zig @@ -2391,3 +2391,5 @@ pub const SpvOp = enum(u32) { ConvertHandleToSampledImageINTEL = 6531, Max = 0x7fffffff, }; + +pub const SpvOpMaxValue: comptime_int = @intFromEnum(SpvOp.ConvertHandleToSampledImageINTEL);