Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 57 additions & 3 deletions pre_commit_hooks/detect_aws_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import argparse
import configparser
import json
import os
from collections.abc import Sequence
from typing import NamedTuple
Expand All @@ -28,13 +29,44 @@ def get_aws_secrets_from_env() -> set[str]:
"""Extract AWS secrets from environment variables."""
keys = set()
for env_var in (
'AWS_SECRET_ACCESS_KEY', 'AWS_SECURITY_TOKEN', 'AWS_SESSION_TOKEN',
'AWS_SECRET_ACCESS_KEY', 'AWS_SECURITY_TOKEN', 'AWS_SESSION_TOKEN',
):
if os.environ.get(env_var):
keys.add(os.environ[env_var])
return keys


def get_aws_secrets_from_json(json_credentials_file: str) -> set[str]:
"""Extract AWS secrets from JSON configuration files.

Read a JSON-style configuration file and return a set with all found AWS
secret access keys.
"""
aws_credentials_file_path = os.path.expanduser(json_credentials_file)

with open(aws_credentials_file_path, encoding='utf-8') as f:
try:
data = json.load(f)
except json.JSONDecodeError:
return set()

keys = set()
for var in (
'AccessKeyId',
'SecretAccessKey',
'SessionToken',
'accessKeyId',
'secretAccessKey',
'sessionToken',
):
if var in data.get('Credentials', {}):
keys.add(data['Credentials'][var])

if var in data.get('accessToken', {}):
keys.add(data['accessToken'][var])
return keys


def get_aws_secrets_from_file(credentials_file: str) -> set[str]:
"""Extract AWS secrets from configuration files.

Expand All @@ -54,8 +86,8 @@ def get_aws_secrets_from_file(credentials_file: str) -> set[str]:
keys = set()
for section in parser.sections():
for var in (
'aws_secret_access_key', 'aws_security_token',
'aws_session_token',
'aws_secret_access_key', 'aws_security_token',
'aws_session_token',
):
try:
key = parser.get(section, var).strip()
Expand Down Expand Up @@ -104,6 +136,16 @@ def main(argv: Sequence[str] | None = None) -> int:
'secret keys. Can be passed multiple times.'
),
)
parser.add_argument(
'--json-credentials-dir',
dest='json_credential_dir',
action='append',
default=['~/.aws/cli/cache/', '~/.aws/login/cache/'],
help=(
'Location of AWS JSON credential files from which to get '
'secret keys. Can be passed multiple times.'
),
)
parser.add_argument(
'--allow-missing-credentials',
dest='allow_missing_credentials',
Expand All @@ -113,6 +155,16 @@ def main(argv: Sequence[str] | None = None) -> int:
args = parser.parse_args(argv)

credential_files = set(args.credentials_file)
json_credential_dirs = set(args.json_credential_dir)
json_credential_files = set()
for json_credential_dir in json_credential_dirs:
if os.path.isdir(os.path.expanduser(json_credential_dir)):
for file in os.listdir(os.path.expanduser(json_credential_dir)):
if file.endswith('.json'):
(
json_credential_files
.add(os.path.join(json_credential_dir, file))
)

# Add the credentials files configured via environment variables to the set
# of files to to gather AWS secrets from.
Expand All @@ -122,6 +174,8 @@ def main(argv: Sequence[str] | None = None) -> int:
for credential_file in credential_files:
keys |= get_aws_secrets_from_file(credential_file)

for json_credential_file in json_credential_files:
keys |= get_aws_secrets_from_json(json_credential_file)
# Secrets might be part of environment variables, so add such secrets to
# the set of keys.
keys |= get_aws_secrets_from_env()
Expand Down
7 changes: 7 additions & 0 deletions testing/resources/aws_temp_credentials_secrets_file.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"Credentials": {
"AccessKeyId": "tempAccessKeyId",
"SecretAccessKey": "tempSecretAccessKey",
"SessionToken": "tempSessionToken"
}
}
7 changes: 7 additions & 0 deletions testing/resources/aws_temp_secrets_file.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"accessToken": {
"accessKeyId": "tempAccessKeyId",
"secretAccessKey": "tempSecretAccessKey",
"sessionToken": "tempSessionToken"
}
}
48 changes: 48 additions & 0 deletions tests/detect_aws_credentials_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pre_commit_hooks.detect_aws_credentials import get_aws_cred_files_from_env
from pre_commit_hooks.detect_aws_credentials import get_aws_secrets_from_env
from pre_commit_hooks.detect_aws_credentials import get_aws_secrets_from_file
from pre_commit_hooks.detect_aws_credentials import get_aws_secrets_from_json
from pre_commit_hooks.detect_aws_credentials import main
from testing.util import get_resource_path

Expand Down Expand Up @@ -68,6 +69,35 @@ def test_get_aws_secrets_from_env(env_vars, values):
assert get_aws_secrets_from_env() == values


@pytest.mark.parametrize(
('filename', 'expected_keys'),
(
(
'aws_temp_secrets_file.json',
{
'tempAccessKeyId',
'tempSecretAccessKey',
'tempSessionToken',
},
),
(
'aws_temp_credentials_secrets_file.json',
{
'tempAccessKeyId',
'tempSecretAccessKey',
'tempSessionToken',
},
),
('nonsense.txt', set()),
('ok_json.json', set()),
),
)
def test_get_aws_secrets_from_json(filename, expected_keys):
"""Test that reading secrets from files works."""
keys = get_aws_secrets_from_json(get_resource_path(filename))
assert keys == expected_keys


@pytest.mark.parametrize(
('filename', 'expected_keys'),
(
Expand Down Expand Up @@ -119,6 +149,8 @@ def test_detect_aws_credentials(filename, expected_retval):
get_resource_path(filename),
'--credentials-file',
'testing/resources/aws_config_with_multiple_sections.ini',
'--json-credentials-dir',
'testing/resources/',
))
assert ret == expected_retval

Expand All @@ -145,6 +177,7 @@ def test_non_existent_credentials(mock_secrets_env, mock_secrets_file, capsys):
ret = main((
get_resource_path('aws_config_without_secrets.ini'),
'--credentials-file=testing/resources/credentailsfilethatdoesntexist',
'--json-credentials-dir=directorythatdoesnotexist',
))
assert ret == 2
out, _ = capsys.readouterr()
Expand All @@ -168,3 +201,18 @@ def test_non_existent_credentials_with_allow_flag(
'--allow-missing-credentials',
))
assert ret == 0


@patch('pre_commit_hooks.detect_aws_credentials.get_aws_secrets_from_file')
@patch('pre_commit_hooks.detect_aws_credentials.get_aws_secrets_from_env')
def test_non_existent_json_credentials_with_allow_flag(
mock_secrets_env, mock_secrets_file,
):
mock_secrets_env.return_value = set()
mock_secrets_file.return_value = set()
ret = main((
get_resource_path('aws_config_without_secrets.ini'),
'--json-credentials-dir=credentailsfilethatdoesntexist',
'--allow-missing-credentials',
))
assert ret == 0
Loading