Skip to content

Commit c98558e

Browse files
authored
Use COM marshalling to exchange handles (#40056)
* Prototype * Save state * Save state * Save state * Save state * Save state * Save state * Remove zeroing * Apply PR feedback * Add test coverage * Format * Add test coverage for null handles
1 parent 51dc5f8 commit c98558e

16 files changed

Lines changed: 390 additions & 199 deletions

src/windows/WslcSDK/IOCallback.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,9 @@ bool IOCallback::HasIOCallback(const WslcContainerProcessIOCallbackOptions& opti
104104

105105
wil::unique_handle IOCallback::GetIOHandle(IWSLCProcess* process, WslcProcessIOHandle ioHandle)
106106
{
107-
ULONG ulongHandle = 0;
107+
wsl::windows::common::wslutil::COMOutputHandle handle;
108108

109-
THROW_IF_FAILED(process->GetStdHandle(static_cast<ULONG>(static_cast<std::underlying_type_t<WslcProcessIOHandle>>(ioHandle)), &ulongHandle));
109+
THROW_IF_FAILED(process->GetStdHandle(static_cast<WSLCFD>(static_cast<std::underlying_type_t<WslcProcessIOHandle>>(ioHandle)), &handle));
110110

111-
return wil::unique_handle{ULongToHandle(ulongHandle)};
111+
return handle.Release();
112112
}

src/windows/WslcSDK/wslcsdk.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ Module Name:
2020
#include "wslutil.h"
2121

2222
using namespace std::string_view_literals;
23+
using namespace wsl::windows::common::wslutil;
2324

2425
namespace {
2526
constexpr uint32_t s_DefaultCPUCount = 2;
@@ -1129,7 +1130,7 @@ static HRESULT WslcImportSessionImageImpl(
11291130
auto progressCallback = ProgressCallback::CreateIf(options);
11301131

11311132
return errorInfoWrapper.CaptureResult(internalSession->session->ImportImage(
1132-
HandleToULong(imageFile.Handle()), imageName, progressCallback.get(), imageFile.Length()));
1133+
ToCOMInputHandle(imageFile.Handle()), imageName, progressCallback.get(), imageFile.Length()));
11331134
}
11341135

11351136
STDAPI WslcImportSessionImage(
@@ -1167,7 +1168,7 @@ static HRESULT WslcLoadSessionImageImpl(
11671168
auto progressCallback = ProgressCallback::CreateIf(options);
11681169

11691170
return errorInfoWrapper.CaptureResult(
1170-
internalSession->session->LoadImage(HandleToULong(imageFile.Handle()), progressCallback.get(), imageFile.Length()));
1171+
internalSession->session->LoadImage(ToCOMInputHandle(imageFile.Handle()), progressCallback.get(), imageFile.Length()));
11711172
}
11721173

11731174
STDAPI WslcLoadSessionImage(

src/windows/common/WSLCProcessLauncher.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -221,10 +221,10 @@ ClientRunningWSLCProcess::ClientRunningWSLCProcess(wil::com_ptr<IWSLCProcess>&&
221221

222222
wil::unique_handle ClientRunningWSLCProcess::GetStdHandle(int Index)
223223
{
224-
ULONG handle{};
225-
THROW_IF_FAILED_MSG(m_process->GetStdHandle(Index, &handle), "Failed to get handle: %i", Index);
224+
wslutil::COMOutputHandle handle;
225+
THROW_IF_FAILED_MSG(m_process->GetStdHandle(static_cast<WSLCFD>(Index), &handle), "Failed to get handle: %i", Index);
226226

227-
return wil::unique_handle{ULongToHandle(handle)};
227+
return handle.Release();
228228
}
229229

230230
wil::unique_event ClientRunningWSLCProcess::GetExitEvent()

src/windows/common/wslutil.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,23 @@ wsl::windows::common::ErrorStrings wsl::windows::common::wslutil::ErrorToString(
529529
return errorStrings;
530530
}
531531

532+
[[nodiscard]] HANDLE wsl::windows::common::wslutil::FromCOMInputHandle(WSLCHandle Handle)
533+
{
534+
THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_HANDLE), Handle.Handle.File == nullptr || Handle.Handle.File == INVALID_HANDLE_VALUE);
535+
536+
switch (Handle.Type)
537+
{
538+
case WSLCHandleTypeFile:
539+
return Handle.Handle.File;
540+
case WSLCHandleTypePipe:
541+
return Handle.Handle.Pipe;
542+
case WSLCHandleTypeSocket:
543+
return Handle.Handle.Socket;
544+
default:
545+
THROW_HR_MSG(E_UNEXPECTED, "Unsupported handle type: %d", Handle.Type);
546+
}
547+
}
548+
532549
std::wstring wsl::windows::common::wslutil::ConstructPipePath(std::wstring_view PipeName)
533550
{
534551
return c_pipePrefix + std::wstring(PipeName);
@@ -1311,6 +1328,51 @@ wil::unique_hlocal_string wsl::windows::common::wslutil::SidToString(_In_ PSID U
13111328
return sid;
13121329
}
13131330

1331+
WSLCHandle wsl::windows::common::wslutil::ToCOMOutputHandle(HANDLE Handle, DWORD Access)
1332+
{
1333+
wil::unique_handle duplicatedHandle{DuplicateHandle(Handle, Access)};
1334+
1335+
// N.B. COM closes the handle when returning an out parameter.
1336+
// The return value of this method should always be passed to a COM out parameter.
1337+
auto comHandle = ToCOMInputHandle(duplicatedHandle.release());
1338+
1339+
return comHandle;
1340+
}
1341+
1342+
WSLCHandle wsl::windows::common::wslutil::ToCOMInputHandle(HANDLE Handle)
1343+
{
1344+
auto type = GetFileType(Handle);
1345+
if (type == FILE_TYPE_PIPE)
1346+
{
1347+
int socketType{};
1348+
int len = sizeof(socketType);
1349+
1350+
// N.B. FILE_TYPE_PIPE can describe a pipe, a named pipe, or a socket.
1351+
// Check for a named pipe first, since getsockopt() can return success for a named pipe.
1352+
1353+
if (GetNamedPipeInfo(Handle, nullptr, nullptr, nullptr, nullptr))
1354+
{
1355+
return WSLCHandle{.Type = WSLCHandleTypePipe, .Handle = {.Pipe = Handle}};
1356+
}
1357+
else if (getsockopt(reinterpret_cast<SOCKET>(Handle), SOL_SOCKET, SO_TYPE, reinterpret_cast<char*>(&socketType), &len) == 0)
1358+
{
1359+
return WSLCHandle{.Type = WSLCHandleTypeSocket, .Handle = {.Socket = Handle}};
1360+
}
1361+
else
1362+
{
1363+
return WSLCHandle{.Type = WSLCHandleTypePipe, .Handle = {.Pipe = Handle}};
1364+
}
1365+
}
1366+
else if (type == FILE_TYPE_DISK)
1367+
{
1368+
return WSLCHandle{.Type = WSLCHandleTypeFile, .Handle = {.File = Handle}};
1369+
}
1370+
else
1371+
{
1372+
THROW_HR_MSG(HRESULT_FROM_WIN32(ERROR_NOT_SUPPORTED), "Unsupported handle type: %d", type);
1373+
}
1374+
}
1375+
13141376
winrt::Windows::Management::Deployment::PackageVolume wsl::windows::common::wslutil::GetSystemVolume()
13151377
try
13161378
{

src/windows/common/wslutil.h

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,54 @@ struct COMErrorInfo
7272
wil::unique_bstr Source;
7373
};
7474

75+
static_assert(sizeof(WSLCHandle::Handle) == sizeof(HANDLE));
76+
static_assert(sizeof(FILE_HANDLE) == sizeof(HANDLE));
77+
static_assert(sizeof(PIPE_HANDLE) == sizeof(HANDLE));
78+
static_assert(sizeof(SOCKET_HANDLE) == sizeof(HANDLE));
79+
80+
struct COMOutputHandle : public WSLCHandle
81+
{
82+
NON_COPYABLE(COMOutputHandle);
83+
NON_MOVABLE(COMOutputHandle);
84+
COMOutputHandle()
85+
{
86+
ZeroMemory(&Handle, sizeof(Handle));
87+
Type = WSLCHandleTypeUnknown;
88+
}
89+
90+
~COMOutputHandle()
91+
{
92+
Reset();
93+
}
94+
95+
void Reset() noexcept
96+
{
97+
if (!Empty())
98+
{
99+
LOG_IF_WIN32_BOOL_FALSE(CloseHandle(Handle.File));
100+
Handle.File = nullptr;
101+
}
102+
}
103+
104+
[[nodiscard]] wil::unique_handle Release() noexcept
105+
{
106+
wil::unique_handle handle(Handle.File);
107+
Handle.File = nullptr;
108+
109+
return handle;
110+
}
111+
112+
HANDLE Get() const noexcept
113+
{
114+
return Handle.File;
115+
}
116+
117+
bool Empty() const noexcept
118+
{
119+
return Handle.File == nullptr || Handle.File == INVALID_HANDLE_VALUE;
120+
}
121+
};
122+
75123
struct PruneResult
76124
{
77125
NON_COPYABLE(PruneResult);
@@ -170,6 +218,8 @@ std::wstring ErrorCodeToString(HRESULT Error);
170218

171219
ErrorStrings ErrorToString(const Error& error);
172220

221+
[[nodiscard]] HANDLE FromCOMInputHandle(WSLCHandle Handle);
222+
173223
std::filesystem::path GetBasePath();
174224

175225
std::optional<COMErrorInfo> GetCOMErrorInfo();
@@ -267,6 +317,9 @@ void SetThreadDescription(LPCWSTR Name);
267317

268318
wil::unique_hlocal_string SidToString(_In_ PSID Sid);
269319

320+
WSLCHandle ToCOMInputHandle(HANDLE Handle);
321+
[[nodiscard]] WSLCHandle ToCOMOutputHandle(HANDLE Handle, DWORD Access);
322+
270323
winrt::Windows::Management::Deployment::PackageVolume GetSystemVolume();
271324

272325
} // namespace wsl::windows::common::wslutil

src/windows/service/exe/WSLCSessionManager.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ void WSLCSessionManagerImpl::CreateSession(const WSLCSessionSettings* Settings,
124124
TraceLoggingValue(Settings->DisplayName, "Name"),
125125
TraceLoggingValue(stopWatch.ElapsedMilliseconds(), "CreationTimeMs"),
126126
TraceLoggingValue(creationResult, "Result"),
127+
TraceLoggingValue(tokenInfo.Elevated, "Elevated"),
127128
TraceLoggingValue(static_cast<uint32_t>(Flags), "Flags"));
128129

129130
THROW_IF_FAILED_MSG(creationResult, "Failed to create session: %ls", Settings->DisplayName);

src/windows/service/inc/wslc.idl

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,35 @@ typedef struct _WSLCContainerEntry
290290
WSLCContainerState State;
291291
} WSLCContainerEntry;
292292

293+
typedef [system_handle(sh_file)] HANDLE FILE_HANDLE;
294+
typedef [system_handle(sh_pipe)] HANDLE PIPE_HANDLE;
295+
typedef [system_handle(sh_socket)] HANDLE SOCKET_HANDLE;
296+
297+
typedef enum _WSLCHandleType
298+
{
299+
WSLCHandleTypeUnknown = 0,
300+
WSLCHandleTypeFile = 1,
301+
WSLCHandleTypePipe = 2,
302+
WSLCHandleTypeSocket = 3
303+
} WSLCHandleType;
304+
305+
typedef struct _WSLCHandle
306+
{
307+
WSLCHandleType Type;
308+
309+
[switch_type(WSLCHandleType), switch_is(Type)]
310+
union
311+
{
312+
[case(WSLCHandleTypeFile)]
313+
FILE_HANDLE File;
314+
[case(WSLCHandleTypePipe)]
315+
PIPE_HANDLE Pipe;
316+
[case(WSLCHandleTypeSocket)]
317+
SOCKET_HANDLE Socket;
318+
[default];
319+
} Handle;
320+
} WSLCHandle;
321+
293322
typedef enum _WSLCProcessState
294323
{
295324
WslcProcessStateUnknown = 0,
@@ -307,7 +336,7 @@ interface IWSLCProcess : IUnknown
307336
{
308337
HRESULT Signal([in] int Signal);
309338
HRESULT GetExitEvent([out, system_handle(sh_event)] HANDLE* EventHandle);
310-
HRESULT GetStdHandle([in] ULONG Index, [out] ULONG* Handle);
339+
HRESULT GetStdHandle([in] WSLCFD Fd, [out] WSLCHandle* Handle);
311340
HRESULT GetFlags([out] WSLCProcessFlags* Flags);
312341
HRESULT GetPid([out] int* Pid);
313342
HRESULT GetState([out] WSLCProcessState* State, [out] int* Code);
@@ -417,16 +446,16 @@ typedef enum _WSLCDeleteFlags
417446
]
418447
interface IWSLCContainer : IUnknown
419448
{
420-
HRESULT Attach([in, unique] LPCSTR DetachKeys, [out] ULONG* StdIn, [out] ULONG* StdOut, [out] ULONG* StdErr);
449+
HRESULT Attach([in, unique] LPCSTR DetachKeys, [out] WSLCHandle* StdIn, [out] WSLCHandle* StdOut, [out] WSLCHandle* StdErr);
421450
HRESULT Stop([in] WSLCSignal Signal, [in] LONG TimeoutSeconds);
422451
HRESULT Start([in] WSLCContainerStartFlags Flags, [in, unique] LPCSTR DetachKeys);
423452
HRESULT Delete([in] WSLCDeleteFlags Flags);
424-
HRESULT Export([in] ULONG TarHandle);
453+
HRESULT Export([in] WSLCHandle TarHandle);
425454
HRESULT GetState([out] WSLCContainerState* State);
426455
HRESULT GetInitProcess([out] IWSLCProcess** Process);
427456
HRESULT Exec([in, ref] const WSLCProcessOptions* Options, [in, unique] LPCSTR DetachKeys, [out] IWSLCProcess** Process);
428457
HRESULT Inspect([out] LPSTR* Output);
429-
HRESULT Logs([in] WSLCLogsFlags Flags, [out] ULONG* Stdout, [out] ULONG* Stderr, [in] ULONGLONG Since, [in] ULONGLONG Until, [in] ULONGLONG Tail);
458+
HRESULT Logs([in] WSLCLogsFlags Flags, [out] WSLCHandle* Stdout, [out] WSLCHandle* Stderr, [in] ULONGLONG Since, [in] ULONGLONG Until, [in] ULONGLONG Tail);
430459
HRESULT GetId([out, string] WSLCContainerId Id);
431460
HRESULT GetName([out, string] LPSTR* Name);
432461
HRESULT GetLabels([out, size_is(, *Count)] WSLCLabelInformation** Labels, [out] ULONG* Count);
@@ -462,7 +491,7 @@ typedef struct _WSLCDeleteImageOptions
462491
typedef struct _WSLCBuildImageOptions
463492
{
464493
LPCWSTR ContextPath;
465-
ULONG DockerfileHandle;
494+
WSLCHandle DockerfileHandle;
466495
WSLCStringArray Tags;
467496
WSLCStringArray BuildArgs; // KEY=VALUE pairs passed as --build-arg to docker.
468497
BOOL Verbose; // Show all build progress including internal steps and all statuses.
@@ -529,9 +558,9 @@ interface IWSLCSession : IUnknown
529558
// Image management.
530559
HRESULT PullImage([in] LPCSTR Image, [in, unique] LPCSTR RegistryAuthenticationInformation, [in, unique] IProgressCallback* ProgressCallback);
531560
HRESULT BuildImage([in] const WSLCBuildImageOptions* Options, [in, unique] IProgressCallback* ProgressCallback, [in, unique, system_handle(sh_event)] HANDLE CancelEvent);
532-
HRESULT LoadImage([in] ULONG ImageHandle, [in, unique] IProgressCallback* ProgressCallback, [in] ULONGLONG ContentLength);
533-
HRESULT ImportImage([in] ULONG ImageHandle, [in] LPCSTR ImageName, [in, unique] IProgressCallback* ProgressCallback, [in] ULONGLONG ContentLength);
534-
HRESULT SaveImage([in] ULONG OutputHandle, [in] LPCSTR ImageNameOrID, [in, unique] IProgressCallback * ProgressCallback, [in, unique, system_handle(sh_event)] HANDLE CancelEvent);
561+
HRESULT LoadImage([in] WSLCHandle ImageHandle, [in, unique] IProgressCallback* ProgressCallback, [in] ULONGLONG ContentLength);
562+
HRESULT ImportImage([in] WSLCHandle ImageHandle, [in] LPCSTR ImageName, [in, unique] IProgressCallback* ProgressCallback, [in] ULONGLONG ContentLength);
563+
HRESULT SaveImage([in] WSLCHandle OutputHandle, [in] LPCSTR ImageNameOrID, [in, unique] IProgressCallback * ProgressCallback, [in, unique, system_handle(sh_event)] HANDLE CancelEvent);
535564
HRESULT ListImages([in, unique] const WSLCListImageOptions* Options, [out, size_is(, *Count)] WSLCImageInformation** Images, [out] ULONG* Count);
536565
HRESULT DeleteImage([in] const WSLCDeleteImageOptions* Options, [out, size_is(, *Count)] WSLCDeletedImageInformation** DeletedImages, [out] ULONG* Count);
537566
HRESULT TagImage([in] const WSLCTagImageOptions* Options);

src/windows/wslc/services/ContainerService.cpp

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ namespace wsl::windows::wslc::services {
2727
using wsl::windows::common::ClientRunningWSLCProcess;
2828
using wsl::windows::common::wslc_schema::InspectContainer;
2929
using wsl::windows::common::wslutil::PrintMessage;
30+
using namespace wsl::windows::common::wslutil;
3031
using namespace wsl::windows::wslc::models;
3132
using namespace std::chrono_literals;
3233

@@ -198,26 +199,22 @@ int ContainerService::Attach(Session& session, const std::string& id)
198199

199200
ClientRunningWSLCProcess runningProcess(std::move(process), processFlags);
200201

201-
ULONG stdinLogsHandle = 0;
202-
ULONG stdoutLogsHandle = 0;
203-
ULONG stderrLogsHandle = 0;
204-
THROW_IF_FAILED(container->Attach(nullptr, &stdinLogsHandle, &stdoutLogsHandle, &stderrLogsHandle));
202+
COMOutputHandle stdinLogs{};
203+
COMOutputHandle stdoutLogs{};
204+
COMOutputHandle stderrLogs{};
205+
THROW_IF_FAILED(container->Attach(nullptr, &stdinLogs, &stdoutLogs, &stderrLogs));
205206

206-
wil::unique_handle stdinLogs(ULongToHandle(stdinLogsHandle));
207-
wil::unique_handle stdoutLogs(ULongToHandle(stdoutLogsHandle));
208-
wil::unique_handle stderrLogs(ULongToHandle(stderrLogsHandle));
209-
210-
if (stdoutLogs)
207+
if (!stdoutLogs.Empty())
211208
{
212209
// Non-TTY process - relay separate stdout/stderr streams
213-
WI_ASSERT(!!stderrLogs);
214-
ConsoleService::RelayNonTtyProcess(std::move(stdinLogs), std::move(stdoutLogs), std::move(stderrLogs));
210+
WI_ASSERT(!stderrLogs.Empty());
211+
ConsoleService::RelayNonTtyProcess(stdinLogs.Release(), stdoutLogs.Release(), stderrLogs.Release());
215212
}
216213
else
217214
{
218215
// TTY process - relay using interactive TTY handling
219-
WI_ASSERT(!stderrLogs);
220-
if (!ConsoleService::RelayInteractiveTty(runningProcess, stdinLogs.get(), true))
216+
WI_ASSERT(stderrLogs.Empty());
217+
if (!ConsoleService::RelayInteractiveTty(runningProcess, stdinLogs.Release().get(), true))
221218
{
222219
wsl::windows::common::wslutil::PrintMessage(L"[detached]", stderr);
223220
return 0; // Exit early if user detached
@@ -386,22 +383,21 @@ void ContainerService::Logs(Session& session, const std::string& id, bool follow
386383
wil::com_ptr<IWSLCContainer> container;
387384
THROW_IF_FAILED(session.Get()->OpenContainer(id.c_str(), &container));
388385

389-
ULONG stdoutLogsHandle = 0;
390-
ULONG stderrLogsHandle = 0;
386+
COMOutputHandle stdoutHandle;
387+
COMOutputHandle stderrHandle;
391388
WSLCLogsFlags flags = WSLCLogsFlagsNone;
392389
WI_SetFlagIf(flags, WSLCLogsFlagsFollow, follow);
393390

394-
THROW_IF_FAILED(container->Logs(flags, &stdoutLogsHandle, &stderrLogsHandle, 0, 0, 0));
395-
wil::unique_handle stdoutLogs(ULongToHandle(stdoutLogsHandle));
396-
wil::unique_handle stderrLogs(ULongToHandle(stderrLogsHandle));
391+
THROW_IF_FAILED(container->Logs(flags, &stdoutHandle, &stderrHandle, 0, 0, 0));
397392

398393
wsl::windows::common::relay::MultiHandleWait io;
399394
io.AddHandle(std::make_unique<wsl::windows::common::relay::RelayHandle<wsl::windows::common::relay::ReadHandle>>(
400-
std::move(stdoutLogs), GetStdHandle(STD_OUTPUT_HANDLE)));
401-
if (stderrLogs) // This handle is only used for non-tty processes.
395+
stdoutHandle.Release(), GetStdHandle(STD_OUTPUT_HANDLE)));
396+
397+
if (!stderrHandle.Empty()) // This handle is only used for non-tty processes.
402398
{
403399
io.AddHandle(std::make_unique<wsl::windows::common::relay::RelayHandle<wsl::windows::common::relay::ReadHandle>>(
404-
std::move(stderrLogs), GetStdHandle(STD_ERROR_HANDLE)));
400+
stderrHandle.Release(), GetStdHandle(STD_ERROR_HANDLE)));
405401
}
406402

407403
// TODO: Handle ctrl-c.

0 commit comments

Comments
 (0)