Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,39 @@

import argparse
import logging
import os

import torch
from safetensors.torch import save_file as safetensors_save_file

from src.utils.load_checkpoint import load_checkpoint


logger = logging.getLogger(__name__)

ALLOWED_HYPERPARAMETER_KEYS = (
"vocab_size",
"hidden_size",
"num_hidden_layers",
"num_attention_heads",
"intermediate_size",
"hidden_act",
"hidden_dropout_prob",
"attention_probs_dropout_prob",
"initializer_range",
"layer_norm_eps",
"pad_token_id",
"position_embedding_type",
"classifier_dropout",
"rotary_theta",
"ignore_index",
"loss_type",
"lora",
"lora_alpha",
"lora_r",
"lora_dropout",
)

# PYTorch -> TE keymap
PYTORCH_TO_TE_KEYMAP = {
"model.layers.*.pre_attn_layer_norm.weight": "model.layers.*.self_attention.layernorm_qkv.layer_norm_weight",
Expand Down Expand Up @@ -300,6 +325,11 @@ def convert_state_dict(src: dict, keymap: dict):
return dst_state_dict


def filter_hyper_parameters(hyper_parameters: dict) -> dict:
"""Keep only conversion-compatible hyperparameter keys."""
return {key: value for key, value in hyper_parameters.items() if key in ALLOWED_HYPERPARAMETER_KEYS}


def main():
"""Main function."""
logging.basicConfig(level=logging.INFO)
Expand All @@ -325,6 +355,7 @@ def main():
# Load source checkpoint (automatically detects format)
logger.info(f"Loading checkpoint from {args.src}")
src_checkpoint = load_checkpoint(args.src, map_location="cpu")
src_checkpoint["hyper_parameters"] = filter_hyper_parameters(src_checkpoint["hyper_parameters"])

# Perform conversion based on direction
if args.direction == "pytorch2te":
Expand All @@ -341,11 +372,19 @@ def main():
dst_state_dict = split_qkv(converted_state_dict, src_checkpoint["hyper_parameters"])

# Prepare final checkpoint
dst_checkpoint = {"state_dict": dst_state_dict, "hyper_parameters": src_checkpoint["hyper_parameters"]}
dst_checkpoint = {
"state_dict": dst_state_dict,
"hyper_parameters": src_checkpoint["hyper_parameters"],
}

# Save the converted checkpoint in pickled format
torch.save(dst_checkpoint, args.dst)
logger.info(f"Successfully converted checkpoint from {args.src} to {args.dst}")
logger.info(f"Successfully converted checkpoint saved to {args.dst}")

# Save the state_dict in safetensors format alongside the .ckpt file
safetensors_path = os.path.splitext(args.dst)[0] + ".safetensors"
safetensors_save_file(dst_state_dict, safetensors_path)
logger.info(f"Successfully saved safetensors checkpoint to {safetensors_path}")


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@


# %%
import argparse
import json
import sys
from pathlib import Path
Expand All @@ -23,41 +24,52 @@
from tqdm import tqdm


sys.path.append("/workspace/codon_fm")
sys.path.append("/workspace/codonfm")
from src.tokenizer import Tokenizer


data_path = Path("/data/ncbi/processed_unfiltered")
tax_ids_to_remove = json.load(open("/data/ncbi/taxids_to_remove.json"))
metadata = json.load(open(data_path / "metadata.json"))
tokenizer = Tokenizer()


groups = set([x["file_name"][:-4] for x in metadata["file_metadata"]]) # noqa: C403
counts = {g: np.zeros(tokenizer.vocab_size) for g in groups}
for fm, cm in tqdm(zip(metadata["file_metadata"], metadata["chunks"]), total=len(metadata["file_metadata"])):
group = fm["file_name"][:-4]
if group in tax_ids_to_remove:
curr_taxids_to_remove = set(tax_ids_to_remove[group])
else:
curr_taxids_to_remove = set()
mmap = np.memmap(
data_path / cm["sequences"]["path"],
dtype=cm["sequences"]["dtype"],
mode="r",
shape=tuple(cm["sequences"]["shape"]),
)
idx_mmap = np.memmap(
data_path / cm["index"]["path"], dtype=cm["index"]["dtype"], mode="r", shape=tuple(cm["index"]["shape"])
)
for start, end, taxid in idx_mmap:
if taxid in curr_taxids_to_remove:
continue
seq = mmap[start:end]
idx, count = np.unique(seq, return_counts=True)
counts[group][idx] += count
def main(pretraining_processed_data_dir: Path, data_dir: Path):
"""Check codon frequency."""
tax_ids_to_remove = json.load(open(data_dir / Path("taxids_to_remove.json")))
metadata = json.load(open(pretraining_processed_data_dir / "metadata.json"))
tokenizer = Tokenizer()

# %%
for g in counts:
counts[g] = counts[g].tolist()
json.dump(counts, open("/data/ncbi/codon_counts_nopathogen.json", "w"))
groups = set([x["file_name"][:-4] for x in metadata["file_metadata"]]) # noqa: C403
counts = {g: np.zeros(tokenizer.vocab_size) for g in groups}
for fm, cm in tqdm(zip(metadata["file_metadata"], metadata["chunks"]), total=len(metadata["file_metadata"])):
group = fm["file_name"][:-4]
if group in tax_ids_to_remove:
curr_taxids_to_remove = set(tax_ids_to_remove[group])
else:
curr_taxids_to_remove = set()
mmap = np.memmap(
pretraining_processed_data_dir / cm["sequences"]["path"],
dtype=cm["sequences"]["dtype"],
mode="r",
shape=tuple(cm["sequences"]["shape"]),
)
idx_mmap = np.memmap(
pretraining_processed_data_dir / cm["index"]["path"],
dtype=cm["index"]["dtype"],
mode="r",
shape=tuple(cm["index"]["shape"]),
)
for start, end, taxid in idx_mmap:
if taxid in curr_taxids_to_remove:
continue
seq = mmap[start:end]
idx, count = np.unique(seq, return_counts=True)
counts[group][idx] += count

# %%
for g in counts:
counts[g] = counts[g].tolist()
json.dump(counts, open(data_dir / "codon_counts_nopathogen.json", "w"))


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Check codon frequency")
parser.add_argument("--pretraining_processed_data_dir", type=str, required=True)
parser.add_argument("--data_dir", type=str, required=True)
args = parser.parse_args()
main(Path(args.pretraining_processed_data_dir), Path(args.data_dir))
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@
import argparse
import json
import os
import sys
from multiprocessing import Pool, cpu_count

import numpy as np
import polars as pl
import pyarrow.parquet as pq
from tqdm import tqdm


sys.path.append("/workspace/codonfm")
from src.tokenizer import Tokenizer


Expand Down
Loading