Skip to content

Commit c1d3bae

Browse files
committed
Merge branch 'MmpTlsFiber'
2 parents b2e4ed4 + 4fa5e97 commit c1d3bae

16 files changed

Lines changed: 556 additions & 299 deletions

MemoryModule/ImportTable.cpp

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
#include "stdafx.h"
2+
3+
typedef struct _MMP_IAT_HANDLE {
4+
5+
HMODULE hModule;
6+
PMM_IAT_RESOLVER lpResolver;
7+
8+
}MMP_IAT_HANDLE, * PMMP_IAT_HANDLE;
9+
10+
HMODULE MmpLoadLibraryA(
11+
_In_ LPCSTR lpModuleName,
12+
_Out_ PMM_IAT_RESOLVER* lpModuleResolver) {
13+
14+
HMODULE hModule = nullptr;
15+
PMM_IAT_RESOLVER resolver = nullptr;
16+
17+
EnterCriticalSection(&MmpGlobalDataPtr->MmpIat->MmpIatResolverListLock);
18+
19+
PLIST_ENTRY lpResolver = MmpGlobalDataPtr->MmpIat->MmpIatResolverList.Flink;
20+
while (lpResolver != &MmpGlobalDataPtr->MmpIat->MmpIatResolverList) {
21+
PMM_IAT_RESOLVER entry = CONTAINING_RECORD(lpResolver, MM_IAT_RESOLVER, MM_IAT_RESOLVER::InMmpIatResolverList);
22+
23+
hModule = entry->LoadLibraryProv(lpModuleName);
24+
if (hModule) {
25+
resolver = entry;
26+
++entry->ReferenceCount;
27+
break;
28+
}
29+
30+
lpResolver = lpResolver->Flink;
31+
}
32+
33+
LeaveCriticalSection(&MmpGlobalDataPtr->MmpIat->MmpIatResolverListLock);
34+
35+
*lpModuleResolver = resolver;
36+
return hModule;
37+
}
38+
39+
VOID MemoryFreeImportTable(_In_ PMEMORYMODULE hMemoryModule) {
40+
41+
EnterCriticalSection(&MmpGlobalDataPtr->MmpIat->MmpIatResolverListLock);
42+
43+
PMMP_IAT_HANDLE list = (PMMP_IAT_HANDLE)hMemoryModule->hModulesList;
44+
for (DWORD i = 0; i < hMemoryModule->dwModulesCount; ++i) {
45+
auto entry = list[i];
46+
entry.lpResolver->FreeLibraryProv(entry.hModule);
47+
--entry.lpResolver->ReferenceCount;
48+
}
49+
50+
LeaveCriticalSection(&MmpGlobalDataPtr->MmpIat->MmpIatResolverListLock);
51+
52+
53+
RtlFreeHeap(NtCurrentPeb()->ProcessHeap, 0, hMemoryModule->hModulesList);
54+
hMemoryModule->hModulesList = nullptr;
55+
hMemoryModule->dwModulesCount = 0;
56+
}
57+
58+
NTSTATUS MemoryResolveImportTable(
59+
_In_ LPBYTE base,
60+
_In_ PIMAGE_NT_HEADERS lpNtHeaders,
61+
_In_ PMEMORYMODULE hMemoryModule) {
62+
NTSTATUS status = STATUS_SUCCESS;
63+
PIMAGE_IMPORT_DESCRIPTOR importDesc = nullptr;
64+
DWORD count = 0;
65+
66+
do {
67+
__try {
68+
PIMAGE_DATA_DIRECTORY dir = &lpNtHeaders->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT];
69+
PIMAGE_IMPORT_DESCRIPTOR iat = nullptr;
70+
71+
if (dir && dir->Size) {
72+
iat = importDesc = PIMAGE_IMPORT_DESCRIPTOR(lpNtHeaders->OptionalHeader.ImageBase + dir->VirtualAddress);
73+
}
74+
75+
if (iat) {
76+
while (iat->Name) {
77+
++count;
78+
++iat;
79+
}
80+
}
81+
82+
if (importDesc && count) {
83+
PMMP_IAT_HANDLE handles = (PMMP_IAT_HANDLE)RtlAllocateHeap(NtCurrentPeb()->ProcessHeap, HEAP_ZERO_MEMORY, sizeof(MMP_IAT_HANDLE) * count);
84+
hMemoryModule->hModulesList = handles;
85+
if (!hMemoryModule->hModulesList) {
86+
status = STATUS_NO_MEMORY;
87+
break;
88+
}
89+
90+
for (DWORD i = 0; i < count; ++i, ++importDesc) {
91+
uintptr_t* thunkRef;
92+
FARPROC* funcRef;
93+
PMM_IAT_RESOLVER resolver;
94+
HMODULE handle = MmpLoadLibraryA((LPCSTR)(base + importDesc->Name), &resolver);
95+
96+
if (!handle) {
97+
status = STATUS_DLL_NOT_FOUND;
98+
break;
99+
}
100+
101+
handles[hMemoryModule->dwModulesCount].hModule = handle;
102+
handles[hMemoryModule->dwModulesCount++].lpResolver = resolver;
103+
thunkRef = (uintptr_t*)(base + (importDesc->OriginalFirstThunk ? importDesc->OriginalFirstThunk : importDesc->FirstThunk));
104+
funcRef = (FARPROC*)(base + importDesc->FirstThunk);
105+
while (*thunkRef) {
106+
*funcRef = GetProcAddress(
107+
handle,
108+
IMAGE_SNAP_BY_ORDINAL(*thunkRef) ? (LPCSTR)IMAGE_ORDINAL(*thunkRef) : (LPCSTR)PIMAGE_IMPORT_BY_NAME(base + (*thunkRef))->Name
109+
);
110+
if (!*funcRef) {
111+
status = STATUS_ENTRYPOINT_NOT_FOUND;
112+
break;
113+
}
114+
++thunkRef;
115+
++funcRef;
116+
}
117+
118+
if (!NT_SUCCESS(status))break;
119+
}
120+
121+
}
122+
}
123+
__except (EXCEPTION_EXECUTE_HANDLER) {
124+
status = GetExceptionCode();
125+
}
126+
} while (false);
127+
128+
if (!NT_SUCCESS(status)) {
129+
MemoryFreeImportTable(hMemoryModule);
130+
}
131+
132+
return status;
133+
}
134+
135+
HANDLE WINAPI MmRegisterImportTableResolver(
136+
_In_ MM_IAT_RESOLVER_ENTRY LoadLibraryProv,
137+
_In_ MM_IAT_FREE_ENTRY FreeLibraryProv) {
138+
139+
HANDLE heap = RtlProcessHeap();
140+
PMM_IAT_RESOLVER resolver = (PMM_IAT_RESOLVER)RtlAllocateHeap(heap, 0, sizeof(MM_IAT_RESOLVER));
141+
142+
if (resolver) {
143+
EnterCriticalSection(&MmpGlobalDataPtr->MmpIat->MmpIatResolverListLock);
144+
145+
resolver->ReferenceCount = 1;
146+
resolver->LoadLibraryProv = LoadLibraryProv;
147+
resolver->FreeLibraryProv = FreeLibraryProv;
148+
InsertTailList(&MmpGlobalDataPtr->MmpIat->MmpIatResolverList, &resolver->InMmpIatResolverList);
149+
150+
LeaveCriticalSection(&MmpGlobalDataPtr->MmpIat->MmpIatResolverListLock);
151+
}
152+
153+
return resolver;
154+
}
155+
156+
_Success_(return)
157+
BOOL WINAPI MmRemoveImportTableResolver(_In_ HANDLE hMmIatResolver) {
158+
159+
HANDLE heap = RtlProcessHeap();
160+
161+
if (hMmIatResolver == &MmpGlobalDataPtr->MmpIat->MmpIatResolverHead) {
162+
return FALSE;
163+
}
164+
165+
PMM_IAT_RESOLVER resolver = CONTAINING_RECORD(hMmIatResolver, MM_IAT_RESOLVER, MM_IAT_RESOLVER::InMmpIatResolverList);
166+
167+
EnterCriticalSection(&MmpGlobalDataPtr->MmpIat->MmpIatResolverListLock);
168+
169+
if (resolver->ReferenceCount > 1) {
170+
LeaveCriticalSection(&MmpGlobalDataPtr->MmpIat->MmpIatResolverListLock);
171+
return FALSE;
172+
}
173+
174+
RemoveHeadList(&resolver->InMmpIatResolverList);
175+
LeaveCriticalSection(&MmpGlobalDataPtr->MmpIat->MmpIatResolverListLock);
176+
177+
return RtlFreeHeap(heap, 0, hMmIatResolver);
178+
}

MemoryModule/ImportTable.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#pragma once
2+
3+
typedef HMODULE(WINAPI* MM_IAT_RESOLVER_ENTRY)(LPCSTR lpModuleName);
4+
typedef BOOL(WINAPI* MM_IAT_FREE_ENTRY)(HMODULE hModule);
5+
6+
typedef struct _MM_IAT_RESOLVER {
7+
8+
LIST_ENTRY InMmpIatResolverList;
9+
10+
MM_IAT_RESOLVER_ENTRY LoadLibraryProv;
11+
MM_IAT_FREE_ENTRY FreeLibraryProv;
12+
13+
DWORD ReferenceCount;
14+
15+
}MM_IAT_RESOLVER, * PMM_IAT_RESOLVER;
16+
17+
VOID MemoryFreeImportTable(_In_ PMEMORYMODULE hMemoryModule);
18+
19+
NTSTATUS MemoryResolveImportTable(
20+
_In_ LPBYTE base,
21+
_In_ PIMAGE_NT_HEADERS lpNtHeaders,
22+
_In_ PMEMORYMODULE hMemoryModule
23+
);
24+
25+
HANDLE WINAPI MmRegisterImportTableResolver(
26+
_In_ MM_IAT_RESOLVER_ENTRY LoadLibraryProv,
27+
_In_ MM_IAT_FREE_ENTRY FreeLibraryProv
28+
);
29+
30+
_Success_(return)
31+
BOOL WINAPI MmRemoveImportTableResolver(_In_ HANDLE hMmIatResolver);

MemoryModule/Initialize.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55

66
PMMP_GLOBAL_DATA MmpGlobalDataPtr;
77

8+
#if MEMORY_MODULE_IS_PREVIEW(MEMORY_MODULE_MINOR_VERSION)
9+
#pragma message("WARNING: You are using a preview version of MemoryModulePP.")
10+
#endif
11+
812
PRTL_RB_TREE FindLdrpModuleBaseAddressIndex() {
913
PRTL_RB_TREE LdrpModuleBaseAddressIndex = nullptr;
1014
PLDR_DATA_TABLE_ENTRY_WIN10 nt10 = decltype(nt10)(MmpGlobalDataPtr->MmpBaseAddressIndex->NtdllLdrEntry);
@@ -427,8 +431,10 @@ NTSTATUS InitializeLockHeld() {
427431
status = MmpAllocateGlobalData();
428432
if (!NT_SUCCESS(status)) {
429433
if (status == STATUS_ALREADY_INITIALIZED) {
430-
if ((MmpGlobalDataPtr->MajorVersion < MEMORY_MODULE_MAJOR_VERSION) ||
431-
(MmpGlobalDataPtr->MajorVersion == MEMORY_MODULE_MAJOR_VERSION && MmpGlobalDataPtr->MinorVersion < MEMORY_MODULE_MINOR_VERSION)) {
434+
if ((MmpGlobalDataPtr->MajorVersion != MEMORY_MODULE_MAJOR_VERSION) ||
435+
MEMORY_MODULE_IS_PREVIEW(MmpGlobalDataPtr->MinorVersion) != MEMORY_MODULE_IS_PREVIEW(MEMORY_MODULE_MINOR_VERSION) ||
436+
(MEMORY_MODULE_IS_PREVIEW(MEMORY_MODULE_MINOR_VERSION) ? MmpGlobalDataPtr->MinorVersion != MEMORY_MODULE_MINOR_VERSION :
437+
MmpGlobalDataPtr->MinorVersion < MEMORY_MODULE_MINOR_VERSION)) {
432438
status = STATUS_NOT_SUPPORTED;
433439
}
434440
else {
@@ -458,6 +464,7 @@ NTSTATUS InitializeLockHeld() {
458464
MmpGlobalDataPtr->MmpTls = (PMMP_TLS_DATA)((LPBYTE)MmpGlobalDataPtr->MmpLdrEntry + sizeof(MMP_LDR_ENTRY_DATA));
459465
MmpGlobalDataPtr->MmpDotNet = (PMMP_DOT_NET_DATA)((LPBYTE)MmpGlobalDataPtr->MmpTls + sizeof(MMP_TLS_DATA));
460466
MmpGlobalDataPtr->MmpFunctions = (PMMP_FUNCTIONS)((LPBYTE)MmpGlobalDataPtr->MmpDotNet + sizeof(MMP_DOT_NET_DATA));
467+
MmpGlobalDataPtr->MmpIat = (PMMP_IAT_DATA)((LPBYTE)MmpGlobalDataPtr->MmpFunctions + sizeof(MMP_FUNCTIONS));
461468

462469
PLDR_DATA_TABLE_ENTRY pNtdllEntry = RtlFindLdrTableEntryByBaseName(L"ntdll.dll");
463470
MmpGlobalDataPtr->MmpBaseAddressIndex->NtdllLdrEntry = pNtdllEntry;
@@ -480,6 +487,14 @@ NTSTATUS InitializeLockHeld() {
480487
MmpGlobalDataPtr->MmpFunctions->_MmpHandleTlsData = MmpHandleTlsData;
481488
MmpGlobalDataPtr->MmpFunctions->_MmpReleaseTlsEntry = MmpReleaseTlsEntry;
482489

490+
InitializeCriticalSection(&MmpGlobalDataPtr->MmpIat->MmpIatResolverListLock);
491+
InitializeListHead(&MmpGlobalDataPtr->MmpIat->MmpIatResolverList);
492+
InitializeListHead(&MmpGlobalDataPtr->MmpIat->MmpIatResolverHead.InMmpIatResolverList);
493+
MmpGlobalDataPtr->MmpIat->MmpIatResolverHead.LoadLibraryProv = LoadLibraryA;
494+
MmpGlobalDataPtr->MmpIat->MmpIatResolverHead.FreeLibraryProv = FreeLibrary;
495+
MmpGlobalDataPtr->MmpIat->MmpIatResolverHead.ReferenceCount = 1;
496+
InsertTailList(&MmpGlobalDataPtr->MmpIat->MmpIatResolverList, &MmpGlobalDataPtr->MmpIat->MmpIatResolverHead.InMmpIatResolverList);
497+
483498
MmpTlsInitialize();
484499

485500
MmpGlobalDataPtr->MmpDotNet->Initialized = MmpGlobalDataPtr->MmpDotNet->PreHooked = FALSE;

MemoryModule/LdrEntry.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,8 @@ BOOL NTAPI RtlInitializeLdrDataTableEntry(
151151
entry->DdagNode->LoadCount = 1;
152152
if (IsWin8) ((_LDR_DDAG_NODE_WIN8*)(entry->DdagNode))->ReferenceCount = 1;
153153
entry->ImageDll = entry->LoadNotificationsSent = entry->EntryProcessed =
154-
entry->InLegacyLists = entry->InIndexes = entry->ProcessAttachCalled = true;
154+
entry->InLegacyLists = entry->InIndexes = true;
155+
entry->ProcessAttachCalled = headers->OptionalHeader.AddressOfEntryPoint != 0;
155156
entry->InExceptionTable = !(dwFlags & LOAD_FLAGS_NOT_ADD_INVERTED_FUNCTION);
156157
entry->CorImage = CorImage;
157158
entry->CorILOnly = CorIL;
@@ -184,7 +185,10 @@ BOOL NTAPI RtlInitializeLdrDataTableEntry(
184185
LdrEntry->EntryPoint = (PLDR_INIT_ROUTINE)((size_t)BaseAddress + headers->OptionalHeader.AddressOfEntryPoint);
185186
LdrEntry->ObsoleteLoadCount = 1;
186187
if (!FlagsProcessed) {
187-
LdrEntry->Flags = LDRP_IMAGE_DLL | LDRP_ENTRY_INSERTED | LDRP_ENTRY_PROCESSED | LDRP_PROCESS_ATTACH_CALLED;
188+
LdrEntry->Flags = LDRP_IMAGE_DLL | LDRP_ENTRY_INSERTED | LDRP_ENTRY_PROCESSED;
189+
190+
if (headers->OptionalHeader.AddressOfEntryPoint != 0)LdrEntry->Flags |= LDRP_PROCESS_ATTACH_CALLED;
191+
188192
if (CorImage)LdrEntry->Flags |= LDRP_COR_IMAGE;
189193
}
190194
InitializeListHead(&LdrEntry->HashLinks);

MemoryModule/Loader.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "stdafx.h"
2+
#include <cmath>
23

34
NTSTATUS NTAPI LdrMapDllMemory(
45
_In_ HMEMORYMODULE ViewBase,
@@ -64,6 +65,7 @@ NTSTATUS NTAPI LdrLoadDllMemoryExW(
6465
if (dwFlags & LOAD_FLAGS_USE_DLL_NAME && (!DllName || !DllFullName))return STATUS_INVALID_PARAMETER_3;
6566

6667
if (DllName) {
68+
int length = (int)wcslen(DllName);
6769
PLIST_ENTRY ListHead = &NtCurrentPeb()->Ldr->InLoadOrderModuleList, ListEntry = ListHead->Flink;
6870
PIMAGE_NT_HEADERS h1 = RtlImageNtHeader(BufferAddress), h2 = nullptr;
6971
if (!h1)return STATUS_INVALID_IMAGE_FORMAT;
@@ -74,11 +76,19 @@ NTSTATUS NTAPI LdrLoadDllMemoryExW(
7476

7577
/* Check if it's being unloaded */
7678
if (!CurEntry->InMemoryOrderLinks.Flink) continue;
79+
80+
auto dist = (CurEntry->BaseDllName.Length / sizeof(wchar_t)) - length;
81+
bool equal = false;
82+
if (dist == 0 || dist == 4) {
83+
equal = !_wcsnicmp(DllName, CurEntry->BaseDllName.Buffer, length);
84+
}
85+
else {
86+
continue;
87+
}
7788

7889
/* Check if name matches */
79-
if (!_wcsnicmp(DllName, CurEntry->BaseDllName.Buffer, (CurEntry->BaseDllName.Length / sizeof(wchar_t)) - 4) ||
80-
!_wcsnicmp(DllName, CurEntry->BaseDllName.Buffer, CurEntry->BaseDllName.Length / sizeof(wchar_t))) {
81-
90+
if (equal) {
91+
8292
/* Let's compare their headers */
8393
if (!(h2 = RtlImageNtHeader(CurEntry->DllBase)))continue;
8494
if (!(module = MapMemoryModuleHandle((HMEMORYMODULE)CurEntry->DllBase)))continue;

0 commit comments

Comments
 (0)