#!/usr/bin/env python3 """ Test script to verify that the optimized shot and asset routers properly support custom task statuses. This test ensures that: 1. Custom task statuses are included in aggregated data 2. Both default and custom statuses appear in task_status and task_details 3. Custom status sorting works correctly 4. Custom status filtering works correctly """ import sys import os sys.path.append(os.path.dirname(os.path.abspath(__file__))) import json from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from database import Base from models.user import User, UserRole from models.project import Project, ProjectType, ProjectStatus, ProjectMember from models.task import Task from models.asset import Asset, AssetCategory, AssetStatus from models.episode import Episode, EpisodeStatus from models.shot import Shot, ShotStatus # Create test database TEST_DB = "test_custom_status_optimization.db" if os.path.exists(TEST_DB): os.remove(TEST_DB) engine = create_engine(f"sqlite:///{TEST_DB}") Base.metadata.create_all(engine) SessionLocal = sessionmaker(bind=engine) def setup_test_data(): """Set up test data with custom task statuses.""" print("Setting up test data...") db = SessionLocal() try: # Create test user user = User( email="coordinator@test.com", password_hash="test", first_name="Test", last_name="Coordinator", role=UserRole.COORDINATOR, is_approved=True ) db.add(user) db.commit() # Create project with custom task statuses custom_statuses = [ { "id": "custom_review", "name": "Custom Review", "color": "#FF6B35", "order": 0, "is_default": False }, { "id": "custom_approved", "name": "Custom Approved", "color": "#4CAF50", "order": 1, "is_default": False }, { "id": "custom_blocked", "name": "Custom Blocked", "color": "#F44336", "order": 2, "is_default": False } ] project = Project( name="Custom Status Test Project", code_name="CSTP", client_name="Test Client", project_type=ProjectType.CINEMA, status=ProjectStatus.IN_PROGRESS, custom_task_statuses=json.dumps(custom_statuses) ) db.add(project) db.commit() # Add user as project member member = ProjectMember( user_id=user.id, project_id=project.id ) db.add(member) db.commit() # Create episode episode = Episode( project_id=project.id, name="Test Episode", episode_number=1, status=EpisodeStatus.IN_PROGRESS ) db.add(episode) db.commit() # Create shots with tasks using both default and custom statuses shots_data = [ {"name": "shot_001", "tasks": [ {"type": "layout", "status": "not_started"}, {"type": "animation", "status": "custom_review"}, {"type": "lighting", "status": "custom_approved"} ]}, {"name": "shot_002", "tasks": [ {"type": "layout", "status": "in_progress"}, {"type": "animation", "status": "custom_blocked"}, {"type": "lighting", "status": "submitted"} ]}, {"name": "shot_003", "tasks": [ {"type": "layout", "status": "custom_approved"}, {"type": "animation", "status": "approved"}, {"type": "lighting", "status": "custom_review"} ]} ] for shot_data in shots_data: # Create shot shot = Shot( project_id=project.id, episode_id=episode.id, name=shot_data["name"], frame_start=1001, frame_end=1100, status=ShotStatus.IN_PROGRESS ) db.add(shot) db.flush() # Get shot ID # Create tasks with mixed default and custom statuses for task_data in shot_data["tasks"]: task = Task( project_id=project.id, episode_id=episode.id, shot_id=shot.id, name=f"{shot.name}_{task_data['type']}", task_type=task_data["type"], status=task_data["status"] ) db.add(task) # Create assets with tasks using both default and custom statuses assets_data = [ {"name": "character_hero", "category": AssetCategory.CHARACTERS, "tasks": [ {"type": "modeling", "status": "custom_review"}, {"type": "surfacing", "status": "not_started"}, {"type": "rigging", "status": "custom_approved"} ]}, {"name": "prop_sword", "category": AssetCategory.PROPS, "tasks": [ {"type": "modeling", "status": "approved"}, {"type": "surfacing", "status": "custom_blocked"} ]}, {"name": "set_castle", "category": AssetCategory.SETS, "tasks": [ {"type": "modeling", "status": "custom_approved"}, {"type": "surfacing", "status": "in_progress"} ]} ] for asset_data in assets_data: # Create asset asset = Asset( project_id=project.id, name=asset_data["name"], category=asset_data["category"], status=AssetStatus.IN_PROGRESS ) db.add(asset) db.flush() # Get asset ID # Create tasks with mixed default and custom statuses for task_data in asset_data["tasks"]: task = Task( project_id=project.id, asset_id=asset.id, name=f"{asset.name}_{task_data['type']}", task_type=task_data["type"], status=task_data["status"] ) db.add(task) db.commit() print("✅ Test data setup complete") return user, project, episode except Exception as e: db.rollback() print(f"❌ Failed to setup test data: {e}") raise finally: db.close() def test_shot_custom_status_aggregation(): """Test that shot list includes custom task statuses in aggregated data.""" print("\n=== Test 1: Shot Custom Status Aggregation ===") db = SessionLocal() try: # Get test data project = db.query(Project).first() episode = db.query(Episode).first() user = db.query(User).first() # Import the optimized function from routers.shots import list_shots # Create a mock dependency function for testing def mock_get_db(): return db def mock_get_current_user(): return user # Call the optimized list_shots function # Note: We need to handle the async nature and dependencies import asyncio async def run_test(): shots = await list_shots( episode_id=episode.id, project_id=project.id, task_status_filter=None, sort_by=None, sort_direction="asc", skip=0, limit=100, db=db, current_user=user ) return shots shots = asyncio.run(run_test()) print(f"✅ Retrieved {len(shots)} shots") # Verify that shots contain both default and custom statuses custom_statuses_found = set() default_statuses_found = set() for shot in shots: print(f"\nShot: {shot.name}") print(f" Task Status: {shot.task_status}") print(f" Task Details: {len(shot.task_details)} tasks") # Check task_status dict includes both types for task_type, status in shot.task_status.items(): if status.startswith('custom_'): custom_statuses_found.add(status) elif status in ['not_started', 'in_progress', 'submitted', 'approved', 'retake']: default_statuses_found.add(status) # Check task_details includes both types for task_detail in shot.task_details: if task_detail.status.startswith('custom_'): custom_statuses_found.add(task_detail.status) elif task_detail.status in ['not_started', 'in_progress', 'submitted', 'approved', 'retake']: default_statuses_found.add(task_detail.status) print(f"\n✅ Custom statuses found: {custom_statuses_found}") print(f"✅ Default statuses found: {default_statuses_found}") # Verify we found both types assert len(custom_statuses_found) > 0, "No custom statuses found in shot aggregation" assert len(default_statuses_found) > 0, "No default statuses found in shot aggregation" # Verify specific custom statuses are present expected_custom = {'custom_review', 'custom_approved', 'custom_blocked'} found_custom = custom_statuses_found.intersection(expected_custom) assert len(found_custom) > 0, f"Expected custom statuses {expected_custom} not found" print("✅ Test 1 PASSED: Shot custom status aggregation works correctly") return True except Exception as e: print(f"❌ Test 1 FAILED: {e}") import traceback traceback.print_exc() return False finally: db.close() def test_asset_custom_status_aggregation(): """Test that asset list includes custom task statuses in aggregated data.""" print("\n=== Test 2: Asset Custom Status Aggregation ===") db = SessionLocal() try: # Get test data project = db.query(Project).first() user = db.query(User).first() # Import the optimized function from routers.assets import list_assets # Call the optimized list_assets function import asyncio async def run_test(): assets = await list_assets( project_id=project.id, category=None, task_status_filter=None, sort_by=None, sort_direction="asc", skip=0, limit=100, db=db, current_user=user ) return assets assets = asyncio.run(run_test()) print(f"✅ Retrieved {len(assets)} assets") # Verify that assets contain both default and custom statuses custom_statuses_found = set() default_statuses_found = set() for asset in assets: print(f"\nAsset: {asset.name}") print(f" Task Status: {asset.task_status}") print(f" Task Details: {len(asset.task_details)} tasks") # Check task_status dict includes both types for task_type, status in asset.task_status.items(): if status.startswith('custom_'): custom_statuses_found.add(status) elif status in ['not_started', 'in_progress', 'submitted', 'approved', 'retake']: default_statuses_found.add(status) # Check task_details includes both types for task_detail in asset.task_details: if task_detail.status.startswith('custom_'): custom_statuses_found.add(task_detail.status) elif task_detail.status in ['not_started', 'in_progress', 'submitted', 'approved', 'retake']: default_statuses_found.add(task_detail.status) print(f"\n✅ Custom statuses found: {custom_statuses_found}") print(f"✅ Default statuses found: {default_statuses_found}") # Verify we found both types assert len(custom_statuses_found) > 0, "No custom statuses found in asset aggregation" assert len(default_statuses_found) > 0, "No default statuses found in asset aggregation" # Verify specific custom statuses are present expected_custom = {'custom_review', 'custom_approved', 'custom_blocked'} found_custom = custom_statuses_found.intersection(expected_custom) assert len(found_custom) > 0, f"Expected custom statuses {expected_custom} not found" print("✅ Test 2 PASSED: Asset custom status aggregation works correctly") return True except Exception as e: print(f"❌ Test 2 FAILED: {e}") import traceback traceback.print_exc() return False finally: db.close() def test_custom_status_filtering(): """Test that custom status filtering works in optimized queries.""" print("\n=== Test 3: Custom Status Filtering ===") db = SessionLocal() try: # Get test data project = db.query(Project).first() episode = db.query(Episode).first() user = db.query(User).first() # Import the optimized function from routers.shots import list_shots import asyncio # Test filtering by custom status async def test_custom_filter(): shots = await list_shots( episode_id=episode.id, project_id=project.id, task_status_filter="animation:custom_review", # Filter for custom status sort_by=None, sort_direction="asc", skip=0, limit=100, db=db, current_user=user ) return shots filtered_shots = asyncio.run(test_custom_filter()) print(f"✅ Retrieved {len(filtered_shots)} shots with custom_review animation status") # Verify that all returned shots have the custom status for shot in filtered_shots: animation_status = shot.task_status.get('animation') assert animation_status == 'custom_review', f"Expected 'custom_review', got '{animation_status}' for shot {shot.name}" print(f" Shot {shot.name}: animation = {animation_status}") # Test filtering by default status async def test_default_filter(): shots = await list_shots( episode_id=episode.id, project_id=project.id, task_status_filter="layout:not_started", # Filter for default status sort_by=None, sort_direction="asc", skip=0, limit=100, db=db, current_user=user ) return shots default_filtered_shots = asyncio.run(test_default_filter()) print(f"✅ Retrieved {len(default_filtered_shots)} shots with not_started layout status") # Verify that all returned shots have the default status for shot in default_filtered_shots: layout_status = shot.task_status.get('layout') assert layout_status == 'not_started', f"Expected 'not_started', got '{layout_status}' for shot {shot.name}" print(f" Shot {shot.name}: layout = {layout_status}") print("✅ Test 3 PASSED: Custom status filtering works correctly") return True except Exception as e: print(f"❌ Test 3 FAILED: {e}") import traceback traceback.print_exc() return False finally: db.close() def test_custom_status_sorting(): """Test that custom status sorting works in optimized queries.""" print("\n=== Test 4: Custom Status Sorting ===") db = SessionLocal() try: # Get test data project = db.query(Project).first() episode = db.query(Episode).first() user = db.query(User).first() # Import the optimized function from routers.shots import list_shots import asyncio # Test sorting by custom status async def test_custom_sort(): shots = await list_shots( episode_id=episode.id, project_id=project.id, task_status_filter=None, sort_by="animation_status", # Sort by animation task status sort_direction="asc", skip=0, limit=100, db=db, current_user=user ) return shots sorted_shots = asyncio.run(test_custom_sort()) print(f"✅ Retrieved {len(sorted_shots)} shots sorted by animation status") # Verify sorting order (should handle both default and custom statuses) previous_order = -1 for shot in sorted_shots: animation_status = shot.task_status.get('animation', 'not_started') # Import the sorting function to get the order from routers.shots import get_status_sort_order, get_project_custom_statuses custom_statuses = get_project_custom_statuses(project.id, db) current_order = get_status_sort_order(animation_status, custom_statuses) print(f" Shot {shot.name}: animation = {animation_status} (order: {current_order})") # Verify ascending order assert current_order >= previous_order, f"Sorting order violated: {current_order} < {previous_order}" previous_order = current_order print("✅ Test 4 PASSED: Custom status sorting works correctly") return True except Exception as e: print(f"❌ Test 4 FAILED: {e}") import traceback traceback.print_exc() return False finally: db.close() def main(): """Run all tests.""" print("🚀 Starting Custom Status Optimization Support Tests") print("=" * 60) try: # Setup test data setup_test_data() # Run tests tests = [ test_shot_custom_status_aggregation, test_asset_custom_status_aggregation, test_custom_status_filtering, test_custom_status_sorting ] passed = 0 failed = 0 for test in tests: try: if test(): passed += 1 else: failed += 1 except Exception as e: print(f"❌ Test failed with exception: {e}") failed += 1 print("\n" + "=" * 60) print(f"📊 Test Results: {passed} passed, {failed} failed") if failed == 0: print("🎉 All tests passed! Custom status support is working correctly.") return True else: print("❌ Some tests failed. Custom status support needs fixes.") return False except Exception as e: print(f"❌ Test setup failed: {e}") import traceback traceback.print_exc() return False finally: # Cleanup if os.path.exists(TEST_DB): os.remove(TEST_DB) if __name__ == "__main__": success = main() sys.exit(0 if success else 1)