-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTokenLattice.php
More file actions
122 lines (102 loc) · 3.34 KB
/
TokenLattice.php
File metadata and controls
122 lines (102 loc) · 3.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
<?php
declare(strict_types=1);
namespace Codewithkyrian\Tokenizers\DataStructures;
class TokenLattice
{
/** @var int The length of the input sentence. */
public int $len;
/** @var TokenLatticeNode[] An array of nodes representing the lattice nodes. */
public array $nodes = [];
/** @var TokenLatticeNode[][] An array of nodes representing the beginning nodes in the lattice. */
public array $beginNodes = [];
/** @var TokenLatticeNode[][] An array of nodes representing the ending nodes in the lattice. */
public array $endNodes = [];
public function __construct(
public string $sentence,
public ?int $bosTokenId,
public ?int $eosTokenId
) {
$this->len = mb_strlen($sentence);
$this->beginNodes = array_fill(0, $this->len + 1, []);
$this->endNodes = array_fill(0, $this->len + 1, []);
$bos = new TokenLatticeNode($this->bosTokenId, 0, 0, 0, 0.0);
$eos = new TokenLatticeNode($this->eosTokenId, 1, $this->len, 0, 0.0);
$this->nodes[] = $bos;
$this->nodes[] = $eos;
$this->beginNodes[$this->len][] = $eos;
$this->endNodes[0][] = $bos;
}
public function insert(int $pos, int $length, float $score, int $tokenId): void
{
$nodeId = \count($this->nodes);
$node = new TokenLatticeNode($tokenId, $nodeId, $pos, $length, $score);
$this->beginNodes[$pos][] = $node;
$this->endNodes[$pos + $length][] = $node;
$this->nodes[] = $node;
}
/**
* @return TokenLatticeNode[]
*/
public function viterbi(): array
{
$len = $this->len;
$pos = 0;
while ($pos <= $len) {
if (empty($this->beginNodes[$pos])) {
return [];
}
foreach ($this->beginNodes[$pos] as $rnode) {
$rnode->prev = null;
$bestScore = 0.0;
$bestNode = null;
foreach ($this->endNodes[$pos] as $lnode) {
$score = $lnode->backtraceScore + $rnode->score;
if (null === $bestNode || $score > $bestScore) {
$bestNode = $lnode;
$bestScore = $score;
}
}
if (null !== $bestNode) {
$rnode->prev = $bestNode;
$rnode->backtraceScore = $bestScore;
} else {
return [];
}
}
++$pos;
}
$results = [];
$root = $this->beginNodes[$len][0];
$prev = $root->prev;
if (null === $prev) {
return [];
}
$node = $prev;
while (null !== $node->prev) {
$results[] = $node;
$n = $node;
$node = $n->prev;
}
return array_reverse($results);
}
public function piece(TokenLatticeNode $node): string
{
return mb_substr($this->sentence, $node->pos, $node->length);
}
/**
* @return string[]
*/
public function tokens(): array
{
$nodes = $this->viterbi();
return array_map([$this, 'piece'], $nodes);
}
/**
* @return int[]
*/
public function tokenIds(): array
{
$nodes = $this->viterbi();
return array_map(static fn ($x) => $x->tokenId, $nodes);
}
}