From 436d8c72e6a17b0638415f4d051cd98bf43fa4e1 Mon Sep 17 00:00:00 2001 From: Nick von Pentz <12549658+nvonpentz@users.noreply.github.com> Date: Tue, 23 Jan 2024 00:39:15 -0500 Subject: [PATCH] Add support for Brave Services Key V2 (uplift to 1.62.x) (#21688) Add support for Brave Services Key V2 (#21542) * Add support for Brave Services Key V2 And use for AI chat * Refactor * Move logic to brave_service_keys * Generalize logic such that callers can sign over multiple headers * Separate digest header generation * Rename AI_CHAT_SERVICE_KEY -> SERVICE_KEY_AI_CHAT * Separate SERVICE_KEY_AI_CHAT from signing logic * * Switch from base::span<> to const std::vector<>& * Move unused header to .cc file * Break apart functions and add more unit tests * Update GetAuthorizationHeaders * Pass the URL, HTTP method, full list of headers, and a list of headers to actually be signed to GetAuthorizationHeaders * Instead of using std::vector> for the list of headers, instead use base::flat_map since that matches the headers passed to the APIRequestHelper * Enforce header ordering specified by headers_to_sign * Generate (request-target) header if supplied, and add test from spec * Pass url and method to CreateSignatureString This way, (request-target) can be generated inside there and thus be unit tested. Adjust unit tests. * Add VLOG(1) when header to sign does not exist Also DCHECK(false) for good measure. * Add SERVICE_KEY_AI_CHAT and KEY_ID to config.js This way they can be sourced from .env. * Update tests * Link to specific section test vectors are from * Remove the "(created)" header from headers_to_sign (it's not included in the test vector) * Use //crypto instad of //crypto:crypto in components/brave_service_keys/BUILD.gn * Use constexpr for http method constant * Don't use a reference to the digest header * Use NOTREACHED_NORETURN() instead of DCHECK and VLOG(1) * Use CHECK for url in GetAuthorizationHeader Brave Server URLs should always be defined * Uncomment base/flat_map.h include in unittest * Add comment explaining KEY_ID * Add comments explaining functions in service_key_utils * Add is_official_build check for service_key_ai_chat * Update header constants * Use existing constants for kDigest and kAuthorization * Change kRequestTarget to kRequestTargetHeader * Rename service_key_utils.* -> brave_service_key_utils.* * nit: use base::StrCat and .append() * Make headers a const& in CreateSignatureString * Rename KEY_ID -> BRAVE_SERVICES_KEY_ID * Fix formatting of string * Rename SERVICE_KEY_AI_CHAT-> SERVICE_KEY_AICHAT * Apply Jenkinsfile patch * Revert "Apply Jenkinsfile patch" This reverts commit 513bfdae0a400bd645322e1f4585a11d33d686a8. --- build/commands/lib/config.js | 4 + components/ai_chat/core/browser/BUILD.gn | 5 +- .../engine/remote_completion_client.cc | 30 +++-- .../ai_chat/core/common/buildflags/BUILD.gn | 13 +- components/brave_service_keys/BUILD.gn | 51 ++++++++ .../brave_service_key_utils.cc | 110 ++++++++++++++++ .../brave_service_key_utils.h | 44 +++++++ .../brave_service_key_utils_unittest.cc | 123 ++++++++++++++++++ test/BUILD.gn | 1 + 9 files changed, 369 insertions(+), 12 deletions(-) create mode 100644 components/brave_service_keys/BUILD.gn create mode 100644 components/brave_service_keys/brave_service_key_utils.cc create mode 100644 components/brave_service_keys/brave_service_key_utils.h create mode 100644 components/brave_service_keys/brave_service_key_utils_unittest.cc diff --git a/build/commands/lib/config.js b/build/commands/lib/config.js index d06007bc2733..5afa27140fbf 100644 --- a/build/commands/lib/config.js +++ b/build/commands/lib/config.js @@ -234,6 +234,8 @@ const Config = function () { this.gomaServerHost.endsWith('.brave.com') || this.rbeService.includes('.brave.com:') || this.rbeService.includes('.engflow.com:') + this.brave_services_key_id = getNPMConfig(['brave_services_key_id']) || '' + this.service_key_aichat = getNPMConfig(['service_key_aichat']) || '' } Config.prototype.isReleaseBuild = function () { @@ -408,6 +410,8 @@ Config.prototype.buildArgs = function () { brave_services_staging_domain: this.braveServicesStagingDomain, brave_services_dev_domain: this.braveServicesDevDomain, enable_dangling_raw_ptr_checks: this.enable_dangling_raw_ptr_checks, + brave_services_key_id: this.brave_services_key_id, + service_key_aichat: this.service_key_aichat, ...this.extraGnArgs, } diff --git a/components/ai_chat/core/browser/BUILD.gn b/components/ai_chat/core/browser/BUILD.gn index d89adeeacd8b..bd57144d1e84 100644 --- a/components/ai_chat/core/browser/BUILD.gn +++ b/components/ai_chat/core/browser/BUILD.gn @@ -11,14 +11,14 @@ static_library("browser") { sources = [ "ai_chat_credential_manager.cc", "ai_chat_credential_manager.h", - "conversation_driver.cc", - "conversation_driver.h", "ai_chat_feedback_api.cc", "ai_chat_feedback_api.h", "ai_chat_metrics.cc", "ai_chat_metrics.h", "constants.cc", "constants.h", + "conversation_driver.cc", + "conversation_driver.h", "engine/engine_consumer.h", "engine/engine_consumer_claude.cc", "engine/engine_consumer_claude.h", @@ -37,6 +37,7 @@ static_library("browser") { "//brave/components/ai_chat/core/common/buildflags", "//brave/components/ai_chat/core/common/mojom", "//brave/components/api_request_helper", + "//brave/components/brave_service_keys", "//brave/components/brave_stats/browser", "//brave/components/constants", "//brave/components/l10n/common", diff --git a/components/ai_chat/core/browser/engine/remote_completion_client.cc b/components/ai_chat/core/browser/engine/remote_completion_client.cc index 10714f1d09d6..3281b236bf69 100644 --- a/components/ai_chat/core/browser/engine/remote_completion_client.cc +++ b/components/ai_chat/core/browser/engine/remote_completion_client.cc @@ -21,6 +21,7 @@ #include "brave/components/ai_chat/core/browser/constants.h" #include "brave/components/ai_chat/core/common/buildflags/buildflags.h" #include "brave/components/ai_chat/core/common/features.h" +#include "brave/components/brave_service_keys/brave_service_key_utils.h" #include "brave/components/constants/brave_services_key.h" #include "net/http/http_status_code.h" #include "net/traffic_annotation/network_traffic_annotation.h" @@ -31,7 +32,8 @@ namespace ai_chat { namespace { -constexpr char kAIChatCompletionPath[] = "v1/complete"; +constexpr char kAIChatCompletionPath[] = "v2/complete"; +constexpr char kHttpMethod[] = "POST"; net::NetworkTrafficAnnotationTag GetNetworkTrafficAnnotationTag() { return net::DefineNetworkTrafficAnnotation("ai_chat", R"( @@ -146,7 +148,23 @@ void RemoteCompletionClient::OnFetchPremiumCredential( absl::optional credential) { bool premium_enabled = credential.has_value(); const GURL api_url = GetEndpointUrl(premium_enabled, kAIChatCompletionPath); + const bool is_sse_enabled = + ai_chat::features::kAIChatSSE.Get() && !data_received_callback.is_null(); + const base::Value::Dict& dict = + CreateApiParametersDict(prompt, model_name_, stop_sequences_, + std::move(extra_stop_sequences), is_sse_enabled); + const std::string request_body = CreateJSONRequestBody(dict); + base::flat_map headers; + const auto digest_header = brave_service_keys::GetDigestHeader(request_body); + headers.emplace(digest_header.first, digest_header.second); + auto result = brave_service_keys::GetAuthorizationHeader( + BUILDFLAG(SERVICE_KEY_AICHAT), headers, api_url, kHttpMethod, {"digest"}); + if (result) { + std::pair authorization_header = result.value(); + headers.emplace(authorization_header.first, authorization_header.second); + } + if (premium_enabled) { // Add Leo premium SKU credential as a Cookie header. std::string cookie_header_value = @@ -156,12 +174,6 @@ void RemoteCompletionClient::OnFetchPremiumCredential( headers.emplace("x-brave-key", BUILDFLAG(BRAVE_SERVICES_KEY)); headers.emplace("Accept", "text/event-stream"); - const bool is_sse_enabled = - ai_chat::features::kAIChatSSE.Get() && !data_received_callback.is_null(); - - const base::Value::Dict& dict = - CreateApiParametersDict(prompt, model_name_, stop_sequences_, - std::move(extra_stop_sequences), is_sse_enabled); if (is_sse_enabled) { VLOG(2) << "Making streaming AI Chat API Request"; auto on_received = base::BindRepeating( @@ -172,7 +184,7 @@ void RemoteCompletionClient::OnFetchPremiumCredential( weak_ptr_factory_.GetWeakPtr(), credential, std::move(data_completed_callback)); - api_request_helper_.RequestSSE("POST", api_url, CreateJSONRequestBody(dict), + api_request_helper_.RequestSSE(kHttpMethod, api_url, request_body, "application/json", std::move(on_received), std::move(on_complete), headers, {}); } else { @@ -182,7 +194,7 @@ void RemoteCompletionClient::OnFetchPremiumCredential( weak_ptr_factory_.GetWeakPtr(), credential, std::move(data_completed_callback)); - api_request_helper_.Request("POST", api_url, CreateJSONRequestBody(dict), + api_request_helper_.Request(kHttpMethod, api_url, request_body, "application/json", std::move(on_complete), headers, {}); } diff --git a/components/ai_chat/core/common/buildflags/BUILD.gn b/components/ai_chat/core/common/buildflags/BUILD.gn index 39fb76ab0b5c..8d3a07934a78 100644 --- a/components/ai_chat/core/common/buildflags/BUILD.gn +++ b/components/ai_chat/core/common/buildflags/BUILD.gn @@ -7,9 +7,20 @@ import("//brave/build/config.gni") import("//brave/components/ai_chat/core/common/buildflags/buildflags.gni") import("//build/buildflag_header.gni") +declare_args() { + service_key_aichat = "" +} + +if (is_official_build) { + assert(service_key_aichat != "") +} + buildflag_header("buildflags") { header = "buildflags.h" - flags = [ "ENABLE_AI_CHAT=$enable_ai_chat" ] + flags = [ + "ENABLE_AI_CHAT=$enable_ai_chat", + "SERVICE_KEY_AICHAT=\"$service_key_aichat\"", + ] # Enable for desktop (all channels) and android (only dev and # nightly channels). diff --git a/components/brave_service_keys/BUILD.gn b/components/brave_service_keys/BUILD.gn new file mode 100644 index 000000000000..87cac319b872 --- /dev/null +++ b/components/brave_service_keys/BUILD.gn @@ -0,0 +1,51 @@ +# Copyright (c) 2024 The Brave Authors. All rights reserved. +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this file, +# You can obtain one at https://mozilla.org/MPL/2.0/. + +import("//brave/build/config.gni") +import("//build/buildflag_header.gni") +import("//testing/test.gni") + +declare_args() { + # BRAVE_SERVICES_KEY_ID = TARGET_OS + '-' + CV_MAJOR + '-' + RELEASE_CHANNEL + # It is used (in combination with a secret seed) to generate service keys + # for each service, OS, chrominum version, release channel + # combination. + brave_services_key_id = "" +} + +if (is_official_build) { + assert(brave_services_key_id != "") +} + +buildflag_header("buildflags") { + header = "buildflags.h" + flags = [ "BRAVE_SERVICES_KEY_ID=\"$brave_services_key_id\"" ] +} + +static_library("brave_service_keys") { + sources = [ + "brave_service_key_utils.cc", + "brave_service_key_utils.h", + ] + + deps = [ + ":buildflags", + "//base", + "//crypto", + "//net", + "//url:url", + ] +} +source_set("unit_tests") { + testonly = true + sources = [ "brave_service_key_utils_unittest.cc" ] + + deps = [ + ":brave_service_keys", + ":buildflags", + "//base", + "//testing/gtest", + ] +} diff --git a/components/brave_service_keys/brave_service_key_utils.cc b/components/brave_service_keys/brave_service_key_utils.cc new file mode 100644 index 000000000000..0bbe0d0e21b4 --- /dev/null +++ b/components/brave_service_keys/brave_service_key_utils.cc @@ -0,0 +1,110 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#include "brave/components/brave_service_keys/brave_service_key_utils.h" + +#include + +#include "base/base64.h" +#include "base/notreached.h" +#include "base/strings/strcat.h" +#include "base/strings/string_util.h" +#include "brave/components/brave_service_keys/buildflags.h" +#include "crypto/hmac.h" +#include "crypto/sha2.h" +#include "net/http/http_auth_scheme.h" +#include "net/http/http_request_headers.h" + +namespace brave_service_keys { + +namespace { + +constexpr char kRequestTargetHeader[] = "(request-target)"; + +} // namespace + +std::pair GetDigestHeader( + const std::string& payload) { + const std::string value = base::StrCat( + {"SHA-256=", base::Base64Encode(crypto::SHA256HashString(payload))}); + return std::make_pair(net::kDigestAuthScheme, value); +} + +std::pair CreateSignatureString( + const base::flat_map& headers, + const GURL& url, + const std::string& method, + const std::vector& headers_to_sign) { + std::string header_names; + std::string signature_string; + + for (const auto& header_to_sign : headers_to_sign) { + // Prepend some padding / newlines if this isn't the first + // header to sign + if (!header_names.empty()) { + header_names.append(" "); + signature_string.append("\n"); + } + header_names.append(header_to_sign); + + // Handle the special case header (request-target) by constructing + // the value instead of getting it from headers. + if (header_to_sign == kRequestTargetHeader) { + signature_string.append( + base::StrCat({kRequestTargetHeader, ": ", base::ToLowerASCII(method), + " ", url.PathForRequest()})); + continue; + } + + // For all the headers to sign, we expect their values to be be in the + // headers flat_map and use the value there to add to the signature string. + auto header = headers.find(header_to_sign); + if (header == headers.end()) { + NOTREACHED_NORETURN() + << "Can't sign over non-existent header " << header_to_sign; + } + signature_string.append( + base::StrCat({header_to_sign, ": ", header->second})); + } + + return std::make_pair(header_names, signature_string); +} + +std::optional> GetAuthorizationHeader( + const std::string& service_key, + const base::flat_map& headers, + const GURL& url, + const std::string& method, + const std::vector& headers_to_sign) { + CHECK(url.is_valid()); + auto [header_names, signature_string] = + CreateSignatureString(headers, url, method, headers_to_sign); + + // Create the signature using the service_key. + crypto::HMAC hmac(crypto::HMAC::SHA256); + const size_t signature_digest_length = hmac.DigestLength(); + std::vector signature_digest(signature_digest_length); + const bool success = hmac.Init(service_key) && + hmac.Sign(signature_string, &signature_digest[0], + signature_digest.size()); + if (!success) { + return std::nullopt; + } + + // Create the authorization header. + std::string signature_digest_base64; + base::Base64Encode( + std::string(signature_digest.begin(), signature_digest.end()), + &signature_digest_base64); + + const std::string value = + base::StrCat({"Signature keyId=\"", BUILDFLAG(BRAVE_SERVICES_KEY_ID), + "\",algorithm=\"hs2019\",headers=\"", header_names, + "\",signature=\"", signature_digest_base64, "\""}); + + return std::make_pair(net::HttpRequestHeaders::kAuthorization, value); +} + +} // namespace brave_service_keys diff --git a/components/brave_service_keys/brave_service_key_utils.h b/components/brave_service_keys/brave_service_key_utils.h new file mode 100644 index 000000000000..7d2460ef3be4 --- /dev/null +++ b/components/brave_service_keys/brave_service_key_utils.h @@ -0,0 +1,44 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#ifndef BRAVE_COMPONENTS_BRAVE_SERVICE_KEYS_BRAVE_SERVICE_KEY_UTILS_H_ +#define BRAVE_COMPONENTS_BRAVE_SERVICE_KEYS_BRAVE_SERVICE_KEY_UTILS_H_ + +#include +#include +#include +#include + +#include "base/containers/flat_map.h" +#include "url/gurl.h" + +namespace brave_service_keys { + +// Calculates the SHA-256 hash of the supplied payload and returns a pair +// comprising of the digest header field, and header value in the format +// "SHA-256=". +std::pair GetDigestHeader(const std::string& payload); + +// Generates the the string to be signed over and included in the authorization +// header. See +// https://datatracker.ietf.org/doc/html/draft-cavage-http-signatures-08#section-2.3:w +std::pair CreateSignatureString( + const base::flat_map& headers, + const GURL& url, + const std::string& method, + const std::vector& headers_to_sign); + +// Generates an authorization header field and value pair using the provided +// service key to sign over specified headers. +std::optional> GetAuthorizationHeader( + const std::string& service_key, + const base::flat_map& headers, + const GURL& url, + const std::string& method, + const std::vector& headers_to_sign); + +} // namespace brave_service_keys + +#endif // BRAVE_COMPONENTS_BRAVE_SERVICE_KEYS_BRAVE_SERVICE_KEY_UTILS_H_ diff --git a/components/brave_service_keys/brave_service_key_utils_unittest.cc b/components/brave_service_keys/brave_service_key_utils_unittest.cc new file mode 100644 index 000000000000..28a84b267866 --- /dev/null +++ b/components/brave_service_keys/brave_service_key_utils_unittest.cc @@ -0,0 +1,123 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#include "brave/components/brave_service_keys/brave_service_key_utils.h" + +#include "base/containers/flat_map.h" +#include "base/strings/strcat.h" +#include "brave/components/brave_service_keys/buildflags.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace brave_service_keys { + +TEST(BraveServicesUtilsUnittest, GetDigestHeader) { + // Test vector is from + // https://www.ietf.org/archive/id/draft-ietf-httpbis-digest-headers-04.html#section-10.4 + const auto& header = GetDigestHeader("{\"hello\": \"world\"}"); + EXPECT_EQ(header.first, "digest"); + EXPECT_EQ(header.second, + "SHA-256=X48E9qOokqqrvdts8nOJRJN3OWDUoyWxBf7kbu9DBPE="); +} + +TEST(BraveServicesUtilsUnittest, CreateSignatureString) { + const GURL url = GURL("http://example.com/foo"); + base::flat_map headers = { + {"digest", "SHA-256=X48E9qOokqqrvdts8nOJRJN3OWDUoyWxBf7kbu9DBPE="}, + {"content-type", "application/json"}, + {"host", "example.com"}}; + // Test for no headers + auto result = CreateSignatureString(headers, url, "GET", {}); + EXPECT_EQ(result.first, ""); + EXPECT_EQ(result.second, ""); + + // Test for single header + result = CreateSignatureString(headers, url, "GET", {"digest"}); + EXPECT_EQ(result.first, "digest"); + EXPECT_EQ(result.second, + "digest: SHA-256=X48E9qOokqqrvdts8nOJRJN3OWDUoyWxBf7kbu9DBPE="); + + // Test for multiple headers in specified order + result = + CreateSignatureString(headers, url, "GET", {"content-type", "digest"}); + EXPECT_EQ(result.first, "content-type digest"); + EXPECT_EQ(result.second, + "content-type: application/json\ndigest: " + "SHA-256=X48E9qOokqqrvdts8nOJRJN3OWDUoyWxBf7kbu9DBPE="); + + // Test for multiple headers in reverse order + result = + CreateSignatureString(headers, url, "GET", {"digest", "content-type"}); + EXPECT_EQ(result.first, "digest content-type"); + EXPECT_EQ(result.second, + "digest: " + "SHA-256=X48E9qOokqqrvdts8nOJRJN3OWDUoyWxBf7kbu9DBPE=\ncontent-" + "type: application/json"); + + // Test vector from + // https://datatracker.ietf.org/doc/html/draft-cavage-http-signatures-08#section-3.1.3 + headers = {{"(request-target)", "post /foo"}, + {"host", "example.org"}, + {"date", "Tue, 07 Jun 2014 20:51:35 GMT"}, + {"digest", "SHA-256=X48E9qOokqqrvdts8nOJRJN3OWDUoyWxBf7kbu9DBPE="}, + {"content-length", "18"}}; + result = CreateSignatureString( + headers, url, "POST", + {"(request-target)", "host", "date", "digest", "content-length"}); + EXPECT_EQ(result.first, "(request-target) host date digest content-length"); + EXPECT_EQ(result.second, + "(request-target): post /foo\n" + "host: example.org\n" + "date: Tue, 07 Jun 2014 20:51:35 GMT\n" + "digest: SHA-256=X48E9qOokqqrvdts8nOJRJN3OWDUoyWxBf7kbu9DBPE=\n" + "content-length: 18"); + + // Try without explicitly setting (request-target) + headers = {{"host", "example.org"}, + {"date", "Tue, 07 Jun 2014 20:51:35 GMT"}, + {"digest", "SHA-256=X48E9qOokqqrvdts8nOJRJN3OWDUoyWxBf7kbu9DBPE="}, + {"content-length", "18"}}; + result = CreateSignatureString( + headers, url, "POST", + {"(request-target)", "host", "date", "digest", "content-length"}); + EXPECT_EQ(result.first, "(request-target) host date digest content-length"); + EXPECT_EQ(result.second, + "(request-target): post /foo\n" + "host: example.org\n" + "date: Tue, 07 Jun 2014 20:51:35 GMT\n" + "digest: SHA-256=X48E9qOokqqrvdts8nOJRJN3OWDUoyWxBf7kbu9DBPE=\n" + "content-length: 18"); +} + +TEST(BraveServicesUtilsUnittest, GetAuthorizationHeader) { + const auto& digest_header = GetDigestHeader("{\"hello\": \"world\"}"); + base::flat_map headers; + headers[digest_header.first] = digest_header.second; + const std::string service_key = + "bacfb4d7e93c6df045f66fa4bf438402b43ba2c9e3ce9b4eef470d24e32378e8"; + auto result = GetAuthorizationHeader( + service_key, headers, GURL("https://example.com"), "POST", {"digest"}); + ASSERT_TRUE(result); + EXPECT_EQ(result->first, "Authorization"); + EXPECT_EQ( + result->second, + base::StrCat({"Signature keyId=\"", BUILDFLAG(BRAVE_SERVICES_KEY_ID), + "\",algorithm=\"hs2019\",headers=\"digest\",signature=\"" + "jumtKp4LQDzIBpuGKIEI/mxrr9AEcSzvRGD6PfYyAq8=\""})); + + // Try again with (request-target) + result = GetAuthorizationHeader(service_key, headers, + GURL("https://example.com/test/v1?a=b"), + "POST", {"(request-target)", "digest"}); + ASSERT_TRUE(result); + EXPECT_EQ(result->first, "Authorization"); + EXPECT_EQ( + result->second, + base::StrCat({"Signature keyId=\"", BUILDFLAG(BRAVE_SERVICES_KEY_ID), + "\",algorithm=\"hs2019\",headers=\"(request-target) " + "digest\",signature=\"" + "kBICAlSiWuMoMr4Rws1KzyXOE6qK91jcAs8v9C7t4QQ=\""})); +} + +} // namespace brave_service_keys diff --git a/test/BUILD.gn b/test/BUILD.gn index 9a3cc1a6b19b..29b15b0f67cb 100644 --- a/test/BUILD.gn +++ b/test/BUILD.gn @@ -214,6 +214,7 @@ test("brave_unit_tests") { "//brave/components/brave_search/common", "//brave/components/brave_search_conversion", "//brave/components/brave_search_conversion:unit_tests", + "//brave/components/brave_service_keys:unit_tests", "//brave/components/brave_shields/browser", "//brave/components/brave_shields/common", "//brave/components/brave_shields/common:mojom",