diff --git a/externals/sirit b/externals/sirit
index f7c4b07a7..e1a6729df 160000
--- a/externals/sirit
+++ b/externals/sirit
@@ -1 +1 @@
-Subproject commit f7c4b07a7e14edb1dcd93bc9879c823423705c2e
+Subproject commit e1a6729df7f11e33f6dc0939b18995a57c8bf3d8
diff --git a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
index 76894275b..8f517bdc1 100644
--- a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
+++ b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
@@ -3,8 +3,10 @@
 // Refer to the license.txt file included.
 
 #include <functional>
+#include <limits>
 #include <map>
-#include <set>
+#include <type_traits>
+#include <utility>
 
 #include <fmt/format.h>
 
@@ -23,7 +25,9 @@
 #include "video_core/shader/node.h"
 #include "video_core/shader/shader_ir.h"
 
-namespace Vulkan::VKShader {
+namespace Vulkan {
+
+namespace {
 
 using Sirit::Id;
 using Tegra::Engines::ShaderType;
@@ -35,22 +39,60 @@ using namespace VideoCommon::Shader;
 using Maxwell = Tegra::Engines::Maxwell3D::Regs;
 using Operation = const OperationNode&;
 
+class ASTDecompiler;
+class ExprDecompiler;
+
 // TODO(Rodrigo): Use rasterizer's value
-constexpr u32 MAX_CONSTBUFFER_FLOATS = 0x4000;
-constexpr u32 MAX_CONSTBUFFER_ELEMENTS = MAX_CONSTBUFFER_FLOATS / 4;
-constexpr u32 STAGE_BINDING_STRIDE = 0x100;
+constexpr u32 MaxConstBufferFloats = 0x4000;
+constexpr u32 MaxConstBufferElements = MaxConstBufferFloats / 4;
 
-enum class Type { Bool, Bool2, Float, Int, Uint, HalfFloat };
+constexpr u32 NumInputPatches = 32; // This value seems to be the standard
 
-struct SamplerImage {
-    Id image_type;
-    Id sampled_image_type;
-    Id sampler;
+enum class Type { Void, Bool, Bool2, Float, Int, Uint, HalfFloat };
+
+class Expression final {
+public:
+    Expression(Id id, Type type) : id{id}, type{type} {
+        ASSERT(type != Type::Void);
+    }
+    Expression() : type{Type::Void} {}
+
+    Id id{};
+    Type type{};
+};
+static_assert(std::is_standard_layout_v<Expression>);
+
+struct TexelBuffer {
+    Id image_type{};
+    Id image{};
 };
 
-namespace {
+struct SampledImage {
+    Id image_type{};
+    Id sampled_image_type{};
+    Id sampler{};
+};
+
+struct StorageImage {
+    Id image_type{};
+    Id image{};
+};
+
+struct AttributeType {
+    Type type;
+    Id scalar;
+    Id vector;
+};
+
+struct VertexIndices {
+    std::optional<u32> position;
+    std::optional<u32> viewport;
+    std::optional<u32> point_size;
+    std::optional<u32> clip_distances;
+};
 
 spv::Dim GetSamplerDim(const Sampler& sampler) {
+    ASSERT(!sampler.IsBuffer());
     switch (sampler.GetType()) {
     case Tegra::Shader::TextureType::Texture1D:
         return spv::Dim::Dim1D;
@@ -66,6 +108,138 @@ spv::Dim GetSamplerDim(const Sampler& sampler) {
     }
 }
 
+std::pair<spv::Dim, bool> GetImageDim(const Image& image) {
+    switch (image.GetType()) {
+    case Tegra::Shader::ImageType::Texture1D:
+        return {spv::Dim::Dim1D, false};
+    case Tegra::Shader::ImageType::TextureBuffer:
+        return {spv::Dim::Buffer, false};
+    case Tegra::Shader::ImageType::Texture1DArray:
+        return {spv::Dim::Dim1D, true};
+    case Tegra::Shader::ImageType::Texture2D:
+        return {spv::Dim::Dim2D, false};
+    case Tegra::Shader::ImageType::Texture2DArray:
+        return {spv::Dim::Dim2D, true};
+    case Tegra::Shader::ImageType::Texture3D:
+        return {spv::Dim::Dim3D, false};
+    default:
+        UNIMPLEMENTED_MSG("Unimplemented image type={}", static_cast<u32>(image.GetType()));
+        return {spv::Dim::Dim2D, false};
+    }
+}
+
+/// Returns the number of vertices present in a primitive topology.
+u32 GetNumPrimitiveTopologyVertices(Maxwell::PrimitiveTopology primitive_topology) {
+    switch (primitive_topology) {
+    case Maxwell::PrimitiveTopology::Points:
+        return 1;
+    case Maxwell::PrimitiveTopology::Lines:
+    case Maxwell::PrimitiveTopology::LineLoop:
+    case Maxwell::PrimitiveTopology::LineStrip:
+        return 2;
+    case Maxwell::PrimitiveTopology::Triangles:
+    case Maxwell::PrimitiveTopology::TriangleStrip:
+    case Maxwell::PrimitiveTopology::TriangleFan:
+        return 3;
+    case Maxwell::PrimitiveTopology::LinesAdjacency:
+    case Maxwell::PrimitiveTopology::LineStripAdjacency:
+        return 4;
+    case Maxwell::PrimitiveTopology::TrianglesAdjacency:
+    case Maxwell::PrimitiveTopology::TriangleStripAdjacency:
+        return 6;
+    case Maxwell::PrimitiveTopology::Quads:
+        UNIMPLEMENTED_MSG("Quads");
+        return 3;
+    case Maxwell::PrimitiveTopology::QuadStrip:
+        UNIMPLEMENTED_MSG("QuadStrip");
+        return 3;
+    case Maxwell::PrimitiveTopology::Polygon:
+        UNIMPLEMENTED_MSG("Polygon");
+        return 3;
+    case Maxwell::PrimitiveTopology::Patches:
+        UNIMPLEMENTED_MSG("Patches");
+        return 3;
+    default:
+        UNREACHABLE();
+        return 3;
+    }
+}
+
+spv::ExecutionMode GetExecutionMode(Maxwell::TessellationPrimitive primitive) {
+    switch (primitive) {
+    case Maxwell::TessellationPrimitive::Isolines:
+        return spv::ExecutionMode::Isolines;
+    case Maxwell::TessellationPrimitive::Triangles:
+        return spv::ExecutionMode::Triangles;
+    case Maxwell::TessellationPrimitive::Quads:
+        return spv::ExecutionMode::Quads;
+    }
+    UNREACHABLE();
+    return spv::ExecutionMode::Triangles;
+}
+
+spv::ExecutionMode GetExecutionMode(Maxwell::TessellationSpacing spacing) {
+    switch (spacing) {
+    case Maxwell::TessellationSpacing::Equal:
+        return spv::ExecutionMode::SpacingEqual;
+    case Maxwell::TessellationSpacing::FractionalOdd:
+        return spv::ExecutionMode::SpacingFractionalOdd;
+    case Maxwell::TessellationSpacing::FractionalEven:
+        return spv::ExecutionMode::SpacingFractionalEven;
+    }
+    UNREACHABLE();
+    return spv::ExecutionMode::SpacingEqual;
+}
+
+spv::ExecutionMode GetExecutionMode(Maxwell::PrimitiveTopology input_topology) {
+    switch (input_topology) {
+    case Maxwell::PrimitiveTopology::Points:
+        return spv::ExecutionMode::InputPoints;
+    case Maxwell::PrimitiveTopology::Lines:
+    case Maxwell::PrimitiveTopology::LineLoop:
+    case Maxwell::PrimitiveTopology::LineStrip:
+        return spv::ExecutionMode::InputLines;
+    case Maxwell::PrimitiveTopology::Triangles:
+    case Maxwell::PrimitiveTopology::TriangleStrip:
+    case Maxwell::PrimitiveTopology::TriangleFan:
+        return spv::ExecutionMode::Triangles;
+    case Maxwell::PrimitiveTopology::LinesAdjacency:
+    case Maxwell::PrimitiveTopology::LineStripAdjacency:
+        return spv::ExecutionMode::InputLinesAdjacency;
+    case Maxwell::PrimitiveTopology::TrianglesAdjacency:
+    case Maxwell::PrimitiveTopology::TriangleStripAdjacency:
+        return spv::ExecutionMode::InputTrianglesAdjacency;
+    case Maxwell::PrimitiveTopology::Quads:
+        UNIMPLEMENTED_MSG("Quads");
+        return spv::ExecutionMode::Triangles;
+    case Maxwell::PrimitiveTopology::QuadStrip:
+        UNIMPLEMENTED_MSG("QuadStrip");
+        return spv::ExecutionMode::Triangles;
+    case Maxwell::PrimitiveTopology::Polygon:
+        UNIMPLEMENTED_MSG("Polygon");
+        return spv::ExecutionMode::Triangles;
+    case Maxwell::PrimitiveTopology::Patches:
+        UNIMPLEMENTED_MSG("Patches");
+        return spv::ExecutionMode::Triangles;
+    }
+    UNREACHABLE();
+    return spv::ExecutionMode::Triangles;
+}
+
+spv::ExecutionMode GetExecutionMode(Tegra::Shader::OutputTopology output_topology) {
+    switch (output_topology) {
+    case Tegra::Shader::OutputTopology::PointList:
+        return spv::ExecutionMode::OutputPoints;
+    case Tegra::Shader::OutputTopology::LineStrip:
+        return spv::ExecutionMode::OutputLineStrip;
+    case Tegra::Shader::OutputTopology::TriangleStrip:
+        return spv::ExecutionMode::OutputTriangleStrip;
+    default:
+        UNREACHABLE();
+        return spv::ExecutionMode::OutputPoints;
+    }
+}
+
 /// Returns true if an attribute index is one of the 32 generic attributes
 constexpr bool IsGenericAttribute(Attribute::Index attribute) {
     return attribute >= Attribute::Index::Attribute_0 &&
@@ -73,7 +247,7 @@ constexpr bool IsGenericAttribute(Attribute::Index attribute) {
 }
 
 /// Returns the location of a generic attribute
-constexpr u32 GetGenericAttributeLocation(Attribute::Index attribute) {
+u32 GetGenericAttributeLocation(Attribute::Index attribute) {
     ASSERT(IsGenericAttribute(attribute));
     return static_cast<u32>(attribute) - static_cast<u32>(Attribute::Index::Attribute_0);
 }
@@ -87,20 +261,146 @@ bool IsPrecise(Operation operand) {
     return false;
 }
 
-} // namespace
-
-class ASTDecompiler;
-class ExprDecompiler;
-
-class SPIRVDecompiler : public Sirit::Module {
+class SPIRVDecompiler final : public Sirit::Module {
 public:
-    explicit SPIRVDecompiler(const VKDevice& device, const ShaderIR& ir, ShaderType stage)
-        : Module(0x00010300), device{device}, ir{ir}, stage{stage}, header{ir.GetHeader()} {
+    explicit SPIRVDecompiler(const VKDevice& device, const ShaderIR& ir, ShaderType stage,
+                             const Specialization& specialization)
+        : Module(0x00010300), device{device}, ir{ir}, stage{stage}, header{ir.GetHeader()},
+          specialization{specialization} {
         AddCapability(spv::Capability::Shader);
+        AddCapability(spv::Capability::UniformAndStorageBuffer16BitAccess);
+        AddCapability(spv::Capability::ImageQuery);
+        AddCapability(spv::Capability::Image1D);
+        AddCapability(spv::Capability::ImageBuffer);
+        AddCapability(spv::Capability::ImageGatherExtended);
+        AddCapability(spv::Capability::SampledBuffer);
+        AddCapability(spv::Capability::StorageImageWriteWithoutFormat);
+        AddCapability(spv::Capability::SubgroupBallotKHR);
+        AddCapability(spv::Capability::SubgroupVoteKHR);
+        AddExtension("SPV_KHR_shader_ballot");
+        AddExtension("SPV_KHR_subgroup_vote");
         AddExtension("SPV_KHR_storage_buffer_storage_class");
         AddExtension("SPV_KHR_variable_pointers");
+
+        if (ir.UsesViewportIndex()) {
+            AddCapability(spv::Capability::MultiViewport);
+            if (device.IsExtShaderViewportIndexLayerSupported()) {
+                AddExtension("SPV_EXT_shader_viewport_index_layer");
+                AddCapability(spv::Capability::ShaderViewportIndexLayerEXT);
+            }
+        }
+
+        if (device.IsFloat16Supported()) {
+            AddCapability(spv::Capability::Float16);
+        }
+        t_scalar_half = Name(TypeFloat(device.IsFloat16Supported() ? 16 : 32), "scalar_half");
+        t_half = Name(TypeVector(t_scalar_half, 2), "half");
+
+        const Id main = Decompile();
+
+        switch (stage) {
+        case ShaderType::Vertex:
+            AddEntryPoint(spv::ExecutionModel::Vertex, main, "main", interfaces);
+            break;
+        case ShaderType::TesselationControl:
+            AddCapability(spv::Capability::Tessellation);
+            AddEntryPoint(spv::ExecutionModel::TessellationControl, main, "main", interfaces);
+            AddExecutionMode(main, spv::ExecutionMode::OutputVertices,
+                             header.common2.threads_per_input_primitive);
+            break;
+        case ShaderType::TesselationEval:
+            AddCapability(spv::Capability::Tessellation);
+            AddEntryPoint(spv::ExecutionModel::TessellationEvaluation, main, "main", interfaces);
+            AddExecutionMode(main, GetExecutionMode(specialization.tessellation.primitive));
+            AddExecutionMode(main, GetExecutionMode(specialization.tessellation.spacing));
+            AddExecutionMode(main, specialization.tessellation.clockwise
+                                       ? spv::ExecutionMode::VertexOrderCw
+                                       : spv::ExecutionMode::VertexOrderCcw);
+            break;
+        case ShaderType::Geometry:
+            AddCapability(spv::Capability::Geometry);
+            AddEntryPoint(spv::ExecutionModel::Geometry, main, "main", interfaces);
+            AddExecutionMode(main, GetExecutionMode(specialization.primitive_topology));
+            AddExecutionMode(main, GetExecutionMode(header.common3.output_topology));
+            AddExecutionMode(main, spv::ExecutionMode::OutputVertices,
+                             header.common4.max_output_vertices);
+            // TODO(Rodrigo): Where can we get this info from?
+            AddExecutionMode(main, spv::ExecutionMode::Invocations, 1U);
+            break;
+        case ShaderType::Fragment:
+            AddEntryPoint(spv::ExecutionModel::Fragment, main, "main", interfaces);
+            AddExecutionMode(main, spv::ExecutionMode::OriginUpperLeft);
+            if (header.ps.omap.depth) {
+                AddExecutionMode(main, spv::ExecutionMode::DepthReplacing);
+            }
+            break;
+        case ShaderType::Compute:
+            const auto workgroup_size = specialization.workgroup_size;
+            AddExecutionMode(main, spv::ExecutionMode::LocalSize, workgroup_size[0],
+                             workgroup_size[1], workgroup_size[2]);
+            AddEntryPoint(spv::ExecutionModel::GLCompute, main, "main", interfaces);
+            break;
+        }
     }
 
+private:
+    Id Decompile() {
+        DeclareCommon();
+        DeclareVertex();
+        DeclareTessControl();
+        DeclareTessEval();
+        DeclareGeometry();
+        DeclareFragment();
+        DeclareCompute();
+        DeclareRegisters();
+        DeclarePredicates();
+        DeclareLocalMemory();
+        DeclareSharedMemory();
+        DeclareInternalFlags();
+        DeclareInputAttributes();
+        DeclareOutputAttributes();
+
+        u32 binding = specialization.base_binding;
+        binding = DeclareConstantBuffers(binding);
+        binding = DeclareGlobalBuffers(binding);
+        binding = DeclareTexelBuffers(binding);
+        binding = DeclareSamplers(binding);
+        binding = DeclareImages(binding);
+
+        const Id main = OpFunction(t_void, {}, TypeFunction(t_void));
+        AddLabel();
+
+        if (ir.IsDecompiled()) {
+            DeclareFlowVariables();
+            DecompileAST();
+        } else {
+            AllocateLabels();
+            DecompileBranchMode();
+        }
+
+        OpReturn();
+        OpFunctionEnd();
+
+        return main;
+    }
+
+    void DefinePrologue() {
+        if (stage == ShaderType::Vertex) {
+            // Clear Position to avoid reading trash on the Z conversion.
+            const auto position_index = out_indices.position.value();
+            const Id position = AccessElement(t_out_float4, out_vertex, position_index);
+            OpStore(position, v_varying_default);
+
+            if (specialization.point_size) {
+                const u32 point_size_index = out_indices.point_size.value();
+                const Id out_point_size = AccessElement(t_out_float, out_vertex, point_size_index);
+                OpStore(out_point_size, Constant(t_float, *specialization.point_size));
+            }
+        }
+    }
+
+    void DecompileAST();
+
     void DecompileBranchMode() {
         const u32 first_address = ir.GetBasicBlocks().begin()->first;
         const Id loop_label = OpLabel("loop");
@@ -111,14 +411,15 @@ public:
 
         std::vector<Sirit::Literal> literals;
         std::vector<Id> branch_labels;
-        for (const auto& pair : labels) {
-            const auto [literal, label] = pair;
+        for (const auto& [literal, label] : labels) {
             literals.push_back(literal);
             branch_labels.push_back(label);
         }
 
-        jmp_to = Emit(OpVariable(TypePointer(spv::StorageClass::Function, t_uint),
-                                 spv::StorageClass::Function, Constant(t_uint, first_address)));
+        jmp_to = OpVariable(TypePointer(spv::StorageClass::Function, t_uint),
+                            spv::StorageClass::Function, Constant(t_uint, first_address));
+        AddLocalVariable(jmp_to);
+
         std::tie(ssy_flow_stack, ssy_flow_stack_top) = CreateFlowStack();
         std::tie(pbk_flow_stack, pbk_flow_stack_top) = CreateFlowStack();
 
@@ -128,102 +429,37 @@ public:
         Name(pbk_flow_stack, "pbk_flow_stack");
         Name(pbk_flow_stack_top, "pbk_flow_stack_top");
 
-        Emit(OpBranch(loop_label));
-        Emit(loop_label);
-        Emit(OpLoopMerge(merge_label, continue_label, spv::LoopControlMask::Unroll));
-        Emit(OpBranch(dummy_label));
+        DefinePrologue();
 
-        Emit(dummy_label);
+        OpBranch(loop_label);
+        AddLabel(loop_label);
+        OpLoopMerge(merge_label, continue_label, spv::LoopControlMask::MaskNone);
+        OpBranch(dummy_label);
+
+        AddLabel(dummy_label);
         const Id default_branch = OpLabel();
-        const Id jmp_to_load = Emit(OpLoad(t_uint, jmp_to));
-        Emit(OpSelectionMerge(jump_label, spv::SelectionControlMask::MaskNone));
-        Emit(OpSwitch(jmp_to_load, default_branch, literals, branch_labels));
+        const Id jmp_to_load = OpLoad(t_uint, jmp_to);
+        OpSelectionMerge(jump_label, spv::SelectionControlMask::MaskNone);
+        OpSwitch(jmp_to_load, default_branch, literals, branch_labels);
 
-        Emit(default_branch);
-        Emit(OpReturn());
+        AddLabel(default_branch);
+        OpReturn();
 
-        for (const auto& pair : ir.GetBasicBlocks()) {
-            const auto& [address, bb] = pair;
-            Emit(labels.at(address));
+        for (const auto& [address, bb] : ir.GetBasicBlocks()) {
+            AddLabel(labels.at(address));
 
             VisitBasicBlock(bb);
 
             const auto next_it = labels.lower_bound(address + 1);
             const Id next_label = next_it != labels.end() ? next_it->second : default_branch;
-            Emit(OpBranch(next_label));
+            OpBranch(next_label);
         }
 
-        Emit(jump_label);
-        Emit(OpBranch(continue_label));
-        Emit(continue_label);
-        Emit(OpBranch(loop_label));
-        Emit(merge_label);
-    }
-
-    void DecompileAST();
-
-    void Decompile() {
-        const bool is_fully_decompiled = ir.IsDecompiled();
-        AllocateBindings();
-        if (!is_fully_decompiled) {
-            AllocateLabels();
-        }
-
-        DeclareVertex();
-        DeclareGeometry();
-        DeclareFragment();
-        DeclareRegisters();
-        DeclarePredicates();
-        if (is_fully_decompiled) {
-            DeclareFlowVariables();
-        }
-        DeclareLocalMemory();
-        DeclareInternalFlags();
-        DeclareInputAttributes();
-        DeclareOutputAttributes();
-        DeclareConstantBuffers();
-        DeclareGlobalBuffers();
-        DeclareSamplers();
-
-        execute_function =
-            Emit(OpFunction(t_void, spv::FunctionControlMask::Inline, TypeFunction(t_void)));
-        Emit(OpLabel());
-
-        if (is_fully_decompiled) {
-            DecompileAST();
-        } else {
-            DecompileBranchMode();
-        }
-
-        Emit(OpReturn());
-        Emit(OpFunctionEnd());
-    }
-
-    ShaderEntries GetShaderEntries() const {
-        ShaderEntries entries;
-        entries.const_buffers_base_binding = const_buffers_base_binding;
-        entries.global_buffers_base_binding = global_buffers_base_binding;
-        entries.samplers_base_binding = samplers_base_binding;
-        for (const auto& cbuf : ir.GetConstantBuffers()) {
-            entries.const_buffers.emplace_back(cbuf.second, cbuf.first);
-        }
-        for (const auto& gmem_pair : ir.GetGlobalMemory()) {
-            const auto& [base, usage] = gmem_pair;
-            entries.global_buffers.emplace_back(base.cbuf_index, base.cbuf_offset);
-        }
-        for (const auto& sampler : ir.GetSamplers()) {
-            entries.samplers.emplace_back(sampler);
-        }
-        for (const auto& attribute : ir.GetInputAttributes()) {
-            if (IsGenericAttribute(attribute)) {
-                entries.attributes.insert(GetGenericAttributeLocation(attribute));
-            }
-        }
-        entries.clip_distances = ir.GetClipDistances();
-        entries.shader_length = ir.GetLength();
-        entries.entry_function = execute_function;
-        entries.interfaces = interfaces;
-        return entries;
+        AddLabel(jump_label);
+        OpBranch(continue_label);
+        AddLabel(continue_label);
+        OpBranch(loop_label);
+        AddLabel(merge_label);
     }
 
 private:
@@ -232,23 +468,6 @@ private:
 
     static constexpr auto INTERNAL_FLAGS_COUNT = static_cast<std::size_t>(InternalFlag::Amount);
 
-    void AllocateBindings() {
-        const u32 binding_base = static_cast<u32>(stage) * STAGE_BINDING_STRIDE;
-        u32 binding_iterator = binding_base;
-
-        const auto Allocate = [&binding_iterator](std::size_t count) {
-            const u32 current_binding = binding_iterator;
-            binding_iterator += static_cast<u32>(count);
-            return current_binding;
-        };
-        const_buffers_base_binding = Allocate(ir.GetConstantBuffers().size());
-        global_buffers_base_binding = Allocate(ir.GetGlobalMemory().size());
-        samplers_base_binding = Allocate(ir.GetSamplers().size());
-
-        ASSERT_MSG(binding_iterator - binding_base < STAGE_BINDING_STRIDE,
-                   "Stage binding stride is too small");
-    }
-
     void AllocateLabels() {
         for (const auto& pair : ir.GetBasicBlocks()) {
             const u32 address = pair.first;
@@ -256,23 +475,72 @@ private:
         }
     }
 
-    void DeclareVertex() {
-        if (stage != ShaderType::Vertex)
-            return;
+    void DeclareCommon() {
+        thread_id =
+            DeclareInputBuiltIn(spv::BuiltIn::SubgroupLocalInvocationId, t_in_uint, "thread_id");
+    }
 
-        DeclareVertexRedeclarations();
+    void DeclareVertex() {
+        if (stage != ShaderType::Vertex) {
+            return;
+        }
+        Id out_vertex_struct;
+        std::tie(out_vertex_struct, out_indices) = DeclareVertexStruct();
+        const Id vertex_ptr = TypePointer(spv::StorageClass::Output, out_vertex_struct);
+        out_vertex = OpVariable(vertex_ptr, spv::StorageClass::Output);
+        interfaces.push_back(AddGlobalVariable(Name(out_vertex, "out_vertex")));
+
+        // Declare input attributes
+        vertex_index = DeclareInputBuiltIn(spv::BuiltIn::VertexIndex, t_in_uint, "vertex_index");
+        instance_index =
+            DeclareInputBuiltIn(spv::BuiltIn::InstanceIndex, t_in_uint, "instance_index");
+    }
+
+    void DeclareTessControl() {
+        if (stage != ShaderType::TesselationControl) {
+            return;
+        }
+        DeclareInputVertexArray(NumInputPatches);
+        DeclareOutputVertexArray(header.common2.threads_per_input_primitive);
+
+        tess_level_outer = DeclareBuiltIn(
+            spv::BuiltIn::TessLevelOuter, spv::StorageClass::Output,
+            TypePointer(spv::StorageClass::Output, TypeArray(t_float, Constant(t_uint, 4U))),
+            "tess_level_outer");
+        Decorate(tess_level_outer, spv::Decoration::Patch);
+
+        tess_level_inner = DeclareBuiltIn(
+            spv::BuiltIn::TessLevelInner, spv::StorageClass::Output,
+            TypePointer(spv::StorageClass::Output, TypeArray(t_float, Constant(t_uint, 2U))),
+            "tess_level_inner");
+        Decorate(tess_level_inner, spv::Decoration::Patch);
+
+        invocation_id = DeclareInputBuiltIn(spv::BuiltIn::InvocationId, t_in_int, "invocation_id");
+    }
+
+    void DeclareTessEval() {
+        if (stage != ShaderType::TesselationEval) {
+            return;
+        }
+        DeclareInputVertexArray(NumInputPatches);
+        DeclareOutputVertex();
+
+        tess_coord = DeclareInputBuiltIn(spv::BuiltIn::TessCoord, t_in_float3, "tess_coord");
     }
 
     void DeclareGeometry() {
-        if (stage != ShaderType::Geometry)
+        if (stage != ShaderType::Geometry) {
             return;
-
-        UNIMPLEMENTED();
+        }
+        const u32 num_input = GetNumPrimitiveTopologyVertices(specialization.primitive_topology);
+        DeclareInputVertexArray(num_input);
+        DeclareOutputVertex();
     }
 
     void DeclareFragment() {
-        if (stage != ShaderType::Fragment)
+        if (stage != ShaderType::Fragment) {
             return;
+        }
 
         for (u32 rt = 0; rt < static_cast<u32>(frag_colors.size()); ++rt) {
             if (!IsRenderTargetUsed(rt)) {
@@ -296,10 +564,19 @@ private:
             interfaces.push_back(frag_depth);
         }
 
-        frag_coord = DeclareBuiltIn(spv::BuiltIn::FragCoord, spv::StorageClass::Input, t_in_float4,
-                                    "frag_coord");
-        front_facing = DeclareBuiltIn(spv::BuiltIn::FrontFacing, spv::StorageClass::Input,
-                                      t_in_bool, "front_facing");
+        frag_coord = DeclareInputBuiltIn(spv::BuiltIn::FragCoord, t_in_float4, "frag_coord");
+        front_facing = DeclareInputBuiltIn(spv::BuiltIn::FrontFacing, t_in_bool, "front_facing");
+        point_coord = DeclareInputBuiltIn(spv::BuiltIn::PointCoord, t_in_float2, "point_coord");
+    }
+
+    void DeclareCompute() {
+        if (stage != ShaderType::Compute) {
+            return;
+        }
+
+        workgroup_id = DeclareInputBuiltIn(spv::BuiltIn::WorkgroupId, t_in_uint3, "workgroup_id");
+        local_invocation_id =
+            DeclareInputBuiltIn(spv::BuiltIn::LocalInvocationId, t_in_uint3, "local_invocation_id");
     }
 
     void DeclareRegisters() {
@@ -327,21 +604,44 @@ private:
     }
 
     void DeclareLocalMemory() {
-        if (const u64 local_memory_size = header.GetLocalMemorySize(); local_memory_size > 0) {
-            const auto element_count = static_cast<u32>(Common::AlignUp(local_memory_size, 4) / 4);
-            const Id type_array = TypeArray(t_float, Constant(t_uint, element_count));
-            const Id type_pointer = TypePointer(spv::StorageClass::Private, type_array);
-            Name(type_pointer, "LocalMemory");
-
-            local_memory =
-                OpVariable(type_pointer, spv::StorageClass::Private, ConstantNull(type_array));
-            AddGlobalVariable(Name(local_memory, "local_memory"));
+        // TODO(Rodrigo): Unstub kernel local memory size and pass it from a register at
+        // specialization time.
+        const u64 lmem_size = stage == ShaderType::Compute ? 0x400 : header.GetLocalMemorySize();
+        if (lmem_size == 0) {
+            return;
         }
+        const auto element_count = static_cast<u32>(Common::AlignUp(lmem_size, 4) / 4);
+        const Id type_array = TypeArray(t_float, Constant(t_uint, element_count));
+        const Id type_pointer = TypePointer(spv::StorageClass::Private, type_array);
+        Name(type_pointer, "LocalMemory");
+
+        local_memory =
+            OpVariable(type_pointer, spv::StorageClass::Private, ConstantNull(type_array));
+        AddGlobalVariable(Name(local_memory, "local_memory"));
+    }
+
+    void DeclareSharedMemory() {
+        if (stage != ShaderType::Compute) {
+            return;
+        }
+        t_smem_uint = TypePointer(spv::StorageClass::Workgroup, t_uint);
+
+        const u32 smem_size = specialization.shared_memory_size;
+        if (smem_size == 0) {
+            // Avoid declaring an empty array.
+            return;
+        }
+        const auto element_count = static_cast<u32>(Common::AlignUp(smem_size, 4) / 4);
+        const Id type_array = TypeArray(t_uint, Constant(t_uint, element_count));
+        const Id type_pointer = TypePointer(spv::StorageClass::Workgroup, type_array);
+        Name(type_pointer, "SharedMemory");
+
+        shared_memory = OpVariable(type_pointer, spv::StorageClass::Workgroup);
+        AddGlobalVariable(Name(shared_memory, "shared_memory"));
     }
 
     void DeclareInternalFlags() {
-        constexpr std::array<const char*, INTERNAL_FLAGS_COUNT> names = {"zero", "sign", "carry",
-                                                                         "overflow"};
+        constexpr std::array names = {"zero", "sign", "carry", "overflow"};
         for (std::size_t flag = 0; flag < INTERNAL_FLAGS_COUNT; ++flag) {
             const auto flag_code = static_cast<InternalFlag>(flag);
             const Id id = OpVariable(t_prv_bool, spv::StorageClass::Private, v_false);
@@ -349,17 +649,53 @@ private:
         }
     }
 
+    void DeclareInputVertexArray(u32 length) {
+        constexpr auto storage = spv::StorageClass::Input;
+        std::tie(in_indices, in_vertex) = DeclareVertexArray(storage, "in_indices", length);
+    }
+
+    void DeclareOutputVertexArray(u32 length) {
+        constexpr auto storage = spv::StorageClass::Output;
+        std::tie(out_indices, out_vertex) = DeclareVertexArray(storage, "out_indices", length);
+    }
+
+    std::tuple<VertexIndices, Id> DeclareVertexArray(spv::StorageClass storage_class,
+                                                     std::string name, u32 length) {
+        const auto [struct_id, indices] = DeclareVertexStruct();
+        const Id vertex_array = TypeArray(struct_id, Constant(t_uint, length));
+        const Id vertex_ptr = TypePointer(storage_class, vertex_array);
+        const Id vertex = OpVariable(vertex_ptr, storage_class);
+        AddGlobalVariable(Name(vertex, std::move(name)));
+        interfaces.push_back(vertex);
+        return {indices, vertex};
+    }
+
+    void DeclareOutputVertex() {
+        Id out_vertex_struct;
+        std::tie(out_vertex_struct, out_indices) = DeclareVertexStruct();
+        const Id out_vertex_ptr = TypePointer(spv::StorageClass::Output, out_vertex_struct);
+        out_vertex = OpVariable(out_vertex_ptr, spv::StorageClass::Output);
+        interfaces.push_back(AddGlobalVariable(Name(out_vertex, "out_vertex")));
+    }
+
     void DeclareInputAttributes() {
         for (const auto index : ir.GetInputAttributes()) {
             if (!IsGenericAttribute(index)) {
                 continue;
             }
 
-            UNIMPLEMENTED_IF(stage == ShaderType::Geometry);
-
             const u32 location = GetGenericAttributeLocation(index);
-            const Id id = OpVariable(t_in_float4, spv::StorageClass::Input);
-            Name(AddGlobalVariable(id), fmt::format("in_attr{}", location));
+            const auto type_descriptor = GetAttributeType(location);
+            Id type;
+            if (IsInputAttributeArray()) {
+                type = GetTypeVectorDefinitionLut(type_descriptor.type).at(3);
+                type = TypeArray(type, Constant(t_uint, GetNumInputVertices()));
+                type = TypePointer(spv::StorageClass::Input, type);
+            } else {
+                type = type_descriptor.vector;
+            }
+            const Id id = OpVariable(type, spv::StorageClass::Input);
+            AddGlobalVariable(Name(id, fmt::format("in_attr{}", location)));
             input_attributes.emplace(index, id);
             interfaces.push_back(id);
 
@@ -389,8 +725,21 @@ private:
             if (!IsGenericAttribute(index)) {
                 continue;
             }
-            const auto location = GetGenericAttributeLocation(index);
-            const Id id = OpVariable(t_out_float4, spv::StorageClass::Output);
+            const u32 location = GetGenericAttributeLocation(index);
+            Id type = t_float4;
+            Id varying_default = v_varying_default;
+            if (IsOutputAttributeArray()) {
+                const u32 num = GetNumOutputVertices();
+                type = TypeArray(type, Constant(t_uint, num));
+                if (device.GetDriverID() != vk::DriverIdKHR::eIntelProprietaryWindows) {
+                    // Intel's proprietary driver fails to setup defaults for arrayed output
+                    // attributes.
+                    varying_default = ConstantComposite(type, std::vector(num, varying_default));
+                }
+            }
+            type = TypePointer(spv::StorageClass::Output, type);
+
+            const Id id = OpVariable(type, spv::StorageClass::Output, varying_default);
             Name(AddGlobalVariable(id), fmt::format("out_attr{}", location));
             output_attributes.emplace(index, id);
             interfaces.push_back(id);
@@ -399,10 +748,8 @@ private:
         }
     }
 
-    void DeclareConstantBuffers() {
-        u32 binding = const_buffers_base_binding;
-        for (const auto& entry : ir.GetConstantBuffers()) {
-            const auto [index, size] = entry;
+    u32 DeclareConstantBuffers(u32 binding) {
+        for (const auto& [index, size] : ir.GetConstantBuffers()) {
             const Id type = device.IsKhrUniformBufferStandardLayoutSupported() ? t_cbuf_scalar_ubo
                                                                                : t_cbuf_std140_ubo;
             const Id id = OpVariable(type, spv::StorageClass::Uniform);
@@ -412,12 +759,11 @@ private:
             Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET);
             constant_buffers.emplace(index, id);
         }
+        return binding;
     }
 
-    void DeclareGlobalBuffers() {
-        u32 binding = global_buffers_base_binding;
-        for (const auto& entry : ir.GetGlobalMemory()) {
-            const auto [base, usage] = entry;
+    u32 DeclareGlobalBuffers(u32 binding) {
+        for (const auto& [base, usage] : ir.GetGlobalMemory()) {
             const Id id = OpVariable(t_gmem_ssbo, spv::StorageClass::StorageBuffer);
             AddGlobalVariable(
                 Name(id, fmt::format("gmem_{}_{}", base.cbuf_index, base.cbuf_offset)));
@@ -426,89 +772,187 @@ private:
             Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET);
             global_buffers.emplace(base, id);
         }
+        return binding;
     }
 
-    void DeclareSamplers() {
-        u32 binding = samplers_base_binding;
+    u32 DeclareTexelBuffers(u32 binding) {
         for (const auto& sampler : ir.GetSamplers()) {
+            if (!sampler.IsBuffer()) {
+                continue;
+            }
+            ASSERT(!sampler.IsArray());
+            ASSERT(!sampler.IsShadow());
+
+            constexpr auto dim = spv::Dim::Buffer;
+            constexpr int depth = 0;
+            constexpr int arrayed = 0;
+            constexpr bool ms = false;
+            constexpr int sampled = 1;
+            constexpr auto format = spv::ImageFormat::Unknown;
+            const Id image_type = TypeImage(t_float, dim, depth, arrayed, ms, sampled, format);
+            const Id pointer_type = TypePointer(spv::StorageClass::UniformConstant, image_type);
+            const Id id = OpVariable(pointer_type, spv::StorageClass::UniformConstant);
+            AddGlobalVariable(Name(id, fmt::format("sampler_{}", sampler.GetIndex())));
+            Decorate(id, spv::Decoration::Binding, binding++);
+            Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET);
+
+            texel_buffers.emplace(sampler.GetIndex(), TexelBuffer{image_type, id});
+        }
+        return binding;
+    }
+
+    u32 DeclareSamplers(u32 binding) {
+        for (const auto& sampler : ir.GetSamplers()) {
+            if (sampler.IsBuffer()) {
+                continue;
+            }
             const auto dim = GetSamplerDim(sampler);
             const int depth = sampler.IsShadow() ? 1 : 0;
             const int arrayed = sampler.IsArray() ? 1 : 0;
-            // TODO(Rodrigo): Sampled 1 indicates that the image will be used with a sampler. When
-            // SULD and SUST instructions are implemented, replace this value.
-            const int sampled = 1;
-            const Id image_type =
-                TypeImage(t_float, dim, depth, arrayed, false, sampled, spv::ImageFormat::Unknown);
+            constexpr bool ms = false;
+            constexpr int sampled = 1;
+            constexpr auto format = spv::ImageFormat::Unknown;
+            const Id image_type = TypeImage(t_float, dim, depth, arrayed, ms, sampled, format);
             const Id sampled_image_type = TypeSampledImage(image_type);
             const Id pointer_type =
                 TypePointer(spv::StorageClass::UniformConstant, sampled_image_type);
             const Id id = OpVariable(pointer_type, spv::StorageClass::UniformConstant);
             AddGlobalVariable(Name(id, fmt::format("sampler_{}", sampler.GetIndex())));
+            Decorate(id, spv::Decoration::Binding, binding++);
+            Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET);
 
-            sampler_images.insert(
-                {static_cast<u32>(sampler.GetIndex()), {image_type, sampled_image_type, id}});
+            sampled_images.emplace(sampler.GetIndex(),
+                                   SampledImage{image_type, sampled_image_type, id});
+        }
+        return binding;
+    }
+
+    u32 DeclareImages(u32 binding) {
+        for (const auto& image : ir.GetImages()) {
+            const auto [dim, arrayed] = GetImageDim(image);
+            constexpr int depth = 0;
+            constexpr bool ms = false;
+            constexpr int sampled = 2; // This won't be accessed with a sampler
+            constexpr auto format = spv::ImageFormat::Unknown;
+            const Id image_type = TypeImage(t_uint, dim, depth, arrayed, ms, sampled, format, {});
+            const Id pointer_type = TypePointer(spv::StorageClass::UniformConstant, image_type);
+            const Id id = OpVariable(pointer_type, spv::StorageClass::UniformConstant);
+            AddGlobalVariable(Name(id, fmt::format("image_{}", image.GetIndex())));
 
             Decorate(id, spv::Decoration::Binding, binding++);
             Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET);
+            if (image.IsRead() && !image.IsWritten()) {
+                Decorate(id, spv::Decoration::NonWritable);
+            } else if (image.IsWritten() && !image.IsRead()) {
+                Decorate(id, spv::Decoration::NonReadable);
+            }
+
+            images.emplace(static_cast<u32>(image.GetIndex()), StorageImage{image_type, id});
+        }
+        return binding;
+    }
+
+    bool IsInputAttributeArray() const {
+        return stage == ShaderType::TesselationControl || stage == ShaderType::TesselationEval ||
+               stage == ShaderType::Geometry;
+    }
+
+    bool IsOutputAttributeArray() const {
+        return stage == ShaderType::TesselationControl;
+    }
+
+    u32 GetNumInputVertices() const {
+        switch (stage) {
+        case ShaderType::Geometry:
+            return GetNumPrimitiveTopologyVertices(specialization.primitive_topology);
+        case ShaderType::TesselationControl:
+        case ShaderType::TesselationEval:
+            return NumInputPatches;
+        default:
+            UNREACHABLE();
+            return 1;
         }
     }
 
-    void DeclareVertexRedeclarations() {
-        vertex_index = DeclareBuiltIn(spv::BuiltIn::VertexIndex, spv::StorageClass::Input,
-                                      t_in_uint, "vertex_index");
-        instance_index = DeclareBuiltIn(spv::BuiltIn::InstanceIndex, spv::StorageClass::Input,
-                                        t_in_uint, "instance_index");
+    u32 GetNumOutputVertices() const {
+        switch (stage) {
+        case ShaderType::TesselationControl:
+            return header.common2.threads_per_input_primitive;
+        default:
+            UNREACHABLE();
+            return 1;
+        }
+    }
 
-        bool is_clip_distances_declared = false;
-        for (const auto index : ir.GetOutputAttributes()) {
-            if (index == Attribute::Index::ClipDistances0123 ||
-                index == Attribute::Index::ClipDistances4567) {
-                is_clip_distances_declared = true;
+    std::tuple<Id, VertexIndices> DeclareVertexStruct() {
+        struct BuiltIn {
+            Id type;
+            spv::BuiltIn builtin;
+            const char* name;
+        };
+        std::vector<BuiltIn> members;
+        members.reserve(4);
+
+        const auto AddBuiltIn = [&](Id type, spv::BuiltIn builtin, const char* name) {
+            const auto index = static_cast<u32>(members.size());
+            members.push_back(BuiltIn{type, builtin, name});
+            return index;
+        };
+
+        VertexIndices indices;
+        indices.position = AddBuiltIn(t_float4, spv::BuiltIn::Position, "position");
+
+        if (ir.UsesViewportIndex()) {
+            if (stage != ShaderType::Vertex || device.IsExtShaderViewportIndexLayerSupported()) {
+                indices.viewport = AddBuiltIn(t_int, spv::BuiltIn::ViewportIndex, "viewport_index");
+            } else {
+                LOG_ERROR(Render_Vulkan,
+                          "Shader requires ViewportIndex but it's not supported on this "
+                          "stage with this device.");
             }
         }
 
-        std::vector<Id> members;
-        members.push_back(t_float4);
-        if (ir.UsesPointSize()) {
-            members.push_back(t_float);
-        }
-        if (is_clip_distances_declared) {
-            members.push_back(TypeArray(t_float, Constant(t_uint, 8)));
+        if (ir.UsesPointSize() || specialization.point_size) {
+            indices.point_size = AddBuiltIn(t_float, spv::BuiltIn::PointSize, "point_size");
         }
 
-        const Id gl_per_vertex_struct = Name(TypeStruct(members), "PerVertex");
-        Decorate(gl_per_vertex_struct, spv::Decoration::Block);
+        const auto& output_attributes = ir.GetOutputAttributes();
+        const bool declare_clip_distances =
+            std::any_of(output_attributes.begin(), output_attributes.end(), [](const auto& index) {
+                return index == Attribute::Index::ClipDistances0123 ||
+                       index == Attribute::Index::ClipDistances4567;
+            });
+        if (declare_clip_distances) {
+            indices.clip_distances = AddBuiltIn(TypeArray(t_float, Constant(t_uint, 8)),
+                                                spv::BuiltIn::ClipDistance, "clip_distances");
+        }
 
-        u32 declaration_index = 0;
-        const auto MemberDecorateBuiltIn = [&](spv::BuiltIn builtin, std::string name,
-                                               bool condition) {
-            if (!condition)
-                return u32{};
-            MemberName(gl_per_vertex_struct, declaration_index, name);
-            MemberDecorate(gl_per_vertex_struct, declaration_index, spv::Decoration::BuiltIn,
-                           static_cast<u32>(builtin));
-            return declaration_index++;
-        };
+        std::vector<Id> member_types;
+        member_types.reserve(members.size());
+        for (std::size_t i = 0; i < members.size(); ++i) {
+            member_types.push_back(members[i].type);
+        }
+        const Id per_vertex_struct = Name(TypeStruct(member_types), "PerVertex");
+        Decorate(per_vertex_struct, spv::Decoration::Block);
 
-        position_index = MemberDecorateBuiltIn(spv::BuiltIn::Position, "position", true);
-        point_size_index =
-            MemberDecorateBuiltIn(spv::BuiltIn::PointSize, "point_size", ir.UsesPointSize());
-        clip_distances_index = MemberDecorateBuiltIn(spv::BuiltIn::ClipDistance, "clip_distances",
-                                                     is_clip_distances_declared);
+        for (std::size_t index = 0; index < members.size(); ++index) {
+            const auto& member = members[index];
+            MemberName(per_vertex_struct, static_cast<u32>(index), member.name);
+            MemberDecorate(per_vertex_struct, static_cast<u32>(index), spv::Decoration::BuiltIn,
+                           static_cast<u32>(member.builtin));
+        }
 
-        const Id type_pointer = TypePointer(spv::StorageClass::Output, gl_per_vertex_struct);
-        per_vertex = OpVariable(type_pointer, spv::StorageClass::Output);
-        AddGlobalVariable(Name(per_vertex, "per_vertex"));
-        interfaces.push_back(per_vertex);
+        return {per_vertex_struct, indices};
     }
 
     void VisitBasicBlock(const NodeBlock& bb) {
         for (const auto& node : bb) {
-            static_cast<void>(Visit(node));
+            [[maybe_unused]] const Type type = Visit(node).type;
+            ASSERT(type == Type::Void);
         }
     }
 
-    Id Visit(const Node& node) {
+    Expression Visit(const Node& node) {
         if (const auto operation = std::get_if<OperationNode>(&*node)) {
             const auto operation_index = static_cast<std::size_t>(operation->GetCode());
             const auto decompiler = operation_decompilers[operation_index];
@@ -516,18 +960,21 @@ private:
                 UNREACHABLE_MSG("Operation decompiler {} not defined", operation_index);
             }
             return (this->*decompiler)(*operation);
+        }
 
-        } else if (const auto gpr = std::get_if<GprNode>(&*node)) {
+        if (const auto gpr = std::get_if<GprNode>(&*node)) {
             const u32 index = gpr->GetIndex();
             if (index == Register::ZeroIndex) {
-                return Constant(t_float, 0.0f);
+                return {v_float_zero, Type::Float};
             }
-            return Emit(OpLoad(t_float, registers.at(index)));
+            return {OpLoad(t_float, registers.at(index)), Type::Float};
+        }
 
-        } else if (const auto immediate = std::get_if<ImmediateNode>(&*node)) {
-            return BitcastTo<Type::Float>(Constant(t_uint, immediate->GetValue()));
+        if (const auto immediate = std::get_if<ImmediateNode>(&*node)) {
+            return {Constant(t_uint, immediate->GetValue()), Type::Uint};
+        }
 
-        } else if (const auto predicate = std::get_if<PredicateNode>(&*node)) {
+        if (const auto predicate = std::get_if<PredicateNode>(&*node)) {
             const auto value = [&]() -> Id {
                 switch (const auto index = predicate->GetIndex(); index) {
                 case Tegra::Shader::Pred::UnusedIndex:
@@ -535,74 +982,107 @@ private:
                 case Tegra::Shader::Pred::NeverExecute:
                     return v_false;
                 default:
-                    return Emit(OpLoad(t_bool, predicates.at(index)));
+                    return OpLoad(t_bool, predicates.at(index));
                 }
             }();
             if (predicate->IsNegated()) {
-                return Emit(OpLogicalNot(t_bool, value));
+                return {OpLogicalNot(t_bool, value), Type::Bool};
             }
-            return value;
+            return {value, Type::Bool};
+        }
 
-        } else if (const auto abuf = std::get_if<AbufNode>(&*node)) {
+        if (const auto abuf = std::get_if<AbufNode>(&*node)) {
             const auto attribute = abuf->GetIndex();
-            const auto element = abuf->GetElement();
+            const u32 element = abuf->GetElement();
+            const auto& buffer = abuf->GetBuffer();
+
+            const auto ArrayPass = [&](Id pointer_type, Id composite, std::vector<u32> indices) {
+                std::vector<Id> members;
+                members.reserve(std::size(indices) + 1);
+
+                if (buffer && IsInputAttributeArray()) {
+                    members.push_back(AsUint(Visit(buffer)));
+                }
+                for (const u32 index : indices) {
+                    members.push_back(Constant(t_uint, index));
+                }
+                return OpAccessChain(pointer_type, composite, members);
+            };
 
             switch (attribute) {
-            case Attribute::Index::Position:
-                if (stage != ShaderType::Fragment) {
-                    UNIMPLEMENTED();
-                    break;
-                } else {
+            case Attribute::Index::Position: {
+                if (stage == ShaderType::Fragment) {
                     if (element == 3) {
-                        return Constant(t_float, 1.0f);
+                        return {Constant(t_float, 1.0f), Type::Float};
                     }
-                    return Emit(OpLoad(t_float, AccessElement(t_in_float, frag_coord, element)));
+                    return {OpLoad(t_float, AccessElement(t_in_float, frag_coord, element)),
+                            Type::Float};
                 }
+                const auto elements = {in_indices.position.value(), element};
+                return {OpLoad(t_float, ArrayPass(t_in_float, in_vertex, elements)), Type::Float};
+            }
+            case Attribute::Index::PointCoord: {
+                switch (element) {
+                case 0:
+                case 1:
+                    return {OpCompositeExtract(t_float, OpLoad(t_float2, point_coord), element),
+                            Type::Float};
+                }
+                UNIMPLEMENTED_MSG("Unimplemented point coord element={}", element);
+                return {v_float_zero, Type::Float};
+            }
             case Attribute::Index::TessCoordInstanceIDVertexID:
                 // TODO(Subv): Find out what the values are for the first two elements when inside a
                 // vertex shader, and what's the value of the fourth element when inside a Tess Eval
                 // shader.
-                ASSERT(stage == ShaderType::Vertex);
                 switch (element) {
+                case 0:
+                case 1:
+                    return {OpLoad(t_float, AccessElement(t_in_float, tess_coord, element)),
+                            Type::Float};
                 case 2:
-                    return BitcastFrom<Type::Uint>(Emit(OpLoad(t_uint, instance_index)));
+                    return {OpLoad(t_uint, instance_index), Type::Uint};
                 case 3:
-                    return BitcastFrom<Type::Uint>(Emit(OpLoad(t_uint, vertex_index)));
+                    return {OpLoad(t_uint, vertex_index), Type::Uint};
                 }
                 UNIMPLEMENTED_MSG("Unmanaged TessCoordInstanceIDVertexID element={}", element);
-                return Constant(t_float, 0);
+                return {Constant(t_uint, 0U), Type::Uint};
             case Attribute::Index::FrontFacing:
                 // TODO(Subv): Find out what the values are for the other elements.
                 ASSERT(stage == ShaderType::Fragment);
                 if (element == 3) {
-                    const Id is_front_facing = Emit(OpLoad(t_bool, front_facing));
-                    const Id true_value =
-                        BitcastTo<Type::Float>(Constant(t_int, static_cast<s32>(-1)));
-                    const Id false_value = BitcastTo<Type::Float>(Constant(t_int, 0));
-                    return Emit(OpSelect(t_float, is_front_facing, true_value, false_value));
+                    const Id is_front_facing = OpLoad(t_bool, front_facing);
+                    const Id true_value = Constant(t_int, static_cast<s32>(-1));
+                    const Id false_value = Constant(t_int, 0);
+                    return {OpSelect(t_int, is_front_facing, true_value, false_value), Type::Int};
                 }
                 UNIMPLEMENTED_MSG("Unmanaged FrontFacing element={}", element);
-                return Constant(t_float, 0.0f);
+                return {v_float_zero, Type::Float};
             default:
                 if (IsGenericAttribute(attribute)) {
-                    const Id pointer =
-                        AccessElement(t_in_float, input_attributes.at(attribute), element);
-                    return Emit(OpLoad(t_float, pointer));
+                    const u32 location = GetGenericAttributeLocation(attribute);
+                    const auto type_descriptor = GetAttributeType(location);
+                    const Type type = type_descriptor.type;
+                    const Id attribute_id = input_attributes.at(attribute);
+                    const Id pointer = ArrayPass(type_descriptor.scalar, attribute_id, {element});
+                    return {OpLoad(GetTypeDefinition(type), pointer), type};
                 }
                 break;
             }
             UNIMPLEMENTED_MSG("Unhandled input attribute: {}", static_cast<u32>(attribute));
+            return {v_float_zero, Type::Float};
+        }
 
-        } else if (const auto cbuf = std::get_if<CbufNode>(&*node)) {
+        if (const auto cbuf = std::get_if<CbufNode>(&*node)) {
             const Node& offset = cbuf->GetOffset();
             const Id buffer_id = constant_buffers.at(cbuf->GetIndex());
 
             Id pointer{};
             if (device.IsKhrUniformBufferStandardLayoutSupported()) {
-                const Id buffer_offset = Emit(OpShiftRightLogical(
-                    t_uint, BitcastTo<Type::Uint>(Visit(offset)), Constant(t_uint, 2u)));
-                pointer = Emit(
-                    OpAccessChain(t_cbuf_float, buffer_id, Constant(t_uint, 0u), buffer_offset));
+                const Id buffer_offset =
+                    OpShiftRightLogical(t_uint, AsUint(Visit(offset)), Constant(t_uint, 2U));
+                pointer =
+                    OpAccessChain(t_cbuf_float, buffer_id, Constant(t_uint, 0U), buffer_offset);
             } else {
                 Id buffer_index{};
                 Id buffer_element{};
@@ -614,53 +1094,76 @@ private:
                     buffer_element = Constant(t_uint, (offset_imm / 4) % 4);
                 } else if (std::holds_alternative<OperationNode>(*offset)) {
                     // Indirect access
-                    const Id offset_id = BitcastTo<Type::Uint>(Visit(offset));
-                    const Id unsafe_offset = Emit(OpUDiv(t_uint, offset_id, Constant(t_uint, 4)));
-                    const Id final_offset = Emit(OpUMod(
-                        t_uint, unsafe_offset, Constant(t_uint, MAX_CONSTBUFFER_ELEMENTS - 1)));
-                    buffer_index = Emit(OpUDiv(t_uint, final_offset, Constant(t_uint, 4)));
-                    buffer_element = Emit(OpUMod(t_uint, final_offset, Constant(t_uint, 4)));
+                    const Id offset_id = AsUint(Visit(offset));
+                    const Id unsafe_offset = OpUDiv(t_uint, offset_id, Constant(t_uint, 4));
+                    const Id final_offset =
+                        OpUMod(t_uint, unsafe_offset, Constant(t_uint, MaxConstBufferElements - 1));
+                    buffer_index = OpUDiv(t_uint, final_offset, Constant(t_uint, 4));
+                    buffer_element = OpUMod(t_uint, final_offset, Constant(t_uint, 4));
                 } else {
                     UNREACHABLE_MSG("Unmanaged offset node type");
                 }
-                pointer = Emit(OpAccessChain(t_cbuf_float, buffer_id, Constant(t_uint, 0),
-                                             buffer_index, buffer_element));
+                pointer = OpAccessChain(t_cbuf_float, buffer_id, Constant(t_uint, 0), buffer_index,
+                                        buffer_element);
             }
-            return Emit(OpLoad(t_float, pointer));
+            return {OpLoad(t_float, pointer), Type::Float};
+        }
 
-        } else if (const auto gmem = std::get_if<GmemNode>(&*node)) {
+        if (const auto gmem = std::get_if<GmemNode>(&*node)) {
             const Id gmem_buffer = global_buffers.at(gmem->GetDescriptor());
-            const Id real = BitcastTo<Type::Uint>(Visit(gmem->GetRealAddress()));
-            const Id base = BitcastTo<Type::Uint>(Visit(gmem->GetBaseAddress()));
+            const Id real = AsUint(Visit(gmem->GetRealAddress()));
+            const Id base = AsUint(Visit(gmem->GetBaseAddress()));
 
-            Id offset = Emit(OpISub(t_uint, real, base));
-            offset = Emit(OpUDiv(t_uint, offset, Constant(t_uint, 4u)));
-            return Emit(OpLoad(t_float, Emit(OpAccessChain(t_gmem_float, gmem_buffer,
-                                                           Constant(t_uint, 0u), offset))));
+            Id offset = OpISub(t_uint, real, base);
+            offset = OpUDiv(t_uint, offset, Constant(t_uint, 4U));
+            return {OpLoad(t_float,
+                           OpAccessChain(t_gmem_float, gmem_buffer, Constant(t_uint, 0U), offset)),
+                    Type::Float};
+        }
 
-        } else if (const auto conditional = std::get_if<ConditionalNode>(&*node)) {
+        if (const auto lmem = std::get_if<LmemNode>(&*node)) {
+            Id address = AsUint(Visit(lmem->GetAddress()));
+            address = OpShiftRightLogical(t_uint, address, Constant(t_uint, 2U));
+            const Id pointer = OpAccessChain(t_prv_float, local_memory, address);
+            return {OpLoad(t_float, pointer), Type::Float};
+        }
+
+        if (const auto smem = std::get_if<SmemNode>(&*node)) {
+            Id address = AsUint(Visit(smem->GetAddress()));
+            address = OpShiftRightLogical(t_uint, address, Constant(t_uint, 2U));
+            const Id pointer = OpAccessChain(t_smem_uint, shared_memory, address);
+            return {OpLoad(t_uint, pointer), Type::Uint};
+        }
+
+        if (const auto internal_flag = std::get_if<InternalFlagNode>(&*node)) {
+            const Id flag = internal_flags.at(static_cast<std::size_t>(internal_flag->GetFlag()));
+            return {OpLoad(t_bool, flag), Type::Bool};
+        }
+
+        if (const auto conditional = std::get_if<ConditionalNode>(&*node)) {
             // It's invalid to call conditional on nested nodes, use an operation instead
             const Id true_label = OpLabel();
             const Id skip_label = OpLabel();
-            const Id condition = Visit(conditional->GetCondition());
-            Emit(OpSelectionMerge(skip_label, spv::SelectionControlMask::MaskNone));
-            Emit(OpBranchConditional(condition, true_label, skip_label));
-            Emit(true_label);
+            const Id condition = AsBool(Visit(conditional->GetCondition()));
+            OpSelectionMerge(skip_label, spv::SelectionControlMask::MaskNone);
+            OpBranchConditional(condition, true_label, skip_label);
+            AddLabel(true_label);
 
-            ++conditional_nest_count;
+            conditional_branch_set = true;
+            inside_branch = false;
             VisitBasicBlock(conditional->GetCode());
-            --conditional_nest_count;
-
-            if (inside_branch == 0) {
-                Emit(OpBranch(skip_label));
+            conditional_branch_set = false;
+            if (!inside_branch) {
+                OpBranch(skip_label);
             } else {
-                inside_branch--;
+                inside_branch = false;
             }
-            Emit(skip_label);
+            AddLabel(skip_label);
             return {};
+        }
 
-        } else if (const auto comment = std::get_if<CommentNode>(&*node)) {
-            Name(Emit(OpUndef(t_void)), comment->GetText());
+        if (const auto comment = std::get_if<CommentNode>(&*node)) {
+            Name(OpUndef(t_void), comment->GetText());
             return {};
         }
 
@@ -669,94 +1172,126 @@ private:
     }
 
     template <Id (Module::*func)(Id, Id), Type result_type, Type type_a = result_type>
-    Id Unary(Operation operation) {
+    Expression Unary(Operation operation) {
         const Id type_def = GetTypeDefinition(result_type);
-        const Id op_a = VisitOperand<type_a>(operation, 0);
+        const Id op_a = As(Visit(operation[0]), type_a);
 
-        const Id value = BitcastFrom<result_type>(Emit((this->*func)(type_def, op_a)));
+        const Id value = (this->*func)(type_def, op_a);
         if (IsPrecise(operation)) {
             Decorate(value, spv::Decoration::NoContraction);
         }
-        return value;
+        return {value, result_type};
     }
 
     template <Id (Module::*func)(Id, Id, Id), Type result_type, Type type_a = result_type,
               Type type_b = type_a>
-    Id Binary(Operation operation) {
+    Expression Binary(Operation operation) {
         const Id type_def = GetTypeDefinition(result_type);
-        const Id op_a = VisitOperand<type_a>(operation, 0);
-        const Id op_b = VisitOperand<type_b>(operation, 1);
+        const Id op_a = As(Visit(operation[0]), type_a);
+        const Id op_b = As(Visit(operation[1]), type_b);
 
-        const Id value = BitcastFrom<result_type>(Emit((this->*func)(type_def, op_a, op_b)));
+        const Id value = (this->*func)(type_def, op_a, op_b);
         if (IsPrecise(operation)) {
             Decorate(value, spv::Decoration::NoContraction);
         }
-        return value;
+        return {value, result_type};
     }
 
     template <Id (Module::*func)(Id, Id, Id, Id), Type result_type, Type type_a = result_type,
               Type type_b = type_a, Type type_c = type_b>
-    Id Ternary(Operation operation) {
+    Expression Ternary(Operation operation) {
         const Id type_def = GetTypeDefinition(result_type);
-        const Id op_a = VisitOperand<type_a>(operation, 0);
-        const Id op_b = VisitOperand<type_b>(operation, 1);
-        const Id op_c = VisitOperand<type_c>(operation, 2);
+        const Id op_a = As(Visit(operation[0]), type_a);
+        const Id op_b = As(Visit(operation[1]), type_b);
+        const Id op_c = As(Visit(operation[2]), type_c);
 
-        const Id value = BitcastFrom<result_type>(Emit((this->*func)(type_def, op_a, op_b, op_c)));
+        const Id value = (this->*func)(type_def, op_a, op_b, op_c);
         if (IsPrecise(operation)) {
             Decorate(value, spv::Decoration::NoContraction);
         }
-        return value;
+        return {value, result_type};
     }
 
     template <Id (Module::*func)(Id, Id, Id, Id, Id), Type result_type, Type type_a = result_type,
               Type type_b = type_a, Type type_c = type_b, Type type_d = type_c>
-    Id Quaternary(Operation operation) {
+    Expression Quaternary(Operation operation) {
         const Id type_def = GetTypeDefinition(result_type);
-        const Id op_a = VisitOperand<type_a>(operation, 0);
-        const Id op_b = VisitOperand<type_b>(operation, 1);
-        const Id op_c = VisitOperand<type_c>(operation, 2);
-        const Id op_d = VisitOperand<type_d>(operation, 3);
+        const Id op_a = As(Visit(operation[0]), type_a);
+        const Id op_b = As(Visit(operation[1]), type_b);
+        const Id op_c = As(Visit(operation[2]), type_c);
+        const Id op_d = As(Visit(operation[3]), type_d);
 
-        const Id value =
-            BitcastFrom<result_type>(Emit((this->*func)(type_def, op_a, op_b, op_c, op_d)));
+        const Id value = (this->*func)(type_def, op_a, op_b, op_c, op_d);
         if (IsPrecise(operation)) {
             Decorate(value, spv::Decoration::NoContraction);
         }
-        return value;
+        return {value, result_type};
     }
 
-    Id Assign(Operation operation) {
+    Expression Assign(Operation operation) {
         const Node& dest = operation[0];
         const Node& src = operation[1];
 
-        Id target{};
+        Expression target{};
         if (const auto gpr = std::get_if<GprNode>(&*dest)) {
             if (gpr->GetIndex() == Register::ZeroIndex) {
                 // Writing to Register::ZeroIndex is a no op
                 return {};
             }
-            target = registers.at(gpr->GetIndex());
+            target = {registers.at(gpr->GetIndex()), Type::Float};
 
         } else if (const auto abuf = std::get_if<AbufNode>(&*dest)) {
-            target = [&]() -> Id {
+            const auto& buffer = abuf->GetBuffer();
+            const auto ArrayPass = [&](Id pointer_type, Id composite, std::vector<u32> indices) {
+                std::vector<Id> members;
+                members.reserve(std::size(indices) + 1);
+
+                if (buffer && IsOutputAttributeArray()) {
+                    members.push_back(AsUint(Visit(buffer)));
+                }
+                for (const u32 index : indices) {
+                    members.push_back(Constant(t_uint, index));
+                }
+                return OpAccessChain(pointer_type, composite, members);
+            };
+
+            target = [&]() -> Expression {
+                const u32 element = abuf->GetElement();
                 switch (const auto attribute = abuf->GetIndex(); attribute) {
-                case Attribute::Index::Position:
-                    return AccessElement(t_out_float, per_vertex, position_index,
-                                         abuf->GetElement());
+                case Attribute::Index::Position: {
+                    const u32 index = out_indices.position.value();
+                    return {ArrayPass(t_out_float, out_vertex, {index, element}), Type::Float};
+                }
                 case Attribute::Index::LayerViewportPointSize:
-                    UNIMPLEMENTED_IF(abuf->GetElement() != 3);
-                    return AccessElement(t_out_float, per_vertex, point_size_index);
-                case Attribute::Index::ClipDistances0123:
-                    return AccessElement(t_out_float, per_vertex, clip_distances_index,
-                                         abuf->GetElement());
-                case Attribute::Index::ClipDistances4567:
-                    return AccessElement(t_out_float, per_vertex, clip_distances_index,
-                                         abuf->GetElement() + 4);
+                    switch (element) {
+                    case 2: {
+                        if (!out_indices.viewport) {
+                            return {};
+                        }
+                        const u32 index = out_indices.viewport.value();
+                        return {AccessElement(t_out_int, out_vertex, index), Type::Int};
+                    }
+                    case 3: {
+                        const auto index = out_indices.point_size.value();
+                        return {AccessElement(t_out_float, out_vertex, index), Type::Float};
+                    }
+                    default:
+                        UNIMPLEMENTED_MSG("LayerViewportPoint element={}", abuf->GetElement());
+                        return {};
+                    }
+                case Attribute::Index::ClipDistances0123: {
+                    const u32 index = out_indices.clip_distances.value();
+                    return {AccessElement(t_out_float, out_vertex, index, element), Type::Float};
+                }
+                case Attribute::Index::ClipDistances4567: {
+                    const u32 index = out_indices.clip_distances.value();
+                    return {AccessElement(t_out_float, out_vertex, index, element + 4),
+                            Type::Float};
+                }
                 default:
                     if (IsGenericAttribute(attribute)) {
-                        return AccessElement(t_out_float, output_attributes.at(attribute),
-                                             abuf->GetElement());
+                        const Id composite = output_attributes.at(attribute);
+                        return {ArrayPass(t_out_float, composite, {element}), Type::Float};
                     }
                     UNIMPLEMENTED_MSG("Unhandled output attribute: {}",
                                       static_cast<u32>(attribute));
@@ -764,72 +1299,154 @@ private:
                 }
             }();
 
+        } else if (const auto patch = std::get_if<PatchNode>(&*dest)) {
+            target = [&]() -> Expression {
+                const u32 offset = patch->GetOffset();
+                switch (offset) {
+                case 0:
+                case 1:
+                case 2:
+                case 3:
+                    return {AccessElement(t_out_float, tess_level_outer, offset % 4), Type::Float};
+                case 4:
+                case 5:
+                    return {AccessElement(t_out_float, tess_level_inner, offset % 4), Type::Float};
+                }
+                UNIMPLEMENTED_MSG("Unhandled patch output offset: {}", offset);
+                return {};
+            }();
+
         } else if (const auto lmem = std::get_if<LmemNode>(&*dest)) {
-            Id address = BitcastTo<Type::Uint>(Visit(lmem->GetAddress()));
-            address = Emit(OpUDiv(t_uint, address, Constant(t_uint, 4)));
-            target = Emit(OpAccessChain(t_prv_float, local_memory, {address}));
+            Id address = AsUint(Visit(lmem->GetAddress()));
+            address = OpUDiv(t_uint, address, Constant(t_uint, 4));
+            target = {OpAccessChain(t_prv_float, local_memory, address), Type::Float};
+
+        } else if (const auto smem = std::get_if<SmemNode>(&*dest)) {
+            ASSERT(stage == ShaderType::Compute);
+            Id address = AsUint(Visit(smem->GetAddress()));
+            address = OpShiftRightLogical(t_uint, address, Constant(t_uint, 2U));
+            target = {OpAccessChain(t_smem_uint, shared_memory, address), Type::Uint};
+
+        } else if (const auto gmem = std::get_if<GmemNode>(&*dest)) {
+            const Id real = AsUint(Visit(gmem->GetRealAddress()));
+            const Id base = AsUint(Visit(gmem->GetBaseAddress()));
+            const Id diff = OpISub(t_uint, real, base);
+            const Id offset = OpShiftRightLogical(t_uint, diff, Constant(t_uint, 2));
+
+            const Id gmem_buffer = global_buffers.at(gmem->GetDescriptor());
+            target = {OpAccessChain(t_gmem_float, gmem_buffer, Constant(t_uint, 0), offset),
+                      Type::Float};
+
+        } else {
+            UNIMPLEMENTED();
         }
 
-        Emit(OpStore(target, Visit(src)));
+        OpStore(target.id, As(Visit(src), target.type));
         return {};
     }
 
-    Id FCastHalf0(Operation operation) {
-        UNIMPLEMENTED();
-        return {};
+    template <u32 offset>
+    Expression FCastHalf(Operation operation) {
+        const Id value = AsHalfFloat(Visit(operation[0]));
+        return {GetFloatFromHalfScalar(OpCompositeExtract(t_scalar_half, value, offset)),
+                Type::Float};
     }
 
-    Id FCastHalf1(Operation operation) {
-        UNIMPLEMENTED();
-        return {};
+    Expression FSwizzleAdd(Operation operation) {
+        const Id minus = Constant(t_float, -1.0f);
+        const Id plus = v_float_one;
+        const Id zero = v_float_zero;
+        const Id lut_a = ConstantComposite(t_float4, minus, plus, minus, zero);
+        const Id lut_b = ConstantComposite(t_float4, minus, minus, plus, minus);
+
+        Id mask = OpLoad(t_uint, thread_id);
+        mask = OpBitwiseAnd(t_uint, mask, Constant(t_uint, 3));
+        mask = OpShiftLeftLogical(t_uint, mask, Constant(t_uint, 1));
+        mask = OpShiftRightLogical(t_uint, AsUint(Visit(operation[2])), mask);
+        mask = OpBitwiseAnd(t_uint, mask, Constant(t_uint, 3));
+
+        const Id modifier_a = OpVectorExtractDynamic(t_float, lut_a, mask);
+        const Id modifier_b = OpVectorExtractDynamic(t_float, lut_b, mask);
+
+        const Id op_a = OpFMul(t_float, AsFloat(Visit(operation[0])), modifier_a);
+        const Id op_b = OpFMul(t_float, AsFloat(Visit(operation[1])), modifier_b);
+        return {OpFAdd(t_float, op_a, op_b), Type::Float};
     }
 
-    Id FSwizzleAdd(Operation operation) {
-        UNIMPLEMENTED();
-        return {};
+    Expression HNegate(Operation operation) {
+        const bool is_f16 = device.IsFloat16Supported();
+        const Id minus_one = Constant(t_scalar_half, is_f16 ? 0xbc00 : 0xbf800000);
+        const Id one = Constant(t_scalar_half, is_f16 ? 0x3c00 : 0x3f800000);
+        const auto GetNegate = [&](std::size_t index) {
+            return OpSelect(t_scalar_half, AsBool(Visit(operation[index])), minus_one, one);
+        };
+        const Id negation = OpCompositeConstruct(t_half, GetNegate(1), GetNegate(2));
+        return {OpFMul(t_half, AsHalfFloat(Visit(operation[0])), negation), Type::HalfFloat};
     }
 
-    Id HNegate(Operation operation) {
-        UNIMPLEMENTED();
-        return {};
+    Expression HClamp(Operation operation) {
+        const auto Pack = [&](std::size_t index) {
+            const Id scalar = GetHalfScalarFromFloat(AsFloat(Visit(operation[index])));
+            return OpCompositeConstruct(t_half, scalar, scalar);
+        };
+        const Id value = AsHalfFloat(Visit(operation[0]));
+        const Id min = Pack(1);
+        const Id max = Pack(2);
+
+        const Id clamped = OpFClamp(t_half, value, min, max);
+        if (IsPrecise(operation)) {
+            Decorate(clamped, spv::Decoration::NoContraction);
+        }
+        return {clamped, Type::HalfFloat};
     }
 
-    Id HClamp(Operation operation) {
-        UNIMPLEMENTED();
-        return {};
+    Expression HCastFloat(Operation operation) {
+        const Id value = GetHalfScalarFromFloat(AsFloat(Visit(operation[0])));
+        return {OpCompositeConstruct(t_half, value, Constant(t_scalar_half, 0)), Type::HalfFloat};
     }
 
-    Id HCastFloat(Operation operation) {
-        UNIMPLEMENTED();
-        return {};
+    Expression HUnpack(Operation operation) {
+        Expression operand = Visit(operation[0]);
+        const auto type = std::get<Tegra::Shader::HalfType>(operation.GetMeta());
+        if (type == Tegra::Shader::HalfType::H0_H1) {
+            return operand;
+        }
+        const auto value = [&] {
+            switch (std::get<Tegra::Shader::HalfType>(operation.GetMeta())) {
+            case Tegra::Shader::HalfType::F32:
+                return GetHalfScalarFromFloat(AsFloat(operand));
+            case Tegra::Shader::HalfType::H0_H0:
+                return OpCompositeExtract(t_scalar_half, AsHalfFloat(operand), 0);
+            case Tegra::Shader::HalfType::H1_H1:
+                return OpCompositeExtract(t_scalar_half, AsHalfFloat(operand), 1);
+            default:
+                UNREACHABLE();
+                return ConstantNull(t_half);
+            }
+        }();
+        return {OpCompositeConstruct(t_half, value, value), Type::HalfFloat};
     }
 
-    Id HUnpack(Operation operation) {
-        UNIMPLEMENTED();
-        return {};
+    Expression HMergeF32(Operation operation) {
+        const Id value = AsHalfFloat(Visit(operation[0]));
+        return {GetFloatFromHalfScalar(OpCompositeExtract(t_scalar_half, value, 0)), Type::Float};
     }
 
-    Id HMergeF32(Operation operation) {
-        UNIMPLEMENTED();
-        return {};
+    template <u32 offset>
+    Expression HMergeHN(Operation operation) {
+        const Id target = AsHalfFloat(Visit(operation[0]));
+        const Id source = AsHalfFloat(Visit(operation[1]));
+        const Id object = OpCompositeExtract(t_scalar_half, source, offset);
+        return {OpCompositeInsert(t_half, object, target, offset), Type::HalfFloat};
     }
 
-    Id HMergeH0(Operation operation) {
-        UNIMPLEMENTED();
-        return {};
+    Expression HPack2(Operation operation) {
+        const Id low = GetHalfScalarFromFloat(AsFloat(Visit(operation[0])));
+        const Id high = GetHalfScalarFromFloat(AsFloat(Visit(operation[1])));
+        return {OpCompositeConstruct(t_half, low, high), Type::HalfFloat};
     }
 
-    Id HMergeH1(Operation operation) {
-        UNIMPLEMENTED();
-        return {};
-    }
-
-    Id HPack2(Operation operation) {
-        UNIMPLEMENTED();
-        return {};
-    }
-
-    Id LogicalAssign(Operation operation) {
+    Expression LogicalAssign(Operation operation) {
         const Node& dest = operation[0];
         const Node& src = operation[1];
 
@@ -850,106 +1467,190 @@ private:
             target = internal_flags.at(static_cast<u32>(flag->GetFlag()));
         }
 
-        Emit(OpStore(target, Visit(src)));
-        return {};
-    }
-
-    Id LogicalPick2(Operation operation) {
-        UNIMPLEMENTED();
-        return {};
-    }
-
-    Id LogicalAnd2(Operation operation) {
-        UNIMPLEMENTED();
+        OpStore(target, AsBool(Visit(src)));
         return {};
     }
 
     Id GetTextureSampler(Operation operation) {
-        const auto meta = std::get_if<MetaTexture>(&operation.GetMeta());
-        const auto entry = sampler_images.at(static_cast<u32>(meta->sampler.GetIndex()));
-        return Emit(OpLoad(entry.sampled_image_type, entry.sampler));
+        const auto& meta = std::get<MetaTexture>(operation.GetMeta());
+        ASSERT(!meta.sampler.IsBuffer());
+
+        const auto& entry = sampled_images.at(meta.sampler.GetIndex());
+        return OpLoad(entry.sampled_image_type, entry.sampler);
     }
 
     Id GetTextureImage(Operation operation) {
-        const auto meta = std::get_if<MetaTexture>(&operation.GetMeta());
-        const auto entry = sampler_images.at(static_cast<u32>(meta->sampler.GetIndex()));
-        return Emit(OpImage(entry.image_type, GetTextureSampler(operation)));
+        const auto& meta = std::get<MetaTexture>(operation.GetMeta());
+        const u32 index = meta.sampler.GetIndex();
+        if (meta.sampler.IsBuffer()) {
+            const auto& entry = texel_buffers.at(index);
+            return OpLoad(entry.image_type, entry.image);
+        } else {
+            const auto& entry = sampled_images.at(index);
+            return OpImage(entry.image_type, GetTextureSampler(operation));
+        }
     }
 
-    Id GetTextureCoordinates(Operation operation) {
-        const auto meta = std::get_if<MetaTexture>(&operation.GetMeta());
+    Id GetImage(Operation operation) {
+        const auto& meta = std::get<MetaImage>(operation.GetMeta());
+        const auto entry = images.at(meta.image.GetIndex());
+        return OpLoad(entry.image_type, entry.image);
+    }
+
+    Id AssembleVector(const std::vector<Id>& coords, Type type) {
+        const Id coords_type = GetTypeVectorDefinitionLut(type).at(coords.size() - 1);
+        return coords.size() == 1 ? coords[0] : OpCompositeConstruct(coords_type, coords);
+    }
+
+    Id GetCoordinates(Operation operation, Type type) {
         std::vector<Id> coords;
         for (std::size_t i = 0; i < operation.GetOperandsCount(); ++i) {
-            coords.push_back(Visit(operation[i]));
+            coords.push_back(As(Visit(operation[i]), type));
         }
-        if (meta->sampler.IsArray()) {
-            const Id array_integer = BitcastTo<Type::Int>(Visit(meta->array));
-            coords.push_back(Emit(OpConvertSToF(t_float, array_integer)));
+        if (const auto meta = std::get_if<MetaTexture>(&operation.GetMeta())) {
+            // Add array coordinate for textures
+            if (meta->sampler.IsArray()) {
+                Id array = AsInt(Visit(meta->array));
+                if (type == Type::Float) {
+                    array = OpConvertSToF(t_float, array);
+                }
+                coords.push_back(array);
+            }
         }
-        if (meta->sampler.IsShadow()) {
-            coords.push_back(Visit(meta->depth_compare));
+        return AssembleVector(coords, type);
+    }
+
+    Id GetOffsetCoordinates(Operation operation) {
+        const auto& meta = std::get<MetaTexture>(operation.GetMeta());
+        std::vector<Id> coords;
+        coords.reserve(meta.aoffi.size());
+        for (const auto& coord : meta.aoffi) {
+            coords.push_back(AsInt(Visit(coord)));
+        }
+        return AssembleVector(coords, Type::Int);
+    }
+
+    std::pair<Id, Id> GetDerivatives(Operation operation) {
+        const auto& meta = std::get<MetaTexture>(operation.GetMeta());
+        const auto& derivatives = meta.derivates;
+        ASSERT(derivatives.size() % 2 == 0);
+
+        const std::size_t components = derivatives.size() / 2;
+        std::vector<Id> dx, dy;
+        dx.reserve(components);
+        dy.reserve(components);
+        for (std::size_t index = 0; index < components; ++index) {
+            dx.push_back(AsFloat(Visit(derivatives.at(index * 2 + 0))));
+            dy.push_back(AsFloat(Visit(derivatives.at(index * 2 + 1))));
+        }
+        return {AssembleVector(dx, Type::Float), AssembleVector(dy, Type::Float)};
+    }
+
+    Expression GetTextureElement(Operation operation, Id sample_value, Type type) {
+        const auto& meta = std::get<MetaTexture>(operation.GetMeta());
+        const auto type_def = GetTypeDefinition(type);
+        return {OpCompositeExtract(type_def, sample_value, meta.element), type};
+    }
+
+    Expression Texture(Operation operation) {
+        const auto& meta = std::get<MetaTexture>(operation.GetMeta());
+        UNIMPLEMENTED_IF(!meta.aoffi.empty());
+
+        const bool can_implicit = stage == ShaderType::Fragment;
+        const Id sampler = GetTextureSampler(operation);
+        const Id coords = GetCoordinates(operation, Type::Float);
+
+        if (meta.depth_compare) {
+            // Depth sampling
+            UNIMPLEMENTED_IF(meta.bias);
+            const Id dref = AsFloat(Visit(meta.depth_compare));
+            if (can_implicit) {
+                return {OpImageSampleDrefImplicitLod(t_float, sampler, coords, dref, {}),
+                        Type::Float};
+            } else {
+                return {OpImageSampleDrefExplicitLod(t_float, sampler, coords, dref,
+                                                     spv::ImageOperandsMask::Lod, v_float_zero),
+                        Type::Float};
+            }
         }
 
-        const std::array<Id, 4> t_float_lut = {nullptr, t_float2, t_float3, t_float4};
-        return coords.size() == 1
-                   ? coords[0]
-                   : Emit(OpCompositeConstruct(t_float_lut.at(coords.size() - 1), coords));
-    }
-
-    Id GetTextureElement(Operation operation, Id sample_value) {
-        const auto meta = std::get_if<MetaTexture>(&operation.GetMeta());
-        ASSERT(meta);
-        return Emit(OpCompositeExtract(t_float, sample_value, meta->element));
-    }
-
-    Id Texture(Operation operation) {
-        const Id texture = Emit(OpImageSampleImplicitLod(t_float4, GetTextureSampler(operation),
-                                                         GetTextureCoordinates(operation)));
-        return GetTextureElement(operation, texture);
-    }
-
-    Id TextureLod(Operation operation) {
-        const auto meta = std::get_if<MetaTexture>(&operation.GetMeta());
-        const Id texture = Emit(OpImageSampleExplicitLod(
-            t_float4, GetTextureSampler(operation), GetTextureCoordinates(operation),
-            spv::ImageOperandsMask::Lod, Visit(meta->lod)));
-        return GetTextureElement(operation, texture);
-    }
-
-    Id TextureGather(Operation operation) {
-        const auto meta = std::get_if<MetaTexture>(&operation.GetMeta());
-        const auto coords = GetTextureCoordinates(operation);
+        std::vector<Id> operands;
+        spv::ImageOperandsMask mask{};
+        if (meta.bias) {
+            mask = mask | spv::ImageOperandsMask::Bias;
+            operands.push_back(AsFloat(Visit(meta.bias)));
+        }
 
         Id texture;
-        if (meta->sampler.IsShadow()) {
-            texture = Emit(OpImageDrefGather(t_float4, GetTextureSampler(operation), coords,
-                                             Visit(meta->component)));
+        if (can_implicit) {
+            texture = OpImageSampleImplicitLod(t_float4, sampler, coords, mask, operands);
+        } else {
+            texture = OpImageSampleExplicitLod(t_float4, sampler, coords,
+                                               mask | spv::ImageOperandsMask::Lod, v_float_zero,
+                                               operands);
+        }
+        return GetTextureElement(operation, texture, Type::Float);
+    }
+
+    Expression TextureLod(Operation operation) {
+        const auto& meta = std::get<MetaTexture>(operation.GetMeta());
+
+        const Id sampler = GetTextureSampler(operation);
+        const Id coords = GetCoordinates(operation, Type::Float);
+        const Id lod = AsFloat(Visit(meta.lod));
+
+        spv::ImageOperandsMask mask = spv::ImageOperandsMask::Lod;
+        std::vector<Id> operands;
+        if (!meta.aoffi.empty()) {
+            mask = mask | spv::ImageOperandsMask::Offset;
+            operands.push_back(GetOffsetCoordinates(operation));
+        }
+
+        if (meta.sampler.IsShadow()) {
+            const Id dref = AsFloat(Visit(meta.depth_compare));
+            return {
+                OpImageSampleDrefExplicitLod(t_float, sampler, coords, dref, mask, lod, operands),
+                Type::Float};
+        }
+        const Id texture = OpImageSampleExplicitLod(t_float4, sampler, coords, mask, lod, operands);
+        return GetTextureElement(operation, texture, Type::Float);
+    }
+
+    Expression TextureGather(Operation operation) {
+        const auto& meta = std::get<MetaTexture>(operation.GetMeta());
+        UNIMPLEMENTED_IF(!meta.aoffi.empty());
+
+        const Id coords = GetCoordinates(operation, Type::Float);
+        Id texture{};
+        if (meta.sampler.IsShadow()) {
+            texture = OpImageDrefGather(t_float4, GetTextureSampler(operation), coords,
+                                        AsFloat(Visit(meta.depth_compare)));
         } else {
             u32 component_value = 0;
-            if (meta->component) {
-                const auto component = std::get_if<ImmediateNode>(&*meta->component);
+            if (meta.component) {
+                const auto component = std::get_if<ImmediateNode>(&*meta.component);
                 ASSERT_MSG(component, "Component is not an immediate value");
                 component_value = component->GetValue();
             }
-            texture = Emit(OpImageGather(t_float4, GetTextureSampler(operation), coords,
-                                         Constant(t_uint, component_value)));
+            texture = OpImageGather(t_float4, GetTextureSampler(operation), coords,
+                                    Constant(t_uint, component_value));
         }
-
-        return GetTextureElement(operation, texture);
+        return GetTextureElement(operation, texture, Type::Float);
     }
 
-    Id TextureQueryDimensions(Operation operation) {
-        const auto meta = std::get_if<MetaTexture>(&operation.GetMeta());
-        const auto image_id = GetTextureImage(operation);
-        AddCapability(spv::Capability::ImageQuery);
+    Expression TextureQueryDimensions(Operation operation) {
+        const auto& meta = std::get<MetaTexture>(operation.GetMeta());
+        UNIMPLEMENTED_IF(!meta.aoffi.empty());
+        UNIMPLEMENTED_IF(meta.depth_compare);
 
-        if (meta->element == 3) {
-            return BitcastTo<Type::Float>(Emit(OpImageQueryLevels(t_int, image_id)));
+        const auto image_id = GetTextureImage(operation);
+        if (meta.element == 3) {
+            return {OpImageQueryLevels(t_int, image_id), Type::Int};
         }
 
-        const Id lod = VisitOperand<Type::Uint>(operation, 0);
+        const Id lod = AsUint(Visit(operation[0]));
         const std::size_t coords_count = [&]() {
-            switch (const auto type = meta->sampler.GetType(); type) {
+            switch (const auto type = meta.sampler.GetType(); type) {
             case Tegra::Shader::TextureType::Texture1D:
                 return 1;
             case Tegra::Shader::TextureType::Texture2D:
@@ -963,141 +1664,190 @@ private:
             }
         }();
 
-        if (meta->element >= coords_count) {
-            return Constant(t_float, 0.0f);
+        if (meta.element >= coords_count) {
+            return {v_float_zero, Type::Float};
         }
 
         const std::array<Id, 3> types = {t_int, t_int2, t_int3};
-        const Id sizes = Emit(OpImageQuerySizeLod(types.at(coords_count - 1), image_id, lod));
-        const Id size = Emit(OpCompositeExtract(t_int, sizes, meta->element));
-        return BitcastTo<Type::Float>(size);
+        const Id sizes = OpImageQuerySizeLod(types.at(coords_count - 1), image_id, lod);
+        const Id size = OpCompositeExtract(t_int, sizes, meta.element);
+        return {size, Type::Int};
     }
 
-    Id TextureQueryLod(Operation operation) {
+    Expression TextureQueryLod(Operation operation) {
+        const auto& meta = std::get<MetaTexture>(operation.GetMeta());
+        UNIMPLEMENTED_IF(!meta.aoffi.empty());
+        UNIMPLEMENTED_IF(meta.depth_compare);
+
+        if (meta.element >= 2) {
+            UNREACHABLE_MSG("Invalid element");
+            return {v_float_zero, Type::Float};
+        }
+        const auto sampler_id = GetTextureSampler(operation);
+
+        const Id multiplier = Constant(t_float, 256.0f);
+        const Id multipliers = ConstantComposite(t_float2, multiplier, multiplier);
+
+        const Id coords = GetCoordinates(operation, Type::Float);
+        Id size = OpImageQueryLod(t_float2, sampler_id, coords);
+        size = OpFMul(t_float2, size, multipliers);
+        size = OpConvertFToS(t_int2, size);
+        return GetTextureElement(operation, size, Type::Int);
+    }
+
+    Expression TexelFetch(Operation operation) {
+        const auto& meta = std::get<MetaTexture>(operation.GetMeta());
+        UNIMPLEMENTED_IF(meta.depth_compare);
+
+        const Id image = GetTextureImage(operation);
+        const Id coords = GetCoordinates(operation, Type::Int);
+        Id fetch;
+        if (meta.lod && !meta.sampler.IsBuffer()) {
+            fetch = OpImageFetch(t_float4, image, coords, spv::ImageOperandsMask::Lod,
+                                 AsInt(Visit(meta.lod)));
+        } else {
+            fetch = OpImageFetch(t_float4, image, coords);
+        }
+        return GetTextureElement(operation, fetch, Type::Float);
+    }
+
+    Expression TextureGradient(Operation operation) {
+        const auto& meta = std::get<MetaTexture>(operation.GetMeta());
+        UNIMPLEMENTED_IF(!meta.aoffi.empty());
+
+        const Id sampler = GetTextureSampler(operation);
+        const Id coords = GetCoordinates(operation, Type::Float);
+        const auto [dx, dy] = GetDerivatives(operation);
+        const std::vector grad = {dx, dy};
+
+        static constexpr auto mask = spv::ImageOperandsMask::Grad;
+        const Id texture = OpImageSampleImplicitLod(t_float4, sampler, coords, mask, grad);
+        return GetTextureElement(operation, texture, Type::Float);
+    }
+
+    Expression ImageLoad(Operation operation) {
         UNIMPLEMENTED();
         return {};
     }
 
-    Id TexelFetch(Operation operation) {
+    Expression ImageStore(Operation operation) {
+        const auto meta{std::get<MetaImage>(operation.GetMeta())};
+        std::vector<Id> colors;
+        for (const auto& value : meta.values) {
+            colors.push_back(AsUint(Visit(value)));
+        }
+
+        const Id coords = GetCoordinates(operation, Type::Int);
+        const Id texel = OpCompositeConstruct(t_uint4, colors);
+
+        OpImageWrite(GetImage(operation), coords, texel, {});
+        return {};
+    }
+
+    Expression AtomicImageAdd(Operation operation) {
         UNIMPLEMENTED();
         return {};
     }
 
-    Id TextureGradient(Operation operation) {
+    Expression AtomicImageMin(Operation operation) {
         UNIMPLEMENTED();
         return {};
     }
 
-    Id ImageLoad(Operation operation) {
+    Expression AtomicImageMax(Operation operation) {
         UNIMPLEMENTED();
         return {};
     }
 
-    Id ImageStore(Operation operation) {
+    Expression AtomicImageAnd(Operation operation) {
         UNIMPLEMENTED();
         return {};
     }
 
-    Id AtomicImageAdd(Operation operation) {
+    Expression AtomicImageOr(Operation operation) {
         UNIMPLEMENTED();
         return {};
     }
 
-    Id AtomicImageAnd(Operation operation) {
+    Expression AtomicImageXor(Operation operation) {
         UNIMPLEMENTED();
         return {};
     }
 
-    Id AtomicImageOr(Operation operation) {
+    Expression AtomicImageExchange(Operation operation) {
         UNIMPLEMENTED();
         return {};
     }
 
-    Id AtomicImageXor(Operation operation) {
-        UNIMPLEMENTED();
-        return {};
-    }
-
-    Id AtomicImageExchange(Operation operation) {
-        UNIMPLEMENTED();
-        return {};
-    }
-
-    Id Branch(Operation operation) {
-        const auto target = std::get_if<ImmediateNode>(&*operation[0]);
-        UNIMPLEMENTED_IF(!target);
-
-        Emit(OpStore(jmp_to, Constant(t_uint, target->GetValue())));
-        Emit(OpBranch(continue_label));
-        inside_branch = conditional_nest_count;
-        if (conditional_nest_count == 0) {
-            Emit(OpLabel());
+    Expression Branch(Operation operation) {
+        const auto& target = std::get<ImmediateNode>(*operation[0]);
+        OpStore(jmp_to, Constant(t_uint, target.GetValue()));
+        OpBranch(continue_label);
+        inside_branch = true;
+        if (!conditional_branch_set) {
+            AddLabel();
         }
         return {};
     }
 
-    Id BranchIndirect(Operation operation) {
-        const Id op_a = VisitOperand<Type::Uint>(operation, 0);
+    Expression BranchIndirect(Operation operation) {
+        const Id op_a = AsUint(Visit(operation[0]));
 
-        Emit(OpStore(jmp_to, op_a));
-        Emit(OpBranch(continue_label));
-        inside_branch = conditional_nest_count;
-        if (conditional_nest_count == 0) {
-            Emit(OpLabel());
+        OpStore(jmp_to, op_a);
+        OpBranch(continue_label);
+        inside_branch = true;
+        if (!conditional_branch_set) {
+            AddLabel();
         }
         return {};
     }
 
-    Id PushFlowStack(Operation operation) {
-        const auto target = std::get_if<ImmediateNode>(&*operation[0]);
-        ASSERT(target);
-
+    Expression PushFlowStack(Operation operation) {
+        const auto& target = std::get<ImmediateNode>(*operation[0]);
         const auto [flow_stack, flow_stack_top] = GetFlowStack(operation);
-        const Id current = Emit(OpLoad(t_uint, flow_stack_top));
-        const Id next = Emit(OpIAdd(t_uint, current, Constant(t_uint, 1)));
-        const Id access = Emit(OpAccessChain(t_func_uint, flow_stack, current));
+        const Id current = OpLoad(t_uint, flow_stack_top);
+        const Id next = OpIAdd(t_uint, current, Constant(t_uint, 1));
+        const Id access = OpAccessChain(t_func_uint, flow_stack, current);
 
-        Emit(OpStore(access, Constant(t_uint, target->GetValue())));
-        Emit(OpStore(flow_stack_top, next));
+        OpStore(access, Constant(t_uint, target.GetValue()));
+        OpStore(flow_stack_top, next);
         return {};
     }
 
-    Id PopFlowStack(Operation operation) {
+    Expression PopFlowStack(Operation operation) {
         const auto [flow_stack, flow_stack_top] = GetFlowStack(operation);
-        const Id current = Emit(OpLoad(t_uint, flow_stack_top));
-        const Id previous = Emit(OpISub(t_uint, current, Constant(t_uint, 1)));
-        const Id access = Emit(OpAccessChain(t_func_uint, flow_stack, previous));
-        const Id target = Emit(OpLoad(t_uint, access));
+        const Id current = OpLoad(t_uint, flow_stack_top);
+        const Id previous = OpISub(t_uint, current, Constant(t_uint, 1));
+        const Id access = OpAccessChain(t_func_uint, flow_stack, previous);
+        const Id target = OpLoad(t_uint, access);
 
-        Emit(OpStore(flow_stack_top, previous));
-        Emit(OpStore(jmp_to, target));
-        Emit(OpBranch(continue_label));
-        inside_branch = conditional_nest_count;
-        if (conditional_nest_count == 0) {
-            Emit(OpLabel());
+        OpStore(flow_stack_top, previous);
+        OpStore(jmp_to, target);
+        OpBranch(continue_label);
+        inside_branch = true;
+        if (!conditional_branch_set) {
+            AddLabel();
         }
         return {};
     }
 
-    Id PreExit() {
-        switch (stage) {
-        case ShaderType::Vertex: {
-            // TODO(Rodrigo): We should use VK_EXT_depth_range_unrestricted instead, but it doesn't
-            // seem to be working on Nvidia's drivers and Intel (mesa and blob) doesn't support it.
-            const Id z_pointer = AccessElement(t_out_float, per_vertex, position_index, 2u);
-            Id depth = Emit(OpLoad(t_float, z_pointer));
-            depth = Emit(OpFAdd(t_float, depth, Constant(t_float, 1.0f)));
-            depth = Emit(OpFMul(t_float, depth, Constant(t_float, 0.5f)));
-            Emit(OpStore(z_pointer, depth));
-            break;
+    void PreExit() {
+        if (stage == ShaderType::Vertex) {
+            const u32 position_index = out_indices.position.value();
+            const Id z_pointer = AccessElement(t_out_float, out_vertex, position_index, 2U);
+            const Id w_pointer = AccessElement(t_out_float, out_vertex, position_index, 3U);
+            Id depth = OpLoad(t_float, z_pointer);
+            depth = OpFAdd(t_float, depth, OpLoad(t_float, w_pointer));
+            depth = OpFMul(t_float, depth, Constant(t_float, 0.5f));
+            OpStore(z_pointer, depth);
         }
-        case ShaderType::Fragment: {
+        if (stage == ShaderType::Fragment) {
             const auto SafeGetRegister = [&](u32 reg) {
                 // TODO(Rodrigo): Replace with contains once C++20 releases
                 if (const auto it = registers.find(reg); it != registers.end()) {
-                    return Emit(OpLoad(t_float, it->second));
+                    return OpLoad(t_float, it->second);
                 }
-                return Constant(t_float, 0.0f);
+                return v_float_zero;
             };
 
             UNIMPLEMENTED_IF_MSG(header.ps.omap.sample_mask != 0,
@@ -1112,8 +1862,8 @@ private:
                 // TODO(Subv): Figure out how dual-source blending is configured in the Switch.
                 for (u32 component = 0; component < 4; ++component) {
                     if (header.ps.IsColorComponentOutputEnabled(rt, component)) {
-                        Emit(OpStore(AccessElement(t_out_float, frag_colors.at(rt), component),
-                                     SafeGetRegister(current_reg)));
+                        OpStore(AccessElement(t_out_float, frag_colors.at(rt), component),
+                                SafeGetRegister(current_reg));
                         ++current_reg;
                     }
                 }
@@ -1121,110 +1871,117 @@ private:
             if (header.ps.omap.depth) {
                 // The depth output is always 2 registers after the last color output, and
                 // current_reg already contains one past the last color register.
-                Emit(OpStore(frag_depth, SafeGetRegister(current_reg + 1)));
+                OpStore(frag_depth, SafeGetRegister(current_reg + 1));
             }
-            break;
         }
-        }
-
-        return {};
     }
 
-    Id Exit(Operation operation) {
+    Expression Exit(Operation operation) {
         PreExit();
-        inside_branch = conditional_nest_count;
-        if (conditional_nest_count > 0) {
-            Emit(OpReturn());
+        inside_branch = true;
+        if (conditional_branch_set) {
+            OpReturn();
         } else {
             const Id dummy = OpLabel();
-            Emit(OpBranch(dummy));
-            Emit(dummy);
-            Emit(OpReturn());
-            Emit(OpLabel());
+            OpBranch(dummy);
+            AddLabel(dummy);
+            OpReturn();
+            AddLabel();
         }
         return {};
     }
 
-    Id Discard(Operation operation) {
-        inside_branch = conditional_nest_count;
-        if (conditional_nest_count > 0) {
-            Emit(OpKill());
+    Expression Discard(Operation operation) {
+        inside_branch = true;
+        if (conditional_branch_set) {
+            OpKill();
         } else {
             const Id dummy = OpLabel();
-            Emit(OpBranch(dummy));
-            Emit(dummy);
-            Emit(OpKill());
-            Emit(OpLabel());
+            OpBranch(dummy);
+            AddLabel(dummy);
+            OpKill();
+            AddLabel();
         }
         return {};
     }
 
-    Id EmitVertex(Operation operation) {
-        UNIMPLEMENTED();
+    Expression EmitVertex(Operation) {
+        OpEmitVertex();
         return {};
     }
 
-    Id EndPrimitive(Operation operation) {
-        UNIMPLEMENTED();
+    Expression EndPrimitive(Operation operation) {
+        OpEndPrimitive();
         return {};
     }
 
-    Id YNegate(Operation operation) {
+    Expression InvocationId(Operation) {
+        return {OpLoad(t_int, invocation_id), Type::Int};
+    }
+
+    Expression YNegate(Operation) {
         UNIMPLEMENTED();
-        return {};
+        return {Constant(t_float, 1.0f), Type::Float};
     }
 
     template <u32 element>
-    Id LocalInvocationId(Operation) {
-        UNIMPLEMENTED();
-        return {};
+    Expression LocalInvocationId(Operation) {
+        const Id id = OpLoad(t_uint3, local_invocation_id);
+        return {OpCompositeExtract(t_uint, id, element), Type::Uint};
     }
 
     template <u32 element>
-    Id WorkGroupId(Operation) {
-        UNIMPLEMENTED();
-        return {};
+    Expression WorkGroupId(Operation operation) {
+        const Id id = OpLoad(t_uint3, workgroup_id);
+        return {OpCompositeExtract(t_uint, id, element), Type::Uint};
     }
 
-    Id BallotThread(Operation) {
-        UNIMPLEMENTED();
-        return {};
+    Expression BallotThread(Operation operation) {
+        const Id predicate = AsBool(Visit(operation[0]));
+        const Id ballot = OpSubgroupBallotKHR(t_uint4, predicate);
+
+        if (!device.IsWarpSizePotentiallyBiggerThanGuest()) {
+            // Guest-like devices can just return the first index.
+            return {OpCompositeExtract(t_uint, ballot, 0U), Type::Uint};
+        }
+
+        // The others will have to return what is local to the current thread.
+        // For instance a device with a warp size of 64 will return the upper uint when the current
+        // thread is 38.
+        const Id tid = OpLoad(t_uint, thread_id);
+        const Id thread_index = OpShiftRightLogical(t_uint, tid, Constant(t_uint, 5));
+        return {OpVectorExtractDynamic(t_uint, ballot, thread_index), Type::Uint};
     }
 
-    Id VoteAll(Operation) {
-        UNIMPLEMENTED();
-        return {};
+    template <Id (Module::*func)(Id, Id)>
+    Expression Vote(Operation operation) {
+        // TODO(Rodrigo): Handle devices with different warp sizes
+        const Id predicate = AsBool(Visit(operation[0]));
+        return {(this->*func)(t_bool, predicate), Type::Bool};
     }
 
-    Id VoteAny(Operation) {
-        UNIMPLEMENTED();
-        return {};
+    Expression ThreadId(Operation) {
+        return {OpLoad(t_uint, thread_id), Type::Uint};
     }
 
-    Id VoteEqual(Operation) {
-        UNIMPLEMENTED();
-        return {};
+    Expression ShuffleIndexed(Operation operation) {
+        const Id value = AsFloat(Visit(operation[0]));
+        const Id index = AsUint(Visit(operation[1]));
+        return {OpSubgroupReadInvocationKHR(t_float, value, index), Type::Float};
     }
 
-    Id ThreadId(Operation) {
-        UNIMPLEMENTED();
-        return {};
-    }
-
-    Id ShuffleIndexed(Operation) {
-        UNIMPLEMENTED();
-        return {};
-    }
-
-    Id DeclareBuiltIn(spv::BuiltIn builtin, spv::StorageClass storage, Id type,
-                      const std::string& name) {
+    Id DeclareBuiltIn(spv::BuiltIn builtin, spv::StorageClass storage, Id type, std::string name) {
         const Id id = OpVariable(type, storage);
         Decorate(id, spv::Decoration::BuiltIn, static_cast<u32>(builtin));
-        AddGlobalVariable(Name(id, name));
+        AddGlobalVariable(Name(id, std::move(name)));
         interfaces.push_back(id);
         return id;
     }
 
+    Id DeclareInputBuiltIn(spv::BuiltIn builtin, Id type, std::string name) {
+        return DeclareBuiltIn(builtin, spv::StorageClass::Input, type, std::move(name));
+    }
+
     bool IsRenderTargetUsed(u32 rt) const {
         for (u32 component = 0; component < 4; ++component) {
             if (header.ps.IsColorComponentOutputEnabled(rt, component)) {
@@ -1242,66 +1999,148 @@ private:
             members.push_back(Constant(t_uint, element));
         }
 
-        return Emit(OpAccessChain(pointer_type, composite, members));
+        return OpAccessChain(pointer_type, composite, members);
     }
 
-    template <Type type>
-    Id VisitOperand(Operation operation, std::size_t operand_index) {
-        const Id value = Visit(operation[operand_index]);
-
-        switch (type) {
+    Id As(Expression expr, Type wanted_type) {
+        switch (wanted_type) {
         case Type::Bool:
+            return AsBool(expr);
         case Type::Bool2:
+            return AsBool2(expr);
         case Type::Float:
-            return value;
+            return AsFloat(expr);
         case Type::Int:
-            return Emit(OpBitcast(t_int, value));
+            return AsInt(expr);
         case Type::Uint:
-            return Emit(OpBitcast(t_uint, value));
+            return AsUint(expr);
         case Type::HalfFloat:
-            UNIMPLEMENTED();
-        }
-        UNREACHABLE();
-        return value;
-    }
-
-    template <Type type>
-    Id BitcastFrom(Id value) {
-        switch (type) {
-        case Type::Bool:
-        case Type::Bool2:
-        case Type::Float:
-            return value;
-        case Type::Int:
-        case Type::Uint:
-            return Emit(OpBitcast(t_float, value));
-        case Type::HalfFloat:
-            UNIMPLEMENTED();
-        }
-        UNREACHABLE();
-        return value;
-    }
-
-    template <Type type>
-    Id BitcastTo(Id value) {
-        switch (type) {
-        case Type::Bool:
-        case Type::Bool2:
+            return AsHalfFloat(expr);
+        default:
             UNREACHABLE();
-        case Type::Float:
-            return Emit(OpBitcast(t_float, value));
-        case Type::Int:
-            return Emit(OpBitcast(t_int, value));
-        case Type::Uint:
-            return Emit(OpBitcast(t_uint, value));
-        case Type::HalfFloat:
-            UNIMPLEMENTED();
+            return expr.id;
+        }
+    }
+
+    Id AsBool(Expression expr) {
+        ASSERT(expr.type == Type::Bool);
+        return expr.id;
+    }
+
+    Id AsBool2(Expression expr) {
+        ASSERT(expr.type == Type::Bool2);
+        return expr.id;
+    }
+
+    Id AsFloat(Expression expr) {
+        switch (expr.type) {
+        case Type::Float:
+            return expr.id;
+        case Type::Int:
+        case Type::Uint:
+            return OpBitcast(t_float, expr.id);
+        case Type::HalfFloat:
+            if (device.IsFloat16Supported()) {
+                return OpBitcast(t_float, expr.id);
+            }
+            return OpBitcast(t_float, OpPackHalf2x16(t_uint, expr.id));
+        default:
+            UNREACHABLE();
+            return expr.id;
+        }
+    }
+
+    Id AsInt(Expression expr) {
+        switch (expr.type) {
+        case Type::Int:
+            return expr.id;
+        case Type::Float:
+        case Type::Uint:
+            return OpBitcast(t_int, expr.id);
+        case Type::HalfFloat:
+            if (device.IsFloat16Supported()) {
+                return OpBitcast(t_int, expr.id);
+            }
+            return OpPackHalf2x16(t_int, expr.id);
+        default:
+            UNREACHABLE();
+            return expr.id;
+        }
+    }
+
+    Id AsUint(Expression expr) {
+        switch (expr.type) {
+        case Type::Uint:
+            return expr.id;
+        case Type::Float:
+        case Type::Int:
+            return OpBitcast(t_uint, expr.id);
+        case Type::HalfFloat:
+            if (device.IsFloat16Supported()) {
+                return OpBitcast(t_uint, expr.id);
+            }
+            return OpPackHalf2x16(t_uint, expr.id);
+        default:
+            UNREACHABLE();
+            return expr.id;
+        }
+    }
+
+    Id AsHalfFloat(Expression expr) {
+        switch (expr.type) {
+        case Type::HalfFloat:
+            return expr.id;
+        case Type::Float:
+        case Type::Int:
+        case Type::Uint:
+            if (device.IsFloat16Supported()) {
+                return OpBitcast(t_half, expr.id);
+            }
+            return OpUnpackHalf2x16(t_half, AsUint(expr));
+        default:
+            UNREACHABLE();
+            return expr.id;
+        }
+    }
+
+    Id GetHalfScalarFromFloat(Id value) {
+        if (device.IsFloat16Supported()) {
+            return OpFConvert(t_scalar_half, value);
         }
-        UNREACHABLE();
         return value;
     }
 
-    Id GetTypeDefinition(Type type) {
+    Id GetFloatFromHalfScalar(Id value) {
+        if (device.IsFloat16Supported()) {
+            return OpFConvert(t_float, value);
+        }
+        return value;
+    }
+
+    AttributeType GetAttributeType(u32 location) const {
+        if (stage != ShaderType::Vertex) {
+            return {Type::Float, t_in_float, t_in_float4};
+        }
+        switch (specialization.attribute_types.at(location)) {
+        case Maxwell::VertexAttribute::Type::SignedNorm:
+        case Maxwell::VertexAttribute::Type::UnsignedNorm:
+        case Maxwell::VertexAttribute::Type::Float:
+            return {Type::Float, t_in_float, t_in_float4};
+        case Maxwell::VertexAttribute::Type::SignedInt:
+            return {Type::Int, t_in_int, t_in_int4};
+        case Maxwell::VertexAttribute::Type::UnsignedInt:
+            return {Type::Uint, t_in_uint, t_in_uint4};
+        case Maxwell::VertexAttribute::Type::UnsignedScaled:
+        case Maxwell::VertexAttribute::Type::SignedScaled:
+            UNIMPLEMENTED();
+            return {Type::Float, t_in_float, t_in_float4};
+        default:
+            UNREACHABLE();
+            return {Type::Float, t_in_float, t_in_float4};
+        }
+    }
+
+    Id GetTypeDefinition(Type type) const {
         switch (type) {
         case Type::Bool:
             return t_bool;
@@ -1314,10 +2153,25 @@ private:
         case Type::Uint:
             return t_uint;
         case Type::HalfFloat:
-            UNIMPLEMENTED();
+            return t_half;
+        default:
+            UNREACHABLE();
+            return {};
+        }
+    }
+
+    std::array<Id, 4> GetTypeVectorDefinitionLut(Type type) const {
+        switch (type) {
+        case Type::Float:
+            return {nullptr, t_float2, t_float3, t_float4};
+        case Type::Int:
+            return {nullptr, t_int2, t_int3, t_int4};
+        case Type::Uint:
+            return {nullptr, t_uint2, t_uint3, t_uint4};
+        default:
+            UNIMPLEMENTED();
+            return {};
         }
-        UNREACHABLE();
-        return {};
     }
 
     std::tuple<Id, Id> CreateFlowStack() {
@@ -1327,9 +2181,11 @@ private:
         constexpr auto storage_class = spv::StorageClass::Function;
 
         const Id flow_stack_type = TypeArray(t_uint, Constant(t_uint, FLOW_STACK_SIZE));
-        const Id stack = Emit(OpVariable(TypePointer(storage_class, flow_stack_type), storage_class,
-                                         ConstantNull(flow_stack_type)));
-        const Id top = Emit(OpVariable(t_func_uint, storage_class, Constant(t_uint, 0)));
+        const Id stack = OpVariable(TypePointer(storage_class, flow_stack_type), storage_class,
+                                    ConstantNull(flow_stack_type));
+        const Id top = OpVariable(t_func_uint, storage_class, Constant(t_uint, 0));
+        AddLocalVariable(stack);
+        AddLocalVariable(top);
         return std::tie(stack, top);
     }
 
@@ -1358,8 +2214,8 @@ private:
         &SPIRVDecompiler::Unary<&Module::OpFNegate, Type::Float>,
         &SPIRVDecompiler::Unary<&Module::OpFAbs, Type::Float>,
         &SPIRVDecompiler::Ternary<&Module::OpFClamp, Type::Float>,
-        &SPIRVDecompiler::FCastHalf0,
-        &SPIRVDecompiler::FCastHalf1,
+        &SPIRVDecompiler::FCastHalf<0>,
+        &SPIRVDecompiler::FCastHalf<1>,
         &SPIRVDecompiler::Binary<&Module::OpFMin, Type::Float>,
         &SPIRVDecompiler::Binary<&Module::OpFMax, Type::Float>,
         &SPIRVDecompiler::Unary<&Module::OpCos, Type::Float>,
@@ -1407,7 +2263,7 @@ private:
         &SPIRVDecompiler::Unary<&Module::OpBitcast, Type::Uint, Type::Int>,
         &SPIRVDecompiler::Binary<&Module::OpShiftLeftLogical, Type::Uint>,
         &SPIRVDecompiler::Binary<&Module::OpShiftRightLogical, Type::Uint>,
-        &SPIRVDecompiler::Binary<&Module::OpShiftRightArithmetic, Type::Uint>,
+        &SPIRVDecompiler::Binary<&Module::OpShiftRightLogical, Type::Uint>,
         &SPIRVDecompiler::Binary<&Module::OpBitwiseAnd, Type::Uint>,
         &SPIRVDecompiler::Binary<&Module::OpBitwiseOr, Type::Uint>,
         &SPIRVDecompiler::Binary<&Module::OpBitwiseXor, Type::Uint>,
@@ -1426,8 +2282,8 @@ private:
         &SPIRVDecompiler::HCastFloat,
         &SPIRVDecompiler::HUnpack,
         &SPIRVDecompiler::HMergeF32,
-        &SPIRVDecompiler::HMergeH0,
-        &SPIRVDecompiler::HMergeH1,
+        &SPIRVDecompiler::HMergeHN<0>,
+        &SPIRVDecompiler::HMergeHN<1>,
         &SPIRVDecompiler::HPack2,
 
         &SPIRVDecompiler::LogicalAssign,
@@ -1435,8 +2291,9 @@ private:
         &SPIRVDecompiler::Binary<&Module::OpLogicalOr, Type::Bool>,
         &SPIRVDecompiler::Binary<&Module::OpLogicalNotEqual, Type::Bool>,
         &SPIRVDecompiler::Unary<&Module::OpLogicalNot, Type::Bool>,
-        &SPIRVDecompiler::LogicalPick2,
-        &SPIRVDecompiler::LogicalAnd2,
+        &SPIRVDecompiler::Binary<&Module::OpVectorExtractDynamic, Type::Bool, Type::Bool2,
+                                 Type::Uint>,
+        &SPIRVDecompiler::Unary<&Module::OpAll, Type::Bool, Type::Bool2>,
 
         &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool, Type::Float>,
         &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool, Type::Float>,
@@ -1444,7 +2301,7 @@ private:
         &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool, Type::Float>,
         &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool, Type::Float>,
         &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool, Type::Float>,
-        &SPIRVDecompiler::Unary<&Module::OpIsNan, Type::Bool>,
+        &SPIRVDecompiler::Unary<&Module::OpIsNan, Type::Bool, Type::Float>,
 
         &SPIRVDecompiler::Binary<&Module::OpSLessThan, Type::Bool, Type::Int>,
         &SPIRVDecompiler::Binary<&Module::OpIEqual, Type::Bool, Type::Int>,
@@ -1460,19 +2317,19 @@ private:
         &SPIRVDecompiler::Binary<&Module::OpINotEqual, Type::Bool, Type::Uint>,
         &SPIRVDecompiler::Binary<&Module::OpUGreaterThanEqual, Type::Bool, Type::Uint>,
 
-        &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool, Type::HalfFloat>,
-        &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool, Type::HalfFloat>,
-        &SPIRVDecompiler::Binary<&Module::OpFOrdLessThanEqual, Type::Bool, Type::HalfFloat>,
-        &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool, Type::HalfFloat>,
-        &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool, Type::HalfFloat>,
-        &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool, Type::HalfFloat>,
+        &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool2, Type::HalfFloat>,
+        &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool2, Type::HalfFloat>,
+        &SPIRVDecompiler::Binary<&Module::OpFOrdLessThanEqual, Type::Bool2, Type::HalfFloat>,
+        &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool2, Type::HalfFloat>,
+        &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool2, Type::HalfFloat>,
+        &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool2, Type::HalfFloat>,
         // TODO(Rodrigo): Should these use the OpFUnord* variants?
-        &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool, Type::HalfFloat>,
-        &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool, Type::HalfFloat>,
-        &SPIRVDecompiler::Binary<&Module::OpFOrdLessThanEqual, Type::Bool, Type::HalfFloat>,
-        &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool, Type::HalfFloat>,
-        &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool, Type::HalfFloat>,
-        &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool, Type::HalfFloat>,
+        &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool2, Type::HalfFloat>,
+        &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool2, Type::HalfFloat>,
+        &SPIRVDecompiler::Binary<&Module::OpFOrdLessThanEqual, Type::Bool2, Type::HalfFloat>,
+        &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool2, Type::HalfFloat>,
+        &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool2, Type::HalfFloat>,
+        &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool2, Type::HalfFloat>,
 
         &SPIRVDecompiler::Texture,
         &SPIRVDecompiler::TextureLod,
@@ -1509,9 +2366,9 @@ private:
         &SPIRVDecompiler::WorkGroupId<2>,
 
         &SPIRVDecompiler::BallotThread,
-        &SPIRVDecompiler::VoteAll,
-        &SPIRVDecompiler::VoteAny,
-        &SPIRVDecompiler::VoteEqual,
+        &SPIRVDecompiler::Vote<&Module::OpSubgroupAllKHR>,
+        &SPIRVDecompiler::Vote<&Module::OpSubgroupAnyKHR>,
+        &SPIRVDecompiler::Vote<&Module::OpSubgroupAllEqualKHR>,
 
         &SPIRVDecompiler::ThreadId,
         &SPIRVDecompiler::ShuffleIndexed,
@@ -1522,8 +2379,7 @@ private:
     const ShaderIR& ir;
     const ShaderType stage;
     const Tegra::Shader::Header header;
-    u64 conditional_nest_count{};
-    u64 inside_branch{};
+    const Specialization& specialization;
 
     const Id t_void = Name(TypeVoid(), "void");
 
@@ -1551,20 +2407,28 @@ private:
     const Id t_func_uint = Name(TypePointer(spv::StorageClass::Function, t_uint), "func_uint");
 
     const Id t_in_bool = Name(TypePointer(spv::StorageClass::Input, t_bool), "in_bool");
+    const Id t_in_int = Name(TypePointer(spv::StorageClass::Input, t_int), "in_int");
+    const Id t_in_int4 = Name(TypePointer(spv::StorageClass::Input, t_int4), "in_int4");
     const Id t_in_uint = Name(TypePointer(spv::StorageClass::Input, t_uint), "in_uint");
+    const Id t_in_uint3 = Name(TypePointer(spv::StorageClass::Input, t_uint3), "in_uint3");
+    const Id t_in_uint4 = Name(TypePointer(spv::StorageClass::Input, t_uint4), "in_uint4");
     const Id t_in_float = Name(TypePointer(spv::StorageClass::Input, t_float), "in_float");
+    const Id t_in_float2 = Name(TypePointer(spv::StorageClass::Input, t_float2), "in_float2");
+    const Id t_in_float3 = Name(TypePointer(spv::StorageClass::Input, t_float3), "in_float3");
     const Id t_in_float4 = Name(TypePointer(spv::StorageClass::Input, t_float4), "in_float4");
 
+    const Id t_out_int = Name(TypePointer(spv::StorageClass::Output, t_int), "out_int");
+
     const Id t_out_float = Name(TypePointer(spv::StorageClass::Output, t_float), "out_float");
     const Id t_out_float4 = Name(TypePointer(spv::StorageClass::Output, t_float4), "out_float4");
 
     const Id t_cbuf_float = TypePointer(spv::StorageClass::Uniform, t_float);
     const Id t_cbuf_std140 = Decorate(
-        Name(TypeArray(t_float4, Constant(t_uint, MAX_CONSTBUFFER_ELEMENTS)), "CbufStd140Array"),
-        spv::Decoration::ArrayStride, 16u);
+        Name(TypeArray(t_float4, Constant(t_uint, MaxConstBufferElements)), "CbufStd140Array"),
+        spv::Decoration::ArrayStride, 16U);
     const Id t_cbuf_scalar = Decorate(
-        Name(TypeArray(t_float, Constant(t_uint, MAX_CONSTBUFFER_FLOATS)), "CbufScalarArray"),
-        spv::Decoration::ArrayStride, 4u);
+        Name(TypeArray(t_float, Constant(t_uint, MaxConstBufferFloats)), "CbufScalarArray"),
+        spv::Decoration::ArrayStride, 4U);
     const Id t_cbuf_std140_struct = MemberDecorate(
         Decorate(TypeStruct(t_cbuf_std140), spv::Decoration::Block), 0, spv::Decoration::Offset, 0);
     const Id t_cbuf_scalar_struct = MemberDecorate(
@@ -1572,28 +2436,43 @@ private:
     const Id t_cbuf_std140_ubo = TypePointer(spv::StorageClass::Uniform, t_cbuf_std140_struct);
     const Id t_cbuf_scalar_ubo = TypePointer(spv::StorageClass::Uniform, t_cbuf_scalar_struct);
 
+    Id t_smem_uint{};
+
     const Id t_gmem_float = TypePointer(spv::StorageClass::StorageBuffer, t_float);
     const Id t_gmem_array =
-        Name(Decorate(TypeRuntimeArray(t_float), spv::Decoration::ArrayStride, 4u), "GmemArray");
+        Name(Decorate(TypeRuntimeArray(t_float), spv::Decoration::ArrayStride, 4U), "GmemArray");
     const Id t_gmem_struct = MemberDecorate(
         Decorate(TypeStruct(t_gmem_array), spv::Decoration::Block), 0, spv::Decoration::Offset, 0);
     const Id t_gmem_ssbo = TypePointer(spv::StorageClass::StorageBuffer, t_gmem_struct);
 
     const Id v_float_zero = Constant(t_float, 0.0f);
+    const Id v_float_one = Constant(t_float, 1.0f);
+
+    // Nvidia uses these defaults for varyings (e.g. position and generic attributes)
+    const Id v_varying_default =
+        ConstantComposite(t_float4, v_float_zero, v_float_zero, v_float_zero, v_float_one);
+
     const Id v_true = ConstantTrue(t_bool);
     const Id v_false = ConstantFalse(t_bool);
 
-    Id per_vertex{};
+    Id t_scalar_half{};
+    Id t_half{};
+
+    Id out_vertex{};
+    Id in_vertex{};
     std::map<u32, Id> registers;
     std::map<Tegra::Shader::Pred, Id> predicates;
     std::map<u32, Id> flow_variables;
     Id local_memory{};
+    Id shared_memory{};
     std::array<Id, INTERNAL_FLAGS_COUNT> internal_flags{};
     std::map<Attribute::Index, Id> input_attributes;
     std::map<Attribute::Index, Id> output_attributes;
     std::map<u32, Id> constant_buffers;
     std::map<GlobalMemoryBase, Id> global_buffers;
-    std::map<u32, SamplerImage> sampler_images;
+    std::map<u32, TexelBuffer> texel_buffers;
+    std::map<u32, SampledImage> sampled_images;
+    std::map<u32, StorageImage> images;
 
     Id instance_index{};
     Id vertex_index{};
@@ -1601,18 +2480,20 @@ private:
     Id frag_depth{};
     Id frag_coord{};
     Id front_facing{};
+    Id point_coord{};
+    Id tess_level_outer{};
+    Id tess_level_inner{};
+    Id tess_coord{};
+    Id invocation_id{};
+    Id workgroup_id{};
+    Id local_invocation_id{};
+    Id thread_id{};
 
-    u32 position_index{};
-    u32 point_size_index{};
-    u32 clip_distances_index{};
+    VertexIndices in_indices;
+    VertexIndices out_indices;
 
     std::vector<Id> interfaces;
 
-    u32 const_buffers_base_binding{};
-    u32 global_buffers_base_binding{};
-    u32 samplers_base_binding{};
-
-    Id execute_function{};
     Id jmp_to{};
     Id ssy_flow_stack_top{};
     Id pbk_flow_stack_top{};
@@ -1620,6 +2501,9 @@ private:
     Id pbk_flow_stack{};
     Id continue_label{};
     std::map<u32, Id> labels;
+
+    bool conditional_branch_set{};
+    bool inside_branch{};
 };
 
 class ExprDecompiler {
@@ -1630,25 +2514,25 @@ public:
         const Id type_def = decomp.GetTypeDefinition(Type::Bool);
         const Id op1 = Visit(expr.operand1);
         const Id op2 = Visit(expr.operand2);
-        return decomp.Emit(decomp.OpLogicalAnd(type_def, op1, op2));
+        return decomp.OpLogicalAnd(type_def, op1, op2);
     }
 
     Id operator()(const ExprOr& expr) {
         const Id type_def = decomp.GetTypeDefinition(Type::Bool);
         const Id op1 = Visit(expr.operand1);
         const Id op2 = Visit(expr.operand2);
-        return decomp.Emit(decomp.OpLogicalOr(type_def, op1, op2));
+        return decomp.OpLogicalOr(type_def, op1, op2);
     }
 
     Id operator()(const ExprNot& expr) {
         const Id type_def = decomp.GetTypeDefinition(Type::Bool);
         const Id op1 = Visit(expr.operand1);
-        return decomp.Emit(decomp.OpLogicalNot(type_def, op1));
+        return decomp.OpLogicalNot(type_def, op1);
     }
 
     Id operator()(const ExprPredicate& expr) {
         const auto pred = static_cast<Tegra::Shader::Pred>(expr.predicate);
-        return decomp.Emit(decomp.OpLoad(decomp.t_bool, decomp.predicates.at(pred)));
+        return decomp.OpLoad(decomp.t_bool, decomp.predicates.at(pred));
     }
 
     Id operator()(const ExprCondCode& expr) {
@@ -1670,12 +2554,15 @@ public:
             }
         } else if (const auto flag = std::get_if<InternalFlagNode>(&*cc)) {
             target = decomp.internal_flags.at(static_cast<u32>(flag->GetFlag()));
+        } else {
+            UNREACHABLE();
         }
-        return decomp.Emit(decomp.OpLoad(decomp.t_bool, target));
+
+        return decomp.OpLoad(decomp.t_bool, target);
     }
 
     Id operator()(const ExprVar& expr) {
-        return decomp.Emit(decomp.OpLoad(decomp.t_bool, decomp.flow_variables.at(expr.var_index)));
+        return decomp.OpLoad(decomp.t_bool, decomp.flow_variables.at(expr.var_index));
     }
 
     Id operator()(const ExprBoolean& expr) {
@@ -1684,9 +2571,9 @@ public:
 
     Id operator()(const ExprGprEqual& expr) {
         const Id target = decomp.Constant(decomp.t_uint, expr.value);
-        const Id gpr = decomp.BitcastTo<Type::Uint>(
-            decomp.Emit(decomp.OpLoad(decomp.t_float, decomp.registers.at(expr.gpr))));
-        return decomp.Emit(decomp.OpLogicalEqual(decomp.t_uint, gpr, target));
+        Id gpr = decomp.OpLoad(decomp.t_float, decomp.registers.at(expr.gpr));
+        gpr = decomp.OpBitcast(decomp.t_uint, gpr);
+        return decomp.OpLogicalEqual(decomp.t_uint, gpr, target);
     }
 
     Id Visit(const Expr& node) {
@@ -1714,16 +2601,16 @@ public:
         const Id condition = expr_parser.Visit(ast.condition);
         const Id then_label = decomp.OpLabel();
         const Id endif_label = decomp.OpLabel();
-        decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone));
-        decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label));
-        decomp.Emit(then_label);
+        decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone);
+        decomp.OpBranchConditional(condition, then_label, endif_label);
+        decomp.AddLabel(then_label);
         ASTNode current = ast.nodes.GetFirst();
         while (current) {
             Visit(current);
             current = current->GetNext();
         }
-        decomp.Emit(decomp.OpBranch(endif_label));
-        decomp.Emit(endif_label);
+        decomp.OpBranch(endif_label);
+        decomp.AddLabel(endif_label);
     }
 
     void operator()([[maybe_unused]] const ASTIfElse& ast) {
@@ -1741,7 +2628,7 @@ public:
     void operator()(const ASTVarSet& ast) {
         ExprDecompiler expr_parser{decomp};
         const Id condition = expr_parser.Visit(ast.condition);
-        decomp.Emit(decomp.OpStore(decomp.flow_variables.at(ast.index), condition));
+        decomp.OpStore(decomp.flow_variables.at(ast.index), condition);
     }
 
     void operator()([[maybe_unused]] const ASTLabel& ast) {
@@ -1758,12 +2645,11 @@ public:
         const Id loop_start_block = decomp.OpLabel();
         const Id loop_end_block = decomp.OpLabel();
         current_loop_exit = endloop_label;
-        decomp.Emit(decomp.OpBranch(loop_label));
-        decomp.Emit(loop_label);
-        decomp.Emit(
-            decomp.OpLoopMerge(endloop_label, loop_end_block, spv::LoopControlMask::MaskNone));
-        decomp.Emit(decomp.OpBranch(loop_start_block));
-        decomp.Emit(loop_start_block);
+        decomp.OpBranch(loop_label);
+        decomp.AddLabel(loop_label);
+        decomp.OpLoopMerge(endloop_label, loop_end_block, spv::LoopControlMask::MaskNone);
+        decomp.OpBranch(loop_start_block);
+        decomp.AddLabel(loop_start_block);
         ASTNode current = ast.nodes.GetFirst();
         while (current) {
             Visit(current);
@@ -1771,8 +2657,8 @@ public:
         }
         ExprDecompiler expr_parser{decomp};
         const Id condition = expr_parser.Visit(ast.condition);
-        decomp.Emit(decomp.OpBranchConditional(condition, loop_label, endloop_label));
-        decomp.Emit(endloop_label);
+        decomp.OpBranchConditional(condition, loop_label, endloop_label);
+        decomp.AddLabel(endloop_label);
     }
 
     void operator()(const ASTReturn& ast) {
@@ -1781,27 +2667,27 @@ public:
             const Id condition = expr_parser.Visit(ast.condition);
             const Id then_label = decomp.OpLabel();
             const Id endif_label = decomp.OpLabel();
-            decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone));
-            decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label));
-            decomp.Emit(then_label);
+            decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone);
+            decomp.OpBranchConditional(condition, then_label, endif_label);
+            decomp.AddLabel(then_label);
             if (ast.kills) {
-                decomp.Emit(decomp.OpKill());
+                decomp.OpKill();
             } else {
                 decomp.PreExit();
-                decomp.Emit(decomp.OpReturn());
+                decomp.OpReturn();
             }
-            decomp.Emit(endif_label);
+            decomp.AddLabel(endif_label);
         } else {
             const Id next_block = decomp.OpLabel();
-            decomp.Emit(decomp.OpBranch(next_block));
-            decomp.Emit(next_block);
+            decomp.OpBranch(next_block);
+            decomp.AddLabel(next_block);
             if (ast.kills) {
-                decomp.Emit(decomp.OpKill());
+                decomp.OpKill();
             } else {
                 decomp.PreExit();
-                decomp.Emit(decomp.OpReturn());
+                decomp.OpReturn();
             }
-            decomp.Emit(decomp.OpLabel());
+            decomp.AddLabel(decomp.OpLabel());
         }
     }
 
@@ -1811,17 +2697,17 @@ public:
             const Id condition = expr_parser.Visit(ast.condition);
             const Id then_label = decomp.OpLabel();
             const Id endif_label = decomp.OpLabel();
-            decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone));
-            decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label));
-            decomp.Emit(then_label);
-            decomp.Emit(decomp.OpBranch(current_loop_exit));
-            decomp.Emit(endif_label);
+            decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone);
+            decomp.OpBranchConditional(condition, then_label, endif_label);
+            decomp.AddLabel(then_label);
+            decomp.OpBranch(current_loop_exit);
+            decomp.AddLabel(endif_label);
         } else {
             const Id next_block = decomp.OpLabel();
-            decomp.Emit(decomp.OpBranch(next_block));
-            decomp.Emit(next_block);
-            decomp.Emit(decomp.OpBranch(current_loop_exit));
-            decomp.Emit(decomp.OpLabel());
+            decomp.OpBranch(next_block);
+            decomp.AddLabel(next_block);
+            decomp.OpBranch(current_loop_exit);
+            decomp.AddLabel(decomp.OpLabel());
         }
     }
 
@@ -1842,20 +2728,51 @@ void SPIRVDecompiler::DecompileAST() {
         flow_variables.emplace(i, AddGlobalVariable(id));
     }
 
+    DefinePrologue();
+
     const ASTNode program = ir.GetASTProgram();
     ASTDecompiler decompiler{*this};
     decompiler.Visit(program);
 
     const Id next_block = OpLabel();
-    Emit(OpBranch(next_block));
-    Emit(next_block);
+    OpBranch(next_block);
+    AddLabel(next_block);
 }
 
-DecompilerResult Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir,
-                           ShaderType stage) {
-    auto decompiler = std::make_unique<SPIRVDecompiler>(device, ir, stage);
-    decompiler->Decompile();
-    return {std::move(decompiler), decompiler->GetShaderEntries()};
+} // Anonymous namespace
+
+ShaderEntries GenerateShaderEntries(const VideoCommon::Shader::ShaderIR& ir) {
+    ShaderEntries entries;
+    for (const auto& cbuf : ir.GetConstantBuffers()) {
+        entries.const_buffers.emplace_back(cbuf.second, cbuf.first);
+    }
+    for (const auto& [base, usage] : ir.GetGlobalMemory()) {
+        entries.global_buffers.emplace_back(base.cbuf_index, base.cbuf_offset, usage.is_written);
+    }
+    for (const auto& sampler : ir.GetSamplers()) {
+        if (sampler.IsBuffer()) {
+            entries.texel_buffers.emplace_back(sampler);
+        } else {
+            entries.samplers.emplace_back(sampler);
+        }
+    }
+    for (const auto& image : ir.GetImages()) {
+        entries.images.emplace_back(image);
+    }
+    for (const auto& attribute : ir.GetInputAttributes()) {
+        if (IsGenericAttribute(attribute)) {
+            entries.attributes.insert(GetGenericAttributeLocation(attribute));
+        }
+    }
+    entries.clip_distances = ir.GetClipDistances();
+    entries.shader_length = ir.GetLength();
+    entries.uses_warps = ir.UsesWarps();
+    return entries;
 }
 
-} // namespace Vulkan::VKShader
+std::vector<u32> Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir,
+                           ShaderType stage, const Specialization& specialization) {
+    return SPIRVDecompiler(device, ir, stage, specialization).Assemble();
+}
+
+} // namespace Vulkan
diff --git a/src/video_core/renderer_vulkan/vk_shader_decompiler.h b/src/video_core/renderer_vulkan/vk_shader_decompiler.h
index 203fc00d0..2b01321b6 100644
--- a/src/video_core/renderer_vulkan/vk_shader_decompiler.h
+++ b/src/video_core/renderer_vulkan/vk_shader_decompiler.h
@@ -5,29 +5,28 @@
 #pragma once
 
 #include <array>
+#include <bitset>
 #include <memory>
 #include <set>
+#include <type_traits>
 #include <utility>
 #include <vector>
 
-#include <sirit/sirit.h>
-
 #include "common/common_types.h"
 #include "video_core/engines/maxwell_3d.h"
+#include "video_core/engines/shader_type.h"
 #include "video_core/shader/shader_ir.h"
 
-namespace VideoCommon::Shader {
-class ShaderIR;
-}
-
 namespace Vulkan {
 class VKDevice;
 }
 
-namespace Vulkan::VKShader {
+namespace Vulkan {
 
 using Maxwell = Tegra::Engines::Maxwell3D::Regs;
+using TexelBufferEntry = VideoCommon::Shader::Sampler;
 using SamplerEntry = VideoCommon::Shader::Sampler;
+using ImageEntry = VideoCommon::Shader::Image;
 
 constexpr u32 DESCRIPTOR_SET = 0;
 
@@ -46,39 +45,74 @@ private:
 
 class GlobalBufferEntry {
 public:
-    explicit GlobalBufferEntry(u32 cbuf_index, u32 cbuf_offset)
-        : cbuf_index{cbuf_index}, cbuf_offset{cbuf_offset} {}
+    constexpr explicit GlobalBufferEntry(u32 cbuf_index, u32 cbuf_offset, bool is_written)
+        : cbuf_index{cbuf_index}, cbuf_offset{cbuf_offset}, is_written{is_written} {}
 
-    u32 GetCbufIndex() const {
+    constexpr u32 GetCbufIndex() const {
         return cbuf_index;
     }
 
-    u32 GetCbufOffset() const {
+    constexpr u32 GetCbufOffset() const {
         return cbuf_offset;
     }
 
+    constexpr bool IsWritten() const {
+        return is_written;
+    }
+
 private:
     u32 cbuf_index{};
     u32 cbuf_offset{};
+    bool is_written{};
 };
 
 struct ShaderEntries {
-    u32 const_buffers_base_binding{};
-    u32 global_buffers_base_binding{};
-    u32 samplers_base_binding{};
+    u32 NumBindings() const {
+        return static_cast<u32>(const_buffers.size() + global_buffers.size() +
+                                texel_buffers.size() + samplers.size() + images.size());
+    }
+
     std::vector<ConstBufferEntry> const_buffers;
     std::vector<GlobalBufferEntry> global_buffers;
+    std::vector<TexelBufferEntry> texel_buffers;
     std::vector<SamplerEntry> samplers;
+    std::vector<ImageEntry> images;
     std::set<u32> attributes;
     std::array<bool, Maxwell::NumClipDistances> clip_distances{};
     std::size_t shader_length{};
-    Sirit::Id entry_function{};
-    std::vector<Sirit::Id> interfaces;
+    bool uses_warps{};
 };
 
-using DecompilerResult = std::pair<std::unique_ptr<Sirit::Module>, ShaderEntries>;
+struct Specialization final {
+    u32 base_binding{};
 
-DecompilerResult Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir,
-                           Tegra::Engines::ShaderType stage);
+    // Compute specific
+    std::array<u32, 3> workgroup_size{};
+    u32 shared_memory_size{};
 
-} // namespace Vulkan::VKShader
+    // Graphics specific
+    Maxwell::PrimitiveTopology primitive_topology{};
+    std::optional<float> point_size{};
+    std::array<Maxwell::VertexAttribute::Type, Maxwell::NumVertexAttributes> attribute_types{};
+
+    // Tessellation specific
+    struct {
+        Maxwell::TessellationPrimitive primitive{};
+        Maxwell::TessellationSpacing spacing{};
+        bool clockwise{};
+    } tessellation;
+};
+// Old gcc versions don't consider this trivially copyable.
+// static_assert(std::is_trivially_copyable_v<Specialization>);
+
+struct SPIRVShader {
+    std::vector<u32> code;
+    ShaderEntries entries;
+};
+
+ShaderEntries GenerateShaderEntries(const VideoCommon::Shader::ShaderIR& ir);
+
+std::vector<u32> Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir,
+                           Tegra::Engines::ShaderType stage, const Specialization& specialization);
+
+} // namespace Vulkan