#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Author: Quentin Kaiser <kaiserquentin@gmail.com>
#
# let's disable 'invalid constant name' shenanigans
# pylint: disable=invalid-name
"""
Module docstring.
"""
import multiprocessing
import signal
import logging
try:
    from urlparse import urlparse
except ImportError:
    from urllib.parse import urlparse
import argparse
import socket
from uuid import uuid4
import re
import pika
import coloredlogs
import time
import verboselogs
from cottontail.rabbitmq_management import *
from cottontail import __version__, __author__, __email__

HEADER = r"""
       /\ /|
       \ V/
       | "")     Cottontail v{}
       /  |      {} ({})
      /  \\
    *(__\_\)
    """.format(__version__, __author__, __email__)

def subproc(host, port, ssl, username, password, vhost_name):
    """
    Function that is launched within a process. We launch one process per
    vhost, each using a blocking connection.

    Args:
        host (str): AMQP server hostname or IP address
        port (int): AMQP server listening port
        ssl (bool): indicates if AMQP over SSL
        username (str): AMQP credentials
        password (str): AMQP credentials
        vhost_name (str): vhost name to which our rabbitmq connection binds to

    Returns:
        None. Triggered when user hits ctrl-c
    """
    logger.verbose("Connecting to amqp{}://{}:{}/{}".format(
        "s" if ssl else "", host, port, vhost_name))

    ssl_options = None
    if ssl:
        ssl_options = {}
        import ssl as s
        ssl_options["cert_reqs"] = s.CERT_NONE

    try:
        credentials = pika.PlainCredentials(username, password)
        parameters = pika.ConnectionParameters(
            host=host, port=port, virtual_host=vhost_name,
            credentials=credentials, ssl_options=ssl_options)
        connection = pika.BlockingConnection(parameters)
        channel = connection.channel()

        # setup ctrl-c handling
        def term(*args, **kwargs):
            """ Called when user hit ctrl-c, gracefully close AMQP connection.
            """
            logger.info("Closing connection to vhost '{}'".format(vhost_name))
            channel.stop_consuming()
        signal.signal(signal.SIGINT, term)

        unique_header = str(uuid4())
        permissions = rbmq.get_permissions(vhost_name, username)

        def callback(ch, method, properties, body):
            """
            Callback function called when we receive a message from RabbitMQ.

            Todo:
                * better handling of body depending on mime type ?

            Args:
                ch (object): pika channel object
                method (object): RabbitMQ message method
                properties (object): RabbitMQ message properties
                body (object): RabbitMQ message body

            Returns:
                None
            """

            unique_header_present = False
            # check if a received message contains our custom header
            if properties.headers is not None:
                if unique_header in properties.headers:
                    unique_header_present = True

            if not unique_header_present:
                # this is not a message we requeued ourselves
                logger.info(
                    "Message from [vhost={}][exchange={}][routing_key={}]: {}".format(
                        vhost_name, method.exchange, method.routing_key, body)
                )

                logger.debug("\tContent-type: {}".format(properties.content_type))
                logger.debug("\tContent-encoding: {}".format(properties.content_encoding))
                logger.debug("\tHeaders: {}".format(properties.headers))
                logger.debug("\tDelivery-mode: {}".format("persistent" \
                    if properties.delivery_mode == 2 else "non persistent"))
                logger.debug("\tPriority: {}".format(properties.priority))
                logger.debug("\tCorrelation-id: {}".format(properties.correlation_id))
                logger.debug("\tReply-to: {}".format(properties.reply_to))
                logger.debug("\tExpiration: {}".format(properties.expiration))
                logger.debug("\tMessage-id: {}".format(properties.message_id))
                logger.debug("\tTimestamp: {}".format(properties.timestamp))
                logger.debug("\tType: {}".format(properties.type))
                logger.debug("\tUser-id: {}".format(properties.user_id))
                logger.debug("\tApp-id: {}".format(properties.app_id))
                logger.debug("\tCluster-id: {}".format(properties.cluster_id))


            # if other consumers are present, we requeue the message so we don't
            # mess things up.
            # NOTE: only requeue if we are in direct exchanges (no need to
            # requeue if fanout/topic exchange).
            if not unique_header_present and not len(method.exchange):
                # we insert our unique header
                headers = properties.headers if properties.headers is not None else {}
                headers[unique_header] = 1

                # if user_id is set to a different value than the user we authenticated
                # with, we remove it so we don't trigger a "PRECONDITION_FAILED -
                # user_id property set to 'xxx' but authenticated user was 'xxx'"
                if properties.user_id is not None and\
                    properties.user_id != username:
                    properties.user_id = None

                # check 'write' permission on exchange name
                if re.match(permissions["write"], method.exchange) is not None:
                    ch.basic_publish(
                        exchange=method.exchange,
                        routing_key=method.routing_key,
                        properties=pika.BasicProperties(
                            content_type=properties.content_type,
                            content_encoding=properties.content_encoding,
                            headers=headers,
                            delivery_mode=properties.delivery_mode,
                            priority=properties.priority,
                            correlation_id=properties.correlation_id,
                            reply_to=properties.reply_to,
                            expiration=properties.expiration,
                            message_id=properties.message_id,
                            timestamp=properties.timestamp,
                            type=properties.type,
                            user_id=properties.user_id,
                            app_id=properties.app_id,
                            cluster_id=properties.cluster_id
                        ),
                        body=body,
                    )
                else:
                    logger.debug("Cannot write to queue, not requeuing")

        for queue in rbmq.get_queues(vhost=vhost_name):
            if not queue["name"].startswith("amq.") and\
                re.match(permissions["read"], queue["name"]) is not None:
                if re.match(permissions["write"], "") is not None:
                    # don't try to bind on queues with exclusive lock set
                    if "exclusive_consumer_tag" in queue:
                        logger.warning(
                            "Not declaring queue [vhost={}][queue={}] due to "\
                            "exclusive lock.".format(vhost_name, queue["name"])
                        )
                        continue

                    # we are declaring the queue passively so we don't need
                    # to check the 'configure' permission
                    logger.info("Declaring queue [vhost={}][queue={}]".format(
                        vhost_name, queue["name"]))
                    channel.queue_declare(
                        queue=queue["name"],
                        durable=queue["durable"],
                        auto_delete=queue["auto_delete"],
                        exclusive=False,
                        arguments=queue["arguments"]
                    )
                    channel.basic_consume(
                        on_message_callback=callback,
                        queue=queue["name"],
                        exclusive=False,
                        auto_ack=False
                    )
                else:
                    # if we declare the queue and start consuming without being
                    # authorized to write, we will intercept messages that will
                    # never reach their legitimate consumer given that we won't
                    # be able to re-queue that message. Therefore, we do not
                    # declare queues that we do not have write permissions on.
                    # It's a design decision, feel free to remove those checks
                    # if you feel like being aggressive today :)
                    logger.warning(
                        "Not delaring queue [vhost={}][queue={}] due to "\
                        "missing write permissions."\
                        .format(vhost_name, queue["name"])
                    )

        for exchange in rbmq.get_exchanges(vhost=vhost_name):
            if not exchange["name"].startswith("amq.") and\
                exchange["name"] != '' and\
                re.match(permissions["read"], exchange["name"]) is not None and\
                re.match(permissions["read"], "amq.gen-") is not None:

                # declaration is passive so we don't need to check
                # the 'configure' permission
                channel.exchange_declare(
                    exchange=exchange["name"],
                    exchange_type=exchange["type"],
                    durable=exchange["durable"],
                    internal=exchange["internal"],
                    auto_delete=exchange["auto_delete"],
                    arguments=exchange["arguments"]
                )
                if exchange["type"] == "direct":
                    routing_keys = []
                    bindings = rbmq.get_bindings(vhost_name)
                    for binding in bindings:
                        if binding["source"] == exchange["name"]:
                            routing_keys.append(binding["routing_key"])
                else:
                    routing_keys = ["#"]

                for routing_key in routing_keys:
                    # declaration is passive so we don't need to check the
                    # 'configure' permission
                    result = channel.queue_declare(exclusive=True, queue='')
                    queue_name = result.method.queue
                    logger.info(
                        "Binding queue [vhost={}][exchange={}][queue={}]"\
                        "[routing_key={}]".format(
                            exchange["vhost"],
                            exchange["name"],
                            queue_name,
                            routing_key
                        )
                    )
                    # we checked 'read' permission on exchange name so we can
                    # go on.
                    channel.queue_bind(
                        exchange=exchange["name"],
                        queue=queue_name,
                        routing_key=routing_key
                    )
                    # we checked 'read' permission on auto-generated queue
                    # name format
                    channel.basic_consume(on_message_callback=callback, queue=queue_name, auto_ack=False)

        logger.warning(
            "[{}] Waiting for messages. To exit press CTRL+C".format(vhost_name)
        )
        channel.start_consuming()
        return connection.close()
    except pika.exceptions.ProbableAccessDeniedError as e:
        logger.warning("Access to vhost '{}' denied for user '{}'".format(
            vhost_name, username))
        connection.close()
        return
    except pika.exceptions.ConnectionClosed as e:
        logger.warning("Connection to vhost '{}' got closed.".format(vhost_name))
        connection.close()
        return
    except Exception as e:
        logger.error("An exception occured: {}".format(e))
        return

def init_worker():
    """use Tor, use Signal"""
    signal.signal(signal.SIGINT, signal.SIG_IGN)


if __name__ == "__main__":

    print(HEADER)
    parser = argparse.ArgumentParser(description=\
        "Capture all RabbitMQ messages being sent through a broker.")
    parser.add_argument('url', type=str, help="rabbitmq_management URL")
    parser.add_argument('--username', type=str, default="guest",\
        help="rabbitmq_management username")
    parser.add_argument('--password', type=str, default="guest",\
        help="rabbitmq_management password")
    parser.add_argument('-v', '--verbose', help="increase output verbosity",\
        action='store_true')
    args = parser.parse_args()

    logger = verboselogs.VerboseLogger('cottontail')
    logger.addHandler(logging.StreamHandler())
    logger.setLevel(logging.VERBOSE)
    coloredlogs.install(
        fmt='%(asctime)s %(levelname)s %(message)s',
        logger=logger,
        level='debug' if args.verbose else 'verbose'
    )

    o = urlparse(args.url)
    if o.port is None:
        raise Exception("Invalid URL provided.")

    rbmq = RabbitMQManagementClient(
        o.hostname,
        port=o.port,
        username=args.username,
        password=args.password,
        ssl=(o.scheme == "https")
    )

    try:
        overview = rbmq.get_overview()
        user = rbmq.whoami()
        vhosts = rbmq.get_vhosts()

        # Some useful information
        logger.verbose("Successful connection to '{}' as user '{}'".format(
            o.geturl(), user["name"]))
        logger.verbose("node: {}".format(overview["node"]))
        logger.verbose("version: RabbitMQ {}, Erlang {}".format(
            overview["rabbitmq_version"], overview["erlang_version"]))
        logger.verbose("{} vhosts detected: {}".format(
            len(vhosts), ", ".join([vhost["name"] for vhost in vhosts])))

        rabbit_ip = socket.gethostbyname(o.hostname)
        amqp_listener = None
        for listener in rbmq.get_amqp_listeners():
            # we attempt only low level tcp connect here.
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            # if bound to all, we try to reach on rabbit_ip
            if listener["ip_address"] in ["::", "0.0.0.0"]:
                listener["ip_address"] = rabbit_ip
            result = sock.connect_ex((listener["ip_address"], listener["port"]))
            sock.close()
            # port is open, let's use that listener
            if result == 0:
                amqp_listener = listener
                break

        if amqp_listener is not None:
            # Launch one process per vhost
            pool = multiprocessing.Pool(len(rbmq.get_vhosts()), init_worker)
            try:
                for vhost in rbmq.get_vhosts():
                    # we check the permissions and only launch a process if
                    # our user is allowed access to this specific vhost.
                    permissions = rbmq.get_permissions(
                        vhost["name"],
                        args.username
                    )
                    if permissions is not None:
                        pool.apply_async(subproc, \
                            (amqp_listener["ip_address"], amqp_listener["port"],\
                            amqp_listener["protocol"] == "amqp/ssl",\
                            args.username, args.password, vhost["name"],))
                    else:
                        logger.warning(
                            "User '{}' is not authorized to access vhost '{}'"\
                            .format(args.username, vhost["name"])
                        )
                pool.close()
                pool.join()
            except KeyboardInterrupt:
                logger.info("Caught KeyboardInterrupt, terminating workers")
                pool.terminate()
                pool.join()
        else:
            logger.warning("AMQP listener not reachable."\
                " Dumping queues via HTTP API. Note that only messages that"\
                " haven't been consumed yet will be shown.")
            for vhost in rbmq.get_vhosts():
                for queue in rbmq.get_queues(vhost["name"]):
                    if not queue["name"].startswith("amq."):
                        for message in rbmq.get_messages(vhost["name"],\
                                queue["name"], count=10000):
                            logger.info("Message from [vhost={}][exchange={}]"\
                                    "[routing_key={}]: {}".format(
                                        vhost["name"], message["exchange"],
                                        message["routing_key"],
                                        message["payload"]))

    except UnauthorizedAccessException as e:
        logger.error(e.message)
