from fastapi import FastAPI, Depends, HTTPException, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware from typing import List import json import logging import asyncio import redis.asyncio as redis from app.core.config import settings from app.api.api import api_router # Import models to ensure they're registered with SQLAlchemy from app.models.ppr import PPRRecord, User, Airport, Aircraft from app.models.journal import JournalEntry from app.models.local_flight import LocalFlight from app.models.departure import Departure from app.models.arrival import Arrival # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Redis client for pub/sub (cross-worker communication) redis_client = None pubsub = None app = FastAPI( title=settings.project_name, openapi_url=f"{settings.api_v1_str}/openapi.json", description="Prior Permission Required (PPR) system API for aircraft operations management", version="2.0.0" ) # Set up CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # Configure this properly for production allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # WebSocket connection manager for real-time updates class ConnectionManager: def __init__(self): self.active_connections: List[WebSocket] = [] self.redis_listener_task = None async def connect(self, websocket: WebSocket): await websocket.accept() self.active_connections.append(websocket) logger.info(f"WebSocket connected. Total connections: {len(self.active_connections)}") def disconnect(self, websocket: WebSocket): if websocket in self.active_connections: self.active_connections.remove(websocket) logger.info(f"WebSocket disconnected. Total connections: {len(self.active_connections)}") async def send_personal_message(self, message: str, websocket: WebSocket): await websocket.send_text(message) async def broadcast_local(self, message_str: str): """Broadcast to connections on this worker only""" dead_connections = [] for connection in self.active_connections: try: await connection.send_text(message_str) except Exception as e: logger.warning(f"Failed to send to connection: {e}") dead_connections.append(connection) # Remove dead connections for connection in dead_connections: if connection in self.active_connections: self.active_connections.remove(connection) if dead_connections: logger.info(f"Removed {len(dead_connections)} dead connections") async def broadcast(self, message: dict): """Broadcast via Redis pub/sub to all workers""" message_str = json.dumps(message) print(f"Publishing message to Redis channel: {message.get('type', 'unknown')}") logger.info(f"Publishing message to Redis channel: {message.get('type', 'unknown')}") try: if redis_client: await redis_client.publish('ppr_updates', message_str) print(f"✓ Message published to Redis") else: # Fallback to local broadcast if Redis not available print("⚠ Redis not available, falling back to local broadcast") logger.warning("Redis not available, falling back to local broadcast") await self.broadcast_local(message_str) except Exception as e: print(f"✗ Failed to publish to Redis: {e}") logger.error(f"Failed to publish to Redis: {e}") # Fallback to local broadcast await self.broadcast_local(message_str) async def start_redis_listener(self): """Listen for Redis pub/sub messages and broadcast to local connections""" global redis_client, pubsub try: # Connect to Redis redis_url = settings.redis_url or "redis://redis:6379" print(f"Connecting to Redis at: {redis_url}") redis_client = await redis.from_url(redis_url, encoding="utf-8", decode_responses=True) pubsub = redis_client.pubsub() await pubsub.subscribe('ppr_updates') print("✓ Redis listener started for PPR updates") logger.info("Redis listener started for PPR updates") async for message in pubsub.listen(): if message['type'] == 'message': message_data = message['data'] print(f"Received Redis message, broadcasting to {len(self.active_connections)} local connections") logger.info(f"Received Redis message, broadcasting to {len(self.active_connections)} local connections") await self.broadcast_local(message_data) except Exception as e: print(f"Redis listener error: {e}") logger.error(f"Redis listener error: {e}") await asyncio.sleep(5) # Wait before retry # Retry connection if self.redis_listener_task and not self.redis_listener_task.done(): asyncio.create_task(self.start_redis_listener()) manager = ConnectionManager() @app.on_event("startup") async def startup_event(): """Start Redis listener when application starts""" print("=" * 50) print("STARTUP: Starting application and Redis listener...") print("=" * 50) logger.info("Starting application and Redis listener...") manager.redis_listener_task = asyncio.create_task(manager.start_redis_listener()) @app.on_event("shutdown") async def shutdown_event(): """Clean up Redis connections on shutdown""" logger.info("Shutting down application...") global redis_client, pubsub if manager.redis_listener_task: manager.redis_listener_task.cancel() if pubsub: await pubsub.unsubscribe('ppr_updates') await pubsub.close() if redis_client: await redis_client.close() @app.websocket("/ws/tower-updates") async def websocket_endpoint(websocket: WebSocket): await manager.connect(websocket) try: while True: # Keep connection alive data = await websocket.receive_text() # Echo back for heartbeat await websocket.send_text(f"Heartbeat: {data}") except WebSocketDisconnect: manager.disconnect(websocket) @app.get("/") async def root(): return { "message": "Airfield PPR API", "version": "2.0.0", "docs": "/docs" } @app.get("/health") async def health_check(): """Health check endpoint with database connectivity verification""" from datetime import datetime from sqlalchemy import text from app.db.session import SessionLocal health_status = { "status": "healthy", "timestamp": datetime.utcnow().isoformat() + "Z", "version": "2.0.0" } # Check database connectivity try: db = SessionLocal() db.execute(text("SELECT 1")) db.close() health_status["database"] = "connected" except Exception as e: health_status["status"] = "unhealthy" health_status["database"] = "disconnected" health_status["error"] = str(e) return health_status # Include API router app.include_router(api_router, prefix=settings.api_v1_str) # Make connection manager available to the app app.state.connection_manager = manager