Skip to content

Commit a7be34c

Browse files
added notebooks and some fixes
1 parent 2340d1b commit a7be34c

File tree

17 files changed

+3487
-2458
lines changed

17 files changed

+3487
-2458
lines changed

main.ipynb

Lines changed: 0 additions & 116 deletions
This file was deleted.

notebooks/sql_generator.ipynb

Lines changed: 130 additions & 0 deletions
Large diffs are not rendered by default.

notebooks/upstream.ipynb

Lines changed: 1652 additions & 0 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,6 @@ test = [
4646
"pytest-asyncio>=1.1.0",
4747
]
4848
lint = ["ruff"]
49-
dev = [
50-
"pytest>=8.4.1",
51-
"pytest-cov>=6.2.1",
52-
]
5349

5450
[tool.ruff]
5551
src = ["src"]

src/data_tools/analysis/models.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,17 @@
1+
import json
2+
import os
13
import uuid
24

3-
from typing import Any, Dict
5+
from typing import Any, Dict, Optional
46

7+
import yaml
8+
9+
from data_tools.common.exception import errors
10+
from data_tools.core import settings
511
from data_tools.dataframes.factory import DataFrameFactory
12+
from data_tools.dataframes.models import ColumnProfile
13+
from data_tools.models.resources.model import Column, ColumnProfilingMetrics
14+
from data_tools.models.resources.source import Source, SourceTables
615

716

817
class DataSet:
@@ -22,3 +31,49 @@ def __init__(self, df: Any, name: str):
2231

2332
# A dictionary to store the results of each analysis step
2433
self.results: Dict[str, Any] = {}
34+
35+
# FIXME - this is a temporary solution to save the results of the analysis
36+
# need to use model while executing the pipeline
37+
def save_yaml(self, file_path: Optional[str] = None) -> None:
38+
if file_path is None:
39+
file_path = f"{self.name}.yml"
40+
file_path = os.path.join(settings.PROJECT_BASE, file_path)
41+
42+
column_profiles = self.results.get("column_profiles")
43+
44+
table_description = self.results.get("table_glossary")
45+
table_tags = self.results.get("business_glossary_and_tags")
46+
47+
if column_profiles is None or table_description is None or table_tags is None:
48+
raise errors.NotFoundError(
49+
"Column profiles not found in the dataset results. Ensure profiling steps were executed."
50+
)
51+
52+
columns: list[Column] = []
53+
54+
for column_profile in column_profiles.values():
55+
column_profile = ColumnProfile.model_validate(column_profile)
56+
column = Column(
57+
name=column_profile.column_name,
58+
description=column_profile.business_glossary,
59+
type=column_profile.datatype_l1,
60+
category=column_profile.datatype_l2,
61+
tags=column_profile.business_tags,
62+
profiling_metrics=ColumnProfilingMetrics(
63+
count=column_profile.count,
64+
null_count=column_profile.null_count,
65+
distinct_count=column_profile.distinct_count,
66+
sample_data=column_profile.sample_data,
67+
),
68+
)
69+
columns.append(column)
70+
71+
table = SourceTables(name=self.name, description=table_description, columns=columns)
72+
73+
source = Source(name="healthcare", description=table_description, schema="public", database="", table=table)
74+
75+
sources = {"sources": [json.loads(source.model_dump_json())]}
76+
77+
# Save the YAML representation of the sources
78+
with open(file_path, "w") as file:
79+
yaml.dump(sources, file, sort_keys=False, default_flow_style=False)

src/data_tools/analysis/steps.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,22 @@ def analyze(self, dataset: DataSet) -> None:
3838
Performs column-level profiling for each column.
3939
This step depends on the 'table_profile' result.
4040
"""
41-
41+
4242
# Dependency check
43-
if 'table_profile' not in dataset.results:
43+
if "table_profile" not in dataset.results:
4444
raise RuntimeError("TableProfiler must be run before ColumnProfiler.")
4545

46-
table_profile: ProfilingOutput = dataset.results['table_profile']
46+
table_profile: ProfilingOutput = dataset.results["table_profile"]
4747
all_column_profiles = {}
4848

4949
for col_name in table_profile.columns:
5050
# We would add a method to our DataFrame wrapper to get stats for a single column
51-
stats = dataset.dataframe_wrapper.column_profile(dataset.raw_df, dataset.name, col_name, settings.UPSTREAM_SAMPLE_LIMIT)
51+
stats = dataset.dataframe_wrapper.column_profile(
52+
dataset.raw_df, dataset.name, col_name, settings.UPSTREAM_SAMPLE_LIMIT
53+
)
5254
all_column_profiles[col_name] = stats
53-
54-
dataset.results['column_profiles'] = all_column_profiles
55+
56+
dataset.results["column_profiles"] = all_column_profiles
5557

5658

5759
class DataTypeIdentifierL1(AnalysisStep):
@@ -60,19 +62,21 @@ def analyze(self, dataset: DataSet) -> None:
6062
Performs datatype identification level 1 for each column.
6163
This step depends on the 'column_profiles' result.
6264
"""
63-
65+
6466
# Dependency check
65-
if 'column_profiles' not in dataset.results:
67+
if "column_profiles" not in dataset.results:
6668
raise RuntimeError("TableProfiler and ColumnProfiler must be run before DatatypeIdentifierL1.")
6769

68-
column_profiles: dict[str, ColumnProfile] = dataset.results['column_profiles']
70+
column_profiles: dict[str, ColumnProfile] = dataset.results["column_profiles"]
6971

70-
column_datatypes_l1 = dataset.dataframe_wrapper.datatype_identification_l1(dataset.raw_df, dataset.name, column_profiles)
72+
column_datatypes_l1 = dataset.dataframe_wrapper.datatype_identification_l1(
73+
dataset.raw_df, dataset.name, column_profiles
74+
)
7175

7276
for column in column_datatypes_l1:
7377
column_profiles[column.column_name].datatype_l1 = column.datatype_l1
7478

75-
dataset.results['column_datatypes_l1'] = column_datatypes_l1
79+
dataset.results["column_datatypes_l1"] = column_datatypes_l1
7680

7781

7882
class DataTypeIdentifierL2(AnalysisStep):
@@ -81,19 +85,21 @@ def analyze(self, dataset: DataSet) -> None:
8185
Performs datatype identification level 2 for each column.
8286
This step depends on the 'column_datatypes_l1' result.
8387
"""
84-
88+
8589
# Dependency check
86-
if 'column_profiles' not in dataset.results:
90+
if "column_profiles" not in dataset.results:
8791
raise RuntimeError("TableProfiler and ColumnProfiler must be run before DatatypeIdentifierL2.")
8892

89-
column_profiles: dict[str, ColumnProfile] = dataset.results['column_profiles']
93+
column_profiles: dict[str, ColumnProfile] = dataset.results["column_profiles"]
9094
columns_with_samples = [DataTypeIdentificationL2Input(**col.model_dump()) for col in column_profiles.values()]
91-
column_datatypes_l2 = dataset.dataframe_wrapper.datatype_identification_l2(dataset.raw_df, dataset.name, columns_with_samples)
95+
column_datatypes_l2 = dataset.dataframe_wrapper.datatype_identification_l2(
96+
dataset.raw_df, dataset.name, columns_with_samples
97+
)
9298

9399
for column in column_datatypes_l2:
94100
column_profiles[column.column_name].datatype_l2 = column.datatype_l2
95101

96-
dataset.results['column_datatypes_l2'] = column_datatypes_l2
102+
dataset.results["column_datatypes_l2"] = column_datatypes_l2
97103

98104

99105
class KeyIdentifier(AnalysisStep):
@@ -102,21 +108,22 @@ def analyze(self, dataset: DataSet) -> None:
102108
Performs key identification for the dataset.
103109
This step depends on the datatype identification results.
104110
"""
105-
if 'column_datatypes_l1' not in dataset.results or 'column_datatypes_l2' not in dataset.results:
111+
if "column_datatypes_l1" not in dataset.results or "column_datatypes_l2" not in dataset.results:
106112
raise RuntimeError("DataTypeIdentifierL1 and L2 must be run before KeyIdentifier.")
107-
108-
column_profiles: dict[str, ColumnProfile] = dataset.results['column_profiles']
113+
114+
column_profiles: dict[str, ColumnProfile] = dataset.results["column_profiles"]
109115
column_profiles_df = pd.DataFrame([col.model_dump() for col in column_profiles.values()])
110116

111117
key = dataset.dataframe_wrapper.key_identification(dataset.name, column_profiles_df)
112-
dataset.results["key"] = key
118+
if key is not None:
119+
dataset.results["key"] = key
113120

114121

115122
class BusinessGlossaryGenerator(AnalysisStep):
116123
def __init__(self, domain: str):
117124
"""
118125
Initializes the BusinessGlossaryGenerator with optional additional context.
119-
126+
120127
:param domain: The industry domain to which the dataset belongs.
121128
"""
122129
self.domain = domain
@@ -125,10 +132,10 @@ def analyze(self, dataset: DataSet) -> None:
125132
"""
126133
Generates business glossary terms and tags for each column in the dataset.
127134
"""
128-
if 'column_datatypes_l1' not in dataset.results:
135+
if "column_datatypes_l1" not in dataset.results:
129136
raise RuntimeError("DataTypeIdentifierL1 must be run before Business Glossary Generation.")
130-
131-
column_profiles: dict[str, ColumnProfile] = dataset.results['column_profiles']
137+
138+
column_profiles: dict[str, ColumnProfile] = dataset.results["column_profiles"]
132139
column_profiles_df = pd.DataFrame([col.model_dump() for col in column_profiles.values()])
133140

134141
glossary_output = dataset.dataframe_wrapper.generate_business_glossary(
@@ -138,7 +145,6 @@ def analyze(self, dataset: DataSet) -> None:
138145
for column in glossary_output.columns:
139146
column_profiles[column.column_name].business_glossary = column.business_glossary
140147
column_profiles[column.column_name].business_tags = column.business_tags
141-
142-
dataset.results["business_glossary_and_tags"] = glossary_output
143-
dataset.results['table_glossary'] = glossary_output.table_glossary
144148

149+
dataset.results["business_glossary_and_tags"] = glossary_output
150+
dataset.results["table_glossary"] = glossary_output.table_glossary

src/data_tools/core/settings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
class Settings(BaseSettings):
1414
"""Global Configuration"""
1515

16-
UPSTREAM_SAMPLE_LIMIT: int = 10000
16+
UPSTREAM_SAMPLE_LIMIT: int = 10
1717
MODEL_DIR_PATH: str = str(Path(os.path.split(os.path.abspath(__file__))[0]).parent.joinpath("artifacts"))
1818
MODEL_RESULTS_PATH: str = os.path.join("model", "model_results")
1919

@@ -24,7 +24,7 @@ class Settings(BaseSettings):
2424

2525
DI_MODEL_VERSION: str = "13052023"
2626

27-
PROJECT_BASE: str = "/home/juhel-phanju/Documents/backup/MIGRATION/codes/poc/dbt/ecom/ecom/models"
27+
PROJECT_BASE: str
2828

2929
MCP_SERVER_NAME: str = "data-tools"
3030
MCP_SERVER_DESCRIPTION: str = "Data Tools for MCP"

src/data_tools/dataframes/dataframe.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import Any
2+
from typing import Any, Optional
33

44
import pandas as pd
55

@@ -20,7 +20,7 @@
2020
)
2121

2222

23-
class DataframeAdatper(ABC):
23+
class DataFrame(ABC):
2424
@abstractmethod
2525
def profile(self, df: Any) -> ProfilingOutput:
2626
pass
@@ -31,7 +31,7 @@ def column_profile(
3131
df: Any,
3232
table_name: str,
3333
column_name: str,
34-
sample_limit: int = 200,
34+
sample_limit: int = 10,
3535
) -> ColumnProfile:
3636
pass
3737

@@ -94,7 +94,7 @@ def key_identification(
9494
self,
9595
table_name: str,
9696
column_stats: pd.DataFrame,
97-
) -> KeyIdentificationOutput:
97+
) -> Optional[str]:
9898
"""
9999
Identifies potential primary keys in the DataFrame based on column profiles.
100100
@@ -104,12 +104,12 @@ def key_identification(
104104
`column_profile` method.
105105
106106
Returns:
107-
A KeyIdentificationOutput model containing the identified primary key column.
107+
A string (column name) containing the identified primary key column.
108108
"""
109109
ki_model = KeyIdentificationLLM(profiling_data=column_stats)
110110
ki_result = ki_model()
111111
output = KeyIdentificationOutput(**ki_result)
112-
return output
112+
return output.column_name
113113

114114
def generate_business_glossary(
115115
self,

0 commit comments

Comments
 (0)