#!/usr/bin/env python3
"""Utility to forward Codex messages to the logging daemon."""

from __future__ import annotations

import argparse
import json
import os
import sys
import urllib.error
import urllib.request
from typing import Any, Dict


def parse_args(argv: list[str]) -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Send a single Codex turn to the logger.")
    parser.add_argument("--endpoint", default=os.environ.get("CODEX_LOG_ENDPOINT", "http://127.0.0.1:8421/ingest"))
    parser.add_argument("--role", required=True, choices=["user", "assistant", "system"])
    parser.add_argument("--text", required=True, help="Message content")
    parser.add_argument("--session-id", help="Session identifier")
    parser.add_argument("--conversation-id", help="Conversation identifier")
    parser.add_argument("--channel", help="Channel (cli, api, gui, etc.)")
    parser.add_argument("--metadata", help="JSON dict with extra metadata")
    parser.add_argument("--token-usage", help="JSON dict describing token usage")
    parser.add_argument("--prompt-tokens", type=int, help="Prompt token count")
    parser.add_argument("--completion-tokens", type=int, help="Completion token count")
    parser.add_argument("--total-tokens", type=int, help="Total token count (defaults to sum if others provided)")
    return parser.parse_args(argv)


def build_payload(ns: argparse.Namespace) -> Dict[str, Any]:
    payload: Dict[str, Any] = {
        "role": ns.role,
        "text": ns.text,
    }
    if ns.session_id:
        payload["session_id"] = ns.session_id
    if ns.conversation_id:
        payload["conversation_id"] = ns.conversation_id
    if ns.channel:
        payload["channel"] = ns.channel
    if ns.metadata:
        try:
            payload["metadata"] = json.loads(ns.metadata)
        except json.JSONDecodeError as exc:
            raise SystemExit(f"Invalid metadata JSON: {exc}") from exc
    token_usage = _build_token_usage(ns)
    if token_usage is not None:
        payload["token_usage"] = token_usage
    return payload


def _build_token_usage(ns: argparse.Namespace) -> Dict[str, int] | None:
    token_usage: Dict[str, int] | None = None
    if ns.token_usage:
        if any(value is not None for value in (ns.prompt_tokens, ns.completion_tokens, ns.total_tokens)):
            raise SystemExit("Provide either --token-usage or the individual token count flags, not both.")
        try:
            parsed = json.loads(ns.token_usage)
        except json.JSONDecodeError as exc:
            raise SystemExit(f"Invalid token_usage JSON: {exc}") from exc
        if not isinstance(parsed, dict):
            raise SystemExit("token_usage must decode to a JSON object.")
        token_usage = {}
        for key, value in parsed.items():
            if not isinstance(value, int):
                raise SystemExit("token_usage values must be integers.")
            token_usage[key] = value
        return token_usage

    counts = {
        "prompt_tokens": ns.prompt_tokens,
        "completion_tokens": ns.completion_tokens,
        "total_tokens": ns.total_tokens,
    }
    if not any(value is not None for value in counts.values()):
        return None
    token_usage = {}
    for key, value in counts.items():
        if value is not None:
            if value < 0:
                raise SystemExit(f"{key} must be non-negative.")
            token_usage[key] = value
    if "total_tokens" not in token_usage:
        total = 0
        if ns.prompt_tokens is not None:
            total += ns.prompt_tokens
        if ns.completion_tokens is not None:
            total += ns.completion_tokens
        token_usage["total_tokens"] = total
    return token_usage


def main(argv: list[str] | None = None) -> None:
    if argv is None:
        argv = sys.argv[1:]
    ns = parse_args(argv)
    payload = build_payload(ns)
    data = json.dumps(payload).encode("utf-8")
    request = urllib.request.Request(ns.endpoint, data=data, headers={"Content-Type": "application/json"})
    try:
        with urllib.request.urlopen(request) as resp:
            if resp.status >= 400:
                raise SystemExit(f"Server returned status {resp.status}: {resp.read().decode()}")
    except urllib.error.URLError as exc:
        raise SystemExit(f"Failed to reach logger: {exc}") from exc


if __name__ == "__main__":
    main()
