Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ disable=E0611,
R1710

[DESIGN]
max-attributes=15
max-attributes=15
max-public-methods=25
4 changes: 2 additions & 2 deletions dist/pyrobusta/assets/www/examples.html
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ <h2>Demo Application</h2>
include_server_version = False

if http_ctx.query:
is_detailed = http_ctx.get_url_encoded_query_param(
http_ctx.query, "detailed", default="false"
is_detailed = http_ctx.get_query_param(
"detailed", default="false"
).lower()

if is_detailed not in ("true", "false"):
Expand Down
4 changes: 2 additions & 2 deletions example/demo_app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ def version(http_ctx, _):
include_server_version = False

if http_ctx.query:
is_detailed = http_ctx.get_url_encoded_query_param(
http_ctx.query, "detailed", default="false"
is_detailed = http_ctx.get_query_param(
"detailed", default="false"
).lower()

if is_detailed not in ("true", "false"):
Expand Down
6 changes: 2 additions & 4 deletions example/mem_usage/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@ def mem_usage(http_ctx, _):
usage_percentage = 100 * used / (free + used)

if http_ctx.query:
value_format = http_ctx.get_url_encoded_query_param(
http_ctx.query, "format", "bytes"
)
value_format = http_ctx.get_query_param("format", "bytes")
if value_format not in ("%", "bytes"):
raise ValueError("invalid format")

selector = http_ctx.get_url_encoded_query_param(http_ctx.query, "key", "")
selector = http_ctx.get_query_param("key", "")
if selector == "free":
if value_format == "%":
free = round(100 * free / (used + free), 2)
Expand Down
87 changes: 53 additions & 34 deletions src/pyrobusta/protocol/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def decorator(func):
return decorator

# =========================================
# Static helpers for parsing
# Helpers for parsing
# =========================================

@staticmethod
Expand All @@ -255,29 +255,31 @@ def percent_decode(s: str):
i += 1
return "".join(out)

@staticmethod
def get_url_encoded_query_param(query: str, key: str, default: str = None):
def get_query_param(self, key: str, default: str = None) -> str:
"""
Parse a query and return the value belonging to a key
according to the x-www-form-urlencoded format.
:param query: query part
:param key: key to parse from the query
:param default: default value to return when key is not present
:return: value of the key or default
"""
if query.startswith(key + "="):
if not self.query or not key:
return default

if self.query.startswith(key + "="):
idx_start = 0
elif (idx_start := query.find("&" + key + "=")) != -1:
elif (idx_start := self.query.find("&" + key + "=")) != -1:
idx_start += 1
elif default is None:
raise KeyError()
else:
return default

idx_end = -1
idx_end = query.find("&", idx_start)
idx_end = self.query.find("&", idx_start)
if idx_end > -1:
return query[idx_start + len(key) + 1 : idx_end]
return query[idx_start + len(key) + 1 :]
return self.query[idx_start + len(key) + 1 : idx_end]
return self.query[idx_start + len(key) + 1 :]

@staticmethod
def _is_matching_url_path(path: bytes, pattern: bytes) -> bool:
Expand Down Expand Up @@ -524,6 +526,27 @@ def do_keep_alive(self):
self.version == b"HTTP/1.1" and "close" not in connection_tokens
)

def _handle_route_response(self, callback_response: tuple | None):
"""
Terminate the state machine based on the return value of a
user-defined route handler. If the handler does not explicitly
set a status code, default to HTTP 200. If the handler returns
a response body and content type, set them accordingly.
"""
if not self.is_terminated():
self.terminate(200, True)

if callback_response is None:
return

dtype, data = callback_response
if dtype.startswith("multipart/") and callable(data):
self.set_response_header(b"transfer-encoding", b"chunked")
self.generate_multipart_response(data, dtype)
return

self.set_response_body(data, content_type=dtype)

def terminate(self, status_code: int, request_complete: bool = False):
"""
Regular state machine termination with a specific status code.
Expand Down Expand Up @@ -605,7 +628,13 @@ def is_chunked(self):
"""
Determines if the request has a payload with chunked transfer-encoding.
"""
return self.headers.get("transfer-encoding") == "chunked"
return self.headers.get("transfer-encoding", "").lower() == "chunked"

def is_multipart(self):
"""
Determines if the request has a multipart payload.
"""
return self.headers.get("content-type", "").lower().startswith("multipart/")

def has_payload(self):
"""
Expand Down Expand Up @@ -787,13 +816,17 @@ def _app_endpoint_st(self, rx):
if self.has_payload():
if self.is_chunked():
if self.recv_chunk_size:
callback(self, bytes(rx.peek(self.recv_chunk_size)))
callback_response = callback(
self, bytes(rx.peek(self.recv_chunk_size))
)
self._consume_payload(rx, self.recv_chunk_size + 2)
self.state = self._recv_chunk_size_st
return
# Last chunk, callback with empty body to signal end of request body
callback_response = callback(self, b"")
self._consume_payload(rx, self.recv_chunk_size + 2, last=True)
if not self.is_terminated():
self.state = self._recv_chunk_size_st
return
else:
# Last chunk, callback with empty body to signal end of request body
callback_response = callback(self, b"")
self._consume_payload(rx, self.recv_chunk_size + 2, last=True)
else:
callback_response = callback(
self, bytes(rx.peek(self.headers["content-length"]))
Expand All @@ -802,19 +835,7 @@ def _app_endpoint_st(self, rx):
else:
callback_response = callback(self, b"")

if not self.is_terminated():
self.terminate(200, True)

if callback_response is None:
return

dtype, data = callback_response
if dtype.startswith("multipart/") and callable(data):
self.set_response_header(b"transfer-encoding", b"chunked")
self.state = lambda _rx: self._generate_multipart_response(_rx, data, dtype)
return

self.set_response_body(data, content_type=dtype)
self._handle_route_response(callback_response)

def _fs_retrieve_st(self, _):
"""
Expand Down Expand Up @@ -860,15 +881,13 @@ def _start_multipart_parser_st(self, rx): # pylint: disable=W0613
"""
Initial state for processing multipart requests (placeholder).
"""
self.terminate(503)
self.abort(503)

def _generate_multipart_response(
self, rx, callback, dtype
): # pylint: disable=W0613
def generate_multipart_response(self, callback, dtype): # pylint: disable=W0613
"""
Generate multipart response depening on the exact content type (placeholder).
"""
self.terminate(503, True)
self.abort(503)


def enable_optional_features():
Expand Down
50 changes: 20 additions & 30 deletions src/pyrobusta/protocol/http_file_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,42 +113,33 @@ def upload_file(http_ctx, payload: bytes):
Callback function for handling single file uploads, supporting chunked transfer encoding.
Uploads are saved to _UPLOAD_ROOT, with the name determined by the URL path.
"""
content_type = http_ctx.headers.get("content-type")
if content_type and content_type.lower().startswith("multipart/"):
target_path = http_ctx.url.decode("ascii")[6:]

if http_ctx.is_multipart() or not is_file_path_valid(target_path):
http_ctx.terminate(400)
return "text/plain", "Bad request"

is_chunked = http_ctx.headers.get("transfer-encoding") == "chunked"

if is_chunked:
url_path = http_ctx.url.decode("ascii")
file_name_idx = url_path.rfind("/") + 1
if not file_name_idx:
http_ctx.terminate(400)
return "text/plain", "Bad request"
file_path = _TMP_DIR + "/" + f"{url_path[file_name_idx:]}.{http_ctx.id}"
else:
file_path = normalize_path(http_ctx.url.decode("ascii")[6:])

if not is_file_path_valid(file_path):
http_ctx.terminate(400)
return "text/plain", "Invalid or missing filename"
if not normalize_path(target_path).startswith(_UPLOAD_ROOT):
http_ctx.terminate(403, True)
return "text/plain", "Forbidden"

try:
if not file_path.startswith(_UPLOAD_ROOT) and not file_path.startswith(
_TMP_DIR
):
http_ctx.terminate(403, True)
return "text/plain", "Forbidden"
if http_ctx.is_chunked():
file_name_idx = target_path.rfind("/") + 1
if not file_name_idx:
http_ctx.terminate(400)
return "text/plain", "Bad request"

tmp_path = _TMP_DIR + "/" + f"{target_path[file_name_idx:]}.{http_ctx.id}"

if is_chunked:
if not payload: # Last chunk received, finalize upload
rename(file_path, normalize_path(http_ctx.url.decode("ascii")[6:]))
else:
with open(file_path, "ab") as f:
if payload: # Wait for more chunks before setting response status
with open(tmp_path, "ab") as f:
f.write(payload)
return
# Last chunk received, finalize upload
rename(tmp_path, normalize_path(target_path))
else:
with open(file_path, "wb") as f:
with open(normalize_path(target_path), "wb") as f:
f.write(payload)

http_ctx.terminate(201, True)
Expand All @@ -166,8 +157,7 @@ def bulk_upload_file(http_ctx, payload: tuple):
same file name, the content of the second part is appended to the first part.
Split files to multiple parts for chunking large files to avoid HTTP 413 errors.
"""
content_type = http_ctx.headers.get("content-type")
if not content_type or not content_type.lower().startswith("multipart/form-data"):
if not http_ctx.is_multipart():
http_ctx.terminate(400)
return "text/plain", "Bad request"

Expand Down
22 changes: 12 additions & 10 deletions src/pyrobusta/protocol/http_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@
from pyrobusta.utils.helpers import add_method


def _generate_multipart_response(self, _, callback: callable, dtype: str):
def generate_multipart_response(self, callback: callable, dtype: str):
"""
Generate multipart response depening on the exact content type.
The callback function is called without arguments, and it must return bytes-like objects.
:param callback: function for part generation, each call generates a separate part
:param dtype: exact multipart content-type (multipart/*)
"""
if type(callback).__name__ not in ("function", "closure"):
if not callable(callback):
raise ValueError("Invalid response handler")
self.terminate(200, True)

boundary = self.MULTIPART_BOUNDARY
self.set_response_header(
b"content-type", dtype.encode("ascii") + b"; boundary=" + boundary
Expand Down Expand Up @@ -134,12 +134,16 @@ def _parse_complete_part_st(self, rx):

# Process complete part
if not is_final:
callback(self, (part_headers, part_body))
callback_response = callback(self, (part_headers, part_body))
if rx.peek(len(self.mp_delimiter)) != self.mp_delimiter:
raise http.MalformedRequest()
self._consume_payload(rx, len(self.mp_delimiter))
self.mp_is_first = False
self.state = self._parse_boundary_st
if not self.is_terminated():
# Proceed to next part if there is no early termination
self.state = self._parse_boundary_st
elif callback_response:
self._handle_route_response(callback_response)
return

# Process last part
Expand All @@ -155,11 +159,9 @@ def _parse_complete_part_st(self, rx):
self._consume_payload(rx, 0, last=True)

self.mp_is_last = True
dtype, data = callback(self, (part_headers, part_body))
callback_response = callback(self, (part_headers, part_body))

if not self.is_terminated():
self.terminate(200, True)
self.set_response_body(data, dtype)
self._handle_route_response(callback_response)


def apply_patches():
Expand All @@ -178,7 +180,7 @@ def new_init(self, *args, **kwargs):
http.HttpEngine.__init__ = new_init
http.HttpEngine.MULTIPART_BOUNDARY = b"pyrobusta-boundary"

add_method(http.HttpEngine, _generate_multipart_response)
add_method(http.HttpEngine, generate_multipart_response)
add_method(http.HttpEngine, _multipart_wrapper_factory, "static")
add_method(http.HttpEngine, _start_multipart_parser_st)
add_method(http.HttpEngine, _parse_boundary_st)
Expand Down
9 changes: 8 additions & 1 deletion tests/unit/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,14 @@ def test_path_segment_validation(self):

def test_file_path_validation(self):
valid_paths = ["/file", "/dir1/file", "/dir-2/file", "/dir_3/file"]
invalid_paths = ["file", "dir1/file", "/dir\\segment/file"]
invalid_paths = [
"file",
"dir1/file",
"/dir\\segment/file",
"/",
"/dir/",
"/dir/file/",
]

for path in valid_paths:
self.assertTrue(self.helpers_module.is_file_path_valid(path))
Expand Down
Loading
Loading