1 import type { GenerationConfig, WebWorkerEngine } from '@mlc-ai/web-llm';
3 import { getTransformForAction } from '@proton/llm/lib/actions';
4 import type { Action, GenerationCallback, PromiseReject, PromiseResolve, RunningAction } from '@proton/llm/lib/types';
6 export class BaseRunningAction implements RunningAction {
7 private action_: Action;
9 protected chat: WebWorkerEngine;
11 protected running: boolean;
13 protected done: boolean;
15 protected cancelled: boolean;
17 protected finishedPromise: Promise<void>;
19 protected finishedPromiseSignals: { resolve: PromiseResolve; reject: PromiseReject } | undefined;
22 protected generation: Promise<void>;
24 constructor(prompt: string, callback: GenerationCallback, chat: WebWorkerEngine, action: Action, stop?: string[]) {
25 let firstTimestamp: number | null = null;
26 let lastTimestamp: number | null = null;
27 let intervals: number[] = [];
29 const transform = getTransformForAction(action);
31 const generateProgressCallback = (_step: number, message: string) => {
32 const now = Date.now();
34 if (firstTimestamp === null) {
37 if (lastTimestamp !== null) {
38 const intervalMs = now - lastTimestamp;
39 const elapsedMs = now - firstTimestamp;
40 const elapsedSec = elapsedMs / 1000;
41 intervals = [...intervals, intervalMs].slice(-10);
42 const meanIntervalMs = intervals.reduce((a, b) => a + b, 0) / intervals.length;
43 let meanIntervalSec = meanIntervalMs / 1000.0;
44 const tokenPerSec = 1.0 / meanIntervalSec;
45 slow = elapsedSec > 5 && tokenPerSec < 2;
49 const fulltext = transform(message);
50 const harmful = fulltext === undefined;
51 callback(fulltext || '', { slow, harmful });
54 this.finishedPromise = new Promise<void>((resolve: PromiseResolve, reject: PromiseReject) => {
55 this.finishedPromiseSignals = { resolve, reject };
58 const stopStrings = ['<|', '\n[Your Name]\n', ...(stop || [])];
59 const genConfig: GenerationConfig = {
63 this.generation = chat
64 .generate(prompt, generateProgressCallback, undefined, genConfig)
66 this.finishedPromiseSignals!.resolve();
70 this.finishedPromiseSignals!.reject(e);
72 .finally(async () => {
74 await chat.resetChat();
80 this.cancelled = false;
81 this.action_ = action;
88 isRunning(): boolean {
96 isCancelled(): boolean {
97 return this.cancelled;
102 this.chat.interruptGenerate();
103 this.running = false;
105 this.cancelled = true;
111 waitForCompletion(): Promise<void> {
112 return this.finishedPromise;