Added auth

This commit is contained in:
2026-01-21 16:38:52 -05:00
parent 615c7caee8
commit cc5c7ff42d
7 changed files with 1303 additions and 277 deletions

93
backend/app/auth.py Normal file
View File

@@ -0,0 +1,93 @@
from datetime import datetime, timedelta
from typing import Optional
from jose import JWTError, jwt
from passlib.context import CryptContext
from fastapi import Depends, HTTPException, status, Header
from sqlalchemy.orm import Session
from .database import get_db
from .models import User
# Configuration
SECRET_KEY = "your-secret-key-change-in-production" # Change this in production!
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 720
pwd_context = CryptContext(schemes=["argon2"], deprecated="auto")
def hash_password(password: str) -> str:
"""Hash a password using bcrypt"""
return pwd_context.hash(password)
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash"""
return pwd_context.verify(plain_password, hashed_password)
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
"""Create a JWT access token"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def get_current_user(authorization: Optional[str] = Header(None), db: Session = Depends(get_db)) -> User:
"""Get the current authenticated user from JWT token"""
if not authorization:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing authorization header",
headers={"WWW-Authenticate": "Bearer"},
)
try:
scheme, credentials = authorization.split()
if scheme.lower() != "bearer":
raise ValueError("Invalid auth scheme")
except (ValueError, AttributeError):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authorization header format",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(credentials, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
except JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
user = db.query(User).filter(User.username == username).first()
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
return user
def get_current_admin_user(current_user: User = Depends(get_current_user)) -> User:
"""Get the current user and verify they are an admin"""
if not current_user.is_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions. Admin access required."
)
return current_user

View File

@@ -1,10 +1,11 @@
from fastapi import FastAPI, Depends, HTTPException, APIRouter
from fastapi import FastAPI, Depends, HTTPException, APIRouter, status
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy.orm import Session
from typing import List, Optional
from datetime import datetime
from datetime import datetime, timedelta
from .database import engine, get_db, Base
from .models import Drug, DrugVariant, Dispensing
from .models import Drug, DrugVariant, Dispensing, User
from .auth import hash_password, verify_password, create_access_token, get_current_user, get_current_admin_user, ACCESS_TOKEN_EXPIRE_MINUTES
from pydantic import BaseModel
# Create tables
@@ -25,18 +26,42 @@ app.add_middleware(
router = APIRouter(prefix="/api")
# Pydantic schemas
class UserCreate(BaseModel):
username: str
password: str
class PasswordChange(BaseModel):
current_password: str
new_password: str
class AdminPasswordChange(BaseModel):
new_password: str
class UserResponse(BaseModel):
id: int
username: str
is_admin: bool
class Config:
from_attributes = True
class TokenResponse(BaseModel):
access_token: str
token_type: str
user: UserResponse
class DrugCreate(BaseModel):
name: str
description: str = None
description: Optional[str] = None
class DrugUpdate(BaseModel):
name: str = None
description: str = None
name: Optional[str] = None
description: Optional[str] = None
class DrugResponse(BaseModel):
id: int
name: str
description: str = None
description: Optional[str] = None
class Config:
from_attributes = True
@@ -67,7 +92,7 @@ class DrugVariantResponse(BaseModel):
class DrugWithVariantsResponse(BaseModel):
id: int
name: str
description: str = None
description: Optional[str] = None
variants: List[DrugVariantResponse] = []
class Config:
@@ -92,13 +117,153 @@ class DispensingResponse(BaseModel):
class Config:
from_attributes = True
# Authentication Routes
@router.post("/auth/register", response_model=TokenResponse)
def register(user_data: UserCreate, db: Session = Depends(get_db)):
"""Register the first admin user (only allowed if no users exist)"""
# Check if users already exist
user_count = db.query(User).count()
if user_count > 0:
raise HTTPException(
status_code=403,
detail="Registration is disabled. Contact an administrator to create an account."
)
# Check if user already exists
existing_user = db.query(User).filter(User.username == user_data.username).first()
if existing_user:
raise HTTPException(status_code=400, detail="Username already registered")
# First (and only allowed) user is admin
hashed_password = hash_password(user_data.password)
db_user = User(
username=user_data.username,
hashed_password=hashed_password,
is_admin=True
)
db.add(db_user)
db.commit()
db.refresh(db_user)
# Create access token
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": db_user.username},
expires_delta=access_token_expires
)
return {
"access_token": access_token,
"token_type": "bearer",
"user": db_user
}
@router.post("/auth/login", response_model=TokenResponse)
def login(user_data: UserCreate, db: Session = Depends(get_db)):
"""Login with username and password"""
user = db.query(User).filter(User.username == user_data.username).first()
if not user or not verify_password(user_data.password, user.hashed_password):
raise HTTPException(status_code=401, detail="Invalid credentials")
# Create access token
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user.username},
expires_delta=access_token_expires
)
return {
"access_token": access_token,
"token_type": "bearer",
"user": user
}
@router.get("/auth/me", response_model=UserResponse)
def get_current_user_info(current_user: User = Depends(get_current_user)):
"""Get current user info"""
return current_user
# User Management Routes (Admin only)
@router.get("/users", response_model=List[UserResponse])
def list_users(db: Session = Depends(get_db), current_user: User = Depends(get_current_admin_user)):
"""List all users (admin only)"""
return db.query(User).all()
@router.post("/users", response_model=UserResponse)
def create_user(user_data: UserCreate, db: Session = Depends(get_db), current_user: User = Depends(get_current_admin_user)):
"""Create a new user (admin only)"""
# Check if user already exists
existing_user = db.query(User).filter(User.username == user_data.username).first()
if existing_user:
raise HTTPException(status_code=400, detail="Username already exists")
hashed_password = hash_password(user_data.password)
db_user = User(
username=user_data.username,
hashed_password=hashed_password,
is_admin=False
)
db.add(db_user)
db.commit()
db.refresh(db_user)
return db_user
@router.delete("/users/{user_id}")
def delete_user(user_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_admin_user)):
"""Delete a user (admin only)"""
# Don't allow deleting yourself
if current_user.id == user_id:
raise HTTPException(status_code=400, detail="Cannot delete your own user account")
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
db.delete(user)
db.commit()
return {"message": "User deleted successfully"}
@router.post("/auth/change-password")
def change_own_password(password_data: PasswordChange, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
"""Change current user's password"""
user = db.query(User).filter(User.id == current_user.id).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
# Verify current password
if not verify_password(password_data.current_password, user.hashed_password):
raise HTTPException(status_code=401, detail="Current password is incorrect")
# Update password
user.hashed_password = hash_password(password_data.new_password)
db.commit()
return {"message": "Password changed successfully"}
@router.post("/users/{user_id}/change-password")
def admin_change_password(user_id: int, password_data: AdminPasswordChange, db: Session = Depends(get_db), current_user: User = Depends(get_current_admin_user)):
"""Change a user's password (admin only)"""
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
# Don't allow changing yourself via this endpoint
if current_user.id == user_id:
raise HTTPException(status_code=400, detail="Use /auth/change-password to change your own password")
# Update password
user.hashed_password = hash_password(password_data.new_password)
db.commit()
return {"message": "Password changed successfully"}
# Routes
@router.get("/")
def read_root():
return {"message": "Drug Inventory API"}
@router.get("/drugs", response_model=List[DrugWithVariantsResponse])
def list_drugs(db: Session = Depends(get_db)):
def list_drugs(db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
"""Get all drugs with their variants"""
drugs = db.query(Drug).all()
result = []
@@ -110,7 +275,7 @@ def list_drugs(db: Session = Depends(get_db)):
return result
@router.get("/drugs/low-stock", response_model=List[DrugWithVariantsResponse])
def low_stock_drugs(db: Session = Depends(get_db)):
def low_stock_drugs(db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
"""Get drugs with low stock variants"""
# Get variants that are low on stock
low_stock_variants = db.query(DrugVariant).filter(
@@ -130,7 +295,7 @@ def low_stock_drugs(db: Session = Depends(get_db)):
return result
@router.get("/drugs/{drug_id}", response_model=DrugWithVariantsResponse)
def get_drug(drug_id: int, db: Session = Depends(get_db)):
def get_drug(drug_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
"""Get a specific drug with its variants"""
drug = db.query(Drug).filter(Drug.id == drug_id).first()
if not drug:
@@ -142,7 +307,7 @@ def get_drug(drug_id: int, db: Session = Depends(get_db)):
return drug_dict
@router.post("/drugs", response_model=DrugWithVariantsResponse)
def create_drug(drug: DrugCreate, db: Session = Depends(get_db)):
def create_drug(drug: DrugCreate, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
"""Create a new drug"""
# Check if drug name already exists
existing = db.query(Drug).filter(Drug.name == drug.name).first()
@@ -160,7 +325,7 @@ def create_drug(drug: DrugCreate, db: Session = Depends(get_db)):
return drug_dict
@router.put("/drugs/{drug_id}", response_model=DrugWithVariantsResponse)
def update_drug(drug_id: int, drug_update: DrugUpdate, db: Session = Depends(get_db)):
def update_drug(drug_id: int, drug_update: DrugUpdate, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
"""Update a drug"""
drug = db.query(Drug).filter(Drug.id == drug_id).first()
if not drug:
@@ -178,7 +343,7 @@ def update_drug(drug_id: int, drug_update: DrugUpdate, db: Session = Depends(get
return drug_dict
@router.delete("/drugs/{drug_id}")
def delete_drug(drug_id: int, db: Session = Depends(get_db)):
def delete_drug(drug_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
"""Delete a drug and all its variants"""
drug = db.query(Drug).filter(Drug.id == drug_id).first()
if not drug:
@@ -194,7 +359,7 @@ def delete_drug(drug_id: int, db: Session = Depends(get_db)):
# Drug Variant endpoints
@router.post("/drugs/{drug_id}/variants", response_model=DrugVariantResponse)
def create_drug_variant(drug_id: int, variant: DrugVariantCreate, db: Session = Depends(get_db)):
def create_drug_variant(drug_id: int, variant: DrugVariantCreate, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
"""Create a new variant for a drug"""
# Check if drug exists
drug = db.query(Drug).filter(Drug.id == drug_id).first()
@@ -222,7 +387,7 @@ def create_drug_variant(drug_id: int, variant: DrugVariantCreate, db: Session =
return db_variant
@router.get("/variants/{variant_id}", response_model=DrugVariantResponse)
def get_drug_variant(variant_id: int, db: Session = Depends(get_db)):
def get_drug_variant(variant_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
"""Get a specific drug variant"""
variant = db.query(DrugVariant).filter(DrugVariant.id == variant_id).first()
if not variant:
@@ -230,7 +395,7 @@ def get_drug_variant(variant_id: int, db: Session = Depends(get_db)):
return variant
@router.put("/variants/{variant_id}", response_model=DrugVariantResponse)
def update_drug_variant(variant_id: int, variant_update: DrugVariantUpdate, db: Session = Depends(get_db)):
def update_drug_variant(variant_id: int, variant_update: DrugVariantUpdate, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
"""Update a drug variant"""
variant = db.query(DrugVariant).filter(DrugVariant.id == variant_id).first()
if not variant:
@@ -244,7 +409,7 @@ def update_drug_variant(variant_id: int, variant_update: DrugVariantUpdate, db:
return variant
@router.delete("/variants/{variant_id}")
def delete_drug_variant(variant_id: int, db: Session = Depends(get_db)):
def delete_drug_variant(variant_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
"""Delete a drug variant"""
variant = db.query(DrugVariant).filter(DrugVariant.id == variant_id).first()
if not variant:
@@ -257,7 +422,7 @@ def delete_drug_variant(variant_id: int, db: Session = Depends(get_db)):
# Dispensing endpoints
@router.post("/dispense", response_model=DispensingResponse)
def dispense_drug(dispensing: DispensingCreate, db: Session = Depends(get_db)):
def dispense_drug(dispensing: DispensingCreate, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
"""Record a drug dispensing and reduce inventory"""
# Check if drug variant exists
variant = db.query(DrugVariant).filter(DrugVariant.id == dispensing.drug_variant_id).first()
@@ -283,12 +448,12 @@ def dispense_drug(dispensing: DispensingCreate, db: Session = Depends(get_db)):
return db_dispensing
@router.get("/dispense/history", response_model=List[DispensingResponse])
def list_dispensings(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
def list_dispensings(skip: int = 0, limit: int = 100, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
"""Get dispensing records (audit log)"""
return db.query(Dispensing).order_by(Dispensing.dispensed_at.desc()).offset(skip).limit(limit).all()
@router.get("/drugs/{drug_id}/dispense/history", response_model=List[DispensingResponse])
def get_drug_dispensings(drug_id: int, db: Session = Depends(get_db)):
def get_drug_dispensings(drug_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
"""Get dispensing history for a specific drug (all variants)"""
# Verify drug exists
drug = db.query(Drug).filter(Drug.id == drug_id).first()
@@ -301,7 +466,7 @@ def get_drug_dispensings(drug_id: int, db: Session = Depends(get_db)):
return db.query(Dispensing).filter(Dispensing.drug_variant_id.in_(variant_ids)).order_by(Dispensing.dispensed_at.desc()).all()
@router.get("/variants/{variant_id}/dispense/history", response_model=List[DispensingResponse])
def get_variant_dispensings(variant_id: int, db: Session = Depends(get_db)):
def get_variant_dispensings(variant_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
"""Get dispensing history for a specific drug variant"""
# Verify variant exists
variant = db.query(DrugVariant).filter(DrugVariant.id == variant_id).first()

View File

@@ -1,7 +1,17 @@
from sqlalchemy import Column, Integer, String, Float, DateTime, ForeignKey
from sqlalchemy import Column, Integer, String, Float, DateTime, ForeignKey, Boolean
from sqlalchemy.sql import func
from .database import Base
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True, index=True)
username = Column(String, unique=True, index=True, nullable=False)
hashed_password = Column(String, nullable=False)
is_admin = Column(Boolean, default=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
class Drug(Base):
__tablename__ = "drugs"
@@ -31,7 +41,7 @@ class Dispensing(Base):
id = Column(Integer, primary_key=True, index=True)
drug_variant_id = Column(Integer, ForeignKey("drug_variants.id"), nullable=False)
quantity = Column(Float, nullable=False)
animal_name = Column(String, nullable=False) # Name/ID of the animal
animal_name = Column(String, nullable=True) # Name/ID of the animal (optional)
user_name = Column(String, nullable=False) # User who dispensed
dispensed_at = Column(DateTime(timezone=True), server_default=func.now(), index=True)
notes = Column(String, nullable=True)