随着VS2017的普遍使用,C++驱动的编写已经不用完全使用C语言了,
C语言时代,最难处理的是资源泄露,分配了的还要写代码进行回收
现在,有了C++,完全可以利用C++来提升代码质量
(1)首先我们应注意,C++隐藏了太多细节,这使得写驱动容易蓝屏
所以,用C++写驱动,一定要实现一个简化的C++运行环境
(2)接下来我们实现一个基于CreateProcessCallback的禁止某些进程执行的防护型驱动
首先实现全局环境
struct CxxGlobal { public: CxxGlobal(); ~CxxGlobal(); public: //在这里放置所有的全局对象 wstring str_DeviceName; wstring str_SYMBOLIC_LINK_NAME; CMyEResourceLock g_lock_MapCreateProcessCBKRule;//CreateProcessCallback 的规则链表锁 CLinkListMapByString g_MapCpCBKRule;//CreateProcessCallback的规则表 };
我们利用new的重载,让这个全局对象创建在给定的全局数据区
CxxGlobal* CreateCxxGlobal_OnDriverEntry(); //请在DriverEntry里调用,确保一定成功后再,否则DriverEntry应该返回失败 void DeleteCxxGlobal_OnDriverUnload();//请在DriverUnload里调用,再此之前请做好Uninit的操作 void * __cdecl operator new(size_t sizex) { return Alloc((DWORD)sizex); } void * __cdecl operator new(size_t sizex, void* p) { return p; } void __cdecl operator delete(void * ptr, size_t sizex) { Free(ptr); } unsigned char g_cxxGlobalObj[sizeof(CxxGlobal)]; CxxGlobal& Global = *((CxxGlobal*)(&(g_cxxGlobalObj[0]))); CxxGlobal* g_pGlobal = 0; CxxGlobal * CreateCxxGlobal_OnDriverEntry() { if (!g_pGlobal) { g_pGlobal = new((CxxGlobal*)(&(g_cxxGlobalObj[0]))) CxxGlobal(); } return g_pGlobal; } void DeleteCxxGlobal_OnDriverUnload() { if (g_pGlobal) g_pGlobal->~CxxGlobal(); g_pGlobal = NULL; }
(3)字符串类的实现
template<typename chType> __inline int Tstrlen(const chType* pstr) { int r = 0; if (pstr) { while (*pstr) { r++; pstr++; } } return r; } template<typename chType> class _chStringT { public: chType* m_p; int length; public: _chStringT() :m_p(nullptr), length(0) {} _chStringT(const _chStringT& dest) { m_p = (chType*)Alloc((dest.length + 1) * sizeof(chType)); if (m_p) { length = dest.length; for (int i = 0; i < length; i++) m_p[i] = dest.m_p[i]; m_p[length] = 0; } } _chStringT(const chType* pStr) { int len = Tstrlen(pStr); m_p = (chType*)Alloc((len + 1) * sizeof(chType)); if (m_p) { length = len; for (int i = 0; i < length; i++) m_p[i] = pStr[i]; m_p[length] = 0; } } _chStringT& operator=(const _chStringT& dest) { __freedata(); m_p = (chType*)Alloc((dest.length + 1) * sizeof(chType)); if (m_p) { length = dest.length; for (int i = 0; i < length; i++) m_p[i] = dest.m_p[i]; m_p[length] = 0; } return (*this); } _chStringT& operator=(const chType* pStr) { __freedata(); int len = Tstrlen(pStr); m_p = (chType*)Alloc((len + 1) * sizeof(chType)); if (m_p) { length = len; for (int i = 0; i < length; i++) m_p[i] = pStr[i]; m_p[length] = 0; } return (*this); } ~_chStringT() { __freedata(); } void resize(int len, chType initch) { chType* pNewBuffer = (chType*)Alloc((len + 1) * sizeof(chType)); if (pNewBuffer) { __freedata(); m_p = pNewBuffer; length = len; for (int i = 0; i < len; i++) pNewBuffer[i] = initch; //end 0 pNewBuffer[len] = 0; } } chType& operator[](int index) { if (m_p) { if (index >= 0 && index < length) return m_p[index]; } return __dummy_ref_char_for_guard_safe; } const chType* c_str() { return m_p; } protected: void __freedata() { if (m_p) { Free(m_p); m_p = nullptr; length = 0; } } static chType __dummy_ref_char_for_guard_safe; }; template<typename chType> chType _chStringT<chType>::__dummy_ref_char_for_guard_safe = 0; typedef _chStringT<char> string; typedef _chStringT<wchar_t> wstring;
(4)RAII思想设计的资源锁
class CMyEResourceLock { public: CMyEResourceLock(); ~CMyEResourceLock(); ERESOURCE* GetResource(){return &m_lock_eresource;} private: ERESOURCE m_lock_eresource; }; #define MY_ENTER_CRITICAL_SECTION(section) KeEnterCriticalRegion();ExAcquireResourceSharedLite(section, TRUE) #define MY_LEAVE_CRITICAL_SECTION(section) ExReleaseResourceLite(section);KeLeaveCriticalRegion() class CAutoLockEResource { public: CAutoLockEResource(ERESOURCE* pResource) { m_pRes = pResource; MY_ENTER_CRITICAL_SECTION(m_pRes); } ~CAutoLockEResource() { if (m_pRes) { MY_LEAVE_CRITICAL_SECTION(m_pRes); } } private: ERESOURCE* m_pRes; };
(5)链表类
typedef struct _LinkListMapObj { WCHAR szPath[260]; BOOL bLetItOK; } LinkListMapObj; typedef BOOL PFN_LinkListMapTravelCallback(void* pContext, LinkListMapObj* pDataObj, _LinkListEntryT<LinkListMapObj>* pCurrent, _LinkListEntryT<LinkListMapObj>* pParentOfCurrent); class CLinkListMapByString { public: CLinkListMapByString(); ~CLinkListMapByString(); public: BOOL AddString(LPCWSTR lpString, BOOL bLetItOK); LinkListMapObj* FindByName(LPCWSTR lpFind); BOOL Erase(LPCWSTR lpString); BOOL DeleteEntry(_LinkListEntryT<LinkListMapObj> * pEntry, _LinkListEntryT<LinkListMapObj> * pParentOfEntry); void Clear(); void Travel(PFN_LinkListMapTravelCallback fnCallback, void * pContext); int GetCount(); private: LinkListT<LinkListMapObj> m_list; };
(6)实现CreateProcessCallback
VOID ProcessMonitorCallback(IN HANDLE hParentId, IN HANDLE hProcessId, IN BOOLEAN bCreate) { wchar_t szPath[520]; NTSTATUS status; HANDLE procHandle = NULL; CLIENT_ID ClientId; OBJECT_ATTRIBUTES Obja; Obja.Length = sizeof(Obja); Obja.RootDirectory = 0; Obja.ObjectName = 0; Obja.Attributes = 0; Obja.SecurityDescriptor = 0; Obja.SecurityQualityOfService = 0; ClientId.UniqueProcess = (HANDLE)hProcessId; ClientId.UniqueThread = 0; //不管创建什么程序都关闭程序 if (bCreate && !g_CpCbkUninit) //bCreate 为True表示创建程序 { //调用函数ZwOpenProcess函数,通过进程pid号获得进程句柄 status = ZwOpenProcess(&procHandle, PROCESS_ALL_ACCESS, &Obja, &ClientId); if (procHandle != NULL) { UNICODE_STRING us_ProcName; NTSTATUS status2; memset(szPath, 0, sizeof(szPath)); us_ProcName.MaximumLength = 512; us_ProcName.Length = 0; us_ProcName.Buffer = szPath; status2 = GetProcessImageName(&us_ProcName, procHandle); if (STATUS_SUCCESS == status2) { wchar_t* pToFree = 0; UNICODE_STRING usDosPath = {}; BOOL bInitDosPath = InitUnicode_AllocIoVolPathToDosPath(&us_ProcName, &usDosPath); if (usDosPath.Buffer) pToFree = usDosPath.Buffer; ANSI_STRING ansDosPath = {}; NTSTATUS ntAllocAnsiPath = STATUS_BUFFER_TOO_SMALL; if (IsXp2003() && usDosPath.Buffer) { ntAllocAnsiPath = RtlUnicodeStringToAnsiString(&ansDosPath, &usDosPath, TRUE); } //判断是不是在链表的 if (IsMatchProcessInList(usDosPath.Buffer ? usDosPath.Buffer : us_ProcName.Buffer)) { status = ZwTerminateProcess(procHandle, 0); } if (pToFree) Free(pToFree); if (ntAllocAnsiPath == STATUS_SUCCESS) RtlFreeAnsiString(&ansDosPath); } ZwClose(procHandle); } } }