From 9dfe81770a8337a7a469eb3bac0ae9599cc0f61c Mon Sep 17 00:00:00 2001 From: gdkchan Date: Thu, 29 Dec 2022 12:09:34 -0300 Subject: [PATCH] Use vector outputs for texture operations (#3939) * Change AggregateType to include vector type counts * Replace VariableType uses with AggregateType and delete VariableType * Support new local vector types on SPIR-V and GLSL * Start using vector outputs for texture operations * Use vectors on more texture operations * Use vector output for ImageLoad operations * Replace all uses of single destination texture constructors with multi destination ones * Update textureGatherOffsets replacement to split vector operations * Shader cache version bump Co-authored-by: Ac_K --- .../Shader/DiskCache/DiskCacheHostStorage.cs | 2 +- Ryujinx.Graphics.Shader/AttributeType.cs | 12 - .../CodeGen/Glsl/Declarations.cs | 36 ++- .../CodeGen/Glsl/GlslGenerator.cs | 8 +- .../CodeGen/Glsl/Instructions/InstGen.cs | 11 +- .../Glsl/Instructions/InstGenBallot.cs | 3 +- .../Glsl/Instructions/InstGenHelper.cs | 14 +- .../Glsl/Instructions/InstGenMemory.cs | 126 +++++--- .../Glsl/Instructions/InstGenVector.cs | 32 ++ .../CodeGen/Glsl/NumberFormatter.cs | 18 +- .../CodeGen/Glsl/OperandManager.cs | 139 +++++---- .../CodeGen/Glsl/TypeConversion.cs | 33 +- .../CodeGen/Spirv/CodeGenContext.cs | 54 +++- .../CodeGen/Spirv/Declarations.cs | 19 +- .../CodeGen/Spirv/EnumConversion.cs | 17 +- .../CodeGen/Spirv/Instructions.cs | 147 +++++++-- .../CodeGen/Spirv/SpirvGenerator.cs | 8 +- .../Instructions/InstEmitSurface.cs | 116 +++---- .../Instructions/InstEmitTexture.cs | 197 ++++++------ .../IntermediateRepresentation/Instruction.cs | 1 + .../IntermediateRepresentation/Operation.cs | 21 +- .../TextureOperation.cs | 8 +- Ryujinx.Graphics.Shader/SamplerType.cs | 8 +- .../StructuredIr/AstHelper.cs | 3 +- .../StructuredIr/AstOperand.cs | 5 +- .../StructuredIr/AstOperation.cs | 18 ++ .../StructuredIr/InstructionInfo.cs | 272 ++++++++-------- .../StructuredIr/OperandInfo.cs | 17 +- .../StructuredIr/StructuredFunction.cs | 15 +- .../StructuredIr/StructuredProgram.cs | 84 ++++- .../StructuredIr/StructuredProgramContext.cs | 10 +- .../StructuredIr/VariableType.cs | 14 - Ryujinx.Graphics.Shader/TextureFormat.cs | 10 +- .../Translation/AggregateType.cs | 11 +- .../Translation/AttributeInfo.cs | 60 ++-- .../Translation/EmitterContext.cs | 8 +- .../Translation/Rewriter.cs | 290 +++++++++++------- 37 files changed, 1100 insertions(+), 747 deletions(-) create mode 100644 Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenVector.cs delete mode 100644 Ryujinx.Graphics.Shader/StructuredIr/VariableType.cs diff --git a/Ryujinx.Graphics.Gpu/Shader/DiskCache/DiskCacheHostStorage.cs b/Ryujinx.Graphics.Gpu/Shader/DiskCache/DiskCacheHostStorage.cs index bdc6f4d6..2622ea3e 100644 --- a/Ryujinx.Graphics.Gpu/Shader/DiskCache/DiskCacheHostStorage.cs +++ b/Ryujinx.Graphics.Gpu/Shader/DiskCache/DiskCacheHostStorage.cs @@ -22,7 +22,7 @@ namespace Ryujinx.Graphics.Gpu.Shader.DiskCache private const ushort FileFormatVersionMajor = 1; private const ushort FileFormatVersionMinor = 2; private const uint FileFormatVersionPacked = ((uint)FileFormatVersionMajor << 16) | FileFormatVersionMinor; - private const uint CodeGenVersion = 4106; + private const uint CodeGenVersion = 3939; private const string SharedTocFileName = "shared.toc"; private const string SharedDataFileName = "shared.data"; diff --git a/Ryujinx.Graphics.Shader/AttributeType.cs b/Ryujinx.Graphics.Shader/AttributeType.cs index 1ede1560..4e6cad59 100644 --- a/Ryujinx.Graphics.Shader/AttributeType.cs +++ b/Ryujinx.Graphics.Shader/AttributeType.cs @@ -1,4 +1,3 @@ -using Ryujinx.Graphics.Shader.StructuredIr; using Ryujinx.Graphics.Shader.Translation; using System; @@ -25,17 +24,6 @@ namespace Ryujinx.Graphics.Shader }; } - public static VariableType ToVariableType(this AttributeType type) - { - return type switch - { - AttributeType.Float => VariableType.F32, - AttributeType.Sint => VariableType.S32, - AttributeType.Uint => VariableType.U32, - _ => throw new ArgumentException($"Invalid attribute type \"{type}\".") - }; - } - public static AggregateType ToAggregateType(this AttributeType type) { return type switch diff --git a/Ryujinx.Graphics.Shader/CodeGen/Glsl/Declarations.cs b/Ryujinx.Graphics.Shader/CodeGen/Glsl/Declarations.cs index c6e3b339..4da21cb7 100644 --- a/Ryujinx.Graphics.Shader/CodeGen/Glsl/Declarations.cs +++ b/Ryujinx.Graphics.Shader/CodeGen/Glsl/Declarations.cs @@ -350,19 +350,33 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl } } - public static string GetVarTypeName(VariableType type) + public static string GetVarTypeName(AggregateType type, bool precise = true) { - switch (type) + return type switch { - case VariableType.Bool: return "bool"; - case VariableType.F32: return "precise float"; - case VariableType.F64: return "double"; - case VariableType.None: return "void"; - case VariableType.S32: return "int"; - case VariableType.U32: return "uint"; - } - - throw new ArgumentException($"Invalid variable type \"{type}\"."); + AggregateType.Void => "void", + AggregateType.Bool => "bool", + AggregateType.FP32 => precise ? "precise float" : "float", + AggregateType.FP64 => "double", + AggregateType.S32 => "int", + AggregateType.U32 => "uint", + AggregateType.Vector2 | AggregateType.Bool => "bvec2", + AggregateType.Vector2 | AggregateType.FP32 => precise ? "precise vec2" : "vec2", + AggregateType.Vector2 | AggregateType.FP64 => "dvec2", + AggregateType.Vector2 | AggregateType.S32 => "ivec2", + AggregateType.Vector2 | AggregateType.U32 => "uvec2", + AggregateType.Vector3 | AggregateType.Bool => "bvec3", + AggregateType.Vector3 | AggregateType.FP32 => precise ? "precise vec3" : "vec3", + AggregateType.Vector3 | AggregateType.FP64 => "dvec3", + AggregateType.Vector3 | AggregateType.S32 => "ivec3", + AggregateType.Vector3 | AggregateType.U32 => "uvec3", + AggregateType.Vector4 | AggregateType.Bool => "bvec4", + AggregateType.Vector4 | AggregateType.FP32 => precise ? "precise vec4" : "vec4", + AggregateType.Vector4 | AggregateType.FP64 => "dvec4", + AggregateType.Vector4 | AggregateType.S32 => "ivec4", + AggregateType.Vector4 | AggregateType.U32 => "uvec4", + _ => throw new ArgumentException($"Invalid variable type \"{type}\".") + }; } private static void DeclareUniforms(CodeGenContext context, BufferDescriptor[] descriptors) diff --git a/Ryujinx.Graphics.Shader/CodeGen/Glsl/GlslGenerator.cs b/Ryujinx.Graphics.Shader/CodeGen/Glsl/GlslGenerator.cs index e1b8eb6e..90826a15 100644 --- a/Ryujinx.Graphics.Shader/CodeGen/Glsl/GlslGenerator.cs +++ b/Ryujinx.Graphics.Shader/CodeGen/Glsl/GlslGenerator.cs @@ -126,8 +126,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl } else if (node is AstAssignment assignment) { - VariableType srcType = OperandManager.GetNodeDestType(context, assignment.Source); - VariableType dstType = OperandManager.GetNodeDestType(context, assignment.Destination, isAsgDest: true); + AggregateType srcType = OperandManager.GetNodeDestType(context, assignment.Source); + AggregateType dstType = OperandManager.GetNodeDestType(context, assignment.Destination, isAsgDest: true); string dest; @@ -158,9 +158,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl private static string GetCondExpr(CodeGenContext context, IAstNode cond) { - VariableType srcType = OperandManager.GetNodeDestType(context, cond); + AggregateType srcType = OperandManager.GetNodeDestType(context, cond); - return ReinterpretCast(context, cond, srcType, VariableType.Bool); + return ReinterpretCast(context, cond, srcType, AggregateType.Bool); } } } \ No newline at end of file diff --git a/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGen.cs b/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGen.cs index b890b015..9ca4618d 100644 --- a/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGen.cs +++ b/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGen.cs @@ -1,5 +1,6 @@ using Ryujinx.Graphics.Shader.IntermediateRepresentation; using Ryujinx.Graphics.Shader.StructuredIr; +using Ryujinx.Graphics.Shader.Translation; using System; using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenBallot; @@ -8,6 +9,7 @@ using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenFSI; using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenHelper; using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenMemory; using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenPacking; +using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenVector; using static Ryujinx.Graphics.Shader.StructuredIr.InstructionInfo; namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions @@ -32,12 +34,12 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions { IAstNode src = operation.GetSource(0); - VariableType type = GetSrcVarType(operation.Inst, 0); + AggregateType type = GetSrcVarType(operation.Inst, 0); string srcExpr = GetSoureExpr(context, src, type); string zero; - if (type == VariableType.F64) + if (type == AggregateType.FP64) { zero = "0.0"; } @@ -95,7 +97,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions } else { - VariableType dstType = GetSrcVarType(inst, argIndex); + AggregateType dstType = GetSrcVarType(inst, argIndex); args += GetSoureExpr(context, operation.GetSource(argIndex), dstType); } @@ -226,6 +228,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions case Instruction.UnpackHalf2x16: return UnpackHalf2x16(context, operation); + + case Instruction.VectorExtract: + return VectorExtract(context, operation); } } diff --git a/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenBallot.cs b/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenBallot.cs index 51e7bd21..68793c5d 100644 --- a/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenBallot.cs +++ b/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenBallot.cs @@ -1,4 +1,5 @@ using Ryujinx.Graphics.Shader.StructuredIr; +using Ryujinx.Graphics.Shader.Translation; using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenHelper; using static Ryujinx.Graphics.Shader.StructuredIr.InstructionInfo; @@ -9,7 +10,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions { public static string Ballot(CodeGenContext context, AstOperation operation) { - VariableType dstType = GetSrcVarType(operation.Inst, 0); + AggregateType dstType = GetSrcVarType(operation.Inst, 0); string arg = GetSoureExpr(context, operation.GetSource(0), dstType); diff --git a/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenHelper.cs b/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenHelper.cs index c40f96f1..743b695c 100644 --- a/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenHelper.cs +++ b/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenHelper.cs @@ -1,5 +1,6 @@ using Ryujinx.Graphics.Shader.IntermediateRepresentation; using Ryujinx.Graphics.Shader.StructuredIr; +using Ryujinx.Graphics.Shader.Translation; using static Ryujinx.Graphics.Shader.CodeGen.Glsl.TypeConversion; @@ -7,11 +8,11 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions { static class InstGenHelper { - private static readonly InstInfo[] InfoTable; + private static readonly InstInfo[] _infoTable; static InstGenHelper() { - InfoTable = new InstInfo[(int)Instruction.Count]; + _infoTable = new InstInfo[(int)Instruction.Count]; Add(Instruction.AtomicAdd, InstType.AtomicBinary, "atomicAdd"); Add(Instruction.AtomicAnd, InstType.AtomicBinary, "atomicAnd"); @@ -132,6 +133,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions Add(Instruction.Truncate, InstType.CallUnary, "trunc"); Add(Instruction.UnpackDouble2x32, InstType.Special); Add(Instruction.UnpackHalf2x16, InstType.Special); + Add(Instruction.VectorExtract, InstType.Special); Add(Instruction.VoteAll, InstType.CallUnary, "allInvocationsARB"); Add(Instruction.VoteAllEqual, InstType.CallUnary, "allInvocationsEqualARB"); Add(Instruction.VoteAny, InstType.CallUnary, "anyInvocationARB"); @@ -139,15 +141,15 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions private static void Add(Instruction inst, InstType flags, string opName = null, int precedence = 0) { - InfoTable[(int)inst] = new InstInfo(flags, opName, precedence); + _infoTable[(int)inst] = new InstInfo(flags, opName, precedence); } public static InstInfo GetInstructionInfo(Instruction inst) { - return InfoTable[(int)(inst & Instruction.Mask)]; + return _infoTable[(int)(inst & Instruction.Mask)]; } - public static string GetSoureExpr(CodeGenContext context, IAstNode node, VariableType dstType) + public static string GetSoureExpr(CodeGenContext context, IAstNode node, AggregateType dstType) { return ReinterpretCast(context, node, OperandManager.GetNodeDestType(context, node), dstType); } @@ -191,7 +193,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions return false; } - InstInfo info = InfoTable[(int)(operation.Inst & Instruction.Mask)]; + InstInfo info = _infoTable[(int)(operation.Inst & Instruction.Mask)]; if ((info.Type & (InstType.Call | InstType.Special)) != 0) { diff --git a/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenMemory.cs b/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenMemory.cs index 022e3a44..f667d080 100644 --- a/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenMemory.cs +++ b/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenMemory.cs @@ -1,5 +1,6 @@ using Ryujinx.Graphics.Shader.IntermediateRepresentation; using Ryujinx.Graphics.Shader.StructuredIr; +using Ryujinx.Graphics.Shader.Translation; using System; using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenHelper; @@ -23,7 +24,17 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions case Instruction.ImageStore: return "// imageStore(bindless)"; case Instruction.ImageLoad: - NumberFormatter.TryFormat(0, texOp.Format.GetComponentType(), out string imageConst); + AggregateType componentType = texOp.Format.GetComponentType(); + + NumberFormatter.TryFormat(0, componentType, out string imageConst); + + AggregateType outputType = texOp.GetVectorType(componentType); + + if ((outputType & AggregateType.ElementCountMask) != 0) + { + return $"{Declarations.GetVarTypeName(outputType, precise: false)}({imageConst})"; + } + return imageConst; default: return NumberFormatter.FormatInt(0); @@ -58,7 +69,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions int srcIndex = isBindless ? 1 : 0; - string Src(VariableType type) + string Src(AggregateType type) { return GetSoureExpr(context, texOp.GetSource(srcIndex++), type); } @@ -67,7 +78,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions if (isIndexed) { - indexExpr = Src(VariableType.S32); + indexExpr = Src(AggregateType.S32); } string imageName = OperandManager.GetImageName(context.Config.Stage, texOp, indexExpr); @@ -113,19 +124,19 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions for (int index = 0; index < pCount; index++) { - elems[index] = Src(VariableType.S32); + elems[index] = Src(AggregateType.S32); } Append(ApplyScaling("ivec" + pCount + "(" + string.Join(", ", elems) + ")")); } else { - Append(Src(VariableType.S32)); + Append(Src(AggregateType.S32)); } if (texOp.Inst == Instruction.ImageStore) { - VariableType type = texOp.Format.GetComponentType(); + AggregateType type = texOp.Format.GetComponentType(); string[] cElems = new string[4]; @@ -139,8 +150,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions { cElems[index] = type switch { - VariableType.S32 => NumberFormatter.FormatInt(0), - VariableType.U32 => NumberFormatter.FormatUint(0), + AggregateType.S32 => NumberFormatter.FormatInt(0), + AggregateType.U32 => NumberFormatter.FormatUint(0), _ => NumberFormatter.FormatFloat(0) }; } @@ -148,8 +159,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions string prefix = type switch { - VariableType.S32 => "i", - VariableType.U32 => "u", + AggregateType.S32 => "i", + AggregateType.U32 => "u", _ => string.Empty }; @@ -158,7 +169,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions if (texOp.Inst == Instruction.ImageAtomic) { - VariableType type = texOp.Format.GetComponentType(); + AggregateType type = texOp.Format.GetComponentType(); if ((texOp.Flags & TextureFlags.AtomicMask) == TextureFlags.CAS) { @@ -176,14 +187,14 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions texCall += ")"; - if (type != VariableType.S32) + if (type != AggregateType.S32) { texCall = "int(" + texCall + ")"; } } else { - texCall += ")" + (texOp.Inst == Instruction.ImageLoad ? GetMask(texOp.Index) : ""); + texCall += ")" + (texOp.Inst == Instruction.ImageLoad ? GetMaskMultiDest(texOp.Index) : ""); } return texCall; @@ -288,7 +299,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions if (isIndexed) { - indexExpr = GetSoureExpr(context, texOp.GetSource(0), VariableType.S32); + indexExpr = GetSoureExpr(context, texOp.GetSource(0), AggregateType.S32); } string samplerName = OperandManager.GetSamplerName(context.Config.Stage, texOp, indexExpr); @@ -303,14 +314,14 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions for (int index = 0; index < coordsCount; index++) { - elems[index] = GetSoureExpr(context, texOp.GetSource(coordsIndex + index), VariableType.F32); + elems[index] = GetSoureExpr(context, texOp.GetSource(coordsIndex + index), AggregateType.FP32); } coordsExpr = "vec" + coordsCount + "(" + string.Join(", ", elems) + ")"; } else { - coordsExpr = GetSoureExpr(context, texOp.GetSource(coordsIndex), VariableType.F32); + coordsExpr = GetSoureExpr(context, texOp.GetSource(coordsIndex), AggregateType.FP32); } return $"textureQueryLod({samplerName}, {coordsExpr}){GetMask(texOp.Index)}"; @@ -362,9 +373,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions string offsetExpr = GetSoureExpr(context, src1, GetSrcVarType(operation.Inst, 0)); - VariableType srcType = OperandManager.GetNodeDestType(context, src2); + AggregateType srcType = OperandManager.GetNodeDestType(context, src2); - string src = TypeConversion.ReinterpretCast(context, src2, srcType, VariableType.U32); + string src = TypeConversion.ReinterpretCast(context, src2, srcType, AggregateType.U32); return $"{arrayName}[{offsetExpr}] = {src}"; } @@ -376,9 +387,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions string offsetExpr = GetSoureExpr(context, src1, GetSrcVarType(operation.Inst, 0)); - VariableType srcType = OperandManager.GetNodeDestType(context, src2); + AggregateType srcType = OperandManager.GetNodeDestType(context, src2); - string src = TypeConversion.ReinterpretCast(context, src2, srcType, VariableType.U32); + string src = TypeConversion.ReinterpretCast(context, src2, srcType, AggregateType.U32); return $"{HelperFunctionNames.StoreShared16}({offsetExpr}, {src})"; } @@ -390,9 +401,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions string offsetExpr = GetSoureExpr(context, src1, GetSrcVarType(operation.Inst, 0)); - VariableType srcType = OperandManager.GetNodeDestType(context, src2); + AggregateType srcType = OperandManager.GetNodeDestType(context, src2); - string src = TypeConversion.ReinterpretCast(context, src2, srcType, VariableType.U32); + string src = TypeConversion.ReinterpretCast(context, src2, srcType, AggregateType.U32); return $"{HelperFunctionNames.StoreShared8}({offsetExpr}, {src})"; } @@ -406,9 +417,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions string indexExpr = GetSoureExpr(context, src1, GetSrcVarType(operation.Inst, 0)); string offsetExpr = GetSoureExpr(context, src2, GetSrcVarType(operation.Inst, 1)); - VariableType srcType = OperandManager.GetNodeDestType(context, src3); + AggregateType srcType = OperandManager.GetNodeDestType(context, src3); - string src = TypeConversion.ReinterpretCast(context, src3, srcType, VariableType.U32); + string src = TypeConversion.ReinterpretCast(context, src3, srcType, AggregateType.U32); string sb = GetStorageBufferAccessor(indexExpr, offsetExpr, context.Config.Stage); @@ -424,9 +435,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions string indexExpr = GetSoureExpr(context, src1, GetSrcVarType(operation.Inst, 0)); string offsetExpr = GetSoureExpr(context, src2, GetSrcVarType(operation.Inst, 1)); - VariableType srcType = OperandManager.GetNodeDestType(context, src3); + AggregateType srcType = OperandManager.GetNodeDestType(context, src3); - string src = TypeConversion.ReinterpretCast(context, src3, srcType, VariableType.U32); + string src = TypeConversion.ReinterpretCast(context, src3, srcType, AggregateType.U32); string sb = GetStorageBufferAccessor(indexExpr, offsetExpr, context.Config.Stage); @@ -442,9 +453,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions string indexExpr = GetSoureExpr(context, src1, GetSrcVarType(operation.Inst, 0)); string offsetExpr = GetSoureExpr(context, src2, GetSrcVarType(operation.Inst, 1)); - VariableType srcType = OperandManager.GetNodeDestType(context, src3); + AggregateType srcType = OperandManager.GetNodeDestType(context, src3); - string src = TypeConversion.ReinterpretCast(context, src3, srcType, VariableType.U32); + string src = TypeConversion.ReinterpretCast(context, src3, srcType, AggregateType.U32); string sb = GetStorageBufferAccessor(indexExpr, offsetExpr, context.Config.Stage); @@ -469,6 +480,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions bool isMultisample = (texOp.Type & SamplerType.Multisample) != 0; bool isShadow = (texOp.Type & SamplerType.Shadow) != 0; + bool colorIsVector = isGather || !isShadow; + SamplerType type = texOp.Type & SamplerType.Mask; bool is2D = type == SamplerType.Texture2D; @@ -492,7 +505,19 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions // TODO: Bindless texture support. For now we just return 0. if (isBindless) { - return NumberFormatter.FormatFloat(0); + string scalarValue = NumberFormatter.FormatFloat(0); + + if (colorIsVector) + { + AggregateType outputType = texOp.GetVectorType(AggregateType.FP32); + + if ((outputType & AggregateType.ElementCountMask) != 0) + { + return $"{Declarations.GetVarTypeName(outputType, precise: false)}({scalarValue})"; + } + } + + return scalarValue; } string texCall = intCoords ? "texelFetch" : "texture"; @@ -521,7 +546,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions int srcIndex = isBindless ? 1 : 0; - string Src(VariableType type) + string Src(AggregateType type) { return GetSoureExpr(context, texOp.GetSource(srcIndex++), type); } @@ -530,7 +555,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions if (isIndexed) { - indexExpr = Src(VariableType.S32); + indexExpr = Src(AggregateType.S32); } string samplerName = OperandManager.GetSamplerName(context.Config.Stage, texOp, indexExpr); @@ -578,7 +603,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions texCall += ", " + str; } - VariableType coordType = intCoords ? VariableType.S32 : VariableType.F32; + AggregateType coordType = intCoords ? AggregateType.S32 : AggregateType.FP32; string AssemblePVector(int count) { @@ -590,7 +615,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions { if (arrayIndexElem == index) { - elems[index] = Src(VariableType.S32); + elems[index] = Src(AggregateType.S32); if (!intCoords) { @@ -652,20 +677,20 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions for (int index = 0; index < count; index++) { - elems[index] = Src(VariableType.F32); + elems[index] = Src(AggregateType.FP32); } return "vec" + count + "(" + string.Join(", ", elems) + ")"; } else { - return Src(VariableType.F32); + return Src(AggregateType.FP32); } } if (hasExtraCompareArg) { - Append(Src(VariableType.F32)); + Append(Src(AggregateType.FP32)); } if (hasDerivatives) @@ -676,7 +701,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions if (isMultisample) { - Append(Src(VariableType.S32)); + Append(Src(AggregateType.S32)); } else if (hasLodLevel) { @@ -691,14 +716,14 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions for (int index = 0; index < count; index++) { - elems[index] = Src(VariableType.S32); + elems[index] = Src(AggregateType.S32); } return "ivec" + count + "(" + string.Join(", ", elems) + ")"; } else { - return Src(VariableType.S32); + return Src(AggregateType.S32); } } @@ -718,17 +743,17 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions if (hasLodBias) { - Append(Src(VariableType.F32)); + Append(Src(AggregateType.FP32)); } // textureGather* optional extra component index, // not needed for shadow samplers. if (isGather && !isShadow) { - Append(Src(VariableType.S32)); + Append(Src(AggregateType.S32)); } - texCall += ")" + (isGather || !isShadow ? GetMask(texOp.Index) : ""); + texCall += ")" + (colorIsVector ? GetMaskMultiDest(texOp.Index) : ""); return texCall; } @@ -751,7 +776,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions if (isIndexed) { - indexExpr = GetSoureExpr(context, texOp.GetSource(0), VariableType.S32); + indexExpr = GetSoureExpr(context, texOp.GetSource(0), AggregateType.S32); } string samplerName = OperandManager.GetSamplerName(context.Config.Stage, texOp, indexExpr); @@ -804,5 +829,20 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions { return '.' + "rgba".Substring(index, 1); } + + private static string GetMaskMultiDest(int mask) + { + string swizzle = "."; + + for (int i = 0; i < 4; i++) + { + if ((mask & (1 << i)) != 0) + { + swizzle += "xyzw"[i]; + } + } + + return swizzle; + } } } \ No newline at end of file diff --git a/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenVector.cs b/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenVector.cs new file mode 100644 index 00000000..f09ea2e8 --- /dev/null +++ b/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenVector.cs @@ -0,0 +1,32 @@ +using Ryujinx.Graphics.Shader.IntermediateRepresentation; +using Ryujinx.Graphics.Shader.StructuredIr; + +using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenHelper; +using static Ryujinx.Graphics.Shader.StructuredIr.InstructionInfo; + +namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions +{ + static class InstGenVector + { + public static string VectorExtract(CodeGenContext context, AstOperation operation) + { + IAstNode vector = operation.GetSource(0); + IAstNode index = operation.GetSource(1); + + string vectorExpr = GetSoureExpr(context, vector, OperandManager.GetNodeDestType(context, vector)); + + if (index is AstOperand indexOperand && indexOperand.Type == OperandType.Constant) + { + char elem = "xyzw"[indexOperand.Value]; + + return $"{vectorExpr}.{elem}"; + } + else + { + string indexExpr = GetSoureExpr(context, index, GetSrcVarType(operation.Inst, 1)); + + return $"{vectorExpr}[{indexExpr}]"; + } + } + } +} \ No newline at end of file diff --git a/Ryujinx.Graphics.Shader/CodeGen/Glsl/NumberFormatter.cs b/Ryujinx.Graphics.Shader/CodeGen/Glsl/NumberFormatter.cs index 2ec44277..eb27e9bf 100644 --- a/Ryujinx.Graphics.Shader/CodeGen/Glsl/NumberFormatter.cs +++ b/Ryujinx.Graphics.Shader/CodeGen/Glsl/NumberFormatter.cs @@ -1,4 +1,4 @@ -using Ryujinx.Graphics.Shader.StructuredIr; +using Ryujinx.Graphics.Shader.Translation; using System; using System.Globalization; @@ -8,21 +8,21 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl { private const int MaxDecimal = 256; - public static bool TryFormat(int value, VariableType dstType, out string formatted) + public static bool TryFormat(int value, AggregateType dstType, out string formatted) { - if (dstType == VariableType.F32) + if (dstType == AggregateType.FP32) { return TryFormatFloat(BitConverter.Int32BitsToSingle(value), out formatted); } - else if (dstType == VariableType.S32) + else if (dstType == AggregateType.S32) { formatted = FormatInt(value); } - else if (dstType == VariableType.U32) + else if (dstType == AggregateType.U32) { formatted = FormatUint((uint)value); } - else if (dstType == VariableType.Bool) + else if (dstType == AggregateType.Bool) { formatted = value != 0 ? "true" : "false"; } @@ -65,13 +65,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl return true; } - public static string FormatInt(int value, VariableType dstType) + public static string FormatInt(int value, AggregateType dstType) { - if (dstType == VariableType.S32) + if (dstType == AggregateType.S32) { return FormatInt(value); } - else if (dstType == VariableType.U32) + else if (dstType == AggregateType.U32) { return FormatUint((uint)value); } diff --git a/Ryujinx.Graphics.Shader/CodeGen/Glsl/OperandManager.cs b/Ryujinx.Graphics.Shader/CodeGen/Glsl/OperandManager.cs index ccc87a7f..080f1708 100644 --- a/Ryujinx.Graphics.Shader/CodeGen/Glsl/OperandManager.cs +++ b/Ryujinx.Graphics.Shader/CodeGen/Glsl/OperandManager.cs @@ -4,6 +4,7 @@ using Ryujinx.Graphics.Shader.Translation; using System; using System.Collections.Generic; using System.Diagnostics; +using System.Numerics; using static Ryujinx.Graphics.Shader.StructuredIr.InstructionInfo; @@ -17,9 +18,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl { public string Name { get; } - public VariableType Type { get; } + public AggregateType Type { get; } - public BuiltInAttribute(string name, VariableType type) + public BuiltInAttribute(string name, AggregateType type) { Name = name; Type = type; @@ -28,64 +29,64 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl private static Dictionary _builtInAttributes = new Dictionary() { - { AttributeConsts.Layer, new BuiltInAttribute("gl_Layer", VariableType.S32) }, - { AttributeConsts.PointSize, new BuiltInAttribute("gl_PointSize", VariableType.F32) }, - { AttributeConsts.PositionX, new BuiltInAttribute("gl_Position.x", VariableType.F32) }, - { AttributeConsts.PositionY, new BuiltInAttribute("gl_Position.y", VariableType.F32) }, - { AttributeConsts.PositionZ, new BuiltInAttribute("gl_Position.z", VariableType.F32) }, - { AttributeConsts.PositionW, new BuiltInAttribute("gl_Position.w", VariableType.F32) }, - { AttributeConsts.ClipDistance0, new BuiltInAttribute("gl_ClipDistance[0]", VariableType.F32) }, - { AttributeConsts.ClipDistance1, new BuiltInAttribute("gl_ClipDistance[1]", VariableType.F32) }, - { AttributeConsts.ClipDistance2, new BuiltInAttribute("gl_ClipDistance[2]", VariableType.F32) }, - { AttributeConsts.ClipDistance3, new BuiltInAttribute("gl_ClipDistance[3]", VariableType.F32) }, - { AttributeConsts.ClipDistance4, new BuiltInAttribute("gl_ClipDistance[4]", VariableType.F32) }, - { AttributeConsts.ClipDistance5, new BuiltInAttribute("gl_ClipDistance[5]", VariableType.F32) }, - { AttributeConsts.ClipDistance6, new BuiltInAttribute("gl_ClipDistance[6]", VariableType.F32) }, - { AttributeConsts.ClipDistance7, new BuiltInAttribute("gl_ClipDistance[7]", VariableType.F32) }, - { AttributeConsts.PointCoordX, new BuiltInAttribute("gl_PointCoord.x", VariableType.F32) }, - { AttributeConsts.PointCoordY, new BuiltInAttribute("gl_PointCoord.y", VariableType.F32) }, - { AttributeConsts.TessCoordX, new BuiltInAttribute("gl_TessCoord.x", VariableType.F32) }, - { AttributeConsts.TessCoordY, new BuiltInAttribute("gl_TessCoord.y", VariableType.F32) }, - { AttributeConsts.InstanceId, new BuiltInAttribute("gl_InstanceID", VariableType.S32) }, - { AttributeConsts.VertexId, new BuiltInAttribute("gl_VertexID", VariableType.S32) }, - { AttributeConsts.BaseInstance, new BuiltInAttribute("gl_BaseInstanceARB", VariableType.S32) }, - { AttributeConsts.BaseVertex, new BuiltInAttribute("gl_BaseVertexARB", VariableType.S32) }, - { AttributeConsts.InstanceIndex, new BuiltInAttribute("gl_InstanceIndex", VariableType.S32) }, - { AttributeConsts.VertexIndex, new BuiltInAttribute("gl_VertexIndex", VariableType.S32) }, - { AttributeConsts.DrawIndex, new BuiltInAttribute("gl_DrawIDARB", VariableType.S32) }, - { AttributeConsts.FrontFacing, new BuiltInAttribute("gl_FrontFacing", VariableType.Bool) }, + { AttributeConsts.Layer, new BuiltInAttribute("gl_Layer", AggregateType.S32) }, + { AttributeConsts.PointSize, new BuiltInAttribute("gl_PointSize", AggregateType.FP32) }, + { AttributeConsts.PositionX, new BuiltInAttribute("gl_Position.x", AggregateType.FP32) }, + { AttributeConsts.PositionY, new BuiltInAttribute("gl_Position.y", AggregateType.FP32) }, + { AttributeConsts.PositionZ, new BuiltInAttribute("gl_Position.z", AggregateType.FP32) }, + { AttributeConsts.PositionW, new BuiltInAttribute("gl_Position.w", AggregateType.FP32) }, + { AttributeConsts.ClipDistance0, new BuiltInAttribute("gl_ClipDistance[0]", AggregateType.FP32) }, + { AttributeConsts.ClipDistance1, new BuiltInAttribute("gl_ClipDistance[1]", AggregateType.FP32) }, + { AttributeConsts.ClipDistance2, new BuiltInAttribute("gl_ClipDistance[2]", AggregateType.FP32) }, + { AttributeConsts.ClipDistance3, new BuiltInAttribute("gl_ClipDistance[3]", AggregateType.FP32) }, + { AttributeConsts.ClipDistance4, new BuiltInAttribute("gl_ClipDistance[4]", AggregateType.FP32) }, + { AttributeConsts.ClipDistance5, new BuiltInAttribute("gl_ClipDistance[5]", AggregateType.FP32) }, + { AttributeConsts.ClipDistance6, new BuiltInAttribute("gl_ClipDistance[6]", AggregateType.FP32) }, + { AttributeConsts.ClipDistance7, new BuiltInAttribute("gl_ClipDistance[7]", AggregateType.FP32) }, + { AttributeConsts.PointCoordX, new BuiltInAttribute("gl_PointCoord.x", AggregateType.FP32) }, + { AttributeConsts.PointCoordY, new BuiltInAttribute("gl_PointCoord.y", AggregateType.FP32) }, + { AttributeConsts.TessCoordX, new BuiltInAttribute("gl_TessCoord.x", AggregateType.FP32) }, + { AttributeConsts.TessCoordY, new BuiltInAttribute("gl_TessCoord.y", AggregateType.FP32) }, + { AttributeConsts.InstanceId, new BuiltInAttribute("gl_InstanceID", AggregateType.S32) }, + { AttributeConsts.VertexId, new BuiltInAttribute("gl_VertexID", AggregateType.S32) }, + { AttributeConsts.BaseInstance, new BuiltInAttribute("gl_BaseInstanceARB", AggregateType.S32) }, + { AttributeConsts.BaseVertex, new BuiltInAttribute("gl_BaseVertexARB", AggregateType.S32) }, + { AttributeConsts.InstanceIndex, new BuiltInAttribute("gl_InstanceIndex", AggregateType.S32) }, + { AttributeConsts.VertexIndex, new BuiltInAttribute("gl_VertexIndex", AggregateType.S32) }, + { AttributeConsts.DrawIndex, new BuiltInAttribute("gl_DrawIDARB", AggregateType.S32) }, + { AttributeConsts.FrontFacing, new BuiltInAttribute("gl_FrontFacing", AggregateType.Bool) }, // Special. - { AttributeConsts.FragmentOutputDepth, new BuiltInAttribute("gl_FragDepth", VariableType.F32) }, - { AttributeConsts.ThreadKill, new BuiltInAttribute("gl_HelperInvocation", VariableType.Bool) }, - { AttributeConsts.ThreadIdX, new BuiltInAttribute("gl_LocalInvocationID.x", VariableType.U32) }, - { AttributeConsts.ThreadIdY, new BuiltInAttribute("gl_LocalInvocationID.y", VariableType.U32) }, - { AttributeConsts.ThreadIdZ, new BuiltInAttribute("gl_LocalInvocationID.z", VariableType.U32) }, - { AttributeConsts.CtaIdX, new BuiltInAttribute("gl_WorkGroupID.x", VariableType.U32) }, - { AttributeConsts.CtaIdY, new BuiltInAttribute("gl_WorkGroupID.y", VariableType.U32) }, - { AttributeConsts.CtaIdZ, new BuiltInAttribute("gl_WorkGroupID.z", VariableType.U32) }, - { AttributeConsts.LaneId, new BuiltInAttribute(null, VariableType.U32) }, - { AttributeConsts.InvocationId, new BuiltInAttribute("gl_InvocationID", VariableType.S32) }, - { AttributeConsts.PrimitiveId, new BuiltInAttribute("gl_PrimitiveID", VariableType.S32) }, - { AttributeConsts.PatchVerticesIn, new BuiltInAttribute("gl_PatchVerticesIn", VariableType.S32) }, - { AttributeConsts.EqMask, new BuiltInAttribute(null, VariableType.U32) }, - { AttributeConsts.GeMask, new BuiltInAttribute(null, VariableType.U32) }, - { AttributeConsts.GtMask, new BuiltInAttribute(null, VariableType.U32) }, - { AttributeConsts.LeMask, new BuiltInAttribute(null, VariableType.U32) }, - { AttributeConsts.LtMask, new BuiltInAttribute(null, VariableType.U32) }, + { AttributeConsts.FragmentOutputDepth, new BuiltInAttribute("gl_FragDepth", AggregateType.FP32) }, + { AttributeConsts.ThreadKill, new BuiltInAttribute("gl_HelperInvocation", AggregateType.Bool) }, + { AttributeConsts.ThreadIdX, new BuiltInAttribute("gl_LocalInvocationID.x", AggregateType.U32) }, + { AttributeConsts.ThreadIdY, new BuiltInAttribute("gl_LocalInvocationID.y", AggregateType.U32) }, + { AttributeConsts.ThreadIdZ, new BuiltInAttribute("gl_LocalInvocationID.z", AggregateType.U32) }, + { AttributeConsts.CtaIdX, new BuiltInAttribute("gl_WorkGroupID.x", AggregateType.U32) }, + { AttributeConsts.CtaIdY, new BuiltInAttribute("gl_WorkGroupID.y", AggregateType.U32) }, + { AttributeConsts.CtaIdZ, new BuiltInAttribute("gl_WorkGroupID.z", AggregateType.U32) }, + { AttributeConsts.LaneId, new BuiltInAttribute(null, AggregateType.U32) }, + { AttributeConsts.InvocationId, new BuiltInAttribute("gl_InvocationID", AggregateType.S32) }, + { AttributeConsts.PrimitiveId, new BuiltInAttribute("gl_PrimitiveID", AggregateType.S32) }, + { AttributeConsts.PatchVerticesIn, new BuiltInAttribute("gl_PatchVerticesIn", AggregateType.S32) }, + { AttributeConsts.EqMask, new BuiltInAttribute(null, AggregateType.U32) }, + { AttributeConsts.GeMask, new BuiltInAttribute(null, AggregateType.U32) }, + { AttributeConsts.GtMask, new BuiltInAttribute(null, AggregateType.U32) }, + { AttributeConsts.LeMask, new BuiltInAttribute(null, AggregateType.U32) }, + { AttributeConsts.LtMask, new BuiltInAttribute(null, AggregateType.U32) }, // Support uniforms. - { AttributeConsts.FragmentOutputIsBgraBase + 0, new BuiltInAttribute($"{DefaultNames.SupportBlockIsBgraName}[0]", VariableType.Bool) }, - { AttributeConsts.FragmentOutputIsBgraBase + 4, new BuiltInAttribute($"{DefaultNames.SupportBlockIsBgraName}[1]", VariableType.Bool) }, - { AttributeConsts.FragmentOutputIsBgraBase + 8, new BuiltInAttribute($"{DefaultNames.SupportBlockIsBgraName}[2]", VariableType.Bool) }, - { AttributeConsts.FragmentOutputIsBgraBase + 12, new BuiltInAttribute($"{DefaultNames.SupportBlockIsBgraName}[3]", VariableType.Bool) }, - { AttributeConsts.FragmentOutputIsBgraBase + 16, new BuiltInAttribute($"{DefaultNames.SupportBlockIsBgraName}[4]", VariableType.Bool) }, - { AttributeConsts.FragmentOutputIsBgraBase + 20, new BuiltInAttribute($"{DefaultNames.SupportBlockIsBgraName}[5]", VariableType.Bool) }, - { AttributeConsts.FragmentOutputIsBgraBase + 24, new BuiltInAttribute($"{DefaultNames.SupportBlockIsBgraName}[6]", VariableType.Bool) }, - { AttributeConsts.FragmentOutputIsBgraBase + 28, new BuiltInAttribute($"{DefaultNames.SupportBlockIsBgraName}[7]", VariableType.Bool) }, + { AttributeConsts.FragmentOutputIsBgraBase + 0, new BuiltInAttribute($"{DefaultNames.SupportBlockIsBgraName}[0]", AggregateType.Bool) }, + { AttributeConsts.FragmentOutputIsBgraBase + 4, new BuiltInAttribute($"{DefaultNames.SupportBlockIsBgraName}[1]", AggregateType.Bool) }, + { AttributeConsts.FragmentOutputIsBgraBase + 8, new BuiltInAttribute($"{DefaultNames.SupportBlockIsBgraName}[2]", AggregateType.Bool) }, + { AttributeConsts.FragmentOutputIsBgraBase + 12, new BuiltInAttribute($"{DefaultNames.SupportBlockIsBgraName}[3]", AggregateType.Bool) }, + { AttributeConsts.FragmentOutputIsBgraBase + 16, new BuiltInAttribute($"{DefaultNames.SupportBlockIsBgraName}[4]", AggregateType.Bool) }, + { AttributeConsts.FragmentOutputIsBgraBase + 20, new BuiltInAttribute($"{DefaultNames.SupportBlockIsBgraName}[5]", AggregateType.Bool) }, + { AttributeConsts.FragmentOutputIsBgraBase + 24, new BuiltInAttribute($"{DefaultNames.SupportBlockIsBgraName}[6]", AggregateType.Bool) }, + { AttributeConsts.FragmentOutputIsBgraBase + 28, new BuiltInAttribute($"{DefaultNames.SupportBlockIsBgraName}[7]", AggregateType.Bool) }, - { AttributeConsts.SupportBlockViewInverseX, new BuiltInAttribute($"{DefaultNames.SupportBlockViewportInverse}.x", VariableType.F32) }, - { AttributeConsts.SupportBlockViewInverseY, new BuiltInAttribute($"{DefaultNames.SupportBlockViewportInverse}.y", VariableType.F32) } + { AttributeConsts.SupportBlockViewInverseX, new BuiltInAttribute($"{DefaultNames.SupportBlockViewportInverse}.x", AggregateType.FP32) }, + { AttributeConsts.SupportBlockViewInverseY, new BuiltInAttribute($"{DefaultNames.SupportBlockViewportInverse}.y", AggregateType.FP32) } }; private Dictionary _locals; @@ -329,7 +330,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl { if (cbIndexable) { - return GetUbName(stage, NumberFormatter.FormatInt(slot, VariableType.S32)); + return GetUbName(stage, NumberFormatter.FormatInt(slot, AggregateType.S32)); } return $"{GetShaderStagePrefix(stage)}_{DefaultNames.UniformNamePrefix}{slot}_{DefaultNames.UniformNameSuffix}"; @@ -404,7 +405,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl return $"{DefaultNames.ArgumentNamePrefix}{argIndex}"; } - public static VariableType GetNodeDestType(CodeGenContext context, IAstNode node, bool isAsgDest = false) + public static AggregateType GetNodeDestType(CodeGenContext context, IAstNode node, bool isAsgDest = false) { if (node is AstOperation operation) { @@ -431,12 +432,22 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl return context.GetFunction(funcId.Value).ReturnType; } - else if (operation is AstTextureOperation texOp && - (texOp.Inst == Instruction.ImageLoad || - texOp.Inst == Instruction.ImageStore || - texOp.Inst == Instruction.ImageAtomic)) + else if (operation.Inst == Instruction.VectorExtract) { - return texOp.Format.GetComponentType(); + return GetNodeDestType(context, operation.GetSource(0)) & ~AggregateType.ElementCountMask; + } + else if (operation is AstTextureOperation texOp) + { + if (texOp.Inst == Instruction.ImageLoad || + texOp.Inst == Instruction.ImageStore || + texOp.Inst == Instruction.ImageAtomic) + { + return texOp.GetVectorType(texOp.Format.GetComponentType()); + } + else if (texOp.Inst == Instruction.TextureSample) + { + return texOp.GetVectorType(GetDestVarType(operation.Inst)); + } } return GetDestVarType(operation.Inst); @@ -458,7 +469,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl } } - private static VariableType GetOperandVarType(CodeGenContext context, AstOperand operand, bool isAsgDest = false) + private static AggregateType GetOperandVarType(CodeGenContext context, AstOperand operand, bool isAsgDest = false) { if (operand.Type == OperandType.Attribute) { @@ -474,7 +485,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl AttributeType type = context.Config.GpuAccessor.QueryAttributeType(location); - return type.ToVariableType(); + return type.ToAggregateType(); } } diff --git a/Ryujinx.Graphics.Shader/CodeGen/Glsl/TypeConversion.cs b/Ryujinx.Graphics.Shader/CodeGen/Glsl/TypeConversion.cs index b13a74f4..22c8623c 100644 --- a/Ryujinx.Graphics.Shader/CodeGen/Glsl/TypeConversion.cs +++ b/Ryujinx.Graphics.Shader/CodeGen/Glsl/TypeConversion.cs @@ -1,6 +1,7 @@ using Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions; using Ryujinx.Graphics.Shader.IntermediateRepresentation; using Ryujinx.Graphics.Shader.StructuredIr; +using Ryujinx.Graphics.Shader.Translation; using System; namespace Ryujinx.Graphics.Shader.CodeGen.Glsl @@ -10,8 +11,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl public static string ReinterpretCast( CodeGenContext context, IAstNode node, - VariableType srcType, - VariableType dstType) + AggregateType srcType, + AggregateType dstType) { if (node is AstOperand operand && operand.Type == OperandType.Constant) { @@ -26,46 +27,46 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl return ReinterpretCast(expr, node, srcType, dstType); } - private static string ReinterpretCast(string expr, IAstNode node, VariableType srcType, VariableType dstType) + private static string ReinterpretCast(string expr, IAstNode node, AggregateType srcType, AggregateType dstType) { if (srcType == dstType) { return expr; } - if (srcType == VariableType.F32) + if (srcType == AggregateType.FP32) { switch (dstType) { - case VariableType.Bool: return $"(floatBitsToInt({expr}) != 0)"; - case VariableType.S32: return $"floatBitsToInt({expr})"; - case VariableType.U32: return $"floatBitsToUint({expr})"; + case AggregateType.Bool: return $"(floatBitsToInt({expr}) != 0)"; + case AggregateType.S32: return $"floatBitsToInt({expr})"; + case AggregateType.U32: return $"floatBitsToUint({expr})"; } } - else if (dstType == VariableType.F32) + else if (dstType == AggregateType.FP32) { switch (srcType) { - case VariableType.Bool: return $"intBitsToFloat({ReinterpretBoolToInt(expr, node, VariableType.S32)})"; - case VariableType.S32: return $"intBitsToFloat({expr})"; - case VariableType.U32: return $"uintBitsToFloat({expr})"; + case AggregateType.Bool: return $"intBitsToFloat({ReinterpretBoolToInt(expr, node, AggregateType.S32)})"; + case AggregateType.S32: return $"intBitsToFloat({expr})"; + case AggregateType.U32: return $"uintBitsToFloat({expr})"; } } - else if (srcType == VariableType.Bool) + else if (srcType == AggregateType.Bool) { return ReinterpretBoolToInt(expr, node, dstType); } - else if (dstType == VariableType.Bool) + else if (dstType == AggregateType.Bool) { expr = InstGenHelper.Enclose(expr, node, Instruction.CompareNotEqual, isLhs: true); return $"({expr} != 0)"; } - else if (dstType == VariableType.S32) + else if (dstType == AggregateType.S32) { return $"int({expr})"; } - else if (dstType == VariableType.U32) + else if (dstType == AggregateType.U32) { return $"uint({expr})"; } @@ -73,7 +74,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl throw new ArgumentException($"Invalid reinterpret cast from \"{srcType}\" to \"{dstType}\"."); } - private static string ReinterpretBoolToInt(string expr, IAstNode node, VariableType dstType) + private static string ReinterpretBoolToInt(string expr, IAstNode node, AggregateType dstType) { string trueExpr = NumberFormatter.FormatInt(IrConsts.True, dstType); string falseExpr = NumberFormatter.FormatInt(IrConsts.False, dstType); diff --git a/Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs b/Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs index dff5474a..41afdf18 100644 --- a/Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs +++ b/Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs @@ -241,6 +241,29 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv throw new NotImplementedException(node.GetType().Name); } + public Instruction GetWithType(IAstNode node, out AggregateType type) + { + if (node is AstOperation operation) + { + var opResult = Instructions.Generate(this, operation); + type = opResult.Type; + return opResult.Value; + } + else if (node is AstOperand operand) + { + switch (operand.Type) + { + case IrOperandType.LocalVariable: + type = operand.VarType; + return GetLocal(type, operand); + default: + throw new ArgumentException($"Invalid operand type \"{operand.Type}\"."); + } + } + + throw new NotImplementedException(node.GetType().Name); + } + private Instruction GetUndefined(AggregateType type) { return type switch @@ -325,7 +348,14 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv if (components > 1) { attrOffset &= ~0xf; - type = AggregateType.Vector | AggregateType.FP32; + type = components switch + { + 2 => AggregateType.Vector2 | AggregateType.FP32, + 3 => AggregateType.Vector3 | AggregateType.FP32, + 4 => AggregateType.Vector4 | AggregateType.FP32, + _ => AggregateType.FP32 + }; + attrInfo = new AttributeInfo(attrOffset, (attr - attrOffset) / 4, components, type, false); } } @@ -335,7 +365,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv bool isIndexed = AttributeInfo.IsArrayAttributeSpirv(Config.Stage, isOutAttr) && (!attrInfo.IsBuiltin || AttributeInfo.IsArrayBuiltIn(attr)); - if ((type & (AggregateType.Array | AggregateType.Vector)) == 0) + if ((type & (AggregateType.Array | AggregateType.ElementCountMask)) == 0) { if (invocationId != null) { @@ -452,7 +482,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv elemType = attrInfo.Type & AggregateType.ElementTypeMask; - if ((attrInfo.Type & (AggregateType.Array | AggregateType.Vector)) == 0) + if ((attrInfo.Type & (AggregateType.Array | AggregateType.ElementCountMask)) == 0) { return ioVariable; } @@ -533,13 +563,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv public Instruction GetLocal(AggregateType dstType, AstOperand local) { - var srcType = local.VarType.Convert(); + var srcType = local.VarType; return BitcastIfNeeded(dstType, srcType, Load(GetType(srcType), GetLocalPointer(local))); } public Instruction GetArgument(AggregateType dstType, AstOperand funcArg) { - var srcType = funcArg.VarType.Convert(); + var srcType = funcArg.VarType; return BitcastIfNeeded(dstType, srcType, Load(GetType(srcType), GetArgumentPointer(funcArg))); } @@ -550,13 +580,21 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv public Instruction GetType(AggregateType type, int length = 1) { - if (type.HasFlag(AggregateType.Array)) + if ((type & AggregateType.Array) != 0) { return TypeArray(GetType(type & ~AggregateType.Array), Constant(TypeU32(), length)); } - else if (type.HasFlag(AggregateType.Vector)) + else if ((type & AggregateType.ElementCountMask) != 0) { - return TypeVector(GetType(type & ~AggregateType.Vector), length); + int vectorLength = (type & AggregateType.ElementCountMask) switch + { + AggregateType.Vector2 => 2, + AggregateType.Vector3 => 3, + AggregateType.Vector4 => 4, + _ => 1 + }; + + return TypeVector(GetType(type & ~AggregateType.ElementCountMask), vectorLength); } return type switch diff --git a/Ryujinx.Graphics.Shader/CodeGen/Spirv/Declarations.cs b/Ryujinx.Graphics.Shader/CodeGen/Spirv/Declarations.cs index 819ece41..3da72b40 100644 --- a/Ryujinx.Graphics.Shader/CodeGen/Spirv/Declarations.cs +++ b/Ryujinx.Graphics.Shader/CodeGen/Spirv/Declarations.cs @@ -23,11 +23,11 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv DeclareParameters(context, function.OutArguments, function.InArguments.Length); } - private static void DeclareParameters(CodeGenContext context, IEnumerable argTypes, int argIndex) + private static void DeclareParameters(CodeGenContext context, IEnumerable argTypes, int argIndex) { foreach (var argType in argTypes) { - var argPointerType = context.TypePointer(StorageClass.Function, context.GetType(argType.Convert())); + var argPointerType = context.TypePointer(StorageClass.Function, context.GetType(argType)); var spvArg = context.FunctionParameter(argPointerType); context.DeclareArgument(argIndex++, spvArg); @@ -38,7 +38,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv { foreach (AstOperand local in function.Locals) { - var localPointerType = context.TypePointer(StorageClass.Function, context.GetType(local.VarType.Convert())); + var localPointerType = context.TypePointer(StorageClass.Function, context.GetType(local.VarType)); var spvLocal = context.Variable(localPointerType, StorageClass.Function); context.AddLocalVariable(spvLocal); @@ -62,7 +62,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv for (int i = 0; i < function.InArguments.Length; i++) { - var type = function.GetArgumentType(i).Convert(); + var type = function.GetArgumentType(i); var localPointerType = context.TypePointer(StorageClass.Function, context.GetType(type)); var spvLocal = context.Variable(localPointerType, StorageClass.Function); @@ -303,7 +303,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv var dim = GetDim(descriptor.Type); var imageType = context.TypeImage( - context.GetType(meta.Format.GetComponentType().Convert()), + context.GetType(meta.Format.GetComponentType()), dim, descriptor.Type.HasFlag(SamplerType.Shadow), descriptor.Type.HasFlag(SamplerType.Array), @@ -652,7 +652,14 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv if (components > 1) { attr &= ~0xf; - type = AggregateType.Vector | AggregateType.FP32; + type = components switch + { + 2 => AggregateType.Vector2 | AggregateType.FP32, + 3 => AggregateType.Vector3 | AggregateType.FP32, + 4 => AggregateType.Vector4 | AggregateType.FP32, + _ => AggregateType.FP32 + }; + hasComponent = false; } } diff --git a/Ryujinx.Graphics.Shader/CodeGen/Spirv/EnumConversion.cs b/Ryujinx.Graphics.Shader/CodeGen/Spirv/EnumConversion.cs index 0ddb4264..aa3d046a 100644 --- a/Ryujinx.Graphics.Shader/CodeGen/Spirv/EnumConversion.cs +++ b/Ryujinx.Graphics.Shader/CodeGen/Spirv/EnumConversion.cs @@ -1,5 +1,4 @@ -using Ryujinx.Graphics.Shader.StructuredIr; -using Ryujinx.Graphics.Shader.Translation; +using Ryujinx.Graphics.Shader.Translation; using System; using static Spv.Specification; @@ -7,20 +6,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv { static class EnumConversion { - public static AggregateType Convert(this VariableType type) - { - return type switch - { - VariableType.None => AggregateType.Void, - VariableType.Bool => AggregateType.Bool, - VariableType.F32 => AggregateType.FP32, - VariableType.F64 => AggregateType.FP64, - VariableType.S32 => AggregateType.S32, - VariableType.U32 => AggregateType.U32, - _ => throw new ArgumentException($"Invalid variable type \"{type}\".") - }; - } - public static ExecutionModel Convert(this ShaderStage stage) { return stage switch diff --git a/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs b/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs index ae280377..a02c4c22 100644 --- a/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs +++ b/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs @@ -4,6 +4,7 @@ using Ryujinx.Graphics.Shader.Translation; using System; using System.Collections.Generic; using System.Diagnostics; +using System.Numerics; using static Spv.Specification; namespace Ryujinx.Graphics.Shader.CodeGen.Spirv @@ -146,6 +147,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv Add(Instruction.Truncate, GenerateTruncate); Add(Instruction.UnpackDouble2x32, GenerateUnpackDouble2x32); Add(Instruction.UnpackHalf2x16, GenerateUnpackHalf2x16); + Add(Instruction.VectorExtract, GenerateVectorExtract); Add(Instruction.VoteAll, GenerateVoteAll); Add(Instruction.VoteAllEqual, GenerateVoteAllEqual); Add(Instruction.VoteAny, GenerateVoteAny); @@ -317,7 +319,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv } else { - var type = function.GetArgumentType(i).Convert(); + var type = function.GetArgumentType(i); var value = context.Get(type, operand); var spvLocal = spvLocals[i]; @@ -327,7 +329,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv } } - var retType = function.ReturnType.Convert(); + var retType = function.ReturnType; var result = context.FunctionCall(context.GetType(retType), spvFunc, args); return new OperationResult(retType, result); } @@ -604,10 +606,10 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv // TODO: Bindless texture support. For now we just return 0/do nothing. if (isBindless) { - return new OperationResult(componentType.Convert(), componentType switch + return new OperationResult(componentType, componentType switch { - VariableType.S32 => context.Constant(context.TypeS32(), 0), - VariableType.U32 => context.Constant(context.TypeU32(), 0u), + AggregateType.S32 => context.Constant(context.TypeS32(), 0), + AggregateType.U32 => context.Constant(context.TypeU32(), 0u), _ => context.Constant(context.TypeFP32(), 0f), }); } @@ -652,13 +654,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv pCoords = Src(AggregateType.S32); } - SpvInstruction value = Src(componentType.Convert()); + SpvInstruction value = Src(componentType); (var imageType, var imageVariable) = context.Images[new TextureMeta(texOp.CbufSlot, texOp.Handle, texOp.Format)]; var image = context.Load(imageType, imageVariable); - SpvInstruction resultType = context.GetType(componentType.Convert()); + SpvInstruction resultType = context.GetType(componentType); SpvInstruction imagePointerType = context.TypePointer(StorageClass.Image, resultType); var pointer = context.ImageTexelPointer(imagePointerType, imageVariable, pCoords, context.Constant(context.TypeU32(), 0)); @@ -668,10 +670,10 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv var result = (texOp.Flags & TextureFlags.AtomicMask) switch { TextureFlags.Add => context.AtomicIAdd(resultType, pointer, one, zero, value), - TextureFlags.Minimum => componentType == VariableType.S32 + TextureFlags.Minimum => componentType == AggregateType.S32 ? context.AtomicSMin(resultType, pointer, one, zero, value) : context.AtomicUMin(resultType, pointer, one, zero, value), - TextureFlags.Maximum => componentType == VariableType.S32 + TextureFlags.Maximum => componentType == AggregateType.S32 ? context.AtomicSMax(resultType, pointer, one, zero, value) : context.AtomicUMax(resultType, pointer, one, zero, value), TextureFlags.Increment => context.AtomicIIncrement(resultType, pointer, one, zero), @@ -680,11 +682,11 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv TextureFlags.BitwiseOr => context.AtomicOr(resultType, pointer, one, zero, value), TextureFlags.BitwiseXor => context.AtomicXor(resultType, pointer, one, zero, value), TextureFlags.Swap => context.AtomicExchange(resultType, pointer, one, zero, value), - TextureFlags.CAS => context.AtomicCompareExchange(resultType, pointer, one, zero, zero, Src(componentType.Convert()), value), + TextureFlags.CAS => context.AtomicCompareExchange(resultType, pointer, one, zero, zero, Src(componentType), value), _ => context.AtomicIAdd(resultType, pointer, one, zero, value), }; - return new OperationResult(componentType.Convert(), result); + return new OperationResult(componentType, result); } private static OperationResult GenerateImageLoad(CodeGenContext context, AstOperation operation) @@ -698,14 +700,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv // TODO: Bindless texture support. For now we just return 0/do nothing. if (isBindless) { - var zero = componentType switch - { - VariableType.S32 => context.Constant(context.TypeS32(), 0), - VariableType.U32 => context.Constant(context.TypeU32(), 0u), - _ => context.Constant(context.TypeFP32(), 0f), - }; - - return new OperationResult(componentType.Convert(), zero); + return GetZeroOperationResult(context, texOp, componentType, isVector: true); } bool isArray = (texOp.Type & SamplerType.Array) != 0; @@ -753,12 +748,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv (var imageType, var imageVariable) = context.Images[new TextureMeta(texOp.CbufSlot, texOp.Handle, texOp.Format)]; var image = context.Load(imageType, imageVariable); - var imageComponentType = context.GetType(componentType.Convert()); + var imageComponentType = context.GetType(componentType); + var swizzledResultType = texOp.GetVectorType(componentType); var texel = context.ImageRead(context.TypeVector(imageComponentType, 4), image, pCoords, ImageOperandsMask.MaskNone); - var result = context.CompositeExtract(imageComponentType, texel, (SpvLiteralInteger)texOp.Index); + var result = GetSwizzledResult(context, texel, swizzledResultType, texOp.Index); - return new OperationResult(componentType.Convert(), result); + return new OperationResult(componentType, result); } private static OperationResult GenerateImageStore(CodeGenContext context, AstOperation operation) @@ -823,20 +819,20 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv { if (srcIndex < texOp.SourcesCount) { - cElems[i] = Src(componentType.Convert()); + cElems[i] = Src(componentType); } else { cElems[i] = componentType switch { - VariableType.S32 => context.Constant(context.TypeS32(), 0), - VariableType.U32 => context.Constant(context.TypeU32(), 0u), + AggregateType.S32 => context.Constant(context.TypeS32(), 0), + AggregateType.U32 => context.Constant(context.TypeU32(), 0u), _ => context.Constant(context.TypeFP32(), 0f), }; } } - var texel = context.CompositeConstruct(context.TypeVector(context.GetType(componentType.Convert()), ComponentsCount), cElems); + var texel = context.CompositeConstruct(context.TypeVector(context.GetType(componentType), ComponentsCount), cElems); (var imageType, var imageVariable) = context.Images[new TextureMeta(texOp.CbufSlot, texOp.Handle, texOp.Format)]; @@ -1238,7 +1234,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv var validLocal = (AstOperand)operation.GetSource(3); - context.Store(context.GetLocalPointer(validLocal), context.BitcastIfNeeded(validLocal.VarType.Convert(), AggregateType.Bool, valid)); + context.Store(context.GetLocalPointer(validLocal), context.BitcastIfNeeded(validLocal.VarType, AggregateType.Bool, valid)); return new OperationResult(AggregateType.FP32, result); } @@ -1268,7 +1264,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv var validLocal = (AstOperand)operation.GetSource(3); - context.Store(context.GetLocalPointer(validLocal), context.BitcastIfNeeded(validLocal.VarType.Convert(), AggregateType.Bool, valid)); + context.Store(context.GetLocalPointer(validLocal), context.BitcastIfNeeded(validLocal.VarType, AggregateType.Bool, valid)); return new OperationResult(AggregateType.FP32, result); } @@ -1294,7 +1290,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv var validLocal = (AstOperand)operation.GetSource(3); - context.Store(context.GetLocalPointer(validLocal), context.BitcastIfNeeded(validLocal.VarType.Convert(), AggregateType.Bool, valid)); + context.Store(context.GetLocalPointer(validLocal), context.BitcastIfNeeded(validLocal.VarType, AggregateType.Bool, valid)); return new OperationResult(AggregateType.FP32, result); } @@ -1324,7 +1320,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv var validLocal = (AstOperand)operation.GetSource(3); - context.Store(context.GetLocalPointer(validLocal), context.BitcastIfNeeded(validLocal.VarType.Convert(), AggregateType.Bool, valid)); + context.Store(context.GetLocalPointer(validLocal), context.BitcastIfNeeded(validLocal.VarType, AggregateType.Bool, valid)); return new OperationResult(AggregateType.FP32, result); } @@ -1485,10 +1481,12 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv bool isMultisample = (texOp.Type & SamplerType.Multisample) != 0; bool isShadow = (texOp.Type & SamplerType.Shadow) != 0; + bool colorIsVector = isGather || !isShadow; + // TODO: Bindless texture support. For now we just return 0. if (isBindless) { - return new OperationResult(AggregateType.FP32, context.Constant(context.TypeFP32(), 0f)); + return GetZeroOperationResult(context, texOp, AggregateType.FP32, colorIsVector); } // This combination is valid, but not available on GLSL. @@ -1705,7 +1703,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv operandsList.Add(sample); } - bool colorIsVector = isGather || !isShadow; var resultType = colorIsVector ? context.TypeVector(context.TypeFP32(), 4) : context.TypeFP32(); var meta = new TextureMeta(texOp.CbufSlot, texOp.Handle, texOp.Format); @@ -1758,12 +1755,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv result = context.ImageSampleImplicitLod(resultType, image, pCoords, operandsMask, operands); } + var swizzledResultType = AggregateType.FP32; + if (colorIsVector) { - result = context.CompositeExtract(context.TypeFP32(), result, (SpvLiteralInteger)texOp.Index); + swizzledResultType = texOp.GetVectorType(swizzledResultType); + + result = GetSwizzledResult(context, result, swizzledResultType, texOp.Index); } - return new OperationResult(AggregateType.FP32, result); + return new OperationResult(swizzledResultType, result); } private static OperationResult GenerateTextureSize(CodeGenContext context, AstOperation operation) @@ -1862,6 +1863,26 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv return new OperationResult(AggregateType.FP32, result); } + private static OperationResult GenerateVectorExtract(CodeGenContext context, AstOperation operation) + { + var vector = context.GetWithType(operation.GetSource(0), out AggregateType vectorType); + var scalarType = vectorType & ~AggregateType.ElementCountMask; + var resultType = context.GetType(scalarType); + SpvInstruction result; + + if (operation.GetSource(1) is AstOperand indexOperand && indexOperand.Type == OperandType.Constant) + { + result = context.CompositeExtract(resultType, vector, (SpvLiteralInteger)indexOperand.Value); + } + else + { + var index = context.Get(AggregateType.S32, operation.GetSource(1)); + result = context.VectorExtractDynamic(resultType, vector, index); + } + + return new OperationResult(scalarType, result); + } + private static OperationResult GenerateVoteAll(CodeGenContext context, AstOperation operation) { var execution = context.Constant(context.TypeU32(), Scope.Subgroup); @@ -2044,6 +2065,64 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv context.AddLabel(loopEnd); } + private static OperationResult GetZeroOperationResult( + CodeGenContext context, + AstTextureOperation texOp, + AggregateType scalarType, + bool isVector) + { + var zero = scalarType switch + { + AggregateType.S32 => context.Constant(context.TypeS32(), 0), + AggregateType.U32 => context.Constant(context.TypeU32(), 0u), + _ => context.Constant(context.TypeFP32(), 0f), + }; + + if (isVector) + { + AggregateType outputType = texOp.GetVectorType(scalarType); + + if ((outputType & AggregateType.ElementCountMask) != 0) + { + int componentsCount = BitOperations.PopCount((uint)texOp.Index); + + SpvInstruction[] values = new SpvInstruction[componentsCount]; + + values.AsSpan().Fill(zero); + + return new OperationResult(outputType, context.ConstantComposite(context.GetType(outputType), values)); + } + } + + return new OperationResult(scalarType, zero); + } + + private static SpvInstruction GetSwizzledResult(CodeGenContext context, SpvInstruction vector, AggregateType swizzledResultType, int mask) + { + if ((swizzledResultType & AggregateType.ElementCountMask) != 0) + { + SpvLiteralInteger[] components = new SpvLiteralInteger[BitOperations.PopCount((uint)mask)]; + + int componentIndex = 0; + + for (int i = 0; i < 4; i++) + { + if ((mask & (1 << i)) != 0) + { + components[componentIndex++] = i; + } + } + + return context.VectorShuffle(context.GetType(swizzledResultType), vector, vector, components); + } + else + { + int componentIndex = (int)BitOperations.TrailingZeroCount(mask); + + return context.CompositeExtract(context.GetType(swizzledResultType), vector, (SpvLiteralInteger)componentIndex); + } + } + private static SpvInstruction GetStorageElemPointer(CodeGenContext context, AstOperation operation) { var sbVariable = context.StorageBuffersArray; diff --git a/Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvGenerator.cs b/Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvGenerator.cs index 6e1db972..9f08b319 100644 --- a/Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvGenerator.cs +++ b/Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvGenerator.cs @@ -104,13 +104,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv for (int funcIndex = 0; funcIndex < info.Functions.Count; funcIndex++) { var function = info.Functions[funcIndex]; - var retType = context.GetType(function.ReturnType.Convert()); + var retType = context.GetType(function.ReturnType); var funcArgs = new SpvInstruction[function.InArguments.Length + function.OutArguments.Length]; for (int argIndex = 0; argIndex < funcArgs.Length; argIndex++) { - var argType = context.GetType(function.GetArgumentType(argIndex).Convert()); + var argType = context.GetType(function.GetArgumentType(argIndex)); var argPointerType = context.TypePointer(StorageClass.Function, argType); funcArgs[argIndex] = argPointerType; } @@ -387,7 +387,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv if (dest.Type == OperandType.LocalVariable) { - var source = context.Get(dest.VarType.Convert(), assignment.Source); + var source = context.Get(dest.VarType, assignment.Source); context.Store(context.GetLocalPointer(dest), source); } else if (dest.Type == OperandType.Attribute || dest.Type == OperandType.AttributePerPatch) @@ -407,7 +407,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv } else if (dest.Type == OperandType.Argument) { - var source = context.Get(dest.VarType.Convert(), assignment.Source); + var source = context.Get(dest.VarType, assignment.Source); context.Store(context.GetArgumentPointer(dest), source); } else diff --git a/Ryujinx.Graphics.Shader/Instructions/InstEmitSurface.cs b/Ryujinx.Graphics.Shader/Instructions/InstEmitSurface.cs index a81362b8..3d94b893 100644 --- a/Ryujinx.Graphics.Shader/Instructions/InstEmitSurface.cs +++ b/Ryujinx.Graphics.Shader/Instructions/InstEmitSurface.cs @@ -1,7 +1,9 @@ using Ryujinx.Graphics.Shader.Decoders; using Ryujinx.Graphics.Shader.IntermediateRepresentation; using Ryujinx.Graphics.Shader.Translation; +using System; using System.Collections.Generic; +using System.Numerics; using static Ryujinx.Graphics.Shader.Instructions.InstEmitHelper; using static Ryujinx.Graphics.Shader.IntermediateRepresentation.OperandHelper; @@ -217,15 +219,7 @@ namespace Ryujinx.Graphics.Shader.Instructions return context.Copy(Register(srcB++, RegisterType.Gpr)); } - Operand GetDest() - { - if (dest >= RegisterConsts.RegisterZeroIndex) - { - return null; - } - - return Register(dest++, RegisterType.Gpr); - } + Operand destOperand = dest != RegisterConsts.RegisterZeroIndex ? Register(dest, RegisterType.Gpr) : null; List sourcesList = new List(); @@ -291,7 +285,7 @@ namespace Ryujinx.Graphics.Shader.Instructions flags, imm, 0, - GetDest(), + new[] { destOperand }, sources); context.Add(operation); @@ -371,36 +365,40 @@ namespace Ryujinx.Graphics.Shader.Instructions if (useComponents) { - for (int compMask = (int)componentMask, compIndex = 0; compMask != 0; compMask >>= 1, compIndex++) - { - if ((compMask & 1) == 0) - { - continue; - } + Operand[] dests = new Operand[BitOperations.PopCount((uint)componentMask)]; - if (srcB == RegisterConsts.RegisterZeroIndex) + int outputIndex = 0; + + for (int i = 0; i < dests.Length; i++) + { + if (srcB + i >= RegisterConsts.RegisterZeroIndex) { break; } - Operand rd = Register(srcB++, RegisterType.Gpr); - - TextureOperation operation = context.CreateTextureOperation( - Instruction.ImageLoad, - type, - flags, - handle, - compIndex, - rd, - sources); - - if (!isBindless) - { - operation.Format = context.Config.GetTextureFormat(handle); - } - - context.Add(operation); + dests[outputIndex++] = Register(srcB + i, RegisterType.Gpr); } + + if (outputIndex != dests.Length) + { + Array.Resize(ref dests, outputIndex); + } + + TextureOperation operation = context.CreateTextureOperation( + Instruction.ImageLoad, + type, + flags, + handle, + (int)componentMask, + dests, + sources); + + if (!isBindless) + { + operation.Format = context.Config.GetTextureFormat(handle); + } + + context.Add(operation); } else { @@ -412,35 +410,45 @@ namespace Ryujinx.Graphics.Shader.Instructions } int components = GetComponents(size); + int compMask = (1 << components) - 1; - for (int compIndex = 0; compIndex < components; compIndex++) + Operand[] dests = new Operand[components]; + + int outputIndex = 0; + + for (int i = 0; i < dests.Length; i++) { - if (srcB == RegisterConsts.RegisterZeroIndex) + if (srcB + i >= RegisterConsts.RegisterZeroIndex) { break; } - Operand rd = Register(srcB++, RegisterType.Gpr); + dests[outputIndex++] = Register(srcB + i, RegisterType.Gpr); + } - TextureOperation operation = context.CreateTextureOperation( - Instruction.ImageLoad, - type, - GetTextureFormat(size), - flags, - handle, - compIndex, - rd, - sources); + if (outputIndex != dests.Length) + { + Array.Resize(ref dests, outputIndex); + } - context.Add(operation); + TextureOperation operation = context.CreateTextureOperation( + Instruction.ImageLoad, + type, + GetTextureFormat(size), + flags, + handle, + compMask, + dests, + sources); - switch (size) - { - case SuSize.U8: context.Copy(rd, ZeroExtendTo32(context, rd, 8)); break; - case SuSize.U16: context.Copy(rd, ZeroExtendTo32(context, rd, 16)); break; - case SuSize.S8: context.Copy(rd, SignExtendTo32(context, rd, 8)); break; - case SuSize.S16: context.Copy(rd, SignExtendTo32(context, rd, 16)); break; - } + context.Add(operation); + + switch (size) + { + case SuSize.U8: context.Copy(dests[0], ZeroExtendTo32(context, dests[0], 8)); break; + case SuSize.U16: context.Copy(dests[0], ZeroExtendTo32(context, dests[0], 16)); break; + case SuSize.S8: context.Copy(dests[0], SignExtendTo32(context, dests[0], 8)); break; + case SuSize.S16: context.Copy(dests[0], SignExtendTo32(context, dests[0], 16)); break; } } } diff --git a/Ryujinx.Graphics.Shader/Instructions/InstEmitTexture.cs b/Ryujinx.Graphics.Shader/Instructions/InstEmitTexture.cs index c54b79cd..caa9a775 100644 --- a/Ryujinx.Graphics.Shader/Instructions/InstEmitTexture.cs +++ b/Ryujinx.Graphics.Shader/Instructions/InstEmitTexture.cs @@ -3,6 +3,7 @@ using Ryujinx.Graphics.Shader.IntermediateRepresentation; using Ryujinx.Graphics.Shader.Translation; using System; using System.Collections.Generic; +using System.Numerics; using static Ryujinx.Graphics.Shader.IntermediateRepresentation.OperandHelper; @@ -303,42 +304,37 @@ namespace Ryujinx.Graphics.Shader.Instructions } Operand[] sources = sourcesList.ToArray(); + Operand[] dests = new Operand[BitOperations.PopCount((uint)componentMask)]; - Operand GetDest() + int outputIndex = 0; + + for (int i = 0; i < dests.Length; i++) { - if (rdIndex >= RegisterConsts.RegisterZeroIndex) + if (rdIndex + i >= RegisterConsts.RegisterZeroIndex) { - return null; + break; } - return Register(rdIndex++, RegisterType.Gpr); + dests[outputIndex++] = Register(rdIndex + i, RegisterType.Gpr); + } + + if (outputIndex != dests.Length) + { + Array.Resize(ref dests, outputIndex); } int handle = !isBindless ? imm : 0; - for (int compMask = componentMask, compIndex = 0; compMask != 0; compMask >>= 1, compIndex++) - { - if ((compMask & 1) != 0) - { - Operand dest = GetDest(); + TextureOperation operation = context.CreateTextureOperation( + Instruction.TextureSample, + type, + flags, + handle, + componentMask, + dests, + sources); - if (dest == null) - { - break; - } - - TextureOperation operation = context.CreateTextureOperation( - Instruction.TextureSample, - type, - flags, - handle, - compIndex, - dest, - sources); - - context.Add(operation); - } - } + context.Add(operation); } private static void EmitTexs( @@ -624,18 +620,23 @@ namespace Ryujinx.Graphics.Shader.Instructions Operand[] rd0 = new Operand[2] { ConstF(0), ConstF(0) }; Operand[] rd1 = new Operand[2] { ConstF(0), ConstF(0) }; - int destIncrement = 0; + int handle = imm; + int componentMask = _maskLut[dest2 == RegisterConsts.RegisterZeroIndex ? 0 : 1, writeMask]; - Operand GetDest() + int componentsCount = BitOperations.PopCount((uint)componentMask); + + Operand[] dests = new Operand[componentsCount]; + + int outputIndex = 0; + + for (int i = 0; i < componentsCount; i++) { - int high = destIncrement >> 1; - int low = destIncrement & 1; - - destIncrement++; + int high = i >> 1; + int low = i & 1; if (isF16) { - return high != 0 + dests[outputIndex++] = high != 0 ? (rd1[low] = Local()) : (rd0[low] = Local()); } @@ -648,30 +649,26 @@ namespace Ryujinx.Graphics.Shader.Instructions rdIndex += low; } - return Register(rdIndex, RegisterType.Gpr); + dests[outputIndex++] = Register(rdIndex, RegisterType.Gpr); } } - int handle = imm; - int componentMask = _maskLut[dest2 == RegisterConsts.RegisterZeroIndex ? 0 : 1, writeMask]; - - for (int compMask = componentMask, compIndex = 0; compMask != 0; compMask >>= 1, compIndex++) + if (outputIndex != dests.Length) { - if ((compMask & 1) != 0) - { - TextureOperation operation = context.CreateTextureOperation( - Instruction.TextureSample, - type, - flags, - handle, - compIndex, - GetDest(), - sources); - - context.Add(operation); - } + Array.Resize(ref dests, outputIndex); } + TextureOperation operation = context.CreateTextureOperation( + Instruction.TextureSample, + type, + flags, + handle, + componentMask, + dests, + sources); + + context.Add(operation); + if (isF16) { context.Copy(Register(dest, RegisterType.Gpr), context.PackHalf2x16(rd0[0], rd0[1])); @@ -797,42 +794,37 @@ namespace Ryujinx.Graphics.Shader.Instructions sourcesList.Add(Const((int)component)); Operand[] sources = sourcesList.ToArray(); + Operand[] dests = new Operand[BitOperations.PopCount((uint)componentMask)]; - Operand GetDest() + int outputIndex = 0; + + for (int i = 0; i < dests.Length; i++) { - if (dest >= RegisterConsts.RegisterZeroIndex) + if (dest + i >= RegisterConsts.RegisterZeroIndex) { - return null; + break; } - return Register(dest++, RegisterType.Gpr); + dests[outputIndex++] = Register(dest + i, RegisterType.Gpr); + } + + if (outputIndex != dests.Length) + { + Array.Resize(ref dests, outputIndex); } int handle = imm; - for (int compMask = componentMask, compIndex = 0; compMask != 0; compMask >>= 1, compIndex++) - { - if ((compMask & 1) != 0) - { - Operand destOperand = GetDest(); + TextureOperation operation = context.CreateTextureOperation( + Instruction.TextureSample, + type, + flags, + handle, + componentMask, + dests, + sources); - if (destOperand == null) - { - break; - } - - TextureOperation operation = context.CreateTextureOperation( - Instruction.TextureSample, - type, - flags, - handle, - compIndex, - destOperand, - sources); - - context.Add(operation); - } - } + context.Add(operation); } private static void EmitTmml( @@ -951,7 +943,7 @@ namespace Ryujinx.Graphics.Shader.Instructions flags, handle, compIndex ^ 1, // The instruction component order is the inverse of GLSL's. - tempDest, + new[] { tempDest }, sources); context.Add(operation); @@ -1071,42 +1063,37 @@ namespace Ryujinx.Graphics.Shader.Instructions } Operand[] sources = sourcesList.ToArray(); + Operand[] dests = new Operand[BitOperations.PopCount((uint)componentMask)]; - Operand GetDest() + int outputIndex = 0; + + for (int i = 0; i < dests.Length; i++) { - if (dest >= RegisterConsts.RegisterZeroIndex) + if (dest + i >= RegisterConsts.RegisterZeroIndex) { - return null; + break; } - return Register(dest++, RegisterType.Gpr); + dests[outputIndex++] = Register(dest + i, RegisterType.Gpr); + } + + if (outputIndex != dests.Length) + { + Array.Resize(ref dests, outputIndex); } int handle = imm; - for (int compMask = componentMask, compIndex = 0; compMask != 0; compMask >>= 1, compIndex++) - { - if ((compMask & 1) != 0) - { - Operand destOperand = GetDest(); + TextureOperation operation = context.CreateTextureOperation( + Instruction.TextureSample, + type, + flags, + handle, + componentMask, + dests, + sources); - if (destOperand == null) - { - break; - } - - TextureOperation operation = context.CreateTextureOperation( - Instruction.TextureSample, - type, - flags, - handle, - compIndex, - destOperand, - sources); - - context.Add(operation); - } - } + context.Add(operation); } private static void EmitTxq( @@ -1188,7 +1175,7 @@ namespace Ryujinx.Graphics.Shader.Instructions flags, imm, compIndex, - destOperand, + new[] { destOperand }, sources); context.Add(operation); diff --git a/Ryujinx.Graphics.Shader/IntermediateRepresentation/Instruction.cs b/Ryujinx.Graphics.Shader/IntermediateRepresentation/Instruction.cs index 9a2c844d..aa9776bc 100644 --- a/Ryujinx.Graphics.Shader/IntermediateRepresentation/Instruction.cs +++ b/Ryujinx.Graphics.Shader/IntermediateRepresentation/Instruction.cs @@ -134,6 +134,7 @@ namespace Ryujinx.Graphics.Shader.IntermediateRepresentation Truncate, UnpackDouble2x32, UnpackHalf2x16, + VectorExtract, VoteAll, VoteAllEqual, VoteAny, diff --git a/Ryujinx.Graphics.Shader/IntermediateRepresentation/Operation.cs b/Ryujinx.Graphics.Shader/IntermediateRepresentation/Operation.cs index 96132633..18e203a7 100644 --- a/Ryujinx.Graphics.Shader/IntermediateRepresentation/Operation.cs +++ b/Ryujinx.Graphics.Shader/IntermediateRepresentation/Operation.cs @@ -62,18 +62,25 @@ namespace Ryujinx.Graphics.Shader.IntermediateRepresentation Inst = inst; Index = index; - // The array may be modified externally, so we store a copy. - _dests = (Operand[])dests.Clone(); - - for (int dstIndex = 0; dstIndex < dests.Length; dstIndex++) + if (dests != null) { - Operand dest = dests[dstIndex]; + // The array may be modified externally, so we store a copy. + _dests = (Operand[])dests.Clone(); - if (dest != null && dest.Type == OperandType.LocalVariable) + for (int dstIndex = 0; dstIndex < dests.Length; dstIndex++) { - dest.AsgOp = this; + Operand dest = dests[dstIndex]; + + if (dest != null && dest.Type == OperandType.LocalVariable) + { + dest.AsgOp = this; + } } } + else + { + _dests = Array.Empty(); + } } public Operation(Instruction inst, Operand dest, params Operand[] sources) : this(sources) diff --git a/Ryujinx.Graphics.Shader/IntermediateRepresentation/TextureOperation.cs b/Ryujinx.Graphics.Shader/IntermediateRepresentation/TextureOperation.cs index 8cfcb0e9..6ab868cd 100644 --- a/Ryujinx.Graphics.Shader/IntermediateRepresentation/TextureOperation.cs +++ b/Ryujinx.Graphics.Shader/IntermediateRepresentation/TextureOperation.cs @@ -19,8 +19,8 @@ namespace Ryujinx.Graphics.Shader.IntermediateRepresentation int cbufSlot, int handle, int compIndex, - Operand dest, - Operand[] sources) : base(inst, compIndex, dest, sources) + Operand[] dests, + Operand[] sources) : base(inst, compIndex, dests, sources) { Type = type; Format = format; @@ -36,8 +36,8 @@ namespace Ryujinx.Graphics.Shader.IntermediateRepresentation TextureFlags flags, int handle, int compIndex, - Operand dest, - Operand[] sources) : this(inst, type, format, flags, DefaultCbufSlot, handle, compIndex, dest, sources) + Operand[] dests, + Operand[] sources) : this(inst, type, format, flags, DefaultCbufSlot, handle, compIndex, dests, sources) { } diff --git a/Ryujinx.Graphics.Shader/SamplerType.cs b/Ryujinx.Graphics.Shader/SamplerType.cs index d04b16b3..620f4ccf 100644 --- a/Ryujinx.Graphics.Shader/SamplerType.cs +++ b/Ryujinx.Graphics.Shader/SamplerType.cs @@ -1,4 +1,4 @@ -using Ryujinx.Graphics.Shader.StructuredIr; +using Ryujinx.Graphics.Shader.Translation; using System; namespace Ryujinx.Graphics.Shader @@ -66,7 +66,7 @@ namespace Ryujinx.Graphics.Shader return typeName; } - public static string ToGlslImageType(this SamplerType type, VariableType componentType) + public static string ToGlslImageType(this SamplerType type, AggregateType componentType) { string typeName = (type & SamplerType.Mask) switch { @@ -90,8 +90,8 @@ namespace Ryujinx.Graphics.Shader switch (componentType) { - case VariableType.U32: typeName = 'u' + typeName; break; - case VariableType.S32: typeName = 'i' + typeName; break; + case AggregateType.U32: typeName = 'u' + typeName; break; + case AggregateType.S32: typeName = 'i' + typeName; break; } return typeName; diff --git a/Ryujinx.Graphics.Shader/StructuredIr/AstHelper.cs b/Ryujinx.Graphics.Shader/StructuredIr/AstHelper.cs index 9d3148e1..7aa0409b 100644 --- a/Ryujinx.Graphics.Shader/StructuredIr/AstHelper.cs +++ b/Ryujinx.Graphics.Shader/StructuredIr/AstHelper.cs @@ -1,4 +1,5 @@ using Ryujinx.Graphics.Shader.IntermediateRepresentation; +using Ryujinx.Graphics.Shader.Translation; namespace Ryujinx.Graphics.Shader.StructuredIr { @@ -46,7 +47,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr return new AstOperand(OperandType.Constant, value); } - public static AstOperand Local(VariableType type) + public static AstOperand Local(AggregateType type) { AstOperand local = new AstOperand(OperandType.LocalVariable); diff --git a/Ryujinx.Graphics.Shader/StructuredIr/AstOperand.cs b/Ryujinx.Graphics.Shader/StructuredIr/AstOperand.cs index 97ff3ca9..1fc0035f 100644 --- a/Ryujinx.Graphics.Shader/StructuredIr/AstOperand.cs +++ b/Ryujinx.Graphics.Shader/StructuredIr/AstOperand.cs @@ -1,4 +1,5 @@ using Ryujinx.Graphics.Shader.IntermediateRepresentation; +using Ryujinx.Graphics.Shader.Translation; using System.Collections.Generic; namespace Ryujinx.Graphics.Shader.StructuredIr @@ -10,7 +11,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr public OperandType Type { get; } - public VariableType VarType { get; set; } + public AggregateType VarType { get; set; } public int Value { get; } @@ -22,7 +23,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr Defs = new HashSet(); Uses = new HashSet(); - VarType = VariableType.S32; + VarType = AggregateType.S32; } public AstOperand(Operand operand) : this() diff --git a/Ryujinx.Graphics.Shader/StructuredIr/AstOperation.cs b/Ryujinx.Graphics.Shader/StructuredIr/AstOperation.cs index a8474955..19397256 100644 --- a/Ryujinx.Graphics.Shader/StructuredIr/AstOperation.cs +++ b/Ryujinx.Graphics.Shader/StructuredIr/AstOperation.cs @@ -1,4 +1,6 @@ using Ryujinx.Graphics.Shader.IntermediateRepresentation; +using Ryujinx.Graphics.Shader.Translation; +using System.Numerics; using static Ryujinx.Graphics.Shader.StructuredIr.AstHelper; @@ -56,5 +58,21 @@ namespace Ryujinx.Graphics.Shader.StructuredIr _sources[index] = source; } + + public AggregateType GetVectorType(AggregateType scalarType) + { + int componentsCount = BitOperations.PopCount((uint)Index); + + AggregateType type = scalarType; + + switch (componentsCount) + { + case 2: type |= AggregateType.Vector2; break; + case 3: type |= AggregateType.Vector3; break; + case 4: type |= AggregateType.Vector4; break; + } + + return type; + } } } \ No newline at end of file diff --git a/Ryujinx.Graphics.Shader/StructuredIr/InstructionInfo.cs b/Ryujinx.Graphics.Shader/StructuredIr/InstructionInfo.cs index aea36423..0a9a9e51 100644 --- a/Ryujinx.Graphics.Shader/StructuredIr/InstructionInfo.cs +++ b/Ryujinx.Graphics.Shader/StructuredIr/InstructionInfo.cs @@ -1,4 +1,5 @@ using Ryujinx.Graphics.Shader.IntermediateRepresentation; +using Ryujinx.Graphics.Shader.Translation; using System; namespace Ryujinx.Graphics.Shader.StructuredIr @@ -7,11 +8,11 @@ namespace Ryujinx.Graphics.Shader.StructuredIr { private readonly struct InstInfo { - public VariableType DestType { get; } + public AggregateType DestType { get; } - public VariableType[] SrcTypes { get; } + public AggregateType[] SrcTypes { get; } - public InstInfo(VariableType destType, params VariableType[] srcTypes) + public InstInfo(AggregateType destType, params AggregateType[] srcTypes) { DestType = destType; SrcTypes = srcTypes; @@ -24,176 +25,173 @@ namespace Ryujinx.Graphics.Shader.StructuredIr { _infoTbl = new InstInfo[(int)Instruction.Count]; - // Inst Destination type Source 1 type Source 2 type Source 3 type Source 4 type - Add(Instruction.AtomicAdd, VariableType.U32, VariableType.S32, VariableType.S32, VariableType.U32); - Add(Instruction.AtomicAnd, VariableType.U32, VariableType.S32, VariableType.S32, VariableType.U32); - Add(Instruction.AtomicCompareAndSwap, VariableType.U32, VariableType.S32, VariableType.S32, VariableType.U32, VariableType.U32); - Add(Instruction.AtomicMaxS32, VariableType.S32, VariableType.S32, VariableType.S32, VariableType.S32); - Add(Instruction.AtomicMaxU32, VariableType.U32, VariableType.S32, VariableType.S32, VariableType.U32); - Add(Instruction.AtomicMinS32, VariableType.S32, VariableType.S32, VariableType.S32, VariableType.S32); - Add(Instruction.AtomicMinU32, VariableType.U32, VariableType.S32, VariableType.S32, VariableType.U32); - Add(Instruction.AtomicOr, VariableType.U32, VariableType.S32, VariableType.S32, VariableType.U32); - Add(Instruction.AtomicSwap, VariableType.U32, VariableType.S32, VariableType.S32, VariableType.U32); - Add(Instruction.AtomicXor, VariableType.U32, VariableType.S32, VariableType.S32, VariableType.U32); - Add(Instruction.Absolute, VariableType.Scalar, VariableType.Scalar); - Add(Instruction.Add, VariableType.Scalar, VariableType.Scalar, VariableType.Scalar); - Add(Instruction.Ballot, VariableType.U32, VariableType.Bool); - Add(Instruction.BitCount, VariableType.Int, VariableType.Int); - Add(Instruction.BitfieldExtractS32, VariableType.S32, VariableType.S32, VariableType.S32, VariableType.S32); - Add(Instruction.BitfieldExtractU32, VariableType.U32, VariableType.U32, VariableType.S32, VariableType.S32); - Add(Instruction.BitfieldInsert, VariableType.Int, VariableType.Int, VariableType.Int, VariableType.S32, VariableType.S32); - Add(Instruction.BitfieldReverse, VariableType.Int, VariableType.Int); - Add(Instruction.BitwiseAnd, VariableType.Int, VariableType.Int, VariableType.Int); - Add(Instruction.BitwiseExclusiveOr, VariableType.Int, VariableType.Int, VariableType.Int); - Add(Instruction.BitwiseNot, VariableType.Int, VariableType.Int); - Add(Instruction.BitwiseOr, VariableType.Int, VariableType.Int, VariableType.Int); - Add(Instruction.BranchIfTrue, VariableType.None, VariableType.Bool); - Add(Instruction.BranchIfFalse, VariableType.None, VariableType.Bool); - Add(Instruction.Call, VariableType.Scalar); - Add(Instruction.Ceiling, VariableType.Scalar, VariableType.Scalar, VariableType.Scalar); - Add(Instruction.Clamp, VariableType.Scalar, VariableType.Scalar, VariableType.Scalar, VariableType.Scalar); - Add(Instruction.ClampU32, VariableType.U32, VariableType.U32, VariableType.U32, VariableType.U32); - Add(Instruction.CompareEqual, VariableType.Bool, VariableType.Scalar, VariableType.Scalar); - Add(Instruction.CompareGreater, VariableType.Bool, VariableType.Scalar, VariableType.Scalar); - Add(Instruction.CompareGreaterOrEqual, VariableType.Bool, VariableType.Scalar, VariableType.Scalar); - Add(Instruction.CompareGreaterOrEqualU32, VariableType.Bool, VariableType.U32, VariableType.U32); - Add(Instruction.CompareGreaterU32, VariableType.Bool, VariableType.U32, VariableType.U32); - Add(Instruction.CompareLess, VariableType.Bool, VariableType.Scalar, VariableType.Scalar); - Add(Instruction.CompareLessOrEqual, VariableType.Bool, VariableType.Scalar, VariableType.Scalar); - Add(Instruction.CompareLessOrEqualU32, VariableType.Bool, VariableType.U32, VariableType.U32); - Add(Instruction.CompareLessU32, VariableType.Bool, VariableType.U32, VariableType.U32); - Add(Instruction.CompareNotEqual, VariableType.Bool, VariableType.Scalar, VariableType.Scalar); - Add(Instruction.ConditionalSelect, VariableType.Scalar, VariableType.Bool, VariableType.Scalar, VariableType.Scalar); - Add(Instruction.ConvertFP32ToFP64, VariableType.F64, VariableType.F32); - Add(Instruction.ConvertFP64ToFP32, VariableType.F32, VariableType.F64); - Add(Instruction.ConvertFP32ToS32, VariableType.S32, VariableType.F32); - Add(Instruction.ConvertFP32ToU32, VariableType.U32, VariableType.F32); - Add(Instruction.ConvertFP64ToS32, VariableType.S32, VariableType.F64); - Add(Instruction.ConvertFP64ToU32, VariableType.U32, VariableType.F64); - Add(Instruction.ConvertS32ToFP32, VariableType.F32, VariableType.S32); - Add(Instruction.ConvertS32ToFP64, VariableType.F64, VariableType.S32); - Add(Instruction.ConvertU32ToFP32, VariableType.F32, VariableType.U32); - Add(Instruction.ConvertU32ToFP64, VariableType.F64, VariableType.U32); - Add(Instruction.Cosine, VariableType.Scalar, VariableType.Scalar); - Add(Instruction.Ddx, VariableType.F32, VariableType.F32); - Add(Instruction.Ddy, VariableType.F32, VariableType.F32); - Add(Instruction.Divide, VariableType.Scalar, VariableType.Scalar, VariableType.Scalar); - Add(Instruction.ExponentB2, VariableType.Scalar, VariableType.Scalar); - Add(Instruction.FindLSB, VariableType.Int, VariableType.Int); - Add(Instruction.FindMSBS32, VariableType.S32, VariableType.S32); - Add(Instruction.FindMSBU32, VariableType.S32, VariableType.U32); - Add(Instruction.Floor, VariableType.Scalar, VariableType.Scalar); - Add(Instruction.FusedMultiplyAdd, VariableType.Scalar, VariableType.Scalar, VariableType.Scalar, VariableType.Scalar); - Add(Instruction.ImageLoad, VariableType.F32); - Add(Instruction.ImageStore, VariableType.None); - Add(Instruction.ImageAtomic, VariableType.S32); - Add(Instruction.IsNan, VariableType.Bool, VariableType.Scalar); - Add(Instruction.LoadAttribute, VariableType.F32, VariableType.S32, VariableType.S32, VariableType.S32); - Add(Instruction.LoadConstant, VariableType.F32, VariableType.S32, VariableType.S32); - Add(Instruction.LoadGlobal, VariableType.U32, VariableType.S32, VariableType.S32); - Add(Instruction.LoadLocal, VariableType.U32, VariableType.S32); - Add(Instruction.LoadShared, VariableType.U32, VariableType.S32); - Add(Instruction.LoadStorage, VariableType.U32, VariableType.S32, VariableType.S32); - Add(Instruction.Lod, VariableType.F32); - Add(Instruction.LogarithmB2, VariableType.Scalar, VariableType.Scalar); - Add(Instruction.LogicalAnd, VariableType.Bool, VariableType.Bool, VariableType.Bool); - Add(Instruction.LogicalExclusiveOr, VariableType.Bool, VariableType.Bool, VariableType.Bool); - Add(Instruction.LogicalNot, VariableType.Bool, VariableType.Bool); - Add(Instruction.LogicalOr, VariableType.Bool, VariableType.Bool, VariableType.Bool); - Add(Instruction.Maximum, VariableType.Scalar, VariableType.Scalar, VariableType.Scalar); - Add(Instruction.MaximumU32, VariableType.U32, VariableType.U32, VariableType.U32); - Add(Instruction.Minimum, VariableType.Scalar, VariableType.Scalar, VariableType.Scalar); - Add(Instruction.MinimumU32, VariableType.U32, VariableType.U32, VariableType.U32); - Add(Instruction.Multiply, VariableType.Scalar, VariableType.Scalar, VariableType.Scalar); - Add(Instruction.MultiplyHighS32, VariableType.S32, VariableType.S32, VariableType.S32); - Add(Instruction.MultiplyHighU32, VariableType.U32, VariableType.U32, VariableType.U32); - Add(Instruction.Negate, VariableType.Scalar, VariableType.Scalar); - Add(Instruction.PackDouble2x32, VariableType.F64, VariableType.U32, VariableType.U32); - Add(Instruction.PackHalf2x16, VariableType.U32, VariableType.F32, VariableType.F32); - Add(Instruction.ReciprocalSquareRoot, VariableType.Scalar, VariableType.Scalar); - Add(Instruction.Round, VariableType.Scalar, VariableType.Scalar); - Add(Instruction.ShiftLeft, VariableType.Int, VariableType.Int, VariableType.Int); - Add(Instruction.ShiftRightS32, VariableType.S32, VariableType.S32, VariableType.Int); - Add(Instruction.ShiftRightU32, VariableType.U32, VariableType.U32, VariableType.Int); - Add(Instruction.Shuffle, VariableType.F32, VariableType.F32, VariableType.U32, VariableType.U32, VariableType.Bool); - Add(Instruction.ShuffleDown, VariableType.F32, VariableType.F32, VariableType.U32, VariableType.U32, VariableType.Bool); - Add(Instruction.ShuffleUp, VariableType.F32, VariableType.F32, VariableType.U32, VariableType.U32, VariableType.Bool); - Add(Instruction.ShuffleXor, VariableType.F32, VariableType.F32, VariableType.U32, VariableType.U32, VariableType.Bool); - Add(Instruction.Sine, VariableType.Scalar, VariableType.Scalar); - Add(Instruction.SquareRoot, VariableType.Scalar, VariableType.Scalar); - Add(Instruction.StoreAttribute, VariableType.None, VariableType.S32, VariableType.S32, VariableType.F32); - Add(Instruction.StoreGlobal, VariableType.None, VariableType.S32, VariableType.S32, VariableType.U32); - Add(Instruction.StoreLocal, VariableType.None, VariableType.S32, VariableType.U32); - Add(Instruction.StoreShared, VariableType.None, VariableType.S32, VariableType.U32); - Add(Instruction.StoreShared16, VariableType.None, VariableType.S32, VariableType.U32); - Add(Instruction.StoreShared8, VariableType.None, VariableType.S32, VariableType.U32); - Add(Instruction.StoreStorage, VariableType.None, VariableType.S32, VariableType.S32, VariableType.U32); - Add(Instruction.StoreStorage16, VariableType.None, VariableType.S32, VariableType.S32, VariableType.U32); - Add(Instruction.StoreStorage8, VariableType.None, VariableType.S32, VariableType.S32, VariableType.U32); - Add(Instruction.Subtract, VariableType.Scalar, VariableType.Scalar, VariableType.Scalar); - Add(Instruction.SwizzleAdd, VariableType.F32, VariableType.F32, VariableType.F32, VariableType.S32); - Add(Instruction.TextureSample, VariableType.F32); - Add(Instruction.TextureSize, VariableType.S32, VariableType.S32, VariableType.S32); - Add(Instruction.Truncate, VariableType.Scalar, VariableType.Scalar); - Add(Instruction.UnpackDouble2x32, VariableType.U32, VariableType.F64); - Add(Instruction.UnpackHalf2x16, VariableType.F32, VariableType.U32); - Add(Instruction.VoteAll, VariableType.Bool, VariableType.Bool); - Add(Instruction.VoteAllEqual, VariableType.Bool, VariableType.Bool); - Add(Instruction.VoteAny, VariableType.Bool, VariableType.Bool); + // Inst Destination type Source 1 type Source 2 type Source 3 type Source 4 type + Add(Instruction.AtomicAdd, AggregateType.U32, AggregateType.S32, AggregateType.S32, AggregateType.U32); + Add(Instruction.AtomicAnd, AggregateType.U32, AggregateType.S32, AggregateType.S32, AggregateType.U32); + Add(Instruction.AtomicCompareAndSwap, AggregateType.U32, AggregateType.S32, AggregateType.S32, AggregateType.U32, AggregateType.U32); + Add(Instruction.AtomicMaxS32, AggregateType.S32, AggregateType.S32, AggregateType.S32, AggregateType.S32); + Add(Instruction.AtomicMaxU32, AggregateType.U32, AggregateType.S32, AggregateType.S32, AggregateType.U32); + Add(Instruction.AtomicMinS32, AggregateType.S32, AggregateType.S32, AggregateType.S32, AggregateType.S32); + Add(Instruction.AtomicMinU32, AggregateType.U32, AggregateType.S32, AggregateType.S32, AggregateType.U32); + Add(Instruction.AtomicOr, AggregateType.U32, AggregateType.S32, AggregateType.S32, AggregateType.U32); + Add(Instruction.AtomicSwap, AggregateType.U32, AggregateType.S32, AggregateType.S32, AggregateType.U32); + Add(Instruction.AtomicXor, AggregateType.U32, AggregateType.S32, AggregateType.S32, AggregateType.U32); + Add(Instruction.Absolute, AggregateType.Scalar, AggregateType.Scalar); + Add(Instruction.Add, AggregateType.Scalar, AggregateType.Scalar, AggregateType.Scalar); + Add(Instruction.Ballot, AggregateType.U32, AggregateType.Bool); + Add(Instruction.BitCount, AggregateType.S32, AggregateType.S32); + Add(Instruction.BitfieldExtractS32, AggregateType.S32, AggregateType.S32, AggregateType.S32, AggregateType.S32); + Add(Instruction.BitfieldExtractU32, AggregateType.U32, AggregateType.U32, AggregateType.S32, AggregateType.S32); + Add(Instruction.BitfieldInsert, AggregateType.S32, AggregateType.S32, AggregateType.S32, AggregateType.S32, AggregateType.S32); + Add(Instruction.BitfieldReverse, AggregateType.S32, AggregateType.S32); + Add(Instruction.BitwiseAnd, AggregateType.S32, AggregateType.S32, AggregateType.S32); + Add(Instruction.BitwiseExclusiveOr, AggregateType.S32, AggregateType.S32, AggregateType.S32); + Add(Instruction.BitwiseNot, AggregateType.S32, AggregateType.S32); + Add(Instruction.BitwiseOr, AggregateType.S32, AggregateType.S32, AggregateType.S32); + Add(Instruction.BranchIfTrue, AggregateType.Void, AggregateType.Bool); + Add(Instruction.BranchIfFalse, AggregateType.Void, AggregateType.Bool); + Add(Instruction.Call, AggregateType.Scalar); + Add(Instruction.Ceiling, AggregateType.Scalar, AggregateType.Scalar, AggregateType.Scalar); + Add(Instruction.Clamp, AggregateType.Scalar, AggregateType.Scalar, AggregateType.Scalar, AggregateType.Scalar); + Add(Instruction.ClampU32, AggregateType.U32, AggregateType.U32, AggregateType.U32, AggregateType.U32); + Add(Instruction.CompareEqual, AggregateType.Bool, AggregateType.Scalar, AggregateType.Scalar); + Add(Instruction.CompareGreater, AggregateType.Bool, AggregateType.Scalar, AggregateType.Scalar); + Add(Instruction.CompareGreaterOrEqual, AggregateType.Bool, AggregateType.Scalar, AggregateType.Scalar); + Add(Instruction.CompareGreaterOrEqualU32, AggregateType.Bool, AggregateType.U32, AggregateType.U32); + Add(Instruction.CompareGreaterU32, AggregateType.Bool, AggregateType.U32, AggregateType.U32); + Add(Instruction.CompareLess, AggregateType.Bool, AggregateType.Scalar, AggregateType.Scalar); + Add(Instruction.CompareLessOrEqual, AggregateType.Bool, AggregateType.Scalar, AggregateType.Scalar); + Add(Instruction.CompareLessOrEqualU32, AggregateType.Bool, AggregateType.U32, AggregateType.U32); + Add(Instruction.CompareLessU32, AggregateType.Bool, AggregateType.U32, AggregateType.U32); + Add(Instruction.CompareNotEqual, AggregateType.Bool, AggregateType.Scalar, AggregateType.Scalar); + Add(Instruction.ConditionalSelect, AggregateType.Scalar, AggregateType.Bool, AggregateType.Scalar, AggregateType.Scalar); + Add(Instruction.ConvertFP32ToFP64, AggregateType.FP64, AggregateType.FP32); + Add(Instruction.ConvertFP64ToFP32, AggregateType.FP32, AggregateType.FP64); + Add(Instruction.ConvertFP32ToS32, AggregateType.S32, AggregateType.FP32); + Add(Instruction.ConvertFP32ToU32, AggregateType.U32, AggregateType.FP32); + Add(Instruction.ConvertFP64ToS32, AggregateType.S32, AggregateType.FP64); + Add(Instruction.ConvertFP64ToU32, AggregateType.U32, AggregateType.FP64); + Add(Instruction.ConvertS32ToFP32, AggregateType.FP32, AggregateType.S32); + Add(Instruction.ConvertS32ToFP64, AggregateType.FP64, AggregateType.S32); + Add(Instruction.ConvertU32ToFP32, AggregateType.FP32, AggregateType.U32); + Add(Instruction.ConvertU32ToFP64, AggregateType.FP64, AggregateType.U32); + Add(Instruction.Cosine, AggregateType.Scalar, AggregateType.Scalar); + Add(Instruction.Ddx, AggregateType.FP32, AggregateType.FP32); + Add(Instruction.Ddy, AggregateType.FP32, AggregateType.FP32); + Add(Instruction.Divide, AggregateType.Scalar, AggregateType.Scalar, AggregateType.Scalar); + Add(Instruction.ExponentB2, AggregateType.Scalar, AggregateType.Scalar); + Add(Instruction.FindLSB, AggregateType.S32, AggregateType.S32); + Add(Instruction.FindMSBS32, AggregateType.S32, AggregateType.S32); + Add(Instruction.FindMSBU32, AggregateType.S32, AggregateType.U32); + Add(Instruction.Floor, AggregateType.Scalar, AggregateType.Scalar); + Add(Instruction.FusedMultiplyAdd, AggregateType.Scalar, AggregateType.Scalar, AggregateType.Scalar, AggregateType.Scalar); + Add(Instruction.ImageLoad, AggregateType.FP32); + Add(Instruction.ImageStore, AggregateType.Void); + Add(Instruction.ImageAtomic, AggregateType.S32); + Add(Instruction.IsNan, AggregateType.Bool, AggregateType.Scalar); + Add(Instruction.LoadAttribute, AggregateType.FP32, AggregateType.S32, AggregateType.S32, AggregateType.S32); + Add(Instruction.LoadConstant, AggregateType.FP32, AggregateType.S32, AggregateType.S32); + Add(Instruction.LoadGlobal, AggregateType.U32, AggregateType.S32, AggregateType.S32); + Add(Instruction.LoadLocal, AggregateType.U32, AggregateType.S32); + Add(Instruction.LoadShared, AggregateType.U32, AggregateType.S32); + Add(Instruction.LoadStorage, AggregateType.U32, AggregateType.S32, AggregateType.S32); + Add(Instruction.Lod, AggregateType.FP32); + Add(Instruction.LogarithmB2, AggregateType.Scalar, AggregateType.Scalar); + Add(Instruction.LogicalAnd, AggregateType.Bool, AggregateType.Bool, AggregateType.Bool); + Add(Instruction.LogicalExclusiveOr, AggregateType.Bool, AggregateType.Bool, AggregateType.Bool); + Add(Instruction.LogicalNot, AggregateType.Bool, AggregateType.Bool); + Add(Instruction.LogicalOr, AggregateType.Bool, AggregateType.Bool, AggregateType.Bool); + Add(Instruction.Maximum, AggregateType.Scalar, AggregateType.Scalar, AggregateType.Scalar); + Add(Instruction.MaximumU32, AggregateType.U32, AggregateType.U32, AggregateType.U32); + Add(Instruction.Minimum, AggregateType.Scalar, AggregateType.Scalar, AggregateType.Scalar); + Add(Instruction.MinimumU32, AggregateType.U32, AggregateType.U32, AggregateType.U32); + Add(Instruction.Multiply, AggregateType.Scalar, AggregateType.Scalar, AggregateType.Scalar); + Add(Instruction.MultiplyHighS32, AggregateType.S32, AggregateType.S32, AggregateType.S32); + Add(Instruction.MultiplyHighU32, AggregateType.U32, AggregateType.U32, AggregateType.U32); + Add(Instruction.Negate, AggregateType.Scalar, AggregateType.Scalar); + Add(Instruction.PackDouble2x32, AggregateType.FP64, AggregateType.U32, AggregateType.U32); + Add(Instruction.PackHalf2x16, AggregateType.U32, AggregateType.FP32, AggregateType.FP32); + Add(Instruction.ReciprocalSquareRoot, AggregateType.Scalar, AggregateType.Scalar); + Add(Instruction.Round, AggregateType.Scalar, AggregateType.Scalar); + Add(Instruction.ShiftLeft, AggregateType.S32, AggregateType.S32, AggregateType.S32); + Add(Instruction.ShiftRightS32, AggregateType.S32, AggregateType.S32, AggregateType.S32); + Add(Instruction.ShiftRightU32, AggregateType.U32, AggregateType.U32, AggregateType.S32); + Add(Instruction.Shuffle, AggregateType.FP32, AggregateType.FP32, AggregateType.U32, AggregateType.U32, AggregateType.Bool); + Add(Instruction.ShuffleDown, AggregateType.FP32, AggregateType.FP32, AggregateType.U32, AggregateType.U32, AggregateType.Bool); + Add(Instruction.ShuffleUp, AggregateType.FP32, AggregateType.FP32, AggregateType.U32, AggregateType.U32, AggregateType.Bool); + Add(Instruction.ShuffleXor, AggregateType.FP32, AggregateType.FP32, AggregateType.U32, AggregateType.U32, AggregateType.Bool); + Add(Instruction.Sine, AggregateType.Scalar, AggregateType.Scalar); + Add(Instruction.SquareRoot, AggregateType.Scalar, AggregateType.Scalar); + Add(Instruction.StoreAttribute, AggregateType.Void, AggregateType.S32, AggregateType.S32, AggregateType.FP32); + Add(Instruction.StoreGlobal, AggregateType.Void, AggregateType.S32, AggregateType.S32, AggregateType.U32); + Add(Instruction.StoreLocal, AggregateType.Void, AggregateType.S32, AggregateType.U32); + Add(Instruction.StoreShared, AggregateType.Void, AggregateType.S32, AggregateType.U32); + Add(Instruction.StoreShared16, AggregateType.Void, AggregateType.S32, AggregateType.U32); + Add(Instruction.StoreShared8, AggregateType.Void, AggregateType.S32, AggregateType.U32); + Add(Instruction.StoreStorage, AggregateType.Void, AggregateType.S32, AggregateType.S32, AggregateType.U32); + Add(Instruction.StoreStorage16, AggregateType.Void, AggregateType.S32, AggregateType.S32, AggregateType.U32); + Add(Instruction.StoreStorage8, AggregateType.Void, AggregateType.S32, AggregateType.S32, AggregateType.U32); + Add(Instruction.Subtract, AggregateType.Scalar, AggregateType.Scalar, AggregateType.Scalar); + Add(Instruction.SwizzleAdd, AggregateType.FP32, AggregateType.FP32, AggregateType.FP32, AggregateType.S32); + Add(Instruction.TextureSample, AggregateType.FP32); + Add(Instruction.TextureSize, AggregateType.S32, AggregateType.S32, AggregateType.S32); + Add(Instruction.Truncate, AggregateType.Scalar, AggregateType.Scalar); + Add(Instruction.UnpackDouble2x32, AggregateType.U32, AggregateType.FP64); + Add(Instruction.UnpackHalf2x16, AggregateType.FP32, AggregateType.U32); + Add(Instruction.VectorExtract, AggregateType.Scalar, AggregateType.Vector4, AggregateType.S32); + Add(Instruction.VoteAll, AggregateType.Bool, AggregateType.Bool); + Add(Instruction.VoteAllEqual, AggregateType.Bool, AggregateType.Bool); + Add(Instruction.VoteAny, AggregateType.Bool, AggregateType.Bool); } - private static void Add(Instruction inst, VariableType destType, params VariableType[] srcTypes) + private static void Add(Instruction inst, AggregateType destType, params AggregateType[] srcTypes) { _infoTbl[(int)inst] = new InstInfo(destType, srcTypes); } - public static VariableType GetDestVarType(Instruction inst) + public static AggregateType GetDestVarType(Instruction inst) { return GetFinalVarType(_infoTbl[(int)(inst & Instruction.Mask)].DestType, inst); } - public static VariableType GetSrcVarType(Instruction inst, int index) + public static AggregateType GetSrcVarType(Instruction inst, int index) { // TODO: Return correct type depending on source index, // that can improve the decompiler output. - if (inst == Instruction.ImageLoad || - inst == Instruction.ImageStore || + if (inst == Instruction.ImageLoad || + inst == Instruction.ImageStore || inst == Instruction.ImageAtomic || - inst == Instruction.Lod || + inst == Instruction.Lod || inst == Instruction.TextureSample) { - return VariableType.F32; + return AggregateType.FP32; } else if (inst == Instruction.Call) { - return VariableType.S32; + return AggregateType.S32; } return GetFinalVarType(_infoTbl[(int)(inst & Instruction.Mask)].SrcTypes[index], inst); } - private static VariableType GetFinalVarType(VariableType type, Instruction inst) + private static AggregateType GetFinalVarType(AggregateType type, Instruction inst) { - if (type == VariableType.Scalar) + if (type == AggregateType.Scalar) { if ((inst & Instruction.FP32) != 0) { - return VariableType.F32; + return AggregateType.FP32; } else if ((inst & Instruction.FP64) != 0) { - return VariableType.F64; + return AggregateType.FP64; } else { - return VariableType.S32; + return AggregateType.S32; } } - else if (type == VariableType.Int) - { - return VariableType.S32; - } - else if (type == VariableType.None) + else if (type == AggregateType.Void) { throw new ArgumentException($"Invalid operand for instruction \"{inst}\"."); } diff --git a/Ryujinx.Graphics.Shader/StructuredIr/OperandInfo.cs b/Ryujinx.Graphics.Shader/StructuredIr/OperandInfo.cs index 34428815..730468a4 100644 --- a/Ryujinx.Graphics.Shader/StructuredIr/OperandInfo.cs +++ b/Ryujinx.Graphics.Shader/StructuredIr/OperandInfo.cs @@ -1,11 +1,12 @@ using Ryujinx.Graphics.Shader.IntermediateRepresentation; +using Ryujinx.Graphics.Shader.Translation; using System; namespace Ryujinx.Graphics.Shader.StructuredIr { static class OperandInfo { - public static VariableType GetVarType(AstOperand operand) + public static AggregateType GetVarType(AstOperand operand) { if (operand.Type == OperandType.LocalVariable) { @@ -17,16 +18,16 @@ namespace Ryujinx.Graphics.Shader.StructuredIr } } - public static VariableType GetVarType(OperandType type) + public static AggregateType GetVarType(OperandType type) { return type switch { - OperandType.Argument => VariableType.S32, - OperandType.Attribute => VariableType.F32, - OperandType.AttributePerPatch => VariableType.F32, - OperandType.Constant => VariableType.S32, - OperandType.ConstantBuffer => VariableType.F32, - OperandType.Undefined => VariableType.S32, + OperandType.Argument => AggregateType.S32, + OperandType.Attribute => AggregateType.FP32, + OperandType.AttributePerPatch => AggregateType.FP32, + OperandType.Constant => AggregateType.S32, + OperandType.ConstantBuffer => AggregateType.FP32, + OperandType.Undefined => AggregateType.S32, _ => throw new ArgumentException($"Invalid operand type \"{type}\".") }; } diff --git a/Ryujinx.Graphics.Shader/StructuredIr/StructuredFunction.cs b/Ryujinx.Graphics.Shader/StructuredIr/StructuredFunction.cs index 3723f259..61c4fed7 100644 --- a/Ryujinx.Graphics.Shader/StructuredIr/StructuredFunction.cs +++ b/Ryujinx.Graphics.Shader/StructuredIr/StructuredFunction.cs @@ -1,3 +1,4 @@ +using Ryujinx.Graphics.Shader.Translation; using System.Collections.Generic; namespace Ryujinx.Graphics.Shader.StructuredIr @@ -8,19 +9,19 @@ namespace Ryujinx.Graphics.Shader.StructuredIr public string Name { get; } - public VariableType ReturnType { get; } + public AggregateType ReturnType { get; } - public VariableType[] InArguments { get; } - public VariableType[] OutArguments { get; } + public AggregateType[] InArguments { get; } + public AggregateType[] OutArguments { get; } public HashSet Locals { get; } public StructuredFunction( AstBlock mainBlock, string name, - VariableType returnType, - VariableType[] inArguments, - VariableType[] outArguments) + AggregateType returnType, + AggregateType[] inArguments, + AggregateType[] outArguments) { MainBlock = mainBlock; Name = name; @@ -31,7 +32,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr Locals = new HashSet(); } - public VariableType GetArgumentType(int index) + public AggregateType GetArgumentType(int index) { return index >= InArguments.Length ? OutArguments[index - InArguments.Length] diff --git a/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgram.cs b/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgram.cs index 7678a4bf..ec989cca 100644 --- a/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgram.cs +++ b/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgram.cs @@ -2,6 +2,7 @@ using Ryujinx.Graphics.Shader.IntermediateRepresentation; using Ryujinx.Graphics.Shader.Translation; using System; using System.Collections.Generic; +using System.Numerics; namespace Ryujinx.Graphics.Shader.StructuredIr { @@ -17,19 +18,19 @@ namespace Ryujinx.Graphics.Shader.StructuredIr BasicBlock[] blocks = function.Blocks; - VariableType returnType = function.ReturnsValue ? VariableType.S32 : VariableType.None; + AggregateType returnType = function.ReturnsValue ? AggregateType.S32 : AggregateType.Void; - VariableType[] inArguments = new VariableType[function.InArgumentsCount]; - VariableType[] outArguments = new VariableType[function.OutArgumentsCount]; + AggregateType[] inArguments = new AggregateType[function.InArgumentsCount]; + AggregateType[] outArguments = new AggregateType[function.OutArgumentsCount]; for (int i = 0; i < inArguments.Length; i++) { - inArguments[i] = VariableType.S32; + inArguments[i] = AggregateType.S32; } for (int i = 0; i < outArguments.Length; i++) { - outArguments[i] = VariableType.S32; + outArguments[i] = AggregateType.S32; } context.EnterFunction(blocks.Length, function.Name, returnType, inArguments, outArguments); @@ -109,8 +110,10 @@ namespace Ryujinx.Graphics.Shader.StructuredIr } } + bool vectorDest = IsVectorDestInst(inst); + int sourcesCount = operation.SourcesCount; - int outDestsCount = operation.DestsCount != 0 ? operation.DestsCount - 1 : 0; + int outDestsCount = operation.DestsCount != 0 && !vectorDest ? operation.DestsCount - 1 : 0; IAstNode[] sources = new IAstNode[sourcesCount + outDestsCount]; @@ -141,7 +144,52 @@ namespace Ryujinx.Graphics.Shader.StructuredIr sources); } - if (operation.Dest != null) + int componentsCount = BitOperations.PopCount((uint)operation.Index); + + if (vectorDest && componentsCount > 1) + { + AggregateType destType = InstructionInfo.GetDestVarType(inst); + + IAstNode source; + + if (operation is TextureOperation texOp) + { + if (texOp.Inst == Instruction.ImageLoad) + { + destType = texOp.Format.GetComponentType(); + } + + source = GetAstTextureOperation(texOp); + } + else + { + source = new AstOperation(inst, operation.Index, sources, operation.SourcesCount); + } + + AggregateType destElemType = destType; + + switch (componentsCount) + { + case 2: destType |= AggregateType.Vector2; break; + case 3: destType |= AggregateType.Vector3; break; + case 4: destType |= AggregateType.Vector4; break; + } + + AstOperand destVec = context.NewTemp(destType); + + context.AddNode(new AstAssignment(destVec, source)); + + for (int i = 0; i < operation.DestsCount; i++) + { + AstOperand dest = context.GetOperandDef(operation.GetDest(i)); + AstOperand index = new AstOperand(OperandType.Constant, i); + + dest.VarType = destElemType; + + context.AddNode(new AstAssignment(dest, new AstOperation(Instruction.VectorExtract, new[] { destVec, index }, 2))); + } + } + else if (operation.Dest != null) { AstOperand dest = context.GetOperandDef(operation.Dest); @@ -149,7 +197,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr // logical operations, rather than forcing a cast to int and doing // a bitwise operation with the value, as it is likely to be used as // a bool in the end. - if (IsBitwiseInst(inst) && AreAllSourceTypesEqual(sources, VariableType.Bool)) + if (IsBitwiseInst(inst) && AreAllSourceTypesEqual(sources, AggregateType.Bool)) { inst = GetLogicalFromBitwiseInst(inst); } @@ -159,9 +207,9 @@ namespace Ryujinx.Graphics.Shader.StructuredIr if (isCondSel || isCopy) { - VariableType type = GetVarTypeFromUses(operation.Dest); + AggregateType type = GetVarTypeFromUses(operation.Dest); - if (isCondSel && type == VariableType.F32) + if (isCondSel && type == AggregateType.FP32) { inst |= Instruction.FP32; } @@ -259,7 +307,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr } } - private static VariableType GetVarTypeFromUses(Operand dest) + private static AggregateType GetVarTypeFromUses(Operand dest) { HashSet visited = new HashSet(); @@ -315,10 +363,10 @@ namespace Ryujinx.Graphics.Shader.StructuredIr } } - return VariableType.S32; + return AggregateType.S32; } - private static bool AreAllSourceTypesEqual(IAstNode[] sources, VariableType type) + private static bool AreAllSourceTypesEqual(IAstNode[] sources, AggregateType type) { foreach (IAstNode node in sources) { @@ -336,6 +384,16 @@ namespace Ryujinx.Graphics.Shader.StructuredIr return true; } + private static bool IsVectorDestInst(Instruction inst) + { + return inst switch + { + Instruction.ImageLoad or + Instruction.TextureSample => true, + _ => false + }; + } + private static bool IsBranchInst(Instruction inst) { return inst switch diff --git a/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgramContext.cs b/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgramContext.cs index e9f8467d..ce57a578 100644 --- a/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgramContext.cs +++ b/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgramContext.cs @@ -80,9 +80,9 @@ namespace Ryujinx.Graphics.Shader.StructuredIr public void EnterFunction( int blocksCount, string name, - VariableType returnType, - VariableType[] inArguments, - VariableType[] outArguments) + AggregateType returnType, + AggregateType[] inArguments, + AggregateType[] outArguments) { _loopTails = new HashSet(); @@ -218,7 +218,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr return gotoTempAsg; } - AstOperand gotoTemp = NewTemp(VariableType.Bool); + AstOperand gotoTemp = NewTemp(AggregateType.Bool); gotoTempAsg = Assign(gotoTemp, Const(IrConsts.False)); @@ -306,7 +306,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr return _gotos.ToArray(); } - private AstOperand NewTemp(VariableType type) + public AstOperand NewTemp(AggregateType type) { AstOperand newTemp = Local(type); diff --git a/Ryujinx.Graphics.Shader/StructuredIr/VariableType.cs b/Ryujinx.Graphics.Shader/StructuredIr/VariableType.cs deleted file mode 100644 index 0afafb2b..00000000 --- a/Ryujinx.Graphics.Shader/StructuredIr/VariableType.cs +++ /dev/null @@ -1,14 +0,0 @@ -namespace Ryujinx.Graphics.Shader.StructuredIr -{ - enum VariableType - { - None, - Bool, - Scalar, - Int, - F32, - F64, - S32, - U32 - } -} \ No newline at end of file diff --git a/Ryujinx.Graphics.Shader/TextureFormat.cs b/Ryujinx.Graphics.Shader/TextureFormat.cs index ecdd15cb..d4c8b96b 100644 --- a/Ryujinx.Graphics.Shader/TextureFormat.cs +++ b/Ryujinx.Graphics.Shader/TextureFormat.cs @@ -1,4 +1,4 @@ -using Ryujinx.Graphics.Shader.StructuredIr; +using Ryujinx.Graphics.Shader.Translation; namespace Ryujinx.Graphics.Shader { @@ -95,7 +95,7 @@ namespace Ryujinx.Graphics.Shader }; } - public static VariableType GetComponentType(this TextureFormat format) + public static AggregateType GetComponentType(this TextureFormat format) { switch (format) { @@ -109,7 +109,7 @@ namespace Ryujinx.Graphics.Shader case TextureFormat.R16G16B16A16Uint: case TextureFormat.R32G32B32A32Uint: case TextureFormat.R10G10B10A2Uint: - return VariableType.U32; + return AggregateType.U32; case TextureFormat.R8Sint: case TextureFormat.R16Sint: case TextureFormat.R32Sint: @@ -119,10 +119,10 @@ namespace Ryujinx.Graphics.Shader case TextureFormat.R8G8B8A8Sint: case TextureFormat.R16G16B16A16Sint: case TextureFormat.R32G32B32A32Sint: - return VariableType.S32; + return AggregateType.S32; } - return VariableType.F32; + return AggregateType.FP32; } } } diff --git a/Ryujinx.Graphics.Shader/Translation/AggregateType.cs b/Ryujinx.Graphics.Shader/Translation/AggregateType.cs index dcd1e0bd..24993e00 100644 --- a/Ryujinx.Graphics.Shader/Translation/AggregateType.cs +++ b/Ryujinx.Graphics.Shader/Translation/AggregateType.cs @@ -12,7 +12,14 @@ ElementTypeMask = 0xff, - Vector = 1 << 8, - Array = 1 << 9 + ElementCountShift = 8, + ElementCountMask = 3 << ElementCountShift, + + Scalar = 0 << ElementCountShift, + Vector2 = 1 << ElementCountShift, + Vector3 = 2 << ElementCountShift, + Vector4 = 3 << ElementCountShift, + + Array = 1 << 10 } } diff --git a/Ryujinx.Graphics.Shader/Translation/AttributeInfo.cs b/Ryujinx.Graphics.Shader/Translation/AttributeInfo.cs index 1647f656..fa1ae17f 100644 --- a/Ryujinx.Graphics.Shader/Translation/AttributeInfo.cs +++ b/Ryujinx.Graphics.Shader/Translation/AttributeInfo.cs @@ -9,22 +9,22 @@ namespace Ryujinx.Graphics.Shader.Translation { AttributeConsts.Layer, new AttributeInfo(AttributeConsts.Layer, 0, 1, AggregateType.S32) }, { AttributeConsts.ViewportIndex, new AttributeInfo(AttributeConsts.ViewportIndex, 0, 1, AggregateType.S32) }, { AttributeConsts.PointSize, new AttributeInfo(AttributeConsts.PointSize, 0, 1, AggregateType.FP32) }, - { AttributeConsts.PositionX, new AttributeInfo(AttributeConsts.PositionX, 0, 4, AggregateType.Vector | AggregateType.FP32) }, - { AttributeConsts.PositionY, new AttributeInfo(AttributeConsts.PositionX, 1, 4, AggregateType.Vector | AggregateType.FP32) }, - { AttributeConsts.PositionZ, new AttributeInfo(AttributeConsts.PositionX, 2, 4, AggregateType.Vector | AggregateType.FP32) }, - { AttributeConsts.PositionW, new AttributeInfo(AttributeConsts.PositionX, 3, 4, AggregateType.Vector | AggregateType.FP32) }, - { AttributeConsts.ClipDistance0, new AttributeInfo(AttributeConsts.ClipDistance0, 0, 8, AggregateType.Array | AggregateType.FP32) }, - { AttributeConsts.ClipDistance1, new AttributeInfo(AttributeConsts.ClipDistance0, 1, 8, AggregateType.Array | AggregateType.FP32) }, - { AttributeConsts.ClipDistance2, new AttributeInfo(AttributeConsts.ClipDistance0, 2, 8, AggregateType.Array | AggregateType.FP32) }, - { AttributeConsts.ClipDistance3, new AttributeInfo(AttributeConsts.ClipDistance0, 3, 8, AggregateType.Array | AggregateType.FP32) }, - { AttributeConsts.ClipDistance4, new AttributeInfo(AttributeConsts.ClipDistance0, 4, 8, AggregateType.Array | AggregateType.FP32) }, - { AttributeConsts.ClipDistance5, new AttributeInfo(AttributeConsts.ClipDistance0, 5, 8, AggregateType.Array | AggregateType.FP32) }, - { AttributeConsts.ClipDistance6, new AttributeInfo(AttributeConsts.ClipDistance0, 6, 8, AggregateType.Array | AggregateType.FP32) }, - { AttributeConsts.ClipDistance7, new AttributeInfo(AttributeConsts.ClipDistance0, 7, 8, AggregateType.Array | AggregateType.FP32) }, - { AttributeConsts.PointCoordX, new AttributeInfo(AttributeConsts.PointCoordX, 0, 2, AggregateType.Vector | AggregateType.FP32) }, - { AttributeConsts.PointCoordY, new AttributeInfo(AttributeConsts.PointCoordX, 1, 2, AggregateType.Vector | AggregateType.FP32) }, - { AttributeConsts.TessCoordX, new AttributeInfo(AttributeConsts.TessCoordX, 0, 3, AggregateType.Vector | AggregateType.FP32) }, - { AttributeConsts.TessCoordY, new AttributeInfo(AttributeConsts.TessCoordX, 1, 3, AggregateType.Vector | AggregateType.FP32) }, + { AttributeConsts.PositionX, new AttributeInfo(AttributeConsts.PositionX, 0, 4, AggregateType.Vector4 | AggregateType.FP32) }, + { AttributeConsts.PositionY, new AttributeInfo(AttributeConsts.PositionX, 1, 4, AggregateType.Vector4 | AggregateType.FP32) }, + { AttributeConsts.PositionZ, new AttributeInfo(AttributeConsts.PositionX, 2, 4, AggregateType.Vector4 | AggregateType.FP32) }, + { AttributeConsts.PositionW, new AttributeInfo(AttributeConsts.PositionX, 3, 4, AggregateType.Vector4 | AggregateType.FP32) }, + { AttributeConsts.ClipDistance0, new AttributeInfo(AttributeConsts.ClipDistance0, 0, 8, AggregateType.Array | AggregateType.FP32) }, + { AttributeConsts.ClipDistance1, new AttributeInfo(AttributeConsts.ClipDistance0, 1, 8, AggregateType.Array | AggregateType.FP32) }, + { AttributeConsts.ClipDistance2, new AttributeInfo(AttributeConsts.ClipDistance0, 2, 8, AggregateType.Array | AggregateType.FP32) }, + { AttributeConsts.ClipDistance3, new AttributeInfo(AttributeConsts.ClipDistance0, 3, 8, AggregateType.Array | AggregateType.FP32) }, + { AttributeConsts.ClipDistance4, new AttributeInfo(AttributeConsts.ClipDistance0, 4, 8, AggregateType.Array | AggregateType.FP32) }, + { AttributeConsts.ClipDistance5, new AttributeInfo(AttributeConsts.ClipDistance0, 5, 8, AggregateType.Array | AggregateType.FP32) }, + { AttributeConsts.ClipDistance6, new AttributeInfo(AttributeConsts.ClipDistance0, 6, 8, AggregateType.Array | AggregateType.FP32) }, + { AttributeConsts.ClipDistance7, new AttributeInfo(AttributeConsts.ClipDistance0, 7, 8, AggregateType.Array | AggregateType.FP32) }, + { AttributeConsts.PointCoordX, new AttributeInfo(AttributeConsts.PointCoordX, 0, 2, AggregateType.Vector4 | AggregateType.FP32) }, + { AttributeConsts.PointCoordY, new AttributeInfo(AttributeConsts.PointCoordX, 1, 2, AggregateType.Vector4 | AggregateType.FP32) }, + { AttributeConsts.TessCoordX, new AttributeInfo(AttributeConsts.TessCoordX, 0, 3, AggregateType.Vector4 | AggregateType.FP32) }, + { AttributeConsts.TessCoordY, new AttributeInfo(AttributeConsts.TessCoordX, 1, 3, AggregateType.Vector4 | AggregateType.FP32) }, { AttributeConsts.InstanceId, new AttributeInfo(AttributeConsts.InstanceId, 0, 1, AggregateType.S32) }, { AttributeConsts.VertexId, new AttributeInfo(AttributeConsts.VertexId, 0, 1, AggregateType.S32) }, { AttributeConsts.BaseInstance, new AttributeInfo(AttributeConsts.BaseInstance, 0, 1, AggregateType.S32) }, @@ -37,21 +37,21 @@ namespace Ryujinx.Graphics.Shader.Translation // Special. { AttributeConsts.FragmentOutputDepth, new AttributeInfo(AttributeConsts.FragmentOutputDepth, 0, 1, AggregateType.FP32) }, { AttributeConsts.ThreadKill, new AttributeInfo(AttributeConsts.ThreadKill, 0, 1, AggregateType.Bool) }, - { AttributeConsts.ThreadIdX, new AttributeInfo(AttributeConsts.ThreadIdX, 0, 3, AggregateType.Vector | AggregateType.U32) }, - { AttributeConsts.ThreadIdY, new AttributeInfo(AttributeConsts.ThreadIdX, 1, 3, AggregateType.Vector | AggregateType.U32) }, - { AttributeConsts.ThreadIdZ, new AttributeInfo(AttributeConsts.ThreadIdX, 2, 3, AggregateType.Vector | AggregateType.U32) }, - { AttributeConsts.CtaIdX, new AttributeInfo(AttributeConsts.CtaIdX, 0, 3, AggregateType.Vector | AggregateType.U32) }, - { AttributeConsts.CtaIdY, new AttributeInfo(AttributeConsts.CtaIdX, 1, 3, AggregateType.Vector | AggregateType.U32) }, - { AttributeConsts.CtaIdZ, new AttributeInfo(AttributeConsts.CtaIdX, 2, 3, AggregateType.Vector | AggregateType.U32) }, + { AttributeConsts.ThreadIdX, new AttributeInfo(AttributeConsts.ThreadIdX, 0, 3, AggregateType.Vector3 | AggregateType.U32) }, + { AttributeConsts.ThreadIdY, new AttributeInfo(AttributeConsts.ThreadIdX, 1, 3, AggregateType.Vector3 | AggregateType.U32) }, + { AttributeConsts.ThreadIdZ, new AttributeInfo(AttributeConsts.ThreadIdX, 2, 3, AggregateType.Vector3 | AggregateType.U32) }, + { AttributeConsts.CtaIdX, new AttributeInfo(AttributeConsts.CtaIdX, 0, 3, AggregateType.Vector3 | AggregateType.U32) }, + { AttributeConsts.CtaIdY, new AttributeInfo(AttributeConsts.CtaIdX, 1, 3, AggregateType.Vector3 | AggregateType.U32) }, + { AttributeConsts.CtaIdZ, new AttributeInfo(AttributeConsts.CtaIdX, 2, 3, AggregateType.Vector3 | AggregateType.U32) }, { AttributeConsts.LaneId, new AttributeInfo(AttributeConsts.LaneId, 0, 1, AggregateType.U32) }, { AttributeConsts.InvocationId, new AttributeInfo(AttributeConsts.InvocationId, 0, 1, AggregateType.S32) }, { AttributeConsts.PrimitiveId, new AttributeInfo(AttributeConsts.PrimitiveId, 0, 1, AggregateType.S32) }, { AttributeConsts.PatchVerticesIn, new AttributeInfo(AttributeConsts.PatchVerticesIn, 0, 1, AggregateType.S32) }, - { AttributeConsts.EqMask, new AttributeInfo(AttributeConsts.EqMask, 0, 4, AggregateType.Vector | AggregateType.U32) }, - { AttributeConsts.GeMask, new AttributeInfo(AttributeConsts.GeMask, 0, 4, AggregateType.Vector | AggregateType.U32) }, - { AttributeConsts.GtMask, new AttributeInfo(AttributeConsts.GtMask, 0, 4, AggregateType.Vector | AggregateType.U32) }, - { AttributeConsts.LeMask, new AttributeInfo(AttributeConsts.LeMask, 0, 4, AggregateType.Vector | AggregateType.U32) }, - { AttributeConsts.LtMask, new AttributeInfo(AttributeConsts.LtMask, 0, 4, AggregateType.Vector | AggregateType.U32) }, + { AttributeConsts.EqMask, new AttributeInfo(AttributeConsts.EqMask, 0, 4, AggregateType.Vector4 | AggregateType.U32) }, + { AttributeConsts.GeMask, new AttributeInfo(AttributeConsts.GeMask, 0, 4, AggregateType.Vector4 | AggregateType.U32) }, + { AttributeConsts.GtMask, new AttributeInfo(AttributeConsts.GtMask, 0, 4, AggregateType.Vector4 | AggregateType.U32) }, + { AttributeConsts.LeMask, new AttributeInfo(AttributeConsts.LeMask, 0, 4, AggregateType.Vector4 | AggregateType.U32) }, + { AttributeConsts.LtMask, new AttributeInfo(AttributeConsts.LtMask, 0, 4, AggregateType.Vector4 | AggregateType.U32) }, }; private static readonly Dictionary _builtInAttributesPerPatch = new Dictionary() @@ -124,11 +124,11 @@ namespace Ryujinx.Graphics.Shader.Translation elemType = AggregateType.FP32; } - return new AttributeInfo(value & ~0xf, (value >> 2) & 3, 4, AggregateType.Vector | elemType, false); + return new AttributeInfo(value & ~0xf, (value >> 2) & 3, 4, AggregateType.Vector4 | elemType, false); } else if (value >= AttributeConsts.FragmentOutputColorBase && value < AttributeConsts.FragmentOutputColorEnd) { - return new AttributeInfo(value & ~0xf, (value >> 2) & 3, 4, AggregateType.Vector | AggregateType.FP32, false); + return new AttributeInfo(value & ~0xf, (value >> 2) & 3, 4, AggregateType.Vector4 | AggregateType.FP32, false); } else if (value == AttributeConsts.SupportBlockViewInverseX || value == AttributeConsts.SupportBlockViewInverseY) { @@ -149,7 +149,7 @@ namespace Ryujinx.Graphics.Shader.Translation if (value >= AttributeConsts.UserAttributePerPatchBase && value < AttributeConsts.UserAttributePerPatchEnd) { int offset = (value - AttributeConsts.UserAttributePerPatchBase) & 0xf; - return new AttributeInfo(value - offset, offset >> 2, 4, AggregateType.Vector | AggregateType.FP32, false); + return new AttributeInfo(value - offset, offset >> 2, 4, AggregateType.Vector4 | AggregateType.FP32, false); } else if (_builtInAttributesPerPatch.TryGetValue(value, out AttributeInfo info)) { diff --git a/Ryujinx.Graphics.Shader/Translation/EmitterContext.cs b/Ryujinx.Graphics.Shader/Translation/EmitterContext.cs index 7961ada8..ad55c010 100644 --- a/Ryujinx.Graphics.Shader/Translation/EmitterContext.cs +++ b/Ryujinx.Graphics.Shader/Translation/EmitterContext.cs @@ -109,10 +109,10 @@ namespace Ryujinx.Graphics.Shader.Translation TextureFlags flags, int handle, int compIndex, - Operand dest, + Operand[] dests, params Operand[] sources) { - return CreateTextureOperation(inst, type, TextureFormat.Unknown, flags, handle, compIndex, dest, sources); + return CreateTextureOperation(inst, type, TextureFormat.Unknown, flags, handle, compIndex, dests, sources); } public TextureOperation CreateTextureOperation( @@ -122,7 +122,7 @@ namespace Ryujinx.Graphics.Shader.Translation TextureFlags flags, int handle, int compIndex, - Operand dest, + Operand[] dests, params Operand[] sources) { if (!flags.HasFlag(TextureFlags.Bindless)) @@ -130,7 +130,7 @@ namespace Ryujinx.Graphics.Shader.Translation Config.SetUsedTexture(inst, type, format, flags, TextureOperation.DefaultCbufSlot, handle); } - return new TextureOperation(inst, type, format, flags, handle, compIndex, dest, sources); + return new TextureOperation(inst, type, format, flags, handle, compIndex, dests, sources); } public void FlagAttributeRead(int attribute) diff --git a/Ryujinx.Graphics.Shader/Translation/Rewriter.cs b/Ryujinx.Graphics.Shader/Translation/Rewriter.cs index 0c3c4a57..3ec4e49a 100644 --- a/Ryujinx.Graphics.Shader/Translation/Rewriter.cs +++ b/Ryujinx.Graphics.Shader/Translation/Rewriter.cs @@ -385,15 +385,6 @@ namespace Ryujinx.Graphics.Shader.Translation int componentIndex = texOp.Index; - Operand Int(Operand value) - { - Operand res = Local(); - - node.List.AddBefore(node, new Operation(Instruction.ConvertFP32ToS32, res, value)); - - return res; - } - Operand Float(Operand value) { Operand res = Local(); @@ -436,7 +427,7 @@ namespace Ryujinx.Graphics.Shader.Translation texOp.CbufSlot, texOp.Handle, index, - coordSize, + new[] { coordSize }, texSizeSources)); config.SetUsedTexture(Instruction.TextureSize, texOp.Type, texOp.Format, texOp.Flags, texOp.CbufSlot, texOp.Handle); @@ -451,80 +442,53 @@ namespace Ryujinx.Graphics.Shader.Translation } } + Operand[] dests = new Operand[texOp.DestsCount]; + + for (int i = 0; i < texOp.DestsCount; i++) + { + dests[i] = texOp.GetDest(i); + } + + Operand bindlessHandle = isBindless || isIndexed ? sources[0] : null; + + LinkedListNode oldNode = node; + // Technically, non-constant texture offsets are not allowed (according to the spec), // however some GPUs does support that. // For GPUs where it is not supported, we can replace the instruction with the following: // For texture*Offset, we replace it by texture*, and add the offset to the P coords. // The offset can be calculated as offset / textureSize(lod), where lod = textureQueryLod(coords). // For texelFetchOffset, we replace it by texelFetch and add the offset to the P coords directly. - // For textureGatherOffset, we take advantage of the fact that the operation is already broken down - // to read the 4 pixels separately, and just replace it with 4 textureGather with a different offset - // for each pixel. - if (hasInvalidOffset) + // For textureGatherOffset, we split the operation into up to 4 operations, one for each component + // that is accessed, where each textureGather operation has a different offset for each pixel. + if (hasInvalidOffset && isGather && !isShadow) { - if (intCoords) + config.SetUsedFeature(FeatureFlags.IntegerSampling); + + Operand[] newSources = new Operand[sources.Length]; + + sources.CopyTo(newSources, 0); + + Operand[] texSizes = InsertTextureSize(node, texOp, lodSources, bindlessHandle, coordsCount); + + int destIndex = 0; + + for (int compIndex = 0; compIndex < 4; compIndex++) { - for (int index = 0; index < coordsCount; index++) + if (((texOp.Index >> compIndex) & 1) == 0) { - Operand source = sources[coordsIndex + index]; - - Operand coordPlusOffset = Local(); - - node.List.AddBefore(node, new Operation(Instruction.Add, coordPlusOffset, source, offsets[index])); - - sources[coordsIndex + index] = coordPlusOffset; + continue; } - } - else - { - config.SetUsedFeature(FeatureFlags.IntegerSampling); - - Operand lod = Local(); - - node.List.AddBefore(node, new TextureOperation( - Instruction.Lod, - texOp.Type, - texOp.Format, - texOp.Flags, - texOp.CbufSlot, - texOp.Handle, - 0, - lod, - lodSources)); for (int index = 0; index < coordsCount; index++) { - Operand coordSize = Local(); - - Operand[] texSizeSources; - - if (isBindless || isIndexed) - { - texSizeSources = new Operand[] { sources[0], Int(lod) }; - } - else - { - texSizeSources = new Operand[] { Int(lod) }; - } - - node.List.AddBefore(node, new TextureOperation( - Instruction.TextureSize, - texOp.Type, - texOp.Format, - texOp.Flags, - texOp.CbufSlot, - texOp.Handle, - index, - coordSize, - texSizeSources)); - config.SetUsedTexture(Instruction.TextureSize, texOp.Type, texOp.Format, texOp.Flags, texOp.CbufSlot, texOp.Handle); Operand offset = Local(); - Operand intOffset = offsets[index + (hasOffsets ? texOp.Index * coordsCount : 0)]; + Operand intOffset = offsets[index + (hasOffsets ? compIndex * coordsCount : 0)]; - node.List.AddBefore(node, new Operation(Instruction.FP32 | Instruction.Divide, offset, Float(intOffset), Float(coordSize))); + node.List.AddBefore(node, new Operation(Instruction.FP32 | Instruction.Divide, offset, Float(intOffset), Float(texSizes[index]))); Operand source = sources[coordsIndex + index]; @@ -532,45 +496,152 @@ namespace Ryujinx.Graphics.Shader.Translation node.List.AddBefore(node, new Operation(Instruction.FP32 | Instruction.Add, coordPlusOffset, source, offset)); - sources[coordsIndex + index] = coordPlusOffset; + newSources[coordsIndex + index] = coordPlusOffset; + } + + TextureOperation newTexOp = new TextureOperation( + Instruction.TextureSample, + texOp.Type, + texOp.Format, + texOp.Flags & ~(TextureFlags.Offset | TextureFlags.Offsets), + texOp.CbufSlot, + texOp.Handle, + 1, + new[] { dests[destIndex++] }, + newSources); + + node = node.List.AddBefore(node, newTexOp); + } + } + else + { + if (hasInvalidOffset) + { + if (intCoords) + { + for (int index = 0; index < coordsCount; index++) + { + Operand source = sources[coordsIndex + index]; + + Operand coordPlusOffset = Local(); + + node.List.AddBefore(node, new Operation(Instruction.Add, coordPlusOffset, source, offsets[index])); + + sources[coordsIndex + index] = coordPlusOffset; + } + } + else + { + config.SetUsedFeature(FeatureFlags.IntegerSampling); + + Operand[] texSizes = InsertTextureSize(node, texOp, lodSources, bindlessHandle, coordsCount); + + for (int index = 0; index < coordsCount; index++) + { + config.SetUsedTexture(Instruction.TextureSize, texOp.Type, texOp.Format, texOp.Flags, texOp.CbufSlot, texOp.Handle); + + Operand offset = Local(); + + Operand intOffset = offsets[index]; + + node.List.AddBefore(node, new Operation(Instruction.FP32 | Instruction.Divide, offset, Float(intOffset), Float(texSizes[index]))); + + Operand source = sources[coordsIndex + index]; + + Operand coordPlusOffset = Local(); + + node.List.AddBefore(node, new Operation(Instruction.FP32 | Instruction.Add, coordPlusOffset, source, offset)); + + sources[coordsIndex + index] = coordPlusOffset; + } } } - if (isGather && !isShadow) - { - Operand gatherComponent = sources[dstIndex - 1]; + TextureOperation newTexOp = new TextureOperation( + Instruction.TextureSample, + texOp.Type, + texOp.Format, + texOp.Flags & ~(TextureFlags.Offset | TextureFlags.Offsets), + texOp.CbufSlot, + texOp.Handle, + componentIndex, + dests, + sources); - Debug.Assert(gatherComponent.Type == OperandType.Constant); - - componentIndex = gatherComponent.Value; - } + node = node.List.AddBefore(node, newTexOp); } - TextureOperation newTexOp = new TextureOperation( - Instruction.TextureSample, - texOp.Type, - texOp.Format, - texOp.Flags & ~(TextureFlags.Offset | TextureFlags.Offsets), - texOp.CbufSlot, - texOp.Handle, - componentIndex, - texOp.Dest, - sources); + node.List.Remove(oldNode); for (int index = 0; index < texOp.SourcesCount; index++) { texOp.SetSource(index, null); } - LinkedListNode oldNode = node; - - node = node.List.AddBefore(node, newTexOp); - - node.List.Remove(oldNode); - return node; } + private static Operand[] InsertTextureSize( + LinkedListNode node, + TextureOperation texOp, + Operand[] lodSources, + Operand bindlessHandle, + int coordsCount) + { + Operand Int(Operand value) + { + Operand res = Local(); + + node.List.AddBefore(node, new Operation(Instruction.ConvertFP32ToS32, res, value)); + + return res; + } + + Operand[] texSizes = new Operand[coordsCount]; + + Operand lod = Local(); + + node.List.AddBefore(node, new TextureOperation( + Instruction.Lod, + texOp.Type, + texOp.Format, + texOp.Flags, + texOp.CbufSlot, + texOp.Handle, + 0, + new[] { lod }, + lodSources)); + + for (int index = 0; index < coordsCount; index++) + { + texSizes[index] = Local(); + + Operand[] texSizeSources; + + if (bindlessHandle != null) + { + texSizeSources = new Operand[] { bindlessHandle, Int(lod) }; + } + else + { + texSizeSources = new Operand[] { Int(lod) }; + } + + node.List.AddBefore(node, new TextureOperation( + Instruction.TextureSize, + texOp.Type, + texOp.Format, + texOp.Flags, + texOp.CbufSlot, + texOp.Handle, + index, + new[] { texSizes[index] }, + texSizeSources)); + } + + return texSizes; + } + private static LinkedListNode InsertSnormNormalization(LinkedListNode node, ShaderConfig config) { TextureOperation texOp = (TextureOperation)node.Value; @@ -604,27 +675,32 @@ namespace Ryujinx.Graphics.Shader.Translation // Do normalization. We assume SINT formats are being used // as replacement for SNORM (which is not supported). - INode[] uses = texOp.Dest.UseOps.ToArray(); - - Operation convOp = new Operation(Instruction.ConvertS32ToFP32, Local(), texOp.Dest); - Operation normOp = new Operation(Instruction.FP32 | Instruction.Multiply, Local(), convOp.Dest, ConstF(1f / maxPositive)); - - node = node.List.AddAfter(node, convOp); - node = node.List.AddAfter(node, normOp); - - foreach (INode useOp in uses) + for (int i = 0; i < texOp.DestsCount; i++) { - if (useOp is not Operation op) - { - continue; - } + Operand dest = texOp.GetDest(i); - // Replace all uses of the texture pixel value with the normalized value. - for (int index = 0; index < op.SourcesCount; index++) + INode[] uses = dest.UseOps.ToArray(); + + Operation convOp = new Operation(Instruction.ConvertS32ToFP32, Local(), dest); + Operation normOp = new Operation(Instruction.FP32 | Instruction.Multiply, Local(), convOp.Dest, ConstF(1f / maxPositive)); + + node = node.List.AddAfter(node, convOp); + node = node.List.AddAfter(node, normOp); + + foreach (INode useOp in uses) { - if (op.GetSource(index) == texOp.Dest) + if (useOp is not Operation op) { - op.SetSource(index, normOp.Dest); + continue; + } + + // Replace all uses of the texture pixel value with the normalized value. + for (int index = 0; index < op.SourcesCount; index++) + { + if (op.GetSource(index) == dest) + { + op.SetSource(index, normOp.Dest); + } } } }