diff --git a/PyVOTCA/schemas.py b/PyVOTCA/schemas.py new file mode 100644 index 0000000..0398cff --- /dev/null +++ b/PyVOTCA/schemas.py @@ -0,0 +1,48 @@ + +"""Module with the schemas to validate the user input.""" + +from multiprocessing import cpu_count +from numbers import Integral +from typing import Any, Dict + +import yaml +from schema import Optional, Or, Schema, SchemaError + +__all__ = ["validate_input"] + + +input_schema = Schema({ + # Path to the molecule in xyz format + "molecule": str, + + # Number of Threads to run the application + Optional("threads", default=cpu_count()): Integral, + + # Functional + Optional("functional", default="PBE"): str, + + # Basisset + Optional("basis", default=None): Or(str, None), + + # AuxBasisset + Optional("auxbasis", default=None): Or(str, None), + + # GW + Optional("gw", default=None): Or(str, None), + + # BSE + Optional("bse", default=None): Or(str, None), +}) + + +def validate_input(file_input: str) -> Dict[str, Any]: + """Check the input validation against an schema.""" + with open(file_input, 'r') as handler: + dict_input = yaml.load(handler.read(), Loader=yaml.FullLoader) + try: + inp = input_schema.validate(dict_input) + return inp + except SchemaError as err: + msg = f"There was an error in the input yaml provided:\n{err}" + print(msg) + raise diff --git a/PyVOTCA/xtp_gradient.py b/PyVOTCA/xtp_gradient.py index 3cebcbc..22eaca1 100644 --- a/PyVOTCA/xtp_gradient.py +++ b/PyVOTCA/xtp_gradient.py @@ -6,6 +6,7 @@ import sys from .numerical_gradient import NumericalGradient +from .schemas import validate_input def exists(input_file: str) -> Path: @@ -35,22 +36,24 @@ def xtp_gradient(args: argparse.Namespace): def parse_user_arguments() -> argparse.Namespace: """Read the user arguments.""" parser = argparse.ArgumentParser("xtp_gradient") - parser.add_argument("-n", "--name", help="Molecule name") - parser.add_argument("-t", "--threads", help="Number of threads") + parser.add_argument( + "-i", "--input", help="Input file in YAML format", type=exists) # Read the arguments args = parser.parse_args() - if args.name is None: + if args.input is None: parser.print_help() sys.exit() - return args + return args.input def main(): - args = parse_user_arguments() - xtp_gradient(args) + inp = parse_user_arguments() + args = validate_input(inp) + print(args) + # xtp_gradient(args) if __name__ == "__main__": diff --git a/setup.py b/setup.py index ba347b1..8163790 100644 --- a/setup.py +++ b/setup.py @@ -38,8 +38,7 @@ 'xtp_gradient=PyVOTCA.xtp_gradient:main', ] }, - - install_requires=["h5py", "matplotlib", "numpy", "scipy"], + install_requires=["h5py", "matplotlib", "numpy", "pyyaml", "schema", "scipy"], extras_require={ 'test': ['coverage', 'mypy', 'pycodestyle', 'pytest>=3.9', 'pytest-asyncio', 'pytest-cov', 'pytest-mock'], diff --git a/tests/files/example.yml b/tests/files/example.yml new file mode 100644 index 0000000..7a7aeab --- /dev/null +++ b/tests/files/example.yml @@ -0,0 +1,2 @@ +molecule: test/files/ethylene.xyz +