﻿using System;
using System.Collections.Generic;
using ProcessHacker.Common;
using ProcessHacker.Native.Api;

namespace PatchPae
{
    class Program
    {
        delegate void PatchDelegate(ref LoadedImage loadedImage, out bool success);

        static void Main(string[] originalArgs)
        {
            try
            {
                Dictionary<string, string> args = Utils.ParseCommandLine(originalArgs);
                string type = "kernel";

                if (!args.ContainsKey(""))
                    Fail("Input file not specified!");
                if (!args.ContainsKey("-o"))
                    Fail("Output file not specified!");
                if (args.ContainsKey("-type"))
                    type = args["-type"];

                type = type.ToLower();

                if (type != "kernel" && type != "loader")
                    Fail("Wrong type. Must be \"kernel\" or \"loader\".");

                string sourceFileName = args[""];
                string targetFileName = args["-o"];
                //string sourceFileName = @"C:\Windows\system32\ntkrnlpa.exe";
                //string targetFileName = @"D:\ntkrnlpa_patched.exe";

                System.IO.File.Copy(sourceFileName, targetFileName, true);

                System.Diagnostics.FileVersionInfo versionInfo =
                    System.Diagnostics.FileVersionInfo.GetVersionInfo(targetFileName);

                switch (type)
                {
                    case "kernel":
                        Patch(targetFileName, PatchKernel);
                        break;
                    case "loader":
                        if (versionInfo.ProductBuildPart == 7601)
                            Patch(targetFileName, PatchLoader7601);
                        else if (versionInfo.ProductBuildPart == 7600)
                            Patch(targetFileName, PatchLoader7600);
                        else
                            Patch(targetFileName, PatchLoader);
                        break;
                    default:
                        throw new NotSupportedException();
                }
            }
            catch (Exception ex)
            {
                Fail(ex.ToString());
            }
        }

        static void Patch(string fileName, PatchDelegate action)
        {
            bool success = false;
            LoadedImage loadedImage;

            if (!Win32.MapAndLoad(
                fileName,
                null,
                out loadedImage,
                false,
                false
                ))
                Win32.Throw();

            try
            {
                action(ref loadedImage, out success);
            }
            finally
            {
                // This will unload the image and fix the checksum.
                Win32.UnMapAndLoad(ref loadedImage);
            }

            if (success)
                Console.WriteLine("Patched.");
            else
                Console.WriteLine("Failed.");
        }

        static void PatchKernel(ref LoadedImage loadedImage, out bool success)
        {
            success = false;

            unsafe
            {
                // MxMemoryLicense

                // Basically, the portion of code we are going to patch 
                // queries the NT license value for the allowed memory.
                // If there is a limit, it sets MiTotalPagesAllowed to 
                // that limit times 256. If there is no specified limit, 
                // it sets MiTotalPagesAllowed to 0x80000 (2 GB).
                //
                // We will patch the limit to be 0x20000 << 8 pages (128 GB).

                byte[] target = new byte[]
                {
                    // test eax, eax ; did ZwQueryLicenseValue succeed?
                    0x85, 0xc0,
                    // jl short loc_75644b ; if it didn't go to the default case
                    0x7c, 0x11,
                    // mov eax, [ebp+var_4] ; get the returned memory limit
                    0x8b, 0x45, 0xfc,
                    // test eax, eax ; is it non-zero?
                    0x85, 0xc0,
                    // jz short loc_75644b ; if it's zero, go to the default case
                    0x74, 0x0a,
                    // shl eax, 8 ; multiply by 256
                    0xc1, 0xe0, 0x08
                    // mov ds:_MiTotalPagesAllowed, eax ; store in MiTotalPagesAllowed
                    // 0xa3, 0x2c, 0x76, 0x53, 0x00
                    // jmp short loc_756455 ; go to the next bit
                    // 0xeb, 0x0a
                    // loc_75644b: mov ds:_MiTotalPagesAllowed, 0x80000
                    // 0xc7, 0x05, 0x2c, 0x76, 0x53, 0x00, 0x00, 0x00, 0x08, 0x00
                };
                int movOffset = 4;
                byte* ptr = (byte*)loadedImage.MappedAddress.ToPointer();

                for (int i = 0; i < loadedImage.SizeOfImage - target.Length; i++)
                {
                    int j;

                    for (j = 0; j < target.Length; j++)
                    {
                        if (ptr[j] != target[j])
                            break;
                    }

                    if (j == target.Length)
                    {
                        // Found it. Patch the code.

                        // mov eax, [ebp+var_4] -> mov eax, 0x20000
                        ptr[movOffset] = 0xb8;
                        *(int*)&ptr[movOffset + 1] = 0x20000;
                        // nop out the jz
                        ptr[movOffset + 5] = 0x90;
                        ptr[movOffset + 6] = 0x90;

                        // Do the same thing to the next mov eax, [ebp+var_4] 
                        // occurence.
                        for (int k = 0; k < 100; k++)
                        {
                            if (
                                ptr[k] == 0x8b &&
                                ptr[k + 1] == 0x45 &&
                                ptr[k + 2] == 0xfc &&
                                ptr[k + 3] == 0x85 &&
                                ptr[k + 4] == 0xc0
                                )
                            {
                                // mov eax, [ebp+var_4] -> mov eax, 0x20000
                                ptr[k] = 0xb8;
                                *(int*)&ptr[k + 1] = 0x20000;
                                // nop out the jz
                                ptr[k + 5] = 0x90;
                                ptr[k + 6] = 0x90;

                                success = true;

                                break;
                            }
                        }

                        break;
                    }

                    ptr++;
                }
            }
        }

        static void PatchLoader(ref LoadedImage loadedImage, out bool success)
        {
            success = false;

            unsafe
            {
                // BlImgLoadPEImageEx

                // There is a function called ImgpValidateImageHash. We are 
                // going to patch BlImgLoadPEImageEx so that it doesn't care 
                // what the result of the function is.

                byte[] target = new byte[]
                {
                    // sub esi, [ebx+4]
                    0x2b, 0x73, 0x04,
                    // push eax
                    0x50,
                    // add esi, [ebp+var_18]
                    0x03, 0x75, 0xe8,
                    // lea eax, [ebp+Source1]
                    0x8d, 0x45, 0x8c,
                    // push eax
                    0x50,
                    // push esi
                    0x56,
                    // mov eax, ebx
                    0x8b, 0xc3
                    // call _ImgpValidateImageHash@16
                    // 0xe8, 0x59, 0x0b, 0x00, 0x00
                    // mov ecx, eax ; copy return status into ecx
                    // test ecx, ecx ; did ImgpValidateImageHash succeed?
                    // mov [ebp+arg_0], ecx ; store the NT status into a variable
                    // jge short loc_42109f ; if the function succeeded, go there
                };
                int movOffset = 19;
                byte* ptr = (byte*)loadedImage.MappedAddress.ToPointer();

                for (int i = 0; i < loadedImage.SizeOfImage - target.Length; i++)
                {
                    int j;

                    for (j = 0; j < target.Length; j++)
                    {
                        if (ptr[j] != target[j])
                            break;
                    }

                    if (j == target.Length)
                    {
                        // Found it. Patch the code.

                        // mov ecx, eax -> mov [ebp+arg_0], 0
                        // 0x8b, 0xc8 -> 0xc7, 0x45, 0x08, 0x00, 0x00, 0x00, 0x00
                        ptr[movOffset] = 0xc7;
                        ptr[movOffset + 1] = 0x45;
                        ptr[movOffset + 2] = 0x08;
                        ptr[movOffset + 3] = 0x00;
                        ptr[movOffset + 4] = 0x00;
                        ptr[movOffset + 5] = 0x00;
                        ptr[movOffset + 6] = 0x00;
                        // jge short loc_42109f -> jmp short loc_42109f
                        // 0x85, 0xc9 -> 0xeb, 0xc9
                        ptr[movOffset + 7] = 0xeb;

                        success = true;

                        break;
                    }

                    ptr++;
                }
            }
        }

        static void PatchLoader7600(ref LoadedImage loadedImage, out bool success)
        {
            success = false;

            unsafe
            {
                // BlImgLoadPEImage

                // There is a function called ImgpValidateImageHash. We are 
                // going to patch BlImgLoadPEImage so that it doesn't care 
                // what the result of the function is.

                byte[] target = new byte[]
                {
                    // push eax
                    0x50,
                    // lea eax, [ebp+Source1]
                    0x8d, 0x85, 0x94, 0xfe, 0xff, 0xff,
                    // push eax
                    0x50,
                    // push [ebp+var_12c]
                    0xff, 0xb5, 0xd4, 0xfe, 0xff, 0xff,
                    // mov eax, [ebp+var_24]
                    0x8b, 0x45, 0xdc,
                    // push [ebp+var_18]
                    0xff, 0x75, 0xe8,
                    // call _ImgpValidateImageHash@24
                    // 0xe8, 0x63, 0x05, 0x00, 0x00
                    // mov [ebp+var_8], eax ; copy return status into var_8
                    // 0x89, 0x45, 0xf8
                    // test eax, eax ; did ImgpValidateImageHash succeed?
                    // 0x85, 0xc0
                    // jge short loc_428ee5 ; if the function succeeded, go there
                    // 0x7d, 0x2e
                };
                int jgeOffset = 30;
                byte* ptr = (byte*)loadedImage.MappedAddress.ToPointer();

                for (int i = 0; i < loadedImage.SizeOfImage - target.Length; i++)
                {
                    int j;

                    for (j = 0; j < target.Length; j++)
                    {
                        if (ptr[j] != target[j])
                            break;
                    }

                    if (j == target.Length)
                    {
                        // Found it. Patch the code.
                        // Note that we don't need to update var_8 as it is 
                        // a temporary status variable which will be overwritten 
                        // very shortly.

                        // jge short loc_428ee5 -> jmp short loc_428ee5
                        // 0x7d, 0x2e -> 0xeb, 0x2e
                        ptr[jgeOffset] = 0xeb;

                        success = true;

                        break;
                    }

                    ptr++;
                }
            }
        }

        static void PatchLoader7601(ref LoadedImage loadedImage, out bool success)
        {
            success = false;

            unsafe
            {
                // ImgpLoadPEImage

                // There is a function called ImgpValidateImageHash. We are 
                // going to patch ImgpLoadPEImage so that it doesn't care 
                // what the result of the function is.

                byte[] target = new byte[]
                {
                    // push eax
                    0x50,
                    // lea eax, [ebp+Source1]
                    0x8d, 0x85, 0x94, 0xfe, 0xff, 0xff,
                    // push eax
                    0x50,
                    // push [ebp+var_12c]
                    0xff, 0xb5, 0xd4, 0xfe, 0xff, 0xff,
                    // mov eax, [ebp+var_24]
                    0x8b, 0x45, 0xdc,
                    // push [ebp+var_18]
                    0xff, 0x75, 0xe8,
                    // call _ImgpValidateImageHash@24
                    // 0xe8, 0x63, 0x05, 0x00, 0x00
                    // mov [ebp+var_8], eax ; copy return status into var_8
                    // 0x89, 0x45, 0xf8
                    // test eax, eax ; did ImgpValidateImageHash succeed?
                    // 0x85, 0xc0
                    // jge short loc_428f57 ; if the function succeeded, go there
                    // 0x7d, 0x2e
                };
                int jgeOffset = 30;
                byte* ptr = (byte*)loadedImage.MappedAddress.ToPointer();

                for (int i = 0; i < loadedImage.SizeOfImage - target.Length; i++)
                {
                    int j;

                    for (j = 0; j < target.Length; j++)
                    {
                        if (ptr[j] != target[j])
                            break;
                    }

                    if (j == target.Length)
                    {
                        // Found it. Patch the code.
                        // Note that we don't need to update var_8 as it is 
                        // a temporary status variable which will be overwritten 
                        // very shortly.

                        // jge short loc_428f57 -> jmp short loc_428f57
                        // 0x7d, 0x2e -> 0xeb, 0x2e
                        ptr[jgeOffset] = 0xeb;

                        success = true;

                        break;
                    }

                    ptr++;
                }
            }
        }

        static void Fail(string message)
        {
            Console.WriteLine(message);
            Environment.Exit(1);
        }
    }
}
