Skip to content

Commit

Permalink
multi-tenancy + sdk client related changes in agents
Browse files Browse the repository at this point in the history
Signed-off-by: Dhrubo Saha <[email protected]>
  • Loading branch information
dhrubo-os committed Jan 24, 2025
1 parent af96fe0 commit 6b2ef3d
Show file tree
Hide file tree
Showing 50 changed files with 1,344 additions and 581 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,23 @@ default ActionFuture<DeleteResponse> deleteAgent(String agentId) {
return actionFuture;
}

void deleteAgent(String agentId, ActionListener<DeleteResponse> listener);
/**
* Delete agent
* @param agentId The id of the agent to delete
* @param listener a listener to be notified of the result
*/
default void deleteAgent(String agentId, ActionListener<DeleteResponse> listener) {
PlainActionFuture<DeleteResponse> actionFuture = PlainActionFuture.newFuture();
deleteAgent(agentId, null, actionFuture);
}

/**
* Delete agent
* @param agentId The id of the agent to delete
* @param tenantId the tenant id. This is necessary for multi-tenancy.
* @param listener a listener to be notified of the result
*/
void deleteAgent(String agentId, String tenantId, ActionListener<DeleteResponse> listener);

/**
* Get a list of ToolMetadata and return ActionFuture.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,8 @@ public void registerAgent(MLAgent mlAgent, ActionListener<MLRegisterAgentRespons
}

@Override
public void deleteAgent(String agentId, ActionListener<DeleteResponse> listener) {
MLAgentDeleteRequest agentDeleteRequest = new MLAgentDeleteRequest(agentId);
public void deleteAgent(String agentId, String tenantId, ActionListener<DeleteResponse> listener) {
MLAgentDeleteRequest agentDeleteRequest = new MLAgentDeleteRequest(agentId, tenantId);
client.execute(MLAgentDeleteAction.INSTANCE, agentDeleteRequest, ActionListener.wrap(listener::onResponse, listener::onFailure));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,11 @@ public void deleteAgent(String agentId, ActionListener<DeleteResponse> listener)
listener.onResponse(deleteResponse);
}

@Override
public void deleteAgent(String agentId, String tenantId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
}

@Override
public void getConfig(String configId, ActionListener<MLConfig> listener) {
listener.onResponse(mlConfig);
Expand Down
4 changes: 1 addition & 3 deletions common/src/main/java/org/opensearch/ml/common/MLModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -315,9 +315,7 @@ public MLModel(StreamInput input) throws IOException {
if (input.readBoolean()) {
modelInterface = input.readMap(StreamInput::readString, StreamInput::readString);
}
if (streamInputVersion.onOrAfter(VERSION_2_19_0)) {
tenantId = input.readOptionalString();
}
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null;
}
}

Expand Down
26 changes: 21 additions & 5 deletions common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
package org.opensearch.ml.common.agent;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;

import java.io.IOException;
Expand Down Expand Up @@ -63,6 +65,7 @@ public class MLAgent implements ToXContentObject, Writeable {
private Instant lastUpdateTime;
private String appType;
private Boolean isHidden;
private final String tenantId;

@Builder(toBuilder = true)
public MLAgent(
Expand All @@ -76,7 +79,8 @@ public MLAgent(
Instant createdTime,
Instant lastUpdateTime,
String appType,
Boolean isHidden
Boolean isHidden,
String tenantId
) {
this.name = name;
this.type = type;
Expand All @@ -90,6 +94,7 @@ public MLAgent(
this.appType = appType;
// is_hidden field isn't going to be set by user. It will be set by the code.
this.isHidden = isHidden;
this.tenantId = tenantId;
validate();
}

Expand Down Expand Up @@ -155,6 +160,7 @@ public MLAgent(StreamInput input) throws IOException {
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_HIDDEN_AGENT)) {
isHidden = input.readOptionalBoolean();
}
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null;
validate();
}

Expand All @@ -169,7 +175,7 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
if (tools != null && tools.size() > 0) {
if (tools != null && !tools.isEmpty()) {
out.writeBoolean(true);
out.writeInt(tools.size());
for (MLToolSpec tool : tools) {
Expand All @@ -178,7 +184,7 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
if (parameters != null && parameters.size() > 0) {
if (parameters != null && !parameters.isEmpty()) {
out.writeBoolean(true);
out.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeOptionalString);
} else {
Expand All @@ -197,6 +203,9 @@ public void writeTo(StreamOutput out) throws IOException {
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_HIDDEN_AGENT)) {
out.writeOptionalBoolean(isHidden);
}
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
out.writeOptionalString(tenantId);
}
}

@Override
Expand Down Expand Up @@ -236,6 +245,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (isHidden != null) {
builder.field(MLModel.IS_HIDDEN_FIELD, isHidden);
}
if (tenantId != null) {
builder.field(TENANT_ID_FIELD, tenantId);
}
builder.endObject();
return builder;
}
Expand All @@ -260,6 +272,7 @@ private static MLAgent parseCommonFields(XContentParser parser, boolean parseHid
Instant lastUpdateTime = null;
String appType = null;
boolean isHidden = false;
String tenantId = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -305,6 +318,9 @@ private static MLAgent parseCommonFields(XContentParser parser, boolean parseHid
if (parseHidden)
isHidden = parser.booleanValue();
break;
case TENANT_ID_FIELD:
tenantId = parser.textOrNull();
break;
default:
parser.skipChildren();
break;
Expand All @@ -324,11 +340,11 @@ private static MLAgent parseCommonFields(XContentParser parser, boolean parseHid
.lastUpdateTime(lastUpdateTime)
.appType(appType)
.isHidden(isHidden)
.tenantId(tenantId)
.build();
}

public static MLAgent fromStream(StreamInput in) throws IOException {
MLAgent agent = new MLAgent(in);
return agent;
return new MLAgent(in);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
package org.opensearch.ml.common.agent;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;

import java.io.IOException;
Expand All @@ -22,6 +24,7 @@
import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;

@EqualsAndHashCode
@Getter
Expand All @@ -41,6 +44,8 @@ public class MLToolSpec implements ToXContentObject {
private Map<String, String> parameters;
private boolean includeOutputInAgentResponse;
private Map<String, String> configMap;
@Setter
private String tenantId;

@Builder(toBuilder = true)
public MLToolSpec(
Expand All @@ -49,7 +54,8 @@ public MLToolSpec(
String description,
Map<String, String> parameters,
boolean includeOutputInAgentResponse,
Map<String, String> configMap
Map<String, String> configMap,
String tenantId
) {
if (type == null) {
throw new IllegalArgumentException("tool type is null");
Expand All @@ -60,9 +66,11 @@ public MLToolSpec(
this.parameters = parameters;
this.includeOutputInAgentResponse = includeOutputInAgentResponse;
this.configMap = configMap;
this.tenantId = tenantId;
}

public MLToolSpec(StreamInput input) throws IOException {
Version streamInputVersion = input.getVersion();
type = input.readString();
name = input.readOptionalString();
description = input.readOptionalString();
Expand All @@ -73,13 +81,15 @@ public MLToolSpec(StreamInput input) throws IOException {
if (input.getVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_TOOL_CONFIG) && input.readBoolean()) {
configMap = input.readMap(StreamInput::readString, StreamInput::readOptionalString);
}
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null;
}

public void writeTo(StreamOutput out) throws IOException {
Version streamOutputVersion = out.getVersion();
out.writeString(type);
out.writeOptionalString(name);
out.writeOptionalString(description);
if (parameters != null && parameters.size() > 0) {
if (parameters != null && !parameters.isEmpty()) {
out.writeBoolean(true);
out.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeOptionalString);
} else {
Expand All @@ -94,6 +104,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(false);
}
}
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
out.writeOptionalString(tenantId);
}
}

@Override
Expand All @@ -108,13 +121,16 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (description != null) {
builder.field(DESCRIPTION_FIELD, description);
}
if (parameters != null && parameters.size() > 0) {
if (parameters != null && !parameters.isEmpty()) {
builder.field(PARAMETERS_FIELD, parameters);
}
builder.field(INCLUDE_OUTPUT_IN_AGENT_RESPONSE, includeOutputInAgentResponse);
if (configMap != null && !configMap.isEmpty()) {
builder.field(CONFIG_FIELD, configMap);
}
if (tenantId != null) {
builder.field(TENANT_ID_FIELD, tenantId);
}
builder.endObject();
return builder;
}
Expand All @@ -126,6 +142,7 @@ public static MLToolSpec parse(XContentParser parser) throws IOException {
Map<String, String> parameters = null;
boolean includeOutputInAgentResponse = false;
Map<String, String> configMap = null;
String tenantId = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -151,6 +168,9 @@ public static MLToolSpec parse(XContentParser parser) throws IOException {
case CONFIG_FIELD:
configMap = getParameterMap(parser.map());
break;
case TENANT_ID_FIELD:
tenantId = parser.textOrNull();
break;
default:
parser.skipChildren();
break;
Expand All @@ -164,11 +184,11 @@ public static MLToolSpec parse(XContentParser parser) throws IOException {
.parameters(parameters)
.includeOutputInAgentResponse(includeOutputInAgentResponse)
.configMap(configMap)
.tenantId(tenantId)
.build();
}

public static MLToolSpec fromStream(StreamInput in) throws IOException {
MLToolSpec toolSpec = new MLToolSpec(in);
return toolSpec;
return new MLToolSpec(in);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ protected Map<String, String> createDecryptedHeaders(Map<String, String> headers
for (String key : headers.keySet()) {
decryptedHeaders.put(key, substitutor.replace(headers.get(key)));
}
if (parameters != null && parameters.size() > 0) {
if (parameters != null && !parameters.isEmpty()) {
substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
for (String key : decryptedHeaders.keySet()) {
decryptedHeaders.put(key, substitutor.replace(decryptedHeaders.get(key)));
Expand Down Expand Up @@ -142,11 +142,11 @@ public void removeCredential() {
@Override
public String getActionEndpoint(String action, Map<String, String> parameters) {
Optional<ConnectorAction> actionEndpoint = findAction(action);
if (!actionEndpoint.isPresent()) {
if (actionEndpoint.isEmpty()) {
return null;
}
String predictEndpoint = actionEndpoint.get().getUrl();
if (parameters != null && parameters.size() > 0) {
if (parameters != null && !parameters.isEmpty()) {
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
predictEndpoint = substitutor.replace(predictEndpoint);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,7 @@ private void parseFromStream(StreamInput input) throws IOException {
if (input.readBoolean()) {
this.connectorClientConfig = new ConnectorClientConfig(input);
}
if (streamInputVersion.onOrAfter(VERSION_2_19_0)) {
this.tenantId = input.readOptionalString();
}
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import lombok.Getter;
import lombok.Setter;

@Setter
@Getter
@InputDataSet(MLInputDataType.REMOTE)
public class RemoteInferenceInputDataSet extends MLInputDataset {
Expand All @@ -45,7 +46,7 @@ public RemoteInferenceInputDataSet(StreamInput streamInput) throws IOException {
super(MLInputDataType.REMOTE);
Version streamInputVersion = streamInput.getVersion();
if (streamInput.readBoolean()) {
parameters = streamInput.readMap(s -> s.readString(), s -> s.readString());
parameters = streamInput.readMap(StreamInput::readString, StreamInput::readString);
}
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG)) {
if (streamInput.readBoolean()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,18 +164,18 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
TextDocsInputDataSet textInputDataSet = (TextDocsInputDataSet) this.inputDataset;
List<String> docs = textInputDataSet.getDocs();
ModelResultFilter resultFilter = textInputDataSet.getResultFilter();
if (docs != null && docs.size() > 0) {
if (docs != null && !docs.isEmpty()) {
builder.field(TEXT_DOCS_FIELD, docs.toArray(new String[0]));
}
if (resultFilter != null) {
builder.field(RETURN_BYTES_FIELD, resultFilter.isReturnBytes());
builder.field(RETURN_NUMBER_FIELD, resultFilter.isReturnNumber());
List<String> targetResponse = resultFilter.getTargetResponse();
if (targetResponse != null && targetResponse.size() > 0) {
if (targetResponse != null && !targetResponse.isEmpty()) {
builder.field(TARGET_RESPONSE_FIELD, targetResponse.toArray(new String[0]));
}
List<Integer> targetPositions = resultFilter.getTargetResponsePositions();
if (targetPositions != null && targetPositions.size() > 0) {
if (targetPositions != null && !targetPositions.isEmpty()) {
builder.field(TARGET_RESPONSE_POSITIONS_FIELD, targetPositions.toArray(new Integer[0]));
}
}
Expand Down
Loading

0 comments on commit 6b2ef3d

Please sign in to comment.