Skip to content

Commit 4e3aeb7

Browse files
committed
Refactor warning/error config checks format
- Removes RuleValidator class and uses a dict data structure with key as error message and value as the check - Extracts [endpoint element/protocol version/application identifier/node country] fetching to methods since they are part of the eIDASConfig and will be overriden by the eIDASIdPConfig
1 parent faba67f commit 4e3aeb7

File tree

3 files changed

+58
-99
lines changed

3 files changed

+58
-99
lines changed

src/saml2/config.py

Lines changed: 56 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from saml2.mdstore import MetadataStore
2525
from saml2.saml import NAME_FORMAT_URI
2626
from saml2.virtual_org import VirtualOrg
27-
from saml2.utility.config import RuleValidator, should_warning, must_error
27+
from saml2.utility.config import ConfigValidationError
2828

2929
logger = logging.getLogger(__name__)
3030

@@ -583,15 +583,17 @@ def ecp_endpoint(self, ipaddress):
583583

584584

585585
class eIDASConfig(Config):
586-
@classmethod
587-
def assert_not_declared(cls, error_signal):
588-
return (lambda x: x is None,
589-
partial(error_signal, message="not be declared"))
586+
def get_endpoint_element(self, element):
587+
pass
588+
589+
def get_protocol_version(self):
590+
pass
590591

591-
@classmethod
592-
def assert_declared(cls, error_signal):
593-
return (lambda x: x is not None,
594-
partial(error_signal, message="be declared"))
592+
def get_application_identifier(self):
593+
pass
594+
595+
def get_node_country(self):
596+
pass
595597

596598
@staticmethod
597599
def validate_node_country_format(node_country):
@@ -613,57 +615,54 @@ class eIDASSPConfig(SPConfig, eIDASConfig):
613615
def get_endpoint_element(self, element):
614616
return getattr(self, "_sp_endpoints", {}).get(element, None)
615617

618+
def get_application_identifier(self):
619+
return getattr(self, "_sp_application_identifier", None)
620+
621+
def get_protocol_version(self):
622+
return getattr(self, "_sp_protocol_version", None)
623+
624+
def get_node_country(self):
625+
return getattr(self, "_sp_node_country", None)
626+
616627
def validate(self):
617-
validators = [
618-
RuleValidator(
619-
"single_logout_service",
620-
self.get_endpoint_element("single_logout_service"),
621-
*self.assert_not_declared(should_warning)
622-
),
623-
RuleValidator(
624-
"artifact_resolution_service",
625-
self.get_endpoint_element("artifact_resolution_service"),
626-
*self.assert_not_declared(should_warning)
627-
),
628-
RuleValidator(
629-
"manage_name_id_service",
630-
self.get_endpoint_element("manage_name_id_service"),
631-
*self.assert_not_declared(should_warning)
632-
),
633-
RuleValidator(
634-
"KeyDescriptor",
635-
self.cert_file or self.encryption_keypairs,
636-
*self.assert_declared(must_error)
637-
),
638-
RuleValidator(
639-
"node_country",
640-
getattr(self, "_sp_node_country", None),
641-
self.validate_node_country_format,
642-
partial(must_error,
643-
message="be declared in ISO 3166-1 alpha-2 format")
644-
),
645-
RuleValidator(
646-
"application_identifier",
647-
getattr(self, "_sp_application_identifier", None),
648-
*self.assert_declared(should_warning)
649-
),
650-
RuleValidator(
651-
"application_identifier",
652-
getattr(self, "_sp_application_identifier", None),
653-
self.validate_application_identifier_format,
654-
partial(must_error,
655-
message="be in the form <vendor name>:<software identifier>"
656-
":<major-version>.<minor-version>[.<patch-version>]”")
657-
),
658-
RuleValidator(
659-
"protocol_version",
660-
getattr(self, "_sp_protocol_version", None),
661-
*self.assert_declared(should_warning)
628+
warning_validators = {
629+
"single_logout_service SHOULD NOT be declared":
630+
self.get_endpoint_element("single_logout_service") is None,
631+
"artifact_resolution_service SHOULD NOT be declared":
632+
self.get_endpoint_element("artifact_resolution_service") is None,
633+
"manage_name_id_service SHOULD NOT be declared":
634+
self.get_endpoint_element("manage_name_id_service") is None,
635+
"application_identifier SHOULD be declared":
636+
self.get_application_identifier() is not None,
637+
"protocol_version SHOULD be declared":
638+
self.get_protocol_version() is not None,
639+
}
640+
641+
if not all(warning_validators.values()):
642+
logger.warning(
643+
"Configuration validation warnings occurred: {}".format(
644+
[msg for msg, check in warning_validators.items()
645+
if check is not True]
646+
)
662647
)
663-
]
664648

665-
for validator in validators:
666-
validator.validate()
649+
error_validators = {
650+
"KeyDescriptor MUST be declared":
651+
self.cert_file or self.encryption_keypairs,
652+
"node_country MUST be declared in ISO 3166-1 alpha-2 format":
653+
self.validate_node_country_format(self.get_node_country()),
654+
"application_identifier MUST be in the form <vendor name>:<software "
655+
"identifier>:<major-version>.<minor-version>[.<patch-version>]":
656+
self.validate_application_identifier_format(
657+
self.get_application_identifier())
658+
}
659+
660+
if not all(error_validators.values()):
661+
error = "Configuration validation errors occurred:".format(
662+
[msg for msg, check in error_validators.items()
663+
if check is not True])
664+
logger.error(error)
665+
raise ConfigValidationError(error)
667666

668667

669668
class IdPConfig(Config):

src/saml2/utility/config.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,2 @@
1-
import logging
2-
3-
4-
logger = logging.getLogger(__name__)
5-
6-
71
class ConfigValidationError(Exception):
82
pass
9-
10-
11-
class RuleValidator(object):
12-
def __init__(self, element_name, element_value, validator, error_signal):
13-
"""
14-
:param element_name: the name of the element that will be
15-
validated
16-
:param element_value: function to be called
17-
with config as parameter to fetch an element value
18-
:param validator: function to be called
19-
with a config element value as a parameter
20-
:param error_signal: function to be called
21-
with an element name and value to signal an error (can be a log
22-
function, raise an error etc)
23-
"""
24-
self.element_name = element_name
25-
self.element_value = element_value
26-
self.validator = validator
27-
self.error_signal = error_signal
28-
29-
def validate(self):
30-
if not self.validator(self.element_value):
31-
self.error_signal(self.element_name)
32-
33-
34-
def should_warning(element_name, message):
35-
logger.warning("{element} SHOULD {message}".format(
36-
element=element_name, message=message))
37-
38-
39-
def must_error(element_name, message):
40-
error = "{element} MUST {message}".format(
41-
element=element_name, message=message)
42-
logger.error(error)
43-
raise ConfigValidationError(error)

tests/eidas/test_sp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,13 @@ def test_protocol_version_in_metadata(self, config):
109109
assert {str(conf._sp_protocol_version)} \
110110
== set([x.text for x in protocol_version.attribute_value])
111111

112+
112113
class TestSPConfig:
113114
@pytest.fixture(scope="function")
114115
def raise_error_on_warning(self, monkeypatch):
115116
def r(*args, **kwargs):
116117
raise ConfigValidationError()
117-
monkeypatch.setattr("saml2.utility.config.logger.warning", r)
118+
monkeypatch.setattr("saml2.config.logger.warning", r)
118119

119120
@pytest.fixture(scope="function")
120121
def config(self):

0 commit comments

Comments
 (0)