"""
Temporary file management with guaranteed cleanup.

This module provides utilities for managing temporary files with guaranteed cleanup,
even if the process crashes or is killed unexpectedly.
"""

import atexit
import os
import tempfile
import time
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager

from app.core.logging import get_logger

logger = get_logger("temp_file_manager")

# Global registry of temp files for cleanup on shutdown
_temp_files: set[str] = set()


def cleanup_all_temp_files():
    """
    Cleanup all tracked temp files on shutdown.

    This function is registered with atexit to ensure cleanup even if
    the process exits unexpectedly.
    """
    if not _temp_files:
        return

    logger.info(f"Shutdown cleanup: removing {len(_temp_files)} temp files")

    for temp_file in list(_temp_files):
        try:
            if os.path.exists(temp_file):
                os.remove(temp_file)
                logger.debug(f"Removed temp file: {temp_file}")
        except Exception as e:
            logger.error(f"Failed to cleanup {temp_file}: {e}")

    _temp_files.clear()
    logger.info("Shutdown cleanup completed")


# Register cleanup function to run on exit
atexit.register(cleanup_all_temp_files)


@asynccontextmanager
async def temp_file_context(suffix: str = "") -> AsyncIterator[str]:
    """
    Context manager for temporary files with guaranteed cleanup.

    This ensures temp files are always cleaned up, even if an exception occurs.

    Args:
        suffix: Optional file extension (e.g., ".mp4", ".wav")

    Yields:
        str: Path to the temporary file

    Example:
        async with temp_file_context(suffix=".mp4") as temp_path:
            # Download or process file at temp_path
            await download_file(temp_path)
            # File is automatically cleaned up after this block
    """
    temp_file = None
    try:
        # Create temporary file
        temp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
        temp_file = temp.name
        temp.close()

        # Track it globally for shutdown cleanup
        _temp_files.add(temp_file)
        logger.debug(f"Created temp file: {temp_file}")

        yield temp_file

    finally:
        # Always cleanup
        if temp_file:
            _temp_files.discard(temp_file)
            if os.path.exists(temp_file):
                try:
                    os.remove(temp_file)
                    logger.debug(f"Cleaned up temp file: {temp_file}")
                except Exception as e:
                    logger.error(f"Failed to cleanup temp file {temp_file}: {e}")


async def cleanup_old_temp_files(max_age_hours: int = 24) -> dict:
    """
    Remove temporary files older than specified age.

    This should be called periodically (e.g., daily) to clean up orphaned files
    that weren't properly cleaned up due to crashes or other issues.

    Args:
        max_age_hours: Maximum age of temp files in hours

    Returns:
        dict: Cleanup results with counts and status
    """
    temp_dir = tempfile.gettempdir()
    now = time.time()
    cleaned = 0
    failed = 0

    logger.info(f"Starting cleanup of temp files older than {max_age_hours} hours")

    try:
        for filename in os.listdir(temp_dir):
            filepath = os.path.join(temp_dir, filename)

            try:
                if os.path.isfile(filepath):
                    file_age_hours = (now - os.path.getmtime(filepath)) / 3600

                    if file_age_hours > max_age_hours:
                        os.remove(filepath)
                        cleaned += 1
                        logger.debug(f"Removed old temp file: {filepath}")
            except Exception as e:
                failed += 1
                logger.warning(f"Error processing {filepath}: {e}")

        logger.info(f"Temp file cleanup complete: {cleaned} removed, {failed} errors")

        return {
            "status": "success",
            "cleaned": cleaned,
            "failed": failed,
            "temp_dir": temp_dir,
        }

    except Exception as e:
        logger.error(f"Temp file cleanup failed: {e}")
        return {
            "status": "error",
            "message": str(e),
            "cleaned": cleaned,
            "failed": failed,
        }


def get_temp_file_count() -> int:
    """
    Get the current number of tracked temp files.

    Returns:
        int: Number of temp files currently tracked
    """
    return len(_temp_files)


def get_temp_files() -> list[str]:
    """
    Get list of all tracked temp files.

    Returns:
        list[str]: List of temp file paths
    """
    return list(_temp_files)
