diff --git a/.gitignore b/.gitignore index 68ebdbf4..07ed59bf 100644 --- a/.gitignore +++ b/.gitignore @@ -40,3 +40,4 @@ venv-triadmotif/ !/docs/source/img/* env/ website/docs/build +examples/forecaster/deferral_out diff --git a/convokit/__init__.py b/convokit/__init__.py index 9bd3462b..d2a8f4f9 100644 --- a/convokit/__init__.py +++ b/convokit/__init__.py @@ -21,6 +21,7 @@ "classifier": ".classifier", "ranker": ".ranker", "forecaster": ".forecaster", + "decisionpolicy": ".decisionpolicy", "fighting_words": ".fighting_words", "paired_prediction": ".paired_prediction", "bag_of_words": ".bag_of_words", diff --git a/convokit/decisionpolicy/__init__.py b/convokit/decisionpolicy/__init__.py new file mode 100644 index 00000000..43e3ab9b --- /dev/null +++ b/convokit/decisionpolicy/__init__.py @@ -0,0 +1,3 @@ +from .decisionPolicy import * +from .thresholdDecisionPolicy import * +from .deferralDecisionPolicy import * diff --git a/convokit/decisionpolicy/decisionPolicy.py b/convokit/decisionpolicy/decisionPolicy.py new file mode 100644 index 00000000..eff7ac75 --- /dev/null +++ b/convokit/decisionpolicy/decisionPolicy.py @@ -0,0 +1,158 @@ +from abc import ABC, abstractmethod +from typing import Callable, Tuple, Optional, Dict, Any + +import numpy as np +from sklearn.metrics import roc_curve +from tqdm import tqdm + + +class DecisionPolicy(ABC): + """ + Abstract interface for converting a conversational context into an action. + """ + + def __init__( + self, + forecast_prob_attribute_name: str = "forecast_prob", + reuse_cached_forecast_probs: bool = True, + ): + self._labeler = None + # name of the utterance-meta field that may already hold a forecast prob + # from a prior Forecaster.transform() pass. kept in sync with the owning + # ForecasterModel / Forecaster when they are wired up. + self.forecast_prob_attribute_name = forecast_prob_attribute_name + self.reuse_cached_forecast_probs = bool(reuse_cached_forecast_probs) + + @property + def labeler(self): + return self._labeler + + @labeler.setter + def labeler(self, value: Callable): + self._labeler = value + + def _score(self, context, score_fn: Callable) -> float: + # prefer a previously written forecast prob on the current utterance meta + # so policies don't re-invoke the belief estimator on utterances the + # forecaster has already transformed. synthetic / simulated utterances + # carry an empty meta and always fall through to score_fn. + if self.reuse_cached_forecast_probs: + meta = getattr(getattr(context, "current_utterance", None), "meta", None) or {} + cached = meta.get(self.forecast_prob_attribute_name) + if cached is not None: + return float(cached) + return float(score_fn(context)) + + def _fit_with_model_checkpoint_selection(self, val_contexts, score_fn: Callable = None): + if score_fn is None: + return None + forecaster_model = getattr(score_fn, "__self__", None) + if forecaster_model is None: + return None + get_checkpoints = getattr(forecaster_model, "get_checkpoints", None) + load_checkpoint = getattr(forecaster_model, "load_checkpoint", None) + finalize_best_checkpoint_selection = getattr( + forecaster_model, "finalize_best_checkpoint_selection", None + ) + if not callable(get_checkpoints) or not callable(load_checkpoint): + return None + + checkpoints = list(get_checkpoints()) + if len(checkpoints) == 0: + return None + + best_config = None + best_checkpoint = None + best_val_accuracy = -1.0 + # while sweeping checkpoints, any cached forecast_prob on utterance meta + # reflects whichever checkpoint's transform() ran last, not the one we + # are currently evaluating. force a fresh score_fn call for each sweep. + prior_reuse_flag = self.reuse_cached_forecast_probs + self.reuse_cached_forecast_probs = False + for checkpoint_name in checkpoints: + load_checkpoint(checkpoint_name) + fit_result = self._fit_threshold_for_loaded_model(val_contexts, score_fn=score_fn) + print(f"accuracy: {checkpoint_name} {fit_result['best_val_accuracy']}") + if fit_result["best_val_accuracy"] > best_val_accuracy: + best_checkpoint = checkpoint_name + best_val_accuracy = fit_result["best_val_accuracy"] + best_config = { + "best_checkpoint": checkpoint_name, + "best_threshold": float(fit_result["best_threshold"]), + "best_val_accuracy": float(fit_result["best_val_accuracy"]), + } + self.reuse_cached_forecast_probs = prior_reuse_flag + + if best_config is None: + return None + + if hasattr(self, "threshold"): + self.threshold = float(best_config["best_threshold"]) + load_checkpoint(best_checkpoint) + if callable(finalize_best_checkpoint_selection): + finalize_best_checkpoint_selection( + best_checkpoint, + best_config, + val_contexts=val_contexts, + score_fn=score_fn, + ) + return best_config + + def _fit_threshold_for_loaded_model(self, val_contexts, score_fn: Callable): + y_true, y_score = self._get_validation_arrays(val_contexts, score_fn) + default_threshold = float(getattr(self, "threshold", 0.5)) + if len(y_true) == 0: + return {"best_threshold": default_threshold, "best_val_accuracy": 0.0} + + try: + _, _, thresholds = roc_curve(y_true, y_score) + except ValueError: + thresholds = np.asarray([default_threshold], dtype=float) + + if len(thresholds) == 0: + thresholds = np.asarray([default_threshold], dtype=float) + + accs = [((y_score > t).astype(int) == y_true).mean() for t in thresholds] + best_idx = int(np.argmax(accs)) + best_threshold = float(thresholds[best_idx]) + return {"best_threshold": best_threshold, "best_val_accuracy": float(accs[best_idx])} + + def _get_validation_arrays(self, val_contexts, score_fn: Callable): + highest_convo_scores = {} + convo_labels = {} + for context in tqdm(val_contexts): + convo_id = context.conversation_id + score = self._score(context, score_fn) + label = int(self.labeler(context.current_utterance.get_conversation())) + if convo_id not in highest_convo_scores: + highest_convo_scores[convo_id] = score + else: + highest_convo_scores[convo_id] = max(highest_convo_scores[convo_id], score) + convo_labels[convo_id] = label + + convo_ids = list(highest_convo_scores.keys()) + y_true = np.asarray([convo_labels[c] for c in convo_ids]) + y_score = np.asarray([highest_convo_scores[c] for c in convo_ids]) + return y_true, y_score + + @abstractmethod + def decide(self, context, score_fn: Callable) -> Tuple[float, int, Optional[Dict[str, Any]]]: + """ + Decide whether to intervene for a context. + + :param context: context tuple supplied by Forecaster + :param score_fn: callable that maps a context tuple to a scalar score + :return: tuple containing the score, the integer action label (currently 0/1), and any additional metadata + """ + pass + + @abstractmethod + def fit(self, contexts, val_contexts=None, score_fn: Callable = None): + """ + Fit policy-specific parameters if needed. + + :param contexts: training contexts for policy fitting + :param val_contexts: optional validation contexts + :param score_fn: optional scorer callable exposed by ForecasterModel + """ + pass \ No newline at end of file diff --git a/convokit/decisionpolicy/deferralDecisionPolicy.py b/convokit/decisionpolicy/deferralDecisionPolicy.py new file mode 100644 index 00000000..335eb2ae --- /dev/null +++ b/convokit/decisionpolicy/deferralDecisionPolicy.py @@ -0,0 +1,206 @@ +from typing import Callable, List, Optional, Dict, Any, Tuple + +from .decisionPolicy import DecisionPolicy + + +class _synthetic_speaker: + def __init__(self, speaker_id: str): + self.id = speaker_id + + +class _synthetic_utterance: + def __init__(self, text: str, utterance_id: str, speaker_id: str): + self.text = text + self.id = utterance_id + self.speaker_ = _synthetic_speaker(speaker_id) + self.meta = {} + + def get_conversation(self): + return None + + +class DeferralDecisionPolicy(DecisionPolicy): + """ + Decision policy that defers intervention by looking ahead at simulated next utterances. + + :param simulator: utterance simulator model (must have a ``transform(contexts)`` method + returning a DataFrame indexed by utterance id). if the simulator exposes + ``get_num_simulations()``, ``num_simulations`` is capped to that value. + :param threshold: probability threshold above which a context is flagged. + :param tau: minimum number of simulated branches that must exceed the threshold + before an intervention is issued. + :param num_simulations: how many simulated branches to use per context (capped to + simulator's ``get_num_simulations()`` if available). + :param store_simulations: if True, simulated reply strings are cached during decide() + and written to corpus utterance metadata by post_transform(). + :param simulated_reply_attribute_name: metadata field name used when storing simulations + on corpus utterances (only relevant when store_simulations=True). + :param reuse_cached_simulations: if True (default), simulations already present on the + current utterance's metadata under ``simulated_reply_attribute_name`` are reused + instead of re-invoking the simulator. similarly, cached simulation scores under + ``sim_replies_forecast_probs_attribute_name`` are reused when they align with the + reused simulations, skipping re-scoring. set to False to force regeneration. + """ + + def __init__( + self, + simulator, + threshold, + tau: int = 5, + num_simulations: int = 10, + store_simulations: bool = False, + simulated_reply_attribute_name: str = "sim_replies", + sim_replies_forecast_probs_attribute_name: str = "sim_replies_forecast_probs", + reuse_cached_simulations: bool = True, + forecast_prob_attribute_name: str = "forecast_prob", + reuse_cached_forecast_probs: bool = True, + ): + super().__init__( + forecast_prob_attribute_name=forecast_prob_attribute_name, + reuse_cached_forecast_probs=reuse_cached_forecast_probs, + ) + self.simulator = simulator + self.threshold = float(threshold) + self.tau = int(tau) + n = int(num_simulations) + if simulator is not None and hasattr(simulator, "get_num_simulations"): + n = min(n, int(simulator.get_num_simulations())) + self.num_simulations = n + self.store_simulations = store_simulations + self.simulated_reply_attribute_name = simulated_reply_attribute_name + self.sim_replies_forecast_probs_attribute_name = sim_replies_forecast_probs_attribute_name + self.reuse_cached_simulations = bool(reuse_cached_simulations) + self._sim_cache: dict = {} + self._sim_score_cache: dict = {} + + def _get_utt_meta(self, context): + # unified accessor so both real Utterance and _synthetic_utterance work; returns {} if absent. + return getattr(context.current_utterance, "meta", {}) or {} + + def _get_cached_simulations(self, context) -> Optional[List[str]]: + # returns cached simulation strings for this utterance if available on its metadata, else None. + if not self.reuse_cached_simulations: + return None + meta = self._get_utt_meta(context) + cached = meta.get(self.simulated_reply_attribute_name) + if cached is None: + return None + cached_list = list(cached) + if len(cached_list) == 0: + return None + return cached_list[: self.num_simulations] + + def _get_cached_simulation_scores( + self, context, num_expected: int + ) -> Optional[List[float]]: + # returns cached per-simulation scores aligned with reused simulations, else None. + if not self.reuse_cached_simulations or num_expected == 0: + return None + meta = self._get_utt_meta(context) + cached = meta.get(self.sim_replies_forecast_probs_attribute_name) + if cached is None: + return None + cached_list = list(cached) + # if the cached scores are shorter than the reused simulations, fall back to re-scoring + # rather than silently mixing cached and fresh scores. + if len(cached_list) < num_expected: + return None + return [float(x) for x in cached_list[:num_expected]] + + def get_simulations(self, context, simulator=None) -> List[str]: + # fast path: reuse pre-computed simulations from utterance metadata when present. + cached = self._get_cached_simulations(context) + if cached is not None: + return cached + sim = simulator if simulator is not None else self.simulator + if sim is None or not hasattr(sim, "transform"): + return [] + sims = sim.transform(iter([context])) + utt_id = context.current_utterance.id + if utt_id not in sims.index or sims.shape[1] == 0: + return [] + col_name = sims.columns[0] + return list(sims.loc[utt_id][col_name])[: self.num_simulations] + + def _build_simulated_context(self, context, simulation_text: str, simulation_idx: int): + current_utt = context.current_utterance + synthetic_utt = _synthetic_utterance( + text=simulation_text, + utterance_id=f"{current_utt.id}__sim_{simulation_idx}", + speaker_id="", + ) + new_context_utts = list(context.context) + [synthetic_utt] + context_cls = context.__class__ + return context_cls( + context=new_context_utts, + current_utterance=synthetic_utt, + future_context=None, + conversation_id=context.conversation_id, + ) + + def _decision_score(self, context, score_fn: Callable): + current_score = self._score(context, score_fn) + simulations = self.get_simulations(context) + # the get_simulations method actively checks if cached simulations exist + + # fast path: if cached per-simulation scores align with the reused simulations, + # skip re-scoring the simulated contexts entirely. + cached_scores = self._get_cached_simulation_scores(context, len(simulations)) + if cached_scores is not None: + simulation_scores = cached_scores + else: + simulation_scores = [] + for idx, sim_text in enumerate(simulations): + sim_context = self._build_simulated_context(context, sim_text, idx) + # synthetic utterances have empty meta so _score falls through to score_fn. + simulation_scores.append(self._score(sim_context, score_fn)) + if self.store_simulations and simulations: + utt_id = context.current_utterance.id + self._sim_cache[utt_id] = simulations + self._sim_score_cache[utt_id] = simulation_scores + return current_score, simulations, simulation_scores + + def decide(self, context, score_fn: Callable) -> Tuple[float, int, Optional[Dict[str, Any]]]: + max_defer_index = 4 + decision_score, simulations, simulation_scores = self._decision_score(context, score_fn) + num_simulations_above_threshold = sum(1 for score in simulation_scores if score > self.threshold) + num_simulations = len(simulations) + # context.context contains chronological_utts[: i+1] (includes current_utterance), + # so the current utterance's position in the conversation is len(context.context) - 1. + utt_index = max(0, len(getattr(context, "context", []) or []) - 1) + # past the deferral window we always commit when fp > threshold, mirroring the + # `i < 4` early-only deferral in performance_utils.no_tricks. + past_defer_window = max_defer_index is not None and utt_index >= max_defer_index + defer_eligible = not past_defer_window + num_calm = num_simulations - num_simulations_above_threshold + # defer = defer_eligible and (num_calm > self.tau) + defer = (num_calm > self.tau) + return ( + decision_score, + 1 if decision_score > self.threshold and not defer else 0, + { + self.simulated_reply_attribute_name: simulations, + self.sim_replies_forecast_probs_attribute_name: simulation_scores, + }, + ) + + def fit(self, contexts, val_contexts=None, score_fn: Callable = None): + if val_contexts is None or score_fn is None or self.labeler is None: + print("either no validation contexts/score function/labeler were provided, returning current threshold") + return {"best_threshold": self.threshold} + + val_contexts = list(val_contexts) + if len(val_contexts) == 0: + print("no validation contexts were provided, returning current threshold") + return {"best_threshold": self.threshold} + + fit_result = self._fit_with_model_checkpoint_selection(val_contexts, score_fn=score_fn) + if isinstance(fit_result, dict): + if "best_threshold" in fit_result: + self.threshold = float(fit_result["best_threshold"]) + return fit_result + + fit_result = self._fit_threshold_for_loaded_model(val_contexts, score_fn=score_fn) + if "best_threshold" in fit_result: + self.threshold = float(fit_result["best_threshold"]) + return fit_result diff --git a/convokit/decisionpolicy/thresholdDecisionPolicy.py b/convokit/decisionpolicy/thresholdDecisionPolicy.py new file mode 100644 index 00000000..16df7ff2 --- /dev/null +++ b/convokit/decisionpolicy/thresholdDecisionPolicy.py @@ -0,0 +1,46 @@ +from typing import Callable, Tuple + +from .decisionPolicy import DecisionPolicy + + +class ThresholdDecisionPolicy(DecisionPolicy): + """ + A simple decision policy that predicts 1 when score > threshold. + """ + + def __init__( + self, + threshold: float = 0.5, + forecast_prob_attribute_name: str = "forecast_prob", + reuse_cached_forecast_probs: bool = True, + ): + super().__init__( + forecast_prob_attribute_name=forecast_prob_attribute_name, + reuse_cached_forecast_probs=reuse_cached_forecast_probs, + ) + self.threshold = float(threshold) + + def decide(self, context, score_fn: Callable) -> Tuple[float, int]: + score = self._score(context, score_fn) + return score, int(score > self.threshold) + + def fit(self, contexts, val_contexts=None, score_fn: Callable = None): + if val_contexts is None or score_fn is None or self.labeler is None: + print("either no validation contexts/score function/labeler were provided, returning current threshold") + return {"best_threshold": self.threshold} + + val_contexts = list(val_contexts) + if len(val_contexts) == 0: + print("no validation contexts were provided, returning current threshold") + return {"best_threshold": self.threshold} + + fit_result = self._fit_with_model_checkpoint_selection(val_contexts, score_fn=score_fn) + if isinstance(fit_result, dict): + if "best_threshold" in fit_result: + self.threshold = float(fit_result["best_threshold"]) + return fit_result + + fit_result = self._fit_threshold_for_loaded_model(val_contexts, score_fn=score_fn) + if "best_threshold" in fit_result: + self.threshold = float(fit_result["best_threshold"]) + return fit_result diff --git a/convokit/forecaster/CRAFTModel.py b/convokit/forecaster/CRAFTModel.py index b0937e7f..80aa0b35 100644 --- a/convokit/forecaster/CRAFTModel.py +++ b/convokit/forecaster/CRAFTModel.py @@ -10,13 +10,12 @@ from convokit import download, warn from convokit.convokitConfig import ConvoKitConfig from .CRAFT.model import EncoderRNN, ContextEncoderRNN, SingleTargetClf -from .CRAFT.runners import Predictor, trainIters, evaluateDataset +from .CRAFT.runners import Predictor, trainIters, evaluateBatch from .forecasterModel import ForecasterModel -import numpy as np -import torch.nn.functional as F from torch import optim, nn from typing import Dict, Union import os +from convokit.decisionpolicy import ThresholdDecisionPolicy # parameters baked into the model design (because the provided models were saved with these parameters); # these cannot be changed by the user @@ -89,8 +88,9 @@ def __init__( decision_threshold: Union[float, str] = "auto", torch_device: str = "cpu", config: dict = DEFAULT_CONFIG, + decision_policy=None, ): - super().__init__() + super().__init__(decision_policy=decision_policy) # load the initial weights and store this as the current model if initial_weights in MODEL_FILENAME_MAP: @@ -131,18 +131,36 @@ def __init__( raise TypeError("CRAFTModel: decision_threshold must be either a float or 'auto'") self._decision_threshold = DECISION_THRESHOLDS.get(initial_weights, 0.5) + if isinstance(self.decision_policy, ThresholdDecisionPolicy): + self.decision_policy.threshold = float(self._decision_threshold) + self._device = torch.device(torch_device) self._config = config + self._inference_components = None + + @property + def best_threshold(self): + if hasattr(self.decision_policy, "threshold"): + return self.decision_policy.threshold + return None - def _context_to_craft_data(self, contexts): + @best_threshold.setter + def best_threshold(self, value): + if hasattr(self.decision_policy, "threshold"): + self.decision_policy.threshold = float(value) + + def _context_to_craft_data(self, contexts, include_labels=True): """ Convert context utterances to a list of token-lists using the model's vocabulary object, maintaining the original temporal ordering """ pairs = [] for context in contexts: - convo = context.current_utterance.get_conversation() - label = self.labeler(convo) + if include_labels: + convo = context.current_utterance.get_conversation() + label = self.labeler(convo) + else: + label = 0 processed_context = processContext(self._voc, context, label) utt = processed_context[-1]["tokens"][: (MAX_LENGTH - 1)] context_utts = [u["tokens"][: (MAX_LENGTH - 1)] for u in processed_context] @@ -188,7 +206,17 @@ def _init_craft(self): return embedding, encoder, context_encoder, attack_clf - def fit(self, contexts, val_contexts=None): + def _get_inference_components(self): + if self._inference_components is None: + embedding, encoder, context_encoder, attack_clf = self._init_craft() + encoder.eval() + context_encoder.eval() + attack_clf.eval() + predictor = Predictor(encoder, context_encoder, attack_clf) + self._inference_components = (encoder, context_encoder, predictor) + return self._inference_components + + def fit_belief_estimator(self, contexts, val_contexts=None): """ Fine-tune the CRAFT model, and save the best model according to validation performance. @@ -196,12 +224,12 @@ def fit(self, contexts, val_contexts=None): :param val_contexts: an iterator over context tuples to be used only for validation. IMPORTANT: this is marked Optional only for compatibility with the generic Forecaster API; CRAFT actually REQUIRES a validation set so leaving this parameter at None will raise an error! """ # convert the input contexts into CRAFT's data format - train_pairs = self._context_to_craft_data(contexts) + train_pairs = self._context_to_craft_data(contexts, include_labels=True) print("Processed", len(train_pairs), "context tuples for model training") # val_contexts is made Optional to conform to the Forecaster spec, but in reality CRAFT requires a validation set if val_contexts is None: raise ValueError("CRAFTModel requires a validation set!") - val_pairs = self._context_to_craft_data(val_contexts) + val_pairs = self._context_to_craft_data(val_contexts, include_labels=True) print("Processed", len(val_pairs), "context tuples for model validation") # initialize the CRAFT model with whatever weights we currently have saved @@ -252,6 +280,50 @@ def fit(self, contexts, val_contexts=None): # save the resulting checkpoints so we can load them later during transform self._model = best_model + self._inference_components = None + + def fit_decision_policy(self, contexts, val_contexts=None): + return super().fit_decision_policy(contexts, val_contexts) + + def fit(self, contexts, val_contexts=None): + return super().fit(contexts, val_contexts) + + def score(self, context) -> float: + encoder, context_encoder, predictor = self._get_inference_components() + score_pairs = self._context_to_craft_data([context], include_labels=False) + batch, batch_dialogs, _, true_batch_size = next( + batchIterator(self._voc, score_pairs, batch_size=1, shuffle=False) + ) + ( + input_variable, + dialog_lengths, + utt_lengths, + batch_indices, + dialog_indices, + labels, + convo_ids, + target_variable, + mask, + max_target_len, + ) = batch + dialog_lengths_list = [len(x) for x in batch_dialogs] + _, scores = evaluateBatch( + encoder, + context_encoder, + predictor, + self._voc, + input_variable, + dialog_lengths, + dialog_lengths_list, + utt_lengths, + batch_indices, + dialog_indices, + true_batch_size, + self._device, + MAX_LENGTH, + threshold=self.best_threshold if self.best_threshold is not None else 0.5, + ) + return float(scores[0].item()) def transform(self, contexts, forecast_attribute_name, forecast_prob_attribute_name): """ @@ -264,34 +336,83 @@ def transform(self, contexts, forecast_attribute_name, forecast_prob_attribute_n :return: a Pandas DataFrame, with one row for each context, indexed by the ID of that context's current utterance. Contains two columns, one with raw probabilities named according to forecast_prob_attribute_name, and one with discretized (binary) forecasts named according to forecast_attribute_name """ # convert the input contexts into CRAFT's data format - test_pairs = self._context_to_craft_data(contexts) + contexts = list(contexts) + context_by_utt_id = {context.current_utterance.id: context for context in contexts} + test_pairs = self._context_to_craft_data(contexts, include_labels=False) print("Processed", len(test_pairs), "context tuples for model evaluation") # initialize the CRAFT model with whatever weights we currently have saved - embedding, encoder, context_encoder, attack_clf = self._init_craft() - - # Set dropout layers to eval mode - encoder.eval() - context_encoder.eval() - attack_clf.eval() + encoder, context_encoder, predictor = self._get_inference_components() - # Initialize the pipeline - predictor = Predictor(encoder, context_encoder, attack_clf) - - # Run the pipeline! - forecasts_df = evaluateDataset( - test_pairs, - encoder, - context_encoder, - predictor, - self._voc, - self._config["batch_size"], - self._device, - MAX_LENGTH, - batchIterator, - self._decision_threshold, - forecast_attribute_name, - forecast_prob_attribute_name, + base_columns = {"id", forecast_attribute_name, forecast_prob_attribute_name} + output_df = {"id": [], forecast_attribute_name: [], forecast_prob_attribute_name: []} + batch_iterator = batchIterator( + self._voc, test_pairs, self._config["batch_size"], shuffle=False + ) + n_iters = len(test_pairs) // self._config["batch_size"] + int( + len(test_pairs) % self._config["batch_size"] > 0 ) + for iteration in range(1, n_iters + 1): + batch, batch_dialogs, _, true_batch_size = next(batch_iterator) + ( + input_variable, + dialog_lengths, + utt_lengths, + batch_indices, + dialog_indices, + labels, + convo_ids, + target_variable, + mask, + max_target_len, + ) = batch + dialog_lengths_list = [len(x) for x in batch_dialogs] + _, scores = evaluateBatch( + encoder, + context_encoder, + predictor, + self._voc, + input_variable, + dialog_lengths, + dialog_lengths_list, + utt_lengths, + batch_indices, + dialog_indices, + true_batch_size, + self._device, + MAX_LENGTH, + threshold=self.best_threshold if self.best_threshold is not None else 0.5, + ) + for i in range(true_batch_size): + score = float(scores[i].item()) + utt_id = convo_ids[i] + context = context_by_utt_id[utt_id] + + def score_fn(scored_context): + scored_utt_id = scored_context.current_utterance.id + if scored_utt_id == utt_id: + return score + return self.score(scored_context) + + utt_score, pred, utt_metadata = self._parse_decision_result( + self.decision_policy.decide(context, score_fn) + ) + current_idx = len(output_df["id"]) + output_df["id"].append(utt_id) + output_df[forecast_attribute_name].append(int(pred)) + output_df[forecast_prob_attribute_name].append(utt_score) + existing_metadata_keys = [key for key in output_df if key not in base_columns] + for key in existing_metadata_keys: + output_df[key].append(utt_metadata.get(key, None)) + for key, value in utt_metadata.items(): + if key not in output_df: + output_df[key] = [None] * current_idx + output_df[key].append(value) + print( + "Iteration: {}; Percent complete: {:.1f}%".format( + iteration, iteration / n_iters * 100 + ) + ) + forecasts_df = pd.DataFrame(output_df).set_index("id") return forecasts_df diff --git a/convokit/forecaster/TransformerDecoderModel.py b/convokit/forecaster/TransformerDecoderModel.py index a8d1813e..a7d17782 100644 --- a/convokit/forecaster/TransformerDecoderModel.py +++ b/convokit/forecaster/TransformerDecoderModel.py @@ -1,3 +1,4 @@ +from itertools import tee, islice import unsloth from unsloth import FastLanguageModel, is_bfloat16_supported from unsloth.chat_templates import get_chat_template @@ -5,17 +6,13 @@ import torch.nn.functional as F from trl import SFTTrainer, SFTConfig from datasets import Dataset +from collections import defaultdict -import json import os from tqdm import tqdm import pandas as pd -import numpy as np -from sklearn.metrics import roc_curve from .forecasterModel import ForecasterModel from .TransformerForecasterConfig import TransformerForecasterConfig -import shutil - def _get_template_map(model_name_or_path): """ @@ -72,7 +69,9 @@ def __init__( config=DEFAULT_CONFIG, system_msg=None, question_msg=None, + decision_policy=None, ): + super().__init__(decision_policy=decision_policy) self.max_seq_length = 4_096 * 2 self.model, tokenizer = FastLanguageModel.from_pretrained( model_name=model_name_or_path, @@ -100,7 +99,6 @@ def __init__( "Will the above conversation derail into a personal attack now or at any point in the future? " "Strictly start your answer with Yes or No, otherwise the answer is invalid." ) - self.best_threshold = 0.5 if not os.path.exists(config.output_dir): os.makedirs(config.output_dir) @@ -108,6 +106,33 @@ def __init__( return + @property + def best_threshold(self): + if hasattr(self.decision_policy, "threshold"): + return self.decision_policy.threshold + return None + + @best_threshold.setter + def best_threshold(self, value): + if hasattr(self.decision_policy, "threshold"): + self.decision_policy.threshold = float(value) + + def get_checkpoints(self): + checkpoints = [cp for cp in os.listdir(self.config.output_dir) if "checkpoint-" in cp] + if len(checkpoints) == 0: + return ["zero-shot"] + return checkpoints + + def load_checkpoint(self, checkpoint_name): + if checkpoint_name == "zero-shot": + return + full_model_path = os.path.join(self.config.output_dir, checkpoint_name) + self.model, _ = FastLanguageModel.from_pretrained( + model_name=full_model_path, + max_seq_length=self.max_seq_length, + load_in_4bit=True, + ) + def _context_mode(self, context): """ Select the utterances to include in the input context based on the configured context mode. @@ -218,14 +243,13 @@ def _context_to_llm_data(self, contexts): print(f"There are {len(dataset)} samples") return Dataset.from_list(dataset) - def fit(self, train_contexts, val_contexts): + def fit_belief_estimator(self, train_contexts, val_contexts=None): """ Fine-tune the TransformerDecoder model using LoRA and save the best model based on validation performance. This method applies Low-Rank Adaptation (LoRA) to the decoder model, converts the training contexts into text-based input for LLM fine-tuning, and trains the model - using HuggingFace's `SFTTrainer`. After training, it tunes a decision threshold on - a held-out validation set to optimize binary forecast classification. + using HuggingFace's `SFTTrainer`. :param contexts: an iterator over context tuples, provided by the Forecaster framework :param val_contexts: an iterator over context tuples to be used only for validation. @@ -260,9 +284,9 @@ def fit(self, train_contexts, val_contexts): model=self.model, tokenizer=self.tokenizer, train_dataset=train_dataset, + max_seq_length=self.max_seq_length, args=SFTConfig( dataset_text_field="text", - max_seq_length=self.max_seq_length, per_device_train_batch_size=self.config.per_device_batch_size, gradient_accumulation_steps=self.config.gradient_accumulation_steps, warmup_steps=10, @@ -282,136 +306,9 @@ def fit(self, train_contexts, val_contexts): ), ) trainer.train() - _ = self._tune_threshold(self, val_contexts) return - def _tune_threshold(self, val_contexts): - """ - Tune the decision threshold and select the best model checkpoint based on validation accuracy. - - This method evaluates all model checkpoints in the configured output directory using a - held-out validation set. - - The selected model, threshold, and associated metadata are stored in: - - `self.model`: the best-performing fine-tuned model - - `self.best_threshold`: the optimal decision threshold - - `dev_config.json`: file containing best checkpoint metadata - - `val_predictions.csv`: CSV file with forecast outputs on the validation set - - Additionally, all non-optimal model checkpoints are removed to save disk space, and the - tokenizer is saved to the directory of the best checkpoint. - - :param val_dataset: A HuggingFace-compatible dataset containing features for validation. - :param val_contexts: An iterable of context tuples corresponding to the validation set. - Used to map utterance IDs to conversation IDs and extract ground-truth labels. - - :return: A dictionary containing the best checkpoint path, best threshold, and best validation accuracy. - """ - checkpoints = [cp for cp in os.listdir(self.config.output_dir) if "checkpoint-" in cp] - if checkpoints == []: - checkpoints.append("zero-shot") - best_val_accuracy = 0 - val_convo_ids = set() - utt2convo = {} - val_labels_dict = {} - val_contexts = list(val_contexts) - for context in val_contexts: - convo_id = context.conversation_id - utt_id = context.current_utterance.id - label = self.labeler(context.current_utterance.get_conversation()) - utt2convo[utt_id] = convo_id - val_labels_dict[convo_id] = label - val_convo_ids.add(convo_id) - val_convo_ids = list(val_convo_ids) - for cp in checkpoints: - if cp != "zero-shot": - full_model_path = os.path.join(self.config.output_dir, cp) - self.model, _ = FastLanguageModel.from_pretrained( - model_name=full_model_path, - max_seq_length=self.max_seq_length, - load_in_4bit=True, - ) - FastLanguageModel.for_inference(self.model) - utt2score = {} - for context in tqdm(val_contexts): - utt_score, _ = self._predict(context) - utt_id = context.current_utterance.id - utt2score[utt_id] = utt_score - # for each CONVERSATION, whether or not it triggers will be effectively determined by what the highest score it ever got was - highest_convo_scores = {convo_id: -1 for convo_id in val_convo_ids} - - for utt_id in utt2convo: - convo_id = utt2convo[utt_id] - utt_score = utt2score[utt_id] - if utt_score > highest_convo_scores[convo_id]: - highest_convo_scores[convo_id] = utt_score - - val_labels = np.asarray([int(val_labels_dict[c]) for c in val_convo_ids]) - val_scores = np.asarray([highest_convo_scores[c] for c in val_convo_ids]) - # use scikit learn to find candidate threshold cutoffs - _, _, thresholds = roc_curve(val_labels, val_scores) - - def acc_with_threshold(y_true, y_score, thresh): - y_pred = (y_score > thresh).astype(int) - return (y_pred == y_true).mean() - - accs = [acc_with_threshold(val_labels, val_scores, t) for t in thresholds] - best_acc_idx = np.argmax(accs) - - print("Accuracy:", cp, accs[best_acc_idx]) - if accs[best_acc_idx] > best_val_accuracy: - best_checkpoint = cp - best_val_accuracy = accs[best_acc_idx] - self.best_threshold = thresholds[best_acc_idx] - - # Save the best config - best_config = {} - best_config["best_checkpoint"] = best_checkpoint - best_config["best_threshold"] = self.best_threshold - best_config["best_val_accuracy"] = best_val_accuracy - config_file = os.path.join(self.config.output_dir, "dev_config.json") - with open(config_file, "w") as outfile: - json_object = json.dumps(best_config, indent=4) - outfile.write(json_object) - # Load best model - best_model_path = os.path.join(self.config.output_dir, best_checkpoint) - self.model, _ = FastLanguageModel.from_pretrained( - model_name=best_model_path, - max_seq_length=self.max_seq_length, - load_in_4bit=True, - ) - - # Clean other checkpoints to save disk space. - for root, _, _ in os.walk(self.config.output_dir): - if ("checkpoint" in root) and (best_checkpoint not in root): - print("Deleting:", root) - shutil.rmtree(root) - # Save the tokenizer. - self.tokenizer.save_pretrained( - os.path.join(self.config.output_dir, best_config["best_checkpoint"]) - ) - return best_config - - def _predict(self, context, threshold=None): - """ - Run inference on a single context using the fine-tuned TransformerDecoder model. - - This method prepares the input from the given context, generates a single-token - prediction (either "Yes" or "No"), and computes the softmax probability for "Yes". - The output is a confidence score and a binary prediction based on the given or - default threshold. - - :param context: A context tuple containing the current utterance and conversation history. - :param threshold: (Optional) A float threshold for converting the predicted probability into a binary label. - If not provided, `self.best_threshold` is used. - - :return: A tuple (`utt_score`, `utt_pred`), where: - - `utt_score` is the softmax probability assigned to "Yes" - - `utt_pred` is the binary prediction (1 if `utt_score > threshold`, else 0) - """ - # Enabling inference with different checkpoints to _tune_best_val_accuracy - if not threshold: - threshold = self.best_threshold + def score(self, context) -> float: FastLanguageModel.for_inference(self.model) context_utts = self._context_mode(context) inputs = self._tokenize(context_utts).to(self.config.device) @@ -432,16 +329,56 @@ def _predict(self, context, threshold=None): utt_score = F.softmax(torch.tensor([yes_logit, no_logit], dtype=torch.float32), dim=0)[ 0 ].item() - utt_pred = int(utt_score > threshold) + return utt_score + + def _predict(self, context, threshold=None): + """ + Run inference on a single context using the fine-tuned TransformerDecoder model. + + This method prepares the input from the given context, generates a single-token + prediction (either "Yes" or "No"), and computes the softmax probability for "Yes". + The output is a confidence score and a binary prediction based on the given or + default threshold. + + :param context: A context tuple containing the current utterance and conversation history. + :param threshold: (Optional) A float threshold for converting the predicted probability into a binary label. + If not provided, `self.best_threshold` is used. + + :return: A tuple (`utt_score`, `utt_pred`), where: + - `utt_score` is the softmax probability assigned to "Yes" + - `utt_pred` is the binary prediction (1 if `utt_score > threshold`, else 0) + """ + utt_score = self.score(context) + # keep threshold override for backward compatibility. + if threshold is not None: + utt_pred = int(utt_score > threshold) + else: + result = self.decision_policy.decide(context, self.score) + if len(result) == 2: + utt_score, utt_pred = result + elif len(result) == 3: + utt_score, utt_pred, _ = result + else: + raise ValueError( + "decision_policy.decide() must return (utt_score, utt_pred) " + "or (utt_score, utt_pred, metadata_dict)" + ) return utt_score, utt_pred - def transform(self, contexts, forecast_attribute_name, forecast_prob_attribute_name): + def fit(self, contexts, val_contexts=None): + val_contexts_belief_estimator, val_contexts_decision_policy = tee(val_contexts, 2) + self.fit_belief_estimator(contexts, val_contexts_belief_estimator) + self.fit_decision_policy(contexts, val_contexts_decision_policy, score_fn=self.score) + return + + def transform(self, contexts, forecast_attribute_name, forecast_prob_attribute_name, verbose=False): """ Generate forecasts using the fine-tuned TransformerDecoder model on the provided contexts, and save the predictions to the output directory specified in the configuration. :param contexts: context tuples from the Forecaster framework :param forecast_attribute_name: Forecaster will use this to look up the table column containing your model's discretized predictions (see output specification below) :param forecast_prob_attribute_name: Forecaster will use this to look up the table column containing your model's raw forecast probabilities (see output specification below) + :param verbose: if True, print verbose transform logging during the transformation :return: a Pandas DataFrame, with one row for each context, indexed by the ID of that context's current utterance. Contains two columns, one with raw probabilities named according to forecast_prob_attribute_name, and one with discretized (binary) forecasts named according to forecast_attribute_name """ @@ -449,15 +386,151 @@ def transform(self, contexts, forecast_attribute_name, forecast_prob_attribute_n utt_ids = [] preds = [] scores = [] - for context in tqdm(contexts): - utt_score, utt_pred = self._predict(context) - + metadatas = defaultdict(list) + # TODO(metrics): temporary running metric logging during transform; remove before merge. + report_every_n = 250 + prediction_file = os.path.join(self.config.output_dir, "predictions.csv") + if os.path.exists(prediction_file): + os.remove(prediction_file) + next_flush_start = 0 + csv_header_written = False + convo_forecasts = {} + convo_labels = {} + + def _compute_conversation_metrics(): + common_convo_ids = [cid for cid in convo_forecasts if cid in convo_labels] + if len(common_convo_ids) == 0: + return None + tp = 0 + fp = 0 + tn = 0 + fn = 0 + for convo_id in common_convo_ids: + pred = int(convo_forecasts[convo_id] > 0) + label = int(convo_labels[convo_id]) + if label == 1 and pred == 1: + tp += 1 + elif label == 0 and pred == 1: + fp += 1 + elif label == 0 and pred == 0: + tn += 1 + elif label == 1 and pred == 0: + fn += 1 + n = len(common_convo_ids) + acc = (tp + tn) / n if n > 0 else 0.0 + p = tp / (tp + fp) if (tp + fp) > 0 else 0.0 + r = tp / (tp + fn) if (tp + fn) > 0 else 0.0 + fpr = fp / (fp + tn) if (fp + tn) > 0 else 0.0 + f1 = (2 * p * r / (p + r)) if (p + r) > 0 else 0.0 + return {"n": n, "acc": acc, "p": p, "r": r, "fpr": fpr, "f1": f1} + # for safety/flexibility we can accept either only score and pred or also the metadata + progress = tqdm(contexts) + for idx, context in enumerate(progress, start=1): + result = self.decision_policy.decide(context, self.score) + + if len(result) == 2: + utt_score, utt_pred = result + utt_metadata = {} + # no metadata + elif len(result) == 3: + utt_score, utt_pred, utt_metadata = result + # coerce None metadata to {} so policies that return (score, pred, None) + # don't crash downstream utt_metadata.items() / .get() calls. + if utt_metadata is None: + utt_metadata = {} + else: + raise ValueError( + "decision_policy.decide() must return (utt_score, utt_pred) " + "or (utt_score, utt_pred, metadata_dict)" + ) utt_ids.append(context.current_utterance.id) preds.append(utt_pred) scores.append(utt_score) - forecasts_df = pd.DataFrame( - {forecast_attribute_name: preds, forecast_prob_attribute_name: scores}, index=utt_ids - ) - prediction_file = os.path.join(self.config.output_dir, "test_predictions.csv") - forecasts_df.to_csv(prediction_file) - return forecasts_df + current_idx = len(preds) - 1 + existing_metadata_keys = list(metadatas.keys()) + for key in existing_metadata_keys: + metadatas[key].append(utt_metadata.get(key, None)) + for key, value in utt_metadata.items(): + if key not in metadatas: + metadatas[key] = [None] * current_idx + metadatas[key].append(value) + + convo_id = getattr(context, "conversation_id", None) + try: + convo = context.current_utterance.get_conversation() + if convo_id is None and convo is not None: + convo_id = convo.id + if convo_id is not None: + if convo_id in convo_forecasts: + convo_forecasts[convo_id] = max(convo_forecasts[convo_id], int(utt_pred)) + else: + convo_forecasts[convo_id] = int(utt_pred) + if convo_id not in convo_labels: + convo_labels[convo_id] = int(self.labeler(convo)) + except Exception: + pass + + if idx % report_every_n == 0: + batch_cols = { + forecast_attribute_name: preds[next_flush_start:idx], + forecast_prob_attribute_name: scores[next_flush_start:idx], + } + for key, series in metadatas.items(): + batch_cols[key] = series[next_flush_start:idx] + batch_df = pd.DataFrame(batch_cols, index=utt_ids[next_flush_start:idx]) + batch_df.to_csv( + prediction_file, + mode="a" if csv_header_written else "w", + header=not csv_header_written, + ) + csv_header_written = True + next_flush_start = idx + + running_metrics = _compute_conversation_metrics() + if verbose: + if running_metrics is not None: + tqdm.write( + f"[info] transform metrics running: " + f"processed_contexts={idx}, conversations={running_metrics['n']}, " + f"acc={running_metrics['acc']:.4f}, p={running_metrics['p']:.4f}, " + f"r={running_metrics['r']:.4f}, fpr={running_metrics['fpr']:.4f}, " + f"f1={running_metrics['f1']:.4f}" + ) + else: + tqdm.write( + f"[info] transform metrics running: " + f"processed_contexts={idx}, conversations=0" + ) + total_processed = len(preds) + if total_processed > next_flush_start: + batch_cols = { + forecast_attribute_name: preds[next_flush_start:total_processed], + forecast_prob_attribute_name: scores[next_flush_start:total_processed], + } + for key, series in metadatas.items(): + batch_cols[key] = series[next_flush_start:total_processed] + batch_df = pd.DataFrame(batch_cols, index=utt_ids[next_flush_start:total_processed]) + batch_df.to_csv( + prediction_file, + mode="a" if csv_header_written else "w", + header=not csv_header_written, + ) + csv_header_written = True + cols = { + forecast_attribute_name: preds, + forecast_prob_attribute_name: scores, + } + final_metrics = _compute_conversation_metrics() + if final_metrics is not None: + tqdm.write( + f"[info] final transform metrics: " + f"processed_contexts={len(preds)}, conversations={final_metrics['n']}, " + f"acc={final_metrics['acc']:.4f}, p={final_metrics['p']:.4f}, " + f"r={final_metrics['r']:.4f}, fpr={final_metrics['fpr']:.4f}, " + f"f1={final_metrics['f1']:.4f}" + ) + for key, series in metadatas.items(): + assert len(series) == len(preds), "Metadata series length must match number of predictions" + cols[key] = series # each series same length as preds + forecasts_df = pd.DataFrame(cols, index=utt_ids) + return forecasts_df \ No newline at end of file diff --git a/convokit/forecaster/TransformerEncoderModel.py b/convokit/forecaster/TransformerEncoderModel.py index 7203818b..f7e8a07d 100644 --- a/convokit/forecaster/TransformerEncoderModel.py +++ b/convokit/forecaster/TransformerEncoderModel.py @@ -11,13 +11,10 @@ import os import pandas as pd -import numpy as np -import json from tqdm import tqdm -from sklearn.metrics import roc_curve from .forecasterModel import ForecasterModel from .TransformerForecasterConfig import TransformerForecasterConfig -import shutil +from itertools import tee os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -43,15 +40,14 @@ class TransformerEncoderModel(ForecasterModel): :param config: (Optional) TransformerForecasterConfig object containing parameters for training and evaluation. """ - def __init__(self, model_name_or_path, config=DEFAULT_CONFIG): - super().__init__() + def __init__(self, model_name_or_path, config=DEFAULT_CONFIG, decision_policy=None): + super().__init__(decision_policy=decision_policy) self.tokenizer = AutoTokenizer.from_pretrained( model_name_or_path, model_max_length=512, truncation_side="left", padding_side="right", ) - self.best_threshold = 0.5 model_config = AutoConfig.from_pretrained( model_name_or_path, num_labels=2, problem_type="single_label_classification" ) @@ -63,6 +59,40 @@ def __init__(self, model_name_or_path, config=DEFAULT_CONFIG): self.config = config return + @property + def best_threshold(self): + if hasattr(self.decision_policy, "threshold"): + return self.decision_policy.threshold + return None + + @best_threshold.setter + def best_threshold(self, value): + if hasattr(self.decision_policy, "threshold"): + self.decision_policy.threshold = float(value) + + def get_checkpoints(self): + return [cp for cp in os.listdir(self.config.output_dir) if "checkpoint-" in cp] + + def load_checkpoint(self, checkpoint_name): + full_model_path = os.path.join(self.config.output_dir, checkpoint_name) + self.model = AutoModelForSequenceClassification.from_pretrained(full_model_path).to( + self.config.device + ) + + def finalize_best_checkpoint_selection( + self, best_checkpoint, best_config, val_contexts=None, score_fn=None + ): + super().finalize_best_checkpoint_selection( + best_checkpoint, best_config, val_contexts=val_contexts, score_fn=score_fn + ) + if val_contexts is None or self.best_threshold is None: + return + val_dataset = self._context_to_bert_data(val_contexts) + val_dataset.set_format("torch") + eval_forecasts_df = self._score_dataset(val_dataset, threshold=self.best_threshold) + eval_prediction_file = os.path.join(self.config.output_dir, "val_predictions.csv") + eval_forecasts_df.to_csv(eval_prediction_file) + def _context_mode(self, context): """ Select the utterances to include in the input context based on the configured context mode. @@ -153,11 +183,11 @@ def _context_to_bert_data(self, contexts): @torch.inference_mode @torch.no_grad - def _predict( + def _score_dataset( self, dataset, model=None, - threshold=0.5, + threshold=None, forecast_prob_attribute_name="forecast_prob", forecast_attribute_name="forecast", ): @@ -179,6 +209,8 @@ def _predict( """ if not model: model = self.model.to(self.config.device) + if threshold is None: + threshold = self.best_threshold if self.best_threshold is not None else 0.5 utt_ids = [] preds = [] scores = [] @@ -198,100 +230,27 @@ def _predict( {forecast_attribute_name: preds, forecast_prob_attribute_name: scores}, index=utt_ids ) - def _tune_threshold(self, val_dataset, val_contexts): - """ - Tune the decision threshold and select the best model checkpoint based on validation accuracy. - - This method evaluates all model checkpoints in the configured output directory using a - held-out validation set. - - The selected model, threshold, and associated metadata are stored in: - - `self.model`: the best-performing fine-tuned model - - `self.best_threshold`: the optimal decision threshold - - `dev_config.json`: file containing best checkpoint metadata - - `val_predictions.csv`: CSV file with forecast outputs on the validation set - - Additionally, all non-optimal model checkpoints are removed to save disk space, and the - tokenizer is saved to the directory of the best checkpoint. - - :param val_dataset: A HuggingFace-compatible dataset containing features for validation. - :param val_contexts: An iterable of context tuples corresponding to the validation set. - Used to map utterance IDs to conversation IDs and extract ground-truth labels. - - :return: A dictionary containing the best checkpoint path, best threshold, and best validation accuracy. - """ - checkpoints = [cp for cp in os.listdir(self.config.output_dir) if "checkpoint-" in cp] - best_val_accuracy = 0 - val_convo_ids = set() - utt2convo = {} - val_labels_dict = {} - for context in val_contexts: - convo_id = context.conversation_id - utt_id = context.current_utterance.id - label = self.labeler(context.current_utterance.get_conversation()) - utt2convo[utt_id] = convo_id - val_labels_dict[convo_id] = label - val_convo_ids.add(convo_id) - val_convo_ids = list(val_convo_ids) - for cp in checkpoints: - full_model_path = os.path.join(self.config.output_dir, cp) - finetuned_model = AutoModelForSequenceClassification.from_pretrained( - full_model_path - ).to(self.config.device) - val_scores = self._predict(val_dataset, model=finetuned_model) - # for each CONVERSATION, whether or not it triggers will be effectively determined by what the highest score it ever got was - highest_convo_scores = {convo_id: -1 for convo_id in val_convo_ids} - for utt_id in val_scores.index: - convo_id = utt2convo[utt_id] - utt_score = val_scores.loc[utt_id].forecast_prob - if utt_score > highest_convo_scores[convo_id]: - highest_convo_scores[convo_id] = utt_score - - val_labels = np.asarray([int(val_labels_dict[c]) for c in val_convo_ids]) - val_scores = np.asarray([highest_convo_scores[c] for c in val_convo_ids]) - # use scikit learn to find candidate threshold cutoffs - _, _, thresholds = roc_curve(val_labels, val_scores) - - def acc_with_threshold(y_true, y_score, thresh): - y_pred = (y_score > thresh).astype(int) - return (y_pred == y_true).mean() - - accs = [acc_with_threshold(val_labels, val_scores, t) for t in thresholds] - best_acc_idx = np.argmax(accs) - - print("Accuracy:", cp, accs[best_acc_idx]) - if accs[best_acc_idx] > best_val_accuracy: - best_checkpoint = cp - best_val_accuracy = accs[best_acc_idx] - self.best_threshold = thresholds[best_acc_idx] - self.model = finetuned_model - - eval_forecasts_df = self._predict(val_dataset, threshold=self.best_threshold) - eval_prediction_file = os.path.join(self.config.output_dir, "val_predictions.csv") - eval_forecasts_df.to_csv(eval_prediction_file) - - # Save the best config - best_config = {} - best_config["best_checkpoint"] = best_checkpoint - best_config["best_threshold"] = self.best_threshold - best_config["best_val_accuracy"] = best_val_accuracy - config_file = os.path.join(self.config.output_dir, "dev_config.json") - with open(config_file, "w") as outfile: - json_object = json.dumps(best_config, indent=4) - outfile.write(json_object) - - # Clean other checkpoints to save disk space. - for root, _, _ in os.walk(self.config.output_dir): - if ("checkpoint" in root) and (best_checkpoint not in root): - print("Deleting:", root) - shutil.rmtree(root) - # Save the tokenizer. - self.tokenizer.save_pretrained( - os.path.join(self.config.output_dir, best_config["best_checkpoint"]) + @torch.inference_mode + @torch.no_grad + def score(self, context) -> float: + self.model.eval() + context_utts = self._context_mode(context) + tokenized_context = self._tokenize(context_utts) + input_ids = ( + torch.tensor(tokenized_context["input_ids"], dtype=torch.long) + .to(self.config.device) + .reshape([1, -1]) ) - return best_config + attention_mask = ( + torch.tensor(tokenized_context["attention_mask"], dtype=torch.long) + .to(self.config.device) + .reshape([1, -1]) + ) + outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) + probs = F.softmax(outputs.logits, dim=-1) + return probs[0, 1].item() - def fit(self, contexts, val_contexts): + def fit_belief_estimator(self, contexts, val_contexts=None): """ Fine-tune the TransformerEncoder model, and save the best model according to validation performance. @@ -301,12 +260,10 @@ def fit(self, contexts, val_contexts): held-out validation set. :param contexts: an iterator over context tuples, provided by the Forecaster framework - :param val_contexts: an iterator over context tuples to be used only for validation. + :param val_contexts: optional validation contexts (not used by this stage). """ - val_contexts = list(val_contexts) train_pairs = self._context_to_bert_data(contexts) - val_for_tuning_pairs = self._context_to_bert_data(val_contexts) - dataset = DatasetDict({"train": train_pairs, "val_for_tuning": val_for_tuning_pairs}) + dataset = DatasetDict({"train": train_pairs}) dataset.set_format("torch") training_args = TrainingArguments( @@ -324,9 +281,32 @@ def fit(self, contexts, val_contexts): ) trainer = Trainer(model=self.model, args=training_args, train_dataset=dataset["train"]) trainer.train() - _ = self._tune_threshold(dataset["val_for_tuning"], val_contexts) return + def fit(self, contexts, val_contexts=None): + val_contexts_belief_estimator, val_contexts_decision_policy = tee(val_contexts, 2) + self.fit_belief_estimator(contexts, val_contexts_belief_estimator) + self.fit_decision_policy(contexts, val_contexts_decision_policy, score_fn=self.score) + return + + def _predict(self, context, threshold=None): + utt_score = self.score(context) + # keep threshold override for backward compatibility. + if threshold is not None: + utt_pred = int(utt_score > threshold) + else: + result = self.decision_policy.decide(context, self.score) + if len(result) == 2: + utt_score, utt_pred = result + elif len(result) == 3: + utt_score, utt_pred, _ = result + else: + raise ValueError( + "decision_policy.decide() must return (utt_score, utt_pred) " + "or (utt_score, utt_pred, metadata_dict)" + ) + return utt_score, utt_pred + def transform(self, contexts, forecast_attribute_name, forecast_prob_attribute_name): """ Generate forecasts using the fine-tuned TransformerEncoder model on the provided contexts, and save the predictions to the output directory specified in the configuration. @@ -337,14 +317,16 @@ def transform(self, contexts, forecast_attribute_name, forecast_prob_attribute_n :return: a Pandas DataFrame, with one row for each context, indexed by the ID of that context's current utterance. Contains two columns, one with raw probabilities named according to forecast_prob_attribute_name, and one with discretized (binary) forecasts named according to forecast_attribute_name """ - test_pairs = self._context_to_bert_data(contexts) - dataset = DatasetDict({"test": test_pairs}) - dataset.set_format("torch") - forecasts_df = self._predict( - dataset["test"], - threshold=self.best_threshold, - forecast_attribute_name=forecast_attribute_name, - forecast_prob_attribute_name=forecast_prob_attribute_name, + utt_ids = [] + preds = [] + scores = [] + for context in tqdm(contexts): + utt_score, utt_pred = self._predict(context) + utt_ids.append(context.current_utterance.id) + preds.append(utt_pred) + scores.append(utt_score) + forecasts_df = pd.DataFrame( + {forecast_attribute_name: preds, forecast_prob_attribute_name: scores}, index=utt_ids ) prediction_file = os.path.join(self.config.output_dir, "test_predictions.csv") diff --git a/convokit/forecaster/__init__.py b/convokit/forecaster/__init__.py index 12a13760..71fac553 100644 --- a/convokit/forecaster/__init__.py +++ b/convokit/forecaster/__init__.py @@ -1,6 +1,7 @@ from .forecaster import * from .forecasterModel import * from .cumulativeBoW import * +from convokit.decisionpolicy import * import sys # Import CRAFT models if torch is available diff --git a/convokit/forecaster/cumulativeBoW.py b/convokit/forecaster/cumulativeBoW.py index 50fe0373..c14b9c61 100644 --- a/convokit/forecaster/cumulativeBoW.py +++ b/convokit/forecaster/cumulativeBoW.py @@ -25,11 +25,15 @@ def __init__( use_tokens=False, forecast_attribute_name: str = "prediction", forecast_prob_attribute_name: str = "score", + decision_policy=None, ): super().__init__( + decision_policy=decision_policy, forecast_attribute_name=forecast_attribute_name, forecast_prob_attribute_name=forecast_prob_attribute_name, ) + self.forecast_attribute_name = forecast_attribute_name + self.forecast_prob_attribute_name = forecast_prob_attribute_name if vectorizer is None: print("Initializing default unigram CountVectorizer...") if use_tokens: @@ -66,6 +70,10 @@ def __init__( else: self.clf_model = clf_model + @staticmethod + def _context_to_text(context): + return " ".join([u.text for u in context.context]) + @staticmethod def _combine_contexts(id_to_context_others): """ @@ -114,3 +122,30 @@ def forecast(self, id_to_context_reply_label): data=list(zip(ids, preds, pred_probs)), columns=["id", self.forecast_attribute_name, self.forecast_prob_attribute_name], ).set_index("id") + + def fit_belief_estimator(self, contexts, val_contexts=None): + contexts = list(contexts) + X_raw = [self._context_to_text(context) for context in contexts] + y = [self.labeler(context.current_utterance.get_conversation()) for context in contexts] + X = self.vectorizer.fit_transform(X_raw) + self.clf_model.fit(X, y) + + def score(self, context) -> float: + X = self.vectorizer.transform([self._context_to_text(context)]) + return float(self.clf_model.predict_proba(X)[0, 1]) + + def fit(self, contexts, val_contexts=None): + return super().fit(contexts, val_contexts) + + def transform(self, contexts, forecast_attribute_name, forecast_prob_attribute_name): + utt_ids = [] + preds = [] + scores = [] + for context in contexts: + utt_score, utt_pred = self._predict(context) + utt_ids.append(context.current_utterance.id) + preds.append(utt_pred) + scores.append(utt_score) + return pd.DataFrame( + {forecast_attribute_name: preds, forecast_prob_attribute_name: scores}, index=utt_ids + ) diff --git a/convokit/forecaster/forecaster.py b/convokit/forecaster/forecaster.py index c3c259ac..0e0b485c 100644 --- a/convokit/forecaster/forecaster.py +++ b/convokit/forecaster/forecaster.py @@ -6,7 +6,9 @@ import numpy as np from matplotlib import pyplot as plt -# Define a namedtuple template to represent conversational context tuples +# define a namedtuple template to represent conversational context tuples. +# this alias is kept for backwards compatibility; decision policies should +# accept the same structure. ContextTuple = namedtuple( "ContextTuple", ["context", "current_utterance", "future_context", "conversation_id"] ) @@ -48,6 +50,9 @@ def __init__( # also give the underlying ForecasterModel access to the labeler function self.forecaster_model.labeler = self.labeler + # keep the decision policy's forecast_prob cache key aligned with the + # meta field that Forecaster.transform() writes to. + self.forecaster_model.forecast_prob_attribute_name = self.forecast_prob_attribute_name def _create_context_iterator( self, @@ -119,11 +124,26 @@ def fit( self.forecaster_model.fit(contexts, val_contexts) return self + + def fit_decision_policy(self, corpus, context_selector, val_context_selector): + contexts = self._create_context_iterator(corpus, context_selector, include_future_context=True) + val_contexts = None + if val_context_selector is not None: + val_contexts = self._create_context_iterator(corpus, val_context_selector, include_future_context=True) + return self.forecaster_model.fit_decision_policy(contexts, val_contexts) + + def fit_belief_estimator(self, corpus, context_selector, val_context_selector): + contexts = self._create_context_iterator(corpus, context_selector, include_future_context=True) + val_contexts = None + if val_context_selector is not None: + val_contexts = self._create_context_iterator(corpus, val_context_selector, include_future_context=True) + return self.forecaster_model.fit_belief_estimator(contexts, val_contexts) def transform( self, corpus: Corpus, context_selector: Callable[[ContextTuple], bool] = lambda context: True, + **kwargs, ) -> Corpus: """ Wrapper method for applying the underlying conversational forecasting model to make forecasts over the Conversations in a given Corpus. @@ -137,22 +157,19 @@ def transform( """ contexts = self._create_context_iterator(corpus, context_selector) forecast_df = self.forecaster_model.transform( - contexts, self.forecast_attribute_name, self.forecast_prob_attribute_name + contexts, self.forecast_attribute_name, self.forecast_prob_attribute_name, **kwargs, ) + # generalize addition of metadata columns + meta_columns = list(forecast_df.columns) for utt in corpus.iter_utterances(): if utt.id in forecast_df.index: - utt.add_meta( - self.forecast_attribute_name, - forecast_df.loc[utt.id][self.forecast_attribute_name], - ) - utt.add_meta( - self.forecast_prob_attribute_name, - forecast_df.loc[utt.id][self.forecast_prob_attribute_name], - ) + row = forecast_df.loc[utt.id] + for col in meta_columns: + utt.add_meta(col, row[col]) else: - utt.add_meta(self.forecast_attribute_name, None) - utt.add_meta(self.forecast_prob_attribute_name, None) + for col in meta_columns: + utt.add_meta(col, None) return corpus diff --git a/convokit/forecaster/forecasterModel.py b/convokit/forecaster/forecasterModel.py index 0051ff32..168d5278 100644 --- a/convokit/forecaster/forecasterModel.py +++ b/convokit/forecaster/forecasterModel.py @@ -1,6 +1,13 @@ from abc import ABC, abstractmethod +from itertools import tee from typing import Callable +import json +import os +import shutil + +from convokit.decisionpolicy import ThresholdDecisionPolicy + class ForecasterModel(ABC): """ @@ -9,8 +16,10 @@ class ForecasterModel(ABC): in a consistent format, defined above. """ - def __init__(self): + def __init__(self, decision_policy=None, **kwargs): self._labeler = None + self._forecast_prob_attribute_name = "forecast_prob" + self._decision_policy = decision_policy or ThresholdDecisionPolicy() @property def labeler(self): @@ -19,17 +28,162 @@ def labeler(self): @labeler.setter def labeler(self, value: Callable): self._labeler = value + if self._decision_policy is not None: + self._decision_policy.labeler = value + + @property + def forecast_prob_attribute_name(self) -> str: + return self._forecast_prob_attribute_name + + @forecast_prob_attribute_name.setter + def forecast_prob_attribute_name(self, value: str): + # keeps the decision policy's cache key aligned with the forecaster's + # meta field so policies can reuse previously written forecast probs. + self._forecast_prob_attribute_name = value + if self._decision_policy is not None: + self._decision_policy.forecast_prob_attribute_name = value + + @property + def decision_policy(self): + return self._decision_policy + + @decision_policy.setter + def decision_policy(self, value): + self._decision_policy = value + if self._decision_policy is not None: + self._decision_policy.labeler = self._labeler + self._decision_policy.forecast_prob_attribute_name = ( + self._forecast_prob_attribute_name + ) @abstractmethod def fit(self, contexts, val_contexts=None): """ - Train this conversational forecasting model on the given data + Train this conversational forecasting model on the given data by fitting + both the belief estimator and the decision policy. :param contexts: an iterator over context tuples :param val_contexts: an optional second iterator over context tuples to be used as a separate held-out validation set. Concrete ForecasterModel implementations may choose to ignore this, or conversely even enforce its presence. """ pass + @abstractmethod + def fit_belief_estimator(self, contexts, val_contexts=None): + """ + Fit only the belief estimator component that produces continuous scores. + """ + pass + + def fit_decision_policy(self, contexts, val_contexts=None, score_fn: Callable = None): + """ + Fit only the decision policy component. + """ + if self.decision_policy is not None: + if score_fn is None: + score_fn = self.score + fit_result = self.decision_policy.fit( + contexts=contexts, val_contexts=val_contexts, score_fn=score_fn + ) + self._json_dump_fit_result(fit_result) + return fit_result + return None + + def _json_dump_fit_result(self, fit_result): + if not isinstance(fit_result, dict): + return + + output_dir = getattr(getattr(self, "config", None), "output_dir", None) + if output_dir is None: + return + + config_file = os.path.join(output_dir, "dev_config.json") + existing_config = {} + if os.path.exists(config_file): + try: + with open(config_file, "r") as infile: + existing_config = json.load(infile) + except (json.JSONDecodeError, OSError): + existing_config = {} + + if "best_checkpoint" in fit_result: + existing_config["best_checkpoint"] = fit_result["best_checkpoint"] + if "best_threshold" in fit_result: + existing_config["best_threshold"] = float(fit_result["best_threshold"]) + if "best_val_accuracy" in fit_result: + existing_config["best_val_accuracy"] = float(fit_result["best_val_accuracy"]) + + with open(config_file, "w") as outfile: + json.dump(existing_config, outfile, indent=4) + + def get_checkpoints(self): + return [] + + def load_checkpoint(self, checkpoint_name): + raise NotImplementedError("checkpoint loading is not implemented for this model") + + def finalize_best_checkpoint_selection( + self, best_checkpoint, best_config, val_contexts=None, score_fn: Callable = None + ): + if best_checkpoint is None: + return + self._cleanup_checkpoints(best_checkpoint) + self._save_tokenizer_checkpoint(best_checkpoint) + + def _cleanup_checkpoints(self, best_checkpoint): + output_dir = getattr(getattr(self, "config", None), "output_dir", None) + if output_dir is None or best_checkpoint is None: + return + + for root, _, _ in os.walk(output_dir): + if ("checkpoint" in root) and (best_checkpoint not in root): + print(f"deleting: {root}") + shutil.rmtree(root) + + def _save_tokenizer_checkpoint(self, best_checkpoint): + tokenizer = getattr(self, "tokenizer", None) + output_dir = getattr(getattr(self, "config", None), "output_dir", None) + if ( + tokenizer is None + or output_dir is None + or best_checkpoint is None + or not hasattr(tokenizer, "save_pretrained") + ): + return + tokenizer.save_pretrained(os.path.join(output_dir, best_checkpoint)) + + @abstractmethod + def score(self, context) -> float: + """ + Produce the belief estimator score for a context. + """ + pass + + def _predict(self, context): + """ + Return both belief score and policy action for a context. + + This method is deprecated in favor of using the self.decision_policy.decide method. + """ + utt_score, utt_pred, _ = self._parse_decision_result( + self.decision_policy.decide(context, self.score) + ) + return utt_score, utt_pred + + def _parse_decision_result(self, result): + if len(result) == 2: + utt_score, utt_pred = result + utt_metadata = {} + elif len(result) == 3: + utt_score, utt_pred, utt_metadata = result + if utt_metadata is None: + utt_metadata = {} + else: + raise ValueError( + "decision_policy.decide() must return (utt_score, utt_pred) " + "or (utt_score, utt_pred, metadata_dict)" + ) + return float(utt_score), int(utt_pred), utt_metadata + @abstractmethod def transform(self, contexts, forecast_attribute_name, forecast_prob_attribute_name): """ diff --git a/docs/source/decisionpolicy.rst b/docs/source/decisionpolicy.rst new file mode 100644 index 00000000..48eebd1a --- /dev/null +++ b/docs/source/decisionpolicy.rst @@ -0,0 +1,9 @@ +Decision Policy +=============== + +The decision policy API separates belief estimation (continuous scores) from +intervention decisions (discrete actions). This keeps ``Forecaster`` unchanged +while allowing flexible action logic in ``ForecasterModel``. + +.. automodule:: convokit.decisionpolicy + :members: diff --git a/docs/source/forecaster.rst b/docs/source/forecaster.rst index e688aa7c..cd0fe35b 100644 --- a/docs/source/forecaster.rst +++ b/docs/source/forecaster.rst @@ -41,6 +41,7 @@ These are subclasses of ForecasterModel, each implementing forecasting models us .. toctree:: :maxdepth: 1 + Decision Policy CRAFT Model Transformer Encoder-based Model Transformer Decoder-based Model diff --git a/download_config.json b/download_config.json index 09bf149f..3027a0d2 100644 --- a/download_config.json +++ b/download_config.json @@ -41,7 +41,8 @@ "ubuntu-chat-logs": 0, "contextual-abuse": 0, "news-interview": 0, - "emotional-support": 0 + "emotional-support": 0, + "decisionpolicy-demo": 0 }, "DatasetURLs": { "chromium-corpus": "http://zissou.infosci.cornell.edu/convokit/datasets/chromium-corpus/chromium-corpus.zip", @@ -115,7 +116,8 @@ "ubuntu-chat-logs": "https://zissou.infosci.cornell.edu/convokit/datasets/ubuntu-chat-logs/ubuntu-chat-logs.zip", "contextual-abuse": "https://zissou.infosci.cornell.edu/convokit/datasets/contextual-abuse/contextual-abuse.zip", "news-interview": "https://zissou.infosci.cornell.edu/convokit/datasets/news-interview/news-interview.zip", - "emotional-support": "https://zissou.infosci.cornell.edu/convokit/datasets/emotional-support/emotional-support.zip" + "emotional-support": "https://zissou.infosci.cornell.edu/convokit/datasets/emotional-support/emotional-support.zip", + "decisionpolicy-demo": "https://zissou.infosci.cornell.edu/convokit/datasets/decisionpolicy-demo/decisionpolicy-demo.zip" }, "ModelURLS": { "craft-wiki-pretrained": [ diff --git a/examples/decisionpolicy/decisionpolicy_demo.ipynb b/examples/decisionpolicy/decisionpolicy_demo.ipynb new file mode 100644 index 00000000..79caffd3 --- /dev/null +++ b/examples/decisionpolicy/decisionpolicy_demo.ipynb @@ -0,0 +1,2539 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "90659cc6", + "metadata": {}, + "source": [ + "# Decision Policy Demo\n", + "\n", + "This notebook will provide code demonstrating how to use Decision Policies as introduced in Wait! There's a Way Out. This notebook will also provide code for running the experiments in the paper. " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "703021a8", + "metadata": {}, + "outputs": [], + "source": [ + "# TODO\n", + "\n", + "import os\n", + "os.environ['CUDA_VISIBLE_DEVICES'] = '2'" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "fd8b87be", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2026-04-28 07:19:59.106329: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", + "2026-04-28 07:19:59.126464: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", + "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", + "E0000 00:00:1777360799.150730 1126746 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", + "E0000 00:00:1777360799.158776 1126746 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", + "W0000 00:00:1777360799.179235 1126746 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", + "W0000 00:00:1777360799.179259 1126746 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", + "W0000 00:00:1777360799.179261 1126746 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", + "W0000 00:00:1777360799.179264 1126746 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", + "2026-04-28 07:19:59.185112: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "🦥 Unsloth Zoo will now patch everything to make training faster!\n" + ] + } + ], + "source": [ + "import os\n", + "import argparse\n", + "import sys\n", + "import glob\n", + "\n", + "from functools import partial\n", + "import json\n", + "from convokit import Corpus, Forecaster, download\n", + "from convokit.forecaster.TransformerDecoderModel import TransformerDecoderModel\n", + "from convokit.forecaster.TransformerForecasterConfig import TransformerForecasterConfig\n", + "from convokit.decisionpolicy import DeferralDecisionPolicy, ThresholdDecisionPolicy\n", + "from convokit.utterance_simulator.unslothUtteranceSimulatorModel import UnslothUtteranceSimulatorModel" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "15f977d3", + "metadata": {}, + "outputs": [], + "source": [ + "# Set repo root\n", + "\n", + "from pathlib import Path\n", + "\n", + "repo_root = Path.cwd()\n", + "while repo_root.name != \"ConvoKit\":\n", + " repo_root = repo_root.parent\n", + "\n", + "repo_root = str(repo_root)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18f2375b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[info] using cached decisionpolicy-demo at /home/lyk25/.convokit/saved-corpora/decisionpolicy-demo\n", + "[info] loaded 26 corpora from /home/lyk25/.convokit/saved-corpora/decisionpolicy-demo\n" + ] + } + ], + "source": [ + "# TODO temporary pre-merge downloader for decisionpolicy-demo\n", + "from pathlib import Path\n", + "import json\n", + "import urllib.request\n", + "import zipfile\n", + "\n", + "from convokit import Corpus\n", + "\n", + "DOWNLOAD_CONFIG_URL = (\n", + " \"https://raw.githubusercontent.com/laerdon/ConvoKit/\"\n", + " \"master/download_config.json\"\n", + ")\n", + "\n", + "\n", + "def get_decisionpolicy_demo_url(config_url=DOWNLOAD_CONFIG_URL):\n", + " print(f\"[info] reading download config from {config_url}\")\n", + " with urllib.request.urlopen(config_url) as response:\n", + " dataset_config = json.load(response)\n", + " try:\n", + " return dataset_config[\"DatasetURLs\"][\"decisionpolicy-demo\"]\n", + " except KeyError as exc:\n", + " raise KeyError(\n", + " \"decisionpolicy-demo is missing from laerdon/ConvoKit master download_config.json\"\n", + " ) from exc\n", + "\n", + "\n", + "def download_decisionpolicy_demo(data_dir=None):\n", + " data_root = Path(data_dir or \"~/.convokit/saved-corpora\").expanduser()\n", + " dataset_name = \"decisionpolicy-demo\"\n", + " dataset_dir = data_root / dataset_name\n", + " zip_path = data_root / f\"{dataset_name}.zip\"\n", + "\n", + " if any(path.is_dir() and (path / \"index.json\").exists() for path in dataset_dir.rglob(\"*\")):\n", + " print(f\"[info] using cached {dataset_name} at {dataset_dir}\")\n", + " return dataset_dir\n", + "\n", + " url = get_decisionpolicy_demo_url()\n", + " data_root.mkdir(parents=True, exist_ok=True)\n", + " print(f\"[info] downloading {dataset_name} from {url}\")\n", + " urllib.request.urlretrieve(url, zip_path)\n", + "\n", + " print(f\"[info] extracting {zip_path} to {data_root}\")\n", + " with zipfile.ZipFile(zip_path, \"r\") as zipf:\n", + " zipf.extractall(data_root)\n", + "\n", + " if not dataset_dir.exists():\n", + " raise FileNotFoundError(f\"expected extracted folder missing: {dataset_dir}\")\n", + " return dataset_dir\n", + "\n", + "base = download_decisionpolicy_demo()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "858a8431", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[info] using cached decisionpolicy-demo at /home/lyk25/.convokit/saved-corpora/decisionpolicy-demo\n", + "[info] test-seed corpora: 5 (test-seed-1, test-seed-2, test-seed-3, test-seed-4, test-seed-5); seed-* policy corpora in corpora_all: 25\n" + ] + } + ], + "source": [ + "import re\n", + "corpus_dirs = sorted(\n", + " corpus_dir\n", + " for corpus_dir in base.rglob(\"*\")\n", + " if corpus_dir.is_dir() and (corpus_dir / \"index.json\").exists()\n", + ")\n", + "if not corpus_dirs:\n", + " raise FileNotFoundError(f\"no convokit corpora found under {base}\")\n", + "_seed_policy_re = re.compile(r\"^seed-(\\d+)-(.+)$\")\n", + "_test_seed_re = re.compile(r\"^test-seed-(\\d+)$\")\n", + "corpora_all = {}\n", + "corpora = {}\n", + "for corpus_dir in corpus_dirs:\n", + " if _test_seed_re.match(corpus_dir.name):\n", + " corpora[corpus_dir.name] = Corpus(filename=str(corpus_dir))\n", + " continue\n", + " m_test = _test_seed_re.match(corpus_dir.parent.name)\n", + " if m_test:\n", + " corpora[f\"test-seed-{m_test.group(1)}\"] = Corpus(filename=str(corpus_dir))\n", + " continue\n", + " m = _seed_policy_re.match(corpus_dir.parent.name)\n", + " if m and corpus_dir.name == m.group(2):\n", + " corpora_all[f\"seed-{m.group(1)}-{m.group(2)}\"] = Corpus(filename=str(corpus_dir))\n", + "CORPUS_POLICY_PER_SEED = \"DeferralDecisionPolicy\"\n", + "if not corpora:\n", + " for key in sorted(corpora_all):\n", + " m = _seed_policy_re.match(key)\n", + " if m and m.group(2) == CORPUS_POLICY_PER_SEED:\n", + " corpora[f\"test-seed-{m.group(1)}\"] = corpora_all[key]\n", + "if not corpora:\n", + " raise FileNotFoundError(\n", + " f\"no test-seed- corpora under {base} and none derived from {CORPUS_POLICY_PER_SEED!r}\"\n", + " )\n", + "print(\n", + " f\"[info] test-seed corpora: {len(corpora)} ({', '.join(sorted(corpora))}); \"\n", + " f\"seed-* policy corpora in corpora_all: {len(corpora_all)}\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "a0bbbd7f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'seed-1-DeferralDecisionPolicy': ,\n", + " 'seed-1-RandomDeferralDecisionPolicy': ,\n", + " 'seed-1-SimulationAverageDecisionPolicy': ,\n", + " 'seed-1-SimulationMajorityDecisionPolicy': ,\n", + " 'seed-1-ThresholdDecisionPolicy': ,\n", + " 'seed-2-DeferralDecisionPolicy': ,\n", + " 'seed-2-RandomDeferralDecisionPolicy': ,\n", + " 'seed-2-SimulationAverageDecisionPolicy': ,\n", + " 'seed-2-SimulationMajorityDecisionPolicy': ,\n", + " 'seed-2-ThresholdDecisionPolicy': ,\n", + " 'seed-3-DeferralDecisionPolicy': ,\n", + " 'seed-3-RandomDeferralDecisionPolicy': ,\n", + " 'seed-3-SimulationAverageDecisionPolicy': ,\n", + " 'seed-3-SimulationMajorityDecisionPolicy': ,\n", + " 'seed-3-ThresholdDecisionPolicy': ,\n", + " 'seed-4-DeferralDecisionPolicy': ,\n", + " 'seed-4-RandomDeferralDecisionPolicy': ,\n", + " 'seed-4-SimulationAverageDecisionPolicy': ,\n", + " 'seed-4-SimulationMajorityDecisionPolicy': ,\n", + " 'seed-4-ThresholdDecisionPolicy': ,\n", + " 'seed-5-DeferralDecisionPolicy': ,\n", + " 'seed-5-RandomDeferralDecisionPolicy': ,\n", + " 'seed-5-SimulationAverageDecisionPolicy': ,\n", + " 'seed-5-SimulationMajorityDecisionPolicy': ,\n", + " 'seed-5-ThresholdDecisionPolicy': }" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "corpora_all" + ] + }, + { + "cell_type": "markdown", + "id": "dab74642", + "metadata": {}, + "source": [ + "Having imported our DeferralDecisionPolicy and ThresholdDecisionPolicy, we now will first define all of the other decision policies to benchmark. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5dabf0bd", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Callable, List, Optional, Dict, Any, Tuple\n", + "import numpy as np\n", + "from convokit.decisionpolicy import DecisionPolicy\n", + "\n", + "\n", + "class _synthetic_speaker:\n", + " def __init__(self, speaker_id: str):\n", + " self.id = speaker_id\n", + "\n", + "\n", + "class _synthetic_utterance:\n", + " def __init__(self, text: str, utterance_id: str, speaker_id: str):\n", + " self.text = text\n", + " self.id = utterance_id\n", + " self.speaker_ = _synthetic_speaker(speaker_id)\n", + " self.meta = {}\n", + "\n", + " def get_conversation(self):\n", + " return None\n", + "\n", + "\n", + "class RandomDeferralDecisionPolicy(DecisionPolicy):\n", + " \"\"\"\n", + " Decision policy that defers intervention by looking ahead at simulated next utterances.\n", + "\n", + " :param simulator: utterance simulator model (must have a ``transform(contexts)`` method\n", + " returning a DataFrame indexed by utterance id). if the simulator exposes\n", + " ``get_num_simulations()``, ``num_simulations`` is capped to that value.\n", + " :param threshold: probability threshold above which a context is flagged.\n", + " :param tau: minimum number of simulated branches that must exceed the threshold\n", + " before an intervention is issued.\n", + " :param num_simulations: how many simulated branches to use per context (capped to\n", + " simulator's ``get_num_simulations()`` if available).\n", + " :param store_simulations: if True, simulated reply strings are cached during decide()\n", + " and written to corpus utterance metadata by post_transform().\n", + " :param simulated_reply_attribute_name: metadata field name used when storing simulations\n", + " on corpus utterances (only relevant when store_simulations=True).\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " simulator,\n", + " threshold,\n", + " deferral_probability: float = 0.1515,\n", + " reuse_cached_forecast_probs: bool = True,\n", + " forecast_prob_attribute_name: str = \"forecast_prob\",\n", + " ):\n", + " # forward the cache flag to the base class so its _score helper honors it.\n", + " # without this, reuse_cached_probabilities on this subclass had no effect.\n", + " super().__init__(\n", + " forecast_prob_attribute_name=forecast_prob_attribute_name,\n", + " reuse_cached_forecast_probs=reuse_cached_forecast_probs,\n", + " )\n", + " self.simulator = simulator\n", + " self.threshold = float(threshold)\n", + " self.deferral_probability = float(deferral_probability)\n", + "\n", + " def _decision_score(self, context, score_fn: Callable):\n", + " # use base _score so a cached forecast_prob on the utterance meta is reused\n", + " # instead of re-invoking the belief estimator.\n", + " return self._score(context, score_fn)\n", + "\n", + " def decide(self, context, score_fn: Callable) -> Tuple[float, int, Optional[Dict[str, Any]]]:\n", + " decision_score = self._score(context, score_fn)\n", + "\n", + " p = float(np.random.rand())\n", + " \n", + "\n", + " # return an empty metadata dict (not None) so downstream code that iterates\n", + " # utt_metadata.items() in TransformerDecoderModel.transform doesn't crash.\n", + " return (decision_score,\n", + " 1 if decision_score > self.threshold and p > self.deferral_probability else 0,\n", + " {}\n", + " )\n", + "\n", + " def fit(self, contexts, val_contexts=None, score_fn: Callable = None):\n", + " if val_contexts is None or score_fn is None or self.labeler is None:\n", + " print(\"either no validation contexts/score function/labeler were provided, returning current threshold\")\n", + " return {\"best_threshold\": self.threshold}\n", + "\n", + " val_contexts = list(val_contexts)\n", + " if len(val_contexts) == 0:\n", + " print(\"no validation contexts were provided, returning current threshold\")\n", + " return {\"best_threshold\": self.threshold}\n", + "\n", + " fit_result = self._fit_with_model_checkpoint_selection(val_contexts, score_fn=score_fn)\n", + " if isinstance(fit_result, dict):\n", + " if \"best_threshold\" in fit_result:\n", + " self.threshold = float(fit_result[\"best_threshold\"])\n", + " return fit_result\n", + "\n", + " fit_result = self._fit_threshold_for_loaded_model(val_contexts, score_fn=score_fn)\n", + " if \"best_threshold\" in fit_result:\n", + " self.threshold = float(fit_result[\"best_threshold\"])\n", + " return fit_result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9d3db5bc", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Callable, Optional, Dict, Any, Tuple\n", + "\n", + "from convokit.decisionpolicy import DeferralDecisionPolicy\n", + "\n", + "\n", + "class SimulationAverageDecisionPolicy(DeferralDecisionPolicy):\n", + " \"\"\"\n", + " decision policy that intervenes if the mean of the simulated next-utterance\n", + " scores is at or above the threshold.\n", + "\n", + " this subclass inherits all simulation fetching, per-utterance metadata\n", + " caching, sim-score caching, and threshold fitting from\n", + " DeferralDecisionPolicy. the only differences are:\n", + " * no ``tau`` parameter (unused; forwarded as 0 to super)\n", + " * ``decide`` predicts based on mean(simulation_scores) >= threshold\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " simulator,\n", + " threshold,\n", + " num_simulations: int = 10,\n", + " store_simulations: bool = False,\n", + " simulated_reply_attribute_name: str = \"sim_replies\",\n", + " sim_replies_forecast_probs_attribute_name: str = \"sim_replies_forecast_probs\",\n", + " reuse_cached_simulations: bool = True,\n", + " ):\n", + " # tau is irrelevant for the mean-based decision rule, so we pin it to 0\n", + " # upstream rather than expose it to callers of this subclass.\n", + " super().__init__(\n", + " simulator=simulator,\n", + " threshold=threshold,\n", + " tau=0,\n", + " num_simulations=num_simulations,\n", + " store_simulations=store_simulations,\n", + " simulated_reply_attribute_name=simulated_reply_attribute_name,\n", + " sim_replies_forecast_probs_attribute_name=sim_replies_forecast_probs_attribute_name,\n", + " reuse_cached_simulations=reuse_cached_simulations,\n", + " )\n", + "\n", + " def decide(self, context, score_fn: Callable) -> Tuple[float, int, Optional[Dict[str, Any]]]:\n", + " decision_score, simulations, simulation_scores = self._decision_score(context, score_fn)\n", + " # empty simulation_scores would zero-divide. this happens when the\n", + " # simulator returns no completions for a context (e.g. end-of-conversation\n", + " # contexts that slip through the selector). treat as no intervention so\n", + " # a single degenerate context doesn't abort the whole transform run.\n", + " if len(simulation_scores) == 0:\n", + " average_simulation_score = 0.0\n", + " pred = 0\n", + " else:\n", + " average_simulation_score = sum(simulation_scores) / len(simulation_scores)\n", + " pred = 1 if average_simulation_score >= self.threshold else 0\n", + " return (\n", + " decision_score,\n", + " pred,\n", + " {\n", + " self.simulated_reply_attribute_name: simulations,\n", + " self.sim_replies_forecast_probs_attribute_name: simulation_scores,\n", + " },\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d0d1ea7a", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Callable, Optional, Dict, Any, Tuple\n", + "\n", + "from convokit.decisionpolicy import DeferralDecisionPolicy\n", + "\n", + "\n", + "class SimulationMajorityDecisionPolicy(DeferralDecisionPolicy):\n", + " \"\"\"\n", + " decision policy that intervenes if at least ``tau`` of the simulated next\n", + " utterances score above the threshold, ignoring the current utterance score.\n", + "\n", + " this subclass inherits all simulation fetching, per-utterance metadata\n", + " caching, sim-score caching, and threshold fitting from\n", + " DeferralDecisionPolicy. the only difference is in ``decide``: the gate\n", + " ``decision_score > threshold`` is dropped so that only the simulated-branch\n", + " vote count drives the prediction.\n", + " \"\"\"\n", + "\n", + " def decide(self, context, score_fn: Callable) -> Tuple[float, int, Optional[Dict[str, Any]]]:\n", + " decision_score, simulations, simulation_scores = self._decision_score(context, score_fn)\n", + " num_simulations_above_threshold = sum(\n", + " 1 for score in simulation_scores if score > self.threshold\n", + " )\n", + " return (\n", + " decision_score,\n", + " 1 if num_simulations_above_threshold >= self.tau else 0,\n", + " {\n", + " self.simulated_reply_attribute_name: simulations,\n", + " self.sim_replies_forecast_probs_attribute_name: simulation_scores,\n", + " },\n", + " )\n" + ] + }, + { + "cell_type": "markdown", + "id": "0498693e", + "metadata": {}, + "source": [ + "Since we use simulations in our baselines, we must define a simulation configuration as follows: " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "638fb31c", + "metadata": {}, + "outputs": [], + "source": [ + "SIMULATOR_TRAIN_CONFIG = {\n", + " \"per_device_train_batch_size\": 16,\n", + " \"per_device_eval_batch_size\": 16,\n", + " \"eval_strategy\": \"steps\",\n", + " \"save_strategy\": \"steps\",\n", + " \"save_steps\": 30,\n", + " \"gradient_accumulation_steps\": 4,\n", + " \"warmup_steps\": 5,\n", + " \"num_train_epochs\": 1,\n", + " \"eval_steps\": 30,\n", + " \"learning_rate\": 2e-4,\n", + " \"logging_steps\": 5,\n", + " \"optim\": \"adamw_8bit\",\n", + " \"weight_decay\": 0.01,\n", + " \"lr_scheduler_type\": \"linear\",\n", + " \"output_dir\": \"outputs/simulator_finetune\",\n", + " \"logging_dir\": \"logs\",\n", + " \"load_best_model_at_end\": True,\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bf55d140", + "metadata": {}, + "outputs": [], + "source": [ + "TAU = 7\n", + "DEFERRAL_PROBABILITY_THRESHOLD = 0.2518938553561718\n", + "NUM_SIMULATIONS = 10\n", + "OUTPUT_DIR = \"benchmark_preannotated\"\n", + "SEEDS = [1,2,3,4,5]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aedd02fe", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "==((====))== Unsloth 2025.7.11: Fast Llama patching. Transformers: 4.53.3.\n", + " \\\\ /| NVIDIA RTX A6000. Num GPUs = 1. Max memory: 47.536 GB. Platform: Linux.\n", + "O^O/ \\_/ \\ Torch: 2.7.1+cu128. CUDA: 8.6. CUDA Toolkit: 12.8. Triton: 3.1.0+cf34004b8a\n", + "\\ / Bfloat16 = TRUE. FA [Xformers = 0.0.31.post1. FA2 = False]\n", + " \"-____-\" Free license: http://github.com/unslothai/unsloth\n", + "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Unsloth 2025.7.11 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.\n", + "Unsloth: Already have LoRA adapters! We shall skip this step.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Unsloth: Training embed_tokens in mixed precision to save VRAM\n", + "Unsloth: Training lm_head in mixed precision to save VRAM\n" + ] + } + ], + "source": [ + "simulator_model = UnslothUtteranceSimulatorModel(\n", + " model_name=\"/reef/lyk25/dynamic_training/game_analysis/outputs/checkpoint-74\",\n", + " train_config=SIMULATOR_TRAIN_CONFIG,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "fe147ae8", + "metadata": {}, + "source": [ + "After loading the simulator model, we define context selectors." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "a9502041", + "metadata": {}, + "outputs": [], + "source": [ + "def context_selector(context_tuple, split):\n", + " \"\"\"\n", + " We use this generic function for both training and validation data.\n", + " In both cases, its job is to select only those contexts for which the\n", + " FUTURE context is not empty, so we have a next utterance to predict.\n", + " \"\"\"\n", + " matches_split = (context_tuple.current_utterance.get_conversation().meta.get(\"split\") == split)\n", + " is_end = (len(context_tuple.future_context) == 0)\n", + " return matches_split and not is_end\n", + "\n", + "def make_data_selector(split):\n", + " return lambda context_tuple: context_tuple.current_utterance.get_conversation().meta.get(\"split\") == split\n", + "\n", + "train_context_selector = partial(context_selector, split=\"train\")\n", + "val_context_selector = partial(context_selector, split=\"val\")\n", + "test_context_selector = partial(context_selector, split=\"test\")" + ] + }, + { + "cell_type": "markdown", + "id": "f707a3a5", + "metadata": {}, + "source": [ + "Below is a script to fully reproduce, sans regenerating simulations (simulations requires substantially more compute)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dde130ed", + "metadata": {}, + "outputs": [], + "source": [ + "for seed_idx in SEEDS:\n", + " train_corpus = Corpus(filename=download('conversations-gone-awry-cmv-corpus-large'))\n", + " # corpus = Corpus(filename=f'/reef/lyk25/theres-a-way-out-ACL26-internal/outputs/benchmark_preannotated/seed-{seed_idx}-ThresholdDecisionPolicy/ThresholdDecisionPolicy')\n", + " # corpus = Corpus(filename=f'/reef/lyk25/dynamic_training/game_analysis/corpi/test/test-son-seed-{seed_idx}')\n", + " corpus = Corpus(filename=download('conversations-gone-awry-cmv-corpus-large'))\n", + " corpus.filter_conversations_by(lambda convo: convo.meta['split'] == 'test')\n", + "\n", + " config = TransformerForecasterConfig(\n", + " output_dir=f\"outputs/{OUTPUT_DIR}/forecaster_{seed_idx}\",\n", + " per_device_batch_size=16,\n", + " gradient_accumulation_steps=1,\n", + " num_train_epochs=1,\n", + " learning_rate=1e-5,\n", + " random_seed=seed_idx,\n", + " context_mode=\"normal\",\n", + " device=\"cuda\",\n", + " )\n", + "\n", + " # TODO this will have to be edited\n", + " forecaster_model = TransformerDecoderModel(\n", + " model_name_or_path=\"google/gemma-2-9b-it\",\n", + " config=config,\n", + " )\n", + "\n", + " forecaster = Forecaster(\n", + " forecaster_model=forecaster_model,\n", + " labeler='has_removed_comment',\n", + " )\n", + "\n", + " forecaster.fit_belief_estimator(\n", + " corpus=train_corpus,\n", + " context_selector=train_context_selector,\n", + " val_context_selector=val_context_selector,\n", + " )\n", + "\n", + " # ---\n", + " cfg_path = os.path.join(repo_root, \"saves\", f\"seed-{seed_idx}\", \"dev_config.json\")\n", + " with open(cfg_path) as f:\n", + " cfg = json.load(f)\n", + " best_threshold = cfg['best_threshold']\n", + "\n", + " for policy_trial in [ThresholdDecisionPolicy, DeferralDecisionPolicy, RandomDeferralDecisionPolicy, SimulationAverageDecisionPolicy, SimulationMajorityDecisionPolicy]:\n", + " print('---')\n", + " print(f\"Fitting policy {policy_trial.__name__} for seed {seed_idx}\")\n", + " if policy_trial == ThresholdDecisionPolicy:\n", + " policy = ThresholdDecisionPolicy(\n", + " threshold=best_threshold,\n", + " reuse_cached_forecast_probs=False,\n", + " )\n", + " elif policy_trial == DeferralDecisionPolicy:\n", + " policy = DeferralDecisionPolicy(\n", + " simulator=simulator_model,\n", + " threshold=best_threshold,\n", + " tau=TAU,\n", + " reuse_cached_forecast_probs=False,\n", + " )\n", + " elif policy_trial == RandomDeferralDecisionPolicy:\n", + " policy = RandomDeferralDecisionPolicy(\n", + " simulator=simulator_model,\n", + " threshold=best_threshold,\n", + " deferral_probability=DEFERRAL_PROBABILITY_THRESHOLD,\n", + " reuse_cached_forecast_probs=False,\n", + " )\n", + " elif policy_trial == SimulationAverageDecisionPolicy:\n", + " policy = SimulationAverageDecisionPolicy(\n", + " simulator=simulator_model,\n", + " threshold=best_threshold,\n", + " num_simulations=NUM_SIMULATIONS,\n", + " store_simulations=False,\n", + " simulated_reply_attribute_name=\"sim_replies\",\n", + " sim_replies_forecast_probs_attribute_name=\"sim_replies_forecast_probs\",\n", + " reuse_cached_forecast_probs=False,\n", + " )\n", + " elif policy_trial == SimulationMajorityDecisionPolicy:\n", + " policy = SimulationMajorityDecisionPolicy(\n", + " simulator=simulator_model,\n", + " threshold=best_threshold,\n", + " tau=TAU,\n", + " num_simulations=NUM_SIMULATIONS,\n", + " store_simulations=False,\n", + " simulated_reply_attribute_name=\"sim_replies\",\n", + " sim_replies_forecast_probs_attribute_name=\"sim_replies_forecast_probs\",\n", + " reuse_cached_forecast_probs=False,\n", + " )\n", + " \n", + " # attach the decision policy to the underlying forecaster model;\n", + " # Forecaster itself does not accept decision_policy in its constructor.\n", + " forecaster_model.decision_policy = policy\n", + "\n", + " forecaster = Forecaster(\n", + " forecaster_model=forecaster_model,\n", + " labeler='has_removed_comment',\n", + " )\n", + "\n", + " print('starting transformation.')\n", + " # evaluate the forecaster on the test set\n", + " forecaster.transform(\n", + " corpus=corpus,\n", + " context_selector=make_data_selector('test'),\n", + " verbose=True,\n", + " )\n", + " print('transformation complete.')\n", + "\n", + " output_dir = f\"outputs/{OUTPUT_DIR}/seed-{seed_idx}-{policy_trial.__name__}\"\n", + " os.makedirs(output_dir, exist_ok=True)\n", + " corpus.dump(name=f\"{policy_trial.__name__}\", base_path=output_dir)\n", + " print('corpus dumped.')\n", + "\n", + " print('starting summarization.')\n", + " # forecaster.summarize expects a conversation-level selector (Callable[[Conversation], bool]),\n", + " # unlike the context-tuple selectors used in fit/transform.\n", + " def summarize_selector(convo):\n", + " return convo.meta.get(\"split\") == \"test\"\n", + " conversational_forecasts_df, metrics = forecaster.summarize(\n", + " corpus=corpus,\n", + " selector=summarize_selector,\n", + " )\n", + " print('summarization complete.')\n", + " \n", + " # path to the seed output directory\n", + " seed_folder = f\"outputs/{OUTPUT_DIR}/seed-{seed_idx}-{policy_trial.__name__}\"\n", + "\n", + " # ensure the directory exists\n", + " os.makedirs(seed_folder, exist_ok=True)\n", + "\n", + " # save conversational_forecasts_df as CSV\n", + " conversational_forecasts_df.to_csv(os.path.join(seed_folder, \"conversational_forecasts.csv\"), index=False)\n", + "\n", + " # save metrics as JSON\n", + " with open(os.path.join(seed_folder, \"metrics.json\"), \"w\") as f:\n", + " json.dump(metrics, f, indent=2)" + ] + }, + { + "cell_type": "markdown", + "id": "e3c6f428", + "metadata": {}, + "source": [ + "A faster reproduction is possible by skipping the training and transformation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2f4a29c2", + "metadata": {}, + "outputs": [], + "source": [ + "for seed_idx in SEEDS:\n", + " config = TransformerForecasterConfig(\n", + " output_dir=f\"outputs/{OUTPUT_DIR}/forecaster_{seed_idx}\",\n", + " per_device_batch_size=16,\n", + " gradient_accumulation_steps=1,\n", + " num_train_epochs=1,\n", + " learning_rate=1e-5,\n", + " random_seed=seed_idx,\n", + " context_mode=\"normal\",\n", + " device=\"cuda\",\n", + " )\n", + "\n", + " # TODO this will have to be edited\n", + " forecaster_model = TransformerDecoderModel(\n", + " model_name_or_path=\"google/gemma-2-9b-it\",\n", + " config=config,\n", + " )\n", + "\n", + " forecaster = Forecaster(\n", + " forecaster_model=forecaster_model,\n", + " labeler='has_removed_comment',\n", + " )\n", + "\n", + " # remove training---we can instead use the cached forecast probabilities.\n", + "\n", + " # ---\n", + " cfg_path = os.path.join(repo_root, \"saves\", f\"seed-{seed_idx}\", \"dev_config.json\")\n", + " with open(cfg_path) as f:\n", + " cfg = json.load(f)\n", + " best_threshold = cfg['best_threshold']\n", + "\n", + " for policy_trial in [ThresholdDecisionPolicy, DeferralDecisionPolicy, RandomDeferralDecisionPolicy, SimulationAverageDecisionPolicy, SimulationMajorityDecisionPolicy]:\n", + " corpus_name = f\"seed-{seed_idx}-{policy_trial.__name__}\"\n", + " if corpus_name in corpora:\n", + " corpus = corpora_all[corpus_name]\n", + " else:\n", + " raise KeyError(f\"missing corpus {corpus_name}\")\n", + "\n", + " print('---')\n", + " print(f\"fitting policy {policy_trial.__name__} for seed {seed_idx}\")\n", + " if policy_trial == ThresholdDecisionPolicy:\n", + " policy = ThresholdDecisionPolicy(\n", + " threshold=best_threshold,\n", + " )\n", + " elif policy_trial == DeferralDecisionPolicy:\n", + " policy = DeferralDecisionPolicy(\n", + " simulator=simulator_model,\n", + " threshold=best_threshold,\n", + " tau=TAU,\n", + " )\n", + " elif policy_trial == RandomDeferralDecisionPolicy:\n", + " policy = RandomDeferralDecisionPolicy(\n", + " simulator=simulator_model,\n", + " threshold=best_threshold,\n", + " deferral_probability=DEFERRAL_PROBABILITY_THRESHOLD,\n", + " )\n", + " elif policy_trial == SimulationAverageDecisionPolicy:\n", + " policy = SimulationAverageDecisionPolicy(\n", + " simulator=simulator_model,\n", + " threshold=best_threshold,\n", + " num_simulations=NUM_SIMULATIONS,\n", + " store_simulations=False,\n", + " simulated_reply_attribute_name=\"sim_replies\",\n", + " sim_replies_forecast_probs_attribute_name=\"sim_replies_forecast_probs\",\n", + " )\n", + " elif policy_trial == SimulationMajorityDecisionPolicy:\n", + " policy = SimulationMajorityDecisionPolicy(\n", + " simulator=simulator_model,\n", + " threshold=best_threshold,\n", + " tau=TAU,\n", + " num_simulations=NUM_SIMULATIONS,\n", + " store_simulations=False,\n", + " simulated_reply_attribute_name=\"sim_replies\",\n", + " sim_replies_forecast_probs_attribute_name=\"sim_replies_forecast_probs\",\n", + " )\n", + " \n", + " # attach the decision policy to the underlying forecaster model;\n", + " # Forecaster itself does not accept decision_policy in its constructor.\n", + " forecaster_model.decision_policy = policy\n", + "\n", + " forecaster = Forecaster(\n", + " forecaster_model=forecaster_model,\n", + " labeler='has_removed_comment',\n", + " )\n", + "\n", + " print('starting transformation.')\n", + " # evaluate the forecaster on the test set\n", + " forecaster.transform(\n", + " corpus=corpus,\n", + " context_selector=make_data_selector('test'),\n", + " verbose=True,\n", + " )\n", + " print('transformation complete.')\n", + "\n", + " output_dir = f\"outputs/{OUTPUT_DIR}/seed-{seed_idx}-{policy_trial.__name__}\"\n", + " os.makedirs(output_dir, exist_ok=True)\n", + " corpus.dump(name=f\"{policy_trial.__name__}\", base_path=output_dir)\n", + " print('corpus dumped.')\n", + "\n", + " print('starting summarization.')\n", + " # forecaster.summarize expects a conversation-level selector (Callable[[Conversation], bool]),\n", + " # unlike the context-tuple selectors used in fit/transform.\n", + " def summarize_selector(convo):\n", + " return convo.meta.get(\"split\") == \"test\"\n", + " conversational_forecasts_df, metrics = forecaster.summarize(\n", + " corpus=corpus,\n", + " selector=summarize_selector,\n", + " )\n", + " print('summarization complete.')\n", + " \n", + " # path to the seed output directory\n", + " seed_folder = f\"outputs/{OUTPUT_DIR}/seed-{seed_idx}-{policy_trial.__name__}\"\n", + "\n", + " # ensure the directory exists\n", + " os.makedirs(seed_folder, exist_ok=True)\n", + "\n", + " # save conversational_forecasts_df as CSV\n", + " conversational_forecasts_df.to_csv(os.path.join(seed_folder, \"conversational_forecasts.csv\"), index=False)\n", + "\n", + " # save metrics as JSON\n", + " with open(os.path.join(seed_folder, \"metrics.json\"), \"w\") as f:\n", + " json.dump(metrics, f, indent=2)" + ] + }, + { + "cell_type": "markdown", + "id": "29a0cdc1", + "metadata": {}, + "source": [ + "## Human benchmark analysis" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "d673d0d6", + "metadata": {}, + "outputs": [], + "source": [ + "# import human data SQL here" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "d55b3bab", + "metadata": {}, + "outputs": [], + "source": [ + "import sqlite3" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "6047d55a", + "metadata": {}, + "outputs": [], + "source": [ + "round_n = 1" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "92766a80", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[info] connected to database: /reef/sqt2/cga-eval/human/game_db.sqlite\n" + ] + } + ], + "source": [ + "# connect to the game database\n", + "db_path = '/reef/sqt2/cga-eval/human/game_db.sqlite'\n", + "connection = sqlite3.connect(db_path)\n", + "cursor = connection.cursor()\n", + "\n", + "print(f\"[info] connected to database: {db_path}\")\n", + "cursor.execute(\"SELECT COUNT(*) FROM results\")\n", + "\n", + "# # names_list = []\n", + "# # answers_list = []\n", + "# # scores_list = []\n", + "# # comments_list = []\n", + "\n", + "# TIME_CUTOFF = 1714059814630 # timestamp to cut off previous results\n", + "# TIME_CUTOFF_END = 1730385762530\n", + "\n", + "# for row in connection.execute('SELECT * FROM results'):\n", + "# if row[0] != 'yc2727' and row[0] != 'sqt2':\n", + "# answers = json.loads(row[1])\n", + "# if answers[0]['start_time'] > TIME_CUTOFF and answers[0]['start_time'] < TIME_CUTOFF_END:\n", + "# names_list.append(row[0])\n", + "# answers_list.append(answers)\n", + "# scores_list.append(row[2])\n", + "# comments_list.append(row[3])\n", + "\n", + "if round_n == 1:\n", + " # for round_n 1\n", + " names_list_1 = []\n", + " answers_list_1 = []\n", + " scores_list_1 = []\n", + " comments_list_1 = []\n", + "\n", + " TIME_CUTOFF_1 = 1730395000000\n", + " TIME_CUTOFF_END_1 = 1730397199000\n", + "\n", + " for row in connection.execute('SELECT * FROM results'):\n", + " if row[0] != 'yc2727' and row[0] != 'sqt2':\n", + " answers = json.loads(row[1])\n", + " if answers[0]['start_time'] > TIME_CUTOFF_1 and answers[0]['start_time'] < TIME_CUTOFF_END_1:\n", + " names_list_1.append(row[0])\n", + " answers_list_1.append(answers)\n", + " scores_list_1.append(row[2])\n", + " comments_list_1.append(row[3])\n", + "elif round_n == 2:\n", + " # for round 2\n", + " names_list_2 = []\n", + " answers_list_2 = []\n", + " scores_list_2 = []\n", + " comments_list_2 = []\n", + "\n", + " TIME_CUTOFF_2_a = 1731002910000\n", + " TIME_CUTOFF_END_2_a = 1731016180000\n", + "\n", + " for row in connection.execute('SELECT * FROM results'):\n", + " if row[0] != 'yc2727' and row[0] != 'sqt2':\n", + " answers = json.loads(row[1])\n", + " if answers[0]['start_time'] > TIME_CUTOFF_2_a and answers[0]['start_time'] < TIME_CUTOFF_END_2_a:\n", + " names_list_2.append(row[0])\n", + " answers_list_2.append(answers)\n", + " scores_list_2.append(row[2])\n", + " comments_list_2.append(row[3])\n", + "\n", + " # The 2_b cutoff below is to specifically add data from 'ljl2' for a later second round window,\n", + " # likely because they submitted their data late or for a different time block.\n", + " TIME_CUTOFF_2_b = 1732212720000\n", + " TIME_CUTOFF_END_2_b = 1740000000000\n", + "\n", + " for row in connection.execute('SELECT * FROM results'):\n", + " if row[0].lower() == 'ljl2':\n", + " answers = json.loads(row[1])\n", + " if answers[0]['start_time'] > TIME_CUTOFF_2_b and answers[0]['start_time'] < TIME_CUTOFF_END_2_b:\n", + " names_list_2.append(row[0])\n", + " answers_list_2.append(answers)\n", + " scores_list_2.append(row[2])\n", + " comments_list_2.append(row[3])\n", + "\n", + " # print(f\"Round 1: {len(names_list_1)} entries\")\n", + " # print(f\"Round 2: {len(names_list_2)} entries\")\n", + " # names_list_1, names_list_2" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "889d7ede", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[info] total unique conversation ids: 84\n", + "['givp578', 'e1vstxd', 'f1353ra', 'ewr7ls3', 'dmlny27', 'dg7wmdb', 'flixm8f', 'f13u0n2', 'ggwdbsa', 'd3f07tv', 'isz5n6g', 'givok1i', 'dyx3au1', 'g1sm9mt', 'f0dryv4', 'doi44j8', 'g2998qp', 'dpirnlc', 'h2h7yl0', 'dg7x3eu', 'ej6jusi', 'ikzb571', 'e1vv6x1', 'ifkoazb', 'dpkib9q', 'g1q3upb', 'dydawov', 'g1q1973', 'dm8exht', 'fphcwq0', 'gmxqrdz', 'dmlqnzz', 'dyd7cfv', 'i4rgtqf', 'doi3vuz', 'd0fyu9x', 'f014doo', 'dm8qu78', 'gmxoyuu', 'fmre7l9', 'ii3cpve', 'ii4ivim', 'e9mynuy', 'djh1bt1', 'f83hb2k', 'ewtowqu', 'dbnxbzn', 'fixjxxs', 'dvivs9p', 'ej6ollb', 'g1spg0k', 'ggwei5y', 'd0fzp40', 'gvb1ekl', 'djh2v24', 'dvisfl0', 'ewtjxp2', 'dyx2jn4', 'ewr7msm', 'd3efg03', 'fixlxvy', 'fegw71k', 'ikzcblg', 'ffpk80c', 'fmspcex', 'f00nxjs', 'gvdnrrb', 'h28ltfw', 'f0dsbw4', 'dbnr5nd', 'ikpw1ol', 'ffpj88q', 'i4rh1id', 'e9mybxp', 'isz6a7m', 'ifknq44', 'g297nn7', 'fegwvia', 'ikpxahv', 'f83akyz', 'h2hauaj', 'h28lknz', 'g3hmlew', 'fljdrtt']\n" + ] + } + ], + "source": [ + "human_map1 = {\n", + " \"vn72\": ['d3efg03', 'd3f07tv', 'f1353ra', 'f13u0n2', 'fegw71k', 'fegwvia', 'ikpw1ol', 'ikpxahv', 'isz5n6g', 'isz6a7m'],\n", + " \"nac86\": ['dyd7cfv', 'dydawov', 'ewtjxp2', 'ewtowqu', 'f00nxjs', 'f014doo', 'ii3cpve', 'ii4ivim', 'ikzb571', 'ikzcblg'],\n", + " # \"sj597\": ['dbnr5nd', 'dbnxbzn', 'g1sm9mt', 'g1spg0k', 'g297nn7', 'g2998qp', 'h2h7yl0', 'h2hauaj', 'ifknq44', 'ifkoazb'],\n", + " \"yc2727\": ['d08x3kl', 'd08x4q0', 'd3efg03', 'd3f07tv', 'ewtjxp2', 'ewtowqu', 'hnupywh', 'hnuu3ip', 'ih2fn94', 'ih3iwph'],\n", + " # \"ex36\": ['d0fyu9x', 'd0fzp40', 'dg7wmdb', 'dg7x3eu', 'ej6jusi', 'ej6ollb', 'ewtjxp2', 'ewtowqu', 'f0dryv4', 'f0dsbw4'],\n", + " \"kz88\": ['doi3vuz', 'doi44j8', 'e9mybxp', 'e9mynuy', 'g1q1973', 'g1q3upb', 'ggwdbsa', 'ggwei5y', 'gmxoyuu', 'gmxqrdz'],\n", + " \"LJL2\": ['ffpj88q', 'ffpk80c', 'flixm8f', 'fljdrtt', 'fmre7l9', 'fmspcex', 'fphcwq0', 'g3hmlew', 'givok1i', 'givp578'],\n", + " \"lyk25\": ['dpirnlc', 'dpkib9q', 'e1vstxd', 'e1vv6x1', 'ewr7ls3', 'ewr7msm', 'fixjxxs', 'fixlxvy', 'fmre7l9', 'fmspcex'],\n", + " \"sqt2\": ['d3efg03', 'd3f07tv', 'g0fwpzc', 'g0ggz8e', 'g8nyz4f', 'g8nzx3m', 'h4i75b9', 'h4i79lc', 'h6ikmzc', 'h6ime20'],\n", + " \"cd326\": ['djh1bt1', 'djh2v24', 'dmlny27', 'dmlqnzz', 'f83akyz', 'f83hb2k', 'h28lknz', 'h28ltfw', 'i4rgtqf', 'i4rh1id'],\n", + " \"tg352\": ['dm8exht', 'dm8qu78', 'dvisfl0', 'dvivs9p', 'dyx2jn4', 'dyx3au1', 'gvb1ekl', 'gvdnrrb', 'i4rgtqf', 'i4rh1id'],\n", + "}\n", + "\n", + "# all_convo_ids = []\n", + "# for k,v in human_map1.items():\n", + "# all_convo_ids.extend(v)\n", + "# all_convo_ids = list(set(all_convo_ids))\n", + "# len(all_convo_ids)\n", + "\n", + "all_convo_ids = []\n", + "# # collect from round 1\n", + "if round_n == 1:\n", + " for i in range(len(answers_list_1)):\n", + " for j in range(len(answers_list_1[i])):\n", + " all_convo_ids.append(answers_list_1[i][j]['id'])\n", + "# collect from round 2\n", + "elif round_n == 2:\n", + " for i in range(len(answers_list_2)):\n", + " for j in range(len(answers_list_2[i])):\n", + " all_convo_ids.append(answers_list_2[i][j]['id'])\n", + "all_convo_ids = list(set(all_convo_ids))\n", + "print(f\"[info] total unique conversation ids: {len(all_convo_ids)}\")\n", + "print(all_convo_ids)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "e55a3803", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset already exists at /reef/lyk25/ConvoKit/examples/forecaster/conversations-gone-awry-cmv-corpus-large\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "corpus = Corpus(filename=download('conversations-gone-awry-cmv-corpus-large'))\n", + "corpus.filter_conversations_by(lambda convo: convo.id in all_convo_ids)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "168fb025", + "metadata": {}, + "outputs": [], + "source": [ + "# we want all utterances to have a human_guesses meta field\n", + "for convo in corpus.iter_conversations():\n", + " for utt in convo.get_chronological_utterance_list():\n", + " utt.add_meta('human_guesses', [])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27fc646a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "vn72\n", + "nac86\n", + "sj597\n", + "ex36\n", + "kz88\n", + "LJL2\n", + "lyk25\n", + "cd326\n", + "tg352\n" + ] + } + ], + "source": [ + "if round_n == 1:\n", + " for i, participant_sessions in enumerate(answers_list_1):\n", + " participant_id = names_list_1[i]\n", + " print(participant_id)\n", + " for session in participant_sessions:\n", + " convo_id = session['id']\n", + " convo = corpus.get_conversation(convo_id)\n", + " utts = convo.get_chronological_utterance_list()\n", + " actions = session.get('actions', [])\n", + "\n", + " for action_idx, action in enumerate(actions):\n", + " if action.get('guess') is True:\n", + " utt = utts[action_idx]\n", + " # get current guesses (returns deep copy if exists, or empty list if not)\n", + " # create a new list to avoid mutating the copy\n", + " current_guesses = list(utt.meta.get('human_guesses', []))\n", + " # append new participant and set back\n", + " current_guesses.append(participant_id)\n", + " utt.add_meta('human_guesses', current_guesses)\n", + "elif round_n == 2:\n", + " for i, participant_sessions in enumerate(answers_list_2):\n", + " participant_id = names_list_2[i]\n", + " print(participant_id)\n", + " for session in participant_sessions:\n", + " convo_id = session['id']\n", + " convo = corpus.get_conversation(convo_id)\n", + " utts = convo.get_chronological_utterance_list()\n", + " actions = session.get('actions', [])\n", + "\n", + " for action_idx, action in enumerate(actions):\n", + " if action.get('guess') is True:\n", + " utt = utts[action_idx]\n", + " # get current guesses (returns deep copy if exists, or empty list if not)\n", + " # create a new list to avoid mutating the copy\n", + " current_guesses = list(utt.meta.get('human_guesses', []))\n", + " # append new participant and set back\n", + " current_guesses.append(participant_id)\n", + " utt.add_meta('human_guesses', current_guesses)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "fd613b4b", + "metadata": {}, + "outputs": [], + "source": [ + "# initialize model prediction metadata fields for each utterance\n", + "for convo in corpus.iter_conversations():\n", + " for utt in convo.get_chronological_utterance_list():\n", + " utt.add_meta('model_forecast_probs', {}) # will store {seed: prob}\n", + " utt.add_meta('model_forecasts', {}) # will store {seed: binary_forecast}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "873cb706", + "metadata": {}, + "outputs": [], + "source": [ + "# copy model predictions from the corpora loaded above\n", + "loaded_seed_nums = sorted(\n", + " int(corpus_name.rsplit(\"-\", 1)[-1])\n", + " for corpus_name in corpora\n", + " if corpus_name.startswith(\"test-seed-\")\n", + ")\n", + "\n", + "if not loaded_seed_nums:\n", + " available = \", \".join(sorted(corpora))\n", + " raise KeyError(f\"no test seed corpora found; available corpora: {available}\")\n", + "\n", + "for seed_num in loaded_seed_nums:\n", + " seed_key = f\"test-seed-{seed_num}\"\n", + " print(f\"[info] loading predictions from seed {seed_num}: {seed_key}\")\n", + " seed_corpus = corpora[seed_key]\n", + "\n", + " for convo in corpus.iter_conversations():\n", + " seed_convo = seed_corpus.get_conversation(convo.id)\n", + " main_utts = convo.get_chronological_utterance_list()\n", + " seed_utts = seed_convo.get_chronological_utterance_list()\n", + "\n", + " for main_utt, seed_utt in zip(main_utts, seed_utts):\n", + " forecast_prob = seed_utt.meta.get(\"forecast_prob\")\n", + " forecast = seed_utt.meta.get(\"forecast\")\n", + "\n", + " if forecast_prob is not None:\n", + " current_probs = dict(main_utt.meta.get(\"model_forecast_probs\", {}))\n", + " current_probs[f\"seed_{seed_num}\"] = forecast_prob\n", + " main_utt.add_meta(\"model_forecast_probs\", current_probs)\n", + "\n", + " if forecast is not None:\n", + " current_forecasts = dict(main_utt.meta.get(\"model_forecasts\", {}))\n", + " current_forecasts[f\"seed_{seed_num}\"] = forecast\n", + " main_utt.add_meta(\"model_forecasts\", current_forecasts)\n", + "\n", + "print(\"\\n[pass] model predictions from all seeds added to corpus\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "5b0f0e44", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "gemma horizon (TPs only, mean - 1):\n", + " mean over seeds: h=nan\n", + " pooled: n=0, h=nan\n", + "human horizon (round 1, TPs only, mean - 1):\n", + " player vn72: n=1, h=6.0000\n", + " player nac86: n=2, h=0.5000\n", + " player sj597: n=4, h=2.2500\n", + " player ex36: n=1, h=2.0000\n", + " player kz88: n=2, h=2.5000\n", + " player LJL2: n=4, h=1.5000\n", + " player lyk25: n=4, h=2.5000\n", + " player cd326: n=2, h=1.5000\n", + " player tg352: n=2, h=5.0000\n", + " mean over players: h=2.6389\n", + " pooled: n=22, h=2.3636\n", + "human horizon (round 2, TPs only, mean - 1):\n", + " player sj597: n=4, h=1.2500\n", + " player nac86: n=2, h=4.0000\n", + " player ex36: n=2, h=0.5000\n", + " player vn72: n=3, h=1.6667\n", + " player lyk25: n=4, h=3.0000\n", + " player kz88: n=1, h=3.0000\n", + " player cd326: n=4, h=3.2500\n", + " player tg352: n=3, h=2.0000\n", + " player LJL2: n=2, h=0.5000\n", + " mean over players: h=2.1296\n", + " pooled: n=25, h=2.1600\n", + "\n", + "summary (comparable to h in performance_utils_wiki):\n", + " gemma h (mean over seeds): nan\n", + " round1 h (mean over players): 2.6389\n", + " round2 h (mean over players): 2.1296\n" + ] + } + ], + "source": [ + "# horizon comparison between gemma and humans, using the same convention as\n", + "# performance_utils_wiki.calculate_current_performance:\n", + "# - only count true positives (ground truth awry AND trigger fired)\n", + "# - horizon = (len(utts) - first_trigger_index) - 1\n", + "# - first trigger only (break after first positive)\n", + "#\n", + "# reports per-round results for humans (round 1 and round 2), pulling data\n", + "# directly from the game sqlite so it works regardless of which round_n the\n", + "# rest of the notebook was run under. ljl2 is excluded.\n", + "\n", + "import sqlite3\n", + "import json as _json\n", + "from collections import defaultdict\n", + "import numpy as np\n", + "\n", + "def _horizon_mean(horizons):\n", + " # matches performance_utils_wiki: h = mean(horizons) - 1\n", + " if len(horizons) == 0:\n", + " return float('nan'), 0\n", + " return float(np.mean(horizons)) - 1, len(horizons)\n", + "\n", + "# always excluded (staff/test accounts). ljl2 is a legitimate late submitter\n", + "# in the round 2_b window, not an exclusion.\n", + "base_excluded = {'yc2727', 'sqt2'}\n", + "round1_excluded = set(base_excluded)\n", + "round2_excluded = set(base_excluded)\n", + "\n", + "# gemma: per-seed first-trigger horizons on TP convos only (ground truth awry)\n", + "gemma_horizons_by_seed = defaultdict(list)\n", + "for convo in corpus.iter_conversations():\n", + " if not convo.meta.get('has_removed_comment', False):\n", + " continue # skip non-awry convos; horizon is only meaningful on TPs\n", + " utts = convo.get_chronological_utterance_list()\n", + " for seed_num in range(1, 6):\n", + " for i, utt in enumerate(utts):\n", + " if utt.meta.get('model_forecasts', {}).get(f'seed_{seed_num}') == 1:\n", + " gemma_horizons_by_seed[seed_num].append(len(utts) - i)\n", + " break\n", + "\n", + "print('gemma horizon (TPs only, mean - 1):')\n", + "gemma_per_seed_h = []\n", + "for seed_num in sorted(gemma_horizons_by_seed.keys()):\n", + " h, n = _horizon_mean(gemma_horizons_by_seed[seed_num])\n", + " gemma_per_seed_h.append(h)\n", + " print(f' seed {seed_num}: n={n}, h={h:.4f}')\n", + "gemma_h_mean_of_seeds = float(np.mean(gemma_per_seed_h)) if gemma_per_seed_h else float('nan')\n", + "all_gemma_vals = [v for vs in gemma_horizons_by_seed.values() for v in vs]\n", + "gemma_h_pooled, gemma_n_pooled = _horizon_mean(all_gemma_vals)\n", + "print(f' mean over seeds: h={gemma_h_mean_of_seeds:.4f}')\n", + "print(f' pooled: n={gemma_n_pooled}, h={gemma_h_pooled:.4f}')\n", + "\n", + "# pull round-specific human guesses directly from the sqlite db\n", + "db_path = '/reef/sqt2/cga-eval/human/game_db.sqlite'\n", + "_conn = sqlite3.connect(db_path)\n", + "\n", + "def _load_round_answers(round_n):\n", + " # returns list of (participant_id, sessions) for the given round\n", + " entries = []\n", + " if round_n == 1:\n", + " excluded = round1_excluded\n", + " t_start, t_end = 1730395000000, 1730397199000\n", + " for row in _conn.execute('SELECT * FROM results'):\n", + " if row[0] in excluded:\n", + " continue\n", + " answers = _json.loads(row[1])\n", + " if answers and answers[0]['start_time'] > t_start and answers[0]['start_time'] < t_end:\n", + " entries.append((row[0], answers))\n", + " elif round_n == 2:\n", + " excluded = round2_excluded\n", + " t_start_a, t_end_a = 1731002910000, 1731016180000\n", + " t_start_b, t_end_b = 1732212720000, 1740000000000\n", + " for row in _conn.execute('SELECT * FROM results'):\n", + " name = row[0]\n", + " if name in excluded:\n", + " continue\n", + " answers = _json.loads(row[1])\n", + " if not answers:\n", + " continue\n", + " st = answers[0]['start_time']\n", + " in_a = t_start_a < st < t_end_a\n", + " # the 2_b late window is only valid for ljl2 (matches cell 7)\n", + " in_b = (name.lower() == 'ljl2') and (t_start_b < st < t_end_b)\n", + " if in_a or in_b:\n", + " entries.append((name, answers))\n", + " return entries\n", + "\n", + "def _unique_convo_ids(entries):\n", + " # unique convo ids seen across all included players' sessions\n", + " ids = set()\n", + " for _, sessions in entries:\n", + " for session in sessions:\n", + " cid = session.get('id')\n", + " if cid is not None:\n", + " ids.add(cid)\n", + " return ids\n", + "\n", + "def _compute_human_horizons(entries):\n", + " # returns dict: player_id -> list of horizons (non-awry convos only, first guess per convo)\n", + " by_player = defaultdict(list)\n", + " for participant_id, sessions in entries:\n", + " for session in sessions:\n", + " convo_id = session.get('id')\n", + " if convo_id is None:\n", + " continue\n", + " try:\n", + " convo = corpus.get_conversation(convo_id)\n", + " except KeyError:\n", + " # convo not in the loaded corpus (e.g., different round loaded)\n", + " continue\n", + " if not convo.meta.get('has_removed_comment', False):\n", + " continue # skip non-awry convos; horizon only counted on TPs\n", + " utts = convo.get_chronological_utterance_list()\n", + " for action_idx, action in enumerate(session.get('actions', [])):\n", + " if action.get('guess') is True:\n", + " if action_idx < len(utts):\n", + " by_player[participant_id].append(len(utts) - action_idx)\n", + " break # only first guess per (player, convo)\n", + " return by_player\n", + "\n", + "def _print_human_horizons(label, by_player):\n", + " print(f'human horizon ({label}, TPs only, mean - 1):')\n", + " player_h = []\n", + " for player, horizons in by_player.items():\n", + " h, n = _horizon_mean(horizons)\n", + " player_h.append(h)\n", + " print(f' player {player}: n={n}, h={h:.4f}')\n", + " mean_of_players = float(np.mean(player_h)) if player_h else float('nan')\n", + " all_vals = [v for vs in by_player.values() for v in vs]\n", + " h_pooled, n_pooled = _horizon_mean(all_vals)\n", + " print(f' mean over players: h={mean_of_players:.4f}')\n", + " print(f' pooled: n={n_pooled}, h={h_pooled:.4f}')\n", + " return mean_of_players, h_pooled\n", + "\n", + "EXPECTED_N_CONVOS = 84\n", + "round1_entries = _load_round_answers(1)\n", + "round1_convo_ids = _unique_convo_ids(round1_entries)\n", + "round1_by_player = _compute_human_horizons(round1_entries)\n", + "r1_mean, r1_pooled = _print_human_horizons('round 1', round1_by_player)\n", + "\n", + "round2_entries = _load_round_answers(2)\n", + "round2_convo_ids = _unique_convo_ids(round2_entries)\n", + "round2_by_player = _compute_human_horizons(round2_entries)\n", + "r2_mean, r2_pooled = _print_human_horizons('round 2', round2_by_player)\n", + "\n", + "_conn.close()\n", + "\n", + "print()\n", + "print('summary (comparable to h in performance_utils_wiki):')\n", + "print(f' gemma h (mean over seeds): {gemma_h_mean_of_seeds:.4f}')\n", + "print(f' round1 h (mean over players): {r1_mean:.4f}')\n", + "print(f' round2 h (mean over players): {r2_mean:.4f}')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "44ed25ac", + "metadata": {}, + "outputs": [], + "source": [ + "from collections import defaultdict\n", + "import numpy as np\n", + "def _compute_metrics(tp, fp, tn, fn):\n", + " total = tp + fp + tn + fn\n", + " accuracy = (tp + tn) / total if total > 0 else 0.0\n", + " precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0\n", + " recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0\n", + " f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0\n", + " fpr = fp / (fp + tn) if (fp + tn) > 0 else 0.0\n", + " specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0\n", + " fnr = fn / (fn + tp) if (fn + tp) > 0 else 0.0\n", + " return {\n", + " 'tp': tp, 'fp': fp, 'tn': tn, 'fn': fn,\n", + " 'accuracy': accuracy, 'precision': precision, 'recall': recall,\n", + " 'f1': f1, 'fpr': fpr, 'specificity': specificity, 'fnr': fnr,\n", + " }\n", + "def _gt(convo):\n", + " return bool(convo.meta.get('has_removed_comment', False))\n", + "# gemma benchmark, per seed, then mean and std\n", + "gemma_per_seed = []\n", + "for seed_num in range(1, 6):\n", + " tp = fp = tn = fn = 0\n", + " for convo in corpus.iter_conversations():\n", + " utts = convo.get_chronological_utterance_list()\n", + " pred = any(\n", + " utt.meta.get('model_forecasts', {}).get(f'seed_{seed_num}') == 1\n", + " for utt in utts\n", + " )\n", + " truth = _gt(convo)\n", + " if pred and truth: tp += 1\n", + " elif pred and not truth: fp += 1\n", + " elif not pred and not truth: tn += 1\n", + " elif not pred and truth: fn += 1\n", + " gemma_per_seed.append(_compute_metrics(tp, fp, tn, fn))\n", + "gemma_mean = {k: float(np.mean([m[k] for m in gemma_per_seed])) for k in gemma_per_seed[0]}\n", + "gemma_std = {k: float(np.std([m[k] for m in gemma_per_seed], ddof=1)) for k in gemma_per_seed[0]}\n", + "def _build_per_player_preds(entries):\n", + " # returns dict[player_id] -> dict[convo_id] -> 0/1\n", + " preds = defaultdict(dict)\n", + " for player, sessions in entries:\n", + " for s in sessions:\n", + " cid = s.get('id')\n", + " if cid is None:\n", + " continue\n", + " try:\n", + " corpus.get_conversation(cid)\n", + " except KeyError:\n", + " continue\n", + " actions = s.get('actions', [])\n", + " preds[player][cid] = 1 if any(a.get('guess') is True for a in actions) else 0\n", + " return preds\n", + "def _aggregate_any(preds):\n", + " # per-convo prediction = 1 if any player who saw it guessed\n", + " by_convo = defaultdict(list)\n", + " for player, convo_preds in preds.items():\n", + " for cid, p in convo_preds.items():\n", + " by_convo[cid].append(p)\n", + " tp = fp = tn = fn = 0\n", + " for cid, plist in by_convo.items():\n", + " try:\n", + " convo = corpus.get_conversation(cid)\n", + " except KeyError:\n", + " continue\n", + " pred = 1 if any(plist) else 0\n", + " truth = 1 if _gt(convo) else 0\n", + " if pred and truth: tp += 1\n", + " elif pred and not truth: fp += 1\n", + " elif not pred and not truth: tn += 1\n", + " elif not pred and truth: fn += 1\n", + " return _compute_metrics(tp, fp, tn, fn), len(by_convo)\n", + "def _per_player_metrics(preds):\n", + " rows = {}\n", + " for player, convo_preds in preds.items():\n", + " tp = fp = tn = fn = 0\n", + " for cid, p in convo_preds.items():\n", + " try:\n", + " convo = corpus.get_conversation(cid)\n", + " except KeyError:\n", + " continue\n", + " truth = 1 if _gt(convo) else 0\n", + " if p and truth: tp += 1\n", + " elif p and not truth: fp += 1\n", + " elif not p and not truth: tn += 1\n", + " elif not p and truth: fn += 1\n", + " rows[player] = _compute_metrics(tp, fp, tn, fn)\n", + " return rows\n", + "# round 1 and round 2 humans\n", + "round1_preds = _build_per_player_preds(round1_entries)\n", + "round1_agg, round1_n_convos = _aggregate_any(round1_preds)\n", + "round1_per_player = _per_player_metrics(round1_preds)\n", + "round2_preds = _build_per_player_preds(round2_entries)\n", + "round2_agg, round2_n_convos = _aggregate_any(round2_preds)\n", + "round2_per_player = _per_player_metrics(round2_preds)\n", + "# metric mean over players (a different aggregation view)\n", + "def _mean_over_players(per_player):\n", + " keys = ['accuracy', 'precision', 'recall', 'f1', 'fpr', 'specificity', 'fnr']\n", + " return {k: float(np.mean([m[k] for m in per_player.values()])) for k in keys}\n", + "round1_mean_over_players = _mean_over_players(round1_per_player)\n", + "round2_mean_over_players = _mean_over_players(round2_per_player)\n", + "# printing\n", + "metric_order = ['accuracy', 'precision', 'recall', 'f1', 'fpr', 'specificity', 'fnr']\n", + "def _print_per_player(label, per_player):\n", + " print(f'{label} - per-player metrics:')\n", + " header = f\" {'player':<10}\" + \"\".join(f\"{m:>12}\" for m in metric_order) + f\"{'n_convos':>10}\"\n", + " print(header)\n", + " print(' ' + '-' * (len(header) - 2))\n", + " for player in sorted(per_player):\n", + " m = per_player[player]\n", + " n = m['tp'] + m['fp'] + m['tn'] + m['fn']\n", + " row = f\" {player:<10}\" + \"\".join(f\"{m[k]:>12.4f}\" for k in metric_order) + f\"{n:>10d}\"\n", + " print(row)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "68c5005a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "round 1 - per-player metrics:\n", + " player accuracy precision recall f1 fpr specificity fnr n_convos\n", + " --------------------------------------------------------------------------------------------------------\n", + " LJL2 0.8000 0.8000 0.8000 0.8000 0.2000 0.8000 0.2000 10\n", + " cd326 0.6000 0.6667 0.4000 0.5000 0.2000 0.8000 0.6000 10\n", + " ex36 0.4000 0.3333 0.2000 0.2500 0.4000 0.6000 0.8000 10\n", + " kz88 0.7000 1.0000 0.4000 0.5714 0.0000 1.0000 0.6000 10\n", + " lyk25 0.7000 0.6667 0.8000 0.7273 0.4000 0.6000 0.2000 10\n", + " nac86 0.7000 1.0000 0.4000 0.5714 0.0000 1.0000 0.6000 10\n", + " sj597 0.8000 0.8000 0.8000 0.8000 0.2000 0.8000 0.2000 10\n", + " tg352 0.5000 0.5000 0.4000 0.4444 0.4000 0.6000 0.6000 10\n", + " vn72 0.4000 0.3333 0.2000 0.2500 0.4000 0.6000 0.8000 10\n", + "\n", + "round 2 - per-player metrics:\n", + " player accuracy precision recall f1 fpr specificity fnr n_convos\n", + " --------------------------------------------------------------------------------------------------------\n", + " LJL2 0.6000 0.6667 0.4000 0.5000 0.2000 0.8000 0.6000 10\n", + " cd326 0.9000 1.0000 0.8000 0.8889 0.0000 1.0000 0.2000 10\n", + " ex36 0.6000 0.6667 0.4000 0.5000 0.2000 0.8000 0.6000 10\n", + " kz88 0.4000 0.3333 0.2000 0.2500 0.4000 0.6000 0.8000 10\n", + " lyk25 0.9000 1.0000 0.8000 0.8889 0.0000 1.0000 0.2000 10\n", + " nac86 0.6000 0.6667 0.4000 0.5000 0.2000 0.8000 0.6000 10\n", + " sj597 0.9000 1.0000 0.8000 0.8889 0.0000 1.0000 0.2000 10\n", + " tg352 0.7000 0.7500 0.6000 0.6667 0.2000 0.8000 0.4000 10\n", + " vn72 0.7000 0.7500 0.6000 0.6667 0.2000 0.8000 0.4000 10\n", + "\n", + "aggregate benchmark (gemma: mean±std over 5 seeds; humans: any-player rule + mean over players)\n", + "group accuracy precision recall f1 fpr specificity fnr\n", + "--------------------------------------------------------------------------------------------------------------------------------------\n", + "gemma (mean±std over seeds) 0.700±0.018 0.679±0.026 0.762±0.029 0.718±0.012 0.362±0.052 0.638±0.052 0.238±0.029\n", + "round1 humans (mean over players) 0.6222 0.6778 0.4889 0.5461 0.2444 0.7556 0.5111\n", + "round2 humans (mean over players) 0.7000 0.7593 0.5556 0.6389 0.1556 0.8444 0.4444\n", + "\n", + "round 1 unique convos: 84\n", + "round 2 unique convos: 84\n" + ] + } + ], + "source": [ + "pm = '\\u00b1'\n", + "metric_order = ['accuracy', 'precision', 'recall', 'f1', 'fpr', 'specificity', 'fnr']\n", + "\n", + "def _print_per_player(label, per_player):\n", + " print(f'{label} - per-player metrics:')\n", + " header = f\" {'player':<10}\" + \"\".join(f\"{m:>12}\" for m in metric_order) + f\"{'n_convos':>10}\"\n", + " print(header)\n", + " print(' ' + '-' * (len(header) - 2))\n", + " for player in sorted(per_player):\n", + " m = per_player[player]\n", + " n = m['tp'] + m['fp'] + m['tn'] + m['fn']\n", + " row = f\" {player:<10}\" + \"\".join(f\"{m[k]:>12.4f}\" for k in metric_order) + f\"{n:>10d}\"\n", + " print(row)\n", + "\n", + "_print_per_player('round 1', round1_per_player)\n", + "print()\n", + "_print_per_player('round 2', round2_per_player)\n", + "\n", + "print()\n", + "print(f'aggregate benchmark (gemma: mean{pm}std over 5 seeds; humans: any-player rule + mean over players)')\n", + "\n", + "header = f\"{'group':<36}\" + \"\".join(f\"{m:>14}\" for m in metric_order)\n", + "print(header)\n", + "print('-' * len(header))\n", + "\n", + "def _fmt_mean_std(mean_dict, std_dict):\n", + " return \"\".join(f\" {mean_dict[k]:.3f}{pm}{std_dict[k]:.3f}\" for k in metric_order)\n", + "\n", + "def _fmt_mean(mean_dict):\n", + " return \"\".join(f\"{mean_dict[k]:>14.4f}\" for k in metric_order)\n", + "\n", + "gemma_label = f'gemma (mean{pm}std over seeds)'\n", + "print(f\"{gemma_label:<36}\" + _fmt_mean_std(gemma_mean, gemma_std))\n", + "print(f\"{'round1 humans (mean over players)':<36}\" + _fmt_mean(round1_mean_over_players))\n", + "print(f\"{'round2 humans (mean over players)':<36}\" + _fmt_mean(round2_mean_over_players))\n", + "\n", + "print()\n", + "print(f'round 1 unique convos: {round1_n_convos}')\n", + "print(f'round 2 unique convos: {round2_n_convos}')" + ] + }, + { + "cell_type": "markdown", + "id": "9c68d1bd", + "metadata": {}, + "source": [ + "## Validation of forecast probability decrease" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "5e527737", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "pooled trigger_all current perf, best threshold per seed: 0.5532229185317815 (12359/22340)\n", + "pooled delayed_all best threshold per seed, k=7: 0.8353165159447882 (3510/4202)\n" + ] + } + ], + "source": [ + "import json\n", + "from pathlib import Path\n", + "from convokit import Corpus\n", + "\n", + "corpus_base = Path(\"/reef/lyk25/dynamic_training/game_analysis/corpi/test\")\n", + "config_base = Path(\"/reef/sqt2/FinalAAO/cga-cmv-large/google/gemma-2-9b-it\")\n", + "\n", + "corpus_dirs = sorted(\n", + " corpus_dir\n", + " for corpus_dir in corpus_base.rglob(\"*\")\n", + " if corpus_dir.is_dir() and (corpus_dir / \"index.json\").exists()\n", + ")\n", + "\n", + "if not corpus_dirs:\n", + " raise FileNotFoundError(f\"no convokit corpora found under {corpus_base}\")\n", + "\n", + "\n", + "def _test_seed_key(corpus_dir):\n", + " parent_name = corpus_dir.parent.name\n", + " if parent_name.startswith(\"test-seed-\"):\n", + " return parent_name\n", + " seed_idx = corpus_dir.name.rsplit(\"-\", 1)[-1]\n", + " return f\"test-seed-{seed_idx}\"\n", + "\n", + "\n", + "corpus_path_by_key = {}\n", + "for corpus_dir in corpus_dirs:\n", + " key = _test_seed_key(corpus_dir)\n", + " corpus_path_by_key[key] = corpus_dir\n", + "\n", + "\n", + "def is_calm_sim_forecast(forecast):\n", + " if isinstance(forecast, str):\n", + " forecast = forecast.strip().lower()\n", + " if forecast in {\"0\", \"false\", \"calm\"}:\n", + " return True\n", + " if forecast in {\"1\", \"true\", \"awry\"}:\n", + " return False\n", + " return int(forecast) == 0\n", + "\n", + "\n", + "def count_decreases_at_indices(corpus, get_indices):\n", + " total = 0\n", + " reduction = 0\n", + "\n", + " for convo in corpus.iter_conversations():\n", + " utts = convo.get_chronological_utterance_list()\n", + "\n", + " for i in get_indices(utts):\n", + " if i + 1 >= len(utts):\n", + " continue\n", + "\n", + " f1 = utts[i].meta[\"forecast_prob\"]\n", + " f2 = utts[i + 1].meta[\"forecast_prob\"]\n", + " total += 1\n", + "\n", + " if f1 > f2:\n", + " reduction += 1\n", + "\n", + " return reduction, total\n", + "\n", + "\n", + "def current_trigger_indices(utts, pred_threshold):\n", + " return [\n", + " i\n", + " for i, utt in enumerate(utts)\n", + " if utt.meta[\"forecast_prob\"] > pred_threshold\n", + " ]\n", + "\n", + "\n", + "def delayed_indices_from_sim_forecasts(utts, pred_threshold, k):\n", + " delayed = []\n", + "\n", + " for i, utt in enumerate(utts):\n", + " if utt.meta[\"forecast_prob\"] <= pred_threshold:\n", + " continue\n", + "\n", + " sim_forecasts = utt.meta[\"sim_replies_forecasts\"]\n", + " calm_sim_replies = sum(\n", + " 1 for forecast in sim_forecasts if is_calm_sim_forecast(forecast)\n", + " )\n", + "\n", + " if calm_sim_replies > k:\n", + " delayed.append(i)\n", + "\n", + " return delayed\n", + "\n", + "\n", + "current_reduction = 0\n", + "current_total = 0\n", + "delayed_reduction = 0\n", + "delayed_total = 0\n", + "\n", + "for seed in range(1, 6):\n", + " seed_key = f\"test-seed-{seed}\"\n", + " corpus = Corpus(filename=str(corpus_path_by_key[seed_key]))\n", + "\n", + " with open(config_base / f\"seed-{seed}\" / \"dev_config.json\") as f:\n", + " best_threshold = json.load(f)[\"best_threshold\"]\n", + "\n", + " r, t = count_decreases_at_indices(\n", + " corpus,\n", + " lambda utts: current_trigger_indices(utts, best_threshold),\n", + " )\n", + " current_reduction += r\n", + " current_total += t\n", + "\n", + " r, t = count_decreases_at_indices(\n", + " corpus,\n", + " lambda utts: delayed_indices_from_sim_forecasts(utts, best_threshold, k=7),\n", + " )\n", + " delayed_reduction += r\n", + " delayed_total += t\n", + "\n", + "pooled_trigger_all = current_reduction / current_total\n", + "pooled_delayed_all = delayed_reduction / delayed_total\n", + "\n", + "print(f\"pooled trigger_all current perf, best threshold per seed: {pooled_trigger_all} ({current_reduction}/{current_total})\")\n", + "print(f\"pooled delayed_all best threshold per seed, k=7: {pooled_delayed_all} ({delayed_reduction}/{delayed_total})\")" + ] + }, + { + "cell_type": "markdown", + "id": "cb154714", + "metadata": {}, + "source": [ + "## Calculating oracle threshold" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "cba93b72", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[info] using 5 test-seed corpora: test-seed-1, test-seed-2, test-seed-3, test-seed-4, test-seed-5\n" + ] + } + ], + "source": [ + "# corpora comes from the decisionpolicy-demo download cell; use only test-seed- entries.\n", + "import re\n", + "\n", + "if \"corpora\" not in globals() or not corpora:\n", + " raise RuntimeError(\n", + " \"corpora is empty: run the download cell above first.\"\n", + " )\n", + "\n", + "_prev_keys = tuple(sorted(corpora))\n", + "_test_seed_re = re.compile(r\"^test-seed-(\\d+)$\")\n", + "\n", + "corpora = {\n", + " k: corpora[k]\n", + " for k in sorted(\n", + " (k for k in _prev_keys if _test_seed_re.match(k)),\n", + " key=lambda k: int(_test_seed_re.match(k).group(1)),\n", + " )\n", + "}\n", + "\n", + "if not corpora:\n", + " raise RuntimeError(\n", + " \"no test-seed- corpora after filter; had keys: \"\n", + " + \", \".join(_prev_keys)\n", + " )\n", + "\n", + "print(\n", + " f\"[info] using {len(corpora)} test-seed corpora: \"\n", + " + \", \".join(corpora)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "3be18b5d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[info] per-seed best thresholds: {'test-seed-1': 0.5926666259765625, 'test-seed-2': 0.622459352016449, 'test-seed-3': 0.6513549089431763, 'test-seed-4': 0.6513549089431763, 'test-seed-5': 0.6513549089431763}\n", + "[info] mean best threshold: 0.633838\n", + "[info] generating baseline roc curve with 400 thresholds in [0.483838, 0.783838]\n", + "[info] testing k values: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "import json\n", + "\n", + "seed_best_thresholds = {\n", + " \"test-seed-1\": 0.5926666259765625,\n", + " \"test-seed-2\": 0.622459352016449,\n", + " \"test-seed-3\": 0.6513549089431763,\n", + " \"test-seed-4\": 0.6513549089431763,\n", + " \"test-seed-5\": 0.6513549089431763,\n", + "}\n", + "\n", + "mean_best_threshold = float(np.mean(list(seed_best_thresholds.values())))\n", + "search_radius = 0.15\n", + "num_thresholds = 400\n", + "baseline_thresholds = np.linspace(\n", + " mean_best_threshold - search_radius,\n", + " mean_best_threshold + search_radius,\n", + " num_thresholds,\n", + ")\n", + "\n", + "print(f\"[info] per-seed best thresholds: {seed_best_thresholds}\")\n", + "print(f\"[info] mean best threshold: {mean_best_threshold:.6f}\")\n", + "print(f\"[info] generating baseline roc curve with {len(baseline_thresholds)} thresholds in [{baseline_thresholds[0]:.6f}, {baseline_thresholds[-1]:.6f}]\")\n", + "\n", + "# k values to test for our method\n", + "k_values = list(range(1, 11)) # 1 to 10\n", + "print(f\"[info] testing k values: {k_values}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "f6128fef", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[info] computing baseline roc curves for all seeds...\n", + "\n", + "[info] processing test-son-seed-1...\n", + "[info] test-son-seed-1 baseline progress: 1/400\n", + "[info] test-son-seed-1 baseline progress: 21/400\n", + "[info] test-son-seed-1 baseline progress: 41/400\n", + "[info] test-son-seed-1 baseline progress: 61/400\n", + "[info] test-son-seed-1 baseline progress: 81/400\n", + "[info] test-son-seed-1 baseline progress: 101/400\n", + "[info] test-son-seed-1 baseline progress: 121/400\n", + "[info] test-son-seed-1 baseline progress: 141/400\n", + "[info] test-son-seed-1 baseline progress: 161/400\n", + "[info] test-son-seed-1 baseline progress: 181/400\n", + "[info] test-son-seed-1 baseline progress: 201/400\n", + "[info] test-son-seed-1 baseline progress: 221/400\n", + "[info] test-son-seed-1 baseline progress: 241/400\n", + "[info] test-son-seed-1 baseline progress: 261/400\n", + "[info] test-son-seed-1 baseline progress: 281/400\n", + "[info] test-son-seed-1 baseline progress: 301/400\n", + "[info] test-son-seed-1 baseline progress: 321/400\n", + "[info] test-son-seed-1 baseline progress: 341/400\n", + "[info] test-son-seed-1 baseline progress: 361/400\n", + "[info] test-son-seed-1 baseline progress: 381/400\n", + "[pass] test-son-seed-1 baseline complete - 400 points\n", + "\n", + "[info] processing test-son-seed-2...\n", + "[info] test-son-seed-2 baseline progress: 1/400\n", + "[info] test-son-seed-2 baseline progress: 21/400\n", + "[info] test-son-seed-2 baseline progress: 41/400\n", + "[info] test-son-seed-2 baseline progress: 61/400\n", + "[info] test-son-seed-2 baseline progress: 81/400\n", + "[info] test-son-seed-2 baseline progress: 101/400\n", + "[info] test-son-seed-2 baseline progress: 121/400\n", + "[info] test-son-seed-2 baseline progress: 141/400\n", + "[info] test-son-seed-2 baseline progress: 161/400\n", + "[info] test-son-seed-2 baseline progress: 181/400\n", + "[info] test-son-seed-2 baseline progress: 201/400\n", + "[info] test-son-seed-2 baseline progress: 221/400\n", + "[info] test-son-seed-2 baseline progress: 241/400\n", + "[info] test-son-seed-2 baseline progress: 261/400\n", + "[info] test-son-seed-2 baseline progress: 281/400\n", + "[info] test-son-seed-2 baseline progress: 301/400\n", + "[info] test-son-seed-2 baseline progress: 321/400\n", + "[info] test-son-seed-2 baseline progress: 341/400\n", + "[info] test-son-seed-2 baseline progress: 361/400\n", + "[info] test-son-seed-2 baseline progress: 381/400\n", + "[pass] test-son-seed-2 baseline complete - 400 points\n", + "\n", + "[info] processing test-son-seed-3...\n", + "[info] test-son-seed-3 baseline progress: 1/400\n", + "[info] test-son-seed-3 baseline progress: 21/400\n", + "[info] test-son-seed-3 baseline progress: 41/400\n", + "[info] test-son-seed-3 baseline progress: 61/400\n", + "[info] test-son-seed-3 baseline progress: 81/400\n", + "[info] test-son-seed-3 baseline progress: 101/400\n", + "[info] test-son-seed-3 baseline progress: 121/400\n", + "[info] test-son-seed-3 baseline progress: 141/400\n", + "[info] test-son-seed-3 baseline progress: 161/400\n", + "[info] test-son-seed-3 baseline progress: 181/400\n", + "[info] test-son-seed-3 baseline progress: 201/400\n", + "[info] test-son-seed-3 baseline progress: 221/400\n", + "[info] test-son-seed-3 baseline progress: 241/400\n", + "[info] test-son-seed-3 baseline progress: 261/400\n", + "[info] test-son-seed-3 baseline progress: 281/400\n", + "[info] test-son-seed-3 baseline progress: 301/400\n", + "[info] test-son-seed-3 baseline progress: 321/400\n", + "[info] test-son-seed-3 baseline progress: 341/400\n", + "[info] test-son-seed-3 baseline progress: 361/400\n", + "[info] test-son-seed-3 baseline progress: 381/400\n", + "[pass] test-son-seed-3 baseline complete - 400 points\n", + "\n", + "[info] processing test-son-seed-4...\n", + "[info] test-son-seed-4 baseline progress: 1/400\n", + "[info] test-son-seed-4 baseline progress: 21/400\n", + "[info] test-son-seed-4 baseline progress: 41/400\n", + "[info] test-son-seed-4 baseline progress: 61/400\n", + "[info] test-son-seed-4 baseline progress: 81/400\n", + "[info] test-son-seed-4 baseline progress: 101/400\n", + "[info] test-son-seed-4 baseline progress: 121/400\n", + "[info] test-son-seed-4 baseline progress: 141/400\n", + "[info] test-son-seed-4 baseline progress: 161/400\n", + "[info] test-son-seed-4 baseline progress: 181/400\n", + "[info] test-son-seed-4 baseline progress: 201/400\n", + "[info] test-son-seed-4 baseline progress: 221/400\n", + "[info] test-son-seed-4 baseline progress: 241/400\n", + "[info] test-son-seed-4 baseline progress: 261/400\n", + "[info] test-son-seed-4 baseline progress: 281/400\n", + "[info] test-son-seed-4 baseline progress: 301/400\n", + "[info] test-son-seed-4 baseline progress: 321/400\n", + "[info] test-son-seed-4 baseline progress: 341/400\n", + "[info] test-son-seed-4 baseline progress: 361/400\n", + "[info] test-son-seed-4 baseline progress: 381/400\n", + "[pass] test-son-seed-4 baseline complete - 400 points\n", + "\n", + "[info] processing test-son-seed-5...\n", + "[info] test-son-seed-5 baseline progress: 1/400\n", + "[info] test-son-seed-5 baseline progress: 21/400\n", + "[info] test-son-seed-5 baseline progress: 41/400\n", + "[info] test-son-seed-5 baseline progress: 61/400\n", + "[info] test-son-seed-5 baseline progress: 81/400\n", + "[info] test-son-seed-5 baseline progress: 101/400\n", + "[info] test-son-seed-5 baseline progress: 121/400\n", + "[info] test-son-seed-5 baseline progress: 141/400\n", + "[info] test-son-seed-5 baseline progress: 161/400\n", + "[info] test-son-seed-5 baseline progress: 181/400\n", + "[info] test-son-seed-5 baseline progress: 201/400\n", + "[info] test-son-seed-5 baseline progress: 221/400\n", + "[info] test-son-seed-5 baseline progress: 241/400\n", + "[info] test-son-seed-5 baseline progress: 261/400\n", + "[info] test-son-seed-5 baseline progress: 281/400\n", + "[info] test-son-seed-5 baseline progress: 301/400\n", + "[info] test-son-seed-5 baseline progress: 321/400\n", + "[info] test-son-seed-5 baseline progress: 341/400\n", + "[info] test-son-seed-5 baseline progress: 361/400\n", + "[info] test-son-seed-5 baseline progress: 381/400\n", + "[pass] test-son-seed-5 baseline complete - 400 points\n", + "\n", + "[pass] all baseline roc curves computed for 5 seeds\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from convokit.decisionpolicy import ThresholdDecisionPolicy\n", + "from convokit.forecaster.forecaster import ContextTuple\n", + "\n", + "\n", + "def _cached_forecast_score(context):\n", + " meta = getattr(context.current_utterance, \"meta\", {}) or {}\n", + " if \"forecast_prob\" not in meta:\n", + " raise KeyError(f\"missing forecast_prob for utterance {context.current_utterance.id}\")\n", + " return meta[\"forecast_prob\"]\n", + "\n", + "\n", + "def _threshold_policy_performance(corpus, policy):\n", + " tp = fp = tn = fn = 0\n", + " horizons = []\n", + "\n", + " for convo in corpus.iter_conversations():\n", + " utts = convo.get_chronological_utterance_list()\n", + " pred = 0\n", + " first_trigger_idx = None\n", + "\n", + " for utt_idx, utt in enumerate(utts):\n", + " context = ContextTuple(\n", + " context=utts[: utt_idx + 1],\n", + " current_utterance=utt,\n", + " future_context=utts[utt_idx + 1 :],\n", + " conversation_id=convo.id,\n", + " )\n", + " _, utt_pred = policy.decide(context, _cached_forecast_score)\n", + " if int(utt_pred) == 1:\n", + " pred = 1\n", + " first_trigger_idx = utt_idx\n", + " break\n", + "\n", + " truth = bool(convo.meta.get(\"has_removed_comment\", False))\n", + " if pred and truth:\n", + " tp += 1\n", + " horizons.append(len(utts) - first_trigger_idx)\n", + " elif pred and not truth:\n", + " fp += 1\n", + " elif not pred and not truth:\n", + " tn += 1\n", + " else:\n", + " fn += 1\n", + "\n", + " h = float(np.mean(horizons)) - 1 if horizons else float(\"nan\")\n", + " return {\n", + " \"confusion_matrix\": {\"TP\": tp, \"FP\": fp, \"TN\": tn, \"FN\": fn},\n", + " \"h\": h,\n", + " }\n", + "\n", + "\n", + "# step 1: compute baseline roc curve with ThresholdDecisionPolicy\n", + "print(\"[info] computing baseline roc curves for all seeds...\")\n", + "all_baseline_results = {}\n", + "\n", + "for seed_name, corpus in corpora.items():\n", + " print(f\"\\n[info] processing {seed_name}...\")\n", + " baseline_results = []\n", + "\n", + " for i, threshold in enumerate(baseline_thresholds):\n", + " if i % 20 == 0:\n", + " print(f\"[info] {seed_name} baseline progress: {i+1}/{len(baseline_thresholds)}\")\n", + "\n", + " policy = ThresholdDecisionPolicy(threshold=threshold)\n", + " results = _threshold_policy_performance(corpus, policy)\n", + " tp = results[\"confusion_matrix\"][\"TP\"]\n", + " fp = results[\"confusion_matrix\"][\"FP\"]\n", + " tn = results[\"confusion_matrix\"][\"TN\"]\n", + " fn = results[\"confusion_matrix\"][\"FN\"]\n", + "\n", + " # calculate tpr, fpr, accuracy, precision, recall, f1\n", + " # recall == tpr; kept as a separate field for downstream readability\n", + " tpr = tp / (tp + fn) if (tp + fn) > 0 else 0\n", + " fpr = fp / (fp + tn) if (fp + tn) > 0 else 0\n", + " total = tp + fp + tn + fn\n", + " accuracy = (tp + tn) / total if total > 0 else 0\n", + " precision = tp / (tp + fp) if (tp + fp) > 0 else 0\n", + " recall = tpr\n", + " f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0\n", + "\n", + " baseline_results.append({\n", + " \"seed\": seed_name,\n", + " \"threshold\": threshold,\n", + " \"tpr\": tpr,\n", + " \"fpr\": fpr,\n", + " \"accuracy\": accuracy,\n", + " \"precision\": precision,\n", + " \"recall\": recall,\n", + " \"f1\": f1,\n", + " \"h\": results.get(\"h\", float(\"nan\")),\n", + " \"tp\": tp,\n", + " \"fp\": fp,\n", + " \"tn\": tn,\n", + " \"fn\": fn,\n", + " })\n", + "\n", + " all_baseline_results[seed_name] = pd.DataFrame(baseline_results)\n", + " print(f\"[pass] {seed_name} baseline complete - {len(all_baseline_results[seed_name])} points\")\n", + "\n", + "print(f\"\\n[pass] all baseline roc curves computed for {len(all_baseline_results)} seeds\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "022d240f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[info] all_baseline_results has been pickled to all_baseline_results.pkl\n" + ] + } + ], + "source": [ + "import pickle\n", + "\n", + "with open('all_baseline_results.pkl', 'wb') as f:\n", + " pickle.dump(all_baseline_results, f)\n", + "print(\"[info] all_baseline_results has been pickled to all_baseline_results.pkl\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "aa6450a5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "[info] processing test-son-seed-1...\n", + "[info] config seed: seed-1\n", + "[info] k method threshold 0.5926666259765625\n", + "[info] test-son-seed-1 computing k=1...\n", + "[info] test-son-seed-1 computing k=2...\n", + "[info] test-son-seed-1 computing k=3...\n", + "[info] test-son-seed-1 computing k=4...\n", + "[info] test-son-seed-1 computing k=5...\n", + "[info] test-son-seed-1 computing k=6...\n", + "[info] test-son-seed-1 computing k=7...\n", + "[info] test-son-seed-1 computing k=8...\n", + "[info] test-son-seed-1 computing k=9...\n", + "[info] test-son-seed-1 computing k=10...\n", + "[pass] test-son-seed-1 k-method complete - 10 points\n", + "\n", + "[info] processing test-son-seed-2...\n", + "[info] config seed: seed-2\n", + "[info] k method threshold 0.622459352016449\n", + "[info] test-son-seed-2 computing k=1...\n", + "[info] test-son-seed-2 computing k=2...\n", + "[info] test-son-seed-2 computing k=3...\n", + "[info] test-son-seed-2 computing k=4...\n", + "[info] test-son-seed-2 computing k=5...\n", + "[info] test-son-seed-2 computing k=6...\n", + "[info] test-son-seed-2 computing k=7...\n", + "[info] test-son-seed-2 computing k=8...\n", + "[info] test-son-seed-2 computing k=9...\n", + "[info] test-son-seed-2 computing k=10...\n", + "[pass] test-son-seed-2 k-method complete - 10 points\n", + "\n", + "[info] processing test-son-seed-3...\n", + "[info] config seed: seed-3\n", + "[info] k method threshold 0.6513549089431763\n", + "[info] test-son-seed-3 computing k=1...\n", + "[info] test-son-seed-3 computing k=2...\n", + "[info] test-son-seed-3 computing k=3...\n", + "[info] test-son-seed-3 computing k=4...\n", + "[info] test-son-seed-3 computing k=5...\n", + "[info] test-son-seed-3 computing k=6...\n", + "[info] test-son-seed-3 computing k=7...\n", + "[info] test-son-seed-3 computing k=8...\n", + "[info] test-son-seed-3 computing k=9...\n", + "[info] test-son-seed-3 computing k=10...\n", + "[pass] test-son-seed-3 k-method complete - 10 points\n", + "\n", + "[info] processing test-son-seed-4...\n", + "[info] config seed: seed-4\n", + "[info] k method threshold 0.6513549089431763\n", + "[info] test-son-seed-4 computing k=1...\n", + "[info] test-son-seed-4 computing k=2...\n", + "[info] test-son-seed-4 computing k=3...\n", + "[info] test-son-seed-4 computing k=4...\n", + "[info] test-son-seed-4 computing k=5...\n", + "[info] test-son-seed-4 computing k=6...\n", + "[info] test-son-seed-4 computing k=7...\n", + "[info] test-son-seed-4 computing k=8...\n", + "[info] test-son-seed-4 computing k=9...\n", + "[info] test-son-seed-4 computing k=10...\n", + "[pass] test-son-seed-4 k-method complete - 10 points\n", + "\n", + "[info] processing test-son-seed-5...\n", + "[info] config seed: seed-5\n", + "[info] k method threshold 0.6513549089431763\n", + "[info] test-son-seed-5 computing k=1...\n", + "[info] test-son-seed-5 computing k=2...\n", + "[info] test-son-seed-5 computing k=3...\n", + "[info] test-son-seed-5 computing k=4...\n", + "[info] test-son-seed-5 computing k=5...\n", + "[info] test-son-seed-5 computing k=6...\n", + "[info] test-son-seed-5 computing k=7...\n", + "[info] test-son-seed-5 computing k=8...\n", + "[info] test-son-seed-5 computing k=9...\n", + "[info] test-son-seed-5 computing k=10...\n", + "[pass] test-son-seed-5 k-method complete - 10 points\n", + "\n", + "[pass] all k-method results computed for 5 seeds\n" + ] + } + ], + "source": [ + "import json\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "from convokit.decisionpolicy import DeferralDecisionPolicy\n", + "from convokit.forecaster.forecaster import ContextTuple\n", + "\n", + "\n", + "def _cached_forecast_score(context):\n", + " meta = getattr(context.current_utterance, \"meta\", {}) or {}\n", + " if \"forecast_prob\" not in meta:\n", + " raise KeyError(f\"missing forecast_prob for utterance {context.current_utterance.id}\")\n", + " return meta[\"forecast_prob\"]\n", + "\n", + "\n", + "def _deferral_policy_performance(corpus, policy):\n", + " tp = fp = tn = fn = 0\n", + " horizons = []\n", + "\n", + " for convo in corpus.iter_conversations():\n", + " utts = convo.get_chronological_utterance_list()\n", + " pred = 0\n", + " first_trigger_idx = None\n", + "\n", + " for utt_idx, utt in enumerate(utts):\n", + " context = ContextTuple(\n", + " context=utts[: utt_idx + 1],\n", + " current_utterance=utt,\n", + " future_context=utts[utt_idx + 1 :],\n", + " conversation_id=convo.id,\n", + " )\n", + " result = policy.decide(context, _cached_forecast_score)\n", + " utt_pred = int(result[1])\n", + " if utt_pred == 1:\n", + " pred = 1\n", + " first_trigger_idx = utt_idx\n", + " break\n", + "\n", + " truth = bool(convo.meta.get(\"has_removed_comment\", False))\n", + " if pred and truth:\n", + " tp += 1\n", + " horizons.append(len(utts) - first_trigger_idx)\n", + " elif pred and not truth:\n", + " fp += 1\n", + " elif not pred and not truth:\n", + " tn += 1\n", + " else:\n", + " fn += 1\n", + "\n", + " h = float(np.mean(horizons)) - 1 if horizons else float(\"nan\")\n", + " return {\n", + " \"confusion_matrix\": {\"TP\": tp, \"FP\": fp, \"TN\": tn, \"FN\": fn},\n", + " \"h\": h,\n", + " }\n", + "\n", + "\n", + "all_k_results = {}\n", + "\n", + "for seed_name, corpus in corpora.items():\n", + " print(f\"\\n[info] processing {seed_name}...\")\n", + " k_results = []\n", + "\n", + " pruned_seed_name = seed_name[len(seed_name) - 6 :]\n", + " print(f\"[info] config seed: {pruned_seed_name}\")\n", + "\n", + " with open(f\"/reef/sqt2/FinalAAO/cga-cmv-large/google/gemma-2-9b-it/{pruned_seed_name}/dev_config.json\", \"r\") as f:\n", + " dev_config = json.load(f)\n", + "\n", + " k_method_threshold = dev_config[\"best_threshold\"]\n", + " print(\"[info] k method threshold\", k_method_threshold)\n", + "\n", + " for k in k_values:\n", + " print(f\"[info] {seed_name} computing k={k}...\")\n", + " policy = DeferralDecisionPolicy(\n", + " simulator=None,\n", + " threshold=k_method_threshold,\n", + " tau=k,\n", + " num_simulations=10,\n", + " store_simulations=False,\n", + " simulated_reply_attribute_name=\"sim_replies\",\n", + " sim_replies_forecast_probs_attribute_name=\"sim_replies_forecast_probs\",\n", + " reuse_cached_simulations=True,\n", + " )\n", + " results = _deferral_policy_performance(corpus, policy)\n", + " tp = results[\"confusion_matrix\"][\"TP\"]\n", + " fp = results[\"confusion_matrix\"][\"FP\"]\n", + " tn = results[\"confusion_matrix\"][\"TN\"]\n", + " fn = results[\"confusion_matrix\"][\"FN\"]\n", + "\n", + " # calculate tpr, fpr, accuracy, precision, recall, f1\n", + " # recall == tpr; kept as a separate field for downstream readability\n", + " tpr = tp / (tp + fn) if (tp + fn) > 0 else 0\n", + " fpr = fp / (fp + tn) if (fp + tn) > 0 else 0\n", + " total = tp + fp + tn + fn\n", + " accuracy = (tp + tn) / total if total > 0 else 0\n", + " precision = tp / (tp + fp) if (tp + fp) > 0 else 0\n", + " recall = tpr\n", + " f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0\n", + "\n", + " k_results.append({\n", + " \"seed\": seed_name,\n", + " \"k\": k,\n", + " \"threshold\": k_method_threshold,\n", + " \"tpr\": tpr,\n", + " \"fpr\": fpr,\n", + " \"accuracy\": accuracy,\n", + " \"precision\": precision,\n", + " \"recall\": recall,\n", + " \"f1\": f1,\n", + " \"h\": results.get(\"h\", float(\"nan\")),\n", + " \"tp\": tp,\n", + " \"fp\": fp,\n", + " \"tn\": tn,\n", + " \"fn\": fn,\n", + " })\n", + "\n", + " all_k_results[seed_name] = pd.DataFrame(k_results)\n", + " print(f\"[pass] {seed_name} k-method complete - {len(all_k_results[seed_name])} points\")\n", + "\n", + "print(f\"\\n[pass] all k-method results computed for {len(all_k_results)} seeds\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "bba8d667", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[info] matching fprs for all seeds...\n", + "\n", + "[info] matching fprs for test-son-seed-1...\n", + "[PASS] test-son-seed-1 matched 10 fpr points\n", + "\n", + "[info] matching fprs for test-son-seed-2...\n", + "[PASS] test-son-seed-2 matched 10 fpr points\n", + "\n", + "[info] matching fprs for test-son-seed-3...\n", + "[PASS] test-son-seed-3 matched 10 fpr points\n", + "\n", + "[info] matching fprs for test-son-seed-4...\n", + "[PASS] test-son-seed-4 matched 10 fpr points\n", + "\n", + "[info] matching fprs for test-son-seed-5...\n", + "[PASS] test-son-seed-5 matched 10 fpr points\n", + "\n", + "[PASS] all matching complete for 5 seeds\n" + ] + } + ], + "source": [ + "# step 3: for each k's FPR, find the closest baseline FPR (per seed)\n", + "print(\"[info] matching fprs for all seeds...\")\n", + "all_matched_results = {}\n", + "\n", + "for seed_name in corpora.keys():\n", + " print(f\"\\n[info] matching fprs for {seed_name}...\")\n", + " matched_results = []\n", + " \n", + " k_df = all_k_results[seed_name]\n", + " baseline_df = all_baseline_results[seed_name]\n", + " \n", + " for _, k_row in k_df.iterrows():\n", + " k_fpr = k_row['fpr']\n", + " k_tpr = k_row['tpr']\n", + " k_val = k_row['k']\n", + " \n", + " # find closest baseline fpr\n", + " fpr_diffs = np.abs(baseline_df['fpr'] - k_fpr)\n", + " closest_idx = fpr_diffs.idxmin()\n", + " baseline_match = baseline_df.iloc[closest_idx]\n", + " \n", + " matched_results.append({\n", + " 'seed': seed_name,\n", + " 'k': k_val,\n", + " 'k_fpr': k_fpr,\n", + " 'k_tpr': k_tpr,\n", + " 'baseline_fpr': baseline_match['fpr'],\n", + " 'baseline_tpr': baseline_match['tpr'],\n", + " 'baseline_accuracy': baseline_match['accuracy'],\n", + " 'baseline_precision': baseline_match['precision'],\n", + " 'baseline_recall': baseline_match['recall'],\n", + " 'baseline_f1': baseline_match['f1'],\n", + " 'baseline_h': baseline_match['h'],\n", + " 'baseline_threshold': baseline_match['threshold'],\n", + " 'fpr_diff': abs(k_fpr - baseline_match['fpr']),\n", + " 'tpr_improvement': k_tpr - baseline_match['tpr']\n", + " })\n", + " \n", + " all_matched_results[seed_name] = pd.DataFrame(matched_results)\n", + " print(f\"[PASS] {seed_name} matched {len(all_matched_results[seed_name])} fpr points\")\n", + "\n", + "print(f\"\\n[PASS] all matching complete for {len(all_matched_results)} seeds\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "95588d73", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[info] matched fpr-oracle metrics for tau=7 (mean ± std over seeds, n=5):\n", + " tau baseline_accuracy_mean baseline_accuracy_std baseline_precision_mean baseline_precision_std baseline_recall_mean baseline_recall_std baseline_f1_mean baseline_f1_std baseline_h_mean baseline_h_std fpr_diff_mean fpr_diff_std tpr_improvement_mean tpr_improvement_std baseline_threshold_mean baseline_threshold_std baseline_fpr_mean baseline_fpr_std baseline_tpr_mean baseline_tpr_std\n", + " 7 0.7002 0.0088 0.7152 0.0100 0.6677 0.0494 0.6896 0.0218 2.6965 0.1029 0.0036 0.0028 0.0155 0.0055 0.6900 0.0245 0.2674 0.0331 0.6677 0.0494\n" + ] + } + ], + "source": [ + "# average matched-baseline (oracle) metrics for tau=7, across all seeds\n", + "target_tau = 7\n", + "matched_long = pd.concat(\n", + " [df.assign(seed=seed_name) for seed_name, df in all_matched_results.items()],\n", + " ignore_index=True,\n", + ")\n", + "matched_long = matched_long[matched_long[\"k\"] == target_tau].copy()\n", + "\n", + "if matched_long.empty:\n", + " raise ValueError(f\"no matched results found for tau={target_tau}\")\n", + "\n", + "oracle_cols = [\n", + " \"baseline_accuracy\",\n", + " \"baseline_precision\",\n", + " \"baseline_recall\",\n", + " \"baseline_f1\",\n", + " \"baseline_h\",\n", + " \"fpr_diff\",\n", + " \"tpr_improvement\",\n", + " \"baseline_threshold\",\n", + " \"baseline_fpr\",\n", + " \"baseline_tpr\",\n", + "]\n", + "\n", + "oracle_stats_per_tau = (\n", + " matched_long.groupby(\"k\")[oracle_cols]\n", + " .agg([\"mean\", \"std\"])\n", + " .reset_index()\n", + " .rename(columns={\"k\": \"tau\"})\n", + ")\n", + "\n", + "# reformat columns to single-level, e.g. baseline_fpr_mean\n", + "oracle_stats_per_tau.columns = [\"tau\"] + [\n", + " f\"{col}_{stat}\" for col in oracle_cols for stat in [\"mean\", \"std\"]\n", + "]\n", + "\n", + "with pd.option_context(\n", + " \"display.float_format\",\n", + " \"{:.4f}\".format,\n", + " \"display.max_columns\",\n", + " None,\n", + " \"display.width\",\n", + " 200,\n", + "):\n", + " print(\"[info] matched fpr-oracle metrics for tau=7 \"\n", + " f\"(mean ± std over seeds, n={matched_long['seed'].nunique()}):\")\n", + " print(oracle_stats_per_tau.to_string(index=False))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "lyk25-env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}