from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from time import time
import uuid
from app.core.logger import logger, set_request_id, get_request_id


class LoggingMiddleware(BaseHTTPMiddleware):
    """
    Middleware to log each API request and response.
    Logs method, URL, status code, duration, client IP, and request ID.
    """

    async def dispatch(self, request: Request, call_next):
        # Generate unique request ID
        request_id = str(uuid.uuid4())
        set_request_id(request_id)
        
        start_time = time()
        client_host = request.client.host if request.client else "unknown"

        # Store request ID in request state for access in endpoints
        request.state.request_id = request_id

        # Log request start
        logger.info(
            "Request started | %(method)s %(url)s | Client: %(client_ip)s",
            {
                "method": request.method,
                "url": request.url.path,
                "client_ip": client_host,
            },
        )

        try:
            response = await call_next(request)
            status_code = response.status_code
            # Add request ID to response headers
            response.headers["X-Request-ID"] = request_id
        except Exception as e:
            status_code = 500
            logger.exception(
                "Exception handling request | %(method)s %(url)s | Error: %(error)s",
                {
                    "method": request.method,
                    "url": request.url.path,
                    "error": str(e),
                },
            )
            raise

        process_time = (time() - start_time) * 1000  # ms
        log_data = {
            "method": request.method,
            "url": request.url.path,
            "status_code": status_code,
            "duration_ms": f"{process_time:.2f}",
            "client_ip": client_host,
        }

        logger.info(
            "%(method)s %(url)s | Status: %(status_code)s | "
            "Duration: %(duration_ms)s ms | Client: %(client_ip)s",
            log_data,
        )

        return response