#!/usr/bin/env python3
"""
Unit tests for Orchestration Core

Test coverage:
- Project creation and lifecycle
- State transitions and validation
- Result collection and validation
- Evaluation and scoring
- Iteration management
- Serving registration
- Registry locking and concurrency

Run with: python -m pytest test_orchestrator.py -v
Or: python test_orchestrator.py (fallback to unittest)
"""

import unittest
import json
import tempfile
import shutil
from pathlib import Path
from datetime import datetime
from unittest.mock import patch, MagicMock

# Import orchestrator module
import sys
sys.path.insert(0, str(Path(__file__).parent))
import orchestrator


class TestOrchestratorCore(unittest.TestCase):
    """Test suite for orchestrator core functionality"""
    
    def setUp(self):
        """Set up test environment with temporary directories"""
        self.test_dir = Path(tempfile.mkdtemp())
        self.original_workspace = orchestrator.WORKSPACE_ROOT
        self.original_registry = orchestrator.REGISTRY_PATH
        
        # Override paths to use test directory
        orchestrator.WORKSPACE_ROOT = self.test_dir
        orchestrator.REGISTRY_PATH = self.test_dir / "registry.json"
        orchestrator.PROJECTS_DIR = self.test_dir / "projects"
        orchestrator.TEMPLATES_DIR = self.test_dir / "templates"
        orchestrator.SCHEMAS_DIR = self.test_dir / "schemas"
        orchestrator.SERVING_DIR = self.test_dir / "serving"
        
        # Create template directory and basic template
        orchestrator.TEMPLATES_DIR.mkdir(parents=True, exist_ok=True)
        template_content = """# {{PROJECT_NAME}}

**ID:** {{PROJECT_ID}}
**Version:** {{VERSION}}
**Created:** {{CREATED_DATE}}

## Description
{{PROJECT_DESCRIPTION}}

## Status
Current status: {{STATUS}}
"""
        with open(orchestrator.TEMPLATES_DIR / "PROJECT.md.tmpl", 'w') as f:
            f.write(template_content)
    
    def tearDown(self):
        """Clean up test environment"""
        # Restore original paths
        orchestrator.WORKSPACE_ROOT = self.original_workspace
        orchestrator.REGISTRY_PATH = self.original_registry
        
        # Remove test directory
        if self.test_dir.exists():
            shutil.rmtree(self.test_dir)
    
    def test_state_transition_validation(self):
        """Test state transition validation logic"""
        # Valid transitions
        self.assertTrue(orchestrator.validate_state_transition("pending", "building"))
        self.assertTrue(orchestrator.validate_state_transition("building", "evaluating"))
        self.assertTrue(orchestrator.validate_state_transition("evaluating", "passed"))
        self.assertTrue(orchestrator.validate_state_transition("evaluating", "failed"))
        self.assertTrue(orchestrator.validate_state_transition("failed", "building"))  # Retry
        
        # Invalid transitions
        self.assertFalse(orchestrator.validate_state_transition("pending", "passed"))
        self.assertFalse(orchestrator.validate_state_transition("building", "passed"))
        self.assertFalse(orchestrator.validate_state_transition("passed", "building"))
        self.assertFalse(orchestrator.validate_state_transition("passed", "failed"))
    
    def test_project_id_generation(self):
        """Test unique project ID generation"""
        id1 = orchestrator.generate_project_id("Dark Tactics")
        id2 = orchestrator.generate_project_id("Dark Tactics")
        
        # IDs should start with normalized name
        self.assertTrue(id1.startswith("dark-tactics-"))
        self.assertTrue(id2.startswith("dark-tactics-"))
        
        # IDs should be unique
        self.assertNotEqual(id1, id2)
    
    def test_create_project_minimal(self):
        """Test creating a minimal project"""
        project = orchestrator.create_project(
            name="Test Project",
            description="A test project"
        )
        
        # Verify project metadata
        self.assertEqual(project["name"], "Test Project")
        self.assertEqual(project["description"], "A test project")
        self.assertEqual(project["status"], "pending")
        self.assertEqual(project["iterations"], 0)
        self.assertEqual(project["maxIterations"], 3)
        self.assertIsNotNone(project["id"])
        
        # Verify directory structure created
        project_dir = orchestrator.PROJECTS_DIR / project["id"]
        self.assertTrue(project_dir.exists())
        self.assertTrue((project_dir / "output").exists())
        self.assertTrue((project_dir / "iterations").exists())
        self.assertTrue((project_dir / "spec.json").exists())
        self.assertTrue((project_dir / "PROJECT.md").exists())
        
        # Verify project in registry
        with orchestrator.locked_registry('r') as registry:
            self.assertEqual(len(registry["projects"]), 1)
            self.assertEqual(registry["projects"][0]["id"], project["id"])
    
    def test_create_project_with_configs(self):
        """Test creating project with module configurations"""
        # Create test config files
        battle_config = self.test_dir / "battle.json"
        scenario_config = self.test_dir / "scenario.json"
        
        with open(battle_config, 'w') as f:
            json.dump({"gridCols": 10, "gridRows": 8}, f)
        
        with open(scenario_config, 'w') as f:
            json.dump({"startStage": "ch1-01"}, f)
        
        project = orchestrator.create_project(
            name="Full Project",
            battle_config=battle_config,
            scenario_config=scenario_config
        )
        
        # Verify modules loaded
        self.assertIn("battleEngine", project["modules"])
        self.assertIn("scenarioData", project["modules"])
        self.assertEqual(project["modules"]["battleEngine"]["gridCols"], 10)
        self.assertEqual(project["modules"]["scenarioData"]["startStage"], "ch1-01")
    
    def test_get_project(self):
        """Test retrieving project by name or ID"""
        project = orchestrator.create_project(name="Findable Project")
        
        # Find by ID
        found_by_id = orchestrator.get_project(project["id"])
        self.assertEqual(found_by_id["id"], project["id"])
        
        # Find by name
        found_by_name = orchestrator.get_project("Findable Project")
        self.assertEqual(found_by_name["name"], "Findable Project")
        
        # Not found
        with self.assertRaises(orchestrator.ProjectNotFoundError):
            orchestrator.get_project("Nonexistent Project")
    
    def test_list_projects(self):
        """Test listing projects with filters"""
        # Create multiple projects
        p1 = orchestrator.create_project(name="Project 1")
        p2 = orchestrator.create_project(name="Project 2")
        orchestrator.update_status(p1["id"], "building")
        orchestrator.update_status(p2["id"], "building")
        orchestrator.update_status(p2["id"], "evaluating")
        
        # List all
        all_projects = orchestrator.list_projects()
        self.assertEqual(len(all_projects), 2)
        
        # Filter by status
        building = orchestrator.list_projects(status_filter="building")
        self.assertEqual(len(building), 1)
        self.assertEqual(building[0]["id"], p1["id"])
        
        evaluating = orchestrator.list_projects(status_filter="evaluating")
        self.assertEqual(len(evaluating), 1)
        self.assertEqual(evaluating[0]["id"], p2["id"])
    
    def test_update_status_valid(self):
        """Test valid status transitions"""
        project = orchestrator.create_project(name="Status Test")
        
        # pending → building
        updated = orchestrator.update_status(project["id"], "building")
        self.assertEqual(updated["status"], "building")
        
        # building → evaluating
        updated = orchestrator.update_status(project["id"], "evaluating")
        self.assertEqual(updated["status"], "evaluating")
        
        # evaluating → passed
        updated = orchestrator.update_status(project["id"], "passed")
        self.assertEqual(updated["status"], "passed")
    
    def test_update_status_invalid(self):
        """Test invalid status transitions raise errors"""
        project = orchestrator.create_project(name="Invalid Status Test")
        
        # pending → passed (invalid)
        with self.assertRaises(orchestrator.InvalidStateTransitionError):
            orchestrator.update_status(project["id"], "passed")
    
    def test_update_status_rollback_disabled(self):
        """Test that rollback can be disabled for forced transitions"""
        project = orchestrator.create_project(name="Rollback Test")
        
        # Force invalid transition without error
        updated = orchestrator.update_status(project["id"], "passed", rollback_on_error=False)
        self.assertEqual(updated["status"], "passed")
    
    def test_collect_result(self):
        """Test collecting subagent results"""
        project = orchestrator.create_project(name="Collection Test")
        
        # Create fake output directory
        output_dir = self.test_dir / "fake_output"
        output_dir.mkdir()
        
        # Create OUTPUT.json
        output_data = {
            "status": "success",
            "timestamp": datetime.now().isoformat(),
            "data": {"result": "test"}
        }
        with open(output_dir / "OUTPUT.json", 'w') as f:
            json.dump(output_data, f)
        
        # Create additional files
        with open(output_dir / "result.txt", 'w') as f:
            f.write("Test result")
        
        # Collect results
        summary = orchestrator.collect_result(project["id"], output_dir)
        
        # Verify collection
        self.assertEqual(summary["project"], project["id"])
        self.assertGreater(summary["collected"], 0)
        
        # Verify files copied
        project_output = Path(project["outputPath"])
        self.assertTrue((project_output / "result.txt").exists())
        self.assertTrue(any(f.name.startswith("OUTPUT_") for f in project_output.glob("OUTPUT_*.json")))
    
    def test_collect_result_validation_error(self):
        """Test OUTPUT.json validation catches missing fields"""
        project = orchestrator.create_project(name="Validation Test")
        
        output_dir = self.test_dir / "invalid_output"
        output_dir.mkdir()
        
        # Create invalid OUTPUT.json (missing required fields)
        with open(output_dir / "OUTPUT.json", 'w') as f:
            json.dump({"incomplete": "data"}, f)
        
        # Should raise validation error
        with self.assertRaises(orchestrator.ValidationError):
            orchestrator.collect_result(project["id"], output_dir, validate_output=True)
    
    def test_evaluate_project_pass(self):
        """Test project evaluation that passes"""
        project = orchestrator.create_project(name="Eval Pass Test")
        
        # Create some output to pass evaluation
        output_dir = Path(project["outputPath"])
        with open(output_dir / "output.json", 'w') as f:
            json.dump({"data": "test"}, f)
        
        # Evaluate
        result = orchestrator.evaluate_project(project["id"])
        
        # Verify results
        self.assertTrue(result["passed"])
        self.assertEqual(result["overallScore"], 100.0)
        
        # Verify status updated
        updated_project = orchestrator.get_project(project["id"])
        self.assertEqual(updated_project["status"], "passed")
    
    def test_evaluate_project_fail(self):
        """Test project evaluation that fails"""
        project = orchestrator.create_project(name="Eval Fail Test")
        
        # No output files - should fail
        result = orchestrator.evaluate_project(project["id"])
        
        # Verify results
        self.assertFalse(result["passed"])
        self.assertGreater(len(result["issues"]), 0)
        
        # Verify status updated
        updated_project = orchestrator.get_project(project["id"])
        self.assertEqual(updated_project["status"], "failed")
    
    def test_create_iteration(self):
        """Test creating iteration with feedback"""
        project = orchestrator.create_project(name="Iteration Test")
        
        # Create first iteration
        iteration = orchestrator.create_iteration(project["id"], "Fix balance issues")
        
        self.assertEqual(iteration["iteration"], 1)
        self.assertEqual(iteration["maxIterations"], 3)
        self.assertFalse(iteration["escalation"])
        self.assertEqual(iteration["feedback"], "Fix balance issues")
        
        # Verify iteration file created
        iteration_file = Path(iteration["iterationFile"])
        self.assertTrue(iteration_file.exists())
        
        # Verify content
        with open(iteration_file) as f:
            content = f.read()
            self.assertIn("Fix balance issues", content)
            self.assertIn("Iteration 1/3", content)
    
    def test_create_iteration_escalation(self):
        """Test iteration escalation when max exceeded"""
        project = orchestrator.create_project(name="Escalation Test")
        
        # Create iterations up to max
        for i in range(3):
            iteration = orchestrator.create_iteration(project["id"], f"Feedback {i+1}")
        
        # Last iteration should trigger escalation
        self.assertTrue(iteration["escalation"])
        
        # Verify escalation in file
        with open(iteration["iterationFile"]) as f:
            content = f.read()
            self.assertIn("ESCALATION REQUIRED", content)
    
    def test_register_serving_success(self):
        """Test registering passed project for serving"""
        project = orchestrator.create_project(name="Serving Test")
        
        # Create output and mark as passed
        output_dir = Path(project["outputPath"])
        with open(output_dir / "output.json", 'w') as f:
            json.dump({"data": "test"}, f)
        
        orchestrator.update_status(project["id"], "building")
        orchestrator.update_status(project["id"], "evaluating")
        orchestrator.update_status(project["id"], "passed")
        
        # Register for serving
        serving = orchestrator.register_serving(project["id"])
        
        # Verify serving link created
        serving_path = Path(serving["servingPath"])
        self.assertTrue(serving_path.exists())
        self.assertTrue(serving_path.is_symlink())
        
        # Verify link points to output directory
        self.assertEqual(serving_path.resolve(), output_dir.resolve())
    
    def test_register_serving_not_passed(self):
        """Test that only passed projects can be served"""
        project = orchestrator.create_project(name="Not Passed Test")
        
        # Try to serve pending project
        with self.assertRaises(orchestrator.OrchestratorError):
            orchestrator.register_serving(project["id"])
    
    def test_registry_locking(self):
        """Test registry file locking mechanism"""
        # Create initial project
        orchestrator.create_project(name="Lock Test 1")
        
        # Test that registry can be read while locked
        with orchestrator.locked_registry('r') as registry1:
            projects_count = len(registry1["projects"])
            
            # Nested read lock should work
            with orchestrator.locked_registry('r') as registry2:
                self.assertEqual(len(registry2["projects"]), projects_count)
    
    def test_template_rendering(self):
        """Test simple template variable substitution"""
        template_path = orchestrator.TEMPLATES_DIR / "test.tmpl"
        with open(template_path, 'w') as f:
            f.write("Hello {{NAME}}, your score is {{SCORE}}!")
        
        rendered = orchestrator.render_template(
            template_path,
            {"NAME": "Alice", "SCORE": "100"}
        )
        
        self.assertEqual(rendered, "Hello Alice, your score is 100!")
    
    def test_concurrent_project_creation(self):
        """Test that concurrent project creation doesn't corrupt registry"""
        import threading
        
        def create_test_project(index):
            orchestrator.create_project(name=f"Concurrent Project {index}")
        
        # Create 5 projects concurrently
        threads = []
        for i in range(5):
            t = threading.Thread(target=create_test_project, args=(i,))
            threads.append(t)
            t.start()
        
        # Wait for all threads
        for t in threads:
            t.join()
        
        # Verify all projects created
        projects = orchestrator.list_projects()
        self.assertEqual(len(projects), 5)
    
    def test_full_project_lifecycle(self):
        """Integration test: full project lifecycle from creation to serving"""
        # 1. Create project
        project = orchestrator.create_project(
            name="Lifecycle Test",
            description="Full lifecycle test project"
        )
        self.assertEqual(project["status"], "pending")
        
        # 2. Start building
        orchestrator.update_status(project["id"], "building")
        
        # 3. Collect results
        output_dir = self.test_dir / "build_output"
        output_dir.mkdir()
        with open(output_dir / "OUTPUT.json", 'w') as f:
            json.dump({"status": "success", "timestamp": datetime.now().isoformat()}, f)
        orchestrator.collect_result(project["id"], output_dir)
        
        # 4. Evaluate
        result = orchestrator.evaluate_project(project["id"])
        self.assertTrue(result["passed"])
        
        # 5. Register for serving
        serving = orchestrator.register_serving(project["id"])
        self.assertIsNotNone(serving["servingPath"])
        
        # Verify final state
        final_project = orchestrator.get_project(project["id"])
        self.assertEqual(final_project["status"], "passed")
        self.assertIsNotNone(final_project["servingPath"])


class TestCLI(unittest.TestCase):
    """Test CLI interface"""
    
    def setUp(self):
        """Set up test environment"""
        self.test_dir = Path(tempfile.mkdtemp())
        self.original_workspace = orchestrator.WORKSPACE_ROOT
        
        orchestrator.WORKSPACE_ROOT = self.test_dir
        orchestrator.REGISTRY_PATH = self.test_dir / "registry.json"
        orchestrator.PROJECTS_DIR = self.test_dir / "projects"
        orchestrator.TEMPLATES_DIR = self.test_dir / "templates"
        orchestrator.SERVING_DIR = self.test_dir / "serving"
        
        # Create template
        orchestrator.TEMPLATES_DIR.mkdir(parents=True, exist_ok=True)
        with open(orchestrator.TEMPLATES_DIR / "PROJECT.md.tmpl", 'w') as f:
            f.write("# {{PROJECT_NAME}}")
    
    def tearDown(self):
        """Clean up"""
        orchestrator.WORKSPACE_ROOT = self.original_workspace
        if self.test_dir.exists():
            shutil.rmtree(self.test_dir)
    
    @patch('sys.argv', ['orchestrator.py', 'create', '--name', 'CLI Test'])
    def test_cli_create(self):
        """Test CLI create command"""
        result = orchestrator.main()
        self.assertEqual(result, 0)
        
        # Verify project created
        projects = orchestrator.list_projects()
        self.assertEqual(len(projects), 1)
        self.assertEqual(projects[0]["name"], "CLI Test")
    
    @patch('sys.argv', ['orchestrator.py', 'list'])
    def test_cli_list(self):
        """Test CLI list command"""
        orchestrator.create_project(name="List Test 1")
        orchestrator.create_project(name="List Test 2")
        
        result = orchestrator.main()
        self.assertEqual(result, 0)


def run_tests():
    """Run all tests"""
    # Try pytest first
    try:
        import pytest
        return pytest.main([__file__, '-v'])
    except ImportError:
        # Fall back to unittest
        print("pytest not available, using unittest")
        loader = unittest.TestLoader()
        suite = loader.loadTestsFromModule(sys.modules[__name__])
        runner = unittest.TextTestRunner(verbosity=2)
        result = runner.run(suite)
        return 0 if result.wasSuccessful() else 1


if __name__ == '__main__':
    sys.exit(run_tests())
