"""
Gaze Tracking Service - Ported from legacy system.
Provides eye tracking and gaze direction analysis using computer vision.
"""

import asyncio
import os
from collections import deque
from typing import Any

import cv2
import dlib
import numpy as np

from app.core.logging import get_logger
from app.services.analysis_utils import AnalysisUtils
from app.utils.video_utils import foreach_frame

logger = get_logger("gaze_tracking_service")


class Eye:
    """Represents an eye and its properties."""

    def __init__(self, frame, landmarks, side, calibration):
        self.frame = frame
        self.calibration = calibration
        self.side = side
        self.iris_quality = 0.0

        if side == 0:  # Left eye
            self.landmarks = landmarks[36:42]
        else:  # Right eye
            self.landmarks = landmarks[42:48]

        self._analyze()

    def _analyze(self):
        """Analyze eye properties."""
        self._get_eye_region()
        self._get_pupil()
        self._get_blinking_ratio()

    def _get_eye_region(self):
        """Get the eye region coordinates using legacy-style masking and cropping."""
        points = np.array(self.landmarks, dtype=np.int32)

        # Create mask to isolate the eye polygon (legacy behavior)
        height, width = self.frame.shape[:2]
        black_frame = np.zeros((height, width), np.uint8)
        mask = np.full((height, width), 255, np.uint8)
        cv2.fillPoly(mask, [points], (0, 0, 0))
        eye_only = cv2.bitwise_not(black_frame, self.frame.copy(), mask=mask)

        # Crop bounding box around the eye polygon with a small margin (legacy uses 5)
        margin = 5
        min_x = int(np.min(points[:, 0]) - margin)
        max_x = int(np.max(points[:, 0]) + margin)
        min_y = int(np.min(points[:, 1]) - margin)
        max_y = int(np.max(points[:, 1]) + margin)

        # Clamp to frame bounds
        min_x = max(min_x, 0)
        min_y = max(min_y, 0)
        max_x = min(max_x, width)
        max_y = min(max_y, height)

        self.frame = eye_only[min_y:max_y, min_x:max_x]
        self.origin = (min_x, min_y)

        # Legacy center within ROI
        roi_h, roi_w = self.frame.shape[:2]
        self.center = (roi_w / 2.0, roi_h / 2.0)
        self.width = roi_w
        self.height = roi_h

    def _get_pupil(self):
        """Detect pupil position using legacy-like per-frame threshold search and binarization."""
        try:
            # Work on the isolated eye ROI (already grayscale)
            eye_region = self.frame
            if eye_region is None or eye_region.size == 0:
                self.pupil = type("Pupil", (), {"x": None, "y": None})()
                return

            # Legacy per-frame best threshold search aiming for iris size ~0.48
            target_iris_size = 0.48
            kernel = np.ones((3, 3), np.uint8)

            def process_with_threshold(thr: int):
                f = cv2.bilateralFilter(eye_region, 10, 15, 15)
                f = cv2.erode(f, kernel, iterations=3)
                return cv2.threshold(f, int(thr), 255, cv2.THRESH_BINARY)[1]

            best_thr = 50
            best_diff = 1e9
            for thr in range(5, 100, 5):
                iris_frame = process_with_threshold(thr)
                # Compute iris size: percent black pixels (excluding 5px border)
                frame_crop = iris_frame[5:-5, 5:-5]
                h, w = frame_crop.shape[:2]
                # Guard against tiny ROIs
                if h == 0 or w == 0:
                    continue
                num_pixels = h * w
                num_blacks = num_pixels - cv2.countNonZero(frame_crop)
                iris_size = num_blacks / float(num_pixels)
                diff = abs(iris_size - target_iris_size)
                if diff < best_diff:
                    best_diff = diff
                    best_thr = thr

            proc = process_with_threshold(best_thr)

            # Contours sorted by area; prefer second largest like legacy, fallback to largest
            fc = cv2.findContours(proc, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
            contours = fc[-2] if isinstance(fc, tuple) else []
            contours = sorted(contours, key=cv2.contourArea)
            if len(contours) == 0:
                self.pupil = type("Pupil", (), {"x": None, "y": None})()
                return
            selected = contours[-2] if len(contours) >= 2 else contours[-1]
            m = cv2.moments(selected)
            if m.get("m00", 0) == 0:
                self.pupil = type("Pupil", (), {"x": None, "y": None})()
                return
            cx = int(m["m10"] / m["m00"])  # centroid x
            cy = int(m["m01"] / m["m00"])  # centroid y
            self.pupil = type("Pupil", (), {"x": cx, "y": cy})()

            # Set iris quality (1.0 = ideal), gate low-quality detections
            # Normalize by target; if far from target (>~0.2), mark as unreliable
            self.iris_quality = max(
                0.0, 1.0 - (best_diff / max(target_iris_size, 1e-6))
            )
            if self.iris_quality < 0.6:
                self.pupil = type("Pupil", (), {"x": None, "y": None})()

        except Exception as e:
            logger.warning(f"Pupil detection failed: {e}")
            # Unknown coordinates when detection fails
            self.pupil = type("Pupil", (), {"x": None, "y": None})()

    def _get_blinking_ratio(self):
        """Calculate blinking ratio based on eye aspect ratio."""
        try:
            # Calculate eye aspect ratio (EAR)
            a = np.linalg.norm(
                np.array(self.landmarks[1]) - np.array(self.landmarks[5])
            )
            b = np.linalg.norm(
                np.array(self.landmarks[2]) - np.array(self.landmarks[4])
            )
            c = np.linalg.norm(
                np.array(self.landmarks[0]) - np.array(self.landmarks[3])
            )

            if c != 0:
                self.blinking = (a + b) / (2.0 * c)
            else:
                self.blinking = 0

        except Exception as e:
            logger.warning(f"Blinking calculation failed: {e}")
            self.blinking = 0

    def compute_gaze_ratio(self) -> float:
        """Compute gaze ratio from thresholded eye image: left_white/right_white."""
        try:
            eye_region = self.frame[
                self.origin[1] : self.origin[1] + self.height,
                self.origin[0] : self.origin[0] + self.width,
            ]
            if eye_region.size == 0:
                return 1.0
            if len(eye_region.shape) == 3:
                eye_region = cv2.cvtColor(eye_region, cv2.COLOR_BGR2GRAY)
            eye_blur = cv2.GaussianBlur(eye_region, (5, 5), 0)
            thresh = cv2.adaptiveThreshold(
                eye_blur, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY_INV, 11, 5
            )
            _h, w = thresh.shape
            mid = w // 2
            left_part = thresh[:, :mid]
            right_part = thresh[:, mid:]
            left_white = cv2.countNonZero(left_part) + 1
            right_white = cv2.countNonZero(right_part) + 1
            return left_white / right_white
        except Exception as e:
            logger.warning(f"Gaze ratio failed: {e}")
            return 1.0


class Calibration:
    """Eye calibration for gaze tracking."""

    def __init__(self):
        self.n = 0
        self.mean_ratio = 1.0
        self.window = deque(maxlen=30)

    def is_complete(self):
        """Check if calibration is complete."""
        return self.n >= 9

    def update(self, eye_left, eye_right):
        """Update calibration with new eye data."""
        try:
            if not eye_left or not eye_right:
                return
            ratio_left = eye_left.compute_gaze_ratio()
            ratio_right = eye_right.compute_gaze_ratio()
            ratio = (ratio_left + ratio_right) / 2.0
            self.window.append(ratio)
            self.n = min(self.n + 1, 1_000_000)
            self.mean_ratio = (
                float(np.mean(self.window)) if len(self.window) > 0 else 1.0
            )
        except Exception:
            pass  # nosec B110 - Intentionally ignore exceptions in gaze tracking calculations


class GazeTrackingService:
    """
    Gaze Tracking Service - Ported from legacy system.
    Tracks user's gaze direction using computer vision techniques.
    """

    def __init__(self):
        """Initialize the gaze tracking service."""
        self.webcam = None
        self.frame = None
        self.eye_left = None
        self.eye_right = None
        self.gaze_ratio = 0.0
        self.gaze_ratio_left = 0.0
        self.gaze_ratio_right = 0.0
        self.total_frames_processed = 0
        self.successful_frames = 0
        self.error_frames = 0
        # Smooth recent horizontal ratios for stability
        self._ratio_window = deque(maxlen=15)

        # Initialize dlib models
        self._face_detector = None
        self._predictor = None
        self._init_models()

        # Initialize calibration
        self.calibration = Calibration()
        self._ratio_window = deque(maxlen=5)

    def _init_models(self):
        """Initialize computer vision models."""
        try:
            self._face_detector = dlib.get_frontal_face_detector()
            model_path = self._get_model_path("shape_predictor_68_face_landmarks.dat")
            if not os.path.exists(model_path):
                logger.error("shape_predictor_68_face_landmarks.dat not found")
                self._predictor = None
                return
            self._predictor = dlib.shape_predictor(model_path)
            logger.info("dlib gaze tracking models initialized successfully")
        except Exception as e:
            logger.error(f"Failed to initialize dlib gaze tracking models: {e}")
            self._face_detector = None
            self._predictor = None

    def _get_model_path(self, model_name: str) -> str:
        """Get the path to a model file using shared helper."""
        if model_name == "shape_predictor_68_face_landmarks.dat":
            return AnalysisUtils.get_resource_path("models", "dlib", model_name)
        return AnalysisUtils.get_resource_path("models", model_name)

    @property
    def pupils_located(self) -> bool:
        """Check if pupils have been located."""
        try:
            if not self.eye_left or not self.eye_right:
                return False

            # Check if pupil coordinates are valid
            int(self.eye_left.pupil.x)
            int(self.eye_left.pupil.y)
            int(self.eye_right.pupil.x)
            int(self.eye_right.pupil.y)
            return True

        except (ValueError, AttributeError, TypeError):
            return False

    def refresh(self, frame):
        """Refresh the frame and analyze it."""
        self.frame = frame
        self._analyze()

    def _analyze(self):
        """Detect face and initialize Eye-like objects using legacy-compatible logic."""
        if self.frame is None:
            return
        try:
            gray = cv2.cvtColor(self.frame, cv2.COLOR_BGR2GRAY)
            if self._face_detector is None or self._predictor is None:
                self.eye_left = None
                self.eye_right = None
                return
            faces = self._face_detector(gray)
            if len(faces) == 0:
                self.eye_left = None
                self.eye_right = None
                return
            shape = self._predictor(gray, faces[0])
            parts = [(p.x, p.y) for p in shape.parts()]
            self.eye_left = Eye(gray, parts, 0, self.calibration)
            self.eye_right = Eye(gray, parts, 1, self.calibration)
            # Use legacy horizontal ratio directly (pupil position relative to eye center)
            self.gaze_ratio = self.horizontal_ratio() or 0.5
            self.gaze_ratio_left = self.gaze_ratio
            self.gaze_ratio_right = self.gaze_ratio
            # Update smoothing window only when detections are reliable
            try:
                if (
                    self.pupils_located
                    and getattr(self.eye_left, "iris_quality", 1.0) >= 0.6
                    and getattr(self.eye_right, "iris_quality", 1.0) >= 0.6
                ):
                    self._ratio_window.append(self.gaze_ratio)
            except Exception:
                pass  # nosec B110 - Intentionally ignore exceptions in gaze ratio calculations
            self.total_frames_processed += 1
            self.successful_frames += 1
        except Exception as e:
            logger.warning(f"dlib-based gaze analysis failed: {e}")
            self.eye_left = None
            self.eye_right = None

    def horizontal_ratio(self) -> float:
        """
        Legacy horizontal gaze ratio in [0.0, 1.0]:
        - 0.0 = extreme right
        - 0.5 = center
        - 1.0 = extreme left
        Computed from pupil X relative to eye center, matching legacy implementation.
        """
        if not self.pupils_located:
            return 0.5

        try:
            # Mirror legacy formula:
            # pupil_left = self.eye_left.pupil.x / (self.eye_left.center[0] * 2 - 10)
            # pupil_right = self.eye_right.pupil.x / (self.eye_right.center[0] * 2 - 10)
            # ratio = (pupil_left + pupil_right) / 2
            denom_left = max(self.eye_left.center[0] * 2 - 10, 1)
            denom_right = max(self.eye_right.center[0] * 2 - 10, 1)
            pupil_left = self.eye_left.pupil.x / float(denom_left)
            pupil_right = self.eye_right.pupil.x / float(denom_right)
            return (pupil_left + pupil_right) / 2.0

        except Exception as e:
            logger.warning(f"Horizontal ratio calculation failed: {e}")
            return 0.5

    def vertical_ratio(self) -> float:
        """
        Returns a number between 0.0 and 1.0 indicating vertical gaze direction.
        - 0.0 = extreme top
        - 0.5 = center
        - 1.0 = extreme bottom
        """
        if not self.pupils_located:
            return 0.5

        try:
            # Calculate vertical ratio using eye region height
            left_height = max(self.eye_left.height, 1)
            right_height = max(self.eye_right.height, 1)
            pupil_left = self.eye_left.pupil.y / left_height
            pupil_right = self.eye_right.pupil.y / right_height
            # Return average of both eyes
            return (pupil_left + pupil_right) / 2

        except Exception as e:
            logger.warning(f"Vertical ratio calculation failed: {e}")
            return 0.5

    def is_right(self) -> bool:
        """Returns True if user is looking to the right (legacy thresholds)."""
        if not self.pupils_located:
            return False
        ratio = (
            float(sum(self._ratio_window) / len(self._ratio_window))
            if len(self._ratio_window) > 0
            else self.horizontal_ratio()
        )
        return ratio <= 0.35

    def is_left(self) -> bool:
        """Returns True if user is looking to the left (legacy thresholds)."""
        if not self.pupils_located:
            return False
        ratio = (
            float(sum(self._ratio_window) / len(self._ratio_window))
            if len(self._ratio_window) > 0
            else self.horizontal_ratio()
        )
        return ratio >= 0.65

    def is_center(self) -> bool:
        """Returns True if user is looking to the center (legacy logic)."""
        if not self.pupils_located:
            return False
        ratio = (
            float(sum(self._ratio_window) / len(self._ratio_window))
            if len(self._ratio_window) > 0
            else self.horizontal_ratio()
        )
        return 0.35 < ratio < 0.65

    def is_blinking(self) -> bool:
        """Returns True if user is blinking."""
        if not self.pupils_located:
            return False

        try:
            blinking_ratio = (self.eye_left.blinking + self.eye_right.blinking) / 2
            return blinking_ratio > 3.8

        except Exception as e:
            logger.warning(f"Blinking detection failed: {e}")
            return False

    def get_gaze_direction(self) -> str:
        """Get current gaze direction as a string."""
        if not self.pupils_located:
            return "unknown"

        if self.is_right():
            return "right"
        elif self.is_left():
            return "left"
        elif self.is_center():
            return "center"
        else:
            return "unknown"

    def get_pupil_coordinates(self) -> dict[str, tuple[int, int] | None]:
        """Get pupil coordinates for both eyes."""
        coords = {"left": None, "right": None}

        if self.pupils_located:
            try:
                if self.eye_left:
                    x = self.eye_left.origin[0] + self.eye_left.pupil.x
                    y = self.eye_left.origin[1] + self.eye_left.pupil.y
                    coords["left"] = (x, y)

                if self.eye_right:
                    x = self.eye_right.origin[0] + self.eye_right.pupil.x
                    y = self.eye_right.origin[1] + self.eye_right.pupil.y
                    coords["right"] = (x, y)

            except Exception as e:
                logger.warning(f"Failed to get pupil coordinates: {e}")

        return coords

    def annotated_frame(self):
        """Returns the frame with pupils highlighted."""
        if self.frame is None:
            return None

        frame = self.frame.copy()

        if self.pupils_located:
            try:
                color = (0, 255, 0)  # Green
                coords = self.get_pupil_coordinates()

                # Draw left pupil
                if coords["left"]:
                    x, y = coords["left"]
                    cv2.line(frame, (x - 5, y), (x + 5, y), color, 2)
                    cv2.line(frame, (x, y - 5), (x, y + 5), color, 2)

                # Draw right pupil
                if coords["right"]:
                    x, y = coords["right"]
                    cv2.line(frame, (x - 5, y), (x + 5, y), color, 2)
                    cv2.line(frame, (x, y - 5), (x, y + 5), color, 2)

            except Exception as e:
                logger.warning(f"Failed to annotate frame: {e}")

        return frame

    def get_analysis_summary(self) -> dict[str, Any]:
        """Get a summary of the gaze tracking analysis."""
        return {
            "pupils_located": self.pupils_located,
            "gaze_ratio": self.gaze_ratio,
            "gaze_ratio_left": self.gaze_ratio_left,
            "gaze_ratio_right": self.gaze_ratio_right,
            "total_frames_processed": self.total_frames_processed,
            "successful_frames": self.successful_frames,
            "error_frames": self.error_frames,
        }

    async def analyze_video_gaze(self, video_path: str) -> dict[str, Any]:
        """
        Analyze gaze tracking for an entire video file.

        Args:
            video_path: Path to the video file

        Returns:
            Dict containing gaze analysis results
        """
        try:
            # Initialize counters
            center_count = 0
            left_count = 0
            right_count = 0
            total_processed = 0
            error_frames = 0

            def on_frame(idx: int, frame):
                nonlocal \
                    center_count, \
                    left_count, \
                    right_count, \
                    total_processed, \
                    error_frames
                try:
                    self.refresh(frame)
                    if self.pupils_located:
                        if self.is_center():
                            center_count += 1
                        elif self.is_left():
                            left_count += 1
                        elif self.is_right():
                            right_count += 1
                        total_processed += 1
                    else:
                        error_frames += 1
                except Exception as e:
                    logger.warning(f"Frame {idx} analysis failed: {e}")
                    error_frames += 1

            stats = await asyncio.to_thread(foreach_frame, video_path, on_frame)

            # Calculate percentages (legacy used strings rounded to 2 decimals)
            total_frames = stats["total_frames"] or 0
            if total_frames > 0:
                total_eye_percentage = round((total_processed / total_frames) * 100, 2)
                center_percentage = round((center_count / total_frames) * 100, 2)
                left_percentage = round((left_count / total_frames) * 100, 2)
                right_percentage = round((right_count / total_frames) * 100, 2)
            else:
                total_eye_percentage = center_percentage = left_percentage = (
                    right_percentage
                ) = 0

            # Determine eye contact behavior (legacy logic)
            if center_percentage > max(left_percentage, right_percentage):
                eye_contact_behavior = (
                    "Good eye contact [based on the Total Eye Capture]"
                )
            elif left_percentage > max(center_percentage, right_percentage):
                eye_contact_behavior = (
                    "Leftside eye contact [based on the Total Eye Capture]"
                )
            elif right_percentage > max(center_percentage, left_percentage):
                eye_contact_behavior = (
                    "Rightside eye contact [based on the Total Eye Capture]"
                )
            else:
                eye_contact_behavior = "Mixed eye contact patterns"

            logger.info(
                f"Gaze analysis completed: {total_processed} successful frames, {error_frames} errors"
            )

            return {
                "left_side": left_count,
                "right_side": right_count,
                "center": center_count,
                "total_percentage": total_eye_percentage,
                "center_percentage": center_percentage,
                "left_percentage": left_percentage,
                "right_percentage": right_percentage,
                "behavior": eye_contact_behavior,
                "total_frames": total_frames,
                "processed_frames": total_processed,
                "error_frames": error_frames,
            }

        except Exception as e:
            logger.error(f"Video gaze analysis failed: {e}")
            return {
                "left_side": 0,
                "right_side": 0,
                "center": 0,
                "total_percentage": 0,
                "center_percentage": 0,
                "left_percentage": 0,
                "right_percentage": 0,
                "behavior": "Error in analysis",
                "total_frames": 0,
                "processed_frames": 0,
                "error_frames": 0,
            }
