Refactor!

This commit is contained in:
TheBrokenRail 2024-07-17 03:47:32 -04:00
parent 6f792dfb16
commit b0814f257a
9 changed files with 217 additions and 158 deletions

40
data/function.cpp Normal file
View File

@ -0,0 +1,40 @@
#include "function.h"
// Virtual Function Information
template <typename T>
__VirtualFunctionInfo<T>::__VirtualFunctionInfo(T *const addr_, void *const parent_):
addr(addr_),
parent(parent_) {}
template <typename T>
bool __VirtualFunctionInfo<T>::can_overwrite() const {
return ((void *) *addr) != parent;
}
// Function Information
template <typename Ret, typename... Args>
__Function<Ret(Args...)>::__Function(const char *const name_, const __Function<Ret(Args...)>::ptr_type func_, const __Function<Ret(Args...)>::ptr_type thunk_):
is_virtual(false),
func(func_),
enabled(true),
name(name_),
backup(func_),
thunk(thunk_) {}
template <typename Ret, typename... Args>
__Function<Ret(Args...)>::__Function(const char *const name_, __Function<Ret(Args...)>::ptr_type *const func_, void *const parent, const __Function<Ret(Args...)>::ptr_type thunk_):
is_virtual(true),
func(__VirtualFunctionInfo(func_, parent)),
enabled(std::get<__VirtualFunctionInfo<ptr_type>>(func).can_overwrite()),
name(name_),
backup(*get_vtable_addr()),
thunk(thunk_) {}
// Thunks
template <typename Ret, typename... Args>
void __Function<Ret(Args...)>::enable_thunk(const thunk_enabler_t &thunk_enabler) {
if (enabled) {
ptr_type real_thunk = (ptr_type) thunk_enabler((void *) get(), (void *) thunk);
if (!is_virtual) {
func = real_thunk;
}
}
}

99
data/function.h Normal file
View File

@ -0,0 +1,99 @@
#pragma once
#include <variant>
#include <functional>
#include <type_traits>
// Virtual Function Information
template <typename T>
class __Function;
template <typename T>
class __VirtualFunctionInfo {
__VirtualFunctionInfo(T *addr_, void *parent_);
[[nodiscard]] bool can_overwrite() const;
T *const addr;
void *const parent;
friend class __Function<std::remove_pointer_t<T>>;
};
// Thunks
typedef void *(*thunk_enabler_t)(void *target, void *thunk);
void enable_all_thunks(const thunk_enabler_t &thunk_enabler);
// Function Information
template <typename Ret, typename... Args>
class __Function<Ret(Args...)> {
public:
// Types
using ptr_type = Ret (*)(Args...);
using type = std::function<Ret(Args...)>;
using overwrite_type = std::function<Ret(type, Args...)>;
// Normal Function
__Function(const char *name_, ptr_type func_, ptr_type thunk_);
// Virtual Function
__Function(const char *name_, ptr_type *func_, void *parent, ptr_type thunk_);
// Overwrite Function
[[nodiscard]] bool overwrite(overwrite_type target) {
// Check If Enabled
if (!enabled) {
return false;
}
// Overwrite
type original = get_thunk_target();
thunk_target = [original, target](Args... args) {
return target(original, args...);
};
return true;
}
// Getters
[[nodiscard]] ptr_type get_backup() const {
return backup;
}
[[nodiscard]] ptr_type get() const {
if (!enabled) {
return nullptr;
} else if (is_virtual) {
return *get_vtable_addr();
} else {
return std::get<ptr_type>(func);
}
}
[[nodiscard]] ptr_type *get_vtable_addr() const {
if (is_virtual) {
return std::get<__VirtualFunctionInfo<ptr_type>>(func).addr;
} else {
return nullptr;
}
}
[[nodiscard]] const char *get_name() const {
return name;
}
[[nodiscard]] type get_thunk_target() const {
if (thunk_target) {
return thunk_target;
} else {
return get_backup();
}
}
private:
// Current Function
const bool is_virtual;
std::variant<ptr_type, __VirtualFunctionInfo<ptr_type>> func;
// State
const bool enabled;
const char *const name;
// Backup Of Original Function Pointer
const ptr_type backup;
// Thunk
const ptr_type thunk;
type thunk_target;
void enable_thunk(const thunk_enabler_t &thunk_enabler);
friend void enable_all_thunks(const thunk_enabler_t &);
};

View File

@ -1,6 +1,9 @@
#include "{{ headerPath }}" #include "{{ headerPath }}"
#include "{{ data }}/function.cpp"
// Thunks // Thunks
template <typename T>
struct __Thunk;
template <typename Ret, typename... Args> template <typename Ret, typename... Args>
struct __Thunk<Ret(Args...)> { struct __Thunk<Ret(Args...)> {
template <__Function<Ret(Args...)> *const *func> template <__Function<Ret(Args...)> *const *func>
@ -9,7 +12,9 @@ struct __Thunk<Ret(Args...)> {
} }
}; };
// Thunk Enabler
thunk_enabler_t thunk_enabler;
{{ main }} {{ main }}
// Enable All Thunks
void enable_all_thunks(const thunk_enabler_t &thunk_enabler) {
{{ enableThunks }}
}

View File

@ -6,147 +6,14 @@
#endif #endif
// Headers // Headers
#include "{{ data }}/function.h"
#include <cstddef> #include <cstddef>
#include <string> #include <string>
#include <vector> #include <vector>
#include <map> #include <map>
#include <functional>
#include <type_traits> #include <type_traits>
#include <cstring> #include <cstring>
// Virtual Function Information
template <typename T>
class __Function;
template <typename T>
class __VirtualFunctionInfo {
__VirtualFunctionInfo(T *const addr_, void *const parent_):
addr(addr_),
parent(parent_) {}
bool can_overwrite() const {
return ((void *) *addr) != parent;
}
T *const addr;
void *const parent;
friend class __Function<std::remove_pointer_t<T>>;
};
// Thunks
template <typename T>
struct __Thunk;
typedef void *(*thunk_enabler_t)(void *target, void *thunk);
extern thunk_enabler_t thunk_enabler;
// Function Information
template <typename Ret, typename... Args>
class __Function<Ret(Args...)> {
public:
// Types
using ptr_type = Ret (*)(Args...);
using type = std::function<Ret(Args...)>;
using overwrite_type = std::function<Ret(type, Args...)>;
// Normal Function
__Function(const char *const name_, const ptr_type func_, const ptr_type thunk_):
enabled(true),
name(name_),
is_virtual(false),
func(func_),
backup(func_),
thunk(thunk_) {}
// Virtual Function
__Function(const char *const name_, const __VirtualFunctionInfo<ptr_type> virtual_info_, const ptr_type thunk_):
enabled(virtual_info_.can_overwrite()),
name(name_),
is_virtual(true),
func(virtual_info_),
backup(*get_vtable_addr()),
thunk(thunk_) {}
__Function(const char *const name_, ptr_type *const func_, void *const parent, const ptr_type thunk_):
__Function(name_, __VirtualFunctionInfo(func_, parent), thunk_) {}
// Overwrite Function
bool overwrite(overwrite_type target) {
// Check If Enabled
if (!enabled) {
return false;
}
// Enable Thunk
enable_thunk();
// Overwrite
type original = get_thunk_target();
thunk_target = [original, target](Args... args) {
return target(original, args...);
};
return true;
}
// Getters
ptr_type get_backup() const {
return backup;
}
ptr_type get() {
if (!enabled) {
return nullptr;
} else {
enable_thunk();
if (is_virtual) {
return *get_vtable_addr();
} else {
return func.normal_addr;
}
}
}
ptr_type *get_vtable_addr() const {
if (is_virtual) {
return func.virtual_info.addr;
} else {
return nullptr;
}
}
const char *get_name() const {
return name;
}
private:
// State
const bool enabled;
const char *const name;
// Current Function
const bool is_virtual;
union __FunctionInfo {
explicit __FunctionInfo(const ptr_type normal_addr_): normal_addr(normal_addr_) {}
explicit __FunctionInfo(const __VirtualFunctionInfo<ptr_type> virtual_info_): virtual_info(virtual_info_) {}
ptr_type normal_addr;
const __VirtualFunctionInfo<ptr_type> virtual_info;
} func;
// Backup Of Original Function Pointer
const ptr_type backup;
// Thunk
const ptr_type thunk;
bool thunk_enabled = false;
type thunk_target;
void enable_thunk() {
if (!thunk_enabled) {
ptr_type real_thunk = (ptr_type) thunk_enabler((void *) backup, (void *) thunk);
if (!is_virtual) {
func.normal_addr = real_thunk;
}
thunk_enabled = true;
}
}
type get_thunk_target() const {
if (thunk_target) {
return thunk_target;
} else {
return backup;
}
}
friend struct __Thunk<Ret(Args...)>;
};
// Shortcuts // Shortcuts
typedef unsigned char uchar; typedef unsigned char uchar;
typedef unsigned short ushort; typedef unsigned short ushort;
@ -157,15 +24,15 @@ typedef unsigned int uint;
template <typename T> template <typename T>
T *dup_vtable(T *vtable) { T *dup_vtable(T *vtable) {
// Check // Check
static_assert(std::is_constructible<T>::value, "Unable To Construct VTable"); static_assert(std::is_constructible_v<T>, "Unable To Construct VTable");
// Get Size // Get Size
uchar *real_vtable = (uchar *) vtable; const uchar *real_vtable = (uchar *) vtable;
real_vtable -= RTTI_SIZE; real_vtable -= RTTI_SIZE;
size_t real_vtable_size = sizeof(T) + RTTI_SIZE; const size_t real_vtable_size = sizeof(T) + RTTI_SIZE;
// Allocate // Allocate
uchar *new_vtable = (uchar *) ::operator new(real_vtable_size); uchar *new_vtable = (uchar *) ::operator new(real_vtable_size);
// Copy // Copy
memcpy((void *) new_vtable, (void *) real_vtable, real_vtable_size); memcpy(new_vtable, real_vtable, real_vtable_size);
// Return // Return
new_vtable += RTTI_SIZE; new_vtable += RTTI_SIZE;
return (T *) new_vtable; return (T *) new_vtable;

View File

@ -86,8 +86,11 @@ export function prependArg(args: string, arg: string) {
} }
return '(' + arg + args.substring(1); return '(' + arg + args.substring(1);
} }
export function getDataDir() {
return path.join(__dirname, '..', 'data');
}
export function formatFile(file: string, options: {[key: string]: string}) { export function formatFile(file: string, options: {[key: string]: string}) {
file = path.join(__dirname, '..', 'data', file); file = path.join(getDataDir(), file);
let data = fs.readFileSync(file, {encoding: 'utf8'}); let data = fs.readFileSync(file, {encoding: 'utf8'});
for (const key in options) { for (const key in options) {
data = data.replace(`{{ ${key} }}`, options[key]!); data = data.replace(`{{ ${key} }}`, options[key]!);

View File

@ -1,5 +1,5 @@
import * as fs from 'node:fs'; import * as fs from 'node:fs';
import { STRUCTURE_FILES, EXTENSION, formatFile } from './common'; import { STRUCTURE_FILES, EXTENSION, formatFile, getDataDir } from './common';
import { getStructure } from './map'; import { getStructure } from './map';
import { Struct } from './struct'; import { Struct } from './struct';
@ -145,7 +145,7 @@ function makeMainHeader(output: string) {
// Main // Main
const main = makeHeaderPart().trim(); const main = makeHeaderPart().trim();
// Write // Write
const result = formatFile('out.h', {forwardDeclarations, extraHeaders, main}); const result = formatFile('out.h', {forwardDeclarations, extraHeaders, main, data: getDataDir()});
fs.writeFileSync(output, result); fs.writeFileSync(output, result);
} }
makeMainHeader(headerOutput); makeMainHeader(headerOutput);
@ -157,23 +157,25 @@ function makeCompiledCode(output: string) {
// Generate // Generate
let declarations = ''; let declarations = '';
let enableThunks = '';
for (const structure of structureObjects) { for (const structure of structureObjects) {
const name = structure.getName(); const name = structure.getName();
declarations += `// ${name}\n`; declarations += `// ${name}\n`;
try { try {
const code = structure.generateCode(); declarations += structure.generateCode();
declarations += code; enableThunks += structure.generateEnableThunks();
} catch (e) { } catch (e) {
console.log(`Error Generating Code: ${name}: ${e instanceof Error ? e.stack : e}`); console.log(`Error Generating Code: ${name}: ${e instanceof Error ? e.stack : e}`);
process.exit(1); process.exit(1);
} }
declarations += '\n'; declarations += '\n';
} }
enableThunks = enableThunks.slice(0, -1); // Remove Last Newline
// Write // Write
const headerPath = fs.realpathSync(headerOutput); const headerPath = fs.realpathSync(headerOutput);
const main = declarations.trim(); const main = declarations.trim();
const result = formatFile('out.cpp', {headerPath, main}); const result = formatFile('out.cpp', {headerPath, main, enableThunks, data: getDataDir()});
fs.writeFileSync(output, result); fs.writeFileSync(output, result);
} }
makeCompiledCode(sourceOutput); makeCompiledCode(sourceOutput);

View File

@ -1,4 +1,4 @@
import { INDENT, INTERNAL, toHex } from './common'; import { INDENT, INTERNAL, formatType, toHex } from './common';
export class Method { export class Method {
readonly self: string; readonly self: string;
@ -27,12 +27,25 @@ export class Method {
getType() { getType() {
return this.getName() + '_t'; return this.getName() + '_t';
} }
getProperty() { getProperty(hasWrapper: boolean) {
return `${INDENT}${this.getWrapperType()}::ptr_type ${this.shortName};\n`; let out = INDENT;
if (hasWrapper) {
out += `${this.getWrapperType()}::ptr_type ${this.shortName}`;
} else {
out += `${formatType(this.returnType.trim())}(*${this.shortName})${this.args.trim()}`;
}
out += ';\n';
return out;
} }
getWrapperType() { getWrapperType() {
return `std::remove_pointer_t<decltype(${this.getName()})>`; return `std::remove_pointer_t<decltype(${this.getName()})>`;
} }
#getSignature() {
return this.returnType.trim() + this.args.trim();
}
#getFullType() {
return `${INTERNAL}Function<${this.#getSignature()}>`;
}
// Overwrite Helper // Overwrite Helper
#getVirtualCall(self: string = this.self) { #getVirtualCall(self: string = this.self) {
@ -43,8 +56,7 @@ export class Method {
if (!code) { if (!code) {
out += 'extern '; out += 'extern ';
} }
const signature = this.returnType.trim() + this.args.trim(); const type = this.#getFullType();
const type = `${INTERNAL}Function<${signature}>`;
out += `${type} *const ${this.getName()}`; out += `${type} *const ${this.getName()}`;
if (code) { if (code) {
out += ` = new ${type}(${JSON.stringify(this.getName('::'))}, `; out += ` = new ${type}(${JSON.stringify(this.getName('::'))}, `;
@ -54,7 +66,7 @@ export class Method {
} else { } else {
out += `${this.getWrapperType()}::ptr_type(${toHex(this.address)})`; out += `${this.getWrapperType()}::ptr_type(${toHex(this.address)})`;
} }
out += `, ${INTERNAL}Thunk<${signature}>::call<&${this.getName()}>)`; out += `, ${INTERNAL}Thunk<${this.#getSignature()}>::call<&${this.getName()}>)`;
} }
out += ';\n'; out += ';\n';
if (!code) { if (!code) {

View File

@ -269,4 +269,27 @@ export class Struct {
setDirectParent(directParent: string) { setDirectParent(directParent: string) {
this.#directParent = directParent; this.#directParent = directParent;
} }
// Generate Part Of enable_all_thunks()
generateEnableThunks() {
// Get All Methods
const allMethods: Method[] = [];
for (const method of this.#methods) {
allMethods.push(method);
}
if (this.#vtable !== null && this.#vtable.canGenerateWrappers()) {
const virtualMethods = this.#vtable.getMethods();
for (const method of virtualMethods) {
if (method) {
allMethods.push(method);
}
}
}
// Generate
let out = '';
for (const method of allMethods) {
out += `${INDENT}${method.getName()}->enable_thunk(thunk_enabler);\n`;
}
return out;
}
} }

View File

@ -105,6 +105,9 @@ export class VTable {
} }
// Generate Header Code // Generate Header Code
canGenerateWrappers() {
return this.#address !== null;
}
generate() { generate() {
let out = ''; let out = '';
@ -113,9 +116,11 @@ export class VTable {
// Wrappers // Wrappers
const methods = this.getMethods(); const methods = this.getMethods();
for (const info of methods) { if (this.canGenerateWrappers()) {
if (info) { for (const info of methods) {
out += info.generate(false, true); if (info) {
out += info.generate(false, true);
}
} }
} }
@ -125,7 +130,7 @@ export class VTable {
for (let i = 0; i < methods.length; i++) { for (let i = 0; i < methods.length; i++) {
const info = methods[i]; const info = methods[i];
if (info) { if (info) {
out += info.getProperty(); out += info.getProperty(this.canGenerateWrappers());
} else { } else {
out += `${INDENT}void *unknown${i};\n`; out += `${INDENT}void *unknown${i};\n`;
} }
@ -162,7 +167,10 @@ export class VTable {
if (this.#address !== null) { if (this.#address !== null) {
// Base // Base
out += `${this.#getName()} *${this.#getName()}_base = (${this.#getName()} *) ${toHex(this.#address)};\n`; out += `${this.#getName()} *${this.#getName()}_base = (${this.#getName()} *) ${toHex(this.#address)};\n`;
// Methods }
// Method Wrappers
if (this.canGenerateWrappers()) {
const methods = this.getMethods(); const methods = this.getMethods();
for (const info of methods) { for (const info of methods) {
if (info) { if (info) {