LinkDesk/backend/test_shot_router_optimizati...

155 lines
5.4 KiB
Python

#!/usr/bin/env python3
"""
Test script to verify the shot router optimization implementation.
This tests the optimized query patterns without requiring a running server.
"""
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
def test_shot_router_imports():
"""Test that the optimized shot router can be imported successfully."""
try:
from routers.shots import list_shots, get_shot, update_shot
print("✅ Shot router functions imported successfully")
return True
except Exception as e:
print(f"❌ Failed to import shot router functions: {e}")
return False
def test_optimization_patterns():
"""Test that the optimization patterns are present in the code."""
try:
import inspect
from routers.shots import list_shots, get_shot, update_shot
# Check list_shots optimization
list_shots_source = inspect.getsource(list_shots)
optimizations_found = []
if 'selectinload' in list_shots_source:
optimizations_found.append("selectinload for eager loading")
if 'joinedload' in list_shots_source:
optimizations_found.append("joinedload for related data")
if 'project_data' in list_shots_source:
optimizations_found.append("optimized project data fetching")
if 'shots_with_tasks' in list_shots_source:
optimizations_found.append("single query with JOIN")
print(f"✅ list_shots optimizations found: {', '.join(optimizations_found)}")
# Check get_shot optimization
get_shot_source = inspect.getsource(get_shot)
get_shot_optimizations = []
if 'selectinload' in get_shot_source:
get_shot_optimizations.append("selectinload for tasks")
if 'joinedload' in get_shot_source:
get_shot_optimizations.append("joinedload for episode/project")
if 'active_tasks' in get_shot_source:
get_shot_optimizations.append("relationship-based task counting")
print(f"✅ get_shot optimizations found: {', '.join(get_shot_optimizations)}")
# Check update_shot optimization
update_shot_source = inspect.getsource(update_shot)
update_shot_optimizations = []
if 'selectinload' in update_shot_source:
update_shot_optimizations.append("selectinload for tasks")
if 'active_tasks' in update_shot_source:
update_shot_optimizations.append("relationship-based task counting")
print(f"✅ update_shot optimizations found: {', '.join(update_shot_optimizations)}")
return True
except Exception as e:
print(f"❌ Failed to analyze optimization patterns: {e}")
return False
def test_backward_compatibility():
"""Test that the function signatures maintain backward compatibility."""
try:
import inspect
from routers.shots import list_shots, get_shot, update_shot
# Check list_shots signature
list_shots_sig = inspect.signature(list_shots)
expected_params = [
'episode_id', 'project_id', 'task_status_filter', 'sort_by',
'sort_direction', 'skip', 'limit', 'db', 'current_user'
]
actual_params = list(list_shots_sig.parameters.keys())
for param in expected_params:
if param not in actual_params:
print(f"❌ Missing parameter in list_shots: {param}")
return False
print("✅ list_shots maintains backward compatible signature")
# Check get_shot signature
get_shot_sig = inspect.signature(get_shot)
expected_get_params = ['shot_id', 'db', 'current_user']
actual_get_params = list(get_shot_sig.parameters.keys())
for param in expected_get_params:
if param not in actual_get_params:
print(f"❌ Missing parameter in get_shot: {param}")
return False
print("✅ get_shot maintains backward compatible signature")
return True
except Exception as e:
print(f"❌ Failed to check backward compatibility: {e}")
return False
def main():
"""Run all optimization verification tests."""
print("Shot Router Optimization Verification")
print("=" * 50)
tests = [
test_shot_router_imports,
test_optimization_patterns,
test_backward_compatibility
]
passed = 0
total = len(tests)
for test in tests:
try:
if test():
passed += 1
else:
print(f"❌ Test {test.__name__} failed")
except Exception as e:
print(f"❌ Test {test.__name__} failed with exception: {e}")
print("\n" + "=" * 50)
if passed == total:
print("✅ ALL OPTIMIZATION VERIFICATION TESTS PASSED!")
print("\nOptimizations implemented:")
print("- Single query with JOIN for shots and tasks")
print("- Eager loading with selectinload and joinedload")
print("- Optimized project data fetching")
print("- Relationship-based task counting")
print("- Backward compatibility maintained")
else:
print(f"{total - passed} out of {total} tests failed")
print("=" * 50)
if __name__ == "__main__":
main()