From ef1991ce4ace6db229ea2a4f00f06d7fc9cb0c65 Mon Sep 17 00:00:00 2001 From: Laerdon Kim <96972420+laerdon@users.noreply.github.com> Date: Wed, 1 Apr 2026 12:23:32 -0400 Subject: [PATCH 1/6] Introduce pluggable decision policies for ForecasterModel (#1) * refactor forecaster model to use decision policies Co-authored-by: Laerdon Kim * update forecaster models for score and policy orchestration Co-authored-by: Laerdon Kim --------- Co-authored-by: Cursor Agent Co-authored-by: Laerdon Kim --- convokit/__init__.py | 1 + convokit/decisionpolicy/__init__.py | 3 + convokit/decisionpolicy/decisionPolicy.py | 41 ++++ .../decisionpolicy/deferralDecisionPolicy.py | 144 ++++++++++++++ .../decisionpolicy/thresholdDecisionPolicy.py | 57 ++++++ convokit/forecaster/CRAFTModel.py | 178 ++++++++++++++---- .../forecaster/TransformerDecoderModel.py | 81 +++++--- .../forecaster/TransformerEncoderModel.py | 96 +++++++--- convokit/forecaster/__init__.py | 1 + convokit/forecaster/cumulativeBoW.py | 35 ++++ convokit/forecaster/forecaster.py | 4 +- convokit/forecaster/forecasterModel.py | 61 +++++- docs/source/decisionpolicy.rst | 9 + docs/source/forecaster.rst | 1 + 14 files changed, 626 insertions(+), 86 deletions(-) create mode 100644 convokit/decisionpolicy/__init__.py create mode 100644 convokit/decisionpolicy/decisionPolicy.py create mode 100644 convokit/decisionpolicy/deferralDecisionPolicy.py create mode 100644 convokit/decisionpolicy/thresholdDecisionPolicy.py create mode 100644 docs/source/decisionpolicy.rst diff --git a/convokit/__init__.py b/convokit/__init__.py index 9bd3462b9..d2a8f4f9a 100644 --- a/convokit/__init__.py +++ b/convokit/__init__.py @@ -21,6 +21,7 @@ "classifier": ".classifier", "ranker": ".ranker", "forecaster": ".forecaster", + "decisionpolicy": ".decisionpolicy", "fighting_words": ".fighting_words", "paired_prediction": ".paired_prediction", "bag_of_words": ".bag_of_words", diff --git a/convokit/decisionpolicy/__init__.py b/convokit/decisionpolicy/__init__.py new file mode 100644 index 000000000..43e3ab9bc --- /dev/null +++ b/convokit/decisionpolicy/__init__.py @@ -0,0 +1,3 @@ +from .decisionPolicy import * +from .thresholdDecisionPolicy import * +from .deferralDecisionPolicy import * diff --git a/convokit/decisionpolicy/decisionPolicy.py b/convokit/decisionpolicy/decisionPolicy.py new file mode 100644 index 000000000..f0aca8546 --- /dev/null +++ b/convokit/decisionpolicy/decisionPolicy.py @@ -0,0 +1,41 @@ +from abc import ABC, abstractmethod +from typing import Callable + + +class DecisionPolicy(ABC): + """ + Abstract interface for converting a conversational context into an action. + """ + + def __init__(self): + self._labeler = None + + @property + def labeler(self): + return self._labeler + + @labeler.setter + def labeler(self, value: Callable): + self._labeler = value + + @abstractmethod + def decide(self, context, score_fn: Callable) -> int: + """ + Decide whether to intervene for a context. + + :param context: context tuple supplied by Forecaster + :param score_fn: callable that maps a context tuple to a scalar score + :return: integer action label (currently 0/1) + """ + pass + + @abstractmethod + def fit(self, contexts, val_contexts=None, score_fn: Callable = None): + """ + Fit policy-specific parameters if needed. + + :param contexts: training contexts for policy fitting + :param val_contexts: optional validation contexts + :param score_fn: optional scorer callable exposed by ForecasterModel + """ + pass diff --git a/convokit/decisionpolicy/deferralDecisionPolicy.py b/convokit/decisionpolicy/deferralDecisionPolicy.py new file mode 100644 index 000000000..ab8d2848e --- /dev/null +++ b/convokit/decisionpolicy/deferralDecisionPolicy.py @@ -0,0 +1,144 @@ +from itertools import tee +from typing import Callable, List, Optional + +import numpy as np +from sklearn.metrics import roc_curve + +from .decisionPolicy import DecisionPolicy + + +class _synthetic_speaker: + def __init__(self, speaker_id: str): + self.id = speaker_id + + +class _synthetic_utterance: + def __init__(self, text: str, utterance_id: str, speaker_id: str): + self.text = text + self.id = utterance_id + self.speaker_ = _synthetic_speaker(speaker_id) + self.meta = {} + + def get_conversation(self): + return None + + +class DeferralDecisionPolicy(DecisionPolicy): + """ + Decision policy that can defer intervention using simulated next utterances. + """ + + def __init__( + self, + simulator=None, + threshold: float = 0.5, + num_simulations: int = 3, + aggregation: str = "mean", + ): + super().__init__() + self.simulator = simulator + self.threshold = float(threshold) + self.num_simulations = int(num_simulations) + self.aggregation = aggregation + + def _aggregate_scores(self, scores: List[float]) -> float: + if len(scores) == 0: + return 0.0 + if self.aggregation == "max": + return float(np.max(scores)) + if self.aggregation == "min": + return float(np.min(scores)) + return float(np.mean(scores)) + + def get_simulations(self, context, simulator=None, k: Optional[int] = None) -> List[str]: + simulator = simulator if simulator is not None else self.simulator + if k is None: + k = self.num_simulations + if simulator is None: + return [] + if callable(simulator): + sims = simulator(context, k) + return list(sims)[:k] + if hasattr(simulator, "get_simulations"): + sims = simulator.get_simulations(context, k) + return list(sims)[:k] + if hasattr(simulator, "transform"): + sims = simulator.transform(iter([context])) + if context.current_utterance.id in sims.index: + col_name = sims.columns[0] + return list(sims.loc[context.current_utterance.id][col_name])[:k] + return [] + + def _build_simulated_context(self, context, simulation_text: str, simulation_idx: int): + current_utt = context.current_utterance + synthetic_utt = _synthetic_utterance( + text=simulation_text, + utterance_id=f"{current_utt.id}__sim_{simulation_idx}", + speaker_id="simulator", + ) + new_context_utts = list(context.context) + [synthetic_utt] + context_cls = context.__class__ + return context_cls( + context=new_context_utts, + current_utterance=synthetic_utt, + future_context=None, + conversation_id=context.conversation_id, + ) + + def _decision_score(self, context, score_fn: Callable) -> float: + current_score = float(score_fn(context)) + simulations = self.get_simulations(context) + if len(simulations) == 0: + return current_score + simulation_scores = [] + for idx, sim_text in enumerate(simulations): + sim_context = self._build_simulated_context(context, sim_text, idx) + simulation_scores.append(float(score_fn(sim_context))) + return self._aggregate_scores([current_score] + simulation_scores) + + def decide(self, context, score_fn: Callable) -> int: + decision_score = self._decision_score(context, score_fn) + return int(decision_score > self.threshold) + + def fit(self, contexts, val_contexts=None, score_fn: Callable = None): + if self.simulator is not None and hasattr(self.simulator, "fit"): + if val_contexts is None: + sim_contexts = contexts + sim_val_contexts = None + else: + sim_contexts, contexts = tee(contexts, 2) + sim_val_contexts, val_contexts = tee(val_contexts, 2) + self.simulator.fit(sim_contexts, sim_val_contexts) + + if val_contexts is None or score_fn is None or self.labeler is None: + return {"threshold": self.threshold} + val_contexts = list(val_contexts) + if len(val_contexts) == 0: + return {"threshold": self.threshold} + + highest_convo_scores = {} + convo_labels = {} + for context in val_contexts: + convo_id = context.conversation_id + score = self._decision_score(context, score_fn) + label = int(self.labeler(context.current_utterance.get_conversation())) + if convo_id not in highest_convo_scores: + highest_convo_scores[convo_id] = score + else: + highest_convo_scores[convo_id] = max(highest_convo_scores[convo_id], score) + convo_labels[convo_id] = label + + convo_ids = list(highest_convo_scores.keys()) + y_true = np.asarray([convo_labels[c] for c in convo_ids]) + y_score = np.asarray([highest_convo_scores[c] for c in convo_ids]) + try: + _, _, thresholds = roc_curve(y_true, y_score) + except ValueError: + return {"threshold": self.threshold} + if len(thresholds) == 0: + return {"threshold": self.threshold} + + accs = [((y_score > t).astype(int) == y_true).mean() for t in thresholds] + best_idx = int(np.argmax(accs)) + self.threshold = float(thresholds[best_idx]) + return {"threshold": self.threshold, "best_val_accuracy": float(accs[best_idx])} diff --git a/convokit/decisionpolicy/thresholdDecisionPolicy.py b/convokit/decisionpolicy/thresholdDecisionPolicy.py new file mode 100644 index 000000000..0f3814a44 --- /dev/null +++ b/convokit/decisionpolicy/thresholdDecisionPolicy.py @@ -0,0 +1,57 @@ +from typing import Callable + +import numpy as np +from sklearn.metrics import roc_curve + +from .decisionPolicy import DecisionPolicy + + +class ThresholdDecisionPolicy(DecisionPolicy): + """ + A simple decision policy that predicts 1 when score > threshold. + """ + + def __init__(self, threshold: float = 0.5): + super().__init__() + self.threshold = float(threshold) + + def decide(self, context, score_fn: Callable) -> int: + return int(score_fn(context) > self.threshold) + + def fit(self, contexts, val_contexts=None, score_fn: Callable = None): + if val_contexts is None or score_fn is None or self.labeler is None: + return {"best_threshold": self.threshold} + + val_contexts = list(val_contexts) + if len(val_contexts) == 0: + return {"best_threshold": self.threshold} + + highest_convo_scores = {} + convo_labels = {} + for context in val_contexts: + convo_id = context.conversation_id + score = score_fn(context) + label = int(self.labeler(context.current_utterance.get_conversation())) + if convo_id not in highest_convo_scores: + highest_convo_scores[convo_id] = score + else: + highest_convo_scores[convo_id] = max(highest_convo_scores[convo_id], score) + convo_labels[convo_id] = label + + convo_ids = list(highest_convo_scores.keys()) + y_true = np.asarray([convo_labels[c] for c in convo_ids]) + y_score = np.asarray([highest_convo_scores[c] for c in convo_ids]) + + # roc_curve can fail when only one class is present; keep current threshold in that case. + try: + _, _, thresholds = roc_curve(y_true, y_score) + except ValueError: + return {"best_threshold": self.threshold} + + if len(thresholds) == 0: + return {"best_threshold": self.threshold} + + accs = [((y_score > t).astype(int) == y_true).mean() for t in thresholds] + best_idx = int(np.argmax(accs)) + self.threshold = float(thresholds[best_idx]) + return {"best_threshold": self.threshold, "best_val_accuracy": float(accs[best_idx])} diff --git a/convokit/forecaster/CRAFTModel.py b/convokit/forecaster/CRAFTModel.py index b0937e7f4..e100e0a1e 100644 --- a/convokit/forecaster/CRAFTModel.py +++ b/convokit/forecaster/CRAFTModel.py @@ -10,13 +10,12 @@ from convokit import download, warn from convokit.convokitConfig import ConvoKitConfig from .CRAFT.model import EncoderRNN, ContextEncoderRNN, SingleTargetClf -from .CRAFT.runners import Predictor, trainIters, evaluateDataset +from .CRAFT.runners import Predictor, trainIters, evaluateBatch from .forecasterModel import ForecasterModel -import numpy as np -import torch.nn.functional as F from torch import optim, nn from typing import Dict, Union import os +from convokit.decisionpolicy import ThresholdDecisionPolicy # parameters baked into the model design (because the provided models were saved with these parameters); # these cannot be changed by the user @@ -89,8 +88,9 @@ def __init__( decision_threshold: Union[float, str] = "auto", torch_device: str = "cpu", config: dict = DEFAULT_CONFIG, + decision_policy=None, ): - super().__init__() + super().__init__(decision_policy=decision_policy) # load the initial weights and store this as the current model if initial_weights in MODEL_FILENAME_MAP: @@ -131,18 +131,36 @@ def __init__( raise TypeError("CRAFTModel: decision_threshold must be either a float or 'auto'") self._decision_threshold = DECISION_THRESHOLDS.get(initial_weights, 0.5) + if isinstance(self.decision_policy, ThresholdDecisionPolicy): + self.decision_policy.threshold = float(self._decision_threshold) + self._device = torch.device(torch_device) self._config = config + self._inference_components = None + + @property + def best_threshold(self): + if hasattr(self.decision_policy, "threshold"): + return self.decision_policy.threshold + return None - def _context_to_craft_data(self, contexts): + @best_threshold.setter + def best_threshold(self, value): + if hasattr(self.decision_policy, "threshold"): + self.decision_policy.threshold = float(value) + + def _context_to_craft_data(self, contexts, include_labels=True): """ Convert context utterances to a list of token-lists using the model's vocabulary object, maintaining the original temporal ordering """ pairs = [] for context in contexts: - convo = context.current_utterance.get_conversation() - label = self.labeler(convo) + if include_labels: + convo = context.current_utterance.get_conversation() + label = self.labeler(convo) + else: + label = 0 processed_context = processContext(self._voc, context, label) utt = processed_context[-1]["tokens"][: (MAX_LENGTH - 1)] context_utts = [u["tokens"][: (MAX_LENGTH - 1)] for u in processed_context] @@ -188,7 +206,17 @@ def _init_craft(self): return embedding, encoder, context_encoder, attack_clf - def fit(self, contexts, val_contexts=None): + def _get_inference_components(self): + if self._inference_components is None: + embedding, encoder, context_encoder, attack_clf = self._init_craft() + encoder.eval() + context_encoder.eval() + attack_clf.eval() + predictor = Predictor(encoder, context_encoder, attack_clf) + self._inference_components = (encoder, context_encoder, predictor) + return self._inference_components + + def fit_belief_estimator(self, contexts, val_contexts=None): """ Fine-tune the CRAFT model, and save the best model according to validation performance. @@ -196,12 +224,12 @@ def fit(self, contexts, val_contexts=None): :param val_contexts: an iterator over context tuples to be used only for validation. IMPORTANT: this is marked Optional only for compatibility with the generic Forecaster API; CRAFT actually REQUIRES a validation set so leaving this parameter at None will raise an error! """ # convert the input contexts into CRAFT's data format - train_pairs = self._context_to_craft_data(contexts) + train_pairs = self._context_to_craft_data(contexts, include_labels=True) print("Processed", len(train_pairs), "context tuples for model training") # val_contexts is made Optional to conform to the Forecaster spec, but in reality CRAFT requires a validation set if val_contexts is None: raise ValueError("CRAFTModel requires a validation set!") - val_pairs = self._context_to_craft_data(val_contexts) + val_pairs = self._context_to_craft_data(val_contexts, include_labels=True) print("Processed", len(val_pairs), "context tuples for model validation") # initialize the CRAFT model with whatever weights we currently have saved @@ -252,6 +280,50 @@ def fit(self, contexts, val_contexts=None): # save the resulting checkpoints so we can load them later during transform self._model = best_model + self._inference_components = None + + def fit_decision_policy(self, contexts, val_contexts=None): + return super().fit_decision_policy(contexts, val_contexts) + + def fit(self, contexts, val_contexts=None): + return super().fit(contexts, val_contexts) + + def score(self, context) -> float: + encoder, context_encoder, predictor = self._get_inference_components() + score_pairs = self._context_to_craft_data([context], include_labels=False) + batch, batch_dialogs, _, true_batch_size = next( + batchIterator(self._voc, score_pairs, batch_size=1, shuffle=False) + ) + ( + input_variable, + dialog_lengths, + utt_lengths, + batch_indices, + dialog_indices, + labels, + convo_ids, + target_variable, + mask, + max_target_len, + ) = batch + dialog_lengths_list = [len(x) for x in batch_dialogs] + _, scores = evaluateBatch( + encoder, + context_encoder, + predictor, + self._voc, + input_variable, + dialog_lengths, + dialog_lengths_list, + utt_lengths, + batch_indices, + dialog_indices, + true_batch_size, + self._device, + MAX_LENGTH, + threshold=self.best_threshold if self.best_threshold is not None else 0.5, + ) + return float(scores[0].item()) def transform(self, contexts, forecast_attribute_name, forecast_prob_attribute_name): """ @@ -264,34 +336,72 @@ def transform(self, contexts, forecast_attribute_name, forecast_prob_attribute_n :return: a Pandas DataFrame, with one row for each context, indexed by the ID of that context's current utterance. Contains two columns, one with raw probabilities named according to forecast_prob_attribute_name, and one with discretized (binary) forecasts named according to forecast_attribute_name """ # convert the input contexts into CRAFT's data format - test_pairs = self._context_to_craft_data(contexts) + contexts = list(contexts) + context_by_utt_id = {context.current_utterance.id: context for context in contexts} + test_pairs = self._context_to_craft_data(contexts, include_labels=False) print("Processed", len(test_pairs), "context tuples for model evaluation") # initialize the CRAFT model with whatever weights we currently have saved - embedding, encoder, context_encoder, attack_clf = self._init_craft() - - # Set dropout layers to eval mode - encoder.eval() - context_encoder.eval() - attack_clf.eval() + encoder, context_encoder, predictor = self._get_inference_components() - # Initialize the pipeline - predictor = Predictor(encoder, context_encoder, attack_clf) - - # Run the pipeline! - forecasts_df = evaluateDataset( - test_pairs, - encoder, - context_encoder, - predictor, - self._voc, - self._config["batch_size"], - self._device, - MAX_LENGTH, - batchIterator, - self._decision_threshold, - forecast_attribute_name, - forecast_prob_attribute_name, + output_df = {"id": [], forecast_attribute_name: [], forecast_prob_attribute_name: []} + batch_iterator = batchIterator( + self._voc, test_pairs, self._config["batch_size"], shuffle=False + ) + n_iters = len(test_pairs) // self._config["batch_size"] + int( + len(test_pairs) % self._config["batch_size"] > 0 ) + for iteration in range(1, n_iters + 1): + batch, batch_dialogs, _, true_batch_size = next(batch_iterator) + ( + input_variable, + dialog_lengths, + utt_lengths, + batch_indices, + dialog_indices, + labels, + convo_ids, + target_variable, + mask, + max_target_len, + ) = batch + dialog_lengths_list = [len(x) for x in batch_dialogs] + _, scores = evaluateBatch( + encoder, + context_encoder, + predictor, + self._voc, + input_variable, + dialog_lengths, + dialog_lengths_list, + utt_lengths, + batch_indices, + dialog_indices, + true_batch_size, + self._device, + MAX_LENGTH, + threshold=self.best_threshold if self.best_threshold is not None else 0.5, + ) + for i in range(true_batch_size): + score = float(scores[i].item()) + utt_id = convo_ids[i] + context = context_by_utt_id[utt_id] + + def score_fn(scored_context): + scored_utt_id = scored_context.current_utterance.id + if scored_utt_id == utt_id: + return score + return self.score(scored_context) + + pred = self.decision_policy.decide(context, score_fn) + output_df["id"].append(utt_id) + output_df[forecast_attribute_name].append(int(pred)) + output_df[forecast_prob_attribute_name].append(score) + print( + "Iteration: {}; Percent complete: {:.1f}%".format( + iteration, iteration / n_iters * 100 + ) + ) + forecasts_df = pd.DataFrame(output_df).set_index("id") return forecasts_df diff --git a/convokit/forecaster/TransformerDecoderModel.py b/convokit/forecaster/TransformerDecoderModel.py index a8d1813ea..e2df302c2 100644 --- a/convokit/forecaster/TransformerDecoderModel.py +++ b/convokit/forecaster/TransformerDecoderModel.py @@ -14,6 +14,7 @@ from sklearn.metrics import roc_curve from .forecasterModel import ForecasterModel from .TransformerForecasterConfig import TransformerForecasterConfig +from convokit.decisionpolicy import ThresholdDecisionPolicy import shutil @@ -72,7 +73,9 @@ def __init__( config=DEFAULT_CONFIG, system_msg=None, question_msg=None, + decision_policy=None, ): + super().__init__(decision_policy=decision_policy) self.max_seq_length = 4_096 * 2 self.model, tokenizer = FastLanguageModel.from_pretrained( model_name=model_name_or_path, @@ -100,7 +103,6 @@ def __init__( "Will the above conversation derail into a personal attack now or at any point in the future? " "Strictly start your answer with Yes or No, otherwise the answer is invalid." ) - self.best_threshold = 0.5 if not os.path.exists(config.output_dir): os.makedirs(config.output_dir) @@ -108,6 +110,17 @@ def __init__( return + @property + def best_threshold(self): + if hasattr(self.decision_policy, "threshold"): + return self.decision_policy.threshold + return None + + @best_threshold.setter + def best_threshold(self, value): + if hasattr(self.decision_policy, "threshold"): + self.decision_policy.threshold = float(value) + def _context_mode(self, context): """ Select the utterances to include in the input context based on the configured context mode. @@ -218,7 +231,7 @@ def _context_to_llm_data(self, contexts): print(f"There are {len(dataset)} samples") return Dataset.from_list(dataset) - def fit(self, train_contexts, val_contexts): + def fit_belief_estimator(self, train_contexts, val_contexts=None): """ Fine-tune the TransformerDecoder model using LoRA and save the best model based on validation performance. @@ -282,7 +295,6 @@ def fit(self, train_contexts, val_contexts): ), ) trainer.train() - _ = self._tune_threshold(self, val_contexts) return def _tune_threshold(self, val_contexts): @@ -310,7 +322,8 @@ def _tune_threshold(self, val_contexts): checkpoints = [cp for cp in os.listdir(self.config.output_dir) if "checkpoint-" in cp] if checkpoints == []: checkpoints.append("zero-shot") - best_val_accuracy = 0 + best_val_accuracy = -1 + best_checkpoint = checkpoints[0] val_convo_ids = set() utt2convo = {} val_labels_dict = {} @@ -334,7 +347,7 @@ def _tune_threshold(self, val_contexts): FastLanguageModel.for_inference(self.model) utt2score = {} for context in tqdm(val_contexts): - utt_score, _ = self._predict(context) + utt_score = self.score(context) utt_id = context.current_utterance.id utt2score[utt_id] = utt_score # for each CONVERSATION, whether or not it triggers will be effectively determined by what the highest score it ever got was @@ -392,26 +405,7 @@ def acc_with_threshold(y_true, y_score, thresh): ) return best_config - def _predict(self, context, threshold=None): - """ - Run inference on a single context using the fine-tuned TransformerDecoder model. - - This method prepares the input from the given context, generates a single-token - prediction (either "Yes" or "No"), and computes the softmax probability for "Yes". - The output is a confidence score and a binary prediction based on the given or - default threshold. - - :param context: A context tuple containing the current utterance and conversation history. - :param threshold: (Optional) A float threshold for converting the predicted probability into a binary label. - If not provided, `self.best_threshold` is used. - - :return: A tuple (`utt_score`, `utt_pred`), where: - - `utt_score` is the softmax probability assigned to "Yes" - - `utt_pred` is the binary prediction (1 if `utt_score > threshold`, else 0) - """ - # Enabling inference with different checkpoints to _tune_best_val_accuracy - if not threshold: - threshold = self.best_threshold + def score(self, context) -> float: FastLanguageModel.for_inference(self.model) context_utts = self._context_mode(context) inputs = self._tokenize(context_utts).to(self.config.device) @@ -432,9 +426,44 @@ def _predict(self, context, threshold=None): utt_score = F.softmax(torch.tensor([yes_logit, no_logit], dtype=torch.float32), dim=0)[ 0 ].item() - utt_pred = int(utt_score > threshold) + return utt_score + + def _predict(self, context, threshold=None): + """ + Run inference on a single context using the fine-tuned TransformerDecoder model. + + This method prepares the input from the given context, generates a single-token + prediction (either "Yes" or "No"), and computes the softmax probability for "Yes". + The output is a confidence score and a binary prediction based on the given or + default threshold. + + :param context: A context tuple containing the current utterance and conversation history. + :param threshold: (Optional) A float threshold for converting the predicted probability into a binary label. + If not provided, `self.best_threshold` is used. + + :return: A tuple (`utt_score`, `utt_pred`), where: + - `utt_score` is the softmax probability assigned to "Yes" + - `utt_pred` is the binary prediction (1 if `utt_score > threshold`, else 0) + """ + utt_score = self.score(context) + # keep threshold override for backward compatibility. + if threshold is not None: + utt_pred = int(utt_score > threshold) + else: + utt_pred = self.decision_policy.decide(context, self.score) return utt_score, utt_pred + def fit_decision_policy(self, contexts, val_contexts=None): + if ( + val_contexts is not None + and isinstance(self.decision_policy, ThresholdDecisionPolicy) + ): + return self._tune_threshold(val_contexts) + return super().fit_decision_policy(contexts, val_contexts) + + def fit(self, contexts, val_contexts=None): + return super().fit(contexts, val_contexts) + def transform(self, contexts, forecast_attribute_name, forecast_prob_attribute_name): """ Generate forecasts using the fine-tuned TransformerDecoder model on the provided contexts, and save the predictions to the output directory specified in the configuration. diff --git a/convokit/forecaster/TransformerEncoderModel.py b/convokit/forecaster/TransformerEncoderModel.py index 7203818b0..e86a9d133 100644 --- a/convokit/forecaster/TransformerEncoderModel.py +++ b/convokit/forecaster/TransformerEncoderModel.py @@ -17,6 +17,7 @@ from sklearn.metrics import roc_curve from .forecasterModel import ForecasterModel from .TransformerForecasterConfig import TransformerForecasterConfig +from convokit.decisionpolicy import ThresholdDecisionPolicy import shutil os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -43,15 +44,14 @@ class TransformerEncoderModel(ForecasterModel): :param config: (Optional) TransformerForecasterConfig object containing parameters for training and evaluation. """ - def __init__(self, model_name_or_path, config=DEFAULT_CONFIG): - super().__init__() + def __init__(self, model_name_or_path, config=DEFAULT_CONFIG, decision_policy=None): + super().__init__(decision_policy=decision_policy) self.tokenizer = AutoTokenizer.from_pretrained( model_name_or_path, model_max_length=512, truncation_side="left", padding_side="right", ) - self.best_threshold = 0.5 model_config = AutoConfig.from_pretrained( model_name_or_path, num_labels=2, problem_type="single_label_classification" ) @@ -63,6 +63,17 @@ def __init__(self, model_name_or_path, config=DEFAULT_CONFIG): self.config = config return + @property + def best_threshold(self): + if hasattr(self.decision_policy, "threshold"): + return self.decision_policy.threshold + return None + + @best_threshold.setter + def best_threshold(self, value): + if hasattr(self.decision_policy, "threshold"): + self.decision_policy.threshold = float(value) + def _context_mode(self, context): """ Select the utterances to include in the input context based on the configured context mode. @@ -153,11 +164,11 @@ def _context_to_bert_data(self, contexts): @torch.inference_mode @torch.no_grad - def _predict( + def _score_dataset( self, dataset, model=None, - threshold=0.5, + threshold=None, forecast_prob_attribute_name="forecast_prob", forecast_attribute_name="forecast", ): @@ -179,6 +190,8 @@ def _predict( """ if not model: model = self.model.to(self.config.device) + if threshold is None: + threshold = self.best_threshold if self.best_threshold is not None else 0.5 utt_ids = [] preds = [] scores = [] @@ -198,6 +211,26 @@ def _predict( {forecast_attribute_name: preds, forecast_prob_attribute_name: scores}, index=utt_ids ) + @torch.inference_mode + @torch.no_grad + def score(self, context) -> float: + self.model.eval() + context_utts = self._context_mode(context) + tokenized_context = self._tokenize(context_utts) + input_ids = ( + torch.tensor(tokenized_context["input_ids"], dtype=torch.long) + .to(self.config.device) + .reshape([1, -1]) + ) + attention_mask = ( + torch.tensor(tokenized_context["attention_mask"], dtype=torch.long) + .to(self.config.device) + .reshape([1, -1]) + ) + outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) + probs = F.softmax(outputs.logits, dim=-1) + return probs[0, 1].item() + def _tune_threshold(self, val_dataset, val_contexts): """ Tune the decision threshold and select the best model checkpoint based on validation accuracy. @@ -221,7 +254,10 @@ def _tune_threshold(self, val_dataset, val_contexts): :return: A dictionary containing the best checkpoint path, best threshold, and best validation accuracy. """ checkpoints = [cp for cp in os.listdir(self.config.output_dir) if "checkpoint-" in cp] - best_val_accuracy = 0 + if len(checkpoints) == 0: + raise ValueError("no checkpoints found for threshold tuning") + best_val_accuracy = -1 + best_checkpoint = checkpoints[0] val_convo_ids = set() utt2convo = {} val_labels_dict = {} @@ -238,7 +274,7 @@ def _tune_threshold(self, val_dataset, val_contexts): finetuned_model = AutoModelForSequenceClassification.from_pretrained( full_model_path ).to(self.config.device) - val_scores = self._predict(val_dataset, model=finetuned_model) + val_scores = self._score_dataset(val_dataset, model=finetuned_model) # for each CONVERSATION, whether or not it triggers will be effectively determined by what the highest score it ever got was highest_convo_scores = {convo_id: -1 for convo_id in val_convo_ids} for utt_id in val_scores.index: @@ -266,7 +302,7 @@ def acc_with_threshold(y_true, y_score, thresh): self.best_threshold = thresholds[best_acc_idx] self.model = finetuned_model - eval_forecasts_df = self._predict(val_dataset, threshold=self.best_threshold) + eval_forecasts_df = self._score_dataset(val_dataset, threshold=self.best_threshold) eval_prediction_file = os.path.join(self.config.output_dir, "val_predictions.csv") eval_forecasts_df.to_csv(eval_prediction_file) @@ -291,7 +327,7 @@ def acc_with_threshold(y_true, y_score, thresh): ) return best_config - def fit(self, contexts, val_contexts): + def fit_belief_estimator(self, contexts, val_contexts=None): """ Fine-tune the TransformerEncoder model, and save the best model according to validation performance. @@ -301,12 +337,10 @@ def fit(self, contexts, val_contexts): held-out validation set. :param contexts: an iterator over context tuples, provided by the Forecaster framework - :param val_contexts: an iterator over context tuples to be used only for validation. + :param val_contexts: optional validation contexts (not used by this stage). """ - val_contexts = list(val_contexts) train_pairs = self._context_to_bert_data(contexts) - val_for_tuning_pairs = self._context_to_bert_data(val_contexts) - dataset = DatasetDict({"train": train_pairs, "val_for_tuning": val_for_tuning_pairs}) + dataset = DatasetDict({"train": train_pairs}) dataset.set_format("torch") training_args = TrainingArguments( @@ -324,9 +358,25 @@ def fit(self, contexts, val_contexts): ) trainer = Trainer(model=self.model, args=training_args, train_dataset=dataset["train"]) trainer.train() - _ = self._tune_threshold(dataset["val_for_tuning"], val_contexts) return + def fit_decision_policy(self, contexts, val_contexts=None): + if val_contexts is None: + return super().fit_decision_policy(contexts, val_contexts) + val_contexts = list(val_contexts) + val_for_tuning_pairs = self._context_to_bert_data(val_contexts) + val_for_tuning_pairs.set_format("torch") + return self._tune_threshold(val_for_tuning_pairs, val_contexts) + + def fit(self, contexts, val_contexts=None): + return super().fit(contexts, val_contexts) + + def _predict(self, context, threshold=None): + utt_score = self.score(context) + if threshold is not None: + return utt_score, int(utt_score > threshold) + return utt_score, self.decision_policy.decide(context, self.score) + def transform(self, contexts, forecast_attribute_name, forecast_prob_attribute_name): """ Generate forecasts using the fine-tuned TransformerEncoder model on the provided contexts, and save the predictions to the output directory specified in the configuration. @@ -337,14 +387,16 @@ def transform(self, contexts, forecast_attribute_name, forecast_prob_attribute_n :return: a Pandas DataFrame, with one row for each context, indexed by the ID of that context's current utterance. Contains two columns, one with raw probabilities named according to forecast_prob_attribute_name, and one with discretized (binary) forecasts named according to forecast_attribute_name """ - test_pairs = self._context_to_bert_data(contexts) - dataset = DatasetDict({"test": test_pairs}) - dataset.set_format("torch") - forecasts_df = self._predict( - dataset["test"], - threshold=self.best_threshold, - forecast_attribute_name=forecast_attribute_name, - forecast_prob_attribute_name=forecast_prob_attribute_name, + utt_ids = [] + preds = [] + scores = [] + for context in tqdm(contexts): + utt_score, utt_pred = self._predict(context) + utt_ids.append(context.current_utterance.id) + preds.append(utt_pred) + scores.append(utt_score) + forecasts_df = pd.DataFrame( + {forecast_attribute_name: preds, forecast_prob_attribute_name: scores}, index=utt_ids ) prediction_file = os.path.join(self.config.output_dir, "test_predictions.csv") diff --git a/convokit/forecaster/__init__.py b/convokit/forecaster/__init__.py index 12a137604..71fac5534 100644 --- a/convokit/forecaster/__init__.py +++ b/convokit/forecaster/__init__.py @@ -1,6 +1,7 @@ from .forecaster import * from .forecasterModel import * from .cumulativeBoW import * +from convokit.decisionpolicy import * import sys # Import CRAFT models if torch is available diff --git a/convokit/forecaster/cumulativeBoW.py b/convokit/forecaster/cumulativeBoW.py index 50fe0373a..c14b9c61a 100644 --- a/convokit/forecaster/cumulativeBoW.py +++ b/convokit/forecaster/cumulativeBoW.py @@ -25,11 +25,15 @@ def __init__( use_tokens=False, forecast_attribute_name: str = "prediction", forecast_prob_attribute_name: str = "score", + decision_policy=None, ): super().__init__( + decision_policy=decision_policy, forecast_attribute_name=forecast_attribute_name, forecast_prob_attribute_name=forecast_prob_attribute_name, ) + self.forecast_attribute_name = forecast_attribute_name + self.forecast_prob_attribute_name = forecast_prob_attribute_name if vectorizer is None: print("Initializing default unigram CountVectorizer...") if use_tokens: @@ -66,6 +70,10 @@ def __init__( else: self.clf_model = clf_model + @staticmethod + def _context_to_text(context): + return " ".join([u.text for u in context.context]) + @staticmethod def _combine_contexts(id_to_context_others): """ @@ -114,3 +122,30 @@ def forecast(self, id_to_context_reply_label): data=list(zip(ids, preds, pred_probs)), columns=["id", self.forecast_attribute_name, self.forecast_prob_attribute_name], ).set_index("id") + + def fit_belief_estimator(self, contexts, val_contexts=None): + contexts = list(contexts) + X_raw = [self._context_to_text(context) for context in contexts] + y = [self.labeler(context.current_utterance.get_conversation()) for context in contexts] + X = self.vectorizer.fit_transform(X_raw) + self.clf_model.fit(X, y) + + def score(self, context) -> float: + X = self.vectorizer.transform([self._context_to_text(context)]) + return float(self.clf_model.predict_proba(X)[0, 1]) + + def fit(self, contexts, val_contexts=None): + return super().fit(contexts, val_contexts) + + def transform(self, contexts, forecast_attribute_name, forecast_prob_attribute_name): + utt_ids = [] + preds = [] + scores = [] + for context in contexts: + utt_score, utt_pred = self._predict(context) + utt_ids.append(context.current_utterance.id) + preds.append(utt_pred) + scores.append(utt_score) + return pd.DataFrame( + {forecast_attribute_name: preds, forecast_prob_attribute_name: scores}, index=utt_ids + ) diff --git a/convokit/forecaster/forecaster.py b/convokit/forecaster/forecaster.py index c3c259ac5..1df481122 100644 --- a/convokit/forecaster/forecaster.py +++ b/convokit/forecaster/forecaster.py @@ -6,7 +6,9 @@ import numpy as np from matplotlib import pyplot as plt -# Define a namedtuple template to represent conversational context tuples +# define a namedtuple template to represent conversational context tuples. +# this alias is kept for backwards compatibility; decision policies should +# accept the same structure. ContextTuple = namedtuple( "ContextTuple", ["context", "current_utterance", "future_context", "conversation_id"] ) diff --git a/convokit/forecaster/forecasterModel.py b/convokit/forecaster/forecasterModel.py index 0051ff32e..474aa30e3 100644 --- a/convokit/forecaster/forecasterModel.py +++ b/convokit/forecaster/forecasterModel.py @@ -1,6 +1,9 @@ from abc import ABC, abstractmethod +from itertools import tee from typing import Callable +from convokit.decisionpolicy import ThresholdDecisionPolicy + class ForecasterModel(ABC): """ @@ -9,8 +12,9 @@ class ForecasterModel(ABC): in a consistent format, defined above. """ - def __init__(self): + def __init__(self, decision_policy=None, **kwargs): self._labeler = None + self._decision_policy = decision_policy or ThresholdDecisionPolicy() @property def labeler(self): @@ -19,17 +23,68 @@ def labeler(self): @labeler.setter def labeler(self, value: Callable): self._labeler = value + if self._decision_policy is not None: + self._decision_policy.labeler = value + + @property + def decision_policy(self): + return self._decision_policy + + @decision_policy.setter + def decision_policy(self, value): + self._decision_policy = value + if self._decision_policy is not None: + self._decision_policy.labeler = self._labeler - @abstractmethod def fit(self, contexts, val_contexts=None): """ - Train this conversational forecasting model on the given data + Train this conversational forecasting model on the given data by fitting + both the belief estimator and the decision policy. :param contexts: an iterator over context tuples :param val_contexts: an optional second iterator over context tuples to be used as a separate held-out validation set. Concrete ForecasterModel implementations may choose to ignore this, or conversely even enforce its presence. """ + belief_contexts, policy_contexts = tee(contexts, 2) + if val_contexts is None: + belief_val_contexts = None + policy_val_contexts = None + else: + belief_val_contexts, policy_val_contexts = tee(val_contexts, 2) + self.fit_belief_estimator(belief_contexts, belief_val_contexts) + self.fit_decision_policy(policy_contexts, policy_val_contexts) + + @abstractmethod + def fit_belief_estimator(self, contexts, val_contexts=None): + """ + Fit only the belief estimator component that produces continuous scores. + """ + pass + + def fit_decision_policy(self, contexts, val_contexts=None): + """ + Fit only the decision policy component. + """ + if self.decision_policy is not None: + return self.decision_policy.fit( + contexts=contexts, val_contexts=val_contexts, score_fn=self.score + ) + return None + + @abstractmethod + def score(self, context) -> float: + """ + Produce the belief estimator score for a context. + """ pass + def _predict(self, context): + """ + Return both belief score and policy action for a context. + """ + utt_score = self.score(context) + utt_pred = self.decision_policy.decide(context, self.score) + return utt_score, utt_pred + @abstractmethod def transform(self, contexts, forecast_attribute_name, forecast_prob_attribute_name): """ diff --git a/docs/source/decisionpolicy.rst b/docs/source/decisionpolicy.rst new file mode 100644 index 000000000..48eebd1a7 --- /dev/null +++ b/docs/source/decisionpolicy.rst @@ -0,0 +1,9 @@ +Decision Policy +=============== + +The decision policy API separates belief estimation (continuous scores) from +intervention decisions (discrete actions). This keeps ``Forecaster`` unchanged +while allowing flexible action logic in ``ForecasterModel``. + +.. automodule:: convokit.decisionpolicy + :members: diff --git a/docs/source/forecaster.rst b/docs/source/forecaster.rst index e688aa7c7..cd0fe35b0 100644 --- a/docs/source/forecaster.rst +++ b/docs/source/forecaster.rst @@ -41,6 +41,7 @@ These are subclasses of ForecasterModel, each implementing forecasting models us .. toctree:: :maxdepth: 1 + Decision Policy CRAFT Model Transformer Encoder-based Model Transformer Decoder-based Model From ed76ff61f3c5659b6717e2a29bd22e37515fc164 Mon Sep 17 00:00:00 2001 From: laerdon Date: Sat, 11 Apr 2026 19:29:14 +0000 Subject: [PATCH 2/6] reformulated decisionpolicy threshold tuning --- .gitignore | 1 + convokit/decisionpolicy/decisionPolicy.py | 98 ++++++++- .../decisionpolicy/deferralDecisionPolicy.py | 164 ++++++++------- .../decisionpolicy/thresholdDecisionPolicy.py | 48 ++--- .../forecaster/TransformerDecoderModel.py | 180 +++++------------ .../forecaster/TransformerEncoderModel.py | 132 +++---------- convokit/forecaster/forecaster.py | 31 ++- convokit/forecaster/forecasterModel.py | 92 +++++++-- examples/forecaster/train_deferral.py | 186 ++++++++++++++++++ 9 files changed, 546 insertions(+), 386 deletions(-) create mode 100644 examples/forecaster/train_deferral.py diff --git a/.gitignore b/.gitignore index 68ebdbf40..07ed59bff 100644 --- a/.gitignore +++ b/.gitignore @@ -40,3 +40,4 @@ venv-triadmotif/ !/docs/source/img/* env/ website/docs/build +examples/forecaster/deferral_out diff --git a/convokit/decisionpolicy/decisionPolicy.py b/convokit/decisionpolicy/decisionPolicy.py index f0aca8546..40b87b857 100644 --- a/convokit/decisionpolicy/decisionPolicy.py +++ b/convokit/decisionpolicy/decisionPolicy.py @@ -1,5 +1,9 @@ from abc import ABC, abstractmethod -from typing import Callable +from typing import Callable, Tuple, Optional, Dict, Any + +import numpy as np +from sklearn.metrics import roc_curve +from tqdm import tqdm class DecisionPolicy(ABC): @@ -18,14 +22,100 @@ def labeler(self): def labeler(self, value: Callable): self._labeler = value + def _fit_with_model_checkpoint_selection(self, val_contexts, score_fn: Callable = None): + if score_fn is None: + return None + forecaster_model = getattr(score_fn, "__self__", None) + if forecaster_model is None: + return None + get_checkpoints = getattr(forecaster_model, "get_checkpoints", None) + load_checkpoint = getattr(forecaster_model, "load_checkpoint", None) + finalize_best_checkpoint_selection = getattr( + forecaster_model, "finalize_best_checkpoint_selection", None + ) + if not callable(get_checkpoints) or not callable(load_checkpoint): + return None + + checkpoints = list(get_checkpoints()) + if len(checkpoints) == 0: + return None + + best_config = None + best_checkpoint = None + best_val_accuracy = -1.0 + for checkpoint_name in checkpoints: + load_checkpoint(checkpoint_name) + fit_result = self._fit_threshold_for_loaded_model(val_contexts, score_fn=score_fn) + print(f"accuracy: {checkpoint_name} {fit_result['best_val_accuracy']}") + if fit_result["best_val_accuracy"] > best_val_accuracy: + best_checkpoint = checkpoint_name + best_val_accuracy = fit_result["best_val_accuracy"] + best_config = { + "best_checkpoint": checkpoint_name, + "best_threshold": float(fit_result["best_threshold"]), + "best_val_accuracy": float(fit_result["best_val_accuracy"]), + } + + if best_config is None: + return None + + if hasattr(self, "threshold"): + self.threshold = float(best_config["best_threshold"]) + load_checkpoint(best_checkpoint) + if callable(finalize_best_checkpoint_selection): + finalize_best_checkpoint_selection( + best_checkpoint, + best_config, + val_contexts=val_contexts, + score_fn=score_fn, + ) + return best_config + + def _fit_threshold_for_loaded_model(self, val_contexts, score_fn: Callable): + y_true, y_score = self._get_validation_arrays(val_contexts, score_fn) + default_threshold = float(getattr(self, "threshold", 0.5)) + if len(y_true) == 0: + return {"best_threshold": default_threshold, "best_val_accuracy": 0.0} + + try: + _, _, thresholds = roc_curve(y_true, y_score) + except ValueError: + thresholds = np.asarray([default_threshold], dtype=float) + + if len(thresholds) == 0: + thresholds = np.asarray([default_threshold], dtype=float) + + accs = [((y_score > t).astype(int) == y_true).mean() for t in thresholds] + best_idx = int(np.argmax(accs)) + best_threshold = float(thresholds[best_idx]) + return {"best_threshold": best_threshold, "best_val_accuracy": float(accs[best_idx])} + + def _get_validation_arrays(self, val_contexts, score_fn: Callable): + highest_convo_scores = {} + convo_labels = {} + for context in tqdm(val_contexts): + convo_id = context.conversation_id + score = float(score_fn(context)) + label = int(self.labeler(context.current_utterance.get_conversation())) + if convo_id not in highest_convo_scores: + highest_convo_scores[convo_id] = score + else: + highest_convo_scores[convo_id] = max(highest_convo_scores[convo_id], score) + convo_labels[convo_id] = label + + convo_ids = list(highest_convo_scores.keys()) + y_true = np.asarray([convo_labels[c] for c in convo_ids]) + y_score = np.asarray([highest_convo_scores[c] for c in convo_ids]) + return y_true, y_score + @abstractmethod - def decide(self, context, score_fn: Callable) -> int: + def decide(self, context, score_fn: Callable) -> Tuple[float, int, Optional[Dict[str, Any]]]: """ Decide whether to intervene for a context. :param context: context tuple supplied by Forecaster :param score_fn: callable that maps a context tuple to a scalar score - :return: integer action label (currently 0/1) + :return: tuple containing the score, the integer action label (currently 0/1), and any additional metadata """ pass @@ -38,4 +128,4 @@ def fit(self, contexts, val_contexts=None, score_fn: Callable = None): :param val_contexts: optional validation contexts :param score_fn: optional scorer callable exposed by ForecasterModel """ - pass + pass \ No newline at end of file diff --git a/convokit/decisionpolicy/deferralDecisionPolicy.py b/convokit/decisionpolicy/deferralDecisionPolicy.py index ab8d2848e..848e1154f 100644 --- a/convokit/decisionpolicy/deferralDecisionPolicy.py +++ b/convokit/decisionpolicy/deferralDecisionPolicy.py @@ -1,8 +1,4 @@ -from itertools import tee -from typing import Callable, List, Optional - -import numpy as np -from sklearn.metrics import roc_curve +from typing import Callable, List, Optional, Dict, Any, Tuple from .decisionPolicy import DecisionPolicy @@ -25,56 +21,63 @@ def get_conversation(self): class DeferralDecisionPolicy(DecisionPolicy): """ - Decision policy that can defer intervention using simulated next utterances. + Decision policy that defers intervention by looking ahead at simulated next utterances. + + :param simulator: utterance simulator model (must have a ``transform(contexts)`` method + returning a DataFrame indexed by utterance id). if the simulator exposes + ``get_num_simulations()``, ``num_simulations`` is capped to that value. + :param threshold: probability threshold above which a context is flagged. + :param tau: minimum number of simulated branches that must exceed the threshold + before an intervention is issued. + :param num_simulations: how many simulated branches to use per context (capped to + simulator's ``get_num_simulations()`` if available). + :param store_simulations: if True, simulated reply strings are cached during decide() + and written to corpus utterance metadata by post_transform(). + :param simulated_reply_attribute_name: metadata field name used when storing simulations + on corpus utterances (only relevant when store_simulations=True). """ def __init__( self, - simulator=None, - threshold: float = 0.5, - num_simulations: int = 3, - aggregation: str = "mean", + simulator, + threshold, + tau: int = 5, + num_simulations: int = 10, + store_simulations: bool = False, + simulated_reply_attribute_name: str = "sim_replies", + sim_replies_forecast_probs_attribute_name: str = "sim_replies_forecast_probs", ): super().__init__() self.simulator = simulator self.threshold = float(threshold) - self.num_simulations = int(num_simulations) - self.aggregation = aggregation - - def _aggregate_scores(self, scores: List[float]) -> float: - if len(scores) == 0: - return 0.0 - if self.aggregation == "max": - return float(np.max(scores)) - if self.aggregation == "min": - return float(np.min(scores)) - return float(np.mean(scores)) - - def get_simulations(self, context, simulator=None, k: Optional[int] = None) -> List[str]: - simulator = simulator if simulator is not None else self.simulator - if k is None: - k = self.num_simulations - if simulator is None: + self.tau = int(tau) + n = int(num_simulations) + if simulator is not None and hasattr(simulator, "get_num_simulations"): + n = min(n, int(simulator.get_num_simulations())) + self.num_simulations = n + self.store_simulations = store_simulations + self.simulated_reply_attribute_name = simulated_reply_attribute_name + self.sim_replies_forecast_probs_attribute_name = sim_replies_forecast_probs_attribute_name + self._sim_cache: dict = {} + self._sim_score_cache: dict = {} + + def get_simulations(self, context, simulator=None) -> List[str]: + sim = simulator if simulator is not None else self.simulator + if sim is None or not hasattr(sim, "transform"): + return [] + sims = sim.transform(iter([context])) + utt_id = context.current_utterance.id + if utt_id not in sims.index or sims.shape[1] == 0: return [] - if callable(simulator): - sims = simulator(context, k) - return list(sims)[:k] - if hasattr(simulator, "get_simulations"): - sims = simulator.get_simulations(context, k) - return list(sims)[:k] - if hasattr(simulator, "transform"): - sims = simulator.transform(iter([context])) - if context.current_utterance.id in sims.index: - col_name = sims.columns[0] - return list(sims.loc[context.current_utterance.id][col_name])[:k] - return [] + col_name = sims.columns[0] + return list(sims.loc[utt_id][col_name])[: self.num_simulations] def _build_simulated_context(self, context, simulation_text: str, simulation_idx: int): current_utt = context.current_utterance synthetic_utt = _synthetic_utterance( text=simulation_text, utterance_id=f"{current_utt.id}__sim_{simulation_idx}", - speaker_id="simulator", + speaker_id="", ) new_context_utts = list(context.context) + [synthetic_utt] context_cls = context.__class__ @@ -85,60 +88,49 @@ def _build_simulated_context(self, context, simulation_text: str, simulation_idx conversation_id=context.conversation_id, ) - def _decision_score(self, context, score_fn: Callable) -> float: + def _decision_score(self, context, score_fn: Callable): current_score = float(score_fn(context)) simulations = self.get_simulations(context) - if len(simulations) == 0: - return current_score simulation_scores = [] for idx, sim_text in enumerate(simulations): sim_context = self._build_simulated_context(context, sim_text, idx) simulation_scores.append(float(score_fn(sim_context))) - return self._aggregate_scores([current_score] + simulation_scores) - - def decide(self, context, score_fn: Callable) -> int: - decision_score = self._decision_score(context, score_fn) - return int(decision_score > self.threshold) + if self.store_simulations and simulations: + utt_id = context.current_utterance.id + self._sim_cache[utt_id] = simulations + self._sim_score_cache[utt_id] = simulation_scores + return current_score, simulations, simulation_scores + + def decide(self, context, score_fn: Callable) -> Tuple[float, int, Optional[Dict[str, Any]]]: + decision_score, simulations, simulation_scores = self._decision_score(context, score_fn) + num_simulations_above_threshold = sum(1 for score in simulation_scores if score > self.threshold) + + # we want to return an awry prediction if we have + return (decision_score, + 1 if decision_score > self.threshold and num_simulations_above_threshold > 5 else 0, + { + self.simulated_reply_attribute_name: simulations, + self.sim_replies_forecast_probs_attribute_name: simulation_scores, + } + ) def fit(self, contexts, val_contexts=None, score_fn: Callable = None): - if self.simulator is not None and hasattr(self.simulator, "fit"): - if val_contexts is None: - sim_contexts = contexts - sim_val_contexts = None - else: - sim_contexts, contexts = tee(contexts, 2) - sim_val_contexts, val_contexts = tee(val_contexts, 2) - self.simulator.fit(sim_contexts, sim_val_contexts) - if val_contexts is None or score_fn is None or self.labeler is None: - return {"threshold": self.threshold} + print("either no validation contexts/score function/labeler were provided, returning current threshold") + return {"best_threshold": self.threshold} + val_contexts = list(val_contexts) if len(val_contexts) == 0: - return {"threshold": self.threshold} - - highest_convo_scores = {} - convo_labels = {} - for context in val_contexts: - convo_id = context.conversation_id - score = self._decision_score(context, score_fn) - label = int(self.labeler(context.current_utterance.get_conversation())) - if convo_id not in highest_convo_scores: - highest_convo_scores[convo_id] = score - else: - highest_convo_scores[convo_id] = max(highest_convo_scores[convo_id], score) - convo_labels[convo_id] = label - - convo_ids = list(highest_convo_scores.keys()) - y_true = np.asarray([convo_labels[c] for c in convo_ids]) - y_score = np.asarray([highest_convo_scores[c] for c in convo_ids]) - try: - _, _, thresholds = roc_curve(y_true, y_score) - except ValueError: - return {"threshold": self.threshold} - if len(thresholds) == 0: - return {"threshold": self.threshold} - - accs = [((y_score > t).astype(int) == y_true).mean() for t in thresholds] - best_idx = int(np.argmax(accs)) - self.threshold = float(thresholds[best_idx]) - return {"threshold": self.threshold, "best_val_accuracy": float(accs[best_idx])} + print("no validation contexts were provided, returning current threshold") + return {"best_threshold": self.threshold} + + fit_result = self._fit_with_model_checkpoint_selection(val_contexts, score_fn=score_fn) + if isinstance(fit_result, dict): + if "best_threshold" in fit_result: + self.threshold = float(fit_result["best_threshold"]) + return fit_result + + fit_result = self._fit_threshold_for_loaded_model(val_contexts, score_fn=score_fn) + if "best_threshold" in fit_result: + self.threshold = float(fit_result["best_threshold"]) + return fit_result diff --git a/convokit/decisionpolicy/thresholdDecisionPolicy.py b/convokit/decisionpolicy/thresholdDecisionPolicy.py index 0f3814a44..44a8b0e3c 100644 --- a/convokit/decisionpolicy/thresholdDecisionPolicy.py +++ b/convokit/decisionpolicy/thresholdDecisionPolicy.py @@ -1,7 +1,4 @@ -from typing import Callable - -import numpy as np -from sklearn.metrics import roc_curve +from typing import Callable, Tuple from .decisionPolicy import DecisionPolicy @@ -15,43 +12,26 @@ def __init__(self, threshold: float = 0.5): super().__init__() self.threshold = float(threshold) - def decide(self, context, score_fn: Callable) -> int: - return int(score_fn(context) > self.threshold) + def decide(self, context, score_fn: Callable) -> Tuple[float, int]: + return score_fn(context), int(score_fn(context) > self.threshold) def fit(self, contexts, val_contexts=None, score_fn: Callable = None): if val_contexts is None or score_fn is None or self.labeler is None: + print("either no validation contexts/score function/labeler were provided, returning current threshold") return {"best_threshold": self.threshold} val_contexts = list(val_contexts) if len(val_contexts) == 0: + print("no validation contexts were provided, returning current threshold") return {"best_threshold": self.threshold} - highest_convo_scores = {} - convo_labels = {} - for context in val_contexts: - convo_id = context.conversation_id - score = score_fn(context) - label = int(self.labeler(context.current_utterance.get_conversation())) - if convo_id not in highest_convo_scores: - highest_convo_scores[convo_id] = score - else: - highest_convo_scores[convo_id] = max(highest_convo_scores[convo_id], score) - convo_labels[convo_id] = label - - convo_ids = list(highest_convo_scores.keys()) - y_true = np.asarray([convo_labels[c] for c in convo_ids]) - y_score = np.asarray([highest_convo_scores[c] for c in convo_ids]) - - # roc_curve can fail when only one class is present; keep current threshold in that case. - try: - _, _, thresholds = roc_curve(y_true, y_score) - except ValueError: - return {"best_threshold": self.threshold} - - if len(thresholds) == 0: - return {"best_threshold": self.threshold} + fit_result = self._fit_with_model_checkpoint_selection(val_contexts, score_fn=score_fn) + if isinstance(fit_result, dict): + if "best_threshold" in fit_result: + self.threshold = float(fit_result["best_threshold"]) + return fit_result - accs = [((y_score > t).astype(int) == y_true).mean() for t in thresholds] - best_idx = int(np.argmax(accs)) - self.threshold = float(thresholds[best_idx]) - return {"best_threshold": self.threshold, "best_val_accuracy": float(accs[best_idx])} + fit_result = self._fit_threshold_for_loaded_model(val_contexts, score_fn=score_fn) + if "best_threshold" in fit_result: + self.threshold = float(fit_result["best_threshold"]) + return fit_result diff --git a/convokit/forecaster/TransformerDecoderModel.py b/convokit/forecaster/TransformerDecoderModel.py index e2df302c2..449a1e4d9 100644 --- a/convokit/forecaster/TransformerDecoderModel.py +++ b/convokit/forecaster/TransformerDecoderModel.py @@ -1,3 +1,4 @@ +from itertools import tee import unsloth from unsloth import FastLanguageModel, is_bfloat16_supported from unsloth.chat_templates import get_chat_template @@ -5,18 +6,13 @@ import torch.nn.functional as F from trl import SFTTrainer, SFTConfig from datasets import Dataset +from collections import defaultdict -import json import os from tqdm import tqdm import pandas as pd -import numpy as np -from sklearn.metrics import roc_curve from .forecasterModel import ForecasterModel from .TransformerForecasterConfig import TransformerForecasterConfig -from convokit.decisionpolicy import ThresholdDecisionPolicy -import shutil - def _get_template_map(model_name_or_path): """ @@ -121,6 +117,22 @@ def best_threshold(self, value): if hasattr(self.decision_policy, "threshold"): self.decision_policy.threshold = float(value) + def get_checkpoints(self): + checkpoints = [cp for cp in os.listdir(self.config.output_dir) if "checkpoint-" in cp] + if len(checkpoints) == 0: + return ["zero-shot"] + return checkpoints + + def load_checkpoint(self, checkpoint_name): + if checkpoint_name == "zero-shot": + return + full_model_path = os.path.join(self.config.output_dir, checkpoint_name) + self.model, _ = FastLanguageModel.from_pretrained( + model_name=full_model_path, + max_seq_length=self.max_seq_length, + load_in_4bit=True, + ) + def _context_mode(self, context): """ Select the utterances to include in the input context based on the configured context mode. @@ -273,9 +285,9 @@ def fit_belief_estimator(self, train_contexts, val_contexts=None): model=self.model, tokenizer=self.tokenizer, train_dataset=train_dataset, + max_seq_length=self.max_seq_length, args=SFTConfig( dataset_text_field="text", - max_seq_length=self.max_seq_length, per_device_train_batch_size=self.config.per_device_batch_size, gradient_accumulation_steps=self.config.gradient_accumulation_steps, warmup_steps=10, @@ -297,114 +309,6 @@ def fit_belief_estimator(self, train_contexts, val_contexts=None): trainer.train() return - def _tune_threshold(self, val_contexts): - """ - Tune the decision threshold and select the best model checkpoint based on validation accuracy. - - This method evaluates all model checkpoints in the configured output directory using a - held-out validation set. - - The selected model, threshold, and associated metadata are stored in: - - `self.model`: the best-performing fine-tuned model - - `self.best_threshold`: the optimal decision threshold - - `dev_config.json`: file containing best checkpoint metadata - - `val_predictions.csv`: CSV file with forecast outputs on the validation set - - Additionally, all non-optimal model checkpoints are removed to save disk space, and the - tokenizer is saved to the directory of the best checkpoint. - - :param val_dataset: A HuggingFace-compatible dataset containing features for validation. - :param val_contexts: An iterable of context tuples corresponding to the validation set. - Used to map utterance IDs to conversation IDs and extract ground-truth labels. - - :return: A dictionary containing the best checkpoint path, best threshold, and best validation accuracy. - """ - checkpoints = [cp for cp in os.listdir(self.config.output_dir) if "checkpoint-" in cp] - if checkpoints == []: - checkpoints.append("zero-shot") - best_val_accuracy = -1 - best_checkpoint = checkpoints[0] - val_convo_ids = set() - utt2convo = {} - val_labels_dict = {} - val_contexts = list(val_contexts) - for context in val_contexts: - convo_id = context.conversation_id - utt_id = context.current_utterance.id - label = self.labeler(context.current_utterance.get_conversation()) - utt2convo[utt_id] = convo_id - val_labels_dict[convo_id] = label - val_convo_ids.add(convo_id) - val_convo_ids = list(val_convo_ids) - for cp in checkpoints: - if cp != "zero-shot": - full_model_path = os.path.join(self.config.output_dir, cp) - self.model, _ = FastLanguageModel.from_pretrained( - model_name=full_model_path, - max_seq_length=self.max_seq_length, - load_in_4bit=True, - ) - FastLanguageModel.for_inference(self.model) - utt2score = {} - for context in tqdm(val_contexts): - utt_score = self.score(context) - utt_id = context.current_utterance.id - utt2score[utt_id] = utt_score - # for each CONVERSATION, whether or not it triggers will be effectively determined by what the highest score it ever got was - highest_convo_scores = {convo_id: -1 for convo_id in val_convo_ids} - - for utt_id in utt2convo: - convo_id = utt2convo[utt_id] - utt_score = utt2score[utt_id] - if utt_score > highest_convo_scores[convo_id]: - highest_convo_scores[convo_id] = utt_score - - val_labels = np.asarray([int(val_labels_dict[c]) for c in val_convo_ids]) - val_scores = np.asarray([highest_convo_scores[c] for c in val_convo_ids]) - # use scikit learn to find candidate threshold cutoffs - _, _, thresholds = roc_curve(val_labels, val_scores) - - def acc_with_threshold(y_true, y_score, thresh): - y_pred = (y_score > thresh).astype(int) - return (y_pred == y_true).mean() - - accs = [acc_with_threshold(val_labels, val_scores, t) for t in thresholds] - best_acc_idx = np.argmax(accs) - - print("Accuracy:", cp, accs[best_acc_idx]) - if accs[best_acc_idx] > best_val_accuracy: - best_checkpoint = cp - best_val_accuracy = accs[best_acc_idx] - self.best_threshold = thresholds[best_acc_idx] - - # Save the best config - best_config = {} - best_config["best_checkpoint"] = best_checkpoint - best_config["best_threshold"] = self.best_threshold - best_config["best_val_accuracy"] = best_val_accuracy - config_file = os.path.join(self.config.output_dir, "dev_config.json") - with open(config_file, "w") as outfile: - json_object = json.dumps(best_config, indent=4) - outfile.write(json_object) - # Load best model - best_model_path = os.path.join(self.config.output_dir, best_checkpoint) - self.model, _ = FastLanguageModel.from_pretrained( - model_name=best_model_path, - max_seq_length=self.max_seq_length, - load_in_4bit=True, - ) - - # Clean other checkpoints to save disk space. - for root, _, _ in os.walk(self.config.output_dir): - if ("checkpoint" in root) and (best_checkpoint not in root): - print("Deleting:", root) - shutil.rmtree(root) - # Save the tokenizer. - self.tokenizer.save_pretrained( - os.path.join(self.config.output_dir, best_config["best_checkpoint"]) - ) - return best_config - def score(self, context) -> float: FastLanguageModel.for_inference(self.model) context_utts = self._context_mode(context) @@ -450,19 +354,14 @@ def _predict(self, context, threshold=None): if threshold is not None: utt_pred = int(utt_score > threshold) else: - utt_pred = self.decision_policy.decide(context, self.score) + utt_score, utt_pred, _ = self.decision_policy.decide(context, self.score) return utt_score, utt_pred - def fit_decision_policy(self, contexts, val_contexts=None): - if ( - val_contexts is not None - and isinstance(self.decision_policy, ThresholdDecisionPolicy) - ): - return self._tune_threshold(val_contexts) - return super().fit_decision_policy(contexts, val_contexts) - def fit(self, contexts, val_contexts=None): - return super().fit(contexts, val_contexts) + val_contexts_belief_estimator, val_contexts_decision_policy = tee(val_contexts, 2) + self.fit_belief_estimator(contexts, val_contexts_belief_estimator) + self.fit_decision_policy(contexts, val_contexts_decision_policy, score_fn=self.score) + return def transform(self, contexts, forecast_attribute_name, forecast_prob_attribute_name): """ @@ -478,15 +377,36 @@ def transform(self, contexts, forecast_attribute_name, forecast_prob_attribute_n utt_ids = [] preds = [] scores = [] + metadatas = defaultdict(list) + # for safety/flexibility we can accept either only score and pred or also the metadata for context in tqdm(contexts): - utt_score, utt_pred = self._predict(context) - + result = self.decision_policy.decide(context, self.score) + + if len(result) == 2: + utt_score, utt_pred = result + utt_metadata = {} + # no metadata + elif len(result) == 3: + utt_score, utt_pred, utt_metadata = result + for key in utt_metadata.keys(): + metadatas[key].append(utt_metadata.get(key, None)) + else: + raise ValueError( + "decision_policy.decide() must return (utt_score, utt_pred) " + "or (utt_score, utt_pred, metadata_dict)" + ) utt_ids.append(context.current_utterance.id) preds.append(utt_pred) scores.append(utt_score) - forecasts_df = pd.DataFrame( - {forecast_attribute_name: preds, forecast_prob_attribute_name: scores}, index=utt_ids - ) + cols = { + forecast_attribute_name: preds, + forecast_prob_attribute_name: scores, + } + for key, series in metadatas.items(): + assert len(series) == len(preds), "Metadata series length must match number of predictions" + cols[key] = series # each series same length as preds + forecasts_df = pd.DataFrame(cols, index=utt_ids) + prediction_file = os.path.join(self.config.output_dir, "test_predictions.csv") forecasts_df.to_csv(prediction_file) return forecasts_df diff --git a/convokit/forecaster/TransformerEncoderModel.py b/convokit/forecaster/TransformerEncoderModel.py index e86a9d133..f2130ac4d 100644 --- a/convokit/forecaster/TransformerEncoderModel.py +++ b/convokit/forecaster/TransformerEncoderModel.py @@ -11,14 +11,9 @@ import os import pandas as pd -import numpy as np -import json from tqdm import tqdm -from sklearn.metrics import roc_curve from .forecasterModel import ForecasterModel from .TransformerForecasterConfig import TransformerForecasterConfig -from convokit.decisionpolicy import ThresholdDecisionPolicy -import shutil os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -74,6 +69,29 @@ def best_threshold(self, value): if hasattr(self.decision_policy, "threshold"): self.decision_policy.threshold = float(value) + def get_checkpoints(self): + return [cp for cp in os.listdir(self.config.output_dir) if "checkpoint-" in cp] + + def load_checkpoint(self, checkpoint_name): + full_model_path = os.path.join(self.config.output_dir, checkpoint_name) + self.model = AutoModelForSequenceClassification.from_pretrained(full_model_path).to( + self.config.device + ) + + def finalize_best_checkpoint_selection( + self, best_checkpoint, best_config, val_contexts=None, score_fn=None + ): + super().finalize_best_checkpoint_selection( + best_checkpoint, best_config, val_contexts=val_contexts, score_fn=score_fn + ) + if val_contexts is None or self.best_threshold is None: + return + val_dataset = self._context_to_bert_data(val_contexts) + val_dataset.set_format("torch") + eval_forecasts_df = self._score_dataset(val_dataset, threshold=self.best_threshold) + eval_prediction_file = os.path.join(self.config.output_dir, "val_predictions.csv") + eval_forecasts_df.to_csv(eval_prediction_file) + def _context_mode(self, context): """ Select the utterances to include in the input context based on the configured context mode. @@ -231,102 +249,6 @@ def score(self, context) -> float: probs = F.softmax(outputs.logits, dim=-1) return probs[0, 1].item() - def _tune_threshold(self, val_dataset, val_contexts): - """ - Tune the decision threshold and select the best model checkpoint based on validation accuracy. - - This method evaluates all model checkpoints in the configured output directory using a - held-out validation set. - - The selected model, threshold, and associated metadata are stored in: - - `self.model`: the best-performing fine-tuned model - - `self.best_threshold`: the optimal decision threshold - - `dev_config.json`: file containing best checkpoint metadata - - `val_predictions.csv`: CSV file with forecast outputs on the validation set - - Additionally, all non-optimal model checkpoints are removed to save disk space, and the - tokenizer is saved to the directory of the best checkpoint. - - :param val_dataset: A HuggingFace-compatible dataset containing features for validation. - :param val_contexts: An iterable of context tuples corresponding to the validation set. - Used to map utterance IDs to conversation IDs and extract ground-truth labels. - - :return: A dictionary containing the best checkpoint path, best threshold, and best validation accuracy. - """ - checkpoints = [cp for cp in os.listdir(self.config.output_dir) if "checkpoint-" in cp] - if len(checkpoints) == 0: - raise ValueError("no checkpoints found for threshold tuning") - best_val_accuracy = -1 - best_checkpoint = checkpoints[0] - val_convo_ids = set() - utt2convo = {} - val_labels_dict = {} - for context in val_contexts: - convo_id = context.conversation_id - utt_id = context.current_utterance.id - label = self.labeler(context.current_utterance.get_conversation()) - utt2convo[utt_id] = convo_id - val_labels_dict[convo_id] = label - val_convo_ids.add(convo_id) - val_convo_ids = list(val_convo_ids) - for cp in checkpoints: - full_model_path = os.path.join(self.config.output_dir, cp) - finetuned_model = AutoModelForSequenceClassification.from_pretrained( - full_model_path - ).to(self.config.device) - val_scores = self._score_dataset(val_dataset, model=finetuned_model) - # for each CONVERSATION, whether or not it triggers will be effectively determined by what the highest score it ever got was - highest_convo_scores = {convo_id: -1 for convo_id in val_convo_ids} - for utt_id in val_scores.index: - convo_id = utt2convo[utt_id] - utt_score = val_scores.loc[utt_id].forecast_prob - if utt_score > highest_convo_scores[convo_id]: - highest_convo_scores[convo_id] = utt_score - - val_labels = np.asarray([int(val_labels_dict[c]) for c in val_convo_ids]) - val_scores = np.asarray([highest_convo_scores[c] for c in val_convo_ids]) - # use scikit learn to find candidate threshold cutoffs - _, _, thresholds = roc_curve(val_labels, val_scores) - - def acc_with_threshold(y_true, y_score, thresh): - y_pred = (y_score > thresh).astype(int) - return (y_pred == y_true).mean() - - accs = [acc_with_threshold(val_labels, val_scores, t) for t in thresholds] - best_acc_idx = np.argmax(accs) - - print("Accuracy:", cp, accs[best_acc_idx]) - if accs[best_acc_idx] > best_val_accuracy: - best_checkpoint = cp - best_val_accuracy = accs[best_acc_idx] - self.best_threshold = thresholds[best_acc_idx] - self.model = finetuned_model - - eval_forecasts_df = self._score_dataset(val_dataset, threshold=self.best_threshold) - eval_prediction_file = os.path.join(self.config.output_dir, "val_predictions.csv") - eval_forecasts_df.to_csv(eval_prediction_file) - - # Save the best config - best_config = {} - best_config["best_checkpoint"] = best_checkpoint - best_config["best_threshold"] = self.best_threshold - best_config["best_val_accuracy"] = best_val_accuracy - config_file = os.path.join(self.config.output_dir, "dev_config.json") - with open(config_file, "w") as outfile: - json_object = json.dumps(best_config, indent=4) - outfile.write(json_object) - - # Clean other checkpoints to save disk space. - for root, _, _ in os.walk(self.config.output_dir): - if ("checkpoint" in root) and (best_checkpoint not in root): - print("Deleting:", root) - shutil.rmtree(root) - # Save the tokenizer. - self.tokenizer.save_pretrained( - os.path.join(self.config.output_dir, best_config["best_checkpoint"]) - ) - return best_config - def fit_belief_estimator(self, contexts, val_contexts=None): """ Fine-tune the TransformerEncoder model, and save the best model according to validation performance. @@ -360,14 +282,6 @@ def fit_belief_estimator(self, contexts, val_contexts=None): trainer.train() return - def fit_decision_policy(self, contexts, val_contexts=None): - if val_contexts is None: - return super().fit_decision_policy(contexts, val_contexts) - val_contexts = list(val_contexts) - val_for_tuning_pairs = self._context_to_bert_data(val_contexts) - val_for_tuning_pairs.set_format("torch") - return self._tune_threshold(val_for_tuning_pairs, val_contexts) - def fit(self, contexts, val_contexts=None): return super().fit(contexts, val_contexts) diff --git a/convokit/forecaster/forecaster.py b/convokit/forecaster/forecaster.py index 1df481122..c89dddaf9 100644 --- a/convokit/forecaster/forecaster.py +++ b/convokit/forecaster/forecaster.py @@ -121,6 +121,20 @@ def fit( self.forecaster_model.fit(contexts, val_contexts) return self + + def fit_decision_policy(self, corpus, context_selector, val_context_selector): + contexts = self._create_context_iterator(corpus, context_selector, include_future_context=True) + val_contexts = None + if val_context_selector is not None: + val_contexts = self._create_context_iterator(corpus, val_context_selector, include_future_context=True) + return self.forecaster_model.fit_decision_policy(contexts, val_contexts) + + def fit_belief_estimator(self, corpus, context_selector, val_context_selector): + contexts = self._create_context_iterator(corpus, context_selector, include_future_context=True) + val_contexts = None + if val_context_selector is not None: + val_contexts = self._create_context_iterator(corpus, val_context_selector, include_future_context=True) + return self.forecaster_model.fit_belief_estimator(contexts, val_contexts) def transform( self, @@ -142,19 +156,16 @@ def transform( contexts, self.forecast_attribute_name, self.forecast_prob_attribute_name ) + # generalize addition of metadata columns + meta_columns = list(forecast_df.columns) for utt in corpus.iter_utterances(): if utt.id in forecast_df.index: - utt.add_meta( - self.forecast_attribute_name, - forecast_df.loc[utt.id][self.forecast_attribute_name], - ) - utt.add_meta( - self.forecast_prob_attribute_name, - forecast_df.loc[utt.id][self.forecast_prob_attribute_name], - ) + row = forecast_df.loc[utt.id] + for col in meta_columns: + utt.add_meta(col, row[col]) else: - utt.add_meta(self.forecast_attribute_name, None) - utt.add_meta(self.forecast_prob_attribute_name, None) + for col in meta_columns: + utt.add_meta(col, None) return corpus diff --git a/convokit/forecaster/forecasterModel.py b/convokit/forecaster/forecasterModel.py index 474aa30e3..48d41ffa7 100644 --- a/convokit/forecaster/forecasterModel.py +++ b/convokit/forecaster/forecasterModel.py @@ -2,6 +2,10 @@ from itertools import tee from typing import Callable +import json +import os +import shutil + from convokit.decisionpolicy import ThresholdDecisionPolicy @@ -36,6 +40,7 @@ def decision_policy(self, value): if self._decision_policy is not None: self._decision_policy.labeler = self._labeler + @abstractmethod def fit(self, contexts, val_contexts=None): """ Train this conversational forecasting model on the given data by fitting @@ -44,14 +49,7 @@ def fit(self, contexts, val_contexts=None): :param contexts: an iterator over context tuples :param val_contexts: an optional second iterator over context tuples to be used as a separate held-out validation set. Concrete ForecasterModel implementations may choose to ignore this, or conversely even enforce its presence. """ - belief_contexts, policy_contexts = tee(contexts, 2) - if val_contexts is None: - belief_val_contexts = None - policy_val_contexts = None - else: - belief_val_contexts, policy_val_contexts = tee(val_contexts, 2) - self.fit_belief_estimator(belief_contexts, belief_val_contexts) - self.fit_decision_policy(policy_contexts, policy_val_contexts) + pass @abstractmethod def fit_belief_estimator(self, contexts, val_contexts=None): @@ -60,16 +58,83 @@ def fit_belief_estimator(self, contexts, val_contexts=None): """ pass - def fit_decision_policy(self, contexts, val_contexts=None): + def fit_decision_policy(self, contexts, val_contexts=None, score_fn: Callable = None): """ Fit only the decision policy component. """ if self.decision_policy is not None: - return self.decision_policy.fit( - contexts=contexts, val_contexts=val_contexts, score_fn=self.score + if score_fn is None: + score_fn = self.score + fit_result = self.decision_policy.fit( + contexts=contexts, val_contexts=val_contexts, score_fn=score_fn ) + self._json_dump_fit_result(fit_result) + return fit_result return None + def _json_dump_fit_result(self, fit_result): + if not isinstance(fit_result, dict): + return + + output_dir = getattr(getattr(self, "config", None), "output_dir", None) + if output_dir is None: + return + + config_file = os.path.join(output_dir, "dev_config.json") + existing_config = {} + if os.path.exists(config_file): + try: + with open(config_file, "r") as infile: + existing_config = json.load(infile) + except (json.JSONDecodeError, OSError): + existing_config = {} + + if "best_checkpoint" in fit_result: + existing_config["best_checkpoint"] = fit_result["best_checkpoint"] + if "best_threshold" in fit_result: + existing_config["best_threshold"] = float(fit_result["best_threshold"]) + if "best_val_accuracy" in fit_result: + existing_config["best_val_accuracy"] = float(fit_result["best_val_accuracy"]) + + with open(config_file, "w") as outfile: + json.dump(existing_config, outfile, indent=4) + + def get_checkpoints(self): + return [] + + def load_checkpoint(self, checkpoint_name): + raise NotImplementedError("checkpoint loading is not implemented for this model") + + def finalize_best_checkpoint_selection( + self, best_checkpoint, best_config, val_contexts=None, score_fn: Callable = None + ): + if best_checkpoint is None: + return + self._cleanup_checkpoints(best_checkpoint) + self._save_tokenizer_checkpoint(best_checkpoint) + + def _cleanup_checkpoints(self, best_checkpoint): + output_dir = getattr(getattr(self, "config", None), "output_dir", None) + if output_dir is None or best_checkpoint is None: + return + + for root, _, _ in os.walk(output_dir): + if ("checkpoint" in root) and (best_checkpoint not in root): + print(f"deleting: {root}") + shutil.rmtree(root) + + def _save_tokenizer_checkpoint(self, best_checkpoint): + tokenizer = getattr(self, "tokenizer", None) + output_dir = getattr(getattr(self, "config", None), "output_dir", None) + if ( + tokenizer is None + or output_dir is None + or best_checkpoint is None + or not hasattr(tokenizer, "save_pretrained") + ): + return + tokenizer.save_pretrained(os.path.join(output_dir, best_checkpoint)) + @abstractmethod def score(self, context) -> float: """ @@ -80,9 +145,10 @@ def score(self, context) -> float: def _predict(self, context): """ Return both belief score and policy action for a context. + + This method is deprecated in favor of using the self.decision_policy.decide method. """ - utt_score = self.score(context) - utt_pred = self.decision_policy.decide(context, self.score) + utt_score, utt_pred = self.decision_policy.decide(context, self.score) return utt_score, utt_pred @abstractmethod diff --git a/examples/forecaster/train_deferral.py b/examples/forecaster/train_deferral.py new file mode 100644 index 000000000..dec75fe49 --- /dev/null +++ b/examples/forecaster/train_deferral.py @@ -0,0 +1,186 @@ +""" +end-to-end training script for a TransformerDecoderModel forecaster with a DeferralDecisionPolicy. + +flow: + 1. load the CGA-CMV corpus + 2. build a DeferralDecisionPolicy backed by an UnslothUtteranceSimulatorModel + 3. build a TransformerDecoderModel (forecaster backbone) with that policy attached + 4. wrap both in a Forecaster + 5. fit: LoRA fine-tune the forecaster, then fit the decision policy on the val set + 6. evaluate on the test set and print metrics + +usage: + python train_deferral.py [--device cuda] [--gpu 0] +""" + +import argparse +import os +import sys + +# ensure the repo root is on the path when running directly +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + + +def main(args): + import convokit + from convokit import Corpus, Forecaster, download + from convokit.forecaster.TransformerDecoderModel import TransformerDecoderModel + from convokit.forecaster.TransformerForecasterConfig import TransformerForecasterConfig + from convokit.decisionpolicy import DeferralDecisionPolicy, ThresholdDecisionPolicy + + # ------------------------------------------------------------------ # + # 1. corpus + # ------------------------------------------------------------------ # + print("[info] loading corpus...") + corpus = Corpus( + filename=download( + "conversations-gone-awry-cmv-corpus", + data_dir=args.data_dir, + ) + ) + + labeler = "has_removed_comment" + + # ------------------------------------------------------------------ # + # 2. context selectors + # ------------------------------------------------------------------ # + def train_selector(ctx): + """last context of every train conversation (matches original craft/llm training setup)""" + convo = ctx.current_utterance.get_conversation() + return ( + convo.meta.get("split") == "train" + and len(ctx.future_context) == 0 + ) + + def val_selector(ctx): + return ctx.current_utterance.get_conversation().meta.get("split") == "val" + + def test_selector(ctx): + convo = ctx.current_utterance.get_conversation() + convo_len = len(convo.get_chronological_utterance_list()) + return ( + convo.meta.get("split") == "test" + # exclude the very last context (the toxic turn itself) + and len(ctx.context) < convo_len + ) + + + # 3. simulator model + # # + # 4. decision policy + policy = ThresholdDecisionPolicy( + threshold=0.5926666259765625, + ) + + # 5. forecaster model + print("[info] loading forecaster model...") + config = TransformerForecasterConfig( + output_dir=args.output_dir, + per_device_batch_size=args.batch_size, + gradient_accumulation_steps=args.grad_accum, + num_train_epochs=args.epochs, + learning_rate=args.lr, + random_seed=args.seed, + context_mode="normal", + device=args.device, + ) + + forecaster_model = TransformerDecoderModel( + model_name_or_path=args.forecaster_model, + config=config, + decision_policy=policy, + ) + + # 6. forecaster wrapper + forecaster = Forecaster( + forecaster_model=forecaster_model, + labeler=labeler, + ) + + # 7. fit + # print("[info] fitting forecaster (belief estimator + decision policy)...") + # forecaster.fit( + # corpus=corpus, + # context_selector=train_selector, + # val_context_selector=val_selector, + # ) + + forecaster.fit_decision_policy( + corpus=corpus, + context_selector=train_selector, + val_context_selector=val_selector, + ) + + # print(forecaster.forecaster_model.decision_policy.threshold) + + # # 8. evaluate on test set + # print("[info] running transform on test set...") + # corpus = forecaster.transform( + # corpus=corpus, + # context_selector=test_selector, + # ) + + # print("[info] computing metrics...") + # forecaster.summarize( + # corpus=corpus, + # selector=lambda convo: convo.meta.get("split") == "test", + # ) + + # optional: inspect a few utterances with stored simulations + if args.store_simulations: + print("\n[info] sample utterances with stored simulations:") + shown = 0 + for utt in corpus.iter_utterances(): + # show only utterances that were forecasted and have sim_replies + if ( + utt.meta.get("forecast") is not None + and utt.meta.get("sim_replies") is not None + ): + print("---") + print("text :", utt.text[:120]) + print("forecast_prob :", utt.meta["forecast_prob"]) + print("forecast :", utt.meta["forecast"]) + print("sim_replies :", utt.meta["sim_replies"][:2]) + print("sim_probs :", utt.meta["sim_replies_forecast_probs"][:2]) + shown += 1 + if shown >= 3: + break + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="train forecaster with DeferralDecisionPolicy") + + # paths + parser.add_argument("--forecaster-model", required=True, + help="hf model name or local path for the decoder forecaster") + parser.add_argument("--simulator-model", required=True, + help="hf model name or local path for the utterance simulator") + parser.add_argument("--output-dir", default="./deferral_output", + help="directory to save checkpoints and predictions") + parser.add_argument("--data-dir", default="./", + help="directory to download/find the corpus") + + # training hyperparams + parser.add_argument("--epochs", type=int, default=1) + parser.add_argument("--batch-size", type=int, default=2) + parser.add_argument("--grad-accum", type=int, default=32) + parser.add_argument("--lr", type=float, default=1e-4) + parser.add_argument("--seed", type=int, default=1) + + # deferral policy hyperparams + parser.add_argument("--num-simulations", type=int, default=10, + help="number of simulated branches per context") + parser.add_argument("--tau", type=int, default=5, + help="minimum simulated branches above threshold to intervene") + + # misc + parser.add_argument("--device", default="cuda") + parser.add_argument("--gpu", type=int, default=3, + help="which gpu to use (sets CUDA_VISIBLE_DEVICES)") + parser.add_argument("--store-simulations", action="store_true", + help="write simulated replies and their forecast probs to corpus metadata") + + args = parser.parse_args() + os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) + + main(args) From 21a314ff65dfbc8a8350f462d9686c6f5c6ede70 Mon Sep 17 00:00:00 2001 From: laerdon Date: Sun, 26 Apr 2026 21:13:48 +0000 Subject: [PATCH 3/6] final 1/x --- convokit/decisionpolicy/decisionPolicy.py | 31 +++- .../decisionpolicy/deferralDecisionPolicy.py | 98 ++++++++++-- .../decisionpolicy/thresholdDecisionPolicy.py | 15 +- convokit/forecaster/CRAFTModel.py | 15 +- .../forecaster/TransformerDecoderModel.py | 148 ++++++++++++++++-- .../forecaster/TransformerEncoderModel.py | 22 ++- convokit/forecaster/forecaster.py | 6 +- convokit/forecaster/forecasterModel.py | 35 ++++- 8 files changed, 332 insertions(+), 38 deletions(-) diff --git a/convokit/decisionpolicy/decisionPolicy.py b/convokit/decisionpolicy/decisionPolicy.py index 40b87b857..eff7ac755 100644 --- a/convokit/decisionpolicy/decisionPolicy.py +++ b/convokit/decisionpolicy/decisionPolicy.py @@ -11,8 +11,17 @@ class DecisionPolicy(ABC): Abstract interface for converting a conversational context into an action. """ - def __init__(self): + def __init__( + self, + forecast_prob_attribute_name: str = "forecast_prob", + reuse_cached_forecast_probs: bool = True, + ): self._labeler = None + # name of the utterance-meta field that may already hold a forecast prob + # from a prior Forecaster.transform() pass. kept in sync with the owning + # ForecasterModel / Forecaster when they are wired up. + self.forecast_prob_attribute_name = forecast_prob_attribute_name + self.reuse_cached_forecast_probs = bool(reuse_cached_forecast_probs) @property def labeler(self): @@ -22,6 +31,18 @@ def labeler(self): def labeler(self, value: Callable): self._labeler = value + def _score(self, context, score_fn: Callable) -> float: + # prefer a previously written forecast prob on the current utterance meta + # so policies don't re-invoke the belief estimator on utterances the + # forecaster has already transformed. synthetic / simulated utterances + # carry an empty meta and always fall through to score_fn. + if self.reuse_cached_forecast_probs: + meta = getattr(getattr(context, "current_utterance", None), "meta", None) or {} + cached = meta.get(self.forecast_prob_attribute_name) + if cached is not None: + return float(cached) + return float(score_fn(context)) + def _fit_with_model_checkpoint_selection(self, val_contexts, score_fn: Callable = None): if score_fn is None: return None @@ -43,6 +64,11 @@ def _fit_with_model_checkpoint_selection(self, val_contexts, score_fn: Callable best_config = None best_checkpoint = None best_val_accuracy = -1.0 + # while sweeping checkpoints, any cached forecast_prob on utterance meta + # reflects whichever checkpoint's transform() ran last, not the one we + # are currently evaluating. force a fresh score_fn call for each sweep. + prior_reuse_flag = self.reuse_cached_forecast_probs + self.reuse_cached_forecast_probs = False for checkpoint_name in checkpoints: load_checkpoint(checkpoint_name) fit_result = self._fit_threshold_for_loaded_model(val_contexts, score_fn=score_fn) @@ -55,6 +81,7 @@ def _fit_with_model_checkpoint_selection(self, val_contexts, score_fn: Callable "best_threshold": float(fit_result["best_threshold"]), "best_val_accuracy": float(fit_result["best_val_accuracy"]), } + self.reuse_cached_forecast_probs = prior_reuse_flag if best_config is None: return None @@ -95,7 +122,7 @@ def _get_validation_arrays(self, val_contexts, score_fn: Callable): convo_labels = {} for context in tqdm(val_contexts): convo_id = context.conversation_id - score = float(score_fn(context)) + score = self._score(context, score_fn) label = int(self.labeler(context.current_utterance.get_conversation())) if convo_id not in highest_convo_scores: highest_convo_scores[convo_id] = score diff --git a/convokit/decisionpolicy/deferralDecisionPolicy.py b/convokit/decisionpolicy/deferralDecisionPolicy.py index 848e1154f..335eb2ae4 100644 --- a/convokit/decisionpolicy/deferralDecisionPolicy.py +++ b/convokit/decisionpolicy/deferralDecisionPolicy.py @@ -35,6 +35,11 @@ class DeferralDecisionPolicy(DecisionPolicy): and written to corpus utterance metadata by post_transform(). :param simulated_reply_attribute_name: metadata field name used when storing simulations on corpus utterances (only relevant when store_simulations=True). + :param reuse_cached_simulations: if True (default), simulations already present on the + current utterance's metadata under ``simulated_reply_attribute_name`` are reused + instead of re-invoking the simulator. similarly, cached simulation scores under + ``sim_replies_forecast_probs_attribute_name`` are reused when they align with the + reused simulations, skipping re-scoring. set to False to force regeneration. """ def __init__( @@ -46,8 +51,14 @@ def __init__( store_simulations: bool = False, simulated_reply_attribute_name: str = "sim_replies", sim_replies_forecast_probs_attribute_name: str = "sim_replies_forecast_probs", + reuse_cached_simulations: bool = True, + forecast_prob_attribute_name: str = "forecast_prob", + reuse_cached_forecast_probs: bool = True, ): - super().__init__() + super().__init__( + forecast_prob_attribute_name=forecast_prob_attribute_name, + reuse_cached_forecast_probs=reuse_cached_forecast_probs, + ) self.simulator = simulator self.threshold = float(threshold) self.tau = int(tau) @@ -58,10 +69,49 @@ def __init__( self.store_simulations = store_simulations self.simulated_reply_attribute_name = simulated_reply_attribute_name self.sim_replies_forecast_probs_attribute_name = sim_replies_forecast_probs_attribute_name + self.reuse_cached_simulations = bool(reuse_cached_simulations) self._sim_cache: dict = {} self._sim_score_cache: dict = {} + def _get_utt_meta(self, context): + # unified accessor so both real Utterance and _synthetic_utterance work; returns {} if absent. + return getattr(context.current_utterance, "meta", {}) or {} + + def _get_cached_simulations(self, context) -> Optional[List[str]]: + # returns cached simulation strings for this utterance if available on its metadata, else None. + if not self.reuse_cached_simulations: + return None + meta = self._get_utt_meta(context) + cached = meta.get(self.simulated_reply_attribute_name) + if cached is None: + return None + cached_list = list(cached) + if len(cached_list) == 0: + return None + return cached_list[: self.num_simulations] + + def _get_cached_simulation_scores( + self, context, num_expected: int + ) -> Optional[List[float]]: + # returns cached per-simulation scores aligned with reused simulations, else None. + if not self.reuse_cached_simulations or num_expected == 0: + return None + meta = self._get_utt_meta(context) + cached = meta.get(self.sim_replies_forecast_probs_attribute_name) + if cached is None: + return None + cached_list = list(cached) + # if the cached scores are shorter than the reused simulations, fall back to re-scoring + # rather than silently mixing cached and fresh scores. + if len(cached_list) < num_expected: + return None + return [float(x) for x in cached_list[:num_expected]] + def get_simulations(self, context, simulator=None) -> List[str]: + # fast path: reuse pre-computed simulations from utterance metadata when present. + cached = self._get_cached_simulations(context) + if cached is not None: + return cached sim = simulator if simulator is not None else self.simulator if sim is None or not hasattr(sim, "transform"): return [] @@ -89,12 +139,21 @@ def _build_simulated_context(self, context, simulation_text: str, simulation_idx ) def _decision_score(self, context, score_fn: Callable): - current_score = float(score_fn(context)) + current_score = self._score(context, score_fn) simulations = self.get_simulations(context) - simulation_scores = [] - for idx, sim_text in enumerate(simulations): - sim_context = self._build_simulated_context(context, sim_text, idx) - simulation_scores.append(float(score_fn(sim_context))) + # the get_simulations method actively checks if cached simulations exist + + # fast path: if cached per-simulation scores align with the reused simulations, + # skip re-scoring the simulated contexts entirely. + cached_scores = self._get_cached_simulation_scores(context, len(simulations)) + if cached_scores is not None: + simulation_scores = cached_scores + else: + simulation_scores = [] + for idx, sim_text in enumerate(simulations): + sim_context = self._build_simulated_context(context, sim_text, idx) + # synthetic utterances have empty meta so _score falls through to score_fn. + simulation_scores.append(self._score(sim_context, score_fn)) if self.store_simulations and simulations: utt_id = context.current_utterance.id self._sim_cache[utt_id] = simulations @@ -102,16 +161,27 @@ def _decision_score(self, context, score_fn: Callable): return current_score, simulations, simulation_scores def decide(self, context, score_fn: Callable) -> Tuple[float, int, Optional[Dict[str, Any]]]: + max_defer_index = 4 decision_score, simulations, simulation_scores = self._decision_score(context, score_fn) num_simulations_above_threshold = sum(1 for score in simulation_scores if score > self.threshold) - - # we want to return an awry prediction if we have - return (decision_score, - 1 if decision_score > self.threshold and num_simulations_above_threshold > 5 else 0, - { - self.simulated_reply_attribute_name: simulations, - self.sim_replies_forecast_probs_attribute_name: simulation_scores, - } + num_simulations = len(simulations) + # context.context contains chronological_utts[: i+1] (includes current_utterance), + # so the current utterance's position in the conversation is len(context.context) - 1. + utt_index = max(0, len(getattr(context, "context", []) or []) - 1) + # past the deferral window we always commit when fp > threshold, mirroring the + # `i < 4` early-only deferral in performance_utils.no_tricks. + past_defer_window = max_defer_index is not None and utt_index >= max_defer_index + defer_eligible = not past_defer_window + num_calm = num_simulations - num_simulations_above_threshold + # defer = defer_eligible and (num_calm > self.tau) + defer = (num_calm > self.tau) + return ( + decision_score, + 1 if decision_score > self.threshold and not defer else 0, + { + self.simulated_reply_attribute_name: simulations, + self.sim_replies_forecast_probs_attribute_name: simulation_scores, + }, ) def fit(self, contexts, val_contexts=None, score_fn: Callable = None): diff --git a/convokit/decisionpolicy/thresholdDecisionPolicy.py b/convokit/decisionpolicy/thresholdDecisionPolicy.py index 44a8b0e3c..16df7ff2f 100644 --- a/convokit/decisionpolicy/thresholdDecisionPolicy.py +++ b/convokit/decisionpolicy/thresholdDecisionPolicy.py @@ -8,12 +8,21 @@ class ThresholdDecisionPolicy(DecisionPolicy): A simple decision policy that predicts 1 when score > threshold. """ - def __init__(self, threshold: float = 0.5): - super().__init__() + def __init__( + self, + threshold: float = 0.5, + forecast_prob_attribute_name: str = "forecast_prob", + reuse_cached_forecast_probs: bool = True, + ): + super().__init__( + forecast_prob_attribute_name=forecast_prob_attribute_name, + reuse_cached_forecast_probs=reuse_cached_forecast_probs, + ) self.threshold = float(threshold) def decide(self, context, score_fn: Callable) -> Tuple[float, int]: - return score_fn(context), int(score_fn(context) > self.threshold) + score = self._score(context, score_fn) + return score, int(score > self.threshold) def fit(self, contexts, val_contexts=None, score_fn: Callable = None): if val_contexts is None or score_fn is None or self.labeler is None: diff --git a/convokit/forecaster/CRAFTModel.py b/convokit/forecaster/CRAFTModel.py index e100e0a1e..80aa0b35d 100644 --- a/convokit/forecaster/CRAFTModel.py +++ b/convokit/forecaster/CRAFTModel.py @@ -344,6 +344,7 @@ def transform(self, contexts, forecast_attribute_name, forecast_prob_attribute_n # initialize the CRAFT model with whatever weights we currently have saved encoder, context_encoder, predictor = self._get_inference_components() + base_columns = {"id", forecast_attribute_name, forecast_prob_attribute_name} output_df = {"id": [], forecast_attribute_name: [], forecast_prob_attribute_name: []} batch_iterator = batchIterator( self._voc, test_pairs, self._config["batch_size"], shuffle=False @@ -393,10 +394,20 @@ def score_fn(scored_context): return score return self.score(scored_context) - pred = self.decision_policy.decide(context, score_fn) + utt_score, pred, utt_metadata = self._parse_decision_result( + self.decision_policy.decide(context, score_fn) + ) + current_idx = len(output_df["id"]) output_df["id"].append(utt_id) output_df[forecast_attribute_name].append(int(pred)) - output_df[forecast_prob_attribute_name].append(score) + output_df[forecast_prob_attribute_name].append(utt_score) + existing_metadata_keys = [key for key in output_df if key not in base_columns] + for key in existing_metadata_keys: + output_df[key].append(utt_metadata.get(key, None)) + for key, value in utt_metadata.items(): + if key not in output_df: + output_df[key] = [None] * current_idx + output_df[key].append(value) print( "Iteration: {}; Percent complete: {:.1f}%".format( iteration, iteration / n_iters * 100 diff --git a/convokit/forecaster/TransformerDecoderModel.py b/convokit/forecaster/TransformerDecoderModel.py index 449a1e4d9..a7d17782f 100644 --- a/convokit/forecaster/TransformerDecoderModel.py +++ b/convokit/forecaster/TransformerDecoderModel.py @@ -1,4 +1,4 @@ -from itertools import tee +from itertools import tee, islice import unsloth from unsloth import FastLanguageModel, is_bfloat16_supported from unsloth.chat_templates import get_chat_template @@ -249,8 +249,7 @@ def fit_belief_estimator(self, train_contexts, val_contexts=None): This method applies Low-Rank Adaptation (LoRA) to the decoder model, converts the training contexts into text-based input for LLM fine-tuning, and trains the model - using HuggingFace's `SFTTrainer`. After training, it tunes a decision threshold on - a held-out validation set to optimize binary forecast classification. + using HuggingFace's `SFTTrainer`. :param contexts: an iterator over context tuples, provided by the Forecaster framework :param val_contexts: an iterator over context tuples to be used only for validation. @@ -354,7 +353,16 @@ def _predict(self, context, threshold=None): if threshold is not None: utt_pred = int(utt_score > threshold) else: - utt_score, utt_pred, _ = self.decision_policy.decide(context, self.score) + result = self.decision_policy.decide(context, self.score) + if len(result) == 2: + utt_score, utt_pred = result + elif len(result) == 3: + utt_score, utt_pred, _ = result + else: + raise ValueError( + "decision_policy.decide() must return (utt_score, utt_pred) " + "or (utt_score, utt_pred, metadata_dict)" + ) return utt_score, utt_pred def fit(self, contexts, val_contexts=None): @@ -363,13 +371,14 @@ def fit(self, contexts, val_contexts=None): self.fit_decision_policy(contexts, val_contexts_decision_policy, score_fn=self.score) return - def transform(self, contexts, forecast_attribute_name, forecast_prob_attribute_name): + def transform(self, contexts, forecast_attribute_name, forecast_prob_attribute_name, verbose=False): """ Generate forecasts using the fine-tuned TransformerDecoder model on the provided contexts, and save the predictions to the output directory specified in the configuration. :param contexts: context tuples from the Forecaster framework :param forecast_attribute_name: Forecaster will use this to look up the table column containing your model's discretized predictions (see output specification below) :param forecast_prob_attribute_name: Forecaster will use this to look up the table column containing your model's raw forecast probabilities (see output specification below) + :param verbose: if True, print verbose transform logging during the transformation :return: a Pandas DataFrame, with one row for each context, indexed by the ID of that context's current utterance. Contains two columns, one with raw probabilities named according to forecast_prob_attribute_name, and one with discretized (binary) forecasts named according to forecast_attribute_name """ @@ -378,8 +387,45 @@ def transform(self, contexts, forecast_attribute_name, forecast_prob_attribute_n preds = [] scores = [] metadatas = defaultdict(list) + # TODO(metrics): temporary running metric logging during transform; remove before merge. + report_every_n = 250 + prediction_file = os.path.join(self.config.output_dir, "predictions.csv") + if os.path.exists(prediction_file): + os.remove(prediction_file) + next_flush_start = 0 + csv_header_written = False + convo_forecasts = {} + convo_labels = {} + + def _compute_conversation_metrics(): + common_convo_ids = [cid for cid in convo_forecasts if cid in convo_labels] + if len(common_convo_ids) == 0: + return None + tp = 0 + fp = 0 + tn = 0 + fn = 0 + for convo_id in common_convo_ids: + pred = int(convo_forecasts[convo_id] > 0) + label = int(convo_labels[convo_id]) + if label == 1 and pred == 1: + tp += 1 + elif label == 0 and pred == 1: + fp += 1 + elif label == 0 and pred == 0: + tn += 1 + elif label == 1 and pred == 0: + fn += 1 + n = len(common_convo_ids) + acc = (tp + tn) / n if n > 0 else 0.0 + p = tp / (tp + fp) if (tp + fp) > 0 else 0.0 + r = tp / (tp + fn) if (tp + fn) > 0 else 0.0 + fpr = fp / (fp + tn) if (fp + tn) > 0 else 0.0 + f1 = (2 * p * r / (p + r)) if (p + r) > 0 else 0.0 + return {"n": n, "acc": acc, "p": p, "r": r, "fpr": fpr, "f1": f1} # for safety/flexibility we can accept either only score and pred or also the metadata - for context in tqdm(contexts): + progress = tqdm(contexts) + for idx, context in enumerate(progress, start=1): result = self.decision_policy.decide(context, self.score) if len(result) == 2: @@ -388,8 +434,10 @@ def transform(self, contexts, forecast_attribute_name, forecast_prob_attribute_n # no metadata elif len(result) == 3: utt_score, utt_pred, utt_metadata = result - for key in utt_metadata.keys(): - metadatas[key].append(utt_metadata.get(key, None)) + # coerce None metadata to {} so policies that return (score, pred, None) + # don't crash downstream utt_metadata.items() / .get() calls. + if utt_metadata is None: + utt_metadata = {} else: raise ValueError( "decision_policy.decide() must return (utt_score, utt_pred) " @@ -398,15 +446,91 @@ def transform(self, contexts, forecast_attribute_name, forecast_prob_attribute_n utt_ids.append(context.current_utterance.id) preds.append(utt_pred) scores.append(utt_score) + current_idx = len(preds) - 1 + existing_metadata_keys = list(metadatas.keys()) + for key in existing_metadata_keys: + metadatas[key].append(utt_metadata.get(key, None)) + for key, value in utt_metadata.items(): + if key not in metadatas: + metadatas[key] = [None] * current_idx + metadatas[key].append(value) + + convo_id = getattr(context, "conversation_id", None) + try: + convo = context.current_utterance.get_conversation() + if convo_id is None and convo is not None: + convo_id = convo.id + if convo_id is not None: + if convo_id in convo_forecasts: + convo_forecasts[convo_id] = max(convo_forecasts[convo_id], int(utt_pred)) + else: + convo_forecasts[convo_id] = int(utt_pred) + if convo_id not in convo_labels: + convo_labels[convo_id] = int(self.labeler(convo)) + except Exception: + pass + + if idx % report_every_n == 0: + batch_cols = { + forecast_attribute_name: preds[next_flush_start:idx], + forecast_prob_attribute_name: scores[next_flush_start:idx], + } + for key, series in metadatas.items(): + batch_cols[key] = series[next_flush_start:idx] + batch_df = pd.DataFrame(batch_cols, index=utt_ids[next_flush_start:idx]) + batch_df.to_csv( + prediction_file, + mode="a" if csv_header_written else "w", + header=not csv_header_written, + ) + csv_header_written = True + next_flush_start = idx + + running_metrics = _compute_conversation_metrics() + if verbose: + if running_metrics is not None: + tqdm.write( + f"[info] transform metrics running: " + f"processed_contexts={idx}, conversations={running_metrics['n']}, " + f"acc={running_metrics['acc']:.4f}, p={running_metrics['p']:.4f}, " + f"r={running_metrics['r']:.4f}, fpr={running_metrics['fpr']:.4f}, " + f"f1={running_metrics['f1']:.4f}" + ) + else: + tqdm.write( + f"[info] transform metrics running: " + f"processed_contexts={idx}, conversations=0" + ) + total_processed = len(preds) + if total_processed > next_flush_start: + batch_cols = { + forecast_attribute_name: preds[next_flush_start:total_processed], + forecast_prob_attribute_name: scores[next_flush_start:total_processed], + } + for key, series in metadatas.items(): + batch_cols[key] = series[next_flush_start:total_processed] + batch_df = pd.DataFrame(batch_cols, index=utt_ids[next_flush_start:total_processed]) + batch_df.to_csv( + prediction_file, + mode="a" if csv_header_written else "w", + header=not csv_header_written, + ) + csv_header_written = True cols = { forecast_attribute_name: preds, forecast_prob_attribute_name: scores, } + final_metrics = _compute_conversation_metrics() + if final_metrics is not None: + tqdm.write( + f"[info] final transform metrics: " + f"processed_contexts={len(preds)}, conversations={final_metrics['n']}, " + f"acc={final_metrics['acc']:.4f}, p={final_metrics['p']:.4f}, " + f"r={final_metrics['r']:.4f}, fpr={final_metrics['fpr']:.4f}, " + f"f1={final_metrics['f1']:.4f}" + ) for key, series in metadatas.items(): assert len(series) == len(preds), "Metadata series length must match number of predictions" cols[key] = series # each series same length as preds forecasts_df = pd.DataFrame(cols, index=utt_ids) - - prediction_file = os.path.join(self.config.output_dir, "test_predictions.csv") - forecasts_df.to_csv(prediction_file) - return forecasts_df + return forecasts_df \ No newline at end of file diff --git a/convokit/forecaster/TransformerEncoderModel.py b/convokit/forecaster/TransformerEncoderModel.py index f2130ac4d..f7e8a07d6 100644 --- a/convokit/forecaster/TransformerEncoderModel.py +++ b/convokit/forecaster/TransformerEncoderModel.py @@ -14,6 +14,7 @@ from tqdm import tqdm from .forecasterModel import ForecasterModel from .TransformerForecasterConfig import TransformerForecasterConfig +from itertools import tee os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -283,13 +284,28 @@ def fit_belief_estimator(self, contexts, val_contexts=None): return def fit(self, contexts, val_contexts=None): - return super().fit(contexts, val_contexts) + val_contexts_belief_estimator, val_contexts_decision_policy = tee(val_contexts, 2) + self.fit_belief_estimator(contexts, val_contexts_belief_estimator) + self.fit_decision_policy(contexts, val_contexts_decision_policy, score_fn=self.score) + return def _predict(self, context, threshold=None): utt_score = self.score(context) + # keep threshold override for backward compatibility. if threshold is not None: - return utt_score, int(utt_score > threshold) - return utt_score, self.decision_policy.decide(context, self.score) + utt_pred = int(utt_score > threshold) + else: + result = self.decision_policy.decide(context, self.score) + if len(result) == 2: + utt_score, utt_pred = result + elif len(result) == 3: + utt_score, utt_pred, _ = result + else: + raise ValueError( + "decision_policy.decide() must return (utt_score, utt_pred) " + "or (utt_score, utt_pred, metadata_dict)" + ) + return utt_score, utt_pred def transform(self, contexts, forecast_attribute_name, forecast_prob_attribute_name): """ diff --git a/convokit/forecaster/forecaster.py b/convokit/forecaster/forecaster.py index c89dddaf9..0e0b485cb 100644 --- a/convokit/forecaster/forecaster.py +++ b/convokit/forecaster/forecaster.py @@ -50,6 +50,9 @@ def __init__( # also give the underlying ForecasterModel access to the labeler function self.forecaster_model.labeler = self.labeler + # keep the decision policy's forecast_prob cache key aligned with the + # meta field that Forecaster.transform() writes to. + self.forecaster_model.forecast_prob_attribute_name = self.forecast_prob_attribute_name def _create_context_iterator( self, @@ -140,6 +143,7 @@ def transform( self, corpus: Corpus, context_selector: Callable[[ContextTuple], bool] = lambda context: True, + **kwargs, ) -> Corpus: """ Wrapper method for applying the underlying conversational forecasting model to make forecasts over the Conversations in a given Corpus. @@ -153,7 +157,7 @@ def transform( """ contexts = self._create_context_iterator(corpus, context_selector) forecast_df = self.forecaster_model.transform( - contexts, self.forecast_attribute_name, self.forecast_prob_attribute_name + contexts, self.forecast_attribute_name, self.forecast_prob_attribute_name, **kwargs, ) # generalize addition of metadata columns diff --git a/convokit/forecaster/forecasterModel.py b/convokit/forecaster/forecasterModel.py index 48d41ffa7..168d52786 100644 --- a/convokit/forecaster/forecasterModel.py +++ b/convokit/forecaster/forecasterModel.py @@ -18,6 +18,7 @@ class ForecasterModel(ABC): def __init__(self, decision_policy=None, **kwargs): self._labeler = None + self._forecast_prob_attribute_name = "forecast_prob" self._decision_policy = decision_policy or ThresholdDecisionPolicy() @property @@ -30,6 +31,18 @@ def labeler(self, value: Callable): if self._decision_policy is not None: self._decision_policy.labeler = value + @property + def forecast_prob_attribute_name(self) -> str: + return self._forecast_prob_attribute_name + + @forecast_prob_attribute_name.setter + def forecast_prob_attribute_name(self, value: str): + # keeps the decision policy's cache key aligned with the forecaster's + # meta field so policies can reuse previously written forecast probs. + self._forecast_prob_attribute_name = value + if self._decision_policy is not None: + self._decision_policy.forecast_prob_attribute_name = value + @property def decision_policy(self): return self._decision_policy @@ -39,6 +52,9 @@ def decision_policy(self, value): self._decision_policy = value if self._decision_policy is not None: self._decision_policy.labeler = self._labeler + self._decision_policy.forecast_prob_attribute_name = ( + self._forecast_prob_attribute_name + ) @abstractmethod def fit(self, contexts, val_contexts=None): @@ -148,9 +164,26 @@ def _predict(self, context): This method is deprecated in favor of using the self.decision_policy.decide method. """ - utt_score, utt_pred = self.decision_policy.decide(context, self.score) + utt_score, utt_pred, _ = self._parse_decision_result( + self.decision_policy.decide(context, self.score) + ) return utt_score, utt_pred + def _parse_decision_result(self, result): + if len(result) == 2: + utt_score, utt_pred = result + utt_metadata = {} + elif len(result) == 3: + utt_score, utt_pred, utt_metadata = result + if utt_metadata is None: + utt_metadata = {} + else: + raise ValueError( + "decision_policy.decide() must return (utt_score, utt_pred) " + "or (utt_score, utt_pred, metadata_dict)" + ) + return float(utt_score), int(utt_pred), utt_metadata + @abstractmethod def transform(self, contexts, forecast_attribute_name, forecast_prob_attribute_name): """ From 924f909f3f0ca955698d7d76316474c03d499d91 Mon Sep 17 00:00:00 2001 From: laerdon Date: Tue, 28 Apr 2026 05:15:43 +0000 Subject: [PATCH 4/6] download config --- download_config.json | 6 +- .../decisionpolicy/decisionpolicy_demo.ipynb | 1217 +++++++++++++++++ 2 files changed, 1221 insertions(+), 2 deletions(-) create mode 100644 examples/decisionpolicy/decisionpolicy_demo.ipynb diff --git a/download_config.json b/download_config.json index 09bf149fb..3027a0d28 100644 --- a/download_config.json +++ b/download_config.json @@ -41,7 +41,8 @@ "ubuntu-chat-logs": 0, "contextual-abuse": 0, "news-interview": 0, - "emotional-support": 0 + "emotional-support": 0, + "decisionpolicy-demo": 0 }, "DatasetURLs": { "chromium-corpus": "http://zissou.infosci.cornell.edu/convokit/datasets/chromium-corpus/chromium-corpus.zip", @@ -115,7 +116,8 @@ "ubuntu-chat-logs": "https://zissou.infosci.cornell.edu/convokit/datasets/ubuntu-chat-logs/ubuntu-chat-logs.zip", "contextual-abuse": "https://zissou.infosci.cornell.edu/convokit/datasets/contextual-abuse/contextual-abuse.zip", "news-interview": "https://zissou.infosci.cornell.edu/convokit/datasets/news-interview/news-interview.zip", - "emotional-support": "https://zissou.infosci.cornell.edu/convokit/datasets/emotional-support/emotional-support.zip" + "emotional-support": "https://zissou.infosci.cornell.edu/convokit/datasets/emotional-support/emotional-support.zip", + "decisionpolicy-demo": "https://zissou.infosci.cornell.edu/convokit/datasets/decisionpolicy-demo/decisionpolicy-demo.zip" }, "ModelURLS": { "craft-wiki-pretrained": [ diff --git a/examples/decisionpolicy/decisionpolicy_demo.ipynb b/examples/decisionpolicy/decisionpolicy_demo.ipynb new file mode 100644 index 000000000..a3b44be6f --- /dev/null +++ b/examples/decisionpolicy/decisionpolicy_demo.ipynb @@ -0,0 +1,1217 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "90659cc6", + "metadata": {}, + "source": [ + "# Decision Policy Demo\n", + "\n", + "This notebook will provide code demonstrating how to use Decision Policies as introduced in Wait! There's a Way Out. This notebook will also provide code for running the experiments in the paper. " + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "703021a8", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ['CUDA_VISIBLE_DEVICES'] = '2'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fd8b87be", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import argparse\n", + "import sys\n", + "import glob\n", + "\n", + "from functools import partial\n", + "import json\n", + "from convokit import Corpus, Forecaster, download\n", + "from convokit.forecaster.TransformerDecoderModel import TransformerDecoderModel\n", + "from convokit.forecaster.TransformerForecasterConfig import TransformerForecasterConfig\n", + "from convokit.decisionpolicy import DeferralDecisionPolicy, ThresholdDecisionPolicy\n", + "from convokit.utterance_simulator.unslothUtteranceSimulatorModel import UnslothUtteranceSimulatorModel" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15f977d3", + "metadata": {}, + "outputs": [], + "source": [ + "# Set repo root\n", + "\n", + "from pathlib import Path\n", + "\n", + "repo_root = Path.cwd()\n", + "while repo_root.name != \"ConvoKit\":\n", + " repo_root = repo_root.parent\n", + "\n", + "repo_root = str(repo_root)" + ] + }, + { + "cell_type": "markdown", + "id": "dab74642", + "metadata": {}, + "source": [ + "Having imported our DeferralDecisionPolicy and ThresholdDecisionPolicy, we now will first define all of the other decision policies to benchmark. " + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "5dabf0bd", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Callable, List, Optional, Dict, Any, Tuple\n", + "import numpy as np\n", + "from convokit.decisionpolicy import DecisionPolicy\n", + "\n", + "\n", + "class _synthetic_speaker:\n", + " def __init__(self, speaker_id: str):\n", + " self.id = speaker_id\n", + "\n", + "\n", + "class _synthetic_utterance:\n", + " def __init__(self, text: str, utterance_id: str, speaker_id: str):\n", + " self.text = text\n", + " self.id = utterance_id\n", + " self.speaker_ = _synthetic_speaker(speaker_id)\n", + " self.meta = {}\n", + "\n", + " def get_conversation(self):\n", + " return None\n", + "\n", + "\n", + "class RandomDeferralDecisionPolicy(DecisionPolicy):\n", + " \"\"\"\n", + " Decision policy that defers intervention by looking ahead at simulated next utterances.\n", + "\n", + " :param simulator: utterance simulator model (must have a ``transform(contexts)`` method\n", + " returning a DataFrame indexed by utterance id). if the simulator exposes\n", + " ``get_num_simulations()``, ``num_simulations`` is capped to that value.\n", + " :param threshold: probability threshold above which a context is flagged.\n", + " :param tau: minimum number of simulated branches that must exceed the threshold\n", + " before an intervention is issued.\n", + " :param num_simulations: how many simulated branches to use per context (capped to\n", + " simulator's ``get_num_simulations()`` if available).\n", + " :param store_simulations: if True, simulated reply strings are cached during decide()\n", + " and written to corpus utterance metadata by post_transform().\n", + " :param simulated_reply_attribute_name: metadata field name used when storing simulations\n", + " on corpus utterances (only relevant when store_simulations=True).\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " simulator,\n", + " threshold,\n", + " deferral_probability: float = 0.1515,\n", + " reuse_cached_forecast_probs: bool = True,\n", + " forecast_prob_attribute_name: str = \"forecast_prob\",\n", + " ):\n", + " # forward the cache flag to the base class so its _score helper honors it.\n", + " # without this, reuse_cached_probabilities on this subclass had no effect.\n", + " super().__init__(\n", + " forecast_prob_attribute_name=forecast_prob_attribute_name,\n", + " reuse_cached_forecast_probs=reuse_cached_forecast_probs,\n", + " )\n", + " self.simulator = simulator\n", + " self.threshold = float(threshold)\n", + " self.deferral_probability = float(deferral_probability)\n", + "\n", + " def _decision_score(self, context, score_fn: Callable):\n", + " # use base _score so a cached forecast_prob on the utterance meta is reused\n", + " # instead of re-invoking the belief estimator.\n", + " return self._score(context, score_fn)\n", + "\n", + " def decide(self, context, score_fn: Callable) -> Tuple[float, int, Optional[Dict[str, Any]]]:\n", + " decision_score = self._score(context, score_fn)\n", + "\n", + " p = float(np.random.rand())\n", + " \n", + "\n", + " # return an empty metadata dict (not None) so downstream code that iterates\n", + " # utt_metadata.items() in TransformerDecoderModel.transform doesn't crash.\n", + " return (decision_score,\n", + " 1 if decision_score > self.threshold and p > self.deferral_probability else 0,\n", + " {}\n", + " )\n", + "\n", + " def fit(self, contexts, val_contexts=None, score_fn: Callable = None):\n", + " if val_contexts is None or score_fn is None or self.labeler is None:\n", + " print(\"either no validation contexts/score function/labeler were provided, returning current threshold\")\n", + " return {\"best_threshold\": self.threshold}\n", + "\n", + " val_contexts = list(val_contexts)\n", + " if len(val_contexts) == 0:\n", + " print(\"no validation contexts were provided, returning current threshold\")\n", + " return {\"best_threshold\": self.threshold}\n", + "\n", + " fit_result = self._fit_with_model_checkpoint_selection(val_contexts, score_fn=score_fn)\n", + " if isinstance(fit_result, dict):\n", + " if \"best_threshold\" in fit_result:\n", + " self.threshold = float(fit_result[\"best_threshold\"])\n", + " return fit_result\n", + "\n", + " fit_result = self._fit_threshold_for_loaded_model(val_contexts, score_fn=score_fn)\n", + " if \"best_threshold\" in fit_result:\n", + " self.threshold = float(fit_result[\"best_threshold\"])\n", + " return fit_result" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "9d3db5bc", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Callable, Optional, Dict, Any, Tuple\n", + "\n", + "from convokit.decisionpolicy import DeferralDecisionPolicy\n", + "\n", + "\n", + "class SimulationAverageDecisionPolicy(DeferralDecisionPolicy):\n", + " \"\"\"\n", + " decision policy that intervenes if the mean of the simulated next-utterance\n", + " scores is at or above the threshold.\n", + "\n", + " this subclass inherits all simulation fetching, per-utterance metadata\n", + " caching, sim-score caching, and threshold fitting from\n", + " DeferralDecisionPolicy. the only differences are:\n", + " * no ``tau`` parameter (unused; forwarded as 0 to super)\n", + " * ``decide`` predicts based on mean(simulation_scores) >= threshold\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " simulator,\n", + " threshold,\n", + " num_simulations: int = 10,\n", + " store_simulations: bool = False,\n", + " simulated_reply_attribute_name: str = \"sim_replies\",\n", + " sim_replies_forecast_probs_attribute_name: str = \"sim_replies_forecast_probs\",\n", + " reuse_cached_simulations: bool = True,\n", + " ):\n", + " # tau is irrelevant for the mean-based decision rule, so we pin it to 0\n", + " # upstream rather than expose it to callers of this subclass.\n", + " super().__init__(\n", + " simulator=simulator,\n", + " threshold=threshold,\n", + " tau=0,\n", + " num_simulations=num_simulations,\n", + " store_simulations=store_simulations,\n", + " simulated_reply_attribute_name=simulated_reply_attribute_name,\n", + " sim_replies_forecast_probs_attribute_name=sim_replies_forecast_probs_attribute_name,\n", + " reuse_cached_simulations=reuse_cached_simulations,\n", + " )\n", + "\n", + " def decide(self, context, score_fn: Callable) -> Tuple[float, int, Optional[Dict[str, Any]]]:\n", + " decision_score, simulations, simulation_scores = self._decision_score(context, score_fn)\n", + " # empty simulation_scores would zero-divide. this happens when the\n", + " # simulator returns no completions for a context (e.g. end-of-conversation\n", + " # contexts that slip through the selector). treat as no intervention so\n", + " # a single degenerate context doesn't abort the whole transform run.\n", + " if len(simulation_scores) == 0:\n", + " average_simulation_score = 0.0\n", + " pred = 0\n", + " else:\n", + " average_simulation_score = sum(simulation_scores) / len(simulation_scores)\n", + " pred = 1 if average_simulation_score >= self.threshold else 0\n", + " return (\n", + " decision_score,\n", + " pred,\n", + " {\n", + " self.simulated_reply_attribute_name: simulations,\n", + " self.sim_replies_forecast_probs_attribute_name: simulation_scores,\n", + " },\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "d0d1ea7a", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Callable, Optional, Dict, Any, Tuple\n", + "\n", + "from convokit.decisionpolicy import DeferralDecisionPolicy\n", + "\n", + "\n", + "class SimulationMajorityDecisionPolicy(DeferralDecisionPolicy):\n", + " \"\"\"\n", + " decision policy that intervenes if at least ``tau`` of the simulated next\n", + " utterances score above the threshold, ignoring the current utterance score.\n", + "\n", + " this subclass inherits all simulation fetching, per-utterance metadata\n", + " caching, sim-score caching, and threshold fitting from\n", + " DeferralDecisionPolicy. the only difference is in ``decide``: the gate\n", + " ``decision_score > threshold`` is dropped so that only the simulated-branch\n", + " vote count drives the prediction.\n", + " \"\"\"\n", + "\n", + " def decide(self, context, score_fn: Callable) -> Tuple[float, int, Optional[Dict[str, Any]]]:\n", + " decision_score, simulations, simulation_scores = self._decision_score(context, score_fn)\n", + " num_simulations_above_threshold = sum(\n", + " 1 for score in simulation_scores if score > self.threshold\n", + " )\n", + " return (\n", + " decision_score,\n", + " 1 if num_simulations_above_threshold >= self.tau else 0,\n", + " {\n", + " self.simulated_reply_attribute_name: simulations,\n", + " self.sim_replies_forecast_probs_attribute_name: simulation_scores,\n", + " },\n", + " )\n" + ] + }, + { + "cell_type": "markdown", + "id": "0498693e", + "metadata": {}, + "source": [ + "Since we use simulations in our baselines, we must define a simulation configuration as follows: " + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "638fb31c", + "metadata": {}, + "outputs": [], + "source": [ + "SIMULATOR_TRAIN_CONFIG = {\n", + " \"per_device_train_batch_size\": 16,\n", + " \"per_device_eval_batch_size\": 16,\n", + " \"eval_strategy\": \"steps\",\n", + " \"save_strategy\": \"steps\",\n", + " \"save_steps\": 30,\n", + " \"gradient_accumulation_steps\": 4,\n", + " \"warmup_steps\": 5,\n", + " \"num_train_epochs\": 1,\n", + " \"eval_steps\": 30,\n", + " \"learning_rate\": 2e-4,\n", + " \"logging_steps\": 5,\n", + " \"optim\": \"adamw_8bit\",\n", + " \"weight_decay\": 0.01,\n", + " \"lr_scheduler_type\": \"linear\",\n", + " \"output_dir\": \"outputs/simulator_finetune\",\n", + " \"logging_dir\": \"logs\",\n", + " \"load_best_model_at_end\": True,\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "bf55d140", + "metadata": {}, + "outputs": [], + "source": [ + "TAU = 7\n", + "DEFERRAL_PROBABILITY_THRESHOLD = 0.2518938553561718\n", + "NUM_SIMULATIONS = 10\n", + "OUTPUT_DIR = \"benchmark_preannotated\"\n", + "SEEDS = [1,2,3,4,5]" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "aedd02fe", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "==((====))== Unsloth 2025.7.11: Fast Llama patching. Transformers: 4.53.3.\n", + " \\\\ /| NVIDIA RTX A6000. Num GPUs = 1. Max memory: 47.536 GB. Platform: Linux.\n", + "O^O/ \\_/ \\ Torch: 2.7.1+cu128. CUDA: 8.6. CUDA Toolkit: 12.8. Triton: 3.1.0+cf34004b8a\n", + "\\ / Bfloat16 = TRUE. FA [Xformers = 0.0.31.post1. FA2 = False]\n", + " \"-____-\" Free license: http://github.com/unslothai/unsloth\n", + "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Unsloth 2025.7.11 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.\n", + "Unsloth: Already have LoRA adapters! We shall skip this step.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Unsloth: Training embed_tokens in mixed precision to save VRAM\n", + "Unsloth: Training lm_head in mixed precision to save VRAM\n" + ] + } + ], + "source": [ + "simulator_model = UnslothUtteranceSimulatorModel(\n", + " model_name=\"/reef/lyk25/dynamic_training/game_analysis/outputs/checkpoint-74\",\n", + " train_config=SIMULATOR_TRAIN_CONFIG,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "fe147ae8", + "metadata": {}, + "source": [ + "After loading the simulator model, we define context selectors." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "a9502041", + "metadata": {}, + "outputs": [], + "source": [ + "def context_selector(context_tuple, split):\n", + " \"\"\"\n", + " We use this generic function for both training and validation data.\n", + " In both cases, its job is to select only those contexts for which the\n", + " FUTURE context is not empty, so we have a next utterance to predict.\n", + " \"\"\"\n", + " matches_split = (context_tuple.current_utterance.get_conversation().meta.get(\"split\") == split)\n", + " is_end = (len(context_tuple.future_context) == 0)\n", + " return matches_split and not is_end\n", + "\n", + "def make_data_selector(split):\n", + " return lambda context_tuple: context_tuple.current_utterance.get_conversation().meta.get(\"split\") == split\n", + "\n", + "train_context_selector = partial(context_selector, split=\"train\")\n", + "val_context_selector = partial(context_selector, split=\"val\")\n", + "test_context_selector = partial(context_selector, split=\"test\")" + ] + }, + { + "cell_type": "markdown", + "id": "f707a3a5", + "metadata": {}, + "source": [ + "Below is a script to fully reproduce, sans regenerating simulations (which takes substantially longer)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dde130ed", + "metadata": {}, + "outputs": [], + "source": [ + "for seed_idx in SEEDS:\n", + " train_corpus = Corpus(filename=download('conversations-gone-awry-cmv-corpus-large'))\n", + " corpus = Corpus(filename=f'/reef/lyk25/theres-a-way-out-ACL26-internal/outputs/benchmark_preannotated/seed-{seed_idx}-ThresholdDecisionPolicy/ThresholdDecisionPolicy')\n", + " # corpus = Corpus(filename=f'/reef/lyk25/dynamic_training/game_analysis/corpi/test/test-son-seed-{seed_idx}')\n", + "\n", + " config = TransformerForecasterConfig(\n", + " output_dir=f\"outputs/{OUTPUT_DIR}/forecaster_{seed_idx}\",\n", + " per_device_batch_size=16,\n", + " gradient_accumulation_steps=1,\n", + " num_train_epochs=1,\n", + " learning_rate=1e-5,\n", + " random_seed=seed_idx,\n", + " context_mode=\"normal\",\n", + " device=\"cuda\",\n", + " )\n", + "\n", + " # TODO this will have to be edited\n", + " forecaster_model = TransformerDecoderModel(\n", + " model_name_or_path=\"google/gemma-2-9b-it\",\n", + " config=config,\n", + " )\n", + "\n", + " forecaster = Forecaster(\n", + " forecaster_model=forecaster_model,\n", + " labeler='has_removed_comment',\n", + " )\n", + "\n", + " forecaster.fit_belief_estimator(\n", + " corpus=train_corpus,\n", + " context_selector=train_context_selector,\n", + " val_context_selector=val_context_selector,\n", + " )\n", + "\n", + "\n", + "\n", + " # ---\n", + " cfg_path = os.path.join(repo_root, \"saves\", f\"seed-{seed_idx}\", \"dev_config.json\")\n", + " with open(cfg_path) as f:\n", + " cfg = json.load(f)\n", + " best_threshold = cfg['best_threshold']\n", + "\n", + " for policy_trial in [ThresholdDecisionPolicy, DeferralDecisionPolicy]:\n", + " print('---')\n", + " print(f\"Fitting policy {policy_trial.__name__} for seed {seed_idx}\")\n", + " if policy_trial == ThresholdDecisionPolicy:\n", + " policy = ThresholdDecisionPolicy(\n", + " threshold=best_threshold,\n", + " reuse_cached_forecast_probs=False,\n", + " )\n", + " elif policy_trial == DeferralDecisionPolicy:\n", + " policy = DeferralDecisionPolicy(\n", + " simulator=simulator_model,\n", + " threshold=best_threshold,\n", + " tau=TAU,\n", + " reuse_cached_forecast_probs=False,\n", + " )\n", + " elif policy_trial == RandomDeferralDecisionPolicy:\n", + " policy = RandomDeferralDecisionPolicy(\n", + " simulator=simulator_model,\n", + " threshold=best_threshold,\n", + " deferral_probability=DEFERRAL_PROBABILITY_THRESHOLD,\n", + " reuse_cached_forecast_probs=False,\n", + " )\n", + " elif policy_trial == SimulationAverageDecisionPolicy:\n", + " policy = SimulationAverageDecisionPolicy(\n", + " simulator=simulator_model,\n", + " threshold=best_threshold,\n", + " num_simulations=NUM_SIMULATIONS,\n", + " store_simulations=False,\n", + " simulated_reply_attribute_name=\"sim_replies\",\n", + " sim_replies_forecast_probs_attribute_name=\"sim_replies_forecast_probs\",\n", + " reuse_cached_forecast_probs=False,\n", + " )\n", + " elif policy_trial == SimulationMajorityDecisionPolicy:\n", + " policy = SimulationMajorityDecisionPolicy(\n", + " simulator=simulator_model,\n", + " threshold=best_threshold,\n", + " tau=TAU,\n", + " num_simulations=NUM_SIMULATIONS,\n", + " store_simulations=False,\n", + " simulated_reply_attribute_name=\"sim_replies\",\n", + " sim_replies_forecast_probs_attribute_name=\"sim_replies_forecast_probs\",\n", + " reuse_cached_forecast_probs=False,\n", + " )\n", + " \n", + " # attach the decision policy to the underlying forecaster model;\n", + " # Forecaster itself does not accept decision_policy in its constructor.\n", + " forecaster_model.decision_policy = policy\n", + "\n", + " forecaster = Forecaster(\n", + " forecaster_model=forecaster_model,\n", + " labeler='has_removed_comment',\n", + " )\n", + "\n", + " print('starting transformation.')\n", + " # evaluate the forecaster on the test set\n", + " forecaster.transform(\n", + " corpus=corpus,\n", + " context_selector=make_data_selector('test'),\n", + " verbose=True,\n", + " )\n", + " print('transformation complete.')\n", + "\n", + " output_dir = f\"outputs/{OUTPUT_DIR}/seed-{seed_idx}-{policy_trial.__name__}\"\n", + " os.makedirs(output_dir, exist_ok=True)\n", + " corpus.dump(name=f\"{policy_trial.__name__}\", base_path=output_dir)\n", + " print('corpus dumped.')\n", + "\n", + " print('starting summarization.')\n", + " # forecaster.summarize expects a conversation-level selector (Callable[[Conversation], bool]),\n", + " # unlike the context-tuple selectors used in fit/transform.\n", + " def summarize_selector(convo):\n", + " return convo.meta.get(\"split\") == \"test\"\n", + " conversational_forecasts_df, metrics = forecaster.summarize(\n", + " corpus=corpus,\n", + " selector=summarize_selector,\n", + " )\n", + " print('summarization complete.')\n", + " \n", + " # path to the seed output directory\n", + " seed_folder = f\"outputs/{OUTPUT_DIR}/seed-{seed_idx}-{policy_trial.__name__}\"\n", + "\n", + " # ensure the directory exists\n", + " os.makedirs(seed_folder, exist_ok=True)\n", + "\n", + " # save conversational_forecasts_df as CSV\n", + " conversational_forecasts_df.to_csv(os.path.join(seed_folder, \"conversational_forecasts.csv\"), index=False)\n", + "\n", + " # save metrics as JSON\n", + " with open(os.path.join(seed_folder, \"metrics.json\"), \"w\") as f:\n", + " json.dump(metrics, f, indent=2)" + ] + }, + { + "cell_type": "markdown", + "id": "e3c6f428", + "metadata": {}, + "source": [ + "A faster reproduction is possible by skipping the training and transformation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2f4a29c2", + "metadata": {}, + "outputs": [], + "source": [ + "# TODO insert code here to download all" + ] + }, + { + "cell_type": "markdown", + "id": "29a0cdc1", + "metadata": {}, + "source": [ + "# Human Data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d673d0d6", + "metadata": {}, + "outputs": [], + "source": [ + "# import human data SQL here" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d55b3bab", + "metadata": {}, + "outputs": [], + "source": [ + "import sqlite3" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "92766a80", + "metadata": {}, + "outputs": [], + "source": [ + "# connect to the game database\n", + "db_path = '/reef/sqt2/cga-eval/human/game_db.sqlite'\n", + "connection = sqlite3.connect(db_path)\n", + "cursor = connection.cursor()\n", + "\n", + "print(f\"[info] connected to database: {db_path}\")\n", + "cursor.execute(\"SELECT COUNT(*) FROM results\")\n", + "\n", + "# # names_list = []\n", + "# # answers_list = []\n", + "# # scores_list = []\n", + "# # comments_list = []\n", + "\n", + "# TIME_CUTOFF = 1714059814630 # timestamp to cut off previous results\n", + "# TIME_CUTOFF_END = 1730385762530\n", + "\n", + "# for row in connection.execute('SELECT * FROM results'):\n", + "# if row[0] != 'yc2727' and row[0] != 'sqt2':\n", + "# answers = json.loads(row[1])\n", + "# if answers[0]['start_time'] > TIME_CUTOFF and answers[0]['start_time'] < TIME_CUTOFF_END:\n", + "# names_list.append(row[0])\n", + "# answers_list.append(answers)\n", + "# scores_list.append(row[2])\n", + "# comments_list.append(row[3])\n", + "\n", + "if round_n == 1:\n", + " # for round_n 1\n", + " names_list_1 = []\n", + " answers_list_1 = []\n", + " scores_list_1 = []\n", + " comments_list_1 = []\n", + "\n", + " TIME_CUTOFF_1 = 1730395000000\n", + " TIME_CUTOFF_END_1 = 1730397199000\n", + "\n", + " for row in connection.execute('SELECT * FROM results'):\n", + " if row[0] != 'yc2727' and row[0] != 'sqt2':\n", + " answers = json.loads(row[1])\n", + " if answers[0]['start_time'] > TIME_CUTOFF_1 and answers[0]['start_time'] < TIME_CUTOFF_END_1:\n", + " names_list_1.append(row[0])\n", + " answers_list_1.append(answers)\n", + " scores_list_1.append(row[2])\n", + " comments_list_1.append(row[3])\n", + "elif round_n == 2:\n", + " # for round 2\n", + " names_list_2 = []\n", + " answers_list_2 = []\n", + " scores_list_2 = []\n", + " comments_list_2 = []\n", + "\n", + " TIME_CUTOFF_2_a = 1731002910000\n", + " TIME_CUTOFF_END_2_a = 1731016180000\n", + "\n", + " for row in connection.execute('SELECT * FROM results'):\n", + " if row[0] != 'yc2727' and row[0] != 'sqt2':\n", + " answers = json.loads(row[1])\n", + " if answers[0]['start_time'] > TIME_CUTOFF_2_a and answers[0]['start_time'] < TIME_CUTOFF_END_2_a:\n", + " names_list_2.append(row[0])\n", + " answers_list_2.append(answers)\n", + " scores_list_2.append(row[2])\n", + " comments_list_2.append(row[3])\n", + "\n", + " # The 2_b cutoff below is to specifically add data from 'ljl2' for a later second round window,\n", + " # likely because they submitted their data late or for a different time block.\n", + " TIME_CUTOFF_2_b = 1732212720000\n", + " TIME_CUTOFF_END_2_b = 1740000000000\n", + "\n", + " for row in connection.execute('SELECT * FROM results'):\n", + " if row[0].lower() == 'ljl2':\n", + " answers = json.loads(row[1])\n", + " if answers[0]['start_time'] > TIME_CUTOFF_2_b and answers[0]['start_time'] < TIME_CUTOFF_END_2_b:\n", + " names_list_2.append(row[0])\n", + " answers_list_2.append(answers)\n", + " scores_list_2.append(row[2])\n", + " comments_list_2.append(row[3])\n", + "\n", + " # print(f\"Round 1: {len(names_list_1)} entries\")\n", + " # print(f\"Round 2: {len(names_list_2)} entries\")\n", + " # names_list_1, names_list_2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "889d7ede", + "metadata": {}, + "outputs": [], + "source": [ + "human_map1 = {\n", + " \"vn72\": ['d3efg03', 'd3f07tv', 'f1353ra', 'f13u0n2', 'fegw71k', 'fegwvia', 'ikpw1ol', 'ikpxahv', 'isz5n6g', 'isz6a7m'],\n", + " \"nac86\": ['dyd7cfv', 'dydawov', 'ewtjxp2', 'ewtowqu', 'f00nxjs', 'f014doo', 'ii3cpve', 'ii4ivim', 'ikzb571', 'ikzcblg'],\n", + " # \"sj597\": ['dbnr5nd', 'dbnxbzn', 'g1sm9mt', 'g1spg0k', 'g297nn7', 'g2998qp', 'h2h7yl0', 'h2hauaj', 'ifknq44', 'ifkoazb'],\n", + " \"yc2727\": ['d08x3kl', 'd08x4q0', 'd3efg03', 'd3f07tv', 'ewtjxp2', 'ewtowqu', 'hnupywh', 'hnuu3ip', 'ih2fn94', 'ih3iwph'],\n", + " # \"ex36\": ['d0fyu9x', 'd0fzp40', 'dg7wmdb', 'dg7x3eu', 'ej6jusi', 'ej6ollb', 'ewtjxp2', 'ewtowqu', 'f0dryv4', 'f0dsbw4'],\n", + " \"kz88\": ['doi3vuz', 'doi44j8', 'e9mybxp', 'e9mynuy', 'g1q1973', 'g1q3upb', 'ggwdbsa', 'ggwei5y', 'gmxoyuu', 'gmxqrdz'],\n", + " \"LJL2\": ['ffpj88q', 'ffpk80c', 'flixm8f', 'fljdrtt', 'fmre7l9', 'fmspcex', 'fphcwq0', 'g3hmlew', 'givok1i', 'givp578'],\n", + " \"lyk25\": ['dpirnlc', 'dpkib9q', 'e1vstxd', 'e1vv6x1', 'ewr7ls3', 'ewr7msm', 'fixjxxs', 'fixlxvy', 'fmre7l9', 'fmspcex'],\n", + " \"sqt2\": ['d3efg03', 'd3f07tv', 'g0fwpzc', 'g0ggz8e', 'g8nyz4f', 'g8nzx3m', 'h4i75b9', 'h4i79lc', 'h6ikmzc', 'h6ime20'],\n", + " \"cd326\": ['djh1bt1', 'djh2v24', 'dmlny27', 'dmlqnzz', 'f83akyz', 'f83hb2k', 'h28lknz', 'h28ltfw', 'i4rgtqf', 'i4rh1id'],\n", + " \"tg352\": ['dm8exht', 'dm8qu78', 'dvisfl0', 'dvivs9p', 'dyx2jn4', 'dyx3au1', 'gvb1ekl', 'gvdnrrb', 'i4rgtqf', 'i4rh1id'],\n", + "}\n", + "\n", + "# all_convo_ids = []\n", + "# for k,v in human_map1.items():\n", + "# all_convo_ids.extend(v)\n", + "# all_convo_ids = list(set(all_convo_ids))\n", + "# len(all_convo_ids)\n", + "\n", + "all_convo_ids = []\n", + "# # collect from round 1\n", + "if round_n == 1:\n", + " for i in range(len(answers_list_1)):\n", + " for j in range(len(answers_list_1[i])):\n", + " all_convo_ids.append(answers_list_1[i][j]['id'])\n", + "# collect from round 2\n", + "elif round_n == 2:\n", + " for i in range(len(answers_list_2)):\n", + " for j in range(len(answers_list_2[i])):\n", + " all_convo_ids.append(answers_list_2[i][j]['id'])\n", + "all_convo_ids = list(set(all_convo_ids))\n", + "print(f\"[info] total unique conversation ids: {len(all_convo_ids)}\")\n", + "print(all_convo_ids)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "168fb025", + "metadata": {}, + "outputs": [], + "source": [ + "# we want all utterances to have a human_guesses meta field\n", + "for convo in corpus.iter_conversations():\n", + " for utt in convo.get_chronological_utterance_list():\n", + " utt.add_meta('human_guesses', [])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27fc646a", + "metadata": {}, + "outputs": [], + "source": [ + "# now we r gonna add the human guesses to the utterances, with their player ids\n", + "\n", + "# extract human first awry guess indices, now using answers_list_2\n", + "\n", + "# from collections import defaultdict\n", + "\n", + "# for i, participant_sessions in enumerate(answers_list):\n", + "# participant_id = names_list[i]\n", + "# print(participant_id)\n", + "# for session in participant_sessions:\n", + "# convo_id = session['id']\n", + "# convo = corpus.get_conversation(convo_id)\n", + "# utts = convo.get_chronological_utterance_list()\n", + "# actions = session.get('actions', [])\n", + "\n", + "# for action_idx, action in enumerate(actions):\n", + "# if action.get('guess') is True:\n", + "# utt = utts[action_idx]\n", + "# # get current guesses (returns deep copy if exists, or empty list if not)\n", + "# # create a new list to avoid mutating the copy\n", + "# current_guesses = list(utt.meta.get('human_guesses', []))\n", + "# # append new participant and set back\n", + "# current_guesses.append(participant_id)\n", + "# utt.add_meta('human_guesses', current_guesses)\n", + "if round_n == 1:\n", + " for i, participant_sessions in enumerate(answers_list_1):\n", + " participant_id = names_list_1[i]\n", + " print(participant_id)\n", + " for session in participant_sessions:\n", + " convo_id = session['id']\n", + " convo = corpus.get_conversation(convo_id)\n", + " utts = convo.get_chronological_utterance_list()\n", + " actions = session.get('actions', [])\n", + "\n", + " for action_idx, action in enumerate(actions):\n", + " if action.get('guess') is True:\n", + " utt = utts[action_idx]\n", + " # get current guesses (returns deep copy if exists, or empty list if not)\n", + " # create a new list to avoid mutating the copy\n", + " current_guesses = list(utt.meta.get('human_guesses', []))\n", + " # append new participant and set back\n", + " current_guesses.append(participant_id)\n", + " utt.add_meta('human_guesses', current_guesses)\n", + "elif round_n == 2:\n", + " for i, participant_sessions in enumerate(answers_list_2):\n", + " participant_id = names_list_2[i]\n", + " print(participant_id)\n", + " for session in participant_sessions:\n", + " convo_id = session['id']\n", + " convo = corpus.get_conversation(convo_id)\n", + " utts = convo.get_chronological_utterance_list()\n", + " actions = session.get('actions', [])\n", + "\n", + " for action_idx, action in enumerate(actions):\n", + " if action.get('guess') is True:\n", + " utt = utts[action_idx]\n", + " # get current guesses (returns deep copy if exists, or empty list if not)\n", + " # create a new list to avoid mutating the copy\n", + " current_guesses = list(utt.meta.get('human_guesses', []))\n", + " # append new participant and set back\n", + " current_guesses.append(participant_id)\n", + " utt.add_meta('human_guesses', current_guesses)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fd613b4b", + "metadata": {}, + "outputs": [], + "source": [ + "# initialize model prediction metadata fields for each utterance\n", + "for convo in corpus.iter_conversations():\n", + " for utt in convo.get_chronological_utterance_list():\n", + " utt.add_meta('model_forecast_probs', {}) # will store {seed: prob}\n", + " utt.add_meta('model_forecasts', {}) # will store {seed: binary_forecast}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "873cb706", + "metadata": {}, + "outputs": [], + "source": [ + "# TODO need to choose which decisionpolicy/forecaster run to use to get the forecast probabilities\n", + "\n", + "# load model predictions from all 5 seeds\n", + "test_corpi_path = '/reef/lyk25/dynamic_training/game_analysis/corpi/test'\n", + "\n", + "for seed_num in range(1, 6):\n", + " seed_corpus_path = f'{test_corpi_path}/test-son-seed-{seed_num}'\n", + " print(f\"[info] loading predictions from seed {seed_num}: {seed_corpus_path}\")\n", + " \n", + " # load the seed corpus\n", + " seed_corpus = Corpus(filename=seed_corpus_path)\n", + " \n", + " # iterate through conversations in our main corpus\n", + " for convo in corpus.iter_conversations():\n", + " convo_id = convo.id\n", + " \n", + " # check if this conversation exists in the seed corpus\n", + " if convo_id not in seed_corpus.conversations:\n", + " print(f\"[warning] convo {convo_id} not found in seed {seed_num}\")\n", + " continue\n", + " \n", + " # get the corresponding conversation from seed corpus\n", + " seed_convo = seed_corpus.get_conversation(convo_id)\n", + " \n", + " # get utterances from both corpora\n", + " main_utts = convo.get_chronological_utterance_list()\n", + " seed_utts = seed_convo.get_chronological_utterance_list()\n", + " \n", + " # match utterances and copy predictions\n", + " for main_utt, seed_utt in zip(main_utts, seed_utts):\n", + " # verify they're the same utterance\n", + " if main_utt.id != seed_utt.id:\n", + " print(f\"[warning] utterance mismatch: {main_utt.id} != {seed_utt.id}\")\n", + " continue\n", + " \n", + " # extract predictions from seed utterance\n", + " forecast_prob = seed_utt.meta.get('forecast_prob', None)\n", + " forecast = seed_utt.meta.get('forecast', None)\n", + " \n", + " # add to main utterance's metadata\n", + " if forecast_prob is not None:\n", + " current_probs = dict(main_utt.meta.get('model_forecast_probs', {}))\n", + " current_probs[f'seed_{seed_num}'] = forecast_prob\n", + " main_utt.add_meta('model_forecast_probs', current_probs)\n", + " \n", + " if forecast is not None:\n", + " current_forecasts = dict(main_utt.meta.get('model_forecasts', {}))\n", + " current_forecasts[f'seed_{seed_num}'] = forecast\n", + " main_utt.add_meta('model_forecasts', current_forecasts)\n", + "\n", + "print(\"\\n[pass] model predictions from all seeds added to corpus\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5b0f0e44", + "metadata": {}, + "outputs": [], + "source": [ + "# horizon comparison between gemma and humans, using the same convention as\n", + "# performance_utils_wiki.calculate_current_performance:\n", + "# - only count true positives (ground truth awry AND trigger fired)\n", + "# - horizon = (len(utts) - first_trigger_index) - 1\n", + "# - first trigger only (break after first positive)\n", + "#\n", + "# reports per-round results for humans (round 1 and round 2), pulling data\n", + "# directly from the game sqlite so it works regardless of which round_n the\n", + "# rest of the notebook was run under. ljl2 is excluded.\n", + "\n", + "import sqlite3\n", + "import json as _json\n", + "from collections import defaultdict\n", + "import numpy as np\n", + "\n", + "def _horizon_mean(horizons):\n", + " # matches performance_utils_wiki: h = mean(horizons) - 1\n", + " if len(horizons) == 0:\n", + " return float('nan'), 0\n", + " return float(np.mean(horizons)) - 1, len(horizons)\n", + "\n", + "# always excluded (staff/test accounts). ljl2 is a legitimate late submitter\n", + "# in the round 2_b window, not an exclusion.\n", + "base_excluded = {'yc2727', 'sqt2'}\n", + "round1_excluded = set(base_excluded)\n", + "round2_excluded = set(base_excluded)\n", + "\n", + "# gemma: per-seed first-trigger horizons on TP convos only (ground truth awry)\n", + "gemma_horizons_by_seed = defaultdict(list)\n", + "for convo in corpus.iter_conversations():\n", + " if not convo.meta.get('has_removed_comment', False):\n", + " continue # skip non-awry convos; horizon is only meaningful on TPs\n", + " utts = convo.get_chronological_utterance_list()\n", + " for seed_num in range(1, 6):\n", + " for i, utt in enumerate(utts):\n", + " if utt.meta.get('model_forecasts', {}).get(f'seed_{seed_num}') == 1:\n", + " gemma_horizons_by_seed[seed_num].append(len(utts) - i)\n", + " break\n", + "\n", + "print('gemma horizon (TPs only, mean - 1):')\n", + "gemma_per_seed_h = []\n", + "for seed_num in sorted(gemma_horizons_by_seed.keys()):\n", + " h, n = _horizon_mean(gemma_horizons_by_seed[seed_num])\n", + " gemma_per_seed_h.append(h)\n", + " print(f' seed {seed_num}: n={n}, h={h:.4f}')\n", + "gemma_h_mean_of_seeds = float(np.mean(gemma_per_seed_h)) if gemma_per_seed_h else float('nan')\n", + "all_gemma_vals = [v for vs in gemma_horizons_by_seed.values() for v in vs]\n", + "gemma_h_pooled, gemma_n_pooled = _horizon_mean(all_gemma_vals)\n", + "print(f' mean over seeds: h={gemma_h_mean_of_seeds:.4f}')\n", + "print(f' pooled: n={gemma_n_pooled}, h={gemma_h_pooled:.4f}')\n", + "\n", + "# pull round-specific human guesses directly from the sqlite db\n", + "db_path = '/reef/sqt2/cga-eval/human/game_db.sqlite'\n", + "_conn = sqlite3.connect(db_path)\n", + "\n", + "def _load_round_answers(round_n):\n", + " # returns list of (participant_id, sessions) for the given round\n", + " entries = []\n", + " if round_n == 1:\n", + " excluded = round1_excluded\n", + " t_start, t_end = 1730395000000, 1730397199000\n", + " for row in _conn.execute('SELECT * FROM results'):\n", + " if row[0] in excluded:\n", + " continue\n", + " answers = _json.loads(row[1])\n", + " if answers and answers[0]['start_time'] > t_start and answers[0]['start_time'] < t_end:\n", + " entries.append((row[0], answers))\n", + " elif round_n == 2:\n", + " excluded = round2_excluded\n", + " t_start_a, t_end_a = 1731002910000, 1731016180000\n", + " t_start_b, t_end_b = 1732212720000, 1740000000000\n", + " for row in _conn.execute('SELECT * FROM results'):\n", + " name = row[0]\n", + " if name in excluded:\n", + " continue\n", + " answers = _json.loads(row[1])\n", + " if not answers:\n", + " continue\n", + " st = answers[0]['start_time']\n", + " in_a = t_start_a < st < t_end_a\n", + " # the 2_b late window is only valid for ljl2 (matches cell 7)\n", + " in_b = (name.lower() == 'ljl2') and (t_start_b < st < t_end_b)\n", + " if in_a or in_b:\n", + " entries.append((name, answers))\n", + " return entries\n", + "\n", + "def _unique_convo_ids(entries):\n", + " # unique convo ids seen across all included players' sessions\n", + " ids = set()\n", + " for _, sessions in entries:\n", + " for session in sessions:\n", + " cid = session.get('id')\n", + " if cid is not None:\n", + " ids.add(cid)\n", + " return ids\n", + "\n", + "def _compute_human_horizons(entries):\n", + " # returns dict: player_id -> list of horizons (non-awry convos only, first guess per convo)\n", + " by_player = defaultdict(list)\n", + " for participant_id, sessions in entries:\n", + " for session in sessions:\n", + " convo_id = session.get('id')\n", + " if convo_id is None:\n", + " continue\n", + " try:\n", + " convo = corpus.get_conversation(convo_id)\n", + " except KeyError:\n", + " # convo not in the loaded corpus (e.g., different round loaded)\n", + " continue\n", + " if not convo.meta.get('has_removed_comment', False):\n", + " continue # skip non-awry convos; horizon only counted on TPs\n", + " utts = convo.get_chronological_utterance_list()\n", + " for action_idx, action in enumerate(session.get('actions', [])):\n", + " if action.get('guess') is True:\n", + " if action_idx < len(utts):\n", + " by_player[participant_id].append(len(utts) - action_idx)\n", + " break # only first guess per (player, convo)\n", + " return by_player\n", + "\n", + "def _print_human_horizons(label, by_player):\n", + " print(f'human horizon ({label}, TPs only, mean - 1):')\n", + " player_h = []\n", + " for player, horizons in by_player.items():\n", + " h, n = _horizon_mean(horizons)\n", + " player_h.append(h)\n", + " print(f' player {player}: n={n}, h={h:.4f}')\n", + " mean_of_players = float(np.mean(player_h)) if player_h else float('nan')\n", + " all_vals = [v for vs in by_player.values() for v in vs]\n", + " h_pooled, n_pooled = _horizon_mean(all_vals)\n", + " print(f' mean over players: h={mean_of_players:.4f}')\n", + " print(f' pooled: n={n_pooled}, h={h_pooled:.4f}')\n", + " return mean_of_players, h_pooled\n", + "\n", + "print()\n", + "EXPECTED_N_CONVOS = 84\n", + "round1_entries = _load_round_answers(1)\n", + "round1_convo_ids = _unique_convo_ids(round1_entries)\n", + "print(f'round 1: {len(round1_entries)} included players, {len(round1_convo_ids)} unique convos')\n", + "assert len(round1_convo_ids) == EXPECTED_N_CONVOS, (\n", + " f'round 1 expected {EXPECTED_N_CONVOS} unique convos but got {len(round1_convo_ids)}'\n", + ")\n", + "round1_by_player = _compute_human_horizons(round1_entries)\n", + "r1_mean, r1_pooled = _print_human_horizons('round 1', round1_by_player)\n", + "\n", + "print()\n", + "round2_entries = _load_round_answers(2)\n", + "round2_convo_ids = _unique_convo_ids(round2_entries)\n", + "print(f'round 2: {len(round2_entries)} included players, {len(round2_convo_ids)} unique convos')\n", + "assert len(round2_convo_ids) == EXPECTED_N_CONVOS, (\n", + " f'round 2 expected {EXPECTED_N_CONVOS} unique convos but got {len(round2_convo_ids)}'\n", + ")\n", + "round2_by_player = _compute_human_horizons(round2_entries)\n", + "r2_mean, r2_pooled = _print_human_horizons('round 2', round2_by_player)\n", + "\n", + "_conn.close()\n", + "\n", + "print()\n", + "print('summary (comparable to h in performance_utils_wiki):')\n", + "print(f' gemma h (mean over seeds): {gemma_h_mean_of_seeds:.4f}')\n", + "print(f' gemma h (pooled): {gemma_h_pooled:.4f}')\n", + "print(f' round1 h (mean over players): {r1_mean:.4f}')\n", + "print(f' round1 h (pooled): {r1_pooled:.4f}')\n", + "print(f' round2 h (mean over players): {r2_mean:.4f}')\n", + "print(f' round2 h (pooled): {r2_pooled:.4f}')\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "44ed25ac", + "metadata": {}, + "outputs": [], + "source": [ + "# benchmark humans, mirroring the gemma benchmark (cell 40).\n", + "# computes accuracy, precision, recall, f1, fpr, specificity, fnr for:\n", + "# - gemma (aggregated over seeds 1-5)\n", + "# - round 1 humans (aggregate \"any\" rule, plus per-player)\n", + "# - round 2 humans (aggregate \"any\" rule, plus per-player)\n", + "# uses the included entries loaded in the previous cell:\n", + "# round1_entries, round2_entries\n", + "# and the corpus in memory for ground truth.\n", + "from collections import defaultdict\n", + "import numpy as np\n", + "def _compute_metrics(tp, fp, tn, fn):\n", + " total = tp + fp + tn + fn\n", + " accuracy = (tp + tn) / total if total > 0 else 0.0\n", + " precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0\n", + " recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0\n", + " f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0\n", + " fpr = fp / (fp + tn) if (fp + tn) > 0 else 0.0\n", + " specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0\n", + " fnr = fn / (fn + tp) if (fn + tp) > 0 else 0.0\n", + " return {\n", + " 'tp': tp, 'fp': fp, 'tn': tn, 'fn': fn,\n", + " 'accuracy': accuracy, 'precision': precision, 'recall': recall,\n", + " 'f1': f1, 'fpr': fpr, 'specificity': specificity, 'fnr': fnr,\n", + " }\n", + "def _gt(convo):\n", + " return bool(convo.meta.get('has_removed_comment', False))\n", + "# gemma benchmark, per seed, then mean and std\n", + "gemma_per_seed = []\n", + "for seed_num in range(1, 6):\n", + " tp = fp = tn = fn = 0\n", + " for convo in corpus.iter_conversations():\n", + " utts = convo.get_chronological_utterance_list()\n", + " pred = any(\n", + " utt.meta.get('model_forecasts', {}).get(f'seed_{seed_num}') == 1\n", + " for utt in utts\n", + " )\n", + " truth = _gt(convo)\n", + " if pred and truth: tp += 1\n", + " elif pred and not truth: fp += 1\n", + " elif not pred and not truth: tn += 1\n", + " elif not pred and truth: fn += 1\n", + " gemma_per_seed.append(_compute_metrics(tp, fp, tn, fn))\n", + "gemma_mean = {k: float(np.mean([m[k] for m in gemma_per_seed])) for k in gemma_per_seed[0]}\n", + "gemma_std = {k: float(np.std([m[k] for m in gemma_per_seed], ddof=1)) for k in gemma_per_seed[0]}\n", + "def _build_per_player_preds(entries):\n", + " # returns dict[player_id] -> dict[convo_id] -> 0/1\n", + " preds = defaultdict(dict)\n", + " for player, sessions in entries:\n", + " for s in sessions:\n", + " cid = s.get('id')\n", + " if cid is None:\n", + " continue\n", + " try:\n", + " corpus.get_conversation(cid)\n", + " except KeyError:\n", + " continue\n", + " actions = s.get('actions', [])\n", + " preds[player][cid] = 1 if any(a.get('guess') is True for a in actions) else 0\n", + " return preds\n", + "def _aggregate_any(preds):\n", + " # per-convo prediction = 1 if any player who saw it guessed\n", + " by_convo = defaultdict(list)\n", + " for player, convo_preds in preds.items():\n", + " for cid, p in convo_preds.items():\n", + " by_convo[cid].append(p)\n", + " tp = fp = tn = fn = 0\n", + " for cid, plist in by_convo.items():\n", + " try:\n", + " convo = corpus.get_conversation(cid)\n", + " except KeyError:\n", + " continue\n", + " pred = 1 if any(plist) else 0\n", + " truth = 1 if _gt(convo) else 0\n", + " if pred and truth: tp += 1\n", + " elif pred and not truth: fp += 1\n", + " elif not pred and not truth: tn += 1\n", + " elif not pred and truth: fn += 1\n", + " return _compute_metrics(tp, fp, tn, fn), len(by_convo)\n", + "def _per_player_metrics(preds):\n", + " rows = {}\n", + " for player, convo_preds in preds.items():\n", + " tp = fp = tn = fn = 0\n", + " for cid, p in convo_preds.items():\n", + " try:\n", + " convo = corpus.get_conversation(cid)\n", + " except KeyError:\n", + " continue\n", + " truth = 1 if _gt(convo) else 0\n", + " if p and truth: tp += 1\n", + " elif p and not truth: fp += 1\n", + " elif not p and not truth: tn += 1\n", + " elif not p and truth: fn += 1\n", + " rows[player] = _compute_metrics(tp, fp, tn, fn)\n", + " return rows\n", + "# round 1 and round 2 humans\n", + "round1_preds = _build_per_player_preds(round1_entries)\n", + "round1_agg, round1_n_convos = _aggregate_any(round1_preds)\n", + "round1_per_player = _per_player_metrics(round1_preds)\n", + "round2_preds = _build_per_player_preds(round2_entries)\n", + "round2_agg, round2_n_convos = _aggregate_any(round2_preds)\n", + "round2_per_player = _per_player_metrics(round2_preds)\n", + "# metric mean over players (a different aggregation view)\n", + "def _mean_over_players(per_player):\n", + " keys = ['accuracy', 'precision', 'recall', 'f1', 'fpr', 'specificity', 'fnr']\n", + " return {k: float(np.mean([m[k] for m in per_player.values()])) for k in keys}\n", + "round1_mean_over_players = _mean_over_players(round1_per_player)\n", + "round2_mean_over_players = _mean_over_players(round2_per_player)\n", + "# printing\n", + "metric_order = ['accuracy', 'precision', 'recall', 'f1', 'fpr', 'specificity', 'fnr']\n", + "def _print_per_player(label, per_player):\n", + " print(f'{label} - per-player metrics:')\n", + " header = f\" {'player':<10}\" + \"\".join(f\"{m:>12}\" for m in metric_order) + f\"{'n_convos':>10}\"\n", + " print(header)\n", + " print(' ' + '-' * (len(header) - 2))\n", + " for player in sorted(per_player):\n", + " m = per_player[player]\n", + " n = m['tp'] + m['fp'] + m['tn'] + m['fn']\n", + " row = f\" {player:<10}\" + \"\".join(f\"{m[k]:>12.4f}\" for k in metric_order) + f\"{n:>10d}\"\n", + " print(row)\n", + "_print_per_player('round 1', round1_per_player)\n", + "print()\n", + "_print_per_player('round 2', round2_per_player)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "lyk25-env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From e448225da65af060f2bc9fe4235005a74cb09341 Mon Sep 17 00:00:00 2001 From: laerdon Date: Tue, 28 Apr 2026 07:35:48 +0000 Subject: [PATCH 5/6] update notebook --- .../decisionpolicy/decisionpolicy_demo.ipynb | 3738 +++++++++++------ 1 file changed, 2530 insertions(+), 1208 deletions(-) diff --git a/examples/decisionpolicy/decisionpolicy_demo.ipynb b/examples/decisionpolicy/decisionpolicy_demo.ipynb index a3b44be6f..79caffd36 100644 --- a/examples/decisionpolicy/decisionpolicy_demo.ipynb +++ b/examples/decisionpolicy/decisionpolicy_demo.ipynb @@ -1,1217 +1,2539 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "90659cc6", - "metadata": {}, - "source": [ - "# Decision Policy Demo\n", - "\n", - "This notebook will provide code demonstrating how to use Decision Policies as introduced in Wait! There's a Way Out. This notebook will also provide code for running the experiments in the paper. " - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "703021a8", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "os.environ['CUDA_VISIBLE_DEVICES'] = '2'" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fd8b87be", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import argparse\n", - "import sys\n", - "import glob\n", - "\n", - "from functools import partial\n", - "import json\n", - "from convokit import Corpus, Forecaster, download\n", - "from convokit.forecaster.TransformerDecoderModel import TransformerDecoderModel\n", - "from convokit.forecaster.TransformerForecasterConfig import TransformerForecasterConfig\n", - "from convokit.decisionpolicy import DeferralDecisionPolicy, ThresholdDecisionPolicy\n", - "from convokit.utterance_simulator.unslothUtteranceSimulatorModel import UnslothUtteranceSimulatorModel" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "15f977d3", - "metadata": {}, - "outputs": [], - "source": [ - "# Set repo root\n", - "\n", - "from pathlib import Path\n", - "\n", - "repo_root = Path.cwd()\n", - "while repo_root.name != \"ConvoKit\":\n", - " repo_root = repo_root.parent\n", - "\n", - "repo_root = str(repo_root)" - ] - }, - { - "cell_type": "markdown", - "id": "dab74642", - "metadata": {}, - "source": [ - "Having imported our DeferralDecisionPolicy and ThresholdDecisionPolicy, we now will first define all of the other decision policies to benchmark. " - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "5dabf0bd", - "metadata": {}, - "outputs": [], - "source": [ - "from typing import Callable, List, Optional, Dict, Any, Tuple\n", - "import numpy as np\n", - "from convokit.decisionpolicy import DecisionPolicy\n", - "\n", - "\n", - "class _synthetic_speaker:\n", - " def __init__(self, speaker_id: str):\n", - " self.id = speaker_id\n", - "\n", - "\n", - "class _synthetic_utterance:\n", - " def __init__(self, text: str, utterance_id: str, speaker_id: str):\n", - " self.text = text\n", - " self.id = utterance_id\n", - " self.speaker_ = _synthetic_speaker(speaker_id)\n", - " self.meta = {}\n", - "\n", - " def get_conversation(self):\n", - " return None\n", - "\n", - "\n", - "class RandomDeferralDecisionPolicy(DecisionPolicy):\n", - " \"\"\"\n", - " Decision policy that defers intervention by looking ahead at simulated next utterances.\n", - "\n", - " :param simulator: utterance simulator model (must have a ``transform(contexts)`` method\n", - " returning a DataFrame indexed by utterance id). if the simulator exposes\n", - " ``get_num_simulations()``, ``num_simulations`` is capped to that value.\n", - " :param threshold: probability threshold above which a context is flagged.\n", - " :param tau: minimum number of simulated branches that must exceed the threshold\n", - " before an intervention is issued.\n", - " :param num_simulations: how many simulated branches to use per context (capped to\n", - " simulator's ``get_num_simulations()`` if available).\n", - " :param store_simulations: if True, simulated reply strings are cached during decide()\n", - " and written to corpus utterance metadata by post_transform().\n", - " :param simulated_reply_attribute_name: metadata field name used when storing simulations\n", - " on corpus utterances (only relevant when store_simulations=True).\n", - " \"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " simulator,\n", - " threshold,\n", - " deferral_probability: float = 0.1515,\n", - " reuse_cached_forecast_probs: bool = True,\n", - " forecast_prob_attribute_name: str = \"forecast_prob\",\n", - " ):\n", - " # forward the cache flag to the base class so its _score helper honors it.\n", - " # without this, reuse_cached_probabilities on this subclass had no effect.\n", - " super().__init__(\n", - " forecast_prob_attribute_name=forecast_prob_attribute_name,\n", - " reuse_cached_forecast_probs=reuse_cached_forecast_probs,\n", - " )\n", - " self.simulator = simulator\n", - " self.threshold = float(threshold)\n", - " self.deferral_probability = float(deferral_probability)\n", - "\n", - " def _decision_score(self, context, score_fn: Callable):\n", - " # use base _score so a cached forecast_prob on the utterance meta is reused\n", - " # instead of re-invoking the belief estimator.\n", - " return self._score(context, score_fn)\n", - "\n", - " def decide(self, context, score_fn: Callable) -> Tuple[float, int, Optional[Dict[str, Any]]]:\n", - " decision_score = self._score(context, score_fn)\n", - "\n", - " p = float(np.random.rand())\n", - " \n", - "\n", - " # return an empty metadata dict (not None) so downstream code that iterates\n", - " # utt_metadata.items() in TransformerDecoderModel.transform doesn't crash.\n", - " return (decision_score,\n", - " 1 if decision_score > self.threshold and p > self.deferral_probability else 0,\n", - " {}\n", - " )\n", - "\n", - " def fit(self, contexts, val_contexts=None, score_fn: Callable = None):\n", - " if val_contexts is None or score_fn is None or self.labeler is None:\n", - " print(\"either no validation contexts/score function/labeler were provided, returning current threshold\")\n", - " return {\"best_threshold\": self.threshold}\n", - "\n", - " val_contexts = list(val_contexts)\n", - " if len(val_contexts) == 0:\n", - " print(\"no validation contexts were provided, returning current threshold\")\n", - " return {\"best_threshold\": self.threshold}\n", - "\n", - " fit_result = self._fit_with_model_checkpoint_selection(val_contexts, score_fn=score_fn)\n", - " if isinstance(fit_result, dict):\n", - " if \"best_threshold\" in fit_result:\n", - " self.threshold = float(fit_result[\"best_threshold\"])\n", - " return fit_result\n", - "\n", - " fit_result = self._fit_threshold_for_loaded_model(val_contexts, score_fn=score_fn)\n", - " if \"best_threshold\" in fit_result:\n", - " self.threshold = float(fit_result[\"best_threshold\"])\n", - " return fit_result" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "9d3db5bc", - "metadata": {}, - "outputs": [], - "source": [ - "from typing import Callable, Optional, Dict, Any, Tuple\n", - "\n", - "from convokit.decisionpolicy import DeferralDecisionPolicy\n", - "\n", - "\n", - "class SimulationAverageDecisionPolicy(DeferralDecisionPolicy):\n", - " \"\"\"\n", - " decision policy that intervenes if the mean of the simulated next-utterance\n", - " scores is at or above the threshold.\n", - "\n", - " this subclass inherits all simulation fetching, per-utterance metadata\n", - " caching, sim-score caching, and threshold fitting from\n", - " DeferralDecisionPolicy. the only differences are:\n", - " * no ``tau`` parameter (unused; forwarded as 0 to super)\n", - " * ``decide`` predicts based on mean(simulation_scores) >= threshold\n", - " \"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " simulator,\n", - " threshold,\n", - " num_simulations: int = 10,\n", - " store_simulations: bool = False,\n", - " simulated_reply_attribute_name: str = \"sim_replies\",\n", - " sim_replies_forecast_probs_attribute_name: str = \"sim_replies_forecast_probs\",\n", - " reuse_cached_simulations: bool = True,\n", - " ):\n", - " # tau is irrelevant for the mean-based decision rule, so we pin it to 0\n", - " # upstream rather than expose it to callers of this subclass.\n", - " super().__init__(\n", - " simulator=simulator,\n", - " threshold=threshold,\n", - " tau=0,\n", - " num_simulations=num_simulations,\n", - " store_simulations=store_simulations,\n", - " simulated_reply_attribute_name=simulated_reply_attribute_name,\n", - " sim_replies_forecast_probs_attribute_name=sim_replies_forecast_probs_attribute_name,\n", - " reuse_cached_simulations=reuse_cached_simulations,\n", - " )\n", - "\n", - " def decide(self, context, score_fn: Callable) -> Tuple[float, int, Optional[Dict[str, Any]]]:\n", - " decision_score, simulations, simulation_scores = self._decision_score(context, score_fn)\n", - " # empty simulation_scores would zero-divide. this happens when the\n", - " # simulator returns no completions for a context (e.g. end-of-conversation\n", - " # contexts that slip through the selector). treat as no intervention so\n", - " # a single degenerate context doesn't abort the whole transform run.\n", - " if len(simulation_scores) == 0:\n", - " average_simulation_score = 0.0\n", - " pred = 0\n", - " else:\n", - " average_simulation_score = sum(simulation_scores) / len(simulation_scores)\n", - " pred = 1 if average_simulation_score >= self.threshold else 0\n", - " return (\n", - " decision_score,\n", - " pred,\n", - " {\n", - " self.simulated_reply_attribute_name: simulations,\n", - " self.sim_replies_forecast_probs_attribute_name: simulation_scores,\n", - " },\n", - " )\n" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "d0d1ea7a", - "metadata": {}, - "outputs": [], - "source": [ - "from typing import Callable, Optional, Dict, Any, Tuple\n", - "\n", - "from convokit.decisionpolicy import DeferralDecisionPolicy\n", - "\n", - "\n", - "class SimulationMajorityDecisionPolicy(DeferralDecisionPolicy):\n", - " \"\"\"\n", - " decision policy that intervenes if at least ``tau`` of the simulated next\n", - " utterances score above the threshold, ignoring the current utterance score.\n", - "\n", - " this subclass inherits all simulation fetching, per-utterance metadata\n", - " caching, sim-score caching, and threshold fitting from\n", - " DeferralDecisionPolicy. the only difference is in ``decide``: the gate\n", - " ``decision_score > threshold`` is dropped so that only the simulated-branch\n", - " vote count drives the prediction.\n", - " \"\"\"\n", - "\n", - " def decide(self, context, score_fn: Callable) -> Tuple[float, int, Optional[Dict[str, Any]]]:\n", - " decision_score, simulations, simulation_scores = self._decision_score(context, score_fn)\n", - " num_simulations_above_threshold = sum(\n", - " 1 for score in simulation_scores if score > self.threshold\n", - " )\n", - " return (\n", - " decision_score,\n", - " 1 if num_simulations_above_threshold >= self.tau else 0,\n", - " {\n", - " self.simulated_reply_attribute_name: simulations,\n", - " self.sim_replies_forecast_probs_attribute_name: simulation_scores,\n", - " },\n", - " )\n" - ] - }, - { - "cell_type": "markdown", - "id": "0498693e", - "metadata": {}, - "source": [ - "Since we use simulations in our baselines, we must define a simulation configuration as follows: " - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "638fb31c", - "metadata": {}, - "outputs": [], - "source": [ - "SIMULATOR_TRAIN_CONFIG = {\n", - " \"per_device_train_batch_size\": 16,\n", - " \"per_device_eval_batch_size\": 16,\n", - " \"eval_strategy\": \"steps\",\n", - " \"save_strategy\": \"steps\",\n", - " \"save_steps\": 30,\n", - " \"gradient_accumulation_steps\": 4,\n", - " \"warmup_steps\": 5,\n", - " \"num_train_epochs\": 1,\n", - " \"eval_steps\": 30,\n", - " \"learning_rate\": 2e-4,\n", - " \"logging_steps\": 5,\n", - " \"optim\": \"adamw_8bit\",\n", - " \"weight_decay\": 0.01,\n", - " \"lr_scheduler_type\": \"linear\",\n", - " \"output_dir\": \"outputs/simulator_finetune\",\n", - " \"logging_dir\": \"logs\",\n", - " \"load_best_model_at_end\": True,\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "bf55d140", - "metadata": {}, - "outputs": [], - "source": [ - "TAU = 7\n", - "DEFERRAL_PROBABILITY_THRESHOLD = 0.2518938553561718\n", - "NUM_SIMULATIONS = 10\n", - "OUTPUT_DIR = \"benchmark_preannotated\"\n", - "SEEDS = [1,2,3,4,5]" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "aedd02fe", - "metadata": {}, - "outputs": [ + "cells": [ + { + "cell_type": "markdown", + "id": "90659cc6", + "metadata": {}, + "source": [ + "# Decision Policy Demo\n", + "\n", + "This notebook will provide code demonstrating how to use Decision Policies as introduced in Wait! There's a Way Out. This notebook will also provide code for running the experiments in the paper. " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "703021a8", + "metadata": {}, + "outputs": [], + "source": [ + "# TODO\n", + "\n", + "import os\n", + "os.environ['CUDA_VISIBLE_DEVICES'] = '2'" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "fd8b87be", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2026-04-28 07:19:59.106329: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", + "2026-04-28 07:19:59.126464: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", + "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", + "E0000 00:00:1777360799.150730 1126746 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", + "E0000 00:00:1777360799.158776 1126746 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", + "W0000 00:00:1777360799.179235 1126746 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", + "W0000 00:00:1777360799.179259 1126746 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", + "W0000 00:00:1777360799.179261 1126746 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", + "W0000 00:00:1777360799.179264 1126746 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", + "2026-04-28 07:19:59.185112: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "🦥 Unsloth Zoo will now patch everything to make training faster!\n" + ] + } + ], + "source": [ + "import os\n", + "import argparse\n", + "import sys\n", + "import glob\n", + "\n", + "from functools import partial\n", + "import json\n", + "from convokit import Corpus, Forecaster, download\n", + "from convokit.forecaster.TransformerDecoderModel import TransformerDecoderModel\n", + "from convokit.forecaster.TransformerForecasterConfig import TransformerForecasterConfig\n", + "from convokit.decisionpolicy import DeferralDecisionPolicy, ThresholdDecisionPolicy\n", + "from convokit.utterance_simulator.unslothUtteranceSimulatorModel import UnslothUtteranceSimulatorModel" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "15f977d3", + "metadata": {}, + "outputs": [], + "source": [ + "# Set repo root\n", + "\n", + "from pathlib import Path\n", + "\n", + "repo_root = Path.cwd()\n", + "while repo_root.name != \"ConvoKit\":\n", + " repo_root = repo_root.parent\n", + "\n", + "repo_root = str(repo_root)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18f2375b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[info] using cached decisionpolicy-demo at /home/lyk25/.convokit/saved-corpora/decisionpolicy-demo\n", + "[info] loaded 26 corpora from /home/lyk25/.convokit/saved-corpora/decisionpolicy-demo\n" + ] + } + ], + "source": [ + "# TODO temporary pre-merge downloader for decisionpolicy-demo\n", + "from pathlib import Path\n", + "import json\n", + "import urllib.request\n", + "import zipfile\n", + "\n", + "from convokit import Corpus\n", + "\n", + "DOWNLOAD_CONFIG_URL = (\n", + " \"https://raw.githubusercontent.com/laerdon/ConvoKit/\"\n", + " \"master/download_config.json\"\n", + ")\n", + "\n", + "\n", + "def get_decisionpolicy_demo_url(config_url=DOWNLOAD_CONFIG_URL):\n", + " print(f\"[info] reading download config from {config_url}\")\n", + " with urllib.request.urlopen(config_url) as response:\n", + " dataset_config = json.load(response)\n", + " try:\n", + " return dataset_config[\"DatasetURLs\"][\"decisionpolicy-demo\"]\n", + " except KeyError as exc:\n", + " raise KeyError(\n", + " \"decisionpolicy-demo is missing from laerdon/ConvoKit master download_config.json\"\n", + " ) from exc\n", + "\n", + "\n", + "def download_decisionpolicy_demo(data_dir=None):\n", + " data_root = Path(data_dir or \"~/.convokit/saved-corpora\").expanduser()\n", + " dataset_name = \"decisionpolicy-demo\"\n", + " dataset_dir = data_root / dataset_name\n", + " zip_path = data_root / f\"{dataset_name}.zip\"\n", + "\n", + " if any(path.is_dir() and (path / \"index.json\").exists() for path in dataset_dir.rglob(\"*\")):\n", + " print(f\"[info] using cached {dataset_name} at {dataset_dir}\")\n", + " return dataset_dir\n", + "\n", + " url = get_decisionpolicy_demo_url()\n", + " data_root.mkdir(parents=True, exist_ok=True)\n", + " print(f\"[info] downloading {dataset_name} from {url}\")\n", + " urllib.request.urlretrieve(url, zip_path)\n", + "\n", + " print(f\"[info] extracting {zip_path} to {data_root}\")\n", + " with zipfile.ZipFile(zip_path, \"r\") as zipf:\n", + " zipf.extractall(data_root)\n", + "\n", + " if not dataset_dir.exists():\n", + " raise FileNotFoundError(f\"expected extracted folder missing: {dataset_dir}\")\n", + " return dataset_dir\n", + "\n", + "base = download_decisionpolicy_demo()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "858a8431", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[info] using cached decisionpolicy-demo at /home/lyk25/.convokit/saved-corpora/decisionpolicy-demo\n", + "[info] test-seed corpora: 5 (test-seed-1, test-seed-2, test-seed-3, test-seed-4, test-seed-5); seed-* policy corpora in corpora_all: 25\n" + ] + } + ], + "source": [ + "import re\n", + "corpus_dirs = sorted(\n", + " corpus_dir\n", + " for corpus_dir in base.rglob(\"*\")\n", + " if corpus_dir.is_dir() and (corpus_dir / \"index.json\").exists()\n", + ")\n", + "if not corpus_dirs:\n", + " raise FileNotFoundError(f\"no convokit corpora found under {base}\")\n", + "_seed_policy_re = re.compile(r\"^seed-(\\d+)-(.+)$\")\n", + "_test_seed_re = re.compile(r\"^test-seed-(\\d+)$\")\n", + "corpora_all = {}\n", + "corpora = {}\n", + "for corpus_dir in corpus_dirs:\n", + " if _test_seed_re.match(corpus_dir.name):\n", + " corpora[corpus_dir.name] = Corpus(filename=str(corpus_dir))\n", + " continue\n", + " m_test = _test_seed_re.match(corpus_dir.parent.name)\n", + " if m_test:\n", + " corpora[f\"test-seed-{m_test.group(1)}\"] = Corpus(filename=str(corpus_dir))\n", + " continue\n", + " m = _seed_policy_re.match(corpus_dir.parent.name)\n", + " if m and corpus_dir.name == m.group(2):\n", + " corpora_all[f\"seed-{m.group(1)}-{m.group(2)}\"] = Corpus(filename=str(corpus_dir))\n", + "CORPUS_POLICY_PER_SEED = \"DeferralDecisionPolicy\"\n", + "if not corpora:\n", + " for key in sorted(corpora_all):\n", + " m = _seed_policy_re.match(key)\n", + " if m and m.group(2) == CORPUS_POLICY_PER_SEED:\n", + " corpora[f\"test-seed-{m.group(1)}\"] = corpora_all[key]\n", + "if not corpora:\n", + " raise FileNotFoundError(\n", + " f\"no test-seed- corpora under {base} and none derived from {CORPUS_POLICY_PER_SEED!r}\"\n", + " )\n", + "print(\n", + " f\"[info] test-seed corpora: {len(corpora)} ({', '.join(sorted(corpora))}); \"\n", + " f\"seed-* policy corpora in corpora_all: {len(corpora_all)}\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "a0bbbd7f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'seed-1-DeferralDecisionPolicy': ,\n", + " 'seed-1-RandomDeferralDecisionPolicy': ,\n", + " 'seed-1-SimulationAverageDecisionPolicy': ,\n", + " 'seed-1-SimulationMajorityDecisionPolicy': ,\n", + " 'seed-1-ThresholdDecisionPolicy': ,\n", + " 'seed-2-DeferralDecisionPolicy': ,\n", + " 'seed-2-RandomDeferralDecisionPolicy': ,\n", + " 'seed-2-SimulationAverageDecisionPolicy': ,\n", + " 'seed-2-SimulationMajorityDecisionPolicy': ,\n", + " 'seed-2-ThresholdDecisionPolicy': ,\n", + " 'seed-3-DeferralDecisionPolicy': ,\n", + " 'seed-3-RandomDeferralDecisionPolicy': ,\n", + " 'seed-3-SimulationAverageDecisionPolicy': ,\n", + " 'seed-3-SimulationMajorityDecisionPolicy': ,\n", + " 'seed-3-ThresholdDecisionPolicy': ,\n", + " 'seed-4-DeferralDecisionPolicy': ,\n", + " 'seed-4-RandomDeferralDecisionPolicy': ,\n", + " 'seed-4-SimulationAverageDecisionPolicy': ,\n", + " 'seed-4-SimulationMajorityDecisionPolicy': ,\n", + " 'seed-4-ThresholdDecisionPolicy': ,\n", + " 'seed-5-DeferralDecisionPolicy': ,\n", + " 'seed-5-RandomDeferralDecisionPolicy': ,\n", + " 'seed-5-SimulationAverageDecisionPolicy': ,\n", + " 'seed-5-SimulationMajorityDecisionPolicy': ,\n", + " 'seed-5-ThresholdDecisionPolicy': }" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "corpora_all" + ] + }, + { + "cell_type": "markdown", + "id": "dab74642", + "metadata": {}, + "source": [ + "Having imported our DeferralDecisionPolicy and ThresholdDecisionPolicy, we now will first define all of the other decision policies to benchmark. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5dabf0bd", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Callable, List, Optional, Dict, Any, Tuple\n", + "import numpy as np\n", + "from convokit.decisionpolicy import DecisionPolicy\n", + "\n", + "\n", + "class _synthetic_speaker:\n", + " def __init__(self, speaker_id: str):\n", + " self.id = speaker_id\n", + "\n", + "\n", + "class _synthetic_utterance:\n", + " def __init__(self, text: str, utterance_id: str, speaker_id: str):\n", + " self.text = text\n", + " self.id = utterance_id\n", + " self.speaker_ = _synthetic_speaker(speaker_id)\n", + " self.meta = {}\n", + "\n", + " def get_conversation(self):\n", + " return None\n", + "\n", + "\n", + "class RandomDeferralDecisionPolicy(DecisionPolicy):\n", + " \"\"\"\n", + " Decision policy that defers intervention by looking ahead at simulated next utterances.\n", + "\n", + " :param simulator: utterance simulator model (must have a ``transform(contexts)`` method\n", + " returning a DataFrame indexed by utterance id). if the simulator exposes\n", + " ``get_num_simulations()``, ``num_simulations`` is capped to that value.\n", + " :param threshold: probability threshold above which a context is flagged.\n", + " :param tau: minimum number of simulated branches that must exceed the threshold\n", + " before an intervention is issued.\n", + " :param num_simulations: how many simulated branches to use per context (capped to\n", + " simulator's ``get_num_simulations()`` if available).\n", + " :param store_simulations: if True, simulated reply strings are cached during decide()\n", + " and written to corpus utterance metadata by post_transform().\n", + " :param simulated_reply_attribute_name: metadata field name used when storing simulations\n", + " on corpus utterances (only relevant when store_simulations=True).\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " simulator,\n", + " threshold,\n", + " deferral_probability: float = 0.1515,\n", + " reuse_cached_forecast_probs: bool = True,\n", + " forecast_prob_attribute_name: str = \"forecast_prob\",\n", + " ):\n", + " # forward the cache flag to the base class so its _score helper honors it.\n", + " # without this, reuse_cached_probabilities on this subclass had no effect.\n", + " super().__init__(\n", + " forecast_prob_attribute_name=forecast_prob_attribute_name,\n", + " reuse_cached_forecast_probs=reuse_cached_forecast_probs,\n", + " )\n", + " self.simulator = simulator\n", + " self.threshold = float(threshold)\n", + " self.deferral_probability = float(deferral_probability)\n", + "\n", + " def _decision_score(self, context, score_fn: Callable):\n", + " # use base _score so a cached forecast_prob on the utterance meta is reused\n", + " # instead of re-invoking the belief estimator.\n", + " return self._score(context, score_fn)\n", + "\n", + " def decide(self, context, score_fn: Callable) -> Tuple[float, int, Optional[Dict[str, Any]]]:\n", + " decision_score = self._score(context, score_fn)\n", + "\n", + " p = float(np.random.rand())\n", + " \n", + "\n", + " # return an empty metadata dict (not None) so downstream code that iterates\n", + " # utt_metadata.items() in TransformerDecoderModel.transform doesn't crash.\n", + " return (decision_score,\n", + " 1 if decision_score > self.threshold and p > self.deferral_probability else 0,\n", + " {}\n", + " )\n", + "\n", + " def fit(self, contexts, val_contexts=None, score_fn: Callable = None):\n", + " if val_contexts is None or score_fn is None or self.labeler is None:\n", + " print(\"either no validation contexts/score function/labeler were provided, returning current threshold\")\n", + " return {\"best_threshold\": self.threshold}\n", + "\n", + " val_contexts = list(val_contexts)\n", + " if len(val_contexts) == 0:\n", + " print(\"no validation contexts were provided, returning current threshold\")\n", + " return {\"best_threshold\": self.threshold}\n", + "\n", + " fit_result = self._fit_with_model_checkpoint_selection(val_contexts, score_fn=score_fn)\n", + " if isinstance(fit_result, dict):\n", + " if \"best_threshold\" in fit_result:\n", + " self.threshold = float(fit_result[\"best_threshold\"])\n", + " return fit_result\n", + "\n", + " fit_result = self._fit_threshold_for_loaded_model(val_contexts, score_fn=score_fn)\n", + " if \"best_threshold\" in fit_result:\n", + " self.threshold = float(fit_result[\"best_threshold\"])\n", + " return fit_result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9d3db5bc", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Callable, Optional, Dict, Any, Tuple\n", + "\n", + "from convokit.decisionpolicy import DeferralDecisionPolicy\n", + "\n", + "\n", + "class SimulationAverageDecisionPolicy(DeferralDecisionPolicy):\n", + " \"\"\"\n", + " decision policy that intervenes if the mean of the simulated next-utterance\n", + " scores is at or above the threshold.\n", + "\n", + " this subclass inherits all simulation fetching, per-utterance metadata\n", + " caching, sim-score caching, and threshold fitting from\n", + " DeferralDecisionPolicy. the only differences are:\n", + " * no ``tau`` parameter (unused; forwarded as 0 to super)\n", + " * ``decide`` predicts based on mean(simulation_scores) >= threshold\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " simulator,\n", + " threshold,\n", + " num_simulations: int = 10,\n", + " store_simulations: bool = False,\n", + " simulated_reply_attribute_name: str = \"sim_replies\",\n", + " sim_replies_forecast_probs_attribute_name: str = \"sim_replies_forecast_probs\",\n", + " reuse_cached_simulations: bool = True,\n", + " ):\n", + " # tau is irrelevant for the mean-based decision rule, so we pin it to 0\n", + " # upstream rather than expose it to callers of this subclass.\n", + " super().__init__(\n", + " simulator=simulator,\n", + " threshold=threshold,\n", + " tau=0,\n", + " num_simulations=num_simulations,\n", + " store_simulations=store_simulations,\n", + " simulated_reply_attribute_name=simulated_reply_attribute_name,\n", + " sim_replies_forecast_probs_attribute_name=sim_replies_forecast_probs_attribute_name,\n", + " reuse_cached_simulations=reuse_cached_simulations,\n", + " )\n", + "\n", + " def decide(self, context, score_fn: Callable) -> Tuple[float, int, Optional[Dict[str, Any]]]:\n", + " decision_score, simulations, simulation_scores = self._decision_score(context, score_fn)\n", + " # empty simulation_scores would zero-divide. this happens when the\n", + " # simulator returns no completions for a context (e.g. end-of-conversation\n", + " # contexts that slip through the selector). treat as no intervention so\n", + " # a single degenerate context doesn't abort the whole transform run.\n", + " if len(simulation_scores) == 0:\n", + " average_simulation_score = 0.0\n", + " pred = 0\n", + " else:\n", + " average_simulation_score = sum(simulation_scores) / len(simulation_scores)\n", + " pred = 1 if average_simulation_score >= self.threshold else 0\n", + " return (\n", + " decision_score,\n", + " pred,\n", + " {\n", + " self.simulated_reply_attribute_name: simulations,\n", + " self.sim_replies_forecast_probs_attribute_name: simulation_scores,\n", + " },\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d0d1ea7a", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Callable, Optional, Dict, Any, Tuple\n", + "\n", + "from convokit.decisionpolicy import DeferralDecisionPolicy\n", + "\n", + "\n", + "class SimulationMajorityDecisionPolicy(DeferralDecisionPolicy):\n", + " \"\"\"\n", + " decision policy that intervenes if at least ``tau`` of the simulated next\n", + " utterances score above the threshold, ignoring the current utterance score.\n", + "\n", + " this subclass inherits all simulation fetching, per-utterance metadata\n", + " caching, sim-score caching, and threshold fitting from\n", + " DeferralDecisionPolicy. the only difference is in ``decide``: the gate\n", + " ``decision_score > threshold`` is dropped so that only the simulated-branch\n", + " vote count drives the prediction.\n", + " \"\"\"\n", + "\n", + " def decide(self, context, score_fn: Callable) -> Tuple[float, int, Optional[Dict[str, Any]]]:\n", + " decision_score, simulations, simulation_scores = self._decision_score(context, score_fn)\n", + " num_simulations_above_threshold = sum(\n", + " 1 for score in simulation_scores if score > self.threshold\n", + " )\n", + " return (\n", + " decision_score,\n", + " 1 if num_simulations_above_threshold >= self.tau else 0,\n", + " {\n", + " self.simulated_reply_attribute_name: simulations,\n", + " self.sim_replies_forecast_probs_attribute_name: simulation_scores,\n", + " },\n", + " )\n" + ] + }, + { + "cell_type": "markdown", + "id": "0498693e", + "metadata": {}, + "source": [ + "Since we use simulations in our baselines, we must define a simulation configuration as follows: " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "638fb31c", + "metadata": {}, + "outputs": [], + "source": [ + "SIMULATOR_TRAIN_CONFIG = {\n", + " \"per_device_train_batch_size\": 16,\n", + " \"per_device_eval_batch_size\": 16,\n", + " \"eval_strategy\": \"steps\",\n", + " \"save_strategy\": \"steps\",\n", + " \"save_steps\": 30,\n", + " \"gradient_accumulation_steps\": 4,\n", + " \"warmup_steps\": 5,\n", + " \"num_train_epochs\": 1,\n", + " \"eval_steps\": 30,\n", + " \"learning_rate\": 2e-4,\n", + " \"logging_steps\": 5,\n", + " \"optim\": \"adamw_8bit\",\n", + " \"weight_decay\": 0.01,\n", + " \"lr_scheduler_type\": \"linear\",\n", + " \"output_dir\": \"outputs/simulator_finetune\",\n", + " \"logging_dir\": \"logs\",\n", + " \"load_best_model_at_end\": True,\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bf55d140", + "metadata": {}, + "outputs": [], + "source": [ + "TAU = 7\n", + "DEFERRAL_PROBABILITY_THRESHOLD = 0.2518938553561718\n", + "NUM_SIMULATIONS = 10\n", + "OUTPUT_DIR = \"benchmark_preannotated\"\n", + "SEEDS = [1,2,3,4,5]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aedd02fe", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "==((====))== Unsloth 2025.7.11: Fast Llama patching. Transformers: 4.53.3.\n", + " \\\\ /| NVIDIA RTX A6000. Num GPUs = 1. Max memory: 47.536 GB. Platform: Linux.\n", + "O^O/ \\_/ \\ Torch: 2.7.1+cu128. CUDA: 8.6. CUDA Toolkit: 12.8. Triton: 3.1.0+cf34004b8a\n", + "\\ / Bfloat16 = TRUE. FA [Xformers = 0.0.31.post1. FA2 = False]\n", + " \"-____-\" Free license: http://github.com/unslothai/unsloth\n", + "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Unsloth 2025.7.11 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.\n", + "Unsloth: Already have LoRA adapters! We shall skip this step.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Unsloth: Training embed_tokens in mixed precision to save VRAM\n", + "Unsloth: Training lm_head in mixed precision to save VRAM\n" + ] + } + ], + "source": [ + "simulator_model = UnslothUtteranceSimulatorModel(\n", + " model_name=\"/reef/lyk25/dynamic_training/game_analysis/outputs/checkpoint-74\",\n", + " train_config=SIMULATOR_TRAIN_CONFIG,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "fe147ae8", + "metadata": {}, + "source": [ + "After loading the simulator model, we define context selectors." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "a9502041", + "metadata": {}, + "outputs": [], + "source": [ + "def context_selector(context_tuple, split):\n", + " \"\"\"\n", + " We use this generic function for both training and validation data.\n", + " In both cases, its job is to select only those contexts for which the\n", + " FUTURE context is not empty, so we have a next utterance to predict.\n", + " \"\"\"\n", + " matches_split = (context_tuple.current_utterance.get_conversation().meta.get(\"split\") == split)\n", + " is_end = (len(context_tuple.future_context) == 0)\n", + " return matches_split and not is_end\n", + "\n", + "def make_data_selector(split):\n", + " return lambda context_tuple: context_tuple.current_utterance.get_conversation().meta.get(\"split\") == split\n", + "\n", + "train_context_selector = partial(context_selector, split=\"train\")\n", + "val_context_selector = partial(context_selector, split=\"val\")\n", + "test_context_selector = partial(context_selector, split=\"test\")" + ] + }, + { + "cell_type": "markdown", + "id": "f707a3a5", + "metadata": {}, + "source": [ + "Below is a script to fully reproduce, sans regenerating simulations (simulations requires substantially more compute)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dde130ed", + "metadata": {}, + "outputs": [], + "source": [ + "for seed_idx in SEEDS:\n", + " train_corpus = Corpus(filename=download('conversations-gone-awry-cmv-corpus-large'))\n", + " # corpus = Corpus(filename=f'/reef/lyk25/theres-a-way-out-ACL26-internal/outputs/benchmark_preannotated/seed-{seed_idx}-ThresholdDecisionPolicy/ThresholdDecisionPolicy')\n", + " # corpus = Corpus(filename=f'/reef/lyk25/dynamic_training/game_analysis/corpi/test/test-son-seed-{seed_idx}')\n", + " corpus = Corpus(filename=download('conversations-gone-awry-cmv-corpus-large'))\n", + " corpus.filter_conversations_by(lambda convo: convo.meta['split'] == 'test')\n", + "\n", + " config = TransformerForecasterConfig(\n", + " output_dir=f\"outputs/{OUTPUT_DIR}/forecaster_{seed_idx}\",\n", + " per_device_batch_size=16,\n", + " gradient_accumulation_steps=1,\n", + " num_train_epochs=1,\n", + " learning_rate=1e-5,\n", + " random_seed=seed_idx,\n", + " context_mode=\"normal\",\n", + " device=\"cuda\",\n", + " )\n", + "\n", + " # TODO this will have to be edited\n", + " forecaster_model = TransformerDecoderModel(\n", + " model_name_or_path=\"google/gemma-2-9b-it\",\n", + " config=config,\n", + " )\n", + "\n", + " forecaster = Forecaster(\n", + " forecaster_model=forecaster_model,\n", + " labeler='has_removed_comment',\n", + " )\n", + "\n", + " forecaster.fit_belief_estimator(\n", + " corpus=train_corpus,\n", + " context_selector=train_context_selector,\n", + " val_context_selector=val_context_selector,\n", + " )\n", + "\n", + " # ---\n", + " cfg_path = os.path.join(repo_root, \"saves\", f\"seed-{seed_idx}\", \"dev_config.json\")\n", + " with open(cfg_path) as f:\n", + " cfg = json.load(f)\n", + " best_threshold = cfg['best_threshold']\n", + "\n", + " for policy_trial in [ThresholdDecisionPolicy, DeferralDecisionPolicy, RandomDeferralDecisionPolicy, SimulationAverageDecisionPolicy, SimulationMajorityDecisionPolicy]:\n", + " print('---')\n", + " print(f\"Fitting policy {policy_trial.__name__} for seed {seed_idx}\")\n", + " if policy_trial == ThresholdDecisionPolicy:\n", + " policy = ThresholdDecisionPolicy(\n", + " threshold=best_threshold,\n", + " reuse_cached_forecast_probs=False,\n", + " )\n", + " elif policy_trial == DeferralDecisionPolicy:\n", + " policy = DeferralDecisionPolicy(\n", + " simulator=simulator_model,\n", + " threshold=best_threshold,\n", + " tau=TAU,\n", + " reuse_cached_forecast_probs=False,\n", + " )\n", + " elif policy_trial == RandomDeferralDecisionPolicy:\n", + " policy = RandomDeferralDecisionPolicy(\n", + " simulator=simulator_model,\n", + " threshold=best_threshold,\n", + " deferral_probability=DEFERRAL_PROBABILITY_THRESHOLD,\n", + " reuse_cached_forecast_probs=False,\n", + " )\n", + " elif policy_trial == SimulationAverageDecisionPolicy:\n", + " policy = SimulationAverageDecisionPolicy(\n", + " simulator=simulator_model,\n", + " threshold=best_threshold,\n", + " num_simulations=NUM_SIMULATIONS,\n", + " store_simulations=False,\n", + " simulated_reply_attribute_name=\"sim_replies\",\n", + " sim_replies_forecast_probs_attribute_name=\"sim_replies_forecast_probs\",\n", + " reuse_cached_forecast_probs=False,\n", + " )\n", + " elif policy_trial == SimulationMajorityDecisionPolicy:\n", + " policy = SimulationMajorityDecisionPolicy(\n", + " simulator=simulator_model,\n", + " threshold=best_threshold,\n", + " tau=TAU,\n", + " num_simulations=NUM_SIMULATIONS,\n", + " store_simulations=False,\n", + " simulated_reply_attribute_name=\"sim_replies\",\n", + " sim_replies_forecast_probs_attribute_name=\"sim_replies_forecast_probs\",\n", + " reuse_cached_forecast_probs=False,\n", + " )\n", + " \n", + " # attach the decision policy to the underlying forecaster model;\n", + " # Forecaster itself does not accept decision_policy in its constructor.\n", + " forecaster_model.decision_policy = policy\n", + "\n", + " forecaster = Forecaster(\n", + " forecaster_model=forecaster_model,\n", + " labeler='has_removed_comment',\n", + " )\n", + "\n", + " print('starting transformation.')\n", + " # evaluate the forecaster on the test set\n", + " forecaster.transform(\n", + " corpus=corpus,\n", + " context_selector=make_data_selector('test'),\n", + " verbose=True,\n", + " )\n", + " print('transformation complete.')\n", + "\n", + " output_dir = f\"outputs/{OUTPUT_DIR}/seed-{seed_idx}-{policy_trial.__name__}\"\n", + " os.makedirs(output_dir, exist_ok=True)\n", + " corpus.dump(name=f\"{policy_trial.__name__}\", base_path=output_dir)\n", + " print('corpus dumped.')\n", + "\n", + " print('starting summarization.')\n", + " # forecaster.summarize expects a conversation-level selector (Callable[[Conversation], bool]),\n", + " # unlike the context-tuple selectors used in fit/transform.\n", + " def summarize_selector(convo):\n", + " return convo.meta.get(\"split\") == \"test\"\n", + " conversational_forecasts_df, metrics = forecaster.summarize(\n", + " corpus=corpus,\n", + " selector=summarize_selector,\n", + " )\n", + " print('summarization complete.')\n", + " \n", + " # path to the seed output directory\n", + " seed_folder = f\"outputs/{OUTPUT_DIR}/seed-{seed_idx}-{policy_trial.__name__}\"\n", + "\n", + " # ensure the directory exists\n", + " os.makedirs(seed_folder, exist_ok=True)\n", + "\n", + " # save conversational_forecasts_df as CSV\n", + " conversational_forecasts_df.to_csv(os.path.join(seed_folder, \"conversational_forecasts.csv\"), index=False)\n", + "\n", + " # save metrics as JSON\n", + " with open(os.path.join(seed_folder, \"metrics.json\"), \"w\") as f:\n", + " json.dump(metrics, f, indent=2)" + ] + }, + { + "cell_type": "markdown", + "id": "e3c6f428", + "metadata": {}, + "source": [ + "A faster reproduction is possible by skipping the training and transformation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2f4a29c2", + "metadata": {}, + "outputs": [], + "source": [ + "for seed_idx in SEEDS:\n", + " config = TransformerForecasterConfig(\n", + " output_dir=f\"outputs/{OUTPUT_DIR}/forecaster_{seed_idx}\",\n", + " per_device_batch_size=16,\n", + " gradient_accumulation_steps=1,\n", + " num_train_epochs=1,\n", + " learning_rate=1e-5,\n", + " random_seed=seed_idx,\n", + " context_mode=\"normal\",\n", + " device=\"cuda\",\n", + " )\n", + "\n", + " # TODO this will have to be edited\n", + " forecaster_model = TransformerDecoderModel(\n", + " model_name_or_path=\"google/gemma-2-9b-it\",\n", + " config=config,\n", + " )\n", + "\n", + " forecaster = Forecaster(\n", + " forecaster_model=forecaster_model,\n", + " labeler='has_removed_comment',\n", + " )\n", + "\n", + " # remove training---we can instead use the cached forecast probabilities.\n", + "\n", + " # ---\n", + " cfg_path = os.path.join(repo_root, \"saves\", f\"seed-{seed_idx}\", \"dev_config.json\")\n", + " with open(cfg_path) as f:\n", + " cfg = json.load(f)\n", + " best_threshold = cfg['best_threshold']\n", + "\n", + " for policy_trial in [ThresholdDecisionPolicy, DeferralDecisionPolicy, RandomDeferralDecisionPolicy, SimulationAverageDecisionPolicy, SimulationMajorityDecisionPolicy]:\n", + " corpus_name = f\"seed-{seed_idx}-{policy_trial.__name__}\"\n", + " if corpus_name in corpora:\n", + " corpus = corpora_all[corpus_name]\n", + " else:\n", + " raise KeyError(f\"missing corpus {corpus_name}\")\n", + "\n", + " print('---')\n", + " print(f\"fitting policy {policy_trial.__name__} for seed {seed_idx}\")\n", + " if policy_trial == ThresholdDecisionPolicy:\n", + " policy = ThresholdDecisionPolicy(\n", + " threshold=best_threshold,\n", + " )\n", + " elif policy_trial == DeferralDecisionPolicy:\n", + " policy = DeferralDecisionPolicy(\n", + " simulator=simulator_model,\n", + " threshold=best_threshold,\n", + " tau=TAU,\n", + " )\n", + " elif policy_trial == RandomDeferralDecisionPolicy:\n", + " policy = RandomDeferralDecisionPolicy(\n", + " simulator=simulator_model,\n", + " threshold=best_threshold,\n", + " deferral_probability=DEFERRAL_PROBABILITY_THRESHOLD,\n", + " )\n", + " elif policy_trial == SimulationAverageDecisionPolicy:\n", + " policy = SimulationAverageDecisionPolicy(\n", + " simulator=simulator_model,\n", + " threshold=best_threshold,\n", + " num_simulations=NUM_SIMULATIONS,\n", + " store_simulations=False,\n", + " simulated_reply_attribute_name=\"sim_replies\",\n", + " sim_replies_forecast_probs_attribute_name=\"sim_replies_forecast_probs\",\n", + " )\n", + " elif policy_trial == SimulationMajorityDecisionPolicy:\n", + " policy = SimulationMajorityDecisionPolicy(\n", + " simulator=simulator_model,\n", + " threshold=best_threshold,\n", + " tau=TAU,\n", + " num_simulations=NUM_SIMULATIONS,\n", + " store_simulations=False,\n", + " simulated_reply_attribute_name=\"sim_replies\",\n", + " sim_replies_forecast_probs_attribute_name=\"sim_replies_forecast_probs\",\n", + " )\n", + " \n", + " # attach the decision policy to the underlying forecaster model;\n", + " # Forecaster itself does not accept decision_policy in its constructor.\n", + " forecaster_model.decision_policy = policy\n", + "\n", + " forecaster = Forecaster(\n", + " forecaster_model=forecaster_model,\n", + " labeler='has_removed_comment',\n", + " )\n", + "\n", + " print('starting transformation.')\n", + " # evaluate the forecaster on the test set\n", + " forecaster.transform(\n", + " corpus=corpus,\n", + " context_selector=make_data_selector('test'),\n", + " verbose=True,\n", + " )\n", + " print('transformation complete.')\n", + "\n", + " output_dir = f\"outputs/{OUTPUT_DIR}/seed-{seed_idx}-{policy_trial.__name__}\"\n", + " os.makedirs(output_dir, exist_ok=True)\n", + " corpus.dump(name=f\"{policy_trial.__name__}\", base_path=output_dir)\n", + " print('corpus dumped.')\n", + "\n", + " print('starting summarization.')\n", + " # forecaster.summarize expects a conversation-level selector (Callable[[Conversation], bool]),\n", + " # unlike the context-tuple selectors used in fit/transform.\n", + " def summarize_selector(convo):\n", + " return convo.meta.get(\"split\") == \"test\"\n", + " conversational_forecasts_df, metrics = forecaster.summarize(\n", + " corpus=corpus,\n", + " selector=summarize_selector,\n", + " )\n", + " print('summarization complete.')\n", + " \n", + " # path to the seed output directory\n", + " seed_folder = f\"outputs/{OUTPUT_DIR}/seed-{seed_idx}-{policy_trial.__name__}\"\n", + "\n", + " # ensure the directory exists\n", + " os.makedirs(seed_folder, exist_ok=True)\n", + "\n", + " # save conversational_forecasts_df as CSV\n", + " conversational_forecasts_df.to_csv(os.path.join(seed_folder, \"conversational_forecasts.csv\"), index=False)\n", + "\n", + " # save metrics as JSON\n", + " with open(os.path.join(seed_folder, \"metrics.json\"), \"w\") as f:\n", + " json.dump(metrics, f, indent=2)" + ] + }, + { + "cell_type": "markdown", + "id": "29a0cdc1", + "metadata": {}, + "source": [ + "## Human benchmark analysis" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "d673d0d6", + "metadata": {}, + "outputs": [], + "source": [ + "# import human data SQL here" + ] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "==((====))== Unsloth 2025.7.11: Fast Llama patching. Transformers: 4.53.3.\n", - " \\\\ /| NVIDIA RTX A6000. Num GPUs = 1. Max memory: 47.536 GB. Platform: Linux.\n", - "O^O/ \\_/ \\ Torch: 2.7.1+cu128. CUDA: 8.6. CUDA Toolkit: 12.8. Triton: 3.1.0+cf34004b8a\n", - "\\ / Bfloat16 = TRUE. FA [Xformers = 0.0.31.post1. FA2 = False]\n", - " \"-____-\" Free license: http://github.com/unslothai/unsloth\n", - "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n" - ] + "cell_type": "code", + "execution_count": 13, + "id": "d55b3bab", + "metadata": {}, + "outputs": [], + "source": [ + "import sqlite3" + ] }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "Unsloth 2025.7.11 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.\n", - "Unsloth: Already have LoRA adapters! We shall skip this step.\n" - ] + "cell_type": "code", + "execution_count": 16, + "id": "6047d55a", + "metadata": {}, + "outputs": [], + "source": [ + "round_n = 1" + ] }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Unsloth: Training embed_tokens in mixed precision to save VRAM\n", - "Unsloth: Training lm_head in mixed precision to save VRAM\n" - ] + "cell_type": "code", + "execution_count": 17, + "id": "92766a80", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[info] connected to database: /reef/sqt2/cga-eval/human/game_db.sqlite\n" + ] + } + ], + "source": [ + "# connect to the game database\n", + "db_path = '/reef/sqt2/cga-eval/human/game_db.sqlite'\n", + "connection = sqlite3.connect(db_path)\n", + "cursor = connection.cursor()\n", + "\n", + "print(f\"[info] connected to database: {db_path}\")\n", + "cursor.execute(\"SELECT COUNT(*) FROM results\")\n", + "\n", + "# # names_list = []\n", + "# # answers_list = []\n", + "# # scores_list = []\n", + "# # comments_list = []\n", + "\n", + "# TIME_CUTOFF = 1714059814630 # timestamp to cut off previous results\n", + "# TIME_CUTOFF_END = 1730385762530\n", + "\n", + "# for row in connection.execute('SELECT * FROM results'):\n", + "# if row[0] != 'yc2727' and row[0] != 'sqt2':\n", + "# answers = json.loads(row[1])\n", + "# if answers[0]['start_time'] > TIME_CUTOFF and answers[0]['start_time'] < TIME_CUTOFF_END:\n", + "# names_list.append(row[0])\n", + "# answers_list.append(answers)\n", + "# scores_list.append(row[2])\n", + "# comments_list.append(row[3])\n", + "\n", + "if round_n == 1:\n", + " # for round_n 1\n", + " names_list_1 = []\n", + " answers_list_1 = []\n", + " scores_list_1 = []\n", + " comments_list_1 = []\n", + "\n", + " TIME_CUTOFF_1 = 1730395000000\n", + " TIME_CUTOFF_END_1 = 1730397199000\n", + "\n", + " for row in connection.execute('SELECT * FROM results'):\n", + " if row[0] != 'yc2727' and row[0] != 'sqt2':\n", + " answers = json.loads(row[1])\n", + " if answers[0]['start_time'] > TIME_CUTOFF_1 and answers[0]['start_time'] < TIME_CUTOFF_END_1:\n", + " names_list_1.append(row[0])\n", + " answers_list_1.append(answers)\n", + " scores_list_1.append(row[2])\n", + " comments_list_1.append(row[3])\n", + "elif round_n == 2:\n", + " # for round 2\n", + " names_list_2 = []\n", + " answers_list_2 = []\n", + " scores_list_2 = []\n", + " comments_list_2 = []\n", + "\n", + " TIME_CUTOFF_2_a = 1731002910000\n", + " TIME_CUTOFF_END_2_a = 1731016180000\n", + "\n", + " for row in connection.execute('SELECT * FROM results'):\n", + " if row[0] != 'yc2727' and row[0] != 'sqt2':\n", + " answers = json.loads(row[1])\n", + " if answers[0]['start_time'] > TIME_CUTOFF_2_a and answers[0]['start_time'] < TIME_CUTOFF_END_2_a:\n", + " names_list_2.append(row[0])\n", + " answers_list_2.append(answers)\n", + " scores_list_2.append(row[2])\n", + " comments_list_2.append(row[3])\n", + "\n", + " # The 2_b cutoff below is to specifically add data from 'ljl2' for a later second round window,\n", + " # likely because they submitted their data late or for a different time block.\n", + " TIME_CUTOFF_2_b = 1732212720000\n", + " TIME_CUTOFF_END_2_b = 1740000000000\n", + "\n", + " for row in connection.execute('SELECT * FROM results'):\n", + " if row[0].lower() == 'ljl2':\n", + " answers = json.loads(row[1])\n", + " if answers[0]['start_time'] > TIME_CUTOFF_2_b and answers[0]['start_time'] < TIME_CUTOFF_END_2_b:\n", + " names_list_2.append(row[0])\n", + " answers_list_2.append(answers)\n", + " scores_list_2.append(row[2])\n", + " comments_list_2.append(row[3])\n", + "\n", + " # print(f\"Round 1: {len(names_list_1)} entries\")\n", + " # print(f\"Round 2: {len(names_list_2)} entries\")\n", + " # names_list_1, names_list_2" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "889d7ede", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[info] total unique conversation ids: 84\n", + "['givp578', 'e1vstxd', 'f1353ra', 'ewr7ls3', 'dmlny27', 'dg7wmdb', 'flixm8f', 'f13u0n2', 'ggwdbsa', 'd3f07tv', 'isz5n6g', 'givok1i', 'dyx3au1', 'g1sm9mt', 'f0dryv4', 'doi44j8', 'g2998qp', 'dpirnlc', 'h2h7yl0', 'dg7x3eu', 'ej6jusi', 'ikzb571', 'e1vv6x1', 'ifkoazb', 'dpkib9q', 'g1q3upb', 'dydawov', 'g1q1973', 'dm8exht', 'fphcwq0', 'gmxqrdz', 'dmlqnzz', 'dyd7cfv', 'i4rgtqf', 'doi3vuz', 'd0fyu9x', 'f014doo', 'dm8qu78', 'gmxoyuu', 'fmre7l9', 'ii3cpve', 'ii4ivim', 'e9mynuy', 'djh1bt1', 'f83hb2k', 'ewtowqu', 'dbnxbzn', 'fixjxxs', 'dvivs9p', 'ej6ollb', 'g1spg0k', 'ggwei5y', 'd0fzp40', 'gvb1ekl', 'djh2v24', 'dvisfl0', 'ewtjxp2', 'dyx2jn4', 'ewr7msm', 'd3efg03', 'fixlxvy', 'fegw71k', 'ikzcblg', 'ffpk80c', 'fmspcex', 'f00nxjs', 'gvdnrrb', 'h28ltfw', 'f0dsbw4', 'dbnr5nd', 'ikpw1ol', 'ffpj88q', 'i4rh1id', 'e9mybxp', 'isz6a7m', 'ifknq44', 'g297nn7', 'fegwvia', 'ikpxahv', 'f83akyz', 'h2hauaj', 'h28lknz', 'g3hmlew', 'fljdrtt']\n" + ] + } + ], + "source": [ + "human_map1 = {\n", + " \"vn72\": ['d3efg03', 'd3f07tv', 'f1353ra', 'f13u0n2', 'fegw71k', 'fegwvia', 'ikpw1ol', 'ikpxahv', 'isz5n6g', 'isz6a7m'],\n", + " \"nac86\": ['dyd7cfv', 'dydawov', 'ewtjxp2', 'ewtowqu', 'f00nxjs', 'f014doo', 'ii3cpve', 'ii4ivim', 'ikzb571', 'ikzcblg'],\n", + " # \"sj597\": ['dbnr5nd', 'dbnxbzn', 'g1sm9mt', 'g1spg0k', 'g297nn7', 'g2998qp', 'h2h7yl0', 'h2hauaj', 'ifknq44', 'ifkoazb'],\n", + " \"yc2727\": ['d08x3kl', 'd08x4q0', 'd3efg03', 'd3f07tv', 'ewtjxp2', 'ewtowqu', 'hnupywh', 'hnuu3ip', 'ih2fn94', 'ih3iwph'],\n", + " # \"ex36\": ['d0fyu9x', 'd0fzp40', 'dg7wmdb', 'dg7x3eu', 'ej6jusi', 'ej6ollb', 'ewtjxp2', 'ewtowqu', 'f0dryv4', 'f0dsbw4'],\n", + " \"kz88\": ['doi3vuz', 'doi44j8', 'e9mybxp', 'e9mynuy', 'g1q1973', 'g1q3upb', 'ggwdbsa', 'ggwei5y', 'gmxoyuu', 'gmxqrdz'],\n", + " \"LJL2\": ['ffpj88q', 'ffpk80c', 'flixm8f', 'fljdrtt', 'fmre7l9', 'fmspcex', 'fphcwq0', 'g3hmlew', 'givok1i', 'givp578'],\n", + " \"lyk25\": ['dpirnlc', 'dpkib9q', 'e1vstxd', 'e1vv6x1', 'ewr7ls3', 'ewr7msm', 'fixjxxs', 'fixlxvy', 'fmre7l9', 'fmspcex'],\n", + " \"sqt2\": ['d3efg03', 'd3f07tv', 'g0fwpzc', 'g0ggz8e', 'g8nyz4f', 'g8nzx3m', 'h4i75b9', 'h4i79lc', 'h6ikmzc', 'h6ime20'],\n", + " \"cd326\": ['djh1bt1', 'djh2v24', 'dmlny27', 'dmlqnzz', 'f83akyz', 'f83hb2k', 'h28lknz', 'h28ltfw', 'i4rgtqf', 'i4rh1id'],\n", + " \"tg352\": ['dm8exht', 'dm8qu78', 'dvisfl0', 'dvivs9p', 'dyx2jn4', 'dyx3au1', 'gvb1ekl', 'gvdnrrb', 'i4rgtqf', 'i4rh1id'],\n", + "}\n", + "\n", + "# all_convo_ids = []\n", + "# for k,v in human_map1.items():\n", + "# all_convo_ids.extend(v)\n", + "# all_convo_ids = list(set(all_convo_ids))\n", + "# len(all_convo_ids)\n", + "\n", + "all_convo_ids = []\n", + "# # collect from round 1\n", + "if round_n == 1:\n", + " for i in range(len(answers_list_1)):\n", + " for j in range(len(answers_list_1[i])):\n", + " all_convo_ids.append(answers_list_1[i][j]['id'])\n", + "# collect from round 2\n", + "elif round_n == 2:\n", + " for i in range(len(answers_list_2)):\n", + " for j in range(len(answers_list_2[i])):\n", + " all_convo_ids.append(answers_list_2[i][j]['id'])\n", + "all_convo_ids = list(set(all_convo_ids))\n", + "print(f\"[info] total unique conversation ids: {len(all_convo_ids)}\")\n", + "print(all_convo_ids)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "e55a3803", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset already exists at /reef/lyk25/ConvoKit/examples/forecaster/conversations-gone-awry-cmv-corpus-large\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "corpus = Corpus(filename=download('conversations-gone-awry-cmv-corpus-large'))\n", + "corpus.filter_conversations_by(lambda convo: convo.id in all_convo_ids)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "168fb025", + "metadata": {}, + "outputs": [], + "source": [ + "# we want all utterances to have a human_guesses meta field\n", + "for convo in corpus.iter_conversations():\n", + " for utt in convo.get_chronological_utterance_list():\n", + " utt.add_meta('human_guesses', [])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27fc646a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "vn72\n", + "nac86\n", + "sj597\n", + "ex36\n", + "kz88\n", + "LJL2\n", + "lyk25\n", + "cd326\n", + "tg352\n" + ] + } + ], + "source": [ + "if round_n == 1:\n", + " for i, participant_sessions in enumerate(answers_list_1):\n", + " participant_id = names_list_1[i]\n", + " print(participant_id)\n", + " for session in participant_sessions:\n", + " convo_id = session['id']\n", + " convo = corpus.get_conversation(convo_id)\n", + " utts = convo.get_chronological_utterance_list()\n", + " actions = session.get('actions', [])\n", + "\n", + " for action_idx, action in enumerate(actions):\n", + " if action.get('guess') is True:\n", + " utt = utts[action_idx]\n", + " # get current guesses (returns deep copy if exists, or empty list if not)\n", + " # create a new list to avoid mutating the copy\n", + " current_guesses = list(utt.meta.get('human_guesses', []))\n", + " # append new participant and set back\n", + " current_guesses.append(participant_id)\n", + " utt.add_meta('human_guesses', current_guesses)\n", + "elif round_n == 2:\n", + " for i, participant_sessions in enumerate(answers_list_2):\n", + " participant_id = names_list_2[i]\n", + " print(participant_id)\n", + " for session in participant_sessions:\n", + " convo_id = session['id']\n", + " convo = corpus.get_conversation(convo_id)\n", + " utts = convo.get_chronological_utterance_list()\n", + " actions = session.get('actions', [])\n", + "\n", + " for action_idx, action in enumerate(actions):\n", + " if action.get('guess') is True:\n", + " utt = utts[action_idx]\n", + " # get current guesses (returns deep copy if exists, or empty list if not)\n", + " # create a new list to avoid mutating the copy\n", + " current_guesses = list(utt.meta.get('human_guesses', []))\n", + " # append new participant and set back\n", + " current_guesses.append(participant_id)\n", + " utt.add_meta('human_guesses', current_guesses)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "fd613b4b", + "metadata": {}, + "outputs": [], + "source": [ + "# initialize model prediction metadata fields for each utterance\n", + "for convo in corpus.iter_conversations():\n", + " for utt in convo.get_chronological_utterance_list():\n", + " utt.add_meta('model_forecast_probs', {}) # will store {seed: prob}\n", + " utt.add_meta('model_forecasts', {}) # will store {seed: binary_forecast}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "873cb706", + "metadata": {}, + "outputs": [], + "source": [ + "# copy model predictions from the corpora loaded above\n", + "loaded_seed_nums = sorted(\n", + " int(corpus_name.rsplit(\"-\", 1)[-1])\n", + " for corpus_name in corpora\n", + " if corpus_name.startswith(\"test-seed-\")\n", + ")\n", + "\n", + "if not loaded_seed_nums:\n", + " available = \", \".join(sorted(corpora))\n", + " raise KeyError(f\"no test seed corpora found; available corpora: {available}\")\n", + "\n", + "for seed_num in loaded_seed_nums:\n", + " seed_key = f\"test-seed-{seed_num}\"\n", + " print(f\"[info] loading predictions from seed {seed_num}: {seed_key}\")\n", + " seed_corpus = corpora[seed_key]\n", + "\n", + " for convo in corpus.iter_conversations():\n", + " seed_convo = seed_corpus.get_conversation(convo.id)\n", + " main_utts = convo.get_chronological_utterance_list()\n", + " seed_utts = seed_convo.get_chronological_utterance_list()\n", + "\n", + " for main_utt, seed_utt in zip(main_utts, seed_utts):\n", + " forecast_prob = seed_utt.meta.get(\"forecast_prob\")\n", + " forecast = seed_utt.meta.get(\"forecast\")\n", + "\n", + " if forecast_prob is not None:\n", + " current_probs = dict(main_utt.meta.get(\"model_forecast_probs\", {}))\n", + " current_probs[f\"seed_{seed_num}\"] = forecast_prob\n", + " main_utt.add_meta(\"model_forecast_probs\", current_probs)\n", + "\n", + " if forecast is not None:\n", + " current_forecasts = dict(main_utt.meta.get(\"model_forecasts\", {}))\n", + " current_forecasts[f\"seed_{seed_num}\"] = forecast\n", + " main_utt.add_meta(\"model_forecasts\", current_forecasts)\n", + "\n", + "print(\"\\n[pass] model predictions from all seeds added to corpus\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "5b0f0e44", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "gemma horizon (TPs only, mean - 1):\n", + " mean over seeds: h=nan\n", + " pooled: n=0, h=nan\n", + "human horizon (round 1, TPs only, mean - 1):\n", + " player vn72: n=1, h=6.0000\n", + " player nac86: n=2, h=0.5000\n", + " player sj597: n=4, h=2.2500\n", + " player ex36: n=1, h=2.0000\n", + " player kz88: n=2, h=2.5000\n", + " player LJL2: n=4, h=1.5000\n", + " player lyk25: n=4, h=2.5000\n", + " player cd326: n=2, h=1.5000\n", + " player tg352: n=2, h=5.0000\n", + " mean over players: h=2.6389\n", + " pooled: n=22, h=2.3636\n", + "human horizon (round 2, TPs only, mean - 1):\n", + " player sj597: n=4, h=1.2500\n", + " player nac86: n=2, h=4.0000\n", + " player ex36: n=2, h=0.5000\n", + " player vn72: n=3, h=1.6667\n", + " player lyk25: n=4, h=3.0000\n", + " player kz88: n=1, h=3.0000\n", + " player cd326: n=4, h=3.2500\n", + " player tg352: n=3, h=2.0000\n", + " player LJL2: n=2, h=0.5000\n", + " mean over players: h=2.1296\n", + " pooled: n=25, h=2.1600\n", + "\n", + "summary (comparable to h in performance_utils_wiki):\n", + " gemma h (mean over seeds): nan\n", + " round1 h (mean over players): 2.6389\n", + " round2 h (mean over players): 2.1296\n" + ] + } + ], + "source": [ + "# horizon comparison between gemma and humans, using the same convention as\n", + "# performance_utils_wiki.calculate_current_performance:\n", + "# - only count true positives (ground truth awry AND trigger fired)\n", + "# - horizon = (len(utts) - first_trigger_index) - 1\n", + "# - first trigger only (break after first positive)\n", + "#\n", + "# reports per-round results for humans (round 1 and round 2), pulling data\n", + "# directly from the game sqlite so it works regardless of which round_n the\n", + "# rest of the notebook was run under. ljl2 is excluded.\n", + "\n", + "import sqlite3\n", + "import json as _json\n", + "from collections import defaultdict\n", + "import numpy as np\n", + "\n", + "def _horizon_mean(horizons):\n", + " # matches performance_utils_wiki: h = mean(horizons) - 1\n", + " if len(horizons) == 0:\n", + " return float('nan'), 0\n", + " return float(np.mean(horizons)) - 1, len(horizons)\n", + "\n", + "# always excluded (staff/test accounts). ljl2 is a legitimate late submitter\n", + "# in the round 2_b window, not an exclusion.\n", + "base_excluded = {'yc2727', 'sqt2'}\n", + "round1_excluded = set(base_excluded)\n", + "round2_excluded = set(base_excluded)\n", + "\n", + "# gemma: per-seed first-trigger horizons on TP convos only (ground truth awry)\n", + "gemma_horizons_by_seed = defaultdict(list)\n", + "for convo in corpus.iter_conversations():\n", + " if not convo.meta.get('has_removed_comment', False):\n", + " continue # skip non-awry convos; horizon is only meaningful on TPs\n", + " utts = convo.get_chronological_utterance_list()\n", + " for seed_num in range(1, 6):\n", + " for i, utt in enumerate(utts):\n", + " if utt.meta.get('model_forecasts', {}).get(f'seed_{seed_num}') == 1:\n", + " gemma_horizons_by_seed[seed_num].append(len(utts) - i)\n", + " break\n", + "\n", + "print('gemma horizon (TPs only, mean - 1):')\n", + "gemma_per_seed_h = []\n", + "for seed_num in sorted(gemma_horizons_by_seed.keys()):\n", + " h, n = _horizon_mean(gemma_horizons_by_seed[seed_num])\n", + " gemma_per_seed_h.append(h)\n", + " print(f' seed {seed_num}: n={n}, h={h:.4f}')\n", + "gemma_h_mean_of_seeds = float(np.mean(gemma_per_seed_h)) if gemma_per_seed_h else float('nan')\n", + "all_gemma_vals = [v for vs in gemma_horizons_by_seed.values() for v in vs]\n", + "gemma_h_pooled, gemma_n_pooled = _horizon_mean(all_gemma_vals)\n", + "print(f' mean over seeds: h={gemma_h_mean_of_seeds:.4f}')\n", + "print(f' pooled: n={gemma_n_pooled}, h={gemma_h_pooled:.4f}')\n", + "\n", + "# pull round-specific human guesses directly from the sqlite db\n", + "db_path = '/reef/sqt2/cga-eval/human/game_db.sqlite'\n", + "_conn = sqlite3.connect(db_path)\n", + "\n", + "def _load_round_answers(round_n):\n", + " # returns list of (participant_id, sessions) for the given round\n", + " entries = []\n", + " if round_n == 1:\n", + " excluded = round1_excluded\n", + " t_start, t_end = 1730395000000, 1730397199000\n", + " for row in _conn.execute('SELECT * FROM results'):\n", + " if row[0] in excluded:\n", + " continue\n", + " answers = _json.loads(row[1])\n", + " if answers and answers[0]['start_time'] > t_start and answers[0]['start_time'] < t_end:\n", + " entries.append((row[0], answers))\n", + " elif round_n == 2:\n", + " excluded = round2_excluded\n", + " t_start_a, t_end_a = 1731002910000, 1731016180000\n", + " t_start_b, t_end_b = 1732212720000, 1740000000000\n", + " for row in _conn.execute('SELECT * FROM results'):\n", + " name = row[0]\n", + " if name in excluded:\n", + " continue\n", + " answers = _json.loads(row[1])\n", + " if not answers:\n", + " continue\n", + " st = answers[0]['start_time']\n", + " in_a = t_start_a < st < t_end_a\n", + " # the 2_b late window is only valid for ljl2 (matches cell 7)\n", + " in_b = (name.lower() == 'ljl2') and (t_start_b < st < t_end_b)\n", + " if in_a or in_b:\n", + " entries.append((name, answers))\n", + " return entries\n", + "\n", + "def _unique_convo_ids(entries):\n", + " # unique convo ids seen across all included players' sessions\n", + " ids = set()\n", + " for _, sessions in entries:\n", + " for session in sessions:\n", + " cid = session.get('id')\n", + " if cid is not None:\n", + " ids.add(cid)\n", + " return ids\n", + "\n", + "def _compute_human_horizons(entries):\n", + " # returns dict: player_id -> list of horizons (non-awry convos only, first guess per convo)\n", + " by_player = defaultdict(list)\n", + " for participant_id, sessions in entries:\n", + " for session in sessions:\n", + " convo_id = session.get('id')\n", + " if convo_id is None:\n", + " continue\n", + " try:\n", + " convo = corpus.get_conversation(convo_id)\n", + " except KeyError:\n", + " # convo not in the loaded corpus (e.g., different round loaded)\n", + " continue\n", + " if not convo.meta.get('has_removed_comment', False):\n", + " continue # skip non-awry convos; horizon only counted on TPs\n", + " utts = convo.get_chronological_utterance_list()\n", + " for action_idx, action in enumerate(session.get('actions', [])):\n", + " if action.get('guess') is True:\n", + " if action_idx < len(utts):\n", + " by_player[participant_id].append(len(utts) - action_idx)\n", + " break # only first guess per (player, convo)\n", + " return by_player\n", + "\n", + "def _print_human_horizons(label, by_player):\n", + " print(f'human horizon ({label}, TPs only, mean - 1):')\n", + " player_h = []\n", + " for player, horizons in by_player.items():\n", + " h, n = _horizon_mean(horizons)\n", + " player_h.append(h)\n", + " print(f' player {player}: n={n}, h={h:.4f}')\n", + " mean_of_players = float(np.mean(player_h)) if player_h else float('nan')\n", + " all_vals = [v for vs in by_player.values() for v in vs]\n", + " h_pooled, n_pooled = _horizon_mean(all_vals)\n", + " print(f' mean over players: h={mean_of_players:.4f}')\n", + " print(f' pooled: n={n_pooled}, h={h_pooled:.4f}')\n", + " return mean_of_players, h_pooled\n", + "\n", + "EXPECTED_N_CONVOS = 84\n", + "round1_entries = _load_round_answers(1)\n", + "round1_convo_ids = _unique_convo_ids(round1_entries)\n", + "round1_by_player = _compute_human_horizons(round1_entries)\n", + "r1_mean, r1_pooled = _print_human_horizons('round 1', round1_by_player)\n", + "\n", + "round2_entries = _load_round_answers(2)\n", + "round2_convo_ids = _unique_convo_ids(round2_entries)\n", + "round2_by_player = _compute_human_horizons(round2_entries)\n", + "r2_mean, r2_pooled = _print_human_horizons('round 2', round2_by_player)\n", + "\n", + "_conn.close()\n", + "\n", + "print()\n", + "print('summary (comparable to h in performance_utils_wiki):')\n", + "print(f' gemma h (mean over seeds): {gemma_h_mean_of_seeds:.4f}')\n", + "print(f' round1 h (mean over players): {r1_mean:.4f}')\n", + "print(f' round2 h (mean over players): {r2_mean:.4f}')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "44ed25ac", + "metadata": {}, + "outputs": [], + "source": [ + "from collections import defaultdict\n", + "import numpy as np\n", + "def _compute_metrics(tp, fp, tn, fn):\n", + " total = tp + fp + tn + fn\n", + " accuracy = (tp + tn) / total if total > 0 else 0.0\n", + " precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0\n", + " recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0\n", + " f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0\n", + " fpr = fp / (fp + tn) if (fp + tn) > 0 else 0.0\n", + " specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0\n", + " fnr = fn / (fn + tp) if (fn + tp) > 0 else 0.0\n", + " return {\n", + " 'tp': tp, 'fp': fp, 'tn': tn, 'fn': fn,\n", + " 'accuracy': accuracy, 'precision': precision, 'recall': recall,\n", + " 'f1': f1, 'fpr': fpr, 'specificity': specificity, 'fnr': fnr,\n", + " }\n", + "def _gt(convo):\n", + " return bool(convo.meta.get('has_removed_comment', False))\n", + "# gemma benchmark, per seed, then mean and std\n", + "gemma_per_seed = []\n", + "for seed_num in range(1, 6):\n", + " tp = fp = tn = fn = 0\n", + " for convo in corpus.iter_conversations():\n", + " utts = convo.get_chronological_utterance_list()\n", + " pred = any(\n", + " utt.meta.get('model_forecasts', {}).get(f'seed_{seed_num}') == 1\n", + " for utt in utts\n", + " )\n", + " truth = _gt(convo)\n", + " if pred and truth: tp += 1\n", + " elif pred and not truth: fp += 1\n", + " elif not pred and not truth: tn += 1\n", + " elif not pred and truth: fn += 1\n", + " gemma_per_seed.append(_compute_metrics(tp, fp, tn, fn))\n", + "gemma_mean = {k: float(np.mean([m[k] for m in gemma_per_seed])) for k in gemma_per_seed[0]}\n", + "gemma_std = {k: float(np.std([m[k] for m in gemma_per_seed], ddof=1)) for k in gemma_per_seed[0]}\n", + "def _build_per_player_preds(entries):\n", + " # returns dict[player_id] -> dict[convo_id] -> 0/1\n", + " preds = defaultdict(dict)\n", + " for player, sessions in entries:\n", + " for s in sessions:\n", + " cid = s.get('id')\n", + " if cid is None:\n", + " continue\n", + " try:\n", + " corpus.get_conversation(cid)\n", + " except KeyError:\n", + " continue\n", + " actions = s.get('actions', [])\n", + " preds[player][cid] = 1 if any(a.get('guess') is True for a in actions) else 0\n", + " return preds\n", + "def _aggregate_any(preds):\n", + " # per-convo prediction = 1 if any player who saw it guessed\n", + " by_convo = defaultdict(list)\n", + " for player, convo_preds in preds.items():\n", + " for cid, p in convo_preds.items():\n", + " by_convo[cid].append(p)\n", + " tp = fp = tn = fn = 0\n", + " for cid, plist in by_convo.items():\n", + " try:\n", + " convo = corpus.get_conversation(cid)\n", + " except KeyError:\n", + " continue\n", + " pred = 1 if any(plist) else 0\n", + " truth = 1 if _gt(convo) else 0\n", + " if pred and truth: tp += 1\n", + " elif pred and not truth: fp += 1\n", + " elif not pred and not truth: tn += 1\n", + " elif not pred and truth: fn += 1\n", + " return _compute_metrics(tp, fp, tn, fn), len(by_convo)\n", + "def _per_player_metrics(preds):\n", + " rows = {}\n", + " for player, convo_preds in preds.items():\n", + " tp = fp = tn = fn = 0\n", + " for cid, p in convo_preds.items():\n", + " try:\n", + " convo = corpus.get_conversation(cid)\n", + " except KeyError:\n", + " continue\n", + " truth = 1 if _gt(convo) else 0\n", + " if p and truth: tp += 1\n", + " elif p and not truth: fp += 1\n", + " elif not p and not truth: tn += 1\n", + " elif not p and truth: fn += 1\n", + " rows[player] = _compute_metrics(tp, fp, tn, fn)\n", + " return rows\n", + "# round 1 and round 2 humans\n", + "round1_preds = _build_per_player_preds(round1_entries)\n", + "round1_agg, round1_n_convos = _aggregate_any(round1_preds)\n", + "round1_per_player = _per_player_metrics(round1_preds)\n", + "round2_preds = _build_per_player_preds(round2_entries)\n", + "round2_agg, round2_n_convos = _aggregate_any(round2_preds)\n", + "round2_per_player = _per_player_metrics(round2_preds)\n", + "# metric mean over players (a different aggregation view)\n", + "def _mean_over_players(per_player):\n", + " keys = ['accuracy', 'precision', 'recall', 'f1', 'fpr', 'specificity', 'fnr']\n", + " return {k: float(np.mean([m[k] for m in per_player.values()])) for k in keys}\n", + "round1_mean_over_players = _mean_over_players(round1_per_player)\n", + "round2_mean_over_players = _mean_over_players(round2_per_player)\n", + "# printing\n", + "metric_order = ['accuracy', 'precision', 'recall', 'f1', 'fpr', 'specificity', 'fnr']\n", + "def _print_per_player(label, per_player):\n", + " print(f'{label} - per-player metrics:')\n", + " header = f\" {'player':<10}\" + \"\".join(f\"{m:>12}\" for m in metric_order) + f\"{'n_convos':>10}\"\n", + " print(header)\n", + " print(' ' + '-' * (len(header) - 2))\n", + " for player in sorted(per_player):\n", + " m = per_player[player]\n", + " n = m['tp'] + m['fp'] + m['tn'] + m['fn']\n", + " row = f\" {player:<10}\" + \"\".join(f\"{m[k]:>12.4f}\" for k in metric_order) + f\"{n:>10d}\"\n", + " print(row)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "68c5005a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "round 1 - per-player metrics:\n", + " player accuracy precision recall f1 fpr specificity fnr n_convos\n", + " --------------------------------------------------------------------------------------------------------\n", + " LJL2 0.8000 0.8000 0.8000 0.8000 0.2000 0.8000 0.2000 10\n", + " cd326 0.6000 0.6667 0.4000 0.5000 0.2000 0.8000 0.6000 10\n", + " ex36 0.4000 0.3333 0.2000 0.2500 0.4000 0.6000 0.8000 10\n", + " kz88 0.7000 1.0000 0.4000 0.5714 0.0000 1.0000 0.6000 10\n", + " lyk25 0.7000 0.6667 0.8000 0.7273 0.4000 0.6000 0.2000 10\n", + " nac86 0.7000 1.0000 0.4000 0.5714 0.0000 1.0000 0.6000 10\n", + " sj597 0.8000 0.8000 0.8000 0.8000 0.2000 0.8000 0.2000 10\n", + " tg352 0.5000 0.5000 0.4000 0.4444 0.4000 0.6000 0.6000 10\n", + " vn72 0.4000 0.3333 0.2000 0.2500 0.4000 0.6000 0.8000 10\n", + "\n", + "round 2 - per-player metrics:\n", + " player accuracy precision recall f1 fpr specificity fnr n_convos\n", + " --------------------------------------------------------------------------------------------------------\n", + " LJL2 0.6000 0.6667 0.4000 0.5000 0.2000 0.8000 0.6000 10\n", + " cd326 0.9000 1.0000 0.8000 0.8889 0.0000 1.0000 0.2000 10\n", + " ex36 0.6000 0.6667 0.4000 0.5000 0.2000 0.8000 0.6000 10\n", + " kz88 0.4000 0.3333 0.2000 0.2500 0.4000 0.6000 0.8000 10\n", + " lyk25 0.9000 1.0000 0.8000 0.8889 0.0000 1.0000 0.2000 10\n", + " nac86 0.6000 0.6667 0.4000 0.5000 0.2000 0.8000 0.6000 10\n", + " sj597 0.9000 1.0000 0.8000 0.8889 0.0000 1.0000 0.2000 10\n", + " tg352 0.7000 0.7500 0.6000 0.6667 0.2000 0.8000 0.4000 10\n", + " vn72 0.7000 0.7500 0.6000 0.6667 0.2000 0.8000 0.4000 10\n", + "\n", + "aggregate benchmark (gemma: mean±std over 5 seeds; humans: any-player rule + mean over players)\n", + "group accuracy precision recall f1 fpr specificity fnr\n", + "--------------------------------------------------------------------------------------------------------------------------------------\n", + "gemma (mean±std over seeds) 0.700±0.018 0.679±0.026 0.762±0.029 0.718±0.012 0.362±0.052 0.638±0.052 0.238±0.029\n", + "round1 humans (mean over players) 0.6222 0.6778 0.4889 0.5461 0.2444 0.7556 0.5111\n", + "round2 humans (mean over players) 0.7000 0.7593 0.5556 0.6389 0.1556 0.8444 0.4444\n", + "\n", + "round 1 unique convos: 84\n", + "round 2 unique convos: 84\n" + ] + } + ], + "source": [ + "pm = '\\u00b1'\n", + "metric_order = ['accuracy', 'precision', 'recall', 'f1', 'fpr', 'specificity', 'fnr']\n", + "\n", + "def _print_per_player(label, per_player):\n", + " print(f'{label} - per-player metrics:')\n", + " header = f\" {'player':<10}\" + \"\".join(f\"{m:>12}\" for m in metric_order) + f\"{'n_convos':>10}\"\n", + " print(header)\n", + " print(' ' + '-' * (len(header) - 2))\n", + " for player in sorted(per_player):\n", + " m = per_player[player]\n", + " n = m['tp'] + m['fp'] + m['tn'] + m['fn']\n", + " row = f\" {player:<10}\" + \"\".join(f\"{m[k]:>12.4f}\" for k in metric_order) + f\"{n:>10d}\"\n", + " print(row)\n", + "\n", + "_print_per_player('round 1', round1_per_player)\n", + "print()\n", + "_print_per_player('round 2', round2_per_player)\n", + "\n", + "print()\n", + "print(f'aggregate benchmark (gemma: mean{pm}std over 5 seeds; humans: any-player rule + mean over players)')\n", + "\n", + "header = f\"{'group':<36}\" + \"\".join(f\"{m:>14}\" for m in metric_order)\n", + "print(header)\n", + "print('-' * len(header))\n", + "\n", + "def _fmt_mean_std(mean_dict, std_dict):\n", + " return \"\".join(f\" {mean_dict[k]:.3f}{pm}{std_dict[k]:.3f}\" for k in metric_order)\n", + "\n", + "def _fmt_mean(mean_dict):\n", + " return \"\".join(f\"{mean_dict[k]:>14.4f}\" for k in metric_order)\n", + "\n", + "gemma_label = f'gemma (mean{pm}std over seeds)'\n", + "print(f\"{gemma_label:<36}\" + _fmt_mean_std(gemma_mean, gemma_std))\n", + "print(f\"{'round1 humans (mean over players)':<36}\" + _fmt_mean(round1_mean_over_players))\n", + "print(f\"{'round2 humans (mean over players)':<36}\" + _fmt_mean(round2_mean_over_players))\n", + "\n", + "print()\n", + "print(f'round 1 unique convos: {round1_n_convos}')\n", + "print(f'round 2 unique convos: {round2_n_convos}')" + ] + }, + { + "cell_type": "markdown", + "id": "9c68d1bd", + "metadata": {}, + "source": [ + "## Validation of forecast probability decrease" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "5e527737", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "pooled trigger_all current perf, best threshold per seed: 0.5532229185317815 (12359/22340)\n", + "pooled delayed_all best threshold per seed, k=7: 0.8353165159447882 (3510/4202)\n" + ] + } + ], + "source": [ + "import json\n", + "from pathlib import Path\n", + "from convokit import Corpus\n", + "\n", + "corpus_base = Path(\"/reef/lyk25/dynamic_training/game_analysis/corpi/test\")\n", + "config_base = Path(\"/reef/sqt2/FinalAAO/cga-cmv-large/google/gemma-2-9b-it\")\n", + "\n", + "corpus_dirs = sorted(\n", + " corpus_dir\n", + " for corpus_dir in corpus_base.rglob(\"*\")\n", + " if corpus_dir.is_dir() and (corpus_dir / \"index.json\").exists()\n", + ")\n", + "\n", + "if not corpus_dirs:\n", + " raise FileNotFoundError(f\"no convokit corpora found under {corpus_base}\")\n", + "\n", + "\n", + "def _test_seed_key(corpus_dir):\n", + " parent_name = corpus_dir.parent.name\n", + " if parent_name.startswith(\"test-seed-\"):\n", + " return parent_name\n", + " seed_idx = corpus_dir.name.rsplit(\"-\", 1)[-1]\n", + " return f\"test-seed-{seed_idx}\"\n", + "\n", + "\n", + "corpus_path_by_key = {}\n", + "for corpus_dir in corpus_dirs:\n", + " key = _test_seed_key(corpus_dir)\n", + " corpus_path_by_key[key] = corpus_dir\n", + "\n", + "\n", + "def is_calm_sim_forecast(forecast):\n", + " if isinstance(forecast, str):\n", + " forecast = forecast.strip().lower()\n", + " if forecast in {\"0\", \"false\", \"calm\"}:\n", + " return True\n", + " if forecast in {\"1\", \"true\", \"awry\"}:\n", + " return False\n", + " return int(forecast) == 0\n", + "\n", + "\n", + "def count_decreases_at_indices(corpus, get_indices):\n", + " total = 0\n", + " reduction = 0\n", + "\n", + " for convo in corpus.iter_conversations():\n", + " utts = convo.get_chronological_utterance_list()\n", + "\n", + " for i in get_indices(utts):\n", + " if i + 1 >= len(utts):\n", + " continue\n", + "\n", + " f1 = utts[i].meta[\"forecast_prob\"]\n", + " f2 = utts[i + 1].meta[\"forecast_prob\"]\n", + " total += 1\n", + "\n", + " if f1 > f2:\n", + " reduction += 1\n", + "\n", + " return reduction, total\n", + "\n", + "\n", + "def current_trigger_indices(utts, pred_threshold):\n", + " return [\n", + " i\n", + " for i, utt in enumerate(utts)\n", + " if utt.meta[\"forecast_prob\"] > pred_threshold\n", + " ]\n", + "\n", + "\n", + "def delayed_indices_from_sim_forecasts(utts, pred_threshold, k):\n", + " delayed = []\n", + "\n", + " for i, utt in enumerate(utts):\n", + " if utt.meta[\"forecast_prob\"] <= pred_threshold:\n", + " continue\n", + "\n", + " sim_forecasts = utt.meta[\"sim_replies_forecasts\"]\n", + " calm_sim_replies = sum(\n", + " 1 for forecast in sim_forecasts if is_calm_sim_forecast(forecast)\n", + " )\n", + "\n", + " if calm_sim_replies > k:\n", + " delayed.append(i)\n", + "\n", + " return delayed\n", + "\n", + "\n", + "current_reduction = 0\n", + "current_total = 0\n", + "delayed_reduction = 0\n", + "delayed_total = 0\n", + "\n", + "for seed in range(1, 6):\n", + " seed_key = f\"test-seed-{seed}\"\n", + " corpus = Corpus(filename=str(corpus_path_by_key[seed_key]))\n", + "\n", + " with open(config_base / f\"seed-{seed}\" / \"dev_config.json\") as f:\n", + " best_threshold = json.load(f)[\"best_threshold\"]\n", + "\n", + " r, t = count_decreases_at_indices(\n", + " corpus,\n", + " lambda utts: current_trigger_indices(utts, best_threshold),\n", + " )\n", + " current_reduction += r\n", + " current_total += t\n", + "\n", + " r, t = count_decreases_at_indices(\n", + " corpus,\n", + " lambda utts: delayed_indices_from_sim_forecasts(utts, best_threshold, k=7),\n", + " )\n", + " delayed_reduction += r\n", + " delayed_total += t\n", + "\n", + "pooled_trigger_all = current_reduction / current_total\n", + "pooled_delayed_all = delayed_reduction / delayed_total\n", + "\n", + "print(f\"pooled trigger_all current perf, best threshold per seed: {pooled_trigger_all} ({current_reduction}/{current_total})\")\n", + "print(f\"pooled delayed_all best threshold per seed, k=7: {pooled_delayed_all} ({delayed_reduction}/{delayed_total})\")" + ] + }, + { + "cell_type": "markdown", + "id": "cb154714", + "metadata": {}, + "source": [ + "## Calculating oracle threshold" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "cba93b72", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[info] using 5 test-seed corpora: test-seed-1, test-seed-2, test-seed-3, test-seed-4, test-seed-5\n" + ] + } + ], + "source": [ + "# corpora comes from the decisionpolicy-demo download cell; use only test-seed- entries.\n", + "import re\n", + "\n", + "if \"corpora\" not in globals() or not corpora:\n", + " raise RuntimeError(\n", + " \"corpora is empty: run the download cell above first.\"\n", + " )\n", + "\n", + "_prev_keys = tuple(sorted(corpora))\n", + "_test_seed_re = re.compile(r\"^test-seed-(\\d+)$\")\n", + "\n", + "corpora = {\n", + " k: corpora[k]\n", + " for k in sorted(\n", + " (k for k in _prev_keys if _test_seed_re.match(k)),\n", + " key=lambda k: int(_test_seed_re.match(k).group(1)),\n", + " )\n", + "}\n", + "\n", + "if not corpora:\n", + " raise RuntimeError(\n", + " \"no test-seed- corpora after filter; had keys: \"\n", + " + \", \".join(_prev_keys)\n", + " )\n", + "\n", + "print(\n", + " f\"[info] using {len(corpora)} test-seed corpora: \"\n", + " + \", \".join(corpora)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "3be18b5d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[info] per-seed best thresholds: {'test-seed-1': 0.5926666259765625, 'test-seed-2': 0.622459352016449, 'test-seed-3': 0.6513549089431763, 'test-seed-4': 0.6513549089431763, 'test-seed-5': 0.6513549089431763}\n", + "[info] mean best threshold: 0.633838\n", + "[info] generating baseline roc curve with 400 thresholds in [0.483838, 0.783838]\n", + "[info] testing k values: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "import json\n", + "\n", + "seed_best_thresholds = {\n", + " \"test-seed-1\": 0.5926666259765625,\n", + " \"test-seed-2\": 0.622459352016449,\n", + " \"test-seed-3\": 0.6513549089431763,\n", + " \"test-seed-4\": 0.6513549089431763,\n", + " \"test-seed-5\": 0.6513549089431763,\n", + "}\n", + "\n", + "mean_best_threshold = float(np.mean(list(seed_best_thresholds.values())))\n", + "search_radius = 0.15\n", + "num_thresholds = 400\n", + "baseline_thresholds = np.linspace(\n", + " mean_best_threshold - search_radius,\n", + " mean_best_threshold + search_radius,\n", + " num_thresholds,\n", + ")\n", + "\n", + "print(f\"[info] per-seed best thresholds: {seed_best_thresholds}\")\n", + "print(f\"[info] mean best threshold: {mean_best_threshold:.6f}\")\n", + "print(f\"[info] generating baseline roc curve with {len(baseline_thresholds)} thresholds in [{baseline_thresholds[0]:.6f}, {baseline_thresholds[-1]:.6f}]\")\n", + "\n", + "# k values to test for our method\n", + "k_values = list(range(1, 11)) # 1 to 10\n", + "print(f\"[info] testing k values: {k_values}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "f6128fef", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[info] computing baseline roc curves for all seeds...\n", + "\n", + "[info] processing test-son-seed-1...\n", + "[info] test-son-seed-1 baseline progress: 1/400\n", + "[info] test-son-seed-1 baseline progress: 21/400\n", + "[info] test-son-seed-1 baseline progress: 41/400\n", + "[info] test-son-seed-1 baseline progress: 61/400\n", + "[info] test-son-seed-1 baseline progress: 81/400\n", + "[info] test-son-seed-1 baseline progress: 101/400\n", + "[info] test-son-seed-1 baseline progress: 121/400\n", + "[info] test-son-seed-1 baseline progress: 141/400\n", + "[info] test-son-seed-1 baseline progress: 161/400\n", + "[info] test-son-seed-1 baseline progress: 181/400\n", + "[info] test-son-seed-1 baseline progress: 201/400\n", + "[info] test-son-seed-1 baseline progress: 221/400\n", + "[info] test-son-seed-1 baseline progress: 241/400\n", + "[info] test-son-seed-1 baseline progress: 261/400\n", + "[info] test-son-seed-1 baseline progress: 281/400\n", + "[info] test-son-seed-1 baseline progress: 301/400\n", + "[info] test-son-seed-1 baseline progress: 321/400\n", + "[info] test-son-seed-1 baseline progress: 341/400\n", + "[info] test-son-seed-1 baseline progress: 361/400\n", + "[info] test-son-seed-1 baseline progress: 381/400\n", + "[pass] test-son-seed-1 baseline complete - 400 points\n", + "\n", + "[info] processing test-son-seed-2...\n", + "[info] test-son-seed-2 baseline progress: 1/400\n", + "[info] test-son-seed-2 baseline progress: 21/400\n", + "[info] test-son-seed-2 baseline progress: 41/400\n", + "[info] test-son-seed-2 baseline progress: 61/400\n", + "[info] test-son-seed-2 baseline progress: 81/400\n", + "[info] test-son-seed-2 baseline progress: 101/400\n", + "[info] test-son-seed-2 baseline progress: 121/400\n", + "[info] test-son-seed-2 baseline progress: 141/400\n", + "[info] test-son-seed-2 baseline progress: 161/400\n", + "[info] test-son-seed-2 baseline progress: 181/400\n", + "[info] test-son-seed-2 baseline progress: 201/400\n", + "[info] test-son-seed-2 baseline progress: 221/400\n", + "[info] test-son-seed-2 baseline progress: 241/400\n", + "[info] test-son-seed-2 baseline progress: 261/400\n", + "[info] test-son-seed-2 baseline progress: 281/400\n", + "[info] test-son-seed-2 baseline progress: 301/400\n", + "[info] test-son-seed-2 baseline progress: 321/400\n", + "[info] test-son-seed-2 baseline progress: 341/400\n", + "[info] test-son-seed-2 baseline progress: 361/400\n", + "[info] test-son-seed-2 baseline progress: 381/400\n", + "[pass] test-son-seed-2 baseline complete - 400 points\n", + "\n", + "[info] processing test-son-seed-3...\n", + "[info] test-son-seed-3 baseline progress: 1/400\n", + "[info] test-son-seed-3 baseline progress: 21/400\n", + "[info] test-son-seed-3 baseline progress: 41/400\n", + "[info] test-son-seed-3 baseline progress: 61/400\n", + "[info] test-son-seed-3 baseline progress: 81/400\n", + "[info] test-son-seed-3 baseline progress: 101/400\n", + "[info] test-son-seed-3 baseline progress: 121/400\n", + "[info] test-son-seed-3 baseline progress: 141/400\n", + "[info] test-son-seed-3 baseline progress: 161/400\n", + "[info] test-son-seed-3 baseline progress: 181/400\n", + "[info] test-son-seed-3 baseline progress: 201/400\n", + "[info] test-son-seed-3 baseline progress: 221/400\n", + "[info] test-son-seed-3 baseline progress: 241/400\n", + "[info] test-son-seed-3 baseline progress: 261/400\n", + "[info] test-son-seed-3 baseline progress: 281/400\n", + "[info] test-son-seed-3 baseline progress: 301/400\n", + "[info] test-son-seed-3 baseline progress: 321/400\n", + "[info] test-son-seed-3 baseline progress: 341/400\n", + "[info] test-son-seed-3 baseline progress: 361/400\n", + "[info] test-son-seed-3 baseline progress: 381/400\n", + "[pass] test-son-seed-3 baseline complete - 400 points\n", + "\n", + "[info] processing test-son-seed-4...\n", + "[info] test-son-seed-4 baseline progress: 1/400\n", + "[info] test-son-seed-4 baseline progress: 21/400\n", + "[info] test-son-seed-4 baseline progress: 41/400\n", + "[info] test-son-seed-4 baseline progress: 61/400\n", + "[info] test-son-seed-4 baseline progress: 81/400\n", + "[info] test-son-seed-4 baseline progress: 101/400\n", + "[info] test-son-seed-4 baseline progress: 121/400\n", + "[info] test-son-seed-4 baseline progress: 141/400\n", + "[info] test-son-seed-4 baseline progress: 161/400\n", + "[info] test-son-seed-4 baseline progress: 181/400\n", + "[info] test-son-seed-4 baseline progress: 201/400\n", + "[info] test-son-seed-4 baseline progress: 221/400\n", + "[info] test-son-seed-4 baseline progress: 241/400\n", + "[info] test-son-seed-4 baseline progress: 261/400\n", + "[info] test-son-seed-4 baseline progress: 281/400\n", + "[info] test-son-seed-4 baseline progress: 301/400\n", + "[info] test-son-seed-4 baseline progress: 321/400\n", + "[info] test-son-seed-4 baseline progress: 341/400\n", + "[info] test-son-seed-4 baseline progress: 361/400\n", + "[info] test-son-seed-4 baseline progress: 381/400\n", + "[pass] test-son-seed-4 baseline complete - 400 points\n", + "\n", + "[info] processing test-son-seed-5...\n", + "[info] test-son-seed-5 baseline progress: 1/400\n", + "[info] test-son-seed-5 baseline progress: 21/400\n", + "[info] test-son-seed-5 baseline progress: 41/400\n", + "[info] test-son-seed-5 baseline progress: 61/400\n", + "[info] test-son-seed-5 baseline progress: 81/400\n", + "[info] test-son-seed-5 baseline progress: 101/400\n", + "[info] test-son-seed-5 baseline progress: 121/400\n", + "[info] test-son-seed-5 baseline progress: 141/400\n", + "[info] test-son-seed-5 baseline progress: 161/400\n", + "[info] test-son-seed-5 baseline progress: 181/400\n", + "[info] test-son-seed-5 baseline progress: 201/400\n", + "[info] test-son-seed-5 baseline progress: 221/400\n", + "[info] test-son-seed-5 baseline progress: 241/400\n", + "[info] test-son-seed-5 baseline progress: 261/400\n", + "[info] test-son-seed-5 baseline progress: 281/400\n", + "[info] test-son-seed-5 baseline progress: 301/400\n", + "[info] test-son-seed-5 baseline progress: 321/400\n", + "[info] test-son-seed-5 baseline progress: 341/400\n", + "[info] test-son-seed-5 baseline progress: 361/400\n", + "[info] test-son-seed-5 baseline progress: 381/400\n", + "[pass] test-son-seed-5 baseline complete - 400 points\n", + "\n", + "[pass] all baseline roc curves computed for 5 seeds\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from convokit.decisionpolicy import ThresholdDecisionPolicy\n", + "from convokit.forecaster.forecaster import ContextTuple\n", + "\n", + "\n", + "def _cached_forecast_score(context):\n", + " meta = getattr(context.current_utterance, \"meta\", {}) or {}\n", + " if \"forecast_prob\" not in meta:\n", + " raise KeyError(f\"missing forecast_prob for utterance {context.current_utterance.id}\")\n", + " return meta[\"forecast_prob\"]\n", + "\n", + "\n", + "def _threshold_policy_performance(corpus, policy):\n", + " tp = fp = tn = fn = 0\n", + " horizons = []\n", + "\n", + " for convo in corpus.iter_conversations():\n", + " utts = convo.get_chronological_utterance_list()\n", + " pred = 0\n", + " first_trigger_idx = None\n", + "\n", + " for utt_idx, utt in enumerate(utts):\n", + " context = ContextTuple(\n", + " context=utts[: utt_idx + 1],\n", + " current_utterance=utt,\n", + " future_context=utts[utt_idx + 1 :],\n", + " conversation_id=convo.id,\n", + " )\n", + " _, utt_pred = policy.decide(context, _cached_forecast_score)\n", + " if int(utt_pred) == 1:\n", + " pred = 1\n", + " first_trigger_idx = utt_idx\n", + " break\n", + "\n", + " truth = bool(convo.meta.get(\"has_removed_comment\", False))\n", + " if pred and truth:\n", + " tp += 1\n", + " horizons.append(len(utts) - first_trigger_idx)\n", + " elif pred and not truth:\n", + " fp += 1\n", + " elif not pred and not truth:\n", + " tn += 1\n", + " else:\n", + " fn += 1\n", + "\n", + " h = float(np.mean(horizons)) - 1 if horizons else float(\"nan\")\n", + " return {\n", + " \"confusion_matrix\": {\"TP\": tp, \"FP\": fp, \"TN\": tn, \"FN\": fn},\n", + " \"h\": h,\n", + " }\n", + "\n", + "\n", + "# step 1: compute baseline roc curve with ThresholdDecisionPolicy\n", + "print(\"[info] computing baseline roc curves for all seeds...\")\n", + "all_baseline_results = {}\n", + "\n", + "for seed_name, corpus in corpora.items():\n", + " print(f\"\\n[info] processing {seed_name}...\")\n", + " baseline_results = []\n", + "\n", + " for i, threshold in enumerate(baseline_thresholds):\n", + " if i % 20 == 0:\n", + " print(f\"[info] {seed_name} baseline progress: {i+1}/{len(baseline_thresholds)}\")\n", + "\n", + " policy = ThresholdDecisionPolicy(threshold=threshold)\n", + " results = _threshold_policy_performance(corpus, policy)\n", + " tp = results[\"confusion_matrix\"][\"TP\"]\n", + " fp = results[\"confusion_matrix\"][\"FP\"]\n", + " tn = results[\"confusion_matrix\"][\"TN\"]\n", + " fn = results[\"confusion_matrix\"][\"FN\"]\n", + "\n", + " # calculate tpr, fpr, accuracy, precision, recall, f1\n", + " # recall == tpr; kept as a separate field for downstream readability\n", + " tpr = tp / (tp + fn) if (tp + fn) > 0 else 0\n", + " fpr = fp / (fp + tn) if (fp + tn) > 0 else 0\n", + " total = tp + fp + tn + fn\n", + " accuracy = (tp + tn) / total if total > 0 else 0\n", + " precision = tp / (tp + fp) if (tp + fp) > 0 else 0\n", + " recall = tpr\n", + " f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0\n", + "\n", + " baseline_results.append({\n", + " \"seed\": seed_name,\n", + " \"threshold\": threshold,\n", + " \"tpr\": tpr,\n", + " \"fpr\": fpr,\n", + " \"accuracy\": accuracy,\n", + " \"precision\": precision,\n", + " \"recall\": recall,\n", + " \"f1\": f1,\n", + " \"h\": results.get(\"h\", float(\"nan\")),\n", + " \"tp\": tp,\n", + " \"fp\": fp,\n", + " \"tn\": tn,\n", + " \"fn\": fn,\n", + " })\n", + "\n", + " all_baseline_results[seed_name] = pd.DataFrame(baseline_results)\n", + " print(f\"[pass] {seed_name} baseline complete - {len(all_baseline_results[seed_name])} points\")\n", + "\n", + "print(f\"\\n[pass] all baseline roc curves computed for {len(all_baseline_results)} seeds\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "022d240f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[info] all_baseline_results has been pickled to all_baseline_results.pkl\n" + ] + } + ], + "source": [ + "import pickle\n", + "\n", + "with open('all_baseline_results.pkl', 'wb') as f:\n", + " pickle.dump(all_baseline_results, f)\n", + "print(\"[info] all_baseline_results has been pickled to all_baseline_results.pkl\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "aa6450a5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "[info] processing test-son-seed-1...\n", + "[info] config seed: seed-1\n", + "[info] k method threshold 0.5926666259765625\n", + "[info] test-son-seed-1 computing k=1...\n", + "[info] test-son-seed-1 computing k=2...\n", + "[info] test-son-seed-1 computing k=3...\n", + "[info] test-son-seed-1 computing k=4...\n", + "[info] test-son-seed-1 computing k=5...\n", + "[info] test-son-seed-1 computing k=6...\n", + "[info] test-son-seed-1 computing k=7...\n", + "[info] test-son-seed-1 computing k=8...\n", + "[info] test-son-seed-1 computing k=9...\n", + "[info] test-son-seed-1 computing k=10...\n", + "[pass] test-son-seed-1 k-method complete - 10 points\n", + "\n", + "[info] processing test-son-seed-2...\n", + "[info] config seed: seed-2\n", + "[info] k method threshold 0.622459352016449\n", + "[info] test-son-seed-2 computing k=1...\n", + "[info] test-son-seed-2 computing k=2...\n", + "[info] test-son-seed-2 computing k=3...\n", + "[info] test-son-seed-2 computing k=4...\n", + "[info] test-son-seed-2 computing k=5...\n", + "[info] test-son-seed-2 computing k=6...\n", + "[info] test-son-seed-2 computing k=7...\n", + "[info] test-son-seed-2 computing k=8...\n", + "[info] test-son-seed-2 computing k=9...\n", + "[info] test-son-seed-2 computing k=10...\n", + "[pass] test-son-seed-2 k-method complete - 10 points\n", + "\n", + "[info] processing test-son-seed-3...\n", + "[info] config seed: seed-3\n", + "[info] k method threshold 0.6513549089431763\n", + "[info] test-son-seed-3 computing k=1...\n", + "[info] test-son-seed-3 computing k=2...\n", + "[info] test-son-seed-3 computing k=3...\n", + "[info] test-son-seed-3 computing k=4...\n", + "[info] test-son-seed-3 computing k=5...\n", + "[info] test-son-seed-3 computing k=6...\n", + "[info] test-son-seed-3 computing k=7...\n", + "[info] test-son-seed-3 computing k=8...\n", + "[info] test-son-seed-3 computing k=9...\n", + "[info] test-son-seed-3 computing k=10...\n", + "[pass] test-son-seed-3 k-method complete - 10 points\n", + "\n", + "[info] processing test-son-seed-4...\n", + "[info] config seed: seed-4\n", + "[info] k method threshold 0.6513549089431763\n", + "[info] test-son-seed-4 computing k=1...\n", + "[info] test-son-seed-4 computing k=2...\n", + "[info] test-son-seed-4 computing k=3...\n", + "[info] test-son-seed-4 computing k=4...\n", + "[info] test-son-seed-4 computing k=5...\n", + "[info] test-son-seed-4 computing k=6...\n", + "[info] test-son-seed-4 computing k=7...\n", + "[info] test-son-seed-4 computing k=8...\n", + "[info] test-son-seed-4 computing k=9...\n", + "[info] test-son-seed-4 computing k=10...\n", + "[pass] test-son-seed-4 k-method complete - 10 points\n", + "\n", + "[info] processing test-son-seed-5...\n", + "[info] config seed: seed-5\n", + "[info] k method threshold 0.6513549089431763\n", + "[info] test-son-seed-5 computing k=1...\n", + "[info] test-son-seed-5 computing k=2...\n", + "[info] test-son-seed-5 computing k=3...\n", + "[info] test-son-seed-5 computing k=4...\n", + "[info] test-son-seed-5 computing k=5...\n", + "[info] test-son-seed-5 computing k=6...\n", + "[info] test-son-seed-5 computing k=7...\n", + "[info] test-son-seed-5 computing k=8...\n", + "[info] test-son-seed-5 computing k=9...\n", + "[info] test-son-seed-5 computing k=10...\n", + "[pass] test-son-seed-5 k-method complete - 10 points\n", + "\n", + "[pass] all k-method results computed for 5 seeds\n" + ] + } + ], + "source": [ + "import json\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "from convokit.decisionpolicy import DeferralDecisionPolicy\n", + "from convokit.forecaster.forecaster import ContextTuple\n", + "\n", + "\n", + "def _cached_forecast_score(context):\n", + " meta = getattr(context.current_utterance, \"meta\", {}) or {}\n", + " if \"forecast_prob\" not in meta:\n", + " raise KeyError(f\"missing forecast_prob for utterance {context.current_utterance.id}\")\n", + " return meta[\"forecast_prob\"]\n", + "\n", + "\n", + "def _deferral_policy_performance(corpus, policy):\n", + " tp = fp = tn = fn = 0\n", + " horizons = []\n", + "\n", + " for convo in corpus.iter_conversations():\n", + " utts = convo.get_chronological_utterance_list()\n", + " pred = 0\n", + " first_trigger_idx = None\n", + "\n", + " for utt_idx, utt in enumerate(utts):\n", + " context = ContextTuple(\n", + " context=utts[: utt_idx + 1],\n", + " current_utterance=utt,\n", + " future_context=utts[utt_idx + 1 :],\n", + " conversation_id=convo.id,\n", + " )\n", + " result = policy.decide(context, _cached_forecast_score)\n", + " utt_pred = int(result[1])\n", + " if utt_pred == 1:\n", + " pred = 1\n", + " first_trigger_idx = utt_idx\n", + " break\n", + "\n", + " truth = bool(convo.meta.get(\"has_removed_comment\", False))\n", + " if pred and truth:\n", + " tp += 1\n", + " horizons.append(len(utts) - first_trigger_idx)\n", + " elif pred and not truth:\n", + " fp += 1\n", + " elif not pred and not truth:\n", + " tn += 1\n", + " else:\n", + " fn += 1\n", + "\n", + " h = float(np.mean(horizons)) - 1 if horizons else float(\"nan\")\n", + " return {\n", + " \"confusion_matrix\": {\"TP\": tp, \"FP\": fp, \"TN\": tn, \"FN\": fn},\n", + " \"h\": h,\n", + " }\n", + "\n", + "\n", + "all_k_results = {}\n", + "\n", + "for seed_name, corpus in corpora.items():\n", + " print(f\"\\n[info] processing {seed_name}...\")\n", + " k_results = []\n", + "\n", + " pruned_seed_name = seed_name[len(seed_name) - 6 :]\n", + " print(f\"[info] config seed: {pruned_seed_name}\")\n", + "\n", + " with open(f\"/reef/sqt2/FinalAAO/cga-cmv-large/google/gemma-2-9b-it/{pruned_seed_name}/dev_config.json\", \"r\") as f:\n", + " dev_config = json.load(f)\n", + "\n", + " k_method_threshold = dev_config[\"best_threshold\"]\n", + " print(\"[info] k method threshold\", k_method_threshold)\n", + "\n", + " for k in k_values:\n", + " print(f\"[info] {seed_name} computing k={k}...\")\n", + " policy = DeferralDecisionPolicy(\n", + " simulator=None,\n", + " threshold=k_method_threshold,\n", + " tau=k,\n", + " num_simulations=10,\n", + " store_simulations=False,\n", + " simulated_reply_attribute_name=\"sim_replies\",\n", + " sim_replies_forecast_probs_attribute_name=\"sim_replies_forecast_probs\",\n", + " reuse_cached_simulations=True,\n", + " )\n", + " results = _deferral_policy_performance(corpus, policy)\n", + " tp = results[\"confusion_matrix\"][\"TP\"]\n", + " fp = results[\"confusion_matrix\"][\"FP\"]\n", + " tn = results[\"confusion_matrix\"][\"TN\"]\n", + " fn = results[\"confusion_matrix\"][\"FN\"]\n", + "\n", + " # calculate tpr, fpr, accuracy, precision, recall, f1\n", + " # recall == tpr; kept as a separate field for downstream readability\n", + " tpr = tp / (tp + fn) if (tp + fn) > 0 else 0\n", + " fpr = fp / (fp + tn) if (fp + tn) > 0 else 0\n", + " total = tp + fp + tn + fn\n", + " accuracy = (tp + tn) / total if total > 0 else 0\n", + " precision = tp / (tp + fp) if (tp + fp) > 0 else 0\n", + " recall = tpr\n", + " f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0\n", + "\n", + " k_results.append({\n", + " \"seed\": seed_name,\n", + " \"k\": k,\n", + " \"threshold\": k_method_threshold,\n", + " \"tpr\": tpr,\n", + " \"fpr\": fpr,\n", + " \"accuracy\": accuracy,\n", + " \"precision\": precision,\n", + " \"recall\": recall,\n", + " \"f1\": f1,\n", + " \"h\": results.get(\"h\", float(\"nan\")),\n", + " \"tp\": tp,\n", + " \"fp\": fp,\n", + " \"tn\": tn,\n", + " \"fn\": fn,\n", + " })\n", + "\n", + " all_k_results[seed_name] = pd.DataFrame(k_results)\n", + " print(f\"[pass] {seed_name} k-method complete - {len(all_k_results[seed_name])} points\")\n", + "\n", + "print(f\"\\n[pass] all k-method results computed for {len(all_k_results)} seeds\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "bba8d667", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[info] matching fprs for all seeds...\n", + "\n", + "[info] matching fprs for test-son-seed-1...\n", + "[PASS] test-son-seed-1 matched 10 fpr points\n", + "\n", + "[info] matching fprs for test-son-seed-2...\n", + "[PASS] test-son-seed-2 matched 10 fpr points\n", + "\n", + "[info] matching fprs for test-son-seed-3...\n", + "[PASS] test-son-seed-3 matched 10 fpr points\n", + "\n", + "[info] matching fprs for test-son-seed-4...\n", + "[PASS] test-son-seed-4 matched 10 fpr points\n", + "\n", + "[info] matching fprs for test-son-seed-5...\n", + "[PASS] test-son-seed-5 matched 10 fpr points\n", + "\n", + "[PASS] all matching complete for 5 seeds\n" + ] + } + ], + "source": [ + "# step 3: for each k's FPR, find the closest baseline FPR (per seed)\n", + "print(\"[info] matching fprs for all seeds...\")\n", + "all_matched_results = {}\n", + "\n", + "for seed_name in corpora.keys():\n", + " print(f\"\\n[info] matching fprs for {seed_name}...\")\n", + " matched_results = []\n", + " \n", + " k_df = all_k_results[seed_name]\n", + " baseline_df = all_baseline_results[seed_name]\n", + " \n", + " for _, k_row in k_df.iterrows():\n", + " k_fpr = k_row['fpr']\n", + " k_tpr = k_row['tpr']\n", + " k_val = k_row['k']\n", + " \n", + " # find closest baseline fpr\n", + " fpr_diffs = np.abs(baseline_df['fpr'] - k_fpr)\n", + " closest_idx = fpr_diffs.idxmin()\n", + " baseline_match = baseline_df.iloc[closest_idx]\n", + " \n", + " matched_results.append({\n", + " 'seed': seed_name,\n", + " 'k': k_val,\n", + " 'k_fpr': k_fpr,\n", + " 'k_tpr': k_tpr,\n", + " 'baseline_fpr': baseline_match['fpr'],\n", + " 'baseline_tpr': baseline_match['tpr'],\n", + " 'baseline_accuracy': baseline_match['accuracy'],\n", + " 'baseline_precision': baseline_match['precision'],\n", + " 'baseline_recall': baseline_match['recall'],\n", + " 'baseline_f1': baseline_match['f1'],\n", + " 'baseline_h': baseline_match['h'],\n", + " 'baseline_threshold': baseline_match['threshold'],\n", + " 'fpr_diff': abs(k_fpr - baseline_match['fpr']),\n", + " 'tpr_improvement': k_tpr - baseline_match['tpr']\n", + " })\n", + " \n", + " all_matched_results[seed_name] = pd.DataFrame(matched_results)\n", + " print(f\"[PASS] {seed_name} matched {len(all_matched_results[seed_name])} fpr points\")\n", + "\n", + "print(f\"\\n[PASS] all matching complete for {len(all_matched_results)} seeds\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "95588d73", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[info] matched fpr-oracle metrics for tau=7 (mean ± std over seeds, n=5):\n", + " tau baseline_accuracy_mean baseline_accuracy_std baseline_precision_mean baseline_precision_std baseline_recall_mean baseline_recall_std baseline_f1_mean baseline_f1_std baseline_h_mean baseline_h_std fpr_diff_mean fpr_diff_std tpr_improvement_mean tpr_improvement_std baseline_threshold_mean baseline_threshold_std baseline_fpr_mean baseline_fpr_std baseline_tpr_mean baseline_tpr_std\n", + " 7 0.7002 0.0088 0.7152 0.0100 0.6677 0.0494 0.6896 0.0218 2.6965 0.1029 0.0036 0.0028 0.0155 0.0055 0.6900 0.0245 0.2674 0.0331 0.6677 0.0494\n" + ] + } + ], + "source": [ + "# average matched-baseline (oracle) metrics for tau=7, across all seeds\n", + "target_tau = 7\n", + "matched_long = pd.concat(\n", + " [df.assign(seed=seed_name) for seed_name, df in all_matched_results.items()],\n", + " ignore_index=True,\n", + ")\n", + "matched_long = matched_long[matched_long[\"k\"] == target_tau].copy()\n", + "\n", + "if matched_long.empty:\n", + " raise ValueError(f\"no matched results found for tau={target_tau}\")\n", + "\n", + "oracle_cols = [\n", + " \"baseline_accuracy\",\n", + " \"baseline_precision\",\n", + " \"baseline_recall\",\n", + " \"baseline_f1\",\n", + " \"baseline_h\",\n", + " \"fpr_diff\",\n", + " \"tpr_improvement\",\n", + " \"baseline_threshold\",\n", + " \"baseline_fpr\",\n", + " \"baseline_tpr\",\n", + "]\n", + "\n", + "oracle_stats_per_tau = (\n", + " matched_long.groupby(\"k\")[oracle_cols]\n", + " .agg([\"mean\", \"std\"])\n", + " .reset_index()\n", + " .rename(columns={\"k\": \"tau\"})\n", + ")\n", + "\n", + "# reformat columns to single-level, e.g. baseline_fpr_mean\n", + "oracle_stats_per_tau.columns = [\"tau\"] + [\n", + " f\"{col}_{stat}\" for col in oracle_cols for stat in [\"mean\", \"std\"]\n", + "]\n", + "\n", + "with pd.option_context(\n", + " \"display.float_format\",\n", + " \"{:.4f}\".format,\n", + " \"display.max_columns\",\n", + " None,\n", + " \"display.width\",\n", + " 200,\n", + "):\n", + " print(\"[info] matched fpr-oracle metrics for tau=7 \"\n", + " f\"(mean ± std over seeds, n={matched_long['seed'].nunique()}):\")\n", + " print(oracle_stats_per_tau.to_string(index=False))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "lyk25-env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" } - ], - "source": [ - "simulator_model = UnslothUtteranceSimulatorModel(\n", - " model_name=\"/reef/lyk25/dynamic_training/game_analysis/outputs/checkpoint-74\",\n", - " train_config=SIMULATOR_TRAIN_CONFIG,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "fe147ae8", - "metadata": {}, - "source": [ - "After loading the simulator model, we define context selectors." - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "a9502041", - "metadata": {}, - "outputs": [], - "source": [ - "def context_selector(context_tuple, split):\n", - " \"\"\"\n", - " We use this generic function for both training and validation data.\n", - " In both cases, its job is to select only those contexts for which the\n", - " FUTURE context is not empty, so we have a next utterance to predict.\n", - " \"\"\"\n", - " matches_split = (context_tuple.current_utterance.get_conversation().meta.get(\"split\") == split)\n", - " is_end = (len(context_tuple.future_context) == 0)\n", - " return matches_split and not is_end\n", - "\n", - "def make_data_selector(split):\n", - " return lambda context_tuple: context_tuple.current_utterance.get_conversation().meta.get(\"split\") == split\n", - "\n", - "train_context_selector = partial(context_selector, split=\"train\")\n", - "val_context_selector = partial(context_selector, split=\"val\")\n", - "test_context_selector = partial(context_selector, split=\"test\")" - ] - }, - { - "cell_type": "markdown", - "id": "f707a3a5", - "metadata": {}, - "source": [ - "Below is a script to fully reproduce, sans regenerating simulations (which takes substantially longer)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dde130ed", - "metadata": {}, - "outputs": [], - "source": [ - "for seed_idx in SEEDS:\n", - " train_corpus = Corpus(filename=download('conversations-gone-awry-cmv-corpus-large'))\n", - " corpus = Corpus(filename=f'/reef/lyk25/theres-a-way-out-ACL26-internal/outputs/benchmark_preannotated/seed-{seed_idx}-ThresholdDecisionPolicy/ThresholdDecisionPolicy')\n", - " # corpus = Corpus(filename=f'/reef/lyk25/dynamic_training/game_analysis/corpi/test/test-son-seed-{seed_idx}')\n", - "\n", - " config = TransformerForecasterConfig(\n", - " output_dir=f\"outputs/{OUTPUT_DIR}/forecaster_{seed_idx}\",\n", - " per_device_batch_size=16,\n", - " gradient_accumulation_steps=1,\n", - " num_train_epochs=1,\n", - " learning_rate=1e-5,\n", - " random_seed=seed_idx,\n", - " context_mode=\"normal\",\n", - " device=\"cuda\",\n", - " )\n", - "\n", - " # TODO this will have to be edited\n", - " forecaster_model = TransformerDecoderModel(\n", - " model_name_or_path=\"google/gemma-2-9b-it\",\n", - " config=config,\n", - " )\n", - "\n", - " forecaster = Forecaster(\n", - " forecaster_model=forecaster_model,\n", - " labeler='has_removed_comment',\n", - " )\n", - "\n", - " forecaster.fit_belief_estimator(\n", - " corpus=train_corpus,\n", - " context_selector=train_context_selector,\n", - " val_context_selector=val_context_selector,\n", - " )\n", - "\n", - "\n", - "\n", - " # ---\n", - " cfg_path = os.path.join(repo_root, \"saves\", f\"seed-{seed_idx}\", \"dev_config.json\")\n", - " with open(cfg_path) as f:\n", - " cfg = json.load(f)\n", - " best_threshold = cfg['best_threshold']\n", - "\n", - " for policy_trial in [ThresholdDecisionPolicy, DeferralDecisionPolicy]:\n", - " print('---')\n", - " print(f\"Fitting policy {policy_trial.__name__} for seed {seed_idx}\")\n", - " if policy_trial == ThresholdDecisionPolicy:\n", - " policy = ThresholdDecisionPolicy(\n", - " threshold=best_threshold,\n", - " reuse_cached_forecast_probs=False,\n", - " )\n", - " elif policy_trial == DeferralDecisionPolicy:\n", - " policy = DeferralDecisionPolicy(\n", - " simulator=simulator_model,\n", - " threshold=best_threshold,\n", - " tau=TAU,\n", - " reuse_cached_forecast_probs=False,\n", - " )\n", - " elif policy_trial == RandomDeferralDecisionPolicy:\n", - " policy = RandomDeferralDecisionPolicy(\n", - " simulator=simulator_model,\n", - " threshold=best_threshold,\n", - " deferral_probability=DEFERRAL_PROBABILITY_THRESHOLD,\n", - " reuse_cached_forecast_probs=False,\n", - " )\n", - " elif policy_trial == SimulationAverageDecisionPolicy:\n", - " policy = SimulationAverageDecisionPolicy(\n", - " simulator=simulator_model,\n", - " threshold=best_threshold,\n", - " num_simulations=NUM_SIMULATIONS,\n", - " store_simulations=False,\n", - " simulated_reply_attribute_name=\"sim_replies\",\n", - " sim_replies_forecast_probs_attribute_name=\"sim_replies_forecast_probs\",\n", - " reuse_cached_forecast_probs=False,\n", - " )\n", - " elif policy_trial == SimulationMajorityDecisionPolicy:\n", - " policy = SimulationMajorityDecisionPolicy(\n", - " simulator=simulator_model,\n", - " threshold=best_threshold,\n", - " tau=TAU,\n", - " num_simulations=NUM_SIMULATIONS,\n", - " store_simulations=False,\n", - " simulated_reply_attribute_name=\"sim_replies\",\n", - " sim_replies_forecast_probs_attribute_name=\"sim_replies_forecast_probs\",\n", - " reuse_cached_forecast_probs=False,\n", - " )\n", - " \n", - " # attach the decision policy to the underlying forecaster model;\n", - " # Forecaster itself does not accept decision_policy in its constructor.\n", - " forecaster_model.decision_policy = policy\n", - "\n", - " forecaster = Forecaster(\n", - " forecaster_model=forecaster_model,\n", - " labeler='has_removed_comment',\n", - " )\n", - "\n", - " print('starting transformation.')\n", - " # evaluate the forecaster on the test set\n", - " forecaster.transform(\n", - " corpus=corpus,\n", - " context_selector=make_data_selector('test'),\n", - " verbose=True,\n", - " )\n", - " print('transformation complete.')\n", - "\n", - " output_dir = f\"outputs/{OUTPUT_DIR}/seed-{seed_idx}-{policy_trial.__name__}\"\n", - " os.makedirs(output_dir, exist_ok=True)\n", - " corpus.dump(name=f\"{policy_trial.__name__}\", base_path=output_dir)\n", - " print('corpus dumped.')\n", - "\n", - " print('starting summarization.')\n", - " # forecaster.summarize expects a conversation-level selector (Callable[[Conversation], bool]),\n", - " # unlike the context-tuple selectors used in fit/transform.\n", - " def summarize_selector(convo):\n", - " return convo.meta.get(\"split\") == \"test\"\n", - " conversational_forecasts_df, metrics = forecaster.summarize(\n", - " corpus=corpus,\n", - " selector=summarize_selector,\n", - " )\n", - " print('summarization complete.')\n", - " \n", - " # path to the seed output directory\n", - " seed_folder = f\"outputs/{OUTPUT_DIR}/seed-{seed_idx}-{policy_trial.__name__}\"\n", - "\n", - " # ensure the directory exists\n", - " os.makedirs(seed_folder, exist_ok=True)\n", - "\n", - " # save conversational_forecasts_df as CSV\n", - " conversational_forecasts_df.to_csv(os.path.join(seed_folder, \"conversational_forecasts.csv\"), index=False)\n", - "\n", - " # save metrics as JSON\n", - " with open(os.path.join(seed_folder, \"metrics.json\"), \"w\") as f:\n", - " json.dump(metrics, f, indent=2)" - ] - }, - { - "cell_type": "markdown", - "id": "e3c6f428", - "metadata": {}, - "source": [ - "A faster reproduction is possible by skipping the training and transformation." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2f4a29c2", - "metadata": {}, - "outputs": [], - "source": [ - "# TODO insert code here to download all" - ] - }, - { - "cell_type": "markdown", - "id": "29a0cdc1", - "metadata": {}, - "source": [ - "# Human Data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d673d0d6", - "metadata": {}, - "outputs": [], - "source": [ - "# import human data SQL here" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d55b3bab", - "metadata": {}, - "outputs": [], - "source": [ - "import sqlite3" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "92766a80", - "metadata": {}, - "outputs": [], - "source": [ - "# connect to the game database\n", - "db_path = '/reef/sqt2/cga-eval/human/game_db.sqlite'\n", - "connection = sqlite3.connect(db_path)\n", - "cursor = connection.cursor()\n", - "\n", - "print(f\"[info] connected to database: {db_path}\")\n", - "cursor.execute(\"SELECT COUNT(*) FROM results\")\n", - "\n", - "# # names_list = []\n", - "# # answers_list = []\n", - "# # scores_list = []\n", - "# # comments_list = []\n", - "\n", - "# TIME_CUTOFF = 1714059814630 # timestamp to cut off previous results\n", - "# TIME_CUTOFF_END = 1730385762530\n", - "\n", - "# for row in connection.execute('SELECT * FROM results'):\n", - "# if row[0] != 'yc2727' and row[0] != 'sqt2':\n", - "# answers = json.loads(row[1])\n", - "# if answers[0]['start_time'] > TIME_CUTOFF and answers[0]['start_time'] < TIME_CUTOFF_END:\n", - "# names_list.append(row[0])\n", - "# answers_list.append(answers)\n", - "# scores_list.append(row[2])\n", - "# comments_list.append(row[3])\n", - "\n", - "if round_n == 1:\n", - " # for round_n 1\n", - " names_list_1 = []\n", - " answers_list_1 = []\n", - " scores_list_1 = []\n", - " comments_list_1 = []\n", - "\n", - " TIME_CUTOFF_1 = 1730395000000\n", - " TIME_CUTOFF_END_1 = 1730397199000\n", - "\n", - " for row in connection.execute('SELECT * FROM results'):\n", - " if row[0] != 'yc2727' and row[0] != 'sqt2':\n", - " answers = json.loads(row[1])\n", - " if answers[0]['start_time'] > TIME_CUTOFF_1 and answers[0]['start_time'] < TIME_CUTOFF_END_1:\n", - " names_list_1.append(row[0])\n", - " answers_list_1.append(answers)\n", - " scores_list_1.append(row[2])\n", - " comments_list_1.append(row[3])\n", - "elif round_n == 2:\n", - " # for round 2\n", - " names_list_2 = []\n", - " answers_list_2 = []\n", - " scores_list_2 = []\n", - " comments_list_2 = []\n", - "\n", - " TIME_CUTOFF_2_a = 1731002910000\n", - " TIME_CUTOFF_END_2_a = 1731016180000\n", - "\n", - " for row in connection.execute('SELECT * FROM results'):\n", - " if row[0] != 'yc2727' and row[0] != 'sqt2':\n", - " answers = json.loads(row[1])\n", - " if answers[0]['start_time'] > TIME_CUTOFF_2_a and answers[0]['start_time'] < TIME_CUTOFF_END_2_a:\n", - " names_list_2.append(row[0])\n", - " answers_list_2.append(answers)\n", - " scores_list_2.append(row[2])\n", - " comments_list_2.append(row[3])\n", - "\n", - " # The 2_b cutoff below is to specifically add data from 'ljl2' for a later second round window,\n", - " # likely because they submitted their data late or for a different time block.\n", - " TIME_CUTOFF_2_b = 1732212720000\n", - " TIME_CUTOFF_END_2_b = 1740000000000\n", - "\n", - " for row in connection.execute('SELECT * FROM results'):\n", - " if row[0].lower() == 'ljl2':\n", - " answers = json.loads(row[1])\n", - " if answers[0]['start_time'] > TIME_CUTOFF_2_b and answers[0]['start_time'] < TIME_CUTOFF_END_2_b:\n", - " names_list_2.append(row[0])\n", - " answers_list_2.append(answers)\n", - " scores_list_2.append(row[2])\n", - " comments_list_2.append(row[3])\n", - "\n", - " # print(f\"Round 1: {len(names_list_1)} entries\")\n", - " # print(f\"Round 2: {len(names_list_2)} entries\")\n", - " # names_list_1, names_list_2" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "889d7ede", - "metadata": {}, - "outputs": [], - "source": [ - "human_map1 = {\n", - " \"vn72\": ['d3efg03', 'd3f07tv', 'f1353ra', 'f13u0n2', 'fegw71k', 'fegwvia', 'ikpw1ol', 'ikpxahv', 'isz5n6g', 'isz6a7m'],\n", - " \"nac86\": ['dyd7cfv', 'dydawov', 'ewtjxp2', 'ewtowqu', 'f00nxjs', 'f014doo', 'ii3cpve', 'ii4ivim', 'ikzb571', 'ikzcblg'],\n", - " # \"sj597\": ['dbnr5nd', 'dbnxbzn', 'g1sm9mt', 'g1spg0k', 'g297nn7', 'g2998qp', 'h2h7yl0', 'h2hauaj', 'ifknq44', 'ifkoazb'],\n", - " \"yc2727\": ['d08x3kl', 'd08x4q0', 'd3efg03', 'd3f07tv', 'ewtjxp2', 'ewtowqu', 'hnupywh', 'hnuu3ip', 'ih2fn94', 'ih3iwph'],\n", - " # \"ex36\": ['d0fyu9x', 'd0fzp40', 'dg7wmdb', 'dg7x3eu', 'ej6jusi', 'ej6ollb', 'ewtjxp2', 'ewtowqu', 'f0dryv4', 'f0dsbw4'],\n", - " \"kz88\": ['doi3vuz', 'doi44j8', 'e9mybxp', 'e9mynuy', 'g1q1973', 'g1q3upb', 'ggwdbsa', 'ggwei5y', 'gmxoyuu', 'gmxqrdz'],\n", - " \"LJL2\": ['ffpj88q', 'ffpk80c', 'flixm8f', 'fljdrtt', 'fmre7l9', 'fmspcex', 'fphcwq0', 'g3hmlew', 'givok1i', 'givp578'],\n", - " \"lyk25\": ['dpirnlc', 'dpkib9q', 'e1vstxd', 'e1vv6x1', 'ewr7ls3', 'ewr7msm', 'fixjxxs', 'fixlxvy', 'fmre7l9', 'fmspcex'],\n", - " \"sqt2\": ['d3efg03', 'd3f07tv', 'g0fwpzc', 'g0ggz8e', 'g8nyz4f', 'g8nzx3m', 'h4i75b9', 'h4i79lc', 'h6ikmzc', 'h6ime20'],\n", - " \"cd326\": ['djh1bt1', 'djh2v24', 'dmlny27', 'dmlqnzz', 'f83akyz', 'f83hb2k', 'h28lknz', 'h28ltfw', 'i4rgtqf', 'i4rh1id'],\n", - " \"tg352\": ['dm8exht', 'dm8qu78', 'dvisfl0', 'dvivs9p', 'dyx2jn4', 'dyx3au1', 'gvb1ekl', 'gvdnrrb', 'i4rgtqf', 'i4rh1id'],\n", - "}\n", - "\n", - "# all_convo_ids = []\n", - "# for k,v in human_map1.items():\n", - "# all_convo_ids.extend(v)\n", - "# all_convo_ids = list(set(all_convo_ids))\n", - "# len(all_convo_ids)\n", - "\n", - "all_convo_ids = []\n", - "# # collect from round 1\n", - "if round_n == 1:\n", - " for i in range(len(answers_list_1)):\n", - " for j in range(len(answers_list_1[i])):\n", - " all_convo_ids.append(answers_list_1[i][j]['id'])\n", - "# collect from round 2\n", - "elif round_n == 2:\n", - " for i in range(len(answers_list_2)):\n", - " for j in range(len(answers_list_2[i])):\n", - " all_convo_ids.append(answers_list_2[i][j]['id'])\n", - "all_convo_ids = list(set(all_convo_ids))\n", - "print(f\"[info] total unique conversation ids: {len(all_convo_ids)}\")\n", - "print(all_convo_ids)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "168fb025", - "metadata": {}, - "outputs": [], - "source": [ - "# we want all utterances to have a human_guesses meta field\n", - "for convo in corpus.iter_conversations():\n", - " for utt in convo.get_chronological_utterance_list():\n", - " utt.add_meta('human_guesses', [])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "27fc646a", - "metadata": {}, - "outputs": [], - "source": [ - "# now we r gonna add the human guesses to the utterances, with their player ids\n", - "\n", - "# extract human first awry guess indices, now using answers_list_2\n", - "\n", - "# from collections import defaultdict\n", - "\n", - "# for i, participant_sessions in enumerate(answers_list):\n", - "# participant_id = names_list[i]\n", - "# print(participant_id)\n", - "# for session in participant_sessions:\n", - "# convo_id = session['id']\n", - "# convo = corpus.get_conversation(convo_id)\n", - "# utts = convo.get_chronological_utterance_list()\n", - "# actions = session.get('actions', [])\n", - "\n", - "# for action_idx, action in enumerate(actions):\n", - "# if action.get('guess') is True:\n", - "# utt = utts[action_idx]\n", - "# # get current guesses (returns deep copy if exists, or empty list if not)\n", - "# # create a new list to avoid mutating the copy\n", - "# current_guesses = list(utt.meta.get('human_guesses', []))\n", - "# # append new participant and set back\n", - "# current_guesses.append(participant_id)\n", - "# utt.add_meta('human_guesses', current_guesses)\n", - "if round_n == 1:\n", - " for i, participant_sessions in enumerate(answers_list_1):\n", - " participant_id = names_list_1[i]\n", - " print(participant_id)\n", - " for session in participant_sessions:\n", - " convo_id = session['id']\n", - " convo = corpus.get_conversation(convo_id)\n", - " utts = convo.get_chronological_utterance_list()\n", - " actions = session.get('actions', [])\n", - "\n", - " for action_idx, action in enumerate(actions):\n", - " if action.get('guess') is True:\n", - " utt = utts[action_idx]\n", - " # get current guesses (returns deep copy if exists, or empty list if not)\n", - " # create a new list to avoid mutating the copy\n", - " current_guesses = list(utt.meta.get('human_guesses', []))\n", - " # append new participant and set back\n", - " current_guesses.append(participant_id)\n", - " utt.add_meta('human_guesses', current_guesses)\n", - "elif round_n == 2:\n", - " for i, participant_sessions in enumerate(answers_list_2):\n", - " participant_id = names_list_2[i]\n", - " print(participant_id)\n", - " for session in participant_sessions:\n", - " convo_id = session['id']\n", - " convo = corpus.get_conversation(convo_id)\n", - " utts = convo.get_chronological_utterance_list()\n", - " actions = session.get('actions', [])\n", - "\n", - " for action_idx, action in enumerate(actions):\n", - " if action.get('guess') is True:\n", - " utt = utts[action_idx]\n", - " # get current guesses (returns deep copy if exists, or empty list if not)\n", - " # create a new list to avoid mutating the copy\n", - " current_guesses = list(utt.meta.get('human_guesses', []))\n", - " # append new participant and set back\n", - " current_guesses.append(participant_id)\n", - " utt.add_meta('human_guesses', current_guesses)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fd613b4b", - "metadata": {}, - "outputs": [], - "source": [ - "# initialize model prediction metadata fields for each utterance\n", - "for convo in corpus.iter_conversations():\n", - " for utt in convo.get_chronological_utterance_list():\n", - " utt.add_meta('model_forecast_probs', {}) # will store {seed: prob}\n", - " utt.add_meta('model_forecasts', {}) # will store {seed: binary_forecast}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "873cb706", - "metadata": {}, - "outputs": [], - "source": [ - "# TODO need to choose which decisionpolicy/forecaster run to use to get the forecast probabilities\n", - "\n", - "# load model predictions from all 5 seeds\n", - "test_corpi_path = '/reef/lyk25/dynamic_training/game_analysis/corpi/test'\n", - "\n", - "for seed_num in range(1, 6):\n", - " seed_corpus_path = f'{test_corpi_path}/test-son-seed-{seed_num}'\n", - " print(f\"[info] loading predictions from seed {seed_num}: {seed_corpus_path}\")\n", - " \n", - " # load the seed corpus\n", - " seed_corpus = Corpus(filename=seed_corpus_path)\n", - " \n", - " # iterate through conversations in our main corpus\n", - " for convo in corpus.iter_conversations():\n", - " convo_id = convo.id\n", - " \n", - " # check if this conversation exists in the seed corpus\n", - " if convo_id not in seed_corpus.conversations:\n", - " print(f\"[warning] convo {convo_id} not found in seed {seed_num}\")\n", - " continue\n", - " \n", - " # get the corresponding conversation from seed corpus\n", - " seed_convo = seed_corpus.get_conversation(convo_id)\n", - " \n", - " # get utterances from both corpora\n", - " main_utts = convo.get_chronological_utterance_list()\n", - " seed_utts = seed_convo.get_chronological_utterance_list()\n", - " \n", - " # match utterances and copy predictions\n", - " for main_utt, seed_utt in zip(main_utts, seed_utts):\n", - " # verify they're the same utterance\n", - " if main_utt.id != seed_utt.id:\n", - " print(f\"[warning] utterance mismatch: {main_utt.id} != {seed_utt.id}\")\n", - " continue\n", - " \n", - " # extract predictions from seed utterance\n", - " forecast_prob = seed_utt.meta.get('forecast_prob', None)\n", - " forecast = seed_utt.meta.get('forecast', None)\n", - " \n", - " # add to main utterance's metadata\n", - " if forecast_prob is not None:\n", - " current_probs = dict(main_utt.meta.get('model_forecast_probs', {}))\n", - " current_probs[f'seed_{seed_num}'] = forecast_prob\n", - " main_utt.add_meta('model_forecast_probs', current_probs)\n", - " \n", - " if forecast is not None:\n", - " current_forecasts = dict(main_utt.meta.get('model_forecasts', {}))\n", - " current_forecasts[f'seed_{seed_num}'] = forecast\n", - " main_utt.add_meta('model_forecasts', current_forecasts)\n", - "\n", - "print(\"\\n[pass] model predictions from all seeds added to corpus\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5b0f0e44", - "metadata": {}, - "outputs": [], - "source": [ - "# horizon comparison between gemma and humans, using the same convention as\n", - "# performance_utils_wiki.calculate_current_performance:\n", - "# - only count true positives (ground truth awry AND trigger fired)\n", - "# - horizon = (len(utts) - first_trigger_index) - 1\n", - "# - first trigger only (break after first positive)\n", - "#\n", - "# reports per-round results for humans (round 1 and round 2), pulling data\n", - "# directly from the game sqlite so it works regardless of which round_n the\n", - "# rest of the notebook was run under. ljl2 is excluded.\n", - "\n", - "import sqlite3\n", - "import json as _json\n", - "from collections import defaultdict\n", - "import numpy as np\n", - "\n", - "def _horizon_mean(horizons):\n", - " # matches performance_utils_wiki: h = mean(horizons) - 1\n", - " if len(horizons) == 0:\n", - " return float('nan'), 0\n", - " return float(np.mean(horizons)) - 1, len(horizons)\n", - "\n", - "# always excluded (staff/test accounts). ljl2 is a legitimate late submitter\n", - "# in the round 2_b window, not an exclusion.\n", - "base_excluded = {'yc2727', 'sqt2'}\n", - "round1_excluded = set(base_excluded)\n", - "round2_excluded = set(base_excluded)\n", - "\n", - "# gemma: per-seed first-trigger horizons on TP convos only (ground truth awry)\n", - "gemma_horizons_by_seed = defaultdict(list)\n", - "for convo in corpus.iter_conversations():\n", - " if not convo.meta.get('has_removed_comment', False):\n", - " continue # skip non-awry convos; horizon is only meaningful on TPs\n", - " utts = convo.get_chronological_utterance_list()\n", - " for seed_num in range(1, 6):\n", - " for i, utt in enumerate(utts):\n", - " if utt.meta.get('model_forecasts', {}).get(f'seed_{seed_num}') == 1:\n", - " gemma_horizons_by_seed[seed_num].append(len(utts) - i)\n", - " break\n", - "\n", - "print('gemma horizon (TPs only, mean - 1):')\n", - "gemma_per_seed_h = []\n", - "for seed_num in sorted(gemma_horizons_by_seed.keys()):\n", - " h, n = _horizon_mean(gemma_horizons_by_seed[seed_num])\n", - " gemma_per_seed_h.append(h)\n", - " print(f' seed {seed_num}: n={n}, h={h:.4f}')\n", - "gemma_h_mean_of_seeds = float(np.mean(gemma_per_seed_h)) if gemma_per_seed_h else float('nan')\n", - "all_gemma_vals = [v for vs in gemma_horizons_by_seed.values() for v in vs]\n", - "gemma_h_pooled, gemma_n_pooled = _horizon_mean(all_gemma_vals)\n", - "print(f' mean over seeds: h={gemma_h_mean_of_seeds:.4f}')\n", - "print(f' pooled: n={gemma_n_pooled}, h={gemma_h_pooled:.4f}')\n", - "\n", - "# pull round-specific human guesses directly from the sqlite db\n", - "db_path = '/reef/sqt2/cga-eval/human/game_db.sqlite'\n", - "_conn = sqlite3.connect(db_path)\n", - "\n", - "def _load_round_answers(round_n):\n", - " # returns list of (participant_id, sessions) for the given round\n", - " entries = []\n", - " if round_n == 1:\n", - " excluded = round1_excluded\n", - " t_start, t_end = 1730395000000, 1730397199000\n", - " for row in _conn.execute('SELECT * FROM results'):\n", - " if row[0] in excluded:\n", - " continue\n", - " answers = _json.loads(row[1])\n", - " if answers and answers[0]['start_time'] > t_start and answers[0]['start_time'] < t_end:\n", - " entries.append((row[0], answers))\n", - " elif round_n == 2:\n", - " excluded = round2_excluded\n", - " t_start_a, t_end_a = 1731002910000, 1731016180000\n", - " t_start_b, t_end_b = 1732212720000, 1740000000000\n", - " for row in _conn.execute('SELECT * FROM results'):\n", - " name = row[0]\n", - " if name in excluded:\n", - " continue\n", - " answers = _json.loads(row[1])\n", - " if not answers:\n", - " continue\n", - " st = answers[0]['start_time']\n", - " in_a = t_start_a < st < t_end_a\n", - " # the 2_b late window is only valid for ljl2 (matches cell 7)\n", - " in_b = (name.lower() == 'ljl2') and (t_start_b < st < t_end_b)\n", - " if in_a or in_b:\n", - " entries.append((name, answers))\n", - " return entries\n", - "\n", - "def _unique_convo_ids(entries):\n", - " # unique convo ids seen across all included players' sessions\n", - " ids = set()\n", - " for _, sessions in entries:\n", - " for session in sessions:\n", - " cid = session.get('id')\n", - " if cid is not None:\n", - " ids.add(cid)\n", - " return ids\n", - "\n", - "def _compute_human_horizons(entries):\n", - " # returns dict: player_id -> list of horizons (non-awry convos only, first guess per convo)\n", - " by_player = defaultdict(list)\n", - " for participant_id, sessions in entries:\n", - " for session in sessions:\n", - " convo_id = session.get('id')\n", - " if convo_id is None:\n", - " continue\n", - " try:\n", - " convo = corpus.get_conversation(convo_id)\n", - " except KeyError:\n", - " # convo not in the loaded corpus (e.g., different round loaded)\n", - " continue\n", - " if not convo.meta.get('has_removed_comment', False):\n", - " continue # skip non-awry convos; horizon only counted on TPs\n", - " utts = convo.get_chronological_utterance_list()\n", - " for action_idx, action in enumerate(session.get('actions', [])):\n", - " if action.get('guess') is True:\n", - " if action_idx < len(utts):\n", - " by_player[participant_id].append(len(utts) - action_idx)\n", - " break # only first guess per (player, convo)\n", - " return by_player\n", - "\n", - "def _print_human_horizons(label, by_player):\n", - " print(f'human horizon ({label}, TPs only, mean - 1):')\n", - " player_h = []\n", - " for player, horizons in by_player.items():\n", - " h, n = _horizon_mean(horizons)\n", - " player_h.append(h)\n", - " print(f' player {player}: n={n}, h={h:.4f}')\n", - " mean_of_players = float(np.mean(player_h)) if player_h else float('nan')\n", - " all_vals = [v for vs in by_player.values() for v in vs]\n", - " h_pooled, n_pooled = _horizon_mean(all_vals)\n", - " print(f' mean over players: h={mean_of_players:.4f}')\n", - " print(f' pooled: n={n_pooled}, h={h_pooled:.4f}')\n", - " return mean_of_players, h_pooled\n", - "\n", - "print()\n", - "EXPECTED_N_CONVOS = 84\n", - "round1_entries = _load_round_answers(1)\n", - "round1_convo_ids = _unique_convo_ids(round1_entries)\n", - "print(f'round 1: {len(round1_entries)} included players, {len(round1_convo_ids)} unique convos')\n", - "assert len(round1_convo_ids) == EXPECTED_N_CONVOS, (\n", - " f'round 1 expected {EXPECTED_N_CONVOS} unique convos but got {len(round1_convo_ids)}'\n", - ")\n", - "round1_by_player = _compute_human_horizons(round1_entries)\n", - "r1_mean, r1_pooled = _print_human_horizons('round 1', round1_by_player)\n", - "\n", - "print()\n", - "round2_entries = _load_round_answers(2)\n", - "round2_convo_ids = _unique_convo_ids(round2_entries)\n", - "print(f'round 2: {len(round2_entries)} included players, {len(round2_convo_ids)} unique convos')\n", - "assert len(round2_convo_ids) == EXPECTED_N_CONVOS, (\n", - " f'round 2 expected {EXPECTED_N_CONVOS} unique convos but got {len(round2_convo_ids)}'\n", - ")\n", - "round2_by_player = _compute_human_horizons(round2_entries)\n", - "r2_mean, r2_pooled = _print_human_horizons('round 2', round2_by_player)\n", - "\n", - "_conn.close()\n", - "\n", - "print()\n", - "print('summary (comparable to h in performance_utils_wiki):')\n", - "print(f' gemma h (mean over seeds): {gemma_h_mean_of_seeds:.4f}')\n", - "print(f' gemma h (pooled): {gemma_h_pooled:.4f}')\n", - "print(f' round1 h (mean over players): {r1_mean:.4f}')\n", - "print(f' round1 h (pooled): {r1_pooled:.4f}')\n", - "print(f' round2 h (mean over players): {r2_mean:.4f}')\n", - "print(f' round2 h (pooled): {r2_pooled:.4f}')\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "44ed25ac", - "metadata": {}, - "outputs": [], - "source": [ - "# benchmark humans, mirroring the gemma benchmark (cell 40).\n", - "# computes accuracy, precision, recall, f1, fpr, specificity, fnr for:\n", - "# - gemma (aggregated over seeds 1-5)\n", - "# - round 1 humans (aggregate \"any\" rule, plus per-player)\n", - "# - round 2 humans (aggregate \"any\" rule, plus per-player)\n", - "# uses the included entries loaded in the previous cell:\n", - "# round1_entries, round2_entries\n", - "# and the corpus in memory for ground truth.\n", - "from collections import defaultdict\n", - "import numpy as np\n", - "def _compute_metrics(tp, fp, tn, fn):\n", - " total = tp + fp + tn + fn\n", - " accuracy = (tp + tn) / total if total > 0 else 0.0\n", - " precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0\n", - " recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0\n", - " f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0\n", - " fpr = fp / (fp + tn) if (fp + tn) > 0 else 0.0\n", - " specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0\n", - " fnr = fn / (fn + tp) if (fn + tp) > 0 else 0.0\n", - " return {\n", - " 'tp': tp, 'fp': fp, 'tn': tn, 'fn': fn,\n", - " 'accuracy': accuracy, 'precision': precision, 'recall': recall,\n", - " 'f1': f1, 'fpr': fpr, 'specificity': specificity, 'fnr': fnr,\n", - " }\n", - "def _gt(convo):\n", - " return bool(convo.meta.get('has_removed_comment', False))\n", - "# gemma benchmark, per seed, then mean and std\n", - "gemma_per_seed = []\n", - "for seed_num in range(1, 6):\n", - " tp = fp = tn = fn = 0\n", - " for convo in corpus.iter_conversations():\n", - " utts = convo.get_chronological_utterance_list()\n", - " pred = any(\n", - " utt.meta.get('model_forecasts', {}).get(f'seed_{seed_num}') == 1\n", - " for utt in utts\n", - " )\n", - " truth = _gt(convo)\n", - " if pred and truth: tp += 1\n", - " elif pred and not truth: fp += 1\n", - " elif not pred and not truth: tn += 1\n", - " elif not pred and truth: fn += 1\n", - " gemma_per_seed.append(_compute_metrics(tp, fp, tn, fn))\n", - "gemma_mean = {k: float(np.mean([m[k] for m in gemma_per_seed])) for k in gemma_per_seed[0]}\n", - "gemma_std = {k: float(np.std([m[k] for m in gemma_per_seed], ddof=1)) for k in gemma_per_seed[0]}\n", - "def _build_per_player_preds(entries):\n", - " # returns dict[player_id] -> dict[convo_id] -> 0/1\n", - " preds = defaultdict(dict)\n", - " for player, sessions in entries:\n", - " for s in sessions:\n", - " cid = s.get('id')\n", - " if cid is None:\n", - " continue\n", - " try:\n", - " corpus.get_conversation(cid)\n", - " except KeyError:\n", - " continue\n", - " actions = s.get('actions', [])\n", - " preds[player][cid] = 1 if any(a.get('guess') is True for a in actions) else 0\n", - " return preds\n", - "def _aggregate_any(preds):\n", - " # per-convo prediction = 1 if any player who saw it guessed\n", - " by_convo = defaultdict(list)\n", - " for player, convo_preds in preds.items():\n", - " for cid, p in convo_preds.items():\n", - " by_convo[cid].append(p)\n", - " tp = fp = tn = fn = 0\n", - " for cid, plist in by_convo.items():\n", - " try:\n", - " convo = corpus.get_conversation(cid)\n", - " except KeyError:\n", - " continue\n", - " pred = 1 if any(plist) else 0\n", - " truth = 1 if _gt(convo) else 0\n", - " if pred and truth: tp += 1\n", - " elif pred and not truth: fp += 1\n", - " elif not pred and not truth: tn += 1\n", - " elif not pred and truth: fn += 1\n", - " return _compute_metrics(tp, fp, tn, fn), len(by_convo)\n", - "def _per_player_metrics(preds):\n", - " rows = {}\n", - " for player, convo_preds in preds.items():\n", - " tp = fp = tn = fn = 0\n", - " for cid, p in convo_preds.items():\n", - " try:\n", - " convo = corpus.get_conversation(cid)\n", - " except KeyError:\n", - " continue\n", - " truth = 1 if _gt(convo) else 0\n", - " if p and truth: tp += 1\n", - " elif p and not truth: fp += 1\n", - " elif not p and not truth: tn += 1\n", - " elif not p and truth: fn += 1\n", - " rows[player] = _compute_metrics(tp, fp, tn, fn)\n", - " return rows\n", - "# round 1 and round 2 humans\n", - "round1_preds = _build_per_player_preds(round1_entries)\n", - "round1_agg, round1_n_convos = _aggregate_any(round1_preds)\n", - "round1_per_player = _per_player_metrics(round1_preds)\n", - "round2_preds = _build_per_player_preds(round2_entries)\n", - "round2_agg, round2_n_convos = _aggregate_any(round2_preds)\n", - "round2_per_player = _per_player_metrics(round2_preds)\n", - "# metric mean over players (a different aggregation view)\n", - "def _mean_over_players(per_player):\n", - " keys = ['accuracy', 'precision', 'recall', 'f1', 'fpr', 'specificity', 'fnr']\n", - " return {k: float(np.mean([m[k] for m in per_player.values()])) for k in keys}\n", - "round1_mean_over_players = _mean_over_players(round1_per_player)\n", - "round2_mean_over_players = _mean_over_players(round2_per_player)\n", - "# printing\n", - "metric_order = ['accuracy', 'precision', 'recall', 'f1', 'fpr', 'specificity', 'fnr']\n", - "def _print_per_player(label, per_player):\n", - " print(f'{label} - per-player metrics:')\n", - " header = f\" {'player':<10}\" + \"\".join(f\"{m:>12}\" for m in metric_order) + f\"{'n_convos':>10}\"\n", - " print(header)\n", - " print(' ' + '-' * (len(header) - 2))\n", - " for player in sorted(per_player):\n", - " m = per_player[player]\n", - " n = m['tp'] + m['fp'] + m['tn'] + m['fn']\n", - " row = f\" {player:<10}\" + \"\".join(f\"{m[k]:>12.4f}\" for k in metric_order) + f\"{n:>10d}\"\n", - " print(row)\n", - "_print_per_player('round 1', round1_per_player)\n", - "print()\n", - "_print_per_player('round 2', round2_per_player)\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "lyk25-env", - "language": "python", - "name": "python3" }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.11" - } - }, - "nbformat": 4, - "nbformat_minor": 5 + "nbformat": 4, + "nbformat_minor": 5 } From 44f77189cd992bd8976451c6b7e9ea0aadf83b01 Mon Sep 17 00:00:00 2001 From: Laerdon Kim <96972420+laerdon@users.noreply.github.com> Date: Thu, 30 Apr 2026 23:19:07 -0400 Subject: [PATCH 6/6] Delete examples/forecaster/train_deferral.py --- examples/forecaster/train_deferral.py | 186 -------------------------- 1 file changed, 186 deletions(-) delete mode 100644 examples/forecaster/train_deferral.py diff --git a/examples/forecaster/train_deferral.py b/examples/forecaster/train_deferral.py deleted file mode 100644 index dec75fe49..000000000 --- a/examples/forecaster/train_deferral.py +++ /dev/null @@ -1,186 +0,0 @@ -""" -end-to-end training script for a TransformerDecoderModel forecaster with a DeferralDecisionPolicy. - -flow: - 1. load the CGA-CMV corpus - 2. build a DeferralDecisionPolicy backed by an UnslothUtteranceSimulatorModel - 3. build a TransformerDecoderModel (forecaster backbone) with that policy attached - 4. wrap both in a Forecaster - 5. fit: LoRA fine-tune the forecaster, then fit the decision policy on the val set - 6. evaluate on the test set and print metrics - -usage: - python train_deferral.py [--device cuda] [--gpu 0] -""" - -import argparse -import os -import sys - -# ensure the repo root is on the path when running directly -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) - - -def main(args): - import convokit - from convokit import Corpus, Forecaster, download - from convokit.forecaster.TransformerDecoderModel import TransformerDecoderModel - from convokit.forecaster.TransformerForecasterConfig import TransformerForecasterConfig - from convokit.decisionpolicy import DeferralDecisionPolicy, ThresholdDecisionPolicy - - # ------------------------------------------------------------------ # - # 1. corpus - # ------------------------------------------------------------------ # - print("[info] loading corpus...") - corpus = Corpus( - filename=download( - "conversations-gone-awry-cmv-corpus", - data_dir=args.data_dir, - ) - ) - - labeler = "has_removed_comment" - - # ------------------------------------------------------------------ # - # 2. context selectors - # ------------------------------------------------------------------ # - def train_selector(ctx): - """last context of every train conversation (matches original craft/llm training setup)""" - convo = ctx.current_utterance.get_conversation() - return ( - convo.meta.get("split") == "train" - and len(ctx.future_context) == 0 - ) - - def val_selector(ctx): - return ctx.current_utterance.get_conversation().meta.get("split") == "val" - - def test_selector(ctx): - convo = ctx.current_utterance.get_conversation() - convo_len = len(convo.get_chronological_utterance_list()) - return ( - convo.meta.get("split") == "test" - # exclude the very last context (the toxic turn itself) - and len(ctx.context) < convo_len - ) - - - # 3. simulator model - # # - # 4. decision policy - policy = ThresholdDecisionPolicy( - threshold=0.5926666259765625, - ) - - # 5. forecaster model - print("[info] loading forecaster model...") - config = TransformerForecasterConfig( - output_dir=args.output_dir, - per_device_batch_size=args.batch_size, - gradient_accumulation_steps=args.grad_accum, - num_train_epochs=args.epochs, - learning_rate=args.lr, - random_seed=args.seed, - context_mode="normal", - device=args.device, - ) - - forecaster_model = TransformerDecoderModel( - model_name_or_path=args.forecaster_model, - config=config, - decision_policy=policy, - ) - - # 6. forecaster wrapper - forecaster = Forecaster( - forecaster_model=forecaster_model, - labeler=labeler, - ) - - # 7. fit - # print("[info] fitting forecaster (belief estimator + decision policy)...") - # forecaster.fit( - # corpus=corpus, - # context_selector=train_selector, - # val_context_selector=val_selector, - # ) - - forecaster.fit_decision_policy( - corpus=corpus, - context_selector=train_selector, - val_context_selector=val_selector, - ) - - # print(forecaster.forecaster_model.decision_policy.threshold) - - # # 8. evaluate on test set - # print("[info] running transform on test set...") - # corpus = forecaster.transform( - # corpus=corpus, - # context_selector=test_selector, - # ) - - # print("[info] computing metrics...") - # forecaster.summarize( - # corpus=corpus, - # selector=lambda convo: convo.meta.get("split") == "test", - # ) - - # optional: inspect a few utterances with stored simulations - if args.store_simulations: - print("\n[info] sample utterances with stored simulations:") - shown = 0 - for utt in corpus.iter_utterances(): - # show only utterances that were forecasted and have sim_replies - if ( - utt.meta.get("forecast") is not None - and utt.meta.get("sim_replies") is not None - ): - print("---") - print("text :", utt.text[:120]) - print("forecast_prob :", utt.meta["forecast_prob"]) - print("forecast :", utt.meta["forecast"]) - print("sim_replies :", utt.meta["sim_replies"][:2]) - print("sim_probs :", utt.meta["sim_replies_forecast_probs"][:2]) - shown += 1 - if shown >= 3: - break - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="train forecaster with DeferralDecisionPolicy") - - # paths - parser.add_argument("--forecaster-model", required=True, - help="hf model name or local path for the decoder forecaster") - parser.add_argument("--simulator-model", required=True, - help="hf model name or local path for the utterance simulator") - parser.add_argument("--output-dir", default="./deferral_output", - help="directory to save checkpoints and predictions") - parser.add_argument("--data-dir", default="./", - help="directory to download/find the corpus") - - # training hyperparams - parser.add_argument("--epochs", type=int, default=1) - parser.add_argument("--batch-size", type=int, default=2) - parser.add_argument("--grad-accum", type=int, default=32) - parser.add_argument("--lr", type=float, default=1e-4) - parser.add_argument("--seed", type=int, default=1) - - # deferral policy hyperparams - parser.add_argument("--num-simulations", type=int, default=10, - help="number of simulated branches per context") - parser.add_argument("--tau", type=int, default=5, - help="minimum simulated branches above threshold to intervene") - - # misc - parser.add_argument("--device", default="cuda") - parser.add_argument("--gpu", type=int, default=3, - help="which gpu to use (sets CUDA_VISIBLE_DEVICES)") - parser.add_argument("--store-simulations", action="store_true", - help="write simulated replies and their forecast probs to corpus metadata") - - args = parser.parse_args() - os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) - - main(args)