Skip to content

Commit 0e17e2b

Browse files
authored
Merge pull request #195 from tcdent/more-coverage
Get to 100% coverage on inputs.py, agents.py, tasks.py, proj_templates.py
2 parents 12732c0 + ef1709c commit 0e17e2b

6 files changed

Lines changed: 392 additions & 33 deletions

File tree

agentstack/cli/cli.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,7 @@ def init_project_builder(
5555

5656
template_data = None
5757
if template is not None:
58-
if template.startswith("https://"):
59-
try:
60-
template_data = TemplateConfig.from_url(template)
61-
except Exception as e:
62-
raise Exception(f"Failed to fetch template data from {template}.\n{e}")
63-
else:
64-
try:
65-
template_data = TemplateConfig.from_template_name(template)
66-
except Exception as e:
67-
raise Exception(f"Failed to load template {template}.\n{e}")
58+
template_data = TemplateConfig.from_user_input(template)
6859

6960
if template_data:
7061
project_details = {
@@ -115,7 +106,6 @@ def init_project_builder(
115106

116107

117108
def welcome_message():
118-
#os.system("cls" if os.name == "nt" else "clear")
119109
title = text2art("AgentStack", font="smisome1")
120110
tagline = "The easiest way to build a robust agent application!"
121111
border = "-" * len(tagline)
@@ -400,7 +390,7 @@ def insert_template(
400390
f'{template_path}/{"{{cookiecutter.project_metadata.project_slug}}"}/.env.example',
401391
f'{template_path}/{"{{cookiecutter.project_metadata.project_slug}}"}/.env',
402392
)
403-
393+
404394
cookiecutter(str(template_path), no_input=True, extra_context=None)
405395

406396
# TODO: inits a git repo in the directory the command was run in

agentstack/proj_templates.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,9 @@ def to_v3(self) -> 'TemplateConfig':
6969
framework=self.framework,
7070
method=self.method,
7171
manager_agent=None,
72-
agents=[TemplateConfig.Agent(**agent.dict()) for agent in self.agents],
73-
tasks=[TemplateConfig.Task(**task.dict()) for task in self.tasks],
74-
tools=[TemplateConfig.Tool(**tool.dict()) for tool in self.tools],
72+
agents=[TemplateConfig.Agent(**agent.model_dump()) for agent in self.agents],
73+
tasks=[TemplateConfig.Task(**task.model_dump()) for task in self.tasks],
74+
tools=[TemplateConfig.Tool(**tool.model_dump()) for tool in self.tools],
7575
inputs=self.inputs,
7676
)
7777

@@ -144,17 +144,22 @@ def write_to_file(self, filename: Path):
144144
f.write(json.dumps(model_dump, indent=4))
145145

146146
@classmethod
147-
def from_template_name(cls, name: str) -> 'TemplateConfig':
148-
# if url
149-
if name.startswith('https://'):
150-
return cls.from_url(name)
151-
152-
# if .json file
153-
if name.endswith('.json'):
154-
path = os.getcwd() / Path(name)
147+
def from_user_input(cls, identifier: str):
148+
"""
149+
Load a template from a user-provided identifier.
150+
Three cases will be tried: A URL, a file path, or a template name.
151+
"""
152+
if identifier.startswith('https://'):
153+
return cls.from_url(identifier)
154+
155+
if identifier.endswith('.json'):
156+
path = Path() / identifier
155157
return cls.from_file(path)
156158

157-
# if named template
159+
return cls.from_template_name(identifier)
160+
161+
@classmethod
162+
def from_template_name(cls, name: str) -> 'TemplateConfig':
158163
path = get_package_path() / f'templates/proj_templates/{name}.json'
159164
if not name in get_all_template_names():
160165
raise ValidationError(f"Template {name} not bundled with agentstack.")
@@ -164,8 +169,11 @@ def from_template_name(cls, name: str) -> 'TemplateConfig':
164169
def from_file(cls, path: Path) -> 'TemplateConfig':
165170
if not os.path.exists(path):
166171
raise ValidationError(f"Template {path} not found.")
167-
with open(path, 'r') as f:
168-
return cls.from_json(json.load(f))
172+
try:
173+
with open(path, 'r') as f:
174+
return cls.from_json(json.load(f))
175+
except json.JSONDecodeError as e:
176+
raise ValidationError(f"Error decoding template JSON.\n{e}")
169177

170178
@classmethod
171179
def from_url(cls, url: str) -> 'TemplateConfig':
@@ -174,7 +182,10 @@ def from_url(cls, url: str) -> 'TemplateConfig':
174182
response = requests.get(url)
175183
if response.status_code != 200:
176184
raise ValidationError(f"Failed to fetch template from {url}")
177-
return cls.from_json(response.json())
185+
try:
186+
return cls.from_json(response.json())
187+
except json.JSONDecodeError as e:
188+
raise ValidationError(f"Error decoding template JSON.\n{e}")
178189

179190
@classmethod
180191
def from_json(cls, data: dict) -> 'TemplateConfig':
@@ -193,8 +204,6 @@ def from_json(cls, data: dict) -> 'TemplateConfig':
193204
for error in e.errors():
194205
err_msg += f"{' '.join([str(loc) for loc in error['loc']])}: {error['msg']}\n"
195206
raise ValidationError(err_msg)
196-
except json.JSONDecodeError as e:
197-
raise ValidationError(f"Error decoding template JSON.\n{e}")
198207

199208

200209
def get_all_template_paths() -> list[Path]:

tests/test_agents_config.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import importlib.resources
66
from pathlib import Path
77
from agentstack import conf
8-
from agentstack.agents import AgentConfig, AGENTS_FILENAME
8+
from agentstack.agents import AgentConfig, AGENTS_FILENAME, get_all_agent_names, get_all_agents
9+
from agentstack.exceptions import ValidationError
910

1011
BASE_PATH = Path(__file__).parent
1112

@@ -83,3 +84,58 @@ def test_write_none_values(self):
8384
llm:
8485
"""
8586
)
87+
88+
def test_yaml_error(self):
89+
# Create an invalid YAML file
90+
with open(self.project_dir / AGENTS_FILENAME, 'w') as f:
91+
f.write("""
92+
agent_name:
93+
role: "This is a valid line"
94+
invalid_yaml: "This line is missing a colon"
95+
nested_key: "This will cause a YAML error"
96+
""")
97+
98+
# Attempt to load the config, which should raise a ValidationError
99+
with self.assertRaises(ValidationError) as context:
100+
AgentConfig("agent_name")
101+
102+
def test_pydantic_validation_error(self):
103+
# Create a YAML file with an invalid field type
104+
with open(self.project_dir / AGENTS_FILENAME, 'w') as f:
105+
f.write("""
106+
agent_name:
107+
role: "This is a valid role"
108+
goal: "This is a valid goal"
109+
backstory: "This is a valid backstory"
110+
llm: 123 # This should be a string, not an integer
111+
""")
112+
113+
# Attempt to load the config, which should raise a ValidationError
114+
with self.assertRaises(ValidationError) as context:
115+
AgentConfig("agent_name")
116+
117+
def test_get_all_agent_names(self):
118+
shutil.copy(BASE_PATH / "fixtures/agents_max.yaml", self.project_dir / AGENTS_FILENAME)
119+
120+
agent_names = get_all_agent_names()
121+
self.assertEqual(set(agent_names), {"agent_name", "second_agent_name"})
122+
self.assertEqual(agent_names, ["agent_name", "second_agent_name"])
123+
124+
def test_get_all_agent_names_missing_file(self):
125+
if os.path.exists(self.project_dir / AGENTS_FILENAME):
126+
os.remove(self.project_dir / AGENTS_FILENAME)
127+
non_existent_file_agent_names = get_all_agent_names()
128+
self.assertEqual(non_existent_file_agent_names, [])
129+
130+
def test_get_all_agent_names_empty_file(self):
131+
with open(self.project_dir / AGENTS_FILENAME, 'w') as f:
132+
f.write("")
133+
134+
empty_agent_names = get_all_agent_names()
135+
self.assertEqual(empty_agent_names, [])
136+
137+
def test_get_all_agents(self):
138+
shutil.copy(BASE_PATH / "fixtures/agents_max.yaml", self.project_dir / AGENTS_FILENAME)
139+
140+
for agent in get_all_agents():
141+
self.assertIsInstance(agent, AgentConfig)

tests/test_inputs_config.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import unittest
44
from pathlib import Path
55
from agentstack import conf
6-
from agentstack.inputs import InputsConfig
6+
from agentstack.inputs import InputsConfig, get_inputs, add_input_for_run
7+
from agentstack.exceptions import ValidationError
78

89
BASE_PATH = Path(__file__).parent
910

@@ -30,3 +31,62 @@ def test_maximal_input_config(self):
3031
assert config['input_name'] == "This in an input"
3132
assert config['input_name_2'] == "This is another input"
3233
assert config.to_dict() == {'input_name': "This in an input", 'input_name_2': "This is another input"}
34+
35+
def test_yaml_error(self):
36+
# Create an invalid YAML file
37+
with open(self.project_dir / "src/config/inputs.yaml", 'w') as f:
38+
f.write("""
39+
input_name: "This is a valid line"
40+
invalid_yaml: "This line is missing a colon"
41+
nested_key: "This will cause a YAML error"
42+
""")
43+
44+
# Attempt to load the config, which should raise a ValidationError
45+
with self.assertRaises(ValidationError) as context:
46+
InputsConfig()
47+
48+
def test_create_inputs_file_if_not_exists(self):
49+
# Ensure the inputs file doesn't exist
50+
inputs_file = self.project_dir / "src/config/inputs.yaml"
51+
if inputs_file.exists():
52+
inputs_file.unlink()
53+
54+
# Create an InputsConfig instance and set a value
55+
with InputsConfig() as config:
56+
config['test_key'] = 'test_value'
57+
58+
# Check that the file was created
59+
self.assertTrue(inputs_file.exists())
60+
61+
def test_inputs_config_contains(self):
62+
# Create an InputsConfig instance and set some values
63+
with InputsConfig() as config:
64+
config['existing_key'] = 'some_value'
65+
config['another_key'] = 'another_value'
66+
67+
# Test the __contains__ method
68+
self.assertTrue('existing_key' in config)
69+
self.assertTrue('another_key' in config)
70+
self.assertFalse('non_existing_key' in config)
71+
72+
def test_get_inputs(self):
73+
# Set up some initial inputs
74+
with InputsConfig() as config:
75+
config['saved_key'] = 'saved_value'
76+
77+
# Test get_inputs without run inputs
78+
inputs = get_inputs()
79+
self.assertEqual(inputs['saved_key'], 'saved_value')
80+
81+
# Add a run input
82+
add_input_for_run('run_key', 'run_value')
83+
84+
# Test get_inputs with run inputs
85+
inputs = get_inputs()
86+
self.assertEqual(inputs['saved_key'], 'saved_value')
87+
self.assertEqual(inputs['run_key'], 'run_value')
88+
89+
# Test that run inputs override saved inputs
90+
add_input_for_run('saved_key', 'overridden_value')
91+
inputs = get_inputs()
92+
self.assertEqual(inputs['saved_key'], 'overridden_value')

tests/test_tasks_config.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import importlib.resources
66
from pathlib import Path
77
from agentstack import conf
8-
from agentstack.tasks import TaskConfig, TASKS_FILENAME
8+
from agentstack.tasks import TaskConfig, TASKS_FILENAME, get_all_task_names, get_all_tasks
9+
from agentstack.exceptions import ValidationError
910

1011
BASE_PATH = Path(__file__).parent
1112

@@ -76,3 +77,56 @@ def test_write_none_values(self):
7677
agent: >
7778
"""
7879
)
80+
81+
def test_yaml_error(self):
82+
# Create an invalid YAML file
83+
with open(self.project_dir / TASKS_FILENAME, 'w') as f:
84+
f.write("""
85+
task_name:
86+
description: "This is a valid line"
87+
invalid_yaml: "This line is missing a colon"
88+
nested_key: "This will cause a YAML error"
89+
""")
90+
91+
# Attempt to load the config, which should raise a ValidationError
92+
with self.assertRaises(ValidationError) as context:
93+
TaskConfig("task_name")
94+
95+
def test_pydantic_validation_error(self):
96+
# Create a YAML file with an invalid field type
97+
with open(self.project_dir / TASKS_FILENAME, 'w') as f:
98+
f.write("""
99+
task_name:
100+
description: "This is a valid description"
101+
expected_output: "This is a valid expected output"
102+
agent: 123 # This should be a string, not an integer
103+
""")
104+
105+
# Attempt to load the config, which should raise a ValidationError
106+
with self.assertRaises(ValidationError) as context:
107+
TaskConfig("task_name")
108+
109+
def test_get_all_task_names(self):
110+
shutil.copy(BASE_PATH / "fixtures/tasks_max.yaml", self.project_dir / TASKS_FILENAME)
111+
112+
task_names = get_all_task_names()
113+
self.assertEqual(set(task_names), {"task_name", "task_name_two"})
114+
self.assertEqual(task_names, ["task_name", "task_name_two"])
115+
116+
def test_get_all_task_names_missing_file(self):
117+
if os.path.exists(self.project_dir / TASKS_FILENAME):
118+
os.remove(self.project_dir / TASKS_FILENAME)
119+
non_existent_file_task_names = get_all_task_names()
120+
self.assertEqual(non_existent_file_task_names, [])
121+
122+
def test_get_all_task_names_empty_file(self):
123+
with open(self.project_dir / TASKS_FILENAME, 'w') as f:
124+
f.write("")
125+
126+
empty_task_names = get_all_task_names()
127+
self.assertEqual(empty_task_names, [])
128+
129+
def test_get_all_tasks(self):
130+
shutil.copy(BASE_PATH / "fixtures/tasks_max.yaml", self.project_dir / TASKS_FILENAME)
131+
for task in get_all_tasks():
132+
self.assertIsInstance(task, TaskConfig)

0 commit comments

Comments
 (0)