-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathadvanced_nlu.py
More file actions
48 lines (40 loc) · 2.11 KB
/
advanced_nlu.py
File metadata and controls
48 lines (40 loc) · 2.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
class AdvancedNLU:
def __init__(self, model_name="facebook/bart-large-mnli"):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device)
def understand(self, text: str, context: List[Dict[str, str]] = None) -> Dict[str, Any]:
# Prepare input
if context:
context_text = " ".join([f"{item['role']}: {item['content']}" for item in context])
input_text = f"{context_text}\nHuman: {text}\nAI:"
else:
input_text = f"Human: {text}\nAI:"
# Tokenize and generate
inputs = self.tokenizer(input_text, return_tensors="pt", max_length=1024, truncation=True).to(self.device)
outputs = self.model.generate(**inputs, max_length=100, num_return_sequences=1, do_sample=True)
# Decode and parse the output
decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Here, implement more sophisticated parsing based on the expected structure of the output
# For this example, we'll use a simple approach
intent, entities = self._parse_output(decoded_output)
return {
"intent": intent,
"entities": entities,
"raw_output": decoded_output
}
def _parse_output(self, output: str) -> Tuple[str, List[Dict[str, str]]]:
# Implement more sophisticated parsing here
# This is a placeholder implementation
parts = output.split(", ")
intent = parts[0] if parts else "unknown"
entities = [{"type": "entity", "value": entity} for entity in parts[1:]]
return intent, entities
# Usage
nlu = AdvancedNLU()
result = nlu.understand("Book a flight to New York for tomorrow",
context=[{"role": "Human", "content": "I need to travel soon."},
{"role": "AI", "content": "Certainly! Where would you like to go?"}])
print(result)