diff --git a/Cargo.lock b/Cargo.lock index 2d88e5a73..692d213ba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -930,7 +930,6 @@ dependencies = [ "cargo_metadata 0.19.2", "ceno_emul", "ceno_host", - "ceno_recursion", "ceno_zkvm", "clap", "console", @@ -1121,48 +1120,6 @@ dependencies = [ "tiny-keccak", ] -[[package]] -name = "ceno_recursion" -version = "0.1.0" -dependencies = [ - "bincode 1.3.3", - "ceno-examples", - "ceno_emul", - "ceno_host", - "ceno_zkvm", - "clap", - "ff_ext", - "gkr_iop", - "itertools 0.13.0", - "mpcs", - "multilinear_extensions", - "openvm", - "openvm-circuit", - "openvm-continuations", - "openvm-cuda-backend", - "openvm-instructions", - "openvm-native-circuit", - "openvm-native-compiler", - "openvm-native-compiler-derive", - "openvm-native-recursion", - "openvm-rv32im-circuit", - "openvm-sdk", - "openvm-stark-backend", - "openvm-stark-sdk", - "p3", - "parse-size", - "rand 0.8.5", - "serde", - "serde_json", - "sumcheck", - "tracing", - "tracing-forest", - "tracing-subscriber", - "transcript", - "whir", - "witness", -] - [[package]] name = "ceno_rt" version = "0.1.0" @@ -2235,7 +2192,7 @@ dependencies = [ [[package]] name = "ff_ext" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.22#a140b93cf80109ef86c2327b9a940d4cace83628" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=feat%2Fbump-p3#9c6e6b63024811a1b51f7e685bacda2567b539db" dependencies = [ "once_cell", "p3", @@ -2411,6 +2368,7 @@ dependencies = [ "multilinear_extensions", "once_cell", "p3", + "p3-field 0.4.1", "rand 0.8.5", "rayon", "serde", @@ -3040,7 +2998,7 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" dependencies = [ - "spin", + "spin 0.9.8", ] [[package]] @@ -3240,7 +3198,7 @@ dependencies = [ [[package]] name = "mpcs" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.22#a140b93cf80109ef86c2327b9a940d4cace83628" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=feat%2Fbump-p3#9c6e6b63024811a1b51f7e685bacda2567b539db" dependencies = [ "bincode 1.3.3", "clap", @@ -3264,7 +3222,7 @@ dependencies = [ [[package]] name = "multilinear_extensions" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.22#a140b93cf80109ef86c2327b9a940d4cace83628" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=feat%2Fbump-p3#9c6e6b63024811a1b51f7e685bacda2567b539db" dependencies = [ "either", "ff_ext", @@ -3755,7 +3713,6 @@ dependencies = [ "itertools 0.14.0", "libc", "memmap2", - "metrics", "openvm-circuit-derive", "openvm-circuit-primitives", "openvm-circuit-primitives-derive", @@ -3766,8 +3723,8 @@ dependencies = [ "openvm-poseidon2-air", "openvm-stark-backend", "openvm-stark-sdk", - "p3-baby-bear", - "p3-field", + "p3-baby-bear 0.1.0", + "p3-field 0.1.0", "rand 0.8.5", "rustc-hash", "serde", @@ -3847,15 +3804,15 @@ dependencies = [ "openvm-cuda-common", "openvm-stark-backend", "openvm-stark-sdk", - "p3-baby-bear", - "p3-commit", - "p3-dft", - "p3-field", - "p3-fri", - "p3-matrix", - "p3-merkle-tree", - "p3-symmetric", - "p3-util", + "p3-baby-bear 0.1.0", + "p3-commit 0.1.0", + "p3-dft 0.1.0", + "p3-field 0.1.0", + "p3-fri 0.1.0", + "p3-matrix 0.1.0", + "p3-merkle-tree 0.1.0", + "p3-symmetric 0.1.0", + "p3-util 0.1.0", "rustc-hash", "serde", "serde_json", @@ -4091,7 +4048,7 @@ dependencies = [ "openvm-rv32im-transpiler", "openvm-stark-backend", "openvm-stark-sdk", - "p3-field", + "p3-field 0.1.0", "rand 0.8.5", "serde", "static_assertions", @@ -4105,7 +4062,6 @@ source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fv1.4.1-scr dependencies = [ "backtrace", "itertools 0.14.0", - "metrics", "num-bigint 0.4.6", "num-integer", "openvm-circuit", @@ -4138,17 +4094,16 @@ dependencies = [ "cfg-if", "itertools 0.14.0", "lazy_static", - "metrics", "openvm-circuit", "openvm-native-circuit", "openvm-native-compiler", "openvm-native-compiler-derive", "openvm-stark-backend", "openvm-stark-sdk", - "p3-dft", - "p3-fri", - "p3-merkle-tree", - "p3-symmetric", + "p3-dft 0.1.0", + "p3-fri 0.1.0", + "p3-merkle-tree 0.1.0", + "p3-symmetric 0.1.0", "rand 0.8.5", "serde", "serde_json", @@ -4162,7 +4117,7 @@ source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fv1.4.1-scr dependencies = [ "openvm-instructions", "openvm-transpiler", - "p3-field", + "p3-field 0.1.0", ] [[package]] @@ -4248,10 +4203,10 @@ dependencies = [ "openvm-cuda-builder", "openvm-stark-backend", "openvm-stark-sdk", - "p3-monty-31", - "p3-poseidon2", - "p3-poseidon2-air", - "p3-symmetric", + "p3-monty-31 0.1.0", + "p3-poseidon2 0.1.0", + "p3-poseidon2-air 0.1.0", + "p3-symmetric 0.1.0", "rand 0.8.5", "zkhash", ] @@ -4302,7 +4257,7 @@ version = "1.4.1" source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fv1.4.1-scroll-ext#ef22e8ecb9965091783d2c0369b8379e7f683f53" dependencies = [ "openvm-custom-insn", - "p3-field", + "p3-field 0.1.0", "strum_macros", ] @@ -4364,7 +4319,7 @@ dependencies = [ "openvm-stark-backend", "openvm-stark-sdk", "openvm-transpiler", - "p3-fri", + "p3-fri 0.1.0", "rand 0.8.5", "rrs-lib", "serde", @@ -4443,16 +4398,14 @@ dependencies = [ "derive-new 0.7.0", "eyre", "itertools 0.14.0", - "metrics", - "p3-air", - "p3-challenger", - "p3-commit", - "p3-field", - "p3-matrix", - "p3-maybe-rayon", + "p3-air 0.1.0", + "p3-challenger 0.1.0", + "p3-commit 0.1.0", + "p3-field 0.1.0", + "p3-matrix 0.1.0", + "p3-maybe-rayon 0.1.0", "p3-uni-stark", - "p3-util", - "rayon", + "p3-util 0.1.0", "rustc-hash", "serde", "serde_json", @@ -4474,18 +4427,18 @@ dependencies = [ "metrics-tracing-context", "metrics-util", "openvm-stark-backend", - "p3-baby-bear", + "p3-baby-bear 0.1.0", "p3-blake3", "p3-bn254-fr", - "p3-dft", - "p3-fri", - "p3-goldilocks", + "p3-dft 0.1.0", + "p3-fri 0.1.0", + "p3-goldilocks 0.1.0", "p3-keccak", "p3-koala-bear", - "p3-merkle-tree", - "p3-poseidon", - "p3-poseidon2", - "p3-symmetric", + "p3-merkle-tree 0.1.0", + "p3-poseidon 0.1.0", + "p3-poseidon2 0.1.0", + "p3-symmetric 0.1.0", "rand 0.8.5", "serde", "serde_json", @@ -4508,7 +4461,6 @@ dependencies = [ "openvm-platform", "openvm-stark-backend", "rrs-lib", - "rustc-demangle", "thiserror 1.0.69", ] @@ -4555,26 +4507,26 @@ dependencies = [ [[package]] name = "p3" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.22#a140b93cf80109ef86c2327b9a940d4cace83628" -dependencies = [ - "p3-air", - "p3-baby-bear", - "p3-challenger", - "p3-commit", - "p3-dft", - "p3-field", - "p3-fri", - "p3-goldilocks", - "p3-matrix", - "p3-maybe-rayon", - "p3-mds", - "p3-merkle-tree", - "p3-monty-31", - "p3-poseidon", - "p3-poseidon2", - "p3-poseidon2-air", - "p3-symmetric", - "p3-util", +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=feat%2Fbump-p3#9c6e6b63024811a1b51f7e685bacda2567b539db" +dependencies = [ + "p3-air 0.4.1", + "p3-baby-bear 0.4.1", + "p3-challenger 0.4.1", + "p3-commit 0.4.1", + "p3-dft 0.4.1", + "p3-field 0.4.1", + "p3-fri 0.4.1", + "p3-goldilocks 0.4.1", + "p3-matrix 0.4.1", + "p3-maybe-rayon 0.4.1", + "p3-mds 0.4.1", + "p3-merkle-tree 0.4.1", + "p3-monty-31 0.4.1", + "p3-poseidon 0.4.1", + "p3-poseidon2 0.4.1", + "p3-poseidon2-air 0.4.1", + "p3-symmetric 0.4.1", + "p3-util 0.4.1", ] [[package]] @@ -4582,8 +4534,18 @@ name = "p3-air" version = "0.1.0" source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ - "p3-field", - "p3-matrix", + "p3-field 0.1.0", + "p3-matrix 0.1.0", +] + +[[package]] +name = "p3-air" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60414dc4fe4b8676bd4b6136b309185e6b3c006eb5564ef4cf5dfae6d9d47f32" +dependencies = [ + "p3-field 0.4.1", + "p3-matrix 0.4.1", ] [[package]] @@ -4591,23 +4553,38 @@ name = "p3-baby-bear" version = "0.1.0" source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ - "p3-field", - "p3-mds", - "p3-monty-31", - "p3-poseidon2", - "p3-symmetric", + "p3-field 0.1.0", + "p3-mds 0.1.0", + "p3-monty-31 0.1.0", + "p3-poseidon2 0.1.0", + "p3-symmetric 0.1.0", "rand 0.8.5", "serde", ] +[[package]] +name = "p3-baby-bear" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f2fecd03416a20949dc7cd4b481c37d744c4d398467f94213c65279a0f00048" +dependencies = [ + "p3-challenger 0.4.1", + "p3-field 0.4.1", + "p3-mds 0.4.1", + "p3-monty-31 0.4.1", + "p3-poseidon2 0.4.1", + "p3-symmetric 0.4.1", + "rand 0.9.2", +] + [[package]] name = "p3-blake3" version = "0.1.0" source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "blake3", - "p3-symmetric", - "p3-util", + "p3-symmetric 0.1.0", + "p3-util 0.1.0", ] [[package]] @@ -4618,9 +4595,9 @@ dependencies = [ "ff 0.13.1", "halo2curves", "num-bigint 0.4.6", - "p3-field", - "p3-poseidon2", - "p3-symmetric", + "p3-field 0.1.0", + "p3-poseidon2 0.1.0", + "p3-symmetric 0.1.0", "rand 0.8.5", "serde", ] @@ -4630,10 +4607,24 @@ name = "p3-challenger" version = "0.1.0" source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ - "p3-field", - "p3-maybe-rayon", - "p3-symmetric", - "p3-util", + "p3-field 0.1.0", + "p3-maybe-rayon 0.1.0", + "p3-symmetric 0.1.0", + "p3-util 0.1.0", + "tracing", +] + +[[package]] +name = "p3-challenger" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8a66da8af6115b9e2df4363cd55efebf2c6d30de0af3e99dac56dd7b77aff24" +dependencies = [ + "p3-field 0.4.1", + "p3-maybe-rayon 0.4.1", + "p3-monty-31 0.4.1", + "p3-symmetric 0.4.1", + "p3-util 0.4.1", "tracing", ] @@ -4643,11 +4634,26 @@ version = "0.1.0" source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "itertools 0.14.0", - "p3-challenger", - "p3-dft", - "p3-field", - "p3-matrix", - "p3-util", + "p3-challenger 0.1.0", + "p3-dft 0.1.0", + "p3-field 0.1.0", + "p3-matrix 0.1.0", + "p3-util 0.1.0", + "serde", +] + +[[package]] +name = "p3-commit" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95104feb4b9895733f92204ec70ba8944dbab39c39b235c0a00adf1456149619" +dependencies = [ + "itertools 0.14.0", + "p3-challenger 0.4.1", + "p3-dft 0.4.1", + "p3-field 0.4.1", + "p3-matrix 0.4.1", + "p3-util 0.4.1", "serde", ] @@ -4657,10 +4663,25 @@ version = "0.1.0" source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "itertools 0.14.0", - "p3-field", - "p3-matrix", - "p3-maybe-rayon", - "p3-util", + "p3-field 0.1.0", + "p3-matrix 0.1.0", + "p3-maybe-rayon 0.1.0", + "p3-util 0.1.0", + "tracing", +] + +[[package]] +name = "p3-dft" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81b2f57569293b9964b1bae68d64e796bfbf3c271718268beb53a0fb761a5819" +dependencies = [ + "itertools 0.14.0", + "p3-field 0.4.1", + "p3-matrix 0.4.1", + "p3-maybe-rayon 0.4.1", + "p3-util 0.4.1", + "spin 0.10.0", "tracing", ] @@ -4674,58 +4695,126 @@ dependencies = [ "num-integer", "num-traits", "nums", - "p3-maybe-rayon", - "p3-util", + "p3-maybe-rayon 0.1.0", + "p3-util 0.1.0", "rand 0.8.5", "serde", "tracing", ] +[[package]] +name = "p3-field" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56aae7630ff6df83fb7421d5bd97df27620e5f0e29422b7e8f6a294d44cce297" +dependencies = [ + "itertools 0.14.0", + "num-bigint 0.4.6", + "p3-maybe-rayon 0.4.1", + "p3-util 0.4.1", + "paste", + "rand 0.9.2", + "serde", + "tracing", +] + [[package]] name = "p3-fri" version = "0.1.0" source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "itertools 0.14.0", - "p3-challenger", - "p3-commit", - "p3-dft", - "p3-field", - "p3-interpolation", - "p3-matrix", - "p3-maybe-rayon", - "p3-util", + "p3-challenger 0.1.0", + "p3-commit 0.1.0", + "p3-dft 0.1.0", + "p3-field 0.1.0", + "p3-interpolation 0.1.0", + "p3-matrix 0.1.0", + "p3-maybe-rayon 0.1.0", + "p3-util 0.1.0", "rand 0.8.5", "serde", "tracing", ] +[[package]] +name = "p3-fri" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0e9a7053c439444f5c4be80ecc08b255a046bc5d23762abe8a4460ae0fca583" +dependencies = [ + "itertools 0.14.0", + "p3-challenger 0.4.1", + "p3-commit 0.4.1", + "p3-dft 0.4.1", + "p3-field 0.4.1", + "p3-interpolation 0.4.1", + "p3-matrix 0.4.1", + "p3-maybe-rayon 0.4.1", + "p3-util 0.4.1", + "rand 0.9.2", + "serde", + "thiserror 2.0.12", + "tracing", +] + [[package]] name = "p3-goldilocks" version = "0.1.0" source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "num-bigint 0.4.6", - "p3-dft", - "p3-field", - "p3-mds", - "p3-poseidon", - "p3-poseidon2", - "p3-symmetric", - "p3-util", + "p3-dft 0.1.0", + "p3-field 0.1.0", + "p3-mds 0.1.0", + "p3-poseidon 0.1.0", + "p3-poseidon2 0.1.0", + "p3-symmetric 0.1.0", + "p3-util 0.1.0", "rand 0.8.5", "serde", ] +[[package]] +name = "p3-goldilocks" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85324dc45db4196ce0083971393124f5ed03741507f9165d5c923c97890b4838" +dependencies = [ + "num-bigint 0.4.6", + "p3-challenger 0.4.1", + "p3-dft 0.4.1", + "p3-field 0.4.1", + "p3-mds 0.4.1", + "p3-poseidon2 0.4.1", + "p3-symmetric 0.4.1", + "p3-util 0.4.1", + "paste", + "rand 0.9.2", + "serde", +] + [[package]] name = "p3-interpolation" version = "0.1.0" source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ - "p3-field", - "p3-matrix", - "p3-maybe-rayon", - "p3-util", + "p3-field 0.1.0", + "p3-matrix 0.1.0", + "p3-maybe-rayon 0.1.0", + "p3-util 0.1.0", +] + +[[package]] +name = "p3-interpolation" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b0bb6a709b26cead74e7c605f4e51e793642870e54a7c280a05cd66b7914866" +dependencies = [ + "p3-field 0.4.1", + "p3-matrix 0.4.1", + "p3-maybe-rayon 0.4.1", + "p3-util 0.4.1", ] [[package]] @@ -4734,9 +4823,9 @@ version = "0.1.0" source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "itertools 0.14.0", - "p3-field", - "p3-symmetric", - "p3-util", + "p3-field 0.1.0", + "p3-symmetric 0.1.0", + "p3-util 0.1.0", "tiny-keccak", ] @@ -4745,11 +4834,11 @@ name = "p3-keccak-air" version = "0.1.0" source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ - "p3-air", - "p3-field", - "p3-matrix", - "p3-maybe-rayon", - "p3-util", + "p3-air 0.1.0", + "p3-field 0.1.0", + "p3-matrix 0.1.0", + "p3-maybe-rayon 0.1.0", + "p3-util 0.1.0", "rand 0.8.5", "tracing", ] @@ -4759,11 +4848,11 @@ name = "p3-koala-bear" version = "0.1.0" source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ - "p3-field", - "p3-mds", - "p3-monty-31", - "p3-poseidon2", - "p3-symmetric", + "p3-field 0.1.0", + "p3-mds 0.1.0", + "p3-monty-31 0.1.0", + "p3-poseidon2 0.1.0", + "p3-symmetric 0.1.0", "rand 0.8.5", "serde", ] @@ -4774,19 +4863,41 @@ version = "0.1.0" source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "itertools 0.14.0", - "p3-field", - "p3-maybe-rayon", - "p3-util", + "p3-field 0.1.0", + "p3-maybe-rayon 0.1.0", + "p3-util 0.1.0", "rand 0.8.5", "serde", "tracing", "transpose", ] +[[package]] +name = "p3-matrix" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d916550e4261126457d4f139fc3156fc796b1cf2f2687bf1c9b269b1efa8ad42" +dependencies = [ + "itertools 0.14.0", + "p3-field 0.4.1", + "p3-maybe-rayon 0.4.1", + "p3-util 0.4.1", + "rand 0.9.2", + "serde", + "tracing", + "transpose", +] + [[package]] name = "p3-maybe-rayon" version = "0.1.0" source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" + +[[package]] +name = "p3-maybe-rayon" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0db6a290f867061aed54593d48f0dfd7ff2d0f706a603d03209fd0eac79518f3" dependencies = [ "rayon", ] @@ -4797,31 +4908,63 @@ version = "0.1.0" source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "itertools 0.14.0", - "p3-dft", - "p3-field", - "p3-matrix", - "p3-symmetric", - "p3-util", + "p3-dft 0.1.0", + "p3-field 0.1.0", + "p3-matrix 0.1.0", + "p3-symmetric 0.1.0", + "p3-util 0.1.0", "rand 0.8.5", ] +[[package]] +name = "p3-mds" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "745a478473a5f3699f76b284378651eaa9d74e74f820b34ea563a4a72ab8a4a6" +dependencies = [ + "p3-dft 0.4.1", + "p3-field 0.4.1", + "p3-symmetric 0.4.1", + "p3-util 0.4.1", + "rand 0.9.2", +] + [[package]] name = "p3-merkle-tree" version = "0.1.0" source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "itertools 0.14.0", - "p3-commit", - "p3-field", - "p3-matrix", - "p3-maybe-rayon", - "p3-symmetric", - "p3-util", + "p3-commit 0.1.0", + "p3-field 0.1.0", + "p3-matrix 0.1.0", + "p3-maybe-rayon 0.1.0", + "p3-symmetric 0.1.0", + "p3-util 0.1.0", "rand 0.8.5", "serde", "tracing", ] +[[package]] +name = "p3-merkle-tree" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "615f09d1c83ca2ad0dd1f8fb4e496445f9c24a224bac81b98849973f444ee86c" +dependencies = [ + "itertools 0.14.0", + "p3-commit 0.4.1", + "p3-field 0.4.1", + "p3-matrix 0.4.1", + "p3-maybe-rayon 0.4.1", + "p3-symmetric 0.4.1", + "p3-util 0.4.1", + "rand 0.9.2", + "serde", + "thiserror 2.0.12", + "tracing", +] + [[package]] name = "p3-monty-31" version = "0.1.0" @@ -4829,66 +4972,141 @@ source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62c dependencies = [ "itertools 0.14.0", "num-bigint 0.4.6", - "p3-dft", - "p3-field", - "p3-matrix", - "p3-maybe-rayon", - "p3-mds", - "p3-poseidon2", - "p3-symmetric", - "p3-util", + "p3-dft 0.1.0", + "p3-field 0.1.0", + "p3-matrix 0.1.0", + "p3-maybe-rayon 0.1.0", + "p3-mds 0.1.0", + "p3-poseidon2 0.1.0", + "p3-symmetric 0.1.0", + "p3-util 0.1.0", "rand 0.8.5", "serde", "tracing", "transpose", ] +[[package]] +name = "p3-monty-31" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f124f989bc5697728a9e71d2094eda673c45a536c6a8b8ec87b7f3660393aad0" +dependencies = [ + "itertools 0.14.0", + "num-bigint 0.4.6", + "p3-dft 0.4.1", + "p3-field 0.4.1", + "p3-matrix 0.4.1", + "p3-maybe-rayon 0.4.1", + "p3-mds 0.4.1", + "p3-poseidon2 0.4.1", + "p3-symmetric 0.4.1", + "p3-util 0.4.1", + "paste", + "rand 0.9.2", + "serde", + "spin 0.10.0", + "tracing", + "transpose", +] + [[package]] name = "p3-poseidon" version = "0.1.0" source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ - "p3-field", - "p3-mds", - "p3-symmetric", + "p3-field 0.1.0", + "p3-mds 0.1.0", + "p3-symmetric 0.1.0", "rand 0.8.5", ] +[[package]] +name = "p3-poseidon" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0dc0930e45272609b239052346e2abe8965adaf22b8237eddb679d659af53f28" +dependencies = [ + "p3-field 0.4.1", + "p3-mds 0.4.1", + "p3-symmetric 0.4.1", + "rand 0.9.2", +] + [[package]] name = "p3-poseidon2" version = "0.1.0" source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "gcd", - "p3-field", - "p3-mds", - "p3-symmetric", + "p3-field 0.1.0", + "p3-mds 0.1.0", + "p3-symmetric 0.1.0", "rand 0.8.5", ] +[[package]] +name = "p3-poseidon2" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b0c96988fd809e7a3086d8d683ddb93c965f8bb08b37c82e3617d12347bf77f" +dependencies = [ + "p3-field 0.4.1", + "p3-mds 0.4.1", + "p3-symmetric 0.4.1", + "p3-util 0.4.1", + "rand 0.9.2", +] + [[package]] name = "p3-poseidon2-air" version = "0.1.0" source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ - "p3-air", - "p3-field", - "p3-matrix", - "p3-maybe-rayon", - "p3-poseidon2", - "p3-util", + "p3-air 0.1.0", + "p3-field 0.1.0", + "p3-matrix 0.1.0", + "p3-maybe-rayon 0.1.0", + "p3-poseidon2 0.1.0", + "p3-util 0.1.0", "rand 0.8.5", "tikv-jemallocator", "tracing", ] +[[package]] +name = "p3-poseidon2-air" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a0c44c47992126b5eb4f5a33444d6059b883c1ea520f1d34590d46338314178" +dependencies = [ + "p3-air 0.4.1", + "p3-field 0.4.1", + "p3-matrix 0.4.1", + "p3-maybe-rayon 0.4.1", + "p3-poseidon2 0.4.1", + "rand 0.9.2", + "tracing", +] + [[package]] name = "p3-symmetric" version = "0.1.0" source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "itertools 0.14.0", - "p3-field", + "p3-field 0.1.0", + "serde", +] + +[[package]] +name = "p3-symmetric" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dabf1c93a83305b291118dec6632357da69f3137d33fc1791225e38fcb615836" +dependencies = [ + "itertools 0.14.0", + "p3-field 0.4.1", "serde", ] @@ -4898,14 +5116,14 @@ version = "0.1.0" source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "itertools 0.14.0", - "p3-air", - "p3-challenger", - "p3-commit", - "p3-dft", - "p3-field", - "p3-matrix", - "p3-maybe-rayon", - "p3-util", + "p3-air 0.1.0", + "p3-challenger 0.1.0", + "p3-commit 0.1.0", + "p3-dft 0.1.0", + "p3-field 0.1.0", + "p3-matrix 0.1.0", + "p3-maybe-rayon 0.1.0", + "p3-util 0.1.0", "serde", "tracing", ] @@ -4918,6 +5136,15 @@ dependencies = [ "serde", ] +[[package]] +name = "p3-util" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a92074eab13c8a30d23ad7bcf99b82787a04c843133a0cba39ca1cf39d434492" +dependencies = [ + "serde", +] + [[package]] name = "pairing" version = "0.22.0" @@ -5123,7 +5350,7 @@ dependencies = [ [[package]] name = "poseidon" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.22#a140b93cf80109ef86c2327b9a940d4cace83628" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=feat%2Fbump-p3#9c6e6b63024811a1b51f7e685bacda2567b539db" dependencies = [ "ff_ext", "p3", @@ -5438,9 +5665,9 @@ checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" [[package]] name = "rayon" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" dependencies = [ "either", "rayon-core", @@ -5448,9 +5675,9 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.12.1" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" dependencies = [ "crossbeam-deque", "crossbeam-utils", @@ -6080,7 +6307,7 @@ dependencies = [ [[package]] name = "sp1-curves" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.22#a140b93cf80109ef86c2327b9a940d4cace83628" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=feat%2Fbump-p3#9c6e6b63024811a1b51f7e685bacda2567b539db" dependencies = [ "cfg-if", "dashu", @@ -6092,7 +6319,7 @@ dependencies = [ "multilinear_extensions", "num", "p256 0.13.2 (registry+https://github.com/rust-lang/crates.io-index)", - "p3-field", + "p3", "rug", "serde", "snowbridge-amcl", @@ -6105,6 +6332,15 @@ version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +[[package]] +name = "spin" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5fe4ccb98d9c292d56fec89a5e07da7fc4cf0dc11e156b41793132775d3e591" +dependencies = [ + "lock_api", +] + [[package]] name = "spki" version = "0.7.3" @@ -6205,7 +6441,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.22#a140b93cf80109ef86c2327b9a940d4cace83628" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=feat%2Fbump-p3#9c6e6b63024811a1b51f7e685bacda2567b539db" dependencies = [ "either", "ff_ext", @@ -6223,7 +6459,7 @@ dependencies = [ [[package]] name = "sumcheck_macro" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.22#a140b93cf80109ef86c2327b9a940d4cace83628" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=feat%2Fbump-p3#9c6e6b63024811a1b51f7e685bacda2567b539db" dependencies = [ "itertools 0.13.0", "p3", @@ -6630,7 +6866,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.22#a140b93cf80109ef86c2327b9a940d4cace83628" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=feat%2Fbump-p3#9c6e6b63024811a1b51f7e685bacda2567b539db" dependencies = [ "ff_ext", "itertools 0.13.0", @@ -6924,7 +7160,7 @@ dependencies = [ [[package]] name = "whir" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.22#a140b93cf80109ef86c2327b9a940d4cace83628" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=feat%2Fbump-p3#9c6e6b63024811a1b51f7e685bacda2567b539db" dependencies = [ "bincode 1.3.3", "clap", @@ -7211,7 +7447,7 @@ dependencies = [ [[package]] name = "witness" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.22#a140b93cf80109ef86c2327b9a940d4cace83628" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=feat%2Fbump-p3#9c6e6b63024811a1b51f7e685bacda2567b539db" dependencies = [ "ff_ext", "multilinear_extensions", diff --git a/Cargo.toml b/Cargo.toml index b20888473..7a647d9dd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,12 +6,12 @@ members = [ "ceno_serde", "ceno_rt", "ceno_zkvm", - "ceno_recursion", "derive", "examples-builder", "examples", "guest_libs/*", ] +exclude = ["ceno_recursion_v2", "ceno_recursion"] resolver = "2" [workspace.package] @@ -27,16 +27,17 @@ version = "0.1.0" ceno_crypto_primitives = { git = "https://github.com/scroll-tech/ceno-patch.git", package = "ceno_crypto_primitives", branch = "main" } ceno_syscall = { git = "https://github.com/scroll-tech/ceno-patch.git", package = "ceno_syscall", branch = "main" } -ff_ext = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "ff_ext", tag = "v1.0.0-alpha.22" } -mpcs = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "mpcs", tag = "v1.0.0-alpha.22" } -multilinear_extensions = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "multilinear_extensions", tag = "v1.0.0-alpha.22" } -p3 = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "p3", tag = "v1.0.0-alpha.22" } -poseidon = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "poseidon", tag = "v1.0.0-alpha.22" } -sp1-curves = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sp1-curves", tag = "v1.0.0-alpha.22" } -sumcheck = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sumcheck", tag = "v1.0.0-alpha.22" } -transcript = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "transcript", tag = "v1.0.0-alpha.22" } -whir = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "whir", tag = "v1.0.0-alpha.22" } -witness = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "witness", tag = "v1.0.0-alpha.22" } +ff_ext = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "ff_ext", branch = "feat/bump-p3" } +mpcs = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "mpcs", branch = "feat/bump-p3", features = ["whir"] } +multilinear_extensions = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "multilinear_extensions", branch = "feat/bump-p3" } +p3 = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "p3", branch = "feat/bump-p3" } +p3-field = { version = "=0.4.1", default-features = false } +poseidon = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "poseidon", branch = "feat/bump-p3" } +sp1-curves = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sp1-curves", branch = "feat/bump-p3" } +sumcheck = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sumcheck", branch = "feat/bump-p3" } +transcript = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "transcript", branch = "feat/bump-p3" } +whir = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "whir", branch = "feat/bump-p3" } +witness = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "witness", branch = "feat/bump-p3" } anyhow = { version = "1.0", default-features = false } bincode = "1" @@ -60,7 +61,7 @@ proptest = "1" rand = "0.8" rand_chacha = { version = "0.3", features = ["serde1"] } rand_core = "0.6" -rayon = "1.10" +rayon = "^1.11" rustc-hash = "2.0.0" secp = "0.4.1" serde = { version = "1.0", features = ["derive", "rc"] } diff --git a/ceno_cli/Cargo.toml b/ceno_cli/Cargo.toml index 1ddf61247..b63a726cf 100644 --- a/ceno_cli/Cargo.toml +++ b/ceno_cli/Cargo.toml @@ -28,7 +28,6 @@ tikv-jemallocator = { version = "0.6", optional = true } ceno_emul = { path = "../ceno_emul" } ceno_host = { path = "../ceno_host" } -ceno_recursion = { path = "../ceno_recursion" } ceno_zkvm = { path = "../ceno_zkvm" } openvm-circuit.workspace = true @@ -49,7 +48,7 @@ mpcs.workspace = true vergen-git2 = { version = "9.1.0", features = ["build", "cargo", "rustc", "emit_and_set"] } [features] -gpu = ["gkr_iop/gpu", "ceno_zkvm/gpu", "ceno_recursion/gpu", "dep:openvm-cuda-backend", "openvm-native-circuit/cuda"] +gpu = ["gkr_iop/gpu", "ceno_zkvm/gpu", "dep:openvm-cuda-backend", "openvm-native-circuit/cuda"] jemalloc = ["dep:tikv-jemallocator", "ceno_zkvm/jemalloc"] jemalloc-prof = ["jemalloc", "tikv-jemallocator?/profiling"] nightly-features = [ diff --git a/ceno_cli/src/lib.rs b/ceno_cli/src/lib.rs index 680c89419..5b2fe3c4c 100644 --- a/ceno_cli/src/lib.rs +++ b/ceno_cli/src/lib.rs @@ -1 +1 @@ -pub mod sdk; +// SDK temporarily disabled due to OpenVM dependency incompatibility diff --git a/ceno_recursion/src/aggregation/internal.rs b/ceno_recursion/src/aggregation/internal.rs index bd494a0e9..fb3f40c88 100644 --- a/ceno_recursion/src/aggregation/internal.rs +++ b/ceno_recursion/src/aggregation/internal.rs @@ -19,7 +19,7 @@ use openvm_stark_sdk::{ config::{FriParameters, baby_bear_poseidon2::BabyBearPoseidon2Config}, openvm_stark_backend::p3_field::PrimeField32, }; -use p3::field::FieldAlgebra; +use p3_field::PrimeCharacteristicRing as FieldAlgebra; use openvm_continuations::verifier::{ common::{ diff --git a/ceno_recursion/src/aggregation/mod.rs b/ceno_recursion/src/aggregation/mod.rs index 55af4f532..723be01c6 100644 --- a/ceno_recursion/src/aggregation/mod.rs +++ b/ceno_recursion/src/aggregation/mod.rs @@ -26,6 +26,7 @@ use openvm_continuations::{ internal::types::{InternalVmVerifierInput, InternalVmVerifierPvs, VmStarkProof}, }, }; +use crate::field_ext::CanonicalFieldExt; #[cfg(feature = "gpu")] use openvm_cuda_backend::engine::GpuBabyBearPoseidon2Engine as BabyBearPoseidon2Engine; use openvm_native_circuit::{NativeBuilder, NativeConfig}; @@ -56,7 +57,7 @@ use openvm_stark_sdk::{ openvm_stark_backend::keygen::types::MultiStarkVerifyingKey, p3_bn254_fr::Bn254Fr, }; -use p3::field::FieldAlgebra; +use p3_field::PrimeCharacteristicRing as FieldAlgebra; use serde::{Deserialize, Serialize}; use std::{borrow::Borrow, sync::Arc, time::Instant}; pub type RecPcs = Basefold; @@ -720,7 +721,7 @@ mod tests { }; use mpcs::{Basefold, BasefoldRSParams}; use openvm_stark_sdk::{config::setup_tracing_with_log_level, p3_bn254_fr::Bn254Fr}; - use p3::field::FieldAlgebra; + use p3_field::PrimeCharacteristicRing as FieldAlgebra; use std::fs::File; pub fn aggregation_inner_thread() { diff --git a/ceno_recursion/src/arithmetics/mod.rs b/ceno_recursion/src/arithmetics/mod.rs index fc88bd498..e6a1a4a5a 100644 --- a/ceno_recursion/src/arithmetics/mod.rs +++ b/ceno_recursion/src/arithmetics/mod.rs @@ -10,7 +10,8 @@ use openvm_native_circuit::EXT_DEG; use openvm_native_compiler::prelude::*; use openvm_native_compiler_derive::iter_zip; use openvm_native_recursion::challenger::{FeltChallenger, duplex::DuplexChallengerVariable}; -use openvm_stark_backend::p3_field::{FieldAlgebra, FieldExtensionAlgebra}; +use openvm_stark_backend::p3_field::{FieldExtensionAlgebra, PrimeCharacteristicRing as FieldAlgebra}; +use crate::field_ext::CanonicalFieldExt; type E = BabyBearExt4; const MAX_NUM_VARS: usize = 25; @@ -1060,7 +1061,8 @@ mod tests { conversion::{CompilerOptions, convert_program}, ir::Ext, }; - use p3::{babybear::BabyBear, field::FieldAlgebra}; + use p3::babybear::BabyBear; + use p3_field::PrimeCharacteristicRing as FieldAlgebra; use crate::arithmetics::eval_stacked_wellform_address_vec; diff --git a/ceno_recursion/src/basefold_verifier/field.rs b/ceno_recursion/src/basefold_verifier/field.rs index 64eea0328..2580fcbaa 100644 --- a/ceno_recursion/src/basefold_verifier/field.rs +++ b/ceno_recursion/src/basefold_verifier/field.rs @@ -36,7 +36,7 @@ const TWO_ADIC_GENERATORS: [usize; 33] = [ ]; use openvm_native_compiler::prelude::*; -use p3_field::FieldAlgebra; +use p3_field::PrimeCharacteristicRing as FieldAlgebra; fn two_adic_generator( builder: &mut Builder, @@ -51,4 +51,4 @@ fn two_adic_generator( builder.set_value(&two_adic_generator, i, C::F::from_canonical_usize(TWO_ADIC_GENERATORS[i.value()])); }); builder.get(&two_adic_generator, bits) -} \ No newline at end of file +} diff --git a/ceno_recursion/src/basefold_verifier/mmcs.rs b/ceno_recursion/src/basefold_verifier/mmcs.rs index d4a59e0ad..5602f1385 100644 --- a/ceno_recursion/src/basefold_verifier/mmcs.rs +++ b/ceno_recursion/src/basefold_verifier/mmcs.rs @@ -95,7 +95,7 @@ pub mod tests { use openvm_native_circuit::{Native, NativeConfig}; use openvm_native_compiler::asm::AsmBuilder; use openvm_native_recursion::hints::Hintable; - use p3::field::FieldAlgebra; + use p3_field::PrimeCharacteristicRing as FieldAlgebra; use super::{E, F, MmcsCommitment, MmcsVerifierInput, mmcs_verify_batch}; diff --git a/ceno_recursion/src/basefold_verifier/query_phase.rs b/ceno_recursion/src/basefold_verifier/query_phase.rs index 7499207d0..cb28cb070 100644 --- a/ceno_recursion/src/basefold_verifier/query_phase.rs +++ b/ceno_recursion/src/basefold_verifier/query_phase.rs @@ -9,12 +9,14 @@ use openvm_native_recursion::{ use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3::{ commit::ExtensionMmcs, - field::{Field, FieldAlgebra}, + field::Field, }; +use p3_field::PrimeCharacteristicRing as FieldAlgebra; use serde::Deserialize; use super::{basefold::*, extension_mmcs::*, mmcs::*, rs::*, utils::*}; use crate::{arithmetics::eq_eval_with_index, tower_verifier::binding::*}; +use crate::field_ext::CanonicalFieldExt; pub type F = BabyBear; pub type E = BabyBearExt4; diff --git a/ceno_recursion/src/basefold_verifier/rs.rs b/ceno_recursion/src/basefold_verifier/rs.rs index c70ce2d49..9b6d5259b 100644 --- a/ceno_recursion/src/basefold_verifier/rs.rs +++ b/ceno_recursion/src/basefold_verifier/rs.rs @@ -4,7 +4,8 @@ use std::{cell::RefCell, collections::BTreeMap}; use openvm_native_compiler::{asm::AsmConfig, prelude::*}; use openvm_native_recursion::hints::Hintable; use openvm_stark_sdk::p3_baby_bear::BabyBear; -use p3::field::{FieldAlgebra, extension::BinomialExtensionField}; +use p3::field::extension::BinomialExtensionField; +use p3_field::PrimeCharacteristicRing as FieldAlgebra; use serde::Deserialize; use super::{structs::*, utils::pow_felt_bits}; @@ -230,7 +231,8 @@ pub mod tests { use openvm_stark_sdk::{ config::baby_bear_poseidon2::BabyBearPoseidon2Config, p3_baby_bear::BabyBear, }; - use p3::field::{FieldAlgebra, extension::BinomialExtensionField}; + use p3::field::extension::BinomialExtensionField; + use p3_field::PrimeCharacteristicRing as FieldAlgebra; type SC = BabyBearPoseidon2Config; diff --git a/ceno_recursion/src/basefold_verifier/utils.rs b/ceno_recursion/src/basefold_verifier/utils.rs index 7cc2b6c55..831c1f2f5 100644 --- a/ceno_recursion/src/basefold_verifier/utils.rs +++ b/ceno_recursion/src/basefold_verifier/utils.rs @@ -1,6 +1,7 @@ use openvm_native_compiler::ir::*; use openvm_native_recursion::vars::HintSlice; -use p3::{babybear::BabyBear, field::FieldAlgebra}; +use p3::babybear::BabyBear; +use p3_field::PrimeCharacteristicRing as FieldAlgebra; use crate::basefold_verifier::mmcs::MmcsProof; diff --git a/ceno_recursion/src/basefold_verifier/verifier.rs b/ceno_recursion/src/basefold_verifier/verifier.rs index 7b6ad8d2f..873854f84 100644 --- a/ceno_recursion/src/basefold_verifier/verifier.rs +++ b/ceno_recursion/src/basefold_verifier/verifier.rs @@ -13,7 +13,7 @@ use openvm_native_recursion::challenger::{ duplex::DuplexChallengerVariable, }; use openvm_stark_sdk::p3_baby_bear::BabyBear; -use p3::field::FieldAlgebra; +use p3_field::PrimeCharacteristicRing as FieldAlgebra; pub type F = BabyBear; pub type E = BabyBearExt4; @@ -176,7 +176,8 @@ pub mod tests { use openvm_native_recursion::{challenger::duplex::DuplexChallengerVariable, hints::Hintable}; use openvm_stark_backend::p3_challenger::GrindingChallenger; use openvm_stark_sdk::{config::baby_bear_poseidon2::Challenger, p3_baby_bear::BabyBear}; - use p3::field::{Field, FieldAlgebra}; + use p3::field::Field; + use p3_field::PrimeCharacteristicRing as FieldAlgebra; use rand::thread_rng; use serde::Deserialize; use transcript::{BasicTranscript, Transcript}; diff --git a/ceno_recursion/src/extensions/mod.rs b/ceno_recursion/src/extensions/mod.rs index 6e92fbd60..bbf9e992b 100644 --- a/ceno_recursion/src/extensions/mod.rs +++ b/ceno_recursion/src/extensions/mod.rs @@ -17,7 +17,7 @@ mod tests { }; use openvm_stark_backend::{ config::StarkGenericConfig, - p3_field::{Field, FieldAlgebra}, + p3_field::{Field, PrimeCharacteristicRing as FieldAlgebra}, }; use openvm_stark_sdk::{ config::baby_bear_poseidon2::BabyBearPoseidon2Config, p3_baby_bear::BabyBear, diff --git a/ceno_recursion/src/lib.rs b/ceno_recursion/src/lib.rs index 1d60bf584..b6e0a9ebe 100644 --- a/ceno_recursion/src/lib.rs +++ b/ceno_recursion/src/lib.rs @@ -2,6 +2,7 @@ #![allow(clippy::too_many_arguments)] mod arithmetics; mod basefold_verifier; +mod field_ext; pub mod constants; mod tower_verifier; mod transcript; diff --git a/ceno_recursion/src/tower_verifier/program.rs b/ceno_recursion/src/tower_verifier/program.rs index b1fda581d..c64068dd3 100644 --- a/ceno_recursion/src/tower_verifier/program.rs +++ b/ceno_recursion/src/tower_verifier/program.rs @@ -13,7 +13,8 @@ use crate::{ use openvm_native_compiler::prelude::*; use openvm_native_compiler_derive::iter_zip; use openvm_native_recursion::challenger::{FeltChallenger, duplex::DuplexChallengerVariable}; -use openvm_stark_backend::p3_field::FieldAlgebra; +use openvm_stark_backend::p3_field::PrimeCharacteristicRing as FieldAlgebra; +use crate::field_ext::CanonicalFieldExt; const NATIVE_SUMCHECK_CTX_LEN: usize = 9; pub fn iop_verifier_state_verify( @@ -494,7 +495,7 @@ pub fn verify_tower_proof( // use p3_baby_bear::BabyBear; // use p3_field::extension::BinomialExtensionField; // use p3_field::Field; -// use p3_field::FieldAlgebra; +// use p3_field::PrimeCharacteristicRing as FieldAlgebra; // use rand::thread_rng; // // type F = BabyBear; diff --git a/ceno_recursion/src/transcript/mod.rs b/ceno_recursion/src/transcript/mod.rs index 9ea0e1ed4..5792dc3b2 100644 --- a/ceno_recursion/src/transcript/mod.rs +++ b/ceno_recursion/src/transcript/mod.rs @@ -3,9 +3,8 @@ use openvm_native_compiler::prelude::*; use openvm_native_recursion::challenger::{ CanObserveVariable, CanSampleBitsVariable, duplex::DuplexChallengerVariable, }; -use openvm_stark_backend::p3_field::FieldAlgebra; - use crate::arithmetics::challenger_multi_observe; +use crate::field_ext::CanonicalFieldExt; pub fn transcript_observe_label( builder: &mut Builder, diff --git a/ceno_recursion/src/zkvm_verifier/binding.rs b/ceno_recursion/src/zkvm_verifier/binding.rs index e2ed1f4c9..73d89443e 100644 --- a/ceno_recursion/src/zkvm_verifier/binding.rs +++ b/ceno_recursion/src/zkvm_verifier/binding.rs @@ -25,10 +25,11 @@ use openvm_native_compiler::{ }; use openvm_native_compiler_derive::iter_zip; use openvm_native_recursion::hints::{Hintable, VecAutoHintable}; -use openvm_stark_backend::p3_field::{FieldAlgebra, extension::BinomialExtensionField}; +use openvm_stark_backend::p3_field::{PrimeCharacteristicRing as FieldAlgebra, extension::BinomialExtensionField}; use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3::field::FieldExtensionAlgebra; use sumcheck::structs::IOPProof; +use crate::field_ext::CanonicalFieldExt; pub type F = BabyBear; pub type E = BinomialExtensionField; diff --git a/ceno_recursion/src/zkvm_verifier/verifier.rs b/ceno_recursion/src/zkvm_verifier/verifier.rs index 129cbcc5b..130746696 100644 --- a/ceno_recursion/src/zkvm_verifier/verifier.rs +++ b/ceno_recursion/src/zkvm_verifier/verifier.rs @@ -54,8 +54,9 @@ use openvm_native_compiler_derive::iter_zip; use openvm_native_recursion::challenger::{ CanObserveVariable, FeltChallenger, duplex::DuplexChallengerVariable, }; -use openvm_stark_backend::p3_field::FieldAlgebra; +use openvm_stark_backend::p3_field::PrimeCharacteristicRing as FieldAlgebra; use p3::babybear::BabyBear; +use crate::field_ext::CanonicalFieldExt; type F = BabyBear; type E = BabyBearExt4; diff --git a/ceno_recursion_v2/Cargo.lock b/ceno_recursion_v2/Cargo.lock new file mode 100644 index 000000000..5d3ab7f31 --- /dev/null +++ b/ceno_recursion_v2/Cargo.lock @@ -0,0 +1,4483 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "abi_stable" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69d6512d3eb05ffe5004c59c206de7f99c34951504056ce23fc953842f12c445" +dependencies = [ + "abi_stable_derive", + "abi_stable_shared", + "const_panic", + "core_extensions", + "crossbeam-channel", + "generational-arena", + "libloading", + "lock_api", + "parking_lot", + "paste", + "repr_offset", + "rustc_version", + "serde", + "serde_derive", + "serde_json", +] + +[[package]] +name = "abi_stable_derive" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7178468b407a4ee10e881bc7a328a65e739f0863615cca4429d43916b05e898" +dependencies = [ + "abi_stable_shared", + "as_derive_utils", + "core_extensions", + "proc-macro2", + "quote", + "rustc_version", + "syn 1.0.109", + "typed-arena", +] + +[[package]] +name = "abi_stable_shared" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2b5df7688c123e63f4d4d649cba63f2967ba7f7861b1664fca3f77d3dad2b63" +dependencies = [ + "core_extensions", +] + +[[package]] +name = "addr2line" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b5d307320b3181d6d7954e663bd7c774a838b8220fe0593c86d9fb09f498b4b" +dependencies = [ + "gimli", +] + +[[package]] +name = "adler2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" + +[[package]] +name = "ahash" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", + "zerocopy", +] + +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + +[[package]] +name = "ansi_term" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d52a9bb7ec0cf484c551830a7ce27bd20d67eac647e1befb56b0be4ee39a55d2" +dependencies = [ + "winapi", +] + +[[package]] +name = "anstream" +version = "0.6.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43d5b281e737544384e969a5ccad3f1cdd24b48086a0fc1b2a5262a26b8f4f4a" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" + +[[package]] +name = "anstyle-parse" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" +dependencies = [ + "anstyle", + "once_cell_polyfill", + "windows-sys", +] + +[[package]] +name = "anyhow" +version = "1.0.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" + +[[package]] +name = "ark-ff" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec847af850f44ad29048935519032c33da8aa03340876d351dfab5660d2966ba" +dependencies = [ + "ark-ff-asm", + "ark-ff-macros", + "ark-serialize", + "ark-std", + "derivative", + "digest", + "itertools 0.10.5", + "num-bigint", + "num-traits", + "paste", + "rustc_version", + "zeroize", +] + +[[package]] +name = "ark-ff-asm" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ed4aa4fe255d0bc6d79373f7e31d2ea147bcf486cba1be5ba7ea85abdb92348" +dependencies = [ + "quote", + "syn 1.0.109", +] + +[[package]] +name = "ark-ff-macros" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7abe79b0e4288889c4574159ab790824d0033b9fdcb2a112a3182fac2e514565" +dependencies = [ + "num-bigint", + "num-traits", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "ark-serialize" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb7b85a02b83d2f22f89bd5cac66c9c89474240cb6207cb1efc16d098e822a5" +dependencies = [ + "ark-std", + "digest", + "num-bigint", +] + +[[package]] +name = "ark-std" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94893f1e0c6eeab764ade8dc4c0db24caf4fe7cbbaafc0eba0a9030f447b5185" +dependencies = [ + "num-traits", + "rand 0.8.5", +] + +[[package]] +name = "arrayref" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" + +[[package]] +name = "arrayvec" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" + +[[package]] +name = "as_derive_utils" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff3c96645900a44cf11941c111bd08a6573b0e2f9f69bc9264b179d8fae753c4" +dependencies = [ + "core_extensions", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "az" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be5eb007b7cacc6c660343e96f650fedf4b5a77512399eb952ca6642cf8d13f7" + +[[package]] +name = "backtrace" +version = "0.3.76" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb531853791a215d7c62a30daf0dde835f381ab5de4589cfe7c649d2cbe92bd6" +dependencies = [ + "addr2line", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", + "serde", + "windows-link", +] + +[[package]] +name = "base16ct" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c7f02d4ea65f2c1853089ffd8d2787bdbc63de2f0d29dedbcf8ccdfa0ccd4cf" + +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + +[[package]] +name = "base64ct" +version = "1.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06" + +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + +[[package]] +name = "bitcode" +version = "0.6.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a6ed1b54d8dc333e7be604d00fa9262f4635485ffea923647b6521a5fff045d" +dependencies = [ + "bytemuck", + "serde", +] + +[[package]] +name = "bitcoin-io" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dee39a0ee5b4095224a0cfc6bf4cc1baf0f9624b96b367e53b66d974e51d953" + +[[package]] +name = "bitcoin_hashes" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26ec84b80c482df901772e931a9a681e26a1b9ee2302edeff23cb30328745c8b" +dependencies = [ + "bitcoin-io", + "hex-conservative", +] + +[[package]] +name = "bitflags" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" + +[[package]] +name = "bitvec" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" +dependencies = [ + "funty", + "radium", + "tap", + "wyz", +] + +[[package]] +name = "blake2" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" +dependencies = [ + "digest", +] + +[[package]] +name = "blake2b_simd" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b79834656f71332577234b50bfc009996f7449e0c056884e6a02492ded0ca2f3" +dependencies = [ + "arrayref", + "arrayvec", + "constant_time_eq", +] + +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array 0.14.9", +] + +[[package]] +name = "bls12_381" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3c196a77437e7cc2fb515ce413a6401291578b5afc8ecb29a3c7ab957f05941" +dependencies = [ + "ff 0.12.1", + "group 0.12.1", + "pairing", + "rand_core 0.6.4", + "subtle", +] + +[[package]] +name = "bumpalo" +version = "3.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" + +[[package]] +name = "byte-slice-cast" +version = "1.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7575182f7272186991736b70173b0ea045398f984bf5ebbb3804736ce1330c9d" + +[[package]] +name = "bytemuck" +version = "1.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8efb64bd706a16a1bdde310ae86b351e4d21550d98d056f22f8a7f7a2183fec" + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "bytesize" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bd91ee7b2422bcb158d90ef4d14f75ef67f340943fc4149891dcce8f8b972a3" + +[[package]] +name = "cc" +version = "1.2.56" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aebf35691d1bfb0ac386a69bac2fde4dd276fb618cf8bf4f5318fe285e821bb2" +dependencies = [ + "find-msvc-tools", + "jobserver", + "libc", + "shlex", +] + +[[package]] +name = "ceno-examples" +version = "0.1.0" +dependencies = [ + "glob", +] + +[[package]] +name = "ceno_crypto_primitives" +version = "0.1.0" +source = "git+https://github.com/scroll-tech/ceno-patch.git?branch=main#b79232a6fc80c799f584380273fd6d3055b36808" +dependencies = [ + "ceno_syscall", + "elliptic-curve", +] + +[[package]] +name = "ceno_emul" +version = "0.1.0" +dependencies = [ + "anyhow", + "ceno_rt", + "ceno_syscall", + "elf", + "ff_ext", + "itertools 0.13.0", + "k256 0.13.4 (git+https://github.com/scroll-tech/elliptic-curves?branch=ceno%2Fk256-13.4)", + "multilinear_extensions", + "num", + "num-derive", + "num-traits", + "once_cell", + "p256 0.13.2 (git+https://github.com/scroll-tech/elliptic-curves?branch=ceno%2Fk256-13.4)", + "rayon", + "rrs-succinct", + "rustc-hash", + "secp", + "serde", + "smallvec", + "sp1-curves", + "strum", + "strum_macros", + "substrate-bn", + "tiny-keccak", + "tracing", + "typenum", +] + +[[package]] +name = "ceno_host" +version = "0.1.0" +dependencies = [ + "anyhow", + "ceno_emul", + "ceno_serde", + "itertools 0.13.0", + "serde", + "tiny-keccak", +] + +[[package]] +name = "ceno_recursion_v2" +version = "0.1.0" +dependencies = [ + "bincode", + "ceno-examples", + "ceno_emul", + "ceno_host", + "ceno_zkvm", + "clap", + "derive-new 0.6.0", + "eyre", + "ff_ext", + "gkr_iop", + "itertools 0.13.0", + "mpcs", + "multilinear_extensions", + "openvm", + "openvm-circuit", + "openvm-circuit-primitives", + "openvm-continuations", + "openvm-cpu-backend", + "openvm-cuda-backend", + "openvm-cuda-common", + "openvm-poseidon2-air", + "openvm-recursion-circuit", + "openvm-recursion-circuit-derive", + "openvm-stark-backend", + "openvm-stark-sdk", + "openvm-verify-stark-host", + "p3", + "p3-air", + "p3-field", + "p3-matrix", + "p3-maybe-rayon", + "p3-symmetric", + "parse-size", + "rand 0.8.5", + "serde", + "serde_json", + "strum", + "strum_macros", + "sumcheck", + "tracing", + "tracing-forest", + "tracing-subscriber", + "transcript", + "whir", + "witness", +] + +[[package]] +name = "ceno_rt" +version = "0.1.0" +dependencies = [ + "ceno_serde", + "getrandom 0.2.17", + "getrandom 0.3.4", + "serde", +] + +[[package]] +name = "ceno_serde" +version = "0.1.0" +dependencies = [ + "bytemuck", + "serde", +] + +[[package]] +name = "ceno_syscall" +version = "0.1.0" +source = "git+https://github.com/scroll-tech/ceno-patch.git?branch=main#b79232a6fc80c799f584380273fd6d3055b36808" + +[[package]] +name = "ceno_zkvm" +version = "0.1.0" +dependencies = [ + "arrayref", + "base64", + "bincode", + "ceno-examples", + "ceno_emul", + "ceno_host", + "cfg-if", + "clap", + "derive", + "either", + "ff_ext", + "generic-array 1.3.5", + "generic_static", + "gkr_iop", + "glob", + "itertools 0.13.0", + "metrics 0.24.3", + "mpcs", + "multilinear_extensions", + "ndarray", + "num", + "num-bigint", + "once_cell", + "p3", + "parse-size", + "prettytable-rs", + "rand 0.8.5", + "rayon", + "rustc-hash", + "serde", + "serde_json", + "smallvec", + "sp1-curves", + "strum", + "strum_macros", + "sumcheck", + "tempfile", + "tiny-keccak", + "tracing", + "tracing-forest", + "tracing-subscriber", + "transcript", + "typenum", + "whir", + "witness", +] + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "clap" +version = "4.5.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2797f34da339ce31042b27d23607e051786132987f595b02ba4f6a6dffb7030a" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.5.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24a241312cea5059b13574bb9b3861cabf758b879c15190b37b6d6fd63ab6876" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.5.55" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a92793da1a46a5f2a02a6f4c46c6496b28c43638adea8306fcb0caa1634f24e5" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "clap_lex" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a822ea5bc7590f9d40f1ba12c0dc3c2760f3482c6984db1573ad11031420831" + +[[package]] +name = "cobs" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fa961b519f0b462e3a3b4a34b64d119eeaca1d59af726fe450bbba07a9fc0a1" +dependencies = [ + "thiserror 2.0.18", +] + +[[package]] +name = "colorchoice" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" + +[[package]] +name = "const-oid" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" + +[[package]] +name = "const_format" +version = "0.2.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7faa7469a93a566e9ccc1c73fe783b4a65c274c5ace346038dca9c39fe0030ad" +dependencies = [ + "const_format_proc_macros", +] + +[[package]] +name = "const_format_proc_macros" +version = "0.2.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d57c2eccfb16dbac1f4e61e206105db5820c9d26c3c472bc17c774259ef7744" +dependencies = [ + "proc-macro2", + "quote", + "unicode-xid", +] + +[[package]] +name = "const_panic" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e262cdaac42494e3ae34c43969f9cdeb7da178bdb4b66fa6a1ea2edb4c8ae652" +dependencies = [ + "typewit", +] + +[[package]] +name = "constant_time_eq" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d52eff69cd5e647efe296129160853a42795992097e8af39800e1060caeea9b" + +[[package]] +name = "core_extensions" +version = "1.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42bb5e5d0269fd4f739ea6cedaf29c16d81c27a7ce7582008e90eb50dcd57003" +dependencies = [ + "core_extensions_proc_macros", +] + +[[package]] +name = "core_extensions_proc_macros" +version = "1.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "533d38ecd2709b7608fb8e18e4504deb99e9a72879e6aa66373a76d8dc4259ea" + +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + +[[package]] +name = "crossbeam-channel" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + +[[package]] +name = "crunchy" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" + +[[package]] +name = "crypto-bigint" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0dc92fb57ca44df6db8059111ab3af99a63d5d0f8375d9972e319a379c6bab76" +dependencies = [ + "generic-array 0.14.9", + "rand_core 0.6.4", + "subtle", + "zeroize", +] + +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array 0.14.9", + "typenum", +] + +[[package]] +name = "csv" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52cd9d68cf7efc6ddfaaee42e7288d3a99d613d4b50f76ce9827ae0c6e14f938" +dependencies = [ + "csv-core", + "itoa", + "ryu", + "serde_core", +] + +[[package]] +name = "csv-core" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "704a3c26996a80471189265814dbc2c257598b96b8a7feae2d31ace646bb9782" +dependencies = [ + "memchr", +] + +[[package]] +name = "ctor" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67773048316103656a637612c4a62477603b777d91d9c62ff2290f9cde178fdb" +dependencies = [ + "ctor-proc-macro", + "dtor", +] + +[[package]] +name = "ctor-proc-macro" +version = "0.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2931af7e13dc045d8e9d26afccc6fa115d64e115c9c84b1166288b46f6782c2" + +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + +[[package]] +name = "dashu" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85b3e5ac1e23ff1995ef05b912e2b012a8784506987a2651552db2c73fb3d7e0" +dependencies = [ + "dashu-base", + "dashu-float", + "dashu-int", + "dashu-macros", + "dashu-ratio", + "rustversion", +] + +[[package]] +name = "dashu-base" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0b80bf6b85aa68c58ffea2ddb040109943049ce3fbdf4385d0380aef08ef289" + +[[package]] +name = "dashu-float" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85078445a8dbd2e1bd21f04a816f352db8d333643f0c9b78ca7c3d1df71063e7" +dependencies = [ + "dashu-base", + "dashu-int", + "num-modular", + "num-order", + "rustversion", + "static_assertions", +] + +[[package]] +name = "dashu-int" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee99d08031ca34a4d044efbbb21dff9b8c54bb9d8c82a189187c0651ffdb9fbf" +dependencies = [ + "cfg-if", + "dashu-base", + "num-modular", + "num-order", + "rustversion", + "static_assertions", +] + +[[package]] +name = "dashu-macros" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93381c3ef6366766f6e9ed9cf09e4ef9dec69499baf04f0c60e70d653cf0ab10" +dependencies = [ + "dashu-base", + "dashu-float", + "dashu-int", + "dashu-ratio", + "paste", + "proc-macro2", + "quote", + "rustversion", +] + +[[package]] +name = "dashu-ratio" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e33b04dd7ce1ccf8a02a69d3419e354f2bbfdf4eb911a0b7465487248764c9" +dependencies = [ + "dashu-base", + "dashu-float", + "dashu-int", + "num-modular", + "num-order", + "rustversion", +] + +[[package]] +name = "der" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" +dependencies = [ + "const-oid", + "pem-rfc7468", + "zeroize", +] + +[[package]] +name = "derivative" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "derive" +version = "0.1.0" +dependencies = [ + "quote", + "syn 1.0.109", +] + +[[package]] +name = "derive-new" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d150dea618e920167e5973d70ae6ece4385b7164e0d799fe7c122dd0a5d912ad" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "derive-new" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2cdc8d50f426189eef89dac62fabfa0abb27d5cc008f25bf4156a0203325becc" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "derive_more" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a9b99b9cbbe49445b21764dc0625032a89b145a2642e67603e1c936f5458d05" +dependencies = [ + "derive_more-impl", +] + +[[package]] +name = "derive_more-impl" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", + "unicode-xid", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "const-oid", + "crypto-common", + "subtle", +] + +[[package]] +name = "dirs-next" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b98cf8ebf19c3d1b223e151f99a4f9f0690dca41414773390fc824184ac833e1" +dependencies = [ + "cfg-if", + "dirs-sys-next", +] + +[[package]] +name = "dirs-sys-next" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ebda144c4fe02d1f7ea1a7d9641b6fc6b580adcfa024ae48797ecdeb6825b4d" +dependencies = [ + "libc", + "redox_users", + "winapi", +] + +[[package]] +name = "downcast-rs" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75b325c5dbd37f80359721ad39aca5a29fb04c89279657cffdda8736d0c0b9d2" + +[[package]] +name = "dtor" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "404d02eeb088a82cfd873006cb713fe411306c7d182c344905e101fb1167d301" +dependencies = [ + "dtor-proc-macro", +] + +[[package]] +name = "dtor-proc-macro" +version = "0.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f678cf4a922c215c63e0de95eb1ff08a958a81d47e485cf9da1e27bf6305cfa5" + +[[package]] +name = "ecdsa" +version = "0.16.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee27f32b5c5292967d2d4a9d7f1e0b0aed2c15daded5a60300e4abb9d8020bca" +dependencies = [ + "der", + "digest", + "elliptic-curve", + "rfc6979 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)", + "signature", + "spki", +] + +[[package]] +name = "ecdsa" +version = "0.16.9" +source = "git+https://github.com/sp1-patches/signatures.git?tag=patch-16.9-sp1-4.1.0#1880299a48fe7ef249edaa616fd411239fb5daf1" +dependencies = [ + "der", + "digest", + "elliptic-curve", + "rfc6979 0.4.0 (git+https://github.com/sp1-patches/signatures.git?tag=patch-16.9-sp1-4.1.0)", + "signature", + "spki", +] + +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +dependencies = [ + "serde", +] + +[[package]] +name = "elf" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4445909572dbd556c457c849c4ca58623d84b27c8fff1e74b0b4227d8b90d17b" + +[[package]] +name = "elliptic-curve" +version = "0.13.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5e6043086bf7973472e0c7dff2142ea0b680d30e18d9cc40f267efbf222bd47" +dependencies = [ + "base16ct", + "crypto-bigint", + "digest", + "ff 0.13.1", + "generic-array 0.14.9", + "group 0.13.0", + "hkdf", + "pem-rfc7468", + "pkcs8", + "rand_core 0.6.4", + "sec1", + "subtle", + "zeroize", +] + +[[package]] +name = "embedded-io" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef1a6892d9eef45c8fa6b9e0086428a2cca8491aca8f787c534a3d6d0bcb3ced" + +[[package]] +name = "embedded-io" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edd0f118536f44f5ccd48bcb8b111bdc3de888b58c74639dfb034a357d0f206d" + +[[package]] +name = "encode_unicode" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" + +[[package]] +name = "endian-type" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d" + +[[package]] +name = "enum_dispatch" +version = "0.3.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa18ce2bc66555b3218614519ac839ddb759a7d6720732f979ef8d13be147ecd" +dependencies = [ + "once_cell", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "eyre" +version = "0.6.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7cd915d99f24784cdc19fd37ef22b97e3ff0ae756c7e492e9fbfe897d61e2aec" +dependencies = [ + "indenter", + "once_cell", +] + +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + +[[package]] +name = "ff" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d013fc25338cc558c5c2cfbad646908fb23591e2404481826742b651c9af7160" +dependencies = [ + "bitvec", + "rand_core 0.6.4", + "subtle", +] + +[[package]] +name = "ff" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0b50bfb653653f9ca9095b427bed08ab8d75a137839d9ad64eb11810d5b6393" +dependencies = [ + "bitvec", + "rand_core 0.6.4", + "subtle", +] + +[[package]] +name = "ff_ext" +version = "0.1.0" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=feat%2Fbump-p3#9c6e6b63024811a1b51f7e685bacda2567b539db" +dependencies = [ + "once_cell", + "p3", + "rand_core 0.6.4", + "serde", +] + +[[package]] +name = "find-msvc-tools" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" + +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "funty" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" + +[[package]] +name = "generational-arena" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877e94aff08e743b651baaea359664321055749b398adff8740a7399af7796e7" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "generic-array" +version = "0.14.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4bb6743198531e02858aeaea5398fcc883e71851fcbcb5a2f773e2fb6cb1edf2" +dependencies = [ + "typenum", + "version_check", + "zeroize", +] + +[[package]] +name = "generic-array" +version = "1.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eaf57c49a95fd1fe24b90b3033bee6dc7e8f1288d51494cb44e627c295e38542" +dependencies = [ + "rustversion", + "serde_core", + "typenum", +] + +[[package]] +name = "generic_static" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28ccff179d8070317671db09aee6d20affc26e88c5394714553b04f509b43a60" +dependencies = [ + "once_cell", +] + +[[package]] +name = "getrandom" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "libc", + "r-efi 5.3.0", + "wasip2", +] + +[[package]] +name = "getrandom" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" +dependencies = [ + "cfg-if", + "libc", + "r-efi 6.0.0", + "wasip2", + "wasip3", +] + +[[package]] +name = "getset" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cf0fc11e47561d47397154977bc219f4cf809b2974facc3ccb3b89e2436f912" +dependencies = [ + "proc-macro-error2", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "gimli" +version = "0.32.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e629b9b98ef3dd8afe6ca2bd0f89306cec16d43d907889945bc5d6687f2f13c7" + +[[package]] +name = "gkr_iop" +version = "0.1.0" +dependencies = [ + "bincode", + "either", + "ff_ext", + "itertools 0.13.0", + "mpcs", + "multilinear_extensions", + "once_cell", + "p3", + "p3-field", + "rand 0.8.5", + "rayon", + "serde", + "smallvec", + "strum", + "strum_macros", + "sumcheck", + "thiserror 2.0.18", + "thread_local", + "tracing", + "tracing-forest", + "tracing-subscriber", + "transcript", + "witness", +] + +[[package]] +name = "glob" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" + +[[package]] +name = "gmp-mpfr-sys" +version = "1.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60f8970a75c006bb2f8ae79c6768a116dd215fa8346a87aed99bf9d82ca43394" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "group" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5dfbfb3a6cfbd390d5c9564ab283a0349b9b9fcd46a706c1eb10e0db70bfbac7" +dependencies = [ + "ff 0.12.1", + "memuse", + "rand_core 0.6.4", + "subtle", +] + +[[package]] +name = "group" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0f9ef7462f7c099f518d754361858f86d8a07af53ba9af0fe635bbccb151a63" +dependencies = [ + "ff 0.13.1", + "rand_core 0.6.4", + "subtle", +] + +[[package]] +name = "halo2" +version = "0.1.0-beta.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a23c779b38253fe1538102da44ad5bd5378495a61d2c4ee18d64eaa61ae5995" +dependencies = [ + "halo2_proofs", +] + +[[package]] +name = "halo2_proofs" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e925780549adee8364c7f2b685c753f6f3df23bde520c67416e93bf615933760" +dependencies = [ + "blake2b_simd", + "ff 0.12.1", + "group 0.12.1", + "pasta_curves 0.4.1", + "rand_core 0.6.4", + "rayon", +] + +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash", +] + +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash", +] + +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "hermit-abi" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" + +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + +[[package]] +name = "hex-conservative" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fda06d18ac606267c40c04e41b9947729bf8b9efe74bd4e82b61a5f26a510b9f" +dependencies = [ + "arrayvec", +] + +[[package]] +name = "hex-literal" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e712f64ec3850b98572bffac52e2c6f282b29fe6c5fa6d42334b30be438d95c1" + +[[package]] +name = "hkdf" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5f8eb2ad728638ea2c7d47a21db23b7b58a72ed6a38256b8a1849f15fbbdf7" +dependencies = [ + "hmac", +] + +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + +[[package]] +name = "id-arena" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" + +[[package]] +name = "impl-trait-for-tuples" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0eb5a3343abf848c0984fe4604b2b105da9539376e24fc0a3b0007411ae4fd9" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "indenter" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "964de6e86d545b246d84badc0fef527924ace5134f30641c203ef52ba83f58d5" + +[[package]] +name = "indexmap" +version = "2.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" +dependencies = [ + "equivalent", + "hashbrown 0.16.1", + "serde", + "serde_core", +] + +[[package]] +name = "is-terminal" +version = "0.4.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46" +dependencies = [ + "hermit-abi", + "libc", + "windows-sys", +] + +[[package]] +name = "is_terminal_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" + +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + +[[package]] +name = "itertools" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" + +[[package]] +name = "jobserver" +version = "0.1.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" +dependencies = [ + "getrandom 0.3.4", + "libc", +] + +[[package]] +name = "js-sys" +version = "0.3.91" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b49715b7073f385ba4bc528e5747d02e66cb39c6146efb66b781f131f0fb399c" +dependencies = [ + "once_cell", + "wasm-bindgen", +] + +[[package]] +name = "jubjub" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a575df5f985fe1cd5b2b05664ff6accfc46559032b954529fd225a2168d27b0f" +dependencies = [ + "bitvec", + "bls12_381", + "ff 0.12.1", + "group 0.12.1", + "rand_core 0.6.4", + "subtle", +] + +[[package]] +name = "k256" +version = "0.13.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6e3919bbaa2945715f0bb6d3934a173d1e9a59ac23767fbaaef277265a7411b" +dependencies = [ + "cfg-if", + "ecdsa 0.16.9 (registry+https://github.com/rust-lang/crates.io-index)", + "elliptic-curve", + "once_cell", + "sha2", + "signature", +] + +[[package]] +name = "k256" +version = "0.13.4" +source = "git+https://github.com/scroll-tech/elliptic-curves?branch=ceno%2Fk256-13.4#17adc274db2fb10510449026ec785ae4fc234540" +dependencies = [ + "ceno_crypto_primitives", + "ceno_syscall", + "cfg-if", + "ecdsa 0.16.9 (git+https://github.com/sp1-patches/signatures.git?tag=patch-16.9-sp1-4.1.0)", + "elliptic-curve", + "hex", + "once_cell", + "sha2", +] + +[[package]] +name = "keccak" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb26cec98cce3a3d96cbb7bced3c4b16e3d13f27ec56dbd62cbc8f39cfb9d653" +dependencies = [ + "cpufeatures", +] + +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +dependencies = [ + "spin 0.9.8", +] + +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + +[[package]] +name = "libc" +version = "0.2.182" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6800badb6cb2082ffd7b6a67e6125bb39f18782f793520caee8cb8846be06112" + +[[package]] +name = "libloading" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b67380fd3b2fbe7527a606e18729d21c6f3951633d0500574c4dc22d2d638b9f" +dependencies = [ + "cfg-if", + "winapi", +] + +[[package]] +name = "libm" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" + +[[package]] +name = "libredox" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1744e39d1d6a9948f4f388969627434e31128196de472883b39f148769bfe30a" +dependencies = [ + "libc", +] + +[[package]] +name = "linux-raw-sys" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" + +[[package]] +name = "lock_api" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" +dependencies = [ + "scopeguard", +] + +[[package]] +name = "lockfree-object-pool" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9374ef4228402d4b7e403e5838cb880d9ee663314b0a900d5a6aabf0c213552e" + +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "matchers" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" +dependencies = [ + "regex-automata", +] + +[[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "rawpointer", +] + +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + +[[package]] +name = "memmap2" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "714098028fe011992e1c3962653c96b2d578c4b4bce9036e15ff220319b1e0e3" +dependencies = [ + "libc", +] + +[[package]] +name = "memuse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d97bbf43eb4f088f8ca469930cde17fa036207c9a5e02ccc5107c4e8b17c964" + +[[package]] +name = "metrics" +version = "0.23.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3045b4193fbdc5b5681f32f11070da9be3609f189a79f3390706d42587f46bb5" +dependencies = [ + "ahash", + "portable-atomic", +] + +[[package]] +name = "metrics" +version = "0.24.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d5312e9ba3771cfa961b585728215e3d972c950a3eed9252aa093d6301277e8" +dependencies = [ + "ahash", + "portable-atomic", +] + +[[package]] +name = "metrics-tracing-context" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62a6a1f7141f1d9bc7a886b87536bbfc97752e08b369e1e0453a9acfab5f5da4" +dependencies = [ + "indexmap", + "itoa", + "lockfree-object-pool", + "metrics 0.23.1", + "metrics-util", + "once_cell", + "tracing", + "tracing-core", + "tracing-subscriber", +] + +[[package]] +name = "metrics-util" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4259040465c955f9f2f1a4a8a16dc46726169bca0f88e8fb2dbeced487c3e828" +dependencies = [ + "aho-corasick", + "crossbeam-epoch", + "crossbeam-utils", + "hashbrown 0.14.5", + "indexmap", + "metrics 0.23.1", + "num_cpus", + "ordered-float", + "quanta", + "radix_trie", + "sketches-ddsketch", +] + +[[package]] +name = "miniz_oxide" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" +dependencies = [ + "adler2", +] + +[[package]] +name = "mpcs" +version = "0.1.0" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=feat%2Fbump-p3#9c6e6b63024811a1b51f7e685bacda2567b539db" +dependencies = [ + "bincode", + "clap", + "ff_ext", + "itertools 0.13.0", + "multilinear_extensions", + "num-integer", + "p3", + "rand 0.8.5", + "rand_chacha 0.3.1", + "rayon", + "serde", + "sumcheck", + "tracing", + "tracing-subscriber", + "transcript", + "whir", + "witness", +] + +[[package]] +name = "multilinear_extensions" +version = "0.1.0" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=feat%2Fbump-p3#9c6e6b63024811a1b51f7e685bacda2567b539db" +dependencies = [ + "either", + "ff_ext", + "itertools 0.13.0", + "p3", + "rand 0.8.5", + "rayon", + "serde", + "tracing", +] + +[[package]] +name = "ndarray" +version = "0.17.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "520080814a7a6b4a6e9070823bb24b4531daac8c4627e08ba5de8c5ef2f2752d" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + +[[package]] +name = "nibble_vec" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a5d83df9f36fe23f0c3648c6bbb8b0298bb5f1939c8f2704431371f4b84d43" +dependencies = [ + "smallvec", +] + +[[package]] +name = "nu-ansi-term" +version = "0.50.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "num" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23" +dependencies = [ + "num-bigint", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", +] + +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-derive" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-iter" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-modular" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17bb261bf36fa7d83f4c294f834e91256769097b3cb505d44831e0a179ac647f" + +[[package]] +name = "num-order" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "537b596b97c40fcf8056d153049eb22f481c17ebce72a513ec9286e4986d1bb6" +dependencies = [ + "num-modular", +] + +[[package]] +name = "num-rational" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "num_cpus" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b" +dependencies = [ + "hermit-abi", + "libc", +] + +[[package]] +name = "num_enum" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f646caf906c20226733ed5b1374287eb97e3c2a5c227ce668c1f2ce20ae57c9" +dependencies = [ + "num_enum_derive", +] + +[[package]] +name = "num_enum_derive" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcbff9bc912032c62bf65ef1d5aea88983b420f4f839db1e9b0c281a25c9c799" +dependencies = [ + "proc-macro-crate 1.3.1", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "object" +version = "0.37.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff76201f031d8863c38aa7f905eca4f53abbfa15f609db4277d44cd8938f33fe" +dependencies = [ + "memchr", +] + +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + +[[package]] +name = "once_cell_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" + +[[package]] +name = "openvm" +version = "2.0.0-alpha" +source = "git+https://github.com/openvm-org/openvm.git?branch=develop-v2.0.0-beta#9ddd546c306affb7753f8b42870b6b27a7cd3be7" +dependencies = [ + "bytemuck", + "num-bigint", + "openvm-custom-insn", + "openvm-platform", + "openvm-rv32im-guest", + "serde", +] + +[[package]] +name = "openvm-circuit" +version = "2.0.0-alpha" +source = "git+https://github.com/openvm-org/openvm.git?branch=develop-v2.0.0-beta#9ddd546c306affb7753f8b42870b6b27a7cd3be7" +dependencies = [ + "abi_stable", + "backtrace", + "bytesize", + "cfg-if", + "dashmap", + "derivative", + "derive-new 0.6.0", + "derive_more", + "enum_dispatch", + "eyre", + "getset", + "itertools 0.14.0", + "libc", + "memmap2", + "openvm-circuit-derive", + "openvm-circuit-primitives", + "openvm-circuit-primitives-derive", + "openvm-cpu-backend", + "openvm-instructions", + "openvm-poseidon2-air", + "openvm-stark-backend", + "p3-baby-bear", + "p3-field", + "rand 0.9.2", + "rustc-hash", + "serde", + "serde-big-array", + "static_assertions", + "thiserror 1.0.69", + "tracing", +] + +[[package]] +name = "openvm-circuit-derive" +version = "2.0.0-alpha" +source = "git+https://github.com/openvm-org/openvm.git?branch=develop-v2.0.0-beta#9ddd546c306affb7753f8b42870b6b27a7cd3be7" +dependencies = [ + "itertools 0.14.0", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "openvm-circuit-primitives" +version = "2.0.0-alpha" +source = "git+https://github.com/openvm-org/openvm.git?branch=develop-v2.0.0-beta#9ddd546c306affb7753f8b42870b6b27a7cd3be7" +dependencies = [ + "derive-new 0.6.0", + "itertools 0.14.0", + "num-bigint", + "num-traits", + "openvm-circuit-primitives-derive", + "openvm-cpu-backend", + "openvm-cuda-builder 2.0.0-alpha (git+https://github.com/openvm-org/stark-backend.git?branch=develop-v2)", + "openvm-stark-backend", + "rand 0.9.2", + "tracing", +] + +[[package]] +name = "openvm-circuit-primitives-derive" +version = "2.0.0-alpha" +source = "git+https://github.com/openvm-org/openvm.git?branch=develop-v2.0.0-beta#9ddd546c306affb7753f8b42870b6b27a7cd3be7" +dependencies = [ + "itertools 0.14.0", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "openvm-codec-derive" +version = "2.0.0-alpha" +source = "git+https://github.com/hero78119/stark-backend.git?branch=develop-v2#1aa35f00ecbb45d75d383407872ef34fd0192b7b" +dependencies = [ + "proc-macro-crate 1.3.1", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "openvm-continuations" +version = "2.0.0-alpha" +source = "git+https://github.com/openvm-org/openvm.git?branch=develop-v2.0.0-beta#9ddd546c306affb7753f8b42870b6b27a7cd3be7" +dependencies = [ + "cfg-if", + "derive-new 0.6.0", + "eyre", + "itertools 0.14.0", + "num-bigint", + "openvm-circuit", + "openvm-circuit-primitives", + "openvm-cpu-backend", + "openvm-poseidon2-air", + "openvm-recursion-circuit", + "openvm-recursion-circuit-derive", + "openvm-stark-backend", + "openvm-stark-sdk", + "openvm-verify-stark-host", + "p3-air", + "p3-bn254", + "p3-field", + "p3-matrix", + "tracing", +] + +[[package]] +name = "openvm-cpu-backend" +version = "2.0.0-alpha" +source = "git+https://github.com/hero78119/stark-backend.git?branch=develop-v2#1aa35f00ecbb45d75d383407872ef34fd0192b7b" +dependencies = [ + "cfg-if", + "derive-new 0.7.0", + "getset", + "itertools 0.14.0", + "openvm-stark-backend", + "p3-air", + "p3-baby-bear", + "p3-dft", + "p3-field", + "p3-interpolation", + "p3-matrix", + "p3-maybe-rayon", + "p3-util", + "rayon", + "rustc-hash", + "serde", + "thiserror 1.0.69", + "tracing", +] + +[[package]] +name = "openvm-cuda-backend" +version = "2.0.0-alpha" +source = "git+https://github.com/hero78119/stark-backend.git?branch=develop-v2#1aa35f00ecbb45d75d383407872ef34fd0192b7b" +dependencies = [ + "derive-new 0.7.0", + "getset", + "glob", + "itertools 0.14.0", + "openvm-cuda-builder 2.0.0-alpha (git+https://github.com/hero78119/stark-backend.git?branch=develop-v2)", + "openvm-cuda-common", + "openvm-stark-backend", + "openvm-stark-sdk", + "p3-baby-bear", + "p3-dft", + "p3-field", + "p3-symmetric", + "p3-util", + "rand 0.9.2", + "rustc-hash", + "serde", + "thiserror 1.0.69", + "tracing", +] + +[[package]] +name = "openvm-cuda-builder" +version = "2.0.0-alpha" +source = "git+https://github.com/hero78119/stark-backend.git?branch=develop-v2#1aa35f00ecbb45d75d383407872ef34fd0192b7b" +dependencies = [ + "cc", + "glob", +] + +[[package]] +name = "openvm-cuda-builder" +version = "2.0.0-alpha" +source = "git+https://github.com/openvm-org/stark-backend.git?branch=develop-v2#47a0a7523b07d8664f1c0758510962d977a68ec5" +dependencies = [ + "cc", + "glob", +] + +[[package]] +name = "openvm-cuda-common" +version = "2.0.0-alpha" +source = "git+https://github.com/hero78119/stark-backend.git?branch=develop-v2#1aa35f00ecbb45d75d383407872ef34fd0192b7b" +dependencies = [ + "bytesize", + "ctor", + "lazy_static", + "metrics 0.23.1", + "openvm-cuda-builder 2.0.0-alpha (git+https://github.com/hero78119/stark-backend.git?branch=develop-v2)", + "thiserror 1.0.69", + "tracing", +] + +[[package]] +name = "openvm-custom-insn" +version = "0.1.0" +source = "git+https://github.com/openvm-org/openvm.git?branch=develop-v2.0.0-beta#9ddd546c306affb7753f8b42870b6b27a7cd3be7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "openvm-instructions" +version = "2.0.0-alpha" +source = "git+https://github.com/openvm-org/openvm.git?branch=develop-v2.0.0-beta#9ddd546c306affb7753f8b42870b6b27a7cd3be7" +dependencies = [ + "backtrace", + "derive-new 0.6.0", + "itertools 0.14.0", + "num-bigint", + "num-traits", + "openvm-instructions-derive", + "openvm-stark-backend", + "serde", + "strum", + "strum_macros", +] + +[[package]] +name = "openvm-instructions-derive" +version = "2.0.0-alpha" +source = "git+https://github.com/openvm-org/openvm.git?branch=develop-v2.0.0-beta#9ddd546c306affb7753f8b42870b6b27a7cd3be7" +dependencies = [ + "quote", + "syn 2.0.117", +] + +[[package]] +name = "openvm-platform" +version = "2.0.0-alpha" +source = "git+https://github.com/openvm-org/openvm.git?branch=develop-v2.0.0-beta#9ddd546c306affb7753f8b42870b6b27a7cd3be7" +dependencies = [ + "libm", + "openvm-custom-insn", + "openvm-rv32im-guest", +] + +[[package]] +name = "openvm-poseidon2-air" +version = "2.0.0-alpha" +source = "git+https://github.com/openvm-org/openvm.git?branch=develop-v2.0.0-beta#9ddd546c306affb7753f8b42870b6b27a7cd3be7" +dependencies = [ + "derivative", + "lazy_static", + "openvm-cuda-builder 2.0.0-alpha (git+https://github.com/openvm-org/stark-backend.git?branch=develop-v2)", + "openvm-stark-backend", + "openvm-stark-sdk", + "p3-poseidon2", + "p3-poseidon2-air", + "p3-symmetric", + "rand 0.9.2", + "zkhash", +] + +[[package]] +name = "openvm-recursion-circuit" +version = "2.0.0-alpha" +source = "git+https://github.com/openvm-org/openvm.git?branch=develop-v2.0.0-beta#9ddd546c306affb7753f8b42870b6b27a7cd3be7" +dependencies = [ + "derive-new 0.6.0", + "eyre", + "itertools 0.14.0", + "openvm-circuit", + "openvm-circuit-primitives", + "openvm-cpu-backend", + "openvm-poseidon2-air", + "openvm-recursion-circuit-derive", + "openvm-stark-backend", + "openvm-stark-sdk", + "p3-air", + "p3-baby-bear", + "p3-field", + "p3-matrix", + "p3-maybe-rayon", + "p3-symmetric", + "strum", + "strum_macros", + "tracing", +] + +[[package]] +name = "openvm-recursion-circuit-derive" +version = "2.0.0-alpha" +source = "git+https://github.com/openvm-org/openvm.git?branch=develop-v2.0.0-beta#9ddd546c306affb7753f8b42870b6b27a7cd3be7" +dependencies = [ + "quote", + "syn 2.0.117", +] + +[[package]] +name = "openvm-rv32im-guest" +version = "2.0.0-alpha" +source = "git+https://github.com/openvm-org/openvm.git?branch=develop-v2.0.0-beta#9ddd546c306affb7753f8b42870b6b27a7cd3be7" +dependencies = [ + "openvm-custom-insn", + "p3-field", + "strum_macros", +] + +[[package]] +name = "openvm-stark-backend" +version = "2.0.0-alpha" +source = "git+https://github.com/hero78119/stark-backend.git?branch=develop-v2#1aa35f00ecbb45d75d383407872ef34fd0192b7b" +dependencies = [ + "cfg-if", + "derivative", + "derive-new 0.7.0", + "eyre", + "getset", + "hex-literal", + "itertools 0.14.0", + "metrics 0.23.1", + "openvm-codec-derive", + "p3-air", + "p3-challenger", + "p3-dft", + "p3-field", + "p3-interpolation", + "p3-matrix", + "p3-maybe-rayon", + "p3-symmetric", + "p3-util", + "postcard", + "rayon", + "rustc-hash", + "serde", + "serde_json", + "thiserror 1.0.69", + "tracing", +] + +[[package]] +name = "openvm-stark-sdk" +version = "2.0.0-alpha" +source = "git+https://github.com/hero78119/stark-backend.git?branch=develop-v2#1aa35f00ecbb45d75d383407872ef34fd0192b7b" +dependencies = [ + "dashmap", + "derive-new 0.7.0", + "eyre", + "hex-literal", + "itertools 0.14.0", + "metrics 0.23.1", + "metrics-tracing-context", + "metrics-util", + "num-bigint", + "openvm-cpu-backend", + "openvm-stark-backend", + "p3-baby-bear", + "p3-bn254", + "p3-field", + "p3-poseidon2", + "rand 0.9.2", + "serde", + "serde_json", + "static_assertions", + "tracing", + "tracing-forest", + "tracing-subscriber", + "zkhash", +] + +[[package]] +name = "openvm-verify-stark-host" +version = "2.0.0-alpha" +source = "git+https://github.com/openvm-org/openvm.git?branch=develop-v2.0.0-beta#9ddd546c306affb7753f8b42870b6b27a7cd3be7" +dependencies = [ + "bitcode", + "eyre", + "openvm-circuit", + "openvm-recursion-circuit-derive", + "openvm-stark-backend", + "openvm-stark-sdk", + "p3-field", + "serde", + "thiserror 1.0.69", + "zstd", +] + +[[package]] +name = "ordered-float" +version = "4.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7bb71e1b3fa6ca1c61f383464aaf2bb0e2f8e772a1f01d486832464de363b951" +dependencies = [ + "num-traits", +] + +[[package]] +name = "p256" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c9863ad85fa8f4460f9c48cb909d38a0d689dba1f6f6988a5e3e0d31071bcd4b" +dependencies = [ + "ecdsa 0.16.9 (registry+https://github.com/rust-lang/crates.io-index)", + "elliptic-curve", + "primeorder 0.13.6 (registry+https://github.com/rust-lang/crates.io-index)", + "sha2", +] + +[[package]] +name = "p256" +version = "0.13.2" +source = "git+https://github.com/scroll-tech/elliptic-curves?branch=ceno%2Fk256-13.4#17adc274db2fb10510449026ec785ae4fc234540" +dependencies = [ + "ceno_crypto_primitives", + "ceno_syscall", + "ecdsa 0.16.9 (registry+https://github.com/rust-lang/crates.io-index)", + "elliptic-curve", + "primeorder 0.13.6 (git+https://github.com/scroll-tech/elliptic-curves?branch=ceno%2Fk256-13.4)", + "sha2", +] + +[[package]] +name = "p3" +version = "0.1.0" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=feat%2Fbump-p3#9c6e6b63024811a1b51f7e685bacda2567b539db" +dependencies = [ + "p3-air", + "p3-baby-bear", + "p3-challenger", + "p3-commit", + "p3-dft", + "p3-field", + "p3-fri", + "p3-goldilocks", + "p3-matrix", + "p3-maybe-rayon", + "p3-mds", + "p3-merkle-tree", + "p3-monty-31", + "p3-poseidon", + "p3-poseidon2", + "p3-poseidon2-air", + "p3-symmetric", + "p3-util", +] + +[[package]] +name = "p3-air" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60414dc4fe4b8676bd4b6136b309185e6b3c006eb5564ef4cf5dfae6d9d47f32" +dependencies = [ + "p3-field", + "p3-matrix", +] + +[[package]] +name = "p3-baby-bear" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f2fecd03416a20949dc7cd4b481c37d744c4d398467f94213c65279a0f00048" +dependencies = [ + "p3-challenger", + "p3-field", + "p3-mds", + "p3-monty-31", + "p3-poseidon2", + "p3-symmetric", + "rand 0.9.2", +] + +[[package]] +name = "p3-bn254" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c408855a82df5911b8e877acc2fd48f22534b80a4984783fa2a292acdf52e6a8" +dependencies = [ + "num-bigint", + "p3-field", + "p3-poseidon2", + "p3-symmetric", + "p3-util", + "paste", + "rand 0.9.2", + "serde", +] + +[[package]] +name = "p3-challenger" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8a66da8af6115b9e2df4363cd55efebf2c6d30de0af3e99dac56dd7b77aff24" +dependencies = [ + "p3-field", + "p3-maybe-rayon", + "p3-monty-31", + "p3-symmetric", + "p3-util", + "tracing", +] + +[[package]] +name = "p3-commit" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95104feb4b9895733f92204ec70ba8944dbab39c39b235c0a00adf1456149619" +dependencies = [ + "itertools 0.14.0", + "p3-challenger", + "p3-dft", + "p3-field", + "p3-matrix", + "p3-util", + "serde", +] + +[[package]] +name = "p3-dft" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81b2f57569293b9964b1bae68d64e796bfbf3c271718268beb53a0fb761a5819" +dependencies = [ + "itertools 0.14.0", + "p3-field", + "p3-matrix", + "p3-maybe-rayon", + "p3-util", + "spin 0.10.0", + "tracing", +] + +[[package]] +name = "p3-field" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56aae7630ff6df83fb7421d5bd97df27620e5f0e29422b7e8f6a294d44cce297" +dependencies = [ + "itertools 0.14.0", + "num-bigint", + "p3-maybe-rayon", + "p3-util", + "paste", + "rand 0.9.2", + "serde", + "tracing", +] + +[[package]] +name = "p3-fri" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0e9a7053c439444f5c4be80ecc08b255a046bc5d23762abe8a4460ae0fca583" +dependencies = [ + "itertools 0.14.0", + "p3-challenger", + "p3-commit", + "p3-dft", + "p3-field", + "p3-interpolation", + "p3-matrix", + "p3-maybe-rayon", + "p3-util", + "rand 0.9.2", + "serde", + "thiserror 2.0.18", + "tracing", +] + +[[package]] +name = "p3-goldilocks" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85324dc45db4196ce0083971393124f5ed03741507f9165d5c923c97890b4838" +dependencies = [ + "num-bigint", + "p3-challenger", + "p3-dft", + "p3-field", + "p3-mds", + "p3-poseidon2", + "p3-symmetric", + "p3-util", + "paste", + "rand 0.9.2", + "serde", +] + +[[package]] +name = "p3-interpolation" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b0bb6a709b26cead74e7c605f4e51e793642870e54a7c280a05cd66b7914866" +dependencies = [ + "p3-field", + "p3-matrix", + "p3-maybe-rayon", + "p3-util", +] + +[[package]] +name = "p3-matrix" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d916550e4261126457d4f139fc3156fc796b1cf2f2687bf1c9b269b1efa8ad42" +dependencies = [ + "itertools 0.14.0", + "p3-field", + "p3-maybe-rayon", + "p3-util", + "rand 0.9.2", + "serde", + "tracing", + "transpose", +] + +[[package]] +name = "p3-maybe-rayon" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0db6a290f867061aed54593d48f0dfd7ff2d0f706a603d03209fd0eac79518f3" +dependencies = [ + "rayon", +] + +[[package]] +name = "p3-mds" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "745a478473a5f3699f76b284378651eaa9d74e74f820b34ea563a4a72ab8a4a6" +dependencies = [ + "p3-dft", + "p3-field", + "p3-symmetric", + "p3-util", + "rand 0.9.2", +] + +[[package]] +name = "p3-merkle-tree" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "615f09d1c83ca2ad0dd1f8fb4e496445f9c24a224bac81b98849973f444ee86c" +dependencies = [ + "itertools 0.14.0", + "p3-commit", + "p3-field", + "p3-matrix", + "p3-maybe-rayon", + "p3-symmetric", + "p3-util", + "rand 0.9.2", + "serde", + "thiserror 2.0.18", + "tracing", +] + +[[package]] +name = "p3-monty-31" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f124f989bc5697728a9e71d2094eda673c45a536c6a8b8ec87b7f3660393aad0" +dependencies = [ + "itertools 0.14.0", + "num-bigint", + "p3-dft", + "p3-field", + "p3-matrix", + "p3-maybe-rayon", + "p3-mds", + "p3-poseidon2", + "p3-symmetric", + "p3-util", + "paste", + "rand 0.9.2", + "serde", + "spin 0.10.0", + "tracing", + "transpose", +] + +[[package]] +name = "p3-poseidon" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0dc0930e45272609b239052346e2abe8965adaf22b8237eddb679d659af53f28" +dependencies = [ + "p3-field", + "p3-mds", + "p3-symmetric", + "rand 0.9.2", +] + +[[package]] +name = "p3-poseidon2" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b0c96988fd809e7a3086d8d683ddb93c965f8bb08b37c82e3617d12347bf77f" +dependencies = [ + "p3-field", + "p3-mds", + "p3-symmetric", + "p3-util", + "rand 0.9.2", +] + +[[package]] +name = "p3-poseidon2-air" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a0c44c47992126b5eb4f5a33444d6059b883c1ea520f1d34590d46338314178" +dependencies = [ + "p3-air", + "p3-field", + "p3-matrix", + "p3-maybe-rayon", + "p3-poseidon2", + "rand 0.9.2", + "tracing", +] + +[[package]] +name = "p3-symmetric" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dabf1c93a83305b291118dec6632357da69f3137d33fc1791225e38fcb615836" +dependencies = [ + "itertools 0.14.0", + "p3-field", + "serde", +] + +[[package]] +name = "p3-util" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a92074eab13c8a30d23ad7bcf99b82787a04c843133a0cba39ca1cf39d434492" +dependencies = [ + "serde", +] + +[[package]] +name = "pairing" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "135590d8bdba2b31346f9cd1fb2a912329f5135e832a4f422942eb6ead8b6b3b" +dependencies = [ + "group 0.12.1", +] + +[[package]] +name = "parity-scale-codec" +version = "3.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "799781ae679d79a948e13d4824a40970bfa500058d245760dd857301059810fa" +dependencies = [ + "arrayvec", + "byte-slice-cast", + "const_format", + "impl-trait-for-tuples", + "parity-scale-codec-derive", + "rustversion", +] + +[[package]] +name = "parity-scale-codec-derive" +version = "3.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34b4653168b563151153c9e4c08ebed57fb8262bebfa79711552fa983c623e7a" +dependencies = [ + "proc-macro-crate 3.5.0", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "parking_lot" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-link", +] + +[[package]] +name = "parse-size" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "487f2ccd1e17ce8c1bfab3a65c89525af41cfad4c8659021a1e9a2aacd73b89b" + +[[package]] +name = "pasta_curves" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cc65faf8e7313b4b1fbaa9f7ca917a0eed499a9663be71477f87993604341d8" +dependencies = [ + "blake2b_simd", + "ff 0.12.1", + "group 0.12.1", + "lazy_static", + "rand 0.8.5", + "static_assertions", + "subtle", +] + +[[package]] +name = "pasta_curves" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3e57598f73cc7e1b2ac63c79c517b31a0877cd7c402cdcaa311b5208de7a095" +dependencies = [ + "blake2b_simd", + "ff 0.13.1", + "group 0.13.0", + "lazy_static", + "rand 0.8.5", + "static_assertions", + "subtle", +] + +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + +[[package]] +name = "pem-rfc7468" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" +dependencies = [ + "base64ct", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" + +[[package]] +name = "pkcs8" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" +dependencies = [ + "der", + "spki", +] + +[[package]] +name = "pkg-config" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" + +[[package]] +name = "portable-atomic" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" + +[[package]] +name = "portable-atomic-util" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a9db96d7fa8782dd8c15ce32ffe8680bbd1e978a43bf51a34d39483540495f5" +dependencies = [ + "portable-atomic", +] + +[[package]] +name = "poseidon" +version = "0.1.0" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=feat%2Fbump-p3#9c6e6b63024811a1b51f7e685bacda2567b539db" +dependencies = [ + "ff_ext", + "p3", + "serde", +] + +[[package]] +name = "postcard" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6764c3b5dd454e283a30e6dfe78e9b31096d9e32036b5d1eaac7a6119ccb9a24" +dependencies = [ + "cobs", + "embedded-io 0.4.0", + "embedded-io 0.6.1", + "serde", +] + +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn 2.0.117", +] + +[[package]] +name = "prettytable-rs" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eea25e07510aa6ab6547308ebe3c036016d162b8da920dbb079e3ba8acf3d95a" +dependencies = [ + "csv", + "encode_unicode", + "is-terminal", + "lazy_static", + "term", + "unicode-width", +] + +[[package]] +name = "primeorder" +version = "0.13.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "353e1ca18966c16d9deb1c69278edbc5f194139612772bd9537af60ac231e1e6" +dependencies = [ + "elliptic-curve", +] + +[[package]] +name = "primeorder" +version = "0.13.6" +source = "git+https://github.com/scroll-tech/elliptic-curves?branch=ceno%2Fk256-13.4#17adc274db2fb10510449026ec785ae4fc234540" +dependencies = [ + "elliptic-curve", +] + +[[package]] +name = "proc-macro-crate" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f4c021e1093a56626774e81216a4ce732a735e5bad4868a03f3ed65ca0c3919" +dependencies = [ + "once_cell", + "toml_edit 0.19.15", +] + +[[package]] +name = "proc-macro-crate" +version = "3.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e67ba7e9b2b56446f1d419b1d807906278ffa1a658a8a5d8a39dcb1f5a78614f" +dependencies = [ + "toml_edit 0.25.4+spec-1.1.0", +] + +[[package]] +name = "proc-macro-error-attr2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96de42df36bb9bba5542fe9f1a054b8cc87e172759a1868aa05c1f3acc89dfc5" +dependencies = [ + "proc-macro2", + "quote", +] + +[[package]] +name = "proc-macro-error2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11ec05c52be0a07b08061f7dd003e7d7092e0472bc731b4af7bb1ef876109802" +dependencies = [ + "proc-macro-error-attr2", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quanta" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3ab5a9d756f0d97bdc89019bd2e4ea098cf9cde50ee7564dde6b81ccc8f06c7" +dependencies = [ + "crossbeam-utils", + "libc", + "once_cell", + "raw-cpuid", + "wasi", + "web-sys", + "winapi", +] + +[[package]] +name = "quote" +version = "1.0.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "r-efi" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" + +[[package]] +name = "radium" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" + +[[package]] +name = "radix_trie" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c069c179fcdc6a2fe24d8d18305cf085fdbd4f922c041943e203685d6a1c58fd" +dependencies = [ + "endian-type", + "nibble_vec", +] + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + +[[package]] +name = "rand" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +dependencies = [ + "rand_chacha 0.9.0", + "rand_core 0.9.5", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core 0.6.4", + "serde", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core 0.9.5", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom 0.2.17", +] + +[[package]] +name = "rand_core" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" +dependencies = [ + "getrandom 0.3.4", +] + +[[package]] +name = "raw-cpuid" +version = "11.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186" +dependencies = [ + "bitflags", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "rayon" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + +[[package]] +name = "redox_users" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" +dependencies = [ + "getrandom 0.2.17", + "libredox", + "thiserror 1.0.69", +] + +[[package]] +name = "regex-automata" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" + +[[package]] +name = "repr_offset" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb1070755bd29dffc19d0971cab794e607839ba2ef4b69a9e6fbc8733c1b72ea" +dependencies = [ + "tstr", +] + +[[package]] +name = "rfc6979" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dd2a808d456c4a54e300a23e9f5a67e122c3024119acbfd73e3bf664491cb2" +dependencies = [ + "hmac", + "subtle", +] + +[[package]] +name = "rfc6979" +version = "0.4.0" +source = "git+https://github.com/sp1-patches/signatures.git?tag=patch-16.9-sp1-4.1.0#1880299a48fe7ef249edaa616fd411239fb5daf1" +dependencies = [ + "hmac", + "subtle", +] + +[[package]] +name = "rrs-succinct" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3372685893a9f67d18e98e792d690017287fd17379a83d798d958e517d380fa9" +dependencies = [ + "downcast-rs", + "num_enum", + "paste", +] + +[[package]] +name = "rug" +version = "1.28.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de190ec858987c79cad4da30e19e546139b3339331282832af004d0ea7829639" +dependencies = [ + "az", + "gmp-mpfr-sys", + "libc", + "libm", +] + +[[package]] +name = "rustc-demangle" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b50b8869d9fc858ce7266cce0194bd74df58b9d0e3f6df3a9fc8eb470d95c09d" + +[[package]] +name = "rustc-hash" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" + +[[package]] +name = "rustc-hex" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e75f6a532d0fd9f7f13144f392b6ad56a32696bfcd9c78f797f16bbb6f072d6" + +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + +[[package]] +name = "rustix" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys", +] + +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "ryu" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" + +[[package]] +name = "scale-info" +version = "2.11.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "346a3b32eba2640d17a9cb5927056b08f3de90f65b72fe09402c2ad07d684d0b" +dependencies = [ + "cfg-if", + "derive_more", + "parity-scale-codec", + "scale-info-derive", +] + +[[package]] +name = "scale-info-derive" +version = "2.11.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6630024bf739e2179b91fb424b28898baf819414262c5d376677dbff1fe7ebf" +dependencies = [ + "proc-macro-crate 3.5.0", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "sec1" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3e97a565f76233a6003f9f5c54be1d9c5bdfa3eccfb189469f11ec4901c47dc" +dependencies = [ + "base16ct", + "der", + "generic-array 0.14.9", + "pkcs8", + "subtle", + "zeroize", +] + +[[package]] +name = "secp" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85ed54b1141d8cec428d8a4abf01282755ba4e4c8a621dd23fa2e0ed761814c2" +dependencies = [ + "base16ct", + "once_cell", + "secp256k1", + "subtle", +] + +[[package]] +name = "secp256k1" +version = "0.30.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b50c5943d326858130af85e049f2661ba3c78b26589b8ab98e65e80ae44a1252" +dependencies = [ + "bitcoin_hashes", + "rand 0.8.5", + "secp256k1-sys", +] + +[[package]] +name = "secp256k1-sys" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4387882333d3aa8cb20530a17c69a3752e97837832f34f6dccc760e715001d9" +dependencies = [ + "cc", +] + +[[package]] +name = "semver" +version = "1.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde-big-array" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11fc7cc2c76d73e0f27ee52abbd64eec84d46f370c88371120433196934e4b7f" +dependencies = [ + "serde", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "serde_json" +version = "1.0.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "sha2" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + +[[package]] +name = "sha3" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75872d278a8f37ef87fa0ddbda7802605cb18344497949862c0d4dcb291eba60" +dependencies = [ + "digest", + "keccak", +] + +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "signature" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" +dependencies = [ + "digest", + "rand_core 0.6.4", +] + +[[package]] +name = "sketches-ddsketch" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85636c14b73d81f541e525f585c0a2109e6744e1565b5c1668e31c70c10ed65c" + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" +dependencies = [ + "serde", +] + +[[package]] +name = "snowbridge-amcl" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "460a9ed63cdf03c1b9847e8a12a5f5ba19c4efd5869e4a737e05be25d7c427e5" +dependencies = [ + "parity-scale-codec", + "scale-info", +] + +[[package]] +name = "sp1-curves" +version = "0.1.0" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=feat%2Fbump-p3#9c6e6b63024811a1b51f7e685bacda2567b539db" +dependencies = [ + "cfg-if", + "dashu", + "elliptic-curve", + "ff_ext", + "generic-array 1.3.5", + "itertools 0.13.0", + "k256 0.13.4 (registry+https://github.com/rust-lang/crates.io-index)", + "multilinear_extensions", + "num", + "p256 0.13.2 (registry+https://github.com/rust-lang/crates.io-index)", + "p3", + "rug", + "serde", + "snowbridge-amcl", + "typenum", +] + +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" + +[[package]] +name = "spin" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5fe4ccb98d9c292d56fec89a5e07da7fc4cf0dc11e156b41793132775d3e591" +dependencies = [ + "lock_api", +] + +[[package]] +name = "spki" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" +dependencies = [ + "base64ct", + "der", +] + +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + +[[package]] +name = "strength_reduce" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" + +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + +[[package]] +name = "strum" +version = "0.26.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" +dependencies = [ + "strum_macros", +] + +[[package]] +name = "strum_macros" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.117", +] + +[[package]] +name = "substrate-bn" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b5bbfa79abbae15dd642ea8176a21a635ff3c00059961d1ea27ad04e5b441c" +dependencies = [ + "byteorder", + "crunchy", + "lazy_static", + "rand 0.8.5", + "rustc-hex", +] + +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + +[[package]] +name = "sumcheck" +version = "0.1.0" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=feat%2Fbump-p3#9c6e6b63024811a1b51f7e685bacda2567b539db" +dependencies = [ + "either", + "ff_ext", + "itertools 0.13.0", + "multilinear_extensions", + "p3", + "rayon", + "serde", + "sumcheck_macro", + "thiserror 1.0.69", + "tracing", + "transcript", +] + +[[package]] +name = "sumcheck_macro" +version = "0.1.0" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=feat%2Fbump-p3#9c6e6b63024811a1b51f7e685bacda2567b539db" +dependencies = [ + "itertools 0.13.0", + "p3", + "proc-macro2", + "quote", + "rand 0.8.5", + "syn 2.0.117", +] + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" + +[[package]] +name = "tempfile" +version = "3.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82a72c767771b47409d2345987fda8628641887d5466101319899796367354a0" +dependencies = [ + "fastrand", + "getrandom 0.4.2", + "once_cell", + "rustix", + "windows-sys", +] + +[[package]] +name = "term" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c59df8ac95d96ff9bede18eb7300b0fda5e5d8d90960e76f8e14ae765eedbf1f" +dependencies = [ + "dirs-next", + "rustversion", + "winapi", +] + +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" +dependencies = [ + "thiserror-impl 2.0.18", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "thread_local" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "tiny-keccak" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237" +dependencies = [ + "crunchy", +] + +[[package]] +name = "toml_datetime" +version = "0.6.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22cddaf88f4fbc13c51aebbf5f8eceb5c7c5a9da2ac40a13519eb5b0a0e8f11c" + +[[package]] +name = "toml_datetime" +version = "1.0.0+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32c2555c699578a4f59f0cc68e5116c8d7cabbd45e1409b989d4be085b53f13e" +dependencies = [ + "serde_core", +] + +[[package]] +name = "toml_edit" +version = "0.19.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b5bb770da30e5cbfde35a2d7b9b8a2c4b8ef89548a7a6aeab5c9a576e3e7421" +dependencies = [ + "indexmap", + "toml_datetime 0.6.11", + "winnow 0.5.40", +] + +[[package]] +name = "toml_edit" +version = "0.25.4+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7193cbd0ce53dc966037f54351dbbcf0d5a642c7f0038c382ef9e677ce8c13f2" +dependencies = [ + "indexmap", + "toml_datetime 1.0.0+spec-1.1.0", + "toml_parser", + "winnow 0.7.15", +] + +[[package]] +name = "toml_parser" +version = "1.0.9+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "702d4415e08923e7e1ef96cd5727c0dfed80b4d2fa25db9647fe5eb6f7c5a4c4" +dependencies = [ + "winnow 0.7.15", +] + +[[package]] +name = "tracing" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" +dependencies = [ + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "tracing-core" +version = "0.1.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" +dependencies = [ + "once_cell", + "valuable", +] + +[[package]] +name = "tracing-forest" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee40835db14ddd1e3ba414292272eddde9dad04d3d4b65509656414d1c42592f" +dependencies = [ + "ansi_term", + "smallvec", + "thiserror 1.0.69", + "tracing", + "tracing-subscriber", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "regex-automata", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", +] + +[[package]] +name = "transcript" +version = "0.1.0" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=feat%2Fbump-p3#9c6e6b63024811a1b51f7e685bacda2567b539db" +dependencies = [ + "ff_ext", + "itertools 0.13.0", + "p3", + "poseidon", +] + +[[package]] +name = "transpose" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ad61aed86bc3faea4300c7aee358b4c6d0c8d6ccc36524c96e4c92ccf26e77e" +dependencies = [ + "num-integer", + "strength_reduce", +] + +[[package]] +name = "tstr" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f8e0294f14baae476d0dd0a2d780b2e24d66e349a9de876f5126777a37bdba7" +dependencies = [ + "tstr_proc_macros", +] + +[[package]] +name = "tstr_proc_macros" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e78122066b0cb818b8afd08f7ed22f7fdbc3e90815035726f0840d0d26c0747a" + +[[package]] +name = "typed-arena" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6af6ae20167a9ece4bcb41af5b80f8a1f1df981f6391189ce00fd257af04126a" + +[[package]] +name = "typenum" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" + +[[package]] +name = "typewit" +version = "1.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8c1ae7cc0fdb8b842d65d127cb981574b0d2b249b74d1c7a2986863dc134f71" + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "unicode-width" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" + +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + +[[package]] +name = "valuable" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "wasip2" +version = "1.0.2+wasi-0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasip3" +version = "0.4.0+wasi-0.3.0-rc-2026-01-06" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasm-bindgen" +version = "0.2.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6532f9a5c1ece3798cb1c2cfdba640b9b3ba884f5db45973a6f442510a87d38e" +dependencies = [ + "cfg-if", + "once_cell", + "rustversion", + "wasm-bindgen-macro", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18a2d50fcf105fb33bb15f00e7a77b772945a2ee45dcf454961fd843e74c18e6" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03ce4caeaac547cdf713d280eda22a730824dd11e6b8c3ca9e42247b25c631e3" +dependencies = [ + "bumpalo", + "proc-macro2", + "quote", + "syn 2.0.117", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75a326b8c223ee17883a4251907455a2431acc2791c98c26279376490c378c16" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "wasm-encoder" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" +dependencies = [ + "anyhow", + "indexmap", + "wasm-encoder", + "wasmparser", +] + +[[package]] +name = "wasmparser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" +dependencies = [ + "bitflags", + "hashbrown 0.15.5", + "indexmap", + "semver", +] + +[[package]] +name = "web-sys" +version = "0.3.91" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "854ba17bb104abfb26ba36da9729addc7ce7f06f5c0f90f3c391f8461cca21f9" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "whir" +version = "0.1.0" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=feat%2Fbump-p3#9c6e6b63024811a1b51f7e685bacda2567b539db" +dependencies = [ + "bincode", + "clap", + "derive_more", + "ff_ext", + "itertools 0.14.0", + "multilinear_extensions", + "p3", + "rand 0.8.5", + "rand_chacha 0.3.1", + "rayon", + "serde", + "sumcheck", + "tracing", + "transcript", + "transpose", + "witness", +] + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-sys" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.53.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" +dependencies = [ + "windows-link", + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" + +[[package]] +name = "windows_i686_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" + +[[package]] +name = "windows_i686_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" + +[[package]] +name = "winnow" +version = "0.5.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f593a95398737aeed53e489c785df13f3618e41dbcd6718c6addbf1395aa6876" +dependencies = [ + "memchr", +] + +[[package]] +name = "winnow" +version = "0.7.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df79d97927682d2fd8adb29682d1140b343be4ac0f08fd68b7765d9c059d3945" +dependencies = [ + "memchr", +] + +[[package]] +name = "wit-bindgen" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +dependencies = [ + "wit-bindgen-rust-macro", +] + +[[package]] +name = "wit-bindgen-core" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" +dependencies = [ + "anyhow", + "heck", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" +dependencies = [ + "anyhow", + "heck", + "indexmap", + "prettyplease", + "syn 2.0.117", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn 2.0.117", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" +dependencies = [ + "anyhow", + "bitflags", + "indexmap", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" +dependencies = [ + "anyhow", + "id-arena", + "indexmap", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] + +[[package]] +name = "witness" +version = "0.1.0" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=feat%2Fbump-p3#9c6e6b63024811a1b51f7e685bacda2567b539db" +dependencies = [ + "ff_ext", + "multilinear_extensions", + "p3", + "rand 0.8.5", + "rayon", + "tracing", +] + +[[package]] +name = "wyz" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" +dependencies = [ + "tap", +] + +[[package]] +name = "zerocopy" +version = "0.8.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a789c6e490b576db9f7e6b6d661bcc9799f7c0ac8352f56ea20193b2681532e5" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f65c489a7071a749c849713807783f70672b28094011623e200cb86dcb835953" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" +dependencies = [ + "zeroize_derive", +] + +[[package]] +name = "zeroize_derive" +version = "1.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85a5b4158499876c763cb03bc4e49185d3cccbabb15b33c627f7884f43db852e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "zkhash" +version = "0.2.0" +source = "git+https://github.com/HorizenLabs/poseidon2.git?rev=bb476b9#bb476b9ca38198cf5092487283c8b8c5d4317c4e" +dependencies = [ + "ark-ff", + "ark-std", + "bitvec", + "blake2", + "bls12_381", + "byteorder", + "cfg-if", + "group 0.12.1", + "group 0.13.0", + "halo2", + "hex", + "jubjub", + "lazy_static", + "pasta_curves 0.5.1", + "rand 0.8.5", + "serde", + "sha2", + "sha3", + "subtle", +] + +[[package]] +name = "zmij" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" + +[[package]] +name = "zstd" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.16+zstd.1.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e19ebc2adc8f83e43039e79776e3fda8ca919132d68a1fed6a5faca2683748" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/ceno_recursion_v2/Cargo.toml b/ceno_recursion_v2/Cargo.toml new file mode 100644 index 000000000..def566ad7 --- /dev/null +++ b/ceno_recursion_v2/Cargo.toml @@ -0,0 +1,71 @@ +[package] +categories = ["cryptography", "zk", "blockchain", "ceno"] +description = "Next-generation recursion circuits for Ceno built on OpenVM v2" +edition = "2024" +keywords = ["cryptography", "zk", "blockchain", "ceno"] +license = "MIT OR Apache-2.0" +name = "ceno_recursion_v2" +readme = "../README.md" +repository = "https://github.com/scroll-tech/ceno" +version = "0.1.0" + +[dependencies] +bincode = "1" +ceno-examples = { path = "../examples-builder" } +ceno_emul = { path = "../ceno_emul" } +ceno_host = { path = "../ceno_host" } +ceno_zkvm = { path = "../ceno_zkvm" } +clap = { version = "4.5", features = ["derive"] } +continuations-v2 = { git = "https://github.com/openvm-org/openvm.git", package = "openvm-continuations", branch = "develop-v2.0.0-beta", default-features = false } +derive-new = "0.6.0" +eyre = "0.6" +ff_ext = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "ff_ext", branch = "feat/bump-p3" } +gkr_iop = { path = "../gkr_iop" } +itertools = "0.13" +mpcs = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "mpcs", branch = "feat/bump-p3" } +multilinear_extensions = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "multilinear_extensions", branch = "feat/bump-p3" } +openvm = { git = "https://github.com/openvm-org/openvm.git", branch = "develop-v2.0.0-beta", default-features = false } +openvm-circuit = { git = "https://github.com/openvm-org/openvm.git", branch = "develop-v2.0.0-beta", default-features = false } +openvm-circuit-primitives = { git = "https://github.com/openvm-org/openvm.git", branch = "develop-v2.0.0-beta", default-features = false } +openvm-poseidon2-air = { git = "https://github.com/openvm-org/openvm.git", branch = "develop-v2.0.0-beta", package = "openvm-poseidon2-air", default-features = false } +openvm-stark-backend = { git = "https://github.com/openvm-org/stark-backend.git", branch = "develop-v2", default-features = false } +openvm-stark-sdk = { git = "https://github.com/openvm-org/stark-backend.git", branch = "develop-v2" } +openvm-cuda-backend = { git = "https://github.com/openvm-org/stark-backend.git", branch = "develop-v2", optional = true } +openvm-cuda-common = { git = "https://github.com/openvm-org/stark-backend.git", branch = "develop-v2", optional = true } +openvm-cpu-backend = { git = "https://github.com/openvm-org/stark-backend.git", branch = "develop-v2", default-features = false } +p3 = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "p3", branch = "feat/bump-p3" } +p3-air = { version = "=0.4.1", default-features = false } +p3-field = { version = "=0.4.1", default-features = false } +p3-matrix = { version = "=0.4.1", default-features = false } +p3-maybe-rayon = { version = "=0.4.1", default-features = false } +p3-symmetric = { version = "=0.4.1", default-features = false } +parse-size = "1.1" +rand = "0.8" +recursion-circuit = { git = "https://github.com/openvm-org/openvm.git", package = "openvm-recursion-circuit", branch = "develop-v2.0.0-beta", default-features = false } +serde = { version = "1.0", features = ["derive", "rc"] } +serde_json = "1.0" +stark-recursion-circuit-derive = { git = "https://github.com/openvm-org/openvm.git", package = "openvm-recursion-circuit-derive", branch = "develop-v2.0.0-beta" } +strum = "0.26" +strum_macros = "0.26" +sumcheck = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sumcheck", branch = "feat/bump-p3" } +tracing = { version = "0.1", features = ["attributes"] } +tracing-forest = { version = "0.1.6" } +tracing-subscriber = { version = "0.3", features = ["env-filter"] } +transcript = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "transcript", branch = "feat/bump-p3" } +verify-stark = { git = "https://github.com/openvm-org/openvm.git", package = "openvm-verify-stark-host", branch = "develop-v2.0.0-beta", default-features = false } +whir = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "whir", branch = "feat/bump-p3" } +witness = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "witness", branch = "feat/bump-p3" } + +[features] +cuda = [ + "dep:openvm-cuda-backend", + "dep:openvm-cuda-common", +] +default = [] + +[patch."https://github.com/openvm-org/stark-backend.git"] +openvm-stark-backend = { git = "https://github.com/hero78119/stark-backend.git", branch = "develop-v2", default-features = false } +openvm-stark-sdk = { git = "https://github.com/hero78119/stark-backend.git", branch = "develop-v2" } +openvm-cuda-backend = { git = "https://github.com/hero78119/stark-backend.git", branch = "develop-v2", optional = true } +openvm-cuda-common = { git = "https://github.com/hero78119/stark-backend.git", branch = "develop-v2", optional = true } +openvm-cpu-backend = { git = "https://github.com/hero78119/stark-backend.git", branch = "develop-v2", default-features = false } \ No newline at end of file diff --git a/ceno_recursion_v2/clippy.toml b/ceno_recursion_v2/clippy.toml new file mode 100644 index 000000000..41690a1eb --- /dev/null +++ b/ceno_recursion_v2/clippy.toml @@ -0,0 +1,29 @@ +# TODO(Matthias): review and see which exception we can remove over time. +# Eg removing syn is blocked by ark-ff-asm cutting a new release +# (https://github.com/arkworks-rs/algebra/issues/813) amongst other things. +allowed-duplicate-crates = [ + "hashbrown", + "itertools", + "regex-automata", + "regex-syntax", + "syn", + "thiserror", + "thiserror-impl", + "windows-sys", + "tracing-subscriber", + "wasi", + "getrandom", + "zerocopy", + "zerocopy-derive", + "proc-macro-crate", + "toml_edit", + "toml_datetime", + "winnow", + "generic-array", + "num-bigint", + "rfc6979", + "k256", + "p256", + "primeorder", + "ecdsa", +] diff --git a/ceno_recursion_v2/docs/main_spec.md b/ceno_recursion_v2/docs/main_spec.md new file mode 100644 index 000000000..242a195db --- /dev/null +++ b/ceno_recursion_v2/docs/main_spec.md @@ -0,0 +1,96 @@ +## Main Module (`src/main`) + +The Main module bridges the reduced GKR claim into a “global” sumcheck AIR. It receives the +`input_layer_claim` emitted by `TowerInputAir`, replays a one-layer sumcheck (currently a pass-through +check), and hands the resulting claim back to downstream modules. + +### MainAir (`src/main/air.rs`) + +| Column | Shape | Description | +|----------------|----------|-------------------------------------------------------------------------| +| `is_enabled` | scalar | Row selector. Disabled rows carry padding. | +| `proof_idx` | scalar | Outer loop counter shared with GKR inputs. | +| `idx` | scalar | Module index within the proof (matches `TowerInputAir`). | +| `is_first_idx` | scalar | Flags the first row for each `(proof_idx, idx)` pair. | +| `is_first` | scalar | Always `1` on real rows (there is a single row per `(proof_idx, idx)`). | +| `tidx` | scalar | Transcript cursor at which the Main claim applies. | +| `claim_in` | `[D_EF]` | The folded claim received from `TowerInputAir`. | +| `claim_out` | `[D_EF]` | The claim returned by `MainSumcheckAir` (expected to match `claim_in`). | + +#### Constraints + +- `NestedForLoopSubAir<2>` enforces boolean enablement, padding-after-padding, and lexicographic + ordering over `(proof_idx, idx)`, using `is_first_idx` / `is_first` to mark loop resets. +- On `is_first` rows, the AIR receives `MainMessage` and constrains the local columns to match the + bus payload (`idx`, `tidx`, and `claim_in`). +- Every enabled row sends `MainSumcheckInputMessage { idx, tidx, claim }` to the sumcheck AIR. +- The AIR immediately receives `MainSumcheckOutputMessage` and constrains `claim_out` to equal the + returned payload. This keeps transcript state explicit even though the current sumcheck logic is a + no-op. +- A simple consistency check enforces `claim_in == claim_out`, ensuring the pass-through sumcheck + cannot mutate the claim silently. + +#### Bus Interactions + +- **MainBus.receive** (from `TowerInputAir`): `(idx, tidx, claim_in)` on `is_first` rows. +- **MainSumcheckInputBus.send**: forwards `(idx, tidx, claim_in)` on every enabled row. +- **MainSumcheckOutputBus.receive**: ingests `(idx, claim_out)` (one message per `(proof_idx, idx)` + because the sumcheck only emits on its `is_last_round`). +- **TranscriptBus**: currently unused (the transcript positions are enforced implicitly through the + provided `tidx`), but columns are wired so future revisions can observe claims if needed. + +### MainSumcheckAir (`src/main/sumcheck`) + +| Column | Shape | Description | +|------------------|----------|---------------------------------------------------------------------| +| `is_enabled` | scalar | Row selector. | +| `proof_idx` | scalar | Matches the producer AIR. | +| `idx` | scalar | Module index within the proof. | +| `is_first_idx` | scalar | Flags the first row for each `(proof_idx, idx)` pair. | +| `is_first_round` | scalar | Indicates the first round for the current `(proof_idx, idx)` block. | +| `is_last_round` | scalar | Marks the final round; used to gate the output message. | +| `is_dummy` | scalar | Allows a placeholder row when `num_rounds = 0`. | +| `round` | scalar | Round counter (starts at 0 and increments each sub-round). | +| `tidx` | scalar | Transcript cursor for the current round (`+4·D_EF` per transition). | +| `ev1/ev2/ev3` | `[D_EF]` | Sumcheck polynomial evaluations at 1/2/3. | +| `claim_in` | `[D_EF]` | Claim entering the round. | +| `claim_out` | `[D_EF]` | Claim produced by cubic interpolation (fed into the next round). | +| `prev_challenge` | `[D_EF]` | The previous transcript challenge (ξ) used in the eq term. | +| `challenge` | `[D_EF]` | The round’s sampled challenge (rᵢ). | +| `eq_in` | `[D_EF]` | Running eq evaluation prior to this round. | +| `eq_out` | `[D_EF]` | Updated eq evaluation after applying the round challenge. | + +#### Constraints + +- `NestedForLoopSubAir<2>` runs over `(proof_idx, idx)` while treating `is_first_round` as the + innermost loop reset. It enforces boolean flags, padding-after-padding, and lexicographic + ordering. +- `round` is zeroed on `is_first_round` rows and increments by 1 on transitions within the same + `(proof_idx, idx)`. The transcript cursor `tidx` increases by `4·D_EF` per round. +- `is_last_round` is constrained to equal `NestedForLoopSubAir::local_is_last`, so it flips to 1 on + the final enabled row for each `(proof_idx, idx)` pair. +- On `is_first_round`, the AIR receives `MainSumcheckInputMessage { idx, tidx, claim }` and seeds the + local columns. `eq_in` is set to one, and `claim_in` is forced to the received payload. +- Each round computes `ev0 = claim_in - ev1`, feeds `ev0/ev1/ev2/ev3` through the optimized cubic + interpolator, and constrains `claim_out`. `claim_out` is copied to the next row’s `claim_in` for + transitions. +- Eq values update via `eq_out = eq_in * (ξ·rᵢ + (1-ξ)(1-rᵢ))`, with propagation to the next row on + transitions. Dummy rows (for zero-round proofs) carry `is_dummy = 1`, which suppresses bus traffic. +- Only rows with `is_last_round = 1` may send the result back; all other rows keep the claim inside + the module. + +#### Bus Interactions + +- **MainSumcheckInputBus.receive**: `(idx, tidx, claim_in)` on `is_first_round` rows (and only when + `is_dummy = 0`). +- **MainSumcheckOutputBus.send**: `(idx, claim_out)` gated by `is_last_round` and `!is_dummy`, so the + claim returns to `MainAir` exactly once per `(proof_idx, idx)`. + +--- + +### Sumcheck Notes + +The Main sumcheck now mirrors the GKR layer sumcheck structure: it emits one row per round, tracks +`round`/`eq`/challenge evolution, and only releases the folded claim on `is_last_round`. The current +trace generator still fills the evaluation/challenge fields with placeholder zeros until real tower +data is connected, so the AIR behaves as a pass-through while preserving the full protocol shape. diff --git a/ceno_recursion_v2/docs/proof_shape_spec.md b/ceno_recursion_v2/docs/proof_shape_spec.md new file mode 100644 index 000000000..49e2306ab --- /dev/null +++ b/ceno_recursion_v2/docs/proof_shape_spec.md @@ -0,0 +1,135 @@ +# Proof Shape Module Spec + +This spec summarizes the components under `src/proof_shape`. The module is forked from upstream recursion code so we can +adapt it to Ceno’s ZKVM while keeping behavior aligned with OpenVM. + +## ProofShapeModule (`src/proof_shape/mod.rs`) + +### Purpose + +- Verify child-proof trace metadata (heights, cached commits, public values) against the child verifying key. +- Route transcript/bus traffic related to those checks (power/range lookups, permutation commitments, GKR cross-module + messages). +- Produce CPU (and optional CUDA) traces for the ProofShape and PublicValues AIRs, plus aggregate preflight info used + later in recursion. + +### Key Fields + +- `per_air: Vec`: built from `RecursionVk.circuit_vks` in circuit-index order; currently stores + `is_required = false`, `num_public_values = instance_openings().len()`, placeholder widths (`main_width = 0`, + `cached_widths = vec![0; max_cached]`), and read/write/log lookup counts (`num_read_count`, `num_write_count`, + `num_logup_count`) used by the GKR checks. +- Cached/commit parameters are currently fixed in `ProofShapeModule::new`: `min_cached_idx = 0`, `max_cached = 2`, + `commit_mult = 100`. +- `idx_encoder`: enforces permutation ordering between `idx` (VK order) and `sorted_idx` (runtime order). +- Bus handles: power/range checker, proof-shape permutation, starting tidx, number of public values, GKR module, + air-shape, expression-claim, fraction-folder, hyperdim lookup, lifted heights, commitments, transcript, n_lift, cached + commit. + +### Tracegen Flow + +1. Build `ProofShapeChip::<4,8>` (CPU) / GPU equivalent, parameterized by cached-commit bounds and range/power checker + handles. +2. Gather context (`StandardTracegenCtx`) of `(vk, proofs, preflights)` and produce row-major traces for both ProofShape + and PublicValues airs. +3. Preflight builder (`Preflight::populate_proof_shape`) collects sorted trace metadata, starting tidx values, cached + commits, and transcript positions for public values; these feed back into recursion aggregates. + +### Module Interactions + +- Sends/receives bus messages enumerated in the AIR sections below. +- `ProofShapeModule::new` wires buses via `BusIndexManager`; `commit_child_vk` commits the child VK once per recursion + instance (currently unimplemented while ZKVM wiring is in progress). + +## ProofShapeAir (`src/proof_shape/proof_shape/air.rs`) + +### Column Groups + +| Group | Columns | Notes | +|-----------------------------|------------------------------------------------------------------------|-----------------------------------------------------------------------------------------| +| Row selectors | `proof_idx`, `is_valid`, `is_first`, `is_last`, `is_present` | Manage per-proof iteration and summary row detection. | +| Ordering & metadata | `idx`, `sorted_idx`, `log_height`, `height`, `need_rot`, `num_present` | Track VK ordering vs runtime order, enforce height monotonicity, rotation requirements. | +| Transcript anchors | `starting_tidx`, `starting_cidx` | Anchor where per-air transcript reads start; exported via buses. | +| Height decomposition | `height_limbs[NUM_LIMBS]` | Enforce limb decomposition/range checks for `height`. | +| Hyperdim summary | `n_max`, `is_n_max_greater`, `num_air_id_lookups`, `num_columns` | Track max `log_height` across present AIRs and auxiliary per-air lookup counts. | +| Cached commit bookkeeping | `cached_idx_flags`, `cached_idx_value`, `cached_commits` | Track how many cached columns exist and their transcript tidx positions. | +| Bookkeeping for permutation | Encoder-specific subcolumns (idx flags) verifying sorted order. + +### Constraints Overview + +- **Looping**: `NestedForLoopSubAir<1>` runs per proof, iterating through `idx` values and ensuring `is_valid`+`is_last` + drive transitions. +- **Permutation**: `ProofShapePermutationBus` enforces that runtime order (`sorted_idx`) is a permutation of VK order ( + `idx`). `idx_encoder` ensures only one row per column and enforces boolean flags. +- **Trace heights**: Range checker ensures `log_height` is monotonically non-increasing; when `is_present = 1`, + `height = 2^{log_height}`. Hyperdim bus uses unsigned `n = log_height` (`n_abs = n`, `n_sign_bit = 0`). +- **Rotation/caching**: Rows with `need_rot = 1` record rotation requirements on `CommitmentsBus` and `CachedCommitBus`. + `starting_cidx`/`starting_tidx` communicate the first column/ transcript offset for each AIR. +- **Expression lookups**: `ExpressionClaimNMaxBus`, `FractionFolderInputBus`, and `NLiftBus` mirror the computed + `n_logup`, `n_max`, and `n_lift = log_height` metadata so batch constraint and fraction-folder modules can cross-check + expectations. `AirShapeBus` exposes additional per-AIR properties (`NumRead`, `NumWrite`, `NumLk`) so GKR AIRs can + enforce that their runtime layer counts match the verifying-key declarations. `NumInteractions` is currently emitted + as + `0` in this AIR. + +### Bus Interactions + +- Sends on: `ProofShapePermutationBus`, `HyperdimBus`, `LiftedHeightsBus`, `CommitmentsBus`, `ExpressionClaimNMaxBus`, + `FractionFolderInputBus`, `NLiftBus`, `StartingTidxBus`, `NumPublicValuesBus`, `CachedCommitBus` (if continuations + enabled). +- Receives from: `ProofShapePermutationBus` (VK order), `TowerModuleBus` (per-proof configuration), `AirShapeBus` + (per-air property lookups, including the new `NumRead` / `NumWrite` / `NumLk` counters that downstream GKR AIRs + enforce), `PowerCheckerBus` (for PoW enforcement), `RangeCheckerBus` (monotonic log heights), + `TranscriptBus` (sample/observe tidx-aligned data), `CachedCommitBus` (continuations), `CommitmentsBus` (when reading + transcript commitments). + +### Summary Row Logic + +On the row with `is_last = 1`, additional checks happen: + +- Emit final `n_logup/n_max` via `ExpressionClaimNMaxBus` and `NLiftBus`. +- Update `ProofShapePreflight` fields in the transcript (tracked via tidx) so future recursion layers know where + ProofShape stopped reading. + +## PublicValuesAir (`src/proof_shape/pvs/air.rs`) + +### Columns + +| Column | Description | +|----------------------------------------|-----------------------------------------------------------| +| `is_valid` | Row selector; invalid rows carry padding data. | +| `proof_idx`, `air_idx`, `pv_idx` | Identify the proof/AIR/public-value index triple. | +| `is_first_in_proof`, `is_first_in_air` | Lower-level loop markers used for sequencing constraints. | +| `tidx` | Transcript cursor for the public value read. | +| `value` | The actual public value field element. | + +### Constraints + +- `NestedForLoopSubAir<1>` enforces that enabled rows form contiguous `(proof_idx, air_idx)` segments and increments + `pv_idx` and `tidx` appropriately when staying within the same AIR. +- On `is_first_in_proof`, enforce `pv_idx = 0` and `tidx = starting_tidx` supplied via preflights/ProofShape module. +- For padding rows, force `proof_idx = num_proofs` to match upstream convention. + +### Interactions + +- `PublicValuesBus.send`: publishes each `(air_idx, pv_idx, value)` pair so downstream modules can replay the values; + optionally doubled when `continuations_enabled`. +- `NumPublicValuesBus.receive`: on first-in-air rows, ingests `(air_idx, tidx, num_pvs)` to cross-check counts derived + from ProofShape. +- `TranscriptBus.receive`: ensures the transcript sees the same public values at the given `tidx` (read-only). + +## Trace Generators + +- `ProofShapeChip::` (CPU) and module-level tracegen wiring are currently placeholders in this + branch, producing zero-filled traces with the requested height while AIR wiring stabilizes. +- `ProofShapeChipGpu`/CUDA ABI wrappers remain available behind feature gates. +- `PublicValuesTraceGenerator` walks each proof’s `public_values` arrays, emits `(proof_idx, air_idx, pv_idx)` rows, + pads to powers of two, and records transcript progression. +- CUDA ABI wrappers (`cuda_abi.rs`) expose raw tracegen entry points for GPU builds. + +## Preflight & Metadata + +- `ProofShapePreflight` stores the sorted trace metadata, per-air transcript anchors (`starting_tidx`), cached commit + tidx list, and summary scalars (`n_logup`, `n_max`). +- During transcript preflight (`ProofShapeModule::preflight`), the module replays transcript interactions (observing + cached commitments, sampling challenges) and writes the preflight struct for later modules (e.g., GKR) to consume. diff --git a/ceno_recursion_v2/docs/system_spec.md b/ceno_recursion_v2/docs/system_spec.md new file mode 100644 index 000000000..7c1216cf6 --- /dev/null +++ b/ceno_recursion_v2/docs/system_spec.md @@ -0,0 +1,113 @@ +# System Module Spec + +This document summarizes the aggregation layer under `src/system`. The code mirrors upstream `recursion_circuit::system` +but is forked so we can swap in ZKVM verifying keys (`RecursionVk`). + +## Type Aliases (`src/system/types.rs`) + +- `RecursionField = BabyBearExt4` and `RecursionPcs = Basefold` unify ZKVM field + choices across the crate. +- `RecursionVk = ZKVMVerifyingKey` replaces the upstream `MultiStarkVerifyingKey` so + future traits accept ZKVM proofs/VKs natively. +- `RecursionProof = ZKVMProof` is the canonical proof type exposed to modules; + `convert_proof_from_zkvm` / `convert_vk_from_zkvm` are bridge placeholders and currently `unimplemented!()`. + +## Preflight Records (`src/system/preflight.rs`) + +- Local fork of the upstream `Preflight`/`ProofShapePreflight`/`TowerPreflight` structs so we can evolve transcript + layout + and bookkeeping independently of OpenVM. +- Only the fields that current modules need are mirrored (trace metadata, tidx checkpoints, transcript log, Poseidon + inputs). Additional upstream functionality stays commented out until required. + +## Frame Shim (`src/system/frame.rs`) + +- Local copy of upstream `system::frame` because the originals are `pub(crate)`. +- Provides `StarkVkeyFrame` and `MultiStarkVkeyFrame` structs used by modules (e.g., ProofShape) when exposing + verifying-key metadata to AIRs. +- Each frame strips non-deterministic data (only clones params, cached commitments, interaction counts) to keep AIR + traces stable. + +## POW Checker Constant + +- `POW_CHECKER_HEIGHT: usize = 32` mirrors the upstream constant so modules (ProofShape, batch-constraint) can + type-check their `PowerChecker` gadgets without reaching into a private upstream module. + +## GlobalCtxCpu Override (`src/system/mod.rs`) + +- The upstream `GlobalCtxCpu` binds `TraceGenModule` to `[Proof]`. We shadow it locally with a + struct of the same name that implements `GlobalTraceGenCtx` but sets `type MultiProof = [RecursionProof]`. +- This keeps all CPU tracegen entry points on ZKVM proofs while leaving the trait definitions untouched; CUDA tracegen + continues to use the upstream GPU context. + +## VerifierTraceGen Trait + +Located at `src/system/mod.rs:28`. + +Responsibilities: + +1. `new(child_vk, config) -> Self`: build the recursive subcircuit using the child verifying key and the user-provided + `VerifierConfig`. +2. `commit_child_vk(engine, child_vk)`: write commitments for the child verifying key into the proof transcript. +3. `generate_proving_ctxs(...)`: orchestrate per-module trace generation (transcript, proof shape, main, GKR), run + preflights, and collect `AirProvingContext`s. +4. `generate_proving_ctxs_base(...)`: helper that synthesizes a default `VerifierExternalData` (empty poseidon/range + inputs, no required heights) and calls the trait method. + +The trait is generic over both the prover backend (`PB`) and the Stark protocol configuration (`SC`), enabling CPU/GPU +backends. + +## VerifierSubCircuit (`src/system/mod.rs:90`) + +Fields capture the stateful modules that participate in recursive verification: + +- `bus_inventory: BusInventory`: record of allocated buses ensuring consistent indices. +- `bus_idx_manager: BusIndexManager`: allocator used when wiring modules. +- `transcript: TranscriptModule`: handles Fiat–Shamir transcript operations across the entire recursion proof. +- `proof_shape: ProofShapeModule`: enforces child trace metadata (see `proof_shape_spec.md`). +- `main_module: MainModule`: validates main-module constraints and participates in tracegen orchestration. +- `gkr: TowerModule`: verifies the GKR proof emitted by the child STARK (see `docs/gkr_air_spec.md`). + +### Trait Implementation Status + +- `VerifierTraceGen` is implemented for CPU: `new`, `commit_child_vk`, `generate_proving_ctxs`, and + `generate_proving_ctxs_base` are active. +- `AggregationSubCircuit` methods `airs`, `bus_inventory`, `next_bus_idx`, and `max_num_proofs` are active. +- Remaining placeholders are bridge converters in `src/system/types.rs` (`convert_proof_from_zkvm`, + `convert_vk_from_zkvm`) and selected module internals that are intentionally stubbed while wiring stabilizes. + +## AggregationSubCircuit Impl + +- `airs()` returns a full list of `AirRef`s from transcript, proof-shape, main, GKR, plus power-checker and + exp-bits-len AIRs. +- `bus_inventory()` returns a reference to the internal inventory so orchestration code can inspect bus handles. +- `next_bus_idx()` returns the current allocator cursor via `BusIndexManager`. +- `max_num_proofs()` returns the const generic bound used by aggregation provers. + +## How Modules Fit Together + +1. **TranscriptModule** absorbs all Fiat–Shamir sampling/observations (PoW, alpha, lambda, mu, sumcheck evaluations). + Other modules refer to transcript locations via shared tidx counters. +2. **ProofShapeModule** reads the child proof metadata and emits bus messages for downstream modules ( + height summaries, cached commitments, public values, etc.). +3. **MainModule** enforces core verifier constraints linked to transcript/proof-shape outputs. +4. **TowerModule** consumes those messages plus the child GKR proof to verify the folding of claims (see separate spec). +5. **VerifierSubCircuit** orchestrates these modules: it shares `BusInventory`, ensures every module gets consistent + handles, and sequences trace generation so transcript state advances consistently. + +## Current Semantics Note + +- The older system-level implication check relating child trace-height constraints to a + `sum(num_interactions * lifted_height) < max_interaction_count` bound is currently commented out in + `VerifierSubCircuit::new_with_options`. +- In the current ProofShape AIR, hyperdim is unsigned (`n = log_height`, `n_sign_bit = 0`) and + `AirShapeProperty::NumInteractions` is emitted as `0`. + +## Pending Work / Notes + +- ZKVM proof objects now flow through every CPU tracegen module; `VerifierSubCircuit::commit_child_vk` still needs + end-to-end bridge converters (`convert_proof_from_zkvm` / `convert_vk_from_zkvm`) are still pending. +- Bus wiring currently happens upstream; replicating it locally may require copying additional files if upstream keeps + types `pub(crate)`. +- All module constructors should remain aligned with upstream layout to minimize future rebase conflicts; prefer small + local wrappers over structural rewrites. diff --git a/ceno_recursion_v2/docs/tower_air_spec.md b/ceno_recursion_v2/docs/tower_air_spec.md new file mode 100644 index 000000000..a08dc2bb2 --- /dev/null +++ b/ceno_recursion_v2/docs/tower_air_spec.md @@ -0,0 +1,269 @@ +# GKR AIR Spec + +This document captures the current behavior of each GKR-related AIR that lives in `src/tower`. It mirrors the code so we +can reason about constraints or plan refactors without diving back into Rust. Update the relevant section whenever an +AIR’s columns, constraints, or interactions change. + +## TowerInputAir (`src/tower/input/air.rs`) + +### Columns + +| Field | Shape | Description | +|-----------------------|-----------------|----------------------------------------------------------------------| +| `is_enabled` | scalar | Row selector (0 = padding). | +| `proof_idx` | scalar | Outer proof loop index enforced by nested sub-AIRs. | +| `idx` | scalar | Inner loop index enumerating AIR instances within a proof. | +| `n_layer` | scalar | Number of active GKR layers for the proof. | +| `is_n_layer_zero` | scalar | Flag for `n_layer == 0` (drives “no interaction” branches). | +| `is_n_layer_zero_aux` | `IsZeroAuxCols` | Witness used by `IsZeroSubAir` to enforce the zero test. | +| `tidx` | scalar | Transcript cursor at start of the proof. | +| `r0_claim` | `[D_EF]` | Root numerator commitment supplied to `TowerLayerAir`. | +| `w0_claim` | `[D_EF]` | Root witness commitment supplied to `TowerLayerAir`. | +| `q0_claim` | `[D_EF]` | Root denominator commitment supplied to `TowerLayerAir`. | +| `alpha_logup` | `[D_EF]` | Transcript challenge sampled before passing inputs to GKR layers. | +| `input_layer_claim` | `[D_EF]` | Folded claim returned from `TowerLayerAir`. | +| `layer_output_lambda` | `[D_EF]` | Batching challenge sampled in the final GKR layer (zeros if unused). | +| `layer_output_mu` | `[D_EF]` | Reduction point sampled in the final GKR layer (zeros if unused). | + +### Row Constraints + +- **Enablement / indexing**: A `NestedForLoopSubAir<2>` enforces boolean `is_enabled`, padding-after-padding, and + consecutive `(proof_idx, idx)` pairs for enabled rows. +- **Zero test**: `IsZeroSubAir` checks `n_logup` against `is_n_logup_zero`, unlocking the “no interaction” path. +- **Input layer defaults**: When `n_logup == 0`, the input-layer claim must be `[0, α]` (numerator zero, denominator + equals `alpha_logup`). +- **Transcript math**: Local expressions derive the transcript offsets for alpha sampling, per-layer reductions, and the + xi-sampling window directly from `n_layer`. No auxiliary `n_max` adjustment is needed. + +### Interactions + +- **Internal buses** + - `TowerLayerInputBus.send`: emits `(idx, tidx skip roots, r0/w0/q0_claim)` when interactions exist. + - `TowerLayerOutputBus.receive`: pulls reduced `(idx, layer_idx_end, input_layer_claim, lambda, mu)` back. +- **External buses** + - `TowerModuleBus.receive`: initial module message `(idx, tidx, n_logup)` per enabled row. + - `BatchConstraintModuleBus.send`: forwards the final input-layer claim with the final transcript index. + - `TranscriptBus`: sample `alpha_logup` and observe `q0_claim` only when `has_interactions`. + +### Notes + +- Local booleans `has_interactions` gate all downstream activity, so future refactors must keep those semantics aligned + with the code branches. + +## TowerLayerAir (`src/tower/layer/air.rs`) + +### Columns + +| Field | Shape | Description | +|------------------------------------|---------------|----------------------------------------------------------------------------| +| `is_enabled` | scalar | Row selector. | +| `proof_idx` | scalar | Proof counter shared with input AIR. | +| `idx` | scalar | AIR index within the proof (matches the input AIR). | +| `is_first_air_idx` | scalar | First row flag for each `(proof_idx, idx)` block. | +| `is_first` | scalar | Indicates the first layer row of a proof. | +| `is_dummy` | scalar | Marks padding rows that still satisfy constraints. | +| `layer_idx` | scalar | Layer number, enforced to start at 0 and increment per transition. | +| `tidx` | scalar | Transcript cursor at the start of the layer. | +| `lambda` | `[D_EF]` | Fresh batching challenge sampled for non-root layers. | +| `lambda_prime` | `[D_EF]` | Challenge inherited from the previous layer (root layer pins it to `1`). | +| `mu` | `[D_EF]` | Reduction point sampled from transcript. | +| `sumcheck_claim_in` | `[D_EF]` | Combined claim passed to the layer sumcheck AIR. | +| `read_claim` | `[D_EF]` | Folded product contribution with respect to `lambda`. | +| `read_claim_prime` | `[D_EF]` | Companion folded claim with respect to `lambda_prime` (root = r₀). | +| `write_claim` | `[D_EF]` | Same as above for the write accumulator. | +| `write_claim_prime` | `[D_EF]` | Companion write claim. | +| `logup_claim` | `[D_EF]` | LogUp folded claim w.r.t. `lambda`. | +| `logup_claim_prime` | `[D_EF]` | LogUp folded claim w.r.t. `lambda_prime` (root = q₀). | +| `num_read_count` | scalar | Declared accumulator length for the read prod AIR (must equal `n_logup`). | +| `num_write_count` | scalar | Declared accumulator length for the write prod AIR (must equal `n_logup`). | +| `num_logup_count` | scalar | Declared accumulator length for the logup AIR (must equal `n_logup`). | +| `eq_at_r_prime` | `[D_EF]` | Product of eq evaluations returned from sumcheck. | +| `r0_claim`, `w0_claim`, `q0_claim` | `[D_EF]` each | Root evaluations supplied by `TowerInputAir`. | + +### Row Constraints + +- **Looping**: `NestedForLoopSubAir<2>` continues to enforce boolean enablement, padding-after-padding, and + lexicographic ordering for `(proof_idx, idx)`. `is_first_air_idx` scopes the per-proof input bus handshake to the very + first active row, while `is_first` marks the first layer row. +- **Layer counter**: `layer_idx = 0` on the `is_first` row and increments by one on every transition flagged by the loop + helper. +- **`lambda_prime` propagation**: On the root row, `lambda_prime` must equal `[1, 0, …, 0]`; on each transition the next + row’s `lambda_prime` is constrained to equal the previous row’s sampled `lambda`. This lets downstream AIRs reuse the + same logic for both initialization and continuing layers. +- **Root comparisons**: When `is_first = 1`, the `_prime` claims received from downstream AIRs must match the supplied + `r0_claim`, `w0_claim`, `q0_claim`. This replaces the old local interpolation logic. +- **Inter-layer propagation**: `next.sumcheck_claim_in = read_claim + write_claim + logup_claim` on transitions. The + `_prime` versions feed `sumcheck_claim_out = read_claim_prime + write_claim_prime + logup_claim_prime`, which is what + the sumcheck AIR receives. +- **Count consistency**: `num_read_count`, `num_write_count`, and `num_logup_count` are all constrained to equal + `n_logup`, and each is individually range-checked against ProofShape metadata via `AirShapeBus`. +- **Transcript timing**: Same `tidx` arithmetic as before, but now the post-sumcheck transcript window must also cover + the sample/observe operations that the product/logup AIRs perform themselves. + +### Interactions + +- **Layer buses** + - `layer_input.receive`: only on the first non-dummy row; provides `(idx, tidx, r0/w0/q0_claim)`. + - `layer_output.send`: on the last non-dummy row; reports `(idx, tidx_end, layer_idx_end, folded claim, lambda, mu)` + back to `TowerInputAir` so the caller can record the transcript state for downstream verifiers. +- **Sumcheck buses** + - `sumcheck_input.send`: for non-root layers, dispatches `(layer_idx, is_last_layer, tidx + D_EF, claim)` to the + sumcheck AIR. + - `sumcheck_output.receive`: ingests `(claim_out, eq_at_r_prime)` and re-encodes them into local columns. + - `sumcheck_challenge.send`: posts the `mu` challenge as round 0 for the next layer’s sumcheck. +- **Transcript bus** + - Samples `lambda` (non-root) and `mu`, observes all `p/q` evaluations. +- **Xi randomness bus** + - On the proof’s final layer, sends `mu` as the shared xi challenge consumed by later modules. +- **Prod/logup buses** + - Sends `(idx, layer_idx, tidx, lambda, lambda_prime, mu)` to the read/write prod AIRs every row (dummy rows are + masked out). Receives back both `lambda_claim` and `lambda_prime_claim` along with `num_read_count` / + `num_write_count`. + - Sends the same challenge payload to the logup AIR and receives its dual claims plus `num_logup_count`. + - No separate “init” buses exist anymore; setting `lambda_prime = 1` on the root row instructs the sub-AIRs to act + as + the initialization accumulators whose outputs are compared directly against `r0/w0/q0`. + +### Notes + +- Dummy rows allow reusing the same AIR width even when no layer work is pending; constraints are guarded by + `is_not_dummy` to avoid accidentally constraining padding rows. +- The transcript math (5·`D_EF` per layer after sumcheck) must stay synchronized with `TowerInputAir`’s tidx + bookkeeping. + +## TowerProdSumCheckClaimAir (`src/tower/layer/prod_claim/air.rs`) + +### Columns & Loops + +- `NestedForLoopSubAir<2>` enumerates `(proof_idx, idx)` and treats `layer_idx` as an inner counter controlled by + `is_first_layer`; within each `(proof_idx, idx, layer_idx)` triple an `index_id` column counts accumulator rows. +- Columns include: + - Loop/indexing flags (`is_enabled`, `is_first_layer`, `is_first`, `is_dummy`, `index_id`, `num_read_count`, + `num_write_count`). + - Metadata observed from `TowerLayerAir`: `layer_idx`, `tidx`, challenges `lambda`, `lambda_prime`, `mu`. + - Transcript observations: `p_xi_0`, `p_xi_1`, interpolated `p_xi`. + - Dual running powers/sums: `(pow_lambda, acc_sum)` for the standard sumcheck, `(pow_lambda_prime, acc_sum_prime)` + for + the root-compatible accumulator. + +### Constraints + +- Clamp `index_id` to zero on the first row of every layer triple, increment it while `stay_in_layer = 1`, and enforce + `index_id + 1 = num_read_count` / `num_write_count` on the rows that send results. +- Recompute `p_xi` via the usual linear interpolation in `mu`. +- Update both accumulators: + - `acc_sum_next = acc_sum + p_xi * pow_lambda`, with `pow_lambda_next = pow_lambda * lambda`. + - `acc_sum_prime_next = acc_sum_prime + (p_xi_0 * p_xi_1) * pow_lambda_prime`, + `pow_lambda_prime_next = pow_lambda_prime * lambda_prime`. +- The root-layer behavior falls out automatically: when `lambda_prime = 1`, the `_prime` accumulator simply sums + pairwise products, so the final row exports `r0`/`w0`-style claims. + +### Interactions + +- First row per layer triple receives + `TowerProdLayerChallengeMessage { idx, layer_idx, tidx, lambda, lambda_prime, mu }`. +- Final row sends `TowerProdSumClaimMessage { lambda_claim = acc_sum, lambda_prime_claim = acc_sum_prime }` alongside + the + appropriate `num_*_count`. Read/write variants simply use different buses. + +## TowerLogUpSumCheckClaimAir (`src/tower/layer/logup_claim/air.rs`) + +### Columns & Loops + +- Shares the same `(proof_idx, idx)` outer loop and `index_id` accumulator counter as the product AIR. +- Columns: + - Loop metadata plus `num_logup_count`. + - Transcript data `p_xi_0`, `p_xi_1`, `q_xi_0`, `q_xi_1`, interpolated `p_xi`, `q_xi`. + - Challenges `lambda`, `lambda_prime`, `mu`. + - Running powers `pow_lambda`, `pow_lambda_prime`. + - Accumulators: `acc_sum` for the standard `(p_xi + lambda * q_xi)` contribution, `acc_p_cross`, `acc_q_cross` for + the + log-up initialization terms that previously lived in their own AIR. + +### Constraints + +- Recompute `p_xi`, `q_xi` every row, then derive the cross terms + `p_cross = p_xi_0 * q_xi_1 + p_xi_1 * q_xi_0`, `q_cross = q_xi_0 * q_xi_1`. +- Accumulators: + - `acc_sum_next = acc_sum + pow_lambda * (p_xi + lambda * q_xi)`. + - `acc_p_cross_next = acc_p_cross + pow_lambda_prime * p_cross`. + - `acc_q_cross_next = acc_q_cross + pow_lambda_prime * lambda_prime * q_cross`. + Root-layer behavior again follows from `lambda_prime = 1`. +- `pow_lambda` and `pow_lambda_prime` follow the same multiplicative recurrence as in the product AIR. +- `index_id` bookkeeping and “final row sends” conditions mirror the product AIR. + +### Interactions + +- Receives the layer challenge message with both `lambda` and `lambda_prime` on the first row. +- Final row sends `TowerLogupClaimMessage { lambda_claim = acc_sum, lambda_prime_claim = acc_q_cross }` plus + `num_logup_count`. (The `acc_p_cross` value remains internal because only the denominator-style accumulator is needed + upstream at the moment.) + +## TowerLogUpSumCheckClaimAir (`src/tower/layer/logup_claim/air.rs`) + +### Columns & Loops + +- Shares the `(proof_idx, idx, layer_idx)` nested-loop structure and reuses `index_id` to count accumulator rows. +- Columns mirror the product AIR plus the denominator evaluations: `is_enabled`, the loop counters/flags, + `(p_xi_0, p_xi_1, q_xi_0, q_xi_1)`, interpolated `(p_xi, q_xi)`, `lambda`, `mu`, `pow_lambda`, `acc_sum`, + `index_id`, and `num_logup_count`. + +### Constraints + +- Recomputes both `p_xi` and `q_xi` in every row. +- Uses the existing log-up contribution `acc_sum_next = acc_sum + (lambda * q_xi) * pow_lambda`. +- `index_id` obeys the same initialization/increment/final-row checks against `num_logup_count` as the product AIR. +- Only the final accumulator row per `(proof_idx, idx, layer_idx)` may drive `TowerLogupClaimBus`. + +### Interactions + +- Layer metadata is consumed on the row flagged by `is_first_layer`. +- Folded logup claim is emitted exactly once per triple when the accumulator row counter reaches `num_logup_count`. + +## TowerLayerSumcheckAir (`src/tower/sumcheck/air.rs`) + +### Columns + +| Field | Shape | Description | +|-------------------------------|----------|-------------------------------------------------------------| +| `is_enabled` | scalar | Row selector. +| `proof_idx` | scalar | Proof counter. +| `idx` | scalar | Module index within the proof (mirrors `TowerLayerAir`). +| `layer_idx` | scalar | Layer whose sumcheck is being executed. +| `is_first_idx` | scalar | First sumcheck row for the current `(proof_idx, idx)` pair. | +| `is_first_layer` | scalar | First round row for the current layer. +| `is_first_round` | scalar | First round inside the layer. +| `is_dummy` | scalar | Padding flag. +| `is_last_layer` | scalar | Whether this layer is the final GKR layer. +| `round` | scalar | Sub-round index within the layer (0 .. layer_idx-1). +| `tidx` | scalar | Transcript cursor before reading evaluations. +| `ev1`, `ev2`, `ev3` | `[D_EF]` | Polynomial evaluations at points 1,2,3 (point 0 inferred). +| `claim_in`, `claim_out` | `[D_EF]` | Incoming/outgoing claims for each round. +| `prev_challenge`, `challenge` | `[D_EF]` | Previous xi component and the new random challenge. +| `eq_in`, `eq_out` | `[D_EF]` | Running eq accumulator before/after this round. + +### Row Constraints + +- **Looping**: `NestedForLoopSubAir<3>` now iterates over `(proof_idx, idx, layer_idx)` with the sumcheck round serving + as the innermost loop. The `is_first_idx` flag gates reset logic when we advance to a new module instance, while + `is_first_layer` protects the per-layer bookkeeping just before the round loop begins. +- **Round counter**: `round` starts at 0 and increments each transition; final round enforces `round = layer_idx - 1`. +- **Eq accumulator**: `eq_in = 1` on the first round; `eq_out = update_eq(eq_in, prev_challenge, challenge)` and + propagates forward. +- **Claim flow**: `claim_out` computed via `interpolate_cubic_at_0123` using `(claim_in - ev1)` as `ev0`; + `next.claim_in = claim_out` across transitions. +- **Transcript timing**: Each transition bumps `next.tidx = tidx + 4·D_EF` (three observations + challenge sample). + +### Interactions + +- `sumcheck_input.receive`: first non-dummy round pulls `(layer_idx, is_last_layer, tidx, claim)` from `TowerLayerAir`. +- `sumcheck_output.send`: last non-dummy round returns `(claim_out, eq_at_r_prime)` to the layer AIR. +- `sumcheck_challenge.receive/send`: enforces challenge chaining between layers/rounds (`prev_challenge` from prior + layer, `challenge` published for the next layer or eq export). +- All three buses now include the `idx` field so messages disambiguate distinct module instances inside the same proof. +- `transcript_bus.observe_ext`: records `ev1/ev2/ev3`, followed by `sample_ext` of `challenge`. + +### Notes + +- Dummy rows short-circuit all bus traffic; guard send/receive calls with `is_not_dummy`. +- The layout assumes cubic polynomials (degree 3) and would need updates if the sumcheck arity changes. diff --git a/ceno_recursion_v2/rust-toolchain.toml b/ceno_recursion_v2/rust-toolchain.toml new file mode 100644 index 000000000..b4514fe67 --- /dev/null +++ b/ceno_recursion_v2/rust-toolchain.toml @@ -0,0 +1,5 @@ +[toolchain] +channel = "nightly-2025-11-20" +targets = ["riscv32im-unknown-none-elf"] +# We need the sources for build-std. +components = ["rust-src"] diff --git a/ceno_recursion_v2/rustfmt.toml b/ceno_recursion_v2/rustfmt.toml new file mode 100644 index 000000000..1ff1f984d --- /dev/null +++ b/ceno_recursion_v2/rustfmt.toml @@ -0,0 +1,8 @@ +comment_width = 300 +edition = "2024" +imports_granularity = "Crate" +max_width = 100 +newline_style = "Unix" +normalize_comments = true +style_edition = "2024" +wrap_comments = false diff --git a/ceno_recursion_v2/skills/ceno-recursion-principles/SKILL.md b/ceno_recursion_v2/skills/ceno-recursion-principles/SKILL.md new file mode 100644 index 000000000..790c50fb4 --- /dev/null +++ b/ceno_recursion_v2/skills/ceno-recursion-principles/SKILL.md @@ -0,0 +1,130 @@ +--- +name: ceno-recursion-principles +description: Migration playbook for `ceno_recursion_v2` when integrating OpenVM recursion components with Ceno `RecursionProof`/`RecursionVk`. Focus on minimal forking, RecursionProof-first seams, preflight-owned replay, and strict placeholder/validation policy. +--- + +# Ceno Recursion Principles + +## Overview + +This skill captures the standing orders for evolving `ceno_recursion_v2`: reuse upstream OpenVM crates whenever possible, fork only where type boundaries force divergence, and keep ZKVM/OpenVM bridge logic at narrow seams. + +## Quick Triggers + +Use this skill when: +- Modifying `ceno_recursion_v2/src/system` or `src/continuation/**` +- Replacing `Proof` flows with `RecursionProof` +- Swapping child VK flows from `MultiStarkVerifyingKey` to `RecursionVk` +- Copying/patching OpenVM modules (recursion/continuation) into the Ceno crate +- Debugging trace/air mismatches during continuation proving + +## Core Principles + +1. **Minimal Divergence** - Fork only the seam that must diverge. Keep upstream for AIRs/modules that do not require Ceno type changes. +2. **RecursionProof First** - Public APIs in local forked modules should prefer `RecursionProof` and `RecursionVk` aliases over upstream `Proof` and `MultiStarkVerifyingKey`. +3. **Bridge Locality** - Put conversion stubs and TODOs in bridge points (`system/types.rs`, local trace adapters), not spread across unrelated AIR logic. +4. **Preflight Owns Replay** - Transcript/replay ordering is computed during preflight; later blob/trace generation should consume read-only replay data. +5. **Placeholder Discipline** - Temporary mocked values are allowed only if shape-correct and tagged with explicit TODO ownership. +6. **Invariant First** - Preserve count/order invariants (`airs()` vs `per_trace`) before semantic completeness. +7. **Visible, Reversible Deltas** - Prefer small, reviewable patches and avoid broad upstream copy unless absolutely required. + +## Minimal Fork Decision Matrix + +Fork locally only when at least one applies: +- Child proof/VK type at boundary must change to `RecursionProof`/`RecursionVk`. +- Upstream item is private (`pub(crate)`) and needed in local integration path. +- Upstream interface cannot inject Ceno-specific data without invasive changes. + +Do not fork when: +- Change is only wiring/imports and can be done in local caller. +- Upstream module already supports required behavior through existing interfaces. + +When forking, keep: +- Original file/module layout for future diffability. +- Fork scope minimal (single module seam first, then expand only if blocked). + +## Workflow + +### 1. Establish Type Seams First +- Confirm aliases in `src/system/types.rs` (`RecursionProof`, `RecursionVk`). +- Update constructor/trait signatures at seam files before touching AIR internals. + +### 2. Keep AIR/Trace Ordering Consistent +- Ensure `src/circuit/inner/mod.rs` `airs()` order exactly matches context order produced in `src/circuit/inner/trace.rs`. +- If pre/post contexts are re-enabled, corresponding AIR entries must be present. + +### 3. Placeholder Policy (Temporary) +- If data is missing from `RecursionProof`, use deterministic zero mocks. +- Add explicit comments: `TODO(recursion-proof-bridge): ...`. +- Mocked traces must be width-correct for their AIR and satisfy basic row-0 invariants. + +### 4. Replay Ownership Rule +- Preflight computes and records transcript/replay ordering. +- Blob/trace generation should consume replay records only (no hidden replay recomputation). + +### 5. Validation Loop +- Iterate with: `cargo check` -> target test -> capture first failing reason -> minimal fix. +- Prefer turning panics into structured errors where possible for diagnosability. +- Keep temporary diagnostics narrow and removable. + +### 6. Cargo/Test Hygiene +- Run checks on `ceno_recursion_v2` after each nontrivial seam change. +- Keep target regression test command handy: + - `RUST_MIN_STACK=33554432 RUST_BACKTRACE=1 cargo test -p ceno_recursion_v2 leaf_app_proof_round_trip_placeholder -- --nocapture` + +## Acceptance Checklist for Migration PRs + +Before continuing to next module, verify: +- `airs().len()` and proving-context trace count match. +- AIR ordering matches trace ordering (pre -> verifier -> post when enabled). +- Any mocked data has `TODO(recursion-proof-bridge)` and clear ownership. +- No new broad forks were introduced without matrix justification. +- Current top blocker is explicitly identified by latest test/check run. + +## Reusable PR Checklist Template + +Copy this block into each migration PR description and mark each item as done or N/A with rationale. + +```markdown +## Migration Checklist (ceno-recursion-principles) + +### Scope and Forking +- [ ] Fork scope is minimal and justified by the decision matrix. +- [ ] Upstream modules remain in use unless a concrete seam forces divergence. +- [ ] Forked files preserve upstream layout for easier future sync. + +### Type Seams +- [ ] New/updated public seams use `RecursionProof` / `RecursionVk` where applicable. +- [ ] Bridge logic is localized (for example `src/system/types.rs` or local trace adapters). +- [ ] No unrelated AIR/business logic was modified only to pass types through. + +### AIR and Trace Invariants +- [ ] `airs()` order matches proving context order exactly. +- [ ] `airs().len()` matches proving-trace count. +- [ ] Re-enabled pre/post contexts have corresponding AIR entries. + +### Placeholder Policy +- [ ] Any mocked value is deterministic and shape-correct. +- [ ] Each mocked value has `TODO(recursion-proof-bridge): ...` with ownership. +- [ ] Placeholder traces satisfy basic row-0 invariants for their AIR. + +### Replay Ownership +- [ ] Replay/transcript ordering is computed in preflight. +- [ ] Blob/trace generation consumes preflight replay records read-only. + +### Validation Evidence +- [ ] `cargo check` run after the latest nontrivial seam change. +- [ ] Target regression test run (or explicit blocker reason recorded). +- [ ] Current top blocker (if any) is stated with file/path and first failing message. +``` + +## Reference Paths + +- Skill source in repo: `skills/ceno-recursion-principles/SKILL.md` +- Skill source in global codex dir: `~/.codex/skills/ceno-recursion-principles/SKILL.md` +- Local seam hotspots: + - `src/system/types.rs` + - `src/circuit/inner/mod.rs` + - `src/circuit/inner/trace.rs` + - `src/continuation/prover/inner/mod.rs` + - `src/gkr/mod.rs`, `src/system/preflight/mod.rs` diff --git a/ceno_recursion_v2/skills/ceno-recursion-principles/agents/openai.yaml b/ceno_recursion_v2/skills/ceno-recursion-principles/agents/openai.yaml new file mode 100644 index 000000000..2d4faa272 --- /dev/null +++ b/ceno_recursion_v2/skills/ceno-recursion-principles/agents/openai.yaml @@ -0,0 +1,4 @@ +interface: + display_name: "Ceno Recursion" + short_description: "Migration playbook for Ceno recursion module forks" + default_prompt: "Follow ceno-recursion-principles: minimal fork matrix, RecursionProof-first seams, preflight-owned replay, TODO-tagged placeholder policy, and air/trace ordering invariants." diff --git a/ceno_recursion_v2/src/batch_constraint/expr_eval/constraints_folding/air.rs b/ceno_recursion_v2/src/batch_constraint/expr_eval/constraints_folding/air.rs new file mode 100644 index 000000000..743d241f1 --- /dev/null +++ b/ceno_recursion_v2/src/batch_constraint/expr_eval/constraints_folding/air.rs @@ -0,0 +1,185 @@ +use std::borrow::Borrow; + +use openvm_circuit_primitives::{SubAir, utils::assert_array_eq}; +use openvm_stark_backend::{ + BaseAirWithPublicValues, PartitionedBaseAir, interaction::InteractionBuilder, +}; +use openvm_stark_sdk::config::baby_bear_poseidon2::D_EF; +use p3_air::{Air, AirBuilder, BaseAir}; +use p3_field::{PrimeCharacteristicRing, extension::BinomiallyExtendable}; +use p3_matrix::Matrix; +use stark_recursion_circuit_derive::AlignedBorrow; + +use crate::{ + batch_constraint::bus::{ + ConstraintsFoldingBus, ConstraintsFoldingMessage, EqNOuterBus, EqNOuterMessage, + ExpressionClaimBus, ExpressionClaimMessage, + }, + bus::{NLiftBus, NLiftMessage, TranscriptBus}, + subairs::nested_for_loop::{NestedForLoopIoCols, NestedForLoopSubAir}, + utils::{ext_field_add, ext_field_multiply, ext_field_multiply_scalar}, +}; + +#[derive(AlignedBorrow, Copy, Clone)] +#[repr(C)] +pub struct ConstraintsFoldingCols { + pub is_valid: T, + pub is_first: T, + pub proof_idx: T, + + pub air_idx: T, + pub sort_idx: T, + pub constraint_idx: T, + pub n_lift: T, + + pub lambda_tidx: T, + pub lambda: [T; D_EF], + + pub value: [T; D_EF], + pub cur_sum: [T; D_EF], + pub eq_n: [T; D_EF], + + pub is_first_in_air: T, +} + +pub struct ConstraintsFoldingAir { + pub transcript_bus: TranscriptBus, + pub constraint_bus: ConstraintsFoldingBus, + pub expression_claim_bus: ExpressionClaimBus, + pub eq_n_outer_bus: EqNOuterBus, + pub n_lift_bus: NLiftBus, +} + +impl BaseAirWithPublicValues for ConstraintsFoldingAir {} +impl PartitionedBaseAir for ConstraintsFoldingAir {} + +impl BaseAir for ConstraintsFoldingAir { + fn width(&self) -> usize { + ConstraintsFoldingCols::::width() + } +} + +impl Air for ConstraintsFoldingAir +where + ::PrimeSubfield: BinomiallyExtendable<{ D_EF }>, +{ + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let (local, next) = ( + main.row_slice(0).expect("window should have two elements"), + main.row_slice(1).expect("window should have two elements"), + ); + + let local: &ConstraintsFoldingCols = (*local).borrow(); + let next: &ConstraintsFoldingCols = (*next).borrow(); + + type LoopSubAir = NestedForLoopSubAir<2>; + LoopSubAir {}.eval( + builder, + ( + NestedForLoopIoCols { + is_enabled: local.is_valid, + counter: [local.proof_idx, local.sort_idx], + is_first: [local.is_first, local.is_first_in_air], + } + .map_into(), + NestedForLoopIoCols { + is_enabled: next.is_valid, + counter: [next.proof_idx, next.sort_idx], + is_first: [next.is_first, next.is_first_in_air], + } + .map_into(), + ), + ); + + let is_same_proof = next.is_valid - next.is_first; + let is_same_air = next.is_valid - next.is_first_in_air; + + // =========================== indices consistency =============================== + // When we are within one air, constraint_idx increases by 0/1 + builder + .when(is_same_air.clone()) + .assert_bool(next.constraint_idx - local.constraint_idx); + // First constraint_idx within an air is zero + builder + .when(local.is_first_in_air) + .assert_zero(local.constraint_idx); + builder + .when(is_same_air.clone()) + .assert_eq(local.n_lift, next.n_lift); + + // ======================== lambda and cur sum consistency ============================ + assert_array_eq(&mut builder.when(is_same_proof), local.lambda, next.lambda); + assert_array_eq( + &mut builder.when(is_same_air.clone()), + local.cur_sum, + ext_field_add( + local.value, + ext_field_multiply::(local.lambda, next.cur_sum), + ), + ); + assert_array_eq( + &mut builder.when(is_same_air.clone()), + local.eq_n, + next.eq_n, + ); + // numerator and the last element of the message are just the corresponding values + assert_array_eq( + &mut builder.when(AB::Expr::ONE - is_same_air.clone()), + local.cur_sum, + local.value, + ); + + self.n_lift_bus.receive( + builder, + local.proof_idx, + NLiftMessage { + air_idx: local.air_idx, + n_lift: local.n_lift, + }, + local.is_first_in_air * local.is_valid, + ); + self.constraint_bus.receive( + builder, + local.proof_idx, + ConstraintsFoldingMessage { + air_idx: local.air_idx.into(), + constraint_idx: local.constraint_idx - AB::Expr::ONE, + value: local.value.map(Into::into), + }, + local.is_valid * (AB::Expr::ONE - local.is_first_in_air), + ); + let folded_sum: [AB::Expr; D_EF] = ext_field_add( + ext_field_multiply_scalar::(next.cur_sum, is_same_air.clone()), + ext_field_multiply_scalar::(local.cur_sum, AB::Expr::ONE - is_same_air), + ); + self.expression_claim_bus.send( + builder, + local.proof_idx, + ExpressionClaimMessage { + is_interaction: AB::Expr::ZERO, + idx: local.sort_idx.into(), + value: ext_field_multiply(folded_sum, local.eq_n), + }, + local.is_first_in_air * local.is_valid, + ); + self.transcript_bus.sample_ext( + builder, + local.proof_idx, + local.lambda_tidx, + local.lambda, + local.is_valid * local.is_first, + ); + + self.eq_n_outer_bus.lookup_key( + builder, + local.proof_idx, + EqNOuterMessage { + is_sharp: AB::Expr::ZERO, + n: local.n_lift.into(), + value: local.eq_n.map(Into::into), + }, + local.is_first_in_air * local.is_valid, + ); + } +} diff --git a/ceno_recursion_v2/src/batch_constraint/expr_eval/constraints_folding/mod.rs b/ceno_recursion_v2/src/batch_constraint/expr_eval/constraints_folding/mod.rs new file mode 100644 index 000000000..26ed4a40d --- /dev/null +++ b/ceno_recursion_v2/src/batch_constraint/expr_eval/constraints_folding/mod.rs @@ -0,0 +1,5 @@ +mod air; +mod trace; + +pub use air::*; +pub use trace::*; diff --git a/ceno_recursion_v2/src/batch_constraint/expr_eval/constraints_folding/trace.rs b/ceno_recursion_v2/src/batch_constraint/expr_eval/constraints_folding/trace.rs new file mode 100644 index 000000000..02ca5f716 --- /dev/null +++ b/ceno_recursion_v2/src/batch_constraint/expr_eval/constraints_folding/trace.rs @@ -0,0 +1,404 @@ +use std::borrow::BorrowMut; + +use openvm_stark_sdk::config::baby_bear_poseidon2::{D_EF, EF, F}; +use p3_field::{BasedVectorSpace, PrimeCharacteristicRing}; +use p3_matrix::dense::RowMajorMatrix; +use p3_maybe_rayon::prelude::*; + +use crate::{ + batch_constraint::expr_eval::constraints_folding::air::ConstraintsFoldingCols, + system::{Preflight, RecursionVk}, + tracegen::RowMajorChip, + utils::{MultiProofVecVec, MultiVecWithBounds}, +}; + +#[derive(Copy, Clone)] +#[repr(C)] +pub(crate) struct ConstraintsFoldingRecord { + sort_idx: usize, + air_idx: usize, + constraint_idx: usize, + node_idx: usize, + is_first_in_air: bool, + value: EF, +} + +pub(crate) struct ConstraintsFoldingBlob { + pub(crate) records: MultiProofVecVec, + // (n, value), n is before lift, can be negative + pub(crate) folded_claims: MultiProofVecVec<(isize, EF)>, +} + +impl ConstraintsFoldingBlob { + pub fn new( + child_vk: &RecursionVk, + expr_evals: &MultiVecWithBounds, + preflights: &[&Preflight], + ) -> Self { + let mut max_air_idx = 0usize; + for key in child_vk.circuit_index_to_name.keys().copied() { + max_air_idx = max_air_idx.max(key); + } + let mut constraints = vec![Vec::::new(); max_air_idx + 1]; + for (&air_idx, name) in &child_vk.circuit_index_to_name { + let expr_len = child_vk + .circuit_vks + .get(name) + .and_then(|vk| vk.cs.gkr_circuit.as_ref()) + .and_then(|circuit| circuit.layers.get(0)) + .map(|layer| layer.exprs.len()) + .unwrap_or_default(); + if air_idx >= constraints.len() { + constraints.resize(air_idx + 1, vec![]); + } + constraints[air_idx] = (0..expr_len).collect(); + } + + let mut records = MultiProofVecVec::new(); + let mut folded = MultiProofVecVec::new(); + for (pidx, preflight) in preflights.iter().enumerate() { + let lambda_tidx = preflight.batch_constraint.lambda_tidx; + let lambda = EF::from_basis_coefficients_slice( + &preflight.transcript.values()[lambda_tidx..lambda_tidx + D_EF], + ) + .unwrap(); + + let vdata = &preflight.proof_shape.sorted_trace_vdata; + for (sort_idx, (air_idx, v)) in vdata.iter().enumerate() { + let constrs = &constraints[*air_idx]; + records.push(ConstraintsFoldingRecord { + // dummy to avoid handling case with no constraints + sort_idx, + air_idx: *air_idx, + constraint_idx: 0, + node_idx: 0, + is_first_in_air: true, + value: EF::ZERO, + }); + let mut folded_claim = EF::ZERO; + let mut lambda_pow = EF::ONE; + for (constraint_idx, &constr) in constrs.iter().enumerate() { + let value = expr_evals[[pidx, *air_idx]][constr]; + folded_claim += lambda_pow * value; + lambda_pow *= lambda; + records.push(ConstraintsFoldingRecord { + sort_idx, + air_idx: *air_idx, + constraint_idx: constraint_idx + 1, + node_idx: constr, + is_first_in_air: false, + value, + }); + } + let l_skip = preflight.proof_shape.l_skip; + let n_lift = v.log_height.saturating_sub(l_skip); + let n = v.log_height as isize - l_skip as isize; + folded.push(( + n, + folded_claim * preflight.batch_constraint.eq_ns_frontloaded[n_lift], + )); + } + records.end_proof(); + folded.end_proof(); + } + Self { + records, + folded_claims: folded, + } + } +} + +pub struct ConstraintsFoldingTraceGenerator; + +impl RowMajorChip for ConstraintsFoldingTraceGenerator { + type Ctx<'a> = (&'a ConstraintsFoldingBlob, &'a [&'a Preflight]); + + #[tracing::instrument(level = "trace", skip_all)] + fn generate_trace( + &self, + ctx: &Self::Ctx<'_>, + required_height: Option, + ) -> Option> { + let (blob, preflights) = ctx; + let width = ConstraintsFoldingCols::::width(); + + let total_height = blob.records.len(); + debug_assert!(total_height > 0); + let padding_height = if let Some(height) = required_height { + if height < total_height { + return None; + } + height + } else { + total_height.next_power_of_two() + }; + let mut trace = vec![F::ZERO; padding_height * width]; + + let mut cur_height = 0; + for (pidx, preflight) in preflights.iter().enumerate() { + let lambda_tidx = preflight.batch_constraint.lambda_tidx; + let lambda_slice = &preflight.transcript.values()[lambda_tidx..lambda_tidx + D_EF]; + let records = &blob.records[pidx]; + + trace[cur_height * width..(cur_height + records.len()) * width] + .par_chunks_exact_mut(width) + .zip(records.par_iter()) + .for_each(|(chunk, record)| { + let cols: &mut ConstraintsFoldingCols<_> = chunk.borrow_mut(); + let n_lift = preflight.proof_shape.sorted_trace_vdata[record.sort_idx] + .1 + .log_height + .saturating_sub(preflight.proof_shape.l_skip); + + cols.is_valid = F::ONE; + cols.proof_idx = F::from_usize(pidx); + cols.air_idx = F::from_usize(record.air_idx); + cols.sort_idx = F::from_usize(record.sort_idx); + cols.constraint_idx = F::from_usize(record.constraint_idx); + cols.n_lift = F::from_usize(n_lift); + cols.lambda_tidx = F::from_usize(lambda_tidx); + cols.lambda.copy_from_slice(lambda_slice); + cols.value + .copy_from_slice(record.value.as_basis_coefficients_slice()); + cols.eq_n.copy_from_slice( + preflight.batch_constraint.eq_ns_frontloaded[n_lift] + .as_basis_coefficients_slice(), + ); + cols.is_first_in_air = F::from_bool(record.is_first_in_air); + }); + + // Setting `cur_sum` + let mut cur_sum = EF::ZERO; + let lambda = EF::from_basis_coefficients_slice(lambda_slice).unwrap(); + trace[cur_height * width..(cur_height + records.len()) * width] + .chunks_exact_mut(width) + .rev() + .for_each(|chunk| { + let cols: &mut ConstraintsFoldingCols<_> = chunk.borrow_mut(); + cur_sum = + cur_sum * lambda + EF::from_basis_coefficients_slice(&cols.value).unwrap(); + cols.cur_sum + .copy_from_slice(cur_sum.as_basis_coefficients_slice()); + if cols.is_first_in_air == F::ONE { + cur_sum = EF::ZERO; + } + }); + + { + let cols: &mut ConstraintsFoldingCols<_> = + trace[cur_height * width..(cur_height + 1) * width].borrow_mut(); + cols.is_first = F::ONE; + } + cur_height += records.len(); + } + Some(RowMajorMatrix::new(trace, width)) + } +} + +#[cfg(feature = "cuda")] +pub(in crate::batch_constraint) mod cuda { + use openvm_circuit_primitives::cuda_abi::UInt2; + use openvm_cuda_backend::{GpuBackend, base::DeviceMatrix}; + use openvm_cuda_common::{copy::MemCopyH2D, d_buffer::DeviceBuffer}; + use openvm_stark_backend::prover::AirProvingContext; + + use super::*; + use crate::{ + batch_constraint::cuda_abi::{ + AffineFpExt, FpExtWithTidx, constraints_folding_tracegen, + constraints_folding_tracegen_temp_bytes, + }, + cuda::{preflight::PreflightGpu, vk::VerifyingKeyGpu}, + tracegen::ModuleChip, + }; + + pub struct ConstraintsFoldingBlobGpu { + // Per proof, per AIR, per constraint + pub values: Vec>>, + // Per proof + pub constraints_folding_per_proof: Vec, + // For compatibility with CPU tracegen + pub folded_claims: MultiProofVecVec<(isize, EF)>, + } + + impl ConstraintsFoldingBlobGpu { + pub fn new( + vk: &VerifyingKeyGpu, + expr_evals: &MultiVecWithBounds, + preflights: &[PreflightGpu], + ) -> Self { + let constraints = vk + .cpu + .inner + .per_air + .iter() + .map(|vk| vk.symbolic_constraints.constraints.constraint_idx.clone()) + .collect_vec(); + + let mut values = Vec::with_capacity(preflights.len()); + let mut constraints_folding_per_proof = Vec::with_capacity(preflights.len()); + let mut folded_claims = MultiProofVecVec::new(); + + for (pidx, preflight) in preflights.iter().enumerate() { + let lambda_tidx = preflight.cpu.batch_constraint.lambda_tidx; + let lambda = EF::from_basis_coefficients_slice( + &preflight.cpu.transcript.values()[lambda_tidx..lambda_tidx + D_EF], + ) + .unwrap(); + + let vdata = &preflight.cpu.proof_shape.sorted_trace_vdata; + let mut proof_values = Vec::with_capacity(vdata.len()); + + for (air_idx, v) in vdata.iter() { + let mut folded_claim = EF::ZERO; + let mut lambda_pow = EF::ONE; + + let air_values = std::iter::once(EF::ZERO) + .chain(constraints[*air_idx].iter().map(|&constr| { + let value = expr_evals[[pidx, *air_idx]][constr]; + folded_claim += lambda_pow * value; + lambda_pow *= lambda; + value + })) + .collect_vec(); + proof_values.push(air_values); + + let n_lift = v.log_height.saturating_sub(vk.system_params.l_skip); + let n = v.log_height as isize - vk.system_params.l_skip as isize; + folded_claims.push(( + n, + folded_claim * preflight.cpu.batch_constraint.eq_ns_frontloaded[n_lift], + )); + } + + values.push(proof_values); + constraints_folding_per_proof.push(FpExtWithTidx { + value: lambda, + tidx: lambda_tidx as u32, + }); + folded_claims.end_proof(); + } + + Self { + values, + constraints_folding_per_proof, + folded_claims, + } + } + } + + impl ModuleChip for ConstraintsFoldingTraceGenerator { + type Ctx<'a> = ( + &'a VerifyingKeyGpu, + &'a [PreflightGpu], + &'a ConstraintsFoldingBlobGpu, + ); + + #[tracing::instrument(level = "trace", skip_all)] + fn generate_proving_ctx( + &self, + ctx: &Self::Ctx<'_>, + required_height: Option, + ) -> Option> { + let (child_vk, preflights_gpu, blob) = ctx; + + let mut num_valid_rows = 0u32; + let mut row_bounds = Vec::with_capacity(preflights_gpu.len()); + let mut constraint_bounds = Vec::with_capacity(preflights_gpu.len()); + let mut proof_and_sort_idxs = vec![]; + + let flat_values = blob + .values + .iter() + .enumerate() + .flat_map(|(proof_idx, proof_values)| { + let mut num_constraints_in_proof = 0; + let mut proof_constraint_bounds = Vec::with_capacity(proof_values.len()); + for (sort_idx, air_values) in proof_values.iter().enumerate() { + let num_constraints = air_values.len(); + num_constraints_in_proof += num_constraints as u32; + proof_constraint_bounds.push(num_constraints_in_proof); + proof_and_sort_idxs.extend(std::iter::repeat_n( + UInt2 { + x: proof_idx as u32, + y: sort_idx as u32, + }, + num_constraints, + )); + } + num_valid_rows += num_constraints_in_proof; + row_bounds.push(num_valid_rows); + constraint_bounds.push(proof_constraint_bounds.to_device().unwrap()); + proof_values.iter().flatten().copied() + }) + .collect_vec(); + let eq_ns = preflights_gpu + .iter() + .map(|preflight| { + preflight + .cpu + .batch_constraint + .eq_ns_frontloaded + .to_device() + .unwrap() + }) + .collect_vec(); + + let height = if let Some(height) = required_height { + if height < num_valid_rows as usize { + return None; + } + height + } else { + (num_valid_rows as usize).next_power_of_two() + }; + let width = ConstraintsFoldingCols::::width(); + let d_trace = DeviceMatrix::::with_capacity(height, width); + + let d_proof_and_sort_idxs = proof_and_sort_idxs.to_device().unwrap(); + let d_values = flat_values.to_device().unwrap(); + let d_cur_sum_evals = DeviceBuffer::::with_capacity(d_values.len()); + + let d_constraint_bounds = constraint_bounds.iter().map(|b| b.as_ptr()).collect_vec(); + let d_sorted_trace_heights = preflights_gpu + .iter() + .map(|preflight| preflight.proof_shape.sorted_trace_heights.as_ptr()) + .collect_vec(); + let d_eq_ns = eq_ns.iter().map(|b| b.as_ptr()).collect_vec(); + + let d_per_proof = blob.constraints_folding_per_proof.to_device().unwrap(); + + unsafe { + let temp_bytes = constraints_folding_tracegen_temp_bytes( + &d_proof_and_sort_idxs, + &d_cur_sum_evals, + num_valid_rows, + ) + .unwrap(); + let d_temp_buffer = DeviceBuffer::::with_capacity(temp_bytes); + constraints_folding_tracegen( + d_trace.buffer(), + height, + width, + &d_proof_and_sort_idxs, + &d_cur_sum_evals, + &d_values, + &row_bounds, + d_constraint_bounds, + d_sorted_trace_heights, + d_eq_ns, + &d_per_proof, + preflights_gpu.len() as u32, + child_vk.per_air.len() as u32, + num_valid_rows, + child_vk.system_params.l_skip as u32, + &d_temp_buffer, + temp_bytes, + ) + .unwrap(); + } + + Some(AirProvingContext::simple_no_pis(d_trace)) + } + } +} diff --git a/ceno_recursion_v2/src/batch_constraint/expr_eval/mod.rs b/ceno_recursion_v2/src/batch_constraint/expr_eval/mod.rs new file mode 100644 index 000000000..996852018 --- /dev/null +++ b/ceno_recursion_v2/src/batch_constraint/expr_eval/mod.rs @@ -0,0 +1,5 @@ +pub mod constraints_folding; +pub mod symbolic_expression; + +pub use constraints_folding::*; +pub use symbolic_expression::*; diff --git a/ceno_recursion_v2/src/batch_constraint/expr_eval/symbolic_expression/air.rs b/ceno_recursion_v2/src/batch_constraint/expr_eval/symbolic_expression/air.rs new file mode 100644 index 000000000..919a034d8 --- /dev/null +++ b/ceno_recursion_v2/src/batch_constraint/expr_eval/symbolic_expression/air.rs @@ -0,0 +1,413 @@ +use core::array; +use std::borrow::Borrow; + +use openvm_circuit_primitives::{encoder::Encoder, utils::assert_array_eq}; +use openvm_stark_backend::{ + BaseAirWithPublicValues, PartitionedBaseAir, air_builders::PartitionedAirBuilder, + interaction::InteractionBuilder, +}; +use openvm_stark_sdk::config::baby_bear_poseidon2::D_EF; +use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, BaseAir}; +use p3_field::{Field, PrimeCharacteristicRing, extension::BinomiallyExtendable}; +use p3_matrix::Matrix; +use stark_recursion_circuit_derive::AlignedBorrow; +use strum::{EnumCount, IntoEnumIterator}; +use strum_macros::EnumIter; + +use crate::{ + batch_constraint::bus::{ + ConstraintsFoldingBus, ConstraintsFoldingMessage, InteractionsFoldingBus, + InteractionsFoldingMessage, SymbolicExpressionBus, SymbolicExpressionMessage, + }, + bus::{ + AirPresenceBus, AirPresenceBusMessage, AirShapeBus, AirShapeBusMessage, ColumnClaimsBus, + ColumnClaimsMessage, HyperdimBus, HyperdimBusMessage, PublicValuesBus, + PublicValuesBusMessage, SelHypercubeBus, SelHypercubeBusMessage, SelUniBus, + SelUniBusMessage, + }, + proof_shape::bus::AirShapeProperty, + utils::{ + base_to_ext, ext_field_add, ext_field_multiply, ext_field_multiply_scalar, + ext_field_subtract, scalar_subtract_ext_field, + }, +}; + +pub const NUM_FLAGS: usize = 5; +pub const ENCODER_MAX_DEGREE: u32 = 2; + +#[derive(Debug, Clone, Copy, EnumIter, EnumCount)] +pub enum NodeKind { + WitIn = 0, + StructuralWitIn = 1, + Fixed = 2, + Instance = 3, + SelIsFirst = 4, + SelIsLast = 5, + SelIsTransition = 6, + Constant = 7, + Add = 8, + Sub = 9, + Neg = 10, + Mul = 11, + InteractionMult = 12, + InteractionMsgComp = 13, + InteractionBusIndex = 14, +} + +impl Default for NodeKind { + fn default() -> Self { + NodeKind::WitIn + } +} + +#[derive(AlignedBorrow, Copy, Clone)] +#[repr(C)] +pub struct CachedSymbolicExpressionColumns { + pub flags: [T; NUM_FLAGS], + pub air_idx: T, + pub node_or_interaction_idx: T, + pub attrs: [T; 3], + pub fanout: T, + pub is_constraint: T, + pub constraint_idx: T, +} + +#[derive(AlignedBorrow, Copy, Clone)] +#[repr(C)] +pub struct SingleMainSymbolicExpressionColumns { + /// 0 = absent proof, 1 = proof present but air absent, 2 = proof+air present. + pub slot_state: T, + pub args: [T; 2 * D_EF], + pub sort_idx: T, + pub n_abs: T, + pub is_n_neg: T, +} + +pub struct SymbolicExpressionAir { + pub expr_bus: SymbolicExpressionBus, + pub hyperdim_bus: HyperdimBus, + pub air_shape_bus: AirShapeBus, + pub air_presence_bus: AirPresenceBus, + pub column_claims_bus: ColumnClaimsBus, + pub interactions_folding_bus: InteractionsFoldingBus, + pub constraints_folding_bus: ConstraintsFoldingBus, + pub public_values_bus: PublicValuesBus, + pub sel_hypercube_bus: SelHypercubeBus, + pub sel_uni_bus: SelUniBus, + + pub cnt_proofs: usize, +} + +impl BaseAirWithPublicValues for SymbolicExpressionAir {} + +impl PartitionedBaseAir for SymbolicExpressionAir { + fn cached_main_widths(&self) -> Vec { + vec![CachedSymbolicExpressionColumns::::width()] + } + + fn common_main_width(&self) -> usize { + SingleMainSymbolicExpressionColumns::::width() * self.cnt_proofs + } +} + +impl BaseAir for SymbolicExpressionAir { + fn width(&self) -> usize { + CachedSymbolicExpressionColumns::::width() + + SingleMainSymbolicExpressionColumns::::width() * self.cnt_proofs + } +} + +impl Air + for SymbolicExpressionAir +where + ::PrimeSubfield: BinomiallyExtendable<{ D_EF }>, +{ + fn eval(&self, builder: &mut AB) { + let cached_local = builder.cached_mains()[0] + .row_slice(0) + .expect("cached window should have a row") + .to_vec(); + let main_local = builder + .common_main() + .row_slice(0) + .expect("main window should have a row") + .to_vec(); + let main_next = builder + .common_main() + .row_slice(1) + .expect("main window should have two rows") + .to_vec(); + + let cached_cols: &CachedSymbolicExpressionColumns = + cached_local.as_slice().borrow(); + let main_cols: Vec<&SingleMainSymbolicExpressionColumns> = main_local + .chunks(SingleMainSymbolicExpressionColumns::::width()) + .map(|chunk| chunk.borrow()) + .collect(); + let next_main_cols: Vec<&SingleMainSymbolicExpressionColumns> = main_next + .chunks(SingleMainSymbolicExpressionColumns::::width()) + .map(|chunk| chunk.borrow()) + .collect(); + + let enc = Encoder::new(NodeKind::COUNT, ENCODER_MAX_DEGREE, true); + let flags = cached_cols.flags; + let is_valid_row = enc.is_valid::(&flags); + + let is_arg0_node_idx = enc.contains_flag::( + &flags, + &[ + NodeKind::Add, + NodeKind::Sub, + NodeKind::Mul, + NodeKind::Neg, + NodeKind::InteractionMult, + NodeKind::InteractionMsgComp, + NodeKind::WitIn, + NodeKind::StructuralWitIn, + NodeKind::Fixed, + NodeKind::Instance, + ] + .map(|x| x as usize), + ); + let is_arg1_node_idx = enc.contains_flag::( + &flags, + &[ + NodeKind::Add, + NodeKind::Sub, + NodeKind::Mul, + NodeKind::InteractionMsgComp, + ] + .map(|x| x as usize), + ); + + for (proof_idx, (&cols, &next_cols)) in main_cols.iter().zip(&next_main_cols).enumerate() { + let proof_idx = AB::F::from_usize(proof_idx); + + let slot_state: AB::Expr = cols.slot_state.into(); + let next_slot_state: AB::Expr = next_cols.slot_state.into(); + let proof_present = slot_state.clone() + * (AB::Expr::from_u8(3) - slot_state.clone()) + * AB::F::TWO.inverse(); + let next_proof_present = next_slot_state.clone() + * (AB::Expr::from_u8(3) - next_slot_state) + * AB::F::TWO.inverse(); + let air_present = + slot_state.clone() * (slot_state.clone() - AB::Expr::ONE) * AB::F::TWO.inverse(); + + let arg_ef0: [AB::Var; D_EF] = cols.args[..D_EF].try_into().unwrap(); + let arg_ef1: [AB::Var; D_EF] = cols.args[D_EF..2 * D_EF].try_into().unwrap(); + + builder.assert_tern(cols.slot_state); + builder + .when(cols.is_n_neg) + .assert_eq(cols.slot_state, AB::Expr::TWO); + builder + .when(air_present.clone()) + .assert_one(is_valid_row.clone()); + builder + .when_transition() + .assert_eq(proof_present.clone(), next_proof_present); + + let mut value = [AB::Expr::ZERO; D_EF]; + for node_kind in NodeKind::iter() { + let sel = enc.get_flag_expr::(node_kind as usize, &flags); + let expr = match node_kind { + NodeKind::Add => ext_field_add::(arg_ef0, arg_ef1), + NodeKind::Sub => ext_field_subtract::(arg_ef0, arg_ef1), + NodeKind::Neg => scalar_subtract_ext_field::(AB::Expr::ZERO, arg_ef0), + NodeKind::Mul => ext_field_multiply::(arg_ef0, arg_ef1), + NodeKind::Constant => base_to_ext(cached_cols.attrs[0]), + NodeKind::Instance => base_to_ext(cols.args[0]), + NodeKind::SelIsFirst => ext_field_multiply(arg_ef0, arg_ef1), + NodeKind::SelIsLast => ext_field_multiply(arg_ef0, arg_ef1), + NodeKind::SelIsTransition => scalar_subtract_ext_field( + AB::Expr::ONE, + ext_field_multiply(arg_ef0, arg_ef1), + ), + NodeKind::WitIn + | NodeKind::StructuralWitIn + | NodeKind::Fixed + | NodeKind::InteractionMult + | NodeKind::InteractionMsgComp => arg_ef0.map(Into::into), + NodeKind::InteractionBusIndex => { + base_to_ext(cached_cols.attrs[0] + AB::Expr::ONE) + } + }; + value = ext_field_add::( + value, + ext_field_multiply_scalar::(expr, sel), + ); + } + + self.expr_bus.add_key_with_lookups( + builder, + proof_idx, + SymbolicExpressionMessage { + air_idx: cached_cols.air_idx.into(), + node_idx: cached_cols.node_or_interaction_idx.into(), + value: value.clone(), + }, + air_present.clone() * cached_cols.fanout, + ); + self.expr_bus.lookup_key( + builder, + proof_idx, + SymbolicExpressionMessage { + air_idx: cached_cols.air_idx, + node_idx: cached_cols.attrs[0], + value: arg_ef0, + }, + air_present.clone() * is_arg0_node_idx.clone(), + ); + self.expr_bus.lookup_key( + builder, + proof_idx, + SymbolicExpressionMessage { + air_idx: cached_cols.air_idx, + node_idx: cached_cols.attrs[1], + value: arg_ef1, + }, + air_present.clone() * is_arg1_node_idx.clone(), + ); + + let is_var = enc.contains_flag::( + &flags, + &[NodeKind::WitIn, NodeKind::StructuralWitIn, NodeKind::Fixed].map(|x| x as usize), + ); + self.column_claims_bus.receive( + builder, + proof_idx, + ColumnClaimsMessage { + sort_idx: cols.sort_idx.into(), + part_idx: cached_cols.attrs[1].into(), + col_idx: cached_cols.attrs[0].into(), + claim: array::from_fn(|i| cols.args[i].into()), + is_rot: cached_cols.attrs[2].into(), + }, + is_var * air_present.clone(), + ); + self.public_values_bus.receive( + builder, + proof_idx, + PublicValuesBusMessage { + air_idx: cached_cols.air_idx, + pv_idx: cached_cols.attrs[0], + value: cols.args[0], + }, + enc.get_flag_expr::(NodeKind::Instance as usize, &flags) * air_present.clone(), + ); + self.air_shape_bus.lookup_key( + builder, + proof_idx, + AirShapeBusMessage { + sort_idx: cols.sort_idx.into(), + property_idx: AirShapeProperty::AirId.to_field(), + value: cached_cols.air_idx.into(), + }, + air_present.clone(), + ); + self.air_presence_bus.lookup_key( + builder, + proof_idx, + AirPresenceBusMessage { + air_idx: cached_cols.air_idx.into(), + is_present: air_present.clone(), + }, + proof_present * is_valid_row.clone(), + ); + self.hyperdim_bus.lookup_key( + builder, + proof_idx, + HyperdimBusMessage { + sort_idx: cols.sort_idx, + n_abs: cols.n_abs, + n_sign_bit: cols.is_n_neg, + }, + air_present.clone(), + ); + + let is_sel = enc.contains_flag::( + &flags, + &[ + NodeKind::SelIsFirst, + NodeKind::SelIsLast, + NodeKind::SelIsTransition, + ] + .map(|x| x as usize), + ); + let is_first = enc.get_flag_expr::(NodeKind::SelIsFirst as usize, &flags); + self.sel_uni_bus.lookup_key( + builder, + proof_idx, + SelUniBusMessage { + n: AB::Expr::NEG_ONE * cols.n_abs * cols.is_n_neg, + is_first: is_first.clone(), + value: arg_ef0.map(Into::into), + }, + air_present.clone() * is_sel.clone(), + ); + self.sel_hypercube_bus.lookup_key( + builder, + proof_idx, + SelHypercubeBusMessage { + n: cols.n_abs.into(), + is_first: is_first.clone(), + value: arg_ef1.map(Into::into), + }, + is_sel.clone() * (air_present.clone() - cols.is_n_neg), + ); + assert_array_eq( + &mut builder.when(is_sel.clone() * cols.is_n_neg), + arg_ef1, + [ + AB::Expr::ONE, + AB::Expr::ZERO, + AB::Expr::ZERO, + AB::Expr::ZERO, + ], + ); + + let is_mult = enc.get_flag_expr::(NodeKind::InteractionMult as usize, &flags); + let is_bus_index = + enc.get_flag_expr::(NodeKind::InteractionBusIndex as usize, &flags); + let is_interaction = enc.contains_flag::( + &flags, + &[NodeKind::InteractionMult, NodeKind::InteractionMsgComp].map(|x| x as usize), + ); + self.interactions_folding_bus.send( + builder, + proof_idx, + InteractionsFoldingMessage { + air_idx: cached_cols.air_idx.into(), + interaction_idx: cached_cols.node_or_interaction_idx.into(), + is_mult, + idx_in_message: cached_cols.attrs[1].into(), + value: value.clone(), + }, + is_interaction * air_present.clone(), + ); + self.interactions_folding_bus.send( + builder, + proof_idx, + InteractionsFoldingMessage { + air_idx: cached_cols.air_idx.into(), + interaction_idx: cached_cols.node_or_interaction_idx.into(), + is_mult: AB::Expr::ZERO, + idx_in_message: AB::Expr::NEG_ONE, + value: value.clone(), + }, + is_bus_index * air_present.clone(), + ); + self.constraints_folding_bus.send( + builder, + proof_idx, + ConstraintsFoldingMessage { + air_idx: cached_cols.air_idx.into(), + constraint_idx: cached_cols.constraint_idx.into(), + value: value.clone(), + }, + cached_cols.is_constraint * air_present, + ); + } + } +} diff --git a/ceno_recursion_v2/src/batch_constraint/expr_eval/symbolic_expression/mod.rs b/ceno_recursion_v2/src/batch_constraint/expr_eval/symbolic_expression/mod.rs new file mode 100644 index 000000000..c68123602 --- /dev/null +++ b/ceno_recursion_v2/src/batch_constraint/expr_eval/symbolic_expression/mod.rs @@ -0,0 +1,5 @@ +pub(crate) mod air; +pub(crate) mod trace; + +pub use air::*; +pub use trace::*; diff --git a/ceno_recursion_v2/src/batch_constraint/expr_eval/symbolic_expression/trace.rs b/ceno_recursion_v2/src/batch_constraint/expr_eval/symbolic_expression/trace.rs new file mode 100644 index 000000000..47e8f6f06 --- /dev/null +++ b/ceno_recursion_v2/src/batch_constraint/expr_eval/symbolic_expression/trace.rs @@ -0,0 +1,520 @@ +use core::cmp::min; +use std::borrow::BorrowMut; + +use openvm_circuit_primitives::encoder::Encoder; +use openvm_stark_backend::{ + air_builders::symbolic::{SymbolicExpressionNode, symbolic_variable::Entry}, + poly_common::{Squarable, eval_eq_uni_at_one}, +}; +use openvm_stark_sdk::config::baby_bear_poseidon2::{D_EF, EF, F}; +use p3_field::{BasedVectorSpace, PrimeCharacteristicRing, TwoAdicField}; +use p3_matrix::dense::RowMajorMatrix; +use p3_maybe_rayon::prelude::*; +use strum::EnumCount; + +use crate::{ + batch_constraint::expr_eval::symbolic_expression::air::{ + CachedSymbolicExpressionColumns, ENCODER_MAX_DEGREE, NodeKind, + SingleMainSymbolicExpressionColumns, + }, + system::{Preflight, RecursionField, RecursionVk, convert_vk_from_zkvm}, + tracegen::RowMajorChip, + utils::{MultiVecWithBounds, interaction_length}, +}; +use ceno_zkvm::structs::ComposedConstrainSystem; +use multilinear_extensions::{Expression, Fixed}; + +pub struct SymbolicExpressionTraceGenerator { + pub max_num_proofs: usize, +} + +pub struct SymbolicExpressionCtx<'a> { + pub vk: &'a RecursionVk, + pub preflights: &'a [&'a Preflight], + pub expr_evals: &'a MultiVecWithBounds, +} + +impl RowMajorChip for SymbolicExpressionTraceGenerator { + type Ctx<'a> = SymbolicExpressionCtx<'a>; + + #[tracing::instrument(level = "trace", skip_all)] + fn generate_trace( + &self, + ctx: &Self::Ctx<'_>, + required_height: Option, + ) -> Option> { + let child_vk = convert_vk_from_zkvm(ctx.vk); + let child_vk = child_vk.as_ref(); + let preflights = ctx.preflights; + let max_num_proofs = self.max_num_proofs; + let expr_evals = ctx.expr_evals; + let l_skip = child_vk.inner.params.l_skip; + + let single_main_width = SingleMainSymbolicExpressionColumns::::width(); + let main_width = single_main_width * max_num_proofs; + + struct Record { + args: [F; 2 * D_EF], + sort_idx: usize, + n_abs: usize, + is_n_neg: usize, + } + let mut records = vec![]; + + for (proof_idx, preflight) in preflights.iter().enumerate() { + let rs = &preflight.batch_constraint.sumcheck_rnd; + if rs.is_empty() { + continue; + } + let (&rs_0, rs_rest) = rs.split_first().unwrap(); + let mut is_first_uni_by_log_height = vec![]; + let mut is_last_uni_by_log_height = vec![]; + + for (log_height, &r_pow) in rs_0 + .exp_powers_of_2() + .take(l_skip + 1) + .collect::>() + .iter() + .rev() + .enumerate() + { + is_first_uni_by_log_height.push(eval_eq_uni_at_one(log_height, r_pow)); + is_last_uni_by_log_height.push(eval_eq_uni_at_one( + log_height, + r_pow * F::two_adic_generator(log_height), + )); + } + let mut is_first_mle_by_n = vec![EF::ONE]; + let mut is_last_mle_by_n = vec![EF::ONE]; + for (i, &r) in rs_rest.iter().enumerate() { + is_first_mle_by_n.push(is_first_mle_by_n[i] * (EF::ONE - r)); + is_last_mle_by_n.push(is_last_mle_by_n[i] * r); + } + + for (air_idx, vk) in child_vk.inner.per_air.iter().enumerate() { + let constraints = &vk.symbolic_constraints.constraints; + let expr_evals = &expr_evals[[proof_idx, air_idx]]; + + if expr_evals.is_empty() { + let n = constraints.nodes.len() + + vk.symbolic_constraints + .interactions + .iter() + .map(interaction_length) + .sum::() + + vk.unused_variables.len(); + records.resize_with(records.len() + n, || None); + continue; + } + + let (sort_idx, trace_vdata) = preflight + .proof_shape + .sorted_trace_vdata + .iter() + .enumerate() + .find_map(|(sort_idx, (idx, vdata))| { + (*idx == air_idx).then_some((sort_idx, vdata)) + }) + .unwrap(); + + let log_height = trace_vdata.log_height; + let (n_abs, is_n_neg) = if log_height < l_skip { + (l_skip - log_height, 1) + } else { + (log_height - l_skip, 0) + }; + + for (node_idx, node) in constraints.nodes.iter().enumerate() { + let mut record = Record { + args: [F::ZERO; 2 * D_EF], + sort_idx, + n_abs, + is_n_neg, + }; + match node { + SymbolicExpressionNode::Variable(var) => match var.entry { + Entry::Preprocessed { .. } | Entry::Main { .. } | Entry::Public => { + record.args[..D_EF].copy_from_slice( + expr_evals[node_idx].as_basis_coefficients_slice(), + ); + } + Entry::Permutation { .. } => unreachable!(), + Entry::Challenge | Entry::Exposed => unreachable!(), + }, + SymbolicExpressionNode::IsFirstRow => { + record.args[..D_EF].copy_from_slice( + is_first_uni_by_log_height[min(log_height, l_skip)] + .as_basis_coefficients_slice(), + ); + record.args[D_EF..2 * D_EF].copy_from_slice( + is_first_mle_by_n[log_height.saturating_sub(l_skip)] + .as_basis_coefficients_slice(), + ); + } + SymbolicExpressionNode::IsLastRow + | SymbolicExpressionNode::IsTransition => { + record.args[..D_EF].copy_from_slice( + is_last_uni_by_log_height[min(log_height, l_skip)] + .as_basis_coefficients_slice(), + ); + record.args[D_EF..2 * D_EF].copy_from_slice( + is_last_mle_by_n[log_height.saturating_sub(l_skip)] + .as_basis_coefficients_slice(), + ); + } + SymbolicExpressionNode::Constant(_) => {} + SymbolicExpressionNode::Add { + left_idx, + right_idx, + .. + } + | SymbolicExpressionNode::Sub { + left_idx, + right_idx, + .. + } + | SymbolicExpressionNode::Mul { + left_idx, + right_idx, + .. + } => { + record.args[..D_EF].copy_from_slice( + expr_evals[*left_idx].as_basis_coefficients_slice(), + ); + record.args[D_EF..2 * D_EF].copy_from_slice( + expr_evals[*right_idx].as_basis_coefficients_slice(), + ); + } + SymbolicExpressionNode::Neg { idx, .. } => { + record.args[..D_EF] + .copy_from_slice(expr_evals[*idx].as_basis_coefficients_slice()); + } + }; + records.push(Some(record)); + } + for interaction in &vk.symbolic_constraints.interactions { + let mut args = [F::ZERO; 2 * D_EF]; + args[..D_EF].copy_from_slice( + expr_evals[interaction.count].as_basis_coefficients_slice(), + ); + records.push(Some(Record { + args, + sort_idx, + n_abs, + is_n_neg, + })); + + for &node_idx in &interaction.message { + let mut args = [F::ZERO; 2 * D_EF]; + args[..D_EF] + .copy_from_slice(expr_evals[node_idx].as_basis_coefficients_slice()); + records.push(Some(Record { + args, + sort_idx, + n_abs, + is_n_neg, + })); + } + + args.fill(F::ZERO); + args[0] = F::from_u16(interaction.bus_index + 1); + records.push(Some(Record { + args, + sort_idx, + n_abs, + is_n_neg, + })); + } + + let mut node_idx = constraints.nodes.len(); + for unused_var in &vk.unused_variables { + if matches!( + unused_var.entry, + Entry::Permutation { .. } + | Entry::Public + | Entry::Challenge + | Entry::Exposed + ) { + continue; + } + let mut args = [F::ZERO; 2 * D_EF]; + args[..D_EF] + .copy_from_slice(expr_evals[node_idx].as_basis_coefficients_slice()); + records.push(Some(Record { + args, + sort_idx, + n_abs, + is_n_neg, + })); + node_idx += 1; + } + } + } + + let num_valid_rows = records.len() / preflights.len(); + let height = if let Some(height) = required_height { + if height < num_valid_rows { + return None; + } + height + } else { + num_valid_rows.max(1).next_power_of_two() + }; + let mut main_trace = F::zero_vec(main_width * height); + main_trace + .par_chunks_exact_mut(main_width) + .enumerate() + .for_each(|(row_idx, row)| { + if row_idx >= num_valid_rows { + return; + } + for proof_idx in 0..max_num_proofs { + if proof_idx >= preflights.len() { + continue; + } + let record_idx = proof_idx * num_valid_rows + row_idx; + let Some(record) = records[record_idx].as_ref() else { + continue; + }; + let start = proof_idx * single_main_width; + let end = start + single_main_width; + let cols: &mut SingleMainSymbolicExpressionColumns<_> = + row[start..end].borrow_mut(); + cols.slot_state = F::from_u8(2); + cols.args = record.args; + cols.sort_idx = F::from_usize(record.sort_idx); + cols.n_abs = F::from_usize(record.n_abs); + cols.is_n_neg = F::from_usize(record.is_n_neg); + } + }); + + Some(RowMajorMatrix::new(main_trace, main_width)) + } +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct CachedRecord { + pub kind: NodeKind, + pub air_idx: usize, + pub node_idx: usize, + pub attrs: [usize; 3], + pub is_constraint: bool, + pub constraint_idx: usize, + pub fanout: usize, +} + +#[derive(Debug, Clone, Default)] +pub struct CachedTraceRecord { + pub records: Vec, +} + +pub fn build_cached_trace_record(child_vk: &RecursionVk) -> CachedTraceRecord { + let mut records = Vec::new(); + for (&air_idx, circuit_name) in &child_vk.circuit_index_to_name { + let Some(circuit_vk) = child_vk.circuit_vks.get(circuit_name) else { + continue; + }; + let Some(gkr) = circuit_vk.cs.gkr_circuit.as_ref() else { + continue; + }; + let Some(layer) = gkr.layers.first() else { + continue; + }; + let counts = Counts::from_css(&circuit_vk.cs); + let offsets = Offsets::new(&counts); + let mut builder = AirBuilder::new(&mut records, air_idx); + push_base_nodes(&mut builder, &counts, &offsets); + for (constraint_idx, expr) in layer.exprs.iter().enumerate() { + let root_idx = build_expression_nodes(expr, &mut builder, &offsets); + builder.mark_constraint(root_idx, constraint_idx); + } + } + + CachedTraceRecord { records } +} + +fn push_base_nodes(builder: &mut AirBuilder<'_>, counts: &Counts, offsets: &Offsets) { + for local in 0..counts.num_witin { + let global = offsets.witin + local; + builder.push(NodeKind::WitIn, [global, local, 0]); + } + for local in 0..counts.num_structural_witin { + let global = offsets.structural + local; + builder.push(NodeKind::StructuralWitIn, [global, local, 0]); + } + for local in 0..counts.num_fixed { + let global = offsets.fixed + local; + builder.push(NodeKind::Fixed, [global, local, 0]); + } + for local in 0..counts.num_instance { + let global = offsets.instance + local; + builder.push(NodeKind::Instance, [global, local, 0]); + } +} + +struct Counts { + num_witin: usize, + num_structural_witin: usize, + num_fixed: usize, + num_instance: usize, +} + +impl Counts { + fn from_css(cs: &ComposedConstrainSystem) -> Self { + let css = &cs.zkvm_v1_css; + Self { + num_witin: css.num_witin as usize, + num_structural_witin: css.num_structural_witin as usize, + num_fixed: css.num_fixed, + num_instance: css.instance_openings.len(), + } + } +} + +struct Offsets { + witin: usize, + structural: usize, + fixed: usize, + instance: usize, +} + +impl Offsets { + fn new(counts: &Counts) -> Self { + let witin = 0; + let structural = witin + counts.num_witin; + let fixed = structural + counts.num_structural_witin; + let instance = fixed + counts.num_fixed; + Self { + witin, + structural, + fixed, + instance, + } + } +} + +struct AirBuilder<'a> { + records: &'a mut Vec, + air_idx: usize, + air_start: usize, + next_local_idx: usize, +} + +impl<'a> AirBuilder<'a> { + fn new(records: &'a mut Vec, air_idx: usize) -> Self { + let air_start = records.len(); + Self { + records, + air_idx, + air_start, + next_local_idx: 0, + } + } + + fn push(&mut self, kind: NodeKind, attrs: [usize; 3]) -> usize { + let node_idx = self.next_local_idx; + self.next_local_idx += 1; + self.records.push(CachedRecord { + kind, + air_idx: self.air_idx, + node_idx, + attrs, + is_constraint: false, + constraint_idx: 0, + fanout: 0, + }); + node_idx + } + + fn bump_fanout(&mut self, local_idx: usize) { + let global_idx = self.air_start + local_idx; + if let Some(record) = self.records.get_mut(global_idx) { + record.fanout = record.fanout.saturating_add(1); + } + } + + fn mark_constraint(&mut self, local_idx: usize, constraint_idx: usize) { + let global_idx = self.air_start + local_idx; + if let Some(record) = self.records.get_mut(global_idx) { + record.is_constraint = true; + record.constraint_idx = constraint_idx; + } + } +} + +fn build_expression_nodes( + expr: &Expression, + builder: &mut AirBuilder<'_>, + offsets: &Offsets, +) -> usize { + match expr { + Expression::WitIn(id) => offsets.witin + (*id as usize), + Expression::StructuralWitIn(id, _) => offsets.structural + (*id as usize), + Expression::Fixed(Fixed(idx)) => offsets.fixed + *idx, + Expression::Instance(instance) | Expression::InstanceScalar(instance) => { + offsets.instance + instance.0 + } + Expression::Constant(_) => builder.push(NodeKind::Constant, [0, 0, 0]), + Expression::Challenge(ch_id, pow, _, _) => { + builder.push(NodeKind::Constant, [*ch_id as usize, *pow, 1]) + } + Expression::Sum(left, right) => { + let left_idx = build_expression_nodes(left, builder, offsets); + let right_idx = build_expression_nodes(right, builder, offsets); + builder.bump_fanout(left_idx); + builder.bump_fanout(right_idx); + builder.push(NodeKind::Add, [left_idx, right_idx, 0]) + } + Expression::Product(left, right) => { + let left_idx = build_expression_nodes(left, builder, offsets); + let right_idx = build_expression_nodes(right, builder, offsets); + builder.bump_fanout(left_idx); + builder.bump_fanout(right_idx); + builder.push(NodeKind::Mul, [left_idx, right_idx, 0]) + } + Expression::ScaledSum(x, a, b) => { + let mul_left = build_expression_nodes(a, builder, offsets); + let mul_right = build_expression_nodes(x, builder, offsets); + builder.bump_fanout(mul_left); + builder.bump_fanout(mul_right); + let mul_idx = builder.push(NodeKind::Mul, [mul_left, mul_right, 0]); + let b_idx = build_expression_nodes(b, builder, offsets); + builder.bump_fanout(mul_idx); + builder.bump_fanout(b_idx); + builder.push(NodeKind::Add, [mul_idx, b_idx, 0]) + } + } +} + +pub fn generate_symbolic_expr_cached_trace( + cached_trace_record: &CachedTraceRecord, +) -> RowMajorMatrix { + let encoder = Encoder::new(NodeKind::COUNT, ENCODER_MAX_DEGREE, true); + let cached_width = CachedSymbolicExpressionColumns::::width(); + let records = &cached_trace_record.records; + + let height = records.len().next_power_of_two(); + let mut cached_trace = F::zero_vec(cached_width * height); + cached_trace + .par_chunks_exact_mut(cached_width) + .zip(records) + .for_each(|(row, record)| { + let cols: &mut CachedSymbolicExpressionColumns<_> = row.borrow_mut(); + + for (i, x) in encoder + .get_flag_pt(record.kind as usize) + .into_iter() + .enumerate() + { + cols.flags[i] = F::from_u32(x); + } + cols.air_idx = F::from_usize(record.air_idx); + cols.node_or_interaction_idx = F::from_usize(record.node_idx); + cols.attrs = record.attrs.map(F::from_usize); + cols.is_constraint = F::from_bool(record.is_constraint); + cols.constraint_idx = F::from_usize(record.constraint_idx); + cols.fanout = F::from_usize(record.fanout); + }); + + RowMajorMatrix::new(cached_trace, cached_width) +} diff --git a/ceno_recursion_v2/src/batch_constraint/expression_claim/air.rs b/ceno_recursion_v2/src/batch_constraint/expression_claim/air.rs new file mode 100644 index 000000000..1c5bf2654 --- /dev/null +++ b/ceno_recursion_v2/src/batch_constraint/expression_claim/air.rs @@ -0,0 +1,222 @@ +use std::borrow::Borrow; + +use openvm_circuit_primitives::utils::{assert_array_eq, not}; +use openvm_stark_backend::{ + BaseAirWithPublicValues, PartitionedBaseAir, interaction::InteractionBuilder, +}; +use openvm_stark_sdk::config::baby_bear_poseidon2::D_EF; +use p3_air::{Air, AirBuilder, BaseAir}; +use p3_field::{PrimeCharacteristicRing, extension::BinomiallyExtendable}; +use p3_matrix::Matrix; +use stark_recursion_circuit_derive::AlignedBorrow; + +use crate::{ + batch_constraint::bus::{ + BatchConstraintConductorBus, BatchConstraintConductorMessage, + BatchConstraintInnerMessageType, EqNOuterBus, EqNOuterMessage, ExpressionClaimBus, + ExpressionClaimMessage, + }, + bus::{ + ExpressionClaimNMaxBus, ExpressionClaimNMaxMessage, HyperdimBus, HyperdimBusMessage, + MainExpressionClaimBus, MainExpressionClaimMessage, + }, + primitives::bus::{PowerCheckerBus, PowerCheckerBusMessage}, + utils::{base_to_ext, ext_field_add, ext_field_multiply, ext_field_multiply_scalar}, +}; + +/// For each proof, this AIR will receive 2t interaction claims and t constraint claims. +/// (2 interaction claims and 1 constraint claim per trace). +/// These values are folded (algebraic batching) with mu into a single value, which +/// should match the final sumcheck claim. +#[derive(AlignedBorrow, Copy, Clone, Debug)] +#[repr(C)] +pub struct ExpressionClaimCols { + pub is_valid: T, + pub is_first: T, + pub proof_idx: T, + + pub is_interaction: T, + /// Index within the proof, 0 ~ 2t-1 are interaction claims, 0~t-1 are constraint claims. + pub idx: T, + pub idx_parity: T, + pub trace_idx: T, + /// The received evaluation claim. Note that for interactions, this is without norm_factor and + /// eq_sharp_ns. These are interactions_evals (without norm_factor and eq_sharp_ns) and + /// constraint_evals in the rust verifier. + pub value: [T; D_EF], + /// Receive from eq_ns AIR + pub eq_sharp_ns: [T; D_EF], + + /// For folding with mu. + pub cur_sum: [T; D_EF], + pub mu: [T; D_EF], + pub multiplier: [T; D_EF], + + /// Need to know n as if n<0, we need to multiply some norm_factor. + pub n_abs: T, + pub n_abs_pow: T, + pub n_sign: T, + /// The round idx for final sumcheck claim. + pub num_multilinear_sumcheck_rounds: T, +} + +pub struct ExpressionClaimAir { + pub expression_claim_n_max_bus: ExpressionClaimNMaxBus, + pub expr_claim_bus: ExpressionClaimBus, + pub mu_bus: BatchConstraintConductorBus, + pub main_claim_bus: MainExpressionClaimBus, + pub eq_n_outer_bus: EqNOuterBus, + pub pow_checker_bus: PowerCheckerBus, + pub hyperdim_bus: HyperdimBus, +} + +impl BaseAirWithPublicValues for ExpressionClaimAir {} +impl PartitionedBaseAir for ExpressionClaimAir {} + +impl BaseAir for ExpressionClaimAir { + fn width(&self) -> usize { + ExpressionClaimCols::::width() + } +} + +impl Air for ExpressionClaimAir +where + ::PrimeSubfield: BinomiallyExtendable<{ D_EF }>, +{ + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let (local, next) = ( + main.row_slice(0).expect("window should have two elements"), + main.row_slice(1).expect("window should have two elements"), + ); + + let local: &ExpressionClaimCols = (*local).borrow(); + let next: &ExpressionClaimCols = (*next).borrow(); + + builder.assert_bool(local.is_valid); + builder.assert_bool(local.is_first); + builder.assert_bool(local.is_interaction); + builder.assert_bool(local.idx_parity); + builder.assert_bool(local.n_sign); + builder + .when(local.is_first) + .assert_one(local.is_interaction); + builder.when(local.is_first).assert_zero(local.idx_parity); + builder + .when(local.is_interaction) + .assert_eq(local.idx_parity + next.idx_parity, AB::Expr::ONE); + builder + .when(local.idx_parity) + .assert_one(local.is_interaction); + + // === cum sum folding === + // cur_sum = next_cur_sum * mu + value * multiplier + assert_array_eq( + &mut builder.when(local.is_valid * not(next.is_first)), + local.cur_sum, + ext_field_add::( + ext_field_multiply::(local.value, local.multiplier), + ext_field_multiply::(next.cur_sum, local.mu), + ), + ); + // multiplier = 1 if not interaction + assert_array_eq( + &mut builder.when(not(local.is_interaction)).when(local.is_valid), + local.multiplier, + base_to_ext::(AB::Expr::ONE), + ); + + // IF negative n and numerator + assert_array_eq( + &mut builder.when(local.n_sign * (local.is_interaction - local.idx_parity)), + ext_field_multiply_scalar::(local.multiplier, local.n_abs_pow), + local.eq_sharp_ns, + ); + // ELSE 1 + assert_array_eq( + &mut builder.when(local.is_interaction * (AB::Expr::ONE - local.n_sign)), + local.multiplier, + local.eq_sharp_ns, + ); + // ELSE 2 + assert_array_eq( + &mut builder.when(local.idx_parity), + local.multiplier, + local.eq_sharp_ns, + ); + + // === interactions === + self.expr_claim_bus.receive( + builder, + local.proof_idx, + ExpressionClaimMessage { + is_interaction: local.is_interaction, + idx: local.idx, + value: local.value, + }, + local.is_valid, + ); + + self.mu_bus.lookup_key( + builder, + local.proof_idx, + BatchConstraintConductorMessage { + msg_type: BatchConstraintInnerMessageType::Mu.to_field(), + idx: AB::Expr::ZERO, + value: local.mu.map(Into::into), + }, + local.is_first * local.is_valid, + ); + + // Receive n_max value from proof shape air + self.expression_claim_n_max_bus.receive( + builder, + local.proof_idx, + ExpressionClaimNMaxMessage { + n_max: local.num_multilinear_sumcheck_rounds, + }, + local.is_first * local.is_valid, + ); + + self.main_claim_bus.receive( + builder, + local.proof_idx, + MainExpressionClaimMessage { + idx: local.idx.into(), + claim: local.cur_sum.map(Into::into), + }, + local.is_first * local.is_valid, + ); + + self.hyperdim_bus.lookup_key( + builder, + local.proof_idx, + HyperdimBusMessage { + sort_idx: local.trace_idx.into(), + n_abs: local.n_abs.into(), + n_sign_bit: local.n_sign.into(), + }, + local.is_valid * (local.is_interaction - local.idx_parity), + ); + + self.eq_n_outer_bus.lookup_key( + builder, + local.proof_idx, + EqNOuterMessage { + is_sharp: AB::Expr::ONE, + n: local.n_abs * (AB::Expr::ONE - local.n_sign), + value: local.eq_sharp_ns.map(Into::into), + }, + local.is_valid * local.is_interaction, + ); + + self.pow_checker_bus.lookup_key( + builder, + PowerCheckerBusMessage { + log: local.n_abs.into(), + exp: local.n_abs_pow.into(), + }, + local.is_valid * local.is_interaction, + ); + } +} diff --git a/ceno_recursion_v2/src/batch_constraint/expression_claim/mod.rs b/ceno_recursion_v2/src/batch_constraint/expression_claim/mod.rs new file mode 100644 index 000000000..bb2d9517a --- /dev/null +++ b/ceno_recursion_v2/src/batch_constraint/expression_claim/mod.rs @@ -0,0 +1,8 @@ +mod air; +mod trace; + +pub use air::{ExpressionClaimAir, ExpressionClaimCols}; +pub(in crate::batch_constraint) use trace::{ + ExpressionClaimBlob, ExpressionClaimCtx, ExpressionClaimTraceGenerator, + generate_expression_claim_blob, +}; diff --git a/ceno_recursion_v2/src/batch_constraint/expression_claim/trace.rs b/ceno_recursion_v2/src/batch_constraint/expression_claim/trace.rs new file mode 100644 index 000000000..7e4a98766 --- /dev/null +++ b/ceno_recursion_v2/src/batch_constraint/expression_claim/trace.rs @@ -0,0 +1,165 @@ +use std::borrow::BorrowMut; + +use openvm_stark_backend::proof::Proof; +use openvm_stark_sdk::config::baby_bear_poseidon2::{BabyBearPoseidon2Config, D_EF, EF, F}; +use p3_field::{BasedVectorSpace, Field, PrimeCharacteristicRing, PrimeField32}; +use p3_matrix::dense::RowMajorMatrix; +use p3_maybe_rayon::prelude::*; + +use super::ExpressionClaimCols; +use crate::{ + primitives::pow::PowerCheckerCpuTraceGenerator, + system::{POW_CHECKER_HEIGHT, Preflight}, + tracegen::RowMajorChip, + utils::MultiProofVecVec, +}; + +pub struct ExpressionClaimBlob { + // (n, value), n is before lift, can be negative + claims: MultiProofVecVec<(isize, EF)>, +} + +pub fn generate_expression_claim_blob( + cf_folded_claims: &MultiProofVecVec<(isize, EF)>, + if_folded_claims: &MultiProofVecVec<(isize, EF)>, +) -> ExpressionClaimBlob { + let mut claims = MultiProofVecVec::new(); + for pidx in 0..cf_folded_claims.num_proofs() { + claims.extend(if_folded_claims[pidx].iter().cloned()); + claims.extend(cf_folded_claims[pidx].iter().cloned()); + claims.end_proof(); + } + ExpressionClaimBlob { claims } +} + +pub struct ExpressionClaimTraceGenerator; + +pub(crate) struct ExpressionClaimCtx<'a> { + pub blob: &'a ExpressionClaimBlob, + pub proofs: &'a [&'a Proof], + pub preflights: &'a [&'a Preflight], + pub pow_checker: &'a PowerCheckerCpuTraceGenerator<2, POW_CHECKER_HEIGHT>, +} + +impl RowMajorChip for ExpressionClaimTraceGenerator { + type Ctx<'a> = ExpressionClaimCtx<'a>; + + #[tracing::instrument(level = "trace", skip_all)] + fn generate_trace( + &self, + ctx: &Self::Ctx<'_>, + required_height: Option, + ) -> Option> { + let blob = ctx.blob; + let proofs = ctx.proofs; + let preflights = ctx.preflights; + let pow_checker = ctx.pow_checker; + let width = ExpressionClaimCols::::width(); + + let num_valid = blob.claims.len(); + let padded_height = if let Some(height) = required_height { + if height < num_valid { + return None; + } + height + } else { + num_valid.next_power_of_two() + }; + let mut trace = vec![F::ZERO; padded_height * width]; + let mut cur_height = 0; + for (pidx, preflight) in preflights.iter().enumerate() { + let claims = &blob.claims[pidx]; + + let num_rounds = proofs[pidx] + .batch_constraint_proof + .sumcheck_round_polys + .len(); + let num_present = preflight.proof_shape.sorted_trace_vdata.len(); + debug_assert_eq!(claims.len(), 3 * num_present); + let mu_tidx = preflight.batch_constraint.tidx_before_univariate - D_EF; + + trace[cur_height * width..(cur_height + claims.len()) * width] + .par_chunks_exact_mut(width) + .enumerate() + .for_each(|(i, chunk)| { + let n_lift = claims[i].0.max(0) as usize; + let n_abs = claims[i].0.unsigned_abs(); + let is_interaction = i < 2 * num_present; + if is_interaction { + pow_checker.add_pow(n_abs); + } + let cols: &mut ExpressionClaimCols<_> = chunk.borrow_mut(); + cols.is_first = F::from_bool(i == 0); + cols.is_valid = F::ONE; + cols.proof_idx = F::from_usize(pidx); + cols.is_interaction = F::from_bool(is_interaction); + cols.num_multilinear_sumcheck_rounds = F::from_usize(num_rounds); + cols.idx = F::from_usize(if i < 2 * num_present { + i + } else { + i - 2 * num_present + }); + cols.idx_parity = F::from_bool(is_interaction && i % 2 == 1); + let trace_idx = if is_interaction { + i / 2 + } else { + i - 2 * num_present + }; + cols.trace_idx = F::from_usize(trace_idx); + cols.mu + .copy_from_slice(&preflight.transcript.values()[mu_tidx..mu_tidx + D_EF]); + cols.value + .copy_from_slice(claims[i].1.as_basis_coefficients_slice()); + cols.eq_sharp_ns.copy_from_slice( + preflight.batch_constraint.eq_sharp_ns_frontloaded[n_lift] + .as_basis_coefficients_slice(), + ); + cols.multiplier + .copy_from_slice(EF::ONE.as_basis_coefficients_slice()); + cols.n_abs = F::from_usize(n_abs); + cols.n_sign = F::from_bool(claims[i].0.is_negative()); + cols.n_abs_pow = F::from_usize(1 << n_abs); + }); + + // Setting `cur_sum` + let mut cur_sum = EF::ZERO; + let mu = EF::from_basis_coefficients_slice( + &preflight.transcript.values()[mu_tidx..mu_tidx + D_EF], + ) + .unwrap(); + trace[cur_height * width..(cur_height + claims.len()) * width] + .chunks_exact_mut(width) + .rev() + .for_each(|chunk| { + let cols: &mut ExpressionClaimCols<_> = chunk.borrow_mut(); + // if it's interaction, we need to multiply by eq_sharp_ns and norm_factor + let multiplier = if cols.is_interaction == F::ONE { + let mut mult = + EF::from_basis_coefficients_slice(&cols.eq_sharp_ns).unwrap(); + if cols.n_sign == F::ONE && cols.idx.as_canonical_u32() % 2 == 0 { + mult *= F::from_u32(1 << cols.n_abs.as_canonical_u32()).inverse(); + } + mult + } else { + EF::ONE + }; + cols.multiplier + .copy_from_slice(multiplier.as_basis_coefficients_slice()); + cur_sum = cur_sum * mu + + EF::from_basis_coefficients_slice(&cols.value).unwrap() * multiplier; + cols.cur_sum + .copy_from_slice(cur_sum.as_basis_coefficients_slice()); + }); + + cur_height += claims.len(); + } + trace[cur_height * width..] + .par_chunks_mut(width) + .enumerate() + .for_each(|(i, chunk)| { + let cols: &mut ExpressionClaimCols = chunk.borrow_mut(); + cols.proof_idx = F::from_usize(preflights.len() + i); + }); + Some(RowMajorMatrix::new(trace, width)) + } +} diff --git a/ceno_recursion_v2/src/batch_constraint/mod.rs b/ceno_recursion_v2/src/batch_constraint/mod.rs new file mode 100644 index 000000000..f0f279ed9 --- /dev/null +++ b/ceno_recursion_v2/src/batch_constraint/mod.rs @@ -0,0 +1,309 @@ +use std::sync::Arc; + +use ceno_zkvm::scheme::ZKVMChipProof; +use openvm_cpu_backend::CpuBackend; +use openvm_poseidon2_air::POSEIDON2_WIDTH; +use openvm_stark_backend::{ + AirRef, FiatShamirTranscript, StarkEngine, StarkProtocolConfig, TranscriptHistory, + prover::{CommittedTraceData, TraceCommitter}, +}; +use openvm_stark_sdk::config::baby_bear_poseidon2::{BabyBearPoseidon2Config, EF, F}; +use p3_field::PrimeCharacteristicRing; +use p3_matrix::dense::RowMajorMatrix; + +use crate::{ + batch_constraint::{ + bus::{ + BatchConstraintConductorBus, ConstraintsFoldingBus, EqNOuterBus, ExpressionClaimBus, + InteractionsFoldingBus, SymbolicExpressionBus, + }, + expr_eval::{ + ConstraintsFoldingAir, ConstraintsFoldingCols, + symbolic_expression::{ + CachedSymbolicExpressionColumns, SingleMainSymbolicExpressionColumns, + SymbolicExpressionAir, + }, + }, + expression_claim::{ExpressionClaimAir, ExpressionClaimCols}, + }, + bus::{AirPresenceBus, ColumnClaimsBus, SelHypercubeBus, SelUniBus}, + system::{ + AirModule, BatchConstraintPreflight, BusIndexManager, BusInventory, GlobalCtxCpu, + Preflight, RecursionField, RecursionProof, RecursionVk, TraceGenModule, + }, +}; + +pub mod expr_eval; +pub mod expression_claim; +pub mod bus { + use p3_field::PrimeCharacteristicRing; + pub use recursion_circuit::batch_constraint::bus::*; + + #[repr(u8)] + #[derive(Debug, Copy, Clone)] + pub enum BatchConstraintInnerMessageType { + R, + Xi, + Mu, + } + + impl BatchConstraintInnerMessageType { + pub fn to_field(self) -> T { + T::from_u8(self as u8) + } + } +} + +pub use expr_eval::CachedTraceRecord; + +pub struct BatchConstraintModule { + transcript_bus: crate::bus::TranscriptBus, + hyperdim_bus: crate::bus::HyperdimBus, + air_shape_bus: crate::bus::AirShapeBus, + air_presence_bus: AirPresenceBus, + column_claims_bus: ColumnClaimsBus, + public_values_bus: crate::bus::PublicValuesBus, + sel_hypercube_bus: SelHypercubeBus, + sel_uni_bus: SelUniBus, + + expression_claim_n_max_bus: crate::bus::ExpressionClaimNMaxBus, + n_lift_bus: crate::bus::NLiftBus, + main_expression_claim_bus: crate::bus::MainExpressionClaimBus, + power_checker_bus: recursion_circuit::primitives::bus::PowerCheckerBus, + + batch_constraint_conductor_bus: BatchConstraintConductorBus, + eq_n_outer_bus: EqNOuterBus, + symbolic_expression_bus: SymbolicExpressionBus, + expression_claim_bus: ExpressionClaimBus, + interactions_folding_bus: InteractionsFoldingBus, + constraints_folding_bus: ConstraintsFoldingBus, + + max_num_proofs: usize, +} + +impl BatchConstraintModule { + pub fn new(b: &mut BusIndexManager, bus_inventory: BusInventory, max_num_proofs: usize) -> Self { + Self { + transcript_bus: bus_inventory.transcript_bus, + hyperdim_bus: bus_inventory.hyperdim_bus, + air_shape_bus: bus_inventory.air_shape_bus, + air_presence_bus: AirPresenceBus::new(b.new_bus_idx()), + column_claims_bus: ColumnClaimsBus::new(b.new_bus_idx()), + public_values_bus: bus_inventory.public_values_bus, + sel_hypercube_bus: SelHypercubeBus::new(b.new_bus_idx()), + sel_uni_bus: SelUniBus::new(b.new_bus_idx()), + + expression_claim_n_max_bus: bus_inventory.expression_claim_n_max_bus, + n_lift_bus: bus_inventory.n_lift_bus, + main_expression_claim_bus: bus_inventory.main_expression_claim_bus, + power_checker_bus: bus_inventory.power_checker_bus, + + batch_constraint_conductor_bus: BatchConstraintConductorBus::new(b.new_bus_idx()), + eq_n_outer_bus: EqNOuterBus::new(b.new_bus_idx()), + symbolic_expression_bus: SymbolicExpressionBus::new(b.new_bus_idx()), + expression_claim_bus: ExpressionClaimBus::new(b.new_bus_idx()), + interactions_folding_bus: InteractionsFoldingBus::new(b.new_bus_idx()), + constraints_folding_bus: ConstraintsFoldingBus::new(b.new_bus_idx()), + max_num_proofs, + } + } + + #[tracing::instrument(level = "trace", skip_all)] + pub fn run_preflight( + &self, + _child_vk: &RecursionVk, + proof: &RecursionProof, + preflight: &mut Preflight, + ts: &mut TS, + ) where + TS: FiatShamirTranscript + + TranscriptHistory, + { + // Constraint batching challenge. + let lambda_tidx = ts.len(); + let _lambda = FiatShamirTranscript::::sample_ext(ts); + + // Replay a lightweight subset of batch-constraint transcript observes from per-chip + // sumcheck messages, then sample mu. + for chip_proof in proof + .chip_proofs + .values() + .flat_map(|instances| instances.iter()) + { + observe_main_sumcheck_msgs(ts, chip_proof); + } + let _mu = FiatShamirTranscript::::sample_ext(ts); + let tidx_before_univariate = ts.len(); + + let mut sumcheck_rnd = vec![]; + for chip_proof in proof + .chip_proofs + .values() + .flat_map(|instances| instances.iter()) + { + if let Some(layer) = chip_proof + .gkr_iop_proof + .as_ref() + .and_then(|proof| proof.0.first()) + { + for msg in &layer.main.proof.proofs { + for eval in msg.evaluations.iter().take(3) { + ts.observe_ext(*eval); + } + sumcheck_rnd.push(FiatShamirTranscript::::sample(ts)); + } + } + } + if sumcheck_rnd.is_empty() { + // Keep downstream preflight consumers shape-safe when this bridge has no rounds. + sumcheck_rnd.push(F::ZERO); + } + + let n_max = preflight + .proof_shape + .sorted_trace_vdata + .iter() + .map(|(_, v)| v.log_height) + .max() + .unwrap_or(0); + let eq_ns_frontloaded = vec![EF::ONE; n_max + 1]; + let eq_sharp_ns_frontloaded = vec![EF::ONE; n_max + 1]; + + // TODO(recursion-proof-bridge): replace placeholder eq vectors with verifier-equivalent + // frontloaded eq_n / eq_sharp_n computation derived from xi and sumcheck randomness. + preflight.batch_constraint = BatchConstraintPreflight { + lambda_tidx, + tidx_before_univariate, + sumcheck_rnd, + eq_ns_frontloaded, + eq_sharp_ns_frontloaded, + }; + } + + fn placeholder_air_widths(&self) -> [usize; 3] { + [ + CachedSymbolicExpressionColumns::::width() + + SingleMainSymbolicExpressionColumns::::width() * self.max_num_proofs, + ConstraintsFoldingCols::::width(), + ExpressionClaimCols::::width(), + ] + } +} + +impl AirModule for BatchConstraintModule { + fn num_airs(&self) -> usize { + 3 + } + + fn airs>(&self) -> Vec> { + let symbolic_expression_air = SymbolicExpressionAir { + expr_bus: self.symbolic_expression_bus, + hyperdim_bus: self.hyperdim_bus, + air_shape_bus: self.air_shape_bus, + air_presence_bus: self.air_presence_bus, + column_claims_bus: self.column_claims_bus, + interactions_folding_bus: self.interactions_folding_bus, + constraints_folding_bus: self.constraints_folding_bus, + public_values_bus: self.public_values_bus, + sel_hypercube_bus: self.sel_hypercube_bus, + sel_uni_bus: self.sel_uni_bus, + cnt_proofs: self.max_num_proofs, + }; + let constraints_folding_air = ConstraintsFoldingAir { + transcript_bus: self.transcript_bus, + constraint_bus: self.constraints_folding_bus, + expression_claim_bus: self.expression_claim_bus, + eq_n_outer_bus: self.eq_n_outer_bus, + n_lift_bus: self.n_lift_bus, + }; + let expression_claim_air = ExpressionClaimAir { + expression_claim_n_max_bus: self.expression_claim_n_max_bus, + expr_claim_bus: self.expression_claim_bus, + mu_bus: self.batch_constraint_conductor_bus, + main_claim_bus: self.main_expression_claim_bus, + eq_n_outer_bus: self.eq_n_outer_bus, + pow_checker_bus: self.power_checker_bus, + hyperdim_bus: self.hyperdim_bus, + }; + vec![ + Arc::new(symbolic_expression_air) as AirRef<_>, + Arc::new(constraints_folding_air) as AirRef<_>, + Arc::new(expression_claim_air) as AirRef<_>, + ] + } +} + +impl> TraceGenModule> + for BatchConstraintModule +{ + type ModuleSpecificCtx<'a> = (); + + #[tracing::instrument(skip_all)] + fn generate_proving_ctxs( + &self, + _child_vk: &RecursionVk, + _proofs: &[RecursionProof], + _preflights: &[Preflight], + _ctx: &Self::ModuleSpecificCtx<'_>, + required_heights: Option<&[usize]>, + ) -> Option>>> { + let widths = self.placeholder_air_widths(); + let air_count = required_heights + .map(|heights| heights.len()) + .unwrap_or(self.num_airs()); + + Some( + (0..air_count) + .map(|idx| { + let height = required_heights + .and_then(|heights| heights.get(idx).copied()) + .unwrap_or(1); + if required_heights.is_some() && height < 2 { + return None; + } + let width = widths.get(idx).copied().unwrap_or(1); + let rows = height.max(2); + let cols = width.max(1); + let matrix = RowMajorMatrix::new(vec![F::ZERO; rows * cols], cols); + Some(openvm_stark_backend::prover::AirProvingContext::simple_no_pis(matrix)) + }) + .collect::>>()?, + ) + } +} + +fn observe_main_sumcheck_msgs(ts: &mut TS, chip_proof: &ZKVMChipProof) +where + TS: FiatShamirTranscript, +{ + if let Some(proofs) = &chip_proof.main_sumcheck_proofs { + for msg in proofs { + for eval in msg.evaluations.iter().take(3) { + ts.observe_ext(*eval); + } + } + } +} + +pub fn cached_trace_record(child_vk: &RecursionVk) -> CachedTraceRecord { + expr_eval::symbolic_expression::build_cached_trace_record(child_vk) +} + +pub fn commit_child_vk( + engine: &E, + child_vk: &RecursionVk, +) -> CommittedTraceData> +where + E: StarkEngine>, + SC: StarkProtocolConfig, +{ + let cached_trace = expr_eval::symbolic_expression::generate_symbolic_expr_cached_trace( + &cached_trace_record(child_vk), + ); + let (commitment, data) = engine.device().commit(&[&cached_trace]).unwrap(); + CommittedTraceData { + commitment, + data: Arc::new(data), + trace: cached_trace, + } +} diff --git a/ceno_recursion_v2/src/bn254.rs b/ceno_recursion_v2/src/bn254.rs new file mode 100644 index 000000000..bcd11dd82 --- /dev/null +++ b/ceno_recursion_v2/src/bn254.rs @@ -0,0 +1,53 @@ +use openvm_stark_sdk::config::baby_bear_poseidon2::{DIGEST_SIZE, F}; +use p3_field::{PrimeCharacteristicRing, PrimeField32}; + +pub const BN254_BYTES: usize = 32; + +/// Minimal byte wrapper for commit values used by the forked inner circuit. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct CommitBytes([u8; BN254_BYTES]); + +impl CommitBytes { + pub fn new(bytes: [u8; BN254_BYTES]) -> Self { + Self(bytes) + } + + pub fn as_slice(&self) -> &[u8; BN254_BYTES] { + &self.0 + } + + pub fn reverse(&mut self) { + self.0.reverse(); + } +} + +impl From<[F; DIGEST_SIZE]> for CommitBytes { + fn from(value: [F; DIGEST_SIZE]) -> Self { + Self::from(value.map(|x| x.as_canonical_u32())) + } +} + +impl From<[u32; DIGEST_SIZE]> for CommitBytes { + fn from(value: [u32; DIGEST_SIZE]) -> Self { + let mut bytes = [0u8; BN254_BYTES]; + for (idx, limb) in value.iter().enumerate() { + let start = idx * 4; + bytes[start..start + 4].copy_from_slice(&limb.to_le_bytes()); + } + Self(bytes) + } +} + +impl From for [u32; DIGEST_SIZE] { + fn from(value: CommitBytes) -> Self { + core::array::from_fn(|idx| { + let start = idx * 4; + u32::from_le_bytes([ + value.0[start], + value.0[start + 1], + value.0[start + 2], + value.0[start + 3], + ]) + }) + } +} diff --git a/ceno_recursion_v2/src/bus.rs b/ceno_recursion_v2/src/bus.rs new file mode 100644 index 000000000..85abacb7c --- /dev/null +++ b/ceno_recursion_v2/src/bus.rs @@ -0,0 +1,59 @@ +use openvm_stark_sdk::config::baby_bear_poseidon2::D_EF; +use recursion_circuit::{bus as upstream, define_typed_per_proof_permutation_bus}; +pub use upstream::{ + AirPresenceBus, AirPresenceBusMessage, AirShapeBus, AirShapeBusMessage, CachedCommitBus, + CachedCommitBusMessage, ColumnClaimsBus, ColumnClaimsMessage, CommitmentsBus, + CommitmentsBusMessage, ExpressionClaimNMaxBus, ExpressionClaimNMaxMessage, + FractionFolderInputBus, FractionFolderInputMessage, HyperdimBus, HyperdimBusMessage, + LiftedHeightsBus, LiftedHeightsBusMessage, NLiftBus, NLiftMessage, PublicValuesBus, + PublicValuesBusMessage, SelHypercubeBus, SelHypercubeBusMessage, SelUniBus, SelUniBusMessage, + TranscriptBus, TranscriptBusMessage, +}; + +#[repr(C)] +#[derive(stark_recursion_circuit_derive::AlignedBorrow, Debug, Clone, Copy)] +pub struct TowerModuleMessage { + pub idx: T, + pub tidx: T, + pub n_logup: T, +} + +define_typed_per_proof_permutation_bus!(TowerModuleBus, TowerModuleMessage); + +#[repr(C)] +#[derive(stark_recursion_circuit_derive::AlignedBorrow, Debug, Clone, Copy)] +pub struct MainMessage { + pub idx: T, + pub tidx: T, + pub claim: [T; D_EF], +} + +define_typed_per_proof_permutation_bus!(MainBus, MainMessage); + +#[repr(C)] +#[derive(stark_recursion_circuit_derive::AlignedBorrow, Debug, Clone, Copy)] +pub struct MainSumcheckInputMessage { + pub idx: T, + pub tidx: T, + pub claim: [T; D_EF], +} + +define_typed_per_proof_permutation_bus!(MainSumcheckInputBus, MainSumcheckInputMessage); + +#[repr(C)] +#[derive(stark_recursion_circuit_derive::AlignedBorrow, Debug, Clone, Copy)] +pub struct MainSumcheckOutputMessage { + pub idx: T, + pub claim: [T; D_EF], +} + +define_typed_per_proof_permutation_bus!(MainSumcheckOutputBus, MainSumcheckOutputMessage); + +#[repr(C)] +#[derive(stark_recursion_circuit_derive::AlignedBorrow, Debug, Clone, Copy)] +pub struct MainExpressionClaimMessage { + pub idx: T, + pub claim: [T; D_EF], +} + +define_typed_per_proof_permutation_bus!(MainExpressionClaimBus, MainExpressionClaimMessage); diff --git a/ceno_recursion_v2/src/circuit/deferral/mod.rs b/ceno_recursion_v2/src/circuit/deferral/mod.rs new file mode 100644 index 000000000..8d8ef2ace --- /dev/null +++ b/ceno_recursion_v2/src/circuit/deferral/mod.rs @@ -0,0 +1 @@ +pub const DEF_HOOK_PVS_AIR_ID: usize = 0; diff --git a/ceno_recursion_v2/src/circuit/inner/bus.rs b/ceno_recursion_v2/src/circuit/inner/bus.rs new file mode 100644 index 000000000..f8d744bd7 --- /dev/null +++ b/ceno_recursion_v2/src/circuit/inner/bus.rs @@ -0,0 +1,11 @@ +use recursion_circuit::define_typed_per_proof_lookup_bus; +use stark_recursion_circuit_derive::AlignedBorrow; + +#[repr(C)] +#[derive(AlignedBorrow, Debug, Clone)] +pub struct PvsAirConsistencyMessage { + pub deferral_flag: T, + pub has_verifier_pvs: T, +} + +define_typed_per_proof_lookup_bus!(PvsAirConsistencyBus, PvsAirConsistencyMessage); diff --git a/ceno_recursion_v2/src/circuit/inner/def_pvs/air.rs b/ceno_recursion_v2/src/circuit/inner/def_pvs/air.rs new file mode 100644 index 000000000..46b0c7c14 --- /dev/null +++ b/ceno_recursion_v2/src/circuit/inner/def_pvs/air.rs @@ -0,0 +1,268 @@ +use std::{array::from_fn, borrow::Borrow}; + +use openvm_circuit_primitives::utils::{assert_array_eq, not}; +use openvm_stark_backend::{ + BaseAirWithPublicValues, PartitionedBaseAir, interaction::InteractionBuilder, +}; +use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, BaseAir}; +use p3_field::{Field, PrimeCharacteristicRing}; +use p3_matrix::Matrix; +use recursion_circuit::{ + bus::{ + CachedCommitBus, CachedCommitBusMessage, Poseidon2CompressBus, Poseidon2CompressMessage, + PublicValuesBus, PublicValuesBusMessage, + }, + prelude::DIGEST_SIZE, +}; +use stark_recursion_circuit_derive::AlignedBorrow; +use verify_stark::pvs::{CONSTRAINT_EVAL_AIR_ID, DEF_PVS_AIR_ID, DeferralPvs}; + +use crate::{ + bn254::CommitBytes, + circuit::{ + CONSTRAINT_EVAL_CACHED_INDEX, + deferral::DEF_HOOK_PVS_AIR_ID, + inner::bus::{PvsAirConsistencyBus, PvsAirConsistencyMessage}, + }, + utils::digests_to_poseidon2_input, +}; + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct DeferralPvsCols { + pub row_idx: F, + pub deferral_flag: F, + pub has_verifier_pvs: F, + + pub proof_idx: F, + pub is_present: F, + pub single_present_is_right: F, + + pub child_pvs: DeferralPvs, +} + +pub struct DeferralPvsAir { + pub public_values_bus: PublicValuesBus, + pub cached_commit_bus: CachedCommitBus, + pub poseidon2_bus: Poseidon2CompressBus, + pub pvs_air_consistency_bus: PvsAirConsistencyBus, + + pub expected_def_hook_commit: CommitBytes, +} + +impl BaseAir for DeferralPvsAir { + fn width(&self) -> usize { + DeferralPvsCols::::width() + } +} +impl BaseAirWithPublicValues for DeferralPvsAir { + fn num_public_values(&self) -> usize { + DeferralPvs::::width() + } +} +impl PartitionedBaseAir for DeferralPvsAir {} + +impl Air for DeferralPvsAir { + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let (local, next) = ( + main.row_slice(0).expect("window should have two elements"), + main.row_slice(1).expect("window should have two elements"), + ); + let local: &DeferralPvsCols = (*local).borrow(); + let next: &DeferralPvsCols = (*next).borrow(); + + // This AIR may have 1 or 2 rows. There are 4 valid 1-row cases: + // - deferral_flag == 0: child deferral pvs are unset + // - deferral_flag == 1 && proof_idx == 0: wrapping a deferral proof + // - deferral_flag == 1 && proof_idx == 1: combining a VM and deferral proof + // - deferral_flag == 2: wrapping a combined proof + // + // There are 2 valid 2-row cases, both with deferral_flag == 1: + // - Both child proofs are present + // - The first proof is present and the second is absent + // constrain that when hash_pvs is set we have exactly 2 def rows + builder.assert_bool(local.row_idx); + builder.when_first_row().assert_zero(local.row_idx); + builder + .when_transition() + .assert_one(next.row_idx - local.row_idx); + + let has_two_rows = (next.row_idx - local.row_idx).square(); + let has_one_row = not::(has_two_rows.clone()); + + // constrain that all the present rows are at the beginning + builder.assert_bool(local.is_present); + builder + .when_transition() + .assert_bool(local.is_present - next.is_present); + + // constrain if deferral_flag is set, there is at least one present proof + builder.assert_tern(local.deferral_flag); + builder.assert_eq(local.deferral_flag, next.deferral_flag); + builder + .when(local.deferral_flag) + .assert_bool(local.is_present + next.is_present - AB::Expr::ONE); + + // basic constraints for consistency columns + builder.when_first_row().assert_bool(local.proof_idx); + builder + .when_first_row() + .when(local.proof_idx) + .assert_one(has_one_row.clone()); + builder.assert_bool(local.has_verifier_pvs); + builder.assert_eq(local.has_verifier_pvs, next.has_verifier_pvs); + + // constrain single_present_is_right is set when there is 1 present and + // 1 absent row + builder.assert_bool(local.single_present_is_right); + builder.assert_eq(local.single_present_is_right, next.single_present_is_right); + builder + .when(local.single_present_is_right) + .assert_one(local.is_present + next.is_present); + + // When deferral_flag is unset, there must be a single row with zeros for + // public values. + let mut when_flag_not_one = builder.when_ne(local.deferral_flag, AB::Expr::ONE); + let mut when_invalid = when_flag_not_one.when_ne(local.deferral_flag, AB::Expr::TWO); + + when_invalid.assert_one(has_one_row.clone()); + when_invalid.assert_zero(local.is_present); + for child_pv in local.child_pvs.as_slice() { + when_invalid.assert_zero(*child_pv); + } + + // If there are two rows and a proof is absent, it represents an accumulator + // Merkle subtree that has been left untouched. We constrain its initial and + // final accumulator hashes to be equal. Additionally, if there are two rows + // then the child_pvs depth should be equal. + assert_array_eq( + &mut builder + .when(has_two_rows.clone()) + .when(not(local.is_present)), + local.child_pvs.initial_acc_hash, + local.child_pvs.final_acc_hash, + ); + + builder + .when(has_two_rows.clone()) + .assert_eq(local.child_pvs.depth, next.child_pvs.depth); + + // If this row is present then we need to receive the child public values + // from ProofShapeModule. At the hook level this is at DEF_HOOK_PVS_AIR_ID, + // at every other level it will be at DEF_PVS_AIR_ID. + let def_pvs_air_idx = AB::Expr::from_usize(DEF_PVS_AIR_ID) * local.has_verifier_pvs + + AB::Expr::from_usize(DEF_HOOK_PVS_AIR_ID) * not(local.has_verifier_pvs); + for (pv_idx, value) in local.child_pvs.as_slice().iter().enumerate() { + self.public_values_bus.receive( + builder, + local.proof_idx, + PublicValuesBusMessage { + air_idx: def_pvs_air_idx.clone(), + pv_idx: AB::Expr::from_usize(pv_idx), + value: (*value).into(), + }, + local.is_present, + ); + } + + // We look up proof metadata from VerifierPvsAir here to ensure consistency + // on each row. + self.pvs_air_consistency_bus.lookup_key( + builder, + local.proof_idx, + PvsAirConsistencyMessage { + deferral_flag: local.deferral_flag, + has_verifier_pvs: local.has_verifier_pvs, + }, + local.is_present, + ); + + // If this row corresponds to a direct deferral hook circuit child (i.e. + // has_verifier_pvs == 0), receive the child's cached trace commit and + // constrain it to an expected constant. + let expected_def_hook_commit = + >::into(self.expected_def_hook_commit); + self.cached_commit_bus.receive( + builder, + local.proof_idx, + CachedCommitBusMessage { + air_idx: AB::Expr::from_usize(CONSTRAINT_EVAL_AIR_ID), + cached_idx: AB::Expr::from_usize(CONSTRAINT_EVAL_CACHED_INDEX), + cached_commit: expected_def_hook_commit.map(AB::Expr::from_u32), + }, + local.is_present * not(local.has_verifier_pvs), + ); + + // Finally, we constrain the public values to be consistent with the + // child's. If there is one row then the pvs are simply passed through. + // If there are two, then initial_acc_hash and final_acc_hash are + // combined and depth is incremented by 1. + let &DeferralPvs::<_> { + initial_acc_hash, + final_acc_hash, + depth, + } = builder.public_values().borrow(); + + // constrain that pvs are passed through if there is one row + let mut when_one_row = builder.when(has_one_row); + when_one_row.assert_eq(local.child_pvs.depth, depth); + + assert_array_eq( + &mut when_one_row, + local.child_pvs.initial_acc_hash, + initial_acc_hash, + ); + assert_array_eq( + &mut when_one_row, + local.child_pvs.final_acc_hash, + final_acc_hash, + ); + + // constrain that pvs are updated properly if there are two rows + let row_delta = next.row_idx - local.row_idx; + let single_present_is_left = not(local.single_present_is_right); + let single_present_is_local = + row_delta.clone() * (row_delta + AB::Expr::ONE) * AB::F::TWO.inverse(); + + let left_init_child = from_fn(|i| { + single_present_is_left.clone() * local.child_pvs.initial_acc_hash[i] + + local.single_present_is_right * next.child_pvs.initial_acc_hash[i] + }); + let right_init_child = from_fn(|i| { + local.single_present_is_right * local.child_pvs.initial_acc_hash[i] + + single_present_is_left.clone() * next.child_pvs.initial_acc_hash[i] + }); + + self.poseidon2_bus.lookup_key( + builder, + Poseidon2CompressMessage { + input: digests_to_poseidon2_input(left_init_child, right_init_child), + output: initial_acc_hash.map(Into::into), + }, + single_present_is_local.clone(), + ); + + let left_final_child = from_fn(|i| { + single_present_is_left.clone() * local.child_pvs.final_acc_hash[i] + + local.single_present_is_right * next.child_pvs.final_acc_hash[i] + }); + let right_final_child = from_fn(|i| { + local.single_present_is_right * local.child_pvs.final_acc_hash[i] + + single_present_is_left.clone() * next.child_pvs.final_acc_hash[i] + }); + + self.poseidon2_bus.lookup_key( + builder, + Poseidon2CompressMessage { + input: digests_to_poseidon2_input(left_final_child, right_final_child), + output: final_acc_hash.map(Into::into), + }, + single_present_is_local, + ); + + builder + .when(has_two_rows) + .assert_one(depth.into() - local.child_pvs.depth); + } +} diff --git a/ceno_recursion_v2/src/circuit/inner/def_pvs/mod.rs b/ceno_recursion_v2/src/circuit/inner/def_pvs/mod.rs new file mode 100644 index 000000000..26ed4a40d --- /dev/null +++ b/ceno_recursion_v2/src/circuit/inner/def_pvs/mod.rs @@ -0,0 +1,5 @@ +mod air; +mod trace; + +pub use air::*; +pub use trace::*; diff --git a/ceno_recursion_v2/src/circuit/inner/def_pvs/trace.rs b/ceno_recursion_v2/src/circuit/inner/def_pvs/trace.rs new file mode 100644 index 000000000..fad680a63 --- /dev/null +++ b/ceno_recursion_v2/src/circuit/inner/def_pvs/trace.rs @@ -0,0 +1,142 @@ +use std::borrow::BorrowMut; + +use itertools::Itertools; +use openvm_cpu_backend::CpuBackend; +use openvm_poseidon2_air::POSEIDON2_WIDTH; +use openvm_stark_backend::prover::AirProvingContext; +use openvm_stark_sdk::config::baby_bear_poseidon2::{ + BabyBearPoseidon2Config, DIGEST_SIZE, F, poseidon2_compress_with_capacity, +}; +use p3_field::PrimeCharacteristicRing; +use p3_matrix::dense::RowMajorMatrix; +use verify_stark::pvs::{DEF_PVS_AIR_ID, DeferralPvs}; + +use crate::{ + circuit::{ + deferral::DEF_HOOK_PVS_AIR_ID, + inner::{ProofsType, def_pvs::air::DeferralPvsCols}, + }, + system::RecursionProof, + utils::digests_to_poseidon2_input, +}; + +pub fn generate_proving_ctx( + proofs: &[RecursionProof], + proofs_type: ProofsType, + child_is_app: bool, + absent_trace_pvs: Option<(DeferralPvs, bool)>, +) -> ( + AirProvingContext>, + Vec<[F; POSEIDON2_WIDTH]>, +) { + assert!( + absent_trace_pvs.is_none() + || (matches!(proofs_type, ProofsType::Deferral) && proofs.len() == 1), + "absent_trace_pvs is only valid for single-proof deferral aggregation" + ); + let mut proof_idxs = vec![]; + let (num_rows, def_flag) = match proofs_type { + ProofsType::Vm => (1, 0), + ProofsType::Deferral => { + proof_idxs = (0..proofs.len()).collect_vec(); + (proofs.len() + absent_trace_pvs.is_some() as usize, 1) + } + ProofsType::Mix => { + proof_idxs.push(1); + (1, 1) + } + ProofsType::Combined => { + proof_idxs.push(0); + (1, 2) + } + }; + + let width = DeferralPvsCols::::width(); + let mut trace = vec![F::ZERO; num_rows * width]; + let mut chunks = trace.chunks_exact_mut(width); + + let mut child_pvs_vec = vec![]; + let single_present_is_right = if let Some((_, is_right)) = absent_trace_pvs.as_ref() { + *is_right + } else { + false + }; + + for (row_idx, proof_idx) in proof_idxs.iter().enumerate() { + let proof = &proofs[*proof_idx]; + let chunk = chunks.next().unwrap(); + let cols: &mut DeferralPvsCols = chunk.borrow_mut(); + cols.row_idx = F::from_usize(row_idx); + cols.proof_idx = F::from_usize(*proof_idx); + cols.is_present = F::ONE; + cols.deferral_flag = F::from_usize(def_flag); + cols.has_verifier_pvs = F::from_bool(!child_is_app); + cols.single_present_is_right = F::from_bool(single_present_is_right); + + let _ = proof; + let _air_id = if child_is_app { + DEF_HOOK_PVS_AIR_ID + } else { + DEF_PVS_AIR_ID + }; + // TODO(recursion-proof-bridge): RecursionProof does not expose per-air public values yet. + // Use zeroed child deferral PVS until proof -> verifier-PVS extraction is implemented. + cols.child_pvs = DeferralPvs { + initial_acc_hash: [F::ZERO; DIGEST_SIZE], + final_acc_hash: [F::ZERO; DIGEST_SIZE], + depth: F::ZERO, + }; + child_pvs_vec.push(cols.child_pvs); + } + + if let Some((pvs, _)) = absent_trace_pvs { + let chunk = chunks.next().unwrap(); + let cols: &mut DeferralPvsCols = chunk.borrow_mut(); + cols.row_idx = F::ONE; + cols.deferral_flag = F::from_usize(def_flag); + cols.has_verifier_pvs = F::from_bool(!child_is_app); + cols.single_present_is_right = F::from_bool(single_present_is_right); + cols.child_pvs = pvs; + child_pvs_vec.push(cols.child_pvs); + } + + let mut poseidon2_inputs = vec![]; + let mut public_values = vec![F::ZERO; DeferralPvs::::width()]; + let pvs: &mut DeferralPvs = public_values.as_mut_slice().borrow_mut(); + + if child_pvs_vec.len() == 1 { + *pvs = child_pvs_vec[0]; + } else if child_pvs_vec.len() == 2 { + let first_child = child_pvs_vec[0]; + let second_child = child_pvs_vec[1]; + let (left_initial, right_initial, left_final, right_final) = if single_present_is_right { + ( + second_child.initial_acc_hash, + first_child.initial_acc_hash, + second_child.final_acc_hash, + first_child.final_acc_hash, + ) + } else { + ( + first_child.initial_acc_hash, + second_child.initial_acc_hash, + first_child.final_acc_hash, + second_child.final_acc_hash, + ) + }; + pvs.initial_acc_hash = poseidon2_compress_with_capacity(left_initial, right_initial).0; + poseidon2_inputs.push(digests_to_poseidon2_input(left_initial, right_initial)); + pvs.final_acc_hash = poseidon2_compress_with_capacity(left_final, right_final).0; + poseidon2_inputs.push(digests_to_poseidon2_input(left_final, right_final)); + pvs.depth = first_child.depth + F::ONE; + } + + ( + AirProvingContext { + cached_mains: vec![], + common_main: RowMajorMatrix::new(trace, width), + public_values, + }, + poseidon2_inputs, + ) +} diff --git a/ceno_recursion_v2/src/circuit/inner/mod.rs b/ceno_recursion_v2/src/circuit/inner/mod.rs new file mode 100644 index 000000000..2e5f25df7 --- /dev/null +++ b/ceno_recursion_v2/src/circuit/inner/mod.rs @@ -0,0 +1,117 @@ +use std::sync::Arc; + +use openvm_stark_backend::{AirRef, StarkProtocolConfig}; +use recursion_circuit::prelude::F; +use verify_stark::pvs::{DEF_PVS_AIR_ID, DeferralPvs, VM_PVS_AIR_ID, VmPvs}; + +use crate::{ + bn254::CommitBytes, + circuit::{Circuit, inner::bus::PvsAirConsistencyBus}, + system::AggregationSubCircuit, +}; + +pub mod app { + pub use openvm_circuit::arch::{ + CONNECTOR_AIR_ID, MERKLE_AIR_ID, PROGRAM_AIR_ID, PROGRAM_CACHED_TRACE_INDEX, + }; +} + +pub mod bus; +pub mod def_pvs; +pub mod unset; +pub mod verifier; +pub mod vm_pvs; + +mod trace; +pub use trace::*; + +#[derive(derive_new::new, Clone)] +pub struct InnerCircuit { + pub verifier_circuit: Arc, + pub def_hook_commit: Option, +} + +impl, S: AggregationSubCircuit> Circuit for InnerCircuit { + fn airs(&self) -> Vec> { + let bus_inventory = self.verifier_circuit.bus_inventory(); + let public_values_bus = bus_inventory.public_values_bus; + let cached_commit_bus = bus_inventory.cached_commit_bus; + let poseidon2_compress_bus = bus_inventory.poseidon2_compress_bus; + let pvs_air_consistency_bus = + PvsAirConsistencyBus::new(self.verifier_circuit.next_bus_idx()); + + let deferral_enabled = self.def_hook_commit.is_some(); + + let deferral_config = if deferral_enabled { + verifier::VerifierDeferralConfig::Enabled { + poseidon2_bus: poseidon2_compress_bus, + } + } else { + verifier::VerifierDeferralConfig::Disabled + }; + + let verifier_pvs_air = Arc::new(verifier::VerifierPvsAir { + public_values_bus, + cached_commit_bus, + pvs_air_consistency_bus, + deferral_config, + }) as AirRef; + + let vm_pvs_air = Arc::new(vm_pvs::VmPvsAir { + public_values_bus, + cached_commit_bus, + pvs_air_consistency_bus, + deferral_enabled, + }) as AirRef; + + let (idx2_air, post_airs): (AirRef, Vec>) = if deferral_enabled { + let def_pvs_air = Arc::new(def_pvs::DeferralPvsAir { + public_values_bus, + cached_commit_bus, + poseidon2_bus: poseidon2_compress_bus, + pvs_air_consistency_bus, + expected_def_hook_commit: self + .def_hook_commit + .expect("def_hook_commit must be set when deferral is enabled"), + }) as AirRef; + let unset_vm_pvs_air = Arc::new(unset::UnsetPvsAir { + public_values_bus, + pvs_air_consistency_bus, + air_idx: VM_PVS_AIR_ID, + num_pvs: VmPvs::::width(), + def_flag: 1, + }) as AirRef; + let unset_def_pvs_air = Arc::new(unset::UnsetPvsAir { + public_values_bus, + pvs_air_consistency_bus, + air_idx: DEF_PVS_AIR_ID, + num_pvs: DeferralPvs::::width(), + def_flag: 0, + }) as AirRef; + (def_pvs_air, vec![unset_vm_pvs_air, unset_def_pvs_air]) + } else { + let unset_dummy_air = Arc::new(unset::UnsetPvsAir { + public_values_bus, + pvs_air_consistency_bus, + air_idx: 0, + num_pvs: 0, + def_flag: 0, + }) as AirRef; + (unset_dummy_air, vec![]) + }; + + [verifier_pvs_air, vm_pvs_air, idx2_air] + .into_iter() + .chain(self.verifier_circuit.airs()) + .chain(post_airs) + .collect() + } +} + +impl, S: AggregationSubCircuit> + continuations_v2::circuit::Circuit for InnerCircuit +{ + fn airs(&self) -> Vec> { + >::airs(self) + } +} diff --git a/ceno_recursion_v2/src/circuit/inner/trace.rs b/ceno_recursion_v2/src/circuit/inner/trace.rs new file mode 100644 index 000000000..365b7eff3 --- /dev/null +++ b/ceno_recursion_v2/src/circuit/inner/trace.rs @@ -0,0 +1,150 @@ +use openvm_cpu_backend::CpuBackend; +#[cfg(feature = "cuda")] +use openvm_cuda_backend::GpuBackend; +use openvm_poseidon2_air::POSEIDON2_WIDTH; +use openvm_stark_backend::prover::{AirProvingContext, ProverBackend}; +use openvm_stark_sdk::config::baby_bear_poseidon2::{BabyBearPoseidon2Config, DIGEST_SIZE, F}; +use p3_field::PrimeCharacteristicRing; +use verify_stark::pvs::DeferralPvs; + +use crate::system::RecursionProof; + +#[derive(Copy, Clone)] +pub enum ProofsType { + Vm, + Deferral, + Mix, + Combined, +} + +// Trait that inner and compression provers use to remain generic in PB +pub trait InnerTraceGen { + fn new(deferral_enabled: bool) -> Self; + fn generate_pre_verifier_subcircuit_ctxs( + &self, + proofs: &[RecursionProof], + proofs_type: ProofsType, + absent_trace_pvs: Option<(DeferralPvs, bool)>, + child_is_app: bool, + child_dag_commit: PB::Commitment, + ) -> (Vec>, Vec<[F; POSEIDON2_WIDTH]>); + fn generate_post_verifier_subcircuit_ctxs( + &self, + proofs: &[RecursionProof], + proofs_type: ProofsType, + child_is_app: bool, + ) -> Vec>; +} + +pub struct InnerTraceGenImpl { + pub deferral_enabled: bool, +} + +impl InnerTraceGen> for InnerTraceGenImpl { + fn new(deferral_enabled: bool) -> Self { + Self { deferral_enabled } + } + + fn generate_pre_verifier_subcircuit_ctxs( + &self, + proofs: &[RecursionProof], + proofs_type: ProofsType, + absent_trace_pvs: Option<(DeferralPvs, bool)>, + child_is_app: bool, + child_dag_commit: [F; DIGEST_SIZE], + ) -> ( + Vec>>, + Vec<[F; POSEIDON2_WIDTH]>, + ) { + let (verifier_ctx, poseidon2_inputs) = super::verifier::generate_proving_ctx( + proofs, + proofs_type, + child_is_app, + child_dag_commit, + self.deferral_enabled, + ); + let vm_ctx = super::vm_pvs::generate_proving_ctx( + proofs, + proofs_type, + child_is_app, + self.deferral_enabled, + ); + + let mut poseidon2_inputs = poseidon2_inputs; + let idx2_ctx = if self.deferral_enabled { + let (def_pvs_ctx, def_poseidon2_inputs) = super::def_pvs::generate_proving_ctx( + proofs, + proofs_type, + child_is_app, + absent_trace_pvs, + ); + poseidon2_inputs.extend_from_slice(&def_poseidon2_inputs); + def_pvs_ctx + } else { + super::unset::generate_proving_ctx(&[], child_is_app) + }; + + (vec![verifier_ctx, vm_ctx, idx2_ctx], poseidon2_inputs) + } + + fn generate_post_verifier_subcircuit_ctxs( + &self, + proofs: &[RecursionProof], + proofs_type: ProofsType, + child_is_app: bool, + ) -> Vec>> { + if !self.deferral_enabled { + return vec![]; + } + + let (vm_unset, def_unset) = match proofs_type { + ProofsType::Vm => (vec![], proofs.iter().enumerate().map(|(i, _)| i).collect()), + ProofsType::Deferral => (proofs.iter().enumerate().map(|(i, _)| i).collect(), vec![]), + ProofsType::Mix => (vec![1], vec![0]), + ProofsType::Combined => (vec![], vec![]), + }; + vec![ + super::unset::generate_proving_ctx(&vm_unset, child_is_app), + super::unset::generate_proving_ctx(&def_unset, child_is_app), + ] + } +} + +#[cfg(feature = "cuda")] +impl InnerTraceGen for InnerTraceGenImpl { + fn new(deferral_enabled: bool) -> Self { + Self { deferral_enabled } + } + + fn generate_pre_verifier_subcircuit_ctxs( + &self, + proofs: &[RecursionProof], + proofs_type: ProofsType, + absent_trace_pvs: Option<(DeferralPvs, bool)>, + child_is_app: bool, + child_dag_commit: [F; DIGEST_SIZE], + ) -> ( + Vec>, + Vec<[F; POSEIDON2_WIDTH]>, + ) { + let _ = ( + self, + proofs, + proofs_type, + absent_trace_pvs, + child_is_app, + child_dag_commit, + ); + (vec![], vec![]) + } + + fn generate_post_verifier_subcircuit_ctxs( + &self, + proofs: &[RecursionProof], + proofs_type: ProofsType, + child_is_app: bool, + ) -> Vec> { + let _ = (self, proofs, proofs_type, child_is_app); + vec![] + } +} diff --git a/ceno_recursion_v2/src/circuit/inner/unset/air.rs b/ceno_recursion_v2/src/circuit/inner/unset/air.rs new file mode 100644 index 000000000..c18d3fa81 --- /dev/null +++ b/ceno_recursion_v2/src/circuit/inner/unset/air.rs @@ -0,0 +1,75 @@ +use std::borrow::Borrow; + +use openvm_stark_backend::{ + BaseAirWithPublicValues, PartitionedBaseAir, interaction::InteractionBuilder, +}; +use p3_air::{Air, AirBuilder, BaseAir}; +use p3_field::PrimeCharacteristicRing; +use p3_matrix::Matrix; +use recursion_circuit::bus::{PublicValuesBus, PublicValuesBusMessage}; +use stark_recursion_circuit_derive::AlignedBorrow; + +use crate::circuit::inner::bus::{PvsAirConsistencyBus, PvsAirConsistencyMessage}; + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct UnsetPvsCols { + pub proof_idx: F, + pub is_valid: F, +} + +pub struct UnsetPvsAir { + pub public_values_bus: PublicValuesBus, + pub pvs_air_consistency_bus: PvsAirConsistencyBus, + pub air_idx: usize, + pub num_pvs: usize, + pub def_flag: u32, +} + +impl BaseAir for UnsetPvsAir { + fn width(&self) -> usize { + UnsetPvsCols::::width() + } +} +impl BaseAirWithPublicValues for UnsetPvsAir {} +impl PartitionedBaseAir for UnsetPvsAir {} + +impl Air for UnsetPvsAir { + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let local = main.row_slice(0).expect("window should have two elements"); + let next = main.row_slice(1).expect("window should have two elements"); + let local: &UnsetPvsCols = (*local).borrow(); + let next: &UnsetPvsCols = (*next).borrow(); + + builder.assert_bool(local.is_valid); + builder + .when_transition() + .assert_one(next.proof_idx - local.proof_idx); + + let air_idx = AB::F::from_usize(self.air_idx); + + for pv_idx in 0..self.num_pvs { + self.public_values_bus.receive( + builder, + local.proof_idx, + PublicValuesBusMessage { + air_idx, + pv_idx: AB::F::from_usize(pv_idx), + value: AB::F::ZERO, + }, + local.is_valid, + ); + } + + self.pvs_air_consistency_bus.lookup_key( + builder, + local.proof_idx, + PvsAirConsistencyMessage { + deferral_flag: AB::F::from_u32(self.def_flag), + has_verifier_pvs: AB::F::ONE, + }, + local.is_valid, + ); + } +} diff --git a/ceno_recursion_v2/src/circuit/inner/unset/mod.rs b/ceno_recursion_v2/src/circuit/inner/unset/mod.rs new file mode 100644 index 000000000..26ed4a40d --- /dev/null +++ b/ceno_recursion_v2/src/circuit/inner/unset/mod.rs @@ -0,0 +1,5 @@ +mod air; +mod trace; + +pub use air::*; +pub use trace::*; diff --git a/ceno_recursion_v2/src/circuit/inner/unset/trace.rs b/ceno_recursion_v2/src/circuit/inner/unset/trace.rs new file mode 100644 index 000000000..c0b2c743b --- /dev/null +++ b/ceno_recursion_v2/src/circuit/inner/unset/trace.rs @@ -0,0 +1,34 @@ +use std::borrow::BorrowMut; + +use openvm_cpu_backend::CpuBackend; +use openvm_stark_backend::prover::AirProvingContext; +use openvm_stark_sdk::config::baby_bear_poseidon2::{BabyBearPoseidon2Config, F}; +use p3_field::PrimeCharacteristicRing; +use p3_matrix::dense::RowMajorMatrix; + +use crate::circuit::inner::unset::UnsetPvsCols; + +pub fn generate_proving_ctx( + unset_proof_idxs: &[usize], + child_is_app: bool, +) -> AirProvingContext> { + let num_valid = if child_is_app { + 0 + } else { + unset_proof_idxs.len() + }; + + let height = num_valid.max(1).next_power_of_two(); + let width = UnsetPvsCols::::width(); + let mut trace = vec![F::ZERO; height * width]; + let mut chunks = trace.chunks_exact_mut(width); + + for proof_idx in unset_proof_idxs.iter().take(num_valid) { + let chunk = chunks.next().unwrap(); + let cols: &mut UnsetPvsCols = chunk.borrow_mut(); + cols.is_valid = F::ONE; + cols.proof_idx = F::from_usize(*proof_idx); + } + + AirProvingContext::simple_no_pis(RowMajorMatrix::new(trace, width)) +} diff --git a/ceno_recursion_v2/src/circuit/inner/verifier/air.rs b/ceno_recursion_v2/src/circuit/inner/verifier/air.rs new file mode 100644 index 000000000..5fd7cd536 --- /dev/null +++ b/ceno_recursion_v2/src/circuit/inner/verifier/air.rs @@ -0,0 +1,619 @@ +use std::{array::from_fn, borrow::Borrow}; + +use openvm_circuit_primitives::utils::{and, assert_array_eq, not}; +use openvm_stark_backend::{ + BaseAirWithPublicValues, PartitionedBaseAir, interaction::InteractionBuilder, +}; +use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, BaseAir}; +use p3_field::{Field, PrimeCharacteristicRing}; +use p3_matrix::Matrix; +use recursion_circuit::{ + bus::{ + CachedCommitBus, CachedCommitBusMessage, Poseidon2CompressBus, Poseidon2CompressMessage, + PublicValuesBus, PublicValuesBusMessage, + }, + prelude::DIGEST_SIZE, + utils::assert_zeros, +}; +use stark_recursion_circuit_derive::AlignedBorrow; +use verify_stark::pvs::{ + CONSTRAINT_EVAL_AIR_ID, DagCommit, VERIFIER_PVS_AIR_ID, VerifierBasePvs, VerifierDefPvs, +}; + +use crate::{ + circuit::{ + CONSTRAINT_EVAL_CACHED_INDEX, + inner::bus::{PvsAirConsistencyBus, PvsAirConsistencyMessage}, + }, + utils::digests_to_poseidon2_input, +}; + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct VerifierPvsCols { + pub proof_idx: F, + pub is_valid: F, + pub has_verifier_pvs: F, + pub child_pvs: VerifierBasePvs, +} + +pub struct VerifierPvsAir { + pub public_values_bus: PublicValuesBus, + pub cached_commit_bus: CachedCommitBus, + pub pvs_air_consistency_bus: PvsAirConsistencyBus, + pub deferral_config: VerifierDeferralConfig, +} + +impl BaseAir for VerifierPvsAir { + fn width(&self) -> usize { + VerifierPvsCols::::width() + self.deferral_config.width() + } +} +impl BaseAirWithPublicValues for VerifierPvsAir { + fn num_public_values(&self) -> usize { + VerifierBasePvs::::width() + self.deferral_config.num_public_values() + } +} +impl PartitionedBaseAir for VerifierPvsAir {} + +impl Air for VerifierPvsAir { + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let (local, next) = ( + main.row_slice(0).expect("window should have two elements"), + main.row_slice(1).expect("window should have two elements"), + ); + + let base_cols_width = VerifierPvsCols::::width(); + let (base_local, def_local) = local.split_at(base_cols_width); + let (base_next, def_next) = next.split_at(base_cols_width); + + let local: &VerifierPvsCols = (*base_local).borrow(); + let next: &VerifierPvsCols = (*base_next).borrow(); + + // This AIR can optionally handle deferrals, the constraints for which are defined in + // function eval_deferrals. We expect dag_commit_cond to be a boolean value that is + // true iff local and next's app, leaf, and internal-for-leaf DAG commits should be + // constrained for equality. + let (dag_commit_cond, deferral_flag, consistency_mult) = match self.deferral_config { + VerifierDeferralConfig::Enabled { poseidon2_bus } => { + let def_local: &VerifierDeferralCols = (*def_local).borrow(); + let def_next: &VerifierDeferralCols = (*def_next).borrow(); + self.eval_deferrals(builder, local, next, def_local, def_next, poseidon2_bus) + } + VerifierDeferralConfig::Disabled => { + debug_assert_eq!(def_local.len(), 0); + ( + and(local.is_valid, next.is_valid), + AB::Expr::ZERO, + AB::Expr::ONE, + ) + } + }; + + // Constrain basic features about the non-pvs columns. + builder.assert_bool(local.is_valid); + builder.when_first_row().assert_one(local.is_valid); + builder + .when_transition() + .assert_bool(local.is_valid - next.is_valid); + + builder.when_first_row().assert_zero(local.proof_idx); + builder + .when_transition() + .when(and(local.is_valid, next.is_valid)) + .assert_eq(local.proof_idx + AB::F::ONE, next.proof_idx); + + builder.assert_bool(local.has_verifier_pvs); + builder + .when(local.has_verifier_pvs) + .assert_one(local.is_valid); + + // We constrain the consistency of verifier-specific public values. We can determine + // what layer a verifier is at using the has_verifier_pvs and internal_flag columns. + // There are several cases we cover: + // - has_verifier_pvs == 0: leaf verifier, app (or deferral circuit) children + // - has_verifier_pvs == 1 && internal_flag == 0: internal verifier with leaf children + // - has_verifier_pvs == 1 && internal_flag == 1: internal_for_leaf children + // - has_verifier_pvs == 1 && internal_flag == 2: internal_recursive children + // - recursion_flag == 1: 2nd (i.e. index 1) internal_recursive layer + // - recursion_flag == 1: 3rd internal_recursive layer or beyond + // constrain the verifier pvs flags and internal_recursive_dag_tommit are the same + // across all valid rows + let both_valid = and(local.is_valid, next.is_valid); + let mut when_both_valid = builder.when(both_valid.clone()); + + when_both_valid.assert_eq(local.has_verifier_pvs, next.has_verifier_pvs); + when_both_valid.assert_eq(local.child_pvs.internal_flag, next.child_pvs.internal_flag); + when_both_valid.assert_eq( + local.child_pvs.recursion_flag, + next.child_pvs.recursion_flag, + ); + + assert_dag_commit_eq( + &mut when_both_valid, + local.child_pvs.internal_recursive_dag_commit, + next.child_pvs.internal_recursive_dag_commit, + ); + + // constrain the other commits are the same when needed + let mut when_dag_compare = builder.when(dag_commit_cond); + + assert_dag_commit_eq( + &mut when_dag_compare, + local.child_pvs.app_dag_commit, + next.child_pvs.app_dag_commit, + ); + assert_dag_commit_eq( + &mut when_dag_compare, + local.child_pvs.leaf_dag_commit, + next.child_pvs.leaf_dag_commit, + ); + assert_dag_commit_eq( + &mut when_dag_compare, + local.child_pvs.internal_for_leaf_dag_commit, + next.child_pvs.internal_for_leaf_dag_commit, + ); + + // constrain that the flags are ternary + builder.assert_tern(local.child_pvs.internal_flag); + builder.assert_tern(local.child_pvs.recursion_flag); + + // constrain that internal_flag is 2 when recursion_flag is set, and not 2 otherwise + builder + .when(local.child_pvs.recursion_flag) + .assert_eq(local.child_pvs.internal_flag, AB::F::TWO); + builder + .when( + (local.child_pvs.recursion_flag - AB::F::ONE) + * (local.child_pvs.recursion_flag - AB::F::TWO), + ) + .assert_bool(local.child_pvs.internal_flag); + + // constrain that child commits are 0 when they shouldn't be defined + let is_leaf = not(local.has_verifier_pvs); + let is_internal = local.has_verifier_pvs; + + builder + .when(is_leaf.clone()) + .assert_zero(local.child_pvs.internal_flag); + + assert_dag_commit_unset( + &mut builder.when(is_leaf.clone()), + local.child_pvs.app_dag_commit, + ); + assert_dag_commit_unset( + &mut builder.when( + (local.child_pvs.internal_flag - AB::F::ONE) + * (local.child_pvs.internal_flag - AB::F::TWO), + ), + local.child_pvs.leaf_dag_commit, + ); + assert_dag_commit_unset( + &mut builder.when(local.child_pvs.internal_flag - AB::F::TWO), + local.child_pvs.internal_for_leaf_dag_commit, + ); + assert_dag_commit_unset( + &mut builder.when(local.child_pvs.recursion_flag - AB::F::TWO), + local.child_pvs.internal_recursive_dag_commit, + ); + + // We need to receive public values from ProofShapeModule to ensure the values being read + // here are correct. This AIR will only read values if it's internal. + let verifier_pvs_id = AB::Expr::from_usize(VERIFIER_PVS_AIR_ID); + + for (pv_idx, value) in local.child_pvs.as_slice().iter().enumerate() { + self.public_values_bus.receive( + builder, + local.proof_idx, + PublicValuesBusMessage { + air_idx: verifier_pvs_id.clone(), + pv_idx: AB::Expr::from_usize(pv_idx), + value: (*value).into(), + }, + local.is_valid * is_internal, + ); + } + + // We also need to receive cached commits from ProofShapeModule. Note that the + // app/deferral circuit cached commits are received in another AIR, so only the + // internal verifier will receive them here. + let is_internal_flag_zero = (local.child_pvs.internal_flag - AB::F::ONE) + * (local.child_pvs.internal_flag - AB::F::TWO) + * AB::F::TWO.inverse(); + let is_internal_flag_one = + (AB::Expr::TWO - local.child_pvs.internal_flag) * local.child_pvs.internal_flag; + let is_recursion_flag_one = + (AB::Expr::TWO - local.child_pvs.recursion_flag) * local.child_pvs.recursion_flag; + let is_recursion_flag_two = (local.child_pvs.recursion_flag - AB::F::ONE) + * local.child_pvs.recursion_flag + * AB::F::TWO.inverse(); + let cached_commit = from_fn(|i| { + is_internal_flag_zero.clone() * local.child_pvs.app_dag_commit.cached_commit[i] + + is_internal_flag_one.clone() * local.child_pvs.leaf_dag_commit.cached_commit[i] + + is_recursion_flag_one.clone() + * local.child_pvs.internal_for_leaf_dag_commit.cached_commit[i] + + is_recursion_flag_two.clone() + * local.child_pvs.internal_recursive_dag_commit.cached_commit[i] + }); + + self.cached_commit_bus.receive( + builder, + local.proof_idx, + CachedCommitBusMessage { + air_idx: AB::Expr::from_usize(CONSTRAINT_EVAL_AIR_ID), + cached_idx: AB::Expr::from_usize(CONSTRAINT_EVAL_CACHED_INDEX), + cached_commit, + }, + local.is_valid * is_internal, + ); + + // We provide proof metadata for lookup here to ensure consistency between AIRs that + // process public values. + self.pvs_air_consistency_bus.add_key_with_lookups( + builder, + local.proof_idx, + PvsAirConsistencyMessage { + deferral_flag, + has_verifier_pvs: local.has_verifier_pvs.into(), + }, + local.is_valid * consistency_mult, + ); + + // Finally, we need to constrain that the public values this AIR produces are consistent + // with the child's. Note that we only impose constraints for layers below the current + // one - it is impossible for the current layer to know its own commit, and future layers + // will catch if we preemptively define a current or future verifier commit. + let base_pvs_width = VerifierBasePvs::::width(); + let &VerifierBasePvs::<_> { + internal_flag, + app_dag_commit, + leaf_dag_commit, + internal_for_leaf_dag_commit, + recursion_flag, + internal_recursive_dag_commit, + } = builder.public_values()[0..base_pvs_width].borrow(); + + // constrain internal_flag is 0 at the leaf level + builder + .when(and(local.is_valid, is_leaf.clone())) + .assert_zero(internal_flag); + + // constrain recursion_flag is 0 at the leaf and internal_for_leaf levels + builder + .when( + local.is_valid + * (local.child_pvs.internal_flag - AB::F::ONE) + * (local.child_pvs.internal_flag - AB::F::TWO), + ) + .assert_zero(recursion_flag); + + // constraint internal_flag is incremented properly at internal levels + builder + .when(is_internal) + .when_ne(local.child_pvs.internal_flag, AB::F::TWO) + .assert_eq(internal_flag, local.child_pvs.internal_flag + AB::F::ONE); + + // constrain app_dag_commit is set at all internal levels and matches the first row + assert_dag_commit_eq( + &mut builder.when_first_row().when(is_internal), + local.child_pvs.app_dag_commit, + app_dag_commit, + ); + + // constrain verifier-specific pvs at all internal_recursive levels + builder + .when(local.child_pvs.internal_flag) + .assert_zero(internal_flag.into() - AB::F::TWO); + assert_dag_commit_eq( + &mut builder.when_first_row().when(local.child_pvs.internal_flag), + local.child_pvs.leaf_dag_commit, + leaf_dag_commit, + ); + + // constrain recursion_flag is 1 at the first internal_recursive level + builder + .when(local.child_pvs.internal_flag * (local.child_pvs.internal_flag - AB::F::TWO)) + .assert_one(recursion_flag); + + // constrain verifier-specific pvs at internal_recursive levels after the first + builder + .when(local.child_pvs.recursion_flag) + .assert_eq(recursion_flag, AB::F::TWO); + assert_dag_commit_eq( + &mut builder + .when_first_row() + .when(local.child_pvs.recursion_flag), + local.child_pvs.internal_for_leaf_dag_commit, + internal_for_leaf_dag_commit, + ); + + // constrain verifier-specific pvs at internal_recursive levels after the second + assert_dag_commit_eq( + &mut builder.when( + local.child_pvs.recursion_flag * (local.child_pvs.recursion_flag - AB::F::ONE), + ), + local.child_pvs.internal_recursive_dag_commit, + internal_recursive_dag_commit, + ); + } +} + +/////////////////////////////////////////////////////////////////////////////// +// DEFERRAL SUPPORT +/////////////////////////////////////////////////////////////////////////////// + +pub enum VerifierDeferralConfig { + Enabled { poseidon2_bus: Poseidon2CompressBus }, + Disabled, +} + +impl VerifierDeferralConfig { + pub fn width(&self) -> usize { + match self { + VerifierDeferralConfig::Enabled { .. } => VerifierDeferralCols::::width(), + VerifierDeferralConfig::Disabled => 0, + } + } + + pub fn num_public_values(&self) -> usize { + match self { + VerifierDeferralConfig::Enabled { .. } => VerifierDefPvs::::width(), + VerifierDeferralConfig::Disabled => 0, + } + } +} + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct VerifierDeferralCols { + pub is_last: F, + pub intermediate_def_vk_commit: [F; DIGEST_SIZE], + pub child_pvs: VerifierDefPvs, +} + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct VerifierCombinedPvs { + pub base: VerifierBasePvs, + pub def: VerifierDefPvs, +} + +impl VerifierPvsAir { + fn eval_deferrals( + &self, + builder: &mut AB, + base_local: &VerifierPvsCols, + base_next: &VerifierPvsCols, + def_local: &VerifierDeferralCols, + def_next: &VerifierDeferralCols, + poseidon2_bus: Poseidon2CompressBus, + ) -> (AB::Expr, AB::Expr, AB::Expr) + where + AB: AirBuilder + InteractionBuilder + AirBuilderWithPublicValues, + { + // The deferral_flag should be 0 if a proof has only VM public values defined, 1 if + // only deferral public values, and 2 if both. There are 4 valid cases: + // - All valid rows have deferral_flag == 0 + // - All valid rows have deferral_flag == 1 + // - There are exactly two rows with deferral_flag == row_idx + // - There is exactly one row with deferral_flag == 2 + let delta = def_next.child_pvs.deferral_flag - def_local.child_pvs.deferral_flag; + builder.assert_tern(def_local.child_pvs.deferral_flag); + + // constrain that is_last is correctly set on the last valid row + builder.assert_bool(def_local.is_last); + builder + .when(def_local.is_last) + .assert_one(base_local.is_valid); + builder + .when(and(base_local.is_valid, not(def_local.is_last))) + .assert_one(base_next.is_valid); + builder + .when(def_local.is_last) + .assert_zero(base_next.is_valid * base_next.proof_idx); + builder + .when_last_row() + .when(base_local.is_valid) + .assert_one(def_local.is_last); + + // constrain that delta is 0 or 1 + builder.when_transition().assert_bool(delta.clone()); + + // constrain that if deferral_flag is 1 or 2, it cannot change later (note that if + // deferral_flag is 1 or 2 there may only be 1 or 2 rows) + builder + .when_transition() + .when(def_local.child_pvs.deferral_flag) + .assert_zero(delta.clone()); + + // constrain that the 0->1 transition happens only on the first row + builder + .when_transition() + .when(base_local.proof_idx) + .assert_zero(delta.clone()); + + // constrain that if first row is 2, it must be the only valid row + builder + .when(def_local.child_pvs.deferral_flag) + .when_ne(def_local.child_pvs.deferral_flag, AB::F::ONE) + .assert_one(def_local.is_last); + + // constrain row 1 to be the last on the 0->1 transition + builder + .when_transition() + .when(delta.clone()) + .assert_one(def_next.is_last); + + // We also need to constrain the deferral-related public values. In particular, the + // def_hook_vk_commit should be defined exactly when internal_for_leaf_dag_commit + // is for deferral_flag == 1. + // constrain that delta == 1 only at some internal_recursive layer + builder + .when(delta.clone()) + .assert_eq(base_local.child_pvs.internal_flag, AB::F::TWO); + builder + .when(def_local.child_pvs.deferral_flag) + .when_ne(def_local.child_pvs.deferral_flag, AB::F::ONE) + .assert_eq(base_local.child_pvs.internal_flag, AB::F::TWO); + + // constrain that def_hook_vk_commit is unset when internal_for_leaf_dag_commit is + assert_zeros( + &mut builder.when(base_local.child_pvs.internal_flag - AB::F::TWO), + def_local.child_pvs.def_hook_vk_commit, + ); + + // constrain def_hook_vk_commit when internal_flag is 2 and deferral_flag is 1 + let half = AB::F::TWO.inverse(); + let is_def_hook_vk_defined = base_local.child_pvs.internal_flag + * (base_local.child_pvs.internal_flag - AB::Expr::ONE) + * def_local.child_pvs.deferral_flag + * (AB::Expr::TWO - def_local.child_pvs.deferral_flag) + * half; + + poseidon2_bus.lookup_key( + builder, + Poseidon2CompressMessage { + input: digests_to_poseidon2_input( + base_local.child_pvs.app_dag_commit.cached_commit, + base_local.child_pvs.leaf_dag_commit.cached_commit, + ), + output: def_local.intermediate_def_vk_commit, + }, + is_def_hook_vk_defined.clone(), + ); + + poseidon2_bus.lookup_key( + builder, + Poseidon2CompressMessage { + input: digests_to_poseidon2_input( + def_local.intermediate_def_vk_commit, + base_local + .child_pvs + .internal_for_leaf_dag_commit + .cached_commit, + ), + output: def_local.child_pvs.def_hook_vk_commit, + }, + is_def_hook_vk_defined, + ); + + // We need to receive dedeferral-specific public values from ProofShapeModule to + // ensure the values being read are correct. + let verifier_pvs_id = AB::Expr::from_usize(VERIFIER_PVS_AIR_ID); + let pvs_offset = VerifierBasePvs::::width(); + + for (pv_idx, value) in def_local.child_pvs.as_slice().iter().enumerate() { + self.public_values_bus.receive( + builder, + base_local.proof_idx, + PublicValuesBusMessage { + air_idx: verifier_pvs_id.clone(), + pv_idx: AB::Expr::from_usize(pv_idx + pvs_offset), + value: (*value).into(), + }, + base_local.is_valid * base_local.has_verifier_pvs, + ); + } + + // Finally, we need to constrain that the deferral-specific public values this AIR + // produces are consistent with the child's. + let &VerifierCombinedPvs::<_> { + base: base_pvs, + def: def_pvs, + } = builder.public_values().borrow(); + + let &VerifierBasePvs::<_> { + internal_flag, + app_dag_commit, + leaf_dag_commit, + internal_for_leaf_dag_commit, + .. + } = base_pvs.as_slice().borrow(); + + let &VerifierDefPvs::<_> { + deferral_flag, + def_hook_vk_commit, + } = def_pvs.as_slice().borrow(); + + // constrain deferral_flag either matches each row, or is 2 when delta is non-zero + builder + .when(delta.clone()) + .assert_eq(deferral_flag, AB::F::TWO); + builder + .when_ne(delta.clone(), AB::F::ONE) + .when_ne(delta.clone(), -AB::F::ONE) + .assert_eq(deferral_flag, def_local.child_pvs.deferral_flag); + + // constrain def_hook_vk_commit matches if set in child_pvs + assert_array_eq( + &mut builder + .when(base_local.child_pvs.recursion_flag) + .when(def_local.child_pvs.deferral_flag), + def_local.child_pvs.def_hook_vk_commit, + def_hook_vk_commit, + ); + + // constrain def_hook_vk_commit when internal_flag is 2 and deferral_flag is 1 + let is_def_hook_vk_defined = internal_flag.into() + * (internal_flag.into() - AB::Expr::ONE) + * deferral_flag.into() + * (AB::Expr::TWO - deferral_flag.into()) + * half; + + poseidon2_bus.lookup_key( + builder, + Poseidon2CompressMessage { + input: digests_to_poseidon2_input( + app_dag_commit.cached_commit, + leaf_dag_commit.cached_commit, + ) + .map(Into::into), + output: def_local.intermediate_def_vk_commit.map(Into::into), + }, + is_def_hook_vk_defined.clone(), + ); + + poseidon2_bus.lookup_key( + builder, + Poseidon2CompressMessage { + input: digests_to_poseidon2_input( + def_local.intermediate_def_vk_commit.map(Into::into), + internal_for_leaf_dag_commit.cached_commit.map(Into::into), + ), + output: def_hook_vk_commit.map(Into::into), + }, + is_def_hook_vk_defined, + ); + + // Finally, we need to generate some expressions for use in the outer constraints. + // dag_commit_cond is non-zero iff on a transition row and all deferral flags are + // the same, and consistency_mult is the number of lookups this AIR will receive + // on the PvsAirConsistencyBus. + let dag_commit_cond = + and(base_local.is_valid, not(def_local.is_last)) * (AB::Expr::ONE - delta); + let deferral_flag = def_local.child_pvs.deferral_flag.into(); + let consistency_mult = base_local.has_verifier_pvs + AB::Expr::ONE; + + (dag_commit_cond, deferral_flag, consistency_mult) + } +} + +fn assert_dag_commit_eq(builder: &mut AB, left: DagCommit, right: DagCommit) +where + AB: AirBuilder, + I1: Into + Copy, + I2: Into + Copy, +{ + assert_array_eq(builder, left.cached_commit, right.cached_commit); + assert_array_eq(builder, left.vk_pre_hash, right.vk_pre_hash); +} + +fn assert_dag_commit_unset(builder: &mut AB, commit: DagCommit) +where + AB: AirBuilder, + I: Into + Copy, +{ + assert_zeros(builder, commit.cached_commit); + assert_zeros(builder, commit.vk_pre_hash); +} diff --git a/ceno_recursion_v2/src/circuit/inner/verifier/mod.rs b/ceno_recursion_v2/src/circuit/inner/verifier/mod.rs new file mode 100644 index 000000000..26ed4a40d --- /dev/null +++ b/ceno_recursion_v2/src/circuit/inner/verifier/mod.rs @@ -0,0 +1,5 @@ +mod air; +mod trace; + +pub use air::*; +pub use trace::*; diff --git a/ceno_recursion_v2/src/circuit/inner/verifier/trace.rs b/ceno_recursion_v2/src/circuit/inner/verifier/trace.rs new file mode 100644 index 000000000..d80a9d8ab --- /dev/null +++ b/ceno_recursion_v2/src/circuit/inner/verifier/trace.rs @@ -0,0 +1,73 @@ +use std::borrow::BorrowMut; + +use openvm_cpu_backend::CpuBackend; +use openvm_poseidon2_air::POSEIDON2_WIDTH; +use openvm_stark_backend::prover::AirProvingContext; +use openvm_stark_sdk::config::baby_bear_poseidon2::{BabyBearPoseidon2Config, DIGEST_SIZE, F}; +use p3_field::PrimeCharacteristicRing; +use p3_matrix::dense::RowMajorMatrix; +use verify_stark::pvs::{VerifierBasePvs, VerifierDefPvs}; + +use crate::{ + circuit::inner::{ + ProofsType, + verifier::air::{VerifierDeferralCols, VerifierPvsCols}, + }, + system::RecursionProof, +}; + +pub fn generate_proving_ctx( + proofs: &[RecursionProof], + proofs_type: ProofsType, + child_is_app: bool, + child_dag_commit: [F; DIGEST_SIZE], + deferral_enabled: bool, +) -> ( + AirProvingContext>, + Vec<[F; POSEIDON2_WIDTH]>, +) { + // TODO(recursion-proof-bridge): populate verifier trace/public values from RecursionProof. + // Any verifier-specific values not available on RecursionProof are currently zero-mocked. + let _ = (proofs, proofs_type, child_is_app, child_dag_commit); + + let rows = proofs.len().max(1).next_power_of_two(); + let width = VerifierPvsCols::::width() + + if deferral_enabled { + VerifierDeferralCols::::width() + } else { + 0 + }; + + let mut trace = vec![F::ZERO; rows * width]; + + if rows > 0 { + let first_row = &mut trace[..width]; + let base_width = VerifierPvsCols::::width(); + let (base_row, def_row) = first_row.split_at_mut(base_width); + + let cols: &mut VerifierPvsCols = base_row.borrow_mut(); + cols.proof_idx = F::ZERO; + cols.is_valid = F::ONE; + cols.has_verifier_pvs = F::ZERO; + + if deferral_enabled { + let def_cols: &mut VerifierDeferralCols = def_row.borrow_mut(); + def_cols.is_last = F::ONE; + def_cols.child_pvs.deferral_flag = F::ZERO; + } + } + + let mut num_public_values = VerifierBasePvs::::width(); + if deferral_enabled { + num_public_values += VerifierDefPvs::::width(); + } + + ( + AirProvingContext { + cached_mains: vec![], + common_main: RowMajorMatrix::new(trace, width), + public_values: vec![F::ZERO; num_public_values], + }, + vec![], + ) +} diff --git a/ceno_recursion_v2/src/circuit/inner/vm_pvs/air.rs b/ceno_recursion_v2/src/circuit/inner/vm_pvs/air.rs new file mode 100644 index 000000000..0530f5e2e --- /dev/null +++ b/ceno_recursion_v2/src/circuit/inner/vm_pvs/air.rs @@ -0,0 +1,375 @@ +use std::borrow::Borrow; + +use openvm_circuit::system::connector::DEFAULT_SUSPEND_EXIT_CODE; +use openvm_circuit_primitives::utils::{and, assert_array_eq, not}; +use openvm_stark_backend::{ + BaseAirWithPublicValues, PartitionedBaseAir, interaction::InteractionBuilder, +}; +use openvm_stark_sdk::config::baby_bear_poseidon2::DIGEST_SIZE; +use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, BaseAir}; +use p3_field::PrimeCharacteristicRing; +use p3_matrix::Matrix; +use recursion_circuit::bus::{ + CachedCommitBus, CachedCommitBusMessage, PublicValuesBus, PublicValuesBusMessage, +}; +use stark_recursion_circuit_derive::AlignedBorrow; +use verify_stark::pvs::{VM_PVS_AIR_ID, VmPvs}; + +use crate::circuit::inner::{ + app::*, + bus::{PvsAirConsistencyBus, PvsAirConsistencyMessage}, +}; + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct VmPvsCols { + pub proof_idx: F, + pub is_valid: F, + pub is_last: F, + pub has_verifier_pvs: F, + pub child_pvs: VmPvs, +} + +pub struct VmPvsAir { + pub public_values_bus: PublicValuesBus, + pub cached_commit_bus: CachedCommitBus, + pub pvs_air_consistency_bus: PvsAirConsistencyBus, + pub deferral_enabled: bool, +} + +impl BaseAir for VmPvsAir { + fn width(&self) -> usize { + VmPvsCols::::width() + (self.deferral_enabled as usize) + } +} +impl BaseAirWithPublicValues for VmPvsAir { + fn num_public_values(&self) -> usize { + VmPvs::::width() + } +} +impl PartitionedBaseAir for VmPvsAir {} + +impl Air for VmPvsAir { + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let (local, next) = ( + main.row_slice(0).expect("window should have two elements"), + main.row_slice(1).expect("window should have two elements"), + ); + + let base_cols_width = VmPvsCols::::width(); + let (base_local, def_local) = local.split_at(base_cols_width); + let (base_next, next_local) = next.split_at(base_cols_width); + + let local: &VmPvsCols = (*base_local).borrow(); + let next: &VmPvsCols = (*base_next).borrow(); + + // If deferrals are enabled, this AIR expects an additional deferral_flag column. It + // can be either 0 or 2 here, and in the latter case there can only be one row. + let (deferral_flag, has_vm_pvs) = if self.deferral_enabled { + debug_assert_eq!(def_local.len(), 1); + debug_assert_eq!(next_local.len(), 1); + self.eval_deferrals(builder, local, def_local[0], next_local[0]) + } else { + debug_assert_eq!(def_local.len(), 0); + debug_assert_eq!(next_local.len(), 0); + (AB::Expr::ZERO, AB::Expr::ONE) + }; + + // Basic constraints for non-public value columns. + // constrain all valid rows are at the beginning + builder.assert_bool(local.is_valid); + builder + .when_first_row() + .assert_eq(local.is_valid, has_vm_pvs); + builder + .when_transition() + .assert_bool(local.is_valid - next.is_valid); + + // constrain increasing proof_idx + builder.when_first_row().assert_zero(local.proof_idx); + builder + .when_transition() + .when(and(local.is_valid, next.is_valid)) + .assert_eq(local.proof_idx + AB::Expr::ONE, next.proof_idx); + + // constrain is_last, note proof_idx on the first row is 0 + builder.assert_bool(local.is_last); + builder.when(local.is_last).assert_one(local.is_valid); + builder + .when(and(local.is_valid, not(local.is_last))) + .assert_one(next.is_valid); + builder + .when(local.is_last) + .assert_zero(next.is_valid * next.proof_idx); + builder + .when_last_row() + .when(local.is_valid) + .assert_one(local.is_last); + + // constrain has_verifier_pvs, which will be compared with the other pv AIRs + builder.assert_bool(local.has_verifier_pvs); + builder + .when(local.has_verifier_pvs) + .assert_one(local.is_valid); + builder + .when(and(local.is_valid, next.is_valid)) + .assert_eq(local.has_verifier_pvs, next.has_verifier_pvs); + + // We first constrain segment adjacency, i.e. that rows in the trace are such that the + // first row is the (chronologically) first segment, and adjacent rows correspond to + // adjacent segments. + // constrain that is_terminate is the last valid proof + builder.assert_bool(local.child_pvs.is_terminate); + builder + .when(local.child_pvs.is_terminate) + .assert_one(local.is_last); + builder + .when(local.child_pvs.is_terminate) + .assert_zero(local.child_pvs.exit_code); + + // constrain that non-terminal segments exited successfully + builder + .when(and(local.is_valid, not(local.child_pvs.is_terminate))) + .assert_eq( + local.child_pvs.exit_code, + AB::F::from_u32(DEFAULT_SUSPEND_EXIT_CODE), + ); + + // when local and next are valid, constrain increasing proof_idx and adjacency + let mut when_both_valid = builder.when(and(local.is_valid, not(local.is_last))); + when_both_valid.assert_eq(local.child_pvs.final_pc, next.child_pvs.initial_pc); + assert_array_eq( + &mut when_both_valid, + local.child_pvs.final_root, + next.child_pvs.initial_root, + ); + + // We receive public values from ProofShapeModule to ensure the values being read here + // are correct. The leaf verifier reads public values from PROGRAM_AIR_ID, + // CONNECTOR_AIR_ID, and MERKLE_AID_ID while the internal verifier reads the full + // VmPvs from VM_PVS_AIR_ID. + let is_leaf = not(local.has_verifier_pvs); + let is_internal = local.has_verifier_pvs; + + let mut internal_pv_idx = 0u8; + let internal_air_id = is_internal * AB::Expr::from_usize(VM_PVS_AIR_ID); + let mut internal_pp = || { + let ret = is_internal * AB::Expr::from_u8(internal_pv_idx); + internal_pv_idx += 1; + ret + }; + + // receive program_commit + let cond_program_air_id = + is_leaf.clone() * AB::Expr::from_usize(PROGRAM_AIR_ID) + internal_air_id.clone(); + + for (didx, value) in local.child_pvs.program_commit.iter().enumerate() { + self.public_values_bus.receive( + builder, + local.proof_idx, + PublicValuesBusMessage { + air_idx: cond_program_air_id.clone(), + pv_idx: is_leaf.clone() * AB::Expr::from_usize(didx) + internal_pp(), + value: (*value).into(), + }, + local.is_valid * is_internal, + ); + } + + // receive connector public values + let cond_connector_air_id = + is_leaf.clone() * AB::Expr::from_usize(CONNECTOR_AIR_ID) + internal_air_id.clone(); + + self.public_values_bus.receive( + builder, + local.proof_idx, + PublicValuesBusMessage { + air_idx: cond_connector_air_id.clone(), + pv_idx: internal_pp(), + value: local.child_pvs.initial_pc.into(), + }, + local.is_valid, + ); + + self.public_values_bus.receive( + builder, + local.proof_idx, + PublicValuesBusMessage { + air_idx: cond_connector_air_id.clone(), + pv_idx: is_leaf.clone() + internal_pp(), + value: local.child_pvs.final_pc.into(), + }, + local.is_valid, + ); + + self.public_values_bus.receive( + builder, + local.proof_idx, + PublicValuesBusMessage { + air_idx: cond_connector_air_id.clone(), + pv_idx: is_leaf.clone() * AB::Expr::TWO + internal_pp(), + value: local.child_pvs.exit_code.into(), + }, + local.is_valid, + ); + + self.public_values_bus.receive( + builder, + local.proof_idx, + PublicValuesBusMessage { + air_idx: cond_connector_air_id.clone(), + pv_idx: is_leaf.clone() * AB::Expr::from_u8(3) + internal_pp(), + value: local.child_pvs.is_terminate.into(), + }, + local.is_valid, + ); + + // receive memory Merkle public values + let cond_merkle_air_id = + is_leaf.clone() * AB::Expr::from_usize(MERKLE_AIR_ID) + internal_air_id.clone(); + + for (didx, value) in local.child_pvs.initial_root.iter().enumerate() { + self.public_values_bus.receive( + builder, + local.proof_idx, + PublicValuesBusMessage { + air_idx: cond_merkle_air_id.clone(), + pv_idx: is_leaf.clone() * AB::Expr::from_usize(didx) + internal_pp(), + value: (*value).into(), + }, + local.is_valid, + ); + } + + for (didx, value) in local.child_pvs.final_root.iter().enumerate() { + self.public_values_bus.receive( + builder, + local.proof_idx, + PublicValuesBusMessage { + air_idx: cond_merkle_air_id.clone(), + pv_idx: is_leaf.clone() * AB::Expr::from_usize(didx + DIGEST_SIZE) + + internal_pp(), + value: (*value).into(), + }, + local.is_valid, + ); + } + + // At the leaf level, this AIR is responsible for receiving the cached trace commit + // program_commit. + self.cached_commit_bus.receive( + builder, + local.proof_idx, + CachedCommitBusMessage { + air_idx: AB::Expr::from_usize(PROGRAM_AIR_ID), + cached_idx: AB::Expr::from_usize(PROGRAM_CACHED_TRACE_INDEX), + cached_commit: local.child_pvs.program_commit.map(Into::into), + }, + local.is_valid * is_leaf, + ); + + // We look up proof metadata from VerifierPvsAir here to ensure consistency on each row. + self.pvs_air_consistency_bus.lookup_key( + builder, + local.proof_idx, + PvsAirConsistencyMessage { + deferral_flag, + has_verifier_pvs: local.has_verifier_pvs.into(), + }, + local.is_valid, + ); + + // Finally, we need to constrain that the public values this AIR produces are consistent + // with the child's. Initial output pvs must match the first row, and final output pvs + // must match the last. + let &VmPvs::<_> { + program_commit, + initial_pc, + final_pc, + exit_code, + is_terminate, + initial_root, + final_root, + } = builder.public_values().borrow(); + + // constrain first proof pvs + builder + .when_first_row() + .assert_eq(local.child_pvs.initial_pc, initial_pc); + assert_array_eq( + &mut builder.when_first_row(), + local.child_pvs.initial_root, + initial_root, + ); + + // constrain last proof pvs + builder + .when(local.is_last) + .assert_eq(local.child_pvs.final_pc, final_pc); + builder + .when(local.is_last) + .assert_eq(local.child_pvs.exit_code, exit_code); + builder + .when(local.is_last) + .assert_eq(local.child_pvs.is_terminate, is_terminate); + assert_array_eq( + &mut builder.when(local.is_last), + local.child_pvs.final_root, + final_root, + ); + + // constrain program_commit + assert_array_eq( + &mut builder.when(local.is_valid), + local.child_pvs.program_commit, + program_commit, + ); + } +} + +impl VmPvsAir { + fn eval_deferrals( + &self, + builder: &mut AB, + local: &VmPvsCols, + local_def_flag: AB::Var, + next_def_flag: AB::Var, + ) -> (AB::Expr, AB::Expr) + where + AB: AirBuilder + InteractionBuilder + AirBuilderWithPublicValues, + { + // Constrain that deferral_flag must be in {0, 1, 2}. If: + // - deferral_flag == 0: all proofs have VmPvs only, ignore deferral-related constraints + // - deferral_flag == 1: all proofs have DeferralPvs only, there should be no valid rows + // and output public values should all be 0 + // - deferral_flag == 2: there is a single child proof with both sets of pvs + builder.assert_tern(local_def_flag); + builder.assert_eq(local_def_flag, next_def_flag); + + let mut when_deferral_flag = builder.when(local_def_flag); + when_deferral_flag.assert_zero(local.proof_idx); + + let mut when_deferral_flag_two = when_deferral_flag.when_ne(local_def_flag, AB::Expr::ONE); + when_deferral_flag_two.assert_one(local.is_valid); + when_deferral_flag_two.assert_one(local.is_last); + + let mut when_deferral_flag_one = when_deferral_flag.when_ne(local_def_flag, AB::Expr::TWO); + when_deferral_flag_one.assert_zero(local.is_valid); + + let vm_pvs: &VmPvs<_> = builder.public_values().borrow(); + let vm_pvs = vm_pvs.as_slice().to_vec(); + + for value in vm_pvs { + builder + .when(local_def_flag) + .when_ne(local_def_flag, AB::Expr::TWO) + .assert_zero(value); + } + + ( + local_def_flag.into(), + (local_def_flag - AB::Expr::ONE).square(), + ) + } +} diff --git a/ceno_recursion_v2/src/circuit/inner/vm_pvs/mod.rs b/ceno_recursion_v2/src/circuit/inner/vm_pvs/mod.rs new file mode 100644 index 000000000..26ed4a40d --- /dev/null +++ b/ceno_recursion_v2/src/circuit/inner/vm_pvs/mod.rs @@ -0,0 +1,5 @@ +mod air; +mod trace; + +pub use air::*; +pub use trace::*; diff --git a/ceno_recursion_v2/src/circuit/inner/vm_pvs/trace.rs b/ceno_recursion_v2/src/circuit/inner/vm_pvs/trace.rs new file mode 100644 index 000000000..1c411d586 --- /dev/null +++ b/ceno_recursion_v2/src/circuit/inner/vm_pvs/trace.rs @@ -0,0 +1,56 @@ +use openvm_cpu_backend::CpuBackend; +use openvm_stark_backend::prover::AirProvingContext; +use openvm_stark_sdk::config::baby_bear_poseidon2::{BabyBearPoseidon2Config, F}; +use p3_field::PrimeCharacteristicRing; +use p3_matrix::dense::RowMajorMatrix; +use std::borrow::BorrowMut; +use verify_stark::pvs::VmPvs; + +use crate::{ + circuit::inner::{ProofsType, vm_pvs::air::VmPvsCols}, + system::RecursionProof, +}; + +pub fn generate_proving_ctx( + proofs: &[RecursionProof], + proofs_type: ProofsType, + child_is_app: bool, + deferral_enabled: bool, +) -> AirProvingContext> { + // TODO(recursion-proof-bridge): populate VM PVS from RecursionProof once projection exists. + // For now we return shape-correct zero rows and zero public values. + let _ = (proofs, proofs_type, child_is_app, deferral_enabled); + + let rows = proofs.len().max(1).next_power_of_two(); + let width = VmPvsCols::::width() + (deferral_enabled as usize); + let mut trace = vec![F::ZERO; rows * width]; + + if rows > 0 { + let first_row = &mut trace[..width]; + let (base_row, def_row) = first_row.split_at_mut(VmPvsCols::::width()); + let cols: &mut VmPvsCols = base_row.borrow_mut(); + cols.proof_idx = F::ZERO; + cols.is_valid = F::ONE; + cols.is_last = F::ONE; + cols.has_verifier_pvs = F::ZERO; + cols.child_pvs.is_terminate = F::ONE; + cols.child_pvs.exit_code = F::ZERO; + + if deferral_enabled { + // deferral_flag for VmPvsAir is 0 or 2; choose 0 for the mocked VM-only case. + def_row[0] = F::ZERO; + } + } + + let trace = RowMajorMatrix::new(trace, width); + let mut public_values = vec![F::ZERO; VmPvs::::width()]; + let pvs: &mut VmPvs = public_values.as_mut_slice().borrow_mut(); + pvs.is_terminate = F::ONE; + pvs.exit_code = F::ZERO; + + AirProvingContext { + cached_mains: vec![], + common_main: trace, + public_values, + } +} diff --git a/ceno_recursion_v2/src/circuit/mod.rs b/ceno_recursion_v2/src/circuit/mod.rs new file mode 100644 index 000000000..6d310e74e --- /dev/null +++ b/ceno_recursion_v2/src/circuit/mod.rs @@ -0,0 +1,20 @@ +use std::sync::Arc; + +use openvm_stark_backend::{AirRef, StarkProtocolConfig}; +use recursion_circuit::prelude::F; + +pub mod deferral; +pub mod inner; + +pub const CONSTRAINT_EVAL_CACHED_INDEX: usize = 0; + +// TODO: move to stark-backend-v2 +pub trait Circuit> { + fn airs(&self) -> Vec>; +} + +impl, C: Circuit> Circuit for Arc { + fn airs(&self) -> Vec> { + self.as_ref().airs() + } +} diff --git a/ceno_recursion_v2/src/continuation/mod.rs b/ceno_recursion_v2/src/continuation/mod.rs new file mode 100644 index 000000000..b38239e63 --- /dev/null +++ b/ceno_recursion_v2/src/continuation/mod.rs @@ -0,0 +1,2 @@ +pub mod prover; +pub mod tests; diff --git a/ceno_recursion_v2/src/continuation/prover/inner/mod.rs b/ceno_recursion_v2/src/continuation/prover/inner/mod.rs new file mode 100644 index 000000000..1e3b47dd3 --- /dev/null +++ b/ceno_recursion_v2/src/continuation/prover/inner/mod.rs @@ -0,0 +1,247 @@ +use std::sync::Arc; + +use ceno_zkvm::scheme::ZKVMProof; +use continuations_v2::SC; +use eyre::Result; +use mpcs::{Basefold, BasefoldRSParams}; +use openvm_poseidon2_air::POSEIDON2_WIDTH; +use openvm_stark_backend::{ + StarkEngine, + keygen::types::{MultiStarkProvingKey, MultiStarkVerifyingKey}, + proof::Proof, + prover::{CommittedTraceData, DeviceMultiStarkProvingKey, ProverBackend, ProvingContext}, +}; +use openvm_stark_sdk::config::baby_bear_poseidon2::{ + Digest, EF, F, default_duplex_sponge_recorder, +}; +use p3_field::PrimeCharacteristicRing; +use verify_stark::pvs::DeferralPvs; + +use crate::{ + circuit::{ + Circuit, + inner::{InnerCircuit, InnerTraceGen, ProofsType}, + }, + system::{ + AggregationSubCircuit, RecursionField, RecursionVk, VerifierConfig, VerifierExternalData, + VerifierTraceGen, + }, +}; + +pub use continuations_v2::prover::ChildVkKind; +use continuations_v2::prover::debug_constraints; +use openvm_stark_backend::prover::DeviceDataTransporter; + +pub use openvm_stark_backend::SystemParams; + +/// Forked inner prover that will bridge Ceno ZKVM proofs with OpenVM recursion. +pub struct InnerAggregationProver< + PB: ProverBackend, + S: AggregationSubCircuit + VerifierTraceGen, + T: InnerTraceGen, +> { + pk: Arc>, + d_pk: DeviceMultiStarkProvingKey, + vk: Arc>, + + agg_node_tracegen: T, + + child_vk: Arc, + child_vk_pcs_data: CommittedTraceData, + circuit: Arc>, + + self_vk_pcs_data: Option>, +} + +impl< + PB: ProverBackend, + S: AggregationSubCircuit + VerifierTraceGen, + T: InnerTraceGen, +> InnerAggregationProver +{ + pub fn new>( + child_vk: Arc, + system_params: SystemParams, + is_self_recursive: bool, + def_hook_commit: Option, + ) -> Self { + let verifier_circuit = S::new( + child_vk.clone(), + VerifierConfig { + continuations_enabled: true, + ..Default::default() + }, + ); + let engine = Eg::new(system_params); + let child_vk_pcs_data = verifier_circuit.commit_child_vk(&engine, &child_vk); + let circuit = Arc::new(InnerCircuit::new( + Arc::new(verifier_circuit), + def_hook_commit.map(|d| d.into()), + )); + let (pk, vk) = engine.keygen(&circuit.airs()); + let d_pk = engine.device().transport_pk_to_device(&pk); + let self_vk_pcs_data = if is_self_recursive { + unimplemented!( + "Self-recursive inner prover support requires converting the local VK into RecursionVk" + ) + } else { + None + }; + let agg_node_tracegen = T::new(def_hook_commit.is_some()); + Self { + pk: Arc::new(pk), + d_pk, + vk: Arc::new(vk), + agg_node_tracegen, + child_vk, + child_vk_pcs_data, + circuit, + self_vk_pcs_data, + } + } + + #[allow(dead_code)] + pub fn from_pk>( + child_vk: Arc, + pk: Arc>, + is_self_recursive: bool, + def_hook_commit: Option, + ) -> Self { + let verifier_circuit = S::new( + child_vk.clone(), + VerifierConfig { + continuations_enabled: true, + ..Default::default() + }, + ); + let engine = Eg::new(pk.params.clone()); + let child_vk_pcs_data = verifier_circuit.commit_child_vk(&engine, &child_vk); + let circuit = Arc::new(InnerCircuit::new( + Arc::new(verifier_circuit), + def_hook_commit.map(|d| d.into()), + )); + let vk = Arc::new(pk.get_vk()); + let d_pk = engine.device().transport_pk_to_device(&pk); + let self_vk_pcs_data = if is_self_recursive { + unimplemented!( + "Self-recursive inner prover support requires converting the local VK into RecursionVk" + ) + } else { + None + }; + let agg_node_tracegen = T::new(def_hook_commit.is_some()); + Self { + pk, + d_pk, + vk, + agg_node_tracegen, + child_vk, + child_vk_pcs_data, + circuit, + self_vk_pcs_data, + } + } +} + +impl< + PB: ProverBackend, + S: AggregationSubCircuit + VerifierTraceGen, + T: InnerTraceGen, +> InnerAggregationProver +where + PB::Matrix: Clone, +{ + pub fn agg_prove_no_def>( + &self, + proofs: &[ZKVMProof>], + child_vk_kind: ChildVkKind, + ) -> Result> { + let ctx = self.generate_proving_ctx(proofs, child_vk_kind, ProofsType::Vm, None); + if tracing::enabled!(tracing::Level::DEBUG) { + // TODO enable trace height + // trace_heights_tracing_info::<_, SC>(&ctx.per_trace, &self.circuit.airs()); + } + + let engine = E::new(self.pk.params.clone()); + #[cfg(debug_assertions)] + debug_constraints(&self.circuit, &ctx, &engine); + let proof = engine.prove(&self.d_pk, ctx)?; + #[cfg(debug_assertions)] + engine.verify(&self.vk, &proof)?; + Ok(proof) + } + + fn generate_proving_ctx( + &self, + proofs: &[ZKVMProof>], + child_vk_kind: ChildVkKind, + proofs_type: ProofsType, + absent_trace_pvs: Option<(DeferralPvs, bool)>, + ) -> ProvingContext { + assert!(proofs.len() <= self.circuit.verifier_circuit.max_num_proofs()); + + let (child_vk, child_vk_pcs_data) = match child_vk_kind { + ChildVkKind::RecursiveSelf => { + unimplemented!("RecursiveSelf proving is not wired for RecursionVk yet") + } + _ => (&self.child_vk, self.child_vk_pcs_data.clone()), + }; + let child_is_app = matches!(child_vk_kind, ChildVkKind::App); + + let (pre_ctxs, poseidon2_compress_inputs) = self + .agg_node_tracegen + .generate_pre_verifier_subcircuit_ctxs( + proofs, + proofs_type, + absent_trace_pvs, + child_is_app, + child_vk_pcs_data.commitment, + ); + + let poseidon2_permute_inputs: Vec<[F; POSEIDON2_WIDTH]> = vec![]; + let range_check_inputs = vec![]; + let mut external_data = VerifierExternalData { + poseidon2_compress_inputs: &poseidon2_compress_inputs, + poseidon2_permute_inputs: &poseidon2_permute_inputs, + range_check_inputs: &range_check_inputs, + required_heights: None, + final_transcript_state: None, + }; + + let subcircuit_ctxs = self + .circuit + .verifier_circuit + .generate_proving_ctxs( + child_vk, + child_vk_pcs_data.clone(), + proofs, + &mut external_data, + default_duplex_sponge_recorder(), + ) + .expect("verifier sub-circuit ctx generation"); + + let post_ctxs = self + .agg_node_tracegen + .generate_post_verifier_subcircuit_ctxs(proofs, proofs_type, child_is_app); + + ProvingContext { + per_trace: pre_ctxs + .into_iter() + .chain(subcircuit_ctxs) + .chain(post_ctxs) + .enumerate() + .collect(), + } + } + + pub fn get_vk(&self) -> Arc> { + self.vk.clone() + } + + pub fn get_self_vk_pcs_data(&self) -> Option> + where + CommittedTraceData: Clone, + { + self.self_vk_pcs_data.clone() + } +} diff --git a/ceno_recursion_v2/src/continuation/prover/mod.rs b/ceno_recursion_v2/src/continuation/prover/mod.rs new file mode 100644 index 000000000..fd7698483 --- /dev/null +++ b/ceno_recursion_v2/src/continuation/prover/mod.rs @@ -0,0 +1,11 @@ +use continuations_v2::SC; +use openvm_cpu_backend::CpuBackend; + +use crate::{circuit::inner::InnerTraceGenImpl, system::VerifierSubCircuit}; + +mod inner; + +pub use inner::*; + +pub type InnerCpuProver = + InnerAggregationProver, VerifierSubCircuit, InnerTraceGenImpl>; diff --git a/ceno_recursion_v2/src/continuation/tests/mod.rs b/ceno_recursion_v2/src/continuation/tests/mod.rs new file mode 100644 index 000000000..5bb48c4ab --- /dev/null +++ b/ceno_recursion_v2/src/continuation/tests/mod.rs @@ -0,0 +1,52 @@ +#[cfg(test)] +mod prover_integration { + use crate::{ + continuation::prover::{ChildVkKind, InnerCpuProver}, + system::utils::test_system_params_zero_pow, + }; + use bincode; + use ceno_zkvm::{scheme::ZKVMProof, structs::ZKVMVerifyingKey}; + use eyre::Result; + use mpcs::{Basefold, BasefoldRSParams}; + use openvm_stark_sdk::{ + config::baby_bear_poseidon2::{BabyBearPoseidon2CpuEngine, DuplexSponge}, + p3_baby_bear::BabyBear, + }; + use p3::field::extension::BinomialExtensionField; + use std::{fs::File, sync::Arc}; + + type Engine = BabyBearPoseidon2CpuEngine; + type E = BinomialExtensionField; + + #[test] + fn leaf_app_proof_round_trip_placeholder() -> Result<()> { + let proof_path = "./src/imported/proof.bin"; + let vk_path = "./src/imported/vk.bin"; + + let zkvm_proofs: Vec>> = + bincode::deserialize_from(File::open(proof_path).expect("open proof file")) + .expect("deserialize zkvm proofs"); + + let child_vk: ZKVMVerifyingKey> = + bincode::deserialize_from(File::open(vk_path).expect("open vk file")) + .expect("deserialize vk file"); + + const MAX_NUM_PROOFS: usize = 1; + let system_params = test_system_params_zero_pow(5, 16, 3); + let leaf_prover = InnerCpuProver::::new::( + Arc::new(child_vk), + system_params, + false, + None, + ); + + let _leaf_proof = leaf_prover.agg_prove_no_def::(&zkvm_proofs, ChildVkKind::App)?; + let overall_size = bincode::serialized_size(&_leaf_proof).expect("serialization error"); + println!("proof size {:.2}mb.", byte_to_mb(overall_size)); + Ok(()) + } + + fn byte_to_mb(byte_size: u64) -> f64 { + byte_size as f64 / (1024.0 * 1024.0) + } +} diff --git a/ceno_recursion_v2/src/cuda/mod.rs b/ceno_recursion_v2/src/cuda/mod.rs new file mode 100644 index 000000000..88e5d0195 --- /dev/null +++ b/ceno_recursion_v2/src/cuda/mod.rs @@ -0,0 +1,30 @@ +use openvm_cuda_common::{copy::MemCopyH2D, d_buffer::DeviceBuffer, error::MemCopyError}; +use recursion_circuit::system::GlobalTraceGenCtx; + +pub mod preflight; +pub mod proof; +pub mod types; +pub mod vk; + +pub use preflight::PreflightGpu; +pub use proof::ProofGpu; +pub use vk::VerifyingKeyGpu; + +pub struct GlobalCtxGpu; + +impl GlobalTraceGenCtx for GlobalCtxGpu { + type ChildVerifyingKey = VerifyingKeyGpu; + type MultiProof = [ProofGpu]; + type PreflightRecords = [PreflightGpu]; +} + +pub fn to_device_or_nullptr(h2d: &[T]) -> Result, MemCopyError> +where + [T]: MemCopyH2D, +{ + if h2d.is_empty() { + Ok(DeviceBuffer::new()) + } else { + h2d.to_device() + } +} diff --git a/ceno_recursion_v2/src/cuda/preflight.rs b/ceno_recursion_v2/src/cuda/preflight.rs new file mode 100644 index 000000000..22f7fddd7 --- /dev/null +++ b/ceno_recursion_v2/src/cuda/preflight.rs @@ -0,0 +1,127 @@ +use openvm_cuda_backend::prelude::EF; +use openvm_cuda_common::d_buffer::DeviceBuffer; +use openvm_stark_sdk::config::baby_bear_poseidon2::Digest; + +use crate::system::{Preflight, RecursionProof, RecursionVk}; + +use super::{ + to_device_or_nullptr, + types::{TraceHeight, TraceMetadata}, +}; + +#[derive(Debug)] +pub struct PreflightGpu { + pub cpu: Preflight, + pub transcript: TranscriptLog, + pub proof_shape: ProofShapePreflightGpu, + pub gkr: TowerPreflightGpu, + pub batch_constraint: BatchConstraintPreflightGpu, + pub stacking: StackingPreflightGpu, + pub whir: WhirPreflightGpu, +} + +#[derive(Debug, Clone, Default)] +pub struct TranscriptLog { + _dummy: usize, +} + +#[derive(Debug)] +pub struct ProofShapePreflightGpu { + pub sorted_trace_heights: DeviceBuffer, + pub sorted_trace_metadata: DeviceBuffer, + pub sorted_cached_commits: DeviceBuffer, + pub per_row_tidx: DeviceBuffer, + pub pvs_tidx: DeviceBuffer, + pub post_tidx: usize, + pub num_present: usize, + pub n_max: usize, + pub n_logup: usize, + pub final_cidx: usize, + pub final_total_interactions: usize, + pub main_commit: Digest, +} + +#[derive(Debug, Clone, Default)] +pub struct TowerPreflightGpu { + _dummy: usize, +} + +#[derive(Debug)] +pub struct BatchConstraintPreflightGpu { + pub sumcheck_rnd: DeviceBuffer, +} + +#[derive(Debug)] +pub struct StackingPreflightGpu { + pub sumcheck_rnd: DeviceBuffer, +} + +#[derive(Debug, Clone, Default)] +pub struct WhirPreflightGpu { + _dummy: usize, +} + +impl PreflightGpu { + pub fn new(vk: &RecursionVk, proof: &RecursionProof, preflight: &Preflight) -> Self { + PreflightGpu { + cpu: preflight.clone(), + transcript: Self::transcript(preflight), + proof_shape: Self::proof_shape(vk, proof, preflight), + gkr: Self::gkr(preflight), + batch_constraint: Self::batch_constraint(preflight), + stacking: Self::stacking(preflight), + whir: Self::whir(preflight), + } + } + + fn transcript(_preflight: &Preflight) -> TranscriptLog { + TranscriptLog { _dummy: 0 } + } + + fn proof_shape( + _vk: &RecursionVk, + _proof: &RecursionProof, + _preflight: &Preflight, + ) -> ProofShapePreflightGpu { + let empty_heights: [TraceHeight; 0] = []; + let empty_metadata: [TraceMetadata; 0] = []; + let empty_commits: [Digest; 0] = []; + let empty_indices: [usize; 0] = []; + ProofShapePreflightGpu { + sorted_trace_heights: to_device_or_nullptr(&empty_heights).unwrap(), + sorted_trace_metadata: to_device_or_nullptr(&empty_metadata).unwrap(), + sorted_cached_commits: to_device_or_nullptr(&empty_commits).unwrap(), + per_row_tidx: to_device_or_nullptr(&empty_indices).unwrap(), + pvs_tidx: to_device_or_nullptr(&empty_indices).unwrap(), + post_tidx: 0, + num_present: 0, + n_max: 0, + n_logup: 0, + final_cidx: 0, + final_total_interactions: 0, + main_commit: Digest::default(), + } + } + + fn gkr(_preflight: &Preflight) -> TowerPreflightGpu { + TowerPreflightGpu { _dummy: 0 } + } + + fn batch_constraint(_preflight: &Preflight) -> BatchConstraintPreflightGpu { + let empty: [EF; 0] = []; + BatchConstraintPreflightGpu { + sumcheck_rnd: to_device_or_nullptr(&empty).unwrap(), + } + } + + fn stacking(_preflight: &Preflight) -> StackingPreflightGpu { + let empty: [EF; 0] = []; + StackingPreflightGpu { + sumcheck_rnd: to_device_or_nullptr(&empty).unwrap(), + } + } + + fn whir(_preflight: &Preflight) -> WhirPreflightGpu { + WhirPreflightGpu { _dummy: 0 } + } +} diff --git a/ceno_recursion_v2/src/cuda/proof.rs b/ceno_recursion_v2/src/cuda/proof.rs new file mode 100644 index 000000000..802a452a0 --- /dev/null +++ b/ceno_recursion_v2/src/cuda/proof.rs @@ -0,0 +1,70 @@ +use openvm_cuda_common::d_buffer::DeviceBuffer; + +use crate::system::{RecursionProof, RecursionVk}; + +use super::{to_device_or_nullptr, types::PublicValueData}; + +pub struct ProofGpu { + pub cpu: RecursionProof, + pub proof_shape: ProofShapeProofGpu, + pub gkr: TowerProofGpu, + pub batch_constraint: BatchConstraintProofGpu, + pub stacking: StackingProofGpu, + pub whir: WhirProofGpu, +} + +pub struct ProofShapeProofGpu { + pub public_values: DeviceBuffer, +} + +pub struct TowerProofGpu { + _dummy: usize, +} + +pub struct BatchConstraintProofGpu { + _dummy: usize, +} + +pub struct StackingProofGpu { + _dummy: usize, +} + +pub struct WhirProofGpu { + _dummy: usize, +} + +impl ProofGpu { + pub fn new(_vk: &RecursionVk, proof: &RecursionProof) -> Self { + ProofGpu { + cpu: proof.clone(), + proof_shape: Self::proof_shape(), + gkr: Self::gkr(proof), + batch_constraint: Self::batch_constraint(proof), + stacking: Self::stacking(proof), + whir: Self::whir(proof), + } + } + + fn proof_shape() -> ProofShapeProofGpu { + let empty: [PublicValueData; 0] = []; + ProofShapeProofGpu { + public_values: to_device_or_nullptr(&empty).unwrap(), + } + } + + fn gkr(_proof: &RecursionProof) -> TowerProofGpu { + TowerProofGpu { _dummy: 0 } + } + + fn batch_constraint(_proof: &RecursionProof) -> BatchConstraintProofGpu { + BatchConstraintProofGpu { _dummy: 0 } + } + + fn stacking(_proof: &RecursionProof) -> StackingProofGpu { + StackingProofGpu { _dummy: 0 } + } + + fn whir(_proof: &RecursionProof) -> WhirProofGpu { + WhirProofGpu { _dummy: 0 } + } +} diff --git a/ceno_recursion_v2/src/cuda/types.rs b/ceno_recursion_v2/src/cuda/types.rs new file mode 100644 index 000000000..d5c7502e6 --- /dev/null +++ b/ceno_recursion_v2/src/cuda/types.rs @@ -0,0 +1,37 @@ +use openvm_stark_sdk::config::baby_bear_poseidon2::F; + +#[repr(C)] +#[derive(Debug, Default)] +pub struct TraceHeight { + pub air_idx: usize, + pub log_height: u8, +} + +#[repr(C)] +#[derive(Debug, Default)] +pub struct TraceMetadata { + pub cached_idx: usize, + pub starting_cidx: usize, + pub total_interactions: usize, + pub num_air_id_lookups: usize, +} + +#[repr(C)] +#[derive(Debug, Default)] +pub struct PublicValueData { + pub air_idx: usize, + pub air_num_pvs: usize, + pub num_airs: usize, + pub pv_idx: usize, + pub value: F, +} + +#[repr(C)] +#[derive(Debug, Default)] +pub struct AirData { + pub num_cached: usize, + pub num_interactions_per_row: usize, + pub total_width: usize, + pub has_preprocessed: bool, + pub need_rot: bool, +} diff --git a/ceno_recursion_v2/src/cuda/vk.rs b/ceno_recursion_v2/src/cuda/vk.rs new file mode 100644 index 000000000..0d45b000c --- /dev/null +++ b/ceno_recursion_v2/src/cuda/vk.rs @@ -0,0 +1,46 @@ +use openvm_cuda_common::d_buffer::DeviceBuffer; +use openvm_stark_backend::{ + SystemParams, WhirProximityStrategy, interaction::LogUpSecurityParameters, +}; +use openvm_stark_sdk::config::baby_bear_poseidon2::{DIGEST_SIZE, Digest, F}; + +use crate::system::RecursionVk; + +use super::types::AirData; + +pub struct VerifyingKeyGpu { + pub cpu: RecursionVk, + pub per_air: DeviceBuffer, + pub system_params: SystemParams, + pub pre_hash: [F; DIGEST_SIZE], +} + +impl VerifyingKeyGpu { + pub fn new(vk: &RecursionVk) -> Self { + Self { + cpu: vk.clone(), + per_air: DeviceBuffer::new(), + system_params: placeholder_system_params(), + pre_hash: Digest::default(), + } + } +} + +fn placeholder_system_params() -> SystemParams { + SystemParams::new( + 1, // log_blowup + 1, // l_skip + 1, // n_stack + 1, // w_stack + 1, // log_final_poly_len + 1, // folding_pow_bits + 1, // mu_pow_bits + WhirProximityStrategy::UniqueDecoding, + 80, // security bits + LogUpSecurityParameters { + max_interaction_count: 1, + log_max_message_length: 1, + pow_bits: 0, + }, + ) +} diff --git a/ceno_recursion_v2/src/lib.rs b/ceno_recursion_v2/src/lib.rs new file mode 100644 index 000000000..e006612e7 --- /dev/null +++ b/ceno_recursion_v2/src/lib.rs @@ -0,0 +1,19 @@ +pub mod batch_constraint; +pub mod bn254; +pub mod circuit; +pub mod continuation; +pub mod main; +pub mod proof_shape; +pub mod system; +pub mod tower; +pub mod tracegen; +pub mod transcript; +pub mod utils; + +#[cfg(feature = "cuda")] +pub mod cuda; + +pub mod bus; +pub use recursion_circuit::{primitives, subairs}; + +pub use recursion_circuit::define_typed_per_proof_permutation_bus; diff --git a/ceno_recursion_v2/src/main/air.rs b/ceno_recursion_v2/src/main/air.rs new file mode 100644 index 000000000..d94927630 --- /dev/null +++ b/ceno_recursion_v2/src/main/air.rs @@ -0,0 +1,126 @@ +use core::borrow::Borrow; + +use openvm_circuit_primitives::{SubAir, utils::assert_array_eq}; +use openvm_stark_backend::{ + BaseAirWithPublicValues, PartitionedBaseAir, interaction::InteractionBuilder, +}; +use openvm_stark_sdk::config::baby_bear_poseidon2::D_EF; +use p3_air::{Air, AirBuilder, BaseAir}; +use p3_field::Field; +use p3_matrix::Matrix; +use recursion_circuit::subairs::nested_for_loop::{NestedForLoopIoCols, NestedForLoopSubAir}; +use stark_recursion_circuit_derive::AlignedBorrow; + +use crate::bus::{ + MainBus, MainExpressionClaimBus, MainExpressionClaimMessage, MainMessage, MainSumcheckInputBus, + MainSumcheckInputMessage, MainSumcheckOutputBus, MainSumcheckOutputMessage, +}; + +#[repr(C)] +#[derive(AlignedBorrow, Debug)] +pub struct MainCols { + pub is_enabled: T, + pub proof_idx: T, + pub idx: T, + pub is_first_idx: T, + pub is_first: T, + pub tidx: T, + pub claim_in: [T; D_EF], + pub claim_out: [T; D_EF], +} + +pub struct MainAir { + pub main_bus: MainBus, + pub sumcheck_input_bus: MainSumcheckInputBus, + pub sumcheck_output_bus: MainSumcheckOutputBus, + pub expression_claim_bus: MainExpressionClaimBus, +} + +impl BaseAir for MainAir { + fn width(&self) -> usize { + MainCols::::width() + } +} + +impl BaseAirWithPublicValues for MainAir {} +impl PartitionedBaseAir for MainAir {} + +impl Air for MainAir { + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let (local_row, next_row) = ( + main.row_slice(0).expect("window should have two elements"), + main.row_slice(1).expect("window should have two elements"), + ); + let local: &MainCols = (*local_row).borrow(); + let next: &MainCols = (*next_row).borrow(); + + type LoopSubAir = NestedForLoopSubAir<2>; + LoopSubAir {}.eval( + builder, + ( + NestedForLoopIoCols { + is_enabled: local.is_enabled, + counter: [local.proof_idx, local.idx], + is_first: [local.is_first_idx, local.is_first], + } + .map_into(), + NestedForLoopIoCols { + is_enabled: next.is_enabled, + counter: [next.proof_idx, next.idx], + is_first: [next.is_first_idx, next.is_first], + } + .map_into(), + ), + ); + + let receive_mask = local.is_enabled * local.is_first; + self.main_bus.receive( + builder, + local.proof_idx, + MainMessage { + idx: local.idx.into(), + tidx: local.tidx.into(), + claim: local.claim_in.map(Into::into), + }, + receive_mask, + ); + + self.sumcheck_input_bus.send( + builder, + local.proof_idx, + MainSumcheckInputMessage { + idx: local.idx.into(), + tidx: local.tidx.into(), + claim: local.claim_in.map(Into::into), + }, + local.is_enabled, + ); + + self.sumcheck_output_bus.receive( + builder, + local.proof_idx, + MainSumcheckOutputMessage { + idx: local.idx.into(), + claim: local.claim_out.map(Into::into), + }, + local.is_enabled, + ); + + assert_array_eq( + &mut builder.when(local.is_enabled), + local.claim_in, + local.claim_out, + ); + + self.expression_claim_bus.send( + builder, + local.proof_idx, + MainExpressionClaimMessage { + idx: local.idx.into(), + claim: local.claim_out.map(Into::into), + }, + local.is_enabled * local.is_first, + ); + } +} diff --git a/ceno_recursion_v2/src/main/mod.rs b/ceno_recursion_v2/src/main/mod.rs new file mode 100644 index 000000000..08cd3238b --- /dev/null +++ b/ceno_recursion_v2/src/main/mod.rs @@ -0,0 +1,314 @@ +mod air; +mod sumcheck; +mod trace; + +use std::sync::Arc; + +use ceno_zkvm::scheme::ZKVMChipProof; +use eyre::{Result, bail, eyre}; +use openvm_cpu_backend::CpuBackend; +use openvm_poseidon2_air::POSEIDON2_WIDTH; +use openvm_stark_backend::{ + AirRef, FiatShamirTranscript, ReadOnlyTranscript, StarkProtocolConfig, TranscriptHistory, + prover::AirProvingContext, +}; +use openvm_stark_sdk::config::baby_bear_poseidon2::{BabyBearPoseidon2Config, EF, F}; +use p3_field::PrimeCharacteristicRing; +use p3_matrix::dense::RowMajorMatrix; + +use self::{ + air::MainAir, + sumcheck::{ + MainSumcheckAir, MainSumcheckRecord, MainSumcheckRoundRecord, MainSumcheckTraceGenerator, + }, + trace::{MainRecord, MainTraceGenerator}, +}; +use crate::{ + bus::{MainBus, MainExpressionClaimBus, MainSumcheckInputBus, MainSumcheckOutputBus}, + system::{ + AirModule, BusIndexManager, BusInventory, ChipTranscriptRange, GlobalCtxCpu, Preflight, + RecursionField, RecursionProof, RecursionVk, TraceGenModule, + }, + tower::convert_logup_claim, + tracegen::{ModuleChip, RowMajorChip}, +}; + +pub use air::MainCols; +pub use sumcheck::MainSumcheckCols; + +#[derive(Clone)] +pub struct MainModule { + main_bus: MainBus, + sumcheck_input_bus: MainSumcheckInputBus, + sumcheck_output_bus: MainSumcheckOutputBus, + expression_claim_bus: MainExpressionClaimBus, +} + +impl MainModule { + pub fn new(b: &mut BusIndexManager, bus_inventory: BusInventory) -> Self { + let _ = b; + let main_bus = bus_inventory.main_bus; + let sumcheck_input_bus = bus_inventory.main_sumcheck_input_bus; + let sumcheck_output_bus = bus_inventory.main_sumcheck_output_bus; + let expression_claim_bus = bus_inventory.main_expression_claim_bus; + Self { + main_bus, + sumcheck_input_bus, + sumcheck_output_bus, + expression_claim_bus, + } + } + + fn collect_records( + &self, + _child_vk: &RecursionVk, + proofs: &[RecursionProof], + preflights: &[Preflight], + ) -> Result> { + if proofs.len() != preflights.len() { + bail!( + "proof/preflight length mismatch ({} proofs vs {} preflights)", + proofs.len(), + preflights.len() + ); + } + + let mut paired = Vec::new(); + for (proof_idx, (proof, preflight)) in proofs.iter().zip(preflights).enumerate() { + let mut chip_pf_iter = preflight.main.chips.iter(); + let mut saw_chip = false; + for (&chip_idx, chip_instances) in &proof.chip_proofs { + for (instance_idx, chip_proof) in chip_instances.iter().enumerate() { + saw_chip = true; + let pf_entry = chip_pf_iter + .next() + .ok_or_else(|| eyre!( + "missing main preflight entry for chip {chip_idx} instance {instance_idx}" + ))?; + if pf_entry.chip_idx != chip_idx || pf_entry.instance_idx != instance_idx { + bail!( + "main preflight chip mismatch: expected ({}, {}), got ({}, {})", + chip_idx, + instance_idx, + pf_entry.chip_idx, + pf_entry.instance_idx + ); + } + let claim = input_layer_claim(chip_proof); + let mut ts = ReadOnlyTranscript::new(&preflight.transcript, pf_entry.tidx); + record_main_transcript(&mut ts, chip_idx, chip_proof); + + let main_record = MainRecord { + proof_idx, + idx: chip_idx, + tidx: pf_entry.tidx, + claim, + }; + let sumcheck_record = build_sumcheck_record_from_chip( + proof_idx, + chip_idx, + claim, + chip_proof, + pf_entry.tidx, + ); + paired.push((main_record, sumcheck_record)); + } + } + + if !saw_chip { + paired.push(( + MainRecord { + proof_idx, + ..MainRecord::default() + }, + MainSumcheckRecord::default(), + )); + } + } + + if paired.is_empty() { + paired.push((MainRecord::default(), MainSumcheckRecord::default())); + } + + Ok(paired) + } +} + +impl AirModule for MainModule { + fn num_airs(&self) -> usize { + 2 + } + + fn airs>(&self) -> Vec> { + let main_air = MainAir { + main_bus: self.main_bus, + sumcheck_input_bus: self.sumcheck_input_bus, + sumcheck_output_bus: self.sumcheck_output_bus, + expression_claim_bus: self.expression_claim_bus, + }; + let main_sumcheck_air = MainSumcheckAir { + sumcheck_input_bus: self.sumcheck_input_bus, + sumcheck_output_bus: self.sumcheck_output_bus, + }; + vec![Arc::new(main_air) as AirRef<_>, Arc::new(main_sumcheck_air)] + } +} + +impl MainModule { + pub fn run_preflight( + &self, + child_vk: &RecursionVk, + proof: &RecursionProof, + preflight: &mut Preflight, + ts: &mut TS, + ) where + TS: FiatShamirTranscript + + TranscriptHistory, + { + let _ = (self, child_vk); + for (&chip_idx, chip_instances) in &proof.chip_proofs { + for (instance_idx, chip_proof) in chip_instances.iter().enumerate() { + let tidx = ts.len(); + record_main_transcript(ts, chip_idx, chip_proof); + preflight.main.chips.push(ChipTranscriptRange { + chip_idx, + instance_idx, + tidx, + }); + } + } + } +} + +impl> TraceGenModule> for MainModule { + type ModuleSpecificCtx<'a> = (); + + fn generate_proving_ctxs( + &self, + child_vk: &RecursionVk, + proofs: &[RecursionProof], + preflights: &[Preflight], + _ctx: &Self::ModuleSpecificCtx<'_>, + required_heights: Option<&[usize]>, + ) -> Option>>> { + let mut paired = self.collect_records(child_vk, proofs, preflights).ok()?; + paired.sort_by_key(|(record, _)| (record.proof_idx, record.idx)); + let (main_records, sumcheck_records): (Vec<_>, Vec<_>) = paired.into_iter().unzip(); + let ctx = MainTraceCtx { + main_records: &main_records, + sumcheck_records: &sumcheck_records, + }; + let chips = [MainModuleChip::Main, MainModuleChip::Sumcheck]; + let span = tracing::Span::current(); + let contexts = chips + .into_iter() + .enumerate() + .map(|(idx, chip)| { + let _guard = span.enter(); + chip.generate_proving_ctx( + &ctx, + required_heights.and_then(|heights| heights.get(idx).copied()), + ) + }) + .collect::>>()?; + + Some(contexts) + } +} + +struct MainTraceCtx<'a> { + main_records: &'a [MainRecord], + sumcheck_records: &'a [MainSumcheckRecord], +} + +enum MainModuleChip { + Main, + Sumcheck, +} + +impl RowMajorChip for MainModuleChip { + type Ctx<'a> = MainTraceCtx<'a>; + + fn generate_trace( + &self, + ctx: &Self::Ctx<'_>, + required_height: Option, + ) -> Option> { + match self { + MainModuleChip::Main => { + MainTraceGenerator.generate_trace(&ctx.main_records, required_height) + } + MainModuleChip::Sumcheck => { + MainSumcheckTraceGenerator.generate_trace(&ctx.sumcheck_records, required_height) + } + } + } +} + +fn input_layer_claim(chip_proof: &ZKVMChipProof) -> EF { + let layer_count = chip_proof + .tower_proof + .logup_specs_eval + .iter() + .map(|spec_layers| spec_layers.len()) + .chain( + chip_proof + .tower_proof + .prod_specs_eval + .iter() + .map(|spec_layers| spec_layers.len()), + ) + .max() + .unwrap_or(0); + if layer_count == 0 { + return EF::ZERO; + } + convert_logup_claim(chip_proof, layer_count - 1)[0] +} + +fn build_sumcheck_record_from_chip( + proof_idx: usize, + chip_idx: usize, + claim: EF, + chip_proof: &ZKVMChipProof, + tidx: usize, +) -> MainSumcheckRecord { + let rounds = chip_proof + .gkr_iop_proof + .as_ref() + .and_then(|proof| proof.0.first()) + .map(|layer| { + layer + .main + .proof + .proofs + .iter() + .map(|msg| { + let mut evals = [EF::ZERO; 3]; + for (dst, src) in evals.iter_mut().zip(msg.evaluations.iter().take(3)) { + *dst = *src; + } + MainSumcheckRoundRecord { evaluations: evals } + }) + .collect::>() + }) + .unwrap_or_default(); + + MainSumcheckRecord { + proof_idx, + idx: chip_idx, + tidx, + claim, + rounds, + } +} + +fn record_main_transcript( + ts: &mut TS, + _chip_idx: usize, + chip_proof: &ZKVMChipProof, +) where + TS: FiatShamirTranscript, +{ + ts.observe_ext(input_layer_claim(chip_proof)); +} diff --git a/ceno_recursion_v2/src/main/sumcheck/air.rs b/ceno_recursion_v2/src/main/sumcheck/air.rs new file mode 100644 index 000000000..9179527b8 --- /dev/null +++ b/ceno_recursion_v2/src/main/sumcheck/air.rs @@ -0,0 +1,236 @@ +use core::borrow::Borrow; + +use openvm_circuit_primitives::{SubAir, utils::assert_array_eq}; +use openvm_stark_backend::{ + BaseAirWithPublicValues, PartitionedBaseAir, interaction::InteractionBuilder, +}; +use openvm_stark_sdk::config::baby_bear_poseidon2::D_EF; +use p3_air::{Air, AirBuilder, BaseAir}; +use p3_field::{Field, PrimeCharacteristicRing, extension::BinomiallyExtendable}; +use p3_matrix::Matrix; +use recursion_circuit::{ + subairs::nested_for_loop::{NestedForLoopIoCols, NestedForLoopSubAir}, + utils::{ + assert_one_ext, ext_field_add, ext_field_multiply, ext_field_multiply_scalar, + ext_field_one_minus, ext_field_subtract, + }, +}; +use stark_recursion_circuit_derive::AlignedBorrow; + +use crate::bus::{ + MainSumcheckInputBus, MainSumcheckInputMessage, MainSumcheckOutputBus, + MainSumcheckOutputMessage, +}; + +#[repr(C)] +#[derive(AlignedBorrow, Debug)] +pub struct MainSumcheckCols { + pub is_enabled: T, + pub proof_idx: T, + pub idx: T, + pub is_first_idx: T, + pub is_first_round: T, + pub is_last_round: T, + pub is_dummy: T, + pub round: T, + pub tidx: T, + pub ev1: [T; D_EF], + pub ev2: [T; D_EF], + pub ev3: [T; D_EF], + pub claim_in: [T; D_EF], + pub claim_out: [T; D_EF], + pub prev_challenge: [T; D_EF], + pub challenge: [T; D_EF], + pub eq_in: [T; D_EF], + pub eq_out: [T; D_EF], +} + +pub struct MainSumcheckAir { + pub sumcheck_input_bus: MainSumcheckInputBus, + pub sumcheck_output_bus: MainSumcheckOutputBus, +} + +impl BaseAir for MainSumcheckAir { + fn width(&self) -> usize { + MainSumcheckCols::::width() + } +} + +impl BaseAirWithPublicValues for MainSumcheckAir {} +impl PartitionedBaseAir for MainSumcheckAir {} + +impl Air for MainSumcheckAir +where + AB: AirBuilder + InteractionBuilder, + ::PrimeSubfield: BinomiallyExtendable<{ D_EF }>, +{ + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let (local_row, next_row) = ( + main.row_slice(0).expect("window should have two elements"), + main.row_slice(1).expect("window should have two elements"), + ); + let local: &MainSumcheckCols = (*local_row).borrow(); + let next: &MainSumcheckCols = (*next_row).borrow(); + + builder.assert_bool(local.is_dummy.clone()); + builder.assert_bool(local.is_last_round.clone()); + builder.assert_bool(local.is_first_round.clone()); + + type LoopSubAir = NestedForLoopSubAir<2>; + LoopSubAir {}.eval( + builder, + ( + NestedForLoopIoCols { + is_enabled: local.is_enabled, + counter: [local.proof_idx, local.idx], + is_first: [local.is_first_idx, local.is_first_round.clone()], + } + .map_into(), + NestedForLoopIoCols { + is_enabled: next.is_enabled, + counter: [next.proof_idx, next.idx], + is_first: [next.is_first_idx, next.is_first_round.clone()], + } + .map_into(), + ), + ); + + let is_transition_round = + LoopSubAir::local_is_transition(next.is_enabled, next.is_first_round.clone()); + let computed_is_last = LoopSubAir::local_is_last( + local.is_enabled, + next.is_enabled, + next.is_first_round.clone(), + ); + + builder + .when(local.is_enabled.clone()) + .assert_eq(local.is_last_round.clone(), computed_is_last.clone()); + + builder + .when(local.is_first_round.clone()) + .assert_zero(local.round); + builder + .when(is_transition_round.clone()) + .assert_eq(next.round, local.round.clone() + AB::Expr::ONE); + + builder.when(is_transition_round.clone()).assert_eq( + next.tidx, + local.tidx.clone().into() + AB::Expr::from_usize(4 * D_EF), + ); + + assert_one_ext(&mut builder.when(local.is_first_round.clone()), local.eq_in); + let eq_out = update_eq(local.eq_in, local.prev_challenge, local.challenge); + assert_array_eq( + &mut builder.when(local.is_enabled.clone()), + local.eq_out, + eq_out, + ); + assert_array_eq( + &mut builder.when(is_transition_round.clone()), + local.eq_out, + next.eq_in, + ); + + let ev0 = ext_field_subtract(local.claim_in, local.ev1); + let claim_out = + interpolate_cubic_at_0123(ev0, local.ev1, local.ev2, local.ev3, local.challenge); + assert_array_eq(builder, local.claim_out, claim_out); + assert_array_eq( + &mut builder.when(is_transition_round.clone()), + local.claim_out, + next.claim_in, + ); + + let is_not_dummy = AB::Expr::ONE - local.is_dummy.clone(); + + let receive_mask = + local.is_enabled.clone() * local.is_first_round.clone() * is_not_dummy.clone(); + self.sumcheck_input_bus.receive( + builder, + local.proof_idx, + MainSumcheckInputMessage { + idx: local.idx.into(), + tidx: local.tidx.into(), + claim: local.claim_in.map(Into::into), + }, + receive_mask, + ); + + let send_mask = local.is_enabled.clone() * local.is_last_round.clone() * is_not_dummy; + self.sumcheck_output_bus.send( + builder, + local.proof_idx, + MainSumcheckOutputMessage { + idx: local.idx.into(), + claim: local.claim_out.map(Into::into), + }, + send_mask, + ); + } +} + +fn interpolate_cubic_at_0123( + ev0: [FA; D_EF], + ev1: [F; D_EF], + ev2: [F; D_EF], + ev3: [F; D_EF], + x: [F; D_EF], +) -> [FA; D_EF] +where + F: Into + Copy, + FA: PrimeCharacteristicRing, + FA::PrimeSubfield: BinomiallyExtendable<{ D_EF }>, +{ + let three: FA = FA::from_usize(3); + let inv2: FA = FA::from_prime_subfield(FA::PrimeSubfield::from_usize(2).inverse()); + let inv6: FA = FA::from_prime_subfield(FA::PrimeSubfield::from_usize(6).inverse()); + + let s1: [FA; D_EF] = ext_field_subtract(ev1, ev0.clone()); + let s2: [FA; D_EF] = ext_field_subtract(ev2, ev0.clone()); + let s3: [FA; D_EF] = ext_field_subtract(ev3, ev0.clone()); + + let d3: [FA; D_EF] = ext_field_subtract::( + s3, + ext_field_multiply_scalar::(ext_field_subtract::(s2.clone(), s1.clone()), three), + ); + + let p: [FA; D_EF] = ext_field_multiply_scalar(d3.clone(), inv6); + + let q: [FA; D_EF] = ext_field_subtract::( + ext_field_multiply_scalar::(ext_field_subtract::(s2, d3), inv2), + s1.clone(), + ); + + let r: [FA; D_EF] = ext_field_subtract::(s1, ext_field_add::(p.clone(), q.clone())); + + ext_field_add::( + ext_field_multiply::( + ext_field_add::( + ext_field_multiply::(ext_field_add::(ext_field_multiply::(p, x), q), x), + r, + ), + x, + ), + ev0, + ) +} + +fn update_eq(eq_in: [F; D_EF], prev_challenge: [F; D_EF], challenge: [F; D_EF]) -> [FA; D_EF] +where + F: Into + Copy, + FA: PrimeCharacteristicRing, + FA::PrimeSubfield: BinomiallyExtendable<{ D_EF }>, +{ + ext_field_multiply::( + eq_in, + ext_field_add::( + ext_field_multiply::(prev_challenge, challenge), + ext_field_multiply::( + ext_field_one_minus::(prev_challenge), + ext_field_one_minus::(challenge), + ), + ), + ) +} diff --git a/ceno_recursion_v2/src/main/sumcheck/mod.rs b/ceno_recursion_v2/src/main/sumcheck/mod.rs new file mode 100644 index 000000000..c6f57b1e0 --- /dev/null +++ b/ceno_recursion_v2/src/main/sumcheck/mod.rs @@ -0,0 +1,5 @@ +mod air; +mod trace; + +pub use air::{MainSumcheckAir, MainSumcheckCols}; +pub use trace::{MainSumcheckRecord, MainSumcheckRoundRecord, MainSumcheckTraceGenerator}; diff --git a/ceno_recursion_v2/src/main/sumcheck/trace.rs b/ceno_recursion_v2/src/main/sumcheck/trace.rs new file mode 100644 index 000000000..772cfa0ba --- /dev/null +++ b/ceno_recursion_v2/src/main/sumcheck/trace.rs @@ -0,0 +1,113 @@ +use core::{borrow::BorrowMut, convert::TryInto}; + +use openvm_stark_sdk::config::baby_bear_poseidon2::{D_EF, EF, F}; +use p3_field::{BasedVectorSpace, PrimeCharacteristicRing}; +use p3_matrix::dense::RowMajorMatrix; + +use super::air::MainSumcheckCols; +use crate::tracegen::RowMajorChip; + +#[derive(Default, Debug, Clone)] +pub struct MainSumcheckRoundRecord { + pub evaluations: [EF; 3], +} + +#[derive(Default, Debug, Clone)] +pub struct MainSumcheckRecord { + pub proof_idx: usize, + pub idx: usize, + pub tidx: usize, + pub claim: EF, + pub rounds: Vec, +} + +impl MainSumcheckRecord { + fn total_rows(&self) -> usize { + self.rounds.len().max(1) + } +} + +pub struct MainSumcheckTraceGenerator; + +impl RowMajorChip for MainSumcheckTraceGenerator { + type Ctx<'a> = &'a [MainSumcheckRecord]; + + fn generate_trace( + &self, + records: &Self::Ctx<'_>, + required_height: Option, + ) -> Option> { + let width = MainSumcheckCols::::width(); + let num_valid_rows: usize = records.iter().map(MainSumcheckRecord::total_rows).sum(); + let num_valid_rows = num_valid_rows.max(1); + let height = if let Some(height) = required_height { + if height < num_valid_rows { + return None; + } + height + } else { + num_valid_rows.next_power_of_two() + }; + + let mut trace = vec![F::ZERO; height * width]; + if records.is_empty() { + return Some(RowMajorMatrix::new(trace, width)); + } + + let zero_challenge: [F; D_EF] = EF::ZERO.as_basis_coefficients_slice().try_into().unwrap(); + let mut row_offset = 0; + + for record in records.iter() { + let rows = record.total_rows(); + let has_rounds = !record.rounds.is_empty(); + let claim_value = record.claim; + let eq_value = EF::ONE; + + for round_idx in 0..rows { + let offset = row_offset * width; + let cols_slice = &mut trace[offset..offset + width]; + let cols: &mut MainSumcheckCols = cols_slice.borrow_mut(); + + let is_first_round = round_idx == 0; + let is_last_round = round_idx + 1 == rows; + cols.is_enabled = F::ONE; + cols.proof_idx = F::from_usize(record.proof_idx); + cols.idx = F::from_usize(record.idx); + cols.is_first_idx = F::from_bool(is_first_round); + cols.is_first_round = F::from_bool(is_first_round); + cols.is_last_round = F::from_bool(is_last_round); + cols.is_dummy = F::from_bool(!has_rounds); + cols.round = F::from_usize(round_idx); + cols.tidx = F::from_usize(record.tidx + 4 * D_EF * round_idx); + + let evals = record + .rounds + .get(round_idx) + .map(|round| round.evaluations) + .unwrap_or([EF::ZERO; 3]); + cols.ev1 = evals[0].as_basis_coefficients_slice().try_into().unwrap(); + cols.ev2 = evals[1].as_basis_coefficients_slice().try_into().unwrap(); + cols.ev3 = evals[2].as_basis_coefficients_slice().try_into().unwrap(); + + let claim_in_basis: [F; D_EF] = claim_value + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + cols.claim_in = claim_in_basis; + cols.claim_out = claim_in_basis; + + let eq_basis: [F; D_EF] = + eq_value.as_basis_coefficients_slice().try_into().unwrap(); + cols.eq_in = eq_basis; + cols.eq_out = eq_basis; + + cols.prev_challenge = zero_challenge; + cols.challenge = zero_challenge; + + row_offset += 1; + } + } + + Some(RowMajorMatrix::new(trace, width)) + } +} diff --git a/ceno_recursion_v2/src/main/trace.rs b/ceno_recursion_v2/src/main/trace.rs new file mode 100644 index 000000000..f312401e8 --- /dev/null +++ b/ceno_recursion_v2/src/main/trace.rs @@ -0,0 +1,103 @@ +use core::borrow::BorrowMut; + +use openvm_stark_sdk::config::baby_bear_poseidon2::{D_EF, EF, F}; +use p3_field::{BasedVectorSpace, PrimeCharacteristicRing}; +use p3_matrix::dense::RowMajorMatrix; + +use super::air::MainCols; +use crate::tracegen::RowMajorChip; + +#[derive(Clone, Debug, Default)] +pub struct MainRecord { + pub proof_idx: usize, + pub idx: usize, + pub tidx: usize, + pub claim: EF, +} + +pub struct MainTraceGenerator; + +impl RowMajorChip for MainTraceGenerator { + type Ctx<'a> = &'a [MainRecord]; + + fn generate_trace( + &self, + records: &Self::Ctx<'_>, + required_height: Option, + ) -> Option> { + generate_trace::, _>(records, required_height, fill_main_cols) + } +} + +fn generate_trace( + records: &[MainRecord], + required_height: Option, + mut fill: Fill, +) -> Option> +where + C: ColumnAccess, + Fill: FnMut(&MainRecord, &mut C, bool), +{ + let width = C::width(); + let num_rows = records.len().max(1); + let height = if let Some(height) = required_height { + if height < num_rows { + return None; + } + height + } else { + num_rows.next_power_of_two().max(1) + }; + let mut trace = vec![F::ZERO; height * width]; + if records.is_empty() { + return Some(RowMajorMatrix::new(trace, width)); + } + + let mut prev_proof_idx = usize::MAX; + let mut prev_idx = usize::MAX; + for (row_idx, record) in records.iter().enumerate() { + let offset = row_idx * width; + let cols_slice = &mut trace[offset..offset + width]; + let cols = C::from_bytes(cols_slice); + fill( + record, + cols, + prev_proof_idx != record.proof_idx || prev_idx != record.idx, + ); + prev_proof_idx = record.proof_idx; + prev_idx = record.idx; + } + + Some(RowMajorMatrix::new(trace, width)) +} + +trait ColumnAccess: Sized { + fn width() -> usize; + fn from_bytes(slice: &mut [F]) -> &mut Self; +} + +impl ColumnAccess for MainCols { + fn width() -> usize { + MainCols::::width() + } + + fn from_bytes(slice: &mut [F]) -> &mut Self { + slice.borrow_mut() + } +} + +fn fill_main_cols(record: &MainRecord, cols: &mut MainCols, is_new_pair: bool) { + cols.is_enabled = F::ONE; + cols.proof_idx = F::from_usize(record.proof_idx); + cols.idx = F::from_usize(record.idx); + cols.is_first_idx = F::from_bool(is_new_pair); + cols.is_first = F::ONE; + cols.tidx = F::from_usize(record.tidx); + let claim_basis: [F; D_EF] = record + .claim + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + cols.claim_in = claim_basis; + cols.claim_out = claim_basis; +} diff --git a/ceno_recursion_v2/src/proof_shape/bus.rs b/ceno_recursion_v2/src/proof_shape/bus.rs new file mode 100644 index 000000000..7427ae1ec --- /dev/null +++ b/ceno_recursion_v2/src/proof_shape/bus.rs @@ -0,0 +1,48 @@ +use p3_field::PrimeCharacteristicRing; +use stark_recursion_circuit_derive::AlignedBorrow; + +use crate::define_typed_per_proof_permutation_bus; + +#[repr(C)] +#[derive(AlignedBorrow, Debug, Clone)] +pub struct ProofShapePermutationMessage { + pub idx: T, +} + +define_typed_per_proof_permutation_bus!(ProofShapePermutationBus, ProofShapePermutationMessage); + +#[repr(C)] +#[derive(AlignedBorrow, Debug, Clone)] +pub struct StartingTidxMessage { + pub air_idx: T, + pub tidx: T, +} + +define_typed_per_proof_permutation_bus!(StartingTidxBus, StartingTidxMessage); + +#[repr(C)] +#[derive(AlignedBorrow, Debug, Clone)] +pub struct NumPublicValuesMessage { + pub air_idx: T, + pub tidx: T, + pub num_pvs: T, +} + +define_typed_per_proof_permutation_bus!(NumPublicValuesBus, NumPublicValuesMessage); + +#[repr(u8)] +#[derive(Debug, Copy, Clone)] +pub enum AirShapeProperty { + AirId, + NumInteractions, + NeedRot, + NumRead, + NumWrite, + NumLk, +} + +impl AirShapeProperty { + pub fn to_field(self) -> T { + T::from_u8(self as u8) + } +} diff --git a/ceno_recursion_v2/src/proof_shape/cuda_abi.rs b/ceno_recursion_v2/src/proof_shape/cuda_abi.rs new file mode 100644 index 000000000..1e04d00d1 --- /dev/null +++ b/ceno_recursion_v2/src/proof_shape/cuda_abi.rs @@ -0,0 +1,101 @@ +#![allow(clippy::missing_safety_doc)] + +use openvm_cuda_backend::prelude::{Digest, F}; +use openvm_cuda_common::{d_buffer::DeviceBuffer, error::CudaError}; + +use crate::cuda::types::{AirData, PublicValueData, TraceHeight, TraceMetadata}; + +#[repr(C)] +pub(crate) struct ProofShapePerProof { + pub num_present: usize, + pub n_max: usize, + pub n_logup: usize, + pub final_cidx: usize, + pub final_total_interactions: usize, + pub main_commit: Digest, +} + +#[repr(C)] +pub(crate) struct ProofShapeTracegenInputs { + pub num_airs: usize, + pub l_skip: usize, + pub max_interaction_count: u32, + pub max_cached: usize, + pub min_cached_idx: usize, + pub pre_hash: Digest, + pub range_checker_8_ptr: *mut u32, + pub range_checker_5_ptr: *mut u32, + pub pow_checker_ptr: *mut u32, +} + +unsafe extern "C" { + fn _proof_shape_tracegen( + d_trace: *mut F, + height: usize, + d_air_data: *const AirData, + d_per_row_tidx: *const *const usize, + d_sorted_trace_heights: *const *const TraceHeight, + d_sorted_trace_metadata: *const *const TraceMetadata, + d_cached_commits: *const *const Digest, + d_per_proof: *const ProofShapePerProof, + num_proofs: usize, + inputs: *const ProofShapeTracegenInputs, + ) -> i32; + fn _public_values_recursion_tracegen( + d_trace: *mut F, + height: usize, + d_pvs_data: *const *const PublicValueData, + d_pvs_tidx: *const *const usize, + num_proofs: usize, + num_pvs: usize, + ) -> i32; +} + +#[allow(clippy::too_many_arguments)] +pub unsafe fn proof_shape_tracegen( + d_trace: &DeviceBuffer, + height: usize, + d_air_data: &DeviceBuffer, + d_per_row_tidx: Vec<*const usize>, + d_sorted_trace_heights: Vec<*const TraceHeight>, + d_sorted_trace_metadata: Vec<*const TraceMetadata>, + d_cached_commits: Vec<*const Digest>, + d_per_proof: &DeviceBuffer, + num_proofs: usize, + inputs: &ProofShapeTracegenInputs, +) -> Result<(), CudaError> { + unsafe { + CudaError::from_result(_proof_shape_tracegen( + d_trace.as_mut_ptr(), + height, + d_air_data.as_ptr(), + d_per_row_tidx.as_ptr(), + d_sorted_trace_heights.as_ptr(), + d_sorted_trace_metadata.as_ptr(), + d_cached_commits.as_ptr(), + d_per_proof.as_ptr(), + num_proofs, + inputs as *const ProofShapeTracegenInputs, + )) + } +} + +pub unsafe fn public_values_tracegen( + d_trace: &DeviceBuffer, + height: usize, + d_pvs_data: Vec<*const PublicValueData>, + d_pvs_tidx: Vec<*const usize>, + num_proofs: usize, + num_pvs: usize, +) -> Result<(), CudaError> { + unsafe { + CudaError::from_result(_public_values_recursion_tracegen( + d_trace.as_mut_ptr(), + height, + d_pvs_data.as_ptr(), + d_pvs_tidx.as_ptr(), + num_proofs, + num_pvs, + )) + } +} diff --git a/ceno_recursion_v2/src/proof_shape/mod.rs b/ceno_recursion_v2/src/proof_shape/mod.rs new file mode 100644 index 000000000..6f088e4cd --- /dev/null +++ b/ceno_recursion_v2/src/proof_shape/mod.rs @@ -0,0 +1,403 @@ +use std::sync::Arc; + +use itertools::Itertools; +use openvm_circuit_primitives::encoder::Encoder; +use openvm_cpu_backend::CpuBackend; +use openvm_stark_backend::{ + AirRef, FiatShamirTranscript, StarkProtocolConfig, TranscriptHistory, + keygen::types::VerifierSinglePreprocessedData, prover::AirProvingContext, +}; +use openvm_stark_sdk::config::baby_bear_poseidon2::{ + BabyBearPoseidon2Config, DIGEST_SIZE, Digest, F, +}; +use p3_field::PrimeCharacteristicRing; +use p3_matrix::dense::RowMajorMatrix; + +use crate::{ + proof_shape::{ + bus::{NumPublicValuesBus, ProofShapePermutationBus, StartingTidxBus}, + proof_shape::ProofShapeAir, + pvs::PublicValuesAir, + }, + system::{ + AirModule, BusIndexManager, BusInventory, GlobalCtxCpu, POW_CHECKER_HEIGHT, Preflight, + RecursionProof, RecursionVk, TraceGenModule, TraceVData, + }, + tracegen::RowMajorChip, +}; +use recursion_circuit::primitives::{ + bus::{PowerCheckerBus, RangeCheckerBus}, + pow::PowerCheckerCpuTraceGenerator, + range::{RangeCheckerAir, RangeCheckerCols}, +}; + +pub mod bus; +#[allow(clippy::module_inception)] +pub mod proof_shape; +pub mod pvs; + +#[cfg(feature = "cuda")] +mod cuda_abi; + +#[derive(Clone)] +pub struct AirMetadata { + is_required: bool, + num_public_values: usize, + main_width: usize, + cached_widths: Vec, + num_read_count: usize, + num_write_count: usize, + num_logup_count: usize, + preprocessed_width: Option, + preprocessed_data: Option>, +} + +pub struct ProofShapeModule { + // Verifying key fields + per_air: Vec, + + // Buses (inventory for external, others are internal) + bus_inventory: BusInventory, + range_bus: RangeCheckerBus, + pow_bus: PowerCheckerBus, + permutation_bus: ProofShapePermutationBus, + starting_tidx_bus: StartingTidxBus, + num_pvs_bus: NumPublicValuesBus, + + // Required for ProofShapeAir tracegen + constraints + idx_encoder: Arc, + min_cached_idx: usize, + max_cached: usize, + commit_mult: usize, + + // Module sends extra public values message for use outside of verifier + // sub-circuit if true + continuations_enabled: bool, +} + +impl ProofShapeModule { + pub fn new( + child_vk: &RecursionVk, + b: &mut BusIndexManager, + bus_inventory: BusInventory, + continuations_enabled: bool, + ) -> Self { + let num_airs = child_vk.circuit_vks.len(); + let idx_encoder = Arc::new(Encoder::new(num_airs, 2, true)); + + let min_cached_idx = 0; + let _min_cached = 1; + let max_cached = 2; + + let per_air = extract_air_metadata_from_vk(child_vk, max_cached); + + let range_bus = bus_inventory.range_checker_bus; + let pow_bus = bus_inventory.power_checker_bus; + Self { + per_air, + bus_inventory, + range_bus, + pow_bus, + permutation_bus: ProofShapePermutationBus::new(b.new_bus_idx()), + starting_tidx_bus: StartingTidxBus::new(b.new_bus_idx()), + num_pvs_bus: NumPublicValuesBus::new(b.new_bus_idx()), + idx_encoder, + min_cached_idx, + max_cached, + commit_mult: 100, + continuations_enabled, + } + } + + #[tracing::instrument(level = "trace", skip_all)] + pub fn run_preflight( + &self, + child_vk: &RecursionVk, + proof: &RecursionProof, + preflight: &mut Preflight, + ts: &mut TS, + ) where + TS: FiatShamirTranscript + TranscriptHistory, + { + let _ = self; + + // Verifier preprocess: absorb raw public input polynomials. + for raw_pi in &proof.raw_pi { + for &value in raw_pi { + ts.observe(value); + } + } + + // Build per-air shape metadata from present chip proofs. + let mut sorted_trace_vdata = proof + .chip_proofs + .iter() + .map(|(&chip_idx, chip_instances)| { + let num_instances: usize = chip_instances + .iter() + .flat_map(|instance| instance.num_instances.iter()) + .copied() + .sum(); + let padded = num_instances.max(1).next_power_of_two(); + let log_height = padded.ilog2() as usize; + (chip_idx, TraceVData { log_height }) + }) + .collect_vec(); + sorted_trace_vdata.sort_by_key(|(air_idx, v)| (usize::MAX - v.log_height, *air_idx)); + preflight.proof_shape.sorted_trace_vdata = sorted_trace_vdata; + preflight.proof_shape.l_skip = 0; + + // Verifier preprocess: absorb (circuit_idx, num_instance...) for all chip proofs. + for (&chip_idx, chip_instances) in &proof.chip_proofs { + ts.observe(F::from_usize(chip_idx)); + for num_instance in chip_instances + .iter() + .flat_map(|instance| &instance.num_instances) + { + ts.observe(F::from_usize(*num_instance)); + } + } + + // TODO(recursion-proof-bridge): absorb fixed/witness commitments once the local + // preflight bridge can encode PCS commitments into the Fiat-Shamir transcript. + preflight.proof_shape.alpha_tidx = ts.len(); + let _alpha = FiatShamirTranscript::::sample_ext(ts); + preflight.proof_shape.beta_tidx = ts.len(); + let _beta = FiatShamirTranscript::::sample_ext(ts); + preflight.proof_shape.fork_start_tidx = ts.len(); + + let _ = child_vk; + } + + fn placeholder_air_widths(&self) -> Vec { + let proof_shape_width = proof_shape::ProofShapeCols::::width() + + self.idx_encoder.width() + + self.max_cached * DIGEST_SIZE; + let pvs_width = pvs::PublicValuesCols::::width(); + let range_width = RangeCheckerCols::::width(); + // TODO(recursion-proof-bridge): replace proof-shape module placeholder contexts with + // real tracegen so RangeCheckerAir rows are semantically valid, not only width-correct. + vec![proof_shape_width, pvs_width, range_width] + } +} + +fn extract_rwlk_counts(child_vk: &RecursionVk, expected_len: usize) -> Vec<(usize, usize, usize)> { + (0..expected_len) + .map(|idx| { + child_vk + .circuit_index_to_name + .get(&idx) + .and_then(|name| child_vk.circuit_vks.get(name)) + .map(|circuit_vk| { + let cs = circuit_vk.get_cs(); + (cs.num_reads(), cs.num_writes(), cs.num_lks()) + }) + .unwrap_or_else(|| { + // TODO: Populate GKR count metadata once every AIR is backed by a concrete VK. + (0, 0, 0) + }) + }) + .collect() +} + +fn extract_air_metadata_from_vk(child_vk: &RecursionVk, max_cached: usize) -> Vec { + let rwlk_counts = extract_rwlk_counts(child_vk, child_vk.circuit_vks.len()); + (0..child_vk.circuit_vks.len()) + .map(|idx| { + let (num_read_count, num_write_count, num_logup_count) = + rwlk_counts.get(idx).copied().unwrap_or((0, 0, 0)); + + let num_public_values = child_vk + .circuit_index_to_name + .get(&idx) + .and_then(|name| child_vk.circuit_vks.get(name)) + .map(|circuit_vk| circuit_vk.get_cs().instance_openings().len()) + .unwrap_or(0); + + AirMetadata { + is_required: false, + num_public_values, + main_width: 0, + cached_widths: vec![0; max_cached], + num_read_count, + num_write_count, + num_logup_count, + preprocessed_width: None, + preprocessed_data: None, + } + }) + .collect_vec() +} + +impl AirModule for ProofShapeModule { + fn num_airs(&self) -> usize { + 3 + } + + fn airs>(&self) -> Vec> { + let proof_shape_air = ProofShapeAir::<4, 8> { + per_air: self.per_air.clone(), + min_cached_idx: self.min_cached_idx, + max_cached: self.max_cached, + commit_mult: self.commit_mult, + idx_encoder: self.idx_encoder.clone(), + range_bus: self.range_bus, + pow_bus: self.pow_bus, + permutation_bus: self.permutation_bus, + starting_tidx_bus: self.starting_tidx_bus, + num_pvs_bus: self.num_pvs_bus, + fraction_folder_input_bus: self.bus_inventory.fraction_folder_input_bus, + expression_claim_n_max_bus: self.bus_inventory.expression_claim_n_max_bus, + tower_module_bus: self.bus_inventory.tower_module_bus, + air_shape_bus: self.bus_inventory.air_shape_bus, + hyperdim_bus: self.bus_inventory.hyperdim_bus, + lifted_heights_bus: self.bus_inventory.lifted_heights_bus, + commitments_bus: self.bus_inventory.commitments_bus, + transcript_bus: self.bus_inventory.transcript_bus, + n_lift_bus: self.bus_inventory.n_lift_bus, + cached_commit_bus: self.bus_inventory.cached_commit_bus, + continuations_enabled: self.continuations_enabled, + }; + let pvs_air = PublicValuesAir { + public_values_bus: self.bus_inventory.public_values_bus, + num_pvs_bus: self.num_pvs_bus, + transcript_bus: self.bus_inventory.transcript_bus, + continuations_enabled: self.continuations_enabled, + }; + let range_checker = RangeCheckerAir::<8> { + bus: self.range_bus, + }; + vec![ + Arc::new(proof_shape_air) as AirRef<_>, + Arc::new(pvs_air) as AirRef<_>, + Arc::new(range_checker) as AirRef<_>, + ] + } +} + +impl> TraceGenModule> + for ProofShapeModule +{ + // (pow_checker, external_range_checks) + type ModuleSpecificCtx<'a> = ( + Arc>, + &'a [usize], + ); + + #[tracing::instrument(skip_all)] + fn generate_proving_ctxs( + &self, + child_vk: &RecursionVk, + proofs: &[RecursionProof], + preflights: &[Preflight], + ctx: &>>::ModuleSpecificCtx<'_>, + required_heights: Option<&[usize]>, + ) -> Option>>> { + let _ = (child_vk, proofs, preflights, ctx); + let widths = self.placeholder_air_widths(); + let num_airs = required_heights + .map(|heights| heights.len()) + .unwrap_or_else(|| self.num_airs()); + Some( + (0..num_airs) + .map(|idx| { + let height = required_heights + .and_then(|heights| heights.get(idx).copied()) + .unwrap_or(1); + let width = widths.get(idx).copied().unwrap_or(1); + zero_air_ctx(height, width) + }) + .collect(), + ) + } +} + +fn zero_air_ctx>( + height: usize, + width: usize, +) -> AirProvingContext> { + let rows = height.max(1); + let cols = width.max(1); + let matrix = RowMajorMatrix::new(vec![F::ZERO; rows * cols], cols); + AirProvingContext::simple_no_pis(matrix) +} + +#[allow(dead_code)] +#[derive(strum_macros::Display, strum::EnumDiscriminants)] +#[strum_discriminants(repr(usize))] +enum ProofShapeModuleChip { + ProofShape(proof_shape::ProofShapeChip<4, 8>), + PublicValues, +} + +impl RowMajorChip for ProofShapeModuleChip { + type Ctx<'a> = (&'a RecursionVk, &'a [RecursionProof], &'a [Preflight]); + + #[tracing::instrument( + name = "wrapper.generate_trace", + level = "trace", + skip_all, + fields(air = %self) + )] + fn generate_trace( + &self, + ctx: &Self::Ctx<'_>, + required_height: Option, + ) -> Option> { + let _ = ctx; + let rows = required_height.unwrap_or(1).max(1); + let width = match self { + ProofShapeModuleChip::ProofShape(chip) => chip.placeholder_width(), + ProofShapeModuleChip::PublicValues => pvs::PublicValuesCols::::width(), + }; + Some(RowMajorMatrix::new(vec![F::ZERO; rows * width], width)) + } +} + +#[cfg(feature = "cuda")] +mod cuda_tracegen { + use openvm_cuda_backend::{GpuBackend, base::DeviceMatrix}; + + use super::*; + use crate::cuda::{ + GlobalCtxGpu, preflight::PreflightGpu, proof::ProofGpu, vk::VerifyingKeyGpu, + }; + + impl TraceGenModule for ProofShapeModule { + type ModuleSpecificCtx<'a> = (); + + #[tracing::instrument(skip_all)] + fn generate_proving_ctxs( + &self, + child_vk: &VerifyingKeyGpu, + proofs: &[ProofGpu], + preflights: &[PreflightGpu], + _ctx: &>::ModuleSpecificCtx<'_>, + required_heights: Option<&[usize]>, + ) -> Option>> { + let _ = (child_vk, proofs, preflights); + let widths = self.placeholder_air_widths(); + let air_count = required_heights + .map(|heights| heights.len()) + .unwrap_or_else(|| self.num_airs()); + Some( + (0..air_count) + .map(|idx| { + let height = required_heights + .and_then(|heights| heights.get(idx).copied()) + .unwrap_or(1); + let width = widths.get(idx).copied().unwrap_or(1); + zero_gpu_ctx(height, width) + }) + .collect(), + ) + } + } + + fn zero_gpu_ctx(height: usize, width: usize) -> AirProvingContext { + let rows = height.max(1); + let cols = width.max(1); + let trace = DeviceMatrix::with_capacity(rows, cols); + AirProvingContext::simple_no_pis(trace) + } +} diff --git a/ceno_recursion_v2/src/proof_shape/proof_shape/air.rs b/ceno_recursion_v2/src/proof_shape/proof_shape/air.rs new file mode 100644 index 000000000..992a0eb5a --- /dev/null +++ b/ceno_recursion_v2/src/proof_shape/proof_shape/air.rs @@ -0,0 +1,785 @@ +use std::{borrow::Borrow, sync::Arc}; + +use itertools::fold; +use openvm_circuit_primitives::{ + SubAir, + encoder::Encoder, + utils::{and, not, or}, +}; +use openvm_stark_backend::{ + BaseAirWithPublicValues, PartitionedBaseAir, interaction::InteractionBuilder, +}; +use openvm_stark_sdk::config::baby_bear_poseidon2::DIGEST_SIZE; +use p3_air::{Air, AirBuilder, BaseAir}; +use p3_field::{PrimeCharacteristicRing, PrimeField32}; +use p3_matrix::Matrix; +use stark_recursion_circuit_derive::AlignedBorrow; + +use crate::{ + bus::{ + AirShapeBus, AirShapeBusMessage, CachedCommitBus, CachedCommitBusMessage, CommitmentsBus, + CommitmentsBusMessage, ExpressionClaimNMaxBus, ExpressionClaimNMaxMessage, + FractionFolderInputBus, FractionFolderInputMessage, HyperdimBus, HyperdimBusMessage, + LiftedHeightsBus, LiftedHeightsBusMessage, NLiftBus, NLiftMessage, TowerModuleBus, + TowerModuleMessage, TranscriptBus, TranscriptBusMessage, + }, + primitives::bus::{ + PowerCheckerBus, PowerCheckerBusMessage, RangeCheckerBus, RangeCheckerBusMessage, + }, + proof_shape::{ + AirMetadata, + bus::{ + AirShapeProperty, NumPublicValuesBus, NumPublicValuesMessage, ProofShapePermutationBus, + ProofShapePermutationMessage, StartingTidxBus, StartingTidxMessage, + }, + }, + subairs::nested_for_loop::{NestedForLoopIoCols, NestedForLoopSubAir}, +}; + +#[repr(C)] +#[derive(AlignedBorrow, Debug)] +pub struct ProofShapeCols { + pub proof_idx: F, + pub is_valid: F, + pub is_first: F, + pub is_last: F, + + // loop: proof_idx -> idx (air idx) + pub idx: F, + pub sorted_idx: F, + /// Represents log2 trace height when `is_present`. + /// + /// Has a special use on summary row (when `is_last`). + pub log_height: F, + /// Whether this AIR needs rotation openings. + pub need_rot: F, + + // First possible tidx and non-main cidx of the current AIR + pub starting_tidx: F, + pub starting_cidx: F, + + // Columns that may be read from the transcript. Note that cached_commits is also read + // from the transcript. + pub is_present: F, + + /// Will be constrained to be `2^log_height` when `is_present`. + /// + /// Has a special use on summary row (when `is_last`). + pub height: F, + + // Number of present AIRs so far + pub num_present: F, + + /// Limb decomposition of `height` used for range/decomposition checks. + pub height_limbs: [F; NUM_LIMBS], + + /// The maximum hypercube dimension across all present AIR traces, or zero. + /// Computed as max(0, n0, n1, ...) where ni = log_height_i for each present trace. + pub n_max: F, + pub is_n_max_greater: F, + + pub num_air_id_lookups: F, + pub num_columns: F, +} + +// Variable-length columns are stored at the end +pub struct ProofShapeVarCols<'a, F> { + pub idx_flags: &'a [F], // [F; IDX_FLAGS] + pub cached_commits: &'a [[F; DIGEST_SIZE]], // [[F; DIGEST_SIZE]; MAX_CACHED] +} + +/// AIR for verifying the proof shape (trace heights, widths, commitments) of a child proof +/// within the recursion circuit. +/// +/// The AIR enforces per-AIR shape consistency and forwards metadata to downstream buses. +pub struct ProofShapeAir { + // Parameters derived from vk + pub per_air: Vec, + pub min_cached_idx: usize, + pub max_cached: usize, + pub commit_mult: usize, + + // Primitives + pub idx_encoder: Arc, + pub range_bus: RangeCheckerBus, + pub pow_bus: PowerCheckerBus, + + // Internal buses + pub permutation_bus: ProofShapePermutationBus, + pub starting_tidx_bus: StartingTidxBus, + pub num_pvs_bus: NumPublicValuesBus, + + // Inter-module buses + pub tower_module_bus: TowerModuleBus, + pub air_shape_bus: AirShapeBus, + pub expression_claim_n_max_bus: ExpressionClaimNMaxBus, + pub fraction_folder_input_bus: FractionFolderInputBus, + pub hyperdim_bus: HyperdimBus, + pub lifted_heights_bus: LiftedHeightsBus, + pub commitments_bus: CommitmentsBus, + pub transcript_bus: TranscriptBus, + pub n_lift_bus: NLiftBus, + + // For continuations + pub cached_commit_bus: CachedCommitBus, + pub continuations_enabled: bool, +} + +impl BaseAir + for ProofShapeAir +{ + fn width(&self) -> usize { + ProofShapeCols::::width() + + self.idx_encoder.width() + + self.max_cached * DIGEST_SIZE + } +} +impl BaseAirWithPublicValues + for ProofShapeAir +{ +} +impl PartitionedBaseAir + for ProofShapeAir +{ +} + +impl Air + for ProofShapeAir +where + AB::F: PrimeField32, +{ + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + + let (local, next) = ( + main.row_slice(0).expect("window should have two elements"), + main.row_slice(1).expect("window should have two elements"), + ); + let const_width = ProofShapeCols::::width(); + + let localv = borrow_var_cols::( + &local[const_width..], + self.idx_encoder.width(), + self.max_cached, + ); + let local: &ProofShapeCols = (*local)[..const_width].borrow(); + let next: &ProofShapeCols = (*next)[..const_width].borrow(); + let n_logup = local.starting_cidx; + + self.idx_encoder.eval(builder, localv.idx_flags); + + NestedForLoopSubAir::<1> {}.eval( + builder, + ( + NestedForLoopIoCols { + is_enabled: local.is_valid + local.is_last, + counter: [local.proof_idx.into()], + is_first: [local.is_first.into()], + }, + NestedForLoopIoCols { + is_enabled: next.is_valid + next.is_last, + counter: [next.proof_idx.into()], + is_first: [next.is_first.into()], + }, + ), + ); + builder + .when(and(local.is_valid, not(local.is_last))) + .assert_eq(local.proof_idx, next.proof_idx); + + builder.assert_bool(local.is_present); + builder.when(local.is_present).assert_one(local.is_valid); + + builder + .when(local.is_first) + .assert_eq(local.is_present, local.num_present); + builder.when(local.is_valid).assert_eq( + local.num_present + next.is_present * next.is_valid, + next.num_present, + ); + + /////////////////////////////////////////////////////////////////////////////////////////// + // PERMUTATION AND SORTING + /////////////////////////////////////////////////////////////////////////////////////////// + builder.when(local.is_first).assert_zero(local.sorted_idx); + builder + .when(next.sorted_idx) + .assert_eq(local.sorted_idx, next.sorted_idx - AB::F::ONE); + + self.permutation_bus.send( + builder, + local.proof_idx, + ProofShapePermutationMessage { + idx: local.sorted_idx, + }, + local.is_valid, + ); + + self.permutation_bus.receive( + builder, + local.proof_idx, + ProofShapePermutationMessage { idx: local.idx }, + local.is_valid, + ); + + builder + .when(and(not(local.is_present), local.is_valid)) + .assert_zero(local.height); + builder + .when(and(not(local.is_present), local.is_valid)) + .assert_zero(local.log_height); + + // Range check difference using ExponentBus to ensure local.log_height >= next.log_height + self.range_bus.lookup_key( + builder, + RangeCheckerBusMessage { + value: local.log_height - next.log_height, + max_bits: AB::Expr::from_usize(5), + }, + and(local.is_valid, not(next.is_last)), + ); + + /////////////////////////////////////////////////////////////////////////////////////////// + // VK FIELD SELECTION + /////////////////////////////////////////////////////////////////////////////////////////// + // Select values for TranscriptBus + let mut is_required = AB::Expr::ZERO; + let mut is_min_cached = AB::Expr::ZERO; + let mut has_preprocessed = AB::Expr::ZERO; + let mut cached_present = vec![AB::Expr::ZERO; self.max_cached]; + + // Select values for LiftedHeightsBus + let mut main_common_width = AB::Expr::ZERO; + let mut preprocessed_stacked_width = AB::Expr::ZERO; + let mut cached_widths = vec![AB::Expr::ZERO; self.max_cached]; + + // Select values for CommitmentsBus + let mut preprocessed_commit = [AB::Expr::ZERO; DIGEST_SIZE]; + + // Select values for NumPublicValuesBus + let mut num_pvs = AB::Expr::ZERO; + let mut has_pvs = AB::Expr::ZERO; + let mut num_read_count = AB::Expr::ZERO; + let mut num_write_count = AB::Expr::ZERO; + let mut num_logup_count = AB::Expr::ZERO; + + for (i, air_data) in self.per_air.iter().enumerate() { + // We keep a running tally of how many transcript reads there should be up to any + // given point, and use that to constrain initial_tidx + let is_current_air = self.idx_encoder.get_flag_expr::(i, localv.idx_flags); + let mut when_current = builder.when(is_current_air.clone()); + + when_current.assert_eq(local.idx, AB::F::from_usize(i)); + + main_common_width += is_current_air.clone() * AB::F::from_usize(air_data.main_width); + + if air_data.num_public_values != 0 { + has_pvs += is_current_air.clone(); + } + num_pvs += is_current_air.clone() * AB::F::from_usize(air_data.num_public_values); + + if air_data.is_required { + is_required += is_current_air.clone(); + when_current.assert_one(local.is_present); + } + + if i == self.min_cached_idx { + is_min_cached += is_current_air.clone(); + } + + if let Some(preprocessed) = &air_data.preprocessed_data { + when_current.assert_eq( + local.log_height, + AB::Expr::from_usize(0usize.wrapping_add_signed(preprocessed.hypercube_dim)), + ); + has_preprocessed += is_current_air.clone(); + + preprocessed_stacked_width += is_current_air.clone() + * AB::F::from_usize(air_data.preprocessed_width.unwrap()); + (0..DIGEST_SIZE).for_each(|didx| { + preprocessed_commit[didx] += is_current_air.clone() + * AB::F::from_u32(preprocessed.commit[didx].as_canonical_u32()); + }); + } + + for (cached_idx, width) in air_data.cached_widths.iter().enumerate() { + cached_present[cached_idx] += is_current_air.clone(); + cached_widths[cached_idx] += is_current_air.clone() * AB::Expr::from_usize(*width); + } + + num_read_count += + is_current_air.clone() * AB::Expr::from_usize(air_data.num_read_count); + num_write_count += + is_current_air.clone() * AB::Expr::from_usize(air_data.num_write_count); + num_logup_count += + is_current_air.clone() * AB::Expr::from_usize(air_data.num_logup_count); + } + + /////////////////////////////////////////////////////////////////////////////////////////// + // TRANSCRIPT OBSERVATIONS + /////////////////////////////////////////////////////////////////////////////////////////// + let is_first_idx = self.idx_encoder.get_flag_expr::(0, localv.idx_flags); + builder + .when(is_first_idx.clone()) + .assert_eq(local.starting_tidx, AB::Expr::from_usize(2 * DIGEST_SIZE)); + + self.starting_tidx_bus.receive( + builder, + local.proof_idx, + StartingTidxMessage { + air_idx: local.idx * local.is_valid + + AB::Expr::from_usize(self.per_air.len()) * local.is_last, + tidx: local.starting_tidx.into(), + }, + or( + local.is_last, + and(local.is_valid, not::(is_first_idx)), + ), + ); + + let mut tidx = local.starting_tidx.into(); + self.transcript_bus.receive( + builder, + local.proof_idx, + TranscriptBusMessage { + tidx: tidx.clone(), + value: local.is_present.into(), + is_sample: AB::Expr::ZERO, + }, + not::(is_required.clone()) * local.is_valid, + ); + tidx += not::(is_required) * local.is_valid; + + for (didx, commit_val) in preprocessed_commit.iter().enumerate() { + self.transcript_bus.receive( + builder, + local.proof_idx, + TranscriptBusMessage { + tidx: tidx.clone() + AB::Expr::from_usize(didx), + value: commit_val.clone(), + is_sample: AB::Expr::ZERO, + }, + has_preprocessed.clone() * local.is_present, + ); + } + tidx += has_preprocessed.clone() * AB::Expr::from_usize(DIGEST_SIZE) * local.is_present; + + self.transcript_bus.receive( + builder, + local.proof_idx, + TranscriptBusMessage { + tidx: tidx.clone(), + value: local.log_height.into(), + is_sample: AB::Expr::ZERO, + }, + not::(has_preprocessed.clone()) * local.is_present, + ); + tidx += not::(has_preprocessed.clone()) * local.is_present; + + (0..self.max_cached).for_each(|i| { + for didx in 0..DIGEST_SIZE { + self.transcript_bus.receive( + builder, + local.proof_idx, + TranscriptBusMessage { + tidx: tidx.clone(), + value: localv.cached_commits[i][didx].into(), + is_sample: AB::Expr::ZERO, + }, + cached_present[i].clone() * local.is_present, + ); + tidx += cached_present[i].clone() * local.is_present; + } + }); + + let num_pvs_tidx = tidx.clone(); + tidx += num_pvs.clone() * local.is_present; + + // constrain next air tid + self.starting_tidx_bus.send( + builder, + local.proof_idx, + StartingTidxMessage { + air_idx: local.idx + AB::F::ONE, + tidx, + }, + local.is_valid, + ); + + for didx in 0..DIGEST_SIZE { + self.transcript_bus.receive( + builder, + local.proof_idx, + TranscriptBusMessage { + tidx: AB::Expr::from_usize(didx), + value: localv.cached_commits[self.max_cached - 1][didx].into(), + is_sample: AB::Expr::ZERO, + }, + local.is_last, + ); + + self.transcript_bus.receive( + builder, + local.proof_idx, + TranscriptBusMessage { + tidx: AB::Expr::from_usize(didx + DIGEST_SIZE), + value: localv.cached_commits[self.max_cached - 1][didx].into(), + is_sample: AB::Expr::ZERO, + }, + is_min_cached.clone() * local.is_valid, + ); + } + + /////////////////////////////////////////////////////////////////////////////////////////// + // AIR SHAPE LOOKUP + /////////////////////////////////////////////////////////////////////////////////////////// + self.air_shape_bus.add_key_with_lookups( + builder, + local.proof_idx, + AirShapeBusMessage { + sort_idx: local.sorted_idx.into(), + property_idx: AirShapeProperty::AirId.to_field(), + value: local.idx.into(), + }, + local.is_present * local.num_air_id_lookups, + ); + + self.air_shape_bus.add_key_with_lookups( + builder, + local.proof_idx, + AirShapeBusMessage { + sort_idx: local.sorted_idx.into(), + property_idx: AirShapeProperty::NumInteractions.to_field(), + value: AB::Expr::ZERO, + }, + local.is_present, + ); + + self.air_shape_bus.add_key_with_lookups( + builder, + local.proof_idx, + AirShapeBusMessage { + sort_idx: local.sorted_idx.into(), + property_idx: AirShapeProperty::NeedRot.to_field(), + value: local.need_rot.into(), + }, + local.is_present * local.num_columns, + ); + self.air_shape_bus.add_key_with_lookups( + builder, + local.proof_idx, + AirShapeBusMessage { + sort_idx: local.sorted_idx.into(), + property_idx: AirShapeProperty::NumRead.to_field(), + value: num_read_count.clone(), + }, + // each layer lookup once if current air was present + local.is_present * n_logup, + ); + self.air_shape_bus.add_key_with_lookups( + builder, + local.proof_idx, + AirShapeBusMessage { + sort_idx: local.sorted_idx.into(), + property_idx: AirShapeProperty::NumWrite.to_field(), + value: num_write_count.clone(), + }, + local.is_present * n_logup, + ); + self.air_shape_bus.add_key_with_lookups( + builder, + local.proof_idx, + AirShapeBusMessage { + sort_idx: local.sorted_idx.into(), + property_idx: AirShapeProperty::NumLk.to_field(), + value: num_logup_count, + }, + local.is_present * n_logup, + ); + + /////////////////////////////////////////////////////////////////////////////////////////// + // HYPERDIM LOOKUP + /////////////////////////////////////////////////////////////////////////////////////////// + let n = local.log_height.into(); + builder.assert_bool(local.need_rot); + builder + .when(not(local.is_present)) + .assert_zero(local.need_rot); + builder + .when(not(local.is_present)) + .assert_zero(local.num_columns); + let n_abs = n.clone(); + // We range check n in [0, 32). + self.range_bus.lookup_key( + builder, + RangeCheckerBusMessage { + value: n_abs.clone(), + max_bits: AB::Expr::from_usize(5), + }, + local.is_present, + ); + + self.hyperdim_bus.add_key_with_lookups( + builder, + local.proof_idx, + HyperdimBusMessage { + sort_idx: local.sorted_idx.into(), + n_abs: n_abs.clone(), + n_sign_bit: AB::Expr::ZERO, + }, + local.is_present * (local.num_air_id_lookups + AB::F::ONE), + ); + + /////////////////////////////////////////////////////////////////////////////////////////// + // LIFTED HEIGHTS LOOKUP + STACKING COMMITMENTS + /////////////////////////////////////////////////////////////////////////////////////////// + self.pow_bus.lookup_key( + builder, + PowerCheckerBusMessage { + log: local.log_height.into(), + exp: local.height.into(), + }, + local.is_present, + ); + + self.lifted_heights_bus.add_key_with_lookups( + builder, + local.proof_idx, + LiftedHeightsBusMessage { + sort_idx: local.sorted_idx.into(), + part_idx: AB::Expr::ZERO, + commit_idx: AB::Expr::ZERO, + hypercube_dim: n.clone(), + lifted_height: local.height.into(), + log_lifted_height: local.log_height.into(), + }, + local.is_present * main_common_width, + ); + + builder + .when(and(local.is_first, local.is_valid)) + .assert_one(local.starting_cidx); + let mut cidx_offset = AB::Expr::ZERO; + + self.lifted_heights_bus.add_key_with_lookups( + builder, + local.proof_idx, + LiftedHeightsBusMessage { + sort_idx: local.sorted_idx.into(), + part_idx: cidx_offset.clone() + AB::F::ONE, + commit_idx: cidx_offset.clone() + local.starting_cidx, + hypercube_dim: n.clone(), + lifted_height: local.height.into(), + log_lifted_height: local.log_height.into(), + }, + local.is_present * preprocessed_stacked_width, + ); + + self.commitments_bus.add_key_with_lookups( + builder, + local.proof_idx, + CommitmentsBusMessage { + major_idx: AB::Expr::ZERO, + minor_idx: cidx_offset.clone() + local.starting_cidx, + commitment: preprocessed_commit, + }, + has_preprocessed.clone() * local.is_present * AB::Expr::from_usize(self.commit_mult), + ); + cidx_offset += has_preprocessed.clone(); + + (0..self.max_cached).for_each(|cached_idx| { + self.lifted_heights_bus.add_key_with_lookups( + builder, + local.proof_idx, + LiftedHeightsBusMessage { + sort_idx: local.sorted_idx.into(), + part_idx: cidx_offset.clone() + AB::F::ONE, + commit_idx: cidx_offset.clone() + local.starting_cidx, + hypercube_dim: n.clone(), + lifted_height: local.height.into(), + log_lifted_height: local.log_height.into(), + }, + local.is_present * cached_widths[cached_idx].clone(), + ); + + self.commitments_bus.add_key_with_lookups( + builder, + local.proof_idx, + CommitmentsBusMessage { + major_idx: AB::Expr::ZERO, + minor_idx: cidx_offset.clone() + local.starting_cidx, + commitment: localv.cached_commits[cached_idx].map(Into::into), + }, + cached_present[cached_idx].clone() + * local.is_present + * AB::Expr::from_usize(self.commit_mult), + ); + cidx_offset += cached_present[cached_idx].clone(); + + self.cached_commit_bus.send( + builder, + local.proof_idx, + CachedCommitBusMessage { + air_idx: local.idx.into(), + cached_idx: AB::Expr::from_usize(cached_idx), + cached_commit: localv.cached_commits[cached_idx].map(Into::into), + }, + cached_present[cached_idx].clone() + * local.is_valid + * AB::Expr::from_bool(self.continuations_enabled), + ); + }); + + builder + .when(and(local.is_valid, not(next.is_last))) + .assert_eq(local.starting_cidx + cidx_offset, next.starting_cidx); + + self.commitments_bus.add_key_with_lookups( + builder, + local.proof_idx, + CommitmentsBusMessage { + major_idx: AB::Expr::ZERO, + minor_idx: AB::Expr::ZERO, + commitment: localv.cached_commits[self.max_cached - 1].map(Into::into), + }, + is_min_cached.clone() * local.is_valid * AB::Expr::from_usize(self.commit_mult), + ); + + /////////////////////////////////////////////////////////////////////////////////////////// + // NUM PUBLIC VALUES + /////////////////////////////////////////////////////////////////////////////////////////// + self.num_pvs_bus.send( + builder, + local.proof_idx, + NumPublicValuesMessage { + air_idx: local.idx.into(), + tidx: num_pvs_tidx, + num_pvs, + }, + local.is_present * has_pvs, + ); + + /////////////////////////////////////////////////////////////////////////////////////////// + // HEIGHT + GKR MESSAGE + /////////////////////////////////////////////////////////////////////////////////////////// + builder.when(local.is_valid).assert_eq( + fold( + local.height_limbs.iter().enumerate(), + AB::Expr::ZERO, + |acc, (i, limb)| acc + (AB::Expr::from_u32(1 << (i * LIMB_BITS)) * *limb), + ), + local.height, + ); + + for i in 0..NUM_LIMBS { + self.range_bus.lookup_key( + builder, + RangeCheckerBusMessage { + value: local.height_limbs[i].into(), + max_bits: AB::Expr::from_usize(LIMB_BITS), + }, + local.is_valid, + ); + } + + // While the (N + 1)-th row (index N) is invalid, we use it to store the final number + // of total cells. We thus can always constrain local.total_cells + local.num_cells = + // next.total_cells when local is valid, and when we're on the summary row we can send + // the stacking main width message. + // + // Note that we must constrain that the is_last flag is set correctly, i.e. it must + // only be set on the row immediately after the N-th. + builder.assert_bool(local.is_last); + builder.when(local.is_last).assert_zero(local.is_valid); + builder.when(next.is_last).assert_one(local.is_valid); + builder + .when(local.sorted_idx - AB::F::from_usize(self.per_air.len() - 1)) + .assert_zero(next.is_last); + builder + .when(next.is_last) + .assert_zero(local.sorted_idx - AB::F::from_usize(self.per_air.len() - 1)); + + // Constrain n_max on each row. Also constrain that local.is_n_max_greater is one when + // n_max is greater than n_logup, and zero otherwise. + builder + .when(local.is_first) + .assert_eq(local.n_max, n_abs.clone()); + builder + .when(local.is_valid) + .assert_eq(local.n_max, next.n_max); + + builder.assert_bool(local.is_n_max_greater); + self.range_bus.lookup_key( + builder, + RangeCheckerBusMessage { + value: (local.n_max - n_logup) * (local.is_n_max_greater * AB::F::TWO - AB::F::ONE), + max_bits: AB::Expr::from_usize(5), + }, + local.is_last, + ); + + self.tower_module_bus.send( + builder, + local.proof_idx, + TowerModuleMessage { + idx: local.idx.into(), + tidx: local.starting_tidx.into(), + n_logup: n_logup.into(), + }, + local.is_last, + ); + + // Send n_max value to expression claim air + self.expression_claim_n_max_bus.send( + builder, + local.proof_idx, + ExpressionClaimNMaxMessage { + n_max: local.n_max.into(), + }, + local.is_last, + ); + + // Send n_lift to constraint folding air + self.n_lift_bus.send( + builder, + local.proof_idx, + NLiftMessage { + air_idx: local.idx.into(), + n_lift: local.log_height.into(), + }, + local.is_present, + ); + + // Send count of present airs to fraction folder air + self.fraction_folder_input_bus.send( + builder, + local.proof_idx, + FractionFolderInputMessage { + num_present_airs: local.num_present, + }, + local.is_last, + ); + } +} + +pub(super) fn borrow_var_cols( + slice: &[F], + idx_flags: usize, + max_cached: usize, +) -> ProofShapeVarCols<'_, F> { + let flags_idx = 0; + let cached_commits_idx = flags_idx + idx_flags; + + let cached_commits = &slice[cached_commits_idx..cached_commits_idx + max_cached * DIGEST_SIZE]; + let cached_commits: &[[F; DIGEST_SIZE]] = unsafe { + std::slice::from_raw_parts( + cached_commits.as_ptr() as *const [F; DIGEST_SIZE], + max_cached, + ) + }; + + ProofShapeVarCols { + idx_flags: &slice[flags_idx..cached_commits_idx], + cached_commits, + } +} diff --git a/ceno_recursion_v2/src/proof_shape/proof_shape/cuda.rs b/ceno_recursion_v2/src/proof_shape/proof_shape/cuda.rs new file mode 100644 index 000000000..a9a0b225b --- /dev/null +++ b/ceno_recursion_v2/src/proof_shape/proof_shape/cuda.rs @@ -0,0 +1,143 @@ +use std::sync::Arc; + +use itertools::Itertools; +use openvm_cuda_backend::{GpuBackend, base::DeviceMatrix, prelude::Digest}; +use openvm_cuda_common::{copy::MemCopyH2D, memory_manager::MemTracker}; +use openvm_stark_backend::prover::AirProvingContext; +use openvm_stark_sdk::config::baby_bear_poseidon2::DIGEST_SIZE; + +use crate::{ + cuda::{preflight::PreflightGpu, vk::VerifyingKeyGpu}, + proof_shape::{cuda_abi::proof_shape_tracegen, proof_shape::ProofShapeCols}, + system::POW_CHECKER_HEIGHT, + tracegen::ModuleChip, +}; +use recursion_circuit::primitives::{ + pow::cuda::PowerCheckerGpuTraceGenerator, range::cuda::RangeCheckerGpuTraceGenerator, +}; + +#[repr(C)] +pub(crate) struct ProofShapePerProof { + num_present: usize, + n_max: usize, + n_logup: usize, + final_cidx: usize, + final_total_interactions: usize, + main_commit: Digest, +} + +#[repr(C)] +pub(crate) struct ProofShapeTracegenInputs { + num_airs: usize, + l_skip: usize, + max_interaction_count: u32, + max_cached: usize, + min_cached_idx: usize, + pre_hash: Digest, + range_checker_8_ptr: *mut u32, + range_checker_5_ptr: *mut u32, + pow_checker_ptr: *mut u32, +} + +#[derive(derive_new::new)] +pub(in crate::proof_shape) struct ProofShapeChipGpu +{ + encoder_width: usize, + min_cached_idx: usize, + max_cached: usize, + range_checker: Arc>, + pow_checker: Arc>, +} + +const NUM_LIMBS: usize = 4; +const LIMB_BITS: usize = 8; +impl ModuleChip for ProofShapeChipGpu { + type Ctx<'a> = (&'a VerifyingKeyGpu, &'a [PreflightGpu]); + + #[tracing::instrument(level = "trace", skip_all)] + fn generate_proving_ctx( + &self, + ctx: &Self::Ctx<'_>, + height: Option, + ) -> Option> { + let (vk_gpu, preflights_gpu) = ctx; + let mem = MemTracker::start("tracegen.proof_shape"); + let num_valid_rows = preflights_gpu.len() * (vk_gpu.per_air.len() + 1); + let height = if let Some(height) = height { + if height < num_valid_rows { + return None; + } + height + } else { + num_valid_rows.next_power_of_two() + }; + let encoder_width = self.encoder_width; + let min_cached_idx = self.min_cached_idx; + let max_cached = self.max_cached; + let range_checker = &self.range_checker; + let pow_checker = &self.pow_checker; + let num_airs = vk_gpu.per_air.len(); + let width = + ProofShapeCols::::width() + encoder_width + max_cached * DIGEST_SIZE; + let trace = DeviceMatrix::with_capacity(height, width); + + let per_row_tidx = preflights_gpu + .iter() + .map(|preflight| preflight.proof_shape.per_row_tidx.as_ptr()) + .collect_vec(); + let sorted_trace_heights = preflights_gpu + .iter() + .map(|preflight| preflight.proof_shape.sorted_trace_heights.as_ptr()) + .collect_vec(); + let sorted_trace_metadata = preflights_gpu + .iter() + .map(|preflight| preflight.proof_shape.sorted_trace_metadata.as_ptr()) + .collect_vec(); + let cached_commits = preflights_gpu + .iter() + .map(|preflight| preflight.proof_shape.sorted_cached_commits.as_ptr()) + .collect_vec(); + let per_proof = preflights_gpu + .iter() + .map(|preflight| ProofShapePerProof { + num_present: preflight.proof_shape.num_present, + n_max: preflight.proof_shape.n_max, + n_logup: preflight.proof_shape.n_logup, + final_cidx: preflight.proof_shape.final_cidx, + final_total_interactions: preflight.proof_shape.final_total_interactions, + main_commit: preflight.proof_shape.main_commit, + }) + .collect_vec() + .to_device() + .unwrap(); + let inputs = ProofShapeTracegenInputs { + num_airs, + l_skip: vk_gpu.system_params.l_skip, + max_interaction_count: vk_gpu.system_params.logup.max_interaction_count, + max_cached, + min_cached_idx, + pre_hash: vk_gpu.pre_hash, + range_checker_8_ptr: range_checker.count_mut_ptr(), + range_checker_5_ptr: pow_checker.range_count_mut_ptr(), + pow_checker_ptr: pow_checker.pow_count_mut_ptr(), + }; + + unsafe { + proof_shape_tracegen( + trace.buffer(), + height, + &vk_gpu.per_air, + per_row_tidx, + sorted_trace_heights, + sorted_trace_metadata, + cached_commits, + &per_proof, + preflights_gpu.len(), + &inputs, + ) + .unwrap(); + } + mem.emit_metrics(); + Some(AirProvingContext::simple_no_pis(trace)) + } +} diff --git a/ceno_recursion_v2/src/proof_shape/proof_shape/mod.rs b/ceno_recursion_v2/src/proof_shape/proof_shape/mod.rs new file mode 100644 index 000000000..f0f196f9d --- /dev/null +++ b/ceno_recursion_v2/src/proof_shape/proof_shape/mod.rs @@ -0,0 +1,5 @@ +mod air; +mod trace; + +pub use air::*; +pub(in crate::proof_shape) use trace::*; diff --git a/ceno_recursion_v2/src/proof_shape/proof_shape/trace.rs b/ceno_recursion_v2/src/proof_shape/proof_shape/trace.rs new file mode 100644 index 000000000..f37cbe415 --- /dev/null +++ b/ceno_recursion_v2/src/proof_shape/proof_shape/trace.rs @@ -0,0 +1,53 @@ +use std::sync::Arc; + +use openvm_circuit_primitives::encoder::Encoder; +use openvm_stark_backend::keygen::types::MultiStarkVerifyingKey; +use openvm_stark_sdk::config::baby_bear_poseidon2::{BabyBearPoseidon2Config, DIGEST_SIZE, F}; +use p3_field::PrimeCharacteristicRing; +use p3_matrix::dense::RowMajorMatrix; + +use super::air::ProofShapeCols; +use crate::{ + primitives::{pow::PowerCheckerCpuTraceGenerator, range::RangeCheckerCpuTraceGenerator}, + system::{POW_CHECKER_HEIGHT, Preflight, RecursionProof}, + tracegen::RowMajorChip, +}; + +#[derive(derive_new::new)] +#[allow(dead_code)] +pub(in crate::proof_shape) struct ProofShapeChip { + idx_encoder: Arc, + min_cached_idx: usize, + max_cached: usize, + range_checker: Arc>, + pow_checker: Arc>, +} + +impl ProofShapeChip { + pub(in crate::proof_shape) fn placeholder_width(&self) -> usize { + ProofShapeCols::::width() + + self.idx_encoder.width() + + self.max_cached * DIGEST_SIZE + } +} + +impl RowMajorChip + for ProofShapeChip +{ + type Ctx<'a> = ( + &'a MultiStarkVerifyingKey, + &'a [RecursionProof], + &'a [Preflight], + ); + + #[tracing::instrument(level = "trace", skip_all)] + fn generate_trace( + &self, + _ctx: &Self::Ctx<'_>, + required_height: Option, + ) -> Option> { + let rows = required_height.unwrap_or(1).max(1); + let width = self.placeholder_width(); + Some(RowMajorMatrix::new(vec![F::ZERO; rows * width], width)) + } +} diff --git a/ceno_recursion_v2/src/proof_shape/pvs/air.rs b/ceno_recursion_v2/src/proof_shape/pvs/air.rs new file mode 100644 index 000000000..64374745a --- /dev/null +++ b/ceno_recursion_v2/src/proof_shape/pvs/air.rs @@ -0,0 +1,143 @@ +use std::borrow::Borrow; + +use openvm_circuit_primitives::{AlignedBorrow, SubAir, utils::not}; +use openvm_stark_backend::{ + BaseAirWithPublicValues, PartitionedBaseAir, interaction::InteractionBuilder, +}; +use p3_air::{Air, AirBuilder, BaseAir}; +use p3_field::{PrimeCharacteristicRing, PrimeField32}; +use p3_matrix::Matrix; + +use crate::{ + bus::{PublicValuesBus, PublicValuesBusMessage, TranscriptBus, TranscriptBusMessage}, + proof_shape::bus::{NumPublicValuesBus, NumPublicValuesMessage}, + subairs::nested_for_loop::{NestedForLoopIoCols, NestedForLoopSubAir}, +}; + +#[repr(C)] +#[derive(AlignedBorrow, Debug)] +pub struct PublicValuesCols { + pub is_valid: F, + + pub proof_idx: F, + pub air_idx: F, + pub pv_idx: F, + + pub is_first_in_proof: F, + pub is_first_in_air: F, + + pub tidx: F, + pub value: F, +} + +pub struct PublicValuesAir { + pub public_values_bus: PublicValuesBus, + pub num_pvs_bus: NumPublicValuesBus, + pub transcript_bus: TranscriptBus, + pub(crate) continuations_enabled: bool, +} + +impl BaseAir for PublicValuesAir { + fn width(&self) -> usize { + PublicValuesCols::::width() + } +} +impl BaseAirWithPublicValues for PublicValuesAir {} +impl PartitionedBaseAir for PublicValuesAir {} + +impl Air for PublicValuesAir +where + AB::F: PrimeField32, +{ + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + + let (local, next) = ( + main.row_slice(0).expect("window should have two elements"), + main.row_slice(1).expect("window should have two elements"), + ); + let local: &PublicValuesCols = (*local).borrow(); + let next: &PublicValuesCols = (*next).borrow(); + + NestedForLoopSubAir::<1> {}.eval( + builder, + ( + NestedForLoopIoCols { + is_enabled: local.is_valid, + counter: [local.proof_idx], + is_first: [local.is_first_in_proof], + } + .map_into(), + NestedForLoopIoCols { + is_enabled: next.is_valid, + counter: [next.proof_idx], + is_first: [next.is_first_in_proof], + } + .map_into(), + ), + ); + // Constrain is_first_for_air, send NumPublicValuesBus message when true + builder.assert_bool(local.is_first_in_air); + builder + .when(local.is_first_in_proof) + .assert_one(local.is_first_in_air); + builder + .when(local.is_first_in_air) + .assert_one(local.is_valid); + builder + .when(next.is_valid * (next.air_idx - local.air_idx)) + .assert_one(next.is_first_in_air); + + let is_same_air = local.is_valid * next.is_valid * not(next.is_first_in_air); + self.num_pvs_bus.receive( + builder, + local.proof_idx, + NumPublicValuesMessage { + air_idx: local.air_idx.into(), + tidx: local.tidx - local.pv_idx, + num_pvs: local.pv_idx + AB::Expr::ONE, + }, + local.is_valid - is_same_air.clone(), + ); + + let mut when_same_air = builder.when(is_same_air); + when_same_air.assert_eq(local.air_idx, next.air_idx); + when_same_air.assert_eq(next.pv_idx, local.pv_idx + AB::Expr::ONE); + when_same_air.assert_eq(next.tidx, local.tidx + AB::Expr::ONE); + + self.public_values_bus.send( + builder, + local.proof_idx, + PublicValuesBusMessage { + air_idx: local.air_idx, + pv_idx: local.pv_idx, + value: local.value, + }, + local.is_valid, + ); + if self.continuations_enabled { + self.public_values_bus.send( + builder, + local.proof_idx, + PublicValuesBusMessage { + air_idx: local.air_idx, + pv_idx: local.pv_idx, + value: local.value, + }, + local.is_valid, + ); + } + + // Receive transcript read of public values + self.transcript_bus.receive( + builder, + local.proof_idx, + TranscriptBusMessage { + tidx: local.tidx.into(), + value: local.value.into(), + is_sample: AB::Expr::ZERO, + }, + local.is_valid, + ); + } +} diff --git a/ceno_recursion_v2/src/proof_shape/pvs/cuda.rs b/ceno_recursion_v2/src/proof_shape/pvs/cuda.rs new file mode 100644 index 000000000..f80be3528 --- /dev/null +++ b/ceno_recursion_v2/src/proof_shape/pvs/cuda.rs @@ -0,0 +1,66 @@ +use openvm_cuda_backend::{GpuBackend, base::DeviceMatrix}; +use openvm_cuda_common::memory_manager::MemTracker; +use openvm_stark_backend::prover::AirProvingContext; + +use crate::{ + cuda::{preflight::PreflightGpu, proof::ProofGpu}, + proof_shape::{cuda_abi::public_values_tracegen, pvs::PublicValuesCols}, + tracegen::ModuleChip, +}; + +pub struct PublicValuesGpuTraceGenerator; + +impl ModuleChip for PublicValuesGpuTraceGenerator { + type Ctx<'a> = (&'a [ProofGpu], &'a [PreflightGpu]); + + #[tracing::instrument(level = "trace", skip_all)] + fn generate_proving_ctx( + &self, + ctx: &Self::Ctx<'_>, + height: Option, + ) -> Option> { + let (proofs_gpu, preflights_gpu) = ctx; + let mem = MemTracker::start("tracegen.public_values"); + debug_assert_eq!(proofs_gpu.len(), preflights_gpu.len()); + + let num_pvs = proofs_gpu[0].proof_shape.public_values.len(); + let num_valid_rows = proofs_gpu + .iter() + .map(|proof| proof.proof_shape.public_values.len()) + .sum::(); + + let height = if let Some(height) = height { + if height < num_valid_rows { + return None; + } + height + } else { + num_valid_rows.next_power_of_two() + }; + let width = PublicValuesCols::::width(); + let trace = DeviceMatrix::with_capacity(height, width); + + let pvs_data = proofs_gpu + .iter() + .map(|proof| proof.proof_shape.public_values.as_ptr()) + .collect::>(); + let pvs_tidx = preflights_gpu + .iter() + .map(|preflight| preflight.proof_shape.pvs_tidx.as_ptr()) + .collect::>(); + + unsafe { + public_values_tracegen( + trace.buffer(), + height, + pvs_data, + pvs_tidx, + proofs_gpu.len(), + num_pvs, + ) + .unwrap(); + } + mem.emit_metrics(); + Some(AirProvingContext::simple_no_pis(trace)) + } +} diff --git a/ceno_recursion_v2/src/proof_shape/pvs/mod.rs b/ceno_recursion_v2/src/proof_shape/pvs/mod.rs new file mode 100644 index 000000000..c53333434 --- /dev/null +++ b/ceno_recursion_v2/src/proof_shape/pvs/mod.rs @@ -0,0 +1,8 @@ +mod air; +mod trace; + +pub use air::*; +pub use trace::*; + +#[cfg(feature = "cuda")] +pub(crate) mod cuda; diff --git a/ceno_recursion_v2/src/proof_shape/pvs/trace.rs b/ceno_recursion_v2/src/proof_shape/pvs/trace.rs new file mode 100644 index 000000000..a979b2c68 --- /dev/null +++ b/ceno_recursion_v2/src/proof_shape/pvs/trace.rs @@ -0,0 +1,26 @@ +use openvm_stark_sdk::config::baby_bear_poseidon2::F; +use p3_field::PrimeCharacteristicRing; +use p3_matrix::dense::RowMajorMatrix; + +use crate::{ + proof_shape::pvs::PublicValuesCols, + system::{Preflight, RecursionProof}, + tracegen::RowMajorChip, +}; + +pub struct PublicValuesTraceGenerator; + +impl RowMajorChip for PublicValuesTraceGenerator { + type Ctx<'a> = (&'a [RecursionProof], &'a [Preflight]); + + #[tracing::instrument(level = "trace", skip_all)] + fn generate_trace( + &self, + _ctx: &Self::Ctx<'_>, + required_height: Option, + ) -> Option> { + let rows = required_height.unwrap_or(1).max(1); + let width = PublicValuesCols::::width(); + Some(RowMajorMatrix::new(vec![F::ZERO; rows * width], width)) + } +} diff --git a/ceno_recursion_v2/src/system/bus_inventory.rs b/ceno_recursion_v2/src/system/bus_inventory.rs new file mode 100644 index 000000000..568057e05 --- /dev/null +++ b/ceno_recursion_v2/src/system/bus_inventory.rs @@ -0,0 +1,107 @@ +use recursion_circuit::{ + bus::{ + AirShapeBus, FinalTranscriptStateBus, MerkleVerifyBus, Poseidon2CompressBus, + Poseidon2PermuteBus, XiRandomnessBus, + }, + primitives::bus::{ExpBitsLenBus, PowerCheckerBus, RangeCheckerBus, RightShiftBus}, + system::BusIndexManager, +}; + +use crate::bus::{ + CachedCommitBus as LocalCachedCommitBus, CommitmentsBus as LocalCommitmentsBus, + ExpressionClaimNMaxBus as LocalExpressionClaimNMaxBus, + FractionFolderInputBus as LocalFractionFolderInputBus, HyperdimBus as LocalHyperdimBus, + LiftedHeightsBus as LocalLiftedHeightsBus, MainBus, MainExpressionClaimBus, + MainSumcheckInputBus, MainSumcheckOutputBus, NLiftBus as LocalNLiftBus, + PublicValuesBus as LocalPublicValuesBus, TowerModuleBus, TranscriptBus as LocalTranscriptBus, +}; + +#[derive(Clone, Debug)] +pub struct BusInventory { + pub transcript_bus: LocalTranscriptBus, + pub poseidon2_permute_bus: Poseidon2PermuteBus, + pub poseidon2_compress_bus: Poseidon2CompressBus, + pub merkle_verify_bus: MerkleVerifyBus, + pub tower_module_bus: TowerModuleBus, + pub expression_claim_n_max_bus: LocalExpressionClaimNMaxBus, + pub fraction_folder_input_bus: LocalFractionFolderInputBus, + pub air_shape_bus: AirShapeBus, + pub hyperdim_bus: LocalHyperdimBus, + pub lifted_heights_bus: LocalLiftedHeightsBus, + pub commitments_bus: LocalCommitmentsBus, + pub n_lift_bus: LocalNLiftBus, + pub cached_commit_bus: LocalCachedCommitBus, + pub public_values_bus: LocalPublicValuesBus, + pub range_checker_bus: RangeCheckerBus, + pub power_checker_bus: PowerCheckerBus, + pub exp_bits_len_bus: ExpBitsLenBus, + pub main_bus: MainBus, + pub main_sumcheck_input_bus: MainSumcheckInputBus, + pub main_sumcheck_output_bus: MainSumcheckOutputBus, + pub main_expression_claim_bus: MainExpressionClaimBus, + pub right_shift_bus: RightShiftBus, + pub xi_randomness_bus: XiRandomnessBus, + pub final_state_bus: FinalTranscriptStateBus, +} + +impl BusInventory { + pub fn new(b: &mut BusIndexManager) -> Self { + let transcript_bus = LocalTranscriptBus::new(b.new_bus_idx()); + let poseidon2_permute_bus = Poseidon2PermuteBus::new(b.new_bus_idx()); + let poseidon2_compress_bus = Poseidon2CompressBus::new(b.new_bus_idx()); + let merkle_verify_bus = MerkleVerifyBus::new(b.new_bus_idx()); + + let gkr_bus_idx = b.new_bus_idx(); + let tower_module_bus = TowerModuleBus::new(gkr_bus_idx); + + let air_shape_bus = AirShapeBus::new(b.new_bus_idx()); + let hyperdim_bus = LocalHyperdimBus::new(b.new_bus_idx()); + let lifted_heights_bus = LocalLiftedHeightsBus::new(b.new_bus_idx()); + let commitments_bus = LocalCommitmentsBus::new(b.new_bus_idx()); + let public_values_bus = LocalPublicValuesBus::new(b.new_bus_idx()); + let range_checker_bus = RangeCheckerBus::new(b.new_bus_idx()); + let power_checker_bus = PowerCheckerBus::new(b.new_bus_idx()); + let expression_claim_n_max_bus = LocalExpressionClaimNMaxBus::new(b.new_bus_idx()); + let fraction_folder_input_bus = LocalFractionFolderInputBus::new(b.new_bus_idx()); + let n_lift_bus = LocalNLiftBus::new(b.new_bus_idx()); + + let xi_randomness_bus = XiRandomnessBus::new(b.new_bus_idx()); + + let exp_bits_len_bus = ExpBitsLenBus::new(b.new_bus_idx()); + let right_shift_bus = RightShiftBus::new(b.new_bus_idx()); + let main_bus = MainBus::new(b.new_bus_idx()); + let main_sumcheck_input_bus = MainSumcheckInputBus::new(b.new_bus_idx()); + let main_sumcheck_output_bus = MainSumcheckOutputBus::new(b.new_bus_idx()); + let main_expression_claim_bus = MainExpressionClaimBus::new(b.new_bus_idx()); + + let cached_commit_bus = LocalCachedCommitBus::new(b.new_bus_idx()); + let final_state_bus = FinalTranscriptStateBus::new(b.new_bus_idx()); + + Self { + transcript_bus, + poseidon2_permute_bus, + poseidon2_compress_bus, + merkle_verify_bus, + tower_module_bus: tower_module_bus, + expression_claim_n_max_bus, + fraction_folder_input_bus, + air_shape_bus, + hyperdim_bus, + lifted_heights_bus, + commitments_bus, + n_lift_bus, + cached_commit_bus, + public_values_bus, + range_checker_bus, + power_checker_bus, + exp_bits_len_bus, + main_bus, + main_sumcheck_input_bus, + main_sumcheck_output_bus, + main_expression_claim_bus, + right_shift_bus, + xi_randomness_bus, + final_state_bus, + } + } +} diff --git a/ceno_recursion_v2/src/system/frame.rs b/ceno_recursion_v2/src/system/frame.rs new file mode 100644 index 000000000..d35033bcf --- /dev/null +++ b/ceno_recursion_v2/src/system/frame.rs @@ -0,0 +1,50 @@ +use itertools::Itertools; +use openvm_stark_backend::{ + SystemParams, + keygen::types::{ + MultiStarkVerifyingKey, StarkVerifyingKey, StarkVerifyingParams, + VerifierSinglePreprocessedData, + }, +}; +use openvm_stark_sdk::config::baby_bear_poseidon2::{BabyBearPoseidon2Config, Digest, F}; + +/// Frame-friendly versions of verifying key structures that strip non-deterministic fields. +/// Copied from upstream because the originals are `pub(crate)`; keeping them local avoids +/// changing ProofShape logic while still letting the fork build against private upstream APIs. +#[derive(Clone)] +pub struct StarkVkeyFrame { + pub preprocessed_data: Option>, + pub params: StarkVerifyingParams, + pub num_interactions: usize, + pub max_constraint_degree: u8, + pub is_required: bool, +} + +#[derive(Clone)] +pub struct MultiStarkVkeyFrame { + pub params: SystemParams, + pub per_air: Vec, + pub max_constraint_degree: usize, +} + +impl From<&StarkVerifyingKey> for StarkVkeyFrame { + fn from(vk: &StarkVerifyingKey) -> Self { + Self { + preprocessed_data: vk.preprocessed_data.clone(), + params: vk.params.clone(), + num_interactions: vk.num_interactions(), + max_constraint_degree: vk.max_constraint_degree, + is_required: vk.is_required, + } + } +} + +impl From<&MultiStarkVerifyingKey> for MultiStarkVkeyFrame { + fn from(mvk: &MultiStarkVerifyingKey) -> Self { + Self { + params: mvk.inner.params.clone(), + per_air: mvk.inner.per_air.iter().map(Into::into).collect_vec(), + max_constraint_degree: mvk.max_constraint_degree(), + } + } +} diff --git a/ceno_recursion_v2/src/system/mod.rs b/ceno_recursion_v2/src/system/mod.rs new file mode 100644 index 000000000..50b58c327 --- /dev/null +++ b/ceno_recursion_v2/src/system/mod.rs @@ -0,0 +1,522 @@ +pub mod frame; +mod preflight; +mod types; + +pub use crate::proof_shape::ProofShapeModule; +pub use preflight::{ + BatchConstraintPreflight, ChipTranscriptRange, MainPreflight, Preflight, ProofShapePreflight, + TowerChipTranscriptRange, TowerPreflight, TraceVData, +}; +pub use recursion_circuit::system::{ + AirModule, BusIndexManager, GlobalTraceGenCtx, TraceGenModule, VerifierConfig, + VerifierExternalData, +}; +mod bus_inventory; +pub mod utils; + +pub use bus_inventory::BusInventory; +pub use types::{ + RecursionField, RecursionPcs, RecursionProof, RecursionVk, convert_proof_from_zkvm, + convert_vk_from_zkvm, +}; + +use std::{iter, mem, sync::Arc}; + +use self::utils::test_system_params_zero_pow; +use crate::{ + batch_constraint::{self, BatchConstraintModule}, + main::MainModule, + tower::TowerModule, + transcript::TranscriptModule, +}; +use openvm_cpu_backend::CpuBackend; +use openvm_poseidon2_air::POSEIDON2_WIDTH; +use openvm_stark_backend::{ + AirRef, FiatShamirTranscript, StarkEngine, StarkProtocolConfig, TranscriptHistory, + interaction::BusIndex, + p3_maybe_rayon::prelude::*, + prover::{AirProvingContext, CommittedTraceData, ProverBackend}, +}; +use openvm_stark_sdk::config::baby_bear_poseidon2::{BabyBearPoseidon2Config, F}; +use p3_field::PrimeCharacteristicRing; +use p3_matrix::Matrix; +use recursion_circuit::primitives::{ + exp_bits_len::{ExpBitsLenAir, ExpBitsLenTraceGenerator}, + pow::{PowerCheckerAir, PowerCheckerCpuTraceGenerator}, +}; +use tracing::Span; + +pub const POW_CHECKER_HEIGHT: usize = 32; + +/// Local override of the upstream CPU tracegen context so modules accept ZKVM proofs. +pub struct GlobalCtxCpu; + +impl GlobalTraceGenCtx for GlobalCtxCpu { + type ChildVerifyingKey = RecursionVk; + type MultiProof = [RecursionProof]; + type PreflightRecords = [Preflight]; +} + +/// Local fork of AggregationSubCircuit so ceno modules depend on local BusInventory. +pub trait AggregationSubCircuit { + fn airs>(&self) -> Vec>; + + fn bus_inventory(&self) -> &BusInventory; + + fn next_bus_idx(&self) -> BusIndex; + + fn max_num_proofs(&self) -> usize; +} + +pub trait VerifierTraceGen> { + fn new(child_vk: Arc, config: VerifierConfig) -> Self; + + fn commit_child_vk>( + &self, + engine: &E, + child_vk: &RecursionVk, + ) -> CommittedTraceData; + + #[allow(clippy::ptr_arg)] + fn generate_proving_ctxs< + TS: FiatShamirTranscript + + TranscriptHistory, + >( + &self, + child_vk: &RecursionVk, + child_vk_pcs_data: CommittedTraceData, + proofs: &[RecursionProof], + external_data: &mut VerifierExternalData<'_>, + initial_transcript: TS, + ) -> Option>>; + + fn generate_proving_ctxs_base< + TS: FiatShamirTranscript + + TranscriptHistory, + >( + &self, + child_vk: &RecursionVk, + child_vk_pcs_data: CommittedTraceData, + proofs: &[RecursionProof], + initial_transcript: TS, + ) -> Vec> { + let poseidon2_compress_inputs = vec![]; + let poseidon2_permute_inputs = vec![]; + let range_check_inputs = vec![]; + + let mut external_data = VerifierExternalData { + poseidon2_compress_inputs: &poseidon2_compress_inputs, + poseidon2_permute_inputs: &poseidon2_permute_inputs, + range_check_inputs: &range_check_inputs, + required_heights: None, + final_transcript_state: None, + }; + + self.generate_proving_ctxs::( + child_vk, + child_vk_pcs_data, + proofs, + &mut external_data, + initial_transcript, + ) + .unwrap() + } +} + +/// The recursive verifier sub-circuit consists of multiple chips, grouped into **modules**. +/// +/// This struct is stateful. +pub struct VerifierSubCircuit { + pub(crate) bus_inventory: BusInventory, + pub(crate) bus_idx_manager: BusIndexManager, + pub(crate) transcript: TranscriptModule, + pub(crate) proof_shape: ProofShapeModule, + pub(crate) main_module: MainModule, + pub(crate) gkr: TowerModule, + pub(crate) batch_constraint: BatchConstraintModule, +} + +#[derive(Copy, Clone)] +enum TraceModuleRef<'a> { + Transcript(&'a TranscriptModule), + ProofShape(&'a ProofShapeModule), + Main(&'a MainModule), + Tower(&'a TowerModule), + BatchConstraint(&'a BatchConstraintModule), +} + +impl<'a> TraceModuleRef<'a> { + fn name(self) -> &'static str { + match self { + TraceModuleRef::Transcript(_) => "Transcript", + TraceModuleRef::ProofShape(_) => "ProofShape", + TraceModuleRef::Main(_) => "Main", + TraceModuleRef::Tower(_) => "Tower", + TraceModuleRef::BatchConstraint(_) => "BatchConstraint", + } + } + + #[tracing::instrument(name = "wrapper.run_preflight", level = "trace", skip_all)] + fn run_preflight( + self, + child_vk: &RecursionVk, + proof: &RecursionProof, + preflight: &mut Preflight, + sponge: &mut TS, + ) where + TS: FiatShamirTranscript + + TranscriptHistory, + { + match self { + TraceModuleRef::ProofShape(module) => { + module.run_preflight(child_vk, proof, preflight, sponge) + } + TraceModuleRef::Main(module) => { + module.run_preflight(child_vk, proof, preflight, sponge) + } + TraceModuleRef::Tower(module) => { + module.run_preflight(child_vk, proof, preflight, sponge) + } + TraceModuleRef::BatchConstraint(module) => { + module.run_preflight(child_vk, proof, preflight, sponge) + } + TraceModuleRef::Transcript(_) => { + panic!("Transcript module does not participate in preflight") + } + } + } + + #[allow(clippy::too_many_arguments)] + #[tracing::instrument(name = "wrapper.generate_proving_ctxs", level = "trace", skip_all)] + fn generate_cpu_ctxs>( + self, + child_vk: &RecursionVk, + proofs: &[RecursionProof], + preflights: &[Preflight], + pow_checker_gen: &Arc>, + exp_bits_len_gen: &ExpBitsLenTraceGenerator, + external_data: &VerifierExternalData<'_>, + required_heights: Option<&[usize]>, + ) -> Option>>> { + match self { + TraceModuleRef::Transcript(module) => module.generate_proving_ctxs( + child_vk, + proofs, + preflights, + &( + external_data.poseidon2_permute_inputs.as_slice(), + external_data.poseidon2_compress_inputs.as_slice(), + ), + required_heights, + ), + TraceModuleRef::ProofShape(module) => module.generate_proving_ctxs( + child_vk, + proofs, + preflights, + &( + pow_checker_gen.clone(), + external_data.range_check_inputs.as_slice(), + ), + required_heights, + ), + TraceModuleRef::Main(module) => { + module.generate_proving_ctxs(child_vk, proofs, preflights, &(), required_heights) + } + TraceModuleRef::Tower(module) => module.generate_proving_ctxs( + child_vk, + proofs, + preflights, + exp_bits_len_gen, + required_heights, + ), + TraceModuleRef::BatchConstraint(module) => { + module.generate_proving_ctxs(child_vk, proofs, preflights, &(), required_heights) + } + } + } +} + +impl VerifierSubCircuit { + pub fn new(child_vk: Arc) -> Self { + Self::new_with_options(child_vk, VerifierConfig::default()) + } + + pub fn new_with_options(child_vk: Arc, config: VerifierConfig) -> Self { + // let child_mvk = convert_vk_from_zkvm(child_vk.as_ref()); + // let proof_shape_constraint = LinearConstraint { + // coefficients: child_mvk + // .inner + // .per_air + // .iter() + // .map(|avk| avk.num_interactions() as u32) + // .collect(), + // threshold: child_mvk.inner.params.logup.max_interaction_count, + // }; + // for (i, constraint) in child_mvk.inner.trace_height_constraints.iter().enumerate() { + // assert!( + // constraint.is_implied_by(&proof_shape_constraint), + // "child_vk trace_height_constraint[{i}] is not implied by ProofShapeAir's check. \ + // The recursion circuit cannot enforce this constraint. \ + // Constraint: coefficients={:?}, threshold={}", + // constraint.coefficients, + // constraint.threshold, + // ); + // } + + let mut bus_idx_manager = BusIndexManager::new(); + let bus_inventory = BusInventory::new(&mut bus_idx_manager); + let system_params = test_system_params_zero_pow(2, 8, 3); + + let transcript = TranscriptModule::new( + bus_inventory.clone(), + system_params, + config.final_state_bus_enabled, + ); + let proof_shape = ProofShapeModule::new( + child_vk.as_ref(), + &mut bus_idx_manager, + bus_inventory.clone(), + config.continuations_enabled, + ); + let main_module = MainModule::new(&mut bus_idx_manager, bus_inventory.clone()); + let gkr = TowerModule::new( + child_vk.as_ref(), + &mut bus_idx_manager, + bus_inventory.clone(), + ); + let batch_constraint = BatchConstraintModule::new( + &mut bus_idx_manager, + bus_inventory.clone(), + MAX_NUM_PROOFS, + ); + + VerifierSubCircuit { + bus_inventory, + bus_idx_manager, + transcript, + proof_shape, + main_module, + gkr, + batch_constraint, + } + } + + /// Runs preflight for a single proof. + #[tracing::instrument(name = "execute_preflight", skip_all)] + fn run_preflight( + &self, + mut sponge: TS, + child_vk: &RecursionVk, + proof: &RecursionProof, + ) -> Preflight + where + TS: FiatShamirTranscript + + TranscriptHistory, + { + let mut preflight = Preflight::default(); + let modules = [ + TraceModuleRef::ProofShape(&self.proof_shape), + TraceModuleRef::Tower(&self.gkr), + TraceModuleRef::Main(&self.main_module), + // TODO(batch-constraint): uncomment after fixing SymbolicExpressionAir trace/preprocessed + // shape assumptions that currently trigger SymbolicEvaluator OOB in release tests. + // TraceModuleRef::BatchConstraint(&self.batch_constraint), + ]; + for module in modules { + module.run_preflight(child_vk, proof, &mut preflight, &mut sponge); + } + preflight.transcript = sponge.into_log(); + preflight + } + + #[allow(clippy::type_complexity)] + fn split_required_heights<'a>( + &self, + required_heights: Option<&'a [usize]>, + ) -> (Vec>, Option, Option) { + let t_n = self.transcript.num_airs(); + let ps_n = self.proof_shape.num_airs(); + let gkr_n = self.gkr.num_airs(); + let main_n = self.main_module.num_airs(); + let module_air_counts = [t_n, ps_n, gkr_n, main_n]; + + let Some(heights) = required_heights else { + return (vec![None; module_air_counts.len()], None, None); + }; + + let total_module_airs: usize = module_air_counts.iter().sum(); + let total = total_module_airs + 2; + assert_eq!(heights.len(), total); + + let mut offset = 0usize; + let mut per_module = Vec::with_capacity(module_air_counts.len()); + for n in module_air_counts { + per_module.push(Some(&heights[offset..offset + n])); + offset += n; + } + debug_assert_eq!(heights.len() - offset, 2); + + (per_module, Some(heights[offset]), Some(heights[offset + 1])) + } +} + +impl, const MAX_NUM_PROOFS: usize> + VerifierTraceGen, SC> for VerifierSubCircuit +{ + fn new(child_vk: Arc, config: VerifierConfig) -> Self { + Self::new_with_options(child_vk, config) + } + + fn commit_child_vk>>( + &self, + engine: &E, + child_vk: &RecursionVk, + ) -> CommittedTraceData> { + batch_constraint::commit_child_vk(engine, child_vk) + } + + #[tracing::instrument(name = "subcircuit_generate_proving_ctxs", skip_all)] + fn generate_proving_ctxs< + TS: FiatShamirTranscript + + TranscriptHistory, + >( + &self, + child_vk: &RecursionVk, + _child_vk_pcs_data: CommittedTraceData>, + proofs: &[RecursionProof], + external_data: &mut VerifierExternalData<'_>, + initial_transcript: TS, + ) -> Option>>> { + debug_assert!(proofs.len() <= MAX_NUM_PROOFS); + + let span = Span::current(); + let child_vk_recursion = child_vk; + let this = self; + let preflights = std::thread::scope(|s| { + let handles: Vec<_> = proofs + .iter() + .map(|zk_proof| { + let child_vk = child_vk_recursion; + let sponge = initial_transcript.clone(); + let span = span.clone(); + s.spawn(move || { + let _guard = span.enter(); + this.run_preflight(sponge, child_vk, zk_proof) + }) + }) + .collect(); + handles + .into_iter() + .map(|h| h.join().unwrap()) + .collect::>() + }); + + if let Some(final_transcript_state) = &mut external_data.final_transcript_state { + final_transcript_state.fill(F::ZERO); + } + + let power_checker_gen = + Arc::new(PowerCheckerCpuTraceGenerator::<2, POW_CHECKER_HEIGHT>::default()); + let exp_bits_len_gen = ExpBitsLenTraceGenerator::default(); + + let (module_required, power_checker_required, exp_bits_len_required) = + self.split_required_heights(external_data.required_heights); + + let modules = [ + TraceModuleRef::Transcript(&self.transcript), + TraceModuleRef::ProofShape(&self.proof_shape), + TraceModuleRef::Tower(&self.gkr), + TraceModuleRef::Main(&self.main_module), + // TODO(batch-constraint): re-enable once batch tracegen/preflight alignment is fixed. + // TraceModuleRef::BatchConstraint(&self.batch_constraint), + ]; + + let span = Span::current(); + let ctxs_by_module = modules + .into_par_iter() + .zip(module_required) + .map(|(module, required_heights)| { + let _guard = span.enter(); + module.generate_cpu_ctxs( + child_vk, + proofs, + &preflights, + &power_checker_gen, + &exp_bits_len_gen, + external_data, + required_heights, + ) + }) + .collect::>(); + + for (module, module_ctxs) in modules.into_iter().zip(ctxs_by_module.iter()) { + if module_ctxs.is_none() { + eprintln!( + "subcircuit_generate_proving_ctxs: module {} returned None", + module.name() + ); + } + } + + let ctxs_by_module: Vec>>> = + ctxs_by_module.into_iter().collect::>>()?; + let mut ctx_per_trace = ctxs_by_module.into_iter().flatten().collect::>(); + if power_checker_required.is_some_and(|h| h != POW_CHECKER_HEIGHT) { + return None; + } + let power_height = power_checker_required.unwrap_or(POW_CHECKER_HEIGHT); + let power_trace = power_checker_gen.generate_trace_row_major(); + if power_trace.height() != power_height { + return None; + } + ctx_per_trace.push(AirProvingContext::simple_no_pis(power_trace)); + + let exp_bits_height = exp_bits_len_required; + let exp_bits_trace = exp_bits_len_gen.generate_trace_row_major(exp_bits_height)?; + ctx_per_trace.push(AirProvingContext::simple_no_pis(exp_bits_trace)); + Some(ctx_per_trace) + } +} + +fn peek_bus_idx(manager: &BusIndexManager) -> BusIndex { + // SAFETY: BusIndexManager is currently a transparent wrapper around a single BusIndex field. + unsafe { mem::transmute::(manager.clone()) } +} + +impl AggregationSubCircuit for VerifierSubCircuit { + fn airs>(&self) -> Vec> { + let exp_bits_len_air = ExpBitsLenAir::new( + self.bus_inventory.exp_bits_len_bus, + self.bus_inventory.right_shift_bus, + ); + let power_checker_air = PowerCheckerAir::<2, POW_CHECKER_HEIGHT> { + pow_bus: self.bus_inventory.power_checker_bus, + range_bus: self.bus_inventory.range_checker_bus, + }; + + iter::empty() + .chain(self.transcript.airs()) + .chain(self.proof_shape.airs()) + .chain(self.gkr.airs()) + .chain(self.main_module.airs()) + // TODO(batch-constraint): re-chain batch AIRs after BatchConstraintModule is stable. + // .chain(self.batch_constraint.airs()) + .chain([ + Arc::new(power_checker_air) as AirRef<_>, + Arc::new(exp_bits_len_air) as AirRef<_>, + ]) + .collect() + } + + fn bus_inventory(&self) -> &BusInventory { + &self.bus_inventory + } + + fn next_bus_idx(&self) -> BusIndex { + peek_bus_idx(&self.bus_idx_manager) + } + + fn max_num_proofs(&self) -> usize { + MAX_NUM_PROOFS + } +} diff --git a/ceno_recursion_v2/src/system/preflight/mod.rs b/ceno_recursion_v2/src/system/preflight/mod.rs new file mode 100644 index 000000000..25c193633 --- /dev/null +++ b/ceno_recursion_v2/src/system/preflight/mod.rs @@ -0,0 +1,68 @@ +use openvm_poseidon2_air::POSEIDON2_WIDTH; +use openvm_stark_backend::TranscriptLog; +use openvm_stark_sdk::config::baby_bear_poseidon2::{EF, F}; + +use crate::tower::TowerReplayResult; + +/// Placeholder types mirroring upstream recursion preflight records. +/// These will be populated with real transcript metadata once the +/// ZKVM bridge is fully implemented. +#[derive(Clone, Debug, Default)] +pub struct Preflight { + pub transcript: TranscriptLog, + pub proof_shape: ProofShapePreflight, + pub main: MainPreflight, + pub gkr: TowerPreflight, + pub batch_constraint: BatchConstraintPreflight, +} + +#[derive(Clone, Debug, Default)] +pub struct ProofShapePreflight { + pub sorted_trace_vdata: Vec<(usize, TraceVData)>, + pub l_skip: usize, + pub fork_start_tidx: usize, + pub alpha_tidx: usize, + pub beta_tidx: usize, +} + +#[derive(Clone, Debug, Default)] +pub struct TraceVData { + pub log_height: usize, +} + +#[derive(Clone, Debug, Default)] +pub struct MainPreflight { + pub chips: Vec, +} + +#[derive(Clone, Debug, Default)] +pub struct TowerPreflight { + pub chips: Vec, +} + +#[derive(Clone, Debug, Default)] +pub struct TowerChipTranscriptRange { + pub chip_idx: usize, + pub instance_idx: usize, + pub tidx: usize, + pub tower_replay: TowerReplayResult, +} + +#[derive(Clone, Debug, Default)] +pub struct BatchConstraintPreflight { + pub lambda_tidx: usize, + pub tidx_before_univariate: usize, + pub sumcheck_rnd: Vec, + pub eq_ns_frontloaded: Vec, + pub eq_sharp_ns_frontloaded: Vec, +} + +#[derive(Clone, Debug, Default)] +pub struct ChipTranscriptRange { + pub chip_idx: usize, + pub instance_idx: usize, + pub tidx: usize, +} + +#[allow(dead_code)] +pub type PoseidonWord = [F; POSEIDON2_WIDTH]; diff --git a/ceno_recursion_v2/src/system/types.rs b/ceno_recursion_v2/src/system/types.rs new file mode 100644 index 000000000..99b11cc6b --- /dev/null +++ b/ceno_recursion_v2/src/system/types.rs @@ -0,0 +1,22 @@ +use std::sync::Arc; + +use ceno_zkvm::{scheme::ZKVMProof, structs::ZKVMVerifyingKey}; +use ff_ext::BabyBearExt4; +use mpcs::{Basefold, BasefoldRSParams}; +use openvm_stark_backend::{keygen::types::MultiStarkVerifyingKey, proof::Proof}; +use openvm_stark_sdk::config::baby_bear_poseidon2::BabyBearPoseidon2Config; + +pub type RecursionField = BabyBearExt4; +pub type RecursionPcs = Basefold; +pub type RecursionVk = ZKVMVerifyingKey; +pub type RecursionProof = ZKVMProof; + +pub fn convert_proof_from_zkvm(_proof: &RecursionProof) -> Proof { + unimplemented!("Bridge ZKVMProof -> Proof conversion"); +} + +pub fn convert_vk_from_zkvm( + _vk: &RecursionVk, +) -> Arc> { + unimplemented!("Bridge ZKVMVerifyingKey -> MultiStarkVerifyingKey conversion"); +} diff --git a/ceno_recursion_v2/src/system/utils.rs b/ceno_recursion_v2/src/system/utils.rs new file mode 100644 index 000000000..655ef2956 --- /dev/null +++ b/ceno_recursion_v2/src/system/utils.rs @@ -0,0 +1,64 @@ +use openvm_stark_backend::{ + SystemParams, WhirConfig, WhirParams, WhirProximityStrategy, + interaction::LogUpSecurityParameters, +}; + +fn test_whir_config_small( + log_blowup: usize, + log_stacked_height: usize, + k_whir: usize, + log_final_poly_len: usize, +) -> WhirConfig { + let params = WhirParams { + k: k_whir, + log_final_poly_len, + query_phase_pow_bits: 1, + folding_pow_bits: 2, + mu_pow_bits: 3, + proximity: WhirProximityStrategy::SplitUniqueList { + m: 3, + list_start_round: 1, + }, + }; + let security_bits = 5; + WhirConfig::new(log_blowup, log_stacked_height, params, security_bits) +} + +/// Trace heights cannot exceed `2^{l_skip + n_stack}` and stacked cells cannot exceed +/// `w_stack * 2^{l_skip + n_stack}` when using these system params. +fn test_system_params_small(l_skip: usize, n_stack: usize, k_whir: usize) -> SystemParams { + let log_final_poly_len = (n_stack + l_skip) % k_whir; + test_system_params_small_with_poly_len(l_skip, n_stack, k_whir, log_final_poly_len, 5) +} + +pub fn test_system_params_zero_pow(l_skip: usize, n_stack: usize, k_whir: usize) -> SystemParams { + let mut params = test_system_params_small(l_skip, n_stack, k_whir); + params.whir.mu_pow_bits = 0; + params.whir.folding_pow_bits = 0; + params.whir.query_phase_pow_bits = 0; + params +} + +fn test_system_params_small_with_poly_len( + l_skip: usize, + n_stack: usize, + k_whir: usize, + log_final_poly_len: usize, + max_constraint_degree: usize, +) -> SystemParams { + assert!(log_final_poly_len < l_skip + n_stack); + let log_blowup = 1; + SystemParams { + l_skip, + n_stack, + w_stack: 1 << 12, + log_blowup, + whir: test_whir_config_small(log_blowup, l_skip + n_stack, k_whir, log_final_poly_len), + logup: LogUpSecurityParameters { + max_interaction_count: 1 << 30, + log_max_message_length: 7, + pow_bits: 2, + }, + max_constraint_degree, + } +} diff --git a/ceno_recursion_v2/src/tower/bus.rs b/ceno_recursion_v2/src/tower/bus.rs new file mode 100644 index 000000000..ab41c3c30 --- /dev/null +++ b/ceno_recursion_v2/src/tower/bus.rs @@ -0,0 +1,146 @@ +use openvm_stark_sdk::config::baby_bear_poseidon2::D_EF; +use stark_recursion_circuit_derive::AlignedBorrow; + +use crate::define_typed_per_proof_permutation_bus; + +#[repr(C)] +#[derive(AlignedBorrow, Debug, Clone)] +pub struct TowerXiSamplerMessage { + pub idx: T, + pub tidx: T, +} + +define_typed_per_proof_permutation_bus!(TowerXiSamplerBus, TowerXiSamplerMessage); + +/// Message sent from TowerInputAir to TowerLayerAir +#[repr(C)] +#[derive(AlignedBorrow, Debug, Clone)] +pub struct TowerLayerInputMessage { + pub idx: T, + pub tidx: T, + pub r0_claim: [T; D_EF], + pub w0_claim: [T; D_EF], + pub q0_claim: [T; D_EF], +} + +define_typed_per_proof_permutation_bus!(TowerLayerInputBus, TowerLayerInputMessage); + +/// Message sent from TowerInputAir to TowerLayerAir +#[repr(C)] +#[derive(AlignedBorrow, Debug, Clone)] +pub struct TowerLayerOutputMessage { + pub idx: T, + pub tidx: T, + pub layer_idx_end: T, + pub input_layer_claim: [T; D_EF], + pub lambda: [T; D_EF], + pub mu: [T; D_EF], +} + +define_typed_per_proof_permutation_bus!(TowerLayerOutputBus, TowerLayerOutputMessage); + +#[repr(C)] +#[derive(AlignedBorrow, Debug, Clone)] +pub struct TowerProdLayerChallengeMessage { + pub idx: T, + pub layer_idx: T, + pub tidx: T, + pub lambda: [T; D_EF], + pub lambda_prime: [T; D_EF], + pub mu: [T; D_EF], +} + +define_typed_per_proof_permutation_bus!(TowerProdReadClaimInputBus, TowerProdLayerChallengeMessage); +define_typed_per_proof_permutation_bus!( + TowerProdWriteClaimInputBus, + TowerProdLayerChallengeMessage +); + +#[repr(C)] +#[derive(AlignedBorrow, Debug, Clone)] +pub struct TowerProdSumClaimMessage { + pub idx: T, + pub layer_idx: T, + pub lambda_claim: [T; D_EF], + pub lambda_prime_claim: [T; D_EF], + pub num_prod_count: T, +} + +define_typed_per_proof_permutation_bus!(TowerProdReadClaimBus, TowerProdSumClaimMessage); +define_typed_per_proof_permutation_bus!(TowerProdWriteClaimBus, TowerProdSumClaimMessage); + +#[repr(C)] +#[derive(AlignedBorrow, Debug, Clone)] +pub struct TowerLogupLayerChallengeMessage { + pub idx: T, + pub layer_idx: T, + pub tidx: T, + pub lambda: [T; D_EF], + pub lambda_prime: [T; D_EF], + pub mu: [T; D_EF], +} + +define_typed_per_proof_permutation_bus!(TowerLogupClaimInputBus, TowerLogupLayerChallengeMessage); + +#[repr(C)] +#[derive(AlignedBorrow, Debug, Clone)] +pub struct TowerLogupClaimMessage { + pub idx: T, + pub layer_idx: T, + pub lambda_claim: [T; D_EF], + pub lambda_prime_claim: [T; D_EF], + pub num_logup_count: T, +} + +define_typed_per_proof_permutation_bus!(TowerLogupClaimBus, TowerLogupClaimMessage); + +/// Message sent from TowerLayerAir to TowerLayerSumcheckAir +#[repr(C)] +#[derive(AlignedBorrow, Debug, Clone)] +pub struct TowerSumcheckInputMessage { + /// Module index within the proof + pub idx: T, + /// GKR layer index + pub layer_idx: T, + pub is_last_layer: T, + /// Transcript index for sumcheck + pub tidx: T, + /// Combined claim to verify + pub claim: [T; D_EF], +} + +define_typed_per_proof_permutation_bus!(TowerSumcheckInputBus, TowerSumcheckInputMessage); + +/// Message sent from TowerLayerSumcheckAir to TowerLayerAir +#[repr(C)] +#[derive(AlignedBorrow, Debug, Clone)] +pub struct TowerSumcheckOutputMessage { + /// Module index within the proof + pub idx: T, + /// GKR layer index + pub layer_idx: T, + /// Transcript index after sumcheck + pub tidx: T, + /// New claim after sumcheck + pub claim_out: [T; D_EF], + /// Equality polynomial evaluation at r' + pub eq_at_r_prime: [T; D_EF], +} + +define_typed_per_proof_permutation_bus!(TowerSumcheckOutputBus, TowerSumcheckOutputMessage); + +/// Message for passing challenges between consecutive sumcheck sub-rounds +#[repr(C)] +#[derive(AlignedBorrow, Debug, Clone)] +pub struct TowerSumcheckChallengeMessage { + /// Module index within the proof + pub idx: T, + /// GKR layer index + pub layer_idx: T, + /// Sumcheck round number + pub sumcheck_round: T, + /// The challenge value + pub challenge: [T; D_EF], +} + +define_typed_per_proof_permutation_bus!(TowerSumcheckChallengeBus, TowerSumcheckChallengeMessage); diff --git a/ceno_recursion_v2/src/tower/input/air.rs b/ceno_recursion_v2/src/tower/input/air.rs new file mode 100644 index 000000000..c481f2591 --- /dev/null +++ b/ceno_recursion_v2/src/tower/input/air.rs @@ -0,0 +1,238 @@ +use core::borrow::Borrow; + +use crate::{ + bus::{MainBus, MainMessage, TowerModuleBus, TowerModuleMessage, TranscriptBus}, + tower::bus::{ + TowerLayerInputBus, TowerLayerInputMessage, TowerLayerOutputBus, TowerLayerOutputMessage, + }, +}; +use openvm_circuit_primitives::{ + SubAir, + is_zero::{IsZeroAuxCols, IsZeroIo, IsZeroSubAir}, + utils::not, +}; +use openvm_stark_backend::{ + BaseAirWithPublicValues, PartitionedBaseAir, interaction::InteractionBuilder, +}; +use openvm_stark_sdk::config::baby_bear_poseidon2::D_EF; +use p3_air::{Air, AirBuilder, BaseAir}; +use p3_field::{Field, PrimeCharacteristicRing}; +use p3_matrix::Matrix; +use recursion_circuit::{ + subairs::proof_idx::{ProofIdxIoCols, ProofIdxSubAir}, + utils::assert_zeros, +}; +use stark_recursion_circuit_derive::AlignedBorrow; + +#[repr(C)] +#[derive(AlignedBorrow, Debug)] +pub struct TowerInputCols { + /// Whether the current row is enabled (i.e. not padding) + pub is_enabled: T, + + pub proof_idx: T, + pub idx: T, + + pub n_logup: T, + + /// Flag indicating whether there are any interactions + /// n_logup = 0 <=> total_interactions = 0 + pub is_n_logup_zero: T, + pub is_n_logup_zero_aux: IsZeroAuxCols, + + /// Transcript index + pub tidx: T, + + pub r0_claim: [T; D_EF], + pub w0_claim: [T; D_EF], + /// Root denominator claim + pub q0_claim: [T; D_EF], + + pub alpha_logup: [T; D_EF], + + pub input_layer_claim: [T; D_EF], + pub layer_output_lambda: [T; D_EF], + pub layer_output_mu: [T; D_EF], +} + +/// The TowerInputAir handles reading and passing the TowerInput +pub struct TowerInputAir { + // Buses + pub tower_module_bus: TowerModuleBus, + pub main_bus: MainBus, + pub transcript_bus: TranscriptBus, + pub layer_input_bus: TowerLayerInputBus, + pub layer_output_bus: TowerLayerOutputBus, +} + +impl BaseAir for TowerInputAir { + fn width(&self) -> usize { + TowerInputCols::::width() + } +} + +impl BaseAirWithPublicValues for TowerInputAir {} +impl PartitionedBaseAir for TowerInputAir {} + +impl Air for TowerInputAir { + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let (local, next) = ( + main.row_slice(0).expect("window should have two elements"), + main.row_slice(1).expect("window should have two elements"), + ); + let local: &TowerInputCols = (*local).borrow(); + let next: &TowerInputCols = (*next).borrow(); + + /////////////////////////////////////////////////////////////////////// + // Proof Index Constraints + /////////////////////////////////////////////////////////////////////// + + // This subair has the following constraints: + // 1. Boolean enabled flag + // 2. Disabled rows are followed by disabled rows + // 3. Proof index increments by exactly one between enabled rows + ProofIdxSubAir.eval( + builder, + ( + ProofIdxIoCols { + is_enabled: local.is_enabled, + proof_idx: local.proof_idx, + } + .map_into(), + ProofIdxIoCols { + is_enabled: next.is_enabled, + proof_idx: next.proof_idx, + } + .map_into(), + ), + ); + + /////////////////////////////////////////////////////////////////////// + // Base Constraints + /////////////////////////////////////////////////////////////////////// + + // 1. Check if n_logup is zero (no logup constraints needed) + IsZeroSubAir.eval( + builder, + ( + IsZeroIo::new( + local.n_logup.into(), + local.is_n_logup_zero.into(), + local.is_enabled.into(), + ), + local.is_n_logup_zero_aux.inv, + ), + ); + + /////////////////////////////////////////////////////////////////////// + // Output Constraints + /////////////////////////////////////////////////////////////////////// + + let has_interactions = AB::Expr::ONE - local.is_n_logup_zero; + // Input layer claim defaults to zero when no interactions + assert_zeros( + &mut builder.when(not::(has_interactions.clone())), + local.input_layer_claim, + ); + assert_zeros( + &mut builder.when(not::(has_interactions.clone())), + local.layer_output_lambda, + ); + assert_zeros( + &mut builder.when(not::(has_interactions.clone())), + local.layer_output_mu, + ); + + /////////////////////////////////////////////////////////////////////// + // Module Interactions + /////////////////////////////////////////////////////////////////////// + + let num_layers = local.n_logup; + + // Add PoW (if any) and alpha, beta + let tidx_after_alpha_beta = local.tidx + AB::Expr::from_usize(2 * D_EF); + // Add GKR layers + Sumcheck + let tidx_after_gkr_layers = tidx_after_alpha_beta.clone() + + has_interactions.clone() + * num_layers.clone() + * (num_layers.clone() + AB::Expr::TWO) + * AB::Expr::from_usize(2 * D_EF); + // 1. TowerLayerInputBus + // 1a. Send input to TowerLayerAir + self.layer_input_bus.send( + builder, + local.proof_idx, + TowerLayerInputMessage { + idx: local.idx.into(), + // Skip q0_claim + tidx: (tidx_after_alpha_beta + AB::Expr::from_usize(D_EF)) + * has_interactions.clone(), + r0_claim: local.r0_claim.map(Into::into), + w0_claim: local.w0_claim.map(Into::into), + q0_claim: local.q0_claim.map(Into::into), + }, + local.is_enabled * has_interactions.clone(), + ); + // 2. TowerLayerOutputBus + // 2a. Receive input layer claim from TowerLayerAir + self.layer_output_bus.receive( + builder, + local.proof_idx, + TowerLayerOutputMessage { + idx: local.idx.into(), + tidx: tidx_after_gkr_layers.clone(), + layer_idx_end: num_layers.clone() - AB::Expr::ONE, + input_layer_claim: local.input_layer_claim.map(Into::into), + lambda: local.layer_output_lambda.map(Into::into), + mu: local.layer_output_mu.map(Into::into), + }, + local.is_enabled * has_interactions.clone(), + ); + /////////////////////////////////////////////////////////////////////// + // External Interactions + /////////////////////////////////////////////////////////////////////// + + // 1. TowerModuleBus + // 1a. Receive initial GKR module message on first layer + self.tower_module_bus.receive( + builder, + local.proof_idx, + TowerModuleMessage { + idx: local.idx.into(), + tidx: local.tidx.into(), + n_logup: local.n_logup.into(), + }, + local.is_enabled, + ); + + // 2. TranscriptBus + // 2a. Sample alpha_logup challenge + self.transcript_bus.sample_ext( + builder, + local.proof_idx, + local.tidx, + local.alpha_logup.map(Into::into), + local.is_enabled, + ); + // 2b. Observe `q0_claim` claim + self.transcript_bus.observe_ext( + builder, + local.proof_idx, + local.tidx + AB::Expr::from_usize(2 * D_EF), + local.q0_claim, + local.is_enabled * has_interactions.clone(), + ); + + self.main_bus.send( + builder, + local.proof_idx, + MainMessage { + idx: local.idx.into(), + tidx: tidx_after_gkr_layers.clone(), + claim: local.input_layer_claim.map(Into::into), + }, + local.is_enabled * has_interactions, + ); + } +} diff --git a/ceno_recursion_v2/src/tower/input/mod.rs b/ceno_recursion_v2/src/tower/input/mod.rs new file mode 100644 index 000000000..d42282ada --- /dev/null +++ b/ceno_recursion_v2/src/tower/input/mod.rs @@ -0,0 +1,5 @@ +mod air; +mod trace; + +pub use air::{TowerInputAir, TowerInputCols}; +pub use trace::{TowerInputRecord, TowerInputTraceGenerator}; diff --git a/ceno_recursion_v2/src/tower/input/trace.rs b/ceno_recursion_v2/src/tower/input/trace.rs new file mode 100644 index 000000000..7ff45f6ac --- /dev/null +++ b/ceno_recursion_v2/src/tower/input/trace.rs @@ -0,0 +1,89 @@ +use core::borrow::BorrowMut; + +use super::TowerInputCols; +use crate::tracegen::RowMajorChip; +use openvm_circuit_primitives::{TraceSubRowGenerator, is_zero::IsZeroSubAir}; +use openvm_stark_backend::p3_maybe_rayon::prelude::*; +use openvm_stark_sdk::config::baby_bear_poseidon2::{EF, F}; +use p3_field::{BasedVectorSpace, PrimeCharacteristicRing}; +use p3_matrix::dense::RowMajorMatrix; + +#[derive(Debug, Clone, Default)] +pub struct TowerInputRecord { + pub proof_idx: usize, + pub idx: usize, + pub tidx: usize, + pub n_logup: usize, + pub alpha_logup: EF, + pub input_layer_claim: EF, +} + +pub struct TowerInputTraceGenerator; + +impl RowMajorChip for TowerInputTraceGenerator { + // (gkr_input_records, q0_claims) + type Ctx<'a> = (&'a [TowerInputRecord], &'a [EF]); + + #[tracing::instrument(level = "trace", skip_all)] + fn generate_trace( + &self, + ctx: &Self::Ctx<'_>, + required_height: Option, + ) -> Option> { + let (gkr_input_records, q0_claims) = ctx; + debug_assert_eq!(gkr_input_records.len(), q0_claims.len()); + + let width = TowerInputCols::::width(); + + // Each record generates exactly 1 row + let num_valid_rows = gkr_input_records.len(); + let height = if let Some(height) = required_height { + if height < num_valid_rows { + return None; + } + height + } else { + num_valid_rows.next_power_of_two() + }; + let mut trace = vec![F::ZERO; height * width]; + + let (data_slice, _) = trace.split_at_mut(num_valid_rows * width); + + // Process each proof row + data_slice + .par_chunks_mut(width) + .zip(gkr_input_records.par_iter().zip(q0_claims.par_iter())) + .for_each(|(row_data, (record, q0_claim))| { + let cols: &mut TowerInputCols = row_data.borrow_mut(); + + cols.is_enabled = F::ONE; + cols.proof_idx = F::from_usize(record.proof_idx); + cols.idx = F::from_usize(record.idx); + + cols.tidx = F::from_usize(record.tidx); + + cols.n_logup = F::from_usize(record.n_logup); + IsZeroSubAir.generate_subrow( + cols.n_logup, + (&mut cols.is_n_logup_zero_aux.inv, &mut cols.is_n_logup_zero), + ); + + let q0_basis = q0_claim.as_basis_coefficients_slice(); + cols.r0_claim.copy_from_slice(q0_basis); + cols.w0_claim.copy_from_slice(q0_basis); + cols.q0_claim.copy_from_slice(q0_basis); + cols.alpha_logup = record + .alpha_logup + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + cols.input_layer_claim = record + .input_layer_claim + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + }); + + Some(RowMajorMatrix::new(trace, width)) + } +} diff --git a/ceno_recursion_v2/src/tower/layer/air.rs b/ceno_recursion_v2/src/tower/layer/air.rs new file mode 100644 index 000000000..b591b0d86 --- /dev/null +++ b/ceno_recursion_v2/src/tower/layer/air.rs @@ -0,0 +1,456 @@ +use core::borrow::Borrow; + +use openvm_circuit_primitives::{SubAir, utils::assert_array_eq}; +use openvm_stark_backend::{ + BaseAirWithPublicValues, PartitionedBaseAir, interaction::InteractionBuilder, +}; +use openvm_stark_sdk::config::baby_bear_poseidon2::D_EF; +use p3_air::{Air, AirBuilder, BaseAir}; +use p3_field::{Field, PrimeCharacteristicRing, extension::BinomiallyExtendable}; +use p3_matrix::Matrix; +use stark_recursion_circuit_derive::AlignedBorrow; + +use crate::{ + bus::{AirShapeBus, AirShapeBusMessage}, + proof_shape::bus::AirShapeProperty, + tower::{ + TowerSumcheckChallengeBus, TowerSumcheckChallengeMessage, + bus::{ + TowerLayerInputBus, TowerLayerInputMessage, TowerLayerOutputBus, + TowerLayerOutputMessage, TowerLogupClaimBus, TowerLogupClaimInputBus, + TowerLogupClaimMessage, TowerLogupLayerChallengeMessage, + TowerProdLayerChallengeMessage, TowerProdReadClaimBus, TowerProdReadClaimInputBus, + TowerProdSumClaimMessage, TowerProdWriteClaimBus, TowerProdWriteClaimInputBus, + TowerSumcheckInputBus, TowerSumcheckInputMessage, TowerSumcheckOutputBus, + TowerSumcheckOutputMessage, + }, + }, +}; + +use recursion_circuit::{ + bus::TranscriptBus, + subairs::nested_for_loop::{NestedForLoopIoCols, NestedForLoopSubAir}, + utils::{assert_zeros, ext_field_add, ext_field_multiply}, +}; + +#[repr(C)] +#[derive(AlignedBorrow, Debug)] +pub struct TowerLayerCols { + /// Whether the current row is enabled (i.e. not padding) + pub is_enabled: T, + pub proof_idx: T, + pub idx: T, + pub is_first_air_idx: T, + pub is_first: T, + + /// An enabled row which is not involved in any interactions + /// but should satisfy air constraints + pub is_dummy: T, + + /// GKR layer index + pub layer_idx: T, + + /// Transcript index at the start of this layer + pub tidx: T, + + /// Sampled batching challenge + pub lambda: [T; D_EF], + /// Challenge inherited from previous layer + pub lambda_prime: [T; D_EF], + /// Reduction point + pub mu: [T; D_EF], + + pub sumcheck_claim_in: [T; D_EF], + + pub read_claim: [T; D_EF], + pub read_claim_prime: [T; D_EF], + pub write_claim: [T; D_EF], + pub write_claim_prime: [T; D_EF], + pub logup_claim: [T; D_EF], + pub logup_claim_prime: [T; D_EF], + pub num_read_count: T, + pub num_write_count: T, + pub num_logup_count: T, + + /// Received from TowerLayerSumcheckAir + pub eq_at_r_prime: [T; D_EF], + + pub r0_claim: [T; D_EF], + pub w0_claim: [T; D_EF], + pub q0_claim: [T; D_EF], +} + +/// The TowerLayerAir handles layer-to-layer transitions in the GKR protocol +pub struct TowerLayerAir { + // External buses + pub transcript_bus: TranscriptBus, + pub air_shape_bus: AirShapeBus, + // Internal buses + pub layer_input_bus: TowerLayerInputBus, + pub layer_output_bus: TowerLayerOutputBus, + pub sumcheck_input_bus: TowerSumcheckInputBus, + pub sumcheck_output_bus: TowerSumcheckOutputBus, + pub sumcheck_challenge_bus: TowerSumcheckChallengeBus, + pub prod_read_claim_input_bus: TowerProdReadClaimInputBus, + pub prod_read_claim_bus: TowerProdReadClaimBus, + pub prod_write_claim_input_bus: TowerProdWriteClaimInputBus, + pub prod_write_claim_bus: TowerProdWriteClaimBus, + pub logup_claim_input_bus: TowerLogupClaimInputBus, + pub logup_claim_bus: TowerLogupClaimBus, +} + +impl BaseAir for TowerLayerAir { + fn width(&self) -> usize { + TowerLayerCols::::width() + } +} + +impl BaseAirWithPublicValues for TowerLayerAir {} +impl PartitionedBaseAir for TowerLayerAir {} + +impl Air for TowerLayerAir +where + ::PrimeSubfield: BinomiallyExtendable<{ D_EF }>, +{ + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let (local, next) = ( + main.row_slice(0).expect("window should have two elements"), + main.row_slice(1).expect("window should have two elements"), + ); + let local: &TowerLayerCols = (*local).borrow(); + let next: &TowerLayerCols = (*next).borrow(); + + /////////////////////////////////////////////////////////////////////// + // Boolean Constraints + /////////////////////////////////////////////////////////////////////// + + builder.assert_bool(local.is_dummy); + builder.assert_bool(local.is_first_air_idx); + + /////////////////////////////////////////////////////////////////////// + // Proof Index and Loop Constraints + /////////////////////////////////////////////////////////////////////// + + type LoopSubAir = NestedForLoopSubAir<2>; + + // This subair has the following constraints: + // 1. Boolean enabled flag + // 2. Disabled rows are followed by disabled rows + // 3. Proof index increments by exactly one between enabled rows + LoopSubAir {}.eval( + builder, + ( + NestedForLoopIoCols { + is_enabled: local.is_enabled, + counter: [local.proof_idx, local.idx], + is_first: [local.is_first_air_idx, local.is_first], + } + .map_into(), + NestedForLoopIoCols { + is_enabled: next.is_enabled, + counter: [next.proof_idx, next.idx], + is_first: [next.is_first_air_idx, next.is_first], + } + .map_into(), + ), + ); + + let is_transition = LoopSubAir::local_is_transition(next.is_enabled, next.is_first); + let is_last = LoopSubAir::local_is_last(local.is_enabled, next.is_enabled, next.is_first); + // Layer index starts from 0 + builder.when(local.is_first).assert_zero(local.layer_idx); + // Layer index increments by 1 + builder + .when(is_transition.clone()) + .assert_eq(next.layer_idx, local.layer_idx + AB::Expr::ONE); + + // constrain lambda_prime + let lambda_prime_one = { + let mut arr = core::array::from_fn(|_| AB::Expr::ZERO); + arr[0] = AB::Expr::ONE; + arr + }; + assert_array_eq( + &mut builder.when(local.is_first), + local.lambda_prime, + lambda_prime_one, + ); + // constrain lambda_prime + assert_array_eq( + &mut builder.when(is_transition.clone()), + next.lambda_prime, + local.lambda, + ); + + /////////////////////////////////////////////////////////////////////// + // Root Layer Constraints + /////////////////////////////////////////////////////////////////////// + + assert_zeros( + &mut builder.when(local.is_first), + local.sumcheck_claim_in.map(Into::into), + ); + + /////////////////////////////////////////////////////////////////////// + // Inter-Layer Constraints + /////////////////////////////////////////////////////////////////////// + + let read_plus_write = ext_field_add::(local.read_claim, local.write_claim); + let folded_claim = ext_field_add::(read_plus_write, local.logup_claim); + assert_array_eq( + &mut builder.when(is_transition.clone()), + next.sumcheck_claim_in, + folded_claim.clone(), + ); + + // Transcript index increment + let tidx_after_sumcheck = local.tidx + // Sample lambda on non-root layer + + (AB::Expr::ONE - local.is_first) * AB::Expr::from_usize(D_EF) + + local.layer_idx * AB::Expr::from_usize(4 * D_EF); + let tidx_end = tidx_after_sumcheck.clone() + AB::Expr::from_usize(5 * D_EF); + builder + .when(is_transition.clone()) + .assert_eq(next.tidx, tidx_end.clone()); + + /////////////////////////////////////////////////////////////////////// + // Module Interactions + /////////////////////////////////////////////////////////////////////// + + let is_not_dummy = AB::Expr::ONE - local.is_dummy; + let is_non_root_layer = local.is_enabled * (AB::Expr::ONE - local.is_first); + + let lookup_enable = local.is_enabled * is_not_dummy.clone(); + self.air_shape_bus.lookup_key( + builder, + local.proof_idx, + AirShapeBusMessage { + sort_idx: local.idx.into(), + property_idx: AirShapeProperty::NumRead.to_field(), + value: local.num_read_count.into(), + }, + lookup_enable.clone(), + ); + self.air_shape_bus.lookup_key( + builder, + local.proof_idx, + AirShapeBusMessage { + sort_idx: local.idx.into(), + property_idx: AirShapeProperty::NumWrite.to_field(), + value: local.num_write_count.into(), + }, + lookup_enable.clone(), + ); + self.air_shape_bus.lookup_key( + builder, + local.proof_idx, + AirShapeBusMessage { + sort_idx: local.idx.into(), + property_idx: AirShapeProperty::NumLk.to_field(), + value: local.num_logup_count.into(), + }, + lookup_enable.clone(), + ); + + let tidx_for_claims = tidx_after_sumcheck.clone(); + self.prod_read_claim_input_bus.send( + builder, + local.proof_idx, + TowerProdLayerChallengeMessage { + idx: local.idx.into(), + layer_idx: local.layer_idx.into(), + tidx: tidx_for_claims.clone(), + lambda: local.lambda.map(Into::into), + lambda_prime: local.lambda_prime.map(Into::into), + mu: local.mu.map(Into::into), + }, + is_not_dummy.clone(), + ); + // TODO separate lambda, lambda_prime for prod-write the relation should be local.lambda^(num_read) + self.prod_write_claim_input_bus.send( + builder, + local.proof_idx, + TowerProdLayerChallengeMessage { + idx: local.idx.into(), + layer_idx: local.layer_idx.into(), + tidx: tidx_for_claims.clone(), + lambda: local.lambda.map(Into::into), + lambda_prime: local.lambda_prime.map(Into::into), + mu: local.mu.map(Into::into), + }, + is_not_dummy.clone(), + ); + // TODO separate lambda, lambda_prime for logup the relation should be local.lambda^(num_read + num_write) + self.logup_claim_input_bus.send( + builder, + local.proof_idx, + TowerLogupLayerChallengeMessage { + idx: local.idx.into(), + layer_idx: local.layer_idx.into(), + tidx: tidx_for_claims.clone(), + lambda: local.lambda.map(Into::into), + lambda_prime: local.lambda_prime.map(Into::into), + mu: local.mu.map(Into::into), + }, + is_not_dummy.clone(), + ); + self.prod_read_claim_bus.receive( + builder, + local.proof_idx, + TowerProdSumClaimMessage { + idx: local.idx.into(), + layer_idx: local.layer_idx.into(), + lambda_claim: local.read_claim.map(Into::into), + lambda_prime_claim: local.read_claim_prime.map(Into::into), + num_prod_count: local.num_read_count.into(), + }, + is_not_dummy.clone(), + ); + self.prod_write_claim_bus.receive( + builder, + local.proof_idx, + TowerProdSumClaimMessage { + idx: local.idx.into(), + layer_idx: local.layer_idx.into(), + lambda_claim: local.write_claim.map(Into::into), + lambda_prime_claim: local.write_claim_prime.map(Into::into), + num_prod_count: local.num_write_count.into(), + }, + is_not_dummy.clone(), + ); + self.logup_claim_bus.receive( + builder, + local.proof_idx, + TowerLogupClaimMessage { + idx: local.idx.into(), + layer_idx: local.layer_idx.into(), + lambda_claim: local.logup_claim.map(Into::into), + lambda_prime_claim: local.logup_claim_prime.map(Into::into), + num_logup_count: local.num_logup_count.into(), + }, + is_not_dummy.clone(), + ); + + let root_layer_mask = local.is_first * is_not_dummy.clone(); + assert_array_eq( + &mut builder.when(root_layer_mask.clone()), + local.read_claim_prime, + local.r0_claim, + ); + assert_array_eq( + &mut builder.when(root_layer_mask.clone()), + local.write_claim_prime, + local.w0_claim, + ); + assert_array_eq( + &mut builder.when(root_layer_mask), + local.logup_claim_prime, + local.q0_claim, + ); + + // 1. TowerLayerInputBus + // 1a. Receive GKR layers input + self.layer_input_bus.receive( + builder, + local.proof_idx, + TowerLayerInputMessage { + idx: local.idx.into(), + tidx: local.tidx.into(), + r0_claim: local.r0_claim.map(Into::into), + w0_claim: local.w0_claim.map(Into::into), + q0_claim: local.q0_claim.map(Into::into), + }, + local.is_first_air_idx * is_not_dummy.clone(), + ); + // 2. TowerLayerOutputBus + // 2a. Send GKR input layer claims back + self.layer_output_bus.send( + builder, + local.proof_idx, + TowerLayerOutputMessage { + idx: local.idx.into(), + tidx: tidx_end, + layer_idx_end: local.layer_idx.into(), + input_layer_claim: folded_claim.map(Into::into), + lambda: local.lambda.map(Into::into), + mu: local.mu.map(Into::into), + }, + is_last.clone() * is_not_dummy.clone(), + ); + // 3. TowerSumcheckInputBus + // 3a. Send claim to sumcheck + // only send sumcheck on non root layer + self.sumcheck_input_bus.send( + builder, + local.proof_idx, + TowerSumcheckInputMessage { + idx: local.idx.into(), + layer_idx: local.layer_idx.into(), + is_last_layer: is_last.clone(), + tidx: local.tidx + AB::Expr::from_usize(D_EF), + claim: local.sumcheck_claim_in.map(Into::into), + }, + is_non_root_layer.clone() * is_not_dummy.clone(), + ); + // 3. TowerSumcheckOutputBus + // 3a. Receive sumcheck results + let prime_fold = ext_field_add::(local.read_claim_prime, local.write_claim_prime); + let sumcheck_claim_out = ext_field_multiply::( + ext_field_add::(prime_fold, local.logup_claim_prime), + local.eq_at_r_prime, + ); + self.sumcheck_output_bus.receive( + builder, + local.proof_idx, + TowerSumcheckOutputMessage { + idx: local.idx.into(), + layer_idx: local.layer_idx.into(), + tidx: tidx_after_sumcheck.clone(), + claim_out: sumcheck_claim_out.map(Into::into), + eq_at_r_prime: local.eq_at_r_prime.map(Into::into), + }, + is_non_root_layer.clone() * is_not_dummy.clone(), + ); + // 4. TowerSumcheckChallengeBus + // 4a. Send challenge mu + self.sumcheck_challenge_bus.send( + builder, + local.proof_idx, + TowerSumcheckChallengeMessage { + idx: local.idx.into(), + layer_idx: local.layer_idx.into(), + sumcheck_round: AB::Expr::ZERO, + challenge: local.mu.map(Into::into), + }, + is_transition.clone() * is_not_dummy.clone(), + ); + + /////////////////////////////////////////////////////////////////////// + // External Interactions + /////////////////////////////////////////////////////////////////////// + + // 1. TranscriptBus + // sample lambda and mu + // in root & intermediate layer: for next.sumcheck_claim_in + // in last layer: for send back to GKR input layer + // 1a. Sample `lambda` + self.transcript_bus.sample_ext( + builder, + local.proof_idx, + local.tidx, + local.lambda, + local.is_enabled * is_not_dummy.clone(), + ); + // 1b. Observe layer claims + let tidx = tidx_after_sumcheck; + // 1c. Sample `mu` + self.transcript_bus.sample_ext( + builder, + local.proof_idx, + tidx, + local.mu, + local.is_enabled * is_not_dummy.clone(), + ); + } +} diff --git a/ceno_recursion_v2/src/tower/layer/logup_claim/air.rs b/ceno_recursion_v2/src/tower/layer/logup_claim/air.rs new file mode 100644 index 000000000..d25d604b0 --- /dev/null +++ b/ceno_recursion_v2/src/tower/layer/logup_claim/air.rs @@ -0,0 +1,287 @@ +use core::borrow::Borrow; + +use openvm_circuit_primitives::{SubAir, utils::assert_array_eq}; +use openvm_stark_backend::{ + BaseAirWithPublicValues, PartitionedBaseAir, interaction::InteractionBuilder, +}; +use openvm_stark_sdk::config::baby_bear_poseidon2::D_EF; +use p3_air::{Air, AirBuilder, BaseAir}; +use p3_field::{Field, PrimeCharacteristicRing, extension::BinomiallyExtendable}; +use p3_matrix::Matrix; +use stark_recursion_circuit_derive::AlignedBorrow; + +use crate::tower::bus::{ + TowerLogupClaimBus, TowerLogupClaimInputBus, TowerLogupClaimMessage, + TowerLogupLayerChallengeMessage, +}; +use recursion_circuit::{ + bus::TranscriptBus, + subairs::nested_for_loop::{NestedForLoopIoCols, NestedForLoopSubAir}, + utils::{assert_zeros, ext_field_add, ext_field_multiply, ext_field_subtract}, +}; + +#[repr(C)] +#[derive(AlignedBorrow, Debug)] +pub struct TowerLogupSumCheckClaimCols { + pub is_enabled: T, + pub proof_idx: T, + pub idx: T, + pub is_first_layer: T, + pub is_first: T, + pub is_dummy: T, + + pub layer_idx: T, + pub index_id: T, + pub tidx: T, + + pub lambda: [T; D_EF], + pub lambda_prime: [T; D_EF], + pub mu: [T; D_EF], + + pub p_xi_0: [T; D_EF], + pub p_xi_1: [T; D_EF], + pub q_xi_0: [T; D_EF], + pub q_xi_1: [T; D_EF], + pub p_xi: [T; D_EF], + pub q_xi: [T; D_EF], + + pub pow_lambda: [T; D_EF], + pub pow_lambda_prime: [T; D_EF], + pub acc_sum: [T; D_EF], + pub acc_p_cross: [T; D_EF], + pub acc_q_cross: [T; D_EF], + pub num_logup_count: T, +} + +pub struct TowerLogupSumCheckClaimAir { + pub transcript_bus: TranscriptBus, + pub logup_claim_input_bus: TowerLogupClaimInputBus, + pub logup_claim_bus: TowerLogupClaimBus, +} + +impl BaseAir for TowerLogupSumCheckClaimAir { + fn width(&self) -> usize { + TowerLogupSumCheckClaimCols::::width() + } +} + +impl BaseAirWithPublicValues for TowerLogupSumCheckClaimAir {} +impl PartitionedBaseAir for TowerLogupSumCheckClaimAir {} + +impl Air for TowerLogupSumCheckClaimAir +where + AB: AirBuilder + InteractionBuilder, + ::PrimeSubfield: BinomiallyExtendable<{ D_EF }>, +{ + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let (local_row, next_row) = ( + main.row_slice(0).expect("window should have two elements"), + main.row_slice(1).expect("window should have two elements"), + ); + let local: &TowerLogupSumCheckClaimCols = (*local_row).borrow(); + let next: &TowerLogupSumCheckClaimCols = (*next_row).borrow(); + + builder.assert_bool(local.is_dummy); + builder.assert_bool(local.is_first_layer); + + type LoopSubAir = NestedForLoopSubAir<2>; + LoopSubAir {}.eval( + builder, + ( + NestedForLoopIoCols { + is_enabled: local.is_enabled, + counter: [local.proof_idx, local.idx], + is_first: [local.is_first_layer, local.is_first], + } + .map_into(), + NestedForLoopIoCols { + is_enabled: next.is_enabled, + counter: [next.proof_idx, next.idx], + is_first: [next.is_first_layer, next.is_first], + } + .map_into(), + ), + ); + + let is_transition = LoopSubAir::local_is_transition(next.is_enabled, next.is_first); + let is_last_layer_row = + LoopSubAir::local_is_last(local.is_enabled, next.is_enabled, next.is_first); + let stay_in_layer = AB::Expr::ONE - is_transition.clone(); + let is_not_dummy = local.is_enabled * (AB::Expr::ONE - local.is_dummy); + + builder + .when(local.is_first) + .assert_zero(local.layer_idx.clone()); + builder + .when(is_transition.clone()) + .assert_eq(next.layer_idx, local.layer_idx + AB::Expr::ONE); + + builder + .when(local.is_first_layer) + .assert_zero(local.index_id.clone()); + builder + .when(local.is_enabled * next.is_enabled * next.is_first_layer) + .assert_zero(next.index_id.clone()); + builder + .when(is_not_dummy.clone() * stay_in_layer.clone()) + .assert_eq(next.index_id, local.index_id + AB::Expr::ONE); + builder + .when(is_last_layer_row.clone() * is_not_dummy.clone()) + .assert_eq( + local.index_id + AB::Expr::ONE, + local.num_logup_count.clone(), + ); + + assert_zeros( + &mut builder.when(local.is_first * is_not_dummy.clone()), + local.acc_sum.map(Into::into), + ); + assert_zeros( + &mut builder.when(local.is_first * is_not_dummy.clone()), + local.acc_p_cross.map(Into::into), + ); + assert_zeros( + &mut builder.when(local.is_first * is_not_dummy.clone()), + local.acc_q_cross.map(Into::into), + ); + builder + .when(local.is_first * is_not_dummy.clone()) + .assert_eq(local.pow_lambda[0], AB::Expr::ONE); + for limb in local.pow_lambda.iter().copied().skip(1) { + builder + .when(local.is_first * is_not_dummy.clone()) + .assert_zero(limb); + } + builder + .when(local.is_first * is_not_dummy.clone()) + .assert_eq(local.pow_lambda_prime[0], AB::Expr::ONE); + for limb in local.pow_lambda_prime.iter().copied().skip(1) { + builder + .when(local.is_first * is_not_dummy.clone()) + .assert_zero(limb); + } + + let delta_p = ext_field_subtract::(local.p_xi_1, local.p_xi_0); + let expected_p_xi = + ext_field_add::(local.p_xi_0, ext_field_multiply(delta_p, local.mu)); + assert_array_eq(builder, local.p_xi, expected_p_xi); + + let delta_q = ext_field_subtract::(local.q_xi_1, local.q_xi_0); + let expected_q_xi = + ext_field_add::(local.q_xi_0, ext_field_multiply(delta_q, local.mu)); + assert_array_eq(builder, local.q_xi, expected_q_xi); + + let (p_cross_term, q_cross_term) = + compute_recursive_relations(local.p_xi_0, local.q_xi_0, local.p_xi_1, local.q_xi_1); + + let lambda = local.lambda.map(Into::into); + let pow_lambda = local.pow_lambda.map(Into::into); + let combined = ext_field_add::( + local.p_xi, + ext_field_multiply::(lambda.clone(), local.q_xi), + ); + let contribution = ext_field_multiply::(pow_lambda.clone(), combined); + let acc_sum_with_cur = ext_field_add::(local.acc_sum, contribution); + let acc_sum_export = acc_sum_with_cur.clone(); + + assert_array_eq( + &mut builder.when(stay_in_layer.clone()), + next.acc_sum, + acc_sum_with_cur, + ); + let pow_lambda_next = ext_field_multiply::(pow_lambda, lambda.clone()); + assert_array_eq( + &mut builder.when(stay_in_layer.clone()), + next.pow_lambda, + pow_lambda_next, + ); + + let pow_lambda_prime = local.pow_lambda_prime.map(Into::into); + let lambda_prime = local.lambda_prime.map(Into::into); + let acc_p_with_cur = ext_field_add::( + local.acc_p_cross, + ext_field_multiply::(pow_lambda_prime.clone(), p_cross_term), + ); + assert_array_eq( + &mut builder.when(stay_in_layer.clone()), + next.acc_p_cross, + acc_p_with_cur.clone(), + ); + let scaled_q_term = ext_field_multiply::( + ext_field_multiply::(pow_lambda_prime.clone(), lambda_prime.clone()), + q_cross_term, + ); + let acc_q_with_cur = ext_field_add::(local.acc_q_cross, scaled_q_term); + assert_array_eq( + &mut builder.when(stay_in_layer.clone()), + next.acc_q_cross, + acc_q_with_cur.clone(), + ); + let pow_lambda_prime_next = + ext_field_multiply::(pow_lambda_prime, lambda_prime.clone()); + assert_array_eq( + &mut builder.when(stay_in_layer.clone()), + next.pow_lambda_prime, + pow_lambda_prime_next, + ); + + self.logup_claim_input_bus.receive( + builder, + local.proof_idx, + TowerLogupLayerChallengeMessage { + idx: local.idx.into(), + layer_idx: local.layer_idx.into(), + tidx: local.tidx.into(), + lambda: lambda.clone(), + lambda_prime: lambda_prime.clone(), + mu: local.mu.map(Into::into), + }, + local.is_first_layer * is_not_dummy.clone(), + ); + + self.logup_claim_bus.send( + builder, + local.proof_idx, + TowerLogupClaimMessage { + idx: local.idx.into(), + layer_idx: local.layer_idx.into(), + lambda_claim: acc_sum_export.map(Into::into), + lambda_prime_claim: acc_q_with_cur.map(Into::into), + num_logup_count: local.num_logup_count.into(), + }, + is_last_layer_row * is_not_dummy.clone(), + ); + + let mut tidx = local.tidx.into(); + for claim in [local.p_xi_0, local.q_xi_0, local.p_xi_1, local.q_xi_1] { + self.transcript_bus.observe_ext( + builder, + local.proof_idx, + tidx.clone(), + claim, + local.is_enabled * is_not_dummy.clone(), + ); + tidx += AB::Expr::from_usize(D_EF); + } + } +} + +fn compute_recursive_relations( + p_xi_0: [F; D_EF], + q_xi_0: [F; D_EF], + p_xi_1: [F; D_EF], + q_xi_1: [F; D_EF], +) -> ([FA; D_EF], [FA; D_EF]) +where + F: Into + Copy, + FA: PrimeCharacteristicRing, + FA::PrimeSubfield: BinomiallyExtendable<{ D_EF }>, +{ + let p_cross_term = ext_field_add::( + ext_field_multiply::(p_xi_0, q_xi_1), + ext_field_multiply::(p_xi_1, q_xi_0), + ); + let q_cross_term = ext_field_multiply::(q_xi_0, q_xi_1); + (p_cross_term, q_cross_term) +} diff --git a/ceno_recursion_v2/src/tower/layer/logup_claim/mod.rs b/ceno_recursion_v2/src/tower/layer/logup_claim/mod.rs new file mode 100644 index 000000000..fb0028371 --- /dev/null +++ b/ceno_recursion_v2/src/tower/layer/logup_claim/mod.rs @@ -0,0 +1,5 @@ +pub mod air; +pub mod trace; + +pub use air::{TowerLogupSumCheckClaimAir, TowerLogupSumCheckClaimCols}; +pub use trace::TowerLogupSumCheckClaimTraceGenerator; diff --git a/ceno_recursion_v2/src/tower/layer/logup_claim/trace.rs b/ceno_recursion_v2/src/tower/layer/logup_claim/trace.rs new file mode 100644 index 000000000..22c43b17d --- /dev/null +++ b/ceno_recursion_v2/src/tower/layer/logup_claim/trace.rs @@ -0,0 +1,225 @@ +use core::borrow::BorrowMut; + +use openvm_stark_backend::p3_maybe_rayon::prelude::*; +use openvm_stark_sdk::config::baby_bear_poseidon2::{D_EF, EF, F}; +use p3_field::{BasedVectorSpace, PrimeCharacteristicRing}; +use p3_matrix::dense::RowMajorMatrix; + +use super::TowerLogupSumCheckClaimCols; +use crate::{ + tower::{TowerTowerEvalRecord, interpolate_pair, layer::trace::TowerLayerRecord}, + tracegen::RowMajorChip, +}; + +pub struct TowerLogupSumCheckClaimTraceGenerator; + +type LogupTraceCtx<'a> = ( + &'a [TowerLayerRecord], + &'a [TowerTowerEvalRecord], + &'a [Vec], +); + +fn logup_rows_for_record(record: &TowerLayerRecord) -> usize { + if record.layer_count() == 0 { + 1 + } else { + (0..record.layer_count()) + .map(|layer_idx| record.logup_count_at(layer_idx).max(1)) + .sum() + } +} + +impl RowMajorChip for TowerLogupSumCheckClaimTraceGenerator { + type Ctx<'a> = LogupTraceCtx<'a>; + + #[tracing::instrument(level = "trace", skip_all)] + fn generate_trace( + &self, + ctx: &Self::Ctx<'_>, + required_height: Option, + ) -> Option> { + let (records, towers, mus_records) = ctx; + let width = TowerLogupSumCheckClaimCols::::width(); + let rows_per_proof: Vec = records.iter().map(logup_rows_for_record).collect(); + let num_valid_rows: usize = rows_per_proof.iter().sum(); + let height = if let Some(height) = required_height { + if height < num_valid_rows { + return None; + } + height + } else { + num_valid_rows.next_power_of_two().max(1) + }; + let mut trace = vec![F::ZERO; height * width]; + let (data_slice, _) = trace.split_at_mut(num_valid_rows * width); + let mut trace_slices: Vec<&mut [F]> = Vec::with_capacity(rows_per_proof.len()); + let mut remaining = data_slice; + for &rows in &rows_per_proof { + let (chunk, rest) = remaining.split_at_mut(rows * width); + trace_slices.push(chunk); + remaining = rest; + } + + trace_slices + .par_iter_mut() + .zip( + records + .par_iter() + .zip(towers.par_iter()) + .zip(mus_records.par_iter()), + ) + .for_each(|(chunk, ((record, tower), mus_for_proof))| { + if record.layer_count() == 0 { + debug_assert_eq!(chunk.len(), width); + let row_data = &mut chunk[..width]; + let cols: &mut TowerLogupSumCheckClaimCols = row_data.borrow_mut(); + cols.is_enabled = F::ONE; + cols.is_first_layer = F::ONE; + cols.is_first = F::ONE; + cols.is_dummy = F::ONE; + cols.proof_idx = F::from_usize(record.proof_idx); + cols.idx = F::from_usize(record.idx); + cols.layer_idx = F::ZERO; + cols.index_id = F::ZERO; + cols.tidx = F::from_usize(record.tidx); + cols.lambda = [F::ZERO; D_EF]; + let mut lambda_prime_one = [F::ZERO; D_EF]; + lambda_prime_one[0] = F::ONE; + cols.lambda_prime = lambda_prime_one; + cols.mu = [F::ZERO; D_EF]; + cols.p_xi_0 = [F::ZERO; D_EF]; + cols.p_xi_1 = [F::ZERO; D_EF]; + cols.q_xi_0 = [F::ZERO; D_EF]; + cols.q_xi_1 = [F::ZERO; D_EF]; + cols.p_xi = [F::ZERO; D_EF]; + cols.q_xi = [F::ZERO; D_EF]; + cols.pow_lambda = lambda_prime_one; + cols.pow_lambda_prime = lambda_prime_one; + cols.acc_sum = [F::ZERO; D_EF]; + cols.acc_p_cross = [F::ZERO; D_EF]; + cols.acc_q_cross = [F::ZERO; D_EF]; + cols.num_logup_count = F::ONE; + return; + } + + let mut proof_row_idx = 0usize; + let mut chunk_iter = chunk.chunks_mut(width); + + for layer_idx in 0..record.layer_count() { + let logup_rows = tower + .logup_layers + .get(layer_idx) + .map(|rows| rows.as_slice()) + .unwrap_or(&[]); + let total_rows = record.logup_count_at(layer_idx).max(1); + debug_assert!( + total_rows == logup_rows.len().max(1), + "unexpected logup count mismatch at layer {layer_idx}" + ); + + let lambda = record.lambda_at(layer_idx); + let lambda_prime = record.lambda_prime_at(layer_idx); + let mu = mus_for_proof.get(layer_idx).copied().unwrap_or(EF::ZERO); + let lambda_basis: [F; D_EF] = + lambda.as_basis_coefficients_slice().try_into().unwrap(); + let lambda_prime_basis: [F; D_EF] = lambda_prime + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + let mu_basis: [F; D_EF] = mu.as_basis_coefficients_slice().try_into().unwrap(); + let tidx = record.claim_tidx(layer_idx); + + let mut pow_lambda = EF::ONE; + let mut pow_lambda_prime = EF::ONE; + let mut acc_sum = EF::ZERO; + let mut acc_p_cross = EF::ZERO; + let mut acc_q_cross = EF::ZERO; + + for row_in_layer in 0..total_rows { + let row = chunk_iter + .next() + .expect("chunk should have enough rows for layer"); + let cols: &mut TowerLogupSumCheckClaimCols = row.borrow_mut(); + let is_real = row_in_layer < logup_rows.len(); + let quad = if is_real { + logup_rows[row_in_layer] + } else { + [EF::ZERO; 4] + }; + let p_vals = [quad[0], quad[1]]; + let q_vals = [quad[2], quad[3]]; + let p_xi_0 = p_vals[0]; + let p_xi_1 = p_vals[1]; + let q_xi_0 = q_vals[0]; + let q_xi_1 = q_vals[1]; + let p_xi = interpolate_pair(p_vals, mu); + let q_xi = interpolate_pair(q_vals, mu); + let combined = p_xi + lambda * q_xi; + let p_cross = p_xi_0 * q_xi_1 + p_xi_1 * q_xi_0; + let q_cross = q_xi_0 * q_xi_1; + + let contribution = if is_real { + pow_lambda * combined + } else { + EF::ZERO + }; + let p_cross_contribution = if is_real { + pow_lambda_prime * p_cross + } else { + EF::ZERO + }; + let q_cross_contribution = if is_real { + pow_lambda_prime * lambda_prime * q_cross + } else { + EF::ZERO + }; + + cols.is_enabled = F::ONE; + cols.is_dummy = F::from_bool(!is_real); + cols.is_first_layer = F::from_bool(proof_row_idx == 0); + cols.is_first = F::from_bool(row_in_layer == 0); + cols.proof_idx = F::from_usize(record.proof_idx); + cols.idx = F::from_usize(record.idx); + cols.layer_idx = F::from_usize(layer_idx); + cols.index_id = F::from_usize(row_in_layer); + cols.tidx = F::from_usize(tidx); + cols.lambda = lambda_basis; + cols.lambda_prime = lambda_prime_basis; + cols.mu = mu_basis; + cols.p_xi_0 = p_xi_0.as_basis_coefficients_slice().try_into().unwrap(); + cols.p_xi_1 = p_xi_1.as_basis_coefficients_slice().try_into().unwrap(); + cols.q_xi_0 = q_xi_0.as_basis_coefficients_slice().try_into().unwrap(); + cols.q_xi_1 = q_xi_1.as_basis_coefficients_slice().try_into().unwrap(); + cols.p_xi = p_xi.as_basis_coefficients_slice().try_into().unwrap(); + cols.q_xi = q_xi.as_basis_coefficients_slice().try_into().unwrap(); + cols.pow_lambda = + pow_lambda.as_basis_coefficients_slice().try_into().unwrap(); + cols.pow_lambda_prime = pow_lambda_prime + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + cols.acc_sum = acc_sum.as_basis_coefficients_slice().try_into().unwrap(); + cols.acc_p_cross = acc_p_cross + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + cols.acc_q_cross = acc_q_cross + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + cols.num_logup_count = F::from_usize(total_rows); + + acc_sum += contribution; + acc_p_cross += p_cross_contribution; + acc_q_cross += q_cross_contribution; + pow_lambda *= lambda; + pow_lambda_prime *= lambda_prime; + + proof_row_idx += 1; + } + } + }); + + Some(RowMajorMatrix::new(trace, width)) + } +} diff --git a/ceno_recursion_v2/src/tower/layer/mod.rs b/ceno_recursion_v2/src/tower/layer/mod.rs new file mode 100644 index 000000000..4eec4862c --- /dev/null +++ b/ceno_recursion_v2/src/tower/layer/mod.rs @@ -0,0 +1,15 @@ +mod air; +pub mod logup_claim; +pub mod prod_claim; +mod trace; + +pub use air::{TowerLayerAir, TowerLayerCols}; +pub use logup_claim::{ + TowerLogupSumCheckClaimAir, TowerLogupSumCheckClaimCols, TowerLogupSumCheckClaimTraceGenerator, +}; +pub use prod_claim::{ + TowerProdReadSumCheckClaimAir, TowerProdReadSumCheckClaimTraceGenerator, + TowerProdSumCheckClaimCols, TowerProdWriteSumCheckClaimAir, + TowerProdWriteSumCheckClaimTraceGenerator, +}; +pub use trace::{TowerLayerRecord, TowerLayerTraceGenerator}; diff --git a/ceno_recursion_v2/src/tower/layer/prod_claim/air.rs b/ceno_recursion_v2/src/tower/layer/prod_claim/air.rs new file mode 100644 index 000000000..b6db73cfb --- /dev/null +++ b/ceno_recursion_v2/src/tower/layer/prod_claim/air.rs @@ -0,0 +1,282 @@ +use core::borrow::Borrow; + +use openvm_circuit_primitives::{SubAir, utils::assert_array_eq}; +use openvm_stark_backend::{ + BaseAirWithPublicValues, PartitionedBaseAir, interaction::InteractionBuilder, +}; +use openvm_stark_sdk::config::baby_bear_poseidon2::D_EF; +use p3_air::{Air, AirBuilder, BaseAir}; +use p3_field::{Field, PrimeCharacteristicRing, extension::BinomiallyExtendable}; +use p3_matrix::Matrix; +use stark_recursion_circuit_derive::AlignedBorrow; + +use crate::tower::bus::{ + TowerProdLayerChallengeMessage, TowerProdReadClaimBus, TowerProdReadClaimInputBus, + TowerProdSumClaimMessage, TowerProdWriteClaimBus, TowerProdWriteClaimInputBus, +}; +use recursion_circuit::{ + bus::TranscriptBus, + subairs::nested_for_loop::{NestedForLoopIoCols, NestedForLoopSubAir}, + utils::{assert_zeros, ext_field_add, ext_field_multiply, ext_field_subtract}, +}; + +#[repr(C)] +#[derive(AlignedBorrow, Debug)] +pub struct TowerProdSumCheckClaimCols { + pub is_enabled: T, + pub proof_idx: T, + pub idx: T, + pub is_first_layer: T, + pub is_first: T, + pub is_dummy: T, + + pub layer_idx: T, + pub index_id: T, + pub tidx: T, + + pub lambda: [T; D_EF], + pub lambda_prime: [T; D_EF], + pub mu: [T; D_EF], + pub p_xi_0: [T; D_EF], + pub p_xi_1: [T; D_EF], + pub p_xi: [T; D_EF], + pub pow_lambda: [T; D_EF], + pub pow_lambda_prime: [T; D_EF], + pub acc_sum: [T; D_EF], + pub acc_sum_prime: [T; D_EF], + pub num_prod_count: T, +} + +pub struct TowerProdSumCheckClaimAir { + pub transcript_bus: TranscriptBus, + pub prod_claim_input_bus: IB, + pub prod_claim_bus: OB, +} + +pub type TowerProdReadSumCheckClaimAir = + TowerProdSumCheckClaimAir; +pub type TowerProdWriteSumCheckClaimAir = + TowerProdSumCheckClaimAir; + +impl BaseAir for TowerProdSumCheckClaimAir { + fn width(&self) -> usize { + TowerProdSumCheckClaimCols::::width() + } +} + +impl BaseAirWithPublicValues + for TowerProdSumCheckClaimAir +{ +} +impl PartitionedBaseAir for TowerProdSumCheckClaimAir {} + +impl TowerProdSumCheckClaimAir { + fn eval_core( + &self, + builder: &mut AB, + mut recv_challenge: Recv, + mut send_claim: Send, + ) where + AB: AirBuilder + InteractionBuilder, + ::PrimeSubfield: BinomiallyExtendable<{ D_EF }>, + Recv: FnMut(&IB, &mut AB, AB::Var, TowerProdLayerChallengeMessage, AB::Expr), + Send: FnMut(&OB, &mut AB, AB::Var, TowerProdSumClaimMessage, AB::Expr), + { + let main = builder.main(); + let (local_row, next_row) = ( + main.row_slice(0).expect("window should have two elements"), + main.row_slice(1).expect("window should have two elements"), + ); + let local: &TowerProdSumCheckClaimCols = (*local_row).borrow(); + let next: &TowerProdSumCheckClaimCols = (*next_row).borrow(); + + builder.assert_bool(local.is_dummy); + builder.assert_bool(local.is_first_layer); + + type LoopSubAir = NestedForLoopSubAir<2>; + LoopSubAir {}.eval( + builder, + ( + NestedForLoopIoCols { + is_enabled: local.is_enabled, + counter: [local.proof_idx, local.idx], + is_first: [local.is_first_layer, local.is_first], + } + .map_into(), + NestedForLoopIoCols { + is_enabled: next.is_enabled, + counter: [next.proof_idx, next.idx], + is_first: [next.is_first_layer, next.is_first], + } + .map_into(), + ), + ); + + let is_transition = LoopSubAir::local_is_transition(next.is_enabled, next.is_first); + let is_last_layer_row = + LoopSubAir::local_is_last(local.is_enabled, next.is_enabled, next.is_first); + let is_not_dummy = local.is_enabled * (AB::Expr::ONE - local.is_dummy); + let stay_in_layer = AB::Expr::ONE - is_transition.clone(); + + builder + .when(local.is_first) + .assert_zero(local.layer_idx.clone()); + builder + .when(is_transition.clone()) + .assert_eq(next.layer_idx, local.layer_idx + AB::Expr::ONE); + + builder + .when(local.is_first_layer) + .assert_zero(local.index_id.clone()); + builder + .when(local.is_enabled * next.is_enabled * next.is_first_layer) + .assert_zero(next.index_id.clone()); + builder + .when(is_not_dummy.clone() * stay_in_layer.clone()) + .assert_eq(next.index_id, local.index_id + AB::Expr::ONE); + builder + .when(is_last_layer_row.clone() * is_not_dummy.clone()) + .assert_eq(local.index_id + AB::Expr::ONE, local.num_prod_count.clone()); + + assert_zeros( + &mut builder.when(local.is_first * is_not_dummy.clone()), + local.acc_sum.map(Into::into), + ); + assert_zeros( + &mut builder.when(local.is_first * is_not_dummy.clone()), + local.acc_sum_prime.map(Into::into), + ); + builder + .when(local.is_first * is_not_dummy.clone()) + .assert_eq(local.pow_lambda[0], AB::Expr::ONE); + for limb in local.pow_lambda.iter().copied().skip(1) { + builder + .when(local.is_first * is_not_dummy.clone()) + .assert_zero(limb); + } + builder + .when(local.is_first * is_not_dummy.clone()) + .assert_eq(local.pow_lambda_prime[0], AB::Expr::ONE); + for limb in local.pow_lambda_prime.iter().copied().skip(1) { + builder + .when(local.is_first * is_not_dummy.clone()) + .assert_zero(limb); + } + + let delta = ext_field_subtract::(local.p_xi_1, local.p_xi_0); + let expected_p_xi = + ext_field_add::(local.p_xi_0, ext_field_multiply(delta, local.mu)); + assert_array_eq(builder, local.p_xi, expected_p_xi); + + let pow_lambda = local.pow_lambda.map(Into::into); + let contribution = ext_field_multiply::(local.p_xi, pow_lambda.clone()); + let acc_sum_with_cur = ext_field_add::(local.acc_sum, contribution); + let acc_sum_export = acc_sum_with_cur.clone(); + + let prime_product = ext_field_multiply::(local.p_xi_0, local.p_xi_1); + let pow_lambda_prime = local.pow_lambda_prime.map(Into::into); + let prime_contribution = + ext_field_multiply::(pow_lambda_prime.clone(), prime_product); + let acc_sum_prime_with_cur = + ext_field_add::(local.acc_sum_prime, prime_contribution); + let acc_sum_prime_export = acc_sum_prime_with_cur.clone(); + + assert_array_eq( + &mut builder.when(stay_in_layer.clone()), + next.acc_sum, + acc_sum_with_cur, + ); + assert_array_eq( + &mut builder.when(stay_in_layer.clone()), + next.acc_sum_prime, + acc_sum_prime_with_cur, + ); + + let lambda = local.lambda.map(Into::into); + let pow_lambda_next = ext_field_multiply::(pow_lambda, lambda.clone()); + assert_array_eq( + &mut builder.when(stay_in_layer.clone()), + next.pow_lambda, + pow_lambda_next, + ); + let lambda_prime = local.lambda_prime.map(Into::into); + let pow_lambda_prime_next = + ext_field_multiply::(pow_lambda_prime, lambda_prime.clone()); + assert_array_eq( + &mut builder.when(stay_in_layer.clone()), + next.pow_lambda_prime, + pow_lambda_prime_next, + ); + + recv_challenge( + &self.prod_claim_input_bus, + builder, + local.proof_idx, + TowerProdLayerChallengeMessage { + idx: local.idx.into(), + layer_idx: local.layer_idx.into(), + tidx: local.tidx.into(), + lambda, + lambda_prime: lambda_prime.clone(), + mu: local.mu.map(Into::into), + }, + local.is_first_layer * is_not_dummy.clone(), + ); + + send_claim( + &self.prod_claim_bus, + builder, + local.proof_idx, + TowerProdSumClaimMessage { + idx: local.idx.into(), + layer_idx: local.layer_idx.into(), + lambda_claim: acc_sum_export.map(Into::into), + lambda_prime_claim: acc_sum_prime_export.map(Into::into), + num_prod_count: local.num_prod_count.into(), + }, + is_last_layer_row * is_not_dummy.clone(), + ); + + let mut tidx = local.tidx.into(); + self.transcript_bus.observe_ext( + builder, + local.proof_idx, + tidx.clone(), + local.p_xi_0, + local.is_enabled * is_not_dummy.clone(), + ); + tidx += AB::Expr::from_usize(D_EF); + self.transcript_bus.observe_ext( + builder, + local.proof_idx, + tidx, + local.p_xi_1, + local.is_enabled * is_not_dummy, + ); + } +} + +macro_rules! impl_prod_sum_air { + ($ty:ty) => { + impl Air for $ty + where + AB: AirBuilder + InteractionBuilder, + ::PrimeSubfield: BinomiallyExtendable<{ D_EF }>, + { + fn eval(&self, builder: &mut AB) { + self.eval_core( + builder, + |bus, builder, proof_idx, msg, mult| { + bus.receive(builder, proof_idx, msg, mult); + }, + |bus, builder, proof_idx, msg, mult| { + bus.send(builder, proof_idx, msg, mult); + }, + ); + } + } + }; +} + +impl_prod_sum_air!(TowerProdReadSumCheckClaimAir); +impl_prod_sum_air!(TowerProdWriteSumCheckClaimAir); diff --git a/ceno_recursion_v2/src/tower/layer/prod_claim/mod.rs b/ceno_recursion_v2/src/tower/layer/prod_claim/mod.rs new file mode 100644 index 000000000..0f69c2772 --- /dev/null +++ b/ceno_recursion_v2/src/tower/layer/prod_claim/mod.rs @@ -0,0 +1,9 @@ +pub mod air; +pub mod trace; + +pub use air::{ + TowerProdReadSumCheckClaimAir, TowerProdSumCheckClaimCols, TowerProdWriteSumCheckClaimAir, +}; +pub use trace::{ + TowerProdReadSumCheckClaimTraceGenerator, TowerProdWriteSumCheckClaimTraceGenerator, +}; diff --git a/ceno_recursion_v2/src/tower/layer/prod_claim/trace.rs b/ceno_recursion_v2/src/tower/layer/prod_claim/trace.rs new file mode 100644 index 000000000..c7d96ac98 --- /dev/null +++ b/ceno_recursion_v2/src/tower/layer/prod_claim/trace.rs @@ -0,0 +1,240 @@ +use core::borrow::BorrowMut; + +use openvm_stark_backend::p3_maybe_rayon::prelude::*; +use openvm_stark_sdk::config::baby_bear_poseidon2::{D_EF, EF, F}; +use p3_field::{BasedVectorSpace, PrimeCharacteristicRing}; +use p3_matrix::dense::RowMajorMatrix; + +use super::TowerProdSumCheckClaimCols; +use crate::{ + tower::{TowerTowerEvalRecord, interpolate_pair, layer::trace::TowerLayerRecord}, + tracegen::RowMajorChip, +}; + +pub struct TowerProdReadSumCheckClaimTraceGenerator; +pub struct TowerProdWriteSumCheckClaimTraceGenerator; + +type ProdTraceCtx<'a> = ( + &'a [TowerLayerRecord], + &'a [TowerTowerEvalRecord], + &'a [Vec], +); + +fn prod_rows_for_record(record: &TowerLayerRecord, is_write: bool) -> usize { + if record.layer_count() == 0 { + 1 + } else { + (0..record.layer_count()) + .map(|layer_idx| { + if is_write { + record.write_count_at(layer_idx).max(1) + } else { + record.read_count_at(layer_idx).max(1) + } + }) + .sum() + } +} + +#[allow(clippy::too_many_arguments)] +fn generate_prod_trace( + records: &[TowerLayerRecord], + towers: &[TowerTowerEvalRecord], + mus_records: &[Vec], + is_write: bool, + required_height: Option, +) -> Option> { + let width = TowerProdSumCheckClaimCols::::width(); + let rows_per_proof: Vec = records + .iter() + .map(|record| prod_rows_for_record(record, is_write)) + .collect(); + let num_valid_rows: usize = rows_per_proof.iter().sum(); + let height = if let Some(height) = required_height { + if height < num_valid_rows { + return None; + } + height + } else { + num_valid_rows.next_power_of_two().max(1) + }; + let mut trace = vec![F::ZERO; height * width]; + let (data_slice, _) = trace.split_at_mut(num_valid_rows * width); + let mut trace_slices: Vec<&mut [F]> = Vec::with_capacity(rows_per_proof.len()); + let mut remaining = data_slice; + for &rows in &rows_per_proof { + let (chunk, rest) = remaining.split_at_mut(rows * width); + trace_slices.push(chunk); + remaining = rest; + } + + trace_slices + .par_iter_mut() + .zip( + records + .par_iter() + .zip(towers.par_iter()) + .zip(mus_records.par_iter()), + ) + .for_each(|(chunk, ((record, tower), mus_for_proof))| { + if record.layer_count() == 0 { + debug_assert_eq!(chunk.len(), width); + let row_data = &mut chunk[..width]; + let cols: &mut TowerProdSumCheckClaimCols = row_data.borrow_mut(); + cols.is_enabled = F::ONE; + cols.is_first_layer = F::ONE; + cols.is_first = F::ONE; + cols.is_dummy = F::ONE; + cols.proof_idx = F::from_usize(record.proof_idx); + cols.idx = F::from_usize(record.idx); + cols.layer_idx = F::ZERO; + cols.index_id = F::ZERO; + cols.tidx = F::from_usize(record.tidx); + cols.lambda = [F::ZERO; D_EF]; + let mut lambda_prime_one = [F::ZERO; D_EF]; + lambda_prime_one[0] = F::ONE; + cols.lambda_prime = lambda_prime_one; + cols.mu = [F::ZERO; D_EF]; + cols.p_xi_0 = [F::ZERO; D_EF]; + cols.p_xi_1 = [F::ZERO; D_EF]; + cols.p_xi = [F::ZERO; D_EF]; + cols.pow_lambda = lambda_prime_one; + cols.pow_lambda_prime = lambda_prime_one; + cols.acc_sum = [F::ZERO; D_EF]; + cols.acc_sum_prime = [F::ZERO; D_EF]; + cols.num_prod_count = F::ONE; + return; + } + + let mut proof_row_idx = 0usize; + let mut chunk_iter = chunk.chunks_mut(width); + + for layer_idx in 0..record.layer_count() { + let active_rows = if is_write { + tower + .write_layers + .get(layer_idx) + .map(|rows| rows.as_slice()) + .unwrap_or(&[]) + } else { + tower + .read_layers + .get(layer_idx) + .map(|rows| rows.as_slice()) + .unwrap_or(&[]) + }; + let total_rows = if is_write { + record.write_count_at(layer_idx).max(1) + } else { + record.read_count_at(layer_idx).max(1) + }; + debug_assert!( + total_rows == active_rows.len().max(1), + "unexpected prod count mismatch at layer {layer_idx}" + ); + let lambda = record.lambda_at(layer_idx); + let lambda_prime = record.lambda_prime_at(layer_idx); + let mu = mus_for_proof.get(layer_idx).copied().unwrap_or(EF::ZERO); + let lambda_basis: [F; D_EF] = + lambda.as_basis_coefficients_slice().try_into().unwrap(); + let lambda_prime_basis: [F; D_EF] = lambda_prime + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + let mu_basis: [F; D_EF] = mu.as_basis_coefficients_slice().try_into().unwrap(); + let tidx = record.claim_tidx(layer_idx); + + let mut pow_lambda = EF::ONE; + let mut pow_lambda_prime = EF::ONE; + let mut acc_sum = EF::ZERO; + let mut acc_sum_prime = EF::ZERO; + + for row_in_layer in 0..total_rows { + let row = chunk_iter + .next() + .expect("chunk should have enough rows for layer"); + let cols: &mut TowerProdSumCheckClaimCols = row.borrow_mut(); + let is_real = row_in_layer < active_rows.len(); + let pair = if is_real { + active_rows[row_in_layer] + } else { + [EF::ZERO; 2] + }; + let p_xi_0 = pair[0]; + let p_xi_1 = pair[1]; + let p_xi = interpolate_pair(pair, mu); + let prime_product = p_xi_0 * p_xi_1; + let contribution = if is_real { pow_lambda * p_xi } else { EF::ZERO }; + let prime_contribution = if is_real { + pow_lambda_prime * prime_product + } else { + EF::ZERO + }; + + cols.is_enabled = F::ONE; + cols.is_dummy = F::from_bool(!is_real); + cols.is_first_layer = F::from_bool(proof_row_idx == 0); + cols.is_first = F::from_bool(row_in_layer == 0); + cols.proof_idx = F::from_usize(record.proof_idx); + cols.idx = F::from_usize(record.idx); + cols.layer_idx = F::from_usize(layer_idx); + cols.index_id = F::from_usize(row_in_layer); + cols.tidx = F::from_usize(tidx); + cols.lambda = lambda_basis; + cols.lambda_prime = lambda_prime_basis; + cols.mu = mu_basis; + cols.p_xi_0 = p_xi_0.as_basis_coefficients_slice().try_into().unwrap(); + cols.p_xi_1 = p_xi_1.as_basis_coefficients_slice().try_into().unwrap(); + cols.p_xi = p_xi.as_basis_coefficients_slice().try_into().unwrap(); + cols.pow_lambda = pow_lambda.as_basis_coefficients_slice().try_into().unwrap(); + cols.pow_lambda_prime = pow_lambda_prime + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + cols.acc_sum = acc_sum.as_basis_coefficients_slice().try_into().unwrap(); + cols.acc_sum_prime = acc_sum_prime + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + cols.num_prod_count = F::from_usize(total_rows); + + acc_sum += contribution; + acc_sum_prime += prime_contribution; + pow_lambda *= lambda; + pow_lambda_prime *= lambda_prime; + + proof_row_idx += 1; + } + } + }); + + Some(RowMajorMatrix::new(trace, width)) +} + +impl RowMajorChip for TowerProdReadSumCheckClaimTraceGenerator { + type Ctx<'a> = ProdTraceCtx<'a>; + + #[tracing::instrument(level = "trace", skip_all)] + fn generate_trace( + &self, + ctx: &Self::Ctx<'_>, + required_height: Option, + ) -> Option> { + let (records, towers, mus_records) = ctx; + generate_prod_trace(records, towers, mus_records, false, required_height) + } +} + +impl RowMajorChip for TowerProdWriteSumCheckClaimTraceGenerator { + type Ctx<'a> = ProdTraceCtx<'a>; + + #[tracing::instrument(level = "trace", skip_all)] + fn generate_trace( + &self, + ctx: &Self::Ctx<'_>, + required_height: Option, + ) -> Option> { + let (records, towers, mus_records) = ctx; + generate_prod_trace(records, towers, mus_records, true, required_height) + } +} diff --git a/ceno_recursion_v2/src/tower/layer/trace.rs b/ceno_recursion_v2/src/tower/layer/trace.rs new file mode 100644 index 000000000..cc5f215ef --- /dev/null +++ b/ceno_recursion_v2/src/tower/layer/trace.rs @@ -0,0 +1,314 @@ +use core::borrow::BorrowMut; + +use openvm_stark_backend::p3_maybe_rayon::prelude::*; +use openvm_stark_sdk::config::baby_bear_poseidon2::{D_EF, EF, F}; +use p3_field::{BasedVectorSpace, PrimeCharacteristicRing}; +use p3_matrix::dense::RowMajorMatrix; + +use super::TowerLayerCols; +use crate::tracegen::RowMajorChip; + +/// Minimal record for parallel tower layer trace generation +#[derive(Debug, Clone, Default)] +pub struct TowerLayerRecord { + pub proof_idx: usize, + pub idx: usize, + pub tidx: usize, + pub layer_claims: Vec<[EF; 4]>, + pub lambdas: Vec, + pub eq_at_r_primes: Vec, + pub read_counts: Vec, + pub write_counts: Vec, + pub logup_counts: Vec, + pub read_claims: Vec, + pub read_prime_claims: Vec, + pub write_claims: Vec, + pub write_prime_claims: Vec, + pub logup_claims: Vec, + pub logup_prime_claims: Vec, + pub sumcheck_claims: Vec, +} + +impl TowerLayerRecord { + #[inline] + pub(crate) fn layer_count(&self) -> usize { + self.layer_claims.len() + } + + #[inline] + pub(crate) fn lambda_at(&self, layer_idx: usize) -> EF { + self.lambdas.get(layer_idx).copied().unwrap_or(EF::ZERO) + } + + #[inline] + pub(crate) fn lambda_prime_at(&self, layer_idx: usize) -> EF { + if layer_idx == 0 { + EF::ONE + } else { + self.lambdas + .get(layer_idx.saturating_sub(1)) + .copied() + .unwrap_or(EF::ONE) + } + } + + #[inline] + pub(crate) fn eq_at(&self, layer_idx: usize) -> EF { + self.eq_at_r_primes + .get(layer_idx) + .copied() + .unwrap_or(EF::ZERO) + } + + #[inline] + pub(crate) fn sumcheck_claim_at(&self, layer_idx: usize) -> EF { + self.sumcheck_claims + .get(layer_idx) + .copied() + .unwrap_or(EF::ZERO) + } + + #[inline] + pub(crate) fn read_claim_at(&self, layer_idx: usize) -> (EF, EF) { + ( + self.read_claims.get(layer_idx).copied().unwrap_or(EF::ZERO), + self.read_prime_claims + .get(layer_idx) + .copied() + .unwrap_or(EF::ZERO), + ) + } + + #[inline] + pub(crate) fn write_claim_at(&self, layer_idx: usize) -> (EF, EF) { + ( + self.write_claims + .get(layer_idx) + .copied() + .unwrap_or(EF::ZERO), + self.write_prime_claims + .get(layer_idx) + .copied() + .unwrap_or(EF::ZERO), + ) + } + + #[inline] + pub(crate) fn logup_claim_at(&self, layer_idx: usize) -> (EF, EF) { + ( + self.logup_claims + .get(layer_idx) + .copied() + .unwrap_or(EF::ZERO), + self.logup_prime_claims + .get(layer_idx) + .copied() + .unwrap_or(EF::ZERO), + ) + } + + #[inline] + pub(crate) fn layer_tidx(&self, layer_idx: usize) -> usize { + if layer_idx == 0 { + self.tidx + } else { + let j = layer_idx; + self.tidx + D_EF * (2 * j * j + 4 * j - 1) + } + } + + #[inline] + pub(crate) fn read_count_at(&self, layer_idx: usize) -> usize { + self.read_counts.get(layer_idx).copied().unwrap_or(1) + } + + #[inline] + pub(crate) fn write_count_at(&self, layer_idx: usize) -> usize { + self.write_counts.get(layer_idx).copied().unwrap_or(1) + } + + #[inline] + pub(crate) fn logup_count_at(&self, layer_idx: usize) -> usize { + self.logup_counts.get(layer_idx).copied().unwrap_or(1) + } + + #[inline] + pub(crate) fn claim_tidx(&self, layer_idx: usize) -> usize { + let base = self.layer_tidx(layer_idx); + let extra = if layer_idx == 0 { 0 } else { D_EF }; + base + extra + layer_idx * 4 * D_EF + } +} + +pub struct TowerLayerTraceGenerator; + +impl RowMajorChip for TowerLayerTraceGenerator { + // (gkr_layer_records, mus, q0_claims) + type Ctx<'a> = (&'a [TowerLayerRecord], &'a [Vec], &'a [EF]); + + #[tracing::instrument(level = "trace", skip_all)] + fn generate_trace( + &self, + ctx: &Self::Ctx<'_>, + required_height: Option, + ) -> Option> { + let (gkr_layer_records, mus, q0_claims) = ctx; + debug_assert_eq!(gkr_layer_records.len(), mus.len()); + debug_assert_eq!(gkr_layer_records.len(), q0_claims.len()); + + let width = TowerLayerCols::::width(); + let rows_per_proof: Vec = gkr_layer_records + .iter() + .map(|record| record.layer_count().max(1)) + .collect(); + let num_valid_rows: usize = rows_per_proof.iter().sum(); + let height = if let Some(height) = required_height { + if height < num_valid_rows { + return None; + } + height + } else { + num_valid_rows.next_power_of_two().max(1) + }; + let mut trace = vec![F::ZERO; height * width]; + + let (data_slice, _) = trace.split_at_mut(num_valid_rows * width); + let mut trace_slices: Vec<&mut [F]> = Vec::with_capacity(rows_per_proof.len()); + let mut remaining = data_slice; + + for &num_rows in &rows_per_proof { + let chunk_size = num_rows * width; + let (chunk, rest) = remaining.split_at_mut(chunk_size); + trace_slices.push(chunk); + remaining = rest; + } + + trace_slices + .par_iter_mut() + .zip( + gkr_layer_records + .par_iter() + .zip(mus.par_iter()) + .zip(q0_claims.par_iter()), + ) + .for_each(|(chunk, ((record, mus_for_proof), q0_claim))| { + let q0_basis = q0_claim.as_basis_coefficients_slice(); + let mus_for_proof = mus_for_proof.as_slice(); + + if record.layer_claims.is_empty() { + debug_assert_eq!(chunk.len(), width); + let row_data = &mut chunk[..width]; + let cols: &mut TowerLayerCols = row_data.borrow_mut(); + cols.is_enabled = F::ONE; + cols.proof_idx = F::from_usize(record.proof_idx); + cols.idx = F::from_usize(record.idx); + cols.is_first_air_idx = F::ONE; + cols.is_first = F::ONE; + cols.is_dummy = F::ONE; + cols.layer_idx = F::ZERO; + cols.tidx = F::from_usize(record.tidx); + cols.lambda = [F::ZERO; D_EF]; + let mut lambda_prime_one = [F::ZERO; D_EF]; + lambda_prime_one[0] = F::ONE; + cols.lambda_prime = lambda_prime_one; + cols.mu = [F::ZERO; D_EF]; + cols.sumcheck_claim_in = [F::ZERO; D_EF]; + cols.read_claim = [F::ZERO; D_EF]; + cols.read_claim_prime = [F::ZERO; D_EF]; + cols.write_claim = [F::ZERO; D_EF]; + cols.write_claim_prime = [F::ZERO; D_EF]; + cols.logup_claim = [F::ZERO; D_EF]; + cols.logup_claim_prime = [F::ZERO; D_EF]; + cols.num_read_count = F::ZERO; + cols.num_write_count = F::ZERO; + cols.num_logup_count = F::ZERO; + cols.eq_at_r_prime = [F::ZERO; D_EF]; + cols.r0_claim.copy_from_slice(q0_basis); + cols.w0_claim.copy_from_slice(q0_basis); + cols.q0_claim.copy_from_slice(q0_basis); + return; + } + + chunk + .chunks_mut(width) + .take(record.layer_count()) + .enumerate() + .for_each(|(layer_idx, row_data)| { + let cols: &mut TowerLayerCols = row_data.borrow_mut(); + cols.is_enabled = F::ONE; + cols.is_dummy = F::ZERO; + cols.proof_idx = F::from_usize(record.proof_idx); + cols.idx = F::from_usize(record.idx); + cols.is_first_air_idx = F::from_bool(layer_idx == 0); + cols.is_first = F::from_bool(layer_idx == 0); + cols.layer_idx = F::from_usize(layer_idx); + cols.tidx = F::from_usize(record.layer_tidx(layer_idx)); + cols.lambda = record + .lambda_at(layer_idx) + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + cols.lambda_prime = record + .lambda_prime_at(layer_idx) + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + let mu = mus_for_proof.get(layer_idx).copied().unwrap_or(EF::ZERO); + cols.mu = mu.as_basis_coefficients_slice().try_into().unwrap(); + let sumcheck_claim = if layer_idx == 0 { + EF::ZERO + } else { + record.sumcheck_claim_at(layer_idx) + }; + cols.sumcheck_claim_in = sumcheck_claim + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + let (read_claim, read_prime) = record.read_claim_at(layer_idx); + cols.read_claim = + read_claim.as_basis_coefficients_slice().try_into().unwrap(); + let (write_claim, write_prime) = record.write_claim_at(layer_idx); + cols.write_claim = write_claim + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + let (logup_claim, logup_prime) = record.logup_claim_at(layer_idx); + cols.logup_claim = logup_claim + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + cols.num_read_count = F::from_usize(record.read_count_at(layer_idx).max(1)); + cols.num_write_count = + F::from_usize(record.write_count_at(layer_idx).max(1)); + cols.num_logup_count = + F::from_usize(record.logup_count_at(layer_idx).max(1)); + cols.eq_at_r_prime = record + .eq_at(layer_idx) + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + cols.r0_claim.copy_from_slice(q0_basis); + cols.w0_claim.copy_from_slice(q0_basis); + cols.q0_claim.copy_from_slice(q0_basis); + if layer_idx == 0 { + cols.read_claim_prime.copy_from_slice(&cols.r0_claim); + cols.write_claim_prime.copy_from_slice(&cols.w0_claim); + cols.logup_claim_prime.copy_from_slice(&cols.q0_claim); + } else { + cols.read_claim_prime = + read_prime.as_basis_coefficients_slice().try_into().unwrap(); + cols.write_claim_prime = write_prime + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + cols.logup_claim_prime = logup_prime + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + } + }); + }); + + Some(RowMajorMatrix::new(trace, width)) + } +} diff --git a/ceno_recursion_v2/src/tower/mod.rs b/ceno_recursion_v2/src/tower/mod.rs new file mode 100644 index 000000000..fa1d0c3c6 --- /dev/null +++ b/ceno_recursion_v2/src/tower/mod.rs @@ -0,0 +1,881 @@ +//! # GKR Air Module +//! +//! The GKR protocol reduces a fractional sum claim $\sum_{y \in H_{\ell+n}} +//! \frac{\hat{p}(y)}{\hat{q}(y)} = 0$ to evaluation claims on the input layer polynomials at a +//! random point. This is done through a layer-by-layer recursive reduction, where each layer uses a +//! sumcheck protocol. +//! +//! The GKR Air Module verifies the [`TowerProof`](openvm_stark_backend::proof::TowerProof) struct and +//! consists of four AIRs: +//! +//! 1. **TowerInputAir** - Handles initial setup, coordinates other AIRs, and sends final claims to +//! batch constraint module +//! 2. **TowerLayerAir** - Manages layer-by-layer GKR reduction (verifies +//! [`verify_gkr`](openvm_stark_backend::verifier::fractional_sumcheck_gkr::verify_gkr)) +//! 3. **TowerLayerSumcheckAir** - Executes sumcheck protocol for each layer (verifies +//! [`verify_gkr_sumcheck`](openvm_stark_backend::verifier::fractional_sumcheck_gkr::verify_gkr_sumcheck)) +//! +//! ## Architecture +//! +//! ```text +//! ┌─────────────────┐ +//! │ │───────────────────► TranscriptBus +//! │ │ +//! TowerModuleBus ────────────────►│ TowerInputAir │───────────────────► ExpBitsLenBus +//! │ │ +//! │ │───────────────────► BatchConstraintModuleBus +//! └─────────────────┘ +//! ┆ ▲ +//! ┆ ┆ +//! TowerLayerInputBus ┆ ┆ TowerLayerOutputBus +//! ┆ ┆ +//! ▼ ┆ +//! ┌─────────────────────────┐ +//! │ │──────────────► TranscriptBus +//! ┌┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄│ TowerLayerAir │ +//! ┆ │ │──────────────► XiRandomnessBus +//! ┆ └─────────────────────────┘ +//! ┆ ┆ ▲ +//! ┆ ┆ ┆ +//! ┆ TowerSumcheckInputBus ┆ ┆ TowerSumcheckOutputBus +//! ┆ ┆ ┆ +//! ┆ ▼ ┆ +//! ┆ TowerSumcheckChallengeBus ┌─────────────────────────┐ +//! ┆┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄│ │──────────────► TranscriptBus +//! ┆ │ TowerLayerSumcheckAir │ +//! └┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄►│ │──────────────► XiRandomnessBus +//! └─────────────────────────┘ +//! ``` + +use std::sync::Arc; + +use ::sumcheck::structs::IOPProverMessage; +use openvm_cpu_backend::CpuBackend; +use openvm_poseidon2_air::POSEIDON2_WIDTH; +use openvm_stark_backend::{ + AirRef, FiatShamirTranscript, ReadOnlyTranscript, StarkProtocolConfig, TranscriptHistory, + p3_maybe_rayon::prelude::*, prover::AirProvingContext, +}; +use openvm_stark_sdk::config::baby_bear_poseidon2::{BabyBearPoseidon2Config, EF, F}; +use p3_field::PrimeCharacteristicRing; +use p3_matrix::dense::RowMajorMatrix; +use recursion_circuit::primitives::exp_bits_len::ExpBitsLenTraceGenerator; +use strum::EnumCount; +use tracing::error; + +use crate::{ + system::{ + AirModule, BusIndexManager, BusInventory, GlobalCtxCpu, Preflight, RecursionField, + RecursionProof, RecursionVk, TowerChipTranscriptRange, TraceGenModule, + }, + tower::{ + bus::{TowerLayerInputBus, TowerLayerOutputBus}, + input::{TowerInputAir, TowerInputRecord, TowerInputTraceGenerator}, + layer::{ + TowerLayerAir, TowerLayerRecord, TowerLayerTraceGenerator, TowerLogupSumCheckClaimAir, + TowerLogupSumCheckClaimTraceGenerator, TowerProdReadSumCheckClaimAir, + TowerProdReadSumCheckClaimTraceGenerator, TowerProdWriteSumCheckClaimAir, + TowerProdWriteSumCheckClaimTraceGenerator, + }, + sumcheck::{TowerLayerSumcheckAir, TowerSumcheckRecord, TowerSumcheckTraceGenerator}, + tower::replay_tower_proof, + }, + tracegen::{ModuleChip, RowMajorChip}, +}; +use ceno_zkvm::{scheme::ZKVMChipProof, structs::VerifyingKey}; +use eyre::Result; + +// Internal bus definitions +mod bus; +pub use bus::{ + TowerLogupClaimBus, TowerLogupClaimInputBus, TowerLogupClaimMessage, + TowerLogupLayerChallengeMessage, TowerProdLayerChallengeMessage, TowerProdReadClaimBus, + TowerProdReadClaimInputBus, TowerProdSumClaimMessage, TowerProdWriteClaimBus, + TowerProdWriteClaimInputBus, TowerSumcheckChallengeBus, TowerSumcheckChallengeMessage, + TowerSumcheckInputBus, TowerSumcheckInputMessage, TowerSumcheckOutputBus, + TowerSumcheckOutputMessage, +}; + +// Sub-modules for different AIRs +pub mod input; +pub mod layer; +pub mod sumcheck; +mod tower; +pub(crate) use tower::TowerReplayResult; +pub struct TowerModule { + // Global bus inventory + bus_inventory: BusInventory, + // Module buses + layer_input_bus: TowerLayerInputBus, + layer_output_bus: TowerLayerOutputBus, + sumcheck_input_bus: TowerSumcheckInputBus, + sumcheck_output_bus: TowerSumcheckOutputBus, + sumcheck_challenge_bus: TowerSumcheckChallengeBus, + prod_read_claim_input_bus: TowerProdReadClaimInputBus, + prod_read_claim_bus: TowerProdReadClaimBus, + prod_write_claim_input_bus: TowerProdWriteClaimInputBus, + prod_write_claim_bus: TowerProdWriteClaimBus, + logup_claim_input_bus: TowerLogupClaimInputBus, + logup_claim_bus: TowerLogupClaimBus, +} + +#[derive(Clone, Debug, Default)] +pub(crate) struct TowerTowerEvalRecord { + pub(crate) read_layers: Vec>, + pub(crate) write_layers: Vec>, + pub(crate) logup_layers: Vec>, +} + +struct TowerBlobCpu { + input_records: Vec, + layer_records: Vec, + tower_records: Vec, + sumcheck_records: Vec, + mus_records: Vec>, + q0_claims: Vec, +} + +impl TowerModule { + pub fn new(_vk: &RecursionVk, b: &mut BusIndexManager, bus_inventory: BusInventory) -> Self { + TowerModule { + bus_inventory, + layer_input_bus: TowerLayerInputBus::new(b.new_bus_idx()), + layer_output_bus: TowerLayerOutputBus::new(b.new_bus_idx()), + sumcheck_input_bus: TowerSumcheckInputBus::new(b.new_bus_idx()), + sumcheck_output_bus: TowerSumcheckOutputBus::new(b.new_bus_idx()), + sumcheck_challenge_bus: TowerSumcheckChallengeBus::new(b.new_bus_idx()), + prod_read_claim_input_bus: TowerProdReadClaimInputBus::new(b.new_bus_idx()), + prod_read_claim_bus: TowerProdReadClaimBus::new(b.new_bus_idx()), + prod_write_claim_input_bus: TowerProdWriteClaimInputBus::new(b.new_bus_idx()), + prod_write_claim_bus: TowerProdWriteClaimBus::new(b.new_bus_idx()), + logup_claim_input_bus: TowerLogupClaimInputBus::new(b.new_bus_idx()), + logup_claim_bus: TowerLogupClaimBus::new(b.new_bus_idx()), + } + } + + #[tracing::instrument(level = "trace", skip_all)] + pub fn run_preflight( + &self, + child_vk: &RecursionVk, + proof: &RecursionProof, + preflight: &mut Preflight, + ts: &mut TS, + ) where + TS: FiatShamirTranscript + + TranscriptHistory, + { + let _ = (self, child_vk); + for (&chip_idx, chip_instances) in &proof.chip_proofs { + for (instance_idx, chip_proof) in chip_instances.iter().enumerate() { + let tidx = ts.len(); + let _ = record_gkr_transcript(ts, chip_idx, chip_proof); + + let tower_replay = match circuit_vk_for_idx(child_vk, chip_idx) { + Some(circuit_vk) => match replay_tower_proof(chip_proof, circuit_vk) { + Ok(replay) => replay, + Err(err) => { + error!( + ?err, + chip_idx, "failed to replay tower proof during preflight" + ); + eprintln!( + "failed to replay tower proof during preflight for chip {chip_idx}: {err:?}" + ); + TowerReplayResult::default() + } + }, + None => { + eprintln!( + "missing circuit verifying key during GKR preflight for chip {chip_idx}" + ); + TowerReplayResult::default() + } + }; + + preflight.gkr.chips.push(TowerChipTranscriptRange { + chip_idx, + instance_idx, + tidx, + tower_replay, + }); + } + } + } +} + +pub(crate) fn convert_logup_claim( + chip_proof: &ZKVMChipProof, + layer_idx: usize, +) -> [EF; 4] { + chip_proof + .tower_proof + .logup_specs_eval + .iter() + .find_map(|spec_layers| spec_layers.get(layer_idx)) + .map(|evals| { + let mut claim = [EF::ZERO; 4]; + for (dst, src) in claim.iter_mut().zip(evals.iter()) { + *dst = *src; + } + claim + }) + .unwrap_or([EF::ZERO; 4]) +} + +fn convert_sumcheck_evals(msg: &IOPProverMessage) -> [EF; 3] { + let mut evals = [EF::ZERO; 3]; + for (dst, src) in evals.iter_mut().zip(msg.evaluations.iter()) { + *dst = *src; + } + evals +} + +pub(crate) fn interpolate_pair(values: [EF; 2], mu: EF) -> EF { + let delta = values[1] - values[0]; + values[0] + delta * mu +} + +fn accumulate_prod_claims(rows: &[[EF; 2]], lambda: EF, lambda_prime: EF, mu: EF) -> (EF, EF) { + let mut pow_lambda = EF::ONE; + let mut pow_lambda_prime = EF::ONE; + let mut acc_sum = EF::ZERO; + let mut acc_sum_prime = EF::ZERO; + + for pair in rows { + let p_xi = interpolate_pair(*pair, mu); + let prime_product = pair[0] * pair[1]; + acc_sum += pow_lambda * p_xi; + acc_sum_prime += pow_lambda_prime * prime_product; + pow_lambda *= lambda; + pow_lambda_prime *= lambda_prime; + } + + (acc_sum, acc_sum_prime) +} + +fn accumulate_logup_claims(rows: &[[EF; 4]], lambda: EF, lambda_prime: EF, mu: EF) -> (EF, EF) { + let mut pow_lambda = EF::ONE; + let mut pow_lambda_prime = EF::ONE; + let mut acc_sum = EF::ZERO; + let mut acc_q = EF::ZERO; + + for quad in rows { + let p_vals = [quad[0], quad[1]]; + let q_vals = [quad[2], quad[3]]; + let p_xi = interpolate_pair(p_vals, mu); + let q_xi = interpolate_pair(q_vals, mu); + acc_sum += pow_lambda * (p_xi + lambda * q_xi); + let q_cross = quad[2] * quad[3]; + acc_q += pow_lambda_prime * lambda_prime * q_cross; + pow_lambda *= lambda; + pow_lambda_prime *= lambda_prime; + } + + (acc_sum, acc_q) +} + +fn circuit_vk_for_idx<'a>( + vk: &'a RecursionVk, + chip_idx: usize, +) -> Option<&'a VerifyingKey> { + vk.circuit_index_to_name + .get(&chip_idx) + .and_then(|name| vk.circuit_vks.get(name)) +} + +fn build_chip_records( + proof_idx: usize, + chip_idx: usize, + chip_proof: &ZKVMChipProof, + _circuit_vk: &VerifyingKey, + replay: &TowerReplayResult, + alpha_logup: EF, + tidx: usize, +) -> Result<( + TowerInputRecord, + TowerLayerRecord, + TowerTowerEvalRecord, + TowerSumcheckRecord, + Vec, + EF, +)> { + let spec_layer_count = chip_proof + .tower_proof + .logup_specs_eval + .iter() + .map(Vec::len) + .chain(chip_proof.tower_proof.prod_specs_eval.iter().map(Vec::len)) + .max() + .unwrap_or(0); + let layer_count = replay.layers.len().max(spec_layer_count); + + let read_count = chip_proof.r_out_evals.len(); + let write_count = chip_proof.w_out_evals.len(); + let logup_count = chip_proof.lk_out_evals.len(); + + let mut read_layers = vec![Vec::with_capacity(read_count); layer_count]; + let mut write_layers = vec![Vec::with_capacity(write_count); layer_count]; + let mut logup_layers = vec![Vec::with_capacity(logup_count); layer_count]; + + for (spec_idx, rounds) in chip_proof.tower_proof.prod_specs_eval.iter().enumerate() { + for layer_idx in 0..layer_count { + let mut pair = [EF::ZERO; 2]; + if let Some(values) = rounds.get(layer_idx) { + for (dst, src) in pair.iter_mut().zip(values.iter().take(2)) { + *dst = *src; + } + } + if spec_idx < read_count { + read_layers[layer_idx].push(pair); + } else { + write_layers[layer_idx].push(pair); + } + } + } + + for rounds in &chip_proof.tower_proof.logup_specs_eval { + for layer_idx in 0..layer_count { + let mut quad = [EF::ZERO; 4]; + if let Some(values) = rounds.get(layer_idx) { + for (dst, src) in quad.iter_mut().zip(values.iter().take(4)) { + *dst = *src; + } + } + logup_layers[layer_idx].push(quad); + } + } + + let tower_record = TowerTowerEvalRecord { + read_layers, + write_layers, + logup_layers, + }; + + let mut layer_record = TowerLayerRecord { + proof_idx, + idx: chip_idx, + tidx: 0, + layer_claims: Vec::with_capacity(layer_count), + lambdas: vec![EF::ZERO; layer_count], + eq_at_r_primes: vec![EF::ZERO; layer_count], + read_counts: vec![1; layer_count], + write_counts: vec![1; layer_count], + logup_counts: vec![1; layer_count], + read_claims: vec![EF::ZERO; layer_count], + read_prime_claims: vec![EF::ZERO; layer_count], + write_claims: vec![EF::ZERO; layer_count], + write_prime_claims: vec![EF::ZERO; layer_count], + logup_claims: vec![EF::ZERO; layer_count], + logup_prime_claims: vec![EF::ZERO; layer_count], + sumcheck_claims: vec![EF::ZERO; layer_count], + }; + + for layer_idx in 0..layer_count { + let read_len = tower_record + .read_layers + .get(layer_idx) + .map(|rows| rows.len()) + .unwrap_or(0); + let write_len = tower_record + .write_layers + .get(layer_idx) + .map(|rows| rows.len()) + .unwrap_or(0); + let logup_len = tower_record + .logup_layers + .get(layer_idx) + .map(|rows| rows.len()) + .unwrap_or(0); + // NOTE: some chip only got read or write + // eyre::ensure!( + // read_len == write_len, + // "read/write prod spec count mismatch at layer {layer_idx}: read={read_len}, write={write_len}" + // ); + layer_record.read_counts[layer_idx] = read_len.max(1); + layer_record.write_counts[layer_idx] = write_len.max(1); + layer_record.logup_counts[layer_idx] = logup_len.max(1); + } + + for layer_idx in 0..layer_count { + layer_record + .layer_claims + .push(convert_logup_claim(chip_proof, layer_idx)); + } + + let input_layer_claim = layer_record + .layer_claims + .last() + .map(|claim| claim[0]) + .unwrap_or(EF::ZERO); + + let mut sumcheck_record = TowerSumcheckRecord { + proof_idx, + tidx: 0, + evals: Vec::new(), + ris: Vec::new(), + claims: vec![EF::ZERO; layer_count], + }; + + for round_msgs in &chip_proof.tower_proof.proofs { + for msg in round_msgs { + sumcheck_record.evals.push(convert_sumcheck_evals(msg)); + } + } + let mut mus_record = vec![EF::ZERO; layer_count]; + + let q0_claim = chip_proof + .lk_out_evals + .get(0) + .and_then(|evals| evals.get(2)) + .copied() + .unwrap_or(EF::ZERO); + + let input_record = TowerInputRecord { + proof_idx, + idx: chip_idx, + tidx, + n_logup: layer_count, + alpha_logup, + input_layer_claim, + }; + let flattened_ris: Vec = replay + .layers + .iter() + .flat_map(|layer| layer.challenges.iter().copied()) + .collect(); + sumcheck_record.ris = flattened_ris; + if !replay.layers.is_empty() { + eyre::ensure!( + sumcheck_record.ris.len() == sumcheck_record.evals.len(), + "tower replay produced mismatched round counts: replay challenges={}, sumcheck eval rounds={}", + sumcheck_record.ris.len(), + sumcheck_record.evals.len() + ); + } + for (layer_idx, data) in replay.layers.iter().enumerate() { + if layer_idx < layer_record.eq_at_r_primes.len() { + layer_record.eq_at_r_primes[layer_idx] = data.eq_at_r; + layer_record.lambdas[layer_idx] = data.lambda; + mus_record[layer_idx] = data.mu; + } + if layer_idx < sumcheck_record.claims.len() { + sumcheck_record.claims[layer_idx] = data.claim_in; + layer_record.sumcheck_claims[layer_idx] = data.claim_in; + } + } + + for layer_idx in 0..layer_count { + let lambda = layer_record + .lambdas + .get(layer_idx) + .copied() + .unwrap_or(EF::ZERO); + let lambda_prime = layer_record.lambda_prime_at(layer_idx); + let mu = mus_record.get(layer_idx).copied().unwrap_or(EF::ZERO); + + if let Some(rows) = tower_record.read_layers.get(layer_idx) { + let (claim, prime) = accumulate_prod_claims(rows, lambda, lambda_prime, mu); + layer_record.read_claims[layer_idx] = claim; + layer_record.read_prime_claims[layer_idx] = prime; + } + if let Some(rows) = tower_record.write_layers.get(layer_idx) { + let (claim, prime) = accumulate_prod_claims(rows, lambda, lambda_prime, mu); + layer_record.write_claims[layer_idx] = claim; + layer_record.write_prime_claims[layer_idx] = prime; + } + if let Some(rows) = tower_record.logup_layers.get(layer_idx) { + let (claim, prime) = accumulate_logup_claims(rows, lambda, lambda_prime, mu); + layer_record.logup_claims[layer_idx] = claim; + layer_record.logup_prime_claims[layer_idx] = prime; + } + } + + Ok(( + input_record, + layer_record, + tower_record, + sumcheck_record, + mus_record, + q0_claim, + )) +} + +impl AirModule for TowerModule { + fn num_airs(&self) -> usize { + TowerModuleChipDiscriminants::COUNT + } + + fn airs>(&self) -> Vec> { + let gkr_input_air = TowerInputAir { + tower_module_bus: self.bus_inventory.tower_module_bus, + main_bus: self.bus_inventory.main_bus, + transcript_bus: self.bus_inventory.transcript_bus, + layer_input_bus: self.layer_input_bus, + layer_output_bus: self.layer_output_bus, + }; + + let gkr_layer_air = TowerLayerAir { + transcript_bus: self.bus_inventory.transcript_bus, + air_shape_bus: self.bus_inventory.air_shape_bus, + layer_input_bus: self.layer_input_bus, + layer_output_bus: self.layer_output_bus, + sumcheck_input_bus: self.sumcheck_input_bus, + sumcheck_output_bus: self.sumcheck_output_bus, + sumcheck_challenge_bus: self.sumcheck_challenge_bus, + prod_read_claim_input_bus: self.prod_read_claim_input_bus, + prod_read_claim_bus: self.prod_read_claim_bus, + prod_write_claim_input_bus: self.prod_write_claim_input_bus, + prod_write_claim_bus: self.prod_write_claim_bus, + logup_claim_input_bus: self.logup_claim_input_bus, + logup_claim_bus: self.logup_claim_bus, + }; + + let gkr_prod_read_sum_air = TowerProdReadSumCheckClaimAir { + transcript_bus: self.bus_inventory.transcript_bus, + prod_claim_input_bus: self.prod_read_claim_input_bus, + prod_claim_bus: self.prod_read_claim_bus, + }; + + let gkr_prod_write_sum_air = TowerProdWriteSumCheckClaimAir { + transcript_bus: self.bus_inventory.transcript_bus, + prod_claim_input_bus: self.prod_write_claim_input_bus, + prod_claim_bus: self.prod_write_claim_bus, + }; + + let gkr_logup_sum_air = TowerLogupSumCheckClaimAir { + transcript_bus: self.bus_inventory.transcript_bus, + logup_claim_input_bus: self.logup_claim_input_bus, + logup_claim_bus: self.logup_claim_bus, + }; + + let gkr_sumcheck_air = TowerLayerSumcheckAir::new( + self.bus_inventory.transcript_bus, + self.bus_inventory.xi_randomness_bus, + self.sumcheck_input_bus, + self.sumcheck_output_bus, + self.sumcheck_challenge_bus, + ); + + vec![ + Arc::new(gkr_input_air) as AirRef<_>, + Arc::new(gkr_layer_air) as AirRef<_>, + Arc::new(gkr_prod_read_sum_air) as AirRef<_>, + Arc::new(gkr_prod_write_sum_air) as AirRef<_>, + Arc::new(gkr_logup_sum_air) as AirRef<_>, + Arc::new(gkr_sumcheck_air) as AirRef<_>, + ] + } +} + +impl TowerModule { + #[tracing::instrument(skip_all)] + fn generate_blob( + &self, + child_vk: &RecursionVk, + proofs: &[RecursionProof], + preflights: &[Preflight], + exp_bits_len_gen: &ExpBitsLenTraceGenerator, + ) -> Result { + let _ = (self, preflights, exp_bits_len_gen); + build_gkr_blob(child_vk, proofs, preflights) + } +} + +pub(crate) fn build_gkr_blob( + child_vk: &RecursionVk, + proofs: &[RecursionProof], + preflights: &[Preflight], +) -> Result { + let mut input_records = Vec::new(); + let mut layer_records = Vec::new(); + let mut tower_records = Vec::new(); + let mut sumcheck_records = Vec::new(); + let mut mus_records = Vec::new(); + let mut q0_claims = Vec::new(); + + eyre::ensure!( + proofs.len() == preflights.len(), + "proof/preflight length mismatch" + ); + + for (proof_idx, (proof, preflight)) in proofs.iter().zip(preflights).enumerate() { + let mut has_chip = false; + let mut chip_preflight_entries = preflight.gkr.chips.iter(); + for (&chip_idx, chip_instances) in &proof.chip_proofs { + for (instance_idx, chip_proof) in chip_instances.iter().enumerate() { + has_chip = true; + let pf_entry = chip_preflight_entries.next().ok_or_else(|| { + eyre::eyre!( + "missing GKR preflight entry for chip {chip_idx} instance {instance_idx}" + ) + })?; + if pf_entry.chip_idx != chip_idx || pf_entry.instance_idx != instance_idx { + return Err(eyre::eyre!( + "tower preflight chip mismatch (expected ({}, {}), found ({}, {}))", + chip_idx, + instance_idx, + pf_entry.chip_idx, + pf_entry.instance_idx + )); + } + let mut ts = ReadOnlyTranscript::new(&preflight.transcript, pf_entry.tidx); + let alpha_logup = record_gkr_transcript(&mut ts, chip_idx, chip_proof); + + let circuit_vk = circuit_vk_for_idx(child_vk, chip_idx).ok_or_else(|| { + eyre::eyre!("missing circuit verifying key for index {chip_idx}") + })?; + println!( + "processing chip name: {:?}", + child_vk.circuit_index_to_name.get(&chip_idx) + ); + let ( + input_record, + layer_record, + tower_record, + sumcheck_record, + mus_record, + q0_claim, + ) = build_chip_records( + proof_idx, + chip_idx, + chip_proof, + circuit_vk, + &pf_entry.tower_replay, + alpha_logup, + pf_entry.tidx, + )?; + input_records.push(input_record); + layer_records.push(layer_record); + tower_records.push(tower_record); + sumcheck_records.push(sumcheck_record); + mus_records.push(mus_record); + q0_claims.push(q0_claim); + } + } + + if !has_chip { + input_records.push(TowerInputRecord { + proof_idx, + ..Default::default() + }); + layer_records.push(TowerLayerRecord { + idx: 0, + proof_idx, + ..Default::default() + }); + tower_records.push(TowerTowerEvalRecord::default()); + sumcheck_records.push(TowerSumcheckRecord { + proof_idx, + ..Default::default() + }); + mus_records.push(vec![]); + q0_claims.push(EF::ZERO); + } + } + + if input_records.is_empty() { + input_records.push(TowerInputRecord::default()); + layer_records.push(TowerLayerRecord::default()); + sumcheck_records.push(TowerSumcheckRecord::default()); + tower_records.push(TowerTowerEvalRecord::default()); + mus_records.push(vec![]); + q0_claims.push(EF::ZERO); + } + + Ok(TowerBlobCpu { + input_records, + layer_records, + tower_records, + sumcheck_records, + mus_records, + q0_claims, + }) +} + +fn record_gkr_transcript( + ts: &mut TS, + _chip_idx: usize, + chip_proof: &ZKVMChipProof, +) -> EF +where + TS: FiatShamirTranscript, +{ + if let Some(q0) = chip_proof + .lk_out_evals + .get(0) + .and_then(|evals| evals.get(2)) + { + ts.observe_ext(*q0); + } + FiatShamirTranscript::::sample_ext(ts) +} + +impl> TraceGenModule> for TowerModule { + type ModuleSpecificCtx<'a> = ExpBitsLenTraceGenerator; + + #[tracing::instrument(skip_all)] + fn generate_proving_ctxs( + &self, + child_vk: &RecursionVk, + proofs: &[RecursionProof], + preflights: &[Preflight], + ctx: &ExpBitsLenTraceGenerator, + required_heights: Option<&[usize]>, + ) -> Option>>> { + let blob = match self.generate_blob(child_vk, proofs, preflights, ctx) { + Ok(blob) => blob, + Err(err) => { + error!(?err, "failed to build GKR trace blob"); + eprintln!("failed to build GKR trace blob: {err:?}"); + return None; + } + }; + let chips = [ + TowerModuleChip::Input, + TowerModuleChip::Layer, + TowerModuleChip::ProdReadClaim, + TowerModuleChip::ProdWriteClaim, + TowerModuleChip::LogupClaim, + TowerModuleChip::LayerSumcheck, + ]; + + let span = tracing::Span::current(); + chips + .par_iter() + .map(|chip| { + let _guard = span.enter(); + chip.generate_proving_ctx( + &blob, + required_heights.and_then(|heights| heights.get(chip.index()).copied()), + ) + }) + .collect::>() + .into_iter() + .collect() + } +} + +// To reduce the number of structs and trait implementations, we collect them into a single enum +// with enum dispatch. +#[derive(strum_macros::Display, strum::EnumDiscriminants)] +#[strum_discriminants(derive(strum_macros::EnumCount))] +#[strum_discriminants(repr(usize))] +enum TowerModuleChip { + Input, + Layer, + ProdReadClaim, + ProdWriteClaim, + LogupClaim, + LayerSumcheck, +} + +impl TowerModuleChip { + fn index(&self) -> usize { + TowerModuleChipDiscriminants::from(self) as usize + } +} + +impl RowMajorChip for TowerModuleChip { + type Ctx<'a> = TowerBlobCpu; + + #[tracing::instrument( + name = "wrapper.generate_trace", + level = "trace", + skip_all, + fields(air = %self) + )] + fn generate_trace( + &self, + blob: &Self::Ctx<'_>, + required_height: Option, + ) -> Option> { + use TowerModuleChip::*; + match self { + Input => TowerInputTraceGenerator + .generate_trace(&(&blob.input_records, &blob.q0_claims), required_height), + Layer => TowerLayerTraceGenerator.generate_trace( + &(&blob.layer_records, &blob.mus_records, &blob.q0_claims), + required_height, + ), + ProdReadClaim => TowerProdReadSumCheckClaimTraceGenerator.generate_trace( + &(&blob.layer_records, &blob.tower_records, &blob.mus_records), + required_height, + ), + ProdWriteClaim => TowerProdWriteSumCheckClaimTraceGenerator.generate_trace( + &(&blob.layer_records, &blob.tower_records, &blob.mus_records), + required_height, + ), + LogupClaim => TowerLogupSumCheckClaimTraceGenerator.generate_trace( + &(&blob.layer_records, &blob.tower_records, &blob.mus_records), + required_height, + ), + LayerSumcheck => TowerSumcheckTraceGenerator.generate_trace( + &(&blob.sumcheck_records, &blob.mus_records), + required_height, + ), + } + } +} + +#[cfg(feature = "cuda")] +mod cuda_tracegen { + use openvm_cuda_backend::GpuBackend; + + use super::*; + use crate::{ + cuda::{GlobalCtxGpu, preflight::PreflightGpu, proof::ProofGpu, vk::VerifyingKeyGpu}, + tracegen::cuda::generate_gpu_proving_ctx, + }; + + impl TraceGenModule for TowerModule { + type ModuleSpecificCtx<'a> = ExpBitsLenTraceGenerator; + + #[tracing::instrument(skip_all)] + fn generate_proving_ctxs( + &self, + child_vk: &VerifyingKeyGpu, + proofs: &[ProofGpu], + preflights: &[PreflightGpu], + exp_bits_len_gen: &ExpBitsLenTraceGenerator, + required_heights: Option<&[usize]>, + ) -> Option>> { + let proofs_cpu: Vec<_> = proofs.iter().map(|proof| proof.cpu.clone()).collect(); + let preflights_cpu: Vec<_> = preflights + .iter() + .map(|preflight| preflight.cpu.clone()) + .collect(); + let blob = match self.generate_blob( + &child_vk.cpu, + &proofs_cpu, + &preflights_cpu, + exp_bits_len_gen, + ) { + Ok(blob) => blob, + Err(err) => { + error!(?err, "failed to build GKR trace blob (cuda)"); + return None; + } + }; + + let chips = [ + TowerModuleChip::Input, + TowerModuleChip::Layer, + TowerModuleChip::ProdReadClaim, + TowerModuleChip::ProdWriteClaim, + TowerModuleChip::LogupClaim, + TowerModuleChip::LayerSumcheck, + ]; + + chips + .iter() + .map(|chip| { + generate_gpu_proving_ctx( + chip, + &blob, + required_heights.and_then(|heights| heights.get(chip.index()).copied()), + ) + }) + .collect() + } + } +} diff --git a/ceno_recursion_v2/src/tower/sumcheck/air.rs b/ceno_recursion_v2/src/tower/sumcheck/air.rs new file mode 100644 index 000000000..a7a564007 --- /dev/null +++ b/ceno_recursion_v2/src/tower/sumcheck/air.rs @@ -0,0 +1,396 @@ +use core::borrow::Borrow; + +use openvm_circuit_primitives::{SubAir, utils::assert_array_eq}; +use openvm_stark_backend::{ + BaseAirWithPublicValues, PartitionedBaseAir, interaction::InteractionBuilder, +}; +use openvm_stark_sdk::config::baby_bear_poseidon2::D_EF; +use p3_air::{Air, AirBuilder, BaseAir}; +use p3_field::{Field, PrimeCharacteristicRing, extension::BinomiallyExtendable}; +use p3_matrix::Matrix; +use stark_recursion_circuit_derive::AlignedBorrow; + +use crate::tower::bus::{ + TowerSumcheckChallengeBus, TowerSumcheckChallengeMessage, TowerSumcheckInputBus, + TowerSumcheckInputMessage, TowerSumcheckOutputBus, TowerSumcheckOutputMessage, +}; +use recursion_circuit::{ + bus::{TranscriptBus, XiRandomnessBus, XiRandomnessMessage}, + subairs::nested_for_loop::{NestedForLoopIoCols, NestedForLoopSubAir}, + utils::{ + assert_one_ext, ext_field_add, ext_field_multiply, ext_field_multiply_scalar, + ext_field_one_minus, ext_field_subtract, + }, +}; + +#[repr(C)] +#[derive(AlignedBorrow, Debug)] +pub struct TowerLayerSumcheckCols { + /// Whether the current row is enabled (i.e. not padding) + pub is_enabled: T, + pub proof_idx: T, + pub idx: T, + pub layer_idx: T, + pub is_first_idx: T, + pub is_first_layer: T, + pub is_first_round: T, + + /// An enabled row which is not involved in any interactions + /// but should satisfy air constraints + pub is_dummy: T, + + pub is_last_layer: T, + + /// Sumcheck sub-round index within this layer_idx (0..layer_idx-1) + // perf(ayush): can probably remove round if XiRandomnessMessage takes tidx instead + pub round: T, + + /// Transcript index + pub tidx: T, + + /// s(1) in extension field + pub ev1: [T; D_EF], + /// s(2) in extension field + pub ev2: [T; D_EF], + /// s(3) in extension field + pub ev3: [T; D_EF], + + /// The claim coming into this sub-round (either from previous sub-round or initial) + pub claim_in: [T; D_EF], + /// The claim going out of this sub-round (result of cubic interpolation) + pub claim_out: [T; D_EF], + + /// Component `round` of the original point ξ^{(j-1)} + /// (corresponding to `gkr_r[round]`) + pub prev_challenge: [T; D_EF], + /// The sampled challenge for this sub-round (corresponds to `ri`) + pub challenge: [T; D_EF], + + /// The eq value coming into this sub-round + pub eq_in: [T; D_EF], + /// The eq value going out (updated for this round) + pub eq_out: [T; D_EF], +} + +pub struct TowerLayerSumcheckAir { + pub transcript_bus: TranscriptBus, + pub xi_randomness_bus: XiRandomnessBus, + pub sumcheck_input_bus: TowerSumcheckInputBus, + pub sumcheck_output_bus: TowerSumcheckOutputBus, + pub sumcheck_challenge_bus: TowerSumcheckChallengeBus, +} + +impl TowerLayerSumcheckAir { + pub fn new( + transcript_bus: TranscriptBus, + xi_randomness_bus: XiRandomnessBus, + sumcheck_input_bus: TowerSumcheckInputBus, + sumcheck_output_bus: TowerSumcheckOutputBus, + sumcheck_challenge_bus: TowerSumcheckChallengeBus, + ) -> Self { + Self { + transcript_bus, + xi_randomness_bus, + sumcheck_input_bus, + sumcheck_output_bus, + sumcheck_challenge_bus, + } + } +} + +impl BaseAir for TowerLayerSumcheckAir { + fn width(&self) -> usize { + TowerLayerSumcheckCols::::width() + } +} + +impl BaseAirWithPublicValues for TowerLayerSumcheckAir {} +impl PartitionedBaseAir for TowerLayerSumcheckAir {} + +impl Air for TowerLayerSumcheckAir +where + ::PrimeSubfield: BinomiallyExtendable<{ D_EF }>, +{ + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let (local, next) = ( + main.row_slice(0).expect("window should have two elements"), + main.row_slice(1).expect("window should have two elements"), + ); + let local: &TowerLayerSumcheckCols = (*local).borrow(); + let next: &TowerLayerSumcheckCols = (*next).borrow(); + + /////////////////////////////////////////////////////////////////////// + // Boolean Constraints + /////////////////////////////////////////////////////////////////////// + + builder.assert_bool(local.is_dummy); + builder.assert_bool(local.is_last_layer); + + /////////////////////////////////////////////////////////////////////// + // Proof Index and Loop Constraints + /////////////////////////////////////////////////////////////////////// + + type LoopSubAir = NestedForLoopSubAir<3>; + LoopSubAir {}.eval( + builder, + ( + NestedForLoopIoCols { + is_enabled: local.is_enabled, + counter: [local.proof_idx, local.idx, local.layer_idx], + is_first: [ + local.is_first_idx, + local.is_first_layer, + local.is_first_round, + ], + } + .map_into(), + NestedForLoopIoCols { + is_enabled: next.is_enabled, + counter: [next.proof_idx, next.idx, next.layer_idx], + is_first: [next.is_first_idx, next.is_first_layer, next.is_first_round], + } + .map_into(), + ), + ); + + let is_transition_round = + LoopSubAir::local_is_transition(next.is_enabled, next.is_first_round); + let is_last_round = + LoopSubAir::local_is_last(local.is_enabled, next.is_enabled, next.is_first_round); + + // Sumcheck round flag starts at 0 + builder.when(local.is_first_round).assert_zero(local.round); + // Sumcheck round flag increments by 1 + builder + .when(is_transition_round.clone()) + .assert_eq(next.round, local.round + AB::Expr::ONE); + // Sumcheck round flag end + builder + .when(is_last_round.clone()) + .assert_eq(local.round, local.layer_idx - AB::Expr::ONE); + + /////////////////////////////////////////////////////////////////////// + // Round Constraints + /////////////////////////////////////////////////////////////////////// + + // Eq initialization: eq_in = 1 at first round + assert_one_ext(&mut builder.when(local.is_first_round), local.eq_in); + + // Eq update: incrementally compute eq *= (xi * ri + (1-xi) * (1-ri)) + let eq_out: [AB::Expr; D_EF] = + update_eq(local.eq_in, local.prev_challenge, local.challenge); + assert_array_eq(&mut builder.when(local.is_enabled), local.eq_out, eq_out); + + // Eq propagation + assert_array_eq( + &mut builder.when(is_transition_round.clone()), + local.eq_out, + next.eq_in, + ); + + // Compute s(0) = claim_in - s(1) + let ev0: [AB::Expr; D_EF] = ext_field_subtract(local.claim_in, local.ev1); + + // Cubic interpolation: compute claim_out from polynomial evals at 0,1,2,3 + let claim_out: [AB::Expr; D_EF] = + interpolate_cubic_at_0123(ev0, local.ev1, local.ev2, local.ev3, local.challenge); + assert_array_eq(builder, local.claim_out, claim_out); + + // Claim propagation + assert_array_eq( + &mut builder.when(is_transition_round.clone()), + local.claim_out, + next.claim_in, + ); + + // Transcript index increment + builder.when(is_transition_round.clone()).assert_eq( + next.tidx, + local.tidx.into() + AB::Expr::from_usize(4 * D_EF), + ); + + /////////////////////////////////////////////////////////////////////// + // Module Interactions + /////////////////////////////////////////////////////////////////////// + + let is_not_dummy = AB::Expr::ONE - local.is_dummy; + + // 1. TowerSumcheckInputBus + // 1a. Receive initial sumcheck input on first round + self.sumcheck_input_bus.receive( + builder, + local.proof_idx, + TowerSumcheckInputMessage { + idx: local.idx.into(), + layer_idx: local.layer_idx.into(), + is_last_layer: local.is_last_layer.into(), + tidx: local.tidx.into(), + claim: local.claim_in.map(Into::into), + }, + local.is_first_round * is_not_dummy.clone(), + ); + // 2. TowerSumcheckOutputBus + // 2a. Send output back to TowerLayerAir on final round + self.sumcheck_output_bus.send( + builder, + local.proof_idx, + TowerSumcheckOutputMessage { + idx: local.idx.into(), + layer_idx: local.layer_idx.into(), + tidx: local.tidx.into() + AB::Expr::from_usize(4 * D_EF), + claim_out: local.claim_out.map(Into::into), + eq_at_r_prime: local.eq_out.map(Into::into), + }, + is_last_round.clone() * is_not_dummy.clone(), + ); + + // 3. TowerSumcheckChallengeBus + // 3a. Receive challenge from previous GKR layer_idx sumcheck + self.sumcheck_challenge_bus.receive( + builder, + local.proof_idx, + TowerSumcheckChallengeMessage { + idx: local.idx.clone().into(), + layer_idx: local.layer_idx - AB::Expr::ONE, + sumcheck_round: local.round.into(), + challenge: local.prev_challenge.map(Into::into), + }, + local.is_enabled * is_not_dummy.clone(), + ); + // 3b. Send challenge to next GKR layer_idx sumcheck for eq calculation + self.sumcheck_challenge_bus.send( + builder, + local.proof_idx, + TowerSumcheckChallengeMessage { + idx: local.idx.into(), + layer_idx: local.layer_idx.into(), + sumcheck_round: local.round.into() + AB::Expr::ONE, + challenge: local.challenge.map(Into::into), + }, + local.is_enabled * (AB::Expr::ONE - local.is_last_layer) * is_not_dummy.clone(), + ); + + /////////////////////////////////////////////////////////////////////// + // External Interactions + /////////////////////////////////////////////////////////////////////// + + // 1. TranscriptBus + // 1a. Observe evaluations + let mut tidx = local.tidx.into(); + for eval in [local.ev1, local.ev2, local.ev3].into_iter() { + self.transcript_bus.observe_ext( + builder, + local.proof_idx, + tidx.clone(), + eval, + local.is_enabled * is_not_dummy.clone(), + ); + tidx += AB::Expr::from_usize(D_EF); + } + // 1b. Sample challenge `ri` + self.transcript_bus.sample_ext( + builder, + local.proof_idx, + tidx, + local.challenge, + local.is_enabled * is_not_dummy.clone(), + ); + + // 2. XiRandomnessBus + // 2a. Send last challenge + self.xi_randomness_bus.send( + builder, + local.proof_idx, + XiRandomnessMessage { + idx: local.round + AB::Expr::ONE, + xi: local.challenge.map(Into::into), + }, + local.is_enabled * local.is_last_layer * is_not_dummy.clone(), + ); + } +} + +/// Interpolates a cubic polynomial at a point using evaluations at 0, 1, 2, 3. +/// +/// Given evaluations `claim_in, ev1, ev2, ev3` (where ev0 = claim_in - ev1) and a point `x`, +/// computes `f(x)` using Lagrange interpolation optimized for these specific points. +pub(super) fn interpolate_cubic_at_0123( + ev0: [FA; D_EF], + ev1: [F; D_EF], + ev2: [F; D_EF], + ev3: [F; D_EF], + x: [F; D_EF], +) -> [FA; D_EF] +where + F: Into + Copy, + FA: PrimeCharacteristicRing, + FA::PrimeSubfield: BinomiallyExtendable<{ D_EF }>, +{ + let three: FA = FA::from_usize(3); + let inv2: FA = FA::from_prime_subfield(FA::PrimeSubfield::from_usize(2).inverse()); + let inv6: FA = FA::from_prime_subfield(FA::PrimeSubfield::from_usize(6).inverse()); + + // s1 = ev1 - ev0 + let s1: [FA; D_EF] = ext_field_subtract(ev1, ev0.clone()); + // s2 = ev2 - ev0 + let s2: [FA; D_EF] = ext_field_subtract(ev2, ev0.clone()); + // s3 = ev3 - ev0 + let s3: [FA; D_EF] = ext_field_subtract(ev3, ev0.clone()); + + // d3 = s3 - (s2 - s1) * 3 + let d3: [FA; D_EF] = ext_field_subtract::( + s3, + ext_field_multiply_scalar::(ext_field_subtract::(s2.clone(), s1.clone()), three), + ); + + // p = d3 / 6 + let p: [FA; D_EF] = ext_field_multiply_scalar(d3.clone(), inv6); + + // q = (s2 - d3) / 2 - s1 + let q: [FA; D_EF] = ext_field_subtract::( + ext_field_multiply_scalar::(ext_field_subtract::(s2, d3), inv2), + s1.clone(), + ); + + // r = s1 - p - q + let r: [FA; D_EF] = ext_field_subtract::(s1, ext_field_add::(p.clone(), q.clone())); + + // result = ((p * x + q) * x + r) * x + ev0 + ext_field_add::( + ext_field_multiply::( + ext_field_add::( + ext_field_multiply::(ext_field_add::(ext_field_multiply::(p, x), q), x), + r, + ), + x, + ), + ev0, + ) +} + +/// Updates the eq evaluation incrementally for one sumcheck round. +/// +/// Computes: `eq_out = eq_in * (prev_challenge * challenge + (1 - prev_challenge) * (1 - +/// challenge))` where `prev_challenge` is xi and `challenge` is ri. +pub(super) fn update_eq( + eq_in: [F; D_EF], + prev_challenge: [F; D_EF], + challenge: [F; D_EF], +) -> [FA; D_EF] +where + F: Into + Copy, + FA: PrimeCharacteristicRing, + FA::PrimeSubfield: BinomiallyExtendable<{ D_EF }>, +{ + ext_field_multiply::( + eq_in, + ext_field_add::( + ext_field_multiply::(prev_challenge, challenge), + ext_field_multiply::( + ext_field_one_minus::(prev_challenge), + ext_field_one_minus::(challenge), + ), + ), + ) +} diff --git a/ceno_recursion_v2/src/tower/sumcheck/mod.rs b/ceno_recursion_v2/src/tower/sumcheck/mod.rs new file mode 100644 index 000000000..efc5f9af9 --- /dev/null +++ b/ceno_recursion_v2/src/tower/sumcheck/mod.rs @@ -0,0 +1,5 @@ +mod air; +mod trace; + +pub use air::{TowerLayerSumcheckAir, TowerLayerSumcheckCols}; +pub use trace::{TowerSumcheckRecord, TowerSumcheckTraceGenerator}; diff --git a/ceno_recursion_v2/src/tower/sumcheck/trace.rs b/ceno_recursion_v2/src/tower/sumcheck/trace.rs new file mode 100644 index 000000000..f0742c14b --- /dev/null +++ b/ceno_recursion_v2/src/tower/sumcheck/trace.rs @@ -0,0 +1,237 @@ +use core::borrow::BorrowMut; + +use openvm_stark_backend::{p3_maybe_rayon::prelude::*, poly_common::interpolate_cubic_at_0123}; +use openvm_stark_sdk::config::baby_bear_poseidon2::{D_EF, EF, F}; +use p3_field::{BasedVectorSpace, PrimeCharacteristicRing}; +use p3_matrix::dense::RowMajorMatrix; + +use super::TowerLayerSumcheckCols; +use crate::tracegen::RowMajorChip; + +#[derive(Default, Debug, Clone)] +pub struct TowerSumcheckRecord { + pub proof_idx: usize, + pub tidx: usize, + pub evals: Vec<[EF; 3]>, + pub ris: Vec, + pub claims: Vec, +} + +impl TowerSumcheckRecord { + #[inline] + pub fn num_layers(&self) -> usize { + self.claims.len() + } + + #[inline] + pub fn total_rounds(&self) -> usize { + let layers = self.num_layers(); + layers * (layers + 1) / 2 + } + + #[inline] + fn layer_start_index(layer_idx: usize) -> usize { + layer_idx * (layer_idx + 1) / 2 + } + + #[inline] + fn layer_rounds(layer_idx: usize) -> usize { + layer_idx + 1 + } + + #[inline] + fn derive_tidx(&self, layer_idx: usize, round_in_layer: usize) -> usize { + let rounds_before_layer = Self::layer_start_index(layer_idx); + self.tidx + 4 * D_EF * (rounds_before_layer + round_in_layer) + 6 * D_EF * layer_idx + } + + #[inline] + fn prev_challenge(layer_idx: usize, round_in_layer: usize, mus: &[EF], ris: &[EF]) -> EF { + if round_in_layer == 0 { + mus[layer_idx] + } else { + let prev_layer = layer_idx + .checked_sub(1) + .expect("round_in_layer > 0 only occurs for non-root layers"); + let offset = Self::layer_start_index(prev_layer) + (round_in_layer - 1); + ris[offset] + } + } +} + +pub struct TowerSumcheckTraceGenerator; + +impl RowMajorChip for TowerSumcheckTraceGenerator { + // (gkr_sumcheck_records, mus) + type Ctx<'a> = (&'a [TowerSumcheckRecord], &'a [Vec]); + + #[tracing::instrument(level = "trace", skip_all)] + fn generate_trace( + &self, + ctx: &Self::Ctx<'_>, + required_height: Option, + ) -> Option> { + let (gkr_sumcheck_records, mus) = ctx; + debug_assert_eq!(gkr_sumcheck_records.len(), mus.len()); + + let width = TowerLayerSumcheckCols::::width(); + + // Calculate rows per proof + let rows_per_proof: Vec = gkr_sumcheck_records + .iter() + .map(|record| record.total_rounds().max(1)) + .collect(); + + // Calculate total rows + let num_valid_rows: usize = rows_per_proof.iter().sum(); + let height = if let Some(height) = required_height { + if height < num_valid_rows { + return None; + } + height + } else { + num_valid_rows.next_power_of_two() + }; + let mut trace = vec![F::ZERO; height * width]; + + // Split trace into chunks for each proof + let (data_slice, _) = trace.split_at_mut(num_valid_rows * width); + let mut trace_slices: Vec<&mut [F]> = Vec::with_capacity(rows_per_proof.len()); + let mut remaining = data_slice; + + for &num_rows in &rows_per_proof { + let chunk_size = num_rows * width; + let (chunk, rest) = remaining.split_at_mut(chunk_size); + trace_slices.push(chunk); + remaining = rest; + } + + // Process each proof in parallel + trace_slices + .par_iter_mut() + .zip(gkr_sumcheck_records.par_iter().zip(mus.par_iter())) + .for_each(|(proof_trace, (record, mus_for_proof))| { + let mus_for_proof = mus_for_proof.as_slice(); + let total_rounds = record.total_rounds(); + let num_layers = record.num_layers(); + + debug_assert_eq!(record.ris.len(), total_rounds); + debug_assert_eq!(record.evals.len(), total_rounds); + debug_assert!(mus_for_proof.len() >= num_layers); + + if total_rounds == 0 { + debug_assert_eq!(proof_trace.len(), width); + let row_data = &mut proof_trace[..width]; + let cols: &mut TowerLayerSumcheckCols = row_data.borrow_mut(); + cols.is_enabled = F::ONE; + cols.tidx = F::from_usize(D_EF); + cols.proof_idx = F::from_usize(record.proof_idx); + cols.idx = F::ZERO; + cols.layer_idx = F::ONE; + cols.is_first_round = F::ONE; + cols.is_first_idx = F::ONE; + cols.is_first_layer = F::ONE; + cols.is_last_layer = F::ONE; + cols.is_dummy = F::ONE; + cols.eq_in = [F::ONE, F::ZERO, F::ZERO, F::ZERO]; + cols.eq_out = [F::ONE, F::ZERO, F::ZERO, F::ZERO]; + cols.claim_in = [F::ONE, F::ZERO, F::ZERO, F::ZERO]; + cols.claim_out = [F::ONE, F::ZERO, F::ZERO, F::ZERO]; + return; + } + + let mut global_round_idx = 0usize; + let mut row_iter = proof_trace.chunks_mut(width); + + for layer_idx in 0..num_layers { + let layer_rounds = TowerSumcheckRecord::layer_rounds(layer_idx); + let layer_idx_value = layer_idx + 1; + let is_last_layer = layer_idx == num_layers.saturating_sub(1); + + let mut claim = record.claims[layer_idx]; + let mut eq = EF::ONE; + + for round_in_layer in 0..layer_rounds { + let challenge = record.ris[global_round_idx]; + let evals = record.evals[global_round_idx]; + let prev_challenge = TowerSumcheckRecord::prev_challenge( + layer_idx, + round_in_layer, + mus_for_proof, + &record.ris, + ); + + let prev_challenge_base: [F; D_EF] = prev_challenge + .as_basis_coefficients_slice() + .try_into() + .unwrap(); + let challenge_base: [F; D_EF] = + challenge.as_basis_coefficients_slice().try_into().unwrap(); + + let eval1_base: [F; D_EF] = + evals[0].as_basis_coefficients_slice().try_into().unwrap(); + let eval2_base: [F; D_EF] = + evals[1].as_basis_coefficients_slice().try_into().unwrap(); + let eval3_base: [F; D_EF] = + evals[2].as_basis_coefficients_slice().try_into().unwrap(); + + let claim_in_base: [F; D_EF] = + claim.as_basis_coefficients_slice().try_into().unwrap(); + let eq_in_base: [F; D_EF] = + eq.as_basis_coefficients_slice().try_into().unwrap(); + + let ev0 = claim - evals[0]; + let evals_full = [ev0, evals[0], evals[1], evals[2]]; + let claim_out = interpolate_cubic_at_0123(&evals_full, challenge); + let eq_factor = prev_challenge * challenge + + (EF::ONE - prev_challenge) * (EF::ONE - challenge); + let eq_out = eq * eq_factor; + + let claim_out_base: [F; D_EF] = + claim_out.as_basis_coefficients_slice().try_into().unwrap(); + let eq_out_base: [F; D_EF] = + eq_out.as_basis_coefficients_slice().try_into().unwrap(); + + let cols: &mut TowerLayerSumcheckCols = + row_iter.next().unwrap().borrow_mut(); + cols.is_enabled = F::ONE; + cols.proof_idx = F::from_usize(record.proof_idx); + cols.idx = F::ZERO; + + cols.layer_idx = F::from_usize(layer_idx_value); + cols.is_last_layer = F::from_bool(is_last_layer); + + cols.round = F::from_usize(round_in_layer); + cols.is_first_round = F::from_bool(round_in_layer == 0); + cols.is_first_layer = F::from_bool(round_in_layer == 0); + cols.is_first_idx = + F::from_bool(layer_idx_value == 1 && round_in_layer == 0); + + let tidx = record.derive_tidx(layer_idx, round_in_layer); + cols.tidx = F::from_usize(tidx); + + cols.ev1 = eval1_base; + cols.ev2 = eval2_base; + cols.ev3 = eval3_base; + + cols.prev_challenge = prev_challenge_base; + cols.challenge = challenge_base; + + cols.claim_in = claim_in_base; + cols.claim_out = claim_out_base; + + cols.eq_in = eq_in_base; + cols.eq_out = eq_out_base; + + claim = claim_out; + eq = eq_out; + global_round_idx += 1; + } + } + + debug_assert_eq!(global_round_idx, total_rounds); + }); + + Some(RowMajorMatrix::new(trace, width)) + } +} diff --git a/ceno_recursion_v2/src/tower/tower.rs b/ceno_recursion_v2/src/tower/tower.rs new file mode 100644 index 000000000..ee1e65def --- /dev/null +++ b/ceno_recursion_v2/src/tower/tower.rs @@ -0,0 +1,333 @@ +use std::marker::PhantomData; + +use ceno_zkvm::{ + scheme::{ZKVMChipProof, constants::NUM_FANIN}, + structs::{TowerProofs, VerifyingKey}, +}; +use eyre::{Result, ensure}; +use itertools::izip; +use mpcs::Point; +use multilinear_extensions::{ + mle::{IntoMLE, PointAndEval}, + util::ceil_log2, + virtual_poly::{VPAuxInfo, build_eq_x_r_vec_sequential, eq_eval}, +}; +use p3_field::PrimeCharacteristicRing; +use sumcheck::{ + structs::{IOPProof, IOPVerifierState}, + util::get_challenge_pows, +}; +use transcript::{Transcript, basic::BasicTranscript}; +use witness::next_pow2_instance_padding; + +use crate::system::RecursionField; + +#[derive(Debug, Clone)] +pub struct TowerLayerData { + pub claim_in: RecursionField, + pub claim_out: RecursionField, + pub eq_at_r: RecursionField, + pub mu: RecursionField, + pub lambda: RecursionField, + pub challenges: Vec, +} + +#[derive(Debug, Clone, Default)] +pub struct TowerReplayResult { + pub layers: Vec, +} + +pub fn replay_tower_proof( + chip_proof: &ZKVMChipProof, + vk: &VerifyingKey, +) -> Result { + let cs = &vk.cs; + let tower_proof = &chip_proof.tower_proof; + + let num_instances: usize = chip_proof.num_instances.iter().sum(); + let next_pow2_instance = next_pow2_instance_padding(num_instances); + let mut log2_num_instances = ceil_log2(next_pow2_instance); + if cs.has_ecc_ops() { + log2_num_instances += 1; + } + let rotation_vars = cs.rotation_vars().unwrap_or(0); + let num_var_with_rotation = log2_num_instances + rotation_vars; + + let read_count = cs.num_reads(); + let write_count = cs.num_writes(); + let lookup_count = cs.num_lks(); + let num_batched = read_count + write_count + lookup_count; + + let prod_out_evals: Vec> = chip_proof + .r_out_evals + .iter() + .chain(chip_proof.w_out_evals.iter()) + .cloned() + .collect(); + let logup_out_evals = chip_proof.lk_out_evals.clone(); + + let num_prod_spec = prod_out_evals.len(); + let num_logup_spec = logup_out_evals.len(); + ensure!( + num_prod_spec == tower_proof.prod_specs_eval.len(), + "prod spec mismatch" + ); + ensure!( + num_logup_spec == tower_proof.logup_specs_eval.len(), + "logup spec mismatch" + ); + + let mut transcript = BasicTranscript::::new(b"ceno-recursion-tower-tower"); + let log2_num_fanin = ceil_log2(NUM_FANIN); + + let mut alpha_pows = get_challenge_pows(num_prod_spec + num_logup_spec * 2, &mut transcript); + + let challenge_from_pows = |pows: &[RecursionField]| -> RecursionField { + pows.get(1).copied().unwrap_or(RecursionField::ONE) + }; + let initial_rt: Point = transcript + .sample_and_append_vec(b"product_sum", log2_num_fanin) + .into_iter() + .collect(); + + let mut prod_spec_point_n_eval: Vec> = prod_out_evals + .iter() + .map(|evals| { + PointAndEval::new( + initial_rt.clone(), + evals.clone().into_mle().evaluate(&initial_rt), + ) + }) + .collect(); + + let (mut logup_spec_p_point_n_eval, mut logup_spec_q_point_n_eval) = logup_out_evals + .iter() + .map(|evals| { + let (p1, p2, q1, q2) = (evals[0], evals[1], evals[2], evals[3]); + ( + PointAndEval::new( + initial_rt.clone(), + vec![p1, p2].into_mle().evaluate(&initial_rt), + ), + PointAndEval::new( + initial_rt.clone(), + vec![q1, q2].into_mle().evaluate(&initial_rt), + ), + ) + }) + .unzip::<_, _, Vec<_>, Vec<_>>(); + + let initial_claim = izip!(&prod_spec_point_n_eval, &alpha_pows) + .map(|(point_eval, alpha)| point_eval.eval * *alpha) + .sum::() + + izip!( + izip!(&logup_spec_p_point_n_eval, &logup_spec_q_point_n_eval) + .flat_map(|(p, q)| vec![p, q]), + &alpha_pows[num_prod_spec..] + ) + .map(|(point_eval, alpha)| point_eval.eval * *alpha) + .sum::(); + + let mut point_and_eval = PointAndEval::new(initial_rt, initial_claim); + let mut layers = Vec::new(); + let max_num_variables = num_var_with_rotation; + let num_variables = vec![num_var_with_rotation; num_batched]; + + for round in 0..max_num_variables.saturating_sub(1) { + let out_rt = point_and_eval.point.clone(); + let out_claim = point_and_eval.eval; + + let round_msgs = tower_proof + .proofs + .get(round) + .ok_or_else(|| eyre::eyre!("missing tower sumcheck round {round}"))?; + let sumcheck_claim = IOPVerifierState::verify( + out_claim, + &IOPProof { + proofs: round_msgs.clone(), + }, + &VPAuxInfo { + max_degree: NUM_FANIN + 1, + max_num_variables: (round + 1) * log2_num_fanin, + phantom: PhantomData, + }, + &mut transcript, + ); + + let rt: Point = sumcheck_claim.point.iter().map(|c| c.elements).collect(); + let eq = eq_eval(&out_rt, &rt); + + let expected = compute_expected_evaluation( + tower_proof, + round, + &alpha_pows, + eq, + &prod_spec_point_n_eval, + &logup_spec_p_point_n_eval, + &logup_spec_q_point_n_eval, + &num_variables, + )?; + // TEMP: Relax strict replay equality while refactoring transcript/plumbing. + // ensure!( + // expected == sumcheck_claim.expected_evaluation, + // "tower sumcheck mismatch at layer {round}" + // ); + + let r_merge = transcript.sample_and_append_vec(b"merge", log2_num_fanin); + let mu = r_merge[0]; + let coeffs = build_eq_x_r_vec_sequential(&r_merge); + let rt_prime = [rt.clone(), r_merge].concat(); + + let next_alpha_pows = + get_challenge_pows(num_prod_spec + num_logup_spec * 2, &mut transcript); + + update_point_evals( + tower_proof, + round, + &rt_prime, + &coeffs, + &mut prod_spec_point_n_eval, + &mut logup_spec_p_point_n_eval, + &mut logup_spec_q_point_n_eval, + ); + + let next_eval = aggregate_next_eval( + round, + &next_alpha_pows, + &num_variables, + &prod_spec_point_n_eval, + &logup_spec_p_point_n_eval, + &logup_spec_q_point_n_eval, + ); + + layers.push(TowerLayerData { + claim_in: out_claim, + claim_out: sumcheck_claim.expected_evaluation, + eq_at_r: eq, + mu, + lambda: challenge_from_pows(&alpha_pows), + challenges: sumcheck_claim.point.iter().map(|c| c.elements).collect(), + }); + + point_and_eval = PointAndEval::new(rt_prime, next_eval); + alpha_pows = next_alpha_pows; + } + + Ok(TowerReplayResult { layers }) +} + +#[allow(clippy::too_many_arguments)] +fn compute_expected_evaluation( + tower_proof: &TowerProofs, + round: usize, + alpha_pows: &[RecursionField], + eq: RecursionField, + _prod_point_eval: &[PointAndEval], + logup_p_point_eval: &[PointAndEval], + logup_q_point_eval: &[PointAndEval], + num_variables: &[usize], +) -> Result { + let num_prod_spec = tower_proof.prod_specs_eval.len(); + let mut total = RecursionField::ZERO; + for ((spec_idx, alpha), max_round) in (0..num_prod_spec) + .zip(alpha_pows.iter()) + .zip(num_variables.iter()) + { + if round < max_round.saturating_sub(1) { + let eval = tower_proof.prod_specs_eval[spec_idx][round] + .iter() + .copied() + .product::(); + total += eq * *alpha * eval; + } + } + + for (((spec_idx, alpha_chunk), max_round), (_p_eval, _q_eval)) in + (0..tower_proof.logup_specs_eval.len()) + .zip(alpha_pows[num_prod_spec..].chunks(2)) + .zip(num_variables[num_prod_spec..].iter()) + .zip(logup_p_point_eval.iter().zip(logup_q_point_eval.iter())) + { + if round < max_round.saturating_sub(1) { + let evals = &tower_proof.logup_specs_eval[spec_idx][round]; + let (p1, p2, q1, q2) = (evals[0], evals[1], evals[2], evals[3]); + total += eq * (alpha_chunk[0] * (p1 * q2 + p2 * q1) + alpha_chunk[1] * (q1 * q2)); + } + } + + Ok(total) +} + +fn update_point_evals( + tower_proof: &TowerProofs, + round: usize, + rt_prime: &Point, + coeffs: &[RecursionField], + prod_point_eval: &mut [PointAndEval], + logup_p_point_eval: &mut [PointAndEval], + logup_q_point_eval: &mut [PointAndEval], +) { + for (spec_idx, point_eval) in prod_point_eval.iter_mut().enumerate() { + if round < tower_proof.prod_specs_eval[spec_idx].len() { + let evals = &tower_proof.prod_specs_eval[spec_idx][round]; + let merged = izip!(evals.iter(), coeffs.iter()) + .map(|(a, b)| *a * *b) + .sum(); + *point_eval = PointAndEval::new(rt_prime.clone(), merged); + } + } + + for (spec_idx, (p_eval, q_eval)) in logup_p_point_eval + .iter_mut() + .zip(logup_q_point_eval.iter_mut()) + .enumerate() + { + if round < tower_proof.logup_specs_eval[spec_idx].len() { + let evals = &tower_proof.logup_specs_eval[spec_idx][round]; + let (p_slice, q_slice) = evals.split_at(2); + let merged_p = izip!(p_slice.iter(), coeffs.iter()) + .map(|(a, b)| *a * *b) + .sum(); + let merged_q = izip!(q_slice.iter(), coeffs.iter()) + .map(|(a, b)| *a * *b) + .sum(); + *p_eval = PointAndEval::new(rt_prime.clone(), merged_p); + *q_eval = PointAndEval::new(rt_prime.clone(), merged_q); + } + } +} + +fn aggregate_next_eval( + round: usize, + alpha_pows: &[RecursionField], + num_variables: &[usize], + prod_point_eval: &[PointAndEval], + logup_p_point_eval: &[PointAndEval], + logup_q_point_eval: &[PointAndEval], +) -> RecursionField { + let num_prod_spec = prod_point_eval.len(); + let mut total = RecursionField::ZERO; + + for ((point_eval, alpha), max_round) in prod_point_eval + .iter() + .zip(alpha_pows.iter()) + .zip(num_variables.iter()) + { + if round + 1 < *max_round { + total += *alpha * point_eval.eval; + } + } + + for (((p_eval, q_eval), alpha_chunk), max_round) in logup_p_point_eval + .iter() + .zip(logup_q_point_eval.iter()) + .zip(alpha_pows[num_prod_spec..].chunks(2)) + .zip(num_variables[num_prod_spec..].iter()) + { + if round + 1 < *max_round { + total += alpha_chunk[0] * p_eval.eval + alpha_chunk[1] * q_eval.eval; + } + } + + total +} diff --git a/ceno_recursion_v2/src/tracegen.rs b/ceno_recursion_v2/src/tracegen.rs new file mode 100644 index 000000000..39fe2f0f6 --- /dev/null +++ b/ceno_recursion_v2/src/tracegen.rs @@ -0,0 +1,82 @@ +use openvm_cpu_backend::CpuBackend; +use openvm_stark_backend::{ + StarkProtocolConfig, + prover::{AirProvingContext, ProverBackend}, +}; +use openvm_stark_sdk::config::baby_bear_poseidon2::F; +use p3_matrix::dense::RowMajorMatrix; + +use crate::system::{Preflight, RecursionProof, RecursionVk}; + +/// Backend-generic trait to generate a proving context +pub(crate) trait ModuleChip { + /// Context needed for trace generation (e.g., VK, proofs, preflights). + type Ctx<'a>; + + /// Generate an AirProvingContext. If required_height is Some(..), then the + /// resulting trace matrices must have height required_height. This function + /// should return None iff required_height is defined AND the matrix requires + /// more than required_height rows. + fn generate_proving_ctx( + &self, + ctx: &Self::Ctx<'_>, + required_height: Option, + ) -> Option>; +} + +/// Trait to generate a CPU row-major common trace +pub(crate) trait RowMajorChip { + /// Context needed for trace generation (e.g., VK, proofs, preflights). + type Ctx<'a>; + + /// Generate row major trace with the same semantics as TraceGenerator + fn generate_trace( + &self, + ctx: &Self::Ctx<'_>, + required_height: Option, + ) -> Option>; +} + +pub(crate) struct StandardTracegenCtx<'a> { + pub vk: &'a RecursionVk, + pub proofs: &'a [RecursionProof], + pub preflights: &'a [&'a Preflight], +} + +impl, T: RowMajorChip> ModuleChip> for T { + type Ctx<'a> = T::Ctx<'a>; + + #[tracing::instrument(level = "trace", skip_all)] + fn generate_proving_ctx( + &self, + ctx: &Self::Ctx<'_>, + required_height: Option, + ) -> Option>> { + let common_main_rm = self.generate_trace(ctx, required_height); + common_main_rm.map(AirProvingContext::simple_no_pis) + } +} + +#[cfg(feature = "cuda")] +pub(crate) mod cuda { + use openvm_cuda_backend::{GpuBackend, data_transporter::transport_matrix_h2d_row}; + + use super::*; + use crate::cuda::{preflight::PreflightGpu, proof::ProofGpu, vk::VerifyingKeyGpu}; + + pub(crate) struct StandardTracegenGpuCtx<'a> { + pub vk: &'a VerifyingKeyGpu, + pub proofs: &'a [ProofGpu], + pub preflights: &'a [PreflightGpu], + } + + pub(crate) fn generate_gpu_proving_ctx>( + t: &T, + ctx: &T::Ctx<'_>, + required_height: Option, + ) -> Option> { + let common_main_rm = t.generate_trace(ctx, required_height); + common_main_rm + .map(|m| AirProvingContext::simple_no_pis(transport_matrix_h2d_row(&m).unwrap())) + } +} diff --git a/ceno_recursion_v2/src/transcript/mod.rs b/ceno_recursion_v2/src/transcript/mod.rs new file mode 100644 index 000000000..7ef1e9128 --- /dev/null +++ b/ceno_recursion_v2/src/transcript/mod.rs @@ -0,0 +1,370 @@ +use core::borrow::BorrowMut; +use std::sync::Arc; + +use openvm_cpu_backend::CpuBackend; +use openvm_poseidon2_air::{POSEIDON2_WIDTH, Poseidon2Config, Poseidon2SubChip}; +use openvm_stark_backend::{AirRef, StarkProtocolConfig, SystemParams, prover::AirProvingContext}; +use openvm_stark_sdk::{ + config::baby_bear_poseidon2::{F, poseidon2_perm}, + p3_baby_bear::Poseidon2BabyBear, +}; +use p3_air::BaseAir; +use p3_field::{PrimeCharacteristicRing, PrimeField32}; +use p3_matrix::dense::RowMajorMatrix; +use p3_symmetric::Permutation; + +use crate::system::{ + AirModule, BusInventory, GlobalCtxCpu, Preflight, RecursionProof, RecursionVk, TraceGenModule, +}; +use recursion_circuit::transcript::{ + merkle_verify::{MerkleVerifyAir, MerkleVerifyCols}, + poseidon2::{CHUNK, Poseidon2Air, Poseidon2Cols}, + transcript::{TranscriptAir, TranscriptCols}, +}; + +// Should be 1 when 3 <= max_constraint_degree < 7. +const SBOX_REGISTERS: usize = 1; + +pub struct TranscriptModule { + pub bus_inventory: BusInventory, + params: SystemParams, + final_state_bus_enabled: bool, + + sub_chip: Poseidon2SubChip, + perm: Poseidon2BabyBear, +} + +impl TranscriptModule { + pub fn new( + bus_inventory: BusInventory, + params: SystemParams, + final_state_bus_enabled: bool, + ) -> Self { + let sub_chip = Poseidon2SubChip::::new(Poseidon2Config::default().constants); + Self { + bus_inventory, + params, + final_state_bus_enabled, + sub_chip, + perm: poseidon2_perm().clone(), + } + } + + #[tracing::instrument(name = "generate_trace.transcript", level = "trace", skip_all)] + fn build_transcript_trace( + &self, + preflights: &[Preflight], + required_height: Option, + ) -> Option<(RowMajorMatrix, Vec<[F; POSEIDON2_WIDTH]>)> { + let transcript_width = TranscriptCols::::width(); + let mut valid_rows = Vec::with_capacity(preflights.len()); + + let mut transcript_valid_rows = 0usize; + for preflight in preflights { + let mut cur_is_sample = false; + let mut count = 0usize; + let mut num_valid_rows = 0usize; + + for op_is_sample in preflight.transcript.samples() { + if *op_is_sample { + if !cur_is_sample { + num_valid_rows += 1; + cur_is_sample = true; + count = 1; + } else { + if count == CHUNK { + num_valid_rows += 1; + count = 0; + } + count += 1; + } + } else if cur_is_sample { + num_valid_rows += 1; + cur_is_sample = false; + count = 1; + } else { + if count == CHUNK { + num_valid_rows += 1; + count = 0; + } + count += 1; + } + } + + if count > 0 { + num_valid_rows += 1; + } + valid_rows.push(num_valid_rows); + transcript_valid_rows += num_valid_rows; + } + + let transcript_num_rows = if let Some(height) = required_height { + if height == 0 || height < transcript_valid_rows { + return None; + } + height + } else if transcript_valid_rows == 0 { + 1 + } else { + transcript_valid_rows.next_power_of_two() + }; + + let mut transcript_trace = vec![F::ZERO; transcript_num_rows * transcript_width]; + let mut poseidon2_perm_inputs = vec![]; + + let mut skip = 0usize; + for (pidx, preflight) in preflights.iter().enumerate() { + let mut tidx = 0usize; + let mut prev_poseidon_state = [F::ZERO; POSEIDON2_WIDTH]; + let off = skip * transcript_width; + let end = off + valid_rows[pidx] * transcript_width; + + for (i, row) in transcript_trace[off..end] + .chunks_exact_mut(transcript_width) + .enumerate() + { + let cols: &mut TranscriptCols = row.borrow_mut(); + cols.proof_idx = F::from_usize(pidx); + if i == 0 { + cols.is_proof_start = F::ONE; + } + let is_sample = preflight.transcript.samples()[tidx]; + cols.is_sample = F::from_bool(is_sample); + cols.tidx = F::from_usize(tidx); + cols.mask[0] = F::ONE; + cols.prev_state = prev_poseidon_state; + + if is_sample { + debug_assert_eq!( + cols.prev_state[CHUNK - 1], + preflight.transcript.values()[tidx] + ); + } else { + cols.prev_state[0] = preflight.transcript.values()[tidx]; + } + + tidx += 1; + let mut idx = 1usize; + let mut permuted = false; + loop { + if tidx >= preflight.transcript.len() { + break; + } + + if preflight.transcript.samples()[tidx] != is_sample { + permuted = preflight.transcript.samples()[tidx]; + break; + } + + cols.mask[idx] = F::ONE; + if is_sample { + debug_assert_eq!( + cols.prev_state[CHUNK - 1 - idx], + preflight.transcript.values()[tidx] + ); + } else { + cols.prev_state[idx] = preflight.transcript.values()[tidx]; + } + + tidx += 1; + idx += 1; + if idx == CHUNK { + permuted = tidx < preflight.transcript.len() + && (!is_sample || preflight.transcript.samples()[tidx]); + break; + } + } + + prev_poseidon_state = cols.prev_state; + if permuted { + self.perm.permute_mut(&mut prev_poseidon_state); + poseidon2_perm_inputs.push(cols.prev_state); + } + cols.post_state = prev_poseidon_state; + } + + skip += valid_rows[pidx]; + debug_assert_eq!(tidx, preflight.transcript.len()); + } + + Some(( + RowMajorMatrix::new(transcript_trace, transcript_width), + poseidon2_perm_inputs, + )) + } + + fn dedup_poseidon_inputs( + poseidon2_perm_inputs: Vec<[F; POSEIDON2_WIDTH]>, + poseidon2_compress_inputs: Vec<[F; POSEIDON2_WIDTH]>, + ) -> (Vec<[F; POSEIDON2_WIDTH]>, Vec) { + let mut keyed_states: Vec<([u32; POSEIDON2_WIDTH], [F; POSEIDON2_WIDTH], bool)> = + Vec::with_capacity(poseidon2_perm_inputs.len() + poseidon2_compress_inputs.len()); + + for state in poseidon2_perm_inputs { + keyed_states.push((state.map(|x| x.as_canonical_u32()), state, true)); + } + for state in poseidon2_compress_inputs { + keyed_states.push((state.map(|x| x.as_canonical_u32()), state, false)); + } + + keyed_states.sort_unstable_by(|a, b| a.0.cmp(&b.0)); + + let mut deduped = Vec::new(); + let mut counts: Vec = Vec::new(); + let mut last_key: Option<[u32; POSEIDON2_WIDTH]> = None; + + for (key, state, is_perm) in keyed_states { + if last_key == Some(key) { + let last = counts.last_mut().expect("counts not empty"); + if is_perm { + last.perm += 1; + } else { + last.compress += 1; + } + } else { + deduped.push(state); + counts.push(if is_perm { + Poseidon2Count { + perm: 1, + compress: 0, + } + } else { + Poseidon2Count { + perm: 0, + compress: 1, + } + }); + last_key = Some(key); + } + } + + (deduped, counts) + } +} + +impl AirModule for TranscriptModule { + fn num_airs(&self) -> usize { + 3 + } + + fn airs>(&self) -> Vec> { + let transcript_air = TranscriptAir { + transcript_bus: self.bus_inventory.transcript_bus, + poseidon2_permute_bus: self.bus_inventory.poseidon2_permute_bus, + final_state_bus: self + .final_state_bus_enabled + .then_some(self.bus_inventory.final_state_bus), + }; + let poseidon2_air = Poseidon2Air:: { + subair: self.sub_chip.air.clone(), + poseidon2_permute_bus: self.bus_inventory.poseidon2_permute_bus, + poseidon2_compress_bus: self.bus_inventory.poseidon2_compress_bus, + }; + let merkle_verify_air = MerkleVerifyAir { + poseidon2_compress_bus: self.bus_inventory.poseidon2_compress_bus, + merkle_verify_bus: self.bus_inventory.merkle_verify_bus, + commitments_bus: self.bus_inventory.commitments_bus, + right_shift_bus: self.bus_inventory.right_shift_bus, + k: self.params.k_whir(), + }; + vec![ + Arc::new(transcript_air), + Arc::new(poseidon2_air), + Arc::new(merkle_verify_air), + ] + } +} + +#[repr(C)] +#[derive(Copy, Clone, Default)] +struct Poseidon2Count { + perm: u32, + compress: u32, +} + +impl> TraceGenModule> + for TranscriptModule +{ + // (external poseidon2 permute, external poseidon2 compress) + type ModuleSpecificCtx<'a> = (&'a [[F; POSEIDON2_WIDTH]], &'a [[F; POSEIDON2_WIDTH]]); + + #[tracing::instrument(skip_all)] + fn generate_proving_ctxs( + &self, + child_vk: &RecursionVk, + proofs: &[RecursionProof], + preflights: &[Preflight], + ctx: &Self::ModuleSpecificCtx<'_>, + required_heights: Option<&[usize]>, + ) -> Option>>> { + let _ = (child_vk, proofs); + + let (required_transcript, required_poseidon2, required_merkle_verify) = + if let Some(heights) = required_heights { + if heights.len() != 3 { + return None; + } + (Some(heights[0]), Some(heights[1]), Some(heights[2])) + } else { + (None, None, None) + }; + + // TODO(recursion-proof-bridge): Implement MerkleVerify trace generation using + // RecursionProof/RecursionVk once those fields are available in local bridge APIs. + let merkle_rows = required_merkle_verify.unwrap_or(1); + if merkle_rows == 0 { + return None; + } + let merkle_verify_trace = RowMajorMatrix::new( + vec![F::ZERO; merkle_rows * MerkleVerifyCols::::width()], + MerkleVerifyCols::::width(), + ); + + let (transcript_trace, mut poseidon2_perm_inputs) = + self.build_transcript_trace(preflights, required_transcript)?; + let mut poseidon2_compress_inputs = Vec::new(); + + poseidon2_perm_inputs.extend_from_slice(ctx.0); + poseidon2_compress_inputs.extend_from_slice(ctx.1); + + let poseidon2_trace = { + let (mut poseidon_states, poseidon_counts) = + Self::dedup_poseidon_inputs(poseidon2_perm_inputs, poseidon2_compress_inputs); + let poseidon2_valid_rows = poseidon_states.len(); + let poseidon2_num_rows = if let Some(height) = required_poseidon2 { + if height == 0 || poseidon2_valid_rows > height { + return None; + } + height + } else if poseidon2_valid_rows == 0 { + 1 + } else { + poseidon2_valid_rows.next_power_of_two() + }; + + poseidon_states.resize(poseidon2_num_rows, [F::ZERO; POSEIDON2_WIDTH]); + + let inner_width = self.sub_chip.air.width(); + let poseidon2_width = Poseidon2Cols::::width(); + let inner_trace = self.sub_chip.generate_trace(poseidon_states); + let mut poseidon_trace = vec![F::ZERO; poseidon2_num_rows * poseidon2_width]; + + for (i, row) in poseidon_trace.chunks_exact_mut(poseidon2_width).enumerate() { + let inner_off = i * inner_width; + row[..inner_width] + .copy_from_slice(&inner_trace.values[inner_off..inner_off + inner_width]); + let cols: &mut Poseidon2Cols = row.borrow_mut(); + let count = poseidon_counts.get(i).copied().unwrap_or_default(); + cols.permute_mult = F::from_u32(count.perm); + cols.compress_mult = F::from_u32(count.compress); + } + RowMajorMatrix::new(poseidon_trace, poseidon2_width) + }; + + Some(vec![ + AirProvingContext::simple_no_pis(transcript_trace), + AirProvingContext::simple_no_pis(poseidon2_trace), + AirProvingContext::simple_no_pis(merkle_verify_trace), + ]) + } +} diff --git a/ceno_recursion_v2/src/utils.rs b/ceno_recursion_v2/src/utils.rs new file mode 100644 index 000000000..c9051c5ce --- /dev/null +++ b/ceno_recursion_v2/src/utils.rs @@ -0,0 +1,290 @@ +use std::ops::Index; + +use openvm_poseidon2_air::POSEIDON2_WIDTH; +use openvm_stark_backend::interaction::Interaction; +use openvm_stark_sdk::config::baby_bear_poseidon2::{CHUNK, D_EF, DIGEST_SIZE, F, poseidon2_perm}; +use p3_air::AirBuilder; +use p3_field::{PrimeCharacteristicRing, extension::BinomiallyExtendable}; +use p3_symmetric::Permutation; + +pub fn base_to_ext(x: impl Into) -> [FA; D_EF] +where + FA: PrimeCharacteristicRing, +{ + [x.into(), FA::ZERO, FA::ZERO, FA::ZERO] +} + +pub fn ext_field_one_minus(x: [impl Into; D_EF]) -> [FA; D_EF] +where + FA: PrimeCharacteristicRing, +{ + let [x0, x1, x2, x3] = x.map(Into::into); + [FA::ONE - x0, -x1, -x2, -x3] +} + +pub fn ext_field_add(x: [impl Into; D_EF], y: [impl Into; D_EF]) -> [FA; D_EF] +where + FA: PrimeCharacteristicRing, +{ + let [x0, x1, x2, x3] = x.map(Into::into); + let [y0, y1, y2, y3] = y.map(Into::into); + [x0 + y0, x1 + y1, x2 + y2, x3 + y3] +} + +pub fn ext_field_subtract(x: [impl Into; D_EF], y: [impl Into; D_EF]) -> [FA; D_EF] +where + FA: PrimeCharacteristicRing, +{ + let [x0, x1, x2, x3] = x.map(Into::into); + let [y0, y1, y2, y3] = y.map(Into::into); + [x0 - y0, x1 - y1, x2 - y2, x3 - y3] +} + +pub fn ext_field_multiply(x: [impl Into; D_EF], y: [impl Into; D_EF]) -> [FA; D_EF] +where + FA: PrimeCharacteristicRing, + FA::PrimeSubfield: BinomiallyExtendable<{ D_EF }>, +{ + let [x0, x1, x2, x3] = x.map(Into::into); + let [y0, y1, y2, y3] = y.map(Into::into); + + let w = FA::from_prime_subfield(FA::PrimeSubfield::W); + + let z0_beta_terms = x1.clone() * y3.clone() + x2.clone() * y2.clone() + x3.clone() * y1.clone(); + let z1_beta_terms = x2.clone() * y3.clone() + x3.clone() * y2.clone(); + let z2_beta_terms = x3.clone() * y3.clone(); + + [ + x0.clone() * y0.clone() + z0_beta_terms * w.clone(), + x0.clone() * y1.clone() + x1.clone() * y0.clone() + z1_beta_terms * w.clone(), + x0.clone() * y2.clone() + + x1.clone() * y1.clone() + + x2.clone() * y0.clone() + + z2_beta_terms * w, + x0 * y3 + x1 * y2 + x2 * y1 + x3 * y0, + ] +} + +pub fn ext_field_add_scalar(x: [impl Into; D_EF], y: impl Into) -> [FA; D_EF] +where + FA: PrimeCharacteristicRing, +{ + let [x0, x1, x2, x3] = x.map(Into::into); + [x0 + y.into(), x1, x2, x3] +} + +pub fn ext_field_subtract_scalar(x: [impl Into; D_EF], y: impl Into) -> [FA; D_EF] +where + FA: PrimeCharacteristicRing, +{ + let [x0, x1, x2, x3] = x.map(Into::into); + [x0 - y.into(), x1, x2, x3] +} + +pub fn scalar_subtract_ext_field(x: impl Into, y: [impl Into; D_EF]) -> [FA; D_EF] +where + FA: PrimeCharacteristicRing, +{ + let [y0, y1, y2, y3] = y.map(Into::into); + [x.into() - y0, -y1, -y2, -y3] +} + +pub fn ext_field_multiply_scalar(x: [impl Into; D_EF], y: impl Into) -> [FA; D_EF] +where + FA: PrimeCharacteristicRing, +{ + let [x0, x1, x2, x3] = x.map(Into::into); + let y = y.into(); + [x0 * y.clone(), x1 * y.clone(), x2 * y.clone(), x3 * y] +} + +pub fn assert_zeros(builder: &mut AB, array: [impl Into; N]) +where + AB: AirBuilder, +{ + for elem in array.into_iter() { + builder.assert_zero(elem); + } +} + +pub fn assert_one_ext(builder: &mut AB, array: [impl Into; D_EF]) +where + AB: AirBuilder, +{ + for (i, elem) in array.into_iter().enumerate() { + if i == 0 { + builder.assert_one(elem); + } else { + builder.assert_zero(elem); + } + } +} + +#[derive(Debug, Clone)] +pub struct MultiProofVecVec { + data: Vec, + bounds: Vec, +} + +impl MultiProofVecVec { + pub fn new() -> Self { + Self { + data: Vec::new(), + bounds: vec![0], + } + } + + pub fn push(&mut self, x: T) { + self.data.push(x); + } + + pub fn extend(&mut self, iter: impl IntoIterator) { + self.data.extend(iter); + } + + pub fn extend_from_slice(&mut self, slice: &[T]) + where + T: Clone, + { + self.data.extend_from_slice(slice); + } + + pub fn end_proof(&mut self) { + self.bounds.push(self.data.len()); + } + + pub fn len(&self) -> usize { + self.data.len() + } + + pub fn num_proofs(&self) -> usize { + self.bounds.len() - 1 + } +} + +impl Index for MultiProofVecVec { + type Output = [T]; + + fn index(&self, index: usize) -> &Self::Output { + debug_assert!(index < self.num_proofs()); + &self.data[self.bounds[index]..self.bounds[index + 1]] + } +} + +#[derive(Debug, Clone)] +pub struct MultiVecWithBounds { + pub data: Vec, + pub bounds: [Vec; DIM_MINUS_ONE], +} + +impl MultiVecWithBounds { + pub fn new() -> Self { + Self { + data: Vec::new(), + bounds: core::array::from_fn(|_| vec![0]), + } + } + + pub fn push(&mut self, x: T) { + self.data.push(x); + } + + pub fn extend(&mut self, iter: impl IntoIterator) { + self.data.extend(iter); + } + + pub fn close_level(&mut self, level: usize) { + debug_assert!(level < DIM_MINUS_ONE); + for i in level..DIM_MINUS_ONE - 1 { + self.bounds[i].push(self.bounds[i + 1].len()); + } + self.bounds[DIM_MINUS_ONE - 1].push(self.data.len()); + } +} + +impl Index<[usize; DIM_MINUS_ONE]> + for MultiVecWithBounds +{ + type Output = [T]; + + fn index(&self, index: [usize; DIM_MINUS_ONE]) -> &Self::Output { + let mut idx = 0; + for i in 0..DIM_MINUS_ONE { + idx += index[i]; + if i < DIM_MINUS_ONE - 1 { + idx = self.bounds[i][idx]; + } + } + &self.data[self.bounds[DIM_MINUS_ONE - 1][idx]..self.bounds[DIM_MINUS_ONE - 1][idx + 1]] + } +} + +pub fn poseidon2_hash_slice(vals: &[F]) -> ([F; CHUNK], Vec<[F; POSEIDON2_WIDTH]>) { + let num_chunks = vals.len().div_ceil(CHUNK); + let mut pre_states = Vec::with_capacity(num_chunks); + let perm = poseidon2_perm(); + let mut state = [F::ZERO; POSEIDON2_WIDTH]; + let mut i = 0; + for &val in vals { + state[i] = val; + i += 1; + if i == CHUNK { + pre_states.push(state); + perm.permute_mut(&mut state); + i = 0; + } + } + if i != 0 { + pre_states.push(state); + perm.permute_mut(&mut state); + } + (state[..CHUNK].try_into().unwrap(), pre_states) +} + +pub fn digests_to_poseidon2_input( + x: [T; DIGEST_SIZE], + y: [T; DIGEST_SIZE], +) -> [T; POSEIDON2_WIDTH] { + core::array::from_fn(|i| { + if i < DIGEST_SIZE { + x[i].clone() + } else { + y[i - DIGEST_SIZE].clone() + } + }) +} + +pub fn poseidon2_hash_slice_with_states( + vals: &[F], +) -> ( + [F; CHUNK], + Vec<[F; POSEIDON2_WIDTH]>, + Vec<[F; POSEIDON2_WIDTH]>, +) { + let num_chunks = vals.len().div_ceil(CHUNK); + let mut pre_states = Vec::with_capacity(num_chunks); + let mut post_states = Vec::with_capacity(num_chunks); + let perm = poseidon2_perm(); + let mut state = [F::ZERO; POSEIDON2_WIDTH]; + let mut i = 0; + for &val in vals { + state[i] = val; + i += 1; + if i == CHUNK { + pre_states.push(state); + perm.permute_mut(&mut state); + post_states.push(state); + i = 0; + } + } + if i != 0 { + pre_states.push(state); + perm.permute_mut(&mut state); + post_states.push(state); + } + (state[..CHUNK].try_into().unwrap(), pre_states, post_states) +} + +pub fn interaction_length(interaction: &Interaction) -> usize { + interaction.message.len() + 2 +} diff --git a/ceno_recursion_v2/taplo.toml b/ceno_recursion_v2/taplo.toml new file mode 100644 index 000000000..d857ad4e0 --- /dev/null +++ b/ceno_recursion_v2/taplo.toml @@ -0,0 +1,6 @@ +# Configuration doc: https://taplo.tamasfe.dev/configuration/formatter-options.html +[formatting] +align_comments = false +array_auto_collapse = false +array_auto_expand = false +reorder_keys = true diff --git a/ceno_zkvm/src/bin/e2e.rs b/ceno_zkvm/src/bin/e2e.rs index 09e5a1131..f5277ce58 100644 --- a/ceno_zkvm/src/bin/e2e.rs +++ b/ceno_zkvm/src/bin/e2e.rs @@ -19,7 +19,8 @@ use gkr_iop::hal::ProverBackend; use mpcs::{ Basefold, BasefoldRSParams, PolynomialCommitmentScheme, SecurityLevel, Whir, WhirDefaultSpec, }; -use p3::field::FieldAlgebra; +use p3::field::PrimeCharacteristicRing; + use serde::{Serialize, de::DeserializeOwned}; use std::{fs, panic, panic::AssertUnwindSafe, path::PathBuf}; use tracing::{error, level_filters::LevelFilter}; diff --git a/ceno_zkvm/src/chip_handler/global_state.rs b/ceno_zkvm/src/chip_handler/global_state.rs index 7fd24ad72..ea4f976da 100644 --- a/ceno_zkvm/src/chip_handler/global_state.rs +++ b/ceno_zkvm/src/chip_handler/global_state.rs @@ -4,7 +4,7 @@ use gkr_iop::error::CircuitBuilderError; use super::GlobalStateRegisterMachineChipOperations; use crate::{circuit_builder::CircuitBuilder, structs::RAMType}; use multilinear_extensions::{Expression, ToExpr}; -use p3::field::FieldAlgebra; +use p3::field::PrimeCharacteristicRing; impl GlobalStateRegisterMachineChipOperations for CircuitBuilder<'_, E> { fn state_in( @@ -13,7 +13,7 @@ impl GlobalStateRegisterMachineChipOperations for CircuitB ts: Expression, ) -> Result<(), CircuitBuilderError> { let record: Vec> = vec![ - E::BaseField::from_canonical_u64(RAMType::GlobalState as u64).expr(), + E::BaseField::from_u64(RAMType::GlobalState as u64).expr(), pc, ts, ]; @@ -26,7 +26,7 @@ impl GlobalStateRegisterMachineChipOperations for CircuitB ts: Expression, ) -> Result<(), CircuitBuilderError> { let record: Vec> = vec![ - E::BaseField::from_canonical_u64(RAMType::GlobalState as u64).expr(), + E::BaseField::from_u64(RAMType::GlobalState as u64).expr(), pc, ts, ]; diff --git a/ceno_zkvm/src/gadgets/add4.rs b/ceno_zkvm/src/gadgets/add4.rs index d3fcf16b5..89115aeb1 100644 --- a/ceno_zkvm/src/gadgets/add4.rs +++ b/ceno_zkvm/src/gadgets/add4.rs @@ -104,9 +104,9 @@ impl Add4Operation { self.is_carry_1[i] = F::from_bool(carry[i] == 1); self.is_carry_2[i] = F::from_bool(carry[i] == 2); self.is_carry_3[i] = F::from_bool(carry[i] == 3); - self.carry[i] = F::from_canonical_u8(carry[i]); + self.carry[i] = F::from_u8(carry[i]); debug_assert!(carry[i] <= 3); - debug_assert_eq!(self.value[i], F::from_canonical_u32(res % base)); + debug_assert_eq!(self.value[i], F::from_u32(res % base)); } // Range check. diff --git a/ceno_zkvm/src/gadgets/field/field_inner_product.rs b/ceno_zkvm/src/gadgets/field/field_inner_product.rs index 7d7c430de..80b2199db 100644 --- a/ceno_zkvm/src/gadgets/field/field_inner_product.rs +++ b/ceno_zkvm/src/gadgets/field/field_inner_product.rs @@ -37,7 +37,7 @@ use std::fmt::Debug; use crate::{ gadgets::{ util::{compute_root_quotient_and_shift, split_u16_limbs_to_u8_limbs}, - util_expr::eval_field_operation, + util_expr::{eval_field_operation, poly_mul_expr}, }, witness::LkMultiplicity, }; @@ -169,7 +169,7 @@ impl FieldInnerProductCols { let p_inner_product = p_a_vec .iter() .zip(p_b_vec.iter()) - .map(|(p_a, p_b)| p_a * p_b) + .map(|(p_a, p_b)| poly_mul_expr(p_a, p_b)) .collect::>() .iter() .fold(p_zero, |acc, x| acc + x); @@ -177,7 +177,7 @@ impl FieldInnerProductCols { let p_inner_product_minus_result = &p_inner_product - &p_result; let p_limbs = Polynomial::from_iter(P::modulus_field_iter::().map(|x| x.expr())); - let p_vanishing = &p_inner_product_minus_result - &(&p_carry * &p_limbs); + let p_vanishing = &p_inner_product_minus_result - &poly_mul_expr(&p_carry, &p_limbs); let p_witness_low = self.witness_low.0.iter().into(); let p_witness_high = self.witness_high.0.iter().into(); diff --git a/ceno_zkvm/src/gadgets/field/field_op.rs b/ceno_zkvm/src/gadgets/field/field_op.rs index b11d933dd..97e9caf09 100644 --- a/ceno_zkvm/src/gadgets/field/field_op.rs +++ b/ceno_zkvm/src/gadgets/field/field_op.rs @@ -38,7 +38,7 @@ use crate::{ gadgets::{ field::FieldOperation, util::{compute_root_quotient_and_shift, split_u16_limbs_to_u8_limbs}, - util_expr::eval_field_operation, + util_expr::{eval_field_operation, poly_mul_expr, poly_scale_expr}, }, witness::LkMultiplicity, }; @@ -119,7 +119,7 @@ impl FieldOpCols { let p_modulus_limbs = modulus .to_bytes_le() .iter() - .map(|x| F::from_canonical_u8(*x)) + .map(|x| F::from_u8(*x)) .collect::>(); let p_modulus: Polynomial = p_modulus_limbs.iter().into(); let p_result: Polynomial = P::to_limbs_field::(&result).into(); @@ -267,14 +267,17 @@ impl FieldOpCols { let is_mul: Expression = is_mul.expr(); let is_div: Expression = is_div.expr(); - let p_result = p_res_param.clone() * (is_add.clone() + is_mul.clone()) - + p_a_param.clone() * (is_sub.clone() + is_div.clone()); + let p_result = poly_scale_expr(&p_res_param, is_add.clone() + is_mul.clone()) + + poly_scale_expr(&p_a_param, is_sub.clone() + is_div.clone()); let p_add = p_a_param.clone() + p_b.clone(); let p_sub = p_res_param.clone() + p_b.clone(); - let p_mul = p_a_param.clone() * p_b.clone(); - let p_div = p_res_param * p_b.clone(); - let p_op = p_add * is_add + p_sub * is_sub + p_mul * is_mul + p_div * is_div; + let p_mul = poly_mul_expr(&p_a_param, &p_b); + let p_div = poly_mul_expr(&p_res_param, &p_b); + let p_op = poly_scale_expr(&p_add, is_add) + + poly_scale_expr(&p_sub, is_sub) + + poly_scale_expr(&p_mul, is_mul) + + poly_scale_expr(&p_div, is_div); self.eval_with_polynomials(builder, p_op, modulus.clone(), p_result) } @@ -298,7 +301,7 @@ impl FieldOpCols { let p_c: Polynomial> = (c).clone().into(); let p_result: Polynomial<_> = self.result.clone().into(); - let p_op = p_a * p_b + p_c; + let p_op = poly_mul_expr(&p_a, &p_b) + p_c; self.eval_with_polynomials(builder, p_op, modulus.clone(), p_result) } @@ -326,7 +329,7 @@ impl FieldOpCols { }; let p_op: Polynomial> = match op { FieldOperation::Add | FieldOperation::Sub => p_a + p_b, - FieldOperation::Mul | FieldOperation::Div => p_a * p_b, + FieldOperation::Mul | FieldOperation::Div => poly_mul_expr(&p_a, &p_b), }; self.eval_with_polynomials(builder, p_op, modulus.clone(), p_result) } @@ -349,7 +352,7 @@ impl FieldOpCols { let p_modulus: Polynomial> = modulus.into(); let p_carry: Polynomial> = self.carry.clone().into(); let p_op_minus_result: Polynomial> = p_op - &p_result; - let p_vanishing = p_op_minus_result - &(&p_carry * &p_modulus); + let p_vanishing = p_op_minus_result - &poly_mul_expr(&p_carry, &p_modulus); let p_witness_low = self.witness_low.0.iter().into(); let p_witness_high = self.witness_high.0.iter().into(); eval_field_operation::(builder, &p_vanishing, &p_witness_low, &p_witness_high)?; diff --git a/ceno_zkvm/src/gadgets/field/field_sqrt.rs b/ceno_zkvm/src/gadgets/field/field_sqrt.rs index 22986c550..91a97fc39 100644 --- a/ceno_zkvm/src/gadgets/field/field_sqrt.rs +++ b/ceno_zkvm/src/gadgets/field/field_sqrt.rs @@ -100,7 +100,7 @@ impl FieldSqrtCols { self.range.populate(record, &sqrt, &modulus); let sqrt_bytes = P::to_limbs(&sqrt); - self.lsb = F::from_canonical_u8(sqrt_bytes[0] & 1); + self.lsb = F::from_u8(sqrt_bytes[0] & 1); record.lookup_and_byte(sqrt_bytes[0] as u64, 1); diff --git a/ceno_zkvm/src/gadgets/field/range.rs b/ceno_zkvm/src/gadgets/field/range.rs index f1fe274da..39ed07d70 100644 --- a/ceno_zkvm/src/gadgets/field/range.rs +++ b/ceno_zkvm/src/gadgets/field/range.rs @@ -81,15 +81,15 @@ impl FieldLtCols { assert!(byte <= modulus_byte); if byte < modulus_byte { *flag = 1; - self.lhs_comparison_byte = F::from_canonical_u8(*byte); - self.rhs_comparison_byte = F::from_canonical_u8(*modulus_byte); + self.lhs_comparison_byte = F::from_u8(*byte); + self.rhs_comparison_byte = F::from_u8(*modulus_byte); record.lookup_ltu_byte(*byte as u64, *modulus_byte as u64); break; } } for (byte, flag) in izip!(byte_flags.iter(), self.byte_flags.0.iter_mut()) { - *flag = F::from_canonical_u8(*byte); + *flag = F::from_u8(*byte); } } } diff --git a/ceno_zkvm/src/gadgets/fixed_rotate_right.rs b/ceno_zkvm/src/gadgets/fixed_rotate_right.rs index a1fa8499f..babeba6d2 100644 --- a/ceno_zkvm/src/gadgets/fixed_rotate_right.rs +++ b/ceno_zkvm/src/gadgets/fixed_rotate_right.rs @@ -66,13 +66,13 @@ impl FixedRotateRightOperation { impl FixedRotateRightOperation { pub fn populate(&mut self, record: &mut LkMultiplicity, input: u32, rotation: usize) -> u32 { - let input_bytes = input.to_le_bytes().map(F::from_canonical_u8); + let input_bytes = input.to_le_bytes().map(F::from_u8); let expected = input.rotate_right(rotation as u32); // Compute some constants with respect to the rotation needed for the rotation. let nb_bytes_to_shift = Self::nb_bytes_to_shift(rotation); let nb_bits_to_shift = Self::nb_bits_to_shift(rotation); - let carry_multiplier = F::from_canonical_u32(Self::carry_multiplier(rotation)); + let carry_multiplier = F::from_u32(Self::carry_multiplier(rotation)); // Perform the byte shift. let input_bytes_rotated = Word([ @@ -93,8 +93,8 @@ impl FixedRotateRightOperation { let (shift, carry) = shr_carry(b, c); record.lookup_shr_byte(shift as u64, carry as u64, nb_bits_to_shift as u64); - self.shift[i] = F::from_canonical_u8(shift); - self.carry[i] = F::from_canonical_u8(carry); + self.shift[i] = F::from_u8(shift); + self.carry[i] = F::from_u8(carry); if i == WORD_SIZE - 1 { first_shift = self.shift[i]; diff --git a/ceno_zkvm/src/gadgets/fixed_shift_right.rs b/ceno_zkvm/src/gadgets/fixed_shift_right.rs index d7827e2dc..7f4005d27 100644 --- a/ceno_zkvm/src/gadgets/fixed_shift_right.rs +++ b/ceno_zkvm/src/gadgets/fixed_shift_right.rs @@ -67,13 +67,13 @@ impl FixedShiftRightOperation { impl FixedShiftRightOperation { pub fn populate(&mut self, record: &mut LkMultiplicity, input: u32, rotation: usize) -> u32 { - let input_bytes = input.to_le_bytes().map(F::from_canonical_u8); + let input_bytes = input.to_le_bytes().map(F::from_u8); let expected = input >> rotation; // Compute some constants with respect to the rotation needed for the rotation. let nb_bytes_to_shift = Self::nb_bytes_to_shift(rotation); let nb_bits_to_shift = Self::nb_bits_to_shift(rotation); - let carry_multiplier = F::from_canonical_u32(Self::carry_multiplier(rotation)); + let carry_multiplier = F::from_u32(Self::carry_multiplier(rotation)); // Perform the byte shift. let mut word = [F::ZERO; WORD_SIZE]; @@ -95,8 +95,8 @@ impl FixedShiftRightOperation { record.lookup_shr_byte(shift as u64, carry as u64, nb_bits_to_shift as u64); - self.shift[i] = F::from_canonical_u8(shift); - self.carry[i] = F::from_canonical_u8(carry); + self.shift[i] = F::from_u8(shift); + self.carry[i] = F::from_u8(carry); if i == WORD_SIZE - 1 { first_shift = self.shift[i]; diff --git a/ceno_zkvm/src/gadgets/is_zero.rs b/ceno_zkvm/src/gadgets/is_zero.rs index bd8191095..e45657e14 100644 --- a/ceno_zkvm/src/gadgets/is_zero.rs +++ b/ceno_zkvm/src/gadgets/is_zero.rs @@ -87,7 +87,7 @@ impl IsZeroOperation { impl IsZeroOperation { pub fn populate(&mut self, a: u32) -> u32 { - self.populate_from_field_element(F::from_canonical_u32(a)) + self.populate_from_field_element(F::from_u32(a)) } pub fn populate_from_field_element(&mut self, a: F) -> u32 { diff --git a/ceno_zkvm/src/gadgets/mod.rs b/ceno_zkvm/src/gadgets/mod.rs index 349b7cf64..fb8577911 100644 --- a/ceno_zkvm/src/gadgets/mod.rs +++ b/ceno_zkvm/src/gadgets/mod.rs @@ -28,5 +28,6 @@ pub use poseidon2::{Poseidon2BabyBearConfig, Poseidon2Config}; pub use signed::Signed; pub use signed_ext::SignedExtendConfig; pub use signed_limbs::{UIntLimbsLT, UIntLimbsLTConfig}; +pub(crate) use util_expr::poly_scale_expr; pub use word::*; pub use xor::*; diff --git a/ceno_zkvm/src/gadgets/poseidon2.rs b/ceno_zkvm/src/gadgets/poseidon2.rs index 1dbfcccb5..b1dad2002 100644 --- a/ceno_zkvm/src/gadgets/poseidon2.rs +++ b/ceno_zkvm/src/gadgets/poseidon2.rs @@ -3,17 +3,15 @@ use std::{ borrow::{Borrow, BorrowMut}, iter::from_fn, - mem::transmute, }; use ff_ext::{BabyBearExt4, ExtensionField}; use gkr_iop::error::CircuitBuilderError; use itertools::Itertools; use multilinear_extensions::{Expression, ToExpr, WitIn}; -use num_bigint::BigUint; use p3::{ - babybear::BabyBearInternalLayerParameters, - field::{Field, FieldAlgebra, PrimeField}, + babybear::{BabyBearInternalLayerParameters, BabyBearParameters}, + field::{Field, PrimeCharacteristicRing as FieldAlgebra, PrimeField}, monty_31::InternalLayerBaseParameters, poseidon2::{GenericPoseidon2LinearLayers, MDSMat4, mds_light_permutation}, poseidon2_air::{FullRound, PartialRound, Poseidon2Cols, SBox, num_cols}, @@ -86,28 +84,15 @@ pub struct Poseidon2Config< #[derive(Debug, Clone)] pub struct Poseidon2LinearLayers; -impl GenericPoseidon2LinearLayers - for Poseidon2LinearLayers +impl GenericPoseidon2LinearLayers for Poseidon2LinearLayers +where + BabyBearInternalLayerParameters: InternalLayerBaseParameters, { - fn internal_linear_layer(state: &mut [F; WIDTH]) { - // this only works when F is BabyBear field for now - let babybear_prime = BigUint::from(0x7800_0001u32); - if F::order() == babybear_prime { - let diag_m1_matrix = &>::INTERNAL_DIAG_MONTY; - let diag_m1_matrix: &[F; WIDTH] = unsafe { transmute(diag_m1_matrix) }; - let sum = state.iter().cloned().sum::(); - for (input, diag_m1) in state.iter_mut().zip(diag_m1_matrix) { - *input = sum + F::from_f(*diag_m1) * *input; - } - } else { - panic!("Unsupported field"); - } + fn internal_linear_layer(state: &mut [R; WIDTH]) { + BabyBearInternalLayerParameters::generic_internal_linear_layer(state); } - fn external_linear_layer(state: &mut [F; WIDTH]) { + fn external_linear_layer(state: &mut [R; WIDTH]) { mds_light_permutation(state, &MDSMat4); } } @@ -120,6 +105,8 @@ impl< const HALF_FULL_ROUNDS: usize, const PARTIAL_ROUNDS: usize, > Poseidon2Config +where + BabyBearInternalLayerParameters: InternalLayerBaseParameters, { // constraints taken from poseidon2_air/src/air.rs fn eval_sbox( @@ -205,25 +192,7 @@ impl< } fn internal_linear_layer(state: &mut [Expression; STATE_WIDTH]) { - let sum: Expression = state.iter().map(|s| s.get_monomial_form()).sum(); - // reduce to monomial form - let sum = sum.get_monomial_form(); - let babybear_prime = BigUint::from(0x7800_0001u32); - if E::BaseField::order() == babybear_prime { - // BabyBear - let diag_m1_matrix_bb = - &>:: - INTERNAL_DIAG_MONTY; - let diag_m1_matrix: &[E::BaseField; STATE_WIDTH] = - unsafe { transmute(diag_m1_matrix_bb) }; - for (input, diag_m1) in state.iter_mut().zip_eq(diag_m1_matrix) { - let updated = sum.clone() + Expression::from_f(*diag_m1) * input.clone(); - // reduce to monomial form - *input = updated.get_monomial_form(); - } - } else { - panic!("Unsupported field"); - } + BabyBearInternalLayerParameters::generic_internal_linear_layer(state); } pub fn construct( @@ -288,7 +257,7 @@ impl< .inputs .iter_mut() .zip_eq(post_linear_layer_cols[0..STATE_WIDTH].iter()) - .for_each(|(input, post_linear)| { + .for_each(|(input, post_linear): (&mut Expression, &WitIn)| { cb.require_equal( || "post_linear_layer = input", post_linear.expr(), @@ -325,7 +294,7 @@ impl< .inputs .iter_mut() .zip_eq(post_linear_layer_cols[STATE_WIDTH + PARTIAL_ROUNDS..].iter()) - .for_each(|(input, post_linear)| { + .for_each(|(input, post_linear): (&mut Expression, &WitIn)| { cb.require_equal( || "post_linear_layer = input", post_linear.expr(), @@ -433,7 +402,7 @@ impl< ////////////////////////////////////////////////////////////////////////// fn generate_trace_rows_for_perm< F: PrimeField, - LinearLayers: GenericPoseidon2LinearLayers, + LinearLayers: GenericPoseidon2LinearLayers, const WIDTH: usize, const SBOX_DEGREE: u64, const SBOX_REGISTERS: usize, @@ -518,7 +487,7 @@ fn generate_trace_rows_for_perm< #[inline] fn generate_full_round< F: PrimeField, - LinearLayers: GenericPoseidon2LinearLayers, + LinearLayers: GenericPoseidon2LinearLayers, const WIDTH: usize, const SBOX_DEGREE: u64, const SBOX_REGISTERS: usize, @@ -546,7 +515,7 @@ fn generate_full_round< #[inline] fn generate_partial_round< F: PrimeField, - LinearLayers: GenericPoseidon2LinearLayers, + LinearLayers: GenericPoseidon2LinearLayers, const WIDTH: usize, const SBOX_DEGREE: u64, const SBOX_REGISTERS: usize, diff --git a/ceno_zkvm/src/gadgets/signed_ext.rs b/ceno_zkvm/src/gadgets/signed_ext.rs index 4be082386..0b3cf721c 100644 --- a/ceno_zkvm/src/gadgets/signed_ext.rs +++ b/ceno_zkvm/src/gadgets/signed_ext.rs @@ -6,7 +6,8 @@ use crate::{ use ff_ext::{ExtensionField, FieldInto}; use gkr_iop::error::CircuitBuilderError; use multilinear_extensions::{Expression, ToExpr, WitIn}; -use p3::field::FieldAlgebra; + +use p3::field::PrimeCharacteristicRing; use std::marker::PhantomData; use witness::set_val; @@ -91,7 +92,7 @@ impl SignedExtendConfig { ) -> Result<(), CircuitBuilderError> { let msb = val >> (self.n_bits - 1); lk_multiplicity.assert_const_range(2 * val - (msb << self.n_bits), self.n_bits); - set_val!(instance, self.msb, E::BaseField::from_canonical_u64(msb)); + set_val!(instance, self.msb, E::BaseField::from_u64(msb)); Ok(()) } diff --git a/ceno_zkvm/src/gadgets/signed_limbs.rs b/ceno_zkvm/src/gadgets/signed_limbs.rs index 7ed83d296..4665ddfa3 100644 --- a/ceno_zkvm/src/gadgets/signed_limbs.rs +++ b/ceno_zkvm/src/gadgets/signed_limbs.rs @@ -7,7 +7,8 @@ use crate::{ use ff_ext::{ExtensionField, FieldInto, SmallField}; use gkr_iop::error::CircuitBuilderError; use multilinear_extensions::{Expression, ToExpr, WitIn}; -use p3::field::FieldAlgebra; + +use p3::field::PrimeCharacteristicRing; use std::{array, marker::PhantomData}; use witness::set_val; @@ -70,13 +71,11 @@ impl UIntLimbsLT { circuit_builder.require_zero( || "a_diff", - a_diff.expr() - * (E::BaseField::from_canonical_u32(1 << LIMB_BITS).expr() - a_diff.expr()), + a_diff.expr() * (E::BaseField::from_u32(1 << LIMB_BITS).expr() - a_diff.expr()), )?; circuit_builder.require_zero( || "b_diff", - b_diff.expr() - * (E::BaseField::from_canonical_u32(1 << LIMB_BITS).expr() - b_diff.expr()), + b_diff.expr() * (E::BaseField::from_u32(1 << LIMB_BITS).expr() - b_diff.expr()), )?; let mut prefix_sum = Expression::ZERO; @@ -86,7 +85,7 @@ impl UIntLimbsLT { b_msb_f.expr() - a_msb_f.expr() } else { b_expr[i].expr() - a_expr[i].expr() - }) * (E::BaseField::from_canonical_u8(2).expr() * cmp_lt.expr() + }) * (E::BaseField::from_u8(2).expr() * cmp_lt.expr() - E::BaseField::ONE.expr()); prefix_sum += diff_marker[i].expr(); circuit_builder.require_zero( @@ -122,7 +121,7 @@ impl UIntLimbsLT { || "a_msb_f_signed_range_check", a_msb_f.expr() + if is_sign_comparison { - E::BaseField::from_canonical_u32(1 << (LIMB_BITS - 1)).expr() + E::BaseField::from_u32(1 << (LIMB_BITS - 1)).expr() } else { Expression::ZERO }, @@ -132,7 +131,7 @@ impl UIntLimbsLT { || "b_msb_f_signed_range_check", b_msb_f.expr() + if is_sign_comparison { - E::BaseField::from_canonical_u32(1 << (LIMB_BITS - 1)).expr() + E::BaseField::from_u32(1 << (LIMB_BITS - 1)).expr() } else { Expression::ZERO }, @@ -170,23 +169,23 @@ impl UIntLimbsLT { // otherwise read_rs1_msb_f and read_rs2_msb_f let (a_msb_f, a_msb_range) = if is_a_neg { ( - -E::BaseField::from_canonical_u32((1 << LIMB_BITS) - a[UINT_LIMBS - 1] as u32), + -E::BaseField::from_u32((1 << LIMB_BITS) - a[UINT_LIMBS - 1] as u32), a[UINT_LIMBS - 1] - (1 << (LIMB_BITS - 1)), ) } else { ( - E::BaseField::from_canonical_u16(a[UINT_LIMBS - 1]), + E::BaseField::from_u16(a[UINT_LIMBS - 1]), a[UINT_LIMBS - 1] + ((is_sign_comparison as u16) << (LIMB_BITS - 1)), ) }; let (b_msb_f, b_msb_range) = if is_b_neg { ( - -E::BaseField::from_canonical_u32((1 << LIMB_BITS) - b[UINT_LIMBS - 1] as u32), + -E::BaseField::from_u32((1 << LIMB_BITS) - b[UINT_LIMBS - 1] as u32), b[UINT_LIMBS - 1] - (1 << (LIMB_BITS - 1)), ) } else { ( - E::BaseField::from_canonical_u16(b[UINT_LIMBS - 1]), + E::BaseField::from_u16(b[UINT_LIMBS - 1]), b[UINT_LIMBS - 1] + ((is_sign_comparison as u16) << (LIMB_BITS - 1)), ) }; diff --git a/ceno_zkvm/src/gadgets/util.rs b/ceno_zkvm/src/gadgets/util.rs index 121ab2937..adc74d813 100644 --- a/ceno_zkvm/src/gadgets/util.rs +++ b/ceno_zkvm/src/gadgets/util.rs @@ -28,11 +28,11 @@ use sp1_curves::polynomial::Polynomial; fn biguint_to_field(num: BigUint) -> F { let mut x = F::ZERO; - let mut power = F::from_canonical_u32(1u32); - let base = F::from_canonical_u64((1 << 32) % F::MODULUS_U64); + let mut power = F::from_u32(1u32); + let base = F::from_u64((1 << 32) % F::MODULUS_U64); let digits = num.iter_u32_digits(); for digit in digits.into_iter() { - x += F::from_canonical_u32(digit) * power; + x += F::from_u32(digit) * power; power *= base; } x @@ -58,7 +58,7 @@ pub fn compute_root_quotient_and_shift( debug_assert_eq!(p_vanishing_eval, F::ZERO); // Compute the witness polynomial by witness(x) = vanishing(x) / (x - 2^nb_bits_per_limb). - let root_monomial = F::from_canonical_u32(2u32.pow(nb_bits_per_limb)); + let root_monomial = F::from_u32(2u32.pow(nb_bits_per_limb)); let p_quotient = p_vanishing.root_quotient(root_monomial); debug_assert_eq!(p_quotient.degree(), p_vanishing.degree() - 1); @@ -78,7 +78,7 @@ pub fn compute_root_quotient_and_shift( // Shifting the witness polynomial to make it positive p_quotient_coefficients .into_iter() - .map(|x| x + F::from_canonical_u64(offset_u64)) + .map(|x| x + F::from_u64(offset_u64)) .collect::>() } @@ -88,12 +88,12 @@ pub fn split_u16_limbs_to_u8_limbs(slice: &[F]) -> (Vec, Vec> 8) as u8) - .map(|x| F::from_canonical_u8(x)) + .map(|x| F::from_u8(x)) .collect(), ) } diff --git a/ceno_zkvm/src/gadgets/util_expr.rs b/ceno_zkvm/src/gadgets/util_expr.rs index 44e21a7a2..a3d4271a3 100644 --- a/ceno_zkvm/src/gadgets/util_expr.rs +++ b/ceno_zkvm/src/gadgets/util_expr.rs @@ -25,9 +25,37 @@ use ff_ext::ExtensionField; use gkr_iop::{circuit_builder::CircuitBuilder, error::CircuitBuilderError}; use multilinear_extensions::{Expression, ToExpr}; -use p3::field::FieldAlgebra; + +use p3::field::PrimeCharacteristicRing; use sp1_curves::{params::FieldParameters, polynomial::Polynomial}; +pub fn poly_mul_expr( + a: &Polynomial>, + b: &Polynomial>, +) -> Polynomial> { + let mut coeffs = + vec![Expression::::ZERO; a.coefficients().len() + b.coefficients().len() - 1]; + for (i, coeff_a) in a.coefficients().iter().enumerate() { + for (j, coeff_b) in b.coefficients().iter().enumerate() { + coeffs[i + j] = coeffs[i + j].clone() + coeff_a.clone() * coeff_b.clone(); + } + } + Polynomial::new(coeffs) +} + +pub fn poly_scale_expr( + poly: &Polynomial>, + scalar: Expression, +) -> Polynomial> { + Polynomial::new( + poly.coefficients() + .iter() + .cloned() + .map(|c| c * scalar.clone()) + .collect(), + ) +} + pub fn eval_field_operation( builder: &mut CircuitBuilder, p_vanishing: &Polynomial>, @@ -35,21 +63,20 @@ pub fn eval_field_operation( p_witness_high: &Polynomial>, ) -> Result<(), CircuitBuilderError> { // Reconstruct and shift back the witness polynomial - let limb: Expression = - E::BaseField::from_canonical_u32(2u32.pow(P::NB_BITS_PER_LIMB as u32)).expr(); + let limb: Expression = E::BaseField::from_u32(2u32.pow(P::NB_BITS_PER_LIMB as u32)).expr(); - let p_witness_shifted = p_witness_low + &(p_witness_high * limb.clone()); + let p_witness_shifted = p_witness_low + &poly_scale_expr(p_witness_high, limb.clone()); // Shift down the witness polynomial. Shifting is needed to range check that each // coefficient w_i of the witness polynomial satisfies |w_i| < 2^WITNESS_OFFSET. - let offset: Expression = E::BaseField::from_canonical_u32(P::WITNESS_OFFSET as u32).expr(); + let offset: Expression = E::BaseField::from_u32(P::WITNESS_OFFSET as u32).expr(); let len = p_witness_shifted.coefficients().len(); let p_witness = p_witness_shifted - Polynomial::new(vec![offset; len]); // Multiply by (x-2^NB_BITS_PER_LIMB) and make the constraint let root_monomial = Polynomial::new(vec![-limb, E::BaseField::ONE.expr()]); - let constraints = p_vanishing - &(p_witness * root_monomial); + let constraints = p_vanishing - &poly_mul_expr(&p_witness, &root_monomial); for constr in constraints.as_coefficients() { builder.require_zero(|| "eval_field_operation require zero", constr)?; } diff --git a/ceno_zkvm/src/gadgets/word.rs b/ceno_zkvm/src/gadgets/word.rs index a11a80587..d59d92084 100644 --- a/ceno_zkvm/src/gadgets/word.rs +++ b/ceno_zkvm/src/gadgets/word.rs @@ -130,7 +130,7 @@ impl IndexMut for Word { impl From for Word { fn from(value: u32) -> Self { - Word(value.to_le_bytes().map(F::from_canonical_u8)) + Word(value.to_le_bytes().map(F::from_u8)) } } diff --git a/ceno_zkvm/src/gadgets/xor.rs b/ceno_zkvm/src/gadgets/xor.rs index d2157aa10..e6a1bd795 100644 --- a/ceno_zkvm/src/gadgets/xor.rs +++ b/ceno_zkvm/src/gadgets/xor.rs @@ -60,7 +60,7 @@ impl XorOperation { let y_bytes = y.to_le_bytes(); for i in 0..WORD_SIZE { let xor = x_bytes[i] ^ y_bytes[i]; - self.value[i] = F::from_canonical_u8(xor); + self.value[i] = F::from_u8(xor); record.lookup_xor_byte(x_bytes[i] as u64, y_bytes[i] as u64); } diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 9dd99ef92..e5ae3eef9 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -12,7 +12,8 @@ use gkr_iop::{ }; use itertools::Itertools; use multilinear_extensions::{ToExpr, util::max_usable_threads}; -use p3::field::FieldAlgebra; + +use p3::field::PrimeCharacteristicRing; use rayon::{ iter::{IndexedParallelIterator, ParallelIterator}, slice::ParallelSlice, diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs index 027483d1e..cbea52181 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs @@ -14,7 +14,8 @@ use crate::{ use ceno_emul::{InsnKind, StepRecord}; use ff_ext::{ExtensionField, FieldInto}; use multilinear_extensions::{ToExpr, WitIn}; -use p3::field::FieldAlgebra; + +use p3::field::PrimeCharacteristicRing; use std::marker::PhantomData; use witness::set_val; @@ -84,7 +85,7 @@ impl Instruction for AddiInstruction { let rs1_read = Value::new_unchecked(step.rs1().unwrap().value); let imm = step.insn().imm as i16 as u16; - set_val!(instance, config.imm, E::BaseField::from_canonical_u16(imm)); + set_val!(instance, config.imm, E::BaseField::from_u16(imm)); let imm_sign_extend = imm_sign_extend(true, step.insn().imm as i16); set_val!( diff --git a/ceno_zkvm/src/instructions/riscv/auipc.rs b/ceno_zkvm/src/instructions/riscv/auipc.rs index 6311fc2aa..793f63838 100644 --- a/ceno_zkvm/src/instructions/riscv/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/auipc.rs @@ -21,7 +21,7 @@ use crate::{ use ceno_emul::InsnKind; use gkr_iop::tables::{LookupTable, ops::XorTable}; use multilinear_extensions::{Expression, ToExpr, WitIn}; -use p3::field::{Field, FieldAlgebra}; +use p3::field::{Field, PrimeCharacteristicRing as FieldAlgebra}; use witness::set_val; pub struct AuipcConfig { @@ -68,8 +68,7 @@ impl Instruction for AuipcInstruction { .iter() .enumerate() .fold(E::BaseField::ZERO.expr(), |acc, (i, &val)| { - acc + val.expr() - * E::BaseField::from_canonical_u32(1 << (i * UInt8::::LIMB_BITS)).expr() + acc + val.expr() * E::BaseField::from_u32(1 << (i * UInt8::::LIMB_BITS)).expr() }); let i_insn = IInstructionConfig::::construct_circuit( @@ -88,16 +87,13 @@ impl Instruction for AuipcInstruction { .enumerate() .fold(E::BaseField::ZERO.expr(), |acc, (i, val)| { acc + val.expr() - * E::BaseField::from_canonical_u32(1 << ((i + 1) * UInt8::::LIMB_BITS)) - .expr() + * E::BaseField::from_u32(1 << ((i + 1) * UInt8::::LIMB_BITS)).expr() }); // Compute the most significant limb of PC let pc_msl = (i_insn.vm_state.pc.expr() - intermed_val.expr()) - * (E::BaseField::from_canonical_usize( - 1 << (UInt8::::LIMB_BITS * (UINT_BYTE_LIMBS - 1)), - ) - .inverse()) + * (E::BaseField::from_usize(1 << (UInt8::::LIMB_BITS * (UINT_BYTE_LIMBS - 1))) + .inverse()) .expr(); // The vector pc_limbs contains the actual limbs of PC in little endian order @@ -113,7 +109,7 @@ impl Instruction for AuipcInstruction { let last_limb_bits = PC_BITS - UInt8::::LIMB_BITS * (UINT_BYTE_LIMBS - 1); let additional_bits = (last_limb_bits..UInt8::::LIMB_BITS).fold(0, |acc, x| acc + (1 << x)); - let additional_bits = E::BaseField::from_canonical_u32(additional_bits); + let additional_bits = E::BaseField::from_u32(additional_bits); circuit_builder.logic_u8( LookupTable::Xor, pc_limbs_expr[3].expr(), @@ -123,7 +119,7 @@ impl Instruction for AuipcInstruction { let mut carry: [Expression; UINT_BYTE_LIMBS] = std::array::from_fn(|_| E::BaseField::ZERO.expr()); - let carry_divide = E::BaseField::from_canonical_usize(1 << UInt8::::LIMB_BITS) + let carry_divide = E::BaseField::from_usize(1 << UInt8::::LIMB_BITS) .inverse() .expr(); @@ -169,13 +165,13 @@ impl Instruction for AuipcInstruction { let pc = split_to_u8(step.pc().before.0); for (val, witin) in izip!(pc.iter().skip(1), config.pc_limbs) { lk_multiplicity.assert_ux::<8>(*val as u64); - set_val!(instance, witin, E::BaseField::from_canonical_u8(*val)); + set_val!(instance, witin, E::BaseField::from_u8(*val)); } let imm = InsnRecord::::imm_internal(&step.insn()).0 as u32; let imm = split_to_u8(imm); for (val, witin) in izip!(imm.iter(), config.imm_limbs) { lk_multiplicity.assert_ux::<8>(*val as u64); - set_val!(instance, witin, E::BaseField::from_canonical_u8(*val)); + set_val!(instance, witin, E::BaseField::from_u8(*val)); } // constrain pc msb limb range via xor let last_limb_bits = PC_BITS - UInt8::::LIMB_BITS * (UINT_BYTE_LIMBS - 1); diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs index 3622dad73..e0e4a4056 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs @@ -21,7 +21,7 @@ use crate::{ witness::LkMultiplicity, }; use multilinear_extensions::Expression; -pub use p3::field::FieldAlgebra; +use p3::field::PrimeCharacteristicRing; pub struct BranchCircuit(PhantomData<(E, I)>); @@ -160,8 +160,8 @@ impl Instruction for BranchCircuit Instruction for BranchCircuit { @@ -97,10 +98,10 @@ impl Instruction for ArithInstruction = - dividend_sign.expr() * E::BaseField::from_canonical_u32((1 << LIMB_BITS) - 1).expr(); + dividend_sign.expr() * E::BaseField::from_u32((1 << LIMB_BITS) - 1).expr(); let divisor_ext: Expression = - divisor_sign.expr() * E::BaseField::from_canonical_u32((1 << LIMB_BITS) - 1).expr(); - let carry_divide = E::BaseField::from_canonical_u32(1 << UInt::::LIMB_BITS).inverse(); + divisor_sign.expr() * E::BaseField::from_u32((1 << LIMB_BITS) - 1).expr(); + let carry_divide = E::BaseField::from_u32(1 << UInt::::LIMB_BITS).inverse(); let mut carry_expr: [Expression; UINT_LIMBS] = array::from_fn(|_| E::BaseField::ZERO.expr()); @@ -126,7 +127,7 @@ impl Instruction for ArithInstruction = - quotient_sign.expr() * E::BaseField::from_canonical_u32((1 << LIMB_BITS) - 1).expr(); + quotient_sign.expr() * E::BaseField::from_u32((1 << LIMB_BITS) - 1).expr(); let mut carry_ext: [Expression; UINT_LIMBS] = array::from_fn(|_| E::BaseField::ZERO.expr()); @@ -172,8 +173,7 @@ impl Instruction for ArithInstruction Instruction for ArithInstruction Instruction for ArithInstruction::new_unchecked(|| "remainder_prime", cb)?; let remainder_prime_expr = remainder_prime.expr(); @@ -276,8 +274,7 @@ impl Instruction for ArithInstruction Instruction for ArithInstruction Instruction for ArithInstruction { cb.assert_dynamic_range( || "div_rem_range_check_dividend_last", - E::BaseField::from_canonical_u32(2).expr() + E::BaseField::from_u32(2).expr() * (dividend_expr[UINT_LIMBS - 1].clone() - dividend_sign.expr() * sign_mask.expr()), - E::BaseField::from_canonical_u32(16).expr(), + E::BaseField::from_u32(16).expr(), )?; cb.assert_dynamic_range( || "div_rem_range_check_divisor_last", - E::BaseField::from_canonical_u32(2).expr() + E::BaseField::from_u32(2).expr() * (divisor_expr[UINT_LIMBS - 1].clone() - divisor_sign.expr() * sign_mask.expr()), - E::BaseField::from_canonical_u32(16).expr(), + E::BaseField::from_u32(16).expr(), )?; } InsnKind::DIVU | InsnKind::REMU => { @@ -474,12 +471,12 @@ impl Instruction for ArithInstruction Instruction for ArithInstruction Instruction for ArithInstruction( WriteMEM::construct_circuit( cb, value_ptr_0.prev_value.as_ref().unwrap().value() - + E::BaseField::from_canonical_u32(ByteAddr::from((i * WORD_SIZE) as u32).0) - .expr(), + + E::BaseField::from_u32(ByteAddr::from((i * WORD_SIZE) as u32).0).expr(), val_before.clone(), val_after.clone(), vm_state.ts, @@ -279,10 +279,7 @@ fn build_fp_op_circuit( WriteMEM::construct_circuit( cb, value_ptr_1.prev_value.as_ref().unwrap().value() - + E::BaseField::from_canonical_u32( - ByteAddr::from((i * WORD_SIZE) as u32).0, - ) - .expr(), + + E::BaseField::from_u32(ByteAddr::from((i * WORD_SIZE) as u32).0).expr(), val_before.clone(), val_before.clone(), vm_state.ts, diff --git a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs index 6715a0b74..2b1daab2c 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs @@ -12,7 +12,7 @@ use gkr_iop::{ }; use itertools::{Itertools, izip}; use multilinear_extensions::{ToExpr, util::max_usable_threads}; -use p3::{field::FieldAlgebra, matrix::Matrix}; +use p3::matrix::Matrix; use rayon::{ iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}, slice::ParallelSlice, @@ -43,6 +43,7 @@ use crate::{ tables::{InsnRecord, RMMCollections}, witness::LkMultiplicity, }; +use p3::field::PrimeCharacteristicRing; pub trait Fp2AddSpec: FpOpField { const SYSCALL_CODE: u32; @@ -180,8 +181,7 @@ fn build_fp2_add_circuit Instruction for HaltInstruction { // read exit_code from arg0 (X10 register) let (_, lt_x10_cfg) = cb.register_read( || "read x10", - E::BaseField::from_canonical_u64(ceno_emul::Platform::reg_arg0() as u64), + E::BaseField::from_u64(ceno_emul::Platform::reg_arg0() as u64), prev_x10_ts.expr(), ecall_cfg.ts.expr() + Tracer::SUBCYCLE_RS2, exit_code, diff --git a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs index e088cc0cc..5ffa68633 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs @@ -12,7 +12,7 @@ use gkr_iop::{ }; use itertools::{Itertools, izip}; use multilinear_extensions::{ToExpr, util::max_usable_threads}; -use p3::{field::FieldAlgebra, matrix::Matrix}; +use p3::matrix::Matrix; use rayon::{ iter::{IndexedParallelIterator, ParallelIterator}, slice::ParallelSlice, @@ -40,6 +40,7 @@ use crate::{ tables::{InsnRecord, RMMCollections}, witness::LkMultiplicity, }; +use p3::field::PrimeCharacteristicRing; #[derive(Debug)] pub struct EcallKeccakConfig { @@ -121,10 +122,7 @@ impl Instruction for KeccakInstruction { WriteMEM::construct_circuit( cb, state_ptr.prev_value.as_ref().unwrap().value() - + E::BaseField::from_canonical_u32( - ByteAddr::from((i * WORD_SIZE) as u32).0, - ) - .expr(), + + E::BaseField::from_u32(ByteAddr::from((i * WORD_SIZE) as u32).0).expr(), val_before.clone(), val_after.clone(), vm_state.ts, diff --git a/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs b/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs index f3a39093f..5b694101b 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs @@ -14,7 +14,7 @@ use gkr_iop::{ use itertools::{Itertools, chain, izip}; use multilinear_extensions::{ToExpr, util::max_usable_threads}; use num_bigint::BigUint; -use p3::{field::FieldAlgebra, matrix::Matrix}; +use p3::matrix::Matrix; use rayon::{ iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}, slice::ParallelSlice, @@ -52,6 +52,7 @@ use crate::{ tables::{InsnRecord, RMMCollections}, witness::LkMultiplicity, }; +use p3::field::PrimeCharacteristicRing; #[derive(Debug)] pub struct EcallUint256MulConfig { @@ -142,10 +143,7 @@ impl Instruction for Uint256MulInstruction { cb, // mem address := word_ptr_0 + i word_ptr_0.prev_value.as_ref().unwrap().value() - + E::BaseField::from_canonical_u32( - ByteAddr::from((i * WORD_SIZE) as u32).0, - ) - .expr(), + + E::BaseField::from_u32(ByteAddr::from((i * WORD_SIZE) as u32).0).expr(), val_before.clone(), val_after.clone(), vm_state.ts, @@ -165,10 +163,7 @@ impl Instruction for Uint256MulInstruction { cb, // mem address := word_ptr_1 + i word_ptr_1.prev_value.as_ref().unwrap().value() - + E::BaseField::from_canonical_u32( - ByteAddr::from((i * WORD_SIZE) as u32).0, - ) - .expr(), + + E::BaseField::from_u32(ByteAddr::from((i * WORD_SIZE) as u32).0).expr(), val_before.clone(), val_before.clone(), vm_state.ts, @@ -495,10 +490,7 @@ impl Instruction for Uint256InvInstr cb, // mem address := word_ptr_0 + i word_ptr_0.prev_value.as_ref().unwrap().value() - + E::BaseField::from_canonical_u32( - ByteAddr::from((i * WORD_SIZE) as u32).0, - ) - .expr(), + + E::BaseField::from_u32(ByteAddr::from((i * WORD_SIZE) as u32).0).expr(), val_before.clone(), val_after.clone(), vm_state.ts, diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs index 80c85ef7a..f13058b6e 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs @@ -13,7 +13,7 @@ use gkr_iop::{ }; use itertools::{Itertools, izip}; use multilinear_extensions::{ToExpr, util::max_usable_threads}; -use p3::{field::FieldAlgebra, matrix::Matrix}; +use p3::matrix::Matrix; use rayon::{ iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}, slice::ParallelSlice, @@ -41,6 +41,7 @@ use crate::{ tables::{InsnRecord, RMMCollections}, witness::LkMultiplicity, }; +use p3::field::PrimeCharacteristicRing; #[derive(Debug)] pub struct EcallWeierstrassAddAssignConfig { @@ -144,10 +145,7 @@ impl Instruction cb, // mem address := point_ptr_0 + i point_ptr_0.prev_value.as_ref().unwrap().value() - + E::BaseField::from_canonical_u32( - ByteAddr::from((i * WORD_SIZE) as u32).0, - ) - .expr(), + + E::BaseField::from_u32(ByteAddr::from((i * WORD_SIZE) as u32).0).expr(), val_before.clone(), val_after.clone(), vm_state.ts, @@ -165,10 +163,8 @@ impl Instruction cb, // mem address := point_ptr_1 + i point_ptr_1.prev_value.as_ref().unwrap().value() - + E::BaseField::from_canonical_u32( - ByteAddr::from((i * WORD_SIZE) as u32).0, - ) - .expr(), + + E::BaseField::from_u32(ByteAddr::from((i * WORD_SIZE) as u32).0) + .expr(), val_before.clone(), val_before.clone(), vm_state.ts, diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs index 72f5f71d8..36c66122c 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs @@ -13,7 +13,7 @@ use gkr_iop::{ }; use itertools::{Itertools, izip}; use multilinear_extensions::{ToExpr, util::max_usable_threads}; -use p3::{field::FieldAlgebra, matrix::Matrix}; +use p3::matrix::Matrix; use rayon::{ iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}, slice::ParallelSlice, @@ -41,6 +41,7 @@ use crate::{ tables::{InsnRecord, RMMCollections}, witness::LkMultiplicity, }; +use p3::field::PrimeCharacteristicRing; #[derive(Debug)] pub struct EcallWeierstrassDoubleAssignConfig< @@ -139,10 +140,7 @@ impl Instruction { @@ -35,7 +35,7 @@ impl OpFixedRS OpFixedRS MemAddr { .sum(); // Range check the middle bits, that is the low limb excluding the low bits. - let shift_right = E::BaseField::from_canonical_u64(1 << Self::N_LOW_BITS) + let shift_right = E::BaseField::from_u64(1 << Self::N_LOW_BITS) .inverse() .expr(); let mid_u14 = (&limbs[0] - low_sum) * shift_right; diff --git a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs index a766ea795..6692a1c5d 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs @@ -20,7 +20,7 @@ use crate::{ use ceno_emul::{InsnKind, PC_STEP_SIZE}; use gkr_iop::tables::{LookupTable, ops::XorTable}; use multilinear_extensions::{Expression, ToExpr}; -use p3::field::FieldAlgebra; +use p3::field::PrimeCharacteristicRing; pub struct JalConfig { pub j_insn: JInstructionConfig, @@ -69,7 +69,7 @@ impl Instruction for JalInstruction { let last_limb_bits = PC_BITS - UInt8::::LIMB_BITS * (UInt8::::NUM_LIMBS - 1); let additional_bits = (last_limb_bits..UInt8::::LIMB_BITS).fold(0, |acc, x| acc + (1 << x)); - let additional_bits = E::BaseField::from_canonical_u32(additional_bits); + let additional_bits = E::BaseField::from_u32(additional_bits); circuit_builder.logic_u8( LookupTable::Xor, rd_exprs[3].expr(), @@ -84,7 +84,7 @@ impl Instruction for JalInstruction { .enumerate() .fold(Expression::ZERO, |acc, (i, val)| { acc + val.expr() - * E::BaseField::from_canonical_u32(1 << (i * UInt8::::LIMB_BITS)).expr() + * E::BaseField::from_u32(1 << (i * UInt8::::LIMB_BITS)).expr() }), j_insn.vm_state.pc.expr() + PC_STEP_SIZE, )?; diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs index 2331c2f82..c148469ad 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs @@ -18,7 +18,7 @@ use crate::{ use ceno_emul::{InsnKind, PC_STEP_SIZE}; use ff_ext::FieldInto; use multilinear_extensions::{Expression, ToExpr, WitIn}; -use p3::field::FieldAlgebra; +use p3::field::PrimeCharacteristicRing as FieldAlgebra; pub struct JalrConfig { pub i_insn: IInstructionConfig, diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs index 7c51728ac..67504af7f 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs @@ -23,7 +23,7 @@ use crate::{ use ceno_emul::{InsnKind, PC_STEP_SIZE}; use ff_ext::FieldInto; use multilinear_extensions::{Expression, ToExpr, WitIn}; -use p3::field::{Field, FieldAlgebra}; +use p3::field::{Field, PrimeCharacteristicRing as FieldAlgebra}; pub struct JalrConfig { pub i_insn: IInstructionConfig, @@ -64,8 +64,8 @@ impl Instruction for JalrInstruction { let vm_state = StateInOut::construct_circuit(circuit_builder, true)?; let rd_high = circuit_builder.create_witin(|| "rd_high"); let rd_low: Expression<_> = vm_state.pc.expr() - + E::BaseField::from_canonical_usize(PC_STEP_SIZE).expr() - - rd_high.expr() * E::BaseField::from_canonical_u32(1 << UInt::::LIMB_BITS).expr(); + + E::BaseField::from_usize(PC_STEP_SIZE).expr() + - rd_high.expr() * E::BaseField::from_u32(1 << UInt::::LIMB_BITS).expr(); // rd range check // rd_low circuit_builder.assert_const_range(|| "rd_low_u16", rd_low.expr(), UInt::::LIMB_BITS)?; @@ -102,15 +102,15 @@ impl Instruction for JalrInstruction { // 1. rs1 + imm = jump_pc_addr + overflow*2^32 // 3. next_pc = jump_pc_addr aligned to even value (round down) - let inv = E::BaseField::from_canonical_u32(1 << UInt::::LIMB_BITS).inverse(); + let inv = E::BaseField::from_u32(1 << UInt::::LIMB_BITS).inverse(); let carry = (rs1_read.expr()[0].expr() + imm.expr() - jump_pc_addr.uint_unaligned().expr()[0].expr()) * inv.expr(); circuit_builder.assert_bit(|| "carry_lo_bit", carry.expr())?; - let imm_extend_limb = imm_sign.expr() - * E::BaseField::from_canonical_u32((1 << UInt::::LIMB_BITS) - 1).expr(); + let imm_extend_limb = + imm_sign.expr() * E::BaseField::from_u32((1 << UInt::::LIMB_BITS) - 1).expr(); let carry = (rs1_read.expr()[1].expr() + imm_extend_limb.expr() + carry - jump_pc_addr.uint_unaligned().expr()[1].expr()) * inv.expr(); @@ -166,11 +166,7 @@ impl Instruction for JalrInstruction { config .rs1_read .assign_value(instance, Value::new_unchecked(rs1)); - set_val!( - instance, - config.rd_high, - E::BaseField::from_canonical_u16(rd_limb[1]) - ); + set_val!(instance, config.rd_high, E::BaseField::from_u16(rd_limb[1])); let (sum, _) = rs1.overflowing_add_signed(i32::from_ne_bytes([ imm_sign_extend[0] as u8, diff --git a/ceno_zkvm/src/instructions/riscv/lui.rs b/ceno_zkvm/src/instructions/riscv/lui.rs index deb7b5736..60681f1b0 100644 --- a/ceno_zkvm/src/instructions/riscv/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/lui.rs @@ -20,7 +20,8 @@ use crate::{ }; use ceno_emul::InsnKind; use multilinear_extensions::{Expression, ToExpr, WitIn}; -use p3::field::FieldAlgebra; + +use p3::field::PrimeCharacteristicRing; use witness::set_val; pub struct LuiConfig { @@ -75,14 +76,14 @@ impl Instruction for LuiInstruction { .enumerate() .fold(Expression::ZERO, |acc, (i, val)| { acc + val.expr() - * E::BaseField::from_canonical_u32(1 << (i * UInt8::::LIMB_BITS)).expr() + * E::BaseField::from_u32(1 << (i * UInt8::::LIMB_BITS)).expr() }); // imm * 2^4 is the correct composition of intermed_val in case of LUI circuit_builder.require_equal( || "imm * 2^4 is the correct composition of intermed_val in case of LUI", intermed_val.expr(), - imm.expr() * E::BaseField::from_canonical_u32(1 << (12 - UInt8::::LIMB_BITS)).expr(), + imm.expr() * E::BaseField::from_u32(1 << (12 - UInt8::::LIMB_BITS)).expr(), )?; Ok(LuiConfig { @@ -106,7 +107,7 @@ impl Instruction for LuiInstruction { let rd_written = split_to_u8(step.rd().unwrap().value.after); for (val, witin) in izip!(rd_written.iter().skip(1), config.rd_written) { lk_multiplicity.assert_ux::<8>(*val as u64); - set_val!(instance, witin, E::BaseField::from_canonical_u8(*val)); + set_val!(instance, witin, E::BaseField::from_u8(*val)); } let imm = InsnRecord::::imm_internal(&step.insn()).0 as u64; set_val!(instance, config.imm, imm); diff --git a/ceno_zkvm/src/instructions/riscv/memory/gadget.rs b/ceno_zkvm/src/instructions/riscv/memory/gadget.rs index 3a8da4a09..0693fb661 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/gadget.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/gadget.rs @@ -10,7 +10,7 @@ use either::Either; use ff_ext::{ExtensionField, FieldInto}; use itertools::izip; use multilinear_extensions::{Expression, ToExpr, WitIn}; -use p3::field::{Field, FieldAlgebra}; +use p3::field::{Field, PrimeCharacteristicRing}; use witness::set_val; pub struct MemWordUtil { @@ -80,7 +80,7 @@ impl MemWordUtil { // extract the least significant byte from u16 limb let rs2_limb_bytes = alloc_bytes(cb, "rs2_limb[0]", 1)?; - let u8_base_inv = E::BaseField::from_canonical_u64(1 << 8).inverse(); + let u8_base_inv = E::BaseField::from_u64(1 << 8).inverse(); cb.assert_ux::<_, _, 8>( || "rs2_limb[0].le_bytes[1]", u8_base_inv.expr() * (&rs2_limbs[0] - rs2_limb_bytes[0].expr()), @@ -149,7 +149,7 @@ impl MemWordUtil { match N_ZEROS { 0 => { for (&col, byte) in izip!(&self.prev_limb_bytes, prev_limb.to_le_bytes()) { - set_val!(instance, col, E::BaseField::from_canonical_u8(byte)); + set_val!(instance, col, E::BaseField::from_u8(byte)); lk_multiplicity.assert_ux::<8>(byte as u64); } @@ -160,18 +160,18 @@ impl MemWordUtil { set_val!( instance, self.rs2_limb_bytes[0], - E::BaseField::from_canonical_u8(rs2_limb.to_le_bytes()[0]) + E::BaseField::from_u8(rs2_limb.to_le_bytes()[0]) ); rs2_limb.to_le_bytes().into_iter().for_each(|byte| { lk_multiplicity.assert_ux::<8>(byte as u64); }); let change = if low_bits[0] == 0 { - E::BaseField::from_canonical_u16((prev_limb.to_le_bytes()[1] as u16) << 8) - + E::BaseField::from_canonical_u8(rs2_limb.to_le_bytes()[0]) + E::BaseField::from_u16((prev_limb.to_le_bytes()[1] as u16) << 8) + + E::BaseField::from_u8(rs2_limb.to_le_bytes()[0]) } else { - E::BaseField::from_canonical_u16((rs2_limb.to_le_bytes()[0] as u16) << 8) - + E::BaseField::from_canonical_u8(prev_limb.to_le_bytes()[0]) + E::BaseField::from_u16((rs2_limb.to_le_bytes()[0] as u16) << 8) + + E::BaseField::from_u8(prev_limb.to_le_bytes()[0]) }; set_val!(instance, expected_limb_witin, change); } diff --git a/ceno_zkvm/src/instructions/riscv/memory/load.rs b/ceno_zkvm/src/instructions/riscv/memory/load.rs index 818e8902a..a26ff5c52 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load.rs @@ -18,7 +18,8 @@ use ceno_emul::{ByteAddr, InsnKind, StepRecord}; use ff_ext::{ExtensionField, FieldInto}; use itertools::izip; use multilinear_extensions::{Expression, ToExpr, WitIn}; -use p3::field::FieldAlgebra; + +use p3::field::PrimeCharacteristicRing; use std::marker::PhantomData; pub struct LoadConfig { @@ -198,11 +199,7 @@ impl Instruction for LoadInstruction Instruction for LoadInstruction(byte as u64); - set_val!(instance, col, E::BaseField::from_canonical_u8(byte)); + set_val!(instance, col, E::BaseField::from_u8(byte)); } } let val = match I::INST_KIND { diff --git a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs index 5a9ed40eb..10e61df94 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs @@ -21,7 +21,7 @@ use ceno_emul::{ByteAddr, InsnKind, StepRecord}; use ff_ext::{ExtensionField, FieldInto}; use itertools::izip; use multilinear_extensions::{Expression, ToExpr, WitIn}; -use p3::field::{Field, FieldAlgebra}; +use p3::field::{Field, PrimeCharacteristicRing as FieldAlgebra}; use std::marker::PhantomData; pub struct LoadConfig { @@ -75,15 +75,15 @@ impl Instruction for LoadInstruction::LIMB_BITS).inverse(); + let inv = E::BaseField::from_u32(1 << UInt::::LIMB_BITS).inverse(); let carry = (rs1_read.expr()[0].expr() + imm.expr() - memory_addr.uint_unaligned().expr()[0].expr()) * inv.expr(); circuit_builder.assert_bit(|| "carry_lo_bit", carry.expr())?; - let imm_extend_limb = imm_sign.expr() - * E::BaseField::from_canonical_u32((1 << UInt::::LIMB_BITS) - 1).expr(); + let imm_extend_limb = + imm_sign.expr() * E::BaseField::from_u32((1 << UInt::::LIMB_BITS) - 1).expr(); let carry = (rs1_read.expr()[1].expr() + imm_extend_limb.expr() + carry - memory_addr.uint_unaligned().expr()[1].expr()) * inv.expr(); @@ -223,11 +223,7 @@ impl Instruction for LoadInstruction Instruction for LoadInstruction(byte as u64); - set_val!(instance, col, E::BaseField::from_canonical_u8(byte)); + set_val!(instance, col, E::BaseField::from_u8(byte)); } } let val = match I::INST_KIND { diff --git a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs index a1bd7a812..e7565e436 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs @@ -20,7 +20,7 @@ use crate::{ use ceno_emul::{ByteAddr, InsnKind, StepRecord}; use ff_ext::{ExtensionField, FieldInto}; use multilinear_extensions::{ToExpr, WitIn}; -use p3::field::{Field, FieldAlgebra}; +use p3::field::{Field, PrimeCharacteristicRing as FieldAlgebra}; use std::marker::PhantomData; pub struct StoreConfig { @@ -79,15 +79,15 @@ impl Instruction } // rs1 + imm = mem_addr - let inv = E::BaseField::from_canonical_u32(1 << UInt::::LIMB_BITS).inverse(); + let inv = E::BaseField::from_u32(1 << UInt::::LIMB_BITS).inverse(); let carry = (rs1_read.expr()[0].expr() + imm.expr() - memory_addr.uint_unaligned().expr()[0].expr()) * inv.expr(); circuit_builder.assert_bit(|| "carry_lo_bit", carry.expr())?; - let imm_extend_limb = imm_sign.expr() - * E::BaseField::from_canonical_u32((1 << UInt::::LIMB_BITS) - 1).expr(); + let imm_extend_limb = + imm_sign.expr() * E::BaseField::from_u32((1 << UInt::::LIMB_BITS) - 1).expr(); let carry = (rs1_read.expr()[1].expr() + imm_extend_limb.expr() + carry - memory_addr.uint_unaligned().expr()[1].expr()) * inv.expr(); diff --git a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit.rs b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit.rs index d5eb551e1..4a582cdba 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit.rs @@ -82,7 +82,7 @@ use std::marker::PhantomData; use ceno_emul::{InsnKind, StepRecord}; use ff_ext::{ExtensionField, SmallField}; -use p3::{field::FieldAlgebra, goldilocks::Goldilocks}; +use p3::{field::PrimeCharacteristicRing as FieldAlgebra, goldilocks::Goldilocks}; use crate::{ circuit_builder::CircuitBuilder, @@ -338,8 +338,8 @@ impl Instruction for MulhInstructionBas } MulhSignDependencies::UU { constrain_rd } => { // assign nonzero value (u32::MAX - rd) - let rd_f = E::BaseField::from_canonical_u64(rd as u64); - let avoid_f = E::BaseField::from_canonical_u32(u32::MAX); + let rd_f = E::BaseField::from_u64(rd as u64); + let avoid_f = E::BaseField::from_u32(u32::MAX); constrain_rd.assign_instance(instance, rd_f, avoid_f)?; // only take the low part of the product @@ -351,12 +351,8 @@ impl Instruction for MulhInstructionBas assert_eq!(prod_lo, rd); let prod_hi = prod >> BIT_WIDTH; - let avoid_f = E::BaseField::from_canonical_u32(u32::MAX); - constrain_rd.assign_instance( - instance, - E::BaseField::from_canonical_u64(prod_hi), - avoid_f, - )?; + let avoid_f = E::BaseField::from_u32(u32::MAX); + constrain_rd.assign_instance(instance, E::BaseField::from_u64(prod_hi), avoid_f)?; prod_hi as u32 } MulhSignDependencies::SU { diff --git a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs index f3bddff1b..755aa766d 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs @@ -16,7 +16,7 @@ use crate::{ use ceno_emul::{InsnKind, StepRecord}; use ff_ext::{ExtensionField, FieldInto}; use multilinear_extensions::{Expression, ToExpr as _, WitIn}; -use p3::field::{Field, FieldAlgebra}; +use p3::field::{Field, PrimeCharacteristicRing as FieldAlgebra}; use witness::set_val; use crate::e2e::ShardContext; @@ -63,7 +63,7 @@ impl Instruction for MulhInstructionBas let rs1_expr = rs1_read.expr(); let rs2_expr = rs2_read.expr(); - let carry_divide = E::BaseField::from_canonical_u32(1 << UInt::::LIMB_BITS).inverse(); + let carry_divide = E::BaseField::from_u32(1 << UInt::::LIMB_BITS).inverse(); let rd_low: [_; UINT_LIMBS] = array::from_fn(|i| circuit_builder.create_witin(|| format!("rd_low_{i}"))); @@ -86,12 +86,12 @@ impl Instruction for MulhInstructionBas circuit_builder.assert_dynamic_range( || format!("range_check_rd_low_{i}"), rd_low.expr(), - E::BaseField::from_canonical_u32(16).expr(), + E::BaseField::from_u32(16).expr(), )?; circuit_builder.assert_dynamic_range( || format!("range_check_carry_low_{i}"), carry_low.expr(), - E::BaseField::from_canonical_u32(18).expr(), + E::BaseField::from_u32(18).expr(), )?; } @@ -125,17 +125,17 @@ impl Instruction for MulhInstructionBas circuit_builder.assert_dynamic_range( || format!("range_check_high_{i}"), rd_high.expr(), - E::BaseField::from_canonical_u32(16).expr(), + E::BaseField::from_u32(16).expr(), )?; circuit_builder.assert_dynamic_range( || format!("range_check_carry_high_{i}"), carry_high.expr(), - E::BaseField::from_canonical_u32(18).expr(), + E::BaseField::from_u32(18).expr(), )?; } - let sign_mask = E::BaseField::from_canonical_u32(1 << (LIMB_BITS - 1)); - let ext_inv = E::BaseField::from_canonical_u32((1 << LIMB_BITS) - 1).inverse(); + let sign_mask = E::BaseField::from_u32(1 << (LIMB_BITS - 1)); + let ext_inv = E::BaseField::from_u32((1 << LIMB_BITS) - 1).inverse(); let rs1_sign: Expression = rs1_ext.expr() * ext_inv.expr(); let rs2_sign: Expression = rs2_ext.expr() * ext_inv.expr(); @@ -147,15 +147,15 @@ impl Instruction for MulhInstructionBas // Implement MULH circuit here circuit_builder.assert_dynamic_range( || "mulh_range_check_rs1_last", - E::BaseField::from_canonical_u32(2).expr() + E::BaseField::from_u32(2).expr() * (rs1_expr[UINT_LIMBS - 1].clone() - rs1_sign * sign_mask.expr()), - E::BaseField::from_canonical_u32(16).expr(), + E::BaseField::from_u32(16).expr(), )?; circuit_builder.assert_dynamic_range( || "mulh_range_check_rs2_last", - E::BaseField::from_canonical_u32(2).expr() + E::BaseField::from_u32(2).expr() * (rs2_expr[UINT_LIMBS - 1].clone() - rs2_sign * sign_mask.expr()), - E::BaseField::from_canonical_u32(16).expr(), + E::BaseField::from_u32(16).expr(), )?; } InsnKind::MULHU => { @@ -169,14 +169,14 @@ impl Instruction for MulhInstructionBas .require_zero(|| "mulhsu_rs2_sign_zero", rs2_sign.clone())?; circuit_builder.assert_dynamic_range( || "mulhsu_range_check_rs1_last", - E::BaseField::from_canonical_u32(2).expr() + E::BaseField::from_u32(2).expr() * (rs1_expr[UINT_LIMBS - 1].clone() - rs1_sign * sign_mask.expr()), - E::BaseField::from_canonical_u32(16).expr(), + E::BaseField::from_u32(16).expr(), )?; circuit_builder.assert_dynamic_range( || "mulhsu_range_check_rs2_last", rs2_expr[UINT_LIMBS - 1].clone() - rs2_sign * sign_mask.expr(), - E::BaseField::from_canonical_u32(16).expr(), + E::BaseField::from_u32(16).expr(), )?; } InsnKind::MUL => (), diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs index 310d17491..f5a670052 100644 --- a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs @@ -17,7 +17,7 @@ use ceno_emul::InsnKind; use ff_ext::{ExtensionField, FieldInto}; use itertools::Itertools; use multilinear_extensions::{Expression, ToExpr, WitIn}; -use p3::field::{Field, FieldAlgebra}; +use p3::field::{Field, PrimeCharacteristicRing as FieldAlgebra}; use std::{array, marker::PhantomData}; use witness::set_val; @@ -70,23 +70,21 @@ impl bit_shift_marker_i.expr(), )?; bit_marker_sum += bit_shift_marker_i.expr(); - bit_shift += E::BaseField::from_canonical_usize(i).expr() * bit_shift_marker_i.expr(); + bit_shift += E::BaseField::from_usize(i).expr() * bit_shift_marker_i.expr(); match kind { InsnKind::SLL | InsnKind::SLLI => { circuit_builder.condition_require_zero( || "bit_multiplier_left_condition", bit_shift_marker_i.expr(), - bit_multiplier_left.expr() - - E::BaseField::from_canonical_usize(1 << i).expr(), + bit_multiplier_left.expr() - E::BaseField::from_usize(1 << i).expr(), )?; } InsnKind::SRL | InsnKind::SRLI | InsnKind::SRA | InsnKind::SRAI => { circuit_builder.condition_require_zero( || "bit_multiplier_right_condition", bit_shift_marker_i.expr(), - bit_multiplier_right.expr() - - E::BaseField::from_canonical_usize(1 << i).expr(), + bit_multiplier_right.expr() - E::BaseField::from_usize(1 << i).expr(), )?; } _ => unreachable!(), @@ -104,8 +102,7 @@ impl limb_shift_marker[i].expr(), )?; limb_marker_sum += limb_shift_marker[i].expr(); - limb_shift += - E::BaseField::from_canonical_usize(i).expr() * limb_shift_marker[i].expr(); + limb_shift += E::BaseField::from_usize(i).expr() * limb_shift_marker[i].expr(); for j in 0..NUM_LIMBS { match kind { @@ -122,7 +119,7 @@ impl } else { bit_shift_carry[j - i - 1].expr() } + b[j - i].expr() * bit_multiplier_left.expr() - - E::BaseField::from_canonical_usize(1 << LIMB_BITS).expr() + - E::BaseField::from_usize(1 << LIMB_BITS).expr() * bit_shift_carry[j - i].expr(); circuit_builder.condition_require_zero( || format!("limb_shift_marker_a_expected_a_left_{i}_{j}",), @@ -139,17 +136,16 @@ impl limb_shift_marker[i].expr(), a[j].expr() - b_sign.expr() - * E::BaseField::from_canonical_usize((1 << LIMB_BITS) - 1) - .expr(), + * E::BaseField::from_usize((1 << LIMB_BITS) - 1).expr(), )?; } else { - let expected_a_right = - if j + i == NUM_LIMBS - 1 { - b_sign.expr() * (bit_multiplier_right.expr() - Expression::ONE) - } else { - bit_shift_carry[j + i + 1].expr() - } * E::BaseField::from_canonical_usize(1 << LIMB_BITS).expr() - + (b[j + i].expr() - bit_shift_carry[j + i].expr()); + let expected_a_right = if j + i == NUM_LIMBS - 1 { + b_sign.expr() * (bit_multiplier_right.expr() - Expression::ONE) + } else { + bit_shift_carry[j + i + 1].expr() + } * E::BaseField::from_usize(1 << LIMB_BITS) + .expr() + + (b[j + i].expr() - bit_shift_carry[j + i].expr()); circuit_builder.condition_require_zero( || format!("limb_shift_marker_a_expected_a_right_{i}_{j}",), @@ -165,11 +161,11 @@ impl circuit_builder.require_one(|| "limb_marker_sum_one_hot", limb_marker_sum.expr())?; // Check that bit_shift and limb_shift are correct. - let num_bits = E::BaseField::from_canonical_usize(NUM_LIMBS * LIMB_BITS); + let num_bits = E::BaseField::from_usize(NUM_LIMBS * LIMB_BITS); circuit_builder.assert_const_range( || "bit_shift_vs_limb_shift", (c[0].expr() - - limb_shift * E::BaseField::from_canonical_usize(LIMB_BITS).expr() + - limb_shift * E::BaseField::from_usize(LIMB_BITS).expr() - bit_shift.expr()) * num_bits.inverse().expr(), LIMB_BITS - ((NUM_LIMBS * LIMB_BITS) as u32).ilog2() as usize, @@ -177,13 +173,13 @@ impl if !matches!(kind, InsnKind::SRA | InsnKind::SRAI) { circuit_builder.require_zero(|| "b_sign_zero", b_sign.expr())?; } else { - let mask = E::BaseField::from_canonical_u32(1 << (LIMB_BITS - 1)).expr(); + let mask = E::BaseField::from_u32(1 << (LIMB_BITS - 1)).expr(); let b_sign_shifted = b_sign.expr() * mask.expr(); circuit_builder.lookup_xor_byte( b[NUM_LIMBS - 1].expr(), mask.expr(), b[NUM_LIMBS - 1].expr() + mask.expr() - - (E::BaseField::from_canonical_u32(2).expr()) * b_sign_shifted.expr(), + - (E::BaseField::from_u32(2).expr()) * b_sign_shifted.expr(), )?; } @@ -226,12 +222,12 @@ impl InsnKind::SLL | InsnKind::SLLI => set_val!( instance, self.bit_multiplier_left, - E::BaseField::from_canonical_usize(1 << bit_shift) + E::BaseField::from_usize(1 << bit_shift) ), _ => set_val!( instance, self.bit_multiplier_right, - E::BaseField::from_canonical_usize(1 << bit_shift) + E::BaseField::from_usize(1 << bit_shift) ), }; @@ -240,7 +236,7 @@ impl _ => b[i] % (1 << bit_shift), }); for (val, witin) in bit_shift_carry.iter().zip_eq(&self.bit_shift_carry) { - set_val!(instance, witin, E::BaseField::from_canonical_u32(*val)); + set_val!(instance, witin, E::BaseField::from_u32(*val)); lk_multiplicity.assert_dynamic_range(*val as u64, bit_shift as u64); } for (i, witin) in self.bit_shift_marker.iter().enumerate() { @@ -437,7 +433,7 @@ impl Instruction for ShiftImmInstructio step: &ceno_emul::StepRecord, ) -> Result<(), crate::error::ZKVMError> { let imm = step.insn().imm as i16 as u16; - set_val!(instance, config.imm, E::BaseField::from_canonical_u16(imm)); + set_val!(instance, config.imm, E::BaseField::from_u16(imm)); // rs1 let rs1_read = split_to_u8::(step.rs1().unwrap().value); // rd diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs index e2df652b1..bc50eb535 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs @@ -20,6 +20,7 @@ use ceno_emul::{InsnKind, SWord, StepRecord, Word}; use ff_ext::{ExtensionField, FieldInto}; use gkr_iop::gadgets::IsLtConfig; use multilinear_extensions::{ToExpr, WitIn}; +use p3::field::PrimeCharacteristicRing; use std::marker::PhantomData; use witness::set_val; diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs index b2449614e..93e2a86c5 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs @@ -19,7 +19,8 @@ use crate::{ use ceno_emul::{InsnKind, StepRecord, Word}; use ff_ext::{ExtensionField, FieldInto}; use multilinear_extensions::{ToExpr, WitIn}; -use p3::field::FieldAlgebra; + +use p3::field::PrimeCharacteristicRing; use std::marker::PhantomData; use witness::set_val; @@ -114,7 +115,7 @@ impl Instruction for SetLessThanImmInst .assign_value(instance, Value::new_unchecked(rs1)); let imm = step.insn().imm as i16 as u16; - set_val!(instance, config.imm, E::BaseField::from_canonical_u16(imm)); + set_val!(instance, config.imm, E::BaseField::from_u16(imm)); // according to riscvim32 spec, imm always do signed extension let imm_sign_extend = imm_sign_extend(true, step.insn().imm as i16); set_val!( diff --git a/ceno_zkvm/src/precompiles/bitwise_keccakf.rs b/ceno_zkvm/src/precompiles/bitwise_keccakf.rs index 78c8bfe95..3e879938f 100644 --- a/ceno_zkvm/src/precompiles/bitwise_keccakf.rs +++ b/ceno_zkvm/src/precompiles/bitwise_keccakf.rs @@ -7,7 +7,7 @@ use multilinear_extensions::{ mle::{MultilinearExtension, Point, PointAndEval}, util::ceil_log2, }; -use p3::{field::FieldAlgebra, util::indices_arr}; +use p3::{field::PrimeCharacteristicRing as FieldAlgebra, util::indices_arr}; use std::{array::from_fn, mem::transmute, sync::Arc}; use sumcheck::{ macros::{entered_span, exit_span}, @@ -52,7 +52,7 @@ fn not_expr(a: Expression) -> Expression { } fn xor_expr(a: Expression, b: Expression) -> Expression { - a.clone() + b.clone() - E::BaseField::from_canonical_u32(2).expr() * a * b + a.clone() + b.clone() - E::BaseField::from_u32(2).expr() * a * b } fn zero_expr() -> Expression { @@ -411,7 +411,7 @@ fn iota_layer( let sel_type = SelectorType::Whole(layer.eq.expr()); iota_out_evals.iter().enumerate().for_each(|(i, out_eval)| { let expr = { - let round_bit = E::BaseField::from_canonical_u64((round_value >> i) & 1).expr(); + let round_bit = E::BaseField::from_u64((round_value >> i) & 1).expr(); xor_expr(bits[i].clone(), round_bit) }; system.add_non_zero_constraint( diff --git a/ceno_zkvm/src/precompiles/fptower/fp.rs b/ceno_zkvm/src/precompiles/fptower/fp.rs index ba7c04308..b7e65b650 100644 --- a/ceno_zkvm/src/precompiles/fptower/fp.rs +++ b/ceno_zkvm/src/precompiles/fptower/fp.rs @@ -34,7 +34,7 @@ use gkr_iop::{ use itertools::Itertools; use multilinear_extensions::{Expression, ToExpr, WitIn, util::max_usable_threads}; use num::BigUint; -use p3::field::FieldAlgebra; + use rayon::{ iter::{IndexedParallelIterator, ParallelIterator}, prelude::ParallelSlice, @@ -53,6 +53,7 @@ use crate::{ precompiles::{SelectorTypeLayout, utils::merge_u8_slice_to_u16_limbs_pairs_and_extend}, witness::LkMultiplicity, }; +use p3::field::PrimeCharacteristicRing; pub const fn num_fp_cols() -> usize { size_of::>() @@ -155,9 +156,9 @@ impl FpOpLayout { cols: &mut FpOpWitCols, lk_multiplicity: &mut LkMultiplicity, ) { - cols.is_add = E::BaseField::from_canonical_u8((instance.op == FieldOperation::Add) as u8); - cols.is_sub = E::BaseField::from_canonical_u8((instance.op == FieldOperation::Sub) as u8); - cols.is_mul = E::BaseField::from_canonical_u8((instance.op == FieldOperation::Mul) as u8); + cols.is_add = E::BaseField::from_u8((instance.op == FieldOperation::Add) as u8); + cols.is_sub = E::BaseField::from_u8((instance.op == FieldOperation::Sub) as u8); + cols.is_mul = E::BaseField::from_u8((instance.op == FieldOperation::Mul) as u8); cols.x_limbs = P::to_limbs_field(&instance.x); cols.y_limbs = P::to_limbs_field(&instance.y); diff --git a/ceno_zkvm/src/precompiles/fptower/fp2_addsub.rs b/ceno_zkvm/src/precompiles/fptower/fp2_addsub.rs index 76d32b31a..4ba84f9d8 100644 --- a/ceno_zkvm/src/precompiles/fptower/fp2_addsub.rs +++ b/ceno_zkvm/src/precompiles/fptower/fp2_addsub.rs @@ -34,7 +34,7 @@ use gkr_iop::{ use itertools::Itertools; use multilinear_extensions::{Expression, ToExpr, WitIn, util::max_usable_threads}; use num::BigUint; -use p3::field::FieldAlgebra; + use rayon::{ iter::{IndexedParallelIterator, ParallelIterator}, prelude::ParallelSlice, @@ -53,6 +53,7 @@ use crate::{ precompiles::{SelectorTypeLayout, utils::merge_u8_slice_to_u16_limbs_pairs_and_extend}, witness::LkMultiplicity, }; +use p3::field::PrimeCharacteristicRing; pub const fn num_fp2_addsub_cols() -> usize { size_of::>() @@ -164,7 +165,7 @@ impl Fp2AddSubAssignLayout { cols: &mut Fp2AddSubAssignWitCols, lk_multiplicity: &mut LkMultiplicity, ) { - cols.is_add = E::BaseField::from_canonical_u8((instance.op == FieldOperation::Add) as u8); + cols.is_add = E::BaseField::from_u8((instance.op == FieldOperation::Add) as u8); cols.a0 = P::to_limbs_field(&instance.a0); cols.a1 = P::to_limbs_field(&instance.a1); cols.b0 = P::to_limbs_field(&instance.b0); diff --git a/ceno_zkvm/src/precompiles/fptower/fp2_mul.rs b/ceno_zkvm/src/precompiles/fptower/fp2_mul.rs index c9160e6d9..fe24b8a58 100644 --- a/ceno_zkvm/src/precompiles/fptower/fp2_mul.rs +++ b/ceno_zkvm/src/precompiles/fptower/fp2_mul.rs @@ -34,7 +34,7 @@ use gkr_iop::{ use itertools::Itertools; use multilinear_extensions::{Expression, ToExpr, WitIn, util::max_usable_threads}; use num::BigUint; -use p3::field::FieldAlgebra; + use rayon::{ iter::{IndexedParallelIterator, ParallelIterator}, prelude::ParallelSlice, @@ -53,6 +53,7 @@ use crate::{ precompiles::{SelectorTypeLayout, utils::merge_u8_slice_to_u16_limbs_pairs_and_extend}, witness::LkMultiplicity, }; +use p3::field::PrimeCharacteristicRing; pub const fn num_fp2_mul_cols() -> usize { size_of::>() diff --git a/ceno_zkvm/src/precompiles/lookup_keccakf.rs b/ceno_zkvm/src/precompiles/lookup_keccakf.rs index 52391267a..d2026f354 100644 --- a/ceno_zkvm/src/precompiles/lookup_keccakf.rs +++ b/ceno_zkvm/src/precompiles/lookup_keccakf.rs @@ -23,7 +23,7 @@ use multilinear_extensions::{ util::{ceil_log2, max_usable_threads}, }; use ndarray::{ArrayView, Ix2, Ix3, s}; -use p3::field::FieldAlgebra; + use rayon::{ iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}, slice::{ParallelSlice, ParallelSliceMut}, @@ -47,6 +47,7 @@ use crate::{ }, scheme::utils::gkr_witness, }; +use p3::field::PrimeCharacteristicRing; pub const ROUNDS: usize = 24; pub const ROUNDS_CEIL_LOG2: usize = 5; // log_2(24.next_pow2()) @@ -620,7 +621,7 @@ where // RC.iter() // .flat_map(|x| { // (0..8) - // .map(|i| E::BaseField::from_canonical_u64((x >> (i << 3)) & 0xFF)) + // .map(|i| E::BaseField::from_u64((x >> (i << 3)) & 0xFF)) // .collect_vec() // }) // .collect_vec(), @@ -973,7 +974,7 @@ pub fn setup_gkr_circuit() WriteMEM::construct_circuit( &mut cb, // mem address := state_ptr + i - state_ptr.expr() + E::BaseField::from_canonical_u32(i as u32).expr(), + state_ptr.expr() + E::BaseField::from_u32(i as u32).expr(), val_before.clone(), val_after.clone(), vm_state.ts, @@ -1197,7 +1198,7 @@ pub fn run_lookup_keccakf // .to_vec() // .iter() // .flat_map(|e| vec![*e as u32, (e >> 32) as u32]) - // .map(|e| Goldilocks::from_canonical_u64(e as u64)) + // .map(|e| Goldilocks::from_u64(e as u64)) // .collect_vec(), // instance_outputs[i] // ); diff --git a/ceno_zkvm/src/precompiles/sha256/extend.rs b/ceno_zkvm/src/precompiles/sha256/extend.rs index 235e37b95..4937aa2d1 100644 --- a/ceno_zkvm/src/precompiles/sha256/extend.rs +++ b/ceno_zkvm/src/precompiles/sha256/extend.rs @@ -32,7 +32,7 @@ use gkr_iop::{ }; use itertools::Itertools; use multilinear_extensions::{Expression, ToExpr, WitIn, util::max_usable_threads}; -use p3::field::{FieldAlgebra, TwoAdicField}; +use p3::field::{PrimeCharacteristicRing as FieldAlgebra, TwoAdicField}; use rayon::{ iter::{IndexedParallelIterator, ParallelIterator}, prelude::ParallelSlice, @@ -448,11 +448,9 @@ mod tests { expected_output.iter().take(SHA_EXTEND_ROUNDS).enumerate() { let row_idx = instance_idx * SHA_EXTEND_ROUNDS + round_idx; - let output_word: [_; WORD_SIZE] = phase1.row_slice(row_idx) - [out_index..out_index + 4] - .to_vec() - .try_into() - .unwrap(); + let row = phase1.row_slice(row_idx).expect("phase1 row out of bounds"); + let output_word: [_; WORD_SIZE] = + row[out_index..out_index + WORD_SIZE].try_into().unwrap(); let expected_word = Word::::from(*expected_word_u32); assert_eq!( output_word, expected_word.0, diff --git a/ceno_zkvm/src/precompiles/uint256.rs b/ceno_zkvm/src/precompiles/uint256.rs index e59e97c55..62aecf41c 100644 --- a/ceno_zkvm/src/precompiles/uint256.rs +++ b/ceno_zkvm/src/precompiles/uint256.rs @@ -27,7 +27,9 @@ use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, e2e::ShardContext, error::ZKVMError, - gadgets::{FieldOperation, IsZeroOperation, field_op::FieldOpCols, range::FieldLtCols}, + gadgets::{ + FieldOperation, IsZeroOperation, field_op::FieldOpCols, poly_scale_expr, range::FieldLtCols, + }, instructions::riscv::insn_base::{StateInOut, WriteMEM}, precompiles::{SelectorTypeLayout, utils::merge_u8_slice_to_u16_limbs_pairs_and_extend}, scheme::utils::gkr_witness, @@ -53,7 +55,8 @@ use multilinear_extensions::{ util::{ceil_log2, max_usable_threads}, }; use num::{BigUint, One, Zero}; -use p3::field::FieldAlgebra; + +use p3::field::PrimeCharacteristicRing; use rayon::{ iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}, slice::ParallelSlice, @@ -230,9 +233,14 @@ impl ProtocolBuilder for Uint256MulLayout { coeff_2_256.resize(32, Expression::ZERO); coeff_2_256.push(Expression::ONE); let modulus_polynomial: Polynomial> = (*modulus_limbs).into(); - let p_modulus: Polynomial> = modulus_polynomial - * (1 - modulus_is_zero.expr()) - + Polynomial::from_coefficients(&coeff_2_256) * modulus_is_zero.expr(); + let modulus_is_zero_expr = modulus_is_zero.expr(); + let p_modulus: Polynomial> = poly_scale_expr( + &modulus_polynomial, + Expression::ONE - modulus_is_zero_expr.clone(), + ) + poly_scale_expr( + &Polynomial::from_coefficients(&coeff_2_256), + modulus_is_zero_expr.clone(), + ); // Evaluate the uint256 multiplication wits.output @@ -673,7 +681,7 @@ pub fn setup_uint256mul_gkr_circuit() WriteMEM::construct_circuit( &mut cb, // mem address := state_ptr_0 + i - number_ptr.expr() + E::BaseField::from_canonical_u32(i as u32).expr(), + number_ptr.expr() + E::BaseField::from_u32(i as u32).expr(), val_before.clone(), val_after.clone(), vm_state.ts, @@ -694,7 +702,7 @@ pub fn setup_uint256mul_gkr_circuit() &mut cb, // mem address := state_ptr_1 + i number_ptr.expr() - + E::BaseField::from_canonical_u32((limb_len * j + i) as u32).expr(), + + E::BaseField::from_u32((limb_len * j + i) as u32).expr(), val_before.clone(), val_before.clone(), vm_state.ts, diff --git a/ceno_zkvm/src/precompiles/utils.rs b/ceno_zkvm/src/precompiles/utils.rs index 4d63d7f90..669fa5f42 100644 --- a/ceno_zkvm/src/precompiles/utils.rs +++ b/ceno_zkvm/src/precompiles/utils.rs @@ -2,11 +2,12 @@ use ff_ext::ExtensionField; use gkr_iop::circuit_builder::expansion_expr; use itertools::Itertools; use multilinear_extensions::{Expression, ToExpr}; -use p3::field::FieldAlgebra; + +use p3::field::PrimeCharacteristicRing; use smallvec::SmallVec; pub fn not8_expr(expr: Expression) -> Expression { - E::BaseField::from_canonical_u8(0xFF).expr() - expr + E::BaseField::from_u8(0xFF).expr() - expr } pub fn set_slice_felts_from_u64(dst: &mut [E::BaseField], start_index: usize, iter: I) @@ -15,7 +16,7 @@ where I: IntoIterator, { for (i, word) in iter.into_iter().enumerate() { - dst[start_index + i] = E::BaseField::from_canonical_u64(word); + dst[start_index + i] = E::BaseField::from_u64(word); } } diff --git a/ceno_zkvm/src/precompiles/weierstrass/test_utils.rs b/ceno_zkvm/src/precompiles/weierstrass/test_utils.rs index ac60c1ea2..b3a5fcc83 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/test_utils.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/test_utils.rs @@ -1,6 +1,6 @@ use generic_array::GenericArray; -use num::bigint::RandBigInt; -use rand::{Rng, SeedableRng}; +use num_bigint::BigUint; +use rand::{Rng, RngCore, SeedableRng}; use sp1_curves::{ EllipticCurve, params::NumWords, @@ -16,11 +16,11 @@ pub fn random_point_pairs( let base = SwCurve::::generator(); (0..num_instances) .map(|_| { - let x = rng.gen_biguint(24); + let x = gen_biguint(&mut rng, 24); - let mut y = rng.gen_biguint(24); + let mut y = gen_biguint(&mut rng, 24); while y == x { - y = rng.gen_biguint(24); + y = gen_biguint(&mut rng, 24); } let x_base = base.clone().sw_scalar_mul(&x); @@ -40,7 +40,7 @@ pub fn random_points( let base = SwCurve::::generator(); (0..num_instances) .map(|_| { - let x = rng.gen_biguint(24); + let x = gen_biguint(&mut rng, 24); let x_base = base.clone().sw_scalar_mul(&x); x_base.to_words_le().try_into().unwrap() }) @@ -55,7 +55,7 @@ pub fn random_decompress_instances( let base = SwCurve::::generator(); (0..num_instances) .map(|_| { - let x = rng.gen_biguint(24); + let x = gen_biguint(&mut rng, 24); let sign_bit = rng.gen_bool(0.5); let x_base = base.clone().sw_scalar_mul(&x); EllipticCurveDecompressInstance { @@ -66,3 +66,20 @@ pub fn random_decompress_instances( }) .collect() } + +fn gen_biguint(rng: &mut R, bits: u64) -> BigUint { + if bits == 0 { + return BigUint::from(0u8); + } + let num_bytes = ((bits + 7) / 8) as usize; + let mut buf = vec![0u8; num_bytes]; + rng.fill_bytes(&mut buf); + let excess_bits = (num_bytes as u64 * 8) - bits; + if excess_bits > 0 { + let mask = 0xffu8 >> excess_bits; + if let Some(last) = buf.last_mut() { + *last &= mask; + } + } + BigUint::from_bytes_be(&buf) +} diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs index 012dcab80..80162f5a5 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs @@ -45,7 +45,7 @@ use multilinear_extensions::{ util::{ceil_log2, max_usable_threads}, }; use num::BigUint; -use p3::field::FieldAlgebra; + use rayon::{ iter::{IndexedParallelIterator, ParallelIterator}, prelude::{IntoParallelRefIterator, ParallelSlice}, @@ -75,6 +75,7 @@ use crate::{ structs::PointAndEval, witness::LkMultiplicity, }; +use p3::field::PrimeCharacteristicRing; #[derive(Clone, Debug, AlignedBorrow)] #[repr(C)] @@ -478,7 +479,7 @@ pub fn setup_gkr_circuit() WriteMEM::construct_circuit( &mut cb, // mem address := state_ptr_0 + i - point_ptr_0.expr() + E::BaseField::from_canonical_u32(i as u32).expr(), + point_ptr_0.expr() + E::BaseField::from_u32(i as u32).expr(), val_before.clone(), val_after.clone(), vm_state.ts, @@ -496,10 +497,7 @@ pub fn setup_gkr_circuit() &mut cb, // mem address := state_ptr_1 + i point_ptr_0.expr() - + E::BaseField::from_canonical_u32( - (layout.output32_exprs.len() + i) as u32, - ) - .expr(), + + E::BaseField::from_u32((layout.output32_exprs.len() + i) as u32).expr(), val_before.clone(), val_before.clone(), vm_state.ts, diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs index d6400a2d7..c9135055c 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs @@ -46,7 +46,7 @@ use multilinear_extensions::{ util::{ceil_log2, max_usable_threads}, }; use num::{BigUint, One, Zero}; -use p3::field::FieldAlgebra; + use rayon::{ iter::{IndexedParallelIterator, ParallelIterator}, prelude::{IntoParallelRefIterator, ParallelSlice}, @@ -85,6 +85,7 @@ use crate::{ structs::PointAndEval, witness::LkMultiplicity, }; +use p3::field::PrimeCharacteristicRing; #[derive(Clone, Debug, AlignedBorrow)] #[repr(C)] @@ -195,8 +196,8 @@ impl cols.sign_bit = E::BaseField::from_bool(instance.sign_bit); cols.old_output32 = GenericArray::generate(|i| { [ - E::BaseField::from_canonical_u32(instance.old_y_words[i] & ((1 << 16) - 1)), - E::BaseField::from_canonical_u32((instance.old_y_words[i] >> 16) & ((1 << 16) - 1)), + E::BaseField::from_u32(instance.old_y_words[i] & ((1 << 16) - 1)), + E::BaseField::from_u32((instance.old_y_words[i] >> 16) & ((1 << 16) - 1)), ] }); diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs index 686baa397..0efe9312d 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs @@ -45,7 +45,7 @@ use multilinear_extensions::{ util::{ceil_log2, max_usable_threads}, }; use num::BigUint; -use p3::field::FieldAlgebra; + use rayon::{ iter::{IndexedParallelIterator, ParallelIterator}, prelude::{IntoParallelRefIterator, ParallelSlice}, @@ -76,6 +76,7 @@ use crate::{ structs::PointAndEval, witness::LkMultiplicity, }; +use p3::field::PrimeCharacteristicRing; #[derive(Clone, Debug, AlignedBorrow)] #[repr(C)] @@ -505,7 +506,7 @@ pub fn setup_gkr_circuit(&self) -> Vec> { vec![ - vec![E::BaseField::from_canonical_u32(self.exit_code & 0xffff)], - vec![E::BaseField::from_canonical_u32( - (self.exit_code >> 16) & 0xffff, - )], - vec![E::BaseField::from_canonical_u32(self.init_pc)], - vec![E::BaseField::from_canonical_u64(self.init_cycle)], - vec![E::BaseField::from_canonical_u32(self.end_pc)], - vec![E::BaseField::from_canonical_u64(self.end_cycle)], - vec![E::BaseField::from_canonical_u32(self.shard_id)], - vec![E::BaseField::from_canonical_u32(self.heap_start_addr)], - vec![E::BaseField::from_canonical_u32(self.heap_shard_len)], - vec![E::BaseField::from_canonical_u32(self.hint_start_addr)], - vec![E::BaseField::from_canonical_u32(self.hint_shard_len)], + vec![E::BaseField::from_u32(self.exit_code & 0xffff)], + vec![E::BaseField::from_u32((self.exit_code >> 16) & 0xffff)], + vec![E::BaseField::from_u32(self.init_pc)], + vec![E::BaseField::from_u64(self.init_cycle)], + vec![E::BaseField::from_u32(self.end_pc)], + vec![E::BaseField::from_u64(self.end_cycle)], + vec![E::BaseField::from_u32(self.shard_id)], + vec![E::BaseField::from_u32(self.heap_start_addr)], + vec![E::BaseField::from_u32(self.heap_shard_len)], + vec![E::BaseField::from_u32(self.hint_start_addr)], + vec![E::BaseField::from_u32(self.hint_shard_len)], ] .into_iter() .chain( @@ -142,7 +139,7 @@ impl PublicValues { self.public_io .iter() .map(|value| { - E::BaseField::from_canonical_u16( + E::BaseField::from_u16( ((value >> (limb_index * LIMB_BITS)) & LIMB_MASK) as u16, ) }) @@ -153,7 +150,7 @@ impl PublicValues { .chain( self.shard_rw_sum .iter() - .map(|value| vec![E::BaseField::from_canonical_u32(*value)]) + .map(|value| vec![E::BaseField::from_u32(*value)]) .collect_vec(), ) .collect::>() diff --git a/ceno_zkvm/src/scheme/gpu/mod.rs b/ceno_zkvm/src/scheme/gpu/mod.rs index 03b789f6d..e9017580e 100644 --- a/ceno_zkvm/src/scheme/gpu/mod.rs +++ b/ceno_zkvm/src/scheme/gpu/mod.rs @@ -805,7 +805,6 @@ fn build_tower_witness_gpu<'buf, E: ExtensionField>( let stream = gkr_iop::gpu::get_thread_stream(); use crate::scheme::constants::{NUM_FANIN, NUM_FANIN_LOGUP}; use ceno_gpu::{CudaHal as _, bb31::GpuPolynomialExt}; - use p3::field::FieldAlgebra; let ComposedConstrainSystem { zkvm_v1_css: cs, .. diff --git a/ceno_zkvm/src/scheme/gpu/util.rs b/ceno_zkvm/src/scheme/gpu/util.rs index df52010ef..7e8fff7d4 100644 --- a/ceno_zkvm/src/scheme/gpu/util.rs +++ b/ceno_zkvm/src/scheme/gpu/util.rs @@ -20,6 +20,7 @@ use crate::{ }; use crate::scheme::gpu::BB31Ext; +use p3::field::PrimeCharacteristicRing; pub fn expect_basic_transcript>( transcript: &mut T, @@ -56,7 +57,7 @@ fn read_base_value_from_gpu<'a, E: ExtensionField>( .get(index, stream.as_ref()) .map_err(|e| hal_to_backend_error(format!("failed to read GPU buffer: {e:?}")))?; let canonical = raw.as_canonical_u32(); - Ok(E::BaseField::from_canonical_u32(canonical)) + Ok(E::BaseField::from_u32(canonical)) } GpuFieldType::Ext(_) => Err(hal_to_backend_error( "expected base-field polynomial for final-sum extraction", diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index a477b1c6a..f95631b03 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -30,7 +30,7 @@ use multilinear_extensions::{ util::ceil_log2, utils::{eval_by_expr, eval_by_expr_with_fixed, eval_by_expr_with_instance}, }; -use p3::field::{Field, FieldAlgebra}; +use p3::field::{Field, PrimeCharacteristicRing as FieldAlgebra}; use rand::thread_rng; use std::{ cmp::max, @@ -479,9 +479,9 @@ fn load_once_tables( ( challenge.map(|c| { - c.as_base_slice() + c.as_bases() .iter() - .map(|b| b.to_canonical_u64()) + .map(|b: &E::BaseField| b.to_canonical_u64()) .collect_vec() }), table, @@ -490,7 +490,8 @@ fn load_once_tables( // reinitialize per generic type E ( challenges_repr.clone().map(|repr| { - E::from_base_iter(repr.iter().copied().map(E::BaseField::from_canonical_u64)) + E::from_basis_coefficients_iter(repr.iter().copied().map(E::BaseField::from_u64)) + .expect("challenge repr must describe a valid extension element") }), table.clone(), ) @@ -1224,7 +1225,7 @@ Hints: let w_selector_vec = w_selector.get_base_field_vec(); let write_rlc_records = filter_mle_by_predicate(write_rlc_records, |i, _v| { - ram_type_vec[i] == E::from_canonical_u32($ram_type as u32) + ram_type_vec[i] == E::from_u32($ram_type as u32) && w_selector_vec[i] == E::BaseField::ONE }); if write_rlc_records.is_empty() { @@ -1327,7 +1328,7 @@ Hints: ); let r_selector_vec = r_selector.get_base_field_vec(); let read_records = filter_mle_by_predicate(read_records, |i, _v| { - ram_type_vec[i] == E::from_canonical_u32($ram_type as u32) + ram_type_vec[i] == E::from_u32($ram_type as u32) && r_selector_vec[i] == E::BaseField::ONE }); if read_records.is_empty() { @@ -1640,7 +1641,7 @@ mod tests { }; use ff_ext::{FieldInto, GoldilocksExt2}; use multilinear_extensions::{ToExpr, WitIn, mle::IntoMLE}; - use p3::{field::FieldAlgebra, goldilocks::Goldilocks}; + use p3::{field::PrimeCharacteristicRing as FieldAlgebra, goldilocks::Goldilocks}; use witness::{InstancePaddingStrategy, RowMajorMatrix, set_val}; #[derive(Debug)] @@ -1720,12 +1721,9 @@ mod tests { let _ = RangeCheckCircuit::construct_circuit(&mut builder).unwrap(); let wits_in = vec![ - vec![ - Goldilocks::from_canonical_u64(3u64), - Goldilocks::from_canonical_u64(5u64), - ] - .into_mle() - .into(), + vec![Goldilocks::from_u64(3u64), Goldilocks::from_u64(5u64)] + .into_mle() + .into(), ]; let challenge = [1.into_f(), 1000.into_f()]; @@ -1760,9 +1758,7 @@ mod tests { GoldilocksExt2::ONE, GoldilocksExt2::ZERO, )), - Box::new( - Goldilocks::from_canonical_u64(ROMType::Dynamic as u64).expr() - ), + Box::new(Goldilocks::from_u64(ROMType::Dynamic as u64).expr()), )), Box::new(Expression::Challenge( 1, diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 99ad1ac39..08082d885 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -3,6 +3,7 @@ use gkr_iop::{ cpu::{CpuBackend, CpuProver}, hal::ProverBackend, }; +use p3::field::PrimeCharacteristicRing; use std::{ collections::{BTreeMap, HashMap}, marker::PhantomData, @@ -24,7 +25,6 @@ use multilinear_extensions::{ Expression, Instance, mle::{IntoMLE, MultilinearExtension}, }; -use p3::field::FieldAlgebra; use std::iter::Iterator; use sumcheck::{ macros::{entered_span, exit_span}, @@ -207,10 +207,9 @@ impl< let circuit_idx = self.pk.circuit_name_to_index.get(name).unwrap(); // write (circuit_idx, num_var) to transcript - transcript.append_field_element(&E::BaseField::from_canonical_usize(*circuit_idx)); + transcript.append_field_element(&E::BaseField::from_usize(*circuit_idx)); for num_instance in num_instances { - transcript - .append_field_element(&E::BaseField::from_canonical_usize(*num_instance)); + transcript.append_field_element(&E::BaseField::from_usize(*num_instance)); } } @@ -381,9 +380,8 @@ impl< return scheduler.execute(tasks, transcript, |task, transcript| { // Append circuit_idx to per-task forked transcript (matching verifier) - transcript.append_field_element(&E::BaseField::from_canonical_u64( - task.circuit_idx as u64, - )); + transcript + .append_field_element(&E::BaseField::from_u64(task.circuit_idx as u64)); // SAFETY: TypeId check above (before closure) guarantees PB = GpuBackend. let gpu_input: ProofInput<'static, gkr_iop::gpu::GpuBackend> = @@ -418,8 +416,7 @@ impl< // Uses execute_sequentially directly to avoid Send+Sync requirement on the closure. scheduler.execute_sequentially(tasks, transcript, |mut task, transcript| { // Append circuit_idx to per-task forked transcript (matching verifier) - transcript - .append_field_element(&E::BaseField::from_canonical_u64(task.circuit_idx as u64)); + transcript.append_field_element(&E::BaseField::from_u64(task.circuit_idx as u64)); // Prepare: deferred extraction for GPU, no-op for CPU self.device.prepare_chip_input(&mut task, witness_data); diff --git a/ceno_zkvm/src/scheme/scheduler.rs b/ceno_zkvm/src/scheme/scheduler.rs index 438421b2e..7f164cec7 100644 --- a/ceno_zkvm/src/scheme/scheduler.rs +++ b/ceno_zkvm/src/scheme/scheduler.rs @@ -18,7 +18,7 @@ use crate::{ use ff_ext::ExtensionField; use gkr_iop::hal::ProverBackend; use mpcs::Point; -use p3::field::FieldAlgebra; +use p3::field::PrimeCharacteristicRing; use std::{collections::HashMap, sync::OnceLock}; use transcript::Transcript; static CHIP_PROVING_MODE: OnceLock = OnceLock::new(); @@ -186,7 +186,7 @@ impl ChipScheduler { // Fork: clone parent + append task_id // (identical to ForkableTranscript::fork default impl) let mut forked = parent_transcript.clone(); - forked.append_field_element(&::BaseField::from_canonical_u64( + forked.append_field_element(&::BaseField::from_u64( task_id as u64, )); @@ -247,7 +247,7 @@ impl ChipScheduler { if tasks.len() == 1 { let task = tasks.remove(0); let mut fork = transcript.clone(); - fork.append_field_element(&::BaseField::from_canonical_u64( + fork.append_field_element(&::BaseField::from_u64( task.task_id as u64, )); let result = execute_task(task, &mut fork)?; @@ -348,9 +348,7 @@ impl ChipScheduler { // (identical to ForkableTranscript::fork default impl) let mut local_transcript = tr.0.clone(); local_transcript.append_field_element( - &::BaseField::from_canonical_u64( - task_id as u64, - ), + &::BaseField::from_u64(task_id as u64), ); let result = execute_fn(task, &mut local_transcript); diff --git a/ceno_zkvm/src/scheme/septic_curve.rs b/ceno_zkvm/src/scheme/septic_curve.rs index f9b6b4f76..db291f277 100644 --- a/ceno_zkvm/src/scheme/septic_curve.rs +++ b/ceno_zkvm/src/scheme/septic_curve.rs @@ -3,7 +3,7 @@ use ff_ext::{ExtensionField, FromUniformBytes}; use multilinear_extensions::Expression; // The extension field and curve definition are adapted from // https://github.com/succinctlabs/sp1/blob/v5.2.1/crates/stark/src/septic_curve.rs -use p3::field::{Field, FieldAlgebra}; +use p3::field::{Field, PrimeCharacteristicRing as FieldAlgebra}; use rand::RngCore; use serde::{Deserialize, Serialize}; use std::{ @@ -233,8 +233,8 @@ impl SepticExtension { pub fn square(&self) -> Self { let mut result = [F::ZERO; 7]; - let two = F::from_canonical_u32(2); - let five = F::from_canonical_u32(5); + let two = F::from_u32(2); + let five = F::from_u32(5); // i < j for i in 0..7 { @@ -423,7 +423,7 @@ impl From<[u32; 7]> for SepticExtension { fn from(arr: [u32; 7]) -> Self { let mut result = [F::ZERO; 7]; for i in 0..7 { - result[i] = F::from_canonical_u32(arr[i]); + result[i] = F::from_u32(arr[i]); } Self(result) } @@ -549,8 +549,8 @@ impl Mul for &SepticExtension { fn mul(self, other: Self) -> Self::Output { let mut result = [F::ZERO; 7]; - let five = F::from_canonical_u32(5); - let two = F::from_canonical_u32(2); + let five = F::from_u32(5); + let two = F::from_u32(2); for i in 0..7 { for j in 0..7 { let term = self.0[i] * other.0[j]; @@ -683,8 +683,8 @@ impl Mul for &SymbolicSepticExtension { fn mul(self, other: Self) -> Self::Output { let mut result = vec![Expression::Constant(Either::Left(E::BaseField::ZERO)); 7]; - let five = Expression::Constant(Either::Left(E::BaseField::from_canonical_u32(5))); - let two = Expression::Constant(Either::Left(E::BaseField::from_canonical_u32(2))); + let five = Expression::Constant(Either::Left(E::BaseField::from_u32(5))); + let two = Expression::Constant(Either::Left(E::BaseField::from_u32(2))); for i in 0..7 { for j in 0..7 { @@ -770,7 +770,7 @@ impl SepticPoint { // if there exists y such that (x, y) is on the curve, return one of them pub fn from_x(x: SepticExtension) -> Option { let b: SepticExtension = [0, 0, 0, 0, 0, 26, 0].into(); - let a: F = F::from_canonical_u32(2); + let a: F = F::from_u32(2); let y2 = x.square() * &x + (&x * a) + &b; if y2.is_square() { @@ -792,9 +792,9 @@ impl SepticPoint { Self { x, y, is_infinity } } pub fn double(&self) -> Self { - let a = F::from_canonical_u32(2); - let three = F::from_canonical_u32(3); - let two = F::from_canonical_u32(2); + let a = F::from_u32(2); + let three = F::from_u32(3); + let two = F::from_u32(2); let x1 = &self.x; let y1 = &self.y; @@ -891,7 +891,7 @@ impl SepticPoint { } let b: SepticExtension = [0, 0, 0, 0, 0, 26, 0].into(); - let a: F = F::from_canonical_u32(2); + let a: F = F::from_u32(2); self.y.square() == self.x.square() * &self.x + (&self.x * a) + b } @@ -955,7 +955,7 @@ impl SepticJacobianPoint { } let b: SepticExtension = [0, 0, 0, 0, 0, 26, 0].into(); - let a: F = F::from_canonical_u32(2); + let a: F = F::from_u32(2); let z2 = self.z.square(); let z4 = z2.square(); @@ -1015,7 +1015,7 @@ impl Add for &SepticJacobianPoint { } } - let two = F::from_canonical_u32(2); + let two = F::from_u32(2); let h = u2 - &u1; let i = (&h * two).square(); let j = &h * &i; @@ -1052,10 +1052,10 @@ impl SepticJacobianPoint { return SepticJacobianPoint::point_at_infinity(); } - let two = F::from_canonical_u32(2); - let three = F::from_canonical_u32(3); - let eight = F::from_canonical_u32(8); - let a = F::from_canonical_u32(2); // The curve coefficient a + let two = F::from_u32(2); + let three = F::from_u32(3); + let eight = F::from_u32(8); + let a = F::from_u32(2); // The curve coefficient a // xx = x1^2 let xx = self.x.square(); diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 91a4e4563..32f68d14f 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -25,6 +25,7 @@ use gkr_iop::cpu::default_backend_config; #[cfg(feature = "gpu")] use gkr_iop::gpu::{MultilinearExtensionGpu, gpu_prover::*}; use multilinear_extensions::{ToExpr, WitIn, mle::MultilinearExtension}; +use p3::field::PrimeCharacteristicRing; use std::marker::PhantomData; #[cfg(feature = "gpu")] use std::sync::Arc; @@ -48,7 +49,7 @@ use mpcs::{ PolynomialCommitmentScheme, SecurityLevel, SecurityLevel::Conjecture100bits, WhirDefault, }; use multilinear_extensions::{mle::IntoMLE, util::ceil_log2}; -use p3::field::FieldAlgebra; + use rand::thread_rng; use transcript::{BasicTranscript, Transcript}; diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index aa2be21fe..6a3bc00f7 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -565,7 +565,7 @@ mod tests { smart_slice::SmartSlice, util::ceil_log2, }; - use p3::field::FieldAlgebra; + use p3::field::PrimeCharacteristicRing; use crate::scheme::utils::{ infer_tower_logup_witness, infer_tower_product_witness, interleaving_mles_to_mles, @@ -576,8 +576,8 @@ mod tests { type E = GoldilocksExt2; let num_product_fanin = 2; let last_layer: Vec> = vec![ - vec![E::ONE, E::from_canonical_u64(2u64)].into_mle(), - vec![E::from_canonical_u64(3u64), E::from_canonical_u64(4u64)].into_mle(), + vec![E::ONE, E::from_u64(2u64)].into_mle(), + vec![E::from_u64(3u64), E::from_u64(4u64)].into_mle(), ]; let num_vars = ceil_log2(last_layer[0].evaluations().len()) + 1; let res = infer_tower_product_witness(num_vars, last_layer.clone(), 2); @@ -610,16 +610,10 @@ mod tests { let num_product_fanin = 2; // [[1, 2], [3, 4], [5, 6], [7, 8]] let input_mles: Vec> = vec![ - vec![E::ONE, E::from_canonical_u64(2u64)].into_mle().into(), - vec![E::from_canonical_u64(3u64), E::from_canonical_u64(4u64)] - .into_mle() - .into(), - vec![E::from_canonical_u64(5u64), E::from_canonical_u64(6u64)] - .into_mle() - .into(), - vec![E::from_canonical_u64(7u64), E::from_canonical_u64(8u64)] - .into_mle() - .into(), + vec![E::ONE, E::from_u64(2u64)].into_mle().into(), + vec![E::from_u64(3u64), E::from_u64(4u64)].into_mle().into(), + vec![E::from_u64(5u64), E::from_u64(6u64)].into_mle().into(), + vec![E::from_u64(7u64), E::from_u64(8u64)].into_mle().into(), ]; let res = interleaving_mles_to_mles(&input_mles, 2, num_product_fanin, E::ONE); // [[1, 3, 5, 7], [2, 4, 6, 8]] @@ -627,18 +621,18 @@ mod tests { res[0].get_ext_field_vec(), vec![ E::ONE, - E::from_canonical_u64(3u64), - E::from_canonical_u64(5u64), - E::from_canonical_u64(7u64) + E::from_u64(3u64), + E::from_u64(5u64), + E::from_u64(7u64) ], ); assert_eq!( res[1].get_ext_field_vec(), vec![ - E::from_canonical_u64(2u64), - E::from_canonical_u64(4u64), - E::from_canonical_u64(6u64), - E::from_canonical_u64(8u64) + E::from_u64(2u64), + E::from_u64(4u64), + E::from_u64(6u64), + E::from_u64(8u64) ], ); } @@ -651,13 +645,9 @@ mod tests { // case 1: test limb level padding // [[1,2],[3,4],[5,6]]] let input_mles: Vec> = vec![ - vec![E::ONE, E::from_canonical_u64(2u64)].into_mle().into(), - vec![E::from_canonical_u64(3u64), E::from_canonical_u64(4u64)] - .into_mle() - .into(), - vec![E::from_canonical_u64(5u64), E::from_canonical_u64(6u64)] - .into_mle() - .into(), + vec![E::ONE, E::from_u64(2u64)].into_mle().into(), + vec![E::from_u64(3u64), E::from_u64(4u64)].into_mle().into(), + vec![E::from_u64(5u64), E::from_u64(6u64)].into_mle().into(), ]; let res = interleaving_mles_to_mles(&input_mles, 2, num_product_fanin, E::ZERO); // [[1, 3, 5, 0], [2, 4, 6, 0]] @@ -665,42 +655,33 @@ mod tests { res[0].get_ext_field_vec(), vec![ E::ONE, - E::from_canonical_u64(3u64), - E::from_canonical_u64(5u64), - E::from_canonical_u64(0u64) + E::from_u64(3u64), + E::from_u64(5u64), + E::from_u64(0u64) ], ); assert_eq!( res[1].get_ext_field_vec(), vec![ - E::from_canonical_u64(2u64), - E::from_canonical_u64(4u64), - E::from_canonical_u64(6u64), - E::from_canonical_u64(0u64) + E::from_u64(2u64), + E::from_u64(4u64), + E::from_u64(6u64), + E::from_u64(0u64) ], ); // case 2: test instance level padding // [[1,0],[3,0],[5,0]]] let input_mles: Vec> = vec![ - vec![E::ONE, E::from_canonical_u64(0u64)].into_mle().into(), - vec![E::from_canonical_u64(3u64), E::from_canonical_u64(0u64)] - .into_mle() - .into(), - vec![E::from_canonical_u64(5u64), E::from_canonical_u64(0u64)] - .into_mle() - .into(), + vec![E::ONE, E::from_u64(0u64)].into_mle().into(), + vec![E::from_u64(3u64), E::from_u64(0u64)].into_mle().into(), + vec![E::from_u64(5u64), E::from_u64(0u64)].into_mle().into(), ]; let res = interleaving_mles_to_mles(&input_mles, 1, num_product_fanin, E::ONE); // [[1, 3, 5, 1], [1, 1, 1, 1]] assert_eq!( res[0].get_ext_field_vec(), - vec![ - E::ONE, - E::from_canonical_u64(3u64), - E::from_canonical_u64(5u64), - E::ONE - ], + vec![E::ONE, E::from_u64(3u64), E::from_u64(5u64), E::ONE], ); assert_eq!(res[1].get_ext_field_vec(), vec![E::ONE; 4],); } @@ -711,14 +692,14 @@ mod tests { let num_product_fanin = 2; // one instance, 2 mles: [[2], [3]] let input_mles: Vec> = vec![ - vec![E::from_canonical_u64(2u64)].into_mle().into(), - vec![E::from_canonical_u64(3u64)].into_mle().into(), + vec![E::from_u64(2u64)].into_mle().into(), + vec![E::from_u64(3u64)].into_mle().into(), ]; let res = interleaving_mles_to_mles(&input_mles, 1, num_product_fanin, E::ONE); // [[2, 3], [1, 1]] assert_eq!( res[0].get_ext_field_vec(), - vec![E::from_canonical_u64(2u64), E::from_canonical_u64(3u64)], + vec![E::from_u64(2u64), E::from_u64(3u64)], ); assert_eq!(res[1].get_ext_field_vec(), vec![E::ONE, E::ONE],); } @@ -730,12 +711,12 @@ mod tests { let q: Vec> = vec![ vec![1, 2, 3, 4] .into_iter() - .map(E::from_canonical_u64) + .map(E::from_u64) .collect_vec() .into_mle(), vec![5, 6, 7, 8] .into_iter() - .map(E::from_canonical_u64) + .map(E::from_u64) .collect_vec() .into_mle(), ]; @@ -778,53 +759,32 @@ mod tests { assert_eq!( layer[0].evaluations().clone(), FieldType::::Ext(SmartSlice::Owned(vec![ - vec![1 + 5] - .into_iter() - .map(E::from_canonical_u64) - .sum::(), - vec![2 + 6] - .into_iter() - .map(E::from_canonical_u64) - .sum::() + vec![1 + 5].into_iter().map(E::from_u64).sum::(), + vec![2 + 6].into_iter().map(E::from_u64).sum::() ])) ); // next layer p2 assert_eq!( layer[1].evaluations().clone(), FieldType::::Ext(SmartSlice::Owned(vec![ - vec![3 + 7] - .into_iter() - .map(E::from_canonical_u64) - .sum::(), - vec![4 + 8] - .into_iter() - .map(E::from_canonical_u64) - .sum::() + vec![3 + 7].into_iter().map(E::from_u64).sum::(), + vec![4 + 8].into_iter().map(E::from_u64).sum::() ])) ); // next layer q1 assert_eq!( layer[2].evaluations().clone(), FieldType::::Ext(SmartSlice::Owned(vec![ - vec![5].into_iter().map(E::from_canonical_u64).sum::(), - vec![2 * 6] - .into_iter() - .map(E::from_canonical_u64) - .sum::() + vec![5].into_iter().map(E::from_u64).sum::(), + vec![2 * 6].into_iter().map(E::from_u64).sum::() ])) ); // next layer q2 assert_eq!( layer[3].evaluations().clone(), FieldType::::Ext(SmartSlice::Owned(vec![ - vec![3 * 7] - .into_iter() - .map(E::from_canonical_u64) - .sum::(), - vec![4 * 8] - .into_iter() - .map(E::from_canonical_u64) - .sum::() + vec![3 * 7].into_iter().map(E::from_u64).sum::(), + vec![4 * 8].into_iter().map(E::from_u64).sum::() ])) ); @@ -837,7 +797,7 @@ mod tests { FieldType::::Ext(SmartSlice::Owned(vec![ vec![(1 + 5) * (3 * 7) + (3 + 7) * 5] .into_iter() - .map(E::from_canonical_u64) + .map(E::from_u64) .sum::(), ])) ); @@ -848,7 +808,7 @@ mod tests { FieldType::::Ext(SmartSlice::Owned(vec![ vec![(2 + 6) * (4 * 8) + (4 + 8) * (2 * 6)] .into_iter() - .map(E::from_canonical_u64) + .map(E::from_u64) .sum::(), ])) ); @@ -857,10 +817,7 @@ mod tests { layer[2].evaluations().clone(), // q12 * q11 FieldType::::Ext(SmartSlice::Owned(vec![ - vec![(3 * 7) * 5] - .into_iter() - .map(E::from_canonical_u64) - .sum::(), + vec![(3 * 7) * 5].into_iter().map(E::from_u64).sum::(), ])) ); // q2 @@ -870,7 +827,7 @@ mod tests { FieldType::::Ext(SmartSlice::Owned(vec![ vec![(4 * 8) * (2 * 6)] .into_iter() - .map(E::from_canonical_u64) + .map(E::from_u64) .sum::(), ])) ); diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 4d02c6e98..fbb635567 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -1,5 +1,6 @@ use either::Either; use ff_ext::ExtensionField; +use p3::field::PrimeCharacteristicRing; use std::{ iter::{self, once, repeat_n}, marker::PhantomData, @@ -38,7 +39,7 @@ use multilinear_extensions::{ utils::eval_by_expr_with_instance, virtual_poly::{VPAuxInfo, build_eq_x_r_vec_sequential, eq_eval}, }; -use p3::field::FieldAlgebra; + use sumcheck::{ structs::{IOPProof, IOPVerifierState}, util::get_challenge_pows, @@ -116,13 +117,13 @@ impl> ZKVMVerifier } // each shard set init cycle = Tracer::SUBCYCLES_PER_INSN // to satisfy initial reads for all prev_cycle = 0 < init_cycle - assert_eq!(vm_proof.pi_evals[INIT_CYCLE_IDX], E::from_canonical_u64(Tracer::SUBCYCLES_PER_INSN)); + assert_eq!(vm_proof.pi_evals[INIT_CYCLE_IDX], E::from_u64(Tracer::SUBCYCLES_PER_INSN)); // check init_pc match prev end_pc if let Some(prev_pc) = prev_pc { assert_eq!(vm_proof.pi_evals[INIT_PC_IDX], prev_pc); } else { // first chunk, check program entry - assert_eq!(vm_proof.pi_evals[INIT_PC_IDX], E::from_canonical_u32(self.vk.entry_pc)); + assert_eq!(vm_proof.pi_evals[INIT_PC_IDX], E::from_u32(self.vk.entry_pc)); } let end_pc = vm_proof.pi_evals[END_PC_IDX]; @@ -191,7 +192,7 @@ impl> ZKVMVerifier // check shard id assert_eq!( vm_proof.raw_pi[SHARD_ID_IDX], - vec![E::BaseField::from_canonical_usize(shard_id)] + vec![E::BaseField::from_usize(shard_id)] ); // verify constant poly(s) evaluation result match @@ -223,10 +224,10 @@ impl> ZKVMVerifier // write (circuit_idx, num_instance) to transcript for (circuit_idx, proofs) in vm_proof.chip_proofs.iter() { - transcript.append_field_element(&E::BaseField::from_canonical_u32(*circuit_idx as u32)); + transcript.append_field_element(&E::BaseField::from_u32(*circuit_idx as u32)); // length of proof.num_instances will be constrained in verify_chip_proof for num_instance in proofs.iter().flat_map(|proof| &proof.num_instances) { - transcript.append_field_element(&E::BaseField::from_canonical_usize(*num_instance)); + transcript.append_field_element(&E::BaseField::from_usize(*num_instance)); } } @@ -340,7 +341,7 @@ impl> ZKVMVerifier }) .sum::(); - transcript.append_field_element(&E::BaseField::from_canonical_u64(*index as u64)); + transcript.append_field_element(&E::BaseField::from_u64(*index as u64)); if circuit_vk.get_cs().is_with_lk_table() { logup_sum -= chip_logup_sum; } else { @@ -391,8 +392,7 @@ impl> ZKVMVerifier shard_ec_sum = shard_ec_sum + chip_shard_ec_sum; } } - logup_sum -= E::from_canonical_u64(dummy_table_item_multiplicity as u64) - * dummy_table_item.inverse(); + logup_sum -= E::from_u64(dummy_table_item_multiplicity as u64) * dummy_table_item.inverse(); #[cfg(debug_assertions)] { diff --git a/ceno_zkvm/src/state.rs b/ceno_zkvm/src/state.rs index caf079a6d..5db6ffa9c 100644 --- a/ceno_zkvm/src/state.rs +++ b/ceno_zkvm/src/state.rs @@ -5,7 +5,7 @@ use crate::{ structs::RAMType, }; use multilinear_extensions::{Expression, ToExpr}; -use p3::field::FieldAlgebra; +use p3::field::PrimeCharacteristicRing; pub trait StateCircuit { fn initial_global_state( @@ -23,7 +23,7 @@ impl StateCircuit for GlobalState { circuit_builder: &mut crate::circuit_builder::CircuitBuilder, ) -> Result, ZKVMError> { let states: Vec> = vec![ - E::BaseField::from_canonical_u64(RAMType::GlobalState as u64).expr(), + E::BaseField::from_u64(RAMType::GlobalState as u64).expr(), circuit_builder.query_init_pc()?.expr(), circuit_builder.query_init_cycle()?.expr(), ]; @@ -35,7 +35,7 @@ impl StateCircuit for GlobalState { circuit_builder: &mut crate::circuit_builder::CircuitBuilder, ) -> Result, ZKVMError> { let states: Vec> = vec![ - E::BaseField::from_canonical_u64(RAMType::GlobalState as u64).expr(), + E::BaseField::from_u64(RAMType::GlobalState as u64).expr(), circuit_builder.query_end_pc()?.expr(), circuit_builder.query_end_cycle()?.expr(), ]; diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index 3894828d9..3cc811f5e 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -13,7 +13,8 @@ use ff_ext::{ExtensionField, FieldInto, SmallField}; use gkr_iop::utils::i64_to_base; use itertools::Itertools; use multilinear_extensions::{Expression, Fixed, ToExpr, WitIn}; -use p3::field::FieldAlgebra; + +use p3::field::PrimeCharacteristicRing; use rayon::iter::{IndexedParallelIterator, ParallelIterator}; use std::{collections::HashMap, marker::PhantomData}; use witness::{ @@ -110,10 +111,7 @@ impl InsnRecord { pub fn imm_internal(insn: &Instruction) -> (i64, F) { match (insn.kind, InsnFormat::from(insn.kind)) { // logic imm - (XORI | ORI | ANDI, _) => ( - insn.imm as i16 as i64, - F::from_canonical_u16(insn.imm as u16), - ), + (XORI | ORI | ANDI, _) => (insn.imm as i16 as i64, F::from_u16(insn.imm as u16)), // for imm operate with program counter => convert to field value (_, B | J) => (insn.imm as i64, i64_to_base(insn.imm as i64)), // AUIPC @@ -121,22 +119,16 @@ impl InsnRecord { // riv32 u type lower 12 bits are 0 // take all except for least significant limb (8 bit) (insn.imm as u32 >> 8) as i64, - F::from_wrapped_u32(insn.imm as u32 >> 8), + F::from_u32(insn.imm as u32 >> 8), ), // U type (_, U) => ( (insn.imm as u32 >> 12) as i64, - F::from_wrapped_u32(insn.imm as u32 >> 12), - ), - (JALR, _) => ( - insn.imm as i16 as i64, - F::from_canonical_u16(insn.imm as i16 as u16), + F::from_u32(insn.imm as u32 >> 12), ), + (JALR, _) => (insn.imm as i16 as i64, F::from_u16(insn.imm as i16 as u16)), // for default imm to operate with register value - _ => ( - insn.imm as i16 as i64, - F::from_canonical_u16(insn.imm as i16 as u16), - ), + _ => (insn.imm as i16 as i64, F::from_u16(insn.imm as i16 as u16)), } } @@ -146,7 +138,7 @@ impl InsnRecord { // logic imm (XORI | ORI | ANDI, _) => ( (insn.imm >> LIMB_BITS) as i16 as i64, - F::from_canonical_u16((insn.imm >> LIMB_BITS) as u16), + F::from_u16((insn.imm >> LIMB_BITS) as u16), ), // Unsigned view. (_, R | U) => (false as i64, F::from_bool(false)), @@ -296,11 +288,7 @@ impl TableCircuit for ProgramTableCircuit { .zip_eq(structural_witness.par_rows_mut()) .zip(prog_mlt) .for_each(|((row, structural_row), mlt)| { - set_val!( - row, - config.mlt, - E::BaseField::from_canonical_u64(mlt as u64) - ); + set_val!(row, config.mlt, E::BaseField::from_u64(mlt as u64)); *structural_row.last_mut().unwrap() = E::BaseField::ONE; }); diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index 7e5f1293d..9e66c4612 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -696,6 +696,7 @@ impl LocalFinalRAMTableConfig { #[cfg(test)] mod tests { + use p3::field::PrimeCharacteristicRing; use std::iter::successors; use crate::{ @@ -711,7 +712,7 @@ mod tests { use gkr_iop::RAMType; use itertools::Itertools; use multilinear_extensions::mle::MultilinearExtension; - use p3::{field::FieldAlgebra, goldilocks::Goldilocks as F}; + use p3::goldilocks::Goldilocks as F; use witness::next_pow2_instance_padding; #[test] @@ -760,7 +761,7 @@ mod tests { structural_witness.to_mles()[addr_column].clone(); // Expect addresses to proceed consecutively inside the padding as well let expected = successors(Some(addr_padded_view.get_base_field_vec()[0]), |idx| { - Some(*idx + F::from_canonical_u64(WORD_SIZE as u64)) + Some(*idx + F::from_u64(WORD_SIZE as u64)) }) .take(next_pow2_instance_padding( structural_witness.num_instances(), diff --git a/ceno_zkvm/src/tables/range/range_impl.rs b/ceno_zkvm/src/tables/range/range_impl.rs index a95664085..1a1dd0cbd 100644 --- a/ceno_zkvm/src/tables/range/range_impl.rs +++ b/ceno_zkvm/src/tables/range/range_impl.rs @@ -74,14 +74,14 @@ impl DynamicRangeTableConfig { } let range_content = std::iter::once(F::ZERO) - .chain((0..=max_bits).flat_map(|i| (0..(1 << i)).map(|j| F::from_canonical_usize(j)))) + .chain((0..=max_bits).flat_map(|i| (0..(1 << i)).map(|j| F::from_usize(j)))) + .collect::>(); + let bits_content = std::iter::once(F::ZERO) + .chain( + (0..=max_bits) + .flat_map(|i| std::iter::repeat_n(i, 1 << i).map(|j| F::from_usize(j))), + ) .collect::>(); - let bits_content = - std::iter::once(F::ZERO) - .chain((0..=max_bits).flat_map(|i| { - std::iter::repeat_n(i, 1 << i).map(|j| F::from_canonical_usize(j)) - })) - .collect::>(); witness .par_rows_mut() @@ -90,7 +90,7 @@ impl DynamicRangeTableConfig { .zip(range_content.par_iter()) .zip(bits_content.par_iter()) .for_each(|((((row, structural_row), mlt), i), b)| { - set_val!(row, self.mlt, F::from_canonical_u64(*mlt as u64)); + set_val!(row, self.mlt, F::from_u64(*mlt as u64)); set_val!(structural_row, self.range, i); set_val!(structural_row, self.bits, b); *structural_row.last_mut().unwrap() = F::ONE; @@ -181,9 +181,9 @@ impl DoubleRangeTableConfig { .for_each(|((row, structural_row), (idx, mlt))| { let a = idx >> self.range_a_bits; let b = idx & ((1 << self.range_a_bits) - 1); - set_val!(row, self.mlt, F::from_canonical_u64(*mlt as u64)); - set_val!(structural_row, self.range_a, F::from_canonical_usize(a)); - set_val!(structural_row, self.range_b, F::from_canonical_usize(b)); + set_val!(row, self.mlt, F::from_u64(*mlt as u64)); + set_val!(structural_row, self.range_a, F::from_usize(a)); + set_val!(structural_row, self.range_b, F::from_usize(b)); *structural_row.last_mut().unwrap() = F::ONE; }); diff --git a/ceno_zkvm/src/tables/shard_ram.rs b/ceno_zkvm/src/tables/shard_ram.rs index 23897fce8..36ef342da 100644 --- a/ceno_zkvm/src/tables/shard_ram.rs +++ b/ceno_zkvm/src/tables/shard_ram.rs @@ -24,7 +24,7 @@ use gkr_iop::{ use itertools::{Itertools, chain}; use multilinear_extensions::{Expression, ToExpr, WitIn, util::max_usable_threads}; use p3::{ - field::{Field, FieldAlgebra}, + field::{Field, PrimeCharacteristicRing as FieldAlgebra}, matrix::{Matrix, dense::RowMajorMatrix}, symmetric::Permutation, }; @@ -108,13 +108,13 @@ impl ShardRamRecord { ) -> ECPoint { let mut nonce = 0; let mut input = vec![ - E::BaseField::from_canonical_u32(self.addr), - E::BaseField::from_canonical_u32(self.ram_type as u32), - E::BaseField::from_canonical_u32(self.value & 0xFFFF), // lower 16 bits - E::BaseField::from_canonical_u32((self.value >> 16) & 0xFFFF), // higher 16 bits - E::BaseField::from_canonical_u64(self.shard), - E::BaseField::from_canonical_u64(self.global_clk), - E::BaseField::from_canonical_u32(nonce), + E::BaseField::from_u32(self.addr), + E::BaseField::from_u32(self.ram_type as u32), + E::BaseField::from_u32(self.value & 0xFFFF), // lower 16 bits + E::BaseField::from_u32((self.value >> 16) & 0xFFFF), // higher 16 bits + E::BaseField::from_u64(self.shard), + E::BaseField::from_u64(self.global_clk), + E::BaseField::from_u32(nonce), E::BaseField::ZERO, E::BaseField::ZERO, E::BaseField::ZERO, @@ -148,7 +148,7 @@ impl ShardRamRecord { } else { // try again with different nonce nonce += 1; - input[6] = E::BaseField::from_canonical_u32(nonce); + input[6] = E::BaseField::from_u32(nonce); } } } @@ -349,19 +349,19 @@ impl ShardRamCircuit { instance[witin.id as usize] = *fe; }); - let ram_type = E::BaseField::from_canonical_u32(record.ram_type as u32); + let ram_type = E::BaseField::from_u32(record.ram_type as u32); let mut input = [E::BaseField::ZERO; 16]; let k = UINT_LIMBS; - input[0] = E::BaseField::from_canonical_u32(record.addr); + input[0] = E::BaseField::from_u32(record.addr); input[1] = ram_type; input[2..(k + 2)] .iter_mut() .zip(value.as_u16_limbs().iter()) - .for_each(|(i, v)| *i = E::BaseField::from_canonical_u16(*v)); - input[2 + k] = E::BaseField::from_canonical_u64(record.shard); - input[2 + k + 1] = E::BaseField::from_canonical_u64(record.global_clk); - input[2 + k + 2] = E::BaseField::from_canonical_u32(*nonce); + .for_each(|(i, v)| *i = E::BaseField::from_u16(*v)); + input[2 + k] = E::BaseField::from_u64(record.shard); + input[2 + k + 1] = E::BaseField::from_u64(record.global_clk); + input[2 + k + 2] = E::BaseField::from_u32(*nonce); config .perm_config diff --git a/ceno_zkvm/src/uint.rs b/ceno_zkvm/src/uint.rs index b1ee7c07f..2d73bba73 100644 --- a/ceno_zkvm/src/uint.rs +++ b/ceno_zkvm/src/uint.rs @@ -1,3 +1,4 @@ +use p3::field::PrimeCharacteristicRing; mod arithmetic; pub mod constants; mod logic; @@ -16,7 +17,7 @@ use ff_ext::{ExtensionField, SmallField}; use gkr_iop::error::CircuitBuilderError; use itertools::{Itertools, enumerate}; use multilinear_extensions::{Expression, ToExpr, WitIn, util::ceil_log2}; -use p3::field::FieldAlgebra; + use std::{ borrow::Cow, mem::{self}, @@ -187,7 +188,7 @@ impl UIntLimbs { limbs .into_iter() .take(Self::NUM_LIMBS) - .map(|limb| E::BaseField::from_canonical_u64(limb.into()).expr()) + .map(|limb| E::BaseField::from_u64(limb.into()).expr()) .collect::>>(), ), carries: None, @@ -264,7 +265,7 @@ impl UIntLimbs { for (wire, limb) in wires.iter().zip( limbs_values .iter() - .map(|v| E::BaseField::from_canonical_u64(*v as u64)) + .map(|v| E::BaseField::from_u64(*v as u64)) .chain(std::iter::repeat(E::BaseField::ZERO)), ) { instance[wire.id as usize] = limb; @@ -290,7 +291,7 @@ impl UIntLimbs { for (wire, carry) in carries.iter().zip( carry_values .iter() - .map(|v| E::BaseField::from_canonical_u64(Into::::into(*v))) + .map(|v| E::BaseField::from_u64(Into::::into(*v))) .chain(std::iter::repeat(E::BaseField::ZERO)), ) { instance[wire.id as usize] = carry; @@ -483,7 +484,7 @@ impl UIntLimbs { pub fn counter_vector(size: usize) -> Vec> { let num_vars = ceil_log2(size); let number_of_limbs = num_vars.div_ceil(C); - let cell_modulo = F::from_canonical_u64(1 << C); + let cell_modulo = F::from_u64(1 << C); let mut res = vec![vec![F::ZERO; number_of_limbs]]; @@ -754,7 +755,7 @@ impl<'a, T: Into + From + Copy + Default> Value<'a, T> { pub fn u16_fields(&self) -> Vec { self.limbs .iter() - .map(|v| F::from_canonical_u64(*v as u64)) + .map(|v| F::from_u64(*v as u64)) .collect_vec() } diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index d07a66e46..5e90b88e5 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -8,7 +8,7 @@ use crate::{ instructions::riscv::config::IsEqualConfig, }; use multilinear_extensions::{Expression, ToExpr, WitIn}; -use p3::field::FieldAlgebra; +use p3::field::PrimeCharacteristicRing; impl UIntLimbs { const POW_OF_C: usize = 2_usize.pow(C as u32); @@ -86,9 +86,7 @@ impl UIntLimbs { // convert Expression::Constant to limbs let b_limbs = (0..Self::NUM_LIMBS) - .map(|i| { - E::BaseField::from_canonical_u64((b >> (C * i)) & Self::LIMB_BIT_MASK).expr() - }) + .map(|i| E::BaseField::from_u64((b >> (C * i)) & Self::LIMB_BIT_MASK).expr()) .collect_vec(); self.internal_add(cb, &b_limbs, with_overflow) @@ -319,8 +317,7 @@ mod tests { use ff_ext::{ExtensionField, GoldilocksExt2}; use itertools::Itertools; use multilinear_extensions::{ToExpr, utils::eval_by_expr}; - use p3::field::FieldAlgebra; - + use p3::field::PrimeCharacteristicRing as _; type E = GoldilocksExt2; #[test] fn test_add64_16_no_carries() { @@ -448,7 +445,7 @@ mod tests { let challenges = vec![E::ONE; witness_values.len()]; let uint_a = UIntLimbs::::new(|| "uint_a", &mut cb).unwrap(); let uint_c = if let Some(const_b) = const_b { - let const_b = E::BaseField::from_canonical_u64(const_b).expr(); + let const_b = E::BaseField::from_u64(const_b).expr(); uint_a .add_const(|| "uint_c", &mut cb, const_b, overflow) .unwrap() @@ -505,13 +502,10 @@ mod tests { let wit: Vec = witness_values .iter() .cloned() - .map(E::from_canonical_u64) + .map(E::from_u64) .collect_vec(); uint_c.expr().iter().zip(result).for_each(|(c, ret)| { - assert_eq!( - eval_by_expr(&wit, &[], &challenges, c), - E::from_canonical_u64(ret) - ); + assert_eq!(eval_by_expr(&wit, &[], &challenges, c), E::from_u64(ret)); }); // overflow @@ -684,13 +678,10 @@ mod tests { let wit: Vec = witness_values .iter() .cloned() - .map(E::from_canonical_u64) + .map(E::from_u64) .collect_vec(); uint_c.expr().iter().zip(result).for_each(|(c, ret)| { - assert_eq!( - eval_by_expr(&wit, &[], &challenges, c), - E::from_canonical_u64(ret) - ); + assert_eq!(eval_by_expr(&wit, &[], &challenges, c), E::from_u64(ret)); }); // overflow @@ -716,8 +707,7 @@ mod tests { use ff_ext::{ExtensionField, GoldilocksExt2}; use itertools::Itertools; use multilinear_extensions::mle::{ArcMultilinearExtension, MultilinearExtension}; - use p3::field::FieldAlgebra; - + use p3::field::PrimeCharacteristicRing; type E = GoldilocksExt2; // 18446744069414584321 trait ValueToArcMle { @@ -732,7 +722,7 @@ mod tests { let mle: ArcMultilinearExtension = MultilinearExtension::from_evaluation_vec_smart( 0, - vec![E::BaseField::from_canonical_u64(*a)], + vec![E::BaseField::from_u64(*a)], ) .into(); mle diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index 276622cf7..855f00efb 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -8,7 +8,7 @@ use std::{ use ff_ext::ExtensionField; pub use gkr_iop::utils::i64_to_base; use itertools::Itertools; -use p3::field::Field; +use p3::field::{Field, PrimeCharacteristicRing}; #[cfg(feature = "u16limb_circuit")] use crate::instructions::riscv::constants::UINT_LIMBS; @@ -17,8 +17,6 @@ use multilinear_extensions::Expression; #[cfg(feature = "u16limb_circuit")] use multilinear_extensions::ToExpr; #[cfg(feature = "u16limb_circuit")] -use p3::field::FieldAlgebra; - pub fn split_to_u8>(value: u32) -> Vec { (0..(u32::BITS / 8)) .scan(value, |acc, _| { @@ -132,10 +130,7 @@ pub fn imm_sign_extend_circuit( if !require_signed { [imm, E::BaseField::ZERO.expr()] } else { - [ - imm, - is_signed * E::BaseField::from_canonical_u16(0xffff).expr(), - ] + [imm, is_signed * E::BaseField::from_u16(0xffff).expr()] } } diff --git a/clippy.toml b/clippy.toml index 41690a1eb..f4b68d53f 100644 --- a/clippy.toml +++ b/clippy.toml @@ -26,4 +26,8 @@ allowed-duplicate-crates = [ "p256", "primeorder", "ecdsa", + "rand", + "rand_chacha", + "rand_core", + "spin", ] diff --git a/gkr_iop/Cargo.toml b/gkr_iop/Cargo.toml index f93ef4335..abce781b6 100644 --- a/gkr_iop/Cargo.toml +++ b/gkr_iop/Cargo.toml @@ -20,6 +20,7 @@ mpcs.workspace = true multilinear_extensions.workspace = true once_cell.workspace = true p3.workspace = true +p3-field.workspace = true rand.workspace = true rayon.workspace = true serde.workspace = true diff --git a/gkr_iop/src/circuit_builder.rs b/gkr_iop/src/circuit_builder.rs index d84cb4d55..f14a041ee 100644 --- a/gkr_iop/src/circuit_builder.rs +++ b/gkr_iop/src/circuit_builder.rs @@ -7,13 +7,12 @@ use serde::de::DeserializeOwned; use std::{collections::HashMap, iter::once, marker::PhantomData}; use ff_ext::ExtensionField; +use p3_field::PrimeCharacteristicRing; use crate::{ RAMType, error::CircuitBuilderError, gkr::layer::ROTATION_OPENING_COUNT, selector::SelectorType, tables::LookupTable, }; -use p3::field::FieldAlgebra; - pub mod ram; #[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)] @@ -295,7 +294,7 @@ impl ConstraintSystem { record: Vec>, ) -> Result<(), CircuitBuilderError> { let rlc_record = self.rlc_chip_record( - std::iter::once(E::BaseField::from_canonical_u64(rom_type as u64).expr()) + std::iter::once(E::BaseField::from_u64(rom_type as u64).expr()) .chain(record.clone()) .collect(), ); @@ -1015,7 +1014,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { cb.lk_record( name_fn, LookupTable::Dynamic, - vec![expr, E::BaseField::from_canonical_usize(max_bits).expr()], + vec![expr, E::BaseField::from_usize(max_bits).expr()], ) }, ) @@ -1037,7 +1036,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { cb.lk_record( name_fn, LookupTable::Dynamic, - vec![expr, E::BaseField::from_canonical_usize(8).expr()], + vec![expr, E::BaseField::from_usize(8).expr()], ) }, ) @@ -1437,7 +1436,7 @@ pub fn expansion_expr( .fold((0, E::BaseField::ZERO.expr()), |acc, (sz, felt)| { ( acc.0 + sz, - acc.1 * E::BaseField::from_canonical_u64(1 << sz).expr() + felt.expr(), + acc.1 * E::BaseField::from_u64(1 << sz).expr() + felt.expr(), ) }); diff --git a/gkr_iop/src/gadgets/is_lt.rs b/gkr_iop/src/gadgets/is_lt.rs index d3f4a2ac6..17f871655 100644 --- a/gkr_iop/src/gadgets/is_lt.rs +++ b/gkr_iop/src/gadgets/is_lt.rs @@ -2,7 +2,7 @@ use crate::utils::i64_to_base; use ff_ext::{ExtensionField, FieldInto, SmallField}; use itertools::izip; use multilinear_extensions::{Expression, ToExpr, WitIn, power_sequence}; -use p3::field::Field; +use p3_field::Field; use std::fmt::Display; use witness::set_val; @@ -220,13 +220,7 @@ impl InnerLtConfig { lhs: u64, rhs: u64, ) -> Result<(), CircuitBuilderError> { - self.assign_instance_field( - instance, - lkm, - F::from_canonical_u64(lhs), - F::from_canonical_u64(rhs), - lhs < rhs, - ) + self.assign_instance_field(instance, lkm, F::from_u64(lhs), F::from_u64(rhs), lhs < rhs) } /// Assign instance values to this configuration where the ordering is diff --git a/gkr_iop/src/gkr/layer.rs b/gkr_iop/src/gkr/layer.rs index a67862df7..29cf65fa5 100644 --- a/gkr_iop/src/gkr/layer.rs +++ b/gkr_iop/src/gkr/layer.rs @@ -7,7 +7,7 @@ use multilinear_extensions::{ mle::{Point, PointAndEval}, monomial::Term, }; -use p3::field::FieldAlgebra; +use p3_field::PrimeCharacteristicRing; use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator}; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use std::{ops::Neg, sync::Arc, vec::IntoIter}; diff --git a/gkr_iop/src/gkr/layer/zerocheck_layer.rs b/gkr_iop/src/gkr/layer/zerocheck_layer.rs index cf1da6df2..5ee8f58cb 100644 --- a/gkr_iop/src/gkr/layer/zerocheck_layer.rs +++ b/gkr_iop/src/gkr/layer/zerocheck_layer.rs @@ -9,7 +9,7 @@ use multilinear_extensions::{ utils::{eval_by_expr, eval_by_expr_with_instance, expr_convert_to_witins}, virtual_poly::VPAuxInfo, }; -use p3::field::{FieldAlgebra, dot_product}; +use p3_field::{PrimeCharacteristicRing, dot_product}; use smallvec::SmallVec; use std::{cmp::Ordering, collections::BTreeMap, marker::PhantomData, ops::Neg}; use sumcheck::{ diff --git a/gkr_iop/src/gkr/layer_constraint_system.rs b/gkr_iop/src/gkr/layer_constraint_system.rs index 6bbb1cdc7..52234516b 100644 --- a/gkr_iop/src/gkr/layer_constraint_system.rs +++ b/gkr_iop/src/gkr/layer_constraint_system.rs @@ -10,8 +10,7 @@ use crate::{ use ff_ext::ExtensionField; use itertools::{Itertools, chain, izip}; use multilinear_extensions::{Expression, Fixed, ToExpr, WitnessId, rlc_chip_record}; -use p3::field::FieldAlgebra; - +use p3_field::PrimeCharacteristicRing; #[derive(Clone, Debug, Default)] pub struct RotationParams { pub rotation_eqs: Option<[Expression; ROTATION_OPENING_COUNT]>, @@ -102,7 +101,7 @@ impl LayerConstraintSystem { pub fn lookup_and8(&mut self, a: Expression, b: Expression, c: Expression) { let rlc_record = rlc_chip_record( vec![ - E::BaseField::from_canonical_u64(LookupTable::And as u64).expr(), + E::BaseField::from_u64(LookupTable::And as u64).expr(), a, b, c, @@ -116,7 +115,7 @@ impl LayerConstraintSystem { pub fn lookup_xor8(&mut self, a: Expression, b: Expression, c: Expression) { let rlc_record = rlc_chip_record( vec![ - E::BaseField::from_canonical_u64(LookupTable::Xor as u64).expr(), + E::BaseField::from_u64(LookupTable::Xor as u64).expr(), a, b, c, @@ -135,7 +134,7 @@ impl LayerConstraintSystem { let rlc_record = rlc_chip_record( vec![ // TODO: layer constrain system is deprecated - E::BaseField::from_canonical_u64(LookupTable::Dynamic as u64).expr(), + E::BaseField::from_u64(LookupTable::Dynamic as u64).expr(), value.clone(), ], self.alpha.clone(), @@ -145,8 +144,8 @@ impl LayerConstraintSystem { if size < 16 { let rlc_record = rlc_chip_record( vec![ - E::BaseField::from_canonical_u64(LookupTable::Dynamic as u64).expr(), - value * E::BaseField::from_canonical_u64(1 << (16 - size)).expr(), + E::BaseField::from_u64(LookupTable::Dynamic as u64).expr(), + value * E::BaseField::from_u64(1 << (16 - size)).expr(), ], self.alpha.clone(), self.beta.clone(), @@ -470,7 +469,7 @@ pub fn expansion_expr( .fold((0, E::BaseField::ZERO.expr()), |acc, (sz, felt)| { ( acc.0 + sz, - acc.1 * E::BaseField::from_canonical_u64(1 << sz).expr() + felt.expr(), + acc.1 * E::BaseField::from_u64(1 << sz).expr() + felt.expr(), ) }); diff --git a/gkr_iop/src/selector.rs b/gkr_iop/src/selector.rs index 857dbd588..543c74f9e 100644 --- a/gkr_iop/src/selector.rs +++ b/gkr_iop/src/selector.rs @@ -10,7 +10,7 @@ use multilinear_extensions::{ util::ceil_log2, virtual_poly::{build_eq_x_r_vec, eq_eval}, }; -use p3::field::FieldAlgebra; +use p3_field::PrimeCharacteristicRing; use rayon::{ iter::{IntoParallelIterator, ParallelIterator}, slice::ParallelSliceMut, @@ -386,7 +386,7 @@ mod tests { use multilinear_extensions::{ StructuralWitIn, ToExpr, util::ceil_log2, virtual_poly::build_eq_x_r_vec, }; - use p3::field::FieldAlgebra; + use p3_field::PrimeCharacteristicRing; use rand::thread_rng; use crate::selector::{SelectorContext, SelectorType}; diff --git a/gkr_iop/src/utils.rs b/gkr_iop/src/utils.rs index e1c8d7453..c4565a737 100644 --- a/gkr_iop/src/utils.rs +++ b/gkr_iop/src/utils.rs @@ -8,7 +8,7 @@ use multilinear_extensions::{ util::ceil_log2, virtual_poly::{build_eq_x_r_vec, eq_eval}, }; -use p3::field::FieldAlgebra; +use p3_field::PrimeCharacteristicRing; use rayon::{ iter::{IndexedParallelIterator, IntoParallelIterator, ParallelExtend, ParallelIterator}, slice::{ParallelSlice, ParallelSliceMut}, @@ -105,9 +105,9 @@ pub fn rotation_selector_eval( pub fn i64_to_base(x: i64) -> F { if x >= 0 { - F::from_canonical_u64(x as u64) + F::from_u64(x as u64) } else { - -F::from_canonical_u64((-x) as u64) + -F::from_u64((-x) as u64) } } @@ -182,7 +182,7 @@ pub fn eq_eval_less_or_equal_than(max_idx: usize, a: &[E], b: let mut running_product = vec![E::ZERO; b.len() + 1]; running_product[b.len()] = E::ONE; for i in (0..b.len()).rev() { - let bit = E::from_canonical_u64(((max_idx >> i) & 1) as u64); + let bit = E::from_u64(((max_idx >> i) & 1) as u64); running_product[i] = running_product[i + 1] * (a[i] * b[i] * bit + (E::ONE - a[i]) * (E::ONE - b[i]) * (E::ONE - bit)); } @@ -220,12 +220,12 @@ pub fn eval_wellform_address_vec( r: &[E], descending: bool, ) -> E { - let (offset, scaled) = (E::from_canonical_u64(offset), E::from_canonical_u64(scaled)); + let (offset, scaled) = (E::from_u64(offset), E::from_u64(scaled)); let tmp = scaled * r.iter() .scan(E::ONE, |state, x| { let result = *x * *state; - *state *= E::from_canonical_u64(2); // Update the state for the next power of 2 + *state *= E::from_u64(2); // Update the state for the next power of 2 Some(result) }) .sum::(); @@ -285,7 +285,7 @@ pub fn eval_stacked_constant_vec(r: &[E]) -> E { let mut res = E::ZERO; for (i, r) in r.iter().enumerate().skip(1) { - res = res * (E::ONE - *r) + E::from_canonical_usize(i) * *r; + res = res * (E::ONE - *r) + E::from_usize(i) * *r; } res } @@ -312,7 +312,8 @@ pub fn eval_outer_repeated_incremental_vec(k: u64, r: &[E]) - #[cfg(test)] mod tests { use ff_ext::{FromUniformBytes, GoldilocksExt2}; - use p3::{field::FieldAlgebra, goldilocks::Goldilocks}; + use p3::goldilocks::Goldilocks; + use p3_field::PrimeCharacteristicRing; use std::{iter, sync::Arc}; type E = GoldilocksExt2; @@ -334,11 +335,7 @@ mod tests { fn test_rotation_next_base_mle_eval() { type E = GoldilocksExt2; let bh = BooleanHypercube::new(5); - let poly = make_mle::( - (0..128u64) - .map(Goldilocks::from_canonical_u64) - .collect_vec(), - ); + let poly = make_mle::((0..128u64).map(Goldilocks::from_u64).collect_vec()); let rotated = rotation_next_base_mle(&bh, &poly, 5); let mut rng = rand::thread_rng(); @@ -360,15 +357,15 @@ mod tests { #[test] fn test_eval_stacked_wellform_address_vec() { let r = [ - E::from_canonical_usize(123), - E::from_canonical_usize(456), - E::from_canonical_usize(789), - E::from_canonical_usize(3210), - E::from_canonical_usize(9876), + E::from_usize(123), + E::from_usize(456), + E::from_usize(789), + E::from_usize(3210), + E::from_usize(9876), ]; for n in 0..r.len() { let v = iter::once(E::ZERO) - .chain((0..=n).flat_map(|i| (0..(1 << i)).map(E::from_canonical_usize))) + .chain((0..=n).flat_map(|i| (0..(1 << i)).map(E::from_usize))) .collect::>(); let poly = MultilinearExtension::from_evaluations_ext_vec(n + 1, v); assert_eq!( @@ -381,15 +378,15 @@ mod tests { #[test] fn test_eval_stacked_constant_vec() { let r = [ - E::from_canonical_usize(123), - E::from_canonical_usize(456), - E::from_canonical_usize(789), - E::from_canonical_usize(3210), - E::from_canonical_usize(9876), + E::from_usize(123), + E::from_usize(456), + E::from_usize(789), + E::from_usize(3210), + E::from_usize(9876), ]; for n in 0..r.len() { let v = iter::once(E::ZERO) - .chain((0..=n).flat_map(|i| iter::repeat_n(i, 1 << i).map(E::from_canonical_usize))) + .chain((0..=n).flat_map(|i| iter::repeat_n(i, 1 << i).map(E::from_usize))) .collect::>(); let poly = MultilinearExtension::from_evaluations_ext_vec(n + 1, v); assert_eq!( @@ -402,16 +399,16 @@ mod tests { #[test] fn test_eval_inner_repeating_incremental_vec() { let r = [ - E::from_canonical_usize(123), - E::from_canonical_usize(456), - E::from_canonical_usize(789), - E::from_canonical_usize(3210), - E::from_canonical_usize(9876), + E::from_usize(123), + E::from_usize(456), + E::from_usize(789), + E::from_usize(3210), + E::from_usize(9876), ]; for n in 1..=r.len() { for k in 0..=n { let v = (0..(1 << (n - k))) - .flat_map(|i| iter::repeat_n(E::from_canonical_usize(i), 1 << k)) + .flat_map(|i| iter::repeat_n(E::from_usize(i), 1 << k)) .collect::>(); let poly = MultilinearExtension::from_evaluations_ext_vec(n, v); assert_eq!( @@ -425,16 +422,16 @@ mod tests { #[test] fn test_eval_outer_repeating_incremental_vec() { let r = [ - E::from_canonical_usize(123), - E::from_canonical_usize(456), - E::from_canonical_usize(789), - E::from_canonical_usize(3210), - E::from_canonical_usize(9876), + E::from_usize(123), + E::from_usize(456), + E::from_usize(789), + E::from_usize(3210), + E::from_usize(9876), ]; for n in 1..=r.len() { for k in 0..=n { let v = iter::repeat_n(0, 1 << (n - k)) - .flat_map(|_| (0..(1 << k)).map(E::from_canonical_usize)) + .flat_map(|_| (0..(1 << k)).map(E::from_usize)) .collect::>(); let poly = MultilinearExtension::from_evaluations_ext_vec(n, v); assert_eq!(