From 4726773bd7112849ebbe550ecc2511f223e2e0dd Mon Sep 17 00:00:00 2001 From: Aryan Amit Barsainyan Date: Mon, 27 Apr 2026 16:17:43 +0530 Subject: [PATCH 1/2] Fix RDD estimation issues --- cais/components/dataset_cleaner.py | 5 +- cais/components/query_interpreter.py | 5 +- .../regression_discontinuity/estimator.py | 51 +++++++++++++------ cais/models.py | 2 + cais/prompts/method_identification_prompts.py | 10 +++- 5 files changed, 53 insertions(+), 20 deletions(-) diff --git a/cais/components/dataset_cleaner.py b/cais/components/dataset_cleaner.py index 3ebb524..0d1d4e9 100644 --- a/cais/components/dataset_cleaner.py +++ b/cais/components/dataset_cleaner.py @@ -200,12 +200,13 @@ def _run_script_text(script: str, dataset_path: str, cleaned_path: str) -> Tuple # ---------- main pipeline blocks ---------- -def _plan_transformation_spec(llm, dataset_path: str, causal_method: str, causal_query: str, variables: Dict[str, Any]) -> Dict[str, Any]: +def _plan_transformation_spec(llm, dataset_path: str, causal_method: str, causal_query: str, variables: Dict[str, Any], dataset_description: Optional[str]) -> Dict[str, Any]: prof = _profile_dataset(dataset_path) human = { "dataset_path": dataset_path, "dataset_profile": prof["profile"], + "dataset_description": dataset_description or "No dataset description provided", "causal_method": causal_method, "causal_query": causal_query or "", "variables": variables @@ -262,7 +263,7 @@ def run_cleaning_stage(dataset_path: str, # 1) PLAN method = causal_method or variables.get("method") or "" - spec = _plan_transformation_spec(llm, dataset_path, method, original_query or "", variables) + spec = _plan_transformation_spec(llm, dataset_path, method, original_query or "", variables, dataset_description) #print(spec) # 2) CODEGEN diff --git a/cais/components/query_interpreter.py b/cais/components/query_interpreter.py index de661c9..a2182f6 100644 --- a/cais/components/query_interpreter.py +++ b/cais/components/query_interpreter.py @@ -313,10 +313,12 @@ def interpret_query(query_info: Dict[str, Any], dataset_analysis: Dict[str, Any] if rdd_result: running_variable = rdd_result.running_variable cutoff_value = rdd_result.cutoff_value + treat_above_cutoff = rdd_result.treat_above_cutoff if running_variable not in columns or cutoff_value is None: running_variable = None cutoff_value = None - logger.info(f"LLM identified RDD: Running={running_variable}, Cutoff={cutoff_value}") + treat_above_cutoff = None + logger.info(f"LLM identified RDD: Running={running_variable}, Cutoff={cutoff_value}, Treat Above Cutoff={treat_above_cutoff}") ## For graph based methods exclude_cols = [treatment_variable, outcome_variable] @@ -377,6 +379,7 @@ def interpret_query(query_info: Dict[str, Any], dataset_analysis: Dict[str, Any] ## for rdd "running_variable": running_variable, "cutoff_value": cutoff_value, + "treat_above_cutoff": treat_above_cutoff, ## for rct "is_rct": is_rct, "treatment_reference_level": treatment_reference_level, diff --git a/cais/methods/regression_discontinuity/estimator.py b/cais/methods/regression_discontinuity/estimator.py index 42296d5..b76df1f 100644 --- a/cais/methods/regression_discontinuity/estimator.py +++ b/cais/methods/regression_discontinuity/estimator.py @@ -44,13 +44,15 @@ def estimate_effect(self, df, variables, query=None): hasattr(variables, 'treatment_variable') and hasattr(variables, 'outcome_variable') and hasattr(variables, 'running_variable') and - hasattr(variables, 'cutoff_value') + hasattr(variables, 'cutoff_value') and + hasattr(variables, 'treat_above_cutoff') ) treatment = variables.treatment_variable outcome = variables.outcome_variable running_var = variables.running_variable cutoff = variables.cutoff_value + treat_above_cutoff = variables.treat_above_cutoff covariates = variables.confounders covariates = covariates if covariates else [] @@ -68,7 +70,8 @@ def estimate_effect(self, df, variables, query=None): outcome=outcome, running_variable=running_var, cutoff_value=cutoff, - covariates=covariates + covariates=covariates, + treat_above_cutoff=treat_above_cutoff, ) except Exception as e: raise ValueError(f"Couldn't calculate effect using RDD: {e}") @@ -114,11 +117,13 @@ def _call_llm_for_var(llm: BaseChatModel, prompt: str, pydantic_model: BaseModel if rdd_result: running_variable = rdd_result.running_variable cutoff_value = rdd_result.cutoff_value + treat_above_cutoff = rdd_result.treat_above_cutoff if running_variable not in columns or cutoff_value is None: running_variable = None cutoff_value = None - logger.info(f"LLM identified RDD: Running={running_variable}, Cutoff={cutoff_value}") - return cutoff_value + treat_above_cutoff = None + logger.info(f"LLM identified RDD: Running={running_variable}, Cutoff={cutoff_value}, Treat Above Cutoff={treat_above_cutoff}") + return cutoff_value, treat_above_cutoff @@ -138,7 +143,7 @@ def _call_llm_for_var(llm: BaseChatModel, prompt: str, pydantic_model: BaseModel _rdd_em_import_error_message = f"An unexpected error occurred during import from evan-magnusson/rdd: {e}" logger.warning(_rdd_em_import_error_message) -def estimate_effect_fallback(df: pd.DataFrame, treatment: str, outcome: str, running_variable: str, cutoff_value: float, covariates: Optional[List[str]], **kwargs) -> Dict[str, Any]: +def estimate_effect_fallback(df: pd.DataFrame, treatment: str, outcome: str, running_variable: str, cutoff_value: float, covariates: Optional[List[str]], treat_above_cutoff: Optional[bool] = None, **kwargs) -> Dict[str, Any]: """Estimate RDD effect using simple linear regression comparison fallback.""" logger.warning("Main RDD estimation failed. Using fallback simple linear regression comparison.") if covariates: @@ -202,8 +207,14 @@ def estimate_effect_fallback(df: pd.DataFrame, treatment: str, outcome: str, run # The coefficient for 'above_cutoff' represents the jump at the cutoff effect = results.params['above_cutoff'] + if not treat_above_cutoff: + effect = effect * -1 p_value = results.pvalues['above_cutoff'] - conf_int = results.conf_int().loc['above_cutoff'].tolist() + L, U = results.conf_int().loc['above_cutoff'].tolist() + if not treat_above_cutoff: + L, U = -U, -L + conf_int = [L, U] + std_err = results.bse['above_cutoff'] return { @@ -218,7 +229,7 @@ def estimate_effect_fallback(df: pd.DataFrame, treatment: str, outcome: str, run def effect_estimate_rdd(df: pd.DataFrame, outcome: str, running_variable: str, cutoff_value: float, treatment: Optional[str] = None, covariates: Optional[List[str]] = None, - bandwidth: Optional[float] = None, **kwargs) -> Dict[str, Any]: + bandwidth: Optional[float] = None, treat_above_cutoff: Optional[bool] = None, **kwargs) -> Dict[str, Any]: """ Estimates RDD effect using the 'evan-magnusson/rdd' package. Uses IK optimal bandwidth selection from the same package by default. @@ -290,12 +301,16 @@ def effect_estimate_rdd(df: pd.DataFrame, outcome: str, running_variable: str, c # Extract results - using 'TREATED' based on the provided summary output effect = sm_results.params.get('TREATED') + if not treat_above_cutoff: + effect = effect * -1 std_err = sm_results.bse.get('TREATED') p_value = sm_results.pvalues.get('TREATED') - + conf_int_series = sm_results.conf_int() conf_int = conf_int_series.loc['TREATED'].tolist() if 'TREATED' in conf_int_series.index else [None, None] - + if not treat_above_cutoff: + if 'TREATED' in conf_int_series.index: + conf_int = [conf_int[1] * -1, conf_int[0] * -1] n_obs = model.nobs # or model.n_ if nobs is not available (check package details) # The formula is implicit in the local linear regression performed by the package @@ -323,7 +338,7 @@ def effect_estimate_rdd(df: pd.DataFrame, outcome: str, running_variable: str, c def estimate_effect(df: pd.DataFrame, treatment: str, outcome: str, running_variable: str, cutoff_value: float, covariates: Optional[List[str]] = None, bandwidth: Optional[float] = None, query: Optional[str] = None, - llm: Optional[BaseChatModel] = None, **kwargs) -> Dict[str, Any]: + llm: Optional[BaseChatModel] = None, treat_above_cutoff: Optional[bool] = None, **kwargs) -> Dict[str, Any]: """ Estimates the causal effect using Regression Discontinuity Design. @@ -341,6 +356,7 @@ def estimate_effect(df: pd.DataFrame, treatment: str, outcome: str, running_vari bandwidth: Optional bandwidth around the cutoff. If None, a default is used. query: Optional user query for context. llm: Optional Language Model instance. + treat_above_cutoff: True if the treatment is assigned above the cutoff, False if below the cutoff. **kwargs: Additional keyword arguments for underlying methods. Returns: @@ -348,15 +364,19 @@ def estimate_effect(df: pd.DataFrame, treatment: str, outcome: str, running_vari """ required_args = { "running_variable": running_variable, - "cutoff_value": cutoff_value + "cutoff_value": cutoff_value, } if any(val is None for val in required_args.values()): - raise ValueError(f"Missing required RDD arguments: running_variable and cutoff_value must be provided.") + raise ValueError(f"Missing required RDD arguments: running_variable, cutoff_value must be provided.") + + if treat_above_cutoff is None: + logger.warning("`treat_above_cutoff` is not provided. Assuming treatment is assigned above the cutoff.") + treat_above_cutoff = True results = {} rdd_em_estimation_error = None # Error from effect_estimate_rdd (evan-magnusson) fallback_estimation_error = None # Error from estimate_effect_fallback - + # --- Try effect_estimate_rdd (evan-magnusson/rdd) First --- try: logger.info("Attempting RDD estimation using 'effect_estimate_rdd' (evan-magnusson/rdd package).") @@ -368,7 +388,8 @@ def estimate_effect(df: pd.DataFrame, treatment: str, outcome: str, running_vari cutoff_value, treatment=treatment, # For API consistency, though evan-magnusson/rdd doesn't use it explicitly covariates=covariates, - bandwidth=bandwidth, + bandwidth=bandwidth, + treat_above_cutoff=treat_above_cutoff, **kwargs ) results['method_used'] = 'evan-magnusson/rdd' # Ensure method_used is set @@ -384,7 +405,7 @@ def estimate_effect(df: pd.DataFrame, treatment: str, outcome: str, running_vari if not results: # If effect_estimate_rdd wasn't used or failed logger.info("'effect_estimate_rdd' did not produce results. Attempting fallback using 'estimate_effect_fallback'.") try: - fallback_results = estimate_effect_fallback(df, treatment, outcome, running_variable, cutoff_value, covariates, bandwidth=bandwidth, **kwargs) + fallback_results = estimate_effect_fallback(df, treatment, outcome, running_variable, cutoff_value, covariates, bandwidth=bandwidth, treat_above_cutoff=treat_above_cutoff, **kwargs) results.update(fallback_results) results['method_used'] = 'Fallback RDD (Linear Interaction with Robust Errors)' fallback_estimation_error = None # Clear fallback error if it succeeded diff --git a/cais/models.py b/cais/models.py index d475803..3e2ad57 100644 --- a/cais/models.py +++ b/cais/models.py @@ -23,6 +23,7 @@ class LLMRDDVars(BaseModel): """Pydantic model for identifying RDD variables.""" running_variable: Optional[str] = Field(None, description="The identified running variable column name.") cutoff_value: Optional[Union[float, int]] = Field(None, description="The identified cutoff value.") + treat_above_cutoff: Optional[bool] = Field(None, description="True if the treatment is assigned above the cutoff, False if below the cutoff.") class LLMRCTCheck(BaseModel): """Pydantic model for checking if data is RCT.""" @@ -124,6 +125,7 @@ class Variables(BaseModel): treatment_state: Optional[str] = None running_variable: Optional[str] = None cutoff_value: Optional[Union[float, int]] = None + treat_above_cutoff: Optional[bool] = None is_rct: Optional[bool] = Field(False, description="Flag indicating if the dataset is from an RCT.") treatment_reference_level: Optional[Union[float, str]] = Field(None, description="The specified reference/control level for a multi-valued treatment variable.") interaction_term_suggested: Optional[bool] = Field(False, description="Whether the query or context suggests an interaction term with the treatment might be relevant.") diff --git a/cais/prompts/method_identification_prompts.py b/cais/prompts/method_identification_prompts.py index 1fc1329..d0f4dfc 100644 --- a/cais/prompts/method_identification_prompts.py +++ b/cais/prompts/method_identification_prompts.py @@ -64,6 +64,7 @@ You need to identify if a running variable exists for performing Regression Discontinuity Design (RDD) to answer the user query. Go through the data description and available columns carefully. You need to be strict with the assessment. In RDD, treatment assignment (for analysis) is determined by whether a continuous variable crosses a specific threshold. +Whether the treatment is assigned above or below the cutoff can be determined by the user query or dataset description. User Query: "{query}" Dataset Description: {description} @@ -82,7 +83,11 @@ Step 3: Identify the cutoff value from design - What specific threshold value determines treatment assignment? -Step 4: Final determination +Step 4: Identify if the treatment is assigned above or below the cutoff +- Is the treatment assigned above the cutoff? Store it as a boolean value in 'treat_above_cutoff' in the JSON. +- Store null if you are unsure. + +Step 5: Final determination - Only suggest RDD if both running variable and cutoff value can be identified - Return null if the assignment mechanism is not threshold-based or you are unsure. @@ -90,7 +95,8 @@ {{ "running_variable": "column_name_or_null", - "cutoff_value": numeric_value_or_null + "cutoff_value": numeric_value_or_null, + "treat_above_cutoff": true_false_or_null, }} """ From b52cceda4896ae9cc33cca9a04da236186303581 Mon Sep 17 00:00:00 2001 From: Aryan Amit Barsainyan Date: Thu, 30 Apr 2026 10:26:56 +0530 Subject: [PATCH 2/2] add unit tests and fix rdd design compliance --- cais/components/decision_tree.py | 2 + cais/components/explanation_generator.py | 1 + cais/components/method_validator.py | 16 +++-- cais/components/query_interpreter.py | 1 + cais/models.py | 3 +- cais/tools/method_executor_tool.py | 2 +- cais/tools/method_validator_tool.py | 1 + cais/utils/agent.py | 2 +- .../cais/components/test_decision_tree_llm.py | 1 + .../test_rdd_estimator.py | 68 +++++++++++++++++++ .../test_components/test_query_interpreter.py | 2 +- 11 files changed, 90 insertions(+), 9 deletions(-) diff --git a/cais/components/decision_tree.py b/cais/components/decision_tree.py index 81a2ef1..cf550f6 100644 --- a/cais/components/decision_tree.py +++ b/cais/components/decision_tree.py @@ -366,6 +366,7 @@ def rule_based_select_method(dataset_analysis, variables, is_rct, llm, dataset_d "has_temporal_structure": dataset_analysis.get("temporal_structure", False).get("has_temporal_structure", False), "frontdoor_criterion": variables.get("frontdoor_criterion", False), "cutoff_value": variables.get("cutoff_value"), + "treat_above_cutoff": variables.get("treat_above_cutoff"), "covariate_overlap_score": variables.get("covariate_overlap_result", 0)} properties["is_rct"] = is_rct @@ -404,6 +405,7 @@ def select_method(self, df: pd.DataFrame, treatment: str, outcome: str, covariat "instrument_variable": query_details.get("instrument_variable"), "running_variable": query_details.get("running_variable"), "cutoff_value": query_details.get("cutoff_value"), + "treat_above_cutoff": query_details.get("treat_above_cutoff"), "is_rct": query_details.get("is_rct", False), "has_temporal_structure": dataset_analysis.get("temporal_structure", False).get("has_temporal_structure", False), "frontdoor_criterion": query_details.get("frontdoor_criterion", False), diff --git a/cais/components/explanation_generator.py b/cais/components/explanation_generator.py index b5b06c3..5f57387 100644 --- a/cais/components/explanation_generator.py +++ b/cais/components/explanation_generator.py @@ -414,6 +414,7 @@ def explain_application(method: str, treatment: str, outcome: str, f"I will focus on observations close to the cutoff value " f"({variables.get('cutoff_value')}) of the running variable " f"({variables.get('running_variable')}), where treatment assignment changes. " + f"Is the treatment assigned above the cutoff? Answer: {variables.get('treat_above_cutoff')} (False means below the cutoff, True means above the cutoff, None means not determined). " f"By comparing outcomes just above and below this threshold, I can estimate " f"the local causal effect of {treatment} on {outcome}." ), diff --git a/cais/components/method_validator.py b/cais/components/method_validator.py index b7f5cc8..0ea2784 100644 --- a/cais/components/method_validator.py +++ b/cais/components/method_validator.py @@ -13,15 +13,20 @@ from cais.config import get_llm_client -def rdd_design_compliance(df: pd.DataFrame, running_variable: str, treatment: str, cutoff_value: float) -> dict: +def rdd_design_compliance(df: pd.DataFrame, running_variable: str, treatment: str, cutoff_value: float, treat_above_cutoff: Optional[bool] = None) -> dict: """Check whether treatment assignment closely follows the cutoff rule T ≈ 1{X >= c}.""" try: - above = df[running_variable] >= cutoff_value + if treat_above_cutoff is None: + treat_above_cutoff = True # Default to above cutoff + if treat_above_cutoff: + mask = df[running_variable] >= cutoff_value + else: + mask = df[running_variable] <= cutoff_value if treatment not in df.columns: return {"ok": False, "reason": f"Treatment column '{treatment}' not found."} t = df[treatment] # Compliance: fraction of rows where T matches the cutoff rule - compliance = (t == above.astype(t.dtype)).mean() + compliance = (t == mask.astype(t.dtype)).mean() # Allow for fuzzy RDD — flag as ok if compliance >= 0.75 return {"ok": float(compliance) >= 0.75, "compliance_rate": float(compliance)} except Exception as e: @@ -405,6 +410,7 @@ def validate_regression_discontinuity(validation_result: Dict[str, Any], cutoff_value = variables.get("cutoff_value") treatment = variables.get("treatment_variable") outcome = variables.get("outcome_variable") + treat_above_cutoff = variables.get("treat_above_cutoff") df = pd.read_csv(dataset_analysis['dataset_info']['file_path']) # Required fields @@ -431,7 +437,7 @@ def validate_regression_discontinuity(validation_result: Dict[str, Any], return # 1) Enforced-by-design check: is treatment determined by cutoff? - design = rdd_design_compliance(df, running_variable, treatment, cutoff_value) + design = rdd_design_compliance(df, running_variable, treatment, cutoff_value, treat_above_cutoff) validation_result.setdefault("evidence", {})["rdd_design"] = design if not design.get("ok", False): validation_result["concerns"].append( @@ -462,7 +468,7 @@ def validate_regression_discontinuity(validation_result: Dict[str, Any], # 3) Assumption status (per text) validation_result.setdefault("evidence", {})["rdd_notes"] = { "assumption_status": "Untestable; assess visually around cutoff for an abrupt jump.", - "design_enforcement": "Treatment is (or should be) determined by a cutoff variable.", + "design_enforcement": "Treatment is (or should be) determined by a cutoff variable. Whether the treatment is assigned above or below the cutoff is determined by the user query or dataset description.", "visual_recommendation": "Plot outcome vs running within a symmetric window around the cutoff; inspect for a jump." } diff --git a/cais/components/query_interpreter.py b/cais/components/query_interpreter.py index a2182f6..d118595 100644 --- a/cais/components/query_interpreter.py +++ b/cais/components/query_interpreter.py @@ -288,6 +288,7 @@ def interpret_query(query_info: Dict[str, Any], dataset_analysis: Dict[str, Any] instrument_variable = None running_variable = None cutoff_value = None + treat_above_cutoff = None is_rct = None smd_score = None diff --git a/cais/models.py b/cais/models.py index 3e2ad57..094ebd0 100644 --- a/cais/models.py +++ b/cais/models.py @@ -23,7 +23,7 @@ class LLMRDDVars(BaseModel): """Pydantic model for identifying RDD variables.""" running_variable: Optional[str] = Field(None, description="The identified running variable column name.") cutoff_value: Optional[Union[float, int]] = Field(None, description="The identified cutoff value.") - treat_above_cutoff: Optional[bool] = Field(None, description="True if the treatment is assigned above the cutoff, False if below the cutoff.") + treat_above_cutoff: Optional[bool] = Field(None, description="True if the treatment is assigned above the cutoff, False if below the cutoff, None if unsure.") class LLMRCTCheck(BaseModel): """Pydantic model for checking if data is RCT.""" @@ -101,6 +101,7 @@ class QueryInfo(BaseModel): instrument_hints: Optional[List[str]] = None running_variable_hints: Optional[List[str]] = None cutoff_value_hint: Optional[Union[float, int]] = None + treat_above_cutoff_hint: Optional[bool] = None class QueryInterpreterInput(BaseModel): """Input structure for the query interpreter tool.""" diff --git a/cais/tools/method_executor_tool.py b/cais/tools/method_executor_tool.py index 2ff0f00..361c181 100644 --- a/cais/tools/method_executor_tool.py +++ b/cais/tools/method_executor_tool.py @@ -101,7 +101,7 @@ def method_executor_tool(inputs: MethodExecutorInput, original_query: Optional[s # Avoid passing the entire variables_dict as estimate_func expects specific args kwargs_for_method = {} for key in ["instrument_variable", "time_variable", "group_variable", - "running_variable", "cutoff_value", "did_term", "did_canonical", "treatment_time", "treatment_state"]: + "running_variable", "cutoff_value", "treat_above_cutoff", "did_term", "did_canonical", "treatment_time", "treatment_state"]: if key in variables_dict and variables_dict[key] is not None: kwargs_for_method[key] = variables_dict[key] diff --git a/cais/tools/method_validator_tool.py b/cais/tools/method_validator_tool.py index 68f757b..be98c23 100644 --- a/cais/tools/method_validator_tool.py +++ b/cais/tools/method_validator_tool.py @@ -43,6 +43,7 @@ def extract_properties_from_inputs(inputs: MethodValidatorInput) -> Dict[str, An "has_temporal_structure": dataset_analysis_dict.get("temporal_structure", {}).get("has_temporal_structure", False), "frontdoor_criterion": variables_dict.get("frontdoor_criterion", False), "cutoff_value": variables_dict.get("cutoff_value"), + "treat_above_cutoff": variables_dict.get("treat_above_cutoff"), "covariate_overlap_score": variables_dict.get("covariate_overlap_result", 0), "is_rct": variables_dict.get("is_rct", False) } diff --git a/cais/utils/agent.py b/cais/utils/agent.py index a9c7337..b96216b 100644 --- a/cais/utils/agent.py +++ b/cais/utils/agent.py @@ -181,7 +181,7 @@ def create_agent_prompt(tools: List[tool]) -> ChatPromptTemplate: **IMPORTANT TOOL USAGE:** 1. **Action Input Format:** The value for 'Action Input' MUST be a single, valid JSON object string. Do NOT include any other text or formatting around the JSON string. 2. **Argument Gathering:** You MUST gather ALL required arguments for the Action Input JSON from the initial Human input AND the 'Observation' outputs of PREVIOUS steps. Look carefully at the required arguments for the tool you are calling. -3. **Data Handoff:** The 'Observation' from a previous step often contains structured data needed by the next tool. For example, the 'variables' output from `query_interpreter_tool` contains fields like `treatment_variable`, `outcome_variable`, `covariates`, `time_variable`, `instrument_variable`, `running_variable`, `cutoff_value`, and `is_rct`. When calling `method_selector_tool`, you MUST construct its required `variables` input argument by including **ALL** these relevant fields identified by the `query_interpreter_tool` in the previous Observation. Similarly, pass the full `dataset_analysis`, `dataset_description`, and `original_query` when required by the next tool. +3. **Data Handoff:** The 'Observation' from a previous step often contains structured data needed by the next tool. For example, the 'variables' output from `query_interpreter_tool` contains fields like `treatment_variable`, `outcome_variable`, `covariates`, `time_variable`, `instrument_variable`, `running_variable`, `cutoff_value`, `treat_above_cutoff`, and `is_rct`. When calling `method_selector_tool`, you MUST construct its required `variables` input argument by including **ALL** these relevant fields identified by the `query_interpreter_tool` in the previous Observation. Similarly, pass the full `dataset_analysis`, `dataset_description`, and `original_query` when required by the next tool. IMPORTANT WORKFLOW: ------------------- diff --git a/tests/cais/components/test_decision_tree_llm.py b/tests/cais/components/test_decision_tree_llm.py index acb2232..399ee01 100644 --- a/tests/cais/components/test_decision_tree_llm.py +++ b/tests/cais/components/test_decision_tree_llm.py @@ -111,6 +111,7 @@ def test_select_method_observational_running_var_llm_selects_rdd(self): rdd_variables["instrument_variable"] = None # Make IV less likely rdd_variables["running_variable"] = "age" rdd_variables["cutoff_value"] = 65 + rdd_variables["treat_above_cutoff"] = True self._create_mock_llm_response({ "selected_method": REGRESSION_DISCONTINUITY, diff --git a/tests/cais/methods/regression_discontinuity/test_rdd_estimator.py b/tests/cais/methods/regression_discontinuity/test_rdd_estimator.py index 1448e18..f097361 100644 --- a/tests/cais/methods/regression_discontinuity/test_rdd_estimator.py +++ b/tests/cais/methods/regression_discontinuity/test_rdd_estimator.py @@ -35,6 +35,34 @@ def sample_rdd_data(): return df +@pytest.fixture +def sample_rdd_data_treatment_below_cutoff(): + """Generates synthetic data suitable for RDD testing with treatment assigned below cutoff.""" + np.random.seed(123) + n_samples = 200 + cutoff = 50.0 + treatment_effect = 10.0 + + # Running variable centered around cutoff + running_var = np.random.uniform(cutoff - 20, cutoff + 20, n_samples) + # Treatment assigned below cutoff + treatment = (running_var < cutoff).astype(int) + # Covariate correlated with running variable + covariate1 = 0.5 * running_var + np.random.normal(0, 5, n_samples) + # Outcome depends on running var (parallel slopes), treatment, and covariate + error = np.random.normal(0, 5, n_samples) + outcome = (10 + 0.8 * running_var + + treatment_effect * treatment + + 2.0 * covariate1 + error) + + df = pd.DataFrame({ + 'outcome': outcome, + 'treatment_indicator': treatment, + 'running_var': running_var, + 'covariate1': covariate1 + }) + return df + # --- Test Cases --- def test_estimate_effect_missing_args(sample_rdd_data): @@ -181,3 +209,43 @@ def test_estimate_effect_no_data_in_bandwidth(sample_rdd_data): cutoff_value=50.0, bandwidth=0.01, # Extremely small bandwidth ) + + +@patch('cais.methods.regression_discontinuity.estimator.run_rdd_diagnostics') +@patch('cais.methods.regression_discontinuity.estimator.interpret_rdd_results') +@patch('cais.methods.regression_discontinuity.estimator.effect_estimate_rdd') +def test_estimate_effect_primary_success_treatment_below_cutoff(mock_em_rdd, mock_interpret, mock_diagnostics, sample_rdd_data_treatment_below_cutoff): + """Test successful estimation using the mocked evan-magnusson/rdd path with treatment assigned below cutoff.""" + mock_em_rdd.return_value = { + 'effect_estimate': 10.5, + 'standard_error': 1.25, + 'p_value': 0.01, + 'confidence_interval': [8.0, 13.0], + 'method_details': 'RDD (evan-magnusson/rdd package, Bandwidth: 5.0000)', + 'bandwidth_used': 5.0, + 'formula': 'local linear', + 'model_summary': 'summary' + } + mock_diagnostics.return_value = {"status": "Success"} + mock_interpret.return_value = "LLM Interpretation" + + results = estimate_effect( + sample_rdd_data_treatment_below_cutoff, + 'treatment_indicator', + 'outcome', + running_variable='running_var', + cutoff_value=50.0, + bandwidth=5.0, + treat_above_cutoff=False, + ) + + mock_em_rdd.assert_called_once() + assert results['method_used'] == 'evan-magnusson/rdd' + assert results['effect_estimate'] == 10.5 + assert results['p_value'] == 0.01 + assert results['confidence_interval'] == [8.0, 13.0] + assert results['standard_error'] == 1.25 + assert 'diagnostics' in results + assert 'interpretation' in results + mock_diagnostics.assert_called_once() + mock_interpret.assert_called_once() diff --git a/tests/cais/test_components/test_query_interpreter.py b/tests/cais/test_components/test_query_interpreter.py index 8ec9045..61360ab 100644 --- a/tests/cais/test_components/test_query_interpreter.py +++ b/tests/cais/test_components/test_query_interpreter.py @@ -72,7 +72,7 @@ def mock_llm_call_router(*args, **kwargs): elif pydantic_model_passed == LLMIVars: return MagicMock(instrument_variable=None) elif pydantic_model_passed == LLMRDDVars: - return MagicMock(running_variable=None, cutoff_value=None) + return MagicMock(running_variable=None, cutoff_value=None, treat_above_cutoff=None) elif pydantic_model_passed == LLMRCTCheck: return MagicMock(is_rct=False, reasoning="No indication of RCT.") elif pydantic_model_passed == LLMInteractionSuggestion: