-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate.py
More file actions
80 lines (66 loc) · 2.53 KB
/
generate.py
File metadata and controls
80 lines (66 loc) · 2.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os, glob, torch
from datetime import datetime
from transformers import AutoTokenizer, T5ForConditionalGeneration
# Load trained model & tokenizer
MODEL_DIR = "./code2doc_model_final"
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True)
model = T5ForConditionalGeneration.from_pretrained(MODEL_DIR)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
def generate_doc(code_snippet: str,
max_input_length: int = 512,
max_output_length: int = 128,
num_beams: int = 4):
text = "summarize: " + code_snippet.strip()
inputs = tokenizer(
text,
max_length=max_input_length,
truncation=True,
padding="longest",
return_tensors="pt",
).to(device)
outs = model.generate(
input_ids = inputs.input_ids,
attention_mask = inputs.attention_mask,
max_length = max_output_length,
num_beams = num_beams,
early_stopping = True,
)
return tokenizer.decode(outs[0], skip_special_tokens=True).strip()
def get_lang_tag(filename: str):
ext = os.path.splitext(filename)[-1]
return {
'.js': 'javascript',
'.py': 'python',
'.java': 'java',
'.go': 'go',
}.get(ext, '')
def batch_generate(input_dir: str, base_output: str = "./output/results"):
# Create timestamped directory
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = os.path.join(base_output, f"test_{timestamp}")
os.makedirs(output_dir, exist_ok=True)
# File extensions to support
extensions = ["*.js", "*.py", "*.java", "*.go"]
files = []
for ext in extensions:
files.extend(glob.glob(os.path.join(input_dir, "**", ext), recursive=True))
for path in files:
with open(path, "r", encoding="utf-8") as f:
code = f.read()
if not code or len(code) > 50_000:
continue
doc = generate_doc(code)
# Output file setup
base_name = os.path.basename(path)
lang_tag = get_lang_tag(base_name)
out_path = os.path.join(output_dir, base_name + ".docs.md")
with open(out_path, "w", encoding="utf-8") as out:
out.write(f"# Documentation for `{base_name}`\n\n")
out.write(f"```{lang_tag}\n{code}\n```\n\n")
out.write("## Generated Documentation\n\n")
out.write(doc + "\n")
print(f"📝 {path} → {out_path}")
if __name__ == "__main__":
project_dir = "./examples"
batch_generate(project_dir)