Skip to content

Commit 0ff08cb

Browse files
committed
feat: downloader with CDN redirect support and retry handling
1 parent 36ad1b1 commit 0ff08cb

39 files changed

Lines changed: 6541 additions & 182 deletions

telegram/cdn.go

Lines changed: 264 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,291 @@
11
package telegram
22

33
import (
4+
"context"
45
"crypto/rsa"
56
"encoding/pem"
67

78
"github.com/go-faster/errors"
89

910
"github.com/gotd/td/crypto"
11+
"github.com/gotd/td/exchange"
1012
"github.com/gotd/td/tg"
1113
)
1214

13-
func parseCDNKeys(keys ...tg.CDNPublicKey) ([]*rsa.PublicKey, error) {
14-
r := make([]*rsa.PublicKey, 0, len(keys))
15+
// Single key because help.getCdnConfig has no request params.
16+
const helpGetCDNConfigSingleflightKey = "help.getCdnConfig"
17+
18+
type cdnKeyEntry struct {
19+
dcID int
20+
key *rsa.PublicKey
21+
}
22+
23+
type fetchedCDNKeys struct {
24+
all []exchange.PublicKey
25+
byDC map[int][]exchange.PublicKey
26+
}
27+
28+
func clonePublicKeys(keys []exchange.PublicKey) []exchange.PublicKey {
29+
return append([]exchange.PublicKey(nil), keys...)
30+
}
31+
32+
func mergePublicKeys(primary, fallback []exchange.PublicKey) []exchange.PublicKey {
33+
if len(primary) == 0 && len(fallback) == 0 {
34+
return nil
35+
}
36+
37+
out := make([]exchange.PublicKey, 0, len(primary)+len(fallback))
38+
seen := make(map[int64]struct{}, len(primary)+len(fallback))
39+
appendUnique := func(keys []exchange.PublicKey) {
40+
for _, key := range keys {
41+
fp := key.Fingerprint()
42+
if _, ok := seen[fp]; ok {
43+
continue
44+
}
45+
seen[fp] = struct{}{}
46+
out = append(out, key)
47+
}
48+
}
49+
50+
// Prefer primary keyset order and use fallback only for missing fingerprints.
51+
appendUnique(primary)
52+
appendUnique(fallback)
53+
return out
54+
}
55+
56+
func parseCDNKeyEntries(keys ...tg.CDNPublicKey) ([]cdnKeyEntry, error) {
57+
r := make([]cdnKeyEntry, 0, len(keys))
1558

1659
for _, key := range keys {
1760
block, _ := pem.Decode([]byte(key.PublicKey))
1861
if block == nil {
1962
continue
2063
}
2164

22-
key, err := crypto.ParseRSA(block.Bytes)
65+
parsedKey, err := crypto.ParseRSA(block.Bytes)
2366
if err != nil {
2467
return nil, errors.Wrap(err, "parse RSA from PEM")
2568
}
2669

27-
r = append(r, key)
70+
r = append(r, cdnKeyEntry{
71+
dcID: key.DCID,
72+
key: parsedKey,
73+
})
2874
}
2975

3076
return r, nil
3177
}
78+
79+
func buildCDNKeysCache(entries []cdnKeyEntry) fetchedCDNKeys {
80+
result := fetchedCDNKeys{
81+
all: make([]exchange.PublicKey, 0, len(entries)),
82+
byDC: make(map[int][]exchange.PublicKey),
83+
}
84+
85+
seenAll := make(map[int64]struct{}, len(entries))
86+
seenByDC := make(map[int]map[int64]struct{})
87+
88+
for _, entry := range entries {
89+
key := exchange.PublicKey{RSA: entry.key}
90+
fingerprint := key.Fingerprint()
91+
92+
if _, ok := seenAll[fingerprint]; !ok {
93+
seenAll[fingerprint] = struct{}{}
94+
result.all = append(result.all, key)
95+
}
96+
97+
seen, ok := seenByDC[entry.dcID]
98+
if !ok {
99+
seen = map[int64]struct{}{}
100+
seenByDC[entry.dcID] = seen
101+
}
102+
if _, ok := seen[fingerprint]; ok {
103+
continue
104+
}
105+
seen[fingerprint] = struct{}{}
106+
result.byDC[entry.dcID] = append(result.byDC[entry.dcID], key)
107+
}
108+
109+
return result
110+
}
111+
112+
func copyCDNKeysByDC(byDC map[int][]exchange.PublicKey) map[int][]exchange.PublicKey {
113+
if len(byDC) == 0 {
114+
return nil
115+
}
116+
117+
r := make(map[int][]exchange.PublicKey, len(byDC))
118+
for dcID, keys := range byDC {
119+
r[dcID] = append([]exchange.PublicKey(nil), keys...)
120+
}
121+
return r
122+
}
123+
124+
func cloneFetchedCDNKeys(keys fetchedCDNKeys) fetchedCDNKeys {
125+
return fetchedCDNKeys{
126+
all: clonePublicKeys(keys.all),
127+
byDC: copyCDNKeysByDC(keys.byDC),
128+
}
129+
}
130+
131+
func (c *Client) cachedCDNKeys() ([]exchange.PublicKey, bool, uint64) {
132+
c.cdnKeysMux.Lock()
133+
defer c.cdnKeysMux.Unlock()
134+
135+
return clonePublicKeys(c.cdnKeys), c.cdnKeysSet, c.cdnKeysGen
136+
}
137+
138+
func (c *Client) cachedCDNKeysForDC(dcID int) ([]exchange.PublicKey, bool) {
139+
c.cdnKeysMux.Lock()
140+
defer c.cdnKeysMux.Unlock()
141+
142+
return clonePublicKeys(c.cdnKeysByDC[dcID]), c.cdnKeysSet
143+
}
144+
145+
func (c *Client) cdnConfigFetchContext(caller context.Context) context.Context {
146+
if c.ctx != nil {
147+
// Bind network request lifetime to client lifecycle, not to the first
148+
// singleflight caller.
149+
return c.ctx
150+
}
151+
152+
// Caller cancellation is handled outside singleflight wait loop; request
153+
// itself should not inherit first caller deadline/cancellation.
154+
return context.WithoutCancel(caller)
155+
}
156+
157+
func (c *Client) loadCDNKeys(ctx context.Context) (fetchedCDNKeys, error) {
158+
resultCh := c.cdnKeysLoad.DoChan(helpGetCDNConfigSingleflightKey, func() (interface{}, error) {
159+
// singleflight ensures only one goroutine issues help.getCdnConfig;
160+
// others wait and reuse same result.
161+
cfg, err := c.tg.HelpGetCDNConfig(c.cdnConfigFetchContext(ctx))
162+
if err != nil {
163+
return nil, errors.Wrap(err, "help.getCdnConfig")
164+
}
165+
166+
entries, err := parseCDNKeyEntries(cfg.PublicKeys...)
167+
if err != nil {
168+
return nil, errors.Wrap(err, "parse CDN public keys")
169+
}
170+
return buildCDNKeysCache(entries), nil
171+
})
172+
173+
select {
174+
case <-ctx.Done():
175+
return fetchedCDNKeys{}, ctx.Err()
176+
case result := <-resultCh:
177+
if result.Err != nil {
178+
return fetchedCDNKeys{}, result.Err
179+
}
180+
181+
keys, ok := result.Val.(fetchedCDNKeys)
182+
if !ok {
183+
return fetchedCDNKeys{}, errors.Errorf("unexpected CDN keys type %T", result.Val)
184+
}
185+
return cloneFetchedCDNKeys(keys), nil
186+
}
187+
}
188+
189+
func (c *Client) fetchCDNKeys(ctx context.Context) ([]exchange.PublicKey, error) {
190+
const maxVersionRetries = 3
191+
for attempt := 0; attempt < maxVersionRetries; attempt++ {
192+
// Fast path: fully cached, no network requests.
193+
cached, set, startGen := c.cachedCDNKeys()
194+
if set {
195+
return cached, nil
196+
}
197+
// Snapshot generation to detect invalidation races after in-flight load.
198+
199+
keys, err := c.loadCDNKeys(ctx)
200+
if err != nil {
201+
return nil, err
202+
}
203+
204+
c.cdnKeysMux.Lock()
205+
switch {
206+
case c.cdnKeysSet:
207+
// Another goroutine already populated cache while we were waiting.
208+
cached := clonePublicKeys(c.cdnKeys)
209+
c.cdnKeysMux.Unlock()
210+
return cached, nil
211+
case c.cdnKeysGen != startGen:
212+
// Cache was invalidated (fingerprint miss) during in-flight request.
213+
// Discard stale result and retry from fresh generation.
214+
c.cdnKeysMux.Unlock()
215+
continue
216+
default:
217+
// Safe to commit fetched keys into cache.
218+
c.cdnKeys = clonePublicKeys(keys.all)
219+
c.cdnKeysByDC = copyCDNKeysByDC(keys.byDC)
220+
c.cdnKeysSet = true
221+
cached := clonePublicKeys(c.cdnKeys)
222+
c.cdnKeysMux.Unlock()
223+
return cached, nil
224+
}
225+
}
226+
227+
return nil, errors.New("cdn keys cache changed concurrently")
228+
}
229+
230+
func (c *Client) refreshCDNKeys(ctx context.Context) ([]exchange.PublicKey, error) {
231+
const maxVersionRetries = 3
232+
for attempt := 0; attempt < maxVersionRetries; attempt++ {
233+
c.cdnKeysMux.Lock()
234+
startGen := c.cdnKeysGen
235+
c.cdnKeysMux.Unlock()
236+
237+
keys, err := c.loadCDNKeys(ctx)
238+
if err != nil {
239+
return nil, err
240+
}
241+
242+
c.cdnKeysMux.Lock()
243+
if c.cdnKeysGen != startGen {
244+
// Fingerprint invalidation happened while refresh was in-flight.
245+
// Discard stale result and refetch for fresh generation.
246+
c.cdnKeysMux.Unlock()
247+
continue
248+
}
249+
c.cdnKeys = clonePublicKeys(keys.all)
250+
c.cdnKeysByDC = copyCDNKeysByDC(keys.byDC)
251+
c.cdnKeysSet = true
252+
cached := clonePublicKeys(c.cdnKeys)
253+
c.cdnKeysMux.Unlock()
254+
255+
return cached, nil
256+
}
257+
258+
return nil, errors.New("cdn keys cache changed concurrently")
259+
}
260+
261+
func (c *Client) fetchCDNKeysForDC(ctx context.Context, dcID int) ([]exchange.PublicKey, error) {
262+
keys, set := c.cachedCDNKeysForDC(dcID)
263+
if !set {
264+
if _, err := c.fetchCDNKeys(ctx); err != nil {
265+
return nil, err
266+
}
267+
}
268+
269+
const maxRefreshAttempts = 3
270+
for attempt := 0; attempt < maxRefreshAttempts; attempt++ {
271+
if err := ctx.Err(); err != nil {
272+
return nil, err
273+
}
274+
275+
keys, _ = c.cachedCDNKeysForDC(dcID)
276+
if len(keys) > 0 {
277+
return keys, nil
278+
}
279+
if attempt == maxRefreshAttempts-1 {
280+
break
281+
}
282+
283+
// Requested CDN DC is missing in current snapshot; retry bounded
284+
// help.getCdnConfig refreshes to handle eventual config propagation.
285+
if _, err := c.refreshCDNKeys(ctx); err != nil {
286+
return nil, err
287+
}
288+
}
289+
290+
return nil, errors.Errorf("no CDN public keys for CDN DC %d after %d refresh attempts", dcID, maxRefreshAttempts)
291+
}

telegram/cdn_conn_dead.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package telegram
2+
3+
import (
4+
"github.com/go-faster/errors"
5+
"go.uber.org/zap"
6+
7+
"github.com/gotd/td/exchange"
8+
"github.com/gotd/td/mtproto"
9+
"github.com/gotd/td/pool"
10+
)
11+
12+
func (c *Client) handleCDNConnDead(dcID int, err error) {
13+
if errors.Is(err, exchange.ErrKeyFingerprintNotFound) {
14+
c.log.Warn("Resetting cached CDN keys after fingerprint miss",
15+
zap.Int("dc_id", dcID),
16+
)
17+
c.cdnKeysMux.Lock()
18+
c.cdnKeys = nil
19+
c.cdnKeysByDC = nil
20+
c.cdnKeysSet = false
21+
// Bump generation so in-flight help.getCdnConfig results are discarded if
22+
// they were started before invalidation.
23+
c.cdnKeysGen++
24+
c.cdnKeysMux.Unlock()
25+
// Drop current singleflight entry so next attempt refetches keys.
26+
c.cdnKeysLoad.Forget(helpGetCDNConfigSingleflightKey)
27+
28+
// Close asynchronously: callback may be invoked from pool worker
29+
// goroutine, and synchronous self-close can deadlock on Wait().
30+
// Queue closes through bounded workers to avoid unbounded goroutine fan-out.
31+
c.cdnPools.invalidateDC(dcID)
32+
// Fingerprint miss is recoverable and handled internally by invalidation
33+
// + reconnect with fresh keys, no external onDead signal is needed.
34+
return
35+
}
36+
37+
if !errors.Is(err, mtproto.ErrPFSDropKeysRequired) {
38+
// Keep legacy callback semantics for all non-PFS errors.
39+
if c.onDead != nil {
40+
c.onDead(err)
41+
}
42+
return
43+
}
44+
45+
c.log.Warn("Dropping stored CDN session key after PFS key reset request",
46+
zap.Int("dc_id", dcID),
47+
)
48+
c.sessionsMux.Lock()
49+
s, ok := c.cdnSessions[dcID]
50+
if !ok {
51+
s = pool.NewSyncSession(pool.Session{DC: dcID})
52+
c.cdnSessions[dcID] = s
53+
}
54+
s.Store(pool.Session{DC: dcID})
55+
c.sessionsMux.Unlock()
56+
57+
if c.onDead != nil {
58+
c.onDead(err)
59+
}
60+
}

0 commit comments

Comments
 (0)