Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions cais/components/dataset_cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,13 @@ def _run_script_text(script: str, dataset_path: str, cleaned_path: str) -> Tuple

# ---------- main pipeline blocks ----------

def _plan_transformation_spec(llm, dataset_path: str, causal_method: str, causal_query: str, variables: Dict[str, Any]) -> Dict[str, Any]:
def _plan_transformation_spec(llm, dataset_path: str, causal_method: str, causal_query: str, variables: Dict[str, Any], dataset_description: Optional[str]) -> Dict[str, Any]:
prof = _profile_dataset(dataset_path)

human = {
"dataset_path": dataset_path,
"dataset_profile": prof["profile"],
"dataset_description": dataset_description or "No dataset description provided",
"causal_method": causal_method,
"causal_query": causal_query or "",
"variables": variables
Expand Down Expand Up @@ -262,7 +263,7 @@ def run_cleaning_stage(dataset_path: str,

# 1) PLAN
method = causal_method or variables.get("method") or ""
spec = _plan_transformation_spec(llm, dataset_path, method, original_query or "", variables)
spec = _plan_transformation_spec(llm, dataset_path, method, original_query or "", variables, dataset_description)
#print(spec)

# 2) CODEGEN
Expand Down
2 changes: 2 additions & 0 deletions cais/components/decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ def rule_based_select_method(dataset_analysis, variables, is_rct, llm, dataset_d
"has_temporal_structure": dataset_analysis.get("temporal_structure", False).get("has_temporal_structure", False),
"frontdoor_criterion": variables.get("frontdoor_criterion", False),
"cutoff_value": variables.get("cutoff_value"),
"treat_above_cutoff": variables.get("treat_above_cutoff"),
"covariate_overlap_score": variables.get("covariate_overlap_result", 0)}

properties["is_rct"] = is_rct
Expand Down Expand Up @@ -404,6 +405,7 @@ def select_method(self, df: pd.DataFrame, treatment: str, outcome: str, covariat
"instrument_variable": query_details.get("instrument_variable"),
"running_variable": query_details.get("running_variable"),
"cutoff_value": query_details.get("cutoff_value"),
"treat_above_cutoff": query_details.get("treat_above_cutoff"),
"is_rct": query_details.get("is_rct", False),
"has_temporal_structure": dataset_analysis.get("temporal_structure", False).get("has_temporal_structure", False),
"frontdoor_criterion": query_details.get("frontdoor_criterion", False),
Expand Down
1 change: 1 addition & 0 deletions cais/components/explanation_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ def explain_application(method: str, treatment: str, outcome: str,
f"I will focus on observations close to the cutoff value "
f"({variables.get('cutoff_value')}) of the running variable "
f"({variables.get('running_variable')}), where treatment assignment changes. "
f"Is the treatment assigned above the cutoff? Answer: {variables.get('treat_above_cutoff')} (False means below the cutoff, True means above the cutoff, None means not determined). "
f"By comparing outcomes just above and below this threshold, I can estimate "
f"the local causal effect of {treatment} on {outcome}."
),
Expand Down
16 changes: 11 additions & 5 deletions cais/components/method_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,20 @@
from cais.config import get_llm_client


def rdd_design_compliance(df: pd.DataFrame, running_variable: str, treatment: str, cutoff_value: float) -> dict:
def rdd_design_compliance(df: pd.DataFrame, running_variable: str, treatment: str, cutoff_value: float, treat_above_cutoff: Optional[bool] = None) -> dict:
"""Check whether treatment assignment closely follows the cutoff rule T ≈ 1{X >= c}."""
try:
above = df[running_variable] >= cutoff_value
if treat_above_cutoff is None:
treat_above_cutoff = True # Default to above cutoff
if treat_above_cutoff:
mask = df[running_variable] >= cutoff_value
else:
mask = df[running_variable] <= cutoff_value
if treatment not in df.columns:
return {"ok": False, "reason": f"Treatment column '{treatment}' not found."}
t = df[treatment]
# Compliance: fraction of rows where T matches the cutoff rule
compliance = (t == above.astype(t.dtype)).mean()
compliance = (t == mask.astype(t.dtype)).mean()
# Allow for fuzzy RDD — flag as ok if compliance >= 0.75
return {"ok": float(compliance) >= 0.75, "compliance_rate": float(compliance)}
except Exception as e:
Expand Down Expand Up @@ -405,6 +410,7 @@ def validate_regression_discontinuity(validation_result: Dict[str, Any],
cutoff_value = variables.get("cutoff_value")
treatment = variables.get("treatment_variable")
outcome = variables.get("outcome_variable")
treat_above_cutoff = variables.get("treat_above_cutoff")
df = pd.read_csv(dataset_analysis['dataset_info']['file_path'])

# Required fields
Expand All @@ -431,7 +437,7 @@ def validate_regression_discontinuity(validation_result: Dict[str, Any],
return

# 1) Enforced-by-design check: is treatment determined by cutoff?
design = rdd_design_compliance(df, running_variable, treatment, cutoff_value)
design = rdd_design_compliance(df, running_variable, treatment, cutoff_value, treat_above_cutoff)
validation_result.setdefault("evidence", {})["rdd_design"] = design
if not design.get("ok", False):
validation_result["concerns"].append(
Expand Down Expand Up @@ -462,7 +468,7 @@ def validate_regression_discontinuity(validation_result: Dict[str, Any],
# 3) Assumption status (per text)
validation_result.setdefault("evidence", {})["rdd_notes"] = {
"assumption_status": "Untestable; assess visually around cutoff for an abrupt jump.",
"design_enforcement": "Treatment is (or should be) determined by a cutoff variable.",
"design_enforcement": "Treatment is (or should be) determined by a cutoff variable. Whether the treatment is assigned above or below the cutoff is determined by the user query or dataset description.",
"visual_recommendation": "Plot outcome vs running within a symmetric window around the cutoff; inspect for a jump."
}

Expand Down
6 changes: 5 additions & 1 deletion cais/components/query_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def interpret_query(query_info: Dict[str, Any], dataset_analysis: Dict[str, Any]
instrument_variable = None
running_variable = None
cutoff_value = None
treat_above_cutoff = None
is_rct = None
smd_score = None

Expand All @@ -313,10 +314,12 @@ def interpret_query(query_info: Dict[str, Any], dataset_analysis: Dict[str, Any]
if rdd_result:
running_variable = rdd_result.running_variable
cutoff_value = rdd_result.cutoff_value
treat_above_cutoff = rdd_result.treat_above_cutoff
if running_variable not in columns or cutoff_value is None:
running_variable = None
cutoff_value = None
logger.info(f"LLM identified RDD: Running={running_variable}, Cutoff={cutoff_value}")
treat_above_cutoff = None
logger.info(f"LLM identified RDD: Running={running_variable}, Cutoff={cutoff_value}, Treat Above Cutoff={treat_above_cutoff}")

## For graph based methods
exclude_cols = [treatment_variable, outcome_variable]
Expand Down Expand Up @@ -377,6 +380,7 @@ def interpret_query(query_info: Dict[str, Any], dataset_analysis: Dict[str, Any]
## for rdd
"running_variable": running_variable,
"cutoff_value": cutoff_value,
"treat_above_cutoff": treat_above_cutoff,
## for rct
"is_rct": is_rct,
"treatment_reference_level": treatment_reference_level,
Expand Down
51 changes: 36 additions & 15 deletions cais/methods/regression_discontinuity/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,15 @@ def estimate_effect(self, df, variables, query=None):
hasattr(variables, 'treatment_variable') and
hasattr(variables, 'outcome_variable') and
hasattr(variables, 'running_variable') and
hasattr(variables, 'cutoff_value')
hasattr(variables, 'cutoff_value') and
hasattr(variables, 'treat_above_cutoff')
)

treatment = variables.treatment_variable
outcome = variables.outcome_variable
running_var = variables.running_variable
cutoff = variables.cutoff_value
treat_above_cutoff = variables.treat_above_cutoff
covariates = variables.confounders
covariates = covariates if covariates else []

Expand All @@ -68,7 +70,8 @@ def estimate_effect(self, df, variables, query=None):
outcome=outcome,
running_variable=running_var,
cutoff_value=cutoff,
covariates=covariates
covariates=covariates,
treat_above_cutoff=treat_above_cutoff,
)
except Exception as e:
raise ValueError(f"Couldn't calculate effect using RDD: {e}")
Expand Down Expand Up @@ -114,11 +117,13 @@ def _call_llm_for_var(llm: BaseChatModel, prompt: str, pydantic_model: BaseModel
if rdd_result:
running_variable = rdd_result.running_variable
cutoff_value = rdd_result.cutoff_value
treat_above_cutoff = rdd_result.treat_above_cutoff
if running_variable not in columns or cutoff_value is None:
running_variable = None
cutoff_value = None
logger.info(f"LLM identified RDD: Running={running_variable}, Cutoff={cutoff_value}")
return cutoff_value
treat_above_cutoff = None
logger.info(f"LLM identified RDD: Running={running_variable}, Cutoff={cutoff_value}, Treat Above Cutoff={treat_above_cutoff}")
return cutoff_value, treat_above_cutoff



Expand All @@ -138,7 +143,7 @@ def _call_llm_for_var(llm: BaseChatModel, prompt: str, pydantic_model: BaseModel
_rdd_em_import_error_message = f"An unexpected error occurred during import from evan-magnusson/rdd: {e}"
logger.warning(_rdd_em_import_error_message)

def estimate_effect_fallback(df: pd.DataFrame, treatment: str, outcome: str, running_variable: str, cutoff_value: float, covariates: Optional[List[str]], **kwargs) -> Dict[str, Any]:
def estimate_effect_fallback(df: pd.DataFrame, treatment: str, outcome: str, running_variable: str, cutoff_value: float, covariates: Optional[List[str]], treat_above_cutoff: Optional[bool] = None, **kwargs) -> Dict[str, Any]:
"""Estimate RDD effect using simple linear regression comparison fallback."""
logger.warning("Main RDD estimation failed. Using fallback simple linear regression comparison.")
if covariates:
Expand Down Expand Up @@ -202,8 +207,14 @@ def estimate_effect_fallback(df: pd.DataFrame, treatment: str, outcome: str, run

# The coefficient for 'above_cutoff' represents the jump at the cutoff
effect = results.params['above_cutoff']
if not treat_above_cutoff:
effect = effect * -1
p_value = results.pvalues['above_cutoff']
conf_int = results.conf_int().loc['above_cutoff'].tolist()
L, U = results.conf_int().loc['above_cutoff'].tolist()
if not treat_above_cutoff:
L, U = -U, -L
conf_int = [L, U]

std_err = results.bse['above_cutoff']

return {
Expand All @@ -218,7 +229,7 @@ def estimate_effect_fallback(df: pd.DataFrame, treatment: str, outcome: str, run

def effect_estimate_rdd(df: pd.DataFrame, outcome: str, running_variable: str, cutoff_value: float,
treatment: Optional[str] = None, covariates: Optional[List[str]] = None,
bandwidth: Optional[float] = None, **kwargs) -> Dict[str, Any]:
bandwidth: Optional[float] = None, treat_above_cutoff: Optional[bool] = None, **kwargs) -> Dict[str, Any]:
"""
Estimates RDD effect using the 'evan-magnusson/rdd' package.
Uses IK optimal bandwidth selection from the same package by default.
Expand Down Expand Up @@ -290,12 +301,16 @@ def effect_estimate_rdd(df: pd.DataFrame, outcome: str, running_variable: str, c

# Extract results - using 'TREATED' based on the provided summary output
effect = sm_results.params.get('TREATED')
if not treat_above_cutoff:
effect = effect * -1
std_err = sm_results.bse.get('TREATED')
p_value = sm_results.pvalues.get('TREATED')

conf_int_series = sm_results.conf_int()
conf_int = conf_int_series.loc['TREATED'].tolist() if 'TREATED' in conf_int_series.index else [None, None]

if not treat_above_cutoff:
if 'TREATED' in conf_int_series.index:
conf_int = [conf_int[1] * -1, conf_int[0] * -1]
n_obs = model.nobs # or model.n_ if nobs is not available (check package details)

# The formula is implicit in the local linear regression performed by the package
Expand Down Expand Up @@ -323,7 +338,7 @@ def effect_estimate_rdd(df: pd.DataFrame, outcome: str, running_variable: str, c
def estimate_effect(df: pd.DataFrame, treatment: str, outcome: str, running_variable: str,
cutoff_value: float, covariates: Optional[List[str]] = None,
bandwidth: Optional[float] = None, query: Optional[str] = None,
llm: Optional[BaseChatModel] = None, **kwargs) -> Dict[str, Any]:
llm: Optional[BaseChatModel] = None, treat_above_cutoff: Optional[bool] = None, **kwargs) -> Dict[str, Any]:
"""
Estimates the causal effect using Regression Discontinuity Design.

Expand All @@ -341,22 +356,27 @@ def estimate_effect(df: pd.DataFrame, treatment: str, outcome: str, running_vari
bandwidth: Optional bandwidth around the cutoff. If None, a default is used.
query: Optional user query for context.
llm: Optional Language Model instance.
treat_above_cutoff: True if the treatment is assigned above the cutoff, False if below the cutoff.
**kwargs: Additional keyword arguments for underlying methods.

Returns:
Dictionary containing estimation results.
"""
required_args = {
"running_variable": running_variable,
"cutoff_value": cutoff_value
"cutoff_value": cutoff_value,
}
if any(val is None for val in required_args.values()):
raise ValueError(f"Missing required RDD arguments: running_variable and cutoff_value must be provided.")
raise ValueError(f"Missing required RDD arguments: running_variable, cutoff_value must be provided.")

if treat_above_cutoff is None:
logger.warning("`treat_above_cutoff` is not provided. Assuming treatment is assigned above the cutoff.")
treat_above_cutoff = True

results = {}
rdd_em_estimation_error = None # Error from effect_estimate_rdd (evan-magnusson)
fallback_estimation_error = None # Error from estimate_effect_fallback

# --- Try effect_estimate_rdd (evan-magnusson/rdd) First ---
try:
logger.info("Attempting RDD estimation using 'effect_estimate_rdd' (evan-magnusson/rdd package).")
Expand All @@ -368,7 +388,8 @@ def estimate_effect(df: pd.DataFrame, treatment: str, outcome: str, running_vari
cutoff_value,
treatment=treatment, # For API consistency, though evan-magnusson/rdd doesn't use it explicitly
covariates=covariates,
bandwidth=bandwidth,
bandwidth=bandwidth,
treat_above_cutoff=treat_above_cutoff,
**kwargs
)
results['method_used'] = 'evan-magnusson/rdd' # Ensure method_used is set
Expand All @@ -384,7 +405,7 @@ def estimate_effect(df: pd.DataFrame, treatment: str, outcome: str, running_vari
if not results: # If effect_estimate_rdd wasn't used or failed
logger.info("'effect_estimate_rdd' did not produce results. Attempting fallback using 'estimate_effect_fallback'.")
try:
fallback_results = estimate_effect_fallback(df, treatment, outcome, running_variable, cutoff_value, covariates, bandwidth=bandwidth, **kwargs)
fallback_results = estimate_effect_fallback(df, treatment, outcome, running_variable, cutoff_value, covariates, bandwidth=bandwidth, treat_above_cutoff=treat_above_cutoff, **kwargs)
results.update(fallback_results)
results['method_used'] = 'Fallback RDD (Linear Interaction with Robust Errors)'
fallback_estimation_error = None # Clear fallback error if it succeeded
Expand Down
3 changes: 3 additions & 0 deletions cais/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class LLMRDDVars(BaseModel):
"""Pydantic model for identifying RDD variables."""
running_variable: Optional[str] = Field(None, description="The identified running variable column name.")
cutoff_value: Optional[Union[float, int]] = Field(None, description="The identified cutoff value.")
treat_above_cutoff: Optional[bool] = Field(None, description="True if the treatment is assigned above the cutoff, False if below the cutoff, None if unsure.")

class LLMRCTCheck(BaseModel):
"""Pydantic model for checking if data is RCT."""
Expand Down Expand Up @@ -100,6 +101,7 @@ class QueryInfo(BaseModel):
instrument_hints: Optional[List[str]] = None
running_variable_hints: Optional[List[str]] = None
cutoff_value_hint: Optional[Union[float, int]] = None
treat_above_cutoff_hint: Optional[bool] = None

class QueryInterpreterInput(BaseModel):
"""Input structure for the query interpreter tool."""
Expand All @@ -124,6 +126,7 @@ class Variables(BaseModel):
treatment_state: Optional[str] = None
running_variable: Optional[str] = None
cutoff_value: Optional[Union[float, int]] = None
treat_above_cutoff: Optional[bool] = None
is_rct: Optional[bool] = Field(False, description="Flag indicating if the dataset is from an RCT.")
treatment_reference_level: Optional[Union[float, str]] = Field(None, description="The specified reference/control level for a multi-valued treatment variable.")
interaction_term_suggested: Optional[bool] = Field(False, description="Whether the query or context suggests an interaction term with the treatment might be relevant.")
Expand Down
10 changes: 8 additions & 2 deletions cais/prompts/method_identification_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
You need to identify if a running variable exists for performing Regression Discontinuity Design (RDD) to answer the user query.
Go through the data description and available columns carefully. You need to be strict with the assessment.
In RDD, treatment assignment (for analysis) is determined by whether a continuous variable crosses a specific threshold.
Whether the treatment is assigned above or below the cutoff can be determined by the user query or dataset description.

User Query: "{query}"
Dataset Description: {description}
Expand All @@ -82,15 +83,20 @@
Step 3: Identify the cutoff value from design
- What specific threshold value determines treatment assignment?

Step 4: Final determination
Step 4: Identify if the treatment is assigned above or below the cutoff
- Is the treatment assigned above the cutoff? Store it as a boolean value in 'treat_above_cutoff' in the JSON.
- Store null if you are unsure.

Step 5: Final determination
- Only suggest RDD if both running variable and cutoff value can be identified
- Return null if the assignment mechanism is not threshold-based or you are unsure.

Important: Return only valid JSON. No explanations, reasoning, or markdown formatting.

{{
"running_variable": "column_name_or_null",
"cutoff_value": numeric_value_or_null
"cutoff_value": numeric_value_or_null,
"treat_above_cutoff": true_false_or_null,
}}
"""

Expand Down
2 changes: 1 addition & 1 deletion cais/tools/method_executor_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def method_executor_tool(inputs: MethodExecutorInput, original_query: Optional[s
# Avoid passing the entire variables_dict as estimate_func expects specific args
kwargs_for_method = {}
for key in ["instrument_variable", "time_variable", "group_variable",
"running_variable", "cutoff_value", "did_term", "did_canonical", "treatment_time", "treatment_state"]:
"running_variable", "cutoff_value", "treat_above_cutoff", "did_term", "did_canonical", "treatment_time", "treatment_state"]:
if key in variables_dict and variables_dict[key] is not None:
kwargs_for_method[key] = variables_dict[key]

Expand Down
Loading
Loading