Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multi-tenancy + sdk client related changes in agents #3432

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,23 @@
return actionFuture;
}

void deleteAgent(String agentId, ActionListener<DeleteResponse> listener);
/**
* Delete agent
* @param agentId The id of the agent to delete
* @param listener a listener to be notified of the result
*/
default void deleteAgent(String agentId, ActionListener<DeleteResponse> listener) {
PlainActionFuture<DeleteResponse> actionFuture = PlainActionFuture.newFuture();
deleteAgent(agentId, null, actionFuture);
}

Check warning on line 485 in client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java

View check run for this annotation

Codecov / codecov/patch

client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java#L483-L485

Added lines #L483 - L485 were not covered by tests

/**
* Delete agent
* @param agentId The id of the agent to delete
* @param tenantId the tenant id. This is necessary for multi-tenancy.
* @param listener a listener to be notified of the result
*/
void deleteAgent(String agentId, String tenantId, ActionListener<DeleteResponse> listener);

/**
* Get a list of ToolMetadata and return ActionFuture.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,8 @@ public void registerAgent(MLAgent mlAgent, ActionListener<MLRegisterAgentRespons
}

@Override
public void deleteAgent(String agentId, ActionListener<DeleteResponse> listener) {
MLAgentDeleteRequest agentDeleteRequest = new MLAgentDeleteRequest(agentId);
public void deleteAgent(String agentId, String tenantId, ActionListener<DeleteResponse> listener) {
MLAgentDeleteRequest agentDeleteRequest = new MLAgentDeleteRequest(agentId, tenantId);
client.execute(MLAgentDeleteAction.INSTANCE, agentDeleteRequest, ActionListener.wrap(listener::onResponse, listener::onFailure));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,11 @@ public void deleteAgent(String agentId, ActionListener<DeleteResponse> listener)
listener.onResponse(deleteResponse);
}

@Override
public void deleteAgent(String agentId, String tenantId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
}

@Override
public void getConfig(String configId, ActionListener<MLConfig> listener) {
listener.onResponse(mlConfig);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1182,7 +1182,7 @@ public void deleteAgent() {

ArgumentCaptor<DeleteResponse> argumentCaptor = ArgumentCaptor.forClass(DeleteResponse.class);

machineLearningNodeClient.deleteAgent(agentId, deleteAgentActionListener);
machineLearningNodeClient.deleteAgent(agentId, null, deleteAgentActionListener);

verify(client).execute(eq(MLAgentDeleteAction.INSTANCE), isA(MLAgentDeleteRequest.class), any());
verify(deleteAgentActionListener).onResponse(argumentCaptor.capture());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,9 +315,7 @@ public MLModel(StreamInput input) throws IOException {
if (input.readBoolean()) {
modelInterface = input.readMap(StreamInput::readString, StreamInput::readString);
}
if (streamInputVersion.onOrAfter(VERSION_2_19_0)) {
tenantId = input.readOptionalString();
}
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
package org.opensearch.ml.common.agent;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;

import java.io.IOException;
Expand Down Expand Up @@ -63,6 +65,7 @@
private Instant lastUpdateTime;
private String appType;
private Boolean isHidden;
private final String tenantId;

@Builder(toBuilder = true)
public MLAgent(
Expand All @@ -76,7 +79,8 @@
Instant createdTime,
Instant lastUpdateTime,
String appType,
Boolean isHidden
Boolean isHidden,
String tenantId
) {
this.name = name;
this.type = type;
Expand All @@ -90,6 +94,7 @@
this.appType = appType;
// is_hidden field isn't going to be set by user. It will be set by the code.
this.isHidden = isHidden;
this.tenantId = tenantId;
validate();
}

Expand Down Expand Up @@ -155,6 +160,7 @@
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_HIDDEN_AGENT)) {
isHidden = input.readOptionalBoolean();
}
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null;
validate();
}

Expand All @@ -169,7 +175,7 @@
} else {
out.writeBoolean(false);
}
if (tools != null && tools.size() > 0) {
if (tools != null && !tools.isEmpty()) {
out.writeBoolean(true);
out.writeInt(tools.size());
for (MLToolSpec tool : tools) {
Expand All @@ -178,7 +184,7 @@
} else {
out.writeBoolean(false);
}
if (parameters != null && parameters.size() > 0) {
if (parameters != null && !parameters.isEmpty()) {
out.writeBoolean(true);
out.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeOptionalString);
} else {
Expand All @@ -197,6 +203,9 @@
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_HIDDEN_AGENT)) {
out.writeOptionalBoolean(isHidden);
}
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
out.writeOptionalString(tenantId);
}
}

@Override
Expand Down Expand Up @@ -236,6 +245,9 @@
if (isHidden != null) {
builder.field(MLModel.IS_HIDDEN_FIELD, isHidden);
}
if (tenantId != null) {
builder.field(TENANT_ID_FIELD, tenantId);

Check warning on line 249 in common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java#L249

Added line #L249 was not covered by tests
}
builder.endObject();
return builder;
}
Expand All @@ -260,6 +272,7 @@
Instant lastUpdateTime = null;
String appType = null;
boolean isHidden = false;
String tenantId = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -305,6 +318,9 @@
if (parseHidden)
isHidden = parser.booleanValue();
break;
case TENANT_ID_FIELD:
tenantId = parser.textOrNull();
break;

Check warning on line 323 in common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java#L322-L323

Added lines #L322 - L323 were not covered by tests
default:
parser.skipChildren();
break;
Expand All @@ -324,11 +340,11 @@
.lastUpdateTime(lastUpdateTime)
.appType(appType)
.isHidden(isHidden)
.tenantId(tenantId)
.build();
}

public static MLAgent fromStream(StreamInput in) throws IOException {
MLAgent agent = new MLAgent(in);
return agent;
return new MLAgent(in);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
package org.opensearch.ml.common.agent;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;

import java.io.IOException;
Expand All @@ -22,6 +24,7 @@
import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;

@EqualsAndHashCode
@Getter
Expand All @@ -41,6 +44,8 @@
private Map<String, String> parameters;
private boolean includeOutputInAgentResponse;
private Map<String, String> configMap;
@Setter
private String tenantId;

@Builder(toBuilder = true)
public MLToolSpec(
Expand All @@ -49,7 +54,8 @@
String description,
Map<String, String> parameters,
boolean includeOutputInAgentResponse,
Map<String, String> configMap
Map<String, String> configMap,
String tenantId
) {
if (type == null) {
throw new IllegalArgumentException("tool type is null");
Expand All @@ -60,9 +66,11 @@
this.parameters = parameters;
this.includeOutputInAgentResponse = includeOutputInAgentResponse;
this.configMap = configMap;
this.tenantId = tenantId;
}

public MLToolSpec(StreamInput input) throws IOException {
Version streamInputVersion = input.getVersion();
type = input.readString();
name = input.readOptionalString();
description = input.readOptionalString();
Expand All @@ -73,13 +81,15 @@
if (input.getVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_TOOL_CONFIG) && input.readBoolean()) {
configMap = input.readMap(StreamInput::readString, StreamInput::readOptionalString);
}
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null;
}

public void writeTo(StreamOutput out) throws IOException {
Version streamOutputVersion = out.getVersion();
out.writeString(type);
out.writeOptionalString(name);
out.writeOptionalString(description);
if (parameters != null && parameters.size() > 0) {
if (parameters != null && !parameters.isEmpty()) {
out.writeBoolean(true);
out.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeOptionalString);
} else {
Expand All @@ -94,6 +104,9 @@
out.writeBoolean(false);
}
}
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
out.writeOptionalString(tenantId);
}
}

@Override
Expand All @@ -108,13 +121,16 @@
if (description != null) {
builder.field(DESCRIPTION_FIELD, description);
}
if (parameters != null && parameters.size() > 0) {
if (parameters != null && !parameters.isEmpty()) {
builder.field(PARAMETERS_FIELD, parameters);
}
builder.field(INCLUDE_OUTPUT_IN_AGENT_RESPONSE, includeOutputInAgentResponse);
if (configMap != null && !configMap.isEmpty()) {
builder.field(CONFIG_FIELD, configMap);
}
if (tenantId != null) {
builder.field(TENANT_ID_FIELD, tenantId);

Check warning on line 132 in common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java#L132

Added line #L132 was not covered by tests
}
builder.endObject();
return builder;
}
Expand All @@ -126,6 +142,7 @@
Map<String, String> parameters = null;
boolean includeOutputInAgentResponse = false;
Map<String, String> configMap = null;
String tenantId = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -151,6 +168,9 @@
case CONFIG_FIELD:
configMap = getParameterMap(parser.map());
break;
case TENANT_ID_FIELD:
tenantId = parser.textOrNull();
break;

Check warning on line 173 in common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java#L172-L173

Added lines #L172 - L173 were not covered by tests
default:
parser.skipChildren();
break;
Expand All @@ -164,11 +184,11 @@
.parameters(parameters)
.includeOutputInAgentResponse(includeOutputInAgentResponse)
.configMap(configMap)
.tenantId(tenantId)
.build();
}

public static MLToolSpec fromStream(StreamInput in) throws IOException {
MLToolSpec toolSpec = new MLToolSpec(in);
return toolSpec;
return new MLToolSpec(in);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ protected Map<String, String> createDecryptedHeaders(Map<String, String> headers
for (String key : headers.keySet()) {
decryptedHeaders.put(key, substitutor.replace(headers.get(key)));
}
if (parameters != null && parameters.size() > 0) {
if (parameters != null && !parameters.isEmpty()) {
substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
for (String key : decryptedHeaders.keySet()) {
decryptedHeaders.put(key, substitutor.replace(decryptedHeaders.get(key)));
Expand Down Expand Up @@ -142,11 +142,11 @@ public void removeCredential() {
@Override
public String getActionEndpoint(String action, Map<String, String> parameters) {
Optional<ConnectorAction> actionEndpoint = findAction(action);
if (!actionEndpoint.isPresent()) {
if (actionEndpoint.isEmpty()) {
return null;
}
String predictEndpoint = actionEndpoint.get().getUrl();
if (parameters != null && parameters.size() > 0) {
if (parameters != null && !parameters.isEmpty()) {
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
predictEndpoint = substitutor.replace(predictEndpoint);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,7 @@ private void parseFromStream(StreamInput input) throws IOException {
if (input.readBoolean()) {
this.connectorClientConfig = new ConnectorClientConfig(input);
}
if (streamInputVersion.onOrAfter(VERSION_2_19_0)) {
this.tenantId = input.readOptionalString();
}
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import lombok.Getter;
import lombok.Setter;

@Setter
@Getter
@InputDataSet(MLInputDataType.REMOTE)
public class RemoteInferenceInputDataSet extends MLInputDataset {
Expand All @@ -45,7 +46,7 @@ public RemoteInferenceInputDataSet(StreamInput streamInput) throws IOException {
super(MLInputDataType.REMOTE);
Version streamInputVersion = streamInput.getVersion();
if (streamInput.readBoolean()) {
parameters = streamInput.readMap(s -> s.readString(), s -> s.readString());
parameters = streamInput.readMap(StreamInput::readString, StreamInput::readString);
}
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG)) {
if (streamInput.readBoolean()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,18 +164,18 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
TextDocsInputDataSet textInputDataSet = (TextDocsInputDataSet) this.inputDataset;
List<String> docs = textInputDataSet.getDocs();
ModelResultFilter resultFilter = textInputDataSet.getResultFilter();
if (docs != null && docs.size() > 0) {
if (docs != null && !docs.isEmpty()) {
builder.field(TEXT_DOCS_FIELD, docs.toArray(new String[0]));
}
if (resultFilter != null) {
builder.field(RETURN_BYTES_FIELD, resultFilter.isReturnBytes());
builder.field(RETURN_NUMBER_FIELD, resultFilter.isReturnNumber());
List<String> targetResponse = resultFilter.getTargetResponse();
if (targetResponse != null && targetResponse.size() > 0) {
if (targetResponse != null && !targetResponse.isEmpty()) {
builder.field(TARGET_RESPONSE_FIELD, targetResponse.toArray(new String[0]));
}
List<Integer> targetPositions = resultFilter.getTargetResponsePositions();
if (targetPositions != null && targetPositions.size() > 0) {
if (targetPositions != null && !targetPositions.isEmpty()) {
builder.field(TARGET_RESPONSE_POSITIONS_FIELD, targetPositions.toArray(new Integer[0]));
}
}
Expand Down
Loading
Loading