Merge branch 'INDA-330-pii-update' into 'main'
[ProtonMail-WebClient.git] / packages / llm / lib / runningAction.ts
blob8476f4041cdc16fdb15d13cb01cd97bf12cdd5d9
1 import type { GenerationConfig, WebWorkerEngine } from '@mlc-ai/web-llm';
3 import { getTransformForAction } from '@proton/llm/lib/actions';
4 import type { Action, GenerationCallback, PromiseReject, PromiseResolve, RunningAction } from '@proton/llm/lib/types';
6 export class BaseRunningAction implements RunningAction {
7     private action_: Action;
9     protected chat: WebWorkerEngine;
11     protected running: boolean;
13     protected done: boolean;
15     protected cancelled: boolean;
17     protected finishedPromise: Promise<void>;
19     protected finishedPromiseSignals: { resolve: PromiseResolve; reject: PromiseReject } | undefined;
21     // @ts-ignore
22     protected generation: Promise<void>;
24     constructor(prompt: string, callback: GenerationCallback, chat: WebWorkerEngine, action: Action, stop?: string[]) {
25         let firstTimestamp: number | null = null;
26         let lastTimestamp: number | null = null;
27         let intervals: number[] = [];
29         const transform = getTransformForAction(action);
31         const generateProgressCallback = (_step: number, message: string) => {
32             const now = Date.now();
33             let slow = false;
34             if (firstTimestamp === null) {
35                 firstTimestamp = now;
36             }
37             if (lastTimestamp !== null) {
38                 const intervalMs = now - lastTimestamp;
39                 const elapsedMs = now - firstTimestamp;
40                 const elapsedSec = elapsedMs / 1000;
41                 intervals = [...intervals, intervalMs].slice(-10);
42                 const meanIntervalMs = intervals.reduce((a, b) => a + b, 0) / intervals.length;
43                 let meanIntervalSec = meanIntervalMs / 1000.0;
44                 const tokenPerSec = 1.0 / meanIntervalSec;
45                 slow = elapsedSec > 5 && tokenPerSec < 2;
46             }
47             lastTimestamp = now;
49             const fulltext = transform(message);
50             const harmful = fulltext === undefined;
51             callback(fulltext || '', { slow, harmful });
52         };
54         this.finishedPromise = new Promise<void>((resolve: PromiseResolve, reject: PromiseReject) => {
55             this.finishedPromiseSignals = { resolve, reject };
56         });
58         const stopStrings = ['<|', '\n[Your Name]\n', ...(stop || [])];
59         const genConfig: GenerationConfig = {
60             stop: stopStrings,
61         };
63         this.generation = chat
64             .generate(prompt, generateProgressCallback, undefined, genConfig)
65             .then(() => {
66                 this.finishedPromiseSignals!.resolve();
67             })
68             .catch((e) => {
69                 this.done = true;
70                 this.finishedPromiseSignals!.reject(e);
71             })
72             .finally(async () => {
73                 this.running = false;
74                 await chat.resetChat();
75             });
77         this.chat = chat;
78         this.running = true;
79         this.done = false;
80         this.cancelled = false;
81         this.action_ = action;
82     }
84     action(): Action {
85         return this.action_;
86     }
88     isRunning(): boolean {
89         return this.running;
90     }
92     isDone(): boolean {
93         return this.done;
94     }
96     isCancelled(): boolean {
97         return this.cancelled;
98     }
100     cancel(): boolean {
101         if (this.running) {
102             this.chat.interruptGenerate();
103             this.running = false;
104             this.done = false;
105             this.cancelled = true;
106             return true;
107         }
108         return false;
109     }
111     waitForCompletion(): Promise<void> {
112         return this.finishedPromise;
113     }