diff --git a/.github/workflows/validation.yml b/.github/workflows/validation.yml new file mode 100644 index 0000000..28cf936 --- /dev/null +++ b/.github/workflows/validation.yml @@ -0,0 +1,21 @@ +name: Question Validation Sync Check + +on: + pull_request: + branches: [ main ] + +jobs: + validate-prompt: + runs-on: ubuntu-latest + permissions: + contents: read + defaults: + run: + working-directory: ./scripts/python-scripts + steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Setup Environment + run: pip3 install -r requirements.txt + - name: Validate prompt + run: python3 sync.py -t validate \ No newline at end of file diff --git a/.gitignore b/.gitignore index 0c06f5e..2f71729 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,6 @@ embeddings_model vector_db -values.env \ No newline at end of file +values.env + +.venv \ No newline at end of file diff --git a/Makefile b/Makefile index 389cf2f..72d704c 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,7 @@ RAG_CONTENT_IMAGE ?= quay.io/redhat-ai-dev/rag-content:release-1.7-lcs +VENV := $(CURDIR)/scripts/python-scripts/.venv +PYTHON := $(VENV)/bin/python3 +PIP := $(VENV)/bin/pip3 default: help @@ -20,4 +23,21 @@ help: ## Show this help screen # TODO (Jdubrick): Replace reference to lightspeed-core/lightspeed-providers once bug is addressed. update-question-validation: - curl -o ./config/providers.d/inline/safety/lightspeed_question_validity.yaml https://raw.githubusercontent.com/Jdubrick/lightspeed-providers/refs/heads/devai/resources/external_providers/inline/safety/lightspeed_question_validity.yaml \ No newline at end of file + curl -o ./config/providers.d/inline/safety/lightspeed_question_validity.yaml https://raw.githubusercontent.com/Jdubrick/lightspeed-providers/refs/heads/devai/resources/external_providers/inline/safety/lightspeed_question_validity.yaml + +$(VENV)/bin/activate: ./scripts/python-scripts/requirements.txt + python3 -m venv $(VENV) + $(PIP) install -r scripts/python-scripts/requirements.txt + touch $(VENV)/bin/activate + +define run_sync + cd ./scripts/python-scripts && \ + $(PYTHON) sync.py -t $(1) +endef + +.PHONY: validate-prompt-templates update-prompt-templates +validate-prompt-templates: $(VENV)/bin/activate + $(call run_sync,validate) + +update-prompt-templates: $(VENV)/bin/activate + $(call run_sync,update) diff --git a/README.md b/README.md index e98b3d7..700615d 100644 --- a/README.md +++ b/README.md @@ -141,6 +141,8 @@ To deploy on a cluster see [DEPLOYMENT.md](./docs/DEPLOYMENT.md). | ---- | ----| | **get-rag** | Gets the RAG data and the embeddings model from the rag-content image registry to your local project directory | | **update-question-validation** | Updates the question validation content in `providers.d` | +| **validate-prompt-templates** | Validates prompt values in run.yaml. **Requires Python >= 3.11** | +| **update-prompt-templates** | Updates the prompt values in run.yaml. **Requires Python >= 3.11** | ## Contributing diff --git a/run.yaml b/run.yaml index 33a786f..6e586cf 100644 --- a/run.yaml +++ b/run.yaml @@ -2,32 +2,32 @@ version: '2' image_name: minimal-viable-llama-stack-configuration apis: - - agents - - datasetio - - eval - - inference - - post_training - - safety - - scoring - - telemetry - - tool_runtime - - vector_io +- agents +- datasetio +- eval +- inference +- post_training +- safety +- scoring +- telemetry +- tool_runtime +- vector_io benchmarks: [] -container_image: null +container_image: datasets: [] external_providers_dir: "/app-root/config/providers.d" inference_store: db_path: .llama/distributions/ollama/inference_store.db type: sqlite -logging: null +logging: metadata_store: db_path: .llama/distributions/ollama/registry.db - namespace: null + namespace: type: sqlite models: - model_id: sentence-transformers/all-mpnet-base-v2 metadata: - embedding_dimension: 768 + embedding_dimension: 768 model_type: embedding provider_id: sentence-transformers provider_model_id: "/app-root/embeddings_model" @@ -36,7 +36,7 @@ providers: - config: persistence_store: db_path: .llama/distributions/ollama/agents_store.db - namespace: null + namespace: type: sqlite responses_store: db_path: .llama/distributions/ollama/responses_store.db @@ -47,14 +47,14 @@ providers: - config: kvstore: db_path: .llama/distributions/ollama/huggingface_datasetio.db - namespace: null + namespace: type: sqlite provider_id: huggingface provider_type: remote::huggingface - config: kvstore: db_path: .llama/distributions/ollama/localfs_datasetio.db - namespace: null + namespace: type: sqlite provider_id: localfs provider_type: inline::localfs @@ -62,7 +62,7 @@ providers: - config: kvstore: db_path: .llama/distributions/ollama/meta_reference_eval.db - namespace: null + namespace: type: sqlite provider_id: meta-reference provider_type: inline::meta-reference @@ -94,80 +94,80 @@ providers: - config: checkpoint_format: huggingface device: cpu - distributed_backend: null + distributed_backend: dpo_output_dir: "." provider_id: huggingface provider_type: inline::huggingface safety: - - config: - excluded_categories: [] - provider_id: llama-guard - provider_type: inline::llama-guard - - provider_id: lightspeed_question_validity - provider_type: inline::lightspeed_question_validity - config: - model_id: ${env.VALIDATION_PROVIDER:=vllm}/${env.VALIDATION_MODEL_NAME} - model_prompt: |- - Instructions: + - config: + excluded_categories: [] + provider_id: llama-guard + provider_type: inline::llama-guard + - provider_id: lightspeed_question_validity + provider_type: inline::lightspeed_question_validity + config: + model_id: ${env.VALIDATION_PROVIDER:=vllm}/${env.VALIDATION_MODEL_NAME} + model_prompt: |- + Instructions: + + You area question classification tool. You are an expert in the following categories: + - Backstage + - Red Hat Developer Hub (RHDH) + - Kubernetes + - Openshift + - CI/CD + - GitOps + - Pipelines + - Developer Portals + - Deployments + - Software Catalogs + - Software Templates + - Tech Docs - You area question classification tool. You are an expert in the following categories: - - Backstage - - Red Hat Developer Hub (RHDH) - - Kubernetes - - Openshift - - CI/CD - - GitOps - - Pipelines - - Developer Portals - - Deployments - - Software Catalogs - - Software Templates - - Tech Docs - - Your job is to determine if a user's question is related to the categories you are an expert in. If the question is related to those categories, \ - or any features that may be related to those categories, you will answer with ${allowed}. + Your job is to determine if a user's question is related to the categories you are an expert in. If the question is related to those categories, \ + or any features that may be related to those categories, you will answer with ${allowed}. - If a question is not related to your expert categories, answer with ${rejected}. + If a question is not related to your expert categories, answer with ${rejected}. - You do not need to explain your answer. + You do not need to explain your answer. - Below are some example questions: - Example Question: - Why is the sky blue? - Example Response: - ${rejected} + Below are some example questions: + Example Question: + Why is the sky blue? + Example Response: + ${rejected} - Example Question: - Can you help configure my cluster to automatically scale? - Example Response: - ${allowed} + Example Question: + Can you help configure my cluster to automatically scale? + Example Response: + ${allowed} - Example Question: - How do I create import an existing software template in Backstage? - Example Response: - ${allowed} + Example Question: + How do I create import an existing software template in Backstage? + Example Response: + ${allowed} - Example Question: - How do I accomplish a task in RHDH? - Example Response: - ${allowed} + Example Question: + How do I accomplish a task in RHDH? + Example Response: + ${allowed} - Example Question: - How do I explore a component in RHDH catalog? - Example Response: - ${allowed} + Example Question: + How do I explore a component in RHDH catalog? + Example Response: + ${allowed} - Example Question: - How can I integrate GitOps into my pipeline? - Example Response: - ${allowed} + Example Question: + How can I integrate GitOps into my pipeline? + Example Response: + ${allowed} - Question: - ${message} - Response: - invalid_question_response: |- - Hi, I'm the Red Hat Developer Hub Lightspeed assistant, I can help you with questions about Red Hat Developer Hub or Backstage. - Please ensure your question is about these topics, and feel free to ask again! + Question: + ${message} + Response: + invalid_question_response: |- + Hi, I'm the Red Hat Developer Hub Lightspeed assistant, I can help you with questions about Red Hat Developer Hub or Backstage. + Please ensure your question is about these topics, and feel free to ask again! scoring: - config: {} provider_id: basic @@ -190,36 +190,36 @@ providers: - provider_id: model-context-protocol provider_type: remote::model-context-protocol config: {} - - provider_id: rag-runtime + - provider_id: rag-runtime provider_type: inline::rag-runtime config: {} vector_io: - config: kvstore: db_path: .llama/distributions/ollama/faiss_store.db - namespace: null + namespace: type: sqlite provider_id: faiss provider_type: inline::faiss - - provider_id: rhdh-docs + - provider_id: rhdh-docs provider_type: inline::faiss config: kvstore: type: sqlite - namespace: null + namespace: db_path: /app-root/vector_db/rhdh_product_docs/1.7/faiss_store.db scoring_fns: [] server: - auth: null - host: null + auth: + host: port: 8321 - quota: null - tls_cafile: null - tls_certfile: null - tls_keyfile: null + quota: + tls_cafile: + tls_certfile: + tls_keyfile: shields: - - shield_id: lightspeed_question_validity-shield - provider_id: lightspeed_question_validity +- shield_id: lightspeed_question_validity-shield + provider_id: lightspeed_question_validity tool_groups: - provider_id: rag-runtime toolgroup_id: builtin::rag diff --git a/scripts/python-scripts/requirements.txt b/scripts/python-scripts/requirements.txt new file mode 100644 index 0000000..9a44715 --- /dev/null +++ b/scripts/python-scripts/requirements.txt @@ -0,0 +1,2 @@ +ruamel.yaml>=0.17.0 +requests>=2.25.0 diff --git a/scripts/python-scripts/sync.py b/scripts/python-scripts/sync.py new file mode 100644 index 0000000..3f1c3f6 --- /dev/null +++ b/scripts/python-scripts/sync.py @@ -0,0 +1,154 @@ +import difflib +import sys +import requests +import re +import argparse +from ruamel.yaml import YAML + +URL = "https://raw.githubusercontent.com/redhat-developer/rhdh-plugins/refs/heads/main/workspaces/lightspeed/plugins/lightspeed-backend/src/prompts/rhdh-profile.py" +RUN_PATH = "../../run.yaml" + +def fetch_and_load(url: str) -> dict[str, str]: + """ + Fetches the contents from the upstream URL and returns a dictionary with the + desired prompt templates. + """ + response = requests.get(url) + response.raise_for_status() + content = response.text + validator_pattern = r"QUESTION_VALIDATOR_PROMPT_TEMPLATE\s*=\s*f?(['\"]{3})(.*?)\1" + rejection_pattern = r"INVALID_QUERY_RESP\s*=\s*(['\"]{3})(.*?)\1" + validator_match = re.search(validator_pattern, content, re.DOTALL) + rejection_match = re.search(rejection_pattern, content, re.DOTALL) + if not validator_match: + raise ValueError("QUESTION_VALIDATOR_PROMPT_TEMPLATE not found") + if not rejection_match: + raise ValueError("INVALID_QUERY_RESP not found") + + resp_dict = {} + resp_dict["validator_prompt"] = validator_match.group(2) + resp_dict["invalid_resp"] = rejection_match.group(2) + return resp_dict + +def replace_values(prompt: str) -> str: + """ + Replaces templated values to ones used with the safety shield. + """ + VALUES_TO_REPLACE = { + "{SUBJECT_REJECTED}": "${rejected}", + "{SUBJECT_ALLOWED}": "${allowed}", + "{{query}}": "${message}", + } + new_prompt = prompt + for replacee, replacement in VALUES_TO_REPLACE.items(): + new_prompt = new_prompt.replace(replacee, replacement) + return new_prompt + +def is_valid(incoming_prompt: dict[str,str], current_prompt: dict[str,str]) -> bool: + """ + Validates if the contents of the run.yaml file are equivalent to + the upstream. + """ + validator_check = incoming_prompt.get("validator_prompt").strip("\n") == current_prompt.get("validator_prompt").strip("\n") + invalid_resp_check = incoming_prompt.get("invalid_resp").strip("\n") == current_prompt.get("invalid_resp").strip("\n") + return validator_check and invalid_resp_check + +def fetch_current_prompts() -> dict[str,str]: + """ + Grabs the question validation prompt templates for both validation and + rejected response from the local run.yaml file. + """ + yaml = YAML() + yaml.preserve_quotes = True + resp_dict = {} + with open(RUN_PATH, "r", encoding="utf-8") as f: + data = yaml.load(f) + safety_providers = data.get("providers").get("safety") + for provider in safety_providers: + if provider.get("provider_id") == "lightspeed_question_validity": + resp_dict["validator_prompt"] = provider.get("config").get("model_prompt") + resp_dict["invalid_resp"] = provider.get("config").get("invalid_question_response") + return resp_dict + +def update_yaml_file(incoming_prompts: dict[str,str], file_path: str) -> None: + """ + Updates the local run.yaml file with the upstream prompt templates. + """ + yaml = YAML() + yaml.preserve_quotes = True + yaml.width = 4096 + with open(file_path, "r", encoding="utf-8") as f: + data = yaml.load(f) + + safety_providers = data.get("providers").get("safety") + for provider in safety_providers: + if provider.get("provider_id") == "lightspeed_question_validity": + provider["config"]["model_prompt"] = incoming_prompts.get("validator_prompt").strip("\n") + provider["config"]["invalid_question_response"] = incoming_prompts.get("invalid_resp").strip("\n") + break + + with open(file_path, "w", encoding="utf-8") as f: + yaml.dump(data, f) + +def output_diff(incoming_prompts: dict[str, str], current_prompts: dict[str, str]) -> None: + """ + Outputs the difference between the upstream prompt templates and the + local run.yaml. + """ + print("Validation Prompt") + print("-----") + diff_validator = difflib.unified_diff( + current_prompts["validator_prompt"].splitlines(), + incoming_prompts["validator_prompt"].splitlines(), + fromfile="run.yaml", + tofile="upstream", + lineterm="" + ) + print("\n".join(diff_validator)) + + print("Invalid Response Prompt") + print("-----") + diff_invalid = difflib.unified_diff( + current_prompts["invalid_resp"].splitlines(), + incoming_prompts["invalid_resp"].splitlines(), + fromfile="run.yaml", + tofile="upstream", + lineterm="" + ) + print("\n".join(diff_invalid)) + +def main(args: argparse.Namespace): + """ + Entrypoint to the program. + """ + incoming_prompts = fetch_and_load(URL) + replaced_prompt = replace_values(incoming_prompts.get("validator_prompt")) + incoming_prompts["validator_prompt"] = replaced_prompt + current_prompts = fetch_current_prompts() + + if args.type == "validate": + if is_valid(incoming_prompts, current_prompts): + print("Contents are valid.") + sys.exit(0) + print("Contents invalid.") + output_diff(incoming_prompts, current_prompts) + sys.exit(1) + elif args.type == "update": + update_yaml_file(incoming_prompts, RUN_PATH) + print("Contents updated.") + else: + print("Type incorrect.") + return + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description="Tool for validating and/or updating the question validation portions of the Llama Stack config." + ) + + parser.add_argument( + "-t", "--type", help="Type of action you want to perform. Can be 'validate' or 'update' to either validate or update the contents of run.yaml with the upstream." + ) + args = parser.parse_args() + + main(args)