Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/build_and_publish_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@ name: build_and_publish_docs
on:
push:
branches:
- dev
- master
- test/**
pull_request:
branches:
- dev
Expand Down
30 changes: 30 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,36 @@ poetry init
- Include docstrings for functions and classes
- Write unit tests for new features or bug fixes

### Docstrings
Use this Google style docstrings format:

https://www.sphinx-doc.org/en/master/usage/extensions/example_google.html

Example from Dialog2graph:
```python
class ModelStorage(BaseModel):
"""
ModelStorage is a class for managing the storage of model configurations and instances.
It provides functionality to load configurations from a YAML file, add new models to the storage,
and save the current storage state back to a YAML file.

Attributes:
storage (Dict[str, StoredData]): A dictionary that holds the stored model configurations
and their corresponding instances.
"""

storage: Dict[str, StoredData] = Field(default_factory=dict)

def load(self, path: Path):
"""
Load model configurations from a YAML file into the storage.

Args:
path (str): The file path to the YAML file containing model configurations.
"""
...
```

## Pull Request Format

- Name of your PR (keep it simple yet meaningful)
Expand Down
3 changes: 2 additions & 1 deletion dialog2graph/pipelines/core/dialog_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ def remove_duplicated_paths(node_paths: list[list[int]]) -> list[list[int]]:


def get_dialog_triplets(seq: list[list[dict]]) -> set[tuple[str]]:
"""Find all dialog triplets with (source, edge, target) utterances
"""Get all dialog triplets with (source, edge, target) utterances
from sequence of dialogs

Args:
seq: sequence of dialogs
Expand Down
56 changes: 54 additions & 2 deletions dialog2graph/pipelines/core/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@

from datetime import datetime
import networkx as nx
import gravis as gv
from pydantic import BaseModel
from typing import Optional, Any
import matplotlib.pyplot as plt
import abc
import colorsys

from dialog2graph.utils.logger import Logger

Expand Down Expand Up @@ -238,7 +240,7 @@ def visualise(self, *args, **kwargs):
plt.axis("off")
plt.show()

def visualise_short(self, name, *args, **kwargs):
def visualise_short(self, name="", *args, **kwargs):
"""Create a compact visualization of the graph.

Args:
Expand All @@ -251,7 +253,7 @@ def visualise_short(self, name, *args, **kwargs):
try:
pos = nx.nx_agraph.pygraphviz_layout(self.graph)
except ImportError as e:
pos = nx.kamada_kawai_layout(self.graph)
pos = nx.spring_layout(self.graph)
logger.warning(
f"{e}.\nInstall pygraphviz from http://pygraphviz.github.io/ .\nFalling back to default layout."
)
Expand Down Expand Up @@ -285,6 +287,56 @@ def visualise_short(self, name, *args, **kwargs):
plt.axis("off")
plt.show()

def visualise_interactive(self, *args, **kwargs) -> gv._internal.plotting.data_structures.Figure:

"""
Visualises the graph using interactive visualisation library "gravis".

Returns:
A figure object representing the interactive graph visualization.
"""
graph = {"graph": {}}
if "frequency" in self.graph_dict["nodes"][0]:
node_rgb = [colorsys.hsv_to_rgb(1/node["frequency"]/2, 1.0, 1.0) for node in self.graph_dict["nodes"]]
node_colors = ["#%02x%02x%02x" % tuple([round(255*x) for x in rgb]) for rgb in node_rgb]
node_frequency = [node["frequency"] for node in self.graph_dict["nodes"]]
else:
node_colors = ["#000000"]*len(self.graph_dict["nodes"])
node_frequency = [0]*len(self.graph_dict["nodes"])
if "frequency" in self.graph_dict["edges"][0]:
edge_rgb = [colorsys.hsv_to_rgb(1/node["frequency"]/2, 1.0, 1.0) for node in self.graph_dict["edges"]]
edge_colors = ["#%02x%02x%02x" % tuple([round(255*x) for x in rgb]) for rgb in edge_rgb]
edge_frequency = [edge["frequency"] for edge in self.graph_dict["edges"]]
else:
edge_colors = ["#000000"]*len(self.graph_dict["edges"])
edge_frequency = [0]*len(self.graph_dict["edges"])

graph["graph"]["nodes"] = {
str(node["id"]): {
"label": f"{node['id']}:{len(node['utterances'])}",
"metadata": {
"hover": f"frequency: {node_frequency[idx]}\n" + '\n'.join([str(i+1)+": "+ node["utterances"][i] for i in range(len(node["utterances"]))]),
"color": node_colors[idx]
}
} for idx, node in enumerate(self.graph_dict["nodes"])
}
graph["graph"]["edges"] = [{"source": str(e["source"]),
"target": str(e["target"]),
"label": len(e["utterances"]),
"metadata": {
"hover": f"frequency: {edge_frequency[idx]}\n" + '\n'.join([str(i+1)+": "+ e["utterances"][i] for i in range(len(e["utterances"]))]),
"color": edge_colors[idx]
}
} for idx, e in enumerate(self.graph_dict["edges"])]
return gv.vis(
graph, show_node_label=True, show_edge_label=True,
node_label_data_source='label',
edge_label_data_source='label', edge_label_size_factor=1.7,
layout_algorithm="hierarchicalRepulsion",
)



def find_nodes_by_utterance(self, utterance: str) -> list[dict]:
"""Find nodes containing a specific utterance.

Expand Down
19 changes: 18 additions & 1 deletion dialog2graph/utils/dg_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,19 @@ def connect_nodes(
"""
edges = []
node_store = NodeStore(nodes, utt_sim)
for idx in range(len(nodes)):
nodes[idx]["frequency"] = 0
for dialog in dialogs:
turns = dialog.to_list()
dialog_store = DialogStore(turns, utt_sim)
for node in nodes:
for utt in node["utterances"]:
ids = dialog_store.search_assistant(utt)
ids = dialog_store.search_store(
dialog_store.assistant_store,
dialog_store.assistant_size,
utt
)
node["frequency"] += len(ids)
if ids:
for id, user_utt in zip(ids, dialog_store.get_user_by_id(ids=ids)):
if len(turns) > 2 * (int(id) + 1):
Expand Down Expand Up @@ -66,6 +73,7 @@ def connect_nodes(
"utterances"
]
+ [user_utt],
"frequency": 0
}
)
else:
Expand All @@ -74,8 +82,17 @@ def connect_nodes(
"source": node["id"],
"target": target,
"utterances": [user_utt],
"frequency": 0,
}
)
for edge in edges:
for utt in edge["utterances"]:
ids = dialog_store.search_store(
dialog_store.user_store,
dialog_store.user_size,
utt
)
edge["frequency"] += len(ids)
return {"edges": edges, "nodes": nodes}


Expand Down
43 changes: 25 additions & 18 deletions dialog2graph/utils/vector_stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@ class DialogStore:
User and assistant utterances vectorized separately

Attributes:
_assistant_store: store for assistant utterances
_user_store: store for user utterances
_assistant_size: number of assistant utterances
assistant_store: store for assistant utterances
user_store: store for user utterances
assistant_size: number of assistant utterances
user_size: number of user utterances
_score_threshold: simlarity threshold
"""

_assistant_store: Chroma
_user_store: Chroma
_assistant_size: int
assistant_store: Chroma
user_store: Chroma
assistant_size: int
user_size: int
_score_threshold: int

def _load_dialog(
Expand All @@ -39,10 +41,10 @@ def _load_dialog(
dialog: list of dicts in a form {"participant": "user" or "assistant", "text": text}
embedder: embedding function for vector store
"""
self._assistant_store = Chroma(
self.assistant_store = Chroma(
collection_name=str(uuid.uuid4()), embedding_function=embedder
)
self._user_store = Chroma(
self.user_store = Chroma(
collection_name=str(uuid.uuid4()), embedding_function=embedder
)
assistant_docs = [
Expand All @@ -53,11 +55,12 @@ def _load_dialog(
]
user_docs = [
Document(page_content=turn["text"].lower(), id=id, metadata={"id": id})
for id, turn in enumerate(d for d in dialog if d["participant"] == "user")
for id, turn in enumerate([d for d in dialog if d["participant"] == "user"])
]
self._assistant_size = len(assistant_docs)
self._assistant_store.add_documents(documents=assistant_docs)
self._user_store.add_documents(documents=user_docs)
self.assistant_size = len(assistant_docs)
self.user_size = len(user_docs)
self.assistant_store.add_documents(documents=assistant_docs)
self.user_store.add_documents(documents=user_docs)

def __init__(
self,
Expand All @@ -75,17 +78,20 @@ def __init__(
self._score_threshold = score_threshold
self._load_dialog(dialog, embedder)

def search_assistant(self, utterance) -> list[str]:
"""Search for utterance over assistant store
def search_store(self, store: Chroma, size: int, utterance: str) -> list[str]:
"""Search for utterance over store

Args:
store: Chroma store
size: size of the store
utterance: utterance to search for
Returns:
list of found documents ids of assistant store
list of found documents ids
"""
docs = self._assistant_store.similarity_search_with_relevance_scores(

docs = store.similarity_search_with_relevance_scores(
utterance.lower(),
k=self._assistant_size,
k=size,
score_threshold=self._score_threshold,
)
res = [d[0].metadata["id"] for d in docs]
Expand All @@ -94,6 +100,7 @@ def search_assistant(self, utterance) -> list[str]:

return res


def get_user_by_id(self, ids: list[str]) -> list[str]:
"""Get utterances of user with ids

Expand All @@ -102,7 +109,7 @@ def get_user_by_id(self, ids: list[str]) -> list[str]:
Returns:
list of utterances
"""
res = self._user_store.get(ids=ids)["documents"]
res = self.user_store.get(ids=ids)["documents"]
return res


Expand Down
Loading
Loading