#!/usr/bin/env python3

import argparse
import base64
import gzip
import json
import pathlib
import shutil
import subprocess
import tempfile
from datetime import datetime
from pathlib import Path
from typing import List, Optional


def _get_all_k8s_pods(
    namespace: str,
    running_only: Optional[bool] = False,
    deployment_name: Optional[str] = None,
) -> List[str]:
    pod_bytes = subprocess.check_output(
        [
            "kubectl",
            "-n",
            namespace,
            "get",
            "pods",
            "-o",
            "jsonpath={.items[*].metadata.name}",
        ]
        + (["--field-selector=status.phase==Running"] if running_only else [])
    )
    return [
        pod
        for pod in pod_bytes.decode("utf-8").strip().split(" ")
        if deployment_name is None or pod.startswith(deployment_name)
    ]


def get_command_for_pod_logs(namespace: str, pod_name: str) -> List[str]:
    return [
        "kubectl",
        "-n",
        namespace,
        "logs",
        f"pod/{pod_name}",
    ]


def _get_logs(
    deployment: str,
    running_pod: str,
    namespace: str,
    pod_logs: str,
    all_logs: bool = False,
) -> bool:
    if all_logs:
        return _get_log_dir(deployment, running_pod, namespace, pod_logs)
    else:
        return _get_latest_logs(deployment, running_pod, namespace, pod_logs)


def _get_latest_logs(
    deployment: str, running_pod: str, namespace: str, output_path: Path
) -> bool:
    LOG_DIR = "/data/.logs"
    log_files_to_fetch = [
        "authorization_api.log",
        "engine-tiny.log",
        "engine.log",
        "flow-ui.log",
        "model-registry.log",
        "model-trainer.log",
        "notebook.log",
        "ray-head.log",
        "ray-worker.log",
        "ray-gpu-worker.log",
        "storage_api.log",
        "studio-api.log",
        "studio-ray-head.log",
        "studio-ray-worker.log",
        "tdm_api.log",
    ]
    ray_logs = []
    read_link_ray_head = subprocess.run(
        [
            "kubectl",
            "-n",
            namespace,
            "exec",
            running_pod,
            "--",
            "readlink",
            "-f",
            f"{LOG_DIR}/ray/ray-head/session_latest",
        ],
        capture_output=True,
    )
    read_link_ray_studio_head = subprocess.run(
        [
            "kubectl",
            "-n",
            namespace,
            "exec",
            running_pod,
            "--",
            "readlink",
            "-f",
            f"{LOG_DIR}/ray/studio-ray-head/session_latest",
        ],
        capture_output=True,
    )
    if read_link_ray_head.returncode == 0:
        ray_logs.append(
            pathlib.PosixPath(read_link_ray_head.stdout.decode().strip(), "logs")
        )
    if read_link_ray_studio_head.returncode == 0:
        ray_logs.append(
            pathlib.PosixPath(read_link_ray_studio_head.stdout.decode().strip(), "logs")
        )

    # Extract latest log file
    for log_file in log_files_to_fetch:
        result = subprocess.run(
            [
                "kubectl",
                "-n",
                namespace,
                "cp",
                f"{running_pod}:{LOG_DIR}/{log_file}",
                pathlib.Path(output_path, log_file).as_posix(),
            ],
            capture_output=True,
        )
        if result.returncode != 0:
            print(
                f"Warning: user does not have sufficient permissions for pod: {running_pod}",
            )
            print(result.stdout.decode())
            return False
    # Extract Ray Logs
    for ray_log in ray_logs:
        subprocess.run(
            [
                "kubectl",
                "-n",
                namespace,
                "cp",
                f"{running_pod}:{ray_log}",
                pathlib.Path(output_path, ray_log.relative_to(LOG_DIR)).as_posix(),
            ],
            capture_output=True,
        )
    print(
        f"Volume logs retrieved from deployment: {deployment}",
    )
    return True


def _get_log_dir(
    deployment: str, running_pod: str, namespace: str, output_path: Path
) -> bool:
    try:
        subprocess.check_call(
            _get_command_for_log_directory(running_pod, namespace, output_path),
            stdout=subprocess.DEVNULL,
            stderr=subprocess.STDOUT,
        )
        print(
            f"Volume logs retrieved from deployment: {deployment}",
        )
        return True

    except Exception as e:
        print(
            f"Warning: user does not have sufficient permissions for pod: {running_pod}",
        )
        print(e)
    return False


def _get_command_for_log_directory(
    running_pod_with_volume: str, namespace: str, output_path: Path
) -> List[str]:
    log_services_path = "/data/.logs"

    return [
        "kubectl",
        "-n",
        namespace,
        "cp",
        f"{running_pod_with_volume}:{log_services_path}",
        output_path.as_posix(),
    ]


def _describe_k8s_object(object: str, namespace: str, k8s_dir: Path) -> None:
    with open(k8s_dir / f"describe_{object}.log", "w") as f:
        try:
            subprocess.run(
                ["kubectl", "-n", namespace, "describe", object],
                stderr=subprocess.STDOUT,
                stdout=f,
            )
        except Exception as e:
            f.write(f"Could not describe {object}\n")
            f.write(str(e))


def _get_helm_secret_names(namespace: str) -> Optional[List[str]]:
    get_secrets_command = [
        "kubectl",
        "-n",
        namespace,
        "get",
        "secret",
        "--field-selector",
        "type=helm.sh/release.v1",
        "-o",
        "jsonpath={.items[*].metadata.name}",
    ]
    try:
        release_secrets = (
            subprocess.check_output(get_secrets_command, stderr=subprocess.DEVNULL)
            .decode()
            .split()
        )
        return release_secrets
    except Exception:
        print("Failed to get helm secrets")

    return None


def _get_helm_secret(
    output_file: pathlib.Path, namespace: str, release_secret: str
) -> None:
    get_secret_command = [
        "kubectl",
        "-n",
        namespace,
        "get",
        "secret",
        release_secret,
        "-o",
        "jsonpath={.data.release}",
    ]
    try:
        result = subprocess.check_output(get_secret_command, stderr=subprocess.DEVNULL)
        release = json.loads(
            gzip.decompress(base64.decodebytes(base64.decodebytes(result)))
        )
        with output_file.open("w") as fp:
            json.dump(obj=release, fp=fp)
    except Exception:
        print(f"Failed to process helm secret {release_secret}")


def _get_helm_secrets_if_exists(
    base_dir: pathlib.Path, namespace: str, all_logs: bool = False
) -> None:
    # technically release name... but this is usually the namespace for most installs
    release_secret_prefix = f"sh.helm.release.v1.{namespace}.v"
    release_secrets = _get_helm_secret_names(namespace=namespace)
    if not release_secrets:
        return

    release_secrets.sort(
        key=lambda secret_name: int(secret_name[len(release_secret_prefix) :])
    )
    # If not fetching historical logs, get up to the last 3 helm releases
    if not all_logs:
        num_secrets = min(3, len(release_secrets))
        release_secrets = release_secrets[-num_secrets:]

    pathlib.Path(base_dir, "helm_releases").mkdir(exist_ok=True)
    for release_secret in [
        release_secret
        for release_secret in release_secrets
        if release_secret.startswith(release_secret_prefix)
    ]:
        output_file = pathlib.Path(
            base_dir,
            "helm_releases",
            f"v{release_secret[len(release_secret_prefix):]}.json",
        )
        _get_helm_secret(output_file, namespace, release_secret)


def main(namespace: str, helm_values: Optional[Path], all_logs: bool) -> None:
    with tempfile.TemporaryDirectory() as zipdir:
        base_dir = pathlib.Path(zipdir) / "snorkelflow_support_bundle"
        base_dir.mkdir()
        logs_dir = base_dir / "logs"
        logs_dir.mkdir()

        pod_logs_dir = logs_dir / "pods"
        pod_logs_dir.mkdir()
        previous_pod_logs_dir = logs_dir / "pods_previous"
        previous_pod_logs_dir.mkdir()

        if helm_values is not None:
            shutil.copy(helm_values, base_dir)
            print("Copied helm values")
        else:
            print("No helm values to copy, skipped")

        # Get pod and previous pod logs
        # Each pod name is in the form "pod/{pod name}" i.e. pod/db-7575dd945c-8rdxb
        print("Retrieving k8s pod logs...")
        for pod in _get_all_k8s_pods(namespace):
            with open(pod_logs_dir / f"{pod}.log", "w") as f:
                subprocess.run(
                    get_command_for_pod_logs(namespace, pod),
                    stderr=subprocess.STDOUT,
                    stdout=f,
                )
            with open(previous_pod_logs_dir / f"{pod}-previous.log", "w") as f:
                subprocess.run(
                    get_command_for_pod_logs(namespace, pod) + ["--previous"],
                    stderr=subprocess.STDOUT,
                    stdout=f,
                )
            print("Retrieved pod logs from", pod)

        print("Finished retrieving k8s pod logs")
        # Get logs. Note Helm as of now only allows /data as the SNORKELFLOW_MOUNT_DIR

        # The existing support bundle does logic on getting the most recent session and only pulling
        # those logs, but for now we can just pull every session since we have to perform the
        # operations via the kubectl CLI. If we run into an issue where this is too big we can always
        # do the same here.
        pod_logs = logs_dir / "volume"
        pod_logs.mkdir()

        depolyments_with_data_volume = [
            "authorization-api",
            "studio-api",
            "tdm-api",
            "engine",
            "flow-ui",
            "ray-head",
            "studio-ray-head",
        ]
        logs_from_volume_copied = False

        print("Retrieving logs from k8s volume...")
        for deployment in depolyments_with_data_volume:
            running_pods = _get_all_k8s_pods(
                namespace, running_only=True, deployment_name=deployment
            )

            for running_pod in running_pods:
                logs_from_volume_copied = _get_logs(
                    deployment, running_pod, namespace, pod_logs, all_logs
                )
                if logs_from_volume_copied:
                    break

            if logs_from_volume_copied:
                break

        if not logs_from_volume_copied:
            print(
                f"Warning: could not get k8s volume logs. Could not connect to a pod with access to the data volume."
            )

        k8s_dir = base_dir / "k8s"
        k8s_dir.mkdir()
        command = ["kubectl", "-n", namespace]
        print("Describing k8s objects...")
        with open(k8s_dir / "get_pods.log", "w") as f:
            subprocess.run(
                command + ["get", "pods"], stderr=subprocess.STDOUT, stdout=f
            )
        k8s_objects = [
            "pod",
            "node",
            "deployment",
            "service",
            "pvc",
            "ingress",
            "networkpolicy",
        ]
        for k8s_object in k8s_objects:
            _describe_k8s_object(k8s_object, namespace, k8s_dir)

        print("Described k8s objects")
        _get_helm_secrets_if_exists(base_dir, namespace, all_logs)

        # Not sure if this is strictly necessary, but given the explanation from the old
        # gather_support_bundle function in snorkel-install it's better safe than sorry.
        for p in Path(zipdir).rglob("*"):
            year = datetime.fromtimestamp(p.lstat().st_mtime).year
            if year < 1980:
                p.touch()

        archive_name = "_".join(
            [
                "snorkelflow_support_bundle",
                datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
            ]
        )

        shutil.make_archive(archive_name, "zip", zipdir)
        bundles_dir = pathlib.Path("snorkelflow_support_bundles")
        bundles_dir.mkdir(exist_ok=True)
        shutil.move(archive_name + ".zip", bundles_dir)
        print(
            f"Written support bundle to snorkelflow_support_bundles/{archive_name}.zip"
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Generate a Snorkel support bundle."
    )
    parser.add_argument(
        "--namespace",
        help="Namespace of the Snorkel AI Data Development Platform install",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--helm-values",
        help="Path to user helm Values.yaml.",
        type=Path,
        required=False,
        default=None,
    )
    parser.add_argument("--all", help="collect historical logs", action="store_true")
    args = parser.parse_args()
    main(args.namespace, args.helm_values, args.all)
