Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 202 additions & 57 deletions cadetrdm/repositories.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from types import ModuleType
from typing import List, Optional, Any
from urllib.request import urlretrieve
import time
import uuid

from semantic_version import Version

Expand Down Expand Up @@ -393,53 +395,117 @@ def test_for_uncommitted_changes(self):
if self.has_uncomitted_changes:
raise RuntimeError(f"Found uncommitted changes in the repository {self.path}.")

def push(self, remote=None, local_branch=None, remote_branch=None, push_all=True):
def push(
self,
remote: str | None = None,
local_branch: str | None = None,
remote_branch: str | None = None,
push_all: bool = True,
) -> None:
"""
Push local branch to remote.

:param remote:
Name of the remote to push to.
:param local_branch:
Name of the local branch to push.
:param remote_branch:
Name of the remote branch to push to.
:return:
Push local changes to configured remotes.

Behavior:
- Project repo: fetch then push. No implicit pull/merge/rebase is performed.
This avoids failures caused by missing Git pull strategy configuration.
- Output repo (ProjectRepo only, when push_all=True):
1) Push the current output (run) branch first.
2) Update and push the output main branch using a retry strategy to handle
concurrent updates to shared log files (e.g., log.tsv, run_history).

Args:
remote: Optional remote name. If None, all configured remotes are used.
local_branch: Local branch name to push when push_all is False.
Defaults to the currently active project branch.
remote_branch: Remote branch name to push to when push_all is False.
Defaults to local_branch.
push_all: If True, push all branches/refs of the project repo and also push
output repo state. If False, only push local_branch -> remote_branch
for the project repo and do not push the output repo.

Raises:
RuntimeError: If no project repo remote is configured, or if pushing the output
repo is requested but no output remote is configured.
Exception: Re-raises errors from pushing the output run branch.
"""
if local_branch is None:
local_branch = self.active_branch
local_branch = str(self.active_branch)
if remote_branch is None:
remote_branch = local_branch

if remote is None:
if len(self._git_repo.remotes) == 0:
if not self._git_repo.remotes:
raise RuntimeError("No remote has been set for this repository yet.")
remote_list = [str(remote.name) for remote in self._git_repo.remotes]
remote_list = [str(r.name) for r in self._git_repo.remotes]
else:
remote_list = [remote]

if local_branch == self.main_branch or push_all:
if push_all:
self.checkout(self.main_branch)
for remote in remote_list:
remote_interface = self._git_repo.remotes[remote]
try:
remote_interface.pull()
except Exception as e:
print("Pulling from this remote failed with the following error:")
print(e)

for remote in remote_list:
remote_interface = self._git_repo.remotes[remote]
# -------------------
# Project repo push
# -------------------
if push_all:
self.checkout(self.main_branch)

for remote_name in remote_list:
remote_interface = self._git_repo.remotes[remote_name]

# Fetch
try:
remote_interface.fetch()
except Exception as exc: # noqa: BLE001
print("Fetching from this remote failed with the following error:")
print(exc)

# Push
if push_all:
push_results = remote_interface.push(all=True)
else:
push_results = remote_interface.push(refspec=f'{local_branch}:{remote_branch}')
push_results = remote_interface.push(
refspec=f"{local_branch}:{remote_branch}"
)

for push_res in push_results:
print(push_res.summary)

if hasattr(self, "output_repo") and push_all:
self.output_repo.push()
# ------------------------------------------------------------------
# Output repo push (only for ProjectRepo, and only when push_all)
# ------------------------------------------------------------------
if not (hasattr(self, "output_repo") and push_all):
return

if not self.output_repo._git_repo.remotes:
raise RuntimeError("No remote has been set for the output repository yet.")

# Publish the currently checked out output branch
output_branch_name = str(self.output_repo.active_branch)

out_origin = self.output_repo._git_repo.remotes.origin
try:
out_origin.fetch()
except Exception as exc: # noqa: BLE001
print("Fetching output remote failed with the following error:")
print(exc)

# 1) Push the run branch first
try:
out_origin.push(output_branch_name)
except Exception as exc: # noqa: BLE001
print("Pushing output run branch failed with the following error:")
print(exc)
raise

# 2) Update and push output main with retry
self._push_output_main_with_retry(
output_branch_name=output_branch_name,
output_dict={},
options=None,
)

# Return to the run branch after updating main
try:
self.output_repo._git.checkout(output_branch_name)
except Exception: # noqa: BLE001
return

def delete_active_branch_if_branch_is_empty(self):
"""
Expand Down Expand Up @@ -1000,8 +1066,8 @@ def get_new_output_branch_name(self, branch_prefix: str | None = None) -> str:
"""
project_repo_hash = str(self.head.commit)
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

branch_name = f"{timestamp}_{self.active_branch}_{project_repo_hash[:7]}"
suffix = uuid.uuid4().hex[:6]
branch_name = f"{timestamp}_{self.active_branch}_{project_repo_hash[:7]}_{suffix}"

if branch_prefix:
branch_name = f"{branch_prefix}_{branch_name}"
Expand Down Expand Up @@ -1064,32 +1130,64 @@ def output_log_file(self):
def output_log(self):
return self.output_repo.output_log

def update_output_main_logs(
self,
output_dict: dict = None,
options: Options | None = None,
):
def _fetch_output_origin(self) -> None:
"""
Dumps all the metadata information about the project repositories state and
the commit hash and branch name of the ouput repository into the main branch of
the output repository.
:param output_dict:
Dictionary containing key-value pairs to be added to the log.
Fetch from the output repository origin remote.

This is used as the first step in the conflict-safe update of the output
repository main branch. It assumes that an ``origin`` remote exists.
"""
if output_dict is None:
output_dict = {}
origin = self.output_repo._git_repo.remotes.origin
origin.fetch()

output_branch_name = str(self.output_repo.active_branch)
def _reset_output_main_to_origin(self) -> None:
"""
Reset the local output repository main branch to match the remote.

output_repo_hash = str(self.output_repo.head.commit)
output_commit_message = self.output_repo.active_branch.commit.message
output_commit_message = output_commit_message.replace("\n", "; ")
This checks out the output main branch and hard-resets it to
``origin/<main_branch>`` so the subsequent log update is applied on top
of the latest remote state.

Note:
This assumes that ``origin/<main_branch>`` exists. In tests/CI, ensure the
bare output remote HEAD points to main and that main has been pushed once.
"""
self.output_repo._git.checkout(self.output_repo.main_branch)
self.output_repo._git_repo.git.reset(
"--hard",
f"origin/{self.output_repo.main_branch}",
)

def _apply_log_update_on_main(
self,
output_branch_name: str,
output_dict: dict | None,
options: Options | None,
) -> None:
"""
Apply a log update on output main and commit it.

Preconditions:
- ``self.output_repo`` is currently checked out on its main branch.
- The working tree is clean.

Effects:
- Writes ``run_history/<output_branch_name>/metadata.json`` (and options.json).
- Updates ``log.tsv`` by adding an entry for the given output branch.
- Commits the changes on output main.

Args:
output_branch_name: Name of the output branch that contains the results.
output_dict: Optional extra key/value pairs added to the log entry.
options: Optional case options used to compute an options hash and persist
an options JSON file in the run history.
"""
logs_dir = self.output_repo.path / "run_history" / output_branch_name
if not logs_dir.exists():
os.makedirs(logs_dir)
logs_dir.mkdir(parents=True, exist_ok=True)

commit_obj = self.output_repo._git_repo.commit(output_branch_name)
output_repo_hash = commit_obj.hexsha
output_commit_message = commit_obj.message.replace("\n", "; ")

entry = LogEntry(
output_repo_commit_message=output_commit_message,
Expand All @@ -1103,21 +1201,20 @@ def update_output_main_logs(
tags=", ".join(self.tags),
options_hash=options.get_hash() if options else None,
filepath=None,
**output_dict
**(output_dict or {}),
)

with open(logs_dir / "metadata.json", "w", encoding="utf-8") as f:
json.dump(entry.to_dict(), f, indent=2)
with open(logs_dir / "metadata.json", "w", encoding="utf-8") as handle:
json.dump(entry.to_dict(), handle, indent=2)

if options:
options.dump_json_file(logs_dir / "options.json", indent=2)

log = OutputLog(self.output_log_file)
log = OutputLog(self.output_repo.output_log_file_path)
log.entries[output_branch_name] = entry
log.write()

self.dump_package_list(logs_dir)

self._copy_code(logs_dir)

self.output_repo.add(".")
Expand All @@ -1126,8 +1223,57 @@ def update_output_main_logs(
f"log for '{output_commit_message}' of branch '{output_branch_name}'",
)

self.output_repo._git.checkout(output_branch_name)
self._most_recent_branch = output_branch_name
def _push_output_main_with_retry(
self,
output_branch_name: str,
output_dict: dict | None,
options: Options | None,
max_attempts: int = 8,
backoff_seconds: float = 0.25,
) -> None:
"""
Push an updated output main branch with retries on non-fast-forward rejection.

This implements an eventually-consistent update of shared log files on the
output repository main branch when multiple machines push concurrently.

Strategy:
fetch -> reset local main to origin/main -> apply log update -> push main
Retry when the push is rejected due to a concurrent update.

Args:
output_branch_name: Output branch that contains the newly produced results.
output_dict: Optional extra key/value pairs added to the log entry.
options: Optional case options written to run history and hashed into the log.
max_attempts: Maximum number of retry attempts.
backoff_seconds: Base backoff time (seconds). The sleep time scales with
the attempt count.

Raises:
git.exc.GitCommandError: If the update fails for reasons other than a push
rejection, or if the maximum number of attempts is exceeded.
"""
origin = self.output_repo._git_repo.remotes.origin

for attempt in range(1, max_attempts + 1):
try:
self._fetch_output_origin()
self._reset_output_main_to_origin()

if self.output_repo.has_uncomitted_changes:
self.output_repo._reset_hard_to_head(force_entry=True)

self._apply_log_update_on_main(output_branch_name, output_dict, options)

origin.push(self.output_repo.main_branch)
return

except git.exc.GitCommandError as exc:
msg = str(exc).lower()
is_reject = "non-fast-forward" in msg or "rejected" in msg
if not is_reject or attempt == max_attempts:
raise
time.sleep(backoff_seconds * attempt)

def _copy_code(self, target_path):
"""
Expand Down Expand Up @@ -1463,7 +1609,6 @@ def _commit_output_data(
# This has to be using ._git.commit to raise an error if no results have been written.
commit_return = self.output_repo._git.commit("-m", message)
self.copy_data_to_cache()
self.update_output_main_logs(output_dict, options)
main_cach_path = self.path / (self.output_directory + "_cached") / self.output_repo.main_branch
if main_cach_path.exists():
delete_path(main_cach_path)
Expand Down
5 changes: 3 additions & 2 deletions tests/test_options.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
import uuid

import numpy as np
import pytest
Expand Down Expand Up @@ -90,7 +91,7 @@ def test_branch_name(clean_repo):
new_branch = clean_repo.get_new_output_branch_name()

escaped_branch = re.escape(active_branch)
pattern = rf"^\d{{4}}-\d{{2}}-\d{{2}}_\d{{2}}-\d{{2}}-\d{{2}}_{escaped_branch}_{hash}$"
pattern = rf"^\d{{4}}-\d{{2}}-\d{{2}}_\d{{2}}-\d{{2}}-\d{{2}}_{escaped_branch}_{hash}_[0-9a-f]{{6}}$"
assert re.match(pattern, new_branch), f"Branch name '{new_branch}' does not match expected format"


Expand All @@ -107,7 +108,7 @@ def test_branch_name_with_prefix(clean_repo):
new_branch = clean_repo.get_new_output_branch_name(options.branch_prefix)

escaped_branch = re.escape(active_branch)
pattern = rf"^Test_Prefix_\d{{4}}-\d{{2}}-\d{{2}}_\d{{2}}-\d{{2}}-\d{{2}}_{escaped_branch}_{hash}$"
pattern = rf"^Test_Prefix_\d{{4}}-\d{{2}}-\d{{2}}_\d{{2}}-\d{{2}}-\d{{2}}_{escaped_branch}_{hash}_[0-9a-f]{{6}}$"
assert re.match(pattern, new_branch), f"Branch name '{new_branch}' does not match expected format"


Expand Down
Loading