diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 34873cdc..218a89c3 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -6,6 +6,9 @@ v3.7.1 *Release date: In development* +- Add text analysis tool to get an overall match of a text against a list of expected characteristics + using AI libraries that come with the `ai` extra dependency + v3.7.0 ------ diff --git a/toolium/test/utils/ai_utils/test_text_analysis.py b/toolium/test/utils/ai_utils/test_text_analysis.py new file mode 100644 index 00000000..3d677a4c --- /dev/null +++ b/toolium/test/utils/ai_utils/test_text_analysis.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +""" +Copyright 2025 Telefónica Innovación Digital, S.L. +This file is part of Toolium. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +import json +import pytest + +from toolium.driver_wrappers_pool import DriverWrappersPool +from toolium.utils.ai_utils.text_analysis import get_text_criteria_analysis + + +def configure_default_openai_model(): + """ + Configure OpenAI model used in unit tests + """ + config = DriverWrappersPool.get_default_wrapper().config + try: + config.add_section('AI') + except Exception: + pass + config.set('AI', 'openai_model', 'gpt-4.1-mini') + + +get_analysis_examples = ( + ('How are you today?', ["is a greeting phrase", "is a question"], 0.7, 1), + ('Today is sunny', ["is an affirmation", "talks about the weather"], 0.7, 1), + ('I love programming', ["expresses a positive sentiment"], 0.7, 1), + ('How are you today?', ["is an affirmation", "talks about the weather"], 0.0, 0.2), + ('Today is sunny', ["is a greeting phrase", "is a question"], 0.0, 0.2), + ('I love programming', ["is a greeting phrase", "is a question"], 0.0, 0.2), +) + + +@pytest.mark.skipif(os.getenv("AZURE_OPENAI_API_KEY") is None, + reason="AZURE_OPENAI_API_KEY environment variable not set") +@pytest.mark.parametrize('input_text, features_list, expected_low, expected_high', get_analysis_examples) +def test_get_text_analysis(input_text, features_list, expected_low, expected_high): + similarity = json.loads(get_text_criteria_analysis(input_text, features_list, azure=True)) + assert expected_low <= similarity['overall_match'] <= expected_high diff --git a/toolium/utils/ai_utils/openai.py b/toolium/utils/ai_utils/openai.py index 3a1d6072..82b58df9 100644 --- a/toolium/utils/ai_utils/openai.py +++ b/toolium/utils/ai_utils/openai.py @@ -49,12 +49,16 @@ def openai_request(system_message, user_message, model_name=None, azure=False, * model_name = model_name or config.get_optional('AI', 'openai_model', 'gpt-4o-mini') logger.info(f"Calling to OpenAI API with model {model_name}") client = AzureOpenAI(**kwargs) if azure else OpenAI(**kwargs) + msg = [] + if isinstance(system_message, list): + for prompt in system_message: + msg.append({"role": "system", "content": prompt}) + else: + msg.append({"role": "system", "content": system_message}) + msg.append({"role": "user", "content": user_message}) completion = client.chat.completions.create( model=model_name, - messages=[ - {"role": "system", "content": system_message}, - {"role": "user", "content": user_message}, - ], + messages=msg, ) response = completion.choices[0].message.content logger.debug(f"OpenAI response: {response}") diff --git a/toolium/utils/ai_utils/text_analysis.py b/toolium/utils/ai_utils/text_analysis.py new file mode 100644 index 00000000..33b016c0 --- /dev/null +++ b/toolium/utils/ai_utils/text_analysis.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- +""" +Copyright 2025 Telefónica Innovación Digital, S.L. +This file is part of Toolium. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import logging +import json + +from toolium.utils.ai_utils.openai import openai_request + +# Configure logger +logger = logging.getLogger(__name__) + +# flake8: noqa E501 +def build_system_message(characteristics): + """ + Build system message for text criteria analysis prompt. + + :param characteristics: list of target characteristics to evaluate + """ + feature_list = "\n".join(f"- {c}" for c in characteristics) + base_prompt = f""" + You are an assistant that scores how well a given text matches a set of target characteristics and returns a JSON object. + + You will receive a user message that contains ONLY the text to analyze. + + Target characteristics: + {feature_list} + + Tasks: + 1) For EACH characteristic, decide how well the text satisfies it on a scale from 0.0 (does not satisfy it at all) to 1.0 (perfectly satisfies it). Consider style, tone and content when relevant. + 2) ONLY for each low scored characteristic (<=0.2), output: + - "name": the exact characteristic name as listed above. + - "score": a float between 0.0 and 0.2. + 3) Compute an overall score "overall_match" between 0.0 and 1.0 that summarizes how well the text matches the whole set. It does not have to be a simple arithmetic mean, but must still be in [0.0, 1.0]. + + Output format (IMPORTANT): + Return ONLY a single valid JSON object with this exact top-level structure and property names: + + {{ + "overall_match": float, + "features": [ + {{ + "name": string, + "score": float + }} + ] + }} + + Constraints: + - Do NOT include scores for high valued (<=0.2) features at features list. + - Use a dot as decimal separator (e.g. 0.75, not 0,75). + - Use at most 2 decimal places for all scores. + - Do NOT include any text outside the JSON (no Markdown, no comments, no explanations). + - If a characteristic is not applicable to the text, give it a low score (<= 0.2). + """ + return base_prompt.strip() + + +def get_text_criteria_analysis(text_input, text_criteria, model_name=None, azure=False, **kwargs): + """ + Get text criteria analysis using Azure OpenAI. To analyze how well a given text + matches a set of target characteristics. + The response is a structured JSON object with overall match score, individual feature scores, + and additional data sections. + + :param text_input: text to analyze + :param text_criteria: list of target characteristics to evaluate + :param model_name: name of the OpenAI model to use + :param azure: whether to use Azure OpenAI or standard OpenAI + :param kwargs: additional parameters to be used by OpenAI client + :returns: response from OpenAI + """ + # Build prompt using base prompt and target features + system_message = build_system_message(text_criteria) + return openai_request(system_message, text_input, model_name, azure, **kwargs) + + +def assert_text_criteria(text_input, text_criteria, threshold, model_name=None, azure=False, **kwargs): + """ + Get text criteria analysis and assert if overall match score is above threshold. + + :param text_input: text to analyze + :param text_criteria: list of target characteristics to evaluate + :param threshold: minimum overall match score to consider the text acceptable + :param model_name: name of the OpenAI model to use + :param azure: whether to use Azure OpenAI or standard OpenAI + :param kwargs: additional parameters to be used by OpenAI client + :raises AssertionError: if overall match score is below threshold + """ + analysis = json.loads(get_text_criteria_analysis(text_input, text_criteria, model_name, azure, **kwargs)) + overall_match = analysis.get("overall_match", 0.0) + if overall_match < threshold: + logger.error(f"Text criteria analysis failed: overall match {overall_match} " + f"is below threshold {threshold}\n" + f"Failed features: {analysis.get('features', [])}") + raise AssertionError(f"Text criteria analysis failed: overall match {overall_match} " + f"is below threshold {threshold}\n" + f"Failed features: {analysis.get('features', [])}") + logger.info(f"Text criteria analysis passed: overall match {overall_match} " + f"is above threshold {threshold}." + f"Low scored features: {analysis.get('features', [])}")