-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoauth.py
151 lines (124 loc) · 4.81 KB
/
oauth.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import OAuth2PasswordBearer
from datetime import datetime, timedelta, timezone
from typing import Annotated, Optional
import jwt
from sqlmodel import Session, select
from models.base import TokenData, User
from settings import get_settings
from database import engine
settings = get_settings()
cookie_name = settings.COOKIE_NAME
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login")
def create_access_token(data: dict, expires_delta: Optional[int] = None) -> str:
if expires_delta is not None:
expires_delta = datetime.now(tz=timezone.utc) + expires_delta
else:
expires_delta = datetime.now(tz=timezone.utc) + timedelta(
minutes=settings.JWT_EXPIRE
)
to_encode = data.copy()
to_encode.update({"iat": datetime.now(tz=timezone.utc)})
to_encode.update({"exp": expires_delta})
to_encode.update({"iss": "sample.com"})
to_encode.update({"aud": "sample.com"})
encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET, settings.JWT_ALGO)
return encoded_jwt
def create_refresh_token(data: dict, expires_delta: Optional[int] = None) -> str:
if expires_delta is not None:
expires_delta = datetime.now(tz=timezone.utc) + expires_delta
else:
expires_delta = datetime.now(tz=timezone.utc) + timedelta(
minutes=settings.JWT_REFRESH_TOKEN_EXPIRE_MINUTES
)
to_encode = data.copy()
to_encode.update({"iat": datetime.now(tz=timezone.utc)})
to_encode.update({"exp": expires_delta})
to_encode.update({"iss": "sample.com"})
to_encode.update({"aud": "sample.com"})
encoded_jwt = jwt.encode(
to_encode, settings.JWT_REFRESH_SECRET_KEY, settings.JWT_ALGO
)
return encoded_jwt
def validate_access_token(token: str):
"""Convenience function just to validate a JWT token\nReturn Frue or False"""
try:
payload = jwt.decode(
token,
settings.JWT_SECRET,
algorithms=[settings.JWT_ALGO],
audience="sample.com",
)
user_id = payload.get("sub")
if user_id is None:
return False
except Exception as ex:
print(ex)
return False
return True
def verify_access_token(token: str, credentials_exception, credentials_expired):
"""Verify a JWT token for endpoints"""
try:
payload = jwt.decode(
token,
settings.JWT_SECRET,
algorithms=[settings.JWT_ALGO],
audience="sample.com",
)
user_id = payload.get("sub")
if user_id is None:
raise credentials_exception
organization = payload.get("organization")
orgid = payload.get("orgid")
role = payload.get("role")
user_name = payload.get("user_name")
accepted_tc = payload.get("accepted_tc")
# tenant = payload.get("tenant")
# iztc = payload.get("iztc")
impersonated = payload.get("impersonated")
impersonated_by = payload.get("impersonated_by")
token_data = TokenData(
sub=user_id,
user_name=user_name,
organization=organization,
orgid=orgid,
role=role,
accepted_tc=accepted_tc,
impersonated=impersonated,
impersonated_by=impersonated_by,
)
with Session(engine) as session:
stmnt = select(User).where(User.email == user_name)
user = session.exec(stmnt).first()
if user is None:
token_data = None
except Exception as JWTError:
if JWTError.args:
if JWTError.args[0]:
if "Signature verification failed" in str(JWTError.args[0]):
raise credentials_exception
if "expired" in str(JWTError.args[0]):
raise credentials_expired
raise credentials_exception
raise credentials_exception
return token_data
def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]) -> TokenData:
"""Returns current user from JWT"""
credentials_exception = HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
credentials_expired = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Credentials have expired",
headers={"WWW-Authenticate": "Bearer"},
)
return verify_access_token(token, credentials_exception, credentials_expired)
async def get_current_user_from_cookie(request: Request):
if request.cookies.get(cookie_name):
cookie = request.cookies.get(cookie_name)
return get_current_user(cookie)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="User is not authenticated"
)