Skip to content

Commit

Permalink
feat: 评估提供模型输出预处理 hook (#611)
Browse files Browse the repository at this point in the history
* 提供预处理接口

* 修复引入副作用
  • Loading branch information
Dobiichi-Origami authored Jun 25, 2024
1 parent badca9b commit 752cfe5
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 37 deletions.
1 change: 0 additions & 1 deletion python/qianfan/common/client/openai_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from typing import Any, AsyncIterator, Dict, Optional

Expand Down
34 changes: 0 additions & 34 deletions python/qianfan/dataset/local_data_operators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,34 +0,0 @@
from qianfan.dataset.local_data_operators.base import (
BaseLocalFilterOperator,
BaseLocalMapOperator,
)
from qianfan.dataset.local_data_operators.check_character_repetition_filter import (
LocalCheckCharacterRepetitionFilter,
)
from qianfan.dataset.local_data_operators.check_flagged_words import (
LocalCheckFlaggedWordsFilter,
)
from qianfan.dataset.local_data_operators.check_sentence_length_filter import (
LocalCheckEachSentenceIsLongEnoughFilter,
)
from qianfan.dataset.local_data_operators.check_special_characters import (
LocalCheckSpecialCharactersFilter,
)
from qianfan.dataset.local_data_operators.check_stopwords import (
LocalCheckStopwordsFilter,
)
from qianfan.dataset.local_data_operators.check_word_number import (
LocalCheckWordNumberFilter,
)

__all__ = [
"BaseLocalMapOperator",
"BaseLocalFilterOperator",
"LocalCheckSpecialCharactersFilter",
"LocalCheckCharacterRepetitionFilter",
"LocalCheckEachSentenceIsLongEnoughFilter",
"LocalCheckFlaggedWordsFilter",
"LocalCheckStopwordsFilter",
"LocalCheckWordNumberFilter",
"LocalCheckWordNumberFilter",
]
21 changes: 19 additions & 2 deletions python/qianfan/evaluation/evaluation_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import zipfile
from concurrent.futures import ALL_COMPLETED, Future, ThreadPoolExecutor, wait
from copy import copy
from typing import Any, Dict, List, Optional, Sequence, Set, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Union

import pyarrow

Expand All @@ -42,6 +42,7 @@
from qianfan.dataset.data_source.utils import (
_download_file_from_url_streamly,
)
from qianfan.dataset.local_data_operators.base import BaseLocalMapOperator
from qianfan.errors import QianfanError
from qianfan.evaluation.consts import QianfanRefereeEvaluatorPromptTemplate
from qianfan.evaluation.evaluation_result import EvaluationResult
Expand Down Expand Up @@ -105,6 +106,10 @@ class EvaluationManager(BaseModel):
"""logic control center of evaluation"""

local_evaluators: Optional[List[LocalEvaluator]] = Field(default=None)
pre_processors: Optional[List[Union[BaseLocalMapOperator, Callable]]] = Field(
default=None
)

qianfan_evaluators: Optional[List[QianfanEvaluator]] = Field(default=None)
task_id: Optional[str] = Field(default=None)

Expand Down Expand Up @@ -162,8 +167,20 @@ def _eval_worker(
assert self.local_evaluators
for i in range(start, end):
result: Dict[str, Any] = {}
(single_input, single_reference, single_output) = (
input[i],
reference[i],
output[i],
)
for pre_processor in self.pre_processors if self.pre_processors else []:
single_reference = pre_processor( # type: ignore
single_output, input=single_input, reference=reference
)

for evaluator in self.local_evaluators:
result.update(evaluator.evaluate(input[i], reference[i], output[i]))
result.update(
evaluator.evaluate(single_input, single_reference, single_output)
)
result_list.append(result)
return result_list

Expand Down

0 comments on commit 752cfe5

Please sign in to comment.