From d5d468cf2cbe235ee149dbd37951389d2a7e61da Mon Sep 17 00:00:00 2001
From: ReinUsesLisp <reinuseslisp@airmail.cc>
Date: Mon, 15 Feb 2021 00:09:11 -0300
Subject: [PATCH] shader: Improve object pool

---
 .../frontend/ir/structured_control_flow.cpp   | 10 +--
 src/shader_recompiler/main.cpp                | 22 ++---
 src/shader_recompiler/object_pool.h           | 84 +++++++++++--------
 3 files changed, 66 insertions(+), 50 deletions(-)

diff --git a/src/shader_recompiler/frontend/ir/structured_control_flow.cpp b/src/shader_recompiler/frontend/ir/structured_control_flow.cpp
index 2e9ce2525..d145095d1 100644
--- a/src/shader_recompiler/frontend/ir/structured_control_flow.cpp
+++ b/src/shader_recompiler/frontend/ir/structured_control_flow.cpp
@@ -269,7 +269,7 @@ bool SearchNode(const Tree& tree, ConstNode stmt, size_t& offset) {
 
 class GotoPass {
 public:
-    explicit GotoPass(std::span<Block* const> blocks, ObjectPool<Statement, 64>& stmt_pool)
+    explicit GotoPass(std::span<Block* const> blocks, ObjectPool<Statement>& stmt_pool)
         : pool{stmt_pool} {
         std::vector gotos{BuildUnorderedTreeGetGotos(blocks)};
         fmt::print(stdout, "BEFORE\n{}\n", DumpTree(root_stmt.children));
@@ -554,7 +554,7 @@ private:
         return offset;
     }
 
-    ObjectPool<Statement, 64>& pool;
+    ObjectPool<Statement>& pool;
     Statement root_stmt{FunctionTag{}};
 };
 
@@ -589,7 +589,7 @@ Block* TryFindForwardBlock(const Statement& stmt) {
 class TranslatePass {
 public:
     TranslatePass(ObjectPool<Inst>& inst_pool_, ObjectPool<Block>& block_pool_,
-                  ObjectPool<Statement, 64>& stmt_pool_, Statement& root_stmt,
+                  ObjectPool<Statement>& stmt_pool_, Statement& root_stmt,
                   const std::function<void(IR::Block*)>& func_, BlockList& block_list_)
         : stmt_pool{stmt_pool_}, inst_pool{inst_pool_}, block_pool{block_pool_}, func{func_},
           block_list{block_list_} {
@@ -720,7 +720,7 @@ private:
         return block;
     }
 
-    ObjectPool<Statement, 64>& stmt_pool;
+    ObjectPool<Statement>& stmt_pool;
     ObjectPool<Inst>& inst_pool;
     ObjectPool<Block>& block_pool;
     const std::function<void(IR::Block*)>& func;
@@ -731,7 +731,7 @@ private:
 BlockList VisitAST(ObjectPool<Inst>& inst_pool, ObjectPool<Block>& block_pool,
                    std::span<Block* const> unordered_blocks,
                    const std::function<void(Block*)>& func) {
-    ObjectPool<Statement, 64> stmt_pool;
+    ObjectPool<Statement> stmt_pool{64};
     GotoPass goto_pass{unordered_blocks, stmt_pool};
     BlockList block_list;
     TranslatePass translate_pass{inst_pool, block_pool, stmt_pool, goto_pass.RootStatement(),
diff --git a/src/shader_recompiler/main.cpp b/src/shader_recompiler/main.cpp
index 3b110af61..216345e91 100644
--- a/src/shader_recompiler/main.cpp
+++ b/src/shader_recompiler/main.cpp
@@ -37,7 +37,7 @@ void RunDatabase() {
     ForEachFile("D:\\Shaders\\Database", [&](const std::filesystem::path& path) {
         map.emplace_back(std::make_unique<FileEnvironment>(path.string().c_str()));
     });
-    auto block_pool{std::make_unique<ObjectPool<Flow::Block>>()};
+    ObjectPool<Flow::Block> block_pool;
     using namespace std::chrono;
     auto t0 = high_resolution_clock::now();
     int N = 1;
@@ -48,8 +48,8 @@ void RunDatabase() {
             // fmt::print(stdout, "Decoding {}\n", path.string());
 
             const Location start_address{0};
-            block_pool->ReleaseContents();
-            Flow::CFG cfg{*env, *block_pool, start_address};
+            block_pool.ReleaseContents();
+            Flow::CFG cfg{*env, block_pool, start_address};
             // fmt::print(stdout, "{}\n", cfg->Dot());
             // IR::Program program{env, cfg};
             // Optimize(program);
@@ -63,18 +63,18 @@ void RunDatabase() {
 int main() {
     // RunDatabase();
 
-    auto flow_block_pool{std::make_unique<ObjectPool<Flow::Block>>()};
-    auto inst_pool{std::make_unique<ObjectPool<IR::Inst>>()};
-    auto block_pool{std::make_unique<ObjectPool<IR::Block>>()};
+    ObjectPool<Flow::Block> flow_block_pool;
+    ObjectPool<IR::Inst> inst_pool;
+    ObjectPool<IR::Block> block_pool;
 
     // FileEnvironment env{"D:\\Shaders\\Database\\Oninaki\\CS8F146B41DB6BD826.bin"};
     FileEnvironment env{"D:\\Shaders\\shader.bin"};
-    block_pool->ReleaseContents();
-    inst_pool->ReleaseContents();
-    flow_block_pool->ReleaseContents();
-    Flow::CFG cfg{env, *flow_block_pool, 0};
+    block_pool.ReleaseContents();
+    inst_pool.ReleaseContents();
+    flow_block_pool.ReleaseContents();
+    Flow::CFG cfg{env, flow_block_pool, 0};
     fmt::print(stdout, "{}\n", cfg.Dot());
-    IR::Program program{TranslateProgram(*inst_pool, *block_pool, env, cfg)};
+    IR::Program program{TranslateProgram(inst_pool, block_pool, env, cfg)};
     fmt::print(stdout, "{}\n", IR::DumpProgram(program));
     Backend::SPIRV::EmitSPIRV spirv{program};
 }
diff --git a/src/shader_recompiler/object_pool.h b/src/shader_recompiler/object_pool.h
index a573add32..f78813b5f 100644
--- a/src/shader_recompiler/object_pool.h
+++ b/src/shader_recompiler/object_pool.h
@@ -10,19 +10,11 @@
 
 namespace Shader {
 
-template <typename T, size_t chunk_size = 8192>
+template <typename T>
 requires std::is_destructible_v<T> class ObjectPool {
 public:
-    ~ObjectPool() {
-        std::unique_ptr<Chunk> tree_owner;
-        Chunk* chunk{&root};
-        while (chunk) {
-            for (size_t obj_id = chunk->free_objects; obj_id < chunk_size; ++obj_id) {
-                chunk->storage[obj_id].object.~T();
-            }
-            tree_owner = std::move(chunk->next);
-            chunk = tree_owner.get();
-        }
+    explicit ObjectPool(size_t chunk_size = 8192) : new_chunk_size{chunk_size} {
+        node = &chunks.emplace_back(new_chunk_size);
     }
 
     template <typename... Args>
@@ -31,17 +23,21 @@ public:
     }
 
     void ReleaseContents() {
-        Chunk* chunk{&root};
-        while (chunk) {
-            if (chunk->free_objects == chunk_size) {
-                break;
-            }
-            for (; chunk->free_objects < chunk_size; ++chunk->free_objects) {
-                chunk->storage[chunk->free_objects].object.~T();
-            }
-            chunk = chunk->next.get();
+        if (chunks.empty()) {
+            return;
+        }
+        Chunk& root{chunks.front()};
+        if (root.used_objects == root.num_objects) {
+            // Root chunk has been filled, squash allocations into it
+            const size_t total_objects{root.num_objects + new_chunk_size * (chunks.size() - 1)};
+            chunks.clear();
+            chunks.emplace_back(total_objects);
+            chunks.shrink_to_fit();
+        } else {
+            root.Release();
+            chunks.resize(1);
+            chunks.shrink_to_fit();
         }
-        node = &root;
     }
 
 private:
@@ -58,31 +54,51 @@ private:
     };
 
     struct Chunk {
-        size_t free_objects = chunk_size;
-        std::array<Storage, chunk_size> storage;
-        std::unique_ptr<Chunk> next;
+        explicit Chunk() = default;
+        explicit Chunk(size_t size)
+            : num_objects{size}, storage{std::make_unique<Storage[]>(size)} {}
+
+        Chunk& operator=(Chunk&& rhs) noexcept {
+            Release();
+            used_objects = std::exchange(rhs.used_objects, 0);
+            num_objects = std::exchange(rhs.num_objects, 0);
+            storage = std::move(rhs.storage);
+        }
+
+        Chunk(Chunk&& rhs) noexcept
+            : used_objects{std::exchange(rhs.used_objects, 0)},
+              num_objects{std::exchange(rhs.num_objects, 0)}, storage{std::move(rhs.storage)} {}
+
+        ~Chunk() {
+            Release();
+        }
+
+        void Release() {
+            std::destroy_n(storage.get(), used_objects);
+            used_objects = 0;
+        }
+
+        size_t used_objects{};
+        size_t num_objects{};
+        std::unique_ptr<Storage[]> storage;
     };
 
     [[nodiscard]] T* Memory() {
         Chunk* const chunk{FreeChunk()};
-        return &chunk->storage[--chunk->free_objects].object;
+        return &chunk->storage[chunk->used_objects++].object;
     }
 
     [[nodiscard]] Chunk* FreeChunk() {
-        if (node->free_objects > 0) {
+        if (node->used_objects != node->num_objects) {
             return node;
         }
-        if (node->next) {
-            node = node->next.get();
-            return node;
-        }
-        node->next = std::make_unique<Chunk>();
-        node = node->next.get();
+        node = &chunks.emplace_back(new_chunk_size);
         return node;
     }
 
-    Chunk* node{&root};
-    Chunk root;
+    Chunk* node{};
+    std::vector<Chunk> chunks;
+    size_t new_chunk_size{};
 };
 
 } // namespace Shader