Skip to content

Commit

Permalink
Wait for renderer to be ready before hooking + startup improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
praydog committed Jun 5, 2024
1 parent 5c2eea8 commit b630984
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 18 deletions.
48 changes: 43 additions & 5 deletions shared/sdk/REContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,21 @@ namespace sdk {

sdk::VM* VM::get() {
update_pointers();

if (s_global_context == nullptr) {
return nullptr;
}

return *s_global_context;
}

REThreadContext* VM::get_thread_context(int32_t unk /*= -1*/) {
update_pointers();

if (s_get_thread_context == nullptr) {
return nullptr;
}

return s_get_thread_context(this, unk);
}

Expand Down Expand Up @@ -86,7 +95,7 @@ namespace sdk {
std::unique_lock lock{ s_mutex };

utility::ScopeGuard sg{ [&]() {
s_fully_updated_pointers = true;
s_fully_updated_pointers = s_global_context != nullptr && s_get_thread_context != nullptr;
}
};

Expand Down Expand Up @@ -148,11 +157,16 @@ namespace sdk {
return;
}

s_global_context = (decltype(s_global_context))utility::calculate_absolute(*ref + context_pattern->ctx_offset);
s_get_thread_context = (decltype(s_get_thread_context))utility::calculate_absolute(*ref + context_pattern->get_thread_context_offset);
const auto potential_context = (decltype(s_global_context))utility::calculate_absolute(*ref + context_pattern->ctx_offset);
bool found_tdb = false;

if (*potential_context == nullptr) {
spdlog::info("[VM::update_pointers] Context is null.");
return;
}

for (auto i = 0; i < 0x20000; i += sizeof(void*)) {
auto ptr = *(sdk::RETypeDB**)((uintptr_t)*s_global_context + i);
auto ptr = *(sdk::RETypeDB**)((uintptr_t)*potential_context + i);

if (ptr == nullptr || IsBadReadPtr(ptr, sizeof(void*)) || ((uintptr_t)ptr & (sizeof(void*) - 1)) != 0) {
continue;
Expand All @@ -164,6 +178,7 @@ namespace sdk {
s_tdb_version = version;
s_type_db_offset = i;
s_static_tbl_offset = s_type_db_offset - 0x30; // hope this holds true for the older gameS!!!!!!!!!!!!!!!!!!!
found_tdb = true;
spdlog::info("[VM::update_pointers] s_type_db_offset: {:x}", s_type_db_offset);
spdlog::info("[VM::update_pointers] s_static_tbl_offset: {:x}", s_static_tbl_offset);
spdlog::info("[VM::update_pointers] TDB Version: {}", version);
Expand All @@ -172,6 +187,14 @@ namespace sdk {
}
}

if (!found_tdb) {
spdlog::error("[VM::update_pointers] Unable to find TDB inside VM");
return;
}

s_global_context = potential_context;
s_get_thread_context = (decltype(s_get_thread_context))utility::calculate_absolute(*ref + context_pattern->get_thread_context_offset);

spdlog::info("[VM::update_pointers] s_global_context: {:x}", (uintptr_t)s_global_context);
spdlog::info("[VM::update_pointers] s_get_thread_context: {:x}", (uintptr_t)s_get_thread_context);

Expand All @@ -181,13 +204,14 @@ namespace sdk {
#if TDB_VER >= 71
if (s_global_context != nullptr && *s_global_context != nullptr) {
auto static_tbl = (REStaticTbl**)((uintptr_t)*s_global_context + s_static_tbl_offset);
bool found_static_tbl_offset = false;
if (IsBadReadPtr(*static_tbl, sizeof(void*)) || ((uintptr_t)*static_tbl & (sizeof(void*) - 1)) != 0) {
spdlog::info("[VM::update_pointers] Static table offset is bad, correcting...");

// We are looking for the two arrays, the static field table, and the static field "initialized table"
// The initialized table tells whether a specific entry in the static field table has been initialized or not
// so they both should have the same size, easy to find
for (auto i = sizeof(void*); i < 0x100; i+= sizeof(void*)) {
for (auto i = sizeof(void*); i < 0x100; i+= sizeof(void*)) try {
const auto& ptr = *(REStaticTbl**)((uintptr_t)*s_global_context + (s_type_db_offset - i));

if (IsBadReadPtr(ptr, sizeof(void*)) || ((uintptr_t)ptr & (sizeof(void*) - 1)) != 0) {
Expand All @@ -210,9 +234,23 @@ namespace sdk {
if (previous_count == potential_count) {
spdlog::info("[VM::update_pointers] Found static table at {:x} (offset {:x})", (uintptr_t)ptr, previous_offset);
s_static_tbl_offset = previous_offset;
found_static_tbl_offset = true;
break;
}
} catch (...) {
continue;
}
} else {
found_static_tbl_offset = true;
}

// Just make it return null if we can't find it
// We do this so the consumer can do while(sdk::VM::get() != nullptr) { ... } to wait for everything to be valid
if (!found_static_tbl_offset) {
spdlog::error("[VM::update_pointers] Unable to find static table offset.");
s_global_context = nullptr;
s_get_thread_context = nullptr;
return;
}
}
#endif
Expand Down
11 changes: 11 additions & 0 deletions shared/sdk/Renderer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,17 @@ void RenderContext::copy_texture(Texture* dest, Texture* src, Fence& fence) {
func(this, dest, src, fence);
}

std::optional<uint32_t> Renderer::get_render_frame() const {
static auto tdef = sdk::find_type_definition("via.render.Renderer");
static auto m = tdef != nullptr ? tdef->get_method("get_RenderFrame") : nullptr;

if (m == nullptr) {
return std::nullopt;
}

return m->call<uint32_t>(sdk::get_thread_context(), this);
}

ConstantBuffer* Renderer::get_constant_buffer(std::string_view name) const {
static auto tdef = sdk::find_type_definition("via.render.Renderer");
static auto t = tdef->get_type();
Expand Down
6 changes: 6 additions & 0 deletions shared/sdk/Renderer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,12 @@ class RenderContext {

class Renderer {
public:
void* get_device() const {
return *(void**)((uintptr_t)this + sizeof(void*)); // simple!
}

std::optional<uint32_t> get_render_frame() const;

ConstantBuffer* get_constant_buffer(std::string_view name) const;

ConstantBuffer* get_scene_info() const {
Expand Down
33 changes: 30 additions & 3 deletions src/D3D12Hook.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ void* D3D12Hook::Streamline::link_swapchain_to_cmd_queue(void* rcx, void* rdx, v
return hook->get_original<decltype(link_swapchain_to_cmd_queue)>()(rcx, rdx, r8, r9);
}

while (g_framework == nullptr) {
std::this_thread::yield();
}

std::scoped_lock _{g_framework->get_hook_monitor_mutex()};

spdlog::info("[Streamline] linkSwapchainToCmdQueue: {:x}", (uintptr_t)_ReturnAddress());
Expand Down Expand Up @@ -67,6 +71,10 @@ HRESULT WINAPI D3D12Hook::create_swapchain(IDXGIFactory4* factory, IUnknown* dev

spdlog::info("create_swapchain called");

while (g_framework == nullptr) {
std::this_thread::yield();
}

std::scoped_lock _{g_framework->get_hook_monitor_mutex()};

bool hook_was_nullptr = g_d3d12_hook == nullptr;
Expand Down Expand Up @@ -354,10 +362,13 @@ bool D3D12Hook::hook() {
return false;
}

const auto ti = utility::rtti::get_type_info(swap_chain1);

try {
const auto swapchain_classname = ti != nullptr ? std::string_view{ti->name()} : "unknown";
const auto ti = utility::rtti::get_type_info(swap_chain1);
const auto swapchain_classname = ti != nullptr && ti->name() != nullptr ? std::string_view{ti->name()} : "unknown";
const auto raw_name = ti != nullptr && ti->raw_name() != nullptr ? std::string_view{ti->raw_name()} : "unknown";

spdlog::info("Swapchain type info: {}", swapchain_classname);
spdlog::info("Swapchain raw type info: {}", raw_name);

if (swapchain_classname.contains("interposer::DXGISwapChain")) { // DLSS3
spdlog::info("Found Streamline (DLSSFG) swapchain during dummy initialization: {:x}", (uintptr_t)swap_chain1);
Expand Down Expand Up @@ -505,6 +516,10 @@ bool D3D12Hook::hook() {
}

bool D3D12Hook::unhook() {
while (g_framework == nullptr) {
std::this_thread::yield();
}

std::scoped_lock _{g_framework->get_hook_monitor_mutex()};

if (!m_hooked) {
Expand All @@ -525,6 +540,10 @@ bool D3D12Hook::unhook() {
thread_local int32_t g_present_depth = 0;

HRESULT WINAPI D3D12Hook::present(IDXGISwapChain3* swap_chain, uint64_t sync_interval, uint64_t flags, void* r9) {
while (g_framework == nullptr) {
std::this_thread::yield();
}

std::scoped_lock _{g_framework->get_hook_monitor_mutex()};

auto d3d12 = g_d3d12_hook;
Expand Down Expand Up @@ -657,6 +676,10 @@ HRESULT WINAPI D3D12Hook::present(IDXGISwapChain3* swap_chain, uint64_t sync_int
thread_local int32_t g_resize_buffers_depth = 0;

HRESULT WINAPI D3D12Hook::resize_buffers(IDXGISwapChain3* swap_chain, UINT buffer_count, UINT width, UINT height, DXGI_FORMAT new_format, UINT swap_chain_flags) {
while (g_framework == nullptr) {
std::this_thread::yield();
}

std::scoped_lock _{g_framework->get_hook_monitor_mutex()};

spdlog::info("D3D12 resize buffers called");
Expand Down Expand Up @@ -742,6 +765,10 @@ HRESULT WINAPI D3D12Hook::resize_buffers(IDXGISwapChain3* swap_chain, UINT buffe
thread_local int32_t g_resize_target_depth = 0;

HRESULT WINAPI D3D12Hook::resize_target(IDXGISwapChain3* swap_chain, const DXGI_MODE_DESC* new_target_parameters) {
while (g_framework == nullptr) {
std::this_thread::yield();
}

std::scoped_lock _{g_framework->get_hook_monitor_mutex()};

spdlog::info("D3D12 resize target called");
Expand Down
135 changes: 125 additions & 10 deletions src/REFramework.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,11 @@ try {
if (NotificationReason == LDR_DLL_NOTIFICATION_REASON_LOADED) {
if (NotificationData->Loaded.BaseDllName != nullptr && NotificationData->Loaded.BaseDllName->Buffer != nullptr) {
std::wstring base_dll_name = NotificationData->Loaded.BaseDllName->Buffer;
std::wstring lower_base_dll_name = base_dll_name;
std::transform(lower_base_dll_name.begin(), lower_base_dll_name.end(), lower_base_dll_name.begin(), ::towlower);
spdlog::info("LdrRegisterDllNotification: Loaded: {}", utility::narrow(base_dll_name));

if (base_dll_name.find(L"sl.dlss_g.dll") != std::wstring::npos) {
if (lower_base_dll_name.find(L"sl.dlss_g.dll") != std::wstring::npos) {
spdlog::info("LdrRegisterDllNotification: Detected DLSS DLL loaded");

D3D12Hook::hook_streamline((HMODULE)NotificationData->Loaded.DllBase);
Expand Down Expand Up @@ -508,21 +510,134 @@ REFramework::REFramework(HMODULE reframework_module)
suspender.resume();
#endif

// Hooking D3D12 initially because we need to retrieve the command queue before the first frame then switch to D3D11 if it failed later
// on
// addendum: now we don't need to do that, we just grab the command queue offset from the swapchain we create
/*if (!hook_d3d12()) {
spdlog::error("Failed to hook D3D12 for initial test.");
}*/
// Load the plugins early right after executable unpacking
PluginLoader::get()->early_init();

// Wait for TDB and render device to be initialized before allowing D3D hooking
const auto start_time = std::chrono::high_resolution_clock::now();

while (true) {
try {
if (sdk::VM::get() != nullptr) {
break;
}
} catch(...) {
}

if (std::chrono::high_resolution_clock::now() - start_time > std::chrono::seconds(30)) {
spdlog::error("Timed out waiting for VM to initialize.");
throw std::runtime_error("Timed out waiting for VM to initialize.");
}

//std::this_thread::sleep_for(std::chrono::milliseconds(100));
std::this_thread::yield();
}

spdlog::info("VM initialized, waiting for renderer to initialize...");
sdk::RETypeDefinition* renderer_t = nullptr;
sdk::renderer::Renderer* renderer = nullptr;
bool found_renderer = false;

while (true) try {
const auto tdb = sdk::RETypeDB::get();

if (tdb == nullptr) {
spdlog::error("TypeDB not found");
break;
}

// We have to manually look through the types because get_FullName
// will crash if we call it this early, which is used in get_type(name)
if (renderer_t == nullptr) {
for (auto i = 0; i < tdb->get_num_types(); ++i) {
const auto t = tdb->get_type(i);

if (t == nullptr || t->get_name() == nullptr || t->get_namespace() == nullptr) {
continue;
}

if (std::string_view{t->get_name()} == "Renderer" && std::string_view{t->get_namespace()} == "via.render") {
spdlog::info("Renderer type found manually");
renderer_t = t;
break;
}
}
}

if (renderer_t == nullptr) {
spdlog::error("Renderer type not found");
break;
}

const auto renderer_has_instance = renderer_t->get_method("hasInstance");

if (renderer_has_instance == nullptr) {
spdlog::error("Renderer::hasInstance not found");
break;
}

if (renderer_has_instance->get_function() == nullptr) {
continue;
}

const auto has_instance = renderer_has_instance->call<bool>(nullptr, nullptr); // static

if (!has_instance) {
std::this_thread::yield();
continue;
}

renderer = sdk::renderer::get_renderer();

if (renderer != nullptr) {
found_renderer = true;
break;
}

spdlog::info("waiting for renderer");
std::this_thread::sleep_for(std::chrono::milliseconds(100));
} catch(...) {
spdlog::warn("Exception occurred while waiting for renderer");
continue;
}

spdlog::info("Found renderer, waiting for first frame...");

bool valid_render_frame = false;

while (found_renderer) try {
const auto render_frame = renderer->get_render_frame();

if (!render_frame.has_value()) {
spdlog::warn("Render frame property not found");
break;
}

if (*render_frame > 0) {
spdlog::info("Render frame: {}", *render_frame);
valid_render_frame = true;
break;
}

std::this_thread::yield();
} catch(...) {
spdlog::warn("Exception occurred while waiting for render frame");
break;
}

// If all is good, we can immediately hook D3D12 very early
// else, defer to the hook monitor if anything in the chain failed
if (valid_render_frame) {
// We can guaranteed hook at this point
std::scoped_lock _{m_hook_monitor_mutex};
hook_d3d12();
}

std::scoped_lock _{m_hook_monitor_mutex};

m_last_present_time = std::chrono::steady_clock::now();
m_last_message_time = std::chrono::steady_clock::now();
m_d3d_monitor_thread = std::make_unique<std::jthread>([this](std::stop_token stop_token) {
// Load the plugins early right after executable unpacking
PluginLoader::get()->early_init();

while (!stop_token.stop_requested() && !m_terminating) {
this->hook_monitor();
std::this_thread::sleep_for(std::chrono::milliseconds(500));
Expand Down

0 comments on commit b630984

Please sign in to comment.