Refactor Generated Code

This commit is contained in:
TheBrokenRail 2025-02-24 05:23:00 -05:00
parent 27c4f4115b
commit 8dc7b17251
9 changed files with 254 additions and 184 deletions

116
data/function.h Normal file
View File

@ -0,0 +1,116 @@
#include <functional>
#include <utility>
// Information Interface
template <typename Ret, typename... Args>
class __FunctionInfo {
typedef Ret (*type)(Args...);
public:
[[nodiscard]] virtual bool can_overwrite() const = 0;
[[nodiscard]] virtual type get() const = 0;
[[nodiscard]] virtual type *get_addr() const = 0;
virtual void update(type new_func) = 0;
};
// Thunks
typedef void *(*thunk_enabler_t)(void *target, void *thunk);
extern thunk_enabler_t thunk_enabler;
// Function
template <unsigned int, typename T>
class __Function;
template <unsigned int discriminator, typename Ret, typename... Args>
class __Function<discriminator, Ret(Args...)> final {
// Prevent Copying
__PREVENT_COPY(__Function);
__PREVENT_DESTRUCTION(__Function);
// Instance
static __Function<discriminator, Ret(Args...)> *instance;
// Current Function
typedef __FunctionInfo<Ret, Args...> *func_t;
const func_t func;
public:
// Types
typedef Ret (*ptr_type)(Args...);
typedef std::function<Ret(Args...)> type;
typedef std::function<Ret(const type &, Args...)> overwrite_type;
// State
const bool enabled;
const char *const name;
// Backup Of Original Function Pointer
const ptr_type backup;
#ifdef {{ BUILDING_SYMBOLS_GUARD }}
// Constructor
__Function(const char *const name_, const func_t func_):
func(func_),
enabled(func->can_overwrite()),
name(name_),
backup(func->get())
{
instance = this;
}
#else
// Prevent Construction
__PREVENT_JUST_CONSTRUCTION(__Function);
#endif
// Overwrite Function
[[nodiscard]] bool overwrite(const overwrite_type &target) {
// Check If Enabled
if (!enabled) {
return false;
}
// Enable Thunk
enable_thunk();
// Overwrite
type original = get_thunk_target();
thunk_target = [original, target](Args... args) {
return target(original, std::forward<Args>(args)...);
};
return true;
}
// Getters
[[nodiscard]] ptr_type get(const bool result_will_be_stored) {
if (!enabled) {
return nullptr;
} else {
if (result_will_be_stored) {
enable_thunk();
}
return func->get();
}
}
[[nodiscard]] ptr_type *get_vtable_addr() const {
return func->get_addr();
}
private:
// Thunk
[[nodiscard]] type get_thunk_target() const {
if (thunk_target) {
return thunk_target;
} else {
return backup;
}
}
static Ret thunk(Args... args) {
return instance->get_thunk_target()(std::forward<Args>(args)...);
}
// Enable Thunk
type thunk_target;
bool thunk_enabled = false;
void enable_thunk() {
if (!thunk_enabled && enabled) {
ptr_type real_thunk = (ptr_type) thunk_enabler((void *) backup, (void *) thunk);
func->update(real_thunk);
thunk_enabled = true;
}
}
};

10
data/internal.h Normal file
View File

@ -0,0 +1,10 @@
#define __PREVENT_DESTRUCTION(self) \
~self() = delete
#define __PREVENT_JUST_CONSTRUCTION(self) \
self() = delete
#define __PREVENT_CONSTRUCTION(self) \
__PREVENT_JUST_CONSTRUCTION(self); \
__PREVENT_DESTRUCTION(self)
#define __PREVENT_COPY(self) \
self(const self &) = delete; \
self &operator=(const self &) = delete

View File

@ -1,10 +1,66 @@
#define LEAN_SYMBOLS_HEADER #define {{ BUILDING_SYMBOLS_GUARD }}
#include "{{ headerPath }}" #include "{{ headerPath }}"
// Thunk Template // Global Instance
template <auto *const *func> template <unsigned int id, typename Ret, typename... Args>
decltype(auto) __thunk(auto... args) { __Function<id, Ret(Args...)> *__Function<id, Ret(Args...)>::instance;
return (*func)->get_thunk_target()(std::forward<decltype(args)>(args)...);
} // Normal Function Information
template <typename Ret, typename... Args>
class __NormalFunctionInfo final : public __FunctionInfo<Ret, Args...> {
typedef Ret (*type)(Args...);
type func;
public:
// Constructor
explicit __NormalFunctionInfo(const type func_):
func(func_) {}
// Functions
[[nodiscard]] bool can_overwrite() const override {
return true;
}
[[nodiscard]] type get() const override {
return func;
}
[[nodiscard]] type *get_addr() const override {
return nullptr;
}
void update(const type new_func) override {
func = new_func;
}
};
// Virtual Function Information
template <typename Ret, typename... Args>
class __VirtualFunctionInfo final : public __FunctionInfo<Ret, Args...> {
typedef Ret (*type)(Args...);
type *const addr;
void *const parent;
public:
// Constructor
__VirtualFunctionInfo(type *const addr_, void *const parent_):
addr(addr_),
parent(parent_) {}
// Functions
[[nodiscard]] bool can_overwrite() const override {
// If this function's address matches its parent's,
// then it was just inherited and does not actually exist.
// Overwriting this function would also overwrite its parent
// which would cause undesired behavior.
return get() != parent;
}
[[nodiscard]] type get() const override {
return *get_addr();
}
[[nodiscard]] type *get_addr() const override {
return addr;
}
void update(const type new_func) override {
// Address Should Have Already Been Updated
if (get() != new_func) {
__builtin_trap();
}
}
};
#undef super
{{ main }} {{ main }}

View File

@ -5,153 +5,19 @@
#error "Symbols Are ARM-Only" #error "Symbols Are ARM-Only"
#endif #endif
// Internal Macros
{{ include internal.h }}
// Function Object
{{ include function.h }}
// Headers // Headers
#include <variant>
#include <functional>
#include <cstddef> #include <cstddef>
#include <string> #include <string>
#include <vector> #include <vector>
#include <map> #include <map>
#include <type_traits>
#include <cstring> #include <cstring>
// Internal Macros
#define __PREVENT_DESTRUCTION(self) \
~self() = delete
#define __PREVENT_CONSTRUCTION(self) \
self() = delete; \
__PREVENT_DESTRUCTION(self)
#define __PREVENT_COPY(self) \
self(const self &) = delete; \
self &operator=(const self &) = delete
// Virtual Function Information
struct __VirtualFunctionInfo {
// Constructors
template <typename Ret, typename Self, typename Super, typename... Args>
__VirtualFunctionInfo(Ret (**const addr_)(Self, Args...), Ret (*const parent_)(Super, Args...)):
addr((void **) addr_),
parent((void *) parent_) {}
template <typename T>
__VirtualFunctionInfo(T **const addr_, const std::nullptr_t parent_):
__VirtualFunctionInfo(addr_, (T *) parent_) {}
// Method
[[nodiscard]] bool can_overwrite() const {
return *addr != parent;
}
// Properties
void **const addr;
void *const parent;
};
// Thunks
typedef void *(*thunk_enabler_t)(void *target, void *thunk);
extern thunk_enabler_t thunk_enabler;
// Function Information
template <typename T>
class __Function;
template <typename Ret, typename... Args>
class __Function<Ret(Args...)> {
// Prevent Copying
__PREVENT_COPY(__Function);
__PREVENT_DESTRUCTION(__Function);
public:
// Types
typedef Ret (*ptr_type)(Args...);
typedef std::function<Ret(Args...)> type;
typedef std::function<Ret(const type &, Args...)> overwrite_type;
// Normal Function
__Function(const std::string name_, const ptr_type thunk_, const ptr_type func_):
func(func_),
enabled(true),
name(name_),
backup(func_),
thunk(thunk_) {}
// Virtual Function
template <typename Parent>
__Function(const std::string name_, const ptr_type thunk_, ptr_type *const func_, const Parent parent):
func(__VirtualFunctionInfo(func_, parent)),
enabled(std::get<__VirtualFunctionInfo>(func).can_overwrite()),
name(name_),
backup(*get_vtable_addr()),
thunk(thunk_) {}
// Overwrite Function
[[nodiscard]] bool overwrite(overwrite_type target) {
// Check If Enabled
if (!enabled) {
return false;
}
// Enable Thunk
enable_thunk();
// Overwrite
type original = get_thunk_target();
thunk_target = [original, target](Args... args) {
return target(original, std::forward<Args>(args)...);
};
return true;
}
// Getters
[[nodiscard]] ptr_type get(bool result_will_be_stored) {
if (!enabled) {
return nullptr;
} else {
if (result_will_be_stored) {
enable_thunk();
}
if (is_virtual()) {
return *get_vtable_addr();
} else {
return std::get<ptr_type>(func);
}
}
}
[[nodiscard]] ptr_type *get_vtable_addr() const {
return (ptr_type *) std::get<__VirtualFunctionInfo>(func).addr;
}
[[nodiscard]] type get_thunk_target() const {
if (thunk_target) {
return thunk_target;
} else {
return backup;
}
}
private:
// Current Function
std::variant<ptr_type, __VirtualFunctionInfo> func;
[[nodiscard]] bool is_virtual() const {
return func.index() == 1;
}
public:
// State
const bool enabled;
const std::string name;
// Backup Of Original Function Pointer
const ptr_type backup;
private:
// Thunk
const ptr_type thunk;
type thunk_target;
bool thunk_enabled = false;
void enable_thunk() {
if (!thunk_enabled && enabled) {
ptr_type real_thunk = (ptr_type) thunk_enabler((void *) backup, (void *) thunk);
if (!is_virtual()) {
func = real_thunk;
}
thunk_enabled = true;
}
}
};
// Shortcuts // Shortcuts
typedef unsigned char uchar; typedef unsigned char uchar;
typedef unsigned short ushort; typedef unsigned short ushort;

View File

@ -9,7 +9,7 @@ export const EXTENSION = '.def';
export const STRUCTURE_FILES: Record<string, string> = {}; export const STRUCTURE_FILES: Record<string, string> = {};
export const COMMENT = '//'; export const COMMENT = '//';
export const INTERNAL = '__'; export const INTERNAL = '__';
export const LEAN_HEADER_GUARD = '#ifndef LEAN_SYMBOLS_HEADER\n'; export const BUILDING_SYMBOLS_GUARD = 'BUILDING_SYMBOLS_LIB';
// Read Definition File // Read Definition File
export function readDefinition(name: string) { export function readDefinition(name: string) {
if (!STRUCTURE_FILES[name]) { if (!STRUCTURE_FILES[name]) {
@ -117,20 +117,23 @@ export function getDataDir() {
return path.join(__dirname, '..', 'data'); return path.join(__dirname, '..', 'data');
} }
// Format File From Data Directory // Format File From Data Directory
export function formatFile(file: string, options: Record<string, string>) { export function formatFile(file: string, options: Record<string, string>, includeOtherFiles = true) {
// Include Other Files const newOptions = Object.assign({}, options);
const dataDir = getDataDir(); const dataDir = getDataDir();
const otherFiles = fs.readdirSync(dataDir); // Include Other Files
for (let otherFile of otherFiles) { if (includeOtherFiles) {
otherFile = path.join(dataDir, otherFile); const otherFiles = fs.readdirSync(dataDir);
options[`include ${otherFile}`] = fs.readFileSync(otherFile, 'utf8'); for (const otherFile of otherFiles) {
newOptions[`include ${otherFile}`] = formatFile(otherFile, options, false);
}
} }
// Format // Format
file = path.join(dataDir, file); file = path.join(dataDir, file);
let data = fs.readFileSync(file, 'utf8'); let data = fs.readFileSync(file, 'utf8');
for (const key in options) { for (const key in newOptions) {
const value = options[key]; let value = newOptions[key];
if (value) { if (value) {
value = value.trim();
data = data.replace(`{{ ${key} }}`, value); data = data.replace(`{{ ${key} }}`, value);
} }
} }

View File

@ -1,6 +1,6 @@
import * as fs from 'node:fs'; import * as fs from 'node:fs';
import * as path from 'node:path'; import * as path from 'node:path';
import { STRUCTURE_FILES, EXTENSION, formatFile, getDataDir, extendErrorMessage } from './common'; import { STRUCTURE_FILES, EXTENSION, formatFile, getDataDir, extendErrorMessage, BUILDING_SYMBOLS_GUARD } from './common';
import { getStructure } from './map'; import { getStructure } from './map';
import { Struct } from './struct'; import { Struct } from './struct';
@ -41,7 +41,7 @@ while (process.argv.length > 0) {
const fullName = file.base; const fullName = file.base;
const name = file.name; const name = file.name;
// Store // Store
if (name in STRUCTURE_FILES) { if (STRUCTURE_FILES[name]) {
throw new Error(`Multiple Definition Files Provided: ${fullName}`); throw new Error(`Multiple Definition Files Provided: ${fullName}`);
} }
STRUCTURE_FILES[name] = filePath; STRUCTURE_FILES[name] = filePath;
@ -129,7 +129,7 @@ function makeHeaderPart() {
try { try {
structures += structure.generate(); structures += structure.generate();
} catch (e) { } catch (e) {
throw new Error(extendErrorMessage(e, 'Error Generating Header: ' + name)); throw new Error(extendErrorMessage(e, 'Generating Header: ' + name));
} }
structures += '\n'; structures += '\n';
} }
@ -155,7 +155,7 @@ function makeMainHeader(output: string) {
// Main // Main
const main = makeHeaderPart().trim(); const main = makeHeaderPart().trim();
// Write // Write
const result = formatFile('out.h', {forwardDeclarations, extraHeaders, main, data: getDataDir()}); const result = formatFile('out.h', {BUILDING_SYMBOLS_GUARD, forwardDeclarations, extraHeaders, main, data: getDataDir()});
fs.writeFileSync(output, result); fs.writeFileSync(output, result);
} }
makeMainHeader(headerOutput); makeMainHeader(headerOutput);
@ -182,14 +182,14 @@ function makeCompiledCode(outputDir: string) {
try { try {
declarations += structure.generateCode().trim(); declarations += structure.generateCode().trim();
} catch (e) { } catch (e) {
throw new Error(extendErrorMessage(e, 'Error Generating Code: ' + name)); throw new Error(extendErrorMessage(e, 'Generating Code: ' + name));
} }
declarations += '\n'; declarations += '\n';
// Write // Write
const headerPath = fs.realpathSync(headerOutput); const headerPath = fs.realpathSync(headerOutput);
const main = declarations.trim(); const main = declarations.trim();
const result = formatFile('out.cpp', {headerPath, main, data: getDataDir()}); const result = formatFile('out.cpp', {BUILDING_SYMBOLS_GUARD, headerPath, main, data: getDataDir()});
const output = path.join(outputDir, name + '.cpp'); const output = path.join(outputDir, name + '.cpp');
fs.writeFileSync(output, result); fs.writeFileSync(output, result);
} }

View File

@ -1,6 +1,11 @@
import { INDENT, INTERNAL, formatType, toHex } from './common'; import { INDENT, INTERNAL, formatType, toHex } from './common';
// A Template Parameter So Each Template Instantiation Is Unique
let nextDiscriminator = 0;
// An Individual Method
export class Method { export class Method {
readonly #discriminator: number;
readonly self: string; readonly self: string;
readonly shortName: string; readonly shortName: string;
readonly returnType: string; readonly returnType: string;
@ -11,6 +16,7 @@ export class Method {
// Constructor // Constructor
constructor(self: string, name: string, returnType: string, args: string, address: number, isInherited: boolean, isStatic: boolean) { constructor(self: string, name: string, returnType: string, args: string, address: number, isInherited: boolean, isStatic: boolean) {
this.#discriminator = nextDiscriminator++;
this.self = self; this.self = self;
this.shortName = name; this.shortName = name;
this.returnType = returnType; this.returnType = returnType;
@ -34,7 +40,7 @@ export class Method {
return `${INDENT}${this.#getRawType()} *${this.shortName};\n`; return `${INDENT}${this.#getRawType()} *${this.shortName};\n`;
} }
#getFullType() { #getFullType() {
return `${INTERNAL}Function<${this.#getRawType()}>`; return `${INTERNAL}Function<${this.#discriminator.toString()}, ${this.#getRawType()}>`;
} }
// Typedefs // Typedefs
@ -52,16 +58,14 @@ export class Method {
generate(code: boolean, isVirtual: boolean, parentSelf?: string) { generate(code: boolean, isVirtual: boolean, parentSelf?: string) {
let out = ''; let out = '';
out += 'extern '; out += 'extern ';
const type = this.#getFullType(); out += `${this.#getFullType()} *const ${this.getName()}`;
out += `${type} *const ${this.getName()}`;
if (code) { if (code) {
out += ` = new ${type}(${JSON.stringify(this.getName('::'))}, `; out += ` = new ${this.#getFullType()}(${JSON.stringify(this.getName('::'))}, `;
out += `${INTERNAL}thunk<&${this.getName()}>, `;
if (isVirtual) { if (isVirtual) {
const parentMethod = parentSelf ? this.#getVirtualCall(parentSelf) : 'nullptr'; const parentMethod = parentSelf ? this.#getVirtualCall(parentSelf) : 'nullptr';
out += `&${this.#getVirtualCall()}, ${parentMethod}`; out += `new ${INTERNAL}VirtualFunctionInfo(&${this.#getVirtualCall()}, (void *) ${parentMethod})`;
} else { } else {
out += `(${this.#getRawType()} *) ${toHex(this.address)}`; out += `new ${INTERNAL}NormalFunctionInfo((${this.#getRawType()} *) ${toHex(this.address)})`;
} }
out += ')'; out += ')';
} }

View File

@ -1,4 +1,4 @@
import { INDENT, STRUCTURE_FILES, toHex, assertSize, INTERNAL, preventConstruction, LEAN_HEADER_GUARD } from './common'; import { INDENT, STRUCTURE_FILES, toHex, assertSize, INTERNAL, preventConstruction, BUILDING_SYMBOLS_GUARD } from './common';
import { Method } from './method'; import { Method } from './method';
import { Property, StaticProperty } from './property'; import { Property, StaticProperty } from './property';
import { VTable } from './vtable'; import { VTable } from './vtable';
@ -71,7 +71,7 @@ export class Struct {
this.#properties.push(property); this.#properties.push(property);
// Add Dependency If Needed // Add Dependency If Needed
const type = property.rawType(); const type = property.rawType();
if (type in STRUCTURE_FILES) { if (STRUCTURE_FILES[type]) {
this.#addDependency(type); this.#addDependency(type);
} }
} }
@ -165,7 +165,7 @@ export class Struct {
} }
#generateMethods() { #generateMethods() {
let out = ''; let out = '';
out += LEAN_HEADER_GUARD; out += `#ifndef ${BUILDING_SYMBOLS_GUARD}\n`;
// Normal Methods // Normal Methods
for (const method of this.#methods) { for (const method of this.#methods) {
out += this.#generateMethod(method, false); out += this.#generateMethod(method, false);
@ -174,7 +174,9 @@ export class Struct {
if (this.#vtable !== null) { if (this.#vtable !== null) {
const virtualMethods = this.#vtable.getMethods(); const virtualMethods = this.#vtable.getMethods();
for (const method of virtualMethods) { for (const method of virtualMethods) {
out += this.#generateMethod(method, true); if (method) {
out += this.#generateMethod(method, true);
}
} }
} }
// Allocation Method // Allocation Method
@ -223,7 +225,7 @@ export class Struct {
} }
} }
out += typedefs; out += typedefs;
out += LEAN_HEADER_GUARD; out += `#ifndef ${BUILDING_SYMBOLS_GUARD}\n`;
out += methodsOut; out += methodsOut;
out += '#endif\n'; out += '#endif\n';

View File

@ -1,4 +1,4 @@
import { INDENT, LEAN_HEADER_GUARD, POINTER_SIZE, assertSize, getSelfArg, preventConstruction, toHex } from './common'; import { BUILDING_SYMBOLS_GUARD, INDENT, POINTER_SIZE, assertSize, getSelfArg, preventConstruction, toHex } from './common';
import { Method } from './method'; import { Method } from './method';
import { Property } from './property'; import { Property } from './property';
@ -7,9 +7,10 @@ export class VTable {
readonly #self: string; readonly #self: string;
#address: number | null; #address: number | null;
#size: number | null; #size: number | null;
readonly #methods: Method[]; readonly #methods: (Method | undefined)[];
readonly property: Property; readonly property: Property;
#destructorOffset: number; #destructorOffset: number;
readonly #destructors: Method[];
// Constructor // Constructor
constructor(self: string) { constructor(self: string) {
@ -18,6 +19,7 @@ export class VTable {
this.#size = null; this.#size = null;
this.#methods = []; this.#methods = [];
this.#destructorOffset = 0; this.#destructorOffset = 0;
this.#destructors = [];
// Create Property // Create Property
this.property = new Property(0, this.#getName() + ' *', 'vtable', this.#self); this.property = new Property(0, this.#getName() + ' *', 'vtable', this.#self);
} }
@ -35,7 +37,7 @@ export class VTable {
} }
// Add To VTable // Add To VTable
#add(target: Method[], method: Method) { #add(target: (Method | undefined)[], method: Method) {
// Check Offset // Check Offset
const offset = method.address; const offset = method.address;
if ((offset % POINTER_SIZE) !== 0) { if ((offset % POINTER_SIZE) !== 0) {
@ -43,7 +45,7 @@ export class VTable {
} }
// Add // Add
const index = offset / POINTER_SIZE; const index = offset / POINTER_SIZE;
if (index in target) { if (target[index]) {
throw new Error(`Duplicate Virtual Method At Offset: ${toHex(offset)}`); throw new Error(`Duplicate Virtual Method At Offset: ${toHex(offset)}`);
} }
target[index] = method; target[index] = method;
@ -80,8 +82,15 @@ export class VTable {
// Add Destructors (https://stackoverflow.com/a/17960941) // Add Destructors (https://stackoverflow.com/a/17960941)
const destructor_return = `${this.#self} *`; const destructor_return = `${this.#self} *`;
const destructor_args = `(${getSelfArg(this.#self)})`; const destructor_args = `(${getSelfArg(this.#self)})`;
this.#add(out, new Method(this.#self, 'destructor_complete', destructor_return, destructor_args, 0x0 + this.#destructorOffset, false, false)); if (this.#destructors.length === 0) {
this.#add(out, new Method(this.#self, 'destructor_deleting', destructor_return, destructor_args, 0x4 + this.#destructorOffset, false, false)); this.#destructors.push(
new Method(this.#self, 'destructor_complete', destructor_return, destructor_args, 0x0 + this.#destructorOffset, false, false),
new Method(this.#self, 'destructor_deleting', destructor_return, destructor_args, 0x4 + this.#destructorOffset, false, false)
);
}
for (const destructor of this.#destructors) {
this.#add(out, destructor);
}
// Return // Return
return out; return out;
} }
@ -119,13 +128,15 @@ export class VTable {
let typedefs = ''; let typedefs = '';
let methodsOut = ''; let methodsOut = '';
for (const info of methods) { for (const info of methods) {
typedefs += info.generateTypedefs(); if (info) {
if (this.canGenerateWrappers()) { typedefs += info.generateTypedefs();
methodsOut += info.generate(false, true); if (this.canGenerateWrappers()) {
methodsOut += info.generate(false, true);
}
} }
} }
out += typedefs; out += typedefs;
out += LEAN_HEADER_GUARD; out += `#ifndef ${BUILDING_SYMBOLS_GUARD}\n`;
out += methodsOut; out += methodsOut;
out += '#endif\n'; out += '#endif\n';
@ -176,7 +187,9 @@ export class VTable {
if (this.canGenerateWrappers()) { if (this.canGenerateWrappers()) {
const methods = this.getMethods(); const methods = this.getMethods();
for (const info of methods) { for (const info of methods) {
out += info.generate(true, true, this.#getParentSelf(info, directParent)); if (info) {
out += info.generate(true, true, this.#getParentSelf(info, directParent));
}
} }
} }