diff --git a/.github/actions/install-cvc5/action.yml b/.github/actions/install-cvc5/action.yml new file mode 100644 index 0000000000..3ba8055209 --- /dev/null +++ b/.github/actions/install-cvc5/action.yml @@ -0,0 +1,52 @@ +# Copyright Strata Contributors +# SPDX-License-Identifier: Apache-2.0 OR MIT +name: Install cvc5 +description: > + Download a static cvc5 build and put it on the PATH. Supports both + x86_64 and aarch64 Linux runners. Consolidates the cvc5 install logic + previously duplicated across ci.yml and cbmc.yml; intended to also be + adopted by the python-fuzz workflow once that lands (see + https://github.com/strata-org/Strata/pull/984). + +inputs: + version: + description: cvc5 release tag (e.g. "1.2.1"). + required: false + default: "1.2.1" + install-to: + description: > + Where to make the cvc5 binary available. One of: + "path" (default) — prepend the unpacked bin/ directory to $GITHUB_PATH. + "system" — sudo cp the cvc5 binary into /usr/local/bin/. + required: false + default: "path" + +runs: + using: composite + steps: + - name: Download cvc5 + shell: bash + run: | + set -eu + ARCH=$(uname -m) + case "$ARCH" in + x86_64) ARCH_NAME="x86_64" ;; + aarch64|arm64) ARCH_NAME="arm64" ;; + *) echo "Unsupported architecture: $ARCH" >&2; exit 1 ;; + esac + URL="https://github.com/cvc5/cvc5/releases/download/cvc5-${{ inputs.version }}/cvc5-Linux-${ARCH_NAME}-static.zip" + wget -q "$URL" + unzip -q "cvc5-Linux-${ARCH_NAME}-static.zip" + chmod +x "cvc5-Linux-${ARCH_NAME}-static/bin/cvc5" + case "${{ inputs.install-to }}" in + path) + echo "$GITHUB_WORKSPACE/cvc5-Linux-${ARCH_NAME}-static/bin/" >> "$GITHUB_PATH" + ;; + system) + sudo cp "cvc5-Linux-${ARCH_NAME}-static/bin/cvc5" /usr/local/bin/ + ;; + *) + echo "Unknown install-to value: ${{ inputs.install-to }}" >&2 + exit 2 + ;; + esac diff --git a/.github/actions/install-z3/action.yml b/.github/actions/install-z3/action.yml new file mode 100644 index 0000000000..86d6c52839 --- /dev/null +++ b/.github/actions/install-z3/action.yml @@ -0,0 +1,55 @@ +# Copyright Strata Contributors +# SPDX-License-Identifier: Apache-2.0 OR MIT +name: Install z3 +description: > + Download a z3 release and put it on the PATH. Supports x86_64 and + aarch64 Linux runners. Consolidates the z3 install logic previously + duplicated across ci.yml and cbmc.yml. + +inputs: + version: + description: z3 release tag (e.g. "4.15.2"). + required: false + default: "4.15.2" + install-to: + description: > + Where to make the z3 binary available. One of: + "path" (default) — prepend the unpacked bin/ directory to $GITHUB_PATH. + "system" — sudo cp the z3 binary into /usr/local/bin/. + required: false + default: "path" + +runs: + using: composite + steps: + - name: Download z3 + shell: bash + run: | + set -eu + ARCH=$(uname -m) + case "$ARCH" in + x86_64) + URL="https://github.com/Z3Prover/z3/releases/download/z3-${{ inputs.version }}/z3-${{ inputs.version }}-x64-glibc-2.39.zip" + ARCHIVE_NAME="z3-${{ inputs.version }}-x64-glibc-2.39" + ;; + aarch64|arm64) + URL="https://github.com/Z3Prover/z3/releases/download/z3-${{ inputs.version }}/z3-${{ inputs.version }}-arm64-glibc-2.34.zip" + ARCHIVE_NAME="z3-${{ inputs.version }}-arm64-glibc-2.34" + ;; + *) echo "Unsupported architecture: $ARCH" >&2; exit 1 ;; + esac + wget -q "$URL" + unzip -q "${ARCHIVE_NAME}.zip" + chmod +x "${ARCHIVE_NAME}/bin/z3" + case "${{ inputs.install-to }}" in + path) + echo "$GITHUB_WORKSPACE/${ARCHIVE_NAME}/bin/" >> "$GITHUB_PATH" + ;; + system) + sudo cp "${ARCHIVE_NAME}/bin/z3" /usr/local/bin/ + ;; + *) + echo "Unknown install-to value: ${{ inputs.install-to }}" >&2 + exit 2 + ;; + esac diff --git a/.github/actions/restore-lake-cache/action.yml b/.github/actions/restore-lake-cache/action.yml new file mode 100644 index 0000000000..9151866cfd --- /dev/null +++ b/.github/actions/restore-lake-cache/action.yml @@ -0,0 +1,82 @@ +# Copyright Strata Contributors +# SPDX-License-Identifier: Apache-2.0 OR MIT +name: Restore lake cache +description: > + Thin wrapper around actions/cache/restore@v5 that uses the standard + Strata cache-key pattern: + lake------ + with three fallback keys dropping each trailing component in turn. + Consolidates the ~15-line cache block previously duplicated across + ci.yml's build_and_test_lean, check_pending_python, build_python and + cbmc.yml; intended to also be adopted by the python-fuzz workflow once + that lands (see https://github.com/strata-org/Strata/pull/984). + +inputs: + fail-on-cache-miss: + description: > + If 'true', the step fails when no cache entry matches. Use this in + jobs that depend on a cache saved by an upstream job for the same + SHA (see https://github.com/strata-org/Strata/issues/952). + required: false + default: "false" + path: + description: Cache path(s), newline-separated. + required: false + default: ".lake" + key-prefix: + description: > + Prefix used in the cache key. The action also hashes the + repo-root `lean-toolchain` and `lake-manifest.json`, so changing + only this prefix is appropriate for caches keyed on the same + root-level Lean build (e.g. distinguishing different artifact + names with the same source set). Sub-projects with their own + toolchain/manifest do not currently fit this action and should + not reuse it as-is. + required: false + default: "lake" + use-restore-keys: + description: > + Must be the string `'true'` or `'false'`. + + If `'true'` (default), include three fallback `restore-keys` so + that a near match (same toolchain/manifest/.st files but different + SHA) is used when no exact-SHA cache exists. + + Set to `'false'` for downstream jobs that depend on a cache saved + by an upstream job for the *same* SHA (typically together with + `fail-on-cache-miss: 'true'`); see + https://github.com/strata-org/Strata/issues/952. With fallback + keys present, `fail-on-cache-miss` only triggers when every + fallback also misses, which silently allows stale cross-SHA cache + matches and defeats the safety net. + required: false + default: "true" + +outputs: + cache-hit: + description: Whether a cache entry was restored (see actions/cache/restore@v5). + value: ${{ steps.restore-with-fallback.outputs.cache-hit || steps.restore-exact.outputs.cache-hit }} + +runs: + using: composite + steps: + - name: Restore lake cache (with fallback keys) + id: restore-with-fallback + if: inputs.use-restore-keys != 'false' + uses: actions/cache/restore@v5 + with: + path: ${{ inputs.path }} + key: ${{ inputs.key-prefix }}-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('lean-toolchain') }}-${{ hashFiles('lake-manifest.json') }}-${{ hashFiles('**/*.st') }}-${{ github.sha }} + restore-keys: | + ${{ inputs.key-prefix }}-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('lean-toolchain') }}-${{ hashFiles('lake-manifest.json') }}-${{ hashFiles('**/*.st') }} + ${{ inputs.key-prefix }}-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('lean-toolchain') }}-${{ hashFiles('lake-manifest.json') }} + ${{ inputs.key-prefix }}-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('lean-toolchain') }} + fail-on-cache-miss: ${{ inputs.fail-on-cache-miss }} + - name: Restore lake cache (exact SHA only) + id: restore-exact + if: inputs.use-restore-keys == 'false' + uses: actions/cache/restore@v5 + with: + path: ${{ inputs.path }} + key: ${{ inputs.key-prefix }}-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('lean-toolchain') }}-${{ hashFiles('lake-manifest.json') }}-${{ hashFiles('**/*.st') }}-${{ github.sha }} + fail-on-cache-miss: ${{ inputs.fail-on-cache-miss }} diff --git a/.github/actions/save-lake-cache/action.yml b/.github/actions/save-lake-cache/action.yml new file mode 100644 index 0000000000..5754371b87 --- /dev/null +++ b/.github/actions/save-lake-cache/action.yml @@ -0,0 +1,35 @@ +# Copyright Strata Contributors +# SPDX-License-Identifier: Apache-2.0 OR MIT +name: Save lake cache +description: > + Save the lake build cache using the canonical Strata cache-key pattern. + Companion to `restore-lake-cache`: the two actions share the same key + construction so downstream jobs that consume the saved cache via + `restore-lake-cache` with `use-restore-keys: "false"` will hit it + reliably. + + Use this in workflows that produce a fresh build (typically the + `build_and_test_lean` job in ci.yml) to share the result with + downstream jobs at the same SHA. + +inputs: + path: + description: Cache path(s), newline-separated. + required: false + default: ".lake" + key-prefix: + description: > + Cache-key prefix; must match the `key-prefix` passed to the + companion `restore-lake-cache` action so that the exact-SHA + restore keys line up. + required: false + default: "lake" + +runs: + using: composite + steps: + - name: Save lake cache + uses: actions/cache/save@v5 + with: + path: ${{ inputs.path }} + key: ${{ inputs.key-prefix }}-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('lean-toolchain') }}-${{ hashFiles('lake-manifest.json') }}-${{ hashFiles('**/*.st') }}-${{ github.sha }} diff --git a/.github/scripts/testStrataCommand.sh b/.github/scripts/testStrataCommand.sh index ced57ce5f6..f584d591b3 100755 --- a/.github/scripts/testStrataCommand.sh +++ b/.github/scripts/testStrataCommand.sh @@ -81,6 +81,15 @@ $strata print --include Examples/dialects Examples/dialects/Arith.dialect.st > " # Print Ion file and compare with previous run $strata print --include Examples/dialects "$temp_dir/Arith.dialect.st.ion" | cmp - "$temp_dir/Arith.dialect.st" +# --- pyResolveOverloads error handling --- +set +e + +expect_error "pyResolveOverloads missing dispatch file" \ + "nonexistent_dispatch.ion" \ + $strata pyResolveOverloads Examples/SimpleProc.core.st nonexistent_dispatch.ion + +set -e + if [ $failed -ne 0 ]; then echo "Some tests failed." exit 1 diff --git a/.github/workflows/cbmc.yml b/.github/workflows/cbmc.yml index b188473bfe..f4c557045e 100644 --- a/.github/workflows/cbmc.yml +++ b/.github/workflows/cbmc.yml @@ -13,38 +13,9 @@ jobs: - name: Checkout uses: actions/checkout@v6 - name: Install cvc5 - shell: bash - run: | - ARCH=$(uname -m) - if [ "$ARCH" = "x86_64" ]; then - ARCH_NAME="x86_64" - elif [ "$ARCH" = "aarch64" ] || [ "$ARCH" = "arm64" ]; then - ARCH_NAME="arm64" - else - echo "Unsupported architecture: $ARCH" - exit 1 - fi - wget -q https://github.com/cvc5/cvc5/releases/download/cvc5-1.2.1/cvc5-Linux-${ARCH_NAME}-static.zip - unzip -q cvc5-Linux-${ARCH_NAME}-static.zip - chmod +x cvc5-Linux-${ARCH_NAME}-static/bin/cvc5 - echo "$GITHUB_WORKSPACE/cvc5-Linux-${ARCH_NAME}-static/bin/" >> $GITHUB_PATH + uses: ./.github/actions/install-cvc5 - name: Install z3 - shell: bash - run: | - ARCH=$(uname -m) - if [ "$ARCH" = "x86_64" ]; then - wget -q https://github.com/Z3Prover/z3/releases/download/z3-4.15.2/z3-4.15.2-x64-glibc-2.39.zip - ARCHIVE_NAME="z3-4.15.2-x64-glibc-2.39" - elif [ "$ARCH" = "aarch64" ] || [ "$ARCH" = "arm64" ]; then - wget -q https://github.com/Z3Prover/z3/releases/download/z3-4.15.2/z3-4.15.2-arm64-glibc-2.34.zip - ARCHIVE_NAME="z3-4.15.2-arm64-glibc-2.34" - else - echo "Unsupported architecture: $ARCH" - exit 1 - fi - unzip -q "${ARCHIVE_NAME}.zip" - chmod +x "${ARCHIVE_NAME}/bin/z3" - echo "$GITHUB_WORKSPACE/${ARCHIVE_NAME}/bin/" >> $GITHUB_PATH + uses: ./.github/actions/install-z3 - name: Prepare ccache uses: actions/cache@v5 with: @@ -77,11 +48,10 @@ jobs: # The cache is safe to use here because we just saved it for this exact SHA # in the build_and_test_lean job from ci.yml # https://github.com/strata-org/Strata/issues/952 - uses: actions/cache/restore@v5 + uses: ./.github/actions/restore-lake-cache with: - path: .lake - key: lake-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('lean-toolchain') }}-${{ hashFiles('lake-manifest.json') }}-${{ hashFiles('**/*.st') }}-${{ github.sha }} - fail-on-cache-miss: true + fail-on-cache-miss: "true" + use-restore-keys: "false" - name: Build Strata uses: leanprover/lean-action@v1 with: diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d01fc1ed7c..6fc9cc26de 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -32,40 +32,9 @@ jobs: - name: Checkout uses: actions/checkout@v6 - name: Install cvc5 - shell: bash - run: | - ARCH=$(uname -m) - if [ "$ARCH" = "x86_64" ]; then - ARCH_NAME="x86_64" - elif [ "$ARCH" = "aarch64" ] || [ "$ARCH" = "arm64" ]; then - ARCH_NAME="arm64" - else - echo "Unsupported architecture: $ARCH" - exit 1 - fi - wget https://github.com/cvc5/cvc5/releases/download/cvc5-1.2.1/cvc5-Linux-${ARCH_NAME}-static.zip - unzip cvc5-Linux-${ARCH_NAME}-static.zip - chmod +x cvc5-Linux-${ARCH_NAME}-static/bin/cvc5 - echo "$GITHUB_WORKSPACE/cvc5-Linux-${ARCH_NAME}-static/bin/" >> $GITHUB_PATH + uses: ./.github/actions/install-cvc5 - name: Install z3 - shell: bash - run: | - ARCH=$(uname -m) - if [ "$ARCH" = "x86_64" ]; then - ARCH_NAME="x86_64" - wget https://github.com/Z3Prover/z3/releases/download/z3-4.15.2/z3-4.15.2-x64-glibc-2.39.zip - ARCHIVE_NAME="z3-4.15.2-x64-glibc-2.39" - elif [ "$ARCH" = "aarch64" ] || [ "$ARCH" = "arm64" ]; then - ARCH_NAME="arm64" - wget https://github.com/Z3Prover/z3/releases/download/z3-4.15.2/z3-4.15.2-arm64-glibc-2.34.zip - ARCHIVE_NAME="z3-4.15.2-arm64-win" - else - echo "Unsupported architecture: $ARCH" - exit 1 - fi - unzip "${ARCHIVE_NAME}.zip" - chmod +x "${ARCHIVE_NAME}/bin/z3" - echo "$GITHUB_WORKSPACE/${ARCHIVE_NAME}/bin/" >> $GITHUB_PATH + uses: ./.github/actions/install-z3 - name: Install .NET uses: actions/setup-dotnet@v5 with: @@ -74,14 +43,7 @@ jobs: # Only use the caches on PRs because there is a risk of stale results: # https://github.com/strata-org/Strata/issues/952 if: github.event_name == 'pull_request' - uses: actions/cache/restore@v5 - with: - path: .lake - key: lake-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('lean-toolchain') }}-${{ hashFiles('lake-manifest.json') }}-${{ hashFiles('**/*.st') }}-${{ github.sha }} - restore-keys: | - lake-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('lean-toolchain') }}-${{ hashFiles('lake-manifest.json') }}-${{ hashFiles('**/*.st') }} - lake-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('lean-toolchain') }}-${{ hashFiles('lake-manifest.json') }} - lake-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('lean-toolchain') }} + uses: ./.github/actions/restore-lake-cache - name: Download ion-java jar for Java codegen test run: wget -q -O StrataTestExtra/DDM/Integration/Java/testdata/ion-java-1.11.11.jar https://github.com/amazon-ion/ion-java/releases/download/v1.11.11/ion-java-1.11.11.jar - name: Build and test Strata @@ -92,10 +54,7 @@ jobs: - name: Run tests (excluding Python) run: lake test -- --exclude Languages.Python - name: Save lake cache - uses: actions/cache/save@v5 - with: - path: .lake - key: lake-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('lean-toolchain') }}-${{ hashFiles('lake-manifest.json') }}-${{ hashFiles('**/*.st') }}-${{ github.sha }} + uses: ./.github/actions/save-lake-cache - name: Verify Java testdata is up to date run: | StrataTestExtra/DDM/Integration/Java/regenerate-testdata.sh @@ -156,26 +115,19 @@ jobs: # The cache is safe to use here because we just saved it for this exact SHA # in the build_and_test_lean job # https://github.com/strata-org/Strata/issues/952 - uses: actions/cache/restore@v5 + uses: ./.github/actions/restore-lake-cache with: - path: .lake - key: lake-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('lean-toolchain') }}-${{ hashFiles('lake-manifest.json') }}-${{ hashFiles('**/*.st') }}-${{ github.sha }} - fail-on-cache-miss: true + fail-on-cache-miss: "true" + use-restore-keys: "false" - name: Install Lean uses: leanprover/lean-action@v1 with: auto-config: false build: false - name: Install cvc5 - shell: bash - run: | - ARCH=$(uname -m) - if [ "$ARCH" = "x86_64" ]; then ARCH_NAME="x86_64" - elif [ "$ARCH" = "aarch64" ] || [ "$ARCH" = "arm64" ]; then ARCH_NAME="arm64" - else echo "Unsupported architecture: $ARCH"; exit 1; fi - wget -q https://github.com/cvc5/cvc5/releases/download/cvc5-1.2.1/cvc5-Linux-${ARCH_NAME}-static.zip - unzip -q cvc5-Linux-${ARCH_NAME}-static.zip - sudo cp cvc5-Linux-${ARCH_NAME}-static/bin/cvc5 /usr/local/bin/ + uses: ./.github/actions/install-cvc5 + with: + install-to: system - name: Check pending tests for newly passing working-directory: StrataTest/Languages/Python shell: bash @@ -271,11 +223,10 @@ jobs: # The cache is safe to use here because we just saved it for this exact SHA # in the build_and_test_lean job from ci.yml # https://github.com/strata-org/Strata/issues/952 - uses: actions/cache/restore@v5 + uses: ./.github/actions/restore-lake-cache with: - path: .lake - key: lake-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('lean-toolchain') }}-${{ hashFiles('lake-manifest.json') }}-${{ hashFiles('**/*.st') }}-${{ github.sha }} - fail-on-cache-miss: true + fail-on-cache-miss: "true" + use-restore-keys: "false" - name: Install Lean (for lake env) uses: leanprover/lean-action@v1 with: @@ -283,21 +234,9 @@ jobs: build: false use-github-cache: false - name: Install cvc5 - shell: bash - run: | - ARCH=$(uname -m) - if [ "$ARCH" = "x86_64" ]; then - ARCH_NAME="x86_64" - elif [ "$ARCH" = "aarch64" ] || [ "$ARCH" = "arm64" ]; then - ARCH_NAME="arm64" - else - echo "Unsupported architecture: $ARCH" - exit 1 - fi - wget https://github.com/cvc5/cvc5/releases/download/cvc5-1.2.1/cvc5-Linux-${ARCH_NAME}-static.zip - unzip cvc5-Linux-${ARCH_NAME}-static.zip - chmod +x cvc5-Linux-${ARCH_NAME}-static/bin/cvc5 - sudo cp cvc5-Linux-${ARCH_NAME}-static/bin/cvc5 /usr/local/bin/ + uses: ./.github/actions/install-cvc5 + with: + install-to: system - name: Install z3 shell: bash run: | @@ -321,51 +260,3 @@ jobs: permissions: contents: read uses: ./.github/workflows/cbmc.yml - - strata-benchmarks: - name: Run internal benchmarks of Strata - runs-on: ubuntu-latest - permissions: - id-token: write - contents: read - steps: - - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@v6 - with: - role-to-assume: arn:aws:iam::${{secrets.AWS_BENCHMARK_ACCOUNT}}:role/github-actions-codebuild-role - aws-region: us-east-2 - - - name: Trigger CodeBuild and wait - shell: bash - run: | - aws --version - BUILD_ID=$(aws codebuild start-build \ - --project-name strata-benchmarks \ - --source-type-override GITHUB \ - --source-location-override https://github.com/strata-org/Strata.git \ - --source-version ${{ github.event.pull_request.head.sha || github.sha }} \ - --query 'build.id' --output text \ - --region us-east-2) - echo "Build started: $BUILD_ID" - echo "CodeBuild console: https://us-east-2.console.aws.amazon.com/codesuite/codebuild/projects/strata-benchmarks/build/${BUILD_ID}/?region=us-east-2" - - LOG_KEY="logs/${BUILD_ID}.log" - echo "[View build log in S3](https://s3.console.aws.amazon.com/s3/object/strata-internal-benchmarks-logs?prefix=${LOG_KEY})" >> $GITHUB_STEP_SUMMARY - - while true; do - STATUS=$(aws codebuild batch-get-builds \ - --ids "$BUILD_ID" \ - --query 'builds[0].buildStatus' --output text \ - --region us-east-2) - echo "Current status: [$STATUS]" - case "$STATUS" in - SUCCEEDED) break;; - FAILED|FAULT|TIMED_OUT|STOPPED) echo "Build failed: $STATUS" ; break ;; - IN_PROGRESS) sleep 30 ;; - *) echo "Unexpected status: $STATUS"; sleep 10 ; break ;; - esac - done - - echo "View build log in S3: https://s3.console.aws.amazon.com/s3/object/strata-internal-benchmarks-logs?prefix=${LOG_KEY}" - - test "$STATUS" = "SUCCEEDED" diff --git a/.gitignore b/.gitignore index 3776616d98..f7d8f2cb47 100644 --- a/.gitignore +++ b/.gitignore @@ -12,4 +12,4 @@ vcs/*.smt2 *.py.ion *.py.ion.core.st -Strata.code-workspace \ No newline at end of file +Strata.code-workspace diff --git a/Examples/expected/IrrelevantAxioms.removeIrrelevantAxioms.core.st b/Examples/expected/IrrelevantAxioms.removeIrrelevantAxioms.core.st index 3b5e32dd3d..21fde32b64 100644 --- a/Examples/expected/IrrelevantAxioms.removeIrrelevantAxioms.core.st +++ b/Examples/expected/IrrelevantAxioms.removeIrrelevantAxioms.core.st @@ -1,11 +1,11 @@ program Core; function f (x : int) : int; -axiom [f_positive]: forall __q0 : int :: f(__q0) > 0; -axiom [f_monotone]: forall __q0 : int :: forall __q1 : int :: __q0 < __q1 ==> f(__q0) < f(__q1); +axiom [f_positive]: forall x : int :: f(x) > 0; +axiom [f_monotone]: forall x : int :: forall y : int :: x < y ==> f(x) < f(y); function g (x : int) : int; function h (x : int) : int; -axiom [h_def]: forall __q0 : int :: h(__q0) == f(__q0) + 1; +axiom [h_def]: forall x : int :: h(x) == f(x) + 1; procedure TestF (x : int, out result : int) spec { ensures [result_positive]: result > 0; diff --git a/Strata.lean b/Strata.lean index ee70beef36..4f094503e4 100644 --- a/Strata.lean +++ b/Strata.lean @@ -68,4 +68,7 @@ import Strata.MetaVerifier /- Simple API -/ import Strata.SimpleAPI +/- Pipeline -/ +import Strata.Pipeline.PyAnalyzeLaurel + -- noimport: Strata.Util.Random -- deletion candidate: nothing imports this module diff --git a/Strata/DDM/AST.lean b/Strata/DDM/AST.lean index 786ac648cc..8e947f40b8 100644 --- a/Strata/DDM/AST.lean +++ b/Strata/DDM/AST.lean @@ -113,6 +113,9 @@ deriving BEq, Inhabited, Repr namespace TypeExprF +/-- An anonymous type placeholder. -/ +def placeholder {α} (loc : α) : TypeExprF α := .tvar loc "" + def ann {α} : TypeExprF α → α | .ident ann _ _ => ann | .bvar ann _ => ann @@ -1339,6 +1342,8 @@ structure Dialect where -- Names of dialects that are imported into this dialect imports : Array DialectName declarations : Array Decl := #[] + /-- When false, type inference and unification are skipped during elaboration. -/ + typecheck : Bool := true cache : Std.HashMap String Decl := declarations.foldl (init := {}) fun m d => m.insert d.name d @@ -2240,6 +2245,11 @@ def addCommand (dialects : DialectMap) (gctx : GlobalContext) (op : Operation) : end GlobalContext +def computeGlobalContext (dialects : DialectMap) (commands : Array Operation) + : Except String GlobalContext := + commands.foldl (init := (Except.ok {} : Except String GlobalContext)) + fun acc cmd => acc.bind (·.addCommand dialects cmd) + structure Program where mk :: /-- Map from dialect names to the dialect definition. -/ @@ -2249,11 +2259,7 @@ structure Program where /-- Top level commands in file. -/ commands : Array Operation := #[] /-- Final global context for program. -/ - globalContext : GlobalContext := - match commands.foldl (init := (Except.ok {} : Except String GlobalContext)) - fun acc cmd => acc.bind (·.addCommand dialects cmd) with - | .ok gctx => gctx - | .error e => panic! s!"Program.globalContext: {e}" -- nopanic:ok + globalContext : GlobalContext namespace Program @@ -2261,7 +2267,7 @@ instance : BEq Program where beq x y := x.dialect == y.dialect && x.commands == y.commands instance : Inhabited Program where - default := private { dialects := .empty, dialect := default } + default := private { dialects := .empty, dialect := default, globalContext := {} } def addCommand (env : Program) (cmd : Operation) : Program := { env with @@ -2276,7 +2282,11 @@ This creates a program. It is added in addition to `Program.mk` to simplify the `ToExpr Program` instance. -/ def create (dialects : DialectMap) (dialect : DialectName) (commands : Array Operation) : Program := - { dialects, dialect, commands } + let globalContext := + match computeGlobalContext dialects commands with + | .ok gctx => gctx + | .error e => panic! s!"Program.globalContext: {e}" -- nopanic:ok + { dialects, dialect, commands, globalContext } end Program diff --git a/Strata/DDM/BuiltinDialects/StrataDDL.lean b/Strata/DDM/BuiltinDialects/StrataDDL.lean index fd48156467..cacc09eef2 100644 --- a/Strata/DDM/BuiltinDialects/StrataDDL.lean +++ b/Strata/DDM/BuiltinDialects/StrataDDL.lean @@ -66,6 +66,15 @@ def StrataDDL : Dialect := BuiltinM.create! "StrataDDL" #[initDialect] do category := Command, syntaxDef := .ofList [.str "import", .ident 0 0, .str ";"] } + declareOp { + name := "setOptionCommand", + argDecls := .ofArray #[ + { ident := "name", kind := Ident }, + { ident := "value", kind := Ident } + ], + category := Command, + syntaxDef := .ofList [.str "dialect_option", .ident 0 0, .ident 1 0, .str ";"] + } declareOp { name := "categoryCommand", argDecls := .ofArray #[ diff --git a/Strata/DDM/Elab.lean b/Strata/DDM/Elab.lean index 0ab6bfe420..5c18dd3e3b 100644 --- a/Strata/DDM/Elab.lean +++ b/Strata/DDM/Elab.lean @@ -181,7 +181,8 @@ def elabProgramRest let s := DeclState.initDeclState let s := { s with pos := startPos } let s := s.openLoadedDialect! loader d - let ctx : DeclContext := { inputContext, stopPos, loader := loader, missingImport := false } + let ctx : DeclContext := { inputContext, stopPos, loader := loader, missingImport := false, + typecheck := d.typecheck } let (cmds, s) := runCommand leanEnv #[] stopPos ctx s if s.errors.isEmpty then let openDialects := loader.dialects.importedDialects dialect known diff --git a/Strata/DDM/Elab/Core.lean b/Strata/DDM/Elab/Core.lean index 1d7d39aa0d..ac7821282d 100644 --- a/Strata/DDM/Elab/Core.lean +++ b/Strata/DDM/Elab/Core.lean @@ -111,6 +111,14 @@ def applyNArgs (tctx : TypingContext) (e : TypeExpr) (n : Nat) := aux #[] e if argsLt : args.size < n then match tctx.hnf e with | .arrow _ a r => aux (args.push a) r + -- A tvar already represents an unresolved type — filling remaining + -- slots with placeholders is consistent with existing tvar semantics. + -- This runs regardless of the `typecheck` flag; downstream unifyTypes + -- still catches genuine mismatches when typecheck is on. + | .tvar ann _ => + let placeholder := .placeholder ann + let tvars := Array.replicate (n - args.size) placeholder + .ok (⟨args ++ tvars, by simp [tvars]; omega⟩, placeholder) | e => .error (args, e) else if argsGt : args.size > n then @@ -142,6 +150,8 @@ structure ElabContext where globalContext : GlobalContext /-- Flag to indicate we are missing an import (silences some warnings)-/ missingImport : Bool + /-- When false, type inference and unification are skipped during elaboration. -/ + typecheck : Bool := true structure ElabState where -- Errors found in elaboration. @@ -1119,7 +1129,14 @@ partial def elabOperation (tctx : TypingContext) (stx : Syntax) : ElabM Tree := if not success then return default let getKind i := .ofArgDeclKind argDecls[i].kind + let typecheck := (← read).typecheck let ((args, newCtx), success) ← runChecked <| + -- When typecheck is off, skip pre-registration passes. Global context + -- is already populated by `computeGlobalContext` at program creation. + if !typecheck then do + let args ← runSyntaxElaborator (argc := argDecls.size) getKind se tctx stxArgs + return (args, resultContext se tctx args) + else match se.preRegisterTypesScope with | some scopeArgLevel => elaborateWithPreRegistrationCore argDecls se tctx loc stxArgs scopeArgLevel @@ -1155,10 +1172,13 @@ partial def elabSyntaxArg (argIdx : Fin argc) (trees : Vector (Option Tree) argc) : ElabM (Vector (Option Tree) argc) := do + let typecheck := (← read).typecheck match getKind argIdx with | .preType expectedType => let (tree, success) ← runChecked <| elabExpr tctx astx if success then + if !typecheck then + return trees.set argIdx (some tree) let expr := tree.info.asExpr!.expr let inferredType ← inferType tctx expr let dialects := (← read).dialects @@ -1178,6 +1198,8 @@ partial def elabSyntaxArg | .typeExpr expectedType => let (tree, success) ← runChecked <| elabExpr tctx astx if success then + if !typecheck then + return trees.set argIdx (some tree) let expr := tree.info.asExpr!.expr let inferredType ← inferType tctx expr let trees ← unifyTypes isTypeP argIdx @@ -1246,6 +1268,19 @@ partial def runSyntaxElaborator trees ← elabSyntaxArg getKind isTypeP t.resultContext astx ⟨argLevel, argLevelP⟩ trees else trees ← elabSyntaxArg getKind isTypeP tctx0 astx ⟨argLevel, argLevelP⟩ trees + -- Fill unfilled type parameter slots with skip types when type checking is skipped. + if !(← read).typecheck then + for i in Fin.range argc do + if trees[i].isNone ∧ isTypeP i then + let loc := SourceRange.none + -- Synthesize placeholder type expr. + let info : TypeInfo := { + loc, + inputCtx := tctx0, + typeExpr := .placeholder loc, + isInferred := true + } + trees := trees.set i (some (.node (.ofTypeInfo info) #[])) return trees.map (·.getD default) /-- @@ -1778,7 +1813,8 @@ def runElab {α} (action : ElabM α) : DeclM α := do metadataDeclMap := s.metadataDeclMap, globalContext := s.globalContext, inputContext := (←read).inputContext, - missingImport := (← read).missingImport + missingImport := (← read).missingImport, + typecheck := (← read).typecheck } let errors := (←get).errors -- Clear errors from decl diff --git a/Strata/DDM/Elab/DeclM.lean b/Strata/DDM/Elab/DeclM.lean index cdd7e646e7..b4096aceaf 100644 --- a/Strata/DDM/Elab/DeclM.lean +++ b/Strata/DDM/Elab/DeclM.lean @@ -170,6 +170,8 @@ structure DeclContext where loader : LoadedDialects /-- Flag indicating imports are missing (silences some errors). -/ missingImport : Bool + /-- When false, type inference and unification are skipped during elaboration. -/ + typecheck : Bool := true namespace DeclContext diff --git a/Strata/DDM/Elab/DialectM.lean b/Strata/DDM/Elab/DialectM.lean index 54bd169f37..ed3286c482 100644 --- a/Strata/DDM/Elab/DialectM.lean +++ b/Strata/DDM/Elab/DialectM.lean @@ -909,6 +909,19 @@ def elabMdCommand : DialectElab := fun tree => do metadataDeclMap := s.metadataDeclMap.add dialect decl } +def elabSetOptionCommand : DialectElab := fun tree => do + let .isTrue _ := checkTreeSize tree 2 + | logError tree.info.loc "setOptionCommand: unexpected tree size"; return + let nameInfo := tree[0].info.asIdent! + let valueInfo := tree[1].info.asIdent! + match nameInfo.val with + | "typecheck" => + match valueInfo.val with + | "on" => modifyDialect fun d => { d with typecheck := true } + | "off" => modifyDialect fun d => { d with typecheck := false } + | _ => logError valueInfo.loc s!"Expected 'on' or 'off' for option 'typecheck'." + | _ => logError nameInfo.loc s!"Unknown option '{nameInfo.val}'." + def dialectElabs : Std.HashMap QualifiedIdent DialectElab := Std.HashMap.ofList <| [ (q`StrataDDL.importCommand, elabDialectImportCommand), @@ -917,6 +930,7 @@ def dialectElabs : Std.HashMap QualifiedIdent DialectElab := (q`StrataDDL.typeCommand, elabTypeCommand), (q`StrataDDL.fnCommand, elabFnCommand), (q`StrataDDL.mdCommand, elabMdCommand), + (q`StrataDDL.setOptionCommand, elabSetOptionCommand), ] partial def runDialectCommand (leanEnv : Lean.Environment) : DialectM Bool := do diff --git a/Strata/DDM/Ion.lean b/Strata/DDM/Ion.lean index a6aaf93451..baa268e797 100644 --- a/Strata/DDM/Ion.lean +++ b/Strata/DDM/Ion.lean @@ -1485,6 +1485,10 @@ private instance : CachedToIon Dialect where for i in d.imports do a := a.push <| .struct #[(ionSymbol! "type", ionSymbol! "import"), (ionSymbol! "name", .string i)] + if !d.typecheck then + a := a.push <| .struct #[(ionSymbol! "type", ionSymbol! "option"), + (ionSymbol! "name", .string "typecheck"), + (ionSymbol! "value", .string "off")] for decl in d.declarations do a := a.push <| (← ionRef! decl) return .list a @@ -1492,9 +1496,17 @@ private instance : CachedToIon Dialect where def fromIonFragment (dialect : DialectName) (f : Ion.Fragment) : Except String Dialect := do let ctx : FromIonContext := ⟨f.symbols⟩ let tbl := f.symbols - let typeId := tbl.symbolId! "type" - let nameId := tbl.symbolId! "name" - let (imports, decls) ← f.values.foldlM (init := (#[], #[])) (start := f.offset) fun (imports, decls) v => do + let typeId := tbl.symbolId "type" + let nameId := tbl.symbolId "name" + let valueId := tbl.symbolId "value" + if f.values.size > 0 then + if typeId = .zero then + throw s!"Missing type symbol" + if nameId = .zero then + throw s!"Missing name symbol" + let (imports, decls, typecheck) ← f.values.foldlM + (init := (#[], #[], true)) (start := f.offset) + fun (imports, decls, typecheck) v => do let fields ← FromIonM.asStruct0 v ⟨f.symbols⟩ let some (_, val) := fields.find? (·.fst == typeId) | throw s!"Could not find type in {repr fields}" @@ -1503,14 +1515,29 @@ def fromIonFragment (dialect : DialectName) (f : Ion.Fragment) : Except String D let some (_, val) := fields.find? (·.fst == nameId) | throw "Could not find import" let i ← FromIonM.asString "Import name" val ctx - pure (imports.push i, decls) + pure (imports.push i, decls, typecheck) + | "option" => + if valueId = .zero then + throw "Could not find option value" + let some (_, nameVal) := fields.find? (·.fst == nameId) + | throw "Could not find option name" + let optName ← FromIonM.asString "Option name" nameVal ctx + let some (_, valueVal) := fields.find? (·.fst == valueId) + | throw "Could not find option value" + let optValue ← FromIonM.asString "Option value" valueVal ctx + match optName, optValue with + | "typecheck", "off" => pure (imports, decls, false) + | "typecheck", "on" => pure (imports, decls, true) + | "typecheck", v => throw s!"Expected 'on' or 'off' for option 'typecheck', got '{v}'" + | name, _ => throw s!"Unknown option '{name}'" | name => let decl ← Decl.fromIonFields name fields ctx - pure (imports, decls.push decl) + pure (imports, decls.push decl, typecheck) return { name := dialect imports := imports - declarations := decls + declarations := decls + typecheck := typecheck } private instance : FromIon Dialect where @@ -1553,12 +1580,9 @@ def fromIonFragmentCommands (f : Ion.Fragment) : Except String (Array Operation) def fromIonFragment (f : Ion.Fragment) (dialects : DialectMap) - (dialect : DialectName) : Except String Program := - return { - dialects := dialects - dialect := dialect - commands := ← fromIonFragmentCommands f - } + (dialect : DialectName) : Except String Program := do + let commands ← fromIonFragmentCommands f + return .create dialects dialect commands /-- Decodes bytes in the Ion format into a single Strata program. @@ -1601,8 +1625,12 @@ def filesFromIon (dialects : DialectMap) (bytes : ByteArray) : Except String (Li let ⟨filesList, _⟩ ← FromIonM.asList ctx[1]! ionCtx let tbl := symbols - let filePathId := tbl.symbolId! "filePath" - let programId := tbl.symbolId! "program" + let filePathId := tbl.symbolId "filePath" + let programId := tbl.symbolId "program" + if filePathId = .zero then + throw "Missing filePath" + if programId = .zero then + throw "Missing program" filesList.toList.mapM fun fileEntry => do let fields ← FromIonM.asStruct0 fileEntry ionCtx diff --git a/Strata/DDM/Util/Ion.lean b/Strata/DDM/Util/Ion.lean index 0d2517ff8f..6b11c00900 100644 --- a/Strata/DDM/Util/Ion.lean +++ b/Strata/DDM/Util/Ion.lean @@ -9,6 +9,7 @@ public import Strata.DDM.Util.Ion.AST public import Strata.DDM.Util.Ion.Deserialize public import Strata.DDM.Util.Ion.Serialize public import Strata.DDM.Util.Ion.SymbolTable +public import Strata.DDM.Util.Ion.SystemSymbolIds import all Strata.DDM.Util.ByteArray import all Strata.DDM.Util.Fin diff --git a/Strata/DDM/Util/Ion/SymbolTable.lean b/Strata/DDM/Util/Ion/SymbolTable.lean index aa7a65bbba..07ff69fb94 100644 --- a/Strata/DDM/Util/Ion/SymbolTable.lean +++ b/Strata/DDM/Util/Ion/SymbolTable.lean @@ -5,30 +5,30 @@ -/ module -import Lean.Elab.Command -- shake: keep public import Strata.DDM.Util.Ion.AST -import all Strata.DDM.Util.Lean public section namespace Ion structure SymbolTable where - array : Array String - map : Std.HashMap String SymbolId + private mk :: + private array : Array String + private map : Std.HashMap String SymbolId locals : Array String deriving Inhabited namespace SymbolTable -instance : GetElem? SymbolTable SymbolId String (fun tbl idx => idx.value < tbl.array.size) where - getElem tbl idx p := tbl.array[idx.value] - getElem! tbl idx := assert! idx.value < tbl.array.size; tbl.array[idx.value]! - getElem? tbl idx := tbl.array[idx.value]? +def size (tbl : SymbolTable) : Nat := tbl.array.size -def symbolId! (sym : String) (tbl : SymbolTable) : SymbolId := - match tbl.map[sym]? with - | some i => i - | none => panic! s!"Unbound symbol {sym}" +instance : GetElem? SymbolTable SymbolId String (fun tbl idx => idx.value < tbl.size) where + getElem tbl idx p := private tbl.array[idx.value] + getElem! tbl idx := private tbl.array[idx.value]! + getElem? tbl idx := private tbl.array[idx.value]? + +/-- Lookup symbol and return `SymbolId.zero` if not defined. -/ +def symbolId (sym : String) (tbl : SymbolTable) : SymbolId := + tbl.map.getD sym .zero /-- Intern a string into a symbol. @@ -62,33 +62,9 @@ def system : SymbolTable where def ofLocals (locals : Array String) : SymbolTable := locals.foldl (init := .system) (fun tbl sym => tbl.intern sym |>.snd) -public instance : Lean.Quote SymbolTable where +instance : Lean.Quote SymbolTable where quote st := Lean.Syntax.mkCApp ``SymbolTable.ofLocals #[Lean.quote st.locals] end SymbolTable -namespace SymbolId - -def systemSymbolId! (sym : String) : SymbolId := SymbolTable.system |>.symbolId! sym - --- Use metaprogramming to declare `{sym}SymbolId : SymbolId` for each system symbol. -section -open Lean (TSyntax) -open Lean.Elab.Command (elabCommand) - --- Declare all system symbol ids as constants -run_cmd do - for sym in SymbolTable.ionSharedSymbolTableEntries do - -- To simplify name, strip out non-alphanumeric characters. - let simplifiedName : String := .ofList <| sym.toList.filter (·.isAlphanum) - let leanName := Lean.mkLocalDeclId simplifiedName - let cmd : TSyntax `command ← `(command| - public def $(leanName) : SymbolId := systemSymbolId! $(Lean.Syntax.mkStrLit sym) - ) - elabCommand cmd - -end - -end SymbolId - end Ion diff --git a/Strata/DDM/Util/Ion/SystemSymbolIds.lean b/Strata/DDM/Util/Ion/SystemSymbolIds.lean new file mode 100644 index 0000000000..dfb0dad26d --- /dev/null +++ b/Strata/DDM/Util/Ion/SystemSymbolIds.lean @@ -0,0 +1,31 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ +module + +import Lean.Elab.Command -- shake: keep +public import Strata.DDM.Util.Ion.AST +meta import Strata.DDM.Util.Ion.SymbolTable + +-- Use metaprogramming to declare `{sym}SymbolId : SymbolId` for each system symbol. +section +open Lean (TSyntax) +open Lean.Elab.Command (elabCommand) +open Lean.Parser.Category (command) + +-- Declare all system symbol ids as constants +run_cmd do + for sym in Ion.SymbolTable.ionSharedSymbolTableEntries do + -- To simplify name, strip out non-alphanumeric characters. + let simplifiedName : String := .ofList <| sym.toList.filter (·.isAlphanum) + let leanName := Lean.mkIdentFrom (canonical := true) default <| ``Ion.SymbolId |>.str simplifiedName + let idx := Ion.SymbolTable.system.symbolId sym + if idx = .zero then + throwError s!"Unbound symbol {sym}" + elabCommand $ ← `(command| + public def $(leanName) : Ion.SymbolId := ⟨$(Lean.Syntax.mkNatLit idx.value)⟩ + ) + +end diff --git a/Strata/DL/Imperative/CmdEval.lean b/Strata/DL/Imperative/CmdEval.lean index 5ee378ef99..c5c99659f2 100644 --- a/Strata/DL/Imperative/CmdEval.lean +++ b/Strata/DL/Imperative/CmdEval.lean @@ -66,11 +66,7 @@ def Cmd.eval [BEq P.Ident] [EC : EvalContext P S] (σ : S) (c : Cmd P) : Cmd P let e := EC.eval σ e let assumptions := EC.getPathConditions σ let c' := .assert label e md - let propType := match md.getPropertyType with - | some s => if s == MetaData.divisionByZero then .divisionByZero - else if s == MetaData.arithmeticOverflow then .arithmeticOverflow - else .assert - | none => .assert + let propType := convertMetaDataPropertyType md match EC.denoteBool e with | some true => -- Proved via evaluation. (c', EC.deferObligation σ (ProofObligation.mk label propType assumptions e md)) diff --git a/Strata/DL/Imperative/EvalContext.lean b/Strata/DL/Imperative/EvalContext.lean index 1f823cb008..941fa3b15f 100644 --- a/Strata/DL/Imperative/EvalContext.lean +++ b/Strata/DL/Imperative/EvalContext.lean @@ -102,12 +102,13 @@ inductive PropertyType where | assert | divisionByZero | arithmeticOverflow + | outOfBoundsAccess deriving Repr, DecidableEq /-- Whether an unreachable path counts as pass for this property type. Assertions pass vacuously when unreachable; covers fail. -/ def PropertyType.passWhenUnreachable : PropertyType → Bool - | .assert | .divisionByZero | .arithmeticOverflow => true + | .assert | .divisionByZero | .arithmeticOverflow | .outOfBoundsAccess => true | .cover => false instance : ToFormat PropertyType where @@ -116,6 +117,21 @@ instance : ToFormat PropertyType where | .assert => "assert" | .divisionByZero => "division by zero check" | .arithmeticOverflow => "arithmetic overflow check" + | .outOfBoundsAccess => "out-of-bounds access check" + +/-- Convert a `MetaData` entry's property-type classification string to the + `PropertyType` enum. Falls back to `.assert` when the metadata carries + no classification or an unrecognized string; callers that emit + propertyType classifications should add a matching arm here. -/ +def convertMetaDataPropertyType {P : PureExpr} [BEq P.Ident] + (md : MetaData P) : PropertyType := + match md.getPropertyType with + | some s => + if s == MetaData.divisionByZero then .divisionByZero + else if s == MetaData.arithmeticOverflow then .arithmeticOverflow + else if s == MetaData.outOfBoundsAccess then .outOfBoundsAccess + else .assert + | none => .assert /-- A proof obligation can be discharged by some backend solver or a dedicated diff --git a/Strata/DL/Imperative/MetaData.lean b/Strata/DL/Imperative/MetaData.lean index f3d3a384d1..28c15c480b 100644 --- a/Strata/DL/Imperative/MetaData.lean +++ b/Strata/DL/Imperative/MetaData.lean @@ -320,6 +320,9 @@ def MetaData.divisionByZero : String := "divisionByZero" /-- Metadata value for arithmetic-overflow property type classification. -/ def MetaData.arithmeticOverflow : String := "arithmeticOverflow" +/-- Metadata value for out-of-bounds-access property type classification. -/ +def MetaData.outOfBoundsAccess : String := "outOfBoundsAccess" + /-- Read the property type classification from metadata, if present. -/ def MetaData.getPropertyType {P : PureExpr} [BEq P.Ident] (md : MetaData P) : Option String := match md.findElem MetaData.propertyType with diff --git a/Strata/DL/Imperative/SMTUtils.lean b/Strata/DL/Imperative/SMTUtils.lean index 36c08f3828..541ec17d98 100644 --- a/Strata/DL/Imperative/SMTUtils.lean +++ b/Strata/DL/Imperative/SMTUtils.lean @@ -10,6 +10,7 @@ import Strata.DL.SMT.DDMTransform.Parse import Strata.DL.SMT.DDMTransform.Translate import Strata.DDM.Elab import Strata.DDM.Format +public import Strata.Pipeline.Context public import Strata.DL.Imperative.PureExpr public import Strata.DL.Imperative.EvalContext @@ -166,7 +167,7 @@ directly, which avoids the ambiguity that arises when parsing at the Returns a list of (key-string, value-Term) pairs on success. -/ -private def parseModelDDM (modelStr : String) : IO (List (String × Strata.SMT.Term)) := do +def parseModelDDM (modelStr : String) : IO (List (String × Strata.SMT.Term)) := do let inputCtx := Strata.Parser.stringInputContext "solver-model" modelStr let op ← try Strata.Elab.parseCategoryFromDialect @@ -194,7 +195,7 @@ Process a parsed model (list of key-string / value-Term pairs) against the expected variables, matching each variable's SMT-encoded name to its value in the model. -/ -private def processModel {P : PureExpr} [ToFormat P.Ident] +def processModel {P : PureExpr} [ToFormat P.Ident] (typedVarToSMTFn : P.Ident → P.Ty → Except Format (String × Strata.SMT.TermType)) (vars : List P.TypedIdent) (pairs : List (String × Strata.SMT.Term)) (E : Strata.SMT.EncoderState) : Except Format (Model P.Ident) := do @@ -292,6 +293,72 @@ def addLocationInfo {P : PureExpr} [BEq P.Ident] Strata.SMT.Solver.setInfoString message.fst message.snd | .none => pure () +/-- Result of encoding a proof obligation against an `AbstractSolver`. + Returned by the encoder callback passed to `dischargeObligationIncremental`, + consumed by the check-sat orchestration. -/ +structure EncodedObligation where + obligationId : Strata.SMT.Term + assumptionIds : List String + estate : Strata.SMT.EncoderState + +/-- Discharge a proof obligation using a live (incremental) SMT solver. + The encoder callback runs against the spawned solver to emit declarations + and assertions; this helper orchestrates check-sat calls and model parsing. -/ +def dischargeObligationIncremental {P : PureExpr} [ToFormat P.Ident] [BEq P.Ident] + (encodeDecl : Strata.SMT.AbstractSolver Strata.SMT.Term Strata.SMT.TermType + Strata.SMT.IncrementalSolverM → + Strata.SMT.IncrementalSolverM EncodedObligation) + (typedVarToSMTFn : P.Ident → P.Ty → Except Format (String × Strata.SMT.TermType)) + (vars : List P.TypedIdent) + (smtsolver : String) (solverFlags : Array String) + (satisfiabilityCheck validityCheck : Bool) : + IO (Except SolverError (Result P.Ident × Result P.Ident × Strata.SMT.EncoderState)) := do + let solverState ← Strata.SMT.IncrementalSolver.spawn smtsolver solverFlags + let action : Strata.SMT.IncrementalSolverM + (Except SolverError (Result P.Ident × Result P.Ident × Strata.SMT.EncoderState)) := do + let solver := Strata.SMT.IncrementalSolver.mkIncrementalSolver + let { obligationId, assumptionIds, estate } ← encodeDecl solver + let varIds := assumptionIds.map fun id => Strata.SMT.Term.var ⟨id, .bool⟩ + let getModelForVars : Strata.SMT.IncrementalSolverM (Model P.Ident) := do + if varIds.isEmpty then return [] + try + let pairs ← solver.getValue varIds + match pairs with + | [(.prim (.string rawOutput), _)] => + let rawModel ← parseModelDDM rawOutput + match processModel typedVarToSMTFn vars rawModel estate with + | .ok model => return model + | .error _ => return [] + | _ => return [] + catch _ => return [] + let decisionToResult (decision : Strata.SMT.Decision) : + Strata.SMT.IncrementalSolverM (Result P.Ident) := do + match decision with + | .sat => return .sat (← getModelForVars) + | .unknown => + let model ← getModelForVars + return if model.isEmpty then .unknown else .unknown (some model) + | .unsat => return .unsat + let bothChecks := satisfiabilityCheck && validityCheck + let mut satResult : Result P.Ident := .unknown + let mut valResult : Result P.Ident := .unknown + if bothChecks then + satResult ← decisionToResult (← solver.checkSatAssuming [obligationId]) + let negObligation ← solver.mkNot obligationId + valResult ← decisionToResult (← solver.checkSatAssuming [negObligation]) + else + if satisfiabilityCheck then + solver.assert obligationId + satResult ← decisionToResult (← solver.checkSat) + else if validityCheck then + let negObligation ← solver.mkNot obligationId + solver.assert negObligation + valResult ← decisionToResult (← solver.checkSat) + solver.close + return .ok (satResult, valResult, estate) + let (result, _) ← action.run solverState + return result + /-- Writes the proof obligation to file, discharge the obligation using SMT solver, and parse the output of the SMT solver. @@ -307,20 +374,22 @@ def dischargeObligation {P : PureExpr} [ToFormat P.Ident] [BEq P.Ident] (smtsolver filename : String) (solver_options : Array String) (printFilename : Bool) (satisfiabilityCheck validityCheck : Bool) - (skipSolver : Bool := false) : + (skipSolver : Bool := false) + (pctx : Strata.Pipeline.PipelineContext) : IO (Except SolverError (Result P.Ident × Result P.Ident × Strata.SMT.EncoderState)) := do let handle ← IO.FS.Handle.mk filename IO.FS.Mode.write let solver ← Strata.SMT.Solver.fileWriter handle - -- encodeSMT (which calls encodeCore) emits check-sat commands internally - let ((_ids, estate), _solverState) ← encodeSMT.run solver + let ((_ids, estate), _solverState) ← pctx.withPhase "encodeSMT" do + encodeSMT.run solver if printFilename then IO.println s!"Wrote problem to {filename}." if skipSolver then return .ok (.unknown, .unknown, estate) - let solver_output ← runSolver smtsolver (#[filename] ++ solver_options) + let solver_output ← pctx.withPhase "runSolver" do + runSolver smtsolver (#[filename] ++ solver_options) match ← solverResult typedVarToSMTFn vars solver_output estate smtsolver satisfiabilityCheck validityCheck with | .error e => return .error e | .ok (satResult, validityResult) => return .ok (satResult, validityResult, estate) diff --git a/Strata/DL/Lambda/IntBoolFactory.lean b/Strata/DL/Lambda/IntBoolFactory.lean index a5de45f549..b58a2028ee 100644 --- a/Strata/DL/Lambda/IntBoolFactory.lean +++ b/Strata/DL/Lambda/IntBoolFactory.lean @@ -120,22 +120,27 @@ instance (n : Nat) : LambdaLeanType (.bitvec n) (BitVec n) where These build well-formed `WFLFunc`s that have no `concreteEval` or `body`. -/ -/-- General polymorphic unevaluated function with optional axioms. - Handles any arity and any number of type arguments. -/ +/-- General polymorphic unevaluated function with optional axioms and + preconditions. Handles any arity and any number of type arguments. -/ @[inline] def polyUneval (n : T.Identifier) (typeArgs : List String) (inputs : List (T.Identifier × LMonoTy)) (output : LMonoTy) (axioms : List (LExpr T.mono) := []) + (preconditions : + List (FuncPrecondition (LExpr T.mono) T.Metadata) := []) (h_nodup : List.Nodup (inputs.map (·.1.name)) := by first | decide | grind) (h_ta_nodup : List.Nodup typeArgs := by grind) (h_inputs : ∀ ty, ty ∈ ListMap.values inputs → ty.freeVars ⊆ typeArgs := by first | decide | grind) (h_output : output.freeVars ⊆ typeArgs := by first | decide | grind) + (h_precond : ∀ p, p ∈ preconditions → + (LExpr.freeVars p.expr).map (·.1.name) ⊆ inputs.map (·.1.name) + := by first | decide | grind) (h_ta_no_gen : ∀ ta, ta ∈ typeArgs → ¬ ("$__ty".toList.isPrefixOf ta.toList = true) := by first | decide | grind) : WFLFunc T := ⟨{ name := n, typeArgs := typeArgs, inputs := inputs, output := output, - axioms := axioms }, { + axioms := axioms, preconditions := preconditions }, { arg_nodup := h_nodup body_freevars := by intro b hb; simp at hb concreteEval_argmatch := by intro fn _ _ _ hfn; simp at hfn @@ -144,7 +149,7 @@ def polyUneval (n : T.Identifier) (typeArgs : List String) typeArgs_nodup := h_ta_nodup inputs_typevars_in_typeArgs := h_inputs output_typevars_in_typeArgs := h_output - precond_freevars := by intro p hp; simp at hp + precond_freevars := h_precond typeArgs_no_gen_prefix := h_ta_no_gen }⟩ diff --git a/Strata/DL/SMT/AbstractSolver.lean b/Strata/DL/SMT/AbstractSolver.lean new file mode 100644 index 0000000000..79df0cddff --- /dev/null +++ b/Strata/DL/SMT/AbstractSolver.lean @@ -0,0 +1,197 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ +module + +public import Strata.DL.SMT.Solver + +/-! +# Abstract Solver Interface + +Defines `AbstractSolver τ σ m`, a generic solver interface parameterized by an +opaque term type `τ`, an opaque sort type `σ`, and a monad `m`. All operations +that can fail throw `IO.Error` via `MonadExceptOf`. The monad `m` captures any +state or effects the backend needs. + +For the incremental SMT-LIB backend, `τ = SMT.Term`, `σ = SMT.TermType`, +`m = StateT IncrementalSolverState IO`. + +## Design + +- `declareNew` allows shadowing: declaring the same name twice creates two + distinct variables. The backend handles disambiguation internally. +- Models return keys as `(String × Nat)` where `Nat` is the shadow depth + (0 = most recently declared). +- Quantifier bound variables are scoped via a callback pattern. +- Terms (`τ`) are opaque handles whose meaning is backend-specific. They may + be internal addresses and should not be assumed valid across sessions. +- Sorts are first-class: backends can create and pass their own sort + representations via `intSort`, `boolSort`, `bitvecSort`, `arraySort`, etc. +-/ + +namespace Strata.SMT + +public section + +/-- Handles for a single datatype constructor returned by `declareDatatype`. + - `constr` is the constructor function (use with `mkApp` to build values) + - `tester` is the recognizer predicate (use with `mkApp` to test membership) + - `selectors` are the field accessors in declaration order -/ +structure DatatypeConstructorHandles (τ : Type) where + constr : τ + tester : τ + selectors : List τ + +/-- Result of declaring a datatype: the sort and handles for each constructor. -/ +structure DatatypeInfo (τ : Type) (σ : Type) where + sort : σ + constructors : List (DatatypeConstructorHandles τ) + +/-- Abstract solver interface parameterized by term type `τ`, sort type `σ`, +and monad `m`. + +All term constructors are fallible. Solvers might not accept certain constructs +(e.g., wrong sorts, unsupported combinations) and we need to surface the issue +precisely via `MonadExceptOf IO.Error`. -/ +structure AbstractSolver (τ : Type) (σ : Type) (m : Type → Type) [Monad m] [MonadExceptOf IO.Error m] where + -- Configuration (for solvers that support them; ignored otherwise) + setLogic : String → m Unit + setOption : String → String → m Unit + comment : String → m Unit + + -- Sort constructors + boolSort : m σ + intSort : m σ + realSort : m σ + stringSort : m σ + regexSort : m σ + bitvecSort : Nat → m σ + arraySort : σ → σ → m σ + + /-- Construct a sort for a named type (datatype or user-defined sort) + with the given type arguments. -/ + constrSort : String → List σ → m σ + + -- Literal / leaf constructors + mkBool : Bool → m τ + mkInt : Int → m τ + mkPrim : TermPrim → m τ + + /-- Fallback for operations not covered by specific mk* methods + (e.g. bitvectors, strings, regex). The backend receives the raw `Op`, + the already-encoded arguments, and the result sort. -/ + mkAppOp : Op → List τ → σ → m τ + + -- Boolean operations + mkAnd : List τ → m τ + mkOr : List τ → m τ + mkNot : τ → m τ + mkImplies : τ → τ → m τ + + -- Arithmetic operations + mkAdd : List τ → m τ + mkSub : List τ → m τ + mkMul : List τ → m τ + mkDiv : τ → τ → m τ + mkMod : τ → τ → m τ + mkNeg : τ → m τ + mkAbs : τ → m τ + + -- Comparison operations + mkEq : List τ → m τ + mkLt : List τ → m τ + mkLe : List τ → m τ + mkGt : List τ → m τ + mkGe : List τ → m τ + + -- Conditional + mkIte : τ → τ → τ → m τ + + -- Array operations + mkSelect : τ → τ → m τ + mkStore : τ → τ → τ → m τ + + -- Function application (for uninterpreted functions) + mkApp : τ → List τ → m τ + + -- Quantifiers + /-- Construct a universally quantified term. + Takes name-sort pairs for bound variables and a monadic callback that + receives the bound variable terms and returns the body and trigger groups. + The callback is monadic so callers can encode sub-terms using the + bound variable handles. Bound variables cannot escape the quantifier scope. -/ + mkForall : List (String × σ) → (List τ → m (τ × List (List τ))) → m τ + + /-- Construct an existentially quantified term. Same callback pattern as `mkForall`. -/ + mkExists : List (String × σ) → (List τ → m (τ × List (List τ))) → m τ + + /-- Declare a new variable. Shadowing is allowed: declaring the same name + twice creates two distinct variables. The backend handles disambiguation + internally. -/ + declareNew : String → σ → m τ + + /-- Declare an uninterpreted function. -/ + declareFun : String → List σ → σ → m τ + + /-- Define an interpreted function with a body term. -/ + defineFun : String → List (String × σ) → σ → τ → m Unit + + /-- Declare a new sort with the given arity. Returns the declared sort. -/ + declareSort : String → Nat → m σ + + /-- Declare an algebraic datatype. + Takes the datatype name, type parameter names, and a callback that + receives `(selfSort, typeParamSorts)` and returns the constructors. + Returns the declared sort and constructor/tester/selector handles. + This callback pattern (like `mkForall`) allows recursive and parametric + datatypes: the sort being declared does not exist yet when selectors + need to reference it. -/ + declareDatatype : String → List String → + (σ → List σ → Except String (List (String × List (String × σ)))) → + m (DatatypeInfo τ σ) + + /-- Declare mutually recursive algebraic datatypes. + Takes a list of `(name, typeParams)` and a callback that receives + `(selfSorts, typeParamSorts)` and returns constructors for each datatype. + Returns the declared sorts and constructor/tester/selector handles. -/ + declareDatatypes : List (String × List String) → + (List σ → List (List σ) → Except String (List (List (String × List (String × σ))))) → + m (List (DatatypeInfo τ σ)) + + -- Session operations + + /-- Assert a term (must be Bool-typed). -/ + assert : τ → m Unit + + /-- Check satisfiability of the current assertions. -/ + checkSat : m Decision + + /-- Check satisfiability under additional assumptions. -/ + checkSatAssuming : List τ → m Decision + + /-- After an `unsat` result from `checkSatAssuming`, retrieve the subset of + assumptions that contributed to unsatisfiability. -/ + getUnsatAssumptions : m (List τ) + + /-- Retrieve the model after a `sat` result. + Keys are `(name, shadow_depth)` where 0 = most recently declared. -/ + getModel : m (List ((String × Nat) × τ)) + + /-- Get values of specific terms in the current model. -/ + getValue : List τ → m (List (τ × τ)) + + /-- Convert a term to its SMT-LIB string representation, making model values inspectable. + The returned string must be valid SMT-LIB syntax. -/ + termToSMTLibString : τ → m String + + /-- Reset the solver session to its initial state. -/ + reset : m Unit + + /-- Close the solver session and release resources. -/ + close : m Unit + +end + +end Strata.SMT diff --git a/Strata/DL/SMT/Encoder.lean b/Strata/DL/SMT/Encoder.lean index 4b3fce6f1c..67757d404a 100644 --- a/Strata/DL/SMT/Encoder.lean +++ b/Strata/DL/SMT/Encoder.lean @@ -177,7 +177,7 @@ private theorem extractTriggerGroup_sizeOf (t ti : Term) (h : ti ∈ extractTrig · simp_all /-- Every term nested in `extractTriggers t` has `sizeOf ≤ sizeOf t`. -/ -private theorem extractTriggers_sizeOf (t : Term) (ts : List Term) (ti : Term) +theorem extractTriggers_sizeOf (t : Term) (ts : List Term) (ti : Term) (hts : ts ∈ extractTriggers t) (hti : ti ∈ ts) : sizeOf ti ≤ sizeOf t := by unfold extractTriggers at hts diff --git a/Strata/DL/SMT/IncrementalSolver.lean b/Strata/DL/SMT/IncrementalSolver.lean new file mode 100644 index 0000000000..4a1e65dbc7 --- /dev/null +++ b/Strata/DL/SMT/IncrementalSolver.lean @@ -0,0 +1,364 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ +module + +public import Strata.DL.SMT.AbstractSolver +public import Strata.DL.SMT.Factory +import Strata.DDM.Format +import Std.Data.HashMap + +/-! +# Incremental SMT-LIB Backend + +Implements `AbstractSolver Term (StateT IncrementalSolverState IO)` where the +state wraps a live solver process communicating via stdin/stdout. Unlike the +batch pipeline (write file, run solver), this backend sends commands +incrementally and reads responses interactively. + +Variable shadowing is handled by appending `@N` suffixes to disambiguate +repeated declarations of the same name. The shadow depth is tracked per name. +-/ + +namespace Strata.SMT + +public section + +/-- State for the incremental SMT-LIB solver backend. Wraps a live solver + process and tracks variable shadowing for `declareNew`. -/ +structure IncrementalSolverState where + /-- The underlying SMT-LIB solver process. -/ + solver : SMTLibSolver + /-- Caches `Term → SMT-LIB string` conversions. -/ + termStrings : Std.HashMap Term String := {} + /-- Caches `TermType → SMT-LIB string` conversions. -/ + typeStrings : Std.HashMap TermType String := {} + /-- Tracks how many times each variable name has been declared (for shadowing). -/ + shadowCounts : Std.HashMap String Nat := {} + /-- Maps SMT-LIB string → Term for the last `checkSatAssuming` call, + used by `getUnsatAssumptions` to recover terms from solver output. -/ + lastAssumptions : Std.HashMap String Term := {} + +/-- The monad for the incremental solver backend. -/ +abbrev IncrementalSolverM := StateT IncrementalSolverState IO + +namespace IncrementalSolver + +def emitln (str : String) : IncrementalSolverM Unit := do + let st ← get + st.solver.smtLibInput.putStr s!"{str}\n" + st.solver.smtLibInput.flush + +def readln : IncrementalSolverM String := do + let st ← get + match st.solver.smtLibOutput with + | .some stdout => return (← stdout.getLine).trimAscii.toString + | .none => throw (IO.userError "no output stream available") + +private def termToStr (t : Term) : IncrementalSolverM String := do + let st ← get + if let .some s := st.termStrings.get? t then return s + match Strata.SMTDDM.termToString t with + | .ok s => + modify fun st => { st with termStrings := st.termStrings.insert t s } + return s + | .error msg => throw (IO.userError s!"term serialization failed: {msg}") + +private def typeToStr (ty : TermType) : IncrementalSolverM String := do + let st ← get + if let .some s := st.typeStrings.get? ty then return s + match Strata.SMTDDM.termTypeToString ty with + | .ok s => + modify fun st => { st with typeStrings := st.typeStrings.insert ty s } + return s + | .error msg => throw (IO.userError s!"type serialization failed: {msg}") + +/-- Get the disambiguated SMT-LIB name for a variable, handling shadowing. -/ +private def disambiguatedName (name : String) (depth : Nat) : String := + if depth == 0 then name else s!"{name}@{depth}" + +/-- Spawn an incremental solver process. -/ +def spawn (path : String) (args : Array String) : IO IncrementalSolverState := do + let solver ← Solver.spawn path args + return { solver } + +/-- Shared helper for constructing quantified terms. -/ +private def mkQuantHelper (qk : QuantifierKind) + (bindings : List (String × TermType)) + (callback : List Term → IncrementalSolverM (Term × List (List Term))) + : IncrementalSolverM Term := do + let vars := bindings.map fun (name, ty) => TermVar.mk name ty + let varTerms := vars.map Term.var + let (body, triggers) ← callback varTerms + let tr := match triggers with + | [] => Term.app .triggers [] .trigger + | groups => + let triggerTerms := groups.map fun group => Term.app .triggers group .trigger + Term.app .triggers triggerTerms .trigger + return (Term.quant qk vars tr body) + +/-- Shared helper for binary comparison operations. -/ +private def mkBinCmp (op : Op) (opName : String) (ts : List Term) + : IncrementalSolverM Term := + match ts with + | [] | [_] => throw (IO.userError s!"{opName}: need at least two arguments") + | [t1, t2] => return (Term.app op [t1, t2] .bool) + | _ => throw (IO.userError s!"{opName}: pairwise comparison not yet supported") + +/-- Shared helper for variadic arithmetic operations. -/ +private def mkVarArith (op : Op) (opName : String) (ts : List Term) + : IncrementalSolverM Term := + match ts with + | [] => throw (IO.userError s!"{opName}: empty argument list") + | [t] => return t + | t :: rest => return (rest.foldl (fun acc x => Term.app op [acc, x] acc.typeOf) t) + +/-- Parse a solver check-sat response into a `Decision`. -/ +def parseDecision (line : String) : Except String Decision := + match line with + | "sat" => .ok .sat + | "unsat" => .ok .unsat + | "unknown" => .ok .unknown + | other => .error s!"unrecognized solver output: {other}" + +/-- Format datatype constructors as SMT-LIB strings. -/ +private def formatConstrs (constrs : List (String × List (String × TermType))) + : IncrementalSolverM (List String) := do + let mut result := [] + for (cname, fields) in constrs.reverse do + if fields.isEmpty then + result := s!"({cname})" :: result + else do + let mut fieldStrs := [] + for (fname, fty) in fields.reverse do + let tyStr ← typeToStr fty + fieldStrs := s!"({fname} {tyStr})" :: fieldStrs + result := s!"({cname} {String.intercalate " " fieldStrs})" :: result + return result + +/-- Construct the sort for a datatype given its name and type parameter names. -/ +private def mkDatatypeSort (name : String) (params : List String) : TermType × List TermType := + let paramSorts := params.map fun p => TermType.constr p [] + (.constr name paramSorts, paramSorts) + +/-- Build constructor/tester/selector handles for a list of constructors. -/ +private def mkConstructorHandles (selfSort : TermType) + (constrs : List (String × List (String × TermType))) + : List (DatatypeConstructorHandles Term) := + constrs.map fun (cname, fields) => + { constr := Term.app (.datatype_op .constructor cname) [] selfSort + tester := Term.app (.datatype_op .tester cname) [] .bool + selectors := fields.map fun (fname, fty) => + Term.app (.datatype_op .selector fname) [] fty } + +/-- Build the `AbstractSolver` implementation for incremental SMT-LIB. -/ +def mkIncrementalSolver : AbstractSolver Term TermType IncrementalSolverM where + setLogic logic := emitln s!"(set-logic {logic})" + setOption name value := emitln s!"(set-option :{name} {value})" + comment c := emitln s!"; {c.replace "\n" " "}" + + boolSort := return .bool + intSort := return .int + realSort := return .real + stringSort := return .string + regexSort := return .regex + bitvecSort n := return .bitvec n + arraySort k v := return .constr "Array" [k, v] + constrSort name args := return .constr name args + + mkBool b := return Term.bool b + mkInt i := return Term.int i + mkPrim p := return .prim p + mkAppOp op args retTy := return .app op args retTy + + mkAnd ts := return (ts.foldl Factory.and (Term.bool true)) + mkOr ts := return (ts.foldl Factory.or (Term.bool false)) + mkNot t := return (Factory.not t) + mkImplies t1 t2 := return (Factory.implies t1 t2) + + mkAdd ts := mkVarArith .add "mkAdd" ts + mkSub ts := mkVarArith .sub "mkSub" ts + mkMul ts := mkVarArith .mul "mkMul" ts + mkDiv t1 t2 := return (Term.app .div [t1, t2] t1.typeOf) + mkMod t1 t2 := return (Term.app .mod [t1, t2] t1.typeOf) + mkNeg t := return (Term.app .neg [t] t.typeOf) + mkAbs t := return (Term.app .abs [t] t.typeOf) + + mkEq ts := match ts with + | [] | [_] => throw (IO.userError "mkEq: need at least two arguments") + | [t1, t2] => return (Factory.eq t1 t2) + | t1 :: t2 :: rest => + return (rest.foldl (fun acc x => Factory.and acc (Factory.eq t1 x)) (Factory.eq t1 t2)) + mkLt ts := mkBinCmp .lt "mkLt" ts + mkLe ts := mkBinCmp .le "mkLe" ts + mkGt ts := mkBinCmp .gt "mkGt" ts + mkGe ts := mkBinCmp .ge "mkGe" ts + + mkIte c t f := return (Factory.ite c t f) + + mkSelect arr idx := return (Term.app .select [arr, idx] arr.typeOf) + mkStore arr idx val := return (Term.app .store [arr, idx, val] arr.typeOf) + mkApp fn args := match fn with + | .app (.uf uf) _ _ => return (Term.app (.uf uf) args uf.out) + | .app (.datatype_op kind name) _ retTy => return (Term.app (.datatype_op kind name) args retTy) + | _ => throw (IO.userError "mkApp: expected a function handle (uninterpreted function or datatype op)") + + declareNew name ty := do + let st ← get + let count := st.shadowCounts.getD name 0 + let smtName := disambiguatedName name count + set { st with shadowCounts := st.shadowCounts.insert name (count + 1) } + let tyStr ← typeToStr ty + emitln s!"(declare-const {quoteIdent smtName} {tyStr})" + return Term.var ⟨smtName, ty⟩ + + declareFun name argTys retTy := do + let retStr ← typeToStr retTy + if argTys.isEmpty then + emitln s!"(declare-const {quoteIdent name} {retStr})" + else + let mut argStrs := [] + for ty in argTys.reverse do + argStrs := (← typeToStr ty) :: argStrs + let inline := String.intercalate " " argStrs + emitln s!"(declare-fun {quoteIdent name} ({inline}) {retStr})" + return Term.var ⟨name, retTy⟩ + + defineFun name args retTy body := do + let retStr ← typeToStr retTy + let mut typedArgs := [] + for (n, ty) in args.reverse do + let tyStr ← typeToStr ty + typedArgs := s!"({quoteIdent n} {tyStr})" :: typedArgs + let inline := String.intercalate " " typedArgs + let bodyStr ← termToStr body + emitln s!"(define-fun {quoteIdent name} ({inline}) {retStr} {bodyStr})" + + declareSort name arity := do + emitln s!"(declare-sort {name} {arity})" + return (.constr name (List.replicate arity (.constr "_" []))) + + declareDatatype name params callback := do + let (selfSort, paramSorts) := mkDatatypeSort name params + match callback selfSort paramSorts with + | .error msg => throw (IO.userError msg) + | .ok constrs => + let strs ← formatConstrs constrs + let cInline := "\n " ++ String.intercalate "\n " strs + if params.isEmpty then + emitln s!"(declare-datatype {name} ({cInline}))" + else + let pInline := String.intercalate " " params + emitln s!"(declare-datatype {name} (par ({pInline}) ({cInline})))" + return { sort := selfSort, constructors := mkConstructorHandles selfSort constrs } + + declareDatatypes dts callback := do + if dts.isEmpty then return [] + let sortsAndParams := dts.map fun (name, params) => mkDatatypeSort name params + let selfSorts := sortsAndParams.map (·.1) + let paramSorts := sortsAndParams.map (·.2) + match callback selfSorts paramSorts with + | .error msg => throw (IO.userError msg) + | .ok allConstrs => + let sortDecls := dts.map fun (name, params) => s!"({name} {params.length})" + let sortDeclStr := String.intercalate " " sortDecls + let mut bodies := [] + for ((_, params), constrs) in (dts.zip allConstrs).reverse do + let strs ← formatConstrs constrs + let cInline := String.intercalate " " strs + if params.isEmpty then + bodies := s!"({cInline})" :: bodies + else + let pInline := String.intercalate " " params + bodies := s!"(par ({pInline}) ({cInline}))" :: bodies + let bodyStr := String.intercalate "\n " bodies + emitln s!"(declare-datatypes ({sortDeclStr})\n ({bodyStr}))" + return (selfSorts.zip allConstrs |>.map fun (sort, constrs) => + { sort, constructors := mkConstructorHandles sort constrs }) + + mkForall bindings callback := do + mkQuantHelper .all bindings callback + + mkExists bindings callback := do + mkQuantHelper .exist bindings callback + + assert t := do + let s ← termToStr t + emitln s!"(assert {s})" + + checkSat := do + emitln "(check-sat)" + let result ← readln + match parseDecision result with + | .ok d => return d + | .error msg => throw (IO.userError msg) + + checkSatAssuming assumptions := do + let mut strs := [] + let mut assumptionMap : Std.HashMap String Term := {} + for t in assumptions.reverse do + let s ← termToStr t + strs := s :: strs + assumptionMap := assumptionMap.insert s t + modify fun st => { st with lastAssumptions := assumptionMap } + let inline := String.intercalate " " strs + emitln s!"(check-sat-assuming ({inline}))" + let result ← readln + match parseDecision result with + | .ok d => return d + | .error msg => throw (IO.userError msg) + + getModel := throw (IO.userError "getModel: not yet implemented for incremental backend") + + getUnsatAssumptions := do + emitln "(get-unsat-assumptions)" + let response ← readln + -- Response is "(lit1 lit2 ...)" — strip parens and split + let inner := response.replace "(" "" |>.replace ")" "" + if inner.trimAscii.toString.isEmpty then return [] + let literals := inner.trimAscii.toString.splitOn " " |>.filter (!·.isEmpty) + let assumptionMap := (← get).lastAssumptions + let mut result := [] + for lit in literals.reverse do + match assumptionMap.get? lit with + | some t => result := t :: result + | none => throw (IO.userError s!"getUnsatAssumptions: unknown literal '{lit}'") + return result + + getValue ts := do + -- Send get-value command with the given terms + let mut strs := [] + for t in ts.reverse do + strs := (← termToStr t) :: strs + let inline := String.intercalate " " strs + emitln s!"(get-value ({inline}))" + -- Read the response (a single s-expression, possibly multi-line) + let mut modelOutput := "" + let mut reading := true + let mut parenDepth : Int := 0 + while reading do + let respLine ← readln + if respLine.isEmpty then + reading := false + else + modelOutput := modelOutput ++ respLine ++ "\n" + for c in respLine.toList do + if c == '(' then parenDepth := parenDepth + 1 + else if c == ')' then parenDepth := parenDepth - 1 + if parenDepth ≤ 0 then reading := false + -- Return the raw output as a single pair (the verifier parses it) + return [(Term.string modelOutput, Term.string modelOutput)] + + termToSMTLibString t := return (← termToStr t) + + reset := emitln "(reset)" + + close := emitln "(exit)" + +end IncrementalSolver + +end + +end Strata.SMT diff --git a/Strata/DL/SMT/SMT.lean b/Strata/DL/SMT/SMT.lean index 53161c9f56..67fbaf1093 100644 --- a/Strata/DL/SMT/SMT.lean +++ b/Strata/DL/SMT/SMT.lean @@ -5,9 +5,11 @@ -/ module +public import Strata.DL.SMT.AbstractSolver public import Strata.DL.SMT.Encoder public import Strata.DL.SMT.Factory public import Strata.DL.SMT.Function +public import Strata.DL.SMT.IncrementalSolver public import Strata.DL.SMT.Op public import Strata.DL.SMT.Solver public import Strata.DL.SMT.Term diff --git a/Strata/DL/SMT/Solver.lean b/Strata/DL/SMT/Solver.lean index 8b12459bf1..d49e2b8529 100644 --- a/Strata/DL/SMT/Solver.lean +++ b/Strata/DL/SMT/Solver.lean @@ -37,35 +37,55 @@ inductive Decision where deriving DecidableEq, Repr /-- - A Solver is an interpreter for SMTLib scripts, which are passed to the solver - via its `smtLibInput` stream. Solvers optionally have an `smtLibOutput` stream - to which they print the results of executing the commands received on the input - stream. We assume that both the input and output streams conform to the SMTLib - standard: the inputs are SMTLib script commands encoded as s-expressions, and - the outputs are the s-expressions whose shape is determined by the standard for - each command. We don't have an error stream here, since we configure solvers to - run in quiet mode and not print anything to the error stream. + An SMT-LIB solver process wrapper. + + An SMTLibSolver is an interpreter for SMTLib scripts, which are passed to the + solver via its `smtLibInput` stream. Solvers optionally have an `smtLibOutput` + stream to which they print the results of executing the commands received on + the input stream. We assume that both the input and output streams conform to + the SMTLib standard: the inputs are SMTLib script commands encoded as + s-expressions, and the outputs are the s-expressions whose shape is determined + by the standard for each command. We don't have an error stream here, since we + configure solvers to run in quiet mode and not print anything to the error + stream. -/ -structure Solver where +structure SMTLibSolver where smtLibInput : IO.FS.Stream smtLibOutput : Option IO.FS.Stream -/-- State tracked by `SolverM`: caches `Term → SMT-LIB string` and +/-- Backward-compatible alias for `SMTLibSolver`. -/ +abbrev Solver := SMTLibSolver + +/-- State tracked by `SMTLibSolverM`: caches `Term → SMT-LIB string` and `TermType → SMT-LIB string` conversions so that the same term/type is never converted twice. -/ -structure SolverState where +structure SMTLibSolverState where /-- Caches `Term → full SMT-LIB string` via `SMTDDM.termToString`. -/ termStrings : Std.HashMap Term String := {} /-- Caches `TermType → full SMT-LIB string` via `SMTDDM.termTypeToString`. -/ typeStrings : Std.HashMap TermType String := {} -def SolverState.init : SolverState := {} +def SMTLibSolverState.init : SMTLibSolverState := {} + +/-- Backward-compatible alias for `SMTLibSolverState`. -/ +abbrev SolverState := SMTLibSolverState + +/-- Backward-compatible alias. -/ +abbrev SolverState.init := SMTLibSolverState.init -@[expose] abbrev SolverM (α) := StateT SolverState (ReaderT Solver IO) α +/-- SMT-LIB solver monad. Renamed from `SolverM` to `SMTLibSolverM` + to distinguish from the abstract solver interface. -/ +@[expose] abbrev SMTLibSolverM (α) := StateT SolverState (ReaderT Solver IO) α + +/-- Backward-compatible alias for `SMTLibSolverM`. -/ +abbrev SolverM := SMTLibSolverM def SolverM.run (solver : Solver) (x : SolverM α) (state : SolverState := SolverState.init) : IO (α × SolverState) := ReaderT.run (StateT.run x state) solver +/-- Alias for `SolverM.run`. -/ +abbrev SMTLibSolverM.run := @SolverM.run + /-- A typed SMT-LIB datatype constructor: name + typed fields. -/ structure SMTConstructor where name : String @@ -75,11 +95,11 @@ deriving Repr, Inhabited namespace Solver /-- - Returns a Solver for the given path and arguments. This function expects - `path` to point to an SMT solver executable, and `args` to specify valid - arguments to that solver. + Returns an SMTLibSolver for the given path and arguments. This function + expects `path` to point to an SMT solver executable, and `args` to specify + valid arguments to that solver. -/ -def spawn (path : String) (args : Array String) : IO Solver := do +def spawn (path : String) (args : Array String) : IO SMTLibSolver := do try let proc ← IO.Process.spawn { stdin := .piped @@ -97,7 +117,7 @@ def spawn (path : String) (args : Array String) : IO Solver := do Returns an instance of the solver that is backed by the executable specified in the environment variable "SOLVER". -/ -def solver : IO Solver := do +def solver : IO SMTLibSolver := do match (← IO.getEnv "SOLVER") with | .some path => spawn path ["--quiet", "--lang", "smt"].toArray | .none => throw (IO.userError "SOLVER environment variable not defined.") @@ -109,7 +129,7 @@ def solver : IO Solver := do useful). For example, `Solver.checkSat` returns `Decision.unknown`. This function expects `h` to be write-enabled. -/ -def fileWriter (h : IO.FS.Handle) : IO Solver := +def fileWriter (h : IO.FS.Handle) : IO SMTLibSolver := return ⟨IO.FS.Stream.ofHandle h, .none⟩ /-- @@ -118,7 +138,7 @@ def fileWriter (h : IO.FS.Handle) : IO Solver := return values that are sound according to the SMTLIb spec (but generally not useful). For example, `Solver.checkSat` returns `Decision.unknown`. -/ -def bufferWriter (b : IO.Ref IO.FS.Stream.Buffer) : IO Solver := +def bufferWriter (b : IO.Ref IO.FS.Stream.Buffer) : IO SMTLibSolver := return ⟨IO.FS.Stream.ofBuffer b, .none⟩ /-! ## Internal helpers -/ @@ -128,7 +148,7 @@ private def emitln (str : String) : SolverM Unit := do solver.smtLibInput.putStr s!"{str}\n" solver.smtLibInput.flush -/-- Convert a `Term` to its SMT-LIB string, using the `SolverState` cache. -/ +/-- Convert a `Term` to its SMT-LIB string, using the `SMTLibSolverState` cache. -/ def termToSMTString (t : Term) : SolverM String := do if let (.some s) := (← get).termStrings.get? t then return s match Strata.SMTDDM.termToString t with @@ -137,7 +157,7 @@ def termToSMTString (t : Term) : SolverM String := do return s | .error msg => throw (IO.userError s!"Solver.termToSMTString failed: {msg}") -/-- Convert a `TermType` to its SMT-LIB string, using the `SolverState` cache. -/ +/-- Convert a `TermType` to its SMT-LIB string, using the `SMTLibSolverState` cache. -/ def typeToSMTString (ty : TermType) : SolverM String := do if let (.some s) := (← get).typeStrings.get? ty then return s match Strata.SMTDDM.termTypeToString ty with diff --git a/Strata/DL/SMT/Translate.lean b/Strata/DL/SMT/Translate.lean index c6de9007fa..1b761271b0 100644 --- a/Strata/DL/SMT/Translate.lean +++ b/Strata/DL/SMT/Translate.lean @@ -369,8 +369,12 @@ def translateTerm (t : SMT.Term) : TranslateM (Expr × Expr) := do leftAssocOp mkIntMul as | .app .div as _ => leftAssocOp mkIntDiv as + | .app .mod [x, y] _ => + let (α, x) ← translateTerm x + let (_, y) ← translateTerm y + return (α, mkApp2 mkIntMod x y) | .app .mod as _ => - leftAssocOp mkIntMod as + throw m!"Error: 'mod' expects exactly two operands, got '{as.length}'" | .app .abs [a] _ => let (_, a) ← translateTerm a let c := mkApp2 mkIntLT a (toExpr (0 : Int)) @@ -545,16 +549,20 @@ def translateTerm (t : SMT.Term) : TranslateM (Expr × Expr) := do | t => throw m!"Error: unsupported term '{repr t}'" where leftAssocOp (op : Expr) (as : List SMT.Term) : TranslateM (Expr × Expr) := do - let a :: as := as | throw m!"Error: expected at least two arguments for '{op}', got '{as.length}'" + let a :: b :: as := as + | throw m!"Error: expected at least two arguments for '{op}', got '{as.length}'" let (α, a) ← translateTerm a + let (_, b) ← translateTerm b let as ← as.mapM (translateTerm · >>= pure ∘ Prod.snd) - return (α, as.foldl (mkApp2 op) a) + return (α, as.foldl (mkApp2 op) (mkApp2 op a b)) leftAssocOpBitVec (op : Nat → Expr) (as : List SMT.Term) : TranslateM (Expr × Expr) := do - let a :: as := as | throw m!"Error: expected at least two arguments for BitVec op, got '{as.length}'" + let a :: b :: as := as + | throw m!"Error: expected at least two arguments for BitVec op, got '{as.length}'" let (α, a) ← translateTerm a + let (_, b) ← translateTerm b let op := op (← getBitVecWidth α) let as ← as.mapM (translateTerm · >>= pure ∘ Prod.snd) - return (α, as.foldl (mkApp2 op) a) + return (α, as.foldl (mkApp2 op) (mkApp2 op a b)) /-- Translate assumptions and a conclusion into a right-associated implication diff --git a/Strata/Languages/Boole/Verify.lean b/Strata/Languages/Boole/Verify.lean index 927f39fc17..5240629436 100644 --- a/Strata/Languages/Boole/Verify.lean +++ b/Strata/Languages/Boole/Verify.lean @@ -469,9 +469,11 @@ partial def toCoreExpr (e : Boole.Expr) : TranslateM Core.Expression.Expr := do let intSub : Core.Expression.Expr := .op () ⟨"Int.Sub", ()⟩ none return mkCoreApp Core.seqTakeOp [mkCoreApp Core.seqDropOp [s', lo'], mkCoreApp intSub [hi', lo']] - -- Typed empty-sequence constant (Sequence.empty for bv32; other types can be added when needed). - | .seq_empty_bv8 _ | .seq_empty_bv16 _ | .seq_empty_bv32 _ - | .seq_empty_bv64 _ | .seq_empty_int _ => return Core.seqEmptyOp + | .seq_empty_bv8 _ => return Core.seqEmptyOp (some (.bitvec 8)) + | .seq_empty_bv16 _ => return Core.seqEmptyOp (some (.bitvec 16)) + | .seq_empty_bv32 _ => return Core.seqEmptyOp (some (.bitvec 32)) + | .seq_empty_bv64 _ => return Core.seqEmptyOp (some (.bitvec 64)) + | .seq_empty_int _ => return Core.seqEmptyOp (some .int) -- Sequence literals: Sequence.of_bv32[v0, v1, ..., vn] -- Lowers to a left-fold of seq_build over seq_empty. | .seq_of_bv8 _ ⟨_, vs⟩ | .seq_of_bv16 _ ⟨_, vs⟩ | .seq_of_bv32 _ ⟨_, vs⟩ diff --git a/Strata/Languages/Core/DDMTransform/ASTtoCST.lean b/Strata/Languages/Core/DDMTransform/ASTtoCST.lean index b74f973594..805805f413 100644 --- a/Strata/Languages/Core/DDMTransform/ASTtoCST.lean +++ b/Strata/Languages/Core/DDMTransform/ASTtoCST.lean @@ -195,7 +195,9 @@ def funcToCST {M} [Inhabited M] -- Convert preconditions let preconds ← precondsToSpecElts func.preconditions let bodyExpr ← lexprToExpr body 0 - let inline? : Ann (Option (Inline M)) M := ⟨default, none⟩ + let inline? : Ann (Option (Inline M)) M := + if func.attr.any (· == .inline) then ⟨default, some (.inline default)⟩ + else ⟨default, none⟩ pure (.command_fndef default name typeArgs b r preconds bodyExpr inline?) modify ToCSTContext.popScope -- Register function name as free variable. diff --git a/Strata/Languages/Core/DDMTransform/FormatCore.lean b/Strata/Languages/Core/DDMTransform/FormatCore.lean index 0af86ae9b8..0bc517fd8a 100644 --- a/Strata/Languages/Core/DDMTransform/FormatCore.lean +++ b/Strata/Languages/Core/DDMTransform/FormatCore.lean @@ -37,8 +37,6 @@ Known issues: translation in the latter's metadata field and recover them in the future. - Misc. formatting issues - -- Remove extra parentheses around constructors in datatypes, assignments, - etc. -- Remove extra indentation from the last brace of a block or the `end` keyword of a mutual block. -/ @@ -295,11 +293,21 @@ def handleZeroaryOps {M} [Inhabited M] (name : String) | .re .All => pure (.re_all default) | .re .AllChar => pure (.re_allchar default) | .re .None => pure (.re_none default) - -- TODO: seq_empty is not yet parseable (see Grammar.lean); handle here when added. | _ => do ToCSTM.logError "lopToExpr" "0-ary op not found" name pure (.re_none default) +/-- Convert a bitvector width to the corresponding CoreType, logging an error and + falling back to bv64 for unsupported widths. -/ +def bvTypeOfWidth {M} [Inhabited M] (caller : String) (w : Nat) : ToCSTM M (CoreType M) := + match w with + | 1 => pure (CoreType.bv1 default) | 8 => pure (.bv8 default) + | 16 => pure (.bv16 default) | 32 => pure (.bv32 default) + | 64 => pure (.bv64 default) + | _ => do + ToCSTM.logError caller s!"unsupported BV width {w}" (toString w) + pure (.bv64 default) + /-- Handle unary operations -/ def handleUnaryOps {M} [Inhabited M] (name : String) (arg : CoreDDM.Expr M) : ToCSTM M (CoreDDM.Expr M) := @@ -338,8 +346,13 @@ def handleUnaryOps {M} [Inhabited M] (name : String) (arg : CoreDDM.Expr M) | .bv ⟨16, .SafeNeg⟩ | .bv ⟨16, .SafeUNeg⟩ => pure (.safeneg_expr default (.bv16 default) arg) | .bv ⟨32, .SafeNeg⟩ | .bv ⟨32, .SafeUNeg⟩ => pure (.safeneg_expr default (.bv32 default) arg) | .bv ⟨64, .SafeNeg⟩ | .bv ⟨64, .SafeUNeg⟩ => pure (.safeneg_expr default (.bv64 default) arg) - -- Overflow predicates: approximated as Bool.Not for CST printing - | .bv ⟨_, .SNegOverflow⟩ | .bv ⟨_, .UNegOverflow⟩ => pure (.not default arg) + -- Overflow predicates + | .bv ⟨w, .SNegOverflow⟩ => do + let bvTy ← bvTypeOfWidth "handleUnaryOps" w + pure (.bv_neg_overflow default bvTy arg) + | .bv ⟨w, .UNegOverflow⟩ => do + let bvTy ← bvTypeOfWidth "handleUnaryOps" w + pure (.bv_uneg_overflow default bvTy arg) -- Bitvector extract ops | .bvExtract 8 7 7 => pure (.bvextract_7_7 default arg) | .bvExtract 16 15 15 => pure (.bvextract_15_15 default arg) @@ -386,14 +399,14 @@ def bvBinaryOpMap {M} [Inhabited M] : (.SafeUAdd, fun ty arg1 arg2 => .safeadd_expr default ty arg1 arg2), (.SafeUSub, fun ty arg1 arg2 => .safesub_expr default ty arg1 arg2), (.SafeUMul, fun ty arg1 arg2 => .safemul_expr default ty arg1 arg2), - -- Overflow predicates: approximated as boolean ops for CST printing - (.SAddOverflow, fun _ty arg1 arg2 => .le default _ty arg1 arg2), - (.SSubOverflow, fun _ty arg1 arg2 => .le default _ty arg1 arg2), - (.SMulOverflow, fun _ty arg1 arg2 => .le default _ty arg1 arg2), - (.SDivOverflow, fun _ty arg1 arg2 => .le default _ty arg1 arg2), - (.UAddOverflow, fun _ty arg1 arg2 => .le default _ty arg1 arg2), - (.USubOverflow, fun _ty arg1 arg2 => .le default _ty arg1 arg2), - (.UMulOverflow, fun _ty arg1 arg2 => .le default _ty arg1 arg2) + -- Overflow predicates + (.SAddOverflow, fun ty arg1 arg2 => .bv_sadd_overflow default ty arg1 arg2), + (.SSubOverflow, fun ty arg1 arg2 => .bv_ssub_overflow default ty arg1 arg2), + (.SMulOverflow, fun ty arg1 arg2 => .bv_smul_overflow default ty arg1 arg2), + (.SDivOverflow, fun ty arg1 arg2 => .bv_sdiv_overflow default ty arg1 arg2), + (.UAddOverflow, fun ty arg1 arg2 => .bv_uadd_overflow default ty arg1 arg2), + (.USubOverflow, fun ty arg1 arg2 => .bv_usub_overflow default ty arg1 arg2), + (.UMulOverflow, fun ty arg1 arg2 => .bv_umul_overflow default ty arg1 arg2) ] /-- Map from bitvector sizes to their corresponding type constructors -/ @@ -551,11 +564,19 @@ partial def lexprToExpr {M} [Inhabited M] pure (.fvar default (ctx.allFreeVars.size)) | .ite _ c t f => liteToExpr c t f qLevel | .eq _ e1 e2 => leqToExpr e1 e2 qLevel - | .op _ name _ => lopToExpr name.name [] + | .op _ name ty => do + -- seq_empty needs the type annotation to render the explicit type parameter + if name.name == "Sequence.empty" then + let tyCST ← match ty with + | some (.tcons "Sequence" [ety]) => lmonoTyToCoreType ety + | _ => pure (CoreType.tvar default unknownTypeVar) + pure (.seq_empty default tyCST) + else + lopToExpr name.name [] | .app _ _ _ => lappToExpr e qLevel | .abs _ prettyName ty body => labsToExpr prettyName ty body (qLevel + 1) - | .quant _ qkind _ ty trigger body => - lquantToExpr qkind ty trigger body (qLevel + 1) + | .quant _ qkind prettyName ty trigger body => + lquantToExpr qkind prettyName ty trigger body (qLevel + 1) /-- Extract trigger patterns from Lambda's trigger expression representation -/ partial def extractTriggerPatterns {M} [Inhabited M] @@ -609,11 +630,13 @@ partial def labsToExpr {M} [Inhabited M] pure (.lambda default tyExpr dl bodyExpr) partial def lquantToExpr {M} [Inhabited M] - (qkind : Lambda.QuantifierKind) (ty : Option Lambda.LMonoTy) + (qkind : Lambda.QuantifierKind) (prettyName : String) + (ty : Option Lambda.LMonoTy) (trigger : Lambda.LExpr CoreLParams.mono) (body : Lambda.LExpr CoreLParams.mono) (qLevel : Nat) : ToCSTM M (CoreDDM.Expr M) := do - let name : Ann String M := ⟨default, mkQuantVarName (qLevel - 1)⟩ + let varName := if prettyName.isEmpty then mkQuantVarName (qLevel - 1) else prettyName + let name : Ann String M := ⟨default, varName⟩ modify ToCSTContext.pushScope modify (·.addScopedBoundVars #[name.val]) let tyExpr ← match ty with @@ -661,23 +684,23 @@ partial def leqToExpr {M} [Inhabited M] partial def lappToExpr {M} [Inhabited M] (e : Lambda.LExpr CoreLParams.mono) - (qLevel : Nat) (acc : List (CoreDDM.Expr M) := []) - : ToCSTM M (CoreDDM.Expr M) := - match e with - | .app _ (.app m fn e1) e2 => do - let e2Expr ← lexprToExpr e2 qLevel - lappToExpr (.app m fn e1) qLevel (e2Expr :: acc) - | .app _ (.op _ fn _) e1 => do - let e1Expr ← lexprToExpr e1 qLevel - lopToExpr fn.name (e1Expr :: acc) - | .app _ fn e1 => do + (qLevel : Nat) + : ToCSTM M (CoreDDM.Expr M) := do + let (head, args) := Lambda.getLFuncCall e + match head with + | .op _ fn _ => + let argExprs ← args.mapM (lexprToExpr · qLevel) + lopToExpr fn.name argExprs + | .app _ fn arg => + -- getLFuncCall couldn't decompose further (fn is not .app or .op) let fnCST ← lexprToExpr fn qLevel - let e1Expr ← lexprToExpr e1 qLevel - pure <| (e1Expr :: acc).foldl (fun fnAcc arg => .app default fnAcc arg) fnCST - | _ => do - -- Non-application head (e.g. lambda applied to arguments) - let eCST ← lexprToExpr e qLevel - pure <| acc.foldl (fun fnAcc arg => .app default fnAcc arg) eCST + let argCST ← lexprToExpr arg qLevel + let argExprs ← args.mapM (lexprToExpr · qLevel) + pure <| (argCST :: argExprs).foldl (fun fnAcc a => .app default fnAcc a) fnCST + | _ => + let fnCST ← lexprToExpr head qLevel + let argExprs ← args.mapM (lexprToExpr · qLevel) + pure <| argExprs.foldl (fun fnAcc arg => .app default fnAcc arg) fnCST end /-- Convert preconditions to CST spec elements -/ @@ -717,7 +740,9 @@ def funcDeclToStatement {M} [Inhabited M] (decl : Imperative.PureFunc Expression let paramNames := results.map (·.2) let b : Bindings M := .mkBindings default ⟨default, bindings⟩ let r ← lTyToCoreType decl.output - let inline? : Ann (Option (Inline M)) M := ⟨default, none⟩ + let inline? : Ann (Option (Inline M)) M := + if decl.attr.any (· == .inline) then ⟨default, some (.inline default)⟩ + else ⟨default, none⟩ -- Add formals to the context modify (·.addScopedBoundVars (reverse? := false) paramNames) -- Convert preconditions @@ -735,6 +760,24 @@ def funcDeclToStatement {M} [Inhabited M] (decl : Imperative.PureFunc Expression modify (·.pushBoundVar name.val) pure (.funcDecl_statement default name typeArgs b r preconds bodyExpr inline?) +/-- Decompose a single-level `map_update(base, idx, val)` where `base` is (or starts + with) an fvar matching `varName`. Returns `(indices, innerVal)` with indices + in left-to-right order, or `none` if the expression is not this pattern. -/ +private def decomposeMapUpdate (varName : String) + (e : Lambda.LExpr CoreLParams.mono) + : Option (List (Lambda.LExpr CoreLParams.mono) × Lambda.LExpr CoreLParams.mono) := + let (head, args) := Lambda.getLFuncCall e + match head, args with + | .op _ opName _, [base, idx, val] => + if opName.name == "update" then + match base with + | .fvar _ ident _ => + if ident.name == varName then some ([idx], val) + else none + | _ => none + else none + | _, _ => none + mutual /-- Convert `Core.Statement` to `CoreDDM.Statement` -/ partial def stmtToCST {M} [Inhabited M] (s : Core.Statement) @@ -758,9 +801,20 @@ partial def stmtToCST {M} [Inhabited M] (s : Core.Statement) modify (·.pushBoundVar name.toPretty) pure result | .set name expr _md => do - let lhs := Lhs.lhsIdent default ⟨default, name.name⟩ - let exprCST ← lexprToExpr expr 0 - -- Type annotation required by CST but not semantically used. + -- Detect map_update(name, idx, val) pattern to produce lhsArray syntax + let (lhs, exprCST) ← match decomposeMapUpdate name.name expr with + | some (idxs, val) => do + let baseLhs := Lhs.lhsIdent default ⟨default, name.name⟩ + let lhs ← idxs.foldlM (init := baseLhs) fun acc idx => do + let idxCST ← lexprToExpr idx 0 + let tyCST := CoreType.tvar default unknownTypeVar + pure (Lhs.lhsArray default tyCST acc idxCST) + let valCST ← lexprToExpr val 0 + pure (lhs, valCST) + | none => do + let lhs := Lhs.lhsIdent default ⟨default, name.name⟩ + let exprCST ← lexprToExpr expr 0 + pure (lhs, exprCST) let tyCST := CoreType.tvar default unknownTypeVar pure (.assign default tyCST lhs exprCST) | .havoc name _md => do diff --git a/Strata/Languages/Core/DDMTransform/Grammar.lean b/Strata/Languages/Core/DDMTransform/Grammar.lean index 206b72e7d5..ae29f0cfc9 100644 --- a/Strata/Languages/Core/DDMTransform/Grammar.lean +++ b/Strata/Languages/Core/DDMTransform/Grammar.lean @@ -103,9 +103,9 @@ fn map_get (K : Type, V : Type, m : Map K V, k : K) : V => m "[" k "]"; fn map_set (K : Type, V : Type, m : Map K V, k : K, v : V) : Map K V => m "[" k ":=" v "]"; -// TODO: seq_empty is not yet supported in the grammar because the DDM parser -// cannot currently handle 0-ary polymorphic functions (no arguments to infer -// the type parameter from). The Factory definition exists for programmatic use. +// seq_empty uses explicit type annotation syntax since there are no value +// arguments to infer the type parameter from. +fn seq_empty (A : Type) : Sequence A => "Sequence.empty" "<" A ">" "(" ")"; fn seq_length (A : Type, s : Sequence A) : int => "Sequence.length" "(" s ")"; fn seq_select (A : Type, s : Sequence A, i : int) : A => "Sequence.select" "(" s ", " i ")"; fn seq_append (A : Type, s1 : Sequence A, s2 : Sequence A) : Sequence A => @@ -188,6 +188,16 @@ fn bvsle (tp : Type, a : tp, b : tp) : bool => @[prec(20), leftassoc] a " <=s " fn bvsgt (tp : Type, a : tp, b : tp) : bool => @[prec(20), leftassoc] a " >s " b; fn bvsge (tp : Type, a : tp, b : tp) : bool => @[prec(20), leftassoc] a " >=s " b; +fn bv_neg_overflow (tp : Type, a : tp) : bool => "Bv.SNegOverflow" "(" a ")"; +fn bv_uneg_overflow (tp : Type, a : tp) : bool => "Bv.UNegOverflow" "(" a ")"; +fn bv_sadd_overflow (tp : Type, a : tp, b : tp) : bool => "Bv.SAddOverflow" "(" a ", " b ")"; +fn bv_ssub_overflow (tp : Type, a : tp, b : tp) : bool => "Bv.SSubOverflow" "(" a ", " b ")"; +fn bv_smul_overflow (tp : Type, a : tp, b : tp) : bool => "Bv.SMulOverflow" "(" a ", " b ")"; +fn bv_sdiv_overflow (tp : Type, a : tp, b : tp) : bool => "Bv.SDivOverflow" "(" a ", " b ")"; +fn bv_uadd_overflow (tp : Type, a : tp, b : tp) : bool => "Bv.UAddOverflow" "(" a ", " b ")"; +fn bv_usub_overflow (tp : Type, a : tp, b : tp) : bool => "Bv.USubOverflow" "(" a ", " b ")"; +fn bv_umul_overflow (tp : Type, a : tp, b : tp) : bool => "Bv.UMulOverflow" "(" a ", " b ")"; + fn bvconcat8 (a : bv8, b : bv8) : bv16 => "bvconcat{8}{8}" "(" a ", " b ")"; fn bvconcat16 (a : bv16, b : bv16) : bv32 => "bvconcat{16}{16}" "(" a ", " b ")"; fn bvconcat32 (a : bv32, b : bv32) : bv64 => "bvconcat{32}{32}" "(" a ", " b ")"; @@ -354,7 +364,7 @@ op command_fndecl (name : Ident, "function " name typeArgs b " : " r ";\n"; category Inline; -op inline () : Inline => "inline"; +op inline () : Inline => "inline "; // Note: when editing command_fndef, consider whether recfn_decl needs // matching edits. @@ -451,8 +461,11 @@ op datatype_decl (name : Ident, // Unified datatype command: one or more datatype declarations separated by // newlines, ending with a semicolon. +// +// `@[nonempty]` is load-bearing: see +// https://github.com/strata-org/Strata/issues/1146. @[scope(datatypes), preRegisterTypes(datatypes)] -op command_datatypes (datatypes : NewlineSepBy DatatypeDecl) : Command => +op command_datatypes (@[nonempty] datatypes : NewlineSepBy DatatypeDecl) : Command => datatypes ";\n"; #end diff --git a/Strata/Languages/Core/DDMTransform/Translate.lean b/Strata/Languages/Core/DDMTransform/Translate.lean index 8cd2563698..860e36a841 100644 --- a/Strata/Languages/Core/DDMTransform/Translate.lean +++ b/Strata/Languages/Core/DDMTransform/Translate.lean @@ -355,14 +355,6 @@ def translateTypeDecl (bindings : TransBindings) (op : Operation) : --------------------------------------------------------------------- -def translateLhs (arg : Arg) : TransM Core.CoreIdent := do - let .op op := arg - | TransM.error s!"translateLhs expected op {repr arg}" - match op.name, op.args with - | q`Core.lhsIdent, #[id] => translateIdent Core.CoreIdent id - -- (TODO) Implement lhsArray. - | _, _ => TransM.error s!"translateLhs: unimplemented for {repr arg}" - def translateBindMk (bindings : TransBindings) (arg : Arg) : TransM (Core.CoreIdent × List TyIdentifier × LMonoTy) := do let .op op := arg @@ -450,8 +442,8 @@ partial def dealiasTypeExpr (p : Program) (te : TypeExpr) : TypeExpr := match te with | (.fvar _ idx #[]) => match p.globalContext.kindOf! idx with - | .expr te => te - | .type [] (.some te) => te + | .expr te => dealiasTypeExpr p te + | .type [] (.some te) => dealiasTypeExpr p te | _ => te | _ => te @@ -664,6 +656,52 @@ def translateFn (ty? : Option LMonoTy) (q : QualifiedIdent) : TransM Core.Expres | _, q`Core.bvextract_15_0_64 => return Core.bv64Extract_15_0_Op | _, q`Core.bvextract_31_0_64 => return Core.bv64Extract_31_0_Op + | .some .bv1, q`Core.bv_neg_overflow => return Core.bv1SNegOverflowOp + | .some .bv1, q`Core.bv_uneg_overflow => return Core.bv1UNegOverflowOp + | .some .bv1, q`Core.bv_sadd_overflow => return Core.bv1SAddOverflowOp + | .some .bv1, q`Core.bv_ssub_overflow => return Core.bv1SSubOverflowOp + | .some .bv1, q`Core.bv_smul_overflow => return Core.bv1SMulOverflowOp + | .some .bv1, q`Core.bv_sdiv_overflow => return Core.bv1SDivOverflowOp + | .some .bv1, q`Core.bv_uadd_overflow => return Core.bv1UAddOverflowOp + | .some .bv1, q`Core.bv_usub_overflow => return Core.bv1USubOverflowOp + | .some .bv1, q`Core.bv_umul_overflow => return Core.bv1UMulOverflowOp + | .some .bv8, q`Core.bv_neg_overflow => return Core.bv8SNegOverflowOp + | .some .bv8, q`Core.bv_uneg_overflow => return Core.bv8UNegOverflowOp + | .some .bv8, q`Core.bv_sadd_overflow => return Core.bv8SAddOverflowOp + | .some .bv8, q`Core.bv_ssub_overflow => return Core.bv8SSubOverflowOp + | .some .bv8, q`Core.bv_smul_overflow => return Core.bv8SMulOverflowOp + | .some .bv8, q`Core.bv_sdiv_overflow => return Core.bv8SDivOverflowOp + | .some .bv8, q`Core.bv_uadd_overflow => return Core.bv8UAddOverflowOp + | .some .bv8, q`Core.bv_usub_overflow => return Core.bv8USubOverflowOp + | .some .bv8, q`Core.bv_umul_overflow => return Core.bv8UMulOverflowOp + | .some .bv16, q`Core.bv_neg_overflow => return Core.bv16SNegOverflowOp + | .some .bv16, q`Core.bv_uneg_overflow => return Core.bv16UNegOverflowOp + | .some .bv16, q`Core.bv_sadd_overflow => return Core.bv16SAddOverflowOp + | .some .bv16, q`Core.bv_ssub_overflow => return Core.bv16SSubOverflowOp + | .some .bv16, q`Core.bv_smul_overflow => return Core.bv16SMulOverflowOp + | .some .bv16, q`Core.bv_sdiv_overflow => return Core.bv16SDivOverflowOp + | .some .bv16, q`Core.bv_uadd_overflow => return Core.bv16UAddOverflowOp + | .some .bv16, q`Core.bv_usub_overflow => return Core.bv16USubOverflowOp + | .some .bv16, q`Core.bv_umul_overflow => return Core.bv16UMulOverflowOp + | .some .bv32, q`Core.bv_neg_overflow => return Core.bv32SNegOverflowOp + | .some .bv32, q`Core.bv_uneg_overflow => return Core.bv32UNegOverflowOp + | .some .bv32, q`Core.bv_sadd_overflow => return Core.bv32SAddOverflowOp + | .some .bv32, q`Core.bv_ssub_overflow => return Core.bv32SSubOverflowOp + | .some .bv32, q`Core.bv_smul_overflow => return Core.bv32SMulOverflowOp + | .some .bv32, q`Core.bv_sdiv_overflow => return Core.bv32SDivOverflowOp + | .some .bv32, q`Core.bv_uadd_overflow => return Core.bv32UAddOverflowOp + | .some .bv32, q`Core.bv_usub_overflow => return Core.bv32USubOverflowOp + | .some .bv32, q`Core.bv_umul_overflow => return Core.bv32UMulOverflowOp + | .some .bv64, q`Core.bv_neg_overflow => return Core.bv64SNegOverflowOp + | .some .bv64, q`Core.bv_uneg_overflow => return Core.bv64UNegOverflowOp + | .some .bv64, q`Core.bv_sadd_overflow => return Core.bv64SAddOverflowOp + | .some .bv64, q`Core.bv_ssub_overflow => return Core.bv64SSubOverflowOp + | .some .bv64, q`Core.bv_smul_overflow => return Core.bv64SMulOverflowOp + | .some .bv64, q`Core.bv_sdiv_overflow => return Core.bv64SDivOverflowOp + | .some .bv64, q`Core.bv_uadd_overflow => return Core.bv64UAddOverflowOp + | .some .bv64, q`Core.bv_usub_overflow => return Core.bv64USubOverflowOp + | .some .bv64, q`Core.bv_umul_overflow => return Core.bv64UMulOverflowOp + | _, q`Core.str_len => return Core.strLengthOp | _, q`Core.str_concat => return Core.strConcatOp | _, q`Core.str_substr => return Core.strSubstrOp @@ -845,6 +883,13 @@ partial def translateExpr (p : Program) (bindings : TransBindings) (arg : Arg) : | .fn _ q`Core.re_all, [] => let fn ← translateFn .none q`Core.re_all return fn + -- Sequence.empty (1 type arg, 0 value args) + | .fn _ q`Core.seq_empty, [_atp] => + let ety ← translateLMonoTy bindings _atp + let fn : LExpr Core.CoreLParams.mono := + Core.coreOpExpr (.seq .Empty) + (.some (Core.seqTy ety)) + return fn -- Unary function applications | .fn _ fni, [xa] => match fni with @@ -877,6 +922,16 @@ partial def translateExpr (p : Program) (bindings : TransBindings) (arg : Arg) : let fn ← translateFn ty q`Core.safeneg_expr let x ← translateExpr p bindings xa return .mkApp () fn [x] + | .fn _ q`Core.bv_neg_overflow, [tpa, xa] => + let ty ← translateLMonoTy bindings (dealiasTypeArg p tpa) + let fn ← translateFn ty q`Core.bv_neg_overflow + let x ← translateExpr p bindings xa + return .mkApp () fn [x] + | .fn _ q`Core.bv_uneg_overflow, [tpa, xa] => + let ty ← translateLMonoTy bindings (dealiasTypeArg p tpa) + let fn ← translateFn ty q`Core.bv_uneg_overflow + let x ← translateExpr p bindings xa + return .mkApp () fn [x] -- Strings | .fn _ q`Core.str_concat, [xa, ya] => let x ← translateExpr p bindings xa @@ -910,7 +965,6 @@ partial def translateExpr (p : Program) (bindings : TransBindings) (arg : Arg) : let x ← translateExpr p bindings xa return .mkApp () fn [m, i, x] -- Seq operations - -- TODO: seq_empty is not yet parseable (see Grammar.lean); handle here when added. | .fn _ q`Core.seq_length, [_atp, sa] => let ety ← translateLMonoTy bindings _atp let fn : LExpr Core.CoreLParams.mono := @@ -1042,7 +1096,14 @@ partial def translateExpr (p : Program) (bindings : TransBindings) (arg : Arg) : | q`Core.bvsle | q`Core.bvslt | q`Core.bvsgt - | q`Core.bvsge => + | q`Core.bvsge + | q`Core.bv_sadd_overflow + | q`Core.bv_ssub_overflow + | q`Core.bv_smul_overflow + | q`Core.bv_sdiv_overflow + | q`Core.bv_uadd_overflow + | q`Core.bv_usub_overflow + | q`Core.bv_umul_overflow => let ty ← translateLMonoTy bindings (dealiasTypeArg p tpa) if ¬ isArithTy ty then TransM.error s!"translateExpr unexpected type for {repr fni}: {repr args}" @@ -1225,6 +1286,36 @@ private def translateCondBool (p : Program) (bindings : TransBindings) (a : Arg) | q`Core.condDet, #[ca] => pure (.det (← translateExpr p bindings ca)) | _, _ => TransM.error s!"translateCondBool: unexpected {repr op.name}" +/-- Build a nested map-update expression: `nestMapUpdate base [i1, i2] v` produces + `map_update(base, i1, map_update(map_select(base, i1), i2, v))`. -/ +private def nestMapUpdate (base : Core.Expression.Expr) (idxs : List Core.Expression.Expr) + (rhs : Core.Expression.Expr) : Core.Expression.Expr := + let selectOp := Core.coreOpExpr (.map .Select) + let updateOp := Core.coreOpExpr (.map .Update) + match idxs with + | [] => rhs + | [i] => .mkApp () updateOp [base, i, rhs] + | i :: rest => + let inner := .mkApp () selectOp [base, i] + let updatedInner := nestMapUpdate inner rest rhs + .mkApp () updateOp [base, i, updatedInner] + +/-- Decompose an LHS into a base identifier and a (reversed) list of index + expressions. For `m[k1][k2]`, returns `(m, [k2, k1])`. -/ +partial def translateLhsParts (p : Program) (bindings : TransBindings) (arg : Arg) : + TransM (Core.CoreIdent × List Core.Expression.Expr) := do + let .op op := arg + | TransM.error s!"translateLhsParts expected op {repr arg}" + match op.name, op.args with + | q`Core.lhsIdent, #[id] => + let ident ← translateIdent Core.CoreIdent id + return (ident, []) + | q`Core.lhsArray, #[_tpa, lhsa, idxa] => + let (ident, idxsRev) ← translateLhsParts p bindings lhsa + let idx ← translateExpr p bindings idxa + return (ident, idx :: idxsRev) + | _, _ => TransM.error s!"translateLhsParts: unimplemented for {repr arg}" + mutual partial def translateFnPreconds (p : Program) (name : Core.CoreIdent) (bindings : TransBindings) (arg : Arg) : TransM (List (Strata.DL.Util.FuncPrecondition Core.Expression.Expr Core.Expression.ExprMetadata)) := do @@ -1255,10 +1346,13 @@ partial def translateStmt (p : Program) (bindings : TransBindings) (arg : Arg) : | q`Core.initStatement, args => translateInitStatement p bindings args (← getOpMetaData op) | q`Core.assign, #[_tpa, lhsa, ea] => - let lhs ← translateLhs lhsa + let (lhs, idxsRev) ← translateLhsParts p bindings lhsa let val ← translateExpr p bindings ea let md ← getOpMetaData op - return ([.set lhs val md], bindings) + let rhs := match idxsRev.reverse with + | [] => val + | idxs => nestMapUpdate (.fvar () lhs none) idxs val + return ([.set lhs rhs md], bindings) | q`Core.havoc_statement, #[ida] => let id ← translateIdent Core.CoreIdent ida let md ← getOpMetaData op diff --git a/Strata/Languages/Core/Factory.lean b/Strata/Languages/Core/Factory.lean index ef99656a5b..a5ea88df7a 100644 --- a/Strata/Languages/Core/Factory.lean +++ b/Strata/Languages/Core/Factory.lean @@ -432,9 +432,9 @@ def seqLengthFunc : WFLFunc CoreLParams := ]) /- An empty `Sequence` constructor with type `∀a. Sequence a`. - NOTE: This is registered in the Factory for programmatic use, but is not yet - parseable from `.st` files because the DDM grammar cannot currently handle - 0-ary polymorphic functions (no arguments to infer the type parameter from). -/ + `Sequence.empty()` returns an empty sequence of element type `A`. + The `` is surface syntax produced by Grammar.lean and consumed by + Translate.lean; this function itself takes no value parameters. -/ def seqEmptyFunc : WFLFunc CoreLParams := polyUneval "Sequence.empty" ["a"] [] (seqTy mty[%a]) (axioms := [ @@ -499,10 +499,57 @@ def seqAppendFunc : WFLFunc CoreLParams := else #true))] ]) -/- A `Sequence` selection function with type `∀a. Sequence a → int → a`. -/ +/-! ### Sequence bounds preconditions + +`Sequence.select` / `update` / `take` / `drop` carry bounds +preconditions; the other `Sequence.*` ops are total. -/ + +/-- Choice of upper-bound operator in `mkSeqBoundsPrecond`: `Lt` (strict) for + `Sequence.select`/`update`, `Le` (non-strict) for `Sequence.take`/`drop`. + Restricting the parameter to this inductive rather than taking an + arbitrary `WFLFunc` or `LExpr` makes it impossible to attach a partial + operator (which would create a nested precondition obligation) by + accident. -/ +private inductive SeqBoundKind where | Lt | Le + +/-- Returns the *upper-bound* comparison for `mkSeqBoundsPrecond`. + The lower bound is always `0 ≤ x` (see `mkSeqBoundsPrecond`), so this + method characterises only the upper comparison. A future partial + Sequence op requiring a non-`int` comparison (e.g. a bitvector variant) + should introduce a separate helper rather than extend this enum. -/ +private def SeqBoundKind.upperOpExpr : SeqBoundKind → LExpr CoreLParams.mono + | .Lt => (intLtFunc (T := CoreLParams)).opExpr + | .Le => (intLeFunc (T := CoreLParams)).opExpr + +/-- Precondition `0 <= varName && varName `k.upperOpExpr` Sequence.length(seqName)`. + + `seqName` defaults to `"s"` since all four current call sites + (`Sequence.select`/`update`/`take`/`drop`) name their `Sequence a` input + that way. The parameter exists so a future partial Sequence op with a + different input name need only pass it explicitly rather than rely on a + hidden string literal. Either way, mismatches between the function's + declared inputs and the names used here are caught at elaboration by + `polyUneval`'s `h_precond` free-vars check. -/ +private def mkSeqBoundsPrecond + (varName : String) (k : SeqBoundKind) (seqName : String := "s") : + Strata.DL.Util.FuncPrecondition (LExpr CoreLParams.mono) CoreLParams.Metadata := + let sVar : LExpr CoreLParams.mono := .fvar default seqName (some (seqTy mty[%a])) + let xVar : LExpr CoreLParams.mono := .fvar default varName (some mty[int]) + let zero : LExpr CoreLParams.mono := .intConst default 0 + let lenS : LExpr CoreLParams.mono := .app default seqLengthFunc.opExpr sVar + let lower : LExpr CoreLParams.mono := + .app default (.app default (intLeFunc (T := CoreLParams)).opExpr zero) xVar + let upper : LExpr CoreLParams.mono := + .app default (.app default k.upperOpExpr xVar) lenS + ⟨.app default (.app default (boolAndFunc (T := CoreLParams)).opExpr lower) upper, + default⟩ + +/- A `Sequence` selection function with type `∀a. Sequence a → int → a`. + Partial: requires `0 <= i && i < Sequence.length(s)`. -/ def seqSelectFunc : WFLFunc CoreLParams := polyUneval "Sequence.select" ["a"] [("s", seqTy mty[%a]), ("i", mty[int])] mty[%a] + (preconditions := [mkSeqBoundsPrecond "i" .Lt]) /- A `Sequence` build (snoc) function with type `∀a. Sequence a → a → Sequence a`. `build(s, v)` appends a single element `v` to the end of `s`. -/ @@ -555,7 +602,8 @@ def seqBuildFunc : WFLFunc CoreLParams := ]) /- A `Sequence` update function with type `∀a. Sequence a → int → a → Sequence a`. - `update(s, i, v)` returns a sequence identical to `s` except at index `i` where the value is `v`. -/ + `update(s, i, v)` returns a sequence identical to `s` except at index `i` where the value is `v`. + Partial: requires `0 <= i && i < Sequence.length(s)`. -/ def seqUpdateFunc : WFLFunc CoreLParams := polyUneval "Sequence.update" ["a"] [("s", seqTy mty[%a]), ("i", mty[int]), ("v", mty[%a])] @@ -606,6 +654,7 @@ def seqUpdateFunc : WFLFunc CoreLParams := (((~Sequence.select : (Sequence %a) → int → %a) %3) %0) else #true)))] ]) + (preconditions := [mkSeqBoundsPrecond "i" .Lt]) /- A `Sequence` contains function with type `∀a. Sequence a → a → bool`. `contains(s, v)` is true iff there exists an index `i` such that `select(s, i) == v`. -/ @@ -628,7 +677,8 @@ def seqContainsFunc : WFLFunc CoreLParams := ]) /- A `Sequence` take function with type `∀a. Sequence a → int → Sequence a`. - `take(s, n)` returns the first `n` elements of `s`. -/ + `take(s, n)` returns the first `n` elements of `s`. + Partial: requires `0 <= n && n <= Sequence.length(s)`. -/ def seqTakeFunc : WFLFunc CoreLParams := polyUneval "Sequence.take" ["a"] [("s", seqTy mty[%a]), ("n", mty[int])] @@ -664,9 +714,11 @@ def seqTakeFunc : WFLFunc CoreLParams := (((~Sequence.select : (Sequence %a) → int → %a) %2) %0) else #true))] ]) + (preconditions := [mkSeqBoundsPrecond "n" .Le]) /- A `Sequence` drop function with type `∀a. Sequence a → int → Sequence a`. - `drop(s, n)` returns the sequence with the first `n` elements removed. -/ + `drop(s, n)` returns the sequence with the first `n` elements removed. + Partial: requires `0 <= n && n <= Sequence.length(s)`. -/ def seqDropFunc : WFLFunc CoreLParams := polyUneval "Sequence.drop" ["a"] [("s", seqTy mty[%a]), ("n", mty[int])] @@ -709,6 +761,7 @@ def seqDropFunc : WFLFunc CoreLParams := (((~Int.Add : int → int → int) %0) %1)) else #true))] ]) + (preconditions := [mkSeqBoundsPrecond "n" .Le]) def emptyTriggersFunc : WFLFunc CoreLParams := nullaryUneval "Triggers.empty" mty[Triggers] @@ -997,7 +1050,10 @@ def mapConstOp : Expression.Expr := mapConstFunc.opExpr def mapSelectOp : Expression.Expr := mapSelectFunc.opExpr def mapUpdateOp : Expression.Expr := mapUpdateFunc.opExpr def seqLengthOp : Expression.Expr := seqLengthFunc.opExpr -def seqEmptyOp : Expression.Expr := seqEmptyFunc.opExpr +def seqEmptyOp (elemTy : Option LMonoTy := none) : Expression.Expr := + match elemTy with + | none => seqEmptyFunc.opExpr + | some ty => .op default "Sequence.empty" (some (seqTy ty)) def seqAppendOp : Expression.Expr := seqAppendFunc.opExpr def seqSelectOp : Expression.Expr := seqSelectFunc.opExpr def seqBuildOp : Expression.Expr := seqBuildFunc.opExpr diff --git a/Strata/Languages/Core/ObligationExtraction.lean b/Strata/Languages/Core/ObligationExtraction.lean index 34eb371291..5cd3cd05f4 100644 --- a/Strata/Languages/Core/ObligationExtraction.lean +++ b/Strata/Languages/Core/ObligationExtraction.lean @@ -55,11 +55,7 @@ def extractGo (pc : PathConditions Expression) : Statements → | s :: rest, acc => match s with | .cmd (.cmd (.assert label e md)) => - let propType := match md.getPropertyType with - | some s => if s == MetaData.divisionByZero then .divisionByZero - else if s == MetaData.arithmeticOverflow then .arithmeticOverflow - else .assert - | none => .assert + let propType := convertMetaDataPropertyType md extractGo pc rest (acc.push (ProofObligation.mk label propType pc e md)) | .cmd (.cmd (.cover label e md)) => diff --git a/Strata/Languages/Core/Options.lean b/Strata/Languages/Core/Options.lean index 15d80e5c6b..3512de72b9 100644 --- a/Strata/Languages/Core/Options.lean +++ b/Strata/Languages/Core/Options.lean @@ -197,6 +197,14 @@ structure VerifyOptions where outputSarif : Bool /-- Print elapsed time for each verification sub-step. -/ profile : Bool + /-- Use the incremental solver backend (stdin/stdout) instead of the + batch pipeline (write file, run solver). Opt-in via `--incremental`; + disabled automatically with `--no-solve`. -/ + incremental : Bool + /-- Number of parallel solver workers. When > 1, obligations are dispatched + to concurrent solver processes using `IO.asTask`. Each task spawns its + own solver instance. Default 1 (sequential). -/ + parallelWorkers : Nat def VerifyOptions.default : VerifyOptions := { verbose := .normal, @@ -216,7 +224,9 @@ def VerifyOptions.default : VerifyOptions := { uniqueBoundNames := false skipSolver := false profile := false + incremental := false pathCap := .none + parallelWorkers := 1 } instance : Inhabited VerifyOptions where diff --git a/Strata/Languages/Core/SMTEncoder.lean b/Strata/Languages/Core/SMTEncoder.lean index e60e63fbca..20907ceea2 100644 --- a/Strata/Languages/Core/SMTEncoder.lean +++ b/Strata/Languages/Core/SMTEncoder.lean @@ -95,7 +95,7 @@ def SMT.Context.withTypeFactory (ctx : SMT.Context) (tf : @Lambda.TypeFactory Co Helper function to convert LMonoTy to TermType for datatype constructor fields. Handles monomorphic types and type variables (as `.constr tv []`). -/ -private def lMonoTyToTermType (ty : LMonoTy) : TermType := +def lMonoTyToTermType (useArrayTheory : Bool := false) (ty : LMonoTy) : TermType := match ty with | .bitvec n => .bitvec n | .tcons "bool" [] => .bool @@ -103,19 +103,23 @@ private def lMonoTyToTermType (ty : LMonoTy) : TermType := | .tcons "real" [] => .real | .tcons "string" [] => .string | .tcons "regex" [] => .regex - | .tcons name args => .constr name (args.map lMonoTyToTermType) + | .tcons name args => + if name == "Map" && useArrayTheory then + .constr "Array" (args.map $ lMonoTyToTermType useArrayTheory) + else + .constr name (args.map $ lMonoTyToTermType useArrayTheory) | .ftvar tv => .constr tv [] /-- Convert a datatype's constructors to typed SMT constructors. -/ -private def datatypeConstructorsToSMT (d : LDatatype CoreLParams.IDMeta) : List SMTConstructor := +private def datatypeConstructorsToSMT (d : LDatatype CoreLParams.IDMeta) (useArrayTheory : Bool := false): List SMTConstructor := d.constrs.map fun c => let fields := c.args.map fun (name, fieldTy) => - (d.name ++ ".." ++ name.name, lMonoTyToTermType fieldTy) + (d.name ++ ".." ++ name.name, lMonoTyToTermType useArrayTheory fieldTy) { name := c.name.name, args := fields } /-- Ensures that all datatypes in the SMT encoding do not have arrow-typed constructor arguments-/ -private def validateDatatypesForSMT (typeFactory : @Lambda.TypeFactory CoreLParams.IDMeta) +def validateDatatypesForSMT (typeFactory : @Lambda.TypeFactory CoreLParams.IDMeta) (seenDatatypes : Std.HashSet String) : Except Format Unit := do for block in typeFactory.toList do for d in block do @@ -133,7 +137,7 @@ Uses the TypeFactory ordering (already topologically sorted). Only emits datatypes that have been seen (added via addDatatype). Single-element blocks use declare-datatype, multi-element blocks use declare-datatypes. -/ -def SMT.Context.emitDatatypes (ctx : SMT.Context) : Strata.SMT.SolverM Unit := do +def SMT.Context.emitDatatypes (ctx : SMT.Context) (useArrayTheory : Bool := false): Strata.SMT.SolverM Unit := do match validateDatatypesForSMT ctx.typeFactory ctx.seenDatatypes with | .error msg => throw (IO.userError (toString msg)) | .ok () => pure () @@ -142,10 +146,10 @@ def SMT.Context.emitDatatypes (ctx : SMT.Context) : Strata.SMT.SolverM Unit := d match usedBlock with | [] => pure () | [d] => - let constructors := datatypeConstructorsToSMT d + let constructors := datatypeConstructorsToSMT d useArrayTheory Strata.SMT.Solver.declareDatatype d.name d.typeArgs constructors | _ => - let dts := usedBlock.map fun d => (d.name, d.typeArgs, datatypeConstructorsToSMT d) + let dts := usedBlock.map fun d => (d.name, d.typeArgs, datatypeConstructorsToSMT d useArrayTheory) Strata.SMT.Solver.declareDatatypes dts @[expose] abbrev BoundVars := List (String × TermType) @@ -639,7 +643,7 @@ partial def toSMTOp (E : Env) (fn : CoreIdent) (fnty : LMonoTy) (ctx : SMT.Conte -- `.bvar`s. Use substFvarsLifting to properly lift indices under binders. let bvars := (List.range formals.length).map (fun i => LExpr.bvar () i) let body := LExpr.substFvarsLifting body (formals.zip bvars) - let (term, ctx) ← toSMTTerm E bvs body ctx + let (term, ctx) ← toSMTTerm E bvs body ctx useArrayTheory .ok (ctx.addIF uf term, !ctx.ifs.contains ({ uf := uf, body := term })) -- For recursive functions with @[cases], generate per-constructor axioms. -- Int-recursive functions (no @[cases]) are pure UFs with no axioms. @@ -666,7 +670,7 @@ partial def toSMTOp (E : Env) (fn : CoreIdent) (fnty : LMonoTy) (ctx : SMT.Conte let savedSubst := ctx.tySubst let ctx ← allAxioms.foldlM (fun acc_ctx (ax: LExpr CoreLParams.mono) => do let current_axiom_ctx := acc_ctx.addSubst smt_ty_inst - let (axiom_term, new_ctx) ← toSMTTerm E [] ax current_axiom_ctx + let (axiom_term, new_ctx) ← toSMTTerm E [] ax current_axiom_ctx useArrayTheory .ok (new_ctx.addAxiom axiom_term) ) ctx let ctx := ctx.restoreSubst savedSubst diff --git a/Strata/Languages/Core/SarifOutput.lean b/Strata/Languages/Core/SarifOutput.lean index 6e69cf014b..7874a9b912 100644 --- a/Strata/Languages/Core/SarifOutput.lean +++ b/Strata/Languages/Core/SarifOutput.lean @@ -29,7 +29,8 @@ def outcomeToLevel (mode : VerificationMode) (property : Imperative.PropertyType match mode, property, outcome.satisfiabilityProperty, outcome.validityProperty with -- Cover satisfied (sat on P∧Q): always pass | _, .cover, .sat _, _ => .none - -- Unreachable (both unsat): deductive=warning for assert/divisionByZero/arithmeticOverflow, error for cover and bugFinding modes + -- Unreachable (both unsat): deductive=warning for assert-like properties + -- (those that pass vacuously), error for cover and bugFinding modes. | .deductive, p, .unsat, .unsat => if p.passWhenUnreachable then .warning else .error | _, _, .unsat, .unsat => .error -- Pass: validity proven (unsat on P∧¬Q) @@ -88,6 +89,7 @@ def extractLocation (files : Map Strata.Uri Lean.FileMap) (md : Imperative.MetaD def propertyTypeToClassification : Imperative.PropertyType → String | .divisionByZero => "division-by-zero" | .arithmeticOverflow => "arithmetic-overflow" + | .outOfBoundsAccess => "out-of-bounds-access" | .cover => "cover" | .assert => "assert" @@ -111,6 +113,8 @@ def extractRelatedLocations (files : Map Strata.Uri Lean.FileMap) (md : Imperati def vcResultToSarifResult (mode : VerificationMode) (files : Map Strata.Uri Lean.FileMap) (vcr : VCResult) : Strata.Sarif.Result := let ruleId := vcr.obligation.label let relatedLocations := extractRelatedLocations files vcr.obligation.metadata + let properties : Strata.Sarif.PropertyBag := + { propertyType := propertyTypeToClassification vcr.obligation.property } match vcr.outcome with | .error err => let level := .error @@ -119,7 +123,7 @@ def vcResultToSarifResult (mode : VerificationMode) (files : Map Strata.Uri Lean let locations := match extractLocation files vcr.obligation.metadata with | some loc => #[locationToSarif loc] | none => #[] - { ruleId, level, message, locations, relatedLocations } + { ruleId, level, message, locations, relatedLocations, properties } | .ok outcome => let level := outcomeToLevel mode vcr.obligation.property outcome let messageText := outcomeToMessage outcome @@ -127,7 +131,7 @@ def vcResultToSarifResult (mode : VerificationMode) (files : Map Strata.Uri Lean let locations := match extractLocation files vcr.obligation.metadata with | some loc => #[locationToSarif loc] | none => #[] - { ruleId, level, message, locations, relatedLocations } + { ruleId, level, message, locations, relatedLocations, properties } /-- Convert VCResults to a SARIF document -/ def vcResultsToSarif (mode : VerificationMode) (files : Map Strata.Uri Lean.FileMap) (vcResults : VCResults) : Strata.Sarif.SarifDocument := diff --git a/Strata/Languages/Core/StatementEval.lean b/Strata/Languages/Core/StatementEval.lean index ce2b35a975..57f800544d 100644 --- a/Strata/Languages/Core/StatementEval.lean +++ b/Strata/Languages/Core/StatementEval.lean @@ -320,11 +320,7 @@ private def createUnreachableAssertObligations Imperative.ProofObligations Expression := asserts.toArray.map (fun (label, md) => - let propType := match md.getPropertyType with - | some s => if s == Imperative.MetaData.divisionByZero then .divisionByZero - else if s == Imperative.MetaData.arithmeticOverflow then .arithmeticOverflow - else .assert - | _ => .assert + let propType := Imperative.convertMetaDataPropertyType md (Imperative.ProofObligation.mk label propType pathConditions (LExpr.true ()) md)) /-- diff --git a/Strata/Languages/Core/Verifier.lean b/Strata/Languages/Core/Verifier.lean index 92a07e3934..26e50d2c7d 100644 --- a/Strata/Languages/Core/Verifier.lean +++ b/Strata/Languages/Core/Verifier.lean @@ -14,6 +14,7 @@ public import Strata.DL.Imperative.MetaData public import Strata.DL.Imperative.SMTUtils public import Strata.DDM.AST public import Strata.Languages.Core.PipelinePhase +import Strata.DL.SMT.IncrementalSolver import Strata.Transform.CallElim import Strata.Transform.FilterProcedures import Strata.Transform.PrecondElim @@ -22,7 +23,9 @@ import Strata.Transform.LoopElim import Strata.Transform.ANFEncoder import Strata.Languages.Core.ObligationExtraction public import Strata.Transform.IrrelevantAxioms -import Strata.Util.Profile +import Strata.Pipeline.Context + +open Strata.Pipeline (PipelineContext) --------------------------------------------------------------------- @@ -33,110 +36,409 @@ open Strata public section -/-- Encode a verification condition into SMT-LIB format. - -This function encodes the path conditions (P) and obligation (Q) into SMT, -then emits check-sat commands to determine satisfiability and/or validity. - -When both checks are requested, uses check-sat-assuming for efficiency: -- Satisfiability: (check-sat-assuming (Q)) tests if P ∧ Q is satisfiable -- Validity: (check-sat-assuming ((not Q))) tests if P ∧ ¬Q is satisfiable - -When only one check is requested, uses assert + check-sat: -- For satisfiability: (assert Q) (check-sat) tests P ∧ Q -- For validity: (assert (not Q)) (check-sat) tests P ∧ ¬Q - -Note: The obligation term Q is encoded without negation. Negation is applied -when needed for the validity check (line 64 for check-sat-assuming, line 77 for assert). --/ -def encodeCore (ctx : Core.SMT.Context) (prelude : SolverM Unit) +/-- Encoder state for the abstract solver backend. Extends `EncoderState` with + a cache of `τ` handles for declared variables, so that `encodeTerm` can + look up handles by name instead of requiring a `mkVar` method on the solver. -/ +structure AbstractEncoderState (τ : Type) where + /-- The underlying encoder state (UF name mappings). -/ + base : EncoderState + /-- Maps declared variable/function names to their solver handles. + Populated by `encodeUF` / `declareFun`; looked up by `encodeTerm`. -/ + varHandles : Std.HashMap String τ := {} + +/-- Encoder monad over an abstract solver backend. + Parameterized by the underlying monad `m` and the solver's term type `τ` + so the encoder is not tied to any particular solver backend. -/ +abbrev AbstractEncoderM (τ : Type) (m : Type → Type) := StateT (AbstractEncoderState τ) m + +namespace AbstractEncoder + +variable {τ σ : Type} {m : Type → Type} [Monad m] [MonadExceptOf IO.Error m] + +/-- Convert a `TermType` to the solver's sort type `σ` by dispatching on + the sort primitives provided by the solver. This is the sort-level + counterpart of `encodeTerm`: both convert a Strata representation to a + solver-native handle by pattern-matching on constructors. Keeping this + logic in the encoder (rather than in `AbstractSolver`) means backends + only need to implement the one-liner sort primitives, not a full + dispatching method. -/ +def termTypeToSort (solver : AbstractSolver τ σ m) (ty : TermType) : m σ := do + match ty with + | .prim .bool => solver.boolSort + | .prim .int => solver.intSort + | .prim .real => solver.realSort + | .prim .string => solver.stringSort + | .prim .regex => solver.regexSort + | .prim (.bitvec n) => solver.bitvecSort n + | .prim .trigger => solver.boolSort + | .option inner => do + let s ← termTypeToSort solver inner + solver.constrSort "Option" [s] + | .constr name args => do + if name == "Array" then + match args with + | [k, v] => do + let ks ← termTypeToSort solver k + let vs ← termTypeToSort solver v + solver.arraySort ks vs + | _ => solver.constrSort name [] + else + let argSorts ← args.attach.mapM fun ⟨t, _⟩ => termTypeToSort solver t + solver.constrSort name argSorts +termination_by sizeOf ty +decreasing_by + all_goals simp_wf + all_goals (try omega) <;> (have := List.sizeOf_lt_of_mem ‹_›; omega) + +private def encodeUF (solver : AbstractSolver τ σ m) (uf : UF) : AbstractEncoderM τ m String := do + if let .some enc := (← get).base.ufs.get? uf then return enc + let baseName := sanitizeSmtName uf.id + let existingNames := (← get).base.ufs.toList.map (·.2) + let usedNames := Std.HashSet.ofList (existingNames ++ smtReservedKeywords) + let id := Strata.Name.findUnique baseName 1 usedNames + liftM (solver.comment uf.id) + let argSorts ← uf.args.mapM (fun vt => liftM (termTypeToSort solver vt.ty)) + let outSort ← liftM (termTypeToSort solver uf.out) + let handle ← liftM (solver.declareFun id argSorts outSort) + modify fun st => { st with varHandles := st.varHandles.insert id handle } + modifyGet fun state => (id, { state with base := { state.base with ufs := state.base.ufs.insert uf id } }) + +private def defineApp (solver : AbstractSolver τ σ m) (retSort : σ) (op : Op) (tEncs : List τ) : AbstractEncoderM τ m τ := do + -- Pattern: `liftM` lifts solver calls from `m` into `StateT`. + match op, tEncs with + -- Boolean operations + | .and, _ => liftM (solver.mkAnd tEncs) + | .or, _ => liftM (solver.mkOr tEncs) + | .not, [t] => liftM (solver.mkNot t) + | .implies, [a,b] => liftM (solver.mkImplies a b) + | .eq, _ => liftM (solver.mkEq tEncs) + | .ite, [c,t,f] => liftM (solver.mkIte c t f) + -- Arithmetic operations + | .add, _ => liftM (solver.mkAdd tEncs) + | .sub, _ => liftM (solver.mkSub tEncs) + | .mul, _ => liftM (solver.mkMul tEncs) + | .div, [a, b] => liftM (solver.mkDiv a b) + | .mod, [a, b] => liftM (solver.mkMod a b) + | .neg, [t] => liftM (solver.mkNeg t) + | .abs, [t] => liftM (solver.mkAbs t) + -- Comparison operations + | .lt, _ => liftM (solver.mkLt tEncs) + | .le, _ => liftM (solver.mkLe tEncs) + | .gt, _ => liftM (solver.mkGt tEncs) + | .ge, _ => liftM (solver.mkGe tEncs) + -- Array operations + | .select, [a, i] => liftM (solver.mkSelect a i) + | .store, [a,i,v] => liftM (solver.mkStore a i v) + -- Uninterpreted functions: declare and apply + | .uf f, _ => + let ufName ← encodeUF solver f + let ufRef : UF := { id := ufName, args := f.args, out := f.out } + let outSort ← liftM (termTypeToSort solver ufRef.out) + let handle ← liftM (solver.mkAppOp (.uf ufRef) [] outSort) + liftM (solver.mkApp handle tEncs) + -- Datatype operations: build handle and apply + | .datatype_op kind name, _ => + let handle ← liftM (solver.mkAppOp (.datatype_op kind name) [] retSort) + liftM (solver.mkApp handle tEncs) + -- All other operations (bitvectors, strings, etc.): route through mkAppOp + | _, _ => liftM (solver.mkAppOp op tEncs retSort) + +private def defineQuantifierHelper (solver : AbstractSolver τ σ m) (qk : QuantifierKind) + (args : List TermVar) + (encodeBody : AbstractEncoderM τ m τ) + (encodeTriggers : AbstractEncoderM τ m (List (List τ))) + : AbstractEncoderM τ m τ := do + let bindings ← args.mapM fun v => do + let s ← liftM (termTypeToSort solver v.ty) + return (v.id, s) + let mkQuant := match qk with + | .all => solver.mkForall + | .exist => solver.mkExists + -- Capture the encoder state so the callback can encode the body and + -- triggers with the bound variable handles in scope. The inner state + -- is intentionally not propagated back: bound variable handles are scoped + -- to the quantifier, and free variables in the body are already declared + -- before the quantifier is encoded. + let st ← get + liftM (mkQuant bindings (fun vars => do + let stWithVars := { st with + varHandles := args.zip vars |>.foldl + (fun m (v, h) => m.insert v.id h) st.varHandles } + let (bodyEnc, st') ← encodeBody.run stWithVars + let (trEncs, _) ← encodeTriggers.run st' + return (bodyEnc, trEncs))) + +def encodeTerm (solver : AbstractSolver τ σ m) (t : Term) : AbstractEncoderM τ m τ := do + match t with + | .var v => + -- Look up the τ handle cached when the variable was declared via declareFun/declareNew + match (← get).varHandles.get? v.id with + | .some handle => return handle + | .none => + -- Variable not yet declared — declare it now via declareNew + let s ← liftM (termTypeToSort solver v.ty) + let handle ← liftM (solver.declareNew v.id s) + modify fun st => { st with varHandles := st.varHandles.insert v.id handle } + return handle + | .prim p => liftM (solver.mkPrim p) + | .none ty => + -- Option none: use the datatype constructor via mkAppOp + let retSort ← liftM (termTypeToSort solver (.option ty)) + liftM (solver.mkAppOp (.datatype_op .constructor "none") [] retSort) + | .some t₁ => + -- Option some: encode the inner term and apply the constructor via mkAppOp + let t₁Enc ← encodeTerm solver t₁ + let retSort ← liftM (termTypeToSort solver (.option t₁.typeOf)) + let handle ← liftM (solver.mkAppOp (.datatype_op .constructor "some") [] retSort) + liftM (solver.mkApp handle [t₁Enc]) + | .app .re_allchar [] .regex => + let s ← liftM (termTypeToSort solver .regex) + liftM (solver.mkAppOp .re_allchar [] s) + | .app .re_all [] .regex => + let s ← liftM (termTypeToSort solver .regex) + liftM (solver.mkAppOp .re_all [] s) + | .app .re_none [] .regex => + let s ← liftM (termTypeToSort solver .regex) + liftM (solver.mkAppOp .re_none [] s) + | .app .bvnego [inner] .bool => + match inner.typeOf with + | .bitvec n => + let innerEnc ← encodeTerm solver inner + let minVal ← liftM (solver.mkPrim (.bitvec (BitVec.intMin n))) + let retSort ← liftM (termTypeToSort solver .bool) + defineApp solver retSort .eq [innerEnc, minVal] + | _ => liftM (solver.mkBool false) + | .app op ts _ => + let retSort ← liftM (termTypeToSort solver t.typeOf) + defineApp solver retSort op (← mapM₁ ts (fun ⟨tᵢ, _⟩ => encodeTerm solver tᵢ)) + | .quant qk qargs tr body => + let trExprs := if Factory.isSimpleTrigger tr then [] else extractTriggers tr + defineQuantifierHelper solver qk qargs + (encodeTerm solver body) + (mapM₁ trExprs (fun ⟨ts, _⟩ => mapM₁ ts (fun ⟨ti, _⟩ => encodeTerm solver ti))) +termination_by sizeOf t +decreasing_by + all_goals first + | term_by_mem + | add_mem_size_lemmas + have hmem : _ ∈ (if Factory.isSimpleTrigger tr then ([] : List (List Term)) else extractTriggers tr) := ‹_ ∈ trExprs› + split at hmem + · simp at hmem + · have := extractTriggers_sizeOf tr _ _ hmem ‹_ ∈ _› + simp_all; omega + +private def encodeFunction (solver : AbstractSolver τ σ m) (uf : UF) (body : Term) : AbstractEncoderM τ m String := do + if let .some enc := (← get).base.ufs.get? uf then return enc + let id := ufId (← get).base.ufs.size + liftM (solver.comment uf.id) + let argPairs ← uf.args.mapM fun vt => do + let s ← liftM (termTypeToSort solver vt.ty) + return (vt.id, s) + let outSort ← liftM (termTypeToSort solver uf.out) + let bodyEnc ← encodeTerm solver body + liftM (solver.defineFun id argPairs outSort bodyEnc) + modifyGet fun state => (id, { state with base := { state.base with ufs := state.base.ufs.insert uf id } }) + +end AbstractEncoder + +/-- Build constructor declarations for a datatype, converting field types + through the solver's `termTypeToSort`. -/ +private def datatypeConstrsM [Monad m] [MonadExceptOf IO.Error m] (solver : AbstractSolver τ σ m) + (d : Lambda.LDatatype Core.CoreLParams.IDMeta) + : m (List (String × List (String × σ))) := do + let mut result := [] + for c in d.constrs.reverse do + let mut fields := [] + for (name, fieldTy) in c.args.reverse do + let s ← AbstractEncoder.termTypeToSort solver (Core.lMonoTyToTermType (ty := fieldTy)) + fields := (d.name ++ ".." ++ name.name, s) :: fields + result := (c.name.name, fields) :: result + return result + +/-- Emit datatype declarations through the `AbstractSolver` API. -/ +private def emitDatatypesAbstract [Monad m] [MonadExceptOf IO.Error m] + (solver : AbstractSolver τ σ m) (ctx : Core.SMT.Context) : m Unit := do + -- Validate that no datatype has arrow-typed fields (same check as batch path) + match Core.validateDatatypesForSMT ctx.typeFactory ctx.seenDatatypes with + | .error msg => throw (IO.userError (toString msg)) + | .ok () => pure () + for block in ctx.typeFactory.toList do + let usedBlock := block.filter (fun d => ctx.seenDatatypes.contains d.name) + match usedBlock with + | [] => pure () + | [d] => + let constrs ← datatypeConstrsM solver d + let _ ← solver.declareDatatype d.name d.typeArgs + fun _ _ => .ok constrs + | _ => + let dtHeaders := usedBlock.map fun d => (d.name, d.typeArgs) + let allConstrs ← usedBlock.mapM (datatypeConstrsM solver) + let _ ← solver.declareDatatypes dtHeaders + fun _ _ => .ok allConstrs + +/-- Encode declarations and assertions through the `AbstractSolver` API. + Replaces `encodeDeclarations` for the incremental path — all commands + go through `AbstractSolver` methods instead of `SolverM`. + + Parameterized by the solver backend monad `m` and the solver's term/sort + types `τ`/`σ` so any implementation of `AbstractSolver τ σ m` can be used + (e.g. incremental SMT-LIB, cvc5 FFI). + + `prelude` is a deferred monadic action (e.g. solver option settings) + executed after `setLogic` but before declarations. The caller constructs + it inside the solver session and passes it in as a callback. -/ +def encodeDeclarationsAbstract [Monad m] [MonadExceptOf IO.Error m] + (solver : AbstractSolver τ σ m) + (ctx : Core.SMT.Context) + (prelude : m Unit) (assumptionTerms : List Term) (obligationTerm : Term) - (md : Imperative.MetaData Core.Expression) - (satisfiabilityCheck validityCheck : Bool) - (label : String) (varDefinitions : List Core.VarDefinition := []) - (varDeclarations : List Core.VarDeclaration := []) : - SolverM (List String × EncoderState) := do - Solver.setLogic "ALL" + (varDeclarations : List Core.VarDeclaration := []) + : m (τ × List String × EncoderState) := do + solver.setLogic "ALL" prelude - let _ ← ctx.sorts.mapM (fun s => Solver.declareSort s.name s.arity) - ctx.emitDatatypes + for s in ctx.sorts do + -- Skip sorts that will be defined as datatypes by emitDatatypesAbstract, + -- since strict solver APIs (e.g. cvc5 FFI) reject redefinition. + if !ctx.seenDatatypes.contains s.name then + let _ ← solver.declareSort s.name s.arity + emitDatatypesAbstract solver ctx + let initState : AbstractEncoderState τ := { base := EncoderState.init } let varDefNames := varDefinitions.map (·.name) let varDeclNames := varDeclarations.map (·.name) let managedNames := varDefNames ++ varDeclNames -- Filter out managed variables from UF declarations (they will be emitted separately) let ufsToDecl := if managedNames.isEmpty then ctx.ufs else ctx.ufs.filter fun uf => !managedNames.contains uf.id - let (_ufs, estate) ← ufsToDecl.mapM (fun uf => encodeUF uf) |>.run EncoderState.init + let (_ufs, estate) ← ufsToDecl.mapM (fun uf => AbstractEncoder.encodeUF solver uf) |>.run initState -- Pre-populate encoder state with managed variable names so encodeTerm - -- recognizes them without emitting declare-fun + -- recognizes them without emitting declareFun let estate := if managedNames.isEmpty then estate else let managedUfs := ctx.ufs.filter fun uf => managedNames.contains uf.id managedUfs.foldl (init := estate) fun estate uf => - { estate with ufs := estate.ufs.insert uf uf.id } - let (_ifs, estate) ← ctx.ifs.mapM (fun fn => encodeFunction fn.uf fn.body) |>.run estate - let (_axms, estate) ← ctx.axms.mapM (fun ax => encodeTerm ax) |>.run estate + { estate with base := { estate.base with ufs := estate.base.ufs.insert uf uf.id } } + let (_ifs, estate) ← ctx.ifs.mapM (fun fn => AbstractEncoder.encodeFunction solver fn.uf fn.body) |>.run estate + let (_axms, estate) ← ctx.axms.mapM (fun ax => AbstractEncoder.encodeTerm solver ax) |>.run estate + for id in _axms do + solver.assert id + -- Emit variable declarations as declareFun + for decl in varDeclarations do + let sort ← AbstractEncoder.termTypeToSort solver decl.ty + let _ ← solver.declareFun decl.name [] sort + -- Emit variable definitions as defineFun + let estate ← varDefinitions.foldlM (init := estate) fun estate def_ => do + let (bodyEnc, estate) ← (AbstractEncoder.encodeTerm solver def_.body) |>.run estate + let sort ← AbstractEncoder.termTypeToSort solver def_.ty + solver.defineFun def_.name [] sort bodyEnc + pure estate + let (assumptionIds, estate) ← assumptionTerms.mapM (AbstractEncoder.encodeTerm solver) |>.run estate + for id in assumptionIds do + solver.assert id + let (obligationId, estate) ← (AbstractEncoder.encodeTerm solver obligationTerm) |>.run estate + let ids := estate.base.ufs.toList.filterMap fun (uf, id) => + if uf.args.isEmpty && !managedNames.contains uf.id then some id else none + return (obligationId, ids, estate.base) + +/-- Encode a verification condition into SMT-LIB format, including check-sat + commands. Used by the batch pipeline. -/ +def encodeCore (ctx : Core.SMT.Context) (prelude : SolverM Unit) + (assumptionTerms : List Term) (obligationTerm : Term) + (md : Imperative.MetaData Core.Expression) + (useArrayTheory : Bool := false) + (satisfiabilityCheck validityCheck : Bool) + (label : String) + (varDefinitions : List Core.VarDefinition := []) + (varDeclarations : List Core.VarDeclaration := []) + (pctx : PipelineContext) : + SolverM (List String × EncoderState) := do + let phase {α} (name : String) (action : SolverM α) : SolverM α := + pctx.withRepeatedPhase name action + Solver.setLogic "ALL" + phase "prelude" do + prelude + + let _ ← ctx.sorts.mapM (fun s => Solver.declareSort s.name s.arity) + ctx.emitDatatypes useArrayTheory + let varDefNames := varDefinitions.map (·.name) + let varDeclNames := varDeclarations.map (·.name) + let managedNames := varDefNames ++ varDeclNames + + let estate ← phase "encodeUFs" do + let ufsToDecl := if managedNames.isEmpty then ctx.ufs + else ctx.ufs.filter fun uf => !managedNames.contains uf.id + let (_ufs, estate) ← ufsToDecl.mapM (fun uf => encodeUF uf) |>.run EncoderState.init + pure estate + + let estate ← phase "encodeFunctions" do + let estate := if managedNames.isEmpty then estate + else + let managedUfs := ctx.ufs.filter fun uf => managedNames.contains uf.id + managedUfs.foldl (init := estate) fun estate uf => + { estate with ufs := estate.ufs.insert uf uf.id } + let (_ifs, estate) ← ctx.ifs.mapM (fun fn => encodeFunction fn.uf fn.body) |>.run estate + pure estate + + let (_axms, estate) ← phase "encodeAxioms" do + ctx.axms.mapM (fun ax => encodeTerm ax) |>.run estate + for id in _axms do Solver.assert id -- Emit variable declarations as declare-fun for decl in varDeclarations do Solver.declareFun decl.name [] decl.ty + -- Emit variable definitions as define-fun (macro expansions, not constraints) - let estate ← varDefinitions.foldlM (init := estate) fun estate def_ => do - let (bodyEnc, estate) ← (encodeTerm def_.body) |>.run estate - Solver.defineFunTerm def_.name [] def_.ty bodyEnc - pure estate - -- Assert assumption terms - let (assumptionIds, estate) ← assumptionTerms.mapM (encodeTerm) |>.run estate + let estate ← phase "defineFunTerms" do + varDefinitions.foldlM (init := estate) fun estate def_ => do + let (bodyEnc, estate) ← (encodeTerm def_.body) |>.run estate + Solver.defineFunTerm def_.name [] def_.ty bodyEnc + pure estate + + let (assumptionIds, estate) ← phase "encodeAssumptions" do + assumptionTerms.mapM (encodeTerm) |>.run estate + for id in assumptionIds do Solver.assert id - -- Encode the obligation term Q (not negated) - let (obligationId, estate) ← (encodeTerm obligationTerm) |>.run estate - let ids := estate.ufs.toList.filterMap fun (uf, id) => - if uf.args.isEmpty && !managedNames.contains uf.id then some id else none + let (obligationId, estate) ← phase "encodeObligation" do + (encodeTerm obligationTerm) |>.run estate - -- Choose encoding strategy: use check-sat-assuming only when doing both checks - let bothChecks := satisfiabilityCheck && validityCheck - - if bothChecks then - -- Satisfiability check: P ∧ Q satisfiable? - Solver.comment "Satisfiability" - Imperative.SMT.addLocationInfo (P := Core.Expression) (md := md) - (message := ("sat-message", "Property can be satisfied")) - let obligationStr ← Solver.termToSMTString obligationId - let _ ← Solver.checkSatAssuming [obligationStr] ids - - -- Validity check: P ∧ ¬Q satisfiable? - Solver.comment "Validity" - Imperative.SMT.addLocationInfo (P := Core.Expression) (md := md) - (message := ("unsat-message", "Property is always true")) - let negObligationStr := s!"(not {obligationStr})" - let _ ← Solver.checkSatAssuming [negObligationStr] ids - else - if satisfiabilityCheck then - -- P ∧ Q satisfiable? + let ids ← phase "epilog" do + let ids := estate.ufs.toList.filterMap fun (uf, id) => + if uf.args.isEmpty && !managedNames.contains uf.id then some id else none + + let bothChecks := satisfiabilityCheck && validityCheck + + if bothChecks then Solver.comment "Satisfiability" Imperative.SMT.addLocationInfo (P := Core.Expression) (md := md) (message := ("sat-message", "Property can be satisfied")) - Solver.assert obligationId - let _ ← Solver.checkSat ids - else if validityCheck then - -- P ∧ ¬Q satisfiable? + let obligationStr ← Solver.termToSMTString obligationId + let _ ← Solver.checkSatAssuming [obligationStr] ids + Solver.comment "Validity" Imperative.SMT.addLocationInfo (P := Core.Expression) (md := md) (message := ("unsat-message", "Property is always true")) - Solver.assert (← encodeTerm (Factory.not obligationTerm) |>.run estate).1 - let _ ← Solver.checkSat ids - - -- Emit the property summary (or label) as the final message in the SMT-LIB output. - -- Use `setInfoString` so the value is quoted and escaped per SMT-LIB 2.6+ - -- (doubled `""` for embedded quotes). C-style `\"` escaping would be rejected - -- by SMT-LIB consumers: backslash is a literal character in string contexts, - -- and the following `"` would close the string. - let rawMsg := md.getPropertySummary.getD label - Solver.setInfoString "final-message" rawMsg + let negObligationStr := s!"(not {obligationStr})" + let _ ← Solver.checkSatAssuming [negObligationStr] ids + else + if satisfiabilityCheck then + Solver.comment "Satisfiability" + Imperative.SMT.addLocationInfo (P := Core.Expression) (md := md) + (message := ("sat-message", "Property can be satisfied")) + Solver.assert obligationId + let _ ← Solver.checkSat ids + else if validityCheck then + Solver.comment "Validity" + Imperative.SMT.addLocationInfo (P := Core.Expression) (md := md) + (message := ("unsat-message", "Property is always true")) + Solver.assert (← encodeTerm (Factory.not obligationTerm) |>.run estate).1 + let _ ← Solver.checkSat ids + + let rawMsg := md.getPropertySummary.getD label + Solver.setInfoString "final-message" rawMsg + pure ids return (ids, estate) @@ -207,6 +509,7 @@ def dischargeObligation (label : String) (varDefinitions : List VarDefinition := []) (varDeclarations : List VarDeclaration := []) + (pctx : PipelineContext) : IO (Except Imperative.SMT.SolverError (SMT.Result × SMT.Result × EncoderState)) := do -- CVC5 requires --incremental for multiple (check-sat) commands let baseFlags := getSolverFlags options @@ -219,8 +522,9 @@ def dischargeObligation Imperative.SMT.dischargeObligation (P := Core.Expression) (Strata.SMT.Encoder.encodeCore ctx (getSolverPrelude options.solver) - assumptionTerms obligationTerm md satisfiabilityCheck validityCheck - (label := label) (varDefinitions := varDefinitions) (varDeclarations := varDeclarations)) + assumptionTerms obligationTerm md options.useArrayTheory satisfiabilityCheck validityCheck + (label := label) (varDefinitions := varDefinitions) (varDeclarations := varDeclarations) + (pctx := pctx)) (typedVarToSMTFn ctx) vars options.solver @@ -228,6 +532,51 @@ def dischargeObligation solverFlags (options.verbose > .normal) satisfiabilityCheck validityCheck (skipSolver := options.skipSolver) + (pctx := pctx) + +/-- Discharge a proof obligation using the incremental solver backend. + Spawns a live solver process, sends commands via stdin/stdout, and + reads results interactively. Returns the same result triple as the + batch `dischargeObligation`. -/ +def dischargeObligationIncremental + (options : VerifyOptions) + (vars : List Expression.TypedIdent) + (_md : Imperative.MetaData Expression) + (assumptionTerms : List Term) + (obligationTerm : Term) + (ctx : SMT.Context) + (satisfiabilityCheck validityCheck : Bool) + (_label : String) + (varDefinitions : List VarDefinition := []) + (varDeclarations : List VarDeclaration := []) + : IO (Except Imperative.SMT.SolverError (SMT.Result × SMT.Result × EncoderState)) := do + let baseFlags := getSolverFlags options + let needsIncremental := satisfiabilityCheck && validityCheck + let solverSpecificFlags := match options.solver with + | "cvc5" => + let base := #["--quiet", "--lang", "smt"] + if needsIncremental && !baseFlags.contains "--incremental" then + base ++ #["--incremental"] + else base + | "z3" => #["-in"] + | _ => #[] + let allFlags := solverSpecificFlags ++ baseFlags + let encodeDecl (solver : Strata.SMT.AbstractSolver Term TermType + Strata.SMT.IncrementalSolverM) : + Strata.SMT.IncrementalSolverM Imperative.SMT.EncodedObligation := do + let prelude : Strata.SMT.IncrementalSolverM Unit := match options.solver with + | "z3" => do + solver.setOption "smt.mbqi" "false" + solver.setOption "auto_config" "false" + | _ => pure () + let (obligationId, ids, estate) ← + _root_.Strata.SMT.Encoder.encodeDeclarationsAbstract solver ctx prelude + assumptionTerms obligationTerm + (varDefinitions := varDefinitions) (varDeclarations := varDeclarations) + return { obligationId, assumptionIds := ids, estate } + Imperative.SMT.dischargeObligationIncremental (P := Core.Expression) + encodeDecl (typedVarToSMTFn ctx) vars options.solver allFlags + satisfiabilityCheck validityCheck end -- public section end Core.SMT @@ -450,7 +799,7 @@ def label (o : VCOutcome) (property : Imperative.PropertyType) -- Simplified labels for minimal check level else if checkLevel == .minimal then if property.passWhenUnreachable then - -- Assert-like property (assert, divisionByZero, arithmeticOverflow) + -- Assert-like property (i.e. passes vacuously on unreachable paths). if checkMode == .deductive then match o.validityProperty with | .unsat => "pass" @@ -930,6 +1279,61 @@ def SMT.Result.adjustForPhases (r : SMT.Result) | .sat _ | .unknown _ => AbstractedPhase.validateModel phases r obligation | other => (other, []) +/-- A discharge function encapsulates the solver backend. It takes assumption + terms, the obligation term, the SMT context, and the satisfiability/validity + check flags, and returns the solver results. The pipeline is parametrized + by this function so it does not know about SMT-LIB or any specific solver. -/ +abbrev DischargeFn := + List Term → Term → SMT.Context → Bool → Bool → List VarDefinition → List VarDeclaration → + IO (Except Imperative.SMT.SolverError (SMT.Result × SMT.Result × EncoderState)) + +/-- A `CoreSMTSolver` encapsulates the strategy for discharging all proof + obligations extracted from a CoreSMT program. The pipeline is parametrized + by this function so that the solver backend can be swapped — e.g. for a + parallel solver that dispatches obligations concurrently, or an incremental + solver that shares path-condition state across assertions. + + The solver receives the factory extensions (custom functions from external + phases, e.g. `ReFactory`) and the obligation program (in CoreSMT format + after all pipeline transformations), and returns verification results + together with statistics. The factory parameter ensures custom solvers + can build the environment with the same function definitions as the + default solver. -/ +abbrev CoreSMTSolver := + @Lambda.Factory CoreLParams → Program → EIO DiagnosticModel (VCResults × Statistics) + +/-- Factory for discharge functions. Called once per obligation with the + obligation's typed variables, metadata, and label. A custom implementation + can replace the default (batch/incremental SMT-LIB) backend. -/ +abbrev MkDischargeFn := + VerifyOptions → IO.Ref Nat → System.FilePath → + List Expression.TypedIdent → Imperative.MetaData Expression → String → + PipelineContext → DischargeFn + +/-- Construct a `DischargeFn` from verification options. Selects the incremental + (abstract solver) backend or the batch (SMT-LIB file) backend based on + `options.incremental` and `options.alwaysGenerateSMT`. -/ +def mkDischargeFn : MkDischargeFn := fun (options : VerifyOptions) (counter : IO.Ref Nat) + (tempDir : System.FilePath) + (vars : List Expression.TypedIdent) + (md : Imperative.MetaData Expression) + (label : String) + (pctx : PipelineContext) => + fun assumptionTerms obligationTerm ctx satisfiabilityCheck validityCheck + varDefinitions varDeclarations => do + if options.incremental && !options.alwaysGenerateSMT then + SMT.dischargeObligationIncremental options vars md + assumptionTerms obligationTerm ctx satisfiabilityCheck validityCheck label + (varDefinitions := varDefinitions) (varDeclarations := varDeclarations) + else + let counterVal ← counter.get + counter.set (counterVal + 1) + let filename := tempDir / s!"{SMT.sanitizeFilename label}_{counterVal}.smt2" + SMT.dischargeObligation options vars md filename.toString + assumptionTerms obligationTerm ctx satisfiabilityCheck validityCheck + (label := label) (varDefinitions := varDefinitions) (varDeclarations := varDeclarations) + (pctx := pctx) + /-- Invoke a backend engine and get the analysis result for a given proof obligation. @@ -937,36 +1341,18 @@ given proof obligation. def getObligationResult (assumptionTerms : List Term) (obligationTerm : Term) (ctx : SMT.Context) (obligation : ProofObligation Expression) (p : Program) - (options : VerifyOptions) (counter : IO.Ref Nat) - (tempDir : System.FilePath) (satisfiabilityCheck validityCheck : Bool) + (options : VerifyOptions) + (discharge : DischargeFn) + (satisfiabilityCheck validityCheck : Bool) (phases : List AbstractedPhase) (varDefinitions : List VarDefinition := []) (varDeclarations : List VarDeclaration := []) : EIO DiagnosticModel VCResult := do let prog := f!"\n\n[DEBUG] Evaluated program:\n{Core.formatProgram p}" - let counterVal ← counter.get - counter.set (counterVal + 1) - let filename := tempDir / s!"{Core.SMT.sanitizeFilename obligation.label}_{counterVal}.smt2" - let varsInObligation := ProofObligation.getVars obligation - -- Filter out managed variables (they are emitted as define-fun/declare-fun, not via UF declarations) - let managedNames := (varDefinitions.map (·.name)) ++ (varDeclarations.map (·.name)) - let varsInObligation := varsInObligation.filter fun (v, _) => - !managedNames.contains v.name - -- All variables in ProofObligation must have been typed. - let typedVarsInObligation ← varsInObligation.mapM - (fun (v,ty) => do - match ty with - | .some ty => return (v,LTy.forAll [] ty) - | .none => throw (DiagnosticModel.fromMessage s!"{v} untyped")) - let ans ← - IO.toEIO - (fun e => DiagnosticModel.fromFormat f!"{e}") - (SMT.dischargeObligation options - typedVarsInObligation - obligation.metadata - filename.toString - assumptionTerms obligationTerm ctx satisfiabilityCheck validityCheck - (label := obligation.label) (varDefinitions := varDefinitions) (varDeclarations := varDeclarations)) + let ans ← IO.toEIO + (fun e => DiagnosticModel.fromFormat f!"{e}") + (discharge assumptionTerms obligationTerm ctx satisfiabilityCheck validityCheck + varDefinitions varDeclarations) match ans with | .error solverError => let vcError : VCError := match solverError with @@ -1009,7 +1395,92 @@ def getObligationResult (assumptionTerms : List Term) (obligationTerm : Term) lexprModel := model } return result - +/-- Data needed to dispatch a single obligation to the solver. Produced by the + sequential preprocessing phase and consumed by the (potentially parallel) + solver dispatch phase. -/ +private structure SolverJob where + obligation : ProofObligation Expression + assumptionTerms : List Term + obligationTerm : Term + ctx : SMT.Context + needSatCheck : Bool + needValCheck : Bool + peSatResult? : Option SMT.Result + peValResult? : Option SMT.Result + typedVarsInObligation : List Expression.TypedIdent + varDefs : List VarDefinition := [] + varDecls : List VarDeclaration := [] + +/-- Dispatch a single solver job. Spawns a solver process and reads the result. -/ +private def dispatchSolverJob (job : SolverJob) (p : Program) + (options : VerifyOptions) (counter : IO.Ref Nat) (tempDir : System.FilePath) + (phases : List AbstractedPhase) + (mkDischarge : MkDischargeFn := mkDischargeFn) + (pctx : PipelineContext) + : IO (Except DiagnosticModel VCResult) := do + let discharge := mkDischarge options counter tempDir + job.typedVarsInObligation job.obligation.metadata job.obligation.label pctx + let resultOrErr ← (getObligationResult job.assumptionTerms job.obligationTerm job.ctx + job.obligation p options discharge job.needSatCheck job.needValCheck phases + (varDefinitions := job.varDefs) (varDeclarations := job.varDecls)).toBaseIO + match resultOrErr with + | .error diag => return .error diag + | .ok result => + let result := match result.outcome with + | .ok solverOutcome => + let satResult := job.peSatResult?.getD solverOutcome.satisfiabilityProperty + let valResult := job.peValResult?.getD solverOutcome.validityProperty + { result with outcome := .ok { solverOutcome with + satisfiabilityProperty := satResult, + validityProperty := valResult } } + | .error _ => result + return .ok result + +/-- Dispatch solver jobs using a bounded worker pool. Workers pull from a shared + queue; results returned in original order. -/ +private def dispatchJobsParallel (jobs : List SolverJob) (p : Program) + (options : VerifyOptions) (counter : IO.Ref Nat) (tempDir : System.FilePath) + (phases : List AbstractedPhase) (workers : Nat) + (mkDischarge : MkDischargeFn := mkDischargeFn) + (pctx : PipelineContext) + : IO (List (Option (Except DiagnosticModel VCResult))) := do + let queue ← IO.mkRef (jobs.zipIdx : List (SolverJob × Nat)) + let resultMap ← IO.mkRef ({} : Std.HashMap Nat (Except DiagnosticModel VCResult)) + let shouldStop ← IO.mkRef false + let workerFn : IO Unit := do + let mut running := true + while running do + if ← shouldStop.get then break + let entry ← queue.modifyGet fun q => + match q with + | [] => (none, []) + | hd :: tl => (some hd, tl) + match entry with + | none => running := false + | some (job, idx) => + let result ← dispatchSolverJob job p options counter tempDir phases mkDischarge pctx + resultMap.modify (·.insert idx result) + if options.stopOnFirstError then + match result with + | .ok r => if r.isNotSuccess then shouldStop.set true + | .error _ => shouldStop.set true + let numWorkers := min workers jobs.length + let workerTasks ← (List.range numWorkers).mapM fun _ => + IO.asTask (prio := .dedicated) workerFn + -- Join all tasks before throwing to prevent orphaned processes + let mut firstError : Option IO.Error := none + for task in workerTasks do + match task.get with + | .ok () => pure () + | .error e => if firstError.isNone then firstError := some e + if let some e := firstError then throw e + let rmap ← resultMap.get + let mut results : List (Option (Except DiagnosticModel VCResult)) := [] + for idx in (List.range jobs.length).reverse do + results := rmap[idx]? :: results + return results + +private def verifySingleEnv (oblProgram : Program) (moreFns : @Lambda.Factory CoreLParams := Lambda.Factory.default) (options : VerifyOptions) @@ -1022,23 +1493,23 @@ def verifySingleEnv (oblProgram : Program) -- irrelevant axiom removal to determine which axioms to prune. (axiomProgram : Option Program := .none) (externalPhases : List AbstractedPhase := []) - (corePhases : List AbstractedPhase := coreAbstractedPhases) : + (corePhases : List AbstractedPhase := coreAbstractedPhases) + (mkDischarge : MkDischargeFn := mkDischargeFn) + (pctx : PipelineContext) : EIO DiagnosticModel (VCResults × Statistics) := do -- Build SMT encoding context from the obligations program itself let E ← EIO.ofExcept (Core.buildEnv options oblProgram moreFns (registerCustomFunctions := true) |>.map (·.1)) let p := E.program - let profile := options.profile - -- Extract obligations from the obligations program via ObligationExtraction + -- Extract obligations from the obligations program via ObligationExtraction let obligations ← match Core.ObligationExtraction.extractObligations oblProgram with | .ok obs => pure obs | .error e => .error (DiagnosticModel.fromFormat f!"ObligationExtraction error: {e}") let mut stats : Statistics := ({} : Statistics) |>.increment s!"{Evaluator.Stats.verify_numObligations}" obligations.size let mut results := (#[] : VCResults) - let mut preprocessNs : Nat := 0 - let mut smtEncodeNs : Nat := 0 - let mut solverNs : Nat := 0 - let mut peResolvedCount : Nat := 0 + let mut solverJobs : List SolverJob := [] + let mut solverJobIndices : List Nat := [] + let useParallel := options.parallelWorkers > 1 for obligation in obligations do -- Determine which checks to perform based on metadata or check mode/amount let (satisfiabilityCheck, validityCheck) := @@ -1052,10 +1523,8 @@ def verifySingleEnv (oblProgram : Program) | .deductive, _ => if obligation.property.passWhenUnreachable then (false, true) else (true, false) | .bugFinding, _ => (true, false) - let t0 ← IO.monoNanosNow - let (obligation, peSatResult?, peValResult?) ← preprocessObligation obligation p options satisfiabilityCheck validityCheck axiomCache axiomNames axiomProgram - let t1 ← IO.monoNanosNow - preprocessNs := preprocessNs + (t1 - t0) + let (obligation, peSatResult?, peValResult?) ← pctx.withRepeatedPhase "preprocess" do + preprocessObligation obligation p options satisfiabilityCheck validityCheck axiomCache axiomNames axiomProgram -- If evaluator resolved both checks, we're done, unless we always want to generate SMT queries if not options.alwaysGenerateSMT then if let (some peSat, some peVal) := (peSatResult?, peValResult?) then @@ -1071,7 +1540,6 @@ def verifySingleEnv (oblProgram : Program) let result : VCResult := { obligation, outcome := .ok outcome, verbose := options.verbose, checkLevel := options.checkLevel, checkMode := options.checkMode, lexprModel := [] } results := results.push result - peResolvedCount := peResolvedCount + 1 if result.isFailure || result.isImplementationError || result.isTimeout then if options.verbose >= .debug then let prog := f!"\n\n[DEBUG] Evaluated program:\n{Core.formatProgram p}" @@ -1081,10 +1549,9 @@ def verifySingleEnv (oblProgram : Program) -- Need the solver for at least one check let needSatCheck := satisfiabilityCheck && peSatResult?.isNone let needValCheck := validityCheck && peValResult?.isNone - let t2 ← IO.monoNanosNow - let maybeTerms := ProofObligation.toSMTTerms E obligation { SMT.Context.default with uniqueBoundNames := options.uniqueBoundNames } options.useArrayTheory - let t3 ← IO.monoNanosNow - smtEncodeNs := smtEncodeNs + (t3 - t2) + let maybeTerms ← pctx.withRepeatedPhase "smtEncode" do + let smtCtx := { SMT.Context.default with uniqueBoundNames := options.uniqueBoundNames } + pure (ProofObligation.toSMTTerms E obligation smtCtx options.useArrayTheory) match maybeTerms with | .error err => let result := { obligation, @@ -1100,34 +1567,84 @@ def verifySingleEnv (oblProgram : Program) if options.stopOnFirstError then break | .ok (assumptionTerms, varDefs, varDecls, obligationTerm, ctx, encStats) => stats := stats.merge encStats - let t4 ← IO.monoNanosNow - let result ← getObligationResult assumptionTerms obligationTerm ctx obligation p options - counter tempDir needSatCheck needValCheck (externalPhases ++ corePhases) - (varDefinitions := varDefs) (varDeclarations := varDecls) - let t5 ← IO.monoNanosNow - solverNs := solverNs + (t5 - t4) - -- Merge evaluator results with solver results - let result := match result.outcome with - | .ok solverOutcome => - let satResult := peSatResult?.getD solverOutcome.satisfiabilityProperty - let valResult := peValResult?.getD solverOutcome.validityProperty - { result with outcome := .ok { solverOutcome with - satisfiabilityProperty := satResult, - validityProperty := valResult } } - | .error _ => result - results := results.push result - if result.isNotSuccess then - if options.verbose >= .debug then - let prog := f!"\n\n[DEBUG] Evaluated program:\n{Core.formatProgram p}" - dbg_trace f!"\n\nResult: {result}\n{prog}" - if options.stopOnFirstError then break - if profile then - let _ ← (IO.println s!"[profile] Preprocess obligations: {nsToMs preprocessNs}ms" |>.toBaseIO) - let _ ← (IO.println s!"[profile] SMT encoding: {nsToMs smtEncodeNs}ms" |>.toBaseIO) - let _ ← (IO.println s!"[profile] Solver/file writing: {nsToMs solverNs}ms" |>.toBaseIO) - let _ ← (IO.println s!"[profile] Obligations: {obligations.size} total, {peResolvedCount} resolved by evaluator" |>.toBaseIO) + let varsInObligation := ProofObligation.getVars obligation + -- Filter out managed variables (they are emitted as define-fun/declare-fun, not via UF declarations) + let managedNames := (varDefs.map (·.name)) ++ (varDecls.map (·.name)) + let varsInObligation := varsInObligation.filter fun (v, _) => + !managedNames.contains v.name + let typedVarsInObligation ← varsInObligation.mapM + (fun (v,ty) => do + match ty with + | .some ty => return (v,LTy.forAll [] ty) + | .none => throw (DiagnosticModel.fromMessage s!"{v} untyped")) + if useParallel then + let job : SolverJob := { + obligation, assumptionTerms, obligationTerm, ctx, + needSatCheck, needValCheck, peSatResult?, peValResult?, + typedVarsInObligation, varDefs, varDecls } + solverJobs := job :: solverJobs + solverJobIndices := results.size :: solverJobIndices + results := results.push { obligation, outcome := .error (.encoding "pending parallel dispatch"), + verbose := options.verbose, checkLevel := options.checkLevel, + checkMode := options.checkMode, lexprModel := [] } + else + let discharge := mkDischarge options counter tempDir + typedVarsInObligation obligation.metadata obligation.label pctx + let result ← pctx.withRepeatedPhase "solver" do + getObligationResult assumptionTerms obligationTerm ctx obligation p options + discharge needSatCheck needValCheck (externalPhases ++ corePhases) + (varDefinitions := varDefs) (varDeclarations := varDecls) + -- Merge evaluator results with solver results + let result := match result.outcome with + | .ok solverOutcome => + let satResult := peSatResult?.getD solverOutcome.satisfiabilityProperty + let valResult := peValResult?.getD solverOutcome.validityProperty + { result with outcome := .ok { solverOutcome with + satisfiabilityProperty := satResult, + validityProperty := valResult } } + | .error _ => result + results := results.push result + if result.isNotSuccess then + if options.verbose >= .debug then + let prog := f!"\n\n[DEBUG] Evaluated program:\n{Core.formatProgram p}" + dbg_trace f!"\n\nResult: {result}\n{prog}" + if options.stopOnFirstError then break + -- Phase 2: Parallel solver dispatch + if useParallel && !solverJobs.isEmpty then + let phases := externalPhases ++ corePhases + let jobResults ← IO.toEIO (fun e => DiagnosticModel.fromFormat f!"{e}") + (dispatchJobsParallel solverJobs.reverse p options counter tempDir phases options.parallelWorkers mkDischarge pctx) + let mut firstError : Option DiagnosticModel := none + for (jobResult?, jobIdx) in jobResults.zip solverJobIndices.reverse do + match jobResult? with + | some (.ok result) => + results := results.setIfInBounds jobIdx result + | some (.error diag) => + if firstError.isNone then firstError := some diag + | none => pure () + if let some diag := firstError then throw diag return (results, stats) +/-- Construct the default `CoreSMTSolver` that discharges obligations + sequentially using the batch or incremental SMT-LIB backend (selected + by `options.incremental`). This is the standard solver used by `verify` + when no custom solver is provided. -/ +def mkDefaultCoreSMTSolver + (options : VerifyOptions) + (counter : IO.Ref Nat) (tempDir : System.FilePath) + (axiomCache : Option IrrelevantAxioms.Cache := .none) + (axiomNames : List String := []) + (axiomProgram : Option Program := .none) + (externalPhases : List AbstractedPhase := []) + (corePhases : List AbstractedPhase := coreAbstractedPhases) + (mkDischarge : MkDischargeFn := mkDischargeFn) + (pctx : PipelineContext) : + CoreSMTSolver := + fun moreFns oblProgram => + verifySingleEnv oblProgram moreFns options counter tempDir axiomCache + axiomNames axiomProgram externalPhases corePhases + (mkDischarge := mkDischarge) pctx + /-- Run the Strata Core verification pipeline on a program: transform, type-check, partially evaluate, and discharge proof obligations via SMT. All program-wide transformations that occur before any analyses @@ -1143,12 +1660,21 @@ def verify (program : Program) (externalPhases : List AbstractedPhase := []) (prefixPhases : List PipelinePhase := []) (keepAllFilesPrefix : Option String := none) + (solver : Option CoreSMTSolver := none) + (mkDischarge : MkDischargeFn := mkDischargeFn) + (pipelineCtx : Option PipelineContext := none) : EIO DiagnosticModel VCResults := do let profile := options.profile + let pctx ← match pipelineCtx with + | some ctx => pure ctx + | none => + let mode := if profile then Strata.Pipeline.OutputMode.profile else .quiet + (PipelineContext.create (outputMode := mode) : BaseIO _) + let factory ← EIO.ofExcept (Core.Factory.addFactory moreFns) let pipelinePhases := prefixPhases ++ corePipelinePhases (procs := proceduresToVerify) (options := options) (moreFns := moreFns) let phases := pipelinePhases.map (·.phase) - let (oblProgram, pipelineStats) ← profileStep profile " Program transformations" do + let (oblProgram, pipelineStats) ← pctx.withPhase "programTransformations" do if let some pfx := keepAllFilesPrefix then if let some parent := (System.FilePath.mk pfx).parent then IO.toEIO (fun e => DiagnosticModel.fromFormat f!"{e}") @@ -1156,10 +1682,13 @@ def verify (program : Program) let mut current := program let mut state : Transform.CoreTransformState := { Transform.CoreTransformState.emp with factory := some factory } let mut step := 0 + have : Inhabited (Except Transform.Err Program × Transform.CoreTransformState) := + ⟨(.error default, Transform.CoreTransformState.emp)⟩ for pp in pipelinePhases do - let (result, newState) := Transform.runWith current (fun prog => do - let (_, next) ← pp.transform prog - return next) state + let (result, newState) ← pctx.withRepeatedPhasePure pp.phase.name fun () => + Transform.runWith current (fun prog => do + let (_, next) ← pp.transform prog + return next) state match result with | .ok next => current := next @@ -1173,23 +1702,21 @@ def verify (program : Program) throw e .ok (current, state.statistics) let allStats := pipelineStats - -- Extract axiom names from the original program. The oblProgram (output of - -- toCoreProofObligationProgram) inlines axioms as assume statements but does - -- not preserve axiom declarations, so we use the pre-transform program for - -- axiom identity. let axiomNames := program.decls.filterMap fun decl => match decl with | .ax a _ => some a.name | _ => none - -- Build the axiom relevance cache from the original program (which has - -- axiom declarations). The cache is reused across all obligations. - let axiomCache? ← profileStep profile " Build axiom relevance cache" do + let axiomCache? ← pctx.withPhase "buildAxiomCache" do pure (if options.removeIrrelevantAxioms == .Off then .none else .some (IrrelevantAxioms.Cache.build program)) let counter ← IO.toEIO (fun e => DiagnosticModel.fromFormat f!"{e}") (IO.mkRef 0) - let VCss ← profileStep profile " VC discharge" do + let VCss ← pctx.withPhase "vcDischarge" do if options.checkOnly then pure [] else - pure [← verifySingleEnv oblProgram moreFns options counter tempDir axiomCache? axiomNames (axiomProgram := program) externalPhases phases] + let coreSMTSolver := solver.getD + (mkDefaultCoreSMTSolver options counter tempDir axiomCache? + axiomNames (axiomProgram := program) externalPhases phases + (mkDischarge := mkDischarge) pctx) + pure [← coreSMTSolver moreFns oblProgram] let allStats := VCss.foldl (fun acc (_, s) => acc.merge s) allStats if profile then let _ ← (IO.println allStats.format |>.toBaseIO) @@ -1237,6 +1764,8 @@ def verify (moreFns : @Lambda.Factory Core.CoreLParams := Lambda.Factory.default) (externalPhases : List Core.AbstractedPhase := []) (keepAllFilesPrefix : Option String := none) + (solver : Option Core.CoreSMTSolver := none) + (mkDischarge : Core.MkDischargeFn := Core.mkDischargeFn) : IO Core.VCResults := do let (program, errors) := Core.getProgram env ictx if errors.isEmpty then @@ -1244,7 +1773,9 @@ def verify EIO.toIO (fun dm => IO.Error.userError (toString (dm.format (some ictx.fileMap)))) (Core.verify program tempDir proceduresToVerify options moreFns (externalPhases := externalPhases) - (keepAllFilesPrefix := keepAllFilesPrefix)) + (keepAllFilesPrefix := keepAllFilesPrefix) + (solver := solver) + (mkDischarge := mkDischarge)) match options.vcDirectory with | .none => IO.FS.withTempDir runner diff --git a/Strata/Languages/Laurel/ConstrainedTypeElim.lean b/Strata/Languages/Laurel/ConstrainedTypeElim.lean index 7e86c374a1..dce1a2eef3 100644 --- a/Strata/Languages/Laurel/ConstrainedTypeElim.lean +++ b/Strata/Languages/Laurel/ConstrainedTypeElim.lean @@ -224,7 +224,7 @@ private def mkWitnessProc (ptMap : ConstrainedTypeMap) (ct : ConstrainedType) : { name := mkId s!"$witness_{ct.name.text}" inputs := [] outputs := [] - body := .Transparent ⟨.Block [witnessInit, assert] none, src⟩ + body := .Opaque [] (some ⟨.Block [witnessInit, assert] none, src⟩) [] preconditions := [] isFunctional := false decreases := none } diff --git a/Strata/Languages/Laurel/HeapParameterization.lean b/Strata/Languages/Laurel/HeapParameterization.lean index fecaf5350c..dae58c3caa 100644 --- a/Strata/Languages/Laurel/HeapParameterization.lean +++ b/Strata/Languages/Laurel/HeapParameterization.lean @@ -9,6 +9,7 @@ public import Strata.Languages.Laurel.Laurel public import Strata.Languages.Laurel.Grammar.AbstractToConcreteTreeTranslator public import Strata.Languages.Laurel.LaurelTypes public import Strata.Languages.Laurel.HeapParameterizationConstants +public import Strata.Languages.Laurel.MapStmtExpr public import Strata.Util.Tactics /- @@ -253,6 +254,10 @@ def resolveQualifiedFieldName (model: SemanticModel) (fieldName : Identifier) : | .unresolved _ => none | _ => dbg_trace s!"BUG: resolveQualifiedFieldName {fieldName} did resolved to something other than a field"; none +private def wrapList (source : Option FileRange) : List StmtExprMd → StmtExprMd + | [single] => single + | many => ⟨.Block many none, source⟩ + /-- Transform an expression, adding heap parameters where needed. - `heapVar`: the name of the heap variable to use @@ -260,22 +265,25 @@ Transform an expression, adding heap parameters where needed. - `valueUsed`: whether the result value of this expression is used (affects optimization of heap-writing calls) -/ def heapTransformExpr (heapVar : Identifier) (model: SemanticModel) (expr : StmtExprMd) (valueUsed : Bool := true) : TransformM StmtExprMd := - recurse expr valueUsed + recurseOne expr valueUsed where - recurse (exprMd : StmtExprMd) (valueUsed : Bool := true) : TransformM StmtExprMd := do + recurseOne (exprMd : StmtExprMd) (valueUsed : Bool := true) : TransformM StmtExprMd := + wrapList exprMd.source <$> recurse exprMd valueUsed + termination_by (sizeOf exprMd, 1) + recurse (exprMd : StmtExprMd) (valueUsed : Bool := true) : TransformM (List StmtExprMd) := do let ⟨expr, source⟩ := exprMd match _h : expr with | .Var (.Field selectTarget fieldName) => do let some qualifiedName := resolveQualifiedFieldName model fieldName - | return ⟨ .Hole, source ⟩ + | return [⟨ .Hole, source ⟩] let valTy := (model.get fieldName).getType let readExpr := ⟨ .StaticCall "readField" [mkMd (.Var (.Local heapVar)), selectTarget, mkMd (.StaticCall qualifiedName [])], source ⟩ -- Unwrap Box: apply the appropriate destructor recordBoxConstructor model valTy.val - return mkMd <| .StaticCall (boxDestructorName model valTy.val) [readExpr] + return [mkMd <| .StaticCall (boxDestructorName model valTy.val) [readExpr]] | .StaticCall callee args => - let args' ← args.mapM (recurse ·) + let args' ← args.mapM (recurseOne ·) let calleeReadsHeap ← readsHeap callee let calleeWritesHeap ← writesHeap callee if calleeWritesHeap then @@ -284,7 +292,7 @@ where let callWithHeap := ⟨ .Assign [mkVarMd (.Local heapVar), mkVarMd (.Declare ⟨freshVar, computeExprType model exprMd⟩)] (⟨ .StaticCall callee (mkMd (.Var (.Local heapVar)) :: args'), source ⟩), source ⟩ - return ⟨ .Block [callWithHeap, mkMd (.Var (.Local freshVar))] none, source ⟩ + return [callWithHeap, mkMd (.Var (.Local freshVar))] else -- Generate throwaway Declare targets for any non-heap outputs let procOutputs := match model.get callee with @@ -294,18 +302,18 @@ where let extraTargets ← procOutputs.mapM fun out => do pure (mkVarMd (.Declare ⟨← freshVarName, out.type⟩)) let allTargets := mkVarMd (.Local heapVar) :: extraTargets - return ⟨ .Assign allTargets (⟨ .StaticCall callee (mkMd (.Var (.Local heapVar)) :: args'), source ⟩), source ⟩ + return [⟨ .Assign allTargets (⟨ .StaticCall callee (mkMd (.Var (.Local heapVar)) :: args'), source ⟩), source ⟩] else if calleeReadsHeap then - return ⟨ .StaticCall callee (mkMd (.Var (.Local heapVar)) :: args'), source ⟩ + return [⟨ .StaticCall callee (mkMd (.Var (.Local heapVar)) :: args'), source ⟩] else - return ⟨ .StaticCall callee args', source ⟩ + return [⟨ .StaticCall callee args', source ⟩] | .InstanceCall callTarget callee args => - let t ← recurse callTarget - let args' ← args.mapM (recurse ·) - return ⟨ .InstanceCall t callee args', source ⟩ + let t ← recurseOne callTarget + let args' ← args.mapM (recurseOne ·) + return [⟨ .InstanceCall t callee args', source ⟩] | .IfThenElse c t e => - let e' ← match e with | some x => some <$> recurse x valueUsed | none => pure none - return ⟨ .IfThenElse (← recurse c) (← recurse t valueUsed) e', source ⟩ + let e' ← match e with | some x => some <$> recurseOne x valueUsed | none => pure none + return [⟨ .IfThenElse (← recurseOne c) (← recurseOne t valueUsed) e', source ⟩] | .Block stmts label => let n := stmts.length let rec processStmts (idx : Nat) (remaining : List StmtExprMd) : TransformM (List StmtExprMd) := do @@ -315,16 +323,16 @@ where let isLast := idx == n - 1 let s' ← recurse s (isLast && valueUsed) let rest' ← processStmts (idx + 1) rest - pure (s' :: rest') - termination_by sizeOf remaining + pure (s' ++ rest') + termination_by (sizeOf remaining, 0) let stmts' ← processStmts 0 stmts - return ⟨ .Block stmts' label, source ⟩ + return [⟨ .Block stmts' label, source ⟩] | .While c invs d b => - let invs' ← invs.mapM (recurse ·) - return ⟨ .While (← recurse c) invs' d (← recurse b false), source ⟩ + let invs' ← invs.mapM (recurseOne ·) + return [⟨ .While (← recurseOne c) invs' d (← recurseOne b false), source ⟩] | .Return v => - let v' ← match v with | some x => some <$> recurse x | none => pure none - return ⟨ .Return v', source ⟩ + let v' ← match v with | some x => some <$> recurseOne x | none => pure none + return [⟨ .Return v', source ⟩] | .Assign targets v => -- Process field targets @@ -338,7 +346,7 @@ where let valTy := (model.get fieldName).getType recordBoxConstructor model valTy.val let freshVar ← freshVarName - let target' ← recurse target + let target' ← recurseOne target let boxedVal := mkMd <| .StaticCall (boxConstructorName model valTy.val) [mkMd (.Var (.Local freshVar))] let updateStmt : StmtExprMd := ⟨ .Assign [mkVarMd (.Local heapVar)] (mkMd (.StaticCall "updateField" [mkMd (.Var (.Local heapVar)), target', mkMd (.StaticCall qualifiedName []), boxedVal])), source ⟩ @@ -350,7 +358,7 @@ where -- Detect calls and add a heap argument if needed let (v', addedHeap) <- match _hv : v.val with | .StaticCall callee args => do - let args' <- args.mapM recurse + let args' <- args.mapM recurseOne let calleeWritesHeap ← writesHeap callee let calleeReadsHeap ← readsHeap callee if calleeWritesHeap then @@ -360,11 +368,11 @@ where else pure (⟨ .StaticCall callee args', v.source ⟩, false) | .InstanceCall callTarget _callee args => do - let _callTarget' ← recurse callTarget - let _args' <- args.mapM recurse + let _callTarget' ← recurseOne callTarget + let _args' <- args.mapM recurseOne pure (⟨ .InstanceCall _callTarget' _callee _args', v.source ⟩, false) | _ => - pure (<- recurse v, false) + pure (<- recurseOne v, false) let allTargets := if addedHeap then ⟨ Variable.Local heapVar, v.source ⟩ :: processedTargets else processedTargets @@ -387,15 +395,12 @@ where else updateStatements pure (newAssign, suffixes) - -- Create a block if necessary - if suffixes.length > 0 then - return ⟨ StmtExpr.Block (newAssign :: suffixes) none, source ⟩ - else - return newAssign + -- Return the list of statements directly (flattened into enclosing block) + return newAssign :: suffixes - | .PureFieldUpdate t f v => return ⟨ .PureFieldUpdate (← recurse t) f (← recurse v), source ⟩ + | .PureFieldUpdate t f v => return [⟨ .PureFieldUpdate (← recurseOne t) f (← recurseOne v), source ⟩] | .PrimitiveOp op args => - let args' ← args.mapM (recurse ·) + let args' ← args.mapM (recurseOne ·) -- For == and != on Composite types, compare refs instead match op, args with | .Eq, [e1, _e2] => @@ -404,58 +409,58 @@ where | .UserDefined _ => let ref1 := mkMd (.StaticCall "Composite..ref!" [args'[0]!]) let ref2 := mkMd (.StaticCall "Composite..ref!" [args'[1]!]) - return ⟨ .PrimitiveOp .Eq [ref1, ref2], source ⟩ - | _ => return ⟨ .PrimitiveOp op args', source ⟩ + return [⟨ .PrimitiveOp .Eq [ref1, ref2], source ⟩] + | _ => return [⟨ .PrimitiveOp op args', source ⟩] | .Neq, [e1, _e2] => let ty := (computeExprType model e1).val match ty with | .UserDefined _ => let ref1 := mkMd (.StaticCall "Composite..ref!" [args'[0]!]) let ref2 := mkMd (.StaticCall "Composite..ref!" [args'[1]!]) - return ⟨ .PrimitiveOp .Neq [ref1, ref2], source ⟩ - | _ => return ⟨ .PrimitiveOp op args', source ⟩ - | _, _ => return ⟨ .PrimitiveOp op args', source ⟩ - | .New _ => return exprMd - | .ReferenceEquals l r => return ⟨ .ReferenceEquals (← recurse l) (← recurse r), source ⟩ + return [⟨ .PrimitiveOp .Neq [ref1, ref2], source ⟩] + | _ => return [⟨ .PrimitiveOp op args', source ⟩] + | _, _ => return [⟨ .PrimitiveOp op args', source ⟩] + | .New _ => return [exprMd] + | .ReferenceEquals l r => return [⟨ .ReferenceEquals (← recurseOne l) (← recurseOne r), source ⟩] | .AsType t ty => - let t' ← recurse t valueUsed + let t' ← recurseOne t valueUsed let isCheck := ⟨ .IsType t' ty, source ⟩ let assertStmt := ⟨ .Assert { condition := isCheck }, source ⟩ - return ⟨ .Block [assertStmt, t'] none, source ⟩ - | .IsType t ty => return ⟨ .IsType (← recurse t) ty, source ⟩ + return [⟨ .Block [assertStmt, t'] none, source ⟩] + | .IsType t ty => return [⟨ .IsType (← recurseOne t) ty, source ⟩] | .Quantifier mode p trigger b => - let trigger' ← trigger.attach.mapM fun ⟨t, _⟩ => recurse t - return ⟨.Quantifier mode p trigger' (← recurse b), source⟩ - | .Assigned n => return ⟨ .Assigned (← recurse n), source ⟩ - | .Old v => return ⟨ .Old (← recurse v), source ⟩ - | .Fresh v => return ⟨ .Fresh (← recurse v), source ⟩ + let trigger' ← trigger.attach.mapM fun ⟨t, _⟩ => recurseOne t + return [⟨.Quantifier mode p trigger' (← recurseOne b), source⟩] + | .Assigned n => return [⟨ .Assigned (← recurseOne n), source ⟩] + | .Old v => return [⟨ .Old (← recurseOne v), source ⟩] + | .Fresh v => return [⟨ .Fresh (← recurseOne v), source ⟩] | .Assert ⟨condExpr, summary⟩ => - return ⟨ .Assert { condition := ← recurse condExpr, summary }, source ⟩ - | .Assume c => return ⟨ .Assume (← recurse c), source ⟩ - | .ProveBy v p => return ⟨ .ProveBy (← recurse v) (← recurse p), source ⟩ - | .ContractOf ty f => return ⟨ .ContractOf ty (← recurse f), source ⟩ - | _ => return exprMd - termination_by sizeOf exprMd - decreasing_by - all_goals simp_wf - all_goals (try have := AstNode.sizeOf_val_lt exprMd) - all_goals (try have := AstNode.sizeOf_val_lt v) - all_goals (try term_by_mem) - all_goals (try (cases exprMd; simp_all; omega)) - -- For field inner expressions in attach-based: - all_goals (try ( - have := List.sizeOf_lt_of_mem ‹_› - have := Variable.sizeOf_field_target_lt_of_eq _htv - omega)) - -- Remaining goals - all_goals ( - cases exprMd with | mk val src mmd => - simp_all - omega) + return [⟨ .Assert { condition := ← recurseOne condExpr, summary }, source ⟩] + | .Assume c => return [⟨ .Assume (← recurseOne c), source ⟩] + | .ProveBy v p => return [⟨ .ProveBy (← recurseOne v) (← recurseOne p), source ⟩] + | .ContractOf ty f => return [⟨ .ContractOf ty (← recurseOne f), source ⟩] + | _ => return [exprMd] + termination_by (sizeOf exprMd, 0) + decreasing_by + all_goals simp_wf + all_goals (try have := AstNode.sizeOf_val_lt exprMd) + all_goals (try have := AstNode.sizeOf_val_lt v) + all_goals (try term_by_mem) + all_goals (try (cases exprMd; simp_all; omega)) + -- For field inner expressions in attach-based: + all_goals (try ( + have := List.sizeOf_lt_of_mem ‹_› + have := Variable.sizeOf_field_target_lt_of_eq _htv + omega)) + -- Remaining goals + all_goals ( + cases exprMd with | mk val src => + simp_all + omega) def heapTransformProcedure (model: SemanticModel) (proc : Procedure) : TransformM Procedure := do - let heapName : Identifier := "$heap" - let heapInName : Identifier := "$heap_in" + let heapName := heapVarName + let heapInName := heapInVarName let readsHeap := (← get).heapReaders.contains proc.name let writesHeap := (← get).heapWriters.contains proc.name diff --git a/Strata/Languages/Laurel/HeapParameterizationConstants.lean b/Strata/Languages/Laurel/HeapParameterizationConstants.lean index 758aa149a1..bfa76a4a59 100644 --- a/Strata/Languages/Laurel/HeapParameterizationConstants.lean +++ b/Strata/Languages/Laurel/HeapParameterizationConstants.lean @@ -15,6 +15,12 @@ namespace Strata.Laurel public section +/-- The name of the heap variable used by the heap parameterization pass. -/ +def heapVarName : Identifier := "$heap" + +/-- The name of the input heap parameter used by the heap parameterization pass. -/ +def heapInVarName : Identifier := "$heap_in" + /-- The Laurel Core prelude defines the heap model types and operations used by the Laurel-to-Core translator. These declarations are expressed diff --git a/Strata/Languages/Laurel/InferHoleTypes.lean b/Strata/Languages/Laurel/InferHoleTypes.lean index d56ad86881..ff80f37c5e 100644 --- a/Strata/Languages/Laurel/InferHoleTypes.lean +++ b/Strata/Languages/Laurel/InferHoleTypes.lean @@ -35,10 +35,20 @@ private def inferComparisonArgType (model : SemanticModel) (args : List StmtExpr args.findSome? (fun a => match a.val with | .Hole _ _ => none | _ => some (computeExprType model a)) |>.getD ⟨ .TInt, source ⟩ -- use Int as a default type for comparisons where both operands are holes -/-- Get the expected type for each argument of a call from the callee's parameter list. -/ +/-- Get the expected type for each argument of a call from the callee's parameter list. + + Auto-generated datatype destructors (`TypeName..fieldName[!]`) and testers + (`TypeName..isCtor`) are unary, taking the datatype itself as their single + input. Their `ResolvedNode` (`.datatypeDestructor` / `.datatypeConstructor`) + carries the resolved type Identifier (with its `uniqueId`), so we can + construct the input `HighType` directly without falling back to textual + decoding of the override name. -/ private def calleeParamTypes (model : SemanticModel) (callee : Identifier) : Option (List HighTypeMd) := match model.get callee with | .staticProcedure proc => some (proc.inputs.map (·.type)) + | .datatypeConstructor typeName _ + | .datatypeDestructor typeName _ => + some [⟨.UserDefined typeName, callee.source⟩] | _ => none inductive InferHoleTypesStats where diff --git a/Strata/Languages/Laurel/LaurelCompilationPipeline.lean b/Strata/Languages/Laurel/LaurelCompilationPipeline.lean index c6984120fe..54b97fdfd9 100644 --- a/Strata/Languages/Laurel/LaurelCompilationPipeline.lean +++ b/Strata/Languages/Laurel/LaurelCompilationPipeline.lean @@ -12,7 +12,6 @@ import Strata.Languages.Laurel.EliminateValueReturns import Strata.Languages.Laurel.ConstrainedTypeElim import Strata.Languages.Laurel.TypeAliasElim import Strata.Languages.Core.Verifier -import Strata.Util.Profile import Strata.Util.Statistics /-! @@ -144,7 +143,8 @@ When `keepAllFilesPrefix` is provided (via the `PipelineM` context), the program state after each named Laurel pass is written to `{prefix}.{n}.{passName}.laurel.st`. -/ -private def runLaurelPasses (options : LaurelTranslateOptions) (program : Program) +private def runLaurelPasses (options : LaurelTranslateOptions) + (pctx : Strata.Pipeline.PipelineContext) (program : Program) : PipelineM (Program × SemanticModel × List DiagnosticModel × Statistics) := do let program := { program with staticProcedures := coreDefinitionsForLaurel.staticProcedures ++ program.staticProcedures, @@ -172,14 +172,16 @@ private def runLaurelPasses (options : LaurelTranslateOptions) (program : Progra let mut allStats : Statistics := {} for pass in laurelPipeline do - let (program', diags, stats) ← profileStep options.profile s!" {pass.name}" do - pure (pass.run program model) + let (program', diags, stats) ← pctx.withPhase pass.name do pure (pass.run program model) program := program' allDiags := allDiags ++ diags allStats := allStats.merge stats -- Run resolve after the pass if needed if pass.needsResolves then let result := resolve program (some model) + let newErrors := result.errors.filter fun e => !resolutionErrors.contains e + if !newErrors.isEmpty then + emit pass.name "laurel.st" program program := result.program model := result.model emit pass.name "laurel.st" program @@ -193,9 +195,13 @@ When `keepAllFilesPrefix` is provided, the program state after each named Laurel-to-Laurel pass is written to `{prefix}.{n}.{passName}.laurel.st`. -/ def translateWithLaurel (options : LaurelTranslateOptions) (program : Program) - : IO TranslateResultWithLaurel := + (pipelineCtx : Option Strata.Pipeline.PipelineContext := none) + : IO TranslateResultWithLaurel := do + let pctx ← match pipelineCtx with + | some ctx => pure ctx + | none => Strata.Pipeline.PipelineContext.create (outputMode := .quiet) runPipelineM options.keepAllFilesPrefix do - let (program, model, passDiags, stats) ← runLaurelPasses options program + let (program, model, passDiags, stats) ← runLaurelPasses options pctx program let ordered := orderProgram program -- This early return is a simple way to protect against duplicative errors. Without this return, diff --git a/Strata/Languages/Laurel/LaurelToCoreTranslator.lean b/Strata/Languages/Laurel/LaurelToCoreTranslator.lean index beb50ad9b0..9e02d9a825 100644 --- a/Strata/Languages/Laurel/LaurelToCoreTranslator.lean +++ b/Strata/Languages/Laurel/LaurelToCoreTranslator.lean @@ -25,6 +25,7 @@ public import Strata.Languages.Laurel.CoreDefinitionsForLaurel public import Strata.Languages.Laurel.CoreGroupingAndOrdering import Strata.DDM.Util.DecimalRat import Strata.DL.Imperative.Stmt +import Strata.Pipeline.Messages import Strata.DL.Imperative.MetaData import Strata.DL.Lambda.LExpr import Strata.Languages.Laurel.Grammar.AbstractToConcreteTreeTranslator @@ -92,7 +93,7 @@ def translateType (ty : HighTypeMd) : TranslateM LMonoTy := do | .TSet elementType => return Core.mapTy (← translateType elementType) LMonoTy.bool | .TMap keyType valueType => return Core.mapTy (← translateType keyType) (← translateType valueType) | .UserDefined name => - match name.uniqueId.bind model.refToDef.get? with + match model.get? name with | some (.compositeType _) => return .tcons "Composite" [] | some (.datatypeDefinition dt) => return .tcons dt.name.text [] | some (.datatypeConstructor typeName _) => return .tcons typeName.text [] @@ -624,7 +625,6 @@ structure LaurelTranslateOptions where inlineFunctionsWhenPossible : Bool := false overflowChecks : Core.OverflowChecks := {} keepAllFilesPrefix : Option String := none - profile : Bool := false instance : Inhabited LaurelTranslateOptions where default := {} diff --git a/Strata/Languages/Laurel/LaurelTypes.lean b/Strata/Languages/Laurel/LaurelTypes.lean index ff07ae5171..9bbdc86a83 100644 --- a/Strata/Languages/Laurel/LaurelTypes.lean +++ b/Strata/Languages/Laurel/LaurelTypes.lean @@ -24,6 +24,7 @@ namespace Strata.Laurel def getCallType (source : Option FileRange) (model : SemanticModel) (callee : Identifier): HighTypeMd := match model.get callee with | .datatypeConstructor t _ => ⟨ .UserDefined t, source ⟩ + | .datatypeDestructor _ fld => fld.type | .parameter p => p.type | .staticProcedure proc => match proc.outputs with | [singleOutput] => singleOutput.type diff --git a/Strata/Languages/Laurel/LiftImperativeExpressions.lean b/Strata/Languages/Laurel/LiftImperativeExpressions.lean index 7c63b6870f..e87b24d480 100644 --- a/Strata/Languages/Laurel/LiftImperativeExpressions.lean +++ b/Strata/Languages/Laurel/LiftImperativeExpressions.lean @@ -174,6 +174,24 @@ def containsAssignment (expr : StmtExprMd) : Bool := decreasing_by all_goals ((try cases x); simp_all; try term_by_mem) +/-- Like containsAssignment but does NOT recurse into Blocks (treats them as opaque). + Used by assert/assume handlers to allow generated Block wrappers through. -/ +def containsBareAssignment (expr : StmtExprMd) : Bool := + match expr with + | AstNode.mk val _ => + match val with + | .Assign .. => true + | .StaticCall _ args => args.attach.any (fun x => containsBareAssignment x.val) + | .PrimitiveOp _ args => args.attach.any (fun x => containsBareAssignment x.val) + | .Block _ _ => false + | .IfThenElse cond th el => + containsBareAssignment cond || containsBareAssignment th || + match el with | some e => containsBareAssignment e | none => false + | _ => false + termination_by expr + decreasing_by + all_goals ((try cases x); simp_all; try term_by_mem) + /-- Check if an expression contains any non-functional procedure calls (recursively). -/ def containsImperativeCall (model : SemanticModel) (expr : StmtExprMd) : Bool := match expr with @@ -345,7 +363,30 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do | .Block stmts labelOption => let newStmts := (← stmts.reverse.mapM transformExpr).reverse - return ⟨ .Block (← onlyKeepSideEffectStmtsAndLast newStmts) labelOption, source⟩ + -- Flatten generated multi-output call wrappers BEFORE onlyKeepSideEffectStmtsAndLast + -- which would drop the multi-target assign. Pattern: [VarDecl, MultiAssign, VarRef]. + match newStmts with + | [decl, assign, last] => + match decl.val, assign.val with + | .Assign [t] _, .Assign targets _ => + match t.val with + | .Declare _ => + if targets.length ≥ 2 then + prepend assign + prepend decl + return last + else + let filtered ← onlyKeepSideEffectStmtsAndLast newStmts + return ⟨ .Block filtered labelOption, source⟩ + | _ => + let filtered ← onlyKeepSideEffectStmtsAndLast newStmts + return ⟨ .Block filtered labelOption, source⟩ + | _, _ => + let filtered ← onlyKeepSideEffectStmtsAndLast newStmts + return ⟨ .Block filtered labelOption, source⟩ + | _ => + let filtered ← onlyKeepSideEffectStmtsAndLast newStmts + return ⟨ .Block filtered labelOption, source⟩ | .Var (.Declare param) => -- If the substitution map has an entry for this variable, it was @@ -372,9 +413,11 @@ def transformStmt (stmt : StmtExprMd) : LiftM (List StmtExprMd) := do | AstNode.mk val source => match val with | .Assert cond => - -- Do not transform assert conditions with assignments — they must be rejected. - -- But nondeterministic holes and imperative calls need to be lifted. - if !containsAssignment cond.condition then + -- Do not transform assert conditions with bare assignments — they are + -- semantic errors that should be rejected downstream. + -- But Blocks with assignments (generated multi-output call wrappers) + -- are handled by the Block case in transformExpr above. + if !containsBareAssignment cond.condition then let seqCond ← transformExpr cond.condition let prepends ← takePrepends modify fun s => { s with subst := [] } @@ -383,9 +426,7 @@ def transformStmt (stmt : StmtExprMd) : LiftM (List StmtExprMd) := do return [stmt] | .Assume cond => - -- Do not transform assume conditions with assignments — they must be rejected. - -- But nondeterministic holes and imperative calls need to be lifted. - if !containsAssignment cond then + if !containsBareAssignment cond then let seqCond ← transformExpr cond let prepends ← takePrepends modify fun s => { s with subst := [] } diff --git a/Strata/Languages/Laurel/ModifiesClauses.lean b/Strata/Languages/Laurel/ModifiesClauses.lean index 20fd01445d..5fb37c60ad 100644 --- a/Strata/Languages/Laurel/ModifiesClauses.lean +++ b/Strata/Languages/Laurel/ModifiesClauses.lean @@ -9,6 +9,7 @@ public import Strata.Languages.Laurel.Laurel public import Strata.Languages.Laurel.LaurelTypes public import Strata.Languages.Core.Verifier public import Strata.Languages.Laurel.Resolution +import Strata.Languages.Laurel.HeapParameterizationConstants /- Modifies clause transformation (Laurel → Laurel). @@ -159,8 +160,8 @@ def transformModifiesClauses (model: SemanticModel) -- modifies * means the procedure can modify anything; no frame condition .ok { proc with body := .Opaque postconds impl [] } else if hasHeapOut proc then - let heapInName : Identifier := "$heap_in" - let heapName : Identifier := "$heap" + let heapInName := heapInVarName + let heapName := heapVarName let frameCondition := buildModifiesEnsures proc model modifiesExprs heapInName heapName let postconds' := match frameCondition with | some frame => postconds ++ [{ condition := frame, summary := "modifies clause" }] diff --git a/Strata/Languages/Laurel/Resolution.lean b/Strata/Languages/Laurel/Resolution.lean index 16bcf1333f..75e6ac1292 100644 --- a/Strata/Languages/Laurel/Resolution.lean +++ b/Strata/Languages/Laurel/Resolution.lean @@ -75,6 +75,7 @@ inductive ResolvedNodeKind where | constrainedType | datatypeDefinition | datatypeConstructor + | datatypeDestructor | typeAlias | constant | quantifierVar @@ -91,6 +92,7 @@ def ResolvedNodeKind.name : ResolvedNodeKind → String | .constrainedType => "constrained type" | .datatypeDefinition => "datatype definition" | .datatypeConstructor => "datatype constructor" + | .datatypeDestructor => "datatype destructor" | .typeAlias => "type alias" | .constant => "constant" | .quantifierVar => "quantifier variable" @@ -116,6 +118,10 @@ inductive ResolvedNode where | datatypeDefinition (ty : DatatypeDefinition) /-- A datatype constructor. -/ | datatypeConstructor (typeName : Identifier) (ctor : DatatypeConstructor) + /-- An auto-generated destructor (or unsafe `!`-destructor) for a datatype field. + `typeName` is the resolved Identifier of the parent datatype (with its + `uniqueId`), and `field` is the underlying constructor parameter. -/ + | datatypeDestructor (typeName : Identifier) (field : Parameter) /-- A type alias. -/ | typeAlias (ty : TypeAlias) /-- A constant. -/ @@ -139,6 +145,7 @@ def ResolvedNode.kind : ResolvedNode → ResolvedNodeKind | .constrainedType .. => .constrainedType | .datatypeDefinition .. => .datatypeDefinition | .datatypeConstructor .. => .datatypeConstructor + | .datatypeDestructor .. => .datatypeDestructor | .typeAlias .. => .typeAlias | .constant .. => .constant | .quantifierVar .. => .quantifierVar @@ -149,29 +156,35 @@ def ResolvedNode.getType (node: ResolvedNode): HighTypeMd := match node with | .parameter p => p.type | .field _ f => f.type | .datatypeConstructor type _ => ⟨ .UserDefined type, none ⟩ + | .datatypeDestructor _ fld => fld.type | .constant c => c.type | .quantifierVar _ type => type | .unresolved source => ⟨ .Unknown, source ⟩ - | _ => dbg_trace s!"SOUND BUG: getType called on {repr node}"; default + | .staticProcedure _ | .instanceProcedure _ _ | .compositeType _ + | .constrainedType _ | .datatypeDefinition _ | .typeAlias _ => ⟨ .Unknown, none ⟩ /-! ## Resolution result -/ structure SemanticModel where nextId: Nat compositeCount: Nat - refToDef: Std.HashMap Nat ResolvedNode + private refToDef: Std.HashMap Nat ResolvedNode deriving Repr +/-- Look up the resolved node for an identifier, returning `none` if the identifier + has no `uniqueId` or is not in the model. -/ +def SemanticModel.get? (model: SemanticModel) (iden: Identifier): Option ResolvedNode := + iden.uniqueId.bind model.refToDef.get? + def SemanticModel.get (model: SemanticModel) (iden: Identifier): ResolvedNode := - match iden.uniqueId with - | some key => (model.refToDef.get? key).getD default - | none => default + (model.get? iden).getD default def SemanticModel.isFunction (model: SemanticModel) (id: Identifier): Bool := match model.get id with | .staticProcedure proc => proc.isFunctional | .parameter _ => true | .datatypeConstructor _ _ => true + | .datatypeDestructor _ _ => true | .constant _ => true | .unresolved _ => true -- functions calls are more permissive, so true avoids possibly incorrect errors | node => @@ -213,6 +226,10 @@ structure ResolveState where /-- When resolving inside an instance procedure, the owning composite type name. Used by `resolveFieldRef` to resolve `self.field` when `self` has type `Any`. -/ instanceTypeName : Option String := none + /-- True when resolving inside an expression where the value is used (e.g., as an + argument to another call or operator). Multi-output calls are only diagnosed + in value context, not in statement position or direct assignment RHS. -/ + inValueContext : Bool := false @[expose] abbrev ResolveM := StateM ResolveState @@ -358,7 +375,10 @@ def resolveStmtExpr (exprMd : StmtExprMd) : ResolveM StmtExprMd := do | AstNode.mk expr source => let val' ← match _: expr with | .IfThenElse cond thenBr elseBr => + let saved := (← get).inValueContext + modify fun s => { s with inValueContext := true } let cond' ← resolveStmtExpr cond + modify fun s => { s with inValueContext := saved } let thenBr' ← resolveStmtExpr thenBr let elseBr' ← elseBr.attach.mapM (fun a => have := a.property; resolveStmtExpr a.val) pure (.IfThenElse cond' thenBr' elseBr') @@ -367,7 +387,10 @@ def resolveStmtExpr (exprMd : StmtExprMd) : ResolveM StmtExprMd := do let stmts' ← stmts.mapM resolveStmtExpr pure (.Block stmts' label) | .While cond invs dec body => + let saved := (← get).inValueContext + modify fun s => { s with inValueContext := true } let cond' ← resolveStmtExpr cond + modify fun s => { s with inValueContext := saved } let invs' ← invs.attach.mapM (fun a => have := a.property; resolveStmtExpr a.val) let dec' ← dec.attach.mapM (fun a => have := a.property; resolveStmtExpr a.val) let body' ← resolveStmtExpr body @@ -436,11 +459,31 @@ def resolveStmtExpr (exprMd : StmtExprMd) : ResolveM StmtExprMd := do pure (.PureFieldUpdate target' fieldName' newVal') | .StaticCall callee args => let callee' ← resolveRef callee source - (expected := #[.parameter, .staticProcedure, .datatypeConstructor, .constant]) + (expected := #[.parameter, .staticProcedure, .datatypeConstructor, .datatypeDestructor, .constant]) + -- Resolve arguments in value context (their results are used as values) + let saved := (← get).inValueContext + modify fun s => { s with inValueContext := true } let args' ← args.mapM resolveStmtExpr + modify fun s => { s with inValueContext := saved } + -- Multi-output procedures must not appear in value context: the extra + -- outputs (e.g. error channels) would be silently discarded. + let s ← get + if s.inValueContext then + let outputCount := match s.scope.get? callee'.text with + | some (_, .staticProcedure proc) => proc.outputs.length + | some (_, .instanceProcedure _ proc) => proc.outputs.length + | _ => 0 + if outputCount > 1 then + let diag := diagnosticFromSource source + s!"Multi-output procedure '{callee'.text}' used in expression position; it returns {outputCount} values but only one can be used here. Use a multi-target assignment instead." + modify fun s => { s with errors := s.errors.push diag } pure (.StaticCall callee' args') | .PrimitiveOp op args => + -- Resolve arguments in value context + let saved := (← get).inValueContext + modify fun s => { s with inValueContext := true } let args' ← args.mapM resolveStmtExpr + modify fun s => { s with inValueContext := saved } pure (.PrimitiveOp op args') | .New ref => let ref' ← resolveRef ref source @@ -482,10 +525,16 @@ def resolveStmtExpr (exprMd : StmtExprMd) : ResolveM StmtExprMd := do let val' ← resolveStmtExpr val pure (.Fresh val') | .Assert ⟨condExpr, summary⟩ => + let saved := (← get).inValueContext + modify fun s => { s with inValueContext := true } let cond' ← resolveStmtExpr condExpr + modify fun s => { s with inValueContext := saved } pure (.Assert { condition := cond', summary }) | .Assume cond => + let saved := (← get).inValueContext + modify fun s => { s with inValueContext := true } let cond' ← resolveStmtExpr cond + modify fun s => { s with inValueContext := saved } pure (.Assume cond') | .ProveBy val proof => let val' ← resolveStmtExpr val @@ -785,7 +834,11 @@ private def collectTypeDefinition (map : Std.HashMap Nat ResolvedNode) (td : Typ dt.constructors.foldl (fun map ctor => let map := register map ctor.name (.datatypeConstructor dt.name ctor) ctor.args.foldl (fun map p => - let map := register map p.name (.parameter p) + -- The constructor parameter's `uniqueId` (set by `resolveTypeDefinition`) + -- is the shared uniqueId of the safe/unsafe destructor scope entries, + -- so registering it here as `.datatypeDestructor` covers calls of the + -- form `TypeName..fieldName` and `TypeName..fieldName!`. + let map := register map p.name (.datatypeDestructor dt.name p) collectHighType map p.type ) map ) map @@ -837,12 +890,19 @@ private def preRegisterTopLevel (program : Program) : ResolveM Unit := do | .Datatype dt => let _ ← defineNameCheckDup dt.name (.datatypeDefinition dt) for ctor in dt.constructors do - _ ← defineNameCheckDup ctor.name (.datatypeConstructor dt.name ctor) (some (dt.testerName ctor)) - let _ ← defineNameCheckDup ctor.name (.datatypeConstructor dt.name ctor) + -- Register the tester override first; the second call reuses the + -- returned Identifier (now carrying a uniqueId) so the unprefixed + -- constructor name and the `TypeName..isCtor` tester name resolve to + -- the same uniqueId, which `buildRefToDef` in turn maps to + -- `.datatypeConstructor`. + let ctorName ← defineNameCheckDup ctor.name (.datatypeConstructor dt.name ctor) (some (dt.testerName ctor)) + let _ ← defineNameCheckDup ctorName (.datatypeConstructor dt.name ctor) for p in ctor.args do - let _ ← defineNameCheckDup p.name (.parameter p) (some (dt.destructorName p)) - -- unsafeDestructorId - let _ ← defineNameCheckDup p.name (.parameter p) (some (dt.unsafeDestructorName p)) + -- Same chaining trick for the safe and unsafe destructor names: both + -- point to the same uniqueId so `IntList..head` and `IntList..head!` + -- resolve to the same `.datatypeDestructor` model entry. + let pName ← defineNameCheckDup p.name (.datatypeDestructor dt.name p) (some (dt.destructorName p)) + let _ ← defineNameCheckDup pName (.datatypeDestructor dt.name p) (some (dt.unsafeDestructorName p)) | .Alias ta => let _ ← defineNameCheckDup ta.name (.typeAlias ta) -- Pre-register constants diff --git a/Strata/Languages/Laurel/TypeHierarchy.lean b/Strata/Languages/Laurel/TypeHierarchy.lean index 411c61b95f..26b72ff23f 100644 --- a/Strata/Languages/Laurel/TypeHierarchy.lean +++ b/Strata/Languages/Laurel/TypeHierarchy.lean @@ -8,6 +8,7 @@ module public import Strata.Languages.Laurel.MapStmtExpr public import Strata.Languages.Laurel.LaurelTypes public import Strata.DL.Imperative.MetaData +import Strata.Languages.Laurel.HeapParameterizationConstants import Strata.Util.Tactics public section @@ -235,7 +236,7 @@ Lower `New name` to a block that: 3. Constructs a `MkComposite(counter, name_TypeTag())` value -/ def lowerNew (name : Identifier) (source : Option FileRange) : THM StmtExprMd := do - let heapVar : Identifier := "$heap" + let heapVar := heapVarName let freshVar ← freshVarName let getCounter := mkMd (.StaticCall "Heap..nextReference!" [mkMd (.Var (.Local heapVar))]) let saveCounter := mkMd (.Assign [mkVarMd (.Declare ⟨freshVar, ⟨.TInt, none⟩⟩)] getCounter) diff --git a/Strata/Languages/Python/OverloadTable.lean b/Strata/Languages/Python/OverloadTable.lean index 4375358854..6edb232236 100644 --- a/Strata/Languages/Python/OverloadTable.lean +++ b/Strata/Languages/Python/OverloadTable.lean @@ -5,36 +5,12 @@ -/ module public import Std.Data.HashMap.Basic +public import Strata.Languages.Python.PythonIdent public section namespace Strata.Python -/-- -A fully-qualified Python identifier consisting of a module path and a name. -For example, `typing.List` has module "typing" and name "List". --/ -structure PythonIdent where - pythonModule : String - name : String - deriving DecidableEq, Hashable, Inhabited, Ord, Repr - -namespace PythonIdent - -protected def ofString (s : String) : Option PythonIdent := - match s.revFind? '.' with - | none => none - | some idx => - some { - pythonModule := s.extract s.startPos idx - name := s.extract idx.next! s.endPos - } - -instance : ToString PythonIdent where - toString i := s!"{i.pythonModule}.{i.name}" - -end PythonIdent - /-- All overloads for a single function name: maps a string literal argument value to the return type (`PythonIdent`), together with @@ -44,10 +20,9 @@ N.B. Current limitations: dispatch is always on the first positional argument or the matching keyword argument, and only string literal values are extracted. -/ structure FunctionOverloads where /-- Expected keyword argument name for dispatch (from the PySpec). -/ - paramName : Option String := none + paramName : String /-- Literal value → return type. -/ entries : Std.HashMap String PythonIdent := {} -deriving Inhabited /-- Find the dispatch argument value from positional or keyword arguments. Prefers the first positional arg; falls back to the keyword arg whose @@ -56,10 +31,15 @@ def FunctionOverloads.findDispatchArg (fo : FunctionOverloads) (positionalArgs : Array α) (kwargPairs : List (Option String × α)) : Option α := - if h : positionalArgs.size > 0 then some positionalArgs[0] - else fo.paramName.bind fun expected => + if h : positionalArgs.size > 0 then + some positionalArgs[0] + else + let expected := fo.paramName kwargPairs.findSome? fun (name?, value) => - if name? == some expected then some value else none + if name? == some expected then + some value + else + none /-- Dispatch table: function name → its overloads. -/ @[expose] abbrev OverloadTable := Std.HashMap String FunctionOverloads diff --git a/Strata/Languages/Python/PySpecPipeline.lean b/Strata/Languages/Python/PySpecPipeline.lean index 21129ebe39..5c05cfc37a 100644 --- a/Strata/Languages/Python/PySpecPipeline.lean +++ b/Strata/Languages/Python/PySpecPipeline.lean @@ -5,20 +5,21 @@ -/ module +import all Strata.DDM.Util.String import Strata.Languages.Laurel.FilterPrelude import Strata.Languages.Laurel.LaurelCompilationPipeline -public import Strata.Util.Statistics public import Strata.Languages.Python.PythonToLaurel import Strata.Languages.Python.ReadPython import Strata.Languages.Python.PythonLaurelCorePrelude import Strata.Languages.Python.PythonRuntimeLaurelPart import Strata.Languages.Python.Specs import Strata.Languages.Python.Specs.DDM -public import Strata.Languages.Python.Specs.Error import Strata.Languages.Python.Specs.IdentifyOverloads +import Strata.Languages.Python.Specs.MessageKind import Strata.Languages.Python.Specs.ToLaurel +public import Strata.Pipeline.Context import Strata.Util.DecideProp -import Strata.Util.Profile +public import Strata.Util.Statistics /-! ## PySpec Pipeline @@ -29,6 +30,7 @@ and translates through to Core for verification. namespace Strata +open Pipeline (emitMessage emitMessageAndAbort) open Python (OverloadTable) /-! ### Types -/ @@ -42,8 +44,7 @@ public structure PySpecLaurelResult where typeAliases : Std.HashMap String String := {} /-- Classes whose spec is considered exhaustive (lists all methods). -/ exhaustiveClasses : Std.HashSet String := {} - /-- Warnings collected during PySpec translation. -/ - pyspecWarnings : Array Python.Specs.SpecError := #[] + deriving Inhabited /-! ### Private Helpers -/ @@ -94,8 +95,8 @@ private def funcDeclToFunctionDecl (name : String) (args : Python.Specs.ArgDecls Handles both top-level functions and class methods. Strips `self` from class methods and expands `**kwargs` TypedDict fields. -/ private def extractFunctionSignatures (sigs : Array Python.Specs.Signature) - (modulePrefix : String) : Except String (Array Python.PythonFunctionDecl) := do - let funcPrefix := if modulePrefix.isEmpty then "" else modulePrefix ++ "_" + (moduleName : Python.ModuleName) : Except String (Array Python.PythonFunctionDecl) := do + let funcPrefix := moduleName.toString (sep := "_") ++ "_" let mut result : Array Python.PythonFunctionDecl := #[] for sig in sigs do match sig with @@ -117,12 +118,11 @@ private def extractFunctionSignatures (sigs : Array Python.Specs.Signature) private def mergeOverloads (old new : OverloadTable) : OverloadTable := new.fold (init := old) fun o name n => - o.alter name fun s => - let existing := s.getD {} - some { paramName := existing.paramName <|> n.paramName - entries := existing.entries.union n.entries } - - + o.alter name fun + | some existing => + some { paramName := existing.paramName + entries := existing.entries.union n.entries } + | none => some n /-- Read PySpec Ion files and collect their Laurel declarations and overload tables into a single combined result. Each Ion file is parsed and translated @@ -130,105 +130,140 @@ private def mergeOverloads (old new : OverloadTable) : OverloadTable := accumulated into one `Laurel.Program`, and overload dispatch entries are merged into a single table. - Each entry is a `(modulePrefix, ionPath)` pair. The `modulePrefix` is used + Each entry is a `(moduleName, ionPath)` pair. The module name is used to namespace all generated Laurel names (e.g., `"servicelib_Storage"` for module `servicelib.Storage`). -/ -public def buildPySpecLaurel (pyspecEntries : Array (String × String)) - (overloads : OverloadTable) : EIO String PySpecLaurelResult := do +private def buildPySpecLaurelM (pyspecEntries : Array (Python.ModuleName × String)) + (overloads : OverloadTable) : Pipeline.PipelineM PySpecLaurelResult := do let mut combinedProcedures : Array (Laurel.Procedure × String) := #[] let mut combinedTypes : Array (Laurel.TypeDefinition × String) := #[] let mut allOverloads := overloads let mut funcSigs : Array Python.PythonFunctionDecl := #[] let mut allTypeAliases : Std.HashMap String String := {} let mut allExhaustiveClasses : Std.HashSet String := {} - let mut allWarnings : Array Python.Specs.SpecError := #[] - for (modulePrefix, ionPath) in pyspecEntries do + for (moduleName, ionPath) in pyspecEntries do let ionFile : System.FilePath := ionPath let sigs ← match ← Python.Specs.readDDM ionFile |>.toBaseIO with | .ok t => pure t - | .error msg => throw s!"Could not read {ionFile}: {msg}" + | .error msg => + emitMessageAndAbort .pySpecReadError msg (file := ionFile) let { program, errors, overloads, typeAliases, exhaustiveClasses } := - Python.Specs.ToLaurel.signaturesToLaurel ionPath sigs modulePrefix - allWarnings := allWarnings ++ errors + Python.Specs.ToLaurel.signaturesToLaurel ionPath sigs moduleName + for msg in errors do + Pipeline.addMessage msg + if msg.kind.impact.isFatal then throw () allOverloads := mergeOverloads allOverloads overloads allTypeAliases := typeAliases.fold (init := allTypeAliases) fun m k v => m.insert k v allExhaustiveClasses := exhaustiveClasses.fold (init := allExhaustiveClasses) fun s name => s.insert name - match extractFunctionSignatures sigs modulePrefix with + match extractFunctionSignatures sigs moduleName with | .ok fs => funcSigs := funcSigs ++ fs - | .error msg => throw msg + | .error msg => + emitMessageAndAbort .functionSignatureError msg (file := ionFile) for td in program.types do combinedTypes := combinedTypes.push (td, ionPath) for proc in program.staticProcedures do combinedProcedures := combinedProcedures.push (proc, ionPath) - -- Reject name collisions across PySpec files + -- Reject name collisions across PySpec files (first-wins) let mut seenTypes : Std.HashMap String String := {} + let mut dedupedTypes : Array (Laurel.TypeDefinition × String) := #[] for (td, srcFile) in combinedTypes do - let name := match td with - | .Composite ct => ct.name.text - | .Constrained ct => ct.name.text - | .Datatype dt => dt.name.text - | .Alias ta => ta.name.text - match seenTypes.get? name with + let ident := match td with + | .Composite ct => ct.name + | .Constrained ct => ct.name + | .Datatype dt => dt.name + | .Alias ta => ta.name + match seenTypes.get? ident.text with | some prevFile => - throw s!"PySpec type name collision: '{name}' defined in both {prevFile} and {srcFile}" - | none => pure () - seenTypes := seenTypes.insert name srcFile + emitMessageAndAbort .typeNameCollision s!"'{ident.text}' already defined in {prevFile}" + (file := srcFile) (loc := ident.source.map (·.range) |>.getD default) + | none => + seenTypes := seenTypes.insert ident.text srcFile + dedupedTypes := dedupedTypes.push (td, srcFile) let mut seenProcs : Std.HashMap String String := {} + let mut dedupedProcs : Array (Laurel.Procedure × String) := #[] for (proc, srcFile) in combinedProcedures do - match seenProcs.get? proc.name.text with + match seenProcs[proc.name.text]? with | some prevFile => - throw s!"PySpec procedure name collision: '{proc.name.text}' defined in both {prevFile} and {srcFile}" - | none => pure () - seenProcs := seenProcs.insert proc.name.text srcFile + emitMessageAndAbort .procedureNameCollision s!"'{proc.name.text}' already defined in {prevFile}" + (file := srcFile) (loc := proc.name.source.map (·.range) |>.getD default) + | none => + seenProcs := seenProcs.insert proc.name.text srcFile + dedupedProcs := dedupedProcs.push (proc, srcFile) let combinedLaurel : Laurel.Program := { - staticProcedures := Strata.Python.pythonRuntimeLaurelPart.staticProcedures ++ combinedProcedures.toList.map Prod.fst + staticProcedures := Strata.Python.pythonRuntimeLaurelPart.staticProcedures ++ dedupedProcs.toList.map Prod.fst staticFields := [] - types := Strata.Python.pythonRuntimeLaurelPart.types ++ combinedTypes.toList.map Prod.fst + types := Strata.Python.pythonRuntimeLaurelPart.types ++ dedupedTypes.toList.map Prod.fst constants := [] } return { laurelProgram := combinedLaurel, overloads := allOverloads functionSignatures := funcSigs.toList, typeAliases := allTypeAliases - exhaustiveClasses := allExhaustiveClasses - pyspecWarnings := allWarnings } + exhaustiveClasses := allExhaustiveClasses } + +/-- Read PySpec Ion files and collect their Laurel declarations and overload + tables into a single combined result. -/ +public def buildPySpecLaurel + (ctx : Pipeline.PipelineContext) + (pyspecEntries : Array (Python.ModuleName × String)) + (overloads : OverloadTable) : EIO Unit PySpecLaurelResult := + buildPySpecLaurelM pyspecEntries overloads |>.run ctx /-- Read dispatch Ion files and merge their overload tables. -/ -public def readDispatchOverloads - (dispatchPaths : Array String) - : EIO String (OverloadTable × Array Python.Specs.SpecError) := do +private def readDispatchOverloadsM + (dispatchPaths : Array String) : Pipeline.PipelineM OverloadTable := do let mut tbl : OverloadTable := {} - let mut allWarnings : Array Python.Specs.SpecError := #[] for dispatchPath in dispatchPaths do let ionFile : System.FilePath := dispatchPath let sigs ← match ← Python.Specs.readDDM ionFile |>.toBaseIO with | .ok t => pure t - | .error msg => throw s!"Could not read dispatch file {ionFile}: {msg}" - let (overloads, errors) := - Python.Specs.ToLaurel.extractOverloads dispatchPath sigs - allWarnings := allWarnings ++ errors + | .error msg => + emitMessageAndAbort .pySpecReadError msg (file := ionFile) + let (overloads, errors) := Python.Specs.ToLaurel.extractOverloads dispatchPath sigs + for msg in errors do + Pipeline.addMessage msg + if msg.kind.impact.isFatal then throw () tbl := mergeOverloads tbl overloads - return (tbl, allWarnings) - -/-- Resolve a module name to a `(modulePrefix, ionPath)` pair for - `buildPySpecLaurel`. Returns `none` if the pyspec file is not found. -/ -private def resolveModuleEntry (modName : String) (specDir : System.FilePath) - (quiet : Bool := false) - : EIO String (Option (String × String)) := do - match Python.Specs.ModuleName.ofString modName with - | .error _ => - if !quiet then - let _ ← IO.eprintln - s!"warning: invalid module name '{modName}', skipping" |>.toBaseIO - return none - | .ok mod => - match ← mod.specIonPath specDir with - | some specPath => - let pfx := "_".intercalate mod.components.toList - return some (pfx, specPath.toString) - | none => return none + return tbl + +/-- Read dispatch Ion files and merge their overload tables. -/ +public def readDispatchOverloads + (ctx : Pipeline.PipelineContext) + (dispatchPaths : Array String) : EIO Unit OverloadTable := + readDispatchOverloadsM dispatchPaths |>.run ctx + +/-- Resolve a parsed module name to its .ion path. + Returns `none` if the file is not found on disk. -/ +private def resolveModuleEntry (mod : Python.ModuleName) (specDir : System.FilePath) + : Pipeline.PipelineM (Option (Python.ModuleName × String)) := do + match ← mod.specIonPath specDir with + | some specPath => + return some (mod, specPath.toString) + | none => return none + +/-- Resolve already-parsed module names that must exist. Fatal on missing file. -/ +private def resolveModuleNames (modules : Array Python.ModuleName) (specDir : System.FilePath) + : Pipeline.PipelineM (Array (Python.ModuleName × String)) := do + let mut entries : Array (Python.ModuleName × String) := #[] + for mod in modules do + let some entry ← resolveModuleEntry mod specDir + | emitMessageAndAbort .missingPySpecModule + s!"PySpec module '{mod}' not found in {specDir}" (file := specDir) + entries := entries.push entry + return entries + +/-- Resolve module name strings that must exist. Fatal on invalid name or missing file. -/ +private def resolveModules (modules : Array String) (specDir : System.FilePath) + : Pipeline.PipelineM (Array (Python.ModuleName × String)) := do + let mut parsed : Array Python.ModuleName := #[] + for modName in modules do + let some mod := Python.ModuleName.ofString? modName + | emitMessageAndAbort .invalidModuleName s!"invalid module name '{modName}'" (file := specDir) + parsed := parsed.push mod + resolveModuleNames parsed specDir + /-- Build dispatch overload table, auto-resolve pyspec files from the program AST, and return combined Laurel declarations @@ -243,40 +278,24 @@ public def resolveAndBuildLaurelPrelude (pyspecModules : Array String) (stmts : Array (Python.stmt SourceRange)) (specDir : System.FilePath := ".") - (quiet : Bool := false) - : EIO String PySpecLaurelResult := do - -- Resolve dispatch module names to Ion paths - let mut dispatchPaths : Array String := #[] - for modName in dispatchModules do - match ← resolveModuleEntry modName specDir (quiet := quiet) with - | some (_, path) => dispatchPaths := dispatchPaths.push path - | none => throw s!"Dispatch module '{modName}' not found in {specDir}" - let (dispatchOverloads, dispatchWarnings) ← readDispatchOverloads dispatchPaths + : Pipeline.PipelineM PySpecLaurelResult := do + -- Dispatch modules (fatal on invalid name or missing file) + let dispatchEntries ← resolveModules dispatchModules specDir + let dispatchPaths := dispatchEntries.map (·.2) + let dispatchOverloads ← readDispatchOverloadsM dispatchPaths let resolveState := Python.Specs.IdentifyOverloads.resolveOverloads dispatchOverloads stmts - if !quiet then - for w in resolveState.warnings do - let _ ← IO.eprintln s!"warning: {w}" |>.toBaseIO - -- Auto-resolve pyspec modules from overload table - let mut autoSpecEntries : Array (String × String) := #[] - if dispatchModules.size > 0 then - let resolvedMods := resolveState.modules.toArray.qsort (· < ·) - for modName in resolvedMods do - match ← resolveModuleEntry modName specDir (quiet := quiet) with - | some entry => autoSpecEntries := autoSpecEntries.push entry - | none => - if !quiet then - let _ ← IO.eprintln - s!"warning: auto-resolved pyspec not found for module '{modName}'" |>.toBaseIO - -- Resolve explicit pyspec module names - let mut explicitEntries : Array (String × String) := #[] - for modName in pyspecModules do - match ← resolveModuleEntry modName specDir (quiet := quiet) with - | some entry => explicitEntries := explicitEntries.push entry - | none => throw s!"PySpec module '{modName}' not found in {specDir}" - let allSpecEntries := autoSpecEntries ++ explicitEntries - let result ← buildPySpecLaurel allSpecEntries dispatchOverloads - return { result with pyspecWarnings := dispatchWarnings ++ result.pyspecWarnings } + for w in resolveState.warnings do + emitMessage .overloadResolveWarning w (file := specDir) + -- Auto-resolved from dispatch overload table + let autoSpecEntries ← + if dispatchModules.size > 0 then + let resolvedMods := resolveState.modules.toArray.qsort (· < ·) + resolveModuleNames resolvedMods specDir + else pure #[] + -- Explicit pyspec modules (fatal on invalid name or missing file) + let explicitEntries ← resolveModules pyspecModules specDir + buildPySpecLaurelM (autoSpecEntries ++ explicitEntries) dispatchOverloads /-! ### Pipeline Steps -/ @@ -376,35 +395,20 @@ public def splitProcNames (prog : Core.Program) Laurel pass is written to `{prefix}.{n}.{passName}.laurel.st`. -/ public def translateCombinedLaurelWithLowered (combined : Laurel.Program) (keepAllFilesPrefix : Option String := none) - (profile : Bool := false) + (pipelineCtx : Option Pipeline.PipelineContext := none) : IO (Option Core.Program × List DiagnosticModel × Laurel.Program × Statistics) := do let (coreOption, errors, lowered, stats) ← - Laurel.translateWithLaurel { inlineFunctionsWhenPossible := true, keepAllFilesPrefix, profile } combined + Laurel.translateWithLaurel { inlineFunctionsWhenPossible := true, keepAllFilesPrefix } + combined (pipelineCtx := pipelineCtx) return (coreOption.map appendCorePartOfRuntime, errors, lowered, stats) /-- Translate a combined Laurel program to Core and prepend the full runtime prelude. -/ public def translateCombinedLaurel (combined : Laurel.Program) - (profile : Bool := false) : IO (Option Core.Program × List DiagnosticModel) := do - let (coreOption, errors, _, _) ← translateCombinedLaurelWithLowered combined (profile := profile) + let (coreOption, errors, _, _) ← translateCombinedLaurelWithLowered combined return (coreOption, errors) -/-- Errors from the pyAnalyzeLaurel pipeline. -/ -public inductive PipelineError where - /-- The Python source contains invalid code (bad method name, wrong arguments, etc.). -/ - | userCode (range : SourceRange := .none) (msg : String) - /-- The pipeline encountered a Python construct it intentionally does not yet support. -/ - | knownLimitation (msg : String) - /-- An unexpected failure — likely a bug in the tool itself. -/ - | internal (msg : String) - -public instance : ToString PipelineError where - toString - | .userCode _ msg => s!"User code error: {msg}" - | .knownLimitation msg => s!"Known limitation: {msg}" - | .internal msg => msg - /-- Run the pyAnalyzeLaurel pipeline: read a Python Ion program, resolve overloads from dispatch files, load PySpec declarations, translate Python to Laurel, and combine with PySpec Laurel. @@ -416,73 +420,46 @@ public instance : ToString PipelineError where Laurel metadata (useful when the Ion file was generated from a `.py` source and you want line numbers to refer to the original). - When `warningSummaryFile` is provided, writes a JSON summary of - PySpec translation warnings to that path. The summary is written - after pyspec resolution, before Python-to-Laurel translation, so - it is produced even when later pipeline stages fail. -/ + Runs in `PipelineM`. Fatal errors abort via `emitMessageAndAbort`. -/ public def pythonAndSpecToLaurel (pythonIonPath : String) (dispatchModules : Array String := #[]) (pyspecModules : Array String := #[]) (sourcePath : Option String := none) (specDir : System.FilePath := ".") - (profile : Bool := false) - (quiet : Bool := false) - (warningSummaryFile : Option String := none) - : EIO PipelineError Laurel.Program := do - let stmts ← profileStep profile "Read Python Ion" do + : Pipeline.PipelineM Laurel.Program := do + let stmts ← Pipeline.withPhase "readPythonIon" do match ← Python.readPythonStrata pythonIonPath |>.toBaseIO with | .ok r => pure r - | .error msg => throw (.internal msg) + | .error msg => + emitMessageAndAbort (file := pythonIonPath) .pySpecParsingError msg - let result ← profileStep profile "Resolve and build Laurel prelude" do - match ← resolveAndBuildLaurelPrelude dispatchModules pyspecModules stmts specDir (quiet := quiet) |>.toBaseIO with - | .ok r => pure r - | .error msg => throw (.internal msg) - - -- Print and write PySpec warnings before later stages can fail - let pyspecWarnings := result.pyspecWarnings - if pyspecWarnings.size > 0 && !quiet then - let _ ← IO.eprintln - s!"{pyspecWarnings.size} PySpec translation warning(s):" |>.toBaseIO - for err in pyspecWarnings do - let _ ← IO.eprintln s!" {err.file}: {err.kind.phase}.{err.kind.category}: {err.message}" |>.toBaseIO - if let some warnFile := warningSummaryFile then - let counts : Std.HashMap _ Nat := pyspecWarnings.foldl (init := {}) - fun acc err => acc.alter err.kind fun mv => some (mv.getD 0 + 1) - let entries := counts.toArray.qsort fun ⟨a, _⟩ ⟨b, _⟩ => a < b - let jsonEntries : Array Lean.Json := entries.map fun (kind, count) => - Lean.Json.mkObj [ - ("phase", .str kind.phase), - ("category", .str kind.category), - ("count", .num count) - ] - let json := Lean.Json.mkObj [ - ("pyspecWarningSummary", .arr jsonEntries), - ("totalWarnings", .num pyspecWarnings.size) - ] - match ← IO.FS.writeFile warnFile (json.compress ++ "\n") |>.toBaseIO with - | .ok () => pure () - | .error e => - let _ ← IO.eprintln s!"warning: failed to write warning summary to {warnFile}: {e}" |>.toBaseIO + let result ← Pipeline.withPhase "resolveAndBuildPrelude" do + resolveAndBuildLaurelPrelude dispatchModules pyspecModules stmts specDir let preludeInfo := buildPreludeInfo result - let metadataPath := sourcePath.getD pythonIonPath - let (laurelProgram, _ctx) ← profileStep profile "Translate Python to Laurel" do + + let (laurelProgram, _ctx) ← match Python.pythonToLaurel preludeInfo stmts metadataPath result.overloads with - | .error (.userPythonError range msg) => throw (.userCode range msg) + | .error (.userPythonError range msg) => + emitMessageAndAbort (file := sourcePath.getD pythonIonPath) (loc := range) + .laurelLoweringUserError msg | .error (.unsupportedConstruct msg ast) => - throw (.knownLimitation s!"Unsupported construct: {msg}\nAST: {ast}") - | .error e => throw (.internal s!"Python to Laurel translation failed: {e}") + emitMessageAndAbort (file := sourcePath.getD pythonIonPath) + .laurelLoweringNotImpl s!"Unsupported construct: {msg}\nAST: {ast}" + | .error e => + emitMessageAndAbort (file := sourcePath.getD pythonIonPath) + .laurelLoweringError s!"Python to Laurel translation failed: {e}" | .ok result => pure result - let filteredPrelude ← profileStep profile "Filter prelude" do + let filteredPrelude ← match Laurel.filterPrelude result.laurelProgram laurelProgram with | .ok prog => pure prog - | .error msg => throw (.internal msg) + | .error msg => + emitMessageAndAbort (file := sourcePath.getD pythonIonPath) .laurelLoweringError msg - profileStep profile "Combine PySpec and user Laurel" do - return combinePySpecLaurel filteredPrelude laurelProgram + let combined := combinePySpecLaurel filteredPrelude laurelProgram + return combined end Strata diff --git a/Strata/Languages/Python/PythonIdent.lean b/Strata/Languages/Python/PythonIdent.lean new file mode 100644 index 0000000000..58960fe207 --- /dev/null +++ b/Strata/Languages/Python/PythonIdent.lean @@ -0,0 +1,251 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ +module + +public section +namespace Strata.Python + +abbrev ModuleComponent := { nm : String // nm ≠ "" } + +def ModuleComponent.ofString (s : String) (h : s ≠ "" := by decide) : ModuleComponent := ⟨s, h⟩ + +/-- +A Python module name split into its dot-separated components. +For example, `typing.List` has components `["typing", "List"]`. +The size constraint ensures at least one component exists. +-/ +structure ModuleName where + mk :: + components : Array ModuleComponent + components_size_pos : components.size > 0 + deriving DecidableEq, Hashable, Ord + +namespace ModuleName + +instance : LT ModuleName where + lt a b := compare a b == .lt + +instance (a b : ModuleName) : Decidable (a < b) := + inferInstanceAs (Decidable (compare a b == .lt)) + +instance : Inhabited ModuleName where + default := private { + components := #[⟨"placeholder", by simp⟩], + components_size_pos := by simp + } + + +private +def ofSliceAux (mod : String.Slice) (a : Array ModuleComponent) (start cur : mod.Pos) : Option ModuleName := + if h : cur.IsAtEnd then + let r := mod.extract start cur + if ne : r = "" then + .none + else + some { + components := a.push ⟨r, ne⟩ + components_size_pos := by simp + } + else + let c := cur.get h + if _ : c = '.' then + let r := mod.extract start cur + if ne : r = "" then + .none + else + let next := cur.next h + ofSliceAux mod (a.push ⟨r, ne⟩) next next + else + let next := cur.next h + ofSliceAux mod a start next + termination_by cur + +/-- Parses a dot-separated module name string (e.g., "typing.List"). -/ +def ofSlice? (mod : String.Slice) : Option ModuleName := + ofSliceAux mod #[] mod.startPos mod.startPos + +/-- Parses a dot-separated module name string (e.g., "typing.List"). -/ +def ofString? (mod : String) : Option ModuleName := + ofSlice? mod.toSlice + +/-- +Parses a dot-separated module name string (e.g., "typing.List") +and panics if parsing fails. +-/ +def ofString! (mod : String) : ModuleName := + match ofString? mod with + | .some m => m + | .none => panic! s!"Malformed module {mod}" -- nopanic:ok + +/-- Convert a module name to a string, joining components with `sep` (default `"."`). -/ +protected def toString (m : ModuleName) (sep : String := ".") : String := + let p : m.components.size > 0 := m.components_size_pos + m.components.foldl (init := m.components[0]) (start := 1) fun a m => + a ++ sep ++ m.val + +instance : ToString ModuleName where + toString m := m.toString + +/-- The last component of the module name. E.g., `"typing.List"` → `"List"`. -/ +def back (m : ModuleName) : String := + let p := m.components_size_pos + m.components.back.val + +/-- Drop the last `n` components. Returns `none` if fewer than `n` components remain. -/ +def parent (m : ModuleName) (n : Nat := 1) : Option ModuleName := + let c := m.components.take (m.components.size - n) + if h : c.size > 0 then + some ⟨c, h⟩ + else + none + +#guard (ModuleName.ofString! "a.b.c" |>.parent).map ModuleName.toString = some "a.b" +#guard (ModuleName.ofString! "a" |>.parent) = none +#guard (ModuleName.ofString! "a.b.c" |>.parent (n := 2)).map ModuleName.toString = some "a" +#guard (ModuleName.ofString! "a.b.c" |>.parent (n := 3)) = none + +/-- Create a single-component module name. -/ +def ofComponent (c : ModuleComponent) : ModuleName := + ⟨#[c], by simp⟩ + +/-- Append a component to the end. E.g., `"typing".push "List"` → `"typing.List"`. -/ +def push (m : ModuleName) (c : ModuleComponent) : ModuleName := + ⟨m.components.push c, by simp⟩ + +/-- Concatenate two module names. E.g., `"a.b" ++ "c.d"` → `"a.b.c.d"`. -/ +def append (m1 m2 : ModuleName) : ModuleName := + ⟨m1.components ++ m2.components, by have p := m1.components_size_pos; grind⟩ + +instance : HAppend ModuleName ModuleName ModuleName where + hAppend := append + +instance : Repr ModuleName where + reprPrec m prec := Repr.addAppParen s!"Strata.ModuleName.ofString! {m}" prec + +/-- +Result of parsing a Python file path into a module name. +`isInit` indicates whether the file is a package `__init__.py`. +-/ +structure ModuleOfPath where + moduleName : ModuleName + isInit : Bool + deriving DecidableEq, Repr + +namespace ModuleOfPath + +/-- The package prefix for relative import resolution. + For `__init__.py` files, this is the full module name's components. + For regular files, this is the module name minus the last component (may be empty). -/ +def modulePrefix (m : ModuleOfPath) : Array ModuleComponent := + if m.isInit then + m.moduleName.components + else + m.moduleName.components.take (m.moduleName.components.size - 1) + +/-- Package prefix as a ModuleName, or none for top-level modules. -/ +def modulePrefix? (m : ModuleOfPath) : Option ModuleName := + if m.isInit then some m.moduleName + else m.moduleName.parent + +end ModuleOfPath + +/-- Derive a `ModuleName` from a file path relative to a search root. + + Examples: + "module.py" → .ok { moduleName := "module", isInit := false } + "service/__init__.py" → .ok { moduleName := "service", isInit := true } + "service/sub/module.py" → .ok { moduleName := "service.sub.module", isInit := false } + "service/sub/__init__.py" → .ok { moduleName := "service.sub", isInit := true } + + Fails if the path doesn't end in `.py` or would produce an empty component. -/ +def ofRelativePath (relativePath : System.FilePath) : Except String ModuleOfPath := do + let parts := relativePath.components |>.toArray + let some last := parts.back? + | throw s!"empty path: {relativePath}" + let (stems, isInit) ← + if last == "__init__.py" then + pure (parts.pop, true) + else if last.endsWith ".py" then + pure (parts.pop.push (last.dropEnd 3 |>.toString), false) + else + throw s!"path does not end in .py: {relativePath}" + let components : Array ModuleComponent ← stems.mapM fun s => + if h : s = "" then + throw s!"empty component in path: {relativePath}" + else + return ⟨s, h⟩ + if h : components.size > 0 then + .ok { moduleName := ⟨components, h⟩, isInit } + else + throw s!"no module components in path: {relativePath}" + +private def testOfRelativePath (path : String) (expectedMod : String) (expectedInit : Bool) : Bool := + match ofRelativePath path with + | .ok info => info.moduleName.toString == expectedMod && info.isInit == expectedInit + | .error _ => false + +#guard testOfRelativePath "module.py" "module" false +#guard testOfRelativePath "service/__init__.py" "service" true +#guard testOfRelativePath "service/sub/module.py" "service.sub.module" false +#guard testOfRelativePath "service/sub/__init__.py" "service.sub" true +#guard ofRelativePath "readme.txt" |>.isOk |>.not +#guard ofRelativePath "__init__.py" |>.isOk |>.not + +#guard (ModuleName.ofString! "a.b.c").toString = "a.b.c" +#guard (ModuleName.ofString! "a").toString = "a" +#guard ModuleName.ofString? "" = none +#guard ModuleName.ofString? "." = none +#guard ModuleName.ofString? "a." = none +#guard ModuleName.ofString? ".a" = none +#guard ModuleName.ofString? "a..b" = none +#guard (ModuleName.ofString! "a.b.c").back = "c" +#guard (ModuleName.ofComponent ⟨"x", by decide⟩).back = "x" +#guard ((ModuleName.ofString! "a") ++ (ModuleName.ofString! "b.c")).toString = "a.b.c" + +end ModuleName + +/-- +A fully-qualified Python identifier consisting of a module path and a name. +For example, `typing.List` has module "typing" and name "List". +-/ +structure PythonIdent where + mkRaw :: + pythonModule : ModuleName + name : String + deriving DecidableEq, Hashable, Ord, Repr + +namespace PythonIdent + +instance : Inhabited PythonIdent where + default := { + pythonModule := default + name := "default" + } + +/-- Construct from a single-component module name. Compile-time error if `mod` is empty. -/ +def ofComponent (mod : String) (name : String) (h : mod ≠ "" := by decide) : PythonIdent := + { pythonModule := .ofComponent ⟨mod, h⟩, name } + +protected def ofString (s : String) : Option PythonIdent := do + let idx ← s.revFind? '.' + let m ← ModuleName.ofString? (s.extract s.startPos idx) + let next ← idx.next? + some { + pythonModule := m + name := s.extract next s.endPos + } + +/-- Convert to a string, joining module components and name with `sep` (default `"."`). -/ +protected def toString (i : PythonIdent) (sep : String := ".") : String := + i.pythonModule.toString sep ++ sep ++ i.name + +instance : ToString PythonIdent where + toString := PythonIdent.toString + +end PythonIdent + +end Strata.Python +end diff --git a/Strata/Languages/Python/PythonToLaurel.lean b/Strata/Languages/Python/PythonToLaurel.lean index 0310fb473f..0c1a030899 100644 --- a/Strata/Languages/Python/PythonToLaurel.lean +++ b/Strata/Languages/Python/PythonToLaurel.lean @@ -473,12 +473,12 @@ def resolveDispatch (ctx : TranslationContext) | some fnOverloads => let kwPairs := kwords.map Python.keyword.nameAndValue let some firstArg := fnOverloads.findDispatchArg args kwPairs - | let msg := match fnOverloads.paramName, kwPairs.filterMap (·.1) with - | some expected, provided@(_ :: _) => + | let msg := match kwPairs.filterMap (·.1) with + | provided@(_ :: _) => s!"Dispatched function '{funcName}' called with wrong \ - keyword argument, expected '{expected}' but got \ + keyword argument, expected '{fnOverloads.paramName}' but got \ '{String.intercalate "', '" provided}'" - | _, _ => + | _ => s!"Dispatched function '{funcName}' called with no \ arguments (expected a string literal first argument)" throw (.typeError msg) @@ -489,12 +489,7 @@ def resolveDispatch (ctx : TranslationContext) let suffix := if fnOverloads.entries.size > 2 then s!" ... ({fnOverloads.entries.size} total)" else "" throwUserError range s!"'{funcName}' called with unknown string \"{s.val}\"; known services: {knownServices}{suffix}" - let className := - if ident.pythonModule.isEmpty then - ident.name - else - ident.pythonModule.replace "." "_" ++ "_" ++ ident.name - return some className + return some <| ident.toString (sep := "_") | _ => return none /-! ## Expression Translation -/ @@ -1336,6 +1331,78 @@ def freeVarExpr (name: String) := mkStmtExprMd (.Var (.Local name)) def maybeExceptVar := freeVarMd "maybe_except" def nullcall_var := freeVarMd "nullcall_ret" +/-- Walk an expression tree and extract any nested multi-output procedure calls + into preceding multi-target assignments. Returns (preamble, rewritten expr). + Uses a mutable counter for unique variable names. -/ +def extractMultiOutputCalls (ctx : TranslationContext) (e : StmtExprMd) + : StateM Nat (List StmtExprMd × StmtExprMd) := do + match _h : e.val with + | .StaticCall callee args => + if withException ctx callee.text then + -- Multi-output call: extract into a temp assignment and add exception check + let n ← get + set (n + 1) + let varName := s!"$mo_{n}" + let varDecl := mkVarDeclInit varName AnyTy AnyNone + let assign := mkStmtExprMdWithLoc (StmtExpr.Assign + [mkVariableMd (.Local varName), maybeExceptVar] + (mkStmtExprMdWithLoc (.StaticCall callee args) e.source)) e.source + let varRef := mkStmtExprMdWithLoc (StmtExpr.Var (.Local varName)) e.source + return ([varDecl, assign], varRef) + else + -- Recurse into arguments + let results ← args.attach.mapM fun ⟨arg, _⟩ => extractMultiOutputCalls ctx arg + let preamble := (results.map (fun (pre, _) => pre)).flatten + let newArgs := results.map (·.2) + if preamble.isEmpty then + return ([], e) + else + return (preamble, mkStmtExprMdWithLoc (.StaticCall callee.text newArgs) e.source) + | .PrimitiveOp op args => + let results ← args.attach.mapM fun ⟨arg, _⟩ => extractMultiOutputCalls ctx arg + let preamble := (results.map (fun (pre, _) => pre)).flatten + let newArgs := results.map (·.2) + if preamble.isEmpty then + return ([], e) + else + return (preamble, mkStmtExprMdWithLoc (.PrimitiveOp op newArgs) e.source) + | .IfThenElse cond thenBr elseBr => + let (preCond, cond') ← extractMultiOutputCalls ctx cond + let (preThen, then') ← extractMultiOutputCalls ctx thenBr + let preElse ← elseBr.attach.mapM fun ⟨br, _⟩ => extractMultiOutputCalls ctx br + let thenExpr := + if preThen.isEmpty then + then' + else + mkStmtExprMdWithLoc (.Block (preThen ++ [then']) none) thenBr.source + let elseExpr := preElse.map fun (pre, else') => + if pre.isEmpty then + else' + else + mkStmtExprMdWithLoc (.Block (pre ++ [else']) none) else'.source + let anyRewrite := !preCond.isEmpty || !preThen.isEmpty || + preElse.any (fun (pre, _) => !pre.isEmpty) + if anyRewrite then + return (preCond, mkStmtExprMdWithLoc + (.IfThenElse cond' thenExpr elseExpr) e.source) + else + return ([], e) + | _ => return ([], e) +termination_by sizeOf e +decreasing_by + all_goals simp_wf + all_goals (try have := AstNode.sizeOf_val_lt e) + all_goals (try term_by_mem) + all_goals (cases e; simp_all; omega) + +/-- Translate an expression and extract any nested multi-output calls into + preceding statements. -/ +def translateExprExtractingCalls (ctx : TranslationContext) (e : Python.expr SourceRange) + (counter : Nat) : Except TranslationError (List StmtExprMd × StmtExprMd × Nat) := do + let expr ← translateExpr ctx e + let ((preamble, expr'), cnt) := (extractMultiOutputCalls ctx expr).run counter + return (preamble, expr', cnt) + partial def translateAssign (ctx : TranslationContext) (lhs: Python.expr SourceRange) (annotation: Option (Python.expr SourceRange) ) @@ -1371,7 +1438,18 @@ partial def translateAssign (ctx : TranslationContext) | .Attribute _ (.Name _ name _) _ _ => name.val == "self" && ctx.currentClassName.isSome | _ => false let rhsCtx := if isSelfFieldAssign then {ctx with suppressDispatch := true} else ctx - let rhs_trans ← translateExpr rhsCtx rhs + let extractionSeed := + if rhs.ann.isNone then + -- Fallback: hash the expression text to get a unique-enough seed + let text := pyExprToString lhs ++ " <- " ++ pyExprToString rhs + text.foldl (fun acc ch => acc * 131 + ch.toNat) 0 + else + -- Use byte offset directly — globally unique per source position + rhs.ann.start.byteIdx + let (moExtracts, rhs_trans, _) ← translateExprExtractingCalls rhsCtx rhs extractionSeed + -- Use the statement's source location for extracted assignments so that + -- diagnostics (e.g. requires checks) report the statement position. + let moExtracts := moExtracts.map fun s => ⟨s.val, source⟩ -- When an unmodeled call produces a Hole, also havoc maybe_except since -- the call is a black box that could throw any exception. let rhsIsCall := match rhs with | .Call _ _ _ _ => true | _ => false @@ -1439,7 +1517,7 @@ partial def translateAssign (ctx : TranslationContext) {newctx with variableTypes:= newctx.variableTypes ++ [(n.val, className.text)]} | _=> newctx if n.val ∈ newctx.variableTypes.unzip.1 then - return (newctx, assignStmts, true) + return (newctx, moExtracts ++ assignStmts, true) else let inferType ← inferExprType ctx rhs let type := match annotation with @@ -1451,7 +1529,7 @@ partial def translateAssign (ctx : TranslationContext) if isKnownType ctx annStr then annStr else inferType let initStmt := mkVarDeclInit n.val AnyTy AnyNone newctx := {ctx with variableTypes:=(n.val, type)::ctx.variableTypes} - return (newctx, initStmt :: assignStmts, true) + return (newctx, moExtracts ++ (initStmt :: assignStmts), true) | .Subscript _ _ _ _ => match getSubscriptList lhs with | target :: slices => @@ -1460,7 +1538,7 @@ partial def translateAssign (ctx : TranslationContext) let source := sourceRangeToSource ctx.filePath lhs.toAst.ann let anySetsExpr := mkStmtExprMdWithLoc (StmtExpr.StaticCall "Any_sets!" [ListAny_mk slices, target, rhs_trans]) source let assignStmts := [mkStmtExprMdWithLoc (StmtExpr.Assign [← stmtExprToVar target] anySetsExpr) source] - return (ctx,assignStmts, false) + return (ctx, moExtracts ++ assignStmts, false) | _ => throw (.internalError "Invalid Subscript Expr") | .Attribute _ obj attr _ => match obj with @@ -1480,11 +1558,11 @@ partial def translateAssign (ctx : TranslationContext) else pure rhs_trans | none => pure rhs_trans let assignStmt := mkStmtExprMdWithLoc (StmtExpr.Assign [fieldAccess] rhs') source - return (ctx, [assignStmt], true) + return (ctx, moExtracts ++ [assignStmt], true) else let targetExpr ← translateExpr ctx lhs -- This will handle self.field via translateExpr let assignStmt := mkStmtExprMdWithLoc (StmtExpr.Assign [← stmtExprToVar targetExpr] rhs_trans) source - return (ctx, [assignStmt], true) + return (ctx, moExtracts ++ [assignStmt], true) | _ => throw (.unsupportedConstruct "Assignment targets not yet supported" (toString (repr lhs))) | _ => throw (.unsupportedConstruct "Assignment targets not yet supported" (toString (repr lhs))) @@ -1587,13 +1665,16 @@ partial def getExceptionAssertions (ctx : TranslationContext) (e : StmtExprMd) : mkExceptionCheckAssert mbe s!"Check {funcName} exception" /-- Check whether an expression tree contains a `StaticCall` to a user-defined - function (procedure). Such calls are disallowed in pure contexts (e.g. - assert bodies), so exception-check assertions that embed them must first - extract the expression into a temporary variable. See issue #1000. -/ + function (procedure) or a multi-output prelude procedure. Such calls are + disallowed in pure contexts (e.g. assert bodies), so exception-check + assertions that embed them must first extract the expression into a + temporary variable. See issue #1000. -/ partial def containsUserCall (ctx : TranslationContext) (e : StmtExprMd) : Bool := match e.val with | .StaticCall callee args => - callee.text ∈ ctx.userFunctions || args.any (containsUserCall ctx) + callee.text ∈ ctx.userFunctions || + withException ctx callee.text || + args.any (containsUserCall ctx) | .PrimitiveOp _ args => args.any (containsUserCall ctx) | .IfThenElse cond thenBranch elseBranch => containsUserCall ctx cond || containsUserCall ctx thenBranch || @@ -1621,10 +1702,26 @@ def withExceptionChecks (ctx : TranslationContext) (result : TranslationContext × List StmtExprMd) : TranslationContext × List StmtExprMd := let (newctx, stmts) := result - let rhs_exprs := stmts.flatMap fun s => - match s.val with | .Assign _ value => [value] | _ => [] + -- Generate exception checks for the last assignment's RHS. + -- Find the last Assign in the list (there may be trailing type assertions). + let lastAssignIdx := stmts.reverse.findIdx? fun s => + match s.val with | .Assign _ _ => true | _ => false + let rhs_exprs := match lastAssignIdx with + | some revIdx => + let idx := stmts.length - 1 - revIdx + match stmts[idx]!.val with | .Assign _ value => [value] | _ => [] + | none => [] let exceptionCheck := rhs_exprs.flatMap $ getExceptionAssertions ctx - (newctx, exceptionCheck ++ stmts) + if exceptionCheck.isEmpty then + (newctx, stmts) + else + match lastAssignIdx with + | some revIdx => + let idx := stmts.length - 1 - revIdx + let before := stmts.take idx + let rest := stmts.drop idx + (newctx, before ++ exceptionCheck ++ rest) + | none => (newctx, exceptionCheck ++ stmts) mutual @@ -2752,10 +2849,7 @@ def pythonToLaurel (info : PreludeInfo) let overloadCompositeType := Std.HashSet.ofList $ (overloadTable.values.flatMap (·.entries.values)).map fun ident => - if ident.pythonModule.isEmpty then - ident.name - else - ident.pythonModule ++ "_" ++ ident.name + ident.toString (sep := "_") let mut compositeTypeNames := info.compositeTypes.union overloadCompositeType -- FIRST PASS: Collect all class definitions and field type info diff --git a/Strata/Languages/Python/Specs.lean b/Strata/Languages/Python/Specs.lean index 00ec5f786d..92663d1706 100644 --- a/Strata/Languages/Python/Specs.lean +++ b/Strata/Languages/Python/Specs.lean @@ -8,132 +8,52 @@ module import Strata.DDM.Format import all Strata.DDM.Util.Fin import Strata.Languages.Python.ReadPython -import Strata.Languages.Python.Specs.DDM +import Strata.Languages.Python.Specs.DDM public import Strata.Languages.Python.Specs.Decls -import Strata.Languages.Python.Specs.Error +import Strata.Languages.Python.Specs.MessageKind +import Strata.Pipeline.Messages import Strata.Util.DecideProp -namespace Strata.Python.Specs - -/-- Type class for monads that support PySpec error and warning reporting. -/ -public class PySpecMClass (m : Type → Type) where - /-- Report an error at a specific source location. -/ - specError (loc : SourceRange) (message : String) : m Unit - /-- Report a warning at a specific source location. -/ - specWarning (loc : SourceRange) (message : String) : m Unit - /-- Run an action and check if any new errors were reported. -/ - runChecked {α} (act : m α) : m (Bool × α) - /-- Run an action and return `true` if no new errors or warnings were reported. -/ - runNoWarn {α} (act : m α) : m (Bool × α) - -open PySpecMClass (specError specWarning runChecked runNoWarn) - -/-- String identifier for event types. -/ -public abbrev EventType := String - -/-- Event type for module imports. -/ -def importEvent : EventType := "import" - -/-- -Log message for event type if enabled in the given event set. -Output format: `[event]: message` --/ -public def baseLogEvent (events : Std.HashSet EventType) - (event : EventType) (message : String) : BaseIO Unit := do - if event ∈ events then - let _ ← IO.eprintln s!"[{event}]: {message}" |>.toBaseIO - pure () - -/-- -Creates `PythonToStrataOptions` from an event set. - -Enables `logPerf` when `"perf"` is present. --/ -def PythonToStrataOptions.ofEventSet (events : Std.HashSet EventType) : PythonToStrataOptions where - logPerf := events.contains "perf" - -/-- -A Python module name split into its dot-separated components. -For example, `typing.List` has components `["typing", "List"]`. -The size constraint ensures at least one component exists. --/ -public structure ModuleName where - components : Array String - componentsSizePos : components.size > 0 - -namespace ModuleName - -def ofStringAux (mod : String) (a : Array String) (start cur : mod.Pos) : Except String ModuleName := - if h : cur.IsAtEnd then - let r := mod.extract start cur - pure { - components := a.push r - componentsSizePos := by simp - } - else - let c := cur.get h - if _ : c = '.' then - let r := mod.extract start cur - let next := cur.next h - ofStringAux mod (a.push r) next next - else - let next := cur.next h - ofStringAux mod a start next - termination_by cur - -/-- Parses a dot-separated module name string (e.g., "typing.List"). -/ -public def ofString (mod : String) : Except String ModuleName := - ofStringAux mod #[] mod.startPos mod.startPos - -public instance : ToString ModuleName where - toString m := - let p : m.components.size > 0 := m.componentsSizePos - m.components.foldl (init := m.components[0]) (start := 1) (s!"{.}.{.}") +namespace Strata.Python.ModuleName public def foldlDirs {α} (mod : ModuleName) (init : α) (f : α → String → α) : α := - mod.components.foldl (init := init) (stop := mod.components.size - 1) f + mod.components.foldl (init := init) (stop := mod.components.size - 1) fun a c => f a c.val def foldlMDirs {α m} [Monad m] (mod : ModuleName) (init : α) (f : α → String → m α) : m α := do - mod.components.foldlM (init := init) (stop := mod.components.size - 1) f - -def fileRoot (mod : ModuleName) : String := - let p := mod.componentsSizePos - mod.components.back + mod.components.foldlM (init := init) (stop := mod.components.size - 1) fun a c => f a c.val /-- Locate the Python source file for a module within `searchPath`. Navigates subdirectories for intermediate components, then looks for `{leaf}.py`. Falls back to `{leaf}/__init__.py` for packages. -Returns `(filePath, modulePrefix)` where `modulePrefix` is the array -of package components for resolving relative imports. For `__init__.py` -packages this is all components; for regular files it is all but the last. +Returns `(filePath, isInit)` where `isInit` indicates whether the resolved +file is a package `__init__.py`. -/ public def findInPath (mod : ModuleName) (searchPath : System.FilePath) - : EIO String (System.FilePath × Array String) := do + : EIO String (System.FilePath × Bool) := do let findComponent path comp := do let newPath := path / comp if !(← newPath.isDir) then throw s!"Directory {newPath} not found" return newPath let dir ← mod.foldlMDirs (init := searchPath) findComponent - let file := dir / s!"{mod.fileRoot}.py" + let file := dir / s!"{mod.back}.py" if let .ok md ← file.metadata |>.toBaseIO then if md.type != .file then throw s!"{file} is not a regular file." - let modulePrefix := mod.components.toSubarray (stop := mod.components.size - 1) |>.toArray - return (file, modulePrefix) + return (file, false) -- Fall back to __init__.py for packages (directories) - let pkgDir := dir / mod.fileRoot + let pkgDir := dir / mod.back let initFile := pkgDir / "__init__.py" if let .ok md ← initFile.metadata |>.toBaseIO then if md.type != .file then throw s!"{initFile} is not a regular file." - return (initFile, mod.components) + return (initFile, true) -- Fail both throw s!"{file} not found (also no {initFile})." /-- Generates the output filename for a module's spec file. -/ -public def strataFileName (mod : ModuleName) : String := s!"{mod.fileRoot}.pyspec.st.ion" +public def strataFileName (mod : ModuleName) : String := s!"{mod.back}.pyspec.st.ion" /-- Resolve a module name to a PySpec Ion file path under `specDir`. Tries the canonical path first (`specDir/servicelib/Storage.pyspec.st.ion`), @@ -141,47 +61,61 @@ public def strataFileName (mod : ModuleName) : String := s!"{mod.fileRoot}.pyspe for package modules. Returns `none` if neither exists. -/ public def specIonPath (mod : ModuleName) (specDir : System.FilePath) : BaseIO (Option System.FilePath) := do - let canonical := mod.foldlDirs (init := specDir) (· / ·) / mod.strataFileName + let modPath := mod.foldlDirs (init := specDir) (· / ·) + let canonical := modPath / mod.strataFileName if ← canonical.pathExists then return some canonical -- Fall back to __init__ layout for package modules - let initPath := mod.foldlDirs (init := specDir) (· / ·) / mod.fileRoot / "__init__.pyspec.st.ion" + let initPath := modPath / mod.back / "__init__.pyspec.st.ion" if ← initPath.pathExists then return some initPath return none -/-- Derive a ModuleName and its search root from a Python source file path. - For regular files, the root is the parent directory. - For package init files (`__init__.py`), the module name is the parent - directory name and the root is the grandparent. - In both cases, `findInPath mod root` resolves back to the original file. -/ -public def ofFile (pythonFile : System.FilePath) - : Except String (ModuleName × System.FilePath) := do - let (stem, root) := - if pythonFile.fileName == some "__init__.py" then - (pythonFile.parent >>= (·.fileName), pythonFile.parent >>= (·.parent)) - else - (pythonFile.fileStem, pythonFile.parent) - let some s := stem | .error s!"Cannot derive module name from {pythonFile}" - let some r := root | .error s!"Cannot derive search root from {pythonFile}" - if s.contains '.' then - .error s!"File stem '{s}' contains '.'; expected a simple module name (from {pythonFile})" - pure (← ofString s, r) - --- Unit tests for ofFile -private def testOfFile (path expectedMod expectedRoot : String) : Bool := - match ofFile path with - | .ok (mod, root) => toString mod == expectedMod && root.toString == expectedRoot - | .error _ => false - -#guard testOfFile "path/to/module.py" "module" "path/to" -#guard testOfFile "path/to/service/__init__.py" "service" "path/to" -#guard testOfFile "./module.py" "module" "." --- Bare filenames without a directory context are rejected -#guard match ofFile "module.py" with | .error _ => true | .ok _ => false -#guard match ofFile "__init__.py" with | .error _ => true | .ok _ => false --- Dotted file stems are rejected (would be silently split by ofString) -#guard match ofFile "path/to/foo.bar.py" with | .error _ => true | .ok _ => false - -end ModuleName + +def mkIdent (mod : ModuleName) (name : String) : PythonIdent := + { pythonModule := mod, name } + +end Strata.Python.ModuleName + +open Strata.Pipeline + +namespace Strata.Python.Specs + +/-- Type class for monads that support PySpec error and warning reporting. -/ +public class PySpecMClass (m : Type → Type) where + /-- Report an error at a specific source location. -/ + specError (loc : SourceRange) (message : String) : m Unit + /-- Report a warning at a specific source location. -/ + specWarning (loc : SourceRange) (message : String) : m Unit + /-- Run an action and check if any new errors were reported. -/ + runChecked {α} (act : m α) : m (Bool × α) + /-- Run an action and return `true` if no new errors or warnings were reported. -/ + runNoWarn {α} (act : m α) : m (Bool × α) + +open PySpecMClass (specError specWarning runChecked runNoWarn) + + +/-- String identifier for event types. -/ +public abbrev EventType := String + +/-- Event type for module imports. -/ +def importEvent : EventType := "import" + +/-- +Log message for event type if enabled in the given event set. +Output format: `[event]: message` +-/ +public def baseLogEvent (events : Std.HashSet EventType) + (event : EventType) (message : String) : BaseIO Unit := do + if event ∈ events then + let _ ← IO.eprintln s!"[{event}]: {message}" |>.toBaseIO + pure () + +/-- +Creates `PythonToStrataOptions` from an event set. + +Enables `logPerf` when `"perf"` is present. +-/ +def PythonToStrataOptions.ofEventSet (events : Std.HashSet EventType) : PythonToStrataOptions where + logPerf := events.contains "perf" inductive SpecValue | boolConst (b : Bool) @@ -210,7 +144,7 @@ structure TypeDecl where Map from Python identifiers to their type specifications. -/ structure TypeSignature where - rank : Std.HashMap String (Option (Std.HashMap String SpecValue)) + rank : Std.HashMap ModuleName (Option (Std.HashMap String SpecValue)) namespace TypeSignature @@ -222,7 +156,7 @@ def ofList (l : List TypeDecl) : TypeSignature where | .some none => .some none | .some (some m) => m |>.insert d.ident.name d.value -def insert (sig : TypeSignature) (name : String) (m : Option (Std.HashMap String SpecValue)) := +def insert (sig : TypeSignature) (name : ModuleName) (m : Option (Std.HashMap String SpecValue)) := { sig with rank := sig.rank.insert name m } end TypeSignature @@ -281,11 +215,11 @@ structure PySpecContext where strataDir : System.FilePath /-- Root directory for module resolution. Stays constant across nested imports. -/ baseSearchPath : System.FilePath - /-- Package prefix components for resolving relative imports to absolute names. - For `__init__.py` modules, this is all components (e.g., `#["boto3"]`). - For regular modules, this is all but the last (e.g., `#["boto3"]` for `boto3.client`). - Empty for top-level modules with no package. -/ - currentModulePrefix : Array String + /-- Package prefix for resolving relative imports. + For `__init__.py` modules, this is the full module name. + For regular modules, this is the parent (e.g., `boto3` for `boto3.client`). + `none` for top-level modules with no package. -/ + currentModulePrefix : Option ModuleName /-- Ref to file map registry for source-location error reporting. -/ fileMapsRef : IO.Ref FileMaps /-- Python module name for the current file (e.g., "boto3.dynamodb"). @@ -293,18 +227,18 @@ structure PySpecContext where currentModule : ModuleName /-- Resolve a module name to a file path, registering the file's FileMap - for source-location error reporting. Returns `(filePath, modulePrefix)` - where `modulePrefix` is the package prefix for resolving relative imports. -/ + for source-location error reporting. Returns `(filePath, isInit)` + where `isInit` indicates whether the file is a package `__init__.py`. -/ def PySpecContext.readModule (ctx : PySpecContext) (mod : ModuleName) - : EIO String (System.FilePath × Array String) := do - let (pythonPath, modulePrefix) ← mod.findInPath ctx.baseSearchPath + : EIO String (System.FilePath × Bool) := do + let (pythonPath, isInit) ← mod.findInPath ctx.baseSearchPath baseLogEvent ctx.eventSet "findFile" s!"Found {mod} as {pythonPath}" match ← IO.FS.readFile pythonPath |>.toBaseIO with | .ok contents => let fm := Lean.FileMap.ofString contents ctx.fileMapsRef.modify fun m => m.insert pythonPath fm - pure (pythonPath, modulePrefix) + pure (pythonPath, isInit) | .error msg => throw s!"Could not read file {pythonPath}: {msg}" @@ -321,8 +255,8 @@ def preludeAtoms : List (String × PythonIdent) := [ structure PySpecState where typeSigs : TypeSignature := preludeSig - errors : Array SpecError - warnings : Array SpecError + errors : Array PipelineMessage + warnings : Array PipelineMessage /-- This maps global identifiers to their value. -/ @@ -351,11 +285,13 @@ private def hasOverloadDecorator /-- Should we skip the given top-level name? -/ def shouldSkip (name : String) : PySpecM Bool := do let ctx ← read - let nameIdent := { pythonModule := toString ctx.currentModule, name } + let nameIdent := ctx.currentModule.mkIdent name return nameIdent ∈ ctx.skipNames +private def pySpecParsingPhase : Phase := Phase.base "pySpecParsing" + def specErrorAt (file : System.FilePath) (loc : SourceRange) (message : String) : PySpecM Unit := do - let e : SpecError := { file, loc, kind := .pySpecParsingError, message } + let e : PipelineMessage := { file, loc, phase := pySpecParsingPhase, kind := .pySpecParsingError, message } modify fun s => { s with errors := s.errors.push e } instance : PySpecMClass PySpecM where @@ -363,7 +299,7 @@ instance : PySpecMClass PySpecM where specErrorAt (←read).pythonFile loc message specWarning loc message := do let file := (←read).pythonFile - let w : SpecError := { file, loc, kind := .pySpecParsingWarning, message } + let w : PipelineMessage := { file, loc, phase := pySpecParsingPhase, kind := .pySpecParsingWarning, message } modify fun s => { s with warnings := s.warnings.push w } runChecked act := do let cnt := (←get).errors.size @@ -434,8 +370,7 @@ def valueAsType (loc : SourceRange) (v : SpecValue) : PySpecM SpecType := do return tp | _ => recordTypeRef loc val - let mod := toString (← read).currentModule - let pyIdent : PythonIdent := { pythonModule := mod, name := val } + let pyIdent := (← read).currentModule.mkIdent val return .ident loc pyIdent | _ => specError loc s!"Expected type instead of {repr v}." @@ -710,8 +645,7 @@ def pySpecArg (usedNames : Std.HashSet String) | some cl => if type.isSome then specError loc s!"Unexpected argument to {name}" - let mod := toString (← read).currentModule - pure <| .ident loc { pythonModule := mod, name := cl } #[] + pure <| .ident loc ((← read).currentModule.mkIdent cl) #[] assert! comment.isNone let argDefault ← match de with @@ -735,8 +669,8 @@ structure SpecAssertionContext where structure SpecAssertionState where assertions : Array Assertion := #[] postconditions : Array SpecExpr := #[] - errors : Array SpecError := #[] - warnings : Array SpecError := #[] + errors : Array PipelineMessage := #[] + warnings : Array PipelineMessage := #[] /-- Monad for extracting pre and post conditions from methods. -/ abbrev SpecAssertionM := ReaderT SpecAssertionContext (StateM SpecAssertionState) @@ -744,11 +678,11 @@ abbrev SpecAssertionM := ReaderT SpecAssertionContext (StateM SpecAssertionState instance : PySpecMClass SpecAssertionM where specError loc message := do let file := (←read) |>.filePath - let e : SpecError := { file, loc, kind := .pySpecParsingError, message } + let e : PipelineMessage := { file, loc, phase := pySpecParsingPhase, kind := .pySpecParsingError, message } modify fun s => { s with errors := s.errors.push e } specWarning loc message := do let file := (←read) |>.filePath - let w : SpecError := { file, loc, kind := .pySpecParsingWarning, message } + let w : PipelineMessage := { file, loc, phase := pySpecParsingPhase, kind := .pySpecParsingWarning, message } modify fun s => { s with warnings := s.warnings.push w } runChecked act := do let cnt := (←get).errors.size @@ -1298,8 +1232,7 @@ partial def pySpecClassBody (loc : SourceRange) (className : String) match value with | .Call _ (.Attribute _ (.Name _ ⟨_, "self"⟩ (.Load _)) ⟨_, innerClsName⟩ (.Load _)) _ _ => - let mod := toString (← read).currentModule - let pyIdent : PythonIdent := { pythonModule := mod, name := innerClsName } + let pyIdent := (← read).currentModule.mkIdent innerClsName let f : ClassField := { name := fieldName, type := .ident loc pyIdent #[], @@ -1332,10 +1265,10 @@ partial def pySpecClassBody (loc : SourceRange) (className : String) methods := methods } -def translateImportFrom (mod : String) (types : Std.HashMap String SpecValue) +def translateImportFrom (mod : ModuleName) (types : Std.HashMap String SpecValue) (names : Array (alias SourceRange)) : PySpecM Unit := do -- Check if module is a builtin (in prelude) - if so, don't generate extern declarations - let isBuiltinModule := preludeSig.rank.contains mod + let isBuiltinModule := mod ∈ preludeSig.rank for a in names do let name := a.name match types[name]? with @@ -1347,11 +1280,7 @@ def translateImportFrom (mod : String) (types : Std.HashMap String SpecValue) -- Generate extern declaration for imported types (but not for builtin modules) if !isBuiltinModule then if let .typeValue _ := tpv then - let source : PythonIdent := { - pythonModule := mod - name := name - } - pushSignature (.externTypeDecl asname source) + pushSignature (.externTypeDecl asname (mod.mkIdent name)) def getModifiedTime (f : System.FilePath) : IO IO.FS.SystemTime := do let md ← f.metadata @@ -1361,21 +1290,12 @@ def getModifiedTime (f : System.FilePath) : IO IO.FS.SystemTime := do Create a value map for module from signatures. -/ def signatureValueMap (mod : ModuleName) (sigs : Array Signature) : Std.HashMap String SpecValue := - let modName := toString mod let addType (m : Std.HashMap String SpecValue) (sig : Signature) := match sig with | .classDef d => - let pyIdent : PythonIdent := { - pythonModule := modName - name := d.name - } - m.insert d.name (.typeValue (.ident d.loc pyIdent)) + m.insert d.name (.typeValue (.ident d.loc (mod.mkIdent d.name))) | .typeDef d => - let pyIdent : PythonIdent := { - pythonModule := modName - name := d.name - } - m.insert d.name (.typeValue (.ident d.loc pyIdent)) + m.insert d.name (.typeValue (.ident d.loc (mod.mkIdent d.name))) | .externTypeDecl name source => m.insert name (.typeValue (.ident default source)) | .functionDecl .. => m @@ -1399,21 +1319,22 @@ public def isNewer (path : System.FilePath) (existing : IO.FS.Metadata) : BaseIO module prefix and prepends the remainder. E.g. `from ..X import Y` (level 2) in package `a.b.c` resolves to `a.b.X`. -/ def resolveRelativeModuleName (loc : SourceRange) (relName : String) (level : Int) - : PySpecM String := do - if level == 0 then return relName - let pfx := (←read).currentModulePrefix - if pfx.isEmpty then - specError loc - "Cannot use a relative import from a top-level module with no package" - return relName + : PySpecM ModuleName := do + let some m := ModuleName.ofString? relName + | specError loc s!"Invalid module name {relName}" + return default + if level == 0 then + return m + let some pfx := (←read).currentModulePrefix + | specError loc "Cannot use a relative import from a top-level module with no package" + return default let drop := level.toNat - 1 - if drop >= pfx.size then - specError loc <| - s!"Relative import (level {level}) goes beyond the top-level package; " ++ - s!"the current module is only {pfx.size} package level(s) deep" - return relName - let base := pfx.toSubarray (stop := pfx.size - drop) |>.toArray - return ".".intercalate (base.push relName).toList + let some base := pfx.parent (n := drop) + | specError loc <| + s!"Relative import (level {level}) goes beyond the top-level package; " ++ + s!"the current module is only {pfx.components.size} package level(s) deep" + return default + return base ++ m mutual @@ -1424,10 +1345,10 @@ Python source if not in cache. -/ partial def resolveModule (loc : SourceRange) (mod : ModuleName) : PySpecM (Std.HashMap String SpecValue) := do - let (pythonFile, childPrefix) ← + let (pythonFile, modPath) ← match ← (←read).readModule mod |>.toBaseIO with - | .ok r => - pure r + | .ok (path, isInit) => + pure (path, ModuleName.ModuleOfPath.mk mod isInit) | .error msg => specError loc msg return default @@ -1466,7 +1387,8 @@ partial def resolveModule (loc : SourceRange) (mod : ModuleName) : let ctx := { (←read) with pythonFile := pythonFile currentModule := mod - currentModulePrefix := childPrefix } + currentModulePrefix := modPath |>.modulePrefix? + } -- This does state shuffling to ensure warnings and errors maintain -- a reference count of 1 (for destructive updates). let s := ←get @@ -1490,37 +1412,26 @@ partial def resolveModule (loc : SourceRange) (mod : ModuleName) : return signatureValueMap mod sigs -partial def resolveModuleCached (loc : SourceRange) (mod : ModuleName) +partial def parseAndResolveModule (loc : SourceRange) (mod : ModuleName) : PySpecM (Option (Std.HashMap String SpecValue)) := do - let key := toString mod - match (←get).typeSigs.rank[key]? with + match (←get).typeSigs.rank[mod]? with | some types => return types | none => let (success, r) ← runChecked <| resolveModule loc mod let r := if success then some r else none - modify fun s => { s with typeSigs := s.typeSigs.insert key r } + modify fun s => { s with typeSigs := s.typeSigs.insert mod r } return r -/-- Parse a module name string and resolve it, returning `none` on - parse or resolution failure. -/ -partial def parseAndResolveModule (loc : SourceRange) (modName : String) - : PySpecM (Option (Std.HashMap String SpecValue)) := do - match ModuleName.ofString modName with - | .ok mod => resolveModuleCached loc mod - | .error msg => - specError loc msg - return none - /-- Resolve a module and register its exports under `"{asname}.{name}"`. If resolution fails, register `asname` as an opaque extern type. -/ partial def resolveAndRegisterModule (loc : SourceRange) - (mod asname : String) : PySpecM Unit := do + (mod : ModuleName) (asname : String) : PySpecM Unit := do if let some types ← parseAndResolveModule loc mod then for (name, tpv) in types do setNameValue s!"{asname}.{name}" tpv else - let source : PythonIdent := { pythonModule := mod, name := asname } + let source := mod.mkIdent asname let tpv : SpecValue := .typeValue (.ident loc source) setNameValue asname tpv pushSignature (.externTypeDecl asname source) @@ -1530,9 +1441,10 @@ partial def resolveAndRegisterModule (loc : SourceRange) partial def translateImport (loc : SourceRange) (names : Array (alias SourceRange)) : PySpecM Unit := do for a in names do - let mod := a.name - let asname := a.asname.getD mod - resolveAndRegisterModule loc mod asname + let asname := a.asname.getD a.name + match ModuleName.ofString? a.name with + | .some mod => resolveAndRegisterModule loc mod asname + | none => specError loc s!"Invalid module name {a.name}" /-- Handle a `from [..] module import name` statement. Supports absolute imports (level 0) and multi-level relative imports (level ≥ 1). @@ -1560,7 +1472,7 @@ partial def translateImportFromStmt (loc : SourceRange) for a in names do let name := a.name let asname := a.asname.getD name - let source : PythonIdent := { pythonModule := mod, name := name } + let source := mod.mkIdent name let tpv : SpecValue := .typeValue (.ident loc source) setNameValue asname tpv pushSignature (.externTypeDecl asname source) @@ -1640,8 +1552,7 @@ partial def translate (body : Array (stmt Strata.SourceRange)) : PySpecM Unit := let baseIdents ← resolveBaseClasses bases let (success, _) ← runChecked <| recordTypeDef loc className -- Add the class to nameMap so it can be used in forward references - let mod := toString (← read).currentModule - setNameValue className (.typeValue (.ident loc { pythonModule := mod, name := className } #[])) + setNameValue className (.typeValue (.ident loc ((← read).currentModule.mkIdent className) #[])) let d ← pySpecClassBody loc className baseIdents stmts let d := { d with exhaustive := isExhaustive } if success then @@ -1664,8 +1575,6 @@ partial def translateModuleAux (body : Array (Strata.Python.stmt Strata.SourceRa end - - /-- Translates Python AST statements to PySpec signatures with dependency resolution. -/ def translateModule (dialectFile searchPath strataDir pythonFile : System.FilePath) @@ -1675,8 +1584,8 @@ def translateModule (pythonCmd : String := "python") (events : Std.HashSet EventType := {}) (skipNames : Std.HashSet PythonIdent := {}) - (currentModulePrefix : Array String := #[]) : - BaseIO (FileMaps × Array Signature × Array SpecError × Array SpecError) := do + (currentModulePrefix : Option ModuleName := none) : + BaseIO (FileMaps × Array Signature × Array PipelineMessage × Array PipelineMessage) := do let fmm : FileMaps := {} let fmm := fmm.insert pythonFile fileMap let fileMapsRef : IO.Ref FileMaps ← IO.mkRef fmm @@ -1699,23 +1608,18 @@ def translateModule /-- Translates a Python source file to PySpec signatures. Main entry point for translation. -/ public def translateFile (dialectFile strataDir pythonFile searchPath : System.FilePath) + (moduleName : ModuleName) (pythonCmd : String := "python") (events : Std.HashSet EventType := {}) (skipNames : Std.HashSet PythonIdent := {}) - (moduleName : Option ModuleName := none) : EIO String (Array Signature × Array String) := do - let currentModule ← match moduleName with - | some m => pure m - | none => - let (mod, _) ← match ModuleName.ofFile pythonFile with - | .ok r => pure r - | .error e => throw e - pure mod - let mod := currentModule + let mod := moduleName -- Compute the package prefix for relative import resolution. - let modulePrefix := - if pythonFile.fileName == some "__init__.py" then mod.components - else mod.components.toSubarray (stop := mod.components.size - 1) |>.toArray + let modulePrefix : Option ModuleName := + if pythonFile.fileName == some "__init__.py" then + some moduleName + else + moduleName.parent let contents ← match ← IO.FS.readFile pythonFile |>.toBaseIO with | .ok b => pure b @@ -1743,8 +1647,8 @@ public def translateFile (pythonFile := pythonFile) (.ofString contents) body - currentModule - let ppErr (e : SpecError) : EIO String String := + mod + let ppErr (e : PipelineMessage) : EIO String String := match fmm[e.file]? with | none => throw s!"No location information for {e.file}" diff --git a/Strata/Languages/Python/Specs/DDM.lean b/Strata/Languages/Python/Specs/DDM.lean index 73c1ff245b..318bfcbad7 100644 --- a/Strata/Languages/Python/Specs/DDM.lean +++ b/Strata/Languages/Python/Specs/DDM.lean @@ -7,20 +7,15 @@ module public import Strata.DDM.Integration.Lean public import Strata.Languages.Python.Specs.Decls - -import Strata.DDM.AST -import Strata.DDM.Util.ByteArray -import Strata.DDM.Format import Strata.DDM.BuiltinDialects.Init public import Strata.DDM.Integration.Lean.OfAstM +import Strata.DDM.Format import Strata.DDM.Ion -public section - namespace Strata.Python /-- Converts a Python identifier to an annotated string for DDM serialization. -/ -private def PythonIdent.toDDM (d : PythonIdent) : Ann String SourceRange := +def PythonIdent.toDDM (d : PythonIdent) : Ann String SourceRange := ⟨.none, toString d⟩ namespace Specs @@ -199,7 +194,7 @@ def DDM.Int.ofDDM {α} : DDM.Int α → _root_.Int mutual -private def SpecIdent.toDDM (si : SpecIdent) (loc : SourceRange) : DDM.SpecType SourceRange := +def SpecIdent.toDDM (si : SpecIdent) (loc : SourceRange) : DDM.SpecType SourceRange := if si.args.isEmpty then .typeIdentNoArgs loc si.name.toDDM else @@ -207,7 +202,7 @@ private def SpecIdent.toDDM (si : SpecIdent) (loc : SourceRange) : DDM.SpecType termination_by sizeOf si decreasing_by cases si; decreasing_tactic -private def SpecTypedDict.toDDM (td : SpecTypedDict) (loc : SourceRange) : DDM.SpecType SourceRange := +def SpecTypedDict.toDDM (td : SpecTypedDict) (loc : SourceRange) : DDM.SpecType SourceRange := assert! td.fields.size = td.fieldTypes.size let argc := td.fieldTypes.size let a := Array.ofFn fun (⟨i, ilt⟩ : Fin argc) => @@ -216,7 +211,7 @@ private def SpecTypedDict.toDDM (td : SpecTypedDict) (loc : SourceRange) : DDM.S termination_by sizeOf td decreasing_by cases td; decreasing_tactic -private def SpecType.toDDM (d : SpecType) : DDM.SpecType SourceRange := +def SpecType.toDDM (d : SpecType) : DDM.SpecType SourceRange := let parts : Array (DDM.SpecType SourceRange) := let r := d.idents.attach.map fun ⟨si, _⟩ => si.toDDM d.loc let ints := d.intLits.toArray.qsort (· < ·) @@ -241,7 +236,7 @@ decreasing_by end -private def SpecAtomType.toDDM (d : SpecAtomType) +def SpecAtomType.toDDM (d : SpecAtomType) (loc : SourceRange := .none) : DDM.SpecType SourceRange := match d with | .ident nm args => @@ -259,10 +254,10 @@ private def SpecAtomType.toDDM (d : SpecAtomType) .typeTypedDict loc ⟨.none, a⟩ -private def SpecDefault.toDDM : Specs.SpecDefault → DDM.SpecDefault SourceRange +def SpecDefault.toDDM : Specs.SpecDefault → DDM.SpecDefault SourceRange | .none => .noneDefault .none -private def Arg.toDDM (d : Arg) : DDM.ArgDecl SourceRange := +def Arg.toDDM (d : Arg) : DDM.ArgDecl SourceRange := .mkArgDecl .none ⟨.none, d.name⟩ d.type.toDDM ⟨.none, d.default.map (·.toDDM)⟩ protected def SpecExpr.toDDM (e : SpecExpr) : DDM.SpecExprDecl SourceRange := @@ -299,18 +294,25 @@ def specExprFormatContext : FormatContext := def specExprFormatState : FormatState where openDialects := DDM.PythonSpecs_map.toList.foldl (init := {}) fun s d => s.insert d.name -instance : ToString SpecExpr where - toString e := (mformat (SpecExpr.toDDM e).toAst specExprFormatContext specExprFormatState).format.pretty +namespace SpecExpr + +public def toString (e : SpecExpr) : String := + (mformat (SpecExpr.toDDM e).toAst specExprFormatContext specExprFormatState).format.pretty + +public instance : ToString SpecExpr where + toString := SpecExpr.toString + +end SpecExpr -private def MessagePart.toDDM (p : MessagePart) : DDM.MessagePart SourceRange := +def MessagePart.toDDM (p : MessagePart) : DDM.MessagePart SourceRange := match p with | .str s => .strMessagePart .none ⟨.none, s⟩ | .expr e => .exprMessagePart .none e.toDDM -private def Assertion.toDDM (a : Assertion) : DDM.Assertion SourceRange := +def Assertion.toDDM (a : Assertion) : DDM.Assertion SourceRange := .mkAssertion .none a.formula.toDDM ⟨.none, a.message.map (·.toDDM)⟩ -private def FunctionDecl.toDDM (d : FunctionDecl) : DDM.FunDecl SourceRange := +def FunctionDecl.toDDM (d : FunctionDecl) : DDM.FunDecl SourceRange := .mkFunDecl d.loc (name := .mk d.nameLoc d.name) @@ -327,10 +329,10 @@ private def FunctionDecl.toDDM (d : FunctionDecl) : DDM.FunDecl SourceRange := d.postconditions.map fun e => .mkPostconditionEntry .none e.toDDM⟩) -private def ClassVariable.toDDM (cv : ClassVariable) : DDM.ClassVarDecl SourceRange := +def ClassVariable.toDDM (cv : ClassVariable) : DDM.ClassVarDecl SourceRange := .mkClassVarDecl .none ⟨.none, cv.name⟩ ⟨.none, cv.value⟩ -private partial def ClassDef.toDDMDecl (d : ClassDef) : DDM.ClassDecl SourceRange := +partial def ClassDef.toDDMDecl (d : ClassDef) : DDM.ClassDecl SourceRange := .mkClassDecl d.loc (.mk .none d.name) ⟨.none, d.bases.map (·.toDDM)⟩ ⟨.none, d.fields.map fun f => @@ -341,7 +343,7 @@ private partial def ClassDef.toDDMDecl (d : ClassDef) : DDM.ClassDecl SourceRang ⟨.none, d.methods.map (·.toDDM)⟩ ⟨.none, d.exhaustive⟩ -private def Signature.toDDM (sig : Signature) : DDM.Signature SourceRange := +def Signature.toDDM (sig : Signature) : DDM.Signature SourceRange := match sig with | .externTypeDecl name source => .externTypeDecl .none ⟨.none, name⟩ source.toDDM @@ -352,37 +354,46 @@ private def Signature.toDDM (sig : Signature) : DDM.Signature SourceRange := | .typeDef d => .typeDef d.loc (.mk d.nameLoc d.name) d.definition.toDDM -private def DDM.SpecType.fromDDM (d : DDM.SpecType SourceRange) : Specs.SpecType := +abbrev FromDDM := Except (SourceRange × String) + +def FromDDM.throw {α} (loc : SourceRange) (msg : String) : FromDDM α := + .error (loc, msg) + +def DDM.SpecType.fromDDM (d : DDM.SpecType SourceRange) : FromDDM Specs.SpecType := match d with | .typeClassNoArgs loc ⟨_, cl⟩ => - .ident loc { pythonModule := "", name := cl } #[] - | .typeClass loc ⟨_, cl⟩ ⟨_, args⟩ => - let a := args.map (·.fromDDM) - .ident loc { pythonModule := "", name := cl } a + match PythonIdent.ofString cl with + | none => .throw loc s!"Unsupported identifier {cl} in typeClass" + | some nm => .ok <| .ident loc nm #[] + | .typeClass loc ⟨_, cl⟩ ⟨_, args⟩ => do + let nm ← match PythonIdent.ofString cl with + | none => .throw loc s!"Unsupported identifier {cl} in typeClass" + | some nm => pure nm + let a ← args.mapM (·.fromDDM) + pure <| .ident loc nm a | .typeIdentNoArgs loc ⟨_, ident⟩ => - if let some pyIdent := PythonIdent.ofString ident then - .ident loc pyIdent #[] - else - panic! "Bad identifier" - | .typeIdent loc ⟨_, ident⟩ ⟨_, args⟩ => - let a := args.map (·.fromDDM) - if let some pyIdent := PythonIdent.ofString ident then - .ident loc pyIdent a - else - panic! "Bad identifier" - | .typeIntLiteral loc i => .intLiteral loc i.ofDDM - | .typeStringLiteral loc ⟨_, s⟩ => .stringLiteral loc s - | .typeTypedDict loc ⟨_, fields⟩ => + match PythonIdent.ofString ident with + | some pyIdent => .ok <| .ident loc pyIdent #[] + | none => .throw loc s!"Bad identifier: {ident}" + | .typeIdent loc ⟨_, ident⟩ ⟨_, args⟩ => do + let a ← args.mapM (·.fromDDM) + match PythonIdent.ofString ident with + | some pyIdent => pure <| .ident loc pyIdent a + | none => .throw loc s!"Bad identifier: {ident}" + | .typeIntLiteral loc i => .ok <| .intLiteral loc i.ofDDM + | .typeStringLiteral loc ⟨_, s⟩ => .ok <| .stringLiteral loc s + | .typeTypedDict loc ⟨_, fields⟩ => do let names := fields.map fun (.mkDictFieldDecl _ ⟨_, name⟩ _ _) => name - let types := fields.attach.map fun ⟨.mkDictFieldDecl _ _ tp _, mem⟩ => tp.fromDDM + let types ← fields.attach.mapM fun ⟨.mkDictFieldDecl _ _ tp _, mem⟩ => tp.fromDDM let required := fields.map fun (.mkDictFieldDecl _ _ _ ⟨_, r⟩) => r - .typedDict loc names types required - | .typeUnion loc ⟨_, args⟩ => + pure <| .typedDict loc names types required + | .typeUnion loc ⟨_, args⟩ => do if p : args.size > 0 then - args.attach.foldl (init := args[0].fromDDM) (start := 1) - fun a ⟨b, mem⟩ => SpecType.union loc a b.fromDDM + let init ← args[0].fromDDM + args.attach.foldlM (init := init) (start := 1) fun a ⟨b, mem⟩ => do + return .union loc a (← b.fromDDM) else - panic! "Expected non-empty union" + .throw loc "Expected non-empty union" termination_by sizeOf d decreasing_by · decreasing_tactic @@ -394,18 +405,18 @@ decreasing_by · decreasing_tactic · decreasing_tactic -private def DDM.SpecDefault.fromDDM : DDM.SpecDefault SourceRange → Specs.SpecDefault +def DDM.SpecDefault.fromDDM : DDM.SpecDefault SourceRange → Specs.SpecDefault | .noneDefault _ => .none -private def DDM.ArgDecl.fromDDM (d : DDM.ArgDecl SourceRange) : Specs.Arg := +def DDM.ArgDecl.fromDDM (d : DDM.ArgDecl SourceRange) : FromDDM Specs.Arg := do let .mkArgDecl _ ⟨_, name⟩ type ⟨_, default⟩ := d - { + pure { name := name - type := type.fromDDM + type := ← type.fromDDM default := default.map (·.fromDDM) } -private def DDM.SpecExprDecl.fromDDM (d : DDM.SpecExprDecl SourceRange) : Specs.SpecExpr := +def DDM.SpecExprDecl.fromDDM (d : DDM.SpecExprDecl SourceRange) : Specs.SpecExpr := match d with | .placeholderExpr loc => .placeholder loc | .varExpr loc ⟨_, name⟩ => .var name loc @@ -429,79 +440,81 @@ private def DDM.SpecExprDecl.fromDDM (d : DDM.SpecExprDecl SourceRange) : Specs. | .forallDictExpr loc dict ⟨_, keyVar⟩ ⟨_, valVar⟩ body => .forallDict dict.fromDDM keyVar valVar body.fromDDM loc -private def DDM.MessagePart.fromDDM (d : DDM.MessagePart SourceRange) : Specs.MessagePart := +def DDM.MessagePart.fromDDM (d : DDM.MessagePart SourceRange) : Specs.MessagePart := match d with | .strMessagePart _ ⟨_, s⟩ => .str s | .exprMessagePart _ e => .expr e.fromDDM -private def DDM.Assertion.fromDDM (d : DDM.Assertion SourceRange) : Specs.Assertion := +def DDM.Assertion.fromDDM (d : DDM.Assertion SourceRange) : Specs.Assertion := let .mkAssertion _ formula ⟨_, message⟩ := d { message := message.map (·.fromDDM), formula := formula.fromDDM } -private def DDM.FunDecl.fromDDM (d : DDM.FunDecl SourceRange) : Specs.FunctionDecl := +def DDM.FunDecl.fromDDM (d : DDM.FunDecl SourceRange) : FromDDM Specs.FunctionDecl := do let .mkFunDecl loc ⟨nameLoc, name⟩ ⟨_, args⟩ ⟨_, kwonly⟩ ⟨_, kwargs⟩ returnType ⟨_, isOverload⟩ ⟨_, preconditions⟩ ⟨_, postconditions⟩ := d - let kwargsOpt : Option (String × Specs.SpecType) := + let kwargsOpt : Option (String × Specs.SpecType) ← match kwargs with - | some (.mkKwargsDecl _ ⟨_, kn⟩ tp) => some (kn, tp.fromDDM) - | none => none - { + | some (.mkKwargsDecl _ ⟨_, kn⟩ tp) => + pure <| some (kn, ← tp.fromDDM) + | none => + pure none + pure { loc := loc nameLoc := nameLoc name := name args := { - args := args.map (·.fromDDM) - kwonly := kwonly.map (·.fromDDM) + args := ← args.mapM (·.fromDDM) + kwonly := ← kwonly.mapM (·.fromDDM) kwargs := kwargsOpt } - returnType := returnType.fromDDM + returnType := ← returnType.fromDDM isOverload := isOverload preconditions := preconditions.map (·.fromDDM) postconditions := postconditions.map fun | .mkPostconditionEntry _ e => e.fromDDM } -private def DDM.ClassDecl.fromDDM (d : DDM.ClassDecl SourceRange) : Specs.ClassDef := +def DDM.ClassDecl.fromDDM (d : DDM.ClassDecl SourceRange) : FromDDM Specs.ClassDef := do let .mkClassDecl ann ⟨_, name⟩ ⟨_, bases⟩ ⟨_, fields⟩ ⟨_, classVars⟩ ⟨_, subclasses⟩ ⟨_, methods⟩ ⟨_, exhaustive⟩ := d - { + pure { loc := ann name := name - bases := bases.map fun ⟨_, s⟩ => + bases := ← bases.mapM fun ⟨_, s⟩ => match PythonIdent.ofString s with - | some id => id - | none => panic! s!"Bad base class identifier: '{s}'" - fields := fields.map fun (.mkClassFieldDecl _ ⟨_, n⟩ tp ⟨_, cv⟩) => - { name := n, type := tp.fromDDM, constValue := cv.map (·.2) : ClassField } + | some id => pure id + | none => .throw ann s!"Bad base class identifier: '{s}'" + fields := ← fields.mapM fun (.mkClassFieldDecl _ ⟨_, n⟩ tp ⟨_, cv⟩) => do + pure { name := n, type := ← tp.fromDDM, constValue := cv.map (·.2) : ClassField } classVars := classVars.map fun (.mkClassVarDecl _ ⟨_, n⟩ ⟨_, v⟩) => { name := n, value := v : ClassVariable } - subclasses := subclasses.map (·.fromDDM) - methods := methods.map (·.fromDDM) + subclasses := ← subclasses.mapM (·.fromDDM) + methods := ← methods.mapM (·.fromDDM) exhaustive := exhaustive } -private def DDM.Command.fromDDM (cmd : DDM.Command SourceRange) : Specs.Signature := +def DDM.Command.fromDDM (cmd : DDM.Command SourceRange) : FromDDM Specs.Signature := match cmd with - | .externTypeDecl _ ⟨_, name⟩ ⟨_, ddmDefinition⟩ => - if let some definition := PythonIdent.ofString ddmDefinition then - .externTypeDecl name definition - else - panic! "Extern type decl definition has bad format." - | .classDef _ decl => - .classDef decl.fromDDM - | .functionDecl _ d => .functionDecl d.fromDDM - | .typeDef loc ⟨nameLoc, name⟩ definition => + | .externTypeDecl loc ⟨_, name⟩ ⟨_, ddmDefinition⟩ => + match PythonIdent.ofString ddmDefinition with + | some definition => .ok <| .externTypeDecl name definition + | none => .throw loc s!"Extern type decl definition has bad format: {ddmDefinition}" + | .classDef _ decl => do + pure <| .classDef (← decl.fromDDM) + | .functionDecl _ d => do + pure <| .functionDecl (← d.fromDDM) + | .typeDef loc ⟨nameLoc, name⟩ definition => do let d : TypeDef := { loc := loc nameLoc := nameLoc name := name - definition := definition.fromDDM + definition := ← definition.fromDDM } - .typeDef d + pure <| .typeDef d /-- Reads Python spec signatures from a DDM Ion file. -/ -def readDDM (path : System.FilePath) : EIO String (Array Signature) := do +public def readDDM (path : System.FilePath) : EIO String (Array Signature) := do let contents ← match ← IO.FS.readBinFile path |>.toBaseIO with | .ok r => pure r @@ -511,23 +524,21 @@ def readDDM (path : System.FilePath) : EIO String (Array Signature) := do let r := pgm.commands.mapM fun cmd => do let pySig ← DDM.Command.ofAst cmd - return pySig.fromDDM + match pySig.fromDDM with + | .ok sig => pure sig + | .error (_, msg) => .error msg match r with | .ok r => pure r | .error msg => throw msg | .error msg => throw msg /-- Converts Python spec signatures to a DDM program for serialization. -/ -def toDDMProgram (sigs : Array Signature) : Strata.Program := { - dialects := DDM.PythonSpecs_map - dialect := DDM.PythonSpecs.name - commands := sigs.map fun s => s.toDDM.toAst - } +def toDDMProgram (sigs : Array Signature) : Strata.Program := + .create DDM.PythonSpecs_map DDM.PythonSpecs.name (sigs.map fun s => s.toDDM.toAst) /-- Writes Python spec signatures to a DDM Ion file. -/ -def writeDDM (path : System.FilePath) (sigs : Array Signature) : IO Unit := do +public def writeDDM (path : System.FilePath) (sigs : Array Signature) : IO Unit := do let pgm := toDDMProgram sigs IO.FS.writeBinFile path <| pgm.toIon end Strata.Python.Specs -end diff --git a/Strata/Languages/Python/Specs/Decls.lean b/Strata/Languages/Python/Specs/Decls.lean index 21e2bc03c8..c1f2212b0c 100644 --- a/Strata/Languages/Python/Specs/Decls.lean +++ b/Strata/Languages/Python/Specs/Decls.lean @@ -6,39 +6,39 @@ module public import Std.Data.HashMap.Basic public import Strata.DDM.Util.SourceRange -public import Strata.Languages.Python.OverloadTable +public import Strata.Languages.Python.PythonIdent public section namespace Strata.Python namespace PythonIdent -def builtinsBool := mk "builtins" "bool" -def builtinsBytearray := mk "builtins" "bytearray" -def builtinsBytes := mk "builtins" "bytes" -def builtinsComplex := mk "builtins" "complex" -def builtinsDict := mk "builtins" "dict" -def builtinsException := mk "builtins" "Exception" -def builtinsFloat := mk "builtins" "float" -def builtinsInt := mk "builtins" "int" -def builtinsStr := mk "builtins" "str" -def noneType := mk "_types" "NoneType" - -def typingAny := mk "typing" "Any" -def typingBinaryIO := mk "typing" "BinaryIO" -def typingDict := mk "typing" "Dict" -def typingGenerator := mk "typing" "Generator" -def typingList := mk "typing" "List" -def typingLiteral := mk "typing" "Literal" -def typingMapping := mk "typing" "Mapping" -def typingOverload := mk "typing" "overload" -def typingSequence := mk "typing" "Sequence" -def typingTypedDict := mk "typing" "TypedDict" -def typingUnion := mk "typing" "Union" -def typingRequired := mk "typing" "Required" -def typingNotRequired := mk "typing" "NotRequired" -def typingUnpack := mk "typing" "Unpack" -def reCompile := mk "re" "compile" +def builtinsBool := ofComponent "builtins" "bool" +def builtinsBytearray := ofComponent "builtins" "bytearray" +def builtinsBytes := ofComponent "builtins" "bytes" +def builtinsComplex := ofComponent "builtins" "complex" +def builtinsDict := ofComponent "builtins" "dict" +def builtinsException := ofComponent "builtins" "Exception" +def builtinsFloat := ofComponent "builtins" "float" +def builtinsInt := ofComponent "builtins" "int" +def builtinsStr := ofComponent "builtins" "str" +def noneType := ofComponent "_types" "NoneType" + +def typingAny := ofComponent "typing" "Any" +def typingBinaryIO := ofComponent "typing" "BinaryIO" +def typingDict := ofComponent "typing" "Dict" +def typingGenerator := ofComponent "typing" "Generator" +def typingList := ofComponent "typing" "List" +def typingLiteral := ofComponent "typing" "Literal" +def typingMapping := ofComponent "typing" "Mapping" +def typingOverload := ofComponent "typing" "overload" +def typingSequence := ofComponent "typing" "Sequence" +def typingTypedDict := ofComponent "typing" "TypedDict" +def typingUnion := ofComponent "typing" "Union" +def typingRequired := ofComponent "typing" "Required" +def typingNotRequired := ofComponent "typing" "NotRequired" +def typingUnpack := ofComponent "typing" "Unpack" +def reCompile := ofComponent "re" "compile" end PythonIdent diff --git a/Strata/Languages/Python/Specs/Error.lean b/Strata/Languages/Python/Specs/Error.lean deleted file mode 100644 index 4894a24b0a..0000000000 --- a/Strata/Languages/Python/Specs/Error.lean +++ /dev/null @@ -1,75 +0,0 @@ -/- - Copyright Strata Contributors - - SPDX-License-Identifier: Apache-2.0 OR MIT --/ -module - -public import Strata.DDM.Util.SourceRange - -public section -namespace Strata.Python.Specs - -/-- A warning category for PySpec translation. - Uses an open vocabulary (string fields) so new categories can be added - without modifying an inductive type. -/ -structure WarningKind where - phase : String - category : String - deriving BEq, DecidableEq, Hashable, Ord, Repr - -instance : LT WarningKind where - lt a b := a.phase < b.phase ∨ (a.phase == b.phase ∧ a.category < b.category) - -instance (a b : WarningKind) : Decidable (a < b) := - inferInstanceAs (Decidable (a.phase < b.phase ∨ (a.phase == b.phase ∧ a.category < b.category))) - -namespace WarningKind - --- Type translation warnings -def unsupportedUnion : WarningKind := { phase := "pySpecToLaurel", category := "unsupportedUnion" } - --- Unsupported Optional patterns -def unsupportedOptionalFloat : WarningKind := { phase := "pySpecToLaurel", category := "unsupportedOptionalFloat" } -def unsupportedOptionalList : WarningKind := { phase := "pySpecToLaurel", category := "unsupportedOptionalList" } -def unsupportedOptionalDict : WarningKind := { phase := "pySpecToLaurel", category := "unsupportedOptionalDict" } -def unsupportedOptionalAny : WarningKind := { phase := "pySpecToLaurel", category := "unsupportedOptionalAny" } -def unsupportedOptionalBytes : WarningKind := { phase := "pySpecToLaurel", category := "unsupportedOptionalBytes" } - --- Internal type errors -def typeError : WarningKind := { phase := "pySpecToLaurel", category := "typeError" } - --- Precondition warnings -def placeholderExpr : WarningKind := { phase := "pySpecToLaurel", category := "placeholderExpr" } -def floatLiteral : WarningKind := { phase := "pySpecToLaurel", category := "floatLiteral" } -def isinstanceUnsupported : WarningKind := { phase := "pySpecToLaurel", category := "isinstanceUnsupported" } -def forallListUnsupported : WarningKind := { phase := "pySpecToLaurel", category := "forallListUnsupported" } -def forallDictUnsupported : WarningKind := { phase := "pySpecToLaurel", category := "forallDictUnsupported" } - --- Declaration warnings -def missingMethodSelf : WarningKind := { phase := "pySpecToLaurel", category := "missingMethodSelf" } -def kwargsExpansionError : WarningKind := { phase := "pySpecToLaurel", category := "kwargsExpansionError" } -def postconditionUnsupported : WarningKind := { phase := "pySpecToLaurel", category := "postconditionUnsupported" } - --- Overload dispatch warnings -def overloadNoArgs : WarningKind := { phase := "pySpecToLaurel", category := "overloadNoArgs" } -def overloadArgArity : WarningKind := { phase := "pySpecToLaurel", category := "overloadArgArity" } -def overloadArgNotStringLiteral : WarningKind := { phase := "pySpecToLaurel", category := "overloadArgNotStringLiteral" } -def overloadReturnArity : WarningKind := { phase := "pySpecToLaurel", category := "overloadReturnArity" } -def overloadReturnNotClass : WarningKind := { phase := "pySpecToLaurel", category := "overloadReturnNotClass" } - --- PySpec parsing phase (generic — callers don't yet distinguish categories) -def pySpecParsingError : WarningKind := { phase := "pySpecParsing", category := "error" } -def pySpecParsingWarning : WarningKind := { phase := "pySpecParsing", category := "warning" } - -end WarningKind - -/-- An error encountered while processing a PySpec file. -/ -structure SpecError where - file : System.FilePath - loc : SourceRange - kind : WarningKind - message : String - -end Strata.Python.Specs -end diff --git a/Strata/Languages/Python/Specs/IdentifyOverloads.lean b/Strata/Languages/Python/Specs/IdentifyOverloads.lean index ad441f595f..212acc5109 100644 --- a/Strata/Languages/Python/Specs/IdentifyOverloads.lean +++ b/Strata/Languages/Python/Specs/IdentifyOverloads.lean @@ -29,7 +29,7 @@ open Strata.Python (stmt expr keyword FunctionOverloads OverloadTable PythonIden /-- State accumulated while walking the AST. -/ public structure ResolveState where - modules : Std.HashSet String := {} + modules : Std.HashSet ModuleName := {} warnings : Array String := #[] /-- Monad for the overload-resolution walker. @@ -43,7 +43,7 @@ def warn (msg : String) : ResolveM Unit := { s with warnings := s.warnings.push msg } /-- Record a module name from a resolved overload. -/ -def recordModule (mod : String) : ResolveM Unit := +def recordModule (mod : ModuleName) : ResolveM Unit := modify fun s => { s with modules := s.modules.insert mod } diff --git a/Strata/Languages/Python/Specs/MessageKind.lean b/Strata/Languages/Python/Specs/MessageKind.lean new file mode 100644 index 0000000000..18ac867ddc --- /dev/null +++ b/Strata/Languages/Python/Specs/MessageKind.lean @@ -0,0 +1,73 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ +module +public import Strata.Pipeline.Messages + +public section +namespace Strata.Pipeline.MessageKind + +-- PySpec parsing phase +def pySpecReadError : MessageKind := + { category := "readError", impact := .configurationError } +def pySpecParsingError : MessageKind := + { category := "error", impact := .internalError } +def pySpecParsingWarning : MessageKind := + { category := "warning", impact := .knownLimitation } + +-- Overload dispatch errors (in PySpec-to-Laurel phase) +def overloadNoArgs : MessageKind := + { category := "overloadNoArgs", impact := .internalError } +def overloadReturnNotClass : MessageKind := + { category := "overloadReturnNotClass", impact := .internalError } +def overloadParamNameDisagreement : MessageKind := + { category := "overloadParamNameDisagreement", impact := .internalError } +def overloadArgNotStringLiteral : MessageKind := + { category := "overloadArgNotStringLiteral", impact := .internalError } + +-- Overload resolution phase +def overloadResolveWarning : MessageKind := + { category := "resolveWarning", impact := .internalWarning } + +-- PySpec.ToLaurel internal warnings/errors +def missingMethodSelf : MessageKind := + { category := "missingMethodSelf", impact := .internalWarning } +def typeError : MessageKind := + { category := "typeError", impact := .internalWarning } +def kwargsExpansionError : MessageKind := + { category := "kwargsExpansionError", impact := .internalWarning } + +-- Type translation warnings +def unsupportedUnion : MessageKind := + { category := "unsupportedUnion", impact := .knownLimitation } + +-- Precondition warnings +def placeholderExpr : MessageKind := + { category := "placeholderExpr", impact := .knownLimitation } +def floatLiteral : MessageKind := + { category := "floatLiteral", impact := .knownLimitation } +def isinstanceUnsupported : MessageKind := + { category := "isinstanceUnsupported", impact := .knownLimitation } +def forallListUnsupported : MessageKind := + { category := "forallListUnsupported", impact := .knownLimitation } +def forallDictUnsupported : MessageKind := + { category := "forallDictUnsupported", impact := .knownLimitation } + +-- PySpec-to-Laurel assembly phase +def functionSignatureError : MessageKind := + { category := "functionSignatureError", impact := .internalError } +def typeNameCollision : MessageKind := + { category := "typeNameCollision", impact := .internalError } +def procedureNameCollision : MessageKind := + { category := "procedureNameCollision", impact := .internalError } + +-- Module resolution phase +def invalidModuleName : MessageKind := + { category := "invalidModuleName", impact := .configurationError } +def missingPySpecModule : MessageKind := + { category := "missingPySpecModule", impact := .configurationError } + +end Strata.Pipeline.MessageKind +end diff --git a/Strata/Languages/Python/Specs/ToLaurel.lean b/Strata/Languages/Python/Specs/ToLaurel.lean index da75cc4076..2f8f07897b 100644 --- a/Strata/Languages/Python/Specs/ToLaurel.lean +++ b/Strata/Languages/Python/Specs/ToLaurel.lean @@ -7,12 +7,13 @@ module public import Strata.Languages.Laurel.Laurel import Strata.DDM.Format -import Strata.Languages.Python.OverloadTable import Strata.Languages.Python.PythonLaurelTypedExpr public import Strata.Languages.Python.Specs.Decls -public import Strata.Languages.Python.Specs.Error +public import Strata.Pipeline.Messages import Strata.Languages.Python.Specs.DDM import Strata.Util.DecideProp +public import Strata.Languages.Python.OverloadTable +import Strata.Languages.Python.Specs.MessageKind /-! # PySpec to Laurel Translation @@ -52,17 +53,20 @@ private def typeTestersMap : Std.HashMap PythonIdent String := /-- Fully qualified Laurel name for a `PythonIdent`: module dots become underscores. E.g., `"mylib.sub"` / `"Foo"` → `"mylib_sub_Foo"`. -/ def PythonIdent.toLaurelName (id : PythonIdent) : String := - let pfx := "_".intercalate (id.pythonModule.splitOn ".") - if pfx.isEmpty then id.name else pfx ++ "_" ++ id.name + id.toString (sep := "_") end -- public section end Strata.Python +namespace Strata.Python.Specs + +end Strata.Python.Specs + namespace Strata.Python.Specs.ToLaurel open Strata.Laurel open Strata.Python.Laurel -open Strata.Python.Specs (SpecError) +open Strata.Pipeline (PipelineMessage MessageKind Phase) /-! ## ToLaurelM Monad -/ @@ -75,7 +79,7 @@ structure ToLaurelContext where /-- State for PySpec to Laurel translation. -/ structure ToLaurelState where - errors : Array SpecError := #[] + errors : Array PipelineMessage := #[] procedures : Array Procedure := #[] types : Array TypeDefinition := #[] overloads : OverloadTable := {} @@ -87,9 +91,11 @@ structure ToLaurelState where /-- Monad for PySpec to Laurel translation. -/ abbrev ToLaurelM := ReaderT ToLaurelContext (StateM ToLaurelState) -/-- Report an error during translation. -/ -def reportError (kind : WarningKind) (loc : SourceRange) (message : String) : ToLaurelM Unit := do - let e : SpecError := ⟨(←read).filepath, loc, kind, message⟩ +/-- Report an error during translation. Phase is set to pySpecToLaurel since + this monad always runs during that phase. -/ +def reportError (kind : MessageKind) (loc : SourceRange) (message : String) : ToLaurelM Unit := do + let phase := Phase.base "pySpecToLaurel" + let e : PipelineMessage := ⟨(←read).filepath, loc, phase, kind, message⟩ modify fun s => { s with errors := s.errors.push e } def runChecked (act : ToLaurelM α) : ToLaurelM (α × Bool) := do @@ -108,24 +114,57 @@ def pushType (td : TypeDefinition) : ToLaurelM Unit := /-- Add an overload dispatch entry for a function. -/ def pushOverloadEntry (funcName : String) (paramName : String) - (literalValue : String) (returnType : PythonIdent) : ToLaurelM Unit := - modify fun s => - let existing := s.overloads.getD funcName {} - let updated : FunctionOverloads := { existing with - paramName := existing.paramName <|> some paramName - entries := existing.entries.insert literalValue returnType } - if existing.paramName.any (· != paramName) then - dbg_trace s!"Warning: overload entries for '{funcName}' disagree on \ - dispatch parameter name: existing '{existing.paramName.get!}', new '{paramName}'" - { s with overloads := s.overloads.insert funcName updated } - else - { s with overloads := s.overloads.insert funcName updated } - -/-- Prepend the module prefix to a name. Returns the name unchanged - if the prefix is empty. -/ + (literalValue : String) (returnType : PythonIdent) : ToLaurelM Unit := do + match (←get).overloads[funcName]? with + | none => + modify fun s => + let entry : FunctionOverloads := { + paramName := paramName + entries := {(literalValue, returnType)} + } + { s with overloads := s.overloads.insert funcName entry } + | some existing => + if existing.paramName != paramName then + reportError .overloadParamNameDisagreement default + s!"Overload entries for '{funcName}' disagree on dispatch parameter \ + name: existing '{existing.paramName}', new '{paramName}'" + modify fun s => + { s with overloads := s.overloads.modify funcName fun existing => + { existing with entries := existing.entries.insert literalValue returnType } + } + +/-- Extract an overload dispatch entry from an `@overload` function declaration. -/ +def extractOverloadEntry (func : FunctionDecl) : ToLaurelM Unit := do + let args := func.args.args + let .isTrue _ := decideProp (args.size > 0) + | reportError .overloadNoArgs func.loc + s!"Overloaded function '{func.name}' has no arguments" + return + let firstArgType := args[0].type + let literalValue ← + match firstArgType.asStringLiteral with + | some v => pure v + | none => + reportError .overloadArgNotStringLiteral func.loc + s!"Overloaded function '{func.name}': first argument \ + type '{firstArgType}' is not a \ + string literal (only string literal dispatch is \ + currently supported)" + return + let retType ← + match func.returnType.asIdent with + | some nm => pure nm + | none => + reportError .overloadReturnNotClass func.loc + s!"Overloaded function '{func.name}': return type \ + '{func.returnType}' is not a \ + class type" + return + pushOverloadEntry func.name args[0].name literalValue retType + +/-- Prepend the module prefix to a name. -/ def prefixName (name : String) : ToLaurelM String := do let ctx ← read - if ctx.modulePrefix.isEmpty then return name return ctx.modulePrefix ++ "_" ++ name /-! ## Helper Functions -/ @@ -570,38 +609,6 @@ def typeDefToLaurel (td : TypeDef) : ToLaurelM Unit := do instanceProcedures := [] }) -/-- Extract an overload dispatch entry from an `@overload` function declaration. - Looks for a `stringLiteral` in the first argument's type and an `.ident` - return type, and records them in the dispatch table. -/ -def extractOverloadEntry (func : FunctionDecl) : ToLaurelM Unit := do - let args := func.args.args - let .isTrue _ := decideProp (args.size > 0) - | reportError .overloadNoArgs func.loc - s!"Overloaded function '{func.name}' has no arguments" - return - let firstArgType := args[0].type - let literalValue ← - match firstArgType.asStringLiteral with - | some v => pure v - | none => - reportError .overloadArgNotStringLiteral func.loc - s!"Overloaded function '{func.name}': first argument \ - type '{firstArgType}' is not a \ - string literal (only string literal dispatch is \ - currently supported)" - return - let retType ← - match func.returnType.asIdent with - | some nm => pure nm - | none => - reportError .overloadReturnNotClass func.loc - s!"Overloaded function '{func.name}': return type \ - '{func.returnType}' is not a \ - class type" - return - -- args[0].name is the formal parameter name from the PySpec (not a call-site argument) - pushOverloadEntry func.name args[0].name literalValue retType - /-- Convert a single PySpec signature to Laurel declarations. -/ def signatureToLaurel (sig : Signature) : ToLaurelM Unit := match sig with @@ -623,7 +630,7 @@ def signatureToLaurel (sig : Signature) : ToLaurelM Unit := /-- Result of translating PySpec signatures to Laurel. -/ public structure TranslationResult where program : Laurel.Program - errors : Array SpecError + errors : Array PipelineMessage overloads : OverloadTable /-- Maps unprefixed class names to prefixed names for type resolution. -/ typeAliases : Std.HashMap String String := {} @@ -633,9 +640,12 @@ public structure TranslationResult where /-- Run the translation and return a Laurel Program, dispatch table, and any errors. -/ public def signaturesToLaurel (filepath : System.FilePath) (sigs : Array Signature) - (modulePrefix : String) + (moduleName : ModuleName) : TranslationResult := - let ctx : ToLaurelContext := { filepath, modulePrefix } + let ctx : ToLaurelContext := { + filepath, + modulePrefix := moduleName.toString (sep := "_") + } let ((), state) := (sigs.forM signatureToLaurel).run ctx |>.run {} let pgm : Laurel.Program := { staticProcedures := state.procedures.toList @@ -653,7 +663,7 @@ public def signaturesToLaurel (filepath : System.FilePath) (sigs : Array Signatu Processes `@overload` function declarations, ignoring classDef, typeDef, externTypeDecl, and non-overload functions. -/ public def extractOverloads (filepath : System.FilePath) (sigs : Array Signature) - : OverloadTable × Array SpecError := + : OverloadTable × Array PipelineMessage := let ctx : ToLaurelContext := { filepath, modulePrefix := "" } let action := sigs.forM fun sig => match sig with @@ -664,4 +674,5 @@ public def extractOverloads (filepath : System.FilePath) (sigs : Array Signature let ((), state) := action.run ctx |>.run {} (state.overloads, state.errors) + end Strata.Python.Specs.ToLaurel diff --git a/Strata/Pipeline/Context.lean b/Strata/Pipeline/Context.lean new file mode 100644 index 0000000000..408d38819d --- /dev/null +++ b/Strata/Pipeline/Context.lean @@ -0,0 +1,369 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ +module +public import Strata.Pipeline.Messages +import Lean.Data.Json.Printer +import all Strata.DDM.Util.String + +namespace Strata.Pipeline + +/-- Print to stdout and flush. -/ +def printlnFlush (msg : String) : BaseIO Unit := do + let _ ← (do IO.println msg; (← IO.getStdout).flush : IO Unit).toBaseIO + +/-- Output verbosity mode for the pipeline. -/ +public inductive OutputMode where + | quiet + | default + | profile + | verbose + deriving BEq, DecidableEq, Repr + +/-- Aggregated data for a single repeated phase (recursive for arbitrary nesting). -/ +structure RepeatedPhaseData where + count : Nat + totalNs : Nat + /-- Aggregated timings for nested subphases, in first-seen order. -/ + children : Array (String × RepeatedPhaseData) := #[] + deriving Inhabited + +namespace RepeatedPhaseData + +/-- Merge `incoming` children into `existing`, summing counts/totals for + matching names and recursively merging their children. -/ +partial def mergeChildren (existing incoming : Array (String × RepeatedPhaseData)) + : Array (String × RepeatedPhaseData) := + incoming.foldl (init := existing) fun acc (name, data) => + match acc.findIdx? (·.1 == name) with + | some idx => + let (_, prev) := acc[idx]! + acc.modify idx fun _ => (name, { + count := prev.count + data.count, + totalNs := prev.totalNs + data.totalNs, + children := mergeChildren prev.children data.children }) + | none => acc.push (name, data) + +end RepeatedPhaseData + +/-- Upsert a repeated phase entry: find by name in the array, merge elapsed + and children, or append a new entry. -/ +def addRepeatedEntry (arr : Array (String × RepeatedPhaseData)) + (name : String) (elapsed : Nat) (children : Array (String × RepeatedPhaseData)) + : Array (String × RepeatedPhaseData) := + match arr.findIdx? (·.1 == name) with + | some idx => + let (_, prev) := arr[idx]! + arr.modify idx fun _ => (name, { + count := prev.count + 1, + totalNs := prev.totalNs + elapsed, + children := RepeatedPhaseData.mergeChildren prev.children children }) + | none => arr.push (name, { count := 1, totalNs := elapsed, children }) + +/-- Per-phase scoped state: saved on phase entry, restored on exit. + Bundled into a single ref to ensure atomic save/restore. -/ +structure PhaseState where + repeatedPhases : Array (String × RepeatedPhaseData) := #[] + messageCounts : Std.HashMap String Nat := {} + deriving Inhabited + +/-- Pipeline context carrying immutable config and mutable state as individual IORefs. + This design allows any monad with BaseIO access to use pipeline capabilities + by passing a PipelineContext value directly. + + **Phase tracking state machine:** + + The phase system operates in two modes controlled by `repeatedDepthRef`: + - Mode N (normal, depth = 0): `withPhase` records individual timing entries + and prints `[start]`/`[end]` in profile mode. + - Mode R (repeated, depth > 0): `withPhase` silently aggregates timing into + `phaseStateRef.repeatedPhases` — no print, no individual `timingRef` entry. + + `withRepeatedPhase` increments `repeatedDepthRef` on entry and decrements on + exit. `withPhase` never changes the depth. + + **Invariants:** + - `currentPhaseRef` always reflects the innermost active scope's full path. + - `phaseStateRef` is scoped: saved on entry, restored on exit of both + `withPhase` and `withRepeatedPhase` — no cross-scope leakage. + - In mode R, all timing flows through `phaseStateRef.repeatedPhases` only; + `timingRef` is not touched until the enclosing mode-N `withPhase` flushes. + + **Thread safety:** PipelineContext is NOT thread-safe. The phase-tracking + refs assume single-threaded sequential access. Concurrent `withPhase` or + `withRepeatedPhase` calls on the same context will corrupt state. -/ +public structure PipelineContext where + private mk :: + outputMode : OutputMode + private pipelineStartTime : Nat + private profilePipeline : Bool := true + private messagesRef : IO.Ref (Array PipelineMessage) + private toolErrorsRef : IO.Ref (Array PipelineMessage) + private userCodeErrorsRef : IO.Ref (Array PipelineMessage) + /-- Full path of the innermost active phase. Managed via `push`/`pop` + by `withPhase` and `withRepeatedPhase` — `emitMessage` stamps this + on each diagnostic. -/ + private currentPhaseRef : IO.Ref Phase + /-- Nesting depth of `withRepeatedPhase` scopes. When > 0, `withPhase` + aggregates silently instead of recording individual timing entries. -/ + private repeatedDepthRef : IO.Ref Nat + /-- Per-phase scoped state (repeated subphases + message counts). + Saved and cleared on phase entry, restored on exit — each phase + sees only its own data. -/ + private phaseStateRef : IO.Ref PhaseState + /-- Caller-owned handle for JSONL metrics. The pipeline appends and flushes + per record but does not open or close the handle. -/ + private metricsHandle : Option IO.FS.Handle := none + +namespace PipelineContext + +/-- Create a fresh PipelineContext with new state refs. -/ +public def create + (outputMode : OutputMode := .default) + (profilePipeline : Bool := true) + (metricsHandle : Option IO.FS.Handle := none) : BaseIO PipelineContext := do + let startTime ← IO.monoNanosNow + let messagesRef ← IO.mkRef (α := Array PipelineMessage) #[] + let toolErrorsRef ← IO.mkRef (α := Array PipelineMessage) #[] + let userCodeErrorsRef ← IO.mkRef (α := Array PipelineMessage) #[] + let currentPhaseRef ← IO.mkRef (α := Phase) default + let repeatedDepthRef ← IO.mkRef 0 + let phaseStateRef ← IO.mkRef (α := PhaseState) {} + return { outputMode, pipelineStartTime := startTime, profilePipeline, + messagesRef, toolErrorsRef, userCodeErrorsRef, + currentPhaseRef, repeatedDepthRef, phaseStateRef, metricsHandle } + +/-- All accumulated pipeline messages. -/ +public def getMessages (ctx : PipelineContext) : BaseIO (Array PipelineMessage) := + ctx.messagesRef.get + +/-- Messages with `.internalError` or `.configurationError` impact. + These represent tool bugs or invalid invocations that we must fix. -/ +public def getToolErrors (ctx : PipelineContext) : BaseIO (Array PipelineMessage) := + ctx.toolErrorsRef.get + +/-- Messages with `.userCodeIssue` impact. + These represent definite errors in the user's Python source code. -/ +public def getUserCodeErrors (ctx : PipelineContext) : BaseIO (Array PipelineMessage) := + ctx.userCodeErrorsRef.get + +/-- Write a JSONL metric record to the metrics file (if open) and flush. -/ +public def emitMetric (ctx : PipelineContext) (json : Lean.Json) : BaseIO Unit := do + if let some h := ctx.metricsHandle then + let _ ← (do h.putStrLn json.compress; h.flush : IO Unit).toBaseIO + +/-- Get elapsed nanoseconds since pipeline start. -/ +public def elapsedNs (ctx : PipelineContext) : BaseIO Nat := do + let now ← IO.monoNanosNow + return now - ctx.pipelineStartTime + +/-- Common entry logic for `withPhase`: push the subphase name onto + `currentPhaseRef`, save and clear scoped phase state. -/ +def enterPhase (ctx : PipelineContext) (name : String) + : BaseIO PhaseState := do + ctx.currentPhaseRef.modify (·.subphase name) + ctx.phaseStateRef.modifyGet fun ps => (ps, {}) + +/-- Recursively print `[profile]` lines and emit JSONL metrics for aggregated + repeated phases. `parentPhase` is the phase under which these entries + are nested. -/ +partial def flushRepeatedEntries (ctx : PipelineContext) + (parentPhase : Phase) (entries : Array (String × RepeatedPhaseData)) + : BaseIO Unit := do + if entries.isEmpty then return + let childIndent := String.replicate (parentPhase.depth * 2) ' ' + for (name, data) in entries do + let subphase := parentPhase.subphase name + if ctx.outputMode == .profile || ctx.outputMode == .verbose then + let avg := if data.count > 0 then nsToMs (data.totalNs / data.count) else 0 + let timeSuffix := + if ctx.profilePipeline then + s!" (×{data.count}, total: {nsToMs data.totalNs}ms, avg: {avg}ms)" + else + "" + printlnFlush s!"{childIndent}[profile] {name}{timeSuffix}" + ctx.emitMetric (Lean.Json.mkObj [ + ("type", .str "timing"), ("phase", .str subphase.display), + ("start_ms", .num 0), ("end_ms", .num (nsToMs data.totalNs)), + ("count", .num data.count)]) + flushRepeatedEntries ctx subphase data.children + +/-- Mode-N entry: print [start] and return the start time. -/ +def enterPhaseNormal (ctx : PipelineContext) : BaseIO Nat := do + let phase ← ctx.currentPhaseRef.get + let startNs ← ctx.elapsedNs + if ctx.outputMode == .profile || ctx.outputMode == .verbose then + let indent := String.replicate ((phase.depth - 1) * 2) ' ' + let timeSuffix := if ctx.profilePipeline then s!" (time: {nsToMs startNs}ms)" else "" + printlnFlush s!"{indent}[start] {phase.leaf}{timeSuffix}" + return startNs + +/-- End the current phase in mode N: flush aggregated repeated subphases, + emit timing metric, print [end]/[warnings], then pop phase and restore state. -/ +def exitPhaseNormal (ctx : PipelineContext) + (saved : PhaseState) (startNs : Nat) : BaseIO Unit := do + let currentPhase ← ctx.currentPhaseRef.modifyGet fun p => (p, p.pop) + let ps ← ctx.phaseStateRef.modifyGet fun ps => (ps, saved) + flushRepeatedEntries ctx currentPhase ps.repeatedPhases + let now ← ctx.elapsedNs + ctx.emitMetric (Lean.Json.mkObj [ + ("type", .str "timing"), + ("phase", .str currentPhase.display), + ("start_ms", .num (nsToMs startNs)), + ("end_ms", .num (nsToMs now))]) + if ctx.outputMode == .profile || ctx.outputMode == .verbose then + let indent := String.replicate ((currentPhase.depth - 1) * 2) ' ' + let timeSuffix := if ctx.profilePipeline then s!" (time: {nsToMs now}ms)" else "" + printlnFlush s!"{indent}[end] {currentPhase.leaf}{timeSuffix}" + unless ps.messageCounts.isEmpty do + let parts := ps.messageCounts.toArray.map fun (cat, n) => s!"{n} {cat}" + let summary := String.intercalate ", " parts.toList + printlnFlush s!"{indent}[warnings] {currentPhase.leaf}: {summary}" + +/-- Mode-R exit for `withPhase`: accumulate elapsed time and nested children + into the saved repeated-phases array, then pop phase and restore state. -/ +def exitPhaseRepeated (ctx : PipelineContext) + (saved : PhaseState) (startNs : Nat) : BaseIO Unit := do + let now ← IO.monoNanosNow + let elapsed := now - startNs + let currentPhase ← ctx.currentPhaseRef.modifyGet fun p => (p, p.pop) + ctx.phaseStateRef.modify fun ps => + let children := ps.repeatedPhases + { saved with + repeatedPhases := + addRepeatedEntry saved.repeatedPhases currentPhase.leaf elapsed children } + +/-- Run an action as a named subphase of the current phase. + Nesting is determined by call structure — at the root the phase is + top-level, inside another `withPhase` it becomes a child. + + Outside a repeated phase: pushes a timing entry to `timingRef`, + prints `[start]`/`[end]` in profile/verbose mode, and flushes any + aggregated repeated subphases on exit. + + Inside a repeated phase (i.e. the action may run many times): + silently accumulates elapsed time into the enclosing + `repeatedPhasesRef`. No print, no individual timing entry. -/ +@[noinline] +public def withPhase {m α} [Monad m] [MonadLiftT BaseIO m] [MonadFinally m] + (ctx : PipelineContext) (name : String) (action : m α) : m α := do + let inRepeatedCnt ← ctx.repeatedDepthRef.get (m := BaseIO) + if inRepeatedCnt > 0 then + let saved ← ctx.enterPhase name + let startNs ← IO.monoNanosNow + try + action + finally + ctx.exitPhaseRepeated saved startNs + else + let saved ← ctx.enterPhase name + let startNs ← ctx.enterPhaseNormal + try + action + finally + ctx.exitPhaseNormal saved startNs + +/-- Run an action as a repeated subphase. Instead of recording individual + timing entries, accumulates count and total duration into the parent's + repeated-phases array. When the parent phase ends, the aggregated results + are flushed as single timing entries. Silent per-iteration. + + Sets `currentPhaseRef` so nested `emitMessage` calls get the correct + phase tag. Increments `repeatedDepthRef` so nested `withPhase` calls + aggregate silently. Saves/restores `phaseStateRef` for child + isolation. -/ +@[noinline] +public def withRepeatedPhase {m α} [Monad m] [MonadLiftT BaseIO m] [MonadFinally m] + (ctx : PipelineContext) (name : String) (action : m α) : m α := do + let saved ← ctx.enterPhase name + ctx.repeatedDepthRef.modify (m := BaseIO) (· + 1) + let startNs ← IO.monoNanosNow + try + action + finally + ctx.repeatedDepthRef.modify (m := BaseIO) (· - 1) + ctx.exitPhaseRepeated saved startNs + +/-- Time a pure expression as a repeated subphase. The `@[noinline]` + attribute prevents the compiler from hoisting `expr` outside the + timing window. Use this instead of `withRepeatedPhase` when the work + being timed is a pure (non-monadic) expression. -/ +@[noinline] +public def withRepeatedPhasePure {α} (ctx : PipelineContext) (name : String) + (expr : Unit → α) : BaseIO α := do + ctx.withRepeatedPhase (m := ReaderT Unit BaseIO) name (pure ∘ expr) () + +end PipelineContext + +/-- The pipeline monad: a reader over PipelineContext with EIO Unit. + Computations accumulate diagnostic messages in PipelineContext.messagesRef. + `emitMessageAndAbort` throws `()` to abort, but multiple messages (including + multiple error-impact messages) may accumulate before or across aborts. + The caller of `PipelineM.run` is responsible for inspecting the accumulated + messages and the outcome to determine the appropriate exit code. -/ +public abbrev PipelineM := ReaderT PipelineContext (EIO Unit) + +/-- Get the current phase from the pipeline context. -/ +public def getPhase : PipelineM Phase := do + let ctx ← read + ctx.currentPhaseRef.get + +/-- PipelineM wrapper for withPhase. -/ +@[noinline] +public def withPhase {α} (name : String) (action : PipelineM α) : PipelineM α := do + let ctx ← read + ctx.withPhase name (action.run ctx) + +/-- Append a pre-built PipelineMessage, emit metrics, and print in verbose mode. + Also buckets the message into specialized refs by impact. Does not throw. -/ +public def addMessage (msg : Pipeline.PipelineMessage) : Pipeline.PipelineM Unit := do + let ctx ← read + ctx.messagesRef.modify (·.push msg) + ctx.phaseStateRef.modify fun ps => + { ps with messageCounts := ps.messageCounts.alter msg.kind.category fun mv => some (mv.getD 0 + 1) } + match msg.kind.impact with + | .internalError | .configurationError => ctx.toolErrorsRef.modify (·.push msg) + | .userCodeIssue => ctx.userCodeErrorsRef.modify (·.push msg) + | _ => pure () + let mut fields : List (String × Lean.Json) := [ + ("type", .str "diagnostic"), ("phase", .str msg.phase.display), + ("file", .str msg.file.toString), ("category", .str msg.kind.category), + ("impact", .str (toString msg.kind.impact)), ("message", .str msg.message)] + unless msg.loc == default do + fields := fields ++ [("start", .num msg.loc.start.byteIdx), ("stop", .num msg.loc.stop.byteIdx)] + ctx.emitMetric (Lean.Json.mkObj fields) + if ctx.outputMode == .verbose then + let tag := toString msg.kind.impact + let indent := String.replicate ((msg.phase.depth - 1) * 2) ' ' + let _ ← (do IO.eprintln s!"{indent}[{tag}] {msg.file}: {msg.message}"; (← IO.getStderr).flush : IO Unit).toBaseIO + +/-- Emit a diagnostic message and continue. Tags with current phase. + The impact classification is for downstream consumers — callers may + accumulate multiple fatal-impact messages before aborting. -/ +public def emitMessage (kind : Pipeline.MessageKind) (message : String) + (file : System.FilePath := default) (loc : SourceRange := default) : Pipeline.PipelineM Unit := do + let phase ← getPhase + addMessage { file, loc, phase, kind, message } + +/-- Emit a diagnostic message and abort the pipeline. + Polymorphic return type allows use in expression position. -/ +public def emitMessageAndAbort (kind : Pipeline.MessageKind) (message : String) + (file : System.FilePath) (loc : SourceRange := default) : Pipeline.PipelineM α := do + emitMessage kind message file loc + throw () + +/-- All messages with a given impact. -/ +public def getMessagesByImpact (impact : MessageImpact) : PipelineM (Array PipelineMessage) := do + let ctx ← read + let msgs ← ctx.messagesRef.get + return msgs.filter (·.kind.impact == impact) + +/-- Whether any accumulated message has the given impact. -/ +public def hasImpact (impact : MessageImpact) : PipelineM Bool := do + let ctx ← read + let msgs ← ctx.messagesRef.get + return msgs.any (·.kind.impact == impact) + +end Strata.Pipeline diff --git a/Strata/Pipeline/Diagnostic.lean b/Strata/Pipeline/Diagnostic.lean new file mode 100644 index 0000000000..91d65ac1a1 --- /dev/null +++ b/Strata/Pipeline/Diagnostic.lean @@ -0,0 +1,42 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ +module + +public import Strata.Pipeline.Messages +public import Strata.Util.FileRange + +namespace Strata.Pipeline + +open Strata (DiagnosticType DiagnosticModel FileRange Uri) + +/-- Map a `DiagnosticType` to a `MessageKind`. + Each diagnostic severity maps to a category and impact. -/ +public def MessageKind.fromDiagnosticType : DiagnosticType → MessageKind + | .Warning => + { category := "warning", impact := .internalWarning } + | .UserError => + { category := "userError", impact := .userCodeIssue } + | .NotYetImplemented => + { category := "notYetImplemented", impact := .knownLimitation } + | .StrataBug => + { category := "error", impact := .internalError } + +/-- Convert a `DiagnosticModel` to a `PipelineMessage` using the given phase. -/ +public def PipelineMessage.fromDiagnostic (phase : Phase) (d : DiagnosticModel) : PipelineMessage := + let file : System.FilePath := match d.fileRange.file with + | .file path => path + { file + loc := d.fileRange.range + phase + kind := MessageKind.fromDiagnosticType d.type + message := d.message } + +/-- Convert a list of `DiagnosticModel` values to pipeline messages. -/ +public def PipelineMessage.fromDiagnostics (phase : Phase) (ds : List DiagnosticModel) + : Array PipelineMessage := + ds.toArray.map (PipelineMessage.fromDiagnostic phase) + +end Strata.Pipeline diff --git a/Strata/Pipeline/Messages.lean b/Strata/Pipeline/Messages.lean new file mode 100644 index 0000000000..4d31096cd8 --- /dev/null +++ b/Strata/Pipeline/Messages.lean @@ -0,0 +1,133 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ +module + +public import Strata.DDM.Util.SourceRange +import all Strata.DDM.Util.String + +public section +namespace Strata.Pipeline + +/-- Nanoseconds to milliseconds with rounding. -/ +def nsToMs (ns : Nat) : Nat := (ns + 500000) / 1000000 + +/-- A phase represents a position in the phase hierarchy. + Top-level phases have a single entry; subphases have multiple. + Ordering is determined by position in the timing array, not by name. -/ +structure Phase where + path : Array String := #[] + deriving BEq, DecidableEq, Hashable, Repr, Inhabited + +namespace Phase + +def base (name : String) : Phase := + { path := #[name] } + +def pop (p : Phase) : Phase := { path := p.path.pop } + +def subphase (parent : Phase) (name : String) : Phase := + { path := parent.path.push name } + +def depth (p : Phase) : Nat := p.path.size + +def leaf (p : Phase) : String := + match p.path.back? with + | some name => name + | none => "" + +def display (p : Phase) : String := + String.intercalate "." p.path.toList + +instance : ToString Phase where + toString p := p.display + +end Phase + + +/-- How severe / actionable is this message? -/ +inductive MessageImpact where + /-- An unexpected failure that prevented some output from being generated + (e.g., a malformed overload entry that was skipped). -/ + | internalError + /-- An unexpected condition that did not prevent output, but may indicate + a tool bug worth investigating. -/ + | internalWarning + /-- A known, documented limitation that may cause specs to be incomplete + or imprecise. -/ + | knownLimitation + /-- An issue detected in the user source code. -/ + | userCodeIssue + /-- The tool was invoked with invalid arguments or the on-disk pyspecs + are invalid (e.g., missing module, unreadable file). -/ + | configurationError + deriving BEq, DecidableEq, Hashable, Ord, Repr + +/-- +Whether this impact level typically warrants aborting the pipeline. + +N.B. Pipeline steps may want a custom abort strategy rather than +relying on this predicate. +-/ +def MessageImpact.isFatal : MessageImpact → Bool + | .internalError => true + | .configurationError => true + | .internalWarning => false + | .knownLimitation => false + | .userCodeIssue => true + +instance : ToString MessageImpact where + toString + | .internalError => "internalError" + | .internalWarning => "internalWarning" + | .knownLimitation => "knownLimitation" + | .userCodeIssue => "userCodeIssue" + | .configurationError => "configurationError" + +/-- A categorized message kind with category and impact. + The phase is derived from pipeline context at emit time. -/ +structure MessageKind where + category : String + impact : MessageImpact + deriving BEq, DecidableEq, Hashable, Ord, Repr + +instance : ToString MessageKind where + toString mk := mk.category + +namespace MessageKind + +-- Laurel lowering phase +def laurelLoweringError : MessageKind := + { category := "error", impact := .internalError } +def laurelLoweringNotImpl : MessageKind := + { category := "notYetImplemented", impact := .knownLimitation } +def laurelLoweringUserError : MessageKind := + { category := "userError", impact := .userCodeIssue } + +-- Laurel-to-Core translation phase +def laurelToCoreError : MessageKind := + { category := "error", impact := .internalError } + +-- Verification phase +def verificationError : MessageKind := + { category := "error", impact := .internalError } +def verificationTimeout : MessageKind := + { category := "solverTimeout", impact := .knownLimitation } + +end MessageKind + +/-- A located, categorized pipeline message. -/ +structure PipelineMessage where + file : System.FilePath + loc : SourceRange + phase : Phase + kind : MessageKind + message : String + +instance : ToString PipelineMessage where + toString m := s!"{m.file}: {m.phase}.{m.kind}: {m.message}" + +end Strata.Pipeline +end diff --git a/Strata/Pipeline/PyAnalyzeLaurel.lean b/Strata/Pipeline/PyAnalyzeLaurel.lean new file mode 100644 index 0000000000..a03fe7d187 --- /dev/null +++ b/Strata/Pipeline/PyAnalyzeLaurel.lean @@ -0,0 +1,140 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ +module + +public import Strata.Pipeline.Diagnostic +public import Strata.Util.Statistics +public import Strata.Languages.Core.EntryPoint +public import Strata.Languages.Core.Verifier +import Strata.Languages.Python.PySpecPipeline +import Strata.Languages.Python.PyFactory +import Strata.Languages.Laurel.Grammar.AbstractToConcreteTreeTranslator +import Strata.SimpleAPI + +namespace Strata.Pipeline + +/-- The outcome of the full pyAnalyzeLaurel pipeline. + Error details are derived from the accumulated messages in PipelineContext. -/ +public inductive PyAnalyzeOutcome where + /-- Pipeline completed verification successfully. -/ + | verified (vcResults : _root_.Core.VCResults) (coreProgram : Core.Program) + /-- Pipeline aborted due to a fatal error. -/ + | failed + +/-- Configuration for the pyAnalyzeLaurel pipeline. -/ +public structure PyAnalyzeConfig where + filePath : String + specDir : System.FilePath + dispatchModules : Array String := #[] + pyspecModules : Array String := #[] + sourcePath : Option String := none + keepAllFilesPrefix : Option String := none + verifyOptions : Core.VerifyOptions + entryPoint : Core.EntryPoint := Core.EntryPoint.roots + isBugFinding : Bool := true + outputMode : OutputMode := .default + skipVerification : Bool := false + profilePipeline : Bool := true + metricsHandle : Option IO.FS.Handle := none + +private def runPipeline (config : PyAnalyzeConfig) + : PipelineM (PyAnalyzeOutcome × Statistics) := do + let combinedLaurel ← withPhase "pythonAndSpecToLaurel" do + Strata.pythonAndSpecToLaurel + (specDir := config.specDir) + config.filePath config.dispatchModules config.pyspecModules config.sourcePath + + if config.outputMode == .verbose then + let _ ← (show IO Unit from do + IO.println "---- BEGIN Laurel Program ----" + IO.println (toString (Std.format combinedLaurel)) + IO.println "---- END Laurel Program ----").toBaseIO + + let uri := config.sourcePath.getD config.filePath + + let (coreProgram, laurelPassStats) ← withPhase "laurelToCore" do + let ctx ← read + let laurelResult ← + Strata.translateCombinedLaurelWithLowered combinedLaurel + (keepAllFilesPrefix := config.keepAllFilesPrefix) + (pipelineCtx := some ctx) |>.toBaseIO + match laurelResult with + | .ok (coreOpt, diags, _, stats) => + let phase ← getPhase + for msg in PipelineMessage.fromDiagnostics phase diags do + addMessage msg + if msg.kind.impact.isFatal then throw () + match coreOpt with + | some core => pure (core, stats) + | none => + emitMessageAndAbort (file := uri) .laurelToCoreError s!"Laurel to Core translation failed: {diags}" + | .error e => + emitMessageAndAbort (file := uri) .laurelToCoreError s!"Laurel translation error: {e}" + + if config.outputMode == .verbose then + let _ ← (show IO Unit from do + IO.println "---- BEGIN Core Program ----" + IO.println (toString coreProgram) + IO.println "---- END Core Program ----").toBaseIO + + if config.skipVerification then + return (PyAnalyzeOutcome.verified #[] coreProgram, laurelPassStats) + + let verifyResult ← withPhase "verification" do + let ctx ← read + let userSourcePath := config.sourcePath.getD config.filePath + let (_, userProcNames) := Strata.splitProcNames coreProgram [userSourcePath] + let (proceduresToVerify, inlinePhases) := + if config.isBugFinding then + let ⟨p, i⟩ := Core.chooseEntryProceduresAndBuildInlinePhases + coreProgram userProcNames config.entryPoint + (p, [i]) + else (userProcNames, []) + Strata.Core.verifyProgram coreProgram config.verifyOptions + (moreFns := Strata.Python.ReFactory) + (proceduresToVerify := some proceduresToVerify) + (externalPhases := [Strata.frontEndPhase]) + (prefixPhases := inlinePhases) + (keepAllFilesPrefix := config.keepAllFilesPrefix) + (pipelineCtx := some ctx) + |>.toBaseIO + + let vcResults ← + match verifyResult with + | .ok r => + pure r.mergeByAssertion + | .error msg => + emitMessageAndAbort (file := uri) .verificationError msg + + for vcResult in vcResults do + match vcResult.outcome with + | .error (.encoding msg) => + emitMessageAndAbort (file := uri) .verificationError msg + | .error (.solverTimeout msg) => + emitMessage .verificationTimeout msg + | .error (.solverCrash msg) => + emitMessageAndAbort (file := uri) .verificationError msg + | .ok _ => pure () + + return (PyAnalyzeOutcome.verified vcResults coreProgram, laurelPassStats) + +/-- Run the full pyAnalyzeLaurel pipeline: Python+PySpec to Laurel, + Laurel to Core, then SMT verification. + + Accumulates pipeline messages from all phases. The caller is responsible + for inspecting the outcome and accumulated messages to determine exit codes. -/ +public def runPyAnalyzePipeline (config : PyAnalyzeConfig) + : IO (PyAnalyzeOutcome × Statistics × PipelineContext) := do + let ctx ← PipelineContext.create + (outputMode := config.outputMode) + (profilePipeline := config.profilePipeline) + (metricsHandle := config.metricsHandle) + let result ← runPipeline config |>.run ctx |>.toBaseIO + match result with + | .ok (outcome, stats) => return (outcome, stats, ctx) + | .error () => return (.failed, {}, ctx) + +end Strata.Pipeline diff --git a/Strata/SimpleAPI.lean b/Strata/SimpleAPI.lean index 4e536c732b..449b8e61f7 100644 --- a/Strata/SimpleAPI.lean +++ b/Strata/SimpleAPI.lean @@ -62,7 +62,7 @@ public section namespace Strata -open Strata.Python.Specs (ModuleName) +open Strata.Python (ModuleName) /-! ### File I/O -/ @@ -328,11 +328,17 @@ def Core.verifyProgram (externalPhases : List Core.AbstractedPhase := []) (prefixPhases : List Core.PipelinePhase := []) (keepAllFilesPrefix : Option String := none) + (solver : Option Core.CoreSMTSolver := none) + (mkDischarge : Core.MkDischargeFn := Core.mkDischargeFn) + (pipelineCtx : Option Pipeline.PipelineContext := none) : EIO String Core.VCResults := do let runVerification (tempDir : System.FilePath) : IO Core.VCResults := EIO.toIO (IO.Error.userError ∘ toString) (Core.verify program tempDir proceduresToVerify options moreFns externalPhases prefixPhases - (keepAllFilesPrefix := keepAllFilesPrefix)) + (keepAllFilesPrefix := keepAllFilesPrefix) + (solver := solver) + (mkDischarge := mkDischarge) + (pipelineCtx := pipelineCtx)) let ioAction := match options.vcDirectory with | .some vcDir => IO.FS.createDirAll vcDir *> runVerification vcDir | .none => IO.FS.withTempDir runVerification @@ -398,32 +404,29 @@ inductive WarningOutput where deriving Inhabited, BEq /-- Recursively discover all Python modules under a directory. - Returns `(moduleName, filePath)` pairs. The `components` array - accumulates directory names as we recurse, forming the dotted - module name prefix. -/ + Returns `(moduleName, filePath)` pairs. -/ private partial def discoverModules (sourceDir : System.FilePath) : IO (Array (ModuleName × System.FilePath)) := do - let rec go (dir : System.FilePath) (components : Array String) + let rec go (dir : System.FilePath) (relPrefix : System.FilePath) : IO (Array (ModuleName × System.FilePath)) := do let mut acc := #[] let entries ← dir.readDir for entry in entries do + let relChild : System.FilePath := + if relPrefix.toString.isEmpty then + entry.fileName + else + relPrefix / entry.fileName if ← entry.path.isDir then - acc := acc ++ (← go entry.path (components.push entry.fileName)) + acc := acc ++ (← go entry.path relChild) else if entry.fileName.endsWith ".py" then - let parts := - if entry.fileName == "__init__.py" then - components - else - components.push (entry.fileName.takeWhile (· != '.') |>.toString) - if parts.isEmpty then continue - let dotted := ".".intercalate parts.toList - match ModuleName.ofString dotted with - | .ok mod => acc := acc.push (mod, entry.path) + match ModuleName.ofRelativePath relChild with + | .ok info => acc := acc.push (info.moduleName, entry.path) | .error msg => let _ ← IO.eprintln s!"warning: skipping {entry.path}: {msg}" |>.toBaseIO + continue return acc - go sourceDir #[] + go sourceDir ⟨""⟩ /-- Derive the output path for a Python file by mirroring the source directory structure and replacing `.py` with `.pyspec.st.ion`. -/ @@ -474,9 +477,9 @@ def pySpecsDir (sourceDir strataDir dialectFile : System.FilePath) else let mut result := #[] for m in modules do - let mod ← match ModuleName.ofString m with - | .ok r => pure r - | .error e => throw s!"Invalid module name '{m}': {e}" + let mod ← match ModuleName.ofString? m with + | some r => pure r + | none => throw s!"Invalid module name '{m}'" let (path, _) ← match ← ModuleName.findInPath mod sourceDir |>.toBaseIO with | .ok r => pure r @@ -508,9 +511,9 @@ def pySpecsDir (sourceDir strataDir dialectFile : System.FilePath) -- Translate Python.Specs.baseLogEvent events "import" s!"Translating {mod}" match ← Strata.Python.Specs.translateFile - dialectFile strataDir pythonFile sourceDir + dialectFile strataDir pythonFile sourceDir mod (events := events) (skipNames := skipIdents) - (moduleName := mod) (pythonCmd := pythonCmd) |>.toBaseIO with + (pythonCmd := pythonCmd) |>.toBaseIO with | .error msg => Python.Specs.baseLogEvent events "import" s!"Failed {mod}: {msg}" failures := failures.push (toString mod, msg) @@ -549,10 +552,16 @@ def pyTranslateLaurel (pyspecModules : Array String := #[]) (specDir : System.FilePath := ".") : EIO String (Core.Program × List DiagnosticModel) := do + let pctx ← Pipeline.PipelineContext.create (outputMode := .quiet) let laurel ← - match ← pythonAndSpecToLaurel pythonIonPath dispatchModules pyspecModules (specDir := specDir) |>.toBaseIO with + match ← (pythonAndSpecToLaurel pythonIonPath dispatchModules pyspecModules (specDir := specDir)).run pctx |>.toBaseIO with | .ok r => pure r - | .error err => throw (toString err) + | .error () => + let msgs ← pctx.getMessages + let detail := match msgs.back? with + | some m => m.message + | none => "Pipeline aborted" + throw detail let (coreOption, laurelTranslateErrors) ← IO.toEIO (fun e => s!"{e}") (translateCombinedLaurel laurel) match coreOption with | none => throw s!"Laurel to Core translation failed: {laurelTranslateErrors}" diff --git a/Strata/Transform/ANFEncoder.lean b/Strata/Transform/ANFEncoder.lean index 39390bb321..6f30267052 100644 --- a/Strata/Transform/ANFEncoder.lean +++ b/Strata/Transform/ANFEncoder.lean @@ -83,8 +83,20 @@ private def findDuplicates (exprs : List Expression.Expr) : List Expression.Expr /-- Replace all occurrences of any target with its corresponding replacement in an expression. Computes hashes bottom-up to avoid redundant traversals. - The map stores (target, replacement) pairs keyed by hash. -/ -def replaceExprs (replacements : Std.HashMap UInt64 (Expression.Expr × Expression.Expr)) + + The map values are lists of (target, replacement) pairs so that distinct + expressions sharing the same `LExpr.hashExpr` do not displace each other + on insertion. On lookup we walk the list with structural `==` to find + the matching target. The expected list length is 1 for typical inputs; + a non-trivial collision only adds the cost of a few extra `==` + comparisons. + + Collision safety is load-bearing for `anfEncodeBody`'s termination + argument: it guarantees that every duplicate found by + `findANFEncoderTargets` is actually rewritten by this function on the + same pass, so no unreplaced duplicate can survive into the next + iteration. -/ +def replaceExprs (replacements : Std.HashMap UInt64 (List (Expression.Expr × Expression.Expr))) (e : Expression.Expr) : Expression.Expr := (go e).2 where @@ -121,11 +133,15 @@ where let e' : Expression.Expr := .quant m k name ty tr' body' let kh : UInt64 := match k with | .all => 0 | .exist => 1 check (LExpr.hashQuantExpr kh (hash name) (LExpr.hashOptTy ty) htr hbody) e' - /-- Check if the hash matches a replacement target. -/ + /-- Check if the hash matches a replacement target. Walks the list of + pairs at this hash bucket and uses structural `==` to find the target, + so collisions never silently drop or misroute a replacement. -/ check (h : UInt64) (e : Expression.Expr) : UInt64 × Expression.Expr := match replacements[h]? with - | some (target, replacement) => - if e == target then (h, replacement) else (h, e) + | some pairs => + match pairs.find? (fun (t, _) => e == t) with + | some (_, replacement) => (h, replacement) + | none => (h, e) | none => (h, e) /-- Collect all subexpression hashes from an expression, @@ -190,26 +206,67 @@ private def findANFEncoderTargets (exprs : List Expression.Expr) : /-- Deduplicate a procedure's body by extracting common subexpressions into `var` declarations prepended to the body. Returns the modified body and the next available dedup index. + Assumes single-assignment (SSA-like) property of the post-PE Core IR: variables are assigned only once, so structurally equal expressions - always denote the same value within a procedure body. -/ + always denote the same value within a procedure body. + + Iterates to a fixpoint: a single pass cannot extract everything because + `removeSubsumed` deliberately drops duplicate subexpressions that are + contained in other (larger) duplicate expressions, to avoid creating + redundant `var` declarations. After the larger duplicate is lifted into + its own var declaration, those previously-subsumed inner duplicates + appear once in the new var-decl init and possibly again elsewhere in the + body, at which point the next iteration can extract them. + + Termination. Let `S(body)` be the set of distinct non-leaf, no-bvar + subexpressions of `body`. Then: + * `findANFEncoderTargets body ⊆ S(body)` and `S(body)` is finite. + * Each iteration replaces every occurrence of every target with a + fresh `fvar`. Fresh `fvar`s are leaves and are filtered out of all + future `S(...)` (via `!e.isLeaf`). + * Each new var-decl init is one of the just-extracted targets, which + was already in `S(body)`, so `S(newBody) ⊆ S(body)`. + * After extraction, every extracted target appears at most once in the + new body (in its own var-decl init), so it is no longer in + `findANFEncoderTargets newBody`. + Hence the iteration count is bounded by `|S(initial body)|`, which is in + turn bounded by the total expression size of the body. We pass that + bound as `fuel` so the recursion is structurally decreasing. -/ def anfEncodeBody (body : Statements) (startIdx : Nat) : Statements × Nat := - let targets := findANFEncoderTargets ((Statements.collectExprs body).flatMap collectSubexprs) - -- Build all var declarations and the replacement map - let (revDecls, replacements, nextIdx) := targets.foldl (fun (decls, repMap, idx) dup => - let freshName : CoreIdent := ⟨s!"{anfVarPrefix}{idx}", ()⟩ - let freshTy := dup.typeOf - let freshVar : Expression.Expr := .fvar () freshName freshTy - let ty : Expression.Ty := match freshTy with - | some mty => LTy.forAll [] mty - | none => LTy.forAll ["α"] (.ftvar "α") - let varDecl := Statement.init freshName ty (.det dup) .empty - let h := LExpr.hashExpr dup - (varDecl :: decls, repMap.insert h (dup, freshVar), idx + 1) - ) ([], ({} : Std.HashMap UInt64 (Expression.Expr × Expression.Expr)), startIdx) - -- Single pass: replace all targets at once - let body' := Statements.mapExprs (replaceExprs replacements) body - (revDecls.reverse ++ body', nextIdx) + let fuel := (Statements.collectExprs body).foldl (fun acc e => acc + LExpr.size _ e) 0 + go fuel body startIdx +where + go (fuel : Nat) (body : Statements) (startIdx : Nat) : Statements × Nat := + match fuel with + | 0 => (body, startIdx) + | fuel' + 1 => + let targets := findANFEncoderTargets ((Statements.collectExprs body).flatMap collectSubexprs) + if targets.isEmpty then + (body, startIdx) + else + -- Build all var declarations and the replacement map. The map value + -- is a list of (target, replacement) pairs to be collision-safe under + -- `LExpr.hashExpr`; see `replaceExprs` above. + let (revDecls, replacements, nextIdx) := targets.foldl (fun (decls, repMap, idx) dup => + let freshName : CoreIdent := ⟨s!"{anfVarPrefix}{idx}", ()⟩ + let freshTy := dup.typeOf + let freshVar : Expression.Expr := .fvar () freshName freshTy + let ty : Expression.Ty := match freshTy with + | some mty => LTy.forAll [] mty + | none => LTy.forAll ["α"] (.ftvar "α") + let varDecl := Statement.init freshName ty (.det dup) .empty + let h := LExpr.hashExpr dup + let pairs := repMap.getD h [] + (varDecl :: decls, repMap.insert h ((dup, freshVar) :: pairs), idx + 1) + ) ([], ({} : Std.HashMap UInt64 (List (Expression.Expr × Expression.Expr))), startIdx) + -- Replace all targets at once in the original body. + let body' := Statements.mapExprs (replaceExprs replacements) body + let newBody := revDecls.reverse ++ body' + -- Iterate: the newly-prepended var declarations may themselves + -- contain duplicated subexpressions that `removeSubsumed` dropped in + -- this pass. + go fuel' newBody nextIdx /-- Deduplicate all procedures in a program. Returns the modified program and whether any changes were made. -/ diff --git a/Strata/Transform/PrecondElim.lean b/Strata/Transform/PrecondElim.lean index ee77df3440..e16d42c025 100644 --- a/Strata/Transform/PrecondElim.lean +++ b/Strata/Transform/PrecondElim.lean @@ -76,6 +76,8 @@ private def classifyPrecondition (funcName : String) (precondIdx : Nat := 0) : O | .bv ⟨_, .SafeAdd⟩ | .bv ⟨_, .SafeSub⟩ | .bv ⟨_, .SafeMul⟩ | .bv ⟨_, .SafeNeg⟩ | .bv ⟨_, .SafeUAdd⟩ | .bv ⟨_, .SafeUSub⟩ | .bv ⟨_, .SafeUMul⟩ | .bv ⟨_, .SafeUNeg⟩ => some Imperative.MetaData.arithmeticOverflow + | .seq .Select | .seq .Update | .seq .Take | .seq .Drop => + some Imperative.MetaData.outOfBoundsAccess | _ => none /-- diff --git a/Strata/Util/FileRange.lean b/Strata/Util/FileRange.lean index dac1c7129d..f112109120 100644 --- a/Strata/Util/FileRange.lean +++ b/Strata/Util/FileRange.lean @@ -88,7 +88,7 @@ instance : Inhabited DiagnosticModel where /-- Create a DiagnosticModel from just a message (using default location). This should not be called, it only exists temporarily to enable incrementally migrating code without error locations -/ -def DiagnosticModel.fromMessage (msg : String) (type : DiagnosticType := DiagnosticType.UserError): DiagnosticModel := +def DiagnosticModel.fromMessage (msg : String) (type : DiagnosticType := DiagnosticType.UserError) : DiagnosticModel := { fileRange := FileRange.unknown, message := msg, type := type } /-- Create a DiagnosticModel from a Format (using default location). diff --git a/Strata/Util/Profile.lean b/Strata/Util/Profile.lean deleted file mode 100644 index 9bf748d1ca..0000000000 --- a/Strata/Util/Profile.lean +++ /dev/null @@ -1,18 +0,0 @@ -/- - Copyright Strata Contributors - - SPDX-License-Identifier: Apache-2.0 OR MIT --/ -module - -@[inline] public def nsToMs (ns : Nat) : Nat := (ns + 500000) / 1000000 - -/-- Run an action, printing its elapsed time in milliseconds to stdout when `profile` is true. -/ -public def profileStep {m α} [Monad m] [MonadLiftT BaseIO m] - (profile : Bool) (name : String) (action : m α) : m α := do - if !profile then return ← action - let start ← IO.monoNanosNow - let result ← action - let elapsed := (← IO.monoNanosNow) - start - let _ ← (IO.println s!"[profile] {name}: {nsToMs elapsed}ms" |>.toBaseIO) - pure result diff --git a/StrataMain.lean b/StrataMain.lean index 1e6700ff54..826504e67f 100644 --- a/StrataMain.lean +++ b/StrataMain.lean @@ -3,1554 +3,7 @@ SPDX-License-Identifier: Apache-2.0 OR MIT -/ -module +import StrataMainLib --- Executable with utilities for working with Strata files. -import Lean.Parser.Extension -import Strata.Backends.CBMC.CollectSymbols -import Strata.Backends.CBMC.GOTO.CoreToGOTOPipeline -import Strata.DDM.Integration.Java.Gen -import Strata.Languages.Core.Verifier -import Strata.Languages.Core.SarifOutput -import Strata.Languages.Core.ProgramEval -import Strata.Languages.Core.StatementEval -import Strata.Languages.C_Simp.Verify -import Strata.Languages.B3.Verifier.Program -import Strata.Languages.Laurel.LaurelCompilationPipeline -import Strata.Languages.Boole.Boole -import Strata.Languages.Boole.Verify -import Strata.Languages.Python.Python -import Strata.Languages.Python.Specs.IdentifyOverloads -import Strata.Languages.Python.Specs.ToLaurel -import Strata.Languages.Laurel.Grammar.AbstractToConcreteTreeTranslator -import Strata.Languages.Laurel.Laurel -import Strata.Languages.Core.EntryPoint -import Strata.Transform.ProcedureInlining -import Strata.Util.IO - -import Strata.SimpleAPI -import Strata.Util.Profile -import Strata.Util.Json -import Strata.DDM.BuiltinDialects -import Strata.DDM.Util.String -import Strata.Languages.Python.PyFactory -import Strata.Languages.Python.Specs -import Strata.Languages.Python.Specs.DDM -import Strata.Languages.Python.ReadPython - -open Strata - -open Core (VerifyOptions VerboseMode VerificationMode CheckLevel EntryPoint) -open Laurel (LaurelVerifyOptions LaurelTranslateOptions) - -/-! ## Exit codes - -All `strata` subcommands use a common exit code scheme: - -| Code | Category | Meaning | -|------|--------------------|-----------------------------------------------------------| -| 0 | Success | Analysis passed, inconclusive, or `--no-solve` completed. | -| 1 | User error | Bad input: invalid arguments, malformed source, etc. | -| 2 | Failures found | Analysis completed and found failures. | -| 3 | Internal error | SMT encoding failure, solver crash, or translation bug. | -| 4 | Known limitation | Intentionally unsupported language construct. | - -Codes 1–2 are **user-actionable** (fix the input or the code under analysis). -Codes 3–4 are **tool-side** (report as a bug or wait for support). -Exit 0 covers success, inconclusive results, and solver timeouts. -/ - -namespace ExitCode - def userError : UInt8 := 1 - def failuresFound : UInt8 := 2 - def internalError : UInt8 := 3 - def knownLimitation : UInt8 := 4 -end ExitCode - -def exitFailure {α} (message : String) (hint : String := "strata --help") : IO α := do - IO.eprintln s!"Exception: {message}\n\nRun {hint} for additional help." - IO.Process.exit ExitCode.userError - -/-- Exit with code 1 for user errors (bad input, malformed source, etc.). -/ -def exitUserError {α} (message : String) : IO α := do - IO.eprintln s!"❌ {message}" - IO.Process.exit ExitCode.userError - -/-- Exit with code 2 for analysis failures found. -/ -def exitFailuresFound {α} (message : String) : IO α := do - IO.eprintln s!"Failures found: {message}" - IO.Process.exit ExitCode.failuresFound - -/-- Exit with code 3 for internal errors (tool limitations or crashes). -/ -def exitInternalError {α} (message : String) : IO α := do - IO.eprintln s!"Exception: {message}" - IO.Process.exit ExitCode.internalError - -/-- Exit with code 4 for known limitations (unsupported constructs). -/ -def exitKnownLimitation {α} (message : String) : IO α := do - IO.eprintln s!"Known limitation: {message}" - IO.Process.exit ExitCode.knownLimitation - -/-- Like `exitFailure` but tailors the help hint to a specific subcommand. -/ -def exitCmdFailure {α} (cmdName : String) (message : String) : IO α := - exitFailure message (hint := s!"strata {cmdName} --help") - -/-- How a flag consumes arguments. -/ -inductive FlagArg where - | none -- boolean flag, e.g. --verbose - | arg (name : String) -- takes one value, e.g. --output - | repeat (name : String) -- takes one value, may appear multiple times, e.g. --include - -/-- A flag that a command accepts. -/ -structure Flag where - name : String -- flag name without "--", used as lookup key - help : String - takesArg : FlagArg := .none - -/-- Parsed flags from the command line. Stored as an ordered array so that - command-line position is preserved (needed by `transform` to bind - `--procedures`/`--functions` to the preceding `--pass`). - For `.arg` flags that appear more than once, `getString` returns the - **last** occurrence (last-writer-wins). -/ -structure ParsedFlags where - entries : Array (String × Option String) := #[] - -namespace ParsedFlags - -def getBool (pf : ParsedFlags) (name : String) : Bool := - pf.entries.any (·.1 == name) - -def getString (pf : ParsedFlags) (name : String) : Option String := - -- Scan from the end so last occurrence wins. - match pf.entries.findRev? (·.1 == name) with - | some (_, some v) => some v - | _ => Option.none - -def getRepeated (pf : ParsedFlags) (name : String) : Array String := - pf.entries.foldl (init := #[]) fun acc (n, v) => - if n == name then match v with | some s => acc.push s | none => acc else acc - -def insert (pf : ParsedFlags) (name : String) (value : Option String) : ParsedFlags := - { pf with entries := pf.entries.push (name, value) } - -def buildDialectFileMap (pflags : ParsedFlags) : IO Strata.DialectFileMap := do - let preloaded := Strata.Elab.LoadedDialects.builtin - |>.addDialect! Strata.Python.Python - |>.addDialect! Strata.Python.Specs.DDM.PythonSpecs - |>.addDialect! Strata.Core - |>.addDialect! Strata.Boole - |>.addDialect! Strata.Laurel.Laurel - |>.addDialect! Strata.smtReservedKeywordsDialect - |>.addDialect! Strata.SMTCore - |>.addDialect! Strata.SMT - |>.addDialect! Strata.SMTResponse - let mut sp ← Strata.DialectFileMap.new preloaded - for path in pflags.getRepeated "include" do - match ← sp.add path |>.toBaseIO with - | .error msg => exitFailure msg - | .ok sp' => sp := sp' - return sp - -end ParsedFlags - -def parseCheckMode (pflags : ParsedFlags) - (default : VerificationMode := .deductive) : IO VerificationMode := - match pflags.getString "check-mode" with - | .none => pure default - | .some s => match VerificationMode.ofString? s with - | .some m => pure m - | .none => exitFailure s!"Invalid check mode: '{s}'. Must be {VerificationMode.options}." - -def parseCheckLevel (pflags : ParsedFlags) - (default : CheckLevel := .minimal) : IO CheckLevel := - match pflags.getString "check-level" with - | .none => pure default - | .some s => match CheckLevel.ofString? s with - | .some l => pure l - | .none => exitFailure s!"Invalid check level: '{s}'. Must be {CheckLevel.options}." - -/-- Common CLI flags for VerifyOptions fields. - Commands can append these to their own flags list. - Note: `parseOnly`, `typeCheckOnly`, and `checkOnly` are omitted here - because they are specific to the `verify` command. -/ -def verifyOptionsFlags : List Flag := [ - { name := "check-mode", - help := s!"Check mode: {VerificationMode.options}. Default: 'deductive'.", - takesArg := .arg "mode" }, - { name := "check-level", - help := s!"Check level: {CheckLevel.options}. Default: 'minimal'.", - takesArg := .arg "level" }, - { name := "verbose", help := "Enable verbose output." }, - { name := "quiet", help := "Suppress warnings on stderr." }, - { name := "profile", help := "Print elapsed time for each pipeline step." }, - { name := "sarif", help := "Write results as SARIF to .sarif." }, - { name := "solver", - help := s!"SMT solver executable (default: {Core.defaultSolver}).", - takesArg := .arg "name" }, - { name := "solver-timeout", - help := "Solver timeout in seconds (default: 10).", - takesArg := .arg "seconds" }, - { name := "vc-directory", - help := "Store VCs in SMT-Lib format in .", - takesArg := .arg "dir" }, - { name := "no-solve", - help := "Generate SMT-Lib files but do not invoke the solver." }, - { name := "stop-on-first-error", - help := "Exit after the first verification error." }, - { name := "unique-bound-names", - help := "Use globally unique names for quantifier-bound variables." }, - { name := "use-array-theory", - help := "Use SMT-LIB Array theory instead of axiomatized maps." }, - { name := "remove-irrelevant-axioms", - help := "Prune irrelevant axioms: 'off', 'aggressive', or 'precise'.", - takesArg := .arg "mode" }, - { name := "overflow-checks", - help := "Comma-separated overflow checks to enable (signed,unsigned,float64,all,none).", - takesArg := .arg "checks" }, - { name := "path-cap", - help := "Maximum continuing paths between statements. 'none' (default) disables; N merges paths when count exceeds N.", - takesArg := .arg "N|none" } -] - -/-- Build a VerifyOptions from parsed CLI flags, starting from a base config. - Fields not present in the flags keep their base values. - Note: boolean flags can only enable a setting; a `true` in the base - cannot be turned off from the CLI (there is no `--no-X` syntax). -/ -def parseVerifyOptions (pflags : ParsedFlags) - (base : VerifyOptions := VerifyOptions.default) : IO VerifyOptions := do - let checkMode ← parseCheckMode pflags base.checkMode - let checkLevel ← parseCheckLevel pflags base.checkLevel - let solverTimeout ← match pflags.getString "solver-timeout" with - | .none => pure base.solverTimeout - | .some s => match s.toNat? with - | .some n => pure n - | .none => exitFailure s!"Invalid solver timeout: '{s}'" - let noSolve := pflags.getBool "no-solve" - let removeIrrelevantAxioms ← match pflags.getString "remove-irrelevant-axioms" with - | .none => pure base.removeIrrelevantAxioms - | .some "off" => pure .Off - | .some "aggressive" => pure .Aggressive - | .some "precise" => pure .Precise - | .some s => exitFailure s!"Invalid remove-irrelevant-axioms mode: '{s}'. Must be 'off', 'aggressive', or 'precise'." - let overflowChecks := match pflags.getString "overflow-checks" with - | .none => base.overflowChecks - | .some s => s.splitOn "," |>.foldl (fun acc c => - match c.trimAscii.toString with - | "signed" => { acc with signedBV := true } - | "unsigned" => { acc with unsignedBV := true } - | "float64" => { acc with float64 := true } - | "none" => { signedBV := false, unsignedBV := false, float64 := false } - | "all" => { signedBV := true, unsignedBV := true, float64 := true } - | _ => acc) { signedBV := false, unsignedBV := false, float64 := false } - let pathCap ← match pflags.getString "path-cap" with - | .none => pure base.pathCap - | .some "none" => pure .none - | .some s => match s.toNat? with - | .some n => if n == 0 then exitFailure "--path-cap must be at least 1 or 'none'." - else pure (.some n) - | .none => exitFailure s!"Invalid path-cap: '{s}'. Must be a positive number or 'none'." - let vcDirectory := (pflags.getString "vc-directory" |>.map (⟨·⟩ : String → System.FilePath)).orElse (fun _ => base.vcDirectory) - let skipSolver := noSolve || base.skipSolver - if skipSolver && vcDirectory.isNone then - exitFailure "--no-solve requires --vc-directory to specify where SMT files are stored." - pure { base with - verbose := if pflags.getBool "verbose" then .normal - else if pflags.getBool "quiet" then .quiet - else base.verbose, - solver := pflags.getString "solver" |>.getD base.solver, - solverTimeout, - checkMode, checkLevel, - stopOnFirstError := pflags.getBool "stop-on-first-error" || base.stopOnFirstError, - uniqueBoundNames := pflags.getBool "unique-bound-names" || base.uniqueBoundNames, - useArrayTheory := pflags.getBool "use-array-theory" || base.useArrayTheory, - removeIrrelevantAxioms, - outputSarif := pflags.getBool "sarif" || base.outputSarif, - profile := pflags.getBool "profile" || base.profile, - skipSolver, - alwaysGenerateSMT := noSolve || base.alwaysGenerateSMT, - overflowChecks, - vcDirectory, - pathCap - } - -/-- Additional CLI flags for `LaurelVerifyOptions` fields that are not already - covered by `verifyOptionsFlags`. -/ -def laurelTranslateFlags : List Flag := [ - { name := "keep-all-files", - help := "Store intermediate Laurel and Core programs in .", - takesArg := .arg "dir" } -] - -/-- All CLI flags accepted by Laurel verify commands. -/ -def laurelVerifyOptionsFlags : List Flag := verifyOptionsFlags ++ laurelTranslateFlags - -/-- Build a `LaurelVerifyOptions` from parsed CLI flags. -/ -def parseLaurelVerifyOptions (pflags : ParsedFlags) - (base : LaurelVerifyOptions := default) : IO LaurelVerifyOptions := do - let verifyOptions ← parseVerifyOptions pflags base.verifyOptions - let keepAllFilesPrefix := (pflags.getString "keep-all-files").orElse - (fun _ => base.translateOptions.keepAllFilesPrefix) - let translateOptions : LaurelTranslateOptions := - { base.translateOptions with - keepAllFilesPrefix - overflowChecks := verifyOptions.overflowChecks - profile := verifyOptions.profile } - return { translateOptions, verifyOptions } - -/-- Read and parse a Strata program file, loading the Core, C_Simp, and B3CST - dialects. Returns the parsed program and the input context (for source - location resolution), or an array of error messages on failure. -/ -private def readStrataProgram (file : String) - : IO (Except (Array Lean.Message) (Strata.Program × Lean.Parser.InputContext)) := do - let text ← Strata.Util.readInputSource file - let inputCtx := Lean.Parser.mkInputContext text (Strata.Util.displayName file) - let dctx := Elab.LoadedDialects.builtin - let dctx := dctx.addDialect! Core - let dctx := dctx.addDialect! Boole - let dctx := dctx.addDialect! C_Simp - let dctx := dctx.addDialect! B3CST - let leanEnv ← Lean.mkEmptyEnvironment 0 - match Strata.Elab.elabProgram dctx leanEnv inputCtx with - | .ok pgm => pure (.ok (pgm, inputCtx)) - | .error msgs => pure (.error msgs) - -structure Command where - name : String - args : List String - flags : List Flag := [] - help : String - callback : Vector String args.length → ParsedFlags → IO Unit - -def includeFlag : Flag := - { name := "include", help := "Add a dialect search path.", takesArg := .repeat "path" } - -def checkCommand : Command where - name := "check" - args := [ "file" ] - flags := [includeFlag] - help := "Parse and validate a Strata file (text or Ion). Reports errors and exits." - callback := fun v pflags => do - let fm ← pflags.buildDialectFileMap - let _ ← Strata.readStrataFile fm v[0] - -def toIonCommand : Command where - name := "toIon" - args := [ "input", "output" ] - flags := [includeFlag] - help := "Convert a Strata text file to Ion binary format." - callback := fun v pflags => do - let searchPath ← pflags.buildDialectFileMap - let pd ← Strata.readStrataFile searchPath v[0] - match pd with - | .dialect d => - IO.FS.writeBinFile v[1] d.toIon - | .program pgm => - IO.FS.writeBinFile v[1] pgm.toIon - -def printCommand : Command where - name := "print" - args := [ "file" ] - flags := [includeFlag] - help := "Pretty-print a Strata file (text or Ion) to stdout." - callback := fun v pflags => do - let searchPath ← pflags.buildDialectFileMap - -- Special case for already loaded dialects. - let ld ← searchPath.getLoaded - if mem : v[0] ∈ ld.dialects then - IO.print <| ld.dialects.format v[0] mem - return - let pd ← Strata.readStrataFile searchPath v[0] - match pd with - | .dialect d => - let ld ← searchPath.getLoaded - let .isTrue mem := (inferInstance : Decidable (d.name ∈ ld.dialects)) - | exitInternalError "Internal error reading file." - IO.print <| ld.dialects.format d.name mem - | .program pgm => - IO.print <| toString pgm - -def diffCommand : Command where - name := "diff" - args := [ "file1", "file2" ] - flags := [includeFlag] - help := "Compare two program files for syntactic equality. Reports the first difference found." - callback := fun v pflags => do - let fm ← pflags.buildDialectFileMap - let p1 ← Strata.readStrataFile fm v[0] - let p2 ← Strata.readStrataFile fm v[1] - match p1, p2 with - | .program p1, .program p2 => - if p1.dialect != p2.dialect then - exitFailure s!"Dialects differ: {p1.dialect} and {p2.dialect}" - let Decidable.isTrue eq := (inferInstance : Decidable (p1.commands.size = p2.commands.size)) - | exitFailure s!"Number of commands differ {p1.commands.size} and {p2.commands.size}" - for (c1, c2) in Array.zip p1.commands p2.commands do - if c1 != c2 then - exitFailure s!"Commands differ: {repr c1} and {repr c2}" - | _, _ => - exitFailure "Cannot compare dialect def with another dialect/program." - -def pySpecsCommand : Command where - name := "pySpecs" - args := [ "source_dir", "output_dir" ] - flags := [ - { name := "quiet", help := "Suppress default logging." }, - { name := "log", help := "Enable logging for an event type.", - takesArg := .repeat "event" }, - { name := "skip", - help := "Skip a top-level definition (module.name). Overloads are kept.", - takesArg := .repeat "name" }, - { name := "module", - help := "Translate only the named module (dot-separated). May be repeated.", - takesArg := .repeat "module" } - ] - help := "Translate Python specification files in a directory into Strata DDM Ion format. If --module is given, translates only those modules; otherwise translates all .py files. Creates subdirectories as needed. (Experimental)" - callback := fun v pflags => do - let quiet := pflags.getBool "quiet" - let mut events : Std.HashSet String := {} - if !quiet then - events := events.insert "import" - for e in pflags.getRepeated "log" do - events := events.insert e - let skipNames := pflags.getRepeated "skip" - let modules := pflags.getRepeated "module" - let warningOutput : Strata.WarningOutput := - if quiet then .none else .detail - -- Serialize embedded dialect for Python subprocess - IO.FS.withTempFile fun _handle dialectFile => do - IO.FS.writeBinFile dialectFile Strata.Python.Python.toIon - let r ← Strata.pySpecsDir (events := events) - (skipNames := skipNames) - (modules := modules) - (warningOutput := warningOutput) - v[0] v[1] dialectFile |>.toBaseIO - match r with - | .ok () => pure () - | .error msg => exitFailure msg - -/-- Derive Python source file path from Ion file path. - E.g., "tests/test_foo.python.st.ion" -> "tests/test_foo.py" -/ -def ionPathToPythonPath (ionPath : String) : Option String := - if ionPath.endsWith ".python.st.ion" then - let basePath := ionPath.dropEnd ".python.st.ion".length |>.toString - some (basePath ++ ".py") - else if ionPath.endsWith ".py.ion" then - some (ionPath.dropEnd ".ion".length |>.toString) - else - none - -/-- Try to read Python source file for source location reconstruction -/ -def tryReadPythonSource (ionPath : String) : IO (Option (String × String)) := do - match ionPathToPythonPath ionPath with - | none => return none - | some pyPath => - try - let content ← IO.FS.readFile pyPath - return some (pyPath, content) - catch _ => - return none - -/-- Format related position strings from metadata, if present. -/ -def formatRelatedPositions (md : Imperative.MetaData Core.Expression) - (mfm : Option (String × Lean.FileMap)) : String := - let ranges := Imperative.getRelatedFileRanges md - if ranges.isEmpty then "" else - match mfm with - | none => "" - | some (_, fm) => - let lines := ranges.filterMap fun fr => - if fr.range.isNone then none else - match fr.file with - | .file "" => some "\n Related location: in prelude file" - | .file _ => - let pos := fm.toPosition fr.range.start - some s!"\n Related location: line {pos.line}, col {pos.column}" - String.join lines.toList - -/-! ### pyAnalyzeLaurel result helpers - -The `pyAnalyzeLaurel` command emits two structured lines on stdout: -- `RESULT: ` — machine-readable category, always the last line. -- `DETAIL: ` — human-readable context (error message or VC counts). - -Exit codes follow the common scheme (see `ExitCode` above). -A successful run exits 0 with `RESULT: Analysis success` or `RESULT: Inconclusive`. -/ - -/-- Determines which VC results count as successes and which count as failures - for the purposes of the `pyAnalyzeLaurel` summary and exit code. - Implementation-error results are partitioned out first; the classifier then - partitions the rest into success / failure / inconclusive. - Narrowing `isFailure` (e.g. to only `alwaysFalseAndReachable`) automatically - widens inconclusive. - Future: may be extended with `isWarning` for non-fatal diagnostic categories. -/ -structure ResultClassifier where - isSuccess : Core.VCResult → Bool := (·.isSuccess) - isFailure : Core.VCResult → Bool := (·.isFailure) - -private def printPyAnalyzeResult (category : String) (detail : String) : IO Unit := do - IO.println s!"DETAIL: {detail}" - IO.println s!"RESULT: {category}" - -private def exitPyAnalyzeUserError {α} (message : String) : IO α := do - printPyAnalyzeResult "User error" message - IO.Process.exit ExitCode.userError - -private def exitPyAnalyzeFailuresFound {α} (detail : String) : IO α := do - printPyAnalyzeResult "Failures found" detail - IO.Process.exit ExitCode.failuresFound - -private def exitPyAnalyzeInternalError {α} (message : String) : IO α := do - printPyAnalyzeResult "Internal error" message - IO.Process.exit ExitCode.internalError - -private def exitPyAnalyzeKnownLimitation {α} (message : String) : IO α := do - printPyAnalyzeResult "Known limitation" message - IO.Process.exit ExitCode.knownLimitation - -/-- Print the final RESULT/DETAIL lines based on solver outcomes. - Always called on successful pipeline completion (as opposed to the - exit helpers above, which are called on early pipeline failure). - Classification uses successive partitioning: timeouts and implementation - errors are removed first, then the classifier partitions the rest into - success / failure / inconclusive (guaranteeing disjointness). - Unreachable count is reported as supplementary info. - - Exit-code priority (highest wins): - - Internal error (exit 3): encoding failures or solver crashes - - Failures found (exit 2): assertion violations - - Inconclusive / success / solver timeout (exit 0) -/ -private def printPyAnalyzeSummary (vcResults : Array Core.VCResult) - (checkMode : VerificationMode := .deductive) : IO Unit := do - let classifier : ResultClassifier := - match checkMode with - | .bugFinding | .bugFindingAssumingCompleteSpec => - { isSuccess := (·.isBugFindingSuccess) - isFailure := (·.isBugFindingFailure) } - | _ => {} - -- 1. Partition out implementation errors and timeouts (not classifiable). - let (implError, rest1) := - vcResults.partition (fun r => r.isImplementationError || r.hasSMTError) - let (timeouts, classifiable) := rest1.partition (·.isTimeout) - -- 2. Successive partitioning via the classifier: success → failure → inconclusive. - let (success, rest) := classifiable.partition classifier.isSuccess - let (failure, inconclusive) := rest.partition classifier.isFailure - -- 3. Unreachable is informational (not a separate partition). - let nUnreachable := vcResults.filter (·.isUnreachable) |>.size - let nImplError := implError.size - let nTimeout := timeouts.size - let nSuccess := success.size - let nFailure := failure.size - let nInconclusive := inconclusive.size - let unreachableStr := if nUnreachable > 0 then s!", {nUnreachable} unreachable" else "" - let implErrorStr := if nImplError > 0 then s!", {nImplError} internal errors" else "" - let timeoutStr := if nTimeout > 0 then s!", {nTimeout} solver timeouts" else "" - let counts := s!"{nSuccess} passed, {nFailure} failed, {nInconclusive} inconclusive{unreachableStr}{timeoutStr}{implErrorStr}" - if nImplError > 0 then - exitPyAnalyzeInternalError s!"An unexpected result was produced. {counts}" - else if nFailure > 0 then - exitPyAnalyzeFailuresFound counts - else - let label := - if nTimeout > 0 then "Solver timeout" - else if nInconclusive > 0 then "Inconclusive" - else "Analysis success" - printPyAnalyzeResult label counts - -private def deriveBaseName (file : String) : String := - let name := System.FilePath.fileName file |>.getD file - let suffixes := [".python.st.ion", ".py.ion", ".st.ion", ".st"] - match suffixes.find? (name.endsWith ·) with - | some sfx => (name.dropEnd sfx.length).toString - | none => name - - -def pyAnalyzeLaurelCommand : Command where - name := "pyAnalyzeLaurel" - args := [ "file" ] - flags := verifyOptionsFlags ++ [ - { name := "spec-dir", - help := "Directory containing compiled PySpec Ion files.", - takesArg := .arg "dir" }, - { name := "dispatch", - help := "Dispatch module name (e.g., servicelib).", - takesArg := .repeat "module" }, - { name := "pyspec", - help := "PySpec module name (e.g., servicelib.Storage).", - takesArg := .repeat "module" }, - { name := "keep-all-files", - help := "Store intermediate Laurel and Core programs in .", - takesArg := .arg "dir" }, - { name := "entry-point", - help := "Which procedures to verify: main (main fn only), roots (user procs with no user callers, default), or all (all user procs). Only valid in bugFinding mode.", - takesArg := .arg "mode" }, - { name := "warning-summary", - help := "Write PySpec warning summary as JSON to .", - takesArg := .arg "file" }, - { name := "skip-verification", - help := "Run Python-to-Laurel and Laurel-to-Core translation only (skip SMT verification).", - takesArg := .none }] - help := "Verify a Python Ion program via the Laurel pipeline. Translates Python to Laurel to Core, then runs SMT verification." - callback := fun v pflags => do - let verbose := pflags.getBool "verbose" - let profile := pflags.getBool "profile" - let quiet := pflags.getBool "quiet" - let outputSarif := pflags.getBool "sarif" - let filePath := v[0] - let pySourceOpt ← tryReadPythonSource filePath - let keepDir := pflags.getString "keep-all-files" - let baseName := deriveBaseName filePath - if let some dir := keepDir then - IO.FS.createDirAll dir - - let dispatchModules := pflags.getRepeated "dispatch" - let pyspecModules := pflags.getRepeated "pyspec" - let specDir := pflags.getString "spec-dir" |>.getD "." - unless ← System.FilePath.isDir specDir do - exitFailure s!"spec-dir '{specDir}' does not exist or is not a directory" - let sourcePath := pySourceOpt.map (·.1) - -- Build FileMap for source position resolution. - let mfm : Option (String × Lean.FileMap) := match pySourceOpt with - | some (pyPath, srcText) => some (pyPath, .ofString srcText) - | none => none - let warningSummaryFile := pflags.getString "warning-summary" - let combinedLaurel ← - match ← Strata.pythonAndSpecToLaurel filePath dispatchModules pyspecModules sourcePath - (specDir := specDir) (profile := profile) - (quiet := quiet) - (warningSummaryFile := warningSummaryFile) |>.toBaseIO with - | .ok r => pure r - | .error (.userCode range msg) => - let location := if range.isNone then "" else - match mfm with - | some (_, fm) => - let pos := fm.toPosition range.start - s!" at line {pos.line}, col {pos.column}" - | none => "" - let filePath' := sourcePath.getD filePath - let mut lines := #[ - s!"(set-info :file {Strata.escapeSMTStringLit filePath'})" - ] - unless range.isNone do - lines := lines.push s!"(set-info :start {range.start})" - lines := lines.push s!"(set-info :stop {range.stop})" - lines := lines.push s!"(set-info :error-message {Strata.escapeSMTStringLit msg})" - for line in lines do - IO.println line - IO.FS.writeFile "user_errors.txt" (String.intercalate "\n" lines.toList ++ "\n") - exitPyAnalyzeUserError s!"{msg}{location}" - | .error (.knownLimitation msg) => - exitPyAnalyzeKnownLimitation msg - | .error (.internal msg) => - exitPyAnalyzeInternalError msg - - if verbose then - IO.println "\n==== Laurel Program ====" - IO.println f!"{combinedLaurel}" - - let keepPrefix := keepDir.map (s!"{·}/{baseName}") - - let (coreProgramOption, laurelTranslateErrors, _loweredLaurel, laurelPassStats) ← - profileStep profile "Laurel to Core translation" do - Strata.translateCombinedLaurelWithLowered combinedLaurel - (keepAllFilesPrefix := keepPrefix) (profile := profile) - - if profile && !laurelPassStats.data.isEmpty then - IO.println laurelPassStats.format - - let coreProgram ← - match coreProgramOption with - | none => - exitPyAnalyzeInternalError s!"Laurel to Core translation failed: {laurelTranslateErrors}" - | some core => pure core - - if verbose then - IO.println "\n==== Core Program ====" - IO.print (Core.formatProgram coreProgram) - - -- When --skip-verification is set, report translation diagnostics and exit - -- without running SMT verification (stages 3-4). - if pflags.getBool "skip-verification" then do - if !laurelTranslateErrors.isEmpty then - IO.eprintln "\n==== Errors ====" - for err in laurelTranslateErrors do - IO.eprintln err - if outputSarif then - let files := match mfm with - | some (pyPath, fm) => Map.empty.insert (Strata.Uri.file pyPath) fm - | none => Map.empty - Core.Sarif.writeSarifOutput .deductive files #[] (filePath ++ ".sarif") - let nStrataBug := laurelTranslateErrors.filter (·.type == .StrataBug) |>.length - let nNotYetImpl := laurelTranslateErrors.filter (·.type == .NotYetImplemented) |>.length - let nUserError := laurelTranslateErrors.filter (·.type == .UserError) |>.length - let nWarning := laurelTranslateErrors.filter (·.type == .Warning) |>.length - let counts := s!"{nUserError} user errors, {nWarning} warnings, {nNotYetImpl} not yet implemented, {nStrataBug} internal errors" - if nStrataBug > 0 then - exitPyAnalyzeInternalError s!"Translation produced internal errors. {counts}" - else if nNotYetImpl > 0 then - exitPyAnalyzeKnownLimitation s!"Translation encountered unsupported constructs. {counts}" - else - printPyAnalyzeResult "Analysis success" counts - return - - -- Verify using Core verifier - -- --keep-all-files implies vc-directory if not explicitly set - let baseVcDir := keepDir.map (fun dir => (s!"{dir}/{baseName}" : System.FilePath)) - let pyAnalyzeBase : VerifyOptions := - { VerifyOptions.default with - verbose := .quiet, removeIrrelevantAxioms := .Precise, - vcDirectory := baseVcDir } - let options ← parseVerifyOptions pflags pyAnalyzeBase - let isBugFinding := options.checkMode == .bugFinding - || options.checkMode == .bugFindingAssumingCompleteSpec - - -- Parse --entry-point flag (only supported in bug-finding modes). - let entryPointFlag := pflags.getString "entry-point" - let entryPoint : EntryPoint ← - if isBugFinding then - match entryPointFlag with - | some s => - match EntryPoint.ofString? s with - | some ep => pure ep - | none => - exitPyAnalyzeUserError s!"Invalid --entry-point value '{s}'. Must be {EntryPoint.options}." - | none => pure .roots - else - if entryPointFlag.isSome then - exitPyAnalyzeUserError s!"--entry-point is unsupported in {options.checkMode} mode" - else pure .all - - -- Pick the procedures to verify and set up inlining phases. - let userSourcePath := sourcePath.getD filePath - let (_, userProcNames) := - Strata.splitProcNames coreProgram [userSourcePath] - let (proceduresToVerify, inlinePhases) := - if isBugFinding then - let ⟨p, i⟩ := Core.chooseEntryProceduresAndBuildInlinePhases coreProgram userProcNames entryPoint - (p, [i]) - else (userProcNames, []) - - let vcResults ← profileStep profile "SMT verification" do - match ← Core.verifyProgram coreProgram options - (moreFns := Strata.Python.ReFactory) - (proceduresToVerify := some proceduresToVerify) - (externalPhases := [Strata.frontEndPhase]) - (prefixPhases := inlinePhases) - (keepAllFilesPrefix := keepPrefix) - |>.toBaseIO with - | .ok r => pure r.mergeByAssertion - | .error msg => exitPyAnalyzeInternalError msg - - -- Print translation errors (always on stderr) - if !laurelTranslateErrors.isEmpty then - IO.eprintln "\n==== Errors ====" - for err in laurelTranslateErrors do - IO.eprintln err - - -- Print per-VC results by default, unless SARIF mode is used - if !outputSarif then - let mut s := "" - for vcResult in vcResults do - let fileMap := mfm.map (·.2) - let location := match Imperative.getFileRange vcResult.obligation.metadata with - | some fr => - if fr.range.isNone then "" - else s!"{fr.format fileMap (includeEnd? := false)}" - | none => "" - let messageSuffix := match vcResult.obligation.metadata.getPropertySummary with - | some msg => s!" - {msg}" - | none => s!" - {vcResult.obligation.label}" - let outcomeStr := vcResult.formatOutcome - let loc := if !location.isEmpty then s!"{location}: " else "unknown location: " - s := s ++ s!"{loc}{outcomeStr}{messageSuffix}\n" - IO.print s - -- Output in SARIF format if requested - if outputSarif then - let files := match mfm with - | some (pyPath, fm) => Map.empty.insert (Strata.Uri.file pyPath) fm - | none => Map.empty - Core.Sarif.writeSarifOutput options.checkMode files vcResults (filePath ++ ".sarif") - printPyAnalyzeSummary vcResults options.checkMode - -def pyAnalyzeToGotoCommand : Command where - name := "pyAnalyzeToGoto" - args := [ "file" ] - help := "Translate a Strata Python Ion file to CProver GOTO JSON files." - callback := fun v _ => do - let filePath := v[0] - let pySourceOpt ← tryReadPythonSource filePath - let sourcePathForMetadata := match pySourceOpt with - | some (pyPath, _) => pyPath - | none => filePath - let sourceText := pySourceOpt.map (·.2) - let newPgm ← Strata.pythonDirectToCore filePath sourcePathForMetadata - match Core.inlineProcedures newPgm { doInline := (fun _caller callee _ => callee ≠ "main") } with - | .error e => exitInternalError (toString e) - | .ok newPgm => - -- Type-check the full program (registers Python types like ExceptOrNone) - let Ctx := { Lambda.LContext.default with functions := Strata.Python.PythonFactory, knownTypes := Core.KnownTypes } - let Env := Lambda.TEnv.default - let (tcPgm, _) ← match Core.Program.typeCheck Ctx Env newPgm with - | .ok r => pure r - | .error e => exitInternalError s!"{e.format none}" - -- Find the main procedure - let some mainDecl := tcPgm.decls.find? fun d => - match d with - | .proc p _ => Core.CoreIdent.toPretty p.header.name == "main" - | _ => false - | exitInternalError "No main procedure found" - let some p := mainDecl.getProc? - | exitInternalError "main is not a procedure" - -- Translate procedure to GOTO (mirrors CoreToGOTO.transformToGoto post-typecheck logic) - let baseName := deriveBaseName filePath - let procName := Core.CoreIdent.toPretty p.header.name - let axioms := tcPgm.decls.filterMap fun d => d.getAxiom? - let distincts := tcPgm.decls.filterMap fun d => match d with - | .distinct name es _ => some (name, es) | _ => none - match procedureToGotoCtx Env p sourceText (axioms := axioms) (distincts := distincts) - with - | .error e => exitInternalError s!"{e}" - | .ok (ctx, liftedFuncs) => - let extraSyms ← match collectExtraSymbols tcPgm with - | .ok s => pure (Lean.toJson s) - | .error e => exitInternalError s!"{e}" - let (symtab, goto) ← emitProcWithLifted Env procName ctx liftedFuncs extraSyms - (moduleName := baseName) - let symTabFile := s!"{baseName}.symtab.json" - let gotoFile := s!"{baseName}.goto.json" - writeJsonFile symTabFile symtab - writeJsonFile gotoFile goto - IO.println s!"Written {symTabFile} and {gotoFile}" - -def pyTranslateLaurelCommand : Command where - name := "pyTranslateLaurel" - args := [ "file" ] - flags := [{ name := "pyspec", - help := "PySpec module name (e.g., servicelib.Storage).", - takesArg := .repeat "module" }, - { name := "dispatch", - help := "Dispatch module name (e.g., servicelib).", - takesArg := .repeat "module" }, - { name := "spec-dir", - help := "Directory containing compiled PySpec Ion files.", - takesArg := .arg "dir" }] - help := "Translate a Strata Python Ion file through Laurel to Strata Core. Write results to stdout." - callback := fun v pflags => do - let dispatchModules := pflags.getRepeated "dispatch" - let pyspecModules := pflags.getRepeated "pyspec" - let specDir := pflags.getString "spec-dir" |>.getD "." - unless ← System.FilePath.isDir specDir do - exitFailure s!"spec-dir '{specDir}' does not exist or is not a directory" - let coreProgram ← - match ← Strata.pyTranslateLaurel v[0] dispatchModules pyspecModules (specDir := specDir) |>.toBaseIO with - | .ok r => pure r - | .error msg => exitFailure msg - IO.print coreProgram - -def pyAnalyzeLaurelToGotoCommand : Command where - name := "pyAnalyzeLaurelToGoto" - args := [ "file" ] - flags := [{ name := "pyspec", - help := "PySpec module name (e.g., servicelib.Storage).", - takesArg := .repeat "module" }, - { name := "dispatch", - help := "Dispatch module name (e.g., servicelib).", - takesArg := .repeat "module" }, - { name := "spec-dir", - help := "Directory containing compiled PySpec Ion files.", - takesArg := .arg "dir" }] - help := "Translate a Strata Python Ion file through Laurel to CProver GOTO JSON files." - callback := fun v pflags => do - let filePath := v[0] - let dispatchModules := pflags.getRepeated "dispatch" - let pyspecModules := pflags.getRepeated "pyspec" - let specDir := pflags.getString "spec-dir" |>.getD "." - unless ← System.FilePath.isDir specDir do - exitFailure s!"spec-dir '{specDir}' does not exist or is not a directory" - let (coreProgram, laurelTranslateErrors) ← - match ← Strata.pyTranslateLaurel filePath dispatchModules pyspecModules (specDir := specDir) |>.toBaseIO with - | .ok r => pure r - | .error msg => exitFailure msg - let sourceText := (← tryReadPythonSource filePath).map (·.2) - let baseName := deriveBaseName filePath - match ← Strata.inlineCoreToGotoFiles coreProgram baseName sourceText - (factory := Strata.Python.PythonFactory) |>.toBaseIO with - | .ok () => pure () - | .error msg => exitFailure msg - -def javaGenCommand : Command where - name := "javaGen" - args := [ "dialect", "package", "output-dir" ] - flags := [includeFlag] - help := "Generate Java source files from a DDM dialect definition. Accepts a dialect name (e.g. Laurel) or a dialect file path." - callback := fun v pflags => do - let fm ← pflags.buildDialectFileMap - let ld ← fm.getLoaded - let d ← if mem : v[0] ∈ ld.dialects then - pure ld.dialects[v[0]] - else - match ← Strata.readStrataFile fm v[0] with - | .dialect d => pure d - | .program _ => exitFailure "Expected a dialect file, not a program file." - match Strata.Java.generateDialect d v[1] with - | .ok files => - Strata.Java.writeJavaFiles v[2] v[1] files - IO.println s!"Generated Java files for {d.name} in {v[2]}/{Strata.Java.packageToPath v[1]}" - | .error msg => - exitFailure s!"Error generating Java: {msg}" - -def laurelAnalyzeBinaryCommand : Command where - name := "laurelAnalyzeBinary" - args := [] - flags := laurelVerifyOptionsFlags - help := "Verify Laurel Ion programs read from stdin and print diagnostics. Combines multiple input files." - callback := fun _ pflags => do - let options ← parseLaurelVerifyOptions pflags - let stdinBytes ← (← IO.getStdin).readBinToEnd - let combinedProgram ← Strata.readLaurelIonProgram stdinBytes - let diagnostics ← Strata.Laurel.verifyToDiagnosticModels combinedProgram options - - IO.println s!"==== DIAGNOSTICS ====" - for diag in diagnostics do - IO.println s!"{Std.format diag.fileRange.file}:{diag.fileRange.range.start}-{diag.fileRange.range.stop}: {diag.message}" - -def pySpecToLaurelCommand : Command where - name := "pySpecToLaurel" - args := [ "python_path", "strata_path" ] - help := "Translate a PySpec Ion file to Laurel declarations. The Ion file must already exist." - callback := fun v _ => do - let pythonFile : System.FilePath := v[0] - let strataDir : System.FilePath := v[1] - let some mod := pythonFile.fileStem - | exitFailure s!"No stem {pythonFile}" - let .ok mod := Strata.Python.Specs.ModuleName.ofString mod - | exitFailure s!"Invalid module {mod}" - let ionFile := strataDir / mod.strataFileName - let sigs ← - match ← Strata.Python.Specs.readDDM ionFile |>.toBaseIO with - | .ok t => pure t - | .error msg => exitFailure s!"Could not read {ionFile}: {msg}" - let result := Strata.Python.Specs.ToLaurel.signaturesToLaurel pythonFile sigs "" - if result.errors.size > 0 then - IO.eprintln s!"{result.errors.size} translation warning(s):" - for err in result.errors do - IO.eprintln s!" {err.file}: {err.message}" - let pgm := result.program - IO.println s!"Laurel: {pgm.staticProcedures.length} procedure(s), {pgm.types.length} type(s)" - IO.println s!"Overloads: {result.overloads.size} function(s)" - for td in pgm.types do - IO.println s!" {Strata.Laurel.formatTypeDefinition td}" - for proc in pgm.staticProcedures do - IO.println s!" {Strata.Laurel.formatProcedure proc}" - -def pyResolveOverloadsCommand : Command where - name := "pyResolveOverloads" - args := [ "python_path", "dispatch_ion" ] - help := "Identify which overloaded service modules a \ - Python program uses. Prints one module name per \ - line to stdout." - callback := fun v _ => do - let pythonFile : System.FilePath := v[0] - let dispatchPath := v[1] - -- Read dispatch overload table - let overloads ← - match ← readDispatchOverloads #[dispatchPath] |>.toBaseIO with - | .ok (r, _) => pure r - | .error msg => exitFailure msg - -- Convert .py to Python AST - let stmts ← - IO.FS.withTempFile fun _handle dialectFile => do - IO.FS.writeBinFile dialectFile - Strata.Python.Python.toIon - match ← Strata.Python.pythonToStrata dialectFile pythonFile |>.toBaseIO with - | .ok s => pure s - | .error msg => exitFailure msg - -- Walk AST and collect modules - let state := - Strata.Python.Specs.IdentifyOverloads.resolveOverloads - overloads stmts - for w in state.warnings do - IO.eprintln s!"warning: {w}" - let sorted := state.modules.toArray.qsort (· < ·) - for m in sorted do - IO.println m - -def laurelParseCommand : Command where - name := "laurelParse" - args := [ "file" ] - help := "Parse a Laurel source file (no verification)." - callback := fun v _ => do - let _ ← Strata.readLaurelTextFile v[0] - IO.println "Parse successful" - -def laurelAnalyzeCommand : Command where - name := "laurelAnalyze" - args := [ "file" ] - flags := laurelVerifyOptionsFlags - help := "Analyze a Laurel source file. Write diagnostics to stdout." - callback := fun v pflags => do - let options ← parseLaurelVerifyOptions pflags - let laurelProgram ← Strata.readLaurelTextFile v[0] - let (vcResultsOption, errors) ← Strata.Laurel.verifyToVcResults laurelProgram options - if !errors.isEmpty then - IO.println s!"==== ERRORS ====" - for err in errors do - IO.println s!"{err.message}" - match vcResultsOption with - | none => return - | some vcResults => - IO.println s!"==== RESULTS ====" - for vc in vcResults do - IO.println s!"{vc.obligation.label}: {match vc.outcome with | .ok o => repr o | .error e => toString e}" - -def laurelAnalyzeToGotoCommand : Command where - name := "laurelAnalyzeToGoto" - args := [ "file" ] - help := "Translate a Laurel source file to CProver GOTO JSON files." - callback := fun v _ => do - let path : System.FilePath := v[0] - let content ← IO.FS.readFile path - let laurelProgram ← Strata.parseLaurelText path content - match ← Strata.Laurel.translate {} laurelProgram with - | (none, diags) => exitFailure s!"Core translation errors: {diags.map (·.message)}" - | (some coreProgram, errors) => - let Ctx := { Lambda.LContext.default with functions := Core.Factory, knownTypes := Core.KnownTypes } - let Env := Lambda.TEnv.default - let (tcPgm, _) ← match Core.Program.typeCheck Ctx Env coreProgram with - | .ok r => pure r - | .error e => exitInternalError s!"{e.format none}" - let procs := tcPgm.decls.filterMap fun d => d.getProc? - let funcs := tcPgm.decls.filterMap fun d => - match d.getFunc? with - | some f => - let name := Core.CoreIdent.toPretty f.name - if f.body.isSome && f.typeArgs.isEmpty - && name != "Int.DivT" && name != "Int.ModT" - then some f else none - | none => none - if procs.isEmpty && funcs.isEmpty then exitInternalError "No procedures or functions found" - let baseName := deriveBaseName path.toString - let typeSyms ← match collectExtraSymbols tcPgm with - | .ok s => pure s - | .error e => exitInternalError s!"{e}" - let typeSymsJson := Lean.toJson typeSyms - let sourceText := some content - let axioms := tcPgm.decls.filterMap fun d => d.getAxiom? - let distincts := tcPgm.decls.filterMap fun d => match d with - | .distinct name es _ => some (name, es) | _ => none - let mut symtabPairs : List (String × Lean.Json) := [] - let mut gotoFns : Array Lean.Json := #[] - let mut allLiftedFuncs : List Core.Function := [] - for p in procs do - let procName := Core.CoreIdent.toPretty p.header.name - match procedureToGotoCtx Env p (sourceText := sourceText) (axioms := axioms) (distincts := distincts) - with - | .error e => exitInternalError s!"{e}" - | .ok (ctx, liftedFuncs) => - allLiftedFuncs := allLiftedFuncs ++ liftedFuncs - let json ← IO.ofExcept (CoreToGOTO.CProverGOTO.Context.toJson procName ctx) - match json.symtab with - | .obj m => symtabPairs := symtabPairs ++ m.toList - | _ => pure () - match json.goto with - | .obj m => - match m.toList.find? (·.1 == "functions") with - | some (_, .arr fns) => gotoFns := gotoFns ++ fns - | _ => pure () - | _ => pure () - for f in funcs ++ allLiftedFuncs do - let funcName := Core.CoreIdent.toPretty f.name - match functionToGotoCtx Env f with - | .error e => exitInternalError s!"{e}" - | .ok ctx => - let json ← IO.ofExcept (CoreToGOTO.CProverGOTO.Context.toJson funcName ctx) - match json.symtab with - | .obj m => symtabPairs := symtabPairs ++ m.toList - | _ => pure () - match json.goto with - | .obj m => - match m.toList.find? (·.1 == "functions") with - | some (_, .arr fns) => gotoFns := gotoFns ++ fns - | _ => pure () - | _ => pure () - match typeSymsJson with - | .obj m => symtabPairs := symtabPairs ++ m.toList - | _ => pure () - -- Deduplicate: keep first occurrence of each symbol name (proper function - -- symbols come before basic symbol references from callers) - let mut seen : Std.HashSet String := {} - let mut dedupPairs : List (String × Lean.Json) := [] - for (k, v) in symtabPairs do - if !seen.contains k then - seen := seen.insert k - dedupPairs := dedupPairs ++ [(k, v)] - -- Add CBMC default symbols (architecture constants, builtins) - -- and wrap in {"symbolTable": ...} for symtab2gb - let symtabObj := dedupPairs.foldl - (fun (acc : Std.TreeMap.Raw String Lean.Json) (k, v) => acc.insert k v) - .empty - let symtab := CProverGOTO.wrapSymtab symtabObj (moduleName := baseName) - let goto := Lean.Json.mkObj [("functions", Lean.Json.arr gotoFns)] - let symTabFile := s!"{baseName}.symtab.json" - let gotoFile := s!"{baseName}.goto.json" - writeJsonFile symTabFile symtab - writeJsonFile gotoFile goto - IO.println s!"Written {symTabFile} and {gotoFile}" - -def laurelPrintCommand : Command where - name := "laurelPrint" - args := [] - help := "Read Laurel Ion from stdin and print in concrete syntax to stdout." - callback := fun _ _ => do - let stdinBytes ← (← IO.getStdin).readBinToEnd - let strataFiles ← Strata.readLaurelIonFiles stdinBytes - for strataFile in strataFiles do - IO.println s!"// File: {strataFile.filePath}" - let p := strataFile.program - let c := p.formatContext {} - let s := p.formatState - let fmt := p.commands.foldl (init := f!"") fun f cmd => - f ++ (Strata.mformat cmd c s).format - IO.println (fmt.pretty 100) - IO.println "" - -def prettyPrintCore (p : Core.Program) : String := - let decls := p.decls.map fun d => - let s := toString (Std.format d) - -- Add newlines after major sections in procedures - s.replace "preconditions:" "\n preconditions:" - |>.replace "postconditions:" "\n postconditions:" - |>.replace "body:" "\n body:\n " - |>.replace "assert [" "\n assert [" - |>.replace "init (" "\n init (" - |>.replace "while (" "\n while (" - |>.replace "if (" "\n if (" - |>.replace "call [" "\n call [" - |>.replace "else{" "\n else {" - |>.replace "}}" "}\n }" - String.intercalate "\n" decls - -def laurelToCoreCommand : Command where - name := "laurelToCore" - args := [ "file" ] - help := "Translate a Laurel source file to Core and print to stdout." - callback := fun v _ => do - let laurelProgram ← Strata.readLaurelTextFile v[0] - let (coreProgramOption, errors) ← Strata.Laurel.translate {} laurelProgram - if !errors.isEmpty then - IO.println s!"Core translation errors: {errors.map (·.message)}" - match coreProgramOption with - | none => return - | some coreProgram => IO.println (prettyPrintCore coreProgram) - -/-- Print a string word-wrapped to `width` columns with `indent` spaces of indentation. -/ -private def printIndented (indent : Nat) (s : String) (width : Nat := 80) : IO Unit := do - let pad := "".pushn ' ' indent - let words := s.splitOn " " |>.filter (!·.isEmpty) - let mut line := pad - let mut first := true - for word in words do - if first then - line := line ++ word - first := false - else if line.length + 1 + word.length > width then - IO.println line - line := pad ++ word - else - line := line ++ " " ++ word - unless line.length ≤ indent do - IO.println line - -structure CommandGroup where - name : String - commands : List Command - commonFlags : List Flag := [] - -private def validPasses := - "inlineProcedures, loopElim, callElim, filterProcedures, removeIrrelevantAxioms" - -/-- A single transform pass together with the `--procedures`/`--functions` - that were specified immediately after it on the command line. -/ -private structure PassConfig where - name : String - procedures : List String := [] - functions : List String := [] -deriving Inhabited - -/-- Walk the ordered flag entries and bind each `--procedures`/`--functions` - to the most recent `--pass`. -/ -private def buildPassConfigs (entries : Array (String × Option String)) - : IO (Array PassConfig) := do - let mut configs : Array PassConfig := #[] - for (flag, value) in entries do - match flag with - | "pass" => configs := configs.push { name := value.getD "" } - | "procedures" => - let some cur := configs.back? | exitFailure "--procedures must appear after a --pass" - let procs := (value.getD "").splitToList (· == ',') - configs := configs.pop.push { cur with procedures := cur.procedures ++ procs } - | "functions" => - let some cur := configs.back? | exitFailure "--functions must appear after a --pass" - let fns := (value.getD "").splitToList (· == ',') - configs := configs.pop.push { cur with functions := cur.functions ++ fns } - | _ => pure () - return configs - -def transformCommand : Command where - name := "transform" - args := [ "file" ] - flags := [ - { name := "pass", - help := s!"Transform pass to apply (repeatable, applied left to right). \ - Valid passes: {validPasses}. \ - --procedures and --functions after a --pass apply to that pass.", - takesArg := .repeat "name" }, - { name := "procedures", - help := "Comma-separated procedure names for the preceding --pass. \ - For filterProcedures: procedures to keep. \ - For inlineProcedures: procedures to inline.", - takesArg := .repeat "procs" }, - { name := "functions", - help := "Comma-separated function names for the preceding --pass (used by removeIrrelevantAxioms).", - takesArg := .repeat "funcs" }] - help := "Apply one or more transforms to a Core program and print the result." - callback := fun v pflags => do - let file := v[0] - let passConfigs ← buildPassConfigs pflags.entries - if passConfigs.isEmpty then - exitFailure s!"No --pass specified. Valid passes: {validPasses}." - -- Read and parse the Core program - let (pgm, _) ← match ← readStrataProgram file with - | .ok r => pure r - | .error msgs => - for e in msgs do println! s!"Error: {← e.toString}" - exitFailure s!"{msgs.size} parse error(s)" - match Strata.genericToCore pgm with - | .error msg => - exitFailure msg - | .ok initProgram => - -- Validate and convert pass configs to TransformPass values - let mut passes : List Strata.Core.TransformPass := [] - for pc in passConfigs do - match pc.name with - | "inlineProcedures" => - let opts : Core.InlineTransformOptions := - if pc.procedures.isEmpty then {} - else { doInline := (fun _caller callee _ => callee ∈ pc.procedures) } - passes := passes ++ [.inlineProcedures opts] - | "loopElim" => - passes := passes ++ [.loopElim] - | "callElim" => - passes := passes ++ [.callElim] - | "filterProcedures" => - if pc.procedures.isEmpty then - exitFailure "filterProcedures requires --procedures" - passes := passes ++ [.filterProcedures pc.procedures] - | "removeIrrelevantAxioms" => - if pc.functions.isEmpty then - exitFailure "removeIrrelevantAxioms requires --functions" - passes := passes ++ [.removeIrrelevantAxioms pc.functions] - | other => - exitFailure s!"Unknown pass '{other}'. Valid passes: {validPasses}." - -- Run all passes in a single CoreTransformM chain so fresh variable - -- counters accumulate and cached analyses are reused across passes. - match Strata.Core.runTransforms initProgram passes with - | .ok program => IO.print (Core.formatProgram program) - | .error e => exitFailure s!"Transform failed: {e}" - -def verifyCommand : Command where - name := "verify" - args := [ "file" ] - flags := verifyOptionsFlags ++ [ - { name := "check", help := "Process up until SMT generation, but don't solve." }, - { name := "type-check", help := "Exit after semantic dialect's type inference/checking." }, - { name := "parse-only", help := "Exit after DDM parsing and type checking." }, - { name := "output-format", help := "Output format (only 'sarif' supported).", takesArg := .arg "format" }, - { name := "procedures", help := "Verify only the specified procedures (comma-separated).", takesArg := .arg "procs" }] - help := "Verify a Strata program file (.core.st, .csimp.st, or .b3.st)." - callback := fun v pflags => do - let file := v[0] - let proceduresToVerify := pflags.getString "procedures" |>.map (·.splitToList (· == ',')) - let opts ← parseVerifyOptions pflags { VerifyOptions.default with verbose := .quiet } - let opts := { opts with - checkOnly := pflags.getBool "check", - typeCheckOnly := pflags.getBool "type-check", - parseOnly := pflags.getBool "parse-only", - outputSarif := opts.outputSarif || pflags.getString "output-format" == some "sarif" } - let (pgm, inputCtx) ← match ← readStrataProgram file with - | .ok r => pure r - | .error errors => - for e in errors do - let msg ← e.toString - println! s!"Error: {msg}" - println! f!"Finished with {errors.size} errors." - IO.Process.exit ExitCode.userError - println! s!"Successfully parsed." - if opts.parseOnly then return - if opts.typeCheckOnly then - let ans := if file.endsWith ".csimp.st" then - C_Simp.typeCheck pgm opts - else if pgm.dialect == "Boole" then - Boole.typeCheck pgm opts - else - typeCheck inputCtx pgm opts - match ans with - | .error e => - println! f!"{e.formatRange (some inputCtx.fileMap) true} {e.message}" - IO.Process.exit ExitCode.userError - | .ok _ => - println! f!"Program typechecked." - return - -- Full verification - let vcResults ← try - if file.endsWith ".csimp.st" then - C_Simp.verify pgm opts - else if file.endsWith ".b3.st" || file.endsWith ".b3cst.st" then - let ast ← match B3.Verifier.programToB3AST pgm with - | Except.error msg => throw (IO.userError s!"Failed to convert to B3 AST: {msg}") - | Except.ok ast => pure ast - let solver ← B3.Verifier.createInteractiveSolver opts.solver - let reports ← B3.Verifier.programToSMT ast solver - for report in reports do - IO.println s!"\nProcedure: {report.procedureName}" - for (result, _) in report.results do - let marker := if result.result.isError then "✗" else "✓" - let desc := match result.result with - | .error .counterexample => "counterexample found" - | .error .unknown => "unknown" - | .error .refuted => "refuted" - | .success .verified => "verified" - | .success .reachable => "reachable" - | .success .reachabilityUnknown => "reachability unknown" - IO.println s!" {marker} {desc}" - pure #[] - else if pgm.dialect == "Boole" then - Boole.verify opts.solver pgm inputCtx proceduresToVerify opts - else - verify pgm inputCtx proceduresToVerify opts - catch e => - println! f!"{e}" - IO.Process.exit ExitCode.internalError - if opts.outputSarif then - if file.endsWith ".csimp.st" then - println! "SARIF output is not supported for C_Simp files (.csimp.st) because location metadata is not preserved during translation to Core." - else - let uri := Strata.Uri.file file - let files := Map.empty.insert uri inputCtx.fileMap - Core.Sarif.writeSarifOutput opts.checkMode files vcResults (file ++ ".sarif") - for vcResult in vcResults do - let posStr := Imperative.MetaData.formatFileRangeD vcResult.obligation.metadata (some inputCtx.fileMap) - println! f!"{posStr} [{vcResult.obligation.label}]: \ - {vcResult.formatOutcome}" - let success := vcResults.all Core.VCResult.isSuccess - if success && !opts.checkOnly then - println! f!"All {vcResults.size} goals passed." - else if success && opts.checkOnly then - println! f!"Skipping verification." - else - let provedGoalCount := (vcResults.filter Core.VCResult.isSuccess).size - let failedGoalCount := (vcResults.filter Core.VCResult.isNotSuccess).size - -- Encoding failures, solver crashes, or per-check SMT errors (exit 3) - let hasImplError := vcResults.any (fun r => r.isImplementationError || r.hasSMTError) - -- Assertion violations that are not timeouts or internal errors (exit 2) - let hasFailure := vcResults.any (fun r => !r.isSuccess && !r.isTimeout && !r.isImplementationError && !r.hasSMTError) - println! f!"Finished with {provedGoalCount} goals passed, {failedGoalCount} failed." - if hasImplError then - IO.Process.exit ExitCode.internalError - else if hasFailure then - IO.Process.exit ExitCode.failuresFound - -def pyInterpretCommand : Command where - name := "pyInterpret" - args := [ "file" ] - flags := [{ name := "fuel", help := "Maximum execution steps.", takesArg := .arg "n" }] - ++ laurelTranslateFlags - help := "Interpret a Python Ion program concretely (Python → Laurel → Core → execute)." - callback := fun v pflags => do - let filePath := v[0] - let keepDir := pflags.getString "keep-all-files" - let fuel ← match pflags.getString "fuel" with - | some s => match s.toNat? with - | .some n => pure n - | .none => exitFailure s!"Invalid fuel: '{s}'" - | none => pure 10000 - - let (core, _diags) ← - match ← Strata.pythonAndSpecToLaurel filePath (specDir := ".") |>.toBaseIO with - | .ok laurel => - if let some dir := keepDir then - IO.FS.createDirAll dir - IO.FS.writeFile (dir ++ "/laurel.st") (toString (Std.format laurel)) - match ← Strata.translateCombinedLaurel laurel with - | (some core, diags) => pure (core, diags) - | (none, diags) => exitFailure s!"Laurel to Core translation failed: {diags}" - | .error msg => exitFailure (toString msg) - if let some dir := keepDir then - IO.FS.writeFile (dir ++ "/core.st") (toString (Std.format core)) - let core ← match Core.typeCheck Core.VerifyOptions.quiet core - (moreFns := Strata.Python.ReFactory) with - | .ok prog => pure prog - | .error e => - println! s!"Core type checking failed: {e.message}" - IO.Process.exit ExitCode.userError - match core.run with - | .ok E => - let mainProc := Core.Program.Procedure.find? core ⟨"__main__", ()⟩ - let outputNames := match mainProc with - | some p => p.header.outputs.keys.map (·.name) - | none => [] - let (lhs, exprEnv) := Core.Env.genVars outputNames E.exprEnv - let E := { E with exprEnv } - let E := Core.Statement.Command.runCall lhs "__main__" [] fuel E - match E.error with - | none => - IO.println "Execution completed successfully." - | some e => - IO.println s!"{Std.format e}" - IO.Process.exit ExitCode.failuresFound - | .error diag => - IO.eprintln s!"Error: {diag}" - IO.Process.exit ExitCode.failuresFound - -def commandGroups : List CommandGroup := [ - { name := "Core" - commands := [verifyCommand, transformCommand, checkCommand, toIonCommand, printCommand, diffCommand] - commonFlags := [includeFlag] }, - { name := "Code Generation" - commands := [javaGenCommand] }, - { name := "Python" - commands := [pyAnalyzeLaurelCommand, - pyResolveOverloadsCommand, - pySpecsCommand, pySpecToLaurelCommand, - pyAnalyzeLaurelToGotoCommand, - pyAnalyzeToGotoCommand, - pyTranslateLaurelCommand, - pyInterpretCommand] }, - { name := "Laurel" - commands := [laurelAnalyzeCommand, laurelAnalyzeBinaryCommand, - laurelAnalyzeToGotoCommand, laurelParseCommand, - laurelPrintCommand, laurelToCoreCommand] }, -] - -def commandList : List Command := - commandGroups.foldl (init := []) fun acc g => acc ++ g.commands - -def commandMap : Std.HashMap String Command := - commandList.foldl (init := {}) fun m c => m.insert c.name c - -/-- Print a single flag's name and help text at the given indentation. -/ -private def printFlag (indent : Nat) (flag : Flag) : IO Unit := do - let pad := "".pushn ' ' indent - match flag.takesArg with - | .arg argName | .repeat argName => - IO.println s!"{pad}--{flag.name} <{argName}> {flag.help}" - | .none => - IO.println s!"{pad}--{flag.name} {flag.help}" - -/-- Print help for all command groups. -/ -private def printGlobalHelp : IO Unit := do - IO.println "Usage: strata [flags]...\n" - IO.println "Command-line utilities for working with Strata.\n" - for group in commandGroups do - IO.println s!"{group.name}:" - for cmd in group.commands do - let cmdLine := cmd.args.foldl (init := cmd.name) fun s a => s!"{s} <{a}>" - IO.println s!" {cmdLine}" - printIndented 4 cmd.help - let perCmdFlags := cmd.flags.filter fun f => - !group.commonFlags.any fun cf => cf.name == f.name - if !perCmdFlags.isEmpty then - IO.println "" - IO.println " Flags:" - for flag in perCmdFlags do - printFlag 6 flag - IO.println "" - if !group.commonFlags.isEmpty then - IO.println " Common flags:" - for flag in group.commonFlags do - printFlag 4 flag - IO.println "" - -/-- Print help for a single command. -/ -private def printCommandHelp (cmd : Command) : IO Unit := do - let cmdLine := cmd.args.foldl (init := s!"strata {cmd.name}") fun s a => s!"{s} <{a}>" - let flagSummary := cmd.flags.foldl (init := "") fun s f => - match f.takesArg with - | .arg argName | .repeat argName => s!"{s} [--{f.name} <{argName}>]" - | .none => s!"{s} [--{f.name}]" - IO.println s!"Usage: {cmdLine}{flagSummary}\n" - printIndented 0 cmd.help - if !cmd.flags.isEmpty then - IO.println "\nFlags:" - for flag in cmd.flags do - printFlag 2 flag - -/-- Parse interleaved flags and positional arguments. Returns the collected - positional arguments and parsed flags. -/ -private def parseArgs (cmdName : String) - (flagMap : Std.HashMap String Flag) - (acc : Array String) (pflags : ParsedFlags) - (cmdArgs : List String) : IO (Array String × ParsedFlags) := do - match cmdArgs with - | arg :: cmdArgs => - if arg.startsWith "--" then - let raw := (arg.drop 2).toString - -- Support --flag=value syntax by splitting on first '=' - let (flagName, inlineValue) ← match raw.splitOn "=" with - | name :: value :: rest => - if !rest.isEmpty then - exitCmdFailure cmdName s!"Invalid option format: {arg}. Values must not contain '='." - pure (name, some value) - | _ => pure (raw, none) - match flagMap[flagName]? with - | some flag => - match flag.takesArg with - | .none => - parseArgs cmdName flagMap acc (pflags.insert flagName Option.none) cmdArgs - | .arg _ => - match inlineValue with - | some value => - parseArgs cmdName flagMap acc (pflags.insert flagName (some value)) cmdArgs - | none => - let value :: cmdArgs := cmdArgs - | exitCmdFailure cmdName s!"Expected value after {arg}." - parseArgs cmdName flagMap acc (pflags.insert flagName (some value)) cmdArgs - | .repeat _ => - match inlineValue with - | some value => - parseArgs cmdName flagMap acc (pflags.insert flagName (some value)) cmdArgs - | none => - let value :: cmdArgs := cmdArgs - | exitCmdFailure cmdName s!"Expected value after {arg}." - parseArgs cmdName flagMap acc (pflags.insert flagName (some value)) cmdArgs - | none => - exitCmdFailure cmdName s!"Unknown option {arg}." - else - parseArgs cmdName flagMap (acc.push arg) pflags cmdArgs - | [] => - pure (acc, pflags) - -public -def main (args : List String) : IO Unit := do - try do - match args with - | ["--help"] => printGlobalHelp - | cmd :: args => - match commandMap[cmd]? with - | none => exitFailure s!"Expected subcommand, got {cmd}." - | some cmd => - -- Handle per-command help before parsing flags. - if args.contains "--help" then - printCommandHelp cmd - return - -- Index the command's flags by name for O(1) lookup during parsing. - let flagMap : Std.HashMap String Flag := - cmd.flags.foldl (init := {}) fun m f => m.insert f.name f - -- Split raw args into positional arguments and parsed flags. - let (args, pflags) ← parseArgs cmd.name flagMap #[] {} args - if p : args.size = cmd.args.length then - cmd.callback ⟨args, p⟩ pflags - else - exitCmdFailure cmd.name s!"{cmd.name} expects {cmd.args.length} argument(s)." - | [] => do - exitFailure "Expected subcommand." - catch e => - exitFailure e.toString +def main (args : List String) : IO Unit := + runCommandMap commandMap commandGroups args diff --git a/StrataMainLib.lean b/StrataMainLib.lean new file mode 100644 index 0000000000..2ca6c87ba2 --- /dev/null +++ b/StrataMainLib.lean @@ -0,0 +1,1569 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ + +-- Library with utilities for working with Strata files. +import Lean.Parser.Extension +import Strata.Backends.CBMC.CollectSymbols +import Strata.Backends.CBMC.GOTO.CoreToGOTOPipeline +import Strata.DDM.Integration.Java.Gen +import Strata.Languages.Core.Verifier +import Strata.Languages.Core.SarifOutput +import Strata.Languages.Core.ProgramEval +import Strata.Languages.Core.StatementEval +import Strata.Languages.C_Simp.Verify +import Strata.Languages.B3.Verifier.Program +import Strata.Languages.Laurel.LaurelCompilationPipeline +import Strata.Pipeline.Diagnostic +import Strata.Pipeline.PyAnalyzeLaurel +import Strata.Languages.Boole.Boole +import Strata.Languages.Boole.Verify +import Strata.Languages.Python.Python +import Strata.Languages.Python.Specs.IdentifyOverloads +import Strata.Languages.Python.Specs.ToLaurel +import Strata.Languages.Laurel.Grammar.AbstractToConcreteTreeTranslator +import Strata.Languages.Laurel.Laurel +import Strata.Languages.Core.EntryPoint +import Strata.Transform.ProcedureInlining +import Strata.Util.IO + +import Strata.SimpleAPI +import Strata.Util.Json +import Strata.DDM.BuiltinDialects +import Strata.DDM.Util.String +import Strata.Languages.Python.PyFactory +import Strata.Languages.Python.Specs +import Strata.Languages.Python.Specs.DDM +import Strata.Languages.Python.ReadPython + +open Strata + +open Core (VerifyOptions VerboseMode VerificationMode CheckLevel EntryPoint) +open Laurel (LaurelVerifyOptions LaurelTranslateOptions) + +/-! ## Exit codes + +All `strata` subcommands use a common exit code scheme: + +| Code | Category | Meaning | +|------|--------------------|-----------------------------------------------------------| +| 0 | Success | Analysis passed, inconclusive, or `--no-solve` completed. | +| 1 | User error | Bad input: invalid arguments, malformed source, etc. | +| 2 | Failures found | Analysis completed and found failures. | +| 3 | Internal error | SMT encoding failure, solver crash, or translation bug. | +| 4 | Known limitation | Intentionally unsupported language construct. | + +Codes 1–2 are **user-actionable** (fix the input or the code under analysis). +Codes 3–4 are **tool-side** (report as a bug or wait for support). +Exit 0 covers success, inconclusive results, and solver timeouts. -/ + +namespace ExitCode + def userError : UInt8 := 1 + def failuresFound : UInt8 := 2 + def internalError : UInt8 := 3 + def knownLimitation : UInt8 := 4 +end ExitCode + +def exitFailure {α} (message : String) (hint : String := "strata --help") : IO α := do + IO.eprintln s!"Exception: {message}\n\nRun {hint} for additional help." + IO.Process.exit ExitCode.userError + +/-- Exit with code 1 for user errors (bad input, malformed source, etc.). -/ +def exitUserError {α} (message : String) : IO α := do + IO.eprintln s!"❌ {message}" + IO.Process.exit ExitCode.userError + +/-- Exit with code 2 for analysis failures found. -/ +def exitFailuresFound {α} (message : String) : IO α := do + IO.eprintln s!"Failures found: {message}" + IO.Process.exit ExitCode.failuresFound + +/-- Exit with code 3 for internal errors (tool limitations or crashes). -/ +def exitInternalError {α} (message : String) : IO α := do + IO.eprintln s!"Exception: {message}" + IO.Process.exit ExitCode.internalError + +/-- Exit with code 4 for known limitations (unsupported constructs). -/ +def exitKnownLimitation {α} (message : String) : IO α := do + IO.eprintln s!"Known limitation: {message}" + IO.Process.exit ExitCode.knownLimitation + +/-- Like `exitFailure` but tailors the help hint to a specific subcommand. -/ +def exitCmdFailure {α} (cmdName : String) (message : String) : IO α := + exitFailure message (hint := s!"strata {cmdName} --help") + +/-- How a flag consumes arguments. -/ +inductive FlagArg where + | none -- boolean flag, e.g. --verbose + | arg (name : String) -- takes one value, e.g. --output + | repeat (name : String) -- takes one value, may appear multiple times, e.g. --include + +/-- A flag that a command accepts. -/ +structure Flag where + name : String -- flag name without "--", used as lookup key + help : String + takesArg : FlagArg := .none + +/-- Parsed flags from the command line. Stored as an ordered array so that + command-line position is preserved (needed by `transform` to bind + `--procedures`/`--functions` to the preceding `--pass`). + For `.arg` flags that appear more than once, `getString` returns the + **last** occurrence (last-writer-wins). -/ +structure ParsedFlags where + entries : Array (String × Option String) := #[] + +namespace ParsedFlags + +def getBool (pf : ParsedFlags) (name : String) : Bool := + pf.entries.any (·.1 == name) + +def getString (pf : ParsedFlags) (name : String) : Option String := + -- Scan from the end so last occurrence wins. + match pf.entries.findRev? (·.1 == name) with + | some (_, some v) => some v + | _ => Option.none + +def getRepeated (pf : ParsedFlags) (name : String) : Array String := + pf.entries.foldl (init := #[]) fun acc (n, v) => + if n == name then match v with | some s => acc.push s | none => acc else acc + +def insert (pf : ParsedFlags) (name : String) (value : Option String) : ParsedFlags := + { pf with entries := pf.entries.push (name, value) } + +def buildDialectFileMap (pflags : ParsedFlags) : IO Strata.DialectFileMap := do + let preloaded := Strata.Elab.LoadedDialects.builtin + |>.addDialect! Strata.Python.Python + |>.addDialect! Strata.Python.Specs.DDM.PythonSpecs + |>.addDialect! Strata.Core + |>.addDialect! Strata.Boole + |>.addDialect! Strata.Laurel.Laurel + |>.addDialect! Strata.smtReservedKeywordsDialect + |>.addDialect! Strata.SMTCore + |>.addDialect! Strata.SMT + |>.addDialect! Strata.SMTResponse + let mut sp ← Strata.DialectFileMap.new preloaded + for path in pflags.getRepeated "include" do + match ← sp.add path |>.toBaseIO with + | .error msg => exitFailure msg + | .ok sp' => sp := sp' + return sp + +end ParsedFlags + +def parseCheckMode (pflags : ParsedFlags) + (default : VerificationMode := .deductive) : IO VerificationMode := + match pflags.getString "check-mode" with + | .none => pure default + | .some s => match VerificationMode.ofString? s with + | .some m => pure m + | .none => exitFailure s!"Invalid check mode: '{s}'. Must be {VerificationMode.options}." + +def parseCheckLevel (pflags : ParsedFlags) + (default : CheckLevel := .minimal) : IO CheckLevel := + match pflags.getString "check-level" with + | .none => pure default + | .some s => match CheckLevel.ofString? s with + | .some l => pure l + | .none => exitFailure s!"Invalid check level: '{s}'. Must be {CheckLevel.options}." + +/-- Common CLI flags for VerifyOptions fields. + Commands can append these to their own flags list. + Note: `parseOnly`, `typeCheckOnly`, and `checkOnly` are omitted here + because they are specific to the `verify` command. -/ +def verifyOptionsFlags : List Flag := [ + { name := "check-mode", + help := s!"Check mode: {VerificationMode.options}. Default: 'deductive'.", + takesArg := .arg "mode" }, + { name := "check-level", + help := s!"Check level: {CheckLevel.options}. Default: 'minimal'.", + takesArg := .arg "level" }, + { name := "verbose", help := "Enable verbose output." }, + { name := "quiet", help := "Suppress warnings on stderr." }, + { name := "profile", help := "Print elapsed time for each pipeline step." }, + { name := "sarif", help := "Write results as SARIF to .sarif." }, + { name := "solver", + help := s!"SMT solver executable (default: {Core.defaultSolver}).", + takesArg := .arg "name" }, + { name := "solver-timeout", + help := "Solver timeout in seconds (default: 10).", + takesArg := .arg "seconds" }, + { name := "vc-directory", + help := "Store VCs in SMT-Lib format in .", + takesArg := .arg "dir" }, + { name := "no-solve", + help := "Generate SMT-Lib files but do not invoke the solver." }, + { name := "stop-on-first-error", + help := "Exit after the first verification error." }, + { name := "unique-bound-names", + help := "Use globally unique names for quantifier-bound variables." }, + { name := "use-array-theory", + help := "Use SMT-LIB Array theory instead of axiomatized maps." }, + { name := "remove-irrelevant-axioms", + help := "Prune irrelevant axioms: 'off', 'aggressive', or 'precise'.", + takesArg := .arg "mode" }, + { name := "overflow-checks", + help := "Comma-separated overflow checks to enable (signed,unsigned,float64,all,none).", + takesArg := .arg "checks" }, + { name := "incremental", + help := "Use incremental solver backend (stdin/stdout) instead of batch file I/O." }, + { name := "path-cap", + help := "Maximum continuing paths between statements. 'none' (default) disables; N merges paths when count exceeds N.", + takesArg := .arg "N|none" }, + { name := "parallel", + help := "Number of parallel solver workers (default: 1, sequential).", + takesArg := .arg "N" } +] + +/-- Build a VerifyOptions from parsed CLI flags, starting from a base config. + Fields not present in the flags keep their base values. + Note: boolean flags can only enable a setting; a `true` in the base + cannot be turned off from the CLI (there is no `--no-X` syntax). -/ +def parseVerifyOptions (pflags : ParsedFlags) + (base : VerifyOptions := VerifyOptions.default) : IO VerifyOptions := do + let checkMode ← parseCheckMode pflags base.checkMode + let checkLevel ← parseCheckLevel pflags base.checkLevel + let solverTimeout ← match pflags.getString "solver-timeout" with + | .none => pure base.solverTimeout + | .some s => match s.toNat? with + | .some n => pure n + | .none => exitFailure s!"Invalid solver timeout: '{s}'" + let noSolve := pflags.getBool "no-solve" + let removeIrrelevantAxioms ← match pflags.getString "remove-irrelevant-axioms" with + | .none => pure base.removeIrrelevantAxioms + | .some "off" => pure .Off + | .some "aggressive" => pure .Aggressive + | .some "precise" => pure .Precise + | .some s => exitFailure s!"Invalid remove-irrelevant-axioms mode: '{s}'. Must be 'off', 'aggressive', or 'precise'." + let overflowChecks := match pflags.getString "overflow-checks" with + | .none => base.overflowChecks + | .some s => s.splitOn "," |>.foldl (fun acc c => + match c.trimAscii.toString with + | "signed" => { acc with signedBV := true } + | "unsigned" => { acc with unsignedBV := true } + | "float64" => { acc with float64 := true } + | "none" => { signedBV := false, unsignedBV := false, float64 := false } + | "all" => { signedBV := true, unsignedBV := true, float64 := true } + | _ => acc) { signedBV := false, unsignedBV := false, float64 := false } + let pathCap ← match pflags.getString "path-cap" with + | .none => pure base.pathCap + | .some "none" => pure .none + | .some s => match s.toNat? with + | .some n => if n == 0 then exitFailure "--path-cap must be at least 1 or 'none'." + else pure (.some n) + | .none => exitFailure s!"Invalid path-cap: '{s}'. Must be a positive number or 'none'." + let parallelWorkers ← match pflags.getString "parallel" with + | .none => pure base.parallelWorkers + | .some s => match s.toNat? with + | .some n => if n == 0 then exitFailure "--parallel must be at least 1." + else pure n + | .none => exitFailure s!"Invalid parallel workers: '{s}'. Must be a positive number." + let vcDirectory := (pflags.getString "vc-directory" |>.map (⟨·⟩ : String → System.FilePath)).orElse (fun _ => base.vcDirectory) + let skipSolver := noSolve || base.skipSolver + if skipSolver && vcDirectory.isNone then + exitFailure "--no-solve requires --vc-directory to specify where SMT files are stored." + pure { base with + verbose := if pflags.getBool "verbose" then .normal + else if pflags.getBool "quiet" then .quiet + else base.verbose, + solver := pflags.getString "solver" |>.getD base.solver, + solverTimeout, + checkMode, checkLevel, + stopOnFirstError := pflags.getBool "stop-on-first-error" || base.stopOnFirstError, + uniqueBoundNames := pflags.getBool "unique-bound-names" || base.uniqueBoundNames, + useArrayTheory := pflags.getBool "use-array-theory" || base.useArrayTheory, + removeIrrelevantAxioms, + outputSarif := pflags.getBool "sarif" || base.outputSarif, + profile := pflags.getBool "profile" || base.profile, + incremental := if noSolve then false else pflags.getBool "incremental" || base.incremental, + skipSolver, + alwaysGenerateSMT := noSolve || base.alwaysGenerateSMT, + overflowChecks, + vcDirectory, + pathCap, + parallelWorkers + } + +/-- Additional CLI flags for `LaurelVerifyOptions` fields that are not already + covered by `verifyOptionsFlags`. -/ +def laurelTranslateFlags : List Flag := [ + { name := "keep-all-files", + help := "Store intermediate Laurel and Core programs in .", + takesArg := .arg "dir" } +] + +/-- All CLI flags accepted by Laurel verify commands. -/ +def laurelVerifyOptionsFlags : List Flag := verifyOptionsFlags ++ laurelTranslateFlags + +/-- Build a `LaurelVerifyOptions` from parsed CLI flags. -/ +def parseLaurelVerifyOptions (pflags : ParsedFlags) + (base : LaurelVerifyOptions := default) : IO LaurelVerifyOptions := do + let verifyOptions ← parseVerifyOptions pflags base.verifyOptions + let keepAllFilesPrefix := (pflags.getString "keep-all-files").orElse + (fun _ => base.translateOptions.keepAllFilesPrefix) + let translateOptions : LaurelTranslateOptions := + { base.translateOptions with + keepAllFilesPrefix + overflowChecks := verifyOptions.overflowChecks } + return { translateOptions, verifyOptions } + +/-- Read and parse a Strata program file, loading the Core, C_Simp, and B3CST + dialects. Returns the parsed program and the input context (for source + location resolution), or an array of error messages on failure. -/ +private def readStrataProgram (file : String) + : IO (Except (Array Lean.Message) (Strata.Program × Lean.Parser.InputContext)) := do + let text ← Strata.Util.readInputSource file + let inputCtx := Lean.Parser.mkInputContext text (Strata.Util.displayName file) + let dctx := Elab.LoadedDialects.builtin + let dctx := dctx.addDialect! Core + let dctx := dctx.addDialect! Boole + let dctx := dctx.addDialect! C_Simp + let dctx := dctx.addDialect! B3CST + let leanEnv ← Lean.mkEmptyEnvironment 0 + match Strata.Elab.elabProgram dctx leanEnv inputCtx with + | .ok pgm => pure (.ok (pgm, inputCtx)) + | .error msgs => pure (.error msgs) + +structure Command where + name : String + args : List String + flags : List Flag := [] + help : String + callback : Vector String args.length → ParsedFlags → IO Unit + +def includeFlag : Flag := + { name := "include", help := "Add a dialect search path.", takesArg := .repeat "path" } + +def checkCommand : Command where + name := "check" + args := [ "file" ] + flags := [includeFlag] + help := "Parse and validate a Strata file (text or Ion). Reports errors and exits." + callback := fun v pflags => do + let fm ← pflags.buildDialectFileMap + let _ ← Strata.readStrataFile fm v[0] + +def toIonCommand : Command where + name := "toIon" + args := [ "input", "output" ] + flags := [includeFlag] + help := "Convert a Strata text file to Ion binary format." + callback := fun v pflags => do + let searchPath ← pflags.buildDialectFileMap + let pd ← Strata.readStrataFile searchPath v[0] + match pd with + | .dialect d => + IO.FS.writeBinFile v[1] d.toIon + | .program pgm => + IO.FS.writeBinFile v[1] pgm.toIon + +def printCommand : Command where + name := "print" + args := [ "file" ] + flags := [includeFlag] + help := "Pretty-print a Strata file (text or Ion) to stdout." + callback := fun v pflags => do + let searchPath ← pflags.buildDialectFileMap + -- Special case for already loaded dialects. + let ld ← searchPath.getLoaded + if mem : v[0] ∈ ld.dialects then + IO.print <| ld.dialects.format v[0] mem + return + let pd ← Strata.readStrataFile searchPath v[0] + match pd with + | .dialect d => + let ld ← searchPath.getLoaded + let .isTrue mem := (inferInstance : Decidable (d.name ∈ ld.dialects)) + | exitInternalError "Internal error reading file." + IO.print <| ld.dialects.format d.name mem + | .program pgm => + IO.print <| toString pgm + +def diffCommand : Command where + name := "diff" + args := [ "file1", "file2" ] + flags := [includeFlag] + help := "Compare two program files for syntactic equality. Reports the first difference found." + callback := fun v pflags => do + let fm ← pflags.buildDialectFileMap + let p1 ← Strata.readStrataFile fm v[0] + let p2 ← Strata.readStrataFile fm v[1] + match p1, p2 with + | .program p1, .program p2 => + if p1.dialect != p2.dialect then + exitFailure s!"Dialects differ: {p1.dialect} and {p2.dialect}" + let Decidable.isTrue eq := (inferInstance : Decidable (p1.commands.size = p2.commands.size)) + | exitFailure s!"Number of commands differ {p1.commands.size} and {p2.commands.size}" + for (c1, c2) in Array.zip p1.commands p2.commands do + if c1 != c2 then + exitFailure s!"Commands differ: {repr c1} and {repr c2}" + | _, _ => + exitFailure "Cannot compare dialect def with another dialect/program." + +def pySpecsCommand : Command where + name := "pySpecs" + args := [ "source_dir", "output_dir" ] + flags := [ + { name := "quiet", help := "Suppress default logging." }, + { name := "log", help := "Enable logging for an event type.", + takesArg := .repeat "event" }, + { name := "skip", + help := "Skip a top-level definition (module.name). Overloads are kept.", + takesArg := .repeat "name" }, + { name := "module", + help := "Translate only the named module (dot-separated). May be repeated.", + takesArg := .repeat "module" } + ] + help := "Translate Python specification files in a directory into Strata DDM Ion format. If --module is given, translates only those modules; otherwise translates all .py files. Creates subdirectories as needed. (Experimental)" + callback := fun v pflags => do + let quiet := pflags.getBool "quiet" + let mut events : Std.HashSet String := {} + if !quiet then + events := events.insert "import" + for e in pflags.getRepeated "log" do + events := events.insert e + let skipNames := pflags.getRepeated "skip" + let modules := pflags.getRepeated "module" + let warningOutput : Strata.WarningOutput := + if quiet then .none else .detail + -- Serialize embedded dialect for Python subprocess + IO.FS.withTempFile fun _handle dialectFile => do + IO.FS.writeBinFile dialectFile Strata.Python.Python.toIon + let r ← Strata.pySpecsDir (events := events) + (skipNames := skipNames) + (modules := modules) + (warningOutput := warningOutput) + v[0] v[1] dialectFile |>.toBaseIO + match r with + | .ok () => pure () + | .error msg => exitFailure msg + +/-- Derive Python source file path from Ion file path. + E.g., "tests/test_foo.python.st.ion" -> "tests/test_foo.py" -/ +def ionPathToPythonPath (ionPath : String) : Option String := + if ionPath.endsWith ".python.st.ion" then + let basePath := ionPath.dropEnd ".python.st.ion".length |>.toString + some (basePath ++ ".py") + else if ionPath.endsWith ".py.ion" then + some (ionPath.dropEnd ".ion".length |>.toString) + else + none + +/-- Try to read Python source file for source location reconstruction -/ +def tryReadPythonSource (ionPath : String) : IO (Option (String × String)) := do + match ionPathToPythonPath ionPath with + | none => return none + | some pyPath => + try + let content ← IO.FS.readFile pyPath + return some (pyPath, content) + catch _ => + return none + +/-- Format related position strings from metadata, if present. -/ +def formatRelatedPositions (md : Imperative.MetaData Core.Expression) + (mfm : Option (String × Lean.FileMap)) : String := + let ranges := Imperative.getRelatedFileRanges md + if ranges.isEmpty then "" else + match mfm with + | none => "" + | some (_, fm) => + let lines := ranges.filterMap fun fr => + if fr.range.isNone then none else + match fr.file with + | .file "" => some "\n Related location: in prelude file" + | .file _ => + let pos := fm.toPosition fr.range.start + some s!"\n Related location: line {pos.line}, col {pos.column}" + String.join lines.toList + +/-! ### pyAnalyzeLaurel result helpers + +The `pyAnalyzeLaurel` command emits two structured lines on stdout: +- `RESULT: ` — machine-readable category, always the last line. +- `DETAIL: ` — human-readable context (error message or VC counts). + +Exit codes follow the common scheme (see `ExitCode` above). +A successful run exits 0 with `RESULT: Analysis success` or `RESULT: Inconclusive`. -/ + +/-- Determines which VC results count as successes and which count as failures + for the purposes of the `pyAnalyzeLaurel` summary and exit code. + Implementation-error results are partitioned out first; the classifier then + partitions the rest into success / failure / inconclusive. + Narrowing `isFailure` (e.g. to only `alwaysFalseAndReachable`) automatically + widens inconclusive. + Future: may be extended with `isWarning` for non-fatal diagnostic categories. -/ +structure ResultClassifier where + isSuccess : Core.VCResult → Bool := (·.isSuccess) + isFailure : Core.VCResult → Bool := (·.isFailure) + +private def printPyAnalyzeResult (category : String) (detail : String) : IO Unit := do + IO.println s!"DETAIL: {detail}" + IO.println s!"RESULT: {category}" + +private def exitPyAnalyzeUserError {α} (message : String) : IO α := do + printPyAnalyzeResult "User error" message + IO.Process.exit ExitCode.userError + +private def exitPyAnalyzeFailuresFound {α} (detail : String) : IO α := do + printPyAnalyzeResult "Failures found" detail + IO.Process.exit ExitCode.failuresFound + +private def exitPyAnalyzeInternalError {α} (message : String) : IO α := do + printPyAnalyzeResult "Internal error" message + IO.Process.exit ExitCode.internalError + +private def exitPyAnalyzeKnownLimitation {α} (message : String) : IO α := do + printPyAnalyzeResult "Known limitation" message + IO.Process.exit ExitCode.knownLimitation + +/-- Print the final RESULT/DETAIL lines based on solver outcomes. + Always called on successful pipeline completion (as opposed to the + exit helpers above, which are called on early pipeline failure). + Classification uses successive partitioning: timeouts and implementation + errors are removed first, then the classifier partitions the rest into + success / failure / inconclusive (guaranteeing disjointness). + Unreachable count is reported as supplementary info. + + Exit-code priority (highest wins): + - Internal error (exit 3): encoding failures or solver crashes + - Failures found (exit 2): assertion violations + - Inconclusive / success / solver timeout (exit 0) -/ +private def printPyAnalyzeSummary (vcResults : Array Core.VCResult) + (checkMode : VerificationMode := .deductive) : IO Unit := do + let classifier : ResultClassifier := + match checkMode with + | .bugFinding | .bugFindingAssumingCompleteSpec => + { isSuccess := (·.isBugFindingSuccess) + isFailure := (·.isBugFindingFailure) } + | _ => {} + -- 1. Partition out implementation errors and timeouts (not classifiable). + let (implError, rest1) := + vcResults.partition (fun r => r.isImplementationError || r.hasSMTError) + let (timeouts, classifiable) := rest1.partition (·.isTimeout) + -- 2. Successive partitioning via the classifier: success → failure → inconclusive. + let (success, rest) := classifiable.partition classifier.isSuccess + let (failure, inconclusive) := rest.partition classifier.isFailure + -- 3. Unreachable is informational (not a separate partition). + let nUnreachable := vcResults.filter (·.isUnreachable) |>.size + let nImplError := implError.size + let nTimeout := timeouts.size + let nSuccess := success.size + let nFailure := failure.size + let nInconclusive := inconclusive.size + let unreachableStr := if nUnreachable > 0 then s!", {nUnreachable} unreachable" else "" + let implErrorStr := if nImplError > 0 then s!", {nImplError} internal errors" else "" + let timeoutStr := if nTimeout > 0 then s!", {nTimeout} solver timeouts" else "" + let counts := s!"{nSuccess} passed, {nFailure} failed, {nInconclusive} inconclusive{unreachableStr}{timeoutStr}{implErrorStr}" + if nImplError > 0 then + exitPyAnalyzeInternalError s!"An unexpected result was produced. {counts}" + else if nFailure > 0 then + exitPyAnalyzeFailuresFound counts + else + let label := + if nTimeout > 0 then "Solver timeout" + else if nInconclusive > 0 then "Inconclusive" + else "Analysis success" + printPyAnalyzeResult label counts + +private def deriveBaseName (file : String) : String := + let name := System.FilePath.fileName file |>.getD file + let suffixes := [".python.st.ion", ".py.ion", ".st.ion", ".st"] + match suffixes.find? (name.endsWith ·) with + | some sfx => (name.dropEnd sfx.length).toString + | none => name + + +/-- Write SMT-style user-error diagnostics to stdout and `user_errors.txt`, + and return a human-readable location suffix (e.g., " at line 42, col 5"). -/ +private def reportUserCodeError (range : SourceRange) (msg : String) + (mfm : Option (String × Lean.FileMap)) (filePath : String) : IO String := do + let location := if range.isNone then "" else + match mfm with + | some (_, fm) => + let pos := fm.toPosition range.start + s!" at line {pos.line}, col {pos.column}" + | none => "" + let mut lines := #[ + s!"(set-info :file {Strata.escapeSMTStringLit filePath})" + ] + unless range.isNone do + lines := lines.push s!"(set-info :start {range.start})" + lines := lines.push s!"(set-info :stop {range.stop})" + lines := lines.push s!"(set-info :error-message {Strata.escapeSMTStringLit msg})" + for line in lines do + IO.println line + IO.FS.Handle.mk "user_errors.txt" .write >>= fun h => + for line in lines do + h.putStrLn line + return location + +def pyAnalyzeLaurelCommand : Command where + name := "pyAnalyzeLaurel" + args := [ "file" ] + flags := verifyOptionsFlags ++ [ + { name := "spec-dir", + help := "Directory containing compiled PySpec Ion files.", + takesArg := .arg "dir" }, + { name := "dispatch", + help := "Dispatch module name (e.g., servicelib).", + takesArg := .repeat "module" }, + { name := "pyspec", + help := "PySpec module name (e.g., servicelib.Storage).", + takesArg := .repeat "module" }, + { name := "keep-all-files", + help := "Store intermediate Laurel and Core programs in .", + takesArg := .arg "dir" }, + { name := "entry-point", + help := "Which procedures to verify: main (main fn only), roots (user procs with no user callers, default), or all (all user procs). Only valid in bugFinding mode.", + takesArg := .arg "mode" }, + { name := "metrics", + help := "Write pipeline metrics (diagnostics, timing, outcome) as JSONL to .", + takesArg := .arg "file" }, + { name := "skip-verification", + help := "Run Python-to-Laurel and Laurel-to-Core translation only (skip SMT verification).", + takesArg := .none }] + help := "Verify a Python Ion program via the Laurel pipeline. Translates Python to Laurel to Core, then runs SMT verification." + callback := fun v pflags => do + let verbose := pflags.getBool "verbose" + let profile := pflags.getBool "profile" + let quiet := pflags.getBool "quiet" + let outputSarif := pflags.getBool "sarif" + let filePath := v[0] + let pySourceOpt ← tryReadPythonSource filePath + let keepDir := pflags.getString "keep-all-files" + let baseName := deriveBaseName filePath + if let some dir := keepDir then + IO.FS.createDirAll dir + + let dispatchModules := pflags.getRepeated "dispatch" + let pyspecModules := pflags.getRepeated "pyspec" + let specDir := pflags.getString "spec-dir" |>.getD "." + unless ← System.FilePath.isDir specDir do + exitFailure s!"spec-dir '{specDir}' does not exist or is not a directory" + let sourcePath := pySourceOpt.map (·.1) + -- Build FileMap for source position resolution. + let mfm : Option (String × Lean.FileMap) := match pySourceOpt with + | some (pyPath, srcText) => some (pyPath, .ofString srcText) + | none => none + let metricsHandle ← match pflags.getString "metrics" with + | some path => some <$> IO.FS.Handle.mk path .write + | none => pure none + + -- Parse verify options early (needed for pipeline config). + let keepPrefix := keepDir.map (s!"{·}/{baseName}") + let baseVcDir := keepDir.map (fun dir => (s!"{dir}/{baseName}" : System.FilePath)) + let pyAnalyzeBase : VerifyOptions := + { VerifyOptions.default with + verbose := .quiet, removeIrrelevantAxioms := .Precise, + vcDirectory := baseVcDir } + let options ← parseVerifyOptions pflags pyAnalyzeBase + let isBugFinding := options.checkMode == .bugFinding + || options.checkMode == .bugFindingAssumingCompleteSpec + + -- Parse --entry-point flag (only supported in bug-finding modes). + let entryPointFlag := pflags.getString "entry-point" + let entryPoint : EntryPoint ← + if isBugFinding then + match entryPointFlag with + | some s => + match EntryPoint.ofString? s with + | some ep => pure ep + | none => + exitPyAnalyzeUserError s!"Invalid --entry-point value '{s}'. Must be {EntryPoint.options}." + | none => pure .roots + else + if entryPointFlag.isSome then + exitPyAnalyzeUserError s!"--entry-point is unsupported in {options.checkMode} mode" + else pure .all + + -- Derive output mode from CLI flags. + let outputMode : Strata.Pipeline.OutputMode := + if verbose then .verbose + else if profile then .profile + else if quiet then .quiet + else .default + let skipVerification := pflags.getBool "skip-verification" + + -- Run the pipeline + let (outcome, laurelPassStats, pctx) ← Strata.Pipeline.runPyAnalyzePipeline { + filePath, specDir + dispatchModules, pyspecModules, sourcePath + keepAllFilesPrefix := keepPrefix + verifyOptions := options + entryPoint, isBugFinding + outputMode, skipVerification + metricsHandle + } + + -- Always print pipeline warnings + let msgs ← pctx.getMessages + if !quiet && msgs.size > 0 then + IO.eprintln s!"{msgs.size} pipeline warning(s)" + if verbose then + for err in msgs do + IO.eprintln s!" {err.file}: {err.phase}.{err.kind}: {err.message}" + + if profile && !laurelPassStats.data.isEmpty then + IO.println laurelPassStats.format + + -- Write outcome record to metrics file. + let emitOutcome (resultStr : String) (exitCode : UInt8) (detail : Option String := none) : IO Unit := do + let totalMs ← pctx.elapsedNs + let mut fields : List (String × Lean.Json) := [ + ("type", .str "outcome"), ("result", .str resultStr), + ("exit_code", .num exitCode.toNat), ("total_ms", .num (Strata.Pipeline.nsToMs totalMs))] + if let some d := detail then + fields := fields ++ [("detail", .str d)] + pctx.emitMetric (Lean.Json.mkObj fields) + + -- Handle pipeline outcome. + -- Exit code is f(outcome, messages) — see priority ordering in unify.md. + let toolErrors ← pctx.getToolErrors + let userErrors ← pctx.getUserCodeErrors + + -- Priority 1: internal/configuration errors always dominate + if let some lastErr := toolErrors.back? then + emitOutcome "internalError" ExitCode.internalError (detail := lastErr.message) + exitPyAnalyzeInternalError lastErr.message + -- Priority 2: user code errors + if let some lastErr := userErrors.back? then + emitOutcome "userError" ExitCode.userError (detail := lastErr.message) + let location ← reportUserCodeError lastErr.loc lastErr.message mfm (sourcePath.getD filePath) + exitPyAnalyzeUserError s!"{lastErr.message}{location}" + match outcome with + | .verified vcResults _coreProgram => + emitOutcome "verified" 0 + -- Print per-VC results by default, unless SARIF mode is used + if !outputSarif then + let mut s := "" + for vcResult in vcResults do + let fileMap := mfm.map (·.2) + let location := match Imperative.getFileRange vcResult.obligation.metadata with + | some fr => + if fr.range.isNone then "" + else s!"{fr.format fileMap (includeEnd? := false)}" + | none => "" + let messageSuffix := match vcResult.obligation.metadata.getPropertySummary with + | some msg => s!" - {msg}" + | none => s!" - {vcResult.obligation.label}" + let outcomeStr := vcResult.formatOutcome + let loc := if !location.isEmpty then s!"{location}: " else "unknown location: " + s := s ++ s!"{loc}{outcomeStr}{messageSuffix}\n" + IO.print s + -- Output in SARIF format if requested + if outputSarif then + let files := match mfm with + | some (pyPath, fm) => Map.empty.insert (Strata.Uri.file pyPath) fm + | none => Map.empty + Core.Sarif.writeSarifOutput options.checkMode files vcResults (filePath ++ ".sarif") + printPyAnalyzeSummary vcResults options.checkMode + | .failed => + -- Priority 4: known limitations + let knownLimitations := msgs.filter (·.kind.impact == .knownLimitation) + match knownLimitations.back? with + | some lastErr => + emitOutcome "knownLimitation" ExitCode.knownLimitation (detail := lastErr.message) + exitPyAnalyzeKnownLimitation lastErr.message + | none => + -- .failed with no classified impact = internal error + let msg : String := match msgs.back? with + | some m => m.message + | none => "Pipeline aborted" + emitOutcome "internalError" ExitCode.internalError (detail := msg) + exitPyAnalyzeInternalError msg + +def pyAnalyzeToGotoCommand : Command where + name := "pyAnalyzeToGoto" + args := [ "file" ] + help := "Translate a Strata Python Ion file to CProver GOTO JSON files." + callback := fun v _ => do + let filePath := v[0] + let pySourceOpt ← tryReadPythonSource filePath + let sourcePathForMetadata := match pySourceOpt with + | some (pyPath, _) => pyPath + | none => filePath + let sourceText := pySourceOpt.map (·.2) + let newPgm ← Strata.pythonDirectToCore filePath sourcePathForMetadata + match Core.inlineProcedures newPgm { doInline := (fun _caller callee _ => callee ≠ "main") } with + | .error e => exitInternalError (toString e) + | .ok newPgm => + -- Type-check the full program (registers Python types like ExceptOrNone) + let Ctx := { Lambda.LContext.default with functions := Strata.Python.PythonFactory, knownTypes := Core.KnownTypes } + let Env := Lambda.TEnv.default + let (tcPgm, _) ← match Core.Program.typeCheck Ctx Env newPgm with + | .ok r => pure r + | .error e => exitInternalError s!"{e.format none}" + -- Find the main procedure + let some mainDecl := tcPgm.decls.find? fun d => + match d with + | .proc p _ => Core.CoreIdent.toPretty p.header.name == "main" + | _ => false + | exitInternalError "No main procedure found" + let some p := mainDecl.getProc? + | exitInternalError "main is not a procedure" + -- Translate procedure to GOTO (mirrors CoreToGOTO.transformToGoto post-typecheck logic) + let baseName := deriveBaseName filePath + let procName := Core.CoreIdent.toPretty p.header.name + let axioms := tcPgm.decls.filterMap fun d => d.getAxiom? + let distincts := tcPgm.decls.filterMap fun d => match d with + | .distinct name es _ => some (name, es) | _ => none + match procedureToGotoCtx Env p sourceText (axioms := axioms) (distincts := distincts) + with + | .error e => exitInternalError s!"{e}" + | .ok (ctx, liftedFuncs) => + let extraSyms ← match collectExtraSymbols tcPgm with + | .ok s => pure (Lean.toJson s) + | .error e => exitInternalError s!"{e}" + let (symtab, goto) ← emitProcWithLifted Env procName ctx liftedFuncs extraSyms + (moduleName := baseName) + let symTabFile := s!"{baseName}.symtab.json" + let gotoFile := s!"{baseName}.goto.json" + writeJsonFile symTabFile symtab + writeJsonFile gotoFile goto + IO.println s!"Written {symTabFile} and {gotoFile}" + +def pyTranslateLaurelCommand : Command where + name := "pyTranslateLaurel" + args := [ "file" ] + flags := [{ name := "pyspec", + help := "PySpec module name (e.g., servicelib.Storage).", + takesArg := .repeat "module" }, + { name := "dispatch", + help := "Dispatch module name (e.g., servicelib).", + takesArg := .repeat "module" }, + { name := "spec-dir", + help := "Directory containing compiled PySpec Ion files.", + takesArg := .arg "dir" }] + help := "Translate a Strata Python Ion file through Laurel to Strata Core. Write results to stdout." + callback := fun v pflags => do + let dispatchModules := pflags.getRepeated "dispatch" + let pyspecModules := pflags.getRepeated "pyspec" + let specDir := pflags.getString "spec-dir" |>.getD "." + unless ← System.FilePath.isDir specDir do + exitFailure s!"spec-dir '{specDir}' does not exist or is not a directory" + let coreProgram ← + match ← Strata.pyTranslateLaurel v[0] dispatchModules pyspecModules (specDir := specDir) |>.toBaseIO with + | .ok r => pure r + | .error msg => exitFailure msg + IO.print coreProgram + +def pyAnalyzeLaurelToGotoCommand : Command where + name := "pyAnalyzeLaurelToGoto" + args := [ "file" ] + flags := [{ name := "pyspec", + help := "PySpec module name (e.g., servicelib.Storage).", + takesArg := .repeat "module" }, + { name := "dispatch", + help := "Dispatch module name (e.g., servicelib).", + takesArg := .repeat "module" }, + { name := "spec-dir", + help := "Directory containing compiled PySpec Ion files.", + takesArg := .arg "dir" }] + help := "Translate a Strata Python Ion file through Laurel to CProver GOTO JSON files." + callback := fun v pflags => do + let filePath := v[0] + let dispatchModules := pflags.getRepeated "dispatch" + let pyspecModules := pflags.getRepeated "pyspec" + let specDir := pflags.getString "spec-dir" |>.getD "." + unless ← System.FilePath.isDir specDir do + exitFailure s!"spec-dir '{specDir}' does not exist or is not a directory" + let (coreProgram, laurelTranslateErrors) ← + match ← Strata.pyTranslateLaurel filePath dispatchModules pyspecModules (specDir := specDir) |>.toBaseIO with + | .ok r => pure r + | .error msg => exitFailure msg + let sourceText := (← tryReadPythonSource filePath).map (·.2) + let baseName := deriveBaseName filePath + match ← Strata.inlineCoreToGotoFiles coreProgram baseName sourceText + (factory := Strata.Python.PythonFactory) |>.toBaseIO with + | .ok () => pure () + | .error msg => exitFailure msg + +def javaGenCommand : Command where + name := "javaGen" + args := [ "dialect", "package", "output-dir" ] + flags := [includeFlag] + help := "Generate Java source files from a DDM dialect definition. Accepts a dialect name (e.g. Laurel) or a dialect file path." + callback := fun v pflags => do + let fm ← pflags.buildDialectFileMap + let ld ← fm.getLoaded + let d ← if mem : v[0] ∈ ld.dialects then + pure ld.dialects[v[0]] + else + match ← Strata.readStrataFile fm v[0] with + | .dialect d => pure d + | .program _ => exitFailure "Expected a dialect file, not a program file." + match Strata.Java.generateDialect d v[1] with + | .ok files => + Strata.Java.writeJavaFiles v[2] v[1] files + IO.println s!"Generated Java files for {d.name} in {v[2]}/{Strata.Java.packageToPath v[1]}" + | .error msg => + exitFailure s!"Error generating Java: {msg}" + +def laurelAnalyzeBinaryCommand : Command where + name := "laurelAnalyzeBinary" + args := [] + flags := laurelVerifyOptionsFlags + help := "Verify Laurel Ion programs read from stdin and print diagnostics. Combines multiple input files." + callback := fun _ pflags => do + let options ← parseLaurelVerifyOptions pflags + let stdinBytes ← (← IO.getStdin).readBinToEnd + let combinedProgram ← Strata.readLaurelIonProgram stdinBytes + let diagnostics ← Strata.Laurel.verifyToDiagnosticModels combinedProgram options + + IO.println s!"==== DIAGNOSTICS ====" + for diag in diagnostics do + IO.println s!"{Std.format diag.fileRange.file}:{diag.fileRange.range.start}-{diag.fileRange.range.stop}: {diag.message}" + +def pySpecToLaurelCommand : Command where + name := "pySpecToLaurel" + args := [ "python_path", "strata_path" ] + help := "Translate a PySpec Ion file to Laurel declarations. The Ion file must already exist." + callback := fun v _ => do + let pythonFile : System.FilePath := v[0] + let strataDir : System.FilePath := v[1] + let some mod := pythonFile.fileStem + | exitFailure s!"No stem {pythonFile}" + let some mod := Strata.Python.ModuleName.ofString? mod + | exitFailure s!"Invalid module {mod}" + let ionFile := strataDir / mod.strataFileName + let sigs ← + match ← Strata.Python.Specs.readDDM ionFile |>.toBaseIO with + | .ok t => pure t + | .error msg => exitFailure s!"Could not read {ionFile}: {msg}" + let result := Strata.Python.Specs.ToLaurel.signaturesToLaurel pythonFile sigs mod + if result.errors.size > 0 then + IO.eprintln s!"{result.errors.size} translation warning(s):" + for err in result.errors do + IO.eprintln s!" {err.file}: {err.message}" + let pgm := result.program + IO.println s!"Laurel: {pgm.staticProcedures.length} procedure(s), {pgm.types.length} type(s)" + IO.println s!"Overloads: {result.overloads.size} function(s)" + for td in pgm.types do + IO.println s!" {Strata.Laurel.formatTypeDefinition td}" + for proc in pgm.staticProcedures do + IO.println s!" {Strata.Laurel.formatProcedure proc}" + +def pyResolveOverloadsCommand : Command where + name := "pyResolveOverloads" + args := [ "python_path", "dispatch_ion" ] + help := "Identify which overloaded service modules a \ + Python program uses. Prints one module name per \ + line to stdout." + callback := fun v _ => do + let pythonFile : System.FilePath := v[0] + let dispatchPath := v[1] + -- Read dispatch overload table + let pctx ← Strata.Pipeline.PipelineContext.create + let overloads ← match ← (readDispatchOverloads pctx #[dispatchPath]).toBaseIO with + | .ok r => pure r + | .error () => + for m in ← pctx.getMessages do + IO.eprintln s!"{m}" + exitFailure "readDispatchOverloads: fatal error" + -- Convert .py to Python AST + let stmts ← + IO.FS.withTempFile fun _handle dialectFile => do + IO.FS.writeBinFile dialectFile + Strata.Python.Python.toIon + match ← Strata.Python.pythonToStrata dialectFile pythonFile |>.toBaseIO with + | .ok s => pure s + | .error msg => exitFailure msg + -- Walk AST and collect modules + let state := + Strata.Python.Specs.IdentifyOverloads.resolveOverloads + overloads stmts + for w in state.warnings do + IO.eprintln s!"warning: {w}" + let sorted := state.modules.toArray.qsort (· < ·) + for m in sorted do + IO.println m + +def laurelParseCommand : Command where + name := "laurelParse" + args := [ "file" ] + help := "Parse a Laurel source file (no verification)." + callback := fun v _ => do + let _ ← Strata.readLaurelTextFile v[0] + IO.println "Parse successful" + +def laurelAnalyzeCommand : Command where + name := "laurelAnalyze" + args := [ "file" ] + flags := laurelVerifyOptionsFlags + help := "Analyze a Laurel source file. Write diagnostics to stdout." + callback := fun v pflags => do + let options ← parseLaurelVerifyOptions pflags + let laurelProgram ← Strata.readLaurelTextFile v[0] + let (vcResultsOption, errors) ← Strata.Laurel.verifyToVcResults laurelProgram options + if !errors.isEmpty then + IO.println s!"==== ERRORS ====" + for err in errors do + IO.println s!"{err.message}" + match vcResultsOption with + | none => return + | some vcResults => + IO.println s!"==== RESULTS ====" + for vc in vcResults do + IO.println s!"{vc.obligation.label}: {match vc.outcome with | .ok o => repr o | .error e => toString e}" + +def laurelAnalyzeToGotoCommand : Command where + name := "laurelAnalyzeToGoto" + args := [ "file" ] + help := "Translate a Laurel source file to CProver GOTO JSON files." + callback := fun v _ => do + let path : System.FilePath := v[0] + let content ← IO.FS.readFile path + let laurelProgram ← Strata.parseLaurelText path content + match ← Strata.Laurel.translate {} laurelProgram with + | (none, diags) => exitFailure s!"Core translation errors: {diags.map (·.message)}" + | (some coreProgram, errors) => + let Ctx := { Lambda.LContext.default with functions := Core.Factory, knownTypes := Core.KnownTypes } + let Env := Lambda.TEnv.default + let (tcPgm, _) ← match Core.Program.typeCheck Ctx Env coreProgram with + | .ok r => pure r + | .error e => exitInternalError s!"{e.format none}" + let procs := tcPgm.decls.filterMap fun d => d.getProc? + let funcs := tcPgm.decls.filterMap fun d => + match d.getFunc? with + | some f => + let name := Core.CoreIdent.toPretty f.name + if f.body.isSome && f.typeArgs.isEmpty + && name != "Int.DivT" && name != "Int.ModT" + then some f else none + | none => none + if procs.isEmpty && funcs.isEmpty then exitInternalError "No procedures or functions found" + let baseName := deriveBaseName path.toString + let typeSyms ← match collectExtraSymbols tcPgm with + | .ok s => pure s + | .error e => exitInternalError s!"{e}" + let typeSymsJson := Lean.toJson typeSyms + let sourceText := some content + let axioms := tcPgm.decls.filterMap fun d => d.getAxiom? + let distincts := tcPgm.decls.filterMap fun d => match d with + | .distinct name es _ => some (name, es) | _ => none + let mut symtabPairs : List (String × Lean.Json) := [] + let mut gotoFns : Array Lean.Json := #[] + let mut allLiftedFuncs : List Core.Function := [] + for p in procs do + let procName := Core.CoreIdent.toPretty p.header.name + match procedureToGotoCtx Env p (sourceText := sourceText) (axioms := axioms) (distincts := distincts) + with + | .error e => exitInternalError s!"{e}" + | .ok (ctx, liftedFuncs) => + allLiftedFuncs := allLiftedFuncs ++ liftedFuncs + let json ← IO.ofExcept (CoreToGOTO.CProverGOTO.Context.toJson procName ctx) + match json.symtab with + | .obj m => symtabPairs := symtabPairs ++ m.toList + | _ => pure () + match json.goto with + | .obj m => + match m.toList.find? (·.1 == "functions") with + | some (_, .arr fns) => gotoFns := gotoFns ++ fns + | _ => pure () + | _ => pure () + for f in funcs ++ allLiftedFuncs do + let funcName := Core.CoreIdent.toPretty f.name + match functionToGotoCtx Env f with + | .error e => exitInternalError s!"{e}" + | .ok ctx => + let json ← IO.ofExcept (CoreToGOTO.CProverGOTO.Context.toJson funcName ctx) + match json.symtab with + | .obj m => symtabPairs := symtabPairs ++ m.toList + | _ => pure () + match json.goto with + | .obj m => + match m.toList.find? (·.1 == "functions") with + | some (_, .arr fns) => gotoFns := gotoFns ++ fns + | _ => pure () + | _ => pure () + match typeSymsJson with + | .obj m => symtabPairs := symtabPairs ++ m.toList + | _ => pure () + -- Deduplicate: keep first occurrence of each symbol name (proper function + -- symbols come before basic symbol references from callers) + let mut seen : Std.HashSet String := {} + let mut dedupPairs : List (String × Lean.Json) := [] + for (k, v) in symtabPairs do + if !seen.contains k then + seen := seen.insert k + dedupPairs := dedupPairs ++ [(k, v)] + -- Add CBMC default symbols (architecture constants, builtins) + -- and wrap in {"symbolTable": ...} for symtab2gb + let symtabObj := dedupPairs.foldl + (fun (acc : Std.TreeMap.Raw String Lean.Json) (k, v) => acc.insert k v) + .empty + let symtab := CProverGOTO.wrapSymtab symtabObj (moduleName := baseName) + let goto := Lean.Json.mkObj [("functions", Lean.Json.arr gotoFns)] + let symTabFile := s!"{baseName}.symtab.json" + let gotoFile := s!"{baseName}.goto.json" + writeJsonFile symTabFile symtab + writeJsonFile gotoFile goto + IO.println s!"Written {symTabFile} and {gotoFile}" + +def laurelPrintCommand : Command where + name := "laurelPrint" + args := [] + help := "Read Laurel Ion from stdin and print in concrete syntax to stdout." + callback := fun _ _ => do + let stdinBytes ← (← IO.getStdin).readBinToEnd + let strataFiles ← Strata.readLaurelIonFiles stdinBytes + for strataFile in strataFiles do + IO.println s!"// File: {strataFile.filePath}" + let p := strataFile.program + let c := p.formatContext {} + let s := p.formatState + let fmt := p.commands.foldl (init := f!"") fun f cmd => + f ++ (Strata.mformat cmd c s).format + IO.println (fmt.pretty 100) + IO.println "" + +def prettyPrintCore (p : Core.Program) : String := + let decls := p.decls.map fun d => + let s := toString (Std.format d) + -- Add newlines after major sections in procedures + s.replace "preconditions:" "\n preconditions:" + |>.replace "postconditions:" "\n postconditions:" + |>.replace "body:" "\n body:\n " + |>.replace "assert [" "\n assert [" + |>.replace "init (" "\n init (" + |>.replace "while (" "\n while (" + |>.replace "if (" "\n if (" + |>.replace "call [" "\n call [" + |>.replace "else{" "\n else {" + |>.replace "}}" "}\n }" + String.intercalate "\n" decls + +def laurelToCoreCommand : Command where + name := "laurelToCore" + args := [ "file" ] + help := "Translate a Laurel source file to Core and print to stdout." + callback := fun v _ => do + let laurelProgram ← Strata.readLaurelTextFile v[0] + let (coreProgramOption, errors) ← Strata.Laurel.translate {} laurelProgram + if !errors.isEmpty then + IO.println s!"Core translation errors: {errors.map (·.message)}" + match coreProgramOption with + | none => return + | some coreProgram => IO.println (prettyPrintCore coreProgram) + +/-- Print a string word-wrapped to `width` columns with `indent` spaces of indentation. -/ +private def printIndented (indent : Nat) (s : String) (width : Nat := 80) : IO Unit := do + let pad := "".pushn ' ' indent + let words := s.splitOn " " |>.filter (!·.isEmpty) + let mut line := pad + let mut first := true + for word in words do + if first then + line := line ++ word + first := false + else if line.length + 1 + word.length > width then + IO.println line + line := pad ++ word + else + line := line ++ " " ++ word + unless line.length ≤ indent do + IO.println line + +structure CommandGroup where + name : String + commands : List Command + commonFlags : List Flag := [] + +private def validPasses := + "inlineProcedures, loopElim, callElim, filterProcedures, removeIrrelevantAxioms" + +/-- A single transform pass together with the `--procedures`/`--functions` + that were specified immediately after it on the command line. -/ +private structure PassConfig where + name : String + procedures : List String := [] + functions : List String := [] +deriving Inhabited + +/-- Walk the ordered flag entries and bind each `--procedures`/`--functions` + to the most recent `--pass`. -/ +private def buildPassConfigs (entries : Array (String × Option String)) + : IO (Array PassConfig) := do + let mut configs : Array PassConfig := #[] + for (flag, value) in entries do + match flag with + | "pass" => configs := configs.push { name := value.getD "" } + | "procedures" => + let some cur := configs.back? | exitFailure "--procedures must appear after a --pass" + let procs := (value.getD "").splitToList (· == ',') + configs := configs.pop.push { cur with procedures := cur.procedures ++ procs } + | "functions" => + let some cur := configs.back? | exitFailure "--functions must appear after a --pass" + let fns := (value.getD "").splitToList (· == ',') + configs := configs.pop.push { cur with functions := cur.functions ++ fns } + | _ => pure () + return configs + +def transformCommand : Command where + name := "transform" + args := [ "file" ] + flags := [ + { name := "pass", + help := s!"Transform pass to apply (repeatable, applied left to right). \ + Valid passes: {validPasses}. \ + --procedures and --functions after a --pass apply to that pass.", + takesArg := .repeat "name" }, + { name := "procedures", + help := "Comma-separated procedure names for the preceding --pass. \ + For filterProcedures: procedures to keep. \ + For inlineProcedures: procedures to inline.", + takesArg := .repeat "procs" }, + { name := "functions", + help := "Comma-separated function names for the preceding --pass (used by removeIrrelevantAxioms).", + takesArg := .repeat "funcs" }] + help := "Apply one or more transforms to a Core program and print the result." + callback := fun v pflags => do + let file := v[0] + let passConfigs ← buildPassConfigs pflags.entries + if passConfigs.isEmpty then + exitFailure s!"No --pass specified. Valid passes: {validPasses}." + -- Read and parse the Core program + let (pgm, _) ← match ← readStrataProgram file with + | .ok r => pure r + | .error msgs => + for e in msgs do println! s!"Error: {← e.toString}" + exitFailure s!"{msgs.size} parse error(s)" + match Strata.genericToCore pgm with + | .error msg => + exitFailure msg + | .ok initProgram => + -- Validate and convert pass configs to TransformPass values + let mut passes : List Strata.Core.TransformPass := [] + for pc in passConfigs do + match pc.name with + | "inlineProcedures" => + let opts : Core.InlineTransformOptions := + if pc.procedures.isEmpty then {} + else { doInline := (fun _caller callee _ => callee ∈ pc.procedures) } + passes := passes ++ [.inlineProcedures opts] + | "loopElim" => + passes := passes ++ [.loopElim] + | "callElim" => + passes := passes ++ [.callElim] + | "filterProcedures" => + if pc.procedures.isEmpty then + exitFailure "filterProcedures requires --procedures" + passes := passes ++ [.filterProcedures pc.procedures] + | "removeIrrelevantAxioms" => + if pc.functions.isEmpty then + exitFailure "removeIrrelevantAxioms requires --functions" + passes := passes ++ [.removeIrrelevantAxioms pc.functions] + | other => + exitFailure s!"Unknown pass '{other}'. Valid passes: {validPasses}." + -- Run all passes in a single CoreTransformM chain so fresh variable + -- counters accumulate and cached analyses are reused across passes. + match Strata.Core.runTransforms initProgram passes with + | .ok program => IO.print (Core.formatProgram program) + | .error e => exitFailure s!"Transform failed: {e}" + +def verifyCommand : Command where + name := "verify" + args := [ "file" ] + flags := verifyOptionsFlags ++ [ + { name := "check", help := "Process up until SMT generation, but don't solve." }, + { name := "type-check", help := "Exit after semantic dialect's type inference/checking." }, + { name := "parse-only", help := "Exit after DDM parsing and type checking." }, + { name := "output-format", help := "Output format (only 'sarif' supported).", takesArg := .arg "format" }, + { name := "procedures", help := "Verify only the specified procedures (comma-separated).", takesArg := .arg "procs" }] + help := "Verify a Strata program file (.core.st, .csimp.st, or .b3.st)." + callback := fun v pflags => do + let file := v[0] + let proceduresToVerify := pflags.getString "procedures" |>.map (·.splitToList (· == ',')) + let opts ← parseVerifyOptions pflags { VerifyOptions.default with verbose := .quiet } + let opts := { opts with + checkOnly := pflags.getBool "check", + typeCheckOnly := pflags.getBool "type-check", + parseOnly := pflags.getBool "parse-only", + outputSarif := opts.outputSarif || pflags.getString "output-format" == some "sarif" } + let (pgm, inputCtx) ← match ← readStrataProgram file with + | .ok r => pure r + | .error errors => + for e in errors do + let msg ← e.toString + println! s!"Error: {msg}" + println! f!"Finished with {errors.size} errors." + IO.Process.exit ExitCode.userError + println! s!"Successfully parsed." + if opts.parseOnly then return + if opts.typeCheckOnly then + let ans := if file.endsWith ".csimp.st" then + C_Simp.typeCheck pgm opts + else if pgm.dialect == "Boole" then + Boole.typeCheck pgm opts + else + typeCheck inputCtx pgm opts + match ans with + | .error e => + println! f!"{e.formatRange (some inputCtx.fileMap) true} {e.message}" + IO.Process.exit ExitCode.userError + | .ok _ => + println! f!"Program typechecked." + return + -- Full verification + let vcResults ← try + if file.endsWith ".csimp.st" then + C_Simp.verify pgm opts + else if file.endsWith ".b3.st" || file.endsWith ".b3cst.st" then + let ast ← match B3.Verifier.programToB3AST pgm with + | Except.error msg => throw (IO.userError s!"Failed to convert to B3 AST: {msg}") + | Except.ok ast => pure ast + let solver ← B3.Verifier.createInteractiveSolver opts.solver + let reports ← B3.Verifier.programToSMT ast solver + for report in reports do + IO.println s!"\nProcedure: {report.procedureName}" + for (result, _) in report.results do + let marker := if result.result.isError then "✗" else "✓" + let desc := match result.result with + | .error .counterexample => "counterexample found" + | .error .unknown => "unknown" + | .error .refuted => "refuted" + | .success .verified => "verified" + | .success .reachable => "reachable" + | .success .reachabilityUnknown => "reachability unknown" + IO.println s!" {marker} {desc}" + pure #[] + else if pgm.dialect == "Boole" then + Boole.verify opts.solver pgm inputCtx proceduresToVerify opts + else + verify pgm inputCtx proceduresToVerify opts + catch e => + println! f!"{e}" + IO.Process.exit ExitCode.internalError + if opts.outputSarif then + if file.endsWith ".csimp.st" then + println! "SARIF output is not supported for C_Simp files (.csimp.st) because location metadata is not preserved during translation to Core." + else + let uri := Strata.Uri.file file + let files := Map.empty.insert uri inputCtx.fileMap + Core.Sarif.writeSarifOutput opts.checkMode files vcResults (file ++ ".sarif") + for vcResult in vcResults do + let posStr := Imperative.MetaData.formatFileRangeD vcResult.obligation.metadata (some inputCtx.fileMap) + println! f!"{posStr} [{vcResult.obligation.label}]: \ + {vcResult.formatOutcome}" + let success := vcResults.all Core.VCResult.isSuccess + if success && !opts.checkOnly then + println! f!"All {vcResults.size} goals passed." + else if success && opts.checkOnly then + println! f!"Skipping verification." + else + let provedGoalCount := (vcResults.filter Core.VCResult.isSuccess).size + let failedGoalCount := (vcResults.filter Core.VCResult.isNotSuccess).size + -- Encoding failures, solver crashes, or per-check SMT errors (exit 3) + let hasImplError := vcResults.any (fun r => r.isImplementationError || r.hasSMTError) + -- Assertion violations that are not timeouts or internal errors (exit 2) + let hasFailure := vcResults.any (fun r => !r.isSuccess && !r.isTimeout && !r.isImplementationError && !r.hasSMTError) + println! f!"Finished with {provedGoalCount} goals passed, {failedGoalCount} failed." + if hasImplError then + IO.Process.exit ExitCode.internalError + else if hasFailure then + IO.Process.exit ExitCode.failuresFound + +def pyInterpretCommand : Command where + name := "pyInterpret" + args := [ "file" ] + flags := [{ name := "fuel", help := "Maximum execution steps.", takesArg := .arg "n" }] + ++ laurelTranslateFlags + help := "Interpret a Python Ion program concretely (Python → Laurel → Core → execute)." + callback := fun v pflags => do + let filePath := v[0] + let keepDir := pflags.getString "keep-all-files" + let fuel ← match pflags.getString "fuel" with + | some s => match s.toNat? with + | .some n => pure n + | .none => exitFailure s!"Invalid fuel: '{s}'" + | none => pure 10000 + + let quietCtx ← Strata.Pipeline.PipelineContext.create (outputMode := .quiet) + let (core, _diags) ← + match ← (Strata.pythonAndSpecToLaurel filePath (specDir := ".")).run quietCtx |>.toBaseIO with + | .ok laurel => + if let some dir := keepDir then + IO.FS.createDirAll dir + IO.FS.writeFile (dir ++ "/laurel.st") (toString (Std.format laurel)) + match ← Strata.translateCombinedLaurel laurel with + | (some core, diags) => pure (core, diags) + | (none, diags) => exitFailure s!"Laurel to Core translation failed: {diags}" + | .error () => + let msgs ← quietCtx.getMessages + let detail := match msgs.back? with | some m => m.message | none => "Pipeline aborted" + exitFailure detail + if let some dir := keepDir then + IO.FS.writeFile (dir ++ "/core.st") (toString (Std.format core)) + let core ← match Core.typeCheck Core.VerifyOptions.quiet core + (moreFns := Strata.Python.ReFactory) with + | .ok prog => pure prog + | .error e => + println! s!"Core type checking failed: {e.message}" + IO.Process.exit ExitCode.userError + match core.run with + | .ok E => + let mainProc := Core.Program.Procedure.find? core ⟨"__main__", ()⟩ + let outputNames := match mainProc with + | some p => p.header.outputs.keys.map (·.name) + | none => [] + let (lhs, exprEnv) := Core.Env.genVars outputNames E.exprEnv + let E := { E with exprEnv } + let E := Core.Statement.Command.runCall lhs "__main__" [] fuel E + match E.error with + | none => + IO.println "Execution completed successfully." + | some e => + IO.println s!"{Std.format e}" + IO.Process.exit ExitCode.failuresFound + | .error diag => + IO.eprintln s!"Error: {diag}" + IO.Process.exit ExitCode.failuresFound + +def commandGroups : List CommandGroup := [ + { name := "Core" + commands := [verifyCommand, transformCommand, checkCommand, toIonCommand, printCommand, diffCommand] + commonFlags := [includeFlag] }, + { name := "Code Generation" + commands := [javaGenCommand] }, + { name := "Python" + commands := [pyAnalyzeLaurelCommand, + pyResolveOverloadsCommand, + pySpecsCommand, pySpecToLaurelCommand, + pyAnalyzeLaurelToGotoCommand, + pyAnalyzeToGotoCommand, + pyTranslateLaurelCommand, + pyInterpretCommand] }, + { name := "Laurel" + commands := [laurelAnalyzeCommand, laurelAnalyzeBinaryCommand, + laurelAnalyzeToGotoCommand, laurelParseCommand, + laurelPrintCommand, laurelToCoreCommand] }, +] + +def commandList : List Command := + commandGroups.foldl (init := []) fun acc g => acc ++ g.commands + +def commandMap : Std.HashMap String Command := + commandList.foldl (init := {}) fun m c => m.insert c.name c + +/-- Print a single flag's name and help text at the given indentation. -/ +private def printFlag (indent : Nat) (flag : Flag) : IO Unit := do + let pad := "".pushn ' ' indent + match flag.takesArg with + | .arg argName | .repeat argName => + IO.println s!"{pad}--{flag.name} <{argName}> {flag.help}" + | .none => + IO.println s!"{pad}--{flag.name} {flag.help}" + +/-- Print help for all command groups. -/ +private def printGlobalHelp (groups : List CommandGroup := commandGroups) : IO Unit := do + IO.println "Usage: strata [flags]...\n" + IO.println "Command-line utilities for working with Strata.\n" + for group in groups do + IO.println s!"{group.name}:" + for cmd in group.commands do + let cmdLine := cmd.args.foldl (init := cmd.name) fun s a => s!"{s} <{a}>" + IO.println s!" {cmdLine}" + printIndented 4 cmd.help + let perCmdFlags := cmd.flags.filter fun f => + !group.commonFlags.any fun cf => cf.name == f.name + if !perCmdFlags.isEmpty then + IO.println "" + IO.println " Flags:" + for flag in perCmdFlags do + printFlag 6 flag + IO.println "" + if !group.commonFlags.isEmpty then + IO.println " Common flags:" + for flag in group.commonFlags do + printFlag 4 flag + IO.println "" + +/-- Print help for a single command. -/ +private def printCommandHelp (cmd : Command) : IO Unit := do + let cmdLine := cmd.args.foldl (init := s!"strata {cmd.name}") fun s a => s!"{s} <{a}>" + let flagSummary := cmd.flags.foldl (init := "") fun s f => + match f.takesArg with + | .arg argName | .repeat argName => s!"{s} [--{f.name} <{argName}>]" + | .none => s!"{s} [--{f.name}]" + IO.println s!"Usage: {cmdLine}{flagSummary}\n" + printIndented 0 cmd.help + if !cmd.flags.isEmpty then + IO.println "\nFlags:" + for flag in cmd.flags do + printFlag 2 flag + +/-- Parse interleaved flags and positional arguments. Returns the collected + positional arguments and parsed flags. -/ +private def parseArgs (cmdName : String) + (flagMap : Std.HashMap String Flag) + (acc : Array String) (pflags : ParsedFlags) + (cmdArgs : List String) : IO (Array String × ParsedFlags) := do + match cmdArgs with + | arg :: cmdArgs => + if arg.startsWith "--" then + let raw := (arg.drop 2).toString + -- Support --flag=value syntax by splitting on first '=' + let (flagName, inlineValue) ← match raw.splitOn "=" with + | name :: value :: rest => + if !rest.isEmpty then + exitCmdFailure cmdName s!"Invalid option format: {arg}. Values must not contain '='." + pure (name, some value) + | _ => pure (raw, none) + match flagMap[flagName]? with + | some flag => + match flag.takesArg with + | .none => + parseArgs cmdName flagMap acc (pflags.insert flagName Option.none) cmdArgs + | .arg _ => + match inlineValue with + | some value => + parseArgs cmdName flagMap acc (pflags.insert flagName (some value)) cmdArgs + | none => + let value :: cmdArgs := cmdArgs + | exitCmdFailure cmdName s!"Expected value after {arg}." + parseArgs cmdName flagMap acc (pflags.insert flagName (some value)) cmdArgs + | .repeat _ => + match inlineValue with + | some value => + parseArgs cmdName flagMap acc (pflags.insert flagName (some value)) cmdArgs + | none => + let value :: cmdArgs := cmdArgs + | exitCmdFailure cmdName s!"Expected value after {arg}." + parseArgs cmdName flagMap acc (pflags.insert flagName (some value)) cmdArgs + | none => + exitCmdFailure cmdName s!"Unknown option {arg}." + else + parseArgs cmdName flagMap (acc.push arg) pflags cmdArgs + | [] => + pure (acc, pflags) + +/-- Dispatch CLI arguments against a command map. This is the shared entry point + that both the default executable and downstream custom executables use. -/ +def runCommandMap (map : Std.HashMap String Command) + (groups : List CommandGroup) (args : List String) : IO Unit := do + try do + match args with + | ["--help"] => printGlobalHelp groups + | cmd :: args => + match map[cmd]? with + | none => exitFailure s!"Expected subcommand, got {cmd}." + | some cmd => + -- Handle per-command help before parsing flags. + if args.contains "--help" then + printCommandHelp cmd + return + -- Index the command's flags by name for O(1) lookup during parsing. + let flagMap : Std.HashMap String Flag := + cmd.flags.foldl (init := {}) fun m f => m.insert f.name f + -- Split raw args into positional arguments and parsed flags. + let (args, pflags) ← parseArgs cmd.name flagMap #[] {} args + if p : args.size = cmd.args.length then + cmd.callback ⟨args, p⟩ pflags + else + exitCmdFailure cmd.name s!"{cmd.name} expects {cmd.args.length} argument(s)." + | [] => do + exitFailure "Expected subcommand." + catch e => + exitFailure e.toString diff --git a/StrataTest/DDM/TypecheckSkip.lean b/StrataTest/DDM/TypecheckSkip.lean new file mode 100644 index 0000000000..eeb09f9730 --- /dev/null +++ b/StrataTest/DDM/TypecheckSkip.lean @@ -0,0 +1,218 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ +module + +import Strata.DDM.Integration.Lean + +/-! +# Tests for `dialect_option typecheck off;` + +When a dialect sets `dialect_option typecheck off;`, elaboration skips +`inferType` and `unifyTypes` for expression arguments. Implicit type +parameter slots are filled with anonymous type placeholders (`.tvar _ ""`). + +This allows programs to elaborate even when the type checker cannot infer +all type arguments — e.g., when a template-generated accessor with tvar +return type is composed with a polymorphic function that needs concrete +type arguments for unification. +-/ + +--------------------------------------------------------------------- +-- Dialect with typecheck ON (default). +-- Includes parameterized types (Lst), polymorphic functions with +-- implicit Type params (lst_select), and perField accessor templates +-- on parameterized datatypes (Maybe). +--------------------------------------------------------------------- + +#dialect +dialect TestTCOn; + +type Boole; +fn equal (tp : Type, a : tp, b : tp) : Boole => @[prec(15)] a " == " b; + +type Inte; +fn natToInt (n : Num) : Inte => n; + +type Lst (elem : Type); +fn lst_select (A : Type, s : Lst A, i : Inte) : A => + "Lst.sel" "(" s ", " i ")"; + +category Binding; +@[declare(name, tp)] +op mkBinding (name : Ident, tp : TypeP) : Binding => + @[prec(40)] name " : " tp; + +category Bindings; +@[scope(bindings)] +op mkBindings (bindings : CommaSepBy Binding) : Bindings => + " (" bindings ")"; + +category Constructor; +category ConstructorList; + +@[constructor(name, fields)] +op constructor_mk (name : Ident, fields : Option (CommaSepBy Binding)) : + Constructor => @[prec(50)] name "(" fields ")"; + +@[constructorListAtom(c)] +op constructorListAtom (c : Constructor) : ConstructorList => "\n " c; + +@[constructorListPush(cl, c)] +op constructorListPush (cl : ConstructorList, c : Constructor) + : ConstructorList => cl ",\n " c; + +category TypeVar; +@[declareTVar(name)] +op type_var (name : Ident) : TypeVar => name; + +category TypeArgs; +@[scope(args)] +op type_args (args : CommaSepBy TypeVar) : TypeArgs => "<" args ">"; + +category DatatypeDecl; +metadata declareDatatype (name : Ident, typeParams : Ident, + constructors : Ident, accessorTemplate : FunctionTemplate); + +@[declareDatatype(name, typeParams, constructors, + perField([.datatype, .literal "..", .field], + [.datatype], .fieldType))] +op datatype_decl (name : Ident, + typeParams : Option Bindings, + @[scopeTVar(typeParams)] constructors : ConstructorList) + : DatatypeDecl => + "datatype " name typeParams " {" constructors "\n}"; + +@[scope(datatypes), preRegisterTypes(datatypes)] +op command_datatypes (datatypes : NewlineSepBy DatatypeDecl) : Command => + datatypes ";\n"; + +@[declare(name, r)] +op command_constdecl (name : Ident, r : Type) : Command => + "const " name ":" r ";\n"; + +category Label; +op label (l : Ident) : Label => "[" l "]: "; + +category Statement; +category Block; + +op assert_stmt (label : Option Label, c : Boole) : Statement => + "assert " label c ";\n"; + +@[scope(c)] +op block (c : SemicolonSepBy Statement) : Block => + "{\n " indent(2, c) "}"; + +op command_procedure (name : Ident, + b : Bindings, + @[scope(b)] body : Block) : + Command => + "procedure " name b " returns ()\n" body ";\n"; +#end + +--------------------------------------------------------------------- +-- Same dialect with typecheck OFF. +-- Imports all declarations from TestTCOn but disables type checking. +-- The typecheck flag is a property of the program's primary dialect; +-- imported dialects' flags are not consulted during elaboration. +--------------------------------------------------------------------- + +#dialect +dialect TestTCOff; +import TestTCOn; +dialect_option typecheck off; +#end + +--------------------------------------------------------------------- +-- Test 1: Accessor result feeds into polymorphic fn. +-- +-- `Maybe..val(m)` returns `tvar "a"` (unresolved) because the +-- accessor template stores its type with tvars. When this flows into +-- `lst_select`, the type checker cannot infer the implicit `A : Type` +-- parameter via unification, producing an error. +-- +-- With typecheck off, no unification is attempted — the implicit type +-- param is filled with a skip placeholder and elaboration succeeds. +--------------------------------------------------------------------- + +/-- +error: Could not infer type parameter 2 for TestTCOn.lst_select +--- +error: Expression has type Inte when .|| expected. +-/ +#guard_msgs in +def typecheckOnFails := +#strata +program TestTCOn; + +datatype Maybe (a : Type) { Nothing(), Just(val: a) }; + +const m: Maybe (Lst Inte); + +procedure Test () returns () +{ + assert [t1]: Lst.sel(Maybe..val(m), 0) == 0; +}; +#end + +-- Same program with typecheck off — elaboration succeeds because +-- inferType/unifyTypes are skipped entirely for expression arguments. +def typecheckOffSucceeds := +#strata +program TestTCOff; + +datatype Maybe (a : Type) { Nothing(), Just(val: a) }; + +const m: Maybe (Lst Inte); + +procedure Test () returns () +{ + assert [t1]: Lst.sel(Maybe..val(m), 0) == 0; +}; +#end + +--------------------------------------------------------------------- +-- Test 2: Unresolved identifiers still fail with typecheck off. +-- +-- `typecheck off` only skips type inference/unification — name +-- resolution still operates normally. +--------------------------------------------------------------------- + +/-- +error: Unknown expr identifier undefined_name +-/ +#guard_msgs in +def typecheckOffStillCatchesUndefined := +#strata +program TestTCOff; + +procedure Test () returns () +{ + assert [t1]: undefined_name == 0; +}; +#end + +--------------------------------------------------------------------- +-- Test 3: Invalid dialect_option values produce clean errors. +--------------------------------------------------------------------- + +/-- +error: Expected 'on' or 'off' for option 'typecheck'. +-/ +#guard_msgs in +#dialect +dialect BadOptionValue; +dialect_option typecheck maybe; +#end + +/-- +error: Unknown option 'nonsense'. +-/ +#guard_msgs in +#dialect +dialect BadOptionName; +dialect_option nonsense on; +#end diff --git a/StrataTest/DL/Imperative/StepStmtTest.lean b/StrataTest/DL/Imperative/StepStmtTest.lean index ec26045867..15734ec0cc 100644 --- a/StrataTest/DL/Imperative/StepStmtTest.lean +++ b/StrataTest/DL/Imperative/StepStmtTest.lean @@ -508,7 +508,7 @@ theorem loopScopeTest : -- Need to reconcile the env shape. conv => rhs; rw [show Env.mk storeWithX miniEval false = { Env.mk (projectStore storeWithX storeWithXY) miniEval false with - hasFailure := false || false } from by simp [hproj, Bool.or_false]] + hasFailure := false || false } from by simp [hproj]] exact .step _ _ _ StepStmt.step_stmts_nil (.refl _) --------------------------------------------------------------------- diff --git a/StrataTest/DL/Imperative/Verify.lean b/StrataTest/DL/Imperative/Verify.lean index ad5a9a6f11..49d54578c7 100644 --- a/StrataTest/DL/Imperative/Verify.lean +++ b/StrataTest/DL/Imperative/Verify.lean @@ -7,6 +7,7 @@ import StrataTest.DL.Imperative.DDMTranslate import StrataTest.DL.Imperative.SMTEncoder import Strata.DL.Imperative.SMTUtils +import Strata.Pipeline.Messages --------------------------------------------------------------------- namespace Arith @@ -27,6 +28,7 @@ def typedVarToSMT (v : String) (ty : Ty) : Except Format (String × Strata.SMT.T def verify (cmds : Commands) (verbose : Bool) : EIO Format (Imperative.VCResults Arith.PureExpr) := do + let pctx ← Strata.Pipeline.PipelineContext.create (outputMode := .quiet) (profilePipeline := false) match typeCheckAndPartialEval cmds with | .error err => .error s!"[Strata.Arith.verify] Error during evaluation!\n\ @@ -56,7 +58,7 @@ def verify (cmds : Commands) (verbose : Bool) : -- (FIXME) ((Arith.Eval.ProofObligation.freeVars obligation).map (fun v => (v, Arith.Ty.Num))) "cvc5" filename.toString - #["--produce-models"] false false true false) + #["--produce-models"] false false true false pctx) match ans with | Except.ok (_, result, estate) => let vcres := { obligation, result, estate } diff --git a/StrataTest/DL/SMT/TranslateTests.lean b/StrataTest/DL/SMT/TranslateTests.lean index 43c7c93611..aee007efbc 100644 --- a/StrataTest/DL/SMT/TranslateTests.lean +++ b/StrataTest/DL/SMT/TranslateTests.lean @@ -140,3 +140,87 @@ info: ∀ (α : Type → Type → Type) [inst : ∀ (α_1 α_2 : Type), Nonempty [(.app .str_concat [(.prim (.bool true)), (.prim (.string "hi"))] (.prim .string)), (.prim (.string "hi"))] (.prim .bool)) + +-- `leftAssocOp` and `leftAssocOpBitVec` require at least two operands, +-- matching the existing error message and `Denote.leftAssoc`. Singletons +-- used to silently pass through the first operand; now they throw. + +/-- error: Error: expected at least two arguments for 'HAdd.hAdd', got '1' -/ +#guard_msgs in +#eval + elabQuery {} [] + (.app .eq + [(.app .add [(.prim (.int 0))] (.prim .int)), + (.prim (.int 0))] + (.prim .bool)) + +/-- error: Error: expected at least two arguments for 'And', got '1' -/ +#guard_msgs in +#eval + elabQuery {} [] + (.app .and + [(.app .eq [(.prim (.int 0)), (.prim (.int 0))] (.prim .bool))] + (.prim .bool)) + +/-- error: Error: expected at least two arguments for BitVec op, got '1' -/ +#guard_msgs in +#eval + let a : SMT.TermVar := { id := "a", ty := .prim (.bitvec 8) } + elabQuery {} [] + (.quant .all [a] a + (.app .eq + [(.app .bvadd [(.var a)] (.prim (.bitvec 8))), (.var a)] + (.prim .bool))) + +/-- error: Error: expected at least two arguments for BitVec op, got '1' -/ +#guard_msgs in +#eval + let a : SMT.TermVar := { id := "a", ty := .prim (.bitvec 8) } + elabQuery {} [] + (.quant .all [a] a + (.app .eq + [(.app .bvand [(.var a)] (.prim (.bitvec 8))), (.var a)] + (.prim .bool))) + +-- Empty-operand lists are still rejected, as before. + +/-- error: Error: expected at least two arguments for 'HAdd.hAdd', got '0' -/ +#guard_msgs in +#eval + elabQuery {} [] + (.app .eq [(.app .add [] (.prim .int)), (.prim (.int 0))] (.prim .bool)) + +-- Binary and ternary uses still produce the expected left-associated Expr. + +/-- info: 1 + 2 + 3 = 6 -/ +#guard_msgs in +#eval + elabQuery {} [] + (.app .eq + [(.app .add + [(.prim (.int 1)), (.prim (.int 2)), (.prim (.int 3))] + (.prim .int)), + (.prim (.int 6))] + (.prim .bool)) + +-- `.app .mod` is strictly binary in the SMT-Lib `Ints` theory and in +-- `Denote.denoteTerm`, so `translateTerm` now rejects any other arity rather +-- than silently lowering e.g. `.app .mod [x, y, z]` to `(x % y) % z`. + +/-- info: 10 % 3 = 1 -/ +#guard_msgs in +#eval + elabQuery {} [] + (.app .eq + [(.app .mod [(.prim (.int 10)), (.prim (.int 3))] (.prim .int)), + (.prim (.int 1))] + (.prim .bool)) + +/-- error: Error: 'mod' expects exactly two operands, got '3' -/ +#guard_msgs in +#eval + elabQuery {} [] + (.app .eq + [(.app .mod [(.prim (.int 10)), (.prim (.int 3)), (.prim (.int 2))] (.prim .int)), + (.prim (.int 1))] + (.prim .bool)) diff --git a/StrataTest/Languages/Boole/FeatureRequests/seq_slicing.lean b/StrataTest/Languages/Boole/FeatureRequests/seq_slicing.lean index e1287844bc..c4b2a152f2 100644 --- a/StrataTest/Languages/Boole/FeatureRequests/seq_slicing.lean +++ b/StrataTest/Languages/Boole/FeatureRequests/seq_slicing.lean @@ -73,7 +73,44 @@ rec function reconstruct(naf: Sequence int) : int ; #end -/-- info: +/-- +info: +Obligation: seq_slicing_seed_post_seq_slicing_seed_ensures_2_1470_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: seq_slicing_seed_post_seq_slicing_seed_ensures_5_1607_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: seq_slicing_seed_post_seq_slicing_seed_ensures_5_1607_calls_Sequence.select_1 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: seq_slicing_seed_post_seq_slicing_seed_ensures_6_1667_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: seq_slicing_seed_post_seq_slicing_seed_ensures_6_1667_calls_Sequence.select_1 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_head_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_tail_calls_Sequence.drop_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_mid_calls_Sequence.drop_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_mid_calls_Sequence.take_1 +Property: out-of-bounds access check +Result: ✅ pass + Obligation: seq_slicing_seed_ensures_2_1470 Property: assert Result: ✅ pass @@ -94,6 +131,10 @@ Obligation: seq_slicing_seed_ensures_6_1667 Property: assert Result: ✅ pass +Obligation: seq_empty_bv64_seed_post_seq_empty_bv64_seed_ensures_8_1938_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + Obligation: seq_empty_bv64_seed_ensures_7_1903 Property: assert Result: ✅ pass @@ -102,13 +143,22 @@ Obligation: seq_empty_bv64_seed_ensures_8_1938 Property: assert Result: ✅ pass +Obligation: reconstruct_body_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: reconstruct_body_calls_Sequence.drop_1 +Property: out-of-bounds access check +Result: ✅ pass + Obligation: reconstruct_terminates_0 Property: assert Result: ✅ pass Obligation: reconstruct_terminates_1 Property: assert -Result: ✅ pass-/ +Result: ✅ pass +-/ #guard_msgs in #eval Strata.Boole.verify "cvc5" seqSlicingSeed (options := .quiet) @@ -133,7 +183,11 @@ spec { #end /-- info: -Obligation: seq_oob_seed_ensures_0_3481 +Obligation: set_v_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ❓ unknown + +Obligation: seq_oob_seed_ensures_0_4964 Property: assert Result: ❓ unknown-/ #guard_msgs in diff --git a/StrataTest/Languages/Boole/FeatureRequests/sha256_compact_indexed.lean b/StrataTest/Languages/Boole/FeatureRequests/sha256_compact_indexed.lean index 47ded901b9..a8a818df42 100644 --- a/StrataTest/Languages/Boole/FeatureRequests/sha256_compact_indexed.lean +++ b/StrataTest/Languages/Boole/FeatureRequests/sha256_compact_indexed.lean @@ -27,19 +27,21 @@ program Boole; type nat; function int_to_nat (i : int) : nat; type Set (T : Type); - function Seq_len (s : Sequence T) : nat { + function Seq_len (s : Sequence bv32) : nat { int_to_nat(Sequence.length(s)) } - function Seq_lib_insert (s : Sequence T, i : int, val : T) : Sequence T { + function Seq_lib_insert (s : Sequence bv32, i : int, val : bv32) : Sequence bv32 + requires 0 <= i && i <= Sequence.length(s); + { Sequence.append(Sequence.build(Sequence.take(s, i), val), Sequence.drop(s, i)) } - function Seq_new (len : nat, f : int -> T) : Sequence T; - function Seq_lib_map (s : Sequence T, f : int -> T -> U) : Sequence U; - function Seq_lib_map_values (s : Sequence T, f : T -> U) : Sequence U; - function Seq_lib_filter (s : Sequence T, p : T -> bool) : Sequence T; - function Seq_lib_sort_by (s : Sequence T, less : T -> T -> bool) : Sequence T; - function Seq_lib_to_set (s : Sequence T) : Set T; - function Set_finite (s : Set T) : bool; + function Seq_new (len : nat, f : int -> bv32) : Sequence bv32; + function Seq_lib_map (s : Sequence bv32, f : int -> bv32 -> bv32) : Sequence bv32; + function Seq_lib_map_values (s : Sequence bv32, f : bv32 -> bv32) : Sequence bv32; + function Seq_lib_filter (s : Sequence bv32, p : bv32 -> bool) : Sequence bv32; + function Seq_lib_sort_by (s : Sequence bv32, less : bv32 -> bv32 -> bool) : Sequence bv32; + function Seq_lib_to_set (s : Sequence bv32) : Set bv32; + function Set_finite (s : Set bv32) : bool; function bv8_to_bv32_u (x : bv8) : bv32; function k32 () : Sequence bv32 { Sequence.of_bv32[bv{32}(1116352408), bv{32}(1899447441), bv{32}(3049323471), bv{32}(3921009573), bv{32}(961987163), bv{32}(1508970993), bv{32}(2453635748), bv{32}(2870763221), bv{32}(3624381080), bv{32}(310598401), bv{32}(607225278), bv{32}(1426881987), bv{32}(1925078388), bv{32}(2162078206), bv{32}(2614888103), bv{32}(3248222580), bv{32}(3835390401), bv{32}(4022224774), bv{32}(264347078), bv{32}(604807628), bv{32}(770255983), bv{32}(1249150122), bv{32}(1555081692), bv{32}(1996064986), bv{32}(2554220882), bv{32}(2821834349), bv{32}(2952996808), bv{32}(3210313671), bv{32}(3336571891), bv{32}(3584528711), bv{32}(113926993), bv{32}(338241895), bv{32}(666307205), bv{32}(773529912), bv{32}(1294757372), bv{32}(1396182291), bv{32}(1695183700), bv{32}(1986661051), bv{32}(2177026350), bv{32}(2456956037), bv{32}(2730485921), bv{32}(2820302411), bv{32}(3259730800), bv{32}(3345764771), bv{32}(3516065817), bv{32}(3600352804), bv{32}(4094571909), bv{32}(275423344), bv{32}(430227734), bv{32}(506948616), bv{32}(659060556), bv{32}(883997877), bv{32}(958139571), bv{32}(1322822218), bv{32}(1537002063), bv{32}(1747873779), bv{32}(1955562222), bv{32}(2024104815), bv{32}(2227730452), bv{32}(2361852424), bv{32}(2428436474), bv{32}(2756734187), bv{32}(3204031479), bv{32}(3329325298)] @@ -55,11 +57,16 @@ spec { }; procedure to_u32s (block : Sequence bv8) returns (_pct_return : (Sequence bv32)) spec { + requires Sequence.length(block) >= 64; + ensures Sequence.length(_pct_return) == 16; } { var j : int; var res : (Sequence bv32); res := Sequence.of_bv32[bv{32}(0), bv{32}(0), bv{32}(0), bv{32}(0), bv{32}(0), bv{32}(0), bv{32}(0), bv{32}(0), bv{32}(0), bv{32}(0), bv{32}(0), bv{32}(0), bv{32}(0), bv{32}(0), bv{32}(0), bv{32}(0)]; - for i : int := 0 to 16 - 1{ + for i : int := 0 to 16 - 1 + invariant 0 <= i + invariant Sequence.length(res) == 16 + { j := i * 4; assert 0 <= 24 && 24 < 32; assert 0 <= 16 && 16 < 32; @@ -71,6 +78,8 @@ spec { }; procedure compress_u32 (state : Sequence bv32, block : Sequence bv32) returns (state_out : (Sequence bv32)) spec { + requires Sequence.length(state) >= 8 && Sequence.length(block) >= 16; + ensures Sequence.length(state_out) == Sequence.length(state); } { var tmp15 : bv32; var tmp16 : bv32; @@ -113,7 +122,11 @@ spec { f := Sequence.select(state_out, 5); g := Sequence.select(state_out, 6); h := Sequence.select(state_out, 7); - for i : int := 0 to 64 - 1{ + for i : int := 0 to 64 - 1 + invariant 0 <= i + invariant Sequence.length(block_local) >= 16 + invariant Sequence.length(state_out) >= 8 + { if (i < 16) { tmp36 := Sequence.select(block_local, i); } else { @@ -176,10 +189,15 @@ spec { }; procedure compress (state : Sequence bv32, blocks : Sequence (Sequence bv8)) returns (state_out : (Sequence bv32)) spec { + requires Sequence.length(state) >= 8; + requires ∀ k:int . 0 <= k && k < Sequence.length(blocks) ==> Sequence.length(Sequence.select(blocks, k)) >= 64; } { var tmp6 : (Sequence bv32); state_out := state; - for k : int := 0 to Sequence.length(blocks) - 1{ + for k : int := 0 to Sequence.length(blocks) - 1 + invariant 0 <= k + invariant Sequence.length(state_out) >= 8 + { call tmp6 := to_u32s(Sequence.select(blocks, k)); call state_out := compress_u32(state_out, tmp6); @@ -195,71 +213,303 @@ spec { #end /-- info: -Obligation: assert_1_3021 +Obligation: Seq_lib_insert_body_calls_Sequence.take_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: Seq_lib_insert_body_calls_Sequence.drop_1 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: assert_2_3107 +Property: assert +Result: ✅ pass + +Obligation: assert_3_3150 +Property: assert +Result: ✅ pass + +Obligation: entry_invariant_0_0 +Property: assert +Result: ✅ pass + +Obligation: entry_invariant_0_1 +Property: assert +Result: ✅ pass + +Obligation: assert_6_3830 +Property: assert +Result: ✅ pass + +Obligation: assert_7_3861 +Property: assert +Result: ✅ pass + +Obligation: assert_8_3892 +Property: assert +Result: ✅ pass + +Obligation: set_res_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_res_calls_Sequence.select_1 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_res_calls_Sequence.select_2 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_res_calls_Sequence.select_3 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_res_calls_Sequence.update_4 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: arbitrary_iter_maintain_invariant_0_0 +Property: assert +Result: ✅ pass + +Obligation: arbitrary_iter_maintain_invariant_0_1 +Property: assert +Result: ✅ pass + +Obligation: to_u32s_ensures_5_3422 +Property: assert +Result: ✅ pass + +Obligation: set_a_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_b_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_c_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_d_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_e_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_f_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_g_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_h_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: entry_invariant_0_0 +Property: assert +Result: ✅ pass + +Obligation: entry_invariant_0_1 +Property: assert +Result: ✅ pass + +Obligation: entry_invariant_0_2 +Property: assert +Result: ✅ pass + +Obligation: set_tmp36_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_w15_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: callElimAssert_rotate_right_requires_1_3056_39 Property: assert Result: ✅ pass -Obligation: assert_2_3064 +Obligation: callElimAssert_rotate_right_requires_1_3056_35 Property: assert Result: ✅ pass -Obligation: assert_3_3596 +Obligation: assert_12_5815 Property: assert Result: ✅ pass -Obligation: assert_4_3627 +Obligation: set_w2_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: callElimAssert_rotate_right_requires_1_3056_31 Property: assert Result: ✅ pass -Obligation: assert_5_3658 +Obligation: callElimAssert_rotate_right_requires_1_3056_27 +Property: assert +Result: ✅ pass + +Obligation: assert_13_6054 +Property: assert +Result: ✅ pass + +Obligation: set_new_w_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_new_w_calls_Sequence.select_1 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_block_local_calls_Sequence.update_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: callElimAssert_rotate_right_requires_1_3056_23 Property: assert Result: ✅ pass -Obligation: callElimAssert_rotate_right_requires_0_2970_39 +Obligation: callElimAssert_rotate_right_requires_1_3056_19 Property: assert Result: ✅ pass -Obligation: callElimAssert_rotate_right_requires_0_2970_35 +Obligation: callElimAssert_rotate_right_requires_1_3056_15 Property: assert Result: ✅ pass -Obligation: assert_7_5332 +Obligation: set_t1_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: callElimAssert_rotate_right_requires_1_3056_11 Property: assert Result: ✅ pass -Obligation: callElimAssert_rotate_right_requires_0_2970_31 +Obligation: callElimAssert_rotate_right_requires_1_3056_7 Property: assert Result: ✅ pass -Obligation: callElimAssert_rotate_right_requires_0_2970_27 +Obligation: callElimAssert_rotate_right_requires_1_3056_3 Property: assert Result: ✅ pass -Obligation: assert_8_5571 +Obligation: arbitrary_iter_maintain_invariant_0_0 Property: assert Result: ✅ pass -Obligation: callElimAssert_rotate_right_requires_0_2970_23 +Obligation: arbitrary_iter_maintain_invariant_0_1 Property: assert Result: ✅ pass -Obligation: callElimAssert_rotate_right_requires_0_2970_19 +Obligation: arbitrary_iter_maintain_invariant_0_2 Property: assert Result: ✅ pass -Obligation: callElimAssert_rotate_right_requires_0_2970_15 +Obligation: set_state_out_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_state_out_calls_Sequence.update_1 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_state_out_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_state_out_calls_Sequence.update_1 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_state_out_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_state_out_calls_Sequence.update_1 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_state_out_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_state_out_calls_Sequence.update_1 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_state_out_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_state_out_calls_Sequence.update_1 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_state_out_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_state_out_calls_Sequence.update_1 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_state_out_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_state_out_calls_Sequence.update_1 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_state_out_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_state_out_calls_Sequence.update_1 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: compress_u32_ensures_11_4416 +Property: assert +Result: ✅ pass + +Obligation: compress_pre_compress_requires_16_7819_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: entry_invariant_0_0 +Property: assert +Result: ✅ pass + +Obligation: entry_invariant_0_1 +Property: assert +Result: ✅ pass + +Obligation: init_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: callElimAssert_to_u32s_requires_4_3381_47 Property: assert Result: ✅ pass -Obligation: callElimAssert_rotate_right_requires_0_2970_11 +Obligation: callElimAssert_compress_u32_requires_10_4344_43 Property: assert Result: ✅ pass -Obligation: callElimAssert_rotate_right_requires_0_2970_7 +Obligation: arbitrary_iter_maintain_invariant_0_0 Property: assert Result: ✅ pass -Obligation: callElimAssert_rotate_right_requires_0_2970_3 +Obligation: arbitrary_iter_maintain_invariant_0_1 Property: assert Result: ✅ pass-/ #guard_msgs in diff --git a/StrataTest/Languages/Boole/otp.lean b/StrataTest/Languages/Boole/otp.lean index d60432fb8f..a066a21ad4 100644 --- a/StrataTest/Languages/Boole/otp.lean +++ b/StrataTest/Languages/Boole/otp.lean @@ -311,6 +311,34 @@ spec /-- info: +Obligation: Encrypt_post_Encrypt_ensures_4_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: Encrypt_post_Encrypt_ensures_4_calls_Sequence.select_1 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: Encrypt_post_Encrypt_ensures_4_calls_Sequence.select_2 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_cipher_calls_Sequence.take_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: loop_invariant_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: loop_invariant_calls_Sequence.select_1 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: loop_invariant_calls_Sequence.select_2 +Property: out-of-bounds access check +Result: ✅ pass + Obligation: entry_invariant_0_0 Property: assert Result: ✅ pass @@ -331,6 +359,14 @@ Obligation: measure_lb_0 Property: assert Result: ✅ pass +Obligation: set_cipher_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_cipher_calls_Sequence.select_1 +Property: out-of-bounds access check +Result: ✅ pass + Obligation: arbitrary_iter_maintain_invariant_0_0 Property: assert Result: ✅ pass @@ -359,6 +395,34 @@ Obligation: Encrypt_ensures_4 Property: assert Result: ✅ pass +Obligation: Decrypt_post_Decrypt_ensures_4_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: Decrypt_post_Decrypt_ensures_4_calls_Sequence.select_1 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: Decrypt_post_Decrypt_ensures_4_calls_Sequence.select_2 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_result_calls_Sequence.take_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: loop_invariant_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: loop_invariant_calls_Sequence.select_1 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: loop_invariant_calls_Sequence.select_2 +Property: out-of-bounds access check +Result: ✅ pass + Obligation: entry_invariant_0_0 Property: assert Result: ✅ pass @@ -379,6 +443,14 @@ Obligation: measure_lb_0 Property: assert Result: ✅ pass +Obligation: set_result_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: set_result_calls_Sequence.select_1 +Property: out-of-bounds access check +Result: ✅ pass + Obligation: arbitrary_iter_maintain_invariant_0_0 Property: assert Result: ✅ pass @@ -407,6 +479,14 @@ Obligation: Decrypt_ensures_4 Property: assert Result: ✅ pass +Obligation: RoundTrip_post_RoundTrip_ensures_4_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: RoundTrip_post_RoundTrip_ensures_4_calls_Sequence.select_1 +Property: out-of-bounds access check +Result: ✅ pass + Obligation: callElimAssert_Encrypt_requires_0_13 Property: assert Result: ✅ pass @@ -419,6 +499,18 @@ Obligation: callElimAssert_Encrypt_requires_2_15 Property: assert Result: ✅ pass +Obligation: assume_callElimAssume_Encrypt_ensures_4_17_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: assume_callElimAssume_Encrypt_ensures_4_17_calls_Sequence.select_1 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: assume_callElimAssume_Encrypt_ensures_4_17_calls_Sequence.select_2 +Property: out-of-bounds access check +Result: ✅ pass + Obligation: callElimAssert_Decrypt_requires_0_4 Property: assert Result: ✅ pass @@ -431,6 +523,18 @@ Obligation: callElimAssert_Decrypt_requires_2_6 Property: assert Result: ✅ pass +Obligation: assume_callElimAssume_Decrypt_ensures_4_8_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: assume_callElimAssume_Decrypt_ensures_4_8_calls_Sequence.select_1 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: assume_callElimAssume_Decrypt_ensures_4_8_calls_Sequence.select_2 +Property: out-of-bounds access check +Result: ✅ pass + Obligation: RoundTrip_ensures_3 Property: assert Result: ✅ pass diff --git a/StrataTest/Languages/Core/Examples/AdvancedMaps.lean b/StrataTest/Languages/Core/Examples/AdvancedMaps.lean index 1de2d946c6..1af27aa0fc 100644 --- a/StrataTest/Languages/Core/Examples/AdvancedMaps.lean +++ b/StrataTest/Languages/Core/Examples/AdvancedMaps.lean @@ -59,15 +59,15 @@ spec { requires [P_requires_1]: c[0] == a; } { assert [c_0_eq_a]: c[0] == a; - c := c[1:=a]; + c[1] := a; assert [c_1_eq_a]: c[1] == a; assert [a0eq0]: a[0] == 0; - a := a[1:=1]; + a[1] := 1; assert [a1eq1]: a[1] == 1; - a := a[0:=1]; + a[0] := 1; assert [a0eq1]: a[0] == 1; assert [a0neq2]: !(a[0] == 2); - b := b[true:=-1]; + b[true] := -1; assert [bTrueEqTrue]: b[true] == -1; assert [mix]: a[1] == -(b[true]); }; diff --git a/StrataTest/Languages/Core/Examples/AdvancedQuantifiers.lean b/StrataTest/Languages/Core/Examples/AdvancedQuantifiers.lean index 758a8a0a5f..aac3dcd880 100644 --- a/StrataTest/Languages/Core/Examples/AdvancedQuantifiers.lean +++ b/StrataTest/Languages/Core/Examples/AdvancedQuantifiers.lean @@ -32,14 +32,14 @@ VCs: Label: a Property: assert Assumptions: -mapAllValues0: forall __q0 : (Map int int) :: forall __q1 : int :: __q0[__q1] == 0 +mapAllValues0: forall m : (Map int int) :: forall k : int :: m[k] == 0 Obligation: mArg@1[kArg@1] == 0 Label: Update_ensures_0 Property: assert Assumptions: -mapAllValues0: forall __q0 : (Map int int) :: forall __q1 : int :: __q0[__q1] == 0 +mapAllValues0: forall m : (Map int int) :: forall k : int :: m[k] == 0 Obligation: mArg@1[kArg@1] == 0 diff --git a/StrataTest/Languages/Core/Examples/Axioms.lean b/StrataTest/Languages/Core/Examples/Axioms.lean index a763936907..6ed5ff4f5c 100644 --- a/StrataTest/Languages/Core/Examples/Axioms.lean +++ b/StrataTest/Languages/Core/Examples/Axioms.lean @@ -51,7 +51,7 @@ Property: assert Assumptions: a1: x == 5 a2: y == 2 -f1: forall __q0 : int :: f(__q0) > __q0 +f1: forall y : int :: f(y) > y Obligation: x > y @@ -60,7 +60,7 @@ Property: assert Assumptions: a1: x == 5 a2: y == 2 -f1: forall __q0 : int :: f(__q0) > __q0 +f1: forall y : int :: f(y) > y Obligation: f(x + y) > 7 @@ -69,7 +69,7 @@ Property: assert Assumptions: a1: x == 5 a2: y == 2 -f1: forall __q0 : int :: f(__q0) > __q0 +f1: forall y : int :: f(y) > y Obligation: y == 2 @@ -78,7 +78,7 @@ Property: assert Assumptions: a1: x == 5 a2: y == 2 -f1: forall __q0 : int :: f(__q0) > __q0 +f1: forall y : int :: f(y) > y Obligation: f(y) > y @@ -139,10 +139,10 @@ VCs: Label: axiomPgm2_main_assert Property: assert Assumptions: -f_g_ax: forall __q0 : int :: { f(__q0) } - f(__q0) == g(__q0) + 1 -g_ax: forall __q0 : int :: { g(__q0), f(__q0) } - g(__q0) == __q0 * 2 +f_g_ax: forall x : int :: { f(x) } + f(x) == g(x) + 1 +g_ax: forall x : int :: { g(x), f(x) } + g(x) == x * 2 Obligation: x@1 >= 0 ==> f(x@1) > x@1 diff --git a/StrataTest/Languages/Core/Examples/FunctionPreconditions.lean b/StrataTest/Languages/Core/Examples/FunctionPreconditions.lean index a33b03a92c..04f09ead9b 100644 --- a/StrataTest/Languages/Core/Examples/FunctionPreconditions.lean +++ b/StrataTest/Languages/Core/Examples/FunctionPreconditions.lean @@ -415,7 +415,7 @@ Property: assert Assumptions: precond_allPositiveDiv_0: y@2 >= 0 Obligation: -forall __q0 : int :: __q0 > 0 ==> !(__q0 == 0) +forall x : int :: x > 0 ==> !(x == 0) --- info: diff --git a/StrataTest/Languages/Core/Examples/Loops.lean b/StrataTest/Languages/Core/Examples/Loops.lean index 729c9c9a23..ddcddd7c62 100644 --- a/StrataTest/Languages/Core/Examples/Loops.lean +++ b/StrataTest/Languages/Core/Examples/Loops.lean @@ -384,6 +384,7 @@ loop_entry$_1: -- Errors encountered during conversion: Unsupported construct in lopToExpr: 0-ary op not found: top Context: Global scope: + freeVars: [n] var loop_measure$_2 : int; assume [assume_loop_measure$_2]: loop_measure$_2 == n - x; assert [measure_lb_loop_measure$_2]: !(loop_measure$_2 < 0); diff --git a/StrataTest/Languages/Core/Examples/Quantifiers.lean b/StrataTest/Languages/Core/Examples/Quantifiers.lean index 6436514ca4..825a5dd4d4 100644 --- a/StrataTest/Languages/Core/Examples/Quantifiers.lean +++ b/StrataTest/Languages/Core/Examples/Quantifiers.lean @@ -54,17 +54,17 @@ VCs: Label: good_assert Property: assert Obligation: -forall __q0 : int :: !(__q0 == __q0 + 1) +forall l : int :: !(l == l + 1) Label: good Property: assert Obligation: -forall __q0 : int :: exists __q1 : int :: x@1 + 1 + (__q1 + __q0) == __q0 + (__q1 + (x@1 + 1)) +forall y : int :: exists z : int :: x@1 + 1 + (z + y) == y + (z + (x@1 + 1)) Label: bad Property: assert Obligation: -forall __q0 : int :: __q0 < x@1 +forall q : int :: q < x@1 --- info: @@ -93,42 +93,42 @@ VCs: Label: trigger_assert Property: assert Assumptions: -f_pos: forall __q0 : int :: { f(__q0) } - f(__q0) > 0 -g_neg: forall __q0 : int :: forall __q1 : int :: { g(__q0, __q1) } - __q0 > 0 ==> g(__q0, __q1) < 0 -f_and_g: forall __q0 : int :: forall __q1 : int :: { g(__q0, __q1), f(__q0) } - g(__q0, __q1) < f(__q0) -f_and_g2: forall __q0 : int :: forall __q1 : int :: { g(__q0, __q1), f(__q0) } - g(__q0, __q1) < f(__q0) +f_pos: forall x : int :: { f(x) } + f(x) > 0 +g_neg: forall x : int :: forall y : int :: { g(x, y) } + x > 0 ==> g(x, y) < 0 +f_and_g: forall x : int :: forall y : int :: { g(x, y), f(x) } + g(x, y) < f(x) +f_and_g2: forall x : int :: forall y : int :: { g(x, y), f(x) } + g(x, y) < f(x) Obligation: f(x@1) > 0 Label: multi_trigger_assert Property: assert Assumptions: -f_pos: forall __q0 : int :: { f(__q0) } - f(__q0) > 0 -g_neg: forall __q0 : int :: forall __q1 : int :: { g(__q0, __q1) } - __q0 > 0 ==> g(__q0, __q1) < 0 -f_and_g: forall __q0 : int :: forall __q1 : int :: { g(__q0, __q1), f(__q0) } - g(__q0, __q1) < f(__q0) -f_and_g2: forall __q0 : int :: forall __q1 : int :: { g(__q0, __q1), f(__q0) } - g(__q0, __q1) < f(__q0) +f_pos: forall x : int :: { f(x) } + f(x) > 0 +g_neg: forall x : int :: forall y : int :: { g(x, y) } + x > 0 ==> g(x, y) < 0 +f_and_g: forall x : int :: forall y : int :: { g(x, y), f(x) } + g(x, y) < f(x) +f_and_g2: forall x : int :: forall y : int :: { g(x, y), f(x) } + g(x, y) < f(x) Obligation: -forall __q0 : int :: g(x@1, __q0) < f(x@1) +forall y : int :: g(x@1, y) < f(x@1) Label: f_and_g Property: assert Assumptions: -f_pos: forall __q0 : int :: { f(__q0) } - f(__q0) > 0 -g_neg: forall __q0 : int :: forall __q1 : int :: { g(__q0, __q1) } - __q0 > 0 ==> g(__q0, __q1) < 0 -f_and_g: forall __q0 : int :: forall __q1 : int :: { g(__q0, __q1), f(__q0) } - g(__q0, __q1) < f(__q0) -f_and_g2: forall __q0 : int :: forall __q1 : int :: { g(__q0, __q1), f(__q0) } - g(__q0, __q1) < f(__q0) +f_pos: forall x : int :: { f(x) } + f(x) > 0 +g_neg: forall x : int :: forall y : int :: { g(x, y) } + x > 0 ==> g(x, y) < 0 +f_and_g: forall x : int :: forall y : int :: { g(x, y), f(x) } + g(x, y) < f(x) +f_and_g2: forall x : int :: forall y : int :: { g(x, y), f(x) } + g(x, y) < f(x) Obligation: g(f(x@1), x@1) < 0 diff --git a/StrataTest/Languages/Core/Examples/QuantifiersWithTypeAliases.lean b/StrataTest/Languages/Core/Examples/QuantifiersWithTypeAliases.lean index cfebbf7c50..ac60917729 100644 --- a/StrataTest/Languages/Core/Examples/QuantifiersWithTypeAliases.lean +++ b/StrataTest/Languages/Core/Examples/QuantifiersWithTypeAliases.lean @@ -42,10 +42,10 @@ type Ref; type Field; type Struct := Map Field int; type Heap := Map Ref Struct; -axiom [axiom_0]: forall __q0 : Struct :: forall __q1 : Field :: forall __q2 : Field :: forall __q3 : int :: !(__q1 == __q2) ==> __q0[__q1] == (__q0[__q2:=__q3])[__q1]; -axiom [axiom_1]: forall __q0 : Struct :: forall __q1 : Field :: forall __q2 : int :: (__q0[__q1:=__q2])[__q1] == __q2; -axiom [axiom_2]: forall __q0 : Heap :: forall __q1 : Ref :: forall __q2 : Ref :: forall __q3 : Struct :: !(__q1 == __q2) ==> __q0[__q1] == (__q0[__q2:=__q3])[__q1]; -axiom [axiom_3]: forall __q0 : Heap :: forall __q1 : Ref :: forall __q2 : Struct :: (__q0[__q1:=__q2])[__q1] == __q2; +axiom [axiom_0]: forall m : Struct :: forall okk : Field :: forall kk : Field :: forall vv : int :: !(okk == kk) ==> m[okk] == (m[kk:=vv])[okk]; +axiom [axiom_1]: forall m : Struct :: forall kk : Field :: forall vv : int :: (m[kk:=vv])[kk] == vv; +axiom [axiom_2]: forall m : Heap :: forall okk : Ref :: forall kk : Ref :: forall vv : Struct :: !(okk == kk) ==> m[okk] == (m[kk:=vv])[okk]; +axiom [axiom_3]: forall m : Heap :: forall kk : Ref :: forall vv : Struct :: (m[kk:=vv])[kk] == vv; procedure test (h : Heap, ref : Ref, field : Field) { var newH : Heap := h[ref:=(h[ref])[field:=(h[ref])[field] + 1]]; @@ -64,10 +64,10 @@ VCs: Label: assert0 Property: assert Assumptions: -axiom_0: forall __q0 : (Map Field int) :: forall __q1 : Field :: forall __q2 : Field :: forall __q3 : int :: !(__q1 == __q2) ==> __q0[__q1] == (__q0[__q2:=__q3])[__q1] -axiom_1: forall __q0 : (Map Field int) :: forall __q1 : Field :: forall __q2 : int :: (__q0[__q1:=__q2])[__q1] == __q2 -axiom_2: forall __q0 : (Map Ref (Map Field int)) :: forall __q1 : Ref :: forall __q2 : Ref :: forall __q3 : (Map Field int) :: !(__q1 == __q2) ==> __q0[__q1] == (__q0[__q2:=__q3])[__q1] -axiom_3: forall __q0 : (Map Ref (Map Field int)) :: forall __q1 : Ref :: forall __q2 : (Map Field int) :: (__q0[__q1:=__q2])[__q1] == __q2 +axiom_0: forall m : (Map Field int) :: forall okk : Field :: forall kk : Field :: forall vv : int :: !(okk == kk) ==> m[okk] == (m[kk:=vv])[okk] +axiom_1: forall m : (Map Field int) :: forall kk : Field :: forall vv : int :: (m[kk:=vv])[kk] == vv +axiom_2: forall m : (Map Ref (Map Field int)) :: forall okk : Ref :: forall kk : Ref :: forall vv : (Map Field int) :: !(okk == kk) ==> m[okk] == (m[kk:=vv])[okk] +axiom_3: forall m : (Map Ref (Map Field int)) :: forall kk : Ref :: forall vv : (Map Field int) :: (m[kk:=vv])[kk] == vv Obligation: ((h@1[ref@1:=(h@1[ref@1])[field@1:=(h@1[ref@1])[field@1] + 1]])[ref@1])[field@1] == (h@1[ref@1])[field@1] + 1 diff --git a/StrataTest/Languages/Core/Examples/Seq.lean b/StrataTest/Languages/Core/Examples/Seq.lean index c7e3c32b8e..d124c1d619 100644 --- a/StrataTest/Languages/Core/Examples/Seq.lean +++ b/StrataTest/Languages/Core/Examples/Seq.lean @@ -68,6 +68,13 @@ s_empty: Sequence.length(s) == 0 Obligation: Sequence.length(Sequence.build(Sequence.build(Sequence.build(s, 10), 20), 30)) == 3 +Label: assert_t_0_calls_Sequence.select_0 +Property: out-of-bounds access check +Assumptions: +s_empty: Sequence.length(s) == 0 +Obligation: +true && 0 < Sequence.length(Sequence.build(Sequence.build(Sequence.build(s, 10), 20), 30)) + Label: t_0 Property: assert Assumptions: @@ -75,6 +82,13 @@ s_empty: Sequence.length(s) == 0 Obligation: Sequence.select(Sequence.build(Sequence.build(Sequence.build(s, 10), 20), 30), 0) == 10 +Label: assert_t_1_calls_Sequence.select_0 +Property: out-of-bounds access check +Assumptions: +s_empty: Sequence.length(s) == 0 +Obligation: +true && 1 < Sequence.length(Sequence.build(Sequence.build(Sequence.build(s, 10), 20), 30)) + Label: t_1 Property: assert Assumptions: @@ -82,6 +96,13 @@ s_empty: Sequence.length(s) == 0 Obligation: Sequence.select(Sequence.build(Sequence.build(Sequence.build(s, 10), 20), 30), 1) == 20 +Label: assert_t_2_calls_Sequence.select_0 +Property: out-of-bounds access check +Assumptions: +s_empty: Sequence.length(s) == 0 +Obligation: +true && 2 < Sequence.length(Sequence.build(Sequence.build(Sequence.build(s, 10), 20), 30)) + Label: t_2 Property: assert Assumptions: @@ -102,14 +123,26 @@ Obligation: t_length Property: assert Result: ✅ pass +Obligation: assert_t_0_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + Obligation: t_0 Property: assert Result: ✅ pass +Obligation: assert_t_1_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + Obligation: t_1 Property: assert Result: ✅ pass +Obligation: assert_t_2_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + Obligation: t_2 Property: assert Result: ✅ pass @@ -185,6 +218,13 @@ s_empty: Sequence.length(s) == 0 Obligation: Sequence.length(Sequence.append(Sequence.build(Sequence.build(Sequence.build(s, 10), 20), 30), Sequence.build(Sequence.build(s, 40), 50))) == 5 +Label: assert_append_elem_0_calls_Sequence.select_0 +Property: out-of-bounds access check +Assumptions: +s_empty: Sequence.length(s) == 0 +Obligation: +true && 0 < Sequence.length(Sequence.append(Sequence.build(Sequence.build(Sequence.build(s, 10), 20), 30), Sequence.build(Sequence.build(s, 40), 50))) + Label: append_elem_0 Property: assert Assumptions: @@ -192,6 +232,13 @@ s_empty: Sequence.length(s) == 0 Obligation: Sequence.select(Sequence.append(Sequence.build(Sequence.build(Sequence.build(s, 10), 20), 30), Sequence.build(Sequence.build(s, 40), 50)), 0) == 10 +Label: assert_append_elem_4_calls_Sequence.select_0 +Property: out-of-bounds access check +Assumptions: +s_empty: Sequence.length(s) == 0 +Obligation: +true && 4 < Sequence.length(Sequence.append(Sequence.build(Sequence.build(Sequence.build(s, 10), 20), 30), Sequence.build(Sequence.build(s, 40), 50))) + Label: append_elem_4 Property: assert Assumptions: @@ -199,6 +246,13 @@ s_empty: Sequence.length(s) == 0 Obligation: Sequence.select(Sequence.append(Sequence.build(Sequence.build(Sequence.build(s, 10), 20), 30), Sequence.build(Sequence.build(s, 40), 50)), 4) == 50 +Label: set_u_calls_Sequence.update_0 +Property: out-of-bounds access check +Assumptions: +s_empty: Sequence.length(s) == 0 +Obligation: +true && 1 < Sequence.length(Sequence.build(Sequence.build(Sequence.build(s, 10), 20), 30)) + Label: update_length Property: assert Assumptions: @@ -206,6 +260,13 @@ s_empty: Sequence.length(s) == 0 Obligation: Sequence.length(Sequence.update(Sequence.build(Sequence.build(Sequence.build(s, 10), 20), 30), 1, 99)) == 3 +Label: assert_update_same_calls_Sequence.select_0 +Property: out-of-bounds access check +Assumptions: +s_empty: Sequence.length(s) == 0 +Obligation: +true && 1 < Sequence.length(Sequence.update(Sequence.build(Sequence.build(Sequence.build(s, 10), 20), 30), 1, 99)) + Label: update_same Property: assert Assumptions: @@ -213,6 +274,13 @@ s_empty: Sequence.length(s) == 0 Obligation: Sequence.select(Sequence.update(Sequence.build(Sequence.build(Sequence.build(s, 10), 20), 30), 1, 99), 1) == 99 +Label: assert_update_other_calls_Sequence.select_0 +Property: out-of-bounds access check +Assumptions: +s_empty: Sequence.length(s) == 0 +Obligation: +true && 0 < Sequence.length(Sequence.update(Sequence.build(Sequence.build(Sequence.build(s, 10), 20), 30), 1, 99)) + Label: update_other Property: assert Assumptions: @@ -227,6 +295,13 @@ s_empty: Sequence.length(s) == 0 Obligation: Sequence.contains(Sequence.build(Sequence.build(Sequence.build(s, 10), 20), 30), 20) +Label: set_u_calls_Sequence.take_0 +Property: out-of-bounds access check +Assumptions: +s_empty: Sequence.length(s) == 0 +Obligation: +true && 2 <= Sequence.length(Sequence.build(Sequence.build(Sequence.build(s, 10), 20), 30)) + Label: take_length Property: assert Assumptions: @@ -234,6 +309,13 @@ s_empty: Sequence.length(s) == 0 Obligation: Sequence.length(Sequence.take(Sequence.build(Sequence.build(Sequence.build(s, 10), 20), 30), 2)) == 2 +Label: assert_take_elem_calls_Sequence.select_0 +Property: out-of-bounds access check +Assumptions: +s_empty: Sequence.length(s) == 0 +Obligation: +true && 0 < Sequence.length(Sequence.take(Sequence.build(Sequence.build(Sequence.build(s, 10), 20), 30), 2)) + Label: take_elem Property: assert Assumptions: @@ -241,6 +323,13 @@ s_empty: Sequence.length(s) == 0 Obligation: Sequence.select(Sequence.take(Sequence.build(Sequence.build(Sequence.build(s, 10), 20), 30), 2), 0) == 10 +Label: set_u_calls_Sequence.drop_0 +Property: out-of-bounds access check +Assumptions: +s_empty: Sequence.length(s) == 0 +Obligation: +true && 1 <= Sequence.length(Sequence.build(Sequence.build(Sequence.build(s, 10), 20), 30)) + Label: drop_length Property: assert Assumptions: @@ -248,6 +337,13 @@ s_empty: Sequence.length(s) == 0 Obligation: Sequence.length(Sequence.drop(Sequence.build(Sequence.build(Sequence.build(s, 10), 20), 30), 1)) == 2 +Label: assert_drop_elem_calls_Sequence.select_0 +Property: out-of-bounds access check +Assumptions: +s_empty: Sequence.length(s) == 0 +Obligation: +true && 0 < Sequence.length(Sequence.drop(Sequence.build(Sequence.build(Sequence.build(s, 10), 20), 30), 1)) + Label: drop_elem Property: assert Assumptions: @@ -261,22 +357,42 @@ Obligation: append_length Property: assert Result: ✅ pass +Obligation: assert_append_elem_0_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + Obligation: append_elem_0 Property: assert Result: ✅ pass +Obligation: assert_append_elem_4_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + Obligation: append_elem_4 Property: assert Result: ✅ pass +Obligation: set_u_calls_Sequence.update_0 +Property: out-of-bounds access check +Result: ✅ pass + Obligation: update_length Property: assert Result: ✅ pass +Obligation: assert_update_same_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + Obligation: update_same Property: assert Result: ✅ pass +Obligation: assert_update_other_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + Obligation: update_other Property: assert Result: ✅ pass @@ -285,18 +401,34 @@ Obligation: contains_yes Property: assert Result: ❓ unknown +Obligation: set_u_calls_Sequence.take_0 +Property: out-of-bounds access check +Result: ✅ pass + Obligation: take_length Property: assert Result: ✅ pass +Obligation: assert_take_elem_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + Obligation: take_elem Property: assert Result: ✅ pass +Obligation: set_u_calls_Sequence.drop_0 +Property: out-of-bounds access check +Result: ✅ pass + Obligation: drop_length Property: assert Result: ✅ pass +Obligation: assert_drop_elem_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + Obligation: drop_elem Property: assert Result: ✅ pass @@ -305,3 +437,178 @@ Result: ✅ pass #eval verify seqOpsPgm --------------------------------------------------------------------- + +---------------------------------------------------------------------- +-- Tests for Sequence.empty() syntax (issue #1027) +---------------------------------------------------------------------- + +private def seqEmptyPgm := +#strata +program Core; + +procedure SeqEmpty() +{ + var s : Sequence int; + + // Create an empty sequence using Sequence.empty syntax + s := Sequence.empty(); + assert [empty_length]: Sequence.length(s) == 0; + + // Build on top of an empty sequence + s := Sequence.build(Sequence.empty(), 42); + assert [build_on_empty_length]: Sequence.length(s) == 1; + assert [build_on_empty_elem]: Sequence.select(s, 0) == 42; +}; +#end + +/-- info: true -/ +#guard_msgs in +-- No errors in translation. +#eval TransM.run Inhabited.default (translateProgram seqEmptyPgm) |>.snd |>.isEmpty + +/-- +info: program Core; + +procedure SeqEmpty () +{ + var s : (Sequence int); + s := Sequence.empty(); + assert [empty_length]: Sequence.length(s) == 0; + s := Sequence.build(Sequence.empty(), 42); + assert [build_on_empty_length]: Sequence.length(s) == 1; + assert [build_on_empty_elem]: Sequence.select(s, 0) == 42; +}; +-/ +#guard_msgs in +#eval TransM.run Inhabited.default (translateProgram seqEmptyPgm) |>.fst + +/-- +info: [Strata.Core] Type checking succeeded. + + +VCs: +Label: empty_length +Property: assert +Obligation: +Sequence.length(Sequence.empty()) == 0 + +Label: build_on_empty_length +Property: assert +Obligation: +Sequence.length(Sequence.build(Sequence.empty(), 42)) == 1 + +Label: assert_build_on_empty_elem_calls_Sequence.select_0 +Property: out-of-bounds access check +Obligation: +true && 0 < Sequence.length(Sequence.build(Sequence.empty(), 42)) + +Label: build_on_empty_elem +Property: assert +Obligation: +Sequence.select(Sequence.build(Sequence.empty(), 42), 0) == 42 + +--- +info: +Obligation: empty_length +Property: assert +Result: ✅ pass + +Obligation: build_on_empty_length +Property: assert +Result: ✅ pass + +Obligation: assert_build_on_empty_elem_calls_Sequence.select_0 +Property: out-of-bounds access check +Result: ✅ pass + +Obligation: build_on_empty_elem +Property: assert +Result: ✅ pass +-/ +#guard_msgs in +#eval verify seqEmptyPgm + +---------------------------------------------------------------------- + +-- Exercise various element types for Sequence.empty(). +private def seqEmptyTypesPgm := +#strata +program Core; + +procedure SeqEmptyTypes() +{ + var sb : Sequence bool; + var ssi : Sequence (Sequence int); + var smi : Sequence (Map int bool); + + sb := Sequence.empty(); + ssi := Sequence.empty(); + smi := Sequence.empty(); + + assert [bool_len]: Sequence.length(sb) == 0; + assert [seq_seq_len]: Sequence.length(ssi) == 0; + assert [seq_map_len]: Sequence.length(smi) == 0; +}; +#end + +/-- info: true -/ +#guard_msgs in +#eval TransM.run Inhabited.default (translateProgram seqEmptyTypesPgm) |>.snd |>.isEmpty + +/-- +info: program Core; + +procedure SeqEmptyTypes () +{ + var sb : (Sequence bool); + var ssi : (Sequence (Sequence int)); + var smi : (Sequence (Map int bool)); + sb := Sequence.empty(); + ssi := Sequence.empty(); + smi := Sequence.empty(); + assert [bool_len]: Sequence.length(sb) == 0; + assert [seq_seq_len]: Sequence.length(ssi) == 0; + assert [seq_map_len]: Sequence.length(smi) == 0; +}; +-/ +#guard_msgs in +#eval TransM.run Inhabited.default (translateProgram seqEmptyTypesPgm) |>.fst + +/-- +info: [Strata.Core] Type checking succeeded. + + +VCs: +Label: bool_len +Property: assert +Obligation: +Sequence.length(Sequence.empty()) == 0 + +Label: seq_seq_len +Property: assert +Obligation: +Sequence.length(Sequence.empty()) == 0 + +Label: seq_map_len +Property: assert +Obligation: +Sequence.length(Sequence.empty()) == 0 + +--- +info: +Obligation: bool_len +Property: assert +Result: ✅ pass + +Obligation: seq_seq_len +Property: assert +Result: ✅ pass + +Obligation: seq_map_len +Property: assert +Result: ✅ pass +-/ +#guard_msgs in +#eval verify seqEmptyTypesPgm + +---------------------------------------------------------------------- diff --git a/StrataTest/Languages/Core/Examples/SubstFvarsCaptureTests.lean b/StrataTest/Languages/Core/Examples/SubstFvarsCaptureTests.lean index 042d88705e..03e9b3e795 100644 --- a/StrataTest/Languages/Core/Examples/SubstFvarsCaptureTests.lean +++ b/StrataTest/Languages/Core/Examples/SubstFvarsCaptureTests.lean @@ -114,11 +114,11 @@ private def actualsBvar : List (LExpr CoreLParams.mono) := [.bvar () 0] -- Correct (with lifting): `forall z :: bvar 1 > bvar 0` (bvar 1 = outer y). -- The "out of bounds" error is expected: bvar!1 is only in-bounds when the iterated version incorrectly captures it. /-- -info: forall __q0 : int :: bvar!1 > __q0 +info: forall z : int :: bvar!1 > z -- Errors: Unsupported construct in lexprToExpr: bvar index out of bounds: 1 Context: Global scope: Scope 1: - boundVars: [__q0] + boundVars: [z] -/ #guard_msgs in #eval Std.ToFormat.format (substitutePrecondition precondBvar formalsBvar actualsBvar) diff --git a/StrataTest/Languages/Core/Examples/TypeVarImplicitlyQuantified.lean b/StrataTest/Languages/Core/Examples/TypeVarImplicitlyQuantified.lean index ff5296a0df..1c580bbf23 100644 --- a/StrataTest/Languages/Core/Examples/TypeVarImplicitlyQuantified.lean +++ b/StrataTest/Languages/Core/Examples/TypeVarImplicitlyQuantified.lean @@ -45,10 +45,10 @@ info: ok: program Core; type set := Map int bool; function diff (a : Map int bool, b : Map int bool) : Map int bool; function lambda_0 (l_0 : bool, l_1 : int, l_2 : int) : Map int int; -axiom [a1]: forall __q0 : (Map int bool) :: forall __q1 : (Map int bool) :: { diff(__q0, __q1) } - diff(__q0, __q1) == diff(__q1, __q0); -axiom [a2]: forall __q0 : bool :: forall __q1 : int :: forall __q2 : int :: forall __q3 : int :: { (lambda_0(__q0, __q1, __q2))[__q3] } - (lambda_0(__q0, __q1, __q2))[__q3] == (lambda_0(__q0, __q2, __q1))[__q3]; +axiom [a1]: forall a : (Map int bool) :: forall b : (Map int bool) :: { diff(a, b) } + diff(a, b) == diff(b, a); +axiom [a2]: forall l_0 : bool :: forall l_1 : int :: forall l_2 : int :: forall y : int :: { (lambda_0(l_0, l_1, l_2))[y] } + (lambda_0(l_0, l_1, l_2))[y] == (lambda_0(l_0, l_2, l_1))[y]; -/ #guard_msgs in #eval Core.typeCheck .default core_pgm.fst diff --git a/StrataTest/Languages/Core/Tests/ExprEvalTest.lean b/StrataTest/Languages/Core/Tests/ExprEvalTest.lean index 60ef7f173d..7bdd7092c0 100644 --- a/StrataTest/Languages/Core/Tests/ExprEvalTest.lean +++ b/StrataTest/Languages/Core/Tests/ExprEvalTest.lean @@ -59,6 +59,7 @@ Check whether concrete evaluation of e matches the SMT encoding of e. Returns false if e did not reduce to a constant. -/ def checkValid (e:LExpr CoreLParams.mono): IO Bool := do + let pctx ← Strata.Pipeline.PipelineContext.create (outputMode := .quiet) (profilePipeline := false) let tenv := TEnv.default let init_state := LState.init let e_fvs := LExpr.freeVars e @@ -75,7 +76,7 @@ def checkValid (e:LExpr CoreLParams.mono): IO Bool := do let ans ← Core.SMT.dischargeObligation { Core.VerifyOptions.default with verbose := .quiet } e_fvs_typed Imperative.MetaData.empty filename.toString - [] smt_term ctx true false (label := "exprEvalTest") + [] smt_term ctx true false (label := "exprEvalTest") (pctx := pctx) match ans with | .ok (.sat _, _, _) => return true | _ => diff --git a/StrataTest/Languages/Core/Tests/GeneratedLabels.lean b/StrataTest/Languages/Core/Tests/GeneratedLabels.lean index 1b5992dbc3..b97363ab5b 100644 --- a/StrataTest/Languages/Core/Tests/GeneratedLabels.lean +++ b/StrataTest/Languages/Core/Tests/GeneratedLabels.lean @@ -40,10 +40,10 @@ type Ref; type Field; type Struct := Map Field int; type Heap := Map Ref Struct; -axiom [axiom_0]: forall __q0 : Struct :: forall __q1 : Field :: forall __q2 : Field :: forall __q3 : int :: !(__q1 == __q2) ==> __q0[__q1] == (__q0[__q2:=__q3])[__q1]; -axiom [axiom_1]: forall __q0 : Struct :: forall __q1 : Field :: forall __q2 : int :: (__q0[__q1:=__q2])[__q1] == __q2; -axiom [axiom_2]: forall __q0 : Heap :: forall __q1 : Ref :: forall __q2 : Ref :: forall __q3 : Struct :: !(__q1 == __q2) ==> __q0[__q1] == (__q0[__q2:=__q3])[__q1]; -axiom [axiom_3]: forall __q0 : Heap :: forall __q1 : Ref :: forall __q2 : Struct :: (__q0[__q1:=__q2])[__q1] == __q2; +axiom [axiom_0]: forall m : Struct :: forall okk : Field :: forall kk : Field :: forall vv : int :: !(okk == kk) ==> m[okk] == (m[kk:=vv])[okk]; +axiom [axiom_1]: forall m : Struct :: forall kk : Field :: forall vv : int :: (m[kk:=vv])[kk] == vv; +axiom [axiom_2]: forall m : Heap :: forall okk : Ref :: forall kk : Ref :: forall vv : Struct :: !(okk == kk) ==> m[okk] == (m[kk:=vv])[okk]; +axiom [axiom_3]: forall m : Heap :: forall kk : Ref :: forall vv : Struct :: (m[kk:=vv])[kk] == vv; procedure test (h : Heap, ref : Ref, field : Field) { var newH : Heap := h[ref:=(h[ref])[field:=(h[ref])[field] + 1]]; @@ -61,10 +61,10 @@ VCs: Label: assert_0 Property: assert Assumptions: -axiom_0: forall __q0 : (Map Field int) :: forall __q1 : Field :: forall __q2 : Field :: forall __q3 : int :: !(__q1 == __q2) ==> __q0[__q1] == (__q0[__q2:=__q3])[__q1] -axiom_1: forall __q0 : (Map Field int) :: forall __q1 : Field :: forall __q2 : int :: (__q0[__q1:=__q2])[__q1] == __q2 -axiom_2: forall __q0 : (Map Ref (Map Field int)) :: forall __q1 : Ref :: forall __q2 : Ref :: forall __q3 : (Map Field int) :: !(__q1 == __q2) ==> __q0[__q1] == (__q0[__q2:=__q3])[__q1] -axiom_3: forall __q0 : (Map Ref (Map Field int)) :: forall __q1 : Ref :: forall __q2 : (Map Field int) :: (__q0[__q1:=__q2])[__q1] == __q2 +axiom_0: forall m : (Map Field int) :: forall okk : Field :: forall kk : Field :: forall vv : int :: !(okk == kk) ==> m[okk] == (m[kk:=vv])[okk] +axiom_1: forall m : (Map Field int) :: forall kk : Field :: forall vv : int :: (m[kk:=vv])[kk] == vv +axiom_2: forall m : (Map Ref (Map Field int)) :: forall okk : Ref :: forall kk : Ref :: forall vv : (Map Field int) :: !(okk == kk) ==> m[okk] == (m[kk:=vv])[okk] +axiom_3: forall m : (Map Ref (Map Field int)) :: forall kk : Ref :: forall vv : (Map Field int) :: (m[kk:=vv])[kk] == vv Obligation: ((h@1[ref@1:=(h@1[ref@1])[field@1:=(h@1[ref@1])[field@1] + 1]])[ref@1])[field@1] == (h@1[ref@1])[field@1] + 1 diff --git a/StrataTest/Languages/Core/Tests/Issue1146Test.lean b/StrataTest/Languages/Core/Tests/Issue1146Test.lean new file mode 100644 index 0000000000..86e7fc5370 --- /dev/null +++ b/StrataTest/Languages/Core/Tests/Issue1146Test.lean @@ -0,0 +1,56 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ + +import Strata.Languages.Core.DDMTransform.Translate + +/-! +# Regression test for https://github.com/strata-org/Strata/issues/1146 + +A trailing `;` after a `function` body must not be silently accepted as an +empty `command_datatypes` block (which would later panic in +`translateDatatypes`), and a program mixing a datatype and a function must +translate cleanly. +-/ + +namespace Strata.Issue1146Test + +/-! ## Canonical form: datatype + function translates without error -/ + +private def datatypeAndFunction : Strata.Program := +#strata +program Core; + +datatype List () { Nil() }; + +function Len (xs : List) : int +{ + 0 +} +#end + +/-- info: true -/ +#guard_msgs in +#eval TransM.run Inhabited.default (translateProgram datatypeAndFunction) |>.snd |>.isEmpty + +/-! ## Stray trailing `;` after a function body is a parse error -/ + +/-- +error: unexpected token ';'; expected 'function', Core.Block or expected at least one element +-/ +#guard_msgs in +def strayTrailingSemi : Strata.Program := +#strata +program Core; + +datatype List () { Nil() }; + +function Len (xs : List) : int +{ + 0 +}; +#end + +end Strata.Issue1146Test diff --git a/StrataTest/Languages/Core/Tests/LambdaHigherOrderTests.lean b/StrataTest/Languages/Core/Tests/LambdaHigherOrderTests.lean index 379e681eda..fe93e8e0e8 100644 --- a/StrataTest/Languages/Core/Tests/LambdaHigherOrderTests.lean +++ b/StrataTest/Languages/Core/Tests/LambdaHigherOrderTests.lean @@ -93,7 +93,7 @@ info: [Strata.Core] Type checking succeeded. --- info: ok: program Core; -function apply (f : int -> int, x : int) : int { +inline function apply (f : int -> int, x : int) : int { f(x) } procedure TestLambdaApply (out result : int) diff --git a/StrataTest/Languages/Core/Tests/MutualRecursiveFunctionTests.lean b/StrataTest/Languages/Core/Tests/MutualRecursiveFunctionTests.lean index 10173ed316..9de26a834c 100644 --- a/StrataTest/Languages/Core/Tests/MutualRecursiveFunctionTests.lean +++ b/StrataTest/Languages/Core/Tests/MutualRecursiveFunctionTests.lean @@ -66,20 +66,20 @@ Obligation: Label: isEven_terminates_0 Property: assert Assumptions: -MyNat..adtRank_0: forall __q0 : MyNat :: { MyNat..adtRank(__q0) } - MyNat..adtRank(__q0) >= 0 -MyNat..adtRank_1: forall __q0 : MyNat :: { MyNat..adtRank(Succ(__q0)) } - MyNat..adtRank(__q0) < MyNat..adtRank(Succ(__q0)) +MyNat..adtRank_0: forall x : MyNat :: { MyNat..adtRank(x) } + MyNat..adtRank(x) >= 0 +MyNat..adtRank_1: forall pred : MyNat :: { MyNat..adtRank(Succ(pred)) } + MyNat..adtRank(pred) < MyNat..adtRank(Succ(pred)) Obligation: !(MyNat..isZero(n@3)) ==> MyNat..adtRank(MyNat..pred(n@3)) < MyNat..adtRank(n@3) Label: isOdd_terminates_0 Property: assert Assumptions: -MyNat..adtRank_0: forall __q0 : MyNat :: { MyNat..adtRank(__q0) } - MyNat..adtRank(__q0) >= 0 -MyNat..adtRank_1: forall __q0 : MyNat :: { MyNat..adtRank(Succ(__q0)) } - MyNat..adtRank(__q0) < MyNat..adtRank(Succ(__q0)) +MyNat..adtRank_0: forall x : MyNat :: { MyNat..adtRank(x) } + MyNat..adtRank(x) >= 0 +MyNat..adtRank_1: forall pred : MyNat :: { MyNat..adtRank(Succ(pred)) } + MyNat..adtRank(pred) < MyNat..adtRank(Succ(pred)) Obligation: !(MyNat..isZero(n@4)) ==> MyNat..adtRank(MyNat..pred(n@4)) < MyNat..adtRank(n@4) diff --git a/StrataTest/Languages/Core/Tests/ProgramEvalTests.lean b/StrataTest/Languages/Core/Tests/ProgramEvalTests.lean index d9ee5b06a5..a415360eeb 100644 --- a/StrataTest/Languages/Core/Tests/ProgramEvalTests.lean +++ b/StrataTest/Languages/Core/Tests/ProgramEvalTests.lean @@ -85,12 +85,16 @@ func update : ∀[k, v]. ((m : (Map k v)) (i : k) (x : v)) → (Map k v); func Sequence.length : ∀[a]. ((s : (Sequence a))) → int; func Sequence.empty : ∀[a]. () → (Sequence a); func Sequence.append : ∀[a]. ((s1 : (Sequence a)) (s2 : (Sequence a))) → (Sequence a); -func Sequence.select : ∀[a]. ((s : (Sequence a)) (i : int)) → a; +func Sequence.select : ∀[a]. ((s : (Sequence a)) (i : int)) → a + requires 0 <= i && i < Sequence.length(s); func Sequence.build : ∀[a]. ((s : (Sequence a)) (v : a)) → (Sequence a); -func Sequence.update : ∀[a]. ((s : (Sequence a)) (i : int) (v : a)) → (Sequence a); +func Sequence.update : ∀[a]. ((s : (Sequence a)) (i : int) (v : a)) → (Sequence a) + requires 0 <= i && i < Sequence.length(s); func Sequence.contains : ∀[a]. ((s : (Sequence a)) (v : a)) → bool; -func Sequence.take : ∀[a]. ((s : (Sequence a)) (n : int)) → (Sequence a); -func Sequence.drop : ∀[a]. ((s : (Sequence a)) (n : int)) → (Sequence a); +func Sequence.take : ∀[a]. ((s : (Sequence a)) (n : int)) → (Sequence a) + requires 0 <= n && n <= Sequence.length(s); +func Sequence.drop : ∀[a]. ((s : (Sequence a)) (n : int)) → (Sequence a) + requires 0 <= n && n <= Sequence.length(s); func Triggers.empty : () → Triggers; func Triggers.addGroup : ((g : TriggerGroup) (t : Triggers)) → Triggers; func TriggerGroup.empty : () → TriggerGroup; @@ -268,115 +272,115 @@ func Bv64.UAddOverflow : ((x : bv64) (y : bv64)) → bool; func Bv64.USubOverflow : ((x : bv64) (y : bv64)) → bool; func Bv64.UMulOverflow : ((x : bv64) (y : bv64)) → bool; func Bv1.SafeAdd : ((x : bv1) (y : bv1)) → bv1 - requires !(x <= y); + requires !(Bv.SAddOverflow(x, y)); func Bv1.SafeSub : ((x : bv1) (y : bv1)) → bv1 - requires !(x <= y); + requires !(Bv.SSubOverflow(x, y)); func Bv1.SafeMul : ((x : bv1) (y : bv1)) → bv1 - requires !(x <= y); + requires !(Bv.SMulOverflow(x, y)); func Bv1.SafeNeg : ((x : bv1)) → bv1 - requires !(!x); + requires !(Bv.SNegOverflow(x)); func Bv1.SafeUAdd : ((x : bv1) (y : bv1)) → bv1 - requires !(x <= y); + requires !(Bv.UAddOverflow(x, y)); func Bv1.SafeUSub : ((x : bv1) (y : bv1)) → bv1 - requires !(x <= y); + requires !(Bv.USubOverflow(x, y)); func Bv1.SafeUMul : ((x : bv1) (y : bv1)) → bv1 - requires !(x <= y); + requires !(Bv.UMulOverflow(x, y)); func Bv1.SafeUNeg : ((x : bv1)) → bv1 - requires !(!x); + requires !(Bv.UNegOverflow(x)); func Bv8.SafeAdd : ((x : bv8) (y : bv8)) → bv8 - requires !(x <= y); + requires !(Bv.SAddOverflow(x, y)); func Bv8.SafeSub : ((x : bv8) (y : bv8)) → bv8 - requires !(x <= y); + requires !(Bv.SSubOverflow(x, y)); func Bv8.SafeMul : ((x : bv8) (y : bv8)) → bv8 - requires !(x <= y); + requires !(Bv.SMulOverflow(x, y)); func Bv8.SafeNeg : ((x : bv8)) → bv8 - requires !(!x); + requires !(Bv.SNegOverflow(x)); func Bv8.SafeUAdd : ((x : bv8) (y : bv8)) → bv8 - requires !(x <= y); + requires !(Bv.UAddOverflow(x, y)); func Bv8.SafeUSub : ((x : bv8) (y : bv8)) → bv8 - requires !(x <= y); + requires !(Bv.USubOverflow(x, y)); func Bv8.SafeUMul : ((x : bv8) (y : bv8)) → bv8 - requires !(x <= y); + requires !(Bv.UMulOverflow(x, y)); func Bv8.SafeUNeg : ((x : bv8)) → bv8 - requires !(!x); + requires !(Bv.UNegOverflow(x)); func Bv16.SafeAdd : ((x : bv16) (y : bv16)) → bv16 - requires !(x <= y); + requires !(Bv.SAddOverflow(x, y)); func Bv16.SafeSub : ((x : bv16) (y : bv16)) → bv16 - requires !(x <= y); + requires !(Bv.SSubOverflow(x, y)); func Bv16.SafeMul : ((x : bv16) (y : bv16)) → bv16 - requires !(x <= y); + requires !(Bv.SMulOverflow(x, y)); func Bv16.SafeNeg : ((x : bv16)) → bv16 - requires !(!x); + requires !(Bv.SNegOverflow(x)); func Bv16.SafeUAdd : ((x : bv16) (y : bv16)) → bv16 - requires !(x <= y); + requires !(Bv.UAddOverflow(x, y)); func Bv16.SafeUSub : ((x : bv16) (y : bv16)) → bv16 - requires !(x <= y); + requires !(Bv.USubOverflow(x, y)); func Bv16.SafeUMul : ((x : bv16) (y : bv16)) → bv16 - requires !(x <= y); + requires !(Bv.UMulOverflow(x, y)); func Bv16.SafeUNeg : ((x : bv16)) → bv16 - requires !(!x); + requires !(Bv.UNegOverflow(x)); func Bv32.SafeAdd : ((x : bv32) (y : bv32)) → bv32 - requires !(x <= y); + requires !(Bv.SAddOverflow(x, y)); func Bv32.SafeSub : ((x : bv32) (y : bv32)) → bv32 - requires !(x <= y); + requires !(Bv.SSubOverflow(x, y)); func Bv32.SafeMul : ((x : bv32) (y : bv32)) → bv32 - requires !(x <= y); + requires !(Bv.SMulOverflow(x, y)); func Bv32.SafeNeg : ((x : bv32)) → bv32 - requires !(!x); + requires !(Bv.SNegOverflow(x)); func Bv32.SafeUAdd : ((x : bv32) (y : bv32)) → bv32 - requires !(x <= y); + requires !(Bv.UAddOverflow(x, y)); func Bv32.SafeUSub : ((x : bv32) (y : bv32)) → bv32 - requires !(x <= y); + requires !(Bv.USubOverflow(x, y)); func Bv32.SafeUMul : ((x : bv32) (y : bv32)) → bv32 - requires !(x <= y); + requires !(Bv.UMulOverflow(x, y)); func Bv32.SafeUNeg : ((x : bv32)) → bv32 - requires !(!x); + requires !(Bv.UNegOverflow(x)); func Bv64.SafeAdd : ((x : bv64) (y : bv64)) → bv64 - requires !(x <= y); + requires !(Bv.SAddOverflow(x, y)); func Bv64.SafeSub : ((x : bv64) (y : bv64)) → bv64 - requires !(x <= y); + requires !(Bv.SSubOverflow(x, y)); func Bv64.SafeMul : ((x : bv64) (y : bv64)) → bv64 - requires !(x <= y); + requires !(Bv.SMulOverflow(x, y)); func Bv64.SafeNeg : ((x : bv64)) → bv64 - requires !(!x); + requires !(Bv.SNegOverflow(x)); func Bv64.SafeUAdd : ((x : bv64) (y : bv64)) → bv64 - requires !(x <= y); + requires !(Bv.UAddOverflow(x, y)); func Bv64.SafeUSub : ((x : bv64) (y : bv64)) → bv64 - requires !(x <= y); + requires !(Bv.USubOverflow(x, y)); func Bv64.SafeUMul : ((x : bv64) (y : bv64)) → bv64 - requires !(x <= y); + requires !(Bv.UMulOverflow(x, y)); func Bv64.SafeUNeg : ((x : bv64)) → bv64 - requires !(!x); + requires !(Bv.UNegOverflow(x)); func Bv1.SafeSDiv : ((x : bv1) (y : bv1)) → bv1 requires !(y == bv{1}(0)) - requires !(x <= y); + requires !(Bv.SDivOverflow(x, y)); func Bv1.SafeSMod : ((x : bv1) (y : bv1)) → bv1 requires !(y == bv{1}(0)) - requires !(x <= y); + requires !(Bv.SDivOverflow(x, y)); func Bv8.SafeSDiv : ((x : bv8) (y : bv8)) → bv8 requires !(y == bv{8}(0)) - requires !(x <= y); + requires !(Bv.SDivOverflow(x, y)); func Bv8.SafeSMod : ((x : bv8) (y : bv8)) → bv8 requires !(y == bv{8}(0)) - requires !(x <= y); + requires !(Bv.SDivOverflow(x, y)); func Bv16.SafeSDiv : ((x : bv16) (y : bv16)) → bv16 requires !(y == bv{16}(0)) - requires !(x <= y); + requires !(Bv.SDivOverflow(x, y)); func Bv16.SafeSMod : ((x : bv16) (y : bv16)) → bv16 requires !(y == bv{16}(0)) - requires !(x <= y); + requires !(Bv.SDivOverflow(x, y)); func Bv32.SafeSDiv : ((x : bv32) (y : bv32)) → bv32 requires !(y == bv{32}(0)) - requires !(x <= y); + requires !(Bv.SDivOverflow(x, y)); func Bv32.SafeSMod : ((x : bv32) (y : bv32)) → bv32 requires !(y == bv{32}(0)) - requires !(x <= y); + requires !(Bv.SDivOverflow(x, y)); func Bv64.SafeSDiv : ((x : bv64) (y : bv64)) → bv64 requires !(y == bv{64}(0)) - requires !(x <= y); + requires !(Bv.SDivOverflow(x, y)); func Bv64.SafeSMod : ((x : bv64) (y : bv64)) → bv64 requires !(y == bv{64}(0)) - requires !(x <= y); + requires !(Bv.SDivOverflow(x, y)); Datatypes: diff --git a/StrataTest/Languages/Core/Tests/QuantifierBvarIndexTest.lean b/StrataTest/Languages/Core/Tests/QuantifierBvarIndexTest.lean index 4a806af60e..492b8a493b 100644 --- a/StrataTest/Languages/Core/Tests/QuantifierBvarIndexTest.lean +++ b/StrataTest/Languages/Core/Tests/QuantifierBvarIndexTest.lean @@ -36,7 +36,7 @@ info: [Strata.Core] Type checking succeeded. info: ok: program Core; function apply (f : int -> int, x : int) : int; -axiom [axiom_0]: forall __q0 : int -> int :: forall __q1 : int :: apply(__q0, __q1) == __q0(__q1); +axiom [axiom_0]: forall f : int -> int :: forall x : int :: apply(f, x) == f(x); -/ #guard_msgs in #eval (Std.format ((Core.typeCheck .default (translate axiomApplyBoundVar).stripMetaData))) @@ -71,7 +71,7 @@ function apply (f : int -> int, x : int) : int { } procedure Check (out result : bool) spec { - ensures [Check_ensures_0]: result == forall __q0 : int -> int :: forall __q1 : int :: apply(__q0, __q1) == __q0(__q1); + ensures [Check_ensures_0]: result == forall f : int -> int :: forall x : int :: apply(f, x) == f(x); } { result := true; }; diff --git a/StrataTest/Languages/Core/Tests/RecursiveFunctionTests.lean b/StrataTest/Languages/Core/Tests/RecursiveFunctionTests.lean index c27eac0880..43f6f1e761 100644 --- a/StrataTest/Languages/Core/Tests/RecursiveFunctionTests.lean +++ b/StrataTest/Languages/Core/Tests/RecursiveFunctionTests.lean @@ -61,10 +61,10 @@ Obligation: Label: listLen_terminates_0 Property: assert Assumptions: -IntList..adtRank_0: forall __q0 : IntList :: { IntList..adtRank(__q0) } - IntList..adtRank(__q0) >= 0 -IntList..adtRank_1: forall __q0 : int :: forall __q1 : IntList :: { IntList..adtRank(Cons(__q0, __q1)) } - IntList..adtRank(__q1) < IntList..adtRank(Cons(__q0, __q1)) +IntList..adtRank_0: forall x : IntList :: { IntList..adtRank(x) } + IntList..adtRank(x) >= 0 +IntList..adtRank_1: forall hd : int :: forall tl : IntList :: { IntList..adtRank(Cons(hd, tl)) } + IntList..adtRank(tl) < IntList..adtRank(Cons(hd, tl)) Obligation: !(IntList..isNil(xs@2)) ==> IntList..adtRank(IntList..tl(xs@2)) < IntList..adtRank(xs@2) @@ -168,10 +168,10 @@ Obligation: Label: listLen_terminates_0 Property: assert Assumptions: -IntList..adtRank_0: forall __q0 : IntList :: { IntList..adtRank(__q0) } - IntList..adtRank(__q0) >= 0 -IntList..adtRank_1: forall __q0 : int :: forall __q1 : IntList :: { IntList..adtRank(Cons(__q0, __q1)) } - IntList..adtRank(__q1) < IntList..adtRank(Cons(__q0, __q1)) +IntList..adtRank_0: forall x : IntList :: { IntList..adtRank(x) } + IntList..adtRank(x) >= 0 +IntList..adtRank_1: forall hd : int :: forall tl : IntList :: { IntList..adtRank(Cons(hd, tl)) } + IntList..adtRank(tl) < IntList..adtRank(Cons(hd, tl)) Obligation: !(IntList..isNil(xs@2)) ==> IntList..adtRank(IntList..tl(xs@2)) < IntList..adtRank(xs@2) @@ -360,10 +360,10 @@ Obligation: Label: listLen_terminates_0 Property: assert Assumptions: -IntList..adtRank_0: forall __q0 : IntList :: { IntList..adtRank(__q0) } - IntList..adtRank(__q0) >= 0 -IntList..adtRank_1: forall __q0 : int :: forall __q1 : IntList :: { IntList..adtRank(Cons(__q0, __q1)) } - IntList..adtRank(__q1) < IntList..adtRank(Cons(__q0, __q1)) +IntList..adtRank_0: forall x : IntList :: { IntList..adtRank(x) } + IntList..adtRank(x) >= 0 +IntList..adtRank_1: forall hd : int :: forall tl : IntList :: { IntList..adtRank(Cons(hd, tl)) } + IntList..adtRank(tl) < IntList..adtRank(Cons(hd, tl)) Obligation: !(IntList..isNil(xs@2)) ==> IntList..adtRank(IntList..tl(xs@2)) < IntList..adtRank(xs@2) diff --git a/StrataTest/Languages/Core/Tests/RoundtripTest.lean b/StrataTest/Languages/Core/Tests/RoundtripTest.lean new file mode 100644 index 0000000000..4ff9c9d88a --- /dev/null +++ b/StrataTest/Languages/Core/Tests/RoundtripTest.lean @@ -0,0 +1,235 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ + +import Strata.Languages.Core.DDMTransform.ASTtoCST +import Strata.Languages.Core.DDMTransform.Translate +import Strata.DDM.Elab +import Strata.DDM.BuiltinDialects.Init + +/-! +# Core Roundtrip Tests + +Tests that `Core.formatProgram` produces output that can be parsed back to the +same AST. The roundtrip is: parse → translate → format → re-parse → re-translate +→ compare. +-/ + +namespace Strata.Test.Roundtrip + +open Strata +open Strata.CoreDDM +open Core +open Lean.Parser (InputContext) + +/-- Parse a string as a Core program and translate to AST. -/ +private def parseAndTranslate (input : String) : IO Core.Program := do + let dialects := Strata.Elab.LoadedDialects.ofDialects! #[initDialect, Core] + -- Strip "program Core;\n\n" header if present + let body := if input.startsWith "program Core;\n\n" then + (input.drop "program Core;\n\n".length).toString + else input + let inputCtx := Strata.Parser.stringInputContext ⟨"roundtrip-test"⟩ body + let strataProgram ← Strata.Elab.parseStrataProgramFromDialect dialects "Core" inputCtx + let (ast, errs) := TransM.run Inhabited.default (translateProgram strataProgram) + if !errs.isEmpty then + throw (IO.userError s!"Translation errors: {errs}") + pure ast + +/-- Perform a roundtrip test: parse → format → re-parse → compare. + Prints OK or FAIL with details. -/ +def roundtrip (program : Strata.Program) : IO Unit := do + -- First pass: translate to AST + let (ast1, errs1) := TransM.run Inhabited.default (translateProgram program) + if !errs1.isEmpty then + IO.println s!"FAIL: First translation errors: {errs1}" + return + -- Format back to text + let formatted := (Core.formatProgram ast1).pretty + -- Second pass: re-parse and re-translate + let ast2 ← parseAndTranslate formatted + -- Compare: format both ASTs and check they match + let formatted2 := (Core.formatProgram ast2).pretty + if formatted == formatted2 then + IO.println "OK" + else + IO.println s!"FAIL: Roundtrip mismatch.\nFirst format:\n{formatted}\nSecond format:\n{formatted2}" + +------------------------------------------------------------------------------- +-- Test: Basic types and type aliases +------------------------------------------------------------------------------- + +private def testTypesRoundtrip : Program := +#strata +program Core; + +type T0; +type Byte := bv8; +type IntMap := Map int int; +type T1 (x : Type); +type MyMap (a : Type, b : Type); +type Foo (a : Type, b : Type) := Map b a; +#end + +/-- info: OK -/ +#guard_msgs in +#eval roundtrip testTypesRoundtrip + +------------------------------------------------------------------------------- +-- Test: Polymorphic datatypes with parameterized types +------------------------------------------------------------------------------- + +private def testDatatypesRoundtrip : Program := +#strata +program Core; + +datatype List (a : Type) { + Nil(), + Cons(head : a, tail : List a) +}; + +datatype Tree (a : Type) { + Leaf(val : a), + Node(left : Tree a, right : Tree a) +}; +#end + +/-- +info: program Core; + +datatype List (a : Type) { + Nil(), + Cons(head : a, tail : List a) +}; +datatype Tree (a : Type) { + Leaf(val : a), + Node(left : Tree a, right : Tree a) +}; +-/ +#guard_msgs in +#eval do + let (ast, _) := TransM.run Inhabited.default (translateProgram testDatatypesRoundtrip) + IO.println f!"{Core.formatProgram ast}" + +/-- info: OK -/ +#guard_msgs in +#eval roundtrip testDatatypesRoundtrip + +------------------------------------------------------------------------------- +-- Test: Functions and axioms with quantifiers +------------------------------------------------------------------------------- + +private def testFunctionsRoundtrip : Program := +#strata +program Core; + +function f1(x : int) : int; +axiom [f1_ax]: (forall x : int :: f1(x) > x); + +function f2(x : int, y : bool) : bool; +axiom [f2_ax]: (forall x : int, y : bool :: + {f2(x, true), f2(x, false)} + f2(x, true) == true); + +function f3(x : T1) : Map T1 T2; +#end + +/-- info: OK -/ +#guard_msgs in +#eval roundtrip testFunctionsRoundtrip + +------------------------------------------------------------------------------- +-- Test: Procedures with specs +------------------------------------------------------------------------------- + +private def testProceduresRoundtrip : Program := +#strata +program Core; + +procedure Test(x : bool, out y : bool) +spec { + requires x == true; + ensures y == x; +} { + y := x; +}; +#end + +/-- info: OK -/ +#guard_msgs in +#eval roundtrip testProceduresRoundtrip + +------------------------------------------------------------------------------- +-- Test: Inline functions +------------------------------------------------------------------------------- + +private def testInlineFunctionRoundtrip : Program := +#strata +program Core; + +inline function double(x : int) : int { + x + x +} +#end + +/-- info: OK -/ +#guard_msgs in +#eval roundtrip testInlineFunctionRoundtrip + +------------------------------------------------------------------------------- +-- Test: Parameterized type arguments (the reversed-args bug) +------------------------------------------------------------------------------- + +private def testTypeArgsRoundtrip : Program := +#strata +program Core; + +type Pair (a : Type, b : Type); + +function f(x : Pair int bool) : int; +function g(x : Map int bool) : int; +#end + +/-- info: OK -/ +#guard_msgs in +#eval roundtrip testTypeArgsRoundtrip + +------------------------------------------------------------------------------- +-- Test: Array assignment (lhsArray: m[k] := v) +------------------------------------------------------------------------------- + +private def testLhsArrayRoundtrip : Program := +#strata +program Core; + +procedure MapUpdate(m : Map int int, out m : Map int int) +spec { + ensures true; +} { + m[0] := 1; +}; +#end + +/-- info: OK -/ +#guard_msgs in +#eval roundtrip testLhsArrayRoundtrip + +------------------------------------------------------------------------------- +-- Test: Sequence.empty with explicit type annotation +------------------------------------------------------------------------------- + +private def testSeqEmptyRoundtrip : Program := +#strata +program Core; + +function f(s : Sequence int) : bool; +axiom [f_ax]: f(Sequence.empty()) == true; +#end + +/-- info: OK -/ +#guard_msgs in +#eval roundtrip testSeqEmptyRoundtrip + +end Strata.Test.Roundtrip diff --git a/StrataTest/Languages/Core/Tests/SMTEncoderDatatypeTest.lean b/StrataTest/Languages/Core/Tests/SMTEncoderDatatypeTest.lean index 1dd7ff5a6c..0f1d023365 100644 --- a/StrataTest/Languages/Core/Tests/SMTEncoderDatatypeTest.lean +++ b/StrataTest/Languages/Core/Tests/SMTEncoderDatatypeTest.lean @@ -71,13 +71,13 @@ def treeDatatype : LDatatype Unit := Convert an expression to full SMT string including datatype declarations. `blocks` is a list of mutual blocks (each block is a list of mutually recursive datatypes). -/ -def toSMTStringWithDatatypeBlocks (e : LExpr CoreLParams.mono) (blocks : List (List (LDatatype Unit))) : IO String := do +def toSMTStringWithDatatypeBlocks (e : LExpr CoreLParams.mono) (blocks : List (List (LDatatype Unit))) (useArrayTheory : Bool := false): IO String := do match Env.init.addDatatypes blocks with | .error msg => return s!"Error creating environment: {msg}" | .ok env => -- Set the TypeFactory for correct datatype emission ordering let ctx := SMT.Context.default.withTypeFactory env.datatypes - match toSMTTerm env [] e ctx with + match toSMTTerm env [] e ctx useArrayTheory with | .error err => return err.pretty | .ok (smt, ctx) => -- Emit the full SMT output including datatype declarations @@ -85,7 +85,7 @@ def toSMTStringWithDatatypeBlocks (e : LExpr CoreLParams.mono) (blocks : List (L let solver ← Strata.SMT.Solver.bufferWriter b match (← ((do -- First emit datatypes - ctx.emitDatatypes + ctx.emitDatatypes useArrayTheory -- Then encode the term let _ ← (Strata.SMT.Encoder.encodeTerm smt).run Strata.SMT.EncoderState.init pure () @@ -102,8 +102,8 @@ def toSMTStringWithDatatypeBlocks (e : LExpr CoreLParams.mono) (blocks : List (L Convert an expression to full SMT string including datatype declarations. Each datatype is treated as its own (non-mutual) block. -/ -def toSMTStringWithDatatypes (e : LExpr CoreLParams.mono) (datatypes : List (LDatatype Unit)) : IO String := - toSMTStringWithDatatypeBlocks e (datatypes.map (fun d => [d])) +def toSMTStringWithDatatypes (e : LExpr CoreLParams.mono) (datatypes : List (LDatatype Unit)) (useArrayTheory : Bool := false): IO String := + toSMTStringWithDatatypeBlocks e (datatypes.map (fun d => [d])) useArrayTheory /-! ## Test Cases with Guard Messages -/ @@ -511,6 +511,68 @@ info: (declare-datatype IntList ( [[intListDatatype]] listLenFunc +/-- Container = MkContainer (data: Map int int) -/ +def containerWithMapDatatype : LDatatype Unit := + { name := "Container" + typeArgs := [] + constrs := [ + { name := ⟨"MkContainer", ()⟩, + args := [(⟨"data", ()⟩, .tcons "Map" [.int, .int])], + testerName := "Container..isMkContainer" } + ] + constrs_ne := by decide } + +-- Test: ADT constructor field with Map type should emit Array when useArrayTheory=true +/-- +info: (declare-datatype Container ( + (MkContainer (Container..data (Array Int Int))))) +; c +(declare-const c Container) +-/ +#guard_msgs in +#eval format <$> toSMTStringWithDatatypes + (.fvar () (⟨"c", ()⟩) (.some (.tcons "Container" []))) + [containerWithMapDatatype] true + +-- Test: Same datatype without useArrayTheory should keep Map +/-- +info: (declare-datatype Container ( + (MkContainer (Container..data (Map Int Int))))) +; c +(declare-const c Container) +-/ +#guard_msgs in +#eval format <$> toSMTStringWithDatatypes + (.fvar () (⟨"c", ()⟩) (.some (.tcons "Container" []))) + [containerWithMapDatatype] + +-- Test: ADT testers with Map type should emit Array when useArrayTheory=true +/-- +info: (declare-datatype Container ( + (MkContainer (Container..data (Array Int Int))))) +; xs +(declare-const xs Container) +-/ +#guard_msgs in +#eval format <$> toSMTStringWithDatatypes + (.app () (.op () (⟨"Container..isMkContainer", ()⟩) (.some (.arrow (.tcons "Container" []) .bool))) + (.fvar () (⟨"xs", ()⟩) (.some (.tcons "Container" [])))) + [containerWithMapDatatype] true + +-- Test: ADT destructors with Map type should emit Array when useArrayTheory=true +/-- +info: (declare-datatype Container ( + (MkContainer (Container..data (Array Int Int))))) +; xs +(declare-const xs Container) +-/ +#guard_msgs in +#eval format <$> toSMTStringWithDatatypes + (.app () (.op () (⟨"Container..data", ()⟩) (.some (.arrow (.tcons "Container" []) (.tcons "Map" [.int, .int])))) + (.fvar () (⟨"xs", ()⟩) (.some (.tcons "Container" [])))) + [containerWithMapDatatype] true + + end DatatypeTests end Core diff --git a/StrataTest/Languages/Core/Tests/SMTEncoderTests.lean b/StrataTest/Languages/Core/Tests/SMTEncoderTests.lean index 979e09a3d4..5ce57a9ecc 100644 --- a/StrataTest/Languages/Core/Tests/SMTEncoderTests.lean +++ b/StrataTest/Languages/Core/Tests/SMTEncoderTests.lean @@ -298,6 +298,7 @@ end ArrayTheory /-- info: (["c"], true) -/ #guard_msgs in #eval show IO _ from do + let pctx ← Strata.Pipeline.PipelineContext.create (outputMode := .quiet) (profilePipeline := false) -- Non-nullary UF: f(x : Int) : Int — should be excluded from ids let uf_f := UF.mk "f" [TermVar.mk "x" TermType.int] TermType.int -- Nullary UF: c : Int — should be included in ids @@ -310,7 +311,8 @@ end ArrayTheory let ((ids, _estate), _) ← Strata.SMT.SolverM.run solver (Strata.SMT.Encoder.encodeCore ctx (pure ()) [] obligationTerm md - (satisfiabilityCheck := false) (validityCheck := true) (label := "test")) + (satisfiabilityCheck := false) (validityCheck := true) (label := "test") + (pctx := pctx)) -- ids should contain "c" but not "f" let hasF := ids.any (· == "f") return (ids, !hasF) @@ -326,6 +328,7 @@ info: (set-logic ALL) -/ #guard_msgs in #eval show IO _ from do + let pctx ← Strata.Pipeline.PipelineContext.create (outputMode := .quiet) (profilePipeline := false) let ctx : SMT.Context := SMT.Context.default let obligationTerm := Term.prim (.bool true) let md : Imperative.MetaData Core.Expression := #[] @@ -334,7 +337,8 @@ info: (set-logic ALL) let _ ← Strata.SMT.SolverM.run solver (Strata.SMT.Encoder.encodeCore ctx (pure ()) [] obligationTerm md - (satisfiabilityCheck := false) (validityCheck := true) (label := "assert_bounds_check")) + (satisfiabilityCheck := false) (validityCheck := true) (label := "assert_bounds_check") + (pctx := pctx)) let contents ← b.get let smt := if h : contents.data.IsValidUTF8 @@ -353,6 +357,7 @@ info: (set-logic ALL) -/ #guard_msgs in #eval show IO _ from do + let pctx ← Strata.Pipeline.PipelineContext.create (outputMode := .quiet) (profilePipeline := false) let ctx : SMT.Context := SMT.Context.default let obligationTerm := Term.prim (.bool true) let md : Imperative.MetaData Core.Expression := @@ -362,7 +367,8 @@ info: (set-logic ALL) let _ ← Strata.SMT.SolverM.run solver (Strata.SMT.Encoder.encodeCore ctx (pure ()) [] obligationTerm md - (satisfiabilityCheck := false) (validityCheck := true) (label := "assert_bounds_check")) + (satisfiabilityCheck := false) (validityCheck := true) (label := "assert_bounds_check") + (pctx := pctx)) let contents ← b.get let smt := if h : contents.data.IsValidUTF8 @@ -386,6 +392,7 @@ info: (set-logic ALL) and check flags, and return the resulting SMT-LIB text. -/ private def captureEncodeCore (md : Imperative.MetaData Core.Expression) (satCheck validityCheck : Bool) (label : String := "test") : IO String := do + let pctx ← Strata.Pipeline.PipelineContext.create (outputMode := .quiet) (profilePipeline := false) let ctx : SMT.Context := SMT.Context.default let obligationTerm := Term.prim (.bool true) let b ← IO.mkRef { : IO.FS.Stream.Buffer } @@ -393,7 +400,8 @@ private def captureEncodeCore (md : Imperative.MetaData Core.Expression) let _ ← Strata.SMT.SolverM.run solver (Strata.SMT.Encoder.encodeCore ctx (pure ()) [] obligationTerm md - (satisfiabilityCheck := satCheck) (validityCheck := validityCheck) (label := label)) + (satisfiabilityCheck := satCheck) (validityCheck := validityCheck) (label := label) + (pctx := pctx)) let contents ← b.get return if h : contents.data.IsValidUTF8 then String.fromUTF8 contents.data h diff --git a/StrataTest/Languages/Core/Tests/SarifOutputTests.lean b/StrataTest/Languages/Core/Tests/SarifOutputTests.lean index 98c8cac9b0..ad488d023c 100644 --- a/StrataTest/Languages/Core/Tests/SarifOutputTests.lean +++ b/StrataTest/Languages/Core/Tests/SarifOutputTests.lean @@ -55,9 +55,10 @@ def makeFilesMap (file : String) : Map Strata.Uri Lean.FileMap := Map.empty.insert uri makeFileMap /-- Create a simple proof obligation for testing -/ -def makeObligation (label : String) (md : MetaData Expression := #[]) : ProofObligation Expression := +def makeObligation (label : String) (md : MetaData Expression := #[]) + (property : Imperative.PropertyType := .assert) : ProofObligation Expression := { label := label - property := .assert + property := property assumptions := [] obligation := Lambda.LExpr.boolConst () true metadata := md } @@ -65,8 +66,9 @@ def makeObligation (label : String) (md : MetaData Expression := #[]) : ProofObl /-- Create a VCResult for testing -/ def makeVCResult (label : String) (outcome : VCOutcome) (md : MetaData Expression := #[]) - (lexprModel : LExprModel := []) : VCResult := - { obligation := makeObligation label md + (lexprModel : LExprModel := []) + (property : Imperative.PropertyType := .assert) : VCResult := + { obligation := makeObligation label md property outcome := .ok outcome verbose := .normal lexprModel := lexprModel @@ -345,4 +347,33 @@ def makeVCResult (label : String) (outcome : VCOutcome) let sarif := vcResultsToSarif .deductive files vcResults Strata.Sarif.toJsonString sarif +/-! ## Property classification tests + +The SARIF `properties.propertyType` field should reflect the obligation's +`PropertyType`, not the default `"assert"`. -/ + +private def sarifPropertyType (vcr : VCResult) : String := + let files := makeFilesMap "/test/x.st" + (vcResultToSarifResult .deductive files vcr).properties.propertyType + +/-- info: "assert" -/ +#guard_msgs in +#eval sarifPropertyType (makeVCResult "t" (mkOutcome (.sat []) .unsat) (property := .assert)) + +/-- info: "division-by-zero" -/ +#guard_msgs in +#eval sarifPropertyType (makeVCResult "t" (mkOutcome (.sat []) .unsat) (property := .divisionByZero)) + +/-- info: "arithmetic-overflow" -/ +#guard_msgs in +#eval sarifPropertyType (makeVCResult "t" (mkOutcome (.sat []) .unsat) (property := .arithmeticOverflow)) + +/-- info: "out-of-bounds-access" -/ +#guard_msgs in +#eval sarifPropertyType (makeVCResult "t" (mkOutcome (.sat []) .unsat) (property := .outOfBoundsAccess)) + +/-- info: "cover" -/ +#guard_msgs in +#eval sarifPropertyType (makeVCResult "t" (mkOutcome (.sat []) .unsat) (property := .cover)) + end Core.Sarif.Tests diff --git a/StrataTest/Languages/Core/Tests/TerminationCheckTests.lean b/StrataTest/Languages/Core/Tests/TerminationCheckTests.lean index e5e44abeed..f36901473e 100644 --- a/StrataTest/Languages/Core/Tests/TerminationCheckTests.lean +++ b/StrataTest/Languages/Core/Tests/TerminationCheckTests.lean @@ -54,10 +54,10 @@ Obligation: Label: listLen_terminates_0 Property: assert Assumptions: -IntList..adtRank_0: forall __q0 : IntList :: { IntList..adtRank(__q0) } - IntList..adtRank(__q0) >= 0 -IntList..adtRank_1: forall __q0 : int :: forall __q1 : IntList :: { IntList..adtRank(Cons(__q0, __q1)) } - IntList..adtRank(__q1) < IntList..adtRank(Cons(__q0, __q1)) +IntList..adtRank_0: forall x : IntList :: { IntList..adtRank(x) } + IntList..adtRank(x) >= 0 +IntList..adtRank_1: forall hd : int :: forall tl : IntList :: { IntList..adtRank(Cons(hd, tl)) } + IntList..adtRank(tl) < IntList..adtRank(Cons(hd, tl)) Obligation: !(IntList..isNil(xs@2)) ==> IntList..adtRank(IntList..tl(xs@2)) < IntList..adtRank(xs@2) @@ -320,10 +320,10 @@ Obligation: Label: listLen_terminates_0 Property: assert Assumptions: -IntList..adtRank_0: forall __q0 : IntList :: { IntList..adtRank(__q0) } - IntList..adtRank(__q0) >= 0 -IntList..adtRank_1: forall __q0 : int :: forall __q1 : IntList :: { IntList..adtRank(Cons(__q0, __q1)) } - IntList..adtRank(__q1) < IntList..adtRank(Cons(__q0, __q1)) +IntList..adtRank_0: forall x : IntList :: { IntList..adtRank(x) } + IntList..adtRank(x) >= 0 +IntList..adtRank_1: forall hd : int :: forall tl : IntList :: { IntList..adtRank(Cons(hd, tl)) } + IntList..adtRank(tl) < IntList..adtRank(Cons(hd, tl)) Obligation: !(IntList..isNil(xs@2)) ==> IntList..adtRank(IntList..tl(xs@2)) < IntList..adtRank(xs@2) @@ -340,10 +340,10 @@ Obligation: Label: listSum_terminates_0 Property: assert Assumptions: -IntList..adtRank_0: forall __q0 : IntList :: { IntList..adtRank(__q0) } - IntList..adtRank(__q0) >= 0 -IntList..adtRank_1: forall __q0 : int :: forall __q1 : IntList :: { IntList..adtRank(Cons(__q0, __q1)) } - IntList..adtRank(__q1) < IntList..adtRank(Cons(__q0, __q1)) +IntList..adtRank_0: forall x : IntList :: { IntList..adtRank(x) } + IntList..adtRank(x) >= 0 +IntList..adtRank_1: forall hd : int :: forall tl : IntList :: { IntList..adtRank(Cons(hd, tl)) } + IntList..adtRank(tl) < IntList..adtRank(Cons(hd, tl)) Obligation: !(IntList..isNil(xs@4)) ==> IntList..adtRank(IntList..tl(xs@4)) < IntList..adtRank(xs@4) @@ -449,42 +449,42 @@ Obligation: Label: treeSize_terminates_0 Property: assert Assumptions: -Tree..adtRank_0: forall __q0 : Tree :: { Tree..adtRank(__q0) } - Tree..adtRank(__q0) >= 0 -Tree..adtRank_1: forall __q0 : Tree :: forall __q1 : Tree :: { Tree..adtRank(Branch(__q0, __q1)) } - Tree..adtRank(__q0) < Tree..adtRank(Branch(__q0, __q1)) -Tree..adtRank_2: forall __q0 : Tree :: forall __q1 : Tree :: { Tree..adtRank(Branch(__q0, __q1)) } - Tree..adtRank(__q1) < Tree..adtRank(Branch(__q0, __q1)) -Tree..adtRank_3: forall __q0 : int :: forall __q1 : Tree :: { Tree..adtRank(Chain(__q0, __q1)) } - Tree..adtRank(__q1) < Tree..adtRank(Chain(__q0, __q1)) +Tree..adtRank_0: forall x : Tree :: { Tree..adtRank(x) } + Tree..adtRank(x) >= 0 +Tree..adtRank_1: forall left : Tree :: forall right : Tree :: { Tree..adtRank(Branch(left, right)) } + Tree..adtRank(left) < Tree..adtRank(Branch(left, right)) +Tree..adtRank_2: forall left : Tree :: forall right : Tree :: { Tree..adtRank(Branch(left, right)) } + Tree..adtRank(right) < Tree..adtRank(Branch(left, right)) +Tree..adtRank_3: forall head : int :: forall tail : Tree :: { Tree..adtRank(Chain(head, tail)) } + Tree..adtRank(tail) < Tree..adtRank(Chain(head, tail)) Obligation: Tree..isBranch(t@2) ==> !(Tree..isLeaf(t@2)) ==> Tree..adtRank(Tree..left(t@2)) < Tree..adtRank(t@2) Label: treeSize_terminates_1 Property: assert Assumptions: -Tree..adtRank_0: forall __q0 : Tree :: { Tree..adtRank(__q0) } - Tree..adtRank(__q0) >= 0 -Tree..adtRank_1: forall __q0 : Tree :: forall __q1 : Tree :: { Tree..adtRank(Branch(__q0, __q1)) } - Tree..adtRank(__q0) < Tree..adtRank(Branch(__q0, __q1)) -Tree..adtRank_2: forall __q0 : Tree :: forall __q1 : Tree :: { Tree..adtRank(Branch(__q0, __q1)) } - Tree..adtRank(__q1) < Tree..adtRank(Branch(__q0, __q1)) -Tree..adtRank_3: forall __q0 : int :: forall __q1 : Tree :: { Tree..adtRank(Chain(__q0, __q1)) } - Tree..adtRank(__q1) < Tree..adtRank(Chain(__q0, __q1)) +Tree..adtRank_0: forall x : Tree :: { Tree..adtRank(x) } + Tree..adtRank(x) >= 0 +Tree..adtRank_1: forall left : Tree :: forall right : Tree :: { Tree..adtRank(Branch(left, right)) } + Tree..adtRank(left) < Tree..adtRank(Branch(left, right)) +Tree..adtRank_2: forall left : Tree :: forall right : Tree :: { Tree..adtRank(Branch(left, right)) } + Tree..adtRank(right) < Tree..adtRank(Branch(left, right)) +Tree..adtRank_3: forall head : int :: forall tail : Tree :: { Tree..adtRank(Chain(head, tail)) } + Tree..adtRank(tail) < Tree..adtRank(Chain(head, tail)) Obligation: Tree..isBranch(t@2) ==> !(Tree..isLeaf(t@2)) ==> Tree..adtRank(Tree..right(t@2)) < Tree..adtRank(t@2) Label: treeSize_terminates_2 Property: assert Assumptions: -Tree..adtRank_0: forall __q0 : Tree :: { Tree..adtRank(__q0) } - Tree..adtRank(__q0) >= 0 -Tree..adtRank_1: forall __q0 : Tree :: forall __q1 : Tree :: { Tree..adtRank(Branch(__q0, __q1)) } - Tree..adtRank(__q0) < Tree..adtRank(Branch(__q0, __q1)) -Tree..adtRank_2: forall __q0 : Tree :: forall __q1 : Tree :: { Tree..adtRank(Branch(__q0, __q1)) } - Tree..adtRank(__q1) < Tree..adtRank(Branch(__q0, __q1)) -Tree..adtRank_3: forall __q0 : int :: forall __q1 : Tree :: { Tree..adtRank(Chain(__q0, __q1)) } - Tree..adtRank(__q1) < Tree..adtRank(Chain(__q0, __q1)) +Tree..adtRank_0: forall x : Tree :: { Tree..adtRank(x) } + Tree..adtRank(x) >= 0 +Tree..adtRank_1: forall left : Tree :: forall right : Tree :: { Tree..adtRank(Branch(left, right)) } + Tree..adtRank(left) < Tree..adtRank(Branch(left, right)) +Tree..adtRank_2: forall left : Tree :: forall right : Tree :: { Tree..adtRank(Branch(left, right)) } + Tree..adtRank(right) < Tree..adtRank(Branch(left, right)) +Tree..adtRank_3: forall head : int :: forall tail : Tree :: { Tree..adtRank(Chain(head, tail)) } + Tree..adtRank(tail) < Tree..adtRank(Chain(head, tail)) Obligation: !(Tree..isBranch(t@2)) ==> !(Tree..isLeaf(t@2)) ==> Tree..adtRank(Tree..tail(t@2)) < Tree..adtRank(t@2) @@ -581,10 +581,10 @@ Obligation: Label: intListLen_terminates_0 Property: assert Assumptions: -MyList..adtRank_0: forall __q0 : (MyList int) :: { MyList..adtRank(__q0) } - MyList..adtRank(__q0) >= 0 -MyList..adtRank_1: forall __q0 : int :: forall __q1 : (MyList int) :: { MyList..adtRank(Cons(__q0, __q1)) } - MyList..adtRank(__q1) < MyList..adtRank(Cons(__q0, __q1)) +MyList..adtRank_0: forall x : (MyList int) :: { MyList..adtRank(x) } + MyList..adtRank(x) >= 0 +MyList..adtRank_1: forall hd : int :: forall tl : (MyList int) :: { MyList..adtRank(Cons(hd, tl)) } + MyList..adtRank(tl) < MyList..adtRank(Cons(hd, tl)) Obligation: !(MyList..isNil(xs@2)) ==> MyList..adtRank(MyList..tl(xs@2)) < MyList..adtRank(xs@2) @@ -674,10 +674,10 @@ Obligation: Label: zipLen_terminates_0 Property: assert Assumptions: -IntList..adtRank_0: forall __q0 : IntList :: { IntList..adtRank(__q0) } - IntList..adtRank(__q0) >= 0 -IntList..adtRank_1: forall __q0 : int :: forall __q1 : IntList :: { IntList..adtRank(Cons(__q0, __q1)) } - IntList..adtRank(__q1) < IntList..adtRank(Cons(__q0, __q1)) +IntList..adtRank_0: forall x : IntList :: { IntList..adtRank(x) } + IntList..adtRank(x) >= 0 +IntList..adtRank_1: forall hd : int :: forall tl : IntList :: { IntList..adtRank(Cons(hd, tl)) } + IntList..adtRank(tl) < IntList..adtRank(Cons(hd, tl)) Obligation: !(IntList..isNil(ys@2)) ==> !(IntList..isNil(xs@2)) ==> IntList..adtRank(IntList..tl(ys@2)) < IntList..adtRank(ys@2) @@ -1131,20 +1131,20 @@ Obligation: Label: listLen_terminates_0 Property: assert Assumptions: -IntList..adtRank_0: forall __q0 : IntList :: { IntList..adtRank(__q0) } - IntList..adtRank(__q0) >= 0 -IntList..adtRank_1: forall __q0 : int :: forall __q1 : IntList :: { IntList..adtRank(Cons(__q0, __q1)) } - IntList..adtRank(__q1) < IntList..adtRank(Cons(__q0, __q1)) +IntList..adtRank_0: forall x : IntList :: { IntList..adtRank(x) } + IntList..adtRank(x) >= 0 +IntList..adtRank_1: forall hd : int :: forall tl : IntList :: { IntList..adtRank(Cons(hd, tl)) } + IntList..adtRank(tl) < IntList..adtRank(Cons(hd, tl)) Obligation: !(IntList..isNil(xs@2)) ==> IntList..adtRank(IntList..tl(xs@2)) < IntList..adtRank(xs@2) Label: natToInt_terminates_0 Property: assert Assumptions: -MyNat..adtRank_0: forall __q0 : MyNat :: { MyNat..adtRank(__q0) } - MyNat..adtRank(__q0) >= 0 -MyNat..adtRank_1: forall __q0 : MyNat :: { MyNat..adtRank(Succ(__q0)) } - MyNat..adtRank(__q0) < MyNat..adtRank(Succ(__q0)) +MyNat..adtRank_0: forall x : MyNat :: { MyNat..adtRank(x) } + MyNat..adtRank(x) >= 0 +MyNat..adtRank_1: forall pred : MyNat :: { MyNat..adtRank(Succ(pred)) } + MyNat..adtRank(pred) < MyNat..adtRank(Succ(pred)) Obligation: !(MyNat..isZero(n@2)) ==> MyNat..adtRank(MyNat..pred(n@2)) < MyNat..adtRank(n@2) diff --git a/StrataTest/Languages/Core/Tests/TestASTtoCST.lean b/StrataTest/Languages/Core/Tests/TestASTtoCST.lean index 57b1251bec..981a7fc688 100644 --- a/StrataTest/Languages/Core/Tests/TestASTtoCST.lean +++ b/StrataTest/Languages/Core/Tests/TestASTtoCST.lean @@ -119,17 +119,17 @@ info: program Core; function fooConst () : int; axiom [fooConst_value]: fooConst == 5; function f1 (x : int) : int; -axiom [f1_ax1]: forall __q0 : int :: { f1(__q0) } - f1(__q0) > __q0; -axiom [f1_ax2_no_trigger]: forall __q0 : int :: f1(__q0) > __q0; +axiom [f1_ax1]: forall x : int :: { f1(x) } + f1(x) > x; +axiom [f1_ax2_no_trigger]: forall x : int :: f1(x) > x; function f2 (x : int, y : bool) : bool; -axiom [f2_ax]: forall __q0 : int :: forall __q1 : bool :: { f2(__q0, true), f2(__q0, false) } - f2(__q0, true) == true; +axiom [f2_ax]: forall x : int :: forall y : bool :: { f2(x, true), f2(x, false) } + f2(x, true) == true; function f3 (x : int, y : bool, z : regex) : bool; -axiom [f3_ax]: forall __q0 : int :: forall __q1 : bool :: forall __q2 : regex :: { f3(__q0, __q1, __q2), f2(__q0, __q1) } - f3(__q0, __q1, __q2) == f2(__q0, __q1); +axiom [f3_ax]: forall x : int :: forall y : bool :: forall z : regex :: { f3(x, y, z), f2(x, y) } + f3(x, y, z) == f2(x, y); function f4 (x : T1) : Map T1 T2; -axiom [foo_ax]: forall __q0 : int :: (f4(__q0))[1] == true; +axiom [foo_ax]: forall x : int :: (f4(x))[1] == true; function f5 (x : T1, y : T2) : T1 { x } @@ -425,8 +425,8 @@ info: program Core; procedure find_max (nums : Map bv64 bv32, nums_len : bv64, out ret : bv32) spec { requires [find_max_requires_0]: nums_len > bv{64}(0); - ensures [find_max_ensures_1]: forall __q0 : bv64 :: bv{64}(0) <= __q0 && __q0 < nums_len ==> ret >=s nums[__q0]; - ensures [find_max_ensures_2]: exists __q0 : bv64 :: bv{64}(0) <= __q0 && __q0 < nums_len && ret == nums[__q0]; + ensures [find_max_ensures_1]: forall x0 : bv64 :: bv{64}(0) <= x0 && x0 < nums_len ==> ret >=s nums[x0]; + ensures [find_max_ensures_2]: exists x0 : bv64 :: bv{64}(0) <= x0 && x0 < nums_len && ret == nums[x0]; } { var max : bv32; var i : bv64; @@ -436,8 +436,8 @@ spec { invariant nums_len > bv{64}(0) invariant bv{64}(0) <= i invariant i <= nums_len - invariant forall __q0 : bv64 :: bv{64}(0) <= __q0 && __q0 < i ==> max >=s nums[__q0] - invariant exists __q0 : bv64 :: bv{64}(0) <= __q0 && __q0 < i && max == nums[__q0] + invariant forall x0 : bv64 :: bv{64}(0) <= x0 && x0 < i ==> max >=s nums[x0] + invariant exists x0 : bv64 :: bv{64}(0) <= x0 && x0 < i && max == nums[x0] { if (nums[i] >s max) { max := nums[i]; diff --git a/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean b/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean index 86ce51e683..0811d5e955 100644 --- a/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean +++ b/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean @@ -52,6 +52,7 @@ procedure test(n: int) ensures nat$constraint(r) { assert r >= 0; var y: int := n; assert nat$constraint(y); return y }; procedure $witness_nat() + opaque { var $witness: int := 0; assert nat$constraint($witness) }; -/ #guard_msgs in @@ -80,6 +81,7 @@ info: function pos$constraint(v: int): bool procedure test(b: bool) { if b then { var x: int := 1; assert pos$constraint(x) }; { var x: int := -5; x := -10 } }; procedure $witness_pos() + opaque { var $witness: int := 1; assert pos$constraint($witness) }; -/ #guard_msgs in @@ -104,6 +106,7 @@ info: function posint$constraint(x: int): bool procedure f() { var x: int; assume posint$constraint(x); assert x == 1 }; procedure $witness_posint() + opaque { var $witness: int := 1; assert posint$constraint($witness) }; -/ #guard_msgs in diff --git a/StrataTest/Languages/Laurel/Examples/Objects/T1_MutableFields.lean b/StrataTest/Languages/Laurel/Examples/Objects/T1_MutableFields.lean index 7dbf35022d..e46f03ef99 100644 --- a/StrataTest/Languages/Laurel/Examples/Objects/T1_MutableFields.lean +++ b/StrataTest/Languages/Laurel/Examples/Objects/T1_MutableFields.lean @@ -199,5 +199,5 @@ procedure fieldTargetInMultiAssign() }; "# -#guard_msgs(drop info, error) in +#guard_msgs (drop info, error) in #eval testInputWithOffset "MutableFields" program 14 processLaurelFile diff --git a/StrataTest/Languages/Laurel/Examples/Objects/T7_InstanceProcedures.lean b/StrataTest/Languages/Laurel/Examples/Objects/T7_InstanceProcedures.lean index ec05fcfd3d..189295102d 100644 --- a/StrataTest/Languages/Laurel/Examples/Objects/T7_InstanceProcedures.lean +++ b/StrataTest/Languages/Laurel/Examples/Objects/T7_InstanceProcedures.lean @@ -15,8 +15,8 @@ namespace Strata.Laurel def instanceProcedureProgram := r" composite Counter { var count: int - procedure increment(self: Counter) -// ^^^^^^^^^ error: Instance procedure 'increment' on composite type 'Counter' is not yet supported + procedure self_increment(self: Counter) +// ^^^^^^^^^^^^^^ error: Instance procedure 'self_increment' on composite type 'Counter' is not yet supported opaque { self#count := self#count + 1 diff --git a/StrataTest/Languages/Laurel/LiftHolesTest.lean b/StrataTest/Languages/Laurel/LiftHolesTest.lean index 14d25a4416..0f5a4997d3 100644 --- a/StrataTest/Languages/Laurel/LiftHolesTest.lean +++ b/StrataTest/Languages/Laurel/LiftHolesTest.lean @@ -332,4 +332,52 @@ procedure test() { var x: int := ; assert }; -- Nondet hole in function → should be rejected (not tested here since -- the error occurs at Core translation time, which requires the full pipeline). +/-! ## Holes inside datatype destructor / tester arguments -/ + +-- Hole as argument to a (safe) datatype destructor → typed as the parent +-- datatype, then lifted to a generated `$hole_0` returning that datatype. +-- Regression test for PR #1134: the destructor's `ResolvedNode` carries the +-- parent datatype's resolved Identifier (with `uniqueId`), so this works +-- without textual decoding of the override name. +/-- +info: function $hole_0() + returns ($result: IntList) + opaque; +procedure test() +{ var x: int := IntList..head($hole_0()) }; +-/ +#guard_msgs in +#eval! parseElimAndPrint r" +datatype IntList { Nil(), Cons(head: int, tail: IntList) } +procedure test() { var x: int := IntList..head() }; +" + +-- Hole as argument to an unsafe `!` destructor → same datatype recovery. +/-- +info: function $hole_0() + returns ($result: IntList) + opaque; +procedure test() +{ var x: int := IntList..head!($hole_0()) }; +-/ +#guard_msgs in +#eval! parseElimAndPrint r" +datatype IntList { Nil(), Cons(head: int, tail: IntList) } +procedure test() { var x: int := IntList..head!() }; +" + +-- Hole as argument to a tester → typed as the parent datatype. +/-- +info: function $hole_0() + returns ($result: IntList) + opaque; +procedure test() +{ assert IntList..isCons($hole_0()) }; +-/ +#guard_msgs in +#eval! parseElimAndPrint r" +datatype IntList { Nil(), Cons(head: int, tail: IntList) } +procedure test() { assert IntList..isCons() }; +" + end Laurel diff --git a/StrataTest/Languages/Laurel/ResolutionKindTests.lean b/StrataTest/Languages/Laurel/ResolutionKindTests.lean index acbef556b6..52355edf11 100644 --- a/StrataTest/Languages/Laurel/ResolutionKindTests.lean +++ b/StrataTest/Languages/Laurel/ResolutionKindTests.lean @@ -66,7 +66,7 @@ def typeAsStaticCall := r" composite Foo { } procedure bar() opaque { var x: int := Foo() -// ^^^^^ error: 'Foo' resolves to composite type, but expected parameter, static procedure, datatype constructor, constant +// ^^^^^ error: 'Foo' resolves to composite type, but expected parameter, static procedure, datatype constructor, datatype destructor, constant }; " @@ -97,4 +97,17 @@ composite Foo extends nat { } #guard_msgs (error, drop all) in #eval testInputWithOffset "ExtendConstrained" extendConstrained 90 processResolution +/-! ## Multi-output procedure used in expression position -/ + +def multiOutputInExpr := r" +procedure multi(x: int) returns (a: int, b: int) opaque; +procedure test() opaque { + assert multi(1) == 1 +// ^^^^^^^^ error: Multi-output procedure 'multi' used in expression position +}; +" + +#guard_msgs (error, drop all) in +#eval testInputWithOffset "MultiOutputInExpr" multiOutputInExpr 100 processResolution + end Laurel diff --git a/StrataTest/Languages/Laurel/TestExamples.lean b/StrataTest/Languages/Laurel/TestExamples.lean index 5affbb2813..00d14ae804 100644 --- a/StrataTest/Languages/Laurel/TestExamples.lean +++ b/StrataTest/Languages/Laurel/TestExamples.lean @@ -36,4 +36,17 @@ def processLaurelFileWithOptions (options : LaurelVerifyOptions) (input : InputC def processLaurelFile (input : InputContext) : IO (Array Diagnostic) := processLaurelFileWithOptions default input +/-- Path to the directory for intermediate files, inside the build directory. + Resolved from the current working directory so it works on any machine. -/ +def buildDir : IO String := do + let cwd ← IO.currentDir + return s!"{cwd}/.lake/build/intermediatePrograms/" + +/-- Debug helper: run the Laurel pipeline keeping intermediate pass outputs in `.lake/build/intermediatePrograms/`. + Not used by any test in this repo; invoke manually via `#eval processLaurelFileKeepIntermediates (stringInputContext …)` + when diagnosing pass-internal issues. -/ +def processLaurelFileKeepIntermediates (input : InputContext) : IO (Array Diagnostic) := do + let dir ← buildDir + processLaurelFileWithOptions { translateOptions := { keepAllFilesPrefix := dir}} input + end Laurel diff --git a/StrataTest/Languages/Python/PySpecArgTypeTest.lean b/StrataTest/Languages/Python/PySpecArgTypeTest.lean index d921aaab9c..21f8e48e21 100644 --- a/StrataTest/Languages/Python/PySpecArgTypeTest.lean +++ b/StrataTest/Languages/Python/PySpecArgTypeTest.lean @@ -17,7 +17,7 @@ namespace Strata.Python.PySpecArgTypeTest open Strata.Python.Specs open Strata (buildPySpecLaurel) -open Strata.Python (OverloadTable PythonFunctionDecl PyArgInfo highTypeToPyLauType) +open Strata.Python (ModuleName OverloadTable PythonFunctionDecl PyArgInfo highTypeToPyLauType) open Strata.Laurel (Procedure formatProcedure) private def loc : SourceRange := default @@ -43,10 +43,12 @@ private def buildSpecs (sigs : Array Signature) : IO Strata.PySpecLaurelResult : IO.FS.withTempDir fun dir => do let ionFile := dir / "test.pyspec.ion" writeDDM ionFile sigs - let result ← buildPySpecLaurel #[("", ionFile.toString)] {} |>.toBaseIO - match result with + let ctx ← Strata.Pipeline.PipelineContext.create + match ← (buildPySpecLaurel ctx #[(.ofComponent (.ofString "test"), ionFile.toString)] {}).toBaseIO with | .ok r => pure r - | .error msg => throw <| .userError msg + | .error () => + let msgs ← ctx.getMessages + throw <| .userError s!"buildPySpecLaurel failed: {msgs.map toString}" private def getFuncSigs (sigs : Array Signature) : IO (List PythonFunctionDecl) := do return (← buildSpecs sigs).functionSignatures @@ -55,10 +57,10 @@ private def unionType (elts : Array SpecType) : SpecType := SpecType.unionArray loc elts /-- -info: typed_func: x=Any[], y=Any[], z=Any[], w=Any[] -untyped_func: a=Any[] -mixed_func: p=Any[], q=Any[] -optional_func: s=Any[], n=Any[] +info: test_typed_func: x=Any[], y=Any[], z=Any[], w=Any[] +test_untyped_func: a=Any[] +test_mixed_func: p=Any[], q=Any[] +test_optional_func: s=Any[], n=Any[] -/ #guard_msgs in #eval do @@ -94,7 +96,7 @@ the pyspec Laurel procedure body contains the type assertions generated by preconditions redundant. -/ /-- -info: procedure typed_func(x: Any, y: Any): Any +info: procedure test_typed_func(x: Any, y: Any): Any opaque modifies * { result := ; assert Any..isfrom_int(x); assert Any..isfrom_str(y); assume Any..isfrom_float(result) }; @@ -108,8 +110,8 @@ info: procedure typed_func(x: Any, y: Any): Any (identType .builtinsFloat) ] let procs := result.laurelProgram.staticProcedures - let some proc := procs.find? (fun (p : Procedure) => p.name.text == "typed_func") - | throw <| IO.userError "typed_func not found" + let some proc := procs.find? (fun (p : Procedure) => p.name.text == "test_typed_func") + | throw <| IO.userError "test_typed_func not found" IO.println (toString (formatProcedure proc)) end Strata.Python.PySpecArgTypeTest diff --git a/StrataTest/Languages/Python/TestExamples.lean b/StrataTest/Languages/Python/TestExamples.lean index 559534dc43..1a9dfa4197 100644 --- a/StrataTest/Languages/Python/TestExamples.lean +++ b/StrataTest/Languages/Python/TestExamples.lean @@ -38,10 +38,16 @@ def withPythonToLaurel (pythonCmd : System.FilePath) (input : InputContext) let exitCode ← child.wait if exitCode ≠ 0 then throw <| .userError s!"py_to_strata failed (exit code {exitCode}): {stderr}" - match ← pythonAndSpecToLaurel ionFile.toString - (sourcePath := some pyFile.toString) |>.toBaseIO with + let pctx ← Pipeline.PipelineContext.create (outputMode := .quiet) + match ← (pythonAndSpecToLaurel ionFile.toString + (sourcePath := some pyFile.toString)).run pctx |>.toBaseIO with | .ok r => k r pyFile - | .error err => throw <| .userError s!"pythonAndSpecToLaurel failed: {err}" + | .error () => + let msgs ← pctx.getMessages + let detail := match msgs.back? with + | some m => m.message + | none => "Pipeline aborted" + throw <| .userError s!"pythonAndSpecToLaurel failed: {detail}" /-- Run the Python → Ion → Laurel pipeline and return the Laurel program. The caller can inspect the Laurel IR directly or continue to Core/SMT. -/ diff --git a/StrataTest/Languages/Python/ToLaurelTest.lean b/StrataTest/Languages/Python/ToLaurelTest.lean index 534056eafa..bbf7e437a5 100644 --- a/StrataTest/Languages/Python/ToLaurelTest.lean +++ b/StrataTest/Languages/Python/ToLaurelTest.lean @@ -15,17 +15,41 @@ signatures into Laurel programs. namespace Strata.Python.Specs.ToLaurel.Tests +open Strata.Python (ModuleName) open Strata.Python.Specs open Strata.Laurel /-! ## Test Infrastructure -/ +private def testModule : ModuleName := .ofComponent (.ofString "test") + private def assertEq [BEq α] [ToString α] (actual expected : α) : IO Unit := do unless actual == expected do throw <| .userError s!"expected: {expected}\n actual: {actual}" private def loc : SourceRange := default +private def identType (nm : PythonIdent) : SpecType := + SpecType.ident default nm + +private def noneType : SpecType := SpecType.noneType default + +private def mkUnion (types : Array SpecType) := SpecType.unionArray loc types + +private def mkArg (name : String) (type : SpecType) (default : Option SpecDefault := none) : Arg := + { name, type, default := default } + +private def mkFuncSig (name : String) (returnType : SpecType) + (args : Array Arg := #[]) (kwonly : Array Arg := #[]) + : Signature := + .functionDecl { + loc := loc, nameLoc := loc, name := name + args := { args := args, kwonly := kwonly } + returnType := returnType + isOverload := false + preconditions := #[], postconditions := #[] + } + /-! ### Output Formatting -/ private def fmtHighType : HighType → String @@ -69,33 +93,261 @@ private def fmtTypeDef : TypeDefinition → String /-- Run signaturesToLaurel and print formatted output. Prints warnings (if any) before procedures so `#guard_msgs` can verify them. -/ -private def runTest (sigs : Array Signature) (modulePrefix : String := "") : IO Unit := do - let result := signaturesToLaurel "" sigs modulePrefix +private def runTest (sigs : Array Signature) (moduleName : ModuleName := testModule) : IO Unit := do + let result := signaturesToLaurel "" sigs moduleName for err in result.errors do - IO.println s!"warning: {err.kind.phase}.{err.kind.category}: {err.message}" + IO.println s!"warning: {err.phase}.{err.kind.category}: {err.message}" for td in result.program.types do IO.println (fmtTypeDef td) for proc in result.program.staticProcedures do IO.println (fmtProc proc) /-- Run signaturesToLaurel expecting errors. Print error messages. -/ -private def runTestErrors (sigs : Array Signature) (modulePrefix : String := "") : IO Unit := do - let result := signaturesToLaurel "" sigs modulePrefix +private def runTestErrors (sigs : Array Signature) (moduleName : ModuleName := testModule) : IO Unit := do + let result := signaturesToLaurel "" sigs moduleName assert! result.errors.size > 0 for err in result.errors do IO.println err.message /-- Run signaturesToLaurel and print warning kinds (phase.category: message). -/ -private def runTestWarningKinds (sigs : Array Signature) (modulePrefix : String := "") : IO Unit := do - let result := signaturesToLaurel "" sigs modulePrefix +private def runTestWarningKinds (sigs : Array Signature) (moduleName : ModuleName := testModule) : IO Unit := do + let result := signaturesToLaurel "" sigs moduleName assert! result.errors.size > 0 for err in result.errors do - IO.println s!"{err.kind.phase}.{err.kind.category}: {err.message}" + IO.println s!"{err.phase}.{err.kind.category}: {err.message}" + +/-- Helper to make a function signature with preconditions. -/ +private def mkFuncSigWithPrecond (name : String) (returnType : SpecType) + (preconditions : Array Assertion) (args : Array Arg := #[]) : Signature := + .functionDecl { + loc := loc, nameLoc := loc, name := name + args := { args := args, kwonly := #[] } + returnType := returnType + isOverload := false + preconditions := preconditions, postconditions := #[] + } + +/-- Helper to make a function signature with postconditions. -/ +private def mkFuncSigWithPostcond (name : String) (returnType : SpecType) + (postconditions : Array SpecExpr) : Signature := + .functionDecl { + loc := loc, nameLoc := loc, name := name + args := { args := #[], kwonly := #[] } + returnType := returnType + isOverload := false + preconditions := #[], postconditions := postconditions + } + + +/-! ## All function params and returns map to Any -/ + +/-- +info: procedure test_returns_int(x:UserDefined(Any)) returns(result:UserDefined(Any)) +procedure test_returns_bool(a:UserDefined(Any), b:UserDefined(Any)) returns(result:UserDefined(Any)) +procedure test_returns_real(flag:UserDefined(Any)) returns(result:UserDefined(Any)) +procedure test_with_kwonly(x:UserDefined(Any), verbose:UserDefined(Any)) returns(result:UserDefined(Any)) +-/ +#guard_msgs in +#eval runTest #[ + mkFuncSig "returns_int" (identType .builtinsInt) + (args := #[mkArg "x" (identType .builtinsStr)]), + mkFuncSig "returns_bool" (identType .builtinsBool) + (args := #[mkArg "a" (identType .builtinsInt), + mkArg "b" (identType .builtinsFloat)]), + mkFuncSig "returns_real" (identType .builtinsFloat) + (args := #[mkArg "flag" (identType .builtinsBool)]), + mkFuncSig "with_kwonly" (identType .builtinsStr) + (args := #[mkArg "x" (identType .builtinsInt)]) + (kwonly := #[mkArg "verbose" (identType .builtinsBool) (default := some .none)]) +] + +/-! ## Complex types (Any, List, Dict, bytes) -/ + +/-- +info: procedure test_takes_any(x:UserDefined(Any)) returns(result:UserDefined(Any)) +procedure test_takes_list(items:UserDefined(Any)) returns(result:UserDefined(Any)) +procedure test_returns_dict() returns(result:UserDefined(Any)) +procedure test_typed_list() returns(result:UserDefined(Any)) +procedure test_typed_dict() returns(result:UserDefined(Any)) +-/ +#guard_msgs in +#eval runTest #[ + mkFuncSig "takes_any" (identType .builtinsInt) + (args := #[mkArg "x" (identType .typingAny)]), + mkFuncSig "takes_list" (identType .builtinsBool) + (args := #[mkArg "items" (identType .typingList)]), + mkFuncSig "returns_dict" (identType .typingDict), + mkFuncSig "typed_list" + (SpecType.ident loc .typingList #[identType .builtinsStr]), + mkFuncSig "typed_dict" + (SpecType.ident loc .typingDict + #[identType .builtinsStr, identType .builtinsInt]) +] + +/-! ## Literal types, TypedDict, and string-literal unions → Any -/ + +/-- +info: warning: pySpecToLaurel.unsupportedUnion: TypedDict 'TypedDict(f : builtins.str)' approximated as DictStrAny in type 'TypedDict(f : builtins.str)' +procedure test_int_literal_ret() returns(result:UserDefined(Any)) +procedure test_str_literal_ret() returns(result:UserDefined(Any)) +procedure test_typed_dict_ret() returns(result:UserDefined(Any)) +procedure test_str_enum() returns(result:UserDefined(Any)) +-/ +#guard_msgs in +#eval runTest #[ + mkFuncSig "int_literal_ret" (SpecType.intLiteral loc 42), + mkFuncSig "str_literal_ret" + (SpecType.stringLiteral loc "hello"), + mkFuncSig "typed_dict_ret" + (SpecType.typedDict loc #["f"] + #[identType .builtinsStr] #[true]), + mkFuncSig "str_enum" + (mkUnion #[SpecType.stringLiteral loc "A", SpecType.stringLiteral loc "B", + SpecType.stringLiteral loc "C"]) +] + +/-! ## Optional type patterns (Union[None, T]) → Any -/ + +/-- +info: warning: pySpecToLaurel.unsupportedUnion: TypedDict 'TypedDict(x : builtins.str)' approximated as DictStrAny in type 'Union[_types.NoneType, TypedDict(x : builtins.str)]' +procedure test_opt_str() returns(result:UserDefined(Any)) +procedure test_opt_int() returns(result:UserDefined(Any)) +procedure test_opt_bool(x:UserDefined(Any)) returns(result:UserDefined(Any)) +procedure test_opt_typed_dict() returns(result:UserDefined(Any)) +procedure test_opt_str_enum() returns(result:UserDefined(Any)) +procedure test_opt_int_enum() returns(result:UserDefined(Any)) +-/ +#guard_msgs in +#eval runTest #[ + mkFuncSig "opt_str" + (mkUnion #[noneType, identType .builtinsStr]), + mkFuncSig "opt_int" + (mkUnion #[noneType, identType .builtinsInt]), + mkFuncSig "opt_bool" + (mkUnion #[noneType, identType .builtinsBool]) + (args := #[mkArg "x" + (mkUnion #[noneType, identType .builtinsStr])]), + mkFuncSig "opt_typed_dict" + (mkUnion #[noneType, + SpecType.typedDict loc #["x"] #[identType .builtinsStr] #[true]]), + mkFuncSig "opt_str_enum" + (mkUnion #[noneType, SpecType.stringLiteral loc "A", + SpecType.stringLiteral loc "B"]), + mkFuncSig "opt_int_enum" + (mkUnion #[noneType, SpecType.intLiteral loc 1, SpecType.intLiteral loc 2]) +] + +/-! ## Error cases (updated to verify MessageKind) -/ + +/-- +info: procedure test_f() returns(result:UserDefined(Any)) +-/ +#guard_msgs in +#eval runTest + #[mkFuncSig "f" + (identType (PythonIdent.ofComponent "foo" "Bar"))] + +/-- +info: procedure test_f() returns(result:UserDefined(Any)) +-/ +#guard_msgs in +#eval runTest + #[mkFuncSig "f" + (mkUnion #[identType .builtinsStr, + identType .builtinsInt])] + +/-- +info: warning: pySpecToLaurel.unsupportedUnion: No type tester for 'foo.Bar' in type 'Union[_types.NoneType, foo.Bar]' +procedure test_f() returns(result:UserDefined(Any)) +-/ +#guard_msgs in +#eval runTest + #[mkFuncSig "f" + (mkUnion #[noneType, + identType (PythonIdent.ofComponent "foo" "Bar")])] + +/-! ## Class and type definitions -/ + +/-- +info: type test_MyClass +type test_MyAlias +procedure test_my_func(x:UserDefined(Any), y:UserDefined(Any)) returns(result:UserDefined(Any)) +procedure test_MyClass@get_value() returns(result:UserDefined(Any)) +-/ +#guard_msgs in +#eval runTest #[ + mkFuncSig "my_func" (identType .builtinsBool) + (args := #[mkArg "x" (identType .builtinsInt), + mkArg "y" (identType .builtinsStr) (some .none)]), + .classDef { + loc := loc, name := "MyClass" + methods := #[ + { loc := loc, nameLoc := loc, name := "get_value" + args := { args := #[mkArg "self" (identType .builtinsStr)], kwonly := #[] } + returnType := identType .builtinsStr + isOverload := false + preconditions := #[] + postconditions := #[] } + ] + }, + .typeDef { + loc := loc, nameLoc := loc + name := "MyAlias" + definition := identType .builtinsStr + } +] + +/-! ## NoneType and void return -/ + +/-- +info: procedure test_returns_none() returns(result:UserDefined(Any)) +procedure test_takes_none(x:UserDefined(Any)) returns(result:UserDefined(Any)) +-/ +#guard_msgs in +#eval runTest #[ + mkFuncSig "returns_none" noneType, + mkFuncSig "takes_none" noneType + (args := #[mkArg "x" noneType]) +] + +/-! ## Class types as UserDefined -/ + +/-- +info: type test_Foo +procedure test_uses_class(x:UserDefined(test_Foo)) returns(result:UserDefined(Any)) +-/ +#guard_msgs in +#eval runTest #[ + .classDef { + loc := loc, name := "Foo" + methods := #[] + }, + mkFuncSig "uses_class" (identType (.mkRaw testModule "Foo")) + (args := #[mkArg "x" (identType (.mkRaw testModule "Foo"))]) +] + +/-! ## Empty input -/ + +#guard_msgs in +#eval runTest #[] + +/-! ## Overload dispatch and method registry -/ + +/-- Helper to make an @overload function signature. -/ +private def mkOverload (name : String) (returnType : SpecType) + (args : Array Arg := #[]) : Signature := + .functionDecl { + loc := loc, nameLoc := loc, name := name + args := { args := args, kwonly := #[] } + returnType := returnType + isOverload := true + preconditions := #[], postconditions := #[] + } /-- Run signaturesToLaurel and print the full result: Laurel output, dispatch table, and method registry. Sorts by key for stable output. -/ -private def runFullTest (sigs : Array Signature) (modulePrefix : String := "") : IO Unit := do - let result := signaturesToLaurel "" sigs modulePrefix +private def runFullTest (sigs : Array Signature) (moduleName : ModuleName := testModule) : IO Unit := do + let result := signaturesToLaurel "" sigs moduleName if result.errors.size > 0 then IO.println s!"errors: {result.errors.size}" for err in result.errors do @@ -144,9 +396,8 @@ private def list_ := SpecType.ident loc .typingList private def dict_ := SpecType.ident loc .typingDict private def listOf (t : SpecType) := SpecType.ident loc .typingList #[t] private def dictOf (k v : SpecType) := SpecType.ident loc .typingDict #[k, v] -private def mkUnion (types : Array SpecType) := SpecType.unionArray loc types -private def pyClass (name : String) := SpecType.ident loc (PythonIdent.mk "" name) -private def externIdent (mod name : String) := PythonIdent.mk mod name +private def pyClass (name : String) := SpecType.ident loc (.mkRaw testModule name) +private def externIdent (mod name : String) := PythonIdent.mkRaw (.ofString! mod) name private def arg (name : String) (type : SpecType) (default : Option SpecDefault := none) : Arg := { name, type, default := default } @@ -194,10 +445,10 @@ private def externType (name : String) (ident : PythonIdent) : Signature := /-! ## All function params and returns map to Any -/ /-- -info: procedure returns_int(x:UserDefined(Any)) returns(result:UserDefined(Any)) -procedure returns_bool(a:UserDefined(Any), b:UserDefined(Any)) returns(result:UserDefined(Any)) -procedure returns_real(flag:UserDefined(Any)) returns(result:UserDefined(Any)) -procedure with_kwonly(x:UserDefined(Any), verbose:UserDefined(Any)) returns(result:UserDefined(Any)) +info: procedure test_returns_int(x:UserDefined(Any)) returns(result:UserDefined(Any)) +procedure test_returns_bool(a:UserDefined(Any), b:UserDefined(Any)) returns(result:UserDefined(Any)) +procedure test_returns_real(flag:UserDefined(Any)) returns(result:UserDefined(Any)) +procedure test_with_kwonly(x:UserDefined(Any), verbose:UserDefined(Any)) returns(result:UserDefined(Any)) -/ #guard_msgs in #eval runTest #[ @@ -212,11 +463,11 @@ procedure with_kwonly(x:UserDefined(Any), verbose:UserDefined(Any)) returns(resu /-! ## Complex types (Any, List, Dict, bytes) -/ /-- -info: procedure takes_any(x:UserDefined(Any)) returns(result:UserDefined(Any)) -procedure takes_list(items:UserDefined(Any)) returns(result:UserDefined(Any)) -procedure returns_dict() returns(result:UserDefined(Any)) -procedure typed_list() returns(result:UserDefined(Any)) -procedure typed_dict() returns(result:UserDefined(Any)) +info: procedure test_takes_any(x:UserDefined(Any)) returns(result:UserDefined(Any)) +procedure test_takes_list(items:UserDefined(Any)) returns(result:UserDefined(Any)) +procedure test_returns_dict() returns(result:UserDefined(Any)) +procedure test_typed_list() returns(result:UserDefined(Any)) +procedure test_typed_dict() returns(result:UserDefined(Any)) -/ #guard_msgs in #eval runTest #[ @@ -231,10 +482,10 @@ procedure typed_dict() returns(result:UserDefined(Any)) /-- info: warning: pySpecToLaurel.unsupportedUnion: TypedDict 'TypedDict(f : builtins.str)' approximated as DictStrAny in type 'TypedDict(f : builtins.str)' -procedure int_literal_ret() returns(result:UserDefined(Any)) -procedure str_literal_ret() returns(result:UserDefined(Any)) -procedure typed_dict_ret() returns(result:UserDefined(Any)) -procedure str_enum() returns(result:UserDefined(Any)) +procedure test_int_literal_ret() returns(result:UserDefined(Any)) +procedure test_str_literal_ret() returns(result:UserDefined(Any)) +procedure test_typed_dict_ret() returns(result:UserDefined(Any)) +procedure test_str_enum() returns(result:UserDefined(Any)) -/ #guard_msgs in #eval runTest #[ @@ -250,12 +501,12 @@ procedure str_enum() returns(result:UserDefined(Any)) /-- info: warning: pySpecToLaurel.unsupportedUnion: TypedDict 'TypedDict(x : builtins.str)' approximated as DictStrAny in type 'Union[_types.NoneType, TypedDict(x : builtins.str)]' -procedure opt_str() returns(result:UserDefined(Any)) -procedure opt_int() returns(result:UserDefined(Any)) -procedure opt_bool(x:UserDefined(Any)) returns(result:UserDefined(Any)) -procedure opt_typed_dict() returns(result:UserDefined(Any)) -procedure opt_str_enum() returns(result:UserDefined(Any)) -procedure opt_int_enum() returns(result:UserDefined(Any)) +procedure test_opt_str() returns(result:UserDefined(Any)) +procedure test_opt_int() returns(result:UserDefined(Any)) +procedure test_opt_bool(x:UserDefined(Any)) returns(result:UserDefined(Any)) +procedure test_opt_typed_dict() returns(result:UserDefined(Any)) +procedure test_opt_str_enum() returns(result:UserDefined(Any)) +procedure test_opt_int_enum() returns(result:UserDefined(Any)) -/ #guard_msgs in #eval runTest #[ @@ -275,14 +526,14 @@ procedure opt_int_enum() returns(result:UserDefined(Any)) /-! ## Error cases (updated to verify WarningKind) -/ /-- -info: procedure f() returns(result:UserDefined(Any)) +info: procedure test_f() returns(result:UserDefined(Any)) -/ #guard_msgs in #eval runTest - #[func "f" (SpecType.ident loc (PythonIdent.mk "foo" "Bar"))] + #[func "f" (SpecType.ident loc (PythonIdent.ofComponent "foo" "Bar"))] /-- -info: procedure f() returns(result:UserDefined(Any)) +info: procedure test_f() returns(result:UserDefined(Any)) -/ #guard_msgs in #eval runTest @@ -290,20 +541,20 @@ info: procedure f() returns(result:UserDefined(Any)) /-- info: warning: pySpecToLaurel.unsupportedUnion: No type tester for 'foo.Bar' in type 'Union[_types.NoneType, foo.Bar]' -procedure f() returns(result:UserDefined(Any)) +procedure test_f() returns(result:UserDefined(Any)) -/ #guard_msgs in #eval runTest #[func "f" - (mkUnion #[none_, SpecType.ident loc (PythonIdent.mk "foo" "Bar")])] + (mkUnion #[none_, SpecType.ident loc (PythonIdent.ofComponent "foo" "Bar")])] /-! ## Class and type definitions -/ /-- -info: type MyClass -type MyAlias -procedure my_func(x:UserDefined(Any), y:UserDefined(Any)) returns(result:UserDefined(Any)) -procedure MyClass@get_value() returns(result:UserDefined(Any)) +info: type test_MyClass +type test_MyAlias +procedure test_my_func(x:UserDefined(Any), y:UserDefined(Any)) returns(result:UserDefined(Any)) +procedure test_MyClass@get_value() returns(result:UserDefined(Any)) -/ #guard_msgs in #eval runTest #[ @@ -315,8 +566,8 @@ procedure MyClass@get_value() returns(result:UserDefined(Any)) /-! ## NoneType and void return -/ /-- -info: procedure returns_none() returns(result:UserDefined(Any)) -procedure takes_none(x:UserDefined(Any)) returns(result:UserDefined(Any)) +info: procedure test_returns_none() returns(result:UserDefined(Any)) +procedure test_takes_none(x:UserDefined(Any)) returns(result:UserDefined(Any)) -/ #guard_msgs in #eval runTest #[ @@ -327,8 +578,8 @@ procedure takes_none(x:UserDefined(Any)) returns(result:UserDefined(Any)) /-! ## Class types as UserDefined -/ /-- -info: type Foo -procedure uses_class(x:UserDefined(Foo)) returns(result:UserDefined(Any)) +info: type test_Foo +procedure test_uses_class(x:UserDefined(test_Foo)) returns(result:UserDefined(Any)) -/ #guard_msgs in #eval runTest #[ @@ -347,9 +598,9 @@ procedure uses_class(x:UserDefined(Foo)) returns(result:UserDefined(Any)) -- overloads dispatching on string literals, a service class with methods, -- and a regular function. /-- -info: type SvcClient -procedure SvcClient@do_thing(x:UserDefined(Any)) returns(result:UserDefined(Any)) -procedure helper() returns(result:UserDefined(Any)) +info: type test_SvcClient +procedure test_SvcClient@do_thing(x:UserDefined(Any)) returns(result:UserDefined(Any)) +procedure test_helper() returns(result:UserDefined(Any)) dispatch create_client: "svc_a" -> mod.client.SvcClient "svc_b" -> mod.other.OtherClient @@ -370,11 +621,11 @@ dispatch create_client: -- Overloads with locally-defined class return types. /-- -info: type Alpha -type Beta +info: type test_Alpha +type test_Beta dispatch make: - "a" -> .Alpha - "b" -> .Beta + "a" -> test.Alpha + "b" -> test.Beta -/ #guard_msgs in #eval runFullTest #[ @@ -439,7 +690,7 @@ body contains FieldSelect: false (.intLit 0 loc) loc }]) - ] "" + ] testModule assert! result.errors.size = 0 match result.program.staticProcedures with | proc :: _ => @@ -455,21 +706,21 @@ body contains FieldSelect: false -- bytes, bytearray, complex now map to Any (matching PythonToLaurel) /-- -info: procedure f() returns(result:UserDefined(Any)) +info: procedure test_f() returns(result:UserDefined(Any)) -/ #guard_msgs in #eval runTest #[func "f" bytes] /-- -info: procedure f() returns(result:UserDefined(Any)) +info: procedure test_f() returns(result:UserDefined(Any)) -/ #guard_msgs in #eval runTest #[func "f" bytearray] /-- -info: procedure f() returns(result:UserDefined(Any)) +info: procedure test_f() returns(result:UserDefined(Any)) -/ #guard_msgs in #eval runTest @@ -477,35 +728,35 @@ info: procedure f() returns(result:UserDefined(Any)) -- Optional patterns now map to Any without warnings /-- -info: procedure f() returns(result:UserDefined(Any)) +info: procedure test_f() returns(result:UserDefined(Any)) -/ #guard_msgs in #eval runTest #[func "f" (mkUnion #[none_, float_])] /-- -info: procedure f() returns(result:UserDefined(Any)) +info: procedure test_f() returns(result:UserDefined(Any)) -/ #guard_msgs in #eval runTest #[func "f" (mkUnion #[none_, list_])] /-- -info: procedure f() returns(result:UserDefined(Any)) +info: procedure test_f() returns(result:UserDefined(Any)) -/ #guard_msgs in #eval runTest #[func "f" (mkUnion #[none_, dict_])] /-- -info: procedure f() returns(result:UserDefined(Any)) +info: procedure test_f() returns(result:UserDefined(Any)) -/ #guard_msgs in #eval runTest #[func "f" (mkUnion #[none_, any])] /-- -info: procedure f() returns(result:UserDefined(Any)) +info: procedure test_f() returns(result:UserDefined(Any)) -/ #guard_msgs in #eval runTest @@ -590,7 +841,7 @@ info: pySpecToLaurel.kwargsExpansionError: **kw has non-TypedDict type; kwargs n -- Declaration: postconditions now translated (no warning) /-- -info: procedure f() returns(result:UserDefined(Any)) +info: procedure test_f() returns(result:UserDefined(Any)) -/ #guard_msgs in #eval runTest @@ -671,7 +922,7 @@ private def translatePrecondResult (preconditions : Array Assertion) args := { args, kwonly := #[] } returnType := str, isOverload := false preconditions, postconditions := #[] - }] "" + }] testModule /-- Translate a single function with preconditions and return `(bodyString, errorCount)`. -/ @@ -717,7 +968,7 @@ private def translatePrecond (preconditions : Array Assertion) preconditions := #[{ message := #[], formula := .containsKey (.var "kwargs" loc) "key" loc }] - postconditions := #[] }] "" + postconditions := #[] }] testModule let body := getBody result |>.getD "" assertEq result.errors.size 0 assert! body.contains "result := " @@ -768,7 +1019,7 @@ private def translateFunc (args : Array Arg := #[]) args := { args := args, kwonly := #[] } returnType, isOverload := false preconditions, postconditions - }] "" + }] testModule (getBody result |>.getD "", result.errors.size) -- No args, no preconditions: body has havoc + return type assume @@ -804,7 +1055,7 @@ private def translateFunc (args : Array Arg := #[]) -- Composite return type: no assume (no tester for user-defined types) #eval do let (body, errs) := translateFunc - (returnType := SpecType.ident loc (PythonIdent.mk "mod" "Cls")) + (returnType := SpecType.ident loc (PythonIdent.ofComponent "mod" "Cls")) assert! errs == 0 assert! !body.contains "assume" diff --git a/StrataTest/Languages/Python/expected_laurel/test_procedure_in_assert.expected b/StrataTest/Languages/Python/expected_laurel/test_procedure_in_assert.expected index 8acc805d68..8d71e8b122 100644 --- a/StrataTest/Languages/Python/expected_laurel/test_procedure_in_assert.expected +++ b/StrataTest/Languages/Python/expected_laurel/test_procedure_in_assert.expected @@ -1,10 +1,11 @@ -test_procedure_in_assert.py(8, 4): ✅ pass - assert(311) -test_procedure_in_assert.py(9, 4): ✅ pass - (Origin_timedelta_Requires) -test_procedure_in_assert.py(9, 4): ✅ pass - (Origin_timedelta_Requires)hours_type -test_procedure_in_assert.py(9, 4): ✅ pass - (Origin_timedelta_Requires)days_pos -test_procedure_in_assert.py(9, 4): ✅ pass - (Origin_timedelta_Requires)hours_pos -test_procedure_in_assert.py(10, 4): ✅ pass - assert(361) -test_procedure_in_assert.py(11, 4): ✅ pass - should pass +test_procedure_in_assert.py(4, 4): ✅ pass - assert(55) +test_procedure_in_assert.py(5, 4): ✅ pass - (Origin_timedelta_Requires) +test_procedure_in_assert.py(5, 4): ✅ pass - (Origin_timedelta_Requires)hours_type +test_procedure_in_assert.py(5, 4): ✅ pass - (Origin_timedelta_Requires)days_pos +test_procedure_in_assert.py(5, 4): ✅ pass - (Origin_timedelta_Requires)hours_pos +test_procedure_in_assert.py(5, 17): ✅ pass - Check PSub exception +test_procedure_in_assert.py(6, 4): ✅ pass - assert(117) +test_procedure_in_assert.py(7, 4): ✅ pass - should pass test_procedure_in_assert.py(3, 14): ✅ pass - (main ensures) Return type constraint -DETAIL: 8 passed, 0 failed, 0 inconclusive +DETAIL: 9 passed, 0 failed, 0 inconclusive RESULT: Analysis success diff --git a/StrataTest/Languages/Python/run_py_analyze.sh b/StrataTest/Languages/Python/run_py_analyze.sh index 6c271661b5..faf632dbd2 100755 --- a/StrataTest/Languages/Python/run_py_analyze.sh +++ b/StrataTest/Languages/Python/run_py_analyze.sh @@ -101,6 +101,43 @@ for test_file in tests/test_*.py; do fi done +# --- --metrics integration test --- +# Run one test file with --metrics and validate the JSONL output. +metrics_test_file=$(ls tests/test_*.py 2>/dev/null | head -1) +if [ -n "$metrics_test_file" ] && [ -z "$filter" ]; then + metrics_base=$(basename "$metrics_test_file" .py) + metrics_ion="tests/${metrics_base}.python.st.ion" + metrics_out=$(mktemp) + # Ion file should already exist from the loop above + if [ -f "$metrics_ion" ]; then + (cd ../../.. && ./.lake/build/bin/strata $command --metrics "$metrics_out" "StrataTest/Languages/Python/${metrics_ion}" 2>/dev/null) || true + if [ ! -s "$metrics_out" ]; then + echo "ERROR: --metrics file is empty for $metrics_base" + failed=1 + else + bad_lines=0 + while IFS= read -r line; do + [ -z "$line" ] && continue + if ! echo "$line" | python3 -c "import sys,json; d=json.load(sys.stdin); assert 'type' in d" 2>/dev/null; then + echo "ERROR: --metrics invalid JSON line: $line" + bad_lines=$((bad_lines + 1)) + fi + done < "$metrics_out" + # Check that an outcome record exists + if ! grep -q '"outcome"' "$metrics_out"; then + echo "ERROR: --metrics missing outcome record for $metrics_base" + failed=1 + elif [ $bad_lines -gt 0 ]; then + echo "ERROR: --metrics has $bad_lines invalid lines for $metrics_base" + failed=1 + else + echo "Test passed: --metrics JSONL ($metrics_base)" + fi + fi + fi + rm -f "$metrics_out" +fi + if [ $pending -eq 1 ]; then for test_file in tests/pending/test_*.py; do [ -f "$test_file" ] || continue diff --git a/StrataTest/Languages/Python/tests/test_procedure_in_assert.py b/StrataTest/Languages/Python/tests/test_procedure_in_assert.py index 6fb2a194a2..2ee6e652a7 100644 --- a/StrataTest/Languages/Python/tests/test_procedure_in_assert.py +++ b/StrataTest/Languages/Python/tests/test_procedure_in_assert.py @@ -1,12 +1,8 @@ from datetime import timedelta def main() -> int: - # Test that a procedure call (timedelta_func) can appear in an - # assignment whose result is then used in an assert. - # The call is in assignment position (not expression position), - # which is the correct pattern for multi-output procedures. base: int = 100 - delta = timedelta(days=7) + delta: Any = base - timedelta(days=7) result: int = 1 assert result == 1, "should pass" return result diff --git a/StrataTest/Pipeline/PhaseTimingTest.lean b/StrataTest/Pipeline/PhaseTimingTest.lean new file mode 100644 index 0000000000..6a9955fdc0 --- /dev/null +++ b/StrataTest/Pipeline/PhaseTimingTest.lean @@ -0,0 +1,131 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ +module + +meta import Strata.Pipeline.Context + +/-! ## Phase timing tests + +Exercises nesting of `withPhase` and `withRepeatedPhase` to validate that: +1. Repeated phases aggregate without crashing. +2. Messages emitted inside nested phases get the correct phase path. +3. Nested `withRepeatedPhase` inside `withRepeatedPhase` does not corrupt + the parent's aggregation map. +4. `withRepeatedPhasePure` evaluates its expression. +-/ + +open Strata.Pipeline + +meta def mkCtx : BaseIO PipelineContext := + PipelineContext.create (outputMode := .quiet) (profilePipeline := false) + +meta def check (cond : Bool) (msg : String) : IO Unit := + unless cond do throw <| IO.userError msg + +/-! ### Test 1: withRepeatedPhase aggregates without error -/ + +#guard_msgs in +#eval show IO Unit from do + let ctx ← mkCtx + ctx.withPhase "outer" (m := IO) do + for _ in List.range 5 do + ctx.withRepeatedPhase "iter" (m := IO) do + pure () + +/-! ### Test 2: withPhase nested inside withRepeatedPhase runs correctly -/ + +#guard_msgs in +#eval show IO Unit from do + let ctx ← mkCtx + ctx.withPhase "outer" (m := IO) do + for _ in List.range 3 do + ctx.withRepeatedPhase "iter" (m := IO) do + ctx.withPhase "inner" (m := IO) do + pure () + +/-! ### Test 3: Messages inside withRepeatedPhase get correct phase tag -/ + +#guard_msgs in +#eval show IO Unit from do + let ctx ← mkCtx + let pipelineAction : PipelineM Unit := do + Strata.Pipeline.withPhase "outer" do + for _ in List.range 2 do + let ctx ← read + ctx.withRepeatedPhase "iter" do + emitMessage .laurelLoweringNotImpl "test warning" + let _ ← pipelineAction.run ctx |>.toBaseIO + let msgs ← ctx.getMessages + check (msgs.size == 2) s!"Expected 2 messages, got {msgs.size}" + let expectedPhase := Phase.base "outer" |>.subphase "iter" + for msg in msgs do + check (msg.phase == expectedPhase) + s!"Expected phase '{expectedPhase}', got '{msg.phase}'" + +/-! ### Test 4: Nested withRepeatedPhase does not corrupt parent -/ + +#guard_msgs in +#eval show IO Unit from do + let ctx ← mkCtx + ctx.withPhase "outer" (m := IO) do + for _ in List.range 4 do + ctx.withRepeatedPhase "a" (m := IO) do + for _ in List.range 2 do + ctx.withRepeatedPhase "b" (m := IO) do + pure () + +/-! ### Test 5: Multiple distinct withPhase inside withRepeatedPhase -/ + +#guard_msgs in +#eval show IO Unit from do + let ctx ← mkCtx + ctx.withPhase "outer" (m := IO) do + for _ in List.range 3 do + ctx.withRepeatedPhase "iter" (m := IO) do + ctx.withPhase "preprocess" (m := IO) do pure () + ctx.withPhase "solve" (m := IO) do pure () + +/-! ### Test 6: Messages inside nested withPhase get deepest phase path -/ + +#guard_msgs in +#eval show IO Unit from do + let ctx ← mkCtx + let pipelineAction : PipelineM Unit := do + Strata.Pipeline.withPhase "parent" do + let ctx ← read + ctx.withRepeatedPhase "iter" do + Strata.Pipeline.withPhase "child" do + emitMessage .laurelLoweringNotImpl "deep msg" + let _ ← pipelineAction.run ctx |>.toBaseIO + let msgs ← ctx.getMessages + check (msgs.size == 1) s!"Expected 1 message, got {msgs.size}" + let expectedPhase := Phase.base "parent" |>.subphase "iter" |>.subphase "child" + match msgs[0]? with + | some msg => + check (msg.phase == expectedPhase) + s!"Expected phase '{expectedPhase}', got '{msg.phase}'" + | none => throw <| IO.userError "unreachable" + +/-! ### Test 7: withRepeatedPhasePure evaluates expression -/ + +/-- +info: withRepeatedPhasePure: evaluating +withRepeatedPhasePure: evaluating +withRepeatedPhasePure: evaluating +withRepeatedPhasePure: evaluating +-/ +#guard_msgs in +#eval show IO Unit from do + let ctx ← mkCtx + let evalRef ← IO.mkRef (0 : Nat) + ctx.withPhase "outer" (m := IO) do + for _ in List.range 4 do + let _ ← ctx.withRepeatedPhasePure "compute" fun () => + dbg_trace "withRepeatedPhasePure: evaluating" + 42 + evalRef.modify (· + 1) + let count ← evalRef.get + check (count == 4) s!"Expected 4 evaluations, got {count}" diff --git a/StrataTest/Transform/ANFEncoderTests.lean b/StrataTest/Transform/ANFEncoderTests.lean index 835b8230af..f40337acb8 100644 --- a/StrataTest/Transform/ANFEncoderTests.lean +++ b/StrataTest/Transform/ANFEncoderTests.lean @@ -138,4 +138,37 @@ procedure test (x : int, y : int) #guard_msgs in #eval IO.println (toString (anfEncodeProgram (translateCore uniqueSubexprProg)).2) +/-! ## Multi-pass: outer subsumes inner duplicate -/ + +-- `(x + 1) * 2` and `x + 1` are both duplicated, but `x + 1` is a subexpression +-- of `(x + 1) * 2` and so is dropped by `removeSubsumed` in pass 1. After pass +-- 1 lifts `(x + 1) * 2` into a var declaration, `x + 1` appears once in that +-- var-decl init AND once in the third assert, exposing a fresh duplicate that +-- pass 2 then extracts. Without fixpoint iteration the third assert would +-- still hold a free `(x + 1)` that duplicates the var-decl's init. +private def nestedDupProg := +#strata +program Core; +procedure test(x : int) { + assert ((x + 1) * 2 > 0); + assert ((x + 1) * 2 < 100); + assert (x + 1 > 50); +}; +#end + +/-- +info: program Core; + +procedure test (x : int) +{ + var $__anf.1 : int := x + 1; + var $__anf.0 : int := $__anf.1 * 2; + assert [assert_0]: $__anf.0 > 0; + assert [assert_1]: $__anf.0 < 100; + assert [assert_2]: $__anf.1 > 50; +}; +-/ +#guard_msgs in +#eval IO.println (toString (anfEncodeProgram (translateCore nestedDupProg)).2) + end Core.ANFEncoder.Tests diff --git a/StrataTest/Transform/PrecondElim.lean b/StrataTest/Transform/PrecondElim.lean index 651ec8b0de..3242d25afa 100644 --- a/StrataTest/Transform/PrecondElim.lean +++ b/StrataTest/Transform/PrecondElim.lean @@ -418,4 +418,80 @@ procedure test (inout g : int, y : int) #guard_msgs in #eval (Std.format (transformProgram loopGuardPrecondPgm)) +/-! ### Test 10: `collectPrecondAsserts` tags Sequence bounds obligations with `outOfBoundsAccess` + +Exercises the full `collectPrecondAsserts` path — the code called by +`transformStmt` / `mkContractWFProc` / `mkFuncWFProc` — and inspects the +metadata on the generated assert. Mirrors `OverflowCheckTest.lean`. -/ + +section SeqBoundsObligations + +open Strata Core Lambda Core.PrecondElim Imperative + +/-- Shared fvar fixtures so each per-op case below is a one-liner. -/ +private def fxS : Core.Expression.Expr := .fvar () ⟨"s", ()⟩ (some (Core.seqTy .int)) +private def fxI : Core.Expression.Expr := .fvar () ⟨"i", ()⟩ (some .int) +private def fxV : Core.Expression.Expr := .fvar () ⟨"v", ()⟩ (some .int) +private def fxN : Core.Expression.Expr := .fvar () ⟨"n", ()⟩ (some .int) +private def fxJ : Core.Expression.Expr := .fvar () ⟨"j", ()⟩ (some .int) + +/-- Check that `collectPrecondAsserts` produces exactly `expectedCount` + obligations from `expr`, each tagged with `outOfBoundsAccess`. -/ +private def assertOutOfBoundsObligations + (expr : Core.Expression.Expr) (expectedCount : Nat) : IO Unit := do + let stmts := collectPrecondAsserts Core.Factory expr "test" #[] + assert! stmts.length == expectedCount + for s in stmts do + let md : MetaData Core.Expression := match s with + | Statement.assert _ _ md => md | _ => #[] + assert! md.getPropertyType == some MetaData.outOfBoundsAccess + +-- Sequence.select / update / take / drop each emit one out-of-bounds obligation. +#eval assertOutOfBoundsObligations (LExpr.mkApp () Core.seqSelectOp [fxS, fxI]) 1 +#eval assertOutOfBoundsObligations (LExpr.mkApp () Core.seqUpdateOp [fxS, fxI, fxV]) 1 +#eval assertOutOfBoundsObligations (LExpr.mkApp () Core.seqTakeOp [fxS, fxN]) 1 +#eval assertOutOfBoundsObligations (LExpr.mkApp () Core.seqDropOp [fxS, fxN]) 1 + +-- Nested: `Sequence.select(Sequence.update(s, i, v), j)` emits two +-- obligations (one per partial call), both tagged `outOfBoundsAccess`. +#eval assertOutOfBoundsObligations + (LExpr.mkApp () Core.seqSelectOp [LExpr.mkApp () Core.seqUpdateOp [fxS, fxI, fxV], fxJ]) 2 + +-- Sequence.length is total: no precondition obligations generated. +#eval do + let stmts := collectPrecondAsserts Core.Factory + (LExpr.mkApp () Core.seqLengthOp [fxS]) "test" #[] + assert! stmts.isEmpty + +/-! #### Test 10a: Pretty-printed obligation shape per partial op + +Catches regressions that preserve count and metadata tag but corrupt the +obligation body (e.g. swapping `.Lt` for `.Le` at a call site, changing +the bound variable name, or swapping the lower/upper bound inside +`mkSeqBoundsPrecond`). -/ + +private def printFirstObligation (expr : Core.Expression.Expr) : IO Unit := do + let stmts := collectPrecondAsserts Core.Factory expr "test" #[] + match stmts.head? with + | some (Statement.assert _ e _) => IO.println s!"{Std.format e}" + | _ => IO.println "" + +/-- info: 0 <= i && i < Sequence.length(s) -/ +#guard_msgs in +#eval printFirstObligation (LExpr.mkApp () Core.seqSelectOp [fxS, fxI]) + +/-- info: 0 <= i && i < Sequence.length(s) -/ +#guard_msgs in +#eval printFirstObligation (LExpr.mkApp () Core.seqUpdateOp [fxS, fxI, fxV]) + +/-- info: 0 <= n && n <= Sequence.length(s) -/ +#guard_msgs in +#eval printFirstObligation (LExpr.mkApp () Core.seqTakeOp [fxS, fxN]) + +/-- info: 0 <= n && n <= Sequence.length(s) -/ +#guard_msgs in +#eval printFirstObligation (LExpr.mkApp () Core.seqDropOp [fxS, fxN]) + +end SeqBoundsObligations + end PrecondElimTests diff --git a/StrataTest/Util/Python.lean b/StrataTest/Util/Python.lean index a1528054f6..f6abf7a1ae 100644 --- a/StrataTest/Util/Python.lean +++ b/StrataTest/Util/Python.lean @@ -68,6 +68,18 @@ def miseWhere (runtime : String) (miseCmd : String := "mise") : IO (Option Syste throw <| .userError msg pure <| some stdout.trimAscii.toString +/-- +info: none +-/ +#guard_msgs in +#eval miseWhere "Python@1.0" + +/-- +info: none +-/ +#guard_msgs in +#eval miseWhere "Python@3.12" (miseCmd := "nonexisting-mise") + /-- This checks to see if a module is found. -/ @@ -97,18 +109,6 @@ def pythonCheckModule (pythonCmd : System.FilePath) (moduleName : String) : IO B throw <| .userError s!"{pythonCmd} has unexpected exit code {exitCode}" -/-- -info: none --/ -#guard_msgs in -#eval miseWhere "Python@1.0" - -/-- -info: none --/ -#guard_msgs in -#eval miseWhere "Python@3.12" (miseCmd := "nonexisting-mise") - /-- Utility to get Python 3 minor version. @@ -208,9 +208,5 @@ def withPython (action : System.FilePath → IO Unit) : IO Unit := do s!"Python Strata libraries not installed in {pythonCmd}." action pythonCmd -/-- Check if `needle` is a substring of `haystack`. -/ -def containsSubstr (haystack needle : String) : Bool := - (haystack.splitOn needle).length != 1 - end Strata.Python end diff --git a/StrataTestExtra/Languages/Python/AnalyzeLaurelTest.lean b/StrataTestExtra/Languages/Python/AnalyzeLaurelTest.lean index 0b2c761a2e..b706b5ede7 100644 --- a/StrataTestExtra/Languages/Python/AnalyzeLaurelTest.lean +++ b/StrataTestExtra/Languages/Python/AnalyzeLaurelTest.lean @@ -23,13 +23,17 @@ Messaging) are generic and not tied to any cloud provider. namespace Strata.Python.AnalyzeLaurelTest open Strata (pythonAndSpecToLaurel pySpecsDir) +open Strata.Pipeline (PipelineContext) -private meta def testDir : System.FilePath := +meta def quietCtx : BaseIO PipelineContext := + PipelineContext.create (outputMode := .quiet) + +meta def testDir : System.FilePath := "StrataTestExtra/Languages/Python/Specs/dispatch_test" /-- Compile a Python source file to a `.python.st.ion` Ion file. Returns the path to the generated Ion file. -/ -private meta def compilePython +meta def compilePython (pythonCmd : System.FilePath) (dialectFile : System.FilePath) (pyFile : System.FilePath) (outDir : System.FilePath) : IO System.FilePath := do @@ -57,7 +61,7 @@ private meta def compilePython /-- Set up the test fixture: compile all servicelib modules and return the spec directory. The dispatch and pyspec modules are resolved by name. -/ -private meta def setupFixture (pythonCmd : System.FilePath) +meta def setupFixture (pythonCmd : System.FilePath) (outDir : System.FilePath) : IO Unit := do IO.FS.withTempFile fun _handle dialectFile => do IO.FS.writeBinFile dialectFile Python.Python.toIon @@ -70,7 +74,7 @@ private meta def setupFixture (pythonCmd : System.FilePath) | .error msg => throw <| IO.userError s!"pySpecsDir failed: {msg}" /-- Compile a test Python file to Ion format. -/ -private meta def compileTestScript (pythonCmd : System.FilePath) +meta def compileTestScript (pythonCmd : System.FilePath) (pyFile : System.FilePath) (outDir : System.FilePath) : IO System.FilePath := do IO.FS.withTempFile fun _handle dialectFile => do @@ -78,17 +82,26 @@ private meta def compileTestScript (pythonCmd : System.FilePath) compilePython pythonCmd dialectFile pyFile outDir /-- Run pyAnalyzeLaurel on a test script within the shared fixture. -/ -private meta def runAnalyze +meta def runAnalyze (pythonCmd : System.FilePath) (tmpDir : System.FilePath) (scriptName : String) : IO (Except String Core.Program) := do let testIon ← compileTestScript pythonCmd (testDir / scriptName) tmpDir + let pctx ← quietCtx let laurel ← - match ← Strata.pythonAndSpecToLaurel testIon.toString + match ← (Strata.pythonAndSpecToLaurel testIon.toString (dispatchModules := #["servicelib"]) - (specDir := tmpDir) |>.toBaseIO with + (specDir := tmpDir)).run pctx |>.toBaseIO with | .ok r => pure r - | .error err => return .error (toString err) + | .error () => + -- Flag tool errors, then user errors, then general + if let some r := (← pctx.getToolErrors).back? then + return .error <| r.message + if let some r := (← pctx.getUserCodeErrors).back? then + return .error <| s!"User code error: {r.message}" + if let some m := (←pctx.getMessages).back? then + return .error m.message + return .error "Pipeline aborted for unspecified reason (bug)" match ← Strata.translateCombinedLaurel laurel with | (some core, []) => -- Also run Core type checking to catch semantic errors (e.g. Heap vs Any) @@ -100,18 +113,22 @@ private meta def runAnalyze /-- Run pyAnalyzeLaurel with inlining and verification. When `useRoots` is true, entry points are determined via the call graph (the CLI `--entry-point roots` default); otherwise only `__main__` is used. -/ -private meta def runAnalyzeAndVerify +meta def runAnalyzeAndVerify (pythonCmd : System.FilePath) (tmpDir : System.FilePath) (scriptName : String) (useRoots : Bool := false) : IO (Except String (Array Core.VCResult)) := do let testIon ← compileTestScript pythonCmd (testDir / scriptName) tmpDir + let pctx ← quietCtx let laurel ← - match ← Strata.pythonAndSpecToLaurel testIon.toString + match ← (Strata.pythonAndSpecToLaurel testIon.toString (dispatchModules := #["servicelib"]) - (specDir := tmpDir) |>.toBaseIO with + (specDir := tmpDir)).run pctx |>.toBaseIO with | .ok r => pure r - | .error err => return .error (toString err) + | .error () => + let msgs ← pctx.getMessages + let detail := match msgs.back? with | some m => m.message | none => "Pipeline aborted" + return .error detail let (coreProgramOption, _) ← Strata.translateCombinedLaurel laurel let coreProgram ← match coreProgramOption with | none => return .error "Laurel to Core translation failed" @@ -144,13 +161,13 @@ private meta def runAnalyzeAndVerify | .error msg => return .error (toString msg) /-- Expected outcome for a test case. -/ -private inductive Expected where +inductive Expected where | success | fail (msg : String) | failPrefix (pfx : String) /-- All dispatch test cases: (filename, expected outcome). -/ -private meta def testCases : List (String × Expected) := [ +meta def testCases : List (String × Expected) := [ -- Positive tests .mk "test_single_service.py" .success, .mk "test_multi_service.py" .success, @@ -206,7 +223,7 @@ private meta def testCases : List (String × Expected) := [ ] /-- Run a single test case and return an error message on failure, or `none` on success. -/ -private meta def runTestCase (pythonCmd : System.FilePath) (tmpDir : System.FilePath) +meta def runTestCase (pythonCmd : System.FilePath) (tmpDir : System.FilePath) (scriptName : String) (expected : Expected) : IO (Option String) := do let result ← runAnalyze pythonCmd tmpDir scriptName match expected, result with @@ -246,13 +263,17 @@ private meta def runTestCase (pythonCmd : System.FilePath) (tmpDir : System.File -- causes a type unification error in Core.typeCheck, which is expected. let task ← IO.asTask do let testIon ← compileTestScript pythonCmd (testDir / "test_class_any_as_composite.py") tmpDir + let pctx ← quietCtx let laurel ← - match ← Strata.pythonAndSpecToLaurel testIon.toString + match ← (Strata.pythonAndSpecToLaurel testIon.toString (dispatchModules := #["servicelib"]) (pyspecModules := #["servicelib.Storage"]) - (specDir := tmpDir) |>.toBaseIO with + (specDir := tmpDir)).run pctx |>.toBaseIO with | .ok r => pure r - | .error err => return some s!"test_class_any_as_composite.py: {err}" + | .error () => + let msgs ← pctx.getMessages + let detail := match msgs.back? with | some m => m.message | none => "Pipeline aborted" + return some s!"test_class_any_as_composite.py: {detail}" match ← Strata.translateCombinedLaurel laurel with | (some core, []) => match Core.typeCheck Core.VerifyOptions.quiet core (moreFns := Strata.Python.ReFactory) with @@ -372,12 +393,16 @@ recursively translates subclasses, so the type setupFixture pythonCmd tmpDir let testIon ← compileTestScript pythonCmd (testDir / "test_resolution_after_filter.py") tmpDir + let pctx ← quietCtx let combined ← - match ← Strata.pythonAndSpecToLaurel testIon.toString + match ← (Strata.pythonAndSpecToLaurel testIon.toString (dispatchModules := #["servicelib"]) - (specDir := tmpDir) |>.toBaseIO with + (specDir := tmpDir)).run pctx |>.toBaseIO with | .ok r => pure r - | .error err => throw <| IO.userError s!"pyAnalyzeLaurel failed: {err}" + | .error () => + let msgs ← pctx.getMessages + let detail := match msgs.back? with | some m => m.message | none => "Pipeline aborted" + throw <| IO.userError s!"pyAnalyzeLaurel failed: {detail}" let result := Laurel.resolve combined unless result.errors.isEmpty do let msgs := result.errors.toList.map (·.message) diff --git a/StrataTestExtra/Languages/Python/DictNoneTest.lean b/StrataTestExtra/Languages/Python/DictNoneTest.lean index 6ed9baaede..2dbe7a73c6 100644 --- a/StrataTestExtra/Languages/Python/DictNoneTest.lean +++ b/StrataTestExtra/Languages/Python/DictNoneTest.lean @@ -16,7 +16,7 @@ is correctly detected as a bug, both for direct assignments and dict unpacking. namespace Strata.Python.DictNoneTest -open Strata.Python (processPythonFile withPython containsSubstr) +open Strata.Python (processPythonFile withPython) open Strata.Parser (stringInputContext) -- Test 1: Using a valid int should succeed (0 diagnostics). @@ -32,7 +32,7 @@ open Strata.Parser (stringInputContext) throw <| .userError s!"Expected 0 diagnostics, got {diags.size}: {diags.map (·.message)}" private def isAssertionFailure (msg : String) : Bool := - containsSubstr msg "does not hold" || containsSubstr msg "could not be proved" + msg.contains "does not hold" || msg.contains "could not be proved" -- Test 2: Assigning None to an int variable with a value-dependent assertion. #guard_msgs in @@ -110,7 +110,7 @@ def main() -> None: match ← (processPythonFile pythonCmd (stringInputContext "test.py" program)).toBaseIO with | .ok _ => throw <| IO.userError "Expected error for len() on class without __len__" | .error err => - unless containsSubstr (toString err) "len() is not supported" do + unless (toString err).contains "len() is not supported" do throw <| IO.userError s!"Unexpected error: {err}" end Strata.Python.DictNoneTest diff --git a/StrataTestExtra/Languages/Python/Specs/IdentifyOverloadsTest.lean b/StrataTestExtra/Languages/Python/Specs/IdentifyOverloadsTest.lean index 6da52f61a7..604641c209 100644 --- a/StrataTestExtra/Languages/Python/Specs/IdentifyOverloadsTest.lean +++ b/StrataTestExtra/Languages/Python/Specs/IdentifyOverloadsTest.lean @@ -22,6 +22,7 @@ fewer. namespace Strata.Python.Specs.IdentifyOverloadsTest open Strata (readDispatchOverloads pySpecsDir pySpecOutputPath) +open Strata.Python (ModuleName) open Strata.Python.Specs.IdentifyOverloads (resolveOverloads) open Strata.Python (OverloadTable) @@ -75,10 +76,11 @@ private meta def buildOverloadTable throw <| .userError s!"pySpecsDir failed for {pyFile}: {msg}" let some ionPath := pySpecOutputPath testDir outDir pyFile | throw <| .userError s!"Cannot derive output path for {pyFile}" - match ← readDispatchOverloads #[ionPath.toString] |>.toBaseIO with - | .ok (tbl, _) => return tbl - | .error msg => - throw <| .userError s!"readDispatchOverloads failed: {msg}" + let ctx ← Strata.Pipeline.PipelineContext.create + match ← (readDispatchOverloads ctx #[ionPath.toString]).toBaseIO with + | .ok tbl => return tbl + | .error () => + throw <| .userError s!"readDispatchOverloads failed for {ionPath}" /-- Parse a user Python Ion file into statements. -/ private meta def parseStmts (ionPath : System.FilePath) @@ -94,32 +96,32 @@ private meta def resolveFile (pythonCmd : System.FilePath) (tbl : OverloadTable) (pyFile : System.FilePath) (outDir : System.FilePath) - : IO (Std.HashSet String) := do + : IO (Std.HashSet ModuleName) := do let ionPath ← compilePython pythonCmd pyFile outDir let stmts ← parseStmts ionPath return (resolveOverloads tbl stmts).modules /-- A test case: Python file and exact expected module set. -/ private structure TestCase where - file : String - expected : List String + file : System.FilePath + expected : List ModuleName private meta def testCases : List TestCase := [ -- Single service at top level { file := "test_single_service.py" - expected := ["servicelib.Storage"] }, + expected := [.ofString! "servicelib.Storage"] }, -- Multiple services { file := "test_multi_service.py" - expected := ["servicelib.Storage", "servicelib.Messaging"] }, + expected := [.ofString! "servicelib.Storage", .ofString! "servicelib.Messaging"] }, -- Dispatch inside a class method { file := "test_class_dispatch.py" - expected := ["servicelib.Storage"] }, + expected := [.ofString! "servicelib.Storage"] }, -- Dispatch in both branches of an if/else { file := "test_dispatch_in_conditional.py" - expected := ["servicelib.Storage", "servicelib.Messaging"] }, + expected := [.ofString! "servicelib.Storage", .ofString! "servicelib.Messaging"] }, -- Dispatch inside a try block { file := "test_dispatch_in_try.py" - expected := ["servicelib.Storage"] }, + expected := [.ofString! "servicelib.Storage"] }, -- No dispatch calls at all { file := "test_no_dispatch.py" expected := [] }, @@ -134,11 +136,11 @@ private meta def runTestCase (tbl : OverloadTable) (outDir : System.FilePath) (tc : TestCase) : IO (Option String) := do let modules ← resolveFile pythonCmd tbl (testDir / tc.file) outDir - let expected : Std.HashSet String := + let expected : Std.HashSet ModuleName := tc.expected.foldl (init := {}) fun s m => s.insert m if modules == expected then return none - let got := modules.toList - let exp := expected.toList + let got := modules.toList.map toString + let exp := expected.toList.map toString return some s!"{tc.file}: expected modules {exp}, got {got}" @@ -146,8 +148,8 @@ private meta def runTestCase IO.FS.withTempDir fun tmpDir => do let tbl ← buildOverloadTable pythonCmd tmpDir -- Launch all tests concurrently - let mut seen : Std.HashSet String := {} - let mut tasks : Array (String × Task (Except IO.Error (Option String))) := #[] + let mut seen : Std.HashSet System.FilePath := {} + let mut tasks : Array (System.FilePath × Task (Except IO.Error (Option String))) := #[] for tc in testCases do if tc.file ∈ seen then throw <| IO.userError s!"Duplicate test filename: {tc.file}" diff --git a/StrataTestExtra/Languages/Python/Specs/RelativeImportTest.lean b/StrataTestExtra/Languages/Python/Specs/RelativeImportTest.lean index 0c08b50acc..198d38f8b8 100644 --- a/StrataTestExtra/Languages/Python/Specs/RelativeImportTest.lean +++ b/StrataTestExtra/Languages/Python/Specs/RelativeImportTest.lean @@ -20,8 +20,8 @@ actionable error messages. namespace Strata.Python.Specs.RelativeImportTest -open Strata.Python.Specs (translateFile ModuleName) -open Strata.Python (containsSubstr) +open Strata.Python.Specs (translateFile) +open Strata.Python (ModuleName) private meta def testDir : System.FilePath := "StrataTestExtra/Languages/Python/Specs/import_test" @@ -38,26 +38,16 @@ private meta def runTest (pythonCmd : System.FilePath) (dialectFile : System.Fil let pythonFile := testDir / file let searchPath := if searchFromTestDir then testDir else pythonFile.parent.getD pythonFile - -- When searching from testDir, derive a multi-component module name - -- from the relative path so that currentModulePrefix is non-empty - -- (e.g. "service/rel_import_basic.py" → "service.rel_import_basic"). - let moduleName ← if searchFromTestDir then - let stem := if (file : String).endsWith "/__init__.py" then - ((file : String).dropEnd "/__init__.py".length).toString - else - ((file : String).dropEnd ".py".length).toString - let dotted := stem.replace "/" "." - match ModuleName.ofString dotted with - | .ok m => pure (some m) - | .error e => return some s!"{file}: bad module name: {e}" - else pure none + let moduleName ← match ModuleName.ofRelativePath file with + | .ok info => pure info.moduleName + | .error msg => return some s!"{file}: {msg}" let r ← translateFile (pythonCmd := toString pythonCmd) (dialectFile := dialectFile) (strataDir := strataDir) (pythonFile := pythonFile) (searchPath := searchPath) - (moduleName := moduleName) + moduleName |>.toBaseIO if expectedErrors.isEmpty then match r with @@ -68,7 +58,7 @@ private meta def runTest (pythonCmd : System.FilePath) (dialectFile : System.Fil | .ok _ => return some s!"{file}: expected error but translation succeeded" | .error msg => for expected in expectedErrors do - if !containsSubstr msg expected then + if !msg.contains expected then return some s!"{file}: error missing expected substring \"{expected}\"\nActual error:\n{msg}" return none diff --git a/StrataTestExtra/Languages/Python/SpecsTest.lean b/StrataTestExtra/Languages/Python/SpecsTest.lean index ff67b38e92..1bc0e66a03 100644 --- a/StrataTestExtra/Languages/Python/SpecsTest.lean +++ b/StrataTestExtra/Languages/Python/SpecsTest.lean @@ -5,7 +5,7 @@ -/ module meta import Strata.Languages.Python.Specs -meta import Strata.Languages.Python.Specs.DDM +meta import all Strata.Languages.Python.Specs.DDM import Strata.DDM.Ion import Strata.Languages.Python.PythonDialect meta import StrataTest.Util.Python @@ -226,6 +226,7 @@ meta def testCase : IO Unit := withPython fun pythonCmd => do (strataDir := strataDir) (pythonFile := testDir / "main.py") (searchPath := testDir) + (.ofComponent (.ofString "main")) |>.toBaseIO match r with | .ok (sigs, warnings) => @@ -264,6 +265,7 @@ meta def warningTestCase : IO Unit := withPython fun pythonCmd => do (strataDir := strataDir) (pythonFile := testDir / "warnings.py") (searchPath := testDir) + (.ofComponent (.ofString "warnings")) |>.toBaseIO match r with | .ok (sigs, warnings) => @@ -282,7 +284,7 @@ meta def warningTestCase : IO Unit := withPython fun pythonCmd => do "skipped Expr in function body" -- kw["a"] (bare expression) ] for expected in expectedWarnings do - if !warnings.any (containsSubstr · expected) then + if !warnings.any (·.contains expected) then let warnStr := warnings.foldl (init := "") fun acc w => s!"{acc}\n {w}" throw <| IO.userError s!"Missing expected warning containing \"{expected}\". Actual warnings:{warnStr}" @@ -294,7 +296,7 @@ meta def warningTestCase : IO Unit := withPython fun pythonCmd => do meta def testNegRoundTrip (v : Nat) : Bool := - DDM.Int.ofDDM (.negInt SourceRange.none ⟨.none, v⟩) = .negOfNat v + DDM.Int.ofDDM (.negInt SourceRange.none ⟨.none, v⟩) = Int.negOfNat v #guard testNegRoundTrip 0 #guard testNegRoundTrip 1 diff --git a/StrataTestExtra/Languages/Python/VerifyPythonTest.lean b/StrataTestExtra/Languages/Python/VerifyPythonTest.lean index a8343c61a1..2e23151c6f 100644 --- a/StrataTestExtra/Languages/Python/VerifyPythonTest.lean +++ b/StrataTestExtra/Languages/Python/VerifyPythonTest.lean @@ -17,7 +17,7 @@ Python → Laurel → Core → SMT pipeline and produces diagnostics. namespace Strata.Python.VerifyPythonTest open StrataTest.Util -open Strata.Python (processPythonFile processPythonToLaurel withPython containsSubstr manglePythonMethod) +open Strata.Python (processPythonFile processPythonToLaurel withPython manglePythonMethod) open Strata.Parser (stringInputContext) /-- Run the Python → Laurel pipeline and return the Laurel program together @@ -190,7 +190,7 @@ def main() -> None: throw <| IO.userError "Expected pipeline error for too many positional arguments" catch e => let msg := toString e - unless containsSubstr msg "too many positional arguments" do + unless msg.contains "too many positional arguments" do throw <| IO.userError s!"Expected 'too many positional arguments' error, got: {msg}" -- Extra positional args with **kwargs expansion should also error. @@ -209,7 +209,7 @@ def main() -> None: throw <| IO.userError "Expected pipeline error for too many positional arguments" catch e => let msg := toString e - unless containsSubstr msg "too many positional arguments" do + unless msg.contains "too many positional arguments" do throw <| IO.userError s!"Expected 'too many positional arguments' error, got: {msg}" -- Returning a Composite-typed value from a function with Any return type @@ -341,7 +341,7 @@ def main() -> None: let (laurel, output) ← toLaurel pythonCmd program let calcAdd := manglePythonMethod "Calculator" "add" assertOpaque laurel calcAdd - unless containsSubstr output s!"{calcAdd}(" do + unless output.contains s!"{calcAdd}(" do throw <| IO.userError s!"Expected '{calcAdd}(' in Laurel output but not found" -- self.field.method() resolution and composite field initialization: @@ -372,10 +372,10 @@ def main() -> None: let (_, output) ← toLaurel pythonCmd program let innerValidate := manglePythonMethod "Inner" "validate" -- self.inner.validate() must resolve to Inner@validate StaticCall - unless containsSubstr output s!"{innerValidate}(" do + unless output.contains s!"{innerValidate}(" do throw <| IO.userError s!"Expected '{innerValidate}(' in Laurel output but not found" -- Composite field assignment (self.inner: Inner = ...) uses New initialization - unless containsSubstr output "new Inner" do + unless output.contains "new Inner" do throw <| IO.userError s!"Expected 'new Inner' in Laurel output but not found" -- Inheritance guard: when a class is part of an inheritance hierarchy, @@ -412,7 +412,7 @@ def main() -> None: | none => throw <| .userError "main procedure not found" | some proc => let mainOutput := toString (Laurel.formatProcedure proc) - if containsSubstr mainOutput s!"{baseValue}(" then + if mainOutput.contains s!"{baseValue}(" then throw <| IO.userError s!"main should NOT call {baseValue} (inheritance guard)" -- Inheritance with field type conflict: B inherits A and redeclares field x @@ -478,7 +478,7 @@ def main() -> None: | none => throw <| .userError "main procedure not found" | some proc => let mainOutput := toString (Laurel.formatProcedure proc) - if containsSubstr mainOutput s!"{aF}(" then + if mainOutput.contains s!"{aF}(" then throw <| IO.userError s!"main should NOT call {aF} (inheritance dispatch unsound)" -- Cross-class method dispatch: a method in one class calls a method on @@ -508,10 +508,10 @@ def main() -> None: let engineGetHp := manglePythonMethod "Engine" "get_hp" let carHorsepower := manglePythonMethod "Car" "horsepower" -- self.engine.get_hp() should resolve to Engine@get_hp StaticCall - unless containsSubstr output s!"{engineGetHp}(" do + unless output.contains s!"{engineGetHp}(" do throw <| IO.userError s!"Expected '{engineGetHp}(' in Laurel output but not found" -- Car@horsepower should also be a StaticCall from main - unless containsSubstr output s!"{carHorsepower}(" do + unless output.contains s!"{carHorsepower}(" do throw <| IO.userError s!"Expected '{carHorsepower}(' in Laurel output but not found" -- Full pipeline: composite field assignment goes through the entire diff --git a/Tools/BoogieToStrata/IntegrationTests/BoogieToStrataIntegrationTests.cs b/Tools/BoogieToStrata/IntegrationTests/BoogieToStrataIntegrationTests.cs index 05e16b9b9f..2f921c041b 100644 --- a/Tools/BoogieToStrata/IntegrationTests/BoogieToStrataIntegrationTests.cs +++ b/Tools/BoogieToStrata/IntegrationTests/BoogieToStrataIntegrationTests.cs @@ -51,7 +51,28 @@ public static IEnumerable GetBoogieTestFiles() { } } + /// + /// Returns true if the first 5 lines of contain + /// the literal token "{:smack}". Files carrying this marker opt into the + /// --smack CLI flag, which gates the assert_. synthetic-requires + /// injection and InferModifies=true. + /// + private static bool HasSmackMarker(string filePath) { + if (!File.Exists(filePath)) return false; + using var reader = new StreamReader(filePath); + for (var i = 0; i < 5; i++) { + var line = reader.ReadLine(); + if (line == null) break; + if (line.Contains("{:smack}", StringComparison.Ordinal)) return true; + } + return false; + } + private (int, string, string) RunTranslation(string filePath) { + return RunTranslation(filePath, HasSmackMarker(filePath)); + } + + private (int, string, string) RunTranslation(string filePath, bool smack) { // Capture console output using var consoleOutput = new StringWriter(); using var consoleError = new StringWriter(); @@ -62,7 +83,8 @@ public static IEnumerable GetBoogieTestFiles() { try { Console.SetOut(consoleOutput); Console.SetError(consoleError); - exitCode = BoogieToStrata.Main([filePath]); + var args = smack ? new[] { "--smack", filePath } : new[] { filePath }; + exitCode = BoogieToStrata.Main(args); } catch (Exception) { exitCode = 1; } finally { @@ -149,6 +171,182 @@ public void VerifyTestFile(string fileName, string filePath) { Assert.Equal(expectedExitCode, proc.ExitCode); } + /// + /// Regression test: assert_. procedures must produce a single + /// merged spec block, not duplicates, regardless of whether the input + /// already has user-written specs. Two cases: + /// 1. existing ensures only — synthetic requires merges in. + /// 2. existing requires — synthetic requires is added alongside, + /// not silently dropped. + /// + [Fact] + public void SmackAssertProducesSingleMergedSpecBlock() { + var filePath = Path.Combine(TestsDirectory, "SmackAssertDuplicateSpec.bpl"); + Assert.True(File.Exists(filePath), $"Test file does not exist: {filePath}"); + + var (exitCode, standardOutput, errorOutput) = RunTranslation(filePath); + Assert.Equal(0, exitCode); + + // Three procedures (assert_.i32, assert_.i32_with_req, main); the two + // assert_. ones each produce one spec block; main has none. + // Count occurrences of "spec {" in the output: must be exactly 2. + var specCount = 0; + var searchFrom = 0; + while (true) { + var idx = standardOutput.IndexOf("spec {", searchFrom, StringComparison.Ordinal); + if (idx < 0) break; + specCount++; + searchFrom = idx + 1; + } + + output.WriteLine($"Output:\n{standardOutput}"); + Assert.Equal(2, specCount); + + // Overall sanity: the output contains at least one of each clause kind. + // (assert_.i32 has an `ensures`, both procedures have at least one + // `requires`.) + Assert.Contains("requires", standardOutput); + Assert.Contains("ensures", standardOutput); + + // The critical regression check: BOTH clauses must appear in the + // SECOND procedure's spec block (assert_.i32_with_req) — i.e., the + // requires-already-present case must not silently drop the synthetic + // clause. Procedures emit in source order, so the second `spec {` is + // assert_.i32_with_req's. + var firstSpec = standardOutput.IndexOf("spec {", StringComparison.Ordinal); + Assert.True(firstSpec >= 0, "Expected at least one spec block"); + var secondSpecStart = standardOutput.IndexOf("spec {", firstSpec + 1, StringComparison.Ordinal); + Assert.True(secondSpecStart >= 0, "Expected a second spec block for assert_.i32_with_req"); + var secondSpecEnd = standardOutput.IndexOf("}", secondSpecStart, StringComparison.Ordinal); + Assert.True(secondSpecEnd > secondSpecStart, "Second spec block missing closing brace"); + var secondSpec = standardOutput.Substring(secondSpecStart, secondSpecEnd - secondSpecStart); + + // The user-written `requires (p.0 > -1)` (sanitized to `p_0 > -(1)`) + // and the synthetic `requires (p.0 != 0)` (sanitized to `p_0 != 0`) + // must both be present in this single spec block. + Assert.Contains("p_0 > -(1)", secondSpec); + Assert.Contains("p_0 != 0", secondSpec); + } + + /// + /// Regression test: old() expressions must use the renamed name when a + /// variable has a name collision (e.g., global var `main` vs procedure `main`). + /// Previously, IdentifierExpr with Decl != null used NameOf() correctly, but + /// had a silent fallback to Name() when Decl was null — which would emit the + /// unrenamed (wrong) name in the collision case. Post-resolution, Decl should + /// always be non-null; the fallback masked potential bugs. + /// + [Fact] + public void OldExprUsesRenamedNameOnCollision() { + var filePath = Path.Combine(TestsDirectory, "OldExprRenameCollision.bpl"); + Assert.True(File.Exists(filePath), $"Test file does not exist: {filePath}"); + + var (exitCode, standardOutput, errorOutput) = RunTranslation(filePath); + + output.WriteLine($"Output:\n{standardOutput}"); + if (!string.IsNullOrEmpty(errorOutput)) { + output.WriteLine($"Error output: {errorOutput}"); + } + + Assert.Equal(0, exitCode); + + // The output should contain an `old` expression referencing the *renamed* + // variable (e.g., __var_main), not the raw name `main` which is the procedure. + // Look for the pattern "old __var_main" or similar renamed form. + Assert.Contains("old __var_main", standardOutput); + + // Also ensure the output does NOT contain "old main" (unrenamed fallback). + // This would indicate the fallback to Name() was used instead of NameOf(). + Assert.DoesNotContain("old main", standardOutput); + } + + /// + /// Regression test for InferModifies = true. + /// + /// SMACK-generated Boogie can omit explicit modifies clauses on procedures. + /// With InferModifies = true, Boogie's ModSetCollector infers them so that + /// the translator correctly emits globals as `inout` parameters (modified) + /// rather than read-only parameters. + /// + /// This test uses a .bpl file where procedure p() assigns to global g but + /// has no `modifies g;` clause. If InferModifies is working, the output + /// should contain `inout g` for procedure p. + /// + [Fact] + public void InferModifiesEmitsInoutForMutatedGlobal() { + var filePath = Path.Combine(TestsDirectory, "InferModifiesGlobal.bpl"); + Assert.True(File.Exists(filePath), $"Test file does not exist: {filePath}"); + + var (exitCode, standardOutput, errorOutput) = RunTranslation(filePath); + + output.WriteLine($"Output:\n{standardOutput}"); + if (!string.IsNullOrEmpty(errorOutput)) { + output.WriteLine($"Error output: {errorOutput}"); + } + + // Translation must succeed — if InferModifies is broken, Boogie would + // reject the program because g is assigned without a modifies clause. + Assert.Equal(0, exitCode); + + // The inferred modifies clause should cause the translator to emit + // `inout g` on procedure p's parameter list. + Assert.Contains("inout g", standardOutput); + } + + /// + /// Pin down the --smack gate: without the flag, the assert_. + /// pattern is treated as an opaque procedure (no synthetic requires + /// injected). Translation succeeds; the output does not contain a + /// requires clause for the assert_ procedure. + /// + [Fact] + public void SmackAssertWithoutFlagDoesNotInjectRequires() { + var filePath = Path.Combine(TestsDirectory, "SmackAssert.bpl"); + Assert.True(File.Exists(filePath), $"Test file does not exist: {filePath}"); + + var (exitCode, standardOutput, errorOutput) = RunTranslation(filePath, smack: false); + + output.WriteLine($"Output:\n{standardOutput}"); + if (!string.IsNullOrEmpty(errorOutput)) { + output.WriteLine($"Error output: {errorOutput}"); + } + + Assert.Equal(0, exitCode); + + // Without --smack, no synthetic requires is added, so no `requires` + // clause should appear anywhere in the translation of this file + // (the .bpl has no user-written requires either). + Assert.DoesNotContain("requires", standardOutput); + } + + /// + /// Pin down the --smack gate: without the flag, InferModifies is off. + /// A program that omits an explicit `modifies` clause on a procedure + /// that mutates a global is rejected at typecheck. + /// + [Fact] + public void InferModifiesOffWithoutSmackFlag() { + var filePath = Path.Combine(TestsDirectory, "InferModifiesGlobal.bpl"); + Assert.True(File.Exists(filePath), $"Test file does not exist: {filePath}"); + + var (exitCode, standardOutput, errorOutput) = RunTranslation(filePath, smack: false); + + output.WriteLine($"Exit code: {exitCode}"); + output.WriteLine($"Output:\n{standardOutput}"); + if (!string.IsNullOrEmpty(errorOutput)) { + output.WriteLine($"Error output: {errorOutput}"); + } + + // Without --smack, ResolveAndTypecheck rejects the program because + // procedure p mutates global g without an explicit `modifies g;` + // clause. BoogieToStrata.Main writes a "Failed to typecheck" line + // to stderr and returns exit code 1. Pin both signals so a future + // regression that fails for an unrelated reason (parse error, + // arg-handling change) doesn't silently pass this test. + Assert.Equal(1, exitCode); + Assert.Contains("Failed to typecheck", errorOutput); + } + [Fact] public void ErrorCodeWithNoArguments() { var result = BoogieToStrata.Main(Array.Empty()); diff --git a/Tools/BoogieToStrata/Source/BoogieToStrata.cs b/Tools/BoogieToStrata/Source/BoogieToStrata.cs index 1816f6a1fc..de97643552 100644 --- a/Tools/BoogieToStrata/Source/BoogieToStrata.cs +++ b/Tools/BoogieToStrata/Source/BoogieToStrata.cs @@ -3,22 +3,59 @@ namespace BoogieToStrata; public static class BoogieToStrata { + private const string Usage = "Usage: BoogieToStrata [--smack] "; + + private static bool _smack; + private static void PrintResolvedProgram(ExecutionEngineOptions options, ProcessedProgram prog) { var writer = new TokenTextWriter(Console.Out, options); - StrataGenerator.EmitProgramAsStrata(options, prog.Program, writer); + StrataGenerator.EmitProgramAsStrata(options, prog.Program, writer, _smack); + } + + /// + /// Parse args into (smack, filename). Returns false on any malformed + /// invocation (zero or two-plus positional args, unknown flags); the + /// caller should print Usage and return exit code 1. + /// + private static bool TryParseArgs(string[] args, out bool smack, out string filename) { + smack = false; + filename = ""; + string? positional = null; + foreach (var arg in args) { + if (arg == "--smack") { + smack = true; + } else if (arg.StartsWith("--")) { + return false; // unknown flag + } else if (positional == null) { + positional = arg; + } else { + return false; // two positional args + } + } + if (positional == null) return false; // no positional arg + filename = positional; + return true; } public static int Main(string[] args) { - if (args.Length != 1) { - Console.Error.WriteLine("Usage: BoogieToStrata "); + if (!TryParseArgs(args, out var smack, out var filename)) { + Console.Error.WriteLine(Usage); return 1; } - - var filename = args[0]; + _smack = smack; var options = new CommandLineOptions(Console.Out, new ConsolePrinter()) { Verify = false, - TypeEncodingMethod = CoreOptions.TypeEncoding.Predicates + TypeEncodingMethod = CoreOptions.TypeEncoding.Predicates, + // Under --smack, SMACK-generated Boogie often omits explicit + // `modifies` clauses on procedures that mutate globals. + // InferModifies runs ModSetCollector.CollectModifies to populate + // empty modifies clauses and suppresses modifies-clause + // typechecking (via CheckModifies), so that ResolveAndTypecheck + // does not reject SMACK programs missing modifies clauses. + // For strict Boogie input (no --smack), this stays false and + // missing modifies clauses are reported as typecheck errors. + InferModifies = smack }; var boogieEngine = ExecutionEngine.CreateWithoutSharedCache(options); @@ -44,4 +81,4 @@ public static int Main(string[] args) { return 0; } -} \ No newline at end of file +} diff --git a/Tools/BoogieToStrata/Source/StrataGenerator.cs b/Tools/BoogieToStrata/Source/StrataGenerator.cs index 368e6e7b5f..ad269fc77a 100644 --- a/Tools/BoogieToStrata/Source/StrataGenerator.cs +++ b/Tools/BoogieToStrata/Source/StrataGenerator.cs @@ -52,15 +52,33 @@ public class StrataGenerator : ReadOnlyVisitor { // Global variables collected from the program, used to convert them // into inout/input parameters on procedure headers and call sites. private List _globalVariables = []; - - private StrataGenerator(VCGenOptions options, TokenTextWriter writer, Program program) { + // Renames for declarations whose sanitized name collides with another + // declaration. Keyed by Boogie Declaration object to avoid ambiguity + // when two entities share the same original name (e.g., const main and + // procedure main). First-seen wins; later entities get prefixed. + // + // Registration order determines who wins a collision: + // 1. Procedures — registered first, always keep their name. + // 2. Implementations — claimed defensively (they share names with + // their procedures, but claiming guards against edge cases). + // 3. Constants, Functions, Globals — registered last; in a + // proc-vs-const collision the constant is always renamed. + private readonly Dictionary _renames = new(); + // True when the input is SMACK-generated Boogie. Gates SMACK-specific + // accommodations (synthetic `requires (p != 0)` on assert_. procedures). + // The companion `InferModifies = true` knob is set on the Boogie options + // by the BoogieToStrata.Main entrypoint, also gated on this flag. + private readonly bool _smack; + + private StrataGenerator(VCGenOptions options, TokenTextWriter writer, Program program, bool smack) { _options = options; _writer = writer; _program = program; + _smack = smack; } - public static void EmitProgramAsStrata(VCGenOptions options, Program p, TokenTextWriter writer) { - var generator = new StrataGenerator(options, writer, p); + public static void EmitProgramAsStrata(VCGenOptions options, Program p, TokenTextWriter writer, bool smack) { + var generator = new StrataGenerator(options, writer, p, smack); var fieldTypeCollector = new FieldTypeCollector(); fieldTypeCollector.Visit(p); @@ -74,6 +92,31 @@ public static void EmitProgramAsStrata(VCGenOptions options, Program p, TokenTex generator.FindSpecialTypes(); + // Build rename map for declarations with colliding sanitized names. + // Two kinds of collision are handled: + // 1. Cross-namespace: constant vs procedure sharing the same name + // 2. Sanitization: distinct names that map to the same string + // (e.g., $add.i32 and $add_i32 both become _add_i32) + // First-seen wins; colliding entities get a suffix (_2, _3, ...). + var claimed = new HashSet(); + + foreach (var proc in p.Procedures) + ClaimOrRename(proc, proc.Name, "__proc_", claimed, generator._renames); + // Defensive: implementations share names with their corresponding + // procedures, so they would normally never collide. Claiming them + // here guards against edge cases (e.g., an implementation whose + // procedure was pruned or renamed upstream). + foreach (var impl in p.Implementations) { + var sanitized = SanitizeNameForStrata(impl.Name); + claimed.Add(sanitized); + } + foreach (var c in liveDeclarations.OfType()) + ClaimOrRename(c, c.TypedIdent.Name, "__const_", claimed, generator._renames); + foreach (var f in liveDeclarations.OfType()) + ClaimOrRename(f, f.Name, "__func_", claimed, generator._renames); + foreach (var g in p.GlobalVariables) + ClaimOrRename(g, g.Name, "__var_", claimed, generator._renames); + var typeConstructors = p.TopLevelDeclarations.OfType().ToList(); if (typeConstructors.Count != 0) { generator.WriteLine("// Type constructors"); @@ -205,6 +248,29 @@ private static string SanitizeNameForStrata(string name) { .Replace("$", "_"); } + /// + /// Claim a sanitized name for , or rename it if the + /// name is already taken. The first declaration to claim a name wins; + /// subsequent colliders get a prefixed (and possibly suffixed) name recorded + /// in . + /// + private static void ClaimOrRename( + Declaration decl, + string originalName, + string prefix, + HashSet claimed, + Dictionary renames) { + var sanitized = SanitizeNameForStrata(originalName); + if (claimed.Add(sanitized)) return; + var candidate = $"{prefix}{sanitized}"; + if (!claimed.Add(candidate)) { + var i = 2; + while (!claimed.Add($"{candidate}_{i}")) i++; + candidate = $"{candidate}_{i}"; + } + renames[decl] = candidate; + } + private void AddUniqueConst(Type t, string name) { if (!_uniqueConstants.TryGetValue(t, out var value)) { value = new HashSet(); @@ -299,6 +365,12 @@ private string Name(string name) { return SanitizeNameForStrata(name); } + private string NameOf(Declaration decl, string originalName) { + if (_renames.TryGetValue(decl, out var renamed)) + return renamed; + return SanitizeNameForStrata(originalName); + } + private void WriteLine(string text) { _writer.WriteLine(text); } @@ -310,7 +382,10 @@ private void EmitOldExpr(Expr expr) { switch (expr) { case IdentifierExpr identExpr: WriteText("old "); - WriteText(Name(identExpr.Name)); + if (identExpr.Decl == null) + throw new StrataConversionException(identExpr.tok, + $"IdentifierExpr '{identExpr.Name}' has null Decl (expected non-null post-resolution)"); + WriteText(NameOf(identExpr.Decl, identExpr.Name)); break; case NAryExpr { Fun: MapSelect } mapSelect: WriteText("("); @@ -546,7 +621,10 @@ public override Expr VisitExpr(Expr node) { case LiteralExpr literalExpr: throw new StrataConversionException(node.tok, $"Unsupported literal type: {literalExpr}"); case IdentifierExpr identifierExpr: - WriteText(Name(identifierExpr.Name)); + if (identifierExpr.Decl == null) + throw new StrataConversionException(identifierExpr.tok, + $"IdentifierExpr '{identifierExpr.Name}' has null Decl (expected non-null post-resolution)"); + WriteText(NameOf(identifierExpr.Decl, identifierExpr.Name)); break; case NAryExpr nAryExpr: { var fun = nAryExpr.Fun; @@ -634,7 +712,7 @@ public override Expr VisitExpr(Expr node) { break; case FunctionCall functionCall: { - WriteText($"{Name(functionCall.FunctionName)}("); + WriteText($"{NameOf(functionCall.Func, functionCall.FunctionName)}("); EmitSeparated(args, e => VisitExpr(e), ", "); WriteText(")"); break; @@ -918,7 +996,7 @@ public override GotoCmd VisitGotoCmd(GotoCmd node) { private void EmitSimpleAssign(SimpleAssignLhs lhs, Expr rhs) { Indent(); - WriteText($"{Name(lhs.AssignedVariable.Name)} := "); + WriteText($"{NameOf(lhs.AssignedVariable.Decl, lhs.AssignedVariable.Name)} := "); VisitExpr(rhs); WriteLine(";"); } @@ -953,17 +1031,17 @@ public override Cmd VisitCallCmd(CallCmd node) { var modifiesNames = new HashSet(callee.Modifies.Select(m => m.Name)); Indent("call "); - WriteText($"{Name(callee.Name)}("); + WriteText($"{NameOf(callee, callee.Name)}("); // Emit: inout globals, then read-only globals, then original args, then out outputs. var needComma = false; foreach (var g in _globalVariables.Where(g => modifiesNames.Contains(g.Name))) { if (needComma) WriteText(", "); - WriteText($"inout {Name(g.Name)}"); + WriteText($"inout {NameOf(g, g.Name)}"); needComma = true; } foreach (var g in _globalVariables.Where(g => !modifiesNames.Contains(g.Name))) { if (needComma) WriteText(", "); - WriteText(Name(g.Name)); + WriteText(NameOf(g, g.Name)); needComma = true; } foreach (var arg in node.Ins) { @@ -983,7 +1061,7 @@ public override Cmd VisitCallCmd(CallCmd node) { public override Cmd VisitHavocCmd(HavocCmd node) { foreach (var x in node.Vars) { - IndentLine($"havoc {Name(x.Name)};"); + IndentLine($"havoc {NameOf(x.Decl, x.Name)};"); } // All assumptions come after all havocs! This allows where clauses @@ -1510,7 +1588,7 @@ public override Block VisitBlock(Block node) { public override Constant VisitConstant(Constant node) { var ti = node.TypedIdent; - var name = Name(ti.Name); + var name = NameOf(node, ti.Name); WriteText($"const {name} : "); VisitType(ti.Type); if (node.Unique) { @@ -1523,7 +1601,7 @@ public override Constant VisitConstant(Constant node) { public override GlobalVariable VisitGlobalVariable(GlobalVariable node) { var ti = node.TypedIdent; - WriteText($"var {Name(ti.Name)} : "); + WriteText($"var {NameOf(node, ti.Name)} : "); VisitType(ti.Type); WriteLine(";"); return node; @@ -1681,7 +1759,7 @@ private void MaybeEmitBuiltinBody(Function function) { } public override Function VisitFunction(Function node) { - WriteText($"function {Name(node.Name)}"); + WriteText($"function {NameOf(node, node.Name)}"); EmitTypeParameters(node.TypeParameters); WriteText("("); WriteFormals(node.InParams); @@ -1723,7 +1801,7 @@ private void WriteProcedureHeader(Procedure proc) { var modifiesGlobals = _globalVariables.Where(g => modifiesNames.Contains(g.Name)).ToList(); var readOnlyGlobals = _globalVariables.Where(g => !modifiesNames.Contains(g.Name)).ToList(); - WriteText($"procedure {Name(proc.Name)}"); + WriteText($"procedure {NameOf(proc, proc.Name)}"); EmitTypeParameters(proc.TypeParameters); WriteText("("); // Emit: inout globals, then read-only globals, then original inputs, then out outputs. @@ -1772,9 +1850,36 @@ private void WriteProcedureHeader(Procedure proc) { public override Procedure VisitProcedure(Procedure node) { if (!_program.Implementations.Any(i => i.Name.Equals(node.Name))) { - WriteProcedureHeader(node); - WriteLine(";"); - WriteLine(); + // Under --smack, SMACK encodes C assert(expr) as a call to + // assert_.*(cond). Inject a synthetic requires precondition so the + // call-elimination pass generates a VC checking the condition is + // non-zero. We add it to node.Requires so WriteProcedureHeader + // emits it inside a single spec block alongside any existing + // specs. The injection always fires when the name pattern matches + // (no Requires.Count == 0 guard) — if the procedure already has a + // hand-written requires, both clauses appear in the merged spec + // block, preserving the SMACK invariant unconditionally. + Requires? syntheticReq = null; + if (_smack && node.Name.StartsWith("assert_.") && node.InParams.Count > 0) { + var param = node.InParams[0]; + var paramExpr = new IdentifierExpr(param.tok, param); + var zero = new LiteralExpr(param.tok, Microsoft.BaseTypes.BigNum.FromInt(0)); + var neqExpr = Expr.Neq(paramExpr, zero); + syntheticReq = new Requires(false, neqExpr); + node.Requires.Add(syntheticReq); + } + + try { + WriteProcedureHeader(node); + WriteLine(";"); + WriteLine(); + } finally { + // Remove the synthetic requires to avoid mutating the shared + // AST, even if WriteProcedureHeader threw. + if (syntheticReq != null) { + node.Requires.Remove(syntheticReq); + } + } } return node; @@ -1787,7 +1892,7 @@ private void WriteFormals(IEnumerable variables, ref bool needComma, if (needComma) WriteText(", "); var name = v.TypedIdent.Name ?? ""; if (name == "") name = $"x{n++}"; - WriteText($"{prefix}{Name(name)} : "); + WriteText($"{prefix}{NameOf(v, name)} : "); VisitType(v.TypedIdent.Type); needComma = true; } diff --git a/Tools/BoogieToStrata/Tests/AssertPrefixFalsePositive.bpl b/Tools/BoogieToStrata/Tests/AssertPrefixFalsePositive.bpl new file mode 100644 index 0000000000..50bdf5579f --- /dev/null +++ b/Tools/BoogieToStrata/Tests/AssertPrefixFalsePositive.bpl @@ -0,0 +1,12 @@ +// {:smack} +// Regression test: procedures starting with "assert_" but NOT matching +// SMACK's assert_.TYPE pattern should NOT get a synthetic requires. +// Only assert_. (literal dot) is the SMACK pattern. + +procedure assert_helper(p: int) returns (r: int); + +procedure main() returns (r: int) +{ + // assert_helper is a normal procedure, passing 0 should be fine + call r := assert_helper(0); +} diff --git a/Tools/BoogieToStrata/Tests/AssertPrefixFalsePositive.expect b/Tools/BoogieToStrata/Tests/AssertPrefixFalsePositive.expect new file mode 100644 index 0000000000..e9e8e85d9e --- /dev/null +++ b/Tools/BoogieToStrata/Tests/AssertPrefixFalsePositive.expect @@ -0,0 +1,2 @@ +Successfully parsed. +All 0 goals passed. diff --git a/Tools/BoogieToStrata/Tests/GlobalVarRenameCollision.bpl b/Tools/BoogieToStrata/Tests/GlobalVarRenameCollision.bpl new file mode 100644 index 0000000000..5a1275cbff --- /dev/null +++ b/Tools/BoogieToStrata/Tests/GlobalVarRenameCollision.bpl @@ -0,0 +1,9 @@ +const $a.b: int; axiom $a.b > 0; +var $a_b: int; // both sanitize to _a_b +procedure main() returns (r: int) + modifies $a_b; +{ + $a_b := 1; + havoc $a_b; + r := $a.b + $a_b; +} diff --git a/Tools/BoogieToStrata/Tests/GlobalVarRenameCollision.expect b/Tools/BoogieToStrata/Tests/GlobalVarRenameCollision.expect new file mode 100644 index 0000000000..e9e8e85d9e --- /dev/null +++ b/Tools/BoogieToStrata/Tests/GlobalVarRenameCollision.expect @@ -0,0 +1,2 @@ +Successfully parsed. +All 0 goals passed. diff --git a/Tools/BoogieToStrata/Tests/InferModifiesGlobal.bpl b/Tools/BoogieToStrata/Tests/InferModifiesGlobal.bpl new file mode 100644 index 0000000000..0767f22518 --- /dev/null +++ b/Tools/BoogieToStrata/Tests/InferModifiesGlobal.bpl @@ -0,0 +1,17 @@ +// {:smack} +// Regression test for InferModifies = true. +// +// This procedure mutates the global variable `g` but has NO explicit +// `modifies g;` clause. With InferModifies = true, Boogie's +// ModSetCollector should infer the modifies clause so that the +// BoogieToStrata translator emits `inout g` for procedure p. +// If InferModifies is ever disabled or broken, this file will fail +// to translate correctly (g would be treated as read-only instead +// of inout). + +var g: int; + +procedure p() +{ + g := 1; +} diff --git a/Tools/BoogieToStrata/Tests/InferModifiesGlobal.expect b/Tools/BoogieToStrata/Tests/InferModifiesGlobal.expect new file mode 100644 index 0000000000..e9e8e85d9e --- /dev/null +++ b/Tools/BoogieToStrata/Tests/InferModifiesGlobal.expect @@ -0,0 +1,2 @@ +Successfully parsed. +All 0 goals passed. diff --git a/Tools/BoogieToStrata/Tests/NamespaceCollision.bpl b/Tools/BoogieToStrata/Tests/NamespaceCollision.bpl new file mode 100644 index 0000000000..caedc2c08a --- /dev/null +++ b/Tools/BoogieToStrata/Tests/NamespaceCollision.bpl @@ -0,0 +1,25 @@ +// Minimal reproduction: namespace collision bug in BoogieToStrata. +// Boogie allows a constant and procedure to share the same name +// because they live in separate namespaces. BoogieToStrata emits +// both into Strata Core's single namespace, causing: +// "a declaration of this name already exists" + +type ref = int; + +const main: ref; +axiom (main == 1000); + +var x: int; + +procedure main() + modifies x; +{ + var y: int; + x := 42; + assert x == 42; + // Use the constant in an expression that doesn't require the axiom to verify + y := 0; + if (main == 1000) { y := 1; } + // This assertion is trivially true regardless of the axiom + assert y == 0 || y == 1; +} diff --git a/Tools/BoogieToStrata/Tests/NamespaceCollision.expect b/Tools/BoogieToStrata/Tests/NamespaceCollision.expect new file mode 100644 index 0000000000..c2c458ba8c --- /dev/null +++ b/Tools/BoogieToStrata/Tests/NamespaceCollision.expect @@ -0,0 +1,4 @@ +Successfully parsed. +NamespaceCollision.core.st(23, 4) [assert_0]: ✅ pass +NamespaceCollision.core.st(33, 4) [assert_1]: ✅ pass +All 2 goals passed. diff --git a/Tools/BoogieToStrata/Tests/OldExprRenameCollision.bpl b/Tools/BoogieToStrata/Tests/OldExprRenameCollision.bpl new file mode 100644 index 0000000000..7494acd002 --- /dev/null +++ b/Tools/BoogieToStrata/Tests/OldExprRenameCollision.bpl @@ -0,0 +1,17 @@ +// Regression test: old() expression with a renamed (colliding) variable. +// +// The global variable `main` collides with procedure `main` after +// sanitization (cross-namespace collision). The variable must be +// renamed when emitted. An old() expression referencing that variable +// must use the *renamed* name, not the raw sanitized name. +// If the code silently falls back to Name() instead of NameOf(), the +// old() expression will reference the wrong (procedure) name. + +var main: int; + +procedure main() + modifies main; + ensures main == old(main) + 1; +{ + main := main + 1; +} diff --git a/Tools/BoogieToStrata/Tests/OldExprRenameCollision.expect b/Tools/BoogieToStrata/Tests/OldExprRenameCollision.expect new file mode 100644 index 0000000000..98a3632dbf --- /dev/null +++ b/Tools/BoogieToStrata/Tests/OldExprRenameCollision.expect @@ -0,0 +1,3 @@ +Successfully parsed. +OldExprRenameCollision.core.st(10, 2) [main_ensures_0]: ✅ pass +All 1 goals passed. diff --git a/Tools/BoogieToStrata/Tests/SanitizationCollision.bpl b/Tools/BoogieToStrata/Tests/SanitizationCollision.bpl new file mode 100644 index 0000000000..7a4f339a01 --- /dev/null +++ b/Tools/BoogieToStrata/Tests/SanitizationCollision.bpl @@ -0,0 +1,25 @@ +// Test case: sanitization collision in BoogieToStrata. +// +// SanitizeNameForStrata replaces '.', '$', '@', '#', '^' with '_'. +// This means distinct Boogie identifiers can map to the same Strata name: +// $add.i32 → _add_i32 +// $add_i32 → _add_i32 (collision!) +// +// The rename mechanism should detect this and disambiguate. + +type i32 = int; + +function {:inline} $add.i32(i1: i32, i2: i32) returns (i32) { i1 + i2 } +function {:inline} $add_i32(i1: i32, i2: i32) returns (i32) { i1 + i2 } + +procedure main() returns (r: i32) +ensures r == 5; +{ + var a: i32; + var b: i32; + a := $add.i32(2, 3); + b := $add_i32(2, 3); + assert a == 5; + assert b == 5; + r := a; +} diff --git a/Tools/BoogieToStrata/Tests/SanitizationCollision.expect b/Tools/BoogieToStrata/Tests/SanitizationCollision.expect new file mode 100644 index 0000000000..d7e99638ef --- /dev/null +++ b/Tools/BoogieToStrata/Tests/SanitizationCollision.expect @@ -0,0 +1,5 @@ +Successfully parsed. +SanitizationCollision.core.st(29, 4) [assert_0]: ✅ pass +SanitizationCollision.core.st(30, 4) [assert_1]: ✅ pass +SanitizationCollision.core.st(21, 2) [main_ensures_0]: ✅ pass +All 3 goals passed. diff --git a/Tools/BoogieToStrata/Tests/SmackAssert.bpl b/Tools/BoogieToStrata/Tests/SmackAssert.bpl new file mode 100644 index 0000000000..5094a86704 --- /dev/null +++ b/Tools/BoogieToStrata/Tests/SmackAssert.bpl @@ -0,0 +1,18 @@ +// {:smack} +// Minimal test case for SMACK assert_ pattern recognition. +// SMACK encodes C assert(expr) as a call to assert_.i32(cond). +// BoogieToStrata should recognize this pattern and emit: +// assert (cond != 0); +// instead of an opaque procedure call. + +type i32 = int; + +procedure assert_.i32(p.0: i32) returns ($r: i32); + +procedure main() returns ($r: i32) +{ + // assert(false) — should fail verification + call $r := assert_.i32(0); + $r := 0; + return; +} diff --git a/Tools/BoogieToStrata/Tests/SmackAssert.expect b/Tools/BoogieToStrata/Tests/SmackAssert.expect new file mode 100644 index 0000000000..8af84c0bbc --- /dev/null +++ b/Tools/BoogieToStrata/Tests/SmackAssert.expect @@ -0,0 +1,3 @@ +Successfully parsed. +SmackAssert.core.st(21, 6) [callElimAssert_assert__i32_requires_0_2]: ❌ fail +Finished with 0 goals passed, 1 failed. diff --git a/Tools/BoogieToStrata/Tests/SmackAssertDuplicateSpec.bpl b/Tools/BoogieToStrata/Tests/SmackAssertDuplicateSpec.bpl new file mode 100644 index 0000000000..3f3d49bce3 --- /dev/null +++ b/Tools/BoogieToStrata/Tests/SmackAssertDuplicateSpec.bpl @@ -0,0 +1,30 @@ +// {:smack} +// Regression test: assert_. procedures must produce a single +// merged spec block, not duplicates, regardless of whether the input +// already has user-written specs. +// +// Two procedures exercise the two cases: +// 1. assert_.i32 has only an existing `ensures` — output must merge +// the synthetic `requires (p.0 != 0)` with the existing ensures +// into one spec block. +// 2. assert_.i32_with_req has an existing `requires (p.0 > -1)` — +// output must contain BOTH requires clauses (the synthetic one +// and the user-written one) in a single spec block, not drop the +// synthetic one. + +type i32 = int; + +procedure assert_.i32(p.0: i32) returns ($r: i32); + ensures ($r == 0); + +procedure assert_.i32_with_req(p.0: i32) returns ($r: i32); + requires (p.0 > -1); + +procedure main() returns ($r: i32) +{ + // assert(true) -- should pass because p.0 != 0 holds for 1 + call $r := assert_.i32(1); + // call assert_.i32_with_req(1) — both `1 != 0` and `1 > -1` hold + call $r := assert_.i32_with_req(1); + return; +} diff --git a/Tools/BoogieToStrata/Tests/TypeSynonymChain.bpl b/Tools/BoogieToStrata/Tests/TypeSynonymChain.bpl new file mode 100644 index 0000000000..cc5811b580 --- /dev/null +++ b/Tools/BoogieToStrata/Tests/TypeSynonymChain.bpl @@ -0,0 +1,20 @@ +// Regression test for multi-level type synonym resolution. +// dealiasTypeExpr must recurse through: ref → i64 → int +// Without recursive resolution, comparison and arithmetic on `ref` +// trigger a panic because the type stays as a synonym instead of +// resolving to the base `int` type. + +type i64 = int; +type ref = i64; + +procedure main() returns (r: ref) +ensures r >= 0; +{ + var a: ref; + var b: ref; + a := 3; + b := a + 4; + assert b == 7; + assert a <= b; + r := b; +} diff --git a/Tools/BoogieToStrata/Tests/TypeSynonymChain.expect b/Tools/BoogieToStrata/Tests/TypeSynonymChain.expect new file mode 100644 index 0000000000..197fd783cb --- /dev/null +++ b/Tools/BoogieToStrata/Tests/TypeSynonymChain.expect @@ -0,0 +1,5 @@ +Successfully parsed. +TypeSynonymChain.core.st(22, 4) [assert_0]: ✅ pass +TypeSynonymChain.core.st(23, 4) [assert_1]: ✅ pass +TypeSynonymChain.core.st(14, 2) [main_ensures_0]: ✅ pass +All 3 goals passed. diff --git a/Tools/Python/strata/base.py b/Tools/Python/strata/base.py index 54a22f729b..deb773984b 100644 --- a/Tools/Python/strata/base.py +++ b/Tools/Python/strata/base.py @@ -1048,6 +1048,7 @@ def __init__(self, name: str): self.name = name self.imports = [] self.decls = [] + self.typecheck = True def add_import(self, name: str): self.imports.append(name) @@ -1111,6 +1112,12 @@ def to_ion(self): d.add_item("type", _importSym) d.add_item("name", i) r.append(d) + if not self.typecheck: + d = ion.IonPyDict() + d.add_item("type", ion_symbol("option")) + d.add_item("name", "typecheck") + d.add_item("value", "off") + r.append(d) for d in self.decls: r.append(d.to_ion()) return r @@ -1158,6 +1165,14 @@ def from_ion(fp) -> 'Dialect': assert field == "name", f"Unexpected field {field}" read_struct_end(reader) dialect.add_import(value) + case "option": + (field, opt_name) = read_field_string(reader) + assert field == "name", f"Unexpected field {field}" + (field, opt_value) = read_field_string(reader) + assert field == "value", f"Unexpected field {field}" + read_struct_end(reader) + if opt_name == "typecheck": + dialect.typecheck = (opt_value == "on") case "syncat": read_syncatdecl(reader, dialect) case "op": diff --git a/docs/verso/DDMDoc.lean b/docs/verso/DDMDoc.lean index 14c6fe8a6b..55ba6da48d 100644 --- a/docs/verso/DDMDoc.lean +++ b/docs/verso/DDMDoc.lean @@ -191,6 +191,37 @@ declared. This includes transitive imports of the dialect being imported. Imports the dialect _ident_. ::: +## Dialect Options + +Dialect options configure elaboration behavior for programs using the dialect. + +:::paragraph +`dialect_option` _name_ _value_`;` + +Sets the dialect option _name_ to _value_. +::: + +The following options are supported: + +- `typecheck` (`on` | `off`, default `on`): When set to `off`, the elaborator + skips type inference and unification for expression arguments. Implicit type + parameter slots are filled with anonymous type placeholders. Variable name + resolution and global context population still operate normally. + + The flag is a property of the program's primary dialect; imported dialects' + flags are not consulted during elaboration. + + This is intended for cases where the type checker cannot infer implicit type + arguments — notably when template-generated accessors with unresolved type + variable return types are composed with polymorphic functions that require + concrete type arguments for unification. + + With `typecheck off`, type errors in a program are not detected at + elaboration time. They will surface at later pipeline stages (VC generation, + symbolic evaluation, SMT encoding) with less-helpful diagnostics. Only use + this option when the elaboration error is a known type-checker limitation + rather than a real type mismatch. + ## Syntactic Categories Syntactic categories are introduced by the `category` declaration: diff --git a/editors/emacs/core-st-mode.el b/editors/emacs/core-st-mode.el index 6bcfb271d4..f2cffc48e1 100644 --- a/editors/emacs/core-st-mode.el +++ b/editors/emacs/core-st-mode.el @@ -22,12 +22,16 @@ '( "div" "mod" "sdiv" "smod" "safesdiv" "safesmod")) (defvar core-st-builtins - '( "Sequence.length" "Sequence.select" "Sequence.append" - "Sequence.build" "Sequence.update" "Sequence.contains" - "Sequence.take" "Sequence.drop" "str.len" "str.concat" "str.substr" - "str.to.re" "str.in.re" "str.prefixof" "str.suffixof" "re.allchar" - "re.all" "re.range" "re.concat" "re.*" "re.+" "re.loop" "re.union" - "re.inter" "re.comp" "re.none" "Int.DivT" "Int.ModT")) + '( "Sequence.empty" "Sequence.length" "Sequence.select" + "Sequence.append" "Sequence.build" "Sequence.update" + "Sequence.contains" "Sequence.take" "Sequence.drop" "str.len" + "str.concat" "str.substr" "str.to.re" "str.in.re" "str.prefixof" + "str.suffixof" "re.allchar" "re.all" "re.range" "re.concat" "re.*" + "re.+" "re.loop" "re.union" "re.inter" "re.comp" "re.none" + "Int.DivT" "Int.ModT" "Bv.SNegOverflow" "Bv.UNegOverflow" + "Bv.SAddOverflow" "Bv.SSubOverflow" "Bv.SMulOverflow" + "Bv.SDivOverflow" "Bv.UAddOverflow" "Bv.USubOverflow" + "Bv.UMulOverflow")) ;; Font-lock rules (defvar core-st-font-lock-keywords diff --git a/editors/vscode/syntaxes/core-st.tmLanguage.json b/editors/vscode/syntaxes/core-st.tmLanguage.json index 44e4208209..8a1dd9e289 100644 --- a/editors/vscode/syntaxes/core-st.tmLanguage.json +++ b/editors/vscode/syntaxes/core-st.tmLanguage.json @@ -84,7 +84,7 @@ ] }, "function-call": { - "match": "\\b(Sequence\\.length|Sequence\\.select|Sequence\\.append|Sequence\\.build|Sequence\\.update|Sequence\\.contains|Sequence\\.take|Sequence\\.drop|str\\.len|str\\.concat|str\\.substr|str\\.to\\.re|str\\.in\\.re|str\\.prefixof|str\\.suffixof|re\\.allchar|re\\.all|re\\.range|re\\.concat|re\\.\\*|re\\.\\+|re\\.loop|re\\.union|re\\.inter|re\\.comp|re\\.none|Int\\.DivT|Int\\.ModT|bvconcat\\{[0-9]+\\}\\{[0-9]+\\}|bvextract\\{[0-9]+\\}\\{[0-9]+\\}\\{[0-9]+\\})\\b", + "match": "\\b(Sequence\\.empty|Sequence\\.length|Sequence\\.select|Sequence\\.append|Sequence\\.build|Sequence\\.update|Sequence\\.contains|Sequence\\.take|Sequence\\.drop|str\\.len|str\\.concat|str\\.substr|str\\.to\\.re|str\\.in\\.re|str\\.prefixof|str\\.suffixof|re\\.allchar|re\\.all|re\\.range|re\\.concat|re\\.\\*|re\\.\\+|re\\.loop|re\\.union|re\\.inter|re\\.comp|re\\.none|Int\\.DivT|Int\\.ModT|Bv\\.SNegOverflow|Bv\\.UNegOverflow|Bv\\.SAddOverflow|Bv\\.SSubOverflow|Bv\\.SMulOverflow|Bv\\.SDivOverflow|Bv\\.UAddOverflow|Bv\\.USubOverflow|Bv\\.UMulOverflow|bvconcat\\{[0-9]+\\}\\{[0-9]+\\}|bvextract\\{[0-9]+\\}\\{[0-9]+\\}\\{[0-9]+\\})\\b", "captures": { "1": { "name": "support.function.builtin.core-st" } } diff --git a/lakefile.toml b/lakefile.toml index 5110d270ac..b1b4d644de 100644 --- a/lakefile.toml +++ b/lakefile.toml @@ -12,6 +12,9 @@ rev = "bump_to_v4.29.0-rc8" [[lean_lib]] name = "Strata" +[[lean_lib]] +name = "StrataMainLib" + [[lean_exe]] name = "strata" root = "StrataMain"