Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions salt/netapi/rest_cherrypy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import logging
import os

import salt.utils.args

from salt.utils.versions import Version

# pylint: disable=C0103
Expand All @@ -27,6 +29,22 @@
cpy_min = "3.2.2"


def parse_timeout(value, name="timeout"):
"""
Parse a timeout value from config or request data.
"""
if value in (None, ""):
return None
parsed = salt.utils.args.yamlify_arg(value)
if parsed is None:
return None
if isinstance(parsed, bool) or not isinstance(parsed, (int, float)):
raise ValueError(f"{name} must be a number")
Comment on lines +41 to +42
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't isinstance(parsed, bool) match not isinstance(parsed, (int, float)) so shouldn't not isinstance(parsed, (int, float)) be sufficient?

    if not isinstance(parsed, (int, float)):
        raise ValueError(f"{name} must be a number")

if parsed < 0:
raise ValueError(f"{name} must be >= 0")
return parsed


def __virtual__():
short_name = __name__.rsplit(".", maxsplit=1)[-1]
mod_opts = __opts__.get(short_name, {})
Expand Down
18 changes: 18 additions & 0 deletions salt/netapi/rest_cherrypy/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,7 @@
import salt.utils.stringutils
import salt.utils.versions
import salt.utils.yaml
from . import parse_timeout

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1125,6 +1126,23 @@ def lowdata_fmt():
else:
cherrypy.serving.request.lowstate = data

_normalize_lowstate_timeouts(cherrypy.serving.request.lowstate)


def _normalize_lowstate_timeouts(lowstate):
if not lowstate:
return
chunks = lowstate if isinstance(lowstate, list) else [lowstate]
for chunk in chunks:
if not isinstance(chunk, Mapping):
continue
if "timeout" not in chunk:
continue
try:
chunk["timeout"] = parse_timeout(chunk["timeout"])
except ValueError as exc:
raise cherrypy.HTTPError(400, str(exc))


tools_config = {
"on_start_resource": [
Expand Down
35 changes: 24 additions & 11 deletions salt/netapi/rest_cherrypy/event_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,21 @@

import salt.netapi
import salt.utils.json
from . import parse_timeout

logger = logging.getLogger(__name__)


def _get_rest_timeout(opts):
try:
return parse_timeout(opts.get("rest_timeout"), name="rest_timeout")
except ValueError as exc:
logger.warning(
"Invalid rest_timeout value %r: %s", opts.get("rest_timeout"), exc
)
return None


class SaltInfo:
"""
Class to handle processing and publishing of "real time" Salt upates.
Expand Down Expand Up @@ -166,17 +177,19 @@ def process_presence_events(self, event_data, token, opts):
if tgt:
changed = True
client = salt.netapi.NetapiClient(opts)
client.run(
{
"fun": "grains.items",
"tgt": tgt,
"expr_type": "list",
"mode": "client",
"client": "local",
"asynchronous": "local_async",
"token": token,
}
)
low = {
"fun": "grains.items",
"tgt": tgt,
"expr_type": "list",
"mode": "client",
"client": "local",
"asynchronous": "local_async",
"token": token,
}
timeout = _get_rest_timeout(opts)
if timeout is not None:
low["timeout"] = timeout
client.run(low)

if changed:
self.publish_minions()
Expand Down
26 changes: 26 additions & 0 deletions tests/pytests/unit/netapi/cherrypy/test_timeout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pytest

from salt.netapi.rest_cherrypy import parse_timeout


@pytest.mark.parametrize(
"value, expected",
[
("60", 60),
(60, 60),
("2.5", 2.5),
("0", 0),
("None", None),
("", None),
(None, None),
],
)
def test_parse_timeout_valid(value, expected):
assert parse_timeout(value) == expected


@pytest.mark.parametrize("value", ["nope", {}, [], True, -1, "-5"])
def test_parse_timeout_invalid(value):
with pytest.raises(ValueError):
parse_timeout(value)