|
| 1 | +import csv |
1 | 2 | import io |
2 | 3 | import os |
| 4 | +from pathlib import Path |
3 | 5 | import unittest |
4 | 6 | from typing import List |
5 | 7 | from unittest.mock import MagicMock |
6 | 8 |
|
| 9 | +from defusedxml.ElementTree import fromstring |
7 | 10 | import requests_mock |
8 | 11 |
|
9 | 12 | import tableauserverclient as TSC |
10 | 13 | from tableauserverclient.datetime_helpers import format_datetime |
11 | 14 |
|
12 | | -TEST_ASSET_DIR = os.path.join(os.path.dirname(__file__), "assets") |
| 15 | +TEST_ASSET_DIR = Path(__file__).resolve().parent / "assets" |
13 | 16 |
|
| 17 | +BULK_ADD_XML = TEST_ASSET_DIR / "users_bulk_add_job.xml" |
14 | 18 | GET_XML = os.path.join(TEST_ASSET_DIR, "user_get.xml") |
15 | 19 | GET_EMPTY_XML = os.path.join(TEST_ASSET_DIR, "user_get_empty.xml") |
16 | 20 | GET_BY_ID_XML = os.path.join(TEST_ASSET_DIR, "user_get_by_id.xml") |
@@ -236,3 +240,82 @@ def test_get_users_from_file(self): |
236 | 240 | users, failures = self.server.users.create_from_file(USERS) |
237 | 241 | assert users[0].name == "Cassie", users |
238 | 242 | assert failures == [] |
| 243 | + |
| 244 | + def test_bulk_add(self): |
| 245 | + self.server.version = "3.15" |
| 246 | + users = [ |
| 247 | + TSC.UserItem( |
| 248 | + "test", |
| 249 | + "Viewer", |
| 250 | + ) |
| 251 | + ] |
| 252 | + with requests_mock.mock() as m: |
| 253 | + m.post(f"{self.server.users.baseurl}/import", text=BULK_ADD_XML.read_text()) |
| 254 | + |
| 255 | + job = self.server.users.bulk_add(users) |
| 256 | + |
| 257 | + assert m.last_request.method == "POST" |
| 258 | + assert m.last_request.url == f"{self.server.users.baseurl}/import" |
| 259 | + |
| 260 | + body = m.last_request.body.replace(b"\r\n", b"\n") |
| 261 | + assert body.startswith(b"--") # Check if it's a multipart request |
| 262 | + boundary = body.split(b"\n")[0].strip() |
| 263 | + |
| 264 | + # Body starts and ends with a boundary string. Split the body into |
| 265 | + # segments and ignore the empty sections at the start and end. |
| 266 | + segments = [seg for s in body.split(boundary) if (seg := s.strip()) not in [b"", b"--"]] |
| 267 | + assert len(segments) == 2 # Check if there are two segments |
| 268 | + |
| 269 | + # Check if the first segment is the csv file and the second segment is the xml |
| 270 | + assert b'Content-Disposition: form-data; name="tableau_user_import"' in segments[0] |
| 271 | + assert b'Content-Disposition: form-data; name="request_payload"' in segments[1] |
| 272 | + assert b"Content-Type: file" in segments[0] |
| 273 | + assert b"Content-Type: text/xml" in segments[1] |
| 274 | + |
| 275 | + xml_string = segments[1].split(b"\n\n")[1].strip() |
| 276 | + xml = fromstring(xml_string) |
| 277 | + xml_users = xml.findall(".//user", namespaces={}) |
| 278 | + assert len(xml_users) == len(users) |
| 279 | + |
| 280 | + for user, xml_user in zip(users, xml_users): |
| 281 | + assert user.name == xml_user.get("name") |
| 282 | + assert xml_user.get("authSetting") == (user.auth_setting or "ServerDefault") |
| 283 | + |
| 284 | + license_map = { |
| 285 | + "Viewer": "Viewer", |
| 286 | + "Explorer": "Explorer", |
| 287 | + "ExplorerCanPublish": "Explorer", |
| 288 | + "Creator": "Creator", |
| 289 | + "SiteAdministratorExplorer": "Explorer", |
| 290 | + "SiteAdministratorCreator": "Creator", |
| 291 | + "ServerAdministrator": "Creator", |
| 292 | + "Unlicensed": "Unlicensed", |
| 293 | + } |
| 294 | + publish_map = { |
| 295 | + "Unlicensed": 0, |
| 296 | + "Viewer": 0, |
| 297 | + "Explorer": 0, |
| 298 | + "Creator": 1, |
| 299 | + "ExplorerCanPublish": 1, |
| 300 | + "SiteAdministratorExplorer": 1, |
| 301 | + "SiteAdministratorCreator": 1, |
| 302 | + "ServerAdministrator": 1, |
| 303 | + } |
| 304 | + admin_map = { |
| 305 | + "SiteAdministratorExplorer": "Site", |
| 306 | + "SiteAdministratorCreator": "Site", |
| 307 | + "ServerAdministrator": "System", |
| 308 | + } |
| 309 | + |
| 310 | + csv_columns = ["name", "password", "fullname", "license", "admin", "publish", "email"] |
| 311 | + csv_file = io.StringIO(segments[0].split(b"\n\n")[1].decode("utf-8")) |
| 312 | + csv_reader = csv.reader(csv_file) |
| 313 | + for user, row in zip(users, csv_reader): |
| 314 | + site_role = user.site_role or "Unlicensed" |
| 315 | + csv_user = dict(zip(csv_columns, row)) |
| 316 | + assert user.name == csv_user["name"] |
| 317 | + assert (user.fullname or "") == csv_user["fullname"] |
| 318 | + assert (user.email or "") == csv_user["email"] |
| 319 | + assert license_map[site_role] == csv_user["license"] |
| 320 | + assert admin_map.get(site_role, "") == csv_user["admin"] |
| 321 | + assert publish_map[site_role] == int(csv_user["publish"]) |
0 commit comments