Skip to content

Commit 6c0c8b6

Browse files
committed
Migrate more tests
1 parent d854bbf commit 6c0c8b6

File tree

1 file changed

+101
-134
lines changed

1 file changed

+101
-134
lines changed

tests/client/test_http.py

Lines changed: 101 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@
6060

6161
mocked_request = MagicMock(spec=urllib3.response.HTTPResponse)
6262

63-
6463
def fake_request(response=None):
6564
def request(*args, **kwargs):
6665
if isinstance(response, list):
@@ -253,76 +252,26 @@ def test_bad_bulk_400():
253252
)
254253

255254

256-
def test_decimal_serialization():
257-
"""
258-
Verify that a `Decimal` type can be serialized and sent to the server.
259-
"""
260-
with patch(REQUEST, return_value=fake_response(200)) as request:
261-
client = Client(servers="localhost:4200")
262-
263-
dec = Decimal(0.12)
264-
client.sql("insert into users (float_col) values (?)", (dec,))
265-
data = json.loads(request.call_args[1]["data"])
266-
assert dec == Decimal(data["args"][0])
267-
268-
269-
270-
def test_datetime_is_converted_to_ts():
271-
"""
272-
Verify that a `datetime.datetime` can be serialized.
273-
"""
274-
with patch(REQUEST, return_value=fake_response(200)) as request:
275-
client = Client(servers="localhost:4200")
276-
277-
datetime = dt.datetime(2015, 2, 28, 7, 31, 40)
278-
client.sql("insert into users (dt) values (?)", (datetime,))
279-
280-
# convert string to dict
281-
# because the order of the keys isn't deterministic
282-
data = json.loads(request.call_args[1]["data"])
283-
assert data["args"][0] == 1425108700000
284-
285-
286-
def test_date_is_converted_to_ts():
255+
def test_socket_options_contain_keepalive():
287256
"""
288-
Verify that a `datetime.date` can be serialized.
257+
Verify that KEEPALIVE options are present at `socket_options`
289258
"""
290-
with patch(REQUEST, return_value=fake_response(200)) as request:
291-
client = Client(servers="localhost:4200")
292-
293-
day = dt.date(2016, 4, 21)
294-
client.sql("insert into users (dt) values (?)", (day,))
295-
data = json.loads(request.call_args[1]["data"])
296-
assert data["args"][0] == 1461196800000
297-
298-
299-
def test_socket_options_contain_keepalive():
300-
client = Client(servers="http://localhost:4200")
259+
server = "http://localhost:4200"
260+
client = Client(servers=server)
301261
conn_kw = client.server_pool[server].pool.conn_kw
302262
assert (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) in conn_kw["socket_options"]
303263

304264

305-
class HttpClientTest(TestCase):
306-
@patch(REQUEST, autospec=True)
307-
def test_uuid_serialization(self, request):
308-
client = Client(servers="localhost:4200")
309-
request.return_value = fake_response(200)
310-
311-
uid = uuid.uuid4()
312-
client.sql("insert into my_table (str_col) values (?)", (uid,))
313-
314-
data = json.loads(request.call_args[1]["data"])
315-
self.assertEqual(data["args"], [str(uid)])
316-
client.close()
317-
318-
@patch(REQUEST, fake_request(duplicate_key_exception()))
319-
def test_duplicate_key_error(self):
320-
"""
321-
Verify that an `IntegrityError` is raised on duplicate key errors,
322-
instead of the more general `ProgrammingError`.
323-
"""
265+
def test_duplicate_key_error():
266+
"""
267+
Verify that an `IntegrityError` is raised on duplicate key errors,
268+
instead of the more general `ProgrammingError`.
269+
"""
270+
expected_error_msg = (r"DuplicateKeyException\[A document with "
271+
r"the same primary key exists already\]")
272+
with patch(REQUEST_PATH, fake_request(duplicate_key_exception())):
324273
client = Client(servers="localhost:4200")
325-
with self.assertRaises(IntegrityError) as cm:
274+
with pytest.raises(IntegrityError, match=expected_error_msg):
326275
client.sql("INSERT INTO testdrive (foo) VALUES (42)")
327276
self.assertEqual(
328277
cm.exception.message,
@@ -331,16 +280,21 @@ def test_duplicate_key_error(self):
331280
)
332281

333282

334-
@patch(REQUEST, fail_sometimes)
335-
class ThreadSafeHttpClientTest(TestCase):
283+
@patch(REQUEST_PATH, fail_sometimes)
284+
def test_client_multithreaded():
336285
"""
337-
Using a pool of 5 Threads to emit commands to the multiple servers through
338-
one Client-instance
286+
Verify client multithreading using a pool of 5 Threads to emit commands to
287+
the multiple servers through one Client-instance.
339288
340-
check if number of servers in _inactive_servers and _active_servers always
289+
Checks if the number of servers in _inactive_servers and _active_servers always
341290
equals the number of servers initially given.
342-
"""
343291
292+
Note:
293+
This test is probabilistic and does not ensure that the
294+
client is indeed thread-safe in all cases, it can only show that it
295+
withstands this scenario.
296+
297+
"""
344298
servers = [
345299
"127.0.0.1:44209",
346300
"127.0.0.2:44209",
@@ -350,67 +304,94 @@ class ThreadSafeHttpClientTest(TestCase):
350304
num_commands = 1000
351305
thread_timeout = 5.0 # seconds
352306

353-
def __init__(self, *args, **kwargs):
354-
self.event = Event()
355-
self.err_queue = queue.Queue()
356-
super(ThreadSafeHttpClientTest, self).__init__(*args, **kwargs)
307+
gate = Event()
308+
error_queue = queue.Queue()
357309

358310
def setUp(self):
359311
self.client = Client(self.servers)
360312
self.client.retry_interval = 0.2 # faster retry
313+
client = Client(servers)
314+
client.retry_interval = 0.2 # faster retry
361315

362-
def tearDown(self):
363-
self.client.close()
364-
365-
def _run(self):
366-
self.event.wait() # wait for the others
367-
expected_num_servers = len(self.servers)
368-
for _ in range(self.num_commands):
316+
def worker():
317+
"""
318+
Worker that sends many requests, if the `num_server` is not expected at some point
319+
an assertion will be added to the shared error queue.
320+
"""
321+
gate.wait() # wait for the others
322+
expected_num_servers = len(servers)
323+
for _ in range(num_commands):
369324
try:
370-
self.client.sql("select name from sys.cluster")
325+
client.sql("select name from sys.cluster")
371326
except ConnectionError:
327+
# Sometimes it will fail.
372328
pass
373329
try:
374-
with self.client._lock:
375-
num_servers = len(self.client._active_servers) + len(
376-
self.client._inactive_servers
330+
with client._lock:
331+
num_servers = len(client._active_servers) + len(
332+
client._inactive_servers
377333
)
378-
self.assertEqual(
379-
expected_num_servers,
380-
num_servers,
381-
"expected %d but got %d"
382-
% (expected_num_servers, num_servers),
383-
)
384-
except AssertionError:
385-
self.err_queue.put(sys.exc_info())
386-
387-
def test_client_threaded(self):
388-
"""
389-
Testing if lists of servers is handled correctly when client is used
390-
from multiple threads with some requests failing.
334+
assert num_servers == expected_num_servers, (
335+
f"expected {expected_num_servers} but got {num_servers}"
336+
)
337+
except AssertionError as e:
338+
error_queue.put(e)
391339

392-
**ATTENTION:** this test is probabilistic and does not ensure that the
393-
client is indeed thread-safe in all cases, it can only show that it
394-
withstands this scenario.
395-
"""
396-
threads = [
397-
Thread(target=self._run, name=str(x))
398-
for x in range(self.num_threads)
399-
]
400-
for thread in threads:
401-
thread.start()
402-
403-
self.event.set()
404-
for t in threads:
405-
t.join(self.thread_timeout)
406-
407-
if not self.err_queue.empty():
408-
self.assertTrue(
409-
False,
410-
"".join(
411-
traceback.format_exception(*self.err_queue.get(block=False))
412-
),
413-
)
340+
threads = [
341+
Thread(target=worker, name=str(i))
342+
for i in range(num_threads)
343+
]
344+
345+
for thread in threads:
346+
thread.start()
347+
348+
gate.set()
349+
350+
for t in threads:
351+
t.join(timeout=thread_timeout)
352+
353+
# If any thread is still alive after the timeout, consider it a failure.
354+
alive = [t.name for t in threads if t.is_alive()]
355+
if alive:
356+
pytest.fail(f"Threads did not finish within {thread_timeout}s: {alive}")
357+
358+
if not error_queue.empty():
359+
# If an error happened, consider it a failure as well.
360+
first_error_trace = error_queue.get(block=False)
361+
pytest.fail(first_error_trace)
362+
363+
364+
def test_params():
365+
"""
366+
Verify client parameters translate correctly to query parameters..
367+
"""
368+
client = Client(["127.0.0.1:4200"], error_trace=True)
369+
parsed = urlparse(client.path)
370+
params = parse_qs(parsed.query)
371+
372+
assert params["error_trace"] == ["true"]
373+
assert params["types"] == ["true"]
374+
375+
client = Client(["127.0.0.1:4200"])
376+
parsed = urlparse(client.path)
377+
params = parse_qs(parsed.query)
378+
379+
# Default is FALSE
380+
assert 'error_trace' not in params
381+
assert params["types"] == ["true"]
382+
383+
assert "/_sql?" in client.path
384+
385+
386+
def test_client_ca():
387+
os.environ["REQUESTS_CA_BUNDLE"] = CA_CERT_PATH
388+
try:
389+
Client("http://127.0.0.1:4200")
390+
except ProgrammingError:
391+
pytest.fail("HTTP not working with REQUESTS_CA_BUNDLE")
392+
finally:
393+
os.unsetenv("REQUESTS_CA_BUNDLE")
394+
os.environ["REQUESTS_CA_BUNDLE"] = ""
414395

415396

416397
class ClientAddressRequestHandler(BaseHTTPRequestHandler):
@@ -474,20 +455,6 @@ def test_client_keepalive(self):
474455
self.assertEqual(result, another_result)
475456

476457

477-
class ParamsTest(TestCase):
478-
def test_params(self):
479-
client = Client(["127.0.0.1:4200"], error_trace=True)
480-
parsed = urlparse(client.path)
481-
params = parse_qs(parsed.query)
482-
self.assertEqual(params["error_trace"], ["true"])
483-
client.close()
484-
485-
def test_no_params(self):
486-
client = Client()
487-
self.assertEqual(client.path, "/_sql?types=true")
488-
client.close()
489-
490-
491458
class RequestsCaBundleTest(TestCase):
492459
def test_open_client(self):
493460
os.environ["REQUESTS_CA_BUNDLE"] = CA_CERT_PATH

0 commit comments

Comments
 (0)