@@ -6,6 +6,9 @@ use libsql::{
66 params:: { IntoParams , IntoValue } ,
77 Connection , Database , Value ,
88} ;
9+ use rand:: distributions:: Uniform ;
10+ use rand:: prelude:: * ;
11+ use std:: collections:: HashSet ;
912
1013async fn setup ( ) -> Connection {
1114 let db = Database :: open ( ":memory:" ) . unwrap ( ) ;
@@ -650,3 +653,102 @@ async fn deserialize_row() {
650653 assert_eq ! ( data. status, Status :: Draft ) ;
651654 assert_eq ! ( data. wrapper, Wrapper ( Status :: Published ) ) ;
652655}
656+
657+ #[ tokio:: test]
658+ #[ ignore]
659+ // fuzz test can be run explicitly with following command:
660+ // cargo test vector_fuzz_test -- --nocapture --include-ignored
661+ async fn vector_fuzz_test ( ) {
662+ let mut global_rng = rand:: thread_rng ( ) ;
663+ for attempt in 0 ..10000 {
664+ let seed = global_rng. next_u64 ( ) ;
665+
666+ let mut rng =
667+ rand:: rngs:: StdRng :: from_seed ( unsafe { std:: mem:: transmute ( [ seed, seed, seed, seed] ) } ) ;
668+ let db = Database :: open ( ":memory:" ) . unwrap ( ) ;
669+ let conn = db. connect ( ) . unwrap ( ) ;
670+ let dim = rng. gen_range ( 1 ..=1536 ) ;
671+ let operations = rng. gen_range ( 1 ..128 ) ;
672+ println ! (
673+ "============== ATTEMPT {} (seed {}u64, dim {}, operations {}) ================" ,
674+ attempt, seed, dim, operations
675+ ) ;
676+
677+ let _ = conn
678+ . execute (
679+ & format ! (
680+ "CREATE TABLE users (id INTEGER PRIMARY KEY, v FLOAT32({}) )" ,
681+ dim
682+ ) ,
683+ ( ) ,
684+ )
685+ . await ;
686+ // println!("CREATE TABLE users (id INTEGER PRIMARY KEY, v FLOAT32({}) );", dim);
687+ let _ = conn
688+ . execute (
689+ "CREATE INDEX users_idx ON users ( libsql_vector_idx(v) );" ,
690+ ( ) ,
691+ )
692+ . await ;
693+ // println!("CREATE INDEX users_idx ON users ( libsql_vector_idx(v) );");
694+
695+ let mut next_id = 1 ;
696+ let mut alive = HashSet :: new ( ) ;
697+ let uniform = Uniform :: new ( -1.0 , 1.0 ) ;
698+ for _ in 0 ..operations {
699+ let operation = rng. gen_range ( 0 ..4 ) ;
700+ let vector: Vec < f32 > = ( 0 ..dim) . map ( |_| rng. sample ( uniform) ) . collect ( ) ;
701+ let vector_str = format ! (
702+ "[{}]" ,
703+ vector
704+ . iter( )
705+ . map( |x| format!( "{}" , x) )
706+ . collect:: <Vec <String >>( )
707+ . join( "," )
708+ ) ;
709+ if operation == 0 {
710+ // println!("INSERT INTO users VALUES ({}, vector('{}') );", next_id, vector_str);
711+ conn. execute (
712+ "INSERT INTO users VALUES (?, vector(?) )" ,
713+ libsql:: params![ next_id, vector_str] ,
714+ )
715+ . await
716+ . unwrap ( ) ;
717+ alive. insert ( next_id) ;
718+ next_id += 1 ;
719+ } else if operation == 1 {
720+ let id = rng. gen_range ( 0 ..next_id) ;
721+ // println!("DELETE FROM users WHERE id = {};", id);
722+ conn. execute ( "DELETE FROM users WHERE id = ?" , libsql:: params![ id] )
723+ . await
724+ . unwrap ( ) ;
725+ alive. remove ( & id) ;
726+ } else if operation == 2 && !alive. is_empty ( ) {
727+ let id = alive. iter ( ) . collect :: < Vec < _ > > ( ) [ rng. gen_range ( 0 ..alive. len ( ) ) ] ;
728+ // println!("UPDATE users SET v = vector('{}') WHERE id = {};", vector_str, id);
729+ conn. execute (
730+ "UPDATE users SET v = vector(?) WHERE id = ?" ,
731+ libsql:: params![ vector_str, id] ,
732+ )
733+ . await
734+ . unwrap ( ) ;
735+ } else if operation == 3 {
736+ let k = rng. gen_range ( 1 ..200 ) ;
737+ // println!("SELECT * FROM vector_top_k('users_idx', '{}', {});", vector_str, k);
738+ let result = conn
739+ . query (
740+ "SELECT * FROM vector_top_k('users_idx', ?, ?)" ,
741+ libsql:: params![ vector_str, k] ,
742+ )
743+ . await
744+ . unwrap ( ) ;
745+ let count = result. into_stream ( ) . count ( ) . await ;
746+ assert ! ( count <= alive. len( ) ) ;
747+ if alive. len ( ) > 0 {
748+ assert ! ( count > 0 ) ;
749+ }
750+ }
751+ }
752+ let _ = conn. execute ( "REINDEX users;" , ( ) ) . await . unwrap ( ) ;
753+ }
754+ }
0 commit comments