|
1 | 1 | #include "../executer/ExternalFunctions.h" |
2 | 2 |
|
3 | | - |
4 | 3 | #ifdef WIN |
5 | | -// Windows |
| 4 | +#include <windows.h> |
6 | 5 | #else |
7 | 6 | #include <dlfcn.h> |
8 | 7 | #endif |
@@ -46,17 +45,120 @@ void Arguments::clear() { |
46 | 45 |
|
47 | 46 | #ifdef WIN // Windows |
48 | 47 |
|
49 | | -// TODO: Implement for windows with LoadLibrary and GetProcAddress |
50 | | - |
51 | 48 | ExternalFunctions::ExternalFunctions() = default; |
52 | 49 |
|
53 | 50 | size_t ExternalFunctions::add(const std::string& library, const std::string& functionName) { |
54 | | - throwConstraintViolated("External functions are not supported on this platform"); |
| 51 | + ExternalFunction functionInfo; |
| 52 | + functionInfo.library = library; |
| 53 | + functionInfo.name = functionName; |
| 54 | + functionInfo.functionPtr = nullptr; |
| 55 | + |
| 56 | + std::string libPath = "bin\\lib" + library + ".dll"; |
| 57 | + HINSTANCE libraryHandle = LoadLibrary(libPath.c_str()); |
| 58 | + |
| 59 | + if (!libraryHandle) { |
| 60 | + std::cerr << "Error: Could not load library " << libPath << std::endl; |
| 61 | + throwConstraintViolated("Failed to load library"); |
| 62 | + } |
| 63 | + |
| 64 | + // resolve function address here |
| 65 | + functionInfo.functionPtr = (void*)GetProcAddress(libraryHandle, functionInfo.name.c_str()); |
| 66 | + if (!functionInfo.functionPtr) { |
| 67 | + std::cerr << "Error: Could not located the function " << functionInfo.name << " in " << libPath << std::endl; |
| 68 | + throwConstraintViolated("Failed to find symbol in library"); |
| 69 | + } |
| 70 | + ASSURE_NOT_NULL(functionInfo.functionPtr); |
| 71 | + |
| 72 | + functions.push_back(functionInfo); |
| 73 | + return functions.size() - 1; |
55 | 74 | } |
56 | 75 |
|
57 | 76 | qword_t ExternalFunctions::call(size_t id, const Arguments& args) { |
58 | | - throwConstraintViolated("External functions are not supported on this platform"); |
59 | | - return 0; |
| 77 | + ASSURE(id < functions.size(), "Function ID out of bounds"); |
| 78 | + |
| 79 | + const ExternalFunction& func = functions[id]; |
| 80 | + ASSURE_NOT_NULL(func.functionPtr); |
| 81 | + |
| 82 | + // TODO: We don't support more than 6 arguments yet (put on stack) |
| 83 | + ASSURE(args.getSize() <= 6, "Too many arguments for external function"); |
| 84 | + |
| 85 | + // Windows 64 calling convention: |
| 86 | + // https://learn.microsoft.com/en-us/cpp/build/x64-calling-convention?view=msvc-170 |
| 87 | + // RCX, RDX, R8, and R9, |
| 88 | + // shadow store allocated on stack for those four registers |
| 89 | + // then follow by the rest of the arguments on the stack |
| 90 | + // Floats go into XMM0 - XMM3 |
| 91 | + |
| 92 | + qword_t result; |
| 93 | + __asm__ volatile ( |
| 94 | + "movq %[args_tag], %%R10\n" // Bring arg pointer into R10 |
| 95 | + "movq %[size_tag], %%R11\n" // Bring arg size into R11 |
| 96 | + "movq %%rsp, %%rbx\n" // Clear rbx, used for shadow space |
| 97 | + |
| 98 | + "cmpq $0, (%%R10)\n" // Check if type is None (0) |
| 99 | + "je label_do_call\n" // End of args |
| 100 | + |
| 101 | + "addq $8, %%R10\n" // Bring the pointer to the argument value |
| 102 | + "movq (%%R10), %%rcx\n" // Put into target register |
| 103 | + "add $8, %%R10\n" // Move to next argument |
| 104 | + |
| 105 | + "cmpq $0, (%%R10)\n" // Check if type is None (0) |
| 106 | + "je label_do_call\n" // End of args |
| 107 | + |
| 108 | + "addq $8, %%R10\n" // Bring the pointer to the argument value |
| 109 | + "movq (%%R10), %%rdx\n" // Put into target register |
| 110 | + "add $8, %%R10\n" // Move to next argument |
| 111 | + |
| 112 | + "cmpq $0, (%%R10)\n" // Check if type is None (0) |
| 113 | + "je label_do_call\n" // End of args |
| 114 | + |
| 115 | + "addq $8, %%R10\n" // Bring the pointer to the argument value |
| 116 | + "movq (%%R10), %%r8\n" // Put into target register |
| 117 | + "addq $8, %%R10\n" // Move to next argument |
| 118 | + |
| 119 | + "cmpq $0, (%%R10)\n" // Check if type is None (0) |
| 120 | + "je label_do_call\n" // End of args |
| 121 | + |
| 122 | + "addq $8, %%R10\n" // Bring the pointer to the argument value |
| 123 | + "movq (%%R10), %%r9\n" // Put into target register |
| 124 | + "addq $8, %%R10\n" // Move to next argument |
| 125 | + |
| 126 | + "cmpq $0, (%%R10)\n" // Check if type is None (0) |
| 127 | + "je label_do_call\n" // End of args |
| 128 | + |
| 129 | + // Allocate stack |
| 130 | + "subq $64, %%rsp\n" |
| 131 | + |
| 132 | + // Next argument (stack 32) |
| 133 | + "addq $8, %%R10\n" |
| 134 | + "movq (%%R10), %%r12\n" |
| 135 | + "movq %%r12, 32(%%rsp)\n" |
| 136 | + "addq $8, %%R10\n" |
| 137 | + |
| 138 | + "cmpq $0, (%%R10)\n" |
| 139 | + "je label_do_call\n" |
| 140 | + |
| 141 | + // Next argument (stack 40) |
| 142 | + "addq $8, %%R10\n" |
| 143 | + "movq (%%R10), %%r12\n" |
| 144 | + "movq %%r12, 40(%%rsp)\n" |
| 145 | + "addq $8, %%R10\n" |
| 146 | + |
| 147 | + // Do the call |
| 148 | + "label_do_call:" |
| 149 | + |
| 150 | + "call *%[fn_tag]\n" |
| 151 | + "movq %%rax, %[result_tag]\n" |
| 152 | + |
| 153 | + // Reset stack |
| 154 | + "movq %%rbx, %%rsp\n" |
| 155 | + |
| 156 | + : [result_tag] "=r"(result) |
| 157 | + : [fn_tag] "r"(func.functionPtr), |
| 158 | + [args_tag] "r"(args.getBuffer()), |
| 159 | + [size_tag] "r"(args.getSize())); |
| 160 | + |
| 161 | + return result; |
60 | 162 | } |
61 | 163 |
|
62 | 164 | ExternalFunctions::~ExternalFunctions() = default; |
|
0 commit comments