155 lines
5.4 KiB
Python
155 lines
5.4 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test script to verify the shot router optimization implementation.
|
|
This tests the optimized query patterns without requiring a running server.
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
def test_shot_router_imports():
|
|
"""Test that the optimized shot router can be imported successfully."""
|
|
try:
|
|
from routers.shots import list_shots, get_shot, update_shot
|
|
print("✅ Shot router functions imported successfully")
|
|
return True
|
|
except Exception as e:
|
|
print(f"❌ Failed to import shot router functions: {e}")
|
|
return False
|
|
|
|
def test_optimization_patterns():
|
|
"""Test that the optimization patterns are present in the code."""
|
|
try:
|
|
import inspect
|
|
from routers.shots import list_shots, get_shot, update_shot
|
|
|
|
# Check list_shots optimization
|
|
list_shots_source = inspect.getsource(list_shots)
|
|
optimizations_found = []
|
|
|
|
if 'selectinload' in list_shots_source:
|
|
optimizations_found.append("selectinload for eager loading")
|
|
|
|
if 'joinedload' in list_shots_source:
|
|
optimizations_found.append("joinedload for related data")
|
|
|
|
if 'project_data' in list_shots_source:
|
|
optimizations_found.append("optimized project data fetching")
|
|
|
|
if 'shots_with_tasks' in list_shots_source:
|
|
optimizations_found.append("single query with JOIN")
|
|
|
|
print(f"✅ list_shots optimizations found: {', '.join(optimizations_found)}")
|
|
|
|
# Check get_shot optimization
|
|
get_shot_source = inspect.getsource(get_shot)
|
|
get_shot_optimizations = []
|
|
|
|
if 'selectinload' in get_shot_source:
|
|
get_shot_optimizations.append("selectinload for tasks")
|
|
|
|
if 'joinedload' in get_shot_source:
|
|
get_shot_optimizations.append("joinedload for episode/project")
|
|
|
|
if 'active_tasks' in get_shot_source:
|
|
get_shot_optimizations.append("relationship-based task counting")
|
|
|
|
print(f"✅ get_shot optimizations found: {', '.join(get_shot_optimizations)}")
|
|
|
|
# Check update_shot optimization
|
|
update_shot_source = inspect.getsource(update_shot)
|
|
update_shot_optimizations = []
|
|
|
|
if 'selectinload' in update_shot_source:
|
|
update_shot_optimizations.append("selectinload for tasks")
|
|
|
|
if 'active_tasks' in update_shot_source:
|
|
update_shot_optimizations.append("relationship-based task counting")
|
|
|
|
print(f"✅ update_shot optimizations found: {', '.join(update_shot_optimizations)}")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"❌ Failed to analyze optimization patterns: {e}")
|
|
return False
|
|
|
|
def test_backward_compatibility():
|
|
"""Test that the function signatures maintain backward compatibility."""
|
|
try:
|
|
import inspect
|
|
from routers.shots import list_shots, get_shot, update_shot
|
|
|
|
# Check list_shots signature
|
|
list_shots_sig = inspect.signature(list_shots)
|
|
expected_params = [
|
|
'episode_id', 'project_id', 'task_status_filter', 'sort_by',
|
|
'sort_direction', 'skip', 'limit', 'db', 'current_user'
|
|
]
|
|
|
|
actual_params = list(list_shots_sig.parameters.keys())
|
|
|
|
for param in expected_params:
|
|
if param not in actual_params:
|
|
print(f"❌ Missing parameter in list_shots: {param}")
|
|
return False
|
|
|
|
print("✅ list_shots maintains backward compatible signature")
|
|
|
|
# Check get_shot signature
|
|
get_shot_sig = inspect.signature(get_shot)
|
|
expected_get_params = ['shot_id', 'db', 'current_user']
|
|
actual_get_params = list(get_shot_sig.parameters.keys())
|
|
|
|
for param in expected_get_params:
|
|
if param not in actual_get_params:
|
|
print(f"❌ Missing parameter in get_shot: {param}")
|
|
return False
|
|
|
|
print("✅ get_shot maintains backward compatible signature")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"❌ Failed to check backward compatibility: {e}")
|
|
return False
|
|
|
|
def main():
|
|
"""Run all optimization verification tests."""
|
|
print("Shot Router Optimization Verification")
|
|
print("=" * 50)
|
|
|
|
tests = [
|
|
test_shot_router_imports,
|
|
test_optimization_patterns,
|
|
test_backward_compatibility
|
|
]
|
|
|
|
passed = 0
|
|
total = len(tests)
|
|
|
|
for test in tests:
|
|
try:
|
|
if test():
|
|
passed += 1
|
|
else:
|
|
print(f"❌ Test {test.__name__} failed")
|
|
except Exception as e:
|
|
print(f"❌ Test {test.__name__} failed with exception: {e}")
|
|
|
|
print("\n" + "=" * 50)
|
|
if passed == total:
|
|
print("✅ ALL OPTIMIZATION VERIFICATION TESTS PASSED!")
|
|
print("\nOptimizations implemented:")
|
|
print("- Single query with JOIN for shots and tasks")
|
|
print("- Eager loading with selectinload and joinedload")
|
|
print("- Optimized project data fetching")
|
|
print("- Relationship-based task counting")
|
|
print("- Backward compatibility maintained")
|
|
else:
|
|
print(f"❌ {total - passed} out of {total} tests failed")
|
|
print("=" * 50)
|
|
|
|
if __name__ == "__main__":
|
|
main() |