Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions dot_ring/ring_proof/columns/columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import json
import os
import secrets
from dataclasses import dataclass
from typing import cast

from dot_ring.ring_proof.constants import DEFAULT_SIZE, MAX_RING_SIZE, OMEGAS, S_PRIME, SeedPoint
from dot_ring.ring_proof.constants import DEFAULT_SIZE, MAX_RING_SIZE, OMEGAS, S_PRIME, ZK_ROWS, SeedPoint
from dot_ring.ring_proof.curve.bandersnatch import TwistedEdwardCurve as TE
from dot_ring.ring_proof.helpers import Helpers as H
from dot_ring.ring_proof.params import RingProofParams
Expand All @@ -27,12 +28,30 @@ class Column:
commitment: G1Point | None = None
size: int = DEFAULT_SIZE

def interpolate(self, domain_omega: int = OMEGAS[DEFAULT_SIZE], prime: int = S_PRIME) -> None:
"""Fill `self.coeffs` from `self.evals` using FFT interpolation."""
def interpolate(
self,
domain_omega: int = OMEGAS[DEFAULT_SIZE],
prime: int = S_PRIME,
hidden: bool = False,
test_vectors: bool = False,
) -> None:
"""Fill `self.coeffs` from `self.evals` using FFT interpolation.

When ``hidden=True`` and ``test_vectors=False``, the last
``ZK_ROWS`` positions are filled with cryptographically random
field elements (random blinding) to preserve zero-knowledge.
"""
if self.coeffs is None:
if len(self.evals) > self.size:
raise ValueError(f"{self.name} evals length {len(self.evals)} exceeds column size {self.size}")
self.evals += [0] * (self.size - len(self.evals))
if hidden and not test_vectors:
capacity = self.size - ZK_ROWS
if len(self.evals) > capacity:
raise ValueError(f"{self.name} evals length {len(self.evals)} exceeds capacity {capacity} (size={self.size}, ZK_ROWS={ZK_ROWS})")
self.evals += [0] * (capacity - len(self.evals))
self.evals += [secrets.randbelow(prime) for _ in range(ZK_ROWS)]
else:
if len(self.evals) > self.size:
raise ValueError(f"{self.name} evals length {len(self.evals)} exceeds column size {self.size}")
self.evals += [0] * (self.size - len(self.evals))
self.coeffs = poly_interpolate_fft(self.evals, domain_omega, prime)

def commit(self) -> None:
Expand All @@ -53,6 +72,7 @@ class WitnessColumnBuilder:
prime: int = S_PRIME
max_ring_size: int = MAX_RING_SIZE
padding_rows: int = 4
test_vectors: bool = False

@classmethod
def from_params(
Expand All @@ -73,6 +93,7 @@ def from_params(
prime=params.prime,
max_ring_size=params.max_ring_size,
padding_rows=params.padding_rows,
test_vectors=params.test_vectors,
)

def _bits_vector(self) -> list[int]:
Expand Down Expand Up @@ -120,7 +141,7 @@ def build(self) -> tuple[Column, Column, Column, Column]:
Column("accip", acc_ip, size=self.size),
]
for col in columns:
col.interpolate(self.omega, self.prime)
col.interpolate(self.omega, self.prime, hidden=True, test_vectors=self.test_vectors)
col.commit()
return (columns[0], columns[1], columns[2], columns[3])

Expand Down
3 changes: 3 additions & 0 deletions dot_ring/ring_proof/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@

MAX_RING_SIZE: int = 255 # Upper bound enforced by the constraint system

ZK_ROWS: int = 3 # Number of random blinding rows for zero-knowledge (matches Rust ZK_ROWS)


__all__ = [
"S_PRIME",
Expand All @@ -78,4 +80,5 @@
"D_512",
"D_2048",
"MAX_RING_SIZE",
"ZK_ROWS",
]
3 changes: 3 additions & 0 deletions dot_ring/ring_proof/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class RingProofParams:
prime: int = S_PRIME
base_root: int = OMEGA_2048
base_root_size: int = 2048
test_vectors: bool = False
cv: ClassVar[CurveVariant] = Bandersnatch

def __post_init__(self) -> None:
Expand Down Expand Up @@ -166,6 +167,7 @@ def from_ring_size(
prime: int = S_PRIME,
base_root: int = OMEGA_2048,
base_root_size: int = 2048,
test_vectors: bool = False,
) -> RingProofParams:
"""
Automatically construct RingProofParams based on ring size.
Expand Down Expand Up @@ -210,4 +212,5 @@ def from_ring_size(
prime=prime,
base_root=base_root,
base_root_size=base_root_size,
test_vectors=test_vectors,
)
2 changes: 1 addition & 1 deletion tests/test_bandersnatch_ark.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def test_ring_proof():
ad = bytes.fromhex(item["ad"])
keys = RingVRF[Bandersnatch].parse_keys(bytes.fromhex(item["ring_pks"]))
start = time()
params = RingProofParams()
params = RingProofParams(test_vectors=True)
ring = Ring(keys, params)
ring_root = RingRoot.from_ring(ring, params)
ring_time = time()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_ring_vrf/test_ring_vrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_ring_proof():
keys = RingVRF[Bandersnatch].parse_keys(bytes.fromhex(item["ring_pks"]))

start_time = time.time()
params = RingProofParams()
params = RingProofParams(test_vectors=True)
ring = Ring(keys, params)
ring_root = RingRoot.from_ring(ring, params)
ring_time = time.time()
Expand Down
40 changes: 36 additions & 4 deletions tests/test_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ def verify_ring_vector(vector: dict[str, Any], curve) -> None:
for i in range(0, len(ring_pks_bytes), point_len):
ring_pks.append(ring_pks_bytes[i : i + point_len])

# Construct ring and ring root
params = RingProofParams()
# Construct ring and ring root (test_vectors=True for deterministic proofs)
params = RingProofParams(test_vectors=True)
ring = Ring(ring_pks, params)
ring_root = RingRoot.from_ring(ring, params)

Expand Down Expand Up @@ -417,8 +417,8 @@ def test_wrong_ring_root(self):
alpha = b"test_input"
ad = b"test_ad"

# Construct rings and ring roots
params = RingProofParams()
# Construct rings and ring roots (test_vectors=True for deterministic proofs)
params = RingProofParams(test_vectors=True)
ring_obj1 = Ring(ring1, params)
ring_root1 = RingRoot.from_ring(ring_obj1, params)
ring_obj2 = Ring(ring2, params)
Expand Down Expand Up @@ -470,3 +470,35 @@ def test_pedersen_deterministic(self):
assert proof1.ok.point_to_string() == proof2.ok.point_to_string()
assert proof1.s == proof2.s
assert proof1.sb == proof2.sb

def test_ring_nondeterministic(self):
"""Ring VRF proofs with default params (test_vectors=False) should be non-deterministic.

Two proofs from the same inputs should differ due to random ZK-row blinding,
but both must still verify correctly.
"""
sk = bytes.fromhex("0101010101010101010101010101010101010101010101010101010101010101")
pk = RingVRF[Bandersnatch].get_public_key(sk)

ring_keys = [pk]
for i in range(7):
other_sk = (i + 2).to_bytes(32, "little")
ring_keys.append(RingVRF[Bandersnatch].get_public_key(other_sk))

alpha = b"deterministic_test"
ad = b"test_ad"

# Default params: test_vectors=False → random ZK-row blinding
params = RingProofParams(test_vectors=False)
ring = Ring(ring_keys, params)
ring_root = RingRoot.from_ring(ring, params)

proof1 = RingVRF[Bandersnatch].prove(alpha, ad, sk, pk, ring, ring_root)
proof2 = RingVRF[Bandersnatch].prove(alpha, ad, sk, pk, ring, ring_root)

# Proof bytes should differ due to random blinding
assert proof1.to_bytes() != proof2.to_bytes(), "Ring proofs should be non-deterministic with random ZK-row blinding"

# Both proofs must still verify
assert proof1.verify(alpha, ad, ring, ring_root), "First proof verification failed"
assert proof2.verify(alpha, ad, ring, ring_root), "Second proof verification failed"
Loading