"""Storage utilities for writing conversation logs to disk."""

from __future__ import annotations

import json
import threading
import time
from dataclasses import dataclass
from datetime import datetime, timezone, timedelta
from pathlib import Path
from typing import Dict, List, Optional

from .config import LoggerConfig


def _ensure_dir(path: Path) -> None:
    path.mkdir(parents=True, exist_ok=True)


def _utc_now() -> datetime:
    return datetime.now(timezone.utc)


def _window_start(ts: datetime, seconds: int) -> datetime:
    epoch_sec = int(ts.timestamp())
    window_sec = epoch_sec - (epoch_sec % seconds)
    return datetime.fromtimestamp(window_sec, tz=timezone.utc)


def _stamp(ts: datetime, fmt: str) -> str:
    return ts.strftime(fmt)


@dataclass
class LogEntry:
    turn_index: int
    timestamp: str
    role: str
    text: str
    session_id: Optional[str] = None
    conversation_id: Optional[str] = None
    channel: Optional[str] = None
    metadata: Optional[Dict[str, object]] = None
    token_usage: Optional[Dict[str, int]] = None

    def to_dict(self) -> Dict[str, object]:
        data: Dict[str, object] = {
            "turn_index": self.turn_index,
            "timestamp": self.timestamp,
            "role": self.role,
            "text": self.text,
        }
        if self.session_id is not None:
            data["session_id"] = self.session_id
        if self.conversation_id is not None:
            data["conversation_id"] = self.conversation_id
        if self.channel is not None:
            data["channel"] = self.channel
        if self.metadata:
            data["metadata"] = self.metadata
        if self.token_usage is not None:
            data["token_usage"] = self.token_usage
        return data


class ConversationStorage:
    """Buffers conversation entries and writes them to daily JSON files."""

    def __init__(self, config: LoggerConfig) -> None:
        self.config = config
        self._buffer: List[Dict[str, object]] = []
        self._buffer_bytes = 0
        self._turn_counter = 0
        self._lock = threading.Lock()
        self._last_flush = time.monotonic()
        self._current_window_start: Optional[datetime] = None

    def add_entry(self, entry: LogEntry) -> None:
        with self._lock:
            self._turn_counter += 1
            entry.turn_index = self._turn_counter
            payload = entry.to_dict()
            encoded = json.dumps(payload, ensure_ascii=False).encode("utf-8")
            self._buffer.append(payload)
            self._buffer_bytes += len(encoded)
            self._maybe_flush_locked()

    def _maybe_flush_locked(self) -> None:
        now = time.monotonic()
        flush_policy = self.config.flush
        if (
            self._buffer_bytes >= flush_policy.max_buffer_bytes
            or now - self._last_flush >= flush_policy.max_interval_seconds
        ):
            self._flush_locked()

    def flush(self) -> None:
        with self._lock:
            self._flush_locked()

    def _flush_locked(self) -> None:
        if not self._buffer:
            self._last_flush = time.monotonic()
            return
        now = _utc_now()
        rotation = self.config.rotation
        window_start = _window_start(now, rotation.window_seconds)
        if self._current_window_start is None or window_start != self._current_window_start:
            self._current_window_start = window_start
            self._turn_counter = 0

        file_path = self._current_file_path(window_start)
        payload = self._build_payload(window_start)
        _ensure_dir(file_path.parent)
        temp_path = file_path.with_suffix(file_path.suffix + self.config.tmp_suffix)
        with temp_path.open("w", encoding="utf-8") as handle:
            json.dump(payload, handle, ensure_ascii=False, indent=2)
            handle.write("\n")
        temp_path.replace(file_path)
        self._buffer.clear()
        self._buffer_bytes = 0
        self._last_flush = time.monotonic()
        if rotation.retention_days is not None:
            self._cleanup_old_files(now)

    def _current_file_path(self, window_start: datetime) -> Path:
        stamp = _stamp(window_start, self.config.stamp_format)
        filename = self.config.filename_template.format(stamp=stamp)
        return self.config.log_dir_path / filename

    def _build_payload(self, window_start: datetime) -> Dict[str, object]:
        file_path = self._current_file_path(window_start)
        entries = self._load_existing_entries(file_path)
        entries.extend(self._buffer)
        payload: Dict[str, object] = {
            "file_id": file_path.stem,
            "created_utc": window_start.isoformat(),
            "rotation_window_sec": self.config.rotation.window_seconds,
            "entries": entries,
        }
        return payload

    def _load_existing_entries(self, file_path: Path) -> List[Dict[str, object]]:
        if not file_path.exists():
            return []
        try:
            with file_path.open("r", encoding="utf-8") as handle:
                data = json.load(handle)
            entries = data.get("entries")
            if isinstance(entries, list):
                return entries
        except Exception:
            pass
        return []

    def _cleanup_old_files(self, now: datetime) -> None:
        retention_days = self.config.rotation.retention_days
        if retention_days is None:
            return
        cutoff = now - timedelta(days=retention_days)
        for path in self.config.log_dir_path.glob("*.json"):
            try:
                if path.stat().st_mtime < cutoff.timestamp():
                    path.unlink(missing_ok=True)
            except Exception:
                continue
