diff --git a/pathways/typing/mermaid.py b/pathways/typing/mermaid.py index 50008d6..899f0c3 100644 --- a/pathways/typing/mermaid.py +++ b/pathways/typing/mermaid.py @@ -79,7 +79,7 @@ def create_cart_diagram(root: Node) -> str: links = [] for node in root.preorder(): - probabilities = getattr(node, "class_probabilities", None) + probabilities = node.class_probabilities if node.is_leaf and probabilities: prob_shapes, prob_links = create_segment_probability_stack( @@ -174,9 +174,10 @@ def create_segment_probability_stack( return shapes, links -def create_form_diagram(root: Node, *, skip_notes: bool = False) -> str: +def create_form_diagram(root: Node, *, skip_notes: bool = False, threshold: float = 0.0) -> str: """Create mermaid diagram for typing form.""" header = "flowchart TD" + threshold = threshold / 100.0 shapes = { "segment": "stadium", "select_one": "rectangle", @@ -195,14 +196,20 @@ def create_form_diagram(root: Node, *, skip_notes: bool = False) -> str: continue is_segment_leaf = node.name == "segment" - probabilities = getattr(node, "class_probabilities", None) + probabilities = node.class_probabilities if is_segment_leaf and probabilities: - prob_shapes, prob_links = create_segment_probability_stack( - node, probabilities, "circle" - ) - shapes_lst.extend(prob_shapes) - links.extend(prob_links) + max_prob = max(probabilities.values()) + if max_prob < threshold: + prob_shapes, prob_links = create_segment_probability_stack( + node, probabilities, "circle" + ) + shapes_lst.extend(prob_shapes) + links.extend(prob_links) + else: + shape_label = get_form_shape_label(node) + shape = draw_shape(node.uid, shape_label, "circle") + shapes_lst.append(shape) else: shape_type = "circle" if is_segment_leaf else shapes[node.question.type] shape_label = get_form_shape_label(node) diff --git a/pathways/typing/options.py b/pathways/typing/options.py index 89dc033..3581eed 100644 --- a/pathways/typing/options.py +++ b/pathways/typing/options.py @@ -98,20 +98,48 @@ def add_segment_note( def add_segment_notes( - root: Node, settings_config: dict, segments_config: dict | None = None + root: Node, + settings_config: dict, + segments_config: dict | None = None, + low_confidence_threshold: float = 0.0, ) -> Node: - """Add notes once segments are assigned.""" + """Add notes once segments are assigned. + + If confidence_threshold is provided (percentage), calculate max probability. If max_probability < threshold, + segment + dead-end note will be applied. Otherwise, only segment note is applied. + """ + low_confidence_threshold = low_confidence_threshold / 100 new_root = copy.deepcopy(root) note_label = { - key.replace("segment_note", "label"): value + key.replace("segment_note", "label"): value.replace("\\n", "\n") if isinstance(value, str) else value for key, value in settings_config.items() if key.startswith("segment_note") } + low_conf_label = { + key.replace("deadend_note", "label"): value.replace("\\n", "\n") if isinstance(value, str) else value + for key, value in settings_config.items() + if key.startswith("deadend_note") + } + for node in new_root.preorder(): if node.is_leaf and node.name == "segment": - add_segment_note(node, note_label, segments_config) - return new_root + use_low_conf = False + if low_confidence_threshold > 0 and node.class_probabilities: + max_prob = max(node.class_probabilities.values()) + use_low_conf = max_prob < low_confidence_threshold + final_label = note_label.copy() + if use_low_conf: + for key, seg_note in final_label.items(): + low_conf_note = low_conf_label.get( + key, + "\n[Low segment assignment confidence]\n" + "We recommend stopping this survey and starting with a new respondent." + ) + final_label[key] = seg_note + low_conf_note + + add_segment_note(node, final_label, segments_config) + return new_root def enforce_relevance(root: Node) -> Node: """Enforce relevance rules for the node. @@ -280,7 +308,7 @@ def exit_deadends( # create note for dead-end deadend_label = { - key.replace("deadend_note", "label"): value + key.replace("deadend_note", "label"): value.replace("\\n", "\n") if isinstance(value, str) else value for key, value in settings_config.items() if key.startswith("deadend_note") }