Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 27 additions & 22 deletions cmd/newtowner/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,18 @@ const (
)

var (
providerFlag *string
urlsFilePathFlag *string
awsRegionFlag *string
brightdataCountryFlag *string
updateDBFlag *bool
awsAllRegionsFlag *bool
sshHostFlag *string
sshPortFlag *int
sshUserFlag *string
sshKeyPathFlag *string
sshPassphraseFlag *string
providerFlag *string
urlsFilePathFlag *string
awsRegionFlag *string
brightdataCountryFlag *string
updateDBFlag *bool
awsAllRegionsFlag *bool
sshHostFlag *string
sshPortFlag *int
sshUserFlag *string
sshKeyPathFlag *string
sshPassphraseFlag *string
legacyRenegotiationFlag *bool
// EC2 flags
ec2HostFlag *string
ec2PortFlag *int
Expand All @@ -68,6 +69,7 @@ func init() {
sshUserFlag = pflag.String("ssh-user", "", "SSH user for the SSH provider")
sshKeyPathFlag = pflag.String("ssh-key-path", "", "Path to the SSH private key for the SSH provider")
sshPassphraseFlag = pflag.String("ssh-passphrase", "", "Passphrase for the SSH private key, if protected")
legacyRenegotiationFlag = pflag.Bool("legacy-renegotiation", false, "Allow unsafe legacy TLS renegotiation in the remote Python runner (use only when required)")
// EC2 flags
ec2HostFlag = pflag.String("ec2-host", "", "EC2 host for the EC2 provider")
ec2PortFlag = pflag.Int("ec2-port", 22, "EC2 port for the EC2 provider")
Expand Down Expand Up @@ -119,6 +121,7 @@ func init() {
viper.BindPFlag("ssh_user", pflag.Lookup("ssh-user"))
viper.BindPFlag("ssh_private_key_path", pflag.Lookup("ssh-key-path"))
viper.BindPFlag("ssh_passphrase", pflag.Lookup("ssh-passphrase"))
viper.BindPFlag("legacy_renegotiation", pflag.Lookup("legacy-renegotiation"))

// EC2 Provider Settings
viper.BindPFlag("ec2_host", pflag.Lookup("ec2-host"))
Expand Down Expand Up @@ -246,6 +249,7 @@ func main() {
log.Printf(" EC2 User: %s", cfg.EC2User)
log.Printf(" EC2 Key Path: %s", cfg.EC2PrivateKeyPath)
log.Printf(" EC2 Key Passphrase Provided: %t", cfg.EC2Passphrase != "")
log.Printf(" Legacy TLS renegotiation (remote runner): %t", cfg.LegacyRenegotiation)

selectedProvider := strings.ToLower(viper.GetString("provider"))
urlsFilePath := viper.GetString("urlsfile")
Expand Down Expand Up @@ -377,7 +381,7 @@ func main() {

case "ssh":
log.Println("Initializing SSH provider...")
sshProvider, err := ssh.NewProvider(ctx, cfg.SshHost, cfg.SshPort, cfg.SshUser, cfg.SshPrivateKeyPath, cfg.SshPassphrase)
sshProvider, err := ssh.NewProvider(ctx, cfg.SshHost, cfg.SshPort, cfg.SshUser, cfg.SshPrivateKeyPath, cfg.SshPassphrase, cfg.LegacyRenegotiation)
if err != nil {
log.Fatalf("Error initializing SSH provider: %v", err)
}
Expand Down Expand Up @@ -406,16 +410,17 @@ func main() {
log.Println("Initializing EC2 dual SSH provider...")

ec2Config := aws.EC2ProviderConfig{
EC2Host: cfg.EC2Host,
EC2Port: cfg.EC2Port,
EC2User: cfg.EC2User,
EC2KeyPath: cfg.EC2PrivateKeyPath,
EC2Passphrase: cfg.EC2Passphrase,
SshHost: cfg.SshHost,
SshPort: cfg.SshPort,
SshUser: cfg.SshUser,
SshKeyPath: cfg.SshPrivateKeyPath,
SshPassphrase: cfg.SshPassphrase,
EC2Host: cfg.EC2Host,
EC2Port: cfg.EC2Port,
EC2User: cfg.EC2User,
EC2KeyPath: cfg.EC2PrivateKeyPath,
EC2Passphrase: cfg.EC2Passphrase,
SshHost: cfg.SshHost,
SshPort: cfg.SshPort,
SshUser: cfg.SshUser,
SshKeyPath: cfg.SshPrivateKeyPath,
SshPassphrase: cfg.SshPassphrase,
AllowLegacyRenegotiation: cfg.LegacyRenegotiation,
}

ec2Provider, err := aws.NewEC2Provider(ctx, ec2Config)
Expand Down
3 changes: 3 additions & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,7 @@ type Config struct {
EC2User string `mapstructure:"ec2_user"` // EC2 instance SSH username
EC2PrivateKeyPath string `mapstructure:"ec2_private_key_path"` // Path to the EC2 instance SSH private key
EC2Passphrase string `mapstructure:"ec2_passphrase"` // Passphrase for the EC2 instance SSH private key

// SSL / TLS toggles
LegacyRenegotiation bool `mapstructure:"legacy_renegotiation"` // Allow unsafe legacy TLS renegotiation in remote runner (off by default)
}
19 changes: 13 additions & 6 deletions internal/providers/aws/ec2.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ type EC2ProviderConfig struct {
SshUser string
SshKeyPath string
SshPassphrase string

AllowLegacyRenegotiation bool
}

type EC2URLCheckResult struct {
Expand Down Expand Up @@ -144,6 +146,7 @@ type EC2Provider struct {
sshRemotePythonRunnerPath string
sshGeoLocation string
comparisonRegion string
allowLegacyRenegotiation bool
}

// NewEC2Provider creates a new EC2 dual SSH provider instance.
Expand All @@ -162,12 +165,13 @@ func NewEC2Provider(ctx context.Context, config EC2ProviderConfig) (*EC2Provider
}

p := &EC2Provider{
ec2Host: config.EC2Host,
ec2Port: config.EC2Port,
ec2User: config.EC2User,
sshHost: config.SshHost,
sshPort: config.SshPort,
sshUser: config.SshUser,
ec2Host: config.EC2Host,
ec2Port: config.EC2Port,
ec2User: config.EC2User,
sshHost: config.SshHost,
sshPort: config.SshPort,
sshUser: config.SshUser,
allowLegacyRenegotiation: config.AllowLegacyRenegotiation,
}

if p.ec2Port <= 0 {
Expand Down Expand Up @@ -473,6 +477,9 @@ func (p *EC2Provider) executeRemoteCheck(ctx context.Context, targetURL string,

quotedTargetURL := strconv.Quote(targetURL)
remoteCommand := fmt.Sprintf("python3 %s %s --output_file %s", runnerPath, quotedTargetURL, remoteTempJSONPath)
if p.allowLegacyRenegotiation {
remoteCommand += " --legacy-renegotiation"
}
log.Printf("[EC2Provider] Executing %s remote command: %s", logPrefix, remoteCommand)

session, err := sshClient.NewSession()
Expand Down
27 changes: 16 additions & 11 deletions internal/providers/ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,18 @@ func (r URLCheckResult) ShouldSkipBodyDiff() bool {

// Provider struct holds SSH configuration.
type Provider struct {
sshHost string
sshPort int
sshUser string
sshPrivateKeyPath string
remotePythonRunnerPath string
sshClient *ssh.Client
sshPassphrase string
sshHost string
sshPort int
sshUser string
sshPrivateKeyPath string
remotePythonRunnerPath string
sshClient *ssh.Client
sshPassphrase string
allowLegacyRenegotiation bool
}

// NewProvider creates a new SSH Comparison provider instance.
func NewProvider(ctx context.Context, host string, port int, user string, keyPath string, passphrase string) (*Provider, error) {
func NewProvider(ctx context.Context, host string, port int, user string, keyPath string, passphrase string, allowLegacyRenegotiation bool) (*Provider, error) {
if host == "" {
return nil, fmt.Errorf("SSH host must be provided")
}
Expand All @@ -113,9 +114,10 @@ func NewProvider(ctx context.Context, host string, port int, user string, keyPat
}

p := &Provider{
sshHost: host,
sshPort: port,
sshUser: user,
sshHost: host,
sshPort: port,
sshUser: user,
allowLegacyRenegotiation: allowLegacyRenegotiation,
}

if p.sshPort <= 0 {
Expand Down Expand Up @@ -333,6 +335,9 @@ func (p *Provider) executeRemoteSshCheck(ctx context.Context, targetURL string)

quotedTargetURL := strconv.Quote(targetURL)
remoteCommand := fmt.Sprintf("python3 %s %s --output_file %s", p.remotePythonRunnerPath, quotedTargetURL, remoteTempJSONPath)
if p.allowLegacyRenegotiation {
remoteCommand += " --legacy-renegotiation"
}
log.Printf("[SSH Provider] Executing remote command: %s", remoteCommand)

session, err := sshClient.NewSession()
Expand Down
68 changes: 56 additions & 12 deletions scripts/http_check_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,58 @@
import sys
import ssl
import urllib.parse
import urllib.error
import socket
from cryptography import x509
from cryptography.hazmat.backends import default_backend
import base64 # Ensure base64 is imported at the top level
from requests.adapters import HTTPAdapter
from urllib3.util.ssl_ import create_urllib3_context

def make_request(target_url):

def create_ssl_context(allow_legacy: bool) -> ssl.SSLContext:
"""Create an SSL context with verification disabled; optionally allow legacy renegotiation."""
context = create_urllib3_context()
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE

if allow_legacy:
context.options |= getattr(ssl, "OP_LEGACY_SERVER_CONNECT", 0)
# Always set the legacy renegotiation bit for older OpenSSL builds.
# https://github.com/urllib3/urllib3/issues/2653
context.options |= 0x4
no_reno_flag = getattr(ssl, "OP_NO_RENEGOTIATION", 0)
if no_reno_flag:
context.options &= ~no_reno_flag

return context


class CustomHttpAdapter(HTTPAdapter):
"""Transport adapter that injects our SSL context into urllib3."""

def __init__(self, ssl_context: ssl.SSLContext, *args, **kwargs):
self.ssl_context = ssl_context
super().__init__(*args, **kwargs)

def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs):
pool_kwargs["ssl_context"] = self.ssl_context
return super().init_poolmanager(connections, maxsize, block=block, **pool_kwargs)

def proxy_manager_for(self, proxy, **proxy_kwargs):
proxy_kwargs["ssl_context"] = self.ssl_context
return super().proxy_manager_for(proxy, **proxy_kwargs)


def create_session(allow_legacy: bool):
"""Build a requests session using the supplied SSL context."""
ssl_context = create_ssl_context(allow_legacy)
adapter = CustomHttpAdapter(ssl_context)
session = requests.Session()
session.mount("https://", adapter)
session.mount("http://", adapter)
session.verify = False
return session, ssl_context

def make_request(target_url, allow_legacy_renegotiation=False):
"""Makes an HTTP GET request and returns details as a dictionary."""
details = {
"url": target_url,
Expand All @@ -26,6 +72,8 @@ def make_request(target_url):
"ssl_certificate_error": None,
}

session, ssl_context = create_session(allow_legacy_renegotiation)

# Fetch SSL certificate
parsed_url = None
try:
Expand All @@ -38,12 +86,7 @@ def make_request(target_url):
conn = None
try:
sock = socket.create_connection((hostname, port), timeout=10)
# Create a context that does not verify certificates
context = ssl.create_default_context()
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE

conn = context.wrap_socket(sock, server_hostname=hostname)
conn = ssl_context.wrap_socket(sock, server_hostname=hostname)

der_cert_bin = conn.getpeercert(True) # Get DER-encoded certificate

Expand Down Expand Up @@ -91,7 +134,7 @@ def make_request(target_url):
start_time = time.perf_counter()

try:
response = requests.get(target_url, timeout=30, allow_redirects=True, verify=False)
response = session.get(target_url, timeout=30, allow_redirects=True)
details["status_code"] = response.status_code

# Ensure headers are map[string][]string for Go unmarshalling
Expand Down Expand Up @@ -134,6 +177,7 @@ def make_request(target_url):
parser = argparse.ArgumentParser(description="Make an HTTP request and save results to JSON.")
parser.add_argument("target_url", help="The URL to make the request to. Can be a single URL or comma-separated URLs.")
parser.add_argument("--output_file", default="result.json", help="Path to save the JSON output.")
parser.add_argument("--legacy-renegotiation", action="store_true", help="Allow unsafe legacy TLS renegotiation (use only when necessary).")
args = parser.parse_args()

urls_to_check = [url.strip() for url in args.target_url.split(',') if url.strip()]
Expand All @@ -155,7 +199,7 @@ def make_request(target_url):

for url_to_check_single in urls_to_check:
print(f"Making request to: {url_to_check_single}", file=sys.stderr)
result_data = make_request(url_to_check_single)
result_data = make_request(url_to_check_single, allow_legacy_renegotiation=args.legacy_renegotiation)
all_results.append(result_data)

try:
Expand All @@ -165,4 +209,4 @@ def make_request(target_url):
except IOError as e:
print(f"Error writing to output file {args.output_file}: {e}")
print("Results JSON (stdout fallback):")
print(json.dumps(all_results, indent=4))
print(json.dumps(all_results, indent=4))