diff --git a/eval_protocol/directory_utils.py b/eval_protocol/directory_utils.py index 74f691b9..a83aac1d 100644 --- a/eval_protocol/directory_utils.py +++ b/eval_protocol/directory_utils.py @@ -9,29 +9,13 @@ def find_eval_protocol_dir() -> str: """ - Find the .eval_protocol directory by looking up the directory tree. + Find the .eval_protocol directory in the user's home folder. Returns: - Path to the .eval_protocol directory + Path to the .eval_protocol directory in the user's home folder """ - # recursively look up for a .eval_protocol directory - current_dir = os.path.dirname(os.path.abspath(__file__)) - while current_dir != "/": - if os.path.exists(os.path.join(current_dir, EVAL_PROTOCOL_DIR)): - log_dir = os.path.join(current_dir, EVAL_PROTOCOL_DIR) - break - current_dir = os.path.dirname(current_dir) - else: - # if not found, recursively look up until a pyproject.toml or requirements.txt is found - current_dir = os.path.dirname(os.path.abspath(__file__)) - while current_dir != "/": - if any(os.path.exists(os.path.join(current_dir, f)) for f in PYTHON_FILES): - log_dir = os.path.join(current_dir, EVAL_PROTOCOL_DIR) - break - current_dir = os.path.dirname(current_dir) - else: - # get the PWD that this python process is running in - log_dir = os.path.join(os.getcwd(), EVAL_PROTOCOL_DIR) + # Always use the home folder for .eval_protocol directory + log_dir = os.path.expanduser(os.path.join("~", EVAL_PROTOCOL_DIR)) # create the .eval_protocol directory if it doesn't exist os.makedirs(log_dir, exist_ok=True) @@ -41,10 +25,10 @@ def find_eval_protocol_dir() -> str: def find_eval_protocol_datasets_dir() -> str: """ - Find the .eval_protocol/datasets directory by looking up the directory tree. + Find the .eval_protocol/datasets directory in the user's home folder. Returns: - Path to the .eval_protocol/datasets directory + Path to the .eval_protocol/datasets directory in the user's home folder """ log_dir = find_eval_protocol_dir() diff --git a/tests/test_directory_utils.py b/tests/test_directory_utils.py new file mode 100644 index 00000000..dcae2cdc --- /dev/null +++ b/tests/test_directory_utils.py @@ -0,0 +1,95 @@ +import os +import tempfile +from unittest.mock import patch +import pytest + +from eval_protocol.directory_utils import find_eval_protocol_dir, find_eval_protocol_datasets_dir + + +class TestDirectoryUtils: + """Test directory utility functions.""" + + def test_find_eval_protocol_dir_uses_home_folder(self): + """Test that find_eval_protocol_dir always maps to home folder.""" + with tempfile.TemporaryDirectory() as temp_dir: + with patch.dict(os.environ, {"HOME": temp_dir}): + result = find_eval_protocol_dir() + expected = os.path.expanduser("~/.eval_protocol") + assert result == expected + assert result.endswith(".eval_protocol") + assert os.path.exists(result) + + def test_find_eval_protocol_dir_creates_directory(self): + """Test that find_eval_protocol_dir creates the directory if it doesn't exist.""" + with tempfile.TemporaryDirectory() as temp_dir: + with patch.dict(os.environ, {"HOME": temp_dir}): + # Ensure the directory doesn't exist initially + eval_protocol_dir = os.path.join(temp_dir, ".eval_protocol") + if os.path.exists(eval_protocol_dir): + os.rmdir(eval_protocol_dir) + + # Call the function + result = find_eval_protocol_dir() + + # Verify the directory was created + assert result == eval_protocol_dir + assert os.path.exists(result) + assert os.path.isdir(result) + + def test_find_eval_protocol_dir_handles_tilde_expansion(self): + """Test that find_eval_protocol_dir properly handles tilde expansion.""" + with tempfile.TemporaryDirectory() as temp_dir: + with patch.dict(os.environ, {"HOME": temp_dir}): + result = find_eval_protocol_dir() + expected = os.path.expanduser("~/.eval_protocol") + assert result == expected + assert result.startswith(temp_dir) + + def test_find_eval_protocol_datasets_dir_uses_home_folder(self): + """Test that find_eval_protocol_datasets_dir also uses home folder.""" + with tempfile.TemporaryDirectory() as temp_dir: + with patch.dict(os.environ, {"HOME": temp_dir}): + result = find_eval_protocol_datasets_dir() + expected = os.path.expanduser("~/.eval_protocol/datasets") + assert result == expected + assert result.endswith(".eval_protocol/datasets") + assert os.path.exists(result) + assert os.path.isdir(result) + + def test_find_eval_protocol_datasets_dir_creates_directory(self): + """Test that find_eval_protocol_datasets_dir creates the datasets directory if it doesn't exist.""" + with tempfile.TemporaryDirectory() as temp_dir: + with patch.dict(os.environ, {"HOME": temp_dir}): + # Ensure the directories don't exist initially + eval_protocol_dir = os.path.join(temp_dir, ".eval_protocol") + datasets_dir = os.path.join(eval_protocol_dir, "datasets") + if os.path.exists(datasets_dir): + os.rmdir(datasets_dir) + if os.path.exists(eval_protocol_dir): + os.rmdir(eval_protocol_dir) + + # Call the function + result = find_eval_protocol_datasets_dir() + + # Verify both directories were created + assert result == datasets_dir + assert os.path.exists(result) + assert os.path.isdir(result) + assert os.path.exists(eval_protocol_dir) + assert os.path.isdir(eval_protocol_dir) + + def test_find_eval_protocol_dir_consistency(self): + """Test that multiple calls to find_eval_protocol_dir return the same path.""" + with tempfile.TemporaryDirectory() as temp_dir: + with patch.dict(os.environ, {"HOME": temp_dir}): + result1 = find_eval_protocol_dir() + result2 = find_eval_protocol_dir() + assert result1 == result2 + + def test_find_eval_protocol_datasets_dir_consistency(self): + """Test that multiple calls to find_eval_protocol_datasets_dir return the same path.""" + with tempfile.TemporaryDirectory() as temp_dir: + with patch.dict(os.environ, {"HOME": temp_dir}): + result1 = find_eval_protocol_datasets_dir() + result2 = find_eval_protocol_datasets_dir() + assert result1 == result2