diff --git a/docker-compose.yml b/docker-compose.yml index 15574ac..52d46ce 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,8 +1,8 @@ services: - data_processing_save_ids: + data_processing_save_labels_descriptions: build: context: . - dockerfile: ./docker/1_Data_Processing_save_ids/Dockerfile + dockerfile: ./docker/1_Data_Processing_save_labels_descriptions/Dockerfile volumes: - ./data:/data # Mount the ./data folder from the host to /data in the container tty: true @@ -12,10 +12,10 @@ services: LANGUAGE: "de" OFFSET: 0 - data_processing_save_entities: + data_processing_save_items_per_lang: build: context: . - dockerfile: ./docker/2_Data_Processing_save_entities/Dockerfile + dockerfile: ./docker/2_Data_Processing_save_items_per_lang/Dockerfile volumes: - ./data:/data # Mount the ./data folder from the host to /data in the container tty: true @@ -43,12 +43,16 @@ services: PYTHONUNBUFFERED: 1 MODEL: "jina" SAMPLE: "true" - API_KEY: "datastax_wikidata_nvidia.json" + API_KEY: "datastax_wikidata2.json" EMBED_BATCH_SIZE: 8 - QUERY_BATCH_SIZE: 1000 - OFFSET: 2560000 - COLLECTION_NAME: "wikidata_test_v1" - LANGUAGE: 'ar' + QUERY_BATCH_SIZE: 100 + OFFSET: 120000 + COLLECTION_NAME: "wikidatav1" + LANGUAGE: 'en' + TEXTIFIER_LANGUAGE: 'en' + ELASTICSEARCH_URL: "http://localhost:9200" + ELASTICSEARCH: "false" + network_mode: "host" run_retrieval: build: @@ -70,17 +74,16 @@ services: environment: PYTHONUNBUFFERED: 1 MODEL: "jina" - API_KEY: "datastax_wikidata_nvidia.json" - COLLECTION_NAME: "wikidata_test_v1" + API_KEY: "datastax_wikidata.json" + COLLECTION_NAME: "wikidata_texttest" BATCH_SIZE: 100 - EVALUATION_PATH: "Mintaka/processed_dataframe_langtest.pkl" + EVALUATION_PATH: "Mintaka/processed_dataframe.pkl" # COMPARATIVE: "true" # COMPARATIVE_COLS: "Correct QID,Wrong QID" QUERY_COL: "Question" - # QUERY_LANGUAGE: "ar" - # DB_LANGUAGE: "en,ar" + QUERY_LANGUAGE: "en" + # DB_LANGUAGE: "en" PREFIX: "" - ELASTICSEARCH_URL: "http://localhost:9200" network_mode: "host" run_rerank: @@ -107,4 +110,42 @@ services: BATCH_SIZE: 1 QUERY_COL: "Question" LANGUAGE: "de" - network_mode: "host" \ No newline at end of file + network_mode: "host" + + push_huggingface: + build: + context: . + dockerfile: ./docker/6_Push_Huggingface/Dockerfile + volumes: + - ./data:/data + tty: true + container_name: push_huggingface + environment: + PYTHONUNBUFFERED: 1 + QUEUE_SIZE: 5000 + NUM_PROCESSES: 4 + SKIPLINES: 0 + ITERATION: 36 + + create_prototype: + build: + context: . + dockerfile: ./docker/7_Create_Prototype/Dockerfile + volumes: + - ./data:/data + - ~/.cache/huggingface:/root/.cache/huggingface + tty: true + container_name: create_prototype + environment: + PYTHONUNBUFFERED: 1 + MODEL: "jinaapi" + API_KEY: "datastax_wikidata.json" + EMBED_BATCH_SIZE: 100 + # QUEUE_SIZE: 5000 + NUM_PROCESSES: 23 + OFFSET: 0 + COLLECTION_NAME: "wikidata_prototype" + LANGUAGE: 'en' + TEXTIFIER_LANGUAGE: 'en' + # CHUNK_NUM: 5 + network_mode: "host" \ No newline at end of file diff --git a/docker/1_Data_Processing_save_ids/run.py b/docker/1_Data_Processing_save_ids/run.py deleted file mode 100644 index c9034b2..0000000 --- a/docker/1_Data_Processing_save_ids/run.py +++ /dev/null @@ -1,41 +0,0 @@ -import sys -sys.path.append('../src') - -from wikidataDumpReader import WikidataDumpReader -from wikidataDB import WikidataID -from multiprocessing import Manager -import os -import time - -FILEPATH = os.getenv("FILEPATH", '../data/Wikidata/latest-all.json.bz2') -BATCH_SIZE = int(os.getenv("BATCH_SIZE", 1000)) -QUEUE_SIZE = int(os.getenv("QUEUE_SIZE", 1500)) -NUM_PROCESSES = int(os.getenv("NUM_PROCESSES", 4)) -SKIPLINES = int(os.getenv("SKIPLINES", 0)) -LANGUAGE = os.getenv("LANGUAGE", 'en') - -def save_ids_to_sqlite(item, bulk_ids, sqlitDBlock): - if (item is not None) and WikidataID.is_in_wikipedia(item, language=LANGUAGE): - ids = WikidataID.extract_entity_ids(item, language=LANGUAGE) - bulk_ids.extend(ids) - - with sqlitDBlock: - if len(bulk_ids) > BATCH_SIZE: - worked = WikidataID.add_bulk_ids(list(bulk_ids[:BATCH_SIZE])) - if worked: - del bulk_ids[:BATCH_SIZE] - -if __name__ == "__main__": - multiprocess_manager = Manager() - sqlitDBlock = multiprocess_manager.Lock() - bulk_ids = multiprocess_manager.list() - - wikidata = WikidataDumpReader(FILEPATH, num_processes=NUM_PROCESSES, batch_size=BATCH_SIZE, queue_size=QUEUE_SIZE, skiplines=SKIPLINES) - wikidata.run(lambda item: save_ids_to_sqlite(item, bulk_ids, sqlitDBlock), max_iterations=None, verbose=True) - - while len(bulk_ids) > 0: - worked = WikidataID.add_bulk_ids(list(bulk_ids)) - if worked: - bulk_ids[:] = [] - else: - time.sleep(1) \ No newline at end of file diff --git a/docker/1_Data_Processing_save_ids/Dockerfile b/docker/1_Data_Processing_save_labels_descriptions/Dockerfile similarity index 72% rename from docker/1_Data_Processing_save_ids/Dockerfile rename to docker/1_Data_Processing_save_labels_descriptions/Dockerfile index 68bae7e..9057498 100644 --- a/docker/1_Data_Processing_save_ids/Dockerfile +++ b/docker/1_Data_Processing_save_labels_descriptions/Dockerfile @@ -10,17 +10,19 @@ LABEL maintainer="philippe.saade@wikimedia.de" WORKDIR /app # Copy the requirements file into the container -COPY ./docker/1_Data_Processing_save_ids/requirements.txt requirements.txt +COPY ./docker/1_Data_Processing_save_labels_descriptions/requirements.txt requirements.txt # Install the dependencies RUN pip install --no-cache-dir -r requirements.txt # Copy the rest of the application code into the container -COPY ./docker/1_Data_Processing_save_ids /app +COPY ./docker/1_Data_Processing_save_labels_descriptions /app COPY ./src /src # Set up the volume for the data folder VOLUME [ "/data" ] +ENV PYTHONPATH="${PYTHONPATH}:/" + # Run the Python script -CMD ["python", "run.py"] \ No newline at end of file +CMD ["python", "run.py"] diff --git a/docker/1_Data_Processing_save_ids/requirements.txt b/docker/1_Data_Processing_save_labels_descriptions/requirements.txt similarity index 100% rename from docker/1_Data_Processing_save_ids/requirements.txt rename to docker/1_Data_Processing_save_labels_descriptions/requirements.txt diff --git a/docker/1_Data_Processing_save_labels_descriptions/run.py b/docker/1_Data_Processing_save_labels_descriptions/run.py new file mode 100644 index 0000000..78dc999 --- /dev/null +++ b/docker/1_Data_Processing_save_labels_descriptions/run.py @@ -0,0 +1,69 @@ +from multiprocessing import Manager +import os +import time +import json + +from src.wikidataDumpReader import WikidataDumpReader +from src.wikidataItemDB import WikidataItem + +FILEPATH = os.getenv("FILEPATH", '../data/Wikidata/latest-all.json.bz2') +PUSH_SIZE = int(os.getenv("PUSH_SIZE", 20000)) +QUEUE_SIZE = int(os.getenv("QUEUE_SIZE", 15000)) +NUM_PROCESSES = int(os.getenv("NUM_PROCESSES", 4)) +SKIPLINES = int(os.getenv("SKIPLINES", 0)) +LANGUAGE = os.getenv("LANGUAGE", 'en') + + +def save_items_to_sqlite(item, data_batch, sqlitDBlock): + if (item is not None): + labels = WikidataItem.clean_label_description(item['labels']) + descriptions = WikidataItem.clean_label_description( + item['descriptions'] + ) + labels = json.dumps(labels, separators=(',', ':')) + descriptions = json.dumps(descriptions, separators=(',', ':')) + in_wikipedia = WikidataItem.is_in_wikipedia(item) + data_batch.append({ + 'id': item['id'], + 'labels': labels, + 'descriptions': descriptions, + 'in_wikipedia': in_wikipedia, + }) + + with sqlitDBlock: + if len(data_batch) > PUSH_SIZE: + worked = WikidataItem.add_bulk_items(list( + data_batch[:PUSH_SIZE] + )) + if worked: + del data_batch[:PUSH_SIZE] + + +if __name__ == "__main__": + multiprocess_manager = Manager() + sqlitDBlock = multiprocess_manager.Lock() + data_batch = multiprocess_manager.list() + + wikidata = WikidataDumpReader( + FILEPATH, + num_processes=NUM_PROCESSES, + queue_size=QUEUE_SIZE, + skiplines=SKIPLINES + ) + + wikidata.run( + lambda item: save_items_to_sqlite( + item, + data_batch, + sqlitDBlock + ), + max_iterations=None, + verbose=True + ) + + while len(data_batch) > 0: + worked = WikidataItem.add_bulk_items(list(data_batch)) + if worked: + del data_batch[:PUSH_SIZE] + else: + time.sleep(1) diff --git a/docker/2_Data_Processing_save_entities/run.py b/docker/2_Data_Processing_save_entities/run.py deleted file mode 100644 index 0252157..0000000 --- a/docker/2_Data_Processing_save_entities/run.py +++ /dev/null @@ -1,41 +0,0 @@ -import sys -sys.path.append('../src') - -from wikidataDumpReader import WikidataDumpReader -from wikidataDB import WikidataID, WikidataEntity -from multiprocessing import Manager -import os -import time - -FILEPATH = os.getenv("FILEPATH", '../data/Wikidata/latest-all.json.bz2') -BATCH_SIZE = int(os.getenv("BATCH_SIZE", 1000)) -QUEUE_SIZE = int(os.getenv("QUEUE_SIZE", 1500)) -NUM_PROCESSES = int(os.getenv("NUM_PROCESSES", 4)) -SKIPLINES = int(os.getenv("SKIPLINES", 0)) -LANGUAGE = os.getenv("LANGUAGE", 'en') - -def save_entities_to_sqlite(item, data_batch, sqlitDBlock): - if (item is not None) and WikidataID.get_id(item['id']): - item = WikidataEntity.normalise_item(item, language=LANGUAGE) - data_batch.append(item) - - with sqlitDBlock: - if len(data_batch) > BATCH_SIZE: - worked = WikidataEntity.add_bulk_entities(list(data_batch[:BATCH_SIZE])) - if worked: - del data_batch[:BATCH_SIZE] - -if __name__ == "__main__": - multiprocess_manager = Manager() - sqlitDBlock = multiprocess_manager.Lock() - data_batch = multiprocess_manager.list() - - wikidata = WikidataDumpReader(FILEPATH, num_processes=NUM_PROCESSES, batch_size=BATCH_SIZE, queue_size=QUEUE_SIZE, skiplines=SKIPLINES) - wikidata.run(lambda item: save_entities_to_sqlite(item, data_batch, sqlitDBlock), max_iterations=None, verbose=True) - - while len(data_batch) > 0: - worked = WikidataEntity.add_bulk_entities(list(data_batch)) - if worked: - data_batch[:] = [] - else: - time.sleep(1) \ No newline at end of file diff --git a/docker/2_Data_Processing_save_entities/Dockerfile b/docker/2_Data_Processing_save_items_per_lang/Dockerfile similarity index 92% rename from docker/2_Data_Processing_save_entities/Dockerfile rename to docker/2_Data_Processing_save_items_per_lang/Dockerfile index cb8433b..2c12567 100644 --- a/docker/2_Data_Processing_save_entities/Dockerfile +++ b/docker/2_Data_Processing_save_items_per_lang/Dockerfile @@ -22,5 +22,7 @@ COPY ./src /src # Set up the volume for the data folder VOLUME [ "/data" ] +ENV PYTHONPATH="${PYTHONPATH}:/" + # Run the Python script -CMD ["python", "run.py"] \ No newline at end of file +CMD ["python", "run.py"] diff --git a/docker/2_Data_Processing_save_entities/requirements.txt b/docker/2_Data_Processing_save_items_per_lang/requirements.txt similarity index 100% rename from docker/2_Data_Processing_save_entities/requirements.txt rename to docker/2_Data_Processing_save_items_per_lang/requirements.txt diff --git a/docker/2_Data_Processing_save_items_per_lang/run.py b/docker/2_Data_Processing_save_items_per_lang/run.py new file mode 100644 index 0000000..a3f355d --- /dev/null +++ b/docker/2_Data_Processing_save_items_per_lang/run.py @@ -0,0 +1,73 @@ +from multiprocessing import Manager +import os +import time + +from src.wikidataDumpReader import WikidataDumpReader +from src.wikidataLangDB import WikidataLang + +FILEPATH = os.getenv("FILEPATH", '../data/Wikidata/latest-all.json.bz2') +PUSH_SIZE = int(os.getenv("PUSH_SIZE", 2000)) +QUEUE_SIZE = int(os.getenv("QUEUE_SIZE", 1500)) +NUM_PROCESSES = int(os.getenv("NUM_PROCESSES", 8)) +SKIPLINES = int(os.getenv("SKIPLINES", 0)) +LANGUAGE = os.getenv("LANGUAGE", 'en') + + +def save_entities_to_sqlite(item, data_batch, sqliteDBlock): + """_summary_ + # TODO Add a docstring + + Args: + item (_type_): _description_ + data_batch (_type_): _description_ + sqliteDBlock (_type_): _description_ + """ + if item is not None: + # Check if the item is a valid entity + return + + lang_in_wp = WikidataLang.is_in_wikipedia(item, language=LANGUAGE) + if not lang_in_wp: + # If the entity is not in the specified language Wikipedia, skip + return + + item = WikidataLang.normalise_item(item, language=LANGUAGE) + data_batch.append(item) + + with sqliteDBlock: + if len(data_batch) > PUSH_SIZE: + worked = WikidataLang.add_bulk_entities(list( + data_batch[:PUSH_SIZE] + )) + if worked: + del data_batch[:PUSH_SIZE] + + +if __name__ == "__main__": + multiprocess_manager = Manager() + sqliteDBlock = multiprocess_manager.Lock() + data_batch = multiprocess_manager.list() + + wikidata = WikidataDumpReader( + FILEPATH, + num_processes=NUM_PROCESSES, + queue_size=QUEUE_SIZE, + skiplines=SKIPLINES + ) + + wikidata.run( + lambda item: save_entities_to_sqlite( + item, + data_batch, + sqliteDBlock + ), + max_iterations=None, + verbose=True + ) + + while len(data_batch) > 0: + worked = WikidataLang.add_bulk_entities(list(data_batch)) + if worked: + del data_batch[:PUSH_SIZE] + else: + time.sleep(1) diff --git a/docker/3_Add_Wikidata_to_AstraDB/Dockerfile b/docker/3_Add_Wikidata_to_AstraDB/Dockerfile index cff05c8..eb6f3ca 100644 --- a/docker/3_Add_Wikidata_to_AstraDB/Dockerfile +++ b/docker/3_Add_Wikidata_to_AstraDB/Dockerfile @@ -28,5 +28,7 @@ COPY ./API_tokens /API_tokens # Set up the volume for the data folder VOLUME [ "/data" ] +ENV PYTHONPATH="${PYTHONPATH}:/" + # Run the Python script -CMD ["python", "run.py"] \ No newline at end of file +CMD ["python", "run.py"] diff --git a/docker/3_Add_Wikidata_to_AstraDB/requirements.txt b/docker/3_Add_Wikidata_to_AstraDB/requirements.txt index fbcfb4d..65289f0 100644 --- a/docker/3_Add_Wikidata_to_AstraDB/requirements.txt +++ b/docker/3_Add_Wikidata_to_AstraDB/requirements.txt @@ -17,5 +17,4 @@ langchain_experimental ragstack-ai-langchain[knowledge-store]==1.3.0 langchain-astradb astrapy -elasticsearch -mediawikiapi \ No newline at end of file +elasticsearch \ No newline at end of file diff --git a/docker/3_Add_Wikidata_to_AstraDB/run.py b/docker/3_Add_Wikidata_to_AstraDB/run.py index 9dc113c..4f7c062 100644 --- a/docker/3_Add_Wikidata_to_AstraDB/run.py +++ b/docker/3_Add_Wikidata_to_AstraDB/run.py @@ -1,19 +1,20 @@ -import sys -sys.path.append('../src') - -from wikidataDB import Session, WikidataID, WikidataEntity -from wikidataEmbed import WikidataTextifier -from wikidataRetriever import AstraDBConnect - import json -from tqdm import tqdm import os import pickle -from datetime import datetime import hashlib +from datetime import datetime +from tqdm import tqdm + +from src.wikidataLangDB import Session, WikidataLang +from src.wikidataEmbed import WikidataTextifier +from src.wikidataRetriever import AstraDBConnect, KeywordSearchConnect + MODEL = os.getenv("MODEL", "jina") SAMPLE = os.getenv("SAMPLE", "false").lower() == "true" +SAMPLE_PATH = os.getenv( + "SAMPLE_PATH", "../data/Evaluation Data/Sample IDs (EN).pkl" +) EMBED_BATCH_SIZE = int(os.getenv("EMBED_BATCH_SIZE", 100)) QUERY_BATCH_SIZE = int(os.getenv("QUERY_BATCH_SIZE", 1000)) OFFSET = int(os.getenv("OFFSET", 0)) @@ -23,6 +24,10 @@ TEXTIFIER_LANGUAGE = os.getenv("TEXTIFIER_LANGUAGE", None) DUMPDATE = os.getenv("DUMPDATE", '09/18/2024') +ELASTICSEARCH_URL = os.getenv("ELASTICSEARCH_URL", "http://localhost:9200") +ELASTICSEARCH = os.getenv("ELASTICSEARCH", "false").lower() == "true" + +# TODO: refactor script into function and call after __name__ == "__main__" # Load the Database if not COLLECTION_NAME: raise ValueError("The COLLECTION_NAME environment variable is required") @@ -32,37 +37,67 @@ if not API_KEY_FILENAME: API_KEY_FILENAME = os.listdir("../API_tokens")[0] -datastax_token = json.load(open(f"../API_tokens/{API_KEY_FILENAME}")) -textifier = WikidataTextifier(language=TEXTIFIER_LANGUAGE) -graph_store = AstraDBConnect(datastax_token, COLLECTION_NAME, model=MODEL, batch_size=EMBED_BATCH_SIZE, cache_embeddings=False) +with open(f"../API_tokens/{API_KEY_FILENAME}") as json_in: + datastax_token = json.load(json_in) +textifier = WikidataTextifier( + language=LANGUAGE, + langvar_filename=TEXTIFIER_LANGUAGE +) + +if ELASTICSEARCH: + graph_store = KeywordSearchConnect( + ELASTICSEARCH_URL, + index_name=COLLECTION_NAME + ) +else: + graph_store = AstraDBConnect( + datastax_token, + COLLECTION_NAME, + model=MODEL, + batch_size=EMBED_BATCH_SIZE, + cache_embeddings=False + ) + +# TODO: refactor script into function and call after __name__ == "__main__" # Load the Sample IDs sample_ids = None if SAMPLE: - sample_ids = pickle.load(open("../data/Evaluation Data/Sample IDs (EN).pkl", "rb")) + sample_ids = pickle.load(open(SAMPLE_PATH, "rb")) sample_ids = sample_ids[sample_ids['In Wikipedia']] total_entities = len(sample_ids) def get_entity(session): sample_qids = list(sample_ids['QID'].values)[OFFSET:] - sample_qid_batches = [sample_qids[i:i + QUERY_BATCH_SIZE] for i in range(0, len(sample_qids), QUERY_BATCH_SIZE)] + sample_qid_batches = [ + sample_qids[i:i + QUERY_BATCH_SIZE] + for i in range(0, len(sample_qids), QUERY_BATCH_SIZE) + ] # For each batch of sample QIDs, fetch the entities from the database for qid_batch in sample_qid_batches: - entities = session.query(WikidataEntity).filter(WikidataEntity.id.in_(qid_batch)).yield_per(QUERY_BATCH_SIZE) + entities = session.query(WikidataLang).filter( + WikidataLang.id.in_(qid_batch) + ).yield_per(QUERY_BATCH_SIZE) + for entity in entities: yield entity - else: total_entities = 9203786 def get_entity(session): - entities = session.query(WikidataEntity).join(WikidataID, WikidataEntity.id == WikidataID.id).filter(WikidataID.in_wikipedia == True).offset(OFFSET).yield_per(QUERY_BATCH_SIZE) + entities = session.query( + WikidataLang).offset( + OFFSET + ).yield_per(QUERY_BATCH_SIZE) + for entity in entities: yield entity + if __name__ == "__main__": + # TODO: refactor script into function and call after __name__ == "__main__" with tqdm(total=total_entities-OFFSET) as progressbar: with Session() as session: entity_generator = get_entity(session) @@ -71,22 +106,47 @@ def get_entity(session): for entity in entity_generator: progressbar.update(1) - chunks = textifier.chunk_text(entity, graph_store.tokenizer, max_length=graph_store.max_token_size) + if ELASTICSEARCH: + chunks = [textifier.entity_to_text(entity)] + else: + chunks = textifier.chunk_text( + entity, + graph_store.tokenizer, + max_length=graph_store.max_token_size + ) + for chunk_i in range(len(chunks)): - md5_hash = hashlib.md5(chunks[chunk_i].encode('utf-8')).hexdigest() - metadata={ - # "MD5": md5_hash, - # "Label": entity.label, - # "Description": entity.description, - # "Aliases": entity.aliases, + md5_hash = hashlib.md5( + chunks[chunk_i].encode('utf-8') + ).hexdigest() + + metadata = { + "MD5": md5_hash, + "Label": entity.label, + "Description": entity.description, + "Aliases": entity.aliases, "Date": datetime.now().isoformat(), "QID": entity.id, "ChunkID": chunk_i+1, "Language": LANGUAGE, + "IsItem": ('Q' in entity.id), + "IsProperty": ('P' in entity.id), "DumpDate": DUMPDATE } - graph_store.add_document(id=f"{entity.id}_{LANGUAGE}_{chunk_i+1}", text=chunks[chunk_i], metadata=metadata) - - tqdm.write(progressbar.format_meter(progressbar.n, progressbar.total, progressbar.format_dict["elapsed"])) # tqdm is not wokring in docker compose. This is the alternative + graph_store.add_document( + id=f"{entity.id}_{LANGUAGE}_{chunk_i+1}", + text=chunks[chunk_i], + metadata=metadata + ) + + # tqdm is not working in docker compose. + # This is the alternative + tqdm.write( + progressbar.format_meter( + progressbar.n, + progressbar.total, + progressbar.format_dict["elapsed"] + ) + ) graph_store.push_batch() diff --git a/docker/4_Run_Retrieval/Dockerfile b/docker/4_Run_Retrieval/Dockerfile index c263af7..6936f7f 100644 --- a/docker/4_Run_Retrieval/Dockerfile +++ b/docker/4_Run_Retrieval/Dockerfile @@ -28,5 +28,7 @@ COPY ./API_tokens /API_tokens # Set up the volume for the data folder VOLUME [ "/data" ] +ENV PYTHONPATH="${PYTHONPATH}:/" + # Run the Python script -CMD ["python", "run.py"] \ No newline at end of file +CMD ["python", "run.py"] diff --git a/docker/4_Run_Retrieval/run.py b/docker/4_Run_Retrieval/run.py index 022625a..d3a8b08 100644 --- a/docker/4_Run_Retrieval/run.py +++ b/docker/4_Run_Retrieval/run.py @@ -1,14 +1,12 @@ -import sys -sys.path.append('../src') - -from wikidataRetriever import AstraDBConnect, WikidataKeywordSearch - import json -from tqdm import tqdm import pandas as pd import os import pickle +from tqdm import tqdm +from src.wikidataRetriever import AstraDBConnect, KeywordSearchConnect + +# TODO: change script to functional form with fucnctions called after __name__ MODEL = os.getenv("MODEL", "jina") BATCH_SIZE = int(os.getenv("BATCH_SIZE", 100)) API_KEY_FILENAME = os.getenv("API_KEY", None) @@ -22,11 +20,23 @@ QUERY_LANGUAGE = os.getenv("QUERY_LANGUAGE", 'en') DB_LANGUAGE = os.getenv("DB_LANGUAGE", None) RESTART = os.getenv("RESTART", "false").lower() == "true" -ELASTICSEARCH_URL = os.getenv("ELASTICSEARCH_URL", "http://localhost:9200") PREFIX = os.getenv("PREFIX", "") -OUTPUT_FILENAME = f"retrieval_results_{EVALUATION_PATH.split('/')[-2]}-{COLLECTION_NAME}-DB({DB_LANGUAGE})-Query({QUERY_LANGUAGE})" -# OUTPUT_FILENAME = f"retrieval_results_{EVALUATION_PATH.split('/')[-2]}-keyword-search-{LANGUAGE}" +ELASTICSEARCH_URL = os.getenv("ELASTICSEARCH_URL", "http://localhost:9200") +ELASTICSEARCH = os.getenv("ELASTICSEARCH", "false").lower() == "true" + +OUTPUT_FILENAME = ( + f"retrieval_results_{EVALUATION_PATH.split('/')[-2]}-{COLLECTION_NAME}-" + f"DB({DB_LANGUAGE})-Query({QUERY_LANGUAGE})" +) + +# TODO: remove unneccesary commented out code +# OUTPUT_FILENAME = ( +# f"retrieval_results_{EVALUATION_PATH.split('/')[-2]}-" +# f"keyword-search-{LANGUAGE}" +# ) + + if PREFIX != "": OUTPUT_FILENAME += PREFIX @@ -36,22 +46,45 @@ if not API_KEY_FILENAME: API_KEY_FILENAME = os.listdir("../API_tokens")[0] -datastax_token = json.load(open(f"../API_tokens/{API_KEY_FILENAME}")) + print(f"API_KEY_FILENAME not provided. Using {API_KEY_FILENAME}") + -graph_store = AstraDBConnect(datastax_token, COLLECTION_NAME, model=MODEL, batch_size=BATCH_SIZE, cache_embeddings=True) -# graph_store = WikidataKeywordSearch(ELASTICSEARCH_URL) +with open(f"../API_tokens/{API_KEY_FILENAME}") as json_in: + datastax_token = json.load(json_in) -#Load the Evaluation Dataset +if ELASTICSEARCH: + graph_store = KeywordSearchConnect( + ELASTICSEARCH_URL, + index_name=COLLECTION_NAME + ) + OUTPUT_FILENAME += "_bm25" +else: + graph_store = AstraDBConnect( + datastax_token, + COLLECTION_NAME, + model=MODEL, + batch_size=BATCH_SIZE, + cache_embeddings=True + ) + +# Load the Evaluation Dataset if not QUERY_COL: raise ValueError("The QUERY_COL environment variable is required") if not EVALUATION_PATH: raise ValueError("The EVALUATION_PATH environment variable is required") -if not RESTART and os.path.exists(f"../data/Evaluation Data/{OUTPUT_FILENAME}.pkl"): +outputfile_exists = os.path.exists( + f"../data/Evaluation Data/{OUTPUT_FILENAME}.pkl" +) +if not RESTART and outputfile_exists: print(f"Loading data from: {OUTPUT_FILENAME}") - eval_data = pickle.load(open(f"../data/Evaluation Data/{OUTPUT_FILENAME}.pkl", "rb")) + pkl_fpath = f"../data/Evaluation Data/{OUTPUT_FILENAME}.pkl" + with open(pkl_fpath, "rb") as pkl_file: + eval_data = pickle.load(pkl_file) else: - eval_data = pickle.load(open(f"../data/Evaluation Data/{EVALUATION_PATH}", "rb")) + pkl_fpath = f"../data/Evaluation Data/{EVALUATION_PATH}" + with open(pkl_fpath, "rb") as pkl_file: + eval_data = pickle.load(pkl_file) if 'Language' in eval_data.columns: eval_data = eval_data[eval_data['Language'] == QUERY_LANGUAGE] @@ -65,23 +98,56 @@ if 'Retrieval Score' not in eval_data: eval_data['Retrieval Score'] = None - row_to_process = eval_data['Retrieval QIDs'].apply(lambda x: (x is None) or (len(x) == 0)) | eval_data['Retrieval Score'].apply(lambda x: (x is None) or (len(x) == 0)) # Find rows that havn't been processed + # TODO: Refactor this row_to_process to avoid nested .apply + # Find rows that haven't been processed + row_to_process = eval_data['Retrieval QIDs'].apply( + lambda x: (x is None) or (len(x) == 0) + ) | eval_data['Retrieval Score'].apply( + lambda x: (x is None) or (len(x) == 0) + ) + pkl_output_path = f"../data/Evaluation Data/{OUTPUT_FILENAME}.pkl" progressbar.update((~row_to_process).sum()) for i in range(0, row_to_process.sum(), BATCH_SIZE): batch_idx = eval_data[row_to_process].iloc[i:i+BATCH_SIZE].index batch = eval_data.loc[batch_idx] if COMPARATIVE: - batch_results = graph_store.batch_retrieve_comparative(batch[QUERY_COL], batch[COMPARATIVE_COLS.split(',')], K=K, Language=DB_LANGUAGE) + batch_results = graph_store.batch_retrieve_comparative( + batch[QUERY_COL], + batch[COMPARATIVE_COLS.split(',')], + K=K, + Language=DB_LANGUAGE + ) else: - batch_results = graph_store.batch_retrieve(batch[QUERY_COL], K=K, Language=DB_LANGUAGE) - - eval_data.loc[batch_idx, 'Retrieval QIDs'] = pd.Series(batch_results[0]).values - eval_data.loc[batch_idx, 'Retrieval Score'] = pd.Series(batch_results[1]).values - + batch_results = graph_store.batch_retrieve( + batch[QUERY_COL], + K=K, + Language=DB_LANGUAGE + ) + + eval_data.loc[batch_idx, 'Retrieval QIDs'] = pd.Series( + batch_results[0] + ).values + + eval_data.loc[batch_idx, 'Retrieval Score'] = pd.Series( + batch_results[1] + ).values + + # TODO: Create progress bar update funciton + # tqdm is not wokring in docker compose. This is the alternative progressbar.update(len(batch)) - tqdm.write(progressbar.format_meter(progressbar.n, progressbar.total, progressbar.format_dict["elapsed"])) # tqdm is not wokring in docker compose. This is the alternative + tqdm.write( + progressbar.format_meter( + progressbar.n, + progressbar.total, + progressbar.format_dict["elapsed"] + ) + ) if progressbar.n % 100 == 0: - pickle.dump(eval_data, open(f"../data/Evaluation Data/{OUTPUT_FILENAME}.pkl", "wb")) - pickle.dump(eval_data, open(f"../data/Evaluation Data/{OUTPUT_FILENAME}.pkl", "wb")) \ No newline at end of file + # BUG: Why is the pkl output file being written twice at end + with open(pkl_output_path, "wb") as pkl_file: + pickle.dump(eval_data, pkl_file) + + # BUG: Why is the pkl output file being written twice at end + pickle.dump(eval_data, open(pkl_output_path, "wb")) diff --git a/docker/5_Run_Rerank/Dockerfile b/docker/5_Run_Rerank/Dockerfile index adfd690..a624029 100644 --- a/docker/5_Run_Rerank/Dockerfile +++ b/docker/5_Run_Rerank/Dockerfile @@ -28,5 +28,7 @@ COPY ./API_tokens /API_tokens # Set up the volume for the data folder VOLUME [ "/data" ] +ENV PYTHONPATH="${PYTHONPATH}:/" + # Run the Python script -CMD ["python", "run.py"] \ No newline at end of file +CMD ["python", "run.py"] diff --git a/docker/5_Run_Rerank/run.py b/docker/5_Run_Rerank/run.py index e443482..7cc26e0 100644 --- a/docker/5_Run_Rerank/run.py +++ b/docker/5_Run_Rerank/run.py @@ -1,15 +1,17 @@ -import sys -sys.path.append('../src') - -from wikidataDB import WikidataEntity -from wikidataEmbed import WikidataTextifier -from JinaAI import JinaAIReranker - from tqdm import tqdm import pandas as pd import os import pickle +# BUG: Reviwer assumed wikidataDB was converted to wikidataLangDB +# Bug: Reviwer assumed wikidataDB.WikidataEntity was converted +# to wikidataLangDB.WikidataLang +# from src.wikidataDB import WikidataEntity +from src.wikidataLangDB import WikidataLang +from src.wikidataEmbed import WikidataTextifier +from src.JinaAI import JinaAIReranker + + MODEL = os.getenv("MODEL", "jina") BATCH_SIZE = int(os.getenv("BATCH_SIZE", 100)) RETRIEVAL_FILENAME = os.getenv("RETRIEVAL_FILENAME") @@ -20,10 +22,17 @@ textifier = WikidataTextifier(language=LANGUAGE) reranker = JinaAIReranker() -eval_data = pickle.load(open(f"../data/Evaluation Data/{RETRIEVAL_FILENAME}.pkl", "rb")) +pkl_fpath = f"../data/Evaluation Data/{RETRIEVAL_FILENAME}.pkl" +with open(pkl_fpath, "rb") as pkl_file: + eval_data = pickle.load(pkl_file) + + +# Rerank the QIDs def rerank_qids(query, qids, reranker, textifier): - entities = [WikidataEntity.get_entity(qid) for qid in qids] + # Bug: Reviwer assumed WikidataEntity was converted to WikidataLang + # entities = [WikidataEntity.get_entity(qid) for qid in qids] + entities = [WikidataLang.get_entity(qid) for qid in qids] texts = [textifier.entity_to_text(entity) for entity in entities] scores = reranker.rank(query, texts) @@ -31,6 +40,7 @@ def rerank_qids(query, qids, reranker, textifier): score_zip = sorted(score_zip, key=lambda x: -x[0]) return [x[1] for x in score_zip] + if __name__ == "__main__": with tqdm(total=len(eval_data), disable=False) as progressbar: if 'Reranked QIDs' not in eval_data: @@ -42,12 +52,33 @@ def rerank_qids(query, qids, reranker, textifier): row = eval_data[row_to_process].iloc[i] # Rerank the QIDs - ranked_qids = rerank_qids(row[QUERY_COL], row['Retrieval QIDs'], reranker, textifier) + ranked_qids = rerank_qids( + row[QUERY_COL], + row['Retrieval QIDs'], + reranker, + textifier + ) - eval_data.loc[[row.index], 'Reranked QIDs'] = pd.Series(ranked_qids).values + eval_data.loc[[row.index], 'Reranked QIDs'] = pd.Series( + ranked_qids + ).values + # TODO: create new function to update tqdm progressbar + # tqdm is not working in docker compose. This is the alternative progressbar.update(1) - tqdm.write(progressbar.format_meter(progressbar.n, progressbar.total, progressbar.format_dict["elapsed"])) # tqdm is not wokring in docker compose. This is the alternative + tqdm.write( + progressbar.format_meter( + progressbar.n, + progressbar.total, + progressbar.format_dict["elapsed"] + ) + ) if progressbar.n % 100 == 0: - pickle.dump(eval_data, open(f"../data/Evaluation Data/{RETRIEVAL_FILENAME}.pkl", "wb")) - pickle.dump(eval_data, open(f"../data/Evaluation Data/{RETRIEVAL_FILENAME}.pkl", "wb")) \ No newline at end of file + pkl_fpath = f"../data/Evaluation Data/{RETRIEVAL_FILENAME}.pkl" + with open(pkl_fpath, "wb") as pkl_file: + pickle.dump(eval_data, pkl_file, "wb") + + # TODO: Why is this definition and open twice at the end? + pkl_fpath = f"../data/Evaluation Data/{RETRIEVAL_FILENAME}.pkl" + with open(pkl_fpath, "wb") as pkl_file: + pickle.dump(eval_data, pkl_file, "wb") diff --git a/docker/6_Push_Huggingface/Dockerfile b/docker/6_Push_Huggingface/Dockerfile index 3b551c4..c835d98 100644 --- a/docker/6_Push_Huggingface/Dockerfile +++ b/docker/6_Push_Huggingface/Dockerfile @@ -23,5 +23,7 @@ COPY ./API_tokens /API_tokens # Set up the volume for the data folder VOLUME [ "/data" ] +ENV PYTHONPATH="${PYTHONPATH}:/" + # Run the Python script -CMD ["python", "run.py"] \ No newline at end of file +CMD ["python", "run.py"] diff --git a/docker/6_Push_Huggingface/run.py b/docker/6_Push_Huggingface/run.py index c636307..753bcd7 100644 --- a/docker/6_Push_Huggingface/run.py +++ b/docker/6_Push_Huggingface/run.py @@ -1,13 +1,11 @@ import os import json -import sys from huggingface_hub import login from multiprocessing import Process, Value, Queue -sys.path.append('../src') -from wikidataDumpReader import WikidataDumpReader -from wikidataLabelsDB import WikidataLabels +from src.wikidataDumpReader import WikidataDumpReader +from src.wikidataItemDB import WikidataItem from datasets import Dataset, load_dataset_builder @@ -18,21 +16,31 @@ API_KEY_FILENAME = os.getenv("API_KEY", "huggingface_api.json") ITERATION = int(os.getenv("ITERATION", 0)) -api_key = json.load(open(f"../API_tokens/{API_KEY_FILENAME}"))['API_KEY'] +api_key_fpath = f"../API_tokens/{API_KEY_FILENAME}" +with open(api_key_fpath) as f_in: + api_key = json.load(open(f_in))['API_KEY'] + def save_to_queue(item, data_queue): """Processes and puts cleaned item into the multiprocessing queue.""" - if (item is not None) and (WikidataLabels.is_in_wikipedia(item)): - claims = WikidataLabels.add_labels_batched(item['claims'], query_batch=100) + if (item is not None) and (WikidataItem.is_in_wikipedia(item)): + claims = WikidataItem.add_labels_batched( + item['claims'], + query_batch=100 + ) data_queue.put({ 'id': item['id'], 'labels': json.dumps(item['labels'], separators=(',', ':')), - 'descriptions': json.dumps(item['descriptions'], separators=(',', ':')), + 'descriptions': json.dumps( + item['descriptions'], + separators=(',', ':') + ), 'aliases': json.dumps(item['aliases'], separators=(',', ':')), 'sitelinks': json.dumps(item['sitelinks'], separators=(',', ':')), 'claims': json.dumps(claims, separators=(',', ':')) }) + def chunk_generator(filepath, num_processes=2, queue_size=5000, skip_lines=0): """ A generator function that reads a chunk file with WikidataDumpReader, @@ -53,8 +61,12 @@ def chunk_generator(filepath, num_processes=2, queue_size=5000, skip_lines=0): # Define a function to feed items into the queue def run_reader(): - wikidata.run(lambda item: save_to_queue(item, data_queue), - max_iterations=None, verbose=True) + wikidata.run( + lambda item: save_to_queue(item, data_queue), + max_iterations=None, + verbose=True + ) + with finished.get_lock(): finished.value = 1 @@ -69,7 +81,8 @@ def run_reader(): break try: item = data_queue.get(timeout=1) - except: + except Exception as e: + print(f'Exception: {e}') continue if item: yield item @@ -77,26 +90,29 @@ def run_reader(): # Wait for the reader process to exit reader_proc.join() -# Now process each chunk file and push to the same Hugging Face repo -HF_REPO_ID = "wikidata" # Change to your actual repo on Hugging Face - -login(token=api_key) -builder = load_dataset_builder("philippesaade/wikidata") -for i in range(45, 113): - split_name = f"chunk_{i}" - if split_name not in builder.info.splits: - filepath = f"../data/Wikidata/latest-all-chunks/chunk_{i}.json.gz" - - print(f"Processing {filepath} -> split={split_name}") - - # Create a Dataset from the generator - ds_chunk = Dataset.from_generator(lambda: chunk_generator( - filepath, - num_processes=NUM_PROCESSES, - queue_size=QUEUE_SIZE, - skip_lines=SKIPLINES - )) - - # Push each chunk as a separate "split" under the same dataset repo - ds_chunk.push_to_hub(HF_REPO_ID, split=split_name) - print(f"Chunk {ITERATION} pushed to {HF_REPO_ID} as {split_name}.") + +if __name__ == "__main__": + # TODO: Convert the following into a function and run it here + # Now process each chunk file and push to the same Hugging Face repo + HF_REPO_ID = "wikidata" # Change to your actual repo on Hugging Face + + login(token=api_key) + builder = load_dataset_builder("philippesaade/wikidata") + for i in range(0, 113): + split_name = f"chunk_{i}" + if split_name not in builder.info.splits: + filepath = f"../data/Wikidata/latest-all-chunks/chunk_{i}.json.gz" + + print(f"Processing {filepath} -> split={split_name}") + + # Create a Dataset from the generator + ds_chunk = Dataset.from_generator(lambda: chunk_generator( + filepath, + num_processes=NUM_PROCESSES, + queue_size=QUEUE_SIZE, + skip_lines=SKIPLINES + )) + + # Push each chunk as a separate "split" under the same dataset repo + ds_chunk.push_to_hub(HF_REPO_ID, split=split_name) + print(f"Chunk {ITERATION} pushed to {HF_REPO_ID} as {split_name}.") diff --git a/docker/7_Create_Prototype/Dockerfile b/docker/7_Create_Prototype/Dockerfile new file mode 100644 index 0000000..3ebbb00 --- /dev/null +++ b/docker/7_Create_Prototype/Dockerfile @@ -0,0 +1,32 @@ +# Use the official Python image from the Docker Hub +FROM python:3.9-slim +LABEL maintainer="philippe.saade@wikimedia.de" + +# Upgrade the pip, git and ubuntu versions to the most recent version +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + git \ + && rm -rf /var/lib/apt/lists/* \ + && pip install --upgrade pip setuptools wheel + +# Set the working directory in the container +WORKDIR /app + +# Copy the requirements file into the container +COPY ./docker/7_Create_Prototype/requirements.txt requirements.txt + +# Install the dependencies +RUN pip install --no-cache-dir -r requirements.txt + +# Copy the rest of the application code into the container +COPY ./docker/7_Create_Prototype /app +COPY ./src /src +COPY ./API_tokens /API_tokens + +# Set up the volume for the data folder +VOLUME [ "/data" ] + +ENV PYTHONPATH="${PYTHONPATH}:/" + +# Run the Python script +CMD ["python", "run.py"] diff --git a/docker/7_Create_Prototype/requirements.txt b/docker/7_Create_Prototype/requirements.txt new file mode 100644 index 0000000..c97586e --- /dev/null +++ b/docker/7_Create_Prototype/requirements.txt @@ -0,0 +1,22 @@ +tqdm +psutil +orjson +pandas + +# wikidataDB +sqlalchemy + +# wikidataEmbed +transformers +einops + +# wikidataRetriever +langchain-core +langchainhub +langchain_experimental +ragstack-ai-langchain[knowledge-store]==1.3.0 +langchain-astradb +astrapy +elasticsearch + +datasets \ No newline at end of file diff --git a/docker/7_Create_Prototype/run.py b/docker/7_Create_Prototype/run.py new file mode 100644 index 0000000..2438508 --- /dev/null +++ b/docker/7_Create_Prototype/run.py @@ -0,0 +1,177 @@ +import json +import os +import hashlib +import time + +from datasets import load_dataset +from datetime import datetime +from multiprocessing import Process, Queue, Manager +from tqdm import tqdm +from types import SimpleNamespace + +from src.wikidataEmbed import WikidataTextifier +from src.wikidataRetriever import AstraDBConnect + +MODEL = os.getenv("MODEL", "jinaapi") +NUM_PROCESSES = int(os.getenv("NUM_PROCESSES", 4)) +EMBED_BATCH_SIZE = int(os.getenv("EMBED_BATCH_SIZE", 100)) + +QUEUE_SIZE = 2 * EMBED_BATCH_SIZE * NUM_PROCESSES # enough to not run out +QUEUE_SIZE = int(os.getenv("QUEUE_SIZE", QUEUE_SIZE)) + +DB_API_KEY_FILENAME = os.getenv("DB_API_KEY", "datastax_wikidata.json") +COLLECTION_NAME = os.getenv("COLLECTION_NAME") + +CHUNK_NUM = os.getenv("CHUNK_NUM") + +assert CHUNK_NUM is not None, ( + "Please provide `CHUNK_NUM` env var at docker run" +) + +LANGUAGE = "en" +TEXTIFIER_LANGUAGE = "en" +DUMPDATE = "09/18/2024" + +# Load the Database +if not COLLECTION_NAME: + raise ValueError("The COLLECTION_NAME environment variable is required") + +if not TEXTIFIER_LANGUAGE: + TEXTIFIER_LANGUAGE = LANGUAGE + +FILEPATH = f"../data/Wikidata/chunks/chunk_{CHUNK_NUM}.json.gz" + +# wikidata_chunk_sizes_fname = "../data/Wikidata/chunk_sizes.json" +# TODO: Add location as env var +# TODO: Sync data format from DATADUMP to chunk_sizes.json +# TODO: Retrieve info from Hugging Face instead of storing it +# TODO: Set this up as a separate script and run after __name__ +wikidata_chunksizes_path = os.path.join("wikidata_chunk_sizes_2024-09-18.json") + +with open(wikidata_chunksizes_path) as json_in: + chunk_sizes = json.load(json_in) + +total_entities = chunk_sizes[f"chunk_{CHUNK_NUM}"] + +with open(f"../API_tokens/{DB_API_KEY_FILENAME}") as json_in: + datastax_token = json.load(json_in) + +dataset = load_dataset( + "philippesaade/wikidata", + data_files=f"data/chunk_{CHUNK_NUM}-*.parquet", + streaming=True, + split="train" +) + + +def process_items(queue, progress_bar): + """Worker function that processes items from the queue + and adds them to AstraDB. + """ + with open(f"../API_tokens/{DB_API_KEY_FILENAME}") as json_in: + datastax_token = json.load(json_in) + + graph_store = AstraDBConnect( + datastax_token, + COLLECTION_NAME, + model=MODEL, + batch_size=EMBED_BATCH_SIZE, + cache_embeddings="wikidata_prototype" + ) + textifier = WikidataTextifier( + language=LANGUAGE, + langvar_filename=TEXTIFIER_LANGUAGE + ) + + while True: + item = queue.get() + if item is None: + break # Exit condition for worker processes + + item_id = item['id'] + + item_label = textifier.get_label( + item_id, + json.loads(item['labels']) + ) + + item_description = textifier.get_description( + item_id, + json.loads(item['descriptions']) + ) + item_aliases = textifier.get_aliases( + json.loads(item['aliases']) + ) + + if item_label is not None: + # TODO: Verify: If label does not exist, then skip item + entity_obj = SimpleNamespace() + entity_obj.id = item_id + entity_obj.label = item_label + entity_obj.description = item_description + entity_obj.aliases = item_aliases + entity_obj.claims = json.loads(item['claims']) + + chunks = textifier.chunk_text( + entity_obj, + graph_store.tokenizer, + max_length=graph_store.max_token_size + ) + + for chunk_i, chunk in enumerate(chunks): + md5_hash = hashlib.md5(chunk.encode('utf-8')).hexdigest() + metadata = { + "MD5": md5_hash, + "Label": item_label, + "Description": item_description, + "Aliases": item_aliases, + "Date": datetime.now().isoformat(), + "QID": item_id, + "ChunkID": chunk_i + 1, + "Language": LANGUAGE, + "IsItem": ('Q' in item_id), + "IsProperty": ('P' in item_id), + "DumpDate": DUMPDATE + } + + graph_store.add_document( + id=f"{item_id}_{LANGUAGE}_{chunk_i+1}", + + text=chunk, + metadata=metadata + ) + + progress_bar.value += 1 + + + graph_store.push_all() + + +if __name__ == "__main__": + queue = Queue(maxsize=QUEUE_SIZE) + progress_bar = Manager().Value("i", 0) + + with tqdm(total=total_entities) as pbar: + processes = [] + for _ in range(NUM_PROCESSES): + p = Process(target=process_items, args=(queue, progress_bar)) + p.start() + processes.append(p) + + for item in dataset: + queue.put(item) + pbar.update(progress_bar.value - pbar.n) + # pbar.n = progress_bar.value + # pbar.refresh() + + for _ in range(NUM_PROCESSES): + queue.put(None) + + while any(p.is_alive() for p in processes): + pbar.update(progress_bar.value - pbar.n) + # pbar.n = progress_bar.value + # pbar.refresh() + time.sleep(1) + + for p in processes: + p.join() diff --git a/docker/7_Create_Prototype/wikidata_chunk_sizes_2024-09-18.json b/docker/7_Create_Prototype/wikidata_chunk_sizes_2024-09-18.json new file mode 100644 index 0000000..1b8c6f8 --- /dev/null +++ b/docker/7_Create_Prototype/wikidata_chunk_sizes_2024-09-18.json @@ -0,0 +1,115 @@ +{ + "chunk_0":992458, + "chunk_1":802125, + "chunk_2":589652, + "chunk_3":310440, + "chunk_4":43477, + "chunk_5":156867, + "chunk_6":141965, + "chunk_7":74047, + "chunk_8":27104, + "chunk_9":70759, + "chunk_10":71395, + "chunk_11":186698, + "chunk_12":153182, + "chunk_13":137155, + "chunk_14":929827, + "chunk_15":853027, + "chunk_16":571543, + "chunk_17":335565, + "chunk_18":47264, + "chunk_19":135986, + "chunk_20":160411, + "chunk_21":76377, + "chunk_22":26321, + "chunk_23":70572, + "chunk_24":68613, + "chunk_25":179806, + "chunk_26":159587, + "chunk_27":139912, + "chunk_28":876104, + "chunk_29":864360, + "chunk_30":590603, + "chunk_31":358747, + "chunk_32":47772, + "chunk_33":135633, + "chunk_34":159629, + "chunk_35":81231, + "chunk_36":24912, + "chunk_37":69201, + "chunk_38":67131, + "chunk_39":172234, + "chunk_40":167698, + "chunk_41":142276, + "chunk_42":821175, + "chunk_43":892005, + "chunk_44":600584, + "chunk_45":374793, + "chunk_46":47443, + "chunk_47":134784, + "chunk_48":155247, + "chunk_49":86997, + "chunk_50":24829, + "chunk_51":68053, + "chunk_52":63517, + "chunk_53":167660, + "chunk_54":175827, + "chunk_55":142816, + "chunk_56":765400, + "chunk_57":900655, + "chunk_58":628866, + "chunk_59":396886, + "chunk_60":46907, + "chunk_61":135384, + "chunk_62":154864, + "chunk_63":88112, + "chunk_64":23353, + "chunk_65":67446, + "chunk_66":40301, + "chunk_67":176420, + "chunk_68":183715, + "chunk_69":149547, + "chunk_70":713006, + "chunk_71":901222, + "chunk_72":652770, + "chunk_73":419554, + "chunk_74":52246, + "chunk_75":134064, + "chunk_76":153318, + "chunk_77":92710, + "chunk_78":22790, + "chunk_79":66521, + "chunk_80":34397, + "chunk_81":173357, + "chunk_82":186788, + "chunk_83":153870, + "chunk_84":657926, + "chunk_85":902477, + "chunk_86":655319, + "chunk_87":455111, + "chunk_88":69724, + "chunk_89":133629, + "chunk_90":146534, + "chunk_91":101890, + "chunk_92":21324, + "chunk_93":65448, + "chunk_94":33345, + "chunk_95":162191, + "chunk_96":192226, + "chunk_97":159451, + "chunk_98":598037, + "chunk_99":903618, + "chunk_100":662580, + "chunk_101":484690, + "chunk_102":86616, + "chunk_103":135160, + "chunk_104":106630, + "chunk_105":142249, + "chunk_106":19290, + "chunk_107":60073, + "chunk_108":39131, + "chunk_109":155251, + "chunk_110":190337, + "chunk_111":166210, + "chunk_112":26375 +} diff --git a/run_experiments.sh b/run_experiments.sh new file mode 100644 index 0000000..a10e6ec --- /dev/null +++ b/run_experiments.sh @@ -0,0 +1,65 @@ +# docker compose run --build add_wikidata_to_astra +# docker compose run --build \ +# -e EVALUATION_PATH="Mintaka/processed_dataframe.pkl" \ +# -e QUERY_COL="Question" \ +# -e PREFIX="_nonewlines" \ +# -e COLLECTION_NAME="wikidatav1" \ +# -e QUERY_LANGUAGE="en" \ +# -e DB_LANGUAGE="en" \ +# -e API_KEY="datastax_wikidata2.json" run_retrieval + +# docker compose run --build \ +# -e EVALUATION_PATH="LC_QuAD/processed_dataframe.pkl" \ +# -e QUERY_COL="Question" \ +# -e PREFIX="_nonewlines" \ +# -e COLLECTION_NAME="wikidatav1" \ +# -e QUERY_LANGUAGE="en" \ +# -e DB_LANGUAGE="en" \ +# -e API_KEY="datastax_wikidata2.json" run_retrieval + +# docker compose run --build \ +# -e EVALUATION_PATH="REDFM/processed_dataframe.pkl" \ +# -e QUERY_COL="Sentence" \ +# -e PREFIX="_nonewlines" \ +# -e COLLECTION_NAME="wikidatav1" \ +# -e QUERY_LANGUAGE="en" \ +# -e DB_LANGUAGE="en" \ +# -e API_KEY="datastax_wikidata2.json" run_retrieval + +# docker compose run --build \ +# -e EVALUATION_PATH="REDFM/processed_dataframe.pkl" \ +# -e QUERY_COL="Sentence no entity" \ +# -e PREFIX="_nonewlines_noentity" \ +# -e COLLECTION_NAME="wikidatav1" \ +# -e QUERY_LANGUAGE="en" \ +# -e DB_LANGUAGE="en" \ +# -e API_KEY="datastax_wikidata2.json" run_retrieval + +# docker compose run --build \ +# -e EVALUATION_PATH="RuBQ/processed_dataframe.pkl" \ +# -e QUERY_COL="Question" \ +# -e PREFIX="_nonewlines" \ +# -e COLLECTION_NAME="wikidatav1" \ +# -e QUERY_LANGUAGE="en" \ +# -e DB_LANGUAGE="en" \ +# -e API_KEY="datastax_wikidata2.json" run_retrieval + +# docker compose run --build \ +# -e EVALUATION_PATH="Wikidata-Disamb/processed_dataframe.pkl" \ +# -e QUERY_COL="Sentence" \ +# -e COMPARATIVE="true" \ +# -e COMPARATIVE_COLS="Correct QID,Wrong QID" \ +# -e COLLECTION_NAME="wikidatav1" \ +# -e PREFIX="_nonewlines" \ +# -e QUERY_LANGUAGE="en" \ +# -e DB_LANGUAGE="en" \ +# -e API_KEY="datastax_wikidata2.json" run_retrieval + +# docker compose run --build -e CHUNK_NUM=112 create_prototype +# docker compose run --build -e CHUNK_NUM=111 create_prototype +# docker compose run --build -e CHUNK_NUM=110 create_prototype +# docker compose run --build -e CHUNK_NUM=109 create_prototype + +for chunk_num in `seq 101 -1 100`; + do docker compose run --build -e CHUNK_NUM=$chunk_num create_prototype; +done diff --git a/src/JinaAI.py b/src/JinaAI.py index 9f367f8..e278266 100644 --- a/src/JinaAI.py +++ b/src/JinaAI.py @@ -1,71 +1,52 @@ - -from typing import List -from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification -import torch -from sqlalchemy import Column, Text, create_engine -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker -from sqlalchemy.types import TypeDecorator import json import requests import numpy as np import base64 -""" -SQLite database setup for caching the query embeddings for a faster evaluation process. -""" -engine = create_engine( - 'sqlite:///../data/Wikidata/sqlite_cacheembeddings.db', - pool_size=5, # Limit the number of open connections - max_overflow=10, # Allow extra connections beyond pool_size - pool_recycle=10 # Recycle connections every 10 seconds -) - -Base = declarative_base() -Session = sessionmaker(bind=engine) - -class JSONType(TypeDecorator): - """Custom SQLAlchemy type for JSON storage in SQLite.""" - impl = Text - - def process_bind_param(self, value, dialect): - if value is not None: - return json.dumps(value, separators=(',', ':')) - return None +# import torch # torch no long imported - def process_result_value(self, value, dialect): - if value is not None: - return json.loads(value) - return None -class CacheEmbeddings(Base): - """Represents a cache entry for a text string and its embedding.""" - __tablename__ = 'embeddings' +from typing import List +from src.wikidataCache import create_cache_embedding_model - text = Column(Text, primary_key=True) - embedding = Column(JSONType) class JinaAIEmbedder: - def __init__(self, passage_task="retrieval.passage", query_task="retrieval.query", embedding_dim=1024, cache=False, api_key_path="../API_tokens/jina_api.json"): + def __init__( + self, passage_task="retrieval.passage", + query_task="retrieval.query", embedding_dim=1024, cache=None): """ - Initializes the JinaAIEmbedder class with the model, tokenizer, and task identifiers. + Initializes the JinaAIEmbedder class with the model, tokenizer, + and task identifiers. Parameters: - - passage_task (str): Task identifier for embedding documents. Defaults to "retrieval.passage". - - query_task (str): Task identifier for embedding queries. Defaults to "retrieval.query". - - embedding_dim (int): Dimensionality of the embeddings. Defaults to 1024. - - cache (bool): Whether to cache query embeddings in the database. Defaults to False. - - api_key_path (str): Path to the JSON file containing the Jina API key. Defaults to "../API_tokens/jina_api.json". + - passage_task (str): Task identifier for embedding documents. + Defaults to "retrieval.passage". + - query_task (str): Task identifier for embedding queries. + Defaults to "retrieval.query". + - embedding_dim (int): Dimensionality of the embeddings. + Defaults to 1024. + - cache (str): Name of caching table. + - api_key_path (str): Path to the JSON file containing the + Jina API key. Defaults to "../API_tokens/jina_api.json". """ + from transformers import AutoModel, AutoTokenizer + self.passage_task = passage_task self.query_task = query_task self.embedding_dim = embedding_dim - self.cache = cache - self.model = AutoModel.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True).to('cuda') - self.tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True) + self.model = AutoModel.from_pretrained( + "jinaai/jina-embeddings-v3", + trust_remote_code=True + ).to('cuda') + self.tokenizer = AutoTokenizer.from_pretrained( + "jinaai/jina-embeddings-v3", + trust_remote_code=True + ) - self.api_key = json.load(open(api_key_path, 'r+'))['API_KEY'] + self.cache = (cache is not None) + if self.cache: + self.cache_model = create_cache_embedding_model(cache) def _cache_embedding(self, text: str, embedding: List[float]): """ @@ -77,14 +58,7 @@ def _cache_embedding(self, text: str, embedding: List[float]): """ if self.cache: embedding = embedding.tolist() - with Session() as session: - try: - cached = CacheEmbeddings(text=text, embedding=embedding) - session.merge(cached) - session.commit() - except Exception as e: - session.rollback() - raise e + self.cache_model.add_cache(id=text, embedding=embedding) def _get_cached_embedding(self, text: str) -> List[float]: """ @@ -97,19 +71,93 @@ def _get_cached_embedding(self, text: str) -> List[float]: - List[float] or None: The embedding if found in cache, otherwise None. """ if self.cache: - with Session() as session: - cached = session.query(CacheEmbeddings).filter_by(text=text).first() - if cached: - return cached.embedding + return self.cache_model.get_cache(id=text) return None - def api_embed(self, text, task="retrieval.query"): + def embed_documents(self, texts: List[str]) -> List[List[float]]: """ - Generates an embedding for the given text using the Jina Embeddings API. + Generates embeddings for a list of document (passage) texts. + + Caching is not used here by default to avoid storing + large numbers of document embeddings. + + Parameters: + - texts (List[str]): A list of document texts to embed. + + Returns: + - List[List[float]]: A list of embedding vectors, each corresponding + to a document. + """ + + with torch.no_grad(): + embeddings = self.model.encode( + texts, + task=self.passage_task, + truncate_dim=self.embedding_dim + ) + + return embeddings + + def embed_query(self, text: str) -> List[float]: + """ + Generates an embedding for a single query string, optionally using + and updating the cache. + + Parameters: + - text (str): The query text to embed. + + Returns: + - List[float]: The embedding vector corresponding to the query. + """ + cached_embedding = self._get_cached_embedding(text) + if cached_embedding: + return cached_embedding + + with torch.no_grad(): + embedding = self.model.encode( + [text], + task=self.query_task, + truncate_dim=self.embedding_dim + )[0] + + self._cache_embedding(text, embedding) + return embedding + + +class JinaAIAPIEmbedder: + def __init__( + self, passage_task="retrieval.passage", + query_task="retrieval.query", embedding_dim=1024, + api_key_path="../API_tokens/jina_api.json"): # cache=False, + """ + Initializes the JinaAIEmbedder class with the model, tokenizer, + and task identifiers. + + Parameters: + - passage_task (str): Task identifier for embedding documents. + Defaults to "retrieval.passage". + - query_task (str): Task identifier for embedding queries. + Defaults to "retrieval.query". + - embedding_dim (int): Dimensionality of the embeddings. + Defaults to 1024. + - cache (str): Name of caching table. # BUG: cache is unused + - api_key_path (str): Path to the JSON file containing + the Jina API key. Defaults to "../API_tokens/jina_api.json". + """ + self.passage_task = passage_task + self.query_task = query_task + self.embedding_dim = embedding_dim + + self.api_key = json.load(open(api_key_path, 'r+'))['API_KEY'] + + def api_embed(self, texts, task="retrieval.query"): + """ + Generates an embedding for the given text using the Jina Embeddings API Parameters: - text (str): The text to embed. - - task (str): The task identifier (e.g., "retrieval.query" or "retrieval.passage"). + - task (str): The task identifier + (e.g., "retrieval.query" or "retrieval.passage"). Returns: - np.ndarray: The resulting embedding vector as a NumPy array. @@ -120,41 +168,52 @@ def api_embed(self, text, task="retrieval.query"): 'Authorization': f'Bearer {self.api_key}' } + if type(texts) is str: + texts = [texts] + data = { "model": "jina-embeddings-v3", "dimensions": self.embedding_dim, "embedding_type": "base64", "task": task, "late_chunking": False, - "input": [ - text - ] + "input": texts } response = requests.post(url, headers=headers, json=data) - binary_data = base64.b64decode(response.json()['data'][0]['embedding']) - embedding_array = np.frombuffer(binary_data, dtype=' List[List[float]]: """ Generates embeddings for a list of document (passage) texts. - Caching is not used here by default to avoid storing large numbers of document embeddings. + Caching is not used here by default to avoid storing large numbers + of document embeddings. Parameters: - texts (List[str]): A list of document texts to embed. Returns: - - List[List[float]]: A list of embedding vectors, each corresponding to a document. + - List[List[float]]: A list of embedding vectors, each corresponding + to a document. """ - with torch.no_grad(): - embeddings = self.model.encode(texts, task=self.passage_task, truncate_dim=self.embedding_dim) + embeddings = self.api_embed(texts, task=self.passage_task) return embeddings def embed_query(self, text: str) -> List[float]: """ - Generates an embedding for a single query string, optionally using and updating the cache. + Generates an embedding for a single query string, optionally using + and updating the cache. Parameters: - text (str): The query text to embed. @@ -162,31 +221,33 @@ def embed_query(self, text: str) -> List[float]: Returns: - List[float]: The embedding vector corresponding to the query. """ - cached_embedding = self._get_cached_embedding(text) - if cached_embedding: - return cached_embedding + embedding = self.api_embed([text], task=self.query_task) + return embedding - with torch.no_grad(): - embedding = self.model.encode([text], task=self.query_task, truncate_dim=self.embedding_dim)[0] - self._cache_embedding(text, embedding) - return embedding class JinaAIReranker: def __init__(self, max_tokens=1024): """ - Initializes the JinaAIReranker with a maximum token length and the Jina Reranker model. + Initializes the JinaAIReranker with a maximum token length + and the Jina Reranker model. Parameters: - - max_tokens (int): Maximum sequence length for the reranker (must be <= 1024). + - max_tokens (int): Maximum sequence length for the reranker + (must be <= 1024). Raises: - ValueError: If max_tokens is greater than 1024. """ + from transformers import AutoModelForSequenceClassification + if max_tokens > 1024: raise ValueError("Max token should be less than or equal to 1024") self.max_tokens = max_tokens - self.model = AutoModelForSequenceClassification.from_pretrained('jinaai/jina-reranker-v2-base-multilingual', trust_remote_code=True).to('cuda') + self.model = AutoModelForSequenceClassification.from_pretrained( + 'jinaai/jina-reranker-v2-base-multilingual', + trust_remote_code=True + ).to('cuda') def rank(self, query: str, texts: List[str]) -> List[float]: """ @@ -197,12 +258,13 @@ def rank(self, query: str, texts: List[str]) -> List[float]: - texts (List[str]): A list of document texts to rank. Returns: - - List[float]: A list of relevance scores, each corresponding to one document. + - List[float]: A list of relevance scores, each corresponding + to one document. """ sentence_pairs = [[query, doc] for doc in texts] with torch.no_grad(): - return self.model.compute_score(sentence_pairs, max_length=self.max_tokens) - -# Create tables if they don't already exist. -Base.metadata.create_all(engine) \ No newline at end of file + return self.model.compute_score( + sentence_pairs, + max_length=self.max_tokens + ) diff --git a/src/__init__.py b/src/__init__.py index d3b8f15..9ef178f 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1,6 +1,6 @@ from .wikidataDumpReader import WikidataDumpReader -from .wikidataDB import WikidataEntity, WikidataID -from .wikidataLabelsDB import WikidataLabels +from .wikidataLangDB import WikidataLang +from .wikidataItemDB import WikidataItem from .wikidataEmbed import WikidataTextifier -from .JinaAI import JinaAIEmbedder, JinaAIReranker -from .wikidataRetriever import AstraDBConnect \ No newline at end of file +from .JinaAI import JinaAIEmbedder, JinaAIReranker, JinaAIAPIEmbedder +from .wikidataRetriever import AstraDBConnect, KeywordSearchConnect diff --git a/src/experimental_functions/word_embeding.py b/src/experimental_functions/word_embeding.py new file mode 100644 index 0000000..2e9b718 --- /dev/null +++ b/src/experimental_functions/word_embeding.py @@ -0,0 +1,107 @@ +# import torch +# import torch.nn.functional as F +from src.JinaAI import JinaAIEmbedder + +model = JinaAIEmbedder() + + +def mean_pooling(model_output, attention_mask): + token_embeddings = model_output[0] + input_mask_expanded = ( + attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + ) + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( + input_mask_expanded.sum(1), min=1e-9 + ) + + +def extract_specific_word_embedding( + model, tokenizer, sentence, start_marker='', end_marker=''): + """_summary_ # TODO: fill out docstring + + Args: + model (_type_): _description_ + tokenizer (_type_): _description_ + sentence (_type_): _description_ + start_marker (str, optional): _description_. Defaults to ''. + end_marker (str, optional): _description_. Defaults to ''. + + Returns: + _type_: _description_ + """ + + # Check if the markers are in the sentence + if start_marker in sentence and end_marker in sentence: + # Find the position of the markers + start_index = sentence.find(start_marker) + end_index = sentence.find(end_marker) + len(end_marker) + + # Split the sentence into three parts: before, target, and after + mid_left_index = start_index + len(start_marker) + mid_right_index = end_index - len(end_marker) + before_target = sentence[:start_index] + target_word = sentence[mid_left_index:mid_right_index] + after_target = sentence[end_index:] + + # Clean the sentence by removing the markers + tokens_before = tokenizer( + before_target, + add_special_tokens=False + )['input_ids'] + + tokens_target = tokenizer( + target_word, + add_special_tokens=False + )['input_ids'] + + # Reconstruct the input with markers removed and tokenize + clean_sentence = before_target + target_word + after_target + inputs = tokenizer( + clean_sentence, + return_tensors="pt" + ).to(model.device) + + # Run the model to get output embeddings + task = 'retrieval.query' + task_id = model._adaptation_map[task] + adapter_mask = torch.full( + (1,), + task_id, + dtype=torch.int32 + ).to(model.device) + + with torch.no_grad(): + outputs = model(**inputs, adapter_mask=adapter_mask) + + # Calculate the start and end positions of the target word tokens + target_start_pos = len(tokens_before) + target_end_pos = target_start_pos + len(tokens_target) + + target_attention_mask = torch.zeros_like(inputs['attention_mask']) + target_attention_mask[:, target_start_pos:target_end_pos] = 1 + + # Apply mean pooling only to the target word's embeddings + # word_embeddings = mean_pooling(outputs, target_attention_mask) + # normalized_embeddings = F.normalize(word_embeddings, p=2, dim=1) + + normalized_embeddings = ( + outputs[0][:, target_start_pos:target_end_pos].mean(dim=0) + ) + return normalized_embeddings + else: + with torch.no_grad(): + outputs = model.encode(sentence, task='retrieval.query') + return outputs + + +# Example usage +sentence = ( + "I went to the river bank to take a swim. " + "Afterwards, I went to the bank to withdraw some money." +) +embedding = extract_specific_word_embedding( + model.model, + model.tokenizer, + sentence +) +print("Extracted Embedding:", embedding) diff --git a/src/language_variables/ar.py b/src/language_variables/ar.py index e0cb548..14fc86f 100644 --- a/src/language_variables/ar.py +++ b/src/language_variables/ar.py @@ -59,11 +59,15 @@ def qualifiers_to_text(qualifiers): """ text = "" for property_label, qualifier_values in qualifiers.items(): - if len(text) > 0: - text += f" ; " + if (qualifier_values is not None) and len(qualifier_values) > 0: + if len(text) > 0: + text += f" ; " - text += f"{property_label}: {'، '.join(qualifier_values)}" - return text + text += f"{property_label}: {'، '.join(qualifier_values)}" + + if len(text) > 0: + return f" ({text})" + return "" def properties_to_text(properties): """ @@ -77,7 +81,7 @@ def properties_to_text(properties): """ properties_text = "" for property_label, claim_values in properties.items(): - if len(claim_values) > 0: + if (claim_values is not None) and (len(claim_values) > 0): claims_text = "" for claim_value in claim_values: @@ -88,7 +92,7 @@ def properties_to_text(properties): qualifiers = claim_value.get('qualifiers', {}) if len(qualifiers) > 0: - claims_text += f" ({qualifiers_to_text(qualifiers)})" + claims_text += qualifiers_to_text(qualifiers) claims_text += f"»" diff --git a/src/language_variables/de.py b/src/language_variables/de.py index 08c8f59..c5eb103 100644 --- a/src/language_variables/de.py +++ b/src/language_variables/de.py @@ -59,11 +59,18 @@ def qualifiers_to_text(qualifiers): """ text = "" for property_label, qualifier_values in qualifiers.items(): - if len(text) > 0: - text += f" ; " + if (qualifier_values is not None) and len(qualifier_values) > 0: + if len(text) > 0: + text += f" " - text += f"{property_label}: {', '.join(qualifier_values)}" - return text + text += f"({property_label}: {', '.join(qualifier_values)})" + + elif (qualifier_values is not None): + text += f"(hat {property_label})" + + if len(text) > 0: + return f" {text}" + return "" def properties_to_text(properties): """ @@ -77,21 +84,22 @@ def properties_to_text(properties): """ properties_text = "" for property_label, claim_values in properties.items(): - if len(claim_values) > 0: + if (claim_values is not None) and (len(claim_values) > 0): claims_text = "" for claim_value in claim_values: if len(claims_text) > 0: - claims_text += f",\n " + claims_text += f", " - claims_text += f"„{claim_value['value']}" + claims_text += claim_value['value'] qualifiers = claim_value.get('qualifiers', {}) if len(qualifiers) > 0: - claims_text += f" ({qualifiers_to_text(qualifiers)})" + claims_text += qualifiers_to_text(qualifiers) - claims_text += f"“" + properties_text += f'\n- {property_label}: „{claims_text}“' - properties_text += f'\n- {property_label}: {claims_text}' + elif (claim_values is not None): + properties_text += f'\n- hat {property_label}' return properties_text \ No newline at end of file diff --git a/src/language_variables/en.py b/src/language_variables/en.py index 6c87778..a5f951d 100644 --- a/src/language_variables/en.py +++ b/src/language_variables/en.py @@ -59,11 +59,18 @@ def qualifiers_to_text(qualifiers): """ text = "" for property_label, qualifier_values in qualifiers.items(): - if len(text) > 0: - text += f" ; " + if (qualifier_values is not None) and len(qualifier_values) > 0: + if len(text) > 0: + text += f" " - text += f"{property_label}: {', '.join(qualifier_values)}" - return text + text += f"({property_label}: {', '.join(qualifier_values)})" + + elif (qualifier_values is not None): + text += f"(has {property_label})" + + if len(text) > 0: + return f" {text}" + return "" def properties_to_text(properties): """ @@ -77,21 +84,22 @@ def properties_to_text(properties): """ properties_text = "" for property_label, claim_values in properties.items(): - if len(claim_values) > 0: + if (claim_values is not None) and (len(claim_values) > 0): claims_text = "" for claim_value in claim_values: if len(claims_text) > 0: - claims_text += f",\n " + claims_text += f", " - claims_text += f"\"{claim_value['value']}" + claims_text += claim_value['value'] qualifiers = claim_value.get('qualifiers', {}) if len(qualifiers) > 0: - claims_text += f" ({qualifiers_to_text(qualifiers)})" + claims_text += qualifiers_to_text(qualifiers) - claims_text += f"\"" + properties_text += f'\n- {property_label}: "{claims_text}"' - properties_text += f'\n- {property_label}: {claims_text}' + elif (claim_values is not None): + properties_text += f'\n- has {property_label}' return properties_text \ No newline at end of file diff --git a/src/language_variables/json.py b/src/language_variables/json.py index 15002f0..eefbb6c 100644 --- a/src/language_variables/json.py +++ b/src/language_variables/json.py @@ -41,7 +41,7 @@ def merge_entity_text(label, description, aliases, properties): 'description': description, 'aliases': aliases, **properties - }, ensure_ascii=False) + }, ensure_ascii=False, indent=4) return text @@ -50,70 +50,19 @@ def compress_json(data): # Iterate through the items of the data for key, items in data.items(): - cleaned_items = [] - for item in items: - qualifiers = {k: v[0] if isinstance(v, list) and len(v) == 1 else v for k, v in item['qualifiers'].items()} - clean_item = {'value': item['value'], **qualifiers} - if len(clean_item) == 1: - clean_item = clean_item['value'] - - cleaned_items.append(clean_item) - - if len(cleaned_items) == 1: - cleaned_items = cleaned_items[0] - cleaned_data[key] = cleaned_items - - return cleaned_data - -def qualifiers_to_text(qualifiers): - """ - Converts a list of qualifiers to a readable text string. - Qualifiers provide additional information about a claim. - - Parameters: - - qualifiers: A dictionary of qualifiers with property IDs as keys and their values as lists. - - Returns: - - A string representation of the qualifiers. - """ - text = "" - for property_label, qualifier_values in qualifiers.items(): - if len(text) > 0: - text += f" ; " - - text += f"{property_label}: {', '.join(qualifier_values)}" - return text - -def properties_to_text(properties, label=""): - """ - Converts a list of properties (claims) to a readable text string. - - Parameters: - - properties: A dictionary of properties (claims) with property IDs as keys. - - Returns: - - A string representation of the properties and their values. - """ - properties_text = "" - for property_label, claim_values in properties.items(): - if len(claim_values) > 0: - - claims_text = "" - qualifier_exists = any([len(claim_value.get('qualifiers', {})) > 0 for claim_value in claim_values]) - if qualifier_exists: - for claim_value in claim_values: - if len(claims_text) > 0: - claims_text += f"\n" - - claims_text += f"{label}: {property_label}: {claim_value['value']}" - - qualifiers = claim_value.get('qualifiers', {}) - if len(qualifiers) > 0: - claims_text += f" ({qualifiers_to_text(qualifiers)})" - else: - claims_text = ', '.join([claim_value['value'] for claim_value in claim_values]) - claims_text = f"{label}: {property_label}: {claims_text}" - - properties_text += f'\n{claims_text}' - - return properties_text \ No newline at end of file + if (items is not None) and (len(items) > 0): + cleaned_items = [] + for item in items: + qualifiers = {k: v[0] if isinstance(v, list) and len(v) == 1 else v for k, v in item['qualifiers'].items()} + clean_item = {'value': item['value'], **qualifiers} + if len(clean_item) == 1: + clean_item = clean_item['value'] + + cleaned_items.append(clean_item) + + if len(cleaned_items) == 1: + cleaned_items = cleaned_items[0] + elif len(cleaned_items) > 1: + cleaned_data[key] = cleaned_items + + return cleaned_data \ No newline at end of file diff --git a/src/language_variables/rdf.py b/src/language_variables/rdf.py index 7fd498c..0e803c7 100644 --- a/src/language_variables/rdf.py +++ b/src/language_variables/rdf.py @@ -57,11 +57,15 @@ def qualifiers_to_text(qualifiers): """ text = "" for property_label, qualifier_values in qualifiers.items(): - if len(text) > 0: - text += f" ; " + if (qualifier_values is not None) and len(qualifier_values) > 0: + if len(text) > 0: + text += f" ; " - text += f"{property_label}: {', '.join(qualifier_values)}" - return text + text += f"{property_label}: {', '.join(qualifier_values)}" + + if len(text) > 0: + return f" ({text})" + return "" def properties_to_text(properties, label=""): """ @@ -75,7 +79,7 @@ def properties_to_text(properties, label=""): """ properties_text = "" for property_label, claim_values in properties.items(): - if len(claim_values) > 0: + if (claim_values is not None) and (len(claim_values) > 0): claims_text = "" qualifier_exists = any([len(claim_value.get('qualifiers', {})) > 0 for claim_value in claim_values]) @@ -88,7 +92,7 @@ def properties_to_text(properties, label=""): qualifiers = claim_value.get('qualifiers', {}) if len(qualifiers) > 0: - claims_text += f" ({qualifiers_to_text(qualifiers)})" + claims_text += qualifiers_to_text(qualifiers) else: claims_text = ', '.join([claim_value['value'] for claim_value in claim_values]) claims_text = f"{label}: {property_label}: {claims_text}" diff --git a/src/merge_cache.py b/src/merge_cache.py new file mode 100644 index 0000000..d26c85b --- /dev/null +++ b/src/merge_cache.py @@ -0,0 +1,91 @@ +import sqlite3 +import glob +import os +import base64 +from tqdm import tqdm + +# Define database file pattern +db_files = glob.glob("../data/Wikidata/sqlite_cacheembeddings_*.db") + +# Define the target merged database +merged_db = "../data/Wikidata/sqlite_cacheembeddings_merged.db" +TABLE_NAME = "wikidata_prototype" + +# Batch size for processing +BATCH_SIZE = 1000 # Adjust based on performance needs + +# Create the merged database connection +conn_merged = sqlite3.connect(merged_db) +cursor_merged = conn_merged.cursor() + +# Create table in the merged database if it doesn't exist +cursor_merged.execute(f""" +CREATE TABLE IF NOT EXISTS {TABLE_NAME} ( + id TEXT PRIMARY KEY, + embedding TEXT +); +""") +conn_merged.commit() + +# Helper function to check if a string is a valid Base64 encoding +def is_valid_base64(s): + try: + if not s or not isinstance(s, str): + return False + base64.b64decode(s, validate=True) + return True + except Exception: + return False + +# Loop through all source databases +for db_file in db_files: + print(f"Processing {db_file}...") + + # Connect to the current database + conn_src = sqlite3.connect(db_file) + cursor_src = conn_src.cursor() + + # Get total record count for progress tracking + cursor_src.execute(f"SELECT COUNT(*) FROM {TABLE_NAME}") + total_records = cursor_src.fetchone()[0] + + # Fetch records in batches + offset = 0 + with tqdm(total=total_records, + desc=f"Merging {db_file}", unit="records") as pbar: + while True: + cursor_src.execute(f"SELECT id, embedding FROM {TABLE_NAME} LIMIT {BATCH_SIZE} OFFSET {offset}") + records = cursor_src.fetchall() + if not records: + break # No more records to process + + # Prepare batch for insertion + batch_data = [] + for id_, embedding in records: + if embedding and embedding.strip() and is_valid_base64(embedding): + cursor_merged.execute(f"SELECT embedding FROM {TABLE_NAME} WHERE id = ?", (id_,)) + existing = cursor_merged.fetchone() + + if existing is None or not is_valid_base64(existing[0]): + batch_data.append((id_, embedding, embedding)) # Prepare for bulk insert + + # Perform batch insert/update + if batch_data: + cursor_merged.executemany( + f""" + INSERT INTO {TABLE_NAME} (id, embedding) + VALUES (?, ?) + ON CONFLICT(id) DO UPDATE SET embedding = ? + """, + batch_data + ) + conn_merged.commit() + + offset += BATCH_SIZE # Move to the next batch + pbar.update(len(records)) # Update tqdm progress bar + + conn_src.close() + +# Close merged database connection +conn_merged.close() +print(f"Merge completed! Combined database saved as {merged_db}") diff --git a/src/migrate_cache.py b/src/migrate_cache.py new file mode 100644 index 0000000..5096bd1 --- /dev/null +++ b/src/migrate_cache.py @@ -0,0 +1,81 @@ +import sqlite3 +import json +import base64 +import numpy as np +from tqdm import tqdm + +# Change this to match your actual database path +DB_PATH = "data/Wikidata/wikidata_cache.db" + +TABLE_NAME = "wikidata_prototype" # Change this to match your actual table name +BATCH_SIZE = 5000 # Process in smaller batches to avoid memory overload + +def convert_embeddings(): + """ + Convert JSON-stored embeddings into Base64-encoded binary format in batches. + Uses `fetchmany(BATCH_SIZE)` to process records iteratively. + """ + # TODO: Migrate away from global variables + print(f"Converting embeddings in {DB_PATH}...") + conn = sqlite3.connect(DB_PATH) + cursor = conn.cursor() + + # Check if the embedding column exists (sanity check) + cursor.execute(f"PRAGMA table_info({TABLE_NAME})") + columns = [row[1] for row in cursor.fetchall()] + if "embedding" not in columns: + print("Error: 'embedding' column does not exist in the table!") + return + + # Count total records for progress tracking + cursor.execute(f"SELECT COUNT(*) FROM {TABLE_NAME}") + total_records = cursor.fetchone()[0] + + print(f"Total records to process: {total_records}") + + # Fetch records in batches using an iterator + offset = 0 + with tqdm(total=total_records, desc="Converting embeddings", unit="record") as pbar: + while True: + cursor.execute(f"SELECT id, embedding FROM {TABLE_NAME} LIMIT {BATCH_SIZE} OFFSET {offset}") + records = cursor.fetchall() + if not records: + break # Stop when there are no more records + + updated_records = [] + for id, json_embedding in records: + if json_embedding: + try: + # Convert JSON string to list of floats + embedding_list = json.loads(json_embedding) + + # Convert list of floats to Base64-encoded binary + binary_data = np.array(embedding_list, dtype=np.float32).tobytes() + base64_embedding = base64.b64encode(binary_data).decode('utf-8') + + updated_records.append((base64_embedding, id)) + except Exception as e: + pass + + pbar.update(1) # Update progress bar for each record processed + + # Update database in batches + if updated_records: + cursor.executemany( + f"UPDATE {TABLE_NAME} SET embedding = ? WHERE id = ?", + updated_records + ) + conn.commit() # Commit every batch + + offset += BATCH_SIZE # Move to next batch + + print("Optimizing database with VACUUM...") + cursor.execute("VACUUM;") + conn.commit() + + print("Migration completed successfully.") + + conn.close() + +if __name__ == "__main__": + convert_embeddings() \ No newline at end of file diff --git a/src/wikidataCache.py b/src/wikidataCache.py new file mode 100644 index 0000000..d0a29cd --- /dev/null +++ b/src/wikidataCache.py @@ -0,0 +1,140 @@ +from sqlalchemy import Column, Text, create_engine, text +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker +from sqlalchemy.types import TypeDecorator + +import os +import json +import base64 +import numpy as np + +""" +SQLite database setup for caching the query embeddings for a faster +evaluation process. +""" + +# TODO: Move to a configuration file +wikidata_cache_file = "wikidata_cache.db" + +wikidata_cache_dir = "../data/Wikidata" +wikidata_cache_path = os.path.join(wikidata_cache_dir, wikidata_cache_file) + +try: + if not os.path.exists(wikidata_cache_dir): + os.makedirs(wikidata_cache_dir) +except OSError as e: + print(f"Error creating directory {wikidata_cache_dir}: {e}") + +assert os.path.exists(wikidata_cache_dir), \ + f"Error creating directory {wikidata_cache_dir}" + +engine = create_engine( + f'sqlite:///{wikidata_cache_path}', + pool_size=5, # Limit the number of open connections + max_overflow=10, # Allow extra connections beyond pool_size + pool_recycle=10 # Recycle connections every 10 seconds +) + +Base = declarative_base() +Session = sessionmaker(bind=engine) + +class EmbeddingType(TypeDecorator): + """Custom SQLAlchemy type for storing embeddings as Base64 strings in SQLite.""" + + impl = Text + + def process_bind_param(self, value, dialect): + """Convert a list of floats (embedding) to a Base64 string before storing.""" + if value is not None and isinstance(value, list): + # Convert list to binary + binary_data = np.array(value, dtype=np.float32).tobytes() + # Encode to Base64 string + return base64.b64encode(binary_data).decode('utf-8') + return None + + def process_result_value(self, value, dialect): + """Convert a Base64 string back to a list of floats when retrieving.""" + if value is not None: + # Decode Base64 + binary_data = base64.b64decode(value) + # Convert back to float32 list + embedding_array = np.frombuffer(binary_data, dtype=np.float32) + return embedding_array.tolist() + return None + + +def create_cache_embedding_model(table_name): + """Factory function to create a dynamic CacheEmbeddings model.""" + + class CacheEmbeddings(Base): + __tablename__ = table_name + + id = Column(Text, primary_key=True) + embedding = Column(EmbeddingType) + + @staticmethod + def add_cache(id, embedding): + with Session() as session: + try: + cached = CacheEmbeddings(id=id, embedding=embedding) + session.merge(cached) + session.commit() + return True + except Exception as e: + session.rollback() + raise e + + @staticmethod + def get_cache(id): + with Session() as session: + cached = session.query( + CacheEmbeddings + ).filter_by(id=id).first() + + if cached: + return cached.embedding + return None + + @staticmethod + def add_bulk_cache(data): + """ + Insert multiple label records in bulk. If a record with the same + ID exists, + it is ignored (no update is performed). + + Parameters: + - data (list[dict]): A list of dictionaries, each containing 'id', + 'labels', 'descriptions', and 'in_wikipedia' keys. + + Returns: + - bool: True if the operation was successful, False otherwise. + """ + worked = False + embeddingtype = EmbeddingType() + for i in range(len(data)): + data[i]['embedding'] = embeddingtype.process_bind_param( + data[i]['embedding'], + None + ) + + with Session() as session: + exec_text = text( + f""" + INSERT INTO {CacheEmbeddings.__tablename__} (id, embedding) + VALUES (:id, :embedding) + ON CONFLICT(id) DO NOTHING + """ + ) + try: + session.execute(exec_text, data) + session.commit() + session.flush() + worked = True + except Exception as e: + session.rollback() + print(e) + return worked + + Base.metadata.create_all(engine) + + return CacheEmbeddings diff --git a/src/wikidataEmbed.py b/src/wikidataEmbed.py index c7525ee..72a2836 100644 --- a/src/wikidataEmbed.py +++ b/src/wikidataEmbed.py @@ -1,53 +1,179 @@ -from wikidataDB import WikidataEntity import requests import time import json -from datetime import date, datetime import re import importlib +from datetime import date, datetime +from src.wikidataItemDB import WikidataItem + class WikidataTextifier: - def __init__(self, language='en'): + """_summary_ + """ + def __init__(self, language='en', langvar_filename=None): """ Initializes the WikidataTextifier with the specified language. + Expected use cases: Hugging Face parquet or sqlite database. Parameters: - - language (str): The language code used by the textifier (default is "en"). + - language (str): The language code used by the textifier + Default is "en". """ self.language = language + langvar_filename = ( + langvar_filename if langvar_filename is not None else language + ) try: - # Importing custom functions and variables from a formating python script in the language_variables folder. - self.langvar = importlib.import_module(f"language_variables.{language}") + # Importing custom functions and variables + # from a formating python script in the language_variables folder. + self.langvar = importlib.import_module( + f"src.language_variables.{langvar_filename}" + ) except Exception as e: raise ValueError(f"Language file for '{language}' not found.") - def entity_to_text(self, entity, properties=None): + def get_label(self, id, labels=None): + """Retrieves the label for a Wikidata entity in a specified language. + + Args: + id (str): QID or PID from the ID column in the Wikidata db or JSON. + labels (dict, optional): Wikidata labels in all available languages, else None. Defaults to None. + + Returns: + str: Wikidata label from specified language or mul[tilingual]. """ - Converts a Wikidata entity into a human-readable text string. + if (labels is None) or (len(labels) == 0): + # If the labels are not provided, fetch them from the Wikidata SQLDB + # TODO: Fetch from the Wikidata API if not found in the SQLDB + labels = WikidataItem.get_labels(id) - Parameters: - - entity (WikidataEntity): A WikidataEntity object containing entity data (label, description, claims, etc.). - - properties (dict or None): A dictionary of properties (claims). If None, the properties will be derived from entity.claims. + if isinstance(labels, str): + # If the labels are a string, return them as is + return labels + + # Take the label from the language, if missing take it + # from the multiligual class + + label = labels.get(self.language) + if label is None: + label = labels.get('mul') + + if isinstance(label, dict): + label = label.get('value') + + return label + + def get_description(self, id, descriptions=None): + """Retrieves the description for a Wikidata entity + in the specified language. + + Args: + id (str): QID or PID from the ID column in the Wikidata db or JSON. + descriptions (dict, optional): Wikidata descriptions in all available languages, else None. Defaults to None. + + Returns: + str: Wikidata description from specified language or mul[tilingual]. + """ + if (descriptions is None) or (len(descriptions) == 0): + # If the descriptions are not provided, fetch them from the Wikidata SQLDB + # TODO: Fetch from the Wikidata API if not found in the SQLDB + descriptions = WikidataItem.get_descriptions(id) + + if isinstance(descriptions, str): + return descriptions + + + # Take the description from the language, + # if missing take it from the multiligual class + description = descriptions.get(self.language) + if description is None: + description = descriptions.get('mul') + + if isinstance(description, dict): + description = description.get('value') + + return description + + + def get_aliases(self, aliases): + """Retrieves the aliases for a Wikidata entity in the specified language. + + Args: + aliases (dict, optional): Wikidata aliases in all available languages, else None. Defaults to None. + + Returns: + list: Wikidata aliases from specified language and mul[tilingual]. + """ + if (type(aliases) is list): + return aliases + + if aliases is None: + return [] + + # Combine the aliases from the specified language and the + # multilingual class. Use set format to avoid duplicates. + aliases = set() + if self.language in aliases: + aliases.update([x['value'] for x in aliases[self.language]]) + + if 'mul' in aliases: + aliases.update([x['value'] for x in aliases['mul']]) + + return list(aliases) + + def entity_to_text(self, entity, properties=None): + """Converts a Wikidata entity into a human-readable text string. + + Args: + entity (obj): A Wikidata entity object containing + entity data (label, description, claims, etc.) + properties (dict or None, optional): A dictionary of + properties (claims). If None, the properties will be derived + from entity.claims. Defaults to None. Returns: - - str: A human-readable representation of the entity, its description, aliases, and claims. + (str): A human-readable representation of the entity, its description, aliases, and claims. """ if properties is None: + # If properties are not provided, fetch them from the entity properties = self.properties_to_dict(entity.claims) - return self.langvar.merge_entity_text(entity.label, entity.description, entity.aliases, properties) + # Get the label, description, and aliases for the entity + label = self.get_label(entity.id, labels=entity.label) + + description = self.get_description( + entity.id, + descriptions=entity.description + ) + if (description is None) or (len(description) == 0): + # If the description is missing, try to get it + # from the `instance_of` property + instanceof = self.get_label('P31') + description = properties.get(instanceof, '') + + aliases = self.get_aliases(entity.aliases) + + # Merge the label, description, aliases, and properties into a single + # text string as the Data Model per language through langvar descriptors + return self.langvar.merge_entity_text( + label, + description, + aliases, + properties + ) def properties_to_dict(self, properties): """ Converts a dictionary of properties (claims) into a dict suitable for text generation. Parameters: - - properties (dict): A dictionary of claims keyed by property IDs. - Each value is a list of claim statements for that property. + - properties (dict): A dictionary of claims keyed by property IDs. + Each value is a list of claim statements for that property. Returns: - - dict: A dictionary mapping property labels to a list of their parsed values (and qualifiers). + - dict: A dictionary mapping property labels to a list of + their parsed values (and qualifiers). """ properties_dict = {} for pid, claim in properties.items(): @@ -55,23 +181,35 @@ def properties_to_dict(self, properties): rank_preferred_found = False for c in claim: - value = self.mainsnak_to_value(c.get('mainsnak', c)) - qualifiers = self.qualifiers_to_dict(c.get('qualifiers', {})) - rank = c.get('rank', 'normal').lower() - - # Only store "normal" ranks. if one "preferred" rank exists, then only store "preferred" ranks. - if value: - if ((not rank_preferred_found) and (rank == 'normal')) or (rank == 'preferred'): - if (not rank_preferred_found) and (rank == 'preferred'): - rank_preferred_found = True - p_data = [] - - p_data.append({'value': value, 'qualifiers': qualifiers}) + try: + value = self.mainsnak_to_value(c.get('mainsnak', c)) + qualifiers = self.qualifiers_to_dict( + c.get('qualifiers', {}) + ) + rank = c.get('rank', 'normal').lower() + + if value is None: + p_data = None + break + + elif len(value) > 0: + # If a preferred rank exists, include values that are + # only preferred. Else include only values that are + # ranked normal (values with a depricated rank are + # never included) + if ((not rank_preferred_found) and (rank == 'normal')) or (rank == 'preferred'): + if (not rank_preferred_found) and (rank == 'preferred'): + rank_preferred_found = True + p_data = [] + + p_data.append({'value': value, 'qualifiers': qualifiers}) + except Exception as e: + print(c) + raise e - if len(p_data) > 0: - property = WikidataEntity.get_entity(pid) - if property: - properties_dict[property.label] = p_data + label = self.get_label(pid, claim[0].get('mainsnak', {}).get('property-labels')) + if label: + properties_dict[label] = p_data return properties_dict @@ -80,7 +218,7 @@ def qualifiers_to_dict(self, qualifiers): Converts qualifiers into a dictionary suitable for text generation. Parameters: - - qualifiers (dict): A dictionary of qualifiers keyed by property IDs. + - qualifiers (dict): A dictionary of qualifiers keyed by property IDs. Each value is a list of qualifier statements. Returns: @@ -92,13 +230,16 @@ def qualifiers_to_dict(self, qualifiers): for q in qualifier: value = self.mainsnak_to_value(q) - if value: + if value is None: + q_data = None + break + elif len(value) > 0: q_data.append(value) - if len(q_data) > 0: - property = WikidataEntity.get_entity(pid) - if property: - qualifier_dict[property.label] = q_data + label = self.get_label(pid, qualifier[0].get('property-labels')) + if label: + qualifier_dict[label] = q_data + return qualifier_dict def mainsnak_to_value(self, mainsnak): @@ -109,42 +250,55 @@ def mainsnak_to_value(self, mainsnak): - mainsnak (dict): A snak object containing the value and datatype information. Returns: - - str or None: A string representation of the value, or None if parsing fails. + - str or None: A string representation of the value. If the returned string is empty, the value is discarded from the text, and If None i retured, then the whole property is discarded. """ - if mainsnak.get('snaktype', '') == 'value': - if (mainsnak.get('datatype', '') == 'wikibase-item') or (mainsnak.get('datatype', '') == 'wikibase-property'): - entity_id = mainsnak['datavalue']['value']['id'] - entity = WikidataEntity.get_entity(entity_id) - if entity is None: - return None + # Extract the datavalue + snaktype = mainsnak.get('snaktype', 'value') + datavalue = mainsnak.get('datavalue') + if (datavalue is not None) and (type(datavalue) is not str): + datavalue = datavalue.get('value', datavalue) + + # Consider missing values + if (snaktype != 'value') or (datavalue is None): + return self.langvar.novalue - text = entity.label - return text + # If the values is based on a language, only consider the language that matched the text representation language. + elif (type(datavalue) is dict) and ('language' in datavalue) and (datavalue['language'] != self.language): + return '' - elif mainsnak.get('datatype', '') == 'monolingualtext': - return mainsnak['datavalue']['value']['text'] + elif (mainsnak.get('datatype', '') == 'wikibase-item') or (mainsnak.get('datatype', '') == 'wikibase-property'): + if type(datavalue) is str: + return self.get_label(datavalue) - elif mainsnak.get('datatype', '') == 'string': - return mainsnak['datavalue']['value'] + entity_id = datavalue['id'] + label = self.get_label(entity_id, datavalue.get('labels')) + return label - elif mainsnak.get('datatype', '') == 'time': - try: - return self.time_to_text(mainsnak['datavalue']['value']) - except Exception as e: - print("Error in time formating:", e) - return mainsnak['datavalue']['value']["time"] + elif mainsnak.get('datatype', '') == 'monolingualtext': + return datavalue.get('text', datavalue) - elif mainsnak.get('datatype', '') == 'quantity': - try: - return self.quantity_to_text(mainsnak['datavalue']['value']) - except Exception as e: - print(e) - return mainsnak['datavalue']['value']['amount'] + elif mainsnak.get('datatype', '') == 'string': + return datavalue - elif mainsnak.get('snaktype', '') == 'novalue': - return self.langvar.novalue + elif mainsnak.get('datatype', '') == 'time': + try: + return self.time_to_text(datavalue) + except Exception as e: + print("Error in time formating:", e) + return datavalue["time"] + + elif mainsnak.get('datatype', '') == 'quantity': + try: + return self.quantity_to_text(datavalue) + except Exception as e: + print(e) + return datavalue['amount'] - return None + elif mainsnak.get('datatype', '') == 'external-id': + return None + + else: + return '' def quantity_to_text(self, quantity_data): """ @@ -156,6 +310,9 @@ def quantity_to_text(self, quantity_data): Returns: - str: A textual representation of the quantity (e.g., "5 kg"). """ + if quantity_data is None: + return None + quantity = quantity_data.get('amount') unit = quantity_data.get('unit') @@ -164,9 +321,7 @@ def quantity_to_text(self, quantity_data): unit = None else: unit_qid = unit.rsplit('/')[-1] - entity = WikidataEntity.get_entity(unit_qid) - if entity: - unit = entity.label + unit = self.get_label(unit_qid, quantity_data.get('unit-labels')) return quantity + (f" {unit}" if unit else "") @@ -180,6 +335,9 @@ def time_to_text(self, time_data): Returns: - str: A textual representation of the time with appropriate granularity. """ + if time_data is None: + return None + time_value = time_data['time'] precision = time_data['precision'] calendarmodel = time_data.get('calendarmodel', 'http://www.wikidata.org/entity/Q1985786') @@ -305,7 +463,7 @@ def chunk_text(self, entity, tokenizer, max_length=500): Splits a text representation of an entity into smaller chunks so that each chunk fits within the token limit of a given tokenizer. Parameters: - - entity (WikidataEntity): The entity to be textified and chunked. + - entity: The entity to be textified and chunked. - tokenizer: A tokenizer (e.g. from Hugging Face) used to count tokens. - max_length (int): The maximum number of tokens allowed per chunk (default is 500). diff --git a/src/wikidataItemDB.py b/src/wikidataItemDB.py new file mode 100644 index 0000000..dd86e0e --- /dev/null +++ b/src/wikidataItemDB.py @@ -0,0 +1,400 @@ +from sqlalchemy import Column, Text, String, Integer, create_engine, text +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker +from sqlalchemy.types import TypeDecorator, Boolean +import json +import re + +""" +SQLite database setup for storing Wikidata labels & descriptions +in all languages. +""" + +SQLITEDB_PATH = '../data/Wikidata/sqlite_wikidata_items.db' +engine = create_engine(f'sqlite:///{SQLITEDB_PATH}', + pool_size=5, # Limit the number of open connections + max_overflow=10, # Allow extra connections beyond pool_size + pool_recycle=10 # Recycle connections every 10 seconds +) + +Base = declarative_base() +Session = sessionmaker(bind=engine) + +class JSONType(TypeDecorator): + """Custom SQLAlchemy type for JSON storage in SQLite.""" + impl = Text + cache_ok = False + + def process_bind_param(self, value, dialect): + if value is not None: + return json.dumps(value, separators=(',', ':')) + return None + + def process_result_value(self, value, dialect): + if value is not None: + return json.loads(value) + return None + +class WikidataItem(Base): + """ Represents a Wikidata entity's labels in multiple languages.""" + + __tablename__ = 'item' + + # TODO: convert ID to Integer and store existin IDs as qpid + """ + id = Column(Integer, primary_key=True) + qpid = Column(String, unique=True, index=True) + """ + id = Column(Text, primary_key=True) + labels = Column(JSONType) + descriptions = Column(JSONType) + in_wikipedia = Column(Boolean, default=False) + + @staticmethod + def add_bulk_items(data): + """ + Insert multiple label records in bulk. If a record with the same ID exists, + it is ignored (no update is performed). + + Parameters: + - data (list[dict]): A list of dictionaries, each containing 'id', 'labels', 'descriptions', and 'in_wikipedia' keys. + + Returns: + - bool: True if the operation was successful, False otherwise. + """ + worked = False # Assume the operation failed + with Session() as session: + try: + # Use a text statement to operate bulk insert + # SQLAlchemy's ORM is unable to handle bulk inserts + # with ON CONFLICT. + + insert_stmt = text( + """ + INSERT INTO item (id, labels, descriptions, in_wikipedia) + VALUES (:id, :labels, :descriptions, :in_wikipedia) + ON CONFLICT(id) DO NOTHING + """ + ) + + # Execute the insert statement for each data entry. + session.execute(insert_stmt, data) + session.commit() + session.flush() + worked = True # Mark the operation as successful + except Exception as e: + session.rollback() + print(e) + + return worked # Return the operation status + + @staticmethod + def add_labels(id, labels, descriptions, in_wikipedia): + """ + Insert a single label record into the database. + + Parameters: + - id (str): The unique identifier for the entity. + - labels (dict): A dictionary of labels (e.g. { "en": "Label in English", "fr": "Label in French", ... }). + + Returns: + - bool: True if the operation was successful, False otherwise. + """ + worked = False + with Session() as session: + try: + new_entry = WikidataItem( + id=id, + labels=labels, + descriptions=descriptions, + in_wikipedia=in_wikipedia + ) + session.add(new_entry) + session.commit() + session.flush() + worked = True + except Exception as e: + session.rollback() + print(f"Error: {e}") + return worked + + @staticmethod + def get_labels(id): + """ + Retrieve labels for a given entity by its ID. + + Parameters: + - id (str): The unique identifier of the entity. + + Returns: + - dict: The labels dictionary if found, otherwise an empty dict. + """ + with Session() as session: + item = session.query(WikidataItem).filter_by(id=id).first() + if item is not None: + return item.labels + return {} + + @staticmethod + def get_descriptions(id): + """ + Retrieve labels for a given entity by its ID. + + Parameters: + - id (str): The unique identifier of the entity. + + Returns: + - dict: The labels dictionary if found, otherwise an empty dict. + """ + with Session() as session: + item = session.query(WikidataItem).filter_by(id=id).first() + if item is not None: + return item.descriptions + return {} + + @staticmethod + def get_item(id): + """ + Retrieve item for a given entity by its ID. + + Parameters: + - id (str): The unique identifier of the entity. + + Returns: + - dict: The labels dictionary if found, otherwise an empty dict. + """ + with Session() as session: + item = session.query(WikidataItem).filter_by(id=id).first() + if item is not None: + return item + return {} + + @staticmethod + def clean_label_description(data): + clean_data = {} + for lang, label in data.items(): + clean_data[lang] = label['value'] + return clean_data + + @staticmethod + def is_in_wikipedia(entity): + """ + Check if a Wikidata entity has a corresponding Wikipedia entry in any language. + + Parameters: + - entity (dict): A Wikidata entity dictionary. + + Returns: + - bool: True if the entity has at least one sitelink ending in 'wiki', otherwise False. + """ + if ('sitelinks' in entity): + for s in entity['sitelinks']: + if s.endswith('wiki'): + return True + return False + + @staticmethod + def get_labels_list(id_list): + """ + Retrieve labels for multiple entities at once. + + Parameters: + - id_list (list[str]): A list of entity IDs. + + Returns: + - dict: A mapping of {entity_id: labels_dict} for each found ID. Missing IDs won't appear. + """ + with Session() as session: + rows = ( + session.query(WikidataItem.id, WikidataItem.labels) + .filter(WikidataItem.id.in_(id_list)) + .all() + ) + + return {row_id: row_labels for row_id, row_labels in rows if row_labels is not None} + + @staticmethod + def _remove_keys(data, keys_to_remove=['hash', 'property', 'numeric-id', 'qualifiers-order']): + """ + Recursively remove specific keys from a nested data structure. + + Parameters: + - data (dict or list): The data structure to clean. + - keys_to_remove (list): Keys to remove. Default includes 'hash', 'property', 'numeric-id', and 'qualifiers-order'. + + Returns: + - dict or list: The cleaned data structure with specified keys removed. + """ + if isinstance(data, dict): + data = {key: WikidataItem._remove_keys(value, keys_to_remove) for key, value in data.items() if key not in keys_to_remove} + elif isinstance(data, list): + data = [WikidataItem._remove_keys(item, keys_to_remove) for item in data] + return data + + @staticmethod + def _clean_datavalue(data): + """ + Remove unnecessary nested structures unless they match a Wikidata entity or property pattern. + + Parameters: + - data (dict or list): The data structure to clean. + + Returns: + - dict or list: The cleaned data. + """ + if isinstance(data, dict): + # If there's only one key and it's not a property or QID, recurse into it. + if (len(data.keys()) == 1) and not re.match(r"^[PQ]\d+$", list(data.keys())[0]): + data = WikidataItem._clean_datavalue(data[list(data.keys())[0]]) + else: + data = {key: WikidataItem._clean_datavalue(value) for key, value in data.items()} + elif isinstance(data, list): + data = [WikidataItem._clean_datavalue(item) for item in data] + return data + + @staticmethod + def _gather_labels_ids(data): + """ + Find and return all relevant Wikidata IDs (e.g., property, unit, or datavalue IDs) in the claims. + + Parameters: + - data (dict or list): The data structure to scan. + + Returns: + - list[str]: A list of discovered Wikidata IDs. + """ + ids = set() + + if isinstance(data, dict): + if 'property' in data: + ids.add(data['property']) + + if 'unit' in data and data['unit'] != '1': + unit_id = data['unit'].split('/')[-1] + ids.add(unit_id) + + datatype_in_data = 'datatype' in data + datavalue_in_data = 'datavalue' in data + data_datatype = data['datatype'] in ( + 'wikibase-item', 'wikibase-property' + ) + if datatype_in_data and datavalue_in_data and data_datatype: + ids.add(data['datavalue']) + + for value in data.values(): + sub_ids = WikidataItem._gather_labels_ids(value) + ids.update(sub_ids) + + elif isinstance(data, list): + for item in data: + sub_ids = WikidataItem._gather_labels_ids(item) + ids.update(sub_ids) + + return list(ids) + + @staticmethod + def _add_labels_to_claims(data, labels_dict={}): + """ + For each found ID (property, unit, or datavalue) within the claims, + insert the corresponding labels from labels_dict or the database. + + Parameters: + - data (dict or list): The claims data structure. + - labels_dict (dict): An optional dict of {id: labels} for quick lookup. + + Returns: + - dict or list: The updated data with added label information. + """ + if isinstance(data, dict): + if 'property' in data: + if data['property'] in labels_dict: + labels = labels_dict[data['property']] + else: + labels = WikidataItem.get_labels(data['property']) + + data = { + **data, + 'property-labels': labels + } + + if ('unit' in data) and (data['unit'] != '1'): + id = data['unit'].split('/')[-1] + if id in labels_dict: + labels = labels_dict[id] + else: + labels = WikidataItem.get_labels(id) + + data = { + **data, + 'unit-labels': labels + } + + if ('datatype' in data) and ('datavalue' in data) and ((data['datatype'] == 'wikibase-item') or (data['datatype'] == 'wikibase-property')): + if data['datavalue'] in labels_dict: + labels = labels_dict[data['datavalue']] + else: + labels = WikidataItem.get_labels(data['datavalue']) + + data['datavalue'] = { + 'id': data['datavalue'], + 'labels': labels + } + + data = {key: WikidataItem._add_labels_to_claims(value, labels_dict=labels_dict) for key, value in data.items()} + + elif isinstance(data, list): + data = [WikidataItem._add_labels_to_claims(item, labels_dict=labels_dict) for item in data] + + return data + + @staticmethod + def add_labels_batched(claims, query_batch=100): + """ + Gather all relevant IDs from claims, batch-fetch their labels, then add them to the claims structure. + + Parameters: + - claims (dict or list): The claims data structure to update. + - query_batch (int): The batch size for querying labels in groups. Default is 100. + + Returns: + - dict or list: The updated claims with labels inserted. + """ + label_ids = WikidataItem._gather_labels_ids(claims) + + labels_dict = {} + for i in range(0, len(label_ids), query_batch): + temp_dict = WikidataItem.get_labels_list(label_ids[i:i+query_batch]) + labels_dict = {**labels_dict, **temp_dict} + + claims = WikidataItem._add_labels_to_claims(claims, labels_dict=labels_dict) + return claims + + @staticmethod + def clean_entity(entity): + """ + Clean a Wikidata entity's data by removing unneeded keys and adding label info to claims. + + Parameters: + - entity (dict): A Wikidata entity dictionary containing 'claims', 'labels', 'sitelinks', etc. + + Returns: + - dict: The cleaned entity with label data integrated into its claims. + """ + clean_claims = WikidataItem._remove_keys(entity.get('claims', {}), ['hash', 'snaktype', 'type', 'entity-type', 'numeric-id', 'qualifiers-order', 'snaks-order']) + clean_claims = WikidataItem._clean_datavalue(clean_claims) + clean_claims = WikidataItem._remove_keys(clean_claims, ['id']) + clean_claims = WikidataItem.add_labels_batched(clean_claims) + + sitelinks = WikidataItem._remove_keys(entity.get('sitelinks', {}), ['badges']) + + return { + 'id': entity['id'], + 'labels': WikidataItem.clean_label_description(entity['labels']), + 'descriptions': WikidataItem.clean_label_description(entity['descriptions']), + 'aliases': entity['aliases'], + 'sitelinks': sitelinks, + 'claims': clean_claims + } + +# Create tables if they don't already exist. +Base.metadata.create_all(engine) \ No newline at end of file diff --git a/src/wikidataDB.py b/src/wikidataLangDB.py similarity index 55% rename from src/wikidataDB.py rename to src/wikidataLangDB.py index b97478e..30e2ce3 100644 --- a/src/wikidataDB.py +++ b/src/wikidataLangDB.py @@ -32,7 +32,7 @@ def process_result_value(self, value, dialect): return json.loads(value) return None -class WikidataEntity(Base): +class WikidataLang(Base): """Represents a Wikidata entity with label, description, aliases, and claims.""" __tablename__ = 'wikidata' @@ -93,7 +93,7 @@ def add_entity(id, label, description, claims, aliases): worked = False with Session() as session: try: - new_entry = WikidataEntity( + new_entry = WikidataLang( id=id, label=label, description=description, @@ -121,7 +121,24 @@ def get_entity(id): - WikidataEntity or None: The entity object if found, otherwise None. """ with Session() as session: - return session.query(WikidataEntity).filter_by(id=id).first() + return session.query(WikidataLang).filter_by(id=id).first() + + @staticmethod + def is_in_wikipedia(item, language='en'): + """ + Check if a Wikidata item has a corresponding Wikipedia entry. + + Parameters: + - item (dict): The Wikidata item. + - language (str): The Wikipedia language code. Default is 'en'. + + Returns: + - bool: True if the item has a Wikipedia sitelink and label/description in the specified language or 'mul'. + """ + condition = ('sitelinks' in item) and (f'{language}wiki' in item['sitelinks']) # Has an Wikipedia Sitelink + condition = condition and ((language in item['labels']) or ('mul' in item['labels'])) # Has a label with the corresponding language or multiligual + condition = condition and ((language in item['descriptions']) or ('mul' in item['descriptions'])) # Has a description with the corresponding language or multiligual + return condition @staticmethod def normalise_item(item, language='en'): @@ -137,8 +154,8 @@ def normalise_item(item, language='en'): """ label = item['labels'][language]['value'] if (language in item['labels']) else (item['labels']['mul']['value'] if ('mul' in item['labels']) else '') # Take the label from the language, if missing take it from the multiligual class description = item['descriptions'][language]['value'] if (language in item['descriptions']) else (item['descriptions']['mul']['value'] if ('mul' in item['descriptions']) else '') # Take the description from the language, if missing take it from the multiligual class - aliases = WikidataEntity._get_aliases(item, language=language) - claims = WikidataEntity._get_claims(item) + aliases = WikidataLang._get_aliases(item, language=language) + claims = WikidataLang._get_claims(item) return { 'id': item['id'], 'label': label, @@ -160,9 +177,9 @@ def _remove_keys(data, keys_to_remove=['hash', 'property', 'numeric-id', 'qualif - dict or list: The cleaned data structure with specified keys removed. """ if isinstance(data, dict): - return {key: WikidataEntity._remove_keys(value, keys_to_remove) for key, value in data.items() if key not in keys_to_remove} + return {key: WikidataLang._remove_keys(value, keys_to_remove) for key, value in data.items() if key not in keys_to_remove} elif isinstance(data, list): - return [WikidataEntity._remove_keys(item, keys_to_remove) for item in data] + return [WikidataLang._remove_keys(item, keys_to_remove) for item in data] else: return data @@ -184,8 +201,8 @@ def _get_claims(item): for i in x: if (i['type'] == 'statement') and (i['rank'] != 'deprecated'): pid_claims.append({ - 'mainsnak': WikidataEntity._remove_keys(i['mainsnak']) if 'mainsnak' in i else {}, - 'qualifiers': WikidataEntity._remove_keys(i['qualifiers']) if 'qualifiers' in i else {}, + 'mainsnak': WikidataLang._remove_keys(i['mainsnak']) if 'mainsnak' in i else {}, + 'qualifiers': WikidataLang._remove_keys(i['qualifiers']) if 'qualifiers' in i else {}, 'rank': i['rank'] }) if len(pid_claims) > 0: @@ -211,158 +228,5 @@ def _get_aliases(item, language='en'): aliases = aliases | set([x['value'] for x in item['aliases']['mul']]) return list(aliases) -class WikidataID(Base): - """ Represents an ID record in the database, indicating whether it appears in Wikipedia or is a property. """ - - __tablename__ = 'wikidataID' - - id = Column(Text, primary_key=True) - in_wikipedia = Column(Boolean, default=False) - is_property = Column(Boolean, default=False) - - @staticmethod - def add_bulk_ids(data): - """ - Add multiple IDs to the database in bulk. If an ID exists, update its boolean fields. - - Parameters: - - data (list[dict]): A list of dictionaries with 'id', 'in_wikipedia', and 'is_property' fields. - - Returns: - - bool: True if successful, False otherwise. - """ - worked = False - with Session() as session: - try: - session.execute( - text( - """ - INSERT INTO wikidataID (id, in_wikipedia, is_property) - VALUES (:id, :in_wikipedia, :is_property) - ON CONFLICT(id) DO UPDATE - SET - in_wikipedia = CASE WHEN excluded.in_wikipedia = TRUE THEN excluded.in_wikipedia ELSE wikidataID.in_wikipedia END, - is_property = CASE WHEN excluded.is_property = TRUE THEN excluded.is_property ELSE wikidataID.is_property END - """ - ), - data - ) - session.commit() - session.flush() - worked = True - except Exception as e: - session.rollback() - print(e) - return worked - - @staticmethod - def add_id(id, in_wikipedia=False, is_property=False): - """ - Add a single ID record to the database. - - Parameters: - - id (str): The unique identifier. - - in_wikipedia (bool): Whether the entity is in Wikipedia. Default is False. - - is_property (bool): Whether the entity is a property. Default is False. - - Returns: - - bool: True if successful, False otherwise. - """ - worked = False - with Session() as session: - try: - new_entry = WikidataID(id=id, in_wikipedia=in_wikipedia, is_property=is_property) - session.add(new_entry) - session.commit() - session.flush() - worked = True - except Exception as e: - session.rollback() - print(e) - return worked - - @staticmethod - def get_id(id): - """ - Retrieve a record by its ID. - - Parameters: - - id (str): The unique identifier of the record. - - Returns: - - WikidataID or None: The record if found, otherwise None. - """ - with Session() as session: - return session.query(WikidataID).filter_by(id=id).first() - - @staticmethod - def is_in_wikipedia(item, language='en'): - """ - Check if a Wikidata item has a corresponding Wikipedia entry. - - Parameters: - - item (dict): The Wikidata item. - - language (str): The Wikipedia language code. Default is 'en'. - - Returns: - - bool: True if the item has a Wikipedia sitelink and label/description in the specified language or 'mul'. - """ - condition = ('sitelinks' in item) and (f'{language}wiki' in item['sitelinks']) # Has an Wikipedia Sitelink - condition = condition and ((language in item['labels']) or ('mul' in item['labels'])) # Has a label with the corresponding language or multiligual - condition = condition and ((language in item['descriptions']) or ('mul' in item['descriptions'])) # Has a description with the corresponding language or multiligual - return condition - - @staticmethod - def extract_entity_ids(item, language='en'): - """ - Extract entity and property IDs from a Wikidata item (including claims, qualifiers, and units). - - Parameters: - - item (dict): The Wikidata item. - - language (str): The language code for additional checks. Default is 'en'. - - Returns: - - list[dict]: A list of dictionaries with 'id', 'in_wikipedia', and 'is_property' for each discovered ID. - """ - if item is None: - return [] - - batch_ids = [{'id': item['id'], 'in_wikipedia': WikidataID.is_in_wikipedia(item, language=language), 'is_property': False}] - - for pid,claim in item.get('claims', {}).items(): - batch_ids.append({'id': pid, 'in_wikipedia': False, 'is_property': True}) - - for c in claim: - if ('mainsnak' in c) and ('datavalue' in c['mainsnak']): - if (c['mainsnak'].get('datatype', '') == 'wikibase-item'): - id = c['mainsnak']['datavalue']['value']['id'] - batch_ids.append({'id': id, 'in_wikipedia': False, 'is_property': False}) - - elif (c['mainsnak'].get('datatype', '') == 'wikibase-property'): - id = c['mainsnak']['datavalue']['value']['id'] - batch_ids.append({'id': id, 'in_wikipedia': False, 'is_property': True}) - - elif (c['mainsnak'].get('datatype', '') == 'quantity') and (c['mainsnak']['datavalue']['value'].get('unit', '1') != '1'): - id = c['mainsnak']['datavalue']['value']['unit'].rsplit('/', 1)[1] - batch_ids.append({'id': id, 'in_wikipedia': False, 'is_property': False}) - - if 'qualifiers' in c: - for pid, qualifier in c['qualifiers'].items(): - batch_ids.append({'id': pid, 'in_wikipedia': False, 'is_property': True}) - for q in qualifier: - if ('datavalue' in q): - if (q['datatype'] == 'wikibase-item'): - id = q['datavalue']['value']['id'] - batch_ids.append({'id': id, 'in_wikipedia': False, 'is_property': False}) - - elif(q['datatype'] == 'wikibase-property'): - id = q['datavalue']['value']['id'] - batch_ids.append({'id': id, 'in_wikipedia': False, 'is_property': True}) - - elif (q['datatype'] == 'quantity') and (q['datavalue']['value'].get('unit', '1') != '1'): - id = q['datavalue']['value']['unit'].rsplit('/', 1)[1] - batch_ids.append({'id': id, 'in_wikipedia': False, 'is_property': False}) - return batch_ids - # Create tables if they don't already exist. Base.metadata.create_all(engine) \ No newline at end of file diff --git a/src/wikidataRetriever.py b/src/wikidataRetriever.py index 0782c01..752368f 100644 --- a/src/wikidataRetriever.py +++ b/src/wikidataRetriever.py @@ -1,27 +1,30 @@ -from langchain_astradb import AstraDBVectorStore -from langchain_core.documents import Document -from astrapy.info import CollectionVectorServiceOptions -from transformers import AutoTokenizer -import requests -from JinaAI import JinaAIEmbedder import time -from elasticsearch import Elasticsearch - -from mediawikiapi import MediaWikiAPI -from mediawikiapi.config import Config +import json +from src.wikidataCache import create_cache_embedding_model class AstraDBConnect: - def __init__(self, datastax_token, collection_name, model='nvidia', batch_size=8, cache_embeddings=False): + def __init__( + self, datastax_token, collection_name, model='jina', + batch_size=8, cache_embeddings=None): """ Initialize the AstraDBConnect object with the corresponding embedding model. Parameters: - datastax_token (dict): Credentials for DataStax Astra, including token and API endpoint. - collection_name (str): Name of the collection (table) where data is stored. - - model (str): The embedding model to use ("nvidia" or "jina"). Default is 'nvidia'. + - model (str): The embedding model to use. Default is 'jina'. - batch_size (int): Number of documents to accumulate before pushing to AstraDB. Default is 8. - - cache_embeddings (bool): Whether to cache embeddings when using the Jina model. Default is False. + - cache_embeddings (str): Name of the cache table. """ + from langchain_astradb import AstraDBVectorStore + from astrapy.info import CollectionVectorServiceOptions + from astrapy import DataAPIClient + from astrapy.exceptions import InsertManyException + from multiprocessing import Queue + + from transformers import AutoTokenizer + from src.JinaAI import JinaAIEmbedder, JinaAIAPIEmbedder + ASTRA_DB_APPLICATION_TOKEN = datastax_token['ASTRA_DB_APPLICATION_TOKEN'] ASTRA_DB_API_ENDPOINT = datastax_token["ASTRA_DB_API_ENDPOINT"] ASTRA_DB_KEYSPACE = datastax_token["ASTRA_DB_KEYSPACE"] @@ -29,8 +32,16 @@ def __init__(self, datastax_token, collection_name, model='nvidia', batch_size=8 self.batch_size = batch_size self.model = model self.collection_name = collection_name - self.doc_batch = [] - self.id_batch = [] + self.doc_batch = Queue() + self.InsertManyException = InsertManyException + + self.cache_on = (cache_embeddings is not None) + if self.cache_on: + self.cache_model = create_cache_embedding_model(cache_embeddings) + + client = DataAPIClient(datastax_token['ASTRA_DB_APPLICATION_TOKEN']) + database0 = client.get_database(datastax_token['ASTRA_DB_API_ENDPOINT']) + self.graph_store = database0.get_collection(collection_name) if model == 'nvidia': self.tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-large-unsupervised', trust_remote_code=True, clean_up_tokenization_spaces=False) @@ -41,7 +52,7 @@ def __init__(self, datastax_token, collection_name, model='nvidia', batch_size=8 model_name="NV-Embed-QA" ) - self.graph_store = AstraDBVectorStore( + self.vector_search = AstraDBVectorStore( collection_name=collection_name, collection_vector_service_options=collection_vector_service_options, token=ASTRA_DB_APPLICATION_TOKEN, @@ -49,13 +60,26 @@ def __init__(self, datastax_token, collection_name, model='nvidia', batch_size=8 namespace=ASTRA_DB_KEYSPACE, ) elif model == 'jina': - embeddings = JinaAIEmbedder(embedding_dim=1024, cache=cache_embeddings) - self.tokenizer = embeddings.tokenizer + self.embeddings = JinaAIEmbedder(embedding_dim=1024) + self.tokenizer = self.embeddings.tokenizer + self.max_token_size = 1024 + + self.vector_search = AstraDBVectorStore( + collection_name=collection_name, + embedding=self.embeddings, + token=ASTRA_DB_APPLICATION_TOKEN, + api_endpoint=ASTRA_DB_API_ENDPOINT, + namespace=ASTRA_DB_KEYSPACE, + ) + + elif model == 'jinaapi': + self.embeddings = JinaAIAPIEmbedder(embedding_dim=1024) + self.tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True) self.max_token_size = 1024 - self.graph_store = AstraDBVectorStore( + self.vector_search = AstraDBVectorStore( collection_name=collection_name, - embedding=embeddings, + embedding=self.embeddings, token=ASTRA_DB_APPLICATION_TOKEN, api_endpoint=ASTRA_DB_API_ENDPOINT, namespace=ASTRA_DB_KEYSPACE, @@ -72,36 +96,71 @@ def add_document(self, id, text, metadata): - text (str): The text content of the document. - metadata (dict): Additional metadata about the document. """ - doc = Document(page_content=text, metadata=metadata) - self.doc_batch.append(doc) - self.id_batch.append(id) + doc = { + '_id': id, + 'content':text, + 'metadata':metadata + } + self.doc_batch.put(doc) # If we reach the batch size, push the accumulated documents to AstraDB - if len(self.doc_batch) >= self.batch_size: + if self.doc_batch.qsize() >= self.batch_size: self.push_batch() def push_batch(self): """ Push the current batch of documents to AstraDB for storage. - Retries automatically if a connection issue occurs, waiting for - an active internet connection. + Caches the embeddings into a SQLite database. """ + if self.doc_batch.empty(): + return False + + docs = [] + for _ in range(self.batch_size): + try: + doc = self.doc_batch.get_nowait() + cache = self._get_cached_embedding(doc['_id']) + if cache is None: + docs.append(doc) + except: + # Queue is empty + break + + if len(docs) == 0: + return False + while True: try: - self.graph_store.add_documents(self.doc_batch, ids=self.id_batch) - self.doc_batch = [] - self.id_batch = [] + vectors = self.embeddings.embed_documents( + [doc['content'] for doc in docs] + ) break except Exception as e: print(e) - while True: - try: - response = requests.get("https://www.google.com", timeout=5) - if response.status_code == 200: - break - except Exception as e: - print("Waiting for internet connection...") + time.sleep(3) + + while True: + try: + self.graph_store.insert_many(docs, vectors=vectors) + break + except self.InsertManyException as e: + break + except Exception as e: + print(e) + time.sleep(3) + + self.cache_model.add_bulk_cache([{ + 'id': docs[i]['_id'], + 'embedding': vectors[i]} + for i in range(len(docs))]) + + return True + + def push_all(self): + while True: + if not self.push_batch(): # Stop when batch is empty + break def get_similar_qids(self, query, filter={}, K=50): """ @@ -117,21 +176,10 @@ def get_similar_qids(self, query, filter={}, K=50): where list_of_qids are the QIDs of the results and list_of_scores are the corresponding similarity scores. """ - while True: - try: - results = self.graph_store.similarity_search_with_relevance_scores(query, k=K, filter=filter) - qid_results = [r[0].metadata['QID'] for r in results] - score_results = [r[1] for r in results] - return qid_results, score_results - except Exception as e: - print(e) - while True: - try: - response = requests.get("https://www.google.com", timeout=5) - if response.status_code == 200: - break - except Exception as e: - time.sleep(5) + results = self.vector_search.similarity_search_with_relevance_scores(query, k=K, filter=filter) + qid_results = [r[0].metadata['QID']+"_"+r[0].metadata.get('Language', '') for r in results] + score_results = [r[1] for r in results] + return qid_results, score_results def batch_retrieve_comparative(self, queries_batch, comparative_batch, K=50, Language=None): """ @@ -186,7 +234,33 @@ def batch_retrieve(self, queries_batch, K=50, Language=None): qids, scores = zip(*results) return list(qids), list(scores) -class WikidataKeywordSearch: + def _cache_embedding(self, id, embedding): + """ + Caches the text and its embedding in the SQLite database. + + Parameters: + - text (str): The text string. + - embedding (List[float]): The embedding vector for the text. + """ + if self.cache_on: + embedding = embedding.tolist() + self.cache_model.add_cache(id=id, embedding=embedding) + + def _get_cached_embedding(self, id): + """ + Retrieves a previously cached embedding for the specified text. + + Parameters: + - text (str): The text string. + + Returns: + - List[float] or None: The embedding if found in cache, otherwise None. + """ + if self.cache_on: + return self.cache_model.get_cache(id=id) + return None + +class KeywordSearchConnect: def __init__(self, url, index_name = 'wikidata'): """ Initialize the WikidataKeywordSearch object with an Elasticsearch instance. @@ -195,97 +269,151 @@ def __init__(self, url, index_name = 'wikidata'): - url (str): URL (host) of the Elasticsearch server. - index_name (str): Name of the Elasticsearch index. Default is 'wikidata'. """ + from elasticsearch import Elasticsearch + self.index_name = index_name self.es = Elasticsearch(url) + self.create_index() - # Create the index if it doesn't already exist + def create_index(self): + """ + Create the index with appropriate settings and mappings to optimize search. + """ if not self.es.indices.exists(index=self.index_name): self.es.indices.create(index=self.index_name, body={ - "mappings": { - "properties": { - "text": { - "type": "text" + "settings": { + "analysis": { + "analyzer": { + "rebuilt_standard": { + "tokenizer": "standard", + "filter": ["lowercase", "stop"] + } + } + } + }, + "mappings": { + "properties": { + "text": { + "type": "text", + "analyzer": "default" + }, + "metadata": { + "type": "object", + "properties": { + "QID": {"type": "keyword"}, + "Language": {"type": "keyword"}, + "Date": {"type": "keyword"} + } + } } } - } - }) + }) - def search(self, query, K=50): + def add_document(self, id, text, metadata): + """ + Add a document to the Elasticsearch index. """ - Perform a keyword-based search against the Elasticsearch index. + doc = { + 'text': text, + 'metadata': {'QID': metadata['QID'], 'Language': metadata['Language']} + } + try: + if self.es.exists(index=self.index_name, id=id): + return + self.es.index(index=self.index_name, id=id, body=doc) + except ConnectionError as e: + print("Connection error:", e) + time.sleep(1) - Parameters: - - query (str): The query string to match against document text. - - K (int): Number of top results to return. Default is 50. + def push_batch(self): + pass - Returns: - - list: A list of raw Elasticsearch hits, each containing a '_score' and '_source'. + def search(self, query, K=50): + """ + Perform a text search using Elasticsearch. + """ + search_body = { + "query": { + "match": { + "text": query + } + }, + "size": K + } + try: + response = self.es.search(index=self.index_name, body=search_body) + return [hit for hit in response['hits']['hits']] + except ConnectionError as e: + print("Connection error:", e) + return [] + + def get_similar_qids(self, query, filter=[], K=50): + """ + Retrieve documents based on similarity to a query, potentially with filtering. """ search_body = { "query": { "bool": { - "should": [ - { - "match": { - "text": { - "query": query, - "operator": "or", - "boost": 1.0 - } - } - }, - { - "match_all": { - "boost": 0.01 # Lower boost to make match_all results less relevant - } + "must": { + "match": { + "text": query } - ] + }, + "filter": filter } }, - "size": K, - "sort": [ - { - "_score": { - "order": "desc" - } - } - ] + "size": K } - response = self.es.search(index=self.index_name, body=search_body) - return [hit for hit in response['hits']['hits']] + try: + response = self.es.search(index=self.index_name, body=search_body) + qid_results = [hit['_source']['metadata']['QID'] for hit in response['hits']['hits']] + score_results = [hit['_score'] for hit in response['hits']['hits']] + return qid_results, score_results - def get_similar_qids(self, query, filter_qid={}, K=50): - """ - Retrieve QIDs based on a keyword-based search. Optionally filter by QID. - - Parameters: - - query (str): The search string. - - filter_qid (dict): Optional filter (currently unused, placeholder). - - K (int): Number of top results to return. Default is 50. + except Exception as e: + print("Search failed:", e) + return [] - Returns: - - tuple: (list_of_qids, list_of_scores) + def batch_retrieve_comparative(self, queries_batch, comparative_batch, K=50, Language=None): """ - results = self.search(query, K=K) - qid_results = [r['_id'].split("_")[0] for r in results] - score_results = [r['_score'] for r in results] - return qid_results, score_results - - def batch_retrieve(self, queries_batch, K=50): + Retrieve similar documents in a comparative fashion for each query and comparative item. """ - Perform keyword-based search for a batch of queries. + qids = [[] for _ in range(len(queries_batch))] + scores = [[] for _ in range(len(queries_batch))] - Parameters: - - queries_batch (pd.Series or list): A list or series of query strings. - - K (int): Number of top results to return for each query. Default is 50. + for i, query in enumerate(queries_batch): + for comp_col in comparative_batch.columns: + filter = [] + # Apply language filter if specified + if Language: + filter.append({"term": {"metadata.Language": Language}}) + # Apply QID filter specific to the comparative group + qid_filter = comparative_batch[comp_col].iloc[i] + filter.append({"term": {"metadata.QID": qid_filter}}) - Returns: - - tuple: (list_of_qid_lists, list_of_score_lists), each element corresponding to a single query. + result_qids, result_scores = self.get_similar_qids(query, filter=filter, K=K) + qids[i].extend(result_qids) + scores[i].extend(result_scores) + + return qids, scores + + def batch_retrieve(self, queries_batch, K=50, Language=None): """ - results = [ - self.get_similar_qids(queries_batch.iloc[i], K=K) - for i in range(len(queries_batch)) - ] + Perform batch searches and handle potential connection issues. + """ + filter = [] + if Language: + languages = Language.split(',') + filter.append({"bool": {"should": [{"term": {"metadata.Language": lang}} for lang in languages]}}) - qids, scores = zip(*results) + results = [] + for query in queries_batch: + try: + result = self.get_similar_qids(query, K=K, filter=filter) + results.append(result) + except ConnectionError as e: + print("Connection error during batch processing:", e) + time.sleep(1) + + qids, scores = zip(*results) if results else ([], []) return list(qids), list(scores) \ No newline at end of file