3333import vector_service_pb2_grpc # noqa: E402
3434
3535# ── Configuration ─────────────────────────────────────────────────────────────
36- SIDECAR_ADDR = os .getenv ("SIDECAR_ADDR" , "localhost:50051" )
37- TAXI_DATA_URL = "https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2023-01.parquet"
38- RAW_FILE = os .path .join (REPO_ROOT , "data" , "raw" , "yellow_tripdata_2023-01.parquet" )
39- DEMO_FILE = os .path .join (REPO_ROOT , "data" , "demo" , "taxi_trips_10k.parquet" )
40- INDEX_DIR = os .path .join (REPO_ROOT , "data" , "indexes" )
41- INDEX_FILE = os .path .join (INDEX_DIR , "nyc_taxi_2023.index" )
42- SAMPLE_SIZE = 10_000
43- RANDOM_SEED = 42
44- BATCH_SIZE = 256 # texts per gRPC batch call
36+ SIDECAR_ADDR = os .getenv ("SIDECAR_ADDR" , "localhost:50051" )
37+ TAXI_DATA_URL = "https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2023-01.parquet"
38+ RAW_FILE = os .path .join (REPO_ROOT , "data" , "raw" , "yellow_tripdata_2023-01.parquet" )
39+ DEMO_FILE = os .path .join (REPO_ROOT , "data" , "demo" , "taxi_trips_10k.parquet" )
40+ INDEX_DIR = os .path .join (REPO_ROOT , "data" , "indexes" )
41+ INDEX_FILE = os .path .join (INDEX_DIR , "nyc_taxi_2023.index" )
42+ SAMPLE_SIZE = 10_000
43+ RANDOM_SEED = 42
44+ BATCH_SIZE = 256 # texts per gRPC batch call
4545
4646
4747# ── Step 1: Download and sample ───────────────────────────────────────────────
@@ -60,7 +60,7 @@ def report(block, block_size, total):
6060 downloaded = block * block_size
6161 if total > 0 :
6262 pct = min (100 , downloaded * 100 // total )
63- mb = downloaded / 1_048_576
63+ mb = downloaded / 1_048_576
6464 print (f"\r { pct } % ({ mb :.0f} MB)" , end = "" , flush = True )
6565
6666 urllib .request .urlretrieve (TAXI_DATA_URL , RAW_FILE , reporthook = report )
@@ -107,12 +107,12 @@ def _check_sidecar():
107107
108108def _make_text (row ) -> str :
109109 """Convert a taxi trip row into a natural-language string for embedding."""
110- pu = int (row .get ('PULocationID' , 0 ))
111- do = int (row .get ('DOLocationID' , 0 ))
112- dist = float (row .get ('trip_distance' , 0 ))
113- fare = float (row .get ('fare_amount' , 0 ))
110+ pu = int (row .get ('PULocationID' , 0 ))
111+ do = int (row .get ('DOLocationID' , 0 ))
112+ dist = float (row .get ('trip_distance' , 0 ))
113+ fare = float (row .get ('fare_amount' , 0 ))
114114 passengers = int (row .get ('passenger_count' , 1 ))
115- pax = "passengers" if passengers > 1 else "passenger"
115+ pax = "passengers" if passengers > 1 else "passenger"
116116 return (
117117 f"Yellow taxi trip from zone { pu } to zone { do } , "
118118 f"{ dist :.1f} miles, ${ fare :.2f} fare, { passengers } { pax } "
@@ -129,14 +129,14 @@ def generate_embeddings(demo_file: str) -> np.ndarray:
129129 texts = [_make_text (row ) for _ , row in df .iterrows ()]
130130
131131 channel = grpc .insecure_channel (SIDECAR_ADDR )
132- stub = vector_service_pb2_grpc .EmbeddingServiceStub (channel )
132+ stub = vector_service_pb2_grpc .EmbeddingServiceStub (channel )
133133
134134 all_embeddings = []
135135 total = len (texts )
136136
137137 for start in range (0 , total , BATCH_SIZE ):
138138 batch = texts [start : start + BATCH_SIZE ]
139- request = vector_service_pb2 .EmbeddingBatchRequest (texts = batch )
139+ request = vector_service_pb2 .EmbeddingBatchRequest (texts = batch )
140140 response = stub .GenerateEmbeddingBatch (request )
141141 for emb in response .embeddings :
142142 all_embeddings .append (emb .vector ) # field name is `vector`
@@ -163,11 +163,11 @@ def build_faiss_index(embeddings: np.ndarray) -> str:
163163 # For 10K vectors: nlist=32 gives ~300 vectors/cell (√10K ≈ 100, but 32
164164 # is safer for training), m=8 subvectors × 8 bits = 1 byte/subvector
165165 nlist = 32
166- m = 8
166+ m = 8
167167 nbits = 8
168168
169169 quantizer = faiss .IndexFlatL2 (d )
170- index = faiss .IndexIVFPQ (quantizer , d , nlist , m , nbits )
170+ index = faiss .IndexIVFPQ (quantizer , d , nlist , m , nbits )
171171
172172 print (f" Training IVF{ nlist } ,PQ{ m } ×{ nbits } on { n :,} vectors..." )
173173 index .train (embeddings )
@@ -208,7 +208,7 @@ def main():
208208 return
209209
210210 try :
211- demo_file = download_sample ()
211+ demo_file = download_sample ()
212212 embeddings = generate_embeddings (demo_file )
213213 build_faiss_index (embeddings )
214214
0 commit comments