diff --git a/roadlib/roadtools/roadlib/auth.py b/roadlib/roadtools/roadlib/auth.py index bfa017ba..04066667 100755 --- a/roadlib/roadtools/roadlib/auth.py +++ b/roadlib/roadtools/roadlib/auth.py @@ -66,6 +66,7 @@ def __init__(self, username=None, password=None, tenant=None, client_id='1b73095 self.claims = None self.use_pkce = False self.pkce_secret = None + self.custom_headers = {} # For cert based app auth self.appprivkey = None @@ -127,6 +128,13 @@ def get_redirect_for_client(self, client_id, interactive=False, broker=False): except Exception: return 'https://login.microsoftonline.com/common/oauth2/nativeclient' + def set_custom_headers(self, headers_dict): + """ + Sets custom headers to include in all HTTP requests + """ + if headers_dict: + self.custom_headers = headers_dict + def user_discovery_v1(self, username): """ Discover whether this is a federated user @@ -1595,6 +1603,9 @@ def get_sub_argparse(auth_parser, for_rr=False): auth_parser.add_argument('--tokens-stdout', action='store_true', help='Do not store tokens on disk, pipe to stdout instead') + auth_parser.add_argument('--headers', + action='store', + help='Custom headers in JSON format, e.g., \'{"X-AnchorMailbox": "Oid:objectid@tenantid"}\'') auth_parser.add_argument('--debug', action='store_true', help='Enable debug logging to disk') @@ -1801,6 +1812,10 @@ def requests_get(self, *args, **kwargs): headers = kwargs.get('headers',{}) headers['User-Agent'] = self.user_agent kwargs['headers'] = headers + if self.custom_headers: + headers = kwargs.get('headers',{}) + headers.update(self.custom_headers) + kwargs['headers'] = headers return requests.get(*args, timeout=30.0, **kwargs) def requests_post(self, *args, **kwargs): @@ -1817,6 +1832,10 @@ def requests_post(self, *args, **kwargs): headers = kwargs.get('headers',{}) headers['Origin'] = self.origin kwargs['headers'] = headers + if self.custom_headers: + headers = kwargs.get('headers',{}) + headers.update(self.custom_headers) + kwargs['headers'] = headers return requests.post(*args, timeout=30.0, **kwargs) def requests_put(self, *args, **kwargs): @@ -1833,6 +1852,10 @@ def requests_put(self, *args, **kwargs): headers = kwargs.get('headers',{}) headers['Origin'] = self.origin kwargs['headers'] = headers + if self.custom_headers: + headers = kwargs.get('headers',{}) + headers.update(self.custom_headers) + kwargs['headers'] = headers return requests.put(*args, timeout=30.0, **kwargs) def requests_patch(self, *args, **kwargs): @@ -1849,6 +1872,10 @@ def requests_patch(self, *args, **kwargs): headers = kwargs.get('headers',{}) headers['Origin'] = self.origin kwargs['headers'] = headers + if self.custom_headers: + headers = kwargs.get('headers',{}) + headers.update(self.custom_headers) + kwargs['headers'] = headers return requests.patch(*args, timeout=30.0, **kwargs) def requests_delete(self, *args, **kwargs): @@ -1861,6 +1888,10 @@ def requests_delete(self, *args, **kwargs): headers = kwargs.get('headers',{}) headers['User-Agent'] = self.user_agent kwargs['headers'] = headers + if self.custom_headers: + headers = kwargs.get('headers',{}) + headers.update(self.custom_headers) + kwargs['headers'] = headers return requests.delete(*args, timeout=30.0, **kwargs) def parse_args(self, args): @@ -1877,6 +1908,12 @@ def parse_args(self, args): self.set_resource_uri(args.resource) self.set_scope(args.scope) self.set_user_agent(args.user_agent) + if hasattr(args, 'headers') and args.headers: + try: + headers_dict = json.loads(args.headers) + self.set_custom_headers(headers_dict) + except json.JSONDecodeError: + print(f'Error: Invalid JSON format for headers: {args.headers}') if args.cae: self.set_cae() if args.force_mfa: diff --git a/roadtx/roadtools/roadtx/main.py b/roadtx/roadtools/roadtx/main.py index 79c225e8..13df8515 100644 --- a/roadtx/roadtools/roadtx/main.py +++ b/roadtx/roadtools/roadtx/main.py @@ -84,6 +84,9 @@ def main(): '--broker-redirect-url', action='store', help='Broker redirect URL (for Nested App Auth)') + rttsauth_parser.add_argument('--headers', + action='store', + help='Custom headers in JSON format, e.g., \'{"X-AnchorMailbox": "Oid:objectid@tenantid"}\'') # Construct device module device_parser = subparsers.add_parser('device', help='Register or join devices to Azure AD') @@ -708,6 +711,9 @@ def main(): intauth_parser.add_argument('--otpseed', action='store', help='TOTP seed to calculate MFA code when prompted') + intauth_parser.add_argument('--headers', + action='store', + help='Custom headers in JSON format, e.g., \'{"X-AnchorMailbox": "Oid:objectid@tenantid"}\'') # Interactive auth using Selenium - creds from keepass kdbauth_parser = subparsers.add_parser('keepassauth', help='Selenium based authentication with credentials from a KeePass database') @@ -1091,6 +1097,13 @@ def main(): auth.set_user_agent(args.user_agent) auth.set_scope(args.scope) auth.outfile = args.tokenfile + if args.headers: + try: + headers_dict = json.loads(args.headers) + auth.set_custom_headers(headers_dict) + except json.JSONDecodeError: + print(f'Error: Invalid JSON format for headers: {args.headers}') + return if args.origin: auth.set_origin_value(args.origin) elif 'originheader' in tokenobject: @@ -1553,6 +1566,13 @@ def main(): auth.set_user_agent(args.user_agent) auth.tenant = args.tenant auth.use_pkce = args.pkce + if args.headers: + try: + headers_dict = json.loads(args.headers) + auth.set_custom_headers(headers_dict) + except json.JSONDecodeError: + print(f'Error: Invalid JSON format for headers: {args.headers}') + return if args.origin: auth.set_origin_value(args.origin, args.redirect_url) if args.cae: