1 import { useCallback, useEffect, useRef } from 'react';
3 import { c } from 'ttag';
5 import { useApi, useNotifications, useUser } from '@proton/components/hooks';
6 import useAssistantTelemetry from '@proton/components/hooks/assistant/useAssistantTelemetry';
7 import useStateRef from '@proton/hooks/useStateRef';
8 import type { AssistantHooksProps, AssistantRunningActions, GenerateAssistantResult } from '@proton/llm/lib';
9 import { MODEL_UNLOADED } from '@proton/llm/lib';
14 UNLOAD_ASSISTANT_TIMEOUT,
18 } from '@proton/llm/lib';
19 import { GpuLlmManager } from '@proton/llm/lib/actions';
20 import type useAssistantCommons from '@proton/llm/lib/hooks/useAssistantCommons';
24 GenerationCallbackDetails,
27 } from '@proton/llm/lib/types';
28 import { ASSISTANT_TYPE, ERROR_TYPE, GENERATION_SELECTION_TYPE } from '@proton/shared/lib/assistant';
29 import { domIsBusy } from '@proton/shared/lib/busy';
30 import { isElectronApp } from '@proton/shared/lib/helpers/desktop';
31 import { traceInitiativeError } from '@proton/shared/lib/helpers/sentry';
34 commonState: ReturnType<typeof useAssistantCommons>;
38 export const useAssistantLocal = ({ commonState, active }: Props): AssistantHooksProps => {
40 const { createNotification } = useNotifications();
41 const [user] = useUser();
43 const llmManager = useRef<LlmManager | null>(null);
44 const llmModel = useRef<LlmModel | null>(null);
46 const assistantConfigRef = useRef<AssistantConfig>();
47 /** In order to be able to wait for config to be set */
48 const assistantConfigPromiseRef = useRef<Promise<void>>();
54 downloadReceivedBytes,
64 isModelDownloaded: false,
65 isModelDownloading: false,
66 downloadReceivedBytes: 0,
68 downloadPaused: false,
69 isModelLoadedOnGPU: false,
70 isModelLoadingOnGPU: false,
71 // ref to know if the user downloaded the model in this session
72 userDownloadedModel: false,
73 // Value used to know when we are checking for model files in cache
74 isCheckingCache: false,
77 const generatedTokensNumber = useRef(0);
78 const initPromise = useRef<Promise<void>>();
86 hasCompatibleHardware,
88 assistantSubscriptionStatus,
93 sendRequestAssistantReport,
94 sendUnloadModelAssistantReport,
95 sendDownloadAssistantReport,
96 sendLoadModelAssistantReport,
97 } = useAssistantTelemetry();
102 getRunningActionFromAssistantID,
106 // eslint-disable-next-line @typescript-eslint/no-use-before-define
107 } = useRunningActions({ addSpecificError });
109 const handleGetAssistantConfig = () => {
110 llmManager.current = new GpuLlmManager();
112 assistantConfigPromiseRef.current = new Promise((resolve, reject) => {
113 void queryAssistantModels(api)
114 .then(async (models) => {
115 const config = buildMLCConfig(models);
117 assistantConfigRef.current = config;
119 if (llmManager.current) {
120 // Check if user has all needed files in cache, so that we can set the state and avoid
121 // going through the download phase when it's not needed during init
122 setLocalState((localState) => ({
124 isCheckingCache: true,
126 await llmManager.current.isDownloaded(config).then((isDownloaded) => {
127 setLocalState((localState) => ({
129 isModelDownloading: false,
130 isModelDownloaded: isDownloaded,
131 isCheckingCache: false,
136 // Resolve the config promise ref so that we can proceed with init if needed
144 if (active && !llmManager.current && canShowAssistant && hasCompatibleHardware && hasCompatibleBrowser) {
145 // Start llm manager and get assistant models API side
146 handleGetAssistantConfig();
148 }, [active, canShowAssistant, hasCompatibleHardware, hasCompatibleBrowser]);
150 const downloadCallback = (info: DownloadProgressInfo) => {
151 setLocalState((localState) => ({
153 downloadModelSize: info.estimatedTotalBytes,
154 downloadReceivedBytes: info.receivedBytes,
155 userDownloadedModel: true,
159 const downloadModel = async () => {
161 if (!llmManager.current || !assistantConfigRef.current) {
162 return; // throw an error?
164 // Clean global errors when downloading the model
167 const startDownloadingTime = performance.now();
168 setLocalState((localState) => ({
170 isModelDownloading: true,
172 const completed = await llmManager.current.startDownload(downloadCallback, assistantConfigRef.current);
173 setLocalState((localState) => ({
175 isModelDownloading: false,
176 isModelDownloaded: completed,
179 // fixme: this can report partial download time if we were resuming a previous download session
180 const endDownloadingTime = performance.now();
181 const downloadingTime = endDownloadingTime - startDownloadingTime;
182 sendDownloadAssistantReport(downloadingTime);
188 if (e.message === CACHING_FAILED) {
189 addGlobalError(ASSISTANT_TYPE.LOCAL, ERROR_TYPE.CACHING_FAILED);
191 const isRequestError = e.message.includes(FAILED_TO_DOWNLOAD);
193 ASSISTANT_TYPE.LOCAL,
194 isRequestError ? ERROR_TYPE.DOWNLOAD_REQUEST_FAIL : ERROR_TYPE.DOWNLOAD_FAIL
198 traceInitiativeError('assistant', e);
200 setLocalState((localState) => ({
202 isModelDownloading: false,
204 throw new Error(errorMessage);
208 const loadModelOnGPU = async () => {
210 if (!llmManager.current || !assistantConfigRef.current) {
211 return; // throw an error?
213 const startLoadingTime = performance.now();
214 // Clean global errors when loading the model
217 setLocalState((localState) => ({
219 isModelLoadingOnGPU: true,
221 const model = await llmManager.current.loadOnGpu(assistantConfigRef.current);
222 llmModel.current = model;
223 setLocalState((localState) => ({
225 isModelLoadingOnGPU: false,
226 isModelLoadedOnGPU: true,
228 const endLoadingTime = performance.now();
229 const loadingTime = endLoadingTime - startLoadingTime;
231 sendLoadModelAssistantReport(loadingTime);
233 traceInitiativeError('assistant', e);
235 const errorMessage = addGlobalError(ASSISTANT_TYPE.LOCAL, ERROR_TYPE.LOADGPU_FAIL);
236 setLocalState((localState) => ({
238 isModelLoadingOnGPU: false,
239 isModelLoadedOnGPU: false,
241 throw new Error(errorMessage);
245 const unloadModelOnGPU = async () => {
247 if (llmModel.current) {
248 await llmModel.current?.unload();
249 setLocalState((localState) => ({
251 isModelLoadedOnGPU: false,
254 sendUnloadModelAssistantReport();
257 traceInitiativeError('assistant', e);
259 addGlobalError(ASSISTANT_TYPE.LOCAL, ERROR_TYPE.UNLOAD_FAIL);
263 const initAssistant = useCallback(() => {
264 // Reset download pause state
265 setLocalState((localState) => ({
267 downloadPaused: false,
270 // If the assistant is already initializing, then we simply wait for the end of the initialization
271 if (initPromise.current) {
272 return initPromise.current;
275 initPromise.current = (async () => {
277 * To init the assistant
278 * 1 - We start by downloading the model if not downloaded yet and model is not downloading at the moment
279 * 2 - Then we can load the model on the GPU if not loaded yet and not loading at the moment
282 // Use try catch in case one of the steps fails, so that we don't run the next step
284 let completedDownload;
286 // If assistant config is not set at all, start get config process manually
287 // Typically, this happens when starting the assistant for the first time and choosing local mode.
288 // We trigger the init manually, but the useEffect that handle getting the config has not run yet.
289 if (!llmManager.current) {
290 handleGetAssistantConfig();
293 // Ensure config is set before starting init
294 if (assistantConfigPromiseRef.current) {
295 await assistantConfigPromiseRef.current;
298 // We don't want to go through the init process with free users.
299 // Free users have no trial period, so they shouldn't be able to download the model,
300 // and of course they shouldn't pass in the loading on GPU step too
311 } = localStateRef.current;
313 if (!isModelDownloaded && !isModelDownloading) {
314 completedDownload = await downloadModel();
315 } else if (isModelDownloaded) {
316 completedDownload = true;
319 if (completedDownload && !isModelLoadedOnGPU && !isModelLoadingOnGPU) {
320 await loadModelOnGPU();
323 // Show a notification only when the user had to download the model
324 if (completedDownload && userDownloadedModel) {
326 text: c('Notification').t`The writing assistant is ready to use`,
329 setLocalState((localState) => ({
331 userDownloadedModel: false,
336 // Reset init promise after init or when init failed so that we can
337 // - Start init again if necessary
338 // - Proceed if init completed
339 initPromise.current = undefined;
343 const cancelDownloadModel = () => {
344 if (llmManager.current) {
345 llmManager.current.cancelDownload();
346 initPromise.current = undefined;
347 setLocalState((localState) => ({
349 downloadPaused: true,
354 const resumeDownloadModel = () => {
355 void initAssistant();
356 setLocalState((localState) => ({
358 downloadPaused: false,
362 const generateResult = async ({ action, callback, assistantID, hasSelection }: GenerateAssistantResult) => {
363 // TODO prevent submit if user made too much harmful requests recently
365 // Do not start multiple actions in the same assistant
366 const runningActionInAssistant = getRunningActionFromAssistantID(assistantID);
367 if (runningActionInAssistant) {
371 // Reset generation errors in this assistant
372 cleanSpecificErrors(assistantID);
374 let isResolved = false;
376 // The generation needs to be stopped in two different cases:
377 // 1 - The assistant is ready (everything is loaded) and the user stops it.
378 // In that case, we can stop the running action.
379 // 2 - The assistant is still loading, but the user submitted a request and cancelled it.
380 // In that case, we don't have the running action yet, so we need to cancel the promise.
381 // That's why the entire function is run into a Promise. We can then store it in a ref and cancel it when needed.
382 await new Promise<void>(async (res) => {
383 const resolve = () => {
387 addRunningAction(assistantID, resolve);
389 const ingestionStart = performance.now();
391 // Start the initialization in case the assistant is not loaded yet.
392 // If it is loaded already, then nothing will be done,
393 // else we will wait for the init process to be finished before starting the generation
394 await initAssistant();
396 // If an error occurred during the init, we set an error.
397 // In that case, we can stop the generation before going further.
401 const ingestionEnd = performance.now();
402 const ingestionTime = ingestionEnd - ingestionStart;
405 // If the promise is resolved, we cancelled it after a user interaction.
406 // We don't want to generate a result anymore.
407 if (llmModel.current && !isResolved) {
408 let promptRejectedOnce = false;
409 const generationCallback = (fulltext: string, details?: GenerationCallbackDetails) => {
410 if (promptRejectedOnce) {
413 generatedTokensNumber.current++;
414 const isHarmful = details?.harmful;
416 // Used to prevent adding additional tokens that we receive after cancelling a running action
417 const isRunningAction = assistantID in runningActionsRef.current;
420 if (isRunningAction) {
424 promptRejectedOnce = true;
425 cleanRunningActions(assistantID);
429 const generationStart = performance.now();
431 if (assistantSubscriptionStatus.trialStatus === 'trial-not-started') {
432 await assistantSubscriptionStatus.start();
435 const runningAction = await llmModel.current.performAction(action, generationCallback);
437 addRunningAction(assistantID, () => {
438 // Resolve is needed to end parent promise
440 // Here we stop the LLM generation
441 runningAction.cancel();
444 await runningAction.waitForCompletion();
446 // Throw an error if the user made a harmful request
447 if (promptRejectedOnce) {
448 throw new PromptRejectedError();
451 // Send telemetry report
452 const generationEnd = performance.now();
453 const generationTime = generationEnd - generationStart;
454 sendRequestAssistantReport({
455 assistantType: ASSISTANT_TYPE.LOCAL,
456 generationType: getGenerationType(action),
457 selectionType: hasSelection
458 ? GENERATION_SELECTION_TYPE.HAS_SELECTION
459 : GENERATION_SELECTION_TYPE.NO_SELECTION,
462 tokensGenerated: generatedTokensNumber.current,
464 generatedTokensNumber.current = 0;
467 // Sometimes the model is being unloaded automatically by the llm lib.
468 // If we detect that the model is not loaded during generate, load it again and retry generation
469 if (e === MODEL_UNLOADED) {
470 await loadModelOnGPU();
471 return generateResult({ action, callback, assistantID, hasSelection });
474 if (e.name === 'PromptRejectedError') {
477 assistantType: ASSISTANT_TYPE.LOCAL,
478 errorType: ERROR_TYPE.GENERATION_HARMFUL,
483 assistantType: ASSISTANT_TYPE.LOCAL,
484 errorType: ERROR_TYPE.GENERATION_FAIL,
487 traceInitiativeError('assistant', e);
490 // Reset assistant result when an error occurred while generating
491 // Otherwise, on next submit the previous result will be displayed for a few ms
495 // Reset the generating state
496 cleanRunningActions(assistantID);
500 const resetAssistantState = () => {
501 // Cancel model downloading
502 if (initPromise.current) {
503 void cancelDownloadModel();
505 // Unload model from GPU
506 if (isModelLoadedOnGPU) {
507 void unloadModelOnGPU();
509 // Cancel all running actions
510 if (Object.keys(runningActions).length) {
515 setLocalState((localState) => ({
517 isModelDownloaded: false,
518 isModelDownloading: false,
519 isModelLoadedOnGPU: false,
520 isModelLoadingOnGPU: false,
521 downloadPaused: false,
522 downloadModelSize: 0,
523 downloadReceivedBytes: 0,
527 // Unload the model after some time of non usage
528 // If model is loaded on the GPU, check every X minutes if user is busy
529 // Reset the timeout completely when user is generating a new result
531 if (!isModelLoadedOnGPU) {
535 const id = setInterval(() => {
539 if (isElectronApp && document.hasFocus()) {
543 void unloadModelOnGPU();
544 }, UNLOAD_ASSISTANT_TIMEOUT);
549 }, [isModelLoadedOnGPU, runningActions]);
552 assistantConfig: assistantConfigRef.current,
558 downloadReceivedBytes,
566 // GPU loading related
574 cancelRunningAction: cleanRunningActions,
576 closeAssistant: closeAssistant(cleanRunningActions),
582 interface UseRunningActionsProps {
583 addSpecificError: ReturnType<typeof useAssistantCommons>['addSpecificError'];
586 function useRunningActions({ addSpecificError }: UseRunningActionsProps) {
587 const [runningActions, setRunningAction, runningActionsRef] = useStateRef<AssistantRunningActions>({});
589 const addRunningAction = (assistantID: string, resolver: () => void) => {
590 setRunningAction((runningActions) => ({
592 [assistantID]: resolver,
596 const cleanRunningActions = (assistantID: string) => {
598 const runningAction = runningActionsRef.current[assistantID];
601 setRunningAction((runningAction) => {
602 delete runningAction[assistantID];
603 return { ...runningAction };
607 traceInitiativeError('assistant', e);
610 assistantType: ASSISTANT_TYPE.LOCAL,
611 errorType: ERROR_TYPE.GENERATION_CANCEL_FAIL,
619 getRunningActionFromAssistantID: (assistantID: string) => assistantID in runningActions,
620 resetResolvers: () => setRunningAction({}),