LinkDesk/backend/test_shot_optimization.py

281 lines
9.0 KiB
Python

#!/usr/bin/env python3
"""
Test the optimized shot router endpoints to verify:
1. Single query optimization works correctly
2. Backward compatibility is maintained
3. Task status aggregation is accurate
4. Performance improvements are achieved
"""
import requests
import json
import time
from sqlalchemy.orm import Session
from database import get_db
from models.user import User, UserRole
from models.project import Project
from models.episode import Episode
from models.shot import Shot
from models.task import Task
BASE_URL = "http://localhost:8000"
# Test credentials
LOGIN_DATA = {
"email": "admin@vfx.com",
"password": "admin123"
}
def login():
"""Login and get access token"""
response = requests.post(f"{BASE_URL}/auth/login", json=LOGIN_DATA)
if response.status_code == 200:
return response.json()["access_token"]
else:
print(f"Login failed: {response.status_code}")
return None
def test_optimized_list_shots():
"""Test that the optimized list_shots endpoint works correctly."""
print("\n=== Test 1: Basic list shots functionality ===")
token = login()
if not token:
print("❌ Login failed")
return False
headers = {"Authorization": f"Bearer {token}"}
# Test basic list shots functionality
start_time = time.time()
response = requests.get(f"{BASE_URL}/shots/?episode_id=1", headers=headers)
end_time = time.time()
if response.status_code != 200:
print(f"❌ Request failed: {response.status_code}")
print(f"Response: {response.json()}")
return False
shots = response.json()
print(f"Response time: {(end_time - start_time) * 1000:.2f}ms")
print(f"Number of shots returned: {len(shots)}")
if not shots:
print("⚠️ No shots found - creating test data might be needed")
return True
# Verify task status aggregation
print("\n=== Test 2: Task status aggregation ===")
for i, shot in enumerate(shots[:3]): # Check first 3 shots
print(f"Shot {shot['name']}:")
print(f" Task count: {shot['task_count']}")
print(f" Task status keys: {list(shot['task_status'].keys())}")
print(f" Task details count: {len(shot['task_details'])}")
# Verify required fields are present
required_fields = ['id', 'name', 'task_count', 'task_status', 'task_details']
for field in required_fields:
if field not in shot:
print(f"❌ Missing field: {field}")
return False
# Verify task_status is a dict
if not isinstance(shot['task_status'], dict):
print(f"❌ task_status should be dict, got {type(shot['task_status'])}")
return False
# Verify task_details is a list
if not isinstance(shot['task_details'], list):
print(f"❌ task_details should be list, got {type(shot['task_details'])}")
return False
# Verify task_details structure
for task_detail in shot['task_details']:
required_task_fields = ['task_type', 'status', 'task_id']
for field in required_task_fields:
if field not in task_detail:
print(f"❌ Missing task detail field: {field}")
return False
print("✅ Task status aggregation test passed!")
return True
def test_optimized_get_shot():
"""Test that the optimized get_shot endpoint works correctly."""
print("\n=== Test 3: Optimized get_shot endpoint ===")
token = login()
if not token:
print("❌ Login failed")
return False
headers = {"Authorization": f"Bearer {token}"}
# First get a list of shots to find a valid shot ID
response = requests.get(f"{BASE_URL}/shots/?episode_id=1&limit=1", headers=headers)
if response.status_code != 200 or not response.json():
print("⚠️ No shots found for testing get_shot")
return True
shot_id = response.json()[0]['id']
# Test get_shot optimization
start_time = time.time()
response = requests.get(f"{BASE_URL}/shots/{shot_id}", headers=headers)
end_time = time.time()
if response.status_code != 200:
print(f"❌ get_shot failed: {response.status_code}")
print(f"Response: {response.json()}")
return False
shot = response.json()
print(f"Response time: {(end_time - start_time) * 1000:.2f}ms")
print(f"Shot: {shot['name']}")
print(f"Task count: {shot['task_count']}")
# Verify shot data is complete
required_fields = ['id', 'name', 'project_id', 'episode_id', 'task_count']
for field in required_fields:
if field not in shot:
print(f"❌ Missing field: {field}")
return False
if shot['id'] != shot_id:
print(f"❌ Shot ID mismatch: expected {shot_id}, got {shot['id']}")
return False
print("✅ get_shot optimization test passed!")
return True
def test_performance():
"""Test performance of optimized queries."""
print("\n=== Test 4: Performance measurement ===")
token = login()
if not token:
print("❌ Login failed")
return False
headers = {"Authorization": f"Bearer {token}"}
# Measure optimized query performance
times = []
for i in range(5): # Run 5 times to get average
start_time = time.time()
response = requests.get(f"{BASE_URL}/shots/?episode_id=1", headers=headers)
end_time = time.time()
if response.status_code != 200:
print(f"❌ Performance test failed: {response.status_code}")
return False
times.append((end_time - start_time) * 1000)
avg_time = sum(times) / len(times)
min_time = min(times)
max_time = max(times)
print(f"Average query time: {avg_time:.2f}ms")
print(f"Min query time: {min_time:.2f}ms")
print(f"Max query time: {max_time:.2f}ms")
# Verify performance is reasonable (should be under 500ms as per requirements)
if avg_time >= 500:
print(f"⚠️ Query time {avg_time:.2f}ms approaches 500ms requirement")
else:
print(f"✅ Query time {avg_time:.2f}ms is well under 500ms requirement")
return True
def test_backward_compatibility():
"""Test that the optimized endpoints maintain backward compatibility."""
print("\n=== Test 5: Backward compatibility ===")
token = login()
if not token:
print("❌ Login failed")
return False
headers = {"Authorization": f"Bearer {token}"}
# Test all existing parameters still work
test_params = [
"?episode_id=1&skip=0&limit=5",
"?episode_id=1&sort_by=name&sort_direction=desc",
"?episode_id=1&sort_by=name&sort_direction=asc",
]
for params in test_params:
response = requests.get(f"{BASE_URL}/shots/{params}", headers=headers)
if response.status_code != 200:
print(f"❌ Parameter test failed for {params}: {response.status_code}")
return False
shots = response.json()
print(f"✅ Parameters {params} work correctly ({len(shots)} shots)")
# Verify response format is unchanged
if shots:
shot = shots[0]
required_fields = ['id', 'name', 'description', 'frame_start', 'frame_end',
'status', 'project_id', 'episode_id', 'created_at', 'updated_at',
'task_count', 'task_status', 'task_details']
for field in required_fields:
if field not in shot:
print(f"❌ Missing field in response: {field}")
return False
print("✅ Backward compatibility test passed!")
return True
def main():
"""Run all optimization tests."""
print("Shot Router Optimization Tests")
print("=" * 50)
# Check if server is running
try:
response = requests.get(f"{BASE_URL}/docs")
if response.status_code != 200:
print("❌ Server is not running. Please start the backend server first.")
return
except requests.exceptions.ConnectionError:
print("❌ Cannot connect to server. Please start the backend server first.")
return
tests = [
test_optimized_list_shots,
test_optimized_get_shot,
test_performance,
test_backward_compatibility
]
passed = 0
total = len(tests)
for test in tests:
try:
if test():
passed += 1
else:
print(f"❌ Test {test.__name__} failed")
except Exception as e:
print(f"❌ Test {test.__name__} failed with exception: {e}")
print("\n" + "=" * 50)
if passed == total:
print("✅ ALL OPTIMIZATION TESTS PASSED!")
else:
print(f"{total - passed} out of {total} tests failed")
print("=" * 50)
if __name__ == "__main__":
main()