diff --git a/dot_ring/ring_proof/columns/columns.py b/dot_ring/ring_proof/columns/columns.py index 9534c06..a669454 100644 --- a/dot_ring/ring_proof/columns/columns.py +++ b/dot_ring/ring_proof/columns/columns.py @@ -2,10 +2,11 @@ import json import os +import secrets from dataclasses import dataclass from typing import cast -from dot_ring.ring_proof.constants import DEFAULT_SIZE, MAX_RING_SIZE, OMEGAS, S_PRIME, SeedPoint +from dot_ring.ring_proof.constants import DEFAULT_SIZE, MAX_RING_SIZE, OMEGAS, S_PRIME, ZK_ROWS, SeedPoint from dot_ring.ring_proof.curve.bandersnatch import TwistedEdwardCurve as TE from dot_ring.ring_proof.helpers import Helpers as H from dot_ring.ring_proof.params import RingProofParams @@ -27,12 +28,30 @@ class Column: commitment: G1Point | None = None size: int = DEFAULT_SIZE - def interpolate(self, domain_omega: int = OMEGAS[DEFAULT_SIZE], prime: int = S_PRIME) -> None: - """Fill `self.coeffs` from `self.evals` using FFT interpolation.""" + def interpolate( + self, + domain_omega: int = OMEGAS[DEFAULT_SIZE], + prime: int = S_PRIME, + hidden: bool = False, + test_vectors: bool = False, + ) -> None: + """Fill `self.coeffs` from `self.evals` using FFT interpolation. + + When ``hidden=True`` and ``test_vectors=False``, the last + ``ZK_ROWS`` positions are filled with cryptographically random + field elements (random blinding) to preserve zero-knowledge. + """ if self.coeffs is None: - if len(self.evals) > self.size: - raise ValueError(f"{self.name} evals length {len(self.evals)} exceeds column size {self.size}") - self.evals += [0] * (self.size - len(self.evals)) + if hidden and not test_vectors: + capacity = self.size - ZK_ROWS + if len(self.evals) > capacity: + raise ValueError(f"{self.name} evals length {len(self.evals)} exceeds capacity {capacity} (size={self.size}, ZK_ROWS={ZK_ROWS})") + self.evals += [0] * (capacity - len(self.evals)) + self.evals += [secrets.randbelow(prime) for _ in range(ZK_ROWS)] + else: + if len(self.evals) > self.size: + raise ValueError(f"{self.name} evals length {len(self.evals)} exceeds column size {self.size}") + self.evals += [0] * (self.size - len(self.evals)) self.coeffs = poly_interpolate_fft(self.evals, domain_omega, prime) def commit(self) -> None: @@ -53,6 +72,7 @@ class WitnessColumnBuilder: prime: int = S_PRIME max_ring_size: int = MAX_RING_SIZE padding_rows: int = 4 + test_vectors: bool = False @classmethod def from_params( @@ -73,6 +93,7 @@ def from_params( prime=params.prime, max_ring_size=params.max_ring_size, padding_rows=params.padding_rows, + test_vectors=params.test_vectors, ) def _bits_vector(self) -> list[int]: @@ -120,7 +141,7 @@ def build(self) -> tuple[Column, Column, Column, Column]: Column("accip", acc_ip, size=self.size), ] for col in columns: - col.interpolate(self.omega, self.prime) + col.interpolate(self.omega, self.prime, hidden=True, test_vectors=self.test_vectors) col.commit() return (columns[0], columns[1], columns[2], columns[3]) diff --git a/dot_ring/ring_proof/constants.py b/dot_ring/ring_proof/constants.py index 116a732..04ecf39 100644 --- a/dot_ring/ring_proof/constants.py +++ b/dot_ring/ring_proof/constants.py @@ -62,6 +62,8 @@ MAX_RING_SIZE: int = 255 # Upper bound enforced by the constraint system +ZK_ROWS: int = 3 # Number of random blinding rows for zero-knowledge (matches Rust ZK_ROWS) + __all__ = [ "S_PRIME", @@ -78,4 +80,5 @@ "D_512", "D_2048", "MAX_RING_SIZE", + "ZK_ROWS", ] diff --git a/dot_ring/ring_proof/params.py b/dot_ring/ring_proof/params.py index 855734d..d1357be 100644 --- a/dot_ring/ring_proof/params.py +++ b/dot_ring/ring_proof/params.py @@ -102,6 +102,7 @@ class RingProofParams: prime: int = S_PRIME base_root: int = OMEGA_2048 base_root_size: int = 2048 + test_vectors: bool = False cv: ClassVar[CurveVariant] = Bandersnatch def __post_init__(self) -> None: @@ -166,6 +167,7 @@ def from_ring_size( prime: int = S_PRIME, base_root: int = OMEGA_2048, base_root_size: int = 2048, + test_vectors: bool = False, ) -> RingProofParams: """ Automatically construct RingProofParams based on ring size. @@ -210,4 +212,5 @@ def from_ring_size( prime=prime, base_root=base_root, base_root_size=base_root_size, + test_vectors=test_vectors, ) diff --git a/tests/test_bandersnatch_ark.py b/tests/test_bandersnatch_ark.py index 55a25f7..41a6d04 100644 --- a/tests/test_bandersnatch_ark.py +++ b/tests/test_bandersnatch_ark.py @@ -121,7 +121,7 @@ def test_ring_proof(): ad = bytes.fromhex(item["ad"]) keys = RingVRF[Bandersnatch].parse_keys(bytes.fromhex(item["ring_pks"])) start = time() - params = RingProofParams() + params = RingProofParams(test_vectors=True) ring = Ring(keys, params) ring_root = RingRoot.from_ring(ring, params) ring_time = time() diff --git a/tests/test_ring_vrf/test_ring_vrf.py b/tests/test_ring_vrf/test_ring_vrf.py index 164c84e..43c8194 100644 --- a/tests/test_ring_vrf/test_ring_vrf.py +++ b/tests/test_ring_vrf/test_ring_vrf.py @@ -24,7 +24,7 @@ def test_ring_proof(): keys = RingVRF[Bandersnatch].parse_keys(bytes.fromhex(item["ring_pks"])) start_time = time.time() - params = RingProofParams() + params = RingProofParams(test_vectors=True) ring = Ring(keys, params) ring_root = RingRoot.from_ring(ring, params) ring_time = time.time() diff --git a/tests/test_vectors.py b/tests/test_vectors.py index dcb30f7..a0c613a 100644 --- a/tests/test_vectors.py +++ b/tests/test_vectors.py @@ -271,8 +271,8 @@ def verify_ring_vector(vector: dict[str, Any], curve) -> None: for i in range(0, len(ring_pks_bytes), point_len): ring_pks.append(ring_pks_bytes[i : i + point_len]) - # Construct ring and ring root - params = RingProofParams() + # Construct ring and ring root (test_vectors=True for deterministic proofs) + params = RingProofParams(test_vectors=True) ring = Ring(ring_pks, params) ring_root = RingRoot.from_ring(ring, params) @@ -417,8 +417,8 @@ def test_wrong_ring_root(self): alpha = b"test_input" ad = b"test_ad" - # Construct rings and ring roots - params = RingProofParams() + # Construct rings and ring roots (test_vectors=True for deterministic proofs) + params = RingProofParams(test_vectors=True) ring_obj1 = Ring(ring1, params) ring_root1 = RingRoot.from_ring(ring_obj1, params) ring_obj2 = Ring(ring2, params) @@ -470,3 +470,35 @@ def test_pedersen_deterministic(self): assert proof1.ok.point_to_string() == proof2.ok.point_to_string() assert proof1.s == proof2.s assert proof1.sb == proof2.sb + + def test_ring_nondeterministic(self): + """Ring VRF proofs with default params (test_vectors=False) should be non-deterministic. + + Two proofs from the same inputs should differ due to random ZK-row blinding, + but both must still verify correctly. + """ + sk = bytes.fromhex("0101010101010101010101010101010101010101010101010101010101010101") + pk = RingVRF[Bandersnatch].get_public_key(sk) + + ring_keys = [pk] + for i in range(7): + other_sk = (i + 2).to_bytes(32, "little") + ring_keys.append(RingVRF[Bandersnatch].get_public_key(other_sk)) + + alpha = b"deterministic_test" + ad = b"test_ad" + + # Default params: test_vectors=False → random ZK-row blinding + params = RingProofParams(test_vectors=False) + ring = Ring(ring_keys, params) + ring_root = RingRoot.from_ring(ring, params) + + proof1 = RingVRF[Bandersnatch].prove(alpha, ad, sk, pk, ring, ring_root) + proof2 = RingVRF[Bandersnatch].prove(alpha, ad, sk, pk, ring, ring_root) + + # Proof bytes should differ due to random blinding + assert proof1.to_bytes() != proof2.to_bytes(), "Ring proofs should be non-deterministic with random ZK-row blinding" + + # Both proofs must still verify + assert proof1.verify(alpha, ad, ring, ring_root), "First proof verification failed" + assert proof2.verify(alpha, ad, ring, ring_root), "Second proof verification failed"