Gone live
This commit is contained in:
@@ -2,9 +2,20 @@ from fastapi import FastAPI, Depends, HTTPException, WebSocket, WebSocketDisconn
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from typing import List
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
import redis.asyncio as redis
|
||||
from app.core.config import settings
|
||||
from app.api.api import api_router
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Redis client for pub/sub (cross-worker communication)
|
||||
redis_client = None
|
||||
pubsub = None
|
||||
|
||||
app = FastAPI(
|
||||
title=settings.project_name,
|
||||
openapi_url=f"{settings.api_v1_str}/openapi.json",
|
||||
@@ -25,28 +36,117 @@ app.add_middleware(
|
||||
class ConnectionManager:
|
||||
def __init__(self):
|
||||
self.active_connections: List[WebSocket] = []
|
||||
self.redis_listener_task = None
|
||||
|
||||
async def connect(self, websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
self.active_connections.append(websocket)
|
||||
logger.info(f"WebSocket connected. Total connections: {len(self.active_connections)}")
|
||||
|
||||
def disconnect(self, websocket: WebSocket):
|
||||
self.active_connections.remove(websocket)
|
||||
if websocket in self.active_connections:
|
||||
self.active_connections.remove(websocket)
|
||||
logger.info(f"WebSocket disconnected. Total connections: {len(self.active_connections)}")
|
||||
|
||||
async def send_personal_message(self, message: str, websocket: WebSocket):
|
||||
await websocket.send_text(message)
|
||||
|
||||
async def broadcast(self, message: dict):
|
||||
message_str = json.dumps(message)
|
||||
async def broadcast_local(self, message_str: str):
|
||||
"""Broadcast to connections on this worker only"""
|
||||
dead_connections = []
|
||||
for connection in self.active_connections:
|
||||
try:
|
||||
await connection.send_text(message_str)
|
||||
except:
|
||||
# Connection is dead, remove it
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send to connection: {e}")
|
||||
dead_connections.append(connection)
|
||||
|
||||
# Remove dead connections
|
||||
for connection in dead_connections:
|
||||
if connection in self.active_connections:
|
||||
self.active_connections.remove(connection)
|
||||
|
||||
if dead_connections:
|
||||
logger.info(f"Removed {len(dead_connections)} dead connections")
|
||||
|
||||
async def broadcast(self, message: dict):
|
||||
"""Broadcast via Redis pub/sub to all workers"""
|
||||
message_str = json.dumps(message)
|
||||
print(f"Publishing message to Redis channel: {message.get('type', 'unknown')}")
|
||||
logger.info(f"Publishing message to Redis channel: {message.get('type', 'unknown')}")
|
||||
|
||||
try:
|
||||
if redis_client:
|
||||
await redis_client.publish('ppr_updates', message_str)
|
||||
print(f"✓ Message published to Redis")
|
||||
else:
|
||||
# Fallback to local broadcast if Redis not available
|
||||
print("⚠ Redis not available, falling back to local broadcast")
|
||||
logger.warning("Redis not available, falling back to local broadcast")
|
||||
await self.broadcast_local(message_str)
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to publish to Redis: {e}")
|
||||
logger.error(f"Failed to publish to Redis: {e}")
|
||||
# Fallback to local broadcast
|
||||
await self.broadcast_local(message_str)
|
||||
|
||||
async def start_redis_listener(self):
|
||||
"""Listen for Redis pub/sub messages and broadcast to local connections"""
|
||||
global redis_client, pubsub
|
||||
|
||||
try:
|
||||
# Connect to Redis
|
||||
redis_url = settings.redis_url or "redis://redis:6379"
|
||||
print(f"Connecting to Redis at: {redis_url}")
|
||||
redis_client = await redis.from_url(redis_url, encoding="utf-8", decode_responses=True)
|
||||
pubsub = redis_client.pubsub()
|
||||
await pubsub.subscribe('ppr_updates')
|
||||
|
||||
print("✓ Redis listener started for PPR updates")
|
||||
logger.info("Redis listener started for PPR updates")
|
||||
|
||||
async for message in pubsub.listen():
|
||||
if message['type'] == 'message':
|
||||
message_data = message['data']
|
||||
print(f"Received Redis message, broadcasting to {len(self.active_connections)} local connections")
|
||||
logger.info(f"Received Redis message, broadcasting to {len(self.active_connections)} local connections")
|
||||
await self.broadcast_local(message_data)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Redis listener error: {e}")
|
||||
logger.error(f"Redis listener error: {e}")
|
||||
await asyncio.sleep(5) # Wait before retry
|
||||
# Retry connection
|
||||
if self.redis_listener_task and not self.redis_listener_task.done():
|
||||
asyncio.create_task(self.start_redis_listener())
|
||||
|
||||
manager = ConnectionManager()
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Start Redis listener when application starts"""
|
||||
print("=" * 50)
|
||||
print("STARTUP: Starting application and Redis listener...")
|
||||
print("=" * 50)
|
||||
logger.info("Starting application and Redis listener...")
|
||||
manager.redis_listener_task = asyncio.create_task(manager.start_redis_listener())
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""Clean up Redis connections on shutdown"""
|
||||
logger.info("Shutting down application...")
|
||||
global redis_client, pubsub
|
||||
|
||||
if manager.redis_listener_task:
|
||||
manager.redis_listener_task.cancel()
|
||||
|
||||
if pubsub:
|
||||
await pubsub.unsubscribe('ppr_updates')
|
||||
await pubsub.close()
|
||||
|
||||
if redis_client:
|
||||
await redis_client.close()
|
||||
|
||||
@app.websocket("/ws/tower-updates")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await manager.connect(websocket)
|
||||
|
||||
Reference in New Issue
Block a user