diff --git a/dist/pyrobusta/assets/www/examples.html b/dist/pyrobusta/assets/www/examples.html
index 1ade6de..99e0481 100644
--- a/dist/pyrobusta/assets/www/examples.html
+++ b/dist/pyrobusta/assets/www/examples.html
@@ -95,7 +95,7 @@
Demo Application
).lower()
if is_detailed not in ("true", "false"):
- http_ctx.terminate(400, True)
+ http_ctx.terminate(400)
return "text/plain", "Invalid query"
include_server_version = is_detailed.lower() == "true"
@@ -115,10 +115,10 @@ Demo Application
@HttpEngine.route("/{app_or_server}/version", "GET")
def version(http_ctx, _):
include_server_version = False
- resource = http_ctx.url.split(b"/")[1]
+ resource = http_ctx.path_segment(0)
if resource not in (b"app", b"server"):
- http_ctx.terminate(404, True)
+ http_ctx.terminate(404)
return "text/plain", "Not found"
version_string = APP_VERSION if resource == b"app" else PYROBUSTA_VERSION
diff --git a/example/demo_app/app.py b/example/demo_app/app.py
index 8cf96db..12426df 100644
--- a/example/demo_app/app.py
+++ b/example/demo_app/app.py
@@ -17,7 +17,7 @@ def version(http_ctx, _):
).lower()
if is_detailed not in ("true", "false"):
- http_ctx.terminate(400, True)
+ http_ctx.terminate(400)
return "text/plain", "Invalid query"
include_server_version = is_detailed.lower() == "true"
@@ -37,10 +37,10 @@ def version(http_ctx, _):
@HttpEngine.route("/{app_or_server}/version", "GET")
def version(http_ctx, _):
include_server_version = False
- resource = http_ctx.url.split(b"/")[1]
+ resource = http_ctx.path_segment(0)
if resource not in (b"app", b"server"):
- http_ctx.terminate(404, True)
+ http_ctx.terminate(404)
return "text/plain", "Not found"
version_string = APP_VERSION if resource == b"app" else PYROBUSTA_VERSION
@@ -59,4 +59,4 @@ async def main():
if __name__ == "__main__":
- asyncio.run(main())
\ No newline at end of file
+ asyncio.run(main())
diff --git a/src/pyrobusta/protocol/http.py b/src/pyrobusta/protocol/http.py
index 8e7146b..81ca0dc 100644
--- a/src/pyrobusta/protocol/http.py
+++ b/src/pyrobusta/protocol/http.py
@@ -57,7 +57,6 @@ class HttpEngine:
"status_code",
"resp_headers",
"resp_handler",
- "aborted",
"version",
"headers",
"method",
@@ -66,6 +65,7 @@ class HttpEngine:
"content_len_cnt",
"recv_chunk_size",
"is_req_empty",
+ "_is_req_complete",
"mp_boundary",
"mp_is_first",
"mp_is_last",
@@ -151,7 +151,6 @@ def __init__(self):
self.status_code = None
self.resp_headers = []
self.resp_handler = None
- self.aborted = False
# [Recived request]
self.version = None
@@ -162,6 +161,7 @@ def __init__(self):
self.content_len_cnt = 0
self.recv_chunk_size = 0
self.is_req_empty = True
+ self._is_req_complete = False
# [Multipart state]
self.mp_boundary = None
@@ -175,7 +175,6 @@ def reset(self):
self.status_code = None
self.resp_headers.clear()
self.resp_handler = None
- self.aborted = False
self.version = None
self.headers.clear()
self.method = None
@@ -184,6 +183,7 @@ def reset(self):
self.content_len_cnt = 0
self.recv_chunk_size = 0
self.is_req_empty = True
+ self._is_req_complete = False
self.mp_boundary = None
# =========================================
@@ -255,6 +255,16 @@ def percent_decode(s: str):
i += 1
return "".join(out)
+ def path_segment(self, idx: int):
+ """
+ Return the nth path segment of the URL path.
+ The index is shifted by one to ignore the first
+ empty segment before the leading slash ('/').
+ :param idx: index of the segment
+ :return: string path segment
+ """
+ return self.url.split(b"/")[idx + 1].decode("ascii")
+
def get_query_param(self, key: str, default: str = None) -> str:
"""
Parse a query and return the value belonging to a key
@@ -493,8 +503,10 @@ def set_response_body(
:param body: body to be sent in the response
:param content_type: content-type of the body
"""
- if body is None:
- return
+ self._unset_response_handler()
+
+ if not body:
+ body_encoded = b""
if isinstance(body, (bytes, bytearray, memoryview)):
body_encoded = body
elif isinstance(body, str):
@@ -503,19 +515,31 @@ def set_response_body(
body_encoded = dumps(body).encode()
else:
raise ValueError("Unhandled body type")
+
self.set_response_header(
b"content-length", str(len(body_encoded)).encode("ascii")
)
- self.set_response_header(b"content-type", content_type.encode("ascii"))
- if self.method != self.HEAD:
- self.resp_handler = BytesIO(body_encoded)
+
+ if len(body_encoded):
+ self.set_response_header(b"content-type", content_type.encode("ascii"))
+
+ if self.method != self.HEAD:
+ self.resp_handler = BytesIO(body_encoded)
+
+ def _unset_response_handler(self):
+ """
+ Unset the response handler (if set).
+ """
+ if type(self.resp_handler).__name__ in ("FileIO", "BytesIO"):
+ self.resp_handler.close()
+ self.resp_handler = None
def do_keep_alive(self):
"""
Determine if the connection should be kept alive
depending on the HTTP version and headers sent in the request.
"""
- if self.aborted:
+ if self.is_terminated() and not self._is_req_complete:
return False
connection_tokens = [
@@ -533,8 +557,7 @@ def _handle_route_response(self, callback_response: tuple | None):
set a status code, default to HTTP 200. If the handler returns
a response body and content type, set them accordingly.
"""
- if not self.is_terminated():
- self.terminate(200, True)
+ self.terminate(self.status_code or 200)
if callback_response is None:
return
@@ -547,36 +570,25 @@ def _handle_route_response(self, callback_response: tuple | None):
self.set_response_body(data, content_type=dtype)
- def terminate(self, status_code: int, request_complete: bool = False):
+ def terminate(self, status_code: int):
"""
Regular state machine termination with a specific status code.
:param status_code: HTTP status code
- :param request_complete: true if the complete request is processed
"""
- self.state = None
+ self.state = self._terminal_st
+ if not isinstance(status_code, int) or status_code not in self.RESP_HEADERS:
+ raise ValueError("Invalid status")
self.status_code = status_code
- if self.version == b"HTTP/1.0" and self.do_keep_alive() and request_complete:
- self.set_response_header(b"connection", b"keep-alive")
- elif (
- self.version == b"HTTP/1.1"
- and not self.do_keep_alive()
- and not request_complete
- ):
- self.set_response_header(b"connection", b"close")
-
def abort(self, status_code: int):
"""
Abort state machine due to runtime errors.
Reset any header or response body set earlier.
:param status_code: HTTP status code
"""
- self.aborted = True
self.resp_headers = []
- if type(self.resp_handler).__name__ in ("FileIO", "BytesIO"):
- self.resp_handler.close()
- self.resp_handler = None
- self.terminate(status_code, False)
+ self._unset_response_handler()
+ self.terminate(status_code)
def is_request_empty(self):
"""
@@ -588,7 +600,7 @@ def is_terminated(self):
"""
Returns true if the state machine is terminated.
"""
- return self.state is None and self.status_code
+ return self.state is None
def run(self, rx):
"""
@@ -651,6 +663,7 @@ def _consume_payload(self, rx, size, last=False):
of content length when the last flag is set. When the request is chunked,
the content length should not be set, otherwise it is ignored.
"""
+ assert not self._is_req_complete
if (
not self.is_chunked()
and "content-length" in self.headers
@@ -665,6 +678,7 @@ def _consume_payload(self, rx, size, last=False):
raise InvalidContentLength()
self.content_len_cnt += size
rx.consume(size)
+ self._is_req_complete = last
# ================================================================================
# Parser states
@@ -737,7 +751,7 @@ def _route_request_st(self, _):
if self.method == self.OPTIONS:
supported_methods = self._supported_methods(self.url)
self.set_response_header(b"allow", b", ".join(supported_methods))
- self.terminate(204, True)
+ self.terminate(204)
return
if self.has_payload():
if self.method in (self.GET, self.HEAD):
@@ -820,7 +834,7 @@ def _app_endpoint_st(self, rx):
self, bytes(rx.peek(self.recv_chunk_size))
)
self._consume_payload(rx, self.recv_chunk_size + 2)
- if not self.is_terminated():
+ if not self.state == self._terminal_st: # pylint: disable=W0143
self.state = self._recv_chunk_size_st
return
else:
@@ -834,6 +848,7 @@ def _app_endpoint_st(self, rx):
self._consume_payload(rx, self.headers["content-length"], last=True)
else:
callback_response = callback(self, b"")
+ self._is_req_complete = True
self._handle_route_response(callback_response)
@@ -855,7 +870,7 @@ def _fs_retrieve_st(self, _):
try:
if not is_path_served:
stat(norm_path)
- self.terminate(403, True)
+ self.terminate(403)
return
try:
@@ -870,12 +885,12 @@ def _fs_retrieve_st(self, _):
b"content-length", str(stat(norm_path)[6]).encode("ascii")
)
self.set_response_header(b"content-type", content_type)
- self.terminate(200, True)
+ self.terminate(200)
if self.method != self.HEAD:
self.resp_handler = open(norm_path, "rb") # pylint: disable=R1732
return
except OSError:
- self.terminate(404, True)
+ self.terminate(404)
def _start_multipart_parser_st(self, rx): # pylint: disable=W0613
"""
@@ -889,6 +904,29 @@ def generate_multipart_response(self, callback, dtype): # pylint: disable=W0613
"""
self.abort(503)
+ def _terminal_st(self, rx): # pylint: disable=W0613
+ """
+ Terminal state for finalizing request/response processing.
+ """
+ if (
+ self.version == b"HTTP/1.0"
+ and self.do_keep_alive()
+ and self._is_req_complete
+ ):
+ self.set_response_header(b"connection", b"keep-alive")
+ elif self.version == b"HTTP/1.1" and (
+ not self.do_keep_alive() or not self._is_req_complete
+ ):
+ self.set_response_header(b"connection", b"close")
+
+ if (
+ self.get_response_header(b"transfer-encoding") != b"chunked"
+ and self.get_response_header(b"content-length") is None
+ ):
+ self.set_response_header(b"content-length", b"0")
+
+ self.state = None
+
def enable_optional_features():
"""
diff --git a/src/pyrobusta/protocol/http_file_server.py b/src/pyrobusta/protocol/http_file_server.py
index a674d31..2fd2df1 100644
--- a/src/pyrobusta/protocol/http_file_server.py
+++ b/src/pyrobusta/protocol/http_file_server.py
@@ -43,14 +43,14 @@ def fs_retrieve(http_ctx, _):
try:
if not is_path_served:
stat(norm_path)
- http_ctx.terminate(403, True)
+ http_ctx.terminate(403)
return "text/plain", "Forbidden"
# Retrieve directory structure
if stat(norm_path)[0] & 0x4000:
http_ctx.set_response_header(b"content-type", b"application/json")
http_ctx.set_response_header(b"transfer-encoding", b"chunked")
- http_ctx.terminate(200, True)
+ http_ctx.terminate(200)
http_ctx.resp_handler = _traverse_dir_factory(norm_path)
return
@@ -67,11 +67,11 @@ def fs_retrieve(http_ctx, _):
b"content-length", str(stat(norm_path)[6]).encode("ascii")
)
http_ctx.set_response_header(b"content-type", content_type)
- http_ctx.terminate(200, True)
+ http_ctx.terminate(200)
if http_ctx.method != http_ctx.HEAD:
http_ctx.resp_handler = open(norm_path, "rb") # pylint: disable=R1732
except OSError:
- http_ctx.terminate(404, True)
+ http_ctx.terminate(404)
return "text/plain", "Not found"
@@ -87,24 +87,24 @@ def delete_file(http_ctx, _):
try:
if not fs_path.startswith(_UPLOAD_ROOT):
stat(fs_path)
- http_ctx.terminate(403, True)
+ http_ctx.terminate(403)
return "text/plain", "Forbidden"
# Delete directory structure
if stat(fs_path)[0] & 0x4000:
if listdir(fs_path):
- http_ctx.terminate(400, True)
+ http_ctx.terminate(400)
return "text/plain", "Directory not empty"
rmdir(fs_path)
- http_ctx.terminate(204, True)
+ http_ctx.terminate(204)
return "text/plain", "Deleted"
# Delete file
remove(fs_path)
- http_ctx.terminate(204, True)
+ http_ctx.terminate(204)
return "text/plain", "Deleted"
except OSError:
- http_ctx.terminate(404, True)
+ http_ctx.terminate(404)
return "text/plain", "Not found"
@@ -120,7 +120,7 @@ def upload_file(http_ctx, payload: bytes):
return "text/plain", "Bad request"
if not normalize_path(target_path).startswith(_UPLOAD_ROOT):
- http_ctx.terminate(403, True)
+ http_ctx.terminate(403)
return "text/plain", "Forbidden"
try:
@@ -142,10 +142,10 @@ def upload_file(http_ctx, payload: bytes):
with open(normalize_path(target_path), "wb") as f:
f.write(payload)
- http_ctx.terminate(201, True)
+ http_ctx.terminate(201)
return "text/plain", "OK"
except OSError:
- http_ctx.terminate(404, True)
+ http_ctx.terminate(404)
return "text/plain", "Not found"
@@ -191,7 +191,7 @@ def bulk_upload_file(http_ctx, payload: tuple):
if file.endswith(suffix):
rename(_TMP_DIR + "/" + file, _UPLOAD_ROOT + "/" + file[: -len(suffix)])
- http_ctx.terminate(201, True)
+ http_ctx.terminate(201)
return "text/plain", "OK"
diff --git a/src/pyrobusta/protocol/http_multipart.py b/src/pyrobusta/protocol/http_multipart.py
index 86af41b..52985d7 100644
--- a/src/pyrobusta/protocol/http_multipart.py
+++ b/src/pyrobusta/protocol/http_multipart.py
@@ -139,7 +139,7 @@ def _parse_complete_part_st(self, rx):
raise http.MalformedRequest()
self._consume_payload(rx, len(self.mp_delimiter))
self.mp_is_first = False
- if not self.is_terminated():
+ if not self.state == self._terminal_st:
# Proceed to next part if there is no early termination
self.state = self._parse_boundary_st
elif callback_response:
diff --git a/tests/unit/test_http.py b/tests/unit/test_http.py
index db8ad5d..cd194a8 100644
--- a/tests/unit/test_http.py
+++ b/tests/unit/test_http.py
@@ -141,7 +141,7 @@ def test_status_parsing_unsupported_method(self):
self.assertEqual(self.engine.method, b"NOTSUPORTED")
self.assertEqual(self.engine.url, b"/index.html")
self.assertEqual(self.engine.version, b"HTTP/1.1")
- self.assertEqual(self.engine.state, None)
+ self.assertEqual(self.engine.state, self.engine._terminal_st)
self.assertEqual(self.engine.status_code, 405)
def test_status_parsing_unsupported_version(self):
@@ -156,7 +156,7 @@ def test_status_parsing_unsupported_version(self):
self.assertEqual(self.engine.method, b"GET")
self.assertEqual(self.engine.url, b"/index.html")
self.assertEqual(self.engine.version, b"HTTP/2")
- self.assertEqual(self.engine.state, None)
+ self.assertEqual(self.engine.state, self.engine._terminal_st)
self.assertEqual(self.engine.status_code, 505)
def test_header_parsing_valid(self):
@@ -219,7 +219,7 @@ def test_routing_unsupported_method(self):
self.engine.state(self.rx)
self.assertEqual(self.engine.status_code, 405)
- self.assertEqual(self.engine.state, None)
+ self.assertEqual(self.engine.state, self.engine._terminal_st)
self.assertIn(b"allow", self.engine.resp_headers)
self.assertIn(b"POST", self.engine.resp_headers)
@@ -237,7 +237,7 @@ def test_routing_options_method(self):
self.engine.state(self.rx)
self.assertEqual(self.engine.status_code, 204)
- self.assertEqual(self.engine.state, None)
+ self.assertEqual(self.engine.state, self.engine._terminal_st)
self.assertIn(b"allow", self.engine.resp_headers)
self.assertIn(b"GET, POST, PUT", self.engine.resp_headers)
@@ -256,7 +256,6 @@ def test_routing_get_method(self):
self.engine.state(self.rx)
self.assertEqual(self.engine.status_code, 200)
- self.assertEqual(self.engine.state, None)
self.assertEqual(
int(self.engine._lookup(self.engine.resp_headers, b"content-length")),
len(test_response),
@@ -278,7 +277,6 @@ def test_routing_head_method(self):
self.engine.state(self.rx)
self.assertEqual(self.engine.status_code, 200)
- self.assertEqual(self.engine.state, None)
self.assertEqual(
int(self.engine._lookup(self.engine.resp_headers, b"content-length")),
len(test_response),
@@ -445,7 +443,7 @@ def test_chunked_transfer_encoding_valid(self):
)
self.assertEqual(self.engine.status_code, 200)
- self.assertEqual(self.engine.state, None)
+ self.assertEqual(self.engine.state, self.engine._terminal_st)
def test_chunked_transfer_encoding_invalid_chunk_size_smaller(self):
self.engine.url = b"/api/test"
@@ -502,7 +500,6 @@ def test_payload_length_matches_content_length(self):
self.engine.state(self.rx)
self.assertEqual(self.engine.status_code, 200)
- self.assertEqual(self.engine.state, None)
test_callback.assert_called_with(self.engine, payload)
def test_payload_length_exceeds_content_length(self):
@@ -533,7 +530,6 @@ def test_payload_length_exceeds_content_length(self):
self.engine.state(self.rx)
self.assertEqual(self.engine.status_code, 200)
- self.assertEqual(self.engine.state, None)
test_callback.assert_called_with(self.engine, b"hello world")
self.assertEqual(
self.rx.peek(), b"!"
@@ -610,7 +606,7 @@ def test_file_serving_root(self, *_):
self.assertEqual(self.engine.resp_handler.read(), file_content)
self.assertEqual(self.engine.status_code, 200)
- self.assertEqual(self.engine.state, None)
+ self.assertEqual(self.engine.state, self.engine._terminal_st)
@patch_os_stat()
def test_file_serving_subdir(self, *_):
@@ -626,7 +622,7 @@ def test_file_serving_subdir(self, *_):
self.assertEqual(self.engine.resp_handler.read(), file_content)
self.assertEqual(self.engine.status_code, 200)
- self.assertEqual(self.engine.state, None)
+ self.assertEqual(self.engine.state, self.engine._terminal_st)
@patch_os_stat()
def test_file_serving_missing_file(self, *_):
@@ -638,7 +634,7 @@ def test_file_serving_missing_file(self, *_):
self.engine.state(self.rx)
self.assertEqual(self.engine.status_code, 404)
- self.assertEqual(self.engine.state, None)
+ self.assertEqual(self.engine.state, self.engine._terminal_st)
@patch_os_stat()
def test_file_serving_known_content_type(self, *_):
@@ -658,7 +654,7 @@ def test_file_serving_known_content_type(self, *_):
b"application/javascript",
)
self.assertEqual(self.engine.status_code, 200)
- self.assertEqual(self.engine.state, None)
+ self.assertEqual(self.engine.state, self.engine._terminal_st)
@patch_os_stat()
def test_file_serving_fallback_content_type(self, *_):
@@ -678,7 +674,7 @@ def test_file_serving_fallback_content_type(self, *_):
b"application/octet-stream",
)
self.assertEqual(self.engine.status_code, 200)
- self.assertEqual(self.engine.state, None)
+ self.assertEqual(self.engine.state, self.engine._terminal_st)
if __name__ == "__main__":
diff --git a/tests/unit/test_http_file_server.py b/tests/unit/test_http_file_server.py
index 597f7ed..4131cac 100644
--- a/tests/unit/test_http_file_server.py
+++ b/tests/unit/test_http_file_server.py
@@ -44,7 +44,7 @@ def test_file_serving_missing_file(self, *_):
self.engine.state(self.rx)
self.assertEqual(self.engine.status_code, 404)
- self.assertEqual(self.engine.state, None)
+ self.assertEqual(self.engine.state, self.engine._terminal_st)
@patch_os_stat()
def test_file_serving_files_endpoint(self, *_):
@@ -60,7 +60,7 @@ def test_file_serving_files_endpoint(self, *_):
self.assertEqual(self.engine.resp_handler.read(), file_content)
self.assertEqual(self.engine.status_code, 200)
- self.assertEqual(self.engine.state, None)
+ self.assertEqual(self.engine.state, self.engine._terminal_st)
@patch_os_stat()
def test_file_serving_known_content_type(self, *_):
@@ -80,7 +80,7 @@ def test_file_serving_known_content_type(self, *_):
b"application/javascript",
)
self.assertEqual(self.engine.status_code, 200)
- self.assertEqual(self.engine.state, None)
+ self.assertEqual(self.engine.state, self.engine._terminal_st)
@patch_os_stat()
def test_file_serving_fallback_content_type(self, *_):
@@ -100,7 +100,7 @@ def test_file_serving_fallback_content_type(self, *_):
b"application/octet-stream",
)
self.assertEqual(self.engine.status_code, 200)
- self.assertEqual(self.engine.state, None)
+ self.assertEqual(self.engine.state, self.engine._terminal_st)
@patch_os_stat()
def test_file_serving_unserved_content_rejected(self, *_):
@@ -116,7 +116,7 @@ def test_file_serving_unserved_content_rejected(self, *_):
self.assertNotEqual(self.engine.resp_handler.read(), file_content)
self.assertEqual(self.engine.status_code, 403)
- self.assertEqual(self.engine.state, None)
+ self.assertEqual(self.engine.state, self.engine._terminal_st)
@patch_os_stat(stat_is_file=False)
def test_file_serving_directory_path(self):
@@ -133,7 +133,7 @@ def test_file_serving_directory_path(self):
b"application/json",
)
self.assertEqual(self.engine.status_code, 200)
- self.assertEqual(self.engine.state, None)
+ self.assertEqual(self.engine.state, self.engine._terminal_st)
@patch_os_stat()
def test_file_serving_directory_traversal(self):
@@ -179,7 +179,7 @@ def test_file_serving_missing_file(self, *_):
self.engine.state(self.rx)
self.assertEqual(self.engine.status_code, 404)
- self.assertEqual(self.engine.state, None)
+ self.assertEqual(self.engine.state, self.engine._terminal_st)
@patch_os_stat()
def test_file_serving_non_user_data_rejected(self, *_):
@@ -191,7 +191,7 @@ def test_file_serving_non_user_data_rejected(self, *_):
self.engine.state(self.rx)
self.assertEqual(self.engine.status_code, 403)
- self.assertEqual(self.engine.state, None)
+ self.assertEqual(self.engine.state, self.engine._terminal_st)
@patch_os_stat()
def test_file_serving_user_data_deleted(self, *_):
@@ -205,7 +205,7 @@ def test_file_serving_user_data_deleted(self, *_):
m.assert_called_once_with("/www/user_data/user_content.json")
self.assertEqual(self.engine.status_code, 204)
- self.assertEqual(self.engine.state, None)
+ self.assertEqual(self.engine.state, self.engine._terminal_st)
@patch_os_stat(stat_is_file=True)
def test_file_serving_user_directory_deleted(self, *_):
@@ -219,7 +219,7 @@ def test_file_serving_user_directory_deleted(self, *_):
m.assert_called_once_with("/www/user_data/user_dir")
self.assertEqual(self.engine.status_code, 204)
- self.assertEqual(self.engine.state, None)
+ self.assertEqual(self.engine.state, self.engine._terminal_st)
class TestFileServerUpload(TestHttpBase):
diff --git a/tests/unit/test_http_multipart.py b/tests/unit/test_http_multipart.py
index 740c488..406a2e3 100644
--- a/tests/unit/test_http_multipart.py
+++ b/tests/unit/test_http_multipart.py
@@ -150,7 +150,7 @@ def test_multipart_receiver_last_part(self):
self.engine.state(self.rx)
- self.assertEqual(self.engine.state, None)
+ self.assertEqual(self.engine.state, self.engine._terminal_st)
self.assertEqual(self.engine.status_code, 200)
test_callback.assert_called_once_with(
self.engine,