from __future__ import annotations

from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any, Iterable, Optional

import requests


DEFAULT_BASE_URL = "http://graph-ingress:8081"
DEFAULT_TIMEOUT_SECONDS = 5


class HTTPStatusError(Exception):
    def __init__(self, status_code: int, body: str) -> None:
        super().__init__(f"failior: request rejected ({status_code})")
        self.status_code = status_code
        self.body = body


@dataclass
class PacketBatchResult:
    accepted: int
    failed: int
    errors: list[Exception]


def _trim_trailing_slash(value: str) -> str:
    return value.rstrip("/")


def _now_rfc3339() -> str:
    return datetime.now(timezone.utc).isoformat()


class Tracker:
    def __init__(self, graph: "Graph", timeout: Optional[float] = None) -> None:
        self._graph = graph
        self._timeout = timeout
        self._nodes: list[str] = []
        self._ended = False

    def node(self, node_id: str) -> None:
        value = (node_id or "").strip()
        if not value:
            raise ValueError("failior: node_id required")
        self._nodes.append(value)

    def end(self, err: Optional[BaseException] = None) -> None:
        if self._ended:
            return
        self._ended = True
        if not self._nodes:
            return

        message = str(err) if err else ""
        self._graph.send_packet(
            {
                "did_error": err is not None,
                "packet_msg": message,
                "graph_id": self._graph.graph_id,
                "node_id_list": list(self._nodes),
                "timestamp": _now_rfc3339(),
            },
            timeout=self._timeout,
        )


class Graph:
    def __init__(
        self,
        graph_id: str,
        *,
        base_url: str = DEFAULT_BASE_URL,
        ingress_key: str = "",
        timeout_seconds: float = DEFAULT_TIMEOUT_SECONDS,
        session: Optional[requests.Session] = None,
    ) -> None:
        value = (graph_id or "").strip()
        if not value:
            raise ValueError("failior: graph_id required")
        base = _trim_trailing_slash(base_url or "")
        if not base:
            raise ValueError("failior: base_url required")
        self.graph_id = value
        self.base_url = base
        self.ingress_key = (ingress_key or "").strip()
        self.timeout_seconds = timeout_seconds
        self.session = session or requests.Session()

    def track(self, *, timeout: Optional[float] = None) -> Tracker:
        return Tracker(self, timeout=timeout)

    def inform(self, node_id: str, status: str, message: str = "", *, timeout: Optional[float] = None) -> None:
        status_value = (status or "").strip()
        if status_value not in ("ok", "error"):
            raise ValueError("failior: status must be ok or error")
        self.send_packet(
            {
                "did_error": status_value == "error",
                "packet_msg": message,
                "graph_id": self.graph_id,
                "node_id_list": [node_id],
                "timestamp": _now_rfc3339(),
            },
            timeout=timeout,
        )

    def inform_up(self, node_id: str, message: str = "", *, timeout: Optional[float] = None) -> None:
        self.inform(node_id, "ok", message, timeout=timeout)

    def inform_error(self, node_id: str, message: str, *, timeout: Optional[float] = None) -> None:
        self.inform(node_id, "error", message, timeout=timeout)

    def send_packet(self, packet: dict[str, Any], *, timeout: Optional[float] = None) -> None:
        payload = self._normalize_packet(packet)
        headers = {"Content-Type": "application/json"}
        if self.ingress_key:
            headers["X-Ingress-Key"] = self.ingress_key
        used_timeout = timeout if timeout is not None else self.timeout_seconds
        response = self.session.post(
            f"{self.base_url}/ingest",
            json=payload,
            headers=headers,
            timeout=used_timeout,
        )
        if response.status_code // 100 != 2:
            raise HTTPStatusError(response.status_code, response.text)

    def send_packet_batch(
        self,
        packets: Iterable[dict[str, Any]],
        workers: int = 8,
        *,
        timeout: Optional[float] = None,
    ) -> PacketBatchResult:
        packet_list = list(packets)
        if not packet_list:
            return PacketBatchResult(accepted=0, failed=0, errors=[])

        width = max(1, min(int(workers), len(packet_list)))
        errors: list[Exception] = []
        accepted = 0
        with ThreadPoolExecutor(max_workers=width) as pool:
            futures = [pool.submit(self.send_packet, packet, timeout=timeout) for packet in packet_list]
            for future in as_completed(futures):
                try:
                    future.result()
                    accepted += 1
                except Exception as err:  # noqa: BLE001
                    errors.append(err)
        return PacketBatchResult(accepted=accepted, failed=len(packet_list) - accepted, errors=errors)

    def _normalize_packet(self, packet: dict[str, Any]) -> dict[str, Any]:
        graph_id = str(packet.get("graph_id") or self.graph_id).strip()
        if not graph_id:
            raise ValueError("failior: graph_id required")

        node_list_raw = packet.get("node_id_list")
        if not isinstance(node_list_raw, list) or not node_list_raw:
            raise ValueError("failior: node_id_list cannot be empty")
        node_list = [str(item).strip() for item in node_list_raw]
        if any(not item for item in node_list):
            raise ValueError("failior: node_id_list contains empty node id")

        did_error = bool(packet.get("did_error"))
        packet_msg = str(packet.get("packet_msg") or "")
        if did_error and not packet_msg.strip():
            raise ValueError("failior: packet_msg required when did_error is true")

        timestamp = packet.get("timestamp")
        if not timestamp:
            timestamp = _now_rfc3339()

        return {
            "did_error": did_error,
            "packet_msg": packet_msg,
            "graph_id": graph_id,
            "node_id_list": node_list,
            "timestamp": timestamp,
        }


def load(graph_id: str, **kwargs: Any) -> Graph:
    return Graph(graph_id, **kwargs)


def track(graph_id: str, **kwargs: Any) -> Tracker:
    return load(graph_id, **kwargs).track()
