Add user roles

This commit is contained in:
2026-03-26 15:56:56 -04:00
parent 511e8ebde4
commit 1c9fbbda6c
7 changed files with 207 additions and 36 deletions
+11 -1
View File
@@ -85,9 +85,19 @@ def get_current_user(authorization: Optional[str] = Header(None), db: Session =
def get_current_admin_user(current_user: User = Depends(get_current_user)) -> User:
"""Get the current user and verify they are an admin"""
if not current_user.is_admin:
if current_user.role != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions. Admin access required."
)
return current_user
def get_current_non_readonly_user(current_user: User = Depends(get_current_user)) -> User:
"""Get the current user and verify they are not read-only"""
if current_user.role == "readonly":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Read-only users cannot perform this action."
)
return current_user
+27 -13
View File
@@ -5,13 +5,20 @@ from typing import List, Optional, Dict, Any
from datetime import datetime, timedelta
from .database import engine, get_db, Base
from .models import Drug, DrugVariant, Dispensing, User
from .auth import hash_password, verify_password, create_access_token, get_current_user, get_current_admin_user, ACCESS_TOKEN_EXPIRE_MINUTES
from .auth import hash_password, verify_password, create_access_token, get_current_user, get_current_admin_user, get_current_non_readonly_user, ACCESS_TOKEN_EXPIRE_MINUTES
from .mqtt_service import publish_label_print_with_response
from .migrate_to_roles import migrate_users_table
from pydantic import BaseModel
# Create tables
Base.metadata.create_all(bind=engine)
# Run migration to convert is_admin to role
try:
migrate_users_table()
except Exception as e:
print(f"Warning: Migration failed: {e}. Continuing anyway...")
app = FastAPI(title="Drug Inventory API")
# CORS middleware for frontend
@@ -30,6 +37,7 @@ router = APIRouter(prefix="/api")
class UserCreate(BaseModel):
username: str
password: str
role: Optional[str] = "user" # admin, user, readonly
class PasswordChange(BaseModel):
current_password: str
@@ -41,7 +49,7 @@ class AdminPasswordChange(BaseModel):
class UserResponse(BaseModel):
id: int
username: str
is_admin: bool
role: str
class Config:
from_attributes = True
@@ -166,7 +174,7 @@ def register(user_data: UserCreate, db: Session = Depends(get_db)):
db_user = User(
username=user_data.username,
hashed_password=hashed_password,
is_admin=True
role="admin"
)
db.add(db_user)
db.commit()
@@ -224,11 +232,17 @@ def create_user(user_data: UserCreate, db: Session = Depends(get_db), current_us
if existing_user:
raise HTTPException(status_code=400, detail="Username already exists")
# Validate role
valid_roles = ["admin", "user", "readonly"]
role = user_data.role or "user"
if role not in valid_roles:
raise HTTPException(status_code=400, detail=f"Invalid role. Must be one of: {', '.join(valid_roles)}")
hashed_password = hash_password(user_data.password)
db_user = User(
username=user_data.username,
hashed_password=hashed_password,
is_admin=False
role=role
)
db.add(db_user)
db.commit()
@@ -334,7 +348,7 @@ def get_drug(drug_id: int, db: Session = Depends(get_db), current_user: User = D
return drug_dict
@router.post("/drugs", response_model=DrugWithVariantsResponse)
def create_drug(drug: DrugCreate, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
def create_drug(drug: DrugCreate, db: Session = Depends(get_db), current_user: User = Depends(get_current_non_readonly_user)):
"""Create a new drug"""
# Check if drug name already exists
existing = db.query(Drug).filter(Drug.name == drug.name).first()
@@ -352,7 +366,7 @@ def create_drug(drug: DrugCreate, db: Session = Depends(get_db), current_user: U
return drug_dict
@router.put("/drugs/{drug_id}", response_model=DrugWithVariantsResponse)
def update_drug(drug_id: int, drug_update: DrugUpdate, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
def update_drug(drug_id: int, drug_update: DrugUpdate, db: Session = Depends(get_db), current_user: User = Depends(get_current_non_readonly_user)):
"""Update a drug"""
drug = db.query(Drug).filter(Drug.id == drug_id).first()
if not drug:
@@ -370,7 +384,7 @@ def update_drug(drug_id: int, drug_update: DrugUpdate, db: Session = Depends(get
return drug_dict
@router.delete("/drugs/{drug_id}")
def delete_drug(drug_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
def delete_drug(drug_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_non_readonly_user)):
"""Delete a drug and all its variants"""
drug = db.query(Drug).filter(Drug.id == drug_id).first()
if not drug:
@@ -386,7 +400,7 @@ def delete_drug(drug_id: int, db: Session = Depends(get_db), current_user: User
# Drug Variant endpoints
@router.post("/drugs/{drug_id}/variants", response_model=DrugVariantResponse)
def create_drug_variant(drug_id: int, variant: DrugVariantCreate, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
def create_drug_variant(drug_id: int, variant: DrugVariantCreate, db: Session = Depends(get_db), current_user: User = Depends(get_current_non_readonly_user)):
"""Create a new variant for a drug"""
# Check if drug exists
drug = db.query(Drug).filter(Drug.id == drug_id).first()
@@ -422,7 +436,7 @@ def get_drug_variant(variant_id: int, db: Session = Depends(get_db), current_use
return variant
@router.put("/variants/{variant_id}", response_model=DrugVariantResponse)
def update_drug_variant(variant_id: int, variant_update: DrugVariantUpdate, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
def update_drug_variant(variant_id: int, variant_update: DrugVariantUpdate, db: Session = Depends(get_db), current_user: User = Depends(get_current_non_readonly_user)):
"""Update a drug variant"""
variant = db.query(DrugVariant).filter(DrugVariant.id == variant_id).first()
if not variant:
@@ -436,7 +450,7 @@ def update_drug_variant(variant_id: int, variant_update: DrugVariantUpdate, db:
return variant
@router.delete("/variants/{variant_id}")
def delete_drug_variant(variant_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
def delete_drug_variant(variant_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_non_readonly_user)):
"""Delete a drug variant"""
variant = db.query(DrugVariant).filter(DrugVariant.id == variant_id).first()
if not variant:
@@ -449,7 +463,7 @@ def delete_drug_variant(variant_id: int, db: Session = Depends(get_db), current_
# Dispensing endpoints
@router.post("/dispense", response_model=DispensingResponse)
def dispense_drug(dispensing: DispensingCreate, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
def dispense_drug(dispensing: DispensingCreate, db: Session = Depends(get_db), current_user: User = Depends(get_current_non_readonly_user)):
"""Record a drug dispensing and reduce inventory"""
# Check if drug variant exists
variant = db.query(DrugVariant).filter(DrugVariant.id == dispensing.drug_variant_id).first()
@@ -523,7 +537,7 @@ def capitalize_label_text(text: str) -> str:
# Label printing endpoint
@router.post("/labels/print", response_model=LabelPrintResponse)
def print_label(label_request: LabelPrintRequest, current_user: User = Depends(get_current_user)):
def print_label(label_request: LabelPrintRequest, current_user: User = Depends(get_current_non_readonly_user)):
"""
Print a drug label by publishing an MQTT message
@@ -588,7 +602,7 @@ def print_label(label_request: LabelPrintRequest, current_user: User = Depends(g
# Notes printing endpoint
@router.post("/notes/print", response_model=NotesPrintResponse)
def print_notes(notes_request: NotesPrintRequest, current_user: User = Depends(get_current_user)):
def print_notes(notes_request: NotesPrintRequest, current_user: User = Depends(get_current_non_readonly_user)):
"""
Print notes by publishing an MQTT message
+98
View File
@@ -0,0 +1,98 @@
"""
Migration script to convert is_admin boolean field to role string field
"""
import sqlite3
import os
from pathlib import Path
from sqlalchemy.engine.url import make_url
def migrate_users_table():
"""Add role column to users table and migrate data from is_admin"""
# Get database path from environment or use default
db_url = os.getenv("DATABASE_URL", "sqlite:///./data/drugs.db")
# Parse SQLite URL to get the file path
if db_url.startswith("sqlite:///"):
db_path = db_url.replace("sqlite:///", "")
# Handle relative paths
if not db_path.startswith("/"):
db_path = Path("/app/data") / "drugs.db"
else:
db_path = Path(db_path)
else:
print(f"Unsupported database URL: {db_url}")
return
if not db_path.exists():
print(f"Database does not exist at {db_path}, skipping migration")
return
print(f"Connecting to database at {db_path}")
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
try:
# Check if role column already exists
cursor.execute("PRAGMA table_info(users)")
columns = [col[1] for col in cursor.fetchall()]
if "role" in columns:
print("Role column already exists, skipping migration")
conn.close()
return
if not columns:
print("Users table does not exist yet, skipping migration")
conn.close()
return
print("Migrating users table: adding role column...")
# Add role column with default value
cursor.execute("ALTER TABLE users ADD COLUMN role VARCHAR DEFAULT 'user'")
# Migrate data from is_admin to role
if "is_admin" in columns:
print("Migrating data from is_admin to role...")
cursor.execute("""
UPDATE users
SET role = CASE
WHEN is_admin = 1 THEN 'admin'
ELSE 'user'
END
""")
# Drop the old is_admin column
# SQLite doesn't support DROP COLUMN directly in older versions,
# so we use a workaround
cursor.execute("ALTER TABLE users RENAME TO users_old")
cursor.execute("""
CREATE TABLE users (
id INTEGER PRIMARY KEY,
username VARCHAR UNIQUE NOT NULL,
hashed_password VARCHAR NOT NULL,
role VARCHAR NOT NULL DEFAULT 'user',
created_at DATETIME
)
""")
cursor.execute("""
INSERT INTO users (id, username, hashed_password, role, created_at)
SELECT id, username, hashed_password, role, created_at FROM users_old
""")
cursor.execute("DROP TABLE users_old")
print("Successfully migrated is_admin to role and cleaned up old column")
conn.commit()
print("Migration completed successfully!")
except sqlite3.OperationalError as e:
print(f"Migration error: {e}")
conn.rollback()
raise
finally:
conn.close()
if __name__ == "__main__":
migrate_users_table()
+1 -1
View File
@@ -8,7 +8,7 @@ class User(Base):
id = Column(Integer, primary_key=True, index=True)
username = Column(String, unique=True, index=True, nullable=False)
hashed_password = Column(String, nullable=False)
is_admin = Column(Boolean, default=False)
role = Column(String, default="user", nullable=False) # admin, user, readonly
created_at = Column(DateTime(timezone=True), server_default=func.now())