Skip to content

Commit 3c57a7f

Browse files
committed
Add a troubling test of multithreaded SwaggerClient
Really feel like this one should pass...
1 parent bd6c072 commit 3c57a7f

2 files changed

Lines changed: 51 additions & 7 deletions

File tree

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
flake8==3.5.0
22
moto==1.3.3
3+
futures; python_version < '3.2'
34
coverage
45
pyyaml
56
responses

test/integration/util/test_swagger_client.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
# coding: utf-8
33

44
import argparse
5+
import concurrent.futures
56
import json
67
import os
78
import requests
89
import sys
10+
import time
911
import unittest
1012

1113
pkg_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..')) # noqa
@@ -45,6 +47,7 @@ class TestSwaggerClient(unittest.TestCase):
4547
dummy_response = requests.models.Response()
4648
dummy_response.status_code = 200
4749
dummy_response._content = "content"
50+
dummy_response.headers["content-type"] = "audio/vnd.rn-realaudio"
4851

4952
generated_method_names = [
5053
# method names corresponding to all `paths`
@@ -69,13 +72,21 @@ def setUpClass(cls):
6972
content = fh.read()
7073
swagger_response._content = content
7174
cls.test_swagger_json = json.loads(content.decode("utf-8"))
72-
75+
cls.test_swagger_response = swagger_response
7376
cls.url_base = (cls.test_swagger_json['schemes'][0] + "://" +
7477
cls.test_swagger_json['host'] +
7578
cls.test_swagger_json['basePath'])
7679

80+
cls.client = cls.create_client(
81+
cls.swagger_url, cls.test_swagger_response, cls.subparsers, cls.open_fn_name)
82+
83+
@staticmethod
84+
def create_client(swagger_url, swagger_response, subparsers, open_fn_name):
85+
"""
86+
Create and return a new SwaggerClient
87+
"""
7788
with mock.patch('requests.Session.get') as mock_get, \
78-
mock.patch(cls.open_fn_name, mock_open()), \
89+
mock.patch(open_fn_name, mock_open()), \
7990
mock.patch('hca.util.fs.atomic_write'), \
8091
mock.patch('hca.dss.SwaggerClient.load_swagger_json') as mock_load_swagger_json:
8192
# init SwaggerClient with test swagger JSON file
@@ -84,9 +95,10 @@ def setUpClass(cls):
8495

8596
config = HCAConfig(save_on_exit=False)
8697
config['SwaggerClient'] = {}
87-
config['SwaggerClient'].swagger_url = cls.swagger_url
88-
cls.client = hca.util.SwaggerClient(config)
89-
cls.client.build_argparse_subparsers(cls.subparsers)
98+
config['SwaggerClient'].swagger_url = swagger_url
99+
client = hca.util.SwaggerClient(config)
100+
client.build_argparse_subparsers(subparsers)
101+
return client
90102

91103
@classmethod
92104
def tearDownClass(cls):
@@ -97,8 +109,8 @@ def setUp(self):
97109

98110
def test_client_methods_exist(self):
99111
for method_name in self.generated_method_names:
100-
self.assertTrue(hasattr(self.client.__class__, method_name) and
101-
callable(getattr(self.client.__class__, method_name)))
112+
self.assertTrue(hasattr(self.client, method_name) and
113+
callable(getattr(self.client, method_name)))
102114

103115
def test_get_with_path_query_params(self):
104116
http_method = "get"
@@ -331,6 +343,37 @@ def test_put_with_invalid_enum_param(self):
331343
'--query-param', query_param_invalid])
332344
self.assertEqual(e.exception.code, 2)
333345

346+
def test_multithreaded(self):
347+
http_method = "get"
348+
path = "/with/path/query/params"
349+
path_param = "path"
350+
url = self.url_base + path + "/" + path_param
351+
num_threads = 32
352+
num_attempts = 400
353+
354+
with concurrent.futures.ThreadPoolExecutor(num_threads) as exe, \
355+
mock.patch('requests.Session.request') as mock_request:
356+
357+
mock_request.return_value = self.dummy_response
358+
359+
def call_with_query_param(param):
360+
client = self.create_client(
361+
self.swagger_url, self.test_swagger_response, self.subparsers,
362+
self.open_fn_name)
363+
with client.get_with_path_query_params.stream(
364+
path_param=path_param, query_param=param):
365+
pass
366+
367+
futures = [exe.submit(call_with_query_param, str(i)) for i in range(num_attempts)]
368+
369+
while any(not f.done() for f in futures):
370+
time.sleep(.5)
371+
372+
called_query_params = set()
373+
for call in mock_request.mock_calls:
374+
if 'params' in call[2]:
375+
called_query_params.add(call[2]["params"]["query_param"])
376+
self.assertSetEqual(called_query_params, set(str(i) for i in range(num_attempts)))
334377

335378
if __name__ == "__main__":
336379
unittest.main()

0 commit comments

Comments
 (0)