From 9045dc4b242f60c929bf782fee6193ea345f06d8 Mon Sep 17 00:00:00 2001 From: Jeremy Fowers <80718789+jeremyfowers@users.noreply.github.com> Date: Thu, 29 Aug 2024 11:40:02 -0400 Subject: [PATCH] Turnkey-LLM (aka lemonade) (#225) --- .github/workflows/publish-to-test-pypi.yml | 9 +- .github/workflows/test_lemonade.yml | 46 +++ .github/workflows/test_turnkey.yml | 50 ++-- README.md | 8 + docs/mmlu_accuracy.md | 100 +++++++ docs/perplexity.md | 72 +++++ examples/llm/turnkey_llm.ipynb | 184 ++++++++++++ setup.py | 40 +++ src/turnkeyml/llm/README.md | 129 ++++++++ src/turnkeyml/llm/__init__.py | 1 + src/turnkeyml/llm/cache.py | 32 ++ src/turnkeyml/llm/cli.py | 124 ++++++++ src/turnkeyml/llm/leap.py | 143 +++++++++ src/turnkeyml/llm/tools/__init__.py | 0 src/turnkeyml/llm/tools/adapter.py | 82 +++++ src/turnkeyml/llm/tools/chat.py | 257 ++++++++++++++++ src/turnkeyml/llm/tools/huggingface_load.py | 252 ++++++++++++++++ src/turnkeyml/llm/tools/mmlu.py | 270 +++++++++++++++++ src/turnkeyml/llm/tools/ort_genai/__init__.py | 0 src/turnkeyml/llm/tools/ort_genai/oga.py | 279 ++++++++++++++++++ src/turnkeyml/llm/tools/perplexity.py | 144 +++++++++ .../llm/tools/ryzenai_npu/__init__.py | 0 .../llm/tools/ryzenai_npu/ryzenai_npu.py | 253 ++++++++++++++++ src/turnkeyml/version.py | 2 +- test/llm_api.py | 57 ++++ 25 files changed, 2506 insertions(+), 28 deletions(-) create mode 100644 .github/workflows/test_lemonade.yml create mode 100644 docs/mmlu_accuracy.md create mode 100644 docs/perplexity.md create mode 100644 examples/llm/turnkey_llm.ipynb create mode 100644 src/turnkeyml/llm/README.md create mode 100644 src/turnkeyml/llm/__init__.py create mode 100644 src/turnkeyml/llm/cache.py create mode 100644 src/turnkeyml/llm/cli.py create mode 100644 src/turnkeyml/llm/leap.py create mode 100644 src/turnkeyml/llm/tools/__init__.py create mode 100644 src/turnkeyml/llm/tools/adapter.py create mode 100644 src/turnkeyml/llm/tools/chat.py create mode 100644 src/turnkeyml/llm/tools/huggingface_load.py create mode 100644 src/turnkeyml/llm/tools/mmlu.py create mode 100644 src/turnkeyml/llm/tools/ort_genai/__init__.py create mode 100644 src/turnkeyml/llm/tools/ort_genai/oga.py create mode 100644 src/turnkeyml/llm/tools/perplexity.py create mode 100644 src/turnkeyml/llm/tools/ryzenai_npu/__init__.py create mode 100644 src/turnkeyml/llm/tools/ryzenai_npu/ryzenai_npu.py create mode 100644 test/llm_api.py diff --git a/.github/workflows/publish-to-test-pypi.yml b/.github/workflows/publish-to-test-pypi.yml index eff66388..7d05d25c 100644 --- a/.github/workflows/publish-to-test-pypi.yml +++ b/.github/workflows/publish-to-test-pypi.yml @@ -24,15 +24,20 @@ jobs: run: >- python -m pip install build --user - name: Build a binary wheel and a source tarball - run: >- + run: | python -m build --sdist --wheel --outdir dist/ . + version=$(python setup.py --version) + echo "VERSION=$version" >> $GITHUB_ENV - name: Test wheel shell: bash -el {0} run: | python -m pip install --upgrade pip - pip install dist/*.whl + pip install "dist/turnkeyml-${{ env.VERSION }}-py3-none-any.whl" models=$(turnkey models-location --quiet) turnkey -i $models/selftest/linear.py discover export-pytorch + # Test LLMs as well + pip install "dist/turnkeyml-${{ env.VERSION }}-py3-none-any.whl[llm]" + lemonade -i facebook/opt-125m huggingface-load llm-prompt -p "Hello, my thoughts are" - name: Publish distribution package to PyPI if: startsWith(github.ref, 'refs/tags/v') uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.github/workflows/test_lemonade.yml b/.github/workflows/test_lemonade.yml new file mode 100644 index 00000000..89b53cdc --- /dev/null +++ b/.github/workflows/test_lemonade.yml @@ -0,0 +1,46 @@ +# This workflow will install Python dependencies, run tests and lint with a single version of Python +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python + +name: Lint and Test Lemonade + +on: + push: + branches: ["main"] + pull_request: + branches: ["main"] + +permissions: + contents: read + +jobs: + make-lemonade: + env: + LEMONADE_CI_MODE: "True" + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Miniconda with 64-bit Python + uses: conda-incubator/setup-miniconda@v2 + with: + miniconda-version: "latest" + activate-environment: lemon + python-version: "3.10" + - name: Install dependencies + shell: bash -el {0} + run: | + python -m pip install --upgrade pip + conda install pylint + python -m pip check + pip install -e .[llm] + - name: Lint with PyLint + shell: bash -el {0} + run: | + pylint src/turnkeyml/llm --rcfile .pylintrc --disable E0401 + - name: Run lemonade tests + shell: bash -el {0} + run: | + lemonade -i facebook/opt-125m huggingface-load llm-prompt -p "hi" --max-new-tokens 10 + + python test/llm_api.py + + diff --git a/.github/workflows/test_turnkey.yml b/.github/workflows/test_turnkey.yml index 12ebb06e..7c21c110 100644 --- a/.github/workflows/test_turnkey.yml +++ b/.github/workflows/test_turnkey.yml @@ -42,7 +42,7 @@ jobs: - name: Lint with PyLint shell: bash -el {0} run: | - pylint src/turnkeyml --rcfile .pylintrc + pylint src/turnkeyml --rcfile .pylintrc --ignore-paths src/turnkeyml/llm pylint examples --rcfile .pylintrc --disable E0401,E0611 - name: Test with unittest shell: bash -el {0} @@ -77,31 +77,31 @@ jobs: rm -rf ~/.cache/turnkey pip install -e examples/cli/plugins/example_tool turnkey -i examples/cli/scripts/hello_world.py discover export-pytorch example-plugin-tool benchmark - - name: Install and Start Slurm - if: runner.os != 'Windows' - shell: bash -el {0} - run: | - sudo apt update -y - sudo apt install slurm-wlm -y - cp test/helpers/slurm.conf test/helpers/slurm_modified.conf - sed -i "s/YOUR_HOSTNAME_HERE/$HOSTNAME/" test/helpers/slurm_modified.conf - sudo mv test/helpers/slurm_modified.conf /etc/slurm/slurm.conf - sudo service slurmd start - sudo service slurmctld start - sudo service munge start - - name: Test turnkey on Slurm - if: runner.os != 'Windows' - shell: bash -el {0} - run: | - # Create conda environment for Slurm using srun (sbatch + wait) - export SKIP_REQUIREMENTS_INSTALL="True" - export TORCH_CPU="True" - srun src/turnkeyml/cli/setup_venv.sh + # - name: Install and Start Slurm + # if: runner.os != 'Windows' + # shell: bash -el {0} + # run: | + # sudo apt update -y + # sudo apt install slurm-wlm -y + # cp test/helpers/slurm.conf test/helpers/slurm_modified.conf + # sed -i "s/YOUR_HOSTNAME_HERE/$HOSTNAME/" test/helpers/slurm_modified.conf + # sudo mv test/helpers/slurm_modified.conf /etc/slurm/slurm.conf + # sudo service slurmd start + # sudo service slurmctld start + # sudo service munge start + # - name: Test turnkey on Slurm + # if: runner.os != 'Windows' + # shell: bash -el {0} + # run: | + # # Create conda environment for Slurm using srun (sbatch + wait) + # export SKIP_REQUIREMENTS_INSTALL="True" + # export TORCH_CPU="True" + # srun src/turnkeyml/cli/setup_venv.sh - # Run tests on Slurm - export TURNKEY_SLURM_USE_DEFAULT_MEMORY="True" - turnkey -i models/selftest/linear.py --use-slurm --cache-dir local_cache discover export-pytorch - bash test/helpers/check_slurm_output.sh slurm-2.out + # # Run tests on Slurm + # export TURNKEY_SLURM_USE_DEFAULT_MEMORY="True" + # turnkey -i models/selftest/linear.py --use-slurm --cache-dir local_cache discover export-pytorch + # bash test/helpers/check_slurm_output.sh slurm-2.out # Below tests are commented out as the GitHub runner runs out of space installing the requirements # - name: Check installation of requirements.txt and their compatibility with turnkey diff --git a/README.md b/README.md index ce230669..534dfd58 100644 --- a/README.md +++ b/README.md @@ -6,8 +6,12 @@ We are on a mission to make it easy to use the most important tools in the ONNX ecosystem. TurnkeyML accomplishes this by providing a no-code CLI, `turnkey`, as well as a low-code API, that provide seamless integration of these tools. +We also provide [`turnkey-llm`](https://github.com/onnx/turnkeyml/tree/main/src/turnkeyml/llm), which has LLM-specific tools for prompting, accuracy measurement, and serving on a variety of runtimes (Huggingface, onnxruntime-genai) and hardware (CPU, GPU, and NPU). + ## Getting Started +### Quick Start + The easiest way to get started is: 1. `pip install turnkeyml` 2. Copy a PyTorch example of a model, like the one on this [Huggingface BERT model card](https://huggingface.co/google-bert/bert-base-uncased), into a file named `bert.py`. @@ -21,6 +25,10 @@ output = model(**encoded_input) ``` 3. `turnkey -i bert.py discover export-pytorch`: make a BERT ONNX file from this `bert.py` example. +### LLMs + +For LLM setup instructions, see [`turnkey-llm`](https://github.com/onnx/turnkeyml/tree/main/src/turnkeyml/llm). + ## Demo Here's `turnkey` in action: BERT-Base is exported from PyTorch to ONNX using `torch.onnx.export`, optimized for inference with `onnxruntime`, and converted to fp16 with `onnxmltools`: diff --git a/docs/mmlu_accuracy.md b/docs/mmlu_accuracy.md new file mode 100644 index 00000000..426f655f --- /dev/null +++ b/docs/mmlu_accuracy.md @@ -0,0 +1,100 @@ + +# Using the MMLU accuracy test tools + +The Massive Multitask Language Understanding (MMLU) benchmark is a comprehensive evaluation framework designed to assess the capabilities of language models across a wide range of subjects and disciplines. It encompasses a diverse set of questions covering topics from humanities to natural sciences, aiming to measure a model's depth and breadth of knowledge and its ability to generalize across different types of language understanding tasks. For detailed list of subjects tested refer [here](#detailed-list-of-subjects-categories-tested). + +This tool provides an automated way to evaluate language models on the MMLU benchmark. It automates the process of downloading the dataset, preparing evaluation prompts, running the model to generate answers, and calculating accuracy metrics across different subjects within the MMLU dataset. + +## Dataset +The MMLU dataset can be automatically downloaded by the script to the mmlu_data directory the first time you run the benchmark. The data is sourced from [here](https://people.eecs.berkeley.edu/~hendrycks/data.tar). + +## Running the Benchmark + +`lemonade -i facebook/opt-125m huggingface-load accuracy-mmlu --ntrain 5 --tests astronomy` + +### Optional arguments: + +`--ntrain`: The ntrain parameter is designed to specify the number of training examples to be used from a development (dev) set for creating context or background information in the prompts for evaluating language models, especially in tasks like MMLU (default: 5). + +In the context of few-shot learning, particularly with language models, "shots" refer to the number of examples provided to the model to help it understand or adapt to the task at hand without explicit training. +By setting `--ntrain` to 5 we achieve 5-shot setting in MMLU. +The model is expected to generate an answer to the test question based on the context provided by the preceding question-answer pairs. + +`--data-dir`: The directory where the MMLU data is stored (default: "/data"). + +`--tests`: Specific tests to run, identified by their subject names. Accepts multiple test names. + + +## How It Works + +1. `Data Preparation:` On the first run, the script downloads the MMLU dataset and extracts it into the specified data directory. It then prepares the data by reading the development and test sets for the specified subjects. + +1. `Prompt Generation:` For each subject, the script generates prompts from the development set to provide context for the test questions. This includes a configurable number of training examples (--ntrain) to help the model understand the task. + +1. `Model Evaluation:` The specified language model is used to generate answers to each test question. Testing methodology adopted from [here](https://github.com/hendrycks/test). + +1. `Accuracy Calculation:` The script compares the model-generated answers against the correct answers to calculate accuracy metrics for each subject. + +1. `Saving Results:` Detailed results for each subject, including questions, prompts, correct and generated answers, and overall accuracy, are saved to CSV files in the specified results directory. A summary CSV file compiling accuracy metrics across all evaluated subjects is also generated and available in the cache directory. + +## Detailed list of subjects/ categories tested + +| Test Subject | Category | +|----------------------------------|-------------------| +| Abstract Algebra | Math | +| Anatomy | Health | +| Astronomy | Physics | +| Business Ethics | Business | +| Clinical Knowledge | Health | +| College Biology | Biology | +| College Chemistry | Chemistry | +| College Computer Science | Computer Science | +| College Mathematics | Math | +| College Medicine | Health | +| College Physics | Physics | +| Computer Security | Computer Science | +| Conceptual Physics | Physics | +| Econometrics | Economics | +| Electrical Engineering | Engineering | +| Elementary Mathematics | Math | +| Formal Logic | Philosophy | +| Global Facts | Other | +| High School Biology | Biology | +| High School Chemistry | Chemistry | +| High School Computer Science | Computer Science | +| High School European History | History | +| High School Geography | Geography | +| High School Government and Politics | Politics | +| High School Macroeconomics | Economics | +| High School Mathematics | Math | +| High School Microeconomics | Economics | +| High School Physics | Physics | +| High School Psychology | Psychology | +| High School Statistics | Math | +| High School US History | History | +| High School World History | History | +| Human Aging | Health | +| Human Sexuality | Culture | +| International Law | Law | +| Jurisprudence | Law | +| Logical Fallacies | Philosophy | +| Machine Learning | Computer Science | +| Management | Business | +| Marketing | Business | +| Medical Genetics | Health | +| Miscellaneous | Other | +| Moral Disputes | Philosophy | +| Moral Scenarios | Philosophy | +| Nutrition | Health | +| Philosophy | Philosophy | +| Prehistory | History | +| Professional Accounting | Other | +| Professional Law | Law | +| Professional Medicine | Health | +| Professional Psychology | Psychology | +| Public Relations | Politics | +| Security Studies | Politics | +| Sociology | Culture | +| US Foreign Policy | Politics | +| Virology | Health | +| World Religions | Philosophy | diff --git a/docs/perplexity.md b/docs/perplexity.md new file mode 100644 index 00000000..0ab8bc9b --- /dev/null +++ b/docs/perplexity.md @@ -0,0 +1,72 @@ + +# Perplexity Evaluation + + +## Overview + +Perplexity is a measurement of how well a probability model predicts a sample. A lower perplexity indicates the model is more confident in its predictions. In the context of language models, perplexity measures the likelihood of the sequence according to the model, given as: + +`Perplexity (P) = exp(Average Negative Log-Likelihood)` + +`Where Average Negative Log-Likelihood = (1/N) * Sum[-log p(x_i) from i=1 to N]` + + +## Script Functionality + +### Key Components + +- **`max_length`**: The maximum input length the model can handle at once (set by the model's configuration). +- **`stride`**: The step size for the window, set to half of `max_length` to ensure some overlap and preserve context. +- **`seq_len`**: The total length of the tokenized input. + +### Detailed Steps + +1. **Load Model and Tokenizer**: Receive the model and tokenizer with specified configurations. +2. **Load and Prepare Data**: Loads the "wikitext-2-raw-v1" dataset and concatenates texts with double newlines. The data is then tokenized. +3. **Sliding Window Perplexity Calculation**: The script uses a sliding window approach (with a stride of half the window size) to calculate the perplexity for subsets of the data, adjusting for the maximum input length of the model: + - For each window, input data is processed, and the corresponding labels are adjusted to mask out irrelevant parts (using `-100`). + - The model computes the logits and loss for each window. + - Predicted and actual words at the end of each window are logged for analysis. +4. **Logging to CSV**: Summarizes the context window, predicted and actual next words, and loss for each window into a CSV file for further analysis. +5. **Perplexity Calculation**: Calculates the total negative log-likelihood adjusted by the effective token count for each window, then computes the average across all tokens to determine the perplexity. + +### Example Outputs + +The script outputs a CSV file named `summary_results.csv` with the following columns: + +- **Context (Partial context displayed for Brevity)** +- **Predicted next word** +- **Actual next word** +- **Loss for this window** + +These entries help in understanding how the model is performing at each step of the text. + +## How to Interpret Perplexity Results + +Understanding Perplexity +Definition: Perplexity is defined as the exponential of the average negative log-likelihood of a model on a given test set. + +Lower Values are Better: A lower perplexity score indicates that the model has a higher probability of correctly predicting the sample, suggesting better performance. A lower perplexity means the model is more certain about its predictions. + +### Interpretation: + +**High Perplexity:** Indicates confusion or a high level of uncertainty in the model’s predictions. A high perplexity can suggest that the model's language understanding is poor or that the model is not well-tuned for the given data. + +**Low Perplexity:** Suggests that the model predictions are more accurate and that it assigns higher probabilities to the actual observed outcomes. This is indicative of a model that has a good grasp of the language patterns seen in the test set. +Practical Implications + +**Model Comparison:** Perplexity is particularly useful for comparing different versions of the same model (e.g., before and after quantization, fine-tuning or training on additional data). The model with the lower perplexity is generally considered better at modeling the language of the test corpus. + +**Model Selection for Applications:** For applications involving language generation (like machine translation, text summarization, or chatbots), selecting a model with lower perplexity might result in more fluent, coherent, and contextually appropriate text output. + +**Diagnosing Model Fit:** High perplexity could indicate underfitting, where the model is too simple to capture the complexity of the language data. It can also help in diagnosing whether the model is well-suited for the specific domain of the text being modeled. + + +### Caveats in Interpretation + +**Dependency on Test Set:** Perplexity is highly dependent on the test set used. A model can show very different perplexity scores on different datasets. Therefore, it's important to consider the nature and domain of the test set when evaluating perplexity. + +**Not a Complete Measure:** While perplexity provides a measure of how uncertain a model is about its predictions, it does not directly measure how coherent or contextually appropriate generated texts are. Other qualitative assessments and metrics might be necessary to fully evaluate a language model's output. + +**Comparison Across Different Data:** Comparing perplexity scores across models trained or tested on different datasets can be misleading because the intrinsic difficulty of the datasets can affect the perplexity. + diff --git a/examples/llm/turnkey_llm.ipynb b/examples/llm/turnkey_llm.ipynb new file mode 100644 index 00000000..07d7e329 --- /dev/null +++ b/examples/llm/turnkey_llm.ipynb @@ -0,0 +1,184 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# LLMs on RyzenAI with TurnkeyML\n", + "\n", + "This notebook will demonstrate how to bring up an example application that uses a RyzenAI to perform LLM inference. We will use the `turnkeyml.llm` APIs in order to make this as quick as possible. This notebook makes use of both the `RyzenAI NPU`, as well as the `RyzenAI Radeon integrated GPU (iGPU)`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Application\n", + "\n", + "Our example application will prompt the user for input and then return the LLM's reponse. This is the same technology stack used to create AMD GAIA, which shows how Retrieval Augmented Generation (RAG), agentic workflows, and other advanced techniques can be layered on top of RyzenAI." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define our application: prompt the user and print the LLM's response\n", + "# We define this in a way that makes the NPU and iGPU interchangable\n", + "\n", + "def application(model, tokenizer):\n", + " while True:\n", + " # Prompt the user\n", + " user_prompt = input(\"What is your prompt to the LLM? \")\n", + " print(\"Prompt:\",user_prompt)\n", + "\n", + " # Exit the application if the user prompts \"exit\"\n", + " if user_prompt == \"exit\":\n", + " print(\"Done!\")\n", + " return\n", + "\n", + " # Tokenize the user's prompt\n", + " input_ids = tokenizer(user_prompt, return_tensors=\"pt\").input_ids\n", + "\n", + " # Generate the response\n", + " # Limit the response length to 30 tokens so that we have time to\n", + " # try a few prompts\n", + " llm_response = model.generate(input_ids, max_new_tokens=30)\n", + "\n", + " # Decode the response into text\n", + " decoded_response = tokenizer.decode(llm_response[0])\n", + "\n", + " # Print the response, then prompt for another input\n", + " print(\"Response:\",decoded_response)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## RyzenAI NPU Model Initialization\n", + "\n", + "### Prequisites for NPU\n", + "\n", + "- `ryzenai-transformers` conda environment is installed and activated.\n", + "- Access to `meta-llama/Llama-2-7b-chat-hf` on Hugging Face.\n", + "- Install the TurnkeyML-LLM in your `ryzenai-transformers` environment, see https://github.com/onnx/turnkeyml/tree/main/src/turnkeyml/llm/README.md#install-ryzenai-npu\n", + "- Also `pip install jupyter` in your `ryzenai-transformers` environment.\n", + "\n", + "### Starting Up\n", + "\n", + "- Run `conda activate ryzenai-transformers`\n", + "- Run `setup_phx.bat` or `setup_stx.bat` on your PHX (RyzenAI 7000 or RyzenAI 300, respectively)\n", + "- Run `jupyter notebook`\n", + "- Copy the URL printed from the previous command, and use that as the kernel for this notebook.\n", + " - Example: \n", + " > Or copy and paste one of these URLs:\n", + " >\n", + " > http://localhost:8888/tree?token=14796c43ce39ef9a3296b7c7c26335e01f7bdc8b0fd4efce\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Import the turnkey APIs\n", + "from turnkeyml.llm import leap\n", + "\n", + "# Load the model on to RyzenAI NPU\n", + "# NOTE: this takes a couple of minutes, but after you've done it once\n", + "# you can keep reusing the `model` instance in subsequent notebook cells\n", + "npu_model, npu_tokenizer = leap.from_pretrained(\n", + " \"meta-llama/Llama-2-7b-chat-hf\", recipe=\"ryzenai-npu\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## NPU Application" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Run the application on NPU\n", + "# User should prompt \"exit\" to stop the application\n", + "application(npu_model, npu_tokenizer)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Radeon iGPU Initialization\n", + "\n", + "### Prequisites for iGPU\n", + "\n", + "- `turnkeyml[llm-oga-dml]` is installed into an activated conda environment.\n", + "- Download a copy of `Phi-3-mini`\n", + "- See https://github.com/onnx/turnkeyml/tree/main/src/turnkeyml/llm/README.md#install-onnxruntime-genai for details" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Import the turnkey APIs\n", + "from turnkeyml.llm import leap\n", + "\n", + "# Load the model on iGPU\n", + "igpu_model, igpu_tokenizer = leap.from_pretrained(\n", + " \"microsoft/Phi-3-mini-4k-instruct\", recipe=\"oga-dml-igpu\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Radeon iGPU Application" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Run the application on iGPU\n", + "# User should prompt \"exit\" to stop the application\n", + "application(igpu_model, igpu_tokenizer)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/setup.py b/setup.py index 87857efb..1cdc8cd5 100644 --- a/setup.py +++ b/setup.py @@ -17,6 +17,10 @@ "turnkeyml.sequence", "turnkeyml.cli", "turnkeyml.common", + "turnkeyml.llm", + "turnkeyml.llm.tools", + "turnkeyml.llm.tools.ort_genai", + "turnkeyml.llm.tools.ryzenai_npu", "turnkeyml_models", "turnkeyml_models.graph_convolutions", "turnkeyml_models.selftest", @@ -44,10 +48,46 @@ "GitPython>=3.1.40", "psutil", ], + extras_require={ + "llm": [ + "tqdm", + "torch>=2.0.0", + "transformers", + "accelerate", + "py-cpuinfo", + "sentencepiece", + "datasets", + "fastapi", + "uvicorn[standard]", + ], + "llm-oga-dml": [ + "onnxruntime-directml==1.19.0", + "onnxruntime-genai-directml==0.4.0", + "tqdm", + "torch>=2.0.0", + "transformers", + "accelerate", + "py-cpuinfo", + "sentencepiece", + "datasets", + "fastapi", + "uvicorn[standard]", + ], + "llm-oga-npu": [ + "transformers", + "torch", + "onnx==1.16.0", + "onnxruntime==1.18.0", + "numpy==1.26.4", + "uvicorn[standard]", + ], + }, classifiers=[], entry_points={ "console_scripts": [ "turnkey=turnkeyml:turnkeycli", + "turnkey-llm=turnkeyml.llm:lemonadecli", + "lemonade=turnkeyml.llm:lemonadecli", ] }, python_requires=">=3.8, <3.12", diff --git a/src/turnkeyml/llm/README.md b/src/turnkeyml/llm/README.md new file mode 100644 index 00000000..92308ea8 --- /dev/null +++ b/src/turnkeyml/llm/README.md @@ -0,0 +1,129 @@ +# Turnkey-LLM + +Welcome to the project page for `turnkey-llm` (aka, "lemonade" the turnkey LLM Aide)! +Contents: + +1. [Getting Started](#getting-started) +1. [Install Specialized Tools](#install-specialized-tools) +1. [Code Organization](#code-organization) +1. [Contributing](#contributing) + +# Getting Started + +`turnkey-llm` introduces a brand new set of LLM-focused tools. + +## Install + +1. Clone: `git clone https://github.com/onnx/turnkeyml.git` +1. `cd turnkeyml` (where `turnkeyml` is the repo root of your TurnkeyML clone) + - Note: be sure to run these installation instructions from the repo root. +1. Create and activate a conda environment: + 1. `conda create -n tk-llm python=3.10` + 1. `conda activate tk-llm` +1. Install lemonade: `pip install -e .[llm]` + - or `pip install -e .[llm-oga-dml]` if you want to use `onnxruntime-genai` (see [OGA](#install-onnxruntime-genai)) +1. `lemonade -h` to explore the LLM tools + +## Syntax + +The `lemonade` CLI uses the same style of syntax as `turnkey`, but with a new set of LLM-specific tools. You can read about that syntax [here](https://github.com/onnx/turnkeyml#how-it-works). + +## Chatting + +To chat with your LLM try: + +`lemonade -i facebook/opt-125m huggingface-load llm-prompt -p "Hello, my thoughts are"` + +The LLM will run on CPU with your provided prompt, and the LLM's response to your prompt will be printed to the screen. You can replace the `"Hello, my thoughts are"` with any prompt you like. + +You can also replace the `facebook/opt-125m` with any Huggingface checkpoint you like, including LLaMA-2, Phi-2, Qwen, Mamba, etc. + +You can also set the `--device` argument in `huggingface-load` to load your LLM on a different device. + +Run `lemonade huggingface-load -h` and `lemonade llm-prompt -h` to learn more about those tools. + +## Accuracy + +To measure the accuracy of an LLM using MMLU, try this: + +`lemonade -i facebook/opt-125m huggingface-load accuracy-mmlu --tests management` + +That command will run just the management test from MMLU on your LLM and save the score to the lemonade cache at `~/.cache/lemonade`. + +You can run the full suite of MMLU subjects by omitting the `--test` argument. You can learn more about this with `lemonade accuracy-mmlu -h. + +## Serving + +You can launch a WebSocket server for your LLM with: + +`lemonade -i facebook/opt-125m huggingface-load serve` + +Once the server has launched, you can connect to it from your own application, or interact directly by following the on-screen instructions to open a basic web app. + +Note that the `llm-prompt`, `accuracy-mmlu`, and `serve` tools can all be used with other model-loading tools, for example `onnxruntime-genai` or `ryzenai-transformers`. See [Install Specialized Tools](#install-specialized-tools) for details. + +## API + +Lemonade is also available via API. Here's a quick example of how to benchmark an LLM: + +```python +import turnkeyml.llm.tools.torch_llm as tl +import turnkeyml.llm.tools.chat as cl +from turnkeyml.state import State + +state = State(cache_dir="cache", build_name="test") + +state = tl.HuggingfaceLoad().run(state, input="facebook/opt-125m") +state = cl.Prompt().run(state, prompt="hi", max_new_tokens=15) + +print("Response:", state.response) +``` + +# Install Specialized Tools + +Lemonade supports specialized tools that each require their own setup steps. **Note:** These tools will only appear in `lemonade -h` if you run in an environment that has completed setup. + +## Install OnnxRuntime-GenAI + +To install support for [onnxruntime-genai](https://github.com/microsoft/onnxruntime-genai) (e.g., the `oga-load` Tool), use `pip install -e .[llm-oga-dml]` instead of the default installation command. + +Next, you need to get an OGA model. Per the OGA instructions, we suggest Phi-3-Mini. Use the following command to download it from Hugging Face, and make sure to set your `--local-dir` to the `REPO_ROOT/src/turnkeyml/llm/ort_genai/models` directory. + +`huggingface-cli download microsoft/Phi-3-mini-4k-instruct-onnx --include directml/directml-int4-awq-block-128* --local-dir REPO_ROOT/src/turnkeyml/llm/tools/ort_genai/models/phi-3-mini-4k-instruct` + +You can try it out with: `lemonade -i microsoft/Phi-3-mini-4k-instruct oga-load --device igpu --dtype int4 llm-prompt -p "Hello, my thoughts are"` + +You can also try Phi-3-Mini-128k-Instruct with the following commands: + +`huggingface-cli download microsoft/Phi-3-mini-128k-instruct-onnx --include directml/directml-int4-awq-block-128* --local-dir REPO_ROOT/src/turnkeyml/llm/tools/ort_genai/models/phi-3-mini-128k-instruct` + +`lemonade -i microsoft/Phi-3-mini-128k-instruct oga-load --device igpu --dtype int4 llm-prompt -p "Hello, my thoughts are"` + + +> Note: no other models or devices are officially supported by `lemonade` on OGA at this time. Contributions appreciated! + +## Install Ryzen AI NPU + +To run your LLMs on Ryzen AI NPU, first install and set up the `ryzenai-transformers` conda environment (see instructions [here](https://github.com/amd/RyzenAI-SW/tree/main/example/transformers)). Then, install `lemonade` into `ryzenai-transformers`. The `ryzenai-npu-load` Tool will become available in that environment. + +You can try it out with: `lemonade -i meta-llama/Llama-2-7b-chat-hf ryzenai-npu-load --device DEVICE llm-prompt -p "Hello, my thoughts are"` + +Where `DEVICE` is either "phx" or "stx" if you have a RyzenAI 7xxx/8xxx or 3xx/9xxx processor, respectively. + +> Note: only `meta-llama/Llama-2-7b-chat-hf` and `microsoft/Phi-3-mini-4k-instruct` are supported by `lemonade` at this time. Contributions appreciated! + +# Contributing + +If you decide to contribute, please: + +- do so via a pull request. +- write your code in keeping with the same style as the rest of this repo's code. +- add a test under `test/llm_api.py` that provides coverage of your new feature. + +The best way to contribute is to add new tools to cover more devices and usage scenarios. + +To add a new tool: + +1. (Optional) Create a new `.py` file under `src/turnkeyml/llm/tools` (or use an existing file if your tool fits into a pre-existing family of tools). +1. Define a new class that inherits the `Tool` class from `TurnkeyML`. +1. Register the class by adding it to the list of `tools` near the top of `src/turnkeyml/llm/cli.py`. diff --git a/src/turnkeyml/llm/__init__.py b/src/turnkeyml/llm/__init__.py new file mode 100644 index 00000000..f4899d7d --- /dev/null +++ b/src/turnkeyml/llm/__init__.py @@ -0,0 +1 @@ +from .cli import main as lemonadecli diff --git a/src/turnkeyml/llm/cache.py b/src/turnkeyml/llm/cache.py new file mode 100644 index 00000000..5c0241f8 --- /dev/null +++ b/src/turnkeyml/llm/cache.py @@ -0,0 +1,32 @@ +import os + +# Allow an environment variable to override the default +# location for the build cache +if os.environ.get("LEMONADE_CACHE_DIR"): + DEFAULT_CACHE_DIR = os.path.expanduser(os.environ.get("LEMONADE_CACHE_DIR")) +else: + DEFAULT_CACHE_DIR = os.path.expanduser("~/.cache/lemonade") + + +def checkpoint_to_model_name(checkpoint_name: str) -> str: + """ + Get the model's name by stripping the author's name from the checkpoint name + """ + + return checkpoint_name.split("/")[1] + + +class Keys: + MODEL = "model" + PER_ITERATION_LATENCY = "per_iteration_latency" + MEAN_LATENCY = "mean_latency" + STD_DEV_LATENCY = "std_dev_latency" + MEAN_TOKENS_PER_SECOND = "mean_tokens_per_second" + STD_DEV_TOKENS_PER_SECOND = "std_dev_tokens_per_second" + SECONDS_TO_FIRST_TOKEN = "seconds_to_first_token" + STD_DEV_SECONDS_TO_FIRST_TOKEN = "std_dev_seconds_to_first_token" + CHECKPOINT = "checkpoint" + DTYPE = "dtype" + PROMPT_TOKENS = "prompt_tokens" + CACHE_DIR = "cache_dir" + DEVICE = "device" diff --git a/src/turnkeyml/llm/cli.py b/src/turnkeyml/llm/cli.py new file mode 100644 index 00000000..1151b2f9 --- /dev/null +++ b/src/turnkeyml/llm/cli.py @@ -0,0 +1,124 @@ +import os +from turnkeyml.tools import FirstTool, NiceHelpFormatter +import turnkeyml.common.filesystem as fs +import turnkeyml.cli.cli as cli +from turnkeyml.sequence import Sequence +from turnkeyml.tools.management_tools import Cache, Version +from turnkeyml.tools.report import Report +from turnkeyml.state import State + +from turnkeyml.llm.tools.huggingface_load import ( + HuggingfaceLoad, + AdaptHuggingface, +) + +import turnkeyml.llm.cache as cache +from turnkeyml.llm.tools.mmlu import AccuracyMMLU +from turnkeyml.llm.tools.perplexity import AccuracyPerplexity +from turnkeyml.llm.tools.chat import LLMPrompt, Serve + + +def main(): + + # List the available tools + tools = [ + HuggingfaceLoad, + AccuracyMMLU, + AccuracyPerplexity, + LLMPrompt, + AdaptHuggingface, + Serve, + # Inherited from TurnkeyML + Report, + Cache, + Version, + ] + + # Import onnxruntime-genai recipes + try: + from turnkeyml.llm.tools.ort_genai.oga import OgaLoad + + tools = tools + [OgaLoad] + + except ModuleNotFoundError: + pass + + # Import RyzenAI NPU modules only if RyzenAI NPU is installed + try: + from turnkeyml.llm.tools.ryzenai_npu.ryzenai_npu import RyzenAINPULoad + + tools = tools + [RyzenAINPULoad] + except ModuleNotFoundError: + pass + + + + + + # Define the argument parser + parser = cli.CustomArgumentParser( + description="Turnkey analysis and benchmarking of GenAI models. " + "This utility is a toolchain. To use it, provide a list of tools and " + "their arguments.", + formatter_class=NiceHelpFormatter, + ) + + parser.add_argument( + "-i", + "--input", + help="The input that will be evaluated by the tool sequence " + "(e.g., huggingface checkpoints)", + ) + + parser.add_argument( + "-d", + "--cache-dir", + help="Cache directory where the results of each tool will " + f"be stored (defaults to {cache.DEFAULT_CACHE_DIR})", + required=False, + default=cache.DEFAULT_CACHE_DIR, + ) + + parser.add_argument( + "--lean-cache", + dest="lean_cache", + help="Delete all build artifacts (e.g., .onnx files) when the command completes", + action="store_true", + ) + + global_args, tool_instances, evaluation_tools = cli.parse_tools(parser, tools) + + if len(evaluation_tools) > 0: + if not issubclass(evaluation_tools[0], FirstTool): + parser.error( + "The first tool in the sequence needs to be one " + "of the 'tools that can start a sequence.' Use " + "`turnkey-llm -h` to see that list of tools." + ) + # Run the evaluation tools as a build + sequence = Sequence(tools=tool_instances) + + # Forward the selected input to the first tool in the sequence + first_tool_args = next(iter(sequence.tools.values())) + first_tool_args.append("--input") + first_tool_args.append(global_args["input"]) + + state = State( + cache_dir=global_args["cache_dir"], + build_name=global_args["input"].replace("/", "_"), + sequence_info=sequence.info, + ) + sequence.launch( + state, + lean_cache=global_args["lean_cache"], + ) + else: + # Run the management tools + for management_tool, argv in tool_instances.items(): + # Support "~" in the cache_dir argument + parsed_cache_dir = os.path.expanduser(global_args[fs.Keys.CACHE_DIR]) + management_tool.parse_and_run(parsed_cache_dir, argv) + + +if __name__ == "__main__": + main() diff --git a/src/turnkeyml/llm/leap.py b/src/turnkeyml/llm/leap.py new file mode 100644 index 00000000..78f17868 --- /dev/null +++ b/src/turnkeyml/llm/leap.py @@ -0,0 +1,143 @@ +# pylint: disable=no-member + +from typing import Tuple, Dict +from turnkeyml.state import State +import turnkeyml.common.printing as printing +import turnkeyml.llm.cache as cache +from turnkeyml.llm.tools.adapter import ModelAdapter, TokenizerAdapter + + +class NotSupported(Exception): + """ + Indicates that a checkpoint/recipe pair are not supported + together at this time. + """ + + def __init__(self, msg): + super().__init__(msg) + printing.log_error(msg) + + +def _raise_not_supported(recipe, checkpoint): + raise NotSupported( + f"Recipe {recipe} does not have support for checkpoint {checkpoint}" + ) + + +def _make_state(recipe, checkpoint) -> Dict: + return State(cache_dir=cache.DEFAULT_CACHE_DIR, build_name=f"{checkpoint}_{recipe}") + + +class HuggingfaceCudaTokenizer(TokenizerAdapter): + """ + Wrap the Huggingface tokenizer class by sending the encoded + tokenizer inputs to the dGPU. + + This allows LEAP recipes to be fungible by saving the user the + additional step of managing the input's device location. + """ + + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + def __call__(self, prompt, **kwargs): + return self.tokenizer(prompt, **kwargs).to(device="cuda") + + def decode(self, response, **kwargs): + return self.tokenizer.decode(response, **kwargs) + + +def from_pretrained( + checkpoint: str, + recipe: str = "hf-cpu", +) -> Tuple[ModelAdapter, TokenizerAdapter]: + """ + Load an LLM and the corresponding tokenizer using a bespoke lemonade recipe. + + Not all recipes are available with all checkpoints. A leap.NotSupported exception + will be raised in these cases. + + Args: + - checkpoint: huggingface checkpoint that defines the LLM + - recipe: defines the implementation and hardware used for the LLM + + Recipe choices: + - hf-cpu: Huggingface Transformers implementation for CPU with max-perf settings + - hf-dgpu: Huggingface Transformers implementation on dGPU (via device="cuda") + - dml-og-igpu: DirectML implementation for iGPU based on onnxruntime-genai + - ryzenai-npu: RyzenAI implementation of huggingface transformers PyTorch model + + Returns: + - model: LLM instance with a generate() method that invokes the recipe + - tokenizer: tokenizer instance compatible with the model, which supports + the encode (call) and decode() methods. + """ + + if recipe == "hf-cpu": + # Huggingface Transformers recipe for CPU + # Huggingface supports all checkpoints, so there is nothing to check for + + import torch + from turnkeyml.llm.tools.huggingface_load import HuggingfaceLoad + + state = _make_state(recipe, checkpoint) + + state = HuggingfaceLoad().run( + state, + input=checkpoint, + dtype=torch.bfloat16, + ) + + return state.model, state.tokenizer + + if recipe == "hf-dgpu": + # Huggingface Transformers recipe for discrete GPU (Nvidia, Instinct, Radeon) + + import torch + from turnkeyml.llm.tools.huggingface_load import HuggingfaceLoad + + state = _make_state(recipe, checkpoint) + + state = HuggingfaceLoad().run( + state, + input=checkpoint, + dtype=torch.bfloat16, + device="cuda", + ) + + # Wrap the tokenizer to ensure that inputs are placed on the dGPU device + tokenizer = HuggingfaceCudaTokenizer(state.tokenizer) + + return state.model, tokenizer + + elif recipe == "oga-dml-igpu": + import turnkeyml.llm.tools.ort_genai.oga as oga + + state = _make_state(recipe, checkpoint) + + state = oga.OgaLoad().run( + state, + device="igpu", + dtype="int4", + ) + + return state.model, state.tokenizer + + elif recipe == "ryzenai-npu": + if ( + checkpoint != "TheBloke/Llama-2-7b-Chat-fp16" + and checkpoint != "meta-llama/Llama-2-7b-chat-hf" + and checkpoint != "microsoft/Phi-3-mini-4k-instruct" + ): + _raise_not_supported(recipe, checkpoint) + + import turnkeyml.llm.tools.ryzenai_npu.ryzenai_npu as ryzenai_npu + + state = _make_state(recipe, checkpoint) + + state = ryzenai_npu.RyzenAINPULoad().run(state, checkpoint, device="phx") + + return state.model, state.tokenizer + + else: + _raise_not_supported(recipe, checkpoint) diff --git a/src/turnkeyml/llm/tools/__init__.py b/src/turnkeyml/llm/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/turnkeyml/llm/tools/adapter.py b/src/turnkeyml/llm/tools/adapter.py new file mode 100644 index 00000000..01824371 --- /dev/null +++ b/src/turnkeyml/llm/tools/adapter.py @@ -0,0 +1,82 @@ +import abc + + +class ModelAdapter(abc.ABC): + """ + Base class for adapting an LLM to work with lemonade's standardized tools + """ + + def __init__(self): + """ + Self-benchmarking ModelAdapters can store their results in the + tokens_per_second and time_to_first_token members. + """ + self.tokens_per_second = None + self.time_to_first_token = None + self.type = "generic" + + @abc.abstractmethod + def generate(self, input_ids, max_new_tokens=512): + """ + Generate is the primary method required by lemonade's accuracy tools + + We try to keep the signature here minimal to allow for maximum compatibility + with recipe components, which themselves may not support a lot of arguments. + """ + + +class TokenizerAdapter(abc.ABC): + """ + Base class for adapting an LLM's tokenizer to work with lemonade's standard tools + """ + + @abc.abstractmethod + def __call__(self, prompt: str): + """ + Args: + prompt: text that should be encoded and passed to the LLM as input_ids + + Returns: input_ids + """ + + @abc.abstractmethod + def decode(self, response) -> str: + """ + Args: + response: tokens from the LLM that should be decoded into text + + Returns: text response of the LLM + """ + + +class PassthroughTokenizerResult: + """ + Data structure for holding a tokenizer result where the input_ids + are packaged in a non-standard way, but we still want to adhere to + standard interfaces (e.g., result.input_ids). + + For example: CLI-based tools that have their own internal tokenizer that + isn't exposed to the user. In this case we can pass the prompt through as + a string. + """ + + def __init__(self, prompt): + self.input_ids = prompt + + +class PassthroughTokenizer(TokenizerAdapter): + """ + Tokenizer adapter that forwards the prompt to input_ids as text, + and then forwards a text LLM response through decode() as text. + + Useful for CLI-based tools that have their own internal tokenizer that + isn't exposed to the user. + """ + + # pylint: disable=unused-argument + def __call__(self, prompt: str, **kwargs): + return PassthroughTokenizerResult(prompt) + + # pylint: disable=unused-argument + def decode(self, response: str, **kwargs): + return response diff --git a/src/turnkeyml/llm/tools/chat.py b/src/turnkeyml/llm/tools/chat.py new file mode 100644 index 00000000..21491038 --- /dev/null +++ b/src/turnkeyml/llm/tools/chat.py @@ -0,0 +1,257 @@ +import argparse +from threading import Thread +import asyncio +from fastapi import FastAPI, WebSocket +from fastapi.responses import HTMLResponse +from pydantic import BaseModel +from transformers import TextIteratorStreamer +import uvicorn +from turnkeyml.state import State +from turnkeyml.tools import Tool +from turnkeyml.llm.tools.adapter import ModelAdapter, TokenizerAdapter + + +class LLMPrompt(Tool): + """ + Send a prompt to an LLM instance and print the response to the screen. + + Required input state: + - state.model: LLM instance that supports the generate() method. + - state.tokenizer: LLM tokenizer instance that supports the __call__() (ie, encode) + and decode() methods. + + Output state produced: + - "response": text response from the LLM. + """ + + unique_name = "llm-prompt" + + def __init__(self): + super().__init__(monitor_message="Prompting LLM") + + self.status_stats = ["response"] + + @staticmethod + def parser(add_help: bool = True) -> argparse.ArgumentParser: + parser = __class__.helpful_parser( + short_description="Prompt an LLM and print the result", + add_help=add_help, + ) + + parser.add_argument("--prompt", "-p", help="Prompt input to the LLM") + + parser.add_argument( + "--max-new-tokens", + "-m", + default=512, + type=int, + help="Maximum number of new tokens in the response", + ) + + return parser + + def run( + self, + state: State, + prompt: str = "Hello", + max_new_tokens: int = 512, + ) -> State: + + model: ModelAdapter = state.model + tokenizer: TokenizerAdapter = state.tokenizer + + input_ids = tokenizer(prompt, return_tensors="pt").input_ids + response = model.generate(input_ids, max_new_tokens=max_new_tokens) + response_text = tokenizer.decode(response[0], skip_special_tokens=True).strip() + + state.response = response_text + state.save_stat("response", response_text) + + return state + + +class Serve(Tool): + """ + Open a web server that apps can use to communicate with the LLM. + + There are two ways interact with the server: + - Send an http request to "http://localhost:8000/generate" and + receive back a response with the complete prompt. + - Open a WebSocket with "ws://localhost:8000" and receive a + streaming response to the prompt. + + The WebSocket functionality is demonstrated by the webpage served at + http://localhost:8000, which you can visit with a web browser after + opening the server. + + Required input state: + - state.model: model instance serve. Must be compatible with the + huggingface TextIteratorStreamer. + - state.tokenizer: tokenizer instance used to generate inputs for the + model. Must be compatible with the huggingface TextIteratorStreamer. + + Output state produced: None + """ + + unique_name = "serve" + + def __init__(self): + # Disable the build logger since the server is interactive + super().__init__( + monitor_message="Launching LLM Server", + enable_logger=False, + ) + + @staticmethod + def parser(add_help: bool = True) -> argparse.ArgumentParser: + parser = __class__.helpful_parser( + short_description="Open an HTTP server for the model", + add_help=add_help, + ) + + parser.add_argument( + "--max-new-tokens", + required=False, + type=int, + default=300, + help="Number of new tokens the LLM should make (default: 300)", + ) + + return parser + + def run( + self, + state: State, + max_new_tokens: int = 300, + ) -> State: + + # Disable the build monitor since the server is persistent and interactive + if self.progress: + self.progress.terminate() + print("\n") + + app = FastAPI() + + # Load the model and tokenizer + model = state.model + tokenizer = state.tokenizer + + class Message(BaseModel): + text: str + + html = """ + + + + Chat + + +

Lemonade Chat

+
+ + +
+

+ + + + """ + + @app.get("/") + async def get(): + return HTMLResponse(html) + + @app.post("/generate") + async def generate_response(message: Message): + input_ids = tokenizer(message.text, return_tensors="pt").input_ids + response = model.generate( + input_ids, + max_new_tokens=max_new_tokens, + do_sample=True, + top_k=50, + top_p=0.95, + temperature=0.7, + pad_token_id=tokenizer.eos_token_id, + ) + generated_text = tokenizer.decode(response[0], skip_special_tokens=True) + + # Remove the input prompt from the generated text + generated_text = generated_text.replace(message.text, "").strip() + + return {"response": generated_text} + + @app.websocket("/ws") + async def stream_response(websocket: WebSocket): + await websocket.accept() + while True: + + message = await websocket.receive_text() + + if message == "done": + break + input_ids = tokenizer(message, return_tensors="pt").input_ids + + # Set up the generation parameters + if isinstance(model, ModelAdapter) and model.type == "ort-genai": + # Onnxruntime-genai models + import turnkeyml.llm.tools.ort_genai.oga as oga + + streamer = oga.OrtGenaiStreamer(tokenizer) + + else: + # Huggingface-like models + streamer = TextIteratorStreamer( + tokenizer, + skip_prompt=True, + ) + generation_kwargs = { + "input_ids": input_ids, + "streamer": streamer, + "max_new_tokens": max_new_tokens, + "do_sample": True, + "top_k": 50, + "top_p": 0.95, + "temperature": 0.7, + "pad_token_id": tokenizer.eos_token_id, + } + + thread = Thread(target=model.generate, kwargs=generation_kwargs) + thread.start() + + # Generate the response using streaming + for new_text in streamer: + print(new_text, end="", flush=True) + + # Send the generated text to the client + await asyncio.sleep(0.1) # Add a small delay (adjust as needed) + await websocket.send_text(new_text) + + print("\n") + thread.join() + + await websocket.close() + + uvicorn.run(app, host="localhost", port=8000) + + return state diff --git a/src/turnkeyml/llm/tools/huggingface_load.py b/src/turnkeyml/llm/tools/huggingface_load.py new file mode 100644 index 00000000..789702d8 --- /dev/null +++ b/src/turnkeyml/llm/tools/huggingface_load.py @@ -0,0 +1,252 @@ +import argparse +from typing import Dict, Optional +import json +import transformers +import torch +from turnkeyml.state import State +import turnkeyml.common.status as status +from turnkeyml.tools import Tool, FirstTool +from turnkeyml.llm.tools.adapter import ModelAdapter +from turnkeyml.llm.cache import Keys + +# Command line interfaces for tools will use string inputs for data +# types, however the internal tool logic will need to know the actual +# torch type +str_to_dtype = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "int8_static": torch.int8, + "int8_dynamic": torch.int8, +} + + +def make_example_inputs(state: State) -> Dict: + """ + Create a dictionary of LLM inputs that can be passed as an argument + into quantization, ONNX export, etc. + """ + + tokenizer = state.tokenizer + inputs_ids = tokenizer("Hello there", return_tensors="pt").input_ids + return {"input_ids": inputs_ids} + + +class HuggingfaceLoad(FirstTool): + """ + Load an LLM as a torch.nn.Module using the Hugging Face transformers + from_pretrained() API. + + Expected input: a checkpoint to load + + Output state produced: + - state.model: instance of torch.nn.Module that implements an LLM. + - state.inputs: tokenized example inputs to the model, in the form of a + dictionary of kwargs. + - state.tokenizer: instance of Hugging Face PretrainedTokenizer. + - state.dtype: data type of the model. + - state.checkpoint: pretrained checkpoint used to load the model. + """ + + unique_name = "huggingface-load" + + def __init__(self): + super().__init__(monitor_message="Loading Huggingface checkpoint") + + self.status_stats = [Keys.DTYPE] + + @staticmethod + def parser(add_help: bool = True) -> argparse.ArgumentParser: + parser = __class__.helpful_parser( + short_description="Load an LLM as torch.nn.Module using huggingface from_pretrained()", + add_help=add_help, + ) + + default_dtype = "float32" + parser.add_argument( + "--dtype", + "-d", + required=False, + default=default_dtype, + help=f"Data type to load the model in (default: {default_dtype}).", + ) + + choices = ["cpu", "cuda"] + for cuda in range(15): + choices.append(f"cuda:{cuda}") + parser.add_argument( + "--device", + required=False, + default=None, + choices=choices, + help="Move the model and inputs to a device using the .to() method " + "(default: don't call the .to() method)", + ) + + parser.add_argument( + "--load-kwargs", + required=False, + default="{}", + type=json.loads, + help="Arbitrary kwargs, in json format, that will be passed as " + "from_pretrained(**kwargs). " + r"Example: --load-kwargs='{\"trust_remote_code\": true} would result in " + "from_pretrained(trust_remote_code=True)", + ) + + parser.add_argument( + "--channels-last", + default=True, + type=bool, + help="Whether to format the model in memory using " + "channels-last (default: True)", + ) + + return parser + + def parse(self, state: State, args, known_only=True) -> argparse.Namespace: + + parsed_args = super().parse(state, args, known_only) + + # Save stats about the user's input (do this prior to decoding) + state.save_stat(Keys.CHECKPOINT, parsed_args.input) + state.save_stat(Keys.DTYPE, parsed_args.dtype) + + # Decode dtype arg into a torch value + parsed_args.dtype = str_to_dtype[parsed_args.dtype] + + return parsed_args + + def run( + self, + state: State, + input: str = "", + dtype: torch.dtype = torch.float32, + device: Optional[str] = None, + load_kwargs: Optional[Dict] = None, + channels_last: bool = True, + ) -> State: + + checkpoint = input + + if load_kwargs is None: + load_kwargs_to_use = {} + else: + load_kwargs_to_use = load_kwargs + + if vars(state).get(Keys.MODEL): + raise ValueError("HuggingfaceLoad must be the first tool in the sequence") + + model = transformers.AutoModelForCausalLM.from_pretrained( + checkpoint, + torch_dtype=dtype, + low_cpu_mem_usage=True, + **load_kwargs_to_use, + ) + + # Only call the model.to() method if an argument to this function + # provides a reason to do so + to_args = {} + if channels_last: + to_args["memory_format"] = torch.channels_last + if device: + to_args["device"] = device + if to_args: + model.to(**to_args) + + model = model.eval() + + try: + tokenizer = transformers.AutoTokenizer.from_pretrained( + checkpoint, use_fast=False, model_max_length=4096, padding_side="left" + ) + except ValueError: + # Sometimes those specific tokenizer flags are not supported, in which + # case we try to just load a simple tokenizer + tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint) + + # Pass the model and inputs into state + state.model = model + state.tokenizer = tokenizer + state.dtype = dtype + state.checkpoint = checkpoint + state.device = device + + # Save stats about the model + state.save_stat(Keys.CHECKPOINT, checkpoint) + state.save_stat(Keys.DTYPE, str(dtype).split(".")[1]) + state.save_stat(Keys.DEVICE, device) + + # Create a UniqueInvocationInfo and ModelInfo so that we can display status + # at the end of the sequence + status.add_to_state(state=state, name=input, model=model) + + return state + + +class HuggingfaceAdapter(ModelAdapter): + """ + Wrapper class for Huggingface LLMs that set generate() arguments to + make them more accurate and pleasant to chat with: + + repetition_penalty: helps the LLM avoid repeating the same short + phrase in the response over and over. + temperature: helps the LLM stay focused on the prompt. + do_sample: apply the temperature. + """ + + def __init__(self, model, dtype=torch.float32, device="cpu"): + super().__init__() + self.model = model + self.dtype = dtype + self.device = device + + def generate(self, input_ids, max_new_tokens=512, repetition_penalty=1.2, + do_sample=True, temperature=0.1, **kwargs): + amp_enabled = ( + True + if (self.dtype == torch.float16 or self.dtype == torch.bfloat16) + else False + ) + + # Move input_ids to the same device as the model + input_ids = input_ids.to(self.device) + + with torch.no_grad(), torch.inference_mode(), torch.cpu.amp.autocast( + enabled=amp_enabled, dtype=self.dtype + ): + return self.model.generate( + input_ids=input_ids, + max_new_tokens=max_new_tokens, + repetition_penalty=repetition_penalty, + do_sample=do_sample, + temperature=temperature, + **kwargs + ) + + +class AdaptHuggingface(Tool): + """ + Apply specific settings to make Huggingface LLMs + more accurate and pleasant to chat with. + """ + + unique_name = "adapt-huggingface" + + def __init__(self): + super().__init__(monitor_message="Adapting Huggingface LLM") + + @staticmethod + def parser(add_help: bool = True) -> argparse.ArgumentParser: + parser = __class__.helpful_parser( + short_description="Apply accuracy-boosting settings to huggingface LLMs", + add_help=add_help, + ) + + return parser + + def run(self, state: State) -> State: + + state.model = HuggingfaceAdapter(state.model, state.dtype, state.device) + + return state diff --git a/src/turnkeyml/llm/tools/mmlu.py b/src/turnkeyml/llm/tools/mmlu.py new file mode 100644 index 00000000..2777fed3 --- /dev/null +++ b/src/turnkeyml/llm/tools/mmlu.py @@ -0,0 +1,270 @@ +import argparse +import os +import tarfile +from pathlib import Path +from typing import List, Optional +import tqdm +import numpy as np +import pandas as pd +import requests +from turnkeyml.state import State +from turnkeyml.tools import Tool +import turnkeyml.common.printing as printing +import turnkeyml.common.build as build + +# Constants +choices = ["A", "B", "C", "D"] +dataset_url = "https://people.eecs.berkeley.edu/~hendrycks/data.tar" + + +class AccuracyMMLU(Tool): + """ + See docs/mmlu_accuracy.md for more details + """ + + unique_name = "accuracy-mmlu" + + def __init__(self): + super().__init__(monitor_message="Measuring accuracy with MMLU") + self.status_stats = [] + + @staticmethod + def parser(add_help: bool = True) -> argparse.ArgumentParser: + parser = __class__.helpful_parser( + short_description="Run accuracy benchmark using Massive Multitask " + "Language Understanding (MMLU) test", + add_help=add_help, + ) + + parser.add_argument( + "--ntrain", + type=int, + default=5, + help="Number of training examples to use. Default set to 5 for `5 Shot`", + ) + parser.add_argument( + "--data-dir", + type=str, + required=False, + help="Directory containing test and dev data (default: lemonade cache).", + ) + parser.add_argument( + "--tests", + nargs="+", + help=( + "Specific tests to run. For a single quick test, we suggest 'management'." + + "Default: run all tests." + ), + ) + return parser + + def run( + self, + state: State, + ntrain: int = 5, + data_dir: Optional[str] = None, + tests: List[str] = None, + ) -> State: + + if data_dir: + data_dir_to_use = data_dir + else: + data_dir_to_use = os.path.join(state.cache_dir, "data", "mmlu") + + # Setup MMLU dataset + dataset_dir = download_and_extract_dataset(data_dir_to_use, dataset_url) + + model_results_dir = os.path.join( + build.output_dir(state.cache_dir, state.build_name), "mmlu" + ) + os.makedirs(model_results_dir, exist_ok=True) + + tests_to_run = [ + f.replace("_test.csv", "") + for f in sorted(os.listdir(os.path.join(dataset_dir, "test"))) + if f.endswith("_test.csv") + ] + if tests is not None: + unsupported_tests = set(tests) - set(tests_to_run) + if unsupported_tests: + printing.log_warning( + "Warning: Unsupported tests specified and will be ignored:" + + f"{', '.join(unsupported_tests)}" + ) + tests_to_run = [test for test in tests if test in tests_to_run] + + tokenizer = state.tokenizer + model = state.model + + summary_data = [] + for subject in tqdm.tqdm(tests_to_run): + dev_df = _safe_read_csv( + os.path.join(dataset_dir, "dev", f"{subject}_dev.csv") + )[:ntrain] + test_df = _safe_read_csv( + os.path.join(dataset_dir, "test", f"{subject}_test.csv") + ) + + detailed_results, acc = _eval_model( + ntrain, subject, model, tokenizer, dev_df, test_df + ) + subject_results_df = pd.DataFrame(detailed_results) + subject_csv_path = os.path.join( + model_results_dir, f"{subject}_detailed_results.csv" + ) + subject_results_df.to_csv(subject_csv_path, index=False) + + # Update summary_data with total questions and correct answers + correct_answers_count = sum( + result["Correct"] for result in detailed_results + ) + summary_data.append( + { + "Subject": subject, + "Accuracy": acc, + "Total Questions": len(test_df), + "Correct Answers": correct_answers_count, + } + ) + + # Save accuracy results to stats file + # And display in the CLI + stat_name = f"mmlu_{subject}_accuracy" + stat_units_name = f"{stat_name}_units" + state.save_stat(stat_name, float(acc) * 100) + state.save_stat(stat_units_name, "%") + self.status_stats.append(stat_name) + + # Save accuracy results to CSV file + summary_df = pd.DataFrame(summary_data) + summary_df.to_csv( + os.path.join(model_results_dir, "summary_results.csv"), index=False + ) + return state + + +def _list_tests(data_dir): + """Lists all available tests based on the files in the test data directory.""" + test_files = [ + f for f in os.listdir(os.path.join(data_dir, "test")) if f.endswith("_test.csv") + ] + print( + "Available tests:", + *[f.replace("_test.csv", "") for f in sorted(test_files)], + sep="\n", + ) + + +def _format_subject(subject): + """Formats a subject string by replacing underscores with spaces.""" + return " ".join(subject.split("_")) + + +def _safe_read_csv(path): + """Safely reads a CSV file and returns a DataFrame.""" + try: + return pd.read_csv(path, header=None) + except FileNotFoundError: + printing.log_error(f"Error: File not found - {path}") + except Exception as e: # pylint: disable=broad-except + printing.log_error(f"An error occurred while reading {path}: {e}") + + +def _format_example(df, idx, include_answer=True): + """Formats an example from the dataframe into a prompt string.""" + prompt = df.iloc[idx, 0] + for j in range(1, df.shape[1] - 1): + prompt += f"\n{choices[j-1]}. {df.iloc[idx, j]}" + prompt += "\nAnswer_:" + if include_answer: + prompt += f" {df.iloc[idx, -1]}\n\n" + return prompt + + +def _gen_prompt(train_df, subject, k=-1): + """Generates a prompt string from multiple examples.""" + prompt = ( + "The following are multiple choice questions (with answers) about " + + f"{_format_subject(subject)}.\n\n" + ) + for i in range(min(k, train_df.shape[0]) if k != -1 else train_df.shape[0]): + prompt += _format_example(train_df, i) + return prompt + + +def _generate_response(tokenizer, model, input_ids): + """Generates a model response for the given input IDs.""" + try: + response = model.generate(input_ids, max_new_tokens=1) + return tokenizer.decode(response[0], skip_special_tokens=True).strip() + except Exception as e: # pylint: disable=broad-except + printing.log_warning(f"Error during model generation: {e}") + return "" # Return an empty string on failure + + +def download_and_extract_dataset(data_cache_dir: str, dataset_url: str): + """ + Download the dataset from the given URL and extract it into the target directory. + """ + + # Create the directory if it does not exist + Path(data_cache_dir).mkdir(parents=True, exist_ok=True) + + # Check if the data already exists to avoid re-downloading + if not os.listdir(data_cache_dir): # Checks if the directory is empty + printing.log_info(f"Downloading dataset to {data_cache_dir}") + + # Download the dataset + response = requests.get(dataset_url, stream=True) + if response.status_code == 200: + tar_path = os.path.join(data_cache_dir, "data.tar") + with open(tar_path, "wb") as f: + f.write(response.raw.read()) + + printing.log_info("Extracting dataset...") + # Extract the tar file + with tarfile.open(tar_path) as tar: + tar.extractall(path=data_cache_dir) + os.remove(tar_path) + printing.log_info("Dataset ready.") + else: + printing.log_info("Failed to download the dataset.") + else: + printing.log_info( + f"Dataset already exists in {data_cache_dir}, skipping download." + ) + + # MMLU data is stored in data.tar/data + return os.path.join(data_cache_dir, "data") + + +def _eval_model(ntrain, subject, model, tokenizer, dev_df, test_df): + """Evaluates the model on the test data for a given subject.""" + detailed_results = [] + + for i in range(test_df.shape[0]): + prompt = _gen_prompt(dev_df, subject, ntrain) + _format_example( + test_df, i, include_answer=False + ) + input_ids = tokenizer(prompt, return_tensors="pt").input_ids + + response_text = _generate_response(tokenizer, model, input_ids) + try: + pred_label = response_text[-1].upper() + # Handle models generating empty outputs + except IndexError: + pred_label = "-" + + label = test_df.iloc[i, -1].strip().upper() + detailed_results.append( + { + "Question": test_df.iloc[i, 0], + "Prompt": prompt, + "Correct Answer": label, + "Generated Answer": pred_label, + "Correct": pred_label == label, + } + ) + + acc = np.mean([res["Correct"] for res in detailed_results]) + return detailed_results, acc diff --git a/src/turnkeyml/llm/tools/ort_genai/__init__.py b/src/turnkeyml/llm/tools/ort_genai/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/turnkeyml/llm/tools/ort_genai/oga.py b/src/turnkeyml/llm/tools/ort_genai/oga.py new file mode 100644 index 00000000..3aa78e06 --- /dev/null +++ b/src/turnkeyml/llm/tools/ort_genai/oga.py @@ -0,0 +1,279 @@ +# onnxruntime_genai is not lint-friendly yet and PyLint can't +# find any of the class methods +# pylint: disable=no-member + +import argparse +import os +import time +from queue import Queue +import onnxruntime_genai as og +from turnkeyml.state import State +from turnkeyml.tools import FirstTool +import turnkeyml.common.status as status +from turnkeyml.llm.tools.adapter import ( + ModelAdapter, + TokenizerAdapter, + PassthroughTokenizerResult, +) +from turnkeyml.llm.cache import Keys + + +class OrtGenaiTokenizer(TokenizerAdapter): + def __init__(self, model: og.Model): + # Initialize the tokenizer and produce the initial tokens. + self.tokenizer = og.Tokenizer(model) + # Placeholder value since some code will try to query it + # If we actually need this to return a proper value, then + # og.GeneratorParams.eos_token_id has it + self.eos_token_id = None + + def __call__(self, prompt: str, return_tensors="np"): + tokens = self.tokenizer.encode(prompt) + + return PassthroughTokenizerResult(tokens) + + # onnxruntime_genai's tokenizer doesn't support any arguments + # yet, so we just ignore skip_special_tokens and hope it + # doesn't have a major negative effect + # pylint: disable=unused-argument + def decode(self, response, skip_special_tokens=True) -> str: + return self.tokenizer.decode(response) + + +class OrtGenaiStreamer: + def __init__(self, tokenizer: OrtGenaiTokenizer, timeout=None): + self.tokenizer = tokenizer + self.text_queue = Queue() + self.stop_signal = None + self.timeout = timeout + + def add_text(self, text: str): + self.text_queue.put(text, timeout=self.timeout) + + def done(self): + self.text_queue.put(self.stop_signal, timeout=self.timeout) + + def __iter__(self): + return self + + def __next__(self): + value = self.text_queue.get(timeout=self.timeout) + if value == self.stop_signal: + raise StopIteration() + else: + return value + + +class OrtGenaiModel(ModelAdapter): + + def __init__(self, input_folder): + super().__init__() + self.model = og.Model(input_folder) + self.type = "ort-genai" + + def generate( + self, + input_ids, + max_new_tokens=512, + do_sample=True, + top_k=50, + top_p=1.0, + temperature=0.7, + streamer: OrtGenaiStreamer = None, + pad_token_id=None, + ): + params = og.GeneratorParams(self.model) + + if pad_token_id: + params.pad_token_id = pad_token_id + + max_length = len(input_ids) + max_new_tokens + + params.input_ids = input_ids + params.set_search_options( + do_sample=do_sample, + top_k=top_k, + top_p=top_p, + temperature=temperature, + max_length=max_length, + min_length=max_length, + ) + params.try_graph_capture_with_max_batch_size(1) + + generator = og.Generator(self.model, params) + + if streamer is None: + prompt_start_time = time.perf_counter() + generator.compute_logits() + generator.generate_next_token() + prompt_end_time = time.perf_counter() + + self.time_to_first_token = prompt_end_time - prompt_start_time + + if max_new_tokens > 1: + + token_gen_times = [] + while not generator.is_done(): + token_gen_start_time = time.perf_counter() + generator.compute_logits() + generator.generate_next_token() + token_gen_end_time = time.perf_counter() + + token_gen_times.append(token_gen_end_time - token_gen_start_time) + + if token_gen_times: + # List will be empty if we generated 1 or 0 tokens, and we don't + # want a divide-by-zero error in those cases + avg_token_gen_latency_s = sum(token_gen_times) / len( + token_gen_times + ) + self.tokens_per_second = 1 / avg_token_gen_latency_s + + return [generator.get_sequence(0)] + else: + tokenizer_stream = streamer.tokenizer.tokenizer.create_stream() + while not generator.is_done(): + generator.compute_logits() + generator.generate_next_token() + + new_token = generator.get_next_tokens()[0] + new_text = tokenizer_stream.decode(new_token) + + streamer.add_text(new_text) + + streamer.add_text("") + streamer.done() + + +# Short names for checkpoints +# So that we don't violate pylint line lengths :) +llama_3 = "meta-llama/Meta-Llama-3-8B" +llama_2 = "meta-llama/Llama-2-7b-chat-hf" +phi_3_mini_4k = "microsoft/Phi-3-mini-4k-instruct" +phi_3_mini_128k = "microsoft/Phi-3-mini-128k-instruct" +qwen_1dot5 = "Qwen/Qwen1.5-7B" + + +class OgaLoad(FirstTool): + """ + Tool that loads an LLM in OnnxRuntime-GenAI for use with DirectML. + + Input: path to a checkpoint. Supported choices: + llama_3 = "meta-llama/Meta-Llama-3-8B" + llama_2 = "meta-llama/Llama-2-7b-chat-hf" + phi_3_mini_4k = "microsoft/Phi-3-mini-4k-instruct" + phi_3_mini_128k = "microsoft/Phi-3-mini-128k-instruct" + + Output: + state.model: handle to a Huggingface-style LLM loaded on DirectML device + state.tokenizer = Huggingface-style LLM tokenizer instance + state.dtype = data type of the model on DirectML device + + Note: This tool expects the onnxruntime-genai-directml library to be pre-installed. + If that library is not installed, this tool will not load. + """ + + unique_name = "oga-load" + + def __init__(self): + super().__init__(monitor_message="Loading OnnxRuntime-GenAI model") + + self.status_stats = [Keys.DTYPE, Keys.DEVICE] + + @staticmethod + def parser(add_help: bool = True) -> argparse.ArgumentParser: + parser = __class__.helpful_parser( + short_description="Load model in onnxruntime-genai (OGA)", + add_help=add_help, + ) + + parser.add_argument( + "-d", + "--device", + choices=["igpu", "npu"], + default="igpu", + help="Which device to load the model on to (default: igpu)", + ) + + parser.add_argument( + "--dtype", + choices=["int4"], + required=True, + help="Data type to load the model in", + ) + + return parser + + def run( + self, + state: State, + input: str = phi_3_mini_128k, + device: str = "igpu", + dtype: str = "int4", + ) -> State: + + checkpoint = input + + # Map of models[device][dtype][checkpoint] to the name of the model folder on disk + supported_models = { + "igpu": { + "int4": { + phi_3_mini_128k: os.path.join( + "phi-3-mini-128k-instruct", + "directml", + "directml-int4-awq-block-128", + ), + phi_3_mini_4k: os.path.join( + "phi-3-mini-4k-instruct", + "directml", + "directml-int4-awq-block-128", + ), + }, + }, + "npu": { + "int4": { + llama_2: "llama2-7b-int4", + llama_3: "llama3-8b-int4", + qwen_1dot5: "qwen1.5-7b-int4", + } + }, + } + + try: + dir_name = supported_models[device][dtype][checkpoint] + except KeyError as e: + raise ValueError( + "The device;dtype;checkpoint combination is not supported: " + f"{device};{dtype};{checkpoint}. The supported combinations " + f"are: {supported_models}" + ) from e + + model_dir = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "models", + dir_name, + ) + + # The NPU requires the CWD to be in the model folder + current_cwd = os.getcwd() + if device == "npu": + os.chdir(model_dir) + # Required environment variable for NPU + os.environ["DOD_ROOT"] = ".\\bins" + + state.model = OrtGenaiModel(model_dir) + state.tokenizer = OrtGenaiTokenizer(state.model.model) + state.dtype = dtype + + state.save_stat(Keys.CHECKPOINT, checkpoint) + state.save_stat(Keys.DTYPE, dtype) + state.save_stat(Keys.DEVICE, device) + + # Create a UniqueInvocationInfo and ModelInfo so that we can display status + # at the end of the sequence + status.add_to_state(state=state, name=input, model=input) + + # Put the CWD back to its original value + os.chdir(current_cwd) + + return state diff --git a/src/turnkeyml/llm/tools/perplexity.py b/src/turnkeyml/llm/tools/perplexity.py new file mode 100644 index 00000000..dd3ebf48 --- /dev/null +++ b/src/turnkeyml/llm/tools/perplexity.py @@ -0,0 +1,144 @@ +import os +import argparse +import pandas as pd +import torch +from datasets import load_dataset +from tqdm import tqdm +from turnkeyml.state import State +from turnkeyml.tools import Tool +import turnkeyml.common.printing as printing +import turnkeyml.common.build as build + + +class AccuracyPerplexity(Tool): + """ + Measure perplexity of an LLM using the wikitext dataset. + Refer to docs/perplexity.md for more details. + + Required input state: + - state.model: instance that provides a __call__() method that returns + output.logits and supports model.config.max_position_embeddings + - state.tokenizer: instance of Hugging Face PretrainedTokenizer + + Output state produced: None + + See docs/perplexity.md for more details. + """ + + unique_name = "accuracy-perplexity" + + def __init__(self): + super().__init__(monitor_message="Measuring perplexity") + + @staticmethod + def parser(add_help: bool = True) -> argparse.ArgumentParser: + parser = __class__.helpful_parser( + short_description="Measure Perplexity score using Wikitext-2 dataset", + add_help=add_help, + ) + return parser + + def run( + self, + state: State, + ) -> State: + + try: + printing.log_info("Downloading dataset ...") + dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + except Exception as e: # pylint: disable=broad-except + printing.log_error(f"Error during dataset load: {e}") + raise e + + tokenizer = state.tokenizer + model = state.model + # Tokenize the entire test dataset text, joining entries with double new lines + encodings = tokenizer("\n\n".join(dataset["text"]), return_tensors="pt") + + # Retrieve the maximum input length that the model can handle + try: + max_length = model.config.max_position_embeddings + except AttributeError: + # Some LLMs do not have the config.max_position_embeddings attribute + # However, most LLMs support at least 2048 context length, so this + # try-except will allow a few more LLMs to work + max_length = 2048 + # Set stride to half of the maximum input length for overlapping window processing + # Refer to docs/perplexity.md for more information on sliding window + stride = max_length // 2 + # Determine the total sequence length of the tokenized input + seq_len = encodings.input_ids.size(1) + + negative_log_likelihoods = [] + summary_data = [] + prev_end_location = 0 + + model_results_dir = os.path.join( + build.output_dir(state.cache_dir, state.build_name), "perplexity" + ) + + for begin_location in tqdm(range(0, seq_len, stride)): + end_location = min(begin_location + max_length, seq_len) + target_len = end_location - prev_end_location + input_ids = encodings.input_ids[:, begin_location:end_location] + target_ids = input_ids.clone() + target_ids[:, :-target_len] = -100 + + # Forward pass the model to get logits + with torch.no_grad(): + try: + outputs = model(input_ids, labels=target_ids) + logits = outputs.logits + except Exception as e: # pylint: disable=broad-except + printing.log_error( + f"Error during model forward pass execution: {e}" + ) + + # Compute loss manually for visualization + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = target_ids[..., 1:].contiguous() + effective_token_count = (target_ids != -100).sum().item() + negative_log_likelihoods.append( + (outputs.loss.item(), effective_token_count) + ) + + # Decode predicted and actual next words for the last token position + predictions = torch.argmax(shift_logits, dim=-1) + predicted_tokens = predictions[:, -1] + actual_tokens = shift_labels[:, -1] + + predicted_words = tokenizer.batch_decode( + predicted_tokens, skip_special_tokens=True + ) + actual_words = tokenizer.batch_decode( + actual_tokens, skip_special_tokens=True + ) + context = tokenizer.decode(input_ids[0, :]) + + summary_data.append( + { + "Context": context[-stride:], + "Predicted next word": predicted_words, + "Actual next word": actual_words, + "Loss for this window": outputs.loss.item(), + } + ) + prev_end_location = end_location + + # Total loss calculation considering the number of tokens for each segment + total_loss = sum(loss * count for loss, count in negative_log_likelihoods) + total_tokens = sum(count for _, count in negative_log_likelihoods) + + # Calculate average negative_log_likelihood and perplexity + average_negative_log_likelihood = total_loss / total_tokens + perplexity = torch.exp(torch.tensor(average_negative_log_likelihood)) + + # Save accuracy results to stats file + state.save_stat("perplexity_score", float(perplexity.item())) + + # Save accuracy results to CSV file + summary_df = pd.DataFrame(summary_data) + summary_df.to_csv( + os.path.join(model_results_dir, "summary_results.csv"), index=False + ) + return state diff --git a/src/turnkeyml/llm/tools/ryzenai_npu/__init__.py b/src/turnkeyml/llm/tools/ryzenai_npu/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/turnkeyml/llm/tools/ryzenai_npu/ryzenai_npu.py b/src/turnkeyml/llm/tools/ryzenai_npu/ryzenai_npu.py new file mode 100644 index 00000000..491b5687 --- /dev/null +++ b/src/turnkeyml/llm/tools/ryzenai_npu/ryzenai_npu.py @@ -0,0 +1,253 @@ +import os +import argparse +import torch +from transformers import ( + LlamaForCausalLM, + LlamaTokenizer, + AutoTokenizer, + PreTrainedTokenizerFast, +) +from ryzenai_llm_engine import RyzenAILLMEngine, TransformConfig +from ryzenai_llm_quantizer import QuantConfig, RyzenAILLMQuantizer +from modeling_phi3 import Phi3ForCausalLM +from turnkeyml.state import State +from turnkeyml.tools import FirstTool +from turnkeyml.llm.tools.adapter import ModelAdapter +from turnkeyml.llm.cache import Keys + +npu_root_dir = os.path.dirname(__file__) +quantized_models_path = os.path.join(npu_root_dir, "quantized_models") +if not os.path.exists(quantized_models_path): + os.mkdir(quantized_models_path) + + +class LlamaModelEval(LlamaForCausalLM): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model_name = "llama-2-7b-chat" + self.tokenizer = None + + def forward(self, *args, **kwargs): + outputs = super().forward(*args, **kwargs) # pylint: disable=no-member + return outputs + + +class Phi3ModelEval(Phi3ForCausalLM): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model_name = "phi-3-mini-4k-instruct" + self.tokenizer = None + + def forward(self, *args, **kwargs): + outputs = super().forward(*args, **kwargs) + return outputs + + def get_position_embeddings(self): + raise NotImplementedError( + f"`get_position_embeddings` is not implemented for {self.__class__}`. " + f"To implement it, you should overwrite this method in the class {self.__class__} " + f"in `modeling_{self.__class__.__module__}.py`" + ) + + def resize_position_embeddings(self, new_num_position_embeddings: int): + raise NotImplementedError( + f"`resize_position_embeddings` is not implemented for {self.__class__}`." + f"To implement it, you should overwrite this method in the class {self.__class__} " + f"in `modeling_{self.__class__.__module__}.py`" + ) + + +class RyzenAiModel(ModelAdapter): + """ + RyzenAI NPU models require an attention_mask of all 1's to be passed + as input to generate. This class exists for the purpose of inserting + that attention mask. + """ + + def __init__(self, model): + super().__init__() + self.model = model + + # pylint: disable=arguments-differ + def generate(self, input_ids, **kwargs): + attention_mask = torch.ones(input_ids.shape) + return self.model.generate( + input_ids=input_ids, attention_mask=attention_mask, **kwargs + ) + + def __getattr__(self, name): + """ + Forward all attribute access to self.model. + """ + return getattr(self.model, name) + + +class RyzenAINPULoad(FirstTool): + """ + Tool that loads an LLM checkpoint on to a RyzenAI NPU. + + Input: the name or path to a checkpoint. Supported options: + "TheBloke/Llama-2-7b-Chat-fp16" + "meta-llama/Llama-2-7b-chat-hf" + "microsoft/Phi-3-mini-4k-instruct" + "meta-llama/Meta-Llama-3-8B-Instruct" + "meta-llama/Meta-Llama-3-8B" + + Output: + state.model: handle to a Huggingface-style LLM loaded on NPU + state.tokenizer = Huggingface-style LLM tokenizer instance + state.dtype = data type of the model on NPU + + Note: This tool expects the ryzenai-transformers library to be pre-installed. + If that library is not installed, this tool will not load. + """ + + unique_name = "ryzenai-npu-load" + + def __init__(self): + super().__init__(monitor_message="Loading LLM on RyzenAI NPU") + + self.status_stats = [Keys.DTYPE] + + @staticmethod + def parser(add_help: bool = True) -> argparse.ArgumentParser: + parser = __class__.helpful_parser( + short_description="Quantize and transform a model using AWQ \ + in int4 format in RyzenAI NPU", + add_help=add_help, + ) + + parser.add_argument("--device", required=True, choices=["phx", "stx"]) + + return parser + + # pylint: disable=C0103 + def run(self, state: State, input: str = "", device=None) -> State: + + checkpoint = input + + w_bit = 4 + group_size = 128 + + if ( + checkpoint == "TheBloke/Llama-2-7b-Chat-fp16" + or checkpoint == "meta-llama/Llama-2-7b-chat-hf" + ): + model_name = "llama-2-7b-chat" + algorithm = "awqplus" + flash_attention_plus = False + trust_remote_code = False + CausalLMModel = LlamaModelEval + LMTokenizer = LlamaTokenizer + quantized_model_path = os.path.join( + quantized_models_path, + f"quantized_llama-2-7b-chat_w{w_bit}_g{group_size}_{algorithm}.pth", + ) + + elif ( + checkpoint == "meta-llama/Meta-Llama-3-8B-Instruct" + or checkpoint == "meta-llama/Meta-Llama-3-8B" + ): + model_name = checkpoint.replace("meta-llama/", "") + algorithm = "awqplus" + flash_attention_plus = False + trust_remote_code = False + CausalLMModel = LlamaModelEval + LMTokenizer = PreTrainedTokenizerFast + quantized_model_path = os.path.join( + quantized_models_path, + f"quantized_{model_name}_w{w_bit}_g{group_size}_{algorithm}.pth", + ) + + elif checkpoint == "microsoft/Phi-3-mini-4k-instruct": + model_name = "phi-3-mini-4k-instruct" + algorithm = "pergrp" + flash_attention_plus = False + trust_remote_code = True + CausalLMModel = Phi3ModelEval + LMTokenizer = AutoTokenizer + + quantized_model_path = os.path.join( + quantized_models_path, + f"quantized_Phi-3-mini-4k-instruct_w{w_bit}_g{group_size}_{algorithm}.pth", + ) + + else: + raise ValueError(f"Model {checkpoint} is not a supported model.") + + if not os.path.exists(quantized_model_path): + + model = CausalLMModel.from_pretrained( + checkpoint, + torch_dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + attn_implementation="eager", + ) + + model.tokenizer = LMTokenizer.from_pretrained( + checkpoint, trust_remote_code=trust_remote_code + ) + + quant_config = QuantConfig( + quant_mode=algorithm, + model_name=checkpoint, + dataset="raw", + w_bit=w_bit, + group_size=group_size, + use_qscales=True, + ) + + model = RyzenAILLMQuantizer.quantize(model, quant_config=quant_config) + torch.save(model, quantized_model_path) + else: + model = torch.load(quantized_model_path) + + if device == "phx": + fast_attention = False + elif device == "stx": + fast_attention = True + else: + raise Exception(f"Use a supported device instead of {device}") + + # Different library versions support different flags + # We maintain a safe set of flags and a cutting-edge set of flags, + # and attempt each + try: + transform_config = TransformConfig( + flash_attention_plus=flash_attention_plus, + fast_attention=fast_attention, + fast_mlp=device != "phx", + fast_norm=device != "phx", + precision="w4abf16", + model_name=model_name, + target="aie", + w_bit=w_bit, + group_size=group_size, + profilegemm=False, + ) + except TypeError: + transform_config = TransformConfig( + flash_attention_plus=False, + fast_attention=False, + fast_mlp=False, + precision="w4abf16", + model_name=model_name, + target="aie", + w_bit=w_bit, + group_size=group_size, + profilegemm=False, + ) + + model = RyzenAILLMEngine.transform(model, transform_config) + model = model.to(torch.bfloat16) + model.eval() + + state.model = RyzenAiModel(model) + state.tokenizer = model.tokenizer + state.dtype = "int4" + + state.save_stat(Keys.CHECKPOINT, checkpoint) + state.save_stat(Keys.DEVICE, "ryzenai-npu") + state.save_stat(Keys.DTYPE, "int4") + + return state diff --git a/src/turnkeyml/version.py b/src/turnkeyml/version.py index ce1305bf..76ad18b8 100644 --- a/src/turnkeyml/version.py +++ b/src/turnkeyml/version.py @@ -1 +1 @@ -__version__ = "4.0.0" +__version__ = "4.0.1" diff --git a/test/llm_api.py b/test/llm_api.py new file mode 100644 index 00000000..28ed5bbe --- /dev/null +++ b/test/llm_api.py @@ -0,0 +1,57 @@ +import unittest +import shutil +import os +from turnkeyml.state import State +import turnkeyml.common.filesystem as fs +import turnkeyml.common.test_helpers as common +from turnkeyml.llm.tools.huggingface_load import HuggingfaceLoad +from turnkeyml.llm.tools.mmlu import AccuracyMMLU +from turnkeyml.llm.tools.chat import LLMPrompt + +ci_mode = os.getenv("LEMONADE_CI_MODE", False) + + +class Testing(unittest.TestCase): + def setUp(self) -> None: + shutil.rmtree(cache_dir, ignore_errors=True) + + def test_001_prompt(self): + """ + Test the LLM Prompt tool + """ + + checkpoint = "facebook/opt-125m" + prompt = "my solution to the test is" + + state = State( + cache_dir=cache_dir, + build_name="test", + ) + + state = HuggingfaceLoad().run(state, input=checkpoint) + state = LLMPrompt().run(state, prompt=prompt, max_new_tokens=15) + + assert len(state.response) > len(prompt), state.response + + def test_002_accuracy_mmlu(self): + # Test MMLU benchmarking with known model + checkpoint = "facebook/opt-125m" + subject = ["management"] + + state = State( + cache_dir=cache_dir, + build_name="test", + ) + + state = HuggingfaceLoad().run(state, input=checkpoint) + state = AccuracyMMLU().run(state, ntrain=5, tests=subject) + + stats = fs.Stats(state.cache_dir, state.build_name).stats + assert stats[f"mmlu_{subject[0]}_accuracy"] > 0 + + + + +if __name__ == "__main__": + cache_dir, _ = common.create_test_dir("lemonade_api") + unittest.main()