"""
Shared runner for processing S3-backed files using whole-file download (legacy approach).
"""

import os
import tempfile
from collections.abc import Awaitable, Callable
from typing import Any

from app.core.config import settings
from app.core.logging import get_logger
from app.services.aws_service import AWSService
from app.utils.s3_utils import normalize_key
from app.utils.temp_file_manager import temp_file_context

logger = get_logger("processing_runner")


async def process_s3_file(
    job_id: str,
    file_type: str,
    s3_path: str,
    analyze: Callable[[str], Awaitable[dict[str, Any]]],
    cleanup: bool = True,  # noqa: ARG001
) -> dict[str, Any]:
    """
    Process S3 file using streaming approach with guaranteed temp file cleanup.

    Uses temp_file_context to ensure cleanup even if process crashes.
    Returns analysis result.
    """
    import time

    try:
        logger.info(f"[JOB {job_id}] Starting {file_type} file processing")

        # Extract bucket and key from S3 path
        if s3_path.startswith("s3://"):
            parts = s3_path[5:].split("/", 1)
            bucket = parts[0]
            key = parts[1] if len(parts) > 1 else ""
        else:
            bucket = settings.AWS_S3_BUCKET
            key = s3_path

        # Normalize the S3 key using the utility function
        normalized_key = normalize_key(key)
        logger.info(
            f"[JOB {job_id}] S3 path normalized | "
            f"original='{key}' normalized='{normalized_key}'"
        )

        # Get file extension for temp file
        _, ext = os.path.splitext(normalized_key)

        # Use context manager for guaranteed cleanup
        async with temp_file_context(suffix=ext) as temp_file:
            # Download entire file from S3 to temporary location
            logger.info(
                f"[JOB {job_id}] Downloading from S3 | "
                f"s3://{bucket}/{normalized_key} -> {temp_file}"
            )

            download_start = time.time()
            aws = AWSService()
            success = await aws.download_file_with_fallback(
                bucket, normalized_key, temp_file
            )
            download_duration = time.time() - download_start

            if not success:
                raise Exception(f"Failed to download {file_type} from S3")

            # Get file size
            file_size_mb = os.path.getsize(temp_file) / (1024 * 1024)
            # Avoid division by zero for very fast downloads
            download_speed = (
                file_size_mb / download_duration if download_duration > 0 else 0
            )
            logger.info(
                f"[JOB {job_id}] S3 download completed | "
                f"size={file_size_mb:.2f}MB duration={download_duration:.2f}s "
                f"speed={download_speed:.2f}MB/s"
            )

            # Analyze the file
            logger.info(f"[JOB {job_id}] Starting {file_type} analysis")
            analysis_start = time.time()

            result = await analyze(temp_file)

            analysis_duration = time.time() - analysis_start
            logger.info(
                f"[JOB {job_id}] {file_type.capitalize()} analysis completed | "
                f"duration={analysis_duration:.2f}s"
            )

            return result

        # temp_file is automatically cleaned up here by context manager

    except Exception as e:
        logger.error(
            f"[JOB {job_id}] {file_type.capitalize()} processing failed | "
            f"error={type(e).__name__}: {e}"
        )
        raise


async def _download_whole_from_s3(bucket: str, key: str) -> str | None:
    """Download entire file from S3 to a temporary path, preserving extension."""
    try:
        # Preserve original file extension if present
        _, ext = os.path.splitext(key)
        suffix = ext if ext else ""

        temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
        temp_path = temp_file.name
        temp_file.close()

        aws = AWSService()
        success = await aws.download_file_with_fallback(bucket, key, temp_path)
        if not success:
            if os.path.exists(temp_path):
                os.unlink(temp_path)
            return None

        return temp_path

    except Exception as e:
        logger.error(f"Failed to download from S3: {e}")
        if "temp_path" in locals() and os.path.exists(temp_path):
            os.unlink(temp_path)
        return None


async def process_s3_file_chunked(
    job_id: str,
    file_type: str,
    s3_path: str,
    analyze: Callable[[str], Awaitable[dict[str, Any]]],
    chunk_size_mb: int = 50,
    cleanup: bool = True,
) -> dict[str, Any]:
    """Process large S3 files in chunks to minimize memory usage."""
    temp_files: list[str] = []

    try:
        logger.info(f"Starting chunked {file_type} processing for job {job_id}")

        # Extract bucket and key
        if s3_path.startswith("s3://"):
            parts = s3_path[5:].split("/", 1)
            bucket = parts[0]
            key = parts[1] if len(parts) > 1 else ""
        else:
            bucket = settings.AWS_S3_BUCKET
            key = s3_path

        # Normalize the S3 key using the utility function
        normalized_key = normalize_key(key)
        logger.info(
            f"Legacy chunk wrapper: bucket={bucket}, '{key}' -> '{normalized_key}'"
        )

        # Silence linter: parameter kept for backward-compat API
        _ = chunk_size_mb

        # Legacy approach: do not chunk; download whole file then analyze once
        return await process_s3_file(job_id, file_type, s3_path, analyze, cleanup)

    except Exception as e:
        logger.error(f"Chunked processing failed for job {job_id}: {e}")
        raise e
    finally:
        # Clean up all temporary files
        if cleanup:
            for temp_file in temp_files:
                if os.path.exists(temp_file):
                    os.unlink(temp_file)
                    logger.info(f"Cleaned up chunk file: {temp_file}")


async def _stream_chunk_from_s3(
    bucket: str, key: str, _start_byte: int, _end_byte: int, _chunk_num: int
) -> str | None:
    """Deprecated: chunked streaming removed in legacy whole-file mode."""
    logger.info("Chunked streaming disabled; falling back to whole-file download")
    return await _download_whole_from_s3(bucket, key)


def _combine_chunk_results(
    chunk_results: list[dict], total_file_size: int
) -> dict[str, Any]:
    """Combine results from multiple chunks into a single result."""
    try:
        # Initialize combined result
        combined = {
            "status": "success",
            "processing_method": "chunked",
            "total_chunks": len(chunk_results),
            "total_file_size_bytes": total_file_size,
            "chunk_results": chunk_results,
        }

        # Check for errors
        errors = [r for r in chunk_results if "error" in r]
        if errors:
            combined["status"] = "partial_success"
            combined["errors"] = errors
            combined["successful_chunks"] = len(chunk_results) - len(errors)

        # Combine specific analysis results if available
        successful_results = [
            r for r in chunk_results if "result" in r and "error" not in r
        ]
        if successful_results:
            # This is a simplified combination - you may need to customize based on your analysis needs
            combined["combined_analysis"] = _merge_analysis_results(successful_results)

        return combined

    except Exception as e:
        logger.error(f"Failed to combine chunk results: {e}")
        return {
            "status": "error",
            "message": f"Failed to combine chunk results: {e!s}",
            "chunk_results": chunk_results,
        }


def _merge_analysis_results(successful_results: list[dict]) -> dict[str, Any]:
    """Merge analysis results from multiple chunks."""
    try:
        merged = {}

        # Simple merging strategy - you may need to customize this
        for result in successful_results:
            chunk_result = result.get("result", {})
            for key, value in chunk_result.items():
                if key not in merged:
                    merged[key] = value
                elif isinstance(value, int | float):
                    # For numeric values, sum them
                    merged[key] = merged.get(key, 0) + value
                elif isinstance(value, list):
                    # For lists, extend them
                    if key not in merged:
                        merged[key] = []
                    merged[key].extend(value)

        return merged

    except Exception as e:
        logger.error(f"Failed to merge analysis results: {e}")
        return {"merge_error": str(e)}
