Skip to content

Commit

Permalink
Add middlewares (#157)
Browse files Browse the repository at this point in the history
* Add middlewares

* Refactor the routers

* Add AFTER REQUEST

* Add arcs

* Add body params

* Add comments for decorator
  • Loading branch information
sansyrox authored Feb 14, 2022
1 parent 67975e9 commit 40f123b
Show file tree
Hide file tree
Showing 11 changed files with 512 additions and 99 deletions.
19 changes: 18 additions & 1 deletion integration_tests/base_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,27 @@ def message():
async def hello(request):
global callCount
callCount += 1
_message = "Called " + str(callCount) + " times"
message = "Called " + str(callCount) + " times"
print(message)
return jsonify(request)


@app.before_request("/")
async def hello_before_request(request):
global callCount
callCount += 1
print(request)
return ""


@app.after_request("/")
async def hello_after_request(request):
global callCount
callCount += 1
print(request)
return ""


@app.get("/test/:id")
async def test(request):
print(request)
Expand Down
83 changes: 82 additions & 1 deletion robyn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ def __init__(self, file_object):
self.routes = []
self.headers = []
self.routes = []
self.directories = []
self.middlewares = []
self.web_sockets = {}
self.directories = []
self.event_handlers = {}

def add_route(self, route_type, endpoint, handler):
Expand All @@ -61,6 +62,84 @@ def add_route(self, route_type, endpoint, handler):
)
)

def add_middleware_route(self, route_type, endpoint, handler):
"""
[This is base handler for the middleware decorator]
:param route_type [str]: [??]
:param endpoint [str]: [endpoint for the route added]
:param handler [function]: [represents the sync or async function passed as a handler for the route]
"""

""" We will add the status code here only
"""
number_of_params = len(signature(handler).parameters)
self.middlewares.append(
(
route_type,
endpoint,
handler,
asyncio.iscoroutinefunction(handler),
number_of_params,
)
)

def before_request(self, endpoint):
"""
[The @app.before_request decorator to add a get route]
:param endpoint [str]: [endpoint to server the route]
"""

# This inner function is basically a wrapper arround the closure(decorator)
# being returned.
# It takes in a handler and converts it in into a closure
# and returns the arguments.
# Arguments are returned as they could be modified by the middlewares.
def inner(handler):
async def async_inner_handler(*args):
await handler(args)
return args

def inner_handler(*args):
handler(*args)
return args

if asyncio.iscoroutinefunction(handler):
self.add_middleware_route("BEFORE_REQUEST", endpoint, async_inner_handler)
else:
self.add_middleware_route("BEFORE_REQUEST", endpoint, inner_handler)

return inner

def after_request(self, endpoint):
"""
[The @app.after_request decorator to add a get route]
:param endpoint [str]: [endpoint to server the route]
"""

# This inner function is basically a wrapper arround the closure(decorator)
# being returned.
# It takes in a handler and converts it in into a closure
# and returns the arguments.
# Arguments are returned as they could be modified by the middlewares.
def inner(handler):
async def async_inner_handler(*args):
await handler(args)
return args

def inner_handler(*args):
handler(*args)
return args

if asyncio.iscoroutinefunction(handler):
self.add_middleware_route("AFTER_REQUEST", endpoint, async_inner_handler)
else:
self.add_middleware_route("AFTER_REQUEST", endpoint, inner_handler)

return inner

def add_directory(
self, route, directory_path, index_file=None, show_files_listing=False
):
Expand Down Expand Up @@ -95,6 +174,7 @@ def start(self, url="127.0.0.1", port=5000):
if not self.dev:
workers = self.workers
socket = SocketHeld(url, port)
print(self.middlewares)
for _ in range(self.processes):
copied_socket = socket.try_clone()
p = Process(
Expand All @@ -103,6 +183,7 @@ def start(self, url="127.0.0.1", port=5000):
self.directories,
self.headers,
self.routes,
self.middlewares,
self.web_sockets,
self.event_handlers,
copied_socket,
Expand Down
7 changes: 6 additions & 1 deletion robyn/processpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def spawn_process(
directories, headers, routes, web_sockets, event_handlers, socket, workers
directories, headers, routes, middlewares, web_sockets, event_handlers, socket, workers
):
"""
This function is called by the main process handler to create a server runtime.
Expand All @@ -21,6 +21,7 @@ def spawn_process(
:param directories tuple: the list of all the directories and related data in a tuple
:param headers tuple: All the global headers in a tuple
:param routes tuple: The routes touple, containing the description about every route.
:param middlewares tuple: The middleware router touple, containing the description about every route.
:param web_sockets list: This is a list of all the web socket routes
:param event_handlers Dict: This is an event dict that contains the event handlers
:param socket Socket: This is the main tcp socket, which is being shared across multiple processes.
Expand Down Expand Up @@ -53,6 +54,10 @@ def spawn_process(
route_type, endpoint, handler, is_async, number_of_params = route
server.add_route(route_type, endpoint, handler, is_async, number_of_params)

for route in middlewares:
route_type, endpoint, handler, is_async, number_of_params = route
server.add_middleware_route(route_type, endpoint, handler, is_async, number_of_params)

if "startup" in event_handlers:
server.add_startup_handler(event_handlers[Events.STARTUP][0], event_handlers[Events.STARTUP][1])

Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mod processor;
mod router;
mod routers;
mod server;
mod shared_socket;
mod types;
Expand Down
141 changes: 135 additions & 6 deletions src/processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use anyhow::{bail, Result};
use crate::types::{Headers, PyFunction};
use futures_util::stream::StreamExt;
use pyo3::prelude::*;
use pyo3::types::PyDict;
use pyo3::types::{PyDict, PyTuple};

use std::fs::File;
use std::io::Read;
Expand Down Expand Up @@ -40,7 +40,7 @@ pub async fn handle_request(
payload: &mut web::Payload,
req: &HttpRequest,
route_params: HashMap<String, String>,
queries: HashMap<&str, &str>,
queries: HashMap<String, String>,
) -> HttpResponse {
let contents = match execute_http_function(
function,
Expand All @@ -67,6 +67,36 @@ pub async fn handle_request(
response.body(contents)
}

pub async fn handle_middleware_request(
function: PyFunction,
number_of_params: u8,
headers: &Arc<Headers>,
payload: &mut web::Payload,
req: &HttpRequest,
route_params: HashMap<String, String>,
queries: HashMap<String, String>,
) -> Py<PyTuple> {
let contents = match execute_middleware_function(
function,
payload,
headers,
req,
route_params,
queries,
number_of_params,
)
.await
{
Ok(res) => res,
Err(err) => Python::with_gil(|py| {
println!("{:?}", err);
PyTuple::empty(py).into_py(py)
}),
};

contents
}

// ideally this should be async
/// A function to read lossy files and serve it as a html response
///
Expand All @@ -81,6 +111,101 @@ fn read_file(file_path: &str) -> String {
String::from_utf8_lossy(&buf).to_string()
}

async fn execute_middleware_function<'a>(
function: PyFunction,
payload: &mut web::Payload,
headers: &Headers,
req: &HttpRequest,
route_params: HashMap<String, String>,
queries: HashMap<String, String>,
number_of_params: u8,
) -> Result<Py<PyTuple>> {
// TODO:
// try executing the first version of middleware(s) here
// with just headers as params

let mut data: Option<Vec<u8>> = None;

if req.method() == Method::POST
|| req.method() == Method::PUT
|| req.method() == Method::PATCH
|| req.method() == Method::DELETE
{
let mut body = web::BytesMut::new();
while let Some(chunk) = payload.next().await {
let chunk = chunk?;
// limit max size of in-memory payload
if (body.len() + chunk.len()) > MAX_SIZE {
bail!("Body content Overflow");
}
body.extend_from_slice(&chunk);
}

data = Some(body.to_vec())
}

// request object accessible while creating routes
let mut request = HashMap::new();
let mut headers_python = HashMap::new();
for elem in headers.into_iter() {
headers_python.insert(elem.key().clone(), elem.value().clone());
}

match function {
PyFunction::CoRoutine(handler) => {
let output = Python::with_gil(|py| {
let handler = handler.as_ref(py);
request.insert("params", route_params.into_py(py));
request.insert("queries", queries.into_py(py));
request.insert("headers", headers_python.into_py(py));
request.insert("body", data.into_py(py));

// this makes the request object to be accessible across every route
let coro: PyResult<&PyAny> = match number_of_params {
0 => handler.call0(),
1 => handler.call1((request,)),
// this is done to accomodate any future params
2_u8..=u8::MAX => handler.call1((request,)),
};
pyo3_asyncio::tokio::into_future(coro?)
})?;

let output = output.await?;

let res = Python::with_gil(|py| -> PyResult<Py<PyTuple>> {
let output: Py<PyTuple> = output.extract(py).unwrap();
Ok(output)
})?;

Ok(res)
}

PyFunction::SyncFunction(handler) => {
tokio::task::spawn_blocking(move || {
Python::with_gil(|py| {
let handler = handler.as_ref(py);
request.insert("params", route_params.into_py(py));
request.insert("queries", queries.into_py(py));
request.insert("headers", headers_python.into_py(py));
request.insert("body", data.into_py(py));

let output: PyResult<&PyAny> = match number_of_params {
0 => handler.call0(),
1 => handler.call1((request,)),
// this is done to accomodate any future params
2_u8..=u8::MAX => handler.call1((request,)),
};

let output: Py<PyTuple> = output?.extract().unwrap();

Ok(output)
})
})
.await?
}
}
}

// Change this!
#[inline]
async fn execute_http_function(
Expand All @@ -89,7 +214,7 @@ async fn execute_http_function(
headers: &Headers,
req: &HttpRequest,
route_params: HashMap<String, String>,
queries: HashMap<&str, &str>,
queries: HashMap<String, String>,
number_of_params: u8,
) -> Result<String> {
let mut data: Option<Vec<u8>> = None;
Expand Down Expand Up @@ -211,9 +336,12 @@ async fn execute_http_function(
}
}

pub async fn execute_event_handler(event_handler: Option<PyFunction>, event_loop: Py<PyAny>) {
pub async fn execute_event_handler(
event_handler: Option<Arc<PyFunction>>,
event_loop: Arc<Py<PyAny>>,
) {
match event_handler {
Some(handler) => match handler {
Some(handler) => match &(*handler) {
PyFunction::SyncFunction(function) => {
println!("Startup event handler");
Python::with_gil(|py| {
Expand All @@ -225,7 +353,8 @@ pub async fn execute_event_handler(event_handler: Option<PyFunction>, event_loop
println!("Startup event handler async");

let coroutine = function.as_ref(py).call0().unwrap();
pyo3_asyncio::into_future_with_loop(event_loop.as_ref(py), coroutine).unwrap()
pyo3_asyncio::into_future_with_loop((*event_loop).as_ref(py), coroutine)
.unwrap()
});
future.await.unwrap();
}
Expand Down
Loading

0 comments on commit 40f123b

Please sign in to comment.