Skip to content

Commit 37eb094

Browse files
committed
feat: enable bulk adding users
1 parent ac8dccd commit 37eb094

File tree

4 files changed

+169
-3
lines changed

4 files changed

+169
-3
lines changed

tableauserverclient/server/endpoint/users_endpoint.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
1+
from collections.abc import Iterable
12
import copy
3+
import csv
4+
import io
5+
import itertools
26
import logging
37
from typing import Optional
8+
from pathlib import Path
9+
import re
410

511
from tableauserverclient.server.query import QuerySet
612

713
from .endpoint import QuerysetEndpoint, api
814
from .exceptions import MissingRequiredFieldError, ServerResponseError
915
from tableauserverclient.server import RequestFactory, RequestOptions
10-
from tableauserverclient.models import UserItem, WorkbookItem, PaginationItem, GroupItem
11-
from ..pager import Pager
16+
from tableauserverclient.models import UserItem, WorkbookItem, PaginationItem, GroupItem, JobItem
17+
from tableauserverclient.server.pager import Pager
1218

1319
from tableauserverclient.helpers.logging import logger
1420

@@ -97,8 +103,25 @@ def add_all(self, users: list[UserItem]):
97103

98104
# helping the user by parsing a file they could have used to add users through the UI
99105
# line format: Username [required], password, display name, license, admin, publish
106+
@api(version="3.15")
107+
def bulk_add(self, users: Iterable[UserItem]) -> JobItem:
108+
"""
109+
line format: Username [required], password, display name, license, admin, publish
110+
"""
111+
url = f"{self.baseurl}/import"
112+
# Allow for iterators to be passed into the function
113+
csv_users, xml_users = itertools.tee(users, 2)
114+
csv_content = create_users_csv(csv_users)
115+
116+
xml_request, content_type = RequestFactory.User.import_from_csv_req(csv_content, xml_users)
117+
server_response = self.post_request(url, xml_request, content_type)
118+
return JobItem.from_response(server_response.content, self.parent_srv.namespace).pop()
119+
100120
@api(version="2.0")
101121
def create_from_file(self, filepath: str) -> tuple[list[UserItem], list[tuple[UserItem, ServerResponseError]]]:
122+
import warnings
123+
124+
warnings.warn("This method is deprecated, use bulk_add instead", DeprecationWarning)
102125
created = []
103126
failed = []
104127
if not filepath.find("csv"):
@@ -205,3 +228,43 @@ def filter(self, *invalid, page_size: Optional[int] = None, **kwargs) -> QuerySe
205228
"""
206229

207230
return super().filter(*invalid, page_size=page_size, **kwargs)
231+
232+
def create_users_csv(users: Iterable[UserItem], identity_pool=None) -> bytes:
233+
"""
234+
Create a CSV byte string from an Iterable of UserItem objects
235+
"""
236+
if identity_pool is not None:
237+
raise NotImplementedError("Identity pool is not supported in this version")
238+
with io.StringIO() as output:
239+
writer = csv.writer(output, quoting=csv.QUOTE_MINIMAL)
240+
for user in users:
241+
site_role = user.site_role or "Unlicensed"
242+
if site_role == "ServerAdministrator":
243+
license = "Creator"
244+
admin_level = "System"
245+
elif site_role.startswith("SiteAdministrator"):
246+
admin_level = "Site"
247+
license = site_role.replace("SiteAdministrator", "")
248+
else:
249+
license = site_role
250+
admin_level = ""
251+
252+
if any(x in site_role for x in ("Creator", "Admin", "Publish")):
253+
publish = 1
254+
else:
255+
publish = 0
256+
257+
writer.writerow(
258+
(
259+
user.name,
260+
getattr(user, "password", ""),
261+
user.fullname,
262+
license,
263+
admin_level,
264+
publish,
265+
user.email,
266+
)
267+
)
268+
output.seek(0)
269+
result = output.read().encode("utf-8")
270+
return result

tableauserverclient/server/request_factory.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -931,6 +931,21 @@ def add_req(self, user_item: UserItem) -> bytes:
931931
user_element.attrib["authSetting"] = user_item.auth_setting
932932
return ET.tostring(xml_request)
933933

934+
def import_from_csv_req(self, csv_content: bytes, users: Iterable[UserItem]):
935+
xml_request = ET.Element("tsRequest")
936+
for user in users:
937+
if user.name is None:
938+
raise ValueError("User name must be populated.")
939+
user_element = ET.SubElement(xml_request, "user")
940+
user_element.attrib["name"] = user.name
941+
user_element.attrib["authSetting"] = user.auth_setting or "ServerDefault"
942+
943+
parts = {
944+
"tableau_user_import": ("tsc_users_file.csv", csv_content, "file"),
945+
"request_payload": ("", ET.tostring(xml_request), "text/xml"),
946+
}
947+
return _add_multipart(parts)
948+
934949

935950
class WorkbookRequest:
936951
def _generate_xml(

test/assets/users_bulk_add_job.xml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
<?xml version='1.0' encoding='UTF-8'?>
2+
<tsResponse xmlns="http://tableau.com/api" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://tableau.com/api https://help.tableau.com/samples/en-us/rest_api/ts-api_3_20.xsd">
3+
<job id="16a3479e-0ff9-4685-a0e4-1533b3c2eb96" mode="Asynchronous" type="UserImport" progress="0" createdAt="2024-06-27T03:21:02Z" finishCode="1"/>
4+
</tsResponse>

test/test_user.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
1+
import csv
2+
import io
13
import os
4+
from pathlib import Path
25
import unittest
36

7+
from defusedxml.ElementTree import fromstring
48
import requests_mock
59

610
import tableauserverclient as TSC
711
from tableauserverclient.datetime_helpers import format_datetime
812

9-
TEST_ASSET_DIR = os.path.join(os.path.dirname(__file__), "assets")
13+
TEST_ASSET_DIR = Path(__file__).resolve().parent / "assets"
1014

15+
BULK_ADD_XML = TEST_ASSET_DIR / "users_bulk_add_job.xml"
1116
GET_XML = os.path.join(TEST_ASSET_DIR, "user_get.xml")
1217
GET_EMPTY_XML = os.path.join(TEST_ASSET_DIR, "user_get_empty.xml")
1318
GET_BY_ID_XML = os.path.join(TEST_ASSET_DIR, "user_get_by_id.xml")
@@ -233,3 +238,82 @@ def test_get_users_from_file(self):
233238
users, failures = self.server.users.create_from_file(USERS)
234239
assert users[0].name == "Cassie", users
235240
assert failures == []
241+
242+
def test_bulk_add(self):
243+
self.server.version = "3.15"
244+
users = [
245+
TSC.UserItem(
246+
"test",
247+
"Viewer",
248+
)
249+
]
250+
with requests_mock.mock() as m:
251+
m.post(f"{self.server.users.baseurl}/import", text=BULK_ADD_XML.read_text())
252+
253+
job = self.server.users.bulk_add(users)
254+
255+
assert m.last_request.method == "POST"
256+
assert m.last_request.url == f"{self.server.users.baseurl}/import"
257+
258+
body = m.last_request.body.replace(b"\r\n", b"\n")
259+
assert body.startswith(b"--") # Check if it's a multipart request
260+
boundary = body.split(b"\n")[0].strip()
261+
262+
# Body starts and ends with a boundary string. Split the body into
263+
# segments and ignore the empty sections at the start and end.
264+
segments = [seg for s in body.split(boundary) if (seg := s.strip()) not in [b"", b"--"]]
265+
assert len(segments) == 2 # Check if there are two segments
266+
267+
# Check if the first segment is the csv file and the second segment is the xml
268+
assert b'Content-Disposition: form-data; name="tableau_user_import"' in segments[0]
269+
assert b'Content-Disposition: form-data; name="request_payload"' in segments[1]
270+
assert b"Content-Type: file" in segments[0]
271+
assert b"Content-Type: text/xml" in segments[1]
272+
273+
xml_string = segments[1].split(b"\n\n")[1].strip()
274+
xml = fromstring(xml_string)
275+
xml_users = xml.findall(".//user", namespaces={})
276+
assert len(xml_users) == len(users)
277+
278+
for user, xml_user in zip(users, xml_users):
279+
assert user.name == xml_user.get("name")
280+
assert xml_user.get("authSetting") == (user.auth_setting or "ServerDefault")
281+
282+
license_map = {
283+
"Viewer": "Viewer",
284+
"Explorer": "Explorer",
285+
"ExplorerCanPublish": "Explorer",
286+
"Creator": "Creator",
287+
"SiteAdministratorExplorer": "Explorer",
288+
"SiteAdministratorCreator": "Creator",
289+
"ServerAdministrator": "Creator",
290+
"Unlicensed": "Unlicensed",
291+
}
292+
publish_map = {
293+
"Unlicensed": 0,
294+
"Viewer": 0,
295+
"Explorer": 0,
296+
"Creator": 1,
297+
"ExplorerCanPublish": 1,
298+
"SiteAdministratorExplorer": 1,
299+
"SiteAdministratorCreator": 1,
300+
"ServerAdministrator": 1,
301+
}
302+
admin_map = {
303+
"SiteAdministratorExplorer": "Site",
304+
"SiteAdministratorCreator": "Site",
305+
"ServerAdministrator": "System",
306+
}
307+
308+
csv_columns = ["name", "password", "fullname", "license", "admin", "publish", "email"]
309+
csv_file = io.StringIO(segments[0].split(b"\n\n")[1].decode("utf-8"))
310+
csv_reader = csv.reader(csv_file)
311+
for user, row in zip(users, csv_reader):
312+
site_role = user.site_role or "Unlicensed"
313+
csv_user = dict(zip(csv_columns, row))
314+
assert user.name == csv_user["name"]
315+
assert (user.fullname or "") == csv_user["fullname"]
316+
assert (user.email or "") == csv_user["email"]
317+
assert license_map[site_role] == csv_user["license"]
318+
assert admin_map.get(site_role, "") == csv_user["admin"]
319+
assert publish_map[site_role] == int(csv_user["publish"])

0 commit comments

Comments
 (0)