-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
99f2293
commit 60a3e2d
Showing
1 changed file
with
277 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,277 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Import All Libaries" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"PyTerrier 0.10.0 has loaded Terrier 5.8 (built by craigm on 2023-11-01 18:05) and terrier-helper 0.0.8\n", | ||
"\n", | ||
"No etc/terrier.properties, using terrier.default.properties for bootstrap configuration.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Imports\n", | ||
"from tira.third_party_integrations import ensure_pyterrier_is_loaded, persist_and_normalize_run\n", | ||
"from tira.rest_api_client import Client\n", | ||
"ensure_pyterrier_is_loaded()\n", | ||
"import pandas as pd\n", | ||
"import pyterrier as pt\n", | ||
"from tqdm import tqdm\n", | ||
"from jnius import autoclass\n", | ||
"import gzip\n", | ||
"import json\n", | ||
"import re\n", | ||
"\n", | ||
"# Create a REST client to the TIRA platform for retrieving the pre-indexed data.\n", | ||
"tira = Client()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Load the Dataset and the Index\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# The dataset: the union of the IR Anthology and the ACL Anthology\n", | ||
"# This line creates an IRDSDataset object and registers it under the name provided as an argument.\n", | ||
"dataset = 'antique-test-20230107-training'\n", | ||
"pt_dataset = pt.get_dataset(f'irds:ir-benchmarks/{dataset}')\n", | ||
"bm25 = tira.pt.from_submission('ir-benchmarks/tira-ir-starter/BM25 Re-Rank (tira-ir-starter-pyterrier)', dataset)\n", | ||
"\n", | ||
"# A (pre-built) PyTerrier index loaded from TIRA\n", | ||
"index = tira.pt.index('ir-lab-sose-2024/tira-ir-starter/Index (tira-ir-starter-pyterrier)', pt_dataset)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Stopwords" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"ir-benchmarks/antique-test-20230107-training documents: 0%| | 1912/403666 [00:02<03:27, 1939.85it/s]" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"13:47:56.872 [ForkJoinPool-1-worker-3] WARN org.terrier.structures.indexing.Indexer - Adding an empty document to the index (2824443_2) - further warnings are suppressed\n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"ir-benchmarks/antique-test-20230107-training documents: 100%|██████████| 403666/403666 [00:44<00:00, 9033.35it/s] \n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"13:48:43.988 [ForkJoinPool-1-worker-3] WARN org.terrier.structures.indexing.Indexer - Indexed 1570 empty documents\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"def create_index(documents, stopwords):\n", | ||
" indexer = pt.IterDictIndexer(\"/tmp/index\", overwrite=True, meta={'docno': 100, 'text': 20480}, stopwords=stopwords)\n", | ||
" index_ref = indexer.index(documents)\n", | ||
" return pt.IndexFactory.of(index_ref)\n", | ||
"\n", | ||
"chatGPTStopwords =[\n", | ||
" 'i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', 'your', 'yours', \n", | ||
" 'yourself', 'yourselves', 'he', 'him', 'his', 'himself', 'she', 'her', 'hers', 'herself', \n", | ||
" 'it', 'its', 'itself', 'they', 'them', 'their', 'theirs', 'themselves', 'what', 'which', \n", | ||
" 'who', 'whom', 'this', 'that', 'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', \n", | ||
" 'been', 'being', 'have', 'has', 'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an', \n", | ||
" 'the', 'and', 'but', 'if', 'or', 'because', 'as', 'until', 'while', 'of', 'at', 'by', \n", | ||
" 'for', 'with', 'about', 'against', 'between', 'into', 'through', 'during', 'before', \n", | ||
" 'after', 'above', 'below', 'to', 'from', 'up', 'down', 'in', 'out', 'on', 'off', 'over', \n", | ||
" 'under', 'again', 'further', 'then', 'once', 'here', 'there', 'when', 'where', 'why', \n", | ||
" 'how', 'all', 'any', 'both', 'each', 'few', 'more', 'most', 'other', 'some', 'such', \n", | ||
" 'no', 'nor', 'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very', 's', 't', 'can', \n", | ||
" 'will', 'just', 'don', 'should', 'now'\n", | ||
"]\n", | ||
"\n", | ||
"index = create_index(pt_dataset.get_corpus_iter(), chatGPTStopwords)\n", | ||
"\n", | ||
"bm25_chatGPTStopwords = pt.BatchRetrieve(index, wmodel=\"BM25\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Query Expansion with Large Language Models" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"llama_sq_zs = tira.pt.transform_queries('ir-benchmarks/tu-dresden-03/qe-llama-sq-zs', dataset, prefix='llm_expansion_')\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"tokeniser = pt.autoclass(\"org.terrier.indexing.tokenisation.Tokeniser\").getTokeniser()\n", | ||
"\n", | ||
"def pt_tokenize(text):\n", | ||
" return ' '.join(tokeniser.getTokens(text))\n", | ||
"\n", | ||
"def expand_query(topic):\n", | ||
" ret = ' '.join([topic['query'], topic['query'], topic['query'], topic['query'], topic['query'], topic['llm_expansion_query']])\n", | ||
"\n", | ||
" # apply the tokenization\n", | ||
" return pt_tokenize(ret)\n", | ||
"\n", | ||
"# we wrap this into an pyterrier transformer\n", | ||
"# Documentation: https://pyterrier.readthedocs.io/en/latest/apply.html\n", | ||
"pt_expand_query = pt.apply.query(expand_query)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"pipeline_llama_sq_zs = (llama_sq_zs >> pt_expand_query) >> bm25_chatGPTStopwords" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Evaluation" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/html": [ | ||
"<div>\n", | ||
"<style scoped>\n", | ||
" .dataframe tbody tr th:only-of-type {\n", | ||
" vertical-align: middle;\n", | ||
" }\n", | ||
"\n", | ||
" .dataframe tbody tr th {\n", | ||
" vertical-align: top;\n", | ||
" }\n", | ||
"\n", | ||
" .dataframe thead th {\n", | ||
" text-align: right;\n", | ||
" }\n", | ||
"</style>\n", | ||
"<table border=\"1\" class=\"dataframe\">\n", | ||
" <thead>\n", | ||
" <tr style=\"text-align: right;\">\n", | ||
" <th></th>\n", | ||
" <th>name</th>\n", | ||
" <th>recall_1000</th>\n", | ||
" <th>ndcg_cut_5</th>\n", | ||
" <th>ndcg_cut.10</th>\n", | ||
" <th>recip_rank</th>\n", | ||
" </tr>\n", | ||
" </thead>\n", | ||
" <tbody>\n", | ||
" <tr>\n", | ||
" <th>0</th>\n", | ||
" <td>BM25_chatgptstopwords+Llama-SQ-ZS \\t</td>\n", | ||
" <td>0.808404</td>\n", | ||
" <td>0.566703</td>\n", | ||
" <td>0.53322</td>\n", | ||
" <td>0.928343</td>\n", | ||
" </tr>\n", | ||
" </tbody>\n", | ||
"</table>\n", | ||
"</div>" | ||
], | ||
"text/plain": [ | ||
" name recall_1000 ndcg_cut_5 ndcg_cut.10 \\\n", | ||
"0 BM25_chatgptstopwords+Llama-SQ-ZS \\t 0.808404 0.566703 0.53322 \n", | ||
"\n", | ||
" recip_rank \n", | ||
"0 0.928343 " | ||
] | ||
}, | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"pt.Experiment(\n", | ||
" retr_systems=[pipeline_llama_sq_zs],\n", | ||
" topics=pt_dataset.get_topics('text'),\n", | ||
" qrels=pt_dataset.get_qrels(),\n", | ||
" names=['BM25_chatgptstopwords+Llama-SQ-ZS'],\n", | ||
" eval_metrics=['recall_1000', 'ndcg_cut_5', 'ndcg_cut.10', 'recip_rank']\n", | ||
")" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.12" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |