diff --git a/sandbox/main.zig b/sandbox/main.zig index 388f295..a11df59 100644 --- a/sandbox/main.zig +++ b/sandbox/main.zig @@ -3,12 +3,8 @@ const spv = @import("spv"); const shader_source = @embedFile("shader.spv"); -const Input = struct { - value: [4]i32 = [4]i32{ 0, 0, 0, 0 }, -}; - -const Output = struct { - value: [4]i32 = [4]i32{ 0, 0, 0, 0 }, +const SSBO = struct { + value: [256]i32 = [_]i32{0} ** 256, }; pub fn main() !void { @@ -28,19 +24,31 @@ pub fn main() !void { const entry = try rt.getEntryPointByName("main"); - var input: Input = .{}; - var output: Output = .{}; + var ssbo: SSBO = .{}; - try rt.writeDescriptorSet(allocator, std.mem.asBytes(&input), 0, 0); - try rt.writeDescriptorSet(allocator, std.mem.asBytes(&output), 0, 1); + for (0..16) |i| { + for (0..16) |x| { + for (0..16) |y| { + const global_invocation_indices = [3]i32{ + @as(i32, @intCast(i * 16 + x)), + @as(i32, @intCast(y)), + 1, + }; - try rt.callEntryPoint(allocator, entry); + try rt.writeBuiltIn(std.mem.asBytes(&global_invocation_indices), .GlobalInvocationId); + try rt.writeDescriptorSet(allocator, std.mem.asBytes(&ssbo), 0, 0); + rt.callEntryPoint(allocator, entry) catch |err| switch (err) { + spv.Runtime.RuntimeError.OutOfBounds => continue, + else => return err, + }; + try rt.readDescriptorSet(std.mem.asBytes(&ssbo), 0, 0); + } + } + } - try rt.readDescriptorSet(std.mem.asBytes(&output), 0, 1); + std.log.info("Output: {any}", .{ssbo}); - std.log.info("Output: {any}", .{output}); - - std.log.info("\nTotal memory used: {d:.3} KB\n", .{@as(f32, @floatFromInt(gpa.total_requested_bytes)) / 1000.0}); + std.log.info("Total memory used: {d:.3} KB\n", .{@as(f32, @floatFromInt(gpa.total_requested_bytes)) / 1000.0}); } std.log.info("Successfully executed", .{}); } diff --git a/sandbox/shader.nzsl b/sandbox/shader.nzsl index 7bf3104..6dbc17b 100644 --- a/sandbox/shader.nzsl +++ b/sandbox/shader.nzsl @@ -1,68 +1,25 @@ -[sudo mkswap /swapfilenzsl_version("1.1")] +[nzsl_version("1.1")] module; -struct FragIn +struct Input { - [location(0)] time: f32, - [location(1)] res: vec2[f32], - [location(2)] pos: vec2[f32], + [builtin(global_invocation_indices)] indices: vec3[u32] } -struct FragOut +[layout(std430)] +struct SSBO { - [location(0)] color: vec4[f32] + data: dyn_array[i32] } -[entry(frag)] -fn main(input: FragIn) -> FragOut +external { - const I: i32 = 32; - const A: f32 = 7.5; - const MA: f32 = 2.0; - const MI: f32 = 0.001; - - let uv0 = input.pos / input.res * 2.0 - vec2[f32](1.0, 1.0); - let uv = vec2[f32](uv0.x * (input.res.x / input.res.y), uv0.y); - - let col = vec4[f32](0.0, 0.0, 0.0, 0.0); - let ro = vec4[f32](0.0, 0.0, -2.0, 0.0); - let rd = vec4[f32](uv.x, uv.y, 1.0, 0.0); - let dt = 0.0; - let ds = 0.0; - let dm = -1.0; - let p = ro; - let c = vec4[f32](0.0, 0.0, 0.0, 0.0); - - let l = vec4[f32](0.0, sin(input.time * 0.2) * 4.0, cos(input.time * 0.2) * 4.0, 0.0); - - for i in 0 -> I - { - p = ro + rd * dt; - ds = length(c - p) - 1.0; - dt += ds; - - if (dm == -1.0 || ds < dm) - dm = ds; - - if (ds <= MI) - { - let value = max(dot(normalize(c - p), normalize(p - l)), 0.0); - col = vec4[f32](value, value, value, 1.0); - break; - } - - if (ds >= MA) - { - if (dot(normalize(rd), normalize(l - ro)) < 1.0) - { - let value = max(dot(normalize(rd), normalize(l - ro)) + 0.15, 0.0) / 1.15 * max(1.0 - dm * A, 0.0); - col = vec4[f32](value, value, value, 1.0); - } - break; - } - } - - let output: FragOut; - output.color = col; - return output; + [set(0), binding(0)] ssbo: storage[SSBO], +} + +[entry(compute)] +[workgroup(16, 16, 1)] +fn main(input: Input) +{ + ssbo.data[input.indices.x * input.indices.y] = i32(input.indices.x * input.indices.y); } diff --git a/sandbox/shader.spv b/sandbox/shader.spv index d89916a..bc1bdc8 100644 Binary files a/sandbox/shader.spv and b/sandbox/shader.spv differ diff --git a/sandbox/shader.spv.txt b/sandbox/shader.spv.txt index a5ff195..784b4e1 100644 --- a/sandbox/shader.spv.txt +++ b/sandbox/shader.spv.txt @@ -1,83 +1,68 @@ -OpCapability Shader -%1 = OpExtInstImport "GLSL.std.450" -OpMemoryModel Logical GLSL450 -OpEntryPoint GLCompute %4 "main" %11 %20 -OpExecutionMode %4 LocalSize 1 1 1 -OpDecorate %11 BuiltIn NumWorkgroups -OpDecorate %20 BuiltIn WorkgroupId -OpMemberDecorate %37 0 Offset 0 -OpDecorate %38 ArrayStride 4 -OpDecorate %39 BufferBlock -OpMemberDecorate %39 0 Offset 0 -OpDecorate %41 Binding 0 -OpDecorate %41 DescriptorSet 0 -OpMemberDecorate %50 0 Offset 0 -OpDecorate %51 ArrayStride 4 -OpDecorate %52 BufferBlock -OpMemberDecorate %52 0 Offset 0 -OpDecorate %54 Binding 1 -OpDecorate %54 DescriptorSet 0 -OpDecorate %58 BuiltIn WorkgroupSize -%2 = OpTypeVoid -%3 = OpTypeFunction %2 -%6 = OpTypeInt 32 0 -%7 = OpTypePointer Function %6 -%9 = OpTypeVector %6 3 -%10 = OpTypePointer Input %9 -%11 = OpVariable %10 Input -%12 = OpConstant %6 0 -%13 = OpTypePointer Input %6 -%16 = OpConstant %6 1 -%20 = OpVariable %10 Input -%21 = OpConstant %6 2 -%34 = OpTypeInt 32 1 -%35 = OpTypePointer Function %34 -%37 = OpTypeStruct %34 -%38 = OpTypeRuntimeArray %37 -%39 = OpTypeStruct %38 -%40 = OpTypePointer Uniform %39 -%41 = OpVariable %40 Uniform -%42 = OpConstant %34 0 -%44 = OpTypePointer Uniform %34 -%50 = OpTypeStruct %34 -%51 = OpTypeRuntimeArray %50 -%52 = OpTypeStruct %51 -%53 = OpTypePointer Uniform %52 -%54 = OpVariable %53 Uniform -%58 = OpConstantComposite %9 %16 %16 %16 -%4 = OpFunction %2 None %3 -%5 = OpLabel -%8 = OpVariable %7 Function -%36 = OpVariable %35 Function -%47 = OpVariable %35 Function -%14 = OpAccessChain %13 %11 %12 -%15 = OpLoad %6 %14 -%17 = OpAccessChain %13 %11 %16 -%18 = OpLoad %6 %17 -%19 = OpIMul %6 %15 %18 -%22 = OpAccessChain %13 %20 %21 -%23 = OpLoad %6 %22 -%24 = OpIMul %6 %19 %23 -%25 = OpAccessChain %13 %11 %12 -%26 = OpLoad %6 %25 -%27 = OpAccessChain %13 %20 %16 -%28 = OpLoad %6 %27 -%29 = OpIMul %6 %26 %28 -%30 = OpIAdd %6 %24 %29 -%31 = OpAccessChain %13 %20 %12 -%32 = OpLoad %6 %31 -%33 = OpIAdd %6 %30 %32 -OpStore %8 %33 -%43 = OpLoad %6 %8 -%45 = OpAccessChain %44 %41 %42 %43 %42 -%46 = OpLoad %34 %45 -OpStore %36 %46 -%48 = OpLoad %34 %36 -%49 = OpExtInst %34 %1 SAbs %48 -OpStore %47 %49 -%55 = OpLoad %6 %8 -%56 = OpLoad %34 %47 -%57 = OpAccessChain %44 %54 %42 %55 %42 -OpStore %57 %56 -OpReturn -OpFunctionEnd +Version 1.0 +Generator: 2560130 +Bound: 41 +Schema: 0 + OpCapability Capability(Shader) + OpMemoryModel AddressingModel(Logical) MemoryModel(GLSL450) + OpEntryPoint ExecutionModel(GLCompute) %18 "main" %11 + OpExecutionMode %18 ExecutionMode(LocalSize) 16 16 1 + OpSource SourceLanguage(NZSL) 4198400 + OpSourceExtension "Version: 1.1" + OpName %3 "SSBO" + OpMemberName %3 0 "data" + OpName %14 "Input" + OpMemberName %14 0 "indices" + OpName %5 "ssbo" + OpName %11 "global_invocation_indices" + OpName %18 "main" + OpDecorate %5 Decoration(Binding) 0 + OpDecorate %5 Decoration(DescriptorSet) 0 + OpDecorate %11 Decoration(BuiltIn) BuiltIn(GlobalInvocationId) + OpDecorate %2 Decoration(ArrayStride) 4 + OpDecorate %3 Decoration(BufferBlock) + OpMemberDecorate %3 0 Decoration(Offset) 0 + OpMemberDecorate %14 0 Decoration(Offset) 0 + %1 = OpTypeInt 32 1 + %2 = OpTypeRuntimeArray %1 + %3 = OpTypeStruct %2 + %4 = OpTypePointer StorageClass(Uniform) %3 + %6 = OpTypeVoid + %7 = OpTypeFunction %6 + %8 = OpTypeInt 32 0 + %9 = OpTypeVector %8 3 +%10 = OpTypePointer StorageClass(Input) %9 +%12 = OpConstant %1 i32(0) +%13 = OpTypePointer StorageClass(Function) %9 +%14 = OpTypeStruct %9 +%15 = OpTypePointer StorageClass(Function) %14 +%16 = OpTypeRuntimeArray %1 +%17 = OpConstant %1 i32(1) +%31 = OpTypePointer StorageClass(Uniform) %2 +%40 = OpTypePointer StorageClass(Uniform) %1 + %5 = OpVariable %4 StorageClass(Uniform) +%11 = OpVariable %10 StorageClass(Input) +%18 = OpFunction %6 FunctionControl(0) %7 +%19 = OpLabel +%20 = OpVariable %15 StorageClass(Function) +%21 = OpAccessChain %13 %20 %12 + OpCopyMemory %21 %11 +%22 = OpAccessChain %13 %20 %12 +%23 = OpLoad %9 %22 +%24 = OpCompositeExtract %8 %23 0 +%25 = OpAccessChain %13 %20 %12 +%26 = OpLoad %9 %25 +%27 = OpCompositeExtract %8 %26 1 +%28 = OpIMul %8 %24 %27 +%29 = OpBitcast %1 %28 +%30 = OpAccessChain %31 %5 %12 +%32 = OpAccessChain %13 %20 %12 +%33 = OpLoad %9 %32 +%34 = OpCompositeExtract %8 %33 0 +%35 = OpAccessChain %13 %20 %12 +%36 = OpLoad %9 %35 +%37 = OpCompositeExtract %8 %36 1 +%38 = OpIMul %8 %34 %37 +%39 = OpAccessChain %40 %30 %38 + OpStore %39 %29 + OpReturn + OpFunctionEnd diff --git a/src/Module.zig b/src/Module.zig index 6f65b76..df18c6d 100644 --- a/src/Module.zig +++ b/src/Module.zig @@ -74,6 +74,7 @@ geometry_output: SpvWord, input_locations: [lib.SPIRV_MAX_INPUT_LOCATIONS]SpvWord, output_locations: [lib.SPIRV_MAX_OUTPUT_LOCATIONS]SpvWord, bindings: [lib.SPIRV_MAX_SET][lib.SPIRV_MAX_SET_BINDINGS]SpvWord, +builtins: std.EnumMap(spv.SpvBuiltIn, SpvWord), push_constants: []Value, pub fn init(allocator: std.mem.Allocator, source: []const SpvWord, options: ModuleOptions) ModuleError!Self { @@ -216,8 +217,14 @@ fn populateMaps(self: *Self) ModuleError!void { for (result.decorations.items) |decoration| { switch (result.variant.?.Variable.storage_class) { .Input => { - if (decoration.rtype == .Location) - self.input_locations[decoration.literal_1] = @intCast(id); + switch (decoration.rtype) { + .BuiltIn => self.builtins.put( + std.enums.fromInt(spv.SpvBuiltIn, decoration.literal_1) orelse return ModuleError.InvalidSpirV, + @intCast(id), + ), + .Location => self.input_locations[decoration.literal_1] = @intCast(id), + else => {}, + } }, .Output => { if (decoration.rtype == .Location) diff --git a/src/Result.zig b/src/Result.zig index b687f2b..4d01646 100644 --- a/src/Result.zig +++ b/src/Result.zig @@ -316,6 +316,29 @@ pub const TypeData = union(Type) { storage_class: spv.SpvStorageClass, target: SpvWord, }, + + pub fn getSize(self: *const TypeData, results: []const Self) usize { + return switch (self.*) { + .Bool => 1, + .Int => |i| @divExact(i.bit_length, 8), + .Float => |f| @divExact(f.bit_length, 8), + .Vector => |v| results[v.components_type_word].variant.?.Type.getSize(results), + .Array => |a| results[a.components_type_word].variant.?.Type.getSize(results), + .Matrix => |m| results[m.column_type_word].variant.?.Type.getSize(results), + .RuntimeArray => |a| results[a.components_type_word].variant.?.Type.getSize(results), + .Structure => |s| blk: { + var total: usize = 0; + for (s.members_type_word) |type_word| { + total += results[type_word].variant.?.Type.getSize(results); + } + break :blk total; + }, + .Vector4f32, .Vector4i32, .Vector4u32 => 4 * 4, + .Vector3f32, .Vector3i32, .Vector3u32 => 3 * 4, + .Vector2f32, .Vector2i32, .Vector2u32 => 2 * 4, + else => 0, + }; + } }; pub const VariantData = union(Variant) { diff --git a/src/Runtime.zig b/src/Runtime.zig index 319b594..23afe35 100644 --- a/src/Runtime.zig +++ b/src/Runtime.zig @@ -24,6 +24,7 @@ pub const RuntimeError = error{ Killed, NotFound, OutOfMemory, + OutOfBounds, ToDo, Unreachable, UnsupportedSpirV, @@ -167,8 +168,9 @@ pub fn writeDescriptorSet(self: *const Self, allocator: std.mem.Allocator, input const resolved = results[type_word].resolveType(results); switch (value.*) { - .RuntimeArray => { - value.* = try Result.initValue(allocator2, len, results, resolved); + .RuntimeArray => |a| if (a == null) { + const elem_size = resolved.variant.?.Type.getSize(results); + value.* = try Result.initValue(allocator2, @divExact(len, elem_size), results, resolved); }, .Structure => |*s| for (s.*, 0..) |*elem, i| { try @This().init(allocator2, len, elem, resolved.variant.?.Type.Structure.members_type_word[i], results); @@ -179,12 +181,12 @@ pub fn writeDescriptorSet(self: *const Self, allocator: std.mem.Allocator, input }; try helper.init(allocator, input.len, &variable.value, variable.type_word, self.results); - @import("pretty").print(allocator, variable, .{ - .tab_size = 4, - .max_depth = 0, - .struct_max_len = 0, - .array_max_len = 0, - }) catch return RuntimeError.OutOfMemory; + //@import("pretty").print(allocator, variable, .{ + // .tab_size = 4, + // .max_depth = 0, + // .struct_max_len = 0, + // .array_max_len = 0, + //}) catch return RuntimeError.OutOfMemory; _ = try self.writeValue(input, &variable.value); } else { return RuntimeError.NotFound; @@ -192,7 +194,7 @@ pub fn writeDescriptorSet(self: *const Self, allocator: std.mem.Allocator, input } pub fn readOutput(self: *const Self, output: []u8, result: SpvWord) RuntimeError!void { - if (std.mem.indexOf(SpvWord, &self.mod.output_locations, &.{result})) |_| { + if (std.mem.indexOfScalar(SpvWord, &self.mod.output_locations, result)) |_| { _ = try self.readValue(output, &self.results[result].variant.?.Variable.value); } else { return RuntimeError.NotFound; @@ -200,7 +202,15 @@ pub fn readOutput(self: *const Self, output: []u8, result: SpvWord) RuntimeError } pub fn writeInput(self: *const Self, input: []const u8, result: SpvWord) RuntimeError!void { - if (std.mem.indexOf(SpvWord, &self.mod.input_locations, &.{result})) |_| { + if (std.mem.indexOfScalar(SpvWord, &self.mod.input_locations, result)) |_| { + _ = try self.writeValue(input, &self.results[result].variant.?.Variable.value); + } else { + return RuntimeError.NotFound; + } +} + +pub fn writeBuiltIn(self: *const Self, input: []const u8, builtin: spv.SpvBuiltIn) RuntimeError!void { + if (self.mod.builtins.get(builtin)) |result| { _ = try self.writeValue(input, &self.results[result].variant.?.Variable.value); } else { return RuntimeError.NotFound; diff --git a/src/opcodes.zig b/src/opcodes.zig index c02f026..50aec47 100644 --- a/src/opcodes.zig +++ b/src/opcodes.zig @@ -1030,51 +1030,56 @@ fn opAccessChain(_: std.mem.Allocator, word_count: SpvWord, rt: *Runtime) Runtim .Variable => |v| &v.value, else => return RuntimeError.InvalidSpirV, }; + switch (member_value.*) { .Int => |i| { + if (std.meta.activeTag(value_ptr.*) == .Pointer) { + value_ptr = value_ptr.Pointer.common; // Don't know if I should check for specialized pointers + } + switch (value_ptr.*) { .Vector, .Matrix, .Array, .Structure => |v| { - if (i.value.uint32 > v.len) return RuntimeError.InvalidSpirV; + if (i.value.uint32 >= v.len) return RuntimeError.OutOfBounds; value_ptr = &v[i.value.uint32]; }, - .RuntimeArray => |opt_v| if (opt_v) |v| { - if (i.value.uint32 > v.len) return RuntimeError.InvalidSpirV; - value_ptr = &v[i.value.uint32]; + .RuntimeArray => |opt_a| if (opt_a) |a| { + if (i.value.uint32 >= a.len) return RuntimeError.OutOfBounds; + value_ptr = &a[i.value.uint32]; } else return RuntimeError.InvalidSpirV, .Vector4f32 => |*v| { - if (i.value.uint32 > 4) return RuntimeError.InvalidSpirV; + if (i.value.uint32 > 4) return RuntimeError.OutOfBounds; break :blk .{ .Pointer = .{ .f32_ptr = &v[i.value.uint32] } }; }, .Vector3f32 => |*v| { - if (i.value.uint32 > 3) return RuntimeError.InvalidSpirV; + if (i.value.uint32 > 3) return RuntimeError.OutOfBounds; break :blk .{ .Pointer = .{ .f32_ptr = &v[i.value.uint32] } }; }, .Vector2f32 => |*v| { - if (i.value.uint32 > 2) return RuntimeError.InvalidSpirV; + if (i.value.uint32 > 2) return RuntimeError.OutOfBounds; break :blk .{ .Pointer = .{ .f32_ptr = &v[i.value.uint32] } }; }, .Vector4i32 => |*v| { - if (i.value.uint32 > 4) return RuntimeError.InvalidSpirV; + if (i.value.uint32 > 4) return RuntimeError.OutOfBounds; break :blk .{ .Pointer = .{ .i32_ptr = &v[i.value.uint32] } }; }, .Vector3i32 => |*v| { - if (i.value.uint32 > 3) return RuntimeError.InvalidSpirV; + if (i.value.uint32 > 3) return RuntimeError.OutOfBounds; break :blk .{ .Pointer = .{ .i32_ptr = &v[i.value.uint32] } }; }, .Vector2i32 => |*v| { - if (i.value.uint32 > 2) return RuntimeError.InvalidSpirV; + if (i.value.uint32 > 2) return RuntimeError.OutOfBounds; break :blk .{ .Pointer = .{ .i32_ptr = &v[i.value.uint32] } }; }, .Vector4u32 => |*v| { - if (i.value.uint32 > 4) return RuntimeError.InvalidSpirV; + if (i.value.uint32 > 4) return RuntimeError.OutOfBounds; break :blk .{ .Pointer = .{ .u32_ptr = &v[i.value.uint32] } }; }, .Vector3u32 => |*v| { - if (i.value.uint32 > 3) return RuntimeError.InvalidSpirV; + if (i.value.uint32 > 3) return RuntimeError.OutOfBounds; break :blk .{ .Pointer = .{ .u32_ptr = &v[i.value.uint32] } }; }, .Vector2u32 => |*v| { - if (i.value.uint32 > 2) return RuntimeError.InvalidSpirV; + if (i.value.uint32 > 2) return RuntimeError.OutOfBounds; break :blk .{ .Pointer = .{ .u32_ptr = &v[i.value.uint32] } }; }, else => return RuntimeError.InvalidSpirV,