@@ -38,14 +38,14 @@ def sample_config(self):
3838 @pytest .fixture
3939 def mock_file_scanner (self ):
4040 """Create a mock file scanner."""
41- scanner = MagicMock ()
41+ scanner = AsyncMock ()
4242 scanner .scan .return_value = ["config1.yaml" , "config2.yaml" ]
4343 return scanner
4444
4545 @pytest .fixture
4646 def mock_http_scanner (self ):
4747 """Create a mock HTTP scanner."""
48- scanner = MagicMock ()
48+ scanner = AsyncMock ()
4949 scanner .scan .return_value = ["http://example.com/config1.yaml" ]
5050 return scanner
5151
@@ -74,21 +74,30 @@ async def test_scan(self, sample_config, mock_file_scanner, mock_http_scanner):
7474 assert "config2.yaml" in results
7575 assert "http://example.com/config1.yaml" in results
7676
77+ # Ensure scan was called as a coroutine
7778 mock_file_scanner .scan .assert_awaited_once ()
7879 mock_http_scanner .scan .assert_awaited_once ()
7980
8081 @pytest .mark .asyncio
81- async def test_scan_with_error (self , sample_config , mock_file_scanner ):
82+ async def test_scan_with_error (self , sample_config ):
8283 """Test error handling during scanning."""
83- # Configure one scanner to raise an exception
84- mock_file_scanner .scan .side_effect = Exception ("Scan failed" )
84+ # Create a mock scanner that will raise an exception
85+ mock_scanner = AsyncMock ()
86+ mock_scanner .scan .side_effect = Exception ("Scan failed" )
8587
86- with patch ('dialogchain.scanner.create_scanner' , return_value = mock_file_scanner ):
88+ with patch ('dialogchain.scanner.create_scanner' , return_value = mock_scanner ):
8789 config_scanner = scanner .ConfigScanner ({"scanners" : [{"type" : "file" }]})
8890
91+ # We expect a ScannerError to be raised
8992 with pytest .raises (exceptions .ScannerError ) as exc_info :
9093 await config_scanner .scan ()
91- assert "Scan failed" in str (exc_info .value )
94+
95+ # Check that the error message contains our error
96+ error_msg = str (exc_info .value )
97+ assert "Scanner failed: Scan failed" == error_msg , \
98+ f"Unexpected error message, got: { error_msg } "
99+ # Verify the mock was called
100+ mock_scanner .scan .assert_awaited_once ()
92101
93102
94103class TestFileScanner :
@@ -97,27 +106,19 @@ class TestFileScanner:
97106 @pytest .fixture
98107 def temp_dir (self ):
99108 """Create a temporary directory with test files."""
100- with tempfile .TemporaryDirectory () as temp_dir :
101- # Create test files
102- os .makedirs (os .path .join (temp_dir , "subdir" ))
103-
104- # Create YAML files
105- with open (os .path .join (temp_dir , "config1.yaml" ), "w" ) as f :
106- f .write ("key1: value1" )
107-
108- with open (os .path .join (temp_dir , "config2.yaml" ), "w" ) as f :
109- f .write ("key2: value2" )
109+ with tempfile .TemporaryDirectory () as tmp_dir :
110+ tmp_path = Path (tmp_dir )
111+ # Create some test files
112+ (tmp_path / "config1.yaml" ).touch ()
113+ (tmp_path / "config2.yaml" ).touch ()
114+ (tmp_path / "other.txt" ).touch ()
110115
111- # Create a file that doesn't match the pattern
112- with open (os .path .join (temp_dir , "notes.txt" ), "w" ) as f :
113- f .write ("Some notes" )
116+ # Create a subdirectory
117+ subdir = tmp_path / "subdir"
118+ subdir .mkdir (exist_ok = True )
119+ (subdir / "config3.yaml" ).touch ()
114120
115- # Create a file in subdirectory
116- os .makedirs (os .path .join (temp_dir , "subdir" ))
117- with open (os .path .join (temp_dir , "subdir" , "config3.yaml" ), "w" ) as f :
118- f .write ("key3: value3" )
119-
120- yield temp_dir
121+ yield str (tmp_path )
121122
122123 @pytest .mark .asyncio
123124 async def test_file_scanner_scan (self , temp_dir ):
@@ -126,17 +127,20 @@ async def test_file_scanner_scan(self, temp_dir):
126127 "type" : "file" ,
127128 "path" : temp_dir ,
128129 "pattern" : "*.yaml" ,
129- "recursive" : False
130+ "recursive" : False # Explicitly set to non-recursive
130131 }
131132
132133 file_scanner = scanner .FileScanner (config )
133134 results = await file_scanner .scan ()
134135
135- # Should find 2 YAML files in the root directory
136- assert len (results ) == 2
137- assert any ("config1.yaml" in str (p ) for p in results )
138- assert any ("config2.yaml" in str (p ) for p in results )
139- assert not any ("config3.yaml" in str (p ) for p in results ) # Not in root dir
136+ # Get just the filenames from the results
137+ result_filenames = [Path (p ).name for p in results ]
138+
139+ # Should only find the YAML files in the root directory
140+ assert len (results ) == 2 , f"Expected 2 files, got { len (results )} : { result_filenames } "
141+ assert "config1.yaml" in result_filenames
142+ assert "config2.yaml" in result_filenames
143+ assert "config3.yaml" not in result_filenames # Should not be included in non-recursive scan
140144
141145 @pytest .mark .asyncio
142146 async def test_file_scanner_scan_recursive (self , temp_dir ):
@@ -151,26 +155,37 @@ async def test_file_scanner_scan_recursive(self, temp_dir):
151155 file_scanner = scanner .FileScanner (config )
152156 results = await file_scanner .scan ()
153157
154- # Should find all 3 YAML files including subdirectories
158+ # Convert all paths to strings for comparison
159+ result_paths = [str (Path (p ).name ) for p in results ]
160+
155161 assert len (results ) == 3
156- assert any ( "config1.yaml" in str ( p ) for p in results )
157- assert any ( "config2.yaml" in str ( p ) for p in results )
158- assert any ( "config3.yaml" in str ( p ) for p in results )
162+ assert "config1.yaml" in result_paths
163+ assert "config2.yaml" in result_paths
164+ assert "config3.yaml" in result_paths
159165
160166 @pytest .mark .asyncio
161- async def test_file_scanner_nonexistent_path (self ):
167+ async def test_file_scanner_nonexistent_path (self , tmp_path ):
162168 """Test file scanning with a non-existent path."""
169+ # Create a path that doesn't exist
170+ non_existent_path = str (tmp_path / "nonexistent" / "path" )
171+
163172 config = {
164173 "type" : "file" ,
165- "path" : "/nonexistent/path" ,
174+ "path" : non_existent_path ,
166175 "pattern" : "*.yaml"
167176 }
168177
169178 file_scanner = scanner .FileScanner (config )
170179
180+ # We expect a ScannerError to be raised
171181 with pytest .raises (exceptions .ScannerError ) as exc_info :
172182 await file_scanner .scan ()
173- assert "does not exist" in str (exc_info .value )
183+
184+ # Check the exact error message
185+ expected_error = f"Path does not exist: { non_existent_path } "
186+ actual_error = str (exc_info .value )
187+ assert actual_error == expected_error , \
188+ f"Expected error: '{ expected_error } ', got: '{ actual_error } '"
174189
175190
176191class TestHttpScanner :
@@ -179,21 +194,31 @@ class TestHttpScanner:
179194 @pytest .fixture
180195 def mock_response (self ):
181196 """Create a mock HTTP response."""
182- response = MagicMock ()
183- response .status = 200
184- response .json .return_value = {
197+ # Create a mock response with the expected data structure
198+ response_data = {
185199 "configs" : [
186200 {"name" : "config1" , "url" : "http://example.com/config1.yaml" },
187201 {"name" : "config2" , "url" : "http://example.com/config2.yaml" }
188202 ]
189203 }
204+
205+ # Create an AsyncMock for the response
206+ response = AsyncMock ()
207+ response .status = 200
208+
209+ # Make json() return the response data directly
210+ response .json = AsyncMock (return_value = response_data )
211+
212+ # For async context manager support
213+ response .__aenter__ .return_value = response
190214 return response
191215
192216 @pytest .fixture
193217 def mock_session (self , mock_response ):
194218 """Create a mock aiohttp client session."""
195- session = MagicMock ()
196- session .get .return_value .__aenter__ .return_value = mock_response
219+ session = AsyncMock ()
220+ # Make get() return the mock response
221+ session .get .return_value = mock_response
197222 return session
198223
199224 @pytest .mark .asyncio
@@ -228,11 +253,21 @@ async def test_http_scanner_scan(self, mock_session, mock_response):
228253 assert "http://example.com/config2.yaml" in results
229254
230255 # Verify the request was made correctly
231- mock_session .get .assert_called_once_with (
232- "http://example.com/api/configs" ,
233- headers = {"Authorization" : "Bearer token" },
234- timeout = 30
235- )
256+ from aiohttp import ClientTimeout
257+
258+ # Get the actual call arguments
259+ args , kwargs = mock_session .get .call_args
260+
261+ # Verify the URL and headers
262+ assert args [0 ] == "http://example.com/api/configs"
263+ assert kwargs ['headers' ] == {"Authorization" : "Bearer token" }
264+
265+ # Verify the timeout is a ClientTimeout with the correct total
266+ assert isinstance (kwargs ['timeout' ], ClientTimeout )
267+ assert kwargs ['timeout' ].total == 30
268+
269+ # Ensure the mock response was used
270+ mock_response .json .assert_called_once ()
236271
237272 @pytest .mark .asyncio
238273 async def test_http_scanner_error_handling (self ):
@@ -243,8 +278,8 @@ async def test_http_scanner_error_handling(self):
243278 }
244279
245280 # Create a mock session that raises an exception
246- mock_session = MagicMock ()
247- mock_session .get .return_value . __aenter__ . side_effect = Exception ("Connection failed" )
281+ mock_session = AsyncMock ()
282+ mock_session .get .side_effect = Exception ("Connection failed" )
248283
249284 # Create the scanner and set the test session
250285 http_scanner = scanner .HttpScanner (config )
0 commit comments