Remove payments components
[ProtonMail-WebClient.git] / packages / llm / lib / hooks / useAssistantLocal.ts
blobf35943c33a53f514bd5344ebbe39265ad0da7744
1 import { useCallback, useEffect, useRef } from 'react';
3 import { c } from 'ttag';
5 import { useApi, useNotifications, useUser } from '@proton/components/hooks';
6 import useAssistantTelemetry from '@proton/components/hooks/assistant/useAssistantTelemetry';
7 import useStateRef from '@proton/hooks/useStateRef';
8 import type { AssistantHooksProps, AssistantRunningActions, GenerateAssistantResult } from '@proton/llm/lib';
9 import { MODEL_UNLOADED } from '@proton/llm/lib';
10 import {
11     CACHING_FAILED,
12     FAILED_TO_DOWNLOAD,
13     PromptRejectedError,
14     UNLOAD_ASSISTANT_TIMEOUT,
15     buildMLCConfig,
16     getGenerationType,
17     queryAssistantModels,
18 } from '@proton/llm/lib';
19 import { GpuLlmManager } from '@proton/llm/lib/actions';
20 import type useAssistantCommons from '@proton/llm/lib/hooks/useAssistantCommons';
21 import type {
22     AssistantConfig,
23     DownloadProgressInfo,
24     GenerationCallbackDetails,
25     LlmManager,
26     LlmModel,
27 } from '@proton/llm/lib/types';
28 import { ASSISTANT_TYPE, ERROR_TYPE, GENERATION_SELECTION_TYPE } from '@proton/shared/lib/assistant';
29 import { domIsBusy } from '@proton/shared/lib/busy';
30 import { isElectronApp } from '@proton/shared/lib/helpers/desktop';
31 import { traceInitiativeError } from '@proton/shared/lib/helpers/sentry';
33 interface Props {
34     commonState: ReturnType<typeof useAssistantCommons>;
35     active: boolean;
38 export const useAssistantLocal = ({ commonState, active }: Props): AssistantHooksProps => {
39     const api = useApi();
40     const { createNotification } = useNotifications();
41     const [user] = useUser();
43     const llmManager = useRef<LlmManager | null>(null);
44     const llmModel = useRef<LlmModel | null>(null);
46     const assistantConfigRef = useRef<AssistantConfig>();
47     /** In order to be able to wait for config to be set */
48     const assistantConfigPromiseRef = useRef<Promise<void>>();
50     const [
51         {
52             downloadModelSize,
53             downloadPaused,
54             downloadReceivedBytes,
55             isCheckingCache,
56             isModelDownloaded,
57             isModelDownloading,
58             isModelLoadedOnGPU,
59             isModelLoadingOnGPU,
60         },
61         setLocalState,
62         localStateRef,
63     ] = useStateRef({
64         isModelDownloaded: false,
65         isModelDownloading: false,
66         downloadReceivedBytes: 0,
67         downloadModelSize: 0,
68         downloadPaused: false,
69         isModelLoadedOnGPU: false,
70         isModelLoadingOnGPU: false,
71         // ref to know if the user downloaded the model in this session
72         userDownloadedModel: false,
73         // Value used to know when we are checking for model files in cache
74         isCheckingCache: false,
75     });
77     const generatedTokensNumber = useRef(0);
78     const initPromise = useRef<Promise<void>>();
80     const {
81         addSpecificError,
82         cleanSpecificErrors,
83         addGlobalError,
84         cleanGlobalErrors,
85         hasCompatibleBrowser,
86         hasCompatibleHardware,
87         canShowAssistant,
88         assistantSubscriptionStatus,
89         closeAssistant,
90     } = commonState;
92     const {
93         sendRequestAssistantReport,
94         sendUnloadModelAssistantReport,
95         sendDownloadAssistantReport,
96         sendLoadModelAssistantReport,
97     } = useAssistantTelemetry();
99     const {
100         addRunningAction,
101         runningActions,
102         getRunningActionFromAssistantID,
103         cleanRunningActions,
104         resetResolvers,
105         runningActionsRef,
106         // eslint-disable-next-line @typescript-eslint/no-use-before-define
107     } = useRunningActions({ addSpecificError });
109     const handleGetAssistantConfig = () => {
110         llmManager.current = new GpuLlmManager();
112         assistantConfigPromiseRef.current = new Promise((resolve, reject) => {
113             void queryAssistantModels(api)
114                 .then(async (models) => {
115                     const config = buildMLCConfig(models);
116                     if (config) {
117                         assistantConfigRef.current = config;
119                         if (llmManager.current) {
120                             // Check if user has all needed files in cache, so that we can set the state and avoid
121                             // going through the download phase when it's not needed during init
122                             setLocalState((localState) => ({
123                                 ...localState,
124                                 isCheckingCache: true,
125                             }));
126                             await llmManager.current.isDownloaded(config).then((isDownloaded) => {
127                                 setLocalState((localState) => ({
128                                     ...localState,
129                                     isModelDownloading: false,
130                                     isModelDownloaded: isDownloaded,
131                                     isCheckingCache: false,
132                                 }));
133                             });
134                         }
135                     }
136                     // Resolve the config promise ref so that we can proceed with init if needed
137                     resolve();
138                 })
139                 .catch(reject);
140         });
141     };
143     useEffect(() => {
144         if (active && !llmManager.current && canShowAssistant && hasCompatibleHardware && hasCompatibleBrowser) {
145             // Start llm manager and get assistant models API side
146             handleGetAssistantConfig();
147         }
148     }, [active, canShowAssistant, hasCompatibleHardware, hasCompatibleBrowser]);
150     const downloadCallback = (info: DownloadProgressInfo) => {
151         setLocalState((localState) => ({
152             ...localState,
153             downloadModelSize: info.estimatedTotalBytes,
154             downloadReceivedBytes: info.receivedBytes,
155             userDownloadedModel: true,
156         }));
157     };
159     const downloadModel = async () => {
160         try {
161             if (!llmManager.current || !assistantConfigRef.current) {
162                 return; // throw an error?
163             }
164             // Clean global errors when downloading the model
165             cleanGlobalErrors();
167             const startDownloadingTime = performance.now();
168             setLocalState((localState) => ({
169                 ...localState,
170                 isModelDownloading: true,
171             }));
172             const completed = await llmManager.current.startDownload(downloadCallback, assistantConfigRef.current);
173             setLocalState((localState) => ({
174                 ...localState,
175                 isModelDownloading: false,
176                 isModelDownloaded: completed,
177             }));
178             if (completed) {
179                 // fixme: this can report partial download time if we were resuming a previous download session
180                 const endDownloadingTime = performance.now();
181                 const downloadingTime = endDownloadingTime - startDownloadingTime;
182                 sendDownloadAssistantReport(downloadingTime);
183                 return true;
184             }
185             return false;
186         } catch (e: any) {
187             let errorMessage;
188             if (e.message === CACHING_FAILED) {
189                 addGlobalError(ASSISTANT_TYPE.LOCAL, ERROR_TYPE.CACHING_FAILED);
190             } else {
191                 const isRequestError = e.message.includes(FAILED_TO_DOWNLOAD);
192                 addGlobalError(
193                     ASSISTANT_TYPE.LOCAL,
194                     isRequestError ? ERROR_TYPE.DOWNLOAD_REQUEST_FAIL : ERROR_TYPE.DOWNLOAD_FAIL
195                 );
196             }
198             traceInitiativeError('assistant', e);
199             console.error(e);
200             setLocalState((localState) => ({
201                 ...localState,
202                 isModelDownloading: false,
203             }));
204             throw new Error(errorMessage);
205         }
206     };
208     const loadModelOnGPU = async () => {
209         try {
210             if (!llmManager.current || !assistantConfigRef.current) {
211                 return; // throw an error?
212             }
213             const startLoadingTime = performance.now();
214             // Clean global errors when loading the model
215             cleanGlobalErrors();
217             setLocalState((localState) => ({
218                 ...localState,
219                 isModelLoadingOnGPU: true,
220             }));
221             const model = await llmManager.current.loadOnGpu(assistantConfigRef.current);
222             llmModel.current = model;
223             setLocalState((localState) => ({
224                 ...localState,
225                 isModelLoadingOnGPU: false,
226                 isModelLoadedOnGPU: true,
227             }));
228             const endLoadingTime = performance.now();
229             const loadingTime = endLoadingTime - startLoadingTime;
231             sendLoadModelAssistantReport(loadingTime);
232         } catch (e: any) {
233             traceInitiativeError('assistant', e);
234             console.error(e);
235             const errorMessage = addGlobalError(ASSISTANT_TYPE.LOCAL, ERROR_TYPE.LOADGPU_FAIL);
236             setLocalState((localState) => ({
237                 ...localState,
238                 isModelLoadingOnGPU: false,
239                 isModelLoadedOnGPU: false,
240             }));
241             throw new Error(errorMessage);
242         }
243     };
245     const unloadModelOnGPU = async () => {
246         try {
247             if (llmModel.current) {
248                 await llmModel.current?.unload();
249                 setLocalState((localState) => ({
250                     ...localState,
251                     isModelLoadedOnGPU: false,
252                 }));
254                 sendUnloadModelAssistantReport();
255             }
256         } catch (e: any) {
257             traceInitiativeError('assistant', e);
258             console.error(e);
259             addGlobalError(ASSISTANT_TYPE.LOCAL, ERROR_TYPE.UNLOAD_FAIL);
260         }
261     };
263     const initAssistant = useCallback(() => {
264         // Reset download pause state
265         setLocalState((localState) => ({
266             ...localState,
267             downloadPaused: false,
268         }));
270         // If the assistant is already initializing, then we simply wait for the end of the initialization
271         if (initPromise.current) {
272             return initPromise.current;
273         }
275         initPromise.current = (async () => {
276             /*
277              * To init the assistant
278              * 1 - We start by downloading the model if not downloaded yet and model is not downloading at the moment
279              * 2 - Then we can load the model on the GPU if not loaded yet and not loading at the moment
280              */
282             // Use try catch in case one of the steps fails, so that we don't run the next step
283             try {
284                 let completedDownload;
286                 // If assistant config is not set at all, start get config process manually
287                 // Typically, this happens when starting the assistant for the first time and choosing local mode.
288                 // We trigger the init manually, but the useEffect that handle getting the config has not run yet.
289                 if (!llmManager.current) {
290                     handleGetAssistantConfig();
291                 }
293                 // Ensure config is set before starting init
294                 if (assistantConfigPromiseRef.current) {
295                     await assistantConfigPromiseRef.current;
296                 }
298                 // We don't want to go through the init process with free users.
299                 // Free users have no trial period, so they shouldn't be able to download the model,
300                 // and of course they shouldn't pass in the loading on GPU step too
301                 if (user.isFree) {
302                     return;
303                 }
305                 const {
306                     isModelDownloaded,
307                     isModelDownloading,
308                     isModelLoadedOnGPU,
309                     isModelLoadingOnGPU,
310                     userDownloadedModel,
311                 } = localStateRef.current;
313                 if (!isModelDownloaded && !isModelDownloading) {
314                     completedDownload = await downloadModel();
315                 } else if (isModelDownloaded) {
316                     completedDownload = true;
317                 }
319                 if (completedDownload && !isModelLoadedOnGPU && !isModelLoadingOnGPU) {
320                     await loadModelOnGPU();
321                 }
323                 // Show a notification only when the user had to download the model
324                 if (completedDownload && userDownloadedModel) {
325                     createNotification({
326                         text: c('Notification').t`The writing assistant is ready to use`,
327                     });
329                     setLocalState((localState) => ({
330                         ...localState,
331                         userDownloadedModel: false,
332                     }));
333                 }
334             } catch {}
335         })().finally(() => {
336             // Reset init promise after init or when init failed so that we can
337             // - Start init again if necessary
338             // - Proceed if init completed
339             initPromise.current = undefined;
340         });
341     }, [user]);
343     const cancelDownloadModel = () => {
344         if (llmManager.current) {
345             llmManager.current.cancelDownload();
346             initPromise.current = undefined;
347             setLocalState((localState) => ({
348                 ...localState,
349                 downloadPaused: true,
350             }));
351         }
352     };
354     const resumeDownloadModel = () => {
355         void initAssistant();
356         setLocalState((localState) => ({
357             ...localState,
358             downloadPaused: false,
359         }));
360     };
362     const generateResult = async ({ action, callback, assistantID, hasSelection }: GenerateAssistantResult) => {
363         // TODO prevent submit if user made too much harmful requests recently
365         // Do not start multiple actions in the same assistant
366         const runningActionInAssistant = getRunningActionFromAssistantID(assistantID);
367         if (runningActionInAssistant) {
368             return;
369         }
371         // Reset generation errors in this assistant
372         cleanSpecificErrors(assistantID);
374         let isResolved = false;
376         // The generation needs to be stopped in two different cases:
377         // 1 - The assistant is ready (everything is loaded) and the user stops it.
378         //      In that case, we can stop the running action.
379         // 2 - The assistant is still loading, but the user submitted a request and cancelled it.
380         //      In that case, we don't have the running action yet, so we need to cancel the promise.
381         //      That's why the entire function is run into a Promise. We can then store it in a ref and cancel it when needed.
382         await new Promise<void>(async (res) => {
383             const resolve = () => {
384                 res();
385                 isResolved = true;
386             };
387             addRunningAction(assistantID, resolve);
389             const ingestionStart = performance.now();
390             try {
391                 // Start the initialization in case the assistant is not loaded yet.
392                 // If it is loaded already, then nothing will be done,
393                 // else we will wait for the init process to be finished before starting the generation
394                 await initAssistant();
395             } catch {
396                 // If an error occurred during the init, we set an error.
397                 // In that case, we can stop the generation before going further.
398                 return;
399             }
401             const ingestionEnd = performance.now();
402             const ingestionTime = ingestionEnd - ingestionStart;
404             try {
405                 // If the promise is resolved, we cancelled it after a user interaction.
406                 // We don't want to generate a result anymore.
407                 if (llmModel.current && !isResolved) {
408                     let promptRejectedOnce = false;
409                     const generationCallback = (fulltext: string, details?: GenerationCallbackDetails) => {
410                         if (promptRejectedOnce) {
411                             return;
412                         }
413                         generatedTokensNumber.current++;
414                         const isHarmful = details?.harmful;
416                         // Used to prevent adding additional tokens that we receive after cancelling a running action
417                         const isRunningAction = assistantID in runningActionsRef.current;
419                         if (!isHarmful) {
420                             if (isRunningAction) {
421                                 callback(fulltext);
422                             }
423                         } else {
424                             promptRejectedOnce = true;
425                             cleanRunningActions(assistantID);
426                         }
427                     };
429                     const generationStart = performance.now();
431                     if (assistantSubscriptionStatus.trialStatus === 'trial-not-started') {
432                         await assistantSubscriptionStatus.start();
433                     }
435                     const runningAction = await llmModel.current.performAction(action, generationCallback);
437                     addRunningAction(assistantID, () => {
438                         // Resolve is needed to end parent promise
439                         resolve();
440                         // Here we stop the LLM generation
441                         runningAction.cancel();
442                     });
444                     await runningAction.waitForCompletion();
446                     // Throw an error if the user made a harmful request
447                     if (promptRejectedOnce) {
448                         throw new PromptRejectedError();
449                     }
451                     // Send telemetry report
452                     const generationEnd = performance.now();
453                     const generationTime = generationEnd - generationStart;
454                     sendRequestAssistantReport({
455                         assistantType: ASSISTANT_TYPE.LOCAL,
456                         generationType: getGenerationType(action),
457                         selectionType: hasSelection
458                             ? GENERATION_SELECTION_TYPE.HAS_SELECTION
459                             : GENERATION_SELECTION_TYPE.NO_SELECTION,
460                         ingestionTime,
461                         generationTime,
462                         tokensGenerated: generatedTokensNumber.current,
463                     });
464                     generatedTokensNumber.current = 0;
465                 }
466             } catch (e: any) {
467                 // Sometimes the model is being unloaded automatically by the llm lib.
468                 // If we detect that the model is not loaded during generate, load it again and retry generation
469                 if (e === MODEL_UNLOADED) {
470                     await loadModelOnGPU();
471                     return generateResult({ action, callback, assistantID, hasSelection });
472                 }
474                 if (e.name === 'PromptRejectedError') {
475                     addSpecificError({
476                         assistantID,
477                         assistantType: ASSISTANT_TYPE.LOCAL,
478                         errorType: ERROR_TYPE.GENERATION_HARMFUL,
479                     });
480                 } else {
481                     addSpecificError({
482                         assistantID,
483                         assistantType: ASSISTANT_TYPE.LOCAL,
484                         errorType: ERROR_TYPE.GENERATION_FAIL,
485                     });
486                 }
487                 traceInitiativeError('assistant', e);
488                 console.error(e);
490                 // Reset assistant result when an error occurred while generating
491                 // Otherwise, on next submit the previous result will be displayed for a few ms
492                 callback('');
493             }
495             // Reset the generating state
496             cleanRunningActions(assistantID);
497         });
498     };
500     const resetAssistantState = () => {
501         // Cancel model downloading
502         if (initPromise.current) {
503             void cancelDownloadModel();
504         }
505         // Unload model from GPU
506         if (isModelLoadedOnGPU) {
507             void unloadModelOnGPU();
508         }
509         // Cancel all running actions
510         if (Object.keys(runningActions).length) {
511             resetResolvers();
512         }
514         // Reset all states
515         setLocalState((localState) => ({
516             ...localState,
517             isModelDownloaded: false,
518             isModelDownloading: false,
519             isModelLoadedOnGPU: false,
520             isModelLoadingOnGPU: false,
521             downloadPaused: false,
522             downloadModelSize: 0,
523             downloadReceivedBytes: 0,
524         }));
525     };
527     // Unload the model after some time of non usage
528     // If model is loaded on the GPU, check every X minutes if user is busy
529     // Reset the timeout completely when user is generating a new result
530     useEffect(() => {
531         if (!isModelLoadedOnGPU) {
532             return;
533         }
535         const id = setInterval(() => {
536             if (domIsBusy()) {
537                 return;
538             }
539             if (isElectronApp && document.hasFocus()) {
540                 return;
541             }
543             void unloadModelOnGPU();
544         }, UNLOAD_ASSISTANT_TIMEOUT);
546         return () => {
547             clearInterval(id);
548         };
549     }, [isModelLoadedOnGPU, runningActions]);
551     return {
552         assistantConfig: assistantConfigRef.current,
554         initAssistant,
556         // download related
557         downloadModelSize,
558         downloadReceivedBytes,
559         downloadPaused,
560         isModelDownloaded,
561         isModelDownloading,
562         cancelDownloadModel,
563         resumeDownloadModel,
564         isCheckingCache,
566         // GPU loading related
567         isModelLoadedOnGPU,
568         isModelLoadingOnGPU,
569         unloadModelOnGPU,
571         // Generate related
572         generateResult,
573         runningActions,
574         cancelRunningAction: cleanRunningActions,
576         closeAssistant: closeAssistant(cleanRunningActions),
578         resetAssistantState,
579     };
582 interface UseRunningActionsProps {
583     addSpecificError: ReturnType<typeof useAssistantCommons>['addSpecificError'];
586 function useRunningActions({ addSpecificError }: UseRunningActionsProps) {
587     const [runningActions, setRunningAction, runningActionsRef] = useStateRef<AssistantRunningActions>({});
589     const addRunningAction = (assistantID: string, resolver: () => void) => {
590         setRunningAction((runningActions) => ({
591             ...runningActions,
592             [assistantID]: resolver,
593         }));
594     };
596     const cleanRunningActions = (assistantID: string) => {
597         try {
598             const runningAction = runningActionsRef.current[assistantID];
599             if (runningAction) {
600                 runningAction();
601                 setRunningAction((runningAction) => {
602                     delete runningAction[assistantID];
603                     return { ...runningAction };
604                 });
605             }
606         } catch (e: any) {
607             traceInitiativeError('assistant', e);
608             addSpecificError({
609                 assistantID,
610                 assistantType: ASSISTANT_TYPE.LOCAL,
611                 errorType: ERROR_TYPE.GENERATION_CANCEL_FAIL,
612             });
613         }
614     };
616     return {
617         addRunningAction,
618         cleanRunningActions,
619         getRunningActionFromAssistantID: (assistantID: string) => assistantID in runningActions,
620         resetResolvers: () => setRunningAction({}),
621         runningActions,
622         runningActionsRef,
623     };