Files
ppr-ng/backend/app/main.py
2025-12-12 11:18:28 -05:00

206 lines
7.4 KiB
Python

from fastapi import FastAPI, Depends, HTTPException, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from typing import List
import json
import logging
import asyncio
import redis.asyncio as redis
from app.core.config import settings
from app.api.api import api_router
# Import models to ensure they're registered with SQLAlchemy
from app.models.ppr import PPRRecord, User, Journal, Airport, Aircraft
from app.models.local_flight import LocalFlight
from app.models.departure import Departure
from app.models.arrival import Arrival
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Redis client for pub/sub (cross-worker communication)
redis_client = None
pubsub = None
app = FastAPI(
title=settings.project_name,
openapi_url=f"{settings.api_v1_str}/openapi.json",
description="Prior Permission Required (PPR) system API for aircraft operations management",
version="2.0.0"
)
# Set up CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Configure this properly for production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# WebSocket connection manager for real-time updates
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
self.redis_listener_task = None
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
logger.info(f"WebSocket connected. Total connections: {len(self.active_connections)}")
def disconnect(self, websocket: WebSocket):
if websocket in self.active_connections:
self.active_connections.remove(websocket)
logger.info(f"WebSocket disconnected. Total connections: {len(self.active_connections)}")
async def send_personal_message(self, message: str, websocket: WebSocket):
await websocket.send_text(message)
async def broadcast_local(self, message_str: str):
"""Broadcast to connections on this worker only"""
dead_connections = []
for connection in self.active_connections:
try:
await connection.send_text(message_str)
except Exception as e:
logger.warning(f"Failed to send to connection: {e}")
dead_connections.append(connection)
# Remove dead connections
for connection in dead_connections:
if connection in self.active_connections:
self.active_connections.remove(connection)
if dead_connections:
logger.info(f"Removed {len(dead_connections)} dead connections")
async def broadcast(self, message: dict):
"""Broadcast via Redis pub/sub to all workers"""
message_str = json.dumps(message)
print(f"Publishing message to Redis channel: {message.get('type', 'unknown')}")
logger.info(f"Publishing message to Redis channel: {message.get('type', 'unknown')}")
try:
if redis_client:
await redis_client.publish('ppr_updates', message_str)
print(f"✓ Message published to Redis")
else:
# Fallback to local broadcast if Redis not available
print("⚠ Redis not available, falling back to local broadcast")
logger.warning("Redis not available, falling back to local broadcast")
await self.broadcast_local(message_str)
except Exception as e:
print(f"✗ Failed to publish to Redis: {e}")
logger.error(f"Failed to publish to Redis: {e}")
# Fallback to local broadcast
await self.broadcast_local(message_str)
async def start_redis_listener(self):
"""Listen for Redis pub/sub messages and broadcast to local connections"""
global redis_client, pubsub
try:
# Connect to Redis
redis_url = settings.redis_url or "redis://redis:6379"
print(f"Connecting to Redis at: {redis_url}")
redis_client = await redis.from_url(redis_url, encoding="utf-8", decode_responses=True)
pubsub = redis_client.pubsub()
await pubsub.subscribe('ppr_updates')
print("✓ Redis listener started for PPR updates")
logger.info("Redis listener started for PPR updates")
async for message in pubsub.listen():
if message['type'] == 'message':
message_data = message['data']
print(f"Received Redis message, broadcasting to {len(self.active_connections)} local connections")
logger.info(f"Received Redis message, broadcasting to {len(self.active_connections)} local connections")
await self.broadcast_local(message_data)
except Exception as e:
print(f"Redis listener error: {e}")
logger.error(f"Redis listener error: {e}")
await asyncio.sleep(5) # Wait before retry
# Retry connection
if self.redis_listener_task and not self.redis_listener_task.done():
asyncio.create_task(self.start_redis_listener())
manager = ConnectionManager()
@app.on_event("startup")
async def startup_event():
"""Start Redis listener when application starts"""
print("=" * 50)
print("STARTUP: Starting application and Redis listener...")
print("=" * 50)
logger.info("Starting application and Redis listener...")
manager.redis_listener_task = asyncio.create_task(manager.start_redis_listener())
@app.on_event("shutdown")
async def shutdown_event():
"""Clean up Redis connections on shutdown"""
logger.info("Shutting down application...")
global redis_client, pubsub
if manager.redis_listener_task:
manager.redis_listener_task.cancel()
if pubsub:
await pubsub.unsubscribe('ppr_updates')
await pubsub.close()
if redis_client:
await redis_client.close()
@app.websocket("/ws/tower-updates")
async def websocket_endpoint(websocket: WebSocket):
await manager.connect(websocket)
try:
while True:
# Keep connection alive
data = await websocket.receive_text()
# Echo back for heartbeat
await websocket.send_text(f"Heartbeat: {data}")
except WebSocketDisconnect:
manager.disconnect(websocket)
@app.get("/")
async def root():
return {
"message": "Airfield PPR API",
"version": "2.0.0",
"docs": "/docs"
}
@app.get("/health")
async def health_check():
"""Health check endpoint with database connectivity verification"""
from datetime import datetime
from sqlalchemy import text
from app.db.session import SessionLocal
health_status = {
"status": "healthy",
"timestamp": datetime.utcnow().isoformat() + "Z",
"version": "2.0.0"
}
# Check database connectivity
try:
db = SessionLocal()
db.execute(text("SELECT 1"))
db.close()
health_status["database"] = "connected"
except Exception as e:
health_status["status"] = "unhealthy"
health_status["database"] = "disconnected"
health_status["error"] = str(e)
return health_status
# Include API router
app.include_router(api_router, prefix=settings.api_v1_str)
# Make connection manager available to the app
app.state.connection_manager = manager