Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 17 additions & 47 deletions src/pyrobusta/protocol/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,7 @@ class HttpEngine:
"recv_chunk_size",
"is_req_empty",
"_is_req_complete",
"mp_boundary",
"mp_is_first",
"mp_is_last",
"mp_delimiter",
"mp_last_delimiter",
"_extras",
)

ROUTES = [] # (route, handler, HTTP method)
Expand All @@ -93,6 +89,8 @@ class HttpEngine:
b"408 Request Timeout",
413,
b"413 Content Too Large",
415,
b"415 Unsupported Media Type",
500,
b"500 Internal Server Error",
503,
Expand Down Expand Up @@ -163,8 +161,8 @@ def __init__(self):
self.is_req_empty = True
self._is_req_complete = False

# [Multipart state]
self.mp_boundary = None
# [Extras]
self._extras = None

def reset(self):
"""
Expand All @@ -184,7 +182,7 @@ def reset(self):
self.recv_chunk_size = 0
self.is_req_empty = True
self._is_req_complete = False
self.mp_boundary = None
self._extras = None

# =========================================
# Methods/decorators for routing
Expand Down Expand Up @@ -364,12 +362,12 @@ def _parse_headers(cls, raw_headers: memoryview) -> dict[str, str | int]:
for line in header_lines:
# pylint: disable=W0511
if any(c > 127 for c in line):
raise InvalidHeaders("Non-ASCII character")
raise InvalidHeaders()
if b":" not in line:
raise InvalidHeaders()
name, value = line.split(b":", 1)
if not name:
raise InvalidHeaders("Empty header name")
raise InvalidHeaders()
for c in name:
if (
48 <= c <= 57 # 0-9
Expand All @@ -378,10 +376,10 @@ def _parse_headers(cls, raw_headers: memoryview) -> dict[str, str | int]:
or c in (45, 95) # -_
):
continue
raise InvalidHeaders("Invalid header name")
raise InvalidHeaders()
name = name.strip().lower().decode("ascii")
if any((c < 32 and c != 9) or c == 127 for c in value):
raise InvalidHeaders("Invalid header value")
raise InvalidHeaders()
if name == "content-length":
value = int(value.strip())
else:
Expand All @@ -392,38 +390,6 @@ def _parse_headers(cls, raw_headers: memoryview) -> dict[str, str | int]:
headers[name] += ", " + value # Combined field value
return headers

@staticmethod
def _get_mp_boundary(headers: dict) -> str:
"""
Determine from the headers if a request is multipart,
and return the boundary value.
"""
content_type = headers.get("content-type")
if not content_type or not content_type.lower().startswith("multipart/"):
return None

parts = content_type.split(";")
for part in parts[1:]:
if "=" not in part:
continue
key, value = part.strip().split("=", 1)

if key.strip().lower() != "boundary":
continue
value = value.strip()

if value.startswith('"'):
if len(value) < 2 or not value.endswith('"'):
raise InvalidHeaders()
value = value[1:-1]
elif value.endswith('"'):
raise InvalidHeaders()

if not value:
raise InvalidHeaders()
return value
raise InvalidHeaders()

@classmethod
def _parse_body_part(cls, part: memoryview) -> tuple[dict, bytes]:
"""
Expand Down Expand Up @@ -756,9 +722,7 @@ def _route_request_st(self, _):
if self.has_payload():
if self.method in (self.GET, self.HEAD):
raise MalformedRequest()
if mp_boundary := self._get_mp_boundary(self.headers):
# Request body is multipart
self.mp_boundary = mp_boundary.encode("ascii")
if self.is_multipart():
self.state = self._start_multipart_parser_st
elif self.is_chunked():
# Request body is chunked
Expand All @@ -784,6 +748,9 @@ def _route_request_st(self, _):
return
# Fallback: serve file
if self.method in (self.GET, self.HEAD):
if self.has_payload():
raise MalformedRequest()
self._is_req_complete = True
self.state = self._fs_retrieve_st
return
self.terminate(404)
Expand Down Expand Up @@ -925,6 +892,9 @@ def _terminal_st(self, rx): # pylint: disable=W0613
):
self.set_response_header(b"content-length", b"0")

if not self.get_response_header(b"cache-control"):
self.set_response_header(b"cache-control", b"no-store")

self.state = None


Expand Down
4 changes: 2 additions & 2 deletions src/pyrobusta/protocol/http_file_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,12 @@ def delete_file(http_ctx, _):
return "text/plain", "Directory not empty"
rmdir(fs_path)
http_ctx.terminate(204)
return "text/plain", "Deleted"
return "text/plain", "OK"

# Delete file
remove(fs_path)
http_ctx.terminate(204)
return "text/plain", "Deleted"
return "text/plain", "OK"
except OSError:
http_ctx.terminate(404)
return "text/plain", "Not found"
Expand Down
71 changes: 58 additions & 13 deletions src/pyrobusta/protocol/http_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# pylint: disable=W0212,R0401

from pyrobusta.protocol import http
from pyrobusta.utils.helpers import add_method
from pyrobusta.utils.helpers import add_method, add_property, patch_extra_property


def generate_multipart_response(self, callback: callable, dtype: str):
Expand All @@ -23,7 +23,7 @@ def generate_multipart_response(self, callback: callable, dtype: str):
:param dtype: exact multipart content-type (multipart/*)
"""
if not callable(callback):
raise ValueError("Invalid function callback")
raise ValueError("Invalid callback")

boundary = self.MULTIPART_BOUNDARY
self.set_response_header(
Expand Down Expand Up @@ -82,6 +82,38 @@ def _multipart_wrapper(tx):
return _multipart_wrapper


def _get_mp_boundary(headers: dict) -> str:
"""
Determine from the headers if a request is multipart,
and return the boundary value.
"""
content_type = headers.get("content-type")
if not content_type or not content_type.lower().startswith("multipart/"):
return None

parts = content_type.split(";")
for part in parts[1:]:
if "=" not in part:
continue
key, value = part.strip().split("=", 1)

if key.strip().lower() != "boundary":
continue
value = value.strip()

if value.startswith('"'):
if len(value) < 2 or not value.endswith('"'):
raise http.InvalidHeaders()
value = value[1:-1]
elif value.endswith('"'):
raise http.InvalidHeaders()

if not value:
raise http.InvalidHeaders()
return value
raise http.InvalidHeaders()


def _start_multipart_parser_st(self, rx):
"""
Initial state for processing multipart requests.
Expand All @@ -90,13 +122,17 @@ def _start_multipart_parser_st(self, rx):
"""
if not "content-length" in self.headers:
raise http.InvalidContentLength()

self.mp_boundary = _get_mp_boundary(self.headers).encode("ascii")

if (start_delimiter := rx.find(b"\r\n")) == -1:
return
self.mp_delimiter = b"--" + self.mp_boundary + b"\r\n"
self.mp_last_delimiter = b"--" + self.mp_boundary + b"--"

if rx.peek(start_delimiter + 2) != self.mp_delimiter:
raise http.MalformedRequest()
self._consume_payload(rx, start_delimiter + 2)
self.mp_is_first = True
self.mp_is_last = False
self.state = self._parse_boundary_st


Expand Down Expand Up @@ -168,20 +204,29 @@ def apply_patches():
"""
Apply patches to class attributes for multipart parsing.
"""
orig_init = http.HttpEngine.__init__

def new_init(self, *args, **kwargs):
orig_init(self, *args, **kwargs)
self.mp_is_first = True
self.mp_is_last = False
self.mp_delimiter = None
self.mp_last_delimiter = None
def mp_delimiter(self):
if self.mp_boundary is None:
return None
return b"--" + self.mp_boundary + b"\r\n"

http.HttpEngine.__init__ = new_init
http.HttpEngine.MULTIPART_BOUNDARY = b"pyrobusta-boundary"
def mp_last_delimiter(self):
if self.mp_boundary is None:
return None
return b"--" + self.mp_boundary + b"--"

add_property(http.HttpEngine, mp_delimiter)
add_property(http.HttpEngine, mp_last_delimiter)

patch_extra_property(http.HttpEngine, "mp_boundary")
patch_extra_property(http.HttpEngine, "mp_is_first")
patch_extra_property(http.HttpEngine, "mp_is_last")

add_method(http.HttpEngine, generate_multipart_response)
add_method(http.HttpEngine, _get_mp_boundary, "static")
add_method(http.HttpEngine, _multipart_wrapper_factory, "static")
add_method(http.HttpEngine, _start_multipart_parser_st)
add_method(http.HttpEngine, _parse_boundary_st)
add_method(http.HttpEngine, _parse_complete_part_st)

http.HttpEngine.MULTIPART_BOUNDARY = b"pyrobusta-boundary"
26 changes: 26 additions & 0 deletions src/pyrobusta/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,29 @@ def add_method(cls, func: callable, method_type="instance"):
setattr(cls, func.__name__, classmethod(func))
else:
raise ValueError("Invalid type")


def add_property(cls, getter: callable, setter: callable = None):
"""
Add a property to a class.
"""
setattr(cls, getter.__name__, property(getter, setter))


# pylint: disable=W0212
def patch_extra_property(cls, name):
"""
Add a property to 'cls' that stores its value in the instance's
'_extras' dictionary. Intended for '__slots__' classes that cannot
have arbitrary instance attributes.
"""

def getter(self):
return self._extras.get(name) if self._extras else None

def setter(self, value):
if self._extras is None:
self._extras = {}
self._extras[name] = value

setattr(cls, name, property(getter, setter))
32 changes: 12 additions & 20 deletions tests/unit/test_http_file_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,11 +370,9 @@ def test_file_serving_single_file_upload(self, *_):
self.engine.version = b"HTTP/1.1"

self.engine.headers["content-length"] = 151
self.engine.headers["content-type"] = "multipart/form-data"

self.engine.mp_boundary = b"test-boundary"
self.engine.mp_delimiter = b"--test-boundary\r\n"
self.engine.mp_last_delimiter = b"--test-boundary--"
self.engine.headers["content-type"] = (
'multipart/form-data;boundary="test-boundary"'
)

self.engine.state = self.engine._start_multipart_parser_st
body_part = (
Expand Down Expand Up @@ -421,11 +419,9 @@ def test_file_serving_multiple_file_upload(self, *_):
self.engine.version = b"HTTP/1.1"

self.engine.headers["content-length"] = 287
self.engine.headers["content-type"] = "multipart/form-data"

self.engine.mp_boundary = b"test-boundary"
self.engine.mp_delimiter = b"--test-boundary\r\n"
self.engine.mp_last_delimiter = b"--test-boundary--"
self.engine.headers["content-type"] = (
'multipart/form-data;boundary="test-boundary"'
)

self.engine.state = self.engine._start_multipart_parser_st
body_part = (
Expand Down Expand Up @@ -488,11 +484,9 @@ def test_file_serving_single_file_multiple_parts_upload(self, *_):
self.engine.version = b"HTTP/1.1"

self.engine.headers["content-length"] = 285
self.engine.headers["content-type"] = "multipart/form-data"

self.engine.mp_boundary = b"test-boundary"
self.engine.mp_delimiter = b"--test-boundary\r\n"
self.engine.mp_last_delimiter = b"--test-boundary--"
self.engine.headers["content-type"] = (
'multipart/form-data;boundary="test-boundary"'
)

self.engine.state = self.engine._start_multipart_parser_st
body_part = (
Expand Down Expand Up @@ -543,11 +537,9 @@ def test_file_serving_multiple_file_chunked_upload(self, *_):
self.engine.method = b"POST"
self.engine.version = b"HTTP/1.1"
self.engine.headers["content-length"] = 565
self.engine.headers["content-type"] = "multipart/form-data"

self.engine.mp_boundary = b"test-boundary"
self.engine.mp_delimiter = b"--test-boundary\r\n"
self.engine.mp_last_delimiter = b"--test-boundary--"
self.engine.headers["content-type"] = (
'multipart/form-data;boundary="test-boundary"'
)

self.engine.state = self.engine._start_multipart_parser_st
body_part = (
Expand Down
Loading
Loading