Skip to content

Commit 2f4e892

Browse files
authored
proper account detection (#322)
1 parent 69e53a7 commit 2f4e892

File tree

4 files changed

+96
-24
lines changed

4 files changed

+96
-24
lines changed

eval_protocol/auth.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,56 @@ def _get_credential_from_config_file(key_name: str) -> Optional[str]:
136136
return None
137137

138138

139+
def _get_credentials_from_config_file() -> Dict[str, Optional[str]]:
140+
"""
141+
Retrieve both api_key and account_id from auth.ini with a single read/parse.
142+
Tries simple parsing first for both keys, then falls back to configparser for any missing ones.
143+
Returns a dict with up to two keys: 'api_key' and 'account_id'.
144+
"""
145+
results: Dict[str, Optional[str]] = {}
146+
auth_ini_path = _get_auth_ini_file()
147+
if not auth_ini_path.exists():
148+
return results
149+
150+
# 1) Simple key=value parsing
151+
try:
152+
simple_creds = _parse_simple_auth_file(auth_ini_path)
153+
if "api_key" in simple_creds and simple_creds["api_key"]:
154+
results["api_key"] = simple_creds["api_key"]
155+
if "account_id" in simple_creds and simple_creds["account_id"]:
156+
results["account_id"] = simple_creds["account_id"]
157+
if "api_key" in results and "account_id" in results:
158+
return results
159+
except Exception as e:
160+
logger.warning("Error during simple parsing of %s: %s", str(auth_ini_path), e)
161+
162+
# 2) ConfigParser for any missing keys
163+
try:
164+
config = configparser.ConfigParser()
165+
config.read(auth_ini_path)
166+
for key_name in ("api_key", "account_id"):
167+
if key_name in results and results[key_name]:
168+
continue
169+
if "fireworks" in config and config.has_option("fireworks", key_name):
170+
value_from_file = config.get("fireworks", key_name)
171+
if value_from_file:
172+
results[key_name] = value_from_file
173+
continue
174+
if config.has_option(config.default_section, key_name):
175+
value_from_default = config.get(config.default_section, key_name)
176+
if value_from_default:
177+
results[key_name] = value_from_default
178+
except configparser.MissingSectionHeaderError:
179+
# Purely key=value file without section headers; simple parsing should have handled it already.
180+
logger.debug("%s has no section headers; falling back to simple parsing results.", str(auth_ini_path))
181+
except configparser.Error as e_config:
182+
logger.warning("Configparser error reading %s: %s", str(auth_ini_path), e_config)
183+
except Exception as e_general:
184+
logger.warning("Unexpected error reading %s: %s", str(auth_ini_path), e_general)
185+
186+
return results
187+
188+
139189
def get_fireworks_api_key() -> Optional[str]:
140190
"""
141191
Retrieves the Fireworks API key.
@@ -177,13 +227,15 @@ def get_fireworks_account_id() -> Optional[str]:
177227
The Account ID is sourced in the following order:
178228
1. FIREWORKS_ACCOUNT_ID environment variable.
179229
2. 'account_id' from the [fireworks] section of ~/.fireworks/auth.ini.
230+
3. If an API key is available (env or auth.ini), resolve via verifyApiKey.
180231
181232
Returns:
182233
The Account ID if found, otherwise None.
183234
"""
184235
# If a profile is active, prefer profile file first, then env
185236
if _is_profile_active():
186-
account_id_from_file = _get_credential_from_config_file("account_id")
237+
creds = _get_credentials_from_config_file()
238+
account_id_from_file = creds.get("account_id")
187239
if account_id_from_file:
188240
return account_id_from_file
189241
account_id = os.environ.get("FIREWORKS_ACCOUNT_ID")
@@ -196,11 +248,24 @@ def get_fireworks_account_id() -> Optional[str]:
196248
if account_id:
197249
logger.debug("Using FIREWORKS_ACCOUNT_ID from environment variable.")
198250
return account_id
199-
account_id_from_file = _get_credential_from_config_file("account_id")
251+
creds = _get_credentials_from_config_file()
252+
account_id_from_file = creds.get("account_id")
200253
if account_id_from_file:
201254
return account_id_from_file
202255

203-
logger.debug("Fireworks Account ID not found in environment variables or auth.ini.")
256+
# 3) Fallback: if API key is present, attempt to resolve via verifyApiKey (env or auth.ini)
257+
try:
258+
# Intentionally use get_fireworks_api_key to centralize precedence (env vs file)
259+
api_key_for_verify = get_fireworks_api_key()
260+
if api_key_for_verify:
261+
resolved = verify_api_key_and_get_account_id(api_key=api_key_for_verify, api_base=get_fireworks_api_base())
262+
if resolved:
263+
logger.debug("Using FIREWORKS_ACCOUNT_ID resolved via verifyApiKey: %s", resolved)
264+
return resolved
265+
except Exception as e:
266+
logger.debug("Failed to resolve FIREWORKS_ACCOUNT_ID via verifyApiKey: %s", e)
267+
268+
logger.debug("Fireworks Account ID not found in environment variables, auth.ini, or via verifyApiKey.")
204269
return None
205270

206271

eval_protocol/evaluation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,9 @@ def _ensure_requirements_present(source_dir: str) -> None:
595595
logger.error("Missing requirements.txt in upload directory: %s", source_dir)
596596
raise ValueError(
597597
"Upload requires requirements.txt in the project root. "
598-
"Please add requirements.txt and re-run ep upload."
598+
"Create a requirements.txt (it can be empty) and rerun 'eval-protocol upload' "
599+
"or 'eval-protocol create rft'. If you're running in a notebook (e.g., Colab), "
600+
f"create the file in your working directory (e.g., {source_dir}/requirements.txt)."
599601
)
600602

601603
@staticmethod

eval_protocol/pytest/handle_persist_flow.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@
1111
from eval_protocol.directory_utils import find_eval_protocol_dir
1212
from eval_protocol.models import EvaluationRow
1313
from eval_protocol.pytest.store_experiment_link import store_experiment_link
14+
from eval_protocol.auth import (
15+
get_fireworks_api_key,
16+
get_fireworks_account_id,
17+
verify_api_key_and_get_account_id,
18+
get_fireworks_api_base,
19+
)
1420

1521
import requests
1622

@@ -90,22 +96,16 @@ def handle_persist_flow(all_results: list[list[EvaluationRow]], test_func_name:
9096
if not should_upload:
9197
continue
9298

93-
def get_auth_value(key: str) -> str | None:
94-
"""Get auth value from config file or environment."""
99+
# Resolve credentials using centralized auth helpers with verification fallback
100+
fireworks_api_key = get_fireworks_api_key()
101+
fireworks_account_id = get_fireworks_account_id()
102+
if not fireworks_account_id and fireworks_api_key:
95103
try:
96-
config_path = Path.home() / ".fireworks" / "auth.ini"
97-
if config_path.exists():
98-
config = configparser.ConfigParser() # noqa: F821
99-
config.read(config_path)
100-
for section in ["DEFAULT", "auth"]:
101-
if config.has_section(section) and config.has_option(section, key):
102-
return config.get(section, key)
104+
fireworks_account_id = verify_api_key_and_get_account_id(
105+
api_key=fireworks_api_key, api_base=get_fireworks_api_base()
106+
)
103107
except Exception:
104-
pass
105-
return os.getenv(key)
106-
107-
fireworks_api_key = get_auth_value("FIREWORKS_API_KEY")
108-
fireworks_account_id = get_auth_value("FIREWORKS_ACCOUNT_ID")
108+
fireworks_account_id = None
109109

110110
if not fireworks_api_key and not fireworks_account_id:
111111
store_experiment_link(
@@ -129,7 +129,7 @@ def get_auth_value(key: str) -> str | None:
129129
)
130130
continue
131131

132-
api_base = "https://api.fireworks.ai"
132+
api_base = get_fireworks_api_base()
133133
headers = {
134134
"Authorization": f"Bearer {fireworks_api_key}",
135135
"Content-Type": "application/json",

tests/test_auth.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,8 @@ def test_get_account_id_not_found(mock_path_exists):
255255
with patch("eval_protocol.auth._parse_simple_auth_file", return_value={}) as mock_parse_simple:
256256
assert get_fireworks_account_id() is None
257257
mock_parse_simple.assert_not_called()
258-
mock_path_exists.assert_called_once_with()
258+
# With verify fallback using get_fireworks_api_key, exists() may be checked more than once
259+
assert mock_path_exists.call_count >= 1
259260

260261

261262
@patch("pathlib.Path.exists", return_value=True)
@@ -269,7 +270,8 @@ def test_get_account_id_ini_exists_no_section(mock_parse_simple, mock_ConfigPars
269270
mock_open(read_data="other_key = some_val_but_no_section_header\nanother=val"),
270271
):
271272
assert get_fireworks_account_id() is None
272-
mock_parse_simple.assert_called_once_with(AUTH_INI_FILE)
273+
# Fallback verify path may trigger a second simple parse for api_key; ensure at least one call
274+
assert mock_parse_simple.call_count >= 1
273275

274276

275277
@patch("pathlib.Path.exists", return_value=True)
@@ -283,7 +285,8 @@ def test_get_account_id_ini_exists_no_id_option(mock_parse_simple, mock_ConfigPa
283285

284286
with patch("builtins.open", mock_open(read_data="[fireworks]\nsome_other_key=foo")):
285287
assert get_fireworks_account_id() is None
286-
mock_parse_simple.assert_called_once_with(AUTH_INI_FILE)
288+
# Fallback verify path may trigger a second simple parse for api_key; ensure at least one call
289+
assert mock_parse_simple.call_count >= 1
287290

288291

289292
@patch("pathlib.Path.exists", return_value=True)
@@ -301,7 +304,8 @@ def test_get_account_id_ini_empty_value(mock_parse_simple, mock_ConfigParser_cla
301304
)
302305
with patch("builtins.open", mock_open(read_data="[fireworks]\naccount_id=")):
303306
assert get_fireworks_account_id() is None
304-
mock_parse_simple.assert_called_once_with(AUTH_INI_FILE)
307+
# Fallback verify path may trigger a second simple parse for api_key; ensure at least one call
308+
assert mock_parse_simple.call_count >= 1
305309

306310

307311
@patch("pathlib.Path.exists", return_value=True)
@@ -372,7 +376,8 @@ def test_get_account_id_ini_parse_error(mock_parse_simple, mock_ConfigParser_cla
372376
assert get_fireworks_account_id() is None
373377
assert "Configparser error reading" in caplog.text
374378
assert "Mocked Parsing Error" in caplog.text
375-
mock_parse_simple.assert_called_once_with(AUTH_INI_FILE)
379+
# Fallback verify path may trigger a second simple parse for api_key; ensure at least one call
380+
assert mock_parse_simple.call_count >= 1
376381

377382

378383
@patch("pathlib.Path.exists", return_value=True)

0 commit comments

Comments
 (0)