Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions roadlib/roadtools/roadlib/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(self, username=None, password=None, tenant=None, client_id='1b73095
self.claims = None
self.use_pkce = False
self.pkce_secret = None
self.custom_headers = {}

# For cert based app auth
self.appprivkey = None
Expand Down Expand Up @@ -127,6 +128,13 @@ def get_redirect_for_client(self, client_id, interactive=False, broker=False):
except Exception:
return 'https://login.microsoftonline.com/common/oauth2/nativeclient'

def set_custom_headers(self, headers_dict):
"""
Sets custom headers to include in all HTTP requests
"""
if headers_dict:
self.custom_headers = headers_dict

def user_discovery_v1(self, username):
"""
Discover whether this is a federated user
Expand Down Expand Up @@ -1595,6 +1603,9 @@ def get_sub_argparse(auth_parser, for_rr=False):
auth_parser.add_argument('--tokens-stdout',
action='store_true',
help='Do not store tokens on disk, pipe to stdout instead')
auth_parser.add_argument('--headers',
action='store',
help='Custom headers in JSON format, e.g., \'{"X-AnchorMailbox": "Oid:objectid@tenantid"}\'')
auth_parser.add_argument('--debug',
action='store_true',
help='Enable debug logging to disk')
Expand Down Expand Up @@ -1801,6 +1812,10 @@ def requests_get(self, *args, **kwargs):
headers = kwargs.get('headers',{})
headers['User-Agent'] = self.user_agent
kwargs['headers'] = headers
if self.custom_headers:
headers = kwargs.get('headers',{})
headers.update(self.custom_headers)
kwargs['headers'] = headers
return requests.get(*args, timeout=30.0, **kwargs)

def requests_post(self, *args, **kwargs):
Expand All @@ -1817,6 +1832,10 @@ def requests_post(self, *args, **kwargs):
headers = kwargs.get('headers',{})
headers['Origin'] = self.origin
kwargs['headers'] = headers
if self.custom_headers:
headers = kwargs.get('headers',{})
headers.update(self.custom_headers)
kwargs['headers'] = headers
return requests.post(*args, timeout=30.0, **kwargs)

def requests_put(self, *args, **kwargs):
Expand All @@ -1833,6 +1852,10 @@ def requests_put(self, *args, **kwargs):
headers = kwargs.get('headers',{})
headers['Origin'] = self.origin
kwargs['headers'] = headers
if self.custom_headers:
headers = kwargs.get('headers',{})
headers.update(self.custom_headers)
kwargs['headers'] = headers
return requests.put(*args, timeout=30.0, **kwargs)

def requests_patch(self, *args, **kwargs):
Expand All @@ -1849,6 +1872,10 @@ def requests_patch(self, *args, **kwargs):
headers = kwargs.get('headers',{})
headers['Origin'] = self.origin
kwargs['headers'] = headers
if self.custom_headers:
headers = kwargs.get('headers',{})
headers.update(self.custom_headers)
kwargs['headers'] = headers
return requests.patch(*args, timeout=30.0, **kwargs)

def requests_delete(self, *args, **kwargs):
Expand All @@ -1861,6 +1888,10 @@ def requests_delete(self, *args, **kwargs):
headers = kwargs.get('headers',{})
headers['User-Agent'] = self.user_agent
kwargs['headers'] = headers
if self.custom_headers:
headers = kwargs.get('headers',{})
headers.update(self.custom_headers)
kwargs['headers'] = headers
return requests.delete(*args, timeout=30.0, **kwargs)

def parse_args(self, args):
Expand All @@ -1877,6 +1908,12 @@ def parse_args(self, args):
self.set_resource_uri(args.resource)
self.set_scope(args.scope)
self.set_user_agent(args.user_agent)
if hasattr(args, 'headers') and args.headers:
try:
headers_dict = json.loads(args.headers)
self.set_custom_headers(headers_dict)
except json.JSONDecodeError:
print(f'Error: Invalid JSON format for headers: {args.headers}')
if args.cae:
self.set_cae()
if args.force_mfa:
Expand Down
20 changes: 20 additions & 0 deletions roadtx/roadtools/roadtx/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def main():
'--broker-redirect-url',
action='store',
help='Broker redirect URL (for Nested App Auth)')
rttsauth_parser.add_argument('--headers',
action='store',
help='Custom headers in JSON format, e.g., \'{"X-AnchorMailbox": "Oid:objectid@tenantid"}\'')

# Construct device module
device_parser = subparsers.add_parser('device', help='Register or join devices to Azure AD')
Expand Down Expand Up @@ -708,6 +711,9 @@ def main():
intauth_parser.add_argument('--otpseed',
action='store',
help='TOTP seed to calculate MFA code when prompted')
intauth_parser.add_argument('--headers',
action='store',
help='Custom headers in JSON format, e.g., \'{"X-AnchorMailbox": "Oid:objectid@tenantid"}\'')

# Interactive auth using Selenium - creds from keepass
kdbauth_parser = subparsers.add_parser('keepassauth', help='Selenium based authentication with credentials from a KeePass database')
Expand Down Expand Up @@ -1091,6 +1097,13 @@ def main():
auth.set_user_agent(args.user_agent)
auth.set_scope(args.scope)
auth.outfile = args.tokenfile
if args.headers:
try:
headers_dict = json.loads(args.headers)
auth.set_custom_headers(headers_dict)
except json.JSONDecodeError:
print(f'Error: Invalid JSON format for headers: {args.headers}')
return
if args.origin:
auth.set_origin_value(args.origin)
elif 'originheader' in tokenobject:
Expand Down Expand Up @@ -1553,6 +1566,13 @@ def main():
auth.set_user_agent(args.user_agent)
auth.tenant = args.tenant
auth.use_pkce = args.pkce
if args.headers:
try:
headers_dict = json.loads(args.headers)
auth.set_custom_headers(headers_dict)
except json.JSONDecodeError:
print(f'Error: Invalid JSON format for headers: {args.headers}')
return
if args.origin:
auth.set_origin_value(args.origin, args.redirect_url)
if args.cae:
Expand Down