#!/usr/bin/env python3
# This script has been tested with Python 3.9.7, but should be compatible with most Python 3 versions
# and easily portable to Python 2.7 if needed.
from __future__ import print_function, with_statement, unicode_literals
import binascii
import glob
import json
import os
import shutil
import time
import urllib
from ssl import SSLContext
from urllib.parse import urlencode
from urllib.request import urlopen, Request

import zipfile

import subprocess

HOST = None
AUTH_TOKEN = None


def fetch_raw(url, query_params={}, data=None, content_type=None, method=None):
    """
    Gets urllib Response for request
    :param url:
    :param query_params:
    :return:
    """
    full_url = "{}{}".format(HOST, url)
    if query_params:
        full_url += "?" + urllib.parse.urlencode(query_params)

    print("Fetching {}".format(full_url))

    req = Request(full_url, data=data)
    req.add_header("Authorization", "Bearer {}".format(AUTH_TOKEN))
    if content_type:
        req.add_header("Content-Type", content_type)

    if method:
        req.method = method

    # Suppress SSL warnings. If you want them, don't pass this empty context to urlopen()
    ssl_context = SSLContext()
    resp = urlopen(req, context=ssl_context)
    return resp


def get(url, query_params={}):
    """
    Sends a GET request and parses response as JSON automatically
    """
    resp = fetch_raw(url, query_params=query_params)
    # Ideally, this would just be a 204 response, but it's unreliable
    data = resp.read()
    if resp.code != 204 and data:
        return json.loads(data)


def post(url, data=None, content_type=None):
    """
    Send a POST request and parses response as JSON automatically
    """
    resp = fetch_raw(
        url,
        data=data,
        content_type=content_type,
        method="POST",
    )
    # Ideally, this would just be a 204 response, but it's unreliable
    data = resp.read()
    if resp.code != 204 and data:
        return json.loads(data)


def wait_task_completion(task_id):
    """
    Waits until given task completes and returns final status
    """
    print("Waiting for task {} to complete".format(task_id))
    while True:
        resp = get("/api/v1/tasks/{}/".format(task_id))
        if resp["failed"]:
            print("Background task failed! Exception:")
            print(resp["traceback"])
            raise ValueError("Background task failed")
        elif resp["complete"]:
            return resp
        time.sleep(2)


def find_machine(client_pk, machine_name):
    """
    Find a machine by long name

    By changing the search, you could also do this by short name, host name, or FQDN
    """
    resp = get("/api/v1/machine/", query_params=dict(
        client_id=urllib.parse.quote(client_pk),
        host_name=urllib.parse.quote(machine_name),
    ))
    if not resp["results"]:
        raise ValueError("No machine with name {} found".format(machine_name))
    if len(resp["results"]) > 1:
        raise ValueError("More than one machine with name {} found".format(machine_name))
    return resp["results"][0]


def get_script(machine_pk):
    """
    Get script for machine and returns path to ZIP file
    """
    task_id = get("/api/v1/machine/{}/script/".format(machine_pk))
    resp = wait_task_completion(task_id)

    # Download it this folder
    download_path = "script.zip"
    print("Saving script to {}".format(download_path))
    with open(download_path, "wb") as f:
        file_resp = fetch_raw(resp["download"])
        f.write(file_resp.read())

    return download_path


def extract_and_run_script(script_path):
    """
    Given a script ZIP, extracts it and runs it. We assume we're already root
    """
    extracted_path = "extracted"
    if os.path.exists(extracted_path):
        shutil.rmtree(extracted_path)
    os.mkdir(extracted_path)

    print("Extracting script to '{}'".format(extracted_path))
    with zipfile.ZipFile(script_path, 'r') as f:
        f.extractall(extracted_path)

    sh_path = os.path.join(extracted_path, "xylok-collect.sh")
    print("Running {}".format(sh_path), end="")
    proc = subprocess.Popen(["/bin/bash", sh_path])
    while proc.poll() is None:
        time.sleep(1)
        print(".", end="")
    print(" done")

    # Find results path
    result_files = list(glob.glob(os.path.join(extracted_path, "*.xylok")))
    if not result_files:
        raise ValueError("Unable to find .xylok results file")
    if len(result_files) > 1:
        raise ValueError("More than one .xylok results file")
    return result_files[0]


def upload_results(results_path):
    """
    Upload given file to server
    """
    def encode_multipart_formdata(fields):
        # From https://julien.danjou.info/handling-multipart-form-data-python/
        boundary = binascii.hexlify(os.urandom(16)).decode('ascii')

        body_parts = []
        for name, data in fields.items():
            # ------WebKitFormBoundarywar355SXT2uPAqTT
            # Content-Disposition: form-data; name="Unsaved Document 1"; filename="Unsaved Document 1"
            # Content-Type: application/octet-stream
            #
            #
            # ------WebKitFormBoundarywar355SXT2uPAqTT--
            part_lines = [
                '''------{boundary}''',
                '''Content-Disposition: form-data; name=\"{name}\"; filename=\"{name}\"''',
                '''Content-Type: application/octet-stream''',
                '',
                data.decode(),
                '',
            ]
            part = "\r\n".join(part_lines).format(name=name, boundary=boundary)
            body_parts.append(part)

        body_parts.append("------{}--\r\n".format(boundary))
        body = "".join(body_parts).encode()

        # Content-Type: 'multipart/form-data; boundary=----WebKitFormBoundarywar355SXT2uPAqTT'
        content_type = "multipart/form-data; boundary={}".format(boundary)

        return body, content_type

    with open(results_path, "rb") as f:
        result_contents = f.read()

    data, content_type = encode_multipart_formdata({"file": result_contents})
    task_ids = post("/api/v1/upload/", data=data, content_type=content_type)
    if not task_ids:
        raise ValueError("No task IDs returned from upload")
    if len(task_ids) != 1:
        raise ValueError("More than one task returned from upload")

    task_id = task_ids[0]
    return wait_task_completion(task_id)


def get_machine_scans(machine_pk):
    """
    Fetch list of the most recent scan IDs for machine

    Results are ordered newest to oldest
    """
    print("Fetching most recent scans for machine {}".format(machine_pk))
    resp = get("/api/v1/machine/{}/scans/".format(machine_pk), query_params=dict(
        limit=2,
    ))
    if not resp["results"]:
        raise ValueError("No scans found for machine")
    return [s["pk"] for s in resp["results"]]


def copy_scan_answers(old_scan_pk, new_scan_pk):
    """
    Copy interview/user-editable answers from one scan to another
    """
    print("Copying answers from {} to {}".format(old_scan_pk, new_scan_pk))
    return post("/api/v1/scans/{}/copy-answers-to/{}/".format(old_scan_pk, new_scan_pk))


def aa_scan(scan_pk):
    """
    Copy interview/user-editable answers from one scan to another
    """
    print("Running AA on scan {}".format(scan_pk))
    return post("/api/v1/scans/{}/aa/execute/".format(scan_pk))


def main():
    global HOST
    global AUTH_TOKEN
    import argparse

    parser = argparse.ArgumentParser(description='Automatically scan a machine')
    parser.add_argument('--client', required=True, help='ID of client')
    parser.add_argument('--host-name', required=True, help='hostname of machine to scan (as saved in Xylok)')
    parser.add_argument('--api-server', required=True, help='Xylok server base URL (ie, http://localhost)')
    parser.add_argument('--api-token', required=True, help='API token for accessing server')

    args = parser.parse_args()

    HOST = args.api_server
    AUTH_TOKEN = args.api_token

    # First parameters in the client the machine is in. You can get this from the URL of the client details page
    machine = find_machine(args.client, args.host_name)
    machine_pk = machine["pk"]

    zip_path = get_script(machine_pk)
    results_path = extract_and_run_script(zip_path)

    # You could probably save the redirect URL this gives you and parse out the scan ID, but
    # instead we're going to just find the most recent scans for our machine
    upload_results(results_path)

    scan_pks = get_machine_scans(machine_pk)
    new_scan_pk = scan_pks[0]
    if len(scan_pks) >= 2:
        old_scan_pk = scan_pks[1]
        copy_scan_answers(old_scan_pk, new_scan_pk)

    aa_task = aa_scan(new_scan_pk)
    aa_task_result = wait_task_completion(aa_task)

    if aa_task_result["success"]:
        print("Scan run, uploaded, and automatically analyzed successfully")
    else:
        print("Scan AA failed:")
        print(aa_task_result)


if __name__ == "__main__":
    import sys
    sys.exit(main())
