Gone live

This commit is contained in:
James Pattinson
2025-12-07 15:02:51 +00:00
parent 3780b3cf2f
commit 4d71d59d90
10 changed files with 542 additions and 401 deletions

View File

@@ -2,9 +2,20 @@ from fastapi import FastAPI, Depends, HTTPException, WebSocket, WebSocketDisconn
from fastapi.middleware.cors import CORSMiddleware
from typing import List
import json
import logging
import asyncio
import redis.asyncio as redis
from app.core.config import settings
from app.api.api import api_router
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Redis client for pub/sub (cross-worker communication)
redis_client = None
pubsub = None
app = FastAPI(
title=settings.project_name,
openapi_url=f"{settings.api_v1_str}/openapi.json",
@@ -25,28 +36,117 @@ app.add_middleware(
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
self.redis_listener_task = None
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
logger.info(f"WebSocket connected. Total connections: {len(self.active_connections)}")
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
if websocket in self.active_connections:
self.active_connections.remove(websocket)
logger.info(f"WebSocket disconnected. Total connections: {len(self.active_connections)}")
async def send_personal_message(self, message: str, websocket: WebSocket):
await websocket.send_text(message)
async def broadcast(self, message: dict):
message_str = json.dumps(message)
async def broadcast_local(self, message_str: str):
"""Broadcast to connections on this worker only"""
dead_connections = []
for connection in self.active_connections:
try:
await connection.send_text(message_str)
except:
# Connection is dead, remove it
except Exception as e:
logger.warning(f"Failed to send to connection: {e}")
dead_connections.append(connection)
# Remove dead connections
for connection in dead_connections:
if connection in self.active_connections:
self.active_connections.remove(connection)
if dead_connections:
logger.info(f"Removed {len(dead_connections)} dead connections")
async def broadcast(self, message: dict):
"""Broadcast via Redis pub/sub to all workers"""
message_str = json.dumps(message)
print(f"Publishing message to Redis channel: {message.get('type', 'unknown')}")
logger.info(f"Publishing message to Redis channel: {message.get('type', 'unknown')}")
try:
if redis_client:
await redis_client.publish('ppr_updates', message_str)
print(f"✓ Message published to Redis")
else:
# Fallback to local broadcast if Redis not available
print("⚠ Redis not available, falling back to local broadcast")
logger.warning("Redis not available, falling back to local broadcast")
await self.broadcast_local(message_str)
except Exception as e:
print(f"✗ Failed to publish to Redis: {e}")
logger.error(f"Failed to publish to Redis: {e}")
# Fallback to local broadcast
await self.broadcast_local(message_str)
async def start_redis_listener(self):
"""Listen for Redis pub/sub messages and broadcast to local connections"""
global redis_client, pubsub
try:
# Connect to Redis
redis_url = settings.redis_url or "redis://redis:6379"
print(f"Connecting to Redis at: {redis_url}")
redis_client = await redis.from_url(redis_url, encoding="utf-8", decode_responses=True)
pubsub = redis_client.pubsub()
await pubsub.subscribe('ppr_updates')
print("✓ Redis listener started for PPR updates")
logger.info("Redis listener started for PPR updates")
async for message in pubsub.listen():
if message['type'] == 'message':
message_data = message['data']
print(f"Received Redis message, broadcasting to {len(self.active_connections)} local connections")
logger.info(f"Received Redis message, broadcasting to {len(self.active_connections)} local connections")
await self.broadcast_local(message_data)
except Exception as e:
print(f"Redis listener error: {e}")
logger.error(f"Redis listener error: {e}")
await asyncio.sleep(5) # Wait before retry
# Retry connection
if self.redis_listener_task and not self.redis_listener_task.done():
asyncio.create_task(self.start_redis_listener())
manager = ConnectionManager()
@app.on_event("startup")
async def startup_event():
"""Start Redis listener when application starts"""
print("=" * 50)
print("STARTUP: Starting application and Redis listener...")
print("=" * 50)
logger.info("Starting application and Redis listener...")
manager.redis_listener_task = asyncio.create_task(manager.start_redis_listener())
@app.on_event("shutdown")
async def shutdown_event():
"""Clean up Redis connections on shutdown"""
logger.info("Shutting down application...")
global redis_client, pubsub
if manager.redis_listener_task:
manager.redis_listener_task.cancel()
if pubsub:
await pubsub.unsubscribe('ppr_updates')
await pubsub.close()
if redis_client:
await redis_client.close()
@app.websocket("/ws/tower-updates")
async def websocket_endpoint(websocket: WebSocket):
await manager.connect(websocket)