#!/usr/bin/env python3 """ Test the optimized shot router endpoints to verify: 1. Single query optimization works correctly 2. Backward compatibility is maintained 3. Task status aggregation is accurate 4. Performance improvements are achieved """ import requests import json import time from sqlalchemy.orm import Session from database import get_db from models.user import User, UserRole from models.project import Project from models.episode import Episode from models.shot import Shot from models.task import Task BASE_URL = "http://localhost:8000" # Test credentials LOGIN_DATA = { "email": "admin@vfx.com", "password": "admin123" } def login(): """Login and get access token""" response = requests.post(f"{BASE_URL}/auth/login", json=LOGIN_DATA) if response.status_code == 200: return response.json()["access_token"] else: print(f"Login failed: {response.status_code}") return None def test_optimized_list_shots(): """Test that the optimized list_shots endpoint works correctly.""" print("\n=== Test 1: Basic list shots functionality ===") token = login() if not token: print("❌ Login failed") return False headers = {"Authorization": f"Bearer {token}"} # Test basic list shots functionality start_time = time.time() response = requests.get(f"{BASE_URL}/shots/?episode_id=1", headers=headers) end_time = time.time() if response.status_code != 200: print(f"❌ Request failed: {response.status_code}") print(f"Response: {response.json()}") return False shots = response.json() print(f"Response time: {(end_time - start_time) * 1000:.2f}ms") print(f"Number of shots returned: {len(shots)}") if not shots: print("⚠️ No shots found - creating test data might be needed") return True # Verify task status aggregation print("\n=== Test 2: Task status aggregation ===") for i, shot in enumerate(shots[:3]): # Check first 3 shots print(f"Shot {shot['name']}:") print(f" Task count: {shot['task_count']}") print(f" Task status keys: {list(shot['task_status'].keys())}") print(f" Task details count: {len(shot['task_details'])}") # Verify required fields are present required_fields = ['id', 'name', 'task_count', 'task_status', 'task_details'] for field in required_fields: if field not in shot: print(f"❌ Missing field: {field}") return False # Verify task_status is a dict if not isinstance(shot['task_status'], dict): print(f"❌ task_status should be dict, got {type(shot['task_status'])}") return False # Verify task_details is a list if not isinstance(shot['task_details'], list): print(f"❌ task_details should be list, got {type(shot['task_details'])}") return False # Verify task_details structure for task_detail in shot['task_details']: required_task_fields = ['task_type', 'status', 'task_id'] for field in required_task_fields: if field not in task_detail: print(f"❌ Missing task detail field: {field}") return False print("✅ Task status aggregation test passed!") return True def test_optimized_get_shot(): """Test that the optimized get_shot endpoint works correctly.""" print("\n=== Test 3: Optimized get_shot endpoint ===") token = login() if not token: print("❌ Login failed") return False headers = {"Authorization": f"Bearer {token}"} # First get a list of shots to find a valid shot ID response = requests.get(f"{BASE_URL}/shots/?episode_id=1&limit=1", headers=headers) if response.status_code != 200 or not response.json(): print("⚠️ No shots found for testing get_shot") return True shot_id = response.json()[0]['id'] # Test get_shot optimization start_time = time.time() response = requests.get(f"{BASE_URL}/shots/{shot_id}", headers=headers) end_time = time.time() if response.status_code != 200: print(f"❌ get_shot failed: {response.status_code}") print(f"Response: {response.json()}") return False shot = response.json() print(f"Response time: {(end_time - start_time) * 1000:.2f}ms") print(f"Shot: {shot['name']}") print(f"Task count: {shot['task_count']}") # Verify shot data is complete required_fields = ['id', 'name', 'project_id', 'episode_id', 'task_count'] for field in required_fields: if field not in shot: print(f"❌ Missing field: {field}") return False if shot['id'] != shot_id: print(f"❌ Shot ID mismatch: expected {shot_id}, got {shot['id']}") return False print("✅ get_shot optimization test passed!") return True def test_performance(): """Test performance of optimized queries.""" print("\n=== Test 4: Performance measurement ===") token = login() if not token: print("❌ Login failed") return False headers = {"Authorization": f"Bearer {token}"} # Measure optimized query performance times = [] for i in range(5): # Run 5 times to get average start_time = time.time() response = requests.get(f"{BASE_URL}/shots/?episode_id=1", headers=headers) end_time = time.time() if response.status_code != 200: print(f"❌ Performance test failed: {response.status_code}") return False times.append((end_time - start_time) * 1000) avg_time = sum(times) / len(times) min_time = min(times) max_time = max(times) print(f"Average query time: {avg_time:.2f}ms") print(f"Min query time: {min_time:.2f}ms") print(f"Max query time: {max_time:.2f}ms") # Verify performance is reasonable (should be under 500ms as per requirements) if avg_time >= 500: print(f"⚠️ Query time {avg_time:.2f}ms approaches 500ms requirement") else: print(f"✅ Query time {avg_time:.2f}ms is well under 500ms requirement") return True def test_backward_compatibility(): """Test that the optimized endpoints maintain backward compatibility.""" print("\n=== Test 5: Backward compatibility ===") token = login() if not token: print("❌ Login failed") return False headers = {"Authorization": f"Bearer {token}"} # Test all existing parameters still work test_params = [ "?episode_id=1&skip=0&limit=5", "?episode_id=1&sort_by=name&sort_direction=desc", "?episode_id=1&sort_by=name&sort_direction=asc", ] for params in test_params: response = requests.get(f"{BASE_URL}/shots/{params}", headers=headers) if response.status_code != 200: print(f"❌ Parameter test failed for {params}: {response.status_code}") return False shots = response.json() print(f"✅ Parameters {params} work correctly ({len(shots)} shots)") # Verify response format is unchanged if shots: shot = shots[0] required_fields = ['id', 'name', 'description', 'frame_start', 'frame_end', 'status', 'project_id', 'episode_id', 'created_at', 'updated_at', 'task_count', 'task_status', 'task_details'] for field in required_fields: if field not in shot: print(f"❌ Missing field in response: {field}") return False print("✅ Backward compatibility test passed!") return True def main(): """Run all optimization tests.""" print("Shot Router Optimization Tests") print("=" * 50) # Check if server is running try: response = requests.get(f"{BASE_URL}/docs") if response.status_code != 200: print("❌ Server is not running. Please start the backend server first.") return except requests.exceptions.ConnectionError: print("❌ Cannot connect to server. Please start the backend server first.") return tests = [ test_optimized_list_shots, test_optimized_get_shot, test_performance, test_backward_compatibility ] passed = 0 total = len(tests) for test in tests: try: if test(): passed += 1 else: print(f"❌ Test {test.__name__} failed") except Exception as e: print(f"❌ Test {test.__name__} failed with exception: {e}") print("\n" + "=" * 50) if passed == total: print("✅ ALL OPTIMIZATION TESTS PASSED!") else: print(f"❌ {total - passed} out of {total} tests failed") print("=" * 50) if __name__ == "__main__": main()