Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion docling_graph/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ class ModelsConfig(BaseModel):
llm: LLMConfig = Field(default_factory=LLMConfig)
vlm: VLMConfig = Field(default_factory=VLMConfig)

class Neo4jConfig(BaseModel):
"""Configuration for Neo4j database connection."""

uri: str = Field(default="bolt://localhost:7687", description="Neo4j URI")
username: str = Field(default="neo4j", description="Database username")
password: str = Field(default="password", description="Database password")
database: str = Field(default="neo4j", description="Target database name")
batch_size: int = Field(default=1000, description="Batch size for ingestion")
write_mode: Literal["merge", "create"] = Field(default="merge", description="Write strategy")

class PipelineConfig(BaseModel):
"""
Expand Down Expand Up @@ -101,13 +110,15 @@ class PipelineConfig(BaseModel):
# Models configuration (flat only, with defaults)
models: ModelsConfig = Field(default_factory=ModelsConfig)

neo4j: Neo4jConfig = Field(default_factory=Neo4jConfig)

# Extract settings (with defaults)
use_chunking: bool = True
llm_consolidation: bool = False
max_batch_size: int = 1

# Export settings (with defaults)
export_format: Literal["csv", "cypher"] = Field(default="csv")
export_format: Literal["csv", "cypher", "neo4j"] = Field(default="csv")
export_docling: bool = Field(default=True)
export_docling_json: bool = Field(default=True)
export_markdown: bool = Field(default=True)
Expand Down Expand Up @@ -153,6 +164,7 @@ def to_dict(self) -> Dict[str, Any]:
"reverse_edges": self.reverse_edges,
"output_dir": self.output_dir,
"models": self.models.model_dump(),
"neo4j": self.neo4j.model_dump(),
}

def run(self) -> None:
Expand Down Expand Up @@ -185,6 +197,7 @@ def generate_yaml_dict(cls) -> Dict[str, Any]:
},
},
"models": default_config.models.model_dump(),
"neo4j": default_config.neo4j.model_dump(),
"output": {
"directory": str(default_config.output_dir),
},
Expand Down
3 changes: 3 additions & 0 deletions docling_graph/db_clients/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .neo4j_client import Neo4jExporter

__all__ = ["Neo4jExporter"]
156 changes: 156 additions & 0 deletions docling_graph/db_clients/neo4j_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from typing import Any, Dict, List, Optional
import networkx as nx
from neo4j import GraphDatabase, Driver
from rich import print as rich_print

class Neo4jExporter:
"""Exporter for populating a live Neo4j database."""

def __init__(
self,
uri: str,
auth: Optional[tuple[str, str]] = None,
database: str = "neo4j",
batch_size: int = 1000,
write_mode: str = "merge", # "merge" or "create"
):
"""
Initialize the Neo4j exporter.

Args:
uri: Neo4j database URI (e.g., 'bolt://localhost:7687')
auth: Tuple of (username, password)
database: Database name to use
batch_size: Number of records to commit in a single transaction
write_mode: Strategy for writing nodes ('merge' updates existing, 'create' adds new)
"""
self.uri = uri
self.auth = auth
self.database = database
self.batch_size = batch_size
self.write_mode = write_mode.lower()
self._driver: Optional[Driver] = None

def _get_driver(self) -> Driver:
if self._driver is None:
self._driver = GraphDatabase.driver(self.uri, auth=self.auth)
return self._driver

def close(self) -> None:
if self._driver:
self._driver.close()
self._driver = None

def export(self, graph: nx.DiGraph) -> None:
"""
Export the NetworkX graph to Neo4j.

Args:
graph: The NetworkX directed graph to export
"""
if graph.number_of_nodes() == 0:
rich_print("[yellow]Graph is empty. Skipping Neo4j export.[/yellow]")
return

driver = self._get_driver()

try:
with driver.session(database=self.database) as session:
# 1. Export Nodes
self._export_nodes(session, graph)

# 2. Export Relationships
self._export_edges(session, graph)

rich_print(f"[green]Successfully exported graph to Neo4j database '{self.database}'[/green]")
except Exception as e:
rich_print(f"[red]Failed to export to Neo4j:[/red] {e}")
raise
finally:
self.close()

def _export_nodes(self, session, graph: nx.DiGraph) -> None:
"""Batch write nodes to Neo4j."""
batch: List[Dict[str, Any]] = []

query = (
"UNWIND $batch AS row "
f"{'MERGE' if self.write_mode == 'merge' else 'CREATE'} (n:Node {{id: row.id}}) "
"SET n += row.properties, n.label = row.label "
"WITH n, row "
"CALL apoc.create.addLabels(n, [row.label]) YIELD node " # Optional: requires APOC, fallback to simple label setting if needed
"RETURN count(*)"
)

# Simplified query without APOC dependency
query = (
"UNWIND $batch AS row "
f"{'MERGE' if self.write_mode == 'merge' else 'CREATE'} (n:Node {{id: row.id}}) "
"SET n += row.properties "
)

# Strategy: Group nodes by label to allow static label assignment
nodes_by_label: Dict[str, List[Dict[str, Any]]] = {}

for node_id, data in graph.nodes(data=True):
label = data.get("label", "Entity")
# Sanitize label
label = "".join(x for x in label if x.isalnum() or x == "_")
if not label:
label = "Entity"

props = {k: v for k, v in data.items() if k != "label"}
props["id"] = node_id # Ensure ID is a property

if label not in nodes_by_label:
nodes_by_label[label] = []
nodes_by_label[label].append(props)

total_nodes = 0
for label, nodes in nodes_by_label.items():
for i in range(0, len(nodes), self.batch_size):
batch = nodes[i : i + self.batch_size]
cypher = (
f"UNWIND $batch AS row "
f"{'MERGE' if self.write_mode == 'merge' else 'CREATE'} (n:{label} {{id: row.id}}) "
"SET n += row.properties"
)
session.run(cypher, batch=batch)
total_nodes += len(batch)

rich_print(f" - Exported {total_nodes} nodes")

def _export_edges(self, session, graph: nx.DiGraph) -> None:
"""Batch write edges to Neo4j."""
edges_by_type: Dict[str, List[Dict[str, Any]]] = {}

for u, v, data in graph.edges(data=True):
rel_type = data.get("label", "RELATED_TO").upper()
# Sanitize relationship type
rel_type = "".join(x for x in rel_type if x.isalnum() or x == "_")
if not rel_type:
rel_type = "RELATED_TO"

props = {k: v for k, v in data.items() if k != "label"}
props["source_id"] = u
props["target_id"] = v

if rel_type not in edges_by_type:
edges_by_type[rel_type] = []
edges_by_type[rel_type].append(props)

total_edges = 0
for rel_type, edges in edges_by_type.items():
for i in range(0, len(edges), self.batch_size):
batch = edges[i : i + self.batch_size]
cypher = (
"UNWIND $batch AS row "
"MATCH (source {id: row.source_id}) "
"MATCH (target {id: row.target_id}) "
f"{'MERGE' if self.write_mode == 'merge' else 'CREATE'} (source)-[r:{rel_type}]->(target) "
"SET r += row " # This sets source_id/target_id on rel too, which is harmless but redundant
)
session.run(cypher, batch=batch)
total_edges += len(batch)

rich_print(f" - Exported {total_edges} edges")
16 changes: 14 additions & 2 deletions docling_graph/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
# Import LLM clients
from .llm_clients import BaseLlmClient, get_client

from .db_clients.neo4j_client import Neo4jExporter

def _load_template_class(template_str: str) -> type[BaseModel]:
"""Dynamically imports a Pydantic model class from a string."""
Expand Down Expand Up @@ -230,11 +231,22 @@ def run_pipeline(config: Union[PipelineConfig, Dict[str, Any]]) -> None:

if export_format == "csv":
CSVExporter().export(knowledge_graph, output_dir)
rich_print(f"[green][/green] Saved CSV files to [green]{output_dir}[/green]")
rich_print(f"[green][/green] Saved CSV files to [green]{output_dir}[/green]")
elif export_format == "cypher":
cypher_path = output_dir / f"{base_name}_graph.cypher"
CypherExporter().export(knowledge_graph, cypher_path)
rich_print(f"[green]→[/green] Saved Cypher script to [green]{cypher_path}[/green]")
rich_print(f"[green]✔[/green] Saved Cypher script to [green]{cypher_path}[/green]")
elif export_format == "neo4j":
# Extract Neo4j config
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try the importfrom .db_clients.neo4j_client import Neo4jExporter here

neo_conf = conf.get("neo4j", {})
exporter = Neo4jExporter(
uri=neo_conf.get("uri", "bolt://localhost:7687"),
auth=(neo_conf.get("username", "neo4j"), neo_conf.get("password", "password")),
database=neo_conf.get("database", "neo4j"),
batch_size=neo_conf.get("batch_size", 1000),
write_mode=neo_conf.get("write_mode", "merge")
)
exporter.export(knowledge_graph)
else:
raise ValueError(f"Unknown export format: {export_format}")

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dependencies = [
"rich>=13,<14",
"typer[all]>=0.12,<1.0.0",
"python-dotenv>=1.0,<2.0",
"neo4j>=5.0.0",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this an optional dependency?

]

[project.urls]
Expand Down