281 lines
9.0 KiB
Python
281 lines
9.0 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test the optimized shot router endpoints to verify:
|
|
1. Single query optimization works correctly
|
|
2. Backward compatibility is maintained
|
|
3. Task status aggregation is accurate
|
|
4. Performance improvements are achieved
|
|
"""
|
|
|
|
import requests
|
|
import json
|
|
import time
|
|
from sqlalchemy.orm import Session
|
|
from database import get_db
|
|
from models.user import User, UserRole
|
|
from models.project import Project
|
|
from models.episode import Episode
|
|
from models.shot import Shot
|
|
from models.task import Task
|
|
|
|
BASE_URL = "http://localhost:8000"
|
|
|
|
# Test credentials
|
|
LOGIN_DATA = {
|
|
"email": "admin@vfx.com",
|
|
"password": "admin123"
|
|
}
|
|
|
|
def login():
|
|
"""Login and get access token"""
|
|
response = requests.post(f"{BASE_URL}/auth/login", json=LOGIN_DATA)
|
|
if response.status_code == 200:
|
|
return response.json()["access_token"]
|
|
else:
|
|
print(f"Login failed: {response.status_code}")
|
|
return None
|
|
|
|
def test_optimized_list_shots():
|
|
"""Test that the optimized list_shots endpoint works correctly."""
|
|
|
|
print("\n=== Test 1: Basic list shots functionality ===")
|
|
|
|
token = login()
|
|
if not token:
|
|
print("❌ Login failed")
|
|
return False
|
|
|
|
headers = {"Authorization": f"Bearer {token}"}
|
|
|
|
# Test basic list shots functionality
|
|
start_time = time.time()
|
|
response = requests.get(f"{BASE_URL}/shots/?episode_id=1", headers=headers)
|
|
end_time = time.time()
|
|
|
|
if response.status_code != 200:
|
|
print(f"❌ Request failed: {response.status_code}")
|
|
print(f"Response: {response.json()}")
|
|
return False
|
|
|
|
shots = response.json()
|
|
|
|
print(f"Response time: {(end_time - start_time) * 1000:.2f}ms")
|
|
print(f"Number of shots returned: {len(shots)}")
|
|
|
|
if not shots:
|
|
print("⚠️ No shots found - creating test data might be needed")
|
|
return True
|
|
|
|
# Verify task status aggregation
|
|
print("\n=== Test 2: Task status aggregation ===")
|
|
for i, shot in enumerate(shots[:3]): # Check first 3 shots
|
|
print(f"Shot {shot['name']}:")
|
|
print(f" Task count: {shot['task_count']}")
|
|
print(f" Task status keys: {list(shot['task_status'].keys())}")
|
|
print(f" Task details count: {len(shot['task_details'])}")
|
|
|
|
# Verify required fields are present
|
|
required_fields = ['id', 'name', 'task_count', 'task_status', 'task_details']
|
|
for field in required_fields:
|
|
if field not in shot:
|
|
print(f"❌ Missing field: {field}")
|
|
return False
|
|
|
|
# Verify task_status is a dict
|
|
if not isinstance(shot['task_status'], dict):
|
|
print(f"❌ task_status should be dict, got {type(shot['task_status'])}")
|
|
return False
|
|
|
|
# Verify task_details is a list
|
|
if not isinstance(shot['task_details'], list):
|
|
print(f"❌ task_details should be list, got {type(shot['task_details'])}")
|
|
return False
|
|
|
|
# Verify task_details structure
|
|
for task_detail in shot['task_details']:
|
|
required_task_fields = ['task_type', 'status', 'task_id']
|
|
for field in required_task_fields:
|
|
if field not in task_detail:
|
|
print(f"❌ Missing task detail field: {field}")
|
|
return False
|
|
|
|
print("✅ Task status aggregation test passed!")
|
|
return True
|
|
|
|
def test_optimized_get_shot():
|
|
"""Test that the optimized get_shot endpoint works correctly."""
|
|
|
|
print("\n=== Test 3: Optimized get_shot endpoint ===")
|
|
|
|
token = login()
|
|
if not token:
|
|
print("❌ Login failed")
|
|
return False
|
|
|
|
headers = {"Authorization": f"Bearer {token}"}
|
|
|
|
# First get a list of shots to find a valid shot ID
|
|
response = requests.get(f"{BASE_URL}/shots/?episode_id=1&limit=1", headers=headers)
|
|
if response.status_code != 200 or not response.json():
|
|
print("⚠️ No shots found for testing get_shot")
|
|
return True
|
|
|
|
shot_id = response.json()[0]['id']
|
|
|
|
# Test get_shot optimization
|
|
start_time = time.time()
|
|
response = requests.get(f"{BASE_URL}/shots/{shot_id}", headers=headers)
|
|
end_time = time.time()
|
|
|
|
if response.status_code != 200:
|
|
print(f"❌ get_shot failed: {response.status_code}")
|
|
print(f"Response: {response.json()}")
|
|
return False
|
|
|
|
shot = response.json()
|
|
|
|
print(f"Response time: {(end_time - start_time) * 1000:.2f}ms")
|
|
print(f"Shot: {shot['name']}")
|
|
print(f"Task count: {shot['task_count']}")
|
|
|
|
# Verify shot data is complete
|
|
required_fields = ['id', 'name', 'project_id', 'episode_id', 'task_count']
|
|
for field in required_fields:
|
|
if field not in shot:
|
|
print(f"❌ Missing field: {field}")
|
|
return False
|
|
|
|
if shot['id'] != shot_id:
|
|
print(f"❌ Shot ID mismatch: expected {shot_id}, got {shot['id']}")
|
|
return False
|
|
|
|
print("✅ get_shot optimization test passed!")
|
|
return True
|
|
|
|
def test_performance():
|
|
"""Test performance of optimized queries."""
|
|
|
|
print("\n=== Test 4: Performance measurement ===")
|
|
|
|
token = login()
|
|
if not token:
|
|
print("❌ Login failed")
|
|
return False
|
|
|
|
headers = {"Authorization": f"Bearer {token}"}
|
|
|
|
# Measure optimized query performance
|
|
times = []
|
|
for i in range(5): # Run 5 times to get average
|
|
start_time = time.time()
|
|
response = requests.get(f"{BASE_URL}/shots/?episode_id=1", headers=headers)
|
|
end_time = time.time()
|
|
|
|
if response.status_code != 200:
|
|
print(f"❌ Performance test failed: {response.status_code}")
|
|
return False
|
|
|
|
times.append((end_time - start_time) * 1000)
|
|
|
|
avg_time = sum(times) / len(times)
|
|
min_time = min(times)
|
|
max_time = max(times)
|
|
|
|
print(f"Average query time: {avg_time:.2f}ms")
|
|
print(f"Min query time: {min_time:.2f}ms")
|
|
print(f"Max query time: {max_time:.2f}ms")
|
|
|
|
# Verify performance is reasonable (should be under 500ms as per requirements)
|
|
if avg_time >= 500:
|
|
print(f"⚠️ Query time {avg_time:.2f}ms approaches 500ms requirement")
|
|
else:
|
|
print(f"✅ Query time {avg_time:.2f}ms is well under 500ms requirement")
|
|
|
|
return True
|
|
|
|
def test_backward_compatibility():
|
|
"""Test that the optimized endpoints maintain backward compatibility."""
|
|
|
|
print("\n=== Test 5: Backward compatibility ===")
|
|
|
|
token = login()
|
|
if not token:
|
|
print("❌ Login failed")
|
|
return False
|
|
|
|
headers = {"Authorization": f"Bearer {token}"}
|
|
|
|
# Test all existing parameters still work
|
|
test_params = [
|
|
"?episode_id=1&skip=0&limit=5",
|
|
"?episode_id=1&sort_by=name&sort_direction=desc",
|
|
"?episode_id=1&sort_by=name&sort_direction=asc",
|
|
]
|
|
|
|
for params in test_params:
|
|
response = requests.get(f"{BASE_URL}/shots/{params}", headers=headers)
|
|
if response.status_code != 200:
|
|
print(f"❌ Parameter test failed for {params}: {response.status_code}")
|
|
return False
|
|
|
|
shots = response.json()
|
|
print(f"✅ Parameters {params} work correctly ({len(shots)} shots)")
|
|
|
|
# Verify response format is unchanged
|
|
if shots:
|
|
shot = shots[0]
|
|
required_fields = ['id', 'name', 'description', 'frame_start', 'frame_end',
|
|
'status', 'project_id', 'episode_id', 'created_at', 'updated_at',
|
|
'task_count', 'task_status', 'task_details']
|
|
|
|
for field in required_fields:
|
|
if field not in shot:
|
|
print(f"❌ Missing field in response: {field}")
|
|
return False
|
|
|
|
print("✅ Backward compatibility test passed!")
|
|
return True
|
|
|
|
def main():
|
|
"""Run all optimization tests."""
|
|
print("Shot Router Optimization Tests")
|
|
print("=" * 50)
|
|
|
|
# Check if server is running
|
|
try:
|
|
response = requests.get(f"{BASE_URL}/docs")
|
|
if response.status_code != 200:
|
|
print("❌ Server is not running. Please start the backend server first.")
|
|
return
|
|
except requests.exceptions.ConnectionError:
|
|
print("❌ Cannot connect to server. Please start the backend server first.")
|
|
return
|
|
|
|
tests = [
|
|
test_optimized_list_shots,
|
|
test_optimized_get_shot,
|
|
test_performance,
|
|
test_backward_compatibility
|
|
]
|
|
|
|
passed = 0
|
|
total = len(tests)
|
|
|
|
for test in tests:
|
|
try:
|
|
if test():
|
|
passed += 1
|
|
else:
|
|
print(f"❌ Test {test.__name__} failed")
|
|
except Exception as e:
|
|
print(f"❌ Test {test.__name__} failed with exception: {e}")
|
|
|
|
print("\n" + "=" * 50)
|
|
if passed == total:
|
|
print("✅ ALL OPTIMIZATION TESTS PASSED!")
|
|
else:
|
|
print(f"❌ {total - passed} out of {total} tests failed")
|
|
print("=" * 50)
|
|
|
|
if __name__ == "__main__":
|
|
main() |