-
Notifications
You must be signed in to change notification settings - Fork 4
Open
Labels
enhancementNew feature or requestNew feature or request
Description
Currently, users need to implement character encoding/decoding manually for text generation tasks. This creates boilerplate code and potential inconsistencies.
The current text generation example requires users to:
- Manually implement character-to-embedding conversion
- Handle vocabulary management themselves
- Implement temperature sampling from scratch
Proposed Solution
Add a comprehensive set of text generation utilities:
Core Utilities
pub struct TextVocabulary {
char_to_idx: HashMap<char, usize>,
idx_to_char: HashMap<usize, char>,
vocab_size: usize,
}
impl TextVocabulary {
pub fn from_text(text: &str) -> Self
pub fn char_to_index(&self, ch: char) -> Option<usize>
pub fn index_to_char(&self, idx: usize) -> Option<char>
pub fn size(&self) -> usize
}
pub struct CharacterEmbedding {
embedding_matrix: Array2<f64>, // (vocab_size, embed_dim)
}
impl CharacterEmbedding {
pub fn new(vocab_size: usize, embed_dim: usize) -> Self
pub fn forward(&self, char_indices: &[usize]) -> Array2<f64>
pub fn lookup(&self, char_idx: usize) -> Array1<f64>
}
// Sampling utilities
pub fn sample_with_temperature(logits: &Array1<f64>, temperature: f64) -> usize
pub fn sample_top_k(logits: &Array1<f64>, k: usize, temperature: f64) -> usize
pub fn sample_nucleus(logits: &Array1<f64>, p: f64, temperature: f64) -> usizeBenefits
- Reduces boilerplate code for users
- Provides consistent, tested implementations
- Enables multiple sampling strategies
- Simplifies character-level text generation
- Makes the library more user-friendly
Usage Example
// Create vocabulary and embeddings
let vocab = TextVocabulary::from_text("Hello world!");
let embeddings = CharacterEmbedding::new(vocab.size(), 64);
// Use with LSTM for text generation
let input_indices = vocab.char_to_index('H').unwrap();
let char_embedding = embeddings.lookup(input_indices);
// ... LSTM forward pass ...
let next_char_idx = sample_with_temperature(&logits, 0.8);Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request