Reload of shaders working!

This commit is contained in:
2025-03-15 21:17:44 -04:00
parent c9cd01bb31
commit f4ba25bec1
8 changed files with 184 additions and 9 deletions

View File

@@ -20,8 +20,8 @@ using size_t = std::size_t;
struct ByteBuffer
{
Byte* Data;
size_t Size;
Byte* Data;
size_t Size;
};
using FunctionPtr = auto (*)(void) -> void;
@@ -37,7 +37,7 @@ constexpr uint16 uint16Max = MaxValueOf<uint16>();
constexpr uint32 uint32Max = MaxValueOf<uint32>();
constexpr uint64 uint64Max = MaxValueOf<uint64>();
constexpr int8 int8Max = MaxValueOf<int8>();
constexpr int16 int16Max = MaxValueOf<int16>();
constexpr int32 int32Max = MaxValueOf<int32>();
constexpr int64 int64Max = MaxValueOf<int64>();
constexpr int8 int8Max = MaxValueOf<int8>();
constexpr int16 int16Max = MaxValueOf<int16>();
constexpr int32 int32Max = MaxValueOf<int32>();
constexpr int64 int64Max = MaxValueOf<int64>();

View File

@@ -10,10 +10,13 @@
#include <Graphics/Shader.h>
#include <Juliet.h>
#ifdef JULIET_DEBUG
#define ALLOW_SHADER_HOT_RELOAD 1
#endif
// Graphics Interface
namespace Juliet
{
// Opaque types
struct CommandList;
struct GraphicsDevice;
@@ -129,5 +132,10 @@ namespace Juliet
extern JULIET_API GraphicsPipeline* CreateGraphicsPipeline(NonNullPtr<GraphicsDevice> device,
const GraphicsPipelineCreateInfo& createInfo);
extern JULIET_API void DestroyGraphicsPipeline(NonNullPtr<GraphicsDevice> device, NonNullPtr<GraphicsPipeline> graphicsPipeline);
#if ALLOW_SHADER_HOT_RELOAD
// Allows updating the graphics pipeline shaders. Can update either one or both shaders.
extern JULIET_API bool UpdateGraphicsPipelineShaders(NonNullPtr<GraphicsDevice> device, NonNullPtr<GraphicsPipeline> graphicsPipeline,
Shader* optional_vertexShader, Shader* optional_fragmentShader);
#endif
} // namespace Juliet

View File

@@ -755,6 +755,7 @@ namespace Juliet::D3D12
device->DestroyShader = DestroyShader;
device->CreateGraphicsPipeline = CreateGraphicsPipeline;
device->DestroyGraphicsPipeline = DestroyGraphicsPipeline;
device->UpdateGraphicsPipelineShaders = UpdateGraphicsPipelineShaders;
device->Driver = driver;
device->DebugEnabled = enableDebug;

View File

@@ -596,6 +596,25 @@ namespace Juliet::D3D12
}
Free(rootSignature.Get());
}
void CopyShader(NonNullPtr<D3D12Shader> destination, NonNullPtr<D3D12Shader> source)
{
D3D12Shader* src = source.Get();
D3D12Shader* dst = destination.Get();
ByteBuffer dstBuffer = dst->ByteCode;
if (src->ByteCode.Size != dstBuffer.Size)
{
dstBuffer.Data = static_cast<Byte*>(Realloc(dstBuffer.Data, src->ByteCode.Size));
dstBuffer.Size = src->ByteCode.Size;
}
// Copy the shader data. Infortunately this will overwrite the bytecode if it exists so we patch it back just after
MemCopy(dst, src, sizeof(D3D12Shader));
dst->ByteCode = dstBuffer;
MemCopy(dst->ByteCode.Data, src->ByteCode.Data, src->ByteCode.Size);
}
} // namespace
GraphicsPipeline* CreateGraphicsPipeline(NonNullPtr<GPUDriver> driver, const GraphicsPipelineCreateInfo& createInfo)
@@ -668,7 +687,7 @@ namespace Juliet::D3D12
return nullptr;
}
pipeline->RootSignature = rootSignature;
psoDesc.pRootSignature = rootSignature->Handle;
psoDesc.pRootSignature = rootSignature->Handle;
ID3D12PipelineState* pipelineState;
HRESULT res = ID3D12Device_CreateGraphicsPipelineState(d3d12Driver->D3D12Device, &psoDesc, IID_ID3D12PipelineState,
@@ -702,6 +721,16 @@ namespace Juliet::D3D12
pipeline->ReferenceCount = 0;
#if ALLOW_SHADER_HOT_RELOAD
// Save the PSODesc and shaders to be able to recreate the graphics pipeline when needed
pipeline->PSODescTemplate = psoDesc;
pipeline->VertexShaderCache = static_cast<D3D12Shader*>(Calloc(1, sizeof(D3D12Shader)));
pipeline->FragmentShaderCache = static_cast<D3D12Shader*>(Calloc(1, sizeof(D3D12Shader)));
CopyShader(pipeline->VertexShaderCache, vertexShader);
CopyShader(pipeline->FragmentShaderCache, fragmentShader);
#endif
return reinterpret_cast<GraphicsPipeline*>(pipeline);
}
@@ -721,6 +750,90 @@ namespace Juliet::D3D12
d3d12Driver->GraphicsPipelinesToDisposeCount += 1;
}
#if ALLOW_SHADER_HOT_RELOAD
bool UpdateGraphicsPipelineShaders(NonNullPtr<GPUDriver> driver, NonNullPtr<GraphicsPipeline> graphicsPipeline,
Shader* optional_vertexShader, Shader* optional_fragmentShader)
{
auto d3d12Driver = static_cast<D3D12Driver*>(driver.Get());
auto d3d12GraphicsPipeline = reinterpret_cast<D3D12GraphicsPipeline*>(graphicsPipeline.Get());
Assert(d3d12GraphicsPipeline->ReferenceCount == 0 &&
"Trying to update a d3d12 graphics pipeline that is currently being used! Call WaitUntilGPUIsIdle "
"before updating!");
auto vertexShader = reinterpret_cast<D3D12Shader*>(optional_vertexShader);
auto fragmentShader = reinterpret_cast<D3D12Shader*>(optional_fragmentShader);
if (!vertexShader)
{
vertexShader = d3d12GraphicsPipeline->VertexShaderCache;
}
if (!fragmentShader)
{
fragmentShader = d3d12GraphicsPipeline->FragmentShaderCache;
}
// Recreate a new root signature and pipeline state
D3D12GraphicsRootSignature* rootSignature = CreateGraphicsRootSignature(d3d12Driver, vertexShader, fragmentShader);
if (rootSignature == nullptr)
{
return false;
}
auto psoDesc = d3d12GraphicsPipeline->PSODescTemplate;
psoDesc.VS.pShaderBytecode = vertexShader->ByteCode.Data;
psoDesc.VS.BytecodeLength = vertexShader->ByteCode.Size;
psoDesc.PS.pShaderBytecode = fragmentShader->ByteCode.Data;
psoDesc.PS.BytecodeLength = fragmentShader->ByteCode.Size;
psoDesc.pRootSignature = rootSignature->Handle;
ID3D12PipelineState* pipelineState;
HRESULT res = ID3D12Device_CreateGraphicsPipelineState(d3d12Driver->D3D12Device, &psoDesc, IID_ID3D12PipelineState,
reinterpret_cast<void**>(&pipelineState));
if (FAILED(res))
{
LogError(d3d12Driver, "Could not create graphics pipeline state", res);
return false;
}
d3d12GraphicsPipeline->VertexSamplerCount = vertexShader->NumSamplers;
d3d12GraphicsPipeline->VertexStorageTextureCount = vertexShader->NumStorageTextures;
d3d12GraphicsPipeline->VertexStorageBufferCount = vertexShader->NumStorageBuffers;
d3d12GraphicsPipeline->VertexUniformBufferCount = vertexShader->NumUniformBuffers;
d3d12GraphicsPipeline->FragmentSamplerCount = fragmentShader->NumSamplers;
d3d12GraphicsPipeline->FragmentStorageTextureCount = fragmentShader->NumStorageTextures;
d3d12GraphicsPipeline->FragmentStorageBufferCount = fragmentShader->NumStorageBuffers;
d3d12GraphicsPipeline->FragmentUniformBufferCount = fragmentShader->NumUniformBuffers;
// If everything worked, we patch the graphics pipeline and destroy everything irrelevant
if (d3d12GraphicsPipeline->PipelineState)
{
ID3D12PipelineState_Release(d3d12GraphicsPipeline->PipelineState);
}
d3d12GraphicsPipeline->PipelineState = pipelineState;
if (d3d12GraphicsPipeline->RootSignature)
{
DestroyGraphicsRootSignature(d3d12GraphicsPipeline->RootSignature);
}
d3d12GraphicsPipeline->RootSignature = rootSignature;
if (vertexShader != d3d12GraphicsPipeline->VertexShaderCache)
{
CopyShader(d3d12GraphicsPipeline->VertexShaderCache, vertexShader);
}
if (fragmentShader != d3d12GraphicsPipeline->FragmentShaderCache)
{
CopyShader(d3d12GraphicsPipeline->FragmentShaderCache, fragmentShader);
}
return true;
}
#endif
namespace Internal
{
void ReleaseGraphicsPipeline(NonNullPtr<D3D12GraphicsPipeline> d3d12GraphicsPipeline)
@@ -730,6 +843,12 @@ namespace Juliet::D3D12
ID3D12PipelineState_Release(d3d12GraphicsPipeline->PipelineState);
}
DestroyGraphicsRootSignature(d3d12GraphicsPipeline->RootSignature);
#if ALLOW_SHADER_HOT_RELOAD
SafeFree(d3d12GraphicsPipeline->VertexShaderCache);
SafeFree(d3d12GraphicsPipeline->FragmentShaderCache);
#endif
Free(d3d12GraphicsPipeline.Get());
}
} // namespace Internal

View File

@@ -1,6 +1,7 @@
#pragma once
#include <Core/Common/NonNullPtr.h>
#include <D3D12Shader.h>
#include <Graphics/D3D12/D3D12Includes.h>
#include <Graphics/GraphicsDevice.h>
#include <Graphics/GraphicsPipeline.h>
@@ -35,6 +36,17 @@ namespace Juliet::D3D12
struct D3D12GraphicsPipeline
{
#if ALLOW_SHADER_HOT_RELOAD
// Template to recreate an ID3D12PipelineState when shader are hot reloaded
// Stripped out in shipping build as the struct is huge
D3D12_GRAPHICS_PIPELINE_STATE_DESC PSODescTemplate;
// Keeping shaders byte code to make it easier to recreate the ID3D12PipelineState
// Those will be freed when the pipeline is destroyed or updated
D3D12Shader * VertexShaderCache;
D3D12Shader * FragmentShaderCache;
#endif
ID3D12PipelineState* PipelineState;
D3D12GraphicsRootSignature* RootSignature;
PrimitiveType PrimitiveType;
@@ -56,7 +68,8 @@ namespace Juliet::D3D12
extern GraphicsPipeline* CreateGraphicsPipeline(NonNullPtr<GPUDriver> driver, const GraphicsPipelineCreateInfo& createInfo);
extern void DestroyGraphicsPipeline(NonNullPtr<GPUDriver> driver, NonNullPtr<GraphicsPipeline> graphicsPipeline);
extern bool UpdateGraphicsPipelineShaders(NonNullPtr<GPUDriver> driver, NonNullPtr<GraphicsPipeline> graphicsPipeline,
Shader* optional_vertexShader, Shader* optional_fragmentShader);
namespace Internal
{
extern void ReleaseGraphicsPipeline(NonNullPtr<D3D12GraphicsPipeline> d3d12GraphicsPipeline);

View File

@@ -310,4 +310,12 @@ namespace Juliet
{
device->DestroyGraphicsPipeline(device->Driver, graphicsPipeline);
}
#if ALLOW_SHADER_HOT_RELOAD
bool UpdateGraphicsPipelineShaders(NonNullPtr<GraphicsDevice> device, NonNullPtr<GraphicsPipeline> graphicsPipeline,
Shader* optional_vertexShader, Shader* optional_fragmentShader)
{
return device->UpdateGraphicsPipelineShaders(device->Driver, graphicsPipeline, optional_vertexShader, optional_fragmentShader);
}
#endif
} // namespace Juliet

View File

@@ -80,6 +80,8 @@ namespace Juliet
// Pipeline
GraphicsPipeline* (*CreateGraphicsPipeline)(NonNullPtr<GPUDriver> driver, const GraphicsPipelineCreateInfo& createInfo);
void (*DestroyGraphicsPipeline)(NonNullPtr<GPUDriver> driver, NonNullPtr<GraphicsPipeline> pipeline);
bool (*UpdateGraphicsPipelineShaders)(NonNullPtr<GPUDriver> driver, NonNullPtr<GraphicsPipeline> graphicsPipeline,
Shader* optional_vertexShader, Shader* optional_fragmentShader);
const char* Name = "Unknown";
GPUDriver* Driver = nullptr;

View File

@@ -202,6 +202,30 @@ void JulietApplication::Update()
{
// We need to wait for the gpu to be idle to recreate our graphics pipelines
WaitUntilGPUIsIdle(GraphicsDevice);
#if ALLOW_SHADER_HOT_RELOAD
String entryPoint = WrapString("main");
ShaderCreateInfo shaderCI = {};
shaderCI.EntryPoint = entryPoint;
String shaderPath = WrapString("../../../Assets/compiled/Triangle.vert.dxil");
shaderCI.Stage = ShaderStage::Vertex;
Shader* vertexShader = CreateShader(GraphicsDevice, shaderPath, shaderCI);
shaderPath = WrapString("../../../Assets/compiled/SolidColor.frag.dxil");
shaderCI.Stage = ShaderStage::Fragment;
Shader* fragmentShader = CreateShader(GraphicsDevice, shaderPath, shaderCI);
UpdateGraphicsPipelineShaders(GraphicsDevice, GraphicsPipeline, vertexShader, fragmentShader);
if (vertexShader)
{
DestroyShader(GraphicsDevice, vertexShader);
}
if (fragmentShader)
{
DestroyShader(GraphicsDevice, fragmentShader);
}
#endif
}
// Draw here for now