"""FastAPI dependency injection factories. Provides DB session, current user extraction from JWT, and role/access guards. """ import logging from collections.abc import AsyncGenerator from fastapi import Depends from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from jose import JWTError, jwt from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.config import settings from app.database import get_session from app.exceptions import AuthenticationError, AuthorizationError from app.models.user import User logger = logging.getLogger(__name__) _bearer_scheme = HTTPBearer(auto_error=False) JWT_ALGORITHM = "HS256" async def get_db() -> AsyncGenerator[AsyncSession, None]: """Yield an async DB session.""" async for session in get_session(): yield session async def get_current_user( credentials: HTTPAuthorizationCredentials | None = Depends(_bearer_scheme), db: AsyncSession = Depends(get_db), ) -> User: """Extract and validate JWT from Authorization header, return the User.""" if credentials is None: raise AuthenticationError("Authentication required") token = credentials.credentials try: payload = jwt.decode( token, settings.jwt_secret, algorithms=[JWT_ALGORITHM], ) user_id_str: str | None = payload.get("sub") if user_id_str is None: raise AuthenticationError("Invalid token: missing subject") user_id = int(user_id_str) except JWTError as exc: if "expired" in str(exc).lower(): raise AuthenticationError("Token expired") from exc raise AuthenticationError("Invalid token") from exc except (ValueError, TypeError) as exc: raise AuthenticationError("Invalid token: bad subject") from exc result = await db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if user is None: raise AuthenticationError("User not found") return user async def require_admin( user: User = Depends(get_current_user), ) -> User: """Guard that ensures the current user has the admin role.""" if user.role != "admin": raise AuthorizationError("Insufficient permissions") return user async def require_access( user: User = Depends(get_current_user), ) -> User: """Guard that ensures the current user has API access granted.""" if not user.has_access: raise AuthorizationError("Insufficient permissions") return user