Skip to content

Commit 8d1454f

Browse files
committed
MmpPostponedTls shuts down gracefully.
1 parent 41e666e commit 8d1454f

2 files changed

Lines changed: 110 additions & 68 deletions

File tree

MemoryModule/MmpTlsFiber.cpp

Lines changed: 51 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -27,42 +27,48 @@ DWORD WINAPI MmpReleasePostponedTlsWorker(PVOID) {
2727

2828
EnterCriticalSection(&MmpPostponedTlsLock);
2929

30-
auto iter = MmpPostponedTlsList->begin();
30+
if (MmpPostponedTlsList) {
31+
auto iter = MmpPostponedTlsList->begin();
3132

32-
while (iter != MmpPostponedTlsList->end()) {
33-
const auto& item = *iter;
34-
GetExitCodeThread(item.hThread, &code);
33+
while (iter != MmpPostponedTlsList->end()) {
34+
const auto& item = *iter;
35+
GetExitCodeThread(item.hThread, &code);
3536

36-
if (code == STILL_ACTIVE) {
37-
++iter;
38-
}
39-
else {
37+
if (code == STILL_ACTIVE) {
38+
++iter;
39+
}
40+
else {
4041

41-
RtlAcquireSRWLockExclusive(&MmpGlobalDataPtr->MmpTls->MmpTlsListLock);
42+
RtlAcquireSRWLockExclusive(&MmpGlobalDataPtr->MmpTls->MmpTlsListLock);
4243

43-
auto TlspMmpBlock = (PVOID*)item.lpOldTlsVector->ModuleTlsData;
44-
auto entry = MmpGlobalDataPtr->MmpTls->MmpTlsList.Flink;
45-
while (entry != &MmpGlobalDataPtr->MmpTls->MmpTlsList) {
44+
auto TlspMmpBlock = (PVOID*)item.lpOldTlsVector->ModuleTlsData;
45+
auto entry = MmpGlobalDataPtr->MmpTls->MmpTlsList.Flink;
46+
while (entry != &MmpGlobalDataPtr->MmpTls->MmpTlsList) {
4647

47-
auto p = CONTAINING_RECORD(entry, TLS_ENTRY, TlsEntryLinks);
48-
RtlFreeHeap(RtlProcessHeap(), 0, TlspMmpBlock[p->TlsDirectory.Characteristics]);
48+
auto p = CONTAINING_RECORD(entry, TLS_ENTRY, TlsEntryLinks);
49+
RtlFreeHeap(RtlProcessHeap(), 0, TlspMmpBlock[p->TlsDirectory.Characteristics]);
4950

50-
entry = entry->Flink;
51-
}
51+
entry = entry->Flink;
52+
}
5253

53-
RtlFreeHeap(RtlProcessHeap(), 0, CONTAINING_RECORD(item.lpTlsRecord->TlspLdrBlock, TLS_VECTOR, TLS_VECTOR::ModuleTlsData));
54-
RtlFreeHeap(RtlProcessHeap(), 0, item.lpTlsRecord);
55-
RtlFreeHeap(RtlProcessHeap(), 0, item.lpOldTlsVector);
54+
RtlFreeHeap(RtlProcessHeap(), 0, CONTAINING_RECORD(item.lpTlsRecord->TlspLdrBlock, TLS_VECTOR, TLS_VECTOR::ModuleTlsData));
55+
RtlFreeHeap(RtlProcessHeap(), 0, item.lpTlsRecord);
56+
RtlFreeHeap(RtlProcessHeap(), 0, item.lpOldTlsVector);
5657

57-
RtlReleaseSRWLockExclusive(&MmpGlobalDataPtr->MmpTls->MmpTlsListLock);
58+
RtlReleaseSRWLockExclusive(&MmpGlobalDataPtr->MmpTls->MmpTlsListLock);
59+
60+
CloseHandle(item.hThread);
61+
iter = MmpPostponedTlsList->erase(iter);
62+
}
5863

59-
CloseHandle(item.hThread);
60-
iter = MmpPostponedTlsList->erase(iter);
6164
}
6265

66+
waitTime = MmpPostponedTlsList->empty() ? INFINITE : 1000;
67+
}
68+
else {
69+
LeaveCriticalSection(&MmpPostponedTlsLock);
70+
break;
6371
}
64-
65-
waitTime = MmpPostponedTlsList->empty() ? INFINITE : 1000;
6672

6773
LeaveCriticalSection(&MmpPostponedTlsLock);
6874
}
@@ -100,16 +106,35 @@ VOID WINAPI MmpQueuePostponedTls(PMMP_TLSP_RECORD record) {
100106

101107
EnterCriticalSection(&MmpPostponedTlsLock);
102108

103-
MmpPostponedTlsList->push_back(item);
104-
SetEvent(MmpPostponedTlsEvent);
109+
if (MmpPostponedTlsList) {
110+
MmpPostponedTlsList->push_back(item);
111+
SetEvent(MmpPostponedTlsEvent);
112+
}
113+
else {
114+
RtlFreeHeap(RtlProcessHeap(), 0, item.lpOldTlsVector);
115+
}
105116

106117
LeaveCriticalSection(&MmpPostponedTlsLock);
107118
}
108119

120+
VOID OnExit() {
121+
EnterCriticalSection(&MmpPostponedTlsLock);
122+
123+
MmpPostponedTlsList->~vector();
124+
HeapFree(RtlProcessHeap(), 0, MmpPostponedTlsList);
125+
MmpPostponedTlsList = nullptr;
126+
127+
CloseHandle(MmpPostponedTlsEvent);
128+
129+
LeaveCriticalSection(&MmpPostponedTlsLock);
130+
}
131+
109132
VOID MmpTlsFiberInitialize() {
110133
InitializeCriticalSection(&MmpPostponedTlsLock);
111134
MmpPostponedTlsEvent = CreateEvent(nullptr, FALSE, FALSE, nullptr);
112135
MmpPostponedTlsList = new(HeapAlloc(GetProcessHeap(), 0, sizeof(std::vector<MMP_POSTPONED_TLS>))) std::vector<MMP_POSTPONED_TLS>();
136+
137+
atexit(OnExit);
113138

114139
CreateThread(nullptr, 0, MmpReleasePostponedTlsWorker_Wrap, nullptr, 0, nullptr);
115140
}

test/test.cpp

Lines changed: 59 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -62,62 +62,79 @@ PVOID ReadDllFile2(LPCSTR FileName) {
6262
return nullptr;
6363
}
6464

65-
#define LIBRARY_PATH "\\\\DESKTOP-1145141919810\\Debug\\"
65+
int test() {
66+
LPVOID buffer = ReadDllFile2("a.vmp.dll");
67+
68+
HMODULE hModule = nullptr;
69+
FARPROC pfn = nullptr;
70+
DWORD MemoryModuleFeatures = 0;
71+
72+
typedef int(*_exception)(int code);
73+
_exception exception = nullptr;
74+
HRSRC hRsrc;
75+
DWORD SizeofRes;
76+
HGLOBAL gRes;
77+
char str[10];
78+
79+
LdrQuerySystemMemoryModuleFeatures(&MemoryModuleFeatures);
80+
if (MemoryModuleFeatures != MEMORY_FEATURE_ALL) {
81+
printf("not support all features on this version of windows.\n");
82+
}
83+
84+
if (!NT_SUCCESS(LdrLoadDllMemoryExW(&hModule, nullptr, 0, buffer, 0, L"kernel64", nullptr))) goto end;
6685

67-
HMODULE WINAPI MyLoadLibrary(LPCSTR lpModuleName) {
68-
HMODULE hModule;
69-
PVOID buffer;
86+
//forward export
87+
pfn = (decltype(pfn))(GetProcAddress(hModule, "Socket")); //ws2_32.WSASocketW
88+
pfn = (decltype(pfn))(GetProcAddress(hModule, "VerifyTruse")); //wintrust.WinVerifyTrust
7089

71-
if (0 == _stricmp(lpModuleName, "CTestClassLibrary1.dll")) {
72-
buffer = ReadDllFile(LIBRARY_PATH"CTestClassLibrary1.dll");
90+
//exception
91+
exception = (_exception)GetProcAddress(hModule, "exception");
92+
if (exception) {
93+
for (int i = 0; i < 5; ++i)exception(i);
7394
}
74-
else if (0 == _stricmp(lpModuleName, "CTestClassLibrary2.dll")) {
75-
buffer = ReadDllFile(LIBRARY_PATH"CTestClassLibrary2.dll");
95+
96+
//tls
97+
pfn = GetProcAddress(hModule, "thread");
98+
if (pfn && pfn()) {
99+
printf("thread test failed.\n");
76100
}
77-
else if (0 == _stricmp(lpModuleName, "CTestClassLibrary1Dep.dll")) {
78-
buffer = ReadDllFile(LIBRARY_PATH"CTestClassLibrary1Dep.dll");
101+
102+
//resource
103+
if (!LoadStringA(hModule, 101, str, 10)) {
104+
printf("load string failed.\n");
79105
}
80106
else {
81-
return nullptr;
107+
printf("%s\n", str);
82108
}
83-
84-
hModule = LoadLibraryMemoryExA(buffer, 0, lpModuleName, nullptr, 0);
85-
delete[]buffer;
86-
return hModule;
87-
}
88-
89-
VOID TestImportTableResolver() {
90-
91-
//
92-
// Register the import table resolver.
93-
//
94-
HANDLE hResolver = MmRegisterImportTableResolver(MyLoadLibrary, FreeLibraryMemory);
95-
96-
//
97-
// |-> CTestClassLibrary1.dll -> CTestClassLibrary1Dep.dll
98-
// CTestClient.dll -|
99-
// |-> CTestClassLibrary2.dll
100-
//
101-
102-
PVOID Client = ReadDllFile2("CTestClient.dll");
103-
HMODULE hm = LoadLibraryMemoryEx(Client, 0, TEXT("CTestClient.dll"), nullptr, 0);
104-
delete[]Client;
105-
106-
if (hm) {
107-
auto pfn = GetProcAddress(hm, "TestProc");
108-
if (pfn) {
109-
pfn();
109+
if (!(hRsrc = FindResourceA(hModule, MAKEINTRESOURCEA(102), "BINARY"))) {
110+
printf("find binary resource failed.\n");
111+
}
112+
else {
113+
if ((SizeofRes = SizeofResource(hModule, hRsrc)) != 0x10) {
114+
printf("invalid res size.\n");
115+
}
116+
else {
117+
if (!(gRes = LoadResource(hModule, hRsrc))) {
118+
printf("load res failed.\n");
119+
}
120+
else {
121+
if (!LockResource(gRes))printf("lock res failed.\n");
122+
else {
123+
printf("resource test success.\n");
124+
}
125+
}
110126
}
111-
112-
FreeLibraryMemory(hm);
113127
}
114128

115-
MmRemoveImportTableResolver(hResolver);
129+
end:
130+
LdrUnloadDllMemory(hModule);
131+
delete[]buffer;
132+
return 0;
116133
}
117134

118135
int main() {
119136
DisplayStatus();
120-
TestImportTableResolver();
137+
test();
121138

122139
return 0;
123140
}

0 commit comments

Comments
 (0)