Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tools/submission/power/power_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@


class LineWithoutTimeStamp(Exception):
""" Exception raised when there exists a line without a timestamp in the log file. """
pass


class CheckerWarning(Exception):
""" Exception raised internally when a Checker reports something wrong. """
pass


Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
""" Module for performing accuracy-related checks on MLPerf submission artifacts. """

from .base import BaseCheck
from ..constants import *
from ..loader import SubmissionLogs
Expand Down
7 changes: 7 additions & 0 deletions tools/submission/submission_checker/checks/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
""" Define a base Checker class for MLPerf submission checks. """
from abc import ABC, abstractmethod


Expand All @@ -8,6 +9,12 @@ class BaseCheck(ABC):
"""

def __init__(self, log, path):
"""Initialize checker

Args:
log (Logger): A logger instance for logging check results and errors.
path (str): A path to the submission artifact being checked.
"""
self.checks = []
self.log = log
self.path = path
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

""" Module for performing compliance checks on MLPerf submission artifacts. """
from .base import BaseCheck
from ..constants import *
from ..loader import SubmissionLogs
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
""" Module for performing measurement-related checks on MLPerf submission artifacts. """
from .base import BaseCheck
from ..constants import *
from ..loader import SubmissionLogs
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
""" Module for performing performance-related checks on MLPerf submission artifacts. """
from .base import BaseCheck
from ..constants import *
from ..loader import SubmissionLogs
Expand Down
167 changes: 166 additions & 1 deletion tools/submission/submission_checker/checks/power/power_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@
logging.basicConfig(level=logging.INFO)
log = logging.getLogger("main")


class LineWithoutTimeStamp(Exception):
""" Exception raised when there exists a line without a timestamp in the log file. """
pass


class CheckerWarning(Exception):
""" Exception raised internally when a Checker reports something wrong. """
pass


Expand Down Expand Up @@ -108,6 +109,14 @@ def _sort_dict(x: Dict[str, Any]) -> "OrderedDict[str, Any]":


def hash_dir(dirname: str) -> Dict[str, str]:
"""For all files in a directory, create a dictionary that maps their name to their hash.

Args:
dirname (str): Directory to traverse

Returns:
Dict[str, str]: Map from fname to hash
"""
result: Dict[str, str] = {}

for path, dirs, files in os.walk(dirname, topdown=True):
Expand All @@ -125,6 +134,20 @@ def hash_dir(dirname: str) -> Dict[str, str]:
def get_time_from_line(
line: str, data_regexp: str, file: str, timezone_offset: int
) -> float:
"""Extract time from a given line using regex and save it as a UTC timestamp.

Args:
line (str): Line to be parsed.
data_regexp (str): Date format regex
file (str): File to be searched in. Used for logging and error-checking.
timezone_offset (int): Offset added to timezone as needed.

Raises:
LineWithoutTimeStamp: An exception raised if the regex does not find a timestamp in the line.

Returns:
float: UTC timestamp.
"""
log_time_str = re.search(data_regexp, line)
if log_time_str and log_time_str.group(0):
log_datetime = datetime.strptime(
Expand All @@ -135,13 +158,20 @@ def get_time_from_line(


class SessionDescriptor:
"""Class for holding and checking session descriptor data."""
def __init__(self, path: str):
"""Initialize session descriptor from JSON file.

Args:
path (str): Path to session descriptor.
"""
self.path = path
with open(path, "r") as f:
self.json_object: Dict[str, Any] = json.loads(f.read())
self.required_fields_check()

def required_fields_check(self) -> None:
"""Check that all required fields are present in session descriptor JSON."""
required_fields = [
"version",
"timezone",
Expand All @@ -161,6 +191,13 @@ def required_fields_check(self) -> None:

def compare_dicts_values(
d1: Dict[str, str], d2: Dict[str, str], comment: str) -> None:
"""Assert that all keys in d1 are in d2 and have the same value as the values in d1.

Args:
d1 (Dict[str, str]): Reference dictionary.
d2 (Dict[str, str]): Dictionary to be compared with.
comment (str): Comment for warning that will popup if assert fails.
"""
files_with_diff_check_sum = {k: d1[k]
for k in d1 if k in d2 and d1[k] != d2[k]}
assert len(files_with_diff_check_sum) == 0, f"{comment}" + "".join(
Expand All @@ -173,6 +210,13 @@ def compare_dicts_values(

def compare_dicts(s1: Dict[str, str],
s2: Dict[str, str], comment: str) -> None:
"""Ensure that the keys and values in s1 and s2 are the same.

Args:
s1 (Dict[str, str]): The first dictionary to work with.
s2 (Dict[str, str]): The second dictionary to work with.
comment (str): Comment for warning that will popup if assert fails.
"""
assert (
not s1.keys() - s2.keys()
), f"{comment} Missing {', '.join(sorted(s1.keys() - s2.keys()))!r}"
Expand Down Expand Up @@ -205,6 +249,14 @@ def ptd_messages_check(sd: SessionDescriptor) -> None:
msgs: List[Dict[str, str]] = sd.json_object["ptd_messages"]

def get_ptd_answer(command: str) -> str:
"""From the list of messages, return the reply from the first instance of a given command.

Args:
command (str): Name of command to look for.

Returns:
str: The reply of the command if found, otherwise an empty string.
"""
for msg in msgs:
if msg["cmd"] == command:
return msg["reply"]
Expand All @@ -226,6 +278,12 @@ def get_ptd_answer(command: str) -> str:
), f"Power meter {power_meter_model!r} is not supported. Only {', '.join(SUPPORTED_MODEL.keys())} are supported."

def check_reply(cmd: str, reply: str) -> None:
"""For a given command, look for a particular reply. If the reply is what is expected, continue. Otherwise raise AssertionError.

Args:
cmd (str): Command to look for.
reply (str): Expected reply.
"""
stop_counter = 0
for msg in msgs:
if msg["cmd"].startswith(cmd):
Expand All @@ -248,6 +306,15 @@ def check_reply(cmd: str, reply: str) -> None:
check_reply("Stop", "Stopping untimed measurement")

def get_initial_range(param_num: int, reply: str) -> str:
"""Get the initial range of a value from the power logs.

Args:
param_num (int): Identify which field of the power log we need to look into.
reply (str): Name of source of data, used for assertions

Returns:
str: Initial range when possible, otherwise "Auto".
"""
reply_list = reply.split(",")
try:
if reply_list[param_num] == "0" and float(
Expand All @@ -259,6 +326,15 @@ def get_initial_range(param_num: int, reply: str) -> str:

def get_command_by_value_and_number(
cmd: str, number: int) -> Optional[str]:
"""From the list of msgs, get the `number` occurence of the command if possible.

Args:
cmd (str): Command to look for.
number (int): Index of command. 1-indexed, so 1 will get the first instance of the command.

Returns:
Optional[str]: The full command if found, otherwise None.
"""
command_counter = 0
for msg in msgs:
if msg["cmd"].startswith(cmd):
Expand Down Expand Up @@ -347,6 +423,13 @@ def phases_check(
def compare_time(
phases_client: List[List[float]], phases_server: List[List[float]], mode: str
) -> None:
"""Compare the time difference between each checkpoint on the client and the server. If they are less than the TIME_DELTA_TOLERANCE, raise AssertionError.

Args:
phases_client (List[List[float]]): List of client checkpoints for ranging and testing.
phases_server (List[List[float]]): List of server checkpoints for ranging and testing.
mode (str): Mode information string. Used for assertion message.
"""
assert len(phases_client) == len(
phases_server
), f"Phases amount is not equal for {mode} mode."
Expand All @@ -361,6 +444,15 @@ def compare_time(
compare_time(phases_testing_c, phases_testing_s, TESTING_MODE)

def compare_duration(range_duration: float, test_duration: float) -> None:
"""Compare the duration between the range mode and the test mode. Fail if the duration difference is more than 5 percent.

Args:
range_duration (float): Length of range duration.
test_duration (float): Length of test duration.

Raises:
CheckerWarning: Raised if duration difference is more than 5 percent.
"""
duration_diff = (range_duration - test_duration) / range_duration

if duration_diff > 0.5:
Expand All @@ -373,6 +465,14 @@ def compare_duration(range_duration: float, test_duration: float) -> None:
def compare_time_boundaries(
begin: float, end: float, phases: List[Any], mode: str
) -> None:
"""Temporarily compare time boundaries between beginning and end, and raise an AssertionError if false.

Args:
begin (float): Beginning timestamp
end (float): Ending timestamp
phases (List[Any]): List of phases.
mode (str): Mode name, used for assertion message.
"""
# TODO: temporary workaround, remove when proper DST handling is
# implemented!
assert (
Expand Down Expand Up @@ -409,6 +509,15 @@ def compare_time_boundaries(
compare_duration(ranging_duration_d, testing_duration_d)

def get_avg_power(power_path: str, run_path: str) -> Tuple[float, float]:
"""Get the average power from the power log.

Args:
power_path (str): Path of power log. Unused.
run_path (str): Path of run log.

Returns:
Tuple[float, float]: Return average power and pf.
"""
# parse the power logs

power_begin, power_end = _get_begin_end_time_from_mlperf_log_detail(
Expand Down Expand Up @@ -512,6 +621,15 @@ def messages_check(client_sd: SessionDescriptor,
# Server.json contains all client.json messages and replies. Checked
# earlier.
def get_version(regexp: str, line: str) -> str:
"""Try to get client and server version within server.json using a given regex.

Args:
regexp (str): Regex to look for a particular system in.
line (str): Line of text to search.

Returns:
str: Version information if possible, otherwise returns AssertionError
"""
version_o = re.search(regexp, line)
assert version_o is not None, f"Server version is not defined in:'{line}'"
return version_o.group(1)
Expand Down Expand Up @@ -553,6 +671,11 @@ def results_check(
result_paths_copy.remove("power/client.json")

def remove_optional_path(res: Dict[str, str]) -> None:
"""Given a dictionary of string, hash pairs, delete the paths not in results_paths_copy.

Args:
res (Dict[str, str]): _description_
"""
keys = list(res.keys())
for path in keys:
# Ignore all the optional files.
Expand Down Expand Up @@ -591,6 +714,13 @@ def remove_optional_path(res: Dict[str, str]) -> None:
def result_files_compare(
res: Dict[str, str], ref_res: List[str], path: str
) -> None:
"""Check if all of the required files are present, using the result dictionary.

Args:
res (Dict[str, str]): Result dictionary.
ref_res (List[str]): Reference result dictionary.
path (str): Path to file. Used for AssertionError.
"""
# If a file is required (in ref_res) but is not present in results directory (res),
# then the submission is invalid.
absent_files = set(ref_res) - set(res.keys())
Expand Down Expand Up @@ -626,6 +756,16 @@ def check_ptd_logs(
ptd_log_lines = f.readlines()

def find_error_or_warning(reg_exp: str, line: str, error: bool) -> None:
"""Given a particular log line and a regex and an error, perform error handling. Any known and common testing error is accepted, while any known and common ranging error leads to an AssertionError.

Args:
reg_exp (str): Regex to find lines with errors or warnings.
line (str): Log line to check.
error (bool): Whether the line represents an error.

Raises:
CheckerWarning: If the line is a warning.
"""
problem_line = re.search(reg_exp, line)

if problem_line and problem_line.group(0):
Expand Down Expand Up @@ -672,6 +812,14 @@ def find_error_or_warning(reg_exp: str, line: str, error: bool) -> None:
start_ranging_line = f": Go with mark {ranging_mark!r}"

def get_msg_without_time(line: str) -> Optional[str]:
"""Try to extract the message contents of a log line without getting the timestamp.

Args:
line (str): Line of text to extract from.

Returns:
Optional[str]: Message content if possible, otherwise None.
"""
try:
get_time_from_line(line, date_regexp, file_path, timezone_offset)
except LineWithoutTimeStamp:
Expand Down Expand Up @@ -773,6 +921,15 @@ def debug_check(server_sd: SessionDescriptor) -> None:

def check_with_logging(
check_name: str, check: Callable[[], None]) -> Tuple[bool, bool]:
"""Try running `check`, but log any and all errors to a logfile, including tracebacks.

Args:
check_name (str): Name of check being ran. Used for logging.
check (Callable[[], None]): The check function being ran.

Returns:
Tuple[bool, bool]: A tuple of (No Errors Detected, Warnings Detected)
"""
try:
check()
except AssertionError as e:
Expand All @@ -794,6 +951,14 @@ def check_with_logging(


def check(path: str) -> int:
"""Run the power checker on a particular path.

Args:
path (str): Path to json files to be checked.

Returns:
int: 1 if there is an error otherwise 0. Used in sys.exit() call.
"""
client = SessionDescriptor(os.path.join(path, "power/client.json"))
server = SessionDescriptor(os.path.join(path, "power/server.json"))

Expand Down
2 changes: 2 additions & 0 deletions tools/submission/submission_checker/checks/power_check.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
""" Module for performing power-related checks on MLPerf submission artifacts. """

from .base import BaseCheck
from ..constants import *
from ..loader import SubmissionLogs
Expand Down
Loading