Skip to content

Commit a39b535

Browse files
author
Dylan Huang
authored
Refactor directory utility functions to always use the user's home folder for .eval_protocol and datasets directories. Add unit tests to verify functionality and directory creation. (#257)
1 parent 1ae8143 commit a39b535

File tree

2 files changed

+101
-22
lines changed

2 files changed

+101
-22
lines changed

eval_protocol/directory_utils.py

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,29 +9,13 @@
99

1010
def find_eval_protocol_dir() -> str:
1111
"""
12-
Find the .eval_protocol directory by looking up the directory tree.
12+
Find the .eval_protocol directory in the user's home folder.
1313
1414
Returns:
15-
Path to the .eval_protocol directory
15+
Path to the .eval_protocol directory in the user's home folder
1616
"""
17-
# recursively look up for a .eval_protocol directory
18-
current_dir = os.path.dirname(os.path.abspath(__file__))
19-
while current_dir != "/":
20-
if os.path.exists(os.path.join(current_dir, EVAL_PROTOCOL_DIR)):
21-
log_dir = os.path.join(current_dir, EVAL_PROTOCOL_DIR)
22-
break
23-
current_dir = os.path.dirname(current_dir)
24-
else:
25-
# if not found, recursively look up until a pyproject.toml or requirements.txt is found
26-
current_dir = os.path.dirname(os.path.abspath(__file__))
27-
while current_dir != "/":
28-
if any(os.path.exists(os.path.join(current_dir, f)) for f in PYTHON_FILES):
29-
log_dir = os.path.join(current_dir, EVAL_PROTOCOL_DIR)
30-
break
31-
current_dir = os.path.dirname(current_dir)
32-
else:
33-
# get the PWD that this python process is running in
34-
log_dir = os.path.join(os.getcwd(), EVAL_PROTOCOL_DIR)
17+
# Always use the home folder for .eval_protocol directory
18+
log_dir = os.path.expanduser(os.path.join("~", EVAL_PROTOCOL_DIR))
3519

3620
# create the .eval_protocol directory if it doesn't exist
3721
os.makedirs(log_dir, exist_ok=True)
@@ -41,10 +25,10 @@ def find_eval_protocol_dir() -> str:
4125

4226
def find_eval_protocol_datasets_dir() -> str:
4327
"""
44-
Find the .eval_protocol/datasets directory by looking up the directory tree.
28+
Find the .eval_protocol/datasets directory in the user's home folder.
4529
4630
Returns:
47-
Path to the .eval_protocol/datasets directory
31+
Path to the .eval_protocol/datasets directory in the user's home folder
4832
"""
4933
log_dir = find_eval_protocol_dir()
5034

tests/test_directory_utils.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import os
2+
import tempfile
3+
from unittest.mock import patch
4+
import pytest
5+
6+
from eval_protocol.directory_utils import find_eval_protocol_dir, find_eval_protocol_datasets_dir
7+
8+
9+
class TestDirectoryUtils:
10+
"""Test directory utility functions."""
11+
12+
def test_find_eval_protocol_dir_uses_home_folder(self):
13+
"""Test that find_eval_protocol_dir always maps to home folder."""
14+
with tempfile.TemporaryDirectory() as temp_dir:
15+
with patch.dict(os.environ, {"HOME": temp_dir}):
16+
result = find_eval_protocol_dir()
17+
expected = os.path.expanduser("~/.eval_protocol")
18+
assert result == expected
19+
assert result.endswith(".eval_protocol")
20+
assert os.path.exists(result)
21+
22+
def test_find_eval_protocol_dir_creates_directory(self):
23+
"""Test that find_eval_protocol_dir creates the directory if it doesn't exist."""
24+
with tempfile.TemporaryDirectory() as temp_dir:
25+
with patch.dict(os.environ, {"HOME": temp_dir}):
26+
# Ensure the directory doesn't exist initially
27+
eval_protocol_dir = os.path.join(temp_dir, ".eval_protocol")
28+
if os.path.exists(eval_protocol_dir):
29+
os.rmdir(eval_protocol_dir)
30+
31+
# Call the function
32+
result = find_eval_protocol_dir()
33+
34+
# Verify the directory was created
35+
assert result == eval_protocol_dir
36+
assert os.path.exists(result)
37+
assert os.path.isdir(result)
38+
39+
def test_find_eval_protocol_dir_handles_tilde_expansion(self):
40+
"""Test that find_eval_protocol_dir properly handles tilde expansion."""
41+
with tempfile.TemporaryDirectory() as temp_dir:
42+
with patch.dict(os.environ, {"HOME": temp_dir}):
43+
result = find_eval_protocol_dir()
44+
expected = os.path.expanduser("~/.eval_protocol")
45+
assert result == expected
46+
assert result.startswith(temp_dir)
47+
48+
def test_find_eval_protocol_datasets_dir_uses_home_folder(self):
49+
"""Test that find_eval_protocol_datasets_dir also uses home folder."""
50+
with tempfile.TemporaryDirectory() as temp_dir:
51+
with patch.dict(os.environ, {"HOME": temp_dir}):
52+
result = find_eval_protocol_datasets_dir()
53+
expected = os.path.expanduser("~/.eval_protocol/datasets")
54+
assert result == expected
55+
assert result.endswith(".eval_protocol/datasets")
56+
assert os.path.exists(result)
57+
assert os.path.isdir(result)
58+
59+
def test_find_eval_protocol_datasets_dir_creates_directory(self):
60+
"""Test that find_eval_protocol_datasets_dir creates the datasets directory if it doesn't exist."""
61+
with tempfile.TemporaryDirectory() as temp_dir:
62+
with patch.dict(os.environ, {"HOME": temp_dir}):
63+
# Ensure the directories don't exist initially
64+
eval_protocol_dir = os.path.join(temp_dir, ".eval_protocol")
65+
datasets_dir = os.path.join(eval_protocol_dir, "datasets")
66+
if os.path.exists(datasets_dir):
67+
os.rmdir(datasets_dir)
68+
if os.path.exists(eval_protocol_dir):
69+
os.rmdir(eval_protocol_dir)
70+
71+
# Call the function
72+
result = find_eval_protocol_datasets_dir()
73+
74+
# Verify both directories were created
75+
assert result == datasets_dir
76+
assert os.path.exists(result)
77+
assert os.path.isdir(result)
78+
assert os.path.exists(eval_protocol_dir)
79+
assert os.path.isdir(eval_protocol_dir)
80+
81+
def test_find_eval_protocol_dir_consistency(self):
82+
"""Test that multiple calls to find_eval_protocol_dir return the same path."""
83+
with tempfile.TemporaryDirectory() as temp_dir:
84+
with patch.dict(os.environ, {"HOME": temp_dir}):
85+
result1 = find_eval_protocol_dir()
86+
result2 = find_eval_protocol_dir()
87+
assert result1 == result2
88+
89+
def test_find_eval_protocol_datasets_dir_consistency(self):
90+
"""Test that multiple calls to find_eval_protocol_datasets_dir return the same path."""
91+
with tempfile.TemporaryDirectory() as temp_dir:
92+
with patch.dict(os.environ, {"HOME": temp_dir}):
93+
result1 = find_eval_protocol_datasets_dir()
94+
result2 = find_eval_protocol_datasets_dir()
95+
assert result1 == result2

0 commit comments

Comments
 (0)