From 0a13464ced59b5f648f611d0b45b6845ae4b1a85 Mon Sep 17 00:00:00 2001 From: Philippe Saade Date: Tue, 11 Mar 2025 15:15:12 +0100 Subject: [PATCH 01/49] Add prototype creation script --- docker-compose.yml | 72 +++- docker/1_Data_Processing_save_ids/run.py | 41 -- .../Dockerfile | 4 +- .../requirements.txt | 0 .../run.py | 51 +++ .../Dockerfile | 0 .../requirements.txt | 0 .../run.py | 22 +- .../requirements.txt | 3 +- docker/3_Add_Wikidata_to_AstraDB/run.py | 40 +- docker/4_Run_Retrieval/run.py | 13 +- docker/6_Push_Huggingface/run.py | 8 +- docker/7_Create_Prototype/Dockerfile | 30 ++ docker/7_Create_Prototype/requirements.txt | 22 + docker/7_Create_Prototype/run.py | 125 ++++++ run_experiments.sh | 12 + src/JinaAI.py | 164 ++++---- src/__init__.py | 8 +- src/experimental_functions/word_embeding.py | 64 +++ src/language_variables/ar.py | 16 +- src/language_variables/de.py | 28 +- src/language_variables/en.py | 28 +- src/language_variables/json.py | 85 +--- src/language_variables/rdf.py | 16 +- src/wikidataCache.py | 98 +++++ src/wikidataEmbed.py | 199 ++++++--- src/wikidataItemDB.py | 383 ++++++++++++++++++ src/{wikidataDB.py => wikidataLangDB.py} | 188 ++------- src/wikidataRetriever.py | 345 ++++++++++------ 29 files changed, 1448 insertions(+), 617 deletions(-) delete mode 100644 docker/1_Data_Processing_save_ids/run.py rename docker/{1_Data_Processing_save_ids => 1_Data_Processing_save_labels_descriptions}/Dockerfile (79%) rename docker/{1_Data_Processing_save_ids => 1_Data_Processing_save_labels_descriptions}/requirements.txt (100%) create mode 100644 docker/1_Data_Processing_save_labels_descriptions/run.py rename docker/{2_Data_Processing_save_entities => 2_Data_Processing_save_items_per_lang}/Dockerfile (100%) rename docker/{2_Data_Processing_save_entities => 2_Data_Processing_save_items_per_lang}/requirements.txt (100%) rename docker/{2_Data_Processing_save_entities => 2_Data_Processing_save_items_per_lang}/run.py (59%) create mode 100644 docker/7_Create_Prototype/Dockerfile create mode 100644 docker/7_Create_Prototype/requirements.txt create mode 100644 docker/7_Create_Prototype/run.py create mode 100644 run_experiments.sh create mode 100644 src/experimental_functions/word_embeding.py create mode 100644 src/wikidataCache.py create mode 100644 src/wikidataItemDB.py rename src/{wikidataDB.py => wikidataLangDB.py} (55%) diff --git a/docker-compose.yml b/docker-compose.yml index 15574ac..37bc8b4 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,8 +1,8 @@ services: - data_processing_save_ids: + data_processing_save_labels_descriptions: build: context: . - dockerfile: ./docker/1_Data_Processing_save_ids/Dockerfile + dockerfile: ./docker/1_Data_Processing_save_labels_descriptions/Dockerfile volumes: - ./data:/data # Mount the ./data folder from the host to /data in the container tty: true @@ -12,10 +12,10 @@ services: LANGUAGE: "de" OFFSET: 0 - data_processing_save_entities: + data_processing_save_items_per_lang: build: context: . - dockerfile: ./docker/2_Data_Processing_save_entities/Dockerfile + dockerfile: ./docker/2_Data_Processing_save_items_per_lang/Dockerfile volumes: - ./data:/data # Mount the ./data folder from the host to /data in the container tty: true @@ -43,12 +43,16 @@ services: PYTHONUNBUFFERED: 1 MODEL: "jina" SAMPLE: "true" - API_KEY: "datastax_wikidata_nvidia.json" + API_KEY: "datastax_wikidata2.json" EMBED_BATCH_SIZE: 8 - QUERY_BATCH_SIZE: 1000 - OFFSET: 2560000 - COLLECTION_NAME: "wikidata_test_v1" - LANGUAGE: 'ar' + QUERY_BATCH_SIZE: 100 + OFFSET: 120000 + COLLECTION_NAME: "wikidatav1" + LANGUAGE: 'en' + TEXTIFIER_LANGUAGE: 'en' + ELASTICSEARCH_URL: "http://localhost:9200" + ELASTICSEARCH: "false" + network_mode: "host" run_retrieval: build: @@ -70,17 +74,16 @@ services: environment: PYTHONUNBUFFERED: 1 MODEL: "jina" - API_KEY: "datastax_wikidata_nvidia.json" - COLLECTION_NAME: "wikidata_test_v1" + API_KEY: "datastax_wikidata.json" + COLLECTION_NAME: "wikidata_texttest" BATCH_SIZE: 100 - EVALUATION_PATH: "Mintaka/processed_dataframe_langtest.pkl" + EVALUATION_PATH: "Mintaka/processed_dataframe.pkl" # COMPARATIVE: "true" # COMPARATIVE_COLS: "Correct QID,Wrong QID" QUERY_COL: "Question" - # QUERY_LANGUAGE: "ar" - # DB_LANGUAGE: "en,ar" + QUERY_LANGUAGE: "en" + # DB_LANGUAGE: "en" PREFIX: "" - ELASTICSEARCH_URL: "http://localhost:9200" network_mode: "host" run_rerank: @@ -107,4 +110,41 @@ services: BATCH_SIZE: 1 QUERY_COL: "Question" LANGUAGE: "de" - network_mode: "host" \ No newline at end of file + network_mode: "host" + + push_huggingface: + build: + context: . + dockerfile: ./docker/6_Push_Huggingface/Dockerfile + volumes: + - ./data:/data + tty: true + container_name: push_huggingface + environment: + PYTHONUNBUFFERED: 1 + QUEUE_SIZE: 5000 + NUM_PROCESSES: 4 + SKIPLINES: 0 + ITERATION: 36 + + create_prototype: + build: + context: . + dockerfile: ./docker/7_Create_Prototype/Dockerfile + volumes: + - ./data:/data + - ~/.cache/huggingface:/root/.cache/huggingface + tty: true + container_name: create_prototype + environment: + PYTHONUNBUFFERED: 1 + MODEL: "jinaapi" + API_KEY: "datastax_wikidata.json" + EMBED_BATCH_SIZE: 100 + NUM_PROCESSES: 16 + OFFSET: 0 + COLLECTION_NAME: "wikidata_prototype" + LANGUAGE: 'en' + TEXTIFIER_LANGUAGE: 'en' + CHUNK_NUM: 5 + network_mode: "host" \ No newline at end of file diff --git a/docker/1_Data_Processing_save_ids/run.py b/docker/1_Data_Processing_save_ids/run.py deleted file mode 100644 index c9034b2..0000000 --- a/docker/1_Data_Processing_save_ids/run.py +++ /dev/null @@ -1,41 +0,0 @@ -import sys -sys.path.append('../src') - -from wikidataDumpReader import WikidataDumpReader -from wikidataDB import WikidataID -from multiprocessing import Manager -import os -import time - -FILEPATH = os.getenv("FILEPATH", '../data/Wikidata/latest-all.json.bz2') -BATCH_SIZE = int(os.getenv("BATCH_SIZE", 1000)) -QUEUE_SIZE = int(os.getenv("QUEUE_SIZE", 1500)) -NUM_PROCESSES = int(os.getenv("NUM_PROCESSES", 4)) -SKIPLINES = int(os.getenv("SKIPLINES", 0)) -LANGUAGE = os.getenv("LANGUAGE", 'en') - -def save_ids_to_sqlite(item, bulk_ids, sqlitDBlock): - if (item is not None) and WikidataID.is_in_wikipedia(item, language=LANGUAGE): - ids = WikidataID.extract_entity_ids(item, language=LANGUAGE) - bulk_ids.extend(ids) - - with sqlitDBlock: - if len(bulk_ids) > BATCH_SIZE: - worked = WikidataID.add_bulk_ids(list(bulk_ids[:BATCH_SIZE])) - if worked: - del bulk_ids[:BATCH_SIZE] - -if __name__ == "__main__": - multiprocess_manager = Manager() - sqlitDBlock = multiprocess_manager.Lock() - bulk_ids = multiprocess_manager.list() - - wikidata = WikidataDumpReader(FILEPATH, num_processes=NUM_PROCESSES, batch_size=BATCH_SIZE, queue_size=QUEUE_SIZE, skiplines=SKIPLINES) - wikidata.run(lambda item: save_ids_to_sqlite(item, bulk_ids, sqlitDBlock), max_iterations=None, verbose=True) - - while len(bulk_ids) > 0: - worked = WikidataID.add_bulk_ids(list(bulk_ids)) - if worked: - bulk_ids[:] = [] - else: - time.sleep(1) \ No newline at end of file diff --git a/docker/1_Data_Processing_save_ids/Dockerfile b/docker/1_Data_Processing_save_labels_descriptions/Dockerfile similarity index 79% rename from docker/1_Data_Processing_save_ids/Dockerfile rename to docker/1_Data_Processing_save_labels_descriptions/Dockerfile index 68bae7e..c67ee6f 100644 --- a/docker/1_Data_Processing_save_ids/Dockerfile +++ b/docker/1_Data_Processing_save_labels_descriptions/Dockerfile @@ -10,13 +10,13 @@ LABEL maintainer="philippe.saade@wikimedia.de" WORKDIR /app # Copy the requirements file into the container -COPY ./docker/1_Data_Processing_save_ids/requirements.txt requirements.txt +COPY ./docker/1_Data_Processing_save_labels_descriptions/requirements.txt requirements.txt # Install the dependencies RUN pip install --no-cache-dir -r requirements.txt # Copy the rest of the application code into the container -COPY ./docker/1_Data_Processing_save_ids /app +COPY ./docker/1_Data_Processing_save_labels_descriptions /app COPY ./src /src # Set up the volume for the data folder diff --git a/docker/1_Data_Processing_save_ids/requirements.txt b/docker/1_Data_Processing_save_labels_descriptions/requirements.txt similarity index 100% rename from docker/1_Data_Processing_save_ids/requirements.txt rename to docker/1_Data_Processing_save_labels_descriptions/requirements.txt diff --git a/docker/1_Data_Processing_save_labels_descriptions/run.py b/docker/1_Data_Processing_save_labels_descriptions/run.py new file mode 100644 index 0000000..7a151c3 --- /dev/null +++ b/docker/1_Data_Processing_save_labels_descriptions/run.py @@ -0,0 +1,51 @@ +import sys +sys.path.append('../src') + +from wikidataDumpReader import WikidataDumpReader +from wikidataItemDB import WikidataItem +from multiprocessing import Manager +import os +import time +import json + +FILEPATH = os.getenv("FILEPATH", '../data/Wikidata/latest-all.json.bz2') +PUSH_SIZE = int(os.getenv("PUSH_SIZE", 20000)) +QUEUE_SIZE = int(os.getenv("QUEUE_SIZE", 15000)) +NUM_PROCESSES = int(os.getenv("NUM_PROCESSES", 4)) +SKIPLINES = int(os.getenv("SKIPLINES", 0)) +LANGUAGE = os.getenv("LANGUAGE", 'en') + +def save_items_to_sqlite(item, data_batch, sqlitDBlock): + if (item is not None): + labels = WikidataItem.clean_label_description(item['labels']) + descriptions = WikidataItem.clean_label_description(item['descriptions']) + labels = json.dumps(labels, separators=(',', ':')) + descriptions = json.dumps(descriptions, separators=(',', ':')) + in_wikipedia = WikidataItem.is_in_wikipedia(item) + data_batch.append({ + 'id': item['id'], + 'labels': labels, + 'descriptions': descriptions, + 'in_wikipedia': in_wikipedia, + }) + + with sqlitDBlock: + if len(data_batch) > PUSH_SIZE: + worked = WikidataItem.add_bulk_items(list(data_batch[:PUSH_SIZE])) + if worked: + del data_batch[:PUSH_SIZE] + +if __name__ == "__main__": + multiprocess_manager = Manager() + sqlitDBlock = multiprocess_manager.Lock() + data_batch = multiprocess_manager.list() + + wikidata = WikidataDumpReader(FILEPATH, num_processes=NUM_PROCESSES, queue_size=QUEUE_SIZE, skiplines=SKIPLINES) + wikidata.run(lambda item: save_items_to_sqlite(item, data_batch, sqlitDBlock), max_iterations=None, verbose=True) + + while len(data_batch) > 0: + worked = WikidataItem.add_bulk_items(list(data_batch)) + if worked: + del data_batch[:PUSH_SIZE] + else: + time.sleep(1) \ No newline at end of file diff --git a/docker/2_Data_Processing_save_entities/Dockerfile b/docker/2_Data_Processing_save_items_per_lang/Dockerfile similarity index 100% rename from docker/2_Data_Processing_save_entities/Dockerfile rename to docker/2_Data_Processing_save_items_per_lang/Dockerfile diff --git a/docker/2_Data_Processing_save_entities/requirements.txt b/docker/2_Data_Processing_save_items_per_lang/requirements.txt similarity index 100% rename from docker/2_Data_Processing_save_entities/requirements.txt rename to docker/2_Data_Processing_save_items_per_lang/requirements.txt diff --git a/docker/2_Data_Processing_save_entities/run.py b/docker/2_Data_Processing_save_items_per_lang/run.py similarity index 59% rename from docker/2_Data_Processing_save_entities/run.py rename to docker/2_Data_Processing_save_items_per_lang/run.py index 0252157..3fd7353 100644 --- a/docker/2_Data_Processing_save_entities/run.py +++ b/docker/2_Data_Processing_save_items_per_lang/run.py @@ -2,40 +2,40 @@ sys.path.append('../src') from wikidataDumpReader import WikidataDumpReader -from wikidataDB import WikidataID, WikidataEntity +from wikidataLangDB import WikidataLang from multiprocessing import Manager import os import time FILEPATH = os.getenv("FILEPATH", '../data/Wikidata/latest-all.json.bz2') -BATCH_SIZE = int(os.getenv("BATCH_SIZE", 1000)) +PUSH_SIZE = int(os.getenv("PUSH_SIZE", 2000)) QUEUE_SIZE = int(os.getenv("QUEUE_SIZE", 1500)) -NUM_PROCESSES = int(os.getenv("NUM_PROCESSES", 4)) +NUM_PROCESSES = int(os.getenv("NUM_PROCESSES", 8)) SKIPLINES = int(os.getenv("SKIPLINES", 0)) LANGUAGE = os.getenv("LANGUAGE", 'en') def save_entities_to_sqlite(item, data_batch, sqlitDBlock): - if (item is not None) and WikidataID.get_id(item['id']): - item = WikidataEntity.normalise_item(item, language=LANGUAGE) + if (item is not None) and WikidataLang.is_in_wikipedia(item, language=LANGUAGE): + item = WikidataLang.normalise_item(item, language=LANGUAGE) data_batch.append(item) with sqlitDBlock: - if len(data_batch) > BATCH_SIZE: - worked = WikidataEntity.add_bulk_entities(list(data_batch[:BATCH_SIZE])) + if len(data_batch) > PUSH_SIZE: + worked = WikidataLang.add_bulk_entities(list(data_batch[:PUSH_SIZE])) if worked: - del data_batch[:BATCH_SIZE] + del data_batch[:PUSH_SIZE] if __name__ == "__main__": multiprocess_manager = Manager() sqlitDBlock = multiprocess_manager.Lock() data_batch = multiprocess_manager.list() - wikidata = WikidataDumpReader(FILEPATH, num_processes=NUM_PROCESSES, batch_size=BATCH_SIZE, queue_size=QUEUE_SIZE, skiplines=SKIPLINES) + wikidata = WikidataDumpReader(FILEPATH, num_processes=NUM_PROCESSES, queue_size=QUEUE_SIZE, skiplines=SKIPLINES) wikidata.run(lambda item: save_entities_to_sqlite(item, data_batch, sqlitDBlock), max_iterations=None, verbose=True) while len(data_batch) > 0: - worked = WikidataEntity.add_bulk_entities(list(data_batch)) + worked = WikidataLang.add_bulk_entities(list(data_batch)) if worked: - data_batch[:] = [] + del data_batch[:PUSH_SIZE] else: time.sleep(1) \ No newline at end of file diff --git a/docker/3_Add_Wikidata_to_AstraDB/requirements.txt b/docker/3_Add_Wikidata_to_AstraDB/requirements.txt index fbcfb4d..65289f0 100644 --- a/docker/3_Add_Wikidata_to_AstraDB/requirements.txt +++ b/docker/3_Add_Wikidata_to_AstraDB/requirements.txt @@ -17,5 +17,4 @@ langchain_experimental ragstack-ai-langchain[knowledge-store]==1.3.0 langchain-astradb astrapy -elasticsearch -mediawikiapi \ No newline at end of file +elasticsearch \ No newline at end of file diff --git a/docker/3_Add_Wikidata_to_AstraDB/run.py b/docker/3_Add_Wikidata_to_AstraDB/run.py index 9dc113c..616f77b 100644 --- a/docker/3_Add_Wikidata_to_AstraDB/run.py +++ b/docker/3_Add_Wikidata_to_AstraDB/run.py @@ -1,9 +1,9 @@ import sys sys.path.append('../src') -from wikidataDB import Session, WikidataID, WikidataEntity +from wikidataLangDB import Session, WikidataLang from wikidataEmbed import WikidataTextifier -from wikidataRetriever import AstraDBConnect +from wikidataRetriever import AstraDBConnect, KeywordSearchConnect import json from tqdm import tqdm @@ -14,6 +14,7 @@ MODEL = os.getenv("MODEL", "jina") SAMPLE = os.getenv("SAMPLE", "false").lower() == "true" +SAMPLE_PATH = os.getenv("SAMPLE_PATH", "../data/Evaluation Data/Sample IDs (EN).pkl") EMBED_BATCH_SIZE = int(os.getenv("EMBED_BATCH_SIZE", 100)) QUERY_BATCH_SIZE = int(os.getenv("QUERY_BATCH_SIZE", 1000)) OFFSET = int(os.getenv("OFFSET", 0)) @@ -23,6 +24,9 @@ TEXTIFIER_LANGUAGE = os.getenv("TEXTIFIER_LANGUAGE", None) DUMPDATE = os.getenv("DUMPDATE", '09/18/2024') +ELASTICSEARCH_URL = os.getenv("ELASTICSEARCH_URL", "http://localhost:9200") +ELASTICSEARCH = os.getenv("ELASTICSEARCH", "false").lower() == "true" + # Load the Database if not COLLECTION_NAME: raise ValueError("The COLLECTION_NAME environment variable is required") @@ -34,13 +38,17 @@ API_KEY_FILENAME = os.listdir("../API_tokens")[0] datastax_token = json.load(open(f"../API_tokens/{API_KEY_FILENAME}")) -textifier = WikidataTextifier(language=TEXTIFIER_LANGUAGE) -graph_store = AstraDBConnect(datastax_token, COLLECTION_NAME, model=MODEL, batch_size=EMBED_BATCH_SIZE, cache_embeddings=False) +textifier = WikidataTextifier(language=LANGUAGE, langvar_filename=TEXTIFIER_LANGUAGE) + +if ELASTICSEARCH: + graph_store = KeywordSearchConnect(ELASTICSEARCH_URL, index_name=COLLECTION_NAME) +else: + graph_store = AstraDBConnect(datastax_token, COLLECTION_NAME, model=MODEL, batch_size=EMBED_BATCH_SIZE, cache_embeddings=False) # Load the Sample IDs sample_ids = None if SAMPLE: - sample_ids = pickle.load(open("../data/Evaluation Data/Sample IDs (EN).pkl", "rb")) + sample_ids = pickle.load(open(SAMPLE_PATH, "rb")) sample_ids = sample_ids[sample_ids['In Wikipedia']] total_entities = len(sample_ids) @@ -50,7 +58,7 @@ def get_entity(session): # For each batch of sample QIDs, fetch the entities from the database for qid_batch in sample_qid_batches: - entities = session.query(WikidataEntity).filter(WikidataEntity.id.in_(qid_batch)).yield_per(QUERY_BATCH_SIZE) + entities = session.query(WikidataLang).filter(WikidataLang.id.in_(qid_batch)).yield_per(QUERY_BATCH_SIZE) for entity in entities: yield entity @@ -58,7 +66,7 @@ def get_entity(session): total_entities = 9203786 def get_entity(session): - entities = session.query(WikidataEntity).join(WikidataID, WikidataEntity.id == WikidataID.id).filter(WikidataID.in_wikipedia == True).offset(OFFSET).yield_per(QUERY_BATCH_SIZE) + entities = session.query(WikidataLang).offset(OFFSET).yield_per(QUERY_BATCH_SIZE) for entity in entities: yield entity @@ -71,22 +79,28 @@ def get_entity(session): for entity in entity_generator: progressbar.update(1) - chunks = textifier.chunk_text(entity, graph_store.tokenizer, max_length=graph_store.max_token_size) + if ELASTICSEARCH: + chunks = [textifier.entity_to_text(entity)] + else: + chunks = textifier.chunk_text(entity, graph_store.tokenizer, max_length=graph_store.max_token_size) + for chunk_i in range(len(chunks)): md5_hash = hashlib.md5(chunks[chunk_i].encode('utf-8')).hexdigest() metadata={ - # "MD5": md5_hash, - # "Label": entity.label, - # "Description": entity.description, - # "Aliases": entity.aliases, + "MD5": md5_hash, + "Label": entity.label, + "Description": entity.description, + "Aliases": entity.aliases, "Date": datetime.now().isoformat(), "QID": entity.id, "ChunkID": chunk_i+1, "Language": LANGUAGE, + "IsItem": ('Q' in entity.id), + "IsProperty": ('P' in entity.id), "DumpDate": DUMPDATE } graph_store.add_document(id=f"{entity.id}_{LANGUAGE}_{chunk_i+1}", text=chunks[chunk_i], metadata=metadata) - tqdm.write(progressbar.format_meter(progressbar.n, progressbar.total, progressbar.format_dict["elapsed"])) # tqdm is not wokring in docker compose. This is the alternative + tqdm.write(progressbar.format_meter(progressbar.n, progressbar.total, progressbar.format_dict["elapsed"])) # tqdm is not working in docker compose. This is the alternative graph_store.push_batch() diff --git a/docker/4_Run_Retrieval/run.py b/docker/4_Run_Retrieval/run.py index 022625a..490b071 100644 --- a/docker/4_Run_Retrieval/run.py +++ b/docker/4_Run_Retrieval/run.py @@ -1,7 +1,7 @@ import sys sys.path.append('../src') -from wikidataRetriever import AstraDBConnect, WikidataKeywordSearch +from wikidataRetriever import AstraDBConnect, KeywordSearchConnect import json from tqdm import tqdm @@ -22,9 +22,11 @@ QUERY_LANGUAGE = os.getenv("QUERY_LANGUAGE", 'en') DB_LANGUAGE = os.getenv("DB_LANGUAGE", None) RESTART = os.getenv("RESTART", "false").lower() == "true" -ELASTICSEARCH_URL = os.getenv("ELASTICSEARCH_URL", "http://localhost:9200") PREFIX = os.getenv("PREFIX", "") +ELASTICSEARCH_URL = os.getenv("ELASTICSEARCH_URL", "http://localhost:9200") +ELASTICSEARCH = os.getenv("ELASTICSEARCH", "false").lower() == "true" + OUTPUT_FILENAME = f"retrieval_results_{EVALUATION_PATH.split('/')[-2]}-{COLLECTION_NAME}-DB({DB_LANGUAGE})-Query({QUERY_LANGUAGE})" # OUTPUT_FILENAME = f"retrieval_results_{EVALUATION_PATH.split('/')[-2]}-keyword-search-{LANGUAGE}" if PREFIX != "": @@ -38,8 +40,11 @@ API_KEY_FILENAME = os.listdir("../API_tokens")[0] datastax_token = json.load(open(f"../API_tokens/{API_KEY_FILENAME}")) -graph_store = AstraDBConnect(datastax_token, COLLECTION_NAME, model=MODEL, batch_size=BATCH_SIZE, cache_embeddings=True) -# graph_store = WikidataKeywordSearch(ELASTICSEARCH_URL) +if ELASTICSEARCH: + graph_store = KeywordSearchConnect(ELASTICSEARCH_URL, index_name=COLLECTION_NAME) + OUTPUT_FILENAME += "_bm25" +else: + graph_store = AstraDBConnect(datastax_token, COLLECTION_NAME, model=MODEL, batch_size=BATCH_SIZE, cache_embeddings=True) #Load the Evaluation Dataset if not QUERY_COL: diff --git a/docker/6_Push_Huggingface/run.py b/docker/6_Push_Huggingface/run.py index c636307..56898f1 100644 --- a/docker/6_Push_Huggingface/run.py +++ b/docker/6_Push_Huggingface/run.py @@ -7,7 +7,7 @@ sys.path.append('../src') from wikidataDumpReader import WikidataDumpReader -from wikidataLabelsDB import WikidataLabels +from wikidataItemDB import WikidataItem from datasets import Dataset, load_dataset_builder @@ -22,8 +22,8 @@ def save_to_queue(item, data_queue): """Processes and puts cleaned item into the multiprocessing queue.""" - if (item is not None) and (WikidataLabels.is_in_wikipedia(item)): - claims = WikidataLabels.add_labels_batched(item['claims'], query_batch=100) + if (item is not None) and (WikidataItem.is_in_wikipedia(item)): + claims = WikidataItem.add_labels_batched(item['claims'], query_batch=100) data_queue.put({ 'id': item['id'], 'labels': json.dumps(item['labels'], separators=(',', ':')), @@ -82,7 +82,7 @@ def run_reader(): login(token=api_key) builder = load_dataset_builder("philippesaade/wikidata") -for i in range(45, 113): +for i in range(0, 113): split_name = f"chunk_{i}" if split_name not in builder.info.splits: filepath = f"../data/Wikidata/latest-all-chunks/chunk_{i}.json.gz" diff --git a/docker/7_Create_Prototype/Dockerfile b/docker/7_Create_Prototype/Dockerfile new file mode 100644 index 0000000..bf17994 --- /dev/null +++ b/docker/7_Create_Prototype/Dockerfile @@ -0,0 +1,30 @@ +# Use the official Python image from the Docker Hub +FROM python:3.9-slim +LABEL maintainer="philippe.saade@wikimedia.de" + +# Upgrade the pip, git and ubuntu versions to the most recent version +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + git \ + && rm -rf /var/lib/apt/lists/* \ + && pip install --upgrade pip setuptools wheel + +# Set the working directory in the container +WORKDIR /app + +# Copy the requirements file into the container +COPY ./docker/7_Create_Prototype/requirements.txt requirements.txt + +# Install the dependencies +RUN pip install --no-cache-dir -r requirements.txt + +# Copy the rest of the application code into the container +COPY ./docker/7_Create_Prototype /app +COPY ./src /src +COPY ./API_tokens /API_tokens + +# Set up the volume for the data folder +VOLUME [ "/data" ] + +# Run the Python script +CMD ["python", "run.py"] \ No newline at end of file diff --git a/docker/7_Create_Prototype/requirements.txt b/docker/7_Create_Prototype/requirements.txt new file mode 100644 index 0000000..c97586e --- /dev/null +++ b/docker/7_Create_Prototype/requirements.txt @@ -0,0 +1,22 @@ +tqdm +psutil +orjson +pandas + +# wikidataDB +sqlalchemy + +# wikidataEmbed +transformers +einops + +# wikidataRetriever +langchain-core +langchainhub +langchain_experimental +ragstack-ai-langchain[knowledge-store]==1.3.0 +langchain-astradb +astrapy +elasticsearch + +datasets \ No newline at end of file diff --git a/docker/7_Create_Prototype/run.py b/docker/7_Create_Prototype/run.py new file mode 100644 index 0000000..5be1dd3 --- /dev/null +++ b/docker/7_Create_Prototype/run.py @@ -0,0 +1,125 @@ +import sys +sys.path.append('../src') + +from wikidataEmbed import WikidataTextifier +from wikidataRetriever import AstraDBConnect +from datasets import load_dataset +from multiprocessing import Process, Queue, Manager + +import json +from tqdm import tqdm +import os +from datetime import datetime +import hashlib +from types import SimpleNamespace +import time + +MODEL = os.getenv("MODEL", "jinaapi") +QUEUE_SIZE = int(os.getenv("QUEUE_SIZE", 5000)) +NUM_PROCESSES = int(os.getenv("NUM_PROCESSES", 4)) +EMBED_BATCH_SIZE = int(os.getenv("EMBED_BATCH_SIZE", 100)) +DB_API_KEY_FILENAME = os.getenv("DB_API_KEY", "datastax_wikidata.json") +COLLECTION_NAME = os.getenv("COLLECTION_NAME") + +CHUNK_NUM = os.getenv("CHUNK_NUM") +LANGUAGE = "en" +TEXTIFIER_LANGUAGE = "en" +DUMPDATE = "09/18/2024" + +# Load the Database +if not COLLECTION_NAME: + raise ValueError("The COLLECTION_NAME environment variable is required") + +if not TEXTIFIER_LANGUAGE: + TEXTIFIER_LANGUAGE = LANGUAGE + +FILEPATH = f"../data/Wikidata/chunks/chunk_{CHUNK_NUM}.json.gz" +chunk_sizes = {"chunk_0":992458,"chunk_1":802125,"chunk_2":589652,"chunk_3":310440,"chunk_4":43477,"chunk_5":156867,"chunk_6":141965,"chunk_7":74047,"chunk_8":27104,"chunk_9":70759,"chunk_10":71395,"chunk_11":186698,"chunk_12":153182,"chunk_13":137155,"chunk_14":929827,"chunk_15":853027,"chunk_16":571543,"chunk_17":335565,"chunk_18":47264,"chunk_19":135986,"chunk_20":160411,"chunk_21":76377,"chunk_22":26321,"chunk_23":70572,"chunk_24":68613,"chunk_25":179806,"chunk_26":159587,"chunk_27":139912,"chunk_28":876104,"chunk_29":864360,"chunk_30":590603,"chunk_31":358747,"chunk_32":47772,"chunk_33":135633,"chunk_34":159629,"chunk_35":81231,"chunk_36":24912,"chunk_37":69201,"chunk_38":67131,"chunk_39":172234,"chunk_40":167698,"chunk_41":142276,"chunk_42":821175,"chunk_43":892005,"chunk_44":600584,"chunk_45":374793,"chunk_46":47443,"chunk_47":134784,"chunk_48":155247,"chunk_49":86997,"chunk_50":24829,"chunk_51":68053,"chunk_52":63517,"chunk_53":167660,"chunk_54":175827,"chunk_55":142816,"chunk_56":765400,"chunk_57":900655,"chunk_58":628866,"chunk_59":396886,"chunk_60":46907,"chunk_61":135384,"chunk_62":154864,"chunk_63":88112,"chunk_64":23353,"chunk_65":67446,"chunk_66":40301,"chunk_67":176420,"chunk_68":183715,"chunk_69":149547,"chunk_70":713006,"chunk_71":901222,"chunk_72":652770,"chunk_73":419554,"chunk_74":52246,"chunk_75":134064,"chunk_76":153318,"chunk_77":92710,"chunk_78":22790,"chunk_79":66521,"chunk_80":34397,"chunk_81":173357,"chunk_82":186788,"chunk_83":153870,"chunk_84":657926,"chunk_85":902477,"chunk_86":655319,"chunk_87":455111,"chunk_88":69724,"chunk_89":133629,"chunk_90":146534,"chunk_91":101890,"chunk_92":21324,"chunk_93":65448,"chunk_94":33345,"chunk_95":162191,"chunk_96":192226,"chunk_97":159451,"chunk_98":598037,"chunk_99":903618,"chunk_100":662580,"chunk_101":484690,"chunk_102":86616,"chunk_103":135160,"chunk_104":106630,"chunk_105":142249,"chunk_106":19290,"chunk_107":60073,"chunk_108":39131,"chunk_109":155251,"chunk_110":190337,"chunk_111":166210,"chunk_112":26375} +total_entities = chunk_sizes[f"chunk_{CHUNK_NUM}"] + +datastax_token = json.load(open(f"../API_tokens/{DB_API_KEY_FILENAME}")) +dataset = load_dataset( + "philippesaade/wikidata", + data_files=f"data/chunk_{CHUNK_NUM}-*.parquet", + streaming=True, + split="train" +) + +def process_items(queue, progress_bar): + """Worker function that processes items from the queue and adds them to AstraDB.""" + datastax_token = json.load(open(f"../API_tokens/{DB_API_KEY_FILENAME}")) + graph_store = AstraDBConnect(datastax_token, COLLECTION_NAME, model=MODEL, batch_size=EMBED_BATCH_SIZE, cache_embeddings="wikidata_prototype") + textifier = WikidataTextifier(language=LANGUAGE, langvar_filename=TEXTIFIER_LANGUAGE) + + while True: + item = queue.get() + if item is None: + break # Exit condition for worker processes + + item_id = item['id'] + item_label = textifier.get_label(item_id, json.loads(item['labels'])) + item_description = textifier.get_description(item_id, json.loads(item['descriptions'])) + item_aliases = textifier.get_aliases(json.loads(item['aliases'])) + + if item_label is None: + continue + + entity_obj = SimpleNamespace() + entity_obj.id = item_id + entity_obj.label = item_label + entity_obj.description = item_description + entity_obj.aliases = item_aliases + entity_obj.claims = json.loads(item['claims']) + + chunks = textifier.chunk_text(entity_obj, graph_store.tokenizer, max_length=graph_store.max_token_size) + + for chunk_i, chunk in enumerate(chunks): + md5_hash = hashlib.md5(chunk.encode('utf-8')).hexdigest() + metadata = { + "MD5": md5_hash, + "Label": item_label, + "Description": item_description, + "Aliases": item_aliases, + "Date": datetime.now().isoformat(), + "QID": item_id, + "ChunkID": chunk_i + 1, + "Language": LANGUAGE, + "IsItem": ('Q' in item_id), + "IsProperty": ('P' in item_id), + "DumpDate": DUMPDATE + } + graph_store.add_document(id=f"{item_id}_{LANGUAGE}_{chunk_i+1}", text=chunk, metadata=metadata) + + with progress_bar.get_lock(): # Update tqdm safely from multiple processes + progress_bar.value += 1 + + while True: + if not graph_store.push_batch(): # Stop when batch is empty + break + +if __name__ == "__main__": + queue = Queue(maxsize=QUEUE_SIZE) + progress_bar = Manager().Value("i", 0) + + with tqdm(total=total_entities) as pbar: + processes = [] + for _ in range(NUM_PROCESSES): + p = Process(target=process_items, args=(queue, progress_bar)) + p.start() + processes.append(p) + + for item in dataset: + queue.put(item) + pbar.n = progress_bar.value + pbar.refresh() + + for _ in range(NUM_PROCESSES): + queue.put(None) + + while any(p.is_alive() for p in processes): + pbar.n = progress_bar.value + pbar.refresh() + time.sleep(1) + + for p in processes: + p.join() \ No newline at end of file diff --git a/run_experiments.sh b/run_experiments.sh new file mode 100644 index 0000000..d79fd90 --- /dev/null +++ b/run_experiments.sh @@ -0,0 +1,12 @@ +# docker compose run --build add_wikidata_to_astra +# docker compose run --build -e EVALUATION_PATH="Mintaka/processed_dataframe.pkl" -e QUERY_COL="Question" -e PREFIX="_nonewlines" -e COLLECTION_NAME="wikidatav1" -e QUERY_LANGUAGE="en" -e DB_LANGUAGE="en" -e API_KEY="datastax_wikidata2.json" run_retrieval +# docker compose run --build -e EVALUATION_PATH="LC_QuAD/processed_dataframe.pkl" -e QUERY_COL="Question" -e PREFIX="_nonewlines" -e COLLECTION_NAME="wikidatav1" -e QUERY_LANGUAGE="en" -e DB_LANGUAGE="en" -e API_KEY="datastax_wikidata2.json" run_retrieval +# docker compose run --build -e EVALUATION_PATH="REDFM/processed_dataframe.pkl" -e QUERY_COL="Sentence" -e PREFIX="_nonewlines" -e COLLECTION_NAME="wikidatav1" -e QUERY_LANGUAGE="en" -e DB_LANGUAGE="en" -e API_KEY="datastax_wikidata2.json" run_retrieval +# docker compose run --build -e EVALUATION_PATH="REDFM/processed_dataframe.pkl" -e QUERY_COL="Sentence no entity" -e PREFIX="_nonewlines_noentity" -e COLLECTION_NAME="wikidatav1" -e QUERY_LANGUAGE="en" -e DB_LANGUAGE="en" -e API_KEY="datastax_wikidata2.json" run_retrieval +# docker compose run --build -e EVALUATION_PATH="RuBQ/processed_dataframe.pkl" -e QUERY_COL="Question" -e PREFIX="_nonewlines" -e COLLECTION_NAME="wikidatav1" -e QUERY_LANGUAGE="en" -e DB_LANGUAGE="en" -e API_KEY="datastax_wikidata2.json" run_retrieval +# docker compose run --build -e EVALUATION_PATH="Wikidata-Disamb/processed_dataframe.pkl" -e QUERY_COL="Sentence" -e COMPARATIVE="true" -e COMPARATIVE_COLS="Correct QID,Wrong QID" -e COLLECTION_NAME="wikidatav1" -e PREFIX="_nonewlines" -e QUERY_LANGUAGE="en" -e DB_LANGUAGE="en" -e API_KEY="datastax_wikidata2.json" run_retrieval + +# docker compose run --build -e CHUNK_NUM=5 create_prototype +docker compose run --build -e CHUNK_NUM=6 create_prototype +docker compose run --build -e CHUNK_NUM=7 create_prototype +docker compose run --build -e CHUNK_NUM=8 create_prototype \ No newline at end of file diff --git a/src/JinaAI.py b/src/JinaAI.py index 9f367f8..2edacc0 100644 --- a/src/JinaAI.py +++ b/src/JinaAI.py @@ -1,52 +1,12 @@ - from typing import List -from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification -import torch -from sqlalchemy import Column, Text, create_engine -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker -from sqlalchemy.types import TypeDecorator import json import requests import numpy as np import base64 - -""" -SQLite database setup for caching the query embeddings for a faster evaluation process. -""" -engine = create_engine( - 'sqlite:///../data/Wikidata/sqlite_cacheembeddings.db', - pool_size=5, # Limit the number of open connections - max_overflow=10, # Allow extra connections beyond pool_size - pool_recycle=10 # Recycle connections every 10 seconds -) - -Base = declarative_base() -Session = sessionmaker(bind=engine) - -class JSONType(TypeDecorator): - """Custom SQLAlchemy type for JSON storage in SQLite.""" - impl = Text - - def process_bind_param(self, value, dialect): - if value is not None: - return json.dumps(value, separators=(',', ':')) - return None - - def process_result_value(self, value, dialect): - if value is not None: - return json.loads(value) - return None - -class CacheEmbeddings(Base): - """Represents a cache entry for a text string and its embedding.""" - __tablename__ = 'embeddings' - - text = Column(Text, primary_key=True) - embedding = Column(JSONType) +from wikidataCache import create_cache_embedding_model class JinaAIEmbedder: - def __init__(self, passage_task="retrieval.passage", query_task="retrieval.query", embedding_dim=1024, cache=False, api_key_path="../API_tokens/jina_api.json"): + def __init__(self, passage_task="retrieval.passage", query_task="retrieval.query", embedding_dim=1024, cache=None): """ Initializes the JinaAIEmbedder class with the model, tokenizer, and task identifiers. @@ -54,18 +14,22 @@ def __init__(self, passage_task="retrieval.passage", query_task="retrieval.query - passage_task (str): Task identifier for embedding documents. Defaults to "retrieval.passage". - query_task (str): Task identifier for embedding queries. Defaults to "retrieval.query". - embedding_dim (int): Dimensionality of the embeddings. Defaults to 1024. - - cache (bool): Whether to cache query embeddings in the database. Defaults to False. + - cache (str): Name of caching table. - api_key_path (str): Path to the JSON file containing the Jina API key. Defaults to "../API_tokens/jina_api.json". """ + from transformers import AutoModel, AutoTokenizer + import torch + self.passage_task = passage_task self.query_task = query_task self.embedding_dim = embedding_dim - self.cache = cache self.model = AutoModel.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True).to('cuda') self.tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True) - self.api_key = json.load(open(api_key_path, 'r+'))['API_KEY'] + self.cache = (cache is not None) + if self.cache: + self.cache_model = create_cache_embedding_model(cache) def _cache_embedding(self, text: str, embedding: List[float]): """ @@ -77,14 +41,7 @@ def _cache_embedding(self, text: str, embedding: List[float]): """ if self.cache: embedding = embedding.tolist() - with Session() as session: - try: - cached = CacheEmbeddings(text=text, embedding=embedding) - session.merge(cached) - session.commit() - except Exception as e: - session.rollback() - raise e + self.cache_model.add_cache(id=text, embedding=embedding) def _get_cached_embedding(self, text: str) -> List[float]: """ @@ -97,13 +54,63 @@ def _get_cached_embedding(self, text: str) -> List[float]: - List[float] or None: The embedding if found in cache, otherwise None. """ if self.cache: - with Session() as session: - cached = session.query(CacheEmbeddings).filter_by(text=text).first() - if cached: - return cached.embedding + return self.cache_model.get_cache(id=text) return None - def api_embed(self, text, task="retrieval.query"): + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """ + Generates embeddings for a list of document (passage) texts. + + Caching is not used here by default to avoid storing large numbers of document embeddings. + + Parameters: + - texts (List[str]): A list of document texts to embed. + + Returns: + - List[List[float]]: A list of embedding vectors, each corresponding to a document. + """ + with torch.no_grad(): + embeddings = self.model.encode(texts, task=self.passage_task, truncate_dim=self.embedding_dim) + return embeddings + + def embed_query(self, text: str) -> List[float]: + """ + Generates an embedding for a single query string, optionally using and updating the cache. + + Parameters: + - text (str): The query text to embed. + + Returns: + - List[float]: The embedding vector corresponding to the query. + """ + cached_embedding = self._get_cached_embedding(text) + if cached_embedding: + return cached_embedding + + with torch.no_grad(): + embedding = self.model.encode([text], task=self.query_task, truncate_dim=self.embedding_dim)[0] + self._cache_embedding(text, embedding) + return embedding + +class JinaAIAPIEmbedder: + def __init__(self, passage_task="retrieval.passage", query_task="retrieval.query", embedding_dim=1024, cache=False, api_key_path="../API_tokens/jina_api.json"): + """ + Initializes the JinaAIEmbedder class with the model, tokenizer, and task identifiers. + + Parameters: + - passage_task (str): Task identifier for embedding documents. Defaults to "retrieval.passage". + - query_task (str): Task identifier for embedding queries. Defaults to "retrieval.query". + - embedding_dim (int): Dimensionality of the embeddings. Defaults to 1024. + - cache (str): Name of caching table. + - api_key_path (str): Path to the JSON file containing the Jina API key. Defaults to "../API_tokens/jina_api.json". + """ + self.passage_task = passage_task + self.query_task = query_task + self.embedding_dim = embedding_dim + + self.api_key = json.load(open(api_key_path, 'r+'))['API_KEY'] + + def api_embed(self, texts, task="retrieval.query"): """ Generates an embedding for the given text using the Jina Embeddings API. @@ -120,21 +127,29 @@ def api_embed(self, text, task="retrieval.query"): 'Authorization': f'Bearer {self.api_key}' } + if type(texts) is str: + texts = [texts] + data = { "model": "jina-embeddings-v3", "dimensions": self.embedding_dim, "embedding_type": "base64", "task": task, "late_chunking": False, - "input": [ - text - ] + "input": texts } response = requests.post(url, headers=headers, json=data) - binary_data = base64.b64decode(response.json()['data'][0]['embedding']) - embedding_array = np.frombuffer(binary_data, dtype=' List[List[float]]: """ @@ -148,8 +163,7 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: Returns: - List[List[float]]: A list of embedding vectors, each corresponding to a document. """ - with torch.no_grad(): - embeddings = self.model.encode(texts, task=self.passage_task, truncate_dim=self.embedding_dim) + embeddings = self.api_embed(texts, task=self.passage_task) return embeddings def embed_query(self, text: str) -> List[float]: @@ -162,14 +176,8 @@ def embed_query(self, text: str) -> List[float]: Returns: - List[float]: The embedding vector corresponding to the query. """ - cached_embedding = self._get_cached_embedding(text) - if cached_embedding: - return cached_embedding - - with torch.no_grad(): - embedding = self.model.encode([text], task=self.query_task, truncate_dim=self.embedding_dim)[0] - self._cache_embedding(text, embedding) - return embedding + embedding = self.api_embed([text], task=self.query_task) + return embedding class JinaAIReranker: def __init__(self, max_tokens=1024): @@ -182,6 +190,9 @@ def __init__(self, max_tokens=1024): Raises: - ValueError: If max_tokens is greater than 1024. """ + from transformers import AutoModelForSequenceClassification + import torch + if max_tokens > 1024: raise ValueError("Max token should be less than or equal to 1024") @@ -202,7 +213,4 @@ def rank(self, query: str, texts: List[str]) -> List[float]: sentence_pairs = [[query, doc] for doc in texts] with torch.no_grad(): - return self.model.compute_score(sentence_pairs, max_length=self.max_tokens) - -# Create tables if they don't already exist. -Base.metadata.create_all(engine) \ No newline at end of file + return self.model.compute_score(sentence_pairs, max_length=self.max_tokens) \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py index d3b8f15..d99d86d 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1,6 +1,6 @@ from .wikidataDumpReader import WikidataDumpReader -from .wikidataDB import WikidataEntity, WikidataID -from .wikidataLabelsDB import WikidataLabels +from .wikidataLangDB import WikidataLang +from .wikidataItemDB import WikidataItem from .wikidataEmbed import WikidataTextifier -from .JinaAI import JinaAIEmbedder, JinaAIReranker -from .wikidataRetriever import AstraDBConnect \ No newline at end of file +from .JinaAI import JinaAIEmbedder, JinaAIReranker, JinaAIAPIEmbedder +from .wikidataRetriever import AstraDBConnect, KeywordSearchConnect \ No newline at end of file diff --git a/src/experimental_functions/word_embeding.py b/src/experimental_functions/word_embeding.py new file mode 100644 index 0000000..c009441 --- /dev/null +++ b/src/experimental_functions/word_embeding.py @@ -0,0 +1,64 @@ +import torch +import torch.nn.functional as F +from JinaAI import JinaAIEmbedder + +model = JinaAIEmbedder() + +def mean_pooling(model_output, attention_mask): + token_embeddings = model_output[0] + input_mask_expanded = ( + attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + ) + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( + input_mask_expanded.sum(1), min=1e-9 + ) + +def extract_specific_word_embedding(model, tokenizer, sentence, start_marker='', end_marker=''): + # Check if the markers are in the sentence + if start_marker in sentence and end_marker in sentence: + # Find the position of the markers + start_index = sentence.find(start_marker) + end_index = sentence.find(end_marker) + len(end_marker) + + # Split the sentence into three parts: before, target, and after + before_target = sentence[:start_index] + target_word = sentence[start_index + len(start_marker):end_index - len(end_marker)] + after_target = sentence[end_index:] + + # Clean the sentence by removing the markers + tokens_before = tokenizer(before_target, add_special_tokens=False)['input_ids'] + tokens_target = tokenizer(target_word, add_special_tokens=False)['input_ids'] + + # Reconstruct the input with markers removed and tokenize + clean_sentence = before_target + target_word + after_target + inputs = tokenizer(clean_sentence, return_tensors="pt").to(model.device) + + # Calculate the start and end positions of the target word tokens + target_start_pos = len(tokens_before) + target_end_pos = target_start_pos + len(tokens_target) + + # Run the model to get output embeddings + task = 'retrieval.query' + task_id = model._adaptation_map[task] + adapter_mask = torch.full((1,), task_id, dtype=torch.int32).to(model.device) + with torch.no_grad(): + outputs = model(**inputs, adapter_mask=adapter_mask) + + target_attention_mask = torch.zeros_like(inputs['attention_mask']) + target_attention_mask[:, target_start_pos:target_end_pos] = 1 + + # Apply mean pooling only to the target word's embeddings + # word_embeddings = mean_pooling(outputs, target_attention_mask) + # normalized_embeddings = F.normalize(word_embeddings, p=2, dim=1) + + normalized_embeddings = outputs[0][:, target_start_pos:target_end_pos].mean(dim=0) + return normalized_embeddings + else: + with torch.no_grad(): + outputs = model.encode(sentence, task='retrieval.query') + return outputs + +# Example usage +sentence = "I went to the river bank to take a swim. Afterwards, I went to the bank to withdraw some money." +embedding = extract_specific_word_embedding(model.model, model.tokenizer, sentence) +print("Extracted Embedding:", embedding) \ No newline at end of file diff --git a/src/language_variables/ar.py b/src/language_variables/ar.py index e0cb548..14fc86f 100644 --- a/src/language_variables/ar.py +++ b/src/language_variables/ar.py @@ -59,11 +59,15 @@ def qualifiers_to_text(qualifiers): """ text = "" for property_label, qualifier_values in qualifiers.items(): - if len(text) > 0: - text += f" ; " + if (qualifier_values is not None) and len(qualifier_values) > 0: + if len(text) > 0: + text += f" ; " - text += f"{property_label}: {'، '.join(qualifier_values)}" - return text + text += f"{property_label}: {'، '.join(qualifier_values)}" + + if len(text) > 0: + return f" ({text})" + return "" def properties_to_text(properties): """ @@ -77,7 +81,7 @@ def properties_to_text(properties): """ properties_text = "" for property_label, claim_values in properties.items(): - if len(claim_values) > 0: + if (claim_values is not None) and (len(claim_values) > 0): claims_text = "" for claim_value in claim_values: @@ -88,7 +92,7 @@ def properties_to_text(properties): qualifiers = claim_value.get('qualifiers', {}) if len(qualifiers) > 0: - claims_text += f" ({qualifiers_to_text(qualifiers)})" + claims_text += qualifiers_to_text(qualifiers) claims_text += f"»" diff --git a/src/language_variables/de.py b/src/language_variables/de.py index 08c8f59..c5eb103 100644 --- a/src/language_variables/de.py +++ b/src/language_variables/de.py @@ -59,11 +59,18 @@ def qualifiers_to_text(qualifiers): """ text = "" for property_label, qualifier_values in qualifiers.items(): - if len(text) > 0: - text += f" ; " + if (qualifier_values is not None) and len(qualifier_values) > 0: + if len(text) > 0: + text += f" " - text += f"{property_label}: {', '.join(qualifier_values)}" - return text + text += f"({property_label}: {', '.join(qualifier_values)})" + + elif (qualifier_values is not None): + text += f"(hat {property_label})" + + if len(text) > 0: + return f" {text}" + return "" def properties_to_text(properties): """ @@ -77,21 +84,22 @@ def properties_to_text(properties): """ properties_text = "" for property_label, claim_values in properties.items(): - if len(claim_values) > 0: + if (claim_values is not None) and (len(claim_values) > 0): claims_text = "" for claim_value in claim_values: if len(claims_text) > 0: - claims_text += f",\n " + claims_text += f", " - claims_text += f"„{claim_value['value']}" + claims_text += claim_value['value'] qualifiers = claim_value.get('qualifiers', {}) if len(qualifiers) > 0: - claims_text += f" ({qualifiers_to_text(qualifiers)})" + claims_text += qualifiers_to_text(qualifiers) - claims_text += f"“" + properties_text += f'\n- {property_label}: „{claims_text}“' - properties_text += f'\n- {property_label}: {claims_text}' + elif (claim_values is not None): + properties_text += f'\n- hat {property_label}' return properties_text \ No newline at end of file diff --git a/src/language_variables/en.py b/src/language_variables/en.py index 6c87778..a5f951d 100644 --- a/src/language_variables/en.py +++ b/src/language_variables/en.py @@ -59,11 +59,18 @@ def qualifiers_to_text(qualifiers): """ text = "" for property_label, qualifier_values in qualifiers.items(): - if len(text) > 0: - text += f" ; " + if (qualifier_values is not None) and len(qualifier_values) > 0: + if len(text) > 0: + text += f" " - text += f"{property_label}: {', '.join(qualifier_values)}" - return text + text += f"({property_label}: {', '.join(qualifier_values)})" + + elif (qualifier_values is not None): + text += f"(has {property_label})" + + if len(text) > 0: + return f" {text}" + return "" def properties_to_text(properties): """ @@ -77,21 +84,22 @@ def properties_to_text(properties): """ properties_text = "" for property_label, claim_values in properties.items(): - if len(claim_values) > 0: + if (claim_values is not None) and (len(claim_values) > 0): claims_text = "" for claim_value in claim_values: if len(claims_text) > 0: - claims_text += f",\n " + claims_text += f", " - claims_text += f"\"{claim_value['value']}" + claims_text += claim_value['value'] qualifiers = claim_value.get('qualifiers', {}) if len(qualifiers) > 0: - claims_text += f" ({qualifiers_to_text(qualifiers)})" + claims_text += qualifiers_to_text(qualifiers) - claims_text += f"\"" + properties_text += f'\n- {property_label}: "{claims_text}"' - properties_text += f'\n- {property_label}: {claims_text}' + elif (claim_values is not None): + properties_text += f'\n- has {property_label}' return properties_text \ No newline at end of file diff --git a/src/language_variables/json.py b/src/language_variables/json.py index 15002f0..eefbb6c 100644 --- a/src/language_variables/json.py +++ b/src/language_variables/json.py @@ -41,7 +41,7 @@ def merge_entity_text(label, description, aliases, properties): 'description': description, 'aliases': aliases, **properties - }, ensure_ascii=False) + }, ensure_ascii=False, indent=4) return text @@ -50,70 +50,19 @@ def compress_json(data): # Iterate through the items of the data for key, items in data.items(): - cleaned_items = [] - for item in items: - qualifiers = {k: v[0] if isinstance(v, list) and len(v) == 1 else v for k, v in item['qualifiers'].items()} - clean_item = {'value': item['value'], **qualifiers} - if len(clean_item) == 1: - clean_item = clean_item['value'] - - cleaned_items.append(clean_item) - - if len(cleaned_items) == 1: - cleaned_items = cleaned_items[0] - cleaned_data[key] = cleaned_items - - return cleaned_data - -def qualifiers_to_text(qualifiers): - """ - Converts a list of qualifiers to a readable text string. - Qualifiers provide additional information about a claim. - - Parameters: - - qualifiers: A dictionary of qualifiers with property IDs as keys and their values as lists. - - Returns: - - A string representation of the qualifiers. - """ - text = "" - for property_label, qualifier_values in qualifiers.items(): - if len(text) > 0: - text += f" ; " - - text += f"{property_label}: {', '.join(qualifier_values)}" - return text - -def properties_to_text(properties, label=""): - """ - Converts a list of properties (claims) to a readable text string. - - Parameters: - - properties: A dictionary of properties (claims) with property IDs as keys. - - Returns: - - A string representation of the properties and their values. - """ - properties_text = "" - for property_label, claim_values in properties.items(): - if len(claim_values) > 0: - - claims_text = "" - qualifier_exists = any([len(claim_value.get('qualifiers', {})) > 0 for claim_value in claim_values]) - if qualifier_exists: - for claim_value in claim_values: - if len(claims_text) > 0: - claims_text += f"\n" - - claims_text += f"{label}: {property_label}: {claim_value['value']}" - - qualifiers = claim_value.get('qualifiers', {}) - if len(qualifiers) > 0: - claims_text += f" ({qualifiers_to_text(qualifiers)})" - else: - claims_text = ', '.join([claim_value['value'] for claim_value in claim_values]) - claims_text = f"{label}: {property_label}: {claims_text}" - - properties_text += f'\n{claims_text}' - - return properties_text \ No newline at end of file + if (items is not None) and (len(items) > 0): + cleaned_items = [] + for item in items: + qualifiers = {k: v[0] if isinstance(v, list) and len(v) == 1 else v for k, v in item['qualifiers'].items()} + clean_item = {'value': item['value'], **qualifiers} + if len(clean_item) == 1: + clean_item = clean_item['value'] + + cleaned_items.append(clean_item) + + if len(cleaned_items) == 1: + cleaned_items = cleaned_items[0] + elif len(cleaned_items) > 1: + cleaned_data[key] = cleaned_items + + return cleaned_data \ No newline at end of file diff --git a/src/language_variables/rdf.py b/src/language_variables/rdf.py index 7fd498c..0e803c7 100644 --- a/src/language_variables/rdf.py +++ b/src/language_variables/rdf.py @@ -57,11 +57,15 @@ def qualifiers_to_text(qualifiers): """ text = "" for property_label, qualifier_values in qualifiers.items(): - if len(text) > 0: - text += f" ; " + if (qualifier_values is not None) and len(qualifier_values) > 0: + if len(text) > 0: + text += f" ; " - text += f"{property_label}: {', '.join(qualifier_values)}" - return text + text += f"{property_label}: {', '.join(qualifier_values)}" + + if len(text) > 0: + return f" ({text})" + return "" def properties_to_text(properties, label=""): """ @@ -75,7 +79,7 @@ def properties_to_text(properties, label=""): """ properties_text = "" for property_label, claim_values in properties.items(): - if len(claim_values) > 0: + if (claim_values is not None) and (len(claim_values) > 0): claims_text = "" qualifier_exists = any([len(claim_value.get('qualifiers', {})) > 0 for claim_value in claim_values]) @@ -88,7 +92,7 @@ def properties_to_text(properties, label=""): qualifiers = claim_value.get('qualifiers', {}) if len(qualifiers) > 0: - claims_text += f" ({qualifiers_to_text(qualifiers)})" + claims_text += qualifiers_to_text(qualifiers) else: claims_text = ', '.join([claim_value['value'] for claim_value in claim_values]) claims_text = f"{label}: {property_label}: {claims_text}" diff --git a/src/wikidataCache.py b/src/wikidataCache.py new file mode 100644 index 0000000..4922a1f --- /dev/null +++ b/src/wikidataCache.py @@ -0,0 +1,98 @@ +from sqlalchemy import Column, Text, create_engine, text +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker +from sqlalchemy.types import TypeDecorator +import json + +""" +SQLite database setup for caching the query embeddings for a faster evaluation process. +""" +engine = create_engine( + 'sqlite:///../data/Wikidata/sqlite_cacheembeddings.db', + pool_size=5, # Limit the number of open connections + max_overflow=10, # Allow extra connections beyond pool_size + pool_recycle=10 # Recycle connections every 10 seconds +) + +Base = declarative_base() +Session = sessionmaker(bind=engine) + +class JSONType(TypeDecorator): + """Custom SQLAlchemy type for JSON storage in SQLite.""" + impl = Text + + def process_bind_param(self, value, dialect): + if value is not None: + return json.dumps(value, separators=(',', ':')) + return None + + def process_result_value(self, value, dialect): + if value is not None: + return json.loads(value) + return None + +def create_cache_embedding_model(table_name): + """Factory function to create a dynamic CacheEmbeddings model.""" + + class CacheEmbeddings(Base): + __tablename__ = table_name + + id = Column(Text, primary_key=True) + embedding = Column(JSONType) + + @staticmethod + def add_cache(id, embedding): + with Session() as session: + try: + cached = CacheEmbeddings(id=id, embedding=embedding) + session.merge(cached) + session.commit() + return True + except Exception as e: + session.rollback() + raise e + + @staticmethod + def get_cache(id): + with Session() as session: + cached = session.query(CacheEmbeddings).filter_by(id=id).first() + if cached: + return cached.embedding + return None + + @staticmethod + def add_bulk_cache(data): + """ + Insert multiple label records in bulk. If a record with the same ID exists, + it is ignored (no update is performed). + + Parameters: + - data (list[dict]): A list of dictionaries, each containing 'id', 'labels', 'descriptions', and 'in_wikipedia' keys. + + Returns: + - bool: True if the operation was successful, False otherwise. + """ + worked = False + with Session() as session: + try: + session.execute( + text( + f""" + INSERT INTO {CacheEmbeddings.__tablename__} (id, embedding) + VALUES (:id, :embedding) + ON CONFLICT(id) DO NOTHING + """ + ), + data + ) + session.commit() + session.flush() + worked = True + except Exception as e: + session.rollback() + print(e) + return worked + + Base.metadata.create_all(engine) + + return CacheEmbeddings \ No newline at end of file diff --git a/src/wikidataEmbed.py b/src/wikidataEmbed.py index c7525ee..9f7b518 100644 --- a/src/wikidataEmbed.py +++ b/src/wikidataEmbed.py @@ -1,4 +1,4 @@ -from wikidataDB import WikidataEntity +from wikidataItemDB import WikidataItem import requests import time import json @@ -7,7 +7,7 @@ import importlib class WikidataTextifier: - def __init__(self, language='en'): + def __init__(self, language='en', langvar_filename=None): """ Initializes the WikidataTextifier with the specified language. @@ -16,18 +16,61 @@ def __init__(self, language='en'): """ self.language = language + langvar_filename = (langvar_filename if langvar_filename is not None else language) try: # Importing custom functions and variables from a formating python script in the language_variables folder. - self.langvar = importlib.import_module(f"language_variables.{language}") + self.langvar = importlib.import_module(f"language_variables.{langvar_filename}") except Exception as e: raise ValueError(f"Language file for '{language}' not found.") + def get_label(self, id, labels=None): + if (labels is None) or (len(labels) == 0): + labels = WikidataItem.get_labels(id) + + if (type(labels) is str): + return labels + + # Take the label from the language, if missing take it from the multiligual class + label = labels[self.language] if (self.language in labels) else (labels['mul'] if ('mul' in labels) else None) + + if type(label) is dict: + label = label['value'] + return label + + def get_description(self, id, descriptions=None): + if (descriptions is None) or (len(descriptions) == 0): + descriptions = WikidataItem.get_descriptions(id) + + if (type(descriptions) is str): + return descriptions + + # Take the description from the language, if missing take it from the multiligual class + description = descriptions[self.language] if (self.language in descriptions) else (descriptions['mul'] if ('mul' in descriptions) else None) + + if type(description) is dict: + description = description['value'] + return description + + def get_aliases(self, aliases): + if (type(aliases) is list): + return aliases + + if aliases is None: + return [] + + aliases = set() + if self.language in aliases: + aliases = set([x['value'] for x in aliases[self.language]]) + if 'mul' in aliases: + aliases = aliases | set([x['value'] for x in aliases['mul']]) + return list(aliases) + def entity_to_text(self, entity, properties=None): """ Converts a Wikidata entity into a human-readable text string. Parameters: - - entity (WikidataEntity): A WikidataEntity object containing entity data (label, description, claims, etc.). + - entity: A Wikidata entity object containing entity data (label, description, claims, etc.). - properties (dict or None): A dictionary of properties (claims). If None, the properties will be derived from entity.claims. Returns: @@ -36,14 +79,23 @@ def entity_to_text(self, entity, properties=None): if properties is None: properties = self.properties_to_dict(entity.claims) - return self.langvar.merge_entity_text(entity.label, entity.description, entity.aliases, properties) + label = self.get_label(entity.id, labels=entity.label) + + description = self.get_description(entity.id, descriptions=entity.description) + if (description is None) or (len(description) == 0): + instanceof = self.get_label('P31') + description = properties.get(instanceof, '') + + aliases = self.get_aliases(entity.aliases) + + return self.langvar.merge_entity_text(label, description, aliases, properties) def properties_to_dict(self, properties): """ Converts a dictionary of properties (claims) into a dict suitable for text generation. Parameters: - - properties (dict): A dictionary of claims keyed by property IDs. + - properties (dict): A dictionary of claims keyed by property IDs. Each value is a list of claim statements for that property. Returns: @@ -55,23 +107,30 @@ def properties_to_dict(self, properties): rank_preferred_found = False for c in claim: - value = self.mainsnak_to_value(c.get('mainsnak', c)) - qualifiers = self.qualifiers_to_dict(c.get('qualifiers', {})) - rank = c.get('rank', 'normal').lower() - - # Only store "normal" ranks. if one "preferred" rank exists, then only store "preferred" ranks. - if value: - if ((not rank_preferred_found) and (rank == 'normal')) or (rank == 'preferred'): - if (not rank_preferred_found) and (rank == 'preferred'): - rank_preferred_found = True - p_data = [] - - p_data.append({'value': value, 'qualifiers': qualifiers}) + try: + value = self.mainsnak_to_value(c.get('mainsnak', c)) + qualifiers = self.qualifiers_to_dict(c.get('qualifiers', {})) + rank = c.get('rank', 'normal').lower() + + if value is None: + p_data = None + break + + elif len(value) > 0: + # If a preferred rank exists, include values that are only preferred. Else include only values that are ranked normal (values with a depricated rank are never included) + if ((not rank_preferred_found) and (rank == 'normal')) or (rank == 'preferred'): + if (not rank_preferred_found) and (rank == 'preferred'): + rank_preferred_found = True + p_data = [] + + p_data.append({'value': value, 'qualifiers': qualifiers}) + except Exception as e: + print(c) + raise e - if len(p_data) > 0: - property = WikidataEntity.get_entity(pid) - if property: - properties_dict[property.label] = p_data + label = self.get_label(pid, claim[0].get('mainsnak', {}).get('property-labels')) + if label: + properties_dict[label] = p_data return properties_dict @@ -80,7 +139,7 @@ def qualifiers_to_dict(self, qualifiers): Converts qualifiers into a dictionary suitable for text generation. Parameters: - - qualifiers (dict): A dictionary of qualifiers keyed by property IDs. + - qualifiers (dict): A dictionary of qualifiers keyed by property IDs. Each value is a list of qualifier statements. Returns: @@ -92,13 +151,16 @@ def qualifiers_to_dict(self, qualifiers): for q in qualifier: value = self.mainsnak_to_value(q) - if value: + if value is None: + q_data = None + break + elif len(value) > 0: q_data.append(value) - if len(q_data) > 0: - property = WikidataEntity.get_entity(pid) - if property: - qualifier_dict[property.label] = q_data + label = self.get_label(pid, qualifier[0].get('property-labels')) + if label: + qualifier_dict[label] = q_data + return qualifier_dict def mainsnak_to_value(self, mainsnak): @@ -109,42 +171,55 @@ def mainsnak_to_value(self, mainsnak): - mainsnak (dict): A snak object containing the value and datatype information. Returns: - - str or None: A string representation of the value, or None if parsing fails. + - str or None: A string representation of the value. If the returned string is empty, the value is discarded from the text, and If None i retured, then the whole property is discarded. """ - if mainsnak.get('snaktype', '') == 'value': - if (mainsnak.get('datatype', '') == 'wikibase-item') or (mainsnak.get('datatype', '') == 'wikibase-property'): - entity_id = mainsnak['datavalue']['value']['id'] - entity = WikidataEntity.get_entity(entity_id) - if entity is None: - return None + # Extract the datavalue + snaktype = mainsnak.get('snaktype', 'value') + datavalue = mainsnak.get('datavalue') + if (datavalue is not None) and (type(datavalue) is not str): + datavalue = datavalue.get('value', datavalue) + + # Consider missing values + if (snaktype != 'value') or (datavalue is None): + return self.langvar.novalue - text = entity.label - return text + # If the values is based on a language, only consider the language that matched the text representation language. + elif (type(datavalue) is dict) and ('language' in datavalue) and (datavalue['language'] != self.language): + return '' - elif mainsnak.get('datatype', '') == 'monolingualtext': - return mainsnak['datavalue']['value']['text'] + elif (mainsnak.get('datatype', '') == 'wikibase-item') or (mainsnak.get('datatype', '') == 'wikibase-property'): + if type(datavalue) is str: + return self.get_label(datavalue) - elif mainsnak.get('datatype', '') == 'string': - return mainsnak['datavalue']['value'] + entity_id = datavalue['id'] + label = self.get_label(entity_id, datavalue.get('labels')) + return label - elif mainsnak.get('datatype', '') == 'time': - try: - return self.time_to_text(mainsnak['datavalue']['value']) - except Exception as e: - print("Error in time formating:", e) - return mainsnak['datavalue']['value']["time"] + elif mainsnak.get('datatype', '') == 'monolingualtext': + return datavalue.get('text', datavalue) - elif mainsnak.get('datatype', '') == 'quantity': - try: - return self.quantity_to_text(mainsnak['datavalue']['value']) - except Exception as e: - print(e) - return mainsnak['datavalue']['value']['amount'] + elif mainsnak.get('datatype', '') == 'string': + return datavalue - elif mainsnak.get('snaktype', '') == 'novalue': - return self.langvar.novalue + elif mainsnak.get('datatype', '') == 'time': + try: + return self.time_to_text(datavalue) + except Exception as e: + print("Error in time formating:", e) + return datavalue["time"] - return None + elif mainsnak.get('datatype', '') == 'quantity': + try: + return self.quantity_to_text(datavalue) + except Exception as e: + print(e) + return datavalue['amount'] + + elif mainsnak.get('datatype', '') == 'external-id': + return None + + else: + return '' def quantity_to_text(self, quantity_data): """ @@ -156,6 +231,9 @@ def quantity_to_text(self, quantity_data): Returns: - str: A textual representation of the quantity (e.g., "5 kg"). """ + if quantity_data is None: + return None + quantity = quantity_data.get('amount') unit = quantity_data.get('unit') @@ -164,9 +242,7 @@ def quantity_to_text(self, quantity_data): unit = None else: unit_qid = unit.rsplit('/')[-1] - entity = WikidataEntity.get_entity(unit_qid) - if entity: - unit = entity.label + unit = self.get_label(unit_qid, quantity_data.get('unit-labels')) return quantity + (f" {unit}" if unit else "") @@ -180,6 +256,9 @@ def time_to_text(self, time_data): Returns: - str: A textual representation of the time with appropriate granularity. """ + if time_data is None: + return None + time_value = time_data['time'] precision = time_data['precision'] calendarmodel = time_data.get('calendarmodel', 'http://www.wikidata.org/entity/Q1985786') @@ -305,7 +384,7 @@ def chunk_text(self, entity, tokenizer, max_length=500): Splits a text representation of an entity into smaller chunks so that each chunk fits within the token limit of a given tokenizer. Parameters: - - entity (WikidataEntity): The entity to be textified and chunked. + - entity: The entity to be textified and chunked. - tokenizer: A tokenizer (e.g. from Hugging Face) used to count tokens. - max_length (int): The maximum number of tokens allowed per chunk (default is 500). diff --git a/src/wikidataItemDB.py b/src/wikidataItemDB.py new file mode 100644 index 0000000..9d85703 --- /dev/null +++ b/src/wikidataItemDB.py @@ -0,0 +1,383 @@ +from sqlalchemy import Column, Text, create_engine, text +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker +from sqlalchemy.types import TypeDecorator, Boolean +import json +import re + +""" +SQLite database setup for storing Wikidata labels in all languages. +""" +engine = create_engine(f'sqlite:///../data/Wikidata/sqlite_wikidata_items.db', + pool_size=5, # Limit the number of open connections + max_overflow=10, # Allow extra connections beyond pool_size + pool_recycle=10 # Recycle connections every 10 seconds +) +Base = declarative_base() +Session = sessionmaker(bind=engine) + +class JSONType(TypeDecorator): + """Custom SQLAlchemy type for JSON storage in SQLite.""" + impl = Text + cache_ok = False + + def process_bind_param(self, value, dialect): + if value is not None: + return json.dumps(value, separators=(',', ':')) + return None + + def process_result_value(self, value, dialect): + if value is not None: + return json.loads(value) + return None + +class WikidataItem(Base): + """ Represents a Wikidata entity's labels in multiple languages.""" + + __tablename__ = 'item' + + id = Column(Text, primary_key=True) + labels = Column(JSONType) + descriptions = Column(JSONType) + in_wikipedia = Column(Boolean, default=False) + + @staticmethod + def add_bulk_items(data): + """ + Insert multiple label records in bulk. If a record with the same ID exists, + it is ignored (no update is performed). + + Parameters: + - data (list[dict]): A list of dictionaries, each containing 'id', 'labels', 'descriptions', and 'in_wikipedia' keys. + + Returns: + - bool: True if the operation was successful, False otherwise. + """ + worked = False + with Session() as session: + try: + session.execute( + text( + """ + INSERT INTO item (id, labels, descriptions, in_wikipedia) + VALUES (:id, :labels, :descriptions, :in_wikipedia) + ON CONFLICT(id) DO NOTHING + """ + ), + data + ) + session.commit() + session.flush() + worked = True + except Exception as e: + session.rollback() + print(e) + return worked + + @staticmethod + def add_labels(id, labels, descriptions, in_wikipedia): + """ + Insert a single label record into the database. + + Parameters: + - id (str): The unique identifier for the entity. + - labels (dict): A dictionary of labels (e.g. { "en": "Label in English", "fr": "Label in French", ... }). + + Returns: + - bool: True if the operation was successful, False otherwise. + """ + worked = False + with Session() as session: + try: + new_entry = WikidataItem( + id=id, + labels=labels, + descriptions=descriptions, + in_wikipedia=in_wikipedia + ) + session.add(new_entry) + session.commit() + session.flush() + worked = True + except Exception as e: + session.rollback() + print(f"Error: {e}") + return worked + + @staticmethod + def get_labels(id): + """ + Retrieve labels for a given entity by its ID. + + Parameters: + - id (str): The unique identifier of the entity. + + Returns: + - dict: The labels dictionary if found, otherwise an empty dict. + """ + with Session() as session: + item = session.query(WikidataItem).filter_by(id=id).first() + if item is not None: + return item.labels + return {} + + @staticmethod + def get_descriptions(id): + """ + Retrieve labels for a given entity by its ID. + + Parameters: + - id (str): The unique identifier of the entity. + + Returns: + - dict: The labels dictionary if found, otherwise an empty dict. + """ + with Session() as session: + item = session.query(WikidataItem).filter_by(id=id).first() + if item is not None: + return item.descriptions + return {} + + @staticmethod + def get_item(id): + """ + Retrieve item for a given entity by its ID. + + Parameters: + - id (str): The unique identifier of the entity. + + Returns: + - dict: The labels dictionary if found, otherwise an empty dict. + """ + with Session() as session: + item = session.query(WikidataItem).filter_by(id=id).first() + if item is not None: + return item + return {} + + @staticmethod + def clean_label_description(data): + clean_data = {} + for lang, label in data.items(): + clean_data[lang] = label['value'] + return clean_data + + @staticmethod + def is_in_wikipedia(entity): + """ + Check if a Wikidata entity has a corresponding Wikipedia entry in any language. + + Parameters: + - entity (dict): A Wikidata entity dictionary. + + Returns: + - bool: True if the entity has at least one sitelink ending in 'wiki', otherwise False. + """ + if ('sitelinks' in entity): + for s in entity['sitelinks']: + if s.endswith('wiki'): + return True + return False + + @staticmethod + def get_labels_list(id_list): + """ + Retrieve labels for multiple entities at once. + + Parameters: + - id_list (list[str]): A list of entity IDs. + + Returns: + - dict: A mapping of {entity_id: labels_dict} for each found ID. Missing IDs won't appear. + """ + with Session() as session: + rows = ( + session.query(WikidataItem.id, WikidataItem.labels) + .filter(WikidataItem.id.in_(id_list)) + .all() + ) + + return {row_id: row_labels for row_id, row_labels in rows if row_labels is not None} + + @staticmethod + def _remove_keys(data, keys_to_remove=['hash', 'property', 'numeric-id', 'qualifiers-order']): + """ + Recursively remove specific keys from a nested data structure. + + Parameters: + - data (dict or list): The data structure to clean. + - keys_to_remove (list): Keys to remove. Default includes 'hash', 'property', 'numeric-id', and 'qualifiers-order'. + + Returns: + - dict or list: The cleaned data structure with specified keys removed. + """ + if isinstance(data, dict): + data = {key: WikidataItem._remove_keys(value, keys_to_remove) for key, value in data.items() if key not in keys_to_remove} + elif isinstance(data, list): + data = [WikidataItem._remove_keys(item, keys_to_remove) for item in data] + return data + + @staticmethod + def _clean_datavalue(data): + """ + Remove unnecessary nested structures unless they match a Wikidata entity or property pattern. + + Parameters: + - data (dict or list): The data structure to clean. + + Returns: + - dict or list: The cleaned data. + """ + if isinstance(data, dict): + # If there's only one key and it's not a property or QID, recurse into it. + if (len(data.keys()) == 1) and not re.match(r"^[PQ]\d+$", list(data.keys())[0]): + data = WikidataItem._clean_datavalue(data[list(data.keys())[0]]) + else: + data = {key: WikidataItem._clean_datavalue(value) for key, value in data.items()} + elif isinstance(data, list): + data = [WikidataItem._clean_datavalue(item) for item in data] + return data + + @staticmethod + def _gather_labels_ids(data): + """ + Find and return all relevant Wikidata IDs (e.g., property, unit, or datavalue IDs) in the claims. + + Parameters: + - data (dict or list): The data structure to scan. + + Returns: + - list[str]: A list of discovered Wikidata IDs. + """ + ids = set() + + if isinstance(data, dict): + if 'property' in data: + ids.add(data['property']) + + if 'unit' in data and data['unit'] != '1': + unit_id = data['unit'].split('/')[-1] + ids.add(unit_id) + + if ('datatype' in data + and 'datavalue' in data + and data['datatype'] in ('wikibase-item', 'wikibase-property')): + ids.add(data['datavalue']) + + for value in data.values(): + sub_ids = WikidataItem._gather_labels_ids(value) + ids.update(sub_ids) + + elif isinstance(data, list): + for item in data: + sub_ids = WikidataItem._gather_labels_ids(item) + ids.update(sub_ids) + + return list(ids) + + @staticmethod + def _add_labels_to_claims(data, labels_dict={}): + """ + For each found ID (property, unit, or datavalue) within the claims, + insert the corresponding labels from labels_dict or the database. + + Parameters: + - data (dict or list): The claims data structure. + - labels_dict (dict): An optional dict of {id: labels} for quick lookup. + + Returns: + - dict or list: The updated data with added label information. + """ + if isinstance(data, dict): + if 'property' in data: + if data['property'] in labels_dict: + labels = labels_dict[data['property']] + else: + labels = WikidataItem.get_labels(data['property']) + + data = { + **data, + 'property-labels': labels + } + + if ('unit' in data) and (data['unit'] != '1'): + id = data['unit'].split('/')[-1] + if id in labels_dict: + labels = labels_dict[id] + else: + labels = WikidataItem.get_labels(id) + + data = { + **data, + 'unit-labels': labels + } + + if ('datatype' in data) and ('datavalue' in data) and ((data['datatype'] == 'wikibase-item') or (data['datatype'] == 'wikibase-property')): + if data['datavalue'] in labels_dict: + labels = labels_dict[data['datavalue']] + else: + labels = WikidataItem.get_labels(data['datavalue']) + + data['datavalue'] = { + 'id': data['datavalue'], + 'labels': labels + } + + data = {key: WikidataItem._add_labels_to_claims(value, labels_dict=labels_dict) for key, value in data.items()} + + elif isinstance(data, list): + data = [WikidataItem._add_labels_to_claims(item, labels_dict=labels_dict) for item in data] + + return data + + @staticmethod + def add_labels_batched(claims, query_batch=100): + """ + Gather all relevant IDs from claims, batch-fetch their labels, then add them to the claims structure. + + Parameters: + - claims (dict or list): The claims data structure to update. + - query_batch (int): The batch size for querying labels in groups. Default is 100. + + Returns: + - dict or list: The updated claims with labels inserted. + """ + label_ids = WikidataItem._gather_labels_ids(claims) + + labels_dict = {} + for i in range(0, len(label_ids), query_batch): + temp_dict = WikidataItem.get_labels_list(label_ids[i:i+query_batch]) + labels_dict = {**labels_dict, **temp_dict} + + claims = WikidataItem._add_labels_to_claims(claims, labels_dict=labels_dict) + return claims + + @staticmethod + def clean_entity(entity): + """ + Clean a Wikidata entity's data by removing unneeded keys and adding label info to claims. + + Parameters: + - entity (dict): A Wikidata entity dictionary containing 'claims', 'labels', 'sitelinks', etc. + + Returns: + - dict: The cleaned entity with label data integrated into its claims. + """ + clean_claims = WikidataItem._remove_keys(entity.get('claims', {}), ['hash', 'snaktype', 'type', 'entity-type', 'numeric-id', 'qualifiers-order', 'snaks-order']) + clean_claims = WikidataItem._clean_datavalue(clean_claims) + clean_claims = WikidataItem._remove_keys(clean_claims, ['id']) + clean_claims = WikidataItem.add_labels_batched(clean_claims) + + sitelinks = WikidataItem._remove_keys(entity.get('sitelinks', {}), ['badges']) + + return { + 'id': entity['id'], + 'labels': WikidataItem.clean_label_description(entity['labels']), + 'descriptions': WikidataItem.clean_label_description(entity['descriptions']), + 'aliases': entity['aliases'], + 'sitelinks': sitelinks, + 'claims': clean_claims + } + +# Create tables if they don't already exist. +Base.metadata.create_all(engine) \ No newline at end of file diff --git a/src/wikidataDB.py b/src/wikidataLangDB.py similarity index 55% rename from src/wikidataDB.py rename to src/wikidataLangDB.py index b97478e..30e2ce3 100644 --- a/src/wikidataDB.py +++ b/src/wikidataLangDB.py @@ -32,7 +32,7 @@ def process_result_value(self, value, dialect): return json.loads(value) return None -class WikidataEntity(Base): +class WikidataLang(Base): """Represents a Wikidata entity with label, description, aliases, and claims.""" __tablename__ = 'wikidata' @@ -93,7 +93,7 @@ def add_entity(id, label, description, claims, aliases): worked = False with Session() as session: try: - new_entry = WikidataEntity( + new_entry = WikidataLang( id=id, label=label, description=description, @@ -121,7 +121,24 @@ def get_entity(id): - WikidataEntity or None: The entity object if found, otherwise None. """ with Session() as session: - return session.query(WikidataEntity).filter_by(id=id).first() + return session.query(WikidataLang).filter_by(id=id).first() + + @staticmethod + def is_in_wikipedia(item, language='en'): + """ + Check if a Wikidata item has a corresponding Wikipedia entry. + + Parameters: + - item (dict): The Wikidata item. + - language (str): The Wikipedia language code. Default is 'en'. + + Returns: + - bool: True if the item has a Wikipedia sitelink and label/description in the specified language or 'mul'. + """ + condition = ('sitelinks' in item) and (f'{language}wiki' in item['sitelinks']) # Has an Wikipedia Sitelink + condition = condition and ((language in item['labels']) or ('mul' in item['labels'])) # Has a label with the corresponding language or multiligual + condition = condition and ((language in item['descriptions']) or ('mul' in item['descriptions'])) # Has a description with the corresponding language or multiligual + return condition @staticmethod def normalise_item(item, language='en'): @@ -137,8 +154,8 @@ def normalise_item(item, language='en'): """ label = item['labels'][language]['value'] if (language in item['labels']) else (item['labels']['mul']['value'] if ('mul' in item['labels']) else '') # Take the label from the language, if missing take it from the multiligual class description = item['descriptions'][language]['value'] if (language in item['descriptions']) else (item['descriptions']['mul']['value'] if ('mul' in item['descriptions']) else '') # Take the description from the language, if missing take it from the multiligual class - aliases = WikidataEntity._get_aliases(item, language=language) - claims = WikidataEntity._get_claims(item) + aliases = WikidataLang._get_aliases(item, language=language) + claims = WikidataLang._get_claims(item) return { 'id': item['id'], 'label': label, @@ -160,9 +177,9 @@ def _remove_keys(data, keys_to_remove=['hash', 'property', 'numeric-id', 'qualif - dict or list: The cleaned data structure with specified keys removed. """ if isinstance(data, dict): - return {key: WikidataEntity._remove_keys(value, keys_to_remove) for key, value in data.items() if key not in keys_to_remove} + return {key: WikidataLang._remove_keys(value, keys_to_remove) for key, value in data.items() if key not in keys_to_remove} elif isinstance(data, list): - return [WikidataEntity._remove_keys(item, keys_to_remove) for item in data] + return [WikidataLang._remove_keys(item, keys_to_remove) for item in data] else: return data @@ -184,8 +201,8 @@ def _get_claims(item): for i in x: if (i['type'] == 'statement') and (i['rank'] != 'deprecated'): pid_claims.append({ - 'mainsnak': WikidataEntity._remove_keys(i['mainsnak']) if 'mainsnak' in i else {}, - 'qualifiers': WikidataEntity._remove_keys(i['qualifiers']) if 'qualifiers' in i else {}, + 'mainsnak': WikidataLang._remove_keys(i['mainsnak']) if 'mainsnak' in i else {}, + 'qualifiers': WikidataLang._remove_keys(i['qualifiers']) if 'qualifiers' in i else {}, 'rank': i['rank'] }) if len(pid_claims) > 0: @@ -211,158 +228,5 @@ def _get_aliases(item, language='en'): aliases = aliases | set([x['value'] for x in item['aliases']['mul']]) return list(aliases) -class WikidataID(Base): - """ Represents an ID record in the database, indicating whether it appears in Wikipedia or is a property. """ - - __tablename__ = 'wikidataID' - - id = Column(Text, primary_key=True) - in_wikipedia = Column(Boolean, default=False) - is_property = Column(Boolean, default=False) - - @staticmethod - def add_bulk_ids(data): - """ - Add multiple IDs to the database in bulk. If an ID exists, update its boolean fields. - - Parameters: - - data (list[dict]): A list of dictionaries with 'id', 'in_wikipedia', and 'is_property' fields. - - Returns: - - bool: True if successful, False otherwise. - """ - worked = False - with Session() as session: - try: - session.execute( - text( - """ - INSERT INTO wikidataID (id, in_wikipedia, is_property) - VALUES (:id, :in_wikipedia, :is_property) - ON CONFLICT(id) DO UPDATE - SET - in_wikipedia = CASE WHEN excluded.in_wikipedia = TRUE THEN excluded.in_wikipedia ELSE wikidataID.in_wikipedia END, - is_property = CASE WHEN excluded.is_property = TRUE THEN excluded.is_property ELSE wikidataID.is_property END - """ - ), - data - ) - session.commit() - session.flush() - worked = True - except Exception as e: - session.rollback() - print(e) - return worked - - @staticmethod - def add_id(id, in_wikipedia=False, is_property=False): - """ - Add a single ID record to the database. - - Parameters: - - id (str): The unique identifier. - - in_wikipedia (bool): Whether the entity is in Wikipedia. Default is False. - - is_property (bool): Whether the entity is a property. Default is False. - - Returns: - - bool: True if successful, False otherwise. - """ - worked = False - with Session() as session: - try: - new_entry = WikidataID(id=id, in_wikipedia=in_wikipedia, is_property=is_property) - session.add(new_entry) - session.commit() - session.flush() - worked = True - except Exception as e: - session.rollback() - print(e) - return worked - - @staticmethod - def get_id(id): - """ - Retrieve a record by its ID. - - Parameters: - - id (str): The unique identifier of the record. - - Returns: - - WikidataID or None: The record if found, otherwise None. - """ - with Session() as session: - return session.query(WikidataID).filter_by(id=id).first() - - @staticmethod - def is_in_wikipedia(item, language='en'): - """ - Check if a Wikidata item has a corresponding Wikipedia entry. - - Parameters: - - item (dict): The Wikidata item. - - language (str): The Wikipedia language code. Default is 'en'. - - Returns: - - bool: True if the item has a Wikipedia sitelink and label/description in the specified language or 'mul'. - """ - condition = ('sitelinks' in item) and (f'{language}wiki' in item['sitelinks']) # Has an Wikipedia Sitelink - condition = condition and ((language in item['labels']) or ('mul' in item['labels'])) # Has a label with the corresponding language or multiligual - condition = condition and ((language in item['descriptions']) or ('mul' in item['descriptions'])) # Has a description with the corresponding language or multiligual - return condition - - @staticmethod - def extract_entity_ids(item, language='en'): - """ - Extract entity and property IDs from a Wikidata item (including claims, qualifiers, and units). - - Parameters: - - item (dict): The Wikidata item. - - language (str): The language code for additional checks. Default is 'en'. - - Returns: - - list[dict]: A list of dictionaries with 'id', 'in_wikipedia', and 'is_property' for each discovered ID. - """ - if item is None: - return [] - - batch_ids = [{'id': item['id'], 'in_wikipedia': WikidataID.is_in_wikipedia(item, language=language), 'is_property': False}] - - for pid,claim in item.get('claims', {}).items(): - batch_ids.append({'id': pid, 'in_wikipedia': False, 'is_property': True}) - - for c in claim: - if ('mainsnak' in c) and ('datavalue' in c['mainsnak']): - if (c['mainsnak'].get('datatype', '') == 'wikibase-item'): - id = c['mainsnak']['datavalue']['value']['id'] - batch_ids.append({'id': id, 'in_wikipedia': False, 'is_property': False}) - - elif (c['mainsnak'].get('datatype', '') == 'wikibase-property'): - id = c['mainsnak']['datavalue']['value']['id'] - batch_ids.append({'id': id, 'in_wikipedia': False, 'is_property': True}) - - elif (c['mainsnak'].get('datatype', '') == 'quantity') and (c['mainsnak']['datavalue']['value'].get('unit', '1') != '1'): - id = c['mainsnak']['datavalue']['value']['unit'].rsplit('/', 1)[1] - batch_ids.append({'id': id, 'in_wikipedia': False, 'is_property': False}) - - if 'qualifiers' in c: - for pid, qualifier in c['qualifiers'].items(): - batch_ids.append({'id': pid, 'in_wikipedia': False, 'is_property': True}) - for q in qualifier: - if ('datavalue' in q): - if (q['datatype'] == 'wikibase-item'): - id = q['datavalue']['value']['id'] - batch_ids.append({'id': id, 'in_wikipedia': False, 'is_property': False}) - - elif(q['datatype'] == 'wikibase-property'): - id = q['datavalue']['value']['id'] - batch_ids.append({'id': id, 'in_wikipedia': False, 'is_property': True}) - - elif (q['datatype'] == 'quantity') and (q['datavalue']['value'].get('unit', '1') != '1'): - id = q['datavalue']['value']['unit'].rsplit('/', 1)[1] - batch_ids.append({'id': id, 'in_wikipedia': False, 'is_property': False}) - return batch_ids - # Create tables if they don't already exist. Base.metadata.create_all(engine) \ No newline at end of file diff --git a/src/wikidataRetriever.py b/src/wikidataRetriever.py index 0782c01..e96d689 100644 --- a/src/wikidataRetriever.py +++ b/src/wikidataRetriever.py @@ -1,17 +1,9 @@ -from langchain_astradb import AstraDBVectorStore -from langchain_core.documents import Document -from astrapy.info import CollectionVectorServiceOptions -from transformers import AutoTokenizer -import requests -from JinaAI import JinaAIEmbedder import time -from elasticsearch import Elasticsearch - -from mediawikiapi import MediaWikiAPI -from mediawikiapi.config import Config +import json +from wikidataCache import create_cache_embedding_model class AstraDBConnect: - def __init__(self, datastax_token, collection_name, model='nvidia', batch_size=8, cache_embeddings=False): + def __init__(self, datastax_token, collection_name, model='nvidia', batch_size=8, cache_embeddings=None): """ Initialize the AstraDBConnect object with the corresponding embedding model. @@ -20,8 +12,16 @@ def __init__(self, datastax_token, collection_name, model='nvidia', batch_size=8 - collection_name (str): Name of the collection (table) where data is stored. - model (str): The embedding model to use ("nvidia" or "jina"). Default is 'nvidia'. - batch_size (int): Number of documents to accumulate before pushing to AstraDB. Default is 8. - - cache_embeddings (bool): Whether to cache embeddings when using the Jina model. Default is False. + - cache_embeddings (str): Name of the cache table. """ + from langchain_astradb import AstraDBVectorStore + from astrapy.info import CollectionVectorServiceOptions + from astrapy import DataAPIClient + from multiprocessing import Queue + + from transformers import AutoTokenizer + from JinaAI import JinaAIEmbedder, JinaAIAPIEmbedder + ASTRA_DB_APPLICATION_TOKEN = datastax_token['ASTRA_DB_APPLICATION_TOKEN'] ASTRA_DB_API_ENDPOINT = datastax_token["ASTRA_DB_API_ENDPOINT"] ASTRA_DB_KEYSPACE = datastax_token["ASTRA_DB_KEYSPACE"] @@ -29,8 +29,15 @@ def __init__(self, datastax_token, collection_name, model='nvidia', batch_size=8 self.batch_size = batch_size self.model = model self.collection_name = collection_name - self.doc_batch = [] - self.id_batch = [] + self.doc_batch = Queue() + + self.cache_on = (cache_embeddings is not None) + if self.cache_on: + self.cache_model = create_cache_embedding_model(cache_embeddings) + + client = DataAPIClient(datastax_token['ASTRA_DB_APPLICATION_TOKEN']) + database0 = client.get_database(datastax_token['ASTRA_DB_API_ENDPOINT']) + self.graph_store = database0.get_collection(collection_name) if model == 'nvidia': self.tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-large-unsupervised', trust_remote_code=True, clean_up_tokenization_spaces=False) @@ -41,7 +48,7 @@ def __init__(self, datastax_token, collection_name, model='nvidia', batch_size=8 model_name="NV-Embed-QA" ) - self.graph_store = AstraDBVectorStore( + self.vector_search = AstraDBVectorStore( collection_name=collection_name, collection_vector_service_options=collection_vector_service_options, token=ASTRA_DB_APPLICATION_TOKEN, @@ -49,13 +56,26 @@ def __init__(self, datastax_token, collection_name, model='nvidia', batch_size=8 namespace=ASTRA_DB_KEYSPACE, ) elif model == 'jina': - embeddings = JinaAIEmbedder(embedding_dim=1024, cache=cache_embeddings) - self.tokenizer = embeddings.tokenizer + self.embeddings = JinaAIEmbedder(embedding_dim=1024) + self.tokenizer = self.embeddings.tokenizer self.max_token_size = 1024 - self.graph_store = AstraDBVectorStore( + self.vector_search = AstraDBVectorStore( collection_name=collection_name, - embedding=embeddings, + embedding=self.embeddings, + token=ASTRA_DB_APPLICATION_TOKEN, + api_endpoint=ASTRA_DB_API_ENDPOINT, + namespace=ASTRA_DB_KEYSPACE, + ) + + elif model == 'jinaapi': + self.embeddings = JinaAIAPIEmbedder(embedding_dim=1024) + self.tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True) + self.max_token_size = 1024 + + self.vector_search = AstraDBVectorStore( + collection_name=collection_name, + embedding=self.embeddings, token=ASTRA_DB_APPLICATION_TOKEN, api_endpoint=ASTRA_DB_API_ENDPOINT, namespace=ASTRA_DB_KEYSPACE, @@ -72,36 +92,52 @@ def add_document(self, id, text, metadata): - text (str): The text content of the document. - metadata (dict): Additional metadata about the document. """ - doc = Document(page_content=text, metadata=metadata) - self.doc_batch.append(doc) - self.id_batch.append(id) + doc = { + '_id': id, + 'content':text, + 'metadata':metadata + } + self.doc_batch.put(doc) # If we reach the batch size, push the accumulated documents to AstraDB - if len(self.doc_batch) >= self.batch_size: + if self.doc_batch.qsize() >= self.batch_size: self.push_batch() def push_batch(self): """ Push the current batch of documents to AstraDB for storage. - Retries automatically if a connection issue occurs, waiting for - an active internet connection. + Caches the embeddings into a SQLite database. """ - while True: + if self.doc_batch.empty(): + return False + + docs = [] + for _ in range(self.batch_size): try: - self.graph_store.add_documents(self.doc_batch, ids=self.id_batch) - self.doc_batch = [] - self.id_batch = [] + doc = self.doc_batch.get_nowait() + cache = self._get_cached_embedding(doc['_id']) + if cache is None: + docs.append(doc) + except: break - except Exception as e: - print(e) - while True: - try: - response = requests.get("https://www.google.com", timeout=5) - if response.status_code == 200: - break - except Exception as e: - print("Waiting for internet connection...") + + if len(docs) == 0: + return False + + vectors = self.embeddings.embed_documents([doc['content'] for doc in docs]) + + try: + self.graph_store.insert_many(docs, vectors=vectors) + except Exception as e: + print(e) + + self.cache_model.add_bulk_cache([{ + 'id': docs[i]['_id'], + 'embedding': json.dumps(vectors[i], separators=(',', ':'))} + for i in range(len(docs))]) + + return True def get_similar_qids(self, query, filter={}, K=50): """ @@ -117,21 +153,10 @@ def get_similar_qids(self, query, filter={}, K=50): where list_of_qids are the QIDs of the results and list_of_scores are the corresponding similarity scores. """ - while True: - try: - results = self.graph_store.similarity_search_with_relevance_scores(query, k=K, filter=filter) - qid_results = [r[0].metadata['QID'] for r in results] - score_results = [r[1] for r in results] - return qid_results, score_results - except Exception as e: - print(e) - while True: - try: - response = requests.get("https://www.google.com", timeout=5) - if response.status_code == 200: - break - except Exception as e: - time.sleep(5) + results = self.vector_search.similarity_search_with_relevance_scores(query, k=K, filter=filter) + qid_results = [r[0].metadata['QID']+"_"+r[0].metadata.get('Language', '') for r in results] + score_results = [r[1] for r in results] + return qid_results, score_results def batch_retrieve_comparative(self, queries_batch, comparative_batch, K=50, Language=None): """ @@ -186,7 +211,33 @@ def batch_retrieve(self, queries_batch, K=50, Language=None): qids, scores = zip(*results) return list(qids), list(scores) -class WikidataKeywordSearch: + def _cache_embedding(self, id, embedding): + """ + Caches the text and its embedding in the SQLite database. + + Parameters: + - text (str): The text string. + - embedding (List[float]): The embedding vector for the text. + """ + if self.cache_on: + embedding = embedding.tolist() + self.cache_model.add_cache(id=id, embedding=embedding) + + def _get_cached_embedding(self, id): + """ + Retrieves a previously cached embedding for the specified text. + + Parameters: + - text (str): The text string. + + Returns: + - List[float] or None: The embedding if found in cache, otherwise None. + """ + if self.cache_on: + return self.cache_model.get_cache(id=id) + return None + +class KeywordSearchConnect: def __init__(self, url, index_name = 'wikidata'): """ Initialize the WikidataKeywordSearch object with an Elasticsearch instance. @@ -195,97 +246,151 @@ def __init__(self, url, index_name = 'wikidata'): - url (str): URL (host) of the Elasticsearch server. - index_name (str): Name of the Elasticsearch index. Default is 'wikidata'. """ + from elasticsearch import Elasticsearch + self.index_name = index_name self.es = Elasticsearch(url) + self.create_index() - # Create the index if it doesn't already exist + def create_index(self): + """ + Create the index with appropriate settings and mappings to optimize search. + """ if not self.es.indices.exists(index=self.index_name): self.es.indices.create(index=self.index_name, body={ - "mappings": { - "properties": { - "text": { - "type": "text" + "settings": { + "analysis": { + "analyzer": { + "rebuilt_standard": { + "tokenizer": "standard", + "filter": ["lowercase", "stop"] + } + } + } + }, + "mappings": { + "properties": { + "text": { + "type": "text", + "analyzer": "default" + }, + "metadata": { + "type": "object", + "properties": { + "QID": {"type": "keyword"}, + "Language": {"type": "keyword"}, + "Date": {"type": "keyword"} + } + } } } - } - }) + }) - def search(self, query, K=50): + def add_document(self, id, text, metadata): + """ + Add a document to the Elasticsearch index. """ - Perform a keyword-based search against the Elasticsearch index. + doc = { + 'text': text, + 'metadata': {'QID': metadata['QID'], 'Language': metadata['Language']} + } + try: + if self.es.exists(index=self.index_name, id=id): + return + self.es.index(index=self.index_name, id=id, body=doc) + except ConnectionError as e: + print("Connection error:", e) + time.sleep(1) - Parameters: - - query (str): The query string to match against document text. - - K (int): Number of top results to return. Default is 50. + def push_batch(self): + pass - Returns: - - list: A list of raw Elasticsearch hits, each containing a '_score' and '_source'. + def search(self, query, K=50): + """ + Perform a text search using Elasticsearch. + """ + search_body = { + "query": { + "match": { + "text": query + } + }, + "size": K + } + try: + response = self.es.search(index=self.index_name, body=search_body) + return [hit for hit in response['hits']['hits']] + except ConnectionError as e: + print("Connection error:", e) + return [] + + def get_similar_qids(self, query, filter=[], K=50): + """ + Retrieve documents based on similarity to a query, potentially with filtering. """ search_body = { "query": { "bool": { - "should": [ - { - "match": { - "text": { - "query": query, - "operator": "or", - "boost": 1.0 - } - } - }, - { - "match_all": { - "boost": 0.01 # Lower boost to make match_all results less relevant - } + "must": { + "match": { + "text": query } - ] + }, + "filter": filter } }, - "size": K, - "sort": [ - { - "_score": { - "order": "desc" - } - } - ] + "size": K } - response = self.es.search(index=self.index_name, body=search_body) - return [hit for hit in response['hits']['hits']] - - def get_similar_qids(self, query, filter_qid={}, K=50): - """ - Retrieve QIDs based on a keyword-based search. Optionally filter by QID. + try: + response = self.es.search(index=self.index_name, body=search_body) + qid_results = [hit['_source']['metadata']['QID'] for hit in response['hits']['hits']] + score_results = [hit['_score'] for hit in response['hits']['hits']] + return qid_results, score_results - Parameters: - - query (str): The search string. - - filter_qid (dict): Optional filter (currently unused, placeholder). - - K (int): Number of top results to return. Default is 50. + except Exception as e: + print("Search failed:", e) + return [] - Returns: - - tuple: (list_of_qids, list_of_scores) + def batch_retrieve_comparative(self, queries_batch, comparative_batch, K=50, Language=None): """ - results = self.search(query, K=K) - qid_results = [r['_id'].split("_")[0] for r in results] - score_results = [r['_score'] for r in results] - return qid_results, score_results - - def batch_retrieve(self, queries_batch, K=50): + Retrieve similar documents in a comparative fashion for each query and comparative item. """ - Perform keyword-based search for a batch of queries. + qids = [[] for _ in range(len(queries_batch))] + scores = [[] for _ in range(len(queries_batch))] - Parameters: - - queries_batch (pd.Series or list): A list or series of query strings. - - K (int): Number of top results to return for each query. Default is 50. + for i, query in enumerate(queries_batch): + for comp_col in comparative_batch.columns: + filter = [] + # Apply language filter if specified + if Language: + filter.append({"term": {"metadata.Language": Language}}) + # Apply QID filter specific to the comparative group + qid_filter = comparative_batch[comp_col].iloc[i] + filter.append({"term": {"metadata.QID": qid_filter}}) - Returns: - - tuple: (list_of_qid_lists, list_of_score_lists), each element corresponding to a single query. + result_qids, result_scores = self.get_similar_qids(query, filter=filter, K=K) + qids[i].extend(result_qids) + scores[i].extend(result_scores) + + return qids, scores + + def batch_retrieve(self, queries_batch, K=50, Language=None): """ - results = [ - self.get_similar_qids(queries_batch.iloc[i], K=K) - for i in range(len(queries_batch)) - ] + Perform batch searches and handle potential connection issues. + """ + filter = [] + if Language: + languages = Language.split(',') + filter.append({"bool": {"should": [{"term": {"metadata.Language": lang}} for lang in languages]}}) - qids, scores = zip(*results) + results = [] + for query in queries_batch: + try: + result = self.get_similar_qids(query, K=K, filter=filter) + results.append(result) + except ConnectionError as e: + print("Connection error during batch processing:", e) + time.sleep(1) + + qids, scores = zip(*results) if results else ([], []) return list(qids), list(scores) \ No newline at end of file From 4eb77d21bb2527ef62471ed9198bb2bd4ea14246 Mon Sep 17 00:00:00 2001 From: Philippe Saade Date: Tue, 11 Mar 2025 15:20:43 +0100 Subject: [PATCH 02/49] Fix tqdm --- docker/7_Create_Prototype/run.py | 59 +++++++++++++++----------------- 1 file changed, 28 insertions(+), 31 deletions(-) diff --git a/docker/7_Create_Prototype/run.py b/docker/7_Create_Prototype/run.py index 5be1dd3..cb69030 100644 --- a/docker/7_Create_Prototype/run.py +++ b/docker/7_Create_Prototype/run.py @@ -61,37 +61,34 @@ def process_items(queue, progress_bar): item_description = textifier.get_description(item_id, json.loads(item['descriptions'])) item_aliases = textifier.get_aliases(json.loads(item['aliases'])) - if item_label is None: - continue - - entity_obj = SimpleNamespace() - entity_obj.id = item_id - entity_obj.label = item_label - entity_obj.description = item_description - entity_obj.aliases = item_aliases - entity_obj.claims = json.loads(item['claims']) - - chunks = textifier.chunk_text(entity_obj, graph_store.tokenizer, max_length=graph_store.max_token_size) - - for chunk_i, chunk in enumerate(chunks): - md5_hash = hashlib.md5(chunk.encode('utf-8')).hexdigest() - metadata = { - "MD5": md5_hash, - "Label": item_label, - "Description": item_description, - "Aliases": item_aliases, - "Date": datetime.now().isoformat(), - "QID": item_id, - "ChunkID": chunk_i + 1, - "Language": LANGUAGE, - "IsItem": ('Q' in item_id), - "IsProperty": ('P' in item_id), - "DumpDate": DUMPDATE - } - graph_store.add_document(id=f"{item_id}_{LANGUAGE}_{chunk_i+1}", text=chunk, metadata=metadata) - - with progress_bar.get_lock(): # Update tqdm safely from multiple processes - progress_bar.value += 1 + if item_label is not None: + entity_obj = SimpleNamespace() + entity_obj.id = item_id + entity_obj.label = item_label + entity_obj.description = item_description + entity_obj.aliases = item_aliases + entity_obj.claims = json.loads(item['claims']) + + chunks = textifier.chunk_text(entity_obj, graph_store.tokenizer, max_length=graph_store.max_token_size) + + for chunk_i, chunk in enumerate(chunks): + md5_hash = hashlib.md5(chunk.encode('utf-8')).hexdigest() + metadata = { + "MD5": md5_hash, + "Label": item_label, + "Description": item_description, + "Aliases": item_aliases, + "Date": datetime.now().isoformat(), + "QID": item_id, + "ChunkID": chunk_i + 1, + "Language": LANGUAGE, + "IsItem": ('Q' in item_id), + "IsProperty": ('P' in item_id), + "DumpDate": DUMPDATE + } + graph_store.add_document(id=f"{item_id}_{LANGUAGE}_{chunk_i+1}", text=chunk, metadata=metadata) + + progress_bar.value += 1 while True: if not graph_store.push_batch(): # Stop when batch is empty From c9b360be9db0fbbc1b83d9a9fdeebe93b7d66c8e Mon Sep 17 00:00:00 2001 From: exowanderer Date: Tue, 11 Mar 2025 16:59:07 +0100 Subject: [PATCH 03/49] updated docker-compose.yml, docker7 run.py, wikidataItemDB, wikidataRetriever --- docker-compose.yml | 3 +- docker/7_Create_Prototype/run.py | 56 ++++++++++++++++++++++++------ src/wikidataItemDB.py | 59 ++++++++++++++++++++------------ src/wikidataRetriever.py | 4 +-- 4 files changed, 88 insertions(+), 34 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index 37bc8b4..5d06906 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -141,10 +141,11 @@ services: MODEL: "jinaapi" API_KEY: "datastax_wikidata.json" EMBED_BATCH_SIZE: 100 + # QUEUE_SIZE: 5000 NUM_PROCESSES: 16 OFFSET: 0 COLLECTION_NAME: "wikidata_prototype" LANGUAGE: 'en' TEXTIFIER_LANGUAGE: 'en' - CHUNK_NUM: 5 + # CHUNK_NUM: 5 network_mode: "host" \ No newline at end of file diff --git a/docker/7_Create_Prototype/run.py b/docker/7_Create_Prototype/run.py index cb69030..da554c3 100644 --- a/docker/7_Create_Prototype/run.py +++ b/docker/7_Create_Prototype/run.py @@ -1,3 +1,4 @@ +# TODO: package with setup inside docker to avoid sys.path mods import sys sys.path.append('../src') @@ -15,13 +16,20 @@ import time MODEL = os.getenv("MODEL", "jinaapi") -QUEUE_SIZE = int(os.getenv("QUEUE_SIZE", 5000)) NUM_PROCESSES = int(os.getenv("NUM_PROCESSES", 4)) EMBED_BATCH_SIZE = int(os.getenv("EMBED_BATCH_SIZE", 100)) + +QUEUE_SIZE = 2 * EMBED_BATCH_SIZE * NUM_PROCESSES # enough to not run out +QUEUE_SIZE = int(os.getenv("QUEUE_SIZE", QUEUE_SIZE)) + DB_API_KEY_FILENAME = os.getenv("DB_API_KEY", "datastax_wikidata.json") COLLECTION_NAME = os.getenv("COLLECTION_NAME") CHUNK_NUM = os.getenv("CHUNK_NUM") + +assert(CHUNK_NUM is not None), \ + "Please provide `CHUNK_NUM` env var at docker run" + LANGUAGE = "en" TEXTIFIER_LANGUAGE = "en" DUMPDATE = "09/18/2024" @@ -34,7 +42,10 @@ TEXTIFIER_LANGUAGE = LANGUAGE FILEPATH = f"../data/Wikidata/chunks/chunk_{CHUNK_NUM}.json.gz" + +# TODO: Push this dict into a json chunk_sizes = {"chunk_0":992458,"chunk_1":802125,"chunk_2":589652,"chunk_3":310440,"chunk_4":43477,"chunk_5":156867,"chunk_6":141965,"chunk_7":74047,"chunk_8":27104,"chunk_9":70759,"chunk_10":71395,"chunk_11":186698,"chunk_12":153182,"chunk_13":137155,"chunk_14":929827,"chunk_15":853027,"chunk_16":571543,"chunk_17":335565,"chunk_18":47264,"chunk_19":135986,"chunk_20":160411,"chunk_21":76377,"chunk_22":26321,"chunk_23":70572,"chunk_24":68613,"chunk_25":179806,"chunk_26":159587,"chunk_27":139912,"chunk_28":876104,"chunk_29":864360,"chunk_30":590603,"chunk_31":358747,"chunk_32":47772,"chunk_33":135633,"chunk_34":159629,"chunk_35":81231,"chunk_36":24912,"chunk_37":69201,"chunk_38":67131,"chunk_39":172234,"chunk_40":167698,"chunk_41":142276,"chunk_42":821175,"chunk_43":892005,"chunk_44":600584,"chunk_45":374793,"chunk_46":47443,"chunk_47":134784,"chunk_48":155247,"chunk_49":86997,"chunk_50":24829,"chunk_51":68053,"chunk_52":63517,"chunk_53":167660,"chunk_54":175827,"chunk_55":142816,"chunk_56":765400,"chunk_57":900655,"chunk_58":628866,"chunk_59":396886,"chunk_60":46907,"chunk_61":135384,"chunk_62":154864,"chunk_63":88112,"chunk_64":23353,"chunk_65":67446,"chunk_66":40301,"chunk_67":176420,"chunk_68":183715,"chunk_69":149547,"chunk_70":713006,"chunk_71":901222,"chunk_72":652770,"chunk_73":419554,"chunk_74":52246,"chunk_75":134064,"chunk_76":153318,"chunk_77":92710,"chunk_78":22790,"chunk_79":66521,"chunk_80":34397,"chunk_81":173357,"chunk_82":186788,"chunk_83":153870,"chunk_84":657926,"chunk_85":902477,"chunk_86":655319,"chunk_87":455111,"chunk_88":69724,"chunk_89":133629,"chunk_90":146534,"chunk_91":101890,"chunk_92":21324,"chunk_93":65448,"chunk_94":33345,"chunk_95":162191,"chunk_96":192226,"chunk_97":159451,"chunk_98":598037,"chunk_99":903618,"chunk_100":662580,"chunk_101":484690,"chunk_102":86616,"chunk_103":135160,"chunk_104":106630,"chunk_105":142249,"chunk_106":19290,"chunk_107":60073,"chunk_108":39131,"chunk_109":155251,"chunk_110":190337,"chunk_111":166210,"chunk_112":26375} + total_entities = chunk_sizes[f"chunk_{CHUNK_NUM}"] datastax_token = json.load(open(f"../API_tokens/{DB_API_KEY_FILENAME}")) @@ -48,8 +59,17 @@ def process_items(queue, progress_bar): """Worker function that processes items from the queue and adds them to AstraDB.""" datastax_token = json.load(open(f"../API_tokens/{DB_API_KEY_FILENAME}")) - graph_store = AstraDBConnect(datastax_token, COLLECTION_NAME, model=MODEL, batch_size=EMBED_BATCH_SIZE, cache_embeddings="wikidata_prototype") - textifier = WikidataTextifier(language=LANGUAGE, langvar_filename=TEXTIFIER_LANGUAGE) + graph_store = AstraDBConnect( + datastax_token, + COLLECTION_NAME, + model=MODEL, + batch_size=EMBED_BATCH_SIZE, + cache_embeddings="wikidata_prototype" + ) + textifier = WikidataTextifier( + language=LANGUAGE, + langvar_filename=TEXTIFIER_LANGUAGE + ) while True: item = queue.get() @@ -58,10 +78,14 @@ def process_items(queue, progress_bar): item_id = item['id'] item_label = textifier.get_label(item_id, json.loads(item['labels'])) - item_description = textifier.get_description(item_id, json.loads(item['descriptions'])) + item_description = textifier.get_description( + item_id, + json.loads(item['descriptions']) + ) item_aliases = textifier.get_aliases(json.loads(item['aliases'])) if item_label is not None: + # TODO: Verify: If label does not exist, then skip item entity_obj = SimpleNamespace() entity_obj.id = item_id entity_obj.label = item_label @@ -69,7 +93,11 @@ def process_items(queue, progress_bar): entity_obj.aliases = item_aliases entity_obj.claims = json.loads(item['claims']) - chunks = textifier.chunk_text(entity_obj, graph_store.tokenizer, max_length=graph_store.max_token_size) + chunks = textifier.chunk_text( + entity_obj, + graph_store.tokenizer, + max_length=graph_store.max_token_size + ) for chunk_i, chunk in enumerate(chunks): md5_hash = hashlib.md5(chunk.encode('utf-8')).hexdigest() @@ -86,11 +114,17 @@ def process_items(queue, progress_bar): "IsProperty": ('P' in item_id), "DumpDate": DUMPDATE } - graph_store.add_document(id=f"{item_id}_{LANGUAGE}_{chunk_i+1}", text=chunk, metadata=metadata) + + graph_store.add_document( + id=f"{item_id}_{LANGUAGE}_{chunk_i+1}", + text=chunk, + metadata=metadata + ) progress_bar.value += 1 while True: + # Leftover Maintenance: Ensure that the batch is emptied out if not graph_store.push_batch(): # Stop when batch is empty break @@ -107,15 +141,17 @@ def process_items(queue, progress_bar): for item in dataset: queue.put(item) - pbar.n = progress_bar.value - pbar.refresh() + pbar.update(progress_bar.value - pbar.n) + # pbar.n = progress_bar.value + # pbar.refresh() for _ in range(NUM_PROCESSES): queue.put(None) while any(p.is_alive() for p in processes): - pbar.n = progress_bar.value - pbar.refresh() + pbar.update(progress_bar.value - pbar.n) + # pbar.n = progress_bar.value + # pbar.refresh() time.sleep(1) for p in processes: diff --git a/src/wikidataItemDB.py b/src/wikidataItemDB.py index 9d85703..dd86e0e 100644 --- a/src/wikidataItemDB.py +++ b/src/wikidataItemDB.py @@ -1,4 +1,4 @@ -from sqlalchemy import Column, Text, create_engine, text +from sqlalchemy import Column, Text, String, Integer, create_engine, text from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker from sqlalchemy.types import TypeDecorator, Boolean @@ -6,13 +6,17 @@ import re """ -SQLite database setup for storing Wikidata labels in all languages. +SQLite database setup for storing Wikidata labels & descriptions +in all languages. """ -engine = create_engine(f'sqlite:///../data/Wikidata/sqlite_wikidata_items.db', - pool_size=5, # Limit the number of open connections - max_overflow=10, # Allow extra connections beyond pool_size - pool_recycle=10 # Recycle connections every 10 seconds + +SQLITEDB_PATH = '../data/Wikidata/sqlite_wikidata_items.db' +engine = create_engine(f'sqlite:///{SQLITEDB_PATH}', + pool_size=5, # Limit the number of open connections + max_overflow=10, # Allow extra connections beyond pool_size + pool_recycle=10 # Recycle connections every 10 seconds ) + Base = declarative_base() Session = sessionmaker(bind=engine) @@ -36,6 +40,11 @@ class WikidataItem(Base): __tablename__ = 'item' + # TODO: convert ID to Integer and store existin IDs as qpid + """ + id = Column(Integer, primary_key=True) + qpid = Column(String, unique=True, index=True) + """ id = Column(Text, primary_key=True) labels = Column(JSONType) descriptions = Column(JSONType) @@ -53,26 +62,31 @@ def add_bulk_items(data): Returns: - bool: True if the operation was successful, False otherwise. """ - worked = False + worked = False # Assume the operation failed with Session() as session: try: - session.execute( - text( - """ - INSERT INTO item (id, labels, descriptions, in_wikipedia) - VALUES (:id, :labels, :descriptions, :in_wikipedia) - ON CONFLICT(id) DO NOTHING - """ - ), - data + # Use a text statement to operate bulk insert + # SQLAlchemy's ORM is unable to handle bulk inserts + # with ON CONFLICT. + + insert_stmt = text( + """ + INSERT INTO item (id, labels, descriptions, in_wikipedia) + VALUES (:id, :labels, :descriptions, :in_wikipedia) + ON CONFLICT(id) DO NOTHING + """ ) + + # Execute the insert statement for each data entry. + session.execute(insert_stmt, data) session.commit() session.flush() - worked = True + worked = True # Mark the operation as successful except Exception as e: session.rollback() print(e) - return worked + + return worked # Return the operation status @staticmethod def add_labels(id, labels, descriptions, in_wikipedia): @@ -259,9 +273,12 @@ def _gather_labels_ids(data): unit_id = data['unit'].split('/')[-1] ids.add(unit_id) - if ('datatype' in data - and 'datavalue' in data - and data['datatype'] in ('wikibase-item', 'wikibase-property')): + datatype_in_data = 'datatype' in data + datavalue_in_data = 'datavalue' in data + data_datatype = data['datatype'] in ( + 'wikibase-item', 'wikibase-property' + ) + if datatype_in_data and datavalue_in_data and data_datatype: ids.add(data['datavalue']) for value in data.values(): diff --git a/src/wikidataRetriever.py b/src/wikidataRetriever.py index e96d689..effb5f6 100644 --- a/src/wikidataRetriever.py +++ b/src/wikidataRetriever.py @@ -3,14 +3,14 @@ from wikidataCache import create_cache_embedding_model class AstraDBConnect: - def __init__(self, datastax_token, collection_name, model='nvidia', batch_size=8, cache_embeddings=None): + def __init__(self, datastax_token, collection_name, model='jina', batch_size=8, cache_embeddings=None): """ Initialize the AstraDBConnect object with the corresponding embedding model. Parameters: - datastax_token (dict): Credentials for DataStax Astra, including token and API endpoint. - collection_name (str): Name of the collection (table) where data is stored. - - model (str): The embedding model to use ("nvidia" or "jina"). Default is 'nvidia'. + - model (str): The embedding model to use. Default is 'jina'. - batch_size (int): Number of documents to accumulate before pushing to AstraDB. Default is 8. - cache_embeddings (str): Name of the cache table. """ From 34d9bb45770e437c097ddde0b17b8ffe20b6284d Mon Sep 17 00:00:00 2001 From: exowanderer Date: Tue, 11 Mar 2025 17:13:26 +0100 Subject: [PATCH 04/49] added docstrings to wikidataEmbed --- src/wikidataEmbed.py | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/src/wikidataEmbed.py b/src/wikidataEmbed.py index 9f7b518..ee44832 100644 --- a/src/wikidataEmbed.py +++ b/src/wikidataEmbed.py @@ -10,20 +10,36 @@ class WikidataTextifier: def __init__(self, language='en', langvar_filename=None): """ Initializes the WikidataTextifier with the specified language. + Expected use cases: Hugging Face parquet or sqlite database. Parameters: - - language (str): The language code used by the textifier (default is "en"). + - language (str): The language code used by the textifier + Default is "en". """ self.language = language - langvar_filename = (langvar_filename if langvar_filename is not None else language) + langvar_filename = ( + langvar_filename if langvar_filename is not None else language + ) try: - # Importing custom functions and variables from a formating python script in the language_variables folder. - self.langvar = importlib.import_module(f"language_variables.{langvar_filename}") + # Importing custom functions and variables + # from a formating python script in the language_variables folder. + self.langvar = importlib.import_module( + f"language_variables.{langvar_filename}" + ) except Exception as e: raise ValueError(f"Language file for '{language}' not found.") def get_label(self, id, labels=None): + """Retrieves the label for a Wikidata entity in a specified language. + + Args: + id (str): QID or PID from the ID column in the Wikidata db or JSON. + labels (dict, optional): Wikidata labels in all available languages, else None. Defaults to None. + + Returns: + str: Wikidata label from specified language or mul[tilingual]. + """ if (labels is None) or (len(labels) == 0): labels = WikidataItem.get_labels(id) @@ -38,6 +54,15 @@ def get_label(self, id, labels=None): return label def get_description(self, id, descriptions=None): + """Retrieves the description for a Wikidata entity in the specified language. + + Args: + id (str): QID or PID from the ID column in the Wikidata db or JSON. + descriptions (dict, optional): Wikidata descriptions in all available languages, else None. Defaults to None. + + Returns: + str: Wikidata description from specified language or mul[tilingual]. + """ if (descriptions is None) or (len(descriptions) == 0): descriptions = WikidataItem.get_descriptions(id) From 76a5f31d48cd671e67c680cc1e5a0c53fd0aaa9a Mon Sep 17 00:00:00 2001 From: exowanderer Date: Tue, 11 Mar 2025 17:18:19 +0100 Subject: [PATCH 05/49] added docstrings to wikidataEmbed --- src/wikidataEmbed.py | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/src/wikidataEmbed.py b/src/wikidataEmbed.py index ee44832..a7f34f2 100644 --- a/src/wikidataEmbed.py +++ b/src/wikidataEmbed.py @@ -7,6 +7,8 @@ import importlib class WikidataTextifier: + """_summary_ + """ def __init__(self, language='en', langvar_filename=None): """ Initializes the WikidataTextifier with the specified language. @@ -69,7 +71,8 @@ def get_description(self, id, descriptions=None): if (type(descriptions) is str): return descriptions - # Take the description from the language, if missing take it from the multiligual class + # Take the description from the language, + # if missing take it from the multiligual class description = descriptions[self.language] if (self.language in descriptions) else (descriptions['mul'] if ('mul' in descriptions) else None) if type(description) is dict: @@ -77,6 +80,14 @@ def get_description(self, id, descriptions=None): return description def get_aliases(self, aliases): + """Retrieves the aliases for a Wikidata entity in the specified language. + + Args: + aliases (dict, optional): Wikidata aliases in all available languages, else None. Defaults to None. + + Returns: + list: Wikidata aliases from specified language and mul[tilingual]. + """ if (type(aliases) is list): return aliases @@ -91,22 +102,27 @@ def get_aliases(self, aliases): return list(aliases) def entity_to_text(self, entity, properties=None): - """ - Converts a Wikidata entity into a human-readable text string. + """Converts a Wikidata entity into a human-readable text string. - Parameters: - - entity: A Wikidata entity object containing entity data (label, description, claims, etc.). - - properties (dict or None): A dictionary of properties (claims). If None, the properties will be derived from entity.claims. + Args: + entity (obj): A Wikidata entity object containing + entity data (label, description, claims, etc.) + properties (dict or None, optional): A dictionary of + properties (claims). If None, the properties will be derived + from entity.claims. Defaults to None. Returns: - - str: A human-readable representation of the entity, its description, aliases, and claims. + (str): A human-readable representation of the entity, its description, aliases, and claims. """ if properties is None: properties = self.properties_to_dict(entity.claims) label = self.get_label(entity.id, labels=entity.label) - description = self.get_description(entity.id, descriptions=entity.description) + description = self.get_description( + entity.id, + descriptions=entity.description + ) if (description is None) or (len(description) == 0): instanceof = self.get_label('P31') description = properties.get(instanceof, '') From 481a95b53feac4ab8daea69940edab8eb059b952 Mon Sep 17 00:00:00 2001 From: exowanderer Date: Tue, 11 Mar 2025 17:21:01 +0100 Subject: [PATCH 06/49] added docstrings to wikidataEmbed --- src/wikidataEmbed.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/wikidataEmbed.py b/src/wikidataEmbed.py index a7f34f2..e04b16e 100644 --- a/src/wikidataEmbed.py +++ b/src/wikidataEmbed.py @@ -43,6 +43,8 @@ def get_label(self, id, labels=None): str: Wikidata label from specified language or mul[tilingual]. """ if (labels is None) or (len(labels) == 0): + # If the labels are not provided, fetch them from the Wikidata SQLDB + # TODO: Fetch from the Wikidata API if not found in the SQLDB labels = WikidataItem.get_labels(id) if (type(labels) is str): From d1dd406f7cc7c55c08bfa81c6342b6238e11dfb0 Mon Sep 17 00:00:00 2001 From: exowanderer Date: Tue, 11 Mar 2025 17:32:42 +0100 Subject: [PATCH 07/49] added docstrings to wikidataEmbed --- src/wikidataEmbed.py | 43 +++++++++++++++++++++++++++++++------------ 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/src/wikidataEmbed.py b/src/wikidataEmbed.py index e04b16e..5fe8681 100644 --- a/src/wikidataEmbed.py +++ b/src/wikidataEmbed.py @@ -47,18 +47,25 @@ def get_label(self, id, labels=None): # TODO: Fetch from the Wikidata API if not found in the SQLDB labels = WikidataItem.get_labels(id) - if (type(labels) is str): + if isinstance(labels, str): + # If the labels are a string, return them as is return labels - # Take the label from the language, if missing take it from the multiligual class - label = labels[self.language] if (self.language in labels) else (labels['mul'] if ('mul' in labels) else None) + # Take the label from the language, if missing take it + # from the multiligual class + + label = labels.get(self.language) + if label is None: + label = labels.get('mul') + + if isinstance(label, dict): + label = label.get('value') - if type(label) is dict: - label = label['value'] return label def get_description(self, id, descriptions=None): - """Retrieves the description for a Wikidata entity in the specified language. + """Retrieves the description for a Wikidata entity + in the specified language. Args: id (str): QID or PID from the ID column in the Wikidata db or JSON. @@ -68,19 +75,26 @@ def get_description(self, id, descriptions=None): str: Wikidata description from specified language or mul[tilingual]. """ if (descriptions is None) or (len(descriptions) == 0): + # If the descriptions are not provided, fetch them from the Wikidata SQLDB + # TODO: Fetch from the Wikidata API if not found in the SQLDB descriptions = WikidataItem.get_descriptions(id) - if (type(descriptions) is str): + if isinstance(descriptions, str): return descriptions + # Take the description from the language, # if missing take it from the multiligual class - description = descriptions[self.language] if (self.language in descriptions) else (descriptions['mul'] if ('mul' in descriptions) else None) + description = descriptions.get(self.language) + if description is None: + description = descriptions.get('mul') + + if isinstance(description, dict): + description = description.get('value') - if type(description) is dict: - description = description['value'] return description + def get_aliases(self, aliases): """Retrieves the aliases for a Wikidata entity in the specified language. @@ -96,11 +110,16 @@ def get_aliases(self, aliases): if aliases is None: return [] + + # Combine the aliases from the specified language and the + # multilingual class. Use set format to avoid duplicates. aliases = set() if self.language in aliases: - aliases = set([x['value'] for x in aliases[self.language]]) + aliases.update([x['value'] for x in aliases[self.language]]) + if 'mul' in aliases: - aliases = aliases | set([x['value'] for x in aliases['mul']]) + aliases.update([x['value'] for x in aliases['mul']]) + return list(aliases) def entity_to_text(self, entity, properties=None): From d84dadebb4b7964c41c2911dd957d5e902e9aaef Mon Sep 17 00:00:00 2001 From: exowanderer Date: Tue, 11 Mar 2025 17:34:34 +0100 Subject: [PATCH 08/49] added docstrings to wikidataEmbed --- src/wikidataEmbed.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/wikidataEmbed.py b/src/wikidataEmbed.py index 5fe8681..114fdb0 100644 --- a/src/wikidataEmbed.py +++ b/src/wikidataEmbed.py @@ -136,8 +136,10 @@ def entity_to_text(self, entity, properties=None): (str): A human-readable representation of the entity, its description, aliases, and claims. """ if properties is None: + # If properties are not provided, fetch them from the entity properties = self.properties_to_dict(entity.claims) + # Get the label, description, and aliases for the entity label = self.get_label(entity.id, labels=entity.label) description = self.get_description( @@ -145,12 +147,19 @@ def entity_to_text(self, entity, properties=None): descriptions=entity.description ) if (description is None) or (len(description) == 0): + # If the description is missing, try to get it + # from the `instance_of` property instanceof = self.get_label('P31') description = properties.get(instanceof, '') aliases = self.get_aliases(entity.aliases) - return self.langvar.merge_entity_text(label, description, aliases, properties) + return self.langvar.merge_entity_text( + label, + description, + aliases, + properties + ) def properties_to_dict(self, properties): """ From 3783b0d2a87cf243049899b7f22de4b0b28edcbd Mon Sep 17 00:00:00 2001 From: exowanderer Date: Tue, 11 Mar 2025 17:37:09 +0100 Subject: [PATCH 09/49] added docstrings to wikidataEmbed --- src/wikidataEmbed.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/wikidataEmbed.py b/src/wikidataEmbed.py index 114fdb0..f1dc2ef 100644 --- a/src/wikidataEmbed.py +++ b/src/wikidataEmbed.py @@ -154,6 +154,8 @@ def entity_to_text(self, entity, properties=None): aliases = self.get_aliases(entity.aliases) + # Merge the label, description, aliases, and properties into a single + # text string as the Data Model per language through langvar descriptors return self.langvar.merge_entity_text( label, description, @@ -167,10 +169,11 @@ def properties_to_dict(self, properties): Parameters: - properties (dict): A dictionary of claims keyed by property IDs. - Each value is a list of claim statements for that property. + Each value is a list of claim statements for that property. Returns: - - dict: A dictionary mapping property labels to a list of their parsed values (and qualifiers). + - dict: A dictionary mapping property labels to a list of + their parsed values (and qualifiers). """ properties_dict = {} for pid, claim in properties.items(): @@ -180,7 +183,9 @@ def properties_to_dict(self, properties): for c in claim: try: value = self.mainsnak_to_value(c.get('mainsnak', c)) - qualifiers = self.qualifiers_to_dict(c.get('qualifiers', {})) + qualifiers = self.qualifiers_to_dict( + c.get('qualifiers', {}) + ) rank = c.get('rank', 'normal').lower() if value is None: From 03f6f144a9a730269641eeb5d4596c71343bb678 Mon Sep 17 00:00:00 2001 From: exowanderer Date: Tue, 11 Mar 2025 17:48:05 +0100 Subject: [PATCH 10/49] added docstrings to wikidataEmbed --- src/wikidataEmbed.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/wikidataEmbed.py b/src/wikidataEmbed.py index f1dc2ef..1f4e46b 100644 --- a/src/wikidataEmbed.py +++ b/src/wikidataEmbed.py @@ -110,7 +110,6 @@ def get_aliases(self, aliases): if aliases is None: return [] - # Combine the aliases from the specified language and the # multilingual class. Use set format to avoid duplicates. aliases = set() @@ -193,7 +192,10 @@ def properties_to_dict(self, properties): break elif len(value) > 0: - # If a preferred rank exists, include values that are only preferred. Else include only values that are ranked normal (values with a depricated rank are never included) + # If a preferred rank exists, include values that are + # only preferred. Else include only values that are + # ranked normal (values with a depricated rank are + # never included) if ((not rank_preferred_found) and (rank == 'normal')) or (rank == 'preferred'): if (not rank_preferred_found) and (rank == 'preferred'): rank_preferred_found = True From ed87ce4396dd9eb1ffa284ce29def049e81ba963 Mon Sep 17 00:00:00 2001 From: exowanderer Date: Tue, 11 Mar 2025 18:00:18 +0100 Subject: [PATCH 11/49] moved mega dict to json and fixed junior eng file open bugs --- docker/7_Create_Prototype/run.py | 26 +++- .../wikidata_chunk_sizes_2024-09-18.json | 115 ++++++++++++++++++ run_experiments.sh | 71 +++++++++-- 3 files changed, 196 insertions(+), 16 deletions(-) create mode 100644 docker/7_Create_Prototype/wikidata_chunk_sizes_2024-09-18.json diff --git a/docker/7_Create_Prototype/run.py b/docker/7_Create_Prototype/run.py index da554c3..25c06d4 100644 --- a/docker/7_Create_Prototype/run.py +++ b/docker/7_Create_Prototype/run.py @@ -43,12 +43,24 @@ FILEPATH = f"../data/Wikidata/chunks/chunk_{CHUNK_NUM}.json.gz" -# TODO: Push this dict into a json -chunk_sizes = {"chunk_0":992458,"chunk_1":802125,"chunk_2":589652,"chunk_3":310440,"chunk_4":43477,"chunk_5":156867,"chunk_6":141965,"chunk_7":74047,"chunk_8":27104,"chunk_9":70759,"chunk_10":71395,"chunk_11":186698,"chunk_12":153182,"chunk_13":137155,"chunk_14":929827,"chunk_15":853027,"chunk_16":571543,"chunk_17":335565,"chunk_18":47264,"chunk_19":135986,"chunk_20":160411,"chunk_21":76377,"chunk_22":26321,"chunk_23":70572,"chunk_24":68613,"chunk_25":179806,"chunk_26":159587,"chunk_27":139912,"chunk_28":876104,"chunk_29":864360,"chunk_30":590603,"chunk_31":358747,"chunk_32":47772,"chunk_33":135633,"chunk_34":159629,"chunk_35":81231,"chunk_36":24912,"chunk_37":69201,"chunk_38":67131,"chunk_39":172234,"chunk_40":167698,"chunk_41":142276,"chunk_42":821175,"chunk_43":892005,"chunk_44":600584,"chunk_45":374793,"chunk_46":47443,"chunk_47":134784,"chunk_48":155247,"chunk_49":86997,"chunk_50":24829,"chunk_51":68053,"chunk_52":63517,"chunk_53":167660,"chunk_54":175827,"chunk_55":142816,"chunk_56":765400,"chunk_57":900655,"chunk_58":628866,"chunk_59":396886,"chunk_60":46907,"chunk_61":135384,"chunk_62":154864,"chunk_63":88112,"chunk_64":23353,"chunk_65":67446,"chunk_66":40301,"chunk_67":176420,"chunk_68":183715,"chunk_69":149547,"chunk_70":713006,"chunk_71":901222,"chunk_72":652770,"chunk_73":419554,"chunk_74":52246,"chunk_75":134064,"chunk_76":153318,"chunk_77":92710,"chunk_78":22790,"chunk_79":66521,"chunk_80":34397,"chunk_81":173357,"chunk_82":186788,"chunk_83":153870,"chunk_84":657926,"chunk_85":902477,"chunk_86":655319,"chunk_87":455111,"chunk_88":69724,"chunk_89":133629,"chunk_90":146534,"chunk_91":101890,"chunk_92":21324,"chunk_93":65448,"chunk_94":33345,"chunk_95":162191,"chunk_96":192226,"chunk_97":159451,"chunk_98":598037,"chunk_99":903618,"chunk_100":662580,"chunk_101":484690,"chunk_102":86616,"chunk_103":135160,"chunk_104":106630,"chunk_105":142249,"chunk_106":19290,"chunk_107":60073,"chunk_108":39131,"chunk_109":155251,"chunk_110":190337,"chunk_111":166210,"chunk_112":26375} +# wikidata_chunk_sizes_fname = "../data/Wikidata/chunk_sizes.json" +# TODO: Add location as env var +# TODO: Sync data format from DATADUMP to chunk_sizes.json +# TODO: Retrieve info from Hugging Face instead of storing it +wikidata_chunk_sizes_path = os.path.join( + "docker", + "7_Create_Prototype", + "wikidata_chunk_sizes_2024-09-18.json" +) + +with open(wikidata_chunk_sizes_path) as json_in: + chunk_sizes = json.load(json_in) total_entities = chunk_sizes[f"chunk_{CHUNK_NUM}"] -datastax_token = json.load(open(f"../API_tokens/{DB_API_KEY_FILENAME}")) +with open(f"../API_tokens/{DB_API_KEY_FILENAME}") as json_in: + datastax_token = json.load(json_in) + dataset = load_dataset( "philippesaade/wikidata", data_files=f"data/chunk_{CHUNK_NUM}-*.parquet", @@ -57,8 +69,12 @@ ) def process_items(queue, progress_bar): - """Worker function that processes items from the queue and adds them to AstraDB.""" - datastax_token = json.load(open(f"../API_tokens/{DB_API_KEY_FILENAME}")) + """Worker function that processes items from the queue + and adds them to AstraDB. + """ + with open(f"../API_tokens/{DB_API_KEY_FILENAME}") as json_in: + datastax_token = json.load(json_in) + graph_store = AstraDBConnect( datastax_token, COLLECTION_NAME, diff --git a/docker/7_Create_Prototype/wikidata_chunk_sizes_2024-09-18.json b/docker/7_Create_Prototype/wikidata_chunk_sizes_2024-09-18.json new file mode 100644 index 0000000..e474869 --- /dev/null +++ b/docker/7_Create_Prototype/wikidata_chunk_sizes_2024-09-18.json @@ -0,0 +1,115 @@ +{ + "chunk_0":992458, + "chunk_1":802125, + "chunk_2":589652, + "chunk_3":310440, + "chunk_4":43477, + "chunk_5":156867, + "chunk_6":141965, + "chunk_7":74047, + "chunk_8":27104, + "chunk_9":70759, + "chunk_10":71395, + "chunk_11":186698, + "chunk_12":153182, + "chunk_13":137155, + "chunk_14":929827, + "chunk_15":853027, + "chunk_16":571543, + "chunk_17":335565, + "chunk_18":47264, + "chunk_19":135986, + "chunk_20":160411, + "chunk_21":76377, + "chunk_22":26321, + "chunk_23":70572, + "chunk_24":68613, + "chunk_25":179806, + "chunk_26":159587, + "chunk_27":139912, + "chunk_28":876104, + "chunk_29":864360, + "chunk_30":590603, + "chunk_31":358747, + "chunk_32":47772, + "chunk_33":135633, + "chunk_34":159629, + "chunk_35":81231, + "chunk_36":24912, + "chunk_37":69201, + "chunk_38":67131, + "chunk_39":172234, + "chunk_40":167698, + "chunk_41":142276, + "chunk_42":821175, + "chunk_43":892005, + "chunk_44":600584, + "chunk_45":374793, + "chunk_46":47443, + "chunk_47":134784, + "chunk_48":155247, + "chunk_49":86997, + "chunk_50":24829, + "chunk_51":68053, + "chunk_52":63517, + "chunk_53":167660, + "chunk_54":175827, + "chunk_55":142816, + "chunk_56":765400, + "chunk_57":900655, + "chunk_58":628866, + "chunk_59":396886, + "chunk_60":46907, + "chunk_61":135384, + "chunk_62":154864, + "chunk_63":88112, + "chunk_64":23353, + "chunk_65":67446, + "chunk_66":40301, + "chunk_67":176420, + "chunk_68":183715, + "chunk_69":149547, + "chunk_70":713006, + "chunk_71":901222, + "chunk_72":652770, + "chunk_73":419554, + "chunk_74":52246, + "chunk_75":134064, + "chunk_76":153318, + "chunk_77":92710, + "chunk_78":22790, + "chunk_79":66521, + "chunk_80":34397, + "chunk_81":173357, + "chunk_82":186788, + "chunk_83":153870, + "chunk_84":657926, + "chunk_85":902477, + "chunk_86":655319, + "chunk_87":455111, + "chunk_88":69724, + "chunk_89":133629, + "chunk_90":146534, + "chunk_91":101890, + "chunk_92":21324, + "chunk_93":65448, + "chunk_94":33345, + "chunk_95":162191, + "chunk_96":192226, + "chunk_97":159451, + "chunk_98":598037, + "chunk_99":903618, + "chunk_100":662580, + "chunk_101":484690, + "chunk_102":86616, + "chunk_103":135160, + "chunk_104":106630, + "chunk_105":142249, + "chunk_106":19290, + "chunk_107":60073, + "chunk_108":39131, + "chunk_109":155251, + "chunk_110":190337, + "chunk_111":166210, + "chunk_112":26375 +} \ No newline at end of file diff --git a/run_experiments.sh b/run_experiments.sh index d79fd90..321513a 100644 --- a/run_experiments.sh +++ b/run_experiments.sh @@ -1,12 +1,61 @@ # docker compose run --build add_wikidata_to_astra -# docker compose run --build -e EVALUATION_PATH="Mintaka/processed_dataframe.pkl" -e QUERY_COL="Question" -e PREFIX="_nonewlines" -e COLLECTION_NAME="wikidatav1" -e QUERY_LANGUAGE="en" -e DB_LANGUAGE="en" -e API_KEY="datastax_wikidata2.json" run_retrieval -# docker compose run --build -e EVALUATION_PATH="LC_QuAD/processed_dataframe.pkl" -e QUERY_COL="Question" -e PREFIX="_nonewlines" -e COLLECTION_NAME="wikidatav1" -e QUERY_LANGUAGE="en" -e DB_LANGUAGE="en" -e API_KEY="datastax_wikidata2.json" run_retrieval -# docker compose run --build -e EVALUATION_PATH="REDFM/processed_dataframe.pkl" -e QUERY_COL="Sentence" -e PREFIX="_nonewlines" -e COLLECTION_NAME="wikidatav1" -e QUERY_LANGUAGE="en" -e DB_LANGUAGE="en" -e API_KEY="datastax_wikidata2.json" run_retrieval -# docker compose run --build -e EVALUATION_PATH="REDFM/processed_dataframe.pkl" -e QUERY_COL="Sentence no entity" -e PREFIX="_nonewlines_noentity" -e COLLECTION_NAME="wikidatav1" -e QUERY_LANGUAGE="en" -e DB_LANGUAGE="en" -e API_KEY="datastax_wikidata2.json" run_retrieval -# docker compose run --build -e EVALUATION_PATH="RuBQ/processed_dataframe.pkl" -e QUERY_COL="Question" -e PREFIX="_nonewlines" -e COLLECTION_NAME="wikidatav1" -e QUERY_LANGUAGE="en" -e DB_LANGUAGE="en" -e API_KEY="datastax_wikidata2.json" run_retrieval -# docker compose run --build -e EVALUATION_PATH="Wikidata-Disamb/processed_dataframe.pkl" -e QUERY_COL="Sentence" -e COMPARATIVE="true" -e COMPARATIVE_COLS="Correct QID,Wrong QID" -e COLLECTION_NAME="wikidatav1" -e PREFIX="_nonewlines" -e QUERY_LANGUAGE="en" -e DB_LANGUAGE="en" -e API_KEY="datastax_wikidata2.json" run_retrieval - -# docker compose run --build -e CHUNK_NUM=5 create_prototype -docker compose run --build -e CHUNK_NUM=6 create_prototype -docker compose run --build -e CHUNK_NUM=7 create_prototype -docker compose run --build -e CHUNK_NUM=8 create_prototype \ No newline at end of file +# docker compose run --build \ +# -e EVALUATION_PATH="Mintaka/processed_dataframe.pkl" \ +# -e QUERY_COL="Question" \ +# -e PREFIX="_nonewlines" \ +# -e COLLECTION_NAME="wikidatav1" \ +# -e QUERY_LANGUAGE="en" \ +# -e DB_LANGUAGE="en" \ +# -e API_KEY="datastax_wikidata2.json" run_retrieval + +# docker compose run --build \ +# -e EVALUATION_PATH="LC_QuAD/processed_dataframe.pkl" \ +# -e QUERY_COL="Question" \ +# -e PREFIX="_nonewlines" \ +# -e COLLECTION_NAME="wikidatav1" \ +# -e QUERY_LANGUAGE="en" \ +# -e DB_LANGUAGE="en" \ +# -e API_KEY="datastax_wikidata2.json" run_retrieval + +# docker compose run --build \ +# -e EVALUATION_PATH="REDFM/processed_dataframe.pkl" \ +# -e QUERY_COL="Sentence" \ +# -e PREFIX="_nonewlines" \ +# -e COLLECTION_NAME="wikidatav1" \ +# -e QUERY_LANGUAGE="en" \ +# -e DB_LANGUAGE="en" \ +# -e API_KEY="datastax_wikidata2.json" run_retrieval + +# docker compose run --build \ +# -e EVALUATION_PATH="REDFM/processed_dataframe.pkl" \ +# -e QUERY_COL="Sentence no entity" \ +# -e PREFIX="_nonewlines_noentity" \ +# -e COLLECTION_NAME="wikidatav1" \ +# -e QUERY_LANGUAGE="en" \ +# -e DB_LANGUAGE="en" \ +# -e API_KEY="datastax_wikidata2.json" run_retrieval + +# docker compose run --build \ +# -e EVALUATION_PATH="RuBQ/processed_dataframe.pkl" \ +# -e QUERY_COL="Question" \ +# -e PREFIX="_nonewlines" \ +# -e COLLECTION_NAME="wikidatav1" \ +# -e QUERY_LANGUAGE="en" \ +# -e DB_LANGUAGE="en" \ +# -e API_KEY="datastax_wikidata2.json" run_retrieval + +# docker compose run --build \ +# -e EVALUATION_PATH="Wikidata-Disamb/processed_dataframe.pkl" \ +# -e QUERY_COL="Sentence" \ +# -e COMPARATIVE="true" \ +# -e COMPARATIVE_COLS="Correct QID,Wrong QID" \ +# -e COLLECTION_NAME="wikidatav1" \ +# -e PREFIX="_nonewlines" \ +# -e QUERY_LANGUAGE="en" \ +# -e DB_LANGUAGE="en" \ +# -e API_KEY="datastax_wikidata2.json" run_retrieval + +docker compose run --build -e CHUNK_NUM=112 create_prototype +docker compose run --build -e CHUNK_NUM=111 create_prototype +docker compose run --build -e CHUNK_NUM=110 create_prototype +docker compose run --build -e CHUNK_NUM=109 create_prototype \ No newline at end of file From 90b013cfe4e9c7d3bf980afccd28779dca344905 Mon Sep 17 00:00:00 2001 From: exowanderer Date: Tue, 11 Mar 2025 18:10:59 +0100 Subject: [PATCH 12/49] added data dir exist checks --- src/wikidataCache.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/wikidataCache.py b/src/wikidataCache.py index 4922a1f..325ad01 100644 --- a/src/wikidataCache.py +++ b/src/wikidataCache.py @@ -2,13 +2,25 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker from sqlalchemy.types import TypeDecorator + +import os import json """ SQLite database setup for caching the query embeddings for a faster evaluation process. """ + +# TODO: Move to a configuration file +wikidata_cache_file = "wikidata_cache.db" + +wikidata_cache_dir = "../data/Wikidata" +wikidata_cache_path = os.path.join(wikidata_cache_dir, wikidata_cache_file) + +if not os.path.exists(wikidata_cache_dir): + os.makedirs(wikidata_cache_dir) + engine = create_engine( - 'sqlite:///../data/Wikidata/sqlite_cacheembeddings.db', + f'sqlite:///{wikidata_cache_path}', pool_size=5, # Limit the number of open connections max_overflow=10, # Allow extra connections beyond pool_size pool_recycle=10 # Recycle connections every 10 seconds From ffedc9a8829d031d5e21ddd901c8f4d15fe24a38 Mon Sep 17 00:00:00 2001 From: exowanderer Date: Tue, 11 Mar 2025 18:29:27 +0100 Subject: [PATCH 13/49] modified test dir exists and create --- docker/7_Create_Prototype/run.py | 4 ++-- src/wikidataCache.py | 10 ++++++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/docker/7_Create_Prototype/run.py b/docker/7_Create_Prototype/run.py index 25c06d4..dbda960 100644 --- a/docker/7_Create_Prototype/run.py +++ b/docker/7_Create_Prototype/run.py @@ -48,8 +48,8 @@ # TODO: Sync data format from DATADUMP to chunk_sizes.json # TODO: Retrieve info from Hugging Face instead of storing it wikidata_chunk_sizes_path = os.path.join( - "docker", - "7_Create_Prototype", + # "docker", + # "7_Create_Prototype", "wikidata_chunk_sizes_2024-09-18.json" ) diff --git a/src/wikidataCache.py b/src/wikidataCache.py index 325ad01..4da927a 100644 --- a/src/wikidataCache.py +++ b/src/wikidataCache.py @@ -16,8 +16,14 @@ wikidata_cache_dir = "../data/Wikidata" wikidata_cache_path = os.path.join(wikidata_cache_dir, wikidata_cache_file) -if not os.path.exists(wikidata_cache_dir): - os.makedirs(wikidata_cache_dir) +try: + if not os.path.exists(wikidata_cache_dir): + os.makedirs(wikidata_cache_dir) +except OSError as e: + print(f"Error creating directory {wikidata_cache_dir}: {e}") + +assert(os.path.exists(wikidata_cache_dir)), \ + f"Error creating directory {wikidata_cache_dir}" engine = create_engine( f'sqlite:///{wikidata_cache_path}', From 8c44bbeaedd50e75e9052ca3aa9b14cef142d7a0 Mon Sep 17 00:00:00 2001 From: Philippe Saade Date: Wed, 12 Mar 2025 17:24:00 +0100 Subject: [PATCH 14/49] Handle network errors from Jina's side --- src/wikidataRetriever.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/wikidataRetriever.py b/src/wikidataRetriever.py index e96d689..9500d8b 100644 --- a/src/wikidataRetriever.py +++ b/src/wikidataRetriever.py @@ -120,18 +120,26 @@ def push_batch(self): if cache is None: docs.append(doc) except: + # Queue is empty break if len(docs) == 0: return False - vectors = self.embeddings.embed_documents([doc['content'] for doc in docs]) - try: + vectors = self.embeddings.embed_documents( + [doc['content'] for doc in docs] + ) self.graph_store.insert_many(docs, vectors=vectors) except Exception as e: print(e) + # Put the documents back in the Queue and try again later. + for doc in docs: + self.doc_batch.put(doc) + + return False + self.cache_model.add_bulk_cache([{ 'id': docs[i]['_id'], 'embedding': json.dumps(vectors[i], separators=(',', ':'))} @@ -139,6 +147,11 @@ def push_batch(self): return True + def push_all(self): + while True: + if not self.push_batch(): # Stop when batch is empty + break + def get_similar_qids(self, query, filter={}, K=50): """ Retrieve similar QIDs for a given query string. From 771266021c2c4309845e885d817a02dac767675e Mon Sep 17 00:00:00 2001 From: Philippe Saade Date: Wed, 12 Mar 2025 20:42:00 +0100 Subject: [PATCH 15/49] change cache database to base64 for the vectors --- src/migrate_cache.py | 72 ++++++++++++++++++++++++++++++++++++++++++++ src/wikidataCache.py | 23 ++++++++++---- 2 files changed, 89 insertions(+), 6 deletions(-) create mode 100644 src/migrate_cache.py diff --git a/src/migrate_cache.py b/src/migrate_cache.py new file mode 100644 index 0000000..ca2429c --- /dev/null +++ b/src/migrate_cache.py @@ -0,0 +1,72 @@ +import sqlite3 +import json +import base64 +import numpy as np +from tqdm import tqdm + +DB_PATH = "../data/Wikidata/sqlite_cacheembeddings.db" +TABLE_NAME = "wikidata_prototype" # Change this to match your actual table name +BATCH_SIZE = 5000 # Process in smaller batches to avoid memory overload + +def convert_embeddings(): + """ + Convert JSON-stored embeddings into Base64-encoded binary format in batches. + Uses `fetchmany(BATCH_SIZE)` to process records iteratively. + """ + conn = sqlite3.connect(DB_PATH) + cursor = conn.cursor() + + # Check if the embedding column exists (sanity check) + cursor.execute(f"PRAGMA table_info({TABLE_NAME})") + columns = [row[1] for row in cursor.fetchall()] + if "embedding" not in columns: + print("Error: 'embedding' column does not exist in the table!") + return + + # Count total records for progress tracking + cursor.execute(f"SELECT COUNT(*) FROM {TABLE_NAME}") + total_records = cursor.fetchone()[0] + + print(f"Total records to process: {total_records}") + + # Fetch records in batches using an iterator + offset = 0 + with tqdm(total=total_records, desc="Converting embeddings", unit="record") as pbar: + while True: + cursor.execute(f"SELECT id, embedding FROM {TABLE_NAME} LIMIT {BATCH_SIZE} OFFSET {offset}") + records = cursor.fetchall() + if not records: + break # Stop when there are no more records + + updated_records = [] + for id, json_embedding in records: + if json_embedding: + try: + # Convert JSON string to list of floats + embedding_list = json.loads(json_embedding) + + # Convert list of floats to Base64-encoded binary + binary_data = np.array(embedding_list, dtype=np.float32).tobytes() + base64_embedding = base64.b64encode(binary_data).decode('utf-8') + + updated_records.append((base64_embedding, id)) + except Exception as e: + print(f"\nSkipping ID {id} due to error: {e}") + + pbar.update(1) # Update progress bar for each record processed + + # Update database in batches + if updated_records: + cursor.executemany( + f"UPDATE {TABLE_NAME} SET embedding = ? WHERE id = ?", + updated_records + ) + conn.commit() # Commit every batch + + offset += BATCH_SIZE # Move to next batch + + print("Migration completed successfully.") + conn.close() + +if __name__ == "__main__": + convert_embeddings() \ No newline at end of file diff --git a/src/wikidataCache.py b/src/wikidataCache.py index 4922a1f..5fea4cf 100644 --- a/src/wikidataCache.py +++ b/src/wikidataCache.py @@ -3,6 +3,8 @@ from sqlalchemy.orm import sessionmaker from sqlalchemy.types import TypeDecorator import json +import base64 +import numpy as np """ SQLite database setup for caching the query embeddings for a faster evaluation process. @@ -17,18 +19,27 @@ Base = declarative_base() Session = sessionmaker(bind=engine) -class JSONType(TypeDecorator): - """Custom SQLAlchemy type for JSON storage in SQLite.""" +class EmbeddingType(TypeDecorator): + """Custom SQLAlchemy type for storing embeddings as Base64 strings in SQLite.""" impl = Text def process_bind_param(self, value, dialect): - if value is not None: - return json.dumps(value, separators=(',', ':')) + """Convert a list of floats (embedding) to a Base64 string before storing.""" + if value is not None and isinstance(value, list): + # Convert list to binary + binary_data = np.array(value, dtype=np.float32).tobytes() + # Encode to Base64 string + return base64.b64encode(binary_data).decode('utf-8') return None def process_result_value(self, value, dialect): + """Convert a Base64 string back to a list of floats when retrieving.""" if value is not None: - return json.loads(value) + # Decode Base64 + binary_data = base64.b64decode(value) + # Convert back to float32 list + embedding_array = np.frombuffer(binary_data, dtype=np.float32) + return embedding_array.tolist() return None def create_cache_embedding_model(table_name): @@ -38,7 +49,7 @@ class CacheEmbeddings(Base): __tablename__ = table_name id = Column(Text, primary_key=True) - embedding = Column(JSONType) + embedding = Column(EmbeddingType) @staticmethod def add_cache(id, embedding): From 75e477ce7437213a80eec0a6cd3f032b2faaf146 Mon Sep 17 00:00:00 2001 From: exowanderer Date: Thu, 13 Mar 2025 10:59:05 +0100 Subject: [PATCH 16/49] Refactored docker/1_Data_Processing_save_labels_descriptions/run.py based on flake8 results --- docker-compose.yml | 2 +- .../Dockerfile | 2 +- .../run.py | 39 ++++++++++++++----- 3 files changed, 32 insertions(+), 11 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index 5d06906..52d46ce 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -142,7 +142,7 @@ services: API_KEY: "datastax_wikidata.json" EMBED_BATCH_SIZE: 100 # QUEUE_SIZE: 5000 - NUM_PROCESSES: 16 + NUM_PROCESSES: 23 OFFSET: 0 COLLECTION_NAME: "wikidata_prototype" LANGUAGE: 'en' diff --git a/docker/1_Data_Processing_save_labels_descriptions/Dockerfile b/docker/1_Data_Processing_save_labels_descriptions/Dockerfile index c67ee6f..e412193 100644 --- a/docker/1_Data_Processing_save_labels_descriptions/Dockerfile +++ b/docker/1_Data_Processing_save_labels_descriptions/Dockerfile @@ -17,7 +17,7 @@ RUN pip install --no-cache-dir -r requirements.txt # Copy the rest of the application code into the container COPY ./docker/1_Data_Processing_save_labels_descriptions /app -COPY ./src /src +COPY ./src /app/src # Set up the volume for the data folder VOLUME [ "/data" ] diff --git a/docker/1_Data_Processing_save_labels_descriptions/run.py b/docker/1_Data_Processing_save_labels_descriptions/run.py index 7a151c3..3dbef30 100644 --- a/docker/1_Data_Processing_save_labels_descriptions/run.py +++ b/docker/1_Data_Processing_save_labels_descriptions/run.py @@ -1,13 +1,14 @@ -import sys -sys.path.append('../src') +# import sys +# sys.path.append('../src') -from wikidataDumpReader import WikidataDumpReader -from wikidataItemDB import WikidataItem from multiprocessing import Manager import os import time import json +from src.wikidataDumpReader import WikidataDumpReader +from src.wikidataItemDB import WikidataItem + FILEPATH = os.getenv("FILEPATH", '../data/Wikidata/latest-all.json.bz2') PUSH_SIZE = int(os.getenv("PUSH_SIZE", 20000)) QUEUE_SIZE = int(os.getenv("QUEUE_SIZE", 15000)) @@ -15,10 +16,13 @@ SKIPLINES = int(os.getenv("SKIPLINES", 0)) LANGUAGE = os.getenv("LANGUAGE", 'en') + def save_items_to_sqlite(item, data_batch, sqlitDBlock): if (item is not None): labels = WikidataItem.clean_label_description(item['labels']) - descriptions = WikidataItem.clean_label_description(item['descriptions']) + descriptions = WikidataItem.clean_label_description( + item['descriptions'] + ) labels = json.dumps(labels, separators=(',', ':')) descriptions = json.dumps(descriptions, separators=(',', ':')) in_wikipedia = WikidataItem.is_in_wikipedia(item) @@ -31,21 +35,38 @@ def save_items_to_sqlite(item, data_batch, sqlitDBlock): with sqlitDBlock: if len(data_batch) > PUSH_SIZE: - worked = WikidataItem.add_bulk_items(list(data_batch[:PUSH_SIZE])) + worked = WikidataItem.add_bulk_items(list( + data_batch[:PUSH_SIZE] + )) if worked: del data_batch[:PUSH_SIZE] + if __name__ == "__main__": multiprocess_manager = Manager() sqlitDBlock = multiprocess_manager.Lock() data_batch = multiprocess_manager.list() - wikidata = WikidataDumpReader(FILEPATH, num_processes=NUM_PROCESSES, queue_size=QUEUE_SIZE, skiplines=SKIPLINES) - wikidata.run(lambda item: save_items_to_sqlite(item, data_batch, sqlitDBlock), max_iterations=None, verbose=True) + wikidata = WikidataDumpReader( + FILEPATH, + num_processes=NUM_PROCESSES, + queue_size=QUEUE_SIZE, + skiplines=SKIPLINES + ) + + wikidata.run( + lambda item: save_items_to_sqlite( + item, + data_batch, + sqlitDBlock + ), + max_iterations=None, + verbose=True + ) while len(data_batch) > 0: worked = WikidataItem.add_bulk_items(list(data_batch)) if worked: del data_batch[:PUSH_SIZE] else: - time.sleep(1) \ No newline at end of file + time.sleep(1) From d1a256ab24ab3dd720b381b6bffa0a8d7fd25bda Mon Sep 17 00:00:00 2001 From: exowanderer Date: Thu, 13 Mar 2025 11:03:26 +0100 Subject: [PATCH 17/49] removed unnecessary commented code at top of file --- docker/1_Data_Processing_save_labels_descriptions/run.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/docker/1_Data_Processing_save_labels_descriptions/run.py b/docker/1_Data_Processing_save_labels_descriptions/run.py index 3dbef30..78dc999 100644 --- a/docker/1_Data_Processing_save_labels_descriptions/run.py +++ b/docker/1_Data_Processing_save_labels_descriptions/run.py @@ -1,6 +1,3 @@ -# import sys -# sys.path.append('../src') - from multiprocessing import Manager import os import time From cf20d46a1d6da8ecd5b9161f8026ecaa95307a9b Mon Sep 17 00:00:00 2001 From: exowanderer Date: Thu, 13 Mar 2025 11:05:33 +0100 Subject: [PATCH 18/49] Added PYTHONPATH env variable to Docekr file to bypass need for sys.path.append operation in run.py --- docker/1_Data_Processing_save_labels_descriptions/Dockerfile | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docker/1_Data_Processing_save_labels_descriptions/Dockerfile b/docker/1_Data_Processing_save_labels_descriptions/Dockerfile index e412193..c3f5f2f 100644 --- a/docker/1_Data_Processing_save_labels_descriptions/Dockerfile +++ b/docker/1_Data_Processing_save_labels_descriptions/Dockerfile @@ -17,10 +17,12 @@ RUN pip install --no-cache-dir -r requirements.txt # Copy the rest of the application code into the container COPY ./docker/1_Data_Processing_save_labels_descriptions /app -COPY ./src /app/src +COPY ./src /src # Set up the volume for the data folder VOLUME [ "/data" ] +ENV PYTHONPATH="${PYTHONPATH}:/src" + # Run the Python script CMD ["python", "run.py"] \ No newline at end of file From e78cd400d7c57973da18cc2e727fae6df490ed0d Mon Sep 17 00:00:00 2001 From: exowanderer Date: Thu, 13 Mar 2025 11:10:35 +0100 Subject: [PATCH 19/49] Refactored docker/2_Data_Processing_save_items_per_lang/run.py based on flake8; add PYTHONPATH to docker/2_Data_Processing_save_items_per_lang/Docker --- .../Dockerfile | 2 + .../run.py | 40 ++++++++++++++----- 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/docker/2_Data_Processing_save_items_per_lang/Dockerfile b/docker/2_Data_Processing_save_items_per_lang/Dockerfile index cb8433b..fd76838 100644 --- a/docker/2_Data_Processing_save_items_per_lang/Dockerfile +++ b/docker/2_Data_Processing_save_items_per_lang/Dockerfile @@ -22,5 +22,7 @@ COPY ./src /src # Set up the volume for the data folder VOLUME [ "/data" ] +ENV PYTHONPATH="${PYTHONPATH}:/src" + # Run the Python script CMD ["python", "run.py"] \ No newline at end of file diff --git a/docker/2_Data_Processing_save_items_per_lang/run.py b/docker/2_Data_Processing_save_items_per_lang/run.py index 3fd7353..aec08b6 100644 --- a/docker/2_Data_Processing_save_items_per_lang/run.py +++ b/docker/2_Data_Processing_save_items_per_lang/run.py @@ -1,12 +1,10 @@ -import sys -sys.path.append('../src') - -from wikidataDumpReader import WikidataDumpReader -from wikidataLangDB import WikidataLang from multiprocessing import Manager import os import time +from src.wikidataDumpReader import WikidataDumpReader +from src.wikidataLangDB import WikidataLang + FILEPATH = os.getenv("FILEPATH", '../data/Wikidata/latest-all.json.bz2') PUSH_SIZE = int(os.getenv("PUSH_SIZE", 2000)) QUEUE_SIZE = int(os.getenv("QUEUE_SIZE", 1500)) @@ -14,28 +12,50 @@ SKIPLINES = int(os.getenv("SKIPLINES", 0)) LANGUAGE = os.getenv("LANGUAGE", 'en') + def save_entities_to_sqlite(item, data_batch, sqlitDBlock): - if (item is not None) and WikidataLang.is_in_wikipedia(item, language=LANGUAGE): + is_not_none = item is not None + if is_not_none: + lang_in_wp = WikidataLang.is_in_wikipedia(item, language=LANGUAGE) + + if is_not_none and lang_in_wp: item = WikidataLang.normalise_item(item, language=LANGUAGE) data_batch.append(item) with sqlitDBlock: if len(data_batch) > PUSH_SIZE: - worked = WikidataLang.add_bulk_entities(list(data_batch[:PUSH_SIZE])) + worked = WikidataLang.add_bulk_entities(list( + data_batch[:PUSH_SIZE] + )) if worked: del data_batch[:PUSH_SIZE] + if __name__ == "__main__": multiprocess_manager = Manager() sqlitDBlock = multiprocess_manager.Lock() data_batch = multiprocess_manager.list() - wikidata = WikidataDumpReader(FILEPATH, num_processes=NUM_PROCESSES, queue_size=QUEUE_SIZE, skiplines=SKIPLINES) - wikidata.run(lambda item: save_entities_to_sqlite(item, data_batch, sqlitDBlock), max_iterations=None, verbose=True) + wikidata = WikidataDumpReader( + FILEPATH, + num_processes=NUM_PROCESSES, + queue_size=QUEUE_SIZE, + skiplines=SKIPLINES + ) + + wikidata.run( + lambda item: save_entities_to_sqlite( + item, + data_batch, + sqlitDBlock + ), + max_iterations=None, + verbose=True + ) while len(data_batch) > 0: worked = WikidataLang.add_bulk_entities(list(data_batch)) if worked: del data_batch[:PUSH_SIZE] else: - time.sleep(1) \ No newline at end of file + time.sleep(1) From c0d50bd03856d24fab1b97f0e3ceadd6b07d3b27 Mon Sep 17 00:00:00 2001 From: exowanderer Date: Thu, 13 Mar 2025 11:23:08 +0100 Subject: [PATCH 20/49] Refactored docker/3_Add_Wikidata_to_AstraDB/run.py based on flake8; add PYTHONPATH to docker/2_Data_Processing_save_items_per_lang/Docker --- docker/3_Add_Wikidata_to_AstraDB/Dockerfile | 2 + docker/3_Add_Wikidata_to_AstraDB/run.py | 93 +++++++++++++++------ 2 files changed, 71 insertions(+), 24 deletions(-) diff --git a/docker/3_Add_Wikidata_to_AstraDB/Dockerfile b/docker/3_Add_Wikidata_to_AstraDB/Dockerfile index cff05c8..51591da 100644 --- a/docker/3_Add_Wikidata_to_AstraDB/Dockerfile +++ b/docker/3_Add_Wikidata_to_AstraDB/Dockerfile @@ -28,5 +28,7 @@ COPY ./API_tokens /API_tokens # Set up the volume for the data folder VOLUME [ "/data" ] +ENV PYTHONPATH="${PYTHONPATH}:/src" + # Run the Python script CMD ["python", "run.py"] \ No newline at end of file diff --git a/docker/3_Add_Wikidata_to_AstraDB/run.py b/docker/3_Add_Wikidata_to_AstraDB/run.py index 616f77b..9eab967 100644 --- a/docker/3_Add_Wikidata_to_AstraDB/run.py +++ b/docker/3_Add_Wikidata_to_AstraDB/run.py @@ -1,20 +1,20 @@ -import sys -sys.path.append('../src') - -from wikidataLangDB import Session, WikidataLang -from wikidataEmbed import WikidataTextifier -from wikidataRetriever import AstraDBConnect, KeywordSearchConnect - import json -from tqdm import tqdm import os import pickle -from datetime import datetime import hashlib +from datetime import datetime +from tqdm import tqdm + +from src.wikidataLangDB import Session, WikidataLang +from src.wikidataEmbed import WikidataTextifier +from src.wikidataRetriever import AstraDBConnect, KeywordSearchConnect + MODEL = os.getenv("MODEL", "jina") SAMPLE = os.getenv("SAMPLE", "false").lower() == "true" -SAMPLE_PATH = os.getenv("SAMPLE_PATH", "../data/Evaluation Data/Sample IDs (EN).pkl") +SAMPLE_PATH = os.getenv( + "SAMPLE_PATH", "../data/Evaluation Data/Sample IDs (EN).pkl" +) EMBED_BATCH_SIZE = int(os.getenv("EMBED_BATCH_SIZE", 100)) QUERY_BATCH_SIZE = int(os.getenv("QUERY_BATCH_SIZE", 1000)) OFFSET = int(os.getenv("OFFSET", 0)) @@ -27,6 +27,7 @@ ELASTICSEARCH_URL = os.getenv("ELASTICSEARCH_URL", "http://localhost:9200") ELASTICSEARCH = os.getenv("ELASTICSEARCH", "false").lower() == "true" +# TODO: refactor script into function and call after __name__ == "__main__" # Load the Database if not COLLECTION_NAME: raise ValueError("The COLLECTION_NAME environment variable is required") @@ -36,15 +37,29 @@ if not API_KEY_FILENAME: API_KEY_FILENAME = os.listdir("../API_tokens")[0] + datastax_token = json.load(open(f"../API_tokens/{API_KEY_FILENAME}")) -textifier = WikidataTextifier(language=LANGUAGE, langvar_filename=TEXTIFIER_LANGUAGE) +textifier = WikidataTextifier( + language=LANGUAGE, + langvar_filename=TEXTIFIER_LANGUAGE +) if ELASTICSEARCH: - graph_store = KeywordSearchConnect(ELASTICSEARCH_URL, index_name=COLLECTION_NAME) + graph_store = KeywordSearchConnect( + ELASTICSEARCH_URL, + index_name=COLLECTION_NAME + ) else: - graph_store = AstraDBConnect(datastax_token, COLLECTION_NAME, model=MODEL, batch_size=EMBED_BATCH_SIZE, cache_embeddings=False) - + graph_store = AstraDBConnect( + datastax_token, + COLLECTION_NAME, + model=MODEL, + batch_size=EMBED_BATCH_SIZE, + cache_embeddings=False + ) + +# TODO: refactor script into function and call after __name__ == "__main__" # Load the Sample IDs sample_ids = None if SAMPLE: @@ -54,23 +69,34 @@ def get_entity(session): sample_qids = list(sample_ids['QID'].values)[OFFSET:] - sample_qid_batches = [sample_qids[i:i + QUERY_BATCH_SIZE] for i in range(0, len(sample_qids), QUERY_BATCH_SIZE)] + sample_qid_batches = [ + sample_qids[i:i + QUERY_BATCH_SIZE] + for i in range(0, len(sample_qids), QUERY_BATCH_SIZE) + ] # For each batch of sample QIDs, fetch the entities from the database for qid_batch in sample_qid_batches: - entities = session.query(WikidataLang).filter(WikidataLang.id.in_(qid_batch)).yield_per(QUERY_BATCH_SIZE) + entities = session.query(WikidataLang).filter( + WikidataLang.id.in_(qid_batch) + ).yield_per(QUERY_BATCH_SIZE) + for entity in entities: yield entity - else: total_entities = 9203786 def get_entity(session): - entities = session.query(WikidataLang).offset(OFFSET).yield_per(QUERY_BATCH_SIZE) + entities = session.query( + WikidataLang).offset( + OFFSET + ).yield_per(QUERY_BATCH_SIZE) + for entity in entities: yield entity + if __name__ == "__main__": + # TODO: refactor script into function and call after __name__ == "__main__" with tqdm(total=total_entities-OFFSET) as progressbar: with Session() as session: entity_generator = get_entity(session) @@ -82,11 +108,18 @@ def get_entity(session): if ELASTICSEARCH: chunks = [textifier.entity_to_text(entity)] else: - chunks = textifier.chunk_text(entity, graph_store.tokenizer, max_length=graph_store.max_token_size) + chunks = textifier.chunk_text( + entity, + graph_store.tokenizer, + max_length=graph_store.max_token_size + ) for chunk_i in range(len(chunks)): - md5_hash = hashlib.md5(chunks[chunk_i].encode('utf-8')).hexdigest() - metadata={ + md5_hash = hashlib.md5( + chunks[chunk_i].encode('utf-8') + ).hexdigest() + + metadata = { "MD5": md5_hash, "Label": entity.label, "Description": entity.description, @@ -99,8 +132,20 @@ def get_entity(session): "IsProperty": ('P' in entity.id), "DumpDate": DUMPDATE } - graph_store.add_document(id=f"{entity.id}_{LANGUAGE}_{chunk_i+1}", text=chunks[chunk_i], metadata=metadata) - - tqdm.write(progressbar.format_meter(progressbar.n, progressbar.total, progressbar.format_dict["elapsed"])) # tqdm is not working in docker compose. This is the alternative + graph_store.add_document( + id=f"{entity.id}_{LANGUAGE}_{chunk_i+1}", + text=chunks[chunk_i], + metadata=metadata + ) + + # tqdm is not working in docker compose. + # This is the alternative + tqdm.write( + progressbar.format_meter( + progressbar.n, + progressbar.total, + progressbar.format_dict["elapsed"] + ) + ) graph_store.push_batch() From 15921fabcf5b6a9545d5305676b64edf9f8e839c Mon Sep 17 00:00:00 2001 From: exowanderer Date: Thu, 13 Mar 2025 11:26:40 +0100 Subject: [PATCH 21/49] Added ENV PYTHONPATH="${PYTHONPATH}:/src" to all Dockerfiles in docker/*/Dockerfile --- docker/4_Run_Retrieval/Dockerfile | 2 ++ docker/5_Run_Rerank/Dockerfile | 2 ++ docker/6_Push_Huggingface/Dockerfile | 2 ++ docker/7_Create_Prototype/Dockerfile | 2 ++ 4 files changed, 8 insertions(+) diff --git a/docker/4_Run_Retrieval/Dockerfile b/docker/4_Run_Retrieval/Dockerfile index c263af7..c978351 100644 --- a/docker/4_Run_Retrieval/Dockerfile +++ b/docker/4_Run_Retrieval/Dockerfile @@ -28,5 +28,7 @@ COPY ./API_tokens /API_tokens # Set up the volume for the data folder VOLUME [ "/data" ] +ENV PYTHONPATH="${PYTHONPATH}:/src" + # Run the Python script CMD ["python", "run.py"] \ No newline at end of file diff --git a/docker/5_Run_Rerank/Dockerfile b/docker/5_Run_Rerank/Dockerfile index adfd690..fbdf497 100644 --- a/docker/5_Run_Rerank/Dockerfile +++ b/docker/5_Run_Rerank/Dockerfile @@ -28,5 +28,7 @@ COPY ./API_tokens /API_tokens # Set up the volume for the data folder VOLUME [ "/data" ] +ENV PYTHONPATH="${PYTHONPATH}:/src" + # Run the Python script CMD ["python", "run.py"] \ No newline at end of file diff --git a/docker/6_Push_Huggingface/Dockerfile b/docker/6_Push_Huggingface/Dockerfile index 3b551c4..72e86f3 100644 --- a/docker/6_Push_Huggingface/Dockerfile +++ b/docker/6_Push_Huggingface/Dockerfile @@ -23,5 +23,7 @@ COPY ./API_tokens /API_tokens # Set up the volume for the data folder VOLUME [ "/data" ] +ENV PYTHONPATH="${PYTHONPATH}:/src" + # Run the Python script CMD ["python", "run.py"] \ No newline at end of file diff --git a/docker/7_Create_Prototype/Dockerfile b/docker/7_Create_Prototype/Dockerfile index bf17994..33bbd7b 100644 --- a/docker/7_Create_Prototype/Dockerfile +++ b/docker/7_Create_Prototype/Dockerfile @@ -26,5 +26,7 @@ COPY ./API_tokens /API_tokens # Set up the volume for the data folder VOLUME [ "/data" ] +ENV PYTHONPATH="${PYTHONPATH}:/src" + # Run the Python script CMD ["python", "run.py"] \ No newline at end of file From e420e7b68b075a3681f708f42c163b7b6e67fe37 Mon Sep 17 00:00:00 2001 From: exowanderer Date: Thu, 13 Mar 2025 11:45:54 +0100 Subject: [PATCH 22/49] Refactored docker/4_Run_Retrieval/run.py based on output from flake8 --- docker/4_Run_Retrieval/run.py | 109 ++++++++++++++++++++++++++-------- 1 file changed, 84 insertions(+), 25 deletions(-) diff --git a/docker/4_Run_Retrieval/run.py b/docker/4_Run_Retrieval/run.py index 490b071..29f825b 100644 --- a/docker/4_Run_Retrieval/run.py +++ b/docker/4_Run_Retrieval/run.py @@ -1,14 +1,12 @@ -import sys -sys.path.append('../src') - -from wikidataRetriever import AstraDBConnect, KeywordSearchConnect - import json -from tqdm import tqdm import pandas as pd import os import pickle +from tqdm import tqdm +from src.wikidataRetriever import AstraDBConnect, KeywordSearchConnect + +# TODO: change script to functional form with fucnctions called after __name__ MODEL = os.getenv("MODEL", "jina") BATCH_SIZE = int(os.getenv("BATCH_SIZE", 100)) API_KEY_FILENAME = os.getenv("API_KEY", None) @@ -27,8 +25,18 @@ ELASTICSEARCH_URL = os.getenv("ELASTICSEARCH_URL", "http://localhost:9200") ELASTICSEARCH = os.getenv("ELASTICSEARCH", "false").lower() == "true" -OUTPUT_FILENAME = f"retrieval_results_{EVALUATION_PATH.split('/')[-2]}-{COLLECTION_NAME}-DB({DB_LANGUAGE})-Query({QUERY_LANGUAGE})" -# OUTPUT_FILENAME = f"retrieval_results_{EVALUATION_PATH.split('/')[-2]}-keyword-search-{LANGUAGE}" +OUTPUT_FILENAME = ( + f"retrieval_results_{EVALUATION_PATH.split('/')[-2]}-{COLLECTION_NAME}-" + f"DB({DB_LANGUAGE})-Query({QUERY_LANGUAGE})" +) + +# TODO: remove unneccesary commented out code +# OUTPUT_FILENAME = ( +# f"retrieval_results_{EVALUATION_PATH.split('/')[-2]}-" +# f"keyword-search-{LANGUAGE}" +# ) + + if PREFIX != "": OUTPUT_FILENAME += PREFIX @@ -38,25 +46,43 @@ if not API_KEY_FILENAME: API_KEY_FILENAME = os.listdir("../API_tokens")[0] + print(f"API_KEY_FILENAME not provided. Using {API_KEY_FILENAME}") + datastax_token = json.load(open(f"../API_tokens/{API_KEY_FILENAME}")) if ELASTICSEARCH: - graph_store = KeywordSearchConnect(ELASTICSEARCH_URL, index_name=COLLECTION_NAME) + graph_store = KeywordSearchConnect( + ELASTICSEARCH_URL, + index_name=COLLECTION_NAME + ) OUTPUT_FILENAME += "_bm25" else: - graph_store = AstraDBConnect(datastax_token, COLLECTION_NAME, model=MODEL, batch_size=BATCH_SIZE, cache_embeddings=True) - -#Load the Evaluation Dataset + graph_store = AstraDBConnect( + datastax_token, + COLLECTION_NAME, + model=MODEL, + batch_size=BATCH_SIZE, + cache_embeddings=True + ) + +# Load the Evaluation Dataset if not QUERY_COL: raise ValueError("The QUERY_COL environment variable is required") if not EVALUATION_PATH: raise ValueError("The EVALUATION_PATH environment variable is required") -if not RESTART and os.path.exists(f"../data/Evaluation Data/{OUTPUT_FILENAME}.pkl"): +outputfile_exists = os.path.exists( + f"../data/Evaluation Data/{OUTPUT_FILENAME}.pkl" +) +if not RESTART and outputfile_exists: print(f"Loading data from: {OUTPUT_FILENAME}") - eval_data = pickle.load(open(f"../data/Evaluation Data/{OUTPUT_FILENAME}.pkl", "rb")) + pkl_fpath = f"../data/Evaluation Data/{OUTPUT_FILENAME}.pkl" + with open(pkl_fpath, "rb") as pkl_file: + eval_data = pickle.load(pkl_file) else: - eval_data = pickle.load(open(f"../data/Evaluation Data/{EVALUATION_PATH}", "rb")) + pkl_fpath = f"../data/Evaluation Data/{EVALUATION_PATH}" + with open(pkl_fpath, "rb") as pkl_file: + eval_data = pickle.load(pkl_file) if 'Language' in eval_data.columns: eval_data = eval_data[eval_data['Language'] == QUERY_LANGUAGE] @@ -70,23 +96,56 @@ if 'Retrieval Score' not in eval_data: eval_data['Retrieval Score'] = None - row_to_process = eval_data['Retrieval QIDs'].apply(lambda x: (x is None) or (len(x) == 0)) | eval_data['Retrieval Score'].apply(lambda x: (x is None) or (len(x) == 0)) # Find rows that havn't been processed + # TODO: Refactor this row_to_process to avoid nested .apply + # Find rows that haven't been processed + row_to_process = eval_data['Retrieval QIDs'].apply( + lambda x: (x is None) or (len(x) == 0) + ) | eval_data['Retrieval Score'].apply( + lambda x: (x is None) or (len(x) == 0) + ) + pkl_output_path = f"../data/Evaluation Data/{OUTPUT_FILENAME}.pkl" progressbar.update((~row_to_process).sum()) for i in range(0, row_to_process.sum(), BATCH_SIZE): batch_idx = eval_data[row_to_process].iloc[i:i+BATCH_SIZE].index batch = eval_data.loc[batch_idx] if COMPARATIVE: - batch_results = graph_store.batch_retrieve_comparative(batch[QUERY_COL], batch[COMPARATIVE_COLS.split(',')], K=K, Language=DB_LANGUAGE) + batch_results = graph_store.batch_retrieve_comparative( + batch[QUERY_COL], + batch[COMPARATIVE_COLS.split(',')], + K=K, + Language=DB_LANGUAGE + ) else: - batch_results = graph_store.batch_retrieve(batch[QUERY_COL], K=K, Language=DB_LANGUAGE) - - eval_data.loc[batch_idx, 'Retrieval QIDs'] = pd.Series(batch_results[0]).values - eval_data.loc[batch_idx, 'Retrieval Score'] = pd.Series(batch_results[1]).values - + batch_results = graph_store.batch_retrieve( + batch[QUERY_COL], + K=K, + Language=DB_LANGUAGE + ) + + eval_data.loc[batch_idx, 'Retrieval QIDs'] = pd.Series( + batch_results[0] + ).values + + eval_data.loc[batch_idx, 'Retrieval Score'] = pd.Series( + batch_results[1] + ).values + + # TODO: Create progress bar update funciton + # tqdm is not wokring in docker compose. This is the alternative progressbar.update(len(batch)) - tqdm.write(progressbar.format_meter(progressbar.n, progressbar.total, progressbar.format_dict["elapsed"])) # tqdm is not wokring in docker compose. This is the alternative + tqdm.write( + progressbar.format_meter( + progressbar.n, + progressbar.total, + progressbar.format_dict["elapsed"] + ) + ) if progressbar.n % 100 == 0: - pickle.dump(eval_data, open(f"../data/Evaluation Data/{OUTPUT_FILENAME}.pkl", "wb")) - pickle.dump(eval_data, open(f"../data/Evaluation Data/{OUTPUT_FILENAME}.pkl", "wb")) \ No newline at end of file + # BUG: Why is the pkl output file being written twice at end + with open(pkl_output_path, "wb") as pkl_file: + pickle.dump(eval_data, pkl_file) + + # BUG: Why is the pkl output file being written twice at end + pickle.dump(eval_data, open(pkl_output_path, "wb")) From efc1e5b721a1f2276cb37585a469ac93b147ef3b Mon Sep 17 00:00:00 2001 From: exowanderer Date: Thu, 13 Mar 2025 14:20:25 +0100 Subject: [PATCH 23/49] Refactors docker/5_Run_Rerank/run.py based on flake8 output --- docker/5_Run_Rerank/run.py | 58 +++++++++++++++++++++++++++++--------- 1 file changed, 44 insertions(+), 14 deletions(-) diff --git a/docker/5_Run_Rerank/run.py b/docker/5_Run_Rerank/run.py index e443482..71b2196 100644 --- a/docker/5_Run_Rerank/run.py +++ b/docker/5_Run_Rerank/run.py @@ -1,15 +1,16 @@ -import sys -sys.path.append('../src') - -from wikidataDB import WikidataEntity -from wikidataEmbed import WikidataTextifier -from JinaAI import JinaAIReranker - from tqdm import tqdm import pandas as pd import os import pickle +# BUG: Reviwer assumed wikidataDB was converted to wikidataLangDB +# Bug: Reviwer assumed WikidataEntity was converted to WikidataLang +# from src.wikidataDB import WikidataEntity +from src.wikidataLangDB import WikidataLang +from src.wikidataEmbed import WikidataTextifier +from src.JinaAI import JinaAIReranker + + MODEL = os.getenv("MODEL", "jina") BATCH_SIZE = int(os.getenv("BATCH_SIZE", 100)) RETRIEVAL_FILENAME = os.getenv("RETRIEVAL_FILENAME") @@ -20,10 +21,17 @@ textifier = WikidataTextifier(language=LANGUAGE) reranker = JinaAIReranker() -eval_data = pickle.load(open(f"../data/Evaluation Data/{RETRIEVAL_FILENAME}.pkl", "rb")) +pkl_fpath = f"../data/Evaluation Data/{RETRIEVAL_FILENAME}.pkl" +with open(pkl_fpath, "rb") as pkl_file: + eval_data = pickle.load(pkl_file) + + +# Rerank the QIDs def rerank_qids(query, qids, reranker, textifier): - entities = [WikidataEntity.get_entity(qid) for qid in qids] + # Bug: Reviwer assumed WikidataEntity was converted to WikidataLang + # entities = [WikidataEntity.get_entity(qid) for qid in qids] + entities = [WikidataLang.get_entity(qid) for qid in qids] texts = [textifier.entity_to_text(entity) for entity in entities] scores = reranker.rank(query, texts) @@ -31,6 +39,7 @@ def rerank_qids(query, qids, reranker, textifier): score_zip = sorted(score_zip, key=lambda x: -x[0]) return [x[1] for x in score_zip] + if __name__ == "__main__": with tqdm(total=len(eval_data), disable=False) as progressbar: if 'Reranked QIDs' not in eval_data: @@ -42,12 +51,33 @@ def rerank_qids(query, qids, reranker, textifier): row = eval_data[row_to_process].iloc[i] # Rerank the QIDs - ranked_qids = rerank_qids(row[QUERY_COL], row['Retrieval QIDs'], reranker, textifier) + ranked_qids = rerank_qids( + row[QUERY_COL], + row['Retrieval QIDs'], + reranker, + textifier + ) - eval_data.loc[[row.index], 'Reranked QIDs'] = pd.Series(ranked_qids).values + eval_data.loc[[row.index], 'Reranked QIDs'] = pd.Series( + ranked_qids + ).values + # TODO: create new function to update tqdm progressbar + # tqdm is not working in docker compose. This is the alternative progressbar.update(1) - tqdm.write(progressbar.format_meter(progressbar.n, progressbar.total, progressbar.format_dict["elapsed"])) # tqdm is not wokring in docker compose. This is the alternative + tqdm.write( + progressbar.format_meter( + progressbar.n, + progressbar.total, + progressbar.format_dict["elapsed"] + ) + ) if progressbar.n % 100 == 0: - pickle.dump(eval_data, open(f"../data/Evaluation Data/{RETRIEVAL_FILENAME}.pkl", "wb")) - pickle.dump(eval_data, open(f"../data/Evaluation Data/{RETRIEVAL_FILENAME}.pkl", "wb")) \ No newline at end of file + pkl_fpath = f"../data/Evaluation Data/{RETRIEVAL_FILENAME}.pkl" + with open(pkl_fpath, "wb") as pkl_file: + pickle.dump(eval_data, pkl_file, "wb") + + # TODO: Why is this definition and open twice at the end? + pkl_fpath = f"../data/Evaluation Data/{RETRIEVAL_FILENAME}.pkl" + with open(pkl_fpath, "wb") as pkl_file: + pickle.dump(eval_data, pkl_file, "wb") From c05ed2640dc512090abfa8bbce8000c7eb2da8bd Mon Sep 17 00:00:00 2001 From: exowanderer Date: Thu, 13 Mar 2025 14:31:36 +0100 Subject: [PATCH 24/49] Refactors docker/6_Push_Huggingface/run.py based on flake8 output --- docker/5_Run_Rerank/run.py | 3 +- docker/6_Push_Huggingface/run.py | 82 +++++++++++++++++++------------- 2 files changed, 51 insertions(+), 34 deletions(-) diff --git a/docker/5_Run_Rerank/run.py b/docker/5_Run_Rerank/run.py index 71b2196..7cc26e0 100644 --- a/docker/5_Run_Rerank/run.py +++ b/docker/5_Run_Rerank/run.py @@ -4,7 +4,8 @@ import pickle # BUG: Reviwer assumed wikidataDB was converted to wikidataLangDB -# Bug: Reviwer assumed WikidataEntity was converted to WikidataLang +# Bug: Reviwer assumed wikidataDB.WikidataEntity was converted +# to wikidataLangDB.WikidataLang # from src.wikidataDB import WikidataEntity from src.wikidataLangDB import WikidataLang from src.wikidataEmbed import WikidataTextifier diff --git a/docker/6_Push_Huggingface/run.py b/docker/6_Push_Huggingface/run.py index 56898f1..753bcd7 100644 --- a/docker/6_Push_Huggingface/run.py +++ b/docker/6_Push_Huggingface/run.py @@ -1,13 +1,11 @@ import os import json -import sys from huggingface_hub import login from multiprocessing import Process, Value, Queue -sys.path.append('../src') -from wikidataDumpReader import WikidataDumpReader -from wikidataItemDB import WikidataItem +from src.wikidataDumpReader import WikidataDumpReader +from src.wikidataItemDB import WikidataItem from datasets import Dataset, load_dataset_builder @@ -18,21 +16,31 @@ API_KEY_FILENAME = os.getenv("API_KEY", "huggingface_api.json") ITERATION = int(os.getenv("ITERATION", 0)) -api_key = json.load(open(f"../API_tokens/{API_KEY_FILENAME}"))['API_KEY'] +api_key_fpath = f"../API_tokens/{API_KEY_FILENAME}" +with open(api_key_fpath) as f_in: + api_key = json.load(open(f_in))['API_KEY'] + def save_to_queue(item, data_queue): """Processes and puts cleaned item into the multiprocessing queue.""" if (item is not None) and (WikidataItem.is_in_wikipedia(item)): - claims = WikidataItem.add_labels_batched(item['claims'], query_batch=100) + claims = WikidataItem.add_labels_batched( + item['claims'], + query_batch=100 + ) data_queue.put({ 'id': item['id'], 'labels': json.dumps(item['labels'], separators=(',', ':')), - 'descriptions': json.dumps(item['descriptions'], separators=(',', ':')), + 'descriptions': json.dumps( + item['descriptions'], + separators=(',', ':') + ), 'aliases': json.dumps(item['aliases'], separators=(',', ':')), 'sitelinks': json.dumps(item['sitelinks'], separators=(',', ':')), 'claims': json.dumps(claims, separators=(',', ':')) }) + def chunk_generator(filepath, num_processes=2, queue_size=5000, skip_lines=0): """ A generator function that reads a chunk file with WikidataDumpReader, @@ -53,8 +61,12 @@ def chunk_generator(filepath, num_processes=2, queue_size=5000, skip_lines=0): # Define a function to feed items into the queue def run_reader(): - wikidata.run(lambda item: save_to_queue(item, data_queue), - max_iterations=None, verbose=True) + wikidata.run( + lambda item: save_to_queue(item, data_queue), + max_iterations=None, + verbose=True + ) + with finished.get_lock(): finished.value = 1 @@ -69,7 +81,8 @@ def run_reader(): break try: item = data_queue.get(timeout=1) - except: + except Exception as e: + print(f'Exception: {e}') continue if item: yield item @@ -77,26 +90,29 @@ def run_reader(): # Wait for the reader process to exit reader_proc.join() -# Now process each chunk file and push to the same Hugging Face repo -HF_REPO_ID = "wikidata" # Change to your actual repo on Hugging Face - -login(token=api_key) -builder = load_dataset_builder("philippesaade/wikidata") -for i in range(0, 113): - split_name = f"chunk_{i}" - if split_name not in builder.info.splits: - filepath = f"../data/Wikidata/latest-all-chunks/chunk_{i}.json.gz" - - print(f"Processing {filepath} -> split={split_name}") - - # Create a Dataset from the generator - ds_chunk = Dataset.from_generator(lambda: chunk_generator( - filepath, - num_processes=NUM_PROCESSES, - queue_size=QUEUE_SIZE, - skip_lines=SKIPLINES - )) - - # Push each chunk as a separate "split" under the same dataset repo - ds_chunk.push_to_hub(HF_REPO_ID, split=split_name) - print(f"Chunk {ITERATION} pushed to {HF_REPO_ID} as {split_name}.") + +if __name__ == "__main__": + # TODO: Convert the following into a function and run it here + # Now process each chunk file and push to the same Hugging Face repo + HF_REPO_ID = "wikidata" # Change to your actual repo on Hugging Face + + login(token=api_key) + builder = load_dataset_builder("philippesaade/wikidata") + for i in range(0, 113): + split_name = f"chunk_{i}" + if split_name not in builder.info.splits: + filepath = f"../data/Wikidata/latest-all-chunks/chunk_{i}.json.gz" + + print(f"Processing {filepath} -> split={split_name}") + + # Create a Dataset from the generator + ds_chunk = Dataset.from_generator(lambda: chunk_generator( + filepath, + num_processes=NUM_PROCESSES, + queue_size=QUEUE_SIZE, + skip_lines=SKIPLINES + )) + + # Push each chunk as a separate "split" under the same dataset repo + ds_chunk.push_to_hub(HF_REPO_ID, split=split_name) + print(f"Chunk {ITERATION} pushed to {HF_REPO_ID} as {split_name}.") From 3fd28203c752d92965ea4c4f470196737ec086eb Mon Sep 17 00:00:00 2001 From: exowanderer Date: Thu, 13 Mar 2025 17:02:15 +0100 Subject: [PATCH 25/49] Refactored docker/7_Create_Prototype/run.py based on flake8 output --- docker/7_Create_Prototype/run.py | 34 ++++++++++++++------------------ 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/docker/7_Create_Prototype/run.py b/docker/7_Create_Prototype/run.py index dbda960..427cb64 100644 --- a/docker/7_Create_Prototype/run.py +++ b/docker/7_Create_Prototype/run.py @@ -1,19 +1,16 @@ -# TODO: package with setup inside docker to avoid sys.path mods -import sys -sys.path.append('../src') +import json +import os +import hashlib +import time -from wikidataEmbed import WikidataTextifier -from wikidataRetriever import AstraDBConnect from datasets import load_dataset +from datetime import datetime from multiprocessing import Process, Queue, Manager - -import json from tqdm import tqdm -import os -from datetime import datetime -import hashlib from types import SimpleNamespace -import time + +from src.wikidataEmbed import WikidataTextifier +from src.wikidataRetriever import AstraDBConnect MODEL = os.getenv("MODEL", "jinaapi") NUM_PROCESSES = int(os.getenv("NUM_PROCESSES", 4)) @@ -27,8 +24,9 @@ CHUNK_NUM = os.getenv("CHUNK_NUM") -assert(CHUNK_NUM is not None), \ +assert CHUNK_NUM is not None, ( "Please provide `CHUNK_NUM` env var at docker run" +) LANGUAGE = "en" TEXTIFIER_LANGUAGE = "en" @@ -47,13 +45,9 @@ # TODO: Add location as env var # TODO: Sync data format from DATADUMP to chunk_sizes.json # TODO: Retrieve info from Hugging Face instead of storing it -wikidata_chunk_sizes_path = os.path.join( - # "docker", - # "7_Create_Prototype", - "wikidata_chunk_sizes_2024-09-18.json" -) +wikidata_chunksizes_path = os.path.join("wikidata_chunk_sizes_2024-09-18.json") -with open(wikidata_chunk_sizes_path) as json_in: +with open(wikidata_chunksizes_path) as json_in: chunk_sizes = json.load(json_in) total_entities = chunk_sizes[f"chunk_{CHUNK_NUM}"] @@ -68,6 +62,7 @@ split="train" ) + def process_items(queue, progress_bar): """Worker function that processes items from the queue and adds them to AstraDB. @@ -144,6 +139,7 @@ def process_items(queue, progress_bar): if not graph_store.push_batch(): # Stop when batch is empty break + if __name__ == "__main__": queue = Queue(maxsize=QUEUE_SIZE) progress_bar = Manager().Value("i", 0) @@ -171,4 +167,4 @@ def process_items(queue, progress_bar): time.sleep(1) for p in processes: - p.join() \ No newline at end of file + p.join() From 3f9f52e3935933e6307b0dcc668c96d1721abbe4 Mon Sep 17 00:00:00 2001 From: exowanderer Date: Thu, 13 Mar 2025 17:28:51 +0100 Subject: [PATCH 26/49] Refactors src/JinaAI based on output from flake8 --- src/JinaAI.py | 122 ++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 88 insertions(+), 34 deletions(-) diff --git a/src/JinaAI.py b/src/JinaAI.py index 2edacc0..2ed5cbe 100644 --- a/src/JinaAI.py +++ b/src/JinaAI.py @@ -1,31 +1,48 @@ -from typing import List import json import requests import numpy as np import base64 + +import torch # torch no long imported + + +from typing import List from wikidataCache import create_cache_embedding_model + class JinaAIEmbedder: - def __init__(self, passage_task="retrieval.passage", query_task="retrieval.query", embedding_dim=1024, cache=None): + def __init__( + self, passage_task="retrieval.passage", + query_task="retrieval.query", embedding_dim=1024, cache=None): """ - Initializes the JinaAIEmbedder class with the model, tokenizer, and task identifiers. + Initializes the JinaAIEmbedder class with the model, tokenizer, + and task identifiers. Parameters: - - passage_task (str): Task identifier for embedding documents. Defaults to "retrieval.passage". - - query_task (str): Task identifier for embedding queries. Defaults to "retrieval.query". - - embedding_dim (int): Dimensionality of the embeddings. Defaults to 1024. + - passage_task (str): Task identifier for embedding documents. + Defaults to "retrieval.passage". + - query_task (str): Task identifier for embedding queries. + Defaults to "retrieval.query". + - embedding_dim (int): Dimensionality of the embeddings. + Defaults to 1024. - cache (str): Name of caching table. - - api_key_path (str): Path to the JSON file containing the Jina API key. Defaults to "../API_tokens/jina_api.json". + - api_key_path (str): Path to the JSON file containing the + Jina API key. Defaults to "../API_tokens/jina_api.json". """ from transformers import AutoModel, AutoTokenizer - import torch self.passage_task = passage_task self.query_task = query_task self.embedding_dim = embedding_dim - self.model = AutoModel.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True).to('cuda') - self.tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True) + self.model = AutoModel.from_pretrained( + "jinaai/jina-embeddings-v3", + trust_remote_code=True + ).to('cuda') + self.tokenizer = AutoTokenizer.from_pretrained( + "jinaai/jina-embeddings-v3", + trust_remote_code=True + ) self.cache = (cache is not None) if self.cache: @@ -61,21 +78,30 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: """ Generates embeddings for a list of document (passage) texts. - Caching is not used here by default to avoid storing large numbers of document embeddings. + Caching is not used here by default to avoid storing + large numbers of document embeddings. Parameters: - texts (List[str]): A list of document texts to embed. Returns: - - List[List[float]]: A list of embedding vectors, each corresponding to a document. + - List[List[float]]: A list of embedding vectors, each corresponding + to a document. """ + with torch.no_grad(): - embeddings = self.model.encode(texts, task=self.passage_task, truncate_dim=self.embedding_dim) + embeddings = self.model.encode( + texts, + task=self.passage_task, + truncate_dim=self.embedding_dim + ) + return embeddings def embed_query(self, text: str) -> List[float]: """ - Generates an embedding for a single query string, optionally using and updating the cache. + Generates an embedding for a single query string, optionally using + and updating the cache. Parameters: - text (str): The query text to embed. @@ -88,21 +114,35 @@ def embed_query(self, text: str) -> List[float]: return cached_embedding with torch.no_grad(): - embedding = self.model.encode([text], task=self.query_task, truncate_dim=self.embedding_dim)[0] + embedding = self.model.encode( + [text], + task=self.query_task, + truncate_dim=self.embedding_dim + )[0] + self._cache_embedding(text, embedding) return embedding + class JinaAIAPIEmbedder: - def __init__(self, passage_task="retrieval.passage", query_task="retrieval.query", embedding_dim=1024, cache=False, api_key_path="../API_tokens/jina_api.json"): + def __init__( + self, passage_task="retrieval.passage", + query_task="retrieval.query", embedding_dim=1024, + api_key_path="../API_tokens/jina_api.json"): # cache=False, """ - Initializes the JinaAIEmbedder class with the model, tokenizer, and task identifiers. + Initializes the JinaAIEmbedder class with the model, tokenizer, + and task identifiers. Parameters: - - passage_task (str): Task identifier for embedding documents. Defaults to "retrieval.passage". - - query_task (str): Task identifier for embedding queries. Defaults to "retrieval.query". - - embedding_dim (int): Dimensionality of the embeddings. Defaults to 1024. - - cache (str): Name of caching table. - - api_key_path (str): Path to the JSON file containing the Jina API key. Defaults to "../API_tokens/jina_api.json". + - passage_task (str): Task identifier for embedding documents. + Defaults to "retrieval.passage". + - query_task (str): Task identifier for embedding queries. + Defaults to "retrieval.query". + - embedding_dim (int): Dimensionality of the embeddings. + Defaults to 1024. + - cache (str): Name of caching table. # BUG: cache is unused + - api_key_path (str): Path to the JSON file containing + the Jina API key. Defaults to "../API_tokens/jina_api.json". """ self.passage_task = passage_task self.query_task = query_task @@ -112,11 +152,12 @@ def __init__(self, passage_task="retrieval.passage", query_task="retrieval.query def api_embed(self, texts, task="retrieval.query"): """ - Generates an embedding for the given text using the Jina Embeddings API. + Generates an embedding for the given text using the Jina Embeddings API Parameters: - text (str): The text to embed. - - task (str): The task identifier (e.g., "retrieval.query" or "retrieval.passage"). + - task (str): The task identifier + (e.g., "retrieval.query" or "retrieval.passage"). Returns: - np.ndarray: The resulting embedding vector as a NumPy array. @@ -146,7 +187,8 @@ def api_embed(self, texts, task="retrieval.query"): embeddings = [] for item in response_data['data']: binary_data = base64.b64decode(item['embedding']) - embedding_array = np.frombuffer(binary_data, dtype=' List[List[float]]: """ Generates embeddings for a list of document (passage) texts. - Caching is not used here by default to avoid storing large numbers of document embeddings. + Caching is not used here by default to avoid storing large numbers + of document embeddings. Parameters: - texts (List[str]): A list of document texts to embed. Returns: - - List[List[float]]: A list of embedding vectors, each corresponding to a document. + - List[List[float]]: A list of embedding vectors, each corresponding + to a document. """ embeddings = self.api_embed(texts, task=self.passage_task) return embeddings def embed_query(self, text: str) -> List[float]: """ - Generates an embedding for a single query string, optionally using and updating the cache. + Generates an embedding for a single query string, optionally using + and updating the cache. Parameters: - text (str): The query text to embed. @@ -179,25 +224,30 @@ def embed_query(self, text: str) -> List[float]: embedding = self.api_embed([text], task=self.query_task) return embedding + class JinaAIReranker: def __init__(self, max_tokens=1024): """ - Initializes the JinaAIReranker with a maximum token length and the Jina Reranker model. + Initializes the JinaAIReranker with a maximum token length + and the Jina Reranker model. Parameters: - - max_tokens (int): Maximum sequence length for the reranker (must be <= 1024). + - max_tokens (int): Maximum sequence length for the reranker + (must be <= 1024). Raises: - ValueError: If max_tokens is greater than 1024. """ from transformers import AutoModelForSequenceClassification - import torch if max_tokens > 1024: raise ValueError("Max token should be less than or equal to 1024") self.max_tokens = max_tokens - self.model = AutoModelForSequenceClassification.from_pretrained('jinaai/jina-reranker-v2-base-multilingual', trust_remote_code=True).to('cuda') + self.model = AutoModelForSequenceClassification.from_pretrained( + 'jinaai/jina-reranker-v2-base-multilingual', + trust_remote_code=True + ).to('cuda') def rank(self, query: str, texts: List[str]) -> List[float]: """ @@ -208,9 +258,13 @@ def rank(self, query: str, texts: List[str]) -> List[float]: - texts (List[str]): A list of document texts to rank. Returns: - - List[float]: A list of relevance scores, each corresponding to one document. + - List[float]: A list of relevance scores, each corresponding + to one document. """ sentence_pairs = [[query, doc] for doc in texts] with torch.no_grad(): - return self.model.compute_score(sentence_pairs, max_length=self.max_tokens) \ No newline at end of file + return self.model.compute_score( + sentence_pairs, + max_length=self.max_tokens + ) From 3bc445430f6a54f70ccf9218bae401b3317f9de5 Mon Sep 17 00:00:00 2001 From: exowanderer Date: Thu, 13 Mar 2025 17:30:35 +0100 Subject: [PATCH 27/49] Refactors src/__init__.py based on output from flake8 --- src/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/__init__.py b/src/__init__.py index d99d86d..9ef178f 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -3,4 +3,4 @@ from .wikidataItemDB import WikidataItem from .wikidataEmbed import WikidataTextifier from .JinaAI import JinaAIEmbedder, JinaAIReranker, JinaAIAPIEmbedder -from .wikidataRetriever import AstraDBConnect, KeywordSearchConnect \ No newline at end of file +from .wikidataRetriever import AstraDBConnect, KeywordSearchConnect From 19ada1486c20049a820c42081cc9e442df90ecf3 Mon Sep 17 00:00:00 2001 From: exowanderer Date: Thu, 13 Mar 2025 17:43:18 +0100 Subject: [PATCH 28/49] Refactored src/experimental_functions/word_embeding.py based on output from flake8 --- src/experimental_functions/word_embeding.py | 73 ++++++++++++++++----- 1 file changed, 58 insertions(+), 15 deletions(-) diff --git a/src/experimental_functions/word_embeding.py b/src/experimental_functions/word_embeding.py index c009441..3333743 100644 --- a/src/experimental_functions/word_embeding.py +++ b/src/experimental_functions/word_embeding.py @@ -1,9 +1,10 @@ import torch -import torch.nn.functional as F +# import torch.nn.functional as F from JinaAI import JinaAIEmbedder model = JinaAIEmbedder() + def mean_pooling(model_output, attention_mask): token_embeddings = model_output[0] input_mask_expanded = ( @@ -13,7 +14,22 @@ def mean_pooling(model_output, attention_mask): input_mask_expanded.sum(1), min=1e-9 ) -def extract_specific_word_embedding(model, tokenizer, sentence, start_marker='', end_marker=''): + +def extract_specific_word_embedding( + model, tokenizer, sentence, start_marker='', end_marker=''): + """_summary_ # TODO: fill out docstring + + Args: + model (_type_): _description_ + tokenizer (_type_): _description_ + sentence (_type_): _description_ + start_marker (str, optional): _description_. Defaults to ''. + end_marker (str, optional): _description_. Defaults to ''. + + Returns: + _type_: _description_ + """ + # Check if the markers are in the sentence if start_marker in sentence and end_marker in sentence: # Find the position of the markers @@ -21,29 +37,46 @@ def extract_specific_word_embedding(model, tokenizer, sentence, start_marker='bank to withdraw some money." +) +embedding = extract_specific_word_embedding( + model.model, + model.tokenizer, + sentence +) +print("Extracted Embedding:", embedding) From b42436b444f5e5c10f23c39b06210ebdacc11d16 Mon Sep 17 00:00:00 2001 From: exowanderer Date: Thu, 13 Mar 2025 17:58:17 +0100 Subject: [PATCH 29/49] Refactored src/wikidataCache.py based on output from flake8 --- src/wikidataCache.py | 38 ++++++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/src/wikidataCache.py b/src/wikidataCache.py index 4da927a..34018d2 100644 --- a/src/wikidataCache.py +++ b/src/wikidataCache.py @@ -7,7 +7,8 @@ import json """ -SQLite database setup for caching the query embeddings for a faster evaluation process. +SQLite database setup for caching the query embeddings for a faster +evaluation process. """ # TODO: Move to a configuration file @@ -22,7 +23,7 @@ except OSError as e: print(f"Error creating directory {wikidata_cache_dir}: {e}") -assert(os.path.exists(wikidata_cache_dir)), \ +assert os.path.exists(wikidata_cache_dir), \ f"Error creating directory {wikidata_cache_dir}" engine = create_engine( @@ -35,6 +36,7 @@ Base = declarative_base() Session = sessionmaker(bind=engine) + class JSONType(TypeDecorator): """Custom SQLAlchemy type for JSON storage in SQLite.""" impl = Text @@ -49,6 +51,7 @@ def process_result_value(self, value, dialect): return json.loads(value) return None + def create_cache_embedding_model(table_name): """Factory function to create a dynamic CacheEmbeddings model.""" @@ -73,7 +76,10 @@ def add_cache(id, embedding): @staticmethod def get_cache(id): with Session() as session: - cached = session.query(CacheEmbeddings).filter_by(id=id).first() + cached = session.query( + CacheEmbeddings + ).filter_by(id=id).first() + if cached: return cached.embedding return None @@ -81,28 +87,28 @@ def get_cache(id): @staticmethod def add_bulk_cache(data): """ - Insert multiple label records in bulk. If a record with the same ID exists, + Insert multiple label records in bulk. If a record with the same + ID exists, it is ignored (no update is performed). Parameters: - - data (list[dict]): A list of dictionaries, each containing 'id', 'labels', 'descriptions', and 'in_wikipedia' keys. + - data (list[dict]): A list of dictionaries, each containing 'id', + 'labels', 'descriptions', and 'in_wikipedia' keys. Returns: - bool: True if the operation was successful, False otherwise. """ worked = False with Session() as session: + exec_text = text( + f""" + INSERT INTO {CacheEmbeddings.__tablename__} (id, embedding) + VALUES (:id, :embedding) + ON CONFLICT(id) DO NOTHING + """ + ) try: - session.execute( - text( - f""" - INSERT INTO {CacheEmbeddings.__tablename__} (id, embedding) - VALUES (:id, :embedding) - ON CONFLICT(id) DO NOTHING - """ - ), - data - ) + session.execute(exec_text, data) session.commit() session.flush() worked = True @@ -113,4 +119,4 @@ def add_bulk_cache(data): Base.metadata.create_all(engine) - return CacheEmbeddings \ No newline at end of file + return CacheEmbeddings From 3e49a7479ebfd57dcce2f7711d4389d10558d882 Mon Sep 17 00:00:00 2001 From: Philippe Saade Date: Thu, 13 Mar 2025 18:10:31 +0100 Subject: [PATCH 30/49] Handle Jina and DataStax API errors --- docker/7_Create_Prototype/run.py | 43 ++++++++++++++++++++++++-------- src/wikidataRetriever.py | 31 ++++++++++++++--------- 2 files changed, 52 insertions(+), 22 deletions(-) diff --git a/docker/7_Create_Prototype/run.py b/docker/7_Create_Prototype/run.py index cb69030..3ced7b7 100644 --- a/docker/7_Create_Prototype/run.py +++ b/docker/7_Create_Prototype/run.py @@ -48,8 +48,17 @@ def process_items(queue, progress_bar): """Worker function that processes items from the queue and adds them to AstraDB.""" datastax_token = json.load(open(f"../API_tokens/{DB_API_KEY_FILENAME}")) - graph_store = AstraDBConnect(datastax_token, COLLECTION_NAME, model=MODEL, batch_size=EMBED_BATCH_SIZE, cache_embeddings="wikidata_prototype") - textifier = WikidataTextifier(language=LANGUAGE, langvar_filename=TEXTIFIER_LANGUAGE) + graph_store = AstraDBConnect( + datastax_token, + COLLECTION_NAME, + model=MODEL, + batch_size=EMBED_BATCH_SIZE, + cache_embeddings="wikidata_prototype" + ) + textifier = WikidataTextifier( + language=LANGUAGE, + langvar_filename=TEXTIFIER_LANGUAGE + ) while True: item = queue.get() @@ -57,9 +66,17 @@ def process_items(queue, progress_bar): break # Exit condition for worker processes item_id = item['id'] - item_label = textifier.get_label(item_id, json.loads(item['labels'])) - item_description = textifier.get_description(item_id, json.loads(item['descriptions'])) - item_aliases = textifier.get_aliases(json.loads(item['aliases'])) + item_label = textifier.get_label( + item_id, + json.loads(item['labels']) + ) + item_description = textifier.get_description( + item_id, + json.loads(item['descriptions']) + ) + item_aliases = textifier.get_aliases( + json.loads(item['aliases']) + ) if item_label is not None: entity_obj = SimpleNamespace() @@ -69,7 +86,11 @@ def process_items(queue, progress_bar): entity_obj.aliases = item_aliases entity_obj.claims = json.loads(item['claims']) - chunks = textifier.chunk_text(entity_obj, graph_store.tokenizer, max_length=graph_store.max_token_size) + chunks = textifier.chunk_text( + entity_obj, + graph_store.tokenizer, + max_length=graph_store.max_token_size + ) for chunk_i, chunk in enumerate(chunks): md5_hash = hashlib.md5(chunk.encode('utf-8')).hexdigest() @@ -86,13 +107,15 @@ def process_items(queue, progress_bar): "IsProperty": ('P' in item_id), "DumpDate": DUMPDATE } - graph_store.add_document(id=f"{item_id}_{LANGUAGE}_{chunk_i+1}", text=chunk, metadata=metadata) + graph_store.add_document( + id=f"{item_id}_{LANGUAGE}_{chunk_i+1}",\ + text=chunk, + metadata=metadata + ) progress_bar.value += 1 - while True: - if not graph_store.push_batch(): # Stop when batch is empty - break + graph_store.push_all() if __name__ == "__main__": queue = Queue(maxsize=QUEUE_SIZE) diff --git a/src/wikidataRetriever.py b/src/wikidataRetriever.py index 9500d8b..fd42068 100644 --- a/src/wikidataRetriever.py +++ b/src/wikidataRetriever.py @@ -17,6 +17,7 @@ def __init__(self, datastax_token, collection_name, model='nvidia', batch_size=8 from langchain_astradb import AstraDBVectorStore from astrapy.info import CollectionVectorServiceOptions from astrapy import DataAPIClient + from astrapy.exceptions import InsertManyException from multiprocessing import Queue from transformers import AutoTokenizer @@ -126,19 +127,25 @@ def push_batch(self): if len(docs) == 0: return False - try: - vectors = self.embeddings.embed_documents( - [doc['content'] for doc in docs] - ) - self.graph_store.insert_many(docs, vectors=vectors) - except Exception as e: - print(e) - - # Put the documents back in the Queue and try again later. - for doc in docs: - self.doc_batch.put(doc) + while True: + try: + vectors = self.embeddings.embed_documents( + [doc['content'] for doc in docs] + ) + break + except Exception as e: + print(e) + time.sleep(3) - return False + while True: + try: + self.graph_store.insert_many(docs, vectors=vectors) + break + except InsertManyException as e: + pass + except Exception as e: + print(e) + time.sleep(3) self.cache_model.add_bulk_cache([{ 'id': docs[i]['_id'], From 7db1df923cbea4d97d28d7880b9ee1c666a3cad2 Mon Sep 17 00:00:00 2001 From: Philippe Saade Date: Thu, 13 Mar 2025 18:12:07 +0100 Subject: [PATCH 31/49] Handle Jina and DataStax API errors --- src/wikidataRetriever.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/wikidataRetriever.py b/src/wikidataRetriever.py index fd42068..083a8f0 100644 --- a/src/wikidataRetriever.py +++ b/src/wikidataRetriever.py @@ -31,6 +31,7 @@ def __init__(self, datastax_token, collection_name, model='nvidia', batch_size=8 self.model = model self.collection_name = collection_name self.doc_batch = Queue() + self.InsertManyException = InsertManyException self.cache_on = (cache_embeddings is not None) if self.cache_on: @@ -141,7 +142,7 @@ def push_batch(self): try: self.graph_store.insert_many(docs, vectors=vectors) break - except InsertManyException as e: + except self.InsertManyException as e: pass except Exception as e: print(e) From b349ff12f872098a757dc1f6bf2e421def446be1 Mon Sep 17 00:00:00 2001 From: exowanderer Date: Fri, 14 Mar 2025 00:05:25 +0100 Subject: [PATCH 32/49] tested new PYTHONPATH; Changed ENV PYTHONPATH=':/src' to ENV PYTHONPATH=':/'; update use of src. in various python imports; update use of src. in src.language_variables. --- .../1_Data_Processing_save_labels_descriptions/Dockerfile | 2 +- docker/2_Data_Processing_save_items_per_lang/Dockerfile | 2 +- docker/4_Run_Retrieval/Dockerfile | 2 +- docker/5_Run_Rerank/Dockerfile | 2 +- docker/6_Push_Huggingface/Dockerfile | 2 +- docker/7_Create_Prototype/Dockerfile | 2 +- src/JinaAI.py | 4 ++-- src/experimental_functions/word_embeding.py | 4 ++-- src/wikidataEmbed.py | 7 ++++--- src/wikidataRetriever.py | 4 ++-- 10 files changed, 16 insertions(+), 15 deletions(-) diff --git a/docker/1_Data_Processing_save_labels_descriptions/Dockerfile b/docker/1_Data_Processing_save_labels_descriptions/Dockerfile index c3f5f2f..cc94072 100644 --- a/docker/1_Data_Processing_save_labels_descriptions/Dockerfile +++ b/docker/1_Data_Processing_save_labels_descriptions/Dockerfile @@ -22,7 +22,7 @@ COPY ./src /src # Set up the volume for the data folder VOLUME [ "/data" ] -ENV PYTHONPATH="${PYTHONPATH}:/src" +ENV PYTHONPATH="${PYTHONPATH}:/" # Run the Python script CMD ["python", "run.py"] \ No newline at end of file diff --git a/docker/2_Data_Processing_save_items_per_lang/Dockerfile b/docker/2_Data_Processing_save_items_per_lang/Dockerfile index fd76838..69a5f2c 100644 --- a/docker/2_Data_Processing_save_items_per_lang/Dockerfile +++ b/docker/2_Data_Processing_save_items_per_lang/Dockerfile @@ -22,7 +22,7 @@ COPY ./src /src # Set up the volume for the data folder VOLUME [ "/data" ] -ENV PYTHONPATH="${PYTHONPATH}:/src" +ENV PYTHONPATH="${PYTHONPATH}:/" # Run the Python script CMD ["python", "run.py"] \ No newline at end of file diff --git a/docker/4_Run_Retrieval/Dockerfile b/docker/4_Run_Retrieval/Dockerfile index c978351..d6c16ca 100644 --- a/docker/4_Run_Retrieval/Dockerfile +++ b/docker/4_Run_Retrieval/Dockerfile @@ -28,7 +28,7 @@ COPY ./API_tokens /API_tokens # Set up the volume for the data folder VOLUME [ "/data" ] -ENV PYTHONPATH="${PYTHONPATH}:/src" +ENV PYTHONPATH="${PYTHONPATH}:/" # Run the Python script CMD ["python", "run.py"] \ No newline at end of file diff --git a/docker/5_Run_Rerank/Dockerfile b/docker/5_Run_Rerank/Dockerfile index fbdf497..fc44395 100644 --- a/docker/5_Run_Rerank/Dockerfile +++ b/docker/5_Run_Rerank/Dockerfile @@ -28,7 +28,7 @@ COPY ./API_tokens /API_tokens # Set up the volume for the data folder VOLUME [ "/data" ] -ENV PYTHONPATH="${PYTHONPATH}:/src" +ENV PYTHONPATH="${PYTHONPATH}:/" # Run the Python script CMD ["python", "run.py"] \ No newline at end of file diff --git a/docker/6_Push_Huggingface/Dockerfile b/docker/6_Push_Huggingface/Dockerfile index 72e86f3..88fa20b 100644 --- a/docker/6_Push_Huggingface/Dockerfile +++ b/docker/6_Push_Huggingface/Dockerfile @@ -23,7 +23,7 @@ COPY ./API_tokens /API_tokens # Set up the volume for the data folder VOLUME [ "/data" ] -ENV PYTHONPATH="${PYTHONPATH}:/src" +ENV PYTHONPATH="${PYTHONPATH}:/" # Run the Python script CMD ["python", "run.py"] \ No newline at end of file diff --git a/docker/7_Create_Prototype/Dockerfile b/docker/7_Create_Prototype/Dockerfile index 33bbd7b..548778a 100644 --- a/docker/7_Create_Prototype/Dockerfile +++ b/docker/7_Create_Prototype/Dockerfile @@ -26,7 +26,7 @@ COPY ./API_tokens /API_tokens # Set up the volume for the data folder VOLUME [ "/data" ] -ENV PYTHONPATH="${PYTHONPATH}:/src" +ENV PYTHONPATH="${PYTHONPATH}:/" # Run the Python script CMD ["python", "run.py"] \ No newline at end of file diff --git a/src/JinaAI.py b/src/JinaAI.py index 2ed5cbe..e278266 100644 --- a/src/JinaAI.py +++ b/src/JinaAI.py @@ -3,11 +3,11 @@ import numpy as np import base64 -import torch # torch no long imported +# import torch # torch no long imported from typing import List -from wikidataCache import create_cache_embedding_model +from src.wikidataCache import create_cache_embedding_model class JinaAIEmbedder: diff --git a/src/experimental_functions/word_embeding.py b/src/experimental_functions/word_embeding.py index 3333743..2e9b718 100644 --- a/src/experimental_functions/word_embeding.py +++ b/src/experimental_functions/word_embeding.py @@ -1,6 +1,6 @@ -import torch +# import torch # import torch.nn.functional as F -from JinaAI import JinaAIEmbedder +from src.JinaAI import JinaAIEmbedder model = JinaAIEmbedder() diff --git a/src/wikidataEmbed.py b/src/wikidataEmbed.py index 1f4e46b..72a2836 100644 --- a/src/wikidataEmbed.py +++ b/src/wikidataEmbed.py @@ -1,11 +1,12 @@ -from wikidataItemDB import WikidataItem import requests import time import json -from datetime import date, datetime import re import importlib +from datetime import date, datetime +from src.wikidataItemDB import WikidataItem + class WikidataTextifier: """_summary_ """ @@ -27,7 +28,7 @@ def __init__(self, language='en', langvar_filename=None): # Importing custom functions and variables # from a formating python script in the language_variables folder. self.langvar = importlib.import_module( - f"language_variables.{langvar_filename}" + f"src.language_variables.{langvar_filename}" ) except Exception as e: raise ValueError(f"Language file for '{language}' not found.") diff --git a/src/wikidataRetriever.py b/src/wikidataRetriever.py index 23f2ad1..d35dcaf 100644 --- a/src/wikidataRetriever.py +++ b/src/wikidataRetriever.py @@ -1,6 +1,6 @@ import time import json -from wikidataCache import create_cache_embedding_model +from src.wikidataCache import create_cache_embedding_model class AstraDBConnect: def __init__(self, datastax_token, collection_name, model='jina', batch_size=8, cache_embeddings=None): @@ -20,7 +20,7 @@ def __init__(self, datastax_token, collection_name, model='jina', batch_size=8, from multiprocessing import Queue from transformers import AutoTokenizer - from JinaAI import JinaAIEmbedder, JinaAIAPIEmbedder + from src.JinaAI import JinaAIEmbedder, JinaAIAPIEmbedder ASTRA_DB_APPLICATION_TOKEN = datastax_token['ASTRA_DB_APPLICATION_TOKEN'] ASTRA_DB_API_ENDPOINT = datastax_token["ASTRA_DB_API_ENDPOINT"] From e248427a976559fd484763f2b93a64f4db80e0e2 Mon Sep 17 00:00:00 2001 From: Philippe Saade Date: Fri, 14 Mar 2025 01:43:19 +0100 Subject: [PATCH 33/49] Fix duplicate ids error --- src/wikidataRetriever.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/wikidataRetriever.py b/src/wikidataRetriever.py index 083a8f0..ee4e139 100644 --- a/src/wikidataRetriever.py +++ b/src/wikidataRetriever.py @@ -143,7 +143,7 @@ def push_batch(self): self.graph_store.insert_many(docs, vectors=vectors) break except self.InsertManyException as e: - pass + break except Exception as e: print(e) time.sleep(3) From 74921864e1fe91fdba99841b3bae34c5c1be5b3f Mon Sep 17 00:00:00 2001 From: exowanderer Date: Fri, 14 Mar 2025 09:55:05 +0100 Subject: [PATCH 34/49] Changed ENV PYTHONPATH=:/src to ENV PYTHONPATH=:/ in Docker/3*/Dockerfile --- docker/3_Add_Wikidata_to_AstraDB/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/3_Add_Wikidata_to_AstraDB/Dockerfile b/docker/3_Add_Wikidata_to_AstraDB/Dockerfile index 51591da..cf4665e 100644 --- a/docker/3_Add_Wikidata_to_AstraDB/Dockerfile +++ b/docker/3_Add_Wikidata_to_AstraDB/Dockerfile @@ -28,7 +28,7 @@ COPY ./API_tokens /API_tokens # Set up the volume for the data folder VOLUME [ "/data" ] -ENV PYTHONPATH="${PYTHONPATH}:/src" +ENV PYTHONPATH="${PYTHONPATH}:/" # Run the Python script CMD ["python", "run.py"] \ No newline at end of file From b22135119d555031e0d4884aa914856ae0fcbe54 Mon Sep 17 00:00:00 2001 From: Philippe Saade Date: Sat, 15 Mar 2025 14:19:39 +0100 Subject: [PATCH 35/49] Fix bulk caching of embeddings --- src/wikidataCache.py | 6 ++++++ src/wikidataRetriever.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/wikidataCache.py b/src/wikidataCache.py index 5fea4cf..bf1d575 100644 --- a/src/wikidataCache.py +++ b/src/wikidataCache.py @@ -84,6 +84,12 @@ def add_bulk_cache(data): - bool: True if the operation was successful, False otherwise. """ worked = False + embeddingtype = EmbeddingType() + for i in range(len(data)): + data[i]['embedding'] = embeddingtype.process_bind_param( + data[i]['embedding'] + ) + with Session() as session: try: session.execute( diff --git a/src/wikidataRetriever.py b/src/wikidataRetriever.py index ee4e139..4de6f09 100644 --- a/src/wikidataRetriever.py +++ b/src/wikidataRetriever.py @@ -150,7 +150,7 @@ def push_batch(self): self.cache_model.add_bulk_cache([{ 'id': docs[i]['_id'], - 'embedding': json.dumps(vectors[i], separators=(',', ':'))} + 'embedding': vectors[i]} for i in range(len(docs))]) return True From 58b0a8805a9be636f2bfe91f9ab986d9965e1861 Mon Sep 17 00:00:00 2001 From: Philippe Saade Date: Sat, 15 Mar 2025 14:22:34 +0100 Subject: [PATCH 36/49] Fix bulk caching of embeddings --- src/wikidataCache.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/wikidataCache.py b/src/wikidataCache.py index bf1d575..e53d138 100644 --- a/src/wikidataCache.py +++ b/src/wikidataCache.py @@ -87,7 +87,8 @@ def add_bulk_cache(data): embeddingtype = EmbeddingType() for i in range(len(data)): data[i]['embedding'] = embeddingtype.process_bind_param( - data[i]['embedding'] + data[i]['embedding'], + None ) with Session() as session: From b6aa2050afb60c91e25a641776005ae63c34534c Mon Sep 17 00:00:00 2001 From: Philippe Saade Date: Mon, 17 Mar 2025 11:15:43 +0100 Subject: [PATCH 37/49] Include merge cache script --- src/merge_cache.py | 91 ++++++++++++++++++++++++++++++++++++++++++++ src/migrate_cache.py | 2 +- 2 files changed, 92 insertions(+), 1 deletion(-) create mode 100644 src/merge_cache.py diff --git a/src/merge_cache.py b/src/merge_cache.py new file mode 100644 index 0000000..d26c85b --- /dev/null +++ b/src/merge_cache.py @@ -0,0 +1,91 @@ +import sqlite3 +import glob +import os +import base64 +from tqdm import tqdm + +# Define database file pattern +db_files = glob.glob("../data/Wikidata/sqlite_cacheembeddings_*.db") + +# Define the target merged database +merged_db = "../data/Wikidata/sqlite_cacheembeddings_merged.db" +TABLE_NAME = "wikidata_prototype" + +# Batch size for processing +BATCH_SIZE = 1000 # Adjust based on performance needs + +# Create the merged database connection +conn_merged = sqlite3.connect(merged_db) +cursor_merged = conn_merged.cursor() + +# Create table in the merged database if it doesn't exist +cursor_merged.execute(f""" +CREATE TABLE IF NOT EXISTS {TABLE_NAME} ( + id TEXT PRIMARY KEY, + embedding TEXT +); +""") +conn_merged.commit() + +# Helper function to check if a string is a valid Base64 encoding +def is_valid_base64(s): + try: + if not s or not isinstance(s, str): + return False + base64.b64decode(s, validate=True) + return True + except Exception: + return False + +# Loop through all source databases +for db_file in db_files: + print(f"Processing {db_file}...") + + # Connect to the current database + conn_src = sqlite3.connect(db_file) + cursor_src = conn_src.cursor() + + # Get total record count for progress tracking + cursor_src.execute(f"SELECT COUNT(*) FROM {TABLE_NAME}") + total_records = cursor_src.fetchone()[0] + + # Fetch records in batches + offset = 0 + with tqdm(total=total_records, + desc=f"Merging {db_file}", unit="records") as pbar: + while True: + cursor_src.execute(f"SELECT id, embedding FROM {TABLE_NAME} LIMIT {BATCH_SIZE} OFFSET {offset}") + records = cursor_src.fetchall() + if not records: + break # No more records to process + + # Prepare batch for insertion + batch_data = [] + for id_, embedding in records: + if embedding and embedding.strip() and is_valid_base64(embedding): + cursor_merged.execute(f"SELECT embedding FROM {TABLE_NAME} WHERE id = ?", (id_,)) + existing = cursor_merged.fetchone() + + if existing is None or not is_valid_base64(existing[0]): + batch_data.append((id_, embedding, embedding)) # Prepare for bulk insert + + # Perform batch insert/update + if batch_data: + cursor_merged.executemany( + f""" + INSERT INTO {TABLE_NAME} (id, embedding) + VALUES (?, ?) + ON CONFLICT(id) DO UPDATE SET embedding = ? + """, + batch_data + ) + conn_merged.commit() + + offset += BATCH_SIZE # Move to the next batch + pbar.update(len(records)) # Update tqdm progress bar + + conn_src.close() + +# Close merged database connection +conn_merged.close() +print(f"Merge completed! Combined database saved as {merged_db}") diff --git a/src/migrate_cache.py b/src/migrate_cache.py index ca2429c..765859b 100644 --- a/src/migrate_cache.py +++ b/src/migrate_cache.py @@ -51,7 +51,7 @@ def convert_embeddings(): updated_records.append((base64_embedding, id)) except Exception as e: - print(f"\nSkipping ID {id} due to error: {e}") + pass pbar.update(1) # Update progress bar for each record processed From 30c3e3b26064b43f27b3b22add367e9125f0425f Mon Sep 17 00:00:00 2001 From: Philippe Saade Date: Mon, 17 Mar 2025 11:19:37 +0100 Subject: [PATCH 38/49] Vacuuming the database when migrating --- src/migrate_cache.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/migrate_cache.py b/src/migrate_cache.py index 765859b..e722b78 100644 --- a/src/migrate_cache.py +++ b/src/migrate_cache.py @@ -65,7 +65,12 @@ def convert_embeddings(): offset += BATCH_SIZE # Move to next batch + print("Optimizing database with VACUUM...") + cursor.execute("VACUUM;") + conn.commit() + print("Migration completed successfully.") + conn.close() if __name__ == "__main__": From d6ed4aa94bd0614cc88139f706bd10cd4e9ac186 Mon Sep 17 00:00:00 2001 From: exowanderer Date: Mon, 17 Mar 2025 12:54:49 +0100 Subject: [PATCH 39/49] added more todos for docker 7 run.py --- docker/7_Create_Prototype/run.py | 1 + 1 file changed, 1 insertion(+) diff --git a/docker/7_Create_Prototype/run.py b/docker/7_Create_Prototype/run.py index 57de162..2438508 100644 --- a/docker/7_Create_Prototype/run.py +++ b/docker/7_Create_Prototype/run.py @@ -45,6 +45,7 @@ # TODO: Add location as env var # TODO: Sync data format from DATADUMP to chunk_sizes.json # TODO: Retrieve info from Hugging Face instead of storing it +# TODO: Set this up as a separate script and run after __name__ wikidata_chunksizes_path = os.path.join("wikidata_chunk_sizes_2024-09-18.json") with open(wikidata_chunksizes_path) as json_in: From 3d1c9ae7ac27ee64895908fb7fbe08d16b070f6e Mon Sep 17 00:00:00 2001 From: exowanderer Date: Mon, 17 Mar 2025 17:16:33 +0100 Subject: [PATCH 40/49] added blank line at bottom of all Dockerfile files --- .../1_Data_Processing_save_labels_descriptions/Dockerfile | 2 +- docker/2_Data_Processing_save_items_per_lang/Dockerfile | 2 +- docker/3_Add_Wikidata_to_AstraDB/Dockerfile | 2 +- docker/4_Run_Retrieval/Dockerfile | 2 +- docker/5_Run_Rerank/Dockerfile | 2 +- docker/6_Push_Huggingface/Dockerfile | 2 +- docker/7_Create_Prototype/Dockerfile | 2 +- src/migrate_cache.py | 6 +++++- 8 files changed, 12 insertions(+), 8 deletions(-) diff --git a/docker/1_Data_Processing_save_labels_descriptions/Dockerfile b/docker/1_Data_Processing_save_labels_descriptions/Dockerfile index cc94072..9057498 100644 --- a/docker/1_Data_Processing_save_labels_descriptions/Dockerfile +++ b/docker/1_Data_Processing_save_labels_descriptions/Dockerfile @@ -25,4 +25,4 @@ VOLUME [ "/data" ] ENV PYTHONPATH="${PYTHONPATH}:/" # Run the Python script -CMD ["python", "run.py"] \ No newline at end of file +CMD ["python", "run.py"] diff --git a/docker/2_Data_Processing_save_items_per_lang/Dockerfile b/docker/2_Data_Processing_save_items_per_lang/Dockerfile index 69a5f2c..2c12567 100644 --- a/docker/2_Data_Processing_save_items_per_lang/Dockerfile +++ b/docker/2_Data_Processing_save_items_per_lang/Dockerfile @@ -25,4 +25,4 @@ VOLUME [ "/data" ] ENV PYTHONPATH="${PYTHONPATH}:/" # Run the Python script -CMD ["python", "run.py"] \ No newline at end of file +CMD ["python", "run.py"] diff --git a/docker/3_Add_Wikidata_to_AstraDB/Dockerfile b/docker/3_Add_Wikidata_to_AstraDB/Dockerfile index cf4665e..eb6f3ca 100644 --- a/docker/3_Add_Wikidata_to_AstraDB/Dockerfile +++ b/docker/3_Add_Wikidata_to_AstraDB/Dockerfile @@ -31,4 +31,4 @@ VOLUME [ "/data" ] ENV PYTHONPATH="${PYTHONPATH}:/" # Run the Python script -CMD ["python", "run.py"] \ No newline at end of file +CMD ["python", "run.py"] diff --git a/docker/4_Run_Retrieval/Dockerfile b/docker/4_Run_Retrieval/Dockerfile index d6c16ca..6936f7f 100644 --- a/docker/4_Run_Retrieval/Dockerfile +++ b/docker/4_Run_Retrieval/Dockerfile @@ -31,4 +31,4 @@ VOLUME [ "/data" ] ENV PYTHONPATH="${PYTHONPATH}:/" # Run the Python script -CMD ["python", "run.py"] \ No newline at end of file +CMD ["python", "run.py"] diff --git a/docker/5_Run_Rerank/Dockerfile b/docker/5_Run_Rerank/Dockerfile index fc44395..a624029 100644 --- a/docker/5_Run_Rerank/Dockerfile +++ b/docker/5_Run_Rerank/Dockerfile @@ -31,4 +31,4 @@ VOLUME [ "/data" ] ENV PYTHONPATH="${PYTHONPATH}:/" # Run the Python script -CMD ["python", "run.py"] \ No newline at end of file +CMD ["python", "run.py"] diff --git a/docker/6_Push_Huggingface/Dockerfile b/docker/6_Push_Huggingface/Dockerfile index 88fa20b..c835d98 100644 --- a/docker/6_Push_Huggingface/Dockerfile +++ b/docker/6_Push_Huggingface/Dockerfile @@ -26,4 +26,4 @@ VOLUME [ "/data" ] ENV PYTHONPATH="${PYTHONPATH}:/" # Run the Python script -CMD ["python", "run.py"] \ No newline at end of file +CMD ["python", "run.py"] diff --git a/docker/7_Create_Prototype/Dockerfile b/docker/7_Create_Prototype/Dockerfile index 548778a..3ebbb00 100644 --- a/docker/7_Create_Prototype/Dockerfile +++ b/docker/7_Create_Prototype/Dockerfile @@ -29,4 +29,4 @@ VOLUME [ "/data" ] ENV PYTHONPATH="${PYTHONPATH}:/" # Run the Python script -CMD ["python", "run.py"] \ No newline at end of file +CMD ["python", "run.py"] diff --git a/src/migrate_cache.py b/src/migrate_cache.py index e722b78..5096bd1 100644 --- a/src/migrate_cache.py +++ b/src/migrate_cache.py @@ -4,7 +4,9 @@ import numpy as np from tqdm import tqdm -DB_PATH = "../data/Wikidata/sqlite_cacheembeddings.db" +# Change this to match your actual database path +DB_PATH = "data/Wikidata/wikidata_cache.db" + TABLE_NAME = "wikidata_prototype" # Change this to match your actual table name BATCH_SIZE = 5000 # Process in smaller batches to avoid memory overload @@ -13,6 +15,8 @@ def convert_embeddings(): Convert JSON-stored embeddings into Base64-encoded binary format in batches. Uses `fetchmany(BATCH_SIZE)` to process records iteratively. """ + # TODO: Migrate away from global variables + print(f"Converting embeddings in {DB_PATH}...") conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() From 4d0e55a75906a32b23bf6951c21f68ae1da4a4cc Mon Sep 17 00:00:00 2001 From: exowanderer Date: Mon, 17 Mar 2025 17:19:07 +0100 Subject: [PATCH 41/49] added default setting for lang_in_wp to avoid collision --- docker/2_Data_Processing_save_items_per_lang/run.py | 1 + 1 file changed, 1 insertion(+) diff --git a/docker/2_Data_Processing_save_items_per_lang/run.py b/docker/2_Data_Processing_save_items_per_lang/run.py index aec08b6..974f949 100644 --- a/docker/2_Data_Processing_save_items_per_lang/run.py +++ b/docker/2_Data_Processing_save_items_per_lang/run.py @@ -14,6 +14,7 @@ def save_entities_to_sqlite(item, data_batch, sqlitDBlock): + lang_in_wp = False # Default setting is_not_none = item is not None if is_not_none: lang_in_wp = WikidataLang.is_in_wikipedia(item, language=LANGUAGE) From 2eea23801eceab36e65054fd0a02a60362d136e0 Mon Sep 17 00:00:00 2001 From: exowanderer Date: Mon, 17 Mar 2025 17:23:35 +0100 Subject: [PATCH 42/49] improved embedded conditional statemetns in 2_docker run.py --- .../run.py | 43 ++++++++++++------- 1 file changed, 27 insertions(+), 16 deletions(-) diff --git a/docker/2_Data_Processing_save_items_per_lang/run.py b/docker/2_Data_Processing_save_items_per_lang/run.py index 974f949..19af60d 100644 --- a/docker/2_Data_Processing_save_items_per_lang/run.py +++ b/docker/2_Data_Processing_save_items_per_lang/run.py @@ -14,22 +14,33 @@ def save_entities_to_sqlite(item, data_batch, sqlitDBlock): - lang_in_wp = False # Default setting - is_not_none = item is not None - if is_not_none: - lang_in_wp = WikidataLang.is_in_wikipedia(item, language=LANGUAGE) - - if is_not_none and lang_in_wp: - item = WikidataLang.normalise_item(item, language=LANGUAGE) - data_batch.append(item) - - with sqlitDBlock: - if len(data_batch) > PUSH_SIZE: - worked = WikidataLang.add_bulk_entities(list( - data_batch[:PUSH_SIZE] - )) - if worked: - del data_batch[:PUSH_SIZE] + """_summary_ + # TODO Add a docstring + + Args: + item (_type_): _description_ + data_batch (_type_): _description_ + sqlitDBlock (_type_): _description_ + """ + if item is not None: + # Check if the item is a valid entity + return + + lang_in_wp = WikidataLang.is_in_wikipedia(item, language=LANGUAGE) + if not lang_in_wp: + # If the entity is not in the specified language Wikipedia, skip + return + + item = WikidataLang.normalise_item(item, language=LANGUAGE) + data_batch.append(item) + + with sqlitDBlock: + if len(data_batch) > PUSH_SIZE: + worked = WikidataLang.add_bulk_entities(list( + data_batch[:PUSH_SIZE] + )) + if worked: + del data_batch[:PUSH_SIZE] if __name__ == "__main__": From 3828ba36cb585102d4abcae6f05a8dfedadd8bc1 Mon Sep 17 00:00:00 2001 From: exowanderer Date: Mon, 17 Mar 2025 17:26:59 +0100 Subject: [PATCH 43/49] refactor open to with open in docker3 run.py --- docker/2_Data_Processing_save_items_per_lang/run.py | 10 +++++----- docker/3_Add_Wikidata_to_AstraDB/run.py | 3 ++- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/docker/2_Data_Processing_save_items_per_lang/run.py b/docker/2_Data_Processing_save_items_per_lang/run.py index 19af60d..a3f355d 100644 --- a/docker/2_Data_Processing_save_items_per_lang/run.py +++ b/docker/2_Data_Processing_save_items_per_lang/run.py @@ -13,14 +13,14 @@ LANGUAGE = os.getenv("LANGUAGE", 'en') -def save_entities_to_sqlite(item, data_batch, sqlitDBlock): +def save_entities_to_sqlite(item, data_batch, sqliteDBlock): """_summary_ # TODO Add a docstring Args: item (_type_): _description_ data_batch (_type_): _description_ - sqlitDBlock (_type_): _description_ + sqliteDBlock (_type_): _description_ """ if item is not None: # Check if the item is a valid entity @@ -34,7 +34,7 @@ def save_entities_to_sqlite(item, data_batch, sqlitDBlock): item = WikidataLang.normalise_item(item, language=LANGUAGE) data_batch.append(item) - with sqlitDBlock: + with sqliteDBlock: if len(data_batch) > PUSH_SIZE: worked = WikidataLang.add_bulk_entities(list( data_batch[:PUSH_SIZE] @@ -45,7 +45,7 @@ def save_entities_to_sqlite(item, data_batch, sqlitDBlock): if __name__ == "__main__": multiprocess_manager = Manager() - sqlitDBlock = multiprocess_manager.Lock() + sqliteDBlock = multiprocess_manager.Lock() data_batch = multiprocess_manager.list() wikidata = WikidataDumpReader( @@ -59,7 +59,7 @@ def save_entities_to_sqlite(item, data_batch, sqlitDBlock): lambda item: save_entities_to_sqlite( item, data_batch, - sqlitDBlock + sqliteDBlock ), max_iterations=None, verbose=True diff --git a/docker/3_Add_Wikidata_to_AstraDB/run.py b/docker/3_Add_Wikidata_to_AstraDB/run.py index 9eab967..4f7c062 100644 --- a/docker/3_Add_Wikidata_to_AstraDB/run.py +++ b/docker/3_Add_Wikidata_to_AstraDB/run.py @@ -38,7 +38,8 @@ if not API_KEY_FILENAME: API_KEY_FILENAME = os.listdir("../API_tokens")[0] -datastax_token = json.load(open(f"../API_tokens/{API_KEY_FILENAME}")) +with open(f"../API_tokens/{API_KEY_FILENAME}") as json_in: + datastax_token = json.load(json_in) textifier = WikidataTextifier( language=LANGUAGE, From 1f7d4ac972e24c5d410fd5b8cec3217eff7fb0a5 Mon Sep 17 00:00:00 2001 From: exowanderer Date: Mon, 17 Mar 2025 17:31:26 +0100 Subject: [PATCH 44/49] refactor open to with open in docker4 run.py --- docker/4_Run_Retrieval/run.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docker/4_Run_Retrieval/run.py b/docker/4_Run_Retrieval/run.py index 29f825b..d3a8b08 100644 --- a/docker/4_Run_Retrieval/run.py +++ b/docker/4_Run_Retrieval/run.py @@ -48,7 +48,9 @@ API_KEY_FILENAME = os.listdir("../API_tokens")[0] print(f"API_KEY_FILENAME not provided. Using {API_KEY_FILENAME}") -datastax_token = json.load(open(f"../API_tokens/{API_KEY_FILENAME}")) + +with open(f"../API_tokens/{API_KEY_FILENAME}") as json_in: + datastax_token = json.load(json_in) if ELASTICSEARCH: graph_store = KeywordSearchConnect( From b560f9f3005079acaf1b107d1073a6eb9839eb6f Mon Sep 17 00:00:00 2001 From: exowanderer Date: Mon, 17 Mar 2025 17:40:30 +0100 Subject: [PATCH 45/49] added space at bottom of json file in docker7 --- docker/7_Create_Prototype/wikidata_chunk_sizes_2024-09-18.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/7_Create_Prototype/wikidata_chunk_sizes_2024-09-18.json b/docker/7_Create_Prototype/wikidata_chunk_sizes_2024-09-18.json index e474869..1b8c6f8 100644 --- a/docker/7_Create_Prototype/wikidata_chunk_sizes_2024-09-18.json +++ b/docker/7_Create_Prototype/wikidata_chunk_sizes_2024-09-18.json @@ -112,4 +112,4 @@ "chunk_110":190337, "chunk_111":166210, "chunk_112":26375 -} \ No newline at end of file +} From 72e74a75d5a2c78f82f33de301f013ac31752ebd Mon Sep 17 00:00:00 2001 From: exowanderer Date: Mon, 17 Mar 2025 17:41:02 +0100 Subject: [PATCH 46/49] added space at bottom of json file in run_exp.sh --- run_experiments.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/run_experiments.sh b/run_experiments.sh index 321513a..3d5e058 100644 --- a/run_experiments.sh +++ b/run_experiments.sh @@ -58,4 +58,4 @@ docker compose run --build -e CHUNK_NUM=112 create_prototype docker compose run --build -e CHUNK_NUM=111 create_prototype docker compose run --build -e CHUNK_NUM=110 create_prototype -docker compose run --build -e CHUNK_NUM=109 create_prototype \ No newline at end of file +docker compose run --build -e CHUNK_NUM=109 create_prototype From 2d5c4506708bfcbaa4b6661232406095a59c9dc5 Mon Sep 17 00:00:00 2001 From: exowanderer Date: Mon, 17 Mar 2025 17:43:29 +0100 Subject: [PATCH 47/49] added bash forloop example to run_exp.sh --- run_experiments.sh | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/run_experiments.sh b/run_experiments.sh index 3d5e058..51e0655 100644 --- a/run_experiments.sh +++ b/run_experiments.sh @@ -55,7 +55,11 @@ # -e DB_LANGUAGE="en" \ # -e API_KEY="datastax_wikidata2.json" run_retrieval -docker compose run --build -e CHUNK_NUM=112 create_prototype -docker compose run --build -e CHUNK_NUM=111 create_prototype -docker compose run --build -e CHUNK_NUM=110 create_prototype -docker compose run --build -e CHUNK_NUM=109 create_prototype +# docker compose run --build -e CHUNK_NUM=112 create_prototype +# docker compose run --build -e CHUNK_NUM=111 create_prototype +# docker compose run --build -e CHUNK_NUM=110 create_prototype +# docker compose run --build -e CHUNK_NUM=109 create_prototype + +for chunk_num in `seq 101 -1 100`; + do docker compose run --build -e CHUNK_NUM=$chunk_num create_prototype; +done \ No newline at end of file From b8376f1f705b4221ff4ee2cbea88dafd4b70df8e Mon Sep 17 00:00:00 2001 From: exowanderer Date: Mon, 17 Mar 2025 17:43:43 +0100 Subject: [PATCH 48/49] added space at bottom of json file in run_exp.sh --- run_experiments.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/run_experiments.sh b/run_experiments.sh index 51e0655..a10e6ec 100644 --- a/run_experiments.sh +++ b/run_experiments.sh @@ -62,4 +62,4 @@ for chunk_num in `seq 101 -1 100`; do docker compose run --build -e CHUNK_NUM=$chunk_num create_prototype; -done \ No newline at end of file +done From d101f074c555a66466fa321029b8ba59c8eb9c5a Mon Sep 17 00:00:00 2001 From: exowanderer Date: Mon, 17 Mar 2025 17:54:07 +0100 Subject: [PATCH 49/49] fixed too long lline in wikidatRetriever --- src/wikidataRetriever.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/wikidataRetriever.py b/src/wikidataRetriever.py index 7ae35a5..752368f 100644 --- a/src/wikidataRetriever.py +++ b/src/wikidataRetriever.py @@ -3,7 +3,9 @@ from src.wikidataCache import create_cache_embedding_model class AstraDBConnect: - def __init__(self, datastax_token, collection_name, model='jina', batch_size=8, cache_embeddings=None): + def __init__( + self, datastax_token, collection_name, model='jina', + batch_size=8, cache_embeddings=None): """ Initialize the AstraDBConnect object with the corresponding embedding model.