diff --git a/pythonvectordbceph.py b/pythonvectordbceph.py index 0422544..e6317e4 100644 --- a/pythonvectordbceph.py +++ b/pythonvectordbceph.py @@ -15,6 +15,8 @@ from timm.data import resolve_data_config from timm.data.transforms_factory import create_transform import torch +import langchain +from langchain_text_splitters import CharacterTextSplitter # this is need for only when second image embedding function is used # from transformers import AutoFeatureExtractor, AutoModelForImageClassification @@ -71,6 +73,10 @@ def __call__(self, imagepath): object_type = os.getenv("OBJECT_TYPE") +chunk_size = int(os.getenv("CHUNK_SIZE")) + +if chunk_size == None: + chunk_size = 1 app = Flask(__name__) @@ -87,16 +93,24 @@ def pythonvectordbappceph(): event_type = event_data['Records'][0]['eventName'] app.logger.debug(object_key) tags = event_data['Records'][0]['s3']['object']['tags'] - app.logger.debug("tags : " + str(tags)) + if len(tags) == 0: + tags = {} + #app.logger.debug("tags : " + str(tags)) # Create collection which includes the id, object url, and embedded vector if not client.has_collection(collection_name=collection_name): fields = [ - FieldSchema(name='url', dtype=DataType.VARCHAR, max_length=2048, is_primary=True), # VARCHARS need a maximum length, so for this example they are set to 200 characters + FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True), + FieldSchema(name='url', dtype=DataType.VARCHAR, max_length=2048), # VARCHARS need a maximum length, so for this example they are set to 200 characters FieldSchema(name='embedded_vector', dtype=DataType.FLOAT_VECTOR, dim=int(os.getenv("VECTOR_DIMENSION"))), + FieldSchema(name='start_offset', dtype=DataType.INT64, nullable=True), + FieldSchema(name='end_offset', dtype=DataType.INT64, nullable=True), FieldSchema(name='tags', dtype=DataType.JSON, nullable=True) ] + #app.logger.debug(fields) schema = CollectionSchema(fields=fields, enable_dynamic_field=True) + #app.logger.debug(schema) client.create_collection(collection_name=collection_name, schema=schema) + index_params = client.prepare_index_params() index_params.add_index(field_name="embedded_vector", metric_type="L2", index_type="IVF_FLAT", params={"nlist": 16384}) client.create_index(collection_name=collection_name, index_params=index_params) @@ -116,11 +130,30 @@ def pythonvectordbappceph(): case "TEXT": object_content = object_data["Body"].read().decode("utf-8") objectlist = [] - objectlist.append(object_content) - # default embedding function provided by milvus, it has some size limtation for the object - # embedding_fn = milvus_model.DefaultEmbeddingFunction() #dimension 768 + if chunk_size < 1: + app.logger.error("chunk size cannot be less than zero") + return + if chunk_size > 1: + object_size = object_data["ContentLength"] + if object_size == 0 : + app.logger.debug("object size zero cannot be chunked") + return + text_splitter = CharacterTextSplitter( + separator=".", + chunk_size = chunk_size, + chunk_overlap=0, + length_function=len, + is_separator_regex=False, + ) + objectlist = text_splitter.split_text(object_content) + app.logger.debug("chunk size " + str(chunk_size) + " no of chunks " + str(len(objectlist))) + else : + objectlist.append(object_content) + # default embedding function provided by milvus, it has some size limtation for the object + # embedding_fn = milvus_model.DefaultEmbeddingFunction() #dimension 768 embedding_fn = milvus_model.dense.SentenceTransformerEmbeddingFunction(model_name='all-MiniLM-L6-v2',device='cpu') # dimension 384 vectors = embedding_fn.encode_documents(objectlist) + app.logger.debug("vector length "+str(len(vectors))) vector = vectors[0] case "IMAGE": @@ -144,14 +177,31 @@ def pythonvectordbappceph(): case _: app.logger.error("Unknown object format") - app.logger.debug(vector) - - if len(tags) > 0: - data = [ {"embedded_vector": vector, "url": object_url, "tags": tags} ] + # delete entries already existing entries, otherwise duplicate entries is possible + res = client.delete(collection_name=collection_name, + filter="url in "+ object_url) + #app.logger.debug(res) + + #app.logger.debug(vector) + data = [] + # null value is not working as expected. The attribute is not set properly + if chunk_size > 1: + start_offset = 0 + for i in range(len(objectlist)): + end_offset = start_offset + len(objectlist[i]) + if len(tags) > 0: + data.append({"embedded_vector": vectors[i], "url": object_url, "start_offset": start_offset, "end_offset": end_offset, "tags" : tags}) + else: + data.append({"embedded_vector": vectors[i], "url": object_url, "start_offset": start_offset, "end_offset": end_offset}) + start_offset = end_offset + 1 else: - data = [ {"embedded_vector": vector, "url": object_url} ] + if len(tags) > 0: + data.append({"embedded_vector": vector, "url": object_url, "tags": tags}) + else: + data.append({"embedded_vector": vector, "url": object_url}) - res = client.upsert(collection_name=collection_name, data=data) + #app.logger.debug(data) + res = client.insert(collection_name=collection_name, data=data) app.logger.debug(res) return '' diff --git a/requirements.txt b/requirements.txt index 9a48cbd..bb45e0b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ torch timm scikit-learn sentence-transformers +langchain diff --git a/sample-deployment-text.yaml b/sample-deployment-text.yaml index cddaa1a..eb3844c 100644 --- a/sample-deployment-text.yaml +++ b/sample-deployment-text.yaml @@ -47,3 +47,4 @@ data: MILVUS_ENDPOINT : "http://my-release-milvus.default.svc:19530" OBJECT_TYPE : "TEXT" VECTOR_DIMENSION: "384" +# CHUNK_SIZE : "500"