Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions packages/filesystem/auth.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,38 @@ describe("AuthVerify", () => {
).resolves.toEqual(["new-access", "new-access", "new-access"]);
expect(fetchMock).toHaveBeenCalledTimes(1);
});

it("concurrent initial token verification should share one auth request and save once", async () => {
vi.useFakeTimers();
const createSpy = vi.spyOn(chrome.tabs, "create").mockImplementation(() => Promise.resolve({ id: 1 }) as any);
const originalGet = (chrome.tabs as any).get;
(chrome.tabs as any).get = vi.fn().mockRejectedValue(new Error("closed"));
const saveSpy = vi.spyOn(LocalStorageDAO.prototype, "saveValue");
const fetchMock = vi.fn().mockResolvedValue({
json: vi.fn().mockResolvedValue({
code: 0,
data: {
token: {
access_token: "initial-access",
refresh_token: "initial-refresh",
},
},
}),
} as unknown as Response);
vi.stubGlobal("fetch", fetchMock);

try {
const auth = Promise.all([AuthVerify("onedrive"), AuthVerify("onedrive"), AuthVerify("onedrive")]);
await vi.advanceTimersByTimeAsync(1000);

await expect(auth).resolves.toEqual(["initial-access", "initial-access", "initial-access"]);
expect(createSpy).toHaveBeenCalledTimes(1);
expect(fetchMock).toHaveBeenCalledTimes(1);
expect(saveSpy).toHaveBeenCalledTimes(1);
expect(saveSpy).toHaveBeenCalledWith(key, expect.objectContaining({ accessToken: "initial-access" }));
} finally {
(chrome.tabs as any).get = originalGet;
vi.useRealTimers();
}
});
});
34 changes: 23 additions & 11 deletions packages/filesystem/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ export type Token = {
createtime: number;
};
const refreshTokenPromises: Partial<Record<NetDiskType, Promise<string>>> = {};
const authTokenPromises: Partial<Record<NetDiskType, Promise<Token>>> = {};

function refreshAccessToken(
netDiskType: NetDiskType,
Expand Down Expand Up @@ -126,19 +127,30 @@ export async function AuthVerify(netDiskType: NetDiskType, invalid?: boolean) {
}
// token不存在,或者没有accessToken,重新获取
if (!token || !token.accessToken) {
// 强制重新获取token
await NetDisk(netDiskType);
const resp = await GetNetDiskToken(netDiskType);
if (resp.code !== 0) {
throw new WarpTokenError(new Error(resp.msg));
if (!authTokenPromises[netDiskType]) {
const authPromise = (async () => {
// 强制重新获取token
await NetDisk(netDiskType);
const resp = await GetNetDiskToken(netDiskType);
if (resp.code !== 0) {
throw new WarpTokenError(new Error(resp.msg));
}
const newToken = {
accessToken: resp.data.token.access_token,
refreshToken: resp.data.token.refresh_token,
createtime: Date.now(),
};
await localStorageDAO.saveValue(key, newToken);
return newToken;
})().finally(() => {
if (authTokenPromises[netDiskType] === authPromise) {
delete authTokenPromises[netDiskType];
}
});
authTokenPromises[netDiskType] = authPromise;
}
token = {
accessToken: resp.data.token.access_token,
refreshToken: resp.data.token.refresh_token,
createtime: Date.now(),
};
token = await authTokenPromises[netDiskType];
invalid = false;
await localStorageDAO.saveValue(key, token);
}
// token未过期(一小时内)及有效则保留,不用刷新token
const unexpired = Date.now() < token.createtime + 3600000;
Expand Down
Loading