import cv2
import numpy as np
from typing import List, Tuple, Union
import os
import base64
import binascii
from fastapi import HTTPException, status
from app.config import settings
from app.schemas.face_detection import FaceBox
from app.schemas.image_input import ImageInputBase64
from loguru import logger


class FaceDetectionService:
    def __init__(self):
        self.model_path = settings.yunet_model_path
        self.confidence_threshold = 0.6
        self.detector = None
        self._load_model()

    def _load_model(self):
        if not os.path.exists(self.model_path):
            raise FileNotFoundError(f"YuNet model not found at {self.model_path}")
        
        try:
            self.detector = cv2.FaceDetectorYN.create(
                self.model_path,
                "",
                (320, 240),
                self.confidence_threshold
            )
            logger.info("YuNet face detection model loaded successfully")
        except Exception as e:
            logger.error(f"Failed to load YuNet model: {str(e)}")
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                detail="Failed to load face detection model"
            )

    def detect_faces_from_bytes(self, image_bytes: bytes) -> List[FaceBox]:
        try:
            # Convert bytes to numpy array
            nparr = np.frombuffer(image_bytes, np.uint8)
            image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
            
            if image is None:
                raise HTTPException(
                    status_code=status.HTTP_400_BAD_REQUEST,
                    detail="Invalid image format"
                )

            return self._detect_faces_from_image(image)
            
        except Exception as e:
            logger.error(f"Face detection error from bytes: {str(e)}")
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                detail=f"Face detection failed: {str(e)}"
            )

    def detect_faces_from_base64(self, base64_data: str) -> List[FaceBox]:
        try:
            # Decode base64 to bytes
            image_bytes = base64.b64decode(base64_data)
            
            # Convert to numpy array
            nparr = np.frombuffer(image_bytes, np.uint8)
            image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
            
            if image is None:
                raise HTTPException(
                    status_code=status.HTTP_400_BAD_REQUEST,
                    detail="Invalid base64 image format"
                )

            return self._detect_faces_from_image(image)
            
        except binascii.Error:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail="Invalid base64 encoding"
            )
        except Exception as e:
            logger.error(f"Face detection error from base64: {str(e)}")
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                detail=f"Face detection failed: {str(e)}"
            )

    def _detect_faces_from_image(self, image: np.ndarray) -> List[FaceBox]:
        # Get image dimensions
        height, width = image.shape[:2]
        
        # Set input size for the detector
        self.detector.setInputSize((width, height))
        
        # Detect faces
        _, faces = self.detector.detect(image)
        
        face_boxes = []
        if faces is not None:
            for face in faces:
                # YuNet returns: [x, y, w, h, x_re, y_re, x_le, y_le, x_nt, y_nt, x_rcm, y_rcm, x_lcm, y_lcm, confidence]
                x, y, w, h = face[:4].astype(int)
                confidence = float(face[14])
                
                if confidence >= self.confidence_threshold:
                    face_box = FaceBox(
                        x=max(0, x),
                        y=max(0, y),
                        width=max(0, w),
                        height=max(0, h),
                        confidence=confidence
                    )
                    face_boxes.append(face_box)
        
        logger.info(f"Detected {len(face_boxes)} faces in image")
        return face_boxes

    # Backward compatibility
    def detect_faces(self, image_bytes: bytes) -> List[FaceBox]:
        return self.detect_faces_from_bytes(image_bytes)

    def validate_image_file(self, filename: str, file_size: int) -> bool:
        # Check file size
        if file_size > settings.max_file_size:
            raise HTTPException(
                status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
                detail=f"File size exceeds maximum allowed size of {settings.max_file_size} bytes"
            )
        
        # Check file extension
        allowed_extensions = settings.allowed_extensions.split(',')
        file_extension = filename.lower().split('.')[-1] if '.' in filename else ''
        
        if file_extension not in allowed_extensions:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=f"File type not allowed. Allowed types: {', '.join(allowed_extensions)}"
            )
        
        return True

    def validate_base64_image(self, base64_data: str, max_size_mb: int = 10) -> bool:
        try:
            # Calculate base64 size (approximate file size = base64_size * 0.75)
            base64_size = len(base64_data)
            estimated_file_size = int(base64_size * 0.75)
            
            if estimated_file_size > settings.max_file_size:
                raise HTTPException(
                    status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
                    detail=f"Image size exceeds maximum allowed size of {settings.max_file_size} bytes"
                )
            
            # Try to decode to validate format
            try:
                image_bytes = base64.b64decode(base64_data)
                nparr = np.frombuffer(image_bytes, np.uint8)
                image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
                
                if image is None:
                    raise HTTPException(
                        status_code=status.HTTP_400_BAD_REQUEST,
                        detail="Invalid image format in base64 data"
                    )
                    
            except Exception:
                raise HTTPException(
                    status_code=status.HTTP_400_BAD_REQUEST,
                    detail="Invalid base64 image data"
                )
            
            return True
            
        except HTTPException:
            raise
        except Exception as e:
            logger.error(f"Base64 validation error: {str(e)}")
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail="Base64 image validation failed"
            )