Skip to content

Commit

Permalink
Fix minor issues related to total tokens calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
carlrobertoh committed Nov 14, 2023
1 parent 346218b commit 318dd42
Show file tree
Hide file tree
Showing 24 changed files with 425 additions and 398 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public void actionPerformed(@NotNull AnActionEvent event) {
if (project != null) {
try {
ConversationService.getInstance().clearAll();
project.getService(StandardChatToolWindowContentManager.class).resetActiveTab();
project.getService(StandardChatToolWindowContentManager.class).resetAll();
} finally {
TelemetryAction.IDE_ACTION.createActionMessage()
.property("action", ActionType.DELETE_ALL_CONVERSATIONS.name())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,19 @@ public class CompletionRequestHandler {

private final StringBuilder messageBuilder = new StringBuilder();
private final boolean useContextualSearch;
private final ToolWindowCompletionEventListener toolWindowCompletionEventListener;
private final CompletionResponseEventListener completionResponseEventListener;
private SwingWorker<Void, String> swingWorker;
private EventSource eventSource;

public CompletionRequestHandler(
boolean useContextualSearch,
ToolWindowCompletionEventListener toolWindowCompletionEventListener) {
CompletionResponseEventListener completionResponseEventListener) {
this.useContextualSearch = useContextualSearch;
this.toolWindowCompletionEventListener = toolWindowCompletionEventListener;
this.completionResponseEventListener = completionResponseEventListener;
}

public void call(Conversation conversation, Message message, boolean isRetry) {
swingWorker = new CompletionRequestWorker(conversation, message, isRetry);
public void call(Conversation conversation, Message message, boolean retry) {
swingWorker = new CompletionRequestWorker(conversation, message, retry);
swingWorker.execute();
}

Expand All @@ -56,7 +56,7 @@ private EventSource startCall(
if (ex instanceof TotalUsageExceededException) {
errorMessage = "The length of the context exceeds the maximum limit that the model can handle. Try reducing the input message or maximum completion token size.";
}
toolWindowCompletionEventListener.handleError(new ErrorDetails(errorMessage), ex);
completionResponseEventListener.handleError(new ErrorDetails(errorMessage), ex);
throw ex;
}
}
Expand All @@ -65,12 +65,12 @@ private class CompletionRequestWorker extends SwingWorker<Void, String> {

private final Conversation conversation;
private final Message message;
private final boolean isRetry;
private final boolean retry;

public CompletionRequestWorker(Conversation conversation, Message message, boolean isRetry) {
public CompletionRequestWorker(Conversation conversation, Message message, boolean retry) {
this.conversation = conversation;
this.message = message;
this.isRetry = isRetry;
this.retry = retry;
}

protected Void doInBackground() {
Expand All @@ -79,10 +79,10 @@ protected Void doInBackground() {
eventSource = startCall(
conversation,
message,
isRetry,
retry,
new YouRequestCompletionEventListener());
} catch (TotalUsageExceededException e) {
toolWindowCompletionEventListener.handleTokensExceeded(conversation, message);
completionResponseEventListener.handleTokensExceeded(conversation, message);
} finally {
sendInfo(settings);
}
Expand All @@ -93,15 +93,15 @@ protected void process(List<String> chunks) {
message.setResponse(messageBuilder.toString());
for (String text : chunks) {
messageBuilder.append(text);
toolWindowCompletionEventListener.handleMessage(text);
completionResponseEventListener.handleMessage(text);
}
}

class YouRequestCompletionEventListener implements YouCompletionEventListener {

@Override
public void onSerpResults(List<YouSerpResult> results) {
toolWindowCompletionEventListener.handleSerpResults(results, message);
completionResponseEventListener.handleSerpResults(results, message);
}

@Override
Expand All @@ -111,14 +111,17 @@ public void onMessage(String message) {

@Override
public void onComplete(StringBuilder messageBuilder) {
toolWindowCompletionEventListener.handleCompleted(messageBuilder.toString(), message,
conversation, isRetry);
completionResponseEventListener.handleCompleted(
messageBuilder.toString(),
message,
conversation,
retry);
}

@Override
public void onError(ErrorDetails error, Throwable ex) {
try {
toolWindowCompletionEventListener.handleError(error, ex);
completionResponseEventListener.handleError(error, ex);
} finally {
sendError(error, ex);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,18 +100,18 @@ public YouCompletionRequest buildYouCompletionRequest(Message message) {
public OpenAIChatCompletionRequest buildOpenAIChatCompletionRequest(
String model,
Message message,
boolean isRetry) {
return buildOpenAIChatCompletionRequest(model, message, isRetry, false, null);
boolean retry) {
return buildOpenAIChatCompletionRequest(model, message, retry, false, null);
}

public OpenAIChatCompletionRequest buildOpenAIChatCompletionRequest(
@Nullable String model,
Message message,
boolean isRetry,
boolean retry,
boolean useContextualSearch,
@Nullable String overriddenPath) {
var builder = new OpenAIChatCompletionRequest.Builder(
buildMessages(model, message, isRetry, useContextualSearch))
buildMessages(model, message, retry, useContextualSearch))
.setModel(model)
.setMaxTokens(ConfigurationState.getInstance().getMaxTokens())
.setTemperature(ConfigurationState.getInstance().getTemperature());
Expand All @@ -125,7 +125,7 @@ public OpenAIChatCompletionRequest buildOpenAIChatCompletionRequest(

public List<OpenAIChatCompletionMessage> buildMessages(
Message message,
boolean isRetry,
boolean retry,
boolean useContextualSearch) {
var messages = new ArrayList<OpenAIChatCompletionMessage>();
if (useContextualSearch) {
Expand All @@ -138,7 +138,7 @@ public List<OpenAIChatCompletionMessage> buildMessages(
systemPrompt.isEmpty() ? COMPLETION_SYSTEM_PROMPT : systemPrompt));

for (var prevMessage : conversation.getMessages()) {
if (isRetry && prevMessage.getId().equals(message.getId())) {
if (retry && prevMessage.getId().equals(message.getId())) {
break;
}
messages.add(new OpenAIChatCompletionMessage("user", prevMessage.getPrompt()));
Expand All @@ -152,9 +152,9 @@ public List<OpenAIChatCompletionMessage> buildMessages(
private List<OpenAIChatCompletionMessage> buildMessages(
@Nullable String model,
Message message,
boolean isRetry,
boolean retry,
boolean useContextualSearch) {
var messages = buildMessages(message, isRetry, useContextualSearch);
var messages = buildMessages(message, retry, useContextualSearch);

if (model == null || SettingsState.getInstance().getSelectedService() == ServiceType.YOU) {
return messages;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import ee.carlrobert.llm.client.you.completion.YouSerpResult;
import java.util.List;

public interface ToolWindowCompletionEventListener {
public interface CompletionResponseEventListener {

default void handleMessage(String message) {}

Expand All @@ -18,7 +18,7 @@ default void handleCompleted(
String fullMessage,
Message message,
Conversation conversation,
boolean isRetry) {}
boolean retry) {}

default void handleSerpResults(List<YouSerpResult> results, Message message) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,8 @@ public void removeMessage(UUID messageId) {
.filter(message -> !message.getId().equals(messageId))
.collect(toList()));
}

public void removeMessages() {
messages.clear();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ public void saveMessage(
String response,
Message message,
Conversation conversation,
boolean isRetry) {
boolean retry) {
var conversationMessages = conversation.getMessages();
if (isRetry && !conversationMessages.isEmpty()) {
if (retry && !conversationMessages.isEmpty()) {
var messageToBeSaved = conversationMessages.stream()
.filter(item -> item.getId().equals(message.getId()))
.findFirst().orElseThrow();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,6 @@ private void resetActiveTab() {
throw new RuntimeException("Could not find current project.");
}

project.getService(StandardChatToolWindowContentManager.class).resetActiveTab();
project.getService(StandardChatToolWindowContentManager.class).resetAll();
}
}
Loading

0 comments on commit 318dd42

Please sign in to comment.