Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 1 addition & 32 deletions pybatfish/client/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,44 +15,13 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING

from pybatfish.client import restv2helper
from pybatfish.datamodel.answer import Answer
from pybatfish.util import get_uuid

from . import workhelper
from .options import Options

if TYPE_CHECKING:
from pybatfish.client.session import Session


def _bf_answer_obj(
session: Session,
question_str: str,
question_name: str,
background: bool,
snapshot: str,
reference_snapshot: str | None,
extra_args: dict[str, Any] | None,
) -> Answer | str:
if not question_name:
question_name = Options.default_question_prefix + "_" + get_uuid()

# Upload the question
restv2helper.upload_question(session, question_name, question_str)

# Answer the question
work_item = workhelper.get_workitem_answer(session, question_name, snapshot, reference_snapshot)
workhelper.execute(work_item, session, background, extra_args)

if background:
return work_item.id

# get the answer
return session.get_answer(question_name, snapshot, reference_snapshot)


def _bf_get_question_templates(session: Session, verbose: bool = False) -> dict:
return restv2helper.get_question_templates(session, verbose)
39 changes: 39 additions & 0 deletions pybatfish/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,45 @@ def generate_dataplane(
answer_dict = workhelper.execute(work_item, self, extra_args=extra_args)
return str(answer_dict["status"].value)

def answer_question(
self,
question_str: str,
question_name: str,
background: bool,
snapshot: str,
reference_snapshot: str | None,
extra_args: dict[str, Any] | None,
) -> Answer | str:
"""
Upload, execute, and return the answer for a question.

Subclasses can override this method to change how questions are
answered (e.g. to use a different backend).

:param question_str: JSON string representing the question
:param question_name: unique name for the question
:param background: if True, return immediately with work item ID
:param snapshot: snapshot on which to answer the question
:param reference_snapshot: reference snapshot for differential questions
:param extra_args: extra arguments to pass with the question
:return: Answer object, or work item ID string if background=True
"""
if not question_name:
question_name = Options.default_question_prefix + "_" + get_uuid()

# Upload the question
restv2helper.upload_question(self, question_name, question_str)

# Answer the question
work_item = workhelper.get_workitem_answer(self, question_name, snapshot, reference_snapshot)
workhelper.execute(work_item, self, background, extra_args)

if background:
return work_item.id

# Get the answer
return self.get_answer(question_name, snapshot, reference_snapshot)

def get_answer(self, question: str, snapshot: str, reference_snapshot: str | None = None) -> Answer:
"""
Get the answer for a previously asked question.
Expand Down
7 changes: 3 additions & 4 deletions pybatfish/question/question.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import attr

from pybatfish.client.internal import _bf_answer_obj, _bf_get_question_templates
from pybatfish.client.internal import _bf_get_question_templates
from pybatfish.datamodel import Assertion, AssertionType, BgpRoute, VariableType
from pybatfish.datamodel.answer.base import Answer
from pybatfish.exception import QuestionValidationException
Expand Down Expand Up @@ -120,7 +120,7 @@ def __dir__(self):
class QuestionBase:
"""All questions inherit functionality from this class."""

def __init__(self, dictionary, session):
def __init__(self, dictionary: dict, session: "Session"):
self._dict = deepcopy(dictionary)
self._session = session

Expand Down Expand Up @@ -160,8 +160,7 @@ def answer(
_validate(self.dict())
if include_one_table_keys is not None:
self._set_include_one_table_keys(include_one_table_keys)
return _bf_answer_obj(
session=session,
return session.answer_question(
question_str=self.json(),
question_name=self.get_name(),
background=background,
Expand Down