diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md index 6d128a0..1430942 100644 --- a/docs/ARCHITECTURE.md +++ b/docs/ARCHITECTURE.md @@ -51,11 +51,19 @@ storage-e2e/ │ │ └── vm_block_device.go │ │ │ ├── infrastructure/ # Infrastructure layer -│ │ └── ssh/ # SSH operations +│ │ └── ssh/ # SSH operations (legacy) │ │ ├── client.go │ │ ├── interface.go │ │ ├── tunnel.go -│ │ └── types.go +│ │ ├── types.go +│ │ └── v2/ # Self-healing SSH client (Dialer/Route + Tunnel) +│ │ ├── client.go # New, Client, Close + package docs +│ │ ├── conn.go # connection core: snapshot/refresh/keepalive + withConn +│ │ ├── dialer.go # Dialer interface, Route, chain closer +│ │ ├── endpoint.go # Endpoint, auth, host/key resolution +│ │ ├── errors.go # transient classification +│ │ ├── options.go # functional options +│ │ └── tunnel.go # Tunnel, accept loop │ │ │ └── logger/ # Structured logging │ ├── logger.go # Logger implementation @@ -445,10 +453,18 @@ internal/kubernetes/ # Internal Kubernetes clients ``` infrastructure/ssh/ -├── client.go # SSH client implementation (Exec, ExecCapture, tunnels) -├── interface.go # SSH client interface -├── tunnel.go # Port forwarding and tunneling -└── types.go # SSH-related types +├── client.go # SSH client implementation (Exec, ExecCapture, tunnels) [legacy] +├── interface.go # SSH client interface [legacy] +├── tunnel.go # Port forwarding and tunneling [legacy] +├── types.go # SSH-related types [legacy] +└── v2/ # Self-healing SSH client (see below) + ├── client.go # New, Client, Close + package docs + ├── conn.go # connection core: snapshot/refresh/keepalive + withConn executor + ├── dialer.go # Dialer interface, Route, chain closer + ├── endpoint.go # Endpoint, auth, host/key resolution + ├── errors.go # transient classification + ├── options.go # functional options + └── tunnel.go # Tunnel, accept loop ``` **Responsibilities**: @@ -466,6 +482,48 @@ infrastructure/ssh/ - `ExecCapture` keeps stdout and stderr separate while preserving retry/reconnect behavior - Proper resource cleanup +#### 3.4.1 Self-healing SSH client (`internal/infrastructure/ssh/v2/`) + +A ground-up rewrite that lives in parallel with the legacy package (no consumers +migrated yet). It separates **how we connect** (directly or via jump hosts) from +**what we do over the connection** (currently only tunneling), and hides every +reconnect from callers. + +**Design**: + +- `Dialer` is the injection point: `Dial(ctx) (*ssh.Client, io.Closer, error)` + + `Describe()`. `Route(first Endpoint, more ...Endpoint)` builds the built-in + implementation; the last hop is always the target, so the `(first, more...)` + signature guarantees at least one hop at compile time. The returned `io.Closer` + tears down the whole chain (target + every jump + ssh-agent connections). +- `Endpoint` describes a single host: `User`, `Addr` (`host` or `host:port`, + default `:22`), `KeyPath` (`~` expanded), optional `Passphrase` + (falls back to `SSH_PASSPHRASE` then ssh-agent), optional per-hop `HostKey`. +- The unexported `conn` core owns the current `*ssh.Client`, its chain `Closer`, + and a generation counter under a mutex. `snapshot` reads them; `refresh` + re-dials via `singleflight` keyed on the failed generation so concurrent + reconnects collapse into one and a stale generation never tears down a freshly + healed link. The slow `Dial` runs outside the lock on a detached context + (`context.WithoutCancel` + timeout) so one caller's cancellation can't abort + the shared flight. +- A single generic executor `withConn[T]` runs an operation against the live + client and heals on transient failures (bounded by `WithRetries`); the tunnel + uses it today and `Run`/`Upload` are designed to reuse it unchanged. +- Optional keepalive (`WithKeepalive`) probes the link and heals through the same + `refresh` path; every heal is logged at WARN. + +**Public API v1**: `New(ctx, Dialer, ...Option)`, `Client.Tunnel(ctx, remotePort)` +(self-healing local forward on a free `127.0.0.1` port; `Tunnel.LocalAddr`, +`Tunnel.Close`), `Client.Close`. Options: `WithKeepalive`, `WithRetries`, +`WithLogger`, `WithHostKeyCallback`, `WithInsecureIgnoreHostKey` (host key +defaults to `InsecureIgnoreHostKey` — a conscious default for ephemeral e2e VMs). + +**Extension points (designed, not yet implemented)**: `Run` (transparent retry +only when the session fails to open; mid-flight drops heal but surface the error +to avoid double side effects; opt-in `Idempotent` for true retry) and `Upload`. +Transient-error classification uses `errors.Is`/`errors.As` against standard +types — never error-string matching. + ### 3.5 Logger Module (`internal/logger/`) ``` diff --git a/docs/WORKLOG.md b/docs/WORKLOG.md index c019ab6..c217718 100644 --- a/docs/WORKLOG.md +++ b/docs/WORKLOG.md @@ -4,12 +4,6 @@ All notable changes to this repository are documented here. New entries are appe --- -## 2026-06-07 - -- **Update** `.github/workflows/unit-tests.yml`: integrate GitHub native code coverage (per-push) — add `code-quality: write` + `pull-requests: read` permissions, convert `coverage.out` to Cobertura XML via `boumenot/gocover-cobertura`, and publish with `actions/upload-code-coverage@v1`; coverage artifact now also includes `coverage.xml` - ---- - ## 2026-05-06 - **Add** `UploadPrivate` on `ssh.SSHClient` (`internal/infrastructure/ssh`): SFTP `Chmod` immediately after `Create`, before payload copy; `uploadOverSFTPOnce`, `uploadWithSFTPRetries`, `jumpUploadWithSFTPRetries`; passphrase `BootstrapCluster` uses it with `install -d -m 0700` staging (`pkg/cluster/setup.go`); ARCHITECTURE mentions ssh uploads @@ -111,54 +105,6 @@ All notable changes to this repository are documented here. New entries are appe --- -## 2026-06-03 - -- **Add** `.github/workflows/unit-tests.yml`: mandatory CI workflow that builds, vets and runs unit tests on every push (any branch) and on PRs to `main`; uses `go-version-file: go.mod`, `-race -shuffle=on -covermode=atomic`, uploads `coverage.out` artifact, scoped to `./internal/... ./pkg/...` so e2e suites stay off CI. -- **Add** `Makefile`: `test` / `cover` / `vet` / `build` / `e2e` / `clean` targets mirroring the CI commands; `.gitignore` for `coverage.out` / `coverage.html`. -- **Add** Wave 1 unit tests (`pkg/retry/retry_test.go`, `pkg/kubernetes/{apply,modules,poll}_test.go`, `pkg/cluster/vms_test.go`, `pkg/testkit/stress_tests_test.go`, `internal/config/types_yaml_test.go`, `internal/kubernetes/commander/client_test.go`, `internal/logger/level_test.go`): hermetic table-driven coverage of `retry.Do/IsRetryable/IsSSHConnectionError/WithRetryAfter`, YAML doc splitting/env-var scanning, module graph + topo sort + cycle detection, `cluster/vms` pure helpers, `commander` mappers / base64 / `NewClientWithOptions` validation, `stress-tests.Config.Validate` / `DefaultConfig`, `LevelToString` round-trip, `ClusterNode`/`ClusterDefinition` YAML unmarshal validation. -- **Add** Wave 2 httptest tests (`internal/kubernetes/commander/client_http_test.go`): drives the Commander HTTP client (`GetClusterByID`, `ListClustersAPI` array/items/data/garbage, `GetClusterByName`, `CreateClusterFromTemplate`, `DeleteClusterByID`, `GetClusterKubeconfigByID` + cluster-details fallback, `GetRegistryByName`, `GetClusterConnectionInfo` precedence + defaults) and all five `setAuthHeaders` paths via a real `httptest.Server`. -- **Update** `docs/TESTS_IMPLEMENTATION_PLAN.md`: triggers changed from `push → main` to push-on-any-branch + `pull_request → main`; status header refreshed; rollout phases marked Done/Pending; exact `gh api` branch-protection command documented. -- **Update** `.github/workflows/gitleaks-scan-on-pr.yml` → renamed to `.github/workflows/gitleaks.yml`: workflow `name` shortened to `Gitleaks`, added `push: {}` trigger so secret scanning runs on every push (any branch), not only on PRs; added cancel-in-progress concurrency group. -- **Update** `.github/workflows/gitleaks.yml`: split into two jobs gated by `github.event_name` — `gitleaks_diff` (`scan_mode: diff`) for `pull_request`, `gitleaks_full` (`scan_mode: full`) for `push`; fixes `fatal: invalid refspec '+refs/pull//merge:...'` that broke push runs because the upstream action's diff mode needs `github.event.number`. Both jobs share check name `Gitleaks scan`. -- **Update** `.github/workflows/gitleaks.yml`: reverted to `pull_request`-only (single `gitleaks_scan` job, `scan_mode: diff`); dropped the `push` trigger because the upstream action's diff mode needs `github.event.number` and fails on push with `fatal: invalid refspec '+refs/pull//merge:...'`. -- **Add** `.gitleaksignore`: ignores the `generic-api-key` false positive on `internal/kubernetes/commander/client_test.go:75` (base64 test fixture) by fingerprint at commit `5f1edc2`; the diff scan flags the introducing commit, so the later inline `gitleaks:allow` could not suppress it. - ---- - ## 2026-06-08 - **Add** `pkg/config/config_test.go`: unit tests for `config.New` covering provider parsing, missing required `TEST_CLUSTER_PROVIDER` (error), empty-value handling, and table-driven provider values. - ---- - -## 2026-06-19 - -- **Bugfix** `internal/config.ResolveModulePullOverrides`: detect malformed `${...}` on the original string (stripping - valid refs first) instead of the resolved value, avoiding a false "malformed" error when an env value itself contains - `${...}`. -- **Add** `pkg/clusterprovider/registry/registry_test.go`: table/unit tests for `Registry` covering `NewRegistry` - seeding the built-in DVP provider, `Get` for registered/unregistered modes, `Register` add + replace semantics, - `DefaultRegistry` contents, and a race-detector concurrency test for `Register`/`Get` - -## 2026-06-22 - -- **Add** `.github/workflows/e2e-reusable.yml`: reusable three-job E2E pipeline (`create-cluster` mocked, `run-tests` mirrors `build_dev` flow, `teardown-cluster` mocked); SSH tunnel, `go mod replace`, Ginkgo label filter, 90m minimum suite timeout. -- **Add** `.github/scripts/e2e-prepare-env.sh`, `.github/scripts/e2e-prepare-workspace.sh`: helper scripts for secrets materialisation and self-hosted runner workspace cleanup. -- **Add** `docs/CI.md`: documents the reusable workflow design, inputs, secrets, and run-tests flow. -- **Update** `README.md`: add CI section linking to `docs/CI.md`. -- **Update** `.github/workflows/e2e-reusable.yml`: add `noop` pipeline_mode (all jobs echo mocked, no real steps run); add `test_suite` input (default `TestSdsNodeConfigurator`) to decouple hardcoded suite name from workflow. -- **Add** `.github/workflows/e2e-self-test.yml`: self-test caller that triggers the reusable workflow in `noop` mode on PRs touching CI files. -- **Update** `.github/workflows/e2e-reusable.yml`: add `skip_storage_e2e_replace` boolean input; gate `checkout storage-e2e`, `go mod edit -replace`, and `setup-go` (with dual-path cache) on this flag so storage-e2e can call the workflow without circular self-reference. -- **Update** `.github/workflows/e2e-self-test.yml`: set `skip_storage_e2e_replace: true`, `test_package: ./tests/test-template/`, `test_suite: TestTemplate`. ---- - -## 2026-06-23 - -- **Add** `gitleaks.toml`: content-based allowlist (`[extend] useDefault=true` + `regexTarget="line"` regex for `dXNlcjp0b2tlbg==`) for the base64 test fixture in `internal/kubernetes/commander/client_test.go`. Replaces the commit-pinned `.gitleaksignore` fingerprint, which broke after rebasing `unit-tests` onto `main` (the introducing commit's SHA changed `5f1edc2`→`35e9bc7`). The regex allowlist survives history rewrites. -- **Bugfix** lint fixes in unit-test files surfaced by `main`'s golangci-lint config (after rebase): `pkg/retry/retry_test.go` (gocritic paramTypeCombine on `statusErr`, `cancelled`→`canceled` misspellings), `internal/kubernetes/commander/client_http_test.go` (`behaviour`→`behavior`, gofmt), `pkg/testkit/stress_tests_test.go` (gofmt), `pkg/kubernetes/apply_test.go` (dropped ineffectual `got` assignment in `FindUnsetEnvVars` test), `pkg/cluster/vms_test.go` (staticcheck QF1001 De Morgan simplification). - ---- - -## 2026-06-24 - -- **Remove** `.github/workflows/unit-tests.yml` per PR #20 review: `main`'s `.github/workflows/go-checks.yml` already runs lint + race-enabled unit tests + coverage publishing, so the dedicated workflow was a duplicate. Updated the `Makefile` header comment to point at `go-checks.yml` instead of the removed workflow. diff --git a/go.mod b/go.mod index dfeba55..256cc65 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/onsi/gomega v1.39.1 github.com/pkg/sftp v1.13.10 golang.org/x/crypto v0.52.0 + golang.org/x/sync v0.21.0 golang.org/x/term v0.43.0 gopkg.in/yaml.v3 v3.0.1 k8s.io/api v0.34.2 @@ -248,7 +249,6 @@ require ( golang.org/x/mod v0.35.0 // indirect golang.org/x/net v0.54.0 // indirect golang.org/x/oauth2 v0.30.0 // indirect - golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.45.0 // indirect golang.org/x/text v0.37.0 // indirect golang.org/x/time v0.12.0 // indirect diff --git a/go.sum b/go.sum index 2c9ebec..c9dd39c 100644 --- a/go.sum +++ b/go.sum @@ -716,8 +716,8 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= -golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/sync v0.21.0 h1:HLII4xRRTtCRkxYp4HNFF0Js/Og6q2i++KXbg0gHCwM= +golang.org/x/sync v0.21.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/internal/infrastructure/ssh/v2/client.go b/internal/infrastructure/ssh/v2/client.go new file mode 100644 index 0000000..5c761a5 --- /dev/null +++ b/internal/infrastructure/ssh/v2/client.go @@ -0,0 +1,74 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package ssh provides a self-healing SSH client whose connection strategy +// ("how we connect" — directly or through jump hosts) is separated from the +// operations performed over it ("what we do" — currently tunneling). +// +// The injection point is the Dialer: Route builds one for a direct connection or +// an arbitrary chain of jump hosts. New opens a Client over a Dialer and hides +// every reconnect: callers invoke methods and never reason about reconnection. +// All operations funnel through a single reconnect-aware executor (withConn) over +// a shared connection core (conn), so future operations such as Run and Upload +// can be added without touching the healing logic. +// +// The primary use case is opening a tunnel to the API server of a closed +// Kubernetes cluster and pointing a kubeconfig at it: +// +// c, _ := ssh.New(ctx, ssh.Route(jumpEp, targetEp)) +// defer c.Close() +// t, _ := c.OpenTunnel(ctx, 6443) +// defer t.Close() +// rest := &rest.Config{Host: "https://" + t.LocalAddr()} +package ssh + +import ( + "context" + "errors" + "log/slog" +) + +type Client struct { + conn *conn + retries int + log *slog.Logger +} + +func New(ctx context.Context, d Dialer, opts ...Option) (*Client, error) { + if d == nil { + return nil, errors.New("ssh: nil dialer") + } + + o := defaultOptions() + for _, opt := range opts { + opt(&o) + } + + if hkd, ok := d.(hostKeyDefaulter); ok { + hkd.setDefaultHostKey(o.hostKey) + } + + core, err := newConn(ctx, d, o) + if err != nil { + return nil, err + } + + return &Client{conn: core, retries: o.retries, log: o.log}, nil +} + +func (c *Client) Close() error { + return c.conn.Close() +} diff --git a/internal/infrastructure/ssh/v2/conn.go b/internal/infrastructure/ssh/v2/conn.go new file mode 100644 index 0000000..b8f5cb9 --- /dev/null +++ b/internal/infrastructure/ssh/v2/conn.go @@ -0,0 +1,261 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ssh + +import ( + "context" + "fmt" + "io" + "log/slog" + "strconv" + "sync" + "time" + + "golang.org/x/crypto/ssh" + "golang.org/x/sync/singleflight" +) + +type conn struct { + dialer Dialer + log *slog.Logger + dialTimeout time.Duration + + flight singleflight.Group + + mu sync.Mutex + client *ssh.Client + closer io.Closer + gen uint64 + closed bool + + kaCancel context.CancelFunc + wg sync.WaitGroup +} + +func newConn(ctx context.Context, d Dialer, o options) (*conn, error) { + client, closer, err := d.Dial(ctx) + if err != nil { + return nil, fmt.Errorf("connect to %s: %w", d.Describe(), err) + } + + c := &conn{ + dialer: d, + log: o.log, + dialTimeout: o.dialTimeout, + client: client, + closer: closer, + gen: 1, + } + + if o.keepalive > 0 { + kaCtx, cancel := context.WithCancel(context.Background()) + c.kaCancel = cancel + c.wg.Add(1) + go c.keepaliveLoop(kaCtx, o.keepalive) + } + + return c, nil +} + +func (c *conn) snapshot() (client *ssh.Client, gen uint64) { + c.mu.Lock() + defer c.mu.Unlock() + return c.client, c.gen +} + +func (c *conn) refresh(ctx context.Context, failedGen uint64) (*ssh.Client, uint64, error) { + key := strconv.FormatUint(failedGen, 10) + + type healed struct { + client *ssh.Client + gen uint64 + } + + v, err, _ := c.flight.Do(key, func() (interface{}, error) { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return nil, errClosed + } + if c.gen != failedGen { + cur := healed{client: c.client, gen: c.gen} + c.mu.Unlock() + return cur, nil + } + c.mu.Unlock() + + dialCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), c.dialTimeout) + defer cancel() + + client, closer, dialErr := c.dialer.Dial(dialCtx) + if dialErr != nil { + return nil, fmt.Errorf("reconnect to %s: %w", c.dialer.Describe(), dialErr) + } + + c.mu.Lock() + if c.closed { + c.mu.Unlock() + _ = closer.Close() + return nil, errClosed + } + old := c.closer + c.client = client + c.closer = closer + c.gen++ + newGen := c.gen + c.mu.Unlock() + + if old != nil { + _ = old.Close() + } + c.log.Warn("ssh: connection re-established", + "route", c.dialer.Describe(), "generation", newGen) + + return healed{client: client, gen: newGen}, nil + }) + if err != nil { + return nil, 0, err + } + r, ok := v.(healed) + if !ok { + return nil, 0, fmt.Errorf("ssh: unexpected refresh result type %T", v) + } + return r.client, r.gen, nil +} + +func (c *conn) keepaliveLoop(ctx context.Context, interval time.Duration) { + defer c.wg.Done() + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + client, gen := c.snapshot() + if client == nil { + continue + } + if err := probeKeepalive(ctx, client, interval); err == nil { + continue + } + if ctx.Err() != nil { + return + } + c.log.Warn("ssh: keepalive failed, healing connection", + "route", c.dialer.Describe()) + if _, _, err := c.refresh(ctx, gen); err != nil { + if c.isClosed() || ctx.Err() != nil { + return + } + c.log.Warn("ssh: keepalive-triggered reconnect failed", + "route", c.dialer.Describe(), "err", err) + } + } + } +} + +func probeKeepalive(ctx context.Context, client *ssh.Client, timeout time.Duration) error { + errc := make(chan error, 1) + go func() { + _, _, err := client.SendRequest("keepalive@openssh.com", true, nil) + errc <- err + }() + + timer := time.NewTimer(timeout) + defer timer.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return fmt.Errorf("ssh: keepalive probe timed out after %s", timeout) + case err := <-errc: + return err + } +} + +func (c *conn) isClosed() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.closed +} + +func (c *conn) Close() error { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return nil + } + c.closed = true + closer := c.closer + cancel := c.kaCancel + c.client = nil + c.closer = nil + c.mu.Unlock() + + if cancel != nil { + cancel() + } + c.wg.Wait() + + if closer != nil { + if err := closer.Close(); err != nil && !isTransient(err) { + return err + } + } + return nil +} + +func withConn[T any](ctx context.Context, c *conn, retries int, op func(context.Context, *ssh.Client) (T, error)) (T, error) { + var zero T + + client, gen := c.snapshot() + for attempt := 0; ; attempt++ { + if err := ctx.Err(); err != nil { + return zero, err + } + if client == nil { + return zero, errClosed + } + + result, err := op(ctx, client) + if err == nil { + return result, nil + } + + if ctx.Err() != nil { + return zero, ctx.Err() + } + if !isTransient(err) { + return zero, err + } + if attempt >= retries { + return zero, fmt.Errorf("after %d attempt(s): %w", attempt+1, err) + } + + c.log.Warn("ssh: operation failed on broken connection, healing", + "route", c.dialer.Describe(), "attempt", attempt+1, "err", err) + + client, gen, err = c.refresh(ctx, gen) + if err != nil { + return zero, fmt.Errorf("heal connection: %w", err) + } + } +} diff --git a/internal/infrastructure/ssh/v2/conn_test.go b/internal/infrastructure/ssh/v2/conn_test.go new file mode 100644 index 0000000..c8ced89 --- /dev/null +++ b/internal/infrastructure/ssh/v2/conn_test.go @@ -0,0 +1,269 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ssh + +import ( + "context" + "errors" + "io" + "sync" + "testing" + "time" + + "golang.org/x/crypto/ssh" +) + +func newTestConn(t *testing.T, d Dialer, keepalive time.Duration) *conn { + t.Helper() + o := defaultOptions() + o.log = quietLogger() + o.keepalive = keepalive + o.dialTimeout = 5 * time.Second + c, err := newConn(context.Background(), d, o) + if err != nil { + t.Fatalf("newConn: %v", err) + } + t.Cleanup(func() { _ = c.Close() }) + return c +} + +func TestConnSnapshotInitialGeneration(t *testing.T) { + t.Parallel() + srv := newTestServer(t) + d := &serverDialer{addr: srv.addr()} + + c := newTestConn(t, d, 0) + client, gen := c.snapshot() + if client == nil { + t.Fatalf("snapshot returned nil client") + } + if gen != 1 { + t.Fatalf("initial generation = %d, want 1", gen) + } + if d.dialCount() != 1 { + t.Fatalf("dial count = %d, want 1", d.dialCount()) + } +} + +func TestConnRefreshStaleGenerationDoesNotReconnect(t *testing.T) { + t.Parallel() + srv := newTestServer(t) + d := &serverDialer{addr: srv.addr()} + c := newTestConn(t, d, 0) + + client, gen, err := c.refresh(context.Background(), 0) + if err != nil { + t.Fatalf("refresh: %v", err) + } + if gen != 1 { + t.Fatalf("generation = %d, want 1 (unchanged)", gen) + } + if client == nil { + t.Fatalf("refresh returned nil client") + } + if d.dialCount() != 1 { + t.Fatalf("dial count = %d, want 1 (no reconnect)", d.dialCount()) + } +} + +func TestConnRefreshDeduplicatesConcurrentReconnects(t *testing.T) { + t.Parallel() + srv := newTestServer(t) + d := &serverDialer{addr: srv.addr()} + c := newTestConn(t, d, 0) + + if d.dialCount() != 1 { + t.Fatalf("setup dial count = %d, want 1", d.dialCount()) + } + + gate := make(chan struct{}) + d.setGate(gate) + + const n = 8 + var wg sync.WaitGroup + gens := make([]uint64, n) + errs := make([]error, n) + start := make(chan struct{}) + + for i := 0; i < n; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + <-start + _, gen, err := c.refresh(context.Background(), 1) + gens[i] = gen + errs[i] = err + }(i) + } + + close(start) + waitFor(t, 2*time.Second, func() bool { return d.dialCount() == 2 }) + close(gate) + wg.Wait() + + for i := 0; i < n; i++ { + if errs[i] != nil { + t.Fatalf("refresher %d error: %v", i, errs[i]) + } + if gens[i] != 2 { + t.Fatalf("refresher %d generation = %d, want 2", i, gens[i]) + } + } + if d.dialCount() != 2 { + t.Fatalf("dial count = %d, want 2 (one reconnect for all callers)", d.dialCount()) + } +} + +func TestWithConnHealsOnTransientFailure(t *testing.T) { + t.Parallel() + srv := newTestServer(t) + d := &serverDialer{addr: srv.addr()} + c := newTestConn(t, d, 0) + + var calls int + got, err := withConn(context.Background(), c, 3, func(_ context.Context, client *ssh.Client) (string, error) { + calls++ + if calls == 1 { + return "", io.EOF // looks like a dropped session + } + if client == nil { + return "", errors.New("nil client after heal") + } + return "ok", nil + }) + if err != nil { + t.Fatalf("withConn: %v", err) + } + if got != "ok" { + t.Fatalf("result = %q, want ok", got) + } + if calls != 2 { + t.Fatalf("op calls = %d, want 2", calls) + } + if d.dialCount() != 2 { + t.Fatalf("dial count = %d, want 2 (one heal)", d.dialCount()) + } +} + +func TestWithConnDoesNotRetryNonTransient(t *testing.T) { + t.Parallel() + srv := newTestServer(t) + d := &serverDialer{addr: srv.addr()} + c := newTestConn(t, d, 0) + + sentinel := errors.New("application error") + var calls int + _, err := withConn(context.Background(), c, 3, func(_ context.Context, _ *ssh.Client) (struct{}, error) { + calls++ + return struct{}{}, sentinel + }) + if !errors.Is(err, sentinel) { + t.Fatalf("err = %v, want %v", err, sentinel) + } + if calls != 1 { + t.Fatalf("op calls = %d, want 1 (no retry)", calls) + } + if d.dialCount() != 1 { + t.Fatalf("dial count = %d, want 1 (no reconnect)", d.dialCount()) + } +} + +func TestWithConnRespectsContextCancellation(t *testing.T) { + t.Parallel() + srv := newTestServer(t) + d := &serverDialer{addr: srv.addr()} + c := newTestConn(t, d, 0) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + var calls int + _, err := withConn(ctx, c, 3, func(_ context.Context, _ *ssh.Client) (struct{}, error) { + calls++ + return struct{}{}, nil + }) + if !errors.Is(err, context.Canceled) { + t.Fatalf("err = %v, want context.Canceled", err) + } + if calls != 0 { + t.Fatalf("op calls = %d, want 0 (ctx already canceled)", calls) + } +} + +func TestWithConnExhaustsRetries(t *testing.T) { + t.Parallel() + srv := newTestServer(t) + d := &serverDialer{addr: srv.addr()} + c := newTestConn(t, d, 0) + + var calls int + _, err := withConn(context.Background(), c, 2, func(_ context.Context, _ *ssh.Client) (struct{}, error) { + calls++ + return struct{}{}, io.EOF + }) + if err == nil { + t.Fatalf("expected error after exhausting retries") + } + if !errors.Is(err, io.EOF) { + t.Fatalf("err = %v, want wrapped io.EOF", err) + } + if calls != 3 { + t.Fatalf("op calls = %d, want 3", calls) + } +} + +func TestConnCloseIsIdempotent(t *testing.T) { + t.Parallel() + srv := newTestServer(t) + d := &serverDialer{addr: srv.addr()} + c := newTestConn(t, d, 0) + + if err := c.Close(); err != nil { + t.Fatalf("first Close: %v", err) + } + if err := c.Close(); err != nil { + t.Fatalf("second Close: %v", err) + } + if _, _, err := c.refresh(context.Background(), 1); !errors.Is(err, errClosed) { + t.Fatalf("refresh after close = %v, want errClosed", err) + } +} + +func TestKeepaliveHealsDroppedConnection(t *testing.T) { + t.Parallel() + srv := newTestServer(t) + d := &serverDialer{addr: srv.addr()} + _ = newTestConn(t, d, 100*time.Millisecond) + + srv.dropConns() + + waitFor(t, 5*time.Second, func() bool { return d.dialCount() >= 2 }) + if d.dialCount() < 2 { + t.Fatalf("dial count = %d, want >= 2 (keepalive heal)", d.dialCount()) + } +} + +func waitFor(t *testing.T, timeout time.Duration, cond func() bool) { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if cond() { + return + } + time.Sleep(5 * time.Millisecond) + } +} diff --git a/internal/infrastructure/ssh/v2/dialer.go b/internal/infrastructure/ssh/v2/dialer.go new file mode 100644 index 0000000..37981e9 --- /dev/null +++ b/internal/infrastructure/ssh/v2/dialer.go @@ -0,0 +1,178 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ssh + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "strings" + "time" + + "golang.org/x/crypto/ssh" +) + +type Dialer interface { + Dial(ctx context.Context) (*ssh.Client, io.Closer, error) + Describe() string +} + +type hostKeyDefaulter interface { + setDefaultHostKey(ssh.HostKeyCallback) +} + +func Route(first Endpoint, more ...Endpoint) Dialer { + hops := make([]Endpoint, 0, 1+len(more)) + hops = append(hops, first) + hops = append(hops, more...) + return &route{hops: hops} +} + +type route struct { + hops []Endpoint + defaultHostKey ssh.HostKeyCallback +} + +func (r *route) setDefaultHostKey(cb ssh.HostKeyCallback) { r.defaultHostKey = cb } + +func (r *route) Describe() string { + labels := make([]string, len(r.hops)) + for i, hop := range r.hops { + labels[i] = hop.label() + } + return strings.Join(labels, " -> ") +} + +func (r *route) Dial(ctx context.Context) (cl *ssh.Client, closer io.Closer, err error) { + chain := &chainCloser{} + defer func() { + if err != nil { + _ = chain.Close() + } + }() + + first := r.hops[0] + cfg, agentCloser, cfgErr := first.clientConfig(ctx, r.defaultHostKey) + if cfgErr != nil { + return nil, nil, fmt.Errorf("build config for %s: %w", first.label(), cfgErr) + } + chain.add(agentCloser) + + current, dialErr := dialSSH(ctx, first.addr(), cfg) + if dialErr != nil { + return nil, nil, fmt.Errorf("dial %s: %w", first.label(), dialErr) + } + chain.add(current) + + for _, hop := range r.hops[1:] { + hopCfg, hopAgentCloser, hopErr := hop.clientConfig(ctx, r.defaultHostKey) + if hopErr != nil { + return nil, nil, fmt.Errorf("build config for %s: %w", hop.label(), hopErr) + } + chain.add(hopAgentCloser) + + next, jumpErr := dialThroughJump(ctx, current, hop.addr()) + if jumpErr != nil { + return nil, nil, fmt.Errorf("dial %s via %s: %w", hop.label(), first.label(), jumpErr) + } + + hopClient, handshakeErr := handshakeOver(ctx, next, hop.addr(), hopCfg) + if handshakeErr != nil { + _ = next.Close() + return nil, nil, fmt.Errorf("handshake to %s: %w", hop.label(), handshakeErr) + } + chain.add(hopClient) + current = hopClient + } + + return current, chain, nil +} + +func dialSSH(ctx context.Context, addr string, cfg *ssh.ClientConfig) (*ssh.Client, error) { + var d net.Dialer + conn, err := d.DialContext(ctx, "tcp", addr) + if err != nil { + return nil, err + } + client, err := handshakeOver(ctx, conn, addr, cfg) + if err != nil { + _ = conn.Close() + return nil, err + } + return client, nil +} + +func handshakeOver(ctx context.Context, conn net.Conn, addr string, cfg *ssh.ClientConfig) (*ssh.Client, error) { + if deadline, ok := ctx.Deadline(); ok { + _ = conn.SetDeadline(deadline) + } + sshConn, chans, reqs, err := ssh.NewClientConn(conn, addr, cfg) + if err != nil { + return nil, err + } + _ = conn.SetDeadline(time.Time{}) + return ssh.NewClient(sshConn, chans, reqs), nil +} + +func dialThroughJump(ctx context.Context, jump *ssh.Client, addr string) (net.Conn, error) { + type result struct { + conn net.Conn + err error + } + ch := make(chan result, 1) + go func() { + conn, err := jump.Dial("tcp", addr) + ch <- result{conn: conn, err: err} + }() + + select { + case <-ctx.Done(): + go func() { + if r := <-ch; r.conn != nil { + _ = r.conn.Close() + } + }() + return nil, ctx.Err() + case r := <-ch: + return r.conn, r.err + } +} + +type chainCloser struct { + closers []io.Closer +} + +func (cc *chainCloser) add(c io.Closer) { + if c != nil { + cc.closers = append(cc.closers, c) + } +} + +func (cc *chainCloser) Close() error { + var errs []error + for i := len(cc.closers) - 1; i >= 0; i-- { + if err := cc.closers[i].Close(); err != nil && !isTransient(err) { + errs = append(errs, err) + } + } + if len(errs) == 0 { + return nil + } + return fmt.Errorf("close ssh chain: %w", errors.Join(errs...)) +} diff --git a/internal/infrastructure/ssh/v2/dialer_test.go b/internal/infrastructure/ssh/v2/dialer_test.go new file mode 100644 index 0000000..0c41c9c --- /dev/null +++ b/internal/infrastructure/ssh/v2/dialer_test.go @@ -0,0 +1,151 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ssh + +import ( + "errors" + "fmt" + "io" + "sync" + "testing" +) + +func TestRouteHopsAndDescribe(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + first Endpoint + more []Endpoint + wantHops int + wantDesc string + }{ + { + name: "direct", + first: Endpoint{User: "root", Addr: "target"}, + wantHops: 1, + wantDesc: "root@target:22", + }, + { + name: "single jump", + first: Endpoint{User: "bastion", Addr: "jump:2222"}, + more: []Endpoint{{User: "root", Addr: "target"}}, + wantHops: 2, + wantDesc: "bastion@jump:2222 -> root@target:22", + }, + { + name: "two jumps preserve order", + first: Endpoint{User: "a", Addr: "h1"}, + more: []Endpoint{ + {User: "b", Addr: "h2"}, + {User: "c", Addr: "h3"}, + }, + wantHops: 3, + wantDesc: "a@h1:22 -> b@h2:22 -> c@h3:22", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + d := Route(tc.first, tc.more...) + r, ok := d.(*route) + if !ok { + t.Fatalf("Route returned %T, want *route", d) + } + if len(r.hops) != tc.wantHops { + t.Fatalf("hops = %d, want %d", len(r.hops), tc.wantHops) + } + if got := d.Describe(); got != tc.wantDesc { + t.Fatalf("Describe() = %q, want %q", got, tc.wantDesc) + } + }) + } +} + +type recordCloser struct { + id int + order *[]int + mu *sync.Mutex + err error +} + +func (c recordCloser) Close() error { + c.mu.Lock() + *c.order = append(*c.order, c.id) + c.mu.Unlock() + return c.err +} + +func TestChainCloserReverseOrderAndNilSkip(t *testing.T) { + t.Parallel() + + var order []int + var mu sync.Mutex + cc := &chainCloser{} + + cc.add(recordCloser{id: 1, order: &order, mu: &mu}) + cc.add(nil) // must be skipped without panicking + cc.add(recordCloser{id: 2, order: &order, mu: &mu}) + cc.add(recordCloser{id: 3, order: &order, mu: &mu}) + + if err := cc.Close(); err != nil { + t.Fatalf("Close() unexpected error: %v", err) + } + + want := []int{3, 2, 1} + if len(order) != len(want) { + t.Fatalf("close order = %v, want %v", order, want) + } + for i := range want { + if order[i] != want[i] { + t.Fatalf("close order = %v, want %v", order, want) + } + } +} + +func TestChainCloserAggregatesErrors(t *testing.T) { + t.Parallel() + + var order []int + var mu sync.Mutex + boom := errors.New("close boom") + cc := &chainCloser{} + cc.add(recordCloser{id: 1, order: &order, mu: &mu, err: boom}) + cc.add(recordCloser{id: 2, order: &order, mu: &mu}) + + err := cc.Close() + if err == nil || !errors.Is(err, boom) { + t.Fatalf("Close() = %v, want error wrapping %v", err, boom) + } +} + +// transientCloser returns a transient error from Close; chainCloser must ignore +// it (an already-dead peer is not a close failure worth surfacing). +type transientCloser struct{} + +func (transientCloser) Close() error { return fmt.Errorf("read: %w", io.EOF) } + +func TestChainCloserIgnoresTransientCloseErrors(t *testing.T) { + t.Parallel() + + cc := &chainCloser{} + cc.add(transientCloser{}) + if err := cc.Close(); err != nil { + t.Fatalf("Close() = %v, want nil (transient close errors ignored)", err) + } +} diff --git a/internal/infrastructure/ssh/v2/endpoint.go b/internal/infrastructure/ssh/v2/endpoint.go new file mode 100644 index 0000000..01773b7 --- /dev/null +++ b/internal/infrastructure/ssh/v2/endpoint.go @@ -0,0 +1,142 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ssh + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "os" + "os/user" + "path/filepath" + "strings" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" +) + +type Endpoint struct { + User string + Addr string + KeyPath string + Passphrase string + HostKey ssh.HostKeyCallback +} + +func (e Endpoint) addr() string { + if e.Addr == "" { + return "" + } + if _, _, err := net.SplitHostPort(e.Addr); err == nil { + return e.Addr + } + return net.JoinHostPort(e.Addr, "22") +} + +func (e Endpoint) label() string { + return fmt.Sprintf("%s@%s", e.User, e.addr()) +} + +func (e Endpoint) clientConfig(ctx context.Context, defaultHostKey ssh.HostKeyCallback) (*ssh.ClientConfig, io.Closer, error) { + var signers []ssh.Signer + + if e.KeyPath != "" { + keyPath, err := expandTilde(e.KeyPath) + if err != nil { + return nil, nil, fmt.Errorf("resolve key path %q: %w", e.KeyPath, err) + } + raw, err := os.ReadFile(keyPath) + if err != nil { + return nil, nil, fmt.Errorf("read private key %q: %w", keyPath, err) + } + signer, err := parseSigner(raw, e.Passphrase) + if err != nil { + return nil, nil, fmt.Errorf("parse private key %q: %w", keyPath, err) + } + if signer != nil { + signers = append(signers, signer) + } + } + + agentCloser := io.Closer(nil) + if sock := os.Getenv("SSH_AUTH_SOCK"); sock != "" { + var dialer net.Dialer + if conn, err := dialer.DialContext(ctx, "unix", sock); err == nil { + if agentSigners, err := agent.NewClient(conn).Signers(); err == nil { + signers = append(signers, agentSigners...) + } + agentCloser = conn + } + } + + if len(signers) == 0 { + return nil, nil, fmt.Errorf("no usable credentials for %s: set KeyPath or start an ssh-agent", e.label()) + } + + hostKey := e.HostKey + if hostKey == nil { + hostKey = defaultHostKey + } + if hostKey == nil { + hostKey = ssh.InsecureIgnoreHostKey() + } + + cfg := &ssh.ClientConfig{ + User: e.User, + Auth: []ssh.AuthMethod{ssh.PublicKeys(signers...)}, + HostKeyCallback: hostKey, + Timeout: defaultDialTimeout, + } + return cfg, agentCloser, nil +} + +func parseSigner(raw []byte, passphrase string) (ssh.Signer, error) { + signer, err := ssh.ParsePrivateKey(raw) + if err == nil { + return signer, nil + } + + if _, ok := errors.AsType[*ssh.PassphraseMissingError](err); !ok { + return nil, err + } + + if passphrase == "" { + return nil, nil + } + + signer, err = ssh.ParsePrivateKeyWithPassphrase(raw, []byte(passphrase)) + if err != nil { + return nil, fmt.Errorf("decrypt private key with passphrase: %w", err) + } + return signer, nil +} + +func expandTilde(path string) (string, error) { + if !strings.HasPrefix(path, "~") { + return path, nil + } + usr, err := user.Current() + if err != nil { + return "", fmt.Errorf("look up current user: %w", err) + } + if path == "~" { + return usr.HomeDir, nil + } + return filepath.Join(usr.HomeDir, strings.TrimPrefix(path, "~/")), nil +} diff --git a/internal/infrastructure/ssh/v2/endpoint_test.go b/internal/infrastructure/ssh/v2/endpoint_test.go new file mode 100644 index 0000000..a72d17a --- /dev/null +++ b/internal/infrastructure/ssh/v2/endpoint_test.go @@ -0,0 +1,122 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ssh + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/pem" + "testing" + + "golang.org/x/crypto/ssh" +) + +func TestEndpointAddr(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + addr string + want string + }{ + {name: "host only gets default port", addr: "example.com", want: "example.com:22"}, + {name: "host with port preserved", addr: "example.com:2222", want: "example.com:2222"}, + {name: "ipv4 with port", addr: "10.0.0.1:6443", want: "10.0.0.1:6443"}, + {name: "empty stays empty", addr: "", want: ""}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + e := Endpoint{User: "u", Addr: tc.addr} + if got := e.addr(); got != tc.want { + t.Fatalf("addr() = %q, want %q", got, tc.want) + } + }) + } +} + +func TestExpandTilde(t *testing.T) { + t.Parallel() + + t.Run("no tilde unchanged", func(t *testing.T) { + t.Parallel() + got, err := expandTilde("/etc/ssh/key") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != "/etc/ssh/key" { + t.Fatalf("got %q, want /etc/ssh/key", got) + } + }) + + t.Run("tilde expands to home", func(t *testing.T) { + t.Parallel() + got, err := expandTilde("~/.ssh/id_ed25519") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got == "~/.ssh/id_ed25519" { + t.Fatalf("tilde was not expanded: %q", got) + } + }) +} + +func TestParseSigner(t *testing.T) { + _, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("generate key: %v", err) + } + + plain, err := ssh.MarshalPrivateKey(priv, "") + if err != nil { + t.Fatalf("marshal plain key: %v", err) + } + plainPEM := pem.EncodeToMemory(plain) + + encrypted, err := ssh.MarshalPrivateKeyWithPassphrase(priv, "", []byte("s3cret")) + if err != nil { + t.Fatalf("marshal encrypted key: %v", err) + } + encryptedPEM := pem.EncodeToMemory(encrypted) + + t.Run("plain key parses", func(t *testing.T) { + signer, err := parseSigner(plainPEM, "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if signer == nil { + t.Fatalf("expected a signer, got nil") + } + }) + + t.Run("encrypted with explicit passphrase parses", func(t *testing.T) { + signer, err := parseSigner(encryptedPEM, "s3cret") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if signer == nil { + t.Fatalf("expected a signer, got nil") + } + }) + + t.Run("garbage fails", func(t *testing.T) { + if _, err := parseSigner([]byte("not a key"), ""); err == nil { + t.Fatalf("expected error for garbage input") + } + }) +} diff --git a/internal/infrastructure/ssh/v2/errors.go b/internal/infrastructure/ssh/v2/errors.go new file mode 100644 index 0000000..6d42020 --- /dev/null +++ b/internal/infrastructure/ssh/v2/errors.go @@ -0,0 +1,66 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ssh + +import ( + "context" + "errors" + "io" + "net" + "syscall" +) + +var errClosed = errors.New("ssh: client is closed") + +// isTransient reports whether err denotes a recoverable transport failure that +// healing the SSH connection might fix (a dropped session, a reset peer, a +// timed-out read, …). Classification is done structurally via errors.Is and +// errors.As against standard error values and types — never by matching error +// text — so it stays correct as wrapping changes. +func isTransient(err error) bool { + if err == nil { + return false + } + + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return false + } + + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + return true + } + + if errors.Is(err, net.ErrClosed) { + return true + } + + if errors.Is(err, syscall.ECONNRESET) || + errors.Is(err, syscall.ECONNREFUSED) || + errors.Is(err, syscall.ECONNABORTED) || + errors.Is(err, syscall.EPIPE) || + errors.Is(err, syscall.ETIMEDOUT) || + errors.Is(err, syscall.EHOSTUNREACH) || + errors.Is(err, syscall.ENETUNREACH) { + return true + } + + if nerr, ok := errors.AsType[net.Error](err); ok && nerr.Timeout() { + return true + } + + return false +} diff --git a/internal/infrastructure/ssh/v2/errors_test.go b/internal/infrastructure/ssh/v2/errors_test.go new file mode 100644 index 0000000..6303177 --- /dev/null +++ b/internal/infrastructure/ssh/v2/errors_test.go @@ -0,0 +1,66 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ssh + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "syscall" + "testing" +) + +func TestIsTransient(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + want bool + }{ + {name: "nil", err: nil, want: false}, + {name: "io.EOF", err: io.EOF, want: true}, + {name: "wrapped EOF", err: fmt.Errorf("dial: %w", io.EOF), want: true}, + {name: "unexpected EOF", err: io.ErrUnexpectedEOF, want: true}, + {name: "net closed", err: net.ErrClosed, want: true}, + {name: "wrapped net closed", err: fmt.Errorf("accept: %w", net.ErrClosed), want: true}, + {name: "ECONNRESET", err: syscall.ECONNRESET, want: true}, + {name: "ECONNREFUSED", err: syscall.ECONNREFUSED, want: true}, + {name: "EPIPE", err: syscall.EPIPE, want: true}, + {name: "timeout net error", err: timeoutErr{}, want: true}, + {name: "context canceled", err: context.Canceled, want: false}, + {name: "context deadline", err: context.DeadlineExceeded, want: false}, + {name: "plain error", err: errors.New("boom"), want: false}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if got := isTransient(tc.err); got != tc.want { + t.Fatalf("isTransient(%v) = %v, want %v", tc.err, got, tc.want) + } + }) + } +} + +type timeoutErr struct{} + +func (timeoutErr) Error() string { return "i/o timeout" } +func (timeoutErr) Timeout() bool { return true } +func (timeoutErr) Temporary() bool { return true } diff --git a/internal/infrastructure/ssh/v2/options.go b/internal/infrastructure/ssh/v2/options.go new file mode 100644 index 0000000..2f802ee --- /dev/null +++ b/internal/infrastructure/ssh/v2/options.go @@ -0,0 +1,83 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ssh + +import ( + "log/slog" + "time" + + "golang.org/x/crypto/ssh" + + "github.com/deckhouse/storage-e2e/internal/config" + "github.com/deckhouse/storage-e2e/internal/logger" +) + +const defaultDialTimeout = 30 * time.Second + +type options struct { + keepalive time.Duration + retries int + log *slog.Logger + hostKey ssh.HostKeyCallback + dialTimeout time.Duration +} + +func defaultOptions() options { + return options{ + keepalive: 0, + retries: config.SSHRetryCount, + log: logger.GetLogger(), + hostKey: ssh.InsecureIgnoreHostKey(), + dialTimeout: defaultDialTimeout, + } +} + +type Option func(*options) + +func WithKeepalive(d time.Duration) Option { + return func(o *options) { o.keepalive = d } +} + +func WithRetries(n int) Option { + return func(o *options) { + if n < 0 { + n = 0 + } + o.retries = n + } +} + +func WithLogger(l *slog.Logger) Option { + return func(o *options) { + if l != nil { + o.log = l + } + } +} + +func WithHostKeyCallback(cb ssh.HostKeyCallback) Option { + return func(o *options) { + if cb != nil { + o.hostKey = cb + } + } +} + +func WithInsecureIgnoreHostKey() Option { + //nolint:gosec // G106: explicit opt-in to skip host key verification. + return func(o *options) { o.hostKey = ssh.InsecureIgnoreHostKey() } +} diff --git a/internal/infrastructure/ssh/v2/testserver_test.go b/internal/infrastructure/ssh/v2/testserver_test.go new file mode 100644 index 0000000..c51191b --- /dev/null +++ b/internal/infrastructure/ssh/v2/testserver_test.go @@ -0,0 +1,241 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ssh + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "io" + "log/slog" + "net" + "strconv" + "sync" + "testing" + + "golang.org/x/crypto/ssh" +) + +func quietLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(io.Discard, nil)) +} + +type testServer struct { + ln net.Listener + cfg *ssh.ServerConfig + wg sync.WaitGroup + closeOnce sync.Once + + mu sync.Mutex + conns []net.Conn +} + +func newTestServer(t *testing.T) *testServer { + t.Helper() + + _, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("generate host key: %v", err) + } + signer, err := ssh.NewSignerFromSigner(priv) + if err != nil { + t.Fatalf("build host signer: %v", err) + } + + cfg := &ssh.ServerConfig{NoClientAuth: true} + cfg.AddHostKey(signer) + + var lc net.ListenConfig + ln, err := lc.Listen(context.Background(), "tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + + s := &testServer{ln: ln, cfg: cfg} + s.wg.Add(1) + go s.acceptLoop() + t.Cleanup(s.Close) + return s +} + +func (s *testServer) addr() string { return s.ln.Addr().String() } + +func (s *testServer) acceptLoop() { + defer s.wg.Done() + for { + nConn, err := s.ln.Accept() + if err != nil { + return + } + s.mu.Lock() + s.conns = append(s.conns, nConn) + s.mu.Unlock() + + s.wg.Add(1) + go s.handleConn(nConn) + } +} + +func (s *testServer) handleConn(nConn net.Conn) { + defer s.wg.Done() + + sconn, chans, reqs, err := ssh.NewServerConn(nConn, s.cfg) + if err != nil { + _ = nConn.Close() + return + } + defer sconn.Close() + + go func() { + for req := range reqs { + if req.WantReply { + _ = req.Reply(true, nil) + } + } + }() + + for newCh := range chans { + if newCh.ChannelType() != "direct-tcpip" { + _ = newCh.Reject(ssh.UnknownChannelType, "only direct-tcpip is supported") + continue + } + go handleDirectTCPIP(newCh) + } +} + +type directTCPIPMsg struct { + DestAddr string + DestPort uint32 + OrigAddr string + OrigPort uint32 +} + +func handleDirectTCPIP(newCh ssh.NewChannel) { + var msg directTCPIPMsg + if err := ssh.Unmarshal(newCh.ExtraData(), &msg); err != nil { + _ = newCh.Reject(ssh.ConnectionFailed, "bad direct-tcpip payload") + return + } + + target := net.JoinHostPort(msg.DestAddr, strconv.Itoa(int(msg.DestPort))) + var dialer net.Dialer + remote, err := dialer.DialContext(context.Background(), "tcp", target) + if err != nil { + _ = newCh.Reject(ssh.ConnectionFailed, err.Error()) + return + } + + ch, reqs, err := newCh.Accept() + if err != nil { + _ = remote.Close() + return + } + go ssh.DiscardRequests(reqs) + + go func() { + _, _ = io.Copy(ch, remote) + _ = ch.Close() + }() + go func() { + _, _ = io.Copy(remote, ch) + _ = remote.Close() + }() +} + +func (s *testServer) dropConns() { + s.mu.Lock() + defer s.mu.Unlock() + for _, c := range s.conns { + _ = c.Close() + } + s.conns = nil +} + +func (s *testServer) Close() { + s.closeOnce.Do(func() { + _ = s.ln.Close() + s.dropConns() + s.wg.Wait() + }) +} + +type serverDialer struct { + addr string + + mu sync.Mutex + dials int + gate chan struct{} +} + +func (d *serverDialer) Dial(ctx context.Context) (*ssh.Client, io.Closer, error) { + d.mu.Lock() + d.dials++ + gate := d.gate + d.mu.Unlock() + + if gate != nil { + <-gate + } + + client, err := dialSSH(ctx, d.addr, &ssh.ClientConfig{ + User: "test", + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + }) + if err != nil { + return nil, nil, err + } + return client, client, nil +} + +func (d *serverDialer) Describe() string { return "test://" + d.addr } + +func (d *serverDialer) dialCount() int { + d.mu.Lock() + defer d.mu.Unlock() + return d.dials +} + +func (d *serverDialer) setGate(gate chan struct{}) { + d.mu.Lock() + d.gate = gate + d.mu.Unlock() +} + +func newEchoServer(t *testing.T) int { + t.Helper() + var lc net.ListenConfig + ln, err := lc.Listen(context.Background(), "tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("echo listen: %v", err) + } + t.Cleanup(func() { _ = ln.Close() }) + + go func() { + for { + c, err := ln.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer c.Close() + _, _ = io.Copy(c, c) + }(c) + } + }() + + return ln.Addr().(*net.TCPAddr).Port +} diff --git a/internal/infrastructure/ssh/v2/tunnel.go b/internal/infrastructure/ssh/v2/tunnel.go new file mode 100644 index 0000000..9f7fd10 --- /dev/null +++ b/internal/infrastructure/ssh/v2/tunnel.go @@ -0,0 +1,197 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ssh + +import ( + "context" + "errors" + "fmt" + "io" + "log/slog" + "net" + "strconv" + "sync" + "time" + + "golang.org/x/crypto/ssh" +) + +const acceptDeadline = 500 * time.Millisecond + +type Tunnel struct { + LocalPort int + RemotePort int + + listener net.Listener + cancel context.CancelFunc + wg sync.WaitGroup + closeOnce sync.Once + closeErr error + + lnCloseOnce sync.Once + lnCloseErr error +} + +func (c *Client) OpenTunnel(ctx context.Context, remotePort int) (*Tunnel, error) { + if err := ctx.Err(); err != nil { + return nil, fmt.Errorf("tunnel setup: %w", err) + } + + var lc net.ListenConfig + listener, err := lc.Listen(ctx, "tcp", "127.0.0.1:0") + if err != nil { + return nil, fmt.Errorf("listen on local port: %w", err) + } + tcpAddr, ok := listener.Addr().(*net.TCPAddr) + if !ok { + _ = listener.Close() + return nil, fmt.Errorf("unexpected listener address type %T", listener.Addr()) + } + localPort := tcpAddr.Port + + serveCtx, cancel := context.WithCancel(ctx) + + t := &Tunnel{ + LocalPort: localPort, + RemotePort: remotePort, + listener: listener, + cancel: cancel, + } + + t.wg.Add(1) + go t.serve(serveCtx, c.conn, c.retries, c.log) + + c.log.Info("ssh: tunnel established", + "local", t.LocalAddr(), "remote_port", remotePort, "route", c.conn.dialer.Describe()) + + return t, nil +} + +func (t *Tunnel) LocalAddr() string { + return "127.0.0.1:" + strconv.Itoa(t.LocalPort) +} + +func (t *Tunnel) Close() error { + t.closeOnce.Do(func() { + t.cancel() + t.closeErr = t.closeListener() + t.wg.Wait() + }) + return t.closeErr +} + +func (t *Tunnel) closeListener() error { + t.lnCloseOnce.Do(func() { + t.lnCloseErr = t.listener.Close() + }) + return t.lnCloseErr +} + +func (t *Tunnel) serve(ctx context.Context, core *conn, retries int, log *slog.Logger) { + defer t.wg.Done() + defer func() { _ = t.closeListener() }() + + for { + select { + case <-ctx.Done(): + return + default: + } + + if tcp, ok := t.listener.(*net.TCPListener); ok { + _ = tcp.SetDeadline(time.Now().Add(acceptDeadline)) + } + + local, err := t.listener.Accept() + if err != nil { + if ctx.Err() != nil { + return + } + if ne, ok := errors.AsType[net.Error](err); ok && ne.Timeout() { + continue + } + return + } + + t.wg.Add(1) + go func() { + defer t.wg.Done() + t.handle(ctx, core, retries, local, log) + }() + } +} + +func (t *Tunnel) handle(ctx context.Context, core *conn, retries int, local net.Conn, log *slog.Logger) { + defer local.Close() + + remotePort := t.RemotePort + remote, err := withConn(ctx, core, retries, func(ctx context.Context, client *ssh.Client) (net.Conn, error) { + return dialChannel(ctx, client, "127.0.0.1:"+strconv.Itoa(remotePort)) + }) + if err != nil { + if ctx.Err() == nil { + log.Warn("ssh: tunnel forward failed", + "local", t.LocalAddr(), "remote_port", remotePort, "err", err) + } + return + } + defer remote.Close() + + stop := make(chan struct{}) + defer close(stop) + go func() { + select { + case <-ctx.Done(): + _ = local.Close() + _ = remote.Close() + case <-stop: + } + }() + + done := make(chan struct{}, 2) + go func() { _, _ = io.Copy(remote, local); done <- struct{}{} }() + go func() { _, _ = io.Copy(local, remote); done <- struct{}{} }() + + <-done + _ = local.Close() + _ = remote.Close() + <-done +} + +func dialChannel(ctx context.Context, client *ssh.Client, addr string) (net.Conn, error) { + type result struct { + conn net.Conn + err error + } + ch := make(chan result, 1) + go func() { + conn, err := client.Dial("tcp", addr) + ch <- result{conn: conn, err: err} + }() + + select { + case <-ctx.Done(): + go func() { + if r := <-ch; r.conn != nil { + _ = r.conn.Close() + } + }() + return nil, ctx.Err() + case r := <-ch: + return r.conn, r.err + } +} diff --git a/internal/infrastructure/ssh/v2/tunnel_test.go b/internal/infrastructure/ssh/v2/tunnel_test.go new file mode 100644 index 0000000..93877d5 --- /dev/null +++ b/internal/infrastructure/ssh/v2/tunnel_test.go @@ -0,0 +1,247 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ssh + +import ( + "context" + "net" + "testing" + "time" +) + +func newTestClient(t *testing.T, d Dialer, keepalive time.Duration) *Client { + t.Helper() + c, err := New(context.Background(), d, + WithLogger(quietLogger()), + WithKeepalive(keepalive), + ) + if err != nil { + t.Fatalf("New: %v", err) + } + t.Cleanup(func() { _ = c.Close() }) + return c +} + +func dialTimeout(addr string, timeout time.Duration) (net.Conn, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + var d net.Dialer + return d.DialContext(ctx, "tcp", addr) +} + +func roundtrip(t *testing.T, addr, payload string) string { + t.Helper() + conn, err := dialTimeout(addr, 3*time.Second) + if err != nil { + t.Fatalf("dial tunnel %s: %v", addr, err) + } + defer conn.Close() + + _ = conn.SetDeadline(time.Now().Add(3 * time.Second)) + if _, err := conn.Write([]byte(payload)); err != nil { + t.Fatalf("write: %v", err) + } + buf := make([]byte, len(payload)) + if _, err := readFull(conn, buf); err != nil { + t.Fatalf("read: %v", err) + } + return string(buf) +} + +func readFull(conn net.Conn, buf []byte) (int, error) { + total := 0 + for total < len(buf) { + n, err := conn.Read(buf[total:]) + total += n + if err != nil { + return total, err + } + } + return total, nil +} + +func TestTunnelForwardsTraffic(t *testing.T) { + t.Parallel() + echoPort := newEchoServer(t) + srv := newTestServer(t) + d := &serverDialer{addr: srv.addr()} + c := newTestClient(t, d, 0) + + tun, err := c.OpenTunnel(context.Background(), echoPort) + if err != nil { + t.Fatalf("OpenTunnel: %v", err) + } + defer tun.Close() + + if tun.LocalPort == 0 { + t.Fatalf("expected a non-zero local port") + } + if tun.RemotePort != echoPort { + t.Fatalf("RemotePort = %d, want %d", tun.RemotePort, echoPort) + } + if got := tun.LocalAddr(); got == "" { + t.Fatalf("LocalAddr empty") + } + + if got := roundtrip(t, tun.LocalAddr(), "hello-tunnel"); got != "hello-tunnel" { + t.Fatalf("echo = %q, want hello-tunnel", got) + } +} + +func TestTunnelHealsAfterDroppedSession(t *testing.T) { + t.Parallel() + echoPort := newEchoServer(t) + srv := newTestServer(t) + d := &serverDialer{addr: srv.addr()} + c := newTestClient(t, d, 0) + + tun, err := c.OpenTunnel(context.Background(), echoPort) + if err != nil { + t.Fatalf("OpenTunnel: %v", err) + } + defer tun.Close() + + if got := roundtrip(t, tun.LocalAddr(), "before"); got != "before" { + t.Fatalf("echo before drop = %q, want before", got) + } + + srv.dropConns() + + var lastErr error + deadline := time.Now().Add(8 * time.Second) + for time.Now().Before(deadline) { + got, err := tryRoundtrip(tun.LocalAddr(), "after") + if err == nil && got == "after" { + lastErr = nil + break + } + lastErr = err + time.Sleep(20 * time.Millisecond) + } + if lastErr != nil { + t.Fatalf("tunnel did not heal after dropped session: %v", lastErr) + } + if d.dialCount() < 2 { + t.Fatalf("dial count = %d, want >= 2 (healed)", d.dialCount()) + } +} + +func tryRoundtrip(addr, payload string) (string, error) { + conn, err := dialTimeout(addr, 2*time.Second) + if err != nil { + return "", err + } + defer conn.Close() + _ = conn.SetDeadline(time.Now().Add(2 * time.Second)) + if _, err := conn.Write([]byte(payload)); err != nil { + return "", err + } + buf := make([]byte, len(payload)) + if _, err := readFull(conn, buf); err != nil { + return "", err + } + return string(buf), nil +} + +func TestTunnelCloseIsIdempotentAndStopsListener(t *testing.T) { + t.Parallel() + echoPort := newEchoServer(t) + srv := newTestServer(t) + d := &serverDialer{addr: srv.addr()} + c := newTestClient(t, d, 0) + + tun, err := c.OpenTunnel(context.Background(), echoPort) + if err != nil { + t.Fatalf("OpenTunnel: %v", err) + } + addr := tun.LocalAddr() + + if err := tun.Close(); err != nil { + t.Fatalf("first Close: %v", err) + } + if err := tun.Close(); err != nil { + t.Fatalf("second Close: %v", err) + } + + waitFor(t, 2*time.Second, func() bool { + conn, err := dialTimeout(addr, 200*time.Millisecond) + if err != nil { + return true + } + _ = conn.Close() + return false + }) + if conn, err := dialTimeout(addr, 200*time.Millisecond); err == nil { + _ = conn.Close() + t.Fatalf("listener still accepting after Close") + } +} + +func TestTunnelStopsWhenContextCancelled(t *testing.T) { + t.Parallel() + echoPort := newEchoServer(t) + srv := newTestServer(t) + d := &serverDialer{addr: srv.addr()} + c := newTestClient(t, d, 0) + + ctx, cancel := context.WithCancel(context.Background()) + tun, err := c.OpenTunnel(ctx, echoPort) + if err != nil { + t.Fatalf("OpenTunnel: %v", err) + } + defer tun.Close() + + addr := tun.LocalAddr() + cancel() + + waitFor(t, 2*time.Second, func() bool { + conn, err := dialTimeout(addr, 200*time.Millisecond) + if err != nil { + return true + } + _ = conn.Close() + return false + }) + if conn, err := dialTimeout(addr, 200*time.Millisecond); err == nil { + _ = conn.Close() + t.Fatalf("listener still accepting after context cancel") + } +} + +func TestNewRejectsNilDialer(t *testing.T) { + t.Parallel() + _, err := New(context.Background(), nil) + if err == nil { + t.Fatalf("expected error for nil dialer") + } +} + +func TestClientCloseIdempotent(t *testing.T) { + t.Parallel() + srv := newTestServer(t) + d := &serverDialer{addr: srv.addr()} + c, err := New(context.Background(), d, WithLogger(quietLogger())) + if err != nil { + t.Fatalf("New: %v", err) + } + if err := c.Close(); err != nil { + t.Fatalf("first Close: %v", err) + } + if err := c.Close(); err != nil { + t.Fatalf("second Close: %v", err) + } +} diff --git a/internal/provisioning/dvp/config.go b/internal/provisioning/dvp/config.go index e34fcc9..0c470b2 100644 --- a/internal/provisioning/dvp/config.go +++ b/internal/provisioning/dvp/config.go @@ -16,49 +16,24 @@ limitations under the License. package dvp -import ( - "fmt" - "os" -) +const apiServerRemotePort = 6445 type Config struct { SSHUser string `env:"E2E_DVP_BASE_CLUSTER_SSH_USER,required"` SSHHost string `env:"E2E_DVP_BASE_CLUSTER_SSH_HOST,required"` - SSHKeyPath string `env:"E2E_DVP_BASE_CLUSTER_SSH_KEY_PATH,required"` + SSHKeyPath string `env:"E2E_DVP_BASE_CLUSTER_SSH_PRIVATE_KEY_PATH,required"` SSHPassphrase string `env:"E2E_DVP_BASE_CLUSTER_SSH_PASSPHRASE"` - SSHJumpHost string `env:"E2E_DVP_BASE_CLUSTER_SSH_JUMP_HOST"` - SSHJumpUser string `env:"E2E_DVP_BASE_CLUSTER_SSH_JUMP_USER"` - SSHJumpKeyPath string `env:"E2E_DVP_BASE_CLUSTER_SSH_JUMP_KEY_PATH"` + SSHJumpHost string `env:"E2E_DVP_BASE_CLUSTER_SSH_JUMP_HOST"` + SSHJumpUser string `env:"E2E_DVP_BASE_CLUSTER_SSH_JUMP_USER"` + SSHJumpKeyPath string `env:"E2E_DVP_BASE_CLUSTER_SSH_JUMP_PRIVATE_KEY_PATH"` + SSHJumpPassphrase string `env:"E2E_DVP_BASE_CLUSTER_SSH_JUMP_KEY_PASSPHRASE"` KubeConfigPath string `env:"E2E_DVP_BASE_CLUSTER_KUBECONFIG_PATH,required"` Namespace string `env:"E2E_DVP_BASE_CLUSTER_NAMESPACE" envDefault:"e2e-test-cluster"` } -func (c *Config) SetPassphrase() error { - if c.SSHPassphrase == "" { - return nil - } - if err := os.Setenv("SSH_PASSPHRASE", c.SSHPassphrase); err != nil { - return fmt.Errorf("failed to set SSH_PASSPHRASE: %w", err) - } - return nil -} - -func (c *Config) baseEndpoint() sshEndpoint { - ep := sshEndpoint{User: c.SSHUser, Host: c.SSHHost, KeyPath: c.SSHKeyPath} - if c.SSHJumpHost == "" { - return ep - } - - jump := sshEndpoint{User: c.SSHJumpUser, Host: c.SSHJumpHost, KeyPath: c.SSHJumpKeyPath} - if jump.User == "" { - jump.User = c.SSHUser - } - if jump.KeyPath == "" { - jump.KeyPath = c.SSHKeyPath - } - ep.Jump = &jump - return ep +func (c *Config) HasJumpHost() bool { + return c.SSHJumpUser != "" && c.SSHJumpHost != "" && c.SSHJumpKeyPath != "" } diff --git a/internal/provisioning/dvp/connection.go b/internal/provisioning/dvp/connection.go deleted file mode 100644 index 4bab332..0000000 --- a/internal/provisioning/dvp/connection.go +++ /dev/null @@ -1,85 +0,0 @@ -/* -Copyright 2026 Flant JSC - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package dvp - -import ( - "context" - "errors" - "fmt" - - "github.com/deckhouse/storage-e2e/internal/infrastructure/ssh" -) - -const apiServerRemotePort = "6445" - -type sshEndpoint struct { - User string - Host string - KeyPath string - Jump *sshEndpoint -} - -func (e sshEndpoint) dial() (ssh.SSHClient, error) { - if e.Jump != nil { - return ssh.NewClientWithJumpHost( - e.Jump.User, e.Jump.Host, e.Jump.KeyPath, - e.User, e.Host, e.KeyPath, - ) - } - return ssh.NewClient(e.User, e.Host, e.KeyPath) -} - -type clusterConnection struct { - ssh ssh.SSHClient - tunnel *ssh.TunnelInfo -} - -func openTunnel(ctx context.Context, ep sshEndpoint) (*clusterConnection, error) { - sshClient, err := ep.dial() - if err != nil { - return nil, fmt.Errorf("ssh dial %s@%s: %w", ep.User, ep.Host, err) - } - - conn := &clusterConnection{ssh: sshClient} - - conn.tunnel, err = sshClient.OpenTunnel(ctx, apiServerRemotePort) - if err != nil { - _ = conn.Close() - return nil, fmt.Errorf("establish API server tunnel: %w", err) - } - - return conn, nil -} - -func (c *clusterConnection) Close() error { - if c == nil { - return nil - } - - var errs []error - if c.tunnel != nil && c.tunnel.StopFunc != nil { - if err := c.tunnel.StopFunc(); err != nil { - errs = append(errs, fmt.Errorf("stop API server tunnel: %w", err)) - } - } - if c.ssh != nil { - if err := c.ssh.Close(); err != nil { - errs = append(errs, fmt.Errorf("close ssh client: %w", err)) - } - } - return errors.Join(errs...) -} diff --git a/internal/provisioning/dvp/kubeconfig.go b/internal/provisioning/dvp/kubeconfig.go index fd94755..89016d1 100644 --- a/internal/provisioning/dvp/kubeconfig.go +++ b/internal/provisioning/dvp/kubeconfig.go @@ -25,7 +25,6 @@ import ( "time" "k8s.io/client-go/rest" - "k8s.io/client-go/tools/clientcmd" ) func readKubeconfig(path string) ([]byte, error) { @@ -58,57 +57,7 @@ func expandUserPath(path string) (string, error) { return filepath.Join(home, strings.TrimPrefix(expanded, "~/")), nil } -func loadKubeconfigViaTunnel(localPort int, kubeconfigDir, host, kubeconfigSrcPath string) (*rest.Config, string, error) { - raw, err := readKubeconfig(kubeconfigSrcPath) - if err != nil { - return nil, "", fmt.Errorf("load base cluster kubeconfig: %w", err) - } - - path, err := kubeconfigFilePath(kubeconfigDir, host) - if err != nil { - return nil, "", err - } - - server := fmt.Sprintf("https://127.0.0.1:%d", localPort) - cfg, err := buildKubeconfig(raw, server, path) - if err != nil { - return nil, "", fmt.Errorf("build kubeconfig: %w", err) - } - return cfg, path, nil -} - -func buildKubeconfig(raw []byte, server, path string) (*rest.Config, error) { - apiCfg, err := clientcmd.Load(raw) - if err != nil { - return nil, fmt.Errorf("parse kubeconfig: %w", err) - } - for _, cluster := range apiCfg.Clusters { - cluster.Server = server - } - - if writeErr := clientcmd.WriteToFile(*apiCfg, path); writeErr != nil { - return nil, fmt.Errorf("write kubeconfig %q: %w", path, writeErr) - } - - restCfg, err := clientcmd.NewDefaultClientConfig(*apiCfg, &clientcmd.ConfigOverrides{}).ClientConfig() - - if err != nil { - return nil, fmt.Errorf("build rest config: %w", err) - } - configureTunnelTimeouts(restCfg) - return restCfg, nil -} - -func kubeconfigFilePath(dir, host string) (string, error) { - if err := os.MkdirAll(dir, 0o700); err != nil { - return "", fmt.Errorf("create kubeconfig dir %q: %w", dir, err) - } - return filepath.Join(dir, fmt.Sprintf("kubeconfig-%s.yml", host)), nil -} - func configureTunnelTimeouts(cfg *rest.Config) { - cfg.Timeout = 2 * time.Minute - prev := cfg.WrapTransport cfg.WrapTransport = func(rt http.RoundTripper) http.RoundTripper { if prev != nil { diff --git a/internal/provisioning/dvp/provider.go b/internal/provisioning/dvp/provider.go index f4b20a8..cc87310 100644 --- a/internal/provisioning/dvp/provider.go +++ b/internal/provisioning/dvp/provider.go @@ -20,10 +20,15 @@ import ( "context" "fmt" "log/slog" + "time" "github.com/caarlos0/env/v11" + "k8s.io/client-go/rest" + "k8s.io/client-go/tools/clientcmd" + clientcmdapi "k8s.io/client-go/tools/clientcmd/api" "github.com/deckhouse/storage-e2e/internal/config" + "github.com/deckhouse/storage-e2e/internal/infrastructure/ssh/v2" "github.com/deckhouse/storage-e2e/pkg/clusterprovider" "github.com/deckhouse/storage-e2e/pkg/kubernetes" ) @@ -39,10 +44,6 @@ func NewDVPProvider(logger *slog.Logger, cfg *clusterprovider.ClusterConfig) (cl if err := env.Parse(dvpConf); err != nil { return nil, err } - err := dvpConf.SetPassphrase() - if err != nil { - return nil, err - } return &dvpProvider{ cfg: cfg, @@ -53,6 +54,63 @@ func NewDVPProvider(logger *slog.Logger, cfg *clusterprovider.ClusterConfig) (cl func (p *dvpProvider) Name() string { return clusterprovider.ModeDVP } +func (p *dvpProvider) buildSshClient(ctx context.Context) (*ssh.Client, error) { + var dialer ssh.Dialer + if p.dvpConf.HasJumpHost() { + dialer = ssh.Route(ssh.Endpoint{ + User: p.dvpConf.SSHJumpUser, + Addr: p.dvpConf.SSHJumpHost, + KeyPath: p.dvpConf.SSHJumpKeyPath, + Passphrase: p.dvpConf.SSHJumpPassphrase, + }, ssh.Endpoint{ + User: p.dvpConf.SSHUser, + Addr: p.dvpConf.SSHHost, + KeyPath: p.dvpConf.SSHKeyPath, + Passphrase: p.dvpConf.SSHPassphrase, + }) + } else { + dialer = ssh.Route(ssh.Endpoint{ + User: p.dvpConf.SSHUser, + Addr: p.dvpConf.SSHHost, + KeyPath: p.dvpConf.SSHKeyPath, + Passphrase: p.dvpConf.SSHPassphrase, + }) + } + + sshClient, sshNewErr := ssh.New(ctx, dialer) + if sshNewErr != nil { + return nil, fmt.Errorf("creating ssh client: %w", sshNewErr) + } + return sshClient, nil +} + +func (p *dvpProvider) buildRestConfig(tun *ssh.Tunnel) (*rest.Config, error) { + rawKubeconfig, readErr := readKubeconfig(p.dvpConf.KubeConfigPath) + if readErr != nil { + return nil, fmt.Errorf("reading kubeconfig: %w", readErr) + } + + apiCfg, err := clientcmd.Load(rawKubeconfig) + overrides := &clientcmd.ConfigOverrides{ + ClusterInfo: clientcmdapi.Cluster{ + Server: tun.LocalAddr(), + }, + Timeout: (2 * time.Minute).String(), + } + + if err != nil { + return nil, fmt.Errorf("parsing kubeconfig: %w", err) + } + + restConfig, clientConfigErr := clientcmd.NewDefaultClientConfig(*apiCfg, overrides).ClientConfig() + if clientConfigErr != nil { + return nil, fmt.Errorf("creating client config: %w", clientConfigErr) + } + + configureTunnelTimeouts(restConfig) + return restConfig, nil +} + func (p *dvpProvider) Bootstrap(ctx context.Context) error { clusterDef, err := config.LoadClusterDefinition(p.cfg.ClusterBootstrapConfigPath) if err != nil { @@ -70,26 +128,28 @@ func (p *dvpProvider) Bootstrap(ctx context.Context) error { "jumpHost", p.dvpConf.SSHJumpHost, "kubeconfigSource", p.dvpConf.KubeConfigPath, ) - conn, err := openTunnel(ctx, p.dvpConf.baseEndpoint()) - if err != nil { - return fmt.Errorf("open tunnel to DVP base cluster: %w", err) + + sshClient, sshNewErr := p.buildSshClient(ctx) + if sshNewErr != nil { + return fmt.Errorf("creating ssh client: %w", sshNewErr) + } + + tun, tunErr := sshClient.OpenTunnel(ctx, apiServerRemotePort) + + if tunErr != nil { + return fmt.Errorf("creating tunnel: %w", tunErr) } defer func() { - if cerr := conn.Close(); cerr != nil { - p.logger.Warn("close DVP base cluster connection", "err", cerr) + tunCloseErr := tun.Close() + if tunCloseErr != nil { + p.logger.Warn("failed to close tunnel", "err", tunCloseErr) } }() - kubeconfig, kubeconfigPath, err := loadKubeconfigViaTunnel( - conn.tunnel.LocalPort, config.E2ETempDir, p.dvpConf.SSHHost, p.dvpConf.KubeConfigPath, - ) - if err != nil { - return fmt.Errorf("build kubeconfig for DVP base cluster: %w", err) + kubeconfig, buildRestConfErr := p.buildRestConfig(tun) + if buildRestConfErr != nil { + return fmt.Errorf("creating rest config: %w", buildRestConfErr) } - p.logger.Info("connected to DVP base cluster", - "kubeconfig", kubeconfigPath, - "apiServer", kubeconfig.Host, - ) p.logger.Info("waiting for virtualization module to become ready", "timeout", config.ModuleCheckTimeout, diff --git a/pkg/cluster/cluster.go b/pkg/cluster/cluster.go index 327f321..d06f002 100644 --- a/pkg/cluster/cluster.go +++ b/pkg/cluster/cluster.go @@ -39,6 +39,8 @@ import ( "k8s.io/client-go/rest" "k8s.io/client-go/tools/clientcmd" + "github.com/deckhouse/virtualization/api/core/v1alpha2" + internalcluster "github.com/deckhouse/storage-e2e/internal/cluster" "github.com/deckhouse/storage-e2e/internal/config" "github.com/deckhouse/storage-e2e/internal/infrastructure/ssh" @@ -47,7 +49,6 @@ import ( "github.com/deckhouse/storage-e2e/internal/logger" "github.com/deckhouse/storage-e2e/pkg/kubernetes" "github.com/deckhouse/storage-e2e/pkg/testkit" - "github.com/deckhouse/virtualization/api/core/v1alpha2" ) // extraCommanderValues stores additional values to be passed to Commander cluster creation @@ -1607,7 +1608,7 @@ func CleanupTestCluster(ctx context.Context, resources *TestClusterResources) er } } } else { - // Tunnel already exists, use it + // OpenTunnel already exists, use it logger.Success("Base cluster tunnel already exists") baseTunnel = resources.BaseTunnelInfo cleanupKubeconfig = resources.BaseKubeconfig diff --git a/pkg/kubernetes/modules.go b/pkg/kubernetes/modules.go index 896ab8f..2285f83 100644 --- a/pkg/kubernetes/modules.go +++ b/pkg/kubernetes/modules.go @@ -590,9 +590,6 @@ const moduleReadyPollInterval = 2 * time.Second // - On timeout the error carries the last observed phase and the IsReady // condition message so a stuck module is diagnosable from logs alone. func WaitForModuleReady(ctx context.Context, kubeconfig *rest.Config, moduleName string, timeout time.Duration) error { - ctx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - var lastPhase, lastCondition string // ready re-reads the module and reports whether it has converged, recording