Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(Leo): Support modifying user prompts #24558

Merged
merged 1 commit into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions browser/ai_chat/android/ai_chat_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "base/android/jni_android.h"
#include "base/android/jni_string.h"
#include "base/time/time.h"
#include "brave/build/android/jni_headers/BraveLeoUtils_jni.h"
#include "brave/components/ai_chat/core/common/buildflags/buildflags.h"
#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h"
Expand All @@ -30,8 +31,8 @@ static void JNI_BraveLeoUtils_OpenLeoQuery(
mojom::ConversationTurnPtr turn = mojom::ConversationTurn::New(
mojom::CharacterType::HUMAN, mojom::ActionType::QUERY,
mojom::ConversationTurnVisibility::VISIBLE,
base::android::ConvertJavaStringToUTF8(query), std::nullopt,
std::nullopt);
base::android::ConvertJavaStringToUTF8(query), std::nullopt, std::nullopt,
base::Time::Now(), std::nullopt);
chat_tab_helper->SubmitHumanConversationEntry(std::move(turn));
#endif
}
Expand Down
11 changes: 10 additions & 1 deletion browser/ui/webui/ai_chat/ai_chat_ui_page_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "base/notreached.h"
#include "base/strings/utf_string_conversions.h"
#include "base/time/time.h"
#include "brave/browser/ui/side_panel/ai_chat/ai_chat_side_panel_utils.h"
#include "brave/components/ai_chat/core/browser/constants.h"
#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-shared.h"
Expand Down Expand Up @@ -141,7 +142,8 @@ void AIChatUIPageHandler::SubmitHumanConversationEntry(

mojom::ConversationTurnPtr turn = mojom::ConversationTurn::New(
CharacterType::HUMAN, mojom::ActionType::UNSPECIFIED,
ConversationTurnVisibility::VISIBLE, input, std::nullopt, std::nullopt);
ConversationTurnVisibility::VISIBLE, input, std::nullopt, std::nullopt,
base::Time::Now(), std::nullopt);
active_chat_tab_helper_->SubmitHumanConversationEntry(std::move(turn));
}

Expand Down Expand Up @@ -516,4 +518,11 @@ void AIChatUIPageHandler::OnGetPremiumStatus(
}
}

void AIChatUIPageHandler::ModifyConversation(uint32_t turn_index,
const std::string& new_text) {
if (active_chat_tab_helper_) {
active_chat_tab_helper_->ModifyConversation(turn_index, new_text);
}
}

} // namespace ai_chat
2 changes: 2 additions & 0 deletions browser/ui/webui/ai_chat/ai_chat_ui_page_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ class AIChatUIPageHandler : public ai_chat::mojom::PageHandler,
void ClosePanel() override;
void GetActionMenuList(GetActionMenuListCallback callback) override;
void OpenModelSupportUrl() override;
void ModifyConversation(uint32_t turn_index,
const std::string& new_text) override;

// content::WebContentsObserver:
void OnVisibilityChanged(content::Visibility visibility) override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,13 @@ void ChromeAutocompleteProviderClient::OpenLeo(const std::u16string& query) {

// Send the query to the AIChat's backend.
ai_chat::mojom::ConversationTurnPtr turn =
ai_chat::mojom::ConversationTurn::New();
turn->character_type = ai_chat::mojom::CharacterType::HUMAN;
turn->action_type = ai_chat::mojom::ActionType::QUERY;
turn->visibility = ai_chat::mojom::ConversationTurnVisibility::VISIBLE;
turn->text = base::UTF16ToUTF8(query);
turn->selected_text = std::nullopt;
ai_chat::mojom::ConversationTurn::New(
ai_chat::mojom::CharacterType::HUMAN,
ai_chat::mojom::ActionType::QUERY,
ai_chat::mojom::ConversationTurnVisibility::VISIBLE,
base::UTF16ToUTF8(query) /* text */, std::nullopt /* selected_text */,
std::nullopt /* events */, base::Time::Now(),
std::nullopt /* edits */);

chat_tab_helper->SubmitHumanConversationEntry(std::move(turn));

Expand Down
3 changes: 3 additions & 0 deletions components/ai_chat/core/browser/constants.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ base::span<const webui::LocalizedString> GetLocalizedStrings() {
{"feedbackPremiumNote", IDS_CHAT_UI_FEEDBACK_PREMIUM_NOTE},
{"submitButtonLabel", IDS_CHAT_UI_SUBMIT_BUTTON_LABEL},
{"cancelButtonLabel", IDS_CHAT_UI_CANCEL_BUTTON_LABEL},
{"saveButtonLabel", IDS_CHAT_UI_SAVE_BUTTON_LABEL},
{"editedLabel", IDS_CHAT_UI_EDITED_LABEL},
{"editButtonLabel", IDS_CHAT_UI_EDIT_BUTTON_LABEL},
{"optionNotHelpful", IDS_CHAT_UI_OPTION_NOT_HELPFUL},
{"optionIncorrect", IDS_CHAT_UI_OPTION_INCORRECT},
{"optionUnsafeHarmful", IDS_CHAT_UI_OPTION_UNSAFE_HARMFUL},
Expand Down
88 changes: 72 additions & 16 deletions components/ai_chat/core/browser/conversation_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "base/strings/strcat.h"
#include "base/strings/string_util.h"
#include "base/strings/utf_string_conversions.h"
#include "base/time/time.h"
#include "base/values.h"
#include "brave/components/ai_chat/core/browser/ai_chat_credential_manager.h"
#include "brave/components/ai_chat/core/browser/ai_chat_metrics.h"
Expand Down Expand Up @@ -416,7 +417,8 @@ void ConversationDriver::UpdateOrCreateLastAssistantEntry(
mojom::ConversationTurnPtr entry = mojom::ConversationTurn::New(
CharacterType::ASSISTANT, mojom::ActionType::RESPONSE,
ConversationTurnVisibility::VISIBLE, "", std::nullopt,
std::vector<mojom::ConversationEntryEventPtr>{});
std::vector<mojom::ConversationEntryEventPtr>{}, base::Time::Now(),
std::nullopt);
chat_history_.push_back(std::move(entry));
}

Expand Down Expand Up @@ -868,7 +870,7 @@ void ConversationDriver::AddSubmitSelectedTextError(
const std::string& question = GetActionTypeQuestion(action_type);
mojom::ConversationTurnPtr turn = mojom::ConversationTurn::New(
CharacterType::HUMAN, action_type, ConversationTurnVisibility::VISIBLE,
question, selected_text, std::nullopt);
question, selected_text, std::nullopt, base::Time::Now(), std::nullopt);
AddToConversationHistory(std::move(turn));
SetAPIError(error);
}
Expand Down Expand Up @@ -925,7 +927,7 @@ void ConversationDriver::SubmitSelectedTextWithQuestion(
// Use sidebar.
mojom::ConversationTurnPtr turn = mojom::ConversationTurn::New(
CharacterType::HUMAN, action_type, ConversationTurnVisibility::VISIBLE,
question, selected_text, std::nullopt);
question, selected_text, std::nullopt, base::Time::Now(), std::nullopt);

SubmitHumanConversationEntry(std::move(turn));
} else {
Expand All @@ -937,6 +939,12 @@ void ConversationDriver::SubmitHumanConversationEntry(
mojom::ConversationTurnPtr turn) {
VLOG(1) << __func__;
DVLOG(4) << __func__ << ": " << turn->text;

// If there's edits, use the last one as the latest turn.
bool has_edits = turn->edits && !turn->edits->empty();
mojom::ConversationTurnPtr& latest_turn =
has_edits ? turn->edits->back() : turn;

// Decide if this entry needs to wait for one of:
// - user to be opted-in
// - conversation to be active
Expand All @@ -963,44 +971,49 @@ void ConversationDriver::SubmitHumanConversationEntry(
return;
}

DCHECK(turn->character_type == CharacterType::HUMAN);
DCHECK(latest_turn->character_type == CharacterType::HUMAN);

is_request_in_progress_ = true;
for (auto& obs : observers_) {
obs.OnAPIRequestInProgress(IsRequestInProgress());
}

// If it's a suggested question, remove it
auto found_question_iter = base::ranges::find(suggestions_, turn->text);
auto found_question_iter =
base::ranges::find(suggestions_, latest_turn->text);
if (found_question_iter != suggestions_.end()) {
suggestions_.erase(found_question_iter);
OnSuggestedQuestionsChanged();
}

// Directly modify Entry's text to remove engine-breaking substrings
engine_->SanitizeInput(turn->text);
if (turn->selected_text) {
engine_->SanitizeInput(*turn->selected_text);
if (!has_edits) { // Edits are already sanitized.
engine_->SanitizeInput(latest_turn->text);
}

if (latest_turn->selected_text) {
engine_->SanitizeInput(*latest_turn->selected_text);
}

// TODO(petemill): Tokenize the summary question so that we
// don't have to do this weird substitution.
// TODO(jocelyn): Assigning turn.type below is a workaround for now since
// callers of SubmitHumanConversationEntry mojo API currently don't have
// action_type specified.
std::string question_part = turn->text;
if (turn->action_type == mojom::ActionType::UNSPECIFIED) {
if (turn->text == l10n_util::GetStringUTF8(IDS_CHAT_UI_SUMMARIZE_PAGE)) {
turn->action_type = mojom::ActionType::SUMMARIZE_PAGE;
std::string question_part = latest_turn->text;
if (latest_turn->action_type == mojom::ActionType::UNSPECIFIED) {
if (latest_turn->text ==
l10n_util::GetStringUTF8(IDS_CHAT_UI_SUMMARIZE_PAGE)) {
latest_turn->action_type = mojom::ActionType::SUMMARIZE_PAGE;
question_part =
l10n_util::GetStringUTF8(IDS_AI_CHAT_QUESTION_SUMMARIZE_PAGE);
} else if (turn->text ==
} else if (latest_turn->text ==
l10n_util::GetStringUTF8(IDS_CHAT_UI_SUMMARIZE_VIDEO)) {
turn->action_type = mojom::ActionType::SUMMARIZE_VIDEO;
latest_turn->action_type = mojom::ActionType::SUMMARIZE_VIDEO;
question_part =
l10n_util::GetStringUTF8(IDS_AI_CHAT_QUESTION_SUMMARIZE_VIDEO);
} else {
turn->action_type = mojom::ActionType::QUERY;
latest_turn->action_type = mojom::ActionType::QUERY;
}
}

Expand Down Expand Up @@ -1165,7 +1178,7 @@ void ConversationDriver::SubmitSummarizationRequest() {
CharacterType::HUMAN, mojom::ActionType::SUMMARIZE_PAGE,
ConversationTurnVisibility::VISIBLE,
l10n_util::GetStringUTF8(IDS_CHAT_UI_SUMMARIZE_PAGE), std::nullopt,
std::nullopt);
std::nullopt, base::Time::Now(), std::nullopt);
SubmitHumanConversationEntry(std::move(turn));
}

Expand Down Expand Up @@ -1329,4 +1342,47 @@ void ConversationDriver::SendFeedback(
: std::nullopt,
std::move(on_complete));
}

void ConversationDriver::ModifyConversation(uint32_t turn_index,
const std::string& new_text) {
if (turn_index >= chat_history_.size()) {
return;
}

auto& turn = chat_history_.at(turn_index);
if (turn->character_type == CharacterType::ASSISTANT) { // not supported yet
return;
}

std::string sanitized_input = new_text;
engine_->SanitizeInput(sanitized_input);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SanitizeInput is already called by SubmitHumanConversationEntry. Do you think we need to call it twice for the purpose of being able to ignore it if it matches the existing turn text?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Intentionally called here so we can check here, I've added a check to avoid another sanitization happening in 26769e4.

const auto& current_text = turn->edits && !turn->edits->empty()
? turn->edits->back()->text
: turn->text;
if (sanitized_input.empty() || sanitized_input == current_text) {
return;
}

// turn->selected_text and turn->events are actually std::nullopt for
// editable human turns in our current implementation, just use std::nullopt
// here directly to be more explicit and avoid confusion.
auto edited_turn = mojom::ConversationTurn::New(
turn->character_type, turn->action_type, turn->visibility,
sanitized_input, std::nullopt /* selected_text */,
std::nullopt /* events */, base::Time::Now(), std::nullopt /* edits */);
if (!turn->edits) {
turn->edits.emplace();
}
turn->edits->emplace_back(std::move(edited_turn));

// Modifying human turn, drop anything after this turn_index and resubmit.
auto new_turn = std::move(chat_history_.at(turn_index));
chat_history_.erase(chat_history_.begin() + turn_index, chat_history_.end());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're modifying chat_history_ then technically shouldn't we let the observers know? Or would you rather wait for the edited turn to appear in the history after calling SubmitHumanConversationEntry?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally I didn't do so since it would be called in SubmitHumanConversationEntry, but I'm fine having another one here after we clear it. Added in 26769e4.

for (auto& obs : observers_) {
obs.OnHistoryUpdate();
}

SubmitHumanConversationEntry(std::move(new_turn));
}

} // namespace ai_chat
7 changes: 7 additions & 0 deletions components/ai_chat/core/browser/conversation_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ class ConversationDriver : public ModelService::Observer {
GeneratedTextCallback received_callback,
EngineConsumer::GenerationCompletedCallback completed_callback);

void ModifyConversation(uint32_t turn_index, const std::string& new_text);

void RateMessage(bool is_liked,
uint32_t turn_id,
mojom::PageHandler::RateMessageCallback callback);
Expand All @@ -169,6 +171,11 @@ class ConversationDriver : public ModelService::Observer {
}
EngineConsumer* GetEngineForTesting() { return engine_.get(); }

void SetChatHistoryForTesting(
std::vector<mojom::ConversationTurnPtr> history) {
chat_history_ = std::move(history);
}

protected:
virtual GURL GetPageURL() const = 0;
virtual std::u16string GetPageTitle() const = 0;
Expand Down
Loading
Loading