Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
- Put edits in the edits array and not putting the latest in the root turn,
  use null edits array in all non-root turns.
- Make edits array nullable
- Add an extra onHistoryUpdate event when erasing the chat history
- Rename last_edited_time to created_time
  • Loading branch information
yrliou committed Jul 16, 2024
1 parent 9c465a2 commit 26769e4
Show file tree
Hide file tree
Showing 15 changed files with 135 additions and 109 deletions.
2 changes: 1 addition & 1 deletion browser/ai_chat/android/ai_chat_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ static void JNI_BraveLeoUtils_OpenLeoQuery(
mojom::CharacterType::HUMAN, mojom::ActionType::QUERY,
mojom::ConversationTurnVisibility::VISIBLE,
base::android::ConvertJavaStringToUTF8(query), std::nullopt, std::nullopt,
base::Time::Now(), std::vector<mojom::ConversationTurnPtr>{});
base::Time::Now(), std::nullopt);
chat_tab_helper->SubmitHumanConversationEntry(std::move(turn));
#endif
}
Expand Down
2 changes: 1 addition & 1 deletion browser/ui/webui/ai_chat/ai_chat_ui_page_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ void AIChatUIPageHandler::SubmitHumanConversationEntry(
mojom::ConversationTurnPtr turn = mojom::ConversationTurn::New(
CharacterType::HUMAN, mojom::ActionType::UNSPECIFIED,
ConversationTurnVisibility::VISIBLE, input, std::nullopt, std::nullopt,
base::Time::Now(), std::vector<mojom::ConversationTurnPtr>{});
base::Time::Now(), std::nullopt);
active_chat_tab_helper_->SubmitHumanConversationEntry(std::move(turn));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,13 @@ void ChromeAutocompleteProviderClient::OpenLeo(const std::u16string& query) {

// Send the query to the AIChat's backend.
ai_chat::mojom::ConversationTurnPtr turn =
ai_chat::mojom::ConversationTurn::New();
turn->character_type = ai_chat::mojom::CharacterType::HUMAN;
turn->action_type = ai_chat::mojom::ActionType::QUERY;
turn->visibility = ai_chat::mojom::ConversationTurnVisibility::VISIBLE;
turn->text = base::UTF16ToUTF8(query);
turn->selected_text = std::nullopt;
ai_chat::mojom::ConversationTurn::New(
ai_chat::mojom::CharacterType::HUMAN,
ai_chat::mojom::ActionType::QUERY,
ai_chat::mojom::ConversationTurnVisibility::VISIBLE,
base::UTF16ToUTF8(query) /* text */, std::nullopt /* selected_text */,
std::nullopt /* events */, base::Time::Now(),
std::nullopt /* edits */);

chat_tab_helper->SubmitHumanConversationEntry(std::move(turn));

Expand Down
69 changes: 46 additions & 23 deletions components/ai_chat/core/browser/conversation_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ void ConversationDriver::UpdateOrCreateLastAssistantEntry(
CharacterType::ASSISTANT, mojom::ActionType::RESPONSE,
ConversationTurnVisibility::VISIBLE, "", std::nullopt,
std::vector<mojom::ConversationEntryEventPtr>{}, base::Time::Now(),
std::vector<mojom::ConversationTurnPtr>{});
std::nullopt);
chat_history_.push_back(std::move(entry));
}

Expand Down Expand Up @@ -870,8 +870,7 @@ void ConversationDriver::AddSubmitSelectedTextError(
const std::string& question = GetActionTypeQuestion(action_type);
mojom::ConversationTurnPtr turn = mojom::ConversationTurn::New(
CharacterType::HUMAN, action_type, ConversationTurnVisibility::VISIBLE,
question, selected_text, std::nullopt, base::Time::Now(),
std::vector<mojom::ConversationTurnPtr>{});
question, selected_text, std::nullopt, base::Time::Now(), std::nullopt);
AddToConversationHistory(std::move(turn));
SetAPIError(error);
}
Expand Down Expand Up @@ -928,8 +927,7 @@ void ConversationDriver::SubmitSelectedTextWithQuestion(
// Use sidebar.
mojom::ConversationTurnPtr turn = mojom::ConversationTurn::New(
CharacterType::HUMAN, action_type, ConversationTurnVisibility::VISIBLE,
question, selected_text, std::nullopt, base::Time::Now(),
std::vector<mojom::ConversationTurnPtr>{});
question, selected_text, std::nullopt, base::Time::Now(), std::nullopt);

SubmitHumanConversationEntry(std::move(turn));
} else {
Expand All @@ -941,6 +939,12 @@ void ConversationDriver::SubmitHumanConversationEntry(
mojom::ConversationTurnPtr turn) {
VLOG(1) << __func__;
DVLOG(4) << __func__ << ": " << turn->text;

// If there's edits, use the last one as the latest turn.
bool has_edits = turn->edits && !turn->edits->empty();
mojom::ConversationTurnPtr& latest_turn =
has_edits ? turn->edits->back() : turn;

// Decide if this entry needs to wait for one of:
// - user to be opted-in
// - conversation to be active
Expand All @@ -967,44 +971,49 @@ void ConversationDriver::SubmitHumanConversationEntry(
return;
}

DCHECK(turn->character_type == CharacterType::HUMAN);
DCHECK(latest_turn->character_type == CharacterType::HUMAN);

is_request_in_progress_ = true;
for (auto& obs : observers_) {
obs.OnAPIRequestInProgress(IsRequestInProgress());
}

// If it's a suggested question, remove it
auto found_question_iter = base::ranges::find(suggestions_, turn->text);
auto found_question_iter =
base::ranges::find(suggestions_, latest_turn->text);
if (found_question_iter != suggestions_.end()) {
suggestions_.erase(found_question_iter);
OnSuggestedQuestionsChanged();
}

// Directly modify Entry's text to remove engine-breaking substrings
engine_->SanitizeInput(turn->text);
if (turn->selected_text) {
engine_->SanitizeInput(*turn->selected_text);
if (!has_edits) { // Edits are already sanitized.
engine_->SanitizeInput(latest_turn->text);
}

if (latest_turn->selected_text) {
engine_->SanitizeInput(*latest_turn->selected_text);
}

// TODO(petemill): Tokenize the summary question so that we
// don't have to do this weird substitution.
// TODO(jocelyn): Assigning turn.type below is a workaround for now since
// callers of SubmitHumanConversationEntry mojo API currently don't have
// action_type specified.
std::string question_part = turn->text;
if (turn->action_type == mojom::ActionType::UNSPECIFIED) {
if (turn->text == l10n_util::GetStringUTF8(IDS_CHAT_UI_SUMMARIZE_PAGE)) {
turn->action_type = mojom::ActionType::SUMMARIZE_PAGE;
std::string question_part = latest_turn->text;
if (latest_turn->action_type == mojom::ActionType::UNSPECIFIED) {
if (latest_turn->text ==
l10n_util::GetStringUTF8(IDS_CHAT_UI_SUMMARIZE_PAGE)) {
latest_turn->action_type = mojom::ActionType::SUMMARIZE_PAGE;
question_part =
l10n_util::GetStringUTF8(IDS_AI_CHAT_QUESTION_SUMMARIZE_PAGE);
} else if (turn->text ==
} else if (latest_turn->text ==
l10n_util::GetStringUTF8(IDS_CHAT_UI_SUMMARIZE_VIDEO)) {
turn->action_type = mojom::ActionType::SUMMARIZE_VIDEO;
latest_turn->action_type = mojom::ActionType::SUMMARIZE_VIDEO;
question_part =
l10n_util::GetStringUTF8(IDS_AI_CHAT_QUESTION_SUMMARIZE_VIDEO);
} else {
turn->action_type = mojom::ActionType::QUERY;
latest_turn->action_type = mojom::ActionType::QUERY;
}
}

Expand Down Expand Up @@ -1169,8 +1178,7 @@ void ConversationDriver::SubmitSummarizationRequest() {
CharacterType::HUMAN, mojom::ActionType::SUMMARIZE_PAGE,
ConversationTurnVisibility::VISIBLE,
l10n_util::GetStringUTF8(IDS_CHAT_UI_SUMMARIZE_PAGE), std::nullopt,
std::nullopt, base::Time::Now(),
std::vector<mojom::ConversationTurnPtr>{});
std::nullopt, base::Time::Now(), std::nullopt);
SubmitHumanConversationEntry(std::move(turn));
}

Expand Down Expand Up @@ -1348,17 +1356,32 @@ void ConversationDriver::ModifyConversation(uint32_t turn_index,

std::string sanitized_input = new_text;
engine_->SanitizeInput(sanitized_input);
if (sanitized_input.empty() || sanitized_input == turn->text) {
const auto& current_text = turn->edits && !turn->edits->empty()
? turn->edits->back()->text
: turn->text;
if (sanitized_input.empty() || sanitized_input == current_text) {
return;
}

turn->edits.push_back(turn.Clone());
turn->text = sanitized_input;
turn->last_edited_time = base::Time::Now();
// turn->selected_text and turn->events are actually std::nullopt for
// editable human turns in our current implementation, just use std::nullopt
// here directly to be more explicit and avoid confusion.
auto edited_turn = mojom::ConversationTurn::New(
turn->character_type, turn->action_type, turn->visibility,
sanitized_input, std::nullopt /* selected_text */,
std::nullopt /* events */, base::Time::Now(), std::nullopt /* edits */);
if (!turn->edits) {
turn->edits.emplace();
}
turn->edits->emplace_back(std::move(edited_turn));

// Modifying human turn, drop anything after this turn_index and resubmit.
auto new_turn = std::move(chat_history_.at(turn_index));
chat_history_.erase(chat_history_.begin() + turn_index, chat_history_.end());
for (auto& obs : observers_) {
obs.OnHistoryUpdate();
}

SubmitHumanConversationEntry(std::move(new_turn));
}

Expand Down
66 changes: 33 additions & 33 deletions components/ai_chat/core/browser/conversation_driver_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -330,13 +330,11 @@ TEST_F(ConversationDriverUnitTest, SubmitSelectedText) {
mojom::CharacterType::HUMAN, mojom::ActionType::SUMMARIZE_SELECTED_TEXT,
mojom::ConversationTurnVisibility::VISIBLE,
l10n_util::GetStringUTF8(IDS_AI_CHAT_QUESTION_SUMMARIZE_SELECTED_TEXT),
"I have spoken.", std::nullopt, base::Time::Now(),
std::vector<mojom::ConversationTurnPtr>{}));
"I have spoken.", std::nullopt, base::Time::Now(), std::nullopt));
expected_history.push_back(mojom::ConversationTurn::New(
mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE,
mojom::ConversationTurnVisibility::VISIBLE, "This is the way.",
std::nullopt, std::nullopt, base::Time::Now(),
std::vector<mojom::ConversationTurnPtr>{}));
std::nullopt, std::nullopt, base::Time::Now(), std::nullopt));
EXPECT_EQ(history.size(), expected_history.size());
for (size_t i = 0; i < history.size(); i++) {
EXPECT_TRUE(CompareConversationTurn(history[i], expected_history[i]));
Expand Down Expand Up @@ -375,24 +373,20 @@ TEST_F(ConversationDriverUnitTest, SubmitSelectedText) {
mojom::CharacterType::HUMAN, mojom::ActionType::SUMMARIZE_SELECTED_TEXT,
mojom::ConversationTurnVisibility::VISIBLE,
l10n_util::GetStringUTF8(IDS_AI_CHAT_QUESTION_SUMMARIZE_SELECTED_TEXT),
"I have spoken.", std::nullopt, base::Time::Now(),
std::vector<mojom::ConversationTurnPtr>{}));
"I have spoken.", std::nullopt, base::Time::Now(), std::nullopt));
expected_history2.push_back(mojom::ConversationTurn::New(
mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE,
mojom::ConversationTurnVisibility::VISIBLE, "This is the way.",
std::nullopt, std::nullopt, base::Time::Now(),
std::vector<mojom::ConversationTurnPtr>{}));
std::nullopt, std::nullopt, base::Time::Now(), std::nullopt));
expected_history2.push_back(mojom::ConversationTurn::New(
mojom::CharacterType::HUMAN, mojom::ActionType::SUMMARIZE_SELECTED_TEXT,
mojom::ConversationTurnVisibility::VISIBLE,
l10n_util::GetStringUTF8(IDS_AI_CHAT_QUESTION_SUMMARIZE_SELECTED_TEXT),
"I have spoken again.", std::nullopt, base::Time::Now(),
std::vector<mojom::ConversationTurnPtr>{}));
"I have spoken again.", std::nullopt, base::Time::Now(), std::nullopt));
expected_history2.push_back(mojom::ConversationTurn::New(
mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE,
mojom::ConversationTurnVisibility::VISIBLE, "This is the way.",
std::nullopt, std::nullopt, base::Time::Now(),
std::vector<mojom::ConversationTurnPtr>{}));
std::nullopt, std::nullopt, base::Time::Now(), std::nullopt));
EXPECT_EQ(history2.size(), expected_history2.size());
for (size_t i = 0; i < history2.size(); i++) {
EXPECT_TRUE(CompareConversationTurn(history2[i], expected_history2[i]));
Expand Down Expand Up @@ -729,49 +723,55 @@ TEST_F(ConversationDriverUnitTest, ModifyConversation) {
}));

// Setup history for testing.
auto last_edited_time1 = base::Time::Now();
auto created_time1 = base::Time::Now();
std::vector<mojom::ConversationTurnPtr> history;
history.push_back(mojom::ConversationTurn::New(
mojom::CharacterType::HUMAN, mojom::ActionType::QUERY,
mojom::ConversationTurnVisibility::VISIBLE, "prompt1", std::nullopt,
std::nullopt, last_edited_time1,
std::vector<mojom::ConversationTurnPtr>{}));
std::nullopt, created_time1, std::nullopt));
history.push_back(mojom::ConversationTurn::New(
mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE,
mojom::ConversationTurnVisibility::VISIBLE, "answer1", std::nullopt,
std::nullopt, base::Time::Now(),
std::vector<mojom::ConversationTurnPtr>{}));
std::nullopt, base::Time::Now(), std::nullopt));
conversation_driver_->SetChatHistoryForTesting(std::move(history));

// Modify an entry for the first time.
conversation_driver_->ModifyConversation(0, "prompt2");
const auto& conversation_history =
conversation_driver_->GetConversationHistory();
ASSERT_EQ(conversation_history.size(), 1u);
EXPECT_EQ(conversation_history[0]->text, "prompt2");
EXPECT_NE(conversation_history[0]->last_edited_time, last_edited_time1);
ASSERT_EQ(conversation_history[0]->edits.size(), 1u);
EXPECT_EQ(conversation_history[0]->edits.at(0)->text, "prompt1");
EXPECT_EQ(conversation_history[0]->edits.at(0)->last_edited_time,
last_edited_time1);
EXPECT_EQ(conversation_history[0]->text, "prompt1");
EXPECT_EQ(conversation_history[0]->created_time, created_time1);

ASSERT_TRUE(conversation_history[0]->edits);
ASSERT_EQ(conversation_history[0]->edits->size(), 1u);
EXPECT_EQ(conversation_history[0]->edits->at(0)->text, "prompt2");
EXPECT_NE(conversation_history[0]->edits->at(0)->created_time, created_time1);
EXPECT_FALSE(conversation_history[0]->edits->at(0)->edits);

WaitForOnEngineCompletionComplete();
ASSERT_EQ(conversation_history.size(), 2u);
EXPECT_EQ(conversation_history[1]->text, "new answer");

auto last_edited_time2 = conversation_history[0]->last_edited_time;
auto created_time2 = conversation_history[0]->edits->at(0)->created_time;

// Modify the same entry again.
conversation_driver_->ModifyConversation(0, "prompt3");
ASSERT_EQ(conversation_history.size(), 1u);
EXPECT_EQ(conversation_history[0]->text, "prompt3");
EXPECT_NE(conversation_history[0]->last_edited_time, last_edited_time2);
ASSERT_EQ(conversation_history[0]->edits.size(), 2u);
EXPECT_EQ(conversation_history[0]->edits.at(0)->text, "prompt1");
EXPECT_EQ(conversation_history[0]->edits.at(0)->last_edited_time,
last_edited_time1);
EXPECT_EQ(conversation_history[0]->edits.at(1)->text, "prompt2");
EXPECT_EQ(conversation_history[0]->edits.at(1)->last_edited_time,
last_edited_time2);
EXPECT_EQ(conversation_history[0]->text, "prompt1");
EXPECT_EQ(conversation_history[0]->created_time, created_time1);

ASSERT_TRUE(conversation_history[0]->edits);
ASSERT_EQ(conversation_history[0]->edits->size(), 2u);
EXPECT_EQ(conversation_history[0]->edits->at(0)->text, "prompt2");
EXPECT_EQ(conversation_history[0]->edits->at(0)->created_time, created_time2);
EXPECT_FALSE(conversation_history[0]->edits->at(0)->edits);

EXPECT_EQ(conversation_history[0]->edits->at(1)->text, "prompt3");
EXPECT_NE(conversation_history[0]->edits->at(1)->created_time, created_time1);
EXPECT_NE(conversation_history[0]->edits->at(1)->created_time, created_time2);
EXPECT_FALSE(conversation_history[0]->edits->at(1)->edits);

WaitForOnEngineCompletionComplete();
ASSERT_EQ(conversation_history.size(), 2u);
EXPECT_EQ(conversation_history[1]->text, "new answer");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,11 @@ TEST_F(EngineConsumerClaudeUnitTest, TestGenerateAssistantResponse) {
mojom::CharacterType::HUMAN, mojom::ActionType::SUMMARIZE_SELECTED_TEXT,
mojom::ConversationTurnVisibility::VISIBLE,
"Which show is this catchphrase from?", "I have spoken.", std::nullopt,
base::Time::Now(), std::vector<mojom::ConversationTurnPtr>{}));
base::Time::Now(), std::nullopt));
history.push_back(mojom::ConversationTurn::New(
mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE,
mojom::ConversationTurnVisibility::VISIBLE, "The Mandalorian.",
std::nullopt, std::nullopt, base::Time::Now(),
std::vector<mojom::ConversationTurnPtr>{}));
std::nullopt, std::nullopt, base::Time::Now(), std::nullopt));
auto* mock_remote_completion_client = GetMockRemoteCompletionClient();
std::string prompt_before_time_and_date =
"\n\nHuman: Here is the text of a web page in <page> tags:\n<page>\nThis "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,17 +215,16 @@ TEST_F(EngineConsumerConversationAPIUnitTest,
mojom::CharacterType::HUMAN, mojom::ActionType::QUERY,
mojom::ConversationTurnVisibility::VISIBLE,
"Which show is this catchphrase from?", "I have spoken.", std::nullopt,
base::Time::Now(), std::vector<mojom::ConversationTurnPtr>{}));
base::Time::Now(), std::nullopt));
history.push_back(mojom::ConversationTurn::New(
mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE,
mojom::ConversationTurnVisibility::VISIBLE, "The Mandalorian.",
std::nullopt, std::nullopt, base::Time::Now(),
std::vector<mojom::ConversationTurnPtr>{}));
std::nullopt, std::nullopt, base::Time::Now(), std::nullopt));
history.push_back(mojom::ConversationTurn::New(
mojom::CharacterType::HUMAN, mojom::ActionType::RESPONSE,
mojom::ConversationTurnVisibility::VISIBLE,
"Is it related to a broader series?", std::nullopt, std::nullopt,
base::Time::Now(), std::vector<mojom::ConversationTurnPtr>{}));
base::Time::Now(), std::nullopt));
std::string expected_events = R"([
{"role": "user", "type": "pageText", "content": "This is my page. I have spoken."},
{"role": "user", "type": "pageExcerpt", "content": "I have spoken."},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,11 @@ TEST_F(EngineConsumerLlamaUnitTest, TestGenerateAssistantResponse) {
mojom::CharacterType::HUMAN, mojom::ActionType::SUMMARIZE_SELECTED_TEXT,
mojom::ConversationTurnVisibility::VISIBLE,
"Which show is this catchphrase from?", "This is the way.", std::nullopt,
base::Time::Now(), std::vector<mojom::ConversationTurnPtr>{}));
base::Time::Now(), std::nullopt));
history.push_back(mojom::ConversationTurn::New(
mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE,
mojom::ConversationTurnVisibility::VISIBLE, "The Mandalorian.",
std::nullopt, std::nullopt, base::Time::Now(),
std::vector<mojom::ConversationTurnPtr>{}));
std::nullopt, std::nullopt, base::Time::Now(), std::nullopt));
auto* mock_remote_completion_client =
static_cast<MockRemoteCompletionClient*>(engine_->GetAPIForTesting());
std::string prompt_before_time_and_date =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,14 +207,12 @@ TEST_F(EngineConsumerOAIUnitTest, TestGenerateAssistantResponse) {
history.push_back(mojom::ConversationTurn::New(
mojom::CharacterType::HUMAN, mojom::ActionType::SUMMARIZE_SELECTED_TEXT,
mojom::ConversationTurnVisibility::VISIBLE, human_input, selected_text,
std::nullopt, base::Time::Now(),
std::vector<mojom::ConversationTurnPtr>{}));
std::nullopt, base::Time::Now(), std::nullopt));

history.push_back(mojom::ConversationTurn::New(
mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE,
mojom::ConversationTurnVisibility::VISIBLE, assistant_input, std::nullopt,
std::nullopt, base::Time::Now(),
std::vector<mojom::ConversationTurnPtr>{}));
std::nullopt, base::Time::Now(), std::nullopt));

std::string expected_human_input =
base::StrCat({base::ReplaceStringPlaceholders(
Expand Down
10 changes: 8 additions & 2 deletions components/ai_chat/core/common/mojom/ai_chat.mojom
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,14 @@ struct ConversationTurn {
// be important.
array<ConversationEntryEvent>? events;

mojo_base.mojom.Time last_edited_time;
array<ConversationTurn> edits;
mojo_base.mojom.Time created_time;

// Edits to this turn, sorted by time of creation, with the most recent edit
// at the end of the array. When this appears, the value of |text| field is
// the original text of the turn, the last entry of this array should be used
// instead of the original turn text when submitting the turn to the AI
// engine or displaying the most recent text to users.
array<ConversationTurn>? edits;
};

// Represents an AI engine model choice, usually for the user to choose for a
Expand Down
Loading

0 comments on commit 26769e4

Please sign in to comment.