# kafka_client.py
import asyncio
import json
from typing import Any, Dict, Optional, AsyncGenerator
from aiokafka import AIOKafkaProducer, AIOKafkaConsumer
from app.core.config import settings
from app.core.logger import logger


class KafkaClient:
    def __init__(self):
        self.bootstrap_servers = settings.KAFKA_BROKER
        self.default_topic = settings.KAFKA_TOPIC
        self.group_id = settings.KAFKA_GROUP_ID
        self.security_protocol = settings.KAFKA_SECURITY_PROTOCOL
        self.sasl_mechanism = settings.KAFKA_SASL_MECHANISM
        self.username = settings.KAFKA_USER
        self.password = settings.KAFKA_PASSWORD

        # Topic configurations
        self.webhook_topic = settings.KAFKA_WEBHOOK_TOPIC
        self.sync_topic = settings.KAFKA_SYNC_TOPIC
        self.push_topic = settings.KAFKA_PUSH_TOPIC
        self.ast_topic = settings.KAFKA_AST_TOPIC
        self.chunking_topic = settings.KAFKA_CHUNKING_TOPIC

        # Build common config dict
        self.common_config = {
            "bootstrap_servers": self.bootstrap_servers,
        }
        
        # Only add security config if provided
        if self.security_protocol:
            self.common_config["security_protocol"] = self.security_protocol
            
        if self.username and self.password:
            self.common_config.update({
                "sasl_mechanism": self.sasl_mechanism or "PLAIN",
                "sasl_plain_username": self.username,
                "sasl_plain_password": self.password,
            })

        self.producer: Optional[AIOKafkaProducer] = None
        self.consumer: Optional[AIOKafkaConsumer] = None
        self._producer_lock = asyncio.Lock()
        self._consumer_lock = asyncio.Lock()
        
        # Consumer management
        self.active_consumers: dict = {}
        self.consumer_tasks: dict = {}
        self.is_running = False

    async def start_producer(self):
        """Initialize and start Kafka producer with connection pooling"""
        async with self._producer_lock:
            if self.producer is None:
                try:
                    self.producer = AIOKafkaProducer(**self.common_config)
                    await self.producer.start()
                    logger.info("Kafka producer started successfully")
                except Exception as e:
                    logger.error(f"Failed to start Kafka producer: {str(e)}")
                    raise

    async def send_message(self, key: str, value: Any, topic: Optional[str] = None) -> bool:
        """Send message to Kafka topic with error handling and JSON serialization"""
        if not self.producer:
            await self.start_producer()

        target_topic = topic or self.default_topic
        
        try:
            # Serialize value to JSON if it's not a string
            if not isinstance(value, (str, bytes)):
                value = json.dumps(value, default=str)
            
            # Ensure value is bytes
            if isinstance(value, str):
                value = value.encode('utf-8')
            
            key_bytes = key.encode('utf-8') if isinstance(key, str) else key
            
            # Use send instead of send_and_wait for better error handling
            future = await self.producer.send(target_topic, key=key_bytes, value=value)
            # Wait for the message to be delivered
            await future
            logger.debug(f"Message sent to Kafka topic {target_topic}, key: {key}")
            return True
            
        except Exception as e:
            logger.error(f"Failed to send message to Kafka topic {target_topic}: {str(e)}")
            return False

    # In your KafkaClient class, replace the health_check method with this:

    async def health_check(self) -> bool:
        """Check if Kafka is healthy and reachable"""
        try:
            if not self.producer:
                await self.start_producer()
            
            # Simple health check - just try to list topics
            metadata = await self.producer.client.fetch_all_metadata()
            
            # Check if we can access brokers - FIXED VERSION
            if hasattr(metadata, 'brokers'):
                # brokers is a method, not a list - call it to get the actual brokers
                brokers = metadata.brokers()
                if brokers:
                    logger.info(f"Kafka health check passed. Found {len(brokers)} brokers")
                    return True
                else:
                    logger.warning("Kafka health check: No brokers found")
                    return False
            else:
                logger.warning("Kafka health check: No brokers attribute in metadata")
                return False
                
        except Exception as e:
            logger.warning(f"Kafka health check failed: {str(e)}")
        return False
    async def stop_producer(self):
        """Stop Kafka producer safely"""
        async with self._producer_lock:
            if self.producer:
                try:
                    await self.producer.stop()
                    self.producer = None
                    logger.info("Kafka producer stopped successfully")
                except Exception as e:
                    logger.error(f"Error stopping Kafka producer: {str(e)}")

    async def start_all_consumers(self):
        """Start all consumers for different topics"""
        if self.is_running:
            logger.info("Kafka consumers already running")
            return

        self.is_running = True
        
        # Define consumer configurations
        consumer_configs = {
            "webhook": {
                "group_id": "webhook_consumers",
                "topics": [self.webhook_topic],
                "session_timeout_ms": 600000,  # 10 minutes for ,
                "heartbeat_interval_ms": 10000,
                "max_poll_records": 10
            },
            "sync": {
                "group_id": "sync_consumers", 
                "topics": [self.sync_topic],
                "session_timeout_ms": 600000,  # 10 minutes for ,
                "heartbeat_interval_ms": 10000,
                "max_poll_records": 5
            },
            "push": {
                "group_id": "push_consumers",
                "topics": [self.push_topic],
                "session_timeout_ms": 600000,  # 10 minutes for ,
                "heartbeat_interval_ms": 10000,
                "max_poll_records": 5
            },
            "ast": {
                "group_id": "ast_consumers",
                "topics": [self.ast_topic],
                "session_timeout_ms": 600000,  # 10 minutes for  long AST processing
                "heartbeat_interval_ms": 300000,  # 5 minutes
                "max_poll_interval_ms": 7200000,  # 2 hours
                "max_poll_records": 1  # Process one at a time
            },
            "chunking": {
                "group_id": "chunking_consumers",
                "topics": [self.chunking_topic],
                "session_timeout_ms": 600000,  # 10 minutes for chunking
                "heartbeat_interval_ms": 60000,  # 1 minute
                "max_poll_records": 1
            }
        }

        # Start consumers for each type
        for consumer_type, config in consumer_configs.items():
            try:
                consumer = AIOKafkaConsumer(
                    *config["topics"],
                    group_id=config["group_id"],
                    **self.common_config,
                    auto_offset_reset="earliest",
                    enable_auto_commit=False,
                    session_timeout_ms=config["session_timeout_ms"],
                    heartbeat_interval_ms=config["heartbeat_interval_ms"],
                    max_poll_records=config.get("max_poll_records", 10)
                )
                
                await consumer.start()
                self.active_consumers[consumer_type] = consumer
                
                # Start message processing task
                task = asyncio.create_task(
                    self._process_messages_loop(consumer_type, consumer)
                )
                self.consumer_tasks[consumer_type] = task
                
                logger.info(f"Started {consumer_type} consumer for topics: {config['topics']}")
                
            except Exception as e:
                logger.error(f"Failed to start {consumer_type} consumer: {str(e)}")

        logger.info("All Kafka consumers started successfully")

    async def _process_messages_loop(self, consumer_type: str, consumer: AIOKafkaConsumer):
        """Process messages for a specific consumer type"""
        try:
            async for msg in consumer:
                try:
                    # Parse message
                    key = msg.key.decode('utf-8') if msg.key else None
                    value = msg.value.decode('utf-8') if msg.value else None
                    
                    # Try to parse as JSON
                    if value:
                        try:
                            value = json.loads(value)
                        except json.JSONDecodeError:
                            pass  # Keep as string if not JSON
                    
                    # Process based on consumer type
                    await self._route_message(consumer_type, msg.topic, key, value)
                    
                    # Manually commit offset after successful processing
                    await consumer.commit()
                    
                except Exception as e:
                    logger.error(f"Error processing {consumer_type} message: {str(e)}")
                    continue
                    
        except Exception as e:
            if self.is_running:  # Only log if we're supposed to be running
                logger.error(f"{consumer_type} consumer loop error: {str(e)}")

    async def _route_message(self, consumer_type: str, topic: str, key: str, value: Any):
        """Route message to appropriate processor"""
        try:
            if consumer_type == "webhook":
                await self._process_webhook_message(key, value)
            elif consumer_type == "sync":
                await self._process_sync_message(key, value)
            elif consumer_type == "push":
                await self._process_push_message(key, value)
            elif consumer_type == "ast":
                await self._process_ast_message(key, value)
            elif consumer_type == "chunking":
                await self._process_chunking_message(key, value)
        except Exception as e:
            logger.error(f"Error routing {consumer_type} message: {str(e)}")

    async def _process_webhook_message(self, key: str, value: Any):
        """Process webhook messages"""
        logger.info(f"Processing webhook: {key} - {value}")
        # Import and call your webhook processor
        try:
            from app.services.kafka.kafka_webhook_consumer import process_webhook_message
            await process_webhook_message(value)
        except ImportError:
            logger.warning("Webhook processor not available, skipping message")
        except Exception as e:
            logger.error(f"Error processing webhook message: {str(e)}")

    async def _process_sync_message(self, key: str, value: Any):
        """Process sync messages"""
        # Import and call your sync processor
        try:
            from app.services.kafka.kafka_sync_consumer import process_sync_message
            await process_sync_message(key,value)
        except ImportError:
            logger.warning("Sync processor not available, skipping message")
        except Exception as e:
            logger.error(f"Error processing sync message: {str(e)}")

    async def _process_push_message(self, key: str, value: Any):
        """Process push messages"""
        logger.info(f"Processing push: {key}")
        # Import and call your push processor
        try:
            from app.services.kafka.kafka_push_consumer import process_push_message
            await process_push_message(value)
        except ImportError:
            logger.warning("Push processor not available, skipping message")
        except Exception as e:
            logger.error(f"Error processing push message: {str(e)}")

    async def _process_ast_message(self, key: str, value: Any):
        """Process AST messages"""
        # Import and call AST processor
        try:
            from app.services.kafka.kafka_ast_consumer import process_ast_message
            await process_ast_message(value)
        except ImportError:
            logger.warning("AST processor not available, skipping message")
        except Exception as e:
            logger.error(f"Error processing AST message: {str(e)}")

    async def _process_chunking_message(self, key: str, value: Any):
        """Process chunking messages"""
        logger.info(f"Processing chunking: {key} - {value}")
        # Import and call your chunking processor
        try:
            from app.services.kafka.kafka_chunking_consumer import process_chunking_message
            await process_chunking_message(value)
        except ImportError:
            logger.warning("Chunking processor not available, skipping message")
        except Exception as e:
            logger.error(f"Error processing chunking message: {str(e)}")

    async def stop_all_consumers(self):
        """Stop all Kafka consumers safely"""
        if not self.is_running:
            logger.info("Kafka consumers already stopped")
            return

        self.is_running = False
        logger.info("Stopping all Kafka consumers...")

        # Cancel all consumer tasks
        for consumer_type, task in self.consumer_tasks.items():
            if not task.done():
                task.cancel()
                try:
                    await task
                except asyncio.CancelledError:
                    pass
                logger.info(f"Cancelled {consumer_type} consumer task")

        # Stop all consumers
        for consumer_type, consumer in self.active_consumers.items():
            try:
                await consumer.stop()
                logger.info(f"Stopped {consumer_type} consumer")
            except Exception as e:
                logger.error(f"Error stopping {consumer_type} consumer: {str(e)}")

        # Clear collections
        self.active_consumers.clear()
        self.consumer_tasks.clear()
        
        logger.info("All Kafka consumers stopped successfully")

    async def start_consumer(self, group_id: Optional[str] = None, topics: Optional[list] = None):
        """Initialize and start Kafka consumer with configurable topics"""
        async with self._consumer_lock:
            if self.consumer is None:
                try:
                    target_topics = topics or [self.default_topic]
                    target_group_id = group_id or self.group_id
                    
                    self.consumer = AIOKafkaConsumer(
                        *target_topics,
                        group_id=target_group_id,
                        **self.common_config,
                        auto_offset_reset="earliest",
                        enable_auto_commit=False,
                        session_timeout_ms=30000,
                        heartbeat_interval_ms=10000
                    )
                    await self.consumer.start()
                    logger.info(f"Kafka consumer started for topics: {target_topics}")
                except Exception as e:
                    logger.error(f"Failed to start Kafka consumer: {str(e)}")
                    raise

    async def consume_messages(self, timeout_ms: Optional[int] = None) -> AsyncGenerator[Dict, None]:
        """Async generator for consuming messages with timeout and JSON parsing"""
        if not self.consumer:
            await self.start_consumer()

        try:
            async for msg in self.consumer:
                try:
                    # Parse message
                    key = msg.key.decode('utf-8') if msg.key else None
                    value = msg.value.decode('utf-8') if msg.value else None
                    
                    # Try to parse as JSON
                    if value:
                        try:
                            value = json.loads(value)
                        except json.JSONDecodeError:
                            pass  # Keep as string if not JSON
                    
                    message_data = {
                        'topic': msg.topic,
                        'partition': msg.partition,
                        'offset': msg.offset,
                        'key': key,
                        'value': value,
                        'timestamp': msg.timestamp,
                        'headers': dict(msg.headers) if msg.headers else {}
                    }
                    
                    yield message_data
                    
                    # Manually commit offset after successful processing
                    await self.consumer.commit()
                    
                except Exception as e:
                    logger.error(f"Error processing Kafka message: {str(e)}")
                    continue
                    
        except Exception as e:
            logger.error(f"Kafka consumer error: {str(e)}")
            raise

    async def stop_consumer(self):
        """Stop Kafka consumer safely"""
        async with self._consumer_lock:
            if self.consumer:
                try:
                    await self.consumer.stop()
                    self.consumer = None
                    logger.info("Kafka consumer stopped successfully")
                except Exception as e:
                    logger.error(f"Error stopping Kafka consumer: {str(e)}")

    async def __aenter__(self):
        """Async context manager support"""
        await self.start_producer()
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        """Async context manager cleanup"""
        await self.stop_producer()
        await self.stop_consumer()


# Global instance
kafka_client = KafkaClient()


# Utility functions for all topic types
async def send_webhook_message(webhook_id: int, data: Dict) -> bool:
    """Utility function to send webhook messages"""
    return await kafka_client.send_message(
        key=f"webhook_{webhook_id}",
        value=data,
        topic=kafka_client.webhook_topic
    )

async def send_sync_message(user_id: str, data: Dict) -> bool:
    """Utility function to send sync messages"""
    return await kafka_client.send_message(
        key=f"sync_{user_id}",
        value=data,
        topic=kafka_client.sync_topic
    )

async def send_push_message(repo_id: str, data: Dict) -> bool:
    """Utility function to send push event messages"""
    return await kafka_client.send_message(
        key=f"push_{repo_id}",
        value=data,
        topic=kafka_client.push_topic
    )

async def send_ast_processing_message(file_id: str, data: Dict) -> bool:
    """Utility function to send AST processing messages"""
    return await kafka_client.send_message(
        key=f"ast_{file_id}",
        value=data,
        topic=kafka_client.ast_topic
    )

async def send_chunking_message(file_id: str, data: Dict) -> bool:
    """Utility function to send chunking messages"""
    return await kafka_client.send_message(
        key=f"chunk_{file_id}",
        value=data,
        topic=kafka_client.chunking_topic
    )


# Consumer management functions
async def start_all_consumers():
    """Start all Kafka consumers"""
    await kafka_client.start_all_consumers()

async def stop_all_consumers():
    """Stop all Kafka consumers"""
    await kafka_client.stop_all_consumers()


# App event handlers
async def on_app_startup():
    """Start Kafka consumers when app starts"""
    logger.info("Starting Kafka consumers...")
    await start_all_consumers()

async def on_app_shutdown():
    """Stop Kafka consumers when app shuts down"""
    logger.info("Stopping Kafka consumers...")
    await stop_all_consumers()

# Update your main test function to use the new shutdown
async def main():
    """Test Kafka client with proper shutdown"""
    try:
        logger.info("Starting Kafka client test...")
        
        is_healthy = await kafka_client.health_check()
        if not is_healthy:
            logger.warning("Kafka health check failed")
        
        # Test sending messages
        test_messages = [
            # (123, "webhook", {"event_type": "installation", "user_id": "user_123"}),
            # ("user_123", "sync", {"sync_type": "manual", "repositories": ["repo1"]}),
            # ("repo_789", "push", {"branch": "main", "commits": ["abc123"]}),
            ("fbcf50d0001df0ba44165bcc22bd84f3", "ast", {
                "file_id": "fbcf50d0001df0ba44165bcc22bd84f3",
                "action": "process_ast",
            }),
            # ("file_xyz", "chunking", {"chunk_count": 5, "file_size": 1024})
        ]

        for key, msg_type, data in test_messages:
            if msg_type == "webhook":
                success = await send_webhook_message(key, data)
            elif msg_type == "sync":
                success = await send_sync_message(key, data)
            elif msg_type == "push":
                success = await send_push_message(key, data)
            elif msg_type == "ast":
                success = await send_ast_processing_message(key, data)
            elif msg_type == "chunking":
                success = await send_chunking_message(key, data)
            
            if success:
                logger.info(f"✓ Successfully sent {msg_type} message")
            else:
                logger.error(f"✗ Failed to send {msg_type} message")

        # Start consumers and wait briefly
        await start_all_consumers()
        await asyncio.sleep(5)  # Let consumers process for 5 seconds
        
    except Exception as e:
        logger.error(f"Test failed: {str(e)}")
    finally:
        # Proper shutdown
        await stop_all_consumers()
        await kafka_client.stop_producer()
        logger.info("Test completed successfully")


if __name__ == "__main__":
    asyncio.run(main())