diff --git a/pybatfish/client/internal.py b/pybatfish/client/internal.py index 9c7bf96a..1f94d28e 100644 --- a/pybatfish/client/internal.py +++ b/pybatfish/client/internal.py @@ -15,44 +15,13 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from pybatfish.client import restv2helper -from pybatfish.datamodel.answer import Answer -from pybatfish.util import get_uuid - -from . import workhelper -from .options import Options if TYPE_CHECKING: from pybatfish.client.session import Session -def _bf_answer_obj( - session: Session, - question_str: str, - question_name: str, - background: bool, - snapshot: str, - reference_snapshot: str | None, - extra_args: dict[str, Any] | None, -) -> Answer | str: - if not question_name: - question_name = Options.default_question_prefix + "_" + get_uuid() - - # Upload the question - restv2helper.upload_question(session, question_name, question_str) - - # Answer the question - work_item = workhelper.get_workitem_answer(session, question_name, snapshot, reference_snapshot) - workhelper.execute(work_item, session, background, extra_args) - - if background: - return work_item.id - - # get the answer - return session.get_answer(question_name, snapshot, reference_snapshot) - - def _bf_get_question_templates(session: Session, verbose: bool = False) -> dict: return restv2helper.get_question_templates(session, verbose) diff --git a/pybatfish/client/session.py b/pybatfish/client/session.py index 8f7d3756..fa1c482c 100644 --- a/pybatfish/client/session.py +++ b/pybatfish/client/session.py @@ -632,6 +632,45 @@ def generate_dataplane( answer_dict = workhelper.execute(work_item, self, extra_args=extra_args) return str(answer_dict["status"].value) + def answer_question( + self, + question_str: str, + question_name: str, + background: bool, + snapshot: str, + reference_snapshot: str | None, + extra_args: dict[str, Any] | None, + ) -> Answer | str: + """ + Upload, execute, and return the answer for a question. + + Subclasses can override this method to change how questions are + answered (e.g. to use a different backend). + + :param question_str: JSON string representing the question + :param question_name: unique name for the question + :param background: if True, return immediately with work item ID + :param snapshot: snapshot on which to answer the question + :param reference_snapshot: reference snapshot for differential questions + :param extra_args: extra arguments to pass with the question + :return: Answer object, or work item ID string if background=True + """ + if not question_name: + question_name = Options.default_question_prefix + "_" + get_uuid() + + # Upload the question + restv2helper.upload_question(self, question_name, question_str) + + # Answer the question + work_item = workhelper.get_workitem_answer(self, question_name, snapshot, reference_snapshot) + workhelper.execute(work_item, self, background, extra_args) + + if background: + return work_item.id + + # Get the answer + return self.get_answer(question_name, snapshot, reference_snapshot) + def get_answer(self, question: str, snapshot: str, reference_snapshot: str | None = None) -> Answer: """ Get the answer for a previously asked question. diff --git a/pybatfish/question/question.py b/pybatfish/question/question.py index 9212516a..2856f617 100644 --- a/pybatfish/question/question.py +++ b/pybatfish/question/question.py @@ -25,7 +25,7 @@ import attr -from pybatfish.client.internal import _bf_answer_obj, _bf_get_question_templates +from pybatfish.client.internal import _bf_get_question_templates from pybatfish.datamodel import Assertion, AssertionType, BgpRoute, VariableType from pybatfish.datamodel.answer.base import Answer from pybatfish.exception import QuestionValidationException @@ -120,7 +120,7 @@ def __dir__(self): class QuestionBase: """All questions inherit functionality from this class.""" - def __init__(self, dictionary, session): + def __init__(self, dictionary: dict, session: "Session"): self._dict = deepcopy(dictionary) self._session = session @@ -160,8 +160,7 @@ def answer( _validate(self.dict()) if include_one_table_keys is not None: self._set_include_one_table_keys(include_one_table_keys) - return _bf_answer_obj( - session=session, + return session.answer_question( question_str=self.json(), question_name=self.get_name(), background=background,