diff --git a/.changes/next-release/enhancement-cloudtrail-28063.json b/.changes/next-release/enhancement-cloudtrail-28063.json new file mode 100644 index 000000000000..388733f396f0 --- /dev/null +++ b/.changes/next-release/enhancement-cloudtrail-28063.json @@ -0,0 +1,5 @@ +{ + "type": "enhancement", + "category": "cloudtrail", + "description": "Added support for validating backfill digest files in ``validate-logs`` command" +} diff --git a/awscli/customizations/cloudtrail/validation.py b/awscli/customizations/cloudtrail/validation.py index d5255f4b8438..50ca8a7d21db 100644 --- a/awscli/customizations/cloudtrail/validation.py +++ b/awscli/customizations/cloudtrail/validation.py @@ -26,12 +26,11 @@ from dateutil import parser, tz from awscli.customizations.cloudtrail.utils import ( - PublicKeyProvider, get_account_id_from_arn, get_trail_by_arn, ) from awscli.customizations.commands import BasicCommand -from awscli.customizations.exceptions import ParamValidationError +from awscli.schema import ParameterRequiredError LOG = logging.getLogger(__name__) DATE_FORMAT = '%Y%m%dT%H%M%SZ' @@ -53,23 +52,35 @@ def normalize_date(date): return date.replace(tzinfo=tz.tzutc()) +def is_backfill_digest_key(digest_key): + """Utility function to determine if a digest key represents a backfill digest file""" + return digest_key.endswith('_backfill.json.gz') + + def extract_digest_key_date(digest_s3_key): """Extract the timestamp portion of a manifest file. Manifest file names take the following form: AWSLogs/{account}/CloudTrail-Digest/{region}/{ymd}/{account}_CloudTrail \ -Digest_{region}_{name}_region_{date}.json.gz + + For backfill files: + AWSLogs/{account}/CloudTrail-Digest/{region}/{ymd}/{account}_CloudTrail \ + -Digest_{region}_{name}_region_{date}_backfill.json.gz """ - return digest_s3_key[-24:-8] + if is_backfill_digest_key(digest_s3_key): + # Backfill files have _backfill suffix before .json.gz + return digest_s3_key[-33:-17] + else: + # Regular digest files + return digest_s3_key[-24:-8] def parse_date(date_string): try: return parser.parse(date_string) except ValueError: - raise ParamValidationError( - f'Unable to parse date value: {date_string}' - ) + raise ValueError(f'Unable to parse date value: {date_string}') def assert_cloudtrail_arn_is_valid(trail_arn): @@ -78,7 +89,7 @@ def assert_cloudtrail_arn_is_valid(trail_arn): ARNs look like: arn:aws:cloudtrail:us-east-1:123456789012:trail/foo""" pattern = re.compile(r'arn:.+:cloudtrail:.+:\d{12}:trail/.+') if not pattern.match(trail_arn): - raise ParamValidationError(f'Invalid trail ARN provided: {trail_arn}') + raise ValueError(f'Invalid trail ARN provided: {trail_arn}') def create_digest_traverser( @@ -134,13 +145,13 @@ def create_digest_traverser( if bucket is None: # Determine the bucket and prefix based on the trail arn. trail_info = get_trail_by_arn(cloudtrail_client, trail_arn) - LOG.debug('Loaded trail info: %s', trail_info) + LOG.debug(f'Loaded trail info: {trail_info}') bucket = trail_info['S3BucketName'] prefix = trail_info.get('S3KeyPrefix', None) is_org_trail = trail_info.get('IsOrganizationTrail') if is_org_trail: if not account_id: - raise ParamValidationError( + raise ParameterRequiredError( "Missing required parameter for organization " "trail: '--account-id'" ) @@ -205,7 +216,9 @@ def _get_bucket_region(self, bucket_name): def _create_client(self, region_name): """Creates an Amazon S3 client for the given region name""" if region_name not in self._client_cache: - client = self._session.create_client('s3', region_name) + client = client = self._session.create_client( + 's3', region_name=region_name + ) # Remove the CLI error event that prevents exceptions. self._client_cache[region_name] = client return self._client_cache[region_name] @@ -225,7 +238,7 @@ def __init__(self, bucket, key): f'Digest file\ts3://{bucket}/{key}\tINVALID: signature verification ' 'failed' ) - super(DigestSignatureError, self).__init__(message) + super().__init__(message) class InvalidDigestFormat(DigestError): @@ -233,7 +246,32 @@ class InvalidDigestFormat(DigestError): def __init__(self, bucket, key): message = f'Digest file\ts3://{bucket}/{key}\tINVALID: invalid format' - super(InvalidDigestFormat, self).__init__(message) + super().__init__(message) + + +class PublicKeyProvider: + """Retrieves public keys from CloudTrail within a date range.""" + + def __init__(self, cloudtrail_client): + self._cloudtrail_client = cloudtrail_client + + def get_public_keys(self, start_date, end_date): + """Loads public keys in a date range into a returned dict. + + :type start_date: datetime + :param start_date: Start date of a date range. + :type end_date: datetime + :param end_date: End date of a date range. + :rtype: dict + :return: Returns a dict where each key is the fingerprint of the + public key, and each value is a dict of public key data. + """ + public_keys = self._cloudtrail_client.list_public_keys( + StartTime=start_date, EndTime=end_date + ) + public_keys_in_range = public_keys['PublicKeyList'] + LOG.debug(f'Loaded public keys in range: {public_keys_in_range}') + return dict((key['Fingerprint'], key) for key in public_keys_in_range) class DigestProvider: @@ -261,18 +299,25 @@ def __init__( self.trail_home_region = trail_home_region self.trail_source_region = trail_source_region or trail_home_region self.organization_id = organization_id + self._digest_cache = {} - def load_digest_keys_in_range(self, bucket, prefix, start_date, end_date): - """Returns a list of digest keys in the date range. + def load_all_digest_keys_in_range( + self, bucket, prefix, start_date, end_date + ): + """Load all digest keys and separate into standard and backfill lists. - This method uses a list_objects API call and provides a Marker - parameter that is calculated based on the start_date provided. - Amazon S3 then returns all keys in the bucket that start after - the given key (non-inclusive). We then iterate over the keys - until the date extracted from the yielded keys is greater than - the given end_date. + Performs a single S3 list operation and separates keys into standard + and backfill digest lists during iteration for optimal performance. + + :param bucket: S3 bucket name + :param prefix: S3 key prefix + :param start_date: Start date for digest range + :param end_date: End date for digest range + :return: Tuple of (standard_digests, backfill_digests) lists + :rtype: tuple """ - digests = [] + standard_digests = [] + backfill_digests = [] marker = self._create_digest_key(start_date, prefix) s3_digest_files_prefix = self._create_digest_prefix(start_date, prefix) client = self._client_provider.get_client(bucket) @@ -288,20 +333,62 @@ def load_digest_keys_in_range(self, bucket, prefix, start_date, end_date): # Ensure digests are from the same trail. digest_key_regex = re.compile(self._create_digest_key_regex(prefix)) for key in key_filter: - if key and digest_key_regex.match(key): - # Use a lexicographic comparison to know when to stop. - extracted_date = extract_digest_key_date(key) - if extracted_date > target_end_date: - break - # Only append digests after the start date. - if extracted_date >= target_start_date: - digests.append(key) - return digests + if not (key and digest_key_regex.match(key)): + continue + # Use a lexicographic comparison to know when to stop. + extracted_date = extract_digest_key_date(key) + if extracted_date > target_end_date: + break + # Only append digests after the start date. + if extracted_date < target_start_date: + continue + if is_backfill_digest_key(key): + backfill_digests.append(key) + else: + standard_digests.append(key) + return standard_digests, backfill_digests + + def load_digest_keys_in_range( + self, bucket, prefix, start_date, end_date, is_backfill=False + ): + """Returns a list of digest keys in the date range. + + This method uses caching to avoid duplicate S3 list operations. + On first call, it loads all digest keys and caches them separated + by type. Subsequent calls return the appropriate cached list. + + :param bucket: S3 bucket name + :param prefix: S3 key prefix + :param start_date: Start date for digest range + :param end_date: End date for digest range + :param is_backfill: Optional filter - True for backfill digests only, + False for standard digests only + :return: List of digest keys matching the specified type + :rtype: list + """ + cache_key = (bucket, prefix, start_date, end_date) + + if cache_key not in self._digest_cache: + standard_digests, backfill_digests = ( + self.load_all_digest_keys_in_range( + bucket, prefix, start_date, end_date + ) + ) + self._digest_cache[cache_key] = { + 'standard': standard_digests, + 'backfill': backfill_digests, + } + + if is_backfill: + return self._digest_cache[cache_key]['backfill'] + else: + return self._digest_cache[cache_key]['standard'] def fetch_digest(self, bucket, key): """Loads a digest by key from S3. Returns the JSON decode data and GZIP inflated raw content. + For backfill digests, also extracts the backfill-generation-timestamp. """ client = self._client_provider.get_client(bucket) result = client.get_object(Bucket=bucket, Key=key) @@ -323,6 +410,15 @@ def fetch_digest(self, bucket, key): digest_data['_signature_algorithm'] = result['Metadata'][ 'signature-algorithm' ] + + if is_backfill_digest_key(key): + if 'backfill-generation-timestamp' in result['Metadata']: + digest_data['_backfill_generation_timestamp'] = result[ + 'Metadata' + ]['backfill-generation-timestamp'] + else: + raise InvalidDigestFormat(bucket, key) + return digest_data, digest def _create_digest_key(self, start_date, key_prefix): @@ -337,24 +433,27 @@ def _create_digest_key(self, start_date, key_prefix): """ # Subtract one minute to ensure the dates are inclusive. date = start_date - timedelta(minutes=1) - template = 'AWSLogs/' - template_params = { - 'account_id': self.account_id, - 'date': format_date(date), - 'ymd': date.strftime('%Y/%m/%d'), - 'source_region': self.trail_source_region, - 'home_region': self.trail_home_region, - 'name': self.trail_name, - } + account_id = self.account_id + date_str = format_date(date) + ymd = date.strftime('%Y/%m/%d') + source_region = self.trail_source_region + home_region = self.trail_home_region + name = self.trail_name + if self.organization_id: - template += '{organization_id}/' - template_params['organization_id'] = self.organization_id - template += ( - '{account_id}/CloudTrail-Digest/{source_region}/' - '{ymd}/{account_id}_CloudTrail-Digest_{source_region}_{name}_' - '{home_region}_{date}.json.gz' - ) - key = template.format(**template_params) + organization_id = self.organization_id + key = ( + f'AWSLogs/{organization_id}/{account_id}/CloudTrail-Digest/' + f'{source_region}/{ymd}/{account_id}_CloudTrail-Digest_' + f'{source_region}_{name}_{home_region}_{date_str}.json.gz' + ) + else: + key = ( + f'AWSLogs/{account_id}/CloudTrail-Digest/{source_region}/' + f'{ymd}/{account_id}_CloudTrail-Digest_{source_region}_{name}_' + f'{home_region}_{date_str}.json.gz' + ) + if key_prefix: key = key_prefix + '/' + key return key @@ -379,23 +478,26 @@ def _create_digest_prefix(self, start_date, key_prefix): return prefix def _create_digest_key_regex(self, key_prefix): - """Creates a regular expression used to match against S3 keys""" - template = 'AWSLogs/' - template_params = { - 'account_id': re.escape(self.account_id), - 'source_region': re.escape(self.trail_source_region), - 'home_region': re.escape(self.trail_home_region), - 'name': re.escape(self.trail_name), - } + """Creates a regular expression used to match against S3 keys for both standard and backfill digests""" + account_id = re.escape(self.account_id) + source_region = re.escape(self.trail_source_region) + home_region = re.escape(self.trail_home_region) + name = re.escape(self.trail_name) + if self.organization_id: - template += '{organization_id}/' - template_params['organization_id'] = self.organization_id - template += ( - '{account_id}/CloudTrail\\-Digest/{source_region}/' - '\\d+/\\d+/\\d+/{account_id}_CloudTrail\\-Digest_' - '{source_region}_{name}_{home_region}_.+\\.json\\.gz' - ) - key = template.format(**template_params) + organization_id = self.organization_id + key = ( + f'AWSLogs/{organization_id}/{account_id}/CloudTrail\\-Digest/' + f'{source_region}/\\d+/\\d+/\\d+/{account_id}_CloudTrail\\-Digest_' + f'{source_region}_{name}_{home_region}_.+(?:_backfill)?\\.json\\.gz' + ) + else: + key = ( + f'AWSLogs/{account_id}/CloudTrail\\-Digest/{source_region}/' + f'\\d+/\\d+/\\d+/{account_id}_CloudTrail\\-Digest_' + f'{source_region}_{name}_{home_region}_.+(?:_backfill)?\\.json\\.gz' + ) + if key_prefix: key = re.escape(key_prefix) + '/' + key return '^' + key + '$' @@ -449,7 +551,7 @@ def __init__( digest_validator = Sha256RSADigestValidator() self._digest_validator = digest_validator - def traverse(self, start_date, end_date=None): + def traverse_digests(self, start_date, end_date=None, is_backfill=False): """Creates and returns a generator that yields validated digest data. Each yielded digest dictionary contains information about the digest @@ -460,8 +562,10 @@ def traverse(self, start_date, end_date=None): :type start_date: datetime :param start_date: Date to start validating from (inclusive). - :type start_date: datetime + :type end_date: datetime :param end_date: Date to stop validating at (inclusive). + :type is_backfill: bool + :param is_backfill: Flag indicating whether to process backfill digests only. """ if end_date is None: end_date = datetime.utcnow() @@ -469,21 +573,53 @@ def traverse(self, start_date, end_date=None): start_date = normalize_date(start_date) bucket = self.starting_bucket prefix = self.starting_prefix - digests = self._load_digests(bucket, prefix, start_date, end_date) - public_keys = self._load_public_keys(start_date, end_date) + + digests = self._load_digests( + bucket, prefix, start_date, end_date, is_backfill=is_backfill + ) + + # For regular digests, pre-load public keys. For backfill, start with empty dict + public_keys = ( + {} if is_backfill else self._load_public_keys(start_date, end_date) + ) + + yield from self._traverse_digest_chain( + digests, + bucket, + prefix, + start_date, + public_keys, + is_backfill=is_backfill, + ) + + def _traverse_digest_chain( + self, + digests, + bucket, + prefix, + start_date, + public_keys, + is_backfill=False, + ): + """Traverses a single chain of digests + + :param is_backfill: Boolean indicating whether this chain contains backfill digests + """ key, end_date = self._get_last_digest(digests) last_start_date = end_date + while key and start_date <= last_start_date: try: digest, end_date = self._load_and_validate_digest( - public_keys, bucket, key + public_keys, bucket, key, is_backfill=is_backfill ) last_start_date = normalize_date( parse_date(digest['digestStartTime']) ) previous_bucket = digest.get('previousDigestS3Bucket', None) + previous_key = digest.get('previousDigestS3Object', None) yield digest - if previous_bucket is None: + if previous_bucket is None or previous_key is None: # The chain is broken, so find next in digest store. key, end_date = self._find_next_digest( digests=digests, @@ -492,14 +628,19 @@ def traverse(self, start_date, end_date=None): last_start_date=last_start_date, cb=self._on_gap, is_cb_conditional=True, + is_backfill=is_backfill, ) else: - key = digest['previousDigestS3Object'] + key = previous_key if previous_bucket != bucket: bucket = previous_bucket # The bucket changed so reload the digest list. digests = self._load_digests( - bucket, prefix, start_date, end_date + bucket, + prefix, + start_date, + end_date, + is_backfill=is_backfill, ) except ClientError as e: if e.response['Error']['Code'] != 'NoSuchKey': @@ -511,6 +652,7 @@ def traverse(self, start_date, end_date=None): last_start_date=last_start_date, cb=self._on_missing, message=str(e), + is_backfill=is_backfill, ) except DigestError as e: key, end_date = self._find_next_digest( @@ -520,6 +662,7 @@ def traverse(self, start_date, end_date=None): last_start_date=last_start_date, cb=self._on_invalid, message=str(e), + is_backfill=is_backfill, ) except Exception as e: # Any other unexpected errors. @@ -530,14 +673,18 @@ def traverse(self, start_date, end_date=None): last_start_date=last_start_date, cb=self._on_invalid, message=f'Digest file\ts3://{bucket}/{key}\tINVALID: {str(e)}', + is_backfill=is_backfill, ) - def _load_digests(self, bucket, prefix, start_date, end_date): + def _load_digests( + self, bucket, prefix, start_date, end_date, is_backfill=False + ): return self.digest_provider.load_digest_keys_in_range( bucket=bucket, prefix=prefix, start_date=start_date, end_date=end_date, + is_backfill=is_backfill, ) def _find_next_digest( @@ -549,6 +696,7 @@ def _find_next_digest( cb=None, is_cb_conditional=False, message=None, + is_backfill=False, ): """Finds the next digest in the bucket and invokes any callback.""" next_key, next_end_date = self._get_last_digest(digests, last_key) @@ -560,6 +708,7 @@ def _find_next_digest( next_end_date=next_end_date, last_start_date=last_start_date, message=message, + is_backfill=is_backfill, ) return next_key, next_end_date @@ -586,22 +735,30 @@ def _get_last_digest(self, digests, before_key=None): parse_date(extract_digest_key_date(next_key)) ) if next_key_date < before_key_date: - LOG.debug("Next found key: %s", next_key) + LOG.debug(f"Next found key: {next_key}") return next_key, next_key_date return None, None - def _load_and_validate_digest(self, public_keys, bucket, key): + def _load_and_validate_digest( + self, public_keys, bucket, key, is_backfill=False + ): """Loads and validates a digest from S3. :param public_keys: Public key dictionary of fingerprint to dict. + :param bucket: S3 bucket name + :param key: S3 key for the digest file + :param is_backfill: Flag indicating if this is a backfill digest :return: Returns a tuple of the digest data as a dict and end_date :rtype: tuple """ digest_data, digest = self.digest_provider.fetch_digest(bucket, key) + + # Validate required keys are present for required_key in self.required_digest_keys: if required_key not in digest_data: raise InvalidDigestFormat(bucket, key) - # Ensure the bucket and key are the same as what's expected. + + # Ensure the bucket and key are the same as what's expected if ( digest_data['digestS3Bucket'] != bucket or digest_data['digestS3Object'] != key @@ -610,17 +767,29 @@ def _load_and_validate_digest(self, public_keys, bucket, key): f'Digest file\ts3://{bucket}/{key}\tINVALID: has been moved from its ' 'original location' ) - # Get the public keys in the given time range. + fingerprint = digest_data['digestPublicKeyFingerprint'] + if fingerprint not in public_keys and is_backfill: + # Backfill-specific logic to fetch public keys + backfill_timestamp = normalize_date( + parse_date(digest_data['_backfill_generation_timestamp']) + ) + start_time = backfill_timestamp - timedelta(hours=1) + end_time = backfill_timestamp + timedelta(hours=1) + public_keys.update(self._load_public_keys(start_time, end_time)) + if fingerprint not in public_keys: - raise DigestError( + error_message = ( f'Digest file\ts3://{bucket}/{key}\tINVALID: public key not found in ' f'region {self.digest_provider.trail_home_region} for fingerprint {fingerprint}' ) + raise DigestError(error_message) + public_key_hex = public_keys[fingerprint]['Value'] self._digest_validator.validate( bucket, key, public_key_hex, digest_data, digest ) + end_date = normalize_date(parse_date(digest_data['digestEndTime'])) return digest_data, end_date @@ -687,14 +856,9 @@ def _create_string_to_sign(self, digest_data, inflated_digest): if previous_signature is None: # The value must be 'null' to match the Java implementation. previous_signature = 'null' - string_to_sign = "{}\n{}/{}\n{}\n{}".format( - digest_data['digestEndTime'], - digest_data['digestS3Bucket'], - digest_data['digestS3Object'], - hashlib.sha256(inflated_digest).hexdigest(), - previous_signature, - ) - LOG.debug('Digest string to sign: %s', string_to_sign) + + string_to_sign = f"{digest_data['digestEndTime']}\n{digest_data['digestS3Bucket']}/{digest_data['digestS3Object']}\n{hashlib.sha256(inflated_digest).hexdigest()}\n{previous_signature}" + LOG.debug(f'Digest string to sign: {string_to_sign}') return string_to_sign.encode() @@ -708,7 +872,8 @@ class CloudTrailValidateLogs(BasicCommand): Validates CloudTrail logs for a given period of time. This command uses the digest files delivered to your S3 bucket to perform - the validation. + the validation. It supports validation of both digest files and + backfill digest files in a single run. The AWS CLI allows you to detect the following types of changes: @@ -813,7 +978,7 @@ class CloudTrailValidateLogs(BasicCommand): ] def __init__(self, session): - super(CloudTrailValidateLogs, self).__init__(session) + super().__init__(session) self.trail_arn = None self.is_verbose = False self.start_time = None @@ -826,6 +991,8 @@ def __init__(self, session): self._source_region = None self._valid_digests = 0 self._invalid_digests = 0 + self._valid_backfill_digests = 0 + self._invalid_backfill_digests = 0 self._valid_logs = 0 self._invalid_logs = 0 self._is_last_status_double_space = True @@ -836,7 +1003,10 @@ def _run_main(self, args, parsed_globals): self.handle_args(args) self.setup_services(parsed_globals) self._call() - if self._invalid_digests > 0 or self._invalid_logs > 0: + total_invalid_digests = ( + self._invalid_digests + self._invalid_backfill_digests + ) + if total_invalid_digests > 0 or self._invalid_logs > 0: return 1 return 0 @@ -852,16 +1022,10 @@ def handle_args(self, args): else: self.end_time = normalize_date(datetime.utcnow()) if self.start_time > self.end_time: - raise ParamValidationError( + raise ValueError( 'Invalid time range specified: start-time must ' 'occur before end-time' ) - # Found start time always defaults to the given start time. This value - # may change if the earliest found digest is after the given start - # time. Note that the summary output report of what date ranges were - # actually found is only shown if a valid digest is encountered, - # thereby setting self._found_end_time to a value. - self._found_start_time = self.start_time def setup_services(self, parsed_globals): self._source_region = parsed_globals.region @@ -898,34 +1062,63 @@ def _call(self): account_id=self.account_id, ) self._write_startup_text() - digests = traverser.traverse(self.start_time, self.end_time) + + digests = traverser.traverse_digests( + self.start_time, self.end_time, is_backfill=False + ) for digest in digests: # Only valid digests are yielded and only valid digests can adjust # the found times that are reported in the CLI output summary. self._track_found_times(digest) + self._valid_digests += 1 + self._write_status( - 'Digest file\ts3://{}/{}\tvalid'.format( - digest['digestS3Bucket'], digest['digestS3Object'] - ) + f'Digest file\ts3://{digest["digestS3Bucket"]}/{digest["digestS3Object"]}\tvalid' + ) + + if not digest['logFiles']: + continue + for log in digest['logFiles']: + self._download_log(log) + + backfill_digests = traverser.traverse_digests( + self.start_time, self.end_time, is_backfill=True + ) + for digest in backfill_digests: + # Only valid digests are yielded and only valid digests can adjust + # the found times that are reported in the CLI output summary. + self._track_found_times(digest) + + self._valid_backfill_digests += 1 + + self._write_status( + f'(backfill) Digest file\ts3://{digest["digestS3Bucket"]}/{digest["digestS3Object"]}\tvalid' ) + if not digest['logFiles']: continue for log in digest['logFiles']: self._download_log(log) + self._write_summary_text() def _track_found_times(self, digest): # Track the earliest found start time, but do not use a date before # the user supplied start date. digest_start_time = parse_date(digest['digestStartTime']) - if digest_start_time > self.start_time: - self._found_start_time = digest_start_time - # Only use the last found end time if it is less than the - # user supplied end time (or the current date). - if not self._found_end_time: - digest_end_time = parse_date(digest['digestEndTime']) - self._found_end_time = min(digest_end_time, self.end_time) + earliest_start_time = max(digest_start_time, self.start_time) + if ( + not self._found_start_time + or earliest_start_time < self._found_start_time + ): + self._found_start_time = earliest_start_time + # Track the latest found end time from all digest types, but do not exceed + # the user supplied end time (or the current date). + digest_end_time = parse_date(digest['digestEndTime']) + latest_end_time = min(digest_end_time, self.end_time) + if not self._found_end_time or latest_end_time > self._found_end_time: + self._found_end_time = latest_end_time def _download_log(self, log): """Download a log, decompress, and compare SHA256 checksums""" @@ -949,7 +1142,7 @@ def _download_log(self, log): else: self._valid_logs += 1 self._write_status( - f"Log file\ts3://{log['s3Bucket']}/{log['s3Object']}\tvalid" + f'Log file\ts3://{log["s3Bucket"]}/{log["s3Object"]}\tvalid' ) except ClientError as e: if e.response['Error']['Code'] != 'NoSuchKey': @@ -980,7 +1173,15 @@ def _write_summary_text(self): sys.stdout.write( f'Results requested for {format_display_date(self.start_time)} to {format_display_date(self.end_time)}\n' ) - if not self._valid_digests and not self._invalid_digests: + + total_valid_digests = ( + self._valid_digests + self._valid_backfill_digests + ) + total_invalid_digests = ( + self._invalid_digests + self._invalid_backfill_digests + ) + + if not total_valid_digests and not total_invalid_digests: sys.stdout.write('No digests found\n') return if not self._found_start_time or not self._found_end_time: @@ -989,64 +1190,69 @@ def _write_summary_text(self): sys.stdout.write( f'Results found for {format_display_date(self._found_start_time)} to {format_display_date(self._found_end_time)}:\n' ) + self._write_ratio(self._valid_digests, self._invalid_digests, 'digest') + self._write_ratio( + self._valid_backfill_digests, + self._invalid_backfill_digests, + 'backfill digest', + ) self._write_ratio(self._valid_logs, self._invalid_logs, 'log') + sys.stdout.write('\n') def _write_ratio(self, valid, invalid, name): total = valid + invalid if total > 0: - sys.stdout.write('\n%d/%d %s files valid' % (valid, total, name)) + sys.stdout.write(f'\n{valid}/{total} {name} files valid') if invalid > 0: - sys.stdout.write( - ', %d/%d %s files INVALID' % (invalid, total, name) - ) + sys.stdout.write(f', {invalid}/{total} {name} files INVALID') - def _on_missing_digest(self, bucket, last_key, **kwargs): - self._invalid_digests += 1 + def _on_missing_digest( + self, bucket, last_key, is_backfill=False, **kwargs + ): + if is_backfill: + self._invalid_backfill_digests += 1 + else: + self._invalid_digests += 1 + digest_type = '(backfill) ' if is_backfill else '' self._write_status( - f'Digest file\ts3://{bucket}/{last_key}\tINVALID: not found', + f'{digest_type}Digest file\ts3://{bucket}/{last_key}\tINVALID: not found', True, ) - def _on_digest_gap(self, **kwargs): + def _on_digest_gap(self, is_backfill=False, **kwargs): + log_type = '(backfill) ' if is_backfill else '' self._write_status( - 'No log files were delivered by CloudTrail between {} and {}'.format( - format_display_date(kwargs['next_end_date']), - format_display_date(kwargs['last_start_date']), - ), + f'{log_type}No log files were delivered by CloudTrail between {format_display_date(kwargs["next_end_date"])} and {format_display_date(kwargs["last_start_date"])}', True, ) - def _on_invalid_digest(self, message, **kwargs): - self._invalid_digests += 1 - self._write_status(message, True) + def _on_invalid_digest(self, message, is_backfill=False, **kwargs): + if is_backfill: + self._invalid_backfill_digests += 1 + else: + self._invalid_digests += 1 + digest_type = '(backfill) ' if is_backfill else '' + self._write_status(f'{digest_type}{message}', True) def _on_invalid_log_format(self, log_data): self._invalid_logs += 1 self._write_status( - ( - 'Log file\ts3://{}/{}\tINVALID: invalid format'.format( - log_data['s3Bucket'], log_data['s3Object'] - ) - ), + f'Log file\ts3://{log_data["s3Bucket"]}/{log_data["s3Object"]}\tINVALID: invalid format', True, ) def _on_log_invalid(self, log_data): self._invalid_logs += 1 self._write_status( - "Log file\ts3://{}/{}\tINVALID: hash value doesn't match".format( - log_data['s3Bucket'], log_data['s3Object'] - ), + f"Log file\ts3://{log_data['s3Bucket']}/{log_data['s3Object']}\tINVALID: hash value doesn't match", True, ) def _on_missing_log(self, log_data): self._invalid_logs += 1 self._write_status( - 'Log file\ts3://{}/{}\tINVALID: not found'.format( - log_data['s3Bucket'], log_data['s3Object'] - ), + f'Log file\ts3://{log_data["s3Bucket"]}/{log_data["s3Object"]}\tINVALID: not found', True, ) diff --git a/tests/functional/cloudtrail/test_validation.py b/tests/functional/cloudtrail/test_validation.py index 6ca343628939..ac9d2db2a77d 100644 --- a/tests/functional/cloudtrail/test_validation.py +++ b/tests/functional/cloudtrail/test_validation.py @@ -11,6 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import gzip +import json from botocore.exceptions import ClientError from botocore.handlers import parse_get_bucket_location @@ -82,7 +83,7 @@ def mock_create( class BaseCloudTrailCommandTest(BaseAWSCommandParamsTest): def setUp(self): - super(BaseCloudTrailCommandTest, self).setUp() + super().setUp() # We need to remove this handler to ensure that we can mock out the # get_bucket_location operation. self.driver.session.unregister( @@ -121,18 +122,19 @@ def setUp(self): class TestCloudTrailCommand(BaseCloudTrailCommandTest): def setUp(self): - super(TestCloudTrailCommand, self).setUp() + super().setUp() self._traverser_patch = mock.patch(RETRIEVER_FUNCTION) self._mock_traverser = self._traverser_patch.start() def tearDown(self): - super(TestCloudTrailCommand, self).tearDown() + super().tearDown() self._traverser_patch.stop() def test_verbose_output_shows_happy_case(self): self.parsed_responses = [ {'LocationConstraint': 'us-east-1'}, {'Body': BytesIO(_gz_compress(self._logs[0]['_raw_value']))}, + {'Body': BytesIO(_gz_compress(self._logs[0]['_raw_value']))}, ] key_provider, digest_provider, validator = create_scenario( ['gap', 'link'], [[], [self._logs[0]]] @@ -141,17 +143,31 @@ def test_verbose_output_shows_happy_case(self): self._mock_traverser, key_provider, digest_provider, validator ) stdout, stderr, rc = self.run_cmd( - ( - "cloudtrail validate-logs --trail-arn %s --start-time %s " - "--region us-east-1 --verbose" - ) - % (TEST_TRAIL_ARN, START_TIME_ARG), + f"cloudtrail validate-logs --trail-arn {TEST_TRAIL_ARN} --start-time {START_TIME_ARG} " + "--region us-east-1 --verbose", 0, ) self.assertIn( - 'Digest file\ts3://1/%s\tvalid' % digest_provider.digests[0], + f'Digest file\ts3://1/{digest_provider.digests[0]}\tvalid', + stdout, + ) + self.assertIn('2/2 digest files valid', stdout) + self.assertIn('2/2 backfill digest files valid', stdout) + self.assertIn('2/2 log files valid', stdout) + self.assertIn( + f'Digest file\ts3://1/{digest_provider.digests[1]}\tvalid', + stdout, + ) + self.assertIn('Log file\ts3://1/key1\tvalid', stdout) + self.assertIn( + f'(backfill) Digest file\ts3://1/{digest_provider.backfill_digests[0]}\tvalid', stdout, ) + self.assertIn( + f'(backfill) Digest file\ts3://1/{digest_provider.backfill_digests[1]}\tvalid', + stdout, + ) + self.assertIn('Log file\ts3://1/key1\tvalid', stdout) def test_verbose_output_shows_valid_digests(self): key_provider, digest_provider, validator = create_scenario(['gap'], []) @@ -159,14 +175,19 @@ def test_verbose_output_shows_valid_digests(self): self._mock_traverser, key_provider, digest_provider, validator ) stdout, stderr, rc = self.run_cmd( - "cloudtrail validate-logs --trail-arn %s --start-time %s --verbose" - % (TEST_TRAIL_ARN, START_TIME_ARG), + f"cloudtrail validate-logs --trail-arn {TEST_TRAIL_ARN} --start-time {START_TIME_ARG} --verbose", 0, ) self.assertIn( - 'Digest file\ts3://1/%s\tvalid' % digest_provider.digests[0], + f'Digest file\ts3://1/{digest_provider.digests[0]}\tvalid', stdout, ) + self.assertIn( + f'(backfill) Digest file\ts3://1/{digest_provider.backfill_digests[0]}\tvalid', + stdout, + ) + self.assertIn('1/1 digest files valid', stdout) + self.assertIn('1/1 backfill digest files valid', stdout) def test_warns_when_digest_deleted(self): key_provider, digest_provider, validator = create_scenario( @@ -176,20 +197,32 @@ def test_warns_when_digest_deleted(self): self._mock_traverser, key_provider, digest_provider, validator ) stdout, stderr, rc = self.run_cmd( - "cloudtrail validate-logs --trail-arn %s --start-time %s --verbose" - % (TEST_TRAIL_ARN, START_TIME_ARG), + f"cloudtrail validate-logs --trail-arn {TEST_TRAIL_ARN} --start-time {START_TIME_ARG} --verbose", 1, ) self.assertIn( - 'Digest file\ts3://1/%s\tINVALID: not found' - % digest_provider.digests[1], + f'Digest file\ts3://1/{digest_provider.digests[1]}\tINVALID: not found', stderr, ) self.assertIn( - 'Digest file\ts3://1/%s\tINVALID: not found' - % digest_provider.digests[3], + f'Digest file\ts3://1/{digest_provider.digests[3]}\tINVALID: not found', stderr, ) + self.assertIn( + f'(backfill) Digest file\ts3://1/{digest_provider.backfill_digests[1]}\tINVALID: not found', + stderr, + ) + self.assertIn( + f'(backfill) Digest file\ts3://1/{digest_provider.backfill_digests[3]}\tINVALID: not found', + stderr, + ) + self.assertIn( + '2/4 digest files valid, 2/4 digest files INVALID', stdout + ) + self.assertIn( + '2/4 backfill digest files valid, 2/4 backfill digest files INVALID', + stdout, + ) def test_warns_when_no_digests_in_gap(self): key_provider, digest_provider, validator = create_scenario( @@ -199,8 +232,7 @@ def test_warns_when_no_digests_in_gap(self): self._mock_traverser, key_provider, digest_provider, validator ) stdout, stderr, rc = self.run_cmd( - "cloudtrail validate-logs --trail-arn %s --start-time '%s'" - % (TEST_TRAIL_ARN, START_TIME_ARG), + f"cloudtrail validate-logs --trail-arn {TEST_TRAIL_ARN} --start-time '{START_TIME_ARG}'", 0, ) self.assertIn( @@ -219,18 +251,21 @@ def test_warns_when_digest_invalid(self): self._mock_traverser, key_provider, digest_provider, validator ) stdout, stderr, rc = self.run_cmd( - "cloudtrail validate-logs --trail-arn %s --start-time %s" - % (TEST_TRAIL_ARN, START_TIME_ARG), + f"cloudtrail validate-logs --trail-arn {TEST_TRAIL_ARN} --start-time {START_TIME_ARG}", 1, ) self.assertIn('invalid error', stderr) self.assertIn( - 'Results requested for %s to ' % format_display_date(START_DATE), + f'Results requested for {format_display_date(START_DATE)} to ', stdout, ) self.assertIn( '2/3 digest files valid, 1/3 digest files INVALID', stdout ) + self.assertIn( + '2/3 backfill digest files valid, 1/3 backfill digest files INVALID', + stdout, + ) def test_shows_successful_summary(self): key_provider, digest_provider, validator = create_scenario( @@ -240,11 +275,8 @@ def test_shows_successful_summary(self): self._mock_traverser, key_provider, digest_provider, validator ) stdout, stderr, rc = self.run_cmd( - ( - "cloudtrail validate-logs --trail-arn %s --start-time %s " - "--end-time %s --verbose" - ) - % (TEST_TRAIL_ARN, START_TIME_ARG, END_TIME_ARG), + f"cloudtrail validate-logs --trail-arn {TEST_TRAIL_ARN} --start-time {START_TIME_ARG} " + f"--end-time {END_TIME_ARG} --verbose", 0, ) self.assertIn( @@ -255,8 +287,9 @@ def test_shows_successful_summary(self): stdout, ) self.assertIn('2/2 digest files valid', stdout) + self.assertIn('2/2 backfill digest files valid', stdout) self.assertIn( - 'Results found for 2014-08-10T01:00:00Z to 2014-08-10T02:30:00Z', + 'Results found for 2014-08-10T00:00:00Z to 2014-08-10T02:30:00Z', stdout, ) @@ -270,16 +303,12 @@ def test_warns_when_no_digests_after_start_date(self): self._mock_traverser, key_provider, digest_provider, validator ) stdout, stderr, rc = self.run_cmd( - ( - 'cloudtrail validate-logs --trail-arn %s --start-time %s ' - '--end-time %s' - ) - % (TEST_TRAIL_ARN, START_TIME_ARG, END_TIME_ARG), + f'cloudtrail validate-logs --trail-arn {TEST_TRAIL_ARN} --start-time {START_TIME_ARG} ' + f'--end-time {END_TIME_ARG}', 0, ) self.assertIn( - 'Results requested for %s to %s\nNo digests found' - % (format_display_date(START_DATE), format_display_date(END_DATE)), + f'Results requested for {format_display_date(START_DATE)} to {format_display_date(END_DATE)}\nNo digests found', stdout, ) @@ -293,16 +322,12 @@ def test_warns_when_no_digests_found_in_range(self): self._mock_traverser, key_provider, digest_provider, validator ) stdout, stderr, rc = self.run_cmd( - ( - "cloudtrail validate-logs --trail-arn %s --start-time '%s' " - "--end-time '%s'" - ) - % (TEST_TRAIL_ARN, START_TIME_ARG, END_TIME_ARG), + f"cloudtrail validate-logs --trail-arn {TEST_TRAIL_ARN} --start-time '{START_TIME_ARG}' " + f"--end-time '{END_TIME_ARG}'", 0, ) self.assertIn( - 'Results requested for %s to %s\nNo digests found' - % (format_display_date(START_DATE), format_display_date(END_DATE)), + f'Results requested for {format_display_date(START_DATE)} to {format_display_date(END_DATE)}\nNo digests found', stdout, ) @@ -314,16 +339,21 @@ def test_warns_when_no_valid_digests_found_in_range(self): self._mock_traverser, key_provider, digest_provider, validator ) stdout, stderr, rc = self.run_cmd( - ( - "cloudtrail validate-logs --trail-arn %s --start-time '%s' " - "--end-time '%s'" - ) - % (TEST_TRAIL_ARN, START_TIME_ARG, END_TIME_ARG), + f"cloudtrail validate-logs --trail-arn {TEST_TRAIL_ARN} --start-time '{START_TIME_ARG}' " + f"--end-time '{END_TIME_ARG}'", 1, ) self.assertIn( - 'Results requested for %s to %s\nNo valid digests found in range' - % (format_display_date(START_DATE), format_display_date(END_DATE)), + f'Results requested for {format_display_date(START_DATE)} to {format_display_date(END_DATE)}\nNo valid digests found in range', + stdout, + ) + self.assertIn('(backfill) invalid error', stderr) + self.assertIn('invalid error', stderr) + self.assertIn( + '0/1 digest files valid, 1/1 digest files INVALID', stdout + ) + self.assertIn( + '0/1 backfill digest files valid, 1/1 backfill digest files INVALID', stdout, ) @@ -334,21 +364,21 @@ def test_fails_and_warns_when_log_hash_is_invalid(self): self.parsed_responses = [ {'LocationConstraint': ''}, {'Body': BytesIO(_gz_compress('does not match'))}, + {'Body': BytesIO(_gz_compress('does not match'))}, ] _setup_mock_traverser( self._mock_traverser, key_provider, digest_provider, validator ) stdout, stderr, rc = self.run_cmd( - ( - "cloudtrail validate-logs --trail-arn %s --start-time " - "--region us-east-1 '%s'" - ) - % (TEST_TRAIL_ARN, START_TIME_ARG), + f"cloudtrail validate-logs --trail-arn {TEST_TRAIL_ARN} --start-time {START_TIME_ARG} --region us-east-1", 1, ) self.assertIn( 'Log file\ts3://1/key1\tINVALID: hash value doesn\'t match', stderr ) + self.assertIn('1/1 digest files valid', stdout) + self.assertIn('1/1 backfill digest files valid', stdout) + self.assertIn('0/2 log files valid, 2/2 log files INVALID', stdout) def test_validates_valid_log_files(self): key_provider, digest_provider, validator = create_scenario( @@ -360,31 +390,41 @@ def test_validates_valid_log_files(self): {'Body': BytesIO(_gz_compress(self._logs[0]['_raw_value']))}, {'Body': BytesIO(_gz_compress(self._logs[1]['_raw_value']))}, {'Body': BytesIO(_gz_compress(self._logs[2]['_raw_value']))}, + {'Body': BytesIO(_gz_compress(self._logs[0]['_raw_value']))}, + {'Body': BytesIO(_gz_compress(self._logs[1]['_raw_value']))}, + {'Body': BytesIO(_gz_compress(self._logs[2]['_raw_value']))}, ] _setup_mock_traverser( self._mock_traverser, key_provider, digest_provider, validator ) stdout, stderr, rc = self.run_cmd( - "cloudtrail validate-logs --trail-arn %s --start-time %s --verbose" - % (TEST_TRAIL_ARN, START_TIME_ARG), + f"cloudtrail validate-logs --trail-arn {TEST_TRAIL_ARN} --start-time {START_TIME_ARG} --verbose", 0, ) self.assertIn('s3://1/key1', stdout) self.assertIn('s3://1/key2', stdout) self.assertIn('s3://1/key3', stdout) + self.assertIn('Log file\ts3://1/key1\tvalid', stdout) + self.assertIn('Log file\ts3://1/key2\tvalid', stdout) + self.assertIn('Log file\ts3://1/key3\tvalid', stdout) + self.assertIn('3/3 digest files valid', stdout) + self.assertIn('3/3 backfill digest files valid', stdout) + self.assertIn('6/6 log files valid', stdout) def test_ensures_start_time_before_end_time(self): stdout, stderr, rc = self.run_cmd( - ( - "cloudtrail validate-logs --trail-arn %s --start-time 2015-01-01 " - "--end-time 2014-01-01" - ), - 252, + f"cloudtrail validate-logs --trail-arn {TEST_TRAIL_ARN} --start-time 2015-01-01 " + "--end-time 2014-01-01", + 255, ) self.assertIn('start-time must occur before end-time', stderr) def test_fails_when_digest_not_from_same_location_as_json_contents(self): - key_name = END_TIME_ARG + '.json.gz' + key_provider, digest_provider, validator = create_scenario(['gap'], []) + + # Use the actual digest keys from the provider + key_name = digest_provider.digests[0] + backfill_key_name = digest_provider.backfill_digests[0] digest = { 'digestPublicKeyFingerprint': 'a', 'digestS3Bucket': 'not_same', @@ -393,49 +433,90 @@ def test_fails_when_digest_not_from_same_location_as_json_contents(self): 'digestStartTime': '...', 'digestEndTime': '...', } - digest_provider = mock.Mock() - digest_provider.load_digest_keys_in_range.return_value = [key_name] - digest_provider.fetch_digest.return_value = (digest, key_name) + backfill_digest = { + 'digestPublicKeyFingerprint': 'a', + 'digestS3Bucket': 'not_same', + 'digestS3Object': backfill_key_name, + 'previousDigestSignature': '...', + 'digestStartTime': '...', + 'digestEndTime': '...', + } + + def mock_fetch(bucket, key): + if '_backfill' in key: + return (backfill_digest, json.dumps(backfill_digest).encode()) + return (digest, json.dumps(digest).encode()) + + digest_provider.fetch_digest = mock_fetch + _setup_mock_traverser( self._mock_traverser, mock.Mock(), digest_provider, mock.Mock() ) stdout, stderr, rc = self.run_cmd( - "cloudtrail validate-logs --trail-arn %s --start-time %s" - % (TEST_TRAIL_ARN, START_TIME_ARG), + f"cloudtrail validate-logs --trail-arn {TEST_TRAIL_ARN} --start-time {START_TIME_ARG}", 1, ) self.assertIn( - ( - 'Digest file\ts3://1/%s\tINVALID: has been moved from its ' - 'original location' % key_name - ), + f'Digest file\ts3://1/{key_name}\tINVALID: has been moved from its ' + 'original location', stderr, ) + self.assertIn( + f'(backfill) Digest file\ts3://1/{backfill_key_name}\tINVALID: has been moved from its ' + 'original location', + stderr, + ) + self.assertIn( + '0/1 digest files valid, 1/1 digest files INVALID', stdout + ) + self.assertIn( + '0/1 backfill digest files valid, 1/1 backfill digest files INVALID', + stdout, + ) def test_fails_when_digest_is_missing_keys_before_validation(self): - digest = {} - digest_provider = mock.Mock() - key_name = END_TIME_ARG + '.json.gz' - digest_provider.load_digest_keys_in_range.return_value = [key_name] - digest_provider.fetch_digest.return_value = (digest, key_name) + key_provider, digest_provider, validator = create_scenario(['gap'], []) + + def mock_fetch(bucket, key): + return ({}, b'{}') + + digest_provider.fetch_digest = mock_fetch + _setup_mock_traverser( self._mock_traverser, mock.Mock(), digest_provider, mock.Mock() ) stdout, stderr, rc = self.run_cmd( - "cloudtrail validate-logs --trail-arn %s --start-time %s" - % (TEST_TRAIL_ARN, START_TIME_ARG), + f"cloudtrail validate-logs --trail-arn {TEST_TRAIL_ARN} --start-time {START_TIME_ARG}", 1, ) self.assertIn( - 'Digest file\ts3://1/%s\tINVALID: invalid format' % key_name, + f'Digest file\ts3://1/{digest_provider.digests[0]}\tINVALID: invalid format', + stderr, + ) + self.assertIn( + f'(backfill) Digest file\ts3://1/{digest_provider.backfill_digests[0]}\tINVALID: invalid format', stderr, ) + self.assertIn( + '0/1 digest files valid, 1/1 digest files INVALID', stdout + ) + self.assertIn( + '0/1 backfill digest files valid, 1/1 backfill digest files INVALID', + stdout, + ) def test_fails_when_digest_metadata_is_missing(self): key = MockDigestProvider([]).get_key_at_position(1) + backfill_key = MockDigestProvider([]).get_key_at_position( + 1, is_backfill=True + ) self.parsed_responses = [ {'LocationConstraint': ''}, - {'Contents': [{'Key': key}]}, + {'Contents': [{'Key': key}, {'Key': backfill_key}]}, + { + 'Body': BytesIO(_gz_compress(self._logs[0]['_raw_value'])), + 'Metadata': {}, + }, { 'Body': BytesIO(_gz_compress(self._logs[0]['_raw_value'])), 'Metadata': {}, @@ -453,25 +534,31 @@ def test_fails_when_digest_metadata_is_missing(self): self._mock_traverser, key_provider, digest_provider, mock.Mock() ) stdout, stderr, rc = self.run_cmd( - ( - "cloudtrail validate-logs --trail-arn %s --start-time %s " - "--region us-east-1" - ) - % (TEST_TRAIL_ARN, START_TIME_ARG), + f"cloudtrail validate-logs --trail-arn {TEST_TRAIL_ARN} --start-time {START_TIME_ARG} " + "--region us-east-1", 1, ) self.assertIn( - 'Digest file\ts3://1/%s\tINVALID: signature verification failed' - % key, + f'Digest file\ts3://1/{key}\tINVALID: signature verification failed', stderr, ) + self.assertIn( + f'(backfill) Digest file\ts3://1/{backfill_key}\tINVALID: signature verification failed', + stderr, + ) + self.assertIn( + '0/1 digest files valid, 1/1 digest files INVALID', stdout + ) + self.assertIn( + '0/1 backfill digest files valid, 1/1 backfill digest files INVALID', + stdout, + ) def test_follows_trails_when_bucket_changes(self): self.parsed_responses = [ {'LocationConstraint': 'us-east-1'}, {'Body': BytesIO(_gz_compress(self._logs[0]['_raw_value']))}, - {'LocationConstraint': 'us-west-2'}, - {'LocationConstraint': 'eu-west-1'}, + {'Body': BytesIO(_gz_compress(self._logs[0]['_raw_value']))}, ] key_provider, digest_provider, validator = create_scenario( ['gap', 'bucket_change', 'link', 'bucket_change', 'link'], @@ -481,31 +568,249 @@ def test_follows_trails_when_bucket_changes(self): self._mock_traverser, key_provider, digest_provider, validator ) stdout, stderr, rc = self.run_cmd( - ( - "cloudtrail validate-logs --trail-arn %s --start-time %s " - "--region us-east-1 --verbose" + f"cloudtrail validate-logs --trail-arn {TEST_TRAIL_ARN} --start-time {START_TIME_ARG} " + "--region us-east-1 --verbose", + 0, + ) + self.assertIn( + f'Digest file\ts3://3/{digest_provider.digests[0]}\tvalid', + stdout, + ) + self.assertIn( + f'Digest file\ts3://2/{digest_provider.digests[1]}\tvalid', + stdout, + ) + self.assertIn( + f'Digest file\ts3://2/{digest_provider.digests[2]}\tvalid', + stdout, + ) + self.assertIn( + f'Digest file\ts3://1/{digest_provider.digests[3]}\tvalid', + stdout, + ) + self.assertIn( + f'Digest file\ts3://1/{digest_provider.digests[4]}\tvalid', + stdout, + ) + self.assertIn( + f'(backfill) Digest file\ts3://3/{digest_provider.backfill_digests[0]}\tvalid', + stdout, + ) + self.assertIn( + f'(backfill) Digest file\ts3://2/{digest_provider.backfill_digests[1]}\tvalid', + stdout, + ) + self.assertIn( + f'(backfill) Digest file\ts3://2/{digest_provider.backfill_digests[2]}\tvalid', + stdout, + ) + self.assertIn( + f'(backfill) Digest file\ts3://1/{digest_provider.backfill_digests[3]}\tvalid', + stdout, + ) + self.assertIn( + f'(backfill) Digest file\ts3://1/{digest_provider.backfill_digests[4]}\tvalid', + stdout, + ) + self.assertIn('Log file\ts3://1/key1\tvalid', stdout) + self.assertIn('Log file\ts3://1/key1\tvalid', stdout) + self.assertIn('5/5 digest files valid', stdout) + self.assertIn('5/5 backfill digest files valid', stdout) + self.assertIn('2/2 log files valid', stdout) + + def test_validates_standard_digests_only(self): + key_provider, digest_provider, validator = create_scenario( + ['gap', 'link'], [[], [self._logs[0]]] + ) + original_load = digest_provider.load_digest_keys_in_range + + def mock_load_keys( + bucket, prefix, start_date, end_date, is_backfill=False + ): + if is_backfill: + return [] + return original_load( + bucket, prefix, start_date, end_date, is_backfill=False ) - % (TEST_TRAIL_ARN, START_TIME_ARG), + + digest_provider.load_digest_keys_in_range = mock_load_keys + + self.parsed_responses = [ + {'LocationConstraint': 'us-east-1'}, + {'Body': BytesIO(_gz_compress(self._logs[0]['_raw_value']))}, + ] + _setup_mock_traverser( + self._mock_traverser, key_provider, digest_provider, validator + ) + stdout, stderr, rc = self.run_cmd( + f"cloudtrail validate-logs --trail-arn {TEST_TRAIL_ARN} --start-time {START_TIME_ARG} --verbose", 0, ) + self.assertIn('2/2 digest files valid', stdout) + self.assertIn('1/1 log files valid', stdout) + self.assertNotIn('(backfill)', stdout) self.assertIn( - 'Digest file\ts3://3/%s\tvalid' % digest_provider.digests[0], + 'Results found for 2014-08-10T00:00:00Z to 2014-08-10T02:30:00Z', stdout, ) + + def test_validates_backfill_digests_only(self): + key_provider, digest_provider, validator = create_scenario( + ['gap', 'link'], [[], [self._logs[0]]] + ) + original_load = digest_provider.load_digest_keys_in_range + + def mock_load_keys( + bucket, prefix, start_date, end_date, is_backfill=False + ): + if not is_backfill: + return [] + return original_load( + bucket, prefix, start_date, end_date, is_backfill=True + ) + + digest_provider.load_digest_keys_in_range = mock_load_keys + + self.parsed_responses = [ + {'LocationConstraint': 'us-east-1'}, + {'Body': BytesIO(_gz_compress(self._logs[0]['_raw_value']))}, + ] + _setup_mock_traverser( + self._mock_traverser, key_provider, digest_provider, validator + ) + stdout, stderr, rc = self.run_cmd( + f"cloudtrail validate-logs --trail-arn {TEST_TRAIL_ARN} --start-time {START_TIME_ARG} --verbose", + 0, + ) + self.assertIn('(backfill) Digest file\ts3://1/', stdout) + self.assertIn('Log file\ts3://1/key1\tvalid', stdout) + self.assertIn('2/2 backfill digest files valid', stdout) + self.assertIn('1/1 log files valid', stdout) self.assertIn( - 'Digest file\ts3://2/%s\tvalid' % digest_provider.digests[1], + 'Results found for 2014-08-10T00:00:00Z to 2014-08-10T02:30:00Z', stdout, ) + + def _setup_mixed_period_providers( + self, standard_actions, backfill_actions, logs=None + ): + """Helper method to set up mixed period test scenarios with proper key alignment.""" + key_provider, digest_provider, validator = create_scenario( + standard_actions, logs or [] + ) + key_provider_bf, digest_provider_bf, _ = create_scenario( + backfill_actions, logs or [] + ) + + combined_keys = {} + combined_keys.update(key_provider.get_public_keys.return_value) + combined_keys.update(key_provider_bf.get_public_keys.return_value) + key_provider.get_public_keys.return_value = combined_keys + + original_load = digest_provider.load_digest_keys_in_range + original_load_bf = digest_provider_bf.load_digest_keys_in_range + original_fetch = digest_provider.fetch_digest + original_fetch_bf = digest_provider_bf.fetch_digest + + def mock_load_keys( + bucket, prefix, start_date, end_date, is_backfill=False + ): + if is_backfill: + return original_load_bf( + bucket, prefix, start_date, end_date, is_backfill=True + ) + return original_load( + bucket, prefix, start_date, end_date, is_backfill=False + ) + + def mock_fetch(bucket, key): + if '_backfill' in key: + return original_fetch_bf(bucket, key) + return original_fetch(bucket, key) + + digest_provider.load_digest_keys_in_range = mock_load_keys + digest_provider.fetch_digest = mock_fetch + + return key_provider, digest_provider, validator + + def test_validates_mixed_digests_standard_smaller_period(self): + """Test validation when standard digests have smaller time period than backfill digests.""" + key_provider, digest_provider, validator = ( + self._setup_mixed_period_providers( + ['gap'], ['gap', 'link', 'link'], [[]] * 4 + ) + ) + + _setup_mock_traverser( + self._mock_traverser, key_provider, digest_provider, validator + ) + stdout, stderr, rc = self.run_cmd( + f"cloudtrail validate-logs --trail-arn {TEST_TRAIL_ARN} --start-time {START_TIME_ARG} --verbose", + 0, + ) + self.assertIn('1/1 digest files valid', stdout) + self.assertIn('3/3 backfill digest files valid', stdout) self.assertIn( - 'Digest file\ts3://2/%s\tvalid' % digest_provider.digests[2], + 'Results found for 2014-08-10T00:00:00Z to 2014-08-10T03:30:00Z', stdout, ) + + def test_validates_mixed_digests_backfill_smaller_period(self): + key_provider, digest_provider, validator = ( + self._setup_mixed_period_providers( + ['gap', 'link', 'link'], ['gap'], [[], [], [], []] + ) + ) + + _setup_mock_traverser( + self._mock_traverser, key_provider, digest_provider, validator + ) + stdout, stderr, rc = self.run_cmd( + f"cloudtrail validate-logs --trail-arn {TEST_TRAIL_ARN} --start-time {START_TIME_ARG} --verbose", + 0, + ) + self.assertIn('3/3 digest files valid', stdout) + self.assertIn('1/1 backfill digest files valid', stdout) self.assertIn( - 'Digest file\ts3://1/%s\tvalid' % digest_provider.digests[3], + 'Results found for 2014-08-10T00:00:00Z to 2014-08-10T03:30:00Z', stdout, ) + + def test_fails_when_digest_public_key_not_found(self): + key_provider, digest_provider, validator = create_scenario(['gap'], []) + key_provider.get_public_keys.return_value = { + 'wrong_fp': {'Fingerprint': 'wrong_fp', 'Value': 'wrong_key'} + } + + _setup_mock_traverser( + self._mock_traverser, key_provider, digest_provider, validator + ) + stdout, stderr, rc = self.run_cmd( + f"cloudtrail validate-logs --trail-arn {TEST_TRAIL_ARN} --start-time {START_TIME_ARG}", + 1, + ) + error_lines = stderr.split('\n') + standard_errors = [ + line + for line in error_lines + if 'public key not found' in line and '(backfill)' not in line + ] + backfill_errors = [ + line + for line in error_lines + if 'public key not found' in line and '(backfill)' in line + ] + + self.assertEqual(1, len(standard_errors)) + self.assertEqual(1, len(backfill_errors)) + self.assertIn('public key not found', standard_errors[0]) + self.assertIn('public key not found', backfill_errors[0]) + + self.assertIn( + '0/1 digest files valid, 1/1 digest files INVALID', stdout + ) self.assertIn( - 'Digest file\ts3://1/%s\tvalid' % digest_provider.digests[4], + '0/1 backfill digest files valid, 1/1 backfill digest files INVALID', stdout, ) @@ -530,13 +835,19 @@ def test_fails_and_warns_when_log_is_deleted(self): validator, ) stdout, stderr, rc = self.run_cmd( - "cloudtrail validate-logs --trail-arn %s --start-time '%s'" - % (TEST_TRAIL_ARN, START_TIME_ARG), + f"cloudtrail validate-logs --trail-arn {TEST_TRAIL_ARN} --start-time '{START_TIME_ARG}'", 1, ) self.assertIn( 'Log file\ts3://1/key1\tINVALID: not found\n\n', stderr ) + self.assertIn( + 'Log file\ts3://1/key1\tINVALID: not found\n\n', + stderr, + ) + self.assertIn('1/1 digest files valid', stdout) + self.assertIn('1/1 backfill digest files valid', stdout) + self.assertIn('0/2 log files valid, 2/2 log files INVALID', stdout) def patch_make_request(self): """Override the default request patching because we need to diff --git a/tests/unit/customizations/cloudtrail/test_validation.py b/tests/unit/customizations/cloudtrail/test_validation.py index 41a8f485cb4a..60f79bc12987 100644 --- a/tests/unit/customizations/cloudtrail/test_validation.py +++ b/tests/unit/customizations/cloudtrail/test_validation.py @@ -31,16 +31,18 @@ DigestSignatureError, DigestTraverser, InvalidDigestFormat, + PublicKeyProvider, S3ClientProvider, Sha256RSADigestValidator, assert_cloudtrail_arn_is_valid, create_digest_traverser, extract_digest_key_date, format_date, + is_backfill_digest_key, normalize_date, parse_date, ) -from awscli.customizations.exceptions import ParamValidationError +from awscli.schema import ParameterRequiredError from awscli.testutils import BaseAWSCommandParamsTest, mock, unittest from tests import PublicPrivateKeyLoader @@ -49,7 +51,7 @@ START_DATE = parser.parse('20140810T000000Z') END_DATE = parser.parse('20150810T000000Z') TEST_ACCOUNT_ID = '123456789012' -TEST_TRAIL_ARN = 'arn:aws:cloudtrail:us-east-1:%s:trail/foo' % TEST_ACCOUNT_ID +TEST_TRAIL_ARN = f'arn:aws:cloudtrail:us-east-1:{TEST_ACCOUNT_ID}:trail/foo' VALID_TEST_KEY = ( 'MIIBCgKCAQEAn11L2YZ9h7onug2ILi1MWyHiMRsTQjfWE+pHVRLk1QjfW' 'hirG+lpOa8NrwQ/r7Ah5bNL6HepznOU9XTDSfmmnP97mqyc7z/upfZdS/' @@ -112,20 +114,29 @@ def __init__(self, actions, logs=None): self.actions = actions self.calls = {'fetch_digest': [], 'load_digest_keys_in_range': []} self.digests = [] + self.backfill_digests = [] + self.trail_home_region = 'us-east-1' for i in range(len(self.actions)): self.digests.append(self.get_key_at_position(i)) + self.backfill_digests.append( + self.get_key_at_position(i, is_backfill=True) + ) - def get_key_at_position(self, position): + def get_key_at_position(self, position, is_backfill=False): + """Get digest key at position, generating backfill keys if needed.""" dt = START_DATE + timedelta(hours=position) - key = ( - 'AWSLogs/{account}/CloudTrail-Digest/us-east-1/{ymd}/{account}_' - 'CloudTrail-Digest_us-east-1_foo_us-east-1_{date}.json.gz' - ) - return key.format( - account=TEST_ACCOUNT_ID, - ymd=dt.strftime('%Y/%m/%d'), - date=dt.strftime(DATE_FORMAT), - ) + + if is_backfill: + key = ( + f'AWSLogs/{TEST_ACCOUNT_ID}/CloudTrail-Digest/us-east-1/{dt.strftime("%Y/%m/%d")}/' + f'{TEST_ACCOUNT_ID}_CloudTrail-Digest_us-east-1_foo_us-east-1_{dt.strftime(DATE_FORMAT)}_backfill.json.gz' + ) + else: + key = ( + f'AWSLogs/{TEST_ACCOUNT_ID}/CloudTrail-Digest/us-east-1/{dt.strftime("%Y/%m/%d")}/' + f'{TEST_ACCOUNT_ID}_CloudTrail-Digest_us-east-1_foo_us-east-1_{dt.strftime(DATE_FORMAT)}.json.gz' + ) + return key @staticmethod def create_digest( @@ -179,28 +190,46 @@ def create_link( next_key=next_key, logs=digest_logs, ) - # Mark the digest as invalid if specified in the action. if action == 'invalid': digest['_invalid'] = True + + digest['_signature'] = 'mock_signature' + digest['_signature_algorithm'] = 'SHA256' + + if is_backfill_digest_key(key): + digest['_backfill_generation_timestamp'] = '2025-09-01T00:00:00Z' + return digest, json.dumps(digest) - def load_digest_keys_in_range(self, bucket, prefix, start_date, end_date): + def load_digest_keys_in_range( + self, bucket, prefix, start_date, end_date, is_backfill=False + ): self.calls['load_digest_keys_in_range'].append(locals()) + if is_backfill: + return list(self.backfill_digests) return list(self.digests) def fetch_digest(self, bucket, key): self.calls['fetch_digest'].append(key) - position = self.digests.index(key) + position = ( + self.backfill_digests.index(key) + if is_backfill_digest_key(key) + else self.digests.index(key) + ) action = self.actions[position] # Simulate a digest missing from S3 if action == 'missing': raise ClientError( {'Error': {'Code': 'NoSuchKey', 'Message': 'foo'}}, 'GetObject' ) - next_key = self.get_key_at_position(position - 1) + + next_key = self.get_key_at_position( + position - 1, is_backfill_digest_key(key) + ) next_bucket = int(bucket) if action == 'bucket_change': next_bucket += 1 + return self.create_link( key, next_key, @@ -221,7 +250,7 @@ def test_parses_dates_with_better_error_message(self): try: parse_date('foo') self.fail('Should have failed to parse') - except ParamValidationError as e: + except ValueError as e: self.assertIn('Unable to parse date value: foo', str(e)) def test_parses_dates(self): @@ -232,21 +261,21 @@ def test_ensures_cloudtrail_arns_are_valid(self): try: assert_cloudtrail_arn_is_valid('foo:bar:baz') self.fail('Should have failed') - except ParamValidationError as e: + except ValueError as e: self.assertIn('Invalid trail ARN provided: foo:bar:baz', str(e)) def test_ensures_cloudtrail_arns_are_valid_when_missing_resource(self): try: assert_cloudtrail_arn_is_valid( - 'arn:aws:cloudtrail:us-east-1:%s:foo' % TEST_ACCOUNT_ID + f'arn:aws:cloudtrail:us-east-1:{TEST_ACCOUNT_ID}:foo' ) self.fail('Should have failed') - except ParamValidationError as e: + except ValueError as e: self.assertIn('Invalid trail ARN provided', str(e)) def test_allows_valid_arns(self): assert_cloudtrail_arn_is_valid( - 'arn:aws:cloudtrail:us-east-1:%s:trail/foo' % TEST_ACCOUNT_ID + f'arn:aws:cloudtrail:us-east-1:{TEST_ACCOUNT_ID}:trail/foo' ) def test_normalizes_date_timezones(self): @@ -262,6 +291,51 @@ def test_extracts_dates_from_digest_keys(self): ) self.assertEqual('20150816T230550Z', extract_digest_key_date(arn)) + def test_is_backfill_digest_key_identifies_standard_digest(self): + standard_key = ( + f'AWSLogs/{TEST_ACCOUNT_ID}/CloudTrail-Digest/us-east-1/2015/08/' + f'16/{TEST_ACCOUNT_ID}_CloudTrail-Digest_us-east-1_foo_us-east-1_' + '20150816T230550Z.json.gz' + ) + self.assertFalse(is_backfill_digest_key(standard_key)) + + def test_is_backfill_digest_key_identifies_backfill_digest(self): + backfill_key = ( + f'AWSLogs/{TEST_ACCOUNT_ID}/CloudTrail-Digest/us-east-1/2015/08/' + f'16/{TEST_ACCOUNT_ID}_CloudTrail-Digest_us-east-1_foo_us-east-1_' + '20150816T230550Z_backfill.json.gz' + ) + self.assertTrue(is_backfill_digest_key(backfill_key)) + + def test_is_backfill_digest_key_handles_edge_cases(self): + # Test with different file extensions + self.assertFalse(is_backfill_digest_key('file.txt')) + self.assertFalse(is_backfill_digest_key('file_backfill.txt')) + self.assertFalse(is_backfill_digest_key('file.json.gz')) + + # Test with backfill in middle of filename + self.assertFalse(is_backfill_digest_key('file_backfill_other.json.gz')) + + def test_extracts_dates_from_backfill_digest_keys(self): + backfill_key = ( + f'AWSLogs/{TEST_ACCOUNT_ID}/CloudTrail-Digest/us-east-1/2015/08/' + f'16/{TEST_ACCOUNT_ID}_CloudTrail-Digest_us-east-1_foo_us-east-1_' + '20150816T230550Z_backfill.json.gz' + ) + self.assertEqual( + '20150816T230550Z', extract_digest_key_date(backfill_key) + ) + + def test_extracts_dates_from_standard_digest_keys(self): + standard_key = ( + f'AWSLogs/{TEST_ACCOUNT_ID}/CloudTrail-Digest/us-east-1/2015/08/' + f'16/{TEST_ACCOUNT_ID}_CloudTrail-Digest_us-east-1_foo_us-east-1_' + '20150816T230550Z.json.gz' + ) + self.assertEqual( + '20150816T230550Z', extract_digest_key_date(standard_key) + ) + def test_creates_traverser(self): mock_s3_provider = mock.Mock() traverser = create_digest_traverser( @@ -420,7 +494,7 @@ def test_creates_traverser_organization_trail_missing_account_id(self): "Id": TEST_ORGANIZATION_ID, } } - with self.assertRaises(ParamValidationError): + with self.assertRaises(ParameterRequiredError): create_digest_traverser( trail_arn=TEST_TRAIL_ARN, trail_source_region='us-east-1', @@ -430,6 +504,33 @@ def test_creates_traverser_organization_trail_missing_account_id(self): ) +class TestPublicKeyProvider(unittest.TestCase): + def test_returns_public_key_in_range(self): + cloudtrail_client = mock.Mock() + cloudtrail_client.list_public_keys.return_value = { + 'PublicKeyList': [ + {'Fingerprint': 'a', 'OtherData': 'a', 'Value': 'a'}, + {'Fingerprint': 'b', 'OtherData': 'b', 'Value': 'b'}, + {'Fingerprint': 'c', 'OtherData': 'c', 'Value': 'c'}, + ] + } + provider = PublicKeyProvider(cloudtrail_client) + start_date = START_DATE + end_date = start_date + timedelta(days=2) + keys = provider.get_public_keys(start_date, end_date) + self.assertEqual( + { + 'a': {'Fingerprint': 'a', 'OtherData': 'a', 'Value': 'a'}, + 'b': {'Fingerprint': 'b', 'OtherData': 'b', 'Value': 'b'}, + 'c': {'Fingerprint': 'c', 'OtherData': 'c', 'Value': 'c'}, + }, + keys, + ) + cloudtrail_client.list_public_keys.assert_has_calls( + [mock.call(EndTime=end_date, StartTime=start_date)] + ) + + class TestSha256RSADigestValidator(unittest.TestCase): def setUp(self): self._digest_data = { @@ -452,13 +553,7 @@ def test_validates_digests(self): get_private_key_path(), get_public_key_path() ) sha256_hash = hashlib.sha256(self._inflated_digest) - string_to_sign = "%s\n%s/%s\n%s\n%s" % ( - self._digest_data['digestEndTime'], - self._digest_data['digestS3Bucket'], - self._digest_data['digestS3Object'], - sha256_hash.hexdigest(), - self._digest_data['previousDigestSignature'], - ) + string_to_sign = f"{self._digest_data['digestEndTime']}\n{self._digest_data['digestS3Bucket']}/{self._digest_data['digestS3Object']}\n{sha256_hash.hexdigest()}\n{self._digest_data['previousDigestSignature']}" to_sign = string_to_sign.encode() signature = private_key.sign( signature_algorithm=RSASignatureAlgorithm.PKCS1_5_SHA256, @@ -573,10 +668,10 @@ def test_returns_digests_in_range(self): {"Key": bad_region}, # skip (regex (source)) {"Key": keys[3]}, {"Key": keys[4]}, # hour is +1, but keep - {"Key": keys[5]}, + {"Key": keys[5]}, # skip (date >) ] } - ] # skip (date >) + ] self.patch_make_request() provider = self._get_mock_provider(s3_client) digests = provider.load_digest_keys_in_range( @@ -749,6 +844,129 @@ def test_fetches_digests(self): ) self.assertEqual(json_str.encode(), result[1]) + def _fake_backfill_key(self, date): + parsed = parser.parse(date) + return ( + f'prefix/AWSLogs/{TEST_ACCOUNT_ID}/CloudTrail-Digest/us-east-1/{parsed.year}/' + f'{parsed.month}/{parsed.day}/{TEST_ACCOUNT_ID}_CloudTrail-Digest_us-east-1_foo_' + f'us-east-1_{date}_backfill.json.gz' + ) + + def test_load_all_digest_keys_in_range_separates_standard_and_backfill( + self, + ): + s3_client = self.driver.session.create_client('s3') + standard_keys = [ + self._fake_key(format_date(START_DATE + timedelta(days=1))), + self._fake_key(format_date(START_DATE + timedelta(days=2))), + ] + backfill_keys = [ + self._fake_backfill_key( + format_date(START_DATE + timedelta(days=1)) + ), + self._fake_backfill_key( + format_date(START_DATE + timedelta(days=2)) + ), + ] + + self.parsed_responses = [ + { + "Contents": [ + {"Key": standard_keys[0]}, + {"Key": backfill_keys[0]}, + {"Key": standard_keys[1]}, + {"Key": backfill_keys[1]}, + ] + } + ] + self.patch_make_request() + provider = self._get_mock_provider(s3_client) + + standard_digests, backfill_digests = ( + provider.load_all_digest_keys_in_range( + 'foo', 'prefix', START_DATE, END_DATE + ) + ) + + self.assertEqual(standard_keys, standard_digests) + self.assertEqual(backfill_keys, backfill_digests) + + def test_load_digest_keys_in_range_uses_cache(self): + s3_client = mock.Mock() + mock_paginate = s3_client.get_paginator.return_value.paginate + mock_search = mock_paginate.return_value.search + mock_search.return_value = [] + provider = self._get_mock_provider(s3_client) + + provider.load_digest_keys_in_range( + 'bucket', 'prefix', START_DATE, END_DATE, is_backfill=False + ) + self.assertEqual(1, mock_paginate.call_count) + + provider.load_digest_keys_in_range( + 'bucket', 'prefix', START_DATE, END_DATE, is_backfill=True + ) + self.assertEqual(1, mock_paginate.call_count) + + provider.load_digest_keys_in_range( + 'bucket', 'prefix', START_DATE, START_DATE, is_backfill=False + ) + self.assertEqual(2, mock_paginate.call_count) + + def test_fetches_backfill_digests_with_metadata(self): + json_str = '{"foo":"bar"}' + out = BytesIO() + f = gzip.GzipFile(fileobj=out, mode="wb") + f.write(json_str.encode()) + f.close() + gzipped_data = out.getvalue() + s3_client = mock.Mock() + s3_client.get_object.return_value = { + 'Body': BytesIO(gzipped_data), + 'Metadata': { + 'signature': 'abc', + 'signature-algorithm': 'SHA256', + 'backfill-generation-timestamp': '2025-09-01T00:00:00Z', + }, + } + provider = self._get_mock_provider(s3_client) + backfill_key = self._fake_backfill_key(format_date(START_DATE)) + + result = provider.fetch_digest('bucket', backfill_key) + + self.assertEqual( + { + 'foo': 'bar', + '_signature': 'abc', + '_signature_algorithm': 'SHA256', + '_backfill_generation_timestamp': '2025-09-01T00:00:00Z', + }, + result[0], + ) + self.assertEqual(json_str.encode(), result[1]) + + def test_ensures_backfill_digest_has_proper_metadata(self): + json_str = '{"foo":"bar"}' + out = BytesIO() + f = gzip.GzipFile(fileobj=out, mode="wb") + f.write(json_str.encode()) + f.close() + gzipped_data = out.getvalue() + s3_client = mock.Mock() + s3_client.get_object.return_value = { + 'Body': BytesIO(gzipped_data), + 'Metadata': { + 'signature': 'abc', + 'signature-algorithm': 'SHA256', + # Missing backfill-generation-timestamp + }, + } + provider = self._get_mock_provider(s3_client) + backfill_key = self._fake_backfill_key(format_date(START_DATE)) + + with self.assertRaises(InvalidDigestFormat): + provider.fetch_digest('bucket', backfill_key) + class TestDigestTraverser(unittest.TestCase): def test_initializes_with_default_validator(self): @@ -775,7 +993,7 @@ def test_ensures_public_keys_are_loaded(self): starting_prefix='baz', public_key_provider=key_provider, ) - digest_iter = traverser.traverse(start_date, end_date) + digest_iter = traverser.traverse_digests(start_date, end_date) with self.assertRaises(RuntimeError): next(digest_iter) key_provider.get_public_keys.assert_called_with(start_date, end_date) @@ -810,15 +1028,13 @@ def test_ensures_public_key_is_found(self): public_key_provider=key_provider, on_invalid=on_invalid, ) - digest_iter = traverser.traverse(start_date, end_date) + digest_iter = traverser.traverse_digests(start_date, end_date) with self.assertRaises(StopIteration): next(digest_iter) self.assertEqual(1, len(calls)) self.assertEqual( - ( - 'Digest file\ts3://1/%s\tINVALID: public key not ' - 'found in region %s for fingerprint abc' % (key_name, region) - ), + f'Digest file\ts3://1/{key_name}\tINVALID: public key not ' + f'found in region {region} for fingerprint abc', calls[0]['message'], ) @@ -850,7 +1066,7 @@ def test_invokes_digest_validator(self): public_key_provider=key_provider, digest_validator=digest_validator, ) - digest_iter = traverser.traverse(start_date, end_date) + digest_iter = traverser.traverse_digests(start_date, end_date) self.assertEqual(digest, next(digest_iter)) digest_validator.validate.assert_called_with( '1', key_name, public_keys['a']['Value'], digest, key_name @@ -880,11 +1096,11 @@ def test_ensures_digest_from_same_location_as_json_contents(self): digest_validator=digest_validator, on_invalid=callback, ) - digest_iter = traverser.traverse(start_date, end_date) + digest_iter = traverser.traverse_digests(start_date, end_date) self.assertIsNone(next(digest_iter, None)) self.assertEqual(1, len(collected)) self.assertEqual( - 'Digest file\ts3://1/%s\tINVALID: invalid format' % key_name, + f'Digest file\ts3://1/{key_name}\tINVALID: invalid format', collected[0]['message'], ) @@ -901,7 +1117,7 @@ def test_loads_digests_in_range(self): public_key_provider=key_provider, digest_validator=validator, ) - collected = list(traverser.traverse(start_date, end_date)) + collected = list(traverser.traverse_digests(start_date, end_date)) self.assertEqual(1, key_provider.get_public_keys.call_count) self.assertEqual( 1, len(digest_provider.calls['load_digest_keys_in_range']) @@ -924,7 +1140,7 @@ def test_invokes_cb_and_continues_when_missing(self): digest_validator=validator, on_missing=on_missing, ) - collected = list(traverser.traverse(start_date, end_date)) + collected = list(traverser.traverse_digests(start_date, end_date)) self.assertEqual(3, len(collected)) self.assertEqual(1, key_provider.get_public_keys.call_count) self.assertEqual(1, len(missing_calls)) @@ -960,7 +1176,7 @@ def test_invokes_cb_and_continues_when_invalid(self): digest_validator=validator, on_invalid=on_invalid, ) - collected = list(traverser.traverse(start_date, end_date)) + collected = list(traverser.traverse_digests(start_date, end_date)) self.assertEqual(3, len(collected)) self.assertEqual(1, key_provider.get_public_keys.call_count) self.assertEqual(2, len(invalid_calls)) @@ -1002,7 +1218,7 @@ def test_invokes_cb_and_continues_when_gap(self): digest_validator=validator, on_gap=on_gap, ) - collected = list(traverser.traverse(start_date, end_date)) + collected = list(traverser.traverse_digests(start_date, end_date)) self.assertEqual(4, len(collected)) self.assertEqual(1, key_provider.get_public_keys.call_count) self.assertEqual(2, len(gap_calls)) @@ -1037,7 +1253,7 @@ def test_reloads_objects_on_bucket_change(self): public_key_provider=key_provider, digest_validator=validator, ) - collected = list(traverser.traverse(start_date, end_date)) + collected = list(traverser.traverse_digests(start_date, end_date)) self.assertEqual(4, len(collected)) self.assertEqual(1, key_provider.get_public_keys.call_count) # Ensure the provider was called correctly @@ -1082,13 +1298,219 @@ def test_does_not_hard_fail_on_invalid_signature(self): digest_validator=digest_validator, on_invalid=on_invalid, ) - digest_iter = traverser.traverse(start_date, end_date) + digest_iter = traverser.traverse_digests(start_date, end_date) next(digest_iter, None) self.assertIn( - 'Digest file\ts3://1/%s\tINVALID: ' % end_timestamp, + f'Digest file\ts3://1/{end_timestamp}\tINVALID: ', calls[0]['message'], ) + def test_traverse_backfill_digests_basic(self): + start_date = START_DATE + end_date = START_DATE + timedelta(hours=4) + key_provider, digest_provider, validator = create_scenario( + ['gap', 'link', 'link', 'link'] + ) + + traverser = DigestTraverser( + digest_provider=digest_provider, + starting_bucket='1', + starting_prefix='baz', + public_key_provider=key_provider, + digest_validator=validator, + ) + + collected = list( + traverser.traverse_digests(start_date, end_date, True) + ) + + self.assertEqual(4, len(collected)) + self.assertEqual(1, key_provider.get_public_keys.call_count) + self.assertEqual( + 1, len(digest_provider.calls['load_digest_keys_in_range']) + ) + self.assertEqual(4, len(digest_provider.calls['fetch_digest'])) + + def test_traverse_backfill_digests_with_missing(self): + start_date = START_DATE + end_date = START_DATE + timedelta(hours=4) + key_provider, digest_provider, validator = create_scenario( + ['gap', 'link', 'missing', 'link'] + ) + + on_missing, missing_calls = collecting_callback() + traverser = DigestTraverser( + digest_provider=digest_provider, + starting_bucket='1', + starting_prefix='baz', + public_key_provider=key_provider, + digest_validator=validator, + on_missing=on_missing, + ) + + collected = list( + traverser.traverse_digests(start_date, end_date, True) + ) + + self.assertEqual(3, len(collected)) + self.assertEqual(1, key_provider.get_public_keys.call_count) + self.assertEqual(1, len(missing_calls)) + self.assertIn('bucket', missing_calls[0]) + self.assertIn('next_end_date', missing_calls[0]) + + def test_traverse_backfill_digests_with_invalid(self): + start_date = START_DATE + end_date = START_DATE + timedelta(hours=5) + key_provider, digest_provider, validator = create_scenario( + ['gap', 'link', 'invalid', 'link', 'invalid'] + ) + + on_invalid, invalid_calls = collecting_callback() + traverser = DigestTraverser( + digest_provider=digest_provider, + starting_bucket='1', + starting_prefix='baz', + public_key_provider=key_provider, + digest_validator=validator, + on_invalid=on_invalid, + ) + + collected = list( + traverser.traverse_digests(start_date, end_date, True) + ) + + self.assertEqual(3, len(collected)) + self.assertEqual(1, key_provider.get_public_keys.call_count) + self.assertEqual(2, len(invalid_calls)) + + def test_traverse_backfill_digests_with_gaps(self): + start_date = START_DATE + end_date = START_DATE + timedelta(hours=4) + key_provider, digest_provider, validator = create_scenario( + ['gap', 'link', 'gap', 'gap'] + ) + + on_gap, gap_calls = collecting_callback() + traverser = DigestTraverser( + digest_provider=digest_provider, + starting_bucket='1', + starting_prefix='baz', + public_key_provider=key_provider, + digest_validator=validator, + on_gap=on_gap, + ) + + collected = list( + traverser.traverse_digests(start_date, end_date, True) + ) + + self.assertEqual(4, len(collected)) + self.assertEqual(1, key_provider.get_public_keys.call_count) + self.assertEqual(2, len(gap_calls)) + for gap_call in gap_calls: + self.assertIn('bucket', gap_call) + self.assertIn('next_key', gap_call) + + def test_traverse_backfill_digests_bucket_change(self): + start_date = START_DATE + end_date = START_DATE + timedelta(hours=4) + key_provider, digest_provider, validator = create_scenario( + ['gap', 'link', 'bucket_change', 'link'] + ) + + traverser = DigestTraverser( + digest_provider=digest_provider, + starting_bucket='1', + starting_prefix='baz', + public_key_provider=key_provider, + digest_validator=validator, + ) + + collected = list( + traverser.traverse_digests(start_date, end_date, True) + ) + + self.assertEqual(4, len(collected)) + self.assertEqual(1, key_provider.get_public_keys.call_count) + self.assertEqual( + 2, len(digest_provider.calls['load_digest_keys_in_range']) + ) + self.assertEqual( + ['1', '1', '2', '2'], [c['digestS3Bucket'] for c in collected] + ) + + def test_traverse_mixed_standard_and_backfill_digests(self): + start_date = START_DATE + end_date = START_DATE + timedelta(hours=3) + key_provider, digest_provider, validator = create_scenario( + ['gap', 'link', 'link'] + ) + + traverser = DigestTraverser( + digest_provider=digest_provider, + starting_bucket='1', + starting_prefix='baz', + public_key_provider=key_provider, + digest_validator=validator, + ) + + standard_digests = list( + traverser.traverse_digests(start_date, end_date) + ) + backfill_digests = list( + traverser.traverse_digests(start_date, end_date, True) + ) + + self.assertEqual(3, len(standard_digests)) + self.assertEqual(3, len(backfill_digests)) + self.assertEqual(2, key_provider.get_public_keys.call_count) + self.assertEqual( + 2, len(digest_provider.calls['load_digest_keys_in_range']) + ) + self.assertEqual(6, len(digest_provider.calls['fetch_digest'])) + + def test_traverse_backfill_digests_cache_miss_triggers_multiple_api_calls( + self, + ): + start_date = START_DATE + end_date = START_DATE + timedelta(hours=3) + key_provider, digest_provider, validator = create_scenario( + ['gap', 'link', 'link'] + ) + + call_count = 0 + + def mock_get_public_keys(start_date, end_date): + nonlocal call_count + call_count += 1 + if call_count == 1: + return {'2': {'Fingerprint': '2', 'Value': 'ffaa02'}} + elif call_count == 2: + return {'1': {'Fingerprint': '1', 'Value': 'ffaa01'}} + else: + return {'0': {'Fingerprint': '0', 'Value': 'ffaa00'}} + + key_provider.get_public_keys.side_effect = mock_get_public_keys + + traverser = DigestTraverser( + digest_provider=digest_provider, + starting_bucket='1', + starting_prefix='baz', + public_key_provider=key_provider, + digest_validator=validator, + ) + + collected = list( + traverser.traverse_digests(start_date, end_date, True) + ) + + self.assertEqual(3, len(collected)) + self.assertEqual(3, key_provider.get_public_keys.call_count) + self.assertEqual( + 1, len(digest_provider.calls['load_digest_keys_in_range']) + ) + self.assertEqual(3, len(digest_provider.calls['fetch_digest'])) + class TestCloudTrailCommand(BaseAWSCommandParamsTest): def test_s3_client_created_lazily(self): @@ -1163,7 +1585,9 @@ def test_creates_clients_for_buckets_in_us_east_1(self): created_client = provider.get_client('foo') self.assertEqual(s3_client, created_client) create_client_calls = session.create_client.call_args_list - self.assertEqual(create_client_calls, [mock.call('s3', 'us-east-1')]) + self.assertEqual( + create_client_calls, [mock.call('s3', region_name='us-east-1')] + ) self.assertEqual(1, s3_client.get_bucket_location.call_count) def test_creates_clients_for_buckets_outside_us_east_1(self): @@ -1179,7 +1603,10 @@ def test_creates_clients_for_buckets_outside_us_east_1(self): create_client_calls = session.create_client.call_args_list self.assertEqual( create_client_calls, - [mock.call('s3', 'us-west-1'), mock.call('s3', 'us-west-2')], + [ + mock.call('s3', region_name='us-west-1'), + mock.call('s3', region_name='us-west-2'), + ], ) self.assertEqual(1, s3_client.get_bucket_location.call_count) @@ -1217,4 +1644,4 @@ def test_removes_cli_error_events(self): session.create_client.return_value = s3_client s3_client.get_bucket_location.return_value = {'LocationConstraint': ''} provider = S3ClientProvider(session) - client = provider.get_client('foo') + provider.get_client('foo')