Skip to content

Commit 58379ac

Browse files
committed
Add support for thread pool(ntdll!TppWorkerThread)
1 parent a32f925 commit 58379ac

4 files changed

Lines changed: 106 additions & 118 deletions

File tree

MemoryModule/Loader.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,26 +64,30 @@ NTSTATUS NTAPI LdrLoadDllMemoryExW(
6464
if (dwFlags & LOAD_FLAGS_USE_DLL_NAME && (!DllName || !DllFullName))return STATUS_INVALID_PARAMETER_3;
6565

6666
if (DllName) {
67-
PLIST_ENTRY ListHead, ListEntry;
68-
PLDR_DATA_TABLE_ENTRY CurEntry;
67+
PLIST_ENTRY ListHead = &NtCurrentPeb()->Ldr->InLoadOrderModuleList, ListEntry = ListHead->Flink;
6968
PIMAGE_NT_HEADERS h1 = RtlImageNtHeader(BufferAddress), h2 = nullptr;
7069
if (!h1)return STATUS_INVALID_IMAGE_FORMAT;
71-
ListEntry = (ListHead = &NtCurrentPeb()->Ldr->InLoadOrderModuleList)->Flink;
70+
7271
while (ListEntry != ListHead) {
73-
CurEntry = CONTAINING_RECORD(ListEntry, LDR_DATA_TABLE_ENTRY, InLoadOrderLinks);
72+
PLDR_DATA_TABLE_ENTRY CurEntry = CONTAINING_RECORD(ListEntry, LDR_DATA_TABLE_ENTRY, InLoadOrderLinks);
7473
ListEntry = ListEntry->Flink;
74+
7575
/* Check if it's being unloaded */
7676
if (!CurEntry->InMemoryOrderLinks.Flink) continue;
77+
7778
/* Check if name matches */
7879
if (!_wcsnicmp(DllName, CurEntry->BaseDllName.Buffer, (CurEntry->BaseDllName.Length / sizeof(wchar_t)) - 4) ||
7980
!_wcsnicmp(DllName, CurEntry->BaseDllName.Buffer, CurEntry->BaseDllName.Length / sizeof(wchar_t))) {
81+
8082
/* Let's compare their headers */
8183
if (!(h2 = RtlImageNtHeader(CurEntry->DllBase)))continue;
8284
if (!(module = MapMemoryModuleHandle((HMEMORYMODULE)CurEntry->DllBase)))continue;
8385
if ((h1->OptionalHeader.SizeOfCode == h2->OptionalHeader.SizeOfCode) &&
8486
(h1->OptionalHeader.SizeOfHeaders == h2->OptionalHeader.SizeOfHeaders)) {
87+
8588
/* This is our entry!, update load count and return success */
8689
if (!module->UseReferenceCount || dwFlags & LOAD_FLAGS_NOT_USE_REFERENCE_COUNT)return STATUS_INVALID_PARAMETER_3;
90+
8791
RtlUpdateReferenceCount(module, FLAG_REFERENCE);
8892
*BaseAddress = (HMEMORYMODULE)CurEntry->DllBase;
8993
if (LdrEntry)*LdrEntry = CurEntry;

MemoryModule/MmpGlobalData.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,11 @@ typedef struct _MMP_TLS_DATA {
2929
DWORD MmpActiveThreadCount;
3030

3131
struct {
32-
decltype(&NtCreateThread) OriginNtCreateThread;
33-
decltype(&NtCreateThreadEx) OriginNtCreateThreadEx;
32+
PVOID HookReserved1;
33+
PVOID HookReserved2;
3434
decltype(&NtSetInformationProcess) OriginNtSetInformationProcess;
3535
decltype(&LdrShutdownThread) OriginLdrShutdownThread;
36+
decltype(&RtlUserThreadStart) OriginRtlUserThreadStart;
3637
}Hooks;
3738
}MMP_TLS_DATA, * PMMP_TLS_DATA;
3839

@@ -77,7 +78,7 @@ typedef enum class _WINDOWS_VERSION :BYTE {
7778
}WINDOWS_VERSION;
7879

7980
#define MEMORY_MODULE_MAJOR_VERSION 1
80-
#define MEMORY_MODULE_MINOR_VERSION 1
81+
#define MEMORY_MODULE_MINOR_VERSION 2
8182

8283
typedef struct _MMP_GLOBAL_DATA {
8384

MemoryModule/MmpTls.cpp

Lines changed: 72 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,37 @@ DWORD NTAPI MmpGetThreadCount() {
114114
while (true) {
115115

116116
if (p->UniqueProcessId == pid) {
117-
result = p->NumberOfThreads;
117+
OBJECT_ATTRIBUTES oa{};
118+
InitializeObjectAttributes(&oa, nullptr, 0, nullptr, nullptr);
119+
120+
THREAD_BASIC_INFORMATION tbi{};
121+
122+
NTSTATUS status;
123+
for (ULONG i = 0; i < p->NumberOfThreads; ++i) {
124+
HANDLE hThread;
125+
status = NtOpenThread(
126+
&hThread,
127+
THREAD_QUERY_INFORMATION,
128+
&oa,
129+
&p->Threads[i].ClientId
130+
);
131+
132+
if (NT_SUCCESS(status)) {
133+
status = NtQueryInformationThread(
134+
hThread,
135+
ThreadBasicInformation,
136+
&tbi,
137+
sizeof(tbi),
138+
nullptr
139+
);
140+
if (NT_SUCCESS(status) && !!tbi.TebBaseAddress->ThreadLocalStoragePointer) {
141+
++result;
142+
}
143+
144+
NtClose(hThread);
145+
}
146+
}
147+
118148
break;
119149
}
120150

@@ -128,6 +158,22 @@ DWORD NTAPI MmpGetThreadCount() {
128158
return result;
129159
}
130160

161+
PMMP_TLSP_RECORD MmpFindTlspRecordLockHeld() {
162+
PLIST_ENTRY entry = MmpGlobalDataPtr->MmpTls->MmpThreadLocalStoragePointer.Flink;
163+
while (entry != &MmpGlobalDataPtr->MmpTls->MmpThreadLocalStoragePointer) {
164+
165+
auto p = CONTAINING_RECORD(entry, MMP_TLSP_RECORD, InMmpThreadLocalStoragePointer);
166+
if (p->UniqueThread == NtCurrentThreadId()) {
167+
assert(p->TlspMmpBlock == NtCurrentTeb()->ThreadLocalStoragePointer);
168+
return p;
169+
}
170+
171+
entry = entry->Flink;
172+
}
173+
174+
return nullptr;
175+
}
176+
131177
DWORD NTAPI MmpUserThreadStart(LPVOID lpThreadParameter) {
132178

133179
THREAD_CONTEXT Context;
@@ -140,8 +186,6 @@ DWORD NTAPI MmpUserThreadStart(LPVOID lpThreadParameter) {
140186
lpThreadParameter,
141187
sizeof(Context)
142188
);
143-
144-
RtlFreeHeap(RtlProcessHeap(), 0, lpThreadParameter);
145189
}
146190
__except (EXCEPTION_EXECUTE_HANDLER) {
147191
return GetExceptionCode();
@@ -151,6 +195,15 @@ DWORD NTAPI MmpUserThreadStart(LPVOID lpThreadParameter) {
151195
goto __skip_tls;
152196
}
153197

198+
//
199+
// Check if we have already initialized
200+
//
201+
EnterCriticalSection(&MmpGlobalDataPtr->MmpTls->MmpTlspLock);
202+
record = MmpFindTlspRecordLockHeld();
203+
LeaveCriticalSection(&MmpGlobalDataPtr->MmpTls->MmpTlspLock);
204+
205+
if (!!record)goto __skip_tls;
206+
154207
//
155208
// Allocate and replace ThreadLocalStoragePointer for new thread
156209
//
@@ -162,7 +215,6 @@ DWORD NTAPI MmpUserThreadStart(LPVOID lpThreadParameter) {
162215
record->TlspMmpBlock = (PVOID*)MmpAllocateTlsp();
163216
record->UniqueThread = NtCurrentThreadId();
164217
if (record->TlspMmpBlock) {
165-
166218
auto size = CONTAINING_RECORD(record->TlspLdrBlock, TLS_VECTOR, ModuleTlsData)->Length;
167219
if ((HANDLE)(ULONG_PTR)size != NtCurrentThreadId()) {
168220
RtlCopyMemory(
@@ -230,91 +282,14 @@ DWORD NTAPI MmpUserThreadStart(LPVOID lpThreadParameter) {
230282
return Context.ThreadStartRoutine(Context.ThreadParameter);
231283
}
232284

233-
NTSTATUS NTAPI HookNtCreateThread(
234-
_Out_ PHANDLE ThreadHandle,
235-
_In_ ACCESS_MASK DesiredAccess,
236-
_In_opt_ POBJECT_ATTRIBUTES ObjectAttributes,
237-
_In_ HANDLE ProcessHandle,
238-
_Out_ PCLIENT_ID ClientId,
239-
_In_ PCONTEXT ThreadContext,
240-
_In_ PVOID InitialTeb,
241-
_In_ BOOLEAN CreateSuspended) {
242-
CONTEXT Context = *ThreadContext;
243-
PTHREAD_CONTEXT _Context = PTHREAD_CONTEXT(RtlAllocateHeap(RtlProcessHeap(), 0, sizeof(*_Context)));
244-
NTSTATUS status;
245-
246-
if (!_Context)return STATUS_NO_MEMORY;
247-
248-
#ifndef _WIN64
249-
_Context->ThreadStartRoutine = PTHREAD_START_ROUTINE(Context.Eax);
250-
_Context->ThreadParameter = LPVOID(Context.Ebx);
251-
252-
Context.Eax = DWORD(MmpUserThreadStart);
253-
Context.Ebx = DWORD(_Context);
254-
255-
#else
256-
_Context->ThreadStartRoutine = PTHREAD_START_ROUTINE(Context.Rcx);
257-
_Context->ThreadParameter = LPVOID(Context.Rdx);
258-
259-
Context.Rcx = ULONG64(MmpUserThreadStart);
260-
Context.Rdx = ULONG64(_Context);
261-
#endif
262-
263-
status = MmpGlobalDataPtr->MmpTls->Hooks.OriginNtCreateThread(
264-
ThreadHandle,
265-
DesiredAccess,
266-
ObjectAttributes,
267-
ProcessHandle,
268-
ClientId,
269-
&Context,
270-
(PINITIAL_TEB)InitialTeb,
271-
CreateSuspended
272-
);
273-
if (!NT_SUCCESS(status)) {
274-
RtlFreeHeap(RtlProcessHeap(), 0, _Context);
275-
}
276-
277-
return status;
278-
}
279-
280-
NTSTATUS NTAPI HookNtCreateThreadEx(
281-
_Out_ PHANDLE ThreadHandle,
282-
_In_ ACCESS_MASK DesiredAccess,
283-
_In_opt_ POBJECT_ATTRIBUTES ObjectAttributes,
284-
_In_ HANDLE ProcessHandle,
285-
_In_ PVOID StartRoutine,
286-
_In_opt_ PVOID Argument,
287-
_In_ ULONG CreateFlags,
288-
_In_ SIZE_T ZeroBits,
289-
_In_ SIZE_T StackSize,
290-
_In_ SIZE_T MaximumStackSize,
291-
_In_opt_ PVOID AttributeList) {
292-
PTHREAD_CONTEXT Context = PTHREAD_CONTEXT(RtlAllocateHeap(RtlProcessHeap(), 0, sizeof(*Context)));
293-
if (!Context) {
294-
return STATUS_NO_MEMORY;
295-
}
296-
297-
Context->ThreadStartRoutine = PTHREAD_START_ROUTINE(StartRoutine);
298-
Context->ThreadParameter = Argument;
299-
300-
NTSTATUS status = MmpGlobalDataPtr->MmpTls->Hooks.OriginNtCreateThreadEx(
301-
ThreadHandle,
302-
DesiredAccess,
303-
ObjectAttributes,
304-
ProcessHandle,
305-
MmpUserThreadStart,
306-
Context,
307-
CreateFlags,
308-
ZeroBits,
309-
StackSize,
310-
MaximumStackSize,
311-
(PPS_ATTRIBUTE_LIST)AttributeList
312-
);
313-
if (!NT_SUCCESS(status)) {
314-
RtlFreeHeap(RtlProcessHeap(), 0, Context);
315-
}
285+
VOID NTAPI HookRtlUserThreadStart(
286+
_In_ PTHREAD_START_ROUTINE Function,
287+
_In_ PVOID Parameter) {
288+
THREAD_CONTEXT Context;
289+
Context.ThreadStartRoutine = PTHREAD_START_ROUTINE(Function);
290+
Context.ThreadParameter = Parameter;
316291

317-
return status;
292+
return MmpGlobalDataPtr->MmpTls->Hooks.OriginRtlUserThreadStart(MmpUserThreadStart, &Context);
318293
}
319294

320295
VOID NTAPI HookLdrShutdownThread(VOID) {
@@ -327,27 +302,16 @@ VOID NTAPI HookLdrShutdownThread(VOID) {
327302
//
328303
EnterCriticalSection(&MmpGlobalDataPtr->MmpTls->MmpTlspLock);
329304

330-
entry = MmpGlobalDataPtr->MmpTls->MmpThreadLocalStoragePointer.Flink;
331-
while (entry != &MmpGlobalDataPtr->MmpTls->MmpThreadLocalStoragePointer) {
332-
333-
auto p = CONTAINING_RECORD(entry, MMP_TLSP_RECORD, InMmpThreadLocalStoragePointer);
334-
if (p->UniqueThread == NtCurrentThreadId()) {
335-
assert(p->TlspMmpBlock == NtCurrentTeb()->ThreadLocalStoragePointer);
336-
337-
//
338-
// Restore tlsp
339-
//
340-
NtCurrentTeb()->ThreadLocalStoragePointer = p->TlspLdrBlock;
305+
record = MmpFindTlspRecordLockHeld();
306+
if (record) {
341307

342-
RemoveEntryList(&p->InMmpThreadLocalStoragePointer);
343-
record = p;
344-
break;
345-
}
308+
//
309+
// Restore tlsp
310+
//
346311

347-
entry = entry->Flink;
348-
}
312+
NtCurrentTeb()->ThreadLocalStoragePointer = record->TlspLdrBlock;
313+
RemoveEntryList(&record->InMmpThreadLocalStoragePointer);
349314

350-
if (record) {
351315
--MmpGlobalDataPtr->MmpTls->MmpActiveThreadCount;
352316
}
353317

@@ -829,17 +793,15 @@ BOOL NTAPI MmpTlsInitialize() {
829793
// Hook functions
830794
//
831795

832-
MmpGlobalDataPtr->MmpTls->Hooks.OriginNtCreateThread = NtCreateThread;
833-
MmpGlobalDataPtr->MmpTls->Hooks.OriginNtCreateThreadEx = NtCreateThreadEx;
834796
MmpGlobalDataPtr->MmpTls->Hooks.OriginLdrShutdownThread = LdrShutdownThread;
835797
MmpGlobalDataPtr->MmpTls->Hooks.OriginNtSetInformationProcess = NtSetInformationProcess;
798+
MmpGlobalDataPtr->MmpTls->Hooks.OriginRtlUserThreadStart = (decltype(&RtlUserThreadStart))GetProcAddress((HMODULE)MmpGlobalDataPtr->MmpBaseAddressIndex->NtdllLdrEntry->DllBase, "RtlUserThreadStart");
836799

837800
DetourTransactionBegin();
838801
DetourUpdateThread(NtCurrentThread());
839-
DetourAttach((PVOID*)&MmpGlobalDataPtr->MmpTls->Hooks.OriginNtCreateThread, HookNtCreateThread);
840-
DetourAttach((PVOID*)&MmpGlobalDataPtr->MmpTls->Hooks.OriginNtCreateThreadEx, HookNtCreateThreadEx);
841802
DetourAttach((PVOID*)&MmpGlobalDataPtr->MmpTls->Hooks.OriginLdrShutdownThread, HookLdrShutdownThread);
842803
DetourAttach((PVOID*)&MmpGlobalDataPtr->MmpTls->Hooks.OriginNtSetInformationProcess, HookNtSetInformationProcess);
804+
DetourAttach((PVOID*)&MmpGlobalDataPtr->MmpTls->Hooks.OriginRtlUserThreadStart, HookRtlUserThreadStart);
843805
DetourTransactionCommit();
844806

845807
return TRUE;

test/test.cpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#include "../MemoryModule/LoadDllMemoryApi.h"
33
#include <cstdio>
44

5+
//PMMP_GLOBAL_DATA MmpGlobalDataPtr = *(PMMP_GLOBAL_DATA*)GetProcAddress(GetModuleHandleA("MemoryModule.dll"), "MmpGlobalDataPtr");
6+
57
static PVOID ReadDllFile(LPCSTR FileName) {
68
LPVOID buffer;
79
size_t size;
@@ -130,8 +132,27 @@ void test_uef() {
130132
return;
131133
}
132134

135+
void Tp() {
136+
auto pool = CreateThreadpool(nullptr);
137+
if (pool) {
138+
139+
SetThreadpoolThreadMaximum(pool, 1);
140+
SetThreadpoolThreadMinimum(pool, 1);
141+
142+
Sleep(1000);
143+
144+
CloseThreadpool(pool);
145+
}
146+
}
147+
133148
int main() {
134-
test_uef();
149+
150+
DisplayStatus();
151+
test();
152+
153+
Tp();
154+
155+
WaitForSingleObject(NtCurrentProcess(), INFINITE);
135156

136157
return 0;
137158
}

0 commit comments

Comments
 (0)