Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Agent API #163

Closed
wants to merge 11 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ private static class ChatMessages {

private final List<ChatRequestMessage> newMessages;
private final List<ChatRequestMessage> allMessages;
private final List<OpenAIChatMessageContent> newChatMessageContent;
private final List<OpenAIChatMessageContent<?>> newChatMessageContent;

public ChatMessages(List<ChatRequestMessage> allMessages) {
this.allMessages = Collections.unmodifiableList(allMessages);
Expand All @@ -195,7 +195,7 @@ public ChatMessages(List<ChatRequestMessage> allMessages) {
private ChatMessages(
List<ChatRequestMessage> allMessages,
List<ChatRequestMessage> newMessages,
List<OpenAIChatMessageContent> newChatMessageContent) {
List<OpenAIChatMessageContent<?>> newChatMessageContent) {
this.allMessages = Collections.unmodifiableList(allMessages);
this.newMessages = Collections.unmodifiableList(newMessages);
this.newChatMessageContent = Collections.unmodifiableList(newChatMessageContent);
Expand All @@ -219,8 +219,8 @@ public ChatMessages add(ChatRequestMessage requestMessage) {
}

@CheckReturnValue
public ChatMessages addChatMessage(List<OpenAIChatMessageContent> chatMessageContent) {
ArrayList<OpenAIChatMessageContent> tmpChatMessageContent = new ArrayList<>(
public ChatMessages addChatMessage(List<OpenAIChatMessageContent<?>> chatMessageContent) {
ArrayList<OpenAIChatMessageContent<?>> tmpChatMessageContent = new ArrayList<>(
newChatMessageContent);
tmpChatMessageContent.addAll(chatMessageContent);

Expand Down Expand Up @@ -357,19 +357,16 @@ private Mono<ChatMessages> internalChatMessageContentsAsync(
// If we don't want to attempt to invoke any functions
// Or if we are auto-invoking, but we somehow end up with other than 1 choice even though only 1 was requested
if (autoInvokeAttempts == 0 || responseMessages.size() != 1) {
return getChatMessageContentsAsync(completions)
.flatMap(m -> {
return Mono.just(messages.addChatMessage(m));
});
List<OpenAIChatMessageContent<?>> chatMessageContents = getChatMessageContentsAsync(completions);
return Mono.just(messages.addChatMessage(chatMessageContents));
}
// Or if there are no tool calls to be done
ChatResponseMessage response = responseMessages.get(0);
List<ChatCompletionsToolCall> toolCalls = response.getToolCalls();
if (toolCalls == null || toolCalls.isEmpty()) {
return getChatMessageContentsAsync(completions)
.flatMap(m -> {
return Mono.just(messages.addChatMessage(m));
});
List<OpenAIChatMessageContent<?>> chatMessageContents = getChatMessageContentsAsync(
completions);
return Mono.just(messages.addChatMessage(chatMessageContents));
}

ChatRequestAssistantMessage requestMessage = new ChatRequestAssistantMessage(
Expand Down Expand Up @@ -592,7 +589,7 @@ private OpenAIFunctionToolCall extractOpenAIFunctionToolCall(
arguments);
}

private Mono<List<OpenAIChatMessageContent>> getChatMessageContentsAsync(
private List<OpenAIChatMessageContent<?>> getChatMessageContentsAsync(
ChatCompletions completions) {
FunctionResultMetadata<CompletionsUsage> completionMetadata = FunctionResultMetadata.build(
completions.getId(),
Expand All @@ -606,22 +603,28 @@ private Mono<List<OpenAIChatMessageContent>> getChatMessageContentsAsync(
.filter(Objects::nonNull)
.collect(Collectors.toList());

return Flux.fromIterable(responseMessages)
.flatMap(response -> {
List<OpenAIChatMessageContent<?>> chatMessageContent =
responseMessages
.stream()
.map(response -> {
try {
return Mono.just(new OpenAIChatMessageContent(
return new OpenAIChatMessageContent<>(
AuthorRole.ASSISTANT,
response.getContent(),
this.getModelId(),
null,
null,
completionMetadata,
formOpenAiToolCalls(response)));
} catch (Exception e) {
return Mono.error(e);
formOpenAiToolCalls(response));
} catch (SKCheckedException e) {
LOGGER.warn("Failed to form chat message content", e);
return null;
}
})
.collectList();
.filter(Objects::nonNull)
.collect(Collectors.toList());

return chatMessageContent;
}

private List<ChatMessageContent<?>> toOpenAIChatMessageContent(
Expand Down Expand Up @@ -931,7 +934,7 @@ private static boolean hasToolCallBeenExecuted(List<ChatRequestMessage> chatRequ
}

private static List<ChatRequestMessage> getChatRequestMessages(
List<? extends ChatMessageContent> messages) {
List<? extends ChatMessageContent<?>> messages) {
if (messages == null || messages.isEmpty()) {
return new ArrayList<>();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public OpenAIChatMessageContent(
@Nullable String modelId,
@Nullable T innerContent,
@Nullable Charset encoding,
@Nullable FunctionResultMetadata metadata,
@Nullable FunctionResultMetadata<?> metadata,
@Nullable List<OpenAIFunctionToolCall> toolCall) {
super(authorRole, content, modelId, innerContent, encoding, metadata);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import com.microsoft.semantickernel.services.chatcompletion.message.ChatMessageTextContent;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.Spliterator;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.function.Consumer;
import javax.annotation.Nullable;

Expand All @@ -18,7 +20,7 @@
*/
public class ChatHistory implements Iterable<ChatMessageContent<?>> {

private final List<ChatMessageContent<?>> chatMessageContents;
private final Collection<ChatMessageContent<?>> chatMessageContents;

/**
* The default constructor
Expand All @@ -33,7 +35,7 @@ public ChatHistory() {
* @param instructions The instructions to add to the chat history
*/
public ChatHistory(@Nullable String instructions) {
this.chatMessageContents = new ArrayList<>();
this.chatMessageContents = new ConcurrentLinkedQueue<>();
if (instructions != null) {
this.chatMessageContents.add(
ChatMessageTextContent.systemMessage(instructions));
Expand All @@ -45,8 +47,8 @@ public ChatHistory(@Nullable String instructions) {
*
* @param chatMessageContents The chat message contents to add to the chat history
*/
public ChatHistory(List<? extends ChatMessageContent> chatMessageContents) {
this.chatMessageContents = new ArrayList(chatMessageContents);
public ChatHistory(List<? extends ChatMessageContent<?>> chatMessageContents) {
this.chatMessageContents = new ConcurrentLinkedQueue<>(chatMessageContents);
}

/**
Expand All @@ -55,7 +57,7 @@ public ChatHistory(List<? extends ChatMessageContent> chatMessageContents) {
* @return List of messages in the chat
*/
public List<ChatMessageContent<?>> getMessages() {
return Collections.unmodifiableList(chatMessageContents);
return Collections.unmodifiableList(new ArrayList<>(chatMessageContents));
}

/**
Expand All @@ -67,7 +69,7 @@ public Optional<ChatMessageContent<?>> getLastMessage() {
if (chatMessageContents.isEmpty()) {
return Optional.empty();
}
return Optional.of(chatMessageContents.get(chatMessageContents.size() - 1));
return Optional.of(((ConcurrentLinkedQueue<ChatMessageContent<?>>)chatMessageContents).peek());
}

/**
Expand Down Expand Up @@ -114,7 +116,7 @@ public Spliterator<ChatMessageContent<?>> spliterator() {
* @param metadata The metadata of the message
*/
public ChatHistory addMessage(AuthorRole authorRole, String content, Charset encoding,
FunctionResultMetadata metadata) {
FunctionResultMetadata<?> metadata) {
chatMessageContents.add(
ChatMessageTextContent.builder()
.withAuthorRole(authorRole)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
// Copyright (c) Microsoft. All rights reserved.
package com.microsoft.semantickernel.agents;

import java.util.List;

import javax.annotation.Nullable;

import reactor.core.publisher.Mono;

/**
* Base abstraction for all Semantic Kernel agents. An agent instance
* may participate in one or more conversations, or {@link AgentChat}.
* A conversation may include one or more agents.
*
* In addition to identity and descriptive meta-data, an {@link Agent}
* must define its communication protocol, or {@link AgentChannel}.
*
* @param The type of {@code AgentChannel} associated with the agent.
*/
public abstract class Agent {

/**
* The description of the agent (optional)
*/
private final String description;

/**
* The identifier of the agent (optional).
* Default to a random guid value, but may be overridden.
*/
private final String id;

/**
* The name of the agent (optional)
*/
private final String name;

/**
* Construct a new {@link Agent} instance.
* @param id The identifier of the agent.
* @param name The name of the agent.
* @param description The description of the agent.
*/
protected Agent(
@Nullable String id,
@Nullable String name,
@Nullable String description) {
this.id = id;
this.name = name;
this.description = description;
}

/**
* Get the description of the agent.
* @return The description of the agent.
*/
public String getDescription() {
return description;
}

/**
* Get the identifier of the agent.
* @return The identifier of the agent.
*/
public String getId() {
return id;
}

/**
* Get the name of the agent.
* @return The name of the agent.
*/
public String getName() {
return name;
}

/**
* Set of keys to establish channel affinity.
* Two specific agents of the same type may each require their own channel. This is
* why the channel type alone is insufficient.
* For example, two OpenAI Assistant agents each targeting a different Azure OpenAI endpoint
* would require their own channel. In this case, the endpoint could be expressed as an additional key.
*/
protected abstract List<String> getChannelKeys();

/**
* Produce the an {@link AgentChannel} appropriate for the agent type.
* Every agent conversation, or {@link AgentChat}, will establish one or more
* {@link AgentChannel} objects according to the specific {@link Agent} type.
*
* @return An {@link AgentChannel} appropriate for the agent type.
*/
protected abstract Mono<AgentChannel> createChannelAsync();

/**
* Base class for agent builders.
*/
public abstract static class Builder<TAgent extends Agent> {

protected String id;
protected String name;
protected String description;

public Builder<? extends Agent> withId(String id) {
this.id = id;
return this;
}

public Builder<? extends Agent> withName(String name) {
this.name = name;
return this;
}

public Builder<? extends Agent> withDescription(String description) {
this.description = description;
return this;
}

public abstract TAgent build();

protected Builder() {
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (c) Microsoft. All rights reserved.
package com.microsoft.semantickernel.agents;

import java.util.List;

import javax.annotation.Nonnull;

import com.microsoft.semantickernel.services.chatcompletion.ChatMessageContent;

import reactor.core.publisher.Mono;

/**
* Defines the communication protocol for a particular {@code Agent} type.
* An {@code Agent} provides its own {@code AgentChannel} via
* {@link Agent#createChannelAsync()}.
* @param <? extends Agent> The type of agent that this channel is associated with.
*/
public interface AgentChannel {


/**
* Receive the conversation messages. Used when joining a conversation and also during each agent interaction.
*
* @param history The chat history at the point the channel is created.
* @return A future task that completes when the conversation messages are received.
*/
Mono<Void> receiveAsync(List<ChatMessageContent<?>> history);

/**
* Perform a discrete incremental interaction between a single Agent and AgentChat.
* @param agent The agent actively interacting with the chat.
* @return Asynchronous enumeration of messages.
*/
Mono<List<ChatMessageContent<?>>> invokeAsync(@Nonnull Agent agent);

/**
* Retrieve the message history specific to this channel.
* @return Asynchronous enumeration of messages.
*/
Mono<List<ChatMessageContent<?>>> getHistoryAsync();
}
Loading