Skip to content

Commit 1046dd6

Browse files
authored
Add ability to filter on branches (#2)
* Add ability to filter on branches * Remove get script
1 parent 866a3b3 commit 1046dd6

2 files changed

Lines changed: 16 additions & 17 deletions

File tree

get.sh

Lines changed: 0 additions & 12 deletions
This file was deleted.

memprof_plotter/plotter.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,16 @@ def download_artefact(url: str) -> bytes | None:
6262
else:
6363
print("Artefact does not contain required TSP database")
6464
return None
65-
6665

6766

68-
def get_artefacts(nruns: int, workflow: github.Workflow.Workflow, artefact: str) -> dict[int, bytes]:
67+
def get_artefacts(nruns: int, workflow: github.Workflow.Workflow, artefact: str, filter: list[str]) -> dict[int, bytes]:
6968
irun = 0
7069
runs = {}
7170
for run in workflow.get_runs(status="success"):
72-
k = run.run_number
71+
if filter:
72+
if run.head_branch not in filter:
73+
continue
74+
k = f"{run.run_number} - {run.head_branch}"
7375
for gha in run.get_artifacts():
7476
if gha.name == artefact:
7577
artefact_data = download_artefact(gha.archive_download_url)
@@ -83,7 +85,6 @@ def get_artefacts(nruns: int, workflow: github.Workflow.Workflow, artefact: str)
8385

8486

8587
def main():
86-
8788
if gh_token == "BAD_KEY":
8889
raise KeyError("GH_TOKEN must be set in environment")
8990

@@ -110,6 +111,14 @@ def main():
110111
parser.add_argument(
111112
"-a", "--artefact", required=False, type=str, default="run-log", help="Name of artefact containing memprof data"
112113
)
114+
parser.add_argument(
115+
"-f",
116+
"--filter",
117+
required=False,
118+
type=str,
119+
default="",
120+
help="Comma separated list of branch names to filter runs on",
121+
)
113122

114123
ns = parser.parse_args(sys.argv[1:])
115124

@@ -118,7 +127,9 @@ def main():
118127
gh = github.Github(auth=auth)
119128
repo = gh.get_repo(ns.repo)
120129

121-
runs = get_artefacts(ns.nruns, repo.get_workflow(ns.workflow), ns.artefact)
130+
filter = ns.filter.split(",") if ns.filter else []
131+
132+
runs = get_artefacts(ns.nruns, repo.get_workflow(ns.workflow), ns.artefact, filter)
122133

123134
d_times = defaultdict(dict)
124135
d_rss = defaultdict(dict)

0 commit comments

Comments
 (0)