From 8bbfa8ccc82503808bf38345e7bf29b7091bf241 Mon Sep 17 00:00:00 2001 From: ericcccliu Date: Sun, 31 Mar 2024 02:27:20 -0500 Subject: [PATCH] implement rate limiting --- api/models/__pycache__/user.cpython-311.pyc | Bin 571 -> 808 bytes api/models/user.py | 6 +- .../__pycache__/auth_routes.cpython-311.pyc | Bin 6302 -> 6304 bytes .../conversation_routes.cpython-311.pyc | Bin 6703 -> 8500 bytes api/routes/auth_routes.py | 2 + api/routes/conversation_routes.py | 32 +++- .../__pycache__/auth_utils.cpython-311.pyc | Bin 1862 -> 2764 bytes .../conversation_utils.cpython-311.pyc | Bin 5930 -> 7934 bytes .../__pycache__/db_utils.cpython-311.pyc | Bin 1092 -> 1094 bytes .../llm_provider_info.cpython-311.pyc | Bin 0 -> 1511 bytes .../__pycache__/llm_providers.cpython-311.pyc | Bin 0 -> 1507 bytes .../__pycache__/llm_utils.cpython-311.pyc | Bin 3425 -> 3427 bytes api/utils/auth_utils.py | 24 ++- api/utils/conversation_utils.py | 42 ++++- api/utils/db_utils.py | 2 + api/utils/llm_provider_info.py | 16 ++ .../__pycache__/anthropic.cpython-311.pyc | Bin 2492 -> 2993 bytes .../__pycache__/openai.cpython-311.pyc | Bin 1430 -> 2477 bytes api/utils/llm_providers/anthropic.py | 13 +- api/utils/llm_providers/openai.py | 21 ++- api/utils/llm_utils.py | 2 + app/llm_providers.tsx | 8 +- app/page.tsx | 150 ++++++++++++++---- app/utils.ts | 18 ++- next.config.js | 2 + requirements.txt | 4 +- 26 files changed, 293 insertions(+), 49 deletions(-) create mode 100644 api/utils/__pycache__/llm_provider_info.cpython-311.pyc create mode 100644 api/utils/__pycache__/llm_providers.cpython-311.pyc create mode 100644 api/utils/llm_provider_info.py diff --git a/api/models/__pycache__/user.cpython-311.pyc b/api/models/__pycache__/user.cpython-311.pyc index d66d344168aafdbcfe7798bd732627f5a54d1eab..2573f7adbdf3b7f0d76ed53bba795cff8b0a6ff0 100644 GIT binary patch delta 509 zcmZusze~eV5Wd&Eq-kp_(iU4R4y^^_prcrdgNuV$2NwxQ2)?uy5-s^LbW+g4v7v+h z1uD3^_)nx;2neojEjT%OFQJ3zs2sLX-2($g zOt6Sb2N2fp=qnH()jTyYT59c}E9jjOQU3~1#6 zf!l;O!)opNU{nrh4Q)be;c7jg<&hguyz7$~k=UhqvyLK0GGD^&TSh!Ak_SVkU%34< zo6f83B4_cVGO};EZCx_YTsL;@3TlGxoFYsGLdXYHgWnd8XQa3-k;WAM!rjh{N5B3_ H#S?!3P%e1# delta 286 zcmZ3%wwr}-IWI340}$xk$<0qG x%}KQ@5(07=fwqepbfZ%^adRjC{&W3=AKb S85tQrurV<5Zf+2hV*&uzd=2gZ delta 76 zcmZ2rIM0xGIWI340}#|bV8`iw&rqTu zqYjDCAu)E>l}W3_pE0H~trF8TRr}F@V^TN0AWvmaL*0Hb_1BcDo3vGvwsY^Z9S6g9 z_1<~+p09h(x#yhwem?MvuRDIy*w}zzd~UNY{=<9I(aH}QZ}!akIV2*{oOR@!31`lg zaOK#3RXY2h_b!b+VD@0Gw59yu$$& z-ZD1Cqj4lUt{~AV+3pyC25*@bW;7Sj+;y}Fqxn5R^wtrhQpfF`#LasRNc53*Lu3&4 zHcG+mdoS{0lN1)4r5%e#vE_>CDSJC>l(bSML(fstCUuGJQj{qYM1P&4Zcx-itsyt~ zX0qGh+RKSlCLnQAV_U^rRF)6c$;Uwcg|ast|Id(+6e%kT*C~1t6z!@F%fNFyq9NhM zNS%V+H42y|&xqD76>_}z(@8A2Zab* zk{jkW@&jRhwv{gFsbD{Vm<8{k^8f`w3K+eE&O#JaIdPaSnX)LIbyh$d88#o$cj<&( zPLRKugT5`kZpueNkThE&I$t*#w*2mZ9gvzy%B3<{W{bWx-7sdvF)H=~<>1BBco$4$ zBY+SL93`XH4pT2I!z5$HT!j3|8s~NqTSF)+0QTGPelgD@bk%$Wap+y7RO=BT#XxR& zgJjU{=(=RwP5`m;5RIbgp5YFtc%LVI?q=U5lVYk65n4lQ{82sW;H!*oEd1CI+_-Jx! z^62F8RYQNgFUDzRQCdtDvWn)`dAG6biT9;z4jR!$yBFRsjsbXw(^Fxsh)#wsaS({I zod{jy&sN3G!jbFa9Gm3@;H-A+dGs#Fq3dB0L4k79kfzjm&6LZFlAgnwEh8rvv#BNd zOlBo!(D-P84Gf&9s=!rxZ;9@1cod2N;ca1JHaT-M1R3BvYapQp*>y zhceMB)@(WxqhXr1xQBv93d|I^0YC$z%>nmOK>G!b1JK&V;xl$gGS4)3jp8!OUXUeR z2|YdlDr8!2&~4~A$WBUg*BuW&3QS#;o2F3tK%RuVetcft8-GPT{+imI*yv7_0ES;L zg`{C5rNz%kV^v*DO2Ic8ML%oA_>ZQ%n)sc%g6j&G1*yP11J#_VD4F*#$8a(^b* z?FU%jUX78zSbNCd?eB3t?Lk;C9Q9-&|eK+J$?(_<&hP7vyYoy|1G(zDUbav`0D z{I2cg|7!G`W`KB*v^h>gFJE)~71XDl-6x(>FJKo{v~_{eRMKw_nlCYvzS}ItRx<(9 z;+JVtW+w@E$EVH?({Wf&5rgXr7rDRga&Z&1acK={ETt4lOr{iCx*1Vd{!O68cuem^ z&6<((i+P*_4jvCujFFbMOq&I}nLc$7ai_@+kJuK`#b{?|A%XDdR*@fi`~od~t%Yp5 zTLgRx=1KC2=Qw!3%X@X=DZdR{g|jWJTRyUU-MSvB>gFX$A;+vf(*mr{5ub0~K@D0J zdRaZCdD8C<4y>WGfYb-ft8EbGto>06ZmV70Hn2av3vOp|nS;M$?negbh5R&lH%^%+ z0ch@QN>=ouP8L=~h$*B(m5fhQ#smO}VbzIh2&n=mSYMWBZ4jpctH3MM*yw$46p}l< zyA(R~s3lrz>3wwM)zXoS8jx>&s}y+YlfdZ1!01L`yc8H$;mOOJW}{u$L^?1DtO-uX z`p6r;;fM?R8_r4!^xJZ)qQ$RSRxXOEWhIlwR`Q`Qz%`IZz6kk`uXUJZqUKGlWa0%S zla=GyY%a-WtSM>!?UWUqKc5ky!DN;pc|0WE)ZRkxWz9+X*zF<5X$QhSl5A=%Hc^_O zdydCo!G7oTUoQG9r5RW9vV!SNz|xYnAWGD-fBY&K21qAJ?$2WQ>O1%>gE X1GsD4nzI?z+zhgte1(OfSvLF^<7hU~ delta 2072 zcmZ`(U1%It6rMXfv)SDwoBjXEX0w~$X459E617c38Va^hHG$TM~d#T zutrie5Jd&!Q9-4mK19KQLVS`wC`g|aR0L(9YG1_1XhEn?o^xlDY!Y*q`R>m-_dDl) z_w4=R{AVMz7atf!LDwh*9OcIa`1)pfAafYCSK-#l3>Ll3p z6jUW5PtZj`Hz6So_nRJbBlN%}-?V?15Z?FY0)6C1&<;QNx?R6MxbORJv>6v2-7FiW zMVX5mngSLyLBfgw@=*D#7{r04A>^@%t~UZXeF$e+=nV)-d3^wGMDuVf;AewyJCG0P zXnqmDd;h?P@%Smumf5x4Qr=WIR4i@$-1fX?lYhM}E6W!Q=9{Oiqb-5H^uzQ6LgXH$b@Yc{2G?`G5fz%Q_ZZxUi{({MMKT*1rmBai4O zXdxh4ah!k>5g`IELcuE~5yP}VDv1*aPD{-)MX_vJmT}c|TIMQC&zQn8>{4aPqHUxe z=|{d{gYzBNqNvntGbYZWp&LKzJ%m5;cSC`b;W)bkkA}P8-B2(41zrh{x{4R-iu>SR z_$KRz^+?aq2<{UP5j>1g@CbVO+ylt8H8<@sS{;D9kwllwC0?s?!!~!)vD|#bIdmI> zRCJc*;koD^7-KeeZ2CY9uV^E-y|R{aQEj_Q>2AmgTG&-ejN{azJFhDOzs7Rv=vlA{}jhxvMTcxQAQ8dQHM`C8cAcBrd@2Tin`EKXnY!C>8G0VIGQXB!sqSd&A#fh+*qdgdo~dRvtnV_OrsH(5y0B@3Z^LH`4$vXhGd diff --git a/api/routes/auth_routes.py b/api/routes/auth_routes.py index 1d4bbc2..0a7982d 100644 --- a/api/routes/auth_routes.py +++ b/api/routes/auth_routes.py @@ -1,3 +1,5 @@ +#api/routes/auth_routes.py + import httpx from fastapi import APIRouter, Request, HTTPException, Depends from authlib.integrations.starlette_client import OAuth diff --git a/api/routes/conversation_routes.py b/api/routes/conversation_routes.py index 58caab0..cf55b0e 100644 --- a/api/routes/conversation_routes.py +++ b/api/routes/conversation_routes.py @@ -3,15 +3,19 @@ from typing import List from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import StreamingResponse +from starlette.config import Config from pydantic import BaseModel from api.models.conversation import Message from api.utils.conversation_utils import update_conversation_messages, create_conversation, add_message, get_conversation_by_id, get_conversations_by_user, update_conversation_model from api.utils.llm_utils import generate_response_stream +from api.utils.llm_provider_info import LLM_PROVIDERS from api.utils.auth_utils import get_current_user from api.models.user import User from api.utils.db_utils import get_db + router = APIRouter() +config = Config('.env') class ConversationCreate(BaseModel): model_provider: str @@ -33,10 +37,25 @@ async def create_conversation_route(conversation_create: ConversationCreate, cur @router.post("/conversations/{conversation_id}/message") async def add_message_route(conversation_id: str, message_create: MessageCreate, current_user: User = Depends(get_current_user), db = Depends(get_db)): + daily_flagship_usage_limit = float(config('DAILY_FLAGSHIP_USAGE_LIMIT', default='1.0')) + daily_usage_limit = float(config('DAILY_USAGE_LIMIT', default='2.0')) + + conversation = await get_conversation_by_id(conversation_id, current_user.email) + + # Check if the user has exceeded their daily flagship limit + if current_user.daily_flagship_usage >= daily_flagship_usage_limit: + if any(p.model_name == conversation.model.name and p.is_flagship for p in LLM_PROVIDERS): + raise HTTPException(status_code=429, detail="Daily flagship usage limit exceeded") + + # Check if the user has exceeded their daily general limit + if current_user.daily_usage >= daily_usage_limit: + raise HTTPException(status_code=429, detail="Daily usage limit exceeded") + user_message = Message(role='user', content=message_create.message) + await add_message(conversation_id, user_message) - conversation = await get_conversation_by_id(conversation_id, current_user.email) + if conversation: return StreamingResponse(generate_response_stream(conversation), media_type="text/event-stream") else: @@ -91,4 +110,13 @@ async def get_conversation_route(conversation_id: str, current_user: User = Depe if conversation: return conversation else: - raise HTTPException(status_code=404, detail="Conversation not found") \ No newline at end of file + raise HTTPException(status_code=404, detail="Conversation not found") + + +@router.get("/usage") +async def get_usage_route(current_user: User = Depends(get_current_user), db = Depends(get_db)): + return { + "daily_flagship_usage": current_user.daily_flagship_usage, + "daily_usage": current_user.daily_usage, + "last_usage_update": current_user.last_usage_update.isoformat() if current_user.last_usage_update else None, + } \ No newline at end of file diff --git a/api/utils/__pycache__/auth_utils.cpython-311.pyc b/api/utils/__pycache__/auth_utils.cpython-311.pyc index 18962857991145a5090f02c2d8f69186fe8165e3..6d88f7bfb5133c7e4276a6aab55fb8963ce9c9f7 100644 GIT binary patch delta 1318 zcmah}&1>UE6d&2LBwLm&e`T{K8=Pd-jlgQMg?^-2df9I2LN~i?5<*>Cqk81TsT`S+ zOo?%DwjbbyvNUK9dnnY0lD2FQg(p7!WMw(HlFIu{da9BVI`aIOi#(tq@Y1_J= z8W|-++qQPf)u(9O&@6kT7iv{f!#XFwt;!WuhTj4ug;-4GGf*#@P)CCEWpJ+P)75>u z?sKHtb#qTgq&9wV&-~D*34v*Tp+&ZFo32B%2_t55)Q;*QwbxI zq|a8heZ$lsnj%&JzVAylOGj3rXc~rIw9BURTQHIFBgIk11sqApHlem}kyk03br`xd zrDB!PS-;{B;4$upU=WyhKBDhYo)Zt{ntw82Ds%tc_LXjW*-bCEA}7-H{pf>yS6XzX z#lEEWB(*CoxzbW!TIoqEU1`;oR$C#DkGcF@hgZJYaMi1Q^+r#<(N%A{>dnsWt)9Bo zQMXPL^ZkU{OQ>Fa`hnb$vR?e0w{X$Rsh;xufTc5u0cZ{-9Z&uqVw2(@0E+=&h5r)? zkezG2?nY<4MTN}riLAQvOW)Ui-0jNOUHN)nUhBzgUHO(P-y#n4r?F&<|LcTL4geL} zpvKbK2wf7c<*r=@KV6=ALkKEq=Wgh5b680&Cp*%uV9n>dn|@>HkhX3feW3<#{>M~u%Cuwt*OtM fyR)Co9x|Tu3$y&<0!2b4vb`{x|68J$@lpQ(_p(J9 delta 469 zcmX>jdW?^6IWI340}xDzVM!BanaC%>xM!kzXgw1{DsKw=GDZf5)j$jZQG7sAmK2UP zoM__wKykhlE;LaApeSQ1KdNXNV+wbwU>Z{jPYY|5PzrA_gC^fgknx&Kx7b}$3sUn^ zif{3HOy*|{oSe&O#mx$oY6jxZEsUEtF_GLm*iurH^HWlbcqVUURTM8`16jdd#0!#QFG#G+$xlov5}jWU?3w*Eg(tt_Xqa5WDPhb5)Ow4#xTL5^0mv#60}`Ksn4w4>Nc`fk$<0qG z%}KQ@(gt!tAzS>KVe)>?dcFso{2gKw*gEBUOit%AFco49Vf?@VA{Qt_X#8Z6 H7SKEZH@Is8 diff --git a/api/utils/__pycache__/conversation_utils.cpython-311.pyc b/api/utils/__pycache__/conversation_utils.cpython-311.pyc index ea543244c446664da54d1d52b81492f180591b64..5a0769ac016f2ff311d627e1e9efc4280b4649ab 100644 GIT binary patch delta 3101 zcma)8U2Gf25#A;5$UE|m|FZt9=*Maz*0yXnQS2C&{1+>=E&s|t0FjoVc}Isbf8@QR zQwt}MXk8>}gEq#cuzGVR^8D?%N0 zV&YW#?D;d7j=y^B{G@%H9;2rhenvl}dqOxzU@BA$+m2W)YkI*6WSI`jXz7fZ3^E@M z>4+`sth7-(aTDqtkM&_!{z73owfw82|)xLdg#u^0H3p@jDL>!33J}GyKF9D-xAukVm z#!wsEi3HU=AW6~Pl0C@cin_5S1A!rKCEp@fh5h$n{}8us%T?V>XUnA-c$fZ?{K~T#&kw2Pv9i(>&&en`BN8$zTV6|=!r0=4^8%J^1!kZ>8`MP2X9e># zPiGa7!&A!(3tB--(ZqE_H2(n@vHbHY+8VHk<#OvqCWSKX+^V+Bi8lVK_4t{2AD--hmilOB3)swVhFG6|OnWZLlQnZU)#T}Snj z$ub2;tQc^dvui#BwdmUy;rskMKz>0_5=!n6i+nFSOD#GDJh!~-aj|`@TJ#dRLn(PL zh8iBw(-7_GmZX;*gUX@RY2G(*<-WcePnOT*16TEu{)1^YHZ!<#;t)%^g4uKUK%tz? z6pXRqHK{jYvat)um!FZ(NR8Ax&cc80`X}c;ef`g=#gQ}hku!5qAnSiF!M2eqxG;RN zbgfi=tJKG@i2Z#t!eG9Xb%I(ZU%2TibrNnTSIAs7-po(CT)31(y$DiHw2(1Omzj1E z^9;n!11N?E9t=Kxl-Iy+XoQ zn^;zelxX^Uarq;(wfFEl%KMj^($;#hYU8LO)bwwir;Gmny1)NPsAnlOv=|y{sfqWe>Pml0jnAj*m$S{w z<@)g(kGdXp)hjoa2|4YV^gy-c!a7widRp=JWj_fgpAsP)_<{&W!+~WY$p1ePi8KH( z;PgSLodS~Vw(lqQRJ$|N`QeJ0FBroc!z!IGLF8k+JMAdJPpL7y@wvh9D&o%+{+e@y zD@M8GUiSVRoJ^R;FzqtD^ca8B3}6c%VhrCy_)8tVwtOWQ7Ek!HD|quc+q1o0sjI)f;S6(yf?9>)rI-(oA=(l_nY_L{F>a(o12Ca zRN%YuMR(tu8)lEEt!(7ZuGRFSUJDchM4<}xRfDxqF(hTJ8m<{d17yErQXN(RR?snB zorc!cViyg=-f;Y`N=?{DoLSi$~-_^Ap)5h5=r-%EItthIruZ*8(W& z;)QoA($V7HMCcOVy;%|$33bCa0E$rarMg8(R9w?WFJh?}=7ng-P##3+Lr4K|6(9p2 z%Ye`dl_@RN9L~#24zodH!Qzj_E7cZl%_Z69b^jfbih_2_!-nA3p@`D1I>?;{u>Gpf zYBFt3Oa_wILSVDCPV$lc6Z;HUKj^pyfHd}qSj?%HYfkCra{ac$c)3wouCvqNltgcf zh}UZF`A)zyg;dxH1mv)#*gl4_*cgIjaTFyT08Fw1vV5KV!U_)zXTa9#UUC|TvXvI^ zmMWB;17jW*-EI!5- z(q~|Alg@;BTAuuv959aLv6o`pnkGeY-?FdZ#TRB^3}t@~UubVWrZM_I<@|3x#5XG& zoB_{ySo|IIWkh_A92b>14`33T%-P-Ick-0D+Sj}0$B@{5hA*h4E^?S8Swz2E=8f`7 zh2exUIVSjghWIA2~p(iYfT5cNbKZ4*M;?XCFpu320_j!bVg%P$EATlC6Nf1nuBt_EqQ-H$= zTYzQ&oguT(vj?C#k|%SdK#B))=K)e;>h9asC-03F=t0hS8 z!O5WtwRpnG3Btojz`!!x1V&0uc5)!nW>mwnC`#fY^l=c&l#nSXKWQ-ZKeiAYqXDP^%QAQA4k6SCJC@Y|Bo-zd~I|t1$ zGXtia_OT^vomW>`j8qES^Q+Z$OQu#;QyrOBFBy{4lXcswN`|J%id^lA(&n<-GeyY` zV#TzpqN!FLTh)N2^@Qk61y%1DOXeLvBS5Iyg5WN`AjeWP%dT&o2eXC7ub8@l}33ih(kn{ZAC`%rCB3;pQ0!=BBKl{&=wS=kJf3AG?PoO>~G{ zX;ci;u`2aieYsYzU+`x{-R@GuR3$ebmXsPLR-!!ZKG}Zo*R$p{z!%-4@fwz)>#`lM z;ud0Rh$!L4drkL9%;}2irV~;cwe#*=ObLLsG9M`5=i!GDu3Y=##%Gs4x^|r&V$#G) zJ`us(&#yYGANeL6Vn-6URTbBMJO)+X81JinD^J-FRJC{C5h literal 0 HcmV?d00001 diff --git a/api/utils/__pycache__/llm_providers.cpython-311.pyc b/api/utils/__pycache__/llm_providers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..392871e42ac168e7910b2c63041efcc04756eb12 GIT binary patch literal 1507 zcmaJG;&kp;5RVtA|8E&2)cs^caKNtS6Ja_0U{&9lLWyeNm3+zF9kTP za0F-u&>1odGkXY{BY84M3Z!@_cOK9M@(RquxQl=uISf^jEsh+$j>^ZR=qgOdRV{&Y z7j6z&sKpabP7oee0v48`5?CoY*~yXg9o$Nh4D2-HXIEF(u2SQcN@VKK_6#Cxf-d%C ze=fAm2+I-GGBxpbh!$1dbZo&kwq#w949oWC4aa^dlvOREqlt>uRZS`V+@0A|LvzqU zE4Xr>>mwnC`#gADOQu#;QyrOBuNacklXcswN`|J%id^lA($=!tGeyY` zY{j&zqN!FLTh)N2^@Qk6ttIoepAiIAS8YLX7oTgzQWRySTNgoWqw%X}ckqve@tN~G zc$7IgK6Q@i6XQ~i>C@xa>P(**pM0I^<6iI`02Qr6nQ#v4t04HZf&jgBG-l@n;TuQP zrbxoFsh{jMp=TR+cNZx9b-wCY5ig^0$7=|UyQ2=;LEj03SL>!&~O@Ba12!{&$XQAraWB3Bv} z!*r}ly;fhY)$5o18Bw>p)G$@a&4($aMv0{;kGqd{?*ILyIW_QQ_jtUAW$3zW$GfMEVc<_GHJr;AiqPpdTlt%5MI~P*|V7<%-9q{w;!w6TeefGsCS3bCQogQJWi97W) z^Iq8HCOrW_(5~OXm?SukoA7B4-uu7eF^fI<7g$8e;`WMHJiFW7&p-HXu;P8SIm9eW6uJ_KzgY({(#OrJgu{OfmB%7M&xASZU)jb2s0hBEO literal 0 HcmV?d00001 diff --git a/api/utils/__pycache__/llm_utils.cpython-311.pyc b/api/utils/__pycache__/llm_utils.cpython-311.pyc index b936cb2b8878e316108841bbc4f55c1801ccbda6..8f97935c95bd2d21bc0a289ee5d75d0a7d5bfd3d 100644 GIT binary patch delta 62 zcmaDT^;n8`IWI340}zyb=SVZ%$ZO2X$hFy#m5qr}aae OM#dWqe3KvWGy?#WnGJ9N diff --git a/api/utils/auth_utils.py b/api/utils/auth_utils.py index e4f3977..7b21847 100644 --- a/api/utils/auth_utils.py +++ b/api/utils/auth_utils.py @@ -4,9 +4,10 @@ from api.utils.db_utils import get_db from api.models.user import User from starlette.config import Config +from datetime import datetime +import pytz config = Config('.env') - oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login") SECRET_KEY = config('SECRET_KEY') @@ -20,12 +21,29 @@ async def get_current_user(token: str = Depends(oauth2_scheme)): db = await get_db() users_collection = db["users"] user_dict = await users_collection.find_one({"email": email}) - if user_dict is None: raise HTTPException(status_code=401, detail="User not found") user = User(**user_dict) - return user + # Check if the day has changed and reset the API usage if necessary + central_tz = pytz.timezone('US/Central') + today = datetime.now(central_tz).date().isoformat() + if user.last_usage_update is None or user.last_usage_update.isoformat() < today: + # Reset the daily usage and flagship usage + await users_collection.update_one( + {"email": email}, + {"$set": { + "daily_usage": 0, + "daily_flagship_usage": 0, + "last_usage_update": today + }} + ) + # Update the user object with the reset values + user.daily_usage = 0 + user.daily_flagship_usage = 0 + user.last_usage_update = datetime.strptime(today, "%Y-%m-%d").date() + + return user except JWTError: raise HTTPException(status_code=401, detail="Invalid authentication token") \ No newline at end of file diff --git a/api/utils/conversation_utils.py b/api/utils/conversation_utils.py index e5f3245..95e3d4d 100644 --- a/api/utils/conversation_utils.py +++ b/api/utils/conversation_utils.py @@ -4,9 +4,11 @@ from typing import List from api.utils.db_utils import get_db +from api.utils.llm_provider_info import LLM_PROVIDERS from api.models.conversation import Message, Conversation, LanguageModel from bson import ObjectId from fastapi import HTTPException +import pytz async def create_conversation(user_email: str, name: str, model_provider: str, model_name: str): db = await get_db() @@ -91,4 +93,42 @@ async def update_conversation_messages(conversation_id: str, updated_messages: L ) return True else: - return False \ No newline at end of file + return False + +async def update_user_usage(user_email: str, model_name: str, input_tokens: int, output_tokens: int): + db = await get_db() + central_tz = pytz.timezone('US/Central') + today = datetime.now(central_tz).date().isoformat() + user_collection = db['users'] + + llm_provider = next((p for p in LLM_PROVIDERS if p.model_name == model_name), None) + if not llm_provider: + raise ValueError(f"Unknown model: {model_name}") + + input_cost = input_tokens * llm_provider.input_token_cost + output_cost = output_tokens * llm_provider.output_token_cost + total_cost = input_cost + output_cost + + update_query = { + '$inc': { + 'daily_usage': total_cost, + } + } + + if llm_provider.is_flagship: + update_query['$inc']['daily_flagship_usage'] = total_cost + + # Update the user's daily usage + result = await user_collection.update_one( + {'email': user_email, 'last_usage_update': today}, + update_query, + ) + + # If no document was updated, it means it's a new day or the user doesn't exist + if result.modified_count == 0: + # Set the daily usage to the current cost and update the last usage update date + await user_collection.update_one( + {'email': user_email}, + {'$set': {'daily_usage': total_cost, 'daily_flagship_usage': total_cost if llm_provider.is_flagship else 0, 'last_usage_update': today}}, + upsert=True, + ) \ No newline at end of file diff --git a/api/utils/db_utils.py b/api/utils/db_utils.py index 5d4c0a2..1573f31 100644 --- a/api/utils/db_utils.py +++ b/api/utils/db_utils.py @@ -1,3 +1,5 @@ +#api/utils/db_utils.py + from motor.motor_asyncio import AsyncIOMotorClient from starlette.config import Config diff --git a/api/utils/llm_provider_info.py b/api/utils/llm_provider_info.py new file mode 100644 index 0000000..2d17f5e --- /dev/null +++ b/api/utils/llm_provider_info.py @@ -0,0 +1,16 @@ +class LLMProvider: + def __init__(self, model_name: str, model_provider: str, display_name: str, input_token_cost: float, output_token_cost: float, is_flagship: bool): + self.model_name = model_name + self.model_provider = model_provider + self.display_name = display_name + self.input_token_cost = input_token_cost + self.output_token_cost = output_token_cost + self.is_flagship = is_flagship + +LLM_PROVIDERS = [ + LLMProvider("gpt-4-0125-preview", "openai", "gpt-4 turbo", 0.03 / 1000, 0.06 / 1000, True), + LLMProvider("gpt-3.5-turbo-0125", "openai", "gpt-3.5 turbo", 0.002 / 1000, 0.002 / 1000, False), + LLMProvider("claude-3-opus-20240229", "anthropic", "claude 3 opus", 0.02 / 1000, 0.04 / 1000, True), + LLMProvider("claude-3-sonnet-20240229", "anthropic", "claude 3 sonnet", 0.001 / 1000, 0.001 / 1000, False), + LLMProvider("claude-3-haiku-20240307", "anthropic", "claude 3 haiku", 0.001 / 1000, 0.001 / 1000, False), +] \ No newline at end of file diff --git a/api/utils/llm_providers/__pycache__/anthropic.cpython-311.pyc b/api/utils/llm_providers/__pycache__/anthropic.cpython-311.pyc index ee9332f24e6cabaa13a683ae22a52b6c3429572e..f82d6133a328b740f2e2b766a90ca189ce1adaed 100644 GIT binary patch delta 1192 zcmaJ>OK1~87@pa89@{o)8`~ze3HFtuB^oMPUnojZQ?Uq&u^1P3l+ByjP4SUZ4MOJJa|$>LB)d?FW$U}X$ctiDtho#4}yAfX8SNb=jye>S306e2}btEN5f9%S1Rh7g;FIGVJft%sP<-(dj!chVNB7vd5%t=^uuE! z-T%2rZivAJp)BkIaiD{n{8{jTzaz|xWic(TpG%iA9MiF%qS88=mNztIU7P2sgPS7d zwMdnsfBjt>{(N@&F zz;k7uUaL=m^hT5!7x@G)Ko5y?FQAKj<_>e1BN2XQznViPBPeBB8n*0>rH-3@#O)ef$hvNjm}*QU-e>Ex#*#bTaHs3SKVOq| z-#b-0GW>c;KyZw8NR)TDutn3tMiF}{gxeSctY^bv>|=nRT$18&h}@HAU?=%19jE;e zlW$Bpm6XiuriF4&t#N?4v(TutR(!2kM0p(tnIXgkqaJT#E<7G$dN=tl2cVBAN?43h zV36!k9?P9nI~8>kTNzD;w4yOyvW%=b-df?51FbfKL=9JrtE>C4GGdnMDZng`bBo`~ t{#-#XWzj*rn_}$0VcsQsUEBRlFj?RH5=7xXn1L&xMeZSwTs;Cy_y>8*BKrUU delta 797 zcmaJ<&1(}u6raiNOg7mj=F2uoNy#CgYwJQSD2NE5_;o23KML_sVs@%YlTDaigK1KF z5%EwYVGjKRgos!a@#4XYpf?Z2lkCZZ7X_ggFCN4e zr;$h)!TF7XYbi2_rg3d>EVmUvOtF+9M(8=*(J85bVz{cM9?tBN!*PjN^Z_bSD+o4d zY1fbu`X)&X>_v-C$F3A@_bzX?sue)!V$)u$-q}!1X1Bd~aq0Ty!B-Q15kopFF5(L$ zU@5z*=;P@dfnMMgT0j=+DmT%N@=)DWdPsK}3E*lo?v&kX)AlL&q--Iq4K#cLd`Iq(mxsq> zN=%cSo(E6P?tDaIkr?D>7^SkJORmlf}d2RwG->%aY8GTNy V=yS6%46^+-g(JAIkN!12+aEzQv9tgH diff --git a/api/utils/llm_providers/__pycache__/openai.cpython-311.pyc b/api/utils/llm_providers/__pycache__/openai.cpython-311.pyc index 46de2b9e4e9919bd9856c2071fe73e07d374bbad..ee105d9f755c08e7d973d430fcac02d50ccc8444 100644 GIT binary patch literal 2477 zcma)7U2GFa5Z?Ro`JBXdVkfafen?BJ#t>}&;{K&c(?9`115zoZZY#^Z+c@W(eb(I> zL9iuDMXjk&)k+{D)QUW%LE(XyzE^$X(MBh6I;j#;Rf)bCrAQS|ojv~}QgyvMH?upl zJ3I5u`e(Ph3qf-{<(Gb85&DA&jpnE`+n0d3g+wG$GK!HeEmLVaMpJ}mWG2nVSes{M zN1BUq6rvGkMaKq(&=!nZi#cJIhT#;s&Xs%4PC3Sl{8bcliPXm^h6egz4=!kkC9n6KFalSBnbV!@GT4xz*CD_ovS*{~s}h>oG=VX5Y=QcX__ z&b+foiyVB}b$3&>jmf`|ehTFs`e;kx8o%ysNiFV0yUu8>!dc+E^8$kVP!ziUWfbYJ zv{|E6UqjjY)=RyaRNX__RY)=BYg^~cFFZv2po@xOJ_$K&7 z^dtQPeccBeQFm#JRTUXqPC?1&IHRj1#v$6G!v`#98fygD8o}B4`15m-PXO#%M3!<` zjT}{kTpBF2h@i+aPU6UNLRdK>rLze^xA|mNj{x)|IbD*$N>0bKs&Y{h$sP(+R6+`8 zFF)^^mL*LWlyr7xb+R@cPvQ(#Np-1M%PJWSho@ylNXXg@Fb!>zZS4qN8kegJkAR&^IP;|}MA#{yiZ<9c^ z-T>OD1m4xoRutrEysDCARvgyD5*MORkfg1*LO z;!BDex2@FxQDHNQG`1Kmm$o=TmH=25D=a5;tD8VW#<~PJfF;;nCvaU>Bmrxd1CG}d zb`UKtkp)zVD!HQCPY9_DtIK|(aXKN%A(zzyK)(nGOppz0-t|bCY)+5c4r(e%qSb3- zTF|i=7nXCG70q%hIlZN2IqYfsTy+2}H1Z0L*W3GchOFRDX;Dd`l4nc10$8diqDnUa z`Z%!12uxNTAU`F*{+p2EiDx~iH)weGZcamfjC~(eyz_>4zIfty|6qG=9$W&cGPq~} z^)DL!#p3Bd{X;k2HT-)Y`42wyA2dC?49~dXdEM+AGQIw)6ZIc@ir9WQ2S^pMo}EeY z!)gx-?kcnWX8%C(BQr2sMVw>szex$3zQB$B>y+Vp)$mOiz5^wv+3hRMnuEd8agz@| z;&(sfcdt7u{Di?zl=%shANa|?wsnCOpBq!hjj7WQbM@{{s4ZQ>p*} delta 669 zcmaJ<&1=*^6o2!PuU!-R;p$e)(z+EI5f7Gn5D}K1dQcC7g^1viN$tjLQYPI6af>a8 zc<|6NU{9harC9$6Z(e(o+&p@6!J8-FL?RTyUznHQoA-V%@6F`X^y|9!!F6kZ*!^a& z=XBp|J`?C2fB-=f=#zjD&1pjWEMNpEL`<;v1mGjCKL|MP5?oyHlaV9iq#qca^P7F$)O;4>A4*sf^oRBW4% z_zo3@ey;&mz%zQacn#~}RHj;C_lcw4vO{94Ro+owj0LkMr% z+{k)iloy7GlC7?+W5_p(Q(=}4Bbj%%dRba3<+K(V;l?0+TxR8Z!Z{7;=}=S6@}?(S z$n_lgF&Z^#E%F+iKBorOtqZL&aO2GXjQU~Cx4p(ow|d47G+TSaywBonQ8iekKM([]); const [loading, setLoading] = useState(false); + const [userUsage, setUserUsage] = useState<{ + daily_flagship_usage: number; + daily_usage: number; + } | null>(null); + + useEffect(() => { + const fetchUserUsage = async () => { + try { + const response = await fetch(`${process.env.BACKEND_URL}/usage`, { + headers: { + Authorization: `Bearer ${cookies.token}`, + }, + mode: "cors", + credentials: "include", + }); + + if (response.ok) { + const data = await response.json(); + setUserUsage(data); + } else { + console.error("Failed to fetch user usage"); + } + } catch (error) { + console.error("Error fetching user usage:", error); + } + }; + + if (cookies.token) { + fetchUserUsage(); + } + }, [cookies.token]); const handleChange = (event: React.ChangeEvent) => { setTextValue(event.target.value); @@ -85,18 +116,18 @@ export default function Home() { } }; - const handleEditMessageWrapper = async (index: number) => { - if (conversationId) { - await handleEditMessage( - index, - messages, - conversationId, - setMessages, - setTextValue, - cookies - ); - } - }; + const handleEditMessageWrapper = async (index: number) => { + if (conversationId) { + await handleEditMessage( + index, + messages, + conversationId, + setMessages, + setTextValue, + cookies + ); + } + }; useEffect(() => { const token = searchParams.get("token"); @@ -232,27 +263,76 @@ export default function Home() { py={4} > - + + + {userUsage && ( + <> + {LLMProviders.find( + (model) => model.model_name === selectedModel + )?.isFlagship && + userUsage.daily_flagship_usage > + parseFloat(process.env.DAILY_FLAGSHIP_USAGE_LIMIT || "0") / + 2 && ( + + parseFloat( + process.env.DAILY_FLAGSHIP_USAGE_LIMIT || "0" + ) + ? "red.700" + : "black" + } + > + ${userUsage.daily_flagship_usage.toFixed(2)} /{" "} + {process.env.DAILY_FLAGSHIP_USAGE_LIMIT} + + )} + {!LLMProviders.find( + (model) => model.model_name === selectedModel + )?.isFlagship && + userUsage.daily_usage > + (parseFloat(process.env.DAILY_USAGE_LIMIT || "0") * 3) / + 4 && ( + + parseFloat(process.env.DAILY_USAGE_LIMIT || "0") + ? "red.700" + : "black" + } + > + ${userUsage.daily_usage.toFixed(2)} /{" "} + {process.env.DAILY_USAGE_LIMIT} + + )} + + )} + prevMessages.length === 0 ? [...prevMessages, newMessage] : [...prevMessages]); + setMessages((prevMessages) => + prevMessages.length === 0 + ? [...prevMessages, newMessage] + : [...prevMessages] + ); await handleStreamingResponse(messageResponse, setMessages); + } else if (response.status === 429) { + // Handle 429 error + setMessages((prevMessages) => prevMessages.slice(0, -1)); // Remove the last message + setTextValue(newMessage.content); // Put the message back in the input + // Display an error message or alert to the user + console.error("API usage limit exceeded"); } } } else { @@ -177,6 +187,12 @@ export const handleSendMessage = async ( if (response.ok) { await handleStreamingResponse(response, setMessages); + } else if (response.status === 429) { + // Handle 429 error + setMessages((prevMessages) => prevMessages.slice(0, -1)); // Remove the last message + setTextValue(newMessage.content); // Put the message back in the input + // Display an error message or alert to the user + console.error("API usage limit exceeded"); } } } diff --git a/next.config.js b/next.config.js index e25a4ea..1a296b4 100644 --- a/next.config.js +++ b/next.config.js @@ -27,6 +27,8 @@ const nextConfig = { }, env: { BACKEND_URL: process.env.BACKEND_URL, + DAILY_FLAGSHIP_USAGE_LIMIT: process.env.DAILY_FLAGSHIP_USAGE_LIMIT, + DAILY_USAGE_LIMIT: process.env.DAILY_USAGE_LIMIT }, experimental: { outputFileTracingExcludes: { diff --git a/requirements.txt b/requirements.txt index 11cd95f..bf9ffca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,6 @@ anthropic python-multipart itsdangerous fastapi-cors -motor \ No newline at end of file +motor +tiktoken +pytz \ No newline at end of file