Skip to content

Commit b1d9aa6

Browse files
authored
Merge pull request #6 from AnswerDotAI/0.0.2
0.0.2
2 parents e789d93 + 7c02079 commit b1d9aa6

4 files changed

Lines changed: 30 additions & 7 deletions

File tree

byaldi/RAGModel.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def index(
106106
Returns:
107107
None
108108
"""
109-
self.model.index(
109+
return self.model.index(
110110
input_path,
111111
index_name,
112112
doc_ids,
@@ -133,7 +133,7 @@ def add_to_index(
133133
Returns:
134134
None
135135
"""
136-
self.model.add_to_index(
136+
return self.model.add_to_index(
137137
input_item, store_collection_with_index, doc_id, metadata=metadata
138138
)
139139

@@ -154,3 +154,6 @@ def search(
154154
Union[List[Result], List[List[Result]]]: A list of Result objects or a list of lists of Result objects.
155155
"""
156156
return self.model.search(query, k, return_base64_results)
157+
158+
def get_doc_ids_to_file_names(self):
159+
return self.model.get_doc_ids_to_file_names()

byaldi/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .RAGModel import RAGMultiModalModel
2+
from importlib.metadata import version
23

3-
__version__ = "0.0.1"
4+
__version__ = version("Byaldi")
45
__all__ = ["RAGMultiModalModel"]

byaldi/colpali.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
)
1818
from byaldi.objects import Result
1919
from .utils import capture_print
20+
# Import version directly from the package metadata
21+
from importlib.metadata import version
22+
VERSION = version("Byaldi")
2023

2124

2225
MOCK_IMAGE = Image.new("RGB", (448, 448), (255, 255, 255))
@@ -55,6 +58,7 @@ def __init__(
5558
self.indexed_embeddings = []
5659
self.embed_id_to_doc_id = {}
5760
self.doc_id_to_metadata = {}
61+
self.doc_ids_to_file_names = {}
5862
self.doc_ids = set()
5963

6064
# self.model = ColPali.from_pretrained(
@@ -135,6 +139,11 @@ def __init__(
135139
self.embed_id_to_doc_id = {int(k): v for k, v in self.embed_id_to_doc_id.items()}
136140
self.highest_doc_id = max(int(entry["doc_id"]) for entry in self.embed_id_to_doc_id.values())
137141
self.doc_ids = set(int(entry["doc_id"]) for entry in self.embed_id_to_doc_id.values())
142+
try:
143+
# We don't want this error out with indexes created prior to 0.0.2
144+
self.doc_ids_to_file_names = srsly.read_gzip_json(index_path / "doc_ids_to_file_names.json.gz")
145+
except FileNotFoundError:
146+
pass
138147

139148
# Load metadata
140149
metadata_path = index_path / "metadata.json.gz"
@@ -211,12 +220,16 @@ def _export_index(self):
211220
"model_name": self.model_name,
212221
"full_document_collection": self.full_document_collection,
213222
"highest_doc_id": self.highest_doc_id,
223+
"library_version": VERSION,
214224
}
215225
srsly.write_gzip_json(index_path / "index_config.json.gz", index_config)
216226

217227
# Save embed_id_to_doc_id mapping
218228
srsly.write_gzip_json(index_path / "embed_id_to_doc_id.json.gz", self.embed_id_to_doc_id)
219229

230+
# Save doc_ids_to_file_names
231+
srsly.write_gzip_json(index_path / "doc_ids_to_file_names.json.gz", self.doc_ids_to_file_names)
232+
220233
# Save metadata
221234
srsly.write_gzip_json(index_path / "metadata.json.gz", self.doc_id_to_metadata)
222235

@@ -239,7 +252,7 @@ def index(
239252
store_collection_with_index: bool = False,
240253
overwrite: bool = False,
241254
metadata: Optional[List[Dict[str, Union[str, int]]]] = None,
242-
):
255+
) -> Dict[int, str]:
243256
if (
244257
self.index_name is not None
245258
and (index_name is None or self.index_name == index_name)
@@ -289,22 +302,25 @@ def index(
289302
doc_id = doc_ids[i] if doc_ids else self.highest_doc_id + 1
290303
doc_metadata = metadata[doc_id] if metadata else None
291304
self.add_to_index(item, store_collection_with_index, doc_id=doc_id, metadata=doc_metadata)
305+
self.doc_ids_to_file_names[doc_id] = str(item)
292306
else:
293307
if metadata is not None and len(metadata) != 1:
294308
raise ValueError("For a single document, metadata should be a list with one dictionary")
295309
doc_id = doc_ids[0] if doc_ids else self.highest_doc_id + 1
296310
doc_metadata = metadata[0] if metadata else None
297311
self.add_to_index(input_path, store_collection_with_index, doc_id=doc_id, metadata=doc_metadata)
312+
self.doc_ids_to_file_names[doc_id] = str(input_path)
298313

299314
self._export_index()
315+
return self.doc_ids_to_file_names
300316

301317
def add_to_index(
302318
self,
303319
input_item: Union[str, Path, Image.Image, List[Union[str, Path, Image.Image]]],
304320
store_collection_with_index: bool,
305321
doc_id: Optional[Union[int, List[int]]] = None,
306322
metadata: Optional[List[Dict[str, Union[str, int]]]] = None,
307-
):
323+
) -> Dict[int, str]:
308324
if self.index_name is None:
309325
raise ValueError("No index loaded. Use index() to create or load an index first.")
310326
if not hasattr(self, "highest_doc_id"):
@@ -339,18 +355,22 @@ def add_to_index(
339355
self._process_directory(item_path, store_collection_with_index, current_doc_id, current_metadata)
340356
else:
341357
self._process_and_add_to_index(item_path, store_collection_with_index, current_doc_id, current_metadata)
358+
self.doc_ids_to_file_names[current_doc_id] = str(item_path)
342359
elif isinstance(item, Image.Image):
343360
self._process_and_add_to_index(item, store_collection_with_index, current_doc_id, current_metadata)
361+
self.doc_ids_to_file_names[current_doc_id] = "In-memory Image"
344362
else:
345363
raise ValueError(f"Unsupported input type: {type(item)}")
346364

347365
self._export_index()
366+
return self.doc_ids_to_file_names
348367

349368
def _process_directory(self, directory: Path, store_collection_with_index: bool, base_doc_id: int, metadata: Optional[Dict[str, Union[str, int]]]):
350369
for i, item in enumerate(directory.iterdir()):
351370
print(f"Indexing file: {item}")
352371
current_doc_id = base_doc_id + i
353372
self._process_and_add_to_index(item, store_collection_with_index, current_doc_id, metadata)
373+
self.doc_ids_to_file_names[current_doc_id] = str(item)
354374

355375
def _process_and_add_to_index(
356376
self,

pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ packages = [
99

1010
[project]
1111
name = "Byaldi"
12-
version = "0.0.1"
12+
version = "0.0.2"
1313
description = "Use late-interaction multi-modal models such as ColPALI in just a few lines of code."
1414
readme = "README.md"
1515
requires-python = ">=3.8"
@@ -25,7 +25,6 @@ maintainers = [
2525
dependencies = [
2626
"transformers",
2727
"torch",
28-
"ml-dtypes",
2928
"ninja",
3029
"pdf2image",
3130
"srsly",

0 commit comments

Comments
 (0)