From 5c86d04ebc0ae92be4fdd26e690ced48c6d83b13 Mon Sep 17 00:00:00 2001 From: Dylan Date: Tue, 31 Dec 2024 17:01:45 -0800 Subject: [PATCH 1/7] update CI --- .github/workflows/format.yml | 25 -- .github/workflows/lint.yml | 49 +++ .github/workflows/static.yml | 43 +++ .github/workflows/test.yml | 48 +++ .github/workflows/tests.yml | 30 -- .pylintrc | 642 +++++++++++++++++++++++++++++++++++ 6 files changed, 782 insertions(+), 55 deletions(-) delete mode 100644 .github/workflows/format.yml create mode 100644 .github/workflows/lint.yml create mode 100644 .github/workflows/static.yml create mode 100644 .github/workflows/test.yml delete mode 100644 .github/workflows/tests.yml create mode 100644 .pylintrc diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml deleted file mode 100644 index 6dd643b..0000000 --- a/.github/workflows/format.yml +++ /dev/null @@ -1,25 +0,0 @@ -name: flake8 - -on: - push: - branches: - - main - pull_request: - paths: - - "**.py" - -jobs: - flake8-lint: - runs-on: ubuntu-latest - steps: - - name: Check out source repo - uses: actions/checkout@v2 - - - name: Set up Python all python version - uses: actions/setup-python@v2 - with: - python-version: 3.9 - architecture: x64 - - - name: flake8-lint - uses: py-actions/flake8@v2 diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..60a6488 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,49 @@ +name: lint + +on: + push: + branches: + - main + pull_request: + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - name: Check out source repo + uses: actions/checkout@v4 + with: + token: ${{github.token}} + + - name: Set up Python all python version + uses: actions/setup-python@v4 + with: + python-version: 3.11 + token: ${{github.token}} + architecture: x64 + + - name: Install packages + run: | + python -v venv --upgrade-deps .venv + source .vinv/bin/activate + pip install .[all] + + - name: Run black + run: | + source .venv/bin/activate + python -m black --config pyproject.toml --check --diff . + + - name: Get all Python files + id: list_files + run: | + echo "files=$(git ls-files '*.py' '*.pyi' | xargs)" >> $GITHUB_OUTPUT + + - name: Run Pylint on files + run: | + source .venv/bin/activate + files="${{ steps.list_files.outputs.files }}" + if [ -n "$files" ]; then + pylint --rcfile=.pylintrc $files + else + echo "No Python files found." + fi diff --git a/.github/workflows/static.yml b/.github/workflows/static.yml new file mode 100644 index 0000000..bb22bbd --- /dev/null +++ b/.github/workflows/static.yml @@ -0,0 +1,43 @@ +name: static + +on: + push: + branches: + - main + pull_request: + +jobs: + static: + name: static + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + token: ${{github.token}} + + - name: set up python + uses: actions/setup-python@v4 + with: + python-version: 3.11 + token: ${{github.token}} + + - name: install requirements + run: | + python -m venv --upgrade-deps .venv + source .venv/bin/activate + pip install .[all] + + - name: Get all Python files + id: list_files + run: | + echo "files=$(git ls-files '*.py' '*.pyi' | xargs)" >> $GITHUB_OUTPUT + + - name: analysing code with pyright + run: | + source .venv/bin/activate + files="${{ steps.list_files.outputs.files }}" + if [ -n "$files" ]; then + python -m pyright $files + else + echo "No Python files found." + fi diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..bb34b5c --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,48 @@ +name: test + +on: + push: + branches: + - main + pull_request: + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + token: ${{github.token}} + + - name: Set up Python all python version + uses: actions/setup-python@v4 + with: + python-version: 3.11 + token: ${{github.token}} + architecture: x64 + + - name: Install dependencies + run: | + python -m venv --upgrade-deps .venv + source .venv/bin/activate + pip install .[all] + + - name: Run pytest with coverage + run: | + source .venv/bin/activate + IN_CI=true coverage run -m pytest + - name: generate coverage report + + - name: Generate coverage report + run: | + source .venv/bin/activate + coverage xml -i + coverage html -i + + # TODO: This must be done after we have an RCTN admin create a Codecov token + # - name: upload coverage report to Codecov + # uses: codecov/codecov-action@v4 + # with: + # flags: unittests + # fail_ci_if_error: false + # token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml deleted file mode 100644 index 9c4026b..0000000 --- a/.github/workflows/tests.yml +++ /dev/null @@ -1,30 +0,0 @@ -name: "Testing" - -on: - push: - branches: - - main - pull_request: - -jobs: - build: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - - name: Set up Python all python version - uses: actions/setup-python@v4 - with: - python-version: 3.11 - architecture: x64 - - - name: Install dependencies - run: | - python -m venv --upgrade-deps .venv - source .venv/bin/activate - pip install .[all] - - - name: Run Test - run: | - source .venv/bin/activate - python -m pytest . diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000..ba547ff --- /dev/null +++ b/.pylintrc @@ -0,0 +1,642 @@ +[MAIN] + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Load and enable all available extensions. Use --list-extensions to see a list +# all available extensions. +#enable-all-extensions= + +# In error mode, messages with a category besides ERROR or FATAL are +# suppressed, and no reports are done by default. Error mode is compatible with +# disabling specific errors. +#errors-only= + +# Always return a 0 (non-error) status code, even if lint errors are found. +# This is primarily useful in continuous integration scripts. +#exit-zero= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-allow-list= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. (This is an alternative name to extension-pkg-allow-list +# for backward compatibility.) +extension-pkg-whitelist= + +# Return non-zero exit code if any of these messages/categories are detected, +# even if score is above --fail-under value. Syntax same as enable. Messages +# specified are enabled, while categories only check already-enabled messages. +fail-on= + +# Specify a score threshold under which the program will exit with error. +fail-under=10 + +# Interpret the stdin as a python script, whose filename needs to be passed as +# the module_or_package argument. +#from-stdin= + +# Files or directories to be skipped. They should be base names, not paths. +ignore=CVS + +# Add files or directories matching the regular expressions patterns to the +# ignore-list. The regex matches against paths and can be in Posix or Windows +# format. Because '\' represents the directory delimiter on Windows systems, it +# can't be used as an escape character. +ignore-paths=docs,.venv + +# Files or directories matching the regular expression patterns are skipped. +# The regex matches against base names, not paths. The default value ignores +# Emacs file locks +ignore-patterns=^\.# + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis). It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use, and will cap the count on Windows to +# avoid hangs. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=100 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +load-plugins=pylint.extensions.docparams, pylint.extensions.docstyle + +# Pickle collected data for later comparisons. +persistent=yes + +# Minimum Python version to use for version dependent checks. Will default to +# the version used to run pylint. +py-version=3.8 + +# Discover python modules and packages in the file system subtree. +recursive=no + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + +# In verbose mode, extra non-checker-related info will be displayed. +#verbose= + + +[REPORTS] + +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'fatal', 'error', 'warning', 'refactor', +# 'convention', and 'info' which contain the number of messages in each +# category, as well as 'statement' which is the total number of statements +# analyzed. This score is used by the global evaluation report (RP0004). +evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +msg-template= + +# Set the output format. Available formats are text, parseable, colorized, json +# and msvs (visual studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +#output-format= + +# Tells whether to display a full report or only the messages. +reports=no + +# Activate the evaluation score. +score=yes + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, +# UNDEFINED. +confidence=HIGH, + CONTROL_FLOW, + INFERENCE, + INFERENCE_FAILURE, + UNDEFINED + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then re-enable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". + +disable=too-few-public-methods, + wrong-import-order, + raw-checker-failed, + duplicate-code, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + use-symbolic-message-instead, + consider-using-from-import + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable=c-extension-no-member + + +[LOGGING] + +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style=new + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. Available dictionaries: none. To make it work, +# install the 'python-enchant' package. +spelling-dict= + +# List of comma separated words that should be considered directives if they +# appear at the beginning of a comment and should not be checked. +spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +spelling-store-unknown-words=no + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME + +# Regular expression of note tags to take in consideration. +notes-rgx= + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of symbolic message names to ignore for Mixin members. +ignored-checks-for-mixins=no-member, + not-async-context-manager, + not-context-manager, + attribute-defined-outside-init + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# Regex pattern to define which classes are considered mixins. +mixin-class-rgx=.*[Mm]ixin + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[CLASSES] + +# Warn about protected attribute access inside special methods +check-protected-access-in-special-methods=no + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp, + __post_init__ + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=cls + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of names allowed to shadow builtins +allowed-redefined-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^.*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=120 + +# Maximum number of lines in a module. +max-module-lines=1000 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[IMPORTS] + +# List of modules that can be imported at any level, not just the top level +# one. +allow-any-import-level= + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules= + +# Output a graph (.gv or any supported image format) of external dependencies +# to the given file (report RP0402 must not be disabled). +ext-import-graph= + +# Output a graph (.gv or any supported image format) of all (i.e. internal and +# external) dependencies to the given file (report RP0402 must not be +# disabled). +import-graph= + +# Output a graph (.gv or any supported image format) of internal dependencies +# to the given file (report RP0402 must not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= + + +[METHOD_ARGS] + +# List of qualified names (i.e., library.method) which require a timeout +# parameter e.g. 'requests.api.get,requests.api.post' +timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when caught. +overgeneral-exceptions=builtins.BaseException, + builtins.Exception + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit,argparse.parse_error + + +[SIMILARITIES] + +# Comments are removed from the similarity computation +ignore-comments=yes + +# Docstrings are removed from the similarity computation +ignore-docstrings=yes + +# Imports are removed from the similarity computation +ignore-imports=yes + +# Signatures are removed from the similarity computation +ignore-signatures=yes + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[DESIGN] + +# List of regular expressions of class ancestor names to ignore when counting +# public methods (see R0903) +exclude-too-few-public-methods= + +# List of qualified class names to ignore when counting class parents (see +# R0901) +ignored-parents= + +# Maximum number of arguments for function / method. +max-args=5 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=no + +# This flag controls whether the implicit-str-concat should generate a warning +# on implicit string concatenation in sequences defined over several lines. +check-str-concat-over-line-jumps=no + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. If left empty, argument names will be checked with the set +# naming style. +#argument-rgx= + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. If left empty, attribute names will be checked with the set naming +# style. +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo, + bar, + baz, + toto, + tutu, + tata + +# Bad variable names regexes, separated by a comma. If names match any regex, +# they will always be refused +bad-names-rgxs= + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. If left empty, class attribute names will be checked +# with the set naming style. +#class-attribute-rgx= + +# Naming style matching correct class constant names. +class-const-naming-style=UPPER_CASE + +# Regular expression matching correct class constant names. Overrides class- +# const-naming-style. If left empty, class constant names will be checked with +# the set naming style. +#class-const-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. If left empty, class names will be checked with the set naming style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. If left empty, constant names will be checked with the set naming +# style. +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. If left empty, function names will be checked with the set +# naming style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=i, + j, + k, + ex, + Run, + _ + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +good-names-rgxs= + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. If left empty, inline iteration names will be checked +# with the set naming style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. If left empty, method names will be checked with the set naming style. +#method-rgx= + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. If left empty, module names will be checked with the set naming style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^(test)?_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Regular expression matching correct type variable names. If left empty, type +# variable names will be checked with the set naming style. +#typevar-rgx= + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. If left empty, variable names will be checked with the set +# naming style. +#variable-rgx= + +[tool.pylint.parameter_documentation] +# Whether to accept totally missing parameter documentation +# in the docstring of a function that has parameters. +accept-no-param-doc = false + +# Whether to accept totally missing raises documentation +# in the docstring of a function that raises an exception. +# accept-no-raise-doc = true + +# Whether to accept totally missing return documentation in +# the docstring of a function that returns a statement. +accept-no-return-doc = false + +# Whether to accept totally missing yields documentation +# in the docstring of a generator. +accept-no-yields-doc = false + +# If the docstring type cannot be guessed the +# specified docstring type will be used. +default-docstring-type = numpy \ No newline at end of file From 1964492041bca1da243210f4fab3dedc2f04961a Mon Sep 17 00:00:00 2001 From: Dylan Date: Tue, 31 Dec 2024 17:06:29 -0800 Subject: [PATCH 2/7] exclude tutorial and example folders --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c6dabd8..cb53707 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,13 +77,13 @@ line-length = 120 extend-exclude = "\\.ipynb" [tool.pylint] -exclude = [".venv", ".vscode", "docs"] +exclude = [".venv", ".vscode", "docs", "tutorials", "examples"] [tool.pylance] -exclude = [".venv", ".vscode", "docs"] +exclude = [".venv", ".vscode", "docs", "tutorials", "examples"] [tool.pyright] -exclude = [".venv", ".vscode", "docs"] +exclude = [".venv", ".vscode", "docs", "tutorials", "examples"] [tool.isort] line_length = 120 From 6de270a8e917533aa4958d3062a6d262be81cb55 Mon Sep 17 00:00:00 2001 From: Dylan Date: Tue, 14 Jan 2025 16:28:32 -0800 Subject: [PATCH 3/7] fix typing & import order --- sparsecoding/datasets.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sparsecoding/datasets.py b/sparsecoding/datasets.py index 13c82a8..56fd29b 100644 --- a/sparsecoding/datasets.py +++ b/sparsecoding/datasets.py @@ -1,10 +1,11 @@ -import torch import os + +import torch from scipy.io import loadmat -from sparsecoding.transforms import patchify from torch.utils.data import Dataset from sparsecoding.priors import Prior +from sparsecoding.transforms import patchify class BarsDataset(Dataset): @@ -94,7 +95,7 @@ def __init__( self, root: str, patch_size: int = 8, - stride: int = None, + stride: int | None = None, ): self.P = patch_size if stride is None: From 8b90428c33e47b10b4aa6a601ac4d47189f878d9 Mon Sep 17 00:00:00 2001 From: Dylan Date: Tue, 14 Jan 2025 16:28:45 -0800 Subject: [PATCH 4/7] add overloaded fns --- sparsecoding/transforms/whiten.py | 68 ++++++++++++++++++++++++------- 1 file changed, 54 insertions(+), 14 deletions(-) diff --git a/sparsecoding/transforms/whiten.py b/sparsecoding/transforms/whiten.py index 4206adc..7ca2bfd 100644 --- a/sparsecoding/transforms/whiten.py +++ b/sparsecoding/transforms/whiten.py @@ -1,9 +1,9 @@ +from typing import Literal, overload + import torch -from typing import Dict def compute_whitening_stats(X: torch.Tensor): - """ Given a tensor of data, compute statistics for whitening transform. @@ -29,15 +29,49 @@ def compute_whitening_stats(X: torch.Tensor): return {"mean": mean, "eigenvalues": eigenvalues, "eigenvectors": eigenvectors, "covariance": Sigma} +@overload +def whiten( + X: torch.Tensor, + algorithm: str = "zca", + stats: dict | None = None, + n_components: float | None = None, + epsilon: float = 0.0, + return_W: Literal[False] = False, +) -> torch.Tensor: ... + + +@overload +def whiten( + X: torch.Tensor, + algorithm: str = "zca", + stats: dict | None = None, + n_components: float | None = None, + epsilon: float = 0.0, + return_W: Literal[True] = True, +) -> tuple[torch.Tensor, torch.Tensor]: ... + + +# The last overload is a fallback in case the caller +# provides a regular bool: +@overload def whiten( X: torch.Tensor, algorithm: str = "zca", - stats: Dict = None, - n_components: float = None, + stats: dict | None = None, + n_components: float | None = None, epsilon: float = 0.0, return_W: bool = False, -) -> torch.Tensor: +) -> torch.Tensor: ... + +def whiten( + X: torch.Tensor, + algorithm: str = "zca", + stats: dict | None = None, + n_components: float | None = None, + epsilon: float = 0.0, + return_W: bool = False, +) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """ Apply whitening transform to data using pre-computed statistics. @@ -45,7 +79,7 @@ def whiten( ---------- X: Input data of shape [N, D] where N are unique data elements of dimensionality D algorithm: Whitening transform we want to apply, one of ['zca', 'pca', or 'cholesky'] - stats: Dict containing precomputed whitening statistics (mean, eigenvectors, eigenvalues) + stats: dict containing precomputed whitening statistics (mean, eigenvectors, eigenvalues) n_components: Number of principal components to keep. If None, keep all components. If int, keep that many components. If float between 0 and 1, keep components that explain that fraction of variance. @@ -73,18 +107,20 @@ def whiten( if algorithm == "pca" or algorithm == "zca": - scaling = 1.0 / torch.sqrt(stats.get("eigenvalues") + epsilon) + eigenvalues = stats.get("eigenvalues", None) + assert eigenvalues is not None # type narrowing + scaling = 1.0 / torch.sqrt(eigenvalues + epsilon) if n_components is not None: if isinstance(n_components, float): if not 0 < n_components <= 1: raise ValueError("If n_components is float, it must be between 0 and 1") - explained_variance_ratio = stats.get("eigenvalues") / torch.sum(stats.get("eigenvalues")) + explained_variance_ratio = stats.get("eigenvalues") / torch.sum(eigenvalues) cumulative_variance_ratio = torch.cumsum(explained_variance_ratio, dim=0) - n_components = torch.sum(cumulative_variance_ratio <= n_components) + 1 + n_components = float(torch.sum(cumulative_variance_ratio <= n_components)) + 1 elif isinstance(n_components, int): - if not 0 < n_components <= len(stats.get("eigenvalues")): - raise ValueError(f"n_components must be between 1 and {len(stats.get('eigenvalues'))}") + if not 0 < n_components <= len(eigenvalues): + raise ValueError(f"n_components must be between 1 and {len(eigenvalues)}") else: raise ValueError("n_components must be int or float") @@ -93,16 +129,20 @@ def whiten( scaling = scaling * mask scaling = torch.diag(scaling) + eigenvectors = stats.get("eigenvectors", None) + assert eigenvectors is not None # type narrowing if algorithm == "pca": # For PCA: project onto eigenvectors and scale - W = scaling @ stats.get("eigenvectors").T + W = scaling @ eigenvectors.T else: # For ZCA: project, scale, and rotate back - W = stats.get("eigenvectors") @ scaling @ stats.get("eigenvectors").T + W = eigenvectors @ scaling @ eigenvectors.T elif algorithm == "cholesky": # Based on Cholesky decomp, related to QR decomp - L = torch.linalg.cholesky(stats.get("covariance")) + covariance = stats.get("covariance", None) + assert covariance is not None # type narrowing + L = torch.linalg.cholesky(covariance) Identity = torch.eye(L.shape[0], device=L.device, dtype=L.dtype) # Solve L @ W = I for W, more stable and quicker than inv(L) W = torch.linalg.solve_triangular(L, Identity, upper=False) From 2740691b77e46c5b4b3da93dd3e52686c1122107 Mon Sep 17 00:00:00 2001 From: Dylan Date: Fri, 7 Feb 2025 14:10:05 -0800 Subject: [PATCH 5/7] fix install command --- .github/workflows/lint.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 60a6488..6e15470 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -24,9 +24,9 @@ jobs: - name: Install packages run: | - python -v venv --upgrade-deps .venv - source .vinv/bin/activate - pip install .[all] + python -m venv --upgrade-deps .venv + source .venv/bin/activate + pip install ".[all]" - name: Run black run: | From 8a9425b6c2701d826ddbe1c0d3113b7b85eb636d Mon Sep 17 00:00:00 2001 From: Dylan Date: Fri, 7 Feb 2025 15:01:35 -0800 Subject: [PATCH 6/7] run black --- sparsecoding/dictionaries.py | 2 +- sparsecoding/inference/ista_test.py | 4 +-- sparsecoding/inference/lca_test.py | 4 +-- sparsecoding/inference/lsm_test.py | 4 +-- .../inference/pytorch_optimizer_test.py | 4 +-- sparsecoding/inference/vanilla_test.py | 4 +-- sparsecoding/models.py | 23 +++++++++----- sparsecoding/visualization.py | 27 ++++++++-------- tutorials/vanilla/src/utils.py | 31 ++++++------------- 9 files changed, 49 insertions(+), 54 deletions(-) diff --git a/sparsecoding/dictionaries.py b/sparsecoding/dictionaries.py index 2c934d2..7833c9a 100644 --- a/sparsecoding/dictionaries.py +++ b/sparsecoding/dictionaries.py @@ -8,7 +8,7 @@ def load_dictionary_from_pickle(path): - dictionary_file = open(path, 'rb') + dictionary_file = open(path, "rb") numpy_dictionary = pkl.load(dictionary_file) dictionary_file.close() dictionary = torch.tensor(numpy_dictionary.astype(np.float32)) diff --git a/sparsecoding/inference/ista_test.py b/sparsecoding/inference/ista_test.py index 397c555..a2c8211 100644 --- a/sparsecoding/inference/ista_test.py +++ b/sparsecoding/inference/ista_test.py @@ -15,7 +15,7 @@ def test_shape( """Test that ISTA inference returns expected shapes.""" N_ITER = 10 - for (data, dataset) in zip(bars_datas_fixture, bars_datasets_fixture): + for data, dataset in zip(bars_datas_fixture, bars_datasets_fixture): inference_method = inference.ISTA(N_ITER) a = inference_method.infer(data, bars_dictionary_fixture) assert_shape_equal(a, dataset.weights) @@ -32,7 +32,7 @@ def test_inference( ): """Test that ISTA inference recovers the correct weights.""" N_ITER = 5000 - for (data, dataset) in zip(bars_datas_fixture, bars_datasets_fixture): + for data, dataset in zip(bars_datas_fixture, bars_datasets_fixture): inference_method = inference.ISTA(n_iter=N_ITER) a = inference_method.infer(data, bars_dictionary_fixture) diff --git a/sparsecoding/inference/lca_test.py b/sparsecoding/inference/lca_test.py index 0834ddd..627cfe4 100644 --- a/sparsecoding/inference/lca_test.py +++ b/sparsecoding/inference/lca_test.py @@ -17,7 +17,7 @@ def test_shape( """ N_ITER = 10 - for (data, dataset) in zip(bars_datas_fixture, bars_datasets_fixture): + for data, dataset in zip(bars_datas_fixture, bars_datasets_fixture): inference_method = inference.LCA(N_ITER) a = inference_method.infer(data, bars_dictionary_fixture) assert_shape_equal(a, dataset.weights) @@ -40,7 +40,7 @@ def test_inference( THRESHOLD = 0.1 N_ITER = 1000 - for (data, dataset) in zip(bars_datas_fixture, bars_datasets_fixture): + for data, dataset in zip(bars_datas_fixture, bars_datasets_fixture): inference_method = inference.LCA( coeff_lr=LR, threshold=THRESHOLD, diff --git a/sparsecoding/inference/lsm_test.py b/sparsecoding/inference/lsm_test.py index 8cd3470..8ebd5c1 100644 --- a/sparsecoding/inference/lsm_test.py +++ b/sparsecoding/inference/lsm_test.py @@ -17,7 +17,7 @@ def test_shape( """ N_ITER = 10 - for (data, dataset) in zip(bars_datas_fixture, bars_datasets_fixture): + for data, dataset in zip(bars_datas_fixture, bars_datasets_fixture): inference_method = inference.LSM(N_ITER) a = inference_method.infer(data, bars_dictionary_fixture) assert_shape_equal(a, dataset.weights) @@ -33,7 +33,7 @@ def test_inference( """ N_ITER = 1000 - for (data, dataset) in zip(bars_datas_fixture, bars_datasets_fixture): + for data, dataset in zip(bars_datas_fixture, bars_datasets_fixture): inference_method = inference.LSM(n_iter=N_ITER) a = inference_method.infer(data, bars_dictionary_fixture) diff --git a/sparsecoding/inference/pytorch_optimizer_test.py b/sparsecoding/inference/pytorch_optimizer_test.py index 87961af..32a5319 100644 --- a/sparsecoding/inference/pytorch_optimizer_test.py +++ b/sparsecoding/inference/pytorch_optimizer_test.py @@ -48,7 +48,7 @@ def test_shape( """ Test that PyTorchOptimizer inference returns expected shapes. """ - for (data, dataset) in zip(bars_datas_fixture, bars_datasets_fixture): + for data, dataset in zip(bars_datas_fixture, bars_datasets_fixture): inference_method = inference.PyTorchOptimizer( optimizer_fn, loss_fn, @@ -68,7 +68,7 @@ def test_inference( """ N_ITER = 1000 - for (data, dataset) in zip(bars_datas_fixture, bars_datasets_fixture): + for data, dataset in zip(bars_datas_fixture, bars_datasets_fixture): inference_method = inference.PyTorchOptimizer( optimizer_fn, loss_fn, diff --git a/sparsecoding/inference/vanilla_test.py b/sparsecoding/inference/vanilla_test.py index 9c556e5..76278e1 100644 --- a/sparsecoding/inference/vanilla_test.py +++ b/sparsecoding/inference/vanilla_test.py @@ -17,7 +17,7 @@ def test_shape( """ N_ITER = 10 - for (data, dataset) in zip(bars_datas_fixture, bars_datasets_fixture): + for data, dataset in zip(bars_datas_fixture, bars_datasets_fixture): inference_method = inference.Vanilla(N_ITER) a = inference_method.infer(data, bars_dictionary_fixture) assert_shape_equal(a, dataset.weights) @@ -38,7 +38,7 @@ def test_inference( LR = 5e-2 N_ITER = 1000 - for (data, dataset) in zip(bars_datas_fixture, bars_datasets_fixture): + for data, dataset in zip(bars_datas_fixture, bars_datasets_fixture): inference_method = inference.Vanilla(coeff_lr=LR, n_iter=N_ITER) a = inference_method.infer(data, bars_dictionary_fixture) diff --git a/sparsecoding/models.py b/sparsecoding/models.py index 5f19db1..78d18cb 100644 --- a/sparsecoding/models.py +++ b/sparsecoding/models.py @@ -6,8 +6,16 @@ class SparseCoding(torch.nn.Module): - def __init__(self, inference_method, n_basis, n_features, - sparsity_penalty=0.2, device=None, check_for_dictionary_nan=False, **kwargs): + def __init__( + self, + inference_method, + n_basis, + n_features, + sparsity_penalty=0.2, + device=None, + check_for_dictionary_nan=False, + **kwargs, + ): """Class for learning a sparse code via dictionary learning Parameters @@ -72,8 +80,7 @@ def update_dictionary(self, data, a): Already-inferred coefficients """ dictionary_grad = self.compute_grad_dict(data, a) - self.dictionary = torch.add(self.dictionary, - self.dictionary_lr*dictionary_grad) + self.dictionary = torch.add(self.dictionary, self.dictionary_lr * dictionary_grad) if self.check_for_dictionary_nan: self.checknan() @@ -115,7 +122,7 @@ def learn_dictionary(self, dataset, n_epoch, batch_size): self.normalize_dictionary() # compute current loss loss += self.compute_loss(batch, a) - losses.append(loss/len(dataloader)) + losses.append(loss / len(dataloader)) return np.asarray(losses) def compute_loss(self, data, a): @@ -135,10 +142,10 @@ def compute_loss(self, data, a): """ batch_size, _ = data.shape - MSE_loss = torch.square(torch.linalg.vector_norm(data-torch.mm(self.dictionary, a.t()).t(), dim=1)) - sparsity_loss = self.sparsity_penalty*torch.abs(a).sum(dim=1) + MSE_loss = torch.square(torch.linalg.vector_norm(data - torch.mm(self.dictionary, a.t()).t(), dim=1)) + sparsity_loss = self.sparsity_penalty * torch.abs(a).sum(dim=1) total_loss = torch.sum(MSE_loss + sparsity_loss) - return total_loss.item()/batch_size + return total_loss.item() / batch_size def get_numpy_dictionary(self): """Returns dictionary as numpy array diff --git a/sparsecoding/visualization.py b/sparsecoding/visualization.py index 1eb70c9..7fdb3e4 100644 --- a/sparsecoding/visualization.py +++ b/sparsecoding/visualization.py @@ -7,8 +7,9 @@ # TODO: Combine/refactor plot_dictionary and plot_patches; lots of repeated code. # TODO: Add method for visualizing coefficients. # TODO: Add method for visualizing reconstructions and original patches. -def plot_dictionary(dictionary, color=False, nrow=30, normalize=True, - scale_each=True, fig=None, ax=None, title="", size=8): +def plot_dictionary( + dictionary, color=False, nrow=30, normalize=True, scale_each=True, fig=None, ax=None, title="", size=8 +): """Plot all elements of dictionary in grid Parameters @@ -45,12 +46,12 @@ def plot_dictionary(dictionary, color=False, nrow=30, normalize=True, if color: nch = 3 - patch_size = int(np.sqrt(n_features//nch)) + patch_size = int(np.sqrt(n_features // nch)) - D_imgs = dictionary.T.reshape([n_basis, patch_size, patch_size, nch]).permute([ - 0, 3, 1, 2]) # swap channel dims for torch - grid_img = torchvision.utils.make_grid( - D_imgs, nrow=nrow, normalize=normalize, scale_each=scale_each).cpu() + D_imgs = dictionary.T.reshape([n_basis, patch_size, patch_size, nch]).permute( + [0, 3, 1, 2] + ) # swap channel dims for torch + grid_img = torchvision.utils.make_grid(D_imgs, nrow=nrow, normalize=normalize, scale_each=scale_each).cpu() if fig is None or ax is None: fig, ax = plt.subplots(1, 1, figsize=(size, size)) @@ -64,8 +65,7 @@ def plot_dictionary(dictionary, color=False, nrow=30, normalize=True, return fig, ax -def plot_patches(patches, color=False, normalize=True, scale_each=True, - fig=None, ax=None, title="", size=8): +def plot_patches(patches, color=False, normalize=True, scale_each=True, fig=None, ax=None, title="", size=8): """ Parameters ---------- @@ -104,11 +104,10 @@ def plot_patches(patches, color=False, normalize=True, scale_each=True, patch_size = int(np.sqrt(patches.size(1))) - D_imgs = patches.reshape( - [batch_size, patch_size, patch_size, nch]).permute([ - 0, 3, 1, 2]) # swap channel dims for torch - grid_img = make_grid( - D_imgs, nrow=nrow, normalize=normalize, scale_each=scale_each).cpu() + D_imgs = patches.reshape([batch_size, patch_size, patch_size, nch]).permute( + [0, 3, 1, 2] + ) # swap channel dims for torch + grid_img = make_grid(D_imgs, nrow=nrow, normalize=normalize, scale_each=scale_each).cpu() if fig is None or ax is None: fig, ax = plt.subplots(1, 1, figsize=(size, size)) diff --git a/tutorials/vanilla/src/utils.py b/tutorials/vanilla/src/utils.py index e949e9d..d338c9a 100644 --- a/tutorials/vanilla/src/utils.py +++ b/tutorials/vanilla/src/utils.py @@ -19,8 +19,7 @@ def show_patches(patches, title=""): size = int(np.sqrt(patches.size(1))) batch_size = patches.size(0) img_grid = torch.reshape(patches, (-1, 1, size, size)) - out = make_grid(img_grid, padding=1, nrow=int( - np.sqrt(batch_size)), pad_value=torch.min(patches))[0] + out = make_grid(img_grid, padding=1, nrow=int(np.sqrt(batch_size)), pad_value=torch.min(patches))[0] display(out, bar=False, title=title) @@ -33,12 +32,10 @@ def show_patches_sbs(orig, recon, title="", dpi=200): batch_size = orig.size(0) img_grid = torch.reshape(orig, (-1, 1, size, size)) - orig_out = make_grid(img_grid, padding=1, nrow=int( - np.sqrt(batch_size)), pad_value=torch.min(orig))[0] + orig_out = make_grid(img_grid, padding=1, nrow=int(np.sqrt(batch_size)), pad_value=torch.min(orig))[0] img_grid = torch.reshape(recon, (-1, 1, size, size)) - recon_out = make_grid(img_grid, padding=1, nrow=int( - np.sqrt(batch_size)), pad_value=torch.min(recon))[0] + recon_out = make_grid(img_grid, padding=1, nrow=int(np.sqrt(batch_size)), pad_value=torch.min(recon))[0] display_sbs(orig_out, recon_out, bar=False, title=title, dpi=dpi) @@ -145,8 +142,7 @@ def create_patches(imgs, epochs, batch_size, N, rng): patches : Tensor of size (epochs, batch_size, pixels_per_patch). """ # TODO: use rng here when sample_random_patches supports it. - patches = sample_random_patches(int(np.sqrt(N)), batch_size*epochs, - torch.unsqueeze(imgs, 1)) + patches = sample_random_patches(int(np.sqrt(N)), batch_size * epochs, torch.unsqueeze(imgs, 1)) patches = patches.reshape(epochs, batch_size, N) return patches @@ -161,8 +157,7 @@ def load_data(img_path): """ imgs = sio.loadmat(img_path)["IMAGES"] # (512, 512, 10) # normalize to mean 0 var 1 - imgs = (imgs - np.mean(imgs, axis=(0, 1), keepdims=True)) / np.std( - imgs, axis=(0, 1), keepdims=True) + imgs = (imgs - np.mean(imgs, axis=(0, 1), keepdims=True)) / np.std(imgs, axis=(0, 1), keepdims=True) return torch.Tensor(imgs).permute(2, 0, 1) @@ -210,11 +205,7 @@ def plot_coeffs(coeffs, title="patch_coefficients"): """ plt.stem(coeffs, use_line_collection=True) plt.title(title) - plt.tick_params( - axis='x', - which='both', - bottom=False, - labelbottom=False) + plt.tick_params(axis="x", which="both", bottom=False, labelbottom=False) plt.show() @@ -224,11 +215,11 @@ def coeff_grid(coeffs): coeffs is Tensor of shape (batch_size, number_of_bases). """ batch_size = coeffs.shape[0] - num_to_plot = batch_size//2 + num_to_plot = batch_size // 2 num_cols = 4 fig = plt.figure() - gs = fig.add_gridspec(num_to_plot//num_cols, num_cols, hspace=0, wspace=0) + gs = fig.add_gridspec(num_to_plot // num_cols, num_cols, hspace=0, wspace=0) subplots = gs.subplots(sharex="col", sharey=True).flatten() for i, subplot in enumerate(subplots[:num_to_plot]): subplot.stem(coeffs[i], use_line_collection=True) @@ -248,10 +239,8 @@ def show_components(phi, a, dpi): phi = phi[:, order] weighted_phi = (phi * a.T).T - weighted_phi = weighted_phi.reshape( - -1, 1, patch_size, patch_size) - components = make_grid(weighted_phi, ncol=int(np.sqrt(a.shape[0])), - padding=1, pad_value=-1)[0] + weighted_phi = weighted_phi.reshape(-1, 1, patch_size, patch_size) + components = make_grid(weighted_phi, ncol=int(np.sqrt(a.shape[0])), padding=1, pad_value=-1)[0] vmax = torch.max(weighted_phi) vmin = torch.min(weighted_phi) From 30fb699c78a4ddc2e332937267dd1ddbcbb24048 Mon Sep 17 00:00:00 2001 From: Dylan Date: Fri, 7 Feb 2025 15:44:55 -0800 Subject: [PATCH 7/7] turn on codecov --- .coveragerc | 15 +++++++++++++++ .github/workflows/test.yml | 13 ++++++------- 2 files changed, 21 insertions(+), 7 deletions(-) create mode 100644 .coveragerc diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..a76d6d2 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,15 @@ +[run] +omit = + */examples/*.* + */docs/*.* + */tutorials/*.* + +[report] +omit = + *test* + **/__init__.py +exclude_lines = + if typing: + if TYPE_CHECKING: + +ignore_errors = True \ No newline at end of file diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index bb34b5c..b4055ad 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -39,10 +39,9 @@ jobs: coverage xml -i coverage html -i - # TODO: This must be done after we have an RCTN admin create a Codecov token - # - name: upload coverage report to Codecov - # uses: codecov/codecov-action@v4 - # with: - # flags: unittests - # fail_ci_if_error: false - # token: ${{ secrets.CODECOV_TOKEN }} + - name: upload coverage report to Codecov + uses: codecov/codecov-action@v4 + with: + flags: unittests + fail_ci_if_error: false + token: ${{ secrets.CODECOV_TOKEN }}