From f4ba25bec13cf1053f4bcef951eafafee942a359 Mon Sep 17 00:00:00 2001 From: Patedam Date: Sat, 15 Mar 2025 21:17:44 -0400 Subject: [PATCH] Reload of shaders working! --- Juliet/include/Core/Common/CoreTypes.h | 12 +- Juliet/include/Graphics/Graphics.h | 10 +- .../Graphics/D3D12/D3D12GraphicsDevice.cpp | 1 + .../Graphics/D3D12/D3D12GraphicsPipeline.cpp | 121 +++++++++++++++++- .../Graphics/D3D12/D3D12GraphicsPipeline.h | 15 ++- Juliet/src/Graphics/Graphics.cpp | 8 ++ Juliet/src/Graphics/GraphicsDevice.h | 2 + JulietApp/main.cpp | 24 ++++ 8 files changed, 184 insertions(+), 9 deletions(-) diff --git a/Juliet/include/Core/Common/CoreTypes.h b/Juliet/include/Core/Common/CoreTypes.h index 1cf3cd0..8e60801 100644 --- a/Juliet/include/Core/Common/CoreTypes.h +++ b/Juliet/include/Core/Common/CoreTypes.h @@ -20,8 +20,8 @@ using size_t = std::size_t; struct ByteBuffer { - Byte* Data; - size_t Size; + Byte* Data; + size_t Size; }; using FunctionPtr = auto (*)(void) -> void; @@ -37,7 +37,7 @@ constexpr uint16 uint16Max = MaxValueOf(); constexpr uint32 uint32Max = MaxValueOf(); constexpr uint64 uint64Max = MaxValueOf(); -constexpr int8 int8Max = MaxValueOf(); -constexpr int16 int16Max = MaxValueOf(); -constexpr int32 int32Max = MaxValueOf(); -constexpr int64 int64Max = MaxValueOf(); +constexpr int8 int8Max = MaxValueOf(); +constexpr int16 int16Max = MaxValueOf(); +constexpr int32 int32Max = MaxValueOf(); +constexpr int64 int64Max = MaxValueOf(); diff --git a/Juliet/include/Graphics/Graphics.h b/Juliet/include/Graphics/Graphics.h index be3a915..24b17a8 100644 --- a/Juliet/include/Graphics/Graphics.h +++ b/Juliet/include/Graphics/Graphics.h @@ -10,10 +10,13 @@ #include #include +#ifdef JULIET_DEBUG +#define ALLOW_SHADER_HOT_RELOAD 1 +#endif + // Graphics Interface namespace Juliet { - // Opaque types struct CommandList; struct GraphicsDevice; @@ -129,5 +132,10 @@ namespace Juliet extern JULIET_API GraphicsPipeline* CreateGraphicsPipeline(NonNullPtr device, const GraphicsPipelineCreateInfo& createInfo); extern JULIET_API void DestroyGraphicsPipeline(NonNullPtr device, NonNullPtr graphicsPipeline); +#if ALLOW_SHADER_HOT_RELOAD + // Allows updating the graphics pipeline shaders. Can update either one or both shaders. + extern JULIET_API bool UpdateGraphicsPipelineShaders(NonNullPtr device, NonNullPtr graphicsPipeline, + Shader* optional_vertexShader, Shader* optional_fragmentShader); +#endif } // namespace Juliet diff --git a/Juliet/src/Graphics/D3D12/D3D12GraphicsDevice.cpp b/Juliet/src/Graphics/D3D12/D3D12GraphicsDevice.cpp index 14e27e0..6af389e 100644 --- a/Juliet/src/Graphics/D3D12/D3D12GraphicsDevice.cpp +++ b/Juliet/src/Graphics/D3D12/D3D12GraphicsDevice.cpp @@ -755,6 +755,7 @@ namespace Juliet::D3D12 device->DestroyShader = DestroyShader; device->CreateGraphicsPipeline = CreateGraphicsPipeline; device->DestroyGraphicsPipeline = DestroyGraphicsPipeline; + device->UpdateGraphicsPipelineShaders = UpdateGraphicsPipelineShaders; device->Driver = driver; device->DebugEnabled = enableDebug; diff --git a/Juliet/src/Graphics/D3D12/D3D12GraphicsPipeline.cpp b/Juliet/src/Graphics/D3D12/D3D12GraphicsPipeline.cpp index 7753abb..4ea4351 100644 --- a/Juliet/src/Graphics/D3D12/D3D12GraphicsPipeline.cpp +++ b/Juliet/src/Graphics/D3D12/D3D12GraphicsPipeline.cpp @@ -596,6 +596,25 @@ namespace Juliet::D3D12 } Free(rootSignature.Get()); } + + void CopyShader(NonNullPtr destination, NonNullPtr source) + { + D3D12Shader* src = source.Get(); + D3D12Shader* dst = destination.Get(); + + ByteBuffer dstBuffer = dst->ByteCode; + + if (src->ByteCode.Size != dstBuffer.Size) + { + dstBuffer.Data = static_cast(Realloc(dstBuffer.Data, src->ByteCode.Size)); + dstBuffer.Size = src->ByteCode.Size; + } + // Copy the shader data. Infortunately this will overwrite the bytecode if it exists so we patch it back just after + MemCopy(dst, src, sizeof(D3D12Shader)); + dst->ByteCode = dstBuffer; + + MemCopy(dst->ByteCode.Data, src->ByteCode.Data, src->ByteCode.Size); + } } // namespace GraphicsPipeline* CreateGraphicsPipeline(NonNullPtr driver, const GraphicsPipelineCreateInfo& createInfo) @@ -668,7 +687,7 @@ namespace Juliet::D3D12 return nullptr; } pipeline->RootSignature = rootSignature; - psoDesc.pRootSignature = rootSignature->Handle; + psoDesc.pRootSignature = rootSignature->Handle; ID3D12PipelineState* pipelineState; HRESULT res = ID3D12Device_CreateGraphicsPipelineState(d3d12Driver->D3D12Device, &psoDesc, IID_ID3D12PipelineState, @@ -702,6 +721,16 @@ namespace Juliet::D3D12 pipeline->ReferenceCount = 0; +#if ALLOW_SHADER_HOT_RELOAD + // Save the PSODesc and shaders to be able to recreate the graphics pipeline when needed + pipeline->PSODescTemplate = psoDesc; + + pipeline->VertexShaderCache = static_cast(Calloc(1, sizeof(D3D12Shader))); + pipeline->FragmentShaderCache = static_cast(Calloc(1, sizeof(D3D12Shader))); + CopyShader(pipeline->VertexShaderCache, vertexShader); + CopyShader(pipeline->FragmentShaderCache, fragmentShader); +#endif + return reinterpret_cast(pipeline); } @@ -721,6 +750,90 @@ namespace Juliet::D3D12 d3d12Driver->GraphicsPipelinesToDisposeCount += 1; } +#if ALLOW_SHADER_HOT_RELOAD + bool UpdateGraphicsPipelineShaders(NonNullPtr driver, NonNullPtr graphicsPipeline, + Shader* optional_vertexShader, Shader* optional_fragmentShader) + { + auto d3d12Driver = static_cast(driver.Get()); + auto d3d12GraphicsPipeline = reinterpret_cast(graphicsPipeline.Get()); + + Assert(d3d12GraphicsPipeline->ReferenceCount == 0 && + "Trying to update a d3d12 graphics pipeline that is currently being used! Call WaitUntilGPUIsIdle " + "before updating!"); + + auto vertexShader = reinterpret_cast(optional_vertexShader); + auto fragmentShader = reinterpret_cast(optional_fragmentShader); + + if (!vertexShader) + { + vertexShader = d3d12GraphicsPipeline->VertexShaderCache; + } + + if (!fragmentShader) + { + fragmentShader = d3d12GraphicsPipeline->FragmentShaderCache; + } + + // Recreate a new root signature and pipeline state + D3D12GraphicsRootSignature* rootSignature = CreateGraphicsRootSignature(d3d12Driver, vertexShader, fragmentShader); + if (rootSignature == nullptr) + { + return false; + } + + auto psoDesc = d3d12GraphicsPipeline->PSODescTemplate; + psoDesc.VS.pShaderBytecode = vertexShader->ByteCode.Data; + psoDesc.VS.BytecodeLength = vertexShader->ByteCode.Size; + psoDesc.PS.pShaderBytecode = fragmentShader->ByteCode.Data; + psoDesc.PS.BytecodeLength = fragmentShader->ByteCode.Size; + + psoDesc.pRootSignature = rootSignature->Handle; + + ID3D12PipelineState* pipelineState; + HRESULT res = ID3D12Device_CreateGraphicsPipelineState(d3d12Driver->D3D12Device, &psoDesc, IID_ID3D12PipelineState, + reinterpret_cast(&pipelineState)); + if (FAILED(res)) + { + LogError(d3d12Driver, "Could not create graphics pipeline state", res); + return false; + } + + d3d12GraphicsPipeline->VertexSamplerCount = vertexShader->NumSamplers; + d3d12GraphicsPipeline->VertexStorageTextureCount = vertexShader->NumStorageTextures; + d3d12GraphicsPipeline->VertexStorageBufferCount = vertexShader->NumStorageBuffers; + d3d12GraphicsPipeline->VertexUniformBufferCount = vertexShader->NumUniformBuffers; + + d3d12GraphicsPipeline->FragmentSamplerCount = fragmentShader->NumSamplers; + d3d12GraphicsPipeline->FragmentStorageTextureCount = fragmentShader->NumStorageTextures; + d3d12GraphicsPipeline->FragmentStorageBufferCount = fragmentShader->NumStorageBuffers; + d3d12GraphicsPipeline->FragmentUniformBufferCount = fragmentShader->NumUniformBuffers; + + // If everything worked, we patch the graphics pipeline and destroy everything irrelevant + if (d3d12GraphicsPipeline->PipelineState) + { + ID3D12PipelineState_Release(d3d12GraphicsPipeline->PipelineState); + } + d3d12GraphicsPipeline->PipelineState = pipelineState; + + if (d3d12GraphicsPipeline->RootSignature) + { + DestroyGraphicsRootSignature(d3d12GraphicsPipeline->RootSignature); + } + d3d12GraphicsPipeline->RootSignature = rootSignature; + + if (vertexShader != d3d12GraphicsPipeline->VertexShaderCache) + { + CopyShader(d3d12GraphicsPipeline->VertexShaderCache, vertexShader); + } + if (fragmentShader != d3d12GraphicsPipeline->FragmentShaderCache) + { + CopyShader(d3d12GraphicsPipeline->FragmentShaderCache, fragmentShader); + } + + return true; + } +#endif + namespace Internal { void ReleaseGraphicsPipeline(NonNullPtr d3d12GraphicsPipeline) @@ -730,6 +843,12 @@ namespace Juliet::D3D12 ID3D12PipelineState_Release(d3d12GraphicsPipeline->PipelineState); } DestroyGraphicsRootSignature(d3d12GraphicsPipeline->RootSignature); + +#if ALLOW_SHADER_HOT_RELOAD + SafeFree(d3d12GraphicsPipeline->VertexShaderCache); + SafeFree(d3d12GraphicsPipeline->FragmentShaderCache); +#endif + Free(d3d12GraphicsPipeline.Get()); } } // namespace Internal diff --git a/Juliet/src/Graphics/D3D12/D3D12GraphicsPipeline.h b/Juliet/src/Graphics/D3D12/D3D12GraphicsPipeline.h index 4100ed1..b9749ab 100644 --- a/Juliet/src/Graphics/D3D12/D3D12GraphicsPipeline.h +++ b/Juliet/src/Graphics/D3D12/D3D12GraphicsPipeline.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -35,6 +36,17 @@ namespace Juliet::D3D12 struct D3D12GraphicsPipeline { +#if ALLOW_SHADER_HOT_RELOAD + // Template to recreate an ID3D12PipelineState when shader are hot reloaded + // Stripped out in shipping build as the struct is huge + D3D12_GRAPHICS_PIPELINE_STATE_DESC PSODescTemplate; + + // Keeping shaders byte code to make it easier to recreate the ID3D12PipelineState + // Those will be freed when the pipeline is destroyed or updated + D3D12Shader * VertexShaderCache; + D3D12Shader * FragmentShaderCache; +#endif + ID3D12PipelineState* PipelineState; D3D12GraphicsRootSignature* RootSignature; PrimitiveType PrimitiveType; @@ -56,7 +68,8 @@ namespace Juliet::D3D12 extern GraphicsPipeline* CreateGraphicsPipeline(NonNullPtr driver, const GraphicsPipelineCreateInfo& createInfo); extern void DestroyGraphicsPipeline(NonNullPtr driver, NonNullPtr graphicsPipeline); - + extern bool UpdateGraphicsPipelineShaders(NonNullPtr driver, NonNullPtr graphicsPipeline, + Shader* optional_vertexShader, Shader* optional_fragmentShader); namespace Internal { extern void ReleaseGraphicsPipeline(NonNullPtr d3d12GraphicsPipeline); diff --git a/Juliet/src/Graphics/Graphics.cpp b/Juliet/src/Graphics/Graphics.cpp index 0bcb8d6..e1b23ab 100644 --- a/Juliet/src/Graphics/Graphics.cpp +++ b/Juliet/src/Graphics/Graphics.cpp @@ -310,4 +310,12 @@ namespace Juliet { device->DestroyGraphicsPipeline(device->Driver, graphicsPipeline); } + +#if ALLOW_SHADER_HOT_RELOAD + bool UpdateGraphicsPipelineShaders(NonNullPtr device, NonNullPtr graphicsPipeline, + Shader* optional_vertexShader, Shader* optional_fragmentShader) + { + return device->UpdateGraphicsPipelineShaders(device->Driver, graphicsPipeline, optional_vertexShader, optional_fragmentShader); + } +#endif } // namespace Juliet diff --git a/Juliet/src/Graphics/GraphicsDevice.h b/Juliet/src/Graphics/GraphicsDevice.h index ea51b31..1ea27d4 100644 --- a/Juliet/src/Graphics/GraphicsDevice.h +++ b/Juliet/src/Graphics/GraphicsDevice.h @@ -80,6 +80,8 @@ namespace Juliet // Pipeline GraphicsPipeline* (*CreateGraphicsPipeline)(NonNullPtr driver, const GraphicsPipelineCreateInfo& createInfo); void (*DestroyGraphicsPipeline)(NonNullPtr driver, NonNullPtr pipeline); + bool (*UpdateGraphicsPipelineShaders)(NonNullPtr driver, NonNullPtr graphicsPipeline, + Shader* optional_vertexShader, Shader* optional_fragmentShader); const char* Name = "Unknown"; GPUDriver* Driver = nullptr; diff --git a/JulietApp/main.cpp b/JulietApp/main.cpp index 67b67a8..84c6d39 100644 --- a/JulietApp/main.cpp +++ b/JulietApp/main.cpp @@ -202,6 +202,30 @@ void JulietApplication::Update() { // We need to wait for the gpu to be idle to recreate our graphics pipelines WaitUntilGPUIsIdle(GraphicsDevice); + +#if ALLOW_SHADER_HOT_RELOAD + String entryPoint = WrapString("main"); + ShaderCreateInfo shaderCI = {}; + shaderCI.EntryPoint = entryPoint; + String shaderPath = WrapString("../../../Assets/compiled/Triangle.vert.dxil"); + shaderCI.Stage = ShaderStage::Vertex; + Shader* vertexShader = CreateShader(GraphicsDevice, shaderPath, shaderCI); + + shaderPath = WrapString("../../../Assets/compiled/SolidColor.frag.dxil"); + shaderCI.Stage = ShaderStage::Fragment; + Shader* fragmentShader = CreateShader(GraphicsDevice, shaderPath, shaderCI); + + UpdateGraphicsPipelineShaders(GraphicsDevice, GraphicsPipeline, vertexShader, fragmentShader); + + if (vertexShader) + { + DestroyShader(GraphicsDevice, vertexShader); + } + if (fragmentShader) + { + DestroyShader(GraphicsDevice, fragmentShader); + } +#endif } // Draw here for now