# app/services/ast/ast_processor.py
import asyncio
import hashlib
from typing import Dict, List, Any, Optional, Set, Tuple
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logger import logger
from app.services.postgress_db_service import pg_db_service as database_service
from datetime import datetime

try:
    import tree_sitter
    from tree_sitter_languages import get_language, get_parser
    TREE_SITTER_AVAILABLE = True
except ImportError:
    logger.warning(
        "Tree-sitter not available. Install with: "
        "pip install tree-sitter==0.21.0 tree-sitter-languages>=1.10.2"
    )
    TREE_SITTER_AVAILABLE = False

class ASTProcessor:
    """Fully dynamic AST Processor for analyzing and processing code files using Tree-sitter."""
    
    def __init__(self):
        self.supported_extensions = {
            '.py': 'python',
            '.js': 'javascript', '.mjs': 'javascript', '.cjs': 'javascript',
            '.ts': 'typescript', '.tsx': 'typescript',
            '.jsx': 'javascript',
            '.java': 'java',
            '.cpp': 'cpp', '.cc': 'cpp', '.cxx': 'cpp', '.h': 'cpp', '.hpp': 'cpp',
            '.c': 'c', '.h': 'c',
            '.cs': 'c_sharp',
            '.php': 'php',
            '.rb': 'ruby',
            '.go': 'go',
            '.rs': 'rust',
            '.kt': 'kotlin', '.kts': 'kotlin',
            '.scala': 'scala',
            '.html': 'html', '.htm': 'html',
            '.css': 'css',
            '.json': 'json',
            '.yaml': 'yaml', '.yml': 'yaml',
            '.md': 'markdown',
            '.sh': 'bash', '.bash': 'bash',
            '.sql': 'sql',
            '.r': 'r',
            '.lua': 'lua',
            '.pl': 'perl', '.pm': 'perl',
        }
        
        self.parsers = {}
        self._initialize_parsers()
    
    def _initialize_parsers(self):
        """Dynamically initialize Tree-sitter parsers for all available languages."""
        if not TREE_SITTER_AVAILABLE:
            logger.warning("Tree-sitter not available. AST processing will be limited.")
            return
        
        # Try to initialize parsers for all unique languages we support
        unique_languages = set(self.supported_extensions.values())
        
        for lang_name in unique_languages:
            try:
                parser = get_parser(lang_name)
                self.parsers[lang_name] = parser
                logger.debug(f"Successfully initialized parser for {lang_name}")
            except Exception as e:
                logger.warning(f"Failed to initialize parser for {lang_name}: {str(e)}")

    
    async def process_file_to_unified_json(
        self,
        file_content: str,
        file_path: str,
        file_id: str,
        user_id: str,
        repo_id: str,
        branch_id: str
    ) -> Dict[str, Any]:
        """Process file and generate unified JSON structure."""
        try:
            # Get file extension and language
            file_extension = "." + file_path.split('.')[-1].lower() if '.' in file_path else ''
            language_name = self.supported_extensions.get(file_extension, 'unknown')
            
            # Analyze with Tree-sitter
            ast_data = await self._analyze_with_tree_sitter(file_content, file_path, language_name)
            
            # Generate unified JSON structure
            unified_json = self._generate_unified_json(
                user_id=user_id,
                repo_id=repo_id,
                branch_id=branch_id,
                file_id=file_id,
                file_path=file_path,
                language=language_name,
                file_content=file_content,
                ast_data=ast_data
            )
            
            return unified_json
            
        except Exception as e:
            logger.error(f"Failed to generate unified JSON for {file_path}: {str(e)}")
            return self._create_empty_unified_json(user_id, repo_id, branch_id, file_id, file_path, language_name)
    
    def _generate_unified_json(
        self,
        user_id: str,
        repo_id: str,
        branch_id: str,
        file_id: str,
        file_path: str,
        language: str,
        file_content: str,
        ast_data: Dict[str, Any]
    ) -> Dict[str, Any]:
        """Generate the unified JSON structure."""
        lines = file_content.splitlines()
        
        return {
            "hierarchy": {
                "user_id": user_id,
                "repo_id": repo_id,
                "branch_id": branch_id,
                "file_id": file_id
            },
            "metadata": {
                "file_path": file_path,
                "language": language,
                "content_hash": self._generate_content_hash(file_content),
                "processed_at": datetime.utcnow().isoformat(),
                "version": "1.0",
                "file_size": len(file_content),
                "encoding": "utf-8",
                "ast_analysis_successful": ast_data.get("analysis_successful", False)
            },
            "nodes": ast_data.get("chunks", []),
            "relationships": ast_data.get("relationships", []),
            "analysis": {
                "metrics": {
                    "line_count": len(lines),
                    "function_count": len([n for n in ast_data.get("chunks", []) if n.get("type") == "function"]),
                    "class_count": len([n for n in ast_data.get("chunks", []) if n.get("type") == "class"]),
                    "import_count": len([n for n in ast_data.get("chunks", []) if n.get("type") == "import"]),
                    "variable_count": len([n for n in ast_data.get("chunks", []) if n.get("type") == "variable"]),
                    "comment_count": len([n for n in ast_data.get("chunks", []) if n.get("type") == "comment"]),
                    "node_count": ast_data.get("node_count", 0),
                    "complexity_score": ast_data.get("complexity_score", 0)
                },
                "dependencies": {
                    "internal_calls": ast_data.get("internal_calls", []),
                    "external_imports": ast_data.get("external_imports", [])
                }
            }
        }
    
    async def _analyze_with_tree_sitter(
        self,
        content: str,
        file_path: str,
        language_name: str
    ) -> Dict[str, Any]:
        """Dynamically analyze file content using Tree-sitter for any language."""
        try:
            if not TREE_SITTER_AVAILABLE:
                return self._create_basic_analysis(language_name, file_path, "Tree-sitter not available")
            
            if language_name not in self.parsers:
                logger.debug(f"Parser not available for {language_name}")
                return self._create_basic_analysis(language_name, file_path, f"Parser not available for {language_name}")
            
            parser = self.parsers[language_name]
            tree = parser.parse(bytes(content, 'utf8'))
            
            # Universal AST traversal that works for any language
            chunks = []
            relationships = []
            internal_calls = set()
            external_imports = set()
            seen_chunks = set()  # To track duplicates
            
            def traverse(node, parent_name=None, parent_type=None):
                """Universal AST traversal function for any language."""
                node_type = node.type
                
                # Extract function/method definitions
                if any(keyword in node_type for keyword in ["function", "method", "def", "func", "fn", "lambda", "arrow", "procedure"]):
                    name = self._extract_name(node, content)
                    if name:
                        chunk_id = f"{file_path}:{name}:{node.start_point[0]+1}"
                        if chunk_id not in seen_chunks:
                            seen_chunks.add(chunk_id)
                            chunk_data = {
                                "id": chunk_id,
                                "file_path": file_path,
                                "type": "function",
                                "name": name,
                                "start_line": node.start_point[0] + 1,
                                "end_line": node.end_point[0] + 1,
                                "code_snippet": self._extract_code_snippet(content, node.start_point[0], node.end_point[0]),
                                "properties": {
                                    "parameters": self._extract_parameters(node, content),
                                    "return_type": "unknown",
                                    "visibility": "public"
                                }
                            }
                            chunks.append(chunk_data)
                            
                            # Add containment relationship
                            if parent_name:
                                relationships.append({
                                    "id": f"rel_{len(relationships)+1}",
                                    "source_id": parent_name,
                                    "target_id": chunk_data["id"],
                                    "type": "CONTAINS",
                                    "properties": {
                                        "from_type": parent_type,
                                        "to_type": "function"
                                    }
                                })
                            else:
                                relationships.append({
                                    "id": f"rel_{len(relationships)+1}",
                                    "source_id": file_path,
                                    "target_id": chunk_data["id"],
                                    "type": "CONTAINS",
                                    "properties": {
                                        "from_type": "file",
                                        "to_type": "function"
                                    }
                                })
                            
                            # Process function body for calls
                            self._process_function_body(node, name, internal_calls, content)
                
                # Extract class/struct/interface definitions
                elif any(keyword in node_type for keyword in ["class", "struct", "interface", "trait", "enum", "type"]):
                    name = self._extract_name(node, content)
                    if name:
                        chunk_id = f"{file_path}:{name}:{node.start_point[0]+1}"
                        if chunk_id not in seen_chunks:
                            seen_chunks.add(chunk_id)
                            chunk_data = {
                                "id": chunk_id,
                                "file_path": file_path,
                                "type": "class",
                                "name": name,
                                "start_line": node.start_point[0] + 1,
                                "end_line": node.end_point[0] + 1,
                                "code_snippet": self._extract_code_snippet(content, node.start_point[0], node.end_point[0]),
                                "properties": {
                                    "inheritance": [],
                                    "visibility": "public"
                                }
                            }
                            chunks.append(chunk_data)
                            
                            # Add containment relationship
                            if parent_name:
                                relationships.append({
                                    "id": f"rel_{len(relationships)+1}",
                                    "source_id": parent_name,
                                    "target_id": chunk_data["id"],
                                    "type": "CONTAINS",
                                    "properties": {
                                        "from_type": parent_type,
                                        "to_type": "class"
                                    }
                                })
                            else:
                                relationships.append({
                                    "id": f"rel_{len(relationships)+1}",
                                    "source_id": file_path,
                                    "target_id": chunk_data["id"],
                                    "type": "CONTAINS",
                                    "properties": {
                                        "from_type": "file",
                                        "to_type": "class"
                                    }
                                })
                            
                            # Process class body
                            for child in node.named_children:
                                traverse(child, parent_name=name, parent_type="class")
                
                # Extract import statements
                elif any(keyword in node_type for keyword in ["import", "require", "using", "include", "use", "from"]):
                    import_data = self._process_imports(node, content)
                    if import_data:
                        chunk_id = f"{file_path}:import:{node.start_point[0]+1}"
                        if chunk_id not in seen_chunks:
                            seen_chunks.add(chunk_id)
                            chunk_data = {
                                "id": chunk_id,
                                "file_path": file_path,
                                "type": "import",
                                "name": import_data.get("module", "import"),
                                "start_line": node.start_point[0] + 1,
                                "end_line": node.end_point[0] + 1,
                                "code_snippet": self._extract_code_snippet(content, node.start_point[0], node.end_point[0]),
                                "properties": {
                                    "module": import_data.get("module", ""),
                                    "import_type": import_data.get("import_type", "import")
                                }
                            }
                            chunks.append(chunk_data)
                            
                            # Add import to external imports
                            if import_data.get("module"):
                                external_imports.add(import_data["module"])
                
                # Extract variable declarations (top-level only)
                elif any(keyword in node_type for keyword in ["variable", "declaration", "assignment", "let", "const", "var", "val"]) and not parent_name:
                    name = self._extract_variable_name(node, content)
                    if name:
                        chunk_id = f"{file_path}:{name}:{node.start_point[0]+1}"
                        if chunk_id not in seen_chunks:
                            seen_chunks.add(chunk_id)
                            chunk_data = {
                                "id": chunk_id,
                                "file_path": file_path,
                                "type": "variable",
                                "name": name,
                                "start_line": node.start_point[0] + 1,
                                "end_line": node.end_point[0] + 1,
                                "code_snippet": self._extract_code_snippet(content, node.start_point[0], node.end_point[0]),
                                "properties": {
                                    "data_type": "unknown",
                                    "scope": "global",
                                    "is_constant": any(keyword in node_type for keyword in ["const", "final"])
                                }
                            }
                            chunks.append(chunk_data)
                
                # Extract comments
                elif "comment" in node_type:
                    comment_text = self._get_node_text(node, content)
                    if comment_text:
                        chunk_id = f"{file_path}:comment:{node.start_point[0]+1}"
                        if chunk_id not in seen_chunks:
                            seen_chunks.add(chunk_id)
                            chunk_data = {
                                "id": chunk_id,
                                "file_path": file_path,
                                "type": "comment",
                                "name": "comment",
                                "start_line": node.start_point[0] + 1,
                                "end_line": node.end_point[0] + 1,
                                "code_snippet": comment_text,
                                "properties": {
                                    "content": comment_text
                                }
                            }
                            chunks.append(chunk_data)
                
                # Recursively traverse children
                for child in node.named_children:
                    traverse(child, parent_name, parent_type)
            
            # Start traversal from root
            traverse(tree.root_node)
            
            analysis = {
                "language": language_name,
                "file_path": file_path,
                "chunks": chunks,
                "relationships": relationships,
                "internal_calls": list(internal_calls),
                "external_imports": list(external_imports),
                "function_count": len([c for c in chunks if c.get("type") == "function"]),
                "class_count": len([c for c in chunks if c.get("type") == "class"]),
                "import_count": len([c for c in chunks if c.get("type") == "import"]),
                "variable_count": len([c for c in chunks if c.get("type") == "variable"]),
                "comment_count": len([c for c in chunks if c.get("type") == "comment"]),
                "node_count": len(chunks),
                "complexity_score": self._calculate_complexity(tree.root_node),
                "analysis_successful": True
            }
            
            return analysis
            
        except Exception as e:
            logger.error(f"Tree-sitter analysis failed for {file_path}: {str(e)}")
            return self._create_basic_analysis(language_name, file_path, f"Analysis error: {str(e)}")
    
    def _extract_name(self, node, content: str) -> Optional[str]:
        """Extract name from node using field name or child patterns."""
        try:
            # Try to get name from field
            name_node = node.child_by_field_name("name")
            if name_node:
                return name_node.text.decode('utf-8')
            
            # Try to find identifier in children
            for child in node.named_children:
                if any(keyword in child.type for keyword in ["identifier", "name", "id"]):
                    text = child.text.decode('utf-8')
                    if text and len(text) < 100:  # Reasonable identifier length
                        return text
            
            # Fallback: extract from node text
            node_text = node.text.decode('utf-8')
            if node_text:
                # Simple heuristic: first word that looks like an identifier
                words = node_text.split()
                for word in words:
                    if word and word[0].isalpha() and len(word) < 50:
                        return word
            
            return None
        except Exception:
            return None
    
    def _extract_variable_name(self, node, content: str) -> Optional[str]:
        """Extract variable name from declaration node."""
        try:
            # Look for identifier in variable declaration
            for child in node.named_children:
                if any(keyword in child.type for keyword in ["identifier", "name", "variable"]):
                    text = child.text.decode('utf-8')
                    if text and len(text) < 100:
                        return text
            
            # Try pattern: look for the first identifier after declaration keywords
            node_text = node.text.decode('utf-8')
            words = node_text.split()
            declaration_keywords = ["let", "const", "var", "val", "def"]
            
            for i, word in enumerate(words):
                if word in declaration_keywords and i + 1 < len(words):
                    next_word = words[i + 1]
                    if next_word and next_word[0].isalpha():
                        return next_word.split('=')[0].strip()  # Handle assignments
            
            return None
        except Exception:
            return None
    
    def _extract_parameters(self, node, content: str) -> List[str]:
        """Extract parameters from function-like nodes."""
        parameters = []
        try:
            # Look for parameter containers
            param_containers = ["parameters", "parameter", "argument", "param", "formal_parameters"]
            
            for child in node.named_children:
                if any(container in child.type for container in param_containers):
                    for param_child in child.named_children:
                        param_text = param_child.text.decode('utf-8')
                        if param_text and param_text.strip() and param_text not in ['(', ')', ',', '...']:
                            parameters.append(param_text.strip())
            
            return parameters
        except Exception:
            return []
    
    def _process_function_body(self, node, function_name: str, internal_calls: set, content: str):
        """Process function body to find internal function calls."""
        try:
            # Look for call expressions in function body
            for child in node.named_children:
                if "call" in child.type:
                    # Try to extract the function being called
                    called_function = self._extract_called_function(child, content)
                    if called_function:
                        internal_calls.add(called_function)
        except Exception:
            pass
    
    def _extract_called_function(self, node, content: str) -> Optional[str]:
        """Extract function name from call expression."""
        try:
            # Look for identifier in call expression
            for child in node.named_children:
                if any(keyword in child.type for keyword in ["identifier", "name", "function"]):
                    text = child.text.decode('utf-8')
                    if text and len(text) < 100:
                        return text
            
            return None
        except Exception:
            return None
    
    def _process_imports(self, node, content: str) -> Optional[Dict[str, Any]]:
        """Process import statements."""
        try:
            import_text = node.text.decode('utf-8')
            
            # Extract module name using patterns
            module_name = self._extract_module_from_import(import_text)
            
            return {
                "statement": import_text,
                "module": module_name,
                "import_type": import_text.split()[0] if import_text.split() else "import"
            }
        except Exception:
            return None
    
    def _extract_module_from_import(self, import_statement: str) -> Optional[str]:
        """Extract module name from import statement."""
        try:
            patterns = [
                r"import\s+([\w\.]+)",
                r"from\s+([\w\.]+)",
                r"require\(['\"]([\w\.\/]+)['\"]\)",
                r"using\s+([\w\.]+)",
                r"include\s+['\"]([\w\.\/]+)['\"]",
                r"use\s+([\w\.]+)"
            ]
            
            import re
            for pattern in patterns:
                match = re.search(pattern, import_statement)
                if match:
                    return match.group(1)
            
            return None
        except Exception:
            return None
    
    def _extract_code_snippet(self, content: str, start_line: int, end_line: int) -> str:
        """Extract code snippet from content."""
        try:
            lines = content.splitlines()
            start_idx = max(0, start_line)
            end_idx = min(len(lines), end_line + 1)
            snippet_lines = lines[start_idx:end_idx]
            return "\n".join(snippet_lines)
        except Exception:
            return ""
    
    def _calculate_complexity(self, node) -> int:
        """Calculate simple complexity score based on control flow nodes."""
        complexity = 0
        
        def _count_complexity_nodes(current_node):
            nonlocal complexity
            node_type = current_node.type
            
            # Count control flow nodes
            if any(keyword in node_type for keyword in ["if", "for", "while", "switch", "case", "try", "catch", "match", "when"]):
                complexity += 1
            
            for child in current_node.named_children:
                _count_complexity_nodes(child)
        
        _count_complexity_nodes(node)
        return complexity
    
    def _get_node_text(self, node, content: str) -> Optional[str]:
        """Get text content of a node from the original content."""
        try:
            start_byte = node.start_byte
            end_byte = node.end_byte
            return content[start_byte:end_byte].decode('utf-8', errors='ignore')
        except Exception:
            return None
    
    def _generate_content_hash(self, content: str) -> str:
        """Generate SHA256 hash of content."""
        return hashlib.sha256(content.encode('utf-8')).hexdigest()
    
    def _create_empty_unified_json(
        self,
        user_id: str,
        repo_id: str,
        branch_id: str,
        file_id: str,
        file_path: str,
        language: str
    ) -> Dict[str, Any]:
        """Create empty unified JSON structure."""
        return {
            "hierarchy": {
                "user_id": user_id,
                "repo_id": repo_id,
                "branch_id": branch_id,
                "file_id": file_id
            },
            "metadata": {
                "file_path": file_path,
                "language": language,
                "content_hash": "",
                "processed_at": datetime.utcnow().isoformat(),
                "version": "1.0",
                "file_size": 0,
                "encoding": "utf-8",
                "ast_analysis_successful": False
            },
            "nodes": [],
            "relationships": [],
            "analysis": {
                "metrics": {
                    "line_count": 0,
                    "function_count": 0,
                    "class_count": 0,
                    "import_count": 0,
                    "variable_count": 0,
                    "comment_count": 0,
                    "node_count": 0,
                    "complexity_score": 0
                },
                "dependencies": {
                    "internal_calls": [],
                    "external_imports": []
                }
            }
        }
    
    def _create_basic_analysis(self, language: str, file_path: str, error_msg: str = "") -> Dict[str, Any]:
        """Create basic analysis structure."""
        return {
            "language": language,
            "file_path": file_path,
            "chunks": [],
            "relationships": [],
            "internal_calls": [],
            "external_imports": [],
            "function_count": 0,
            "class_count": 0,
            "import_count": 0,
            "variable_count": 0,
            "comment_count": 0,
            "node_count": 0,
            "complexity_score": 0,
            "analysis_successful": False,
            "error": error_msg
        }



# Initialize AST Processor
ast_processor = ASTProcessor()