Skip to content

Commit

Permalink
feat: improve chat UI performance
Browse files Browse the repository at this point in the history
  • Loading branch information
carlrobertoh committed Aug 2, 2024
1 parent 658e78f commit 6479604
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 154 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,37 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import ee.carlrobert.codegpt.events.CodeGPTEvent;
import ee.carlrobert.codegpt.settings.GeneralSettings;
import ee.carlrobert.codegpt.settings.GeneralSettingsState;
import ee.carlrobert.codegpt.telemetry.TelemetryAction;
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
import ee.carlrobert.llm.completion.CompletionEventListener;
import java.util.List;
import javax.swing.SwingWorker;
import okhttp3.sse.EventSource;

public class CompletionRequestHandler {

private final StringBuilder messageBuilder = new StringBuilder();
private final CompletionResponseEventListener completionResponseEventListener;
private SwingWorker<Void, String> swingWorker;
private EventSource eventSource;

public CompletionRequestHandler(CompletionResponseEventListener completionResponseEventListener) {
this.completionResponseEventListener = completionResponseEventListener;
}

public void call(CallParameters callParameters) {
swingWorker = new CompletionRequestWorker(callParameters);
swingWorker.execute();
try {
eventSource = startCall(callParameters, new RequestCompletionEventListener(callParameters));
} catch (TotalUsageExceededException e) {
completionResponseEventListener.handleTokensExceeded(
callParameters.getConversation(),
callParameters.getMessage());
} finally {
sendInfo(callParameters);
}
}

public void cancel() {
if (eventSource != null) {
eventSource.cancel();
}
swingWorker.cancel(true);
}

private EventSource startCall(
Expand All @@ -57,79 +59,48 @@ private void handleCallException(Throwable ex) {
completionResponseEventListener.handleError(new ErrorDetails(errorMessage), ex);
}

private class CompletionRequestWorker extends SwingWorker<Void, String> {
class RequestCompletionEventListener implements CompletionEventListener<String> {

private final CallParameters callParameters;

public CompletionRequestWorker(CallParameters callParameters) {
public RequestCompletionEventListener(CallParameters callParameters) {
this.callParameters = callParameters;
}

protected Void doInBackground() {
var settings = GeneralSettings.getCurrentState();
@Override
public void onEvent(String data) {
try {
eventSource = startCall(callParameters, new RequestCompletionEventListener());
} catch (TotalUsageExceededException e) {
completionResponseEventListener.handleTokensExceeded(
callParameters.getConversation(),
callParameters.getMessage());
} finally {
sendInfo(settings);
var event = new ObjectMapper().readValue(data, CodeGPTEvent.class);
completionResponseEventListener.handleCodeGPTEvent(event);
} catch (JsonProcessingException e) {
// ignore
}
return null;
}

protected void process(List<String> chunks) {
@Override
public void onMessage(String message, EventSource eventSource) {
messageBuilder.append(message);
callParameters.getMessage().setResponse(messageBuilder.toString());
for (String text : chunks) {
messageBuilder.append(text);
completionResponseEventListener.handleMessage(text);
}
completionResponseEventListener.handleMessage(message);
}

class RequestCompletionEventListener implements CompletionEventListener<String> {

@Override
public void onEvent(String data) {
try {
var event = new ObjectMapper().readValue(data, CodeGPTEvent.class);
completionResponseEventListener.handleCodeGPTEvent(event);
} catch (JsonProcessingException e) {
// ignore
}
}

@Override
public void onMessage(String message, EventSource eventSource) {
publish(message);
}

@Override
public void onComplete(StringBuilder messageBuilder) {
completionResponseEventListener.handleCompleted(messageBuilder.toString(), callParameters);
}

@Override
public void onCancelled(StringBuilder messageBuilder) {
completionResponseEventListener.handleCompleted(messageBuilder.toString(), callParameters);
}
@Override
public void onComplete(StringBuilder messageBuilder) {
completionResponseEventListener.handleCompleted(messageBuilder.toString(), callParameters);
}

@Override
public void onError(ErrorDetails error, Throwable ex) {
try {
completionResponseEventListener.handleError(error, ex);
} finally {
sendError(error, ex);
}
}
@Override
public void onCancelled(StringBuilder messageBuilder) {
completionResponseEventListener.handleCompleted(messageBuilder.toString(), callParameters);
}

private void sendInfo(GeneralSettingsState settings) {
TelemetryAction.COMPLETION.createActionMessage()
.property("conversationId", callParameters.getConversation().getId().toString())
.property("model", callParameters.getConversation().getModel())
.property("service", settings.getSelectedService().getCode().toLowerCase())
.send();
@Override
public void onError(ErrorDetails error, Throwable ex) {
try {
completionResponseEventListener.handleError(error, ex);
} finally {
sendError(error, ex);
}
}

private void sendError(ErrorDetails error, Throwable ex) {
Expand All @@ -147,4 +118,12 @@ private void sendError(ErrorDetails error, Throwable ex) {
telemetryMessage.send();
}
}

private void sendInfo(CallParameters callParameters) {
TelemetryAction.COMPLETION.createActionMessage()
.property("conversationId", callParameters.getConversation().getId().toString())
.property("model", callParameters.getConversation().getModel())
.property("service", GeneralSettings.getSelectedService().getCode().toLowerCase())
.send();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import static java.lang.String.format;

import com.intellij.openapi.Disposable;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.diagnostic.Logger;
import com.intellij.openapi.project.Project;
import com.intellij.ui.JBColor;
import com.intellij.util.ui.JBUI;
import ee.carlrobert.codegpt.CodeGPTKeys;
import ee.carlrobert.codegpt.EncodingManager;
import ee.carlrobert.codegpt.ReferencedFile;
import ee.carlrobert.codegpt.actions.ActionType;
import ee.carlrobert.codegpt.completions.CallParameters;
Expand Down Expand Up @@ -41,7 +41,6 @@
import java.util.UUID;
import javax.swing.JComponent;
import javax.swing.JPanel;
import javax.swing.SwingUtilities;
import kotlin.Unit;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
Expand Down Expand Up @@ -111,7 +110,7 @@ public void sendMessage(Message message) {
}

public void sendMessage(Message message, ConversationType conversationType) {
SwingUtilities.invokeLater(() -> {
ApplicationManager.getApplication().invokeLater(() -> {
var referencedFiles = project.getUserData(CodeGPTKeys.SELECTED_FILES);
var chatToolWindowPanel = project.getService(ChatToolWindowContentManager.class)
.tryFindChatToolWindowPanel();
Expand All @@ -127,6 +126,7 @@ public void sendMessage(Message message, ConversationType conversationType) {

chatToolWindowPanel.ifPresent(panel -> panel.clearNotifications(project));
}
totalTokensPanel.updateConversationTokens(conversation);

var userMessagePanel = new UserMessagePanel(project, message, this);
var attachedFilePath = CodeGPTKeys.IMAGE_ATTACHMENT_FILE_PATH.get(project);
Expand All @@ -142,7 +142,6 @@ public void sendMessage(Message message, ConversationType conversationType) {

var responsePanel = createResponsePanel(message, conversationType);
messagePanel.add(responsePanel);
updateTotalTokens(message);
call(callParameters, responsePanel);
});
}
Expand All @@ -163,12 +162,6 @@ private CallParameters getCallParameters(
return callParameters;
}

private void updateTotalTokens(Message message) {
int userPromptTokens = EncodingManager.getInstance().countTokens(message.getPrompt());
int conversationTokens = EncodingManager.getInstance().countConversationTokens(conversation);
totalTokensPanel.updateConversationTokens(conversationTokens + userPromptTokens);
}

private ResponsePanel createResponsePanel(Message message, ConversationType conversationType) {
return new ResponsePanel()
.withReloadAction(() -> reloadMessage(message, conversation, conversationType))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import ee.carlrobert.codegpt.ui.OverlayUtil;
import ee.carlrobert.codegpt.ui.textarea.UserInputPanel;
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
import javax.swing.SwingUtilities;

abstract class ToolWindowCompletionResponseEventListener implements
CompletionResponseEventListener {
Expand Down Expand Up @@ -54,17 +53,16 @@ public ToolWindowCompletionResponseEventListener(
@Override
public void handleMessage(String partialMessage) {
try {
ApplicationManager.getApplication()
.invokeLater(() -> {
responseContainer.update(partialMessage);
messageBuilder.append(partialMessage);

if (!completed) {
var ongoingTokens = encodingManager.countTokens(messageBuilder.toString());
totalTokensPanel.update(
totalTokensPanel.getTokenDetails().getTotal() + ongoingTokens);
}
});
responseContainer.update(partialMessage);
messageBuilder.append(partialMessage);

if (!completed) {
var ongoingTokens = encodingManager.countTokens(messageBuilder.toString());
ApplicationManager.getApplication().invokeLater(() -> {
totalTokensPanel.update(
totalTokensPanel.getTokenDetails().getTotal() + ongoingTokens);
});
}
} catch (Exception e) {
responseContainer.displayError("Something went wrong.");
throw new RuntimeException("Error while updating the content", e);
Expand All @@ -73,7 +71,7 @@ public void handleMessage(String partialMessage) {

@Override
public void handleError(ErrorDetails error, Throwable ex) {
SwingUtilities.invokeLater(() -> {
ApplicationManager.getApplication().invokeLater(() -> {
try {
if ("insufficient_quota".equals(error.getCode())) {
responseContainer.displayQuotaExceeded();
Expand All @@ -90,7 +88,7 @@ public void handleError(ErrorDetails error, Throwable ex) {

@Override
public void handleTokensExceeded(Conversation conversation, Message message) {
SwingUtilities.invokeLater(() -> {
ApplicationManager.getApplication().invokeLater(() -> {
var answer = OverlayUtil.showTokenLimitExceededDialog();
if (answer == OK) {
TelemetryAction.IDE_ACTION.createActionMessage()
Expand All @@ -110,7 +108,7 @@ public void handleTokensExceeded(Conversation conversation, Message message) {
public void handleCompleted(String fullMessage, CallParameters callParameters) {
conversationService.saveMessage(fullMessage, callParameters);

SwingUtilities.invokeLater(() -> {
ApplicationManager.getApplication().invokeLater(() -> {
try {
responsePanel.enableActions();
totalTokensPanel.updateUserPromptTokens(textArea.getText());
Expand All @@ -123,7 +121,8 @@ public void handleCompleted(String fullMessage, CallParameters callParameters) {

@Override
public void handleCodeGPTEvent(CodeGPTEvent event) {
responseContainer.displayWebSearchItem(event.getEvent().getDetails());
ApplicationManager.getApplication().invokeLater(() ->
responseContainer.displayWebSearchItem(event.getEvent().getDetails()));
}

private void stopStreaming(ChatMessageResponseBody responseContainer) {
Expand Down
Loading

0 comments on commit 6479604

Please sign in to comment.