-
Notifications
You must be signed in to change notification settings - Fork 0
/
harbor.py
215 lines (184 loc) · 8.62 KB
/
harbor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import imghdr
from starlette.applications import Starlette
from starlette.middleware.cors import CORSMiddleware
from starlette.responses import JSONResponse, PlainTextResponse, RedirectResponse
from starlette.config import Config
import json
import requests
from uvicorn import run as uvi_run
from gino import Gino
from sqlalchemy import and_
# Database Management
db = Gino()
class Model(db.Model):
def __repr__(self):
return str(self)
def __str__(self):
return "{\n" + " id: {0},\n title: {1},\n desc: {2},\n versions: {3},\n category: {4},\n" \
" params: {5},\n views: {6},\n requests: {7}\n".format(
self.id, self.title, self.desc, self.versions, self.category, self.params,
self.views, self.requests) + "}"
__tablename__ = 'models'
id = db.Column(db.Integer(), primary_key=True)
title = db.Column(db.String())
desc = db.Column(db.String())
versions = db.Column(db.ARRAY(db.Float()))
category = db.Column(db.String())
params = db.Column(db.ARRAY(db.JSON()))
views = db.Column(db.Integer(), default=0)
requests = db.Column(db.Integer(), default=0)
output_type = db.Column(db.String())
output_attr = db.Column(db.JSON())
clipper_model_name = db.Column(db.String())
def to_json(self):
return {"id": self.id, "title": self.title, "desc": self.desc, "versions": self.versions,
"output_type": self.output_type, "clipper_model_name": self.clipper_model_name, "output_attr": self.output_attr,
"category": self.category, "params": self.params, "views": self.views, "requests": self.requests}
models = {
"models": Model
}
# Routing and Backend Logic
app = Starlette()
app.add_middleware(CORSMiddleware, allow_origins=['*'], allow_methods=["*"], allow_headers=["*"])
CLIPPER_URL = None
@app.on_event("startup")
async def startup():
# Why does the example gino code make startup an async event with await? Isn't this a necessary first step?
# This code assumes a PostgreSQL database with a name specified in the env file that has already been created
# $ createdb gino
harbor_config = Config('harbor.env')
DB_USER = harbor_config('DB_USER')
PASSWORD = harbor_config('PASSWORD')
DB_NAME = harbor_config('DB_NAME')
global CLIPPER_URL
CLIPPER_URL = harbor_config('CLIPPER_URL')
await db.set_bind('postgresql://{0}:{1}@localhost/{2}'.format(DB_USER, PASSWORD, DB_NAME))
await db.gino.create_all(checkfirst=True)
# user/pass may or may not be necessary, depending on OS
# maybe should consider adding harbor user to postgreSQL as a prereq
@app.on_event("shutdown")
async def shutdown():
await db.pop_bind().close()
@app.route('/')
async def homepage(request):
return PlainTextResponse("Hello, world!")
@app.route('/models', methods=["GET"])
@app.route('/models/', methods=["GET"])
async def get_models(request):
# name and category values assumed to be alphanumeric
names = request.query_params['name'].split('+') if 'name' in request.query_params else []
categories = request.query_params['category'].split('+') if 'category' in request.query_params else []
if not alpha_num_validator(names + categories):
return PlainTextResponse("400 Bad Request\nName and category values must be alphanumeric.", status_code=400)
if names and categories:
models = await Model.query.where(
and_(Model.title.in_(names),
Model.category.in_(categories))
).gino.all()
elif names:
models = await Model.query.where(Model.title.in_(names)).gino.all()
elif categories:
models = await Model.query.where(Model.category.in_(categories)).gino.all()
else:
models = await Model.query.gino.all()
return JSONResponse({"models": [model.to_json() for model in models]})
@app.route('/models/popular', methods=["GET"])
@app.route('/models/popular/', methods=["GET"])
async def get_popular(request):
# optional params start_rank and count
count = int(request.query_params.get('count', 5))
start_rank = int(request.query_params.get('start_rank', 0))
metric = Model.views if request.query_params.get('metric', '') == 'views' else Model.requests
models = await Model.query.order_by(metric.desc()).offset(start_rank).limit(count).gino.all()
return JSONResponse({"models": [model.to_json() for model in models]})
@app.route('/model')
@app.route('/model/')
async def forgot_id(request):
return PlainTextResponse("404 Not Found\nMust provide model ID", status_code=404)
@app.route('/model/{id:int}', methods=["GET"])
@app.route('/model/{id:int}/', methods=["GET"])
async def get_model(request):
id = request.path_params["id"]
# if not id.isdigit():
# return PlainTextResponse("400 Bad Request\nMust provide integer ID.", status_code=400)
id = int(id)
model = await Model.get(id)
if model is None:
return PlainTextResponse("400 Bad Request\nModel with given ID not found", status_code=400)
await model.update(views=model.views + 1).apply()
return JSONResponse({id: model.to_json() if model else None})
@app.route('/query', methods=["POST", "OPTIONS"])
@app.route('/query/', methods=["POST", "OPTIONS"])
async def query_clipper(request):
body = await request.json()
# print(body)
if any([elem not in body for elem in ["id", "version", "query"]]):
return PlainTextResponse("400 Bad Request\nIncomplete query provided.", status_code=400)
# Front-End Work done here:
id = int(body["id"])
query = body["query"]
model = await Model.get(id)
if model is None:
return PlainTextResponse("400 Bad Request\nModel with given ID not found", status_code=400)
await model.update(requests=model.requests + 1).apply()
# Image verification
# Think of cleaner way to do this -- verifications will be necessary for many models, some of which are from users.
# In the future, maybe devise a scheme for indicating which models need certain params to be verified
if 'img' in query:
if imghdr.what(h=query['img']) != 'jpeg':
return PlainTextResponse("400 Bad Request\nModel only accepts JPEG images.", status_code=400)
# Accessing Clipper
addr = "http://18.213.175.138:1337/%s/predict" % (model.clipper_model_name)
req_headers = {"Content-Type": "application/json"}
req_json = json.dumps(query)
try:
clipperResponse = requests.post(addr, headers=req_headers, data=req_json).json()
except:
return JSONResponse({
"error": "clipper returned an error",
"req_json": req_json
})
# return JSONResponse({"garbage": "garbage"});
return JSONResponse({
"model": model.to_json() if model else None,
"req_json": req_json,
"url": addr,
"data": clipperResponse
})
@app.route('/model/create', methods=["POST"])
@app.route('/model/create/', methods=["POST"])
async def create_model(request):
common_sad_path = PlainTextResponse("400 Bad Request\nRequired parameters not provided.", status_code=400)
try:
body = await request.json()
except json.JSONDecodeError:
return PlainTextResponse("400 Bad Request\nPlease provide parameters in request body as a JSON.", status_code=400)
if "version" not in body or not isinstance(body["version"], float):
return common_sad_path
if "id" not in body:
if any([elem not in body for elem in ["title", "desc", "category", "params", "clipper_model_name", "output_type"]]):
return common_sad_path
else:
if not isinstance(body['params'], list) or not all((isinstance(elem, dict) for elem in body['params'])):
return common_sad_path
await Model.create(title=body["title"], desc=body["desc"], versions=[body["version"]], clipper_model_name=body["clipper_model_name"], output_attr=["output_attr"],
category=body["category"], params=body["params"], output_type=body["output_type"])
else:
id = body["id"]
if not isinstance(id, int):
return common_sad_path
model = await Model.get(id)
if model is None:
return common_sad_path
ver = body["version"]
if ver not in model.versions:
await model.update(versions=(model.versions + [ver])).apply()
return JSONResponse({"success": True})
def alpha_num_validator(arg):
if isinstance(arg, list):
return not arg or all([all([e.isalnum() for e in elem.split()]) or not elem for elem in arg])
return arg.isalnum()
def string_is_float(num):
return all([nums.isdigit() for nums in num.split(".")])
if __name__ == '__main__':
uvi_run(app, host='0.0.0.0', port=8000)