-
Notifications
You must be signed in to change notification settings - Fork 37
Simplify tabular API #43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,59 +2,66 @@ | |
| import csv | ||
| import warnings | ||
|
|
||
| from dowel import TabularInput | ||
| import numpy as np | ||
|
|
||
| from dowel.simple_outputs import FileOutput | ||
| from dowel.tabular import Tabular | ||
| from dowel.utils import colorize | ||
|
|
||
|
|
||
| class CsvOutput(FileOutput): | ||
| """CSV file output for logger. | ||
|
|
||
| :param file_name: The file this output should log to. | ||
| :param keys_accepted: Regex for which keys this output should accept. | ||
| """ | ||
|
|
||
| def __init__(self, file_name): | ||
| super().__init__(file_name) | ||
| def __init__(self, file_name, keys_accepted=r'^\S+$'): | ||
| super().__init__(file_name, keys_accepted=keys_accepted) | ||
| self._writer = None | ||
| self._fieldnames = None | ||
| self._warned_once = set() | ||
| self._disable_warnings = False | ||
| self.tabular = Tabular() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| @property | ||
| def types_accepted(self): | ||
| """Accept TabularInput objects only.""" | ||
| return (TabularInput, ) | ||
|
|
||
| def record(self, data, prefix=''): | ||
| """Log tabular data to CSV.""" | ||
| if isinstance(data, TabularInput): | ||
| to_csv = data.as_primitive_dict | ||
|
|
||
| if not to_csv.keys() and not self._writer: | ||
| return | ||
|
|
||
| if not self._writer: | ||
| self._fieldnames = set(to_csv.keys()) | ||
| self._writer = csv.DictWriter( | ||
| self._log_file, | ||
| fieldnames=self._fieldnames, | ||
| extrasaction='ignore') | ||
| self._writer.writeheader() | ||
|
|
||
| if to_csv.keys() != self._fieldnames: | ||
| self._warn('Inconsistent TabularInput keys detected. ' | ||
| 'CsvOutput keys: {}. ' | ||
| 'TabularInput keys: {}. ' | ||
| 'Did you change key sets after your first ' | ||
| 'logger.log(TabularInput)?'.format( | ||
| set(self._fieldnames), set(to_csv.keys()))) | ||
|
|
||
| self._writer.writerow(to_csv) | ||
|
|
||
| for k in to_csv.keys(): | ||
| data.mark(k) | ||
| else: | ||
| raise ValueError('Unacceptable type.') | ||
| """Accept str and scalar objects.""" | ||
| return (str, ) + np.ScalarType | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just make this a tuple rather than using addition? |
||
|
|
||
| def record(self, key, value, prefix=''): | ||
| """Log data to a csv file.""" | ||
| self.tabular.record(key, value) | ||
|
|
||
| def dump(self, step=None): | ||
| """Flush data to log file.""" | ||
| if self.tabular.empty: | ||
| return | ||
|
|
||
| to_csv = self.tabular.as_primitive_dict | ||
|
|
||
| if not to_csv.keys() and not self._writer: | ||
| return | ||
|
|
||
| if not self._writer: | ||
| self._fieldnames = set(to_csv.keys()) | ||
| self._writer = csv.DictWriter(self._log_file, | ||
| fieldnames=self._fieldnames, | ||
| extrasaction='ignore') | ||
| self._writer.writeheader() | ||
|
|
||
| if to_csv.keys() != self._fieldnames: | ||
| self._warn('Inconsistent Tabular keys detected. ' | ||
| 'CsvOutput keys: {}. ' | ||
| 'Tabular keys: {}. ' | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. your user now has no idea that a Tabular is, so this message needs to be updated. |
||
| 'Did you change key sets after your first ' | ||
| 'logger.log(Tabular)?'.format(set(self._fieldnames), | ||
| set(to_csv.keys()))) | ||
|
|
||
| self._writer.writerow(to_csv) | ||
|
|
||
| self._log_file.flush() | ||
| self.tabular.clear() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i don't think we should clear the table between calls to dump, because it allows us to provide a value even if someone doesn't update it. basically, if KV pairs are not all updated at the same rate it's okay, and we don't need to output an error. |
||
|
|
||
| def _warn(self, msg): | ||
| """Warns the user using warnings.warn. | ||
|
|
@@ -63,8 +70,9 @@ def _warn(self, msg): | |
| is the one printed. | ||
| """ | ||
| if not self._disable_warnings and msg not in self._warned_once: | ||
| warnings.warn( | ||
| colorize(msg, 'yellow'), CsvOutputWarning, stacklevel=3) | ||
| warnings.warn(colorize(msg, 'yellow'), | ||
| CsvOutputWarning, | ||
| stacklevel=3) | ||
| self._warned_once.add(msg) | ||
| return msg | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,17 +5,18 @@ | |
|
|
||
| The logger has 4 major steps: | ||
|
|
||
| 1. Inputs, such as a simple string or something more complicated like | ||
| TabularInput, are passed to the log() method of an instantiated Logger. | ||
| 1. Inputs, such as a simple string or something more complicated like | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please make sure this giant commend renders nicely in sphinx. Here's what it looks like now: Perhaps you can actually move this content (+example) to the title page of the documentation. If you do that, it will probably render fine (markdown is supported for pages but not docstrings). Anyway, to render the docs just do |
||
| a distribution, are passed to the log() or logkv() method of an | ||
| instantiated Logger. | ||
|
|
||
| 2. The Logger class checks for any outputs that have been added to it, and | ||
| calls the record() method of any outputs that accept the type of input. | ||
| 2. The Logger class checks for any outputs that have been added to it, and | ||
| calls the record() method of any outputs that accept the type of input. | ||
|
|
||
| 3. The output (a subclass of LogOutput) receives the input via its record() | ||
| method and handles it in whatever way is expected. | ||
| 3. The output (a subclass of LogOutput) receives the input via its record() | ||
| method and handles it in whatever way is expected. | ||
|
|
||
| 4. (only in some cases) The dump method is used to dump the output to file. | ||
| It is necessary for some LogOutput subclasses, like TensorBoardOutput. | ||
| 4. (only in some cases) The dump method is used to dump the output to file | ||
| and to log any key-value pairs that have been stored. | ||
|
|
||
|
|
||
| # Here's a demonstration of dowel: | ||
|
|
@@ -61,8 +62,8 @@ | |
|
|
||
| # And another output. | ||
|
|
||
| from dowel import CsvOutput | ||
| logger.add_output(CsvOutput('log_folder/table.csv')) | ||
| from dowel import TensorBoardOutput | ||
| logger.add_output(TensorBoardOutput('log_folder/tensorboard')) | ||
|
|
||
| +---------+ | ||
| +------>StdOutput| | ||
|
|
@@ -72,13 +73,16 @@ | |
| |logger+------>TextOutput| | ||
| +------+ +----------+ | ||
| | | ||
| | +---------+ | ||
| +------>CsvOutput| | ||
| +---------+ | ||
| | +-----------------+ | ||
| +------>TensorBoardOutput| | ||
| +-----------------+ | ||
|
|
||
| # The logger will record anything passed to logger.log to all outputs that | ||
| # accept its type. | ||
|
|
||
|
|
||
| # Now let's try logging a string again. | ||
|
|
||
| logger.log('test') | ||
|
|
||
| +---------+ | ||
|
|
@@ -89,38 +93,36 @@ | |
| |logger+---'test'--->TextOutput| | ||
| +------+ +----------+ | ||
| | | ||
| | +---------+ | ||
| +-----!!----->CsvOutput| | ||
| +---------+ | ||
| | +-----------------+ | ||
| +-----!!----->TensorBoardOutput| | ||
| +-----------------+ | ||
|
|
||
| # !! Note that the logger knows not to send CsvOutput the string 'test' | ||
| # Similarly, more complex objects like tf.tensor won't be sent to (for | ||
| # !! Note that the logger knows not to send 'test' to TensorBoardOutput. | ||
| # Similarly, more complex objects like tf.Graph won't be sent to (for | ||
| # example) TextOutput. | ||
| # This behavior is defined in each output's types_accepted property | ||
|
|
||
| # Here's a more complex example. | ||
| # TabularInput, instantiated for you as the tabular, can log key/value pairs. | ||
| # We can log key-value pairs using logger.logkv | ||
|
|
||
| from dowel import tabular | ||
| tabular.record('key', 72) | ||
| tabular.record('foo', 'bar') | ||
| logger.log(tabular) | ||
| logger.logkv('key', 72) | ||
| logger.logkv('foo', 'bar') | ||
| logger.dump_all() | ||
|
|
||
| +---------+ | ||
| +---tabular--->StdOutput| | ||
| | +---------+ | ||
| +---------+ | ||
| +------>StdOutput| | ||
| | +---------+ | ||
| | | ||
| +------+ +----------+ | ||
| |logger+---tabular--->TextOutput| | ||
| +------+ +----------+ | ||
| +------+ +----------+ | ||
| |logger+------>TextOutput| | ||
| +------+ +----------+ | ||
| | | ||
| | +---------+ | ||
| +---tabular--->CsvOutput| | ||
| +---------+ | ||
| | +---------+ | ||
| +------>CsvOutput| | ||
| +---------+ | ||
|
|
||
| # Note that LogOutputs which consume TabularInputs must call | ||
| # TabularInput.mark() on each key they log. This helps the logger detect when | ||
| # tabular data is not logged. | ||
| # Note that the key-value pairs are saved in each output until we call | ||
| # dump_all(). | ||
|
|
||
| # Console Output: | ||
| --- --- | ||
|
|
@@ -133,29 +135,37 @@ | |
| """ | ||
| import abc | ||
| import contextlib | ||
| import re | ||
| import warnings | ||
|
|
||
| from dowel.utils import colorize | ||
|
|
||
|
|
||
| class LogOutput(abc.ABC): | ||
| """Abstract class for Logger Outputs.""" | ||
| """Abstract class for Logger Outputs. | ||
|
|
||
| @property | ||
| def types_accepted(self): | ||
| """Pass these types to this logger output. | ||
| :param keys_accepted: Regex for which keys this output should accept. | ||
| """ | ||
|
|
||
| The types in this tuple will be accepted by this output. | ||
| def __init__(self, keys_accepted=r'^$'): | ||
| self._keys_accepted = keys_accepted | ||
|
|
||
| :return: A tuple containing all valid input types. | ||
| """ | ||
| @property | ||
| def types_accepted(self): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should probably be an |
||
| """Returns a tuple containing all valid input value types.""" | ||
| return () | ||
|
|
||
| @property | ||
| def keys_accepted(self): | ||
| """Returns a regex string matching keys to be sent to this output.""" | ||
| return self._keys_accepted | ||
|
|
||
| @abc.abstractmethod | ||
| def record(self, data, prefix=''): | ||
| def record(self, key, value, prefix=''): | ||
| """Pass logger data to this output. | ||
|
|
||
| :param data: The data to be logged by the output. | ||
| :param key: The key to be logged by the output. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please use google style docstrings |
||
| :param value: The value to be logged by the output. | ||
| :param prefix: A prefix placed before a log entry in text outputs. | ||
| """ | ||
| pass | ||
|
|
@@ -186,7 +196,7 @@ def __init__(self): | |
| self._warned_once = set() | ||
| self._disable_warnings = False | ||
|
|
||
| def log(self, data): | ||
| def logkv(self, key, value): | ||
| """Magic method that takes in all different types of input. | ||
|
|
||
| This method is the main API for the logger. Any data to be logged goes | ||
|
|
@@ -195,24 +205,30 @@ def log(self, data): | |
| Any data sent to this method is sent to all outputs that accept its | ||
| type (defined in the types_accepted property). | ||
|
|
||
| :param data: Data to be logged. This can be any type specified in the | ||
| :param key: Key to be logged. This must be a string. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please update these docstrings to use Google style |
||
| :param value: Value to be logged. This can be any type specified in the | ||
| types_accepted property of any of the logger outputs. | ||
| """ | ||
| if not self._outputs: | ||
| self._warn('No outputs have been added to the logger.') | ||
|
|
||
| at_least_one_logged = False | ||
| for output in self._outputs: | ||
| if isinstance(data, output.types_accepted): | ||
| output.record(data, prefix=self._prefix_str) | ||
| if isinstance(value, output.types_accepted) and re.match( | ||
| output.keys_accepted, key): | ||
| output.record(key, value, prefix=self._prefix_str) | ||
| at_least_one_logged = True | ||
|
|
||
| if not at_least_one_logged: | ||
| warning = ( | ||
| 'Log data of type {} was not accepted by any output'.format( | ||
| type(data).__name__)) | ||
| type(value).__name__)) | ||
| self._warn(warning) | ||
|
|
||
| def log(self, value): | ||
| """Log just a value without a key.""" | ||
| self.logkv('', value) | ||
|
|
||
| def add_output(self, output): | ||
| """Add a new output to the logger. | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is now a private API, so it should not be in
__all__, which is only for things people should be importing from your package.