diff --git a/.coderabbit.yaml b/.coderabbit.yaml index 3221c5b..81a41ed 100644 --- a/.coderabbit.yaml +++ b/.coderabbit.yaml @@ -121,7 +121,11 @@ reviews: comment/string-aware multi-statement rejection and SELECT/WITH-only (DML behind WithAllowDML) policy. EXPLAIN takes no bind params, so concatenation is by design — the defense is validate() + the rolled-back - read-only tx; do not "fix" it with parameterization. + read-only tx; do not "fix" it with parameterization. Deliberate + carve-out: explain keeps Result.Query (and the inner analyzer.Result + .Query of its findings) RAW — the user typed the query on their own CLI, + it never reaches a log/telemetry sink, and Fingerprint is still set. Do + NOT flag explain findings for not redacting Query; that is intended. - path: "**/*_test.go" instructions: >- diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 0000000..f8c374a --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,43 @@ +--- +name: Bug report +about: Report incorrect behavior, a false positive/negative, or a crash +title: "" +labels: bug +assignees: "" +--- + +**Do not file security vulnerabilities here** — see [SECURITY.md](../../SECURITY.md). + +## What happened + +A clear description of the bug. + +## Expected behavior + +What you expected instead. For a false positive/negative, say which **rule** +(e.g. `select-star`) fired or failed to fire. + +## Reproduction + +The SQL or Go snippet, and how it was issued: + +```sql +-- query (redacted is fine) +``` + +```go +// minimal repro +``` + +## Environment + +- sqlguard version / commit: +- Affected module(s) (root, `integrations/`, `parsers/`): +- Parser in use (default fallback / pgparser / mysqlparser): +- Entry surface (runtime middleware / CLI `scan` / CLI `explain` / integration): +- Go version: +- Database + dialect (if relevant): + +## Additional context + +Logs (redaction-safe), config (`.sqlguard.yml`), or anything else useful. diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000..d9cc535 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,8 @@ +blank_issues_enabled: false +contact_links: + - name: Security vulnerability + url: https://github.com/KARTIKrocks/sqlguard/security/advisories/new + about: Report security issues privately — please do not open a public issue. + - name: Question / discussion + url: https://github.com/KARTIKrocks/sqlguard/discussions + about: Ask usage questions or discuss ideas here. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 0000000..829c5ac --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,31 @@ +--- +name: Feature request +about: Suggest a new rule, integration, or capability +title: "" +labels: enhancement +assignees: "" +--- + +## Problem + +What are you trying to catch or do that sqlguard can't today? + +## Proposed solution + +What you'd like to see. If you're proposing a **new detection rule**, include: + +- the SQL anti-pattern it should flag, +- example queries that should and should **not** trigger it, +- a suggested severity (info / warning / critical), +- any tunable (and its default). + +If you're proposing a **new integration**, name the ORM/driver and its +hook/seam. + +## Alternatives considered + +Other approaches, workarounds, or existing rules/config that almost fit. + +## Additional context + +Anything else — links, prior art, willingness to send a PR. diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..12ab443 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,27 @@ +## Summary + +What does this PR change, and why? + +Closes # + +## Type of change + +- [ ] Bug fix +- [ ] New detection rule +- [ ] New integration / parser +- [ ] Feature / enhancement +- [ ] Docs only +- [ ] Refactor / chore + +## Checklist + +- [ ] `make ci` passes (fmt-check, vet, lint, test-race) across all modules +- [ ] Added/updated tests (and, where practical, a failure-mode check) +- [ ] Updated docs as needed (`README.md`, `CLAUDE.md`, `.sqlguard.example.yml`) +- [ ] Added an entry under `## [Unreleased]` in `CHANGELOG.md` +- [ ] No new third-party deps in `analyzer` / `middleware` / `reporter` +- [ ] Findings stay redaction-safe (no raw literals leak into a `Result`) + +## Notes for reviewers + +Anything reviewers should focus on — tricky areas, trade-offs, follow-ups. diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..e47fffe --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,56 @@ +version: 2 + +updates: + - package-ecosystem: gomod + directory: / + schedule: + interval: weekly + groups: + go-dependencies: + patterns: + - "*" + + - package-ecosystem: gomod + directory: /integrations/gormguard + schedule: + interval: weekly + + - package-ecosystem: gomod + directory: /integrations/sqlxguard + schedule: + interval: weekly + + - package-ecosystem: gomod + directory: /integrations/pgxguard + schedule: + interval: weekly + + - package-ecosystem: gomod + directory: /integrations/bunguard + schedule: + interval: weekly + + - package-ecosystem: gomod + directory: /integrations/xormguard + schedule: + interval: weekly + + - package-ecosystem: gomod + directory: /integrations/entguard + schedule: + interval: weekly + + - package-ecosystem: gomod + directory: /parsers/pgparser + schedule: + interval: weekly + + - package-ecosystem: gomod + directory: /parsers/mysqlparser + schedule: + interval: weekly + + - package-ecosystem: github-actions + directory: / + schedule: + interval: weekly diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..90fdb30 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,109 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: read + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + go-version: ["1.26"] + steps: + - uses: actions/checkout@v6 + with: + persist-credentials: false + + - uses: actions/setup-go@v6 + with: + go-version: ${{ matrix.go-version }} + + - name: Run tests + run: go test ./... -count=1 -race + + - name: Test integrations (gormguard) + run: cd integrations/gormguard && go test ./... -count=1 -race + + - name: Test integrations (sqlxguard) + run: cd integrations/sqlxguard && go test ./... -count=1 -race + + - name: Test integrations (pgxguard) + run: cd integrations/pgxguard && go test ./... -count=1 -race + + - name: Test integrations (bunguard) + run: cd integrations/bunguard && go test ./... -count=1 -race + + - name: Test integrations (xormguard) + run: cd integrations/xormguard && go test ./... -count=1 -race + + - name: Test integrations (entguard) + run: cd integrations/entguard && go test ./... -count=1 -race + + - name: Test parsers (pgparser) + run: cd parsers/pgparser && go test ./... -count=1 -race + + - name: Test parsers (mysqlparser) + run: cd parsers/mysqlparser && go test ./... -count=1 -race + + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + with: + persist-credentials: false + + - uses: actions/setup-go@v6 + with: + go-version: "1.26" + + - uses: golangci/golangci-lint-action@v9 + with: + version: v2.11 + args: --timeout=5m + + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + with: + persist-credentials: false + + - uses: actions/setup-go@v6 + with: + go-version: "1.26" + + - name: Build CLI + run: go build -o bin/sqlguard ./cmd/sqlguard + + coverage: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + with: + persist-credentials: false + + - uses: actions/setup-go@v6 + with: + go-version: "1.26" + + # `make coverage` runs every module and merges into a single coverage.out + # (root go test does not reach the satellite modules). + - name: Generate merged coverage + run: make coverage + + - name: Upload to Codecov + uses: codecov/codecov-action@v5 + with: + files: ./coverage.out + token: ${{ secrets.CODECOV_TOKEN }} + fail_ci_if_error: false diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml new file mode 100644 index 0000000..7d5a6da --- /dev/null +++ b/.github/workflows/codeql.yml @@ -0,0 +1,63 @@ +name: CodeQL + +on: + push: + branches: [main] + pull_request: + branches: [main] + schedule: + # Weekly re-scan so newly published CodeQL queries flag old code too. + - cron: "0 6 * * 1" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.event_name != 'schedule' }} + +permissions: + # CodeQL requires security-events: write to upload SARIF results + security-events: write + contents: read + +jobs: + analyze: + name: Analyze (Go) + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v6 + with: + persist-credentials: false + + - name: Setup Go + uses: actions/setup-go@v6 + with: + go-version: "1.26" + + - name: Initialize CodeQL + uses: github/codeql-action/init@v4 + with: + languages: go + # Build the modules ourselves (below) so the tracer sees all nine. + build-mode: manual + queries: security-extended + + # Each integration/parser carries its own go.mod (heavy deps kept opt-in), + # so `go build ./...` from root does NOT reach them. Build every module + # under the CodeQL tracer so all nine are analyzed — same MODULES loop the + # Makefile uses; a satellite must not silently skip scanning. + - name: Build all modules + run: | + set -e + for mod in . \ + ./integrations/gormguard ./integrations/sqlxguard \ + ./integrations/pgxguard ./integrations/bunguard \ + ./integrations/xormguard ./integrations/entguard \ + ./parsers/pgparser ./parsers/mysqlparser; do + echo "==> Building $mod" + (cd "$mod" && go build ./...) + done + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v4 + with: + category: "/language:go" diff --git a/.gitignore b/.gitignore index aaadf73..1151f4d 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ *.dll *.so *.dylib +bin # Test binary, built with `go test -c` *.test @@ -17,6 +18,9 @@ coverage.* *.coverprofile profile.cov +# FE +sqlguard-website + # Dependency directories (remove the comment below to include it) # vendor/ diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..4d5beff --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,33 @@ +version: "2" + +linters: + enable: + - errcheck + - govet + - staticcheck + - unused + - ineffassign + - misspell + - gocritic + - gocyclo + - revive + - prealloc + settings: + gocyclo: + min-complexity: 15 + revive: + rules: + - name: exported + exclusions: + rules: + - linters: + - errcheck + path: _test\.go + - linters: + - errcheck + path: examples/ + +formatters: + enable: + - gofmt + - goimports diff --git a/.sqlguard.example.yml b/.sqlguard.example.yml new file mode 100644 index 0000000..556d714 --- /dev/null +++ b/.sqlguard.example.yml @@ -0,0 +1,73 @@ +# sqlguard configuration. Copy to `.sqlguard.yml` at your project root +# (sqlguard discovers it by walking up from the scanned/working directory +# until it hits the file or the git root). +# +# Every key is optional; omitting the file runs all rules at their defaults. + +version: 1 + +# strict: true turns "soft" problems (unknown keys, unknown rule names, +# invalid severities) into hard errors instead of warnings. Leave false so a +# config written for a newer sqlguard still loads on an older binary. +strict: false + +rules: + # Turn rules off entirely. + disable: + - orderby-without-limit + + # Whitelist mode: when non-empty, ONLY these rules run (disable is ignored). + # only: + # - delete-without-where + # - update-without-where + + # Override the reported severity per rule: info | warning | critical | off + # ("off" is equivalent to disabling the rule). + severity: + select-star: info + select-without-limit: "off" + + # Per-rule tunables. Keys are rule-specific. + settings: + leading-wildcard: + # Don't flag LIKE/ILIKE '%x%' style patterns whose searchable term is + # shorter than this many characters. + min-length: 3 + in-list-too-large: + # Flag IN (...) value lists with more than this many elements + # (default 100). Subquery INs are never counted. + max-length: 100 + large-offset: + # Flag a literal OFFSET above this (default 1000) — deep pagination. + # Parameterized offsets (OFFSET $1 / ?) can't be evaluated statically. + threshold: 1000 + +# Redact literal values (strings/numbers) out of Result.Query before it +# reaches any reporter/log. ON by default — leave it on so customer data in +# query literals never lands in your logs. Result.Fingerprint (a PII-free, +# value-free query identity) is emitted regardless. Set to false ONLY for +# local debugging where the query text is trusted. +redact: true + +# Runtime slow-query threshold (middleware). Go duration string. +slow-query: + threshold: 200ms + +# Runtime de-duplication of repeated static findings (middleware). The same +# finding (rule + query fingerprint) is reported at most once per window, so a +# recurring query doesn't flood your logs. Default 1m. Set "0" to disable +# (report every occurrence). Slow-query and N+1 have their own emission policy. +dedup: + window: 1m + +# Static scanner only: skip files whose path matches any of these regexes. +scan: + exclude-paths: + - "(^|/)legacy/" + - "_gen\\.go$" + +# Inline suppressions (no config needed): +# In SQL: SELECT * FROM t -- sqlguard:ignore +# DELETE FROM t /* sqlguard:ignore:delete-without-where */ +# In Go: // sqlguard:ignore (on or above the db call) +# db.Query(q) // sqlguard:ignore:select-star diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..31cfd92 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,61 @@ +# Changelog + +All notable changes to this project are documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +Each Go module in this repo (root, `integrations/*`, `parsers/*`) is tagged with +the same version in lockstep. + +## [Unreleased] + +## [0.1.0] - 2026-06-08 + +Initial public release. + +### Added + +- **Runtime middleware** that intercepts at the `database/sql` **driver** layer + (`Register` / `OpenDB`), so any query — including those issued by ORMs and + query builders — is analyzed and you get back a real `*sql.DB`. Zero + third-party dependencies in the core. +- **Analyzer with 21 detection rules** across static, runtime, and EXPLAIN + surfaces: `select-star`, `leading-wildcard`, `non-sargable-predicate`, + `add-not-null-without-default`, `implicit-join`, `cartesian-join`, + `in-list-too-large`, `large-offset`, `select-distinct`, `delete-without-where`, + `update-without-where`, `insert-without-columns`, `select-without-limit`, + `orderby-without-limit`, `n-plus-one`, `slow-query`, `seq-scan`, + `full-table-scan`, `high-cost`, `no-index-used`, `filesort`. +- **Redaction by default**: every `Result.Query` is redacted (literals → `?`) + before it leaves the process, and every `Result.Fingerprint` is a PII-free, + low-cardinality query identity safe as a metric label. Opt out with + `WithRawQuery()` / `redact: false`. +- **N+1 detection** (windowed) and **slow-query detection** with configurable + thresholds. +- **Finding de-duplication** — each finding (rule + fingerprint) is reported at + most once per window (default 1m) to keep hot queries from flooding logs + (`WithFindingDedup`). +- **Per-query analysis cache** — an LRU keyed on the exact query string so + recurring queries are parsed and checked once (`WithAnalysisCacheSize`). +- **Pluggable parser**: a zero-dependency, never-erroring `FallbackParser` by + default, with opt-in real grammars in separate modules — `parsers/pgparser` + (PostgreSQL) and `parsers/mysqlparser` (MySQL) — via `WithParser`. +- **File configuration** (`.sqlguard.yml`, discovered up to the git root): + enable/disable rules, `only` whitelist, per-rule severity overrides, per-rule + settings, `redact`, `slow-query`, `dedup`, and scanner `exclude-paths`. + Lenient by default; `strict: true` makes unknown keys/rules fatal. +- **Inline suppressions** — in-SQL `-- sqlguard:ignore[:rules]` (honored at + runtime and statically) and Go-source `// sqlguard:ignore[:rules]` (honored by + the scanner). +- **CLI** (`cmd/sqlguard`): `scan` for static analysis of Go source (with + literal/constant resolution via `go/types`) and `explain` for live EXPLAIN + plan analysis. `explain` never executes the statement — it validates input and + runs inside an always-rolled-back read-only transaction. +- **ORM / driver integrations**, each a separate opt-in module built on the + shared `middleware.Guard` core (redaction, fingerprints, parser seam, + slow-query, N+1, and a `ResetN1()` per-request hook): `integrations/gormguard`, + `integrations/sqlxguard`, `integrations/pgxguard` (native pgx / pgxpool), + `integrations/bunguard`, `integrations/xormguard`, `integrations/entguard`. + +[Unreleased]: https://github.com/KARTIKrocks/sqlguard/compare/v0.1.0...HEAD +[0.1.0]: https://github.com/KARTIKrocks/sqlguard/releases/tag/v0.1.0 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..44f70e9 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,96 @@ +# Contributing to sqlguard + +Thanks for your interest in improving sqlguard! This guide covers the +project-specific things that aren't obvious from a quick look at the repo. + +## Project layout + +sqlguard is a **multi-module repo** — nine Go modules on Go 1.26, kept in +lockstep: + +- **root** (`github.com/KARTIKrocks/sqlguard`) — core analyzer, middleware, + reporter, config, and CLI. Deliberately near-zero-dependency. +- **`parsers/pgparser`, `parsers/mysqlparser`** — opt-in real SQL grammars, + isolated so their heavy parser deps never enter a consumer's build. +- **`integrations/{gormguard,sqlxguard,pgxguard,bunguard,xormguard,entguard}`** — + ORM/driver adapters, each a separate module so its deps stay opt-in. + +The satellite modules use a local `replace` directive pointing at the root, so +you can develop across modules without publishing. + +> **Important:** `go test ./...` (and `go build` / `go vet` / `go mod tidy`) +> from the root does **not** reach the satellite modules. Always use the +> Makefile targets, which loop over every module. + +## Development workflow + +```bash +make setup # install pinned golangci-lint + goimports (one-time) +make all # tidy, fmt, vet, lint, build, test across all nine modules +make ci # what CI runs: fmt-check, vet, lint, test-race +make test-race # race detector (required for anything touching middleware) +make help # list every target +``` + +Before opening a PR, run `make ci` and make sure it's green. + +- Run a single test: `go test ./middleware/ -run TestName -count=1`. +- Use `-race` for anything touching `middleware` (the driver chain and + `QueryTracker` are concurrent). +- After any dependency change, run `make tidy` (tidies all nine modules — tidying + only the root leaves the others stale). + +## Conventions + +- **Pre-1.0, no backward-compatibility burden.** Prefer the clean design over + preserving an existing public API; don't add deprecation shims or compat + layers. +- Modern Go idioms are expected (range-over-int, `any`, compile-time interface + asserts `var _ I = (*T)(nil)`). +- Keep the **core dependency-light**: `analyzer`, `middleware`, and `reporter` + must stay free of third-party deps and of YAML. `config` is the only + YAML-aware package. +- **Redaction is the default.** Never let raw literal values reach a `Result` + that leaves the process. There is one canonical normalizer (`analyzer.Redact` + / `Fingerprint`) — don't add a second. +- See [`CLAUDE.md`](CLAUDE.md) for the deeper architecture notes and invariants, + and [`PRODUCTION_READINESS.md`](PRODUCTION_READINESS.md) for the roadmap. + +## Adding a detection rule + +Rules self-register. Write the rule, then add one `analyzer.Register(RuleSpec{ +... })` call in `analyzer/rules.go` (a stable name, default severity, and a +settings-aware factory). Being addressable by name is what makes enable/disable, +severity overrides, per-rule settings, and suppressions all work uniformly — +do **not** hand-maintain a rule list. Rules read the normalized `Statement`, +never raw SQL. + +If your rule has a tunable, read it from `Settings` in the factory and document +it in [`.sqlguard.example.yml`](.sqlguard.example.yml). + +## Adding an integration + +Every integration must build on the exported `middleware.Guard` core — +`integrations/pgxguard` is the reference. Hand-rolling analysis silently loses +redaction, fingerprints, the parser seam, config, N+1, and dedup. Each +integration should expose `ResetN1()` for per-request scoping. + +## Pull requests + +1. Fork and branch from `main`. +2. Keep changes focused; update docs (`README.md`, `CLAUDE.md`, + `.sqlguard.example.yml`) when behavior or config changes. +3. Add tests for new behavior; where practical, also prove the failure mode + (e.g. a bug-reintroduction check). +4. Add a line under `## [Unreleased]` in [`CHANGELOG.md`](CHANGELOG.md). +5. Run `make ci` and ensure it passes. + +## Reporting security issues + +Please do **not** open a public issue for security vulnerabilities. See +[`SECURITY.md`](SECURITY.md) for the private reporting process. + +## License + +By contributing, you agree that your contributions are licensed under the +project's [MIT License](LICENSE). diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..d205223 --- /dev/null +++ b/Makefile @@ -0,0 +1,194 @@ +GOLANGCI_LINT_VERSION := v2.12.2 +GOIMPORTS_VERSION := v0.45.0 + +MODULE_PATH := github.com/KARTIKrocks/sqlguard + +# Sub-modules carry their own go.mod (heavy/opt-in deps kept out of the core +# import graph). `go test ./...` from root does NOT reach them, so every +# all-modules target loops over MODULES. +SUB_MODULES = \ + ./integrations/gormguard \ + ./integrations/sqlxguard \ + ./integrations/pgxguard \ + ./integrations/bunguard \ + ./integrations/xormguard \ + ./integrations/entguard \ + ./parsers/pgparser \ + ./parsers/mysqlparser +MODULES = . $(SUB_MODULES) + +.PHONY: all help setup deps ci test test-v test-race coverage lint lint-fix fix fmt fmt-check vet tidy build cli install bench clean release-prep + +all: tidy fmt vet lint build test + +## Show available targets +help: + @echo "Available targets:" + @echo " all - Tidy, format, vet, lint, build, test (all modules)" + @echo " setup - Install development tools" + @echo " deps - Download module dependencies (all modules)" + @echo " ci - CI pipeline (fmt-check, vet, lint, test-race)" + @echo " test - Run tests across all modules" + @echo " test-v - Run tests with verbose output (all modules)" + @echo " test-race - Run tests with race detector (all modules)" + @echo " coverage - Run tests with merged coverage report (all modules)" + @echo " vet - Run go vet (all modules)" + @echo " lint - Run golangci-lint (all modules)" + @echo " lint-fix - Run golangci-lint with --fix (root module)" + @echo " fix - fmt + lint-fix" + @echo " fmt - Format code (gofmt -s + goimports)" + @echo " fmt-check - Verify formatting without modifying files" + @echo " tidy - Run go mod tidy (all modules)" + @echo " build - Build all packages (all modules)" + @echo " cli - Build the sqlguard CLI to bin/sqlguard" + @echo " install - Install the CLI to \$$GOPATH/bin" + @echo " bench - Run benchmarks (all modules)" + @echo " clean - Remove build/coverage artifacts" + @echo " release-prep - Pin sub-modules to a release version (VERSION=vX.Y.Z)" + +## Install development tools (skips if already present) +setup: + @command -v golangci-lint >/dev/null 2>&1 || { \ + echo "Installing golangci-lint $(GOLANGCI_LINT_VERSION)..."; \ + go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@$(GOLANGCI_LINT_VERSION); \ + } + @command -v goimports >/dev/null 2>&1 || { \ + echo "Installing goimports $(GOIMPORTS_VERSION)..."; \ + go install golang.org/x/tools/cmd/goimports@$(GOIMPORTS_VERSION); \ + } + +## Download module dependencies across all modules +deps: + @for mod in $(MODULES); do \ + echo "==> Downloading deps $$mod"; \ + (cd $$mod && go mod download) || exit 1; \ + done + +## CI: run formatting check, vet, lint and tests with race detector +ci: fmt-check vet lint test-race + +## Build all packages across all modules (compile check) +build: + @for mod in $(MODULES); do \ + echo "==> Building $$mod"; \ + (cd $$mod && go build ./...) || exit 1; \ + done + +## Build the CLI binary +cli: + @echo "==> Building bin/sqlguard" + @go build -o bin/sqlguard ./cmd/sqlguard + +## Install the CLI to $GOPATH/bin +install: + go install ./cmd/sqlguard + +## Run tests across all modules +test: + @for mod in $(MODULES); do \ + echo "==> Testing $$mod"; \ + (cd $$mod && go test -count=1 ./...) || exit 1; \ + done + +## Run tests with verbose output across all modules +test-v: + @for mod in $(MODULES); do \ + echo "==> Testing (verbose) $$mod"; \ + (cd $$mod && go test -v -count=1 ./...) || exit 1; \ + done + +## Run tests with race detector across all modules +test-race: + @for mod in $(MODULES); do \ + echo "==> Testing (race) $$mod"; \ + (cd $$mod && go test -race -count=1 ./...) || exit 1; \ + done + +## Run tests with coverage and generate a merged report across all modules +coverage: + @echo "mode: atomic" > coverage.out + @for mod in $(MODULES); do \ + echo "==> Coverage $$mod"; \ + (cd $$mod && go test -race -covermode=atomic -coverprofile=cover.tmp ./...) || exit 1; \ + if [ -f $$mod/cover.tmp ]; then tail -n +2 $$mod/cover.tmp >> coverage.out && rm $$mod/cover.tmp; fi; \ + done + @go tool cover -func=coverage.out | tail -1 + @echo "Full report: go tool cover -html=coverage.out" + +## Run linter across all modules +lint: setup + @for mod in $(MODULES); do \ + echo "==> Linting $$mod"; \ + (cd $$mod && golangci-lint run --timeout=5m ./...) || exit 1; \ + done + +## Run golangci-lint with auto-fix (root module) +lint-fix: setup + golangci-lint run --fix ./... + +## Fix code formatting and linting issues +fix: fmt lint-fix + +## Format code (recurses the whole tree, all modules) +fmt: setup + @gofmt -s -w . + @goimports -w . + +## Check formatting without modifying files (used in CI) +fmt-check: setup + @test -z "$$(gofmt -s -l . | tee /dev/stderr)" || { echo "Unformatted files found. Run 'make fmt'."; exit 1; } + @test -z "$$(goimports -l . | tee /dev/stderr)" || { echo "Unordered imports found. Run 'make fmt'."; exit 1; } + +## Run go vet across all modules +vet: + @for mod in $(MODULES); do \ + echo "==> Vetting $$mod"; \ + (cd $$mod && go vet ./...) || exit 1; \ + done + +## Run go mod tidy across all modules +tidy: + @for mod in $(MODULES); do \ + echo "==> Tidying $$mod"; \ + (cd $$mod && go mod tidy) || exit 1; \ + done + +## Run benchmarks across all modules +bench: + @for mod in $(MODULES); do \ + echo "==> Benchmarking $$mod"; \ + (cd $$mod && go test -bench=. -benchmem -run='^$$' ./...) || exit 1; \ + done + +## Remove build and coverage artifacts +clean: + @rm -f coverage*.out cover.tmp coverage.txt coverage.html + @find . -name cover.tmp -delete 2>/dev/null || true + @rm -rf dist/ build/ bin/ + +## Prepare sub-modules for release: drop the local replace and pin the parent +## version. Usage: make release-prep VERSION=v0.1.0 +## Run this AFTER the root module tag for VERSION exists and is pushed, then +## commit and tag the sub-modules. Restore replaces afterwards for local dev +## (git checkout -- '**/go.mod') or develop against the published version. +release-prep: +ifndef VERSION + $(error VERSION is required. Usage: make release-prep VERSION=v0.1.0) +endif + @for mod in $(SUB_MODULES); do \ + echo "==> release-prep $$mod"; \ + (cd $$mod && go mod edit -dropreplace $(MODULE_PATH) -require $(MODULE_PATH)@$(VERSION)) || exit 1; \ + done + @echo "" + @echo "Done! Sub-modules now require $(MODULE_PATH)@$(VERSION) (replace dropped)." + @echo "Next steps (root tag $(VERSION) must already be pushed):" + @echo " git add -A && git commit -m 'Prepare release $(VERSION)'" + @echo " git tag integrations/gormguard/$(VERSION)" + @echo " git tag integrations/sqlxguard/$(VERSION)" + @echo " git tag integrations/pgxguard/$(VERSION)" + @echo " git tag integrations/bunguard/$(VERSION)" + @echo " git tag integrations/xormguard/$(VERSION)" + @echo " git tag integrations/entguard/$(VERSION)" + @echo " git tag parsers/pgparser/$(VERSION)" + @echo " git tag parsers/mysqlparser/$(VERSION)" + @echo " git push origin main --tags" diff --git a/README.md b/README.md new file mode 100644 index 0000000..d4f2b00 --- /dev/null +++ b/README.md @@ -0,0 +1,516 @@ +# sqlguard + +[![Go Reference](https://pkg.go.dev/badge/github.com/KARTIKrocks/sqlguard.svg)](https://pkg.go.dev/github.com/KARTIKrocks/sqlguard) +[![Go Report Card](https://goreportcard.com/badge/github.com/KARTIKrocks/sqlguard)](https://goreportcard.com/report/github.com/KARTIKrocks/sqlguard) +[![Go Version](https://img.shields.io/github/go-mod/go-version/KARTIKrocks/sqlguard)](go.mod) +[![CI](https://github.com/KARTIKrocks/sqlguard/actions/workflows/ci.yml/badge.svg)](https://github.com/KARTIKrocks/sqlguard/actions/workflows/ci.yml) +[![CodeQL](https://github.com/KARTIKrocks/sqlguard/actions/workflows/codeql.yml/badge.svg)](https://github.com/KARTIKrocks/sqlguard/actions/workflows/codeql.yml) +[![GitHub tag](https://img.shields.io/github/v/tag/KARTIKrocks/sqlguard)](https://github.com/KARTIKrocks/sqlguard/releases) +[![License](https://img.shields.io/github/license/KARTIKrocks/sqlguard)](LICENSE) +[![codecov](https://codecov.io/gh/KARTIKrocks/sqlguard/branch/main/graph/badge.svg)](https://codecov.io/gh/KARTIKrocks/sqlguard) + +Production-safe SQL query analyzer for Go applications. + +Detects slow queries, dangerous SQL patterns, and performance issues — both at runtime and statically. Think of it as `golangci-lint` for SQL queries. + +## Install + +```bash +go get github.com/KARTIKrocks/sqlguard +``` + +CLI tool: + +```bash +go install github.com/KARTIKrocks/sqlguard/cmd/sqlguard@latest +``` + +## Detection Rules + +| Rule | Severity | Description | +| ------------------------------ | -------- | ------------------------------------------------------------------------------------------------ | +| `select-star` | WARNING | `SELECT *` — selects all columns unnecessarily | +| `leading-wildcard` | WARNING | `LIKE '%...'` (and `ILIKE`) — index cannot be used | +| `non-sargable-predicate` | WARNING | `WHERE LOWER(col) = ...` — function on column defeats its index | +| `add-not-null-without-default` | WARNING | `ALTER TABLE ... ADD COLUMN ... NOT NULL` without `DEFAULT` — fails / rewrites a populated table | +| `implicit-join` | WARNING | `FROM a, b` — comma join; a forgotten condition becomes a cartesian product | +| `cartesian-join` | WARNING | Multiple tables with no join condition or `WHERE` — a cartesian product (incl. `CROSS JOIN`) | +| `in-list-too-large` | WARNING | `IN (...)` value list with more than `max-length` (default 100) elements | +| `large-offset` | WARNING | `OFFSET` above `threshold` (default 1000) — deep pagination scans/discards skipped rows | +| `select-distinct` | INFO | `SELECT DISTINCT` — often masks duplicate rows from an unintended join | +| `delete-without-where` | CRITICAL | `DELETE` without `WHERE` — deletes all rows | +| `update-without-where` | CRITICAL | `UPDATE` without `WHERE` — updates all rows | +| `insert-without-columns` | WARNING | `INSERT` without an explicit column list (`VALUES` or `... SELECT`) — breaks on schema change | +| `select-without-limit` | WARNING | `SELECT` without `LIMIT` or `WHERE` — may return excessive rows | +| `orderby-without-limit` | INFO | `ORDER BY` without `LIMIT` — sorts entire result set | +| `n-plus-one` | WARNING | Same query pattern repeated N times (runtime only) | +| `slow-query` | WARNING | Query exceeds latency threshold (runtime only) | +| `seq-scan` | WARNING | Sequential scan detected via EXPLAIN (postgres) | +| `full-table-scan` | WARNING | Full table scan detected via EXPLAIN (mysql) | +| `high-cost` | WARNING | High cost operation in query plan | +| `no-index-used` | WARNING | No index used for a table access detected via EXPLAIN (mysql) | +| `filesort` | INFO | `Using filesort` in the query plan — `ORDER BY` not covered by an index (mysql) | + +## Configuration + +Drop a `.sqlguard.yml` at your project root. sqlguard discovers it by walking +up from the scanned (or working) directory until it finds the file or the git +root. The CLI takes `--config ` and `--no-config`; the file is optional +— without it every rule runs at its default. A fully-commented template lives +at [`.sqlguard.example.yml`](.sqlguard.example.yml). + +```yaml +version: 1 +rules: + disable: [orderby-without-limit] + severity: + select-star: info # info | warning | critical | off + select-without-limit: "off" # "off" disables the rule + settings: + leading-wildcard: + min-length: 3 # ignore short LIKE '%x%' patterns + in-list-too-large: + max-length: 100 # flag IN (...) lists longer than this + large-offset: + threshold: 1000 # flag literal OFFSET above this +redact: true # redact literals out of Result.Query (default) +slow-query: + threshold: 200ms # runtime middleware threshold +dedup: + window: 1m # report each repeated finding at most once per window ("0" disables) +scan: + exclude-paths: ["(^|/)legacy/"] # static scanner only, regex +``` + +Unknown keys and rule names are warnings, not errors, so a config written for +a newer sqlguard still loads on an older binary; set `strict: true` to make +them fatal. `only: [rule, ...]` switches to whitelist mode. + +**Inline suppressions** — no config required: + +```sql +SELECT * FROM users -- sqlguard:ignore +DELETE FROM users /* sqlguard:ignore:delete-without-where */ +``` + +```go +// sqlguard:ignore +db.Exec("DELETE FROM users") +db.Query("SELECT * FROM users") // sqlguard:ignore:select-star +``` + +In-SQL directives work at runtime _and_ in the static scanner; the Go-source +form is honored by the scanner when it sits on or directly above the call. + +Apply the same config to the runtime middleware: + +```go +opts, _ := config.Middleware("", ".") // discover from cwd +sqlguard.Register("sqlguard-pg", "pgx", opts...) +``` + +## Security & redaction + +sqlguard's findings flow into logs, so by **default it never emits raw +literal values**. Before any `Result` leaves the process its `Query` is +redacted — single-quoted strings and numeric literals become `?`, while +keywords, identifiers (including `"quoted"` / `` `backtick` `` names) and +structure are preserved: + +``` +[SQLGUARD WARNING] select-star + Query: SELECT * FROM users WHERE email = ? +``` + +Every `Result` also carries a `Fingerprint`: the redacted query with +whitespace collapsed and `IN (?, ?, ?)` folded to `(?)`. It is a stable, +PII-free, low-cardinality identity — safe as a metrics label or log key, and +the same value the N+1 detector groups on. The JSON reporter emits it as +`fingerprint`. + +Opt out only where the query text is trusted (local debugging): + +```go +a := analyzer.Default().WithRawQuery() // standalone analyzer +sqlguard.Register("pg", "pgx", middleware.WithAnalyzer(a)) +``` + +or `redact: false` in `.sqlguard.yml`. `Fingerprint` is populated either way. + +## Usage + +### Runtime Middleware + +sqlguard wraps at the `database/sql` **driver** layer, so you get back a real +`*sql.DB` and every query is analyzed automatically — including queries issued +by ORMs and query builders (sqlc, ent, sqlx, gorm, pgx-stdlib). There is no +wrapper type to thread through your code and no method list to keep in sync. + +```go +import ( + "database/sql" + "github.com/KARTIKrocks/sqlguard" + "github.com/KARTIKrocks/sqlguard/middleware" + "time" +) + +func main() { + // Register an analyzed driver by wrapping an existing one... + sqlguard.Register("sqlguard-pg", "postgres", + middleware.WithSlowQueryThreshold(500*time.Millisecond), + middleware.WithN1Detection(5, 2*time.Second), + ) + db, _ := sql.Open("sqlguard-pg", "...") // db is a plain *sql.DB + + // ...or wrap a driver.Connector directly (e.g. pgx stdlib): + // db := sqlguard.OpenDB(connector, middleware.WithN1Detection(5, time.Second)) + + // Use as normal — warnings are logged automatically + db.Query("SELECT * FROM users") + // Output: + // [SQLGUARD WARNING] select-star + // Query: SELECT * FROM users + // Issue: SELECT * detected. Selecting all columns can hurt performance. + // Fix: Select only the columns you need. +} +``` + +### N+1 Query Detection + +The middleware detects when the same query pattern executes repeatedly — a classic N+1 problem: + +```go +sqlguard.Register("sqlguard-pg", "postgres", + middleware.WithN1Detection(5, 2*time.Second), // flag after 5 similar queries in 2s +) +db, _ := sql.Open("sqlguard-pg", "...") +``` + +N+1 patterns are detected within the configured time window. On the raw +`database/sql` driver path you get back a plain `*sql.DB`, so detection is +process-wide (windowed) — there is no handle to scope it per request. The +integration adapters (`gormguard`, `pgxguard`, `sqlxguard`, `bunguard`, +`xormguard`, `entguard`) hold the guard and expose `ResetN1()` to scope +detection to a single unit of work; call it at a request boundary. + +### Noise control (finding de-duplication) + +A recurring query would otherwise re-emit the same static warning on every +execution. By default the runtime middleware reports each finding (rule + query +fingerprint) **at most once per minute**, so a hot query doesn't flood your +logs. Tune or disable it: + +```go +sqlguard.Register("sqlguard-pg", "postgres", + middleware.WithFindingDedup(5*time.Minute), // quieter +) +sqlguard.Register("sqlguard-pg", "postgres", + middleware.WithFindingDedup(0), // disable: report every occurrence +) +``` + +Or set `dedup.window` in `.sqlguard.yml`. Slow-query and N+1 findings have +their own emission policy and are unaffected. + +The middleware also memoizes analysis per distinct query string (an LRU keyed on +the exact query — correct even for the literal-sensitive rules), so a recurring +query is parsed and rule-checked once rather than on every execution. A repeated +query then costs a cache lookup instead of a full parse (≈1000× cheaper, zero +allocations in the repeat case). Default 1024 entries; tune with +`middleware.WithAnalysisCacheSize(n)` or disable with `n == 0`. + +### CLI Static Scanner + +Scan your Go source code for SQL issues without running the application: + +```bash +# Scan current directory +sqlguard scan . + +# Scan specific package +sqlguard scan ./internal/repository + +# JSON output (for CI pipelines) +sqlguard scan --format json ./... +``` + +Exit code is **1** when issues are found, **0** when clean — works with CI/CD pipelines. + +### EXPLAIN Plan Analyzer + +Connect to a live database and analyze query plans: + +```bash +# PostgreSQL +sqlguard explain --db "postgres://user:pass@localhost/mydb?sslmode=disable" \ + "SELECT * FROM orders WHERE user_id = 42" + +# MySQL +sqlguard explain --dialect mysql --db "user:pass@tcp(localhost:3306)/mydb" \ + "SELECT * FROM orders WHERE user_id = 42" + +# JSON output +sqlguard explain --db "..." --format json "SELECT * FROM orders" +``` + +Detects sequential scans, missing indexes, filesort, and high-cost operations. + +For safety the EXPLAIN runs inside a **read-only transaction that is always +rolled back** (Postgres and MySQL), and `ANALYZE` is never used — the +statement is planned, never executed. Input is validated with a +comment- and string-literal-aware multi-statement check (a `;` hidden in a +comment or string can't smuggle a second statement). Only `SELECT`/`WITH` is +allowed by default; pass `--allow-dml` to EXPLAIN an `INSERT/UPDATE/DELETE` +(still rolled back). DDL/`SET`/transaction-control is always refused. + +### GORM Integration + +```bash +go get github.com/KARTIKrocks/sqlguard/integrations/gormguard +``` + +```go +import ( + "github.com/KARTIKrocks/sqlguard/integrations/gormguard" + "github.com/KARTIKrocks/sqlguard/middleware" +) + +gormDB, _ := gorm.Open(postgres.Open(dsn), &gorm.Config{}) + +// Register as GORM plugin — hooks into all queries automatically +gormguard.Register(gormDB) + +// Or customize via the standard middleware options +gormguard.Register(gormDB, + middleware.WithSlowQueryThreshold(500*time.Millisecond), + middleware.WithN1Detection(10, time.Second), +) +``` + +### sqlx Integration + +```bash +go get github.com/KARTIKrocks/sqlguard/integrations/sqlxguard +``` + +```go +import ( + "github.com/KARTIKrocks/sqlguard/integrations/sqlxguard" + "github.com/KARTIKrocks/sqlguard/middleware" +) + +sqlxDB := sqlx.MustConnect("postgres", dsn) + +db := sqlxguard.WrapSqlx(sqlxDB, + middleware.WithSlowQueryThreshold(500*time.Millisecond), +) + +var users []User +db.Select(&users, "SELECT * FROM users") // warns about SELECT * +``` + +### pgx Integration (native pgx / pgxpool) + +The `database/sql` driver wrapper covers pgx-stdlib (`pgx/v5/stdlib`). For the +**native pgx APIs** (`pgxpool.Pool`, `pgx.Conn` — which bypass `database/sql` +entirely) use `pgxguard`. It hooks pgx's own tracer seam, so every +`Query`/`QueryRow`/`Exec` and every `SendBatch` is analyzed without a wrapper +type or a method list. + +```bash +go get github.com/KARTIKrocks/sqlguard/integrations/pgxguard +``` + +```go +import ( + "github.com/KARTIKrocks/sqlguard/integrations/pgxguard" + "github.com/KARTIKrocks/sqlguard/middleware" + "github.com/jackc/pgx/v5/pgxpool" +) + +cfg, _ := pgxpool.ParseConfig(dsn) +pgxguard.ApplyPool(cfg, + middleware.WithSlowQueryThreshold(50*time.Millisecond), + middleware.WithN1Detection(10, time.Second), +) +pool, _ := pgxpool.NewWithConfig(ctx, cfg) +``` + +`Apply` (for `*pgx.ConnConfig`) and `ApplyPool` (for `*pgxpool.Config`) +**compose** with any tracer already installed via pgx's own `multitracer`, +so sqlguard coexists with `otelpgx`, `ddtrace` and friends rather than +silently overwriting them. Configuration is the standard `middleware.Option` +set — same as the driver wrapper, no parallel surface to learn. + +Coverage: `Query` / `QueryRow` / `Exec` (via `pgx.QueryTracer`) and +`SendBatch` (via `pgx.BatchTracer`). Prepared-statement execution is already +covered by `QueryTracer`, so `PrepareTracer` is deliberately omitted to avoid +double-reporting. `CopyFrom` carries no SQL and is out of scope. + +### bun / xorm Integrations + +bun and xorm build SQL through their own query layers and expose native +before/after hook seams. `bunguard` and `xormguard` plug into those seams and +run every statement through the same shared core — same `middleware.Option` +set, no parallel surface. + +```bash +go get github.com/KARTIKrocks/sqlguard/integrations/bunguard +go get github.com/KARTIKrocks/sqlguard/integrations/xormguard +``` + +```go +// bun — register a QueryHook +db.AddQueryHook(bunguard.New( + middleware.WithSlowQueryThreshold(500*time.Millisecond), + middleware.WithN1Detection(10, time.Second), +)) + +// xorm — register a Hook +engine.AddHook(xormguard.New( + middleware.WithSlowQueryThreshold(500*time.Millisecond), +)) +``` + +### ent Integration + +ent runs on `database/sql`, so the simplest coverage is to point `entsql` at a +`*sql.DB` from `sqlguard.Register`/`OpenDB`. `entguard` is the dedicated +alternative: it decorates ent's own `dialect.Driver`, so it covers every +`Exec`/`Query` (and transactions it opens) regardless of how the `*sql.DB` was +created. + +```bash +go get github.com/KARTIKrocks/sqlguard/integrations/entguard +``` + +```go +drv, _ := entsql.Open(dialect.Postgres, dsn) +guarded := entguard.Wrap(drv, + middleware.WithSlowQueryThreshold(500*time.Millisecond), + middleware.WithN1Detection(10, time.Second), +) +client := ent.NewClient(ent.Driver(guarded)) +``` + +Every adapter (`gormguard`, `bunguard`, `xormguard`, `entguard`, `pgxguard`, +`sqlxguard`) exposes a `ResetN1()` you can call at a per-request boundary to +scope N+1 detection to one unit of work. + +### SQL Parsers (accuracy vs. zero dependencies) + +By default the analyzer uses a **zero-dependency fallback parser**: it strips +SQL comments and string-literal contents before pattern matching, so keywords +inside comments/strings and identifiers like `update_at` no longer cause false +positives. It never errors — SQL it can't fully understand still yields a +best-effort result, so analysis never breaks your query path. + +For **exact, structural analysis**, opt into a real grammar. These live in +separate modules so the core stays dependency-free: + +```bash +go get github.com/KARTIKrocks/sqlguard/parsers/pgparser # PostgreSQL (pure Go, no cgo) +go get github.com/KARTIKrocks/sqlguard/parsers/mysqlparser # MySQL (pure Go, no cgo) +``` + +```go +import ( + "github.com/KARTIKrocks/sqlguard" + "github.com/KARTIKrocks/sqlguard/middleware" + "github.com/KARTIKrocks/sqlguard/parsers/pgparser" +) + +sqlguard.Register("sqlguard-pg", "pgx", middleware.WithParser(pgparser.New())) +db, _ := sql.Open("sqlguard-pg", dsn) + +// Or with the standalone analyzer: +a := analyzer.Default().WithParser(pgparser.New()) +``` + +A real parser drives the false-positive-prone facts (statement kind, +WHERE/LIMIT/ORDER BY/FROM presence, `SELECT *`, `SELECT DISTINCT`, `OFFSET`, +explicit INSERT columns) from the AST instead of regex. CTEs, subqueries, and +dialect syntax are handled correctly; anything the grammar rejects (dynamic SQL, +driver placeholders) transparently degrades to the fallback parser. + +A few facts stay lexical heuristics even with a real parser, because they read +literal values the AST discards or are intentionally text-level: IN-list size +(`in-list-too-large`), comma/cartesian joins (`implicit-join` / +`cartesian-join`), and the literal/text checks (`leading-wildcard`, +`non-sargable-predicate`, `add-not-null-without-default`). These keep their +zero-dependency, best-effort behavior regardless of the parser. + +### Custom Rules + +```go +import "github.com/KARTIKrocks/sqlguard/analyzer" + +// Create analyzer with only the rules you want +a := analyzer.New( + analyzer.CheckDeleteWithoutWhere, + analyzer.CheckUpdateWithoutWhere, +) + +// Or use all defaults +a := analyzer.Default() + +// Analyze a query +results := a.Analyze("DELETE FROM users") +for _, r := range results { + fmt.Printf("[%s] %s: %s\n", r.Severity, r.RuleName, r.Message) +} +``` + +## Development + +```bash +make help # List all targets +make all # tidy, fmt, vet, lint, build, test (all modules) +make build # Compile all modules; `make cli` builds bin/sqlguard +make test # Run tests across all modules (test-race adds -race) +make lint # Run golangci-lint across all modules +make fmt # gofmt -s + goimports +make tidy # go mod tidy across all modules +make install # Install the CLI to $GOPATH/bin +``` + +## Coverage + +The middleware wraps the `database/sql` **driver** chain, so _every_ query +is analyzed regardless of how it's issued (`Query`/`Exec`/`Prepare`/`Tx`, +context variants, and any ORM/query builder on top — sqlc, ent, sqlx, gorm, +pgx-stdlib). There is no method allowlist to keep in sync; you get back a +real `*sql.DB`. + +Opt-in adapter modules, each built on the same `middleware.Guard` core, +extend coverage to APIs that bypass or sit above the `database/sql` driver +path: + +- **`pgxguard`** — native pgx / pgxpool (which never goes through + `database/sql`), via pgx's own tracer seam. Composes with existing tracers + (otelpgx, ddtrace) via `multitracer`. Covers `Query`/`QueryRow`/`Exec` and + `SendBatch`. +- **`gormguard`** / **`bunguard`** / **`xormguard`** — hook each ORM's native + before/after callback seam (`gorm.Plugin`, `bun.QueryHook`, xorm + `contexts.Hook`). +- **`entguard`** — decorates ent's `dialect.Driver` (Exec/Query + the + transactions it opens). +- **`sqlxguard`** — sqlx-only helpers that build SQL outside the driver path: + `Select` / `SelectContext`, `Get` / `GetContext`, `Queryx`, `NamedExec` / + `NamedExecContext`. + +All six inherit redaction-by-default, stable fingerprints, the parser seam, +and slow-query/N+1 detection from the shared core, and expose `ResetN1()` for +per-request scoping. + +## Limitations + +- The static scanner resolves inline literals, same/cross-package constants, + constant concatenation, and `fmt.Sprintf` with a constant format string + (via `go/types`); it cannot resolve values only known at runtime. +- The default fallback parser is best-effort; for exact structural analysis use a real parser module (see _SQL Parsers_ above) +- EXPLAIN analyzer requires a live database connection; only Postgres and MySQL dialects are supported + +## License + +[MIT](LICENSE) diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000..87fd9ce --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,56 @@ +# Security Policy + +## Supported versions + +sqlguard is pre-1.0. Security fixes are applied to the latest released minor +version. Until 1.0, only the most recent `0.x` release is supported. + +| Version | Supported | +| ------------ | --------- | +| latest `0.x` | ✅ | +| older | ❌ | + +## Reporting a vulnerability + +**Please do not report security vulnerabilities through public GitHub issues, +discussions, or pull requests.** + +Instead, use one of the following private channels: + +1. **GitHub private vulnerability reporting** (preferred) — open a report via + the repository's **Security → Report a vulnerability** tab + (`https://github.com/KARTIKrocks/sqlguard/security/advisories/new`). +2. **Email** — `kartik.rajput622001@gmail.com` with a subject line starting + `[sqlguard security]`. + +Please include: + +- the affected module(s) and version/commit, +- a description of the issue and its impact, +- steps to reproduce (a minimal repro or PoC is ideal), +- any suggested remediation. + +You can expect an acknowledgement within **5 business days**. We'll keep you +informed as we investigate and work on a fix, and we'll credit you in the +release notes / advisory unless you prefer to remain anonymous. + +## Scope and threat model + +sqlguard is a _defensive_ tool — it analyzes SQL for risky patterns and is +designed to fail safe. A few invariants are part of its security contract; bugs +that break them are in scope: + +- **Redaction by default.** A `Result` must never carry raw literal values out + of the process: `Result.Query` is redacted and `Result.Fingerprint` must be + PII-free. A path that leaks literals to a reporter/log is a security bug. +- **EXPLAIN never executes the statement.** The `explain` analyzer validates + input (comment/string-aware multi-statement rejection, `SELECT`/`WITH`-only by + default) and runs every plan inside an always-rolled-back, read-only + transaction. A way to make `explain` mutate data or run a second statement is + in scope. +- **The middleware must not alter query semantics or results.** It observes; + it must not change what the underlying driver returns. + +Out of scope: vulnerabilities in third-party dependencies (report those +upstream; we'll bump once fixed), and misuse such as deliberately disabling +redaction with `WithRawQuery()` / `redact: false`. diff --git a/analyzer/analyzer.go b/analyzer/analyzer.go new file mode 100644 index 0000000..5be8118 --- /dev/null +++ b/analyzer/analyzer.go @@ -0,0 +1,165 @@ +package analyzer + +import "maps" + +// Rule checks a normalized Statement and returns a Result if an issue is +// found. It returns the result and true if an issue was detected, or a zero +// Result and false otherwise. +// +// Rules operate on the parsed Statement, not the raw SQL string, so a query +// is parsed once per Analyze call and every rule sees the same dialect- +// agnostic view. +type Rule func(s *Statement) (Result, bool) + +// boundRule is a rule together with its registry name and the default +// severity from its RuleSpec. The name is "" for rules supplied directly via +// New (anonymous rules); profile overrides and suppressions only apply to +// named, registry-built rules. hasSeverity distinguishes a registry-built rule +// (whose severity is the spec's DefaultSeverity, the single source of truth) +// from an anonymous rule (which carries its own severity in the Result it +// returns); since SeverityInfo is the zero value, a flag is needed rather than +// a sentinel. +type boundRule struct { + name string + check Rule + severity Severity + hasSeverity bool +} + +// Analyzer holds a set of rules and a Parser, and runs the rules against +// SQL queries. Configuration (disabled rules, severity overrides, per-rule +// settings) is resolved once at construction into the bound rule set and the +// severity map; the per-query Analyze path does no config work. +type Analyzer struct { + rules []boundRule + parser Parser + severity map[string]Severity + // rawQuery, when true, leaves Result.Query unredacted. Default is false + // (redact): the safe default for a tool whose findings flow into logs. + rawQuery bool +} + +// New creates an Analyzer with the given anonymous rules, using the +// zero-dependency FallbackParser. Use WithParser to supply a real dialect +// parser. Rules added this way are not subject to profile overrides (they +// have no registry name); use Default/DefaultWithProfile for configurable +// built-in rules. +func New(rules ...Rule) *Analyzer { + bound := make([]boundRule, len(rules)) + for i, r := range rules { + bound[i] = boundRule{check: r} + } + return &Analyzer{rules: bound, parser: NewFallbackParser()} +} + +// WithParser returns a copy of the Analyzer that uses the given Parser. +// Passing nil resets it to the FallbackParser. +func (a *Analyzer) WithParser(p Parser) *Analyzer { + if p == nil { + p = NewFallbackParser() + } + cp := *a + cp.parser = p + return &cp +} + +// WithRawQuery returns a copy of the Analyzer that leaves Result.Query +// unredacted (the raw SQL, literals and all). Redaction is on by default so +// literal values never reach a log sink; opt out only for local debugging +// where the query text is trusted. Fingerprint is always populated either +// way. +func (a *Analyzer) WithRawQuery() *Analyzer { + cp := *a + cp.rawQuery = true + return &cp +} + +// PrepareQuery returns the query field and fingerprint for a Result built +// outside the rule path (e.g. the runtime slow-query and N+1 findings), +// applying the same redaction policy as Analyze so every emitted Result is +// consistent. display is redacted unless the Analyzer was built +// WithRawQuery; fingerprint is always the PII-free identity. +func (a *Analyzer) PrepareQuery(raw string) (display, fingerprint string) { + fingerprint = Fingerprint(raw) + if a.rawQuery { + return raw, fingerprint + } + return Redact(raw), fingerprint +} + +// Default creates an Analyzer with all registered built-in rules and the +// fallback parser, using each rule's default settings and severity. +func Default() *Analyzer { + return DefaultWithProfile(Profile{}) +} + +// DefaultWithProfile builds an Analyzer from the rule registry with the given +// Profile applied: disabled/whitelisted rules are filtered, per-rule settings +// are passed to each rule's factory, and severity overrides are precomputed. +// The config package uses this to turn a .sqlguard.yml into an Analyzer +// without analyzer ever importing config or YAML. +func DefaultWithProfile(p Profile) *Analyzer { + var bound []boundRule + for _, spec := range specs() { + if p.skip(spec.Name) { + continue + } + bound = append(bound, boundRule{ + name: spec.Name, + check: spec.Factory(p.Settings[spec.Name]), + severity: spec.DefaultSeverity, + hasSeverity: true, + }) + } + var sev map[string]Severity + if len(p.Severity) > 0 { + sev = make(map[string]Severity, len(p.Severity)) + maps.Copy(sev, p.Severity) + } + return &Analyzer{rules: bound, parser: NewFallbackParser(), severity: sev, rawQuery: p.RawQuery} +} + +// Analyze parses the query once and runs all rules against it. If the +// configured parser returns an error, it degrades to the FallbackParser so +// analysis never breaks the caller's query path. Findings for rules named in +// an in-SQL `sqlguard:ignore` directive are suppressed, and severity +// overrides from the active Profile are applied. +func (a *Analyzer) Analyze(query string) []Result { + stmt, err := a.parser.Parse(query) + if err != nil || stmt == nil { + stmt, _ = NewFallbackParser().Parse(query) + } + + ignoreAll, ignored := parseIgnoreDirective(query) + + display, fingerprint := a.PrepareQuery(query) + + results := make([]Result, 0, len(a.rules)) + for _, br := range a.rules { + if ignoreAll { + break + } + r, ok := br.check(stmt) + if !ok { + continue + } + if r.RuleName != "" && ignored[r.RuleName] { + continue + } + // Severity precedence: the spec's DefaultSeverity is the single source + // of truth for a registry-built rule (the rule body no longer sets + // one); a profile override, when present, wins over that. + if br.hasSeverity { + r.Severity = br.severity + } + if a.severity != nil { + if s, has := a.severity[r.RuleName]; has { + r.Severity = s + } + } + r.Query = display + r.Fingerprint = fingerprint + results = append(results, r) + } + return results +} diff --git a/analyzer/analyzer_test.go b/analyzer/analyzer_test.go new file mode 100644 index 0000000..76419c8 --- /dev/null +++ b/analyzer/analyzer_test.go @@ -0,0 +1,434 @@ +package analyzer + +import "testing" + +// run parses q with the fallback parser and applies a single rule, returning +// whether the rule fired. Rules operate on a parsed Statement now, so tests +// go through the parser the same way Analyze does. +func run(t *testing.T, rule Rule, q string) bool { + t.Helper() + st, err := NewFallbackParser().Parse(q) + if err != nil { + t.Fatalf("fallback parser returned error for %q: %v", q, err) + } + _, ok := rule(st) + return ok +} + +func TestCheckSelectStar(t *testing.T) { + tests := []struct { + name string + query string + wantHit bool + }{ + {"basic select star", "SELECT * FROM users", true}, + {"lowercase", "select * from users", true}, + {"with where", "SELECT * FROM users WHERE id = 1", true}, + {"qualified star", "SELECT u.* FROM users u", true}, + {"specific columns", "SELECT id, name FROM users", false}, + {"count star", "SELECT COUNT(*) FROM users", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := run(t, CheckSelectStar, tt.query); got != tt.wantHit { + t.Errorf("got hit=%v, want %v for query: %s", got, tt.wantHit, tt.query) + } + }) + } +} + +func TestCheckLeadingWildcard(t *testing.T) { + tests := []struct { + name string + query string + wantHit bool + }{ + {"leading wildcard", "SELECT * FROM users WHERE email LIKE '%gmail.com%'", true}, + {"trailing only", "SELECT * FROM users WHERE name LIKE 'John%'", false}, + {"double quotes", `SELECT * FROM users WHERE email LIKE "%gmail%"`, true}, + {"ilike leading wildcard", "SELECT * FROM users WHERE email ILIKE '%gmail%'", true}, + {"ilike trailing only", "SELECT * FROM users WHERE name ILIKE 'John%'", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := run(t, CheckLeadingWildcard, tt.query); got != tt.wantHit { + t.Errorf("got hit=%v, want %v for query: %s", got, tt.wantHit, tt.query) + } + }) + } +} + +func TestCheckDeleteWithoutWhere(t *testing.T) { + tests := []struct { + name string + query string + wantHit bool + }{ + {"no where", "DELETE FROM users", true}, + {"with where", "DELETE FROM users WHERE id = 1", false}, + {"not a delete", "SELECT * FROM users", false}, + {"where in string literal", "DELETE FROM logs WHERE msg = 'no WHERE clause'", false}, + {"fake where in string", "DELETE FROM users SET bio = 'I live WHERE the sun shines'", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := run(t, CheckDeleteWithoutWhere, tt.query); got != tt.wantHit { + t.Errorf("got hit=%v, want %v for query: %s", got, tt.wantHit, tt.query) + } + }) + } +} + +func TestCheckUpdateWithoutWhere(t *testing.T) { + tests := []struct { + name string + query string + wantHit bool + }{ + {"no where", "UPDATE users SET name = 'test'", true}, + {"with where", "UPDATE users SET name = 'test' WHERE id = 1", false}, + {"not an update", "SELECT * FROM users", false}, + {"where in string literal", "UPDATE users SET bio = 'I live WHERE the sun shines'", true}, + {"real where after string", "UPDATE users SET bio = 'hello' WHERE id = 1", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := run(t, CheckUpdateWithoutWhere, tt.query); got != tt.wantHit { + t.Errorf("got hit=%v, want %v for query: %s", got, tt.wantHit, tt.query) + } + }) + } +} + +func TestCheckInsertWithoutColumns(t *testing.T) { + tests := []struct { + name string + query string + wantHit bool + }{ + {"no columns", "INSERT INTO users VALUES ('alice', 'alice@test.com')", true}, + {"with columns", "INSERT INTO users (name, email) VALUES ('alice', 'alice@test.com')", false}, + {"not an insert", "SELECT * FROM users", false}, + {"insert select no columns", "INSERT INTO users SELECT name, email FROM staging", true}, + {"insert select with columns", "INSERT INTO users (name, email) SELECT name, email FROM staging", false}, + {"qualified table no columns", "INSERT INTO public.users VALUES ('alice')", true}, + {"mysql set form", "INSERT INTO users SET name = 'alice', email = 'a@test.com'", false}, + {"default values", "INSERT INTO users DEFAULT VALUES", false}, + {"cte insert no columns", "WITH s AS (SELECT 1) INSERT INTO users SELECT * FROM s", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := run(t, CheckInsertWithoutColumns, tt.query); got != tt.wantHit { + t.Errorf("got hit=%v, want %v for query: %s", got, tt.wantHit, tt.query) + } + }) + } +} + +func TestCheckSelectWithoutLimit(t *testing.T) { + tests := []struct { + name string + query string + wantHit bool + }{ + {"no limit no where", "SELECT id FROM users", true}, + {"with limit", "SELECT id FROM users LIMIT 10", false}, + {"with where", "SELECT id FROM users WHERE id = 1", false}, + {"with both", "SELECT id FROM users WHERE id > 0 LIMIT 10", false}, + {"not a select", "DELETE FROM users", false}, + {"select without from", "SELECT 1", false}, + {"select version", "SELECT version()", false}, + {"select current_timestamp", "SELECT CURRENT_TIMESTAMP", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := run(t, CheckSelectWithoutLimit, tt.query); got != tt.wantHit { + t.Errorf("got hit=%v, want %v for query: %s", got, tt.wantHit, tt.query) + } + }) + } +} + +func TestCheckOrderByWithoutLimit(t *testing.T) { + tests := []struct { + name string + query string + wantHit bool + }{ + {"order without limit", "SELECT id FROM users ORDER BY name", true}, + {"order with limit", "SELECT id FROM users ORDER BY name LIMIT 10", false}, + {"no order by", "SELECT id FROM users", false}, + {"window order by", "SELECT row_number() OVER (ORDER BY id) FROM users", false}, + {"ordered aggregate", "SELECT GROUP_CONCAT(x ORDER BY y) FROM t", false}, + {"window order by with top-level order by", "SELECT rank() OVER (ORDER BY a) FROM t ORDER BY b", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := run(t, CheckOrderByWithoutLimit, tt.query); got != tt.wantHit { + t.Errorf("got hit=%v, want %v for query: %s", got, tt.wantHit, tt.query) + } + }) + } +} + +func TestCheckNonSargablePredicate(t *testing.T) { + tests := []struct { + name string + query string + wantHit bool + }{ + {"lower on column", "SELECT id FROM users WHERE LOWER(email) = 'x'", true}, + {"date on column", "SELECT id FROM events WHERE DATE(created_at) = '2020-01-01'", true}, + {"cast on column", "SELECT id FROM users WHERE CAST(id AS text) = '5'", true}, + {"coalesce on column", "SELECT id FROM users WHERE COALESCE(deleted, false) = false", true}, + {"like on wrapped column", "SELECT id FROM users WHERE UPPER(name) LIKE 'A%'", true}, + {"function on value side", "SELECT id FROM users WHERE email = LOWER('X')", false}, + {"now on value side", "SELECT id FROM events WHERE created_at > NOW()", false}, + {"bare column", "SELECT id FROM users WHERE email = 'x'", false}, + {"in list not a function", "SELECT id FROM users WHERE id IN (1, 2, 3)", false}, + {"function in select list", "SELECT LOWER(name) FROM users", false}, + {"function in order by", "SELECT id FROM users WHERE active = true ORDER BY LOWER(name)", false}, + {"commented out predicate", "SELECT id FROM users -- WHERE LOWER(email) = 'x'", false}, + {"predicate after subquery clause", "SELECT id FROM users WHERE id IN (SELECT uid FROM o ORDER BY x LIMIT 1) AND LOWER(name) = 'a'", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := run(t, CheckNonSargablePredicate, tt.query); got != tt.wantHit { + t.Errorf("got hit=%v, want %v for query: %s", got, tt.wantHit, tt.query) + } + }) + } +} + +func TestCheckAddNotNullWithoutDefault(t *testing.T) { + tests := []struct { + name string + query string + wantHit bool + }{ + {"add not null no default", "ALTER TABLE users ADD COLUMN age int NOT NULL", true}, + {"add not null without column kw", "ALTER TABLE users ADD age int NOT NULL", true}, + {"numeric type with comma", "ALTER TABLE t ADD COLUMN bal numeric(10,2) NOT NULL", true}, + {"multi action one unsafe", "ALTER TABLE t ADD COLUMN a int NOT NULL, ADD COLUMN b int DEFAULT 5", true}, + {"not null with default", "ALTER TABLE users ADD COLUMN age int NOT NULL DEFAULT 0", false}, + {"default before not null", "ALTER TABLE users ADD COLUMN age int DEFAULT 0 NOT NULL", false}, + {"nullable column", "ALTER TABLE users ADD COLUMN age int", false}, + {"set not null on existing", "ALTER TABLE users ALTER COLUMN age SET NOT NULL", false}, + {"add check constraint is not null", "ALTER TABLE users ADD CONSTRAINT chk CHECK (age IS NOT NULL)", false}, + {"not an alter", "INSERT INTO users (age) VALUES (1)", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := run(t, CheckAddNotNullWithoutDefault, tt.query); got != tt.wantHit { + t.Errorf("got hit=%v, want %v for query: %s", got, tt.wantHit, tt.query) + } + }) + } +} + +func TestCheckImplicitJoin(t *testing.T) { + tests := []struct { + name string + query string + wantHit bool + }{ + {"two table comma join", "SELECT * FROM a, b WHERE a.id = b.id", true}, + {"three tables", "SELECT * FROM a, b, c WHERE a.id = b.id AND b.id = c.id", true}, + {"comma plus explicit join", "SELECT * FROM a, b JOIN c ON b.id = c.id", true}, + {"explicit join only", "SELECT * FROM a JOIN b ON a.id = b.id", false}, + {"single table", "SELECT * FROM users WHERE id = 1", false}, + {"select list comma not from", "SELECT id, name FROM users", false}, + {"comma inside in list", "SELECT * FROM users WHERE id IN (1, 2, 3)", false}, + {"comma inside function", "SELECT * FROM generate_series(1, 10)", false}, + {"comma inside subquery", "SELECT * FROM (SELECT a, b FROM t) sub", false}, + {"from inside extract", "SELECT EXTRACT(YEAR FROM created_at) FROM events", false}, + {"extract then comma join", "SELECT EXTRACT(YEAR FROM ts) FROM events, logs WHERE events.id = logs.id", true}, + {"no from", "SELECT 1, 2", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := run(t, CheckImplicitJoin, tt.query); got != tt.wantHit { + t.Errorf("got hit=%v, want %v for query: %s", got, tt.wantHit, tt.query) + } + }) + } +} + +func TestCheckCartesianJoin(t *testing.T) { + tests := []struct { + name string + query string + wantHit bool + }{ + {"comma join no where", "SELECT * FROM a, b", true}, + {"three tables no where", "SELECT * FROM a, b, c", true}, + {"explicit cross join", "SELECT * FROM a CROSS JOIN b", true}, + {"cross join with where", "SELECT * FROM a CROSS JOIN b WHERE a.x = 1", false}, + {"comma join with where", "SELECT * FROM a, b WHERE a.id = b.id", false}, + {"join with on", "SELECT * FROM a JOIN b ON a.id = b.id", false}, + {"join with using", "SELECT * FROM a JOIN b USING (id)", false}, + {"natural join", "SELECT * FROM a NATURAL JOIN b", false}, + {"single table", "SELECT * FROM users", false}, + {"subquery cross product", "SELECT * FROM (SELECT * FROM x WHERE y = 1) sub, t", true}, + {"cross join only in subquery", "SELECT x FROM (SELECT * FROM a CROSS JOIN b) sub", false}, + {"conditioned join only in subquery", "SELECT x FROM (SELECT * FROM a JOIN b ON a.id = b.id) sub", false}, + {"no from", "SELECT 1, 2", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := run(t, CheckCartesianJoin, tt.query); got != tt.wantHit { + t.Errorf("got hit=%v, want %v for query: %s", got, tt.wantHit, tt.query) + } + }) + } +} + +func TestCheckInListTooLarge(t *testing.T) { + rule := inListRule(5) // flag IN lists with more than 5 elements + tests := []struct { + name string + query string + wantHit bool + }{ + {"over threshold", "SELECT * FROM t WHERE id IN (1, 2, 3, 4, 5, 6)", true}, + {"at threshold", "SELECT * FROM t WHERE id IN (1, 2, 3, 4, 5)", false}, + {"under threshold", "SELECT * FROM t WHERE id IN (1, 2, 3)", false}, + {"not in over threshold", "SELECT * FROM t WHERE id NOT IN (1, 2, 3, 4, 5, 6)", true}, + {"placeholders over threshold", "SELECT * FROM t WHERE id IN (?, ?, ?, ?, ?, ?)", true}, + {"string literals over threshold", "SELECT * FROM t WHERE c IN ('a', 'b', 'c', 'd', 'e', 'f')", true}, + {"subquery not counted", "SELECT * FROM t WHERE id IN (SELECT id FROM other)", false}, + {"no in list", "SELECT * FROM t WHERE id = 5", false}, + {"function commas not an in list", "SELECT * FROM t WHERE x = greatest(1, 2, 3, 4, 5, 6, 7)", false}, + {"largest of multiple lists", "SELECT * FROM t WHERE a IN (1, 2) AND b IN (1, 2, 3, 4, 5, 6, 7)", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := run(t, rule, tt.query); got != tt.wantHit { + t.Errorf("got hit=%v, want %v for query: %s", got, tt.wantHit, tt.query) + } + }) + } +} + +func TestCheckLargeOffset(t *testing.T) { + rule := largeOffsetRule(1000) + tests := []struct { + name string + query string + wantHit bool + }{ + {"large offset", "SELECT * FROM t ORDER BY id LIMIT 20 OFFSET 5000", true}, + {"at threshold", "SELECT * FROM t ORDER BY id LIMIT 20 OFFSET 1000", false}, + {"small offset", "SELECT * FROM t ORDER BY id LIMIT 20 OFFSET 40", false}, + {"no offset", "SELECT * FROM t ORDER BY id LIMIT 20", false}, + {"parameterized offset", "SELECT * FROM t ORDER BY id LIMIT 20 OFFSET $1", false}, + {"offset rows fetch", "SELECT * FROM t ORDER BY id OFFSET 5000 ROWS FETCH NEXT 20 ROWS ONLY", true}, + {"mysql limit offset comma", "SELECT * FROM t ORDER BY id LIMIT 5000, 20", true}, + {"mysql limit small offset", "SELECT * FROM t ORDER BY id LIMIT 40, 20", false}, + {"offset as column name", "SELECT offset FROM t WHERE offset = 5000", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := run(t, rule, tt.query); got != tt.wantHit { + t.Errorf("got hit=%v, want %v for query: %s", got, tt.wantHit, tt.query) + } + }) + } +} + +func TestCheckSelectDistinct(t *testing.T) { + tests := []struct { + name string + query string + wantHit bool + }{ + {"basic distinct", "SELECT DISTINCT name FROM users", true}, + {"lowercase", "select distinct id from t", true}, + {"distinct on postgres", "SELECT DISTINCT ON (dept) name FROM emp", true}, + {"distinct parens", "SELECT DISTINCT(name) FROM users", true}, + {"distinctrow mysql", "SELECT DISTINCTROW name FROM users", true}, + {"distinct in subquery", "SELECT * FROM (SELECT DISTINCT x FROM t) s", true}, + {"no distinct", "SELECT name FROM users", false}, + {"count distinct aggregate", "SELECT COUNT(DISTINCT name) FROM users", false}, + {"distinct in aggregate with group", "SELECT id, COUNT(DISTINCT x) FROM t GROUP BY id", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := run(t, CheckSelectDistinct, tt.query); got != tt.wantHit { + t.Errorf("got hit=%v, want %v for query: %s", got, tt.wantHit, tt.query) + } + }) + } +} + +func TestDefaultAnalyzer(t *testing.T) { + a := Default() + + results := a.Analyze("DELETE FROM users") + if len(results) == 0 { + t.Fatal("expected at least one result for DELETE without WHERE") + } + if results[0].Severity != SeverityCritical { + t.Errorf("expected critical severity, got %s", results[0].Severity) + } + + results = a.Analyze("SELECT id FROM users WHERE id = 1") + if len(results) != 0 { + t.Errorf("expected no results for safe query, got %d", len(results)) + } +} + +// TestSpecDefaultSeverityIsAuthoritative locks in that a registry-built rule's +// reported severity comes from its RuleSpec.DefaultSeverity — the single +// source of truth — not from a literal in the rule body. The rule here +// deliberately returns the zero severity (Info); Analyze must report Critical. +func TestSpecDefaultSeverityIsAuthoritative(t *testing.T) { + const name = "zz-spec-severity-probe" + Register(RuleSpec{ + Name: name, + DefaultSeverity: SeverityCritical, + Factory: func(Settings) Rule { + return func(*Statement) (Result, bool) { + return Result{RuleName: name}, true // no Severity set + } + }, + }) + // Don't leak the probe into the global registry; Default() would pick it up. + t.Cleanup(func() { + registryMu.Lock() + delete(registry, name) + registryMu.Unlock() + }) + + a := DefaultWithProfile(Profile{Only: map[string]bool{name: true}}) + got := a.Analyze("SELECT 1") + if len(got) != 1 { + t.Fatalf("expected 1 result, got %d", len(got)) + } + if got[0].Severity != SeverityCritical { + t.Errorf("severity = %s, want CRITICAL (from spec DefaultSeverity)", got[0].Severity) + } + + // A profile override still wins over the spec default. + a = DefaultWithProfile(Profile{ + Only: map[string]bool{name: true}, + Severity: map[string]Severity{name: SeverityInfo}, + }) + if got := a.Analyze("SELECT 1"); len(got) != 1 || got[0].Severity != SeverityInfo { + t.Errorf("profile override not applied: %+v", got) + } +} diff --git a/analyzer/fallback.go b/analyzer/fallback.go new file mode 100644 index 0000000..53e43df --- /dev/null +++ b/analyzer/fallback.go @@ -0,0 +1,576 @@ +package analyzer + +import ( + "regexp" + "strconv" + "strings" +) + +// FallbackParser is the zero-dependency Parser. It removes SQL comments and +// string-literal contents before pattern matching, so keywords inside +// comments or strings (and identifiers like update_at) no longer cause +// false positives. It is best-effort and never returns an error: SQL it +// cannot fully understand still yields a usable Statement with Exact=false. +type FallbackParser struct{} + +// NewFallbackParser returns the default zero-dependency parser. +func NewFallbackParser() *FallbackParser { return &FallbackParser{} } + +var ( + // I?LIKE matches both LIKE and Postgres' case-insensitive ILIKE; the \b + // before it keeps the "I" from matching inside words like DISLIKE. + fbLeadingWildcardRe = regexp.MustCompile(`(?i)\bI?LIKE\s+['"]\s*%`) + // Best-effort capture of a LIKE/ILIKE pattern's literal body. Does not model + // embedded/escaped quotes; the fallback is heuristic by contract. + fbLikeLiteralRe = regexp.MustCompile(`(?i)\bI?LIKE\s+['"]([^'"]*)['"]`) + fbWhereRe = regexp.MustCompile(`(?i)\bWHERE\b`) + fbLimitRe = regexp.MustCompile(`(?i)\bLIMIT\b`) + fbOrderByRe = regexp.MustCompile(`(?i)\bORDER\s+BY\b`) + fbFromRe = regexp.MustCompile(`(?i)\bFROM\b`) + fbSelectStarRe = regexp.MustCompile(`(?i)\bSELECT\s+(?:DISTINCT\s+)?(?:[a-z_][a-z0-9_]*\s*\.\s*)?\*`) + // fbSelectDistinctRe anchors DISTINCT to right after SELECT, so an + // aggregate-level DISTINCT (COUNT(DISTINCT x)) does not match. + fbSelectDistinctRe = regexp.MustCompile(`(?i)\bSELECT\s+DISTINCT(?:ROW)?\b`) + fbIntoRe = regexp.MustCompile(`(?i)\bINTO\b`) + // fbInsertDataRe marks the start of an INSERT's data clause, after the + // target table (and its optional column list). VALUES? covers MySQL's + // singular VALUE; SELECT/WITH/TABLE cover INSERT ... SELECT and friends. + fbInsertDataRe = regexp.MustCompile(`(?i)\b(VALUES?|SELECT|WITH|TABLE|SET|DEFAULT)\b`) + fbLeadKindRe = regexp.MustCompile(`(?i)^\s*\(*\s*(SELECT|INSERT|UPDATE|DELETE|WITH)\b`) + fbDMLWordRe = regexp.MustCompile(`(?i)\b(INSERT|UPDATE|DELETE)\b`) + + // fbWhereRegionEndRe marks the first clause keyword that ends the WHERE + // region, so a function in ORDER BY / GROUP BY / HAVING isn't read as a + // WHERE predicate. + fbWhereRegionEndRe = regexp.MustCompile(`(?i)\b(GROUP\s+BY|ORDER\s+BY|HAVING|LIMIT|OFFSET|WINDOW|FETCH|FOR\s+UPDATE)\b`) + // fbFuncOnColumnRe matches IDENT(args): a function/cast call whose + // closing paren is immediately followed by a comparison operator. The + // operator-after-paren shape is what restricts it to the column side of a + // predicate (WHERE LOWER(c) = ...), not the value side (WHERE c = ABS(x)). + fbFuncOnColumnRe = regexp.MustCompile(`(?i)\b([a-z_][a-z0-9_]*)\s*\(([^()]*)\)\s*(?:=|<>|!=|<=|>=|<|>|\bLIKE\b|\bIN\b|\bBETWEEN\b)`) + // fbArgIdentRe checks that a function's arguments contain a column-like + // identifier, so NOW() and LOWER('x') (literal blanked to '') are skipped. + fbArgIdentRe = regexp.MustCompile(`[a-zA-Z_]`) + + fbAlterTableRe = regexp.MustCompile(`(?i)^\s*ALTER\s+TABLE\b`) + // fbAddActionRe matches the start of an ALTER action and captures the + // first token after ADD [COLUMN] — a column name for a column add, or a + // keyword (CONSTRAINT, CHECK, ...) for the forms we must skip. + fbAddActionRe = regexp.MustCompile(`(?i)\bADD\s+(?:COLUMN\s+)?(\w+)`) + fbNotNullRe = regexp.MustCompile(`(?i)\bNOT\s+NULL\b`) + fbDefaultRe = regexp.MustCompile(`(?i)\bDEFAULT\b`) + + // fbFromRegionEndRe marks the first clause keyword that ends the FROM + // region, so commas after it (an IN list, GROUP BY, etc.) aren't read as + // join separators. + fbFromRegionEndRe = regexp.MustCompile(`(?i)\b(WHERE|GROUP\s+BY|ORDER\s+BY|HAVING|LIMIT|OFFSET|WINDOW|FETCH|FOR|UNION|EXCEPT|INTERSECT)\b`) + fbJoinRe = regexp.MustCompile(`(?i)\bJOIN\b`) + // fbJoinCondRe matches anything that conditions a join — an ON/USING + // predicate or a NATURAL join (which joins on common columns) — so a join + // carrying one of these is not treated as a cartesian product. + fbJoinCondRe = regexp.MustCompile(`(?i)\b(ON|USING|NATURAL)\b`) + // fbInListRe matches the opening of an IN value list (NOT IN matches too). + fbInListRe = regexp.MustCompile(`(?i)\bIN\s*\(`) + // fbSubqueryStartRe recognizes an IN (...) body that is a subquery (or set) + // rather than a value list, so it is not counted. + fbSubqueryStartRe = regexp.MustCompile(`(?i)^\(*\s*(SELECT|WITH|VALUES|TABLE)\b`) + // fbOffsetRe captures a literal standard OFFSET n (incl. OFFSET n ROWS). + fbOffsetRe = regexp.MustCompile(`(?i)\bOFFSET\s+(\d+)`) + // fbLimitOffsetRe captures the offset of MySQL's LIMIT offset, count form. + fbLimitOffsetRe = regexp.MustCompile(`(?i)\bLIMIT\s+(\d+)\s*,\s*\d+`) +) + +// fbNonSargableSkipFuncs are tokens that can appear as IDENT before "(" but +// are SQL keywords, not functions wrapping a column. +var fbNonSargableSkipFuncs = map[string]bool{ + "in": true, "exists": true, "any": true, "all": true, + "some": true, "and": true, "or": true, "not": true, +} + +// fbAddNonColumnKeywords are the tokens following ADD that mean the action is +// not a column add (so a stray NOT NULL, e.g. inside a CHECK constraint, isn't +// mistaken for a NOT NULL column). +var fbAddNonColumnKeywords = map[string]bool{ + "constraint": true, "primary": true, "foreign": true, + "unique": true, "check": true, "key": true, "index": true, +} + +// Parse implements Parser. It always returns a non-nil Statement and a nil +// error. +func (p *FallbackParser) Parse(sql string) (*Statement, error) { + st := &Statement{Raw: sql, Exact: false} + + noComments := stripComments(sql) + + // Leading-wildcard LIKE is detected before literal contents are blanked, + // because the pattern lives inside the literal. Comments are already gone, + // so a commented-out LIKE won't trigger. + st.LeadingWildcardLike = fbLeadingWildcardRe.MatchString(noComments) + if st.LeadingWildcardLike { + st.LeadingWildcardTermLen = leadingWildcardTermLen(noComments) + } + + sanitized := blankStringLiterals(noComments) + + st.Kind = detectKind(sanitized) + st.HasWhere = fbWhereRe.MatchString(sanitized) + st.HasLimit = fbLimitRe.MatchString(sanitized) + st.HasOrderBy = hasTopLevelOrderBy(sanitized) + st.HasFrom = fbFromRe.MatchString(sanitized) + st.SelectStar = fbSelectStarRe.MatchString(sanitized) + st.SelectDistinct = fbSelectDistinctRe.MatchString(sanitized) + st.NonSargablePredicate = hasNonSargablePredicate(sanitized) + st.AddNotNullNoDefault = hasUnsafeAddNotNull(sanitized) + st.ImplicitCommaJoin = hasImplicitCommaJoin(sanitized) + st.CartesianJoin = hasCartesianJoin(sanitized) + st.MaxInListLen = maxInListLen(sanitized) + st.OffsetValue = maxOffset(sanitized) + + if st.Kind == StmtInsert { + st.InsertColumnsListed = insertColumnsListed(sanitized) + } + + return st, nil +} + +// insertColumnsListed reports whether an INSERT names its target columns +// explicitly. It inspects the span between INTO and the data clause +// (VALUES / SELECT / WITH / TABLE): an explicit column list shows up there as a +// "(". The "VALUES"-only shape the old regex matched missed INSERT ... SELECT +// (and CTE-prefixed inserts), which carry the same schema-change risk. MySQL's +// "SET col = ..." names its columns, and "DEFAULT VALUES" inserts no data, so +// both count as listed (no positional column-order risk to warn about). +// Comment-free, literal-blanked input expected; heuristic by contract. +func insertColumnsListed(sanitized string) bool { + loc := fbIntoRe.FindStringIndex(sanitized) + if loc == nil { + return true // no INTO found — can't tell, don't flag + } + rest := sanitized[loc[1]:] + data := fbInsertDataRe.FindStringIndex(rest) + if data == nil { + return true // no recognizable data clause — don't flag + } + switch strings.ToUpper(strings.TrimSpace(rest[data[0]:data[1]])) { + case "SET", "DEFAULT": + return true + } + // Columns are listed iff a "(" appears between the table name and the data + // clause. A bare table reference (incl. schema.table) has no parens there. + return strings.Contains(rest[:data[0]], "(") +} + +// leadingWildcardTermLen returns the length of the longest searchable term +// (the LIKE literal with surrounding '%' trimmed) among patterns that begin +// with a wildcard. Comment-free input is expected. +func leadingWildcardTermLen(noComments string) int { + max := 0 + for _, m := range fbLikeLiteralRe.FindAllStringSubmatch(noComments, -1) { + body := strings.TrimSpace(m[1]) + if !strings.HasPrefix(body, "%") { + continue + } + if n := len(strings.Trim(body, "%")); n > max { + max = n + } + } + return max +} + +// hasNonSargablePredicate reports whether the WHERE clause applies a function +// or cast to a column (WHERE LOWER(email) = ...), which defeats an index on +// that column. Input must be comment-free and have its string literals +// blanked. Scope is limited to the WHERE region so functions in the SELECT +// list, ORDER BY, or GROUP BY don't false-fire. +func hasNonSargablePredicate(sanitized string) bool { + region := whereRegion(sanitized) + if region == "" { + return false + } + for _, m := range fbFuncOnColumnRe.FindAllStringSubmatch(region, -1) { + if fbNonSargableSkipFuncs[strings.ToLower(m[1])] { + continue // a keyword like IN(...) / EXISTS(...), not a function + } + if !fbArgIdentRe.MatchString(m[2]) { + continue // no column in the args (e.g. NOW(), LOWER('x')) + } + return true + } + return false +} + +// whereRegion returns the slice of sanitized SQL from the WHERE keyword up to +// the next clause keyword (ORDER BY, GROUP BY, HAVING, LIMIT, ...), or "" when +// there is no WHERE clause. +func whereRegion(sanitized string) string { + loc := fbWhereRe.FindStringIndex(sanitized) + if loc == nil { + return "" + } + region := sanitized[loc[1]:] + // End the region at the first clause keyword that sits at the WHERE's own + // nesting level. A keyword inside a subquery in the WHERE (e.g. ORDER BY / + // LIMIT in "WHERE id IN (SELECT ... ORDER BY x LIMIT 1)") is at depth > 0 + // and must not cut the region short, which would drop predicates after it. + for _, end := range fbWhereRegionEndRe.FindAllStringIndex(region, -1) { + if parenDepthBefore(region, end[0]) == 0 { + return region[:end[0]] + } + } + return region +} + +// hasUnsafeAddNotNull reports whether an ALTER TABLE adds a NOT NULL column +// without a DEFAULT — which errors or rewrites the table on a populated table. +// Input must be comment-free with string literals blanked. The statement is +// split on top-level commas so each ADD action is judged independently (and a +// numeric type's own comma, e.g. NUMERIC(10,2), isn't a split point). +func hasUnsafeAddNotNull(sanitized string) bool { + if !fbAlterTableRe.MatchString(sanitized) { + return false + } + for _, seg := range splitTopLevelCommas(sanitized) { + m := fbAddActionRe.FindStringSubmatch(seg) + if m == nil { + continue // not an ADD action + } + if fbAddNonColumnKeywords[strings.ToLower(m[1])] { + continue // ADD CONSTRAINT / CHECK / KEY / ... — not a column add + } + if fbNotNullRe.MatchString(seg) && !fbDefaultRe.MatchString(seg) { + return true + } + } + return false +} + +// splitTopLevelCommas splits s on commas that are not nested inside +// parentheses. Used to separate the actions of a multi-action ALTER TABLE +// without breaking parenthesized type specs or expressions. +func splitTopLevelCommas(s string) []string { + var segs []string + depth, start := 0, 0 + for i := 0; i < len(s); i++ { + switch s[i] { + case '(': + depth++ + case ')': + if depth > 0 { + depth-- + } + case ',': + if depth == 0 { + segs = append(segs, s[start:i]) + start = i + 1 + } + } + } + return append(segs, s[start:]) +} + +// hasImplicitCommaJoin reports whether the FROM clause lists multiple tables +// separated by top-level commas (FROM a, b) rather than explicit JOIN syntax. +// Input must be comment-free with string literals blanked. The FROM region is +// found and bounded paren-aware so that a FROM inside EXTRACT(... FROM ...) or +// a subquery, and commas inside function calls / subqueries / IN lists, are +// not mistaken for join separators. +func hasImplicitCommaJoin(sanitized string) bool { + region := fromRegion(sanitized) + if region == "" { + return false + } + return len(splitTopLevelCommas(region)) > 1 +} + +// hasCartesianJoin reports whether the FROM clause joins multiple tables (via +// a top-level comma, CROSS JOIN, or a bare JOIN) with no join condition +// (ON/USING/NATURAL) anywhere in the FROM region and no top-level WHERE — an +// unconditioned cartesian product. It is deliberately conservative: if any +// join condition or WHERE filter is present it does not fire, so mixed queries +// (some joins conditioned) yield a false negative rather than a false positive. +// Input must be comment-free with string literals blanked. +func hasCartesianJoin(sanitized string) bool { + region := fromRegion(sanitized) + if region == "" { + return false + } + multiTable := len(splitTopLevelCommas(region)) > 1 || hasTopLevelJoin(region) + if !multiTable { + return false + } + if fbJoinCondRe.MatchString(region) || hasTopLevelWhere(sanitized) { + return false + } + return true +} + +// hasTopLevelWhere reports whether a WHERE keyword appears at parenthesis depth +// zero (a WHERE inside a subquery does not filter an outer cartesian product). +func hasTopLevelWhere(sanitized string) bool { + for _, loc := range fbWhereRe.FindAllStringIndex(sanitized, -1) { + if parenDepthBefore(sanitized, loc[0]) == 0 { + return true + } + } + return false +} + +// hasTopLevelJoin reports whether a JOIN keyword appears at parenthesis depth +// zero within the FROM region — an outer-level join (incl. CROSS/bare JOIN), +// not one inside a subquery. region is a depth-zero slice from fromRegion, so +// depth is measured relative to it. This is the JOIN counterpart of the +// paren-aware comma split: without it a JOIN in a FROM-clause subquery would be +// read as an outer cartesian product. +func hasTopLevelJoin(region string) bool { + for _, loc := range fbJoinRe.FindAllStringIndex(region, -1) { + if parenDepthBefore(region, loc[0]) == 0 { + return true + } + } + return false +} + +// hasTopLevelOrderBy reports whether an ORDER BY appears at parenthesis depth +// zero — a result-set sort, not a window-function (OVER (ORDER BY ...)), +// ordered-aggregate (GROUP_CONCAT(... ORDER BY ...), WITHIN GROUP (ORDER BY +// ...)), or subquery ordering, none of which sort the statement's result set. +func hasTopLevelOrderBy(sanitized string) bool { + for _, loc := range fbOrderByRe.FindAllStringIndex(sanitized, -1) { + if parenDepthBefore(sanitized, loc[0]) == 0 { + return true + } + } + return false +} + +// maxInListLen returns the largest element count among the statement's +// IN (...) value lists. IN (SELECT ...) / IN (VALUES ...) subqueries are not +// counted. Input must be comment-free with string literals blanked — commas +// between blanked literals survive, so element counting is unaffected. +func maxInListLen(sanitized string) int { + max := 0 + for _, loc := range fbInListRe.FindAllStringIndex(sanitized, -1) { + inner, ok := parenContent(sanitized, loc[1]-1) // loc[1]-1 is the "(" + if !ok { + continue + } + if strings.TrimSpace(inner) == "" || fbSubqueryStartRe.MatchString(strings.TrimSpace(inner)) { + continue + } + if n := len(splitTopLevelCommas(inner)); n > max { + max = n + } + } + return max +} + +// maxOffset returns the largest literal offset in the statement, considering +// both the standard OFFSET n form and MySQL's LIMIT offset, count form. A +// parameterized offset (OFFSET $1 / ?) matches neither and yields 0. Input +// must be comment-free with string literals blanked (numeric literals, which +// is what an offset is, are left intact). An offset literal too large for int +// is ignored (treated as no offset) — a rare, harmless false negative. +func maxOffset(sanitized string) int { + max := 0 + consider := func(re *regexp.Regexp) { + for _, m := range re.FindAllStringSubmatch(sanitized, -1) { + if n, err := strconv.Atoi(m[1]); err == nil && n > max { + max = n + } + } + } + consider(fbOffsetRe) + consider(fbLimitOffsetRe) + return max +} + +// parenContent returns the substring between the parenthesis at index open and +// its matching close paren (exclusive of both), and true, or "" and false if +// unbalanced. +func parenContent(s string, open int) (string, bool) { + depth := 0 + for i := open; i < len(s); i++ { + switch s[i] { + case '(': + depth++ + case ')': + depth-- + if depth == 0 { + return s[open+1 : i], true + } + } + } + return "", false +} + +// fromRegion returns the slice of sanitized SQL between the first top-level +// FROM keyword and the next top-level clause keyword (WHERE, GROUP BY, a set +// operator, ...), or "" when there is no top-level FROM. "Top-level" means at +// parenthesis depth zero, so subquery and function-argument keywords are +// ignored. +func fromRegion(sanitized string) string { + fromEnd := -1 + for _, loc := range fbFromRe.FindAllStringIndex(sanitized, -1) { + if parenDepthBefore(sanitized, loc[0]) == 0 { + fromEnd = loc[1] + break + } + } + if fromEnd == -1 { + return "" + } + region := sanitized[fromEnd:] + for _, loc := range fbFromRegionEndRe.FindAllStringIndex(region, -1) { + if parenDepthBefore(region, loc[0]) == 0 { + return region[:loc[0]] + } + } + return region +} + +// parenDepthBefore returns the net parenthesis nesting depth at index idx +// (count of unmatched '(' in s[:idx]). +func parenDepthBefore(s string, idx int) int { + depth := 0 + for i := range idx { + switch s[i] { + case '(': + depth++ + case ')': + if depth > 0 { + depth-- + } + } + } + return depth +} + +func detectKind(sanitized string) StmtKind { + m := fbLeadKindRe.FindStringSubmatch(sanitized) + if m == nil { + return StmtOther + } + switch strings.ToUpper(m[1]) { + case "SELECT": + return StmtSelect + case "INSERT": + return StmtInsert + case "UPDATE": + return StmtUpdate + case "DELETE": + return StmtDelete + case "WITH": + // A CTE feeds a main statement. Best-effort: if an INSERT/UPDATE/DELETE + // keyword appears anywhere, treat it as that; otherwise a SELECT. + if w := fbDMLWordRe.FindString(sanitized); w != "" { + switch strings.ToUpper(w) { + case "INSERT": + return StmtInsert + case "UPDATE": + return StmtUpdate + case "DELETE": + return StmtDelete + } + } + return StmtSelect + } + return StmtOther +} + +// stripComments removes -- line comments and /* */ block comments, replacing +// each with a single space so token boundaries are preserved. It does not +// remove comment markers that appear inside string literals. +func stripComments(s string) string { + var b strings.Builder + b.Grow(len(s)) + for i := 0; i < len(s); { + switch c := s[i]; { + case c == '\'' || c == '"': + i = copyStringLiteral(&b, s, i) + case c == '-' && i+1 < len(s) && s[i+1] == '-': + i = skipLineComment(s, i) + b.WriteByte(' ') + case c == '/' && i+1 < len(s) && s[i+1] == '*': + i = skipBlockComment(s, i) + b.WriteByte(' ') + default: + b.WriteByte(c) + i++ + } + } + return b.String() +} + +// copyStringLiteral writes the string literal that begins at s[i] (a quote +// byte) verbatim, treating a doubled quote (two of the same quote byte in a +// row) as an escaped quote rather than the terminator, and returns the index +// just past the literal. +func copyStringLiteral(b *strings.Builder, s string, i int) int { + q := s[i] + b.WriteByte(q) + i++ + for i < len(s) { + b.WriteByte(s[i]) + if s[i] == q { + if i+1 < len(s) && s[i+1] == q { // doubled-quote escape + b.WriteByte(s[i+1]) + i += 2 + continue + } + return i + 1 + } + i++ + } + return i +} + +// skipLineComment returns the index of the newline (or end of input) that +// terminates the -- comment starting at i. +func skipLineComment(s string, i int) int { + for i < len(s) && s[i] != '\n' { + i++ + } + return i +} + +// skipBlockComment returns the index just past the */ that closes the block +// comment starting at i (or end of input if unterminated). +func skipBlockComment(s string, i int) int { + i += 2 + for i+1 < len(s) && (s[i] != '*' || s[i+1] != '/') { + i++ + } + return i + 2 +} + +// blankStringLiterals replaces the contents of every string literal with an +// empty literal, so SQL keywords that appear inside string values cannot be +// mistaken for clauses. Input must already be comment-free. +func blankStringLiterals(s string) string { + var b strings.Builder + b.Grow(len(s)) + for i := 0; i < len(s); { + c := s[i] + if c == '\'' || c == '"' { + q := c + b.WriteByte(q) + i++ + for i < len(s) { + if s[i] == q { + if i+1 < len(s) && s[i+1] == q { // doubled-quote escape + i += 2 + continue + } + i++ + break + } + i++ + } + b.WriteByte(q) + continue + } + b.WriteByte(c) + i++ + } + return b.String() +} diff --git a/analyzer/fallback_test.go b/analyzer/fallback_test.go new file mode 100644 index 0000000..a6d0476 --- /dev/null +++ b/analyzer/fallback_test.go @@ -0,0 +1,111 @@ +package analyzer + +import "testing" + +// These cover the exact false-positive classes the production-grade review +// flagged for the old raw-regex engine: comments, keyword-like identifiers, +// CTEs, subqueries, multi-statement input, and driver placeholders. The +// fallback parser must not misfire and must never panic or error. + +func TestFallback_CommentsDoNotTriggerRules(t *testing.T) { + a := Default() + tests := []struct { + name string + query string + }{ + {"line comment with DELETE", "SELECT id FROM users WHERE id = 1 -- DELETE FROM users everything"}, + {"block comment with WHERE", "DELETE FROM users /* no WHERE here on purpose */ WHERE id = 1"}, + {"line comment hiding where", "UPDATE users SET active = false WHERE id = 1 -- WHERE"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for _, r := range a.Analyze(tt.query) { + if r.RuleName == "delete-without-where" || r.RuleName == "update-without-where" { + t.Errorf("%s: unexpected %s on commented query: %s", tt.name, r.RuleName, tt.query) + } + } + }) + } +} + +func TestFallback_CommentedOutClausesAreNotCounted(t *testing.T) { + a := New(CheckDeleteWithoutWhere) + // The only WHERE is inside a comment, so this DELETE is genuinely unsafe. + got := a.Analyze("DELETE FROM users -- WHERE id = 1") + if len(got) != 1 || got[0].RuleName != "delete-without-where" { + t.Errorf("expected delete-without-where when WHERE is only in a comment, got %+v", got) + } +} + +func TestFallback_KeywordLikeIdentifiers(t *testing.T) { + a := Default() + // Column/table names containing keyword substrings must not be parsed + // as clauses. + queries := []string{ + "SELECT id, update_at, where_clause FROM orders WHERE id = 1 LIMIT 1", + "SELECT limited, ordered_by FROM report WHERE k = 1 LIMIT 10", + "UPDATE wherehouse SET stock = 0 WHERE id = 7", + } + for _, q := range queries { + results := a.Analyze(q) + for _, r := range results { + if r.RuleName == "update-without-where" || r.RuleName == "delete-without-where" { + t.Errorf("keyword-like identifier misparsed in %q: got %s", q, r.RuleName) + } + } + } +} + +func TestFallback_CTEAndSubquery(t *testing.T) { + p := NewFallbackParser() + + st, err := p.Parse("WITH recent AS (SELECT id FROM orders WHERE ts > now()) DELETE FROM orders WHERE id IN (SELECT id FROM recent)") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if st.Kind != StmtDelete { + t.Errorf("CTE-wrapped DELETE: got kind %v, want StmtDelete", st.Kind) + } + if !st.HasWhere { + t.Error("CTE-wrapped DELETE: WHERE clause not detected") + } + + st, _ = p.Parse("WITH t AS (SELECT 1) SELECT id FROM t WHERE id = 1") + if st.Kind != StmtSelect { + t.Errorf("CTE SELECT: got kind %v, want StmtSelect", st.Kind) + } +} + +func TestFallback_PlaceholdersNeverErrorOrPanic(t *testing.T) { + p := NewFallbackParser() + queries := []string{ + "SELECT * FROM users WHERE id = $1", + "SELECT * FROM users WHERE id = ? AND name = ?", + "SELECT * FROM users WHERE id = :id", + "INSERT INTO t VALUES ($1, $2); DELETE FROM other", + "", + "not even sql", + "SELECT '%' || $1 || '%'", + } + for _, q := range queries { + st, err := p.Parse(q) + if err != nil { + t.Errorf("fallback returned error for %q: %v (it must never error)", q, err) + } + if st == nil { + t.Errorf("fallback returned nil Statement for %q", q) + continue + } + if st.Exact { + t.Errorf("fallback Statement for %q must have Exact=false", q) + } + } +} + +func TestFallback_MultiStatementLeadingKind(t *testing.T) { + p := NewFallbackParser() + st, _ := p.Parse("DELETE FROM a WHERE id = 1; DROP TABLE b") + if st.Kind != StmtDelete { + t.Errorf("multi-statement: got kind %v, want StmtDelete (leading statement)", st.Kind) + } +} diff --git a/analyzer/parser.go b/analyzer/parser.go new file mode 100644 index 0000000..979732c --- /dev/null +++ b/analyzer/parser.go @@ -0,0 +1,18 @@ +package analyzer + +// Parser turns a raw SQL string into sqlguard's normalized Statement. +// +// Implementations: +// +// - FallbackParser (this package): zero-dependency, best-effort, never +// returns an error. +// - parsers/pgparser, parsers/mysqlparser (optional modules): real +// dialect ASTs, exact analysis, fall back to FallbackParser on parse +// failure. +// +// A Parser used on the runtime query path MUST NOT panic and SHOULD avoid +// returning an error for SQL it merely doesn't understand — degrade to a +// best-effort Statement instead, so analysis never breaks db.Query. +type Parser interface { + Parse(sql string) (*Statement, error) +} diff --git a/analyzer/profile_test.go b/analyzer/profile_test.go new file mode 100644 index 0000000..6f3c277 --- /dev/null +++ b/analyzer/profile_test.go @@ -0,0 +1,144 @@ +package analyzer + +import ( + "slices" + "testing" +) + +func TestRuleNamesCoversBuiltins(t *testing.T) { + names := RuleNames() + for _, want := range []string{ + "select-star", "leading-wildcard", "delete-without-where", + "update-without-where", "insert-without-columns", + "select-without-limit", "orderby-without-limit", + } { + if !slices.Contains(names, want) { + t.Errorf("rule %q not registered; got %v", want, names) + } + } +} + +func TestDefaultMatchesRegistry(t *testing.T) { + // Default() must behave exactly as before the registry refactor. + a := Default() + if got := a.Analyze("DELETE FROM users"); len(got) == 0 || got[0].Severity != SeverityCritical { + t.Fatalf("expected critical delete-without-where, got %+v", got) + } + if got := a.Analyze("SELECT id FROM users WHERE id = 1"); len(got) != 0 { + t.Errorf("expected no findings, got %+v", got) + } +} + +func TestProfileDisable(t *testing.T) { + a := DefaultWithProfile(Profile{Disabled: map[string]bool{"select-star": true}}) + for _, r := range a.Analyze("SELECT * FROM users") { + if r.RuleName == "select-star" { + t.Fatal("select-star should be disabled") + } + } +} + +func TestProfileOnlyWhitelist(t *testing.T) { + a := DefaultWithProfile(Profile{Only: map[string]bool{"select-star": true}}) + // delete-without-where must not run; only select-star is whitelisted. + got := a.Analyze("DELETE FROM users") + if len(got) != 0 { + t.Errorf("expected no findings with whitelist, got %+v", got) + } + if got := a.Analyze("SELECT * FROM users"); len(got) != 1 || got[0].RuleName != "select-star" { + t.Errorf("expected only select-star, got %+v", got) + } +} + +func TestProfileSeverityOverride(t *testing.T) { + a := DefaultWithProfile(Profile{Severity: map[string]Severity{"select-star": SeverityInfo}}) + got := a.Analyze("SELECT * FROM users") + if len(got) == 0 || got[0].RuleName != "select-star" || got[0].Severity != SeverityInfo { + t.Fatalf("expected select-star downgraded to INFO, got %+v", got) + } +} + +func TestProfileSettingsLeadingWildcardMinLength(t *testing.T) { + a := DefaultWithProfile(Profile{ + Settings: map[string]Settings{"leading-wildcard": {"min-length": 5}}, + }) + // 2-char term -> below threshold, not flagged. + if hits := filterByRule(a.Analyze("SELECT id FROM t WHERE x LIKE '%ab%'"), "leading-wildcard"); hits != 0 { + t.Errorf("short pattern should be tolerated with min-length=5") + } + // 6-char term -> flagged. + if hits := filterByRule(a.Analyze("SELECT id FROM t WHERE x LIKE '%abcdef%'"), "leading-wildcard"); hits != 1 { + t.Errorf("long pattern should still be flagged with min-length=5") + } +} + +func TestProfileSettingsInListMaxLength(t *testing.T) { + a := DefaultWithProfile(Profile{ + Settings: map[string]Settings{"in-list-too-large": {"max-length": 3}}, + }) + // 3 elements -> at threshold, not flagged. + if hits := filterByRule(a.Analyze("SELECT id FROM t WHERE id IN (1, 2, 3)"), "in-list-too-large"); hits != 0 { + t.Errorf("list at threshold should be tolerated with max-length=3") + } + // 4 elements -> over threshold, flagged. + if hits := filterByRule(a.Analyze("SELECT id FROM t WHERE id IN (1, 2, 3, 4)"), "in-list-too-large"); hits != 1 { + t.Errorf("list over threshold should be flagged with max-length=3") + } +} + +func TestProfileSettingsLargeOffsetThreshold(t *testing.T) { + a := DefaultWithProfile(Profile{ + Settings: map[string]Settings{"large-offset": {"threshold": 100}}, + }) + // offset 100 -> at threshold, not flagged. + if hits := filterByRule(a.Analyze("SELECT id FROM t LIMIT 10 OFFSET 100"), "large-offset"); hits != 0 { + t.Errorf("offset at threshold should be tolerated with threshold=100") + } + // offset 200 -> over threshold, flagged. + if hits := filterByRule(a.Analyze("SELECT id FROM t LIMIT 10 OFFSET 200"), "large-offset"); hits != 1 { + t.Errorf("offset over threshold should be flagged with threshold=100") + } +} + +func TestInlineSuppression(t *testing.T) { + a := Default() + + if got := a.Analyze("SELECT * FROM users -- sqlguard:ignore"); len(got) != 0 { + t.Errorf("bare ignore should suppress all, got %+v", got) + } + // Scoped: suppress select-star only; delete-without-where still fires. + q := "DELETE FROM users /* sqlguard:ignore:select-star */" + if got := a.Analyze(q); len(got) != 1 || got[0].RuleName != "delete-without-where" { + t.Errorf("scoped ignore should keep delete-without-where, got %+v", got) + } + if got := a.Analyze("SELECT * FROM users WHERE id = 1 /* sqlguard:ignore:select-star */"); len(got) != 0 { + t.Errorf("select-star should be suppressed, got %+v", got) + } + // The token inside a string literal must NOT suppress (no comment marker). + if got := a.Analyze("SELECT * FROM users WHERE note = 'sqlguard:ignore'"); len(got) == 0 { + t.Error("string-literal text must not act as a suppression directive") + } +} + +func TestParseIgnoreComment(t *testing.T) { + if all, _, found := ParseIgnoreComment("// sqlguard:ignore"); !found || !all { + t.Error("expected bare directive parsed as all") + } + all, rules, found := ParseIgnoreComment("// noise sqlguard:ignore:select-star, leading-wildcard") + if !found || all || !rules["select-star"] || !rules["leading-wildcard"] { + t.Errorf("expected scoped rules, got all=%v rules=%v found=%v", all, rules, found) + } + if _, _, found := ParseIgnoreComment("// just a normal comment"); found { + t.Error("non-directive comment should not be found") + } +} + +func filterByRule(rs []Result, name string) int { + n := 0 + for _, r := range rs { + if r.RuleName == name { + n++ + } + } + return n +} diff --git a/analyzer/redact.go b/analyzer/redact.go new file mode 100644 index 0000000..14f81b3 --- /dev/null +++ b/analyzer/redact.go @@ -0,0 +1,158 @@ +package analyzer + +import ( + "regexp" + "strings" +) + +// Redact returns sql with comments stripped and every single-quoted string +// literal and numeric literal replaced by a single "?" placeholder. Query +// structure, keywords, and identifiers (including double-quoted and +// backtick-quoted identifiers) are preserved, so the result stays readable +// and analyzable but carries no literal values — no emails, tokens, or other +// PII reach a log sink. +// +// It is a zero-dependency lexical pass, not a full parser: it is +// intentionally conservative (e.g. it does not special-case hex/scientific +// forms beyond a simple exponent) and never errors. Use it whenever a query +// is about to leave the process. +func Redact(sql string) string { + s := stripComments(sql) + var b strings.Builder + b.Grow(len(s)) + + var prev byte // last byte written to output, 0 at start + for i := 0; i < len(s); { + c := s[i] + switch { + case c == '\'': + // String literal — the classic PII carrier. Replace its whole + // body (honoring '' escapes) with one placeholder. + i = skipSingleQuoted(s, i) + b.WriteByte('?') + prev = '?' + case c == '"' || c == '`': + // Quoted identifier (ANSI double-quote / MySQL backtick). Copy + // verbatim so a quote-enclosed name or a stray ' inside it does + // not corrupt structure or trip the literal branch. + j := skipQuoted(s, i, c) + b.WriteString(s[i:j]) + prev = s[j-1] + i = j + case isDigit(c) && !suppressesNumber(prev): + j := scanNumber(s, i) + b.WriteByte('?') + prev = '?' + i = j + default: + b.WriteByte(c) + prev = c + i++ + } + } + return b.String() +} + +var fpListRe = regexp.MustCompile(`\(\?(?:, ?\?)+\)`) + +// Fingerprint returns a stable, PII-free identity for sql: it is Redact +// followed by whitespace collapsing and IN/VALUES-list folding +// ("(?, ?, ?)" -> "(?)") so that queries differing only in literal values or +// list length share one fingerprint. A trailing ";" is trimmed. +// +// The result is safe to use as a low-cardinality metric label or log key — +// it is the canonical query identity the runtime, the N+1 tracker, and any +// metrics/observability adapter group on. +func Fingerprint(sql string) string { + r := Redact(sql) + r = strings.Join(strings.Fields(r), " ") + r = fpListRe.ReplaceAllString(r, "(?)") + return strings.TrimRight(r, "; ") +} + +// IsMultiStatement reports whether sql contains more than one SQL statement, +// i.e. a ";" statement separator followed by further non-whitespace content. +// Comments and string-literal bodies are removed first (reusing the same +// comment/literal-aware lexer the parser uses), so the check cannot be +// defeated by a ";" hidden in a -- / /* */ comment or inside a string +// literal — the evasion the brittle strings.Contains(query, ";") check +// allowed. A single trailing ";" is not multi-statement. +func IsMultiStatement(sql string) bool { + s := blankStringLiterals(stripComments(sql)) + if _, rest, found := strings.Cut(s, ";"); found { + return strings.TrimSpace(rest) != "" + } + return false +} + +func isDigit(c byte) bool { return c >= '0' && c <= '9' } + +// suppressesNumber reports whether a digit following prev is part of an +// identifier (col1, int8) or a bind placeholder ($1, @p1) rather than a +// numeric literal, so it must not be redacted. +func suppressesNumber(prev byte) bool { + switch { + case prev >= 'a' && prev <= 'z', prev >= 'A' && prev <= 'Z', + prev >= '0' && prev <= '9': + return true + case prev == '_' || prev == '$' || prev == '@': + return true + } + return false +} + +// scanNumber returns the index just past the numeric literal starting at i +// (digits, an optional decimal point, and an optional e[+-]?digits exponent). +func scanNumber(s string, i int) int { + for i < len(s) && (isDigit(s[i]) || s[i] == '.') { + i++ + } + if i < len(s) && (s[i] == 'e' || s[i] == 'E') { + j := i + 1 + if j < len(s) && (s[j] == '+' || s[j] == '-') { + j++ + } + if j < len(s) && isDigit(s[j]) { + for j < len(s) && isDigit(s[j]) { + j++ + } + i = j + } + } + return i +} + +// skipSingleQuoted returns the index just past the single-quoted string +// literal that opens at s[i], treating a doubled single-quote (two in a row) +// as an escaped quote rather than the terminator. +func skipSingleQuoted(s string, i int) int { + i++ // opening quote + for i < len(s) { + if s[i] == '\'' { + if i+1 < len(s) && s[i+1] == '\'' { + i += 2 + continue + } + return i + 1 + } + i++ + } + return i +} + +// skipQuoted returns the index just past the quoted run starting at s[i] == q +// (q is '"' or '`'), honoring doubled-quote escapes for the same quote. +func skipQuoted(s string, i int, q byte) int { + i++ // opening quote + for i < len(s) { + if s[i] == q { + if i+1 < len(s) && s[i+1] == q { + i += 2 + continue + } + return i + 1 + } + i++ + } + return i +} diff --git a/analyzer/redact_policy_test.go b/analyzer/redact_policy_test.go new file mode 100644 index 0000000..c456b95 --- /dev/null +++ b/analyzer/redact_policy_test.go @@ -0,0 +1,51 @@ +package analyzer + +import "testing" + +func TestAnalyzeRedactsByDefault(t *testing.T) { + q := `SELECT * FROM users WHERE email = 'alice@acme.com'` + res := Default().Analyze(q) + if len(res) == 0 { + t.Fatal("expected at least one finding (select-star)") + } + for _, r := range res { + if contains(r.Query, "alice@acme.com") { + t.Errorf("default Analyze leaked literal in Query: %q", r.Query) + } + if r.Fingerprint == "" { + t.Error("Fingerprint not populated") + } + if contains(r.Fingerprint, "alice@acme.com") { + t.Errorf("Fingerprint leaked literal: %q", r.Fingerprint) + } + } +} + +func TestWithRawQueryKeepsLiterals(t *testing.T) { + q := `SELECT * FROM users WHERE email = 'alice@acme.com'` + res := Default().WithRawQuery().Analyze(q) + if len(res) == 0 { + t.Fatal("expected a finding") + } + if !contains(res[0].Query, "alice@acme.com") { + t.Errorf("WithRawQuery should keep raw SQL, got %q", res[0].Query) + } + if res[0].Fingerprint == "" { + t.Error("Fingerprint must still be set in raw mode") + } +} + +func TestPrepareQueryPolicy(t *testing.T) { + raw := `SELECT 'x' FROM t WHERE id = 9` + d, fp := Default().PrepareQuery(raw) + if contains(d, "'x'") { + t.Errorf("default PrepareQuery should redact: %q", d) + } + if fp == "" { + t.Error("fingerprint empty") + } + d2, _ := Default().WithRawQuery().PrepareQuery(raw) + if d2 != raw { + t.Errorf("raw PrepareQuery = %q, want %q", d2, raw) + } +} diff --git a/analyzer/redact_test.go b/analyzer/redact_test.go new file mode 100644 index 0000000..ef3b35f --- /dev/null +++ b/analyzer/redact_test.go @@ -0,0 +1,115 @@ +package analyzer + +import "testing" + +func TestRedact(t *testing.T) { + cases := []struct { + name string + in string + want string + }{ + {"string literal", `SELECT * FROM users WHERE email = 'alice@acme.com'`, + `SELECT * FROM users WHERE email = ?`}, + {"numeric literal", `SELECT * FROM t WHERE id = 42 AND age > 18`, + `SELECT * FROM t WHERE id = ? AND age > ?`}, + {"float and exponent", `SELECT * FROM t WHERE x = 3.14 AND y = 1e10`, + `SELECT * FROM t WHERE x = ? AND y = ?`}, + {"identifier with digits kept", `SELECT col1, int8_v FROM t1 WHERE a2 = 5`, + `SELECT col1, int8_v FROM t1 WHERE a2 = ?`}, + {"bind placeholders kept", `SELECT * FROM t WHERE a = $1 AND b = @p2`, + `SELECT * FROM t WHERE a = $1 AND b = @p2`}, + {"quoted identifier preserved", `SELECT "weird;col" FROM t WHERE n = 'x'`, + `SELECT "weird;col" FROM t WHERE n = ?`}, + {"backtick identifier preserved", "SELECT `from` FROM t WHERE n = 'x'", + "SELECT `from` FROM t WHERE n = ?"}, + {"escaped quote in literal", `SELECT * FROM t WHERE s = 'O''Brien'`, + `SELECT * FROM t WHERE s = ?`}, + {"comment stripped", "SELECT a -- secret 'tok'\nFROM t WHERE id = 9", + "SELECT a \nFROM t WHERE id = ?"}, + {"semicolon inside literal not structural", `SELECT * FROM t WHERE s = 'a;b'`, + `SELECT * FROM t WHERE s = ?`}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + if got := Redact(c.in); got != c.want { + t.Errorf("Redact(%q)\n got: %q\nwant: %q", c.in, got, c.want) + } + }) + } +} + +func TestRedactNoPII(t *testing.T) { + pii := []string{"alice@acme.com", "123-45-6789", "4111111111111111", "secret"} + q := `SELECT * FROM users WHERE email='alice@acme.com' AND ssn='123-45-6789' + AND card='4111111111111111' /* secret */ LIMIT 10` + got := Redact(q) + for _, p := range pii { + if contains(got, p) { + t.Errorf("Redact leaked %q: %q", p, got) + } + } +} + +func TestFingerprint(t *testing.T) { + cases := []struct{ name, in, want string }{ + {"collapse whitespace", "SELECT *\n FROM t WHERE id = 1", + "SELECT * FROM t WHERE id = ?"}, + {"fold IN list", `SELECT * FROM t WHERE id IN (1, 2, 3, 4)`, + `SELECT * FROM t WHERE id IN (?)`}, + {"fold VALUES tuple", `INSERT INTO t VALUES ('a', 'b', 'c')`, + `INSERT INTO t VALUES (?)`}, + {"trailing semicolon trimmed", `SELECT 1;`, `SELECT ?`}, + {"differing literals same fp", + `SELECT * FROM t WHERE name = 'bob' AND age = 7`, + `SELECT * FROM t WHERE name = ? AND age = ?`}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + if got := Fingerprint(c.in); got != c.want { + t.Errorf("Fingerprint(%q)\n got: %q\nwant: %q", c.in, got, c.want) + } + }) + } + + // Stability: queries differing only in values/list length share a fp. + a := Fingerprint(`SELECT * FROM t WHERE id IN (1,2,3) AND s = 'x'`) + b := Fingerprint(`SELECT * FROM t WHERE id IN (9,8) AND s = 'zzzzz'`) + if a != b { + t.Errorf("fingerprints should match:\n a=%q\n b=%q", a, b) + } +} + +func TestIsMultiStatement(t *testing.T) { + cases := []struct { + name string + in string + want bool + }{ + {"single", `SELECT * FROM t WHERE id = 1`, false}, + {"trailing semicolon", `SELECT * FROM t;`, false}, + {"trailing semicolon + ws", "SELECT 1; \n\t", false}, + {"stacked", `SELECT 1; DROP TABLE users`, true}, + {"stacked no space", `SELECT 1;DELETE FROM t`, true}, + {"semicolon in line comment", "SELECT 1 -- a; b\n", false}, + {"semicolon in block comment", `SELECT 1 /* a ; b */`, false}, + {"semicolon in string literal", `SELECT * FROM t WHERE s = 'a; DROP'`, false}, + {"comment hides stacking attempt", "SELECT 1 -- ;\nfrom t", false}, + {"real stack after string", `SELECT 'a;b'; DELETE FROM t`, true}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + if got := IsMultiStatement(c.in); got != c.want { + t.Errorf("IsMultiStatement(%q) = %v, want %v", c.in, got, c.want) + } + }) + } +} + +func contains(s, sub string) bool { + for i := 0; i+len(sub) <= len(s); i++ { + if s[i:i+len(sub)] == sub { + return true + } + } + return false +} diff --git a/analyzer/registry.go b/analyzer/registry.go new file mode 100644 index 0000000..12fe055 --- /dev/null +++ b/analyzer/registry.go @@ -0,0 +1,151 @@ +package analyzer + +import ( + "sort" + "sync" + "time" +) + +// Settings holds rule-specific configuration as a generic key/value map so +// new tunables can be added without changing this type or the config schema. +// Accessors are nil-safe and fall back to the provided default, so a rule +// can always be constructed even with no settings supplied. +type Settings map[string]any + +// Int returns the setting as an int, or def if missing or not numeric. +// YAML decodes integers as int and JSON as float64, so both are accepted. +func (s Settings) Int(key string, def int) int { + if s == nil { + return def + } + switch v := s[key].(type) { + case int: + return v + case int64: + return int(v) + case float64: + return int(v) + default: + return def + } +} + +// Bool returns the setting as a bool, or def if missing or not a bool. +func (s Settings) Bool(key string, def bool) bool { + if s == nil { + return def + } + if v, ok := s[key].(bool); ok { + return v + } + return def +} + +// String returns the setting as a string, or def if missing or not a string. +func (s Settings) String(key, def string) string { + if s == nil { + return def + } + if v, ok := s[key].(string); ok { + return v + } + return def +} + +// Duration returns the setting parsed as a time.Duration. It accepts a +// duration string ("200ms") or a number interpreted as milliseconds. Returns +// def if missing or unparseable. +func (s Settings) Duration(key string, def time.Duration) time.Duration { + if s == nil { + return def + } + switch v := s[key].(type) { + case string: + if d, err := time.ParseDuration(v); err == nil { + return d + } + case int: + return time.Duration(v) * time.Millisecond + case int64: + return time.Duration(v) * time.Millisecond + case float64: + return time.Duration(v) * time.Millisecond + } + return def +} + +// RuleSpec describes a built-in rule: its stable name (used in config, +// suppressions and reports), its default severity, and a factory that builds +// the rule from its settings. Keeping construction behind a factory is what +// makes per-rule settings work uniformly for every present and future rule. +type RuleSpec struct { + Name string + DefaultSeverity Severity + Factory func(Settings) Rule +} + +var ( + registryMu sync.RWMutex + registry = map[string]RuleSpec{} +) + +// Register adds a rule to the global registry. Built-in rules call this from +// init(); third-party rules may call it too. A duplicate name overwrites the +// previous spec, so a custom rule can replace a built-in one by name. +func Register(spec RuleSpec) { + registryMu.Lock() + defer registryMu.Unlock() + registry[spec.Name] = spec +} + +// RuleNames returns all registered rule names, sorted. Used by the config +// loader to validate rule references and by tooling to list rules. +func RuleNames() []string { + registryMu.RLock() + names := make([]string, 0, len(registry)) + for n := range registry { + names = append(names, n) + } + registryMu.RUnlock() + sort.Strings(names) + return names +} + +// specs returns all registered specs sorted by name, for deterministic +// analyzer construction and stable report ordering. +func specs() []RuleSpec { + registryMu.RLock() + out := make([]RuleSpec, 0, len(registry)) + for _, s := range registry { + out = append(out, s) + } + registryMu.RUnlock() + sort.Slice(out, func(i, j int) bool { return out[i].Name < out[j].Name }) + return out +} + +// Profile is the resolved, parser-independent view of configuration applied +// to an Analyzer at construction time. The config package builds it from +// .sqlguard.yml; analyzer never imports config or YAML. All maps are keyed +// by rule name. Resolution happens once here, never on the per-query path. +type Profile struct { + // Disabled rules are not constructed or run. + Disabled map[string]bool + // Only, when non-empty, is a whitelist: only these rules run. + Only map[string]bool + // Severity overrides a rule's reported severity. + Severity map[string]Severity + // Settings holds per-rule tunables. + Settings map[string]Settings + // RawQuery, when true, disables Result.Query redaction (literals are + // left in the reported SQL). Default (false) redacts — see + // Analyzer.WithRawQuery. + RawQuery bool +} + +func (p Profile) skip(name string) bool { + if len(p.Only) > 0 && !p.Only[name] { + return true + } + return p.Disabled[name] +} diff --git a/analyzer/result.go b/analyzer/result.go new file mode 100644 index 0000000..30a8645 --- /dev/null +++ b/analyzer/result.go @@ -0,0 +1,19 @@ +package analyzer + +// Result represents a single finding from query analysis. +type Result struct { + RuleName string + Severity Severity + // Query is the offending SQL as surfaced to reporters. By default it is + // redacted (string/numeric literals replaced with "?") so literal values + // never reach a log sink; an Analyzer built WithRawQuery leaves it raw. + Query string + // Fingerprint is the redacted, whitespace-collapsed, list-folded query + // identity (see analyzer.Fingerprint). It is always set, never carries + // PII, and is safe as a metric label or log key. + Fingerprint string + Message string + Suggestion string + File string // populated only in static analysis mode + Line int // populated only in static analysis mode +} diff --git a/analyzer/rules.go b/analyzer/rules.go new file mode 100644 index 0000000..872df4e --- /dev/null +++ b/analyzer/rules.go @@ -0,0 +1,269 @@ +package analyzer + +import "fmt" + +// Built-in rules self-register so they are addressable by name for config +// (enable/disable, severity overrides, settings) and suppressions. Adding a +// new rule is just another Register call here — no other plumbing changes. +func init() { + Register(RuleSpec{Name: "select-star", DefaultSeverity: SeverityWarning, + Factory: func(Settings) Rule { return CheckSelectStar }}) + Register(RuleSpec{Name: "leading-wildcard", DefaultSeverity: SeverityWarning, + Factory: func(s Settings) Rule { return leadingWildcardRule(s.Int("min-length", 0)) }}) + Register(RuleSpec{Name: "delete-without-where", DefaultSeverity: SeverityCritical, + Factory: func(Settings) Rule { return CheckDeleteWithoutWhere }}) + Register(RuleSpec{Name: "update-without-where", DefaultSeverity: SeverityCritical, + Factory: func(Settings) Rule { return CheckUpdateWithoutWhere }}) + Register(RuleSpec{Name: "insert-without-columns", DefaultSeverity: SeverityWarning, + Factory: func(Settings) Rule { return CheckInsertWithoutColumns }}) + Register(RuleSpec{Name: "select-without-limit", DefaultSeverity: SeverityWarning, + Factory: func(Settings) Rule { return CheckSelectWithoutLimit }}) + Register(RuleSpec{Name: "orderby-without-limit", DefaultSeverity: SeverityInfo, + Factory: func(Settings) Rule { return CheckOrderByWithoutLimit }}) + Register(RuleSpec{Name: "non-sargable-predicate", DefaultSeverity: SeverityWarning, + Factory: func(Settings) Rule { return CheckNonSargablePredicate }}) + Register(RuleSpec{Name: "add-not-null-without-default", DefaultSeverity: SeverityWarning, + Factory: func(Settings) Rule { return CheckAddNotNullWithoutDefault }}) + Register(RuleSpec{Name: "implicit-join", DefaultSeverity: SeverityWarning, + Factory: func(Settings) Rule { return CheckImplicitJoin }}) + Register(RuleSpec{Name: "cartesian-join", DefaultSeverity: SeverityWarning, + Factory: func(Settings) Rule { return CheckCartesianJoin }}) + Register(RuleSpec{Name: "in-list-too-large", DefaultSeverity: SeverityWarning, + Factory: func(s Settings) Rule { return inListRule(s.Int("max-length", 100)) }}) + Register(RuleSpec{Name: "large-offset", DefaultSeverity: SeverityWarning, + Factory: func(s Settings) Rule { return largeOffsetRule(s.Int("threshold", 1000)) }}) + Register(RuleSpec{Name: "select-distinct", DefaultSeverity: SeverityInfo, + Factory: func(Settings) Rule { return CheckSelectDistinct }}) +} + +// CheckSelectStar detects SELECT * usage. +func CheckSelectStar(s *Statement) (Result, bool) { + if s.SelectStar { + return Result{ + RuleName: "select-star", + Query: s.Raw, + Message: "SELECT * detected. Selecting all columns can hurt performance.", + Suggestion: "Select only the columns you need.", + }, true + } + return Result{}, false +} + +// CheckLeadingWildcard detects LIKE patterns with leading wildcards, using +// the rule's default settings (no minimum term length). +func CheckLeadingWildcard(s *Statement) (Result, bool) { + return leadingWildcardRule(0)(s) +} + +// leadingWildcardRule builds the leading-wildcard rule. When minLen > 0, a +// leading-wildcard LIKE is flagged only if its searchable term is at least +// minLen characters long, so short patterns like LIKE '%x%' can be tolerated. +// A statement whose term length is unknown (0, e.g. produced by a real +// parser that did not compute it) is still flagged, to avoid false negatives. +func leadingWildcardRule(minLen int) Rule { + return func(s *Statement) (Result, bool) { + if !s.LeadingWildcardLike { + return Result{}, false + } + if minLen > 0 && s.LeadingWildcardTermLen > 0 && s.LeadingWildcardTermLen < minLen { + return Result{}, false + } + return Result{ + RuleName: "leading-wildcard", + Query: s.Raw, + Message: "LIKE with leading wildcard detected. Index cannot be used.", + Suggestion: "Use prefix search or a full-text index.", + }, true + } +} + +// CheckDeleteWithoutWhere detects DELETE statements without a WHERE clause. +func CheckDeleteWithoutWhere(s *Statement) (Result, bool) { + if s.Kind == StmtDelete && !s.HasWhere { + return Result{ + RuleName: "delete-without-where", + Query: s.Raw, + Message: "DELETE without WHERE clause detected. This will delete all rows.", + Suggestion: "Add a WHERE clause to limit the scope of the delete.", + }, true + } + return Result{}, false +} + +// CheckUpdateWithoutWhere detects UPDATE statements without a WHERE clause. +func CheckUpdateWithoutWhere(s *Statement) (Result, bool) { + if s.Kind == StmtUpdate && !s.HasWhere { + return Result{ + RuleName: "update-without-where", + Query: s.Raw, + Message: "UPDATE without WHERE clause detected. This will update all rows.", + Suggestion: "Add a WHERE clause to limit the scope of the update.", + }, true + } + return Result{}, false +} + +// CheckInsertWithoutColumns detects INSERT statements without an explicit +// column list. +func CheckInsertWithoutColumns(s *Statement) (Result, bool) { + if s.Kind == StmtInsert && !s.InsertColumnsListed { + return Result{ + RuleName: "insert-without-columns", + Query: s.Raw, + Message: "INSERT without explicit column list. This breaks if table schema changes.", + Suggestion: "Specify columns explicitly: INSERT INTO table (col1, col2) VALUES (...).", + }, true + } + return Result{}, false +} + +// CheckSelectWithoutLimit detects SELECT statements without a LIMIT clause. +// Only flags queries that have a FROM clause (to skip SELECT 1, SELECT +// version(), etc.) and don't have WHERE, to reduce noise. +func CheckSelectWithoutLimit(s *Statement) (Result, bool) { + if s.Kind == StmtSelect && s.HasFrom && !s.HasLimit && !s.HasWhere { + return Result{ + RuleName: "select-without-limit", + Query: s.Raw, + Message: "SELECT without LIMIT or WHERE clause. May return excessive rows.", + Suggestion: "Add a LIMIT clause or WHERE filter to restrict results.", + }, true + } + return Result{}, false +} + +// CheckNonSargablePredicate detects a function or cast applied to a column on +// the column side of a WHERE comparison (e.g. WHERE LOWER(email) = ...), which +// prevents an ordinary index on that column from being used. +func CheckNonSargablePredicate(s *Statement) (Result, bool) { + if s.NonSargablePredicate { + return Result{ + RuleName: "non-sargable-predicate", + Query: s.Raw, + Message: "Function applied to a column in WHERE prevents index use.", + Suggestion: "Compare the bare column instead, or add a matching expression/function index.", + }, true + } + return Result{}, false +} + +// CheckAddNotNullWithoutDefault detects an ALTER TABLE that adds a NOT NULL +// column with no DEFAULT, which errors or forces a full table rewrite on a +// populated table. +func CheckAddNotNullWithoutDefault(s *Statement) (Result, bool) { + if s.AddNotNullNoDefault { + return Result{ + RuleName: "add-not-null-without-default", + Query: s.Raw, + Message: "ADD COLUMN ... NOT NULL without DEFAULT fails or rewrites the table on a populated table.", + Suggestion: "Add a DEFAULT, or split into: add the column nullable, backfill, then SET NOT NULL.", + }, true + } + return Result{}, false +} + +// CheckInListTooLarge detects an IN (...) value list with more elements than +// the default threshold (100). Use the registry / config to tune max-length. +func CheckInListTooLarge(s *Statement) (Result, bool) { + return inListRule(100)(s) +} + +// inListRule builds the in-list-too-large rule. It flags a statement whose +// largest IN (...) value list has more than maxLen elements. A maxLen of 0 +// flags any value-list IN; subquery INs are never counted (MaxInListLen +// excludes them). +func inListRule(maxLen int) Rule { + return func(s *Statement) (Result, bool) { + if s.MaxInListLen <= maxLen { + return Result{}, false + } + return Result{ + RuleName: "in-list-too-large", + Query: s.Raw, + Message: fmt.Sprintf("IN list has %d elements (threshold %d). Large IN lists hurt query planning.", s.MaxInListLen, maxLen), + Suggestion: "Use a JOIN against a temp table / VALUES list, or a parameterized array such as = ANY($1).", + }, true + } +} + +// CheckSelectDistinct detects a select-level DISTINCT, which is often added to +// hide duplicate rows produced by an unintended join fan-out rather than to +// express a genuine need for distinct results. INFO by default. +func CheckSelectDistinct(s *Statement) (Result, bool) { + if s.SelectDistinct { + return Result{ + RuleName: "select-distinct", + Query: s.Raw, + Message: "SELECT DISTINCT detected. It often masks duplicate rows from an unintended join.", + Suggestion: "Confirm the duplicates aren't a join fan-out; prefer fixing the join or using EXISTS/GROUP BY.", + }, true + } + return Result{}, false +} + +// CheckLargeOffset detects a literal OFFSET larger than the default threshold +// (1000). Use the registry / config to tune threshold. +func CheckLargeOffset(s *Statement) (Result, bool) { + return largeOffsetRule(1000)(s) +} + +// largeOffsetRule builds the large-offset rule. It flags a statement whose +// literal OFFSET exceeds threshold — deep pagination, where the database scans +// and discards every skipped row. Parameterized offsets (OffsetValue == 0) are +// never flagged. +func largeOffsetRule(threshold int) Rule { + return func(s *Statement) (Result, bool) { + if s.OffsetValue <= threshold { + return Result{}, false + } + return Result{ + RuleName: "large-offset", + Query: s.Raw, + Message: fmt.Sprintf("OFFSET %d exceeds %d. Deep pagination scans and discards all skipped rows.", s.OffsetValue, threshold), + Suggestion: "Use keyset (cursor) pagination: WHERE id > $last ORDER BY id LIMIT n.", + }, true + } +} + +// CheckCartesianJoin detects a multi-table FROM with no join condition and no +// WHERE filter — an unconditioned cartesian product (incl. CROSS JOIN). +func CheckCartesianJoin(s *Statement) (Result, bool) { + if s.CartesianJoin { + return Result{ + RuleName: "cartesian-join", + Query: s.Raw, + Message: "Cartesian product: multiple tables joined with no join condition or WHERE filter.", + Suggestion: "Add a JOIN ... ON condition (or a WHERE clause relating the tables).", + }, true + } + return Result{}, false +} + +// CheckImplicitJoin detects a FROM clause that joins tables with commas +// (FROM a, b) instead of explicit JOIN syntax — error-prone because a +// forgotten join condition silently yields a cartesian product. +func CheckImplicitJoin(s *Statement) (Result, bool) { + if s.ImplicitCommaJoin { + return Result{ + RuleName: "implicit-join", + Query: s.Raw, + Message: "Implicit comma join in FROM. A missing join condition silently becomes a cartesian product.", + Suggestion: "Use explicit JOIN ... ON syntax.", + }, true + } + return Result{}, false +} + +// CheckOrderByWithoutLimit detects ORDER BY without LIMIT, which sorts the +// entire result set. +func CheckOrderByWithoutLimit(s *Statement) (Result, bool) { + if s.HasOrderBy && !s.HasLimit { + return Result{ + RuleName: "orderby-without-limit", + Query: s.Raw, + Message: "ORDER BY without LIMIT sorts the entire result set.", + Suggestion: "Add a LIMIT clause if you only need a subset of rows.", + }, true + } + return Result{}, false +} diff --git a/analyzer/severity.go b/analyzer/severity.go new file mode 100644 index 0000000..cde1e06 --- /dev/null +++ b/analyzer/severity.go @@ -0,0 +1,26 @@ +package analyzer + +// Severity represents the importance level of an analysis finding. +type Severity int + +const ( + // SeverityInfo is an advisory finding worth noting but not necessarily acting on. + SeverityInfo Severity = iota + // SeverityWarning is a likely problem that should be reviewed. + SeverityWarning + // SeverityCritical is a serious problem likely to cause incorrect or destructive behavior. + SeverityCritical +) + +func (s Severity) String() string { + switch s { + case SeverityInfo: + return "INFO" + case SeverityWarning: + return "WARNING" + case SeverityCritical: + return "CRITICAL" + default: + return "UNKNOWN" + } +} diff --git a/analyzer/statement.go b/analyzer/statement.go new file mode 100644 index 0000000..3b54401 --- /dev/null +++ b/analyzer/statement.go @@ -0,0 +1,138 @@ +package analyzer + +// StmtKind is the top-level kind of a SQL statement. +type StmtKind int + +const ( + // StmtUnknown means the parser could not determine the statement kind. + StmtUnknown StmtKind = iota + // StmtSelect is a SELECT (or WITH ... SELECT) query. + StmtSelect + // StmtInsert is an INSERT statement. + StmtInsert + // StmtUpdate is an UPDATE statement. + StmtUpdate + // StmtDelete is a DELETE statement. + StmtDelete + // StmtOther is a recognized statement that none of the rules target + // (DDL, transaction control, etc.). + StmtOther +) + +// Statement is sqlguard's normalized, dialect-agnostic view of a single SQL +// statement. It carries only the semantic facts the rules need — not a full +// AST. Every Parser (the zero-dependency fallback and the optional real +// dialect parsers) populates this same struct, so rules never depend on a +// particular parser or dialect. +// +// Boolean fields are best-effort: a fallback-produced Statement may leave a +// field false when it genuinely cannot tell. Rules must treat "false" as +// "not detected", never as "proven absent", to avoid false positives. +type Statement struct { + // Raw is the original, untouched SQL string. Reported back to users. + Raw string + + // Kind is the statement's top-level kind. + Kind StmtKind + + // HasWhere reports whether the statement has a WHERE clause. + HasWhere bool + + // HasLimit reports whether the statement has a LIMIT clause. + HasLimit bool + + // HasOrderBy reports whether the statement has an ORDER BY clause. + HasOrderBy bool + + // HasFrom reports whether a SELECT has a FROM clause. Distinguishes + // "SELECT * FROM t" from "SELECT 1" / "SELECT version()". + HasFrom bool + + // SelectStar reports an unqualified "SELECT *" / "SELECT t.*" of columns. + // It is false for aggregate forms like COUNT(*). + SelectStar bool + + // SelectDistinct reports a select-level DISTINCT (SELECT DISTINCT ..., + // incl. Postgres DISTINCT ON and MySQL DISTINCTROW). It is false for an + // aggregate-level DISTINCT such as COUNT(DISTINCT col), which is unrelated. + // The dialect parsers compute it from the AST; the fallback approximates it + // lexically. + SelectDistinct bool + + // InsertColumnsListed reports whether an INSERT names its target columns + // explicitly: INSERT INTO t (a, b) VALUES (...). Only meaningful when + // Kind == StmtInsert. + InsertColumnsListed bool + + // LeadingWildcardLike reports a LIKE pattern beginning with a wildcard + // (e.g. LIKE '%foo'), which prevents index use. + LeadingWildcardLike bool + + // NonSargablePredicate reports a function or cast applied to a column on + // the column side of a WHERE comparison (e.g. WHERE LOWER(email) = ...), + // which prevents the use of an ordinary index on that column. Like the + // LIKE fields, this is a literal/text-level heuristic the real parsers' + // ASTs discard, so it is computed by the fallback lexer and preserved by + // the dialect parsers rather than recomputed structurally. + NonSargablePredicate bool + + // AddNotNullNoDefault reports an ALTER TABLE that adds a NOT NULL column + // with no DEFAULT (e.g. ALTER TABLE t ADD COLUMN c int NOT NULL), which + // fails or forces a table rewrite on a populated table. Like the other + // text-level fields above, it is computed by the fallback lexer and + // preserved by the dialect parsers. + AddNotNullNoDefault bool + + // ImplicitCommaJoin reports a FROM clause that lists multiple tables + // separated by top-level commas (FROM a, b) instead of explicit JOIN + // syntax — the old-style join that silently produces a cartesian product + // when its join condition is forgotten. Computed by the fallback lexer and + // preserved (not recomputed from the AST) by the dialect parsers, so it + // stays a best-effort heuristic even when Exact is true. + ImplicitCommaJoin bool + + // CartesianJoin reports a multi-table FROM (comma join, CROSS JOIN, or a + // bare JOIN) with no join condition (ON/USING/NATURAL) and no top-level + // WHERE filter — an unconditioned cartesian product. It is the high- + // confidence subset of ImplicitCommaJoin and also covers CROSS/bare JOIN. + // Like ImplicitCommaJoin, it is a fallback-lexer heuristic preserved by the + // dialect parsers, so it stays best-effort even when Exact is true. + CartesianJoin bool + + // MaxInListLen is the largest element count among the statement's IN (...) + // value lists (IN (SELECT ...) subqueries are excluded). It powers the + // in-list-too-large rule's max-length threshold. Zero means no value-list + // IN was found. Like the other counts, rules read it, never raw SQL. It is a + // fallback-lexer heuristic preserved by the dialect parsers (the AST discards + // the literal list it counts), so it stays best-effort even when Exact is true. + MaxInListLen int + + // OffsetValue is the largest literal OFFSET seen (standard OFFSET n or + // MySQL's LIMIT offset, count), powering the large-offset rule. Zero means + // no offset, OFFSET 0, or a parameterized offset (OFFSET $1 / ?), which + // cannot be evaluated statically and is therefore never flagged. The dialect + // parsers read it from the AST's limit clause; the fallback scans for it. + OffsetValue int + + // LeadingWildcardTermLen is the length of the longest searchable term + // (the literal with surrounding % wildcards trimmed) across all + // leading-wildcard LIKE patterns in the statement. It powers the + // leading-wildcard rule's min-length setting. Zero means "unknown" + // (e.g. produced by a real parser that did not compute it); rules must + // treat zero as unknown and not as "short", to avoid false negatives. + LeadingWildcardTermLen int + + // Exact is true when the Statement was produced by a real SQL parser + // (structural analysis), false when produced by the regex fallback + // (best-effort). Rules may use this to suppress lower-confidence findings. + // + // "Exact" covers the structural facts the dialect parsers derive from the + // AST: Kind, HasWhere/HasLimit/HasOrderBy/HasFrom, SelectStar, + // SelectDistinct, OffsetValue, and InsertColumnsListed. A few facts stay + // lexical heuristics even when Exact is true — MaxInListLen, + // ImplicitCommaJoin, CartesianJoin, and the literal/text-level fields + // (LeadingWildcard*, NonSargablePredicate, AddNotNullNoDefault) — because + // they read literal values the AST discards or are intentionally text-level. + // Each such field documents this. + Exact bool +} diff --git a/analyzer/suppress.go b/analyzer/suppress.go new file mode 100644 index 0000000..5502116 --- /dev/null +++ b/analyzer/suppress.go @@ -0,0 +1,69 @@ +package analyzer + +import ( + "regexp" + "strings" +) + +// ignoreDirectiveRe matches a sqlguard:ignore directive inside a SQL or Go +// comment. The leading comment marker (--, /*, #, //) anchors it so the +// token is honored only in comment context, not when the literal text +// happens to appear inside a string. An optional `:rule-a, rule-b` list +// scopes the suppression to specific rules; without it, all rules are +// suppressed for the statement. +var ignoreDirectiveRe = regexp.MustCompile(`(?i)(?:--|/\*|#|//)[^\n]*?sqlguard:ignore(?::\s*([a-z0-9_,\s-]+))?`) + +// ignoreTokenRe matches the bare directive in text that is already known to +// be a comment (e.g. go/ast comment text with the marker stripped). No +// comment marker is required here because the whole string is comment +// context. +var ignoreTokenRe = regexp.MustCompile(`(?i)sqlguard:ignore(?::\s*([a-z0-9_,\s-]+))?`) + +// parseIgnoreDirective scans raw SQL for `sqlguard:ignore` directives. +// It returns ignoreAll=true if any directive has no rule list, otherwise a +// set of rule names to suppress. The result is empty when no directive is +// present, so the common path allocates nothing. +func parseIgnoreDirective(sql string) (ignoreAll bool, ignored map[string]bool) { + if !strings.Contains(strings.ToLower(sql), "sqlguard:ignore") { + return false, nil + } + for _, m := range ignoreDirectiveRe.FindAllStringSubmatch(sql, -1) { + list := strings.TrimSpace(m[1]) + if list == "" { + return true, nil + } + if ignored == nil { + ignored = make(map[string]bool) + } + for name := range strings.SplitSeq(list, ",") { + if name = strings.TrimSpace(name); name != "" { + ignored[name] = true + } + } + } + return false, ignored +} + +// ParseIgnoreComment parses the text of a single comment for a +// sqlguard:ignore directive. It is used by the static scanner to honor +// `// sqlguard:ignore` / `// sqlguard:ignore:rule-a,rule-b` annotations in Go +// source. found reports whether a directive was present; all is true for a +// bare directive (suppress every rule); rules holds the named rules +// otherwise. +func ParseIgnoreComment(text string) (all bool, rules map[string]bool, found bool) { + m := ignoreTokenRe.FindStringSubmatch(text) + if m == nil { + return false, nil, false + } + list := strings.TrimSpace(m[1]) + if list == "" { + return true, nil, true + } + rules = make(map[string]bool) + for name := range strings.SplitSeq(list, ",") { + if name = strings.TrimSpace(name); name != "" { + rules[name] = true + } + } + return false, rules, true +} diff --git a/cmd/sqlguard/db.go b/cmd/sqlguard/db.go new file mode 100644 index 0000000..649a0ef --- /dev/null +++ b/cmd/sqlguard/db.go @@ -0,0 +1,31 @@ +package main + +import ( + "database/sql" + "fmt" +) + +// openDB opens a database connection using the appropriate driver. +func openDB(dialect, dsn string) (*sql.DB, error) { + var driverName string + switch dialect { + case "postgres": + driverName = "postgres" + case "mysql": + driverName = "mysql" + default: + return nil, fmt.Errorf("unsupported dialect: %s", dialect) + } + + db, err := sql.Open(driverName, dsn) + if err != nil { + return nil, err + } + + if err := db.Ping(); err != nil { + _ = db.Close() + return nil, fmt.Errorf("cannot reach database: %w", err) + } + + return db, nil +} diff --git a/cmd/sqlguard/explain.go b/cmd/sqlguard/explain.go new file mode 100644 index 0000000..8379b73 --- /dev/null +++ b/cmd/sqlguard/explain.go @@ -0,0 +1,87 @@ +package main + +import ( + "context" + "fmt" + "os" + "time" + + "github.com/KARTIKrocks/sqlguard/explain" + "github.com/KARTIKrocks/sqlguard/reporter" + "github.com/spf13/cobra" +) + +var ( + explainDSN string + explainDialect string + explainFormat string + explainAllowDML bool +) + +var explainCmd = &cobra.Command{ + Use: `explain "SQL QUERY"`, + Short: "Run EXPLAIN on a query against a live database", + Long: "Connects to a database and runs EXPLAIN to detect performance issues like sequential scans and missing indexes.", + Args: cobra.ExactArgs(1), + RunE: runExplain, +} + +func init() { + explainCmd.Flags().StringVar(&explainDSN, "db", "", "Database connection string (required)") + explainCmd.Flags().StringVar(&explainDialect, "dialect", "postgres", "Database dialect: postgres or mysql") + explainCmd.Flags().StringVar(&explainFormat, "format", "console", "Output format: console or json") + explainCmd.Flags().BoolVar(&explainAllowDML, "allow-dml", false, "Allow EXPLAIN on INSERT/UPDATE/DELETE (still run in an always-rolled-back transaction); refused by default") + _ = explainCmd.MarkFlagRequired("db") +} + +func runExplain(cmd *cobra.Command, args []string) error { + // Args are valid past this point; don't dump usage for runtime errors or + // the errIssuesFound sentinel. (Arg-parse errors still show usage.) + cmd.SilenceUsage = true + + query := args[0] + + db, err := openDB(explainDialect, explainDSN) + if err != nil { + return fmt.Errorf("failed to connect: %w", err) + } + defer func() { _ = db.Close() }() + + var explainOpts []explain.Option + if explainAllowDML { + explainOpts = append(explainOpts, explain.WithAllowDML()) + } + analyzer, err := explain.New(db, explainDialect, explainOpts...) + if err != nil { + return err + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + result, err := analyzer.Analyze(ctx, query) + if err != nil { + return err + } + + var rep reporter.Reporter + switch explainFormat { + case "json": + rep = reporter.NewJSONReporter() + default: + rep = reporter.NewConsoleReporter() + } + + if len(result.Issues) > 0 { + rep.Report(result.Issues) + if explainFormat != "json" { + fmt.Fprintf(os.Stderr, "\n%d issue(s) found in query plan\n", len(result.Issues)) + } + return errIssuesFound + } + + if explainFormat != "json" { + fmt.Fprintln(os.Stderr, "No issues found in query plan") + } + return nil +} diff --git a/cmd/sqlguard/main.go b/cmd/sqlguard/main.go new file mode 100644 index 0000000..927a276 --- /dev/null +++ b/cmd/sqlguard/main.go @@ -0,0 +1,18 @@ +package main + +import ( + "errors" + "fmt" + "os" +) + +func main() { + if err := rootCmd.Execute(); err != nil { + // If issues were found, exit with code 1 silently (already reported). + if errors.Is(err, errIssuesFound) { + os.Exit(1) + } + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} diff --git a/cmd/sqlguard/resolve_test.go b/cmd/sqlguard/resolve_test.go new file mode 100644 index 0000000..749010a --- /dev/null +++ b/cmd/sqlguard/resolve_test.go @@ -0,0 +1,178 @@ +package main + +import ( + "errors" + "os" + "path/filepath" + "strings" + "testing" +) + +// createModule turns dir into a loadable Go module so the type-aware +// (go/packages) scan path runs instead of the AST fallback. +func createModule(t *testing.T, dir, modPath string) { + t.Helper() + createTestFile(t, dir, "go.mod", "module "+modPath+"\n\ngo 1.26\n") +} + +func createFileInSubdir(t *testing.T, dir, rel, content string) { + t.Helper() + full := filepath.Join(dir, rel) + if err := os.MkdirAll(filepath.Dir(full), 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(full, []byte(content), 0o644); err != nil { + t.Fatal(err) + } +} + +func TestScan_ResolvesSamePackageConst(t *testing.T) { + dir := t.TempDir() + createModule(t, dir, "example.com/m") + createTestFile(t, dir, "q.go", `package example +import "database/sql" + +const userQuery = "SELECT * FROM users WHERE id = 1" + +func f(db *sql.DB) { + db.Query(userQuery) +} +`) + + out, err := captureScanOutput(t, dir) + + if !errors.Is(err, errIssuesFound) { + t.Fatalf("expected issues from a resolved const, got %v\n%s", err, out) + } + if !strings.Contains(out, "select-star") { + t.Errorf("expected select-star from resolved const, got:\n%s", out) + } +} + +func TestScan_ResolvesConstConcatenation(t *testing.T) { + dir := t.TempDir() + createModule(t, dir, "example.com/m") + createTestFile(t, dir, "q.go", `package example +import "database/sql" + +const ( + cols = "*" + q = "SELECT " + cols + " FROM users WHERE id = 1" +) + +func f(db *sql.DB) { + db.Query(q) +} +`) + + out, err := captureScanOutput(t, dir) + + if !errors.Is(err, errIssuesFound) { + t.Fatalf("expected issues from folded concatenation, got %v\n%s", err, out) + } + if !strings.Contains(out, "select-star") { + t.Errorf("expected select-star from concatenated const, got:\n%s", out) + } +} + +func TestScan_ResolvesCrossPackageConst(t *testing.T) { + dir := t.TempDir() + createModule(t, dir, "example.com/m") + createFileInSubdir(t, dir, "queries/queries.go", `package queries + +const GetUser = "SELECT * FROM users WHERE id = 1" +`) + createTestFile(t, dir, "main.go", `package example +import ( + "database/sql" + + "example.com/m/queries" +) + +func f(db *sql.DB) { + db.Query(queries.GetUser) +} +`) + + out, err := captureScanOutput(t, dir) + + if !errors.Is(err, errIssuesFound) { + t.Fatalf("expected issues from cross-package const, got %v\n%s", err, out) + } + if !strings.Contains(out, "select-star") { + t.Errorf("expected select-star from cross-package const, got:\n%s", out) + } +} + +func TestScan_ResolvesSprintfFormat(t *testing.T) { + dir := t.TempDir() + createModule(t, dir, "example.com/m") + createTestFile(t, dir, "q.go", `package example +import ( + "database/sql" + "fmt" +) + +func f(db *sql.DB, table string) { + db.Query(fmt.Sprintf("SELECT * FROM %s WHERE id = %d", table, 1)) +} +`) + + out, err := captureScanOutput(t, dir) + + if !errors.Is(err, errIssuesFound) { + t.Fatalf("expected issues from Sprintf format string, got %v\n%s", err, out) + } + if !strings.Contains(out, "select-star") { + t.Errorf("expected select-star from Sprintf format, got:\n%s", out) + } +} + +// A safe query held in a constant must stay clean — proves resolution does not +// introduce false positives. +func TestScan_ResolvedConstNoFalsePositive(t *testing.T) { + dir := t.TempDir() + createModule(t, dir, "example.com/m") + createTestFile(t, dir, "q.go", `package example +import "database/sql" + +const safe = "SELECT id, name FROM users WHERE id = ? LIMIT 10" + +func f(db *sql.DB) { + db.Query(safe, 1) +} +`) + + out, err := captureScanOutput(t, dir) + + if err != nil { + t.Fatalf("expected clean exit for safe resolved const, got %v\n%s", err, out) + } + if strings.Contains(out, "SQLGUARD") { + t.Errorf("expected no findings for safe const, got:\n%s", out) + } +} + +// Inline suppression must still apply when the query is a resolved const. +func TestScan_SuppressionWithResolvedConst(t *testing.T) { + dir := t.TempDir() + createModule(t, dir, "example.com/m") + createTestFile(t, dir, "q.go", `package example +import "database/sql" + +const userQuery = "SELECT * FROM users WHERE id = 1" + +func f(db *sql.DB) { + db.Query(userQuery) // sqlguard:ignore:select-star +} +`) + + out, err := captureScanOutput(t, dir) + + if err != nil { + t.Fatalf("expected clean exit, finding suppressed, got %v\n%s", err, out) + } + if strings.Contains(out, "select-star") { + t.Errorf("inline directive should suppress finding on resolved const, got:\n%s", out) + } +} diff --git a/cmd/sqlguard/root.go b/cmd/sqlguard/root.go new file mode 100644 index 0000000..ff44072 --- /dev/null +++ b/cmd/sqlguard/root.go @@ -0,0 +1,58 @@ +package main + +import ( + "fmt" + "os" + + "github.com/KARTIKrocks/sqlguard/config" + "github.com/spf13/cobra" +) + +var ( + configPathFlag string + noConfigFlag bool +) + +var rootCmd = &cobra.Command{ + Use: "sqlguard", + Short: "Production-safe SQL query analyzer for Go applications", + Long: "sqlguard detects slow queries, dangerous SQL patterns, and performance issues in Go applications.", + // main() owns error printing and exit codes. Without this, cobra prints + // "Error: issues found" for the errIssuesFound sentinel, which is a normal + // outcome (issues were already reported), not a CLI error. + SilenceErrors: true, +} + +func init() { + rootCmd.PersistentFlags().StringVar(&configPathFlag, "config", "", "path to .sqlguard.yml (default: auto-discover)") + rootCmd.PersistentFlags().BoolVar(&noConfigFlag, "no-config", false, "ignore any .sqlguard.yml and use built-in defaults") + rootCmd.AddCommand(scanCmd) + rootCmd.AddCommand(explainCmd) +} + +// resolveConfig loads configuration honoring --config / --no-config, falling +// back to discovery from startDir. Warnings are printed to stderr; a load +// error is returned to abort the command. +func resolveConfig(startDir string) (*config.Config, error) { + switch { + case noConfigFlag: + return config.Default(), nil + case configPathFlag != "": + return config.Load(configPathFlag) + default: + c, path, err := config.Discover(startDir) + if err != nil { + return nil, err + } + if path != "" { + _, _ = fmt.Fprintf(os.Stderr, "Using config %s\n", path) + } + return c, nil + } +} + +func printConfigWarnings(c *config.Config) { + for _, w := range c.Warnings() { + _, _ = fmt.Fprintf(os.Stderr, "sqlguard: config warning: %s\n", w) + } +} diff --git a/cmd/sqlguard/scan.go b/cmd/sqlguard/scan.go new file mode 100644 index 0000000..f1fd6e9 --- /dev/null +++ b/cmd/sqlguard/scan.go @@ -0,0 +1,401 @@ +package main + +import ( + "errors" + "fmt" + "go/ast" + "go/constant" + "go/parser" + "go/token" + "go/types" + "os" + "path/filepath" + "regexp" + "strconv" + "strings" + + "github.com/KARTIKrocks/sqlguard/analyzer" + "github.com/KARTIKrocks/sqlguard/reporter" + "github.com/spf13/cobra" + "golang.org/x/tools/go/packages" +) + +// SQL method names we look for on any receiver. +var sqlMethods = map[string]bool{ + "Query": true, + "QueryContext": true, + "QueryRow": true, + "QueryRowContext": true, + "Exec": true, + "ExecContext": true, + "Prepare": true, + "PrepareContext": true, +} + +var formatFlag string + +var scanCmd = &cobra.Command{ + Use: "scan [path]", + Short: "Scan Go source files for SQL query issues", + Long: "Statically analyzes Go source files to find SQL queries and check them for common issues.", + Args: cobra.MaximumNArgs(1), + RunE: runScan, +} + +func init() { + scanCmd.Flags().StringVar(&formatFlag, "format", "console", "Output format: console or json") +} + +// errIssuesFound is returned when the scan finds issues, to signal a non-zero exit code. +var errIssuesFound = errors.New("issues found") + +func runScan(cmd *cobra.Command, args []string) error { + // Args are valid past this point; don't dump usage for runtime errors or + // the errIssuesFound sentinel. (Arg-parse errors still show usage.) + cmd.SilenceUsage = true + + dir := "." + if len(args) > 0 { + dir = args[0] + } + + rep, err := newReporter(formatFlag) + if err != nil { + return err + } + + cfg, err := resolveConfig(dir) + if err != nil { + return err + } + a, err := cfg.Analyzer() + if err != nil { + return err + } + printConfigWarnings(cfg) + exclude, err := cfg.ExcludeMatcher() + if err != nil { + return err + } + + allResults, totalFiles, err := scanDir(dir, a, exclude) + if err != nil { + return fmt.Errorf("scan failed: %w", err) + } + + if len(allResults) > 0 { + rep.Report(allResults) + if formatFlag != "json" { + _, _ = fmt.Fprintf(os.Stderr, "\n%d issue(s) found (%d file(s) scanned)\n", len(allResults), totalFiles) + } + return errIssuesFound + } + + if formatFlag != "json" { + _, _ = fmt.Fprintf(os.Stderr, "No issues found (%d file(s) scanned)\n", totalFiles) + } + return nil +} + +func newReporter(format string) (reporter.Reporter, error) { + switch format { + case "json": + return reporter.NewJSONReporter(), nil + case "console", "": + return reporter.NewConsoleReporter(), nil + default: + return nil, fmt.Errorf("unknown format %q: use 'console' or 'json'", format) + } +} + +// scanDir type-checks the target with golang.org/x/tools/go/packages so query +// arguments that are constants, cross-package constants, constant +// concatenations, or fmt.Sprintf literal format strings all resolve. If the +// target is not a loadable module (no go.mod, ad-hoc files), it degrades to a +// dependency-free go/parser walk that still handles inline string literals, so +// a broken or module-less tree is never silently skipped. +func scanDir(dir string, a *analyzer.Analyzer, exclude func(string) bool) ([]analyzer.Result, int, error) { + absDir, err := filepath.Abs(dir) + if err != nil { + return nil, 0, fmt.Errorf("cannot resolve absolute path: %w", err) + } + + if results, n, ok := scanViaPackages(absDir, a, exclude); ok { + return results, n, nil + } + results, n, err := scanViaAST(dir, absDir, a, exclude) + return results, n, err +} + +// scanViaPackages is the primary, type-aware path. ok is false when the target +// cannot be loaded as a module at all (caller then falls back to the AST walk); +// individual packages with type errors are still scanned, degrading per-file to +// literal-only resolution. +func scanViaPackages(absDir string, a *analyzer.Analyzer, exclude func(string) bool) (results []analyzer.Result, totalFiles int, ok bool) { + cfg := &packages.Config{ + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | + packages.NeedImports | packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesInfo, + Dir: absDir, + Tests: false, + } + pkgs, err := packages.Load(cfg, "./...") + if err != nil || len(pkgs) == 0 { + return nil, 0, false + } + + seen := map[string]struct{}{} + scannedAny := false + for _, pkg := range pkgs { + if len(pkg.Syntax) == 0 { + continue + } + scannedAny = true + // Degraded package (type errors): TypesInfo may be partial or nil; + // constString falls back to *ast.BasicLit when info lacks the value. + info := pkg.TypesInfo + for _, file := range pkg.Syntax { + path := pkg.Fset.Position(file.Pos()).Filename + if !keepFile(path, absDir, exclude) { + continue + } + if _, dup := seen[path]; dup { + continue + } + seen[path] = struct{}{} + totalFiles++ + results = append(results, scanASTFile(pkg.Fset, file, info, a)...) + } + } + if !scannedAny { + return nil, 0, false + } + return results, totalFiles, true +} + +// scanViaAST is the dependency-free fallback for module-less / unbuildable +// trees: parse each file in isolation and resolve only inline string literals +// (info is nil, so scanASTFile degrades accordingly). +func scanViaAST(dir, absDir string, a *analyzer.Analyzer, exclude func(string) bool) ([]analyzer.Result, int, error) { + fset := token.NewFileSet() + var results []analyzer.Result + totalFiles := 0 + + err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() { + return shouldSkipDir(path, absDir) + } + if !keepFile(path, absDir, exclude) { + return nil + } + f, perr := parser.ParseFile(fset, path, nil, parser.ParseComments) + if perr != nil { + return nil + } + totalFiles++ + results = append(results, scanASTFile(fset, f, nil, a)...) + return nil + }) + + return results, totalFiles, err +} + +// keepFile reports whether a .go file should be analyzed: skip non-Go and +// _test.go files, then apply the configured exclude matcher against the path +// relative to the scan root (so regexes behave identically whether the path +// came from go list (absolute) or the walk (relative)). +func keepFile(path, absDir string, exclude func(string) bool) bool { + if !strings.HasSuffix(path, ".go") || strings.HasSuffix(path, "_test.go") { + return false + } + if exclude != nil { + rel := path + if abs, err := filepath.Abs(path); err == nil { + if r, rerr := filepath.Rel(absDir, abs); rerr == nil { + rel = r + } + } + if exclude(filepath.ToSlash(rel)) { + return false + } + } + return true +} + +func shouldSkipDir(path, absDir string) error { + absPath, _ := filepath.Abs(path) + if absPath != absDir { + base := filepath.Base(path) + if strings.HasPrefix(base, ".") || base == "vendor" || base == "node_modules" { + return filepath.SkipDir + } + } + return nil +} + +// scanASTFile walks one parsed file for SQL-method calls and resolves each +// query argument via resolveQuery. info may be nil (fallback / degraded +// package), in which case resolution is limited to inline string literals. +func scanASTFile(fset *token.FileSet, f *ast.File, info *types.Info, a *analyzer.Analyzer) []analyzer.Result { + suppress := buildSuppressor(fset, f) + + var results []analyzer.Result + ast.Inspect(f, func(n ast.Node) bool { + call, ok := n.(*ast.CallExpr) + if !ok { + return true + } + + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok || !sqlMethods[sel.Sel.Name] { + return true + } + + arg := queryArgExpr(sel.Sel.Name, call.Args) + if arg == nil { + return true + } + query := resolveQuery(info, arg) + if query == "" { + return true + } + + found := a.Analyze(query) + pos := fset.Position(call.Pos()) + all, rules := suppress(pos.Line) + for _, r := range found { + if all || rules[r.RuleName] { + continue + } + r.File = pos.Filename + r.Line = pos.Line + results = append(results, r) + } + return true + }) + + return results +} + +// buildSuppressor returns a lookup that, for a given source line, reports +// whether a `// sqlguard:ignore` directive applies — either trailing on that +// line or on the line directly above the call. This is the static-analysis +// counterpart to the in-SQL directive the analyzer handles at runtime. +func buildSuppressor(fset *token.FileSet, f *ast.File) func(line int) (bool, map[string]bool) { + type directive struct { + all bool + rules map[string]bool + } + byLine := map[int]directive{} + for _, cg := range f.Comments { + all, rules, found := analyzer.ParseIgnoreComment(cg.Text()) + if !found { + continue + } + end := fset.Position(cg.End()).Line + // Apply to the comment's own line (trailing) and the next line + // (comment sitting directly above the call). + byLine[end] = directive{all, rules} + byLine[end+1] = directive{all, rules} + } + return func(line int) (bool, map[string]bool) { + d, ok := byLine[line] + if !ok { + return false, nil + } + return d.all, d.rules + } +} + +// queryArgExpr returns the expression holding the SQL string for a given SQL +// method (the first arg, or the second for *Context variants). +func queryArgExpr(methodName string, args []ast.Expr) ast.Expr { + argIdx := 0 + if strings.HasSuffix(methodName, "Context") { + argIdx = 1 + } + if argIdx >= len(args) { + return nil + } + return args[argIdx] +} + +// resolveQuery turns a query-argument expression into SQL text. The single +// go/constant lookup in constString already covers inline literals, +// same-package constants, cross-package constants, and constant concatenation +// (the type checker folded them). fmt.Sprintf with a constant format string is +// resolved by neutralizing its verbs so the SQL stays structurally analyzable. +func resolveQuery(info *types.Info, e ast.Expr) string { + if s, ok := constString(info, e); ok { + return s + } + if ce, ok := e.(*ast.CallExpr); ok { + if fa, ok := sprintfFormatArg(info, ce); ok { + if f, ok := constString(info, fa); ok { + return neutralizeFormat(f) + } + } + } + return "" +} + +// constString resolves any constant string-valued expression. With type info +// this is one map lookup that the compiler already folded; without it (nil +// info or value absent) it degrades to a raw string literal. +func constString(info *types.Info, e ast.Expr) (string, bool) { + if info != nil { + if tv, ok := info.Types[e]; ok && tv.Value != nil && tv.Value.Kind() == constant.String { + return constant.StringVal(tv.Value), true + } + } + if bl, ok := e.(*ast.BasicLit); ok && bl.Kind == token.STRING { + if s, err := strconv.Unquote(bl.Value); err == nil { + return s, true + } + } + return "", false +} + +// sprintfFormatArg returns the format-string argument if ce is a call to +// fmt.Sprintf. With type info the callee is verified to be package "fmt"; +// without it, a conservative `fmt.Sprintf` selector-name heuristic is used. +func sprintfFormatArg(info *types.Info, ce *ast.CallExpr) (ast.Expr, bool) { + sel, ok := ce.Fun.(*ast.SelectorExpr) + if !ok || sel.Sel.Name != "Sprintf" || len(ce.Args) == 0 { + return nil, false + } + if info != nil { + if obj := info.Uses[sel.Sel]; obj != nil { + fn, ok := obj.(*types.Func) + if !ok || fn.Pkg() == nil || fn.Pkg().Path() != "fmt" { + return nil, false + } + return ce.Args[0], true + } + } + if id, ok := sel.X.(*ast.Ident); ok && id.Name == "fmt" { + return ce.Args[0], true + } + return nil, false +} + +var formatVerb = regexp.MustCompile(`%[-+# 0]*[\d.*]*[a-zA-Z%]`) + +// neutralizeFormat replaces fmt verbs in a constant format string with benign +// placeholders so the remaining SQL keeps its structure for the rule engine. +// Numeric verbs become 0; everything else becomes a harmless identifier; %% +// collapses to a literal %. +func neutralizeFormat(format string) string { + return formatVerb.ReplaceAllStringFunc(format, func(v string) string { + switch v[len(v)-1] { + case '%': + return "%" + case 'b', 'c', 'd', 'o', 'O', 'x', 'X', 'U', 'e', 'E', 'f', 'F', 'g', 'G', 'p': + return "0" + default: + return "sqlguard" + } + }) +} diff --git a/cmd/sqlguard/scan_test.go b/cmd/sqlguard/scan_test.go new file mode 100644 index 0000000..7cbafc8 --- /dev/null +++ b/cmd/sqlguard/scan_test.go @@ -0,0 +1,399 @@ +package main + +import ( + "bytes" + "errors" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/spf13/cobra" +) + +func createTestFile(t *testing.T, dir, name, content string) { + t.Helper() + err := os.WriteFile(filepath.Join(dir, name), []byte(content), 0644) + if err != nil { + t.Fatalf("failed to create test file: %v", err) + } +} + +func TestScan_DetectsSelectStar(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "bad.go", `package example +import "database/sql" +func f(db *sql.DB) { + db.Query("SELECT * FROM users WHERE id = 1") +} +`) + + out, err := captureScanOutput(t, dir) + + if !errors.Is(err, errIssuesFound) { + t.Error("expected non-zero exit (errIssuesFound)") + } + if !strings.Contains(out, "select-star") { + t.Errorf("expected select-star warning, got:\n%s", out) + } +} + +func TestScan_DetectsDeleteWithoutWhere(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "bad.go", `package example +import "database/sql" +func f(db *sql.DB) { + db.Exec("DELETE FROM users") +} +`) + + out, err := captureScanOutput(t, dir) + + if !errors.Is(err, errIssuesFound) { + t.Error("expected non-zero exit") + } + if !strings.Contains(out, "delete-without-where") { + t.Errorf("expected delete-without-where warning, got:\n%s", out) + } + if !strings.Contains(out, "CRITICAL") { + t.Errorf("expected CRITICAL severity, got:\n%s", out) + } +} + +func TestScan_DetectsLeadingWildcard(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "bad.go", `package example +import "database/sql" +func f(db *sql.DB) { + db.Query("SELECT id FROM users WHERE name LIKE '%test%'") +} +`) + + out, err := captureScanOutput(t, dir) + + if !errors.Is(err, errIssuesFound) { + t.Error("expected non-zero exit") + } + if !strings.Contains(out, "leading-wildcard") { + t.Errorf("expected leading-wildcard warning, got:\n%s", out) + } +} + +func TestScan_DetectsUpdateWithoutWhere(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "bad.go", `package example +import "database/sql" +func f(db *sql.DB) { + db.Exec("UPDATE users SET name = 'test'") +} +`) + + out, err := captureScanOutput(t, dir) + + if !errors.Is(err, errIssuesFound) { + t.Error("expected non-zero exit") + } + if !strings.Contains(out, "update-without-where") { + t.Errorf("expected update-without-where warning, got:\n%s", out) + } +} + +func TestScan_DetectsInsertWithoutColumns(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "bad.go", `package example +import "database/sql" +func f(db *sql.DB) { + db.Exec("INSERT INTO users VALUES ('alice', 'alice@test.com')") +} +`) + + out, err := captureScanOutput(t, dir) + + if !errors.Is(err, errIssuesFound) { + t.Error("expected non-zero exit") + } + if !strings.Contains(out, "insert-without-columns") { + t.Errorf("expected insert-without-columns warning, got:\n%s", out) + } +} + +func TestScan_DetectsSelectWithoutLimit(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "bad.go", `package example +import "database/sql" +func f(db *sql.DB) { + db.Query("SELECT id, name FROM users") +} +`) + + out, err := captureScanOutput(t, dir) + + if !errors.Is(err, errIssuesFound) { + t.Error("expected non-zero exit") + } + if !strings.Contains(out, "select-without-limit") { + t.Errorf("expected select-without-limit warning, got:\n%s", out) + } +} + +func TestScan_DetectsOrderByWithoutLimit(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "bad.go", `package example +import "database/sql" +func f(db *sql.DB) { + db.Query("SELECT id FROM users WHERE active = true ORDER BY name") +} +`) + + out, err := captureScanOutput(t, dir) + + if !errors.Is(err, errIssuesFound) { + t.Error("expected non-zero exit") + } + if !strings.Contains(out, "orderby-without-limit") { + t.Errorf("expected orderby-without-limit warning, got:\n%s", out) + } +} + +func TestScan_NoWarningsForSafeQuery(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "good.go", `package example +import "database/sql" +func f(db *sql.DB) { + db.Query("SELECT id, name FROM users WHERE id = ? LIMIT 10", 1) +} +`) + + out, err := captureScanOutput(t, dir) + + if err != nil { + t.Errorf("expected nil error for safe query, got: %v", err) + } + if strings.Contains(out, "SQLGUARD") { + t.Errorf("expected no warnings for safe query, got:\n%s", out) + } + if !strings.Contains(out, "No issues found (") { + t.Errorf("expected 'No issues found' message, got:\n%s", out) + } +} + +func TestScan_SkipsTestFiles(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "bad_test.go", `package example +import "database/sql" +func f(db *sql.DB) { + db.Query("SELECT * FROM users WHERE id = 1") +} +`) + + out, err := captureScanOutput(t, dir) + + if err != nil { + t.Errorf("expected nil error, got: %v", err) + } + if strings.Contains(out, "select-star") { + t.Errorf("should skip _test.go files, got:\n%s", out) + } +} + +func TestScan_SkipsVendorDir(t *testing.T) { + vendorDir := filepath.Join(t.TempDir(), "vendor") + os.MkdirAll(vendorDir, 0755) + createTestFile(t, vendorDir, "bad.go", `package vendor +import "database/sql" +func f(db *sql.DB) { + db.Query("SELECT * FROM users WHERE id = 1") +} +`) + + out, err := captureScanOutput(t, filepath.Dir(vendorDir)) + + if err != nil { + t.Errorf("expected nil error, got: %v", err) + } + if strings.Contains(out, "select-star") { + t.Errorf("should skip vendor directory, got:\n%s", out) + } +} + +func TestScan_HandlesContextMethods(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "ctx.go", `package example +import ( + "context" + "database/sql" +) +func f(db *sql.DB) { + db.QueryContext(context.Background(), "SELECT * FROM users WHERE id = 1") +} +`) + + out, err := captureScanOutput(t, dir) + + if !errors.Is(err, errIssuesFound) { + t.Error("expected non-zero exit") + } + if !strings.Contains(out, "select-star") { + t.Errorf("expected select-star for QueryContext, got:\n%s", out) + } +} + +func TestScan_MultipleIssuesInOneFile(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "multi.go", `package example +import "database/sql" +func f(db *sql.DB) { + db.Query("SELECT * FROM users WHERE email LIKE '%test%'") + db.Exec("DELETE FROM orders") +} +`) + + out, err := captureScanOutput(t, dir) + + if !errors.Is(err, errIssuesFound) { + t.Error("expected non-zero exit") + } + if !strings.Contains(out, "select-star") { + t.Error("expected select-star warning") + } + if !strings.Contains(out, "leading-wildcard") { + t.Error("expected leading-wildcard warning") + } + if !strings.Contains(out, "delete-without-where") { + t.Error("expected delete-without-where warning") + } +} + +// boundDir creates a temp dir with a .git marker so config.Discover does not +// escape it while walking parents. +func boundDir(t *testing.T) string { + t.Helper() + dir := t.TempDir() + if err := os.Mkdir(filepath.Join(dir, ".git"), 0o755); err != nil { + t.Fatal(err) + } + return dir +} + +func TestScan_ConfigDisablesRule(t *testing.T) { + dir := boundDir(t) + createTestFile(t, dir, ".sqlguard.yml", "rules:\n disable: [select-star]\n") + createTestFile(t, dir, "bad.go", `package example +import "database/sql" +func f(db *sql.DB) { + db.Query("SELECT * FROM users WHERE id = 1") +} +`) + + out, err := captureScanOutput(t, dir) + + if err != nil { + t.Errorf("expected clean exit when rule disabled by config, got %v\n%s", err, out) + } + if strings.Contains(out, "select-star") { + t.Errorf("select-star should be disabled via .sqlguard.yml, got:\n%s", out) + } +} + +func TestScan_InlineSuppressionComment(t *testing.T) { + dir := boundDir(t) + createTestFile(t, dir, "bad.go", `package example +import "database/sql" +func f(db *sql.DB) { + // sqlguard:ignore + db.Exec("DELETE FROM users") + db.Query("SELECT * FROM users WHERE id = 1") // sqlguard:ignore:select-star +} +`) + + out, err := captureScanOutput(t, dir) + + if err != nil { + t.Errorf("expected clean exit, all findings suppressed, got %v\n%s", err, out) + } + if strings.Contains(out, "delete-without-where") || strings.Contains(out, "select-star") { + t.Errorf("inline directives should suppress findings, got:\n%s", out) + } +} + +func TestScan_ExitCodeZeroWhenClean(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "clean.go", `package example +import "database/sql" +func f(db *sql.DB) { + db.Query("SELECT id, name FROM users WHERE id = ? LIMIT 10", 1) +} +`) + + _, err := captureScanOutput(t, dir) + + if err != nil { + t.Errorf("expected exit code 0 for clean code, got error: %v", err) + } +} + +// captureScanOutput runs the scan command and captures stderr output. +// Returns the output and the error (errIssuesFound if issues were found). +func captureScanOutput(t *testing.T, dir string) (string, error) { + t.Helper() + + // Reset format flag to default for each test + formatFlag = "console" + + old := os.Stderr + r, w, _ := os.Pipe() + os.Stderr = w + + err := runScan(&cobra.Command{}, []string{dir}) + + w.Close() + os.Stderr = old + + var buf bytes.Buffer + buf.ReadFrom(r) + + // Only fail on unexpected errors, not errIssuesFound + if err != nil && !errors.Is(err, errIssuesFound) { + t.Fatalf("scan failed unexpectedly: %v", err) + } + + return buf.String(), err +} + +// TestScanCommand_NoUsageDumpOnIssues runs the real command tree (rootCmd.Execute) +// and asserts that a normal "issues found" outcome does NOT print cobra's usage +// text. Regression guard for the SilenceErrors/SilenceUsage wiring: without it, +// returning errIssuesFound from RunE makes cobra dump "Error: issues found" +// followed by the full usage, which looks like a CLI misuse. +func TestScanCommand_NoUsageDumpOnIssues(t *testing.T) { + dir := t.TempDir() + createTestFile(t, dir, "bad.go", + "package bad\n\nimport \"database/sql\"\n\nfunc r(d *sql.DB) { d.Exec(\"DELETE FROM x\") }\n") + + formatFlag = "console" + noConfigFlag = true + t.Cleanup(func() { noConfigFlag = false }) + + old := os.Stderr + r, w, _ := os.Pipe() + os.Stderr = w + + rootCmd.SetArgs([]string{"scan", "--no-config", dir}) + err := rootCmd.Execute() + + w.Close() + os.Stderr = old + var buf bytes.Buffer + _, _ = buf.ReadFrom(r) + out := buf.String() + + if !errors.Is(err, errIssuesFound) { + t.Fatalf("expected errIssuesFound, got %v", err) + } + if strings.Contains(out, "Usage:") { + t.Errorf("scan dumped usage text on an issues-found result:\n%s", out) + } + if !strings.Contains(out, "delete-without-where") { + t.Errorf("expected the finding in output, got:\n%s", out) + } +} diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 0000000..cce2abc --- /dev/null +++ b/codecov.yml @@ -0,0 +1,31 @@ +# Codecov configuration for sqlguard. +# Docs: https://docs.codecov.com/docs/codecov-yaml +# Coverage is produced by `make coverage` (merged across all nine modules) and +# uploaded from the CI "coverage" job. + +codecov: + require_ci_to_pass: true + +coverage: + precision: 2 + round: down + range: "70...100" + status: + project: + default: + target: auto + threshold: 5% + patch: + default: + target: auto + threshold: 5% + +comment: + layout: "reach,diff,flags,files" + behavior: default + require_changes: true + +ignore: + - "examples/**" + - "**/*_test.go" + - "cmd/sqlguard/main.go" diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..4319da1 --- /dev/null +++ b/config/config.go @@ -0,0 +1,317 @@ +// Package config loads and applies .sqlguard.yml configuration. +// +// It is the only package that depends on a YAML library. Importing +// sqlguard/analyzer or sqlguard/middleware does NOT pull YAML in; only code +// that opts into file-based configuration through this package does. The +// analyzer stays parser- and config-agnostic: config translates a Config +// into an analyzer.Profile, which the analyzer applies once at construction. +package config + +import ( + "bytes" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "regexp" + "strings" + "time" + + "github.com/KARTIKrocks/sqlguard/analyzer" + "gopkg.in/yaml.v3" +) + +// ConfigFileNames are the file names Discover looks for, in order. +var ConfigFileNames = []string{".sqlguard.yml", ".sqlguard.yaml"} + +// Config mirrors the .sqlguard.yml schema. The Version field is reserved for +// forward compatibility: older binaries reading a newer config degrade with +// warnings rather than failing, unless Strict is set. +type Config struct { + Version int `yaml:"version"` + Strict bool `yaml:"strict"` + Rules RulesConfig `yaml:"rules"` + SlowQuery SlowQueryConfig `yaml:"slow-query"` + Dedup DedupConfig `yaml:"dedup"` + Scan ScanConfig `yaml:"scan"` + // Redact controls Result.Query literal redaction. Pointer so an unset + // key means "use the safe default" (redact). Set `redact: false` only + // when the query text is trusted (local debugging). + Redact *bool `yaml:"redact"` + + warnings []string +} + +// RulesConfig configures which rules run, their severity, and per-rule +// settings. +type RulesConfig struct { + // Disable turns off the named rules. + Disable []string `yaml:"disable"` + // Only, when non-empty, is a whitelist: only these rules run. + Only []string `yaml:"only"` + // Severity overrides per rule: info | warning | critical | off + // ("off" is equivalent to disabling the rule). + Severity map[string]string `yaml:"severity"` + // Settings holds per-rule tunables, e.g. leading-wildcard.min-length. + Settings map[string]map[string]any `yaml:"settings"` +} + +// SlowQueryConfig configures the middleware slow-query threshold. +type SlowQueryConfig struct { + // Threshold is a Go duration string, e.g. "200ms". + Threshold string `yaml:"threshold"` +} + +// DedupConfig configures runtime suppression of repeated static findings. +type DedupConfig struct { + // Window is a Go duration string, e.g. "1m". The same finding (rule + + // query fingerprint) is reported at most once per window. "0" disables + // dedup (report every occurrence). Unset keeps the middleware default. + Window string `yaml:"window"` +} + +// ScanConfig holds settings that apply only to the static scanner. +type ScanConfig struct { + // ExcludePaths is a list of regular expressions matched against scanned + // file paths; matching files are skipped. + ExcludePaths []string `yaml:"exclude-paths"` +} + +// Default returns an empty configuration: every rule enabled at its default +// severity and settings. Used when no .sqlguard.yml is found. +func Default() *Config { return &Config{Version: 1} } + +// Load reads and parses the config at path. Parsing is lenient by default so +// a config written for a newer sqlguard still loads on an older binary; +// unknown top-level keys become warnings. If the file sets `strict: true`, +// unknown keys are a hard error instead. +func Load(path string) (*Config, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("sqlguard config: %w", err) + } + + var c Config + if err := yaml.Unmarshal(data, &c); err != nil { + return nil, fmt.Errorf("sqlguard config %s: %w", path, err) + } + + // Detect unknown fields with a second strict decode. yaml.v3 surfaces the + // first unknown field as an error; we treat it as fatal only in strict + // mode, otherwise as a warning so forward-compatible configs still work. + if strictErr := strictDecode(data); strictErr != nil { + if c.Strict { + return nil, fmt.Errorf("sqlguard config %s (strict): %w", path, strictErr) + } + c.warnings = append(c.warnings, strictErr.Error()) + } + return &c, nil +} + +func strictDecode(data []byte) error { + dec := yaml.NewDecoder(bytes.NewReader(data)) + dec.KnownFields(true) + var probe Config + if err := dec.Decode(&probe); err != nil && !errors.Is(err, io.EOF) { + return err + } + return nil +} + +// Discover walks startDir and its parents looking for a config file. It stops +// at a directory containing a .git entry (project root) after checking that +// directory, or at the filesystem root. It returns Default() and an empty +// path when no config file is found. +func Discover(startDir string) (cfg *Config, path string, err error) { + dir, err := filepath.Abs(startDir) + if err != nil { + return nil, "", err + } + for { + for _, name := range ConfigFileNames { + p := filepath.Join(dir, name) + if st, statErr := os.Stat(p); statErr == nil && !st.IsDir() { + c, loadErr := Load(p) + return c, p, loadErr + } + } + if isProjectRoot(dir) { + break + } + parent := filepath.Dir(dir) + if parent == dir { + break + } + dir = parent + } + return Default(), "", nil +} + +func isProjectRoot(dir string) bool { + _, err := os.Stat(filepath.Join(dir, ".git")) + return err == nil +} + +// Warnings returns non-fatal issues collected while loading or resolving the +// config (unknown keys in lenient mode, unknown rule names, bad severities). +// Callers should surface these to the user. +func (c *Config) Warnings() []string { return c.warnings } + +// Profile resolves the config into an analyzer.Profile. Unknown rule names +// and unparseable severities are warnings (or errors if Strict). A severity +// of "off" disables the rule. The returned Profile is ready to pass to +// analyzer.DefaultWithProfile. +func (c *Config) Profile() (analyzer.Profile, error) { + known := make(map[string]bool) + for _, n := range analyzer.RuleNames() { + known[n] = true + } + + p := analyzer.Profile{ + Disabled: map[string]bool{}, + Only: map[string]bool{}, + Severity: map[string]analyzer.Severity{}, + Settings: map[string]analyzer.Settings{}, + RawQuery: c.rawQuery(), + } + + warn := func(format string, args ...any) error { + msg := fmt.Sprintf(format, args...) + if c.Strict { + return errors.New(msg) + } + c.warnings = append(c.warnings, msg) + return nil + } + + checkName := func(name string) error { + if !known[name] { + return warn("unknown rule %q (known: %s)", name, strings.Join(analyzer.RuleNames(), ", ")) + } + return nil + } + + for _, name := range c.Rules.Disable { + if err := checkName(name); err != nil { + return p, err + } + p.Disabled[name] = true + } + for _, name := range c.Rules.Only { + if err := checkName(name); err != nil { + return p, err + } + p.Only[name] = true + } + for name, sevStr := range c.Rules.Severity { + if err := checkName(name); err != nil { + return p, err + } + sev, off, ok := parseSeverity(sevStr) + if !ok { + if err := warn("rule %q: invalid severity %q", name, sevStr); err != nil { + return p, err + } + continue + } + if off { + p.Disabled[name] = true + continue + } + p.Severity[name] = sev + } + for name, kv := range c.Rules.Settings { + if err := checkName(name); err != nil { + return p, err + } + p.Settings[name] = analyzer.Settings(kv) + } + return p, nil +} + +// rawQuery reports whether Result.Query redaction is disabled. Redaction is +// the default (PII-safe); only an explicit `redact: false` turns it off. +func (c *Config) rawQuery() bool { return c.Redact != nil && !*c.Redact } + +// Analyzer is a convenience that builds an analyzer from the config's +// Profile using the fallback parser. Callers wanting a real dialect parser +// should take the Profile and combine with analyzer.DefaultWithProfile + +// WithParser themselves. +func (c *Config) Analyzer() (*analyzer.Analyzer, error) { + p, err := c.Profile() + if err != nil { + return nil, err + } + return analyzer.DefaultWithProfile(p), nil +} + +// SlowQueryThreshold returns the configured slow-query threshold. ok is false +// when unset, in which case the caller keeps its own default. +func (c *Config) SlowQueryThreshold() (d time.Duration, ok bool, err error) { + s := strings.TrimSpace(c.SlowQuery.Threshold) + if s == "" { + return 0, false, nil + } + d, err = time.ParseDuration(s) + if err != nil { + return 0, false, fmt.Errorf("sqlguard config: slow-query.threshold %q: %w", s, err) + } + return d, true, nil +} + +// DedupWindow returns the configured static-finding dedup window. ok is false +// when unset, in which case the middleware keeps its own default. A configured +// "0" returns ok=true with d=0, which disables dedup (report every occurrence). +func (c *Config) DedupWindow() (d time.Duration, ok bool, err error) { + s := strings.TrimSpace(c.Dedup.Window) + if s == "" { + return 0, false, nil + } + d, err = time.ParseDuration(s) + if err != nil { + return 0, false, fmt.Errorf("sqlguard config: dedup.window %q: %w", s, err) + } + return d, true, nil +} + +// ExcludeMatcher compiles Scan.ExcludePaths into a single predicate. It +// returns a nil func (never excludes) when no patterns are configured. +func (c *Config) ExcludeMatcher() (func(path string) bool, error) { + if len(c.Scan.ExcludePaths) == 0 { + return nil, nil + } + res := make([]*regexp.Regexp, 0, len(c.Scan.ExcludePaths)) + for _, pat := range c.Scan.ExcludePaths { + re, err := regexp.Compile(pat) + if err != nil { + return nil, fmt.Errorf("sqlguard config: scan.exclude-paths %q: %w", pat, err) + } + res = append(res, re) + } + return func(path string) bool { + for _, re := range res { + if re.MatchString(path) { + return true + } + } + return false + }, nil +} + +// parseSeverity maps a config severity string to an analyzer.Severity. +// "off" / "none" / "disabled" report off=true (disable the rule). +func parseSeverity(s string) (sev analyzer.Severity, off bool, ok bool) { + switch strings.ToLower(strings.TrimSpace(s)) { + case "info": + return analyzer.SeverityInfo, false, true + case "warning", "warn": + return analyzer.SeverityWarning, false, true + case "critical", "error": + return analyzer.SeverityCritical, false, true + case "off", "none", "disabled": + return 0, true, true + default: + return 0, false, false + } +} diff --git a/config/config_test.go b/config/config_test.go new file mode 100644 index 0000000..42ecdda --- /dev/null +++ b/config/config_test.go @@ -0,0 +1,197 @@ +package config + +import ( + "os" + "path/filepath" + "testing" + "time" + + "github.com/KARTIKrocks/sqlguard/analyzer" +) + +func writeConfig(t *testing.T, dir, body string) string { + t.Helper() + p := filepath.Join(dir, ".sqlguard.yml") + if err := os.WriteFile(p, []byte(body), 0o644); err != nil { + t.Fatalf("write config: %v", err) + } + return p +} + +func TestLoadAndProfile(t *testing.T) { + dir := t.TempDir() + p := writeConfig(t, dir, ` +version: 1 +rules: + disable: [orderby-without-limit] + severity: + select-star: info + select-without-limit: "off" + settings: + leading-wildcard: + min-length: 4 +slow-query: + threshold: 350ms +`) + c, err := Load(p) + if err != nil { + t.Fatalf("Load: %v", err) + } + + prof, err := c.Profile() + if err != nil { + t.Fatalf("Profile: %v", err) + } + if !prof.Disabled["orderby-without-limit"] { + t.Error("orderby-without-limit should be disabled") + } + if !prof.Disabled["select-without-limit"] { + t.Error(`severity "off" should disable select-without-limit`) + } + if prof.Severity["select-star"] != analyzer.SeverityInfo { + t.Errorf("select-star severity = %v, want INFO", prof.Severity["select-star"]) + } + if prof.Settings["leading-wildcard"].Int("min-length", 0) != 4 { + t.Error("min-length setting not carried into profile") + } + + d, ok, err := c.SlowQueryThreshold() + if err != nil || !ok || d != 350*time.Millisecond { + t.Errorf("SlowQueryThreshold = %v, %v, %v; want 350ms,true,nil", d, ok, err) + } + + // End-to-end: the built analyzer respects the profile. + a := analyzer.DefaultWithProfile(prof) + got := a.Analyze("SELECT * FROM users") + if len(got) != 1 || got[0].RuleName != "select-star" || got[0].Severity != analyzer.SeverityInfo { + t.Errorf("expected single INFO select-star, got %+v", got) + } +} + +func TestDedupWindow(t *testing.T) { + t.Run("set", func(t *testing.T) { + c := &Config{Dedup: DedupConfig{Window: "30s"}} + d, ok, err := c.DedupWindow() + if err != nil || !ok || d != 30*time.Second { + t.Errorf("DedupWindow = %v, %v, %v; want 30s,true,nil", d, ok, err) + } + }) + t.Run("unset keeps default", func(t *testing.T) { + c := &Config{} + if d, ok, err := c.DedupWindow(); err != nil || ok || d != 0 { + t.Errorf("DedupWindow = %v, %v, %v; want 0,false,nil", d, ok, err) + } + }) + t.Run("zero disables", func(t *testing.T) { + c := &Config{Dedup: DedupConfig{Window: "0"}} + if d, ok, err := c.DedupWindow(); err != nil || !ok || d != 0 { + t.Errorf("DedupWindow = %v, %v, %v; want 0,true,nil (explicit disable)", d, ok, err) + } + }) + t.Run("invalid errors", func(t *testing.T) { + c := &Config{Dedup: DedupConfig{Window: "soon"}} + if _, _, err := c.DedupWindow(); err == nil { + t.Error("expected error for invalid dedup.window") + } + }) +} + +func TestUnknownRuleLenientVsStrict(t *testing.T) { + dir := t.TempDir() + body := "rules:\n disable: [no-such-rule]\n" + + c, err := Load(writeConfig(t, dir, body)) + if err != nil { + t.Fatalf("Load: %v", err) + } + if _, err := c.Profile(); err != nil { + t.Fatalf("lenient Profile should not error: %v", err) + } + if len(c.Warnings()) == 0 { + t.Error("expected a warning for unknown rule in lenient mode") + } + + strict := &Config{Strict: true, Rules: RulesConfig{Disable: []string{"no-such-rule"}}} + if _, err := strict.Profile(); err == nil { + t.Error("expected error for unknown rule in strict mode") + } +} + +func TestUnknownKeyLenientWarnsStrictFails(t *testing.T) { + dir := t.TempDir() + + c, err := Load(writeConfig(t, dir, "bananas: true\n")) + if err != nil { + t.Fatalf("lenient load should succeed: %v", err) + } + if len(c.Warnings()) == 0 { + t.Error("expected warning for unknown top-level key") + } + + if _, err := Load(writeConfig(t, dir, "strict: true\nbananas: true\n")); err == nil { + t.Error("expected strict load to fail on unknown key") + } +} + +func TestDiscoverWalksUpAndStopsAtGitRoot(t *testing.T) { + root := t.TempDir() + if err := os.Mkdir(filepath.Join(root, ".git"), 0o755); err != nil { + t.Fatal(err) + } + writeConfig(t, root, "rules:\n disable: [select-star]\n") + deep := filepath.Join(root, "a", "b", "c") + if err := os.MkdirAll(deep, 0o755); err != nil { + t.Fatal(err) + } + + c, path, err := Discover(deep) + if err != nil { + t.Fatalf("Discover: %v", err) + } + if path == "" { + t.Fatal("expected to find config by walking up") + } + prof, _ := c.Profile() + if !prof.Disabled["select-star"] { + t.Error("discovered config not applied") + } +} + +func TestDiscoverNoConfigReturnsDefault(t *testing.T) { + dir := t.TempDir() + // .git marks the boundary so Discover does not escape the temp dir. + _ = os.Mkdir(filepath.Join(dir, ".git"), 0o755) + + c, path, err := Discover(dir) + if err != nil { + t.Fatalf("Discover: %v", err) + } + if path != "" { + t.Errorf("expected no config path, got %q", path) + } + if _, err := c.Profile(); err != nil { + t.Errorf("default profile should be valid: %v", err) + } +} + +func TestExcludeMatcher(t *testing.T) { + c := &Config{Scan: ScanConfig{ExcludePaths: []string{`(^|/)legacy/`, `_gen\.go$`}}} + m, err := c.ExcludeMatcher() + if err != nil { + t.Fatalf("ExcludeMatcher: %v", err) + } + if !m("pkg/legacy/old.go") || !m("api/types_gen.go") { + t.Error("expected matches for excluded paths") + } + if m("pkg/service/user.go") { + t.Error("did not expect match for normal path") + } + + none, err := (&Config{}).ExcludeMatcher() + if err != nil { + t.Errorf("no patterns should not error: %v", err) + } + if none != nil { + t.Error("no patterns should yield a nil matcher") + } +} diff --git a/config/middleware.go b/config/middleware.go new file mode 100644 index 0000000..439d91d --- /dev/null +++ b/config/middleware.go @@ -0,0 +1,61 @@ +package config + +import ( + "github.com/KARTIKrocks/sqlguard/middleware" +) + +// MiddlewareOptions translates this config into middleware options: an +// analyzer built from the rule Profile, and the slow-query threshold when +// configured. Combine with other middleware options as needed, e.g.: +// +// opts, _ := cfg.MiddlewareOptions() +// opts = append(opts, middleware.WithParser(pgparser.New())) +// sqlguard.Register("sqlguard-pg", "pgx", opts...) +// +// Keeping this in the config package (not middleware) keeps YAML out of the +// middleware import graph for users who do not use file configuration. +func (c *Config) MiddlewareOptions() ([]middleware.Option, error) { + a, err := c.Analyzer() + if err != nil { + return nil, err + } + opts := []middleware.Option{middleware.WithAnalyzer(a)} + + d, ok, err := c.SlowQueryThreshold() + if err != nil { + return nil, err + } + if ok { + opts = append(opts, middleware.WithSlowQueryThreshold(d)) + } + + dw, ok, err := c.DedupWindow() + if err != nil { + return nil, err + } + if ok { + opts = append(opts, middleware.WithFindingDedup(dw)) + } + return opts, nil +} + +// Middleware loads configuration and returns ready-to-use middleware +// options. If path is non-empty it is loaded directly; otherwise config is +// discovered by walking up from startDir (use "." for the working +// directory). A missing config is not an error — it yields options +// equivalent to the built-in defaults. +func Middleware(path, startDir string) ([]middleware.Option, error) { + var ( + c *Config + err error + ) + if path != "" { + c, err = Load(path) + } else { + c, _, err = Discover(startDir) + } + if err != nil { + return nil, err + } + return c.MiddlewareOptions() +} diff --git a/config/middleware_test.go b/config/middleware_test.go new file mode 100644 index 0000000..cfe4b9f --- /dev/null +++ b/config/middleware_test.go @@ -0,0 +1,50 @@ +package config + +import ( + "database/sql" + "path/filepath" + "strings" + "testing" + + "github.com/KARTIKrocks/sqlguard" + "github.com/KARTIKrocks/sqlguard/middleware" + "github.com/KARTIKrocks/sqlguard/reporter" + + _ "github.com/mattn/go-sqlite3" +) + +func TestMiddlewareOptionsAppliesProfile(t *testing.T) { + dir := t.TempDir() + writeConfig(t, dir, "rules:\n disable: [select-star]\n") + + opts, err := Middleware("", dir) + if err != nil { + t.Fatalf("Middleware: %v", err) + } + + var buf strings.Builder + opts = append(opts, middleware.WithReporter(reporter.NewConsoleReporterTo(&buf))) + + name := "sqlguard-cfg-test" + if err := sqlguard.Register(name, "sqlite3", opts...); err != nil { + t.Fatalf("Register: %v", err) + } + db, err := sql.Open(name, filepath.Join(dir, "t.db")) + if err != nil { + t.Fatalf("open: %v", err) + } + defer db.Close() + if _, err := db.Exec("CREATE TABLE u (id INTEGER, name TEXT)"); err != nil { + t.Fatalf("create: %v", err) + } + + rows, err := db.Query("SELECT * FROM u WHERE id = 1") + if err != nil { + t.Fatalf("query: %v", err) + } + rows.Close() + + if strings.Contains(buf.String(), "select-star") { + t.Errorf("select-star should be disabled via config, got:\n%s", buf.String()) + } +} diff --git a/explain/explain.go b/explain/explain.go new file mode 100644 index 0000000..86a0a00 --- /dev/null +++ b/explain/explain.go @@ -0,0 +1,295 @@ +// Package explain provides SQL EXPLAIN plan analysis. +// It connects to a live database to run EXPLAIN on queries and detect +// performance issues like sequential scans and high-cost operations. +package explain + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strings" + + "github.com/KARTIKrocks/sqlguard/analyzer" +) + +// PlanAnalyzer runs EXPLAIN on queries against a live database. +type PlanAnalyzer struct { + db *sql.DB + dialect string // "postgres" or "mysql" + allowDML bool +} + +// Option configures a PlanAnalyzer. +type Option func(*PlanAnalyzer) + +// WithAllowDML permits EXPLAIN on INSERT/UPDATE/DELETE statements. It is OFF +// by default: only SELECT/WITH are explained, because feeding DML to a prod +// database — even under plain EXPLAIN — is a footgun (and EXPLAIN ANALYZE +// would execute it). When enabled, DML EXPLAINs still run inside a +// transaction that is always rolled back (see analyzePostgres/analyzeMySQL), +// so nothing is committed regardless. +func WithAllowDML() Option { + return func(p *PlanAnalyzer) { p.allowDML = true } +} + +// New creates a PlanAnalyzer for the given database connection. +// dialect must be "postgres" or "mysql". +func New(db *sql.DB, dialect string, opts ...Option) (*PlanAnalyzer, error) { + if db == nil { + return nil, fmt.Errorf("explain: db is nil") + } + dialect = strings.ToLower(dialect) + if dialect != "postgres" && dialect != "mysql" { + return nil, fmt.Errorf("explain: unsupported dialect %q (use 'postgres' or 'mysql')", dialect) + } + p := &PlanAnalyzer{db: db, dialect: dialect} + for _, o := range opts { + o(p) + } + return p, nil +} + +// Result holds the parsed EXPLAIN output and any detected issues. +type Result struct { + Query string + RawPlan string + Issues []analyzer.Result +} + +// Analyze runs EXPLAIN on the given query and returns detected issues. The +// query is validated (see validate) and the EXPLAIN is run inside an +// always-rolled-back transaction, so a query passed here cannot mutate the +// target database. +func (p *PlanAnalyzer) Analyze(ctx context.Context, query string) (*Result, error) { + safe, err := p.validate(query) + if err != nil { + return nil, err + } + + var res *Result + switch p.dialect { + case "postgres": + res, err = p.analyzePostgres(ctx, safe) + case "mysql": + res, err = p.analyzeMySQL(ctx, safe) + default: + return nil, fmt.Errorf("explain: unsupported dialect %q", p.dialect) + } + if res != nil { + fp := analyzer.Fingerprint(query) + for i := range res.Issues { + res.Issues[i].Fingerprint = fp + } + } + return res, err +} + +// validate enforces the EXPLAIN safety policy and returns the single, +// terminator-stripped statement that is safe to concatenate into an EXPLAIN +// prefix. +// +// EXPLAIN cannot take bind parameters, so the query is necessarily +// string-concatenated; the defense is therefore strict input validation, not +// parameterization: +// +// - Reject empty input. +// - Reject multi-statement input using a comment- and string-literal-aware +// check (analyzer.IsMultiStatement). The previous +// strings.Contains(query, ";") check was defeated by a ";" inside a +// -- / /* */ comment or a string literal, and over-rejected a harmless +// trailing ";". +// - Classify the statement via the same parser the analyzer uses. Only +// SELECT/WITH are allowed by default; INSERT/UPDATE/DELETE require +// WithAllowDML; DDL/SET/other is always refused. +func (p *PlanAnalyzer) validate(query string) (string, error) { + q := strings.TrimSpace(query) + if q == "" { + return "", fmt.Errorf("explain: refusing to explain an empty query") + } + if analyzer.IsMultiStatement(q) { + return "", fmt.Errorf("explain: refusing to explain multi-statement input") + } + q = strings.TrimRight(q, "; \t\r\n") + + st, _ := analyzer.NewFallbackParser().Parse(q) + switch st.Kind { + case analyzer.StmtSelect: + return q, nil + case analyzer.StmtInsert, analyzer.StmtUpdate, analyzer.StmtDelete: + if !p.allowDML { + return "", fmt.Errorf("explain: refusing to EXPLAIN a data-modifying statement by default; construct the analyzer with explain.WithAllowDML to opt in") + } + return q, nil + default: + return "", fmt.Errorf("explain: refusing to explain a non-SELECT/WITH/DML statement (DDL, SET, transaction control, or unrecognized)") + } +} + +// PostgreSQL EXPLAIN JSON structures +type pgPlan struct { + Plan pgPlanNode `json:"Plan"` +} + +type pgPlanNode struct { + NodeType string `json:"Node Type"` + TotalCost float64 `json:"Total Cost"` + PlanRows int64 `json:"Plan Rows"` + Plans []pgPlanNode `json:"Plans"` +} + +func (p *PlanAnalyzer) analyzePostgres(ctx context.Context, query string) (*Result, error) { + // query is the validated, single, terminator-free statement from + // validate(). EXPLAIN takes no bind parameters, so concatenation is + // unavoidable; safety comes from validate() plus the rolled-back, + // read-only transaction below. We never use EXPLAIN ANALYZE, so the + // statement is planned, not executed. + explainQuery := "EXPLAIN (FORMAT JSON) " + query + + tx, err := p.db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true}) + if err != nil { + return nil, fmt.Errorf("explain: failed to begin read-only transaction: %w", err) + } + // Always roll back: an EXPLAIN must never commit anything. + defer func() { _ = tx.Rollback() }() + + var rawJSON string + if err := tx.QueryRowContext(ctx, explainQuery).Scan(&rawJSON); err != nil { + return nil, fmt.Errorf("explain: failed to run EXPLAIN: %w", err) + } + + result := &Result{ + Query: query, + RawPlan: rawJSON, + } + + var plans []pgPlan + if err := json.Unmarshal([]byte(rawJSON), &plans); err != nil { + return result, fmt.Errorf("explain: failed to parse EXPLAIN JSON: %w", err) + } + + if len(plans) > 0 { + p.walkPgPlan(&plans[0].Plan, query, &result.Issues) + } + + return result, nil +} + +func (p *PlanAnalyzer) walkPgPlan(node *pgPlanNode, query string, issues *[]analyzer.Result) { + if node == nil { + return + } + + // Detect sequential scans + if node.NodeType == "Seq Scan" { + severity := analyzer.SeverityInfo + if node.PlanRows > 1000 { + severity = analyzer.SeverityWarning + } + *issues = append(*issues, analyzer.Result{ + RuleName: "seq-scan", + Severity: severity, + Query: query, + Message: fmt.Sprintf("Sequential scan detected (estimated %d rows, cost %.1f)", node.PlanRows, node.TotalCost), + Suggestion: "Consider adding an index to avoid full table scan.", + }) + } + + // Detect high cost operations + if node.TotalCost > 10000 { + *issues = append(*issues, analyzer.Result{ + RuleName: "high-cost", + Severity: analyzer.SeverityWarning, + Query: query, + Message: fmt.Sprintf("High cost operation: %s (cost %.1f)", node.NodeType, node.TotalCost), + Suggestion: "Review query plan and consider optimization.", + }) + } + + // Recurse into child plans + for i := range node.Plans { + p.walkPgPlan(&node.Plans[i], query, issues) + } +} + +func (p *PlanAnalyzer) analyzeMySQL(ctx context.Context, query string) (*Result, error) { + // See analyzePostgres: validated single statement, no ANALYZE, run in an + // always-rolled-back read-only transaction so EXPLAIN cannot mutate data. + explainQuery := "EXPLAIN " + query + + tx, err := p.db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true}) + if err != nil { + return nil, fmt.Errorf("explain: failed to begin read-only transaction: %w", err) + } + defer func() { _ = tx.Rollback() }() + + rows, err := tx.QueryContext(ctx, explainQuery) + if err != nil { + return nil, fmt.Errorf("explain: failed to run EXPLAIN: %w", err) + } + defer func() { _ = rows.Close() }() + + result := &Result{ + Query: query, + } + + for rows.Next() { + var ( + id int + selectType string + table sql.NullString + partitions sql.NullString + accessType sql.NullString + possibleKeys sql.NullString + key sql.NullString + keyLen sql.NullString + ref sql.NullString + rowCount sql.NullInt64 + filtered sql.NullFloat64 + extra sql.NullString + ) + + if err := rows.Scan(&id, &selectType, &table, &partitions, &accessType, &possibleKeys, &key, &keyLen, &ref, &rowCount, &filtered, &extra); err != nil { + return result, fmt.Errorf("explain: failed to scan row: %w", err) + } + + // Detect full table scans (type = ALL) + if accessType.Valid && accessType.String == "ALL" { + result.Issues = append(result.Issues, analyzer.Result{ + RuleName: "full-table-scan", + Severity: analyzer.SeverityWarning, + Query: query, + Message: fmt.Sprintf("Full table scan on %s (estimated %d rows)", table.String, rowCount.Int64), + Suggestion: "Consider adding an index to avoid full table scan.", + }) + } + + // Detect missing indexes + if (!key.Valid || key.String == "") && (!possibleKeys.Valid || possibleKeys.String == "") && table.Valid && table.String != "" { + result.Issues = append(result.Issues, analyzer.Result{ + RuleName: "no-index-used", + Severity: analyzer.SeverityWarning, + Query: query, + Message: fmt.Sprintf("No index used on table %s", table.String), + Suggestion: "Consider adding an index on the filtered/joined columns.", + }) + } + + // Detect filesort + if strings.Contains(extra.String, "Using filesort") { + result.Issues = append(result.Issues, analyzer.Result{ + RuleName: "filesort", + Severity: analyzer.SeverityInfo, + Query: query, + Message: fmt.Sprintf("Filesort detected on table %s", table.String), + Suggestion: "Consider adding an index that covers the ORDER BY columns.", + }) + } + } + + if err := rows.Err(); err != nil { + return result, fmt.Errorf("explain: error reading rows: %w", err) + } + + return result, nil +} diff --git a/explain/explain_test.go b/explain/explain_test.go new file mode 100644 index 0000000..d251bcb --- /dev/null +++ b/explain/explain_test.go @@ -0,0 +1,52 @@ +package explain + +import ( + "strings" + "testing" +) + +func TestValidate(t *testing.T) { + cases := []struct { + name string + query string + allowDML bool + wantErr string // substring; "" means no error + wantSafe string // expected returned statement when no error + }{ + {"select ok", `SELECT * FROM t WHERE id = 1`, false, "", `SELECT * FROM t WHERE id = 1`}, + {"with ok", `WITH c AS (SELECT 1) SELECT * FROM c`, false, "", `WITH c AS (SELECT 1) SELECT * FROM c`}, + {"trailing semicolon trimmed", `SELECT 1;`, false, "", `SELECT 1`}, + {"empty", ` `, false, "empty", ""}, + {"stacked statements", `SELECT 1; DROP TABLE users`, false, "multi-statement", ""}, + {"semicolon in comment is fine", "SELECT 1 -- ; DROP\n", false, "", "SELECT 1 -- ; DROP"}, + {"semicolon in string is fine", `SELECT * FROM t WHERE s = 'a;b'`, false, "", `SELECT * FROM t WHERE s = 'a;b'`}, + {"stack hidden after string", `SELECT 'a;b'; DELETE FROM t`, false, "multi-statement", ""}, + {"dml refused by default", `DELETE FROM t WHERE id = 1`, false, "data-modifying", ""}, + {"update refused by default", `UPDATE t SET a = 1`, false, "data-modifying", ""}, + {"dml allowed with opt-in", `DELETE FROM t WHERE id = 1`, true, "", `DELETE FROM t WHERE id = 1`}, + {"ddl always refused", `DROP TABLE users`, true, "non-SELECT", ""}, + {"set always refused", `SET search_path = x`, true, "non-SELECT", ""}, + {"truncate refused", `TRUNCATE t`, true, "non-SELECT", ""}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + p := &PlanAnalyzer{dialect: "postgres", allowDML: c.allowDML} + safe, err := p.validate(c.query) + if c.wantErr == "" { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if safe != c.wantSafe { + t.Errorf("safe = %q, want %q", safe, c.wantSafe) + } + return + } + if err == nil { + t.Fatalf("expected error containing %q, got nil", c.wantErr) + } + if !strings.Contains(err.Error(), c.wantErr) { + t.Errorf("error %q does not contain %q", err, c.wantErr) + } + }) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..b178be2 --- /dev/null +++ b/go.mod @@ -0,0 +1,17 @@ +module github.com/KARTIKrocks/sqlguard + +go 1.26 + +require ( + github.com/mattn/go-sqlite3 v1.14.45 + github.com/spf13/cobra v1.10.2 + golang.org/x/tools v0.45.0 + gopkg.in/yaml.v3 v3.0.1 +) + +require ( + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/spf13/pflag v1.0.10 // indirect + golang.org/x/mod v0.36.0 // indirect + golang.org/x/sync v0.20.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..37b285f --- /dev/null +++ b/go.sum @@ -0,0 +1,24 @@ +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/mattn/go-sqlite3 v1.14.45 h1:6KA/spDguL3KV8rnybG7ezSaE4SeMR3KC9VbUoAQaIk= +github.com/mattn/go-sqlite3 v1.14.45/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= +github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= +github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= +github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.org/x/mod v0.36.0 h1:JJjpVx6myfUsUdAzZuOSTTmRE0PfZeNWzzvKrP7amb4= +golang.org/x/mod v0.36.0/go.mod h1:moc6ELqsWcOw5Ef3xVprK5ul/MvtVvkIXLziUOICjUQ= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/tools v0.45.0 h1:18qN3FAooORvApf5XjCXgsuayZOEtXf6JK18I3+ONa8= +golang.org/x/tools v0.45.0/go.mod h1:LuUGqqaXcXMEFEruIVJVm5mgDD8vww/z/SR1gQ4uE/0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/integrations/bunguard/bunguard.go b/integrations/bunguard/bunguard.go new file mode 100644 index 0000000..dc3a5bf --- /dev/null +++ b/integrations/bunguard/bunguard.go @@ -0,0 +1,75 @@ +// Package bunguard integrates sqlguard with bun (github.com/uptrace/bun). +// +// Analysis is driven by the single shared sqlguard core (middleware.Guard), +// so redaction-by-default, stable fingerprints, the pluggable real-grammar +// parser, slow-query timing and N+1 detection behave identically to the +// database/sql driver wrapper, pgxguard and gormguard. There is no parallel +// option surface — configure with the standard middleware options: +// +// sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(dsn))) +// db := bun.NewDB(sqldb, pgdialect.New()) +// db.AddQueryHook(bunguard.New( +// middleware.WithSlowQueryThreshold(500*time.Millisecond), +// middleware.WithN1Detection(10, time.Second), +// )) +// +// bun exposes the final rendered SQL and a start timestamp on the QueryEvent +// in its AfterQuery hook, so this uses the explicit Check+CheckLatency pair +// (matching gormguard) rather than middleware.Guard.Observe: static rules run +// on every query, latency is reported only on success. +package bunguard + +import ( + "context" + "time" + + "github.com/KARTIKrocks/sqlguard/middleware" + "github.com/uptrace/bun" +) + +// QueryHook implements bun.QueryHook and drives every traced statement +// through the shared sqlguard analysis core. +type QueryHook struct { + g *middleware.Guard +} + +// Compile-time proof we satisfy bun.QueryHook. +var _ bun.QueryHook = (*QueryHook)(nil) + +// New creates a new sqlguard bun query hook. It accepts the standard sqlguard +// middleware options (WithAnalyzer, WithReporter, WithSlowQueryThreshold, +// WithParser, WithN1Detection, …) — the same option set the database/sql +// driver wrapper, pgxguard and gormguard use, so there is no parallel +// configuration surface to drift. +func New(opts ...middleware.Option) *QueryHook { + return &QueryHook{g: middleware.NewGuard(opts...)} +} + +// ResetN1 clears N+1 tracker state. Call it at a per-request boundary +// (e.g. end of an HTTP handler) to scope N+1 detection to one unit of work. +// No-op unless WithN1Detection was passed to New. +func (h *QueryHook) ResetN1() { h.g.ResetN1() } + +// BeforeQuery implements bun.QueryHook. bun stamps event.StartTime itself +// before invoking the hook, so there is nothing to stash here. +func (h *QueryHook) BeforeQuery(ctx context.Context, _ *bun.QueryEvent) context.Context { + return ctx +} + +// AfterQuery implements bun.QueryHook. event.Query holds the rendered SQL. +func (h *QueryHook) AfterQuery(_ context.Context, event *bun.QueryEvent) { + sql := event.Query + if sql == "" { + return + } + + // Static rules + N+1 run on every call (matches Observe semantics). + h.g.Check(sql) + + // Latency is reported only on success — a failed query's duration is + // meaningless. This mirrors middleware.Guard.Observe. + if event.Err != nil { + return + } + h.g.CheckLatency(sql, time.Since(event.StartTime)) +} diff --git a/integrations/bunguard/bunguard_test.go b/integrations/bunguard/bunguard_test.go new file mode 100644 index 0000000..5b5fbd5 --- /dev/null +++ b/integrations/bunguard/bunguard_test.go @@ -0,0 +1,185 @@ +package bunguard + +import ( + "context" + "database/sql" + "strings" + "sync" + "testing" + "time" + + "github.com/KARTIKrocks/sqlguard/analyzer" + "github.com/KARTIKrocks/sqlguard/middleware" + _ "github.com/mattn/go-sqlite3" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect/sqlitedialect" +) + +// capture is a thread-safe in-memory Reporter for assertions. +type capture struct { + mu sync.Mutex + r []analyzer.Result +} + +func (c *capture) Report(rs []analyzer.Result) { + c.mu.Lock() + defer c.mu.Unlock() + c.r = append(c.r, rs...) +} + +func (c *capture) snapshot() []analyzer.Result { + c.mu.Lock() + defer c.mu.Unlock() + out := make([]analyzer.Result, len(c.r)) + copy(out, c.r) + return out +} + +func (c *capture) has(rule string) bool { + for _, r := range c.snapshot() { + if r.RuleName == rule { + return true + } + } + return false +} + +type user struct { + bun.BaseModel `bun:"table:users"` + ID int64 `bun:"id,pk"` + Email string `bun:"email"` +} + +// newDBWithCapture spins up an in-memory sqlite-backed *bun.DB with the +// sqlguard hook registered, so the integration runs end-to-end (QueryHook +// seam → driver round trip) rather than mocked. +func newDBWithCapture(t *testing.T, opts ...middleware.Option) (*bun.DB, *capture, *QueryHook) { + t.Helper() + sqldb, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("sql.Open: %v", err) + } + t.Cleanup(func() { _ = sqldb.Close() }) + db := bun.NewDB(sqldb, sqlitedialect.New()) + + ctx := context.Background() + if _, err := db.NewCreateTable().Model((*user)(nil)).Exec(ctx); err != nil { + t.Fatalf("create table: %v", err) + } + if _, err := db.NewInsert().Model(&user{ID: 1, Email: "leak@example.com"}).Exec(ctx); err != nil { + t.Fatalf("seed: %v", err) + } + + cap := &capture{} + opts = append([]middleware.Option{middleware.WithReporter(cap)}, opts...) + hook := New(opts...) + db.AddQueryHook(hook) + // Hook registered after seeding, so capture starts clean — every test + // asserts only on findings from its own queries. + return db, cap, hook +} + +func TestHook_DetectsRawSelectStar(t *testing.T) { + db, cap, _ := newDBWithCapture(t) + var us []user + if err := db.NewRaw("SELECT * FROM users").Scan(context.Background(), &us); err != nil { + t.Fatalf("Raw: %v", err) + } + if !cap.has("select-star") { + t.Fatalf("expected select-star finding, got %+v", cap.snapshot()) + } +} + +// TestHook_RedactsLiteralsByDefault asserts the headline redaction guarantee: +// single-quoted literals never reach Result.Query and Fingerprint is always +// populated. +func TestHook_RedactsLiteralsByDefault(t *testing.T) { + db, cap, _ := newDBWithCapture(t) + var us []user + if err := db.NewRaw("SELECT * FROM users WHERE email = 'leak@example.com'").Scan(context.Background(), &us); err != nil { + t.Fatalf("Raw: %v", err) + } + results := cap.snapshot() + if len(results) == 0 { + t.Fatal("expected at least one finding") + } + for _, r := range results { + if strings.Contains(r.Query, "leak@example.com") { + t.Errorf("literal leaked into Result.Query: %q (rule=%s)", r.Query, r.RuleName) + } + if r.Fingerprint == "" { + t.Errorf("Fingerprint must always be populated, got empty for rule %s", r.RuleName) + } + } +} + +func TestHook_SlowQueryReportedOnSuccess(t *testing.T) { + db, cap, _ := newDBWithCapture(t, middleware.WithSlowQueryThreshold(0)) + var u user + if err := db.NewSelect().Model(&u).Where("id = ?", 1).Scan(context.Background()); err != nil { + t.Fatalf("Select: %v", err) + } + if !cap.has("slow-query") { + t.Fatalf("expected slow-query finding with zero threshold, got %+v", cap.snapshot()) + } +} + +func TestHook_SlowQuerySuppressedOnError(t *testing.T) { + db, cap, _ := newDBWithCapture(t, middleware.WithSlowQueryThreshold(0)) + var dst int + err := db.NewRaw("SELECT id FROM no_such_table_xyz WHERE id = 1").Scan(context.Background(), &dst) + if err == nil { + t.Fatal("expected error from selecting a missing table") + } + if cap.has("slow-query") { + t.Fatalf("slow-query must not fire when the query failed; got %+v", cap.snapshot()) + } +} + +func TestHook_NPlusOneAcrossCalls(t *testing.T) { + db, cap, _ := newDBWithCapture(t, middleware.WithN1Detection(3, time.Second)) + var u user + for range 3 { + if err := db.NewRaw("SELECT id FROM users WHERE id = 1").Scan(context.Background(), &u); err != nil { + t.Fatalf("Raw: %v", err) + } + } + if !cap.has("n-plus-one") { + t.Fatalf("expected n-plus-one finding after 3 identical queries, got %+v", cap.snapshot()) + } +} + +func TestHook_ResetN1ClearsState(t *testing.T) { + db, cap, hook := newDBWithCapture(t, middleware.WithN1Detection(3, time.Second)) + var u user + for range 2 { + if err := db.NewRaw("SELECT id FROM users WHERE id = 1").Scan(context.Background(), &u); err != nil { + t.Fatalf("Raw: %v", err) + } + } + hook.ResetN1() + if err := db.NewRaw("SELECT id FROM users WHERE id = 1").Scan(context.Background(), &u); err != nil { + t.Fatalf("Raw: %v", err) + } + if cap.has("n-plus-one") { + t.Fatalf("n-plus-one should not fire — ResetN1 zeroed the counter; got %+v", cap.snapshot()) + } +} + +// Proves UPDATE / DELETE statements also flow through Guard. +func TestHook_UpdateAndDeleteAnalyzed(t *testing.T) { + db, cap, _ := newDBWithCapture(t) + ctx := context.Background() + if _, err := db.NewRaw("UPDATE users SET email = 'x'").Exec(ctx); err != nil { + t.Fatalf("UPDATE: %v", err) + } + if !cap.has("update-without-where") { + t.Fatalf("expected update-without-where, got %+v", cap.snapshot()) + } + if _, err := db.NewRaw("DELETE FROM users").Exec(ctx); err != nil { + t.Fatalf("DELETE: %v", err) + } + if !cap.has("delete-without-where") { + t.Fatalf("expected delete-without-where, got %+v", cap.snapshot()) + } +} diff --git a/integrations/bunguard/go.mod b/integrations/bunguard/go.mod new file mode 100644 index 0000000..c089ec0 --- /dev/null +++ b/integrations/bunguard/go.mod @@ -0,0 +1,21 @@ +module github.com/KARTIKrocks/sqlguard/integrations/bunguard + +go 1.26 + +require ( + github.com/KARTIKrocks/sqlguard v0.0.0 + github.com/mattn/go-sqlite3 v1.14.45 + github.com/uptrace/bun v1.2.18 + github.com/uptrace/bun/dialect/sqlitedialect v1.2.18 +) + +require ( + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect + github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect + github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect + github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect + golang.org/x/sys v0.41.0 // indirect +) + +replace github.com/KARTIKrocks/sqlguard => ../.. diff --git a/integrations/bunguard/go.sum b/integrations/bunguard/go.sum new file mode 100644 index 0000000..f284f00 --- /dev/null +++ b/integrations/bunguard/go.sum @@ -0,0 +1,26 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/mattn/go-sqlite3 v1.14.45 h1:6KA/spDguL3KV8rnybG7ezSaE4SeMR3KC9VbUoAQaIk= +github.com/mattn/go-sqlite3 v1.14.45/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg= +github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo= +github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs= +github.com/uptrace/bun v1.2.18 h1:3HnRcMfS6OBPMG1eSOzlbFJ/X/AyMEJb7rMxE6VQvDU= +github.com/uptrace/bun v1.2.18/go.mod h1:wNltaKJk4JtOt4SG5I5zmA7v0/Mzjh1+/S906Rayd3Y= +github.com/uptrace/bun/dialect/sqlitedialect v1.2.18 h1:Z33SY/U++XK9uGWqS4h8OZVxfCXguIG+sU9cYq2PGFQ= +github.com/uptrace/bun/dialect/sqlitedialect v1.2.18/go.mod h1:1MVOS/Ncy4FZbkJcgUFH6OqYoQinYNjkEwsmNQEXz2A= +github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8= +github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= +github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= +github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/integrations/entguard/entguard.go b/integrations/entguard/entguard.go new file mode 100644 index 0000000..68191a1 --- /dev/null +++ b/integrations/entguard/entguard.go @@ -0,0 +1,123 @@ +// Package entguard integrates sqlguard with ent (entgo.io/ent). +// +// ent runs on database/sql, so the simplest coverage is already available by +// pointing entsql at a *sql.DB obtained from sqlguard.Register / OpenDB. This +// package is the dedicated alternative: it decorates ent's own +// dialect.Driver seam, so it works regardless of how the underlying *sql.DB +// was opened (including ent's dialect.DebugDriver chain) and mirrors ent's +// built-in dialect.Debug wrapper. +// +// Analysis is driven by the single shared sqlguard core (middleware.Guard), +// so redaction-by-default, stable fingerprints, the pluggable real-grammar +// parser, slow-query timing and N+1 detection behave identically to every +// other sqlguard surface. There is no parallel option surface — configure +// with the standard middleware options: +// +// drv, _ := entsql.Open(dialect.Postgres, dsn) +// guarded := entguard.Wrap(drv, +// middleware.WithSlowQueryThreshold(500*time.Millisecond), +// middleware.WithN1Detection(10, time.Second), +// ) +// client := ent.NewClient(ent.Driver(guarded)) +// +// Every Exec/Query — on the driver and on transactions it opens — flows +// through middleware.Guard.Observe: static rules run on every call, latency +// is recorded only on success. +package entguard + +import ( + "context" + "database/sql" + + "entgo.io/ent/dialect" + "github.com/KARTIKrocks/sqlguard/middleware" +) + +// Driver decorates an ent dialect.Driver, routing every statement through the +// shared sqlguard analysis core. +type Driver struct { + dialect.Driver + g *middleware.Guard +} + +// Compile-time proof we still satisfy ent's driver contract. +var _ dialect.Driver = (*Driver)(nil) + +// Wrap decorates an ent dialect.Driver. It accepts the standard sqlguard +// middleware options (WithAnalyzer, WithReporter, WithSlowQueryThreshold, +// WithParser, WithN1Detection, …) — the same option set every other sqlguard +// surface uses, so there is no parallel configuration surface to drift. +func Wrap(d dialect.Driver, opts ...middleware.Option) *Driver { + return &Driver{Driver: d, g: middleware.NewGuard(opts...)} +} + +// ResetN1 clears N+1 tracker state. Call it at a per-request boundary +// (e.g. end of an HTTP handler) to scope N+1 detection to one unit of work. +// No-op unless WithN1Detection was passed to Wrap. +func (d *Driver) ResetN1() { d.g.ResetN1() } + +// Exec implements dialect.ExecQuerier. +func (d *Driver) Exec(ctx context.Context, query string, args, v any) error { + done := d.g.Observe(query) + err := d.Driver.Exec(ctx, query, args, v) + done(err) + return err +} + +// Query implements dialect.ExecQuerier. +func (d *Driver) Query(ctx context.Context, query string, args, v any) error { + done := d.g.Observe(query) + err := d.Driver.Query(ctx, query, args, v) + done(err) + return err +} + +// Tx wraps the transaction so statements executed inside it are analyzed too. +func (d *Driver) Tx(ctx context.Context) (dialect.Tx, error) { + t, err := d.Driver.Tx(ctx) + if err != nil { + return nil, err + } + return &tx{Tx: t, g: d.g}, nil +} + +// BeginTx forwards to the wrapped driver's BeginTx when it implements one +// (entsql.Driver does — this is how ent honours read-only / isolation +// options), and wraps the resulting transaction. It degrades to Tx when the +// base driver has no BeginTx, matching ent's own fallback. +func (d *Driver) BeginTx(ctx context.Context, opts *sql.TxOptions) (dialect.Tx, error) { + bt, ok := d.Driver.(interface { + BeginTx(context.Context, *sql.TxOptions) (dialect.Tx, error) + }) + if !ok { + return d.Tx(ctx) + } + t, err := bt.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + return &tx{Tx: t, g: d.g}, nil +} + +// tx decorates a dialect.Tx so in-transaction Exec/Query are analyzed. +// Commit/Rollback are inherited from the embedded transaction unchanged. +type tx struct { + dialect.Tx + g *middleware.Guard +} + +// Exec implements dialect.ExecQuerier. +func (t *tx) Exec(ctx context.Context, query string, args, v any) error { + done := t.g.Observe(query) + err := t.Tx.Exec(ctx, query, args, v) + done(err) + return err +} + +// Query implements dialect.ExecQuerier. +func (t *tx) Query(ctx context.Context, query string, args, v any) error { + done := t.g.Observe(query) + err := t.Tx.Query(ctx, query, args, v) + done(err) + return err +} diff --git a/integrations/entguard/entguard_test.go b/integrations/entguard/entguard_test.go new file mode 100644 index 0000000..aa135b8 --- /dev/null +++ b/integrations/entguard/entguard_test.go @@ -0,0 +1,193 @@ +package entguard + +import ( + "context" + "strings" + "sync" + "testing" + "time" + + entsql "entgo.io/ent/dialect/sql" + "github.com/KARTIKrocks/sqlguard/analyzer" + "github.com/KARTIKrocks/sqlguard/middleware" + _ "github.com/mattn/go-sqlite3" +) + +// capture is a thread-safe in-memory Reporter for assertions. +type capture struct { + mu sync.Mutex + r []analyzer.Result +} + +func (c *capture) Report(rs []analyzer.Result) { + c.mu.Lock() + defer c.mu.Unlock() + c.r = append(c.r, rs...) +} + +func (c *capture) snapshot() []analyzer.Result { + c.mu.Lock() + defer c.mu.Unlock() + out := make([]analyzer.Result, len(c.r)) + copy(out, c.r) + return out +} + +func (c *capture) has(rule string) bool { + for _, r := range c.snapshot() { + if r.RuleName == rule { + return true + } + } + return false +} + +// newDriverWithCapture opens a real sqlite-backed ent dialect.Driver, seeds it +// through the *unwrapped* driver (so the capture starts clean), then wraps it +// with the sqlguard decorator. The integration thus runs end-to-end +// (dialect.Driver seam → database/sql round trip) rather than mocked. +func newDriverWithCapture(t *testing.T, opts ...middleware.Option) (*Driver, *capture) { + t.Helper() + drv, err := entsql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("entsql.Open: %v", err) + } + t.Cleanup(func() { _ = drv.Close() }) + + ctx := context.Background() + if err := drv.Exec(ctx, "CREATE TABLE users (id INTEGER PRIMARY KEY, email TEXT)", []any{}, nil); err != nil { + t.Fatalf("create table: %v", err) + } + if err := drv.Exec(ctx, "INSERT INTO users (id, email) VALUES (?, ?)", []any{1, "leak@example.com"}, nil); err != nil { + t.Fatalf("seed: %v", err) + } + + cap := &capture{} + opts = append([]middleware.Option{middleware.WithReporter(cap)}, opts...) + return Wrap(drv, opts...), cap +} + +func query(t *testing.T, ctx context.Context, q interface { + Query(context.Context, string, any, any) error +}, sqlText string) error { + t.Helper() + var rows entsql.Rows + err := q.Query(ctx, sqlText, []any{}, &rows) + if err == nil { + _ = rows.Close() + } + return err +} + +func TestDriver_DetectsSelectStar(t *testing.T) { + drv, cap := newDriverWithCapture(t) + if err := query(t, context.Background(), drv, "SELECT * FROM users"); err != nil { + t.Fatalf("Query: %v", err) + } + if !cap.has("select-star") { + t.Fatalf("expected select-star finding, got %+v", cap.snapshot()) + } +} + +// TestDriver_RedactsLiteralsByDefault asserts the headline redaction +// guarantee: single-quoted literals never reach Result.Query and Fingerprint +// is always populated. +func TestDriver_RedactsLiteralsByDefault(t *testing.T) { + drv, cap := newDriverWithCapture(t) + if err := query(t, context.Background(), drv, "SELECT * FROM users WHERE email = 'leak@example.com'"); err != nil { + t.Fatalf("Query: %v", err) + } + results := cap.snapshot() + if len(results) == 0 { + t.Fatal("expected at least one finding") + } + for _, r := range results { + if strings.Contains(r.Query, "leak@example.com") { + t.Errorf("literal leaked into Result.Query: %q (rule=%s)", r.Query, r.RuleName) + } + if r.Fingerprint == "" { + t.Errorf("Fingerprint must always be populated, got empty for rule %s", r.RuleName) + } + } +} + +func TestDriver_SlowQueryReportedOnSuccess(t *testing.T) { + drv, cap := newDriverWithCapture(t, middleware.WithSlowQueryThreshold(0)) + if err := query(t, context.Background(), drv, "SELECT id FROM users WHERE id = 1"); err != nil { + t.Fatalf("Query: %v", err) + } + if !cap.has("slow-query") { + t.Fatalf("expected slow-query finding with zero threshold, got %+v", cap.snapshot()) + } +} + +func TestDriver_SlowQuerySuppressedOnError(t *testing.T) { + drv, cap := newDriverWithCapture(t, middleware.WithSlowQueryThreshold(0)) + err := query(t, context.Background(), drv, "SELECT id FROM no_such_table_xyz WHERE id = 1") + if err == nil { + t.Fatal("expected error from selecting a missing table") + } + if cap.has("slow-query") { + t.Fatalf("slow-query must not fire when the query failed; got %+v", cap.snapshot()) + } +} + +func TestDriver_NPlusOneAcrossCalls(t *testing.T) { + drv, cap := newDriverWithCapture(t, middleware.WithN1Detection(3, time.Second)) + for range 3 { + if err := query(t, context.Background(), drv, "SELECT id FROM users WHERE id = 1"); err != nil { + t.Fatalf("Query: %v", err) + } + } + if !cap.has("n-plus-one") { + t.Fatalf("expected n-plus-one finding after 3 identical queries, got %+v", cap.snapshot()) + } +} + +func TestDriver_ResetN1ClearsState(t *testing.T) { + drv, cap := newDriverWithCapture(t, middleware.WithN1Detection(3, time.Second)) + for range 2 { + if err := query(t, context.Background(), drv, "SELECT id FROM users WHERE id = 1"); err != nil { + t.Fatalf("Query: %v", err) + } + } + drv.ResetN1() + if err := query(t, context.Background(), drv, "SELECT id FROM users WHERE id = 1"); err != nil { + t.Fatalf("Query: %v", err) + } + if cap.has("n-plus-one") { + t.Fatalf("n-plus-one should not fire — ResetN1 zeroed the counter; got %+v", cap.snapshot()) + } +} + +func TestDriver_ExecUpdateAnalyzed(t *testing.T) { + drv, cap := newDriverWithCapture(t) + if err := drv.Exec(context.Background(), "UPDATE users SET email = 'x'", []any{}, nil); err != nil { + t.Fatalf("Exec: %v", err) + } + if !cap.has("update-without-where") { + t.Fatalf("expected update-without-where, got %+v", cap.snapshot()) + } +} + +// TestDriver_TxQueriesAnalyzed proves the transaction wrapper also routes +// in-tx statements through Guard — a query class the database/sql-only path +// would miss if Tx weren't decorated. +func TestDriver_TxQueriesAnalyzed(t *testing.T) { + drv, cap := newDriverWithCapture(t) + ctx := context.Background() + tx, err := drv.Tx(ctx) + if err != nil { + t.Fatalf("Tx: %v", err) + } + if err := query(t, ctx, tx, "SELECT * FROM users"); err != nil { + _ = tx.Rollback() + t.Fatalf("tx Query: %v", err) + } + if err := tx.Commit(); err != nil { + t.Fatalf("Commit: %v", err) + } + if !cap.has("select-star") { + t.Fatalf("expected select-star finding from in-tx query, got %+v", cap.snapshot()) + } +} diff --git a/integrations/entguard/go.mod b/integrations/entguard/go.mod new file mode 100644 index 0000000..a58e275 --- /dev/null +++ b/integrations/entguard/go.mod @@ -0,0 +1,13 @@ +module github.com/KARTIKrocks/sqlguard/integrations/entguard + +go 1.26 + +require ( + entgo.io/ent v0.14.6 + github.com/KARTIKrocks/sqlguard v0.0.0 + github.com/mattn/go-sqlite3 v1.14.45 +) + +require github.com/google/uuid v1.3.0 // indirect + +replace github.com/KARTIKrocks/sqlguard => ../.. diff --git a/integrations/entguard/go.sum b/integrations/entguard/go.sum new file mode 100644 index 0000000..90db7ad --- /dev/null +++ b/integrations/entguard/go.sum @@ -0,0 +1,16 @@ +entgo.io/ent v0.14.6 h1:/f2696BpwuWAEEG6PVGWflg6+Inrpq4pRWuNlWz/Skk= +entgo.io/ent v0.14.6/go.mod h1:z46QBUdGC+BATwsedbDuREfSS0oSCV+csdEYlL4p73s= +github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= +github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/mattn/go-sqlite3 v1.14.45 h1:6KA/spDguL3KV8rnybG7ezSaE4SeMR3KC9VbUoAQaIk= +github.com/mattn/go-sqlite3 v1.14.45/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/integrations/gormguard/go.mod b/integrations/gormguard/go.mod new file mode 100644 index 0000000..5fd0221 --- /dev/null +++ b/integrations/gormguard/go.mod @@ -0,0 +1,18 @@ +module github.com/KARTIKrocks/sqlguard/integrations/gormguard + +go 1.26 + +require ( + github.com/KARTIKrocks/sqlguard v0.0.0 + gorm.io/driver/sqlite v1.6.0 + gorm.io/gorm v1.31.1 +) + +require ( + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/mattn/go-sqlite3 v1.14.45 // indirect + golang.org/x/text v0.20.0 // indirect +) + +replace github.com/KARTIKrocks/sqlguard => ../.. diff --git a/integrations/gormguard/go.sum b/integrations/gormguard/go.sum new file mode 100644 index 0000000..3bb1e0e --- /dev/null +++ b/integrations/gormguard/go.sum @@ -0,0 +1,12 @@ +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/mattn/go-sqlite3 v1.14.45 h1:6KA/spDguL3KV8rnybG7ezSaE4SeMR3KC9VbUoAQaIk= +github.com/mattn/go-sqlite3 v1.14.45/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ= +golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= +golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= +gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ= +gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8= +gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg= +gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= diff --git a/integrations/gormguard/gormguard.go b/integrations/gormguard/gormguard.go new file mode 100644 index 0000000..29237d2 --- /dev/null +++ b/integrations/gormguard/gormguard.go @@ -0,0 +1,134 @@ +// Package gormguard integrates sqlguard with GORM. +// +// Analysis is driven by the single shared sqlguard core (middleware.Guard), +// so redaction-by-default, stable fingerprints, the pluggable real-grammar +// parser, slow-query timing and N+1 detection behave identically to the +// database/sql driver wrapper and to pgxguard. There is no parallel option +// surface — configure with the standard middleware options: +// +// gormDB, _ := gorm.Open(postgres.Open(dsn), &gorm.Config{}) +// gormguard.Register(gormDB, +// middleware.WithSlowQueryThreshold(500*time.Millisecond), +// middleware.WithN1Detection(10, time.Second), +// ) +// +// GORM only exposes the final built SQL in its after-callback (it has not +// been generated when the before-callback fires), so this plugin uses the +// explicit Check+CheckLatency pair rather than middleware.Guard.Observe. +// Behaviour matches Observe semantically: static rules run on every call, +// latency is reported only on success. +package gormguard + +import ( + "time" + + "github.com/KARTIKrocks/sqlguard/middleware" + "gorm.io/gorm" +) + +// Plugin implements gorm.Plugin and drives every traced statement through +// the shared sqlguard analysis core. +type Plugin struct { + g *middleware.Guard +} + +// Compile-time proof we satisfy gorm.Plugin. +var _ gorm.Plugin = (*Plugin)(nil) + +// New creates a new sqlguard GORM plugin. It accepts the standard sqlguard +// middleware options (WithAnalyzer, WithReporter, WithSlowQueryThreshold, +// WithParser, WithN1Detection, …) — the same option set the database/sql +// driver wrapper and pgxguard use, so there is no parallel configuration +// surface to drift. +func New(opts ...middleware.Option) *Plugin { + return &Plugin{g: middleware.NewGuard(opts...)} +} + +// Name implements gorm.Plugin. +func (p *Plugin) Name() string { return "sqlguard" } + +// ResetN1 clears N+1 tracker state. Call it at a per-request boundary +// (e.g. end of an HTTP handler) to scope N+1 detection to one unit of work. +// No-op unless WithN1Detection was passed to New / Register. +func (p *Plugin) ResetN1() { p.g.ResetN1() } + +// Initialize registers before/after callbacks on every GORM callback chain. +// +// GORM v2 routes operations through six distinct callback chains: +// - Create/Update/Delete — ORM-style mutating operations +// - Query — ORM-style reads (First/Find/Take/…) +// - Row — raw SQL that returns rows (db.Raw().Scan / .Row) +// - Raw — raw SQL without rows (db.Exec) +// +// Missing any chain silently uncovers a query class — pre-rewrite, only +// Create/Query/Update/Delete were hooked, so every db.Raw and db.Exec +// bypassed analysis (and there were no tests to catch it). All six chains +// are now registered. +// +// SQL is analyzed in the after-callback because GORM has not yet rendered +// db.Statement.SQL when the before-callback fires for the ORM chains. +func (p *Plugin) Initialize(db *gorm.DB) error { + cb := db.Callback() + registrations := []struct { + before, after func(name string, fn func(*gorm.DB)) error + chain string + }{ + {before: cb.Create().Before("gorm:create").Register, after: cb.Create().After("gorm:create").Register, chain: "create"}, + {before: cb.Query().Before("gorm:query").Register, after: cb.Query().After("gorm:query").Register, chain: "query"}, + {before: cb.Update().Before("gorm:update").Register, after: cb.Update().After("gorm:update").Register, chain: "update"}, + {before: cb.Delete().Before("gorm:delete").Register, after: cb.Delete().After("gorm:delete").Register, chain: "delete"}, + {before: cb.Row().Before("gorm:row").Register, after: cb.Row().After("gorm:row").Register, chain: "row"}, + {before: cb.Raw().Before("gorm:raw").Register, after: cb.Raw().After("gorm:raw").Register, chain: "raw"}, + } + for _, r := range registrations { + if err := r.before("sqlguard:before_"+r.chain, p.before); err != nil { + return err + } + if err := r.after("sqlguard:after_"+r.chain, p.after); err != nil { + return err + } + } + return nil +} + +// startTimeKey is the per-statement context key under which the before +// callback stashes the start timestamp. Unexported so it can't collide with +// keys set by other plugins. +const startTimeKey = "sqlguard:start_time" + +func (p *Plugin) before(db *gorm.DB) { + db.Set(startTimeKey, time.Now()) +} + +func (p *Plugin) after(db *gorm.DB) { + if db.Statement == nil { + return + } + sql := db.Statement.SQL.String() + if sql == "" { + return + } + + // Static rules + N+1 run on every call (matches Observe semantics). + p.g.Check(sql) + + // Latency is reported only on success — a failed query's duration is + // meaningless. This mirrors middleware.Guard.Observe. + if db.Error != nil { + return + } + val, ok := db.Get(startTimeKey) + if !ok { + return + } + start, ok := val.(time.Time) + if !ok { + return + } + p.g.CheckLatency(sql, time.Since(start)) +} + +// Register is a convenience function to create and register the plugin. +func Register(db *gorm.DB, opts ...middleware.Option) error { + return db.Use(New(opts...)) +} diff --git a/integrations/gormguard/gormguard_test.go b/integrations/gormguard/gormguard_test.go new file mode 100644 index 0000000..c0e03cf --- /dev/null +++ b/integrations/gormguard/gormguard_test.go @@ -0,0 +1,200 @@ +package gormguard + +import ( + "context" + "strings" + "sync" + "testing" + "time" + + "github.com/KARTIKrocks/sqlguard/analyzer" + "github.com/KARTIKrocks/sqlguard/middleware" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +// capture is a thread-safe in-memory Reporter for assertions. +type capture struct { + mu sync.Mutex + r []analyzer.Result +} + +func (c *capture) Report(rs []analyzer.Result) { + c.mu.Lock() + defer c.mu.Unlock() + c.r = append(c.r, rs...) +} + +func (c *capture) snapshot() []analyzer.Result { + c.mu.Lock() + defer c.mu.Unlock() + out := make([]analyzer.Result, len(c.r)) + copy(out, c.r) + return out +} + +func (c *capture) has(rule string) bool { + for _, r := range c.snapshot() { + if r.RuleName == rule { + return true + } + } + return false +} + +type user struct { + ID int64 `gorm:"primaryKey"` + Email string +} + +// newDBWithCapture spins up an in-memory sqlite-backed *gorm.DB with the +// sqlguard plugin registered, so the integration runs end-to-end (callback +// seam → driver round trip) rather than mocked. +func newDBWithCapture(t *testing.T, opts ...middleware.Option) (*gorm.DB, *capture, *Plugin) { + t.Helper() + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + if err != nil { + t.Fatalf("gorm.Open: %v", err) + } + if err := db.AutoMigrate(&user{}); err != nil { + t.Fatalf("AutoMigrate: %v", err) + } + if err := db.Create(&user{ID: 1, Email: "leak@example.com"}).Error; err != nil { + t.Fatalf("seed: %v", err) + } + + cap := &capture{} + opts = append([]middleware.Option{middleware.WithReporter(cap)}, opts...) + plugin := New(opts...) + if err := db.Use(plugin); err != nil { + t.Fatalf("db.Use: %v", err) + } + // Reset capture so the seed INSERT's findings don't pollute test + // assertions — every test wants only the findings from its own queries. + cap.mu.Lock() + cap.r = nil + cap.mu.Unlock() + return db, cap, plugin +} + +func TestPlugin_DetectsRawSelectStar(t *testing.T) { + db, cap, _ := newDBWithCapture(t) + var us []user + if err := db.Raw("SELECT * FROM users").Scan(&us).Error; err != nil { + t.Fatalf("Raw: %v", err) + } + if !cap.has("select-star") { + t.Fatalf("expected select-star finding, got %+v", cap.snapshot()) + } +} + +// TestPlugin_RedactsLiteralsByDefault is the headline 11.1 regression: +// the old hand-rolled after() set Result.Query to the raw SQL, so single- +// quoted literals leaked into log sinks. After the Guard rewrite Query +// must be the redacted form and Fingerprint must always be populated. +func TestPlugin_RedactsLiteralsByDefault(t *testing.T) { + db, cap, _ := newDBWithCapture(t) + var us []user + if err := db.Raw("SELECT * FROM users WHERE email = 'leak@example.com'").Scan(&us).Error; err != nil { + t.Fatalf("Raw: %v", err) + } + results := cap.snapshot() + if len(results) == 0 { + t.Fatal("expected at least one finding") + } + for _, r := range results { + if strings.Contains(r.Query, "leak@example.com") { + t.Errorf("literal leaked into Result.Query: %q (rule=%s)", r.Query, r.RuleName) + } + if r.Fingerprint == "" { + t.Errorf("Fingerprint must always be populated, got empty for rule %s", r.RuleName) + } + } +} + +// TestPlugin_SlowQueryReportedOnSuccess uses a zero threshold so any +// successful query trips the slow-query path. Threshold arithmetic is +// covered by middleware.Guard's own tests — here we only assert that the +// integration's after-callback drives CheckLatency on success. +func TestPlugin_SlowQueryReportedOnSuccess(t *testing.T) { + db, cap, _ := newDBWithCapture(t, middleware.WithSlowQueryThreshold(0)) + var u user + if err := db.First(&u, 1).Error; err != nil { + t.Fatalf("First: %v", err) + } + if !cap.has("slow-query") { + t.Fatalf("expected slow-query finding with zero threshold, got %+v", cap.snapshot()) + } +} + +func TestPlugin_SlowQuerySuppressedOnError(t *testing.T) { + db, cap, _ := newDBWithCapture(t, middleware.WithSlowQueryThreshold(0)) + // Force a SQL error: SELECT from a missing table via Raw so we hit the + // after-callback with db.Error != nil. + var dst int + err := db.Raw("SELECT id FROM no_such_table_xyz WHERE id = 1").Scan(&dst).Error + if err == nil { + t.Fatal("expected error from selecting a missing table") + } + if cap.has("slow-query") { + t.Fatalf("slow-query must not fire when the query failed; got %+v", cap.snapshot()) + } +} + +func TestPlugin_NPlusOneAcrossCalls(t *testing.T) { + db, cap, _ := newDBWithCapture(t, middleware.WithN1Detection(3, time.Second)) + var u user + for range 3 { + if err := db.Raw("SELECT id FROM users WHERE id = 1").Scan(&u).Error; err != nil { + t.Fatalf("Raw: %v", err) + } + } + if !cap.has("n-plus-one") { + t.Fatalf("expected n-plus-one finding after 3 identical queries, got %+v", cap.snapshot()) + } +} + +func TestPlugin_ResetN1ClearsState(t *testing.T) { + db, cap, plugin := newDBWithCapture(t, middleware.WithN1Detection(3, time.Second)) + var u user + for range 2 { + if err := db.Raw("SELECT id FROM users WHERE id = 1").Scan(&u).Error; err != nil { + t.Fatalf("Raw: %v", err) + } + } + plugin.ResetN1() + if err := db.Raw("SELECT id FROM users WHERE id = 1").Scan(&u).Error; err != nil { + t.Fatalf("Raw: %v", err) + } + if cap.has("n-plus-one") { + t.Fatalf("n-plus-one should not fire — ResetN1 zeroed the counter; got %+v", cap.snapshot()) + } +} + +// Proves the UPDATE / DELETE callbacks also flow through Guard. +func TestPlugin_UpdateAndDeleteCallbacksAnalyzed(t *testing.T) { + db, cap, _ := newDBWithCapture(t) + if err := db.WithContext(context.Background()).Exec("UPDATE users SET email = 'x'").Error; err != nil { + t.Fatalf("UPDATE: %v", err) + } + if !cap.has("update-without-where") { + t.Fatalf("expected update-without-where from update callback, got %+v", cap.snapshot()) + } + + if err := db.Exec("DELETE FROM users").Error; err != nil { + t.Fatalf("DELETE: %v", err) + } + if !cap.has("delete-without-where") { + t.Fatalf("expected delete-without-where from delete callback, got %+v", cap.snapshot()) + } +} + +func TestRegister_ReturnsNoError(t *testing.T) { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + if err != nil { + t.Fatalf("gorm.Open: %v", err) + } + if err := Register(db); err != nil { + t.Fatalf("Register: %v", err) + } +} diff --git a/integrations/pgxguard/go.mod b/integrations/pgxguard/go.mod new file mode 100644 index 0000000..136572a --- /dev/null +++ b/integrations/pgxguard/go.mod @@ -0,0 +1,19 @@ +module github.com/KARTIKrocks/sqlguard/integrations/pgxguard + +go 1.26 + +require ( + github.com/KARTIKrocks/sqlguard v0.0.0 + github.com/jackc/pgx/v5 v5.7.6 +) + +require ( + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + golang.org/x/crypto v0.37.0 // indirect + golang.org/x/sync v0.20.0 // indirect + golang.org/x/text v0.27.0 // indirect +) + +replace github.com/KARTIKrocks/sqlguard => ../.. diff --git a/integrations/pgxguard/go.sum b/integrations/pgxguard/go.sum new file mode 100644 index 0000000..f06d480 --- /dev/null +++ b/integrations/pgxguard/go.sum @@ -0,0 +1,30 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk= +github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/mattn/go-sqlite3 v1.14.45 h1:6KA/spDguL3KV8rnybG7ezSaE4SeMR3KC9VbUoAQaIk= +github.com/mattn/go-sqlite3 v1.14.45/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= +golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= +golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/integrations/pgxguard/pgxguard.go b/integrations/pgxguard/pgxguard.go new file mode 100644 index 0000000..426110a --- /dev/null +++ b/integrations/pgxguard/pgxguard.go @@ -0,0 +1,145 @@ +// Package pgxguard integrates sqlguard with pgx/v5 — the native, dominant +// PostgreSQL driver for Go (pgx/pgxpool, not the database/sql shim). +// +// It hooks pgx's own tracer seam (pgx.QueryTracer + pgx.BatchTracer), which +// is the idiomatic extension point every pgx ecosystem tool uses, so every +// Query/QueryRow/Exec and every SendBatch is analyzed without a method list +// or a wrapper type. +// +// Composability is a first-class concern: pgx allows exactly one Tracer per +// config, and production services usually already set one (otelpgx). Apply +// and ApplyPool therefore *compose* with any existing tracer via pgx's own +// multitracer rather than overwriting it. +// +// Usage with a pool: +// +// cfg, _ := pgxpool.ParseConfig(dsn) +// pgxguard.ApplyPool(cfg) // composes with cfg.ConnConfig.Tracer if set +// pool, _ := pgxpool.NewWithConfig(ctx, cfg) +// +// Usage with a single connection: +// +// cfg, _ := pgx.ParseConfig(dsn) +// pgxguard.Apply(cfg) +// conn, _ := pgx.ConnectConfig(ctx, cfg) +// +// Analysis is driven by the single sqlguard core (middleware.Guard), so +// redaction-by-default, stable fingerprints, the pluggable real-grammar +// parser, slow-query and N+1 detection all behave identically to the +// database/sql driver wrapper. Configure with the standard middleware +// options: +// +// pgxguard.NewTracer( +// middleware.WithSlowQueryThreshold(50*time.Millisecond), +// middleware.WithN1Detection(10, time.Second), +// ) +package pgxguard + +import ( + "context" + + "github.com/KARTIKrocks/sqlguard/middleware" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/multitracer" + "github.com/jackc/pgx/v5/pgxpool" +) + +// Tracer implements pgx.QueryTracer and pgx.BatchTracer, driving every +// traced statement through the shared sqlguard analysis core. +// +// It deliberately does not implement pgx.PrepareTracer: prepared statements +// are still analyzed when executed (execution routes through QueryTracer), +// so tracing Prepare as well would double-report findings and inflate N+1 +// counts. CopyFrom carries no SQL and is out of scope by nature. +type Tracer struct { + g *middleware.Guard +} + +// Compile-time proof we satisfy the pgx tracer interfaces we claim. +var ( + _ pgx.QueryTracer = (*Tracer)(nil) + _ pgx.BatchTracer = (*Tracer)(nil) +) + +// NewTracer builds a Tracer. It accepts the standard sqlguard middleware +// options (WithAnalyzer, WithReporter, WithSlowQueryThreshold, WithParser, +// WithN1Detection, …) — the same option set the database/sql driver wrapper +// uses, so there is no parallel configuration surface to drift. +func NewTracer(opts ...middleware.Option) *Tracer { + return &Tracer{g: middleware.NewGuard(opts...)} +} + +// ResetN1 clears N+1 tracker state. Call it at a per-request boundary +// (e.g. end of an HTTP handler) to scope N+1 detection to one unit of work. +// No-op unless WithN1Detection was passed to NewTracer. +func (t *Tracer) ResetN1() { t.g.ResetN1() } + +// ctxKey is unexported so the stashed latency closure can't collide with +// any other package's context values. +type ctxKey struct{} + +// TraceQueryStart runs static analysis + N+1 tracking and starts the latency +// timer, stashing the end closure in the returned context. +func (t *Tracer) TraceQueryStart(ctx context.Context, _ *pgx.Conn, data pgx.TraceQueryStartData) context.Context { + done := t.g.Observe(data.SQL) + return context.WithValue(ctx, ctxKey{}, done) +} + +// TraceQueryEnd closes the latency window. Latency is recorded only on +// success (Guard.Observe drops it when data.Err != nil — a failed query's +// duration is meaningless). +func (t *Tracer) TraceQueryEnd(ctx context.Context, _ *pgx.Conn, data pgx.TraceQueryEndData) { + if done, ok := ctx.Value(ctxKey{}).(func(error)); ok { + done(data.Err) + } +} + +// TraceBatchStart is a no-op: the batch's SQL is only known per-query, in +// TraceBatchQuery. +func (t *Tracer) TraceBatchStart(ctx context.Context, _ *pgx.Conn, _ pgx.TraceBatchStartData) context.Context { + return ctx +} + +// TraceBatchQuery analyzes each statement in a batch (static rules + N+1). +// Per-statement latency is not exposed by pgx's batch tracer — only the +// whole-batch round trip — so slow-query timing is intentionally not +// reported here rather than reported wrongly. +func (t *Tracer) TraceBatchQuery(_ context.Context, _ *pgx.Conn, data pgx.TraceBatchQueryData) { + t.g.Check(data.SQL) +} + +// TraceBatchEnd is a no-op (per-statement analysis happens in TraceBatchQuery). +func (t *Tracer) TraceBatchEnd(_ context.Context, _ *pgx.Conn, _ pgx.TraceBatchEndData) {} + +// Apply installs a sqlguard Tracer on a *pgx.ConnConfig, composing with any +// tracer already configured (via pgx's multitracer) instead of overwriting +// it — so it coexists with otelpgx and friends. opts are the standard +// middleware options. Returns the same cfg for chaining. +func Apply(cfg *pgx.ConnConfig, opts ...middleware.Option) *pgx.ConnConfig { + if cfg == nil { + panic("pgxguard: Apply called with nil *pgx.ConnConfig") + } + cfg.Tracer = compose(cfg.Tracer, NewTracer(opts...)) + return cfg +} + +// ApplyPool installs a sqlguard Tracer on a *pgxpool.Config (delegating to +// Apply on the embedded ConnConfig), composing with any existing tracer. +// Returns the same cfg for chaining. +func ApplyPool(cfg *pgxpool.Config, opts ...middleware.Option) *pgxpool.Config { + if cfg == nil { + panic("pgxguard: ApplyPool called with nil *pgxpool.Config") + } + Apply(cfg.ConnConfig, opts...) + return cfg +} + +// compose merges an existing tracer with ours. multitracer.New fans each +// call out to every wrapped tracer and routes by interface type-assertion, +// so the existing tracer keeps receiving exactly the events it did before. +func compose(existing pgx.QueryTracer, ours pgx.QueryTracer) pgx.QueryTracer { + if existing == nil { + return ours + } + return multitracer.New(existing, ours) +} diff --git a/integrations/pgxguard/pgxguard_test.go b/integrations/pgxguard/pgxguard_test.go new file mode 100644 index 0000000..6a2120f --- /dev/null +++ b/integrations/pgxguard/pgxguard_test.go @@ -0,0 +1,247 @@ +package pgxguard + +import ( + "context" + "errors" + "strings" + "sync" + "testing" + "time" + + "github.com/KARTIKrocks/sqlguard/analyzer" + "github.com/KARTIKrocks/sqlguard/middleware" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/multitracer" + "github.com/jackc/pgx/v5/pgxpool" +) + +// capture is a thread-safe in-memory Reporter for assertions. +type capture struct { + mu sync.Mutex + r []analyzer.Result +} + +func (c *capture) Report(rs []analyzer.Result) { + c.mu.Lock() + defer c.mu.Unlock() + c.r = append(c.r, rs...) +} + +func (c *capture) snapshot() []analyzer.Result { + c.mu.Lock() + defer c.mu.Unlock() + out := make([]analyzer.Result, len(c.r)) + copy(out, c.r) + return out +} + +func (c *capture) has(rule string) bool { + for _, r := range c.snapshot() { + if r.RuleName == rule { + return true + } + } + return false +} + +// stubTracer is a fake existing pgx.QueryTracer used to prove Apply composes +// instead of clobbering. +type stubTracer struct { + mu sync.Mutex + starts int + ends int +} + +func (s *stubTracer) TraceQueryStart(ctx context.Context, _ *pgx.Conn, _ pgx.TraceQueryStartData) context.Context { + s.mu.Lock() + s.starts++ + s.mu.Unlock() + return ctx +} + +func (s *stubTracer) TraceQueryEnd(_ context.Context, _ *pgx.Conn, _ pgx.TraceQueryEndData) { + s.mu.Lock() + s.ends++ + s.mu.Unlock() +} + +func newTracerWithCapture(t *testing.T, opts ...middleware.Option) (*Tracer, *capture) { + t.Helper() + cap := &capture{} + opts = append([]middleware.Option{middleware.WithReporter(cap)}, opts...) + return NewTracer(opts...), cap +} + +// driveQuery runs a full Start→End round trip with no error. +func driveQuery(tr *Tracer, sql string, err error) { + ctx := tr.TraceQueryStart(context.Background(), nil, pgx.TraceQueryStartData{SQL: sql}) + tr.TraceQueryEnd(ctx, nil, pgx.TraceQueryEndData{Err: err}) +} + +func TestTracer_DetectsSelectStarOnQueryStart(t *testing.T) { + tr, cap := newTracerWithCapture(t) + driveQuery(tr, "SELECT * FROM users", nil) + if !cap.has("select-star") { + t.Fatalf("expected select-star finding, got %+v", cap.snapshot()) + } +} + +func TestTracer_RedactsLiteralsByDefault(t *testing.T) { + tr, cap := newTracerWithCapture(t) + driveQuery(tr, "SELECT * FROM users WHERE email = 'leak@example.com'", nil) + results := cap.snapshot() + if len(results) == 0 { + t.Fatal("expected at least one finding") + } + for _, r := range results { + if strings.Contains(r.Query, "leak@example.com") { + t.Errorf("literal leaked into Result.Query: %q", r.Query) + } + if r.Fingerprint == "" { + t.Errorf("Fingerprint must always be populated, got empty for rule %s", r.RuleName) + } + } +} + +func TestTracer_SlowQueryReportedOnEnd(t *testing.T) { + tr, cap := newTracerWithCapture(t, middleware.WithSlowQueryThreshold(1*time.Millisecond)) + ctx := tr.TraceQueryStart(context.Background(), nil, pgx.TraceQueryStartData{SQL: "SELECT id FROM users WHERE id = 1"}) + time.Sleep(5 * time.Millisecond) + tr.TraceQueryEnd(ctx, nil, pgx.TraceQueryEndData{Err: nil}) + if !cap.has("slow-query") { + t.Fatalf("expected slow-query finding, got %+v", cap.snapshot()) + } +} + +func TestTracer_SlowQuerySuppressedOnError(t *testing.T) { + tr, cap := newTracerWithCapture(t, middleware.WithSlowQueryThreshold(1*time.Millisecond)) + ctx := tr.TraceQueryStart(context.Background(), nil, pgx.TraceQueryStartData{SQL: "SELECT id FROM users WHERE id = 1"}) + time.Sleep(5 * time.Millisecond) + tr.TraceQueryEnd(ctx, nil, pgx.TraceQueryEndData{Err: errors.New("boom")}) + if cap.has("slow-query") { + t.Fatalf("slow-query should not fire when the query failed; got %+v", cap.snapshot()) + } +} + +func TestTracer_NPlusOneAcrossCalls(t *testing.T) { + tr, cap := newTracerWithCapture(t, middleware.WithN1Detection(3, time.Second)) + for range 3 { + driveQuery(tr, "SELECT id FROM users WHERE id = 1", nil) + } + if !cap.has("n-plus-one") { + t.Fatalf("expected n-plus-one finding after 3 identical queries, got %+v", cap.snapshot()) + } +} + +func TestTracer_ResetN1ClearsState(t *testing.T) { + tr, cap := newTracerWithCapture(t, middleware.WithN1Detection(3, time.Second)) + for range 2 { + driveQuery(tr, "SELECT id FROM users WHERE id = 1", nil) + } + tr.ResetN1() + driveQuery(tr, "SELECT id FROM users WHERE id = 1", nil) + if cap.has("n-plus-one") { + t.Fatalf("n-plus-one should not fire — reset zeroed the counter; got %+v", cap.snapshot()) + } +} + +func TestTracer_BatchQueryAnalyzed(t *testing.T) { + tr, cap := newTracerWithCapture(t) + ctx := tr.TraceBatchStart(context.Background(), nil, pgx.TraceBatchStartData{}) + tr.TraceBatchQuery(ctx, nil, pgx.TraceBatchQueryData{SQL: "SELECT * FROM users"}) + tr.TraceBatchEnd(ctx, nil, pgx.TraceBatchEndData{}) + if !cap.has("select-star") { + t.Fatalf("expected select-star finding from batch path, got %+v", cap.snapshot()) + } +} + +func TestApply_NilExistingSetsOursDirectly(t *testing.T) { + cfg, err := pgx.ParseConfig("postgres://u:p@localhost:5432/db") + if err != nil { + t.Fatalf("ParseConfig: %v", err) + } + if cfg.Tracer != nil { + t.Fatalf("baseline assumption broken: ParseConfig set a tracer (%T)", cfg.Tracer) + } + Apply(cfg) + if _, ok := cfg.Tracer.(*Tracer); !ok { + t.Fatalf("expected *pgxguard.Tracer, got %T", cfg.Tracer) + } +} + +// TestApply_ComposesWithExistingTracer is the headline community-fitness +// guarantee: if the user has already wired e.g. otelpgx, Apply must NOT +// silently overwrite it. +func TestApply_ComposesWithExistingTracer(t *testing.T) { + cfg, err := pgx.ParseConfig("postgres://u:p@localhost:5432/db") + if err != nil { + t.Fatalf("ParseConfig: %v", err) + } + stub := &stubTracer{} + cfg.Tracer = stub + + Apply(cfg) + + mt, ok := cfg.Tracer.(*multitracer.Tracer) + if !ok { + t.Fatalf("expected *multitracer.Tracer after composition, got %T", cfg.Tracer) + } + + var sawStub, sawOurs bool + for _, qt := range mt.QueryTracers { + switch qt.(type) { + case *stubTracer: + sawStub = true + case *Tracer: + sawOurs = true + } + } + if !sawStub { + t.Error("existing tracer was dropped by Apply — community-fitness contract violated") + } + if !sawOurs { + t.Error("our tracer was not installed by Apply") + } + + // And drive it: the existing stub must still receive Start/End events. + ctx := cfg.Tracer.TraceQueryStart(context.Background(), nil, pgx.TraceQueryStartData{SQL: "SELECT 1"}) + cfg.Tracer.TraceQueryEnd(ctx, nil, pgx.TraceQueryEndData{}) + stub.mu.Lock() + defer stub.mu.Unlock() + if stub.starts != 1 || stub.ends != 1 { + t.Errorf("existing tracer not driven through composition: starts=%d ends=%d", stub.starts, stub.ends) + } +} + +func TestApplyPool_DelegatesAndComposes(t *testing.T) { + cfg, err := pgxpool.ParseConfig("postgres://u:p@localhost:5432/db") + if err != nil { + t.Fatalf("pgxpool.ParseConfig: %v", err) + } + stub := &stubTracer{} + cfg.ConnConfig.Tracer = stub + + ApplyPool(cfg) + + if _, ok := cfg.ConnConfig.Tracer.(*multitracer.Tracer); !ok { + t.Fatalf("ApplyPool did not compose: got %T", cfg.ConnConfig.Tracer) + } +} + +func TestApply_NilConfigPanics(t *testing.T) { + defer func() { + if recover() == nil { + t.Fatal("expected panic on nil *pgx.ConnConfig") + } + }() + Apply(nil) +} + +func TestApplyPool_NilConfigPanics(t *testing.T) { + defer func() { + if recover() == nil { + t.Fatal("expected panic on nil *pgxpool.Config") + } + }() + ApplyPool(nil) +} diff --git a/integrations/sqlxguard/go.mod b/integrations/sqlxguard/go.mod new file mode 100644 index 0000000..8fb1885 --- /dev/null +++ b/integrations/sqlxguard/go.mod @@ -0,0 +1,12 @@ +module github.com/KARTIKrocks/sqlguard/integrations/sqlxguard + +go 1.26 + +require ( + github.com/KARTIKrocks/sqlguard v0.0.0 + github.com/jmoiron/sqlx v1.4.0 +) + +require github.com/mattn/go-sqlite3 v1.14.45 + +replace github.com/KARTIKrocks/sqlguard => ../.. diff --git a/integrations/sqlxguard/go.sum b/integrations/sqlxguard/go.sum new file mode 100644 index 0000000..af18b98 --- /dev/null +++ b/integrations/sqlxguard/go.sum @@ -0,0 +1,11 @@ +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= +github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o= +github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/mattn/go-sqlite3 v1.14.45 h1:6KA/spDguL3KV8rnybG7ezSaE4SeMR3KC9VbUoAQaIk= +github.com/mattn/go-sqlite3 v1.14.45/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ= diff --git a/integrations/sqlxguard/sqlxguard.go b/integrations/sqlxguard/sqlxguard.go new file mode 100644 index 0000000..6ddcdbb --- /dev/null +++ b/integrations/sqlxguard/sqlxguard.go @@ -0,0 +1,157 @@ +// Package sqlxguard integrates sqlguard with sqlx. +// +// Every wrapped method routes through the shared sqlguard analysis core +// (middleware.Guard), so redaction-by-default, stable fingerprints, the +// pluggable real-grammar parser, slow-query timing and N+1 detection behave +// identically to the database/sql driver wrapper and to pgxguard. There is +// no parallel option surface — configure with the standard middleware +// options: +// +// db := sqlxguard.WrapSqlx(sqlxDB, +// middleware.WithSlowQueryThreshold(50*time.Millisecond), +// middleware.WithN1Detection(10, time.Second), +// ) +// +// Coverage note: WrappedDB exposes the sqlx-specific extension methods +// (Select/Get/Queryx/NamedExec and their *Context variants, plus Query/Exec +// passthrough). For full surface coverage — including QueryRow*, NamedQuery, +// MustExec and the transaction helpers — layer sqlx on top of the sqlguard +// driver chain instead: +// +// sqlguard.Register("sqlguard-pgx", pq.Driver{}, opts...) +// sqlDB, _ := sql.Open("sqlguard-pgx", dsn) +// db := sqlx.NewDb(sqlDB, "postgres") +// +// That path covers every sqlx method automatically because interception +// happens at the database/sql driver layer. +package sqlxguard + +import ( + "context" + "database/sql" + + "github.com/KARTIKrocks/sqlguard/middleware" + "github.com/jmoiron/sqlx" +) + +// WrappedDB wraps a *sqlx.DB with sqlguard analysis. Every analysis-bearing +// method drives the shared middleware.Guard, so behavior matches pgxguard +// and the database/sql driver chain exactly. +type WrappedDB struct { + db *sqlx.DB + g *middleware.Guard +} + +// WrapSqlx creates a new WrappedDB around the given sqlx connection. +// It accepts the standard sqlguard middleware options (WithAnalyzer, +// WithReporter, WithSlowQueryThreshold, WithParser, WithN1Detection, …) — +// the same option set the database/sql driver wrapper and pgxguard use, so +// there is no parallel configuration surface to drift. +func WrapSqlx(db *sqlx.DB, opts ...middleware.Option) *WrappedDB { + if db == nil { + panic("sqlxguard: WrapSqlx called with nil *sqlx.DB") + } + return &WrappedDB{db: db, g: middleware.NewGuard(opts...)} +} + +// DB returns the underlying *sqlx.DB. +func (w *WrappedDB) DB() *sqlx.DB { return w.db } + +// ResetN1 clears N+1 tracker state. Call it at a per-request boundary +// (e.g. end of an HTTP handler) to scope N+1 detection to one unit of work. +// No-op unless WithN1Detection was passed to WrapSqlx. +func (w *WrappedDB) ResetN1() { w.g.ResetN1() } + +// Select executes a query and scans the results into dest. +func (w *WrappedDB) Select(dest any, query string, args ...any) error { + done := w.g.Observe(query) + err := w.db.Select(dest, query, args...) + done(err) + return err +} + +// SelectContext executes a query with context and scans the results into dest. +func (w *WrappedDB) SelectContext(ctx context.Context, dest any, query string, args ...any) error { + done := w.g.Observe(query) + err := w.db.SelectContext(ctx, dest, query, args...) + done(err) + return err +} + +// Get executes a query and scans a single row into dest. +func (w *WrappedDB) Get(dest any, query string, args ...any) error { + done := w.g.Observe(query) + err := w.db.Get(dest, query, args...) + done(err) + return err +} + +// GetContext executes a query with context and scans a single row into dest. +func (w *WrappedDB) GetContext(ctx context.Context, dest any, query string, args ...any) error { + done := w.g.Observe(query) + err := w.db.GetContext(ctx, dest, query, args...) + done(err) + return err +} + +// Query executes a query that returns rows. +func (w *WrappedDB) Query(query string, args ...any) (*sql.Rows, error) { + done := w.g.Observe(query) + rows, err := w.db.Query(query, args...) + done(err) + return rows, err +} + +// QueryContext executes a query with context that returns rows. +func (w *WrappedDB) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { + done := w.g.Observe(query) + rows, err := w.db.QueryContext(ctx, query, args...) + done(err) + return rows, err +} + +// Queryx executes a query that returns sqlx.Rows. +func (w *WrappedDB) Queryx(query string, args ...any) (*sqlx.Rows, error) { + done := w.g.Observe(query) + rows, err := w.db.Queryx(query, args...) + done(err) + return rows, err +} + +// Exec executes a query without returning rows. +func (w *WrappedDB) Exec(query string, args ...any) (sql.Result, error) { + done := w.g.Observe(query) + result, err := w.db.Exec(query, args...) + done(err) + return result, err +} + +// ExecContext executes a query with context without returning rows. +func (w *WrappedDB) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { + done := w.g.Observe(query) + result, err := w.db.ExecContext(ctx, query, args...) + done(err) + return result, err +} + +// NamedExec executes a named query. +func (w *WrappedDB) NamedExec(query string, arg any) (sql.Result, error) { + done := w.g.Observe(query) + result, err := w.db.NamedExec(query, arg) + done(err) + return result, err +} + +// NamedExecContext executes a named query with context. +func (w *WrappedDB) NamedExecContext(ctx context.Context, query string, arg any) (sql.Result, error) { + done := w.g.Observe(query) + result, err := w.db.NamedExecContext(ctx, query, arg) + done(err) + return result, err +} + +// Ping verifies the database connection. +func (w *WrappedDB) Ping() error { return w.db.Ping() } + +// Close closes the database connection. +func (w *WrappedDB) Close() error { return w.db.Close() } diff --git a/integrations/sqlxguard/sqlxguard_test.go b/integrations/sqlxguard/sqlxguard_test.go new file mode 100644 index 0000000..8f692ce --- /dev/null +++ b/integrations/sqlxguard/sqlxguard_test.go @@ -0,0 +1,187 @@ +package sqlxguard + +import ( + "context" + "strings" + "sync" + "testing" + "time" + + "github.com/KARTIKrocks/sqlguard/analyzer" + "github.com/KARTIKrocks/sqlguard/middleware" + "github.com/jmoiron/sqlx" + _ "github.com/mattn/go-sqlite3" +) + +// capture is a thread-safe in-memory Reporter for assertions. +type capture struct { + mu sync.Mutex + r []analyzer.Result +} + +func (c *capture) Report(rs []analyzer.Result) { + c.mu.Lock() + defer c.mu.Unlock() + c.r = append(c.r, rs...) +} + +func (c *capture) snapshot() []analyzer.Result { + c.mu.Lock() + defer c.mu.Unlock() + out := make([]analyzer.Result, len(c.r)) + copy(out, c.r) + return out +} + +func (c *capture) has(rule string) bool { + for _, r := range c.snapshot() { + if r.RuleName == rule { + return true + } + } + return false +} + +// newWrappedWithCapture spins up an in-memory sqlite-backed *sqlx.DB so the +// integration is exercised end-to-end (sqlx extension method → database/sql → +// real driver round trip) rather than mocked. +func newWrappedWithCapture(t *testing.T, opts ...middleware.Option) (*WrappedDB, *capture) { + t.Helper() + sqlxDB, err := sqlx.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("sqlx.Open: %v", err) + } + t.Cleanup(func() { _ = sqlxDB.Close() }) + if _, err := sqlxDB.Exec(`CREATE TABLE users (id INTEGER PRIMARY KEY, email TEXT)`); err != nil { + t.Fatalf("create table: %v", err) + } + if _, err := sqlxDB.Exec(`INSERT INTO users (id, email) VALUES (1, 'leak@example.com')`); err != nil { + t.Fatalf("seed: %v", err) + } + cap := &capture{} + opts = append([]middleware.Option{middleware.WithReporter(cap)}, opts...) + return WrapSqlx(sqlxDB, opts...), cap +} + +type user struct { + ID int64 `db:"id"` + Email string `db:"email"` +} + +func TestWrappedDB_DetectsSelectStar(t *testing.T) { + w, cap := newWrappedWithCapture(t) + var us []user + if err := w.Select(&us, "SELECT * FROM users"); err != nil { + t.Fatalf("Select: %v", err) + } + if !cap.has("select-star") { + t.Fatalf("expected select-star finding, got %+v", cap.snapshot()) + } +} + +// TestWrappedDB_RedactsLiteralsByDefault is the headline 11.1 regression: +// the old hand-rolled check() set Result.Query to the raw SQL, so single- +// quoted literals leaked into log sinks. After the Guard rewrite Query +// must be the redacted form and Fingerprint must always be populated. +func TestWrappedDB_RedactsLiteralsByDefault(t *testing.T) { + w, cap := newWrappedWithCapture(t) + var us []user + if err := w.Select(&us, "SELECT * FROM users WHERE email = 'leak@example.com'"); err != nil { + t.Fatalf("Select: %v", err) + } + results := cap.snapshot() + if len(results) == 0 { + t.Fatal("expected at least one finding") + } + for _, r := range results { + if strings.Contains(r.Query, "leak@example.com") { + t.Errorf("literal leaked into Result.Query: %q (rule=%s)", r.Query, r.RuleName) + } + if r.Fingerprint == "" { + t.Errorf("Fingerprint must always be populated, got empty for rule %s", r.RuleName) + } + } +} + +// TestWrappedDB_SlowQueryReportedOnSuccess uses a zero threshold so any +// successful round trip trips the slow-query path. The integration-level +// claim under test is "slow-query check runs on success", not the threshold +// arithmetic itself (that lives in middleware.Guard's own tests). +func TestWrappedDB_SlowQueryReportedOnSuccess(t *testing.T) { + w, cap := newWrappedWithCapture(t, middleware.WithSlowQueryThreshold(0)) + var u user + if err := w.Get(&u, "SELECT id FROM users WHERE id = 1"); err != nil { + t.Fatalf("Get: %v", err) + } + if !cap.has("slow-query") { + t.Fatalf("expected slow-query finding with zero threshold, got %+v", cap.snapshot()) + } +} + +func TestWrappedDB_SlowQuerySuppressedOnError(t *testing.T) { + w, cap := newWrappedWithCapture(t, middleware.WithSlowQueryThreshold(0)) + var u user + if err := w.Get(&u, "SELECT id FROM no_such_table_xyz WHERE id = 1"); err == nil { + t.Fatal("expected error from selecting a missing table") + } + if cap.has("slow-query") { + t.Fatalf("slow-query must not fire when the query failed; got %+v", cap.snapshot()) + } +} + +func TestWrappedDB_NPlusOneAcrossCalls(t *testing.T) { + w, cap := newWrappedWithCapture(t, middleware.WithN1Detection(3, time.Second)) + var u user + for range 3 { + if err := w.Get(&u, "SELECT id FROM users WHERE id = 1"); err != nil { + t.Fatalf("Get: %v", err) + } + } + if !cap.has("n-plus-one") { + t.Fatalf("expected n-plus-one finding after 3 identical queries, got %+v", cap.snapshot()) + } +} + +func TestWrappedDB_ResetN1ClearsState(t *testing.T) { + w, cap := newWrappedWithCapture(t, middleware.WithN1Detection(3, time.Second)) + var u user + for range 2 { + if err := w.Get(&u, "SELECT id FROM users WHERE id = 1"); err != nil { + t.Fatalf("Get: %v", err) + } + } + w.ResetN1() + if err := w.Get(&u, "SELECT id FROM users WHERE id = 1"); err != nil { + t.Fatalf("Get: %v", err) + } + if cap.has("n-plus-one") { + t.Fatalf("n-plus-one should not fire — ResetN1 zeroed the counter; got %+v", cap.snapshot()) + } +} + +// Proves the non-SELECT and *Context paths also flow through Guard. +func TestWrappedDB_ExecAndContextVariantsAnalyzed(t *testing.T) { + w, cap := newWrappedWithCapture(t) + if _, err := w.Exec("DELETE FROM users"); err != nil { + t.Fatalf("Exec: %v", err) + } + if !cap.has("delete-without-where") { + t.Fatalf("expected delete-without-where from Exec path, got %+v", cap.snapshot()) + } + + if _, err := w.ExecContext(context.Background(), "UPDATE users SET email = 'x'"); err != nil { + t.Fatalf("ExecContext: %v", err) + } + if !cap.has("update-without-where") { + t.Fatalf("expected update-without-where from ExecContext path, got %+v", cap.snapshot()) + } +} + +func TestWrapSqlx_NilPanics(t *testing.T) { + defer func() { + if recover() == nil { + t.Fatal("expected panic on nil *sqlx.DB") + } + }() + WrapSqlx(nil) +} diff --git a/integrations/xormguard/go.mod b/integrations/xormguard/go.mod new file mode 100644 index 0000000..8cfd74c --- /dev/null +++ b/integrations/xormguard/go.mod @@ -0,0 +1,18 @@ +module github.com/KARTIKrocks/sqlguard/integrations/xormguard + +go 1.26 + +require ( + github.com/KARTIKrocks/sqlguard v0.0.0 + github.com/mattn/go-sqlite3 v1.14.45 + xorm.io/xorm v1.3.11 +) + +require ( + github.com/goccy/go-json v0.10.5 // indirect + github.com/golang/snappy v0.0.4 // indirect + github.com/syndtr/goleveldb v1.0.0 // indirect + xorm.io/builder v0.3.13 // indirect +) + +replace github.com/KARTIKrocks/sqlguard => ../.. diff --git a/integrations/xormguard/go.sum b/integrations/xormguard/go.sum new file mode 100644 index 0000000..eb82c3b --- /dev/null +++ b/integrations/xormguard/go.sum @@ -0,0 +1,94 @@ +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a h1:lSA0F4e9A2NcQSqGqTOXqu2aRi/XEQxDCBwM8yJtE6s= +gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a/go.mod h1:EXuID2Zs0pAQhH8yz+DNjUbjppKQzKFAn28TMYPB6IU= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= +github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= +github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= +github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= +github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.45 h1:6KA/spDguL3KV8rnybG7ezSaE4SeMR3KC9VbUoAQaIk= +github.com/mattn/go-sqlite3 v1.14.45/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ= +github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= +github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.7.0 h1:WSHQ+IS43OoUrWtD1/bbclrwK8TTH5hzp+umCiuxHgs= +github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/gomega v1.4.3 h1:RE1xgDvH7imwFD45h+u2SgIfERHlS2yNG4DObb5BSKU= +github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE= +github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ= +golang.org/x/mod v0.36.0 h1:JJjpVx6myfUsUdAzZuOSTTmRE0PfZeNWzzvKrP7amb4= +golang.org/x/mod v0.36.0/go.mod h1:moc6ELqsWcOw5Ef3xVprK5ul/MvtVvkIXLziUOICjUQ= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd h1:nTDtHvHSdCn1m6ITfMRqtOd/9+7a3s8RBNOZ3eYZzJA= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +golang.org/x/tools v0.45.0 h1:18qN3FAooORvApf5XjCXgsuayZOEtXf6JK18I3+ONa8= +golang.org/x/tools v0.45.0/go.mod h1:LuUGqqaXcXMEFEruIVJVm5mgDD8vww/z/SR1gQ4uE/0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +lukechampine.com/uint128 v1.2.0 h1:mBi/5l91vocEN8otkC5bDLhi2KdCticRiwbdB0O+rjI= +lukechampine.com/uint128 v1.2.0/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk= +modernc.org/cc/v3 v3.40.0 h1:P3g79IUS/93SYhtoeaHW+kRCIrYaxJ27MFPv+7kaTOw= +modernc.org/cc/v3 v3.40.0/go.mod h1:/bTg4dnWkSXowUO6ssQKnOV0yMVxDYNIsIrzqTFDGH0= +modernc.org/ccgo/v3 v3.16.13 h1:Mkgdzl46i5F/CNR/Kj80Ri59hC8TKAhZrYSaqvkwzUw= +modernc.org/ccgo/v3 v3.16.13/go.mod h1:2Quk+5YgpImhPjv2Qsob1DnZ/4som1lJTodubIcoUkY= +modernc.org/libc v1.55.3 h1:AzcW1mhlPNrRtjS5sS+eW2ISCgSOLLNyFzRh/V3Qj/U= +modernc.org/libc v1.55.3/go.mod h1:qFXepLhz+JjFThQ4kzwzOjA/y/artDeg+pcYnY+Q83w= +modernc.org/mathutil v1.6.0 h1:fRe9+AmYlaej+64JsEEhoWuAYBkOtQiMEU7n/XgfYi4= +modernc.org/mathutil v1.6.0/go.mod h1:Ui5Q9q1TR2gFm0AQRqQUaBWFLAhQpCwNcuhBOSedWPo= +modernc.org/memory v1.8.0 h1:IqGTL6eFMaDZZhEWwcREgeMXYwmW83LYW8cROZYkg+E= +modernc.org/memory v1.8.0/go.mod h1:XPZ936zp5OMKGWPqbD3JShgd/ZoQ7899TUuQqxY+peU= +modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4= +modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0= +modernc.org/sqlite v1.20.4 h1:J8+m2trkN+KKoE7jglyHYYYiaq5xmz2HoHJIiBlRzbE= +modernc.org/sqlite v1.20.4/go.mod h1:zKcGyrICaxNTMEHSr1HQ2GUraP0j+845GYw37+EyT6A= +modernc.org/strutil v1.2.0 h1:agBi9dp1I+eOnxXeiZawM8F4LawKv4NzGWSaLfyeNZA= +modernc.org/strutil v1.2.0/go.mod h1:/mdcBmfOibveCTBxUl5B5l6W+TTH1FXPLHZE6bTosX0= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= +xorm.io/builder v0.3.13 h1:a3jmiVVL19psGeXx8GIurTp7p0IIgqeDmwhcR6BAOAo= +xorm.io/builder v0.3.13/go.mod h1:aUW0S9eb9VCaPohFCH3j7czOx1PMW3i1HrSzbLYGBSE= +xorm.io/xorm v1.3.11 h1:i4tlVUASogb0ZZFJHA7dZqoRU2pUpUsutnNdaOlFyMI= +xorm.io/xorm v1.3.11/go.mod h1:cs0ePc8O4a0jD78cNvD+0VFwhqotTvLQZv372QsDw7Q= diff --git a/integrations/xormguard/xormguard.go b/integrations/xormguard/xormguard.go new file mode 100644 index 0000000..047bdbc --- /dev/null +++ b/integrations/xormguard/xormguard.go @@ -0,0 +1,74 @@ +// Package xormguard integrates sqlguard with xorm (xorm.io/xorm). +// +// Analysis is driven by the single shared sqlguard core (middleware.Guard), +// so redaction-by-default, stable fingerprints, the pluggable real-grammar +// parser, slow-query timing and N+1 detection behave identically to the +// database/sql driver wrapper, pgxguard, gormguard and bunguard. There is no +// parallel option surface — configure with the standard middleware options: +// +// engine, _ := xorm.NewEngine("postgres", dsn) +// engine.AddHook(xormguard.New( +// middleware.WithSlowQueryThreshold(500*time.Millisecond), +// middleware.WithN1Detection(10, time.Second), +// )) +// +// xorm's contexts.Hook exposes the rendered SQL and the measured execution +// time on the ContextHook in AfterProcess, so this uses the explicit +// Check+CheckLatency pair (matching gormguard): static rules run on every +// query, latency is reported only on success. +package xormguard + +import ( + "context" + + "github.com/KARTIKrocks/sqlguard/middleware" + "xorm.io/xorm/contexts" +) + +// Hook implements xorm's contexts.Hook and drives every traced statement +// through the shared sqlguard analysis core. +type Hook struct { + g *middleware.Guard +} + +// Compile-time proof we satisfy contexts.Hook. +var _ contexts.Hook = (*Hook)(nil) + +// New creates a new sqlguard xorm hook. It accepts the standard sqlguard +// middleware options (WithAnalyzer, WithReporter, WithSlowQueryThreshold, +// WithParser, WithN1Detection, …) — the same option set every other sqlguard +// surface uses, so there is no parallel configuration surface to drift. +func New(opts ...middleware.Option) *Hook { + return &Hook{g: middleware.NewGuard(opts...)} +} + +// ResetN1 clears N+1 tracker state. Call it at a per-request boundary +// (e.g. end of an HTTP handler) to scope N+1 detection to one unit of work. +// No-op unless WithN1Detection was passed to New. +func (h *Hook) ResetN1() { h.g.ResetN1() } + +// BeforeProcess implements contexts.Hook. xorm stamps the start time itself +// and reports the elapsed duration as ContextHook.ExecuteTime in +// AfterProcess, so there is nothing to do here but pass the context through. +func (h *Hook) BeforeProcess(c *contexts.ContextHook) (context.Context, error) { + return c.Ctx, nil +} + +// AfterProcess implements contexts.Hook. c.SQL holds the rendered SQL, +// c.ExecuteTime the measured latency, and c.Err the query error (which is +// returned unchanged so the hook never swallows it). +func (h *Hook) AfterProcess(c *contexts.ContextHook) error { + if c.SQL == "" { + return c.Err + } + + // Static rules + N+1 run on every call (matches Observe semantics). + h.g.Check(c.SQL) + + // Latency is reported only on success — a failed query's duration is + // meaningless. This mirrors middleware.Guard.Observe. + if c.Err == nil { + h.g.CheckLatency(c.SQL, c.ExecuteTime) + } + return c.Err +} diff --git a/integrations/xormguard/xormguard_test.go b/integrations/xormguard/xormguard_test.go new file mode 100644 index 0000000..c9b28f6 --- /dev/null +++ b/integrations/xormguard/xormguard_test.go @@ -0,0 +1,166 @@ +package xormguard + +import ( + "strings" + "sync" + "testing" + "time" + + "github.com/KARTIKrocks/sqlguard/analyzer" + "github.com/KARTIKrocks/sqlguard/middleware" + _ "github.com/mattn/go-sqlite3" + "xorm.io/xorm" +) + +// capture is a thread-safe in-memory Reporter for assertions. +type capture struct { + mu sync.Mutex + r []analyzer.Result +} + +func (c *capture) Report(rs []analyzer.Result) { + c.mu.Lock() + defer c.mu.Unlock() + c.r = append(c.r, rs...) +} + +func (c *capture) snapshot() []analyzer.Result { + c.mu.Lock() + defer c.mu.Unlock() + out := make([]analyzer.Result, len(c.r)) + copy(out, c.r) + return out +} + +func (c *capture) has(rule string) bool { + for _, r := range c.snapshot() { + if r.RuleName == rule { + return true + } + } + return false +} + +// newEngineWithCapture spins up an in-memory sqlite-backed *xorm.Engine with +// the sqlguard hook registered, so the integration runs end-to-end +// (contexts.Hook seam → driver round trip) rather than mocked. The hook is +// added after seeding so the capture starts clean. +func newEngineWithCapture(t *testing.T, opts ...middleware.Option) (*xorm.Engine, *capture, *Hook) { + t.Helper() + engine, err := xorm.NewEngine("sqlite3", ":memory:") + if err != nil { + t.Fatalf("NewEngine: %v", err) + } + t.Cleanup(func() { _ = engine.Close() }) + + if _, err := engine.Exec("CREATE TABLE users (id INTEGER PRIMARY KEY, email TEXT)"); err != nil { + t.Fatalf("create table: %v", err) + } + if _, err := engine.Exec("INSERT INTO users (id, email) VALUES (?, ?)", 1, "leak@example.com"); err != nil { + t.Fatalf("seed: %v", err) + } + + cap := &capture{} + opts = append([]middleware.Option{middleware.WithReporter(cap)}, opts...) + hook := New(opts...) + engine.AddHook(hook) + return engine, cap, hook +} + +func TestHook_DetectsRawSelectStar(t *testing.T) { + engine, cap, _ := newEngineWithCapture(t) + if _, err := engine.QueryString("SELECT * FROM users"); err != nil { + t.Fatalf("QueryString: %v", err) + } + if !cap.has("select-star") { + t.Fatalf("expected select-star finding, got %+v", cap.snapshot()) + } +} + +// TestHook_RedactsLiteralsByDefault asserts the headline redaction guarantee: +// single-quoted literals never reach Result.Query and Fingerprint is always +// populated. +func TestHook_RedactsLiteralsByDefault(t *testing.T) { + engine, cap, _ := newEngineWithCapture(t) + if _, err := engine.QueryString("SELECT * FROM users WHERE email = 'leak@example.com'"); err != nil { + t.Fatalf("QueryString: %v", err) + } + results := cap.snapshot() + if len(results) == 0 { + t.Fatal("expected at least one finding") + } + for _, r := range results { + if strings.Contains(r.Query, "leak@example.com") { + t.Errorf("literal leaked into Result.Query: %q (rule=%s)", r.Query, r.RuleName) + } + if r.Fingerprint == "" { + t.Errorf("Fingerprint must always be populated, got empty for rule %s", r.RuleName) + } + } +} + +func TestHook_SlowQueryReportedOnSuccess(t *testing.T) { + engine, cap, _ := newEngineWithCapture(t, middleware.WithSlowQueryThreshold(0)) + if _, err := engine.QueryString("SELECT id FROM users WHERE id = 1"); err != nil { + t.Fatalf("QueryString: %v", err) + } + if !cap.has("slow-query") { + t.Fatalf("expected slow-query finding with zero threshold, got %+v", cap.snapshot()) + } +} + +func TestHook_SlowQuerySuppressedOnError(t *testing.T) { + engine, cap, _ := newEngineWithCapture(t, middleware.WithSlowQueryThreshold(0)) + _, err := engine.QueryString("SELECT id FROM no_such_table_xyz WHERE id = 1") + if err == nil { + t.Fatal("expected error from selecting a missing table") + } + if cap.has("slow-query") { + t.Fatalf("slow-query must not fire when the query failed; got %+v", cap.snapshot()) + } +} + +func TestHook_NPlusOneAcrossCalls(t *testing.T) { + engine, cap, _ := newEngineWithCapture(t, middleware.WithN1Detection(3, time.Second)) + for range 3 { + if _, err := engine.QueryString("SELECT id FROM users WHERE id = 1"); err != nil { + t.Fatalf("QueryString: %v", err) + } + } + if !cap.has("n-plus-one") { + t.Fatalf("expected n-plus-one finding after 3 identical queries, got %+v", cap.snapshot()) + } +} + +func TestHook_ResetN1ClearsState(t *testing.T) { + engine, cap, hook := newEngineWithCapture(t, middleware.WithN1Detection(3, time.Second)) + for range 2 { + if _, err := engine.QueryString("SELECT id FROM users WHERE id = 1"); err != nil { + t.Fatalf("QueryString: %v", err) + } + } + hook.ResetN1() + if _, err := engine.QueryString("SELECT id FROM users WHERE id = 1"); err != nil { + t.Fatalf("QueryString: %v", err) + } + if cap.has("n-plus-one") { + t.Fatalf("n-plus-one should not fire — ResetN1 zeroed the counter; got %+v", cap.snapshot()) + } +} + +// Proves UPDATE / DELETE statements also flow through Guard. +func TestHook_UpdateAndDeleteAnalyzed(t *testing.T) { + engine, cap, _ := newEngineWithCapture(t) + if _, err := engine.Exec("UPDATE users SET email = 'x'"); err != nil { + t.Fatalf("UPDATE: %v", err) + } + if !cap.has("update-without-where") { + t.Fatalf("expected update-without-where, got %+v", cap.snapshot()) + } + if _, err := engine.Exec("DELETE FROM users"); err != nil { + t.Fatalf("DELETE: %v", err) + } + if !cap.has("delete-without-where") { + t.Fatalf("expected delete-without-where, got %+v", cap.snapshot()) + } +} diff --git a/middleware/cache.go b/middleware/cache.go new file mode 100644 index 0000000..2389f1c --- /dev/null +++ b/middleware/cache.go @@ -0,0 +1,84 @@ +package middleware + +import ( + "container/list" + "sync" + + "github.com/KARTIKrocks/sqlguard/analyzer" +) + +// analysisCache memoizes analyzer.Analyze results so each distinct query is +// parsed and rule-checked once instead of on every execution. It is a bounded +// LRU keyed on the **exact** query string. +// +// Why the exact string and not the fingerprint: the fingerprint folds away +// literal values, but a few rules read literal-derived facts the fingerprint +// discards — large-offset (OffsetValue), in-list-too-large (MaxInListLen), and +// leading-wildcard min-length (LeadingWildcardTermLen). Two queries can share a +// fingerprint yet warrant different findings, so fingerprint-keying would cache +// a wrong verdict. Identical query strings always analyze identically, which +// makes the exact string the only fully-correct key — and an effective one, +// since parameterized queries (the common case) and repeated identical queries +// hit while varying-literal queries miss (and those need re-analysis anyway). +// +// Cached result slices are shared and must be treated as read-only by callers +// (Guard.report and the reporters do). An analysisCache is safe for concurrent +// use. +type analysisCache struct { + mu sync.Mutex + ll *list.List + items map[string]*list.Element + capacity int +} + +type cacheEntry struct { + key string + results []analyzer.Result +} + +func newAnalysisCache(capacity int) *analysisCache { + return &analysisCache{ + ll: list.New(), + items: make(map[string]*list.Element), + capacity: capacity, + } +} + +// get returns the cached results for query and true, or nil and false on a +// miss. A cached empty/nil slice is a hit (the query was analyzed and produced +// no findings) — exactly the common case worth memoizing. +func (c *analysisCache) get(query string) ([]analyzer.Result, bool) { + c.mu.Lock() + defer c.mu.Unlock() + if el, ok := c.items[query]; ok { + c.ll.MoveToFront(el) + return el.Value.(*cacheEntry).results, true + } + return nil, false +} + +// put stores results for query, evicting the least-recently-used entry when the +// cache exceeds its capacity. +func (c *analysisCache) put(query string, results []analyzer.Result) { + c.mu.Lock() + defer c.mu.Unlock() + if el, ok := c.items[query]; ok { + c.ll.MoveToFront(el) + el.Value.(*cacheEntry).results = results + return + } + el := c.ll.PushFront(&cacheEntry{key: query, results: results}) + c.items[query] = el + if c.ll.Len() > c.capacity { + if oldest := c.ll.Back(); oldest != nil { + c.ll.Remove(oldest) + delete(c.items, oldest.Value.(*cacheEntry).key) + } + } +} + +func (c *analysisCache) len() int { + c.mu.Lock() + defer c.mu.Unlock() + return c.ll.Len() +} diff --git a/middleware/cache_test.go b/middleware/cache_test.go new file mode 100644 index 0000000..2ee4634 --- /dev/null +++ b/middleware/cache_test.go @@ -0,0 +1,125 @@ +package middleware + +import ( + "testing" + + "github.com/KARTIKrocks/sqlguard/analyzer" +) + +func TestAnalysisCache_HitMissAndStore(t *testing.T) { + c := newAnalysisCache(4) + + if _, ok := c.get("q"); ok { + t.Fatal("empty cache should miss") + } + + res := []analyzer.Result{{RuleName: "select-star"}} + c.put("q", res) + + got, ok := c.get("q") + if !ok { + t.Fatal("expected a hit after put") + } + if len(got) != 1 || got[0].RuleName != "select-star" { + t.Errorf("cached results mismatch: %+v", got) + } +} + +func TestAnalysisCache_CachesEmptyResults(t *testing.T) { + c := newAnalysisCache(4) + c.put("clean", nil) // a query that produced no findings is still worth caching + + got, ok := c.get("clean") + if !ok { + t.Fatal("a cached no-findings query must be a hit, not a miss") + } + if len(got) != 0 { + t.Errorf("expected zero findings, got %d", len(got)) + } +} + +func TestAnalysisCache_LRUEviction(t *testing.T) { + c := newAnalysisCache(2) + c.put("a", nil) + c.put("b", nil) + // Touch "a" so "b" becomes least-recently-used. + if _, ok := c.get("a"); !ok { + t.Fatal("a should still be present") + } + c.put("c", nil) // exceeds capacity -> evict LRU ("b") + + if _, ok := c.get("b"); ok { + t.Error("b should have been evicted as least-recently-used") + } + if _, ok := c.get("a"); !ok { + t.Error("a should survive (recently used)") + } + if _, ok := c.get("c"); !ok { + t.Error("c should be present (just added)") + } + if c.len() > 2 { + t.Errorf("cache exceeded capacity: %d", c.len()) + } +} + +// The cache must not change which findings are produced, including for the +// literal-sensitive rules whose verdict the fingerprint would have folded away. +func TestGuard_CacheCorrectForLiteralSensitiveRules(t *testing.T) { + rep := &countingReporter{} + g := NewGuard(WithReporter(rep), WithFindingDedup(0)) // dedup off to count every finding + + // Same fingerprint ("... OFFSET ?"), different OffsetValue: only the first + // crosses the large-offset threshold (default 1000). WHERE + LIMIT keep the + // only finding large-offset. If the cache keyed on fingerprint, the second + // would wrongly inherit the first's finding. + g.Check("SELECT id FROM users WHERE tenant = ? ORDER BY id LIMIT 10 OFFSET 5000") + if rep.count() != 1 { + t.Fatalf("expected exactly large-offset on OFFSET 5000, got %d findings", rep.count()) + } + g.Check("SELECT id FROM users WHERE tenant = ? ORDER BY id LIMIT 10 OFFSET 10") + if rep.count() != 1 { + t.Errorf("OFFSET 10 must not inherit a cached large-offset finding; total findings = %d", rep.count()) + } +} + +func TestGuard_CacheReturnsConsistentFindingsOnRepeat(t *testing.T) { + rep := &countingReporter{} + g := NewGuard(WithReporter(rep), WithFindingDedup(0)) + + for range 5 { + g.Check("DELETE FROM accounts") + } + // Cache must not swallow findings: dedup is off, so all 5 are reported. + if got := rep.count(); got != 5 { + t.Errorf("expected 5 findings across 5 identical calls, got %d", got) + } +} + +func TestGuard_CacheDisabled(t *testing.T) { + g := NewGuard(WithAnalysisCacheSize(0)) + if g.cache != nil { + t.Error("cache size 0 should leave the cache nil (disabled)") + } + // Still functions without a cache. + g.Check("DELETE FROM accounts") +} + +// benchQuery is a clean, parameterized query: representative of the prod-common +// case and produces no findings, so Check reduces to the analyze path. +const benchQuery = "SELECT id, name FROM users WHERE id = ? AND tenant = ?" + +func BenchmarkGuardCheck_Cached(b *testing.B) { + g := NewGuard(WithReporter(&countingReporter{}), WithFindingDedup(0)) + b.ReportAllocs() + for b.Loop() { + g.Check(benchQuery) + } +} + +func BenchmarkGuardCheck_Uncached(b *testing.B) { + g := NewGuard(WithReporter(&countingReporter{}), WithFindingDedup(0), WithAnalysisCacheSize(0)) + b.ReportAllocs() + for b.Loop() { + g.Check(benchQuery) + } +} diff --git a/middleware/dedup.go b/middleware/dedup.go new file mode 100644 index 0000000..f63f275 --- /dev/null +++ b/middleware/dedup.go @@ -0,0 +1,74 @@ +package middleware + +import ( + "sync" + "time" +) + +// deduper suppresses repeat emission of the same finding within a time window. +// A finding's identity is (fingerprint, ruleName): the same rule firing on the +// same canonical query shape. Without it, Guard.Check would re-emit every +// static finding on every execution of a recurring query (or every Exec of a +// prepared statement) and flood the log sink. The N+1 detector already +// self-dedups and slow-query is intentionally per-execution; this covers the +// per-query static rules. It reuses the QueryTracker windowing shape. +// +// A deduper is safe for concurrent use. +type deduper struct { + mu sync.Mutex + seen map[string]time.Time // key -> time the finding was last allowed + window time.Duration + maxKeys int +} + +func newDeduper(window time.Duration) *deduper { + return &deduper{ + seen: make(map[string]time.Time), + window: window, + maxKeys: 10000, + } +} + +// allow reports whether a finding identified by (fingerprint, rule) should be +// emitted at time now. It returns true the first time the finding is seen and +// again only after window has elapsed since it was last allowed. A window <= 0 +// disables dedup, so every call returns true (the legacy report-every-time +// behavior). +func (d *deduper) allow(fingerprint, rule string, now time.Time) bool { + if d.window <= 0 { + return true + } + key := fingerprint + "\x00" + rule + + d.mu.Lock() + defer d.mu.Unlock() + + if len(d.seen) >= d.maxKeys { + d.evictExpired(now) + // If eviction freed nothing (every entry still in-window) and this is a + // new key, drop the finding rather than grow the map without bound. A + // key already present is still updated below — never lose dedup state + // for a finding we're actively tracking. + if len(d.seen) >= d.maxKeys { + if _, ok := d.seen[key]; !ok { + return false + } + } + } + + last, ok := d.seen[key] + if !ok || now.Sub(last) > d.window { + d.seen[key] = now + return true + } + return false +} + +// evictExpired removes entries whose window has elapsed. Caller holds the lock. +func (d *deduper) evictExpired(now time.Time) { + for k, t := range d.seen { + if now.Sub(t) > d.window { + delete(d.seen, k) + } + } +} diff --git a/middleware/dedup_test.go b/middleware/dedup_test.go new file mode 100644 index 0000000..b400563 --- /dev/null +++ b/middleware/dedup_test.go @@ -0,0 +1,163 @@ +package middleware + +import ( + "strings" + "sync" + "testing" + "time" + + "github.com/KARTIKrocks/sqlguard/analyzer" +) + +// countingReporter records every Result it is handed, concurrency-safe. +type countingReporter struct { + mu sync.Mutex + results []analyzer.Result +} + +func (c *countingReporter) Report(rs []analyzer.Result) { + c.mu.Lock() + defer c.mu.Unlock() + c.results = append(c.results, rs...) +} + +func (c *countingReporter) count() int { + c.mu.Lock() + defer c.mu.Unlock() + return len(c.results) +} + +func TestDeduper_AllowsFirstSuppressesRepeatThenReReportsAfterWindow(t *testing.T) { + d := newDeduper(time.Minute) + now := time.Now() + + if !d.allow("fp", "select-star", now) { + t.Fatal("first occurrence should be allowed") + } + if d.allow("fp", "select-star", now) { + t.Error("repeat within window should be suppressed") + } + if !d.allow("fp", "select-star", now.Add(2*time.Minute)) { + t.Error("occurrence after window elapsed should be allowed again") + } +} + +func TestDeduper_DistinctIdentitiesIndependent(t *testing.T) { + d := newDeduper(time.Minute) + now := time.Now() + + // Different rule, same fingerprint. + if !d.allow("fp", "select-star", now) || !d.allow("fp", "select-without-limit", now) { + t.Error("distinct rules on the same fingerprint should each be allowed once") + } + // Different fingerprint, same rule. + if !d.allow("fp2", "select-star", now) { + t.Error("same rule on a distinct fingerprint should be allowed") + } +} + +func TestDeduper_DisabledWindowAlwaysAllows(t *testing.T) { + d := newDeduper(0) + now := time.Now() + for i := range 5 { + if !d.allow("fp", "select-star", now) { + t.Errorf("window<=0 disables dedup; call %d should be allowed", i) + } + } +} + +func TestDeduper_BoundedAtMaxKeys(t *testing.T) { + // All entries stay in-window, so eviction frees nothing: new keys past the + // cap must be dropped rather than grow the map without bound. + d := &deduper{seen: map[string]time.Time{}, window: time.Hour, maxKeys: 2} + now := time.Now() + + if !d.allow("a", "r", now) || !d.allow("b", "r", now) { + t.Fatal("first two distinct keys should be allowed") + } + if d.allow("c", "r", now) { + t.Error("a new key past maxKeys with nothing to evict should be dropped") + } + if len(d.seen) > d.maxKeys { + t.Errorf("map grew past maxKeys: %d", len(d.seen)) + } + // An already-tracked key is still served (dedup state is not lost). + if d.allow("a", "r", now) { + t.Error("an in-window tracked key should remain suppressed, not re-reported") + } +} + +func TestGuard_DedupSuppressesRepeatStaticFindings(t *testing.T) { + rep := &countingReporter{} + g := NewGuard(WithReporter(rep)) // default dedup window = 1m + + // DELETE without WHERE triggers exactly one rule (delete-without-where). + for range 10 { + g.Check("DELETE FROM accounts") + } + + if got := rep.count(); got != 1 { + t.Errorf("expected 1 static finding for a repeated query, got %d", got) + } +} + +func TestGuard_DedupDisabledReportsEveryTime(t *testing.T) { + rep := &countingReporter{} + g := NewGuard(WithReporter(rep), WithFindingDedup(0)) + + for range 5 { + g.Check("DELETE FROM accounts") + } + + if got := rep.count(); got != 5 { + t.Errorf("with dedup disabled expected 5 findings, got %d", got) + } +} + +func TestGuard_DedupIsPerIdentityNotPerQuery(t *testing.T) { + rep := &countingReporter{} + g := NewGuard(WithReporter(rep)) + + // Two literal variants share one fingerprint -> one select-star finding. + g.Check("SELECT * FROM users WHERE id = 1") + g.Check("SELECT * FROM users WHERE id = 2") + // A genuinely different flagged query is reported independently. + g.Check("DELETE FROM accounts") + + if got := rep.count(); got != 2 { + t.Errorf("expected 2 findings (one per identity), got %d", got) + } +} + +func TestGuard_DedupConcurrent(t *testing.T) { + rep := &countingReporter{} + g := NewGuard(WithReporter(rep)) + + var wg sync.WaitGroup + for range 100 { + wg.Go(func() { + g.Check("DELETE FROM accounts") + }) + } + wg.Wait() + + if got := rep.count(); got != 1 { + t.Errorf("expected exactly 1 finding under concurrency, got %d", got) + } +} + +func TestDriver_DedupRepeatedStaticFinding(t *testing.T) { + db, buf := guardedWithBuffer(t) + + for range 10 { + rows, err := db.Query("SELECT * FROM users") + if err != nil { + t.Fatalf("query: %v", err) + } + rows.Close() + } + + if n := strings.Count(buf.String(), "select-star"); n != 1 { + t.Errorf("expected select-star reported once across 10 executions, got %d", n) + } +} diff --git a/middleware/driver.go b/middleware/driver.go new file mode 100644 index 0000000..8be10ea --- /dev/null +++ b/middleware/driver.go @@ -0,0 +1,390 @@ +package middleware + +import ( + "context" + "database/sql" + "database/sql/driver" + "errors" + "fmt" +) + +// This file implements the standard database/sql driver-wrapping pattern +// (the approach used by ngrok/sqlmw, luna-duclos/instrumentedsql and +// OpenTelemetry's otelsql), hand-written with zero dependencies. +// +// Wrapping at the driver.Driver layer means every query — including those +// issued by ORMs and query builders through database/sql internals — flows +// through the analyzer automatically. There is no method list to keep in +// sync with database/sql, and the result is a real *sql.DB that composes +// with sqlc, ent, sqlx, gorm, pgx-stdlib and anything else. +// +// Optional driver interfaces (QueryerContext, Pinger, SessionResetter, …) +// are forwarded only when the wrapped driver implements them. Because the +// wrapper type structurally implements every optional interface, database/sql +// will always call them; the wrapper returns driver.ErrSkip (or the documented +// no-op) when the base does not support an operation, so database/sql falls +// back exactly as it would for the bare driver. This preserves the base +// driver's behavior without the combinatorial type-switch other libraries use. + +// Register wraps the database/sql driver currently registered under +// baseDriver and registers the analyzed result under name. Afterwards +// sql.Open(name, dsn) yields a *sql.DB whose every query is analyzed. +// +// middleware.Register("sqlguard-sqlite", "sqlite3") +// db, _ := sql.Open("sqlguard-sqlite", ":memory:") +// +// It returns an error if name is already registered or baseDriver is not +// a known driver. +func Register(name, baseDriver string, opts ...Option) (err error) { + // sql.Open does not connect; it only resolves the registered driver, + // so this is a cheap way to obtain the base driver.Driver by name. + probe, oerr := sql.Open(baseDriver, "") + if oerr != nil { + return fmt.Errorf("sqlguard: base driver %q: %w", baseDriver, oerr) + } + base := probe.Driver() + _ = probe.Close() + + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("sqlguard: register %q: %v", name, r) + } + }() + sql.Register(name, WrapDriver(base, opts...)) + return nil +} + +// OpenDB wraps a driver.Connector and returns an analyzed *sql.DB. Use this +// when you already hold a connector — for example pgx's stdlib.GetConnector +// or a driver-specific Connector — and don't want a global registration. +// +// connector := stdlib.GetConnector(*pgxConfig) +// db := middleware.OpenDB(connector) +func OpenDB(c driver.Connector, opts ...Option) *sql.DB { + return sql.OpenDB(WrapConnector(c, opts...)) +} + +// WrapDriver returns a driver.Driver that analyzes every query executed +// through it. The returned driver also implements driver.DriverContext so +// connector-based pooling is preserved. +func WrapDriver(base driver.Driver, opts ...Option) driver.Driver { + return &wDriver{base: base, g: NewGuard(opts...)} +} + +// WrapConnector returns a driver.Connector that analyzes every query +// executed through connections it produces. +func WrapConnector(base driver.Connector, opts ...Option) driver.Connector { + return &wConnector{base: base, g: NewGuard(opts...)} +} + +// ---- driver.Driver / driver.DriverContext ---- + +type wDriver struct { + base driver.Driver + g *Guard +} + +var ( + _ driver.Driver = (*wDriver)(nil) + _ driver.DriverContext = (*wDriver)(nil) +) + +func (d *wDriver) Open(name string) (driver.Conn, error) { + c, err := d.base.Open(name) + if err != nil { + return nil, err + } + return &wConn{base: c, g: d.g}, nil +} + +// OpenConnector implements driver.DriverContext. If the base driver supports +// connectors we wrap its connector; otherwise we synthesize a DSN-based +// connector equivalent to the one database/sql builds internally. +func (d *wDriver) OpenConnector(name string) (driver.Connector, error) { + if dc, ok := d.base.(driver.DriverContext); ok { + bc, err := dc.OpenConnector(name) + if err != nil { + return nil, err + } + return &wConnector{base: bc, g: d.g}, nil + } + return &wConnector{base: dsnConnector{dsn: name, driver: d.base}, g: d.g}, nil +} + +// dsnConnector mirrors database/sql's internal dsnConnector for base drivers +// that do not implement driver.DriverContext. +type dsnConnector struct { + dsn string + driver driver.Driver +} + +func (c dsnConnector) Connect(_ context.Context) (driver.Conn, error) { + return c.driver.Open(c.dsn) +} +func (c dsnConnector) Driver() driver.Driver { return c.driver } + +// ---- driver.Connector ---- + +type wConnector struct { + base driver.Connector + g *Guard +} + +var _ driver.Connector = (*wConnector)(nil) + +func (c *wConnector) Connect(ctx context.Context) (driver.Conn, error) { + conn, err := c.base.Connect(ctx) + if err != nil { + return nil, err + } + return &wConn{base: conn, g: c.g}, nil +} + +func (c *wConnector) Driver() driver.Driver { + return &wDriver{base: c.base.Driver(), g: c.g} +} + +// ---- driver.Conn and its optional interfaces ---- + +type wConn struct { + base driver.Conn + g *Guard +} + +var ( + _ driver.Conn = (*wConn)(nil) + _ driver.ConnPrepareContext = (*wConn)(nil) + _ driver.ConnBeginTx = (*wConn)(nil) + _ driver.QueryerContext = (*wConn)(nil) + _ driver.ExecerContext = (*wConn)(nil) + _ driver.Pinger = (*wConn)(nil) + _ driver.SessionResetter = (*wConn)(nil) + _ driver.Validator = (*wConn)(nil) + _ driver.NamedValueChecker = (*wConn)(nil) +) + +func (c *wConn) Prepare(query string) (driver.Stmt, error) { + s, err := c.base.Prepare(query) + if err != nil { + return nil, err + } + return &wStmt{base: s, query: query, g: c.g}, nil +} + +func (c *wConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + var ( + s driver.Stmt + err error + ) + if cpc, ok := c.base.(driver.ConnPrepareContext); ok { + s, err = cpc.PrepareContext(ctx, query) + } else { + s, err = c.base.Prepare(query) + } + if err != nil { + return nil, err + } + return &wStmt{base: s, query: query, g: c.g}, nil +} + +func (c *wConn) Close() error { return c.base.Close() } + +func (c *wConn) Begin() (driver.Tx, error) { + tx, err := c.base.Begin() //nolint:staticcheck // delegated deprecated path + if err != nil { + return nil, err + } + return &wTx{base: tx}, nil +} + +func (c *wConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + var ( + tx driver.Tx + err error + ) + if cbt, ok := c.base.(driver.ConnBeginTx); ok { + tx, err = cbt.BeginTx(ctx, opts) + } else { + tx, err = c.base.Begin() //nolint:staticcheck // delegated deprecated path + } + if err != nil { + return nil, err + } + return &wTx{base: tx}, nil +} + +func (c *wConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + // Observe only on a path that actually executes. When the base has no direct + // Query entry point we return driver.ErrSkip *without* analyzing, so + // database/sql's Prepare+Query fallback — which re-enters through wStmt — is + // the single place this query is analyzed. Analyzing here too would count + // the same logical query twice (a duplicate finding and an inflated N+1). + if qc, ok := c.base.(driver.QueryerContext); ok { + done := c.g.Observe(query) + rows, err := qc.QueryContext(ctx, query, args) + done(err) + return rows, err + } + if q, ok := c.base.(driver.Queryer); ok { //nolint:staticcheck // legacy fallback + values, verr := namedToValues(args) + if verr != nil { + return nil, verr + } + done := c.g.Observe(query) + rows, err := q.Query(query, values) //nolint:staticcheck // legacy fallback + done(err) + return rows, err + } + return nil, driver.ErrSkip +} + +func (c *wConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + // See QueryContext: analyze only when this path executes. Returning ErrSkip + // without analyzing lets the Prepare+Exec fallback (via wStmt) be the single + // analysis point, avoiding a double count. + if ec, ok := c.base.(driver.ExecerContext); ok { + done := c.g.Observe(query) + res, err := ec.ExecContext(ctx, query, args) + done(err) + return res, err + } + if e, ok := c.base.(driver.Execer); ok { //nolint:staticcheck // legacy fallback + values, verr := namedToValues(args) + if verr != nil { + return nil, verr + } + done := c.g.Observe(query) + res, err := e.Exec(query, values) //nolint:staticcheck // legacy fallback + done(err) + return res, err + } + return nil, driver.ErrSkip +} + +func (c *wConn) Ping(ctx context.Context) error { + if p, ok := c.base.(driver.Pinger); ok { + return p.Ping(ctx) + } + // Base is not a Pinger; ErrSkip tells database/sql ping is unsupported + // and the connection should be assumed valid, matching the bare driver. + return driver.ErrSkip +} + +func (c *wConn) ResetSession(ctx context.Context) error { + if r, ok := c.base.(driver.SessionResetter); ok { + return r.ResetSession(ctx) + } + return nil +} + +func (c *wConn) IsValid() bool { + if v, ok := c.base.(driver.Validator); ok { + return v.IsValid() + } + return true +} + +func (c *wConn) CheckNamedValue(nv *driver.NamedValue) error { + if ck, ok := c.base.(driver.NamedValueChecker); ok { + return ck.CheckNamedValue(nv) + } + // Defer to database/sql's default argument conversion. + return driver.ErrSkip +} + +// ---- driver.Stmt and its optional interfaces ---- + +type wStmt struct { + base driver.Stmt + query string + g *Guard +} + +var ( + _ driver.Stmt = (*wStmt)(nil) + _ driver.StmtExecContext = (*wStmt)(nil) + _ driver.StmtQueryContext = (*wStmt)(nil) + _ driver.NamedValueChecker = (*wStmt)(nil) +) + +func (s *wStmt) Close() error { return s.base.Close() } +func (s *wStmt) NumInput() int { return s.base.NumInput() } + +func (s *wStmt) Exec(args []driver.Value) (driver.Result, error) { + done := s.g.Observe(s.query) + res, err := s.base.Exec(args) //nolint:staticcheck // delegated deprecated path + done(err) + return res, err +} + +func (s *wStmt) Query(args []driver.Value) (driver.Rows, error) { + done := s.g.Observe(s.query) + rows, err := s.base.Query(args) //nolint:staticcheck // delegated deprecated path + done(err) + return rows, err +} + +func (s *wStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + done := s.g.Observe(s.query) + if ec, ok := s.base.(driver.StmtExecContext); ok { + res, err := ec.ExecContext(ctx, args) + done(err) + return res, err + } + values, verr := namedToValues(args) + if verr != nil { + return nil, verr + } + res, err := s.base.Exec(values) //nolint:staticcheck // legacy fallback + done(err) + return res, err +} + +func (s *wStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + done := s.g.Observe(s.query) + if qc, ok := s.base.(driver.StmtQueryContext); ok { + rows, err := qc.QueryContext(ctx, args) + done(err) + return rows, err + } + values, verr := namedToValues(args) + if verr != nil { + return nil, verr + } + rows, err := s.base.Query(values) //nolint:staticcheck // legacy fallback + done(err) + return rows, err +} + +func (s *wStmt) CheckNamedValue(nv *driver.NamedValue) error { + if ck, ok := s.base.(driver.NamedValueChecker); ok { + return ck.CheckNamedValue(nv) + } + return driver.ErrSkip +} + +// ---- driver.Tx ---- + +type wTx struct { + base driver.Tx +} + +var _ driver.Tx = (*wTx)(nil) + +func (t *wTx) Commit() error { return t.base.Commit() } +func (t *wTx) Rollback() error { return t.base.Rollback() } + +// ---- helpers ---- + +// namedToValues converts named values to positional values for the legacy +// Queryer/Execer/Stmt fallback paths, which predate named parameters. +func namedToValues(named []driver.NamedValue) ([]driver.Value, error) { + values := make([]driver.Value, len(named)) + for i, nv := range named { + if nv.Name != "" { + return nil, errors.New("sqlguard: driver does not support named parameters") + } + values[i] = nv.Value + } + return values, nil +} diff --git a/middleware/driver_fallback_test.go b/middleware/driver_fallback_test.go new file mode 100644 index 0000000..071ee3b --- /dev/null +++ b/middleware/driver_fallback_test.go @@ -0,0 +1,111 @@ +package middleware + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "io" + "testing" + "time" +) + +// fakeNoQueryerDriver is a minimal driver whose Conn implements neither +// QueryerContext/ExecerContext nor the legacy Queryer/Execer. database/sql is +// therefore forced down its Prepare+Stmt fallback path for every Query/Exec — +// the path where wConn.{Query,Exec}Context return driver.ErrSkip. It exists to +// prove a single logical query is analyzed exactly once even then. +type fakeNoQueryerDriver struct{} + +func (fakeNoQueryerDriver) Open(string) (driver.Conn, error) { return &fakeConn{}, nil } + +type fakeConn struct{} + +func (*fakeConn) Prepare(string) (driver.Stmt, error) { return &fakeStmt{}, nil } +func (*fakeConn) Close() error { return nil } +func (*fakeConn) Begin() (driver.Tx, error) { return &fakeTx{}, nil } + +type fakeStmt struct{} + +func (*fakeStmt) Close() error { return nil } +func (*fakeStmt) NumInput() int { return -1 } // skip arg-count checking +func (*fakeStmt) Exec([]driver.Value) (driver.Result, error) { return driver.RowsAffected(0), nil } +func (*fakeStmt) Query([]driver.Value) (driver.Rows, error) { return &fakeRows{}, nil } + +type fakeRows struct{} + +func (*fakeRows) Columns() []string { return nil } +func (*fakeRows) Close() error { return nil } +func (*fakeRows) Next([]driver.Value) error { return io.EOF } + +type fakeTx struct{} + +func (*fakeTx) Commit() error { return nil } +func (*fakeTx) Rollback() error { return nil } + +// openFakeGuarded registers a wrapped fakeNoQueryerDriver and returns the DB +// plus the reporter that records findings. Dedup is off so every analysis is +// counted (the bug would surface as 2 findings for one query). +func openFakeGuarded(t *testing.T) (*sql.DB, *countingReporter) { + t.Helper() + rep := &countingReporter{} + name := fmt.Sprintf("sqlguard-fake-%d", driverSeq.Add(1)) + sql.Register(name, WrapDriver(fakeNoQueryerDriver{}, WithReporter(rep), WithFindingDedup(0))) + db, err := sql.Open(name, "") + if err != nil { + t.Fatalf("open: %v", err) + } + t.Cleanup(func() { db.Close() }) + return db, rep +} + +func TestDriver_NoQueryerContextAnalyzedOnce(t *testing.T) { + db, rep := openFakeGuarded(t) + + rows, err := db.Query("DELETE FROM accounts") // flagged: delete-without-where + if err != nil { + t.Fatalf("query: %v", err) + } + rows.Close() + + if got := rep.count(); got != 1 { + t.Errorf("expected one logical query analyzed once via the prepare fallback, got %d", got) + } +} + +func TestDriver_NoExecerContextAnalyzedOnce(t *testing.T) { + db, rep := openFakeGuarded(t) + + if _, err := db.Exec("DELETE FROM accounts"); err != nil { + t.Fatalf("exec: %v", err) + } + + if got := rep.count(); got != 1 { + t.Errorf("expected one logical query analyzed once via the prepare fallback, got %d", got) + } +} + +// With N+1 enabled, each logical query must increment the counter once. If the +// ErrSkip path double-counted, threshold=2 would trip after a single query. +func TestDriver_NoQueryerContextN1CountedOnce(t *testing.T) { + rep := &countingReporter{} + name := fmt.Sprintf("sqlguard-fake-%d", driverSeq.Add(1)) + sql.Register(name, WrapDriver(fakeNoQueryerDriver{}, + WithReporter(rep), WithFindingDedup(0), WithN1Detection(2, time.Minute))) + db, err := sql.Open(name, "") + if err != nil { + t.Fatalf("open: %v", err) + } + defer db.Close() + + // One execution of a non-flagged query: no static finding, and the N+1 + // counter should be at 1 (below threshold 2), so nothing is reported. + rows, err := db.Query("SELECT id, name FROM users WHERE id = ?", 1) + if err != nil { + t.Fatalf("query: %v", err) + } + rows.Close() + + if got := rep.count(); got != 0 { + t.Errorf("one logical query must not trip N+1 (threshold 2); got %d reports", got) + } +} diff --git a/middleware/driver_test.go b/middleware/driver_test.go new file mode 100644 index 0000000..a33ed1b --- /dev/null +++ b/middleware/driver_test.go @@ -0,0 +1,265 @@ +package middleware + +import ( + "bytes" + "database/sql" + "fmt" + "path/filepath" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/KARTIKrocks/sqlguard/analyzer" + "github.com/KARTIKrocks/sqlguard/reporter" + + _ "github.com/mattn/go-sqlite3" +) + +var driverSeq atomic.Int64 + +// newGuardedDB registers a uniquely-named wrapped sqlite3 driver with the +// given options and returns an analyzed *sql.DB backed by a temp-file +// database (so the connection pool sees a consistent schema). +func newGuardedDB(t *testing.T, opts ...Option) *sql.DB { + t.Helper() + name := fmt.Sprintf("sqlguard-test-%d", driverSeq.Add(1)) + if err := Register(name, "sqlite3", opts...); err != nil { + t.Fatalf("Register: %v", err) + } + dsn := filepath.Join(t.TempDir(), "test.db") + db, err := sql.Open(name, dsn) + if err != nil { + t.Fatalf("sql.Open: %v", err) + } + t.Cleanup(func() { db.Close() }) + + if _, err := db.Exec("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, email TEXT)"); err != nil { + t.Fatalf("create table: %v", err) + } + if _, err := db.Exec("INSERT INTO users (name, email) VALUES ('alice', 'alice@example.com')"); err != nil { + t.Fatalf("insert: %v", err) + } + return db +} + +func guardedWithBuffer(t *testing.T, extra ...Option) (*sql.DB, *bytes.Buffer) { + t.Helper() + var buf bytes.Buffer + opts := append([]Option{WithReporter(reporter.NewConsoleReporterTo(&buf))}, extra...) + return newGuardedDB(t, opts...), &buf +} + +func TestDriver_ReturnsRealSQLDB(t *testing.T) { + db, _ := guardedWithBuffer(t) + // The whole point: Register/sql.Open yield a real *sql.DB, usable + // anywhere one is expected (no wrapper type to thread through). + if db == nil { + t.Fatal("expected a *sql.DB") + } +} + +func TestDriver_QueryDetectsSelectStar(t *testing.T) { + db, buf := guardedWithBuffer(t) + + rows, err := db.Query("SELECT * FROM users") + if err != nil { + t.Fatalf("query: %v", err) + } + rows.Close() + + if !strings.Contains(buf.String(), "select-star") { + t.Errorf("expected select-star warning, got: %q", buf.String()) + } +} + +func TestDriver_NoWarningForSafeQuery(t *testing.T) { + db, buf := guardedWithBuffer(t) + + rows, err := db.Query("SELECT id, name FROM users WHERE id = ?", 1) + if err != nil { + t.Fatalf("query: %v", err) + } + rows.Close() + + if buf.Len() != 0 { + t.Errorf("expected no warnings, got: %q", buf.String()) + } +} + +func TestDriver_ExecDetectsDeleteWithoutWhere(t *testing.T) { + db, buf := guardedWithBuffer(t) + + if _, err := db.Exec("DELETE FROM users"); err != nil { + t.Fatalf("exec: %v", err) + } + + if !strings.Contains(buf.String(), "delete-without-where") { + t.Errorf("expected delete-without-where, got: %q", buf.String()) + } + if !strings.Contains(buf.String(), "CRITICAL") { + t.Error("expected CRITICAL severity") + } +} + +func TestDriver_QueryRowDetectsLeadingWildcard(t *testing.T) { + db, buf := guardedWithBuffer(t) + + _ = db.QueryRow("SELECT id FROM users WHERE email LIKE '%gmail%'") + + if !strings.Contains(buf.String(), "leading-wildcard") { + t.Errorf("expected leading-wildcard, got: %q", buf.String()) + } +} + +func TestDriver_PreparedStatementIsAnalyzed(t *testing.T) { + db, buf := guardedWithBuffer(t) + + stmt, err := db.Prepare("SELECT * FROM users") + if err != nil { + t.Fatalf("prepare: %v", err) + } + defer stmt.Close() + + rows, err := stmt.Query() + if err != nil { + t.Fatalf("stmt query: %v", err) + } + rows.Close() + + if !strings.Contains(buf.String(), "select-star") { + t.Errorf("expected select-star on prepared exec, got: %q", buf.String()) + } +} + +func TestDriver_TransactionIsAnalyzed(t *testing.T) { + db, buf := guardedWithBuffer(t) + + tx, err := db.Begin() + if err != nil { + t.Fatalf("begin: %v", err) + } + if _, err := tx.Exec("DELETE FROM users"); err != nil { + t.Fatalf("tx exec: %v", err) + } + if err := tx.Rollback(); err != nil { + t.Fatalf("rollback: %v", err) + } + + if !strings.Contains(buf.String(), "delete-without-where") { + t.Errorf("expected delete-without-where in tx, got: %q", buf.String()) + } +} + +func TestDriver_TransactionCommitRollback(t *testing.T) { + db, _ := guardedWithBuffer(t) + + tx, err := db.Begin() + if err != nil { + t.Fatalf("begin: %v", err) + } + if _, err := tx.Exec("INSERT INTO users (name, email) VALUES (?, ?)", "bob", "bob@example.com"); err != nil { + t.Fatalf("exec: %v", err) + } + if err := tx.Commit(); err != nil { + t.Fatalf("commit: %v", err) + } + + var count int + if err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count); err != nil { + t.Fatalf("scan: %v", err) + } + if count != 2 { + t.Errorf("expected 2 rows after commit, got %d", count) + } + + tx2, err := db.Begin() + if err != nil { + t.Fatalf("begin: %v", err) + } + if _, err := tx2.Exec("DELETE FROM users WHERE name = ?", "bob"); err != nil { + t.Fatalf("exec: %v", err) + } + if err := tx2.Rollback(); err != nil { + t.Fatalf("rollback: %v", err) + } + + if err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count); err != nil { + t.Fatalf("scan: %v", err) + } + if count != 2 { + t.Errorf("expected 2 rows after rollback, got %d", count) + } +} + +func TestDriver_SlowQueryDetection(t *testing.T) { + db, buf := guardedWithBuffer(t, WithSlowQueryThreshold(1*time.Nanosecond)) + + rows, err := db.Query("SELECT id FROM users WHERE id = ?", 1) + if err != nil { + t.Fatalf("query: %v", err) + } + rows.Close() + + if !strings.Contains(buf.String(), "slow-query") { + t.Errorf("expected slow-query with 1ns threshold, got: %q", buf.String()) + } +} + +func TestDriver_NoSlowQueryBelowThreshold(t *testing.T) { + db, buf := guardedWithBuffer(t, WithSlowQueryThreshold(1*time.Hour)) + + rows, err := db.Query("SELECT id FROM users WHERE id = ?", 1) + if err != nil { + t.Fatalf("query: %v", err) + } + rows.Close() + + if strings.Contains(buf.String(), "slow-query") { + t.Errorf("did not expect slow-query, got: %q", buf.String()) + } +} + +func TestDriver_CustomAnalyzer(t *testing.T) { + db, buf := guardedWithBuffer(t, WithAnalyzer(analyzer.New(analyzer.CheckDeleteWithoutWhere))) + + rows, err := db.Query("SELECT * FROM users") + if err != nil { + t.Fatalf("query: %v", err) + } + rows.Close() + + if strings.Contains(buf.String(), "select-star") { + t.Errorf("did not expect select-star with custom analyzer, got: %q", buf.String()) + } +} + +func TestDriver_N1Detection(t *testing.T) { + db, buf := guardedWithBuffer(t, WithN1Detection(3, time.Second)) + + for i := range 5 { + row := db.QueryRow("SELECT name FROM users WHERE id = ?", i) + var name string + _ = row.Scan(&name) + } + + if !strings.Contains(buf.String(), "n-plus-one") { + t.Errorf("expected n-plus-one warning, got: %q", buf.String()) + } +} + +func TestRegister_DuplicateNameErrors(t *testing.T) { + name := fmt.Sprintf("sqlguard-dup-%d", driverSeq.Add(1)) + if err := Register(name, "sqlite3"); err != nil { + t.Fatalf("first Register: %v", err) + } + if err := Register(name, "sqlite3"); err == nil { + t.Error("expected error registering duplicate name") + } +} + +func TestRegister_UnknownBaseDriverErrors(t *testing.T) { + if err := Register("sqlguard-x", "no-such-driver"); err == nil { + t.Error("expected error for unknown base driver") + } +} diff --git a/middleware/guard.go b/middleware/guard.go new file mode 100644 index 0000000..e29e92b --- /dev/null +++ b/middleware/guard.go @@ -0,0 +1,137 @@ +package middleware + +import ( + "fmt" + "time" + + "github.com/KARTIKrocks/sqlguard/analyzer" +) + +// Guard is the single shared analysis core. It runs the configured analyzer +// and reporter against every executed query, measures latency, and feeds the +// N+1 tracker. Every interception point — the database/sql driver chain and +// every out-of-tree integration (pgxguard, …) — drives the same Guard so +// analysis logic, redaction, fingerprinting, N+1, the parser seam and config +// live here exactly once. Integrations must build on Guard rather than +// re-implementing check/latency by hand (that path silently loses +// redaction-by-default and fingerprints). +// +// A Guard is safe for concurrent use. +type Guard struct { + opts options + tracker *QueryTracker + deduper *deduper + cache *analysisCache +} + +// NewGuard builds a Guard from the given options. +func NewGuard(opts ...Option) *Guard { + o := defaultOptions() + for _, opt := range opts { + opt(&o) + } + if o.parser != nil { + o.analyzer = o.analyzer.WithParser(o.parser) + } + g := &Guard{opts: o, deduper: newDeduper(o.dedupWindow)} + if o.cacheSize > 0 { + g.cache = newAnalysisCache(o.cacheSize) + } + if o.enableN1 { + g.tracker = NewQueryTracker(o.n1Threshold, o.n1Window, func(results []analyzer.Result) { + o.reporter.Report(results) + }) + } + return g +} + +// Analyzer returns the configured analyzer. Useful for integrations that need +// the canonical redact/fingerprint helpers without re-deriving policy. +func (g *Guard) Analyzer() *analyzer.Analyzer { return g.opts.analyzer } + +// Check runs the static rules against the query and feeds the N+1 tracker. +func (g *Guard) Check(query string) { + results := g.analyze(query) + if len(results) > 0 { + g.report(results) + } + if g.tracker != nil { + g.tracker.Track(query) + } +} + +// analyze returns the static findings for query, memoizing per distinct query +// string so a recurring query is parsed and rule-checked once. The cache is +// keyed on the exact query string because a few rules read literal-derived +// facts the fingerprint folds away (see analysisCache). The returned slice may +// be shared from the cache and must be treated as read-only. +func (g *Guard) analyze(query string) []analyzer.Result { + if g.cache == nil { + return g.opts.analyzer.Analyze(query) + } + if cached, ok := g.cache.get(query); ok { + return cached + } + results := g.opts.analyzer.Analyze(query) + g.cache.put(query, results) + return results +} + +// report emits static findings, suppressing repeats of the same +// (fingerprint, rule) within the dedup window so a recurring query does not +// flood the reporter. results may be a shared cache entry, so it is never +// mutated; kept is allocated only when a finding actually passes dedup (rare +// after the first occurrence, and never for the common no-findings case). +func (g *Guard) report(results []analyzer.Result) { + now := time.Now() + var kept []analyzer.Result + for _, r := range results { + if g.deduper.allow(r.Fingerprint, r.RuleName, now) { + kept = append(kept, r) + } + } + if len(kept) > 0 { + g.opts.reporter.Report(kept) + } +} + +// CheckLatency reports a slow-query finding if elapsed exceeds the threshold. +func (g *Guard) CheckLatency(query string, elapsed time.Duration) { + if elapsed >= g.opts.slowThreshold { + display, fingerprint := g.opts.analyzer.PrepareQuery(query) + g.opts.reporter.Report([]analyzer.Result{{ + RuleName: "slow-query", + Severity: analyzer.SeverityWarning, + Query: display, + Fingerprint: fingerprint, + Message: fmt.Sprintf("Query took %s (threshold: %s)", elapsed.Round(time.Millisecond), g.opts.slowThreshold), + Suggestion: "Consider adding indexes or optimizing the query.", + }}) + } +} + +// Observe analyzes a query and times its execution. The returned function +// must be called once the underlying operation completes; it records latency +// only when err is nil (a failed query's latency is meaningless). It is +// designed for split start/end interception points such as pgx tracers: +// call Observe in the start hook, stash the closure, invoke it in the end +// hook with the operation error. +func (g *Guard) Observe(query string) func(err error) { + g.Check(query) + start := time.Now() + return func(err error) { + if err == nil { + g.CheckLatency(query, time.Since(start)) + } + } +} + +// ResetN1 clears the N+1 tracker's accumulated state. Call this at a +// per-request boundary (e.g. end of an HTTP handler) so N+1 detection is +// scoped to a single logical unit of work rather than process-global. It is +// a no-op when N+1 detection is not enabled. +func (g *Guard) ResetN1() { + if g.tracker != nil { + g.tracker.Reset() + } +} diff --git a/middleware/n_plus_one.go b/middleware/n_plus_one.go new file mode 100644 index 0000000..38bcaba --- /dev/null +++ b/middleware/n_plus_one.go @@ -0,0 +1,124 @@ +package middleware + +import ( + "fmt" + "sync" + "time" + + "github.com/KARTIKrocks/sqlguard/analyzer" +) + +// normalizeQuery is the N+1 grouping key: the canonical, literal-free query +// fingerprint. It delegates to analyzer.Fingerprint so there is a single +// normalizer in the codebase (the comment/string-literal-aware one) rather +// than a second, subtly different regex pass. +func normalizeQuery(query string) string { + return analyzer.Fingerprint(query) +} + +type queryRecord struct { + count int + firstSeen time.Time + reported bool +} + +// QueryTracker detects N+1 query patterns at runtime. +// It tracks normalized query patterns and flags when the same pattern +// is executed more than a threshold number of times within a time window. +type QueryTracker struct { + mu sync.Mutex + queries map[string]*queryRecord + threshold int + window time.Duration + maxKeys int + reporter func(results []analyzer.Result) +} + +// NewQueryTracker creates a tracker that flags when the same query pattern +// appears more than threshold times within the given window. +func NewQueryTracker(threshold int, window time.Duration, reportFn func([]analyzer.Result)) *QueryTracker { + return &QueryTracker{ + queries: make(map[string]*queryRecord), + threshold: threshold, + window: window, + maxKeys: 10000, + reporter: reportFn, + } +} + +// Track records a query execution and reports if N+1 pattern is detected. +func (qt *QueryTracker) Track(query string) { + normalized := normalizeQuery(query) + + qt.mu.Lock() + + now := time.Now() + + // Bound memory: when the map is at capacity, evict expired entries first. + if len(qt.queries) >= qt.maxKeys { + qt.evictExpired(now) + } + + rec, exists := qt.queries[normalized] + if !exists { + // A new key past the cap (eviction freed nothing — every entry is still + // in-window) is dropped rather than grown without bound: a rare, + // harmless false negative under pathological query-shape cardinality. + // Already-tracked keys (the exists path below) are always honored, so + // in-flight N+1 detection is never lost. + if len(qt.queries) >= qt.maxKeys { + qt.mu.Unlock() + return + } + qt.queries[normalized] = &queryRecord{count: 1, firstSeen: now} + qt.mu.Unlock() + return + } + + // If outside the window, reset + if now.Sub(rec.firstSeen) > qt.window { + rec.count = 1 + rec.firstSeen = now + rec.reported = false + qt.mu.Unlock() + return + } + + rec.count++ + + shouldReport := rec.count >= qt.threshold && !rec.reported + if shouldReport { + rec.reported = true + } + + // Release lock before calling reporter to avoid holding mutex during I/O + count := rec.count + qt.mu.Unlock() + + if shouldReport { + qt.reporter([]analyzer.Result{{ + RuleName: "n-plus-one", + Severity: analyzer.SeverityWarning, + Query: normalized, + Fingerprint: normalized, + Message: fmt.Sprintf("Possible N+1 query detected: same pattern executed %d times in %s", count, qt.window), + Suggestion: "Consider using a JOIN or IN clause to batch these queries.", + }}) + } +} + +// evictExpired removes entries older than the window. Must be called with mutex held. +func (qt *QueryTracker) evictExpired(now time.Time) { + for key, rec := range qt.queries { + if now.Sub(rec.firstSeen) > qt.window { + delete(qt.queries, key) + } + } +} + +// Reset clears all tracked queries. Call this between requests. +func (qt *QueryTracker) Reset() { + qt.mu.Lock() + defer qt.mu.Unlock() + qt.queries = make(map[string]*queryRecord) +} diff --git a/middleware/n_plus_one_test.go b/middleware/n_plus_one_test.go new file mode 100644 index 0000000..141b2c1 --- /dev/null +++ b/middleware/n_plus_one_test.go @@ -0,0 +1,168 @@ +package middleware + +import ( + "fmt" + "testing" + "time" + + "github.com/KARTIKrocks/sqlguard/analyzer" +) + +func TestNormalizeQuery(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + {"numbers", "SELECT * FROM users WHERE id = 42", "SELECT * FROM users WHERE id = ?"}, + {"strings", "SELECT * FROM users WHERE name = 'alice'", "SELECT * FROM users WHERE name = ?"}, + {"mixed", "SELECT * FROM users WHERE id = 1 AND name = 'bob'", "SELECT * FROM users WHERE id = ? AND name = ?"}, + {"no literals", "SELECT * FROM users WHERE id = ?", "SELECT * FROM users WHERE id = ?"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := normalizeQuery(tt.input) + if got != tt.want { + t.Errorf("normalizeQuery(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestQueryTracker_DetectsN1(t *testing.T) { + var reported []analyzer.Result + tracker := NewQueryTracker(3, 5*time.Second, func(results []analyzer.Result) { + reported = append(reported, results...) + }) + + // Same pattern 3 times should trigger + tracker.Track("SELECT * FROM orders WHERE user_id = 1") + tracker.Track("SELECT * FROM orders WHERE user_id = 2") + tracker.Track("SELECT * FROM orders WHERE user_id = 3") + + if len(reported) != 1 { + t.Fatalf("expected 1 N+1 report, got %d", len(reported)) + } + if reported[0].RuleName != "n-plus-one" { + t.Errorf("expected rule n-plus-one, got %s", reported[0].RuleName) + } +} + +func TestQueryTracker_DifferentPatterns(t *testing.T) { + var reported []analyzer.Result + tracker := NewQueryTracker(3, 5*time.Second, func(results []analyzer.Result) { + reported = append(reported, results...) + }) + + // Different patterns should not trigger + tracker.Track("SELECT * FROM orders WHERE user_id = 1") + tracker.Track("SELECT * FROM users WHERE id = 1") + tracker.Track("SELECT * FROM products WHERE id = 1") + + if len(reported) != 0 { + t.Errorf("expected no reports for different patterns, got %d", len(reported)) + } +} + +func TestQueryTracker_BelowThreshold(t *testing.T) { + var reported []analyzer.Result + tracker := NewQueryTracker(5, 5*time.Second, func(results []analyzer.Result) { + reported = append(reported, results...) + }) + + // Only 3 of same pattern, threshold is 5 + tracker.Track("SELECT * FROM orders WHERE user_id = 1") + tracker.Track("SELECT * FROM orders WHERE user_id = 2") + tracker.Track("SELECT * FROM orders WHERE user_id = 3") + + if len(reported) != 0 { + t.Errorf("expected no reports below threshold, got %d", len(reported)) + } +} + +func TestQueryTracker_ReportsOnlyOnce(t *testing.T) { + var reported []analyzer.Result + tracker := NewQueryTracker(2, 5*time.Second, func(results []analyzer.Result) { + reported = append(reported, results...) + }) + + tracker.Track("SELECT * FROM orders WHERE user_id = 1") + tracker.Track("SELECT * FROM orders WHERE user_id = 2") + tracker.Track("SELECT * FROM orders WHERE user_id = 3") + tracker.Track("SELECT * FROM orders WHERE user_id = 4") + + if len(reported) != 1 { + t.Errorf("expected exactly 1 report (not per-query), got %d", len(reported)) + } +} + +func TestQueryTracker_Reset(t *testing.T) { + var reported []analyzer.Result + tracker := NewQueryTracker(2, 5*time.Second, func(results []analyzer.Result) { + reported = append(reported, results...) + }) + + tracker.Track("SELECT * FROM orders WHERE user_id = 1") + tracker.Reset() + tracker.Track("SELECT * FROM orders WHERE user_id = 2") + + // After reset, count should restart + if len(reported) != 0 { + t.Errorf("expected no reports after reset, got %d", len(reported)) + } +} + +// N+1 detection through the driver path is covered by +// TestDriver_N1Detection in driver_test.go. QueryTracker.Reset is +// exercised directly above; per-request reset is no longer exposed on +// the *sql.DB returned by the driver wrapper. + +func TestQueryTracker_BoundedAtMaxKeys(t *testing.T) { + qt := &QueryTracker{ + queries: make(map[string]*queryRecord), + threshold: 1000, // high so nothing reports + window: time.Hour, // long so nothing expires (eviction frees nothing) + maxKeys: 3, + reporter: func([]analyzer.Result) {}, + } + + // Distinct query *shapes* (distinct column names → distinct fingerprints), + // far more than maxKeys, all in-window. The map must stay capped. + for i := range 50 { + qt.Track(fmt.Sprintf("SELECT col%d FROM t WHERE id = ?", i)) + } + + if len(qt.queries) > qt.maxKeys { + t.Errorf("tracker map grew past maxKeys: %d > %d", len(qt.queries), qt.maxKeys) + } +} + +func TestQueryTracker_TrackedKeyHonoredAtCap(t *testing.T) { + var reports int + qt := &QueryTracker{ + queries: make(map[string]*queryRecord), + threshold: 3, + window: time.Hour, + maxKeys: 2, + reporter: func([]analyzer.Result) { reports++ }, + } + + // "cola" gets tracked to count 2 (below threshold), then the cap fills. + qt.Track("SELECT cola FROM t") + qt.Track("SELECT cola FROM t") + qt.Track("SELECT colb FROM t") // map now {cola, colb}, at cap + + // A brand-new key at the cap is dropped (map stays bounded)... + qt.Track("SELECT colc FROM t") + if len(qt.queries) > qt.maxKeys { + t.Fatalf("map exceeded cap: %d", len(qt.queries)) + } + + // ...but the already-tracked "cola" still increments to threshold and fires + // exactly once — in-flight detection is never lost to the cap. + qt.Track("SELECT cola FROM t") + if reports != 1 { + t.Errorf("expected the tracked key to still reach threshold and report once, got %d", reports) + } +} diff --git a/middleware/options.go b/middleware/options.go new file mode 100644 index 0000000..8285bf6 --- /dev/null +++ b/middleware/options.go @@ -0,0 +1,98 @@ +package middleware + +import ( + "time" + + "github.com/KARTIKrocks/sqlguard/analyzer" + "github.com/KARTIKrocks/sqlguard/reporter" +) + +type options struct { + slowThreshold time.Duration + reporter reporter.Reporter + analyzer *analyzer.Analyzer + parser analyzer.Parser + n1Threshold int + n1Window time.Duration + enableN1 bool + dedupWindow time.Duration + cacheSize int +} + +// Option configures the runtime guard. +type Option func(*options) + +// WithSlowQueryThreshold sets the duration above which a query is flagged as slow. +// Default is 200ms. +func WithSlowQueryThreshold(d time.Duration) Option { + return func(o *options) { + o.slowThreshold = d + } +} + +// WithReporter sets a custom reporter. Default is ConsoleReporter. +func WithReporter(r reporter.Reporter) Option { + return func(o *options) { + o.reporter = r + } +} + +// WithAnalyzer sets a custom analyzer. Default is analyzer.Default(). +func WithAnalyzer(a *analyzer.Analyzer) Option { + return func(o *options) { + o.analyzer = a + } +} + +// WithParser sets the SQL parser the analyzer uses. Default is the +// zero-dependency analyzer.FallbackParser. Pass a real dialect parser +// (e.g. from sqlguard/parsers/pgparser) for exact, structural analysis. +func WithParser(p analyzer.Parser) Option { + return func(o *options) { + o.parser = p + } +} + +// WithN1Detection enables N+1 query detection with the given threshold and window. +// When the same query pattern is executed threshold times within window, a warning is reported. +func WithN1Detection(threshold int, window time.Duration) Option { + return func(o *options) { + o.enableN1 = true + o.n1Threshold = threshold + o.n1Window = window + } +} + +// WithFindingDedup sets the window within which a repeated static finding — +// the same rule firing on the same canonical query shape — is reported at most +// once. This keeps a recurring query (or a prepared statement run in a loop) +// from flooding the log sink with the same warning on every execution. The +// default is one minute. Pass 0 to disable dedup and report every occurrence +// (the legacy behavior). Slow-query and N+1 findings have their own emission +// policy and are unaffected. +func WithFindingDedup(window time.Duration) Option { + return func(o *options) { + o.dedupWindow = window + } +} + +// WithAnalysisCacheSize sets the maximum number of distinct query strings whose +// analysis results are memoized, so a recurring query is parsed and rule-checked +// once instead of on every execution. The cache is an LRU keyed on the exact +// query string (correct even for the literal-sensitive rules). Default is 1024. +// Pass 0 to disable the cache and analyze every query. +func WithAnalysisCacheSize(n int) Option { + return func(o *options) { + o.cacheSize = n + } +} + +func defaultOptions() options { + return options{ + slowThreshold: 200 * time.Millisecond, + reporter: reporter.NewConsoleReporter(), + analyzer: analyzer.Default(), + dedupWindow: time.Minute, + cacheSize: 1024, + } +} diff --git a/parsers/mysqlparser/go.mod b/parsers/mysqlparser/go.mod new file mode 100644 index 0000000..6360f11 --- /dev/null +++ b/parsers/mysqlparser/go.mod @@ -0,0 +1,9 @@ +module github.com/KARTIKrocks/sqlguard/parsers/mysqlparser + +go 1.26 + +require github.com/KARTIKrocks/sqlguard v0.0.0 + +require github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2 + +replace github.com/KARTIKrocks/sqlguard => ../.. diff --git a/parsers/mysqlparser/go.sum b/parsers/mysqlparser/go.sum new file mode 100644 index 0000000..6354a21 --- /dev/null +++ b/parsers/mysqlparser/go.sum @@ -0,0 +1,2 @@ +github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2 h1:zzrxE1FKn5ryBNl9eKOeqQ58Y/Qpo3Q9QNxKHX5uzzQ= +github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2/go.mod h1:hzfGeIUDq/j97IG+FhNqkowIyEcD88LrW6fyU3K3WqY= diff --git a/parsers/mysqlparser/mysqlparser.go b/parsers/mysqlparser/mysqlparser.go new file mode 100644 index 0000000..7943340 --- /dev/null +++ b/parsers/mysqlparser/mysqlparser.go @@ -0,0 +1,139 @@ +// Package mysqlparser is an optional sqlguard Parser backed by a real +// MySQL grammar (github.com/xwb1989/sqlparser — a pure-Go, no-cgo, +// lightweight Vitess-derived MySQL parser). +// +// It produces exact, structural answers for the false-positive-prone facts +// (statement kind, WHERE/LIMIT/ORDER BY/FROM presence, SELECT *, explicit +// INSERT columns) instead of regex guesses. SQL the grammar rejects — +// CTEs it doesn't support, dynamic fragments, dialect extensions — +// transparently degrades to sqlguard's zero-dependency FallbackParser, so +// analysis never breaks the caller's query path. +// +// Usage: +// +// sqlguard.Register("sqlguard-mysql", "mysql", middleware.WithParser(mysqlparser.New())) +// db, _ := sql.Open("sqlguard-mysql", dsn) +package mysqlparser + +import ( + "strconv" + "strings" + + "github.com/KARTIKrocks/sqlguard/analyzer" + "github.com/xwb1989/sqlparser" +) + +// Parser implements analyzer.Parser using a MySQL grammar. +type Parser struct { + fallback analyzer.Parser +} + +// New returns a MySQL-dialect Parser that falls back to the +// zero-dependency FallbackParser on parse failure. +func New() *Parser { + return &Parser{fallback: analyzer.NewFallbackParser()} +} + +var _ analyzer.Parser = (*Parser)(nil) + +// Parse implements analyzer.Parser. It never returns an error: unparseable +// SQL yields the fallback parser's best-effort Statement (Exact=false). +func (p *Parser) Parse(sql string) (*analyzer.Statement, error) { + // Baseline from the fallback. It detects the literal/text-level fields + // (leading-wildcard LIKE, non-sargable predicates, unsafe NOT NULL adds) + // that the AST loses after parsing, so those fields are kept; only + // structural fields are overwritten. + st, _ := p.fallback.Parse(sql) + if st == nil { + st = &analyzer.Statement{Raw: sql} + } + + ast, err := sqlparser.Parse(sql) + if err != nil || ast == nil { + return st, nil // keep best-effort fallback Statement + } + + st.Kind = analyzer.StmtOther + st.HasWhere = false + st.HasLimit = false + st.HasOrderBy = false + st.HasFrom = false + st.SelectStar = false + st.SelectDistinct = false + st.OffsetValue = 0 + st.InsertColumnsListed = false + + switch n := ast.(type) { + case *sqlparser.Select: + st.Kind = analyzer.StmtSelect + st.HasWhere = n.Where != nil + st.HasLimit = n.Limit != nil + st.HasOrderBy = len(n.OrderBy) > 0 + st.HasFrom = hasRealFrom(n.From) + st.SelectDistinct = n.Distinct != "" + st.OffsetValue = offsetValue(n.Limit) + for _, e := range n.SelectExprs { + if _, ok := e.(*sqlparser.StarExpr); ok { // '*' or 'table.*' + st.SelectStar = true + } + } + case *sqlparser.Delete: + st.Kind = analyzer.StmtDelete + st.HasWhere = n.Where != nil + st.HasLimit = n.Limit != nil + st.HasOrderBy = len(n.OrderBy) > 0 + st.OffsetValue = offsetValue(n.Limit) + case *sqlparser.Update: + st.Kind = analyzer.StmtUpdate + st.HasWhere = n.Where != nil + st.HasLimit = n.Limit != nil + st.HasOrderBy = len(n.OrderBy) > 0 + st.OffsetValue = offsetValue(n.Limit) + case *sqlparser.Insert: + st.Kind = analyzer.StmtInsert + st.InsertColumnsListed = len(n.Columns) > 0 + } + + st.Exact = true + return st, nil +} + +// offsetValue extracts a literal OFFSET as an int, or 0 when there is no limit +// clause, no offset, or a non-literal (parameterized) offset — matching the +// large-offset rule's contract that only statically-known offsets are flagged. +// Covers both "LIMIT count OFFSET n" and MySQL's "LIMIT n, count" (the parser +// puts n in Offset for both). +func offsetValue(lim *sqlparser.Limit) int { + if lim == nil || lim.Offset == nil { + return 0 + } + v, ok := lim.Offset.(*sqlparser.SQLVal) + if !ok || v.Type != sqlparser.IntVal { + return 0 + } + n, err := strconv.Atoi(string(v.Val)) + if err != nil || n < 0 { + return 0 + } + return n +} + +// hasRealFrom reports whether a FROM clause references a real table, not the +// implicit "dual" the parser injects for FROM-less selects like SELECT 1. +func hasRealFrom(from sqlparser.TableExprs) bool { + for _, te := range from { + ate, ok := te.(*sqlparser.AliasedTableExpr) + if !ok { + return true // join / subquery / etc. — a real source + } + if tn, ok := ate.Expr.(sqlparser.TableName); ok { + // Case-insensitive: sqlparser preserves the casing of backticked + // identifiers, so `DUAL` would otherwise read as a real table. + if strings.EqualFold(tn.Name.String(), "dual") { + continue + } + } + return true + } + return false +} diff --git a/parsers/mysqlparser/mysqlparser_test.go b/parsers/mysqlparser/mysqlparser_test.go new file mode 100644 index 0000000..abb1271 --- /dev/null +++ b/parsers/mysqlparser/mysqlparser_test.go @@ -0,0 +1,151 @@ +package mysqlparser + +import ( + "testing" + + "github.com/KARTIKrocks/sqlguard/analyzer" +) + +func TestParser_ExactStructuralFacts(t *testing.T) { + p := New() + tests := []struct { + name string + sql string + want analyzer.Statement + }{ + { + name: "delete without where", + sql: "DELETE FROM users", + want: analyzer.Statement{Kind: analyzer.StmtDelete, Exact: true}, + }, + { + name: "delete with where", + sql: "DELETE FROM users WHERE id = 1", + want: analyzer.Statement{Kind: analyzer.StmtDelete, HasWhere: true, Exact: true}, + }, + { + name: "update without where", + sql: "UPDATE users SET name = 'x'", + want: analyzer.Statement{Kind: analyzer.StmtUpdate, Exact: true}, + }, + { + name: "select star with from", + sql: "SELECT * FROM users", + want: analyzer.Statement{Kind: analyzer.StmtSelect, SelectStar: true, HasFrom: true, Exact: true}, + }, + { + name: "qualified star", + sql: "SELECT u.* FROM users u", + want: analyzer.Statement{Kind: analyzer.StmtSelect, SelectStar: true, HasFrom: true, Exact: true}, + }, + { + name: "count star is not select star", + sql: "SELECT COUNT(*) FROM users WHERE id = 1", + want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: true, HasWhere: true, Exact: true}, + }, + { + name: "select 1 has no real from", + sql: "SELECT 1", + want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: false, Exact: true}, + }, + { + name: "explicit dual is not a real from", + sql: "SELECT 1 FROM dual", + want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: false, Exact: true}, + }, + { + name: "uppercase DUAL is not a real from", + sql: "SELECT 1 FROM DUAL", + want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: false, Exact: true}, + }, + { + name: "backticked DUAL is not a real from", + sql: "SELECT 1 FROM `DUAL`", + want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: false, Exact: true}, + }, + { + name: "insert with columns", + sql: "INSERT INTO users (name) VALUES ('a')", + want: analyzer.Statement{Kind: analyzer.StmtInsert, InsertColumnsListed: true, Exact: true}, + }, + { + name: "insert without columns", + sql: "INSERT INTO users VALUES ('a')", + want: analyzer.Statement{Kind: analyzer.StmtInsert, Exact: true}, + }, + { + name: "order by without limit", + sql: "SELECT id FROM users ORDER BY name", + want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: true, HasOrderBy: true, Exact: true}, + }, + { + name: "select distinct", + sql: "SELECT DISTINCT name FROM users", + want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: true, SelectDistinct: true, Exact: true}, + }, + { + name: "count distinct is not select distinct", + sql: "SELECT COUNT(DISTINCT id) FROM users WHERE id = 1", + want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: true, HasWhere: true, Exact: true}, + }, + { + name: "literal offset (OFFSET form)", + sql: "SELECT id FROM users WHERE x = 1 ORDER BY id LIMIT 10 OFFSET 5000", + want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: true, HasWhere: true, HasOrderBy: true, HasLimit: true, OffsetValue: 5000, Exact: true}, + }, + { + name: "literal offset (LIMIT n, count form)", + sql: "SELECT id FROM users WHERE x = 1 LIMIT 5000, 10", + want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: true, HasWhere: true, HasLimit: true, OffsetValue: 5000, Exact: true}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + st, err := p.Parse(tt.sql) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if st.Kind != tt.want.Kind || + st.HasWhere != tt.want.HasWhere || + st.HasLimit != tt.want.HasLimit || + st.HasOrderBy != tt.want.HasOrderBy || + st.HasFrom != tt.want.HasFrom || + st.SelectStar != tt.want.SelectStar || + st.SelectDistinct != tt.want.SelectDistinct || + st.OffsetValue != tt.want.OffsetValue || + st.InsertColumnsListed != tt.want.InsertColumnsListed || + st.Exact != tt.want.Exact { + t.Errorf("Parse(%q)\n got: %+v\nwant: %+v", tt.sql, *st, tt.want) + } + }) + } +} + +func TestParser_FallsBackOnUnparseable(t *testing.T) { + p := New() + // Postgres-style placeholders the MySQL grammar rejects must not error + // and must come back as a best-effort (non-exact) Statement. + st, err := p.Parse("SELECT * FROM t WHERE id = $1") + if err != nil { + t.Fatalf("fallback path must not error: %v", err) + } + if st == nil || st.Exact { + t.Errorf("expected non-nil, non-exact fallback statement, got %+v", st) + } +} + +func TestParser_IntegratesWithAnalyzer(t *testing.T) { + a := analyzer.Default().WithParser(New()) + + got := a.Analyze("UPDATE users SET active = 0 /* WHERE id = 1 */") + found := false + for _, r := range got { + if r.RuleName == "update-without-where" { + found = true + } + } + if !found { + t.Errorf("expected update-without-where (WHERE only in comment), got %+v", got) + } +} diff --git a/parsers/pgparser/go.mod b/parsers/pgparser/go.mod new file mode 100644 index 0000000..88cf8cb --- /dev/null +++ b/parsers/pgparser/go.mod @@ -0,0 +1,35 @@ +module github.com/KARTIKrocks/sqlguard/parsers/pgparser + +go 1.26 + +require github.com/KARTIKrocks/sqlguard v0.0.0 + +require ( + github.com/auxten/postgresql-parser v1.0.1 + github.com/certifi/gocertifi v0.0.0-20200922220541-2c3bb06c6054 // indirect + github.com/cockroachdb/apd v1.1.1-0.20181017181144-bced77f817b4 // indirect + github.com/cockroachdb/errors v1.8.2 // indirect + github.com/cockroachdb/logtags v0.0.0-20190617123548-eb05cc24525f // indirect + github.com/cockroachdb/redact v1.0.8 // indirect + github.com/cockroachdb/sentry-go v0.6.1-cockroachdb.2 // indirect + github.com/dustin/go-humanize v1.0.0 // indirect + github.com/getsentry/raven-go v0.2.0 // indirect + github.com/gogo/protobuf v1.3.2 // indirect + github.com/golang/protobuf v1.4.3 // indirect + github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect + github.com/konsorten/go-windows-terminal-sequences v1.0.3 // indirect + github.com/kr/pretty v0.2.0 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/lib/pq v1.9.0 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/sirupsen/logrus v1.6.0 // indirect + github.com/spf13/pflag v1.0.10 // indirect + golang.org/x/sync v0.20.0 // indirect + golang.org/x/sys v0.0.0-20201214210602-f9fddec55a1e // indirect + golang.org/x/text v0.3.4 // indirect + google.golang.org/genproto v0.0.0-20200911024640-645f7a48b24f // indirect + google.golang.org/grpc v1.33.1 // indirect + google.golang.org/protobuf v1.25.0 // indirect +) + +replace github.com/KARTIKrocks/sqlguard => ../.. diff --git a/parsers/pgparser/go.sum b/parsers/pgparser/go.sum new file mode 100644 index 0000000..c8df32f --- /dev/null +++ b/parsers/pgparser/go.sum @@ -0,0 +1,347 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/AndreasBriese/bbloom v0.0.0-20190306092124-e2d15f34fcf9/go.mod h1:bOvUY6CB00SOBii9/FifXqc0awNKxLFCL/+pkDPuyl8= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/CloudyKit/fastprinter v0.0.0-20170127035650-74b38d55f37a/go.mod h1:EFZQ978U7x8IRnstaskI3IysnWY5Ao3QgZUKOXlsAdw= +github.com/CloudyKit/jet v2.1.3-0.20180809161101-62edd43e4f88+incompatible/go.mod h1:HPYO+50pSWkPoj9Q/eq0aRGByCL6ScRlUmiEX5Zgm+w= +github.com/Joker/hpp v1.0.0/go.mod h1:8x5n+M1Hp5hC0g8okX3sR3vFQwynaX/UgSOM9MeBKzY= +github.com/Joker/jade v1.0.1-0.20190614124447-d475f43051e7/go.mod h1:6E6s8o2AE4KhCrqr6GRJjdC/gNfTdxkIXvuGZZda2VM= +github.com/Shopify/goreferrer v0.0.0-20181106222321-ec9c9a553398/go.mod h1:a1uqRtAwp2Xwc6WNPJEufxJ7fx3npB4UV/JOLmbu5I0= +github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= +github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= +github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= +github.com/auxten/postgresql-parser v1.0.1 h1:x+qiEHAe2cH55Kly64dWh4tGvUKEQwMmJgma7a1kbj4= +github.com/auxten/postgresql-parser v1.0.1/go.mod h1:Nf27dtv8EU1C+xNkoLD3zEwfgJfDDVi8Zl86gznxPvI= +github.com/aymerick/raymond v2.0.3-0.20180322193309-b565731e1464+incompatible/go.mod h1:osfaiScAUVup+UC9Nfq76eWqDhXlp+4UYaA8uhTBO6g= +github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/certifi/gocertifi v0.0.0-20200922220541-2c3bb06c6054 h1:uH66TXeswKn5PW5zdZ39xEwfS9an067BirqA+P4QaLI= +github.com/certifi/gocertifi v0.0.0-20200922220541-2c3bb06c6054/go.mod h1:sGbDF6GwGcLpkNXPUTkMRoywsNa/ol15pxFe6ERfguA= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= +github.com/cockroachdb/apd v1.1.1-0.20181017181144-bced77f817b4 h1:XWEdfNxDkZI3DXXlpo0hZJ1xdaH/f3CKuZpk93pS/Y0= +github.com/cockroachdb/apd v1.1.1-0.20181017181144-bced77f817b4/go.mod h1:mdGz2CnkJrefFtlLevmE7JpL2zB9tKofya/6w7wWzNA= +github.com/cockroachdb/datadriven v1.0.0/go.mod h1:5Ib8Meh+jk1RlHIXej6Pzevx/NLlNvQB9pmSBZErGA4= +github.com/cockroachdb/errors v1.6.1/go.mod h1:tm6FTP5G81vwJ5lC0SizQo374JNCOPrHyXGitRJoDqM= +github.com/cockroachdb/errors v1.8.2 h1:rnnWK9Nn5kEMOGz9531HuDx/FOleL4NVH20VsDexVC8= +github.com/cockroachdb/errors v1.8.2/go.mod h1:qGwQn6JmZ+oMjuLwjWzUNqblqk0xl4CVV3SQbGwK7Ac= +github.com/cockroachdb/logtags v0.0.0-20190617123548-eb05cc24525f h1:o/kfcElHqOiXqcou5a3rIlMc7oJbMQkeLk0VQJ7zgqY= +github.com/cockroachdb/logtags v0.0.0-20190617123548-eb05cc24525f/go.mod h1:i/u985jwjWRlyHXQbwatDASoW0RMlZ/3i9yJHE2xLkI= +github.com/cockroachdb/redact v1.0.8 h1:8QG/764wK+vmEYoOlfobpe12EQcS81ukx/a4hdVMxNw= +github.com/cockroachdb/redact v1.0.8/go.mod h1:BVNblN9mBWFyMyqK1k3AAiSxhvhfK2oOZZ2lK+dpvRg= +github.com/cockroachdb/sentry-go v0.6.1-cockroachdb.2 h1:IKgmqgMQlVJIZj19CdocBeSfSaiCbEBZGKODaixqtHM= +github.com/cockroachdb/sentry-go v0.6.1-cockroachdb.2/go.mod h1:8BT+cPK6xvFOcRlk0R8eg+OTkcqI6baNH4xAkpiYVvQ= +github.com/codegangsta/inject v0.0.0-20150114235600-33e0aa1cb7c0/go.mod h1:4Zcjuz89kmFXt9morQgcfYZAYZ5n8WHjt81YYWIwtTM= +github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= +github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk= +github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= +github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgraph-io/badger v1.6.0/go.mod h1:zwt7syl517jmP8s94KqSxTlM6IMsdhYy6psNgSztDR4= +github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= +github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= +github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= +github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/eknkc/amber v0.0.0-20171010120322-cdade1c07385/go.mod h1:0vRUJqYpeSZifjYj7uP3BG/gKcuzL9xWVV/Y+cK33KM= +github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= +github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/etcd-io/bbolt v1.3.3/go.mod h1:ZF2nL25h33cCyBtcyWeZ2/I3HQOfTP+0PIEvHjkjCrw= +github.com/fasthttp-contrib/websocket v0.0.0-20160511215533-1f3b11f56072/go.mod h1:duJ4Jxv5lDcvg4QuQr0oowTf7dz4/CR8NtyCooz9HL8= +github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= +github.com/flosch/pongo2 v0.0.0-20190707114632-bbf5a6c351f4/go.mod h1:T9YF2M40nIgbVgp3rreNmTged+9HrbNTIQf1PsaIiTA= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/gavv/httpexpect v2.0.0+incompatible/go.mod h1:x+9tiU1YnrOvnB725RkpoLv1M62hOWzwo5OXotisrKc= +github.com/getsentry/raven-go v0.2.0 h1:no+xWJRb5ZI7eE8TWgIq1jLulQiIoLG0IfYxv5JYMGs= +github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ= +github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= +github.com/gin-contrib/sse v0.0.0-20190301062529-5545eab6dad3/go.mod h1:VJ0WA2NBN22VlZ2dKZQPAPnyWw5XTlK1KymzLKsr59s= +github.com/gin-gonic/gin v1.4.0/go.mod h1:OW2EZn3DO8Ln9oIKOvM++LBO+5UPHJJDH72/q/3rZdM= +github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98= +github.com/go-errors/errors v1.0.1 h1:LUHzmkK3GUKUrL/1gfBUxAHzcev3apQlezX/+O7ma6w= +github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= +github.com/go-martini/martini v0.0.0-20170121215854-22fa46961aab/go.mod h1:/P9AEU963A2AYjv4d1V5eVL1CQbEJq6aCNHDDjibzu8= +github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= +github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= +github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= +github.com/gogo/googleapis v0.0.0-20180223154316-0cd9801be74a/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s= +github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= +github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= +github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/gogo/status v1.1.0/go.mod h1:BFv9nrluPLmrS0EmGVvLaPNmRosr9KapBYd5/hpY1WM= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.4.3 h1:JjCZWpVbqXDqFVmTfYWEVTMIYrL/NPdPSCHPJ0T/raM= +github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/gomodule/redigo v1.7.1-0.20190724094224-574c33c3df38/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.1 h1:JFrFEBb2xKufg6XkJsJr+WbKb4FQlURi5RUcBveYu9k= +github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= +github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= +github.com/grpc-ecosystem/grpc-gateway v1.16.0 h1:gmcG1KaJ57LophUzW0Hy8NmPhnMZb4M0+kPpLofRdBo= +github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= +github.com/hashicorp/go-version v1.2.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= +github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/hydrogen18/memlistener v0.0.0-20141126152155-54553eb933fb/go.mod h1:qEIFzExnS6016fRpRfxrExeVn2gbClQA99gQhnIcdhE= +github.com/imkira/go-interpol v1.1.0/go.mod h1:z0h2/2T3XF8kyEPpRgJ3kmNv+C43p+I/CoI+jC3w2iA= +github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= +github.com/iris-contrib/blackfriday v2.0.0+incompatible/go.mod h1:UzZ2bDEoaSGPbkg6SAB4att1aAwTmVIx/5gCVqeyUdI= +github.com/iris-contrib/go.uuid v2.0.0+incompatible/go.mod h1:iz2lgM/1UnEf1kP0L/+fafWORmlnuysV2EMP8MW+qe0= +github.com/iris-contrib/i18n v0.0.0-20171121225848-987a633949d0/go.mod h1:pMCz62A0xJL6I+umB2YTlFRwWXaDFA0jy+5HzGiJjqI= +github.com/iris-contrib/schema v0.0.1/go.mod h1:urYA3uvUNG1TIIjOSCzHr9/LmbQo8LrOcOqfqxa4hXw= +github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= +github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= +github.com/juju/errors v0.0.0-20181118221551-089d3ea4e4d5/go.mod h1:W54LbzXuIE0boCoNJfwqpmkKJ1O4TCTZMetAt6jGk7Q= +github.com/juju/loggo v0.0.0-20180524022052-584905176618/go.mod h1:vgyd7OREkbtVEN/8IXZe5Ooef3LQePvuBm9UWj6ZL8U= +github.com/juju/testing v0.0.0-20180920084828-472a3e8b2073/go.mod h1:63prj8cnj0tU0S9OHjGJn+b1h0ZghCndfnbQolrYTwA= +github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88/go.mod h1:3w7q1U84EfirKl04SVQ/s7nPm1ZPhiXd34z40TNz36k= +github.com/kataras/golog v0.0.9/go.mod h1:12HJgwBIZFNGL0EJnMRhmvGA0PQGx8VFwrZtM4CqbAk= +github.com/kataras/iris/v12 v12.0.1/go.mod h1:udK4vLQKkdDqMGJJVd/msuMtN6hpYJhg/lSzuxjhO+U= +github.com/kataras/neffos v0.0.10/go.mod h1:ZYmJC07hQPW67eKuzlfY7SO3bC0mw83A3j6im82hfqw= +github.com/kataras/pio v0.0.0-20190103105442-ea782b38602d/go.mod h1:NV88laa9UiiDuX9AhMbDPkGYSPugBOV6yTZB1l2K9Z0= +github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= +github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/klauspost/compress v1.8.2/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= +github.com/klauspost/compress v1.9.0/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= +github.com/klauspost/cpuid v1.2.1/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= +github.com/konsorten/go-windows-terminal-sequences v1.0.3 h1:CE8S1cTafDpPvMhIxNJKvHsGVBgn1xWYf1NbHQhywc8= +github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.0 h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs= +github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/labstack/echo/v4 v4.1.11/go.mod h1:i541M3Fj6f76NZtHSj7TXnyM8n2gaodfvfxNnFqi74g= +github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k= +github.com/lib/pq v1.9.0 h1:L8nSXQQzAYByakOFMTwpjRoHsMJklur4Gi59b6VivR8= +github.com/lib/pq v1.9.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= +github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= +github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= +github.com/mattn/goveralls v0.0.2/go.mod h1:8d1ZMHsd7fW6IRPKQh46F2WRpyib5/X4FOpevwGNQEw= +github.com/mediocregopher/mediocre-go-lib v0.0.0-20181029021733-cb65787f37ed/go.mod h1:dSsfyI2zABAdhcbvkXqgxOxrCsbYeHCPgrZkku60dSg= +github.com/mediocregopher/radix/v3 v3.3.0/go.mod h1:EmfVyvspXz1uZEyPBMyGK+kjWiKQGvsUt6O3Pj+LDCQ= +github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc= +github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= +github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/moul/http2curl v1.0.0/go.mod h1:8UbvGypXm98wA/IqH45anm5Y2Z6ep6O31QGOAZ3H0fQ= +github.com/nats-io/nats.go v1.8.1/go.mod h1:BrFz9vVn0fU3AcH9Vn4Kd7W0NpJ651tD5omQ3M8LwxM= +github.com/nats-io/nkeys v0.0.2/go.mod h1:dab7URMsZm6Z/jp9Z5UGa87Uutgc2mVpXLC4B7TDb/4= +github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= +github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= +github.com/onsi/ginkgo v1.13.0/go.mod h1:+REjRxOmWfHCjfv9TTWB1jD1Frx4XydAD3zm1lskyM0= +github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= +github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= +github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= +github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5 h1:q2e307iGHPdTGp0hoxKjt1H5pDo6utceo3dQVK3I5XQ= +github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5/go.mod h1:jvVRKCrJTQWu0XVbaOlby/2lO20uSCHEMzzplHXte1o= +github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4= +github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= +github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= +github.com/ryanuber/columnize v2.1.0+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= +github.com/sclevine/agouti v3.0.0+incompatible/go.mod h1:b4WX9W9L1sfQKXeJf1mUTLZKJ48R1S7H23Ji7oFO5Bw= +github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= +github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= +github.com/sirupsen/logrus v1.6.0 h1:UBcNElsrwanuuMsnGSlYmtmgbb23qDR5dG+6X6Oo89I= +github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= +github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= +github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= +github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= +github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= +github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU= +github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= +github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= +github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc= +github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= +github.com/urfave/negroni v1.0.0/go.mod h1:Meg73S6kFm/4PpbYdq35yYWoCZ9mS/YSx+lKnmiohz4= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.6.0/go.mod h1:FstJa9V+Pj9vQ7OJie2qMHdwemEDaDiSdBnvPM1Su9w= +github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8= +github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio= +github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= +github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= +github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= +github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= +github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0/go.mod h1:/LWChgwKmvncFJFHJ7Gvn9wZArjbV5/FppcK2fKk/tI= +github.com/yudai/gojsondiff v1.0.0/go.mod h1:AY32+k2cwILAkW1fbgxQ5mUmMiZFgLIV+FBNExI05xg= +github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82/go.mod h1:lgjkn3NuSvDfVJdfcVVdX+jpBxNmX4rDAzaS45IcYoM= +github.com/yudai/pp v2.0.1+incompatible/go.mod h1:PuxR/8QJ7cyCkFp/aUDS+JY727OFEZkTdatxwunjIkc= +github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190327091125-710a502c58a2/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20201110031124-69a78807bb2b h1:uwuIcX0g4Yl1NC5XAz37xsr2lTtcqevgzYNVt49waME= +golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201214210602-f9fddec55a1e h1:AyodaIpKjppX+cBfTASF2E1US3H2JFBj920Ot3rtDjs= +golang.org/x/sys v0.0.0-20201214210602-f9fddec55a1e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.4 h1:0YWbFKbhXG/wIiuHDSKpS0Iy7FSA+u45VtBMfQcFTTc= +golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20181221001348-537d06c36207/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190327201419-c70d86f8b7cf/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genproto v0.0.0-20180518175338-11a468237815/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= +google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= +google.golang.org/genproto v0.0.0-20200911024640-645f7a48b24f h1:Yv4xsIx7HZOoyUGSJ2ksDyWE2qIBXROsZKt2ny3hCGM= +google.golang.org/genproto v0.0.0-20200911024640-645f7a48b24f/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/grpc v1.12.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= +google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +google.golang.org/grpc v1.29.1/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk= +google.golang.org/grpc v1.33.1 h1:DGeFlSan2f+WEtCERJ4J9GJWk15TxUi8QGagfI87Xyc= +google.golang.org/grpc v1.33.1/go.mod h1:fr5YgcSWrqhRRxogOsw7RzIpsmvOZ6IcH4kBYTpR3n0= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= +google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4c= +google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8bDuhia5mkpMnE= +gopkg.in/go-playground/validator.v8 v8.18.2/go.mod h1:RX2a/7Ha8BgOhfk7j780h4/u/RRjR0eouCJSH80/M2Y= +gopkg.in/mgo.v2 v2.0.0-20180705113604-9856a29383ce/go.mod h1:yeKp02qBN3iKW1OzL3MGk2IdtZzaj7SFntXj72NppTA= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/parsers/pgparser/pgparser.go b/parsers/pgparser/pgparser.go new file mode 100644 index 0000000..e0b7123 --- /dev/null +++ b/parsers/pgparser/pgparser.go @@ -0,0 +1,149 @@ +// Package pgparser is an optional sqlguard Parser backed by a real +// PostgreSQL grammar (github.com/auxten/postgresql-parser, pure Go, no cgo). +// +// It produces exact, structural answers for the false-positive-prone facts +// (statement kind, WHERE/LIMIT/ORDER BY/FROM presence, SELECT *, explicit +// INSERT columns) instead of regex guesses. SQL the grammar rejects — +// dynamic fragments, dialect extensions, driver placeholders it can't +// handle — transparently degrades to sqlguard's zero-dependency +// FallbackParser, so analysis never breaks the caller's query path. +// +// Usage: +// +// sqlguard.Register("sqlguard-pg", "pgx", middleware.WithParser(pgparser.New())) +// db, _ := sql.Open("sqlguard-pg", dsn) +package pgparser + +import ( + "github.com/KARTIKrocks/sqlguard/analyzer" + "github.com/auxten/postgresql-parser/pkg/sql/parser" + "github.com/auxten/postgresql-parser/pkg/sql/sem/tree" +) + +// Parser implements analyzer.Parser using a PostgreSQL grammar. +type Parser struct { + fallback analyzer.Parser +} + +// New returns a Postgres-dialect Parser that falls back to the +// zero-dependency FallbackParser on parse failure. +func New() *Parser { + return &Parser{fallback: analyzer.NewFallbackParser()} +} + +var _ analyzer.Parser = (*Parser)(nil) + +// Parse implements analyzer.Parser. It never returns an error: unparseable +// SQL yields the fallback parser's best-effort Statement (Exact=false). +func (p *Parser) Parse(sql string) (*analyzer.Statement, error) { + // The fallback result is the baseline. It already detects the literal/text- + // level fields (leading-wildcard LIKE, non-sargable predicates, unsafe + // NOT NULL adds) that the AST loses after parsing, so we keep those fields + // and overwrite only the structural ones. + st, _ := p.fallback.Parse(sql) + if st == nil { + st = &analyzer.Statement{Raw: sql} + } + + stmts, err := parser.Parse(sql) + if err != nil || len(stmts) == 0 || stmts[0].AST == nil { + return st, nil // keep best-effort fallback Statement + } + + st.Kind = analyzer.StmtOther + st.HasWhere = false + st.HasLimit = false + st.HasOrderBy = false + st.HasFrom = false + st.SelectStar = false + st.SelectDistinct = false + st.OffsetValue = 0 + st.InsertColumnsListed = false + + switch n := stmts[0].AST.(type) { + case *tree.Select: + st.Kind = analyzer.StmtSelect + st.HasOrderBy = len(n.OrderBy) > 0 + st.HasLimit = n.Limit != nil + st.OffsetValue = offsetValue(n.Limit) + fillSelectBody(st, n.Select) + case *tree.SelectClause: + st.Kind = analyzer.StmtSelect + fillSelectClause(st, n) + case *tree.Delete: + st.Kind = analyzer.StmtDelete + st.HasWhere = n.Where != nil + st.HasLimit = n.Limit != nil + st.HasOrderBy = len(n.OrderBy) > 0 + st.OffsetValue = offsetValue(n.Limit) + case *tree.Update: + st.Kind = analyzer.StmtUpdate + st.HasWhere = n.Where != nil + st.HasLimit = n.Limit != nil + st.HasOrderBy = len(n.OrderBy) > 0 + st.OffsetValue = offsetValue(n.Limit) + case *tree.Insert: + st.Kind = analyzer.StmtInsert + st.InsertColumnsListed = len(n.Columns) > 0 + } + + st.Exact = true + return st, nil +} + +// fillSelectBody unwraps the inner SelectStatement of a *tree.Select. +func fillSelectBody(st *analyzer.Statement, sel tree.SelectStatement) { + switch c := sel.(type) { + case *tree.SelectClause: + fillSelectClause(st, c) + case *tree.ParenSelect: + if c.Select != nil { + st.HasOrderBy = st.HasOrderBy || len(c.Select.OrderBy) > 0 + st.HasLimit = st.HasLimit || c.Select.Limit != nil + if v := offsetValue(c.Select.Limit); v > st.OffsetValue { + st.OffsetValue = v + } + fillSelectBody(st, c.Select.Select) + } + } + // UnionClause / ValuesClause: leave structural defaults; the rules that + // matter for those forms don't trigger on set operations. +} + +// offsetValue extracts a literal OFFSET as an int, or 0 when there is no limit +// clause, no offset, or a non-literal (parameterized) offset — matching the +// large-offset rule's contract that only statically-known offsets are flagged. +func offsetValue(lim *tree.Limit) int { + if lim == nil { + return 0 + } + nv, ok := lim.Offset.(*tree.NumVal) + if !ok { + return 0 + } + n, err := nv.AsInt64() + if err != nil || n < 0 { + return 0 + } + return int(n) +} + +func fillSelectClause(st *analyzer.Statement, c *tree.SelectClause) { + st.HasWhere = c.Where != nil + st.HasFrom = len(c.From.Tables) > 0 + // DISTINCT and DISTINCT ON both set the select-level distinct flag; an + // aggregate-level DISTINCT (count(DISTINCT x)) lives in the expr, not here. + st.SelectDistinct = c.Distinct || len(c.DistinctOn) > 0 + for _, e := range c.Exprs { + switch ex := e.Expr.(type) { + case tree.UnqualifiedStar, *tree.UnqualifiedStar: + st.SelectStar = true // SELECT * + case *tree.AllColumnsSelector: + st.SelectStar = true // SELECT t.* (resolved form) + case *tree.UnresolvedName: + if ex.Star { // SELECT t.* (unresolved form) + st.SelectStar = true + } + } + } +} diff --git a/parsers/pgparser/pgparser_test.go b/parsers/pgparser/pgparser_test.go new file mode 100644 index 0000000..b2e51b2 --- /dev/null +++ b/parsers/pgparser/pgparser_test.go @@ -0,0 +1,137 @@ +package pgparser + +import ( + "testing" + + "github.com/KARTIKrocks/sqlguard/analyzer" +) + +func TestParser_ExactStructuralFacts(t *testing.T) { + p := New() + tests := []struct { + name string + sql string + want analyzer.Statement + }{ + { + name: "cte-wrapped delete with where", + sql: "WITH r AS (SELECT id FROM o WHERE ts > now()) DELETE FROM o WHERE id IN (SELECT id FROM r)", + want: analyzer.Statement{Kind: analyzer.StmtDelete, HasWhere: true, Exact: true}, + }, + { + name: "delete without where", + sql: "DELETE FROM users", + want: analyzer.Statement{Kind: analyzer.StmtDelete, HasWhere: false, Exact: true}, + }, + { + name: "select star with from", + sql: "SELECT * FROM users", + want: analyzer.Statement{Kind: analyzer.StmtSelect, SelectStar: true, HasFrom: true, Exact: true}, + }, + { + name: "qualified star", + sql: "SELECT u.* FROM users u", + want: analyzer.Statement{Kind: analyzer.StmtSelect, SelectStar: true, HasFrom: true, Exact: true}, + }, + { + name: "count star is not select star", + sql: "SELECT count(*) FROM users", + want: analyzer.Statement{Kind: analyzer.StmtSelect, SelectStar: false, HasFrom: true, Exact: true}, + }, + { + name: "select no from", + sql: "SELECT 1", + want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: false, Exact: true}, + }, + { + name: "insert with columns", + sql: "INSERT INTO users (name) VALUES ('a')", + want: analyzer.Statement{Kind: analyzer.StmtInsert, InsertColumnsListed: true, Exact: true}, + }, + { + name: "insert without columns", + sql: "INSERT INTO users VALUES ('a')", + want: analyzer.Statement{Kind: analyzer.StmtInsert, InsertColumnsListed: false, Exact: true}, + }, + { + name: "order by without limit", + sql: "SELECT id FROM users ORDER BY name", + want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: true, HasOrderBy: true, Exact: true}, + }, + { + name: "select distinct", + sql: "SELECT DISTINCT name FROM users", + want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: true, SelectDistinct: true, Exact: true}, + }, + { + name: "distinct on", + sql: "SELECT DISTINCT ON (dept) dept, name FROM emp", + want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: true, SelectDistinct: true, Exact: true}, + }, + { + name: "count distinct is not select distinct", + sql: "SELECT count(DISTINCT id) FROM users", + want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: true, Exact: true}, + }, + { + name: "literal offset", + sql: "SELECT id FROM users WHERE x = 1 ORDER BY id LIMIT 10 OFFSET 5000", + want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: true, HasWhere: true, HasOrderBy: true, HasLimit: true, OffsetValue: 5000, Exact: true}, + }, + { + name: "parameterized offset is zero", + sql: "SELECT id FROM users WHERE x = 1 LIMIT 10 OFFSET $1", + want: analyzer.Statement{Kind: analyzer.StmtSelect, HasFrom: true, HasWhere: true, HasLimit: true, Exact: true}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + st, err := p.Parse(tt.sql) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if st.Kind != tt.want.Kind || + st.HasWhere != tt.want.HasWhere || + st.HasLimit != tt.want.HasLimit || + st.HasOrderBy != tt.want.HasOrderBy || + st.HasFrom != tt.want.HasFrom || + st.SelectStar != tt.want.SelectStar || + st.SelectDistinct != tt.want.SelectDistinct || + st.OffsetValue != tt.want.OffsetValue || + st.InsertColumnsListed != tt.want.InsertColumnsListed || + st.Exact != tt.want.Exact { + t.Errorf("Parse(%q)\n got: %+v\nwant: %+v", tt.sql, *st, tt.want) + } + }) + } +} + +func TestParser_FallsBackOnUnparseable(t *testing.T) { + p := New() + // Driver placeholders the PG grammar won't accept as-is still must not + // error, and must come back as a best-effort (non-exact) Statement. + st, err := p.Parse("SELECT * FROM t WHERE id = ?") + if err != nil { + t.Fatalf("fallback path must not error: %v", err) + } + if st == nil { + t.Fatal("nil statement") + } + if st.Exact { + t.Error("expected Exact=false when grammar rejected the SQL") + } +} + +func TestParser_IntegratesWithAnalyzer(t *testing.T) { + a := analyzer.Default().WithParser(New()) + + got := a.Analyze("DELETE FROM users -- WHERE id = 1") + if len(got) != 1 || got[0].RuleName != "delete-without-where" { + t.Errorf("expected delete-without-where (WHERE only in comment), got %+v", got) + } + + if r := a.Analyze("SELECT id FROM users WHERE id = 1 LIMIT 1"); len(r) != 0 { + t.Errorf("expected no findings for safe query, got %+v", r) + } +} diff --git a/reporter/console.go b/reporter/console.go new file mode 100644 index 0000000..286aa13 --- /dev/null +++ b/reporter/console.go @@ -0,0 +1,64 @@ +package reporter + +import ( + "fmt" + "io" + "os" + "sync" + + "github.com/KARTIKrocks/sqlguard/analyzer" +) + +const ( + colorReset = "\033[0m" + colorRed = "\033[31m" + colorYellow = "\033[33m" + colorCyan = "\033[36m" +) + +// ConsoleReporter prints analysis results to the terminal with color. +// The output writer is fixed at construction so Report is safe for concurrent +// use (the writer cannot be swapped out from under an in-flight Report). +type ConsoleReporter struct { + out io.Writer + mu sync.Mutex +} + +// NewConsoleReporter creates a ConsoleReporter that writes to stderr. +func NewConsoleReporter() *ConsoleReporter { + return NewConsoleReporterTo(os.Stderr) +} + +// NewConsoleReporterTo creates a ConsoleReporter that writes to w. +func NewConsoleReporterTo(w io.Writer) *ConsoleReporter { + return &ConsoleReporter{out: w} +} + +// Report writes each result to the configured output, colored by severity. +func (c *ConsoleReporter) Report(results []analyzer.Result) { + c.mu.Lock() + defer c.mu.Unlock() + + for _, r := range results { + color := colorCyan + switch r.Severity { + case analyzer.SeverityWarning: + color = colorYellow + case analyzer.SeverityCritical: + color = colorRed + } + + _, _ = fmt.Fprintf(c.out, "\n%s[SQLGUARD %s]%s %s\n", color, r.Severity, colorReset, r.RuleName) + + if r.File != "" { + _, _ = fmt.Fprintf(c.out, " File: %s:%d\n", r.File, r.Line) + } + + _, _ = fmt.Fprintf(c.out, " Query: %s\n", r.Query) + _, _ = fmt.Fprintf(c.out, " Issue: %s\n", r.Message) + + if r.Suggestion != "" { + _, _ = fmt.Fprintf(c.out, " Fix: %s\n", r.Suggestion) + } + } +} diff --git a/reporter/console_test.go b/reporter/console_test.go new file mode 100644 index 0000000..a1d0f56 --- /dev/null +++ b/reporter/console_test.go @@ -0,0 +1,86 @@ +package reporter + +import ( + "bytes" + "strings" + "testing" + + "github.com/KARTIKrocks/sqlguard/analyzer" +) + +func TestConsoleReporter_Report(t *testing.T) { + var buf bytes.Buffer + rep := NewConsoleReporterTo(&buf) + + results := []analyzer.Result{ + { + RuleName: "select-star", + Severity: analyzer.SeverityWarning, + Query: "SELECT * FROM users", + Message: "SELECT * detected.", + Suggestion: "Select only needed columns.", + }, + } + + rep.Report(results) + output := buf.String() + + if !strings.Contains(output, "SQLGUARD WARNING") { + t.Error("expected WARNING label in output") + } + if !strings.Contains(output, "select-star") { + t.Error("expected rule name in output") + } + if !strings.Contains(output, "SELECT * FROM users") { + t.Error("expected query in output") + } + if !strings.Contains(output, "Select only needed columns.") { + t.Error("expected suggestion in output") + } +} + +func TestConsoleReporter_CriticalSeverity(t *testing.T) { + var buf bytes.Buffer + rep := NewConsoleReporterTo(&buf) + + rep.Report([]analyzer.Result{{ + RuleName: "delete-without-where", + Severity: analyzer.SeverityCritical, + Query: "DELETE FROM users", + Message: "DELETE without WHERE.", + }}) + + if !strings.Contains(buf.String(), "SQLGUARD CRITICAL") { + t.Error("expected CRITICAL label in output") + } +} + +func TestConsoleReporter_WithFileInfo(t *testing.T) { + var buf bytes.Buffer + rep := NewConsoleReporterTo(&buf) + + rep.Report([]analyzer.Result{{ + RuleName: "select-star", + Severity: analyzer.SeverityWarning, + Query: "SELECT * FROM users", + Message: "SELECT * detected.", + File: "repo/user.go", + Line: 42, + }}) + + output := buf.String() + if !strings.Contains(output, "repo/user.go:42") { + t.Error("expected file:line in output") + } +} + +func TestConsoleReporter_EmptyResults(t *testing.T) { + var buf bytes.Buffer + rep := NewConsoleReporterTo(&buf) + + rep.Report(nil) + + if buf.Len() != 0 { + t.Error("expected no output for empty results") + } +} diff --git a/reporter/json.go b/reporter/json.go new file mode 100644 index 0000000..2bdd04f --- /dev/null +++ b/reporter/json.go @@ -0,0 +1,67 @@ +package reporter + +import ( + "encoding/json" + "fmt" + "io" + "os" + "sync" + + "github.com/KARTIKrocks/sqlguard/analyzer" +) + +// JSONReporter outputs analysis results as JSON. +// The output writer is fixed at construction so Report is safe for concurrent +// use (the writer cannot be swapped out from under an in-flight Report). +type JSONReporter struct { + out io.Writer + mu sync.Mutex +} + +// NewJSONReporter creates a JSONReporter that writes to stderr. +func NewJSONReporter() *JSONReporter { + return NewJSONReporterTo(os.Stderr) +} + +// NewJSONReporterTo creates a JSONReporter that writes to w. +func NewJSONReporterTo(w io.Writer) *JSONReporter { + return &JSONReporter{out: w} +} + +type jsonResult struct { + Rule string `json:"rule"` + Severity string `json:"severity"` + Query string `json:"query"` + Fingerprint string `json:"fingerprint,omitempty"` + Message string `json:"message"` + Suggestion string `json:"suggestion,omitempty"` + File string `json:"file,omitempty"` + Line int `json:"line,omitempty"` +} + +// Report writes the results to the configured output as a JSON array. +func (j *JSONReporter) Report(results []analyzer.Result) { + j.mu.Lock() + defer j.mu.Unlock() + + out := make([]jsonResult, len(results)) + for i, r := range results { + out[i] = jsonResult{ + Rule: r.RuleName, + Severity: r.Severity.String(), + Query: r.Query, + Fingerprint: r.Fingerprint, + Message: r.Message, + Suggestion: r.Suggestion, + File: r.File, + Line: r.Line, + } + } + + enc := json.NewEncoder(j.out) + enc.SetIndent("", " ") + if err := enc.Encode(out); err != nil { + // Fallback: log encoding failure since Reporter interface can't return error + fmt.Fprintf(os.Stderr, "sqlguard: failed to encode JSON report: %v\n", err) + } +} diff --git a/reporter/json_test.go b/reporter/json_test.go new file mode 100644 index 0000000..b7c30a0 --- /dev/null +++ b/reporter/json_test.go @@ -0,0 +1,77 @@ +package reporter + +import ( + "bytes" + "encoding/json" + "testing" + + "github.com/KARTIKrocks/sqlguard/analyzer" +) + +func TestJSONReporter_Report(t *testing.T) { + var buf bytes.Buffer + rep := NewJSONReporterTo(&buf) + + results := []analyzer.Result{ + { + RuleName: "select-star", + Severity: analyzer.SeverityWarning, + Query: "SELECT * FROM users", + Message: "SELECT * detected.", + Suggestion: "Select only needed columns.", + File: "user.go", + Line: 10, + }, + { + RuleName: "delete-without-where", + Severity: analyzer.SeverityCritical, + Query: "DELETE FROM users", + Message: "DELETE without WHERE.", + }, + } + + rep.Report(results) + + var parsed []map[string]any + if err := json.Unmarshal(buf.Bytes(), &parsed); err != nil { + t.Fatalf("invalid JSON output: %v\nGot: %s", err, buf.String()) + } + + if len(parsed) != 2 { + t.Fatalf("expected 2 results, got %d", len(parsed)) + } + + if parsed[0]["rule"] != "select-star" { + t.Errorf("expected rule 'select-star', got %v", parsed[0]["rule"]) + } + if parsed[0]["severity"] != "WARNING" { + t.Errorf("expected severity 'WARNING', got %v", parsed[0]["severity"]) + } + if parsed[0]["file"] != "user.go" { + t.Errorf("expected file 'user.go', got %v", parsed[0]["file"]) + } + + if parsed[1]["severity"] != "CRITICAL" { + t.Errorf("expected severity 'CRITICAL', got %v", parsed[1]["severity"]) + } + // file should be omitted (empty) + if _, ok := parsed[1]["file"]; ok && parsed[1]["file"] != "" { + t.Errorf("expected file to be omitted, got %v", parsed[1]["file"]) + } +} + +func TestJSONReporter_EmptyResults(t *testing.T) { + var buf bytes.Buffer + rep := NewJSONReporterTo(&buf) + + rep.Report([]analyzer.Result{}) + + var parsed []map[string]any + if err := json.Unmarshal(buf.Bytes(), &parsed); err != nil { + t.Fatalf("invalid JSON output: %v", err) + } + + if len(parsed) != 0 { + t.Errorf("expected empty array, got %d items", len(parsed)) + } +} diff --git a/reporter/reporter.go b/reporter/reporter.go new file mode 100644 index 0000000..fec8431 --- /dev/null +++ b/reporter/reporter.go @@ -0,0 +1,8 @@ +package reporter + +import "github.com/KARTIKrocks/sqlguard/analyzer" + +// Reporter defines the interface for reporting analysis results. +type Reporter interface { + Report(results []analyzer.Result) +} diff --git a/sqlguard.go b/sqlguard.go new file mode 100644 index 0000000..912ac95 --- /dev/null +++ b/sqlguard.go @@ -0,0 +1,40 @@ +// Package sqlguard is a production-safe SQL query analyzer for Go applications. +// +// It detects slow queries, dangerous SQL patterns, and performance issues +// both at runtime (via a database/sql driver wrapper) and statically +// (via the CLI). +// +// The runtime guard wraps at the driver.Driver layer, so it returns a real +// *sql.DB and analyzes every query — including those issued by ORMs and +// query builders — without a method list to keep in sync. +// +// Register a wrapped driver by name: +// +// sqlguard.Register("sqlguard-pg", "pgx") +// db, _ := sql.Open("sqlguard-pg", dsn) +// db.Query("SELECT * FROM users") // logs warning about SELECT * +// +// Or wrap an existing driver.Connector directly: +// +// db := sqlguard.OpenDB(connector) +package sqlguard + +import ( + "database/sql" + "database/sql/driver" + + "github.com/KARTIKrocks/sqlguard/middleware" +) + +// Register wraps the database/sql driver registered under baseDriver and +// registers the analyzed result under name. Afterwards sql.Open(name, dsn) +// yields a *sql.DB whose every query is analyzed. +func Register(name, baseDriver string, opts ...middleware.Option) error { + return middleware.Register(name, baseDriver, opts...) +} + +// OpenDB wraps a driver.Connector and returns an analyzed *sql.DB. Use this +// when you already hold a connector (e.g. pgx's stdlib.GetConnector). +func OpenDB(c driver.Connector, opts ...middleware.Option) *sql.DB { + return middleware.OpenDB(c, opts...) +}