diff --git a/example/main.zig b/example/main.zig index 169a6cc..7f778e5 100644 --- a/example/main.zig +++ b/example/main.zig @@ -17,8 +17,8 @@ pub fn main() !void { defer rt.deinit(allocator); try rt.callEntryPoint(allocator, try rt.getEntryPointByName("main")); - var output: [4]f32 = undefined; - try rt.readOutput(f32, output[0..output.len], try rt.getResultByName("color")); + var output: [4]i32 = undefined; + try rt.readOutput(i32, output[0..output.len], try rt.getResultByName("color")); std.log.info("Output: Vec4{any}", .{output}); } std.log.info("Successfully executed", .{}); diff --git a/example/shader.nzsl b/example/shader.nzsl index ec65fc3..0058b30 100644 --- a/example/shader.nzsl +++ b/example/shader.nzsl @@ -4,18 +4,19 @@ module; struct FragOut { - [location(0)] color: vec4[f32] + [location(0)] color: vec4[i32] } -fn computeColor(val: f32) -> f32 +fn fibonacci(n: i32) -> i32 { - return 2.0 * val; + if (n <= i32(1)) return n; + return fibonacci(n - i32(1)) + fibonacci(n - i32(2)); } [entry(frag)] fn main() -> FragOut { let output: FragOut; - output.color = vec4[f32](computeColor(1.0), computeColor(2.0), computeColor(3.0), computeColor(4.0)); + output.color = vec4[i32](fibonacci(2), fibonacci(2), fibonacci(2), fibonacci(2)); return output; } diff --git a/example/shader.spv b/example/shader.spv index afda241..bc61751 100644 Binary files a/example/shader.spv and b/example/shader.spv differ diff --git a/example/shader.spv.txt b/example/shader.spv.txt index 59c5a0e..1a1ffb5 100644 --- a/example/shader.spv.txt +++ b/example/shader.spv.txt @@ -1,65 +1,82 @@ Version 1.0 Generator: 2560130 -Bound: 38 +Bound: 49 Schema: 0 OpCapability Capability(Shader) OpCapability Capability(Float64) OpMemoryModel AddressingModel(Logical) MemoryModel(GLSL450) - OpEntryPoint ExecutionModel(Fragment) %18 "main" %9 - OpExecutionMode %18 ExecutionMode(OriginUpperLeft) + OpEntryPoint ExecutionModel(Fragment) %16 "main" %11 + OpExecutionMode %16 ExecutionMode(OriginUpperLeft) OpSource SourceLanguage(NZSL) 4198400 OpSourceExtension "Version: 1.1" - OpName %10 "FragOut" - OpMemberName %10 0 "color" - OpName %9 "color" - OpName %17 "computeColor" - OpName %18 "main" - OpDecorate %9 Decoration(Location) 0 - OpMemberDecorate %10 0 Decoration(Offset) 0 - %1 = OpTypeFloat 32 + OpName %12 "FragOut" + OpMemberName %12 0 "color" + OpName %11 "color" + OpName %15 "fibonacci" + OpName %16 "main" + OpDecorate %11 Decoration(Location) 0 + OpMemberDecorate %12 0 Decoration(Offset) 0 + %1 = OpTypeInt 32 1 %2 = OpTypePointer StorageClass(Function) %1 %3 = OpTypeFunction %1 %2 - %4 = OpConstant %1 f32(2) - %5 = OpTypeVoid - %6 = OpTypeFunction %5 - %7 = OpTypeVector %1 4 - %8 = OpTypePointer StorageClass(Output) %7 -%10 = OpTypeStruct %7 -%11 = OpTypePointer StorageClass(Function) %10 -%12 = OpTypeInt 32 1 -%13 = OpConstant %12 i32(0) -%14 = OpConstant %1 f32(1) -%15 = OpConstant %1 f32(3) -%16 = OpConstant %1 f32(4) -%35 = OpTypePointer StorageClass(Function) %7 - %9 = OpVariable %8 StorageClass(Output) -%17 = OpFunction %1 FunctionControl(0) %3 -%19 = OpFunctionParameter %2 -%20 = OpLabel -%21 = OpLoad %1 %19 -%22 = OpFMul %1 %4 %21 - OpReturnValue %22 - OpFunctionEnd -%18 = OpFunction %5 FunctionControl(0) %6 + %4 = OpConstant %1 i32(1) + %5 = OpTypeBool + %6 = OpConstant %1 i32(2) + %7 = OpTypeVoid + %8 = OpTypeFunction %7 + %9 = OpTypeVector %1 4 +%10 = OpTypePointer StorageClass(Output) %9 +%12 = OpTypeStruct %9 +%13 = OpTypePointer StorageClass(Function) %12 +%14 = OpConstant %1 i32(0) +%46 = OpTypePointer StorageClass(Function) %9 +%11 = OpVariable %10 StorageClass(Output) +%15 = OpFunction %1 FunctionControl(0) %3 +%17 = OpFunctionParameter %2 +%18 = OpLabel +%19 = OpVariable %2 StorageClass(Function) +%20 = OpVariable %2 StorageClass(Function) +%24 = OpLoad %1 %17 +%25 = OpSLessThanEqual %5 %24 %4 + OpSelectionMerge %21 SelectionControl(0) + OpBranchConditional %25 %22 %23 +%22 = OpLabel +%26 = OpLoad %1 %17 + OpReturnValue %26 %23 = OpLabel -%24 = OpVariable %11 StorageClass(Function) -%25 = OpVariable %2 StorageClass(Function) -%26 = OpVariable %2 StorageClass(Function) -%27 = OpVariable %2 StorageClass(Function) -%28 = OpVariable %2 StorageClass(Function) - OpStore %25 %14 -%29 = OpFunctionCall %1 %17 %25 - OpStore %26 %4 -%30 = OpFunctionCall %1 %17 %26 - OpStore %27 %15 -%31 = OpFunctionCall %1 %17 %27 - OpStore %28 %16 -%32 = OpFunctionCall %1 %17 %28 -%33 = OpCompositeConstruct %7 %29 %30 %31 %32 -%34 = OpAccessChain %35 %24 %13 - OpStore %34 %33 -%36 = OpLoad %10 %24 -%37 = OpCompositeExtract %7 %36 0 - OpStore %9 %37 + OpBranch %21 +%21 = OpLabel +%27 = OpLoad %1 %17 +%28 = OpISub %1 %27 %4 + OpStore %19 %28 +%29 = OpFunctionCall %1 %15 %19 +%30 = OpLoad %1 %17 +%31 = OpISub %1 %30 %6 + OpStore %20 %31 +%32 = OpFunctionCall %1 %15 %20 +%33 = OpIAdd %1 %29 %32 + OpReturnValue %33 + OpFunctionEnd +%16 = OpFunction %7 FunctionControl(0) %8 +%34 = OpLabel +%35 = OpVariable %13 StorageClass(Function) +%36 = OpVariable %2 StorageClass(Function) +%37 = OpVariable %2 StorageClass(Function) +%38 = OpVariable %2 StorageClass(Function) +%39 = OpVariable %2 StorageClass(Function) + OpStore %36 %6 +%40 = OpFunctionCall %1 %15 %36 + OpStore %37 %6 +%41 = OpFunctionCall %1 %15 %37 + OpStore %38 %6 +%42 = OpFunctionCall %1 %15 %38 + OpStore %39 %6 +%43 = OpFunctionCall %1 %15 %39 +%44 = OpCompositeConstruct %9 %40 %41 %42 %43 +%45 = OpAccessChain %46 %35 %14 + OpStore %45 %44 +%47 = OpLoad %12 %35 +%48 = OpCompositeExtract %9 %47 0 + OpStore %11 %48 OpReturn OpFunctionEnd diff --git a/test/functions.zig b/test/functions.zig new file mode 100644 index 0000000..967d2c4 --- /dev/null +++ b/test/functions.zig @@ -0,0 +1,162 @@ +const std = @import("std"); +const root = @import("root.zig"); +const compileNzsl = root.compileNzsl; +const case = root.case; + +test "Simple function calls" { + const allocator = std.testing.allocator; + const types = [_]type{ i32, u32, f32, f64 }; + + inline for (types) |T| { + const n = case.random(T); + + const shader = try std.fmt.allocPrint( + allocator, + \\ [nzsl_version("1.1")] + \\ [feature(float64)] + \\ module; + \\ + \\ struct FragOut + \\ {{ + \\ [location(0)] color: vec4[{s}] + \\ }} + \\ + \\ fn value() -> {s} + \\ {{ + \\ return {d}; + \\ }} + \\ + \\ [entry(frag)] + \\ fn main() -> FragOut + \\ {{ + \\ let output: FragOut; + \\ output.color = vec4[{s}](value(), value(), value(), value()); + \\ return output; + \\ }} + , + .{ + @typeName(T), + @typeName(T), + n, + @typeName(T), + }, + ); + defer allocator.free(shader); + const code = try compileNzsl(allocator, shader); + defer allocator.free(code); + try case.expectOutput(T, 4, code, "color", &.{ n, n, n, n }); + } +} + +test "Nested function calls" { + const allocator = std.testing.allocator; + const types = [_]type{ i32, u32, f32, f64 }; + + inline for (types) |T| { + const n = case.random(T); + + const shader = try std.fmt.allocPrint( + allocator, + \\ [nzsl_version("1.1")] + \\ [feature(float64)] + \\ module; + \\ + \\ struct FragOut + \\ {{ + \\ [location(0)] color: vec4[{s}] + \\ }} + \\ + \\ fn deepValue() -> {s} + \\ {{ + \\ return {d}; + \\ }} + \\ + \\ fn value() -> {s} + \\ {{ + \\ return deepValue(); + \\ }} + \\ + \\ [entry(frag)] + \\ fn main() -> FragOut + \\ {{ + \\ let output: FragOut; + \\ output.color = vec4[{s}](value(), value(), value(), value()); + \\ return output; + \\ }} + , + .{ + @typeName(T), + @typeName(T), + n, + @typeName(T), + @typeName(T), + }, + ); + defer allocator.free(shader); + const code = try compileNzsl(allocator, shader); + defer allocator.free(code); + try case.expectOutput(T, 4, code, "color", &.{ n, n, n, n }); + } +} + +test "Recursive function calls" { + const allocator = std.testing.allocator; + const types = [_]type{ i32, u32, f32, f64 }; + + inline for (types) |T| { + const iterations = 10; + + const fib = struct { + fn onacci(n: T) T { + if (n <= 0) return n; + return onacci(n - 1) + onacci(n - 2); + } + }; + const expected = fib.onacci(iterations); + + const shader = try std.fmt.allocPrint( + allocator, + \\ [nzsl_version("1.1")] + \\ [feature(float64)] + \\ module; + \\ + \\ struct FragOut + \\ {{ + \\ [location(0)] color: vec4[{s}] + \\ }} + \\ + \\ fn fibonacci(n: {s}) -> {s} + \\ {{ + \\ if (n <= {s}(1)) return n; + \\ return fibonacci(n - {s}(1)) + fibonacci(n - {s}(2)); + \\ }} + \\ + \\ [entry(frag)] + \\ fn main() -> FragOut + \\ {{ + \\ let output: FragOut; + \\ output.color = vec4[{s}](fibonacci({d}), fibonacci({d}), fibonacci({d}), fibonacci({d})); + \\ return output; + \\ }} + , + .{ + @typeName(T), + @typeName(T), + @typeName(T), + @typeName(T), + @typeName(T), + @typeName(T), + @typeName(T), + iterations, + iterations, + iterations, + iterations, + }, + ); + defer allocator.free(shader); + std.debug.print("{s}\n\n", .{shader}); + const code = try compileNzsl(allocator, shader); + defer allocator.free(code); + try case.expectOutput(T, 4, code, "color", &.{ expected, expected, expected, expected }); + } +} diff --git a/test/root.zig b/test/root.zig index 6c7ec9c..becad8d 100644 --- a/test/root.zig +++ b/test/root.zig @@ -59,5 +59,6 @@ test { std.testing.refAllDecls(@import("basics.zig")); std.testing.refAllDecls(@import("branching.zig")); std.testing.refAllDecls(@import("casts.zig")); + std.testing.refAllDecls(@import("functions.zig")); std.testing.refAllDecls(@import("maths.zig")); }