#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import socket, time, json, statistics, sys, os, argparse, importlib
from datetime import datetime

HOSTS_DEFAULT = ["wts.cigam.com.br", "wts2.cigam.com.br"]

def parse_args():
    p = argparse.ArgumentParser(description="Teste de latência TCP com gráficos opcionais")
    p.add_argument("--hosts", nargs="+", default=HOSTS_DEFAULT, help="Hosts a testar")
    p.add_argument("--port", type=int, default=33038, help="Porta TCP (ex.: 33038, 443, 3389)")
    p.add_argument("--attempts", type=int, default=5, help="Nº de tentativas")
    p.add_argument("--timeout", type=float, default=2.0, help="Timeout por tentativa (s)")
    p.add_argument("--plot", type=int, default=1, help="Tentar gerar gráficos (1) ou não (0)")
    p.add_argument("--outdir", default=".", help="Diretório de saída p/ PNGs")
    return p.parse_args()

def resolve(host):
    t0 = time.perf_counter()
    try:
        infos = socket.getaddrinfo(host, None, proto=socket.IPPROTO_TCP)
        ips = list({info[4][0] for info in infos})
        dt = (time.perf_counter() - t0) * 1000
        return {"ok": True, "ms": dt, "ips": ips}
    except Exception as e:
        return {"ok": False, "error": str(e)}

def try_connect(ip, port, timeout):
    t0 = time.perf_counter()
    try:
        s = socket.socket(socket.AF_INET6 if ":" in ip else socket.AF_INET, socket.SOCK_STREAM)
        s.settimeout(timeout)
        s.connect((ip, port))
        s.close()
        return (time.perf_counter() - t0) * 1000
    except Exception:
        return None

def test_host(host, port, attempts, timeout):
    res = resolve(host)
    if not res["ok"] or not res["ips"]:
        return {
            "host": host, "resolve_ok": False, "dns_ms": None, "ips": res.get("ips", []),
            "error": res.get("error", "DNS failed"), "attempts": attempts,
            "round_times": [], "per_ip": {}
        }
    times = []
    per_ip_times = {ip: [] for ip in res["ips"]}
    round_times = []
    for _ in range(attempts):
        round_best = None
        for ip in res["ips"]:
            dt = try_connect(ip, port, timeout)
            if dt is not None:
                per_ip_times[ip].append(dt)
                round_best = dt if round_best is None or dt < round_best else round_best
        round_times.append(round_best if round_best is not None else None)
        if round_best is not None:
            times.append(round_best)

    summary = {
        "host": host,
        "resolve_ok": True,
        "dns_ms": res["ms"],
        "ips": res["ips"],
        "attempts": attempts,
        "success": len(times),
        "success_rate": len(times) / attempts if attempts else 0,
        "min_ms": min(times) if times else None,
        "avg_ms": statistics.mean(times) if times else None,
        "p95_ms": (statistics.quantiles(times, n=20)[18] if len(times) >= 20 else None) if times else None,
        "max_ms": max(times) if times else None,
        "round_times": round_times,
        "per_ip": {ip: {
            "success": len(v),
            "min_ms": min(v) if v else None,
            "avg_ms": statistics.mean(v) if v else None
        } for ip, v in per_ip_times.items()}
    }
    return summary

def _short(host: str) -> str:
    return host.split('.')[0] if host else host

def compute_conclusion(results):
    valid = [r for r in results if r.get("min_ms") is not None]
    if len(valid) < 2:
        return None, None
    def key(r):
        return (r["min_ms"], r["avg_ms"] if r.get("avg_ms") is not None else float("inf"))
    valid.sort(key=key)
    best, second = valid[0], valid[1]
    d_min = (second["min_ms"] - best["min_ms"]) if (second.get("min_ms") is not None) else None
    d_avg = None
    if best.get("avg_ms") is not None and second.get("avg_ms") is not None:
        d_avg = second["avg_ms"] - best["avg_ms"]
    if d_min is None:
        return None, None
    if d_avg is not None:
        text = f'{_short(best["host"])} mais rápido por {d_min:.3f} ms no min e {d_avg:.3f} ms no avg'
    else:
        text = f'{_short(best["host"])} mais rápido por {d_min:.3f} ms no min'
    diff_obj = {
        "winner": best["host"],
        "runner_up": second["host"],
        "d_min_ms": round(d_min, 3),
        "d_avg_ms": round(d_avg, 3) if d_avg is not None else None
    }
    return text, diff_obj

def make_plots(results, ts, outdir):
    # Importa matplotlib só aqui, e sempre headless
    import matplotlib
    matplotlib.use("Agg")
    import matplotlib.pyplot as plt
    plots = []
    try:
        os.makedirs(outdir, exist_ok=True)
    except Exception:
        # sem permissão -> sem plots
        return plots

    # 1) Min/Avg/Max por host
    labels, mins, avgs, maxs = [], [], [], []
    for r in results:
        labels.append(r["host"])
        mins.append(r["min_ms"] if r.get("min_ms") is not None else 0)
        avgs.append(r["avg_ms"] if r.get("avg_ms") is not None else 0)
        maxs.append(r["max_ms"] if r.get("max_ms") is not None else 0)

    if any(x > 0 for x in mins + avgs + maxs):
        x = range(len(labels))
        width = 0.25
        plt.figure()
        plt.bar([i - width for i in x], mins, width=width, label="Min (ms)")
        plt.bar(x, avgs, width=width, label="Avg (ms)")
        plt.bar([i + width for i in x], maxs, width=width, label="Max (ms)")
        plt.xticks(list(x), labels)
        plt.ylabel("ms")
        plt.title("Latência TCP — Min/Avg/Max por Host")
        plt.legend()
        f1 = os.path.join(outdir, f"latency_summary_{ts}.png")
        plt.tight_layout(); plt.savefig(f1); plt.close()
        plots.append(f1)

    # 2) Melhor latência por tentativa (linhas)
    import math
    any_series = False
    plt.figure()
    for r in results:
        y = [(v if v is not None else math.nan) for v in r.get("round_times", [])]
        if any(v is not None for v in r.get("round_times", [])):
            plt.plot(range(1, len(y)+1), y, marker="o", label=r["host"])
            any_series = True
    if any_series:
        plt.xlabel("Tentativa"); plt.ylabel("ms"); plt.title("Melhor latência por tentativa"); plt.legend()
        f2 = os.path.join(outdir, f"attempts_{ts}.png")
        plt.tight_layout(); plt.savefig(f2); plt.close()
        plots.append(f2)
    else:
        plt.close()

    # 3) DNS por host
    labels_dns, dns_vals = [], []
    for r in results:
        if r.get("dns_ms") is not None and r.get("resolve_ok"):
            labels_dns.append(r["host"]); dns_vals.append(r["dns_ms"])
    if labels_dns:
        plt.figure()
        plt.bar(range(len(labels_dns)), dns_vals)
        plt.xticks(range(len(labels_dns)), labels_dns)
        plt.ylabel("ms"); plt.title("Tempo de resolução DNS")
        f3 = os.path.join(outdir, f"dns_{ts}.png")
        plt.tight_layout(); plt.savefig(f3); plt.close()
        plots.append(f3)

    return plots

def main():
    args = parse_args()

    # Habilita/desabilita gráficos com degradação limpa
    plot_enabled = False
    plot_reason = None
    if args.plot == 1:
        try:
            importlib.import_module("matplotlib")
            plot_enabled = True
        except ImportError:
            plot_enabled = False
            plot_reason = "matplotlib não está instalado"

    results = [test_host(h, args.port, args.attempts, args.timeout) for h in args.hosts]

    # Vencedor pelo menor min_ms (desempate pelo maior success_rate)
    winners = [r for r in results if r.get("min_ms") is not None]
    faster = None
    if winners:
        winners.sort(key=lambda r: (r["min_ms"], -r["success_rate"]))
        faster = winners[0]["host"]

    # Conclusão "wts2 mais rápido por X ms no min e Y ms no avg"
    conclusion_text, diff_obj = compute_conclusion(results)

    ts = datetime.now().strftime("%Y%m%d_%H%M%S")

    plots = []
    if plot_enabled:
        try:
            plots = make_plots(results, ts, args.outdir)
        except Exception as e:
            plot_enabled = False
            plot_reason = f"falha ao plotar: {e.__class__.__name__}: {e}"
            plots = []

    out = {
        "port": args.port,
        "attempts": args.attempts,
        "timeout_s": args.timeout,
        "hosts": args.hosts,
        "results": results,
        "faster": faster,
        "conclusion": conclusion_text,
        "diff_ms": diff_obj,
        "plots": plots,
        "plots_enabled": plot_enabled,
        "plots_disabled_reason": plot_reason
    }
    # IMPORTANTE: só JSON no stdout (nada de prints auxiliares)
    print(json.dumps(out, ensure_ascii=False, indent=2))

if __name__ == "__main__":
    main()
