Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions dist/pyrobusta/assets/www/examples.html
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ <h2>Demo Application</h2>
).lower()

if is_detailed not in ("true", "false"):
http_ctx.terminate(400, True)
http_ctx.terminate(400)
return "text/plain", "Invalid query"

include_server_version = is_detailed.lower() == "true"
Expand All @@ -115,10 +115,10 @@ <h2>Demo Application</h2>
@HttpEngine.route("/{app_or_server}/version", "GET")
def version(http_ctx, _):
include_server_version = False
resource = http_ctx.url.split(b"/")[1]
resource = http_ctx.path_segment(0)

if resource not in (b"app", b"server"):
http_ctx.terminate(404, True)
http_ctx.terminate(404)
return "text/plain", "Not found"

version_string = APP_VERSION if resource == b"app" else PYROBUSTA_VERSION
Expand Down
8 changes: 4 additions & 4 deletions example/demo_app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def version(http_ctx, _):
).lower()

if is_detailed not in ("true", "false"):
http_ctx.terminate(400, True)
http_ctx.terminate(400)
return "text/plain", "Invalid query"

include_server_version = is_detailed.lower() == "true"
Expand All @@ -37,10 +37,10 @@ def version(http_ctx, _):
@HttpEngine.route("/{app_or_server}/version", "GET")
def version(http_ctx, _):
include_server_version = False
resource = http_ctx.url.split(b"/")[1]
resource = http_ctx.path_segment(0)

if resource not in (b"app", b"server"):
http_ctx.terminate(404, True)
http_ctx.terminate(404)
return "text/plain", "Not found"

version_string = APP_VERSION if resource == b"app" else PYROBUSTA_VERSION
Expand All @@ -59,4 +59,4 @@ async def main():


if __name__ == "__main__":
asyncio.run(main())
asyncio.run(main())
106 changes: 72 additions & 34 deletions src/pyrobusta/protocol/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ class HttpEngine:
"status_code",
"resp_headers",
"resp_handler",
"aborted",
"version",
"headers",
"method",
Expand All @@ -66,6 +65,7 @@ class HttpEngine:
"content_len_cnt",
"recv_chunk_size",
"is_req_empty",
"_is_req_complete",
"mp_boundary",
"mp_is_first",
"mp_is_last",
Expand Down Expand Up @@ -151,7 +151,6 @@ def __init__(self):
self.status_code = None
self.resp_headers = []
self.resp_handler = None
self.aborted = False

# [Recived request]
self.version = None
Expand All @@ -162,6 +161,7 @@ def __init__(self):
self.content_len_cnt = 0
self.recv_chunk_size = 0
self.is_req_empty = True
self._is_req_complete = False

# [Multipart state]
self.mp_boundary = None
Expand All @@ -175,7 +175,6 @@ def reset(self):
self.status_code = None
self.resp_headers.clear()
self.resp_handler = None
self.aborted = False
self.version = None
self.headers.clear()
self.method = None
Expand All @@ -184,6 +183,7 @@ def reset(self):
self.content_len_cnt = 0
self.recv_chunk_size = 0
self.is_req_empty = True
self._is_req_complete = False
self.mp_boundary = None

# =========================================
Expand Down Expand Up @@ -255,6 +255,16 @@ def percent_decode(s: str):
i += 1
return "".join(out)

def path_segment(self, idx: int):
"""
Return the nth path segment of the URL path.
The index is shifted by one to ignore the first
empty segment before the leading slash ('/').
:param idx: index of the segment
:return: string path segment
"""
return self.url.split(b"/")[idx + 1].decode("ascii")

def get_query_param(self, key: str, default: str = None) -> str:
"""
Parse a query and return the value belonging to a key
Expand Down Expand Up @@ -493,8 +503,10 @@ def set_response_body(
:param body: body to be sent in the response
:param content_type: content-type of the body
"""
if body is None:
return
self._unset_response_handler()

if not body:
body_encoded = b""
if isinstance(body, (bytes, bytearray, memoryview)):
body_encoded = body
elif isinstance(body, str):
Expand All @@ -503,19 +515,31 @@ def set_response_body(
body_encoded = dumps(body).encode()
else:
raise ValueError("Unhandled body type")

self.set_response_header(
b"content-length", str(len(body_encoded)).encode("ascii")
)
self.set_response_header(b"content-type", content_type.encode("ascii"))
if self.method != self.HEAD:
self.resp_handler = BytesIO(body_encoded)

if len(body_encoded):
self.set_response_header(b"content-type", content_type.encode("ascii"))

if self.method != self.HEAD:
self.resp_handler = BytesIO(body_encoded)

def _unset_response_handler(self):
"""
Unset the response handler (if set).
"""
if type(self.resp_handler).__name__ in ("FileIO", "BytesIO"):
self.resp_handler.close()
self.resp_handler = None

def do_keep_alive(self):
"""
Determine if the connection should be kept alive
depending on the HTTP version and headers sent in the request.
"""
if self.aborted:
if self.is_terminated() and not self._is_req_complete:
return False

connection_tokens = [
Expand All @@ -533,8 +557,7 @@ def _handle_route_response(self, callback_response: tuple | None):
set a status code, default to HTTP 200. If the handler returns
a response body and content type, set them accordingly.
"""
if not self.is_terminated():
self.terminate(200, True)
self.terminate(self.status_code or 200)

if callback_response is None:
return
Expand All @@ -547,36 +570,25 @@ def _handle_route_response(self, callback_response: tuple | None):

self.set_response_body(data, content_type=dtype)

def terminate(self, status_code: int, request_complete: bool = False):
def terminate(self, status_code: int):
"""
Regular state machine termination with a specific status code.
:param status_code: HTTP status code
:param request_complete: true if the complete request is processed
"""
self.state = None
self.state = self._terminal_st
if not isinstance(status_code, int) or status_code not in self.RESP_HEADERS:
raise ValueError("Invalid status")
self.status_code = status_code

if self.version == b"HTTP/1.0" and self.do_keep_alive() and request_complete:
self.set_response_header(b"connection", b"keep-alive")
elif (
self.version == b"HTTP/1.1"
and not self.do_keep_alive()
and not request_complete
):
self.set_response_header(b"connection", b"close")

def abort(self, status_code: int):
"""
Abort state machine due to runtime errors.
Reset any header or response body set earlier.
:param status_code: HTTP status code
"""
self.aborted = True
self.resp_headers = []
if type(self.resp_handler).__name__ in ("FileIO", "BytesIO"):
self.resp_handler.close()
self.resp_handler = None
self.terminate(status_code, False)
self._unset_response_handler()
self.terminate(status_code)

def is_request_empty(self):
"""
Expand All @@ -588,7 +600,7 @@ def is_terminated(self):
"""
Returns true if the state machine is terminated.
"""
return self.state is None and self.status_code
return self.state is None

def run(self, rx):
"""
Expand Down Expand Up @@ -651,6 +663,7 @@ def _consume_payload(self, rx, size, last=False):
of content length when the last flag is set. When the request is chunked,
the content length should not be set, otherwise it is ignored.
"""
assert not self._is_req_complete
if (
not self.is_chunked()
and "content-length" in self.headers
Expand All @@ -665,6 +678,7 @@ def _consume_payload(self, rx, size, last=False):
raise InvalidContentLength()
self.content_len_cnt += size
rx.consume(size)
self._is_req_complete = last

# ================================================================================
# Parser states
Expand Down Expand Up @@ -737,7 +751,7 @@ def _route_request_st(self, _):
if self.method == self.OPTIONS:
supported_methods = self._supported_methods(self.url)
self.set_response_header(b"allow", b", ".join(supported_methods))
self.terminate(204, True)
self.terminate(204)
return
if self.has_payload():
if self.method in (self.GET, self.HEAD):
Expand Down Expand Up @@ -820,7 +834,7 @@ def _app_endpoint_st(self, rx):
self, bytes(rx.peek(self.recv_chunk_size))
)
self._consume_payload(rx, self.recv_chunk_size + 2)
if not self.is_terminated():
if not self.state == self._terminal_st: # pylint: disable=W0143
self.state = self._recv_chunk_size_st
return
else:
Expand All @@ -834,6 +848,7 @@ def _app_endpoint_st(self, rx):
self._consume_payload(rx, self.headers["content-length"], last=True)
else:
callback_response = callback(self, b"")
self._is_req_complete = True

self._handle_route_response(callback_response)

Expand All @@ -855,7 +870,7 @@ def _fs_retrieve_st(self, _):
try:
if not is_path_served:
stat(norm_path)
self.terminate(403, True)
self.terminate(403)
return

try:
Expand All @@ -870,12 +885,12 @@ def _fs_retrieve_st(self, _):
b"content-length", str(stat(norm_path)[6]).encode("ascii")
)
self.set_response_header(b"content-type", content_type)
self.terminate(200, True)
self.terminate(200)
if self.method != self.HEAD:
self.resp_handler = open(norm_path, "rb") # pylint: disable=R1732
return
except OSError:
self.terminate(404, True)
self.terminate(404)

def _start_multipart_parser_st(self, rx): # pylint: disable=W0613
"""
Expand All @@ -889,6 +904,29 @@ def generate_multipart_response(self, callback, dtype): # pylint: disable=W0613
"""
self.abort(503)

def _terminal_st(self, rx): # pylint: disable=W0613
"""
Terminal state for finalizing request/response processing.
"""
if (
self.version == b"HTTP/1.0"
and self.do_keep_alive()
and self._is_req_complete
):
self.set_response_header(b"connection", b"keep-alive")
elif self.version == b"HTTP/1.1" and (
not self.do_keep_alive() or not self._is_req_complete
):
self.set_response_header(b"connection", b"close")

if (
self.get_response_header(b"transfer-encoding") != b"chunked"
and self.get_response_header(b"content-length") is None
):
self.set_response_header(b"content-length", b"0")

self.state = None


def enable_optional_features():
"""
Expand Down
Loading
Loading