6060
6161mocked_request = MagicMock (spec = urllib3 .response .HTTPResponse )
6262
63-
6463def fake_request (response = None ):
6564 def request (* args , ** kwargs ):
6665 if isinstance (response , list ):
@@ -253,76 +252,26 @@ def test_bad_bulk_400():
253252 )
254253
255254
256- def test_decimal_serialization ():
257- """
258- Verify that a `Decimal` type can be serialized and sent to the server.
259- """
260- with patch (REQUEST , return_value = fake_response (200 )) as request :
261- client = Client (servers = "localhost:4200" )
262-
263- dec = Decimal (0.12 )
264- client .sql ("insert into users (float_col) values (?)" , (dec ,))
265- data = json .loads (request .call_args [1 ]["data" ])
266- assert dec == Decimal (data ["args" ][0 ])
267-
268-
269-
270- def test_datetime_is_converted_to_ts ():
271- """
272- Verify that a `datetime.datetime` can be serialized.
273- """
274- with patch (REQUEST , return_value = fake_response (200 )) as request :
275- client = Client (servers = "localhost:4200" )
276-
277- datetime = dt .datetime (2015 , 2 , 28 , 7 , 31 , 40 )
278- client .sql ("insert into users (dt) values (?)" , (datetime ,))
279-
280- # convert string to dict
281- # because the order of the keys isn't deterministic
282- data = json .loads (request .call_args [1 ]["data" ])
283- assert data ["args" ][0 ] == 1425108700000
284-
285-
286- def test_date_is_converted_to_ts ():
255+ def test_socket_options_contain_keepalive ():
287256 """
288- Verify that a `datetime.date` can be serialized.
257+ Verify that KEEPALIVE options are present at `socket_options`
289258 """
290- with patch (REQUEST , return_value = fake_response (200 )) as request :
291- client = Client (servers = "localhost:4200" )
292-
293- day = dt .date (2016 , 4 , 21 )
294- client .sql ("insert into users (dt) values (?)" , (day ,))
295- data = json .loads (request .call_args [1 ]["data" ])
296- assert data ["args" ][0 ] == 1461196800000
297-
298-
299- def test_socket_options_contain_keepalive ():
300- client = Client (servers = "http://localhost:4200" )
259+ server = "http://localhost:4200"
260+ client = Client (servers = server )
301261 conn_kw = client .server_pool [server ].pool .conn_kw
302262 assert (socket .SOL_SOCKET , socket .SO_KEEPALIVE , 1 ) in conn_kw ["socket_options" ]
303263
304264
305- class HttpClientTest (TestCase ):
306- @patch (REQUEST , autospec = True )
307- def test_uuid_serialization (self , request ):
308- client = Client (servers = "localhost:4200" )
309- request .return_value = fake_response (200 )
310-
311- uid = uuid .uuid4 ()
312- client .sql ("insert into my_table (str_col) values (?)" , (uid ,))
313-
314- data = json .loads (request .call_args [1 ]["data" ])
315- self .assertEqual (data ["args" ], [str (uid )])
316- client .close ()
317-
318- @patch (REQUEST , fake_request (duplicate_key_exception ()))
319- def test_duplicate_key_error (self ):
320- """
321- Verify that an `IntegrityError` is raised on duplicate key errors,
322- instead of the more general `ProgrammingError`.
323- """
265+ def test_duplicate_key_error ():
266+ """
267+ Verify that an `IntegrityError` is raised on duplicate key errors,
268+ instead of the more general `ProgrammingError`.
269+ """
270+ expected_error_msg = (r"DuplicateKeyException\[A document with "
271+ r"the same primary key exists already\]" )
272+ with patch (REQUEST_PATH , fake_request (duplicate_key_exception ())):
324273 client = Client (servers = "localhost:4200" )
325- with self . assertRaises (IntegrityError ) as cm :
274+ with pytest . raises (IntegrityError , match = expected_error_msg ) :
326275 client .sql ("INSERT INTO testdrive (foo) VALUES (42)" )
327276 self .assertEqual (
328277 cm .exception .message ,
@@ -331,16 +280,21 @@ def test_duplicate_key_error(self):
331280 )
332281
333282
334- @patch (REQUEST , fail_sometimes )
335- class ThreadSafeHttpClientTest ( TestCase ):
283+ @patch (REQUEST_PATH , fail_sometimes )
284+ def test_client_multithreaded ( ):
336285 """
337- Using a pool of 5 Threads to emit commands to the multiple servers through
338- one Client-instance
286+ Verify client multithreading using a pool of 5 Threads to emit commands to
287+ the multiple servers through one Client-instance.
339288
340- check if number of servers in _inactive_servers and _active_servers always
289+ Checks if the number of servers in _inactive_servers and _active_servers always
341290 equals the number of servers initially given.
342- """
343291
292+ Note:
293+ This test is probabilistic and does not ensure that the
294+ client is indeed thread-safe in all cases, it can only show that it
295+ withstands this scenario.
296+
297+ """
344298 servers = [
345299 "127.0.0.1:44209" ,
346300 "127.0.0.2:44209" ,
@@ -350,67 +304,94 @@ class ThreadSafeHttpClientTest(TestCase):
350304 num_commands = 1000
351305 thread_timeout = 5.0 # seconds
352306
353- def __init__ (self , * args , ** kwargs ):
354- self .event = Event ()
355- self .err_queue = queue .Queue ()
356- super (ThreadSafeHttpClientTest , self ).__init__ (* args , ** kwargs )
307+ gate = Event ()
308+ error_queue = queue .Queue ()
357309
358310 def setUp (self ):
359311 self .client = Client (self .servers )
360312 self .client .retry_interval = 0.2 # faster retry
313+ client = Client (servers )
314+ client .retry_interval = 0.2 # faster retry
361315
362- def tearDown (self ):
363- self .client .close ()
364-
365- def _run (self ):
366- self .event .wait () # wait for the others
367- expected_num_servers = len (self .servers )
368- for _ in range (self .num_commands ):
316+ def worker ():
317+ """
318+ Worker that sends many requests, if the `num_server` is not expected at some point
319+ an assertion will be added to the shared error queue.
320+ """
321+ gate .wait () # wait for the others
322+ expected_num_servers = len (servers )
323+ for _ in range (num_commands ):
369324 try :
370- self . client .sql ("select name from sys.cluster" )
325+ client .sql ("select name from sys.cluster" )
371326 except ConnectionError :
327+ # Sometimes it will fail.
372328 pass
373329 try :
374- with self . client ._lock :
375- num_servers = len (self . client ._active_servers ) + len (
376- self . client ._inactive_servers
330+ with client ._lock :
331+ num_servers = len (client ._active_servers ) + len (
332+ client ._inactive_servers
377333 )
378- self .assertEqual (
379- expected_num_servers ,
380- num_servers ,
381- "expected %d but got %d"
382- % (expected_num_servers , num_servers ),
383- )
384- except AssertionError :
385- self .err_queue .put (sys .exc_info ())
386-
387- def test_client_threaded (self ):
388- """
389- Testing if lists of servers is handled correctly when client is used
390- from multiple threads with some requests failing.
334+ assert num_servers == expected_num_servers , (
335+ f"expected { expected_num_servers } but got { num_servers } "
336+ )
337+ except AssertionError as e :
338+ error_queue .put (e )
391339
392- **ATTENTION:** this test is probabilistic and does not ensure that the
393- client is indeed thread-safe in all cases, it can only show that it
394- withstands this scenario.
395- """
396- threads = [
397- Thread (target = self ._run , name = str (x ))
398- for x in range (self .num_threads )
399- ]
400- for thread in threads :
401- thread .start ()
402-
403- self .event .set ()
404- for t in threads :
405- t .join (self .thread_timeout )
406-
407- if not self .err_queue .empty ():
408- self .assertTrue (
409- False ,
410- "" .join (
411- traceback .format_exception (* self .err_queue .get (block = False ))
412- ),
413- )
340+ threads = [
341+ Thread (target = worker , name = str (i ))
342+ for i in range (num_threads )
343+ ]
344+
345+ for thread in threads :
346+ thread .start ()
347+
348+ gate .set ()
349+
350+ for t in threads :
351+ t .join (timeout = thread_timeout )
352+
353+ # If any thread is still alive after the timeout, consider it a failure.
354+ alive = [t .name for t in threads if t .is_alive ()]
355+ if alive :
356+ pytest .fail (f"Threads did not finish within { thread_timeout } s: { alive } " )
357+
358+ if not error_queue .empty ():
359+ # If an error happened, consider it a failure as well.
360+ first_error_trace = error_queue .get (block = False )
361+ pytest .fail (first_error_trace )
362+
363+
364+ def test_params ():
365+ """
366+ Verify client parameters translate correctly to query parameters..
367+ """
368+ client = Client (["127.0.0.1:4200" ], error_trace = True )
369+ parsed = urlparse (client .path )
370+ params = parse_qs (parsed .query )
371+
372+ assert params ["error_trace" ] == ["true" ]
373+ assert params ["types" ] == ["true" ]
374+
375+ client = Client (["127.0.0.1:4200" ])
376+ parsed = urlparse (client .path )
377+ params = parse_qs (parsed .query )
378+
379+ # Default is FALSE
380+ assert 'error_trace' not in params
381+ assert params ["types" ] == ["true" ]
382+
383+ assert "/_sql?" in client .path
384+
385+
386+ def test_client_ca ():
387+ os .environ ["REQUESTS_CA_BUNDLE" ] = CA_CERT_PATH
388+ try :
389+ Client ("http://127.0.0.1:4200" )
390+ except ProgrammingError :
391+ pytest .fail ("HTTP not working with REQUESTS_CA_BUNDLE" )
392+ finally :
393+ os .unsetenv ("REQUESTS_CA_BUNDLE" )
394+ os .environ ["REQUESTS_CA_BUNDLE" ] = ""
414395
415396
416397class ClientAddressRequestHandler (BaseHTTPRequestHandler ):
@@ -474,20 +455,6 @@ def test_client_keepalive(self):
474455 self .assertEqual (result , another_result )
475456
476457
477- class ParamsTest (TestCase ):
478- def test_params (self ):
479- client = Client (["127.0.0.1:4200" ], error_trace = True )
480- parsed = urlparse (client .path )
481- params = parse_qs (parsed .query )
482- self .assertEqual (params ["error_trace" ], ["true" ])
483- client .close ()
484-
485- def test_no_params (self ):
486- client = Client ()
487- self .assertEqual (client .path , "/_sql?types=true" )
488- client .close ()
489-
490-
491458class RequestsCaBundleTest (TestCase ):
492459 def test_open_client (self ):
493460 os .environ ["REQUESTS_CA_BUNDLE" ] = CA_CERT_PATH
0 commit comments