Skip to content

Commit

Permalink
refactor: centralize cancellation token in chat model (#14644)
Browse files Browse the repository at this point in the history
So far it is the responsibility of the chat agent to create the cancellation token and react to the chat model being canceled to update the token that is passed to the language model request. This is very indirect and forces sub agents that may contribute to providing an answer to do the same over and over again for their language model requests.

Instead one chat request/response should have one cancellation token that is being created automatically. The UI can then invoke cancel on the chat request/response and all agents that are involved can reuse the cancellation token for their requests.
  • Loading branch information
planger authored Dec 20, 2024
1 parent 2098ff0 commit 8721248
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 21 deletions.
8 changes: 3 additions & 5 deletions packages/ai-chat-ui/src/browser/chat-view-widget.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0
// *****************************************************************************
import { CommandService, deepClone, Emitter, Event, MessageService } from '@theia/core';
import { ChatRequest, ChatRequestModel, ChatRequestModelImpl, ChatService, ChatSession } from '@theia/ai-chat';
import { ChatRequest, ChatRequestModel, ChatService, ChatSession } from '@theia/ai-chat';
import { BaseWidget, codicon, ExtractableWidget, Message, PanelLayout, PreferenceService, StatefulWidget } from '@theia/core/lib/browser';
import { nls } from '@theia/core/lib/common/nls';
import { inject, injectable, postConstruct } from '@theia/core/shared/inversify';
Expand Down Expand Up @@ -165,7 +165,7 @@ export class ChatViewWidget extends BaseWidget implements ExtractableWidget, Sta
const requestProgress = await this.chatService.sendRequest(this.chatSession.id, chatRequest);
requestProgress?.responseCompleted.then(responseModel => {
if (responseModel.isError) {
this.messageService.error(responseModel.errorObject?.message ?? 'An error occurred druring chat service invocation.');
this.messageService.error(responseModel.errorObject?.message ?? 'An error occurred during chat service invocation.');
}
});
if (!requestProgress) {
Expand All @@ -176,9 +176,7 @@ export class ChatViewWidget extends BaseWidget implements ExtractableWidget, Sta
}

protected onCancel(requestModel: ChatRequestModel): void {
// TODO we should pass a cancellation token with the request (or retrieve one from the request invocation) so we can cleanly cancel here
// For now we cancel manually via casting
(requestModel as ChatRequestModelImpl).response.cancel();
this.chatService.cancelRequest(requestModel.session.id, requestModel.id);
}

lock(): void {
Expand Down
11 changes: 2 additions & 9 deletions packages/ai-chat/src/common/chat-agents.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import {
LanguageModelStreamResponsePart,
MessageActor,
} from '@theia/ai-core/lib/common';
import { CancellationToken, CancellationTokenSource, ContributionProvider, ILogger, isArray } from '@theia/core';
import { CancellationToken, ContributionProvider, ILogger, isArray } from '@theia/core';
import { inject, injectable, named, postConstruct, unmanaged } from '@theia/core/shared/inversify';
import { ChatAgentService } from './chat-agent-service';
import {
Expand Down Expand Up @@ -186,18 +186,11 @@ export abstract class AbstractChatAgent {
}
this.getTools(request)?.forEach(tool => tools.set(tool.id, tool));

const cancellationToken = new CancellationTokenSource();
request.response.onDidChange(() => {
if (request.response.isCanceled) {
cancellationToken.cancel();
}
});

const languageModelResponse = await this.callLlm(
languageModel,
messages,
tools.size > 0 ? Array.from(tools.values()) : undefined,
cancellationToken.token
request.response.cancellationToken
);
await this.addContentsToResponse(languageModelResponse, request);
await this.onResponseComplete(request);
Expand Down
23 changes: 17 additions & 6 deletions packages/ai-chat/src/common/chat-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
*--------------------------------------------------------------------------------------------*/
// Partially copied from https://github.com/microsoft/vscode/blob/a2cab7255c0df424027be05d58e1b7b941f4ea60/src/vs/workbench/contrib/chat/common/chatModel.ts

import { Command, Emitter, Event, generateUuid, URI } from '@theia/core';
import { CancellationToken, CancellationTokenSource, Command, Emitter, Event, generateUuid, URI } from '@theia/core';
import { MarkdownString, MarkdownStringImpl } from '@theia/core/lib/common/markdown-rendering';
import { Position } from '@theia/core/shared/vscode-languageserver-protocol';
import { ChatAgentLocation } from './chat-agents';
Expand Down Expand Up @@ -397,6 +397,10 @@ export class ChatModelImpl implements ChatModel {
return this._requests;
}

getRequest(id: string): ChatRequestModelImpl | undefined {
return this._requests.find(request => request.id === id);
}

get id(): string {
return this._id;
}
Expand Down Expand Up @@ -466,6 +470,10 @@ export class ChatRequestModelImpl implements ChatRequestModel {
get agentId(): string | undefined {
return this._agentId;
}

cancel(): void {
this.response.cancel();
}
}

export class ErrorChatResponseContentImpl implements ErrorChatResponseContent {
Expand Down Expand Up @@ -798,11 +806,11 @@ class ChatResponseModelImpl implements ChatResponseModel {
protected _progressMessages: ChatProgressMessage[];
protected _response: ChatResponseImpl;
protected _isComplete: boolean;
protected _isCanceled: boolean;
protected _isWaitingForInput: boolean;
protected _agentId?: string;
protected _isError: boolean;
protected _errorObject: Error | undefined;
protected _cancellationToken: CancellationTokenSource;

constructor(requestId: string, agentId?: string) {
// TODO accept serialized data as a parameter to restore a previously saved ChatResponseModel
Expand All @@ -813,9 +821,9 @@ class ChatResponseModelImpl implements ChatResponseModel {
response.onDidChange(() => this._onDidChangeEmitter.fire());
this._response = response;
this._isComplete = false;
this._isCanceled = false;
this._isWaitingForInput = false;
this._agentId = agentId;
this._cancellationToken = new CancellationTokenSource();
}

get id(): string {
Expand Down Expand Up @@ -870,7 +878,7 @@ class ChatResponseModelImpl implements ChatResponseModel {
}

get isCanceled(): boolean {
return this._isCanceled;
return this._cancellationToken.token.isCancellationRequested;
}

get isWaitingForInput(): boolean {
Expand All @@ -892,12 +900,16 @@ class ChatResponseModelImpl implements ChatResponseModel {
}

cancel(): void {
this._cancellationToken.cancel();
this._isComplete = true;
this._isCanceled = true;
this._isWaitingForInput = false;
this._onDidChangeEmitter.fire();
}

get cancellationToken(): CancellationToken {
return this._cancellationToken.token;
}

waitForInput(): void {
this._isWaitingForInput = true;
this._onDidChangeEmitter.fire();
Expand All @@ -910,7 +922,6 @@ class ChatResponseModelImpl implements ChatResponseModel {

error(error: Error): void {
this._isComplete = true;
this._isCanceled = false;
this._isWaitingForInput = false;
this._isError = true;
this._errorObject = error;
Expand Down
6 changes: 6 additions & 0 deletions packages/ai-chat/src/common/chat-service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ export interface ChatService {
sessionId: string,
request: ChatRequest
): Promise<ChatRequestInvocation | undefined>;

cancelRequest(sessionId: string, requestId: string): Promise<void>;
}

interface ChatSessionInternal extends ChatSession {
Expand Down Expand Up @@ -219,6 +221,10 @@ export class ChatServiceImpl implements ChatService {
return invocation;
}

async cancelRequest(sessionId: string, requestId: string): Promise<void> {
return this.getSession(sessionId)?.model.getRequest(requestId)?.response.cancel();
}

protected getAgent(parsedRequest: ParsedChatRequest): ChatAgent | undefined {
const agentPart = this.getMentionedAgent(parsedRequest);
if (agentPart) {
Expand Down
2 changes: 1 addition & 1 deletion packages/ai-chat/src/common/command-chat-agents.ts
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ export class CommandChatAgent extends AbstractTextToModelParsingChatAgent<Parsed
const theiaCommand = this.commandRegistry.getCommand(parsedCommand.commandId);
if (theiaCommand === undefined) {
console.error(`No Theia Command with id ${parsedCommand.commandId}`);
request.response.cancel();
request.cancel();
}
const args = parsedCommand.arguments !== undefined &&
parsedCommand.arguments.length > 0
Expand Down

0 comments on commit 8721248

Please sign in to comment.