Skip to content

Commit

Permalink
Fixes from pre-commit rules on current main (#154)
Browse files Browse the repository at this point in the history
* format fixes

* Update 20241028204107_684a330f9e9f_new_migration.py
  • Loading branch information
vineetshar authored Nov 1, 2024
1 parent da00fb8 commit 6b76511
Show file tree
Hide file tree
Showing 11 changed files with 121 additions and 69 deletions.
2 changes: 1 addition & 1 deletion GETTING_STARTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ curl -X 'POST' \
```
## Step 6: Send Messages in a Conversation

This API returns a stream response for the
This API returns a stream response for the
```bash
curl -X 'POST' \
'http://localhost:8001/api/v1/conversations/1234/message/' \
Expand Down
18 changes: 11 additions & 7 deletions app/alembic/versions/20241028204107_684a330f9e9f_new_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,34 @@
Create Date: 2024-10-28 20:41:07.469748
"""

from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from alembic import op

# revision identifiers, used by Alembic.
revision: str = '20241028204107_684a330f9e9f'
down_revision: Union[str, None] = '20241003153813_827623103002'
revision: str = "20241028204107_684a330f9e9f"
down_revision: Union[str, None] = "20241003153813_827623103002"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.execute("CREATE TYPE visibility AS ENUM ('PRIVATE', 'PUBLIC')")
op.add_column('conversations', sa.Column('visibility', sa.Enum('PRIVATE', 'PUBLIC', name='visibility'), nullable=False, server_default='PRIVATE'))
op.add_column(
"conversations",
sa.Column(
"visibility", sa.Enum("PRIVATE", "PUBLIC", name="visibility"), nullable=True
),
)

# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('conversations', sa.Column('is_public', sa.BOOLEAN(), autoincrement=False, nullable=True))
op.drop_column('conversations', 'visibility')
op.drop_column("conversations", "visibility")
op.execute("DROP TYPE visibility")
# ### end Alembic commands ###
9 changes: 3 additions & 6 deletions app/modules/auth/auth_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@

from app.core.database import get_db
from app.modules.auth.auth_schema import LoginRequest
from app.modules.auth.auth_service import auth_handler
from app.modules.users.user_schema import CreateUser
from app.modules.users.user_service import UserService
from app.modules.utils.APIRouter import APIRouter
from app.modules.utils.posthog_helper import PostHogClient
from app.modules.auth.auth_service import auth_handler

SLACK_WEBHOOK_URL = os.getenv("SLACK_WEBHOOK_URL", None)

Expand All @@ -29,7 +29,6 @@ async def send_slack_message(message: str):


class AuthAPI:

@auth_router.post("/login")
async def login(login_request: LoginRequest):
email, password = login_request.email, login_request.password
Expand All @@ -39,10 +38,8 @@ async def login(login_request: LoginRequest):
id_token = res.get("idToken")
return JSONResponse(content={"token": id_token}, status_code=200)
except Exception as e:
return JSONResponse(
content={"error": f"ERROR: {str(e)}"}, status_code=400
)

return JSONResponse(content={"error": f"ERROR: {str(e)}"}, status_code=400)

@auth_router.post("/signup")
async def signup(request: Request, db: Session = Depends(get_db)):
body = json.loads(await request.body())
Expand Down
11 changes: 7 additions & 4 deletions app/modules/conversations/access/access_schema.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from typing import List
from typing import List, Optional

from pydantic import BaseModel, EmailStr
from typing import Optional, List

from app.modules.conversations.conversation.conversation_model import Visibility


class ShareChatRequest(BaseModel):
conversation_id: str
recipientEmails: Optional[List[EmailStr]]=None
conversation_id: str
recipientEmails: Optional[List[EmailStr]] = None
visibility: Visibility


class ShareChatResponse(BaseModel):
message: str
sharedID: str
Expand All @@ -17,5 +19,6 @@ class ShareChatResponse(BaseModel):
class SharedChatResponse(BaseModel):
chat: dict


class RemoveAccessRequest(BaseModel):
emails: List[EmailStr]
104 changes: 70 additions & 34 deletions app/modules/conversations/access/access_service.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import List
from uuid6 import uuid7
from app.modules.conversations.conversation.conversation_model import Conversation, Visibility
from fastapi import HTTPException

from fastapi import HTTPException
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session

from app.modules.conversations.conversation.conversation_model import Conversation
from app.modules.conversations.conversation.conversation_model import (
Conversation,
Visibility,
)


class ShareChatServiceError(Exception):
Expand All @@ -17,16 +18,28 @@ class ShareChatService:
def __init__(self, db: Session):
self.db = db

async def share_chat(self, conversation_id: str, user_id: str, recipient_emails: List[str] = None, visibility: Visibility = Visibility.PRIVATE) -> str:
chat = self.db.query(Conversation).filter_by(id=conversation_id, user_id=user_id).first()
async def share_chat(
self,
conversation_id: str,
user_id: str,
recipient_emails: List[str] = None,
visibility: Visibility = Visibility.PRIVATE,
) -> str:
chat = (
self.db.query(Conversation)
.filter_by(id=conversation_id, user_id=user_id)
.first()
)
if not chat:
raise HTTPException(404,"Chat does not exist or you are not authorized to access it.")

raise HTTPException(
404, "Chat does not exist or you are not authorized to access it."
)

if visibility == Visibility.PUBLIC:
chat.visibility = Visibility.PUBLIC
self.db.commit()
return conversation_id

if visibility == Visibility.PRIVATE:
chat.visibility = Visibility.PRIVATE
if recipient_emails:
Expand All @@ -35,66 +48,89 @@ async def share_chat(self, conversation_id: str, user_id: str, recipient_emails:
unique_new_emails_set = set(recipient_emails)

if unique_new_emails_set.issubset(existing_emails_set):
raise ShareChatServiceError("All provided emails have already been shared.")
raise ShareChatServiceError(
"All provided emails have already been shared."
)

to_share = unique_new_emails_set - existing_emails_set
if to_share:
try:
updated_emails = existing_emails + list(to_share)
self.db.query(Conversation).filter_by(id=conversation_id).update(
{Conversation.shared_with_emails: updated_emails, Conversation.visibility: visibility},
synchronize_session=False
self.db.query(Conversation).filter_by(
id=conversation_id
).update(
{
Conversation.shared_with_emails: updated_emails,
Conversation.visibility: visibility,
},
synchronize_session=False,
)
self.db.commit()
self.db.commit()
except IntegrityError as e:
self.db.rollback()
raise ShareChatServiceError("Failed to update shared chat due to a database integrity error.") from e
raise ShareChatServiceError(
"Failed to update shared chat due to a database integrity error."
) from e
self.db.commit()
return conversation_id
else:
self.db.query(Conversation).filter_by(id=conversation_id).update(
{Conversation.visibility: visibility},
synchronize_session=False
)
self.db.commit()
self.db.query(Conversation).filter_by(id=conversation_id).update(
{Conversation.visibility: visibility}, synchronize_session=False
)
self.db.commit()

return conversation_id

async def get_shared_emails(self, conversation_id: str, user_id: str) -> List[str]:

chat = self.db.query(Conversation).filter_by(id=conversation_id, user_id=user_id).first()
chat = (
self.db.query(Conversation)
.filter_by(id=conversation_id, user_id=user_id)
.first()
)
if not chat:
raise HTTPException(404,"Chat does not exist or you are not authorized to access it.")

raise HTTPException(
404, "Chat does not exist or you are not authorized to access it."
)

return chat.shared_with_emails or []

async def remove_access(self, conversation_id: str, user_id: str, emails_to_remove: List[str]) -> bool:
async def remove_access(
self, conversation_id: str, user_id: str, emails_to_remove: List[str]
) -> bool:
"""Remove access for specified emails from a conversation."""
chat = self.db.query(Conversation).filter_by(id=conversation_id, user_id=user_id).first()
chat = (
self.db.query(Conversation)
.filter_by(id=conversation_id, user_id=user_id)
.first()
)
if not chat:
raise HTTPException(
status_code=404,
detail="Chat does not exist or you are not authorized to access it."
detail="Chat does not exist or you are not authorized to access it.",
)

if not chat.shared_with_emails:
raise ShareChatServiceError("Chat has no shared access to remove.")

existing_emails = set(chat.shared_with_emails)
emails_to_remove_set = set(emails_to_remove)

# Check if any of the emails to remove actually have access
if not emails_to_remove_set.intersection(existing_emails):
raise ShareChatServiceError("None of the specified emails have access to this chat.")

raise ShareChatServiceError(
"None of the specified emails have access to this chat."
)

try:
updated_emails = list(existing_emails - emails_to_remove_set)
self.db.query(Conversation).filter_by(id=conversation_id).update(
{Conversation.shared_with_emails: updated_emails},
synchronize_session=False
synchronize_session=False,
)
self.db.commit()
return True
except IntegrityError as e:
self.db.rollback()
raise ShareChatServiceError("Failed to update shared chat due to a database integrity error.") from e
raise ShareChatServiceError(
"Failed to update shared chat due to a database integrity error."
) from e
8 changes: 6 additions & 2 deletions app/modules/conversations/conversation/conversation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from sqlalchemy import ARRAY, TIMESTAMP, Column
from sqlalchemy import Enum as SQLAEnum
from sqlalchemy import ForeignKey, String, func, Boolean
from sqlalchemy import ForeignKey, String, func
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import relationship

Expand All @@ -15,10 +15,12 @@ class ConversationStatus(enum.Enum):
ARCHIVED = "archived"
DELETED = "deleted"


class Visibility(enum.Enum):
PRIVATE = "private"
PUBLIC = "public"


class Conversation(Base):
__tablename__ = "conversations"

Expand All @@ -43,7 +45,9 @@ class Conversation(Base):
nullable=False,
)
shared_with_emails = Column(ARRAY(String), nullable=True)
visibility = Column(SQLAEnum(Visibility), default=Visibility.PRIVATE, nullable=False)
visibility = Column(
SQLAEnum(Visibility), default=Visibility.PRIVATE, nullable=False
)
# Relationships
user = relationship("User", back_populates="conversations")
messages = relationship(
Expand Down
10 changes: 6 additions & 4 deletions app/modules/conversations/conversation/conversation_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from app.modules.conversations.conversation.conversation_model import (
Conversation,
ConversationStatus,
Visibility
Visibility,
)
from app.modules.conversations.conversation.conversation_schema import (
ConversationAccessType,
Expand Down Expand Up @@ -104,11 +104,13 @@ async def check_conversation_access(
self.sql_db.query(Conversation).filter_by(id=conversation_id).first()
)
if not conversation:
return ConversationAccessType.NOT_FOUND # Return 'not found' if conversation doesn't exist

return (
ConversationAccessType.NOT_FOUND
) # Return 'not found' if conversation doesn't exist

if conversation.visibility == Visibility.PUBLIC:
return ConversationAccessType.READ

if user_id == conversation.user_id: # Check if the user is the creator
return ConversationAccessType.WRITE # Creator can write
# Check if the conversation is shared
Expand Down
15 changes: 11 additions & 4 deletions app/modules/conversations/conversations_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from app.core.database import get_db
from app.modules.auth.auth_service import AuthService
from app.modules.conversations.access.access_schema import (
RemoveAccessRequest,
ShareChatRequest,
ShareChatResponse,
RemoveAccessRequest,
)
from app.modules.conversations.access.access_service import (
ShareChatService,
Expand Down Expand Up @@ -163,13 +163,19 @@ async def share_chat(
user_id = user["user_id"]
service = ShareChatService(db)
try:
shared_conversation = await service.share_chat(request.conversation_id, user_id, request.recipientEmails, request.visibility)
shared_conversation = await service.share_chat(
request.conversation_id,
user_id,
request.recipientEmails,
request.visibility,
)
return ShareChatResponse(
message="Chat shared successfully!", sharedID=shared_conversation
)
except ShareChatServiceError as e:
raise HTTPException(status_code=400, detail=str(e))


@router.get("/conversations/{conversation_id}/shared-emails", response_model=List[str])
async def get_shared_emails(
conversation_id: str,
Expand All @@ -181,12 +187,13 @@ async def get_shared_emails(
shared_emails = await service.get_shared_emails(conversation_id, user_id)
return shared_emails


@router.delete("/conversations/{conversation_id}/access")
async def remove_access(
conversation_id: str,
request: RemoveAccessRequest,
user: str = Depends(AuthService.check_auth),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
) -> dict:
"""Remove access for specified emails from a conversation."""
share_service = ShareChatService(db)
Expand All @@ -195,7 +202,7 @@ async def remove_access(
await share_service.remove_access(
conversation_id=conversation_id,
user_id=current_user_id,
emails_to_remove=request.emails
emails_to_remove=request.emails,
)
return {"message": "Access removed successfully"}
except ShareChatServiceError as e:
Expand Down
Loading

0 comments on commit 6b76511

Please sign in to comment.