forked from mlcommons/algorithmic-efficiency
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpyproject.toml
More file actions
187 lines (173 loc) · 5.66 KB
/
pyproject.toml
File metadata and controls
187 lines (173 loc) · 5.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
###############################################################################
# MLCommons Algorithmic Efficiency. #
###############################################################################
[project]
name = "algoperf"
dynamic = ["version"]
description = "Codebase for the AlgoPerf: Training Algorithms benchmark"
authors = [
{ name = "MLCommons Algorithms Working Group", email = "algorithms@mlcommons.org" },
]
license = { text = "Apache 2.0" }
readme = "README.md"
requires-python = ">=3.11"
keywords = [
"algoperf",
"algorithmic-efficiency",
"machine-learning",
"deep-learning",
"optimization",
"benchmarking",
"training-methods",
]
classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
dependencies = [
"absl-py==2.1.0",
"networkx==3.2.1",
"docker==7.1.0",
"numpy>=2.0.2",
"pandas>=2.0.1",
"tensorflow==2.19.0",
"tensorflow-datasets==4.9.9",
"tensorflow-probability==0.20.0",
"gputil==1.4.0",
"psutil==6.1.0",
"clu==0.0.12",
"matplotlib>=3.9.2",
"tabulate==0.9.0",
"wandb==0.21.0"
]
[build-system]
requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2"]
build-backend = "setuptools.build_meta"
[tool.setuptools]
py-modules = ["submission_runner"]
include-package-data = true
zip-safe = false
[tool.setuptools.packages]
find = {} # Scanning implicit namespaces is active by default
[tool.setuptools_scm]
# Version is automatically managed by setuptools_scm from GitHub tags.
version_file = "algoperf/_version.py"
###############################################################################
# (Optional) Dependencies #
###############################################################################
[project.optional-dependencies]
# All workloads
full = ["algoperf[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt]"]
# All workloads plus development dependencies
full_dev = ["algoperf[full,dev]"]
# Dependencies for developing the package
dev = ["ruff==0.12.0", "pytest==8.3.3", "pre-commit==4.0.1"]
wandb = ["wandb==0.19.6"]
# Workloads
criteo1tb = ["scikit-learn==1.5.2"]
fastmri = ["h5py==3.12.0", "scikit-image==0.24.0"]
ogbg = ["jraph==0.0.6.dev0", "scikit-learn==1.5.2"]
librispeech_conformer = [
"sentencepiece==0.2.0",
"tensorflow-text==2.19.0",
"pydub==0.25.1",
]
wmt = ["sentencepiece==0.2.0", "tensorflow-text==2.19.0"]
# Frameworks
jax_core_deps = [
"flax==0.10.7",
"optax==0.2.2",
"chex==0.1.86",
"ml_dtypes==0.5.1",
"protobuf==4.25.5",
]
jax_cpu = [
"jax==0.7.0",
"algoperf[jax_core_deps]",
]
jax_gpu = [
"jax[cuda12]==0.7.0",
"algoperf[jax_core_deps]",
"nvidia-cudnn-cu12==9.10.2.21", # temporary workaround for https://github.com/jax-ml/jax/issues/30663
]
pytorch_cpu = [
"torch==2.5.1",
"torchvision==0.20.1"
]
pytorch_gpu = [
"torch==2.5.1",
"torchvision==0.20.1",
] # Note: omit the cuda suffix and installing from the appropriate wheel will result in using locally installed CUDA.
###############################################################################
# Linting & Formatting Configurations #
###############################################################################
[tool.ruff]
line-length = 80
indent-width = 2
exclude = ["_version.py"]
target-version = "py311"
[tool.ruff.format]
quote-style = "single"
[tool.ruff.lint]
# Could add the commented out rules in the future:
extend-select = [
"BLE", # disallow catch-all exceptions
"COM", # enforce trailing comma rules
"F", # Pyflakes rules
"FA", # Enforce from __future__ import annotations
"I", # Isort rules
"ICN", # Use common import conventions
"PLE", # Pylint Errors
"TID", # Some good import practices
# "A", # flake8-builtins: detect shadowed builtins
# "B", # flake8-bugbear:
# "C4", # flake8-comprehensions: catch incorrect use of comprehensions
# "D", # pydocstyle
# "DOC", # pydoclint
# "DTZ", # flake8-datetimez: strict timezone manipulation with datetime
# "E", # pycodestyle errors
# "FBT", # flake8-boolean-trap: detect boolean traps
# "ISC", # flake8-implicit-str-concat: good use of string concatenation
# "N", # pep8-naming: enforce naming conventions
# "NPY", # Some numpy-specific things
# "PL", # All Pylint rules
# "PLC", # Pylint Convention
# "PLR", # Pylint Refactor
# "PLW", # Pylint Warnings
# "PTH", # flake8-use-pathlib: use pathlib instead of os.path
# "RET", # flake8-return: good return practices
# "S", # flake8-bandit: security testing
# "SIM", # flake8-simplify: common simplification rules
# "TC", # flake8-type-checking: enforce importing certain types in a TYPE_CHECKING block
# "TD", # flake8-todo: Be diligent with TODO comments
# "UP", # pyupgrade: Warn if things can changed due to newer versions
# "W", # pycodestyle warnings
]
ignore = [
# Conflicting lint rules with Ruff's formatter
# (see https://docs.astral.sh/ruff/formatter/#conflicting-lint-rules).
"W191",
"E111",
"E114",
"E117",
"D206",
"D300",
"Q000",
"Q001",
"Q002",
"Q003",
"COM812",
"COM819",
"ISC001",
"ISC002",
"FBT001",
"FBT003",
"TD003",
]