1 import { stringToUtf8Array, utf8ArrayToString } from '@proton/crypto/lib/utils'
2 import type { DocumentKeys, NodeMeta, PublicNodeMeta } from '@proton/drive-store'
3 import type { EncryptMessage } from '../../UseCase/EncryptMessage'
4 import type { AnonymousEncryptionMetadata, EncryptionMetadata } from '../../Types/EncryptionMetadata'
5 import type { LoggerInterface } from '@proton/utils/logs'
6 import { WebsocketConnection } from '../../Realtime/WebsocketConnection'
8 InternalEventBusInterface,
9 WebsocketConnectionInterface,
12 } from '@proton/docs-shared'
13 import { BroadcastSource, ProcessedIncomingRealtimeEventMessage, assertUnreachableAndLog } from '@proton/docs-shared'
14 import type { GetRealtimeUrlAndToken } from '../../UseCase/CreateRealtimeValetToken'
15 import type { WebsocketServiceInterface } from './WebsocketServiceInterface'
16 import metrics from '@proton/metrics'
18 ServerMessageWithEvents,
19 ServerMessageWithDocumentUpdates,
20 ServerMessageWithMessageAcks,
21 ConnectionReadyPayload,
22 } from '@proton/docs-proto'
29 CreateClientEventMessage,
31 CreateDocumentUpdateMessage,
32 DocumentUpdateVersion,
33 CreateClientMessageWithDocumentUpdates,
34 ConnectionCloseReason,
36 } from '@proton/docs-proto'
37 import { c } from 'ttag'
38 import type { DecryptMessage } from '../../UseCase/DecryptMessage'
39 import type { DocumentConnectionRecord } from './DocumentConnectionRecord'
40 import { GenerateUUID } from '../../Util/GenerateUuid'
41 import { AckLedger } from './AckLedger/AckLedger'
42 import type { AckLedgerInterface } from './AckLedger/AckLedgerInterface'
43 import type { WebsocketConnectionEventPayloads } from '../../Realtime/WebsocketEvent/WebsocketConnectionEventPayloads'
44 import { WebsocketConnectionEvent } from '../../Realtime/WebsocketEvent/WebsocketConnectionEvent'
45 import { DocsApiErrorCode } from '@proton/shared/lib/api/docs'
46 import { UpdateDebouncer } from './Debouncer/UpdateDebouncer'
47 import { UpdateDebouncerEventType } from './Debouncer/UpdateDebouncerEventType'
48 import { DocumentDebounceMode } from './Debouncer/DocumentDebounceMode'
49 import { PostApplicationError } from '../../Application/ApplicationEvent'
50 import type { MetricService } from '../Metrics/MetricService'
51 import type { UserState } from '../../State/UserState'
52 import { isPrivateDocumentKeys, type PublicDocumentKeys } from '../../Types/DocumentEntitlements'
56 export class WebsocketService implements WebsocketServiceInterface {
57 private connections: Record<LinkID, DocumentConnectionRecord> = {}
58 readonly ledger: AckLedgerInterface = new AckLedger(this.logger, this.handleLedgerStatusChangeCallback.bind(this))
61 private userState: UserState,
62 private _createRealtimeValetToken: GetRealtimeUrlAndToken,
63 private _encryptMessage: EncryptMessage,
64 private _decryptMessage: DecryptMessage,
65 private logger: LoggerInterface,
66 private eventBus: InternalEventBusInterface,
67 private metricService: MetricService,
68 private appVersion: string,
70 window.addEventListener('beforeunload', this.handleWindowUnload)
74 window.removeEventListener('beforeunload', this.handleWindowUnload)
77 for (const { debouncer, connection } of Object.values(this.connections)) {
85 handleLedgerStatusChangeCallback(): void {
86 this.eventBus.publish<WebsocketConnectionEventPayloads[WebsocketConnectionEvent.AckStatusChange]>({
87 type: WebsocketConnectionEvent.AckStatusChange,
94 handleWindowUnload = (event: BeforeUnloadEvent): void => {
95 const connections = Object.values(this.connections)
97 for (const { debouncer } of connections) {
98 if (debouncer.hasPendingUpdates()) {
100 event.preventDefault()
104 if (this.ledger.hasConcerningMessages() || this.ledger.hasErroredMessages()) {
105 this.retryAllFailedDocumentUpdates()
106 event.preventDefault()
110 flushPendingUpdates(): void {
111 const connections = Object.values(this.connections)
113 for (const { debouncer } of connections) {
114 if (debouncer.hasPendingUpdates()) {
121 document: NodeMeta | PublicNodeMeta,
122 keys: DocumentKeys | PublicDocumentKeys,
123 options: { commitId: () => string | undefined },
124 ): WebsocketConnectionInterface {
125 this.logger.info(`Creating connection for document ${document.linkId}`)
127 const callbacks: WebsocketCallbacks = {
128 onConnecting: () => {
129 this.eventBus.publish<WebsocketConnectionEventPayloads[WebsocketConnectionEvent.Connecting]>({
130 type: WebsocketConnectionEvent.Connecting,
138 this.eventBus.publish<
139 WebsocketConnectionEventPayloads[WebsocketConnectionEvent.ConnectionEstablishedButNotYetReady]
141 type: WebsocketConnectionEvent.ConnectionEstablishedButNotYetReady,
148 onFailToConnect: (reason) => {
149 this.eventBus.publish<WebsocketConnectionEventPayloads[WebsocketConnectionEvent.FailedToConnect]>({
150 type: WebsocketConnectionEvent.FailedToConnect,
153 serverReason: reason,
158 onClose: (reason) => {
159 this.eventBus.publish<WebsocketConnectionEventPayloads[WebsocketConnectionEvent.Disconnected]>({
160 type: WebsocketConnectionEvent.Disconnected,
163 serverReason: reason,
168 onMessage: (message) => {
169 void this.handleConnectionMessage(document, message)
172 onEncryptionError: (error) => {
173 this.eventBus.publish<WebsocketConnectionEventPayloads[WebsocketConnectionEvent.EncryptionError]>({
174 type: WebsocketConnectionEvent.EncryptionError,
182 onFailToGetToken: (errorCode) => {
183 if (errorCode === DocsApiErrorCode.CommitIdOutOfSync) {
184 this.eventBus.publish<
185 WebsocketConnectionEventPayloads[WebsocketConnectionEvent.FailedToGetTokenCommitIdOutOfSync]
187 type: WebsocketConnectionEvent.FailedToGetTokenCommitIdOutOfSync,
193 this.logger.error(`Failed to get token: ${errorCode}`)
197 getUrlAndToken: async () => {
198 const result = await this._createRealtimeValetToken.execute(document, options.commitId())
200 if (!result.isFailed()) {
201 this.handleRetrievedValetTokenResult(result.getValue())
208 const connection = new WebsocketConnection(callbacks, this.metricService, this.logger, this.appVersion)
210 const debouncer = new UpdateDebouncer(document, this.logger, (event) => {
211 if (event.type === UpdateDebouncerEventType.DidFlush) {
212 void this.handleDocumentUpdateDebouncerFlush(document, event.mergedUpdate)
213 } else if (event.type === UpdateDebouncerEventType.WillFlush) {
214 this.eventBus.publish<WebsocketConnectionEventPayloads[WebsocketConnectionEvent.Saving]>({
215 type: WebsocketConnectionEvent.Saving,
223 this.connections[document.linkId] = {
233 handleRetrievedValetTokenResult(result: RealtimeUrlAndToken): void {
234 this.userState.setProperty('currentDocumentEmailDocTitleEnabled', result.preferences.includeDocumentNameInEmails)
237 isConnectionReadyPayload(obj: any): obj is ConnectionReadyPayload {
239 typeof obj === 'object' &&
241 typeof obj.clientUpgradeRecommended === 'boolean' &&
242 typeof obj.clientUpgradeRequired === 'boolean'
246 onDocumentConnectionReadyToBroadcast(record: DocumentConnectionRecord, content: Uint8Array): void {
247 this.logger.info('Received ready to broadcast message from RTS')
249 record.connection.markAsReadyToAcceptMessages()
250 record.debouncer.markAsReadyToFlush()
252 let readinessInformation: ConnectionReadyPayload | undefined
254 const parsed = JSON.parse(utf8ArrayToString(content))
255 if (this.isConnectionReadyPayload(parsed)) {
256 readinessInformation = parsed
259 this.logger.error('Unable to parse content from ConnectionReady message')
262 this.eventBus.publish<WebsocketConnectionEventPayloads[WebsocketConnectionEvent.ConnectedAndReady]>({
263 type: WebsocketConnectionEvent.ConnectedAndReady,
265 document: record.document,
266 readinessInformation: readinessInformation,
270 this.retryFailedDocumentUpdatesForDoc(record.document)
273 retryAllFailedDocumentUpdates(): void {
274 this.logger.info('Retrying all failed document updates')
276 for (const record of Object.values(this.connections)) {
277 this.retryFailedDocumentUpdatesForDoc(record.document)
281 retryFailedDocumentUpdatesForDoc(document: NodeMeta | PublicNodeMeta): void {
282 const record = this.getConnectionRecord(document.linkId)
284 throw new Error('Connection not found')
287 const failedUpdates = this.ledger.getUnacknowledgedUpdates()
288 if (failedUpdates.length === 0) {
292 this.logger.info(`Retrying ${failedUpdates.length} failed document updates ${failedUpdates.map((u) => u.uuid)}`)
294 const message = CreateClientMessageWithDocumentUpdates({
295 updates: failedUpdates,
298 const messageWrapper = new ClientMessage({ documentUpdatesMessage: message })
300 const binary = messageWrapper.serializeBinary()
302 this.logger.info(`Broadcasting failed document update of size: ${binary.byteLength} bytes`)
304 void record.connection.broadcastMessage(binary, BroadcastSource.RetryingMessagesAfterReconnect)
306 metrics.docs_document_updates_total.increment({})
309 getConnectionRecord(linkId: LinkID): DocumentConnectionRecord | undefined {
310 return this.connections[linkId]
313 isConnected(nodeMeta: NodeMeta | PublicNodeMeta): boolean {
314 const record = this.getConnectionRecord(nodeMeta.linkId)
319 return record.connection.isConnected()
322 async reconnectToDocumentWithoutDelay(nodeMeta: NodeMeta | PublicNodeMeta): Promise<void> {
323 const record = this.getConnectionRecord(nodeMeta.linkId)
325 throw new Error('Connection not found')
328 if (record.connection.isConnected()) {
329 this.logger.info(`Connection is already connected`)
333 this.logger.info(`Reconnecting to document without delay`)
335 await record.connection.connect()
338 async handleDocumentUpdateDebouncerFlush(
339 nodeMeta: NodeMeta | PublicNodeMeta,
340 mergedUpdate: Uint8Array,
342 const record = this.getConnectionRecord(nodeMeta.linkId)
344 throw new Error('Connection not found')
347 const { keys, connection } = record
349 const metadata: EncryptionMetadata | AnonymousEncryptionMetadata = {
350 authorAddress: isPrivateDocumentKeys(keys) ? keys.userOwnAddress : undefined,
351 timestamp: Date.now(),
352 version: DocumentUpdateVersion.V1,
355 const encryptedContent = await this.encryptMessage(
360 BroadcastSource.DocumentBufferFlush,
363 const uuid = GenerateUUID()
365 const message = CreateDocumentUpdateMessage({
366 content: encryptedContent,
371 const messageWrapper = new ClientMessage({ documentUpdatesMessage: message })
372 const binary = messageWrapper.serializeBinary()
374 this.logger.info(`Broadcasting document update ${uuid} of size: ${binary.byteLength} bytes`)
376 this.ledger.messagePosted(message)
378 void connection.broadcastMessage(binary, BroadcastSource.DocumentBufferFlush)
380 metrics.docs_document_updates_total.increment({})
383 async sendDocumentUpdateMessage(
384 nodeMeta: NodeMeta | PublicNodeMeta,
385 rawContent: Uint8Array | Uint8Array[],
387 const record = this.getConnectionRecord(nodeMeta.linkId)
389 throw new Error('Connection not found')
392 const { debouncer } = record
394 debouncer.addUpdates(
395 Array.isArray(rawContent) ? rawContent.map((c) => new DecryptedValue(c)) : [new DecryptedValue(rawContent)],
399 async sendEventMessage(
400 nodeMeta: NodeMeta | PublicNodeMeta,
401 rawContent: Uint8Array,
403 source: BroadcastSource,
405 const record = this.getConnectionRecord(nodeMeta.linkId)
407 throw new Error('Connection not found')
410 const { keys, connection, debouncer } = record
412 if (debouncer.getMode() === DocumentDebounceMode.SinglePlayer) {
413 const eventsThatShouldNotBeSentIfInSinglePlayerMode: EventTypeEnum[] = [
414 EventTypeEnum.ClientIsBroadcastingItsPresenceState,
415 EventTypeEnum.ClientHasSentACommentMessage,
418 if (eventsThatShouldNotBeSentIfInSinglePlayerMode.includes(type)) {
419 this.logger.info('Not in real time mode. Not sending event:', EventTypeEnum[type])
424 if (!record.connection.canBroadcastMessages()) {
425 this.logger.info(`Not sending event ${EventTypeEnum[type]} because connection is not ready`)
429 const metadata: EncryptionMetadata | AnonymousEncryptionMetadata = {
430 authorAddress: isPrivateDocumentKeys(keys) ? keys.userOwnAddress : undefined,
431 timestamp: Date.now(),
432 version: ClientEventVersion.V1,
435 const encryptedContent = await this.encryptMessage(rawContent, metadata, nodeMeta, keys, source)
436 const message = CreateClientEventMessage({
437 content: encryptedContent,
442 const messageWrapper = new ClientMessage({ eventsMessage: message })
443 const binary = messageWrapper.serializeBinary()
446 `Broadcasting event message of type: ${EventTypeEnum[type]} from source: ${source} size: ${binary.byteLength} bytes`,
449 void connection.broadcastMessage(binary, source)
452 async encryptMessage(
454 metadata: EncryptionMetadata | AnonymousEncryptionMetadata,
455 document: NodeMeta | PublicNodeMeta,
456 keys: DocumentKeys | PublicDocumentKeys,
457 source: BroadcastSource,
458 ): Promise<Uint8Array> {
459 const result = await this._encryptMessage.execute(content, metadata, keys)
461 if (result.isFailed()) {
462 const message = c('Error').t`We are having trouble saving recent edits. Please refresh the page.`
464 this.eventBus.publish<WebsocketConnectionEventPayloads[WebsocketConnectionEvent.EncryptionError]>({
465 type: WebsocketConnectionEvent.EncryptionError,
472 if (source === BroadcastSource.CommentsController) {
473 metrics.docs_comments_error_total.increment({
474 reason: 'encryption_error',
478 this.logger.error('Unable to encrypt realtime message', result.getError())
480 PostApplicationError(this.eventBus, {
481 translatedError: message,
482 translatedErrorTitle: message,
486 throw new Error(`Unable to encrypt message: ${result.getError()}`)
489 return new Uint8Array(result.getValue())
492 async handleConnectionMessage(document: NodeMeta | PublicNodeMeta, data: Uint8Array): Promise<void> {
493 const record = this.getConnectionRecord(document.linkId)
495 throw new Error('Connection not found')
498 const message = ServerMessage.deserializeBinary(data)
499 const type = ServerMessageType.create(message.type)
501 if (type.hasDocumentUpdates()) {
502 await this.handleIncomingDocumentUpdatesMessage(record, message.documentUpdatesMessage)
503 } else if (type.hasEvents()) {
504 await this.handleIncomingEventsMessage(record, message.eventsMessage)
505 } else if (type.isMessageAck()) {
506 await this.handleAckMessage(record, message.acksMessage)
508 throw new Error('Unknown message type')
512 async handleAckMessage(record: DocumentConnectionRecord, message: ServerMessageWithMessageAcks): Promise<void> {
513 this.ledger.messageAcknowledgementReceived(message)
514 this.eventBus.publish<WebsocketConnectionEventPayloads[WebsocketConnectionEvent.Saved]>({
515 type: WebsocketConnectionEvent.Saved,
517 document: record.document,
522 async handleIncomingDocumentUpdatesMessage(
523 record: DocumentConnectionRecord,
524 message: ServerMessageWithDocumentUpdates,
526 if (message.updates.documentUpdates.length === 0) {
531 `Received ${message.updates.documentUpdates.length} document updates with ids ${message.updates.documentUpdates.map((u) => u.uuid)}`,
534 const { keys, debouncer, document } = record
536 for (const update of message.updates.documentUpdates) {
537 const isReceivedUpdateFromOtherUser = isPrivateDocumentKeys(keys) && update.authorAddress !== keys.userOwnAddress
538 const isReceivedUpdateFromAnonymousUser = isPrivateDocumentKeys(keys) && !update.authorAddress
539 if (isReceivedUpdateFromOtherUser || isReceivedUpdateFromAnonymousUser) {
540 this.switchToRealtimeMode(debouncer, 'receiving DU from other user')
543 const decryptionResult = await this._decryptMessage.execute({
545 documentContentKey: keys.documentContentKey,
548 if (decryptionResult.isFailed()) {
549 metrics.docs_document_updates_decryption_error_total.increment({
552 throw new Error(`Failed to decrypt document update: ${decryptionResult.getError()}`)
555 const decrypted = decryptionResult.getValue()
557 this.eventBus.publish<WebsocketConnectionEventPayloads[WebsocketConnectionEvent.DocumentUpdateMessage]>({
558 type: WebsocketConnectionEvent.DocumentUpdateMessage,
567 switchToRealtimeMode(debouncer: UpdateDebouncer, reason: string): void {
568 if (debouncer.getMode() === DocumentDebounceMode.Realtime) {
572 this.logger.info('Switching to realtime mode due to', reason)
575 debouncer.setMode(DocumentDebounceMode.Realtime)
578 async handleIncomingEventsMessage(record: DocumentConnectionRecord, message: ServerMessageWithEvents): Promise<void> {
579 const { keys, debouncer, document } = record
581 const processedMessages: ProcessedIncomingRealtimeEventMessage[] = []
583 const eventsThatTakeUsIntoRealtimeMode: EventTypeEnum[] = [
584 EventTypeEnum.ClientIsRequestingOtherClientsToBroadcastTheirState,
585 EventTypeEnum.ClientIsBroadcastingItsPresenceState,
588 for (const event of message.events) {
589 if (eventsThatTakeUsIntoRealtimeMode.includes(event.type)) {
590 this.switchToRealtimeMode(debouncer, `receiving event ${EventTypeEnum[event.type]}`)
593 const type = EventType.create(event.type)
595 this.logger.info('Handling event from RTS:', EventTypeEnum[event.type])
597 switch (type.value) {
598 case EventTypeEnum.ServerIsPlacingEmptyActivityIndicatorInStreamToIndicateTheStreamIsStillActive:
599 case EventTypeEnum.ClientIsDebugRequestingServerToPerformCommit:
600 case EventTypeEnum.ServerIsNotifyingOtherServersToDisconnectAllClientsFromTheStream:
601 case EventTypeEnum.ServerIsRequestingOtherServersToBroadcastParticipants:
602 case EventTypeEnum.ServerIsInformingOtherServersOfTheParticipants:
604 case EventTypeEnum.ServerIsReadyToAcceptClientMessages:
605 this.onDocumentConnectionReadyToBroadcast(record, event.content)
607 case EventTypeEnum.ClientIsRequestingOtherClientsToBroadcastTheirState:
608 case EventTypeEnum.ServerIsRequestingClientToBroadcastItsState:
609 case EventTypeEnum.ServerHasMoreOrLessGivenTheClientEverythingItHas:
610 processedMessages.push(
611 new ProcessedIncomingRealtimeEventMessage({
616 case EventTypeEnum.ServerIsInformingClientThatTheDocumentCommitHasBeenUpdated:
617 processedMessages.push(
618 new ProcessedIncomingRealtimeEventMessage({
620 content: event.content,
624 case EventTypeEnum.ClientHasSentACommentMessage:
625 case EventTypeEnum.ClientIsBroadcastingItsPresenceState: {
626 const decryptionResult = await this._decryptMessage.execute({
628 documentContentKey: keys.documentContentKey,
631 if (decryptionResult.isFailed()) {
632 this.logger.error(`Failed to decrypt event: ${decryptionResult.getError()}`)
636 const decrypted = decryptionResult.getValue()
638 processedMessages.push(
639 new ProcessedIncomingRealtimeEventMessage({
641 content: decrypted.content,
648 assertUnreachableAndLog(type.value)
652 this.eventBus.publish<WebsocketConnectionEventPayloads[WebsocketConnectionEvent.EventMessage]>({
653 type: WebsocketConnectionEvent.EventMessage,
656 message: processedMessages,
662 * This is a debug utility exposed in development by the Debug menu and allows the client to force the RTS to commit immediately
663 * (rather than waiting for the next scheduled commit cycle)
665 public async debugSendCommitCommandToRTS(document: NodeMeta, keys: DocumentKeys): Promise<void> {
666 this.logger.info('Sending commit command to RTS')
668 const record = this.getConnectionRecord(document.linkId)
670 throw new Error('Connection not found')
673 const content = stringToUtf8Array(JSON.stringify({ authorAddress: keys.userOwnAddress }))
675 void this.sendEventMessage(
678 EventTypeEnum.ClientIsDebugRequestingServerToPerformCommit,
679 BroadcastSource.CommitDocumentUseCase,
683 public closeConnection(document: { linkId: string }): void {
684 this.logger.info('Closing connection')
686 const record = this.getConnectionRecord(document.linkId)
688 throw new Error('Connection not found')
691 void record.connection.disconnect(ConnectionCloseReason.CODES.NORMAL_CLOSURE)