"""
AWS service for S3 operations, Secrets Manager, and other AWS integrations.
"""

import asyncio
import json
import os

import aioboto3
import aiofiles
from botocore.exceptions import ClientError, NoCredentialsError

from app.core.logging import get_logger
from app.utils.s3_utils import normalize_key

logger = get_logger("aws_service")


class AWSService:
    """Service for AWS operations including S3, Secrets Manager, and other AWS integrations."""

    def __init__(self):
        self.region = os.getenv("AWS_DEFAULT_REGION", "us-east-1")
        self.session = None
        self._credentials_cache = {}
        self._cache_ttl = 3600  # 1 hour cache

    async def initialize_clients(self) -> bool:
        """
        Initialize AWS session with environment credentials.

        Returns:
            bool: True if session was initialized successfully, False otherwise
        """
        if self.session is not None:
            logger.debug("AWS session already initialized")
            return True

        try:
            # Try to create session with environment credentials
            aws_key = os.getenv("AWS_ACCESS_KEY_ID")
            aws_secret = os.getenv("AWS_SECRET_ACCESS_KEY")

            if aws_key and aws_secret:
                self.session = aioboto3.Session(
                    aws_access_key_id=aws_key,
                    aws_secret_access_key=aws_secret,
                    region_name=self.region,
                )
                logger.info("AWS session initialized with environment credentials")
            else:
                # Fallback to default credential chain (IAM roles, etc.)
                logger.info(
                    "No explicit AWS credentials found, using default credential chain"
                )
                self.session = aioboto3.Session(region_name=self.region)
                logger.info("AWS session initialized with default credential chain")

            return True

        except NoCredentialsError:
            logger.error(
                "AWS credentials not found and default credential chain failed"
            )
            self.session = None
            return False
        except Exception as e:
            logger.error(f"Failed to initialize AWS session: {e}")
            self.session = None
            return False

    def _validate_session(self) -> bool:
        """Validate that an AWS session is available."""
        if not self.session:
            logger.error("AWS session not initialized")
            return False
        return True

    def _handle_aws_error(
        self, operation: str, error: Exception, default_return: any = False
    ) -> any:
        """Handle AWS operation errors consistently."""
        if isinstance(error, ClientError):
            logger.error(f"AWS {operation} error: {error}")
        else:
            logger.error(f"Unexpected error during {operation}: {error}")
        return default_return

    async def download_file(self, bucket: str, key: str, local_path: str) -> bool:
        """Download a file from S3 with retry logic for ContentLengthError."""
        if not await self.initialize_clients() or not self._validate_session():
            return False

        max_retries = 3
        retry_delay = 1  # seconds

        for attempt in range(max_retries):
            try:
                # Ensure the directory exists
                os.makedirs(os.path.dirname(local_path), exist_ok=True)

                async with self.session.client("s3", region_name=self.region) as s3:
                    await s3.download_file(bucket, key, local_path)

                logger.info(
                    f"Successfully downloaded s3://{bucket}/{key} to {local_path}"
                )
                return True

            except Exception as e:
                error_msg = str(e)
                is_content_length_error = (
                    "ContentLengthError" in error_msg
                    or "Not enough data to satisfy content length header" in error_msg
                )

                if is_content_length_error and attempt < max_retries - 1:
                    logger.warning(
                        f"ContentLengthError on attempt {attempt + 1}/{max_retries} for s3://{bucket}/{key}, retrying in {retry_delay}s..."
                    )
                    await asyncio.sleep(retry_delay)
                    retry_delay *= 2  # Exponential backoff
                    continue
                else:
                    return self._handle_aws_error("download", e, False)

        return False

    async def download_file_with_fallback(
        self, bucket: str, key: str, local_path: str
    ) -> bool:
        """Download a file from S3 with fallback method for ContentLengthError."""
        # First try the standard download method
        if await self.download_file(bucket, key, local_path):
            return True

        logger.warning(
            f"Standard download failed for s3://{bucket}/{key}, trying fallback method..."
        )

        try:
            # Fallback: Use get_object and write to file manually
            async with self.session.client("s3", region_name=self.region) as s3:
                response = await s3.get_object(Bucket=bucket, Key=key)

                # Ensure the directory exists
                os.makedirs(os.path.dirname(local_path), exist_ok=True)

                # Write the response body to file
                async with aiofiles.open(local_path, "wb") as f:
                    async for chunk in response["Body"]:
                        await f.write(chunk)

                logger.info(
                    f"Successfully downloaded s3://{bucket}/{key} to {local_path} using fallback method"
                )
                return True

        except Exception as e:
            logger.error(
                f"Fallback download method also failed for s3://{bucket}/{key}: {e}"
            )
            return False

    async def generate_presigned_url(
        self, bucket: str, key: str, expires_in: int = 3600
    ) -> str | None:
        """Generate a presigned URL for a file in S3."""
        if not await self.initialize_clients() or not self._validate_session():
            return None

        try:
            async with self.session.client("s3", region_name=self.region) as s3:
                url = await s3.generate_presigned_url(
                    "get_object",
                    {"Bucket": bucket, "Key": key},
                    ExpiresIn=expires_in,
                )

            logger.info(f"Generated presigned URL for s3://{bucket}/{key}")
            return url

        except Exception as e:
            return self._handle_aws_error("presigned URL generation", e, None)

    async def get_secret(self, secret_name: str) -> dict | None:
        """
        Retrieve a secret from AWS Secrets Manager.

        Args:
            secret_name: Name of the secret to retrieve

        Returns:
            Dictionary containing the secret values or None if failed
        """
        if not await self.initialize_clients() or not self._validate_session():
            return None

        try:
            async with self.session.client(
                "secretsmanager", region_name=self.region
            ) as client:
                response = await client.get_secret_value(SecretId=secret_name)

                if "SecretString" in response:
                    secret = json.loads(response["SecretString"])
                    logger.info(f"Successfully retrieved secret: {secret_name}")
                    return secret
                else:
                    logger.error(f"Secret {secret_name} not found or invalid format")
                    return None

        except Exception as e:
            return self._handle_aws_error(f"secret retrieval ({secret_name})", e, None)

    async def get_database_credentials(
        self, secret_name: str | None = None
    ) -> dict | None:
        """
        Retrieve database credentials from AWS Secrets Manager.

        Args:
            secret_name: Name of the secret containing database credentials.
                        Defaults to 'DB_CREDENTIALS' if not specified.

        Returns:
            Dictionary containing database connection parameters or None if failed
        """
        if secret_name is None:
            secret_name = "DB_CREDENTIALS"  # nosec B105 - This is a secret name, not a password

        # Check cache first
        if secret_name in self._credentials_cache:
            logger.debug(f"Using cached credentials for {secret_name}")
            return self._credentials_cache[secret_name]

        try:
            credentials = await self.get_secret(secret_name)

            if credentials:
                # Cache the credentials
                self._credentials_cache[secret_name] = credentials

                # Schedule cache invalidation (fire and forget)
                _ = asyncio.create_task(  # noqa: RUF006
                    self._invalidate_cache_after_ttl(secret_name)
                )

                logger.info("Database credentials retrieved from Secrets Manager")
                return credentials

            logger.warning(f"Failed to retrieve credentials for {secret_name}")
            return None

        except Exception as e:
            logger.error(f"Failed to retrieve database credentials: {e}")
            return None

    async def _invalidate_cache_after_ttl(self, secret_name: str):
        """Invalidate credentials cache after TTL."""
        await asyncio.sleep(self._cache_ttl)
        if secret_name in self._credentials_cache:
            del self._credentials_cache[secret_name]
            logger.debug(f"Credentials cache invalidated for {secret_name}")

    async def refresh_credentials(self, secret_name: str | None = None):
        """Force refresh of cached credentials."""
        if secret_name is None:
            secret_name = "DB_CREDENTIALS"  # nosec B105 - This is a secret name, not a password

        if secret_name in self._credentials_cache:
            del self._credentials_cache[secret_name]
            logger.info(f"Credentials cache cleared for {secret_name}")
        else:
            logger.info(f"No cached credentials found for {secret_name}")

    async def check_file_exists(
        self, file_name: str, bucket: str | None = None
    ) -> bool:
        """Check if a file exists in S3."""
        if not await self.initialize_clients() or not self._validate_session():
            return False

        if not bucket:
            bucket = os.getenv("AWS_S3_BUCKET", "pub.myinterviewpractice.com")

        try:
            s3_key = normalize_key(file_name)

            async with self.session.client("s3", region_name=self.region) as s3:
                await s3.head_object(Bucket=bucket, Key=s3_key)

            logger.debug(f"File {file_name} exists in S3 bucket {bucket}")
            return True

        except ClientError as e:
            if e.response["Error"]["Code"] == "404":
                logger.debug(f"File {file_name} not found in S3 bucket {bucket}")
                return False
            else:
                logger.error(f"AWS S3 head_object error for {file_name}: {e}")
                return False
        except Exception as e:
            logger.error(
                f"Unexpected error checking file existence for {file_name}: {e}"
            )
            return False
