Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 54 additions & 43 deletions tgrag/construct_relational_database/construct_relational_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pickle
import sqlite3
from pathlib import Path
from typing import cast
from typing import List, cast

import numpy as np
import pandas as pd
Expand All @@ -15,7 +15,7 @@

from tgrag.utils.args import parse_args
from tgrag.utils.logger import setup_logging
from tgrag.utils.path import get_root_dir, get_scratch
from tgrag.utils.path import discover_subfolders, get_root_dir, get_scratch
from tgrag.utils.rd_utils import table_has_data
from tgrag.utils.seed import seed_everything
from tgrag.utils.target_generation import strict_exact_etld1_match
Expand All @@ -33,11 +33,11 @@


def construct_formatted_data(
db_path: Path,
node_csv: Path,
output_path: Path,
subfolders: List[Path],
dqr_csv: Path,
seed: int = 42,
D: int = 128,
D: int = 64,
chunk_size: int = 1_000_000,
) -> None:
dqr_df = pd.read_csv(dqr_csv)
Expand All @@ -47,46 +47,49 @@ def construct_formatted_data(
}

rng = np.random.default_rng(seed=seed)
output_path = db_path / 'features.json'
output_path = output_path / 'features.json'

if output_path.exists():
logging.info(f'{output_path} already exists, returning.')
return

logging.info(f'Processing {node_csv} in chunks of {chunk_size:,} rows...')

included: int = 0
with open(output_path, 'w') as f_out:
for chunk in tqdm(
pd.read_csv(node_csv, chunksize=chunk_size),
desc='Reading vertices',
unit='chunk',
):
x_chunk = rng.normal(size=(len(chunk), D)).astype(np.float32)
for folder in tqdm(subfolders, desc='Processing Subfolders'):
node_csv = folder / 'vertices.csv'

logging.info(f'Processing {node_csv} in chunks of {chunk_size:,} rows...')

for i, (_, row) in tqdm(
enumerate(chunk.iterrows()), desc='Iterating chunk'
for chunk in tqdm(
pd.read_csv(node_csv, chunksize=chunk_size),
desc='Reading vertices',
unit='chunk',
):
raw_domain = str(row['domain']).strip()
x_chunk = rng.normal(size=(len(chunk), D)).astype(np.float32)

for i, (_, row) in tqdm(
enumerate(chunk.iterrows()), desc='Iterating chunk'
):
raw_domain = str(row['domain']).strip()

etld1 = strict_exact_etld1_match(raw_domain, dqr_domains)
etld1 = strict_exact_etld1_match(raw_domain, dqr_domains)

if etld1 is None:
y = -1.0
else:
included += 1
y = float(dqr_domains[etld1]['pc1'])
if etld1 is None:
y = -1.0
else:
included += 1
y = float(dqr_domains[etld1]['pc1'])

record = {
'domain': raw_domain,
'ts': int(row['ts']),
'y': y,
'x': x_chunk[i].tolist(),
}
record = {
'domain': raw_domain,
'ts': int(row['ts']),
'y': y,
'x': x_chunk[i].tolist(),
}

f_out.write(json.dumps(record) + '\n')
f_out.write(json.dumps(record) + '\n')

logging.info(f'There are {included} domains that exist in DQR')
logging.info(f'There are {included} domains that intersect with DQR')
logging.info(f'Streaming write complete to {output_path}')


Expand Down Expand Up @@ -121,9 +124,9 @@ def initialize_graph_db(db_path: Path) -> sqlite3.Connection:


def construct_masks_from_json(
nid_map_path: Path, json_path: Path, db_path: Path, seed: int = 0
output_path: Path, nid_map_path: Path, json_path: Path, seed: int
) -> None:
output_path = db_path / 'split_idx.pt'
output_path = output_path / 'split_idx.pt'
if output_path.exists():
logging.info(f'{output_path} already exists, returning.')
return
Expand Down Expand Up @@ -238,7 +241,7 @@ def populate_from_json(

x = np.array(record['x'], dtype=np.float32).tobytes()
con.execute(
'INSERT INTO domain VALUES (?, ?, ?, ?)',
'INSERT OR IGNORE INTO domain VALUES (?, ?, ?, ?)',
(id, int(record['ts']), x, float(record['y'])),
)
logging.info('Database populated')
Expand All @@ -257,24 +260,32 @@ def main() -> None:
setup_logging(meta_args.log_file_path)
seed_everything(meta_args.global_seed)

db_path = scratch / cast(str, meta_args.database_folder)
node_path = scratch / cast(str, meta_args.node_file)
base_dir = scratch / cast(str, meta_args.database_folder)
aggregate_out = base_dir / 'aggregate'
aggregate_out.parent.mkdir(parents=True, exist_ok=True)

logging.info(f'Scanning base directory: {base_dir}')
subfolders = discover_subfolders(base_dir)
dqr_path = root / 'data' / 'dqr' / 'domain_pc1.csv'

construct_formatted_data(db_path=db_path, node_csv=node_path, dqr_csv=dqr_path)
db_path = scratch / cast(str, meta_args.database_folder)

construct_formatted_data(
output_path=aggregate_out, subfolders=subfolders, dqr_csv=dqr_path
)
construct_masks_from_json(
nid_map_path=db_path / 'nid_map.pkl',
json_path=db_path / 'features.json',
db_path=db_path,
output_path=aggregate_out,
nid_map_path=aggregate_out / 'global_domain_to_id.pkl',
json_path=aggregate_out / 'features.json',
seed=meta_args.global_seed,
)
con = initialize_graph_db(db_path=db_path)
populate_from_json(
con=con,
nid_map_path=db_path / 'nid_map.pkl',
json_path=db_path / 'features.json',
nid_map_path=aggregate_out / 'global_domain_to_id.pkl',
json_path=aggregate_out / 'features.json',
)
populate_edges(con=con, edges_path=db_path / 'edges_with_id.csv')
populate_edges(con=con, edges_path=aggregate_out / 'edges_with_id.csv')
logging.info('Completed.')


Expand Down
146 changes: 146 additions & 0 deletions tgrag/construct_relational_database/domain_to_id_mapping_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import argparse
import faulthandler
import logging
import pickle
from pathlib import Path
from typing import Dict, List, cast

import numpy as np
import pandas as pd
from tqdm import tqdm

from tgrag.utils.args import parse_args
from tgrag.utils.logger import setup_logging
from tgrag.utils.path import discover_subfolders, get_root_dir, get_scratch
from tgrag.utils.seed import seed_everything

parser = argparse.ArgumentParser(
description='Aggregate domain-to-ID mapping and rewrite all vertices/edges.',
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
'--config-file',
type=str,
default='configs/tgl/base.yaml',
help='Path to yaml configuration file',
)


def build_global_mapping(
output_folder: Path, subfolders: List[Path], chunk_size: int = 1_000_000
) -> Dict[str, int]:
"""Scan all vertices.csv files and build one global domain→id mapping."""
logging.info('Building global domain-to-id mapping...')

domain_to_id = {}
next_id = 0

for folder in subfolders:
node_csv = folder / 'vertices.csv'
logging.info(f'Scanning domains in: {node_csv}')

for chunk in tqdm(
pd.read_csv(node_csv, chunksize=chunk_size),
desc=f'Scanning {folder.name}',
unit='chunk',
):
for domain in chunk['domain'].astype(str):
if domain not in domain_to_id:
domain_to_id[domain] = next_id
next_id += 1

logging.info(f'Total unique domains: {len(domain_to_id):,}')
with open(output_folder / 'global_domain_to_id.pkl', 'wb') as f:
pickle.dump(domain_to_id, f)

np.save(
output_folder / 'global_domain_ids.npy',
np.arange(len(domain_to_id), dtype=np.int64),
)
return domain_to_id


def aggregate_rewrite(
subfolders: List[Path],
domain_to_id: Dict[str, int],
aggregate_out: Path,
chunk_size: int = 1_000_000,
) -> None:
"""Rewrite and append all subfolder vertices/edges into global outputs."""
aggregate_out.mkdir(parents=True, exist_ok=True)

out_nodes = aggregate_out / 'vertices_with_id.csv'
out_edges = aggregate_out / 'edges_with_id.csv'

with open(out_nodes, 'w') as f:
f.write('id,ts\n')

with open(out_edges, 'w') as f:
f.write('src_id,dst_id,ts\n')

for folder in subfolders:
logging.info(f'Aggregating subfolder: {folder}')

nodes_csv = folder / 'vertices.csv'
edges_csv = folder / 'edges.csv'

with open(out_nodes, 'a') as fout:
for chunk in tqdm(
pd.read_csv(nodes_csv, chunksize=chunk_size),
desc=f'Vertices {folder.name}',
unit='chunk',
):
chunk['id'] = chunk['domain'].map(domain_to_id)
chunk[['id', 'ts']].astype({'id': 'int64'}).to_csv(
fout, header=False, index=False
)

with open(out_edges, 'a') as fout:
for chunk in tqdm(
pd.read_csv(edges_csv, chunksize=chunk_size),
desc=f'Edges {folder.name}',
unit='chunk',
):
chunk['src_id'] = chunk['src'].map(domain_to_id)
chunk['dst_id'] = chunk['dst'].map(domain_to_id)
chunk[['src_id', 'dst_id', 'ts']].astype(
{'src_id': 'int64', 'dst_id': 'int64'}
).to_csv(fout, header=False, index=False)

logging.info(f'Aggregate outputs complete at: {aggregate_out}')


def main() -> None:
faulthandler.enable()

root = get_root_dir()
scratch = get_scratch()
args = parser.parse_args()

config_file_path = root / args.config_file
meta_args, _ = parse_args(config_file_path)

setup_logging(meta_args.log_file_path)
seed_everything(meta_args.global_seed)

base_dir = scratch / cast(str, meta_args.database_folder)
aggregate_out = base_dir / 'aggregate'
aggregate_out.mkdir(parents=True, exist_ok=True)

logging.info(f'Scanning base directory: {base_dir}')
subfolders = discover_subfolders(base_dir)

if not subfolders:
raise RuntimeError(f'No valid subfolders found in {base_dir}')

domain_to_id = build_global_mapping(
output_folder=aggregate_out, subfolders=subfolders
)

aggregate_rewrite(subfolders, domain_to_id, aggregate_out)

logging.info('Completed.')


if __name__ == '__main__':
main()
7 changes: 5 additions & 2 deletions tgrag/experiments/gnn_experiments/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from tgrag.experiments.gnn_experiments.gnn_experiment import run_gnn_baseline
from tgrag.utils.args import parse_args
from tgrag.utils.logger import setup_logging
from tgrag.utils.mem import mem
from tgrag.utils.path import get_root_dir, get_scratch
from tgrag.utils.plot import (
load_all_loss_tuples,
Expand All @@ -35,7 +36,7 @@

def main() -> None:
root = get_root_dir()
scratch = get_scratch()
get_scratch()
args = parser.parse_args()
config_file_path = root / args.config_file
meta_args, experiment_args = parse_args(config_file_path)
Expand All @@ -57,6 +58,7 @@ def main() -> None:

logging.info(f'Encoding Dictionary: {encoding_dict}')

logging.info(f'Memory before: {mem():2f} MB')
dataset = TemporalDataset(
root=f'{root}/data/',
node_file=cast(str, meta_args.node_file),
Expand All @@ -68,9 +70,10 @@ def main() -> None:
index_col=meta_args.index_col,
encoding=encoding_dict,
seed=meta_args.global_seed,
processed_dir=f'{scratch}/{meta_args.processed_location}',
processed_dir=cast(str, meta_args.processed_location),
) # Map to .to_cpu()
logging.info('In-Memory Dataset loaded.')
logging.info(f'Memory after TemporalDataset load: {mem():2f} MB')

for experiment, experiment_arg in experiment_args.exp_args.items():
logging.info(f'\n**Running**: {experiment}')
Expand Down
3 changes: 3 additions & 0 deletions tgrag/experiments/gnn_experiments/main_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
from tgrag.utils.args import parse_args
from tgrag.utils.logger import setup_logging
from tgrag.utils.mem import mem
from tgrag.utils.path import get_root_dir, get_scratch
from tgrag.utils.plot import (
load_all_loss_tuples,
Expand Down Expand Up @@ -42,6 +43,7 @@ def main() -> None:

logging.info(f'Scratch Location: {scratch}')

logging.info(f'Memory before ZarrDataset: {mem():2f} MB')
dataset = ZarrDataset(
root=f'{root}/data/',
node_file=cast(str, meta_args.node_file),
Expand All @@ -56,6 +58,7 @@ def main() -> None:
database_folder=cast(str, meta_args.database_folder),
)
logging.info('In-Memory Zarr Dataset loaded.')
logging.info(f'Memory after ZarrDataset load: {mem():2f} MB')
zarr_path = scratch / cast(str, meta_args.database_folder) / 'embeddings.zarr'
logging.info(f'Reading Zarr storage from: {zarr_path}')
embeddings = zarr.open_array(str(zarr_path))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def run_weak_supervision_forward(

def main() -> None:
root = get_root_dir()
scratch = get_scratch()
get_scratch()
args = parser.parse_args()
config_file_path = root / args.config_file
meta_args, experiment_args = parse_args(config_file_path)
Expand All @@ -141,7 +141,7 @@ def main() -> None:
index_col=meta_args.index_col,
encoding=encoding_dict,
seed=meta_args.global_seed,
processed_dir=f'{scratch}/{meta_args.processed_location}',
processed_dir=cast(str, meta_args.processed_location),
) # Map to .to_cpu()
logging.info('In-Memory Dataset loaded.')
weight_directory = (
Expand Down
Loading