Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 23 additions & 21 deletions include/kf/FltCommunicationPort.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "ObjectAttributes.h"
#include "ScopeExit.h"
#include "VariableSizeStruct.h"
#include "IWinApi.h"

namespace kf
{
Expand Down Expand Up @@ -69,7 +70,7 @@ namespace kf
class FltCommunicationPort
{
public:
FltCommunicationPort() : m_filter(), m_port()
FltCommunicationPort(IWinApi& api) : m_api(api), m_filter(), m_port()
{
}

Expand All @@ -85,26 +86,26 @@ namespace kf
m_filter = filter;

PSECURITY_DESCRIPTOR securityDescriptor = nullptr;
NTSTATUS status = ::FltBuildDefaultSecurityDescriptor(&securityDescriptor, FLT_PORT_ALL_ACCESS);
NTSTATUS status = m_api.FltBuildDefaultSecurityDescriptor(&securityDescriptor, FLT_PORT_ALL_ACCESS);

if (!NT_SUCCESS(status))
{
return status;
}

SCOPE_EXIT{ ::FltFreeSecurityDescriptor(securityDescriptor); };
SCOPE_EXIT{ m_api.FltFreeSecurityDescriptor(securityDescriptor); };

VariableSizeStruct<SYSTEM_MANDATORY_LABEL_ACE, PagedPool> lowIntegrityAce;
VariableSizeStruct<ACL, PagedPool> sacl;
if (allowNonAdmins)
{
status = RtlSetDaclSecurityDescriptor(securityDescriptor, true, nullptr, false);
status = m_api.RtlSetDaclSecurityDescriptor(securityDescriptor, true, nullptr, false);
if (!NT_SUCCESS(status))
{
return status;
}

const auto lowMandatorySidLength = RtlLengthSid(SeExports->SeLowMandatorySid);
const auto lowMandatorySidLength = m_api.RtlLengthSid(SeExports->SeLowMandatorySid);
status = lowIntegrityAce.emplace(FIELD_OFFSET(SYSTEM_MANDATORY_LABEL_ACE, SidStart) + lowMandatorySidLength);
if (!NT_SUCCESS(status))
{
Expand All @@ -114,7 +115,7 @@ namespace kf
lowIntegrityAce->Header.AceType = SYSTEM_MANDATORY_LABEL_ACE_TYPE;
lowIntegrityAce->Header.AceSize = static_cast<USHORT>(FIELD_OFFSET(SYSTEM_MANDATORY_LABEL_ACE, SidStart) + lowMandatorySidLength);
lowIntegrityAce->Mask = 0;
status = RtlCopySid(lowMandatorySidLength, &lowIntegrityAce->SidStart, SeExports->SeLowMandatorySid);
status = m_api.RtlCopySid(lowMandatorySidLength, &lowIntegrityAce->SidStart, SeExports->SeLowMandatorySid);
if (!NT_SUCCESS(status))
{
return status;
Expand All @@ -126,19 +127,19 @@ namespace kf
{
return status;
}
status = RtlCreateAcl(sacl.get(), saclSize, ACL_REVISION);
status = m_api.RtlCreateAcl(sacl.get(), saclSize, ACL_REVISION);
if (!NT_SUCCESS(status))
{
return status;
}

status = RtlAddAce(sacl.get(), ACL_REVISION, 0, static_cast<PVOID>(lowIntegrityAce.get()), lowIntegrityAce->Header.AceSize);
status = m_api.RtlAddAce(sacl.get(), ACL_REVISION, 0, static_cast<PVOID>(lowIntegrityAce.get()), lowIntegrityAce->Header.AceSize);
if (!NT_SUCCESS(status))
{
return status;
}

status = RtlSetSaclSecurityDescriptor(securityDescriptor, true, sacl.get(), false);
status = m_api.RtlSetSaclSecurityDescriptor(securityDescriptor, true, sacl.get(), false);
if (!NT_SUCCESS(status))
{
return status;
Expand All @@ -147,14 +148,14 @@ namespace kf

ObjectAttributes oa(&name, securityDescriptor);

return ::FltCreateCommunicationPort(filter, &m_port, &oa, this, connectNotify, disconnectNotify, messageNotify, maxConnections);
return m_api.FltCreateCommunicationPort(filter, &m_port, &oa, this, connectNotify, disconnectNotify, messageNotify, maxConnections);
}

void close()
{
if (m_port)
{
::FltCloseCommunicationPort(m_port);
m_api.FltCloseCommunicationPort(m_port);
m_port = nullptr;
}

Expand All @@ -175,7 +176,7 @@ namespace kf
{
ASSERT(serverPortCookie);
auto self = static_cast<FltCommunicationPort*>(serverPortCookie);
return Handler::onConnect(self->m_filter, clientPort, connectionContext, connectionContextLength, reinterpret_cast<Handler**>(connectionCookie));
return Handler::onConnect(self->m_filter, clientPort, connectionContext, connectionContextLength, reinterpret_cast<Handler**>(connectionCookie), self->m_api);
}

static VOID FLTAPI disconnectNotify(
Expand Down Expand Up @@ -214,15 +215,15 @@ namespace kf
{
if (inputBufferLength)
{
inputMdl = IoAllocateMdl(inputBuffer, inputBufferLength, false, false, nullptr);
inputMdl = handler->m_api.IoAllocateMdl(inputBuffer, inputBufferLength, false, false, nullptr);
if (!inputMdl)
{
return STATUS_INSUFFICIENT_RESOURCES;
}

MmProbeAndLockPages(inputMdl, KernelMode, IoReadAccess);
handler->m_api.MmProbeAndLockPages(inputMdl, KernelMode, IoReadAccess);

inputBuffer = MmGetSystemAddressForMdlSafe(inputMdl, NormalPagePriority | MdlMappingNoExecute | MdlMappingNoWrite);
inputBuffer = handler->m_api.MmGetSystemAddressForMdlSafe(inputMdl, NormalPagePriority | MdlMappingNoExecute | MdlMappingNoWrite);
if (!inputBuffer)
{
return STATUS_INSUFFICIENT_RESOURCES;
Expand All @@ -231,15 +232,15 @@ namespace kf

if (outputBufferLength)
{
outputMdl = IoAllocateMdl(outputBuffer, outputBufferLength, false, false, nullptr);
outputMdl = handler->m_api.IoAllocateMdl(outputBuffer, outputBufferLength, false, false, nullptr);
if (!outputMdl)
{
return STATUS_INSUFFICIENT_RESOURCES;
}

MmProbeAndLockPages(outputMdl, KernelMode, IoWriteAccess);
handler->m_api.MmProbeAndLockPages(outputMdl, KernelMode, IoWriteAccess);

outputBuffer = MmGetSystemAddressForMdlSafe(outputMdl, NormalPagePriority | MdlMappingNoExecute);
outputBuffer = handler->m_api.MmGetSystemAddressForMdlSafe(outputMdl, NormalPagePriority | MdlMappingNoExecute);
if (!outputBuffer)
{
return STATUS_INSUFFICIENT_RESOURCES;
Expand All @@ -258,16 +259,16 @@ namespace kf
// Cleanup
//

auto freeMdl = [](PMDL& mdl)
auto freeMdl = [&handler](PMDL& mdl)
{
if (mdl)
{
if (FlagOn(mdl->MdlFlags, MDL_PAGES_LOCKED))
{
MmUnlockPages(mdl);
handler->m_api.MmUnlockPages(mdl);
}

IoFreeMdl(mdl);
handler->m_api.IoFreeMdl(mdl);
mdl = nullptr;
}
};
Expand All @@ -281,5 +282,6 @@ namespace kf
private:
PFLT_FILTER m_filter;
PFLT_PORT m_port;
IWinApi& m_api;
};
} // namespace
50 changes: 50 additions & 0 deletions include/kf/IWinApi.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#pragma once

namespace kf
{
////////////////////////////////////////////////////
// Interface for Windows API calls
class IWinApi
{
public:
virtual NTSTATUS FltBuildDefaultSecurityDescriptor(PSECURITY_DESCRIPTOR* sd, ACCESS_MASK access) = 0;

virtual VOID FltFreeSecurityDescriptor(PSECURITY_DESCRIPTOR sd) = 0;

virtual NTSTATUS FltCreateCommunicationPort(
PFLT_FILTER filter,
PFLT_PORT* serverPort,
POBJECT_ATTRIBUTES oa,
PVOID serverPortCookie,
PFLT_CONNECT_NOTIFY connectNotify,
PFLT_DISCONNECT_NOTIFY disconnectNotify,
PFLT_MESSAGE_NOTIFY messageNotify,
LONG maxConnections) = 0;

virtual VOID FltCloseCommunicationPort(PFLT_PORT port) = 0;

virtual VOID FltCloseClientPort(PFLT_FILTER filter, PFLT_PORT* port) = 0;

virtual NTSTATUS RtlSetDaclSecurityDescriptor(PSECURITY_DESCRIPTOR sd, BOOLEAN daclPresent, PACL dacl, BOOLEAN daclDefaulted) = 0;

virtual NTSTATUS RtlSetSaclSecurityDescriptor(PSECURITY_DESCRIPTOR sd, BOOLEAN saclPresent, PACL sacl, BOOLEAN saclDefaulted) = 0;

virtual ULONG RtlLengthSid(PSID sid) = 0;

virtual NTSTATUS RtlCopySid(ULONG len, PSID dest, PSID src) = 0;

virtual NTSTATUS RtlCreateAcl(PACL acl, ULONG size, ULONG rev) = 0;

virtual NTSTATUS RtlAddAce(PACL acl, ULONG rev, ULONG start, PVOID ace, ULONG aceSize) = 0;

virtual PMDL IoAllocateMdl(PVOID va, ULONG len, BOOLEAN secondary, BOOLEAN chargeQuota, PIRP irp) = 0;

virtual VOID IoFreeMdl(PMDL mdl) = 0;

virtual VOID MmProbeAndLockPages(PMDL mdl, KPROCESSOR_MODE mode, LOCK_OPERATION op) = 0;

virtual VOID MmUnlockPages(PMDL mdl) = 0;

virtual PVOID MmGetSystemAddressForMdlSafe(PMDL mdl, ULONG priority) = 0;
};
}
98 changes: 98 additions & 0 deletions include/kf/WinApi.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#pragma once
#include "IWinApi.h"

namespace kf
{
//////////////////////////////////////////////////////////////////////////
// Wrapper for Windows API calls to allow mocking in unit tests
class WinApi : public IWinApi
{
public:
NTSTATUS FltBuildDefaultSecurityDescriptor(PSECURITY_DESCRIPTOR* sd, ACCESS_MASK access)
{
return ::FltBuildDefaultSecurityDescriptor(sd, access);
}

VOID FltFreeSecurityDescriptor(PSECURITY_DESCRIPTOR sd)
{
::FltFreeSecurityDescriptor(sd);
}
NTSTATUS FltCreateCommunicationPort(
PFLT_FILTER filter,
PFLT_PORT* serverPort,
POBJECT_ATTRIBUTES oa,
PVOID serverPortCookie,
PFLT_CONNECT_NOTIFY connectNotify,
PFLT_DISCONNECT_NOTIFY disconnectNotify,
PFLT_MESSAGE_NOTIFY messageNotify,
LONG maxConnections)
{
return ::FltCreateCommunicationPort(filter, serverPort, oa, serverPortCookie, connectNotify, disconnectNotify, messageNotify, maxConnections);
}

VOID FltCloseCommunicationPort(PFLT_PORT port)
{
::FltCloseCommunicationPort(port);
}

VOID FltCloseClientPort(PFLT_FILTER filter, PFLT_PORT* port)
{
::FltCloseClientPort(filter, port);
}

NTSTATUS RtlSetDaclSecurityDescriptor(PSECURITY_DESCRIPTOR sd, BOOLEAN daclPresent, PACL dacl, BOOLEAN daclDefaulted)
{
return ::RtlSetDaclSecurityDescriptor(sd, daclPresent, dacl, daclDefaulted);
}

NTSTATUS RtlSetSaclSecurityDescriptor(PSECURITY_DESCRIPTOR sd, BOOLEAN saclPresent, PACL sacl, BOOLEAN saclDefaulted)
{
return ::RtlSetSaclSecurityDescriptor(sd, saclPresent, sacl, saclDefaulted);
}

ULONG RtlLengthSid(PSID sid)
{
return ::RtlLengthSid(sid);
}

NTSTATUS RtlCopySid(ULONG len, PSID dest, PSID src)
{
return ::RtlCopySid(len, dest, src);
}

NTSTATUS RtlCreateAcl(PACL acl, ULONG size, ULONG rev)
{
return ::RtlCreateAcl(acl, size, rev);
}

NTSTATUS RtlAddAce(PACL acl, ULONG rev, ULONG start, PVOID ace, ULONG aceSize)
{
return ::RtlAddAce(acl, rev, start, ace, aceSize);
}

PMDL IoAllocateMdl(PVOID va, ULONG len, BOOLEAN secondary, BOOLEAN chargeQuota, PIRP irp)
{
return ::IoAllocateMdl(va, len, secondary, chargeQuota, irp);
}

VOID IoFreeMdl(PMDL mdl)
{
::IoFreeMdl(mdl);
}

VOID MmProbeAndLockPages(PMDL mdl, KPROCESSOR_MODE mode, LOCK_OPERATION op)
{
::MmProbeAndLockPages(mdl, mode, op);
}

VOID MmUnlockPages(PMDL mdl)
{
::MmUnlockPages(mdl);
}

PVOID MmGetSystemAddressForMdlSafe(PMDL mdl, ULONG priority)
{
return ::MmGetSystemAddressForMdlSafe(mdl, priority);
}
};
}
2 changes: 2 additions & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ wdk_add_driver(kf-test WINVER NTDDI_WIN10 STL
AutoSpinLockTest.cpp
EResourceSharedLockTest.cpp
RecursiveAutoSpinLockTest.cpp
FltCommunicationPortTest.cpp
WinApiMock.h
)

target_link_libraries(kf-test kf::kf kmtest::kmtest)
Expand Down
Loading
Loading