Skip to content

Commit

Permalink
Add baseline Architecture to support auth + User sessions + Basic Auth (
Browse files Browse the repository at this point in the history
cohere-ai#90)

* Auth wip

* Add tests

* working basic auth

* add test coverage: still todo add session tests

* Add session tests

* Add docs

* fix types

* merge alembic

* add secret key fixture

* fix tests
  • Loading branch information
tianjing-li authored and lakshyaag committed May 8, 2024
1 parent 968e2ef commit 110d47b
Show file tree
Hide file tree
Showing 24 changed files with 770 additions and 46 deletions.
5 changes: 4 additions & 1 deletion .env-template
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,7 @@ AZURE_CHAT_ENDPOINT_URL=<ENDPOINT URL>
USE_EXPERIMENTAL_LANGCHAIN=False

# Community features
USE_COMMUNITY_FEATURES='True'
USE_COMMUNITY_FEATURES='True'

# Auth session
SESSION_SECRET_KEY=<GENERATE_A_SECRET_KEY>
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ reset-db:
docker volume rm cohere_toolkit_db
setup:
poetry install --only setup --verbose
poetry run python3 src/backend/cli/main.py
poetry run python3 cli/main.py
lint:
poetry run black .
poetry run isort .
Expand Down
170 changes: 157 additions & 13 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ py-expression-eval = "^0.3.14"
tavily-python = "^0.3.3"
arxiv = "^2.1.0"
xmltodict = "^0.13.0"
authlib = "^1.3.0"
itsdangerous = "^2.2.0"
bcrypt = "^4.1.2"

[tool.poetry.group.dev.dependencies]
pytest = "^7.1.2"
Expand Down
34 changes: 34 additions & 0 deletions src/backend/alembic/versions/b88f00283a27_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""empty message
Revision ID: b88f00283a27
Revises: 2853273872ca
Create Date: 2024-05-02 19:19:52.608062
"""

from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "b88f00283a27"
down_revision: Union[str, None] = "2853273872ca"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"users", sa.Column("hashed_password", sa.LargeBinary(), nullable=True)
)
op.create_unique_constraint("unique_user_email", "users", ["email"])
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint("unique_user_email", "users", type_="unique")
op.drop_column("users", "hashed_password")
# ### end Alembic commands ###
26 changes: 26 additions & 0 deletions src/backend/alembic/versions/c15b848babe3_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""empty message
Revision ID: c15b848babe3
Revises: 6553b76de6ca, b88f00283a27
Create Date: 2024-05-07 15:59:05.436751
"""

from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "c15b848babe3"
down_revision: Union[str, None] = ("6553b76de6ca", "b88f00283a27")
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
pass


def downgrade() -> None:
pass
9 changes: 9 additions & 0 deletions src/backend/config/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from backend.services.auth import BasicAuthentication

# Modify this to enable auth strategies.
ENABLED_AUTH_STRATEGIES = []

# Define the mapping from Auth strategy name to class obj.
# Does not need to be manually modified.
# Ex: {"Basic": BasicAuthentication}
ENABLED_AUTH_STRATEGY_MAPPING = {cls.NAME: cls for cls in ENABLED_AUTH_STRATEGIES}
2 changes: 1 addition & 1 deletion src/backend/crud/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def update_user(db: Session, user: User, new_user: UpdateUser) -> User:
Returns:
User: Updated user.
"""
for attr, value in new_user.model_dump().items():
for attr, value in new_user.model_dump(exclude_none=True).items():
setattr(user, attr, value)
db.commit()
db.refresh(user)
Expand Down
4 changes: 4 additions & 0 deletions src/backend/database_models/user.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional

from sqlalchemy import UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column

from backend.database_models.base import Base
Expand All @@ -10,3 +11,6 @@ class User(Base):

fullname: Mapped[str] = mapped_column()
email: Mapped[Optional[str]] = mapped_column()
hashed_password: Mapped[Optional[bytes]] = mapped_column()

__table_args__ = (UniqueConstraint("email", name="unique_user_email"),)
36 changes: 32 additions & 4 deletions src/backend/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import os
from contextlib import asynccontextmanager

from alembic.command import upgrade
from alembic.config import Config
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.sessions import SessionMiddleware

from backend.config.auth import ENABLED_AUTH_STRATEGY_MAPPING
from backend.routers.auth import router as auth_router
from backend.routers.chat import router as chat_router
from backend.routers.conversation import router as conversation_router
from backend.routers.deployment import router as deployment_router
Expand All @@ -15,32 +19,56 @@

load_dotenv()

ORIGINS = ["*"]


@asynccontextmanager
async def lifespan(app: FastAPI):
yield


origins = ["*"]


def create_app():
app = FastAPI(lifespan=lifespan)

# Add routers
app.include_router(auth_router)
app.include_router(chat_router)
app.include_router(user_router)
app.include_router(conversation_router)
app.include_router(tool_router)
app.include_router(deployment_router)
app.include_router(experimental_feature_router)

# Add middleware
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_origins=ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

if ENABLED_AUTH_STRATEGY_MAPPING:
secret_key = os.environ.get("SESSION_SECRET_KEY", None)

if not secret_key:
raise ValueError(
"Missing SESSION_SECRET_KEY environment variable to enable Authentication."
)

# Handle User sessions and Auth
app.add_middleware(
SessionMiddleware,
secret_key=secret_key,
)

# Add auth
for auth in ENABLED_AUTH_STRATEGY_MAPPING.values():
if auth.SHOULD_ATTACH_TO_APP:
# TODO: Add app attachment logic for eg OAuth:
# https://docs.authlib.org/en/latest/client/fastapi.html
pass

return app


Expand Down
12 changes: 6 additions & 6 deletions src/backend/model_deployments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ class BaseDeployment:
@abstractmethod
def rerank_enabled(self) -> bool: ...

@staticmethod
def list_models() -> List[str]: ...

@staticmethod
def is_available() -> bool: ...

@abstractmethod
def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any: ...

Expand All @@ -45,9 +51,3 @@ def invoke_rerank(

@abstractmethod
def invoke_tools(self, message: str, tools: List[Any], **kwargs: Any) -> Any: ...

@staticmethod
def list_models() -> List[str]: ...

@staticmethod
def is_available() -> bool: ...
102 changes: 102 additions & 0 deletions src/backend/routers/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from fastapi import APIRouter, Depends, HTTPException
from starlette.requests import Request

from backend.config.auth import ENABLED_AUTH_STRATEGY_MAPPING
from backend.database_models import get_session
from backend.database_models.database import DBSessionDep
from backend.schemas.auth import Login

router = APIRouter(dependencies=[Depends(get_session)])


@router.get("/session")
def get_session(request: Request):
"""
Retrievers the current session user.
Args:
request (Request): current Request object.
Returns:
session: current user session ({} if no active session)
Raises:
401 HTTPException if no user found in session.
"""

if not request.session:
raise HTTPException(status_code=401, detail="Not authenticated.")

return request.session.get("user")


@router.post("/login")
async def login(request: Request, login: Login, session: DBSessionDep):
"""
Logs user in, verifying their credentials and either setting the user session,
or redirecting to /auth endpoint.
Args:er
request (Request): current Request object.
login (Login): Login payload.
session (DBSessionDep): Database session.
Returns:
dict: On success.
Raises:
HTTPException: If the strategy or payload are invalid, or if the login fails.
"""
strategy_name = login.strategy
payload = login.payload

# Check the strategy is valid and enabled
if strategy_name not in ENABLED_AUTH_STRATEGY_MAPPING.keys():
raise HTTPException(
status_code=404, detail=f"Invalid Authentication strategy: {strategy_name}."
)

# Check that the payload required is given
strategy = ENABLED_AUTH_STRATEGY_MAPPING[strategy_name]
strategy_payload = strategy.get_required_payload()
if not set(strategy_payload).issubset(payload):
missing_keys = [key for key in strategy_payload if key not in payload.keys()]
raise HTTPException(
status_code=404,
detail=f"Missing the following keys in the payload: {missing_keys}.",
)

# Do login
user = strategy.login(session, payload)
if not user:
raise HTTPException(
status_code=401,
detail=f"Error performing {strategy_name} authentication with payload: {payload}.",
)

# Set session user
request.session["user"] = user

return {}


@router.post("/auth")
async def auth(request: Request):
# TODO: Implement for OAuth strategies
return {}


@router.get("/logout")
async def logout(request: Request):
"""
Logs out the current user session.
Args:
request (Request): current Request object.
Returns:
dict: On success.
"""
request.session.pop("user", None)

return {}
2 changes: 1 addition & 1 deletion src/backend/routers/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def create_user(user: CreateUser, session: DBSessionDep) -> User:
Returns:
User: Created user.
"""
db_user = UserModel(**user.model_dump())
db_user = UserModel(**user.model_dump(exclude_none=True))
db_user = user_crud.create_user(session, db_user)

return db_user
Expand Down
9 changes: 9 additions & 0 deletions src/backend/schemas/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from pydantic import BaseModel


class Login(BaseModel):
strategy: str
payload: dict[str, str]

class Config:
from_attributes = True
24 changes: 21 additions & 3 deletions src/backend/schemas/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from pydantic import BaseModel

from backend.services.auth import BasicAuthentication


class UserBase(BaseModel):
fullname: str
Expand All @@ -18,13 +20,29 @@ class Config:
from_attributes = True


class CreateUser(UserBase):
pass
class UserPassword(BaseModel):
password: Optional[str] = None
hashed_password: Optional[bytes] = None

def __init__(self, **data):
password = data.pop("password", None)

if password is not None:
data["hashed_password"] = BasicAuthentication.hash_and_salt_password(
password
)

super().__init__(**data)

class UpdateUser(UserBase):

class CreateUser(UserBase, UserPassword):
pass


class UpdateUser(UserPassword):
fullname: Optional[str] = None
email: Optional[str] = None


class DeleteUser(BaseModel):
pass
Empty file.
5 changes: 5 additions & 0 deletions src/backend/services/auth/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from backend.services.auth.basic import BasicAuthentication

__all__ = [
"BasicAuthentication",
]

0 comments on commit 110d47b

Please sign in to comment.