Skip to content

Commit ef19111

Browse files
committed
feat: enable bulk adding users
1 parent 7822be0 commit ef19111

File tree

4 files changed

+170
-5
lines changed

4 files changed

+170
-5
lines changed

tableauserverclient/server/endpoint/users_endpoint.py

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
import copy
2+
import csv
3+
import io
4+
import itertools
25
import logging
3-
from typing import List, Optional, Tuple
6+
from pathlib import Path
7+
import re
8+
from typing import List, Iterable, Optional, Tuple, Union
49

5-
from .endpoint import QuerysetEndpoint, api
10+
from tableauserverclient.server.endpoint.endpoint import QuerysetEndpoint, api
611
from .exceptions import MissingRequiredFieldError, ServerResponseError
712
from tableauserverclient.server import RequestFactory, RequestOptions
8-
from tableauserverclient.models import UserItem, WorkbookItem, PaginationItem, GroupItem
9-
from ..pager import Pager
13+
from tableauserverclient.models import UserItem, WorkbookItem, PaginationItem, GroupItem, JobItem
14+
from tableauserverclient.server.pager import Pager
1015

1116
from tableauserverclient.helpers.logging import logger
1217

@@ -95,8 +100,25 @@ def add_all(self, users: List[UserItem]):
95100

96101
# helping the user by parsing a file they could have used to add users through the UI
97102
# line format: Username [required], password, display name, license, admin, publish
103+
@api(version="3.15")
104+
def bulk_add(self, users: Iterable[UserItem]) -> JobItem:
105+
"""
106+
line format: Username [required], password, display name, license, admin, publish
107+
"""
108+
url = f"{self.baseurl}/import"
109+
# Allow for iterators to be passed into the function
110+
csv_users, xml_users = itertools.tee(users, 2)
111+
csv_content = create_users_csv(csv_users)
112+
113+
xml_request, content_type = RequestFactory.User.import_from_csv_req(csv_content, xml_users)
114+
server_response = self.post_request(url, xml_request, content_type)
115+
return JobItem.from_response(server_response.content, self.parent_srv.namespace).pop()
116+
98117
@api(version="2.0")
99118
def create_from_file(self, filepath: str) -> Tuple[List[UserItem], List[Tuple[UserItem, ServerResponseError]]]:
119+
import warnings
120+
121+
warnings.warn("This method is deprecated, use bulk_add instead", DeprecationWarning)
100122
created = []
101123
failed = []
102124
if not filepath.find("csv"):
@@ -166,3 +188,44 @@ def _get_groups_for_user(
166188
group_item = GroupItem.from_response(server_response.content, self.parent_srv.namespace)
167189
pagination_item = PaginationItem.from_response(server_response.content, self.parent_srv.namespace)
168190
return group_item, pagination_item
191+
192+
193+
def create_users_csv(users: Iterable[UserItem], identity_pool=None) -> bytes:
194+
"""
195+
Create a CSV byte string from an Iterable of UserItem objects
196+
"""
197+
if identity_pool is not None:
198+
raise NotImplementedError("Identity pool is not supported in this version")
199+
with io.StringIO() as output:
200+
writer = csv.writer(output, quoting=csv.QUOTE_MINIMAL)
201+
for user in users:
202+
site_role = user.site_role or "Unlicensed"
203+
if site_role == "ServerAdministrator":
204+
license = "Creator"
205+
admin_level = "System"
206+
elif site_role.startswith("SiteAdministrator"):
207+
admin_level = "Site"
208+
license = site_role.replace("SiteAdministrator", "")
209+
else:
210+
license = site_role
211+
admin_level = ""
212+
213+
if any(x in site_role for x in ("Creator", "Admin", "Publish")):
214+
publish = 1
215+
else:
216+
publish = 0
217+
218+
writer.writerow(
219+
(
220+
user.name,
221+
getattr(user, "password", ""),
222+
user.fullname,
223+
license,
224+
admin_level,
225+
publish,
226+
user.email,
227+
)
228+
)
229+
output.seek(0)
230+
result = output.read().encode("utf-8")
231+
return result

tableauserverclient/server/request_factory.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,21 @@ def add_req(self, user_item: UserItem) -> bytes:
882882
user_element.attrib["authSetting"] = user_item.auth_setting
883883
return ET.tostring(xml_request)
884884

885+
def import_from_csv_req(self, csv_content: bytes, users: Iterable[UserItem]):
886+
xml_request = ET.Element("tsRequest")
887+
for user in users:
888+
if user.name is None:
889+
raise ValueError("User name must be populated.")
890+
user_element = ET.SubElement(xml_request, "user")
891+
user_element.attrib["name"] = user.name
892+
user_element.attrib["authSetting"] = user.auth_setting or "ServerDefault"
893+
894+
parts = {
895+
"tableau_user_import": ("tsc_users_file.csv", csv_content, "file"),
896+
"request_payload": ("", ET.tostring(xml_request), "text/xml"),
897+
}
898+
return _add_multipart(parts)
899+
885900

886901
class WorkbookRequest(object):
887902
def _generate_xml(

test/assets/users_bulk_add_job.xml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
<?xml version='1.0' encoding='UTF-8'?>
2+
<tsResponse xmlns="http://tableau.com/api" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://tableau.com/api https://help.tableau.com/samples/en-us/rest_api/ts-api_3_20.xsd">
3+
<job id="16a3479e-0ff9-4685-a0e4-1533b3c2eb96" mode="Asynchronous" type="UserImport" progress="0" createdAt="2024-06-27T03:21:02Z" finishCode="1"/>
4+
</tsResponse>

test/test_user.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
1+
import csv
12
import io
23
import os
4+
from pathlib import Path
35
import unittest
46
from typing import List
57
from unittest.mock import MagicMock
68

9+
from defusedxml.ElementTree import fromstring
710
import requests_mock
811

912
import tableauserverclient as TSC
1013
from tableauserverclient.datetime_helpers import format_datetime
1114

12-
TEST_ASSET_DIR = os.path.join(os.path.dirname(__file__), "assets")
15+
TEST_ASSET_DIR = Path(__file__).resolve().parent / "assets"
1316

17+
BULK_ADD_XML = TEST_ASSET_DIR / "users_bulk_add_job.xml"
1418
GET_XML = os.path.join(TEST_ASSET_DIR, "user_get.xml")
1519
GET_EMPTY_XML = os.path.join(TEST_ASSET_DIR, "user_get_empty.xml")
1620
GET_BY_ID_XML = os.path.join(TEST_ASSET_DIR, "user_get_by_id.xml")
@@ -236,3 +240,82 @@ def test_get_users_from_file(self):
236240
users, failures = self.server.users.create_from_file(USERS)
237241
assert users[0].name == "Cassie", users
238242
assert failures == []
243+
244+
def test_bulk_add(self):
245+
self.server.version = "3.15"
246+
users = [
247+
TSC.UserItem(
248+
"test",
249+
"Viewer",
250+
)
251+
]
252+
with requests_mock.mock() as m:
253+
m.post(f"{self.server.users.baseurl}/import", text=BULK_ADD_XML.read_text())
254+
255+
job = self.server.users.bulk_add(users)
256+
257+
assert m.last_request.method == "POST"
258+
assert m.last_request.url == f"{self.server.users.baseurl}/import"
259+
260+
body = m.last_request.body.replace(b"\r\n", b"\n")
261+
assert body.startswith(b"--") # Check if it's a multipart request
262+
boundary = body.split(b"\n")[0].strip()
263+
264+
# Body starts and ends with a boundary string. Split the body into
265+
# segments and ignore the empty sections at the start and end.
266+
segments = [seg for s in body.split(boundary) if (seg := s.strip()) not in [b"", b"--"]]
267+
assert len(segments) == 2 # Check if there are two segments
268+
269+
# Check if the first segment is the csv file and the second segment is the xml
270+
assert b'Content-Disposition: form-data; name="tableau_user_import"' in segments[0]
271+
assert b'Content-Disposition: form-data; name="request_payload"' in segments[1]
272+
assert b"Content-Type: file" in segments[0]
273+
assert b"Content-Type: text/xml" in segments[1]
274+
275+
xml_string = segments[1].split(b"\n\n")[1].strip()
276+
xml = fromstring(xml_string)
277+
xml_users = xml.findall(".//user", namespaces={})
278+
assert len(xml_users) == len(users)
279+
280+
for user, xml_user in zip(users, xml_users):
281+
assert user.name == xml_user.get("name")
282+
assert xml_user.get("authSetting") == (user.auth_setting or "ServerDefault")
283+
284+
license_map = {
285+
"Viewer": "Viewer",
286+
"Explorer": "Explorer",
287+
"ExplorerCanPublish": "Explorer",
288+
"Creator": "Creator",
289+
"SiteAdministratorExplorer": "Explorer",
290+
"SiteAdministratorCreator": "Creator",
291+
"ServerAdministrator": "Creator",
292+
"Unlicensed": "Unlicensed",
293+
}
294+
publish_map = {
295+
"Unlicensed": 0,
296+
"Viewer": 0,
297+
"Explorer": 0,
298+
"Creator": 1,
299+
"ExplorerCanPublish": 1,
300+
"SiteAdministratorExplorer": 1,
301+
"SiteAdministratorCreator": 1,
302+
"ServerAdministrator": 1,
303+
}
304+
admin_map = {
305+
"SiteAdministratorExplorer": "Site",
306+
"SiteAdministratorCreator": "Site",
307+
"ServerAdministrator": "System",
308+
}
309+
310+
csv_columns = ["name", "password", "fullname", "license", "admin", "publish", "email"]
311+
csv_file = io.StringIO(segments[0].split(b"\n\n")[1].decode("utf-8"))
312+
csv_reader = csv.reader(csv_file)
313+
for user, row in zip(users, csv_reader):
314+
site_role = user.site_role or "Unlicensed"
315+
csv_user = dict(zip(csv_columns, row))
316+
assert user.name == csv_user["name"]
317+
assert (user.fullname or "") == csv_user["fullname"]
318+
assert (user.email or "") == csv_user["email"]
319+
assert license_map[site_role] == csv_user["license"]
320+
assert admin_map.get(site_role, "") == csv_user["admin"]
321+
assert publish_map[site_role] == int(csv_user["publish"])

0 commit comments

Comments
 (0)