-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathoci_baai_reranker.py
163 lines (132 loc) · 4.97 KB
/
oci_baai_reranker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""
File name: oci_baai_reranker.py
Author: Luigi Saetta
Date created: 2023-12-30
Date last modified: 2024-08-03
Python Version: 3.9
Description:
This module provides the base class to integrate a reranker
deployed as Model Deployment in OCI Data Science
as reranker in llama-index
Inspired by:
https://github.com/run-llama/llama_index/blob/main/llama_index/postprocessor/cohere_rerank.py
Usage:
Import this module into other scripts to use its functions.
Example:
baai_reranker = OCIBAAIReranker(
auth=api_keys_config,
deployment_id=RERANKER_ID, region="eu-frankfurt-1")
reranker = OCILLamaReranker(oci_reranker=baai_reranker, top_n=TOP_N)
License:
This code is released under the MIT License.
Notes:
This is a part of a set of demo showing how to use Oracle Vector DB,
OCI GenAI service, Oracle GenAI Embeddings, to build a RAG solution,
where all he data (text + embeddings) are stored in Oracle DB 23c
Warnings:
This module is in development, may change in future versions.
"""
import base64
import logging
import requests
import cloudpickle
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
class OCIBAAIReranker:
"""
class for custom reranker
"""
def __init__(self, auth, deployment_id, region="eu-frankfurt-1"):
"""
auth: to manage OCI auth
deployment_id: the ocid of the model deployment
region: the OCI region where the deployment is
top_n: how many to return
"""
self.auth = auth
self.deployment_id = deployment_id
# build the endpoint
base_url = f"https://modeldeployment.{region}.oci.customer-oci.com/"
self.endpoint = f"{base_url}{self.deployment_id}/predict"
logging.info("Created OCI reranker client...")
logging.info("Region: %s ...", region)
logging.info("Deployment id: %s ...", deployment_id)
logging.info("")
def _build_body(self, input_list):
"""
This method builds the body for the https POST call
"""
# it has been difficult to figure out how to do this... seems to work
# maybe there are other ways
val_ser = base64.b64encode(cloudpickle.dumps(input_list)).decode("utf-8")
body = {"data": val_ser, "data_type": "numpy.ndarray"}
return body
@classmethod
def class_name(cls) -> str:
"""
return the class name
"""
return "OCIBAAIReranker"
def _compute_score(self, x):
"""
This method exposes the original interface of the Model deployed
(see BAAI reranker compute_score)
x: a list of couple of strings to be compared
example: [["input1", "input2"]]
"""
# prepares the body for the Model Deployment invocation
body = self._build_body(x)
try:
# here we invoke the deployment
response = requests.post(
self.endpoint, json=body, auth=self.auth["signer"], timeout=60
)
# check if HTTP status is OK
if response.status_code == 200:
# ok go forward
response = response.json()
else:
logging.error(
"Error in OCIBAAIReranker compute_score: %s", response.json()
)
return []
except Exception as e:
logging.error("Error in OCIBAAIReranker compute_score...")
logging.error(e)
return []
return response
def rerank(self, query, texts, top_n=2):
"""
Invoke the Model Deployment with the reranker
- query
- texts: List[str] are compared and reranked with query
"""
# BAAI reranker expects input in this way
# x is a list of list, like [['what is panda?', 'The giant panda is a bear.'],
# ['what is panda?', 'It is an animal living in China']]
x = [[query, text] for text in texts]
try:
# here we invoke the deployment
response = self._compute_score(x)
# return the texts in order of decreasing score
# this block of code has been inspired by the code of the cohere_reranker
sorted_data = []
if len(response) > 0:
data = [
{"text": text, "index": index, "relevance_score": score}
for index, (text, score) in enumerate(
zip(texts, response["prediction"])
)
]
# sort in decreasing score
sorted_data = sorted(
data, key=lambda x: x["relevance_score"], reverse=True
)
# output only top_n
sorted_data = sorted_data[:top_n]
except Exception as e:
logging.error("Error in OCIBAAIReranker rerank...")
logging.error(e)
return []
return sorted_data