Merge branch 'renovate/all-minor-patch' into 'main'
[ProtonMail-WebClient.git] / packages / llm / lib / downloader.ts
blob3d2bf181f4a405ca129127c8ae509aff8fbf42a5
1 import { CryptoProxy } from '@proton/crypto';
2 import { arrayToHexString } from '@proton/crypto/lib/utils';
3 import { FAILED_TO_DOWNLOAD } from '@proton/llm/lib/constants';
4 import { postMessageIframeToParent } from '@proton/llm/lib/helpers';
5 import * as _ndarrayCache from '@proton/llm/resources/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/ndarray-cache.json';
6 import throttle from '@proton/utils/throttle';
8 import type { AssistantConfig, AssistantConfigModel } from './types';
9 import { AssistantEvent } from './types';
11 type NdarrayCache = {
12     metadata: {
13         ParamSize: number;
14         ParamBytes: number;
15         BitsPerParam: number;
16     };
17     records: {
18         dataPath: string;
19         format: string;
20         nbytes: number;
21         records: {
22             name: string;
23             shape: number[];
24             dtype: string;
25             format: string;
26             nbytes: number;
27             byteOffset: number;
28         }[];
29         md5sum: string;
30     }[];
33 const ndarrayCache = _ndarrayCache as NdarrayCache;
35 export enum CacheId {
36     WASM = 'webllm/wasm',
37     MODEL = 'webllm/model',
38     CONFIG = 'webllm/config',
41 export type LlmFile = {
42     downloadUrl: string;
43     cacheUrl: string;
44     cacheKey: string;
45     cacheId: CacheId;
46     expectedMd5?: string;
47     expectedSize?: number;
50 export type DownloadResult = {
51     headers: string;
52     chunks: Uint8Array<ArrayBufferLike>[];
53     statusText: string;
54     status: number;
57 export type AppCaches = { [k in CacheId]: Cache };
59 // A function to monitor the progress of a single file.
60 type OneFileProgressCallback = (url: string, received: number, total: number) => void;
62 /**
63  * Functions used on the iframe app to
64  * - Download the files that we need to use the model
65  * - Send download events to the parent app
66  */
68 // Initiate a download, monitors the progress, and returns the result when finished.
69 async function downloadFile(
70     downloadUrl: string,
71     callback: OneFileProgressCallback,
72     abortController: AbortController
73 ): Promise<DownloadResult> {
74     const signal = abortController.signal;
75     const response = await fetch(downloadUrl, { signal });
76     const { status, statusText, ok } = response;
77     if (!ok) {
78         throw Error(`${downloadUrl}: ${FAILED_TO_DOWNLOAD} ${status} ${statusText}`);
79     }
80     const headers = new Headers(response.headers);
81     const reader = response.body?.getReader();
82     const contentLength = +response.headers.get('Content-Length')!;
83     let receivedLength = 0;
84     const chunks: Uint8Array<ArrayBufferLike>[] = [];
86     // Debounce the progress callback call to avoid sending too many events to the iframe parent
87     const debouncedCallback = throttle((callback) => {
88         callback();
89     }, 200);
91     while (true) {
92         const { done, value } = await reader!.read();
93         if (done) {
94             break;
95         }
96         chunks.push(value);
97         receivedLength += value.length;
98         const totalReceivedLength = receivedLength;
99         debouncedCallback(() => callback(downloadUrl, totalReceivedLength, contentLength));
100     }
101     let headersMap = Object.fromEntries(headers);
102     const serializedHeaders = JSON.stringify(headersMap);
103     return { status, statusText, headers: serializedHeaders, chunks };
106 // Compute the file url using the model_url so that we cache the file correctly in the parent app
107 async function downloadFilesSequentially(
108     files: LlmFile[],
109     callback: OneFileProgressCallback,
110     abortController: AbortController,
111     filesToIgnore: LlmFile[],
112     variantConfig: AssistantConfigModel,
113     parentURL: string
114 ) {
115     let filesDownloaded = 0;
116     const totalFilesToDownload = files.length - filesToIgnore.length;
117     for (let i = 0; i < files.length; i++) {
118         if (abortController.signal.aborted) return;
120         const { downloadUrl, cacheUrl, cacheId, expectedMd5 } = files[i];
122         // Do not download files that have already been downloaded and cached
123         let ignoreThisFile = filesToIgnore.some((f) => f.downloadUrl === downloadUrl);
124         if (!ignoreThisFile) {
125             // Start the download for a new file.
126             const downloadResult = await downloadFile(downloadUrl, callback, abortController);
127             filesDownloaded++;
129             postMessageIframeToParent(
130                 {
131                     type: AssistantEvent.DOWNLOAD_DATA,
132                     payload: {
133                         downloadResult: {
134                             headers: downloadResult.headers,
135                             chunks: downloadResult.chunks,
136                             status: downloadResult.status,
137                             statusText: downloadResult.statusText,
138                         },
139                         // use the parent model url so that we put the right element in cache
140                         cacheId,
141                         cacheUrl,
142                         expectedMd5,
143                         terminate: filesDownloaded === totalFilesToDownload,
144                     },
145                 },
146                 parentURL,
147                 downloadResult.chunks.map((chunk) => chunk.buffer)
148             );
149         }
150     }
153 // Prepare the list of all the files we need to download
154 function listFilesToDownload(variantConfig: AssistantConfigModel): LlmFile[] {
155     // From the iframe, we are downloading files using the model_download_url
156     // Then, before sending an event to the parent app,
157     // we will update the url so that we use the model_url domain instead for the file caching
158     const baseKey = new URL(variantConfig.model_download_url).pathname;
159     const baseDownloadUrl = variantConfig.model_download_url;
160     const baseCacheUrl = variantConfig.model_url;
162     // Since we don't have access to the app using the model from the cache,
163     // put the destination cache identifiers
164     const files: LlmFile[] = [];
165     files.push({
166         // "webllm/model" -> ".../mlc-chat-config.json"
167         downloadUrl: new URL('mlc-chat-config.json', baseDownloadUrl).href,
168         cacheUrl: new URL('mlc-chat-config.json', baseCacheUrl).href,
169         cacheKey: `${baseKey}mlc-chat-config.json`,
170         cacheId: CacheId.CONFIG,
171     });
173     files.push({
174         // "webllm/model" -> ".../tokenizer.json"
175         downloadUrl: new URL('tokenizer.json', baseDownloadUrl).href,
176         cacheUrl: new URL('tokenizer.json', baseCacheUrl).href,
177         cacheKey: `${baseKey}tokenizer.json`,
178         cacheId: CacheId.MODEL,
179     });
181     files.push(
182         ...ndarrayCache.records.map((record) => ({
183             // "webllm/model" -> ".../params_shard_*.bin"
184             downloadUrl: new URL(record.dataPath, baseDownloadUrl).href,
185             cacheUrl: new URL(record.dataPath, baseCacheUrl).href,
186             cacheKey: `${baseKey}${record.dataPath}`,
187             cacheId: CacheId.MODEL,
188             expectedMd5: record.md5sum,
189             expectedSize: record.nbytes,
190         }))
191     );
192     return files;
195 // Retrieves all the files that we need to use the model
196 export async function downloadModel(
197     variant: string,
198     assistantConfig: AssistantConfig,
199     abortController: AbortController,
200     filesToIgnore: LlmFile[],
201     parentURL: string
202 ) {
203     // Grab the entry for our chosen model inside mlc-config.
204     const variantConfig = assistantConfig.model_list.find((m) => m.model_id === variant);
205     if (variantConfig === undefined) {
206         console.error(`Model not found in MLC config: ${variant}`);
207         throw Error(`Model not found in MLC config: ${variant}`);
208     }
210     // Prepare a list of files to download
211     const files = listFilesToDownload(variantConfig);
213     // This first map tracks how many bytes we need to download.
214     //   { url: expectedSize }.
215     // Thanks to it, we can compute the total overall size to download by summing the values.
216     // For most files, especially for the model weights, the size is specified in a meta file, so we know it upfront.
217     // Unfortunately, some small files (like wasm and tokenizer.json) have an unknown size. However, we start to fetch
218     // it, Content-Length should tell us the real size. Therefore, this map will be modified a few times, as we fetch
219     // some of these files that have an initially unknown size.
220     const expectedSizes: Map<string, number> = new Map();
221     for (const f of files) {
222         if (f.expectedSize) {
223             expectedSizes.set(f.downloadUrl, f.expectedSize!);
224         }
225     }
226     const overallExpectedSize = () => [...expectedSizes.values()].reduce((acc, n) => acc + n, 0);
228     // This second map tracks how many bytes we have received for each file.
229     //   { url: receivedSize }
230     // The purpose is to track the overall bytes we've downloaded so far. We compute this by summing the values too.
231     // Consequently, it will be frequently updated, namely each time we receive a new chunk of data.
232     const receivedSizes: Map<string, number> = new Map();
233     for (const f of files) {
234         let ignoreThisFile = filesToIgnore.some((ignored) => ignored.downloadUrl === f.downloadUrl);
235         if (ignoreThisFile) {
236             receivedSizes.set(f.downloadUrl, f.expectedSize || 0);
237         } else {
238             receivedSizes.set(f.downloadUrl, 0);
239         }
240     }
241     const overallReceived = () => [...receivedSizes.values()].reduce((acc, n) => acc + n, 0);
243     const nFinishedFiles = () =>
244         files.filter((f) => {
245             let r = receivedSizes.get(f.downloadUrl);
246             let e = expectedSizes.get(f.downloadUrl);
247             return r !== undefined && e !== undefined && r >= e;
248         }).length;
250     // Start downloading files
251     const updateProgressOneFile = (downloadUrl: string, received: number, total: number) => {
252         if (!expectedSizes.has(downloadUrl) && total > 0) {
253             expectedSizes.set(downloadUrl, total);
254         }
255         const receivedCapped = Math.min(received, total);
256         receivedSizes.set(downloadUrl, receivedCapped);
257         const r = overallReceived();
258         const e = overallExpectedSize();
260         // Send a message to the parent app so that we can update the download progress
261         postMessageIframeToParent(
262             {
263                 type: AssistantEvent.DOWNLOAD_PROGRESS,
264                 payload: {
265                     progress: {
266                         receivedBytes: r,
267                         estimatedTotalBytes: e,
268                         receivedFiles: nFinishedFiles(),
269                         totalFiles: files.length,
270                     },
271                 },
272             },
273             parentURL
274         );
275     };
276     await downloadFilesSequentially(
277         files,
278         updateProgressOneFile,
279         abortController,
280         filesToIgnore,
281         variantConfig,
282         parentURL
283     );
287  * Functions used on the parent app to
288  * - Check which files of the model have already been downloaded
289  * - Store files in the cache
290  */
292 async function computeMd5(data: Uint8Array) {
293     return arrayToHexString(await CryptoProxy.computeHash({ data, algorithm: 'unsafeMD5' }));
296 function buildFakeResponseForCache(data: string, filename: string, origin: string) {
297     const blob = new Blob([data], { type: 'text/plain;charset=utf-8' });
298     const headers = new Headers({
299         'Accept-Ranges': 'bytes',
300         'Access-Control-Allow-Origin': origin,
301         'Access-Control-Expose-Headers': 'Accept-Ranges,Content-Range',
302         'Content-Disposition': `inline; filename*=UTF-8''${filename}; filename="${filename}"`,
303         'Content-Length': blob.size.toString(),
304         'Content-Security-Policy': 'default-src none; sandbox',
305         'Content-Type': 'text/plain; charset=utf-8',
306         'Cross-Origin-Opener-Policy': 'same-origin',
307         Date: new Date().toUTCString(),
308         'Referrer-Policy': 'strict-origin-when-cross-origin',
309         Vary: 'Origin',
310     });
311     const response = new Response(blob, {
312         status: 200,
313         statusText: 'OK',
314         headers: headers,
315     });
316     return response;
319 // Store a piece of data that we don't actually need to download
320 // because we already have it in a JavaScript object right here.
321 // This creates a fake response that looks like the server sent
322 // it, and stores it in the cache.
323 async function storeLocalDataInCache(object: any, filename: string, cacheId: Cache, origin: string, baseKey: string) {
324     const data = JSON.stringify(object);
325     const response = buildFakeResponseForCache(data, filename, origin);
326     await cacheId.put(`${baseKey}${filename}`, response);
329 // Check if a given url is already present in the cache, and optionally verify its MD5 checksum.
330 async function existsInCache(
331     cacheUrl: string,
332     cache: Cache,
333     expects?: { md5?: string; size?: number }
334 ): Promise<boolean> {
335     let cachedResponse = await cache.match(cacheUrl);
336     if (!cachedResponse) {
337         return false;
338     }
339     const expectedSize = expects?.size;
340     const expectedMd5 = expects?.md5;
342     const needsCheck = expectedSize !== undefined || expectedMd5 !== undefined;
343     if (!needsCheck) {
344         return true;
345     }
347     const arrayBuffer = await cachedResponse.arrayBuffer();
349     if (expectedSize !== undefined) {
350         return arrayBuffer.byteLength === expectedSize;
351     }
353     if (expectedMd5 !== undefined) {
354         // disabling for now; it takes too long to compute checksum for cached files
355         const BYPASS_CHECKSUM_FOR_CACHED_FILES = true;
356         if (BYPASS_CHECKSUM_FOR_CACHED_FILES) {
357             return true;
358         }
360         const data = new Uint8Array(arrayBuffer);
362         const actualMd5 = await computeMd5(data);
363         if (actualMd5 === expectedMd5) {
364             return true;
365         }
366         return false;
367     }
369     return true;
372 // Put files that we don't need to download from the iframe in the cache
373 async function cacheParentAppFiles(appCaches: AppCaches, variantConfig: AssistantConfigModel) {
374     // Cache files we already have statically and don't need to download. We pretend we have downloaded it, but
375     // in fact we just create a fake response to a nonexistent request, pretend the server sent it to us, and store
376     // it in the cache.
377     const origin = window.location.origin;
378     const baseKey = new URL(variantConfig.model_url).pathname;
379     // - "webllm/model" -> ".../ndarray-cache.json"
380     await storeLocalDataInCache(ndarrayCache, 'ndarray-cache.json', appCaches[CacheId.MODEL], origin, baseKey);
382     // Cache files that we stored in the app assets
383     // "webllm/wasm" -> ".../file.wasm"
384     const wasmUrl = variantConfig.model_lib_url; // 'https://mail.proton.me/.../file.wasm'
385     const isWasmInCache = await existsInCache(wasmUrl, appCaches[CacheId.WASM]);
386     if (!isWasmInCache) {
387         const wasmResponse = await fetch(wasmUrl);
388         await appCaches[CacheId.WASM].put(wasmUrl, wasmResponse);
389     }
392 // Search for files that we have already downloaded and stored in the cache
393 export async function getCachedFiles(variant: string, assistantConfig: AssistantConfig) {
394     const filesAlreadyDownloaded: LlmFile[] = [];
396     // Open caches
397     const appCaches: AppCaches = {
398         [CacheId.MODEL]: await caches.open(CacheId.MODEL),
399         [CacheId.WASM]: await caches.open(CacheId.WASM),
400         [CacheId.CONFIG]: await caches.open(CacheId.CONFIG),
401     };
403     // Grab the entry for our chosen model inside mlc-config.
404     const variantConfig = assistantConfig.model_list.find((m) => m.model_id === variant);
405     if (variantConfig === undefined) {
406         console.error(`Model not found in MLC config: ${variant}`);
407         throw Error(`Model not found in MLC config: ${variant}`);
408     }
410     // Put files that we don't need to download from the iframe in the cache
411     await cacheParentAppFiles(appCaches, variantConfig);
413     // Prepare a list of files that we need to run the model
414     const files = listFilesToDownload(variantConfig);
416     // Check which files are present in the cache
417     for (const f of files) {
418         const { cacheUrl, cacheId, expectedSize } = f;
419         const exists = await existsInCache(cacheUrl, appCaches[cacheId], { size: expectedSize });
421         if (exists) {
422             filesAlreadyDownloaded.push(f);
423         }
424     }
426     const needsAdditionalDownload = filesAlreadyDownloaded.length !== files.length;
428     return { filesAlreadyDownloaded, needsAdditionalDownload, appCaches };
432  * Clears the cache related to the AI Assistant
433  * @returns A promise that resolves when the cache is cleared
434  */
435 export async function deleteAssistantCachedFiles() {
436     return Promise.all([caches.delete(CacheId.MODEL), caches.delete(CacheId.WASM), caches.delete(CacheId.CONFIG)]);
439 // Store a fully downloaded file into the cache.
440 export async function storeInCache(
441     downloadResult: DownloadResult,
442     cacheUrl: string,
443     cache: Cache,
444     expectedMd5: string | undefined
445 ) {
446     const { status, statusText, headers, chunks } = downloadResult;
447     const blob = new Blob(chunks);
448     const arrayBuffer = await blob.arrayBuffer();
449     const data = new Uint8Array(arrayBuffer);
450     const actualMd5 = await computeMd5(data);
452     if (!expectedMd5 || actualMd5 === expectedMd5) {
453         try {
454             const parsedHeaders = JSON.parse(headers);
455             await cache.put(cacheUrl, new Response(blob, { status, statusText, headers: parsedHeaders }));
456         } catch (e) {
457             throw new Error(`${cacheUrl}: error while storing in cache: ${e}`);
458         }
459     } else {
460         throw new Error(`${cacheUrl}: checksum failed, expected ${expectedMd5}, got ${actualMd5}`);
461     }
464 // Map the destination cache ID with the cache in which we want to store files