-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathSplats.cpp
More file actions
107 lines (84 loc) · 4.51 KB
/
Splats.cpp
File metadata and controls
107 lines (84 loc) · 4.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#include "Splats.h"
#include "DX12Lib/Helpers.h"
void Splats::CreateRenderMesh(Renderer& renderer)
{
auto device = renderer.device;
AllocateUploadBuffer(device.Get(), splatData.data(), splatData.size() * sizeof(splatData[0]), &splatBuffer.resource, L"SplatBuffer");
NAME_D3D12_OBJECT(splatBuffer.resource);
}
void Splats::Reset()
{
splatBuffer.resource.Reset();
bottomLevelAccelerationStructure.Reset();
}
void Splats::BuildAccelerationStructures(Renderer& renderer)
{
auto device = renderer.device;
std::vector<D3D12_RAYTRACING_AABB> aabs;
aabs.reserve(splatData.size());
for (const auto& el: splatData)
{
D3D12_RAYTRACING_AABB aabb;
aabb.MinX = el.position.x - el.radius;
aabb.MinY = el.position.y - el.radius;
aabb.MinZ = el.position.z - el.radius;
aabb.MaxX = el.position.x + el.radius;
aabb.MaxY = el.position.y + el.radius;
aabb.MaxZ = el.position.z + el.radius;
aabs.push_back(aabb);
}
AllocateUploadBuffer(device.Get(), aabs.data(), aabs.size() * sizeof(aabs[0]), &aabbBuffer.resource, L"SplatAABBs");
D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAGS buildFlags = D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_PREFER_FAST_TRACE;
auto commandQueue = renderer.directCommandQueue;
assert(commandQueue);
auto commandList = commandQueue->GetCommandList();
// PERFORMANCE TIP: mark geometry as opaque whenever applicable as it can enable important ray processing optimizations.
D3D12_RAYTRACING_GEOMETRY_FLAGS geometryFlags = D3D12_RAYTRACING_GEOMETRY_FLAG_OPAQUE;// D3D12_RAYTRACING_GEOMETRY_FLAG_NONE;
// geometryFlags |= D3D12_RAYTRACING_GEOMETRY_FLAG_NO_DUPLICATE_ANYHIT_INVOCATION;
D3D12_RAYTRACING_GEOMETRY_DESC geometryDesc = {};
geometryDesc.Type = D3D12_RAYTRACING_GEOMETRY_TYPE_PROCEDURAL_PRIMITIVE_AABBS;
geometryDesc.AABBs.AABBCount = (UINT64)splatData.size();
geometryDesc.AABBs.AABBs.StrideInBytes = sizeof(D3D12_RAYTRACING_AABB);
geometryDesc.AABBs.AABBs.StartAddress = aabbBuffer.resource->GetGPUVirtualAddress();
geometryDesc.Flags = geometryFlags;
D3D12_RAYTRACING_ACCELERATION_STRUCTURE_PREBUILD_INFO bottomLevelPrebuildInfo = {};
D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS bottomLevelInputs = {};
bottomLevelInputs.DescsLayout = D3D12_ELEMENTS_LAYOUT_ARRAY;
bottomLevelInputs.Flags = buildFlags;
bottomLevelInputs.NumDescs = 1;
bottomLevelInputs.Type = D3D12_RAYTRACING_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL;
bottomLevelInputs.pGeometryDescs = &geometryDesc;
renderer.dxrDevice->GetRaytracingAccelerationStructurePrebuildInfo(&bottomLevelInputs, &bottomLevelPrebuildInfo);
ThrowIfFalse(bottomLevelPrebuildInfo.ResultDataMaxSizeInBytes > 0);
ComPtr<ID3D12Resource> scratchResource;
AllocateUAVBuffer(device.Get(), commandList.Get(), bottomLevelPrebuildInfo.ScratchDataSizeInBytes, &scratchResource, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, L"ScratchResource");
{
D3D12_RESOURCE_STATES initialResourceState = D3D12_RESOURCE_STATE_RAYTRACING_ACCELERATION_STRUCTURE;
AllocateUAVBuffer(device.Get(), commandList.Get(), bottomLevelPrebuildInfo.ResultDataMaxSizeInBytes, &bottomLevelAccelerationStructure, initialResourceState, L"BLAS");
}
// Bottom Level Acceleration Structure desc
D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_DESC bottomLevelBuildDesc = {};
{
bottomLevelBuildDesc.Inputs = bottomLevelInputs;
bottomLevelBuildDesc.ScratchAccelerationStructureData = scratchResource->GetGPUVirtualAddress();
bottomLevelBuildDesc.DestAccelerationStructureData = bottomLevelAccelerationStructure->GetGPUVirtualAddress();
}
// Build acceleration structure.
// hack
ComPtr<ID3D12GraphicsCommandList4> m_dxrCommandList;
ThrowIfFailed(commandList->QueryInterface(IID_PPV_ARGS(&m_dxrCommandList)), L"Couldn't get DirectX Raytracing interface for the command list.\n");
// BuildAccelerationStructure
{
auto* raytracingCommandList = m_dxrCommandList.Get();
raytracingCommandList->BuildRaytracingAccelerationStructure(&bottomLevelBuildDesc, 0, nullptr);
CD3DX12_RESOURCE_BARRIER a = CD3DX12_RESOURCE_BARRIER::UAV(bottomLevelAccelerationStructure.Get());
commandList->ResourceBarrier(1, &a);
}
// Kick off acceleration structure construction.
auto fenceValue = commandQueue->ExecuteCommandList(commandList);
commandQueue->WaitForFenceValue(fenceValue);
}
UINT Splats::CreateSRVs(Renderer& renderer)
{
return renderer.CreateBufferSRV(&splatBuffer, (UINT)splatData.size(), sizeof(splatData[0]));
}