diff --git a/src/occupational_classification_utils/llm/llm.py b/src/occupational_classification_utils/llm/llm.py index 9bdc188..aa7cc83 100644 --- a/src/occupational_classification_utils/llm/llm.py +++ b/src/occupational_classification_utils/llm/llm.py @@ -122,7 +122,7 @@ async def get_soc_code( self, job_title: str, job_description: str, - level_of_education: str, + level_of_education: str | None, manage_others: bool, industry_descr: str, ) -> SocResponse: @@ -269,6 +269,7 @@ async def unambiguous_soc_code( # noqa: PLR0913 semantic_search_results: list[dict], job_title: str | None = None, job_description: str | None = None, + level_of_education: str | None = None, candidates_limit: int = config["llm"]["candidates_limit"], code_digits: int = config["llm"]["code_digits"], correlation_id: str | None = None, @@ -288,11 +289,17 @@ async def unambiguous_soc_code( # noqa: PLR0913 if (job_description is None or job_description in {"", " "}) else job_description ) + level_of_education = ( + "Unknown" + if (level_of_education is None or level_of_education in {"", " "}) + else level_of_education + ) call_dict = { "industry_descr": industry_descr, "job_title": job_title, "job_description": job_description, + "level_of_education": level_of_education, "soc_candidates": soc_candidates, } @@ -305,6 +312,7 @@ async def unambiguous_soc_code( # noqa: PLR0913 "LLM request sent - unambiguous_soc_code", job_title=truncate_identifier(job_title), job_description=truncate_identifier(job_description), + level_of_education=truncate_identifier(str(level_of_education)), industry_descr=truncate_identifier(industry_descr), correlation_id=correlation_id or "", ) @@ -391,19 +399,42 @@ async def unambiguous_soc_code( # noqa: PLR0913 return validated_answer, call_dict - async def formulate_open_question( + async def formulate_open_question( # noqa: PLR0913 self, industry_descr: str, job_title: str | None = None, job_description: str | None = None, + level_of_education: str | None = None, llm_output: RagCandidate | None = None, correlation_id: str | None = None, - ) -> tuple[OpenFollowUp, dict[str, Any]]: - """Formulate an open-ended follow-up (mirrors SIC formulate_open_question).""" + ) -> tuple[OpenFollowUp, Any]: + """Formulates an open-ended question using respondent data and survey design guidelines. + + Args: + industry_descr (str): The description of the industry. + job_title (str, optional): The job title. Defaults to None. + job_description (str, optional): The job description. Defaults to None. + level_of_education (str, optional): The level od education. Defaults to None. + llm_output (RagCandidate, optional): The response from the LLM model. + correlation_id (str, optional): Optional correlation ID for request tracking. - def prep_call_dict(industry_descr, job_title, job_description, llm_output): + Returns: + OpenFollowUp: The generated response to the query. + + Raises: + ValueError: If there is an error during the parsing of the response. + ValueError: If the default embedding handler is required but + not loaded correctly. + + """ + + def prep_call_dict( + industry_descr, job_title, job_description, level_of_education, llm_output + ): + # Helper function to prepare the call dictionary is_job_title_present = job_title is None or job_title in {"", " "} job_title = "Unknown" if is_job_title_present else job_title + is_job_description_present = job_description is None or job_description in { "", " ", @@ -411,17 +442,26 @@ def prep_call_dict(industry_descr, job_title, job_description, llm_output): job_description = ( "Unknown" if is_job_description_present else job_description ) - return { + level_of_education = ( + "Unknown" + if (level_of_education is None or level_of_education in {"", " "}) + else level_of_education + ) + + call_dict = { "industry_descr": industry_descr, "job_title": job_title, "job_description": job_description, + "level_of_education": level_of_education, "llm_output": str(llm_output), } + return call_dict call_dict = prep_call_dict( industry_descr=industry_descr, job_title=job_title, job_description=job_description, + level_of_education=level_of_education, llm_output=llm_output, ) @@ -430,10 +470,13 @@ def prep_call_dict(industry_descr, job_title, job_description, llm_output): logger.debug(final_prompt) chain = self.soc_prompt_openfollowup | self.llm + + # Log LLM request sent logger.info( "LLM request sent - formulate_open_question", job_title=truncate_identifier(job_title), job_description=truncate_identifier(job_description), + level_of_education=truncate_identifier(str(level_of_education)), industry_descr=truncate_identifier(industry_descr), correlation_id=correlation_id or "", ) @@ -459,9 +502,11 @@ def prep_call_dict(industry_descr, job_title, job_description, llm_output): llm_duration_ms = int((time.perf_counter() - llm_start) * 1000) + # Parse the output to the desired format parser = PydanticOutputParser(pydantic_object=OpenFollowUp) try: validated_answer = parser.parse(str(response.content)) + # Log LLM response received after successful parse has_followup = bool(getattr(validated_answer, "followup", None)) logger.info( "LLM response received for open question prompt", @@ -487,8 +532,8 @@ def prep_call_dict(industry_descr, job_title, job_description, llm_output): correlation_id=correlation_id or "", ) try: - fix_chain = FIX_PARSING_PROMPT | self.llm - response = await fix_chain.ainvoke( + chain = FIX_PARSING_PROMPT | self.llm + response = await chain.ainvoke( { "llm_output": str(response.content), "format_instructions": parser.get_format_instructions(), @@ -497,6 +542,7 @@ def prep_call_dict(industry_descr, job_title, job_description, llm_output): ) validated_answer = parser.parse(str(response.content)) logger.debug("Successfully parsed reformatted response.") + except (ValueError, AttributeError) as parse_error2: logger.error( f"Failed to parse response again: {parse_error2}", @@ -524,6 +570,7 @@ async def sa_rag_soc_code( # noqa: PLR0913 industry_descr: str, job_title: str | None = None, job_description: str | None = None, + level_of_education: str | None = None, code_digits: int = config["llm"]["code_digits"], candidates_limit: int = config["llm"]["candidates_limit"], short_list: list[dict[Any, Any]] | None = None, @@ -537,6 +584,7 @@ async def sa_rag_soc_code( # noqa: PLR0913 industry_descr (str): The description of the industry. job_title (str, optional): The job title. Defaults to None. job_description (str, optional): The job description. Defaults to None. + level_of_education (str): The level of education required for the job. code_digits (int, optional): The number of digits in the generated SOC code. Defaults to 4. candidates_limit (int, optional): The maximum number of SOC code candidates @@ -554,7 +602,9 @@ async def sa_rag_soc_code( # noqa: PLR0913 """ - def prep_call_dict(industry_descr, job_title, job_description, soc_codes): + def prep_call_dict( + industry_descr, job_title, job_description, level_of_education, soc_codes + ): # Helper function to prepare the call dictionary is_job_title_present = job_title is None or job_title in {"", " "} job_title = "Unknown" if is_job_title_present else job_title @@ -571,6 +621,7 @@ def prep_call_dict(industry_descr, job_title, job_description, soc_codes): "industry_descr": industry_descr, "job_title": job_title, "job_description": job_description, + "level_of_education": level_of_education, "soc_index": soc_codes, } return call_dict @@ -588,6 +639,7 @@ def prep_call_dict(industry_descr, job_title, job_description, soc_codes): industry_descr=industry_descr, job_title=job_title, job_description=job_description, + level_of_education=level_of_education, soc_codes=soc_codes, ) diff --git a/src/occupational_classification_utils/llm/prompt.py b/src/occupational_classification_utils/llm/prompt.py index ddf504b..3955d04 100644 --- a/src/occupational_classification_utils/llm/prompt.py +++ b/src/occupational_classification_utils/llm/prompt.py @@ -104,6 +104,7 @@ - Company's main activity: {industry_descr} - Job Title: {job_title} - Job Description: {job_description} +- Level of Education: {level_of_education} ===Relevant subset of UK SOC 2020=== {soc_index} @@ -155,6 +156,7 @@ }, ) + FIX_PARSING_PROMPT = PromptTemplate.from_template( """You are a meticulous assistant tasked with ensuring that the output from a language model adheres strictly to the required JSON format. @@ -196,6 +198,7 @@ - Company's main activity: {industry_descr} - Job Title: {job_title} - Job Description: {job_description} +- Level of Education: {level_of_education} ===Shortlist=== {soc_candidates} @@ -228,6 +231,7 @@ - Company's main activity: {industry_descr} - Job title: {job_title} - Job description: {job_description} +- Level of Education: {level_of_education} - Shortlist from previous model: {llm_output} - Note: These are candidate occupational categories; do not mention codes or "SOC" to the respondent. diff --git a/src/occupational_classification_utils/models/response_model.py b/src/occupational_classification_utils/models/response_model.py index 0c2a950..177813e 100644 --- a/src/occupational_classification_utils/models/response_model.py +++ b/src/occupational_classification_utils/models/response_model.py @@ -277,61 +277,89 @@ class SurveyAssistSocResponse(BaseModel): class UnambiguousResponse(BaseModel): - """Represents a response model for classification code assignment (two-step SOC). + """Represents a response model for classification code assignment. - Same generic field names as SIC ``UnambiguousResponse`` for parity across schemes. + Attributes: + codable (bool): True only if enough information is provided to assign + an unambiguous single classification code, False otherwise. + class_code (Optional[str]): Full classification code (to the required number of digits) + assigned based on provided respondent's data. Must be present if codable=True, + must be None if codable=False. + class_descriptive (Optional[str]): Descriptive label of the classification category. + Must be present if codable=True, must be None if codable=False. + alt_candidates (list[RagCandidate]): Short list of possible classification codes with their + descriptive labels and estimated likelihoods. + reasoning (str): Step by step reasoning behind the classification selected. """ codable: bool = Field( - description=( - "True only if enough information is provided to decide an unambiguous " - "classification code, False otherwise." - ) + description="True only if enough information is provided to decide an unambiguous " + "classification code, False otherwise." ) + class_code: str | None = Field( default=None, - description=( - "Full classification code assigned from respondent data. " - "Present if codable=True, None if codable=False." - ), + description="Full classification code (to the required number of digits) " + "assigned based on provided respondent's data. Must be present if codable=True, " + "must be None if codable=False.", ) + class_descriptive: str | None = Field( default=None, - description=( - "Descriptive label for class_code. Present if codable=True, " - "None if codable=False." - ), + description="Descriptive label of the classification category. " + "Must be present if codable=True, must be None if codable=False.", ) + alt_candidates: list[RagCandidate] = Field( default_factory=list, - description="Short list of possible classification codes with likelihoods.", - min_length=1, - max_length=10, + description="Short list of possible classification codes with their " + "descriptive labels and estimated likelihoods.", + max_length=10, # Limit to less than 10 candidates ) + reasoning: str = Field( description="Step by step reasoning behind the classification selected.", - min_length=50, + min_length=50, # Ensure detailed reasoning is provided ) @field_validator("alt_candidates") @classmethod - def validate_alt_candidates(cls, v: list[RagCandidate]) -> list[RagCandidate]: - """Validate alternative candidate count.""" - if not 1 <= len(v) <= MAX_ALT_CANDIDATES: - raise ValueError("alt_candidates must contain between 1 and 10 items.") + def validate_alt_candidates(cls, v): + """Validates the number of alternative candidates. + + Ensures that the number of candidates is less or equal to the maximum allowed. + + Args: + v (list): The list of alternative candidates. + + Returns: + list: The validated list of candidates. + + Raises: + ValueError: If the number of candidates is not within the allowed range. + """ + if not len(v) <= MAX_ALT_CANDIDATES: + raise ValueError("alt_candidates must contain no more than 10 items.") return v class OpenFollowUp(BaseModel): - """Open-ended follow-up question when SOC cannot be assigned unambiguously.""" + """Represents a response model for open ended follow-up question. + + Attributes: + followup (str): Question to ask user in order to collect + additional information to enable reliable classification assignment. + reasoning (str): Reasoning explaining how follow-up question will help + assign classification code. + """ followup: str | None = Field( - description=( - "Question to collect additional information for reliable SOC assignment." - ), + description="""Question to ask user in order to collect additional information + to enable reliable classification assignment.""", default="", ) reasoning: str = Field( - description="Reasoning explaining how the follow-up question helps classification.", + description="""Reasoning explaining how follow-up question will help + assign classification code.""", default="", ) diff --git a/tests/test_llm.py b/tests/test_llm.py index 0e9ea99..ba91dd5 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -16,7 +16,9 @@ from occupational_classification.data_access.soc_data_access import ( load_soc_structure as lib_load_soc_structure, ) -from occupational_classification.hierarchy.soc_hierarchy import load_hierarchy +from occupational_classification.hierarchy.soc_hierarchy import ( + load_hierarchy, +) from occupational_classification_utils.llm.llm import ClassificationLLM from occupational_classification_utils.llm.prompt import SA_SOC_PROMPT_RAG @@ -425,18 +427,28 @@ async def test_unambiguous_soc_code_followup_is_str( assert isinstance(result, str) +@pytest.fixture +def mock_soc(): + """Minimal SOC hierarchy from the packaged example lookup table.""" + ref = ("occupational_classification", "data/example_soc_lookup_data.csv") + with as_file(files(ref[0]).joinpath(ref[1])) as path: + p = str(path) + idx = lib_load_soc_index(p) + soc = load_hierarchy(lib_load_soc_structure(p), idx) + return soc + + @pytest.mark.llm async def test_llm_response_mocked_formulate_open_question( mocker, prompt_candidate_soc ): - """formulate_open_question returns typed response and call dict with mocked output.""" mock_object_dict = {"class_code": "", "class_descriptive": "", "likelihood": 0.5} mock_object_json = json.dumps(mock_object_dict) mock_message = mocker.Mock(spec=AIMessage) mock_message.content = mock_object_json - mocker.patch( + mock_patcher = mocker.patch( # noqa: F841 "occupational_classification_utils.llm.llm.ChatVertexAI.ainvoke", return_value=mock_message, ) @@ -445,18 +457,8 @@ async def test_llm_response_mocked_formulate_open_question( industry_descr="", job_title="", job_description="", + level_of_education="", llm_output="", ) assert isinstance(result[0], OpenFollowUp) assert isinstance(result[1], dict) - - -@pytest.fixture -def mock_soc(): - """Minimal SOC hierarchy from the packaged example lookup table.""" - ref = ("occupational_classification", "data/example_soc_lookup_data.csv") - with as_file(files(ref[0]).joinpath(ref[1])) as path: - p = str(path) - idx = lib_load_soc_index(p) - soc = load_hierarchy(lib_load_soc_structure(p), idx) - return soc