Skip to content

Commit

Permalink
feat: gate use of v2 endpoint with waffle flag (#152)
Browse files Browse the repository at this point in the history
* feat: gate use of v2 endpoint with waffle flag

* fix: update version
  • Loading branch information
alangsto authored Jan 8, 2025
1 parent 1e87b89 commit 93e9d3a
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 38 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ Change Log
Unreleased
**********

4.7.0 - 2025-01-07
******************
* Gate use of the Xpert platform v2 endpoint with a waffle flag.

4.6.3 - 2025-01-06
******************
* Uses CourseEnrollment instead of CourseMode to get the upgrade deadline required to calculate if a learner's audit trial is expired.
Expand Down
2 changes: 1 addition & 1 deletion learning_assistant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
Plugin for a learning assistant backend, intended for use within edx-platform.
"""

__version__ = '4.6.3'
__version__ = '4.7.0'

default_app_config = 'learning_assistant.apps.LearningAssistantConfig' # pylint: disable=invalid-name
33 changes: 16 additions & 17 deletions learning_assistant/toggles.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,6 @@

WAFFLE_NAMESPACE = 'learning_assistant'

# .. toggle_name: learning_assistant.enable_course_content
# .. toggle_implementation: CourseWaffleFlag
# .. toggle_default: False
# .. toggle_description: Waffle flag to enable the course content integration with the learning assistant
# .. toggle_use_cases: temporary
# .. toggle_creation_date: 2024-01-08
# .. toggle_target_removal_date: 2024-01-31
# .. toggle_tickets: COSMO-80
ENABLE_COURSE_CONTENT = 'enable_course_content'

# .. toggle_name: learning_assistant.enable_chat_history
# .. toggle_implementation: CourseWaffleFlag
# .. toggle_default: False
Expand All @@ -24,8 +14,17 @@
# .. toggle_tickets: COSMO-436
ENABLE_CHAT_HISTORY = 'enable_chat_history'

# .. toggle_name: learning_assistant.enable_v2_endpoint
# .. toggle_implementation: CourseWaffleFlag
# .. toggle_default: False
# .. toggle_description: Waffle flag to enable use of the internal Xpert platform v2 endpoint
# .. toggle_use_cases: temporary
# .. toggle_creation_date: 2025-01-06
# .. toggle_target_removal_date: 2025-01-31
ENABLE_V2_ENDPOINT = 'enable_v2_endpoint'


def _is_learning_assistant_waffle_flag_enabled(flag_name, course_key):
def _is_learning_assistant_waffle_flag_enabled(flag_name, course_key=None):
"""
Import and return Waffle flag for enabling the summary hook.
"""
Expand All @@ -37,15 +36,15 @@ def _is_learning_assistant_waffle_flag_enabled(flag_name, course_key):
return False


def course_content_enabled(course_key):
def chat_history_enabled(course_key):
"""
Return whether the learning_assistant.enable_course_content WaffleFlag is on.
Return whether the learning_assistant.enable_chat_history WaffleFlag is on.
"""
return _is_learning_assistant_waffle_flag_enabled(ENABLE_COURSE_CONTENT, course_key)
return _is_learning_assistant_waffle_flag_enabled(ENABLE_CHAT_HISTORY, course_key)


def chat_history_enabled(course_key):
def v2_endpoint_enabled():
"""
Return whether the learning_assistant.enable_chat_history WaffleFlag is on.
Return whether the learning_assistant.enable_v2_endpoint WaffleFlag is on.
"""
return _is_learning_assistant_waffle_flag_enabled(ENABLE_CHAT_HISTORY, course_key)
return _is_learning_assistant_waffle_flag_enabled(ENABLE_V2_ENDPOINT)
25 changes: 17 additions & 8 deletions learning_assistant/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from requests.exceptions import ConnectTimeout
from rest_framework import status as http_status

from learning_assistant.toggles import v2_endpoint_enabled

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -47,30 +49,37 @@ def get_reduced_message_list(prompt_template, message_list):
# insert message at beginning of list, because we are traversing the message list from most recent to oldest
new_message_list.insert(0, new_message)

system_message = {'role': 'system', 'content': prompt_template}

return [system_message] + new_message_list
return new_message_list


def create_request_body(prompt_template, message_list):
"""
Form request body to be passed to the chat endpoint.
"""
messages = get_reduced_message_list(prompt_template, message_list)

response_body = {
'message_list': get_reduced_message_list(prompt_template, message_list),
'message_list': [{'role': 'system', 'content': prompt_template}] + messages,
}

if v2_endpoint_enabled():
response_body = {
'client_id': getattr(settings, 'CHAT_COMPLETION_CLIENT_ID', 'edx_olc_la'),
'system_message': prompt_template,
'messages': messages,
}

return response_body


def get_chat_response(prompt_template, message_list):
"""
Pass message list to chat endpoint, as defined by the CHAT_COMPLETION_API setting.
"""
completion_endpoint = getattr(settings, 'CHAT_COMPLETION_API', None)
completion_endpoint_key = getattr(settings, 'CHAT_COMPLETION_API_KEY', None)
if completion_endpoint and completion_endpoint_key:
headers = {'Content-Type': 'application/json', 'x-api-key': completion_endpoint_key}
completion_endpoint = getattr(settings, 'CHAT_COMPLETION_API_V2', None) if v2_endpoint_enabled() \
else getattr(settings, 'CHAT_COMPLETION_API', None)
if completion_endpoint:
headers = {'Content-Type': 'application/json'}
connect_timeout = getattr(settings, 'CHAT_COMPLETION_API_CONNECT_TIMEOUT', 1)
read_timeout = getattr(settings, 'CHAT_COMPLETION_API_READ_TIMEOUT', 15)

Expand Down
2 changes: 1 addition & 1 deletion test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def root(*args):
}]

CHAT_COMPLETION_API = 'https://test.edx.org/'
CHAT_COMPLETION_API_KEY = 'endpoint_key'
CHAT_COMPLETION_API_V2 = 'https://test.edx.org/v2'
CHAT_COMPLETION_API_CONNECT_TIMEOUT = 0.5
CHAT_COMPLETION_API_READ_TIMEOUT = 10

Expand Down
41 changes: 30 additions & 11 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,6 @@ def test_no_endpoint_setting(self):
self.assertEqual(status_code, 404)
self.assertEqual(message, 'Completion endpoint is not defined.')

@override_settings(CHAT_COMPLETION_API_KEY=None)
def test_no_endpoint_key_setting(self):
status_code, message = self.get_response()
self.assertEqual(status_code, 404)
self.assertEqual(message, 'Completion endpoint is not defined.')

@responses.activate
def test_200_response(self):
message_response = {'role': 'assistant', 'content': 'See you later!'}
Expand Down Expand Up @@ -91,7 +85,7 @@ def test_post_request_structure(self, mock_requests):
completion_endpoint = settings.CHAT_COMPLETION_API
connect_timeout = settings.CHAT_COMPLETION_API_CONNECT_TIMEOUT
read_timeout = settings.CHAT_COMPLETION_API_READ_TIMEOUT
headers = {'Content-Type': 'application/json', 'x-api-key': settings.CHAT_COMPLETION_API_KEY}
headers = {'Content-Type': 'application/json'}

response_body = {
'message_list': [{'role': 'system', 'content': self.prompt_template}] + self.message_list,
Expand All @@ -105,6 +99,31 @@ def test_post_request_structure(self, mock_requests):
timeout=(connect_timeout, read_timeout)
)

@patch('learning_assistant.utils.v2_endpoint_enabled')
@patch('learning_assistant.utils.requests')
def test_post_request_structure_v2_endpoint(self, mock_requests, mock_v2_enabled):
mock_requests.post = MagicMock()
mock_v2_enabled.return_value = True

completion_endpoint_v2 = settings.CHAT_COMPLETION_API_V2
connect_timeout = settings.CHAT_COMPLETION_API_CONNECT_TIMEOUT
read_timeout = settings.CHAT_COMPLETION_API_READ_TIMEOUT
headers = {'Content-Type': 'application/json'}

response_body = {
'client_id': 'edx_olc_la',
'system_message': self.prompt_template,
'messages': self.message_list,
}

self.get_response()
mock_requests.post.assert_called_with(
completion_endpoint_v2,
headers=headers,
data=json.dumps(response_body),
timeout=(connect_timeout, read_timeout)
)


class GetReducedMessageListTests(TestCase):
"""
Expand All @@ -126,18 +145,18 @@ def test_message_list_reduced(self):
"""
# pass in copy of list, as it is modified as part of the reduction
reduced_message_list = get_reduced_message_list(self.prompt_template, self.message_list)
self.assertEqual(len(reduced_message_list), 2)
self.assertEqual(len(reduced_message_list), 1)
self.assertEqual(
reduced_message_list,
[{'role': 'system', 'content': self.prompt_template}] + self.message_list[-1:]
self.message_list[-1:]
)

def test_message_list(self):
reduced_message_list = get_reduced_message_list(self.prompt_template, self.message_list)
self.assertEqual(len(reduced_message_list), 3)
self.assertEqual(len(reduced_message_list), 2)
self.assertEqual(
reduced_message_list,
[{'role': 'system', 'content': self.prompt_template}] + self.message_list
self.message_list
)


Expand Down

0 comments on commit 93e9d3a

Please sign in to comment.