Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ disable=E0611,
[DESIGN]
max-attributes=15
max-public-methods=25
max-branches=15
3 changes: 2 additions & 1 deletion src/pyrobusta/bindings/http_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ async def _run_state_machine(self):
self._engine.set_response_body(b"Read error: " + str(e).encode("ascii"))

# [2] process request by state machine
for _ in self._engine.run(self._recv_buf):
while True:
self._engine.run(self._recv_buf)
if self._prev_state == self._engine.state:
# No state transition occurred, read more data
break
Expand Down
89 changes: 41 additions & 48 deletions src/pyrobusta/protocol/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,54 +357,54 @@ def _parse_headers(cls, raw_headers: memoryview) -> dict[str, str | int]:
"""
Basic parser to extract HTTP/MIME headers.
"""
header_lines = bytes(raw_headers).split(b"\r\n")
headers = {}
for line in header_lines:
# pylint: disable=W0511
if any(c > 127 for c in line):
start = 0
n = len(raw_headers)

while start < n:
end = start
colon = -1
while end < n:
c = raw_headers[end]
if c > 127:
raise InvalidHeaders()
if c == 58 and colon == -1:
colon = end
if end + 1 < n and c == 13 and raw_headers[end + 1] == 10:
break
end += 1

if colon in (-1, start):
raise InvalidHeaders()
if b":" not in line:
raise InvalidHeaders()
name, value = line.split(b":", 1)
if not name:
raise InvalidHeaders()
for c in name:
if (

for i in range(start, colon):
c = raw_headers[i]
if not (
48 <= c <= 57 # 0-9
or 65 <= c <= 90 # A-Z
or 97 <= c <= 122 # a-z
or c in (45, 95) # -_
):
continue
raise InvalidHeaders()
name = name.strip().lower().decode("ascii")
if any((c < 32 and c != 9) or c == 127 for c in value):
raise InvalidHeaders()

name = bytes(raw_headers[start:colon]).strip(b" ").lower().decode("ascii")
value_bytes = bytes(raw_headers[colon + 1 : end]).strip(b" ")

if any((c < 32 and c != 9) or c == 127 for c in value_bytes):
raise InvalidHeaders()
if name == "content-length":
value = int(value.strip())
if not all(48 <= c <= 57 for c in value_bytes):
raise InvalidHeaders()
value = int(value_bytes)
else:
value = value.strip().decode("ascii")
value = value_bytes.decode("ascii")
if name not in headers and value:
headers[name] = value
elif value:
headers[name] += ", " + value # Combined field value
return headers

@classmethod
def _parse_body_part(cls, part: memoryview) -> tuple[dict, bytes]:
"""
Parse part headers and body and return them as a tuple.
"""
blank_idx = -1
for i in range(len(part) - 3):
if part[i : i + 4] == b"\r\n\r\n":
blank_idx = i
break
if blank_idx == -1:
raise InvalidHeaders()
headers = cls._parse_headers(part[:blank_idx])
body = part[blank_idx + 4 :]
return headers, body
start = end + 2
return headers

# =========================================
# Helpers for state machine termination
Expand Down Expand Up @@ -469,8 +469,6 @@ def set_response_body(
:param body: body to be sent in the response
:param content_type: content-type of the body
"""
self._unset_response_handler()

if not body:
body_encoded = b""
if isinstance(body, (bytes, bytearray, memoryview)):
Expand All @@ -486,20 +484,17 @@ def set_response_body(
b"content-length", str(len(body_encoded)).encode("ascii")
)

# Unset and clean up existing handler if set
if type(self.resp_handler).__name__ in ("FileIO", "BytesIO"):
self.resp_handler.close()
self.resp_handler = None

if len(body_encoded):
self.set_response_header(b"content-type", content_type.encode("ascii"))

if self.method != self.HEAD:
self.resp_handler = BytesIO(body_encoded)

def _unset_response_handler(self):
"""
Unset the response handler (if set).
"""
if type(self.resp_handler).__name__ in ("FileIO", "BytesIO"):
self.resp_handler.close()
self.resp_handler = None

def do_keep_alive(self):
"""
Determine if the connection should be kept alive
Expand Down Expand Up @@ -553,7 +548,7 @@ def abort(self, status_code: int):
:param status_code: HTTP status code
"""
self.resp_headers = []
self._unset_response_handler()
self.set_response_body(b"")
self.terminate(status_code)

def is_request_empty(self):
Expand All @@ -572,15 +567,13 @@ def run(self, rx):
"""
Run the state machine, consuming the content of a request buffer (rx).
Unlike individual states, this method does not raise an exception.
This method yields on every state transition allowing the calling side
This method returns on every state transition allowing the calling side
to flush the response buffer.
"""
if self.is_terminated():
return
try:
while not self.is_terminated():
self.state(rx)
yield
self.state(rx)
except BufferFullError:
self.abort(500)
self.set_response_body(b"Buffer full")
Expand Down
15 changes: 13 additions & 2 deletions src/pyrobusta/protocol/http_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# pylint: disable=W0212,R0401

from pyrobusta.protocol import http
from pyrobusta.utils.helpers import add_method, add_property, patch_extra_property
from pyrobusta.utils.patch import add_method, add_property, patch_extra_property


def generate_multipart_response(self, callback: callable, dtype: str):
Expand Down Expand Up @@ -147,6 +147,7 @@ def _parse_boundary_st(self, rx):
return

if is_last and self.content_len_cnt + rx.size() < self.headers["content-length"]:
# Wait for optional trailing newline
return

self.state = self._parse_complete_part_st
Expand All @@ -165,7 +166,17 @@ def _parse_complete_part_st(self, rx):
and rx.peek(len(self.mp_last_delimiter)) == self.mp_last_delimiter
)

part_headers, part_body = http.HttpEngine._parse_body_part(part)
# Parse part headers and part body
blank_idx = -1
for i in range(len(part) - 3):
if part[i : i + 4] == b"\r\n\r\n":
blank_idx = i
break
if blank_idx == -1:
raise http.InvalidHeaders()
part_headers = http.HttpEngine._parse_headers(part[:blank_idx])
part_body = part[blank_idx + 4 :]

handler = http.HttpEngine._get_handler(self.url, self.method)

# Process complete part
Expand Down
42 changes: 0 additions & 42 deletions src/pyrobusta/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,45 +77,3 @@ def is_path_segment_valid(filename: str):
):
return False
return True


def add_method(cls, func: callable, method_type="instance"):
"""
Helper to patch/extend classes with additional methods and states.
:param func: function to add
:param method_type: type of the method (instance, static, class)
"""
if method_type == "instance":
setattr(cls, func.__name__, func)
elif method_type == "static":
setattr(cls, func.__name__, staticmethod(func))
elif method_type == "class":
setattr(cls, func.__name__, classmethod(func))
else:
raise ValueError("Invalid type")


def add_property(cls, getter: callable, setter: callable = None):
"""
Add a property to a class.
"""
setattr(cls, getter.__name__, property(getter, setter))


# pylint: disable=W0212
def patch_extra_property(cls, name):
"""
Add a property to 'cls' that stores its value in the instance's
'_extras' dictionary. Intended for '__slots__' classes that cannot
have arbitrary instance attributes.
"""

def getter(self):
return self._extras.get(name) if self._extras else None

def setter(self, value):
if self._extras is None:
self._extras = {}
self._extras[name] = value

setattr(cls, name, property(getter, setter))
46 changes: 46 additions & 0 deletions src/pyrobusta/utils/patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""
Helper methods for patching classes
"""

# pylint: disable=W0212


def add_method(cls, func: callable, method_type="instance"):
"""
Helper to patch/extend classes with additional methods and states.
:param func: function to add
:param method_type: type of the method (instance, static, class)
"""
if method_type == "instance":
setattr(cls, func.__name__, func)
elif method_type == "static":
setattr(cls, func.__name__, staticmethod(func))
elif method_type == "class":
setattr(cls, func.__name__, classmethod(func))
else:
raise ValueError("Invalid type")


def add_property(cls, getter: callable, setter: callable = None):
"""
Add a property to a class.
"""
setattr(cls, getter.__name__, property(getter, setter))


def patch_extra_property(cls, name):
"""
Add a property to 'cls' that stores its value in the instance's
'_extras' dictionary. Intended for '__slots__' classes that cannot
have arbitrary instance attributes.
"""

def getter(self):
return self._extras.get(name) if self._extras else None

def setter(self, value):
if self._extras is None:
self._extras = {}
self._extras[name] = value

setattr(cls, name, property(getter, setter))
1 change: 0 additions & 1 deletion tests/unit/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ def test_header_parsing_incomplete_header(self):

def test_header_parsing_error(self):
for case in (
b"",
b":",
b": value",
b" leading-space: value",
Expand Down
Loading