1 import { CryptoProxy } from '@proton/crypto';
2 import { arrayToHexString } from '@proton/crypto/lib/utils';
3 import { FAILED_TO_DOWNLOAD } from '@proton/llm/lib/constants';
4 import { postMessageIframeToParent } from '@proton/llm/lib/helpers';
5 import * as _ndarrayCache from '@proton/llm/resources/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/ndarray-cache.json';
6 import throttle from '@proton/utils/throttle';
8 import type { AssistantConfig, AssistantConfigModel } from './types';
9 import { AssistantEvent } from './types';
33 const ndarrayCache = _ndarrayCache as NdarrayCache;
37 MODEL = 'webllm/model',
38 CONFIG = 'webllm/config',
41 export type LlmFile = {
47 expectedSize?: number;
50 export type DownloadResult = {
52 chunks: Uint8Array<ArrayBufferLike>[];
57 export type AppCaches = { [k in CacheId]: Cache };
59 // A function to monitor the progress of a single file.
60 type OneFileProgressCallback = (url: string, received: number, total: number) => void;
63 * Functions used on the iframe app to
64 * - Download the files that we need to use the model
65 * - Send download events to the parent app
68 // Initiate a download, monitors the progress, and returns the result when finished.
69 async function downloadFile(
71 callback: OneFileProgressCallback,
72 abortController: AbortController
73 ): Promise<DownloadResult> {
74 const signal = abortController.signal;
75 const response = await fetch(downloadUrl, { signal });
76 const { status, statusText, ok } = response;
78 throw Error(`${downloadUrl}: ${FAILED_TO_DOWNLOAD} ${status} ${statusText}`);
80 const headers = new Headers(response.headers);
81 const reader = response.body?.getReader();
82 const contentLength = +response.headers.get('Content-Length')!;
83 let receivedLength = 0;
84 const chunks: Uint8Array<ArrayBufferLike>[] = [];
86 // Debounce the progress callback call to avoid sending too many events to the iframe parent
87 const debouncedCallback = throttle((callback) => {
92 const { done, value } = await reader!.read();
97 receivedLength += value.length;
98 const totalReceivedLength = receivedLength;
99 debouncedCallback(() => callback(downloadUrl, totalReceivedLength, contentLength));
101 let headersMap = Object.fromEntries(headers);
102 const serializedHeaders = JSON.stringify(headersMap);
103 return { status, statusText, headers: serializedHeaders, chunks };
106 // Compute the file url using the model_url so that we cache the file correctly in the parent app
107 async function downloadFilesSequentially(
109 callback: OneFileProgressCallback,
110 abortController: AbortController,
111 filesToIgnore: LlmFile[],
112 variantConfig: AssistantConfigModel,
115 let filesDownloaded = 0;
116 const totalFilesToDownload = files.length - filesToIgnore.length;
117 for (let i = 0; i < files.length; i++) {
118 if (abortController.signal.aborted) return;
120 const { downloadUrl, cacheUrl, cacheId, expectedMd5 } = files[i];
122 // Do not download files that have already been downloaded and cached
123 let ignoreThisFile = filesToIgnore.some((f) => f.downloadUrl === downloadUrl);
124 if (!ignoreThisFile) {
125 // Start the download for a new file.
126 const downloadResult = await downloadFile(downloadUrl, callback, abortController);
129 postMessageIframeToParent(
131 type: AssistantEvent.DOWNLOAD_DATA,
134 headers: downloadResult.headers,
135 chunks: downloadResult.chunks,
136 status: downloadResult.status,
137 statusText: downloadResult.statusText,
139 // use the parent model url so that we put the right element in cache
143 terminate: filesDownloaded === totalFilesToDownload,
147 downloadResult.chunks.map((chunk) => chunk.buffer)
153 // Prepare the list of all the files we need to download
154 function listFilesToDownload(variantConfig: AssistantConfigModel): LlmFile[] {
155 // From the iframe, we are downloading files using the model_download_url
156 // Then, before sending an event to the parent app,
157 // we will update the url so that we use the model_url domain instead for the file caching
158 const baseKey = new URL(variantConfig.model_download_url).pathname;
159 const baseDownloadUrl = variantConfig.model_download_url;
160 const baseCacheUrl = variantConfig.model_url;
162 // Since we don't have access to the app using the model from the cache,
163 // put the destination cache identifiers
164 const files: LlmFile[] = [];
166 // "webllm/model" -> ".../mlc-chat-config.json"
167 downloadUrl: new URL('mlc-chat-config.json', baseDownloadUrl).href,
168 cacheUrl: new URL('mlc-chat-config.json', baseCacheUrl).href,
169 cacheKey: `${baseKey}mlc-chat-config.json`,
170 cacheId: CacheId.CONFIG,
174 // "webllm/model" -> ".../tokenizer.json"
175 downloadUrl: new URL('tokenizer.json', baseDownloadUrl).href,
176 cacheUrl: new URL('tokenizer.json', baseCacheUrl).href,
177 cacheKey: `${baseKey}tokenizer.json`,
178 cacheId: CacheId.MODEL,
182 ...ndarrayCache.records.map((record) => ({
183 // "webllm/model" -> ".../params_shard_*.bin"
184 downloadUrl: new URL(record.dataPath, baseDownloadUrl).href,
185 cacheUrl: new URL(record.dataPath, baseCacheUrl).href,
186 cacheKey: `${baseKey}${record.dataPath}`,
187 cacheId: CacheId.MODEL,
188 expectedMd5: record.md5sum,
189 expectedSize: record.nbytes,
195 // Retrieves all the files that we need to use the model
196 export async function downloadModel(
198 assistantConfig: AssistantConfig,
199 abortController: AbortController,
200 filesToIgnore: LlmFile[],
203 // Grab the entry for our chosen model inside mlc-config.
204 const variantConfig = assistantConfig.model_list.find((m) => m.model_id === variant);
205 if (variantConfig === undefined) {
206 console.error(`Model not found in MLC config: ${variant}`);
207 throw Error(`Model not found in MLC config: ${variant}`);
210 // Prepare a list of files to download
211 const files = listFilesToDownload(variantConfig);
213 // This first map tracks how many bytes we need to download.
214 // { url: expectedSize }.
215 // Thanks to it, we can compute the total overall size to download by summing the values.
216 // For most files, especially for the model weights, the size is specified in a meta file, so we know it upfront.
217 // Unfortunately, some small files (like wasm and tokenizer.json) have an unknown size. However, we start to fetch
218 // it, Content-Length should tell us the real size. Therefore, this map will be modified a few times, as we fetch
219 // some of these files that have an initially unknown size.
220 const expectedSizes: Map<string, number> = new Map();
221 for (const f of files) {
222 if (f.expectedSize) {
223 expectedSizes.set(f.downloadUrl, f.expectedSize!);
226 const overallExpectedSize = () => [...expectedSizes.values()].reduce((acc, n) => acc + n, 0);
228 // This second map tracks how many bytes we have received for each file.
229 // { url: receivedSize }
230 // The purpose is to track the overall bytes we've downloaded so far. We compute this by summing the values too.
231 // Consequently, it will be frequently updated, namely each time we receive a new chunk of data.
232 const receivedSizes: Map<string, number> = new Map();
233 for (const f of files) {
234 let ignoreThisFile = filesToIgnore.some((ignored) => ignored.downloadUrl === f.downloadUrl);
235 if (ignoreThisFile) {
236 receivedSizes.set(f.downloadUrl, f.expectedSize || 0);
238 receivedSizes.set(f.downloadUrl, 0);
241 const overallReceived = () => [...receivedSizes.values()].reduce((acc, n) => acc + n, 0);
243 const nFinishedFiles = () =>
244 files.filter((f) => {
245 let r = receivedSizes.get(f.downloadUrl);
246 let e = expectedSizes.get(f.downloadUrl);
247 return r !== undefined && e !== undefined && r >= e;
250 // Start downloading files
251 const updateProgressOneFile = (downloadUrl: string, received: number, total: number) => {
252 if (!expectedSizes.has(downloadUrl) && total > 0) {
253 expectedSizes.set(downloadUrl, total);
255 const receivedCapped = Math.min(received, total);
256 receivedSizes.set(downloadUrl, receivedCapped);
257 const r = overallReceived();
258 const e = overallExpectedSize();
260 // Send a message to the parent app so that we can update the download progress
261 postMessageIframeToParent(
263 type: AssistantEvent.DOWNLOAD_PROGRESS,
267 estimatedTotalBytes: e,
268 receivedFiles: nFinishedFiles(),
269 totalFiles: files.length,
276 await downloadFilesSequentially(
278 updateProgressOneFile,
287 * Functions used on the parent app to
288 * - Check which files of the model have already been downloaded
289 * - Store files in the cache
292 async function computeMd5(data: Uint8Array) {
293 return arrayToHexString(await CryptoProxy.computeHash({ data, algorithm: 'unsafeMD5' }));
296 function buildFakeResponseForCache(data: string, filename: string, origin: string) {
297 const blob = new Blob([data], { type: 'text/plain;charset=utf-8' });
298 const headers = new Headers({
299 'Accept-Ranges': 'bytes',
300 'Access-Control-Allow-Origin': origin,
301 'Access-Control-Expose-Headers': 'Accept-Ranges,Content-Range',
302 'Content-Disposition': `inline; filename*=UTF-8''${filename}; filename="${filename}"`,
303 'Content-Length': blob.size.toString(),
304 'Content-Security-Policy': 'default-src none; sandbox',
305 'Content-Type': 'text/plain; charset=utf-8',
306 'Cross-Origin-Opener-Policy': 'same-origin',
307 Date: new Date().toUTCString(),
308 'Referrer-Policy': 'strict-origin-when-cross-origin',
311 const response = new Response(blob, {
319 // Store a piece of data that we don't actually need to download
320 // because we already have it in a JavaScript object right here.
321 // This creates a fake response that looks like the server sent
322 // it, and stores it in the cache.
323 async function storeLocalDataInCache(object: any, filename: string, cacheId: Cache, origin: string, baseKey: string) {
324 const data = JSON.stringify(object);
325 const response = buildFakeResponseForCache(data, filename, origin);
326 await cacheId.put(`${baseKey}${filename}`, response);
329 // Check if a given url is already present in the cache, and optionally verify its MD5 checksum.
330 async function existsInCache(
333 expects?: { md5?: string; size?: number }
334 ): Promise<boolean> {
335 let cachedResponse = await cache.match(cacheUrl);
336 if (!cachedResponse) {
339 const expectedSize = expects?.size;
340 const expectedMd5 = expects?.md5;
342 const needsCheck = expectedSize !== undefined || expectedMd5 !== undefined;
347 const arrayBuffer = await cachedResponse.arrayBuffer();
349 if (expectedSize !== undefined) {
350 return arrayBuffer.byteLength === expectedSize;
353 if (expectedMd5 !== undefined) {
354 // disabling for now; it takes too long to compute checksum for cached files
355 const BYPASS_CHECKSUM_FOR_CACHED_FILES = true;
356 if (BYPASS_CHECKSUM_FOR_CACHED_FILES) {
360 const data = new Uint8Array(arrayBuffer);
362 const actualMd5 = await computeMd5(data);
363 if (actualMd5 === expectedMd5) {
372 // Put files that we don't need to download from the iframe in the cache
373 async function cacheParentAppFiles(appCaches: AppCaches, variantConfig: AssistantConfigModel) {
374 // Cache files we already have statically and don't need to download. We pretend we have downloaded it, but
375 // in fact we just create a fake response to a nonexistent request, pretend the server sent it to us, and store
377 const origin = window.location.origin;
378 const baseKey = new URL(variantConfig.model_url).pathname;
379 // - "webllm/model" -> ".../ndarray-cache.json"
380 await storeLocalDataInCache(ndarrayCache, 'ndarray-cache.json', appCaches[CacheId.MODEL], origin, baseKey);
382 // Cache files that we stored in the app assets
383 // "webllm/wasm" -> ".../file.wasm"
384 const wasmUrl = variantConfig.model_lib_url; // 'https://mail.proton.me/.../file.wasm'
385 const isWasmInCache = await existsInCache(wasmUrl, appCaches[CacheId.WASM]);
386 if (!isWasmInCache) {
387 const wasmResponse = await fetch(wasmUrl);
388 await appCaches[CacheId.WASM].put(wasmUrl, wasmResponse);
392 // Search for files that we have already downloaded and stored in the cache
393 export async function getCachedFiles(variant: string, assistantConfig: AssistantConfig) {
394 const filesAlreadyDownloaded: LlmFile[] = [];
397 const appCaches: AppCaches = {
398 [CacheId.MODEL]: await caches.open(CacheId.MODEL),
399 [CacheId.WASM]: await caches.open(CacheId.WASM),
400 [CacheId.CONFIG]: await caches.open(CacheId.CONFIG),
403 // Grab the entry for our chosen model inside mlc-config.
404 const variantConfig = assistantConfig.model_list.find((m) => m.model_id === variant);
405 if (variantConfig === undefined) {
406 console.error(`Model not found in MLC config: ${variant}`);
407 throw Error(`Model not found in MLC config: ${variant}`);
410 // Put files that we don't need to download from the iframe in the cache
411 await cacheParentAppFiles(appCaches, variantConfig);
413 // Prepare a list of files that we need to run the model
414 const files = listFilesToDownload(variantConfig);
416 // Check which files are present in the cache
417 for (const f of files) {
418 const { cacheUrl, cacheId, expectedSize } = f;
419 const exists = await existsInCache(cacheUrl, appCaches[cacheId], { size: expectedSize });
422 filesAlreadyDownloaded.push(f);
426 const needsAdditionalDownload = filesAlreadyDownloaded.length !== files.length;
428 return { filesAlreadyDownloaded, needsAdditionalDownload, appCaches };
432 * Clears the cache related to the AI Assistant
433 * @returns A promise that resolves when the cache is cleared
435 export async function deleteAssistantCachedFiles() {
436 return Promise.all([caches.delete(CacheId.MODEL), caches.delete(CacheId.WASM), caches.delete(CacheId.CONFIG)]);
439 // Store a fully downloaded file into the cache.
440 export async function storeInCache(
441 downloadResult: DownloadResult,
444 expectedMd5: string | undefined
446 const { status, statusText, headers, chunks } = downloadResult;
447 const blob = new Blob(chunks);
448 const arrayBuffer = await blob.arrayBuffer();
449 const data = new Uint8Array(arrayBuffer);
450 const actualMd5 = await computeMd5(data);
452 if (!expectedMd5 || actualMd5 === expectedMd5) {
454 const parsedHeaders = JSON.parse(headers);
455 await cache.put(cacheUrl, new Response(blob, { status, statusText, headers: parsedHeaders }));
457 throw new Error(`${cacheUrl}: error while storing in cache: ${e}`);
460 throw new Error(`${cacheUrl}: checksum failed, expected ${expectedMd5}, got ${actualMd5}`);
464 // Map the destination cache ID with the cache in which we want to store files