diff --git a/.pylintrc b/.pylintrc
index 60a1103..ce7476a 100644
--- a/.pylintrc
+++ b/.pylintrc
@@ -6,4 +6,5 @@ disable=E0611,
R1710
[DESIGN]
-max-attributes=15
\ No newline at end of file
+max-attributes=15
+max-public-methods=25
diff --git a/dist/pyrobusta/assets/www/examples.html b/dist/pyrobusta/assets/www/examples.html
index e0cdfd9..1ade6de 100644
--- a/dist/pyrobusta/assets/www/examples.html
+++ b/dist/pyrobusta/assets/www/examples.html
@@ -90,8 +90,8 @@
Demo Application
include_server_version = False
if http_ctx.query:
- is_detailed = http_ctx.get_url_encoded_query_param(
- http_ctx.query, "detailed", default="false"
+ is_detailed = http_ctx.get_query_param(
+ "detailed", default="false"
).lower()
if is_detailed not in ("true", "false"):
diff --git a/example/demo_app/app.py b/example/demo_app/app.py
index 6817989..8cf96db 100644
--- a/example/demo_app/app.py
+++ b/example/demo_app/app.py
@@ -12,8 +12,8 @@ def version(http_ctx, _):
include_server_version = False
if http_ctx.query:
- is_detailed = http_ctx.get_url_encoded_query_param(
- http_ctx.query, "detailed", default="false"
+ is_detailed = http_ctx.get_query_param(
+ "detailed", default="false"
).lower()
if is_detailed not in ("true", "false"):
diff --git a/example/mem_usage/app.py b/example/mem_usage/app.py
index bfde590..f4bb084 100644
--- a/example/mem_usage/app.py
+++ b/example/mem_usage/app.py
@@ -13,13 +13,11 @@ def mem_usage(http_ctx, _):
usage_percentage = 100 * used / (free + used)
if http_ctx.query:
- value_format = http_ctx.get_url_encoded_query_param(
- http_ctx.query, "format", "bytes"
- )
+ value_format = http_ctx.get_query_param("format", "bytes")
if value_format not in ("%", "bytes"):
raise ValueError("invalid format")
- selector = http_ctx.get_url_encoded_query_param(http_ctx.query, "key", "")
+ selector = http_ctx.get_query_param("key", "")
if selector == "free":
if value_format == "%":
free = round(100 * free / (used + free), 2)
diff --git a/src/pyrobusta/protocol/http.py b/src/pyrobusta/protocol/http.py
index 12ff6d2..8e7146b 100644
--- a/src/pyrobusta/protocol/http.py
+++ b/src/pyrobusta/protocol/http.py
@@ -236,7 +236,7 @@ def decorator(func):
return decorator
# =========================================
- # Static helpers for parsing
+ # Helpers for parsing
# =========================================
@staticmethod
@@ -255,18 +255,20 @@ def percent_decode(s: str):
i += 1
return "".join(out)
- @staticmethod
- def get_url_encoded_query_param(query: str, key: str, default: str = None):
+ def get_query_param(self, key: str, default: str = None) -> str:
"""
Parse a query and return the value belonging to a key
according to the x-www-form-urlencoded format.
- :param query: query part
:param key: key to parse from the query
:param default: default value to return when key is not present
+ :return: value of the key or default
"""
- if query.startswith(key + "="):
+ if not self.query or not key:
+ return default
+
+ if self.query.startswith(key + "="):
idx_start = 0
- elif (idx_start := query.find("&" + key + "=")) != -1:
+ elif (idx_start := self.query.find("&" + key + "=")) != -1:
idx_start += 1
elif default is None:
raise KeyError()
@@ -274,10 +276,10 @@ def get_url_encoded_query_param(query: str, key: str, default: str = None):
return default
idx_end = -1
- idx_end = query.find("&", idx_start)
+ idx_end = self.query.find("&", idx_start)
if idx_end > -1:
- return query[idx_start + len(key) + 1 : idx_end]
- return query[idx_start + len(key) + 1 :]
+ return self.query[idx_start + len(key) + 1 : idx_end]
+ return self.query[idx_start + len(key) + 1 :]
@staticmethod
def _is_matching_url_path(path: bytes, pattern: bytes) -> bool:
@@ -524,6 +526,27 @@ def do_keep_alive(self):
self.version == b"HTTP/1.1" and "close" not in connection_tokens
)
+ def _handle_route_response(self, callback_response: tuple | None):
+ """
+ Terminate the state machine based on the return value of a
+ user-defined route handler. If the handler does not explicitly
+ set a status code, default to HTTP 200. If the handler returns
+ a response body and content type, set them accordingly.
+ """
+ if not self.is_terminated():
+ self.terminate(200, True)
+
+ if callback_response is None:
+ return
+
+ dtype, data = callback_response
+ if dtype.startswith("multipart/") and callable(data):
+ self.set_response_header(b"transfer-encoding", b"chunked")
+ self.generate_multipart_response(data, dtype)
+ return
+
+ self.set_response_body(data, content_type=dtype)
+
def terminate(self, status_code: int, request_complete: bool = False):
"""
Regular state machine termination with a specific status code.
@@ -605,7 +628,13 @@ def is_chunked(self):
"""
Determines if the request has a payload with chunked transfer-encoding.
"""
- return self.headers.get("transfer-encoding") == "chunked"
+ return self.headers.get("transfer-encoding", "").lower() == "chunked"
+
+ def is_multipart(self):
+ """
+ Determines if the request has a multipart payload.
+ """
+ return self.headers.get("content-type", "").lower().startswith("multipart/")
def has_payload(self):
"""
@@ -787,13 +816,17 @@ def _app_endpoint_st(self, rx):
if self.has_payload():
if self.is_chunked():
if self.recv_chunk_size:
- callback(self, bytes(rx.peek(self.recv_chunk_size)))
+ callback_response = callback(
+ self, bytes(rx.peek(self.recv_chunk_size))
+ )
self._consume_payload(rx, self.recv_chunk_size + 2)
- self.state = self._recv_chunk_size_st
- return
- # Last chunk, callback with empty body to signal end of request body
- callback_response = callback(self, b"")
- self._consume_payload(rx, self.recv_chunk_size + 2, last=True)
+ if not self.is_terminated():
+ self.state = self._recv_chunk_size_st
+ return
+ else:
+ # Last chunk, callback with empty body to signal end of request body
+ callback_response = callback(self, b"")
+ self._consume_payload(rx, self.recv_chunk_size + 2, last=True)
else:
callback_response = callback(
self, bytes(rx.peek(self.headers["content-length"]))
@@ -802,19 +835,7 @@ def _app_endpoint_st(self, rx):
else:
callback_response = callback(self, b"")
- if not self.is_terminated():
- self.terminate(200, True)
-
- if callback_response is None:
- return
-
- dtype, data = callback_response
- if dtype.startswith("multipart/") and callable(data):
- self.set_response_header(b"transfer-encoding", b"chunked")
- self.state = lambda _rx: self._generate_multipart_response(_rx, data, dtype)
- return
-
- self.set_response_body(data, content_type=dtype)
+ self._handle_route_response(callback_response)
def _fs_retrieve_st(self, _):
"""
@@ -860,15 +881,13 @@ def _start_multipart_parser_st(self, rx): # pylint: disable=W0613
"""
Initial state for processing multipart requests (placeholder).
"""
- self.terminate(503)
+ self.abort(503)
- def _generate_multipart_response(
- self, rx, callback, dtype
- ): # pylint: disable=W0613
+ def generate_multipart_response(self, callback, dtype): # pylint: disable=W0613
"""
Generate multipart response depening on the exact content type (placeholder).
"""
- self.terminate(503, True)
+ self.abort(503)
def enable_optional_features():
diff --git a/src/pyrobusta/protocol/http_file_server.py b/src/pyrobusta/protocol/http_file_server.py
index 83bf3b3..a674d31 100644
--- a/src/pyrobusta/protocol/http_file_server.py
+++ b/src/pyrobusta/protocol/http_file_server.py
@@ -113,42 +113,33 @@ def upload_file(http_ctx, payload: bytes):
Callback function for handling single file uploads, supporting chunked transfer encoding.
Uploads are saved to _UPLOAD_ROOT, with the name determined by the URL path.
"""
- content_type = http_ctx.headers.get("content-type")
- if content_type and content_type.lower().startswith("multipart/"):
+ target_path = http_ctx.url.decode("ascii")[6:]
+
+ if http_ctx.is_multipart() or not is_file_path_valid(target_path):
http_ctx.terminate(400)
return "text/plain", "Bad request"
- is_chunked = http_ctx.headers.get("transfer-encoding") == "chunked"
-
- if is_chunked:
- url_path = http_ctx.url.decode("ascii")
- file_name_idx = url_path.rfind("/") + 1
- if not file_name_idx:
- http_ctx.terminate(400)
- return "text/plain", "Bad request"
- file_path = _TMP_DIR + "/" + f"{url_path[file_name_idx:]}.{http_ctx.id}"
- else:
- file_path = normalize_path(http_ctx.url.decode("ascii")[6:])
-
- if not is_file_path_valid(file_path):
- http_ctx.terminate(400)
- return "text/plain", "Invalid or missing filename"
+ if not normalize_path(target_path).startswith(_UPLOAD_ROOT):
+ http_ctx.terminate(403, True)
+ return "text/plain", "Forbidden"
try:
- if not file_path.startswith(_UPLOAD_ROOT) and not file_path.startswith(
- _TMP_DIR
- ):
- http_ctx.terminate(403, True)
- return "text/plain", "Forbidden"
+ if http_ctx.is_chunked():
+ file_name_idx = target_path.rfind("/") + 1
+ if not file_name_idx:
+ http_ctx.terminate(400)
+ return "text/plain", "Bad request"
+
+ tmp_path = _TMP_DIR + "/" + f"{target_path[file_name_idx:]}.{http_ctx.id}"
- if is_chunked:
- if not payload: # Last chunk received, finalize upload
- rename(file_path, normalize_path(http_ctx.url.decode("ascii")[6:]))
- else:
- with open(file_path, "ab") as f:
+ if payload: # Wait for more chunks before setting response status
+ with open(tmp_path, "ab") as f:
f.write(payload)
+ return
+ # Last chunk received, finalize upload
+ rename(tmp_path, normalize_path(target_path))
else:
- with open(file_path, "wb") as f:
+ with open(normalize_path(target_path), "wb") as f:
f.write(payload)
http_ctx.terminate(201, True)
@@ -166,8 +157,7 @@ def bulk_upload_file(http_ctx, payload: tuple):
same file name, the content of the second part is appended to the first part.
Split files to multiple parts for chunking large files to avoid HTTP 413 errors.
"""
- content_type = http_ctx.headers.get("content-type")
- if not content_type or not content_type.lower().startswith("multipart/form-data"):
+ if not http_ctx.is_multipart():
http_ctx.terminate(400)
return "text/plain", "Bad request"
diff --git a/src/pyrobusta/protocol/http_multipart.py b/src/pyrobusta/protocol/http_multipart.py
index c21f264..86af41b 100644
--- a/src/pyrobusta/protocol/http_multipart.py
+++ b/src/pyrobusta/protocol/http_multipart.py
@@ -15,16 +15,16 @@
from pyrobusta.utils.helpers import add_method
-def _generate_multipart_response(self, _, callback: callable, dtype: str):
+def generate_multipart_response(self, callback: callable, dtype: str):
"""
Generate multipart response depening on the exact content type.
The callback function is called without arguments, and it must return bytes-like objects.
:param callback: function for part generation, each call generates a separate part
:param dtype: exact multipart content-type (multipart/*)
"""
- if type(callback).__name__ not in ("function", "closure"):
+ if not callable(callback):
raise ValueError("Invalid response handler")
- self.terminate(200, True)
+
boundary = self.MULTIPART_BOUNDARY
self.set_response_header(
b"content-type", dtype.encode("ascii") + b"; boundary=" + boundary
@@ -134,12 +134,16 @@ def _parse_complete_part_st(self, rx):
# Process complete part
if not is_final:
- callback(self, (part_headers, part_body))
+ callback_response = callback(self, (part_headers, part_body))
if rx.peek(len(self.mp_delimiter)) != self.mp_delimiter:
raise http.MalformedRequest()
self._consume_payload(rx, len(self.mp_delimiter))
self.mp_is_first = False
- self.state = self._parse_boundary_st
+ if not self.is_terminated():
+ # Proceed to next part if there is no early termination
+ self.state = self._parse_boundary_st
+ elif callback_response:
+ self._handle_route_response(callback_response)
return
# Process last part
@@ -155,11 +159,9 @@ def _parse_complete_part_st(self, rx):
self._consume_payload(rx, 0, last=True)
self.mp_is_last = True
- dtype, data = callback(self, (part_headers, part_body))
+ callback_response = callback(self, (part_headers, part_body))
- if not self.is_terminated():
- self.terminate(200, True)
- self.set_response_body(data, dtype)
+ self._handle_route_response(callback_response)
def apply_patches():
@@ -178,7 +180,7 @@ def new_init(self, *args, **kwargs):
http.HttpEngine.__init__ = new_init
http.HttpEngine.MULTIPART_BOUNDARY = b"pyrobusta-boundary"
- add_method(http.HttpEngine, _generate_multipart_response)
+ add_method(http.HttpEngine, generate_multipart_response)
add_method(http.HttpEngine, _multipart_wrapper_factory, "static")
add_method(http.HttpEngine, _start_multipart_parser_st)
add_method(http.HttpEngine, _parse_boundary_st)
diff --git a/tests/unit/test_helpers.py b/tests/unit/test_helpers.py
index 2d85271..7214a41 100644
--- a/tests/unit/test_helpers.py
+++ b/tests/unit/test_helpers.py
@@ -121,7 +121,14 @@ def test_path_segment_validation(self):
def test_file_path_validation(self):
valid_paths = ["/file", "/dir1/file", "/dir-2/file", "/dir_3/file"]
- invalid_paths = ["file", "dir1/file", "/dir\\segment/file"]
+ invalid_paths = [
+ "file",
+ "dir1/file",
+ "/dir\\segment/file",
+ "/",
+ "/dir/",
+ "/dir/file/",
+ ]
for path in valid_paths:
self.assertTrue(self.helpers_module.is_file_path_valid(path))
diff --git a/tests/unit/test_http.py b/tests/unit/test_http.py
index a02907c..db8ad5d 100644
--- a/tests/unit/test_http.py
+++ b/tests/unit/test_http.py
@@ -319,9 +319,7 @@ def test_single_url_encoded_query_parameter(self):
self.rx.write(request[i : i + 1])
self.engine.state(self.rx)
- self.assertEqual(
- self.engine.get_url_encoded_query_param(self.engine.query, "param"), "value"
- )
+ self.assertEqual(self.engine.get_query_param("param"), "value")
def test_multiple_url_encoded_query_parameter(self):
request = (
@@ -333,15 +331,15 @@ def test_multiple_url_encoded_query_parameter(self):
self.engine.state(self.rx)
self.assertEqual(
- self.engine.get_url_encoded_query_param(self.engine.query, "param1"),
+ self.engine.get_query_param("param1"),
"value1",
)
self.assertEqual(
- self.engine.get_url_encoded_query_param(self.engine.query, "param2"),
+ self.engine.get_query_param("param2"),
"value2",
)
self.assertEqual(
- self.engine.get_url_encoded_query_param(self.engine.query, "param3"),
+ self.engine.get_query_param("param3"),
"value3",
)
@@ -353,22 +351,20 @@ def test_empty_or_missing_url_encoded_query_parameter(self):
self.engine.state(self.rx)
self.assertEqual(
- self.engine.get_url_encoded_query_param(self.engine.query, "param1"),
+ self.engine.get_query_param("param1"),
"",
)
self.assertEqual(
- self.engine.get_url_encoded_query_param(self.engine.query, "param2"),
+ self.engine.get_query_param("param2"),
"",
)
self.assertEqual(
- self.engine.get_url_encoded_query_param(
- self.engine.query, "param3", "default"
- ),
+ self.engine.get_query_param("param3", "default"),
"default",
)
with self.assertRaises(KeyError):
- self.engine.get_url_encoded_query_param(self.engine.query, "param3")
+ self.engine.get_query_param("param3")
def test_overlapping_url_encoded_query_parameter(self):
request = b"GET /api/test?data=value1&ta=value2&a=value3 HTTP/1.1\r\n"
@@ -378,15 +374,15 @@ def test_overlapping_url_encoded_query_parameter(self):
self.engine.state(self.rx)
self.assertEqual(
- self.engine.get_url_encoded_query_param(self.engine.query, "data"),
+ self.engine.get_query_param("data"),
"value1",
)
self.assertEqual(
- self.engine.get_url_encoded_query_param(self.engine.query, "ta"),
+ self.engine.get_query_param("ta"),
"value2",
)
self.assertEqual(
- self.engine.get_url_encoded_query_param(self.engine.query, "a"),
+ self.engine.get_query_param("a"),
"value3",
)
diff --git a/tests/unit/test_http_file_server.py b/tests/unit/test_http_file_server.py
index 2bfb210..597f7ed 100644
--- a/tests/unit/test_http_file_server.py
+++ b/tests/unit/test_http_file_server.py
@@ -276,6 +276,23 @@ def test_file_serving_complete_file_invalid_name(self, *_):
self.assertEqual(self.engine.status_code, 400)
+ def test_file_serving_complete_file_invalid_path(self, *_):
+ self.engine.url = b"/files/www/user_data/file/"
+ self.engine.method = b"PUT"
+ self.engine.version = b"HTTP/1.1"
+
+ self.engine.headers["content-length"] = 28
+ self.engine.headers["content-type"] = "application/octet-stream"
+
+ self.engine.state = self.engine._app_endpoint_st
+ body_part = b"File uploaded for testing.\r\n"
+ self.rx.write(body_part)
+
+ while self.engine.state is not None:
+ self.engine.state(self.rx)
+
+ self.assertEqual(self.engine.status_code, 400)
+
def test_file_serving_chunked_file_upload(self, *_):
self.engine.url = b"/files/www/user_data/chunked.txt"
self.engine.method = b"PUT"