#!/usr/bin/env python3 """ Test script to verify the shot router optimization implementation. This tests the optimized query patterns without requiring a running server. """ import sys import os sys.path.append(os.path.dirname(os.path.abspath(__file__))) def test_shot_router_imports(): """Test that the optimized shot router can be imported successfully.""" try: from routers.shots import list_shots, get_shot, update_shot print("✅ Shot router functions imported successfully") return True except Exception as e: print(f"❌ Failed to import shot router functions: {e}") return False def test_optimization_patterns(): """Test that the optimization patterns are present in the code.""" try: import inspect from routers.shots import list_shots, get_shot, update_shot # Check list_shots optimization list_shots_source = inspect.getsource(list_shots) optimizations_found = [] if 'selectinload' in list_shots_source: optimizations_found.append("selectinload for eager loading") if 'joinedload' in list_shots_source: optimizations_found.append("joinedload for related data") if 'project_data' in list_shots_source: optimizations_found.append("optimized project data fetching") if 'shots_with_tasks' in list_shots_source: optimizations_found.append("single query with JOIN") print(f"✅ list_shots optimizations found: {', '.join(optimizations_found)}") # Check get_shot optimization get_shot_source = inspect.getsource(get_shot) get_shot_optimizations = [] if 'selectinload' in get_shot_source: get_shot_optimizations.append("selectinload for tasks") if 'joinedload' in get_shot_source: get_shot_optimizations.append("joinedload for episode/project") if 'active_tasks' in get_shot_source: get_shot_optimizations.append("relationship-based task counting") print(f"✅ get_shot optimizations found: {', '.join(get_shot_optimizations)}") # Check update_shot optimization update_shot_source = inspect.getsource(update_shot) update_shot_optimizations = [] if 'selectinload' in update_shot_source: update_shot_optimizations.append("selectinload for tasks") if 'active_tasks' in update_shot_source: update_shot_optimizations.append("relationship-based task counting") print(f"✅ update_shot optimizations found: {', '.join(update_shot_optimizations)}") return True except Exception as e: print(f"❌ Failed to analyze optimization patterns: {e}") return False def test_backward_compatibility(): """Test that the function signatures maintain backward compatibility.""" try: import inspect from routers.shots import list_shots, get_shot, update_shot # Check list_shots signature list_shots_sig = inspect.signature(list_shots) expected_params = [ 'episode_id', 'project_id', 'task_status_filter', 'sort_by', 'sort_direction', 'skip', 'limit', 'db', 'current_user' ] actual_params = list(list_shots_sig.parameters.keys()) for param in expected_params: if param not in actual_params: print(f"❌ Missing parameter in list_shots: {param}") return False print("✅ list_shots maintains backward compatible signature") # Check get_shot signature get_shot_sig = inspect.signature(get_shot) expected_get_params = ['shot_id', 'db', 'current_user'] actual_get_params = list(get_shot_sig.parameters.keys()) for param in expected_get_params: if param not in actual_get_params: print(f"❌ Missing parameter in get_shot: {param}") return False print("✅ get_shot maintains backward compatible signature") return True except Exception as e: print(f"❌ Failed to check backward compatibility: {e}") return False def main(): """Run all optimization verification tests.""" print("Shot Router Optimization Verification") print("=" * 50) tests = [ test_shot_router_imports, test_optimization_patterns, test_backward_compatibility ] passed = 0 total = len(tests) for test in tests: try: if test(): passed += 1 else: print(f"❌ Test {test.__name__} failed") except Exception as e: print(f"❌ Test {test.__name__} failed with exception: {e}") print("\n" + "=" * 50) if passed == total: print("✅ ALL OPTIMIZATION VERIFICATION TESTS PASSED!") print("\nOptimizations implemented:") print("- Single query with JOIN for shots and tasks") print("- Eager loading with selectinload and joinedload") print("- Optimized project data fetching") print("- Relationship-based task counting") print("- Backward compatibility maintained") else: print(f"❌ {total - passed} out of {total} tests failed") print("=" * 50) if __name__ == "__main__": main()