473 lines
15 KiB
Python
473 lines
15 KiB
Python
from passlib.context import CryptContext
|
|
from jose import JWTError, jwt
|
|
from datetime import datetime, timedelta
|
|
from typing import Optional, Union
|
|
from fastapi import HTTPException, status, Depends, Request
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
from sqlalchemy.orm import Session
|
|
import os
|
|
import secrets
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
|
|
# Setup logger
|
|
logger = logging.getLogger("vfx_auth")
|
|
|
|
# Password hashing
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
|
# JWT settings
|
|
SECRET_KEY = os.getenv("SECRET_KEY", "your-secret-key-here-change-in-production")
|
|
REFRESH_SECRET_KEY = os.getenv("REFRESH_SECRET_KEY", "your-refresh-secret-key-here-change-in-production")
|
|
ALGORITHM = "HS256"
|
|
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
|
REFRESH_TOKEN_EXPIRE_DAYS = 7
|
|
|
|
# Security scheme
|
|
security = HTTPBearer()
|
|
|
|
# API Key settings
|
|
API_KEY_PREFIX = "vfx_"
|
|
API_KEY_LENGTH = 32
|
|
|
|
|
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|
"""Verify a plain password against its hash."""
|
|
return pwd_context.verify(plain_password, hashed_password)
|
|
|
|
|
|
def get_password_hash(password: str) -> str:
|
|
"""Hash a password using bcrypt."""
|
|
return pwd_context.hash(password)
|
|
|
|
|
|
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
|
|
"""Create a JWT access token."""
|
|
to_encode = data.copy()
|
|
if expires_delta:
|
|
expire = datetime.utcnow() + expires_delta
|
|
else:
|
|
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
|
|
to_encode.update({"exp": expire, "type": "access"})
|
|
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
|
return encoded_jwt
|
|
|
|
|
|
def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None):
|
|
"""Create a JWT refresh token."""
|
|
to_encode = data.copy()
|
|
if expires_delta:
|
|
expire = datetime.utcnow() + expires_delta
|
|
else:
|
|
expire = datetime.utcnow() + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
|
|
|
to_encode.update({"exp": expire, "type": "refresh"})
|
|
encoded_jwt = jwt.encode(to_encode, REFRESH_SECRET_KEY, algorithm=ALGORITHM)
|
|
return encoded_jwt
|
|
|
|
|
|
def verify_token(token: str, token_type: str = "access") -> Optional[dict]:
|
|
"""Verify and decode a JWT token."""
|
|
try:
|
|
secret_key = SECRET_KEY if token_type == "access" else REFRESH_SECRET_KEY
|
|
logger.debug(f"🔐 Decoding {token_type} token with secret: {secret_key[:10]}...")
|
|
|
|
payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM])
|
|
logger.debug(f"🔐 Token decoded successfully: {payload}")
|
|
|
|
# Verify token type
|
|
if payload.get("type") != token_type:
|
|
logger.warning(f"🔐 Token type mismatch: expected {token_type}, got {payload.get('type')}")
|
|
return None
|
|
|
|
return payload
|
|
except JWTError as e:
|
|
logger.warning(f"🔐 JWT decode error: {e}")
|
|
return None
|
|
|
|
|
|
def get_current_user_from_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
|
"""Extract user information from JWT token."""
|
|
credentials_exception = HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Could not validate credentials",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
try:
|
|
logger.debug(f"🔐 Verifying token: {credentials.credentials[:20]}...")
|
|
payload = verify_token(credentials.credentials, "access")
|
|
if payload is None:
|
|
logger.warning("🔐 Token verification failed - invalid token")
|
|
raise credentials_exception
|
|
|
|
logger.debug(f"🔐 Token payload: {payload}")
|
|
|
|
user_id_str = payload.get("sub")
|
|
if user_id_str is None:
|
|
logger.warning("🔐 Token verification failed - no sub field")
|
|
raise credentials_exception
|
|
|
|
try:
|
|
user_id = int(user_id_str)
|
|
logger.debug(f"🔐 Extracted user_id: {user_id}")
|
|
except (ValueError, TypeError):
|
|
logger.warning(f"🔐 Token verification failed - invalid user_id: {user_id_str}")
|
|
raise credentials_exception
|
|
|
|
return {"user_id": user_id, "email": payload.get("email")}
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"🔐 Token verification error: {e}")
|
|
raise credentials_exception
|
|
|
|
|
|
|
|
|
|
|
|
def get_current_user(
|
|
token_data: dict = Depends(get_current_user_from_token),
|
|
db: Session = Depends(lambda: None)
|
|
):
|
|
"""Get current user from database using token data."""
|
|
from database import get_db
|
|
|
|
# Get database session if not provided
|
|
if db is None:
|
|
db_gen = get_db()
|
|
db = next(db_gen)
|
|
try:
|
|
return _get_user_from_db(db, token_data["user_id"])
|
|
finally:
|
|
db.close()
|
|
else:
|
|
return _get_user_from_db(db, token_data["user_id"])
|
|
|
|
|
|
def _get_user_from_db(db: Session, user_id: int):
|
|
"""Helper function to get user from database."""
|
|
from models.user import User
|
|
|
|
logger.debug(f"🔐 Looking up user_id: {user_id}")
|
|
user = db.query(User).filter(User.id == user_id).first()
|
|
if user is None:
|
|
logger.warning(f"🔐 User not found: {user_id}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="User not found"
|
|
)
|
|
|
|
logger.debug(f"🔐 Found user: {user.email} (approved: {user.is_approved})")
|
|
if not user.is_approved:
|
|
logger.warning(f"🔐 User not approved: {user.email}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="User account not approved"
|
|
)
|
|
|
|
return user
|
|
|
|
|
|
def get_current_user_with_db(
|
|
token_data: dict = Depends(get_current_user_from_token),
|
|
db: Session = Depends(lambda: None)
|
|
):
|
|
"""Get current user with database dependency injection."""
|
|
from database import get_db
|
|
|
|
if db is None:
|
|
# This should not happen in normal FastAPI usage
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Database session not available"
|
|
)
|
|
|
|
return _get_user_from_db(db, token_data["user_id"])
|
|
|
|
|
|
def require_role(required_roles: list):
|
|
"""Decorator to require specific user roles."""
|
|
def role_checker(
|
|
token_data: dict = Depends(get_current_user_from_token),
|
|
db: Session = Depends(lambda: None)
|
|
):
|
|
from database import get_db
|
|
|
|
# Get database session if not provided
|
|
if db is None:
|
|
db_gen = get_db()
|
|
db = next(db_gen)
|
|
try:
|
|
current_user = _get_user_from_db(db, token_data["user_id"])
|
|
finally:
|
|
db.close()
|
|
else:
|
|
current_user = _get_user_from_db(db, token_data["user_id"])
|
|
|
|
if current_user.role not in required_roles:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Insufficient permissions"
|
|
)
|
|
return current_user
|
|
return role_checker
|
|
|
|
|
|
def require_admin_permission():
|
|
"""Decorator to require admin permission regardless of role."""
|
|
def admin_checker(
|
|
token_data: dict = Depends(get_current_user_from_token),
|
|
db: Session = Depends(lambda: None)
|
|
):
|
|
from database import get_db
|
|
|
|
# Get database session if not provided
|
|
if db is None:
|
|
db_gen = get_db()
|
|
db = next(db_gen)
|
|
try:
|
|
current_user = _get_user_from_db(db, token_data["user_id"])
|
|
finally:
|
|
db.close()
|
|
else:
|
|
current_user = _get_user_from_db(db, token_data["user_id"])
|
|
|
|
if not current_user.is_admin:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Admin permission required"
|
|
)
|
|
return current_user
|
|
return admin_checker
|
|
|
|
|
|
def create_role_dependency(required_roles: list):
|
|
"""Create a dependency that requires specific user roles with proper DB injection."""
|
|
def role_checker(
|
|
token_data: dict = Depends(get_current_user_from_token),
|
|
db: Session = None
|
|
):
|
|
from database import get_db
|
|
|
|
# Get database session if not provided
|
|
if db is None:
|
|
db_gen = get_db()
|
|
db = next(db_gen)
|
|
try:
|
|
current_user = _get_user_from_db(db, token_data["user_id"])
|
|
finally:
|
|
db.close()
|
|
else:
|
|
current_user = _get_user_from_db(db, token_data["user_id"])
|
|
|
|
if current_user.role not in required_roles:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Insufficient permissions"
|
|
)
|
|
return current_user
|
|
return role_checker
|
|
|
|
|
|
# API Key utilities
|
|
def generate_api_key() -> str:
|
|
"""Generate a new API key."""
|
|
random_part = secrets.token_urlsafe(API_KEY_LENGTH)
|
|
return f"{API_KEY_PREFIX}{random_part}"
|
|
|
|
|
|
def hash_api_key(api_key: str) -> str:
|
|
"""Hash an API key for secure storage."""
|
|
return hashlib.sha256(api_key.encode()).hexdigest()
|
|
|
|
|
|
def verify_api_key_format(api_key: str) -> bool:
|
|
"""Verify that an API key has the correct format."""
|
|
return api_key.startswith(API_KEY_PREFIX) and len(api_key) > len(API_KEY_PREFIX)
|
|
|
|
|
|
def get_current_user_from_api_key(
|
|
request: Request,
|
|
db: Session = Depends(lambda: None)
|
|
) -> Optional[dict]:
|
|
"""Extract user information from API key."""
|
|
from database import get_db
|
|
from models.api_key import APIKey
|
|
from models.api_key_usage import APIKeyUsage
|
|
|
|
# Get API key from header
|
|
api_key = request.headers.get("X-API-Key")
|
|
if not api_key:
|
|
return None
|
|
|
|
# Verify format
|
|
if not verify_api_key_format(api_key):
|
|
return None
|
|
|
|
# Get database session if not provided
|
|
if db is None:
|
|
db_gen = get_db()
|
|
db = next(db_gen)
|
|
try:
|
|
return _verify_api_key_and_get_user(db, api_key, request)
|
|
finally:
|
|
db.close()
|
|
else:
|
|
return _verify_api_key_and_get_user(db, api_key, request)
|
|
|
|
|
|
def _verify_api_key_and_get_user(db: Session, api_key: str, request: Request) -> Optional[dict]:
|
|
"""Helper function to verify API key and get user."""
|
|
from models.api_key import APIKey
|
|
from models.api_key_usage import APIKeyUsage
|
|
from models.user import User
|
|
|
|
# Hash the provided key
|
|
key_hash = hash_api_key(api_key)
|
|
|
|
# Find the API key in database
|
|
api_key_record = db.query(APIKey).filter(
|
|
APIKey.key_hash == key_hash,
|
|
APIKey.is_active == True
|
|
).first()
|
|
|
|
if not api_key_record:
|
|
return None
|
|
|
|
# Check if key is expired
|
|
if api_key_record.expires_at and api_key_record.expires_at < datetime.utcnow():
|
|
return None
|
|
|
|
# Get the user
|
|
user = db.query(User).filter(User.id == api_key_record.user_id).first()
|
|
if not user or not user.is_approved:
|
|
return None
|
|
|
|
# Log API key usage
|
|
usage_log = APIKeyUsage(
|
|
api_key_id=api_key_record.id,
|
|
endpoint=str(request.url.path),
|
|
method=request.method,
|
|
ip_address=request.client.host if request.client else None,
|
|
user_agent=request.headers.get("User-Agent")
|
|
)
|
|
db.add(usage_log)
|
|
|
|
# Update last used timestamp
|
|
api_key_record.last_used_at = datetime.utcnow()
|
|
db.commit()
|
|
|
|
# Parse scopes
|
|
try:
|
|
scopes = json.loads(api_key_record.scopes)
|
|
except json.JSONDecodeError:
|
|
scopes = []
|
|
|
|
return {
|
|
"user_id": user.id,
|
|
"email": user.email,
|
|
"role": user.role,
|
|
"api_key_id": api_key_record.id,
|
|
"scopes": scopes,
|
|
"auth_type": "api_key"
|
|
}
|
|
|
|
|
|
def get_current_user_flexible(
|
|
request: Request,
|
|
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
|
|
db: Session = Depends(lambda: None)
|
|
):
|
|
"""Get current user from either JWT token or API key."""
|
|
from database import get_db
|
|
|
|
# Get database session if not provided
|
|
if db is None:
|
|
db_gen = get_db()
|
|
db = next(db_gen)
|
|
try:
|
|
return _get_current_user_flexible_with_db(request, credentials, db)
|
|
finally:
|
|
db.close()
|
|
else:
|
|
return _get_current_user_flexible_with_db(request, credentials, db)
|
|
|
|
|
|
def _get_current_user_flexible_with_db(
|
|
request: Request,
|
|
credentials: Optional[HTTPAuthorizationCredentials],
|
|
db: Session
|
|
):
|
|
"""Helper function for flexible authentication with database session."""
|
|
# Try API key first
|
|
api_key_user = get_current_user_from_api_key(request, db)
|
|
if api_key_user:
|
|
# Get full user object
|
|
user = _get_user_from_db(db, api_key_user["user_id"])
|
|
# Add API key specific data
|
|
user.api_key_id = api_key_user["api_key_id"]
|
|
user.scopes = api_key_user["scopes"]
|
|
user.auth_type = "api_key"
|
|
return user
|
|
|
|
# Try JWT token
|
|
if credentials:
|
|
try:
|
|
payload = verify_token(credentials.credentials, "access")
|
|
if payload:
|
|
user = _get_user_from_db(db, payload.get("sub"))
|
|
user.auth_type = "jwt"
|
|
return user
|
|
except Exception:
|
|
pass
|
|
|
|
# No valid authentication found
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Could not validate credentials",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
|
|
def check_api_key_scope(user, required_scope: str) -> bool:
|
|
"""Check if the current user (authenticated via API key) has the required scope."""
|
|
if not hasattr(user, 'auth_type') or user.auth_type != 'api_key':
|
|
# JWT tokens have full access based on user role
|
|
return True
|
|
|
|
if not hasattr(user, 'scopes'):
|
|
return False
|
|
|
|
# Check if user has the specific scope or full access
|
|
return required_scope in user.scopes or "full:access" in user.scopes
|
|
|
|
|
|
def require_api_key_scope(required_scope: str):
|
|
"""Decorator to require specific API key scope."""
|
|
def scope_checker(
|
|
request: Request,
|
|
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
|
|
db: Session = Depends(lambda: None)
|
|
):
|
|
from database import get_db
|
|
|
|
if db is None:
|
|
db_gen = get_db()
|
|
db = next(db_gen)
|
|
try:
|
|
user = _get_current_user_flexible_with_db(request, credentials, db)
|
|
finally:
|
|
db.close()
|
|
else:
|
|
user = _get_current_user_flexible_with_db(request, credentials, db)
|
|
|
|
if not check_api_key_scope(user, required_scope):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail=f"Insufficient permissions. Required scope: {required_scope}"
|
|
)
|
|
return user
|
|
return scope_checker |