#!/usr/bin/python

from argparse import ArgumentTypeError, ArgumentParser
from collections import OrderedDict
from subprocess import check_output, STDOUT, PIPE, Popen
import sys
import re

from argcomplete import autocomplete
import pyperclip

from bashfuscator.common.colors import bold, yellow
from bashfuscator.common.messages import activateQuietMode, printInfo, printWarning, printError, printExitMsg
from bashfuscator.core.engine.obfuscation_handler import ObfuscationHandler
from bashfuscator.core.utils import import_mutators


commandObfuscators, stringObfuscators, tokenObfuscators, encoders, compressors = import_mutators()


def check_positive(value):
    ivalue = int(value)
    if ivalue <= 0:
        raise ArgumentTypeError(f"{ivalue} is an invalid positive int value")
    return ivalue


def getMutators(prefix, parsed_args, **kwargs):
    """
    Returns a list of all mutator's longName attribute for
    auto-completion in cli
    """
    mutatorNames = []
    allMutators = commandObfuscators + stringObfuscators + tokenObfuscators + encoders + compressors
    for mutator in allMutators:
        mutatorNames.append(mutator.longName)

    return mutatorNames


def getMutatorsAndStubs(prefix, parsed_args, **kwargs):
    """
    Returns a list of all mutator's longName attribute, and every
    longName attribute of each CommandObfuscator's stubs for
    auto-completion in cli
    """
    mutatorNames = []
    for cmdOb in commandObfuscators:
        for stub in cmdOb.stubs:
            mutatorNames.append(cmdOb.longName + "/" + stub.longName)

    mutators = stringObfuscators + tokenObfuscators + encoders + compressors
    for mutator in mutators:
        mutatorNames.append(mutator.longName)

    return mutatorNames


def listMutators():
    print(bold("Command Obfuscators:"))
    listMutatorType(commandObfuscators)

    print(bold("\nString Obfuscators:"))
    listMutatorType(stringObfuscators)

    print(bold("\nToken Obfuscators:"))
    listMutatorType(tokenObfuscators)

    print(bold("\nEncoders:"))
    listMutatorType(encoders)

    print(bold("\nCompressors:"))
    listMutatorType(compressors)


def listMutatorType(mutators):
    mutators.sort(key=lambda x: x.name)

    for mutator in mutators:
        print(bold(f"\nName: {mutator.name}"))
        print(f"Description: {mutator.description}")
        print(f"Size Rating: {mutator.sizeRating}")
        print(f"Time Rating: {mutator.timeRating}")
        if mutator.mutatorType != "command":
            if mutator.binariesUsed:
                    binariesUsedString = "".join(b + ", " for b in mutator.binariesUsed)[:-2]
            else:
                binariesUsedString = "None"

            print(f"Binaries used: {binariesUsedString}")
            print(f"File write: {mutator.fileWrite}")

        if mutator.notes is not None: print(f"Notes: {mutator.notes}")
        if mutator.author is not None: print(f"Author: {mutator.author}")
        if mutator.credits is not None:
            print("Credits: ", end="")

            for idx, credit in enumerate(mutator.credits):
                if idx == 0:
                    print(credit)
                else:
                    print(" "*9 + credit)

        if mutator.mutatorType == "command":
            print("\nStubs:")
            for stub in mutator.stubs:
                print(f"\n\tName: {stub.name}")
                print(f"\tSize Rating: {stub.sizeRating}")
                print(f"\tTime Rating: {stub.timeRating}")
                print(f"\tFile write: {stub.fileWrite}")

                if stub.binariesUsed:
                    binariesUsedString = "".join(b + ", " for b in stub.binariesUsed)[:-2]
                else:
                    binariesUsedString = "None"

                print(f"\tBinaries Used: {binariesUsedString}")


def printArgError(errorMsg):
    print(f"bashfuscator: error: {errorMsg}")
    sys.exit(1)


def checkArguments(args):
    if not args.command and not args.file and not args.stdin:
        parser.print_usage()
        printArgError("one of the arguments -c/--command, -f/--file, or --stdin is required")

    if args.quiet and args.test:
        parser.print_usage()
        printArgError("--test is not allowed with -q/--quiet")

    if not args.no_mangling:
        if not args.no_binary_mangling:
            printWarning("--no-mangling implies --no-binary-mangling")
        if not args.no_random_whitespace:
            printWarning("--no-mangling implies --no-random-whitespace")
        if not args.no_insert_chars:
            printWarning("--no-mangling implies --no-insert-chars")
        if not args.no_integer_mangling:
            printWarning("--no-mangling implies --no-integer-mangling")
        if not args.no_integer_expansion:
            printWarning("--no-mangling implies --no-integer-expansion")
        if not args.no_integer_base_randomization:
            printWarning("--no-mangling implies --no-integer-base-randomization")
        if not args.no_terminator_randomization:
            printWarning("--no-mangling implies --no-terminator-randomization")

    elif not args.no_integer_mangling:
        if not args.no_integer_base_randomization:
            printWarning("--no-integer-mangling implies --no-integer-base-randomization")
        if not args.no_integer_expansion:
            printWarning("--no-integer-mangling implies --no-integer-expansion")

    if args.binary_mangle_percent:
        if not args.no_mangling or not args.no_binary_mangling:
            printWarning("binary mangling is disabled, --binary-mangle-percent argument will not take effect")

    if args.random_whitespace_range:
        if not args.no_mangling or not args.no_random_whitespace:
            printWarning("random whitespace is disabled, --random-whitespace-range argument will not take effect")
        else:
            args.random_whitespace_range = checkRangeArgs("--random-whitespace-range", args.random_whitespace_range)

    if args.insert_chars_range:
        if not args.no_mangling or not args.no_insert_chars:
            printWarning("random character insertion is disabled, --insert-char-range argument will not take effect")
        else:
            args.insert_chars_range = checkRangeArgs("--insert_chars_range", args.insert_chars_range)

    if args.misleading_commands_range:
        if not args.no_mangling or not args.no_misleading_commands:
            printWarning("misleading command insertion is disabled, --misleading-commands-range argument will not take effect")
        else:
            args.misleading_commands_range = checkRangeArgs("--misleading_commands_range", args.misleading_commands_range)

    if args.integer_expansion_depth:
        if not args.no_mangling or not args.no_integer_mangling or not args.no_integer_expansion:
            printWarning("integer arithmetic expansion is disabled, --integer-expansion-depth argument will not take effect")


def checkRangeArgs(argName, argValue):
    if argValue.find(",") == -1:
        printArgError(f"Invalid value for {argName} option: format is NUM,NUM")

    rangeValues = *(int(x) for x in argValue.split(",")),

    if rangeValues[0] < 0 or rangeValues[1] < 0:
        printArgError(f"Invalid value for {argName} option: range values must be positive")

    if rangeValues[1] == 0:
        printArgError(f"Invalid value for {argName} option: upper bound must be greater than zero")

    if rangeValues[0] > rangeValues[1]:
        printArgError(f"Invalid value for {argName} option: invalid range")

    return rangeValues


def printPayload(payload, args):
    rows, columns = check_output(["stty", "size"]).decode().split()
    rows = int(rows)
    columns = int(columns)

    printInfo("Payload:\n")

    payloadLen = len(payload)
    payloadPrintLen = int((rows * (columns * .50)) / 2)
    payloadPrintLen = payloadPrintLen + (payloadPrintLen % columns)

    if not args.quiet and (args.clip or args.outfile) and payloadLen > payloadPrintLen * 2:
        # we don't need to print the whole payload, the user is given the full payload via a different method
        # thanks Dbo for most of the logic and idea: https://github.com/danielbohannon/Invoke-Obfuscation
        notShownMessageLen = len("[Not shown: characters]")
        redactedLen = payloadLen - payloadPrintLen - notShownMessageLen
        notShownMessage = yellow(f"[Not shown: {redactedLen + len(str(redactedLen))} characters]")

        # make sure message is centered
        notShownMsgCenteredIndex = int((columns - len(notShownMessage)) / 2)
        notShownMsgCurrentIndex = payloadPrintLen % columns

        if notShownMsgCurrentIndex > notShownMsgCenteredIndex:
            payloadPrintLen = payloadPrintLen - (notShownMsgCurrentIndex - notShownMsgCenteredIndex)
        else:
            payloadPrintLen = payloadPrintLen + (notShownMsgCenteredIndex - notShownMsgCurrentIndex)

        print(payload[:payloadPrintLen] + notShownMessage + payload[payloadLen - payloadPrintLen:])
    else:
        print(payload)

    if not args.quiet:
        print("")

    printInfo(f"Payload size: {payloadLen} characters")


if __name__ == "__main__":
    parser = ArgumentParser()
    progOpts = parser.add_argument_group("Program Options")
    inptOpts = progOpts.add_mutually_exclusive_group()
    progOpts.add_argument("-l", "--list", action="store_true", help="List all the availible obfuscators, compressors, and encoders")
    inptOpts.add_argument("-c", "--command", type=str, help="Command to obfuscate")
    inptOpts.add_argument("-f", "--file", type=str, help="Name of the script to obfuscate")
    inptOpts.add_argument("--stdin", action="store_true", help="Obfuscate stdin")
    progOpts.add_argument("-o", "--outfile", type=str, help="File to write payload to")
    progOpts.add_argument("-q", "--quiet", action="store_true", help="Print only the payload")
    progOpts.add_argument("--clip", action="store_true", help="Copy the payload to clipboard")
    progOpts.add_argument("--test", action="store_true", help="Test the payload after running it. Not compatible with -q")

    obOpts = parser.add_argument_group("Obfuscation Options")
    obOpts.add_argument("-s", "--payload-size", default=2, type=int, choices=range(1, 4), help="Desired size of the payload. Default: 2")
    obOpts.add_argument("-t", "--execution-time", default=2, type=int, choices=range(1, 4), help="Desired speed of the payload. Default: 2")
    obOpts.add_argument("--layers", type=check_positive, help="Number of layers of obfuscation to apply. Default is 1 when --choose-mutators is used, otherwise: 2")
    binOpts = obOpts.add_mutually_exclusive_group()
    binOpts.add_argument("--include-binaries", nargs="+", type=str, metavar="BINARIES", help="Binaries you exclusively want used in the generated payload")
    binOpts.add_argument("--exclude-binaries", nargs="+", type=str, metavar="BINARIES", help="Binaries you don't want to be used in the generated payload")
    obOpts.add_argument("--no-file-write", action="store_false", help="Don't use obfuscators that require writing to files")
    obOpts.add_argument("--write-dir", default="/tmp/", type=str, help="Directory to use if Mutators need to write to or create files")

    advancedOpts = parser.add_argument_group("Advanced Options")
    chooseOpts = advancedOpts.add_mutually_exclusive_group()
    chooseOpts.add_argument("--choose-mutators", nargs="+", metavar="MUTATOR", help="Manually choose what mutators are used in what order").completer = getMutators
    chooseOpts.add_argument("--choose-all", nargs="+", metavar="MUTATOR", help="Manually choose what mutators and their stubs if applicable").completer = getMutatorsAndStubs
    advancedOpts.add_argument("--no-mangling", action="store_false", help="Don't preform binary mangling and don't insert random whitespace and characters")
    advancedOpts.add_argument("--no-binary-mangling", action="store_false", help="Don't obfuscate binary/builtin names")
    advancedOpts.add_argument("--binary-mangle-percent", type=int, choices=range(1, 101), metavar="{1..100}", help="Percentage of chars of binary/builtin names to mangle")
    advancedOpts.add_argument("--no-random-whitespace", action="store_false", help="Don't randomly add whitespace")
    advancedOpts.add_argument("--random-whitespace-range", type=str, metavar="NUM,NUM", help="The range of random whitespace to add")
    advancedOpts.add_argument("--no-insert-chars", action="store_false", help="Don't randomly add variables and characters that Bash ignores")
    advancedOpts.add_argument("--insert-chars-range", type=str, metavar="NUM,NUM", help="The range of random variables and characters to add")
    advancedOpts.add_argument("--no-misleading-commands", action="store_false", help="Don't randomly insert misleading commands that do nothing. NOT YET IMPLEMENTED")
    advancedOpts.add_argument("--misleading-commands-range", type=str, metavar="NUM,NUM", help="The range of random misleading commands to add. NOT YET IMPLEMENTED")
    advancedOpts.add_argument("--no-integer-mangling", action="store_false", help="Don't expand integers into arithmetic expressions with multiple number bases")
    advancedOpts.add_argument("--no-integer-expansion", action="store_false", help="Don't expand integers into arithmetic expressions")
    advancedOpts.add_argument("--no-integer-base-randomization", action="store_false", help="Don't randomize the number bases of integers")
    advancedOpts.add_argument("--integer-expansion-depth", type=check_positive, help="Number of times to recursively expand integers into arithmetic expansions")
    advancedOpts.add_argument("--no-terminator-randomization", action="store_false", help="Don't randomize the symbols that terminate commands")
    advancedOpts.add_argument("--full-ascii-strings", action="store_true", help="Use the full ASCII character set (with a few exceptions) when generating random strings")

    debugOpts = parser.add_argument_group("Debugging Options")
    debugOpts = debugOpts.add_argument("--debug", action="store_true", help="Automatically terminate every command in the payload with a newline for easier payload debugging")

    autocomplete(parser)
    args = parser.parse_args()


    if args.list:
        listMutators()
        exit(0)

    if args.quiet:
        activateQuietMode()

    checkArguments(args)

    if args.layers is None:
        if args.choose_mutators is not None or args.choose_all is not None:
            args.layers = 1
        else:
            args.layers = 2

    if args.include_binaries is not None:
        args.binaryPref = (args.include_binaries, True)
    elif args.exclude_binaries is not None:
        args.binaryPref = (args.exclude_binaries, False)
    else:
        args.binaryPref = None

    shebang = "#!/bin/bash\n"

    if args.stdin:
        args.command = sys.stdin.read()
        sys.stdin.close()

    elif args.file:
        with open(args.file, "rb") as infile:
            firstLine = infile.readline().decode("utf-8")
            if re.match(r"\s*#!", firstLine) is not None:
                shebang = firstLine
            else:
                infile.seek(0)

            args.command = infile.read().decode("utf-8")

    if args.write_dir[-1:] != "/":
        args.write_dir += "/"

    obHandler = ObfuscationHandler(commandObfuscators, stringObfuscators, tokenObfuscators, encoders, compressors, args)
    try:
        payload = obHandler.generatePayload()
    except KeyboardInterrupt:
        printExitMsg("\nCtrl-C caught... stopping payload generation...")

    if not args.choose_mutators and not args.choose_all:
        mutatorChainStr = "Mutators used: "
        for mutator in obHandler.mutatorList:
            mutatorChainStr += f"{mutator.mutatorType.title()}/{mutator.name} -> "

        mutatorChainStr = mutatorChainStr[:-4]
        printInfo(mutatorChainStr)

    if args.outfile:
        with open(args.outfile, "wb") as outfile:
            outfile.write(shebang.encode("utf-8"))
            outfile.write(payload.encode("utf-8"))

        printInfo(f"Payload written to {args.outfile}")

    if args.clip:
        pyperclip.copy(payload)
        printInfo("Payload copied to clipboard")

    printPayload(payload, args)

    if args.test:
        try:
            printInfo("Testing payload:\n")
            if args.outfile:
                proc = Popen(["bash", args.outfile], stdout=PIPE, stderr=STDOUT, shell=False, universal_newlines=True)
                stdout, __ = proc.communicate()
                print(stdout)
            else:
                proc = Popen(payload, executable="bash", stdout=PIPE, stderr=STDOUT, shell=True, universal_newlines=True)
                stdout, __ = proc.communicate()
                print(stdout)
        except KeyboardInterrupt:
            printExitMsg("\nCtrl-C caught... stopping payload testing...")
